@@ -244,8 +244,7 @@ def step_sim(self):
244244 else :
245245 self .sim .step_until_convergence ()
246246
247- def reset_sim (self ):
248- self .sim .reset ()
247+ def apply_sim_state (self ):
249248 self .sim .step (1 )
250249
251250
@@ -255,7 +254,20 @@ def reset(
255254 if self .main_greenlet is not None :
256255 self .main_greenlet .switch ()
257256 else :
258- self .reset_sim ()
257+ self .apply_sim_state ()
258+ return super ().reset (seed = seed , options = options )
259+
260+ class CoverWrapper (gym .Wrapper ):
261+ """The CoverWrapper must be the last wrapper on the stack
262+
263+ Only strictly necessary for simulator environments, but also works for hardware environments.
264+ It takes care of resetting the simulator before any other wrapper resets its state, already assuming
265+ a fresh simulator state.
266+ """
267+ def reset (self , * , seed : int | None = None , options : dict [str , Any ] | None = None ) -> tuple [dict [str , Any ], dict [str , Any ]]:
268+ if self .env .get_wrapper_attr ("PLATFORM" ) == RobotPlatform .SIMULATION :
269+ sim = cast (simulation .Sim , self .get_wrapper_attr ("sim" ))
270+ sim .reset ()
259271 return super ().reset (seed = seed , options = options )
260272
261273
@@ -333,25 +345,28 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
333345 ):
334346 msg = "Given type is not matching control mode!"
335347 raise RuntimeError (msg )
348+ last_action = self .prev_action
336349 self .prev_action = copy .deepcopy (action )
337350
351+ # shallow copy
352+ action = dict (action )
338353 if self .get_base_control_mode () == ControlMode .JOINTS and (
339- self . prev_action is None
340- or not np .allclose (action [self .joints_key ], self . prev_action [self .joints_key ], atol = 1e-03 , rtol = 0 )
354+ last_action is None
355+ or not np .allclose (action [self .joints_key ], last_action [self .joints_key ], atol = 1e-03 , rtol = 0 )
341356 ):
342357 self .robot .set_joint_position (action [self .joints_key ])
343358 action .pop (self .joints_key )
344359 elif self .get_base_control_mode () == ControlMode .CARTESIAN_TRPY and (
345- self . prev_action is None
346- or not np .allclose (action [self .trpy_key ], self . prev_action [self .trpy_key ], atol = 1e-03 , rtol = 0 )
360+ last_action is None
361+ or not np .allclose (action [self .trpy_key ], last_action [self .trpy_key ], atol = 1e-03 , rtol = 0 )
347362 ):
348363 self .robot .set_cartesian_position (
349364 common .Pose (translation = action [self .trpy_key ][:3 ], rpy_vector = action [self .trpy_key ][3 :])
350365 )
351366 action .pop (self .trpy_key )
352367 elif self .get_base_control_mode () == ControlMode .CARTESIAN_TQuat and (
353- self . prev_action is None
354- or not np .allclose (action [self .tquat_key ], self . prev_action [self .tquat_key ], atol = 1e-03 , rtol = 0 )
368+ last_action is None
369+ or not np .allclose (action [self .tquat_key ], last_action [self .tquat_key ], atol = 1e-03 , rtol = 0 )
355370 ):
356371 self .robot .set_cartesian_position (
357372 common .Pose (translation = action [self .tquat_key ][:3 ], quaternion = action [self .tquat_key ][3 :])
@@ -361,18 +376,13 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
361376
362377 def observation (self , observation : dict , info : dict [str , Any ]) -> tuple [dict [str , Any ], dict [str , Any ]]:
363378 observation .update (self .get_robot_obs ())
364- # if self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION:
365- # sim_robot = cast(SimRobot, self.robot)
366- # state = sim_robot.get_state()
367- # info["collision"] = state.collision
368- # info["ik_success"] = state.ik_success
369- # info["is_sim_converged"] = self.env.get_wrapper_attr("sim").is_converged()
370379 return observation , info
371380
372381
373382 def reset (
374383 self , * , seed : int | None = None , options : dict [str , Any ] | None = None
375384 ) -> tuple [dict [str , Any ], dict [str , Any ]]:
385+ self .prev_action = None
376386 self .robot .reset ()
377387 if self .home_on_reset :
378388 self .robot .move_home ()
@@ -405,6 +415,7 @@ def __init__(
405415 else :
406416 self .robot2world = robot2world
407417 self .lead_env : gym .Env | None = None
418+ self .sim : simulation .Sim | None = None
408419
409420 # make sure all envs are the same type (sim/real)
410421 for env in self .envs :
@@ -416,6 +427,9 @@ def __init__(
416427 self ._runs_in_sim = self .PLATFORM == RobotPlatform .SIMULATION
417428 if self ._runs_in_sim :
418429 self ._inject_main_greenlet ()
430+ assert isinstance (self .lead_env , SimEnv ), "something is wrong with the env, the base should be type SimEnv"
431+ self .sim = self .lead_env .get_wrapper_attr ("sim" )
432+
419433
420434 def _inject_main_greenlet (self ):
421435 main_gr = getcurrent ()
@@ -471,9 +485,7 @@ def make_step_gr(env_to_step):
471485 if self ._runs_in_sim :
472486 # SIM path: 3. UP: Gather observations
473487 # Resume robot greenlet. It returns the step results.
474- res = step_greenlets [key ].switch ()
475- ob , r , t , tr , i = res
476-
488+ ob , r , t , tr , info [key ] = step_greenlets [key ].switch ()
477489 else :
478490 # HARDWARE path
479491 act = self ._translate_pose (key , action [key ], to_world = False )
@@ -510,9 +522,9 @@ def make_reset_gr(env_to_reset, s, o):
510522 reset_greenlets [key ] = gr
511523 gr .switch ()
512524
513- # SIM path: 2. SIM: reset
525+ # SIM path: 2. SIM: apply state from rested wrappers
514526 assert isinstance (self .lead_env , SimEnv )
515- self .lead_env .reset_sim ()
527+ self .lead_env .apply_sim_state ()
516528
517529
518530 for key , env in self .envs .items ():
0 commit comments