Skip to content

Commit 1db96ed

Browse files
committed
upd: cleaned up code and add support for geoms
1 parent c6a386d commit 1db96ed

1 file changed

Lines changed: 129 additions & 52 deletions

File tree

python/rcs/ompl/mj_ompl.py

Lines changed: 129 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import xml.etree.ElementTree as ET
1212
from rcs._core.common import Pose
1313

14-
print("OMPL modules imported successfully.")
1514
DEFAULT_PLANNING_TIME = 5.0 # Default time allowed for planning in seconds
1615
INTERPOLATE_NUM = 500 # Number of points to interpolate between start and goal states
1716

@@ -67,7 +66,8 @@ def __init__(self,
6766
robot_root_name:str="base_0",
6867
robot_joint_name:str="fr3_joint#_0",
6968
robot_actuator_name:str="fr3_joint#_0",
70-
obstacle_body_names: list = None):
69+
obstacle_body_names: list = None,
70+
obstacle_geom_names: list = None):
7171
'''
7272
Initialize the robot object with the given parameters.
7373
It is essentially a thin wrapper around the RobotEnv (i.e. MuJoCo variables),
@@ -79,14 +79,14 @@ def __init__(self,
7979
- env: The RobotEnv environment in which the robot operates.
8080
- njoints: Number of joints in the robot.
8181
- robot_xml_name: Path to the robot's XML file.
82-
This file will be used to query the <body>s of the robot.
82+
This file will be used to query the <body>s of the robot to get collision checking info.
8383
- robot_root_name: Name of the root body of the robot,
8484
i.e. the top level <body>'s name.
8585
- robot_joint_name: Pattern of the robot's joints to be controlled,
8686
where # will be replaced by 1 to njoints.
8787
- robot_actuator_name: Pattern of the robot's actuators to be controlled,
8888
where # will be replaced by 1 to njoints.
89-
Actuators are not controlled, but is used to validate
89+
Actuators are not controlled, but is used to validate
9090
that the number of joints and actuators match.
9191
- obstacle_body_names: List of names of other <body>s to be checked for collision.
9292
If None, it will be set to an empty list.
@@ -117,39 +117,33 @@ def __init__(self,
117117
self.actuator_ids = [mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_ACTUATOR, name) \
118118
for name in self.actuator_names \
119119
if mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_ACTUATOR, name) > -1]
120-
self._set_obstacle_body_ids(obstacle_body_names)
120+
self.obstacle_body_ids = set()
121+
self.obstacle_geom_ids = set()
122+
self._add_obstacle_body_ids(obstacle_body_names)
123+
self._add_obstacle_geom_ids(obstacle_geom_names)
121124

122125
# Check if the number of joints, links, and actuators match
123126
assert len(self.joint_ids) == len(self.actuator_ids), \
124127
f"Mismatch in number of joints and actuators for robot {robot_name}."
125128

126129
self.njoints = len(self.joint_ids)
127130

128-
def _set_obstacle_body_ids(self, obstacle_body_names: list):
129-
"""
130-
Set the obstacle bodies for the robot.
131-
132-
Args:
133-
obstacle_body_names (list): List of names of other bodies to be checked for path planning.
134-
"""
135-
self.obstacle_body_ids = set([mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, name) \
136-
for name in obstacle_body_names \
137-
if mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, name) > -1])
138-
139131
def _remove_obstacle_body_ids(self, obstacle_body_names: list|str):
140132
"""
141133
Remove specified obstacle bodies from the robot's obstacle checks, given their names.
142134
143135
Args:
144136
obstacle_body_names (list): List of names of bodies to be removed from obstacle checks.
145137
"""
138+
if obstacle_body_names is None:
139+
return
146140
obstacle_body_names = obstacle_body_names if isinstance(obstacle_body_names, list) else [obstacle_body_names]
147141
for name in obstacle_body_names:
148142
body_id = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, name)
149143
if body_id in self.obstacle_body_ids:
150144
self.obstacle_body_ids.remove(body_id)
151145
else:
152-
warnings.warn(f"obstacle body name {name} does not exist in the model. Skipping removal.")
146+
raise RuntimeError(f"_remove_obstacle_body_ids: obstacle body name {name} does not exist in the set.")
153147

154148
def _add_obstacle_body_ids(self, obstacle_body_names: list|str):
155149
"""
@@ -158,14 +152,50 @@ def _add_obstacle_body_ids(self, obstacle_body_names: list|str):
158152
Args:
159153
obstacle_body_names (list): List of names of bodies to be added to obstacle checks.
160154
"""
155+
if obstacle_body_names is None:
156+
return
161157
obstacle_body_names = obstacle_body_names if isinstance(obstacle_body_names, list) else [obstacle_body_names]
162158
for name in obstacle_body_names:
163159
body_id = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_BODY, name)
164160
if body_id > -1:
165161
self.obstacle_body_ids.add(body_id)
166162
else:
167-
warnings.warn("obstacle body name {name} does not exist in the model. Skipping addition.")
163+
raise RuntimeError(f"_add_obstacle_body_ids: obstacle body name {name} does not exist in the model.")
164+
165+
def _add_obstacle_geom_ids(self, obstacle_geom_names: list|str):
166+
"""
167+
Add specified obstacle geoms to the robot's obstacle checks, given their names.
168168
169+
Args:
170+
obstacle_geom_names (list): List of names of geoms to be added to obstacle checks.
171+
"""
172+
if obstacle_geom_names is None:
173+
return
174+
obstacle_geom_names = obstacle_geom_names if isinstance(obstacle_geom_names, list) else [obstacle_geom_names]
175+
for name in obstacle_geom_names:
176+
geom_id = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_GEOM, name)
177+
if geom_id > -1:
178+
self.obstacle_geom_ids.add(geom_id)
179+
else:
180+
raise RuntimeError(f"_add_obstacle_geom_ids: obstacle geom name {name} does not exist in the model.")
181+
182+
def _remove_obstacle_geom_ids(self, obstacle_geom_names: list|str):
183+
"""
184+
Remove specified obstacle geoms from the robot's obstacle checks, given their names.
185+
186+
Args:
187+
obstacle_geom_names (list): List of names of geoms to be removed from obstacle checks.
188+
"""
189+
if obstacle_geom_names is None:
190+
return
191+
obstacle_geom_names = obstacle_geom_names if isinstance(obstacle_geom_names, list) else [obstacle_geom_names]
192+
for name in obstacle_geom_names:
193+
geom_id = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_GEOM, name)
194+
if geom_id in self.obstacle_geom_ids:
195+
self.obstacle_geom_ids.remove(geom_id)
196+
else:
197+
raise RuntimeError(f"_remove_obstacle_geom_ids: obstacle geom name {name} does not exist in the set.")
198+
169199
def check_collision(self, qpos: ob.State):
170200
'''
171201
Checks for collisions when the robot is set to the given joint positions.
@@ -201,6 +231,11 @@ def check_collision(self, qpos: ob.State):
201231
(b1 in self.robot_body_ids and b2 in self.robot_body_ids):
202232
has_collision = True
203233
break
234+
# Check for collisions with geoms
235+
if (c.geom1 in self.obstacle_geom_ids or c.geom2 in self.obstacle_geom_ids) and \
236+
(b1 in self.robot_body_ids or b2 in self.robot_body_ids):
237+
has_collision = True
238+
break
204239

