Skip to content

Commit 6afbd26

Browse files
committed
refactor(wrappers): base environment is now platform based
- base environment is no longer the robot but for hardware an empty environment and for simulation the simulator - greenlet implementation: MultiRobotWrapper steps the simulation only once for environment wraps - adapted initialization of the environments wraps - added tests to test sim greenlet implementation
1 parent 833cf06 commit 6afbd26

26 files changed

Lines changed: 619 additions & 491 deletions

File tree

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ Flexibly compose your Gymnasium environment to fit your exact training needs. *F
4141
```python
4242
from time import sleep
4343

44-
import gymnasium as gym
4544
import numpy as np
4645
from rcs._core.sim import SimConfig
4746
from rcs.camera.sim import SimCameraSet
@@ -51,7 +50,8 @@ from rcs.envs.base import (
5150
GripperWrapper,
5251
RelativeActionSpace,
5352
RelativeTo,
54-
RobotEnv,
53+
RobotWrapper,
54+
SimEnv,
5555
)
5656
from rcs.envs.sim import GripperWrapperSim, RobotSimWrapper
5757
from rcs.envs.utils import (
@@ -82,18 +82,19 @@ if __name__ == "__main__":
8282

8383
# base env
8484
robot = rcs.sim.SimRobot(simulation, ik, robot_cfg)
85-
env: gym.Env = RobotEnv(robot, ControlMode.CARTESIAN_TQuat)
85+
env = SimEnv(simulation)
86+
env = RobotWrapper(env, robot, ControlMode.CARTESIAN_TQuat)
8687

8788
# gripper
8889
gripper = sim.SimGripper(simulation, gripper_cfg)
8990
env = GripperWrapper(env, gripper, binary=True)
9091

91-
env = RobotSimWrapper(env, simulation)
92-
env = GripperWrapperSim(env, gripper)
92+
env = RobotSimWrapper(env)
93+
env = GripperWrapperSim(env)
9394

9495
# camera
9596
camera_set = SimCameraSet(simulation, cameras, physical_units=True, render_on_demand=True)
96-
env = CameraSetWrapper(env, camera_set, include_depth=True)
97+
env = CameraSetWrapper(env, camera_set, include_depth=True) # type: ignore
9798

9899
# relative actions bounded by 10cm translation and 10 degree rotation
99100
env = RelativeActionSpace(env, max_mov=(0.1, np.deg2rad(10)), relative_to=RelativeTo.LAST_STEP)
@@ -104,14 +105,13 @@ if __name__ == "__main__":
104105
env.reset()
105106

106107
# access low level robot api to get current cartesian position
107-
print(env.unwrapped.robot.get_cartesian_position())
108+
print(env.get_wrapper_attr("robot").get_cartesian_position())
108109

109110
for _ in range(10):
110111
# move 1cm in x direction (forward) and close gripper
111112
act = {"tquat": [0.01, 0, 0, 0, 0, 0, 1], "gripper": [0]}
112113
obs, reward, terminated, truncated, info = env.step(act)
113114
print(obs)
114-
115115
```
116116

117117
> **Note:** This and other examples can be found in the [`examples/`]() folder.

examples/fr3/fr3_env_cartesian_control.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656
env_rel.reset()
5757

5858
# access low level robot api to get current cartesian position
59-
print(env_rel.unwrapped.robot.get_cartesian_position()) # type: ignore
59+
print(env_rel.get_wrapper_attr("robot").get_cartesian_position()) # type: ignore
6060

6161
for _ in range(100):
6262
for _ in range(10):

examples/fr3/fr3_env_joint_control.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656
input("the robot is going to move, press enter whenever you are ready")
5757

5858
# access low level robot api to get current cartesian position
59-
print(env_rel.unwrapped.robot.get_joint_position()) # type: ignore
59+
print(env_rel.get_wrapper_attr("robot").get_joint_position()) # type: ignore
6060

6161
for _ in range(100):
6262
obs, info = env_rel.reset()

