Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ dependencies = [
"prime-pydantic-config>=0.3.0.dev83",
]

[project.optional-dependencies]
vision = [
"pillow>=12.2.0",
"torch>=2.11.0",
"torchvision>=0.26.0",
]

[tool.hatch.version]
source = "vcs"
# Tags look like ``renderers-v0.1.8`` (prefix matches the publish.yml
Expand Down
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
LagunaXS2RendererConfig,
Llama3RendererConfig,
MiniMaxM2RendererConfig,
MultimodalOutput,
Nemotron3RendererConfig,
Nemotron3UltraRendererConfig,
Qwen35RendererConfig,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __dir__() -> list[str]:
"Message",
"MiniMaxM2Renderer",
"MiniMaxM2RendererConfig",
"MultimodalOutput",
"MultiModalData",
"MultimodalRenderer",
"Nemotron3Renderer",
Expand Down
22 changes: 10 additions & 12 deletions renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,10 @@ class PlaceholderRange:
class MultiModalData:
"""Multimodal sidecar produced alongside the token stream.

Renderer output is framework-agnostic: ``mm_items[modality][i]`` is a
plain ``dict`` mirroring the per-item output of a HuggingFace processor
(e.g. ``{"pixel_values": Tensor, "image_grid_thw": Tensor}`` for
Qwen3-VL images). Translation to engine-specific wire formats — vLLM's
``MultiModalKwargsItem``, SGLang's payload, etc. — happens in the
inference glue layer (see ``renderers.client``).
``mm_items[modality][i]`` follows the renderer's configured
``multimodal_output``. The default ``"raw"`` mode emits JSON-safe image
descriptor envelopes for inference paths. ``"processed"`` emits
image-processor payloads such as ``pixel_values`` for SFT/training paths.
"""

