Skip to content

Commit 229a775

Browse files
committed
feat: added a configurable RandomObjectPos wrapper for flexible object placement in collector
1 parent 3ecd701 commit 229a775

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

python/rcs/envs/sim.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,43 @@ def env_from_xml_paths(
276276
truncate_on_collision=truncate_on_collision,
277277
)
278278

279+
class RandomObjectPos(SimWrapper):
280+
"""
281+
Wrapper to randomly re-place an object in the lab environments.
282+
Given the object's joint name and initial pose, its x, y coordinates are randomized, while z remains fixed.
283+
If include_rotation is true, the object's z-axis rotation (yaw) is also randomized.
284+
285+
Args:
286+
env (gym.Env): The environment to wrap.
287+
simulation (sim.Sim): The simulation instance.
288+
joint_name (str): The name of the free joint attached to the object to manipulate.
289+
init_object_pose (rcs.common.Pose): The initial pose of the object.
290+
include_rotation (bool): Whether to include rotation in the randomization.
291+
"""
292+
293+
def __init__(self, env: gym.Env, simulation: sim.Sim, joint_name: str, init_object_pose: rcs.common.Pose, include_rotation: bool = False):
294+
super().__init__(env, simulation)
295+
self.joint_name = joint_name
296+
self.init_object_pose = init_object_pose
297+
self.include_rotation = include_rotation
279298

299+
def reset(
300+
self, seed: int | None = None, options: dict[str, Any] | None = None
301+
) -> tuple[dict[str, Any], dict[str, Any]]:
302+
obs, info = super().reset(seed=seed, options=options)
303+
self.sim.step(1)
304+
305+
pos_z = self.init_object_pose.translation()[2]
306+
pos_x = self.init_object_pose.translation()[0] + np.random.random() * 0.2 - 0.1
307+
pos_y = self.init_object_pose.translation()[1] + np.random.random() * 0.2 - 0.1
308+
309+
if self.include_rotation:
310+
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, 2 * np.random.random() - 1, 0, 0, 1]
311+
else:
312+
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, 0, 0, 0, 1]
313+
314+
return obs, info
315+
280316
class RandomCubePos(SimWrapper):
281317
"""Wrapper to randomly place cube in the lab environments."""
282318

0 commit comments

Comments
 (0)