Skip to content

Commit aa4ea95

Browse files
committed
implement client, server and tests.
1 parent 69d1125 commit aa4ea95

3 files changed

Lines changed: 276 additions & 0 deletions

File tree

python/rcs/rpc/client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import gymnasium as gym
2+
import rpyc
3+
from rpyc.utils.classic import obtain
4+
5+
class RcsClient(gym.Env):
6+
def __init__(self, host='localhost', port=50051):
7+
super().__init__()
8+
self.conn = rpyc.connect(host, port)
9+
self.server = self.conn.root
10+
# Optionally, fetch spaces from server if needed
11+
# self.observation_space = ...
12+
# self.action_space = ...
13+
14+
def step(self, action):
15+
return self.server.step(action)
16+
17+
def reset(self, **kwargs):
18+
return self.server.reset(**kwargs)
19+
20+
def get_obs(self):
21+
return self.server.get_obs()
22+
23+
@property
24+
def unwrapped(self):
25+
return self.server.unwrapped()
26+
27+
@property
28+
def action_space(self):
29+
return obtain(self.server.action_space())
30+
31+
def close(self):
32+
self.conn.close()

python/rcs/rpc/server.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
# import wrapper
3+
from gymnasium import Wrapper
4+
import rpyc
5+
from rpyc.utils.server import ThreadedServer
6+
rpyc.core.protocol.DEFAULT_CONFIG['allow_pickle'] = True
7+
8+
@rpyc.service
9+
class RcsServer(Wrapper, rpyc.Service):
10+
def __init__(self, env, host='localhost', port=50051):
11+
super().__init__(env)
12+
self.host = host
13+
self.port = port
14+
15+
@rpyc.exposed
16+
def step(self, action):
17+
"""Perform a step in the environment using the Wrapper base class."""
18+
return super().step(action)
19+
20+
@rpyc.exposed
21+
def reset(self, **kwargs):
22+
"""Reset the environment using the Wrapper base class."""
23+
return super().reset(**kwargs)
24+
25+
@rpyc.exposed
26+
def get_obs(self):
27+
"""Get the current observation using the Wrapper base class if available."""
28+
if hasattr(super(), 'get_obs'):
29+
return super().get_obs()
30+
elif hasattr(self.env, 'get_obs'):
31+
return self.env.get_obs()
32+
else:
33+
raise NotImplementedError("The environment does not have a get_obs method.")
34+
35+
@rpyc.exposed
36+
def unwrapped(self):
37+
"""Return the unwrapped environment using the Wrapper base class."""
38+
return super().unwrapped
39+
40+
@rpyc.exposed
41+
def action_space(self):
42+
"""Return the action space using the Wrapper base class."""
43+
return super().action_space
44+
45+
def start(self):
46+
import time
47+
print(f"Starting RcsServer RPC (looped OneShotServer) on {self.host}:{self.port}")
48+
t = ThreadedServer(self, port=self.port)
49+
t.start()