examples/fr3/fr3_readme.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from time import sleep
22

3-
import gymnasium as gym
43
import numpy as np
54
from rcs._core.sim import SimConfig
65
from rcs.camera.sim import SimCameraSet
@@ -10,7 +9,8 @@
109
GripperWrapper,
1110
RelativeActionSpace,
1211
RelativeTo,
13-
RobotEnv,
12+
RobotWrapper,
13+
SimEnv,
1414
)
1515
from rcs.envs.sim import GripperWrapperSim, RobotSimWrapper
1616
from rcs.envs.utils import (
@@ -41,14 +41,15 @@
4141

4242
# base env
4343
robot = rcs.sim.SimRobot(simulation, ik, robot_cfg)
44-
env: gym.Env = RobotEnv(robot, ControlMode.CARTESIAN_TQuat)
44+
env = SimEnv(simulation)
45+
env = RobotWrapper(env, robot, ControlMode.CARTESIAN_TQuat)
4546

4647
# gripper
4748
gripper = sim.SimGripper(simulation, gripper_cfg)
4849
env = GripperWrapper(env, gripper, binary=True)
4950

50-
env = RobotSimWrapper(env, simulation)
51-
env = GripperWrapperSim(env, gripper)
51+
env = RobotSimWrapper(env)
52+
env = GripperWrapperSim(env)
5253

5354
# camera
5455
camera_set = SimCameraSet(simulation, cameras, physical_units=True, render_on_demand=True)
@@ -63,7 +64,7 @@
6364
env.reset()
6465

6566
# access low level robot api to get current cartesian position
66-
print(env.unwrapped.robot.get_cartesian_position())
67+
print(env.get_wrapper_attr("robot").get_cartesian_position())
6768

6869
for _ in range(10):
6970
# move 1cm in x direction (forward) and close gripper

examples/fr3/grasp_demo.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import mujoco
77
import numpy as np
88
from rcs._core.common import Pose
9-
from rcs.envs.base import GripperWrapper, RobotEnv
9+
from rcs._core.sim import SimRobot
10+
from rcs.envs.base import GripperWrapper
1011
from rcs.envs.creators import FR3SimplePickUpSimEnvCreator
1112

1213
logger = logging.getLogger(__name__)
@@ -16,8 +17,8 @@
1617
class PickUpDemo:
1718
def __init__(self, env: gym.Env):
1819
self.env = env
19-
self.unwrapped: RobotEnv = cast(RobotEnv, self.env.unwrapped)
20-
self.home_pose = self.unwrapped.robot.get_cartesian_position()
20+
self._robot = cast(SimRobot, self.env.get_wrapper_attr("robot"))
21+
self.home_pose = self._robot.get_cartesian_position()
2122

2223
def _action(self, pose: Pose, gripper: list[float]) -> dict[str, Any]:
2324
return {"xyzrpy": pose.xyzrpy(), "gripper": [gripper]}
@@ -32,7 +33,7 @@ def get_object_pose(self, geom_name) -> Pose:
3233
) * Pose(
3334
rpy_vector=np.array([0, 0, np.pi]), translation=[0, 0, 0] # type: ignore
3435
)
35-
return self.unwrapped.robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
36+
return self._robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
3637

3738
def generate_waypoints(self, start_pose: Pose, end_pose: Pose, num_waypoints: int) -> list[Pose]:
3839
waypoints = []
@@ -45,12 +46,13 @@ def step(self, action: dict) -> dict:
4546
return self.env.step(action)[0]
4647

4748
def plan_linear_motion(self, geom_name: str, delta_up: float, num_waypoints: int = 20) -> list[Pose]:
48-
end_eff_pose = self.unwrapped.robot.get_cartesian_position()
49+
end_eff_pose = self._robot.get_cartesian_position()
4950
goal_pose = self.get_object_pose(geom_name=geom_name)
5051
goal_pose *= Pose(translation=np.array([0, 0, delta_up]), quaternion=np.array([1, 0, 0, 0])) # type: ignore
5152
return self.generate_waypoints(end_eff_pose, goal_pose, num_waypoints=num_waypoints)
5253

