Skip to content

Commit 8278a3e

Browse files
kevinzakkacopybara-github
authored andcommitted
Allow custom composer Arena XMLs.
PiperOrigin-RevId: 515346126 Change-Id: I4e9947bf4549d3202fa460d77273709aa43dac41
1 parent de21947 commit 8278a3e

1 file changed

Lines changed: 26 additions & 4 deletions

File tree

dm_control/composer/arena.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,36 @@
2626
class Arena(entity_module.Entity):
2727
"""The base empty arena that defines global settings for Composer."""
2828

29-
def _build(self, name=None):
29+
def __init__(self, *args, **kwargs):
30+
self._mjcf_root = None # Declare that _mjcf_root exists to allay pytype.
31+
super().__init__(*args, **kwargs)
32+
33+
# _build uses *args and **kwargs rather than named arguments, to get
34+
# around a signature-mismatch error from pytype in derived classes.
35+
36+
def _build(self, *args, **kwargs) -> None:
3037
"""Initializes this arena.
3138
39+
The function takes two arguments through args, kwargs:
40+
name: A string, the name of this arena. If `None`, use the model name
41+
defined in the MJCF file.
42+
xml_path: An optional path to an XML file that will override the default
43+
composer arena MJCF.
44+
3245
Args:
33-
name: (optional) A string, the name of this arena. If `None`, use the
34-
model name defined in the MJCF file.
46+
*args: See above.
47+
**kwargs: See above.
3548
"""
36-
self._mjcf_root = mjcf.from_path(_ARENA_XML_PATH)
49+
if args:
50+
name = args[0]
51+
else:
52+
name = kwargs.get('name', None)
53+
if len(args) > 1:
54+
xml_path = args[1]
55+
else:
56+
xml_path = kwargs.get('xml_path', None)
57+
58+
self._mjcf_root = mjcf.from_path(xml_path or _ARENA_XML_PATH)
3759
if name:
3860
self._mjcf_root.model = name
3961

0 commit comments

Comments
 (0)