Skip to content

Commit e3d4c0a

Browse files
saran-tcopybara-github
authored andcommitted
Stop requiring stateful actuators to come after stateless ones.
PiperOrigin-RevId: 487201455 Change-Id: Iec3076769486b763053f32b5941678b4d6583eba
1 parent 822cd01 commit e3d4c0a

4 files changed

Lines changed: 38 additions & 95 deletions

File tree

dm_control/mjcf/element.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,47 +1134,24 @@ def to_xml(self, prefix_root=None, debug_context=None,
11341134

11351135

11361136
class _ActuatorElement(_ElementImpl):
1137-
"""Specialized object representing an <actuator> element.
1137+
"""Specialized object representing an <actuator> element."""
11381138

1139-
This is necessary because MuJoCo requires that all 3rd-order actuators (i.e.
1140-
those with internal dynamics) come after all 2nd-order actuators in the
1141-
generated XML.
1142-
"""
11431139
__slots__ = ()
11441140

1145-
def _is_third_order_actuator(self, child):
1146-
if child.tag == 'general':
1147-
return child.dyntype and child.dyntype != 'none'
1148-
elif child.tag == 'cylinder':
1149-
return True # The `<cylinder>` shortcut has 'filter' dynamics.
1150-
else:
1151-
return False # No other actuator shortcuts have internal dynamics.
1152-
11531141
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
11541142
*,
11551143
precision=constants.XML_DEFAULT_PRECISION,
11561144
zero_threshold=0):
1157-
second_order = []
1158-
third_order = []
11591145
debug_comments = {}
11601146
for child in self.all_children():
11611147
child_xml = child.to_xml(prefix_root, debug_context,
11621148
precision=precision,
11631149
zero_threshold=zero_threshold)
1164-
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
1165-
or child.spec.repeated or child.spec.on_demand):
1166-
if self._is_third_order_actuator(child):
1167-
third_order.append(child_xml)
1168-
else:
1169-
second_order.append(child_xml)
1170-
if debugging.debug_mode() and debug_context:
1171-
debug_comment = debug_context.register_element_for_debugging(child)
1172-
debug_comments[child_xml] = debug_comment
1173-
if len(child_xml) > 0: # pylint: disable=g-explicit-length-test
1174-
child_xml.insert(0, copy.deepcopy(debug_comment))
1175-
# Ensure that all second-order actuators come before third-order actuators
1176-
# in the XML.
1177-
for child_xml in second_order + third_order:
1150+
if debugging.debug_mode() and debug_context:
1151+
debug_comment = debug_context.register_element_for_debugging(child)
1152+
debug_comments[child_xml] = debug_comment
1153+
if len(child_xml) > 0: # pylint: disable=g-explicit-length-test
1154+
child_xml.insert(0, copy.deepcopy(debug_comment))
11781155
xml_element.append(child_xml)
11791156
if debugging.debug_mode() and debug_context:
11801157
xml_element.append(debug_comments[child_xml])

dm_control/mjcf/physics.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -91,36 +91,6 @@ def _get_attributes(size_names, strip_prefixes):
9191
return out
9292

9393

