@@ -49,6 +49,7 @@ def __init__(
4949 check_home_collision : bool = True ,
5050 to_joint_control : bool = False ,
5151 sim_gui : bool = True ,
52+ truncate_on_collision : bool = True ,
5253 ):
5354 super ().__init__ (env )
5455 self .unwrapped : FR3Env
@@ -58,6 +59,7 @@ def __init__(
5859 self ._logger = logging .getLogger (__name__ )
5960 self .check_home_collision = check_home_collision
6061 self .to_joint_control = to_joint_control
62+ self .truncate_on_collision = truncate_on_collision
6163 if to_joint_control :
6264 assert (
6365 self .unwrapped .get_unwrapped_control_mode (- 2 ) == ControlMode .JOINTS
@@ -80,6 +82,11 @@ def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], SupportsFloat, b
8082 if info ["collision" ]:
8183 self ._logger .warning ("Collision detected! %s" , info )
8284 action [self .unwrapped .joints_key ] = self .unwrapped .robot .get_joint_position ()
85+ if self .truncate_on_collision :
86+ if self .last_obs is None :
87+ msg = "Collision detected in the first step!"
88+ raise RuntimeError (msg )
89+ return self .last_obs [0 ], 0 , True , True , info
8390
8491 obs , reward , done , truncated , info = super ().step (action )
8592 self .last_obs = obs , info
@@ -114,6 +121,7 @@ def env_from_xml_paths(
114121 tcp_offset : rcsss .common .Pose | None = None ,
115122 control_mode : ControlMode | None = None ,
116123 sim_gui : bool = True ,
124+ truncate_on_collision : bool = True ,
117125 ) -> "CollisionGuard" :
118126 assert isinstance (env .unwrapped , FR3Env )
119127 simulation = sim .Sim (mjmld )
@@ -140,4 +148,12 @@ def env_from_xml_paths(
140148 gripper_cfg = sim .FHConfig ()
141149 fh = sim .FrankaHand (simulation , id , gripper_cfg )
142150 c_env = GripperWrapper (c_env , fh )
143- return cls (env , simulation , c_env , check_home_collision , to_joint_control , sim_gui )
151+ return cls (
152+ env = env ,
153+ simulation = simulation ,
154+ collision_env = c_env ,
155+ check_home_collision = check_home_collision ,
156+ to_joint_control = to_joint_control ,
157+ sim_gui = sim_gui ,
158+ truncate_on_collision = truncate_on_collision ,
159+ )
0 commit comments