Skip to content

Commit d5d054b

Browse files
DeepMindcopybara-github
authored andcommitted
Update state items to include plugin state in mujoco engine.
PiperOrigin-RevId: 500738879 Change-Id: I339e26c7cdc928890d348654109c39b8f76ebc2a
1 parent 321717b commit d5d054b

2 files changed

Lines changed: 59 additions & 2 deletions

File tree

dm_control/mujoco/engine.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,12 +538,21 @@ def _physics_state_items(self):
538538
"""Returns list of arrays making up internal physics simulation state.
539539
540540
The physics state consists of the state variables, their derivatives and
541-
actuation activations.
541+
actuation activations. If the model contains plugins, then the state will
542+
also contain any plugin state.
542543
543544
Returns:
544545
List of NumPy arrays containing full physics simulation state.
545546
"""
546-
return [self.data.qpos, self.data.qvel, self.data.act]
547+
if self.model.nplugin > 0:
548+
return [
549+
self.data.qpos,
550+
self.data.qvel,
551+
self.data.act,
552+
self.data.plugin_state,
553+
]
554+
else:
555+
return [self.data.qpos, self.data.qvel, self.data.act]
547556

548557
# Named views of simulation data.
549558

dm_control/mujoco/engine_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ def testNamedViews(self):
370370

371371
def testSetGetPhysicsState(self):
372372
physics_state = self._physics.get_state()
373+
374+
# qpos, qvel, act
375+
self.assertLen(self._physics._physics_state_items(), 3)
376+
373377
self._physics.set_state(physics_state)
374378

375379
new_physics_state = np.random.random_sample(physics_state.shape)
@@ -378,6 +382,50 @@ def testSetGetPhysicsState(self):
378382
np.testing.assert_allclose(new_physics_state,
379383
self._physics.get_state())
380384

385+
def testSetGetPhysicsStateWithPlugin(self):
386+
# Model copied from mujoco/test/plugin/elasticity/elasticity_test.cc
387+
model_with_cable_plugin = """
388+
<mujoco>
389+
<option gravity="0 0 0"/>
390+
<extension>
391+
<required plugin="mujoco.elasticity.cable"/>
392+
</extension>
393+
<worldbody>
394+
<geom type="plane" size="0 0 1" quat="1 0 0 0"/>
395+
<site name="reference" pos="0 0 0"/>
396+
<composite type="cable" curve="s" count="41 1 1" size="1" offset="0 0 1" initial="none">
397+
<plugin plugin="mujoco.elasticity.cable">
398+
<config key="twist" value="1e6"/>
399+
<config key="bend" value="1e9"/>
400+
</plugin>
401+
<joint kind="main" damping="2"/>
402+
<geom type="capsule" size=".005" density="1"/>
403+
</composite>
404+
</worldbody>
405+
<contact>
406+
<exclude body1="B_first" body2="B_last"/>
407+
</contact>
408+
<sensor>
409+
<framepos objtype="site" objname="S_last"/>
410+
</sensor>
411+
<actuator>
412+
<motor site="S_last" gear="0 0 0 0 1 0" ctrllimited="true" ctrlrange="0 4"/>
413+
</actuator>
414+
</mujoco>
415+
"""
416+
physics = engine.Physics.from_xml_string(model_with_cable_plugin)
417+
physics_state = physics.get_state()
418+
419+
# qpos, qvel, act, plugin_state
420+
self.assertLen(physics._physics_state_items(), 4)
421+
422+
physics.set_state(physics_state)
423+
424+
new_physics_state = np.random.random_sample(physics_state.shape)
425+
physics.set_state(new_physics_state)
426+
427+
np.testing.assert_allclose(new_physics_state, physics.get_state())
428+
381429
def testSetInvalidPhysicsState(self):
382430
badly_shaped_state = np.repeat(self._physics.get_state(), repeats=2)
383431

0 commit comments

Comments
 (0)