diff --git a/verifiers/clients/renderer_client.py b/verifiers/clients/renderer_client.py index d7a58d0e8..5881482f2 100644 --- a/verifiers/clients/renderer_client.py +++ b/verifiers/clients/renderer_client.py @@ -8,6 +8,7 @@ concurrent rollouts tokenize in parallel instead of blocking the event loop. """ +import asyncio import json import threading from collections.abc import Mapping @@ -56,6 +57,7 @@ UserMessage, ) from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.multimodal import prepare_images_inplace # Module-level bridge counters. Incremented by every RendererClient instance # that tries to stitch a multi-turn prompt; callers (e.g. prime-rl's @@ -472,6 +474,7 @@ def _get_renderer_or_pool( async def to_native_prompt( self, messages: Messages ) -> tuple[list[RendererMessage], dict]: + await asyncio.to_thread(prepare_images_inplace, messages) return ( _attach_tool_call_names([_to_renderer_message(m) for m in messages]), {}, diff --git a/verifiers/types.py b/verifiers/types.py index feac3b168..80ad8c137 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -213,7 +213,7 @@ class ResponseTokens(CustomBaseModel): completion_logprobs: list[float] routed_experts: RoutedExpertsPayload | None = None # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) - # carrying processed pixel_values / placeholder ranges per modality. + # carrying raw image descriptors / placeholder ranges per modality. # Populated by the renderer client when the rollout went through a # multimodal-aware renderer; ``None`` otherwise. Stored as ``Any`` to # avoid a hard import dependency on ``renderers`` at this layer. @@ -260,7 +260,7 @@ class TrajectoryStepTokens(TypedDict): is_truncated: bool routed_experts: RoutedExpertsPayload | None # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) - # carrying processed pixel_values / placeholder ranges per modality. + # carrying raw image descriptors / placeholder ranges per modality. # ``NotRequired`` because text-only rollouts (and non-renderer client # types) never populate it. multi_modal_data: NotRequired[Any] diff --git a/verifiers/utils/multimodal.py b/verifiers/utils/multimodal.py new file mode 100644 index 000000000..d9b9c6939 --- /dev/null +++ b/verifiers/utils/multimodal.py @@ -0,0 +1,90 @@ +"""Multimodal ingress helpers for renderer-backed training.""" + +from __future__ import annotations + +from importlib import import_module +from pathlib import Path +from typing import Any + + +def _offload_image_url(url: object, image_dir: Path | None) -> str | None: + try: + offload_image_to_run_assets = getattr( + import_module("renderers.mm_store"), + "offload_image_to_run_assets", + ) + except ( + ImportError, + AttributeError, + ) as exc: # pragma: no cover - dependency-version guard + raise RuntimeError( + "Multimodal training requires a renderers version with raw image " + "asset offload support." + ) from exc + + return offload_image_to_run_assets(url, image_dir=image_dir) + + +def _image_source_url(source: Any) -> object: + if isinstance(source, dict): + return source.get("url") + return getattr(source, "url", None) + + +def _set_image_source_url(source: Any, url: str) -> None: + if isinstance(source, dict): + source["url"] = url + else: + source.url = url + + +def _require_file_image_url(source: Any) -> None: + url = _image_source_url(source) + if not isinstance(url, str) or not url.startswith("file://"): + raise RuntimeError( + "multimodal training requires image_url entries to be offloaded " + "to file:// run image assets" + ) + + +def _prepare_image_source(source: Any, *, image_dir: Path | None) -> None: + result = _offload_image_url(_image_source_url(source), image_dir) + if result is not None: + _set_image_source_url(source, result) + _require_file_image_url(source) + + +def prepare_images_inplace(value: Any, *, image_dir: Path | None = None) -> None: + """Offload image URLs reachable from ``value`` to run image assets. + + Handles OpenAI wire dicts/lists and the pydantic v0/v1 message/content-part + models used by trajectories and traces. + """ + if isinstance(value, dict): + if value.get("type") == "image_url": + source = value.get("image_url") + if source is not None: + _prepare_image_source(source, image_dir=image_dir) + for child in value.values(): + prepare_images_inplace(child, image_dir=image_dir) + return + + if isinstance(value, list): + for child in value: + prepare_images_inplace(child, image_dir=image_dir) + return + + if isinstance(value, tuple): + for child in value: + prepare_images_inplace(child, image_dir=image_dir) + return + + if getattr(value, "type", None) == "image_url": + source = getattr(value, "image_url", None) + if source is not None: + _prepare_image_source(source, image_dir=image_dir) + return + + content = getattr(value, "content", None) + if isinstance(content, (list, tuple)): + prepare_images_inplace(content, image_dir=image_dir) diff --git a/verifiers/v1/ARCHITECTURE.md b/verifiers/v1/ARCHITECTURE.md index 4251aec7f..5b5afda70 100644 --- a/verifiers/v1/ARCHITECTURE.md +++ b/verifiers/v1/ARCHITECTURE.md @@ -70,12 +70,16 @@ end to end: each surviving context window is just another root→leaf path. `Trace.to_record()` (`trace.py`) is the JSON record dump (`model_dump(mode="json")`) for `results.jsonl` / W&B tables, minus the per-node training tensors (`MessageNode.multi_modal_data`, -`routed_experts`, via `_NODE_DUMP_EXCLUDE`): those hold raw numpy bytes that can't round-trip JSON -(the dump raises `UnicodeDecodeError` on real expert ids) and bloat every line. Computed views +`routed_experts`, via `_NODE_DUMP_EXCLUDE`): routed-expert tensors hold raw numpy bytes that can't +round-trip JSON (the dump raises `UnicodeDecodeError` on real expert ids), and multimodal +descriptors are trainer sidecars rather than rollout records. Computed views (`reward`, `branches`, `num_turns`, per-span `duration`) are pydantic properties, so they're never serialized and recompute on load; `state` is excluded. The tensors still reach the trainer over the env-server *wire*, which uses msgpack `model_dump(mode="python")` and carries them as raw `bin` bytes -(not base64) via the field serializers on `MessageNode` (`graph.py`); only the JSON record strips them. +(not base64) via the field serializers on `MessageNode` (`graph.py`); only the JSON record strips +them. Multimodal training uses raw run-image assets: the train client rewrites base64 image parts to +`file://` refs before tracing, and `MessageNode.multi_modal_data` carries lightweight renderer +descriptors (hashes, placeholder ranges, image metadata/refs) rather than image processor outputs. ### Branching: message-level vs renderer-level, and the token invariant @@ -111,9 +115,10 @@ The renderer client avoids the break entirely when it can: instead of re-renderi each turn, the train client (`clients/train.py`) calls `renderer.bridge_to_next_turn(...)`, which keeps the prior `prompt_ids + completion_ids` **verbatim** and only renders the new tail. Verbatim prior ⇒ the stored prefix matches token-for-token ⇒ no fork, one linear branch, invariant intact. -The token-identity check in `commit` is the backstop for when the bridge can't apply (the renderer -returns `None`, multimodal, the eval relay): the break still surfaces as honest branches rather than -silent corruption. +For multimodal renderers, the train client also passes the reusable prefix's `multi_modal_data` so +prior image placeholders and descriptors remain aligned. The token-identity check in `commit` is the +backstop for when the bridge can't apply (the renderer returns `None`, the eval relay): the break +still surfaces as honest branches rather than silent corruption. ## Model access — interception, dialects, clients diff --git a/verifiers/v1/cli/dashboard/eval.py b/verifiers/v1/cli/dashboard/eval.py index 78b0edd15..93bdfa576 100644 --- a/verifiers/v1/cli/dashboard/eval.py +++ b/verifiers/v1/cli/dashboard/eval.py @@ -190,10 +190,12 @@ def _breakdown(done: list[Trace]) -> Table | None: names.extend(n for n in getattr(trace, source) if n not in names) if not names: continue - segments = [ - f"{name} {format_mean(done, lambda t, n=name, s=source: getattr(t, s).get(n, 0.0))}" - for name in names - ] + segments = [] + for name in names: + value = format_mean( + done, lambda t, n=name, s=source: getattr(t, s).get(n, 0.0) + ) + segments.append(f"{name} {value}") grid.add_row(label, " · ".join(segments)) # Resource use over every completed rollout (errored ones still spent tokens/time): tokens and diff --git a/verifiers/v1/clients/client.py b/verifiers/v1/clients/client.py index 7642b0986..dbbc19c79 100644 --- a/verifiers/v1/clients/client.py +++ b/verifiers/v1/clients/client.py @@ -33,6 +33,19 @@ class RelayReply: class Client(ABC): + async def prepare_request_body(self, dialect: Dialect, body: dict) -> dict: + """Normalize a provider request before the interception server parses/traces it. + + Relay clients keep the request verbatim. Training clients may rewrite heavy + in-process payloads (for example base64 images) into stable run-asset refs so the + trace, renderer, and trainer all see the same cheap message content. + """ + return body + + async def prepare_messages(self, dialect: Dialect, messages: list) -> list: + """Normalize typed simulator messages before adding them to the wire body/trace.""" + return messages + @abstractmethod async def get_response( self, diff --git a/verifiers/v1/clients/train.py b/verifiers/v1/clients/train.py index b8c8b7a39..36b612240 100644 --- a/verifiers/v1/clients/train.py +++ b/verifiers/v1/clients/train.py @@ -8,6 +8,7 @@ needs a running vLLM engine. """ +import asyncio import json from collections.abc import Mapping from typing import Any @@ -16,6 +17,7 @@ from renderers import RenderedTokens from renderers import OverlongPromptError as RendererOverlongPromptError from renderers import RendererConfig +from renderers.base import is_multimodal from verifiers.v1.clients.client import SESSION_ID_HEADER, Client from verifiers.v1.dialects import FINISH_REASONS, ChatDialect, Dialect, parse_tools @@ -32,6 +34,7 @@ TurnTokens, Usage, ) +from verifiers.utils.multimodal import prepare_images_inplace def tool_to_wire(tool: Tool) -> dict: @@ -167,16 +170,6 @@ def _is_valid_incremental_tail(messages: list[dict[str, Any]]) -> bool: return all(role == "tool" for role in roles) -def _has_multimodal_content(messages) -> bool: - for message in messages: - content = getattr(message, "content", None) - if not isinstance(content, list): - continue - if any(getattr(part, "type", None) == "image_url" for part in content): - return True - return False - - class TrainClient(Client): """Renders prompts to token ids and calls a vLLM `/inference/v1/generate` engine.""" @@ -213,6 +206,16 @@ def _renderer_pool( ) return self._pool + async def prepare_request_body(self, dialect: Dialect, body: dict) -> dict: + if isinstance(dialect, ChatDialect): + await asyncio.to_thread(prepare_images_inplace, body) + return body + + async def prepare_messages(self, dialect: Dialect, messages: list) -> list: + if isinstance(dialect, ChatDialect): + await asyncio.to_thread(prepare_images_inplace, messages) + return messages + async def get_response( self, dialect: Dialect, @@ -263,23 +266,24 @@ async def get_response( ) bridged_turn: PendingTurn | None = None - # Only build the (O(context)) previous-turn token ids once the cheap guards pass — a - # multimodal prompt or a tail that isn't a clean `[tool*, user?]` extension can't bridge. - can_bridge = ( - turn is not None - and not _has_multimodal_content(prompt) - and _is_valid_incremental_tail(wire_messages) - ) + # Only build the (O(context)) previous-turn token ids once the cheap guards pass: a + # tail that isn't a clean `[tool*, user?]` extension can't bridge. + can_bridge = turn is not None and _is_valid_incremental_tail(wire_messages) previous_ids = turn.previous_token_ids() if can_bridge else None if previous_ids is not None: previous_prompt_ids, previous_completion_ids = previous_ids def bridge(): + kwargs: dict[str, Any] = {"tools": wire_tools} + if is_multimodal(renderer): + kwargs["previous_multi_modal_data"] = ( + turn.previous_multi_modal_data() + ) return renderer.bridge_to_next_turn( previous_prompt_ids, previous_completion_ids, wire_messages, - tools=wire_tools, + **kwargs, ) bridged = await _maybe_offload(renderer, bridge) diff --git a/verifiers/v1/graph.py b/verifiers/v1/graph.py index 371bd7174..a8104d537 100644 --- a/verifiers/v1/graph.py +++ b/verifiers/v1/graph.py @@ -60,6 +60,46 @@ def _decode_ndarray(d: dict) -> np.ndarray: return np.frombuffer(d["data"], dtype=np.dtype(d["dtype"])).reshape(d["shape"]) +_PROCESSED_MM_KEYS = frozenset({"pixel_values", "image_embeds", "image_features"}) + + +def _contains_processed_mm_key(value: Any) -> bool: + if isinstance(value, dict): + return bool(_PROCESSED_MM_KEYS.intersection(value)) or any( + _contains_processed_mm_key(v) for v in value.values() + ) + if isinstance(value, (list, tuple)): + return any(_contains_processed_mm_key(v) for v in value) + return False + + +def _validate_raw_mm_item(item: Any) -> dict[str, Any]: + if not isinstance(item, dict): + raise TypeError( + "v1 multimodal sidecars must be raw image descriptor dicts, " + f"got {type(item).__name__}" + ) + if _contains_processed_mm_key(item): + raise TypeError( + "v1 multimodal sidecars must be raw image descriptors, " + "not processed multimodal payloads" + ) + if not isinstance(item.get("raw_image_uri"), str) or not item["raw_image_uri"]: + raise ValueError("v1 multimodal sidecars require raw_image_uri") + return dict(item) + + +def _validate_raw_mm_data(mmd: MultiModalData) -> MultiModalData: + return MultiModalData( + mm_hashes={k: list(v) for k, v in mmd.mm_hashes.items()}, + mm_placeholders={k: list(v) for k, v in mmd.mm_placeholders.items()}, + mm_items={ + modality: [_validate_raw_mm_item(item) for item in items] + for modality, items in mmd.mm_items.items() + }, + ) + + class MessageNode(StrictBaseModel): """One message in the graph: a message plus the tokens it adds to the cumulative sequence. Concatenating a root→leaf path's nodes reconstructs that branch's full token @@ -97,14 +137,16 @@ class MessageNode(StrictBaseModel): finish_reason: FinishReason = None """The response's finish reason (assistant nodes only) — kept for truncation detection.""" multi_modal_data: MultiModalData | None = None - """The renderer items for the images this message's content introduces (pixel tensors, - grids, hashes, placeholders) — the only carrier of the pixels from the env server to the - trainer. `Branch.multi_modal_data` concatenates them along the path into the training - `mm_kwargs`. Rides the wire as raw bytes (msgpack `bin`) since pydantic can't JSON the numpy; - kept off disk by the dump-site `exclude` in prime-rl (the tensors bloat the rollout jsonl).""" + """The renderer items for images this message introduces. + + With the raw-image path, items are lightweight descriptors (hashes, grid metadata, and + optional run-image refs), not image processor tensors. `Branch.multi_modal_data` concatenates + them along the path for the trainer. Old processed-payload sidecars are rejected. + """ usage: Usage | None = None - """Provider-reported token usage for this message's response (assistant nodes). Preserved - on the wire and on disk, including cache-read tokens when the provider reports them.""" + """Provider-reported token usage for this message's response (assistant nodes). Preserved on + the wire and on disk so dashboards can show token counts and cost even when the endpoint + returns no token ids.""" routed_experts: np.ndarray | None = None """This node's slice of the MoE expert-routing array — uint8 `[len(token_ids), layers, top_k]`, the expert ids inference selected for exactly this node's tokens. Attributed from @@ -116,10 +158,10 @@ class MessageNode(StrictBaseModel): @field_serializer("multi_modal_data") def serialize_multi_modal_data(self, mmd: MultiModalData | None) -> dict | None: - """`MultiModalData` -> msgpack-safe dict so the pixel tensors ride the wire; numpy - `mm_items` values become raw-bytes `__nd__` dicts (every renderer emits `return_tensors="np"`).""" + """`MultiModalData` -> msgpack-safe raw descriptor dict.""" if mmd is None: return None + mmd = _validate_raw_mm_data(mmd) return { "mm_hashes": {k: list(v) for k, v in mmd.mm_hashes.items()}, "mm_placeholders": { @@ -127,9 +169,7 @@ def serialize_multi_modal_data(self, mmd: MultiModalData | None) -> dict | None: for modality, ranges in mmd.mm_placeholders.items() }, "mm_items": { - modality: [ - {k: _encode_ndarray(v) for k, v in item.items()} for item in items - ] + modality: [dict(item) for item in items] for modality, items in mmd.mm_items.items() }, } @@ -137,25 +177,29 @@ def serialize_multi_modal_data(self, mmd: MultiModalData | None) -> dict | None: @field_validator("multi_modal_data", mode="before") @classmethod def deserialize_multi_modal_data(cls, value: Any) -> MultiModalData | None: - if value is None or isinstance(value, MultiModalData): + if value is None: return value + if isinstance(value, MultiModalData): + return _validate_raw_mm_data(value) if not isinstance(value, dict): raise TypeError(f"cannot build MultiModalData from {type(value).__name__}") - return MultiModalData( - mm_hashes={k: list(v) for k, v in (value.get("mm_hashes") or {}).items()}, - mm_placeholders={ - modality: [ - PlaceholderRange(offset=p["offset"], length=p["length"]) - for p in ranges - ] - for modality, ranges in (value.get("mm_placeholders") or {}).items() - }, - mm_items={ - modality: [ - {k: _decode_ndarray(v) for k, v in item.items()} for item in items - ] - for modality, items in (value.get("mm_items") or {}).items() - }, + return _validate_raw_mm_data( + MultiModalData( + mm_hashes={ + k: list(v) for k, v in (value.get("mm_hashes") or {}).items() + }, + mm_placeholders={ + modality: [ + PlaceholderRange(offset=p["offset"], length=p["length"]) + for p in ranges + ] + for modality, ranges in (value.get("mm_placeholders") or {}).items() + }, + mm_items={ + modality: list(items) + for modality, items in (value.get("mm_items") or {}).items() + }, + ) ) @field_serializer("routed_experts") @@ -304,6 +348,23 @@ def prompt_message_spans( for span in tail_spans ] + def previous_multi_modal_data(self) -> MultiModalData | None: + """Concatenate multimodal sidecars attached to the reusable prefix.""" + merged = MultiModalData() + found = False + for nid in self.prefix_node_ids: + mmd = self.trace.nodes[nid].multi_modal_data + if mmd is None or mmd.is_empty(): + continue + found = True + for modality, items in mmd.mm_items.items(): + merged.mm_items.setdefault(modality, []).extend(items) + for modality, hashes in mmd.mm_hashes.items(): + merged.mm_hashes.setdefault(modality, []).extend(hashes) + for modality, placeholders in mmd.mm_placeholders.items(): + merged.mm_placeholders.setdefault(modality, []).extend(placeholders) + return merged if found else None + def commit(self, response: Response) -> None: _commit_turn(self, response) @@ -360,8 +421,9 @@ def _attribute_mm( renderer emits items per modality in prompt order (message order, then content-part order), so we walk the path advancing a per-modality cursor over every message's media but write only the nodes created this turn — `path[:num_reused]` is the reused prefix, already - attributed when first created. Item order is all training needs; placeholder offsets aren't - carried.""" + attributed when first created. Each node gets the hashes/items/placeholders for exactly the + media it introduced, preserving vLLM multimodal-list alignment when those node sidecars are + later merged for bridge or training.""" if mmd is None or mmd.is_empty(): return cursors: dict[str, int] = {} @@ -371,6 +433,7 @@ def _attribute_mm( continue node_items: dict[str, list] = {} node_hashes: dict[str, list] = {} + node_placeholders: dict[str, list[PlaceholderRange]] = {} for part in content: modality = _part_modality(part) if modality is None: @@ -382,13 +445,20 @@ def _attribute_mm( continue items = mmd.mm_items.get(modality) or [] hashes = mmd.mm_hashes.get(modality) or [] + placeholders = mmd.mm_placeholders.get(modality) or [] if k < len(items): node_items.setdefault(modality, []).append(items[k]) if k < len(hashes): node_hashes.setdefault(modality, []).append(hashes[k]) - if node_items: - trace.nodes[node_id].multi_modal_data = MultiModalData( - mm_items=node_items, mm_hashes=node_hashes + if k < len(placeholders): + node_placeholders.setdefault(modality, []).append(placeholders[k]) + if node_items or node_hashes or node_placeholders: + trace.nodes[node_id].multi_modal_data = _validate_raw_mm_data( + MultiModalData( + mm_items=node_items, + mm_hashes=node_hashes, + mm_placeholders=node_placeholders, + ) ) diff --git a/verifiers/v1/interception/server.py b/verifiers/v1/interception/server.py index ddf7ffac6..881b14228 100644 --- a/verifiers/v1/interception/server.py +++ b/verifiers/v1/interception/server.py @@ -37,6 +37,7 @@ from verifiers.v1.dialects.base import is_sse_done_event from verifiers.v1 import graph from verifiers.v1.errors import ( + InterceptionError, OverlongPromptError, RolloutError, TasksetError, @@ -262,6 +263,18 @@ async def handle_request( # alias after parsing so the wire body does not survive model inference. request._read_bytes = None del raw + try: + body = await session.ctx.client.prepare_request_body(dialect, body) + except RolloutError as e: + return self._fail(session, dialect, e) + except Exception as e: + return self._fail( + session, + dialect, + InterceptionError( + f"request preparation failed: {type(e).__name__}: {e}" + ), + ) logger.debug( "intercept %s: id=%s stream=%s", request.path, @@ -288,7 +301,18 @@ async def handle_request( and session.trace.num_turns == 0 ): if session.opening is None: - session.opening = await session.user("") + try: + session.opening = await session.ctx.client.prepare_messages( + dialect, await session.user("") + ) + except RolloutError as e: + return self._fail(session, dialect, e) + except Exception as e: + return self._fail( + session, + dialect, + UserError(f"user simulator failed: {type(e).__name__}: {e}"), + ) body = dialect.extend(body, None, session.opening) prompt = [*prompt, *session.opening] # If the simulator ended at the open (its taskset's `@stop` now fires), the loop's @@ -383,6 +407,9 @@ async def handle_request( return _completion_response(completion) try: user_messages = await session.user(response.message.content or "") + user_messages = await session.ctx.client.prepare_messages( + dialect, user_messages + ) except RolloutError as e: return self._fail(session, dialect, e) except Exception as e: diff --git a/verifiers/v1/legacy.py b/verifiers/v1/legacy.py index 02410b012..370a2a059 100644 --- a/verifiers/v1/legacy.py +++ b/verifiers/v1/legacy.py @@ -13,6 +13,7 @@ v1 stays importable without the v0 package present. """ +import asyncio import contextlib import logging from pathlib import Path @@ -162,6 +163,7 @@ def _to_v1_tokens(raw: Any) -> TurnTokens | None: prompt_ids=list(raw.get("prompt_ids") or []), completion_ids=list(raw.get("completion_ids") or []), completion_logprobs=list(raw.get("completion_logprobs") or []), + multi_modal_data=raw.get("multi_modal_data"), ) @@ -360,6 +362,20 @@ def _v0_client(self, client_config: ClientConfig, model: str): self._clients[key] = resolve_client(v0_config) return self._clients[key] + async def _state_output_with_live_trajectory(self, state: Any) -> dict: + """Build v0 rollout output metadata while preserving live trajectory sidecars. + + The JSON save path deltas ``tokens.multi_modal_data`` to avoid repeated + cumulative multimodal sidecars. Trace reconstruction needs the live, + cumulative sidecar for each turn so image descriptors align with the + full prompt the renderer saw. + """ + from verifiers.utils.save_utils import state_to_output + + out = await asyncio.to_thread(state_to_output, state, []) + out["trajectory"] = state.get("trajectory", []) + return out + async def _run_v0( self, task_idx: int, @@ -368,13 +384,13 @@ async def _run_v0( sampling: SamplingConfig, ) -> dict: client = self._v0_client(client_config, model) - return await self.env.run_rollout( + state = await self.env._run_rollout_state( input=dict(self.dataset[task_idx]), client=client, model=model, sampling_args=sampling.model_dump(exclude_none=True), - state_columns=["trajectory"], ) + return await self._state_output_with_live_trajectory(state) async def _run_rollout(self, req: RunRolloutRequest) -> RunRolloutResponse: out = await self._run_v0(req.task_idx, req.client, req.model, req.sampling) @@ -385,12 +401,14 @@ async def _run_rollout(self, req: RunRolloutRequest) -> RunRolloutResponse: async def _run_group(self, req: RunGroupRequest) -> RunGroupResponse: client = self._v0_client(req.client, req.model) # run_group scores the rollouts together so group/preference reward funcs apply. - outs = await self.env.run_group( + states = await self.env._run_group_states( group_inputs=[dict(self.dataset[req.task_idx]) for _ in range(req.n)], client=client, model=req.model, sampling_args=req.sampling.model_dump(exclude_none=True), - state_columns=["trajectory"], + ) + outs = await asyncio.gather( + *(self._state_output_with_live_trajectory(state) for state in states) ) traces = [ rollout_output_to_trace(out, req.task_idx).model_dump() for out in outs diff --git a/verifiers/v1/trace.py b/verifiers/v1/trace.py index dc0792c78..3ac6f68ba 100644 --- a/verifiers/v1/trace.py +++ b/verifiers/v1/trace.py @@ -133,10 +133,12 @@ def logprobs(self) -> list[float]: @property def multi_modal_data(self) -> MultiModalData | None: - """The branch's multimodal sidecar — every node's images concatenated in path (token) - order. None when the branch has no images. Drives the training `mm_kwargs` (the renderer - items per modality); the per-token `mm_token_type_ids` come from the token ids, so no - placeholder offsets are carried. Never persisted (node mm is transient).""" + """The branch's multimodal sidecar — every node's images concatenated in path order. + + None when the branch has no images. The raw-image path carries lightweight descriptors + plus placeholder ranges, so downstream vLLM/training multimodal payloads can align hashes, + placeholders, and item refs without reprocessing images in the env worker. + """ merged = MultiModalData() found = False for node in self.nodes: @@ -148,6 +150,8 @@ def multi_modal_data(self) -> MultiModalData | None: merged.mm_items.setdefault(modality, []).extend(items) for modality, hashes in mmd.mm_hashes.items(): merged.mm_hashes.setdefault(modality, []).extend(hashes) + for modality, placeholders in mmd.mm_placeholders.items(): + merged.mm_placeholders.setdefault(modality, []).extend(placeholders) return merged if found else None @property diff --git a/verifiers/v1/types.py b/verifiers/v1/types.py index b86604acb..52b3d8356 100644 --- a/verifiers/v1/types.py +++ b/verifiers/v1/types.py @@ -240,8 +240,8 @@ class TurnTokens(StrictBaseModel): default=None, exclude=True ) is_content: list[bool] | None = Field(default=None, exclude=True) - # Transient carrier (excluded): the renderer's multimodal sidecar (image tensors + offsets), - # attributed per node by the turn's `commit`, then dropped — never persisted. + # Transient carrier (excluded): the renderer's multimodal sidecar (raw-image descriptors, + # hashes, and placeholder offsets), attributed per node by the turn's `commit`, then dropped. multi_modal_data: MultiModalData | None = Field(default=None, exclude=True) # Transient carrier (excluded): the MoE expert-routing data from `generate` (expert ids # per token), attributed per node by the turn's `commit` into `MessageNode.routed_experts`,