5354
def execute_motion(self, waypoints: list[Pose], gripper: list[float] = GripperWrapper.BINARY_GRIPPER_OPEN) -> dict:
55+
obs = {}
5456
for i in range(len(waypoints)):
5557
obs = self.step(self._action(waypoints[i], gripper))
5658
return obs
@@ -65,13 +67,13 @@ def grasp(self, geom_name: str):
6567
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_OPEN)
6668

6769
for _ in range(4):
68-
self.step(self._action(self.unwrapped.robot.get_cartesian_position(), GripperWrapper.BINARY_GRIPPER_CLOSED))
70+
self.step(self._action(self._robot.get_cartesian_position(), GripperWrapper.BINARY_GRIPPER_CLOSED))
6971

7072
waypoints = self.plan_linear_motion(geom_name=geom_name, delta_up=0.2, num_waypoints=60)
7173
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_CLOSED)
7274

7375
def move_home(self):
74-
end_eff_pose = self.unwrapped.robot.get_cartesian_position()
76+
end_eff_pose = self._robot.get_cartesian_position()
7577
waypoints = self.generate_waypoints(end_eff_pose, self.home_pose, num_waypoints=60)
7678
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_CLOSED)
7779

@@ -90,7 +92,7 @@ def main():
9092
sleep(3)
9193
for _ in range(100):
9294
env.reset()
93-
print(env.unwrapped.robot.get_cartesian_position().translation()) # type: ignore
95+
print(env.get_wrapper_attr("robot").get_cartesian_position().translation()) # type: ignore
9496
controller = PickUpDemo(env)
9597
controller.pickup("box_geom")
9698

examples/fr3/grasp_digit_demo.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import mujoco
66
import numpy as np
77
from rcs._core.common import Pose
8-
from rcs.envs.base import GripperWrapper, RobotEnv
8+
from rcs._core.sim import SimRobot
9+
from rcs.envs.base import GripperWrapper
910
from rcs_tacto.creators import FR3TactoSimplePickUpSimEnvCreator
1011
from tqdm import tqdm
1112

@@ -16,8 +17,8 @@
1617
class PickUpDemo:
1718
def __init__(self, env: gym.Env):
1819
self.env = env
19-
self.unwrapped: RobotEnv = cast(RobotEnv, self.env.unwrapped)
20-
self.home_pose = self.unwrapped.robot.get_cartesian_position()
20+
self._robot = cast(SimRobot, self.env.get_wrapper_attr("robot"))
21+
self.home_pose = self._robot.get_cartesian_position()
2122

2223
def _action(self, pose: Pose, gripper: list[float]) -> dict[str, Any]:
2324
return {"xyzrpy": pose.xyzrpy(), "gripper": gripper}
@@ -30,7 +31,7 @@ def get_object_pose(self, geom_name) -> Pose:
3031
obj_pose_world_coordinates = Pose(
3132
translation=data.geom_xpos[geom_id], rotation=data.geom_xmat[geom_id].reshape(3, 3)
3233
)
33-
return self.unwrapped.robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
34+
return self._robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
3435

3536
def generate_waypoints(self, start_pose: Pose, end_pose: Pose, num_waypoints: int) -> list[Pose]:
3637
waypoints = []
@@ -43,7 +44,7 @@ def step(self, action: dict) -> dict:
4344
return self.env.step(action)[0]
4445

4546
def plan_linear_motion(self, geom_name: str, delta_up: float, num_waypoints: int = 200) -> list[Pose]:
46-
end_eff_pose = self.unwrapped.robot.get_cartesian_position()
47+
end_eff_pose = self._robot.get_cartesian_position()
4748
goal_pose = self.get_object_pose(geom_name=geom_name)
4849
goal_pose *= Pose(translation=np.array([0, 0, delta_up]), quaternion=np.array([1, 0, 0, 0])) # type: ignore
4950
return self.generate_waypoints(end_eff_pose, goal_pose, num_waypoints=num_waypoints)
@@ -68,7 +69,7 @@ def grasp(self, geom_name: str):
6869
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_CLOSED)
6970

