Skip to content

Commit 34bf5db

Browse files
authored
Merge branch 'google-deepmind:main' into musculoskeletal_dog_creation
2 parents ffa2358 + 241851f commit 34bf5db

16 files changed

Lines changed: 443 additions & 70 deletions

File tree

dm_control/_render/pyopengl/egl_renderer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,13 @@ def _platform_init(self, unused_max_width, unused_max_height):
103103
"""Initialization this EGL context."""
104104
num_configs = ctypes.c_long(0)
105105
config_size = 1
106-
config = EGL.EGLConfig()
106+
# ctypes syntax for making an array of length config_size.
107+
configs = (EGL.EGLConfig * config_size)()
107108
EGL.eglReleaseThread()
108109
EGL.eglChooseConfig(
109110
EGL_DISPLAY,
110111
EGL_ATTRIBUTES,
111-
ctypes.byref(config),
112+
configs,
112113
config_size,
113114
num_configs)
114115
if num_configs.value < 1:
@@ -117,7 +118,7 @@ def _platform_init(self, unused_max_width, unused_max_height):
117118
'desired attributes: {}'.format(EGL_ATTRIBUTES))
118119
EGL.eglBindAPI(EGL.EGL_OPENGL_API)
119120
self._context = EGL.eglCreateContext(
120-
EGL_DISPLAY, config, EGL.EGL_NO_CONTEXT, None)
121+
EGL_DISPLAY, configs[0], EGL.EGL_NO_CONTEXT, None)
121122
if not self._context:
122123
raise RuntimeError('Cannot create an EGL context.')
123124

dm_control/composer/environment.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,26 @@ def control_timestep(self):
292292
class Environment(_CommonEnvironment, dm_env.Environment):
293293
"""Reinforcement learning environment for Composer tasks."""
294294

