@@ -290,26 +290,45 @@ class RandomObjectPos(SimWrapper):
290290 include_rotation (bool): Whether to include rotation in the randomization.
291291 """
292292
293- def __init__ (self , env : gym .Env , simulation : sim .Sim , joint_name : str , init_object_pose : rcs .common .Pose , include_rotation : bool = False ):
293+ def __init__ (self , env : gym .Env ,
294+ simulation : sim .Sim ,
295+ joint_name : str ,
296+ init_object_pose : rcs .common .Pose ,
297+ include_position : bool = True ,
298+ include_rotation : bool = False ):
294299 super ().__init__ (env , simulation )
295300 self .joint_name = joint_name
296301 self .init_object_pose = init_object_pose
302+ self .include_position = include_position
297303 self .include_rotation = include_rotation
298304
299305 def reset (
300306 self , seed : int | None = None , options : dict [str , Any ] | None = None
301307 ) -> tuple [dict [str , Any ], dict [str , Any ]]:
308+ if (options is not None and "RandomObjectPos.init_object_pose" in options .keys ()):
309+ assert isinstance (options ["RandomObjectPos.init_object_pose" ], rcs .common .Pose ), \
310+ "RandomObjectPos.init_object_pose must be a rcs.common.Pose"
311+
312+ self .init_object_pose = options ["RandomObjectPos.init_object_pose" ]
313+ print ("Got random object pos!\n " , self .init_object_pose )
314+ del options ["RandomObjectPos.init_object_pose" ]
302315 obs , info = super ().reset (seed = seed , options = options )
303316 self .sim .step (1 )
304317
318+
305319 pos_z = self .init_object_pose .translation ()[2 ]
306- pos_x = self .init_object_pose .translation ()[0 ] + np .random .random () * 0.2 - 0.1
307- pos_y = self .init_object_pose .translation ()[1 ] + np .random .random () * 0.2 - 0.1
320+ if (self .include_position ):
321+ pos_x = self .init_object_pose .translation ()[0 ] + np .random .random () * 0.2 - 0.1
322+ pos_y = self .init_object_pose .translation ()[1 ] + np .random .random () * 0.2 - 0.1
323+ else :
324+ pos_x = self .init_object_pose .translation ()[0 ]
325+ pos_y = self .init_object_pose .translation ()[1 ]
308326
327+ quat = self .init_object_pose .rotation_q () # xyzw format
309328 if self .include_rotation :
310- self .sim .data .joint (self .joint_name ).qpos = [pos_x , pos_y , pos_z , 2 * np .random .random () - 1 , 0 , 0 , 1 ]
329+ self .sim .data .joint (self .joint_name ).qpos = [pos_x , pos_y , pos_z , 2 * np .random .random () - quat [ 3 ], quat [ 0 ], quat [ 1 ], quat [ 2 ] ]
311330 else :
312- self .sim .data .joint (self .joint_name ).qpos = [pos_x , pos_y , pos_z , 0 , 0 , 0 , 1 ]
331+ self .sim .data .joint (self .joint_name ).qpos = [pos_x , pos_y , pos_z , quat [ 3 ], quat [ 0 ], quat [ 1 ], quat [ 2 ] ]
313332
314333 return obs , info
315334
0 commit comments