Support parallel envs in Pi0RemotePolicy#688
Conversation
Replace the num_envs==1 assert with a per-env action cache and a loop over envs in get_action. openpi's wire format is one obs per request, so this still issues one server call per env per chunk refill; the correctness payoff is that any rollout (policy_runner or eval_runner) can now request num_envs>1 instead of being silently single-env. Concretely: - Pi0EmbodimentAdapter.extract(observation) gains an env_id parameter so adapters slice one env at a time. Pi0DroidAdapter swaps its hardcoded [0] indexing for [env_id]. This is a breaking change for any external Pi0EmbodimentAdapter subclass. - Pi0RemotePolicy keeps per-env _cached_action_chunks and _next_chunk_steps lists, lazy-allocated on the first get_action when num_envs is observable. Mid-rollout num_envs changes assert. - reset(env_ids) clears only the listed envs' caches, so per-env termination (the common case in parallel rollouts) leaves untouched envs replaying their cached chunks uninterrupted. - Connection-drop reconnect still flushes all per-env caches to preserve the prior "stale server state" defensive behaviour. Tests cover the num_envs>1 loop, partial reset semantics, and the reconnect-flush behaviour; the single-env tests are unchanged. Signed-off-by: Clemens Volk <cvolk@nvidia.com>
There was a problem hiding this comment.
Code Review for PR #688: Support parallel envs in Pi0RemotePolicy
Thank you for this well-structured PR! The implementation cleanly extends Pi0RemotePolicy to support parallel environments while maintaining backward compatibility for the single-env case. Here are my findings:
✅ Strengths
- Well-documented breaking change: The
Pi0EmbodimentAdapter.extract()interface change is clearly noted in the PR description and commit message. - Efficient lazy allocation: Per-env caches are only allocated on first
get_action()call whennum_envsis observable. - Clean partial reset semantics:
reset(env_ids)correctly preserves caches for untouched envs, which is important for per-env termination in parallel rollouts. - Comprehensive test coverage: New tests for parallel envs, partial reset, and existing reconnect behavior.
🔄 Update (2026-05-20)
Reviewed new commit ae326000. Changes since previous review:
✅ Fixed: 0-dim env_ids tensor handling in reset() - now uses .reshape(-1).tolist() as suggested.
Remaining: The P1 concern about step counters being mutated before exception rollback (line 175) was not addressed in this commit. This is a minor edge case for non-websocket exceptions.
Overall, looking good! 🎉
Greptile SummaryThis PR extends
Confidence Score: 3/5Safe to merge for single-env workloads; multi-env rollouts that hit a websocket reconnect mid-loop will receive an action batch that mixes pre-reconnect and post-reconnect model outputs for different envs in the same step. The reconnect flush in _call_server_with_retry wipes all per-env caches and step counters while get_action is still mid-loop. Envs processed before the reconnect have their actions already staged in a local list, but their caches and counters are reset to 0, while the env that triggered the reconnect gets a fresh chunk. The resulting action tensor has incoherent provenance across envs, and step counters diverge from what was actually returned. isaaclab_arena_openpi/policy/pi0_remote_policy.py — specifically the interaction between the per-env loop in get_action and the all-env flush in _call_server_with_retry. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant Pi0RemotePolicy
participant Pi0EmbodimentAdapter
participant OpenPiServer
Caller->>Pi0RemotePolicy: get_action(env, obs)
Pi0RemotePolicy->>Pi0RemotePolicy: _maybe_init_per_env_state(N)
loop for env_id in 0..N-1
alt cache miss
Pi0RemotePolicy->>Pi0EmbodimentAdapter: extract(obs, env_id)
Pi0EmbodimentAdapter-->>Pi0RemotePolicy: extracted obs
Pi0RemotePolicy->>Pi0EmbodimentAdapter: pack_request(extracted, task)
Pi0EmbodimentAdapter-->>Pi0RemotePolicy: wire-format request
Pi0RemotePolicy->>OpenPiServer: infer(request)
OpenPiServer-->>Pi0RemotePolicy: actions (H, action_dim)
Pi0RemotePolicy->>Pi0RemotePolicy: store chunk, reset step to 0
end
Pi0RemotePolicy->>Pi0RemotePolicy: append chunk row, increment step
end
Pi0RemotePolicy->>Pi0RemotePolicy: np.stack into (N, action_dim)
Pi0RemotePolicy-->>Caller: torch.Tensor shape (N, action_dim)
Caller->>Pi0RemotePolicy: reset(env_ids)
Pi0RemotePolicy->>Pi0RemotePolicy: clear selected env caches and counters
|
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Summary
Arena rollouts can now request
num_envs > 1instead of being restricted to single-env.reset(env_ids)clears only the listed envs' caches, so per-env termination leaves untouched envs replaying uninterrupted.