python/tests/test_rpc.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import multiprocessing
2+
import time
3+
import socket
4+
import sys
5+
import traceback
6+
import os
7+
import pytest
8+
from typing import Optional # Add this import at the top
9+
from rcs.envs.creators import SimEnvCreator
10+
from rcs.envs.utils import (
11+
default_mujoco_cameraset_cfg,
12+
default_sim_gripper_cfg,
13+
default_sim_robot_cfg,
14+
)
15+
from rcs.envs.base import ControlMode, RelativeTo
16+
from rcs.rpc.server import RcsServer
17+
from rcs.rpc.client import RcsClient
18+
19+
HOST = "127.0.0.1"
20+
21+
def get_free_port() -> int:
22+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
23+
s.bind((HOST, 0))
24+
return s.getsockname()[1]
25+
26+
def wait_for_port(
27+
host: str,
28+
port: int,
29+
timeout: float,
30+
server_proc: Optional[multiprocessing.Process] = None,
31+
err_q: Optional[multiprocessing.Queue] = None
32+
) -> None:
33+
start = time.time()
34+
last_exc = None
35+
while time.time() - start < timeout:
36+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
37+
s.settimeout(0.5)
38+
try:
39+
if s.connect_ex((host, port)) == 0:
40+
return
41+
except OSError as e:
42+
last_exc = e
43+
# If the server process died, surface its error immediately
44+
if server_proc is not None and not server_proc.is_alive():
45+
server_err = None
46+
if err_q is not None:
47+
try:
48+
server_err = err_q.get_nowait()
49+
except Exception:
50+
pass
51+
msg = f"Server process exited early (exitcode={server_proc.exitcode})."
52+
if server_err:
53+
msg += f"\nServer traceback:\n{server_err}"
54+
raise RuntimeError(msg)
55+
time.sleep(0.2)
56+
server_err = None
57+
if err_q is not None:
58+
try:
59+
server_err = err_q.get_nowait()
60+
except Exception:
61+
pass
62+
msg = f"Timed out waiting for {host}:{port} to open."
63+
if last_exc:
64+
msg += f" Last socket error: {last_exc}"
65+
if server_proc is not None and not server_proc.is_alive():
66+
msg += f" Server exitcode={server_proc.exitcode}."
67+
if server_err:
68+
msg += f"\nServer traceback:\n{server_err}"
69+
raise TimeoutError(msg)
70+
71+
def run_server(host: str, port: int, err_q: multiprocessing.Queue) -> None:
72+
try:
73+
env = SimEnvCreator()(
74+
control_mode=ControlMode.JOINTS,
75+
collision_guard=False,
76+
robot_cfg=default_sim_robot_cfg(),
77+
gripper_cfg=default_sim_gripper_cfg(),
78+
# Disabled to avoid rendering problem in python subprocess.
79+
#cameras=default_mujoco_cameraset_cfg(),
80+
max_relative_movement=0.1,
81+
relative_to=RelativeTo.LAST_STEP,
82+
)
83+
# Bind explicitly to IPv4 loopback
84+
server = RcsServer(env, host=host, port=port)
85+
try:
86+
server.start()
87+
finally:
88+
# If start returns (non-blocking implementation), keep process alive
89+
while True:
90+
time.sleep(1)
91+
except Exception:
92+
tb = "".join(traceback.format_exception(*sys.exc_info()))
93+
try:
94+
err_q.put(tb)
95+
except Exception:
96+
pass
97+
sys.exit(1)
98+
99+
def _mp_context() -> multiprocessing.context.BaseContext:
100+
# Prefer spawn to avoid fork-related issues with GL/MuJoCo/threaded libs
101+
methods = multiprocessing.get_all_start_methods()
102+
if "spawn" in methods:
103+
return multiprocessing.get_context("spawn")
104+
if "forkserver" in methods:
105+
return multiprocessing.get_context("forkserver")
106+
return multiprocessing.get_context(methods[0])
107+
108+
def _external_server_from_env() -> tuple[str, int] | None:
109+
# Set RCS_TEST_HOST and RCS_TEST_PORT to reuse an already running server.
110+
host = os.getenv("RCS_TEST_HOST")
111+
port = os.getenv("RCS_TEST_PORT")
112+
if host and port:
113+
try:
114+
return host, int(port)
115+
except ValueError:
116+
pass
117+
# Convenience: RCS_TEST_REUSE_SERVER=1 will use HOST + default port 50055
118+
if os.getenv("RCS_TEST_REUSE_SERVER") == "1":
119+
return HOST, 50055
120+
return None
121+
122+
def test_run_server_starts_and_stops():
123+
# Skip if reusing an external server
124+
ext = _external_server_from_env()
125+
if ext:
126+
pytest.skip("External server reuse enabled via env; skipping spawn test.")
127+
ctx = _mp_context()
128+
err_q = ctx.Queue()
129+
port = get_free_port()
130+
server_proc = ctx.Process(target=run_server, args=(HOST, port, err_q))
131+
server_proc.start()
132+
try:
133+
wait_for_port(HOST, port, timeout=120.0, server_proc=server_proc, err_q=err_q)
134+
assert server_proc.is_alive(), "Server process did not start as expected."
135+
finally:
136+
if server_proc.is_alive():
137+
server_proc.terminate()
138+
server_proc.join(timeout=5)
139+
assert not server_proc.is_alive(), "Server process did not terminate as expected."
140+
141+
class TestRcsClientServer:
142+
@classmethod
143+
def setup_class(cls):
144+
ext = _external_server_from_env()
145+
if ext:
146+
cls.host, cls.port = ext
147+
cls.server_proc = None
148+
cls.err_q = None
149+
wait_for_port(cls.host, cls.port, timeout=60.0)
150+
cls.client = RcsClient(host=cls.host, port=cls.port)
151+
return
152+
153+
ctx = _mp_context()
154+
cls.err_q = ctx.Queue()
155+
cls.host, cls.port = HOST, get_free_port()
156+
cls.server_proc = ctx.Process(target=run_server, args=(cls.host, cls.port, cls.err_q))
157+
cls.server_proc.start()
158+
# Wait until the server is actually listening or fail early if it crashed
159+
wait_for_port(cls.host, cls.port, timeout=180.0, server_proc=cls.server_proc, err_q=cls.err_q)
160+
cls.client = RcsClient(host=cls.host, port=cls.port)
161+
162+
@classmethod
163+
def teardown_class(cls):
164+
try:
165+
if getattr(cls, "client", None):
166+
cls.client.close()
167+
finally:
168+
if getattr(cls, "server_proc", None) and cls.server_proc and cls.server_proc.is_alive():
169+
cls.server_proc.terminate()
170+
cls.server_proc.join(timeout=5)
171+
172+
def test_reset(self):
173+
obs, info = self.client.reset()
174+
assert obs is not None, "reset did not return an observation"
175+
176+
def test_step(self):
177+
self.client.reset()
178+
act = self.client.action_space.sample()
179+
step_result = self.client.step(act)
180+
assert isinstance(step_result, (tuple, list)), "step did not return a tuple or list"
181+
182+
def test_get_obs(self):
183+
self.client.reset()
184+
obs2 = self.client.get_obs()
185+
assert obs2 is not None, "get_obs did not return an observation"
186+
187+
def test_unwrapped(self):
188+
_ = self.client.unwrapped
189+
190+
def test_close(self):
191+
self.client.close()
192+
# Reconnect for further tests
193+
wait_for_port(self.__class__.host, self.__class__.port, timeout=15.0,
194+
server_proc=self.__class__.server_proc, err_q=self.__class__.err_q)
195+
self.__class__.client = RcsClient(host=self.__class__.host, port=self.__class__.port)

0 commit comments

Comments
 (0)