From 1fca36452bdd93c6ce472fe7e6501a447986610b Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:11:18 +0200 Subject: [PATCH 01/17] Remove isaaclab_arena/remote_policy/ directory Delete the entire networking module: PolicyServer, PolicyClient, MessageSerializer, ServerSidePolicy, ActionProtocol, RemotePolicyConfig, and the server runner CLI. These duplicate Isaac-GR00T's gr00t/policy/server_client.py. Each policy framework owns its own server and client. Signed-off-by: Clemens Volk --- isaaclab_arena/remote_policy/__init__.py | 22 -- .../remote_policy/action_protocol.py | 94 -------- .../remote_policy/message_serializer.py | 130 ----------- isaaclab_arena/remote_policy/policy_client.py | 131 ----------- isaaclab_arena/remote_policy/policy_server.py | 216 ------------------ .../remote_policy/remote_policy_config.py | 18 -- .../remote_policy_server_runner.py | 116 ---------- .../remote_policy/server_side_policy.py | 207 ----------------- 8 files changed, 934 deletions(-) delete mode 100644 isaaclab_arena/remote_policy/__init__.py delete mode 100644 isaaclab_arena/remote_policy/action_protocol.py delete mode 100644 isaaclab_arena/remote_policy/message_serializer.py delete mode 100644 isaaclab_arena/remote_policy/policy_client.py delete mode 100644 isaaclab_arena/remote_policy/policy_server.py delete mode 100644 isaaclab_arena/remote_policy/remote_policy_config.py delete mode 100644 isaaclab_arena/remote_policy/remote_policy_server_runner.py delete mode 100644 isaaclab_arena/remote_policy/server_side_policy.py diff --git a/isaaclab_arena/remote_policy/__init__.py b/isaaclab_arena/remote_policy/__init__.py deleted file mode 100644 index 6b2258a35..000000000 --- a/isaaclab_arena/remote_policy/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from .action_protocol import ActionMode, ActionProtocol, ChunkingActionProtocol -from .message_serializer import MessageSerializer -from .policy_client import PolicyClient -from .policy_server import PolicyServer -from .remote_policy_config import RemotePolicyConfig -from .server_side_policy import ServerSidePolicy - -__all__ = [ - "RemotePolicyConfig", - "ServerSidePolicy", - "MessageSerializer", - "PolicyClient", - "PolicyServer", - "ActionMode", - "ActionProtocol", - "ChunkingActionProtocol", -] diff --git a/isaaclab_arena/remote_policy/action_protocol.py b/isaaclab_arena/remote_policy/action_protocol.py deleted file mode 100644 index 5d4df2203..000000000 --- a/isaaclab_arena/remote_policy/action_protocol.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Any, ClassVar - - -class ActionMode(str, Enum): - """Action output mode of a policy. - - Currently only CHUNK is used. - Other modes can be added later if needed. - """ - - CHUNK = "chunk" - - -@dataclass -class ActionProtocol(ABC): - """Base handshake/config for a policy's action output. - - - Encapsulates the ActionMode. - - Holds common fields (action_dim, observation_keys). - - Subclasses add mode-specific fields (e.g. chunk_length). - """ - - # Subclasses must override this. - MODE: ClassVar[ActionMode | None] = None - - # Common fields for all modes. - action_dim: int - observation_keys: list[str] - - def __post_init__(self) -> None: - """Validate that subclasses configured MODE properly.""" - mode = type(self).MODE - if mode is None: - raise NotImplementedError(f"{type(self).__name__} must define MODE as an ActionMode.") - - @classmethod - @abstractmethod - def from_dict(cls, data: dict[str, Any]) -> ActionProtocol: - """Build protocol config from server-side config dict.""" - - @abstractmethod - def to_dict(self) -> dict[str, Any]: - """Serialize protocol config to a dict for RPC.""" - - @property - def mode(self) -> ActionMode: - return self.MODE - - -@dataclass -class ChunkingActionProtocol(ActionProtocol): - """ActionProtocol for CHUNK mode. - - action_chunk_length: - Number of actions that the client-side policy consumes from each - chunk at a time during post-processing. - action_horizon: - Total length of the action sequence produced by the model for a - single query. - """ - - MODE: ClassVar[ActionMode] = ActionMode.CHUNK - - # Mode-specific field. - action_chunk_length: int - action_horizon: int - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> ChunkingActionProtocol: - return cls( - action_dim=int(data["action_dim"]), - observation_keys=list(data["observation_keys"]), - action_chunk_length=int(data["action_chunk_length"]), - action_horizon=int(data["action_horizon"]), - ) - - def to_dict(self) -> dict[str, Any]: - return { - "action_mode": self.mode.value, - "action_dim": self.action_dim, - "observation_keys": self.observation_keys, - "action_chunk_length": self.action_chunk_length, - "action_horizon": self.action_horizon, - } diff --git a/isaaclab_arena/remote_policy/message_serializer.py b/isaaclab_arena/remote_policy/message_serializer.py deleted file mode 100644 index 94167cd2d..000000000 --- a/isaaclab_arena/remote_policy/message_serializer.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import io -import numpy as np -from dataclasses import asdict, is_dataclass -from enum import Enum -from typing import Any - -import msgpack - - -class MessageSerializer: - """Msgpack-based serializer for dict-based policy messages. - - Supports: - - standard Python types, - - dataclasses (via to_json_serializable), - - numpy.ndarray (tagged as __ndarray_class__), - - generic binary blobs (tagged as __blob_class__). - """ - - @staticmethod - def to_bytes(data: Any) -> bytes: - """Serialize a Python object to bytes using msgpack.""" - return msgpack.packb(data, default=MessageSerializer._encode_custom) - - @staticmethod - def from_bytes(data: bytes) -> Any: - """Deserialize bytes into Python objects, decoding custom tags.""" - return msgpack.unpackb(data, object_hook=MessageSerializer._decode_custom) - - # ------------------------------------------------------------------ # - # Custom encode / decode - # ------------------------------------------------------------------ # - - @staticmethod - def _decode_custom(obj: Any) -> Any: - """Decode tagged structures created in _encode_custom. - - This function is registered as the `object_hook` for msgpack.unpackb, - so it is called once for every decoded map/dict. - - - If the dict contains a special tag (e.g. '__ndarray_class__' or - '__blob_class__'), it is converted back into the corresponding - high-level type (numpy array, blob, etc.). - - If the dict has no special tag, it is returned unchanged. In that - case the object stays as whatever type msgpack's default decoder - produced (dict, list, int, str, ...). - - Untagged values and non-dict types are therefore handled entirely - by msgpack's built-in decoder. - """ - if not isinstance(obj, dict): - return obj - - # numpy array - if "__ndarray_class__" in obj: - return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False) - - # generic binary blob - if "__blob_class__" in obj: - return { - "mime": obj.get("mime"), - "data": obj.get("as_bytes"), - } - - # other tagged types can be added here - return obj - - @staticmethod - def _encode_custom(obj: Any) -> Any: - """Encode special Python objects into msgpack-friendly structures.""" - - # numpy array -> npy bytes - if isinstance(obj, np.ndarray): - output = io.BytesIO() - np.save(output, obj, allow_pickle=False) - return {"__ndarray_class__": True, "as_npy": output.getvalue()} - - # generic binary blob: bytes / bytearray - if isinstance(obj, (bytes, bytearray)): - return { - "__blob_class__": True, - "mime": None, - "as_bytes": bytes(obj), - } - - # optional: custom Image/Frame types with to_bytes() and mime attribute - if hasattr(obj, "to_bytes") and hasattr(obj, "mime"): - return { - "__blob_class__": True, - "mime": getattr(obj, "mime"), - "as_bytes": obj.to_bytes(), - } - - # fall back to JSON-serializable representation - return to_json_serializable(obj) - - -def to_json_serializable(obj: Any) -> Any: - """Recursively convert dataclasses and numpy arrays to JSON-serializable format. - - This is useful when encoding configuration objects or metadata. - """ - if is_dataclass(obj) and not isinstance(obj, type): - return to_json_serializable(asdict(obj)) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.bool_): - return bool(obj) - elif isinstance(obj, dict): - return {key: to_json_serializable(value) for key, value in obj.items()} - elif isinstance(obj, (list, tuple, set)): - return [to_json_serializable(item) for item in obj] - elif isinstance(obj, (str, int, float, bool, type(None))): - return obj - elif isinstance(obj, Enum): - return obj.name - else: - # Fallback: convert to string - return str(obj) diff --git a/isaaclab_arena/remote_policy/policy_client.py b/isaaclab_arena/remote_policy/policy_client.py deleted file mode 100644 index 04e25b2e0..000000000 --- a/isaaclab_arena/remote_policy/policy_client.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import warnings -from typing import Any - -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class PolicyClient: - """Synchronous client for talking to a PolicyServer over ZeroMQ.""" - - def __init__(self, config: RemotePolicyConfig) -> None: - self._config = config - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REQ) - self._socket.setsockopt(zmq.RCVTIMEO, self._config.timeout_ms) - self._socket.connect(f"tcp://{self._config.host}:{self._config.port}") - - # ------------------------------------------------------------------ # - # Public API - # ------------------------------------------------------------------ # - - def ping(self) -> bool: - """Check if the server is reachable.""" - try: - self.call_endpoint("ping", requires_input=False) - return True - except Exception as exc: - warnings.warn( - f"[PolicyClient] Failed to ping remote policy server at {self._config.host}:{self._config.port}: {exc}" - ) - return False - - def reset(self, env_ids=None, options: dict[str, Any] | None = None) -> Any: - """Reset remote policy state.""" - resp = self.call_endpoint( - endpoint="reset", - data={"env_ids": env_ids, "options": options}, - requires_input=True, - ) - if isinstance(resp, dict): - status = resp.get("status") - if status not in ("reset_success", "ok", "reset_ok", None): - raise RuntimeError(f"Remote reset failed with status={status}, resp={resp}") - return resp - - def kill(self) -> Any: - """Ask remote server to stop main loop.""" - return self.call_endpoint("kill", requires_input=False) - - def get_action( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Send policy_observations and get back policy action dict.""" - payload: dict[str, Any] = {"observation": observation} - - resp = self.call_endpoint( - endpoint="get_action", - data=payload, - requires_input=True, - ) - return resp - - def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: - """Call get_init_info on the server with a requested_action_mode. - - Args: - requested_action_mode: ActionMode value (e.g. "chunk"). - - Returns: - A dict returned by the server, expected to contain: - - "status" - - "message" (optional) - - "config" (on success) - """ - payload = {"requested_action_mode": requested_action_mode} - resp = self.call_endpoint( - "get_init_info", - data=payload, - requires_input=True, - ) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from get_init_info, got {type(resp)!r}") - return resp - - def set_task_description(self, task_description: str | None) -> dict[str, Any]: - """Send task description to the remote policy.""" - payload: dict[str, Any] = {"task_description": task_description} - resp = self.call_endpoint( - endpoint="set_task_description", - data=payload, - requires_input=True, - ) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from set_task_description, got {type(resp)!r}") - return resp - - def call_endpoint( - self, - endpoint: str, - data: dict[str, Any] | None = None, - requires_input: bool = True, - ) -> Any: - """Generic RPC helper.""" - request: dict[str, Any] = {"endpoint": endpoint} - if requires_input: - request["data"] = data or {} - if self._config.api_token: - request["api_token"] = self._config.api_token - - self._socket.send(MessageSerializer.to_bytes(request)) - message = self._socket.recv() - response = MessageSerializer.from_bytes(message) - - if isinstance(response, dict) and "error" in response: - raise RuntimeError(f"Server error: {response['error']}") - return response - - def close(self) -> None: - """Close the underlying ZeroMQ socket and context.""" - self._socket.close() - self._context.term() diff --git a/isaaclab_arena/remote_policy/policy_server.py b/isaaclab_arena/remote_policy/policy_server.py deleted file mode 100644 index 27a1d8498..000000000 --- a/isaaclab_arena/remote_policy/policy_server.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy - - -@dataclass -class EndpointHandler: - handler: Callable[..., Any] - requires_input: bool = True - - -class PolicyServer: - def __init__( - self, - policy: ServerSidePolicy, - host: str = "*", - port: int = 5555, - api_token: str | None = None, - timeout_ms: int = 15000, - ) -> None: - self._policy = policy - self._running = True - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REP) - self._socket.setsockopt(zmq.RCVTIMEO, timeout_ms) - bind_addr = f"tcp://{host}:{port}" - print(f"[PolicyServer] binding on {bind_addr}") - self._socket.bind(bind_addr) - self._api_token = api_token - - self._endpoints: dict[str, EndpointHandler] = {} - self._register_default_endpoints() - - def _register_default_endpoints(self) -> None: - self.register_endpoint("ping", self._handle_ping, requires_input=False) - self.register_endpoint("kill", self._handle_kill, requires_input=False) - self.register_endpoint("get_action", self._handle_get_action, requires_input=True) - self.register_endpoint("reset", self._handle_reset, requires_input=True) - self.register_endpoint("get_init_info", self._handle_get_init_info, requires_input=True) - self.register_endpoint("set_task_description", self._handle_set_task_description, requires_input=True) - print(f"[PolicyServer] registered endpoints: {list(self._endpoints.keys())}") - - def register_endpoint( - self, - name: str, - handler: Callable[..., Any], - requires_input: bool = True, - ) -> None: - self._endpoints[name] = EndpointHandler(handler=handler, requires_input=requires_input) - - def _handle_get_init_info( - self, - requested_action_mode: str, - ) -> dict[str, Any]: - print(f"[PolicyServer] handle get_init_info: requested_action_mode={requested_action_mode!r}") - resp = self._policy.get_init_info(requested_action_mode=requested_action_mode) - if not isinstance(resp, dict): - raise TypeError(f"Policy.get_init_info() must return dict, got {type(resp)!r}") - return resp - - def _handle_set_task_description( - self, - task_description: str | None = None, - **_: Any, - ) -> dict[str, Any]: - print(f"[PolicyServer] handle set_task_description: {task_description!r}") - resp = self._policy.set_task_description(task_description) - if not isinstance(resp, dict): - raise TypeError(f"Policy.set_task_description() must return dict, got {type(resp)!r}") - return resp - - def _handle_ping(self) -> dict[str, Any]: - print("[PolicyServer] handle ping") - return {"status": "ok"} - - def _handle_kill(self) -> dict[str, Any]: - print("[PolicyServer] handle kill -> stopping") - self._running = False - return {"status": "stopping"} - - def _handle_get_action( - self, - observation: dict[str, Any], - options: dict[str, Any] | None = None, - **_: Any, - ) -> dict[str, Any]: - print("[PolicyServer] handle get_action") - if options is not None: - print(f" options keys: {list(options.keys())}") - action, info = self._policy.get_action( - observation=observation, - options=options, - ) - - if not isinstance(action, dict): - raise TypeError(f"Policy.get_action() must return (dict, dict), got action type={type(action)!r}") - if not isinstance(info, dict): - raise TypeError(f"Policy.get_action() must return (dict, dict), got info type={type(info)!r}") - - merged: dict[str, Any] = {} - merged.update(action) - if any(k in merged for k in info.keys()): - raise ValueError(f"Policy info keys conflict with action keys: {set(merged.keys()) & set(info.keys())}") - merged.update(info) - - return merged - - def _handle_reset(self, env_ids=None, options=None, **_: Any) -> dict[str, Any]: - print(f"[PolicyServer] handle reset: env_ids={env_ids}, options={options}") - status: dict[str, Any] = {"status": "reset_success"} - if hasattr(self._policy, "reset"): - resp = self._policy.reset(env_ids=env_ids, reset_options=options) - if isinstance(resp, dict): - status.update(resp) - return status - - def _validate_token(self, request: dict[str, Any]) -> bool: - if self._api_token is None: - return True - ok = request.get("api_token") == self._api_token - if not ok: - print("[PolicyServer] invalid api_token in request") - return ok - - def run(self) -> None: - addr = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) - print(f"[PolicyServer] listening on {addr}, api_token={self._api_token!r}") - while self._running: - try: - raw = self._socket.recv() - print(f"[PolicyServer] received {len(raw)} bytes") - request = MessageSerializer.from_bytes(raw) - - if not isinstance(request, dict): - raise TypeError(f"Expected dict request, got {type(request)!r}") - - print(f"[PolicyServer] request keys: {list(request.keys())}") - - if not self._validate_token(request): - self._socket.send(MessageSerializer.to_bytes({"error": "Unauthorized: invalid api_token"})) - continue - - endpoint = request.get("endpoint", "get_action") - if "endpoint" not in request: - self._socket.send(MessageSerializer.to_bytes({"error": "Missing 'endpoint' in request"})) - continue - - endpoint = request["endpoint"] - - handler = self._endpoints.get(endpoint) - if handler is None: - raise ValueError(f"Unknown endpoint: {endpoint}") - print(f"[PolicyServer] dispatch endpoint='{endpoint}'") - - data = request.get("data", {}) or {} - if not isinstance(data, dict): - raise TypeError(f"Expected dict data, got {type(data)!r}") - - if handler.requires_input: - result = handler.handler(**data) - else: - result = handler.handler() - - resp_bytes = MessageSerializer.to_bytes(result) - print(f"[PolicyServer] sending response ({len(resp_bytes)} bytes)") - self._socket.send(resp_bytes) - except zmq.Again: - # timeout, loop again - continue - except Exception as exc: - import traceback - - print(f"[PolicyServer] Error: {exc}") - print(traceback.format_exc()) - self._socket.send(MessageSerializer.to_bytes({"error": str(exc)})) - - def close(self) -> None: - """Stop the main loop and close ZMQ resources.""" - self._running = False - try: - self._socket.close(0) - except Exception as exc: - print(f"[PolicyServer] socket.close() error: {exc}") - try: - self._context.term() - except Exception as exc: - print(f"[PolicyServer] context.term() error: {exc}") - - @staticmethod - def start( - policy: ServerSidePolicy, - host: str = "*", - port: int = 5555, - api_token: str | None = None, - timeout_ms: int = 15000, - ) -> None: - server = PolicyServer( - policy=policy, - host=host, - port=port, - api_token=api_token, - timeout_ms=timeout_ms, - ) - server.run() diff --git a/isaaclab_arena/remote_policy/remote_policy_config.py b/isaaclab_arena/remote_policy/remote_policy_config.py deleted file mode 100644 index a256f14cb..000000000 --- a/isaaclab_arena/remote_policy/remote_policy_config.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class RemotePolicyConfig: - """Configuration for using a remote PolicyServer.""" - - host: str - port: int - api_token: str | None = None - timeout_ms: int = 15000 diff --git a/isaaclab_arena/remote_policy/remote_policy_server_runner.py b/isaaclab_arena/remote_policy/remote_policy_server_runner.py deleted file mode 100644 index c96cd00b6..000000000 --- a/isaaclab_arena/remote_policy/remote_policy_server_runner.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - - -from __future__ import annotations - -import argparse -from importlib import import_module - -from isaaclab_arena.remote_policy.policy_server import PolicyServer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy - - -def get_policy_cls(policy_type: str) -> type[ServerSidePolicy]: - """Dynamically import and return a ServerSidePolicy subclass. - - The policy_type argument must be a fully qualified Python path of the form: - "package.subpackage.module.ClassName" - """ - print(f"[remote_policy_server_runner] Importing server-side policy from: {policy_type}") - if "." not in policy_type: - raise ValueError( - "policy_type must be a dotted Python import path of the form " - "'module.submodule.ClassName', " - f"got: {policy_type!r}" - ) - module_path, class_name = policy_type.rsplit(".", 1) - module = import_module(module_path) - policy_cls = getattr(module, class_name) - return policy_cls - - -def build_base_parser() -> argparse.ArgumentParser: - """Build the base CLI parser for the remote policy server. - - This parser only contains arguments that are common to all server-side policies. - Policy-specific arguments are added later by the selected ServerSidePolicy subclass. - """ - parser = argparse.ArgumentParser("IsaacLab Arena Remote Policy Server") - - # Generic server options. - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=5555) - parser.add_argument("--api_token", type=str, default=None) - parser.add_argument("--timeout_ms", type=int, default=5000) - - # Which ServerSidePolicy implementation to run. - parser.add_argument( - "--policy_type", - type=str, - required=True, - help=( - "Dotted Python path of the server-side policy to run, e.g. " - "'isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy'." - ), - ) - return parser - - -def parse_args() -> argparse.Namespace: - """Parse CLI arguments in two stages. - - 1) Parse only the base arguments to discover which policy class to use. - 2) Let that class extend the parser with its own arguments, then parse again. - """ - # Stage 1: parse base args to get policy_type. - base_parser = build_base_parser() - base_args, _ = base_parser.parse_known_args() - - policy_cls = get_policy_cls(base_args.policy_type) - print(f"[remote_policy_server_runner] Requested server-side policy: {base_args.policy_type} -> {policy_cls}") - - # Stage 2: build a fresh parser, extend it with policy-specific arguments, then parse fully. - full_parser = build_base_parser() - if not hasattr(policy_cls, "add_args_to_parser"): - raise TypeError( - f"Server-side policy class {policy_cls} must define a static 'add_args_to_parser(parser)' method." - ) - full_parser = policy_cls.add_args_to_parser(full_parser) # type: ignore[assignment] - - args = full_parser.parse_args() - return args - - -def main() -> None: - """Entry point for running a remote policy server. - - The script: - 1) Parses CLI arguments in two stages. - 2) Instantiates the requested ServerSidePolicy via its from_args() helper. - 3) Wraps it in a PolicyServer and starts the RPC loop. - """ - args = parse_args() - - policy_cls = get_policy_cls(args.policy_type) - if not hasattr(policy_cls, "from_args"): - raise TypeError(f"Server-side policy class {policy_cls} must define a static 'from_args(args)' method.") - - # Construct the server-side policy from CLI arguments. - policy = policy_cls.from_args(args) # type: ignore[call-arg] - - # Start the RPC server. - server = PolicyServer( - policy=policy, - host=args.host, - port=args.port, - api_token=args.api_token, - timeout_ms=args.timeout_ms, - ) - server.run() - - -if __name__ == "__main__": - main() diff --git a/isaaclab_arena/remote_policy/server_side_policy.py b/isaaclab_arena/remote_policy/server_side_policy.py deleted file mode 100644 index 8b96f3bfb..000000000 --- a/isaaclab_arena/remote_policy/server_side_policy.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, -# The Isaac Lab Arena Project Developers -# (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -from abc import ABC, abstractmethod -from typing import Any - -from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol - - -class ServerSidePolicy(ABC): - """Base class for server-side remote policies. - - This class defines: - * The protocol- and handshake-related API that the PolicyServer relies on. - * A minimal configuration hook via ``config_class`` and ``from_dict``. - * A CLI construction pattern via ``add_args_to_parser`` and ``from_args``, - mirroring the design of :class:`isaaclab_arena.policy.policy_base.PolicyBase` - on the client side. - - Concrete server-side policies (e.g. GR00T-based ones) should: - * Implement ``_build_protocol()`` and the core RPC methods. - * Optionally define a dataclass as ``config_class``. - * Implement ``add_args_to_parser(parser)`` and ``from_args(args)`` - so they can be instantiated directly from command-line arguments. - """ - - # Optional: subclasses can define this to enable from_dict() - config_class: type | None = None - - def __init__(self, config: Any | None = None) -> None: - """Base constructor for server-side policies. - - Args: - config: Optional configuration object (for example, a dataclass - instance). Subclasses are free to interpret this as needed. - """ - self.config = config - self._protocol: ActionProtocol | None = None - self._task_description: str | None = None - - # ------------------------------------------------------------------ - # Config helpers (mirroring PolicyBase.from_dict) - # ------------------------------------------------------------------ - - @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> ServerSidePolicy: - """Create a policy instance from a configuration dictionary. - - Path: dict -> ConfigDataclass -> Policy instance - - This mirrors :meth:`PolicyBase.from_dict` on the client side. - """ - if cls.config_class is None: - raise NotImplementedError(f"{cls.__name__} must define 'config_class' to use from_dict().") - - config = cls.config_class(**config_dict) # type: ignore[misc] - return cls(config) # type: ignore[call-arg] - - # ------------------------------------------------------------------ - # Protocol / handshake API - # ------------------------------------------------------------------ - - @abstractmethod - def _build_protocol(self) -> ActionProtocol: - """Subclasses must build and return an ActionProtocol instance.""" - raise NotImplementedError - - @property - def protocol(self) -> ActionProtocol: - """Return the ActionProtocol associated with this policy. - - The protocol is lazily constructed on first access via ``_build_protocol()``. - """ - if self._protocol is None: - self._protocol = self._build_protocol() - if self._protocol.mode is None: - raise ValueError(f"{self.__class__.__name__} has an ActionProtocol with mode=None, which is not allowed.") - return self._protocol - - def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: - """Handle the initial handshake with the client. - - Checks that the requested action mode is valid and supported by - this policy's ActionProtocol, and returns either an error status - or the protocol configuration as a plain dictionary. - """ - proto = self.protocol - - try: - requested_mode_enum = ActionMode(requested_action_mode) - except ValueError: - return { - "status": "invalid_action_mode", - "message": f"Requested action_mode={requested_action_mode!r} is invalid.", - } - - if requested_mode_enum is not proto.mode: - return { - "status": "unsupported_action_mode", - "message": ( - f"Requested action_mode={requested_mode_enum.value!r} " - "is not supported by this policy. " - f"Supported: {proto.mode.value!r}." - ), - } - - return { - "status": "success", - "config": proto.to_dict(), - } - - # ------------------------------------------------------------------ - # Core RPC methods (to be used by PolicyServer) - # ------------------------------------------------------------------ - - @abstractmethod - def get_action( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Compute one or more actions given an observation payload. - - Args: - observation: Flat observation dictionary received from the client. - - Returns: - A dictionary that must contain at least an ``"action"`` entry - whose structure is compatible with the negotiated ActionProtocol. - """ - raise NotImplementedError - - def reset(self) -> None: - """Reset the policy state. - - Subclasses may override this if they maintain per-environment or - global state that needs to be cleared between episodes. - """ - ... - - def set_task_description( - self, - task_description: str | None, - ) -> dict[str, Any]: - """Set the task description and return a small status/config payload. - - The default implementation stores the description locally and - echoes it back. Subclasses can override this to perform additional - updates or validation. - """ - self._task_description = task_description - return {"task_description": self._task_description or ""} - - # ------------------------------------------------------------------ - # Shared helpers - # ------------------------------------------------------------------ - - def unpack_observation(self, flat_obs: dict[str, Any]) -> dict[str, Any]: - """Convert a flat dotted-key observation dict into a nested dict. - - For example, a key ``"camera_obs.pov.rgb"`` becomes - ``nested["camera_obs"]["pov"]["rgb"]``. - """ - nested: dict[str, Any] = {} - for key_path, value in flat_obs.items(): - cur = nested - parts = key_path.split(".") - for k in parts[:-1]: - cur = cur.setdefault(k, {}) - cur[parts[-1]] = value - return nested - - # ------------------------------------------------------------------ - # CLI helpers (to mirror PolicyBase.add_args_to_parser / from_args) - # ------------------------------------------------------------------ - - @staticmethod - @abstractmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add policy-specific CLI arguments to the parser. - - Server-side policies are expected to implement this so that - :mod:`remote_policy_server_runner` can delegate CLI argument - definitions to the selected policy class. - """ - raise NotImplementedError("ServerSidePolicy subclasses must implement add_args_to_parser().") - - @staticmethod - @abstractmethod - def from_args(args: argparse.Namespace) -> ServerSidePolicy: - """Construct a server-side policy instance from CLI arguments. - - This mirrors the ``from_args(args)`` pattern used by client-side - policies deriving from :class:`PolicyBase`. - """ - raise NotImplementedError("ServerSidePolicy subclasses must implement from_args(args).") From 34182ba54c6776a67a03915c52965cb90c969e4d Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:11:38 +0200 Subject: [PATCH 02/17] Remove ClientSidePolicy and ActionChunkingClientSidePolicy These classes wrapped Arena's now-deleted PolicyClient. Framework wrappers will use each framework's native client instead. Signed-off-by: Clemens Volk --- isaaclab_arena/policy/__init__.py | 1 - .../policy/action_chunking_client.py | 157 -------------- isaaclab_arena/policy/client_side_policy.py | 200 ------------------ 3 files changed, 358 deletions(-) delete mode 100644 isaaclab_arena/policy/action_chunking_client.py delete mode 100644 isaaclab_arena/policy/client_side_policy.py diff --git a/isaaclab_arena/policy/__init__.py b/isaaclab_arena/policy/__init__.py index 66caa51a1..4a0bb1e9d 100644 --- a/isaaclab_arena/policy/__init__.py +++ b/isaaclab_arena/policy/__init__.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 from .action_chunking import ActionChunkingState -from .action_chunking_client import * from .replay_action_policy import * from .rsl_rl_action_policy import * from .zero_action_policy import * diff --git a/isaaclab_arena/policy/action_chunking_client.py b/isaaclab_arena/policy/action_chunking_client.py deleted file mode 100644 index 3cdc5856a..000000000 --- a/isaaclab_arena/policy/action_chunking_client.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import gymnasium as gym -import torch -from typing import Any - -from isaaclab_arena.policy.action_chunking import ActionChunkingState -from isaaclab_arena.policy.client_side_policy import ClientSidePolicy -from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class ActionChunkingClientSidePolicy(ClientSidePolicy): - """Client-side policy that consumes fixed-length action chunks sequentially.""" - - def __init__( - self, - config: Any, - num_envs: int, - device: str, - remote_config: RemotePolicyConfig, - ) -> None: - super().__init__(config=config, remote_config=remote_config, protocol_cls=ChunkingActionProtocol) - - self._num_envs = num_envs - self._device = device - - assert self.protocol.action_chunk_length <= self.protocol.action_horizon, ( - f"protocol.action_chunk_length ({self.protocol.action_chunk_length}) " - f"must be <= protocol.action_horizon ({self.protocol.action_horizon})" - ) - # Shared chunking state (unified with local Gr00tClosedloopPolicy) - self._chunking_state = ActionChunkingState( - num_envs=self._num_envs, - action_chunk_length=self.protocol.action_chunk_length, - action_horizon=self.protocol.action_horizon, - action_dim=self.protocol.action_dim, - device=self._device, - dtype=torch.float32, - ) - - self.task_description: str | None = None - - # ---------------------- CLI ---------------------------------------- - - @staticmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add CLI arguments for ActionChunkingClientSidePolicy.""" - # Shared remote policy args. - parser = ClientSidePolicy.add_remote_args_to_parser(parser) - - # Policy-specific args. - group = parser.add_argument_group( - "Action Chunking Client Policy", - "Arguments for client-side action chunking policy.", - ) - group.add_argument( - "--policy_device", - type=str, - default="cuda", - help="Device to use for the policy-related operations.", - ) - return parser - - @staticmethod - def from_args(args: argparse.Namespace) -> ActionChunkingClientSidePolicy: - """Create an ActionChunkingClientSidePolicy from CLI arguments.""" - remote_config = ClientSidePolicy.build_remote_config_from_args(args) - return ActionChunkingClientSidePolicy( - config=None, - num_envs=args.num_envs, - device=args.policy_device, - remote_config=remote_config, - ) - - # ---------------------- Task description ---------------------------- - - def set_task_description(self, task_description: str | None) -> str: - """Set the task description on both client-side and remote policy.""" - self.task_description = task_description - # Always notify the server so it can set _task_description (server uses config default when None) - self.remote_client.call_endpoint( - "set_task_description", - data={"task_description": task_description}, - requires_input=True, - ) - return self.task_description or "" - - # ---------------------- Chunking logic ------------------------------ - - def _request_new_chunk( - self, - observation: dict[str, Any], - ) -> torch.Tensor: - """Request a new action chunk from the remote policy and validate it.""" - protocol = self.protocol - packed_obs = self.pack_observation_for_server(observation) - - resp = self.remote_client.get_action(packed_obs) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from get_action, got {type(resp)!r}") - if "action" not in resp: - raise KeyError("Remote response does not contain key 'action' for ActionChunkingClientSidePolicy.") - - raw_chunk = resp["action"] - if not isinstance(raw_chunk, torch.Tensor): - raw_chunk = torch.tensor(raw_chunk, dtype=torch.float32, device=self._device) - else: - raw_chunk = raw_chunk.to(self._device, dtype=torch.float32) - - if raw_chunk.shape[0] != self._num_envs: - raise ValueError(f"Expected batch size {self._num_envs}, got {raw_chunk.shape[0]}") - if raw_chunk.shape[1] < protocol.action_chunk_length: - raise ValueError( - f"Expected at least {protocol.action_chunk_length} actions per chunk, got {raw_chunk.shape[1]}" - ) - if raw_chunk.shape[2] != protocol.action_dim: - raise ValueError(f"Expected action_dim {protocol.action_dim}, got {raw_chunk.shape[2]}") - - return raw_chunk - - def get_action( - self, - env: gym.Env, - observation: gym.spaces.Dict, - ) -> torch.Tensor: - """Return one action per env step, consuming action chunks sequentially.""" - - def fetch_chunk() -> torch.Tensor: - return self._request_new_chunk(observation) - - return self._chunking_state.get_action(fetch_chunk) - - def reset(self, env_ids: torch.Tensor | None = None) -> None: - """Reset client-side chunking state and remote policy state.""" - if env_ids is None: - env_ids = torch.arange( - self._num_envs, - device=self._device, - dtype=torch.long, - ) - - self._chunking_state.reset(env_ids) - - # Reset remote state via ClientSidePolicy. - super().reset(env_ids=env_ids) diff --git a/isaaclab_arena/policy/client_side_policy.py b/isaaclab_arena/policy/client_side_policy.py deleted file mode 100644 index 44068dc8d..000000000 --- a/isaaclab_arena/policy/client_side_policy.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import torch -from typing import Any - -from isaaclab_arena.policy.policy_base import PolicyBase -from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol -from isaaclab_arena.remote_policy.policy_client import PolicyClient -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class ClientSidePolicy(PolicyBase): - """Base class for policies that query a remote policy server. - - Responsibilities: - - Manage RemotePolicyConfig and PolicyClient. - - Handshake with the server via get_init_info(). - - Provide observation packing based on observation_keys. - - Provide shared CLI helpers for remote-related arguments. - - Subclasses: - - Must implement get_action(). - """ - - def __init__(self, config: Any, remote_config: RemotePolicyConfig, protocol_cls: type[ActionProtocol]) -> None: - super().__init__(config=config) - - if protocol_cls.MODE is None: - raise ValueError(f"{protocol_cls.__name__}.MODE must be defined as an ActionMode.") - - self.protocol_cls = protocol_cls - requested_action_mode: ActionMode = protocol_cls.MODE - - self._remote_config = remote_config - self._client = PolicyClient(config=self._remote_config) - - # 1) Ping server to ensure connectivity. - if not self._client.ping(): - raise RuntimeError( - f"Failed to connect to remote policy server at {self._remote_config.host}:{self._remote_config.port}." - ) - - # 2) Handshake: send requested_action_mode, parse response. - init_resp = self._client.get_init_info(requested_action_mode=requested_action_mode.value) - - if not isinstance(init_resp, dict): - raise TypeError(f"Expected dict from get_init_info, got {type(init_resp)!r}") - - status = init_resp.get("status", "error") - if status != "success": - message = init_resp.get("message", "no message") - raise RuntimeError(f"Remote policy get_init_info failed with status='{status}': {message}") - - cfg_dict = init_resp.get("config") - if not isinstance(cfg_dict, dict): - raise TypeError( - f"Remote policy get_init_info must return a 'config' dict inside the response, got {type(cfg_dict)!r}" - ) - - self._protocol: ActionProtocol = self.protocol_cls.from_dict(cfg_dict) - - # ---------------------- properties ---------------------------------- - @property - def protocol(self) -> ActionProtocol: - return self._protocol - - @property - def action_mode(self) -> ActionMode: - return self._protocol.mode - - @property - def action_dim(self) -> int: - return self._protocol.action_dim - - @property - def observation_keys(self) -> list[str]: - return list(self._protocol.observation_keys) - - @property - def remote_config(self) -> RemotePolicyConfig: - return self._remote_config - - @property - def remote_client(self) -> PolicyClient: - return self._client - - @property - def is_remote(self) -> bool: - return True - - # ---------------------- observation packing ------------------------- - @staticmethod - def _get_nested_observation(observation: dict[str, Any], key_path: str) -> Any: - """Get a nested value from a dict using 'a.b.c' path.""" - cur: Any = observation - - for k in key_path.split("."): - cur = cur[k] - return cur - - def pack_observation_for_server( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Pack selected observation entries into a flat CPU dict for the server. - - Uses `self.observation_keys` from ClientSidePolicyConfig and: - - Extracts values using nested key paths. - - Moves torch.Tensor values to CPU numpy arrays. - """ - packed: dict[str, Any] = {} - for key_path in self.observation_keys: - value = self._get_nested_observation(observation, key_path) - if isinstance(value, torch.Tensor): - value = value.detach().cpu().numpy() - packed[key_path] = value - return packed - - def reset(self, env_ids: torch.Tensor | None = None) -> None: - """Optionally reset remote policy state. - - Client-side state should be reset in subclasses. - """ - env_ids_list = None - if env_ids is not None: - env_ids_list = env_ids.detach().cpu().tolist() - self._client.reset(env_ids=env_ids_list, options=None) - - def shutdown_remote(self, kill_server: bool = False) -> None: - """Clean up the remote client and optionally stop the remote server.""" - if kill_server: - try: - self._client.call_endpoint("kill", requires_input=False) - except Exception as exc: - print(f"[ClientSidePolicy] Failed to send kill to remote server: {exc}") - self._client.close() - - # ---------------------- shared CLI helpers -------------------------- - - @staticmethod - def add_remote_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add shared remote-policy arguments to the parser. - - This should be called from subclass.add_args_to_parser(). - """ - group = parser.add_argument_group( - "Remote Policy", - "Arguments for connecting to a remote policy server.", - ) - group.add_argument( - "--remote_host", - type=str, - default=None, - required=True, - help="Remote policy server host.", - ) - group.add_argument( - "--remote_port", - type=int, - default=5555, - help="Remote policy server port.", - ) - group.add_argument( - "--remote_api_token", - type=str, - default=None, - help="API token for the remote policy server.", - ) - group.add_argument( - "--remote_timeout_ms", - type=int, - default=15000, - help="Timeout (ms) for remote policy requests.", - ) - group.add_argument( - "--remote_kill_on_exit", - action="store_true", - help="If set, send a 'kill' request to the remote policy server when the run finishes.", - ) - return parser - - @staticmethod - def build_remote_config_from_args(args: argparse.Namespace) -> RemotePolicyConfig: - """Construct RemotePolicyConfig from CLI arguments. - - Assumes add_remote_args_to_parser() has been called on the parser. - """ - - return RemotePolicyConfig( - host=args.remote_host, - port=args.remote_port, - api_token=args.remote_api_token, - timeout_ms=args.remote_timeout_ms, - ) From b56f0d8b63a1166bf9bc508f70d8c6c75c23a857 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:11:59 +0200 Subject: [PATCH 03/17] Remove Gr00tRemoteServerSidePolicy This class wrapped GR00T's Gr00tPolicy inside Arena's ServerSidePolicy abstraction. The GR00T repo already provides its own server via gr00t/eval/run_gr00t_server.py. Signed-off-by: Clemens Volk --- .../policy/gr00t_remote_policy.py | 227 ------------------ 1 file changed, 227 deletions(-) delete mode 100644 isaaclab_arena_gr00t/policy/gr00t_remote_policy.py diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py deleted file mode 100644 index 3dbbed49b..000000000 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import os -import sys -from dataclasses import dataclass -from typing import Any - -# Same as local (Isaac Sim): GR00T deps appended via .pth; do not prepend so system packages (numpy, cv2, tokenizers 0.21) are used first. -_GROOT_DEPS_DIR = os.environ.get("GROOT_DEPS_DIR") -if _GROOT_DEPS_DIR and _GROOT_DEPS_DIR not in sys.path: - sys.path.append(_GROOT_DEPS_DIR) - -from gr00t.policy.gr00t_policy import Gr00tPolicy - -from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode -from isaaclab_arena_gr00t.policy.gr00t_core import ( - Gr00tBasePolicyArgs, - build_gr00t_action_tensor, - build_gr00t_policy_observations, - compute_action_dim, - extract_obs_numpy_from_packed, - load_gr00t_joint_configs, - load_gr00t_policy_from_config, -) -from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml, load_gr00t_modality_config_from_file, to_numpy - - -@dataclass -class Gr00tRemotePolicyArgs(Gr00tBasePolicyArgs): - """Configuration for Gr00tRemoteServerSidePolicy. - - Reuses policy_config_yaml_path and policy_device from the base. - """ - - @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> Gr00tRemotePolicyArgs: - return cls( - policy_config_yaml_path=args.policy_config_yaml_path, - policy_device=args.policy_device, - ) - - -class Gr00tRemoteServerSidePolicy(ServerSidePolicy): - """Server-side wrapper around Gr00tPolicy.""" - - config_class = Gr00tRemotePolicyArgs - - def __init__(self, config: Gr00tRemotePolicyArgs) -> None: - super().__init__(config) - - print(f"[Gr00tRemoteServerSidePolicy] loading config from: {config.policy_config_yaml_path}") - self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( - config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig - ) - print( - "[Gr00tRemoteServerSidePolicy] config:\n" - f" model_path = {self.policy_config.model_path}\n" - f" embodiment_tag = {self.policy_config.embodiment_tag}\n" - f" task_mode_name = {self.policy_config.task_mode_name}\n" - f" action_horizon = {self.policy_config.action_horizon}\n" - f" action_chunk_len = {self.policy_config.action_chunk_length}\n" - f" pov_cam_name_sim = {self.policy_config.pov_cam_name_sim}\n" - f" policy_device = {self.policy_config.policy_device}\n" - ) - - self.device = config.policy_device - self.task_mode = TaskMode(self.policy_config.task_mode_name) - - # Joint configurations - ( - self.policy_joints_config, - self.robot_action_joints_config, - self.robot_state_joints_config, - ) = load_gr00t_joint_configs(self.policy_config) - - # Modality config - self.modality_configs = load_gr00t_modality_config_from_file( - self.policy_config.modality_config_path, - self.policy_config.embodiment_tag, - ) - - # Action dimensions - self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) - self.action_chunk_length = self.policy_config.action_chunk_length - self.action_horizon = self.policy_config.action_horizon - - # Underlying GR00T policy - self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) - print("[Gr00tRemoteServerSidePolicy] Gr00tPolicy loaded successfully") - - # Required observation keys for protocol (one key per camera) - self.camera_names: list[str] = self.policy_config.pov_cam_name_sim - self.required_observation_keys: list[str] = [f"camera_obs.{cam}" for cam in self.camera_names] + [ - "policy.robot_joint_pos" - ] - - # Task description will be set via set_task_description RPC - self._task_description: str | None = None - - # ---------------------- CLI helpers (server-side) ------------------- - - @staticmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add server-side GR00T remote policy arguments.""" - group = parser.add_argument_group( - "Gr00t Remote Server Policy", - "Arguments for GR00T-based server-side remote policy.", - ) - group.add_argument( - "--policy_config_yaml_path", - type=str, - required=True, - help="Path to the GR00T closedloop policy config YAML file.", - ) - group.add_argument( - "--policy_device", - type=str, - default="cuda", - help="Device to use for server-side GR00T inference (default: cuda).", - ) - return parser - - @staticmethod - def from_args(args: argparse.Namespace) -> Gr00tRemoteServerSidePolicy: - """Create a Gr00tRemoteServerSidePolicy from CLI arguments.""" - config = Gr00tRemotePolicyArgs.from_cli_args(args) - return Gr00tRemoteServerSidePolicy(config) - - # ------------ protocol ------------ - - def _build_protocol(self) -> ChunkingActionProtocol: - proto = ChunkingActionProtocol( - action_dim=self.action_dim, - observation_keys=self.required_observation_keys, - action_chunk_length=self.action_chunk_length, - action_horizon=self.action_horizon, - ) - print(f"[Gr00tRemoteServerSidePolicy] protocol mode = {proto.mode.value}") - return proto - - # ------------------------------------------------------------------ # - # Helper methods - # ------------------------------------------------------------------ # - - def _build_policy_observations( - self, - observation: dict[str, Any], - camera_names: list[str], - ) -> dict[str, Any]: - """Convert packed numpy observation into numpy GR00T policy inputs. - - Uses ``extract_obs_numpy_from_packed`` as the single explicit - data-extraction boundary for the remote pipeline, then delegates - to the shared core preprocessing. - """ - assert self._task_description is not None, "Task description is not set" - - rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_packed( - observation, camera_names, self.unpack_observation - ) - - return build_gr00t_policy_observations( - rgb_list_np=rgb_list_np, - joint_pos_sim_np=joint_pos_sim_np, - task_description=self._task_description, - policy_config=self.policy_config, - robot_state_joints_config=self.robot_state_joints_config, - policy_joints_config=self.policy_joints_config, - modality_configs=self.modality_configs, - ) - - # ------------------------------------------------------------------ # - # ServerSidePolicy interface - # ------------------------------------------------------------------ # - - def set_task_description(self, task_description: str | None) -> dict[str, Any]: - if task_description is None: - task_description = self.policy_config.language_instruction - if not task_description: - raise ValueError( - "No language instruction provided. Set 'language_instruction' in the job config, " - "pass --language_instruction on the CLI, or define 'task_description' on the task class." - ) - self._task_description = task_description - return {"status": "ok"} - - def get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - # 1) Shared numpy-based preprocessing - policy_observations = self._build_policy_observations(observation, self.camera_names) - - # 2) GR00T forward pass - robot_action_policy, _ = self.policy.get_action(policy_observations) - - # 3) Postprocessing (shared with closedloop) - action_tensor = build_gr00t_action_tensor( - robot_action_policy=robot_action_policy, - task_mode=self.task_mode, - policy_joints_config=self.policy_joints_config, - robot_action_joints_config=self.robot_action_joints_config, - device=self.device, - embodiment_tag=self.policy_config.embodiment_tag, - ) - - assert action_tensor.shape[1] >= self.action_chunk_length - - action_chunk = to_numpy(action_tensor) - # NOTE(huikang, 2026-02-06): Currently, it seems that the output action length is action_horizon, - # but the action chunk post-process actually handles a length of action_chunk_length. - # It looks like we can transmit a tensor of length action_chunk_length. At the moment, action_chunk_length and action_horizon are the same. - action: dict[str, Any] = {"action": action_chunk} - info: dict[str, Any] = {} - return action, info - - def reset(self, env_ids: list[int] | None = None, reset_options: dict[str, Any] | None = None) -> dict[str, Any]: - # placeholder for future reset options from GR00T repo - self.policy.reset() - return {"status": "reset_success"} From aa9c495c920915c53e6ad1bb0a958a9806b08d0b Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:12:17 +0200 Subject: [PATCH 04/17] Remove tests for deleted remote policy classes test_policy_client.py and test_action_chunking_client.py tested the now-removed PolicyClient and ActionChunkingClientSidePolicy. Signed-off-by: Clemens Volk --- .../tests/test_action_chunking_client.py | 187 ------------------ isaaclab_arena/tests/test_policy_client.py | 142 ------------- 2 files changed, 329 deletions(-) delete mode 100644 isaaclab_arena/tests/test_action_chunking_client.py delete mode 100644 isaaclab_arena/tests/test_policy_client.py diff --git a/isaaclab_arena/tests/test_action_chunking_client.py b/isaaclab_arena/tests/test_action_chunking_client.py deleted file mode 100644 index aab81f67f..000000000 --- a/isaaclab_arena/tests/test_action_chunking_client.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, -# The Isaac Lab Arena Project Developers -# (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import numpy as np -import threading -import time -from typing import Any - -import pytest - -from isaaclab_arena.remote_policy.action_protocol import ActionProtocol, ChunkingActionProtocol -from isaaclab_arena.remote_policy.policy_server import PolicyServer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -from isaaclab_arena.tests.utils.constants import TestConstants -from isaaclab_arena.tests.utils.subprocess import run_subprocess - -HEADLESS = True -NUM_STEPS = 2 -HOST = "127.0.0.1" -PORT = 5563 # test-only port, avoid conflicts - - -# ====================================================================================== -# Dummy server-side policy using the real ChunkingActionProtocol -# ====================================================================================== - - -class _DummyChunkingServerPolicy(ServerSidePolicy): - """Server-side policy that uses ChunkingActionProtocol and returns fixed chunks.""" - - def __init__(self, action_dim: int = 50, chunk_length: int = 4) -> None: - super().__init__(config=None) - self._action_dim = action_dim - self._chunk_length = chunk_length - self._counter = 0 - - def _build_protocol(self) -> ActionProtocol: - return ChunkingActionProtocol( - action_dim=self._action_dim, - observation_keys=["policy.robot_joint_pos"], - action_chunk_length=self._chunk_length, - action_horizon=self._chunk_length, - ) - - def get_action( - self, - observation: dict[str, Any], - options: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Return (batch, chunk_length, action_dim) array with a simple pattern. - - The options argument is accepted to match PolicyServer._handle_get_action, - but is not used in this dummy implementation. - """ - first_key = next(iter(observation.keys())) - batch = int(np.shape(observation[first_key])[0]) - - base_value = float(self._counter) - self._counter += 1 - - chunk = np.full( - (batch, self._chunk_length, self._action_dim), - fill_value=base_value, - dtype=np.float32, - ) - # IMPORTANT: return a dict containing "action" and "info" - return {"action": chunk}, {} - - # NEW: match what PolicyServer._handle_reset expects - def reset(self, env_ids: list[int] | None = None, reset_options: dict[str, Any] | None = None) -> dict[str, Any]: - """Reset policy state for the given environment ids. - - The implementation here is trivial; it just returns an OK status - and does not keep any per-env state. - """ - return {"status": "ok"} - - @staticmethod - def add_args_to_parser(parser: Any) -> Any: - return parser - - @staticmethod - def from_args(args: Any) -> _DummyChunkingServerPolicy: - return _DummyChunkingServerPolicy() - - -# ====================================================================================== -# Helper to start/stop a PolicyServer in background -# ====================================================================================== - - -@pytest.fixture -def running_dummy_chunking_server() -> PolicyServer: - """Start a PolicyServer with _DummyChunkingServerPolicy on localhost.""" - policy = _DummyChunkingServerPolicy(chunk_length=4) - server = PolicyServer( - policy=policy, - host=HOST, - port=PORT, - api_token=None, - timeout_ms=2_000, - ) - - thread = threading.Thread(target=server.run, daemon=True) - thread.start() - # Give the server a short time to bind and start. - time.sleep(0.2) - - try: - yield server - finally: - # Ask the server to stop and wait for the thread. - server.running = False - thread.join(timeout=5.0) - - if hasattr(server, "close"): - server.close() - assert not thread.is_alive() - - -# ====================================================================================== -# Helper to call policy_runner (same style as existing tests) -# ====================================================================================== - - -def _run_policy_runner_with_action_chunking_client() -> None: - """Run policy_runner.py with ActionChunkingClientSidePolicy in remote mode. - - The remote host/port are set to the dummy server started by the fixture. - """ - args: list[str] = [ - TestConstants.python_path, - f"{TestConstants.evaluation_dir}/policy_runner.py", - ] - - args.extend([ - "--policy_type", - "isaaclab_arena.policy.action_chunking_client.ActionChunkingClientSidePolicy", - ]) - - args.extend([ - "--remote_host", - HOST, - "--remote_port", - str(PORT), - "--remote_kill_on_exit", - ]) - - args.extend(["--num_steps", str(NUM_STEPS)]) - if HEADLESS: - args.append("--headless") - - args.append("galileo_g1_locomanip_pick_and_place") - args.extend(["--embodiment", "g1_wbc_joint"]) - args.extend(["--object", "brown_box"]) - - run_subprocess(args) - - -# ====================================================================================== -# Test -# ====================================================================================== - - -@pytest.mark.with_subprocess -def test_action_chunking_client_end_to_end_with_dummy_chunking_server( - running_dummy_chunking_server: PolicyServer, -) -> None: - """End-to-end test: dummy chunking server + ActionChunkingClientSidePolicy + policy_runner. - - This verifies that: - - The dummy PolicyServer using ChunkingActionProtocol can be reached on HOST:PORT. - - ActionChunkingClientSidePolicy can connect to it via policy_runner.py. - - The process exits successfully for a short rollout. - """ - _run_policy_runner_with_action_chunking_client() diff --git a/isaaclab_arena/tests/test_policy_client.py b/isaaclab_arena/tests/test_policy_client.py deleted file mode 100644 index 42fc18b17..000000000 --- a/isaaclab_arena/tests/test_policy_client.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -import threading -import time -from typing import Any - -import pytest -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.policy_client import PolicyClient -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class _DummyServer: - """Minimal test server that emulates a subset of PolicyServer behavior. - - It only understands the endpoints used by PolicyClient tests and always - responds with well-formed msgpack-encoded dictionaries. - """ - - def __init__(self, host: str = "127.0.0.1", port: int = 5557, api_token: str | None = None) -> None: - self._host = host - self._port = port - self._api_token = api_token - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REP) - self._running = False - self._thread: threading.Thread | None = None - - def start(self) -> None: - """Start the server loop in a background thread.""" - bind_addr = f"tcp://{self._host}:{self._port}" - self._socket.bind(bind_addr) - self._running = True - self._thread = threading.Thread(target=self._loop, daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the server loop and close the socket.""" - self._running = False - if self._thread is not None: - self._thread.join(timeout=5.0) - self._socket.close(0) - self._context.term() - - def _loop(self) -> None: - """Event loop that receives one request and sends one response.""" - while self._running: - try: - message = self._socket.recv(flags=zmq.NOBLOCK) - except zmq.Again: - time.sleep(0.01) - continue - - request: dict[str, Any] = MessageSerializer.from_bytes(message) - - # Real code uses "api_token" on the wire. - if self._api_token is not None: - if request.get("api_token") != self._api_token: - response: dict[str, Any] = {"error": "invalid apitoken"} - self._socket.send(MessageSerializer.to_bytes(response)) - continue - - endpoint = request.get("endpoint", "") - data = request.get("data", {}) or {} - - if endpoint == "get_action": - # Return a minimal valid action payload; client expects a dict. - resp = {"action": [[0.0, 1.0, 2.0]], "info": {"dummy": True}} - elif endpoint == "get_init_info": - resp = {"obs_keys": ["rgb", "depth"], "action_dim": 3} - elif endpoint == "set_task_description": - desc = data.get("task_description", "") - resp = {"status": "ok", "echo": desc} - elif endpoint == "ping": - resp = {"status": "alive"} - else: - resp = {"error": f"unknown endpoint {endpoint!r}"} - - self._socket.send(MessageSerializer.to_bytes(resp)) - - -@pytest.fixture -def dummy_server() -> _DummyServer: - """Fixture that starts a dummy server and tears it down after the test.""" - server = _DummyServer(host="127.0.0.1", port=5557, api_token="SECRET") - server.start() - # Give the background thread a short time to bind the socket. - time.sleep(0.1) - try: - yield server - finally: - server.stop() - - -def test_policy_client_call_endpoint_and_get_action(dummy_server: _DummyServer) -> None: - """PolicyClient should be able to call endpoints and parse responses.""" - config = RemotePolicyConfig(host="127.0.0.1", port=5557, api_token="SECRET", timeout_ms=2000) - client = PolicyClient(config=config) - - # Test ping endpoint without input. - resp = client.call_endpoint(endpoint="ping", data=None, requires_input=False) - assert isinstance(resp, dict) - assert resp.get("status") == "alive" - - # Test get_action endpoint with dummy observation. - action_resp = client.get_action({ - "rgb": "dummy", # Content does not matter for this dummy server. - }) - assert isinstance(action_resp, dict) - assert "action" in action_resp - assert "info" in action_resp - - action = action_resp["action"] - assert isinstance(action, list) - assert len(action) == 1 - assert len(action[0]) == 3 - - client.close() - - -def test_policy_client_get_init_info_and_set_task_description(dummy_server: _DummyServer) -> None: - """get_init_info and set_task_description should return dictionaries.""" - config = RemotePolicyConfig(host="127.0.0.1", port=5557, api_token="SECRET", timeout_ms=2000) - client = PolicyClient(config=config) - - init_info = client.get_init_info({"dummy": True}) - assert isinstance(init_info, dict) - assert "obs_keys" in init_info - assert "action_dim" in init_info - - desc = "open the microwave door" - status = client.set_task_description(desc) - assert isinstance(status, dict) - assert status.get("status") == "ok" - assert status.get("echo") == desc - - client.close() From 7281e8c4fedd3985244806ac0ae3c5f1f278243a Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:12:34 +0200 Subject: [PATCH 05/17] Remove is_remote property from PolicyBase The remote/local distinction is no longer an Arena-level concern. Each framework wrapper handles its own lifecycle internally. Signed-off-by: Clemens Volk --- isaaclab_arena/policy/policy_base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/isaaclab_arena/policy/policy_base.py b/isaaclab_arena/policy/policy_base.py index bf594ea56..d927c8d72 100644 --- a/isaaclab_arena/policy/policy_base.py +++ b/isaaclab_arena/policy/policy_base.py @@ -86,11 +86,6 @@ def length(self) -> int | None: """Get the length of the policy (for dataset-driven policies).""" pass - @property - def is_remote(self) -> bool: - """Check if policy is run remotely.""" - return False - @staticmethod @abstractmethod def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: From d20d7f58ddd6a24f81b3459d72cb558cd3da0c15 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:13:03 +0200 Subject: [PATCH 06/17] Remove remote policy shutdown from policy_runner.py The is_remote/shutdown_remote pattern is removed. Framework wrappers handle their own client lifecycle internally. Signed-off-by: Clemens Volk --- isaaclab_arena/evaluation/policy_runner.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/isaaclab_arena/evaluation/policy_runner.py b/isaaclab_arena/evaluation/policy_runner.py index 6859e65dd..be2d24ce3 100644 --- a/isaaclab_arena/evaluation/policy_runner.py +++ b/isaaclab_arena/evaluation/policy_runner.py @@ -204,18 +204,10 @@ def main(): # Each rank prints its own metrics as it can be different due to random seed print(f"[Rank {local_rank}/{world_size}] Metrics: {metrics}") - # NOTE(huikang, 2025-12-30)Explicitly clean up the remote policy client / server. - # Do NOT rely on a __del__ destructor in policy for this, since destructors are - # triggered implicitly and their execution time (or even whether they run) - # is not guaranteed, which makes resource cleanup unreliable. - if policy.is_remote: - policy.shutdown_remote(kill_server=args_cli.remote_kill_on_exit) - # Close the environment. env.close() if __name__ == "__main__": - # TODO(xinjie.yao, 2026.03.31): Remove it after policy sever-client is implemented properly in v0.3. ensure_groot_deps_in_path() main() From 722bb3d348dfd8b09047af00cd955ca98935cb56 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:13:25 +0200 Subject: [PATCH 07/17] Update comments referencing deleted remote classes Signed-off-by: Clemens Volk --- isaaclab_arena/policy/action_chunking.py | 4 ++-- isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/isaaclab_arena/policy/action_chunking.py b/isaaclab_arena/policy/action_chunking.py index fbfb32755..a2b410e42 100644 --- a/isaaclab_arena/policy/action_chunking.py +++ b/isaaclab_arena/policy/action_chunking.py @@ -14,8 +14,8 @@ class ActionChunkingState: """Holds chunk buffer, per-env index, and refill flag; provides get_action(fetch_chunk_fn). - Used by both Gr00tClosedloopPolicy (local) and ActionChunkingClientSidePolicy (remote) - so chunking behavior is identical. + Used by Gr00tClosedloopPolicy and any framework-specific remote wrapper + so chunking behavior is identical across local and remote policies. """ def __init__( diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 8186ff709..7b5afa824 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -105,7 +105,7 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) self.action_chunk_length = self.policy_config.action_chunk_length - # Shared chunking state (unified with remote ActionChunkingClientSidePolicy) + # Shared chunking state (reused by remote framework wrappers) self._chunking_state = ActionChunkingState( num_envs=self.num_envs, action_chunk_length=self.action_chunk_length, From f2dd1b7914ce6103beeeadc4065b71703fbe3947 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:13:44 +0200 Subject: [PATCH 08/17] Remove Arena's GR00T policy server Docker files Dockerfile.gr00t_server launched Arena's remote_policy_server_runner, which is now deleted. The GR00T repo provides its own server via gr00t/eval/run_gr00t_server.py with its own Docker setup. Signed-off-by: Clemens Volk --- docker/Dockerfile.gr00t_server | 38 ----- docker/run_gr00t_server.sh | 248 --------------------------------- 2 files changed, 286 deletions(-) delete mode 100644 docker/Dockerfile.gr00t_server delete mode 100755 docker/run_gr00t_server.sh diff --git a/docker/Dockerfile.gr00t_server b/docker/Dockerfile.gr00t_server deleted file mode 100644 index 227e30d66..000000000 --- a/docker/Dockerfile.gr00t_server +++ /dev/null @@ -1,38 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:24.07-py3 - -ARG WORKDIR="/workspace" -ARG GROOT_DEPS_GROUP="base" -ENV WORKDIR=${WORKDIR} -ENV GROOT_DEPS_GROUP=${GROOT_DEPS_GROUP} -WORKDIR "${WORKDIR}" - -RUN apt-get update && apt-get install -y \ - git \ - git-lfs \ - cmake \ - && rm -rf /var/lib/apt/lists/* - -RUN pip install --upgrade pip - -COPY ./submodules/Isaac-GR00T ${WORKDIR}/submodules/Isaac-GR00T - -COPY docker/setup/install_gr00t_deps.sh /tmp/install_gr00t_deps.sh -RUN chmod +x /tmp/install_gr00t_deps.sh && \ - /tmp/install_gr00t_deps.sh --server && \ - rm -f /tmp/install_gr00t_deps.sh - -RUN pip install --no-cache-dir --upgrade "opencv-python-headless==4.8.0.74" -# GN1.6 uses termcolor 3.2.0 -RUN pip install --no-cache-dir termcolor==3.2.0 - -COPY isaaclab_arena/remote_policy ${WORKDIR}/isaaclab_arena/remote_policy -COPY isaaclab_arena_gr00t ${WORKDIR}/isaaclab_arena_gr00t -COPY isaaclab_arena_g1 ${WORKDIR}/isaaclab_arena_g1 - -RUN pip install --no-cache-dir pyzmq msgpack - -ENV PYTHONPATH=${WORKDIR} -# So gr00t_remote_policy loads transformers/tokenizers from GR00T deps (e.g. tokenizers 0.21.x) not system site-packages. -ENV GROOT_DEPS_DIR=/opt/groot_deps - -ENTRYPOINT ["python", "-u", "-m", "isaaclab_arena.remote_policy.remote_policy_server_runner"] diff --git a/docker/run_gr00t_server.sh b/docker/run_gr00t_server.sh deleted file mode 100755 index fca3d0d1a..000000000 --- a/docker/run_gr00t_server.sh +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# ------------------------- -# User-configurable defaults -# ------------------------- - -# Default mount directories on the host machine -DATASETS_DIR="${DATASETS_DIR:-$HOME/datasets}" -MODELS_DIR="${MODELS_DIR:-$HOME/models}" -EVAL_DIR="${EVAL_DIR:-$HOME/eval}" - -# Docker image name and tag for the GR00T policy server -DOCKER_IMAGE_NAME="${DOCKER_IMAGE_NAME:-gr00t_policy_server}" -DOCKER_VERSION_TAG="${DOCKER_VERSION_TAG:-latest}" - -# Rebuild controls -FORCE_REBUILD="${FORCE_REBUILD:-false}" -NO_CACHE="" - -# Server parameters (can also be overridden via environment variables) -HOST="${HOST:-0.0.0.0}" -PORT="${PORT:-5555}" -API_TOKEN="${API_TOKEN:-}" -TIMEOUT_MS="${TIMEOUT_MS:-5000}" -POLICY_TYPE="${POLICY_TYPE:-gr00t_closedloop}" -POLICY_CONFIG_YAML_PATH="${POLICY_CONFIG_YAML_PATH:-/workspace/isaaclab_arena_gr00t/gr1_manip_gr00t_closedloop_config.yaml}" - -# GPU selection for docker --gpus (can also be overridden via environment variables) -# Examples: -# all -> use all GPUs -# 1 -> use 1 GPU (count) -# "device=0" -> use GPU 0 -# "device=0,1" -> use GPU 0 and 1 -GPUS="${GPUS:-all}" - -# ------------------------- -# Help message -# ------------------------- -usage() { - script_name=$(basename "$0") - cat < Path to datasets on the host. Default: "$DATASETS_DIR". - -m Path to models on the host. Default: "$MODELS_DIR". - -e Path to evaluation data on the host. Default: "$EVAL_DIR". - -n Docker image name. Default: "$DOCKER_IMAGE_NAME". - -g GPU selection for docker --gpus. Default: "all". - Examples: "all", "1", "device=0", "device=0,1". - -r Force rebuilding of the Docker image. - -R Force rebuilding of the Docker image, without cache. - -Server-specific options (passed through to the policy server entrypoint): - --host HOST - --port PORT - --api_token TOKEN - --timeout_ms MS - --policy_type TYPE - --policy_config_yaml_path PATH - -Examples: - # Minimal: use defaults, just build & run server - bash $script_name - - # Custom models directory, port and single GPU (GPU 0) - bash $script_name -m /data/models -g "device=0" --port 6000 --api_token MY_TOKEN - - # Custom image name, force rebuild, datasets/eval mounts, and multiple GPUs - bash $script_name -n gr00t_server -r \\ - -d /data/datasets -m /data/models -e /data/eval \\ - -g "device=0,1" \\ - --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \\ - --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml -EOF -} - -# ------------------------- -# Parse docker/path options (short flags, like run_docker.sh) -# ------------------------- -DOCKER_ARGS_DONE=false -SERVER_ARGS=() - -while [[ $# -gt 0 ]]; do - if [ "$DOCKER_ARGS_DONE" = false ]; then - case "$1" in - -v) - # Enable verbose mode for debugging - set -x - shift 1 - ;; - -d) - # Set host datasets directory - DATASETS_DIR="$2" - shift 2 - ;; - -m) - # Set host models directory - MODELS_DIR="$2" - shift 2 - ;; - -e) - # Set host eval directory - EVAL_DIR="$2" - shift 2 - ;; - -n) - # Set Docker image name - DOCKER_IMAGE_NAME="$2" - shift 2 - ;; - -g) - # Set GPU selection for docker --gpus - GPUS="$2" - shift 2 - ;; - -r) - # Force rebuild of Docker image - FORCE_REBUILD="true" - shift 1 - ;; - -R) - # Force rebuild of Docker image, without cache - FORCE_REBUILD="true" - NO_CACHE="--no-cache" - shift 1 - ;; - -h|--help) - usage - exit 0 - ;; - --host|--port|--api_token|--timeout_ms|--policy_type|--policy_config_yaml_path) - # From here on, treat everything as server args and stop parsing docker flags - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - --*) - # Unknown long option at docker level -> treat as server arg - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - *) - # Anything else -> treat as server arg - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - esac - else - # Additional server arguments after docker/path args - SERVER_ARGS+=("$1") - shift 1 - fi -done - -# If no server args were passed, use defaults -if [ ${#SERVER_ARGS[@]} -eq 0 ]; then - SERVER_ARGS=( - --host "${HOST}" - --port "${PORT}" - --api_token "${API_TOKEN}" - --timeout_ms "${TIMEOUT_MS}" - --policy_type "${POLICY_TYPE}" - --policy_config_yaml_path "${POLICY_CONFIG_YAML_PATH}" - ) -fi - -echo "Host paths:" -echo " DATASETS_DIR = ${DATASETS_DIR}" -echo " MODELS_DIR = ${MODELS_DIR}" -echo " EVAL_DIR = ${EVAL_DIR}" -echo "Docker image:" -echo " ${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" -echo "GPU:" -echo " --gpus ${GPUS}" -echo "Rebuild:" -echo " FORCE_REBUILD = ${FORCE_REBUILD}, NO_CACHE = '${NO_CACHE}'" -echo "Server args:" -printf ' %q ' "${SERVER_ARGS[@]}"; echo - -# ------------------------- -# 1) Build the Docker image -# ------------------------- - -IMAGE_TAG_FULL="${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" - -# 1) Decide whether to build -SHOULD_BUILD=false - -if [ "${FORCE_REBUILD}" = "true" ]; then - # -r or -R: force rebuild - SHOULD_BUILD=true -else - # Without force flag: only build if the image does not exist locally - if [ -z "$(docker images -q "${IMAGE_TAG_FULL}")" ]; then - SHOULD_BUILD=true - fi -fi - -# 2) Build or skip -if [ "${SHOULD_BUILD}" = "true" ]; then - echo "Building Docker image ${IMAGE_TAG_FULL}..." - docker build \ - ${NO_CACHE} \ - -f docker/Dockerfile.gr00t_server \ - -t "${IMAGE_TAG_FULL}" \ - . -else - echo "Docker image ${IMAGE_TAG_FULL} already exists. Skipping rebuild." - echo "Use -r or -R to force rebuilding the image." -fi - -# ------------------------- -# 2) Run the container -# ------------------------- - -DOCKER_RUN_ARGS=( - --rm - --gpus "${GPUS}" - --net host - --name gr00t_policy_server_container - -v "${MODELS_DIR}":/models -) - -# Only mount datasets / eval if the directories exist on host -if [ -d "${DATASETS_DIR}" ]; then - DOCKER_RUN_ARGS+=(-v "${DATASETS_DIR}":/datasets) -fi - -if [ -d "${EVAL_DIR}" ]; then - DOCKER_RUN_ARGS+=(-v "${EVAL_DIR}":/eval) -fi - -# Pass through so gr00t_remote_policy can print path/debug info (e.g. GROOT_DEBUG_PATH=1). -if [ -n "${GROOT_DEBUG_PATH:-}" ]; then - DOCKER_RUN_ARGS+=(-e "GROOT_DEBUG_PATH=${GROOT_DEBUG_PATH}") -fi - -docker run "${DOCKER_RUN_ARGS[@]}" \ - "${IMAGE_TAG_FULL}" \ - "${SERVER_ARGS[@]}" From ac6a93d42fc691dd71c79c55f1d0f80cb6af0c5f Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 15:19:38 +0200 Subject: [PATCH 09/17] Add Gr00tRemoteClosedloopPolicy using GR00T's native client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thin wrapper that connects to a GR00T policy server via GR00T's own PolicyClient (gr00t.policy.server_client). Reuses the same obs/action translation pipeline as the local Gr00tClosedloopPolicy — only the inference call changes from in-process to remote. Signed-off-by: Clemens Volk --- .../policy/gr00t_remote_closedloop_policy.py | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py new file mode 100644 index 000000000..584fd34ce --- /dev/null +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""GR00T remote closed-loop policy using GR00T's native PolicyClient. + +This policy connects to a GR00T policy server (launched via +``gr00t/eval/run_gr00t_server.py``) and reuses the same observation/action +translation pipeline as the local ``Gr00tClosedloopPolicy``. +""" + +from __future__ import annotations + +import argparse +import gymnasium as gym +import torch +from dataclasses import dataclass, field +from typing import Any + +from gr00t.policy.server_client import PolicyClient as Gr00tPolicyClient + +from isaaclab_arena.policy.action_chunking import ActionChunkingState +from isaaclab_arena.policy.policy_base import PolicyBase +from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode +from isaaclab_arena_gr00t.policy.gr00t_core import ( + Gr00tBasePolicyArgs, + build_gr00t_action_tensor, + build_gr00t_policy_observations, + compute_action_dim, + extract_obs_numpy_from_torch, + load_gr00t_joint_configs, +) +from isaaclab_arena_gr00t.utils.io_utils import ( + create_config_from_yaml, + load_gr00t_modality_config_from_file, +) + + +@dataclass +class Gr00tRemoteClosedloopPolicyArgs(Gr00tBasePolicyArgs): + """Configuration for Gr00tRemoteClosedloopPolicy. + + Inherits policy_config_yaml_path and policy_device from Gr00tBasePolicyArgs, + and adds remote server connection parameters and num_envs. + """ + + num_envs: int = field(default=1, metadata={"help": "Number of environments to simulate"}) + remote_host: str = field(default="localhost", metadata={"help": "GR00T policy server hostname"}) + remote_port: int = field(default=5555, metadata={"help": "GR00T policy server port"}) + remote_api_token: str | None = field(default=None, metadata={"help": "API token for the policy server"}) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> Gr00tRemoteClosedloopPolicyArgs: + """Create configuration from parsed CLI arguments.""" + return cls( + policy_config_yaml_path=args.policy_config_yaml_path, + policy_device=args.policy_device, + num_envs=args.num_envs, + remote_host=args.remote_host, + remote_port=args.remote_port, + remote_api_token=getattr(args, "remote_api_token", None), + ) + + +class Gr00tRemoteClosedloopPolicy(PolicyBase): + """GR00T closed-loop policy that delegates inference to a remote GR00T server. + + Uses GR00T's native ``PolicyClient`` (from ``gr00t.policy.server_client``) + to communicate with a GR00T policy server. The observation/action translation + pipeline is identical to the local ``Gr00tClosedloopPolicy``. + + Server side (run independently): + python gr00t/eval/run_gr00t_server.py \\ + --model_path nvidia/GR00T-N1.6-DROID \\ + --embodiment_tag OXE_DROID --device cuda --host 0.0.0.0 --port 5555 + + Client side (Arena evaluation): + python policy_runner.py \\ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_closedloop_policy.Gr00tRemoteClosedloopPolicy \\ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/droid_manip_gr00t_closedloop_config.yaml \\ + --remote_host 10.0.0.1 --remote_port 5555 \\ + --enable_cameras --num_episodes 5 \\ + pick_and_place_maple_table --embodiment droid_abs_joint_pos + """ + + name = "gr00t_remote_closedloop" + config_class = Gr00tRemoteClosedloopPolicyArgs + + def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs): + super().__init__(config) + + # Policy config (for obs/action translation — no model loading) + self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( + config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig + ) + self.num_envs = config.num_envs + self.device = config.policy_device + self.task_mode = TaskMode(self.policy_config.task_mode_name) + + # Joint configs (for sim↔policy joint remapping) + ( + self.policy_joints_config, + self.robot_action_joints_config, + self.robot_state_joints_config, + ) = load_gr00t_joint_configs(self.policy_config) + + # Modality config (for building GR00T observation dicts) + self.modality_configs = load_gr00t_modality_config_from_file( + self.policy_config.modality_config_path, + self.policy_config.embodiment_tag, + ) + + # Action / chunk shapes + self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) + self.action_chunk_length = self.policy_config.action_chunk_length + + # Chunking state (same as local policy) + self._chunking_state = ActionChunkingState( + num_envs=self.num_envs, + action_chunk_length=self.action_chunk_length, + action_horizon=self.policy_config.action_horizon, + action_dim=self.action_dim, + device=self.device, + dtype=torch.float, + ) + + # Connect to GR00T's native policy server + self._client = Gr00tPolicyClient( + host=config.remote_host, + port=config.remote_port, + api_token=config.remote_api_token, + strict=False, + ) + if not self._client.ping(): + raise ConnectionError( + f"Cannot reach GR00T policy server at {config.remote_host}:{config.remote_port}" + ) + + self.task_description: str | None = None + + # ---------------------- CLI helpers ------------------- + + @staticmethod + def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + "Gr00t Remote Closedloop Policy", + "Arguments for GR00T remote closed-loop policy evaluation.", + ) + group.add_argument( + "--policy_config_yaml_path", + type=str, + required=True, + help="Path to the Gr00t closedloop policy config YAML file", + ) + group.add_argument( + "--policy_device", + type=str, + default="cuda", + help="Device for Arena-side tensor operations (default: cuda)", + ) + group.add_argument("--remote_host", type=str, default="localhost", help="GR00T policy server hostname") + group.add_argument("--remote_port", type=int, default=5555, help="GR00T policy server port") + group.add_argument("--remote_api_token", type=str, default=None, help="API token for the policy server") + return parser + + @staticmethod + def from_args(args: argparse.Namespace) -> Gr00tRemoteClosedloopPolicy: + config = Gr00tRemoteClosedloopPolicyArgs.from_cli_args(args) + return Gr00tRemoteClosedloopPolicy(config) + + # ---------------------- Policy interface ------------------- + + def set_task_description(self, task_description: str | None) -> str: + if task_description is None: + task_description = self.policy_config.language_instruction + if not task_description: + raise ValueError( + "No language instruction provided. Set 'language_instruction' in the job config, " + "pass --language_instruction on the CLI, or define 'task_description' on the task class." + ) + self.task_description = task_description + return self.task_description + + def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: + def fetch_chunk() -> torch.Tensor: + return self._get_action_chunk(observation, self.policy_config.pov_cam_name_sim) + + return self._chunking_state.get_action(fetch_chunk) + + def _get_action_chunk( + self, observation: dict[str, Any], camera_names: list[str] | str = "robot_head_cam_rgb" + ) -> torch.Tensor: + """Get an action chunk from the remote GR00T server. + + Same pipeline as Gr00tClosedloopPolicy.get_action_chunk(), but calls + GR00T's PolicyClient instead of a local Gr00tPolicy. + """ + if isinstance(camera_names, str): + camera_names = [camera_names] + + # 1. Reuse the same obs translation as local policy + assert self.task_description is not None, "Task description is not set" + rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_torch(nested_obs=observation, camera_names=camera_names) + policy_observations = build_gr00t_policy_observations( + rgb_list_np=rgb_list_np, + joint_pos_sim_np=joint_pos_sim_np, + task_description=self.task_description, + policy_config=self.policy_config, + robot_state_joints_config=self.robot_state_joints_config, + policy_joints_config=self.policy_joints_config, + modality_configs=self.modality_configs, + ) + + # 2. Call GR00T's own client + robot_action_policy, _ = self._client.get_action(policy_observations) + + # 3. Reuse the same action translation as local policy + action_tensor = build_gr00t_action_tensor( + robot_action_policy=robot_action_policy, + task_mode=self.task_mode, + policy_joints_config=self.policy_joints_config, + robot_action_joints_config=self.robot_action_joints_config, + device=self.device, + embodiment_tag=self.policy_config.embodiment_tag, + ) + + assert action_tensor.shape[0] == self.num_envs and action_tensor.shape[1] >= self.action_chunk_length + return action_tensor + + def reset(self, env_ids: torch.Tensor | None = None): + if env_ids is None: + env_ids = slice(None) + self._client.reset() + self._chunking_state.reset(env_ids) From 3645a4c15f85b392f0e0c11e316bd59183a1a0d9 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 18:41:54 +0200 Subject: [PATCH 10/17] Move Gr00tPolicy model loading out of gr00t_core.py Remove load_gr00t_policy_from_config() and the top-level Gr00tPolicy import from gr00t_core.py so the module no longer pulls in transformers/torch model stack at import time. The local Gr00tClosedloopPolicy now loads the model via its own _load_policy() method. This keeps gr00t_core.py lightweight for the remote wrapper which only needs obs/action translation. Signed-off-by: Clemens Volk --- .../policy/gr00t_closedloop_policy.py | 22 +++++++++------ isaaclab_arena_gr00t/policy/gr00t_core.py | 27 ------------------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 7b5afa824..86e89ee45 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -14,6 +14,11 @@ from typing import Any from gr00t.data.embodiment_tags import EmbodimentTag +# NOTE: Gr00tPolicy is a heavy import (transformers, model loading). This local policy +# loads the model in-process and requires the full GR00T ML stack. For production +# evaluation, use Gr00tRemoteClosedloopPolicy which delegates inference to a remote +# GR00T server and has no heavy dependencies. This local policy may be removed in a +# future release if all workflows move to the remote path. from gr00t.policy.gr00t_policy import Gr00tPolicy from isaaclab_arena.policy.action_chunking import ActionChunkingState @@ -27,7 +32,6 @@ compute_action_dim, extract_obs_numpy_from_torch, load_gr00t_joint_configs, - load_gr00t_policy_from_config, ) from isaaclab_arena_gr00t.utils.eagle_config_compat import apply_eagle_config_compat from isaaclab_arena_gr00t.utils.io_utils import ( @@ -76,7 +80,7 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig ) - self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) + self.policy: Gr00tPolicy = self._load_policy() # Basic attributes self.num_envs = config.num_envs @@ -156,16 +160,18 @@ def load_sim_action_joints_config(self, action_config_path: Path) -> dict[str, A """Load the simulation action joint config from the data config.""" return load_robot_joints_config_from_yaml(action_config_path) - def load_policy(self) -> Gr00tPolicy: - """Load the dataset, whose iterator will be used as the policy.""" - assert Path( - self.policy_config.model_path - ).exists(), f"Dataset path {self.policy_config.dataset_path} does not exist" + def _load_policy(self) -> Gr00tPolicy: + """Load the GR00T policy model in-process.""" + model_path = self.policy_config.model_path + is_hf_id = bool(model_path and "/" in model_path and not model_path.startswith(("/", "."))) + assert ( + Path(model_path).exists() or is_hf_id + ), f"Model path {model_path} does not exist and is not a HuggingFace model id" apply_eagle_config_compat() return Gr00tPolicy( - model_path=self.policy_config.model_path, + model_path=model_path, embodiment_tag=EmbodimentTag[self.policy_config.embodiment_tag], device=self.device, strict=True, diff --git a/isaaclab_arena_gr00t/policy/gr00t_core.py b/isaaclab_arena_gr00t/policy/gr00t_core.py index c53241b09..abbf6e8f5 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_core.py +++ b/isaaclab_arena_gr00t/policy/gr00t_core.py @@ -35,7 +35,6 @@ from typing import Any from gr00t.data.embodiment_tags import EmbodimentTag -from gr00t.policy.gr00t_policy import Gr00tPolicy from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.policy_constants import ( NUM_BASE_HEIGHT_CMD, @@ -82,32 +81,6 @@ class Gr00tBasePolicyArgs: # --------------------------------------------------------------------------- # -def load_gr00t_policy_from_config(policy_config: Gr00tClosedloopPolicyConfig) -> Gr00tPolicy: - """Instantiate a GR00T policy from the closed-loop config. - - Args: - policy_config: Loaded closed-loop config (model path, embodiment, device). - - Returns: - Loaded ``Gr00tPolicy`` on the configured device. - - Raises: - AssertionError: If ``policy_config.model_path`` does not exist. - """ - model_path = policy_config.model_path - # HuggingFace Hub repo IDs use "owner/repo" format (e.g. "nvidia/GR00T-N1.6-DROID"). - is_hf_id = bool(model_path and "/" in model_path and not model_path.startswith(("/", "."))) - assert ( - Path(model_path).exists() or is_hf_id - ), f"Model path {model_path} does not exist and is not a HuggingFace model id" - return Gr00tPolicy( - model_path=policy_config.model_path, - embodiment_tag=EmbodimentTag[policy_config.embodiment_tag], - device=policy_config.policy_device, - strict=True, - ) - - def load_gr00t_joint_configs( policy_config: Gr00tClosedloopPolicyConfig, ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: From 17cf47df3c042bda89586b698e71d9f019f33e6c Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 18:55:28 +0200 Subject: [PATCH 11/17] Inline load_modality_config to avoid training stack import Replace the import of gr00t.experiment.launch_finetune (which pulls in tyro, FinetuneConfig, and the training experiment module) with the 10-line inlined equivalent. The remote wrapper no longer requires any GR00T training dependencies. Signed-off-by: Clemens Volk --- isaaclab_arena_gr00t/utils/io_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/isaaclab_arena_gr00t/utils/io_utils.py b/isaaclab_arena_gr00t/utils/io_utils.py index 15f25a6e5..3f3d87d97 100644 --- a/isaaclab_arena_gr00t/utils/io_utils.py +++ b/isaaclab_arena_gr00t/utils/io_utils.py @@ -204,13 +204,22 @@ def load_gr00t_modality_config_from_file(modality_config_path: str | Path, embod Returns: modality_configs: Modality configurations """ + import importlib + import sys + from gr00t.configs.data.embodiment_configs import MODALITY_CONFIGS from gr00t.data.embodiment_tags import EmbodimentTag - from gr00t.experiment.launch_finetune import load_modality_config if modality_config_path: - # Import module for side-effect registration - load_modality_config(modality_config_path) + # Import the modality config module for side-effect registration. + # Inlined from gr00t.experiment.launch_finetune.load_modality_config() + # to avoid pulling in the full training stack (tyro, FinetuneConfig, etc.). + path = Path(modality_config_path) + if path.exists() and path.suffix == ".py": + sys.path.append(str(path.parent)) + importlib.import_module(path.stem) + else: + raise FileNotFoundError(f"Modality config path does not exist: {modality_config_path}") # Get the embodiment tag from policy config and convert to EmbodimentTag enum # Handle case-insensitive lookup (e.g., "NEW_EMBODIMENT" or "new_embodiment" both work) From cbd2b797b06005e1f9373df29c9b6188d4e28edc Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Wed, 15 Apr 2026 19:12:53 +0200 Subject: [PATCH 12/17] Auto-discover Isaac-GR00T submodule for lightweight client imports Update ensure_groot_deps_in_path to automatically find and add the Isaac-GR00T submodule to sys.path if gr00t is not already importable. This allows the remote wrapper to import PolicyClient without requiring a full pip install of the GR00T package or manual PYTHONPATH setup. Signed-off-by: Clemens Volk --- isaaclab_arena_gr00t/utils/groot_path.py | 49 ++++++++++++++++-------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/isaaclab_arena_gr00t/utils/groot_path.py b/isaaclab_arena_gr00t/utils/groot_path.py index 8c797e5fa..6893b692e 100644 --- a/isaaclab_arena_gr00t/utils/groot_path.py +++ b/isaaclab_arena_gr00t/utils/groot_path.py @@ -5,23 +5,40 @@ import os import sys +from pathlib import Path -# TODO(xinjie.yao, 2026.03.31): Remove it after policy sever-client is implemented properly in v0.3. -def ensure_groot_deps_in_path(reexec_argv: list[str] | None = None) -> None: - """Prepend ``GROOT_DEPS_DIR`` to ``PYTHONPATH`` and re-exec the process so - GR00T dependencies are importable before Isaac Sim's bundled packages. +def _find_groot_submodule() -> str | None: + """Locate the Isaac-GR00T submodule relative to the repo root.""" + # Walk up from this file to find the repo root (where submodules/ lives) + current = Path(__file__).resolve() + for parent in current.parents: + candidate = parent / "submodules" / "Isaac-GR00T" + if candidate.is_dir(): + return str(candidate) + return None - The function is guarded by the ``_GROOT_PYTHONPATH_APPLIED`` env-var so it - only re-execs once. If ``GROOT_DEPS_DIR`` is not set the call is a no-op. - Args: - reexec_argv: The argv list to pass to ``os.execv`` after the Python - interpreter. Defaults to ``sys.argv`` (i.e. re-run the current - script with the same arguments). Pass - ``["-m", "pytest"] + sys.argv[1:]`` when bootstrapping from a - pytest conftest so the test runner is invoked correctly. +def ensure_groot_in_path() -> None: + """Ensure the Isaac-GR00T submodule is importable. + + Adds the submodule to sys.path if ``gr00t`` is not already importable. + This allows the lightweight client imports (PolicyClient, MsgSerializer) + without requiring a full ``pip install`` of the GR00T package. + + Also prepends ``GROOT_DEPS_DIR`` to ``PYTHONPATH`` and re-execs the + process if set, so GR00T's pip dependencies are importable before + Isaac Sim's bundled packages. """ + # 1. Add submodule to sys.path if gr00t is not already importable + try: + import gr00t # noqa: F401 + except ModuleNotFoundError: + submodule_path = _find_groot_submodule() + if submodule_path and submodule_path not in sys.path: + sys.path.insert(0, submodule_path) + + # 2. Handle GROOT_DEPS_DIR re-exec (for heavy deps like transformers) deps_dir = os.environ.get("GROOT_DEPS_DIR") if not deps_dir or os.environ.get("_GROOT_PYTHONPATH_APPLIED") == "1": return @@ -29,6 +46,8 @@ def ensure_groot_deps_in_path(reexec_argv: list[str] | None = None) -> None: os.environ["PYTHONPATH"] = deps_dir + os.pathsep + os.environ.get("PYTHONPATH", "") os.environ["_GROOT_PYTHONPATH_APPLIED"] = "1" - if reexec_argv is None: - reexec_argv = sys.argv - os.execv(sys.executable, [sys.executable] + reexec_argv) + os.execv(sys.executable, [sys.executable] + sys.argv) + + +# Keep old name as alias for backward compatibility +ensure_groot_deps_in_path = ensure_groot_in_path From 9c325acbc004eb1c05ddf6e92f623447826b5d33 Mon Sep 17 00:00:00 2001 From: Xinjie Yao Date: Wed, 15 Apr 2026 16:03:55 -0700 Subject: [PATCH 13/17] make action chunk another class --- isaaclab_arena/policy/__init__.py | 3 +- isaaclab_arena/policy/action_chunking.py | 18 ++++--- isaaclab_arena/policy/action_scheduler.py | 47 +++++++++++++++++++ .../policy/gr00t_closedloop_policy.py | 28 ++++++----- .../policy/gr00t_remote_closedloop_policy.py | 28 ++++++----- 5 files changed, 91 insertions(+), 33 deletions(-) create mode 100644 isaaclab_arena/policy/action_scheduler.py diff --git a/isaaclab_arena/policy/__init__.py b/isaaclab_arena/policy/__init__.py index 4a0bb1e9d..9290be08a 100644 --- a/isaaclab_arena/policy/__init__.py +++ b/isaaclab_arena/policy/__init__.py @@ -3,7 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from .action_chunking import ActionChunkingState +from .action_scheduler import ActionScheduler +from .action_chunking import ActionChunkScheduler, ActionChunkingState from .replay_action_policy import * from .rsl_rl_action_policy import * from .zero_action_policy import * diff --git a/isaaclab_arena/policy/action_chunking.py b/isaaclab_arena/policy/action_chunking.py index a2b410e42..e367e17d5 100644 --- a/isaaclab_arena/policy/action_chunking.py +++ b/isaaclab_arena/policy/action_chunking.py @@ -3,19 +3,21 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Shared action chunking state and logic for local and remote policies.""" +"""ActionChunkScheduler: buffer a model chunk and step through it sequentially.""" from __future__ import annotations import torch from collections.abc import Callable +from isaaclab_arena.policy.action_scheduler import ActionScheduler -class ActionChunkingState: - """Holds chunk buffer, per-env index, and refill flag; provides get_action(fetch_chunk_fn). - Used by Gr00tClosedloopPolicy and any framework-specific remote wrapper - so chunking behavior is identical across local and remote policies. +class ActionChunkScheduler(ActionScheduler): + """Buffers one action chunk and replays it one step at a time. + + Fetches a new chunk from the model only when the current one is exhausted. + Per-env tracking allows environments to refetch independently. """ def __init__( @@ -79,10 +81,14 @@ def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor return action - def reset(self, env_ids: torch.Tensor | None = None) -> None: + def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: """Reset chunking state for the given envs (all if None).""" if env_ids is None: env_ids = slice(None) self.current_action_chunk[env_ids] = 0.0 self.current_action_index[env_ids] = -1 self.env_requires_new_chunk[env_ids] = True + + +# Backwards-compatibility alias +ActionChunkingState = ActionChunkScheduler diff --git a/isaaclab_arena/policy/action_scheduler.py b/isaaclab_arena/policy/action_scheduler.py new file mode 100644 index 000000000..18f91d70b --- /dev/null +++ b/isaaclab_arena/policy/action_scheduler.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Abstract base class for action scheduling strategies.""" + +from __future__ import annotations + +import torch +from abc import ABC, abstractmethod +from collections.abc import Callable + + +class ActionScheduler(ABC): + """Translates raw model chunk outputs into per-step actions. + + The policy calls ``get_action(fetch_chunk_fn)`` at every environment step. + The scheduler controls when to query the model and how to derive a single + action from one or more model outputs. + + Concrete implementations include: + - ``ActionChunkScheduler``: buffer one chunk, step through it sequentially, + refetch when exhausted. + - ``TemporalEnsemblingScheduler``: always query the model, blend overlapping + chunks with exponential decay weights (ACT-style). + - ``PassThroughScheduler``: always query the model, return the first action + in the chunk. + """ + + @abstractmethod + def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor: + """Return one action per env for the current timestep. + + Args: + fetch_chunk_fn: Callable that queries the model and returns a chunk + tensor of shape ``(num_envs, horizon, action_dim)``. + + Returns: + Action tensor of shape ``(num_envs, action_dim)``. + """ + ... + + @abstractmethod + def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: + """Reset scheduler state for the given envs (all envs if None).""" + ... diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 86e89ee45..32af9f56d 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -21,7 +21,8 @@ # future release if all workflows move to the remote path. from gr00t.policy.gr00t_policy import Gr00tPolicy -from isaaclab_arena.policy.action_chunking import ActionChunkingState +from isaaclab_arena.policy.action_chunking import ActionChunkScheduler +from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode @@ -72,7 +73,7 @@ class Gr00tClosedloopPolicy(PolicyBase): name = "gr00t_closedloop" config_class = Gr00tClosedloopPolicyArgs - def __init__(self, config: Gr00tClosedloopPolicyArgs): + def __init__(self, config: Gr00tClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): """Initialize Gr00tClosedloopPolicy from a configuration dataclass.""" super().__init__(config) @@ -109,15 +110,16 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) self.action_chunk_length = self.policy_config.action_chunk_length - # Shared chunking state (reused by remote framework wrappers) - self._chunking_state = ActionChunkingState( - num_envs=self.num_envs, - action_chunk_length=self.action_chunk_length, - action_horizon=self.policy_config.action_horizon, - action_dim=self.action_dim, - device=self.device, - dtype=torch.float, - ) + if action_scheduler is None: + action_scheduler = ActionChunkScheduler( + num_envs=self.num_envs, + action_chunk_length=self.action_chunk_length, + action_horizon=self.policy_config.action_horizon, + action_dim=self.action_dim, + device=self.device, + dtype=torch.float, + ) + self._action_scheduler = action_scheduler # task description of task being evaluated. It will be set by the task being evaluated. self.task_description: str | None = None @@ -222,7 +224,7 @@ def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: def fetch_chunk() -> torch.Tensor: return self.get_action_chunk(observation, self.policy_config.pov_cam_name_sim) - return self._chunking_state.get_action(fetch_chunk) + return self._action_scheduler.get_action(fetch_chunk) def get_action_chunk( self, observation: dict[str, Any], camera_names: list[str] | str = "robot_head_cam_rgb" @@ -255,4 +257,4 @@ def reset(self, env_ids: torch.Tensor | None = None): env_ids = slice(None) # placeholder for future reset options from GR00T repo self.policy.reset() - self._chunking_state.reset(env_ids) + self._action_scheduler.reset(env_ids) diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py index 584fd34ce..4c2db68e5 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -20,7 +20,8 @@ from gr00t.policy.server_client import PolicyClient as Gr00tPolicyClient -from isaaclab_arena.policy.action_chunking import ActionChunkingState +from isaaclab_arena.policy.action_chunking import ActionChunkScheduler +from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode from isaaclab_arena_gr00t.policy.gr00t_core import ( @@ -87,7 +88,7 @@ class Gr00tRemoteClosedloopPolicy(PolicyBase): name = "gr00t_remote_closedloop" config_class = Gr00tRemoteClosedloopPolicyArgs - def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs): + def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): super().__init__(config) # Policy config (for obs/action translation — no model loading) @@ -115,15 +116,16 @@ def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs): self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) self.action_chunk_length = self.policy_config.action_chunk_length - # Chunking state (same as local policy) - self._chunking_state = ActionChunkingState( - num_envs=self.num_envs, - action_chunk_length=self.action_chunk_length, - action_horizon=self.policy_config.action_horizon, - action_dim=self.action_dim, - device=self.device, - dtype=torch.float, - ) + if action_scheduler is None: + action_scheduler = ActionChunkScheduler( + num_envs=self.num_envs, + action_chunk_length=self.action_chunk_length, + action_horizon=self.policy_config.action_horizon, + action_dim=self.action_dim, + device=self.device, + dtype=torch.float, + ) + self._action_scheduler = action_scheduler # Connect to GR00T's native policy server self._client = Gr00tPolicyClient( @@ -186,7 +188,7 @@ def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: def fetch_chunk() -> torch.Tensor: return self._get_action_chunk(observation, self.policy_config.pov_cam_name_sim) - return self._chunking_state.get_action(fetch_chunk) + return self._action_scheduler.get_action(fetch_chunk) def _get_action_chunk( self, observation: dict[str, Any], camera_names: list[str] | str = "robot_head_cam_rgb" @@ -232,4 +234,4 @@ def reset(self, env_ids: torch.Tensor | None = None): if env_ids is None: env_ids = slice(None) self._client.reset() - self._chunking_state.reset(env_ids) + self._action_scheduler.reset(env_ids) From fa9ed2b50ce1ed8009cf650231a589f73886ae5c Mon Sep 17 00:00:00 2001 From: Xinjie Yao Date: Wed, 15 Apr 2026 16:42:10 -0700 Subject: [PATCH 14/17] gpu rank assignment --- isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py index 4c2db68e5..fdea97fdc 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -23,6 +23,7 @@ from isaaclab_arena.policy.action_chunking import ActionChunkScheduler from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase +from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode from isaaclab_arena_gr00t.policy.gr00t_core import ( Gr00tBasePolicyArgs, @@ -97,6 +98,8 @@ def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: Ac ) self.num_envs = config.num_envs self.device = config.policy_device + if get_world_size() > 1 and "cuda" in self.device: + self.device = f"cuda:{get_local_rank()}" self.task_mode = TaskMode(self.policy_config.task_mode_name) # Joint configs (for sim↔policy joint remapping) From edd2979954df56ed3b73bb12d04bb59f56ed6e5d Mon Sep 17 00:00:00 2001 From: Xinjie Yao Date: Thu, 16 Apr 2026 21:38:43 -0700 Subject: [PATCH 15/17] Add evaluation timing stats and parallel-env efficiency tracking - Add TimingStats utility (utils/timing.py) for accumulating wall-clock measurements per named category with avg/total/count/pct reporting - Instrument policy_runner rollout loop: measure get_action, env_step, and reset; print every 100 steps and at end of run - Track per-env reset counts and inference fetch efficiency to quantify wasted compute in parallel-env setups - Add fetch-efficiency stats to ActionChunkScheduler: n_fetch_calls, avg_envs_per_fetch, fetch_efficiency, per_env_fetch_count - Instrument Gr00tRemoteClosedloopPolicy with timing on obs_extract, obs_pack, inference_wait, and action_build phases Signed-off-by: Xinjie Yao --- isaaclab_arena/evaluation/policy_runner.py | 63 +++++++++++++++++-- isaaclab_arena/policy/action_chunking.py | 30 +++++++++ isaaclab_arena/utils/timing.py | 46 ++++++++++++++ .../policy/gr00t_remote_closedloop_policy.py | 56 +++++++++++------ 4 files changed, 171 insertions(+), 24 deletions(-) create mode 100644 isaaclab_arena/utils/timing.py diff --git a/isaaclab_arena/evaluation/policy_runner.py b/isaaclab_arena/evaluation/policy_runner.py index be2d24ce3..dfdb8c3fb 100644 --- a/isaaclab_arena/evaluation/policy_runner.py +++ b/isaaclab_arena/evaluation/policy_runner.py @@ -14,6 +14,7 @@ from isaaclab_arena.utils.isaaclab_utils.simulation_app import SimulationAppContext from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size from isaaclab_arena.utils.random import set_seed +from isaaclab_arena.utils.timing import TimingStats from isaaclab_arena_environments.cli import get_arena_builder_from_cli, get_isaaclab_arena_environments_cli_parser from isaaclab_arena_gr00t.utils.groot_path import ensure_groot_deps_in_path @@ -64,10 +65,43 @@ def rollout_policy( assert num_steps is not None or num_episodes is not None, "Either num_steps or num_episodes must be provided" assert num_steps is None or num_episodes is None, "Only one of num_steps or num_episodes must be provided" + _TIMING_PRINT_INTERVAL = 100 + pbar = None + timing = TimingStats() + + # Per-env reset counter — initialised lazily once num_envs is known. + per_env_reset_counts: torch.Tensor | None = None + + def _parallel_env_report(num_steps_done: int) -> str: + """Build a human-readable parallel-env efficiency report.""" + if per_env_reset_counts is None: + return "" + num_envs = per_env_reset_counts.shape[0] + lines = [f"[Parallel-env stats] num_envs={num_envs} steps={num_steps_done}"] + lines.append(f" per_env_resets : {per_env_reset_counts.tolist()}") + if hasattr(policy, "action_scheduler_stats") and hasattr(policy, "action_chunk_length"): + s = policy.action_scheduler_stats + ideal = num_steps_done / policy.action_chunk_length + actual = s["n_fetch_calls"] + overhead_pct = (actual - ideal) / ideal * 100 if ideal > 0 else 0.0 + eff = s["fetch_efficiency"] + lines.append( + f" inference calls : actual={actual} ideal={ideal:.1f}" + f" overhead={overhead_pct:+.1f}%" + ) + lines.append( + f" avg_envs_per_fetch : {s['avg_envs_per_fetch']:.2f}/{num_envs}" + f" efficiency={eff:.1%}" + f" (1.0 = all envs needed fetch, <1.0 = wasted compute)" + ) + lines.append(f" per_env_fetch_count : {s['per_env_fetch_count']}") + return "\n".join(lines) + try: obs, _ = env.reset() policy.reset() + per_env_reset_counts = torch.zeros(env.unwrapped.num_envs, dtype=torch.int64) # Determine language instruction: CLI/job-level override takes precedence over the task's own # description. Use unwrapped to reach the base env through any gym wrappers (e.g. OrderEnforcing). task_description = language_instruction or env.unwrapped.cfg.isaaclab_arena_env.task.get_task_description() @@ -84,17 +118,22 @@ def rollout_policy( while True: with torch.inference_mode(): - actions = policy.get_action(env, obs) - obs, _, terminated, truncated, _ = env.step(actions) + with timing.measure("get_action"): + actions = policy.get_action(env, obs) + + with timing.measure("env_step"): + obs, _, terminated, truncated, _ = env.step(actions) if terminated.any() or truncated.any(): - # Only reset policy for those envs that are terminated or truncated + env_ids = (terminated | truncated).nonzero().flatten() print( f"Resetting policy for terminated env_ids: {terminated.nonzero().flatten()}" f" and truncated env_ids: {truncated.nonzero().flatten()}" ) - env_ids = (terminated | truncated).nonzero().flatten() - policy.reset(env_ids=env_ids) + with timing.measure("reset"): + if per_env_reset_counts is not None: + per_env_reset_counts[env_ids.cpu()] += 1 + policy.reset(env_ids=env_ids) # Break if number of episodes is reached completed_episodes = env_ids.shape[0] num_episodes_completed += completed_episodes @@ -109,6 +148,14 @@ def rollout_policy( if num_steps_completed >= num_steps: break + if num_steps_completed % _TIMING_PRINT_INTERVAL == 0: + print(timing.summary("Runner (cumulative)")) + if hasattr(policy, "timing_stats"): + print(policy.timing_stats.summary("Policy (cumulative)")) + report = _parallel_env_report(num_steps_completed) + if report: + print(report) + pbar.close() except Exception as e: @@ -117,6 +164,12 @@ def rollout_policy( raise RuntimeError(f"Error rolling out policy: {e}") else: + print(timing.summary("Runner (final)")) + if hasattr(policy, "timing_stats"): + print(policy.timing_stats.summary("Policy (final)")) + report = _parallel_env_report(num_steps_completed) + if report: + print(report.replace("cumulative", "final")) # Only compute metrics if env has a non-None metrics list (e.g. NoTask leaves metrics as None). # Use unwrapped to reach the base env through any gym wrappers (e.g. OrderEnforcing) diff --git a/isaaclab_arena/policy/action_chunking.py b/isaaclab_arena/policy/action_chunking.py index e367e17d5..c17caf712 100644 --- a/isaaclab_arena/policy/action_chunking.py +++ b/isaaclab_arena/policy/action_chunking.py @@ -45,6 +45,12 @@ def __init__( self.current_action_index = torch.zeros(num_envs, dtype=torch.int32, device=device) self.env_requires_new_chunk = torch.ones(num_envs, dtype=torch.bool, device=device) + # Fetch-efficiency tracking: how many times each env triggered a chunk fetch, + # and how many envs actually needed the fetch vs. total (wasted compute detection). + self._n_fetch_calls: int = 0 + self._total_envs_needed: int = 0 + self._per_env_fetch_count = torch.zeros(num_envs, dtype=torch.int64, device=device) + def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor: """Return one action per env, refilling the chunk when needed. @@ -52,6 +58,12 @@ def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor with horizon >= action_chunk_length. """ if self.env_requires_new_chunk.any(): + # Track which envs triggered this fetch before calling fetch_chunk_fn. + needed_mask = self.env_requires_new_chunk.clone() + self._n_fetch_calls += 1 + self._total_envs_needed += int(needed_mask.sum().item()) + self._per_env_fetch_count[needed_mask] += 1 + # compute a new action chunk for the envs that require a new action chunk new_chunk = fetch_chunk_fn() mask = self.env_requires_new_chunk @@ -81,6 +93,24 @@ def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor return action + @property + def fetch_stats(self) -> dict: + """Fetch-efficiency stats useful for parallel-env analysis. + + Returns a dict with: + n_fetch_calls - total inference calls made + avg_envs_per_fetch - mean envs that actually needed the fetch + fetch_efficiency - avg_envs_per_fetch / num_envs (1.0 = perfectly sync'd) + per_env_fetch_count - list of per-env fetch trigger counts + """ + avg = self._total_envs_needed / self._n_fetch_calls if self._n_fetch_calls > 0 else 0.0 + return { + "n_fetch_calls": self._n_fetch_calls, + "avg_envs_per_fetch": avg, + "fetch_efficiency": avg / self.num_envs if self.num_envs > 0 else 1.0, + "per_env_fetch_count": self._per_env_fetch_count.tolist(), + } + def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: """Reset chunking state for the given envs (all if None).""" if env_ids is None: diff --git a/isaaclab_arena/utils/timing.py b/isaaclab_arena/utils/timing.py new file mode 100644 index 000000000..b004dd9e3 --- /dev/null +++ b/isaaclab_arena/utils/timing.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Simple timing-statistics accumulator for benchmarking the evaluation loop.""" + +from __future__ import annotations + +import time +from collections import defaultdict +from contextlib import contextmanager + + +class TimingStats: + """Accumulates wall-clock timings per named category and reports averages.""" + + def __init__(self) -> None: + self._totals: dict[str, float] = defaultdict(float) + self._counts: dict[str, int] = defaultdict(int) + + def record(self, key: str, elapsed: float) -> None: + self._totals[key] += elapsed + self._counts[key] += 1 + + @contextmanager + def measure(self, key: str): + t0 = time.perf_counter() + yield + self.record(key, time.perf_counter() - t0) + + def summary(self, label: str = "Timing Stats") -> str: + if not self._totals: + return f"[{label}] No data recorded." + total_wall = sum(self._totals.values()) + lines = [f"[{label}]"] + for key, total in self._totals.items(): + count = self._counts[key] + avg_ms = (total / count) * 1000 if count > 0 else 0.0 + pct = total / total_wall * 100 if total_wall > 0 else 0.0 + lines.append(f" {key:<30s} avg={avg_ms:7.1f}ms total={total:7.2f}s n={count:5d} ({pct:4.1f}%)") + return "\n".join(lines) + + def reset(self) -> None: + self._totals.clear() + self._counts.clear() diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py index fdea97fdc..ffe9e2eb0 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -24,6 +24,7 @@ from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size +from isaaclab_arena.utils.timing import TimingStats from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode from isaaclab_arena_gr00t.policy.gr00t_core import ( Gr00tBasePolicyArgs, @@ -143,6 +144,15 @@ def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: Ac ) self.task_description: str | None = None + self._timing = TimingStats() + + @property + def timing_stats(self) -> TimingStats: + return self._timing + + @property + def action_scheduler_stats(self) -> dict: + return self._action_scheduler.fetch_stats # ---------------------- CLI helpers ------------------- @@ -206,29 +216,36 @@ def _get_action_chunk( # 1. Reuse the same obs translation as local policy assert self.task_description is not None, "Task description is not set" - rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_torch(nested_obs=observation, camera_names=camera_names) - policy_observations = build_gr00t_policy_observations( - rgb_list_np=rgb_list_np, - joint_pos_sim_np=joint_pos_sim_np, - task_description=self.task_description, - policy_config=self.policy_config, - robot_state_joints_config=self.robot_state_joints_config, - policy_joints_config=self.policy_joints_config, - modality_configs=self.modality_configs, - ) + with self._timing.measure("obs_extract"): + rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_torch( + nested_obs=observation, camera_names=camera_names + ) + + with self._timing.measure("obs_pack"): + policy_observations = build_gr00t_policy_observations( + rgb_list_np=rgb_list_np, + joint_pos_sim_np=joint_pos_sim_np, + task_description=self.task_description, + policy_config=self.policy_config, + robot_state_joints_config=self.robot_state_joints_config, + policy_joints_config=self.policy_joints_config, + modality_configs=self.modality_configs, + ) # 2. Call GR00T's own client - robot_action_policy, _ = self._client.get_action(policy_observations) + with self._timing.measure("inference_wait"): + robot_action_policy, _ = self._client.get_action(policy_observations) # 3. Reuse the same action translation as local policy - action_tensor = build_gr00t_action_tensor( - robot_action_policy=robot_action_policy, - task_mode=self.task_mode, - policy_joints_config=self.policy_joints_config, - robot_action_joints_config=self.robot_action_joints_config, - device=self.device, - embodiment_tag=self.policy_config.embodiment_tag, - ) + with self._timing.measure("action_build"): + action_tensor = build_gr00t_action_tensor( + robot_action_policy=robot_action_policy, + task_mode=self.task_mode, + policy_joints_config=self.policy_joints_config, + robot_action_joints_config=self.robot_action_joints_config, + device=self.device, + embodiment_tag=self.policy_config.embodiment_tag, + ) assert action_tensor.shape[0] == self.num_envs and action_tensor.shape[1] >= self.action_chunk_length return action_tensor @@ -238,3 +255,4 @@ def reset(self, env_ids: torch.Tensor | None = None): env_ids = slice(None) self._client.reset() self._action_scheduler.reset(env_ids) + From 9fb5a451b0d6489971a91a3b2e768219e5b7659c Mon Sep 17 00:00:00 2001 From: Xinjie Yao Date: Thu, 16 Apr 2026 23:45:35 -0700 Subject: [PATCH 16/17] Add SyncedBatchActionScheduler and --scheduler switch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds SyncedBatchActionScheduler to action_chunking.py: instead of triggering inference whenever any env needs a new chunk, it waits until all envs are exhausted, then makes a single full-batch call. Envs that finish their chunk early hold their current robot state (joint positions from observation) until the batch is ready. Tradeoff vs default chunk scheduler: - Server batch is always full (100% efficiency at any N) - Reduces inference calls from ~6× overhead at N=10 to exactly ideal - Envs that reset early are frozen for up to (chunk_length-1) steps, resulting in slightly fewer episodes completed per run Adds --scheduler {chunk,synced_batch} argument to Gr00tRemoteClosedloopPolicy CLI. Selecting synced_batch routes to the new Gr00tRemoteClosedloopPolicySyncedBatch subclass. Signed-off-by: Xinjie Yao --- isaaclab_arena/policy/action_chunking.py | 49 +++++++++++ .../policy/gr00t_remote_closedloop_policy.py | 88 +++++++++++++++++-- 2 files changed, 128 insertions(+), 9 deletions(-) diff --git a/isaaclab_arena/policy/action_chunking.py b/isaaclab_arena/policy/action_chunking.py index c17caf712..7aa2cdc88 100644 --- a/isaaclab_arena/policy/action_chunking.py +++ b/isaaclab_arena/policy/action_chunking.py @@ -120,5 +120,54 @@ def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: self.env_requires_new_chunk[env_ids] = True +class SyncedBatchActionScheduler(ActionChunkScheduler): + """ActionChunkScheduler that waits until ALL envs need a new chunk before calling inference. + + Envs that exhaust their chunk early hold their current robot state + (joint positions passed as hold_action) until every env is ready. + Only then is one full-batch inference call made for all envs together. + + Tradeoff vs ActionChunkScheduler: + - Server batch is always full (N envs, never wasted) + - Envs that reset early hold their post-reset state for up to + (chunk_length - 1) steps before receiving a fresh chunk + """ + + def get_action( + self, + fetch_chunk_fn: Callable[[], torch.Tensor], + hold_action: torch.Tensor, + ) -> torch.Tensor: + """Return one action per env, fetching only when all envs need a new chunk. + + Args: + fetch_chunk_fn: Returns (num_envs, horizon, action_dim) when called. + hold_action: (num_envs, action_dim) current robot joint state; applied + to envs that are waiting for others to catch up. + """ + if self.env_requires_new_chunk.all(): + self._n_fetch_calls += 1 + self._total_envs_needed += self.num_envs + self._per_env_fetch_count += 1 + + new_chunk = fetch_chunk_fn() + self.current_action_chunk[:] = new_chunk + self.current_action_index[:] = 0 + self.env_requires_new_chunk[:] = False + + waiting = self.env_requires_new_chunk + batch_idx = torch.arange(self.num_envs, device=self.device) + action = self.current_action_chunk[batch_idx, self.current_action_index.clamp(min=0)] + action[waiting] = hold_action[waiting] + + self.current_action_index[~waiting] += 1 + exhausted = (~waiting) & (self.current_action_index >= self.action_chunk_length) + self.current_action_chunk[exhausted] = 0.0 + self.current_action_index[exhausted] = -1 + self.env_requires_new_chunk[exhausted] = True + + return action + + # Backwards-compatibility alias ActionChunkingState = ActionChunkScheduler diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py index ffe9e2eb0..43287af37 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -14,13 +14,14 @@ import argparse import gymnasium as gym +import numpy as np import torch from dataclasses import dataclass, field from typing import Any from gr00t.policy.server_client import PolicyClient as Gr00tPolicyClient -from isaaclab_arena.policy.action_chunking import ActionChunkScheduler +from isaaclab_arena.policy.action_chunking import ActionChunkScheduler, SyncedBatchActionScheduler from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size @@ -34,10 +35,7 @@ extract_obs_numpy_from_torch, load_gr00t_joint_configs, ) -from isaaclab_arena_gr00t.utils.io_utils import ( - create_config_from_yaml, - load_gr00t_modality_config_from_file, -) +from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml, load_gr00t_modality_config_from_file @dataclass @@ -78,11 +76,21 @@ class Gr00tRemoteClosedloopPolicy(PolicyBase): --model_path nvidia/GR00T-N1.6-DROID \\ --embodiment_tag OXE_DROID --device cuda --host 0.0.0.0 --port 5555 - Client side (Arena evaluation): + Client side — default chunk scheduler (fetches for all envs when any needs a new chunk): + python policy_runner.py \\ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_closedloop_policy.Gr00tRemoteClosedloopPolicy \\ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/droid_manip_gr00t_closedloop_config.yaml \\ + --remote_host 10.0.0.1 --remote_port 5555 \\ + --enable_cameras --num_episodes 5 \\ + pick_and_place_maple_table --embodiment droid_abs_joint_pos + + Client side — synced_batch scheduler (waits until ALL envs need a new chunk, then does one + full-batch call; envs that finish their chunk early hold their current robot state): python policy_runner.py \\ --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_closedloop_policy.Gr00tRemoteClosedloopPolicy \\ --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/droid_manip_gr00t_closedloop_config.yaml \\ --remote_host 10.0.0.1 --remote_port 5555 \\ + --scheduler synced_batch \\ --enable_cameras --num_episodes 5 \\ pick_and_place_maple_table --embodiment droid_abs_joint_pos """ @@ -139,9 +147,7 @@ def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: Ac strict=False, ) if not self._client.ping(): - raise ConnectionError( - f"Cannot reach GR00T policy server at {config.remote_host}:{config.remote_port}" - ) + raise ConnectionError(f"Cannot reach GR00T policy server at {config.remote_host}:{config.remote_port}") self.task_description: str | None = None self._timing = TimingStats() @@ -177,11 +183,25 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars group.add_argument("--remote_host", type=str, default="localhost", help="GR00T policy server hostname") group.add_argument("--remote_port", type=int, default=5555, help="GR00T policy server port") group.add_argument("--remote_api_token", type=str, default=None, help="API token for the policy server") + group.add_argument( + "--scheduler", + type=str, + default="chunk", + choices=["chunk", "synced_batch"], + help=( + "Action chunk scheduler: 'chunk' fetches for all envs when any needs a new chunk;" + " 'synced_batch' waits until all envs need a new chunk then does one full-batch call" + " (envs that finish early hold their current robot state)" + ), + ) return parser @staticmethod def from_args(args: argparse.Namespace) -> Gr00tRemoteClosedloopPolicy: config = Gr00tRemoteClosedloopPolicyArgs.from_cli_args(args) + scheduler = getattr(args, "scheduler", "chunk") + if scheduler == "synced_batch": + return Gr00tRemoteClosedloopPolicySyncedBatch(config) return Gr00tRemoteClosedloopPolicy(config) # ---------------------- Policy interface ------------------- @@ -256,3 +276,53 @@ def reset(self, env_ids: torch.Tensor | None = None): self._client.reset() self._action_scheduler.reset(env_ids) + +class Gr00tRemoteClosedloopPolicySyncedBatch(Gr00tRemoteClosedloopPolicy): + """GR00T remote policy that waits for all envs to need a new chunk before calling inference. + + Uses SyncedBatchActionScheduler: envs that exhaust their chunk early hold their + current robot state (joint positions from obs) until all envs are ready, then one + full-batch inference call is made for all envs together. + + Activate via: + --policy_type ... --scheduler synced_batch + """ + + name = "gr00t_remote_closedloop_synced_batch" + + def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): + if action_scheduler is None: + _policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( + config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig + ) + _action_dim = compute_action_dim( + TaskMode(_policy_config.task_mode_name), + load_gr00t_joint_configs(_policy_config)[1], + ) + action_scheduler = SyncedBatchActionScheduler( + num_envs=config.num_envs, + action_chunk_length=_policy_config.action_chunk_length, + action_horizon=_policy_config.action_horizon, + action_dim=_action_dim, + device=config.policy_device, + dtype=torch.float, + ) + super().__init__(config, action_scheduler) + + def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: + hold_action = self._extract_hold_action(observation) + + def fetch_chunk() -> torch.Tensor: + return self._get_action_chunk(observation, self.policy_config.pov_cam_name_sim) + + return self._action_scheduler.get_action(fetch_chunk, hold_action) + + def _extract_hold_action(self, observation: dict[str, Any]) -> torch.Tensor: + """Return current joint positions mapped to action space as the hold action.""" + _, joint_pos_sim_np = extract_obs_numpy_from_torch(nested_obs=observation, camera_names=[]) + hold_np = np.zeros((self.num_envs, self.action_dim), dtype=np.float64) + for joint_name, action_idx in self.robot_action_joints_config.items(): + if joint_name in self.robot_state_joints_config: + state_idx = self.robot_state_joints_config[joint_name] + hold_np[:, action_idx] = joint_pos_sim_np[:, state_idx] + return torch.tensor(hold_np, dtype=torch.float, device=self.device) From 77b5f23cdfa438fa21a9ac69545678768e474e12 Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Thu, 16 Apr 2026 16:57:10 +0200 Subject: [PATCH 17/17] Skip model_path validation in remote wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The remote wrapper doesn't load the model — the server does. The YAML config may reference a model_path that only exists on the server machine. Replace it with a dummy value so __post_init__ validation passes. Signed-off-by: Clemens Volk --- .../policy/gr00t_remote_closedloop_policy.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py index 43287af37..4514003dd 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -35,7 +35,24 @@ extract_obs_numpy_from_torch, load_gr00t_joint_configs, ) -from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml, load_gr00t_modality_config_from_file +from isaaclab_arena_gr00t.utils.io_utils import ( + create_config_from_yaml, + load_config_from_yaml, + load_gr00t_modality_config_from_file, +) + + +def _load_config_skip_model_path(yaml_path: str) -> Gr00tClosedloopPolicyConfig: + """Load Gr00tClosedloopPolicyConfig but skip model_path validation. + + The remote wrapper doesn't load the model — the server does. The YAML + may reference a model_path that only exists on the server machine. + We replace it with a dummy HuggingFace-style ID so __post_init__ + validation passes without requiring the file on disk. + """ + data = load_config_from_yaml(yaml_path, Gr00tClosedloopPolicyConfig) + data["model_path"] = "remote/server-side-model" + return Gr00tClosedloopPolicyConfig(**data) @dataclass @@ -101,9 +118,12 @@ class Gr00tRemoteClosedloopPolicy(PolicyBase): def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): super().__init__(config) - # Policy config (for obs/action translation — no model loading) - self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( - config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig + # Policy config (for obs/action translation — no model loading). + # The YAML may contain a model_path for the local policy, but the remote + # wrapper doesn't load any model — the server does. Override model_path + # to skip the local-filesystem validation in __post_init__. + self.policy_config: Gr00tClosedloopPolicyConfig = _load_config_skip_model_path( + config.policy_config_yaml_path ) self.num_envs = config.num_envs self.device = config.policy_device