@@ -39,9 +39,9 @@ def __init__(self, env, simulation: sim.Sim, sim_wrapper: Type[SimWrapper] | Non
3939 if sim_wrapper is not None :
4040 env = sim_wrapper (env , simulation )
4141 super ().__init__ (env )
42- # self.unwrapped: RobotEnv
43- # assert isinstance(self.unwrapped.robot, sim.SimRobot), "Robot must be a sim.SimRobot instance."
44- # self.sim_robot = cast(sim.SimRobot, self.unwrapped.robot)
42+ self .unwrapped : RobotEnv
43+ assert isinstance (self .unwrapped .robot , sim .SimRobot ), "Robot must be a sim.SimRobot instance."
44+ self .sim_robot = cast (sim .SimRobot , self .unwrapped .robot )
4545 self .sim = simulation
4646 cfg = self .sim .get_config ()
4747 self .frame_rate = SimpleFrameRate (1 / cfg .frequency , "RobotSimWrapper" )
@@ -56,19 +56,19 @@ def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, boo
5656 self .frame_rate ()
5757
5858 else :
59- # self.sim_robot.clear_collision_flag()
59+ self .sim_robot .clear_collision_flag ()
6060 self .sim .step_until_convergence ()
61- # state = self.sim_robot.get_state()
62- # if "collision" not in info:
63- # info["collision"] = state.collision
64- # else:
65- # info["collision"] = info["collision"] or state.collision
66- # info["ik_success"] = state.ik_success
61+ state = self .sim_robot .get_state ()
62+ if "collision" not in info :
63+ info ["collision" ] = state .collision
64+ else :
65+ info ["collision" ] = info ["collision" ] or state .collision
66+ info ["ik_success" ] = state .ik_success
6767 info ["is_sim_converged" ] = self .sim .is_converged ()
6868 # truncate episode if collision
69- # obs.update(self.unwrapped.get_obs())
70- # return obs, 0, False, info["collision"] or not state.ik_success, info
71- return obs , 0 , False , False , info
69+ obs .update (self .unwrapped .get_obs ())
70+ return obs , 0 , False , info ["collision" ] or not state .ik_success , info
71+ # return obs, 0, False, False, info
7272
7373 def reset (
7474 self , * , seed : int | None = None , options : dict [str , Any ] | None = None
@@ -77,7 +77,7 @@ def reset(
7777 obs , info = super ().reset (seed = seed , options = options )
7878 self .sim .step (1 )
7979 # todo: an obs method that is recursive over wrappers would be needed
80- # obs.update(self.unwrapped.get_obs())
80+ obs .update (self .unwrapped .get_obs ())
8181 return obs , info
8282
8383
0 commit comments