94-
# Fields related to the internal states of actuators (i.e. with a leading
95-
# dimension of 'na') require special treatment.
96-
def _get_actuator_state_fields():
97-
actuator_state_fields = []
98-
for sizes_dict in sizes.array_sizes.values():
99-
for field_name, dimensions in sizes_dict.items():
100-
if dimensions[0] == 'na':
101-
actuator_state_fields.append(field_name)
102-
return frozenset(actuator_state_fields)
103-
104-
_ACTUATOR_STATE_FIELDS = _get_actuator_state_fields()
105-
106-
107-
def _filter_stateful_actuators(physics, actuator_names):
108-
"""Removes any stateless actuators from the list of actuator names."""
109-
if isinstance(actuator_names, str):
110-
actuator_names = [actuator_names]
111-
112-
if physics.model.na:
113-
# MuJoCo requires that stateful actuators always come after stateless
114-
# actuators in the model, so we keep actuator names only if their
115-
# corresponding IDs are >= to the total number of stateless actuators.
116-
num_stateless_actuators = physics.model.nu - physics.model.na
117-
return [
118-
name for name in actuator_names
119-
if physics.model.name2id(name, 'actuator') >= num_stateless_actuators]
120-
else:
121-
return []
122-
123-
12494
_ATTRIBUTES = {
12595
'actuator': _get_attributes(['na', 'nu'], strip_prefixes=['actuator']),
12696
'body': _get_attributes(['nbody'], strip_prefixes=['body']),
@@ -320,14 +290,7 @@ def _get_cached_array_and_index(self, name):
320290
try:
321291
index = self._array_index_cache[name]
322292
except KeyError:
323-
# If we are indexing into a field relating to actuator internal states
324-
# then we must first remove the names of any stateless actuators.
325-
if name in _ACTUATOR_STATE_FIELDS:
326-
named_index = _filter_stateful_actuators(
327-
self._physics, self._named_index)
328-
else:
329-
named_index = self._named_index
330-
index = named_indexer._convert_key(named_index) # pylint: disable=protected-access
293+
index = named_indexer._convert_key(self._named_index) # pylint: disable=protected-access
331294
self._array_index_cache[name] = index
332295
return array, index
333296

@@ -376,7 +339,7 @@ def __getattr__(self, name):
376339

377340
if self._physics.is_dirty and not triggers_dirty:
378341
self._physics.forward()
379-
if isinstance(index, int) and array.ndim == 1:
342+
if np.issubdtype(type(index), np.integer) and array.ndim == 1:
380343
# Case where indexing results in a scalar.
381344
out = array[index]
382345
else:

dm_control/mujoco/index.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
_RAGGED_ADDRS = {
9595
'nq': 'jnt_qposadr',
9696
'nv': 'jnt_dofadr',
97+
'na': 'actuator_actadr',
9798
'nsensordata': 'sensor_adr',
9899
'nnumericdata': 'numeric_adr',
99100
}
@@ -218,6 +219,8 @@ def _get_size_name_to_element_names(model):
218219
# For example, the element names for "nv" axis come from "njnt".
219220
for size_name, address_field_name in _RAGGED_ADDRS.items():
220221
donor = 'n' + address_field_name.split('_')[0]
222+
if donor == 'nactuator':
223+
donor = 'nu'
221224
if donor in size_name_to_element_names:
222225
size_name_to_element_names[size_name] = size_name_to_element_names[donor]
223226

@@ -230,13 +233,6 @@ def _get_size_name_to_element_names(model):
230233
assert None not in mocap_body_names
231234
size_name_to_element_names['nmocap'] = mocap_body_names
232235

233-
# Arrays with dimension `na` correspond to stateful actuators. MuJoCo's
234-
# compiler requires that these are always defined after stateless actuators,
235-
# so we only need the final `na` elements in the list of all actuator names.
236-
if model.na:
237-
all_actuator_names = size_name_to_element_names['nu']
238-
size_name_to_element_names['na'] = all_actuator_names[-model.na:]
239-
240236
return size_name_to_element_names
241237

242238

@@ -255,8 +251,11 @@ def _get_size_name_to_element_sizes(model):
255251

256252
for size_name, address_field_name in _RAGGED_ADDRS.items():
257253
addresses = getattr(model, address_field_name).ravel()
258-
total_length = getattr(model, size_name)
259-
element_sizes = np.diff(np.r_[addresses, total_length])
254+
if size_name == 'na':
255+
element_sizes = np.where(addresses == -1, 0, 1)
256+
else:
257+
total_length = getattr(model, size_name)
258+
element_sizes = np.diff(np.r_[addresses, total_length])
260259
size_name_to_element_sizes[size_name] = element_sizes
261260

262261
return size_name_to_element_sizes
@@ -282,7 +281,9 @@ def make_axis_indexers(model):
282281
element_names = size_name_to_element_names[size_name]
283282
if size_name in _RAGGED_ADDRS:
284283
element_sizes = size_name_to_element_sizes[size_name]
285-
indexer = RaggedNamedAxis(element_names, element_sizes)
284+
singleton = (size_name == 'na')
285+
indexer = RaggedNamedAxis(element_names, element_sizes,
286+
singleton=singleton)
286287
else:
287288
indexer = RegularNamedAxis(element_names)
288289
axis_indexers[size_name] = indexer
@@ -377,12 +378,13 @@ def names(self):
377378
class RaggedNamedAxis(Axis):
378379
"""Represents an axis where the named elements may vary in size."""
379380

380-
def __init__(self, element_names, element_sizes):
381+
def __init__(self, element_names, element_sizes, singleton=False):
381382
"""Initializes a new `RaggedNamedAxis` instance.
382383
383384
Args:
384385
element_names: A list or array containing the element names.
385386
element_sizes: A list or array containing the size of each element.
387+
singleton: Whether to reduce singleton slices to scalars.
386388
"""
387389
names_to_slices = {}
388390
names_to_indices = {}
@@ -391,7 +393,10 @@ def __init__(self, element_names, element_sizes):
391393
for name, size in zip(element_names, element_sizes):
392394
# Don't add unnamed elements to the dicts.
393395
if name:
394-
names_to_slices[name] = slice(offset, offset + size)
396+
if size == 1 and singleton:
397+
names_to_slices[name] = offset
398+
else:
399+
names_to_slices[name] = slice(offset, offset + size)
395400
names_to_indices[name] = range(offset, offset + size)
396401
offset += size
397402

@@ -400,29 +405,29 @@ def __init__(self, element_names, element_sizes):
400405
self._names_to_slices = names_to_slices
401406
self._names_to_indices = names_to_indices
402407

403-
def convert_key_item(self, key):
408+
def convert_key_item(self, key_item):
404409
"""Converts a named indexing expression to a numpy-friendly index."""
405410

406-
_validate_key_item(key)
411+
_validate_key_item(key_item)
407412

408-
if isinstance(key, str):
409-
key = self._names_to_slices[util.to_native_string(key)]
413+
if isinstance(key_item, str):
414+
key_item = self._names_to_slices[util.to_native_string(key_item)]
410415

411-
elif isinstance(key, (list, np.ndarray)):
416+
elif isinstance(key_item, (list, np.ndarray)):
412417
# We assume that either all or none of the items in the sequence are
413418
# strings representing names. If there is a mix, we will let NumPy throw
414419
# an error when trying to index with the returned key.
415-
if isinstance(key[0], str):
420+
if isinstance(key_item[0], str):
416421
new_key = []
417-
for k in key:
422+
for k in key_item:
418423
idx = self._names_to_indices[util.to_native_string(k)]
419424
if isinstance(idx, int):
420425
new_key.append(idx)
421426
else:
422427
new_key.extend(idx)
423-
key = new_key
428+
key_item = new_key
424429

425-
return key
430+
return key_item
426431

427432
@property
428433
def names(self):

dm_control/mujoco/testing/assets/model_with_third_order_actuators.xml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
</body>
77
</worldbody>
88
<actuator>
9-
<!-- Second-order actuators -->
10-
<motor name="motor" joint="slide_joint"/>
11-
<velocity name="velocity" joint="slide_joint"/>
12-
<!-- Third-order actuators -->
13-
<cylinder name="cylinder" joint="slide_joint"/>
14-
<general name="general" joint="slide_joint" dyntype="integrator" biastype="affine" dynprm="1 0 0"/>
9+
<motor name="motor" joint="slide_joint"/> <!-- Second-order -->
10+
<cylinder name="cylinder" joint="slide_joint"/> <!-- Third-order -->
11+
<velocity name="velocity" joint="slide_joint"/> <!-- Second-order -->
12+
<general name="general" joint="slide_joint" dyntype="integrator" biastype="affine" dynprm="1 0 0"/> <!-- Third-order -->
1513
</actuator>
1614
</mujoco>

0 commit comments

Comments
 (0)