Skip to content

Commit ebe0d0b

Browse files
committed
fix(collision guard): back compatibility truncate episode
For backward compatibility with tests there is a default true option which truncates the episode when a collision occurred
1 parent d5fe17a commit ebe0d0b

1 file changed

Lines changed: 17 additions & 1 deletion

File tree

python/rcsss/envs/sim.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
check_home_collision: bool = True,
5050
to_joint_control: bool = False,
5151
sim_gui: bool = True,
52+
truncate_on_collision: bool = True,
5253
):
5354
super().__init__(env)
5455
self.unwrapped: FR3Env
@@ -58,6 +59,7 @@ def __init__(
5859
self._logger = logging.getLogger(__name__)
5960
self.check_home_collision = check_home_collision
6061
self.to_joint_control = to_joint_control
62+
self.truncate_on_collision = truncate_on_collision
6163
if to_joint_control:
6264
assert (
6365
self.unwrapped.get_unwrapped_control_mode(-2) == ControlMode.JOINTS
@@ -80,6 +82,11 @@ def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], SupportsFloat, b
8082
if info["collision"]:
8183
self._logger.warning("Collision detected! %s", info)
8284
action[self.unwrapped.joints_key] = self.unwrapped.robot.get_joint_position()
85+
if self.truncate_on_collision:
86+
if self.last_obs is None:
87+
msg = "Collision detected in the first step!"
88+
raise RuntimeError(msg)
89+
return self.last_obs[0], 0, True, True, info
8390

8491
obs, reward, done, truncated, info = super().step(action)
8592
self.last_obs = obs, info
@@ -114,6 +121,7 @@ def env_from_xml_paths(
114121
tcp_offset: rcsss.common.Pose | None = None,
115122
control_mode: ControlMode | None = None,
116123
sim_gui: bool = True,
124+
truncate_on_collision: bool = True,
117125
) -> "CollisionGuard":
118126
assert isinstance(env.unwrapped, FR3Env)
119127
simulation = sim.Sim(mjmld)
@@ -140,4 +148,12 @@ def env_from_xml_paths(
140148
gripper_cfg = sim.FHConfig()
141149
fh = sim.FrankaHand(simulation, id, gripper_cfg)
142150
c_env = GripperWrapper(c_env, fh)
143-
return cls(env, simulation, c_env, check_home_collision, to_joint_control, sim_gui)
151+
return cls(
152+
env=env,
153+
simulation=simulation,
154+
collision_env=c_env,
155+
check_home_collision=check_home_collision,
156+
to_joint_control=to_joint_control,
157+
sim_gui=sim_gui,
158+
truncate_on_collision=truncate_on_collision,
159+
)

0 commit comments

Comments
 (0)