295-
def __init__(self, task, time_limit=float('inf'), random_state=None,
296-
n_sub_steps=None,
297-
raise_exception_on_physics_error=True,
298-
strip_singleton_obs_buffer_dim=False,
299-
max_reset_attempts=1,
300-
delayed_observation_padding=ObservationPadding.ZERO,
301-
legacy_step: bool = True):
295+
def __init__(
296+
self,
297+
task,
298+
time_limit=float('inf'),
299+
random_state=None,
300+
n_sub_steps=None,
301+
raise_exception_on_physics_error=True,
302+
strip_singleton_obs_buffer_dim=False,
303+
max_reset_attempts=1,
304+
recompile_mjcf_every_episode=True,
305+
fixed_initial_state=False,
306+
delayed_observation_padding=ObservationPadding.ZERO,
307+
legacy_step: bool = True,
308+
):
302309
"""Initializes an instance of `Environment`.
303310
304311
Args:
305312
task: Instance of `composer.base.Task`.
306-
time_limit: (optional) A float, the time limit in seconds beyond which
307-
an episode is forced to terminate.
313+
time_limit: (optional) A float, the time limit in seconds beyond which an
314+
episode is forced to terminate.
308315
random_state: (optional) an int seed or `np.random.RandomState` instance.
309316
n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
310317
agent control step. New code should instead override the
@@ -313,15 +320,22 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
313320
`PhysicsError` should be raised as an exception. If `False`, physics
314321
errors will result in the current episode being terminated with a
315322
warning logged, and a new episode started.
316-
strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
317-
the array shape of observations with `buffer_size == 1` will not have a
318-
leading buffer dimension.
323+
strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, the array
324+
shape of observations with `buffer_size == 1` will not have a leading
325+
buffer dimension.
319326
max_reset_attempts: (optional) Maximum number of times to try resetting
320-
the environment. If an `EpisodeInitializationError` is raised
321-
during this process, an environment reset is reattempted up to this
322-
number of times. If this count is exceeded then the most recent
323-
exception will be allowed to propagate. Defaults to 1, i.e. no failure
324-
is allowed.
327+
the environment. If an `EpisodeInitializationError` is raised during
328+
this process, an environment reset is reattempted up to this number of
329+
times. If this count is exceeded then the most recent exception will be
330+
allowed to propagate. Defaults to 1, i.e. no failure is allowed.
331+
recompile_mjcf_every_episode: If True will recompile the mjcf model
332+
between episodes. This specifically skips the `initialize_episode_mjcf`
333+
and `after_compile` steps. This allows a speedup if no changes are made
334+
to the model.
335+
fixed_initial_state: If True the starting state of every single episode
336+
will be the same. Meaning an identical sequence of action will lead to
337+
an identical final state. If False, will randomize the starting state at
338+
every episode.
325339
delayed_observation_padding: (optional) An `ObservationPadding` enum value
326340
specifying the padding behavior of the initial buffers for delayed
327341
observables. If `ZERO` then the buffer is initially filled with zeroes.
@@ -340,6 +354,10 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
340354
delayed_observation_padding=delayed_observation_padding,
341355
legacy_step=legacy_step)
342356
self._max_reset_attempts = max_reset_attempts
357+
self._recompile_mjcf_every_episode = recompile_mjcf_every_episode
358+
self._mjcf_never_compiled = True
359+
self._fixed_initial_state = fixed_initial_state
360+
self._fixed_random_state = self._random_state.get_state()
343361
self._reset_next_step = True
344362

345363
def reset(self):
@@ -355,8 +373,15 @@ def reset(self):
355373
raise
356374

357375
def _reset_attempt(self):
358-
self._hooks.initialize_episode_mjcf(self._random_state)
359-
self._recompile_physics_and_update_observables()
376+
if self._recompile_mjcf_every_episode or self._mjcf_never_compiled:
377+
if self._fixed_initial_state:
378+
self._random_state.set_state(self._fixed_random_state)
379+
self._hooks.initialize_episode_mjcf(self._random_state)
380+
self._recompile_physics_and_update_observables()
381+
self._mjcf_never_compiled = False
382+
383+
if self._fixed_initial_state:
384+
self._random_state.set_state(self._fixed_random_state)
360385
with self._physics.reset_context():
361386
self._hooks.initialize_episode(self._physics_proxy, self._random_state)
362387
self._observation_updater.reset(self._physics_proxy, self._random_state)

dm_control/composer/environment_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def initialize_episode(self, physics, random_state):
5353
raise composer.EpisodeInitializationError()
5454

5555

56+
class DummyTaskWithRandomObservation(composer.NullTask):
57+
58+
def __init__(self):
59+
null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
60+
super().__init__(null_entity)
61+
62+
self._observation = [0.0] * 1000
63+
64+
def initialize_episode(self, physics, random_state):
65+
del physics
66+
self._observation = random_state.randint(1000, size=1000)
67+
68+
@property
69+
def task_observables(self):
70+
random_int = observable.Generic(lambda physics: self._observation)
71+
random_int.enabled = True
72+
return {'random_int': random_int}
73+
74+
5675
class EnvironmentTest(parameterized.TestCase):
5776

5877
def test_failed_resets(self):
@@ -96,5 +115,48 @@ def test_can_provide_observation(self):
96115
self.assertLen(obs, 1)
97116
np.testing.assert_array_equal(obs['time'], env.physics.time())
98117

118+
def test_dont_compile_mjcf_between_episodes(self):
119+
class AfterCompileHook(object):
120+
121+
def __init__(self):
122+
self.after_compile_call_count = 0
123+
124+
def __call__(self, physics, random_state):
125+
del physics, random_state
126+
self.after_compile_call_count += 1
127+
128+
after_compile_hook = AfterCompileHook()
129+
task = DummyTask()
130+
env = composer.Environment(task, recompile_mjcf_every_episode=False)
131+
env.add_extra_hook('after_compile', after_compile_hook)
132+
env.reset()
133+
self.assertEqual(after_compile_hook.after_compile_call_count, 1)
134+
for _ in range(4):
135+
env.reset()
136+
env.step([])
137+
138+
# Check the hook is not called.
139+
self.assertEqual(after_compile_hook.after_compile_call_count, 1)
140+
141+
def test_fixed_initial_state(self):
142+
task = DummyTaskWithRandomObservation()
143+
fixed_env = composer.Environment(task, fixed_initial_state=True)
144+
non_fixed_env = composer.Environment(task, fixed_initial_state=False)
145+
fixed_obs = fixed_env.reset().observation['random_int']
146+
non_fixed_obs = non_fixed_env.reset().observation['random_int']
147+
for _ in range(3):
148+
np.testing.assert_array_equal(
149+
fixed_env.reset().observation['random_int'], fixed_obs
150+
)
151+
self.assertTrue(
152+
np.any(
153+
np.not_equal(
154+
np.asarray(non_fixed_obs),
155+
np.asarray(non_fixed_env.reset().observation['random_int']),
156+
)
157+
)
158+
)
159+
160+
99161
if __name__ == '__main__':
100162
absltest.main()

dm_control/locomotion/examples/explore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
from absl import app
1919

20-
from dm_control import viewer
2120
from dm_control.locomotion.examples import basic_cmu_2019
21+
from dm_control import viewer
2222

2323

2424
def main(unused_argv):

dm_control/locomotion/mocap/loader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
from dm_control.locomotion.mocap import mocap_pb2
2323
from dm_control.locomotion.mocap import trajectory
2424
from dm_control.utils import transformations as tr
25-
import numpy as np
26-
2725
from google.protobuf import descriptor
26+
import numpy as np
2827

2928

3029
class TrajectoryLoader(metaclass=abc.ABCMeta):

dm_control/locomotion/mocap/loader_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from dm_control.locomotion.mocap import loader
2121
from dm_control.locomotion.mocap import mocap_pb2
2222
from dm_control.locomotion.mocap import trajectory
23-
2423
from google.protobuf import descriptor
2524
from google.protobuf import text_format
25+
2626
from dm_control.utils import io as resources
2727

2828
TEXTPROTOS = [

dm_control/locomotion/soccer/explore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414
# ============================================================================
1515

16-
"""Interactive viewer for MuJoCo soccer enviornmnet."""
16+
"""Interactive viewer for MuJoCo soccer environment."""
1717

1818
import functools
1919
from absl import app
2020
from absl import flags
21-
from dm_control import viewer
2221
from dm_control.locomotion import soccer
22+
from dm_control import viewer
2323

2424
FLAGS = flags.FLAGS
2525

dm_control/mjcf/physics_test.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def test_error_when_pickling_synchronizing_array_wrapper(self):
579579
mjcf_physics._PICKLING_NOT_SUPPORTED.format(type=type(xpos_view))):
580580
pickle.dumps(xpos_view)
581581

582-
def test_plugins(self):
582+
def test_plugins_elasticity(self):
583583
root = mjcf.RootElement()
584584
root.extension.add('plugin', plugin='mujoco.elasticity.cable')
585585

@@ -603,7 +603,36 @@ def test_plugins(self):
603603
composite.geom.rgba = [0.8, 0.2, 0.1, 1]
604604
composite.geom.condim = 1
605605

606-
mjcf.Physics.from_mjcf_model(root)
606+
physics = mjcf.Physics.from_mjcf_model(root)
607+
physics.step()
608+
609+
def test_plugins_sdf(self):
610+
root = mjcf.RootElement()
611+
root.option.sdf_iterations = 10
612+
root.option.sdf_initpoints = 40
613+
614+
extension = root.extension.add('plugin', plugin='mujoco.sdf.torus')
615+
instance = extension.add('instance', name='torus')
616+
instance.add('config', key='radius1', value='0.35')
617+
instance.add('config', key='radius2', value='0.15')
618+
619+
# Replicate example in mujoco/model/plugin/elasticity/torus.xml
620+
mesh = root.asset.add('mesh', name='torus')
621+
mesh.add('plugin', instance='torus')
622+
623+
# Test we can add SDF geom to the worldbody.
624+
worldbody_geom = root.worldbody.add(
625+
'geom', type='sdf', mesh='torus', rgba=[.2, .2, .8, 1])
626+
worldbody_geom.add('plugin', instance='torus')
627+
628+
# Test we can add SDF geom to a body.
629+
body = root.worldbody.add('body', pos=[-1, 0, 3.8])
630+
body.add('freejoint')
631+
body_geom = body.add('geom', type='sdf', mesh='torus', rgba=[.2, .2, .8, 1])
632+
body_geom.add('plugin', instance='torus')
633+
634+
physics = mjcf.Physics.from_mjcf_model(root)
635+
physics.step()
607636

608637
if __name__ == '__main__':
609638
absltest.main()

0 commit comments

Comments
 (0)