Skip to content

Commit 690b2cf

Browse files
nimrod-gileadicopybara-github
authored andcommitted
Make sure fields in mjData are initialized in index_test.
The test was reading fields that aren't initialized until mj_forward is called. PiperOrigin-RevId: 518012755 Change-Id: I424b712bd1b443fc3639ea686c8ed4a473059c1b
1 parent 87c695c commit 690b2cf

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

dm_control/mujoco/index_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from dm_control.mujoco import wrapper
2424
from dm_control.mujoco.testing import assets
2525
from dm_control.mujoco.wrapper.mjbindings import sizes
26+
import mujoco
2627
import numpy as np
2728

2829
MODEL = assets.get_contents('cartpole.xml')
@@ -58,6 +59,7 @@ def setUp(self):
5859
super().setUp()
5960
self._model = wrapper.MjModel.from_xml_string(MODEL)
6061
self._data = wrapper.MjData(self._model)
62+
mujoco.mj_forward(self._model.ptr, self._data.ptr)
6163

6264
self._size_to_axis_indexer = index.make_axis_indexers(self._model)
6365

@@ -285,8 +287,8 @@ def testBuildIndexersForEdgeCases(self, xml_string):
285287
index.struct_indexer(data, 'mjdata', size_to_axis_indexer)
286288

287289
# pylint: disable=undefined-variable
288-
@parameterized.parameters([
289-
name for name in dir(np.ndarray)
290+
@parameterized.named_parameters([
291+
(name, name) for name in dir(np.ndarray)
290292
if not name.startswith('_') # Exclude 'private' attributes
291293
and name not in ('ctypes', 'flat') # Can't compare via identity/equality
292294
])
@@ -312,6 +314,7 @@ def testFieldIndexerDir(self):
312314

313315

314316
def _iter_indexers(model, data):
317+
mujoco.mj_forward(model.ptr, data.ptr)
315318
size_to_axis_indexer = index.make_axis_indexers(model)
316319
all_fields = collections.OrderedDict()
317320
for struct, struct_name in ((model, 'mjmodel'), (data, 'mjdata')):

0 commit comments

Comments
 (0)