7071
def move_home(self):
71-
end_eff_pose = self.unwrapped.robot.get_cartesian_position()
72+
end_eff_pose = self._robot.get_cartesian_position()
7273
waypoints = self.generate_waypoints(end_eff_pose, self.home_pose, num_waypoints=10)
7374
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_CLOSED)
7475

examples/fr3/grasp_ompl_demo.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import mujoco
77
import numpy as np
88
from rcs._core.common import Pose
9-
from rcs.envs.base import ControlMode, GripperWrapper, RobotEnv
9+
from rcs._core.sim import SimRobot
10+
from rcs.envs.base import ControlMode, GripperWrapper
1011
from rcs.envs.creators import FR3SimplePickUpSimEnvCreator
1112
from rcs.ompl.mj_ompl import MjOMPL
1213

@@ -38,9 +39,9 @@
3839
class OmplTrajectoryDemo:
3940
def __init__(self, env: gym.Env, planner: MjOMPL):
4041
self.env = env
41-
self.unwrapped: RobotEnv = cast(RobotEnv, self.env.unwrapped)
42-
self.home_pose: Pose = self.unwrapped.robot.get_cartesian_position()
43-
self.home_qpos: np.ndarray = self.unwrapped.robot.get_joint_position()
42+
self._robot = cast(SimRobot, self.env.get_wrapper_attr("robot"))
43+
self.home_pose: Pose = self._robot.get_cartesian_position()
44+
self.home_qpos: np.ndarray = self._robot.get_joint_position()
4445
self.sol_path = None
4546
self.planner = planner
4647

@@ -60,7 +61,7 @@ def get_object_pose(self, geom_name) -> Pose:
6061
) * Pose(
6162
rpy_vector=np.array([0, 0, np.pi]), translation=[0, 0, 0] # type: ignore
6263
)
63-
return self.unwrapped.robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
64+
return self._robot.to_pose_in_robot_coordinates(obj_pose_world_coordinates)
6465

6566
def plan_path_to_object(self, obj_name: str, delta_up):
6667
self.move_home()
@@ -83,7 +84,7 @@ def approach_and_grasp(self, obj_name: str, delta_up: float = 0.2):
8384

