Skip to content

Commit 931ad67

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Fix binding to a single stateful actuator.
When binding single elements, the named_index argument to Binding is a string, rather than a list of strings. Most code in the Binding class handles this, but for stateful actuators, _filter_stateful_actuators didn't handle the case of the single actuator name, which caused a crash. PiperOrigin-RevId: 450437063 Change-Id: I7e47287d162fb2c765a013d86fb104f09caebd64
1 parent 41d0c73 commit 931ad67

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

dm_control/mjcf/physics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def _get_actuator_state_fields():
106106

107107
def _filter_stateful_actuators(physics, actuator_names):
108108
"""Removes any stateless actuators from the list of actuator names."""
109+
if isinstance(actuator_names, str):
110+
actuator_names = [actuator_names]
111+
109112
if physics.model.na:
110113
# MuJoCo requires that stateful actuators always come after stateless
111114
# actuators in the model, so we keep actuator names only if their

dm_control/mjcf/physics_test.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def test_construct_and_reload_from_mjcf_model(self):
6666
('geom', True),
6767
('geom', False),
6868
('joint', True),
69-
('joint', False))
69+
('joint', False),
70+
('actuator', True),
71+
('actuator', False))
7072
def test_id(self, namespace, single_element):
7173
elements, full_identifiers = self.sample_elements(namespace, single_element)
7274
actual = self.physics.bind(elements).element_id
@@ -204,6 +206,22 @@ def test_bind_worldbody(self):
204206
mass = physics.bind(model.worldbody).subtreemass
205207
self.assertEqual(mass, expected_mass)
206208

209+
def test_bind_stateful_actuator(self):
210+
model = mjcf.RootElement()
211+
body = model.worldbody.add('body')
212+
body.add('joint', name='joint')
213+
body.add('geom', type='sphere', size=[1])
214+
215+
model.actuator.add(
216+
'general', name='act1', joint='joint', dyntype='integrator')
217+
218+
physics = mjcf.Physics.from_mjcf_model(model)
219+
actuator = model.find('actuator', 'act1')
220+
binding = physics.bind(actuator)
221+
222+
# This used to crash
223+
self.assertEqual(0, binding.act)
224+
207225
def test_caching(self):
208226
all_joints = self.model.find_all('joint')
209227

0 commit comments

Comments
 (0)