Skip to content

Commit 53c0834

Browse files
authored
Merge pull request #152 from utn-mi/juelg/fix-collision-guard
fix(collision guard): hard sim update and identity action
2 parents d453279 + ebe0d0b commit 53c0834

5 files changed

Lines changed: 34 additions & 12 deletions

File tree

python/rcsss/_core/sim.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class FR3(rcsss._core.common.Robot):
8989
def __init__(self, sim: Sim, id: str, ik: rcsss._core.common.IK) -> None: ...
9090
def get_parameters(self) -> FR3Config: ...
9191
def get_state(self) -> FR3State: ...
92+
def set_joints_hard(self, q: numpy.ndarray[typing.Literal[7], numpy.dtype[numpy.float64]]) -> None: ...
9293
def set_parameters(self, cfg: FR3Config) -> bool: ...
9394

9495
class FR3Config(rcsss._core.common.RConfig):

python/rcsss/envs/sim.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
check_home_collision: bool = True,
5050
to_joint_control: bool = False,
5151
sim_gui: bool = True,
52+
truncate_on_collision: bool = True,
5253
):
5354
super().__init__(env)
5455
self.unwrapped: FR3Env
@@ -58,6 +59,7 @@ def __init__(
5859
self._logger = logging.getLogger(__name__)
5960
self.check_home_collision = check_home_collision
6061
self.to_joint_control = to_joint_control
62+
self.truncate_on_collision = truncate_on_collision
6163
if to_joint_control:
6264
assert (
6365
self.unwrapped.get_unwrapped_control_mode(-2) == ControlMode.JOINTS
@@ -68,23 +70,23 @@ def __init__(
6870
self.sim.open_gui()
6971

7072
def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict[str, Any]]:
71-
# TODO: we should set the state of the sim to the state of the real robot
73+
74+
self.collision_env.get_wrapper_attr("robot").set_joints_hard(self.unwrapped.robot.get_joint_position())
7275
_, _, _, _, info = self.collision_env.step(action)
76+
7377
if self.to_joint_control:
7478
fr3_env = self.collision_env.unwrapped
7579
assert isinstance(fr3_env, FR3Env), "Collision env must be an FR3Env instance."
7680
action[self.unwrapped.joints_key] = fr3_env.robot.get_joint_position()
7781

78-
# modify action to be joint angles down stream
79-
if info["collision"] or not info["ik_success"] or not info["is_sim_converged"]:
80-
# return old obs, with truncated and print warning
81-
self._logger.warning("Collision detected! Truncating episode: %s", info)
82-
if self.last_obs is None:
83-
msg = "Collisions detected and no old observation."
84-
raise RuntimeError(msg)
85-
old_obs, old_info = self.last_obs
86-
old_info.update(info)
87-
return old_obs, 0, False, True, old_info
82+
if info["collision"]:
83+
self._logger.warning("Collision detected! %s", info)
84+
action[self.unwrapped.joints_key] = self.unwrapped.robot.get_joint_position()
85+
if self.truncate_on_collision:
86+
if self.last_obs is None:
87+
msg = "Collision detected in the first step!"
88+
raise RuntimeError(msg)
89+
return self.last_obs[0], 0, True, True, info
8890

8991
obs, reward, done, truncated, info = super().step(action)
9092
self.last_obs = obs, info
@@ -119,6 +121,7 @@ def env_from_xml_paths(
119121
tcp_offset: rcsss.common.Pose | None = None,
120122
control_mode: ControlMode | None = None,
121123
sim_gui: bool = True,
124+
truncate_on_collision: bool = True,
122125
) -> "CollisionGuard":
123126
assert isinstance(env.unwrapped, FR3Env)
124127
simulation = sim.Sim(mjmld)
@@ -145,4 +148,12 @@ def env_from_xml_paths(
145148
gripper_cfg = sim.FHConfig()
146149
fh = sim.FrankaHand(simulation, id, gripper_cfg)
147150
c_env = GripperWrapper(c_env, fh)
148-
return cls(env, simulation, c_env, check_home_collision, to_joint_control, sim_gui)
151+
return cls(
152+
env=env,
153+
simulation=simulation,
154+
collision_env=c_env,
155+
check_home_collision=check_home_collision,
156+
to_joint_control=to_joint_control,
157+
sim_gui=sim_gui,
158+
truncate_on_collision=truncate_on_collision,
159+
)

src/pybind/rcsss.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ PYBIND11_MODULE(_core, m) {
470470
py::arg("sim"), py::arg("id"), py::arg("ik"))
471471
.def("get_parameters", &rcs::sim::FR3::get_parameters)
472472
.def("set_parameters", &rcs::sim::FR3::set_parameters, py::arg("cfg"))
473+
.def("set_joints_hard", &rcs::sim::FR3::set_joints_hard, py::arg("q"))
473474
.def("get_state", &rcs::sim::FR3::get_state);
474475
py::enum_<rcs::sim::CameraType>(sim, "CameraType")
475476
.value("free", rcs::sim::CameraType::free)

src/sim/FR3.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ void FR3::m_reset() {
208208
}
209209
}
210210

211+
void FR3::set_joints_hard(const common::Vector7d& q) {
212+
for (size_t i = 0; i < std::size(this->ids.joints); ++i) {
213+
size_t jnt_id = this->ids.joints[i];
214+
size_t jnt_qposadr = this->sim->m->jnt_qposadr[jnt_id];
215+
this->sim->d->qpos[jnt_qposadr] = q[i];
216+
}
217+
}
218+
211219
common::Pose FR3::get_base_pose_in_world_coordinates() {
212220
auto id = mj_name2id(this->sim->m, mjOBJ_BODY,
213221
(std::string("base_") + this->id).c_str());

src/sim/FR3.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class FR3 : public common::Robot {
5050
common::Pose get_base_pose_in_world_coordinates() override;
5151
std::optional<std::shared_ptr<common::IK>> get_ik() override;
5252
void reset() override;
53+
void set_joints_hard(const common::Vector7d &q);
5354

5455
private:
5556
FR3Config cfg;

0 commit comments

Comments
 (0)