8485
obj_pose_grasp = obj_pose_og * Pose(translation=np.array([0, 0, delta_up]), quaternion=np.array([1, 0, 0, 0])) # type: ignore
8586
waypoints = self.generate_waypoints(
86-
start_pose=self.unwrapped.robot.get_cartesian_position(), end_pose=obj_pose_grasp, num_waypoints=5
87+
start_pose=self._robot.get_cartesian_position(), end_pose=obj_pose_grasp, num_waypoints=5
8788
)
8889
for waypoint in waypoints:
8990
self.step(self._jaction(waypoint, GripperWrapper.BINARY_GRIPPER_OPEN)) # type: ignore
@@ -108,7 +109,7 @@ def execute_motion(self, waypoints: list[Pose], gripper: list[float] = GripperWr
108109
return obs
109110

110111
def move_home(self):
111-
end_eff_pose = self.unwrapped.robot.get_cartesian_position()
112+
end_eff_pose = self._robot.get_cartesian_position()
112113
waypoints = self.generate_waypoints(end_eff_pose, self.home_pose, num_waypoints=15)
113114
self.execute_motion(waypoints=waypoints, gripper=GripperWrapper.BINARY_GRIPPER_CLOSED)
114115

extensions/rcs_fr3/src/rcs_fr3/creators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
ControlMode,
1414
GripperWrapper,
1515
HandWrapper,
16+
HardwareEnv,
1617
MultiRobotWrapper,
1718
RelativeActionSpace,
1819
RelativeTo,
19-
RobotEnv,
20+
RobotWrapper,
2021
)
2122
from rcs.envs.creators import RCSHardwareEnvCreator
2223
from rcs.hand.tilburg_hand import TilburgHand
@@ -91,7 +92,8 @@ def __call__( # type: ignore
9192
robot = hw.Franka(ip, ik)
9293
robot.set_config(robot_cfg)
9394

94-
env: gym.Env = RobotEnv(robot, ControlMode.JOINTS if collision_guard is not None else control_mode)
95+
env = HardwareEnv()
96+
env = RobotWrapper(env, robot, ControlMode.JOINTS if collision_guard is not None else control_mode)
9597

9698
env = FR3HW(env)
9799
if isinstance(gripper_cfg, hw.FHConfig):
@@ -154,7 +156,8 @@ def __call__( # type: ignore
154156

155157
envs = {}
156158
for key, ip in name2ip.items():
157-
env: gym.Env = RobotEnv(robots[key], control_mode)
159+
env = HardwareEnv()
160+
env = RobotWrapper(env, robots[key], control_mode)
158161
env = FR3HW(env)
159162
if gripper_cfg is not None:
160163
gripper = hw.FrankaHand(ip, gripper_cfg)

extensions/rcs_fr3/src/rcs_fr3/envs.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, SupportsFloat, cast
33

44
import gymnasium as gym
5-
from rcs.envs.base import RobotEnv
5+
from rcs._core.common import RobotPlatform
66
from rcs_fr3._core import hw
77

88
_logger = logging.getLogger(__name__)
@@ -11,9 +11,9 @@
1111
class FR3HW(gym.Wrapper):
1212
def __init__(self, env):
1313
super().__init__(env)
14-
self.unwrapped: RobotEnv
15-
assert isinstance(self.unwrapped.robot, hw.Franka), "Robot must be a hw.Franka instance."
16-
self.hw_robot = cast(hw.Franka, self.unwrapped.robot)
14+
assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.HARDWARE, "Base environment must be hardware."
15+
assert isinstance(self.get_wrapper_attr("robot"), hw.Franka), "Robot must be a hw.Franka instance."
16+
self.hw_robot = cast(hw.Franka, self.get_wrapper_attr("robot"))
1717
self._robot_state_keys: list[str] | None = None
1818

1919
def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict]:
@@ -24,14 +24,12 @@ def step(self, action: Any) -> tuple[dict[str, Any], SupportsFloat, bool, bool,
2424
except hw.exceptions.FrankaControlException as e:
2525
_logger.error("FrankaControlException: %s", e)
2626
self.hw_robot.automatic_error_recovery()
27-
# TODO: this does not work if some wrappers are in between
28-
# FR3HW and RobotEnv
2927
return self.get_obs(), 0, False, True, {}
3028

3129
def get_obs(self, obs: dict | None = None) -> dict[str, Any]:
3230
if obs is None:
33-
obs = dict(self.unwrapped.get_obs())
34-
robot_state = cast(hw.FrankaState, self.unwrapped.robot.get_state())
31+
obs = dict(self.get_wrapper_attr("get_robot_obs")())
32+
robot_state = cast(hw.FrankaState, self.hw_robot.get_state())
3533
obs["robot_state"] = self._rs2dict(robot_state.robot_state)
3634
return obs
3735

@@ -44,7 +42,7 @@ def _rs2dict(self, state: hw.RobotState):
4442
return {key: getattr(state, key) for key in self._robot_state_keys}
4543

4644
def reset(
47-
self, seed: int | None = None, options: dict[str, Any] | None = None
45+
self, *, seed: int | None = None, options: dict[str, Any] | None = None
4846
) -> tuple[dict[str, Any], dict[str, Any]]:
4947
return super().reset(seed=seed, options=options)
5048

0 commit comments

Comments
 (0)