@@ -49,6 +49,7 @@ def __init__(
4949 check_home_collision : bool = True ,
5050 to_joint_control : bool = False ,
5151 sim_gui : bool = True ,
52+ truncate_on_collision : bool = True ,
5253 ):
5354 super ().__init__ (env )
5455 self .unwrapped : FR3Env
@@ -58,6 +59,7 @@ def __init__(
5859 self ._logger = logging .getLogger (__name__ )
5960 self .check_home_collision = check_home_collision
6061 self .to_joint_control = to_joint_control
62+ self .truncate_on_collision = truncate_on_collision
6163 if to_joint_control :
6264 assert (
6365 self .unwrapped .get_unwrapped_control_mode (- 2 ) == ControlMode .JOINTS
@@ -68,23 +70,23 @@ def __init__(
6870 self .sim .open_gui ()
6971
7072 def step (self , action : dict [str , Any ]) -> tuple [dict [str , Any ], SupportsFloat , bool , bool , dict [str , Any ]]:
71- # TODO: we should set the state of the sim to the state of the real robot
73+
74+ self .collision_env .get_wrapper_attr ("robot" ).set_joints_hard (self .unwrapped .robot .get_joint_position ())
7275 _ , _ , _ , _ , info = self .collision_env .step (action )
76+
7377 if self .to_joint_control :
7478 fr3_env = self .collision_env .unwrapped
7579 assert isinstance (fr3_env , FR3Env ), "Collision env must be an FR3Env instance."
7680 action [self .unwrapped .joints_key ] = fr3_env .robot .get_joint_position ()
7781
78- # modify action to be joint angles down stream
79- if info ["collision" ] or not info ["ik_success" ] or not info ["is_sim_converged" ]:
80- # return old obs, with truncated and print warning
81- self ._logger .warning ("Collision detected! Truncating episode: %s" , info )
82- if self .last_obs is None :
83- msg = "Collisions detected and no old observation."
84- raise RuntimeError (msg )
85- old_obs , old_info = self .last_obs
86- old_info .update (info )
87- return old_obs , 0 , False , True , old_info
82+ if info ["collision" ]:
83+ self ._logger .warning ("Collision detected! %s" , info )
84+ action [self .unwrapped .joints_key ] = self .unwrapped .robot .get_joint_position ()
85+ if self .truncate_on_collision :
86+ if self .last_obs is None :
87+ msg = "Collision detected in the first step!"
88+ raise RuntimeError (msg )
89+ return self .last_obs [0 ], 0 , True , True , info
8890
8991 obs , reward , done , truncated , info = super ().step (action )
9092 self .last_obs = obs , info
@@ -119,6 +121,7 @@ def env_from_xml_paths(
119121 tcp_offset : rcsss .common .Pose | None = None ,
120122 control_mode : ControlMode | None = None ,
121123 sim_gui : bool = True ,
124+ truncate_on_collision : bool = True ,
122125 ) -> "CollisionGuard" :
123126 assert isinstance (env .unwrapped , FR3Env )
124127 simulation = sim .Sim (mjmld )
@@ -145,4 +148,12 @@ def env_from_xml_paths(
145148 gripper_cfg = sim .FHConfig ()
146149 fh = sim .FrankaHand (simulation , id , gripper_cfg )
147150 c_env = GripperWrapper (c_env , fh )
148- return cls (env , simulation , c_env , check_home_collision , to_joint_control , sim_gui )
151+ return cls (
152+ env = env ,
153+ simulation = simulation ,
154+ collision_env = c_env ,
155+ check_home_collision = check_home_collision ,
156+ to_joint_control = to_joint_control ,
157+ sim_gui = sim_gui ,
158+ truncate_on_collision = truncate_on_collision ,
159+ )
0 commit comments