mm_hashes: dict[str, list[str]] = field(default_factory=dict)
Expand Down Expand Up @@ -760,8 +758,8 @@ def bridge_to_next_turn(
Text-only renderers return :class:`RenderedTokens` with
``multi_modal_data=None``. Multimodal renderers (see
:class:`MultimodalRenderer`) populate ``multi_modal_data`` so
the caller can recover placeholder offsets + per-item processed
tensors for the new full prompt; they also accept a
the caller can recover placeholder offsets + per-item image
descriptors for the new full prompt; they also accept a
``previous_multi_modal_data`` kwarg via the
:class:`MultimodalRenderer` Protocol override.

Expand Down Expand Up @@ -821,8 +819,8 @@ def bridge_to_next_turn(
the combined token sequence and silently falls back to
hash-cache lookup (or errors)
- returns :class:`RenderedTokens` (not ``list[int]``) so the
caller can recover the placeholder offsets + per-item
processed tensors for the new full prompt
caller can recover the placeholder offsets + per-item image
descriptors for the new full prompt
"""
...

Expand Down Expand Up @@ -1475,7 +1473,7 @@ def _resolve_auto_config(
model_name = getattr(tokenizer, "name_or_path", "")
renderer_name = MODEL_RENDERER_MAP.get(model_name)

preserve_carry = {}
preserve_carry: dict[str, Any] = {"multimodal_output": auto.multimodal_output}
if auto.thinking_retention is not None:
preserve_carry["thinking_retention"] = auto.thinking_retention

Expand Down Expand Up @@ -1526,7 +1524,7 @@ def _resolve_auto_config(
"reasoning_parser=...) to enable structured output parsing.",
model_name or "<unnamed tokenizer>",
)
return DefaultRendererConfig()
return DefaultRendererConfig(multimodal_output=auto.multimodal_output)


# ---------------------------------------------------------------------------
Expand Down
192 changes: 83 additions & 109 deletions renderers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
from collections.abc import Mapping
from dataclasses import replace
from typing import Any, cast

import httpx
Expand Down Expand Up @@ -170,11 +171,12 @@ async def generate(
attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` /
``message_roles``) into the result without re-rendering.

For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes
For multimodal renderers, the call goes
through ``renderer.render(...)`` to recover the ``multi_modal_data``
sidecar, then serializes it to vLLM's ``features`` schema (mm_hashes,
mm_placeholders, kwargs_data) before POSTing. The serializer imports
``vllm.*`` lazily so text-only consumers never pay for the import.
mm_placeholders, kwargs_data) before POSTing. Raw image ``kwargs_data``
slots always carry a descriptor ref — every image (current and prior
turns) is sent as a pointer that the inference endpoint materializes.

``max_prompt_len`` controls the pre-flight overflow check. When the
rendered prompt is strictly longer than the cap, the request is never
Expand Down Expand Up @@ -211,12 +213,14 @@ async def generate(
def _prepare():
if prompt_ids is not None:
# Caller-supplied prompt; if they also gave us pre-computed
# attribution (e.g. the bridge path in verifiers), thread it
# through unchanged.
# attribution (e.g. the bridge path in verifiers), thread it through.
prompt_mm_data = multi_modal_data
if prompt_mm_data is None and prompt_attribution is not None:
prompt_mm_data = prompt_attribution.multi_modal_data
return (
list(prompt_ids),
renderer.get_stop_token_ids(),
multi_modal_data,
prompt_mm_data,
prompt_attribution,
)
rendered = renderer.render(messages, tools=tools, add_generation_prompt=True)
Expand Down Expand Up @@ -248,11 +252,19 @@ def _prepare():
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)

def _features_and_descriptor_mm() -> tuple[
dict[str, Any] | None, MultiModalData | None
]:
if mm_data is None or mm_data.is_empty():
return None, mm_data
# Every image carries its raw ref (the pointer); persisted mm_data keeps it,
# so prior-turn images carry their ref forward unchanged.
return _build_vllm_mm_features(mm_data), mm_data

features, out_mm_data = await _maybe_offload(renderer, _features_and_descriptor_mm)
if prompt_attr is not None and out_mm_data is not None:
prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data)
Comment thread
cursor[bot] marked this conversation as resolved.
if features is not None:
body["features"] = features
if cache_salt is not None:
Expand Down Expand Up @@ -322,7 +334,7 @@ def _prepare():
# The mm sidecar consumed on the request side, surfaced back so
# callers can persist it on the trajectory step for downstream
# multi-turn bridging and training-sample construction.
"multi_modal_data": mm_data,
"multi_modal_data": out_mm_data,
# The renderer's per-token attribution for the prompt — either
# the RenderedTokens computed here via renderer.render(...) or
# the one threaded in by the caller alongside prompt_ids (the
Expand All @@ -334,113 +346,75 @@ def _prepare():
}


def _build_mm_features(
renderer: Renderer | RendererPool,
mm_data: MultiModalData,
) -> dict[str, Any] | None:
def _build_vllm_mm_features(mm_data: MultiModalData) -> dict[str, Any]:
"""Serialize ``MultiModalData`` to vLLM's ``/inference/v1/generate`` features payload.

vLLM's ``MultiModalFeatures`` carries three things: hashes (for cache
lookup), placeholder positions (so the engine knows where in the
token stream each item lives), and per-item ``MultiModalKwargsItem``
base64-encoded. The encoding requires vLLM-side type info — what
fields belong to each modality, how they batch — and is currently
model-family specific. For now we dispatch on the renderer class;
extend the dispatch table as more multimodal renderers land.

NOTE — future engine pluggability: this encoder is vLLM 0.20-specific
(uses ``vllm.multimodal.inputs.MultiModalKwargsItems``,
``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and
``_create_qwen2vl_field_factory``). When a second inference engine
arrives (SGLang, MAX, ...) the renderer client should be parameterized
on engine: either (a) move the encoder onto the renderer as
``encode_mm_for_<engine>(mm_data)`` methods, or (b) accept an
``Encoder`` strategy at the ``generate(...)`` call site. The data type
(``MultiModalData``) is already framework-agnostic and does not need
to change. Don't pre-build the abstraction with one engine in tree.
vLLM's ``MultiModalFeatures`` carries three things: hashes, placeholder
positions (so the engine knows where in the token stream each item lives),
and one raw ref per item. Raw multimodal descriptors use the common envelope
emitted by renderers; family-specific geometry stays inside the descriptor
payload and is interpreted downstream by prime-rl/vLLM adapters.
"""
from renderers.qwen3_vl import Qwen3VLRenderer
from renderers.qwen35 import Qwen35Renderer

# Type dispatch only needs the renderer class. Pools expose
# ``renderer_cls`` as a snapshot attribute, so we don't have to check
# out a slot just to read ``type(r)``.
renderer_cls = (
renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer)
)

# Qwen3-VL and Qwen3.5 both ship ``pixel_values`` + ``image_grid_thw``
# via the shared Qwen2-VL field factory. ``spatial_merge_size=2`` is
# the family default and matches every Qwen-VL processor in tree.
if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)):
return _build_qwen_vl_features(mm_data, spatial_merge_size=2)

raise NotImplementedError(
f"Multimodal serialization not implemented for {renderer_cls.__name__}. "
"Add a dispatch branch in renderers.client._build_mm_features."
from renderers.mm_store import (
RAW_MM_ITEM_KIND,
raw_mm_ref,
)


def _build_qwen_vl_features(
mm_data: MultiModalData, *, spatial_merge_size: int
) -> dict[str, Any]:
"""vLLM features payload for the Qwen-VL family (Qwen2-VL / Qwen3-VL).

Stacks per-image processor outputs back into a batched ``BatchFeature``,
runs the Qwen2-VL field factory (shared across the family), wraps as
``MultiModalKwargsItems``, base64-encodes each item, and assembles a
JSON-serializable dict matching vLLM's ``MultiModalFeatures`` schema.

Returns ``None`` semantics live one level up — this helper assumes
the caller already verified ``mm_data`` is non-empty.
"""
try:
import torch
from transformers.feature_extraction_utils import BatchFeature
from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item
from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory
from vllm.multimodal.inputs import MultiModalKwargsItems
except ImportError as exc:
raise RuntimeError(
"Multimodal generate via /inference/v1/generate requires `vllm` "
"and `torch` to encode the features payload. Install vLLM in this "
"environment, or pre-build features upstream."
) from exc

out: dict[str, Any] = {
"mm_hashes": {},
"mm_placeholders": {},
"kwargs_data": {},
}

image_items = mm_data.mm_items.get("image") or []
if image_items:
# mm_items now ship numpy arrays (the renderer is torch-free);
# convert at this vLLM-glue boundary where torch is already a
# hard dependency.
pixel_values = torch.cat(
[torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0
)
image_grid_thw = torch.cat(
[torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0
)
hf_inputs = BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}
)
config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs)
kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config)
encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]]
out["kwargs_data"]["image"] = encoded
out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or [])
out["mm_placeholders"]["image"] = [
{"offset": p.offset, "length": p.length}
for p in mm_data.mm_placeholders.get("image") or []
]

# If kwargs_data is empty across all modalities, drop the key so vLLM
# falls back to the hash-only (cache-hit) path. Otherwise hand it the
# full payload.
if not any(out["kwargs_data"].values()):
out["kwargs_data"] = None
for source_modality, items in mm_data.mm_items.items():
if not items:
continue
mm_hashes = list(mm_data.mm_hashes.get(source_modality) or [])
placeholders = list(mm_data.mm_placeholders.get(source_modality) or [])
if len(mm_hashes) != len(items) or len(placeholders) != len(items):
raise ValueError(
"Multimodal sidecar length mismatch: "
f"modality={source_modality} items={len(items)} "
f"hashes={len(mm_hashes)} placeholders={len(placeholders)}"
)

for idx, item in enumerate(items):
if item.get("kind") != RAW_MM_ITEM_KIND:
raise NotImplementedError(
"renderers.client.generate() requires raw multimodal "
"descriptor envelopes (multimodal_output='raw'); "
f"got item keys {sorted(item)} for modality {source_modality!r}."
)
feature_modality = item.get("vllm_modality") or source_modality
if not isinstance(feature_modality, str) or not feature_modality:
raise ValueError("raw multimodal item has invalid vllm_modality")

raw_image_uri = item.get("raw_image_uri")
family = item.get("family")
fingerprint = item.get("layout_fingerprint")
payload = item.get("payload")
if not isinstance(raw_image_uri, str) or not raw_image_uri:
raise ValueError("raw multimodal item is missing raw_image_uri")
if not isinstance(family, str) or not family:
raise ValueError("raw multimodal item is missing family")
if not isinstance(fingerprint, str) or not fingerprint:
raise ValueError("raw multimodal item is missing layout_fingerprint")
if not isinstance(payload, dict):
raise ValueError("raw multimodal item payload must be a dict")

out["mm_hashes"].setdefault(feature_modality, []).append(mm_hashes[idx])
out["mm_placeholders"].setdefault(feature_modality, []).append(
{"offset": placeholders[idx].offset, "length": placeholders[idx].length}
)
out["kwargs_data"].setdefault(feature_modality, []).append(
raw_mm_ref(
family=family,
fingerprint=fingerprint,
modality=feature_modality,
mm_hash=mm_hashes[idx],
raw_image_uri=raw_image_uri,
payload=payload,
)
Comment thread
cursor[bot] marked this conversation as resolved.
)

return out
Loading
Loading