205240
# Reset the simulation
206241
self.data.qpos = qpos_init
@@ -231,7 +266,7 @@ def ik(self, pose, q0=None, tcp_offset=None):
231266
numpy.ndarray: Joint positions that achieve the desired pose, or None if no solution is found.
232267
"""
233268
tcp_offset = tcp_offset if tcp_offset is not None else self.franka_hand_tcp
234-
return self.env.robot.get_ik().ik(pose, q0, tcp_offset)
269+
return self.env.unwrapped.robot.get_ik().ik(pose, q0, tcp_offset)
235270

236271

237272
class MjOStateSpace(ob.RealVectorStateSpace):
@@ -259,32 +294,64 @@ def set_state_sampler(self, state_sampler):
259294

260295
class MjOMPL():
261296
def __init__(self,
262-
robot:MjORobot=None,
263-
robot_name:str=None,
264-
robot_xml_name:str=None,
265-
robot_root_name:str=None,
266-
robot_joint_name:str=None,
267-
robot_actuator_name:str=None,
268-
njoints:int=None,
269-
robot_env: RobotEnv = None,
270-
obstacle_body_names: list = None):
271-
272-
if(robot is None):
273-
if robot_name is None or robot_xml_name is None or robot_root_name is None or \
274-
robot_joint_name is None or robot_actuator_name is None or njoints is None:
275-
raise ValueError("If robot is not provided, all parameters must be provided.")
276-
# Create a new MjORobot instance
277-
robot = MjORobot(robot_name,
278-
robot_env,
279-
njoints=njoints,
280-
robot_xml_name=robot_xml_name,
281-
robot_root_name=robot_root_name,
282-
robot_joint_name=robot_joint_name,
283-
robot_actuator_name=robot_actuator_name,
284-
obstacle_body_names=obstacle_body_names)
297+
robot_env: RobotEnv,
298+
robot_name:str,
299+
robot_xml_name:str,
300+
robot_root_name:str,
301+
robot_joint_name:str,
302+
robot_actuator_name:str,
303+
njoints:int,
304+
obstacle_body_names: list = None,
305+
obstacle_geom_names: list = None,
306+
interpolate_num:int=INTERPOLATE_NUM,
307+
default_planning_time:float=DEFAULT_PLANNING_TIME):
308+
'''
309+
Initialize the OMPL planner with the given parameters.
310+
Besides setting up the OMPL planning context, it instantiates an MjORobot instance,
311+
which is a thin wrapper around the RobotEnv (i.e. MuJoCo variables),
312+
for easily extracting the relevant bodies, joints, etc. from the MuJoCo model.
285313
314+
315+
Parameters:
316+
- robot_env: The RobotEnv environment in which the robot operates.
317+
- robot_name: Name of the robot.
318+
319+
- robot_xml_name: Path to the robot's XML file.
320+
This file will be used to query the <body>s of the robot to get collision checking info.
321+
- robot_root_name: Name of the root body of the robot,
322+
i.e. the top level <body>'s name.
323+
- robot_joint_name: Pattern of the robot's joints to be controlled,
324+
where # will be replaced by 1 to njoints.
325+
- robot_actuator_name: Pattern of the robot's actuators to be controlled,
326+
where # will be replaced by 1 to njoints.
327+
Actuators are not controlled, but is used to validate
328+
that the number of joints and actuators match.
329+
- njoints: Number of joints in the robot.
330+
- obstacle_body_names: List of names of other mjOBJ_BODYs to be checked for collision.
331+
If None, it will be set to an empty list.
332+
- obstacle_geom_names: List of names of other mjOBJ_GEOMs to be checked for collision.
333+
If None, it will be set to an empty list.
334+
- interpolate_num=100 (optional): Number of points to interpolate between start and goal states.
335+
- default_planning_time=5.0 (optional): Default time allowed for planning in seconds.
336+
337+
'''
338+
if robot_name is None or robot_xml_name is None or robot_root_name is None or \
339+
robot_joint_name is None or robot_actuator_name is None or njoints is None:
340+
raise ValueError("Initialization values are missing.")
341+
# Create a new MjORobot instance
342+
robot = MjORobot(robot_name,
343+
robot_env,
344+
njoints=njoints,
345+
robot_xml_name=robot_xml_name,
346+
robot_root_name=robot_root_name,
347+
robot_joint_name=robot_joint_name,
348+
robot_actuator_name=robot_actuator_name,
349+
obstacle_body_names=obstacle_body_names,
350+
obstacle_geom_names=obstacle_geom_names)
351+
286352
self.robot = robot
287-
self.interpolate_num = INTERPOLATE_NUM
353+
self.interpolate_num = interpolate_num
354+
self.default_planning_time = default_planning_time
288355

289356
# Create the planning space
290357
self.space = MjOStateSpace(robot.njoints)
@@ -358,8 +425,7 @@ def _plan_start_goal(self, start: np.ndarray, goal: np.ndarray, allowed_time = D
358425
res (bool): True if a solution was found, False otherwise.
359426
sol_path_list (list): List of joint positions in the solution path, if found.
360427
'''
361-
print("start_planning")
362-
print(self.planner.params())
428+
print("Planner params: \n", self.planner.params())
363429

364430
# set the start and goal states;
365431
s = ob.State(self.space)
@@ -381,8 +447,6 @@ def _plan_start_goal(self, start: np.ndarray, goal: np.ndarray, allowed_time = D
381447
sol_path_geometric.interpolate(self.interpolate_num)
382448
sol_path_states = sol_path_geometric.getStates()
383449
sol_path_list = [self.state_to_list(state) for state in sol_path_states]
384-
# print(len(sol_path_list))
385-
# print(sol_path_list)
386450
for sol_path in sol_path_list:
387451
self.is_state_valid(sol_path)
388452
res = True
@@ -427,26 +491,39 @@ def set_state_sampler(self, state_sampler):
427491
def add_collision_bodies(self, obstacle_body_names: list|str):
428492
"""
429493
Add specified obstacle bodies to the robot's obstacle checks.
494+
Prints a warning if the body name does not exist in the model.
430495
431496
Args:
432497
obstacle_body_names (list|str): List of names of bodies to be added to obstacle checks.
433498
"""
434499
self.robot._add_obstacle_body_ids(obstacle_body_names)
500+
501+
def add_collision_geoms(self, obstacle_geom_names: list|str):
502+
"""
503+
Add specified obstacle geometries to the robot's obstacle checks.
504+
Prints a warning if the geometry name does not exist in the model.
505+
506+
Args:
507+
obstacle_geom_names (list|str): List of names of geometries to be added to obstacle checks.
508+
"""
509+
self.robot._add_obstacle_geom_ids(obstacle_geom_names)
435510

436511
def remove_collision_bodies(self, obstacle_body_names: list|str):
437512
"""
438513
Remove specified obstacle bodies from the robot's obstacle checks.
514+
Prints a warning if the body name does not exist in the current set of obstacles.
439515
440516
Args:
441517
obstacle_body_names (list|str): List of names of bodies to be removed from obstacle checks.
442518
"""
443519
self.robot._remove_obstacle_body_ids(obstacle_body_names)
444520

445-
def set_collision_bodies(self, obstacle_body_names: list|str):
521+
def remove_collision_geoms(self, obstacle_geom_names: list|str):
446522
"""
447-
Set the obstacle bodies for the robot.
448-
523+
Remove specified obstacle geometries from the robot's obstacle checks.
524+
Prints a warning if the geometry name does not exist in the current set of obstacles.
525+
449526
Args:
450-
obstacle_body_names (list|str): List of names of bodies to be set as obstacles.
527+
obstacle_geom_names (list|str): List of names of geometries to be removed from obstacle checks.
451528
"""
452-
self.robot._set_obstacle_body_ids(obstacle_body_names)
529+
self.robot._remove_obstacle_geom_ids(obstacle_geom_names)

0 commit comments

Comments
 (0)