Skip to content

Commit 65758b2

Browse files
committed
fix: RandomObjectPos include setting orientation
1 parent a416988 commit 65758b2

1 file changed

Lines changed: 24 additions & 5 deletions

File tree

python/rcs/envs/sim.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,26 +290,45 @@ class RandomObjectPos(SimWrapper):
290290
include_rotation (bool): Whether to include rotation in the randomization.
291291
"""
292292

293-
def __init__(self, env: gym.Env, simulation: sim.Sim, joint_name: str, init_object_pose: rcs.common.Pose, include_rotation: bool = False):
293+
def __init__(self, env: gym.Env,
294+
simulation: sim.Sim,
295+
joint_name: str,
296+
init_object_pose: rcs.common.Pose,
297+
include_position: bool = True,
298+
include_rotation: bool = False):
294299
super().__init__(env, simulation)
295300
self.joint_name = joint_name
296301
self.init_object_pose = init_object_pose
302+
self.include_position = include_position
297303
self.include_rotation = include_rotation
298304

299305
def reset(
300306
self, seed: int | None = None, options: dict[str, Any] | None = None
301307
) -> tuple[dict[str, Any], dict[str, Any]]:
308+
if(options is not None and "RandomObjectPos.init_object_pose" in options.keys()):
309+
assert isinstance(options["RandomObjectPos.init_object_pose"], rcs.common.Pose), \
310+
"RandomObjectPos.init_object_pose must be a rcs.common.Pose"
311+
312+
self.init_object_pose = options["RandomObjectPos.init_object_pose"]
313+
print("Got random object pos!\n", self.init_object_pose)
314+
del options["RandomObjectPos.init_object_pose"]
302315
obs, info = super().reset(seed=seed, options=options)
303316
self.sim.step(1)
304317

318+
305319
pos_z = self.init_object_pose.translation()[2]
306-
pos_x = self.init_object_pose.translation()[0] + np.random.random() * 0.2 - 0.1
307-
pos_y = self.init_object_pose.translation()[1] + np.random.random() * 0.2 - 0.1
320+
if(self.include_position):
321+
pos_x = self.init_object_pose.translation()[0] + np.random.random() * 0.2 - 0.1
322+
pos_y = self.init_object_pose.translation()[1] + np.random.random() * 0.2 - 0.1
323+
else:
324+
pos_x = self.init_object_pose.translation()[0]
325+
pos_y = self.init_object_pose.translation()[1]
308326

327+
quat = self.init_object_pose.rotation_q() # xyzw format
309328
if self.include_rotation:
310-
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, 2 * np.random.random() - 1, 0, 0, 1]
329+
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, 2 * np.random.random() - quat[3], quat[0], quat[1], quat[2]]
311330
else:
312-
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, 0, 0, 0, 1]
331+
self.sim.data.joint(self.joint_name).qpos = [pos_x, pos_y, pos_z, quat[3], quat[0], quat[1], quat[2]]
313332

314333
return obs, info
315334

0 commit comments

Comments
 (0)