diff --git a/pyproject.toml b/pyproject.toml index dc75397..3db0939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,13 @@ dependencies = [ "prime-pydantic-config>=0.3.0.dev83", ] +[project.optional-dependencies] +vision = [ + "pillow>=12.2.0", + "torch>=2.11.0", + "torchvision>=0.26.0", +] + [tool.hatch.version] source = "vcs" # Tags look like ``renderers-v0.1.8`` (prefix matches the publish.yml diff --git a/renderers/__init__.py b/renderers/__init__.py index 9fd385e..bb8e6ba 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -55,6 +55,7 @@ LagunaXS2RendererConfig, Llama3RendererConfig, MiniMaxM2RendererConfig, + MultimodalOutput, Nemotron3RendererConfig, Nemotron3UltraRendererConfig, Qwen35RendererConfig, @@ -144,6 +145,7 @@ def __dir__() -> list[str]: "Message", "MiniMaxM2Renderer", "MiniMaxM2RendererConfig", + "MultimodalOutput", "MultiModalData", "MultimodalRenderer", "Nemotron3Renderer", diff --git a/renderers/base.py b/renderers/base.py index b1d397f..81d3d60 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -205,12 +205,10 @@ class PlaceholderRange: class MultiModalData: """Multimodal sidecar produced alongside the token stream. - Renderer output is framework-agnostic: ``mm_items[modality][i]`` is a - plain ``dict`` mirroring the per-item output of a HuggingFace processor - (e.g. ``{"pixel_values": Tensor, "image_grid_thw": Tensor}`` for - Qwen3-VL images). Translation to engine-specific wire formats — vLLM's - ``MultiModalKwargsItem``, SGLang's payload, etc. — happens in the - inference glue layer (see ``renderers.client``). + ``mm_items[modality][i]`` follows the renderer's configured + ``multimodal_output``. The default ``"raw"`` mode emits JSON-safe image + descriptor envelopes for inference paths. ``"processed"`` emits + image-processor payloads such as ``pixel_values`` for SFT/training paths. """ mm_hashes: dict[str, list[str]] = field(default_factory=dict) @@ -760,8 +758,8 @@ def bridge_to_next_turn( Text-only renderers return :class:`RenderedTokens` with ``multi_modal_data=None``. Multimodal renderers (see :class:`MultimodalRenderer`) populate ``multi_modal_data`` so - the caller can recover placeholder offsets + per-item processed - tensors for the new full prompt; they also accept a + the caller can recover placeholder offsets + per-item image + descriptors for the new full prompt; they also accept a ``previous_multi_modal_data`` kwarg via the :class:`MultimodalRenderer` Protocol override. @@ -821,8 +819,8 @@ def bridge_to_next_turn( the combined token sequence and silently falls back to hash-cache lookup (or errors) - returns :class:`RenderedTokens` (not ``list[int]``) so the - caller can recover the placeholder offsets + per-item - processed tensors for the new full prompt + caller can recover the placeholder offsets + per-item image + descriptors for the new full prompt """ ... @@ -1475,7 +1473,7 @@ def _resolve_auto_config( model_name = getattr(tokenizer, "name_or_path", "") renderer_name = MODEL_RENDERER_MAP.get(model_name) - preserve_carry = {} + preserve_carry: dict[str, Any] = {"multimodal_output": auto.multimodal_output} if auto.thinking_retention is not None: preserve_carry["thinking_retention"] = auto.thinking_retention @@ -1526,7 +1524,7 @@ def _resolve_auto_config( "reasoning_parser=...) to enable structured output parsing.", model_name or "", ) - return DefaultRendererConfig() + return DefaultRendererConfig(multimodal_output=auto.multimodal_output) # --------------------------------------------------------------------------- diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..c8015da 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -15,6 +15,7 @@ import json import logging from collections.abc import Mapping +from dataclasses import replace from typing import Any, cast import httpx @@ -170,11 +171,12 @@ async def generate( attribution (``is_content`` / ``sampled_mask`` / ``message_indices`` / ``message_roles``) into the result without re-rendering. - For multimodal renderers (e.g. ``Qwen3VLRenderer``), the call goes + For multimodal renderers, the call goes through ``renderer.render(...)`` to recover the ``multi_modal_data`` sidecar, then serializes it to vLLM's ``features`` schema (mm_hashes, - mm_placeholders, kwargs_data) before POSTing. The serializer imports - ``vllm.*`` lazily so text-only consumers never pay for the import. + mm_placeholders, kwargs_data) before POSTing. Raw image ``kwargs_data`` + slots always carry a descriptor ref — every image (current and prior + turns) is sent as a pointer that the inference endpoint materializes. ``max_prompt_len`` controls the pre-flight overflow check. When the rendered prompt is strictly longer than the cap, the request is never @@ -211,12 +213,14 @@ async def generate( def _prepare(): if prompt_ids is not None: # Caller-supplied prompt; if they also gave us pre-computed - # attribution (e.g. the bridge path in verifiers), thread it - # through unchanged. + # attribution (e.g. the bridge path in verifiers), thread it through. + prompt_mm_data = multi_modal_data + if prompt_mm_data is None and prompt_attribution is not None: + prompt_mm_data = prompt_attribution.multi_modal_data return ( list(prompt_ids), renderer.get_stop_token_ids(), - multi_modal_data, + prompt_mm_data, prompt_attribution, ) rendered = renderer.render(messages, tools=tools, add_generation_prompt=True) @@ -248,11 +252,19 @@ def _prepare(): "token_ids": prompt_ids, "sampling_params": sp, } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) + + def _features_and_descriptor_mm() -> tuple[ + dict[str, Any] | None, MultiModalData | None + ]: + if mm_data is None or mm_data.is_empty(): + return None, mm_data + # Every image carries its raw ref (the pointer); persisted mm_data keeps it, + # so prior-turn images carry their ref forward unchanged. + return _build_vllm_mm_features(mm_data), mm_data + + features, out_mm_data = await _maybe_offload(renderer, _features_and_descriptor_mm) + if prompt_attr is not None and out_mm_data is not None: + prompt_attr = replace(prompt_attr, multi_modal_data=out_mm_data) if features is not None: body["features"] = features if cache_salt is not None: @@ -322,7 +334,7 @@ def _prepare(): # The mm sidecar consumed on the request side, surfaced back so # callers can persist it on the trajectory step for downstream # multi-turn bridging and training-sample construction. - "multi_modal_data": mm_data, + "multi_modal_data": out_mm_data, # The renderer's per-token attribution for the prompt — either # the RenderedTokens computed here via renderer.render(...) or # the one threaded in by the caller alongside prompt_ids (the @@ -334,113 +346,75 @@ def _prepare(): } -def _build_mm_features( - renderer: Renderer | RendererPool, - mm_data: MultiModalData, -) -> dict[str, Any] | None: +def _build_vllm_mm_features(mm_data: MultiModalData) -> dict[str, Any]: """Serialize ``MultiModalData`` to vLLM's ``/inference/v1/generate`` features payload. - vLLM's ``MultiModalFeatures`` carries three things: hashes (for cache - lookup), placeholder positions (so the engine knows where in the - token stream each item lives), and per-item ``MultiModalKwargsItem`` - base64-encoded. The encoding requires vLLM-side type info — what - fields belong to each modality, how they batch — and is currently - model-family specific. For now we dispatch on the renderer class; - extend the dispatch table as more multimodal renderers land. - - NOTE — future engine pluggability: this encoder is vLLM 0.20-specific - (uses ``vllm.multimodal.inputs.MultiModalKwargsItems``, - ``vllm.entrypoints.serve.disagg.mm_serde.encode_mm_kwargs_item``, and - ``_create_qwen2vl_field_factory``). When a second inference engine - arrives (SGLang, MAX, ...) the renderer client should be parameterized - on engine: either (a) move the encoder onto the renderer as - ``encode_mm_for_(mm_data)`` methods, or (b) accept an - ``Encoder`` strategy at the ``generate(...)`` call site. The data type - (``MultiModalData``) is already framework-agnostic and does not need - to change. Don't pre-build the abstraction with one engine in tree. + vLLM's ``MultiModalFeatures`` carries three things: hashes, placeholder + positions (so the engine knows where in the token stream each item lives), + and one raw ref per item. Raw multimodal descriptors use the common envelope + emitted by renderers; family-specific geometry stays inside the descriptor + payload and is interpreted downstream by prime-rl/vLLM adapters. """ - from renderers.qwen3_vl import Qwen3VLRenderer - from renderers.qwen35 import Qwen35Renderer - - # Type dispatch only needs the renderer class. Pools expose - # ``renderer_cls`` as a snapshot attribute, so we don't have to check - # out a slot just to read ``type(r)``. - renderer_cls = ( - renderer.renderer_cls if isinstance(renderer, RendererPool) else type(renderer) - ) - - # Qwen3-VL and Qwen3.5 both ship ``pixel_values`` + ``image_grid_thw`` - # via the shared Qwen2-VL field factory. ``spatial_merge_size=2`` is - # the family default and matches every Qwen-VL processor in tree. - if issubclass(renderer_cls, (Qwen3VLRenderer, Qwen35Renderer)): - return _build_qwen_vl_features(mm_data, spatial_merge_size=2) - - raise NotImplementedError( - f"Multimodal serialization not implemented for {renderer_cls.__name__}. " - "Add a dispatch branch in renderers.client._build_mm_features." + from renderers.mm_store import ( + RAW_MM_ITEM_KIND, + raw_mm_ref, ) - -def _build_qwen_vl_features( - mm_data: MultiModalData, *, spatial_merge_size: int -) -> dict[str, Any]: - """vLLM features payload for the Qwen-VL family (Qwen2-VL / Qwen3-VL). - - Stacks per-image processor outputs back into a batched ``BatchFeature``, - runs the Qwen2-VL field factory (shared across the family), wraps as - ``MultiModalKwargsItems``, base64-encodes each item, and assembles a - JSON-serializable dict matching vLLM's ``MultiModalFeatures`` schema. - - Returns ``None`` semantics live one level up — this helper assumes - the caller already verified ``mm_data`` is non-empty. - """ - try: - import torch - from transformers.feature_extraction_utils import BatchFeature - from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item - from vllm.model_executor.models.qwen2_vl import _create_qwen2vl_field_factory - from vllm.multimodal.inputs import MultiModalKwargsItems - except ImportError as exc: - raise RuntimeError( - "Multimodal generate via /inference/v1/generate requires `vllm` " - "and `torch` to encode the features payload. Install vLLM in this " - "environment, or pre-build features upstream." - ) from exc - out: dict[str, Any] = { "mm_hashes": {}, "mm_placeholders": {}, "kwargs_data": {}, } - image_items = mm_data.mm_items.get("image") or [] - if image_items: - # mm_items now ship numpy arrays (the renderer is torch-free); - # convert at this vLLM-glue boundary where torch is already a - # hard dependency. - pixel_values = torch.cat( - [torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0 - ) - image_grid_thw = torch.cat( - [torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0 - ) - hf_inputs = BatchFeature( - data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} - ) - config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) - kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) - encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]] - out["kwargs_data"]["image"] = encoded - out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) - out["mm_placeholders"]["image"] = [ - {"offset": p.offset, "length": p.length} - for p in mm_data.mm_placeholders.get("image") or [] - ] - - # If kwargs_data is empty across all modalities, drop the key so vLLM - # falls back to the hash-only (cache-hit) path. Otherwise hand it the - # full payload. - if not any(out["kwargs_data"].values()): - out["kwargs_data"] = None + for source_modality, items in mm_data.mm_items.items(): + if not items: + continue + mm_hashes = list(mm_data.mm_hashes.get(source_modality) or []) + placeholders = list(mm_data.mm_placeholders.get(source_modality) or []) + if len(mm_hashes) != len(items) or len(placeholders) != len(items): + raise ValueError( + "Multimodal sidecar length mismatch: " + f"modality={source_modality} items={len(items)} " + f"hashes={len(mm_hashes)} placeholders={len(placeholders)}" + ) + + for idx, item in enumerate(items): + if item.get("kind") != RAW_MM_ITEM_KIND: + raise NotImplementedError( + "renderers.client.generate() requires raw multimodal " + "descriptor envelopes (multimodal_output='raw'); " + f"got item keys {sorted(item)} for modality {source_modality!r}." + ) + feature_modality = item.get("vllm_modality") or source_modality + if not isinstance(feature_modality, str) or not feature_modality: + raise ValueError("raw multimodal item has invalid vllm_modality") + + raw_image_uri = item.get("raw_image_uri") + family = item.get("family") + fingerprint = item.get("layout_fingerprint") + payload = item.get("payload") + if not isinstance(raw_image_uri, str) or not raw_image_uri: + raise ValueError("raw multimodal item is missing raw_image_uri") + if not isinstance(family, str) or not family: + raise ValueError("raw multimodal item is missing family") + if not isinstance(fingerprint, str) or not fingerprint: + raise ValueError("raw multimodal item is missing layout_fingerprint") + if not isinstance(payload, dict): + raise ValueError("raw multimodal item payload must be a dict") + + out["mm_hashes"].setdefault(feature_modality, []).append(mm_hashes[idx]) + out["mm_placeholders"].setdefault(feature_modality, []).append( + {"offset": placeholders[idx].offset, "length": placeholders[idx].length} + ) + out["kwargs_data"].setdefault(feature_modality, []).append( + raw_mm_ref( + family=family, + fingerprint=fingerprint, + modality=feature_modality, + mm_hash=mm_hashes[idx], + raw_image_uri=raw_image_uri, + payload=payload, + ) + ) return out diff --git a/renderers/configs.py b/renderers/configs.py index 1e6f3f2..5be3362 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -52,6 +52,9 @@ def _reject_thinking_retention_conflict( ThinkingRetention = Literal["tool_cycle", "all"] """User-facing historical thinking/analysis retention override.""" +MultimodalOutput = Literal["raw", "processed"] +"""Renderer multimodal sidecar format.""" + ResolvedThinkingRetention = Literal["template", "tool_cycle", "all"] """Internal bridge policy after template kwargs have been resolved.""" @@ -87,9 +90,15 @@ class BaseRendererConfig(BaseConfig): to the Python chat-template implementation and its explicit template kwargs.""" + multimodal_output: MultimodalOutput = "raw" + """Multimodal sidecar format: + + - ``"raw"`` — emit JSON-safe image refs/descriptors for inference paths. + - ``"processed"`` — emit image-processor payloads for SFT/training paths.""" + # Fields that are renderer-internal — not forwarded to (or mirrored # by) ``apply_chat_template``. Override in subclasses that hold - # non-template config (e.g. ``image_cache_max``, GptOss's + # non-template config (e.g. GptOss's # ``use_system_prompt`` / ``knowledge_cutoff`` / ``model_identity``, # or fields that exist as renderer conventions without a Jinja # analogue like DeepSeek V3 / Kimi K2 ``enable_thinking``). @@ -116,10 +125,9 @@ def template_field_names(cls) -> frozenset[str]: class AutoRendererConfig(BaseRendererConfig): """Resolve the renderer from ``tokenizer.name_or_path`` at construction - time via ``MODEL_RENDERER_MAP``. Carries only the shared - ``thinking_retention`` field when explicitly set; template kwargs require - an explicit renderer choice so template-dependent behaviour stays visible - at the call site.""" + time via ``MODEL_RENDERER_MAP``. Carries the shared base fields into the + concrete renderer config; template kwargs require an explicit renderer + choice so template-dependent behaviour stays visible at the call site.""" name: Literal["auto"] = "auto" @@ -198,12 +206,6 @@ class Qwen35RendererConfig(BaseRendererConfig): running across the entire conversation. Mirrors the chat template's ``add_vision_id`` toggle.""" - image_cache_max: int = 256 - """FIFO bound on the per-renderer image processor cache. Renderer- - internal — not a Jinja chat-template kwarg.""" - - _internal_fields = frozenset({"image_cache_max"}) - class Qwen36RendererConfig(BaseRendererConfig): """Qwen3.6 renderer config. Inherits Qwen3.5's template surface.""" @@ -221,11 +223,6 @@ class Qwen36RendererConfig(BaseRendererConfig): last real user query. Mirrors the Qwen3.6 chat template's native ``preserve_thinking`` kwarg.""" - image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" - - _internal_fields = frozenset({"image_cache_max"}) - @model_validator(mode="after") def _check_thinking_retention(self): _reject_thinking_retention_conflict( @@ -245,11 +242,6 @@ class Qwen3VLRendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" - - _internal_fields = frozenset({"image_cache_max"}) - class GLM5RendererConfig(BaseRendererConfig): """GLM-5 renderer config.""" @@ -398,11 +390,6 @@ class KimiK25RendererConfig(BaseRendererConfig): ``thinking`` (not ``enable_thinking``) to match the upstream chat template's native variable name.""" - image_cache_max: int = 256 - """See :class:`Qwen35RendererConfig.image_cache_max`.""" - - _internal_fields = frozenset({"image_cache_max"}) - class LagunaXS2RendererConfig(BaseRendererConfig): """Laguna XS.2 renderer config.""" @@ -661,6 +648,7 @@ def config_from_name(name: str) -> BaseRendererConfig | None: "LagunaXS2RendererConfig", "Llama3RendererConfig", "MiniMaxM2RendererConfig", + "MultimodalOutput", "Nemotron3RendererConfig", "Nemotron3UltraRendererConfig", "Qwen35RendererConfig", diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index 8235933..4fb576c 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -22,7 +22,9 @@ from __future__ import annotations import json +import math import re +from dataclasses import dataclass from typing import Any from transformers.tokenization_utils import PreTrainedTokenizer @@ -43,12 +45,17 @@ trim_to_turn_close, ) from renderers.configs import KimiK25RendererConfig +from renderers.mm_store import image_layout_fingerprint, raw_mm_item from renderers.parsing import _reasoning_end_token_index, parse_kimi_k2_section from renderers.qwen3_vl import ( - _image_hash, + _image_content_hash, + _image_dimensions, + _image_source, _is_image_part, _is_video_part, _load_pil_image, + _pil_image_hash, + _raw_image_uri, ) # --------------------------------------------------------------------------- @@ -57,6 +64,9 @@ _DEFAULT_SYSTEM_PROMPT = "You are Kimi, an AI assistant created by Moonshot AI." +KIMI_K25_FAMILY = "kimi_k25" +KIMI_K25_VLLM_MODALITY = "vision_chunk" + # --------------------------------------------------------------------------- # TypeScript-style tool declaration # --------------------------------------------------------------------------- @@ -402,6 +412,154 @@ def _encode_tools_typescript(tools: list[ToolSpec]) -> str: return "# Tools\n\n## functions\nnamespace functions {\n" + functions_str + "\n}\n" +@dataclass(frozen=True) +class KimiK25ImageLayoutSpec: + patch_size: int = 14 + merge_kernel_size: int = 2 + in_patch_limit: int = 16384 + patch_limit_on_one_side: int = 512 + fixed_output_tokens: int | None = None + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + + +KIMI_K25_IMAGE_LAYOUT = KimiK25ImageLayoutSpec() + + +@dataclass(frozen=True) +class KimiImageLayoutDescriptor: + mm_hash: str + grid_thws: list[list[int]] + num_media_tokens: int + fingerprint: str + raw_image_uri: str + + +def _ceil_to_factor(value: int, factor: int) -> int: + return max(factor, math.ceil(value / factor) * factor) + + +def _kimi_resize_config( + width: int, height: int, layout: KimiK25ImageLayoutSpec +) -> tuple[int, int, int]: + """Kimi MoonViT/NavIT image resize layout without materializing pixels.""" + if height <= 0 or width <= 0: + raise ValueError(f"image dimensions must be positive, got {height}x{width}") + patch_size = layout.patch_size + patch_limit_pixels = layout.patch_limit_on_one_side * patch_size + s1 = math.sqrt( + layout.in_patch_limit + / (max(1.0, width // patch_size) * max(1.0, height // patch_size)) + ) + s2 = patch_limit_pixels / width + s3 = patch_limit_pixels / height + scale = min(1.0, s1, s2, s3) + resized_w = min(max(1, int(width * scale)), patch_limit_pixels) + resized_h = min(max(1, int(height * scale)), patch_limit_pixels) + + factor = layout.merge_kernel_size * patch_size + padded_w = _ceil_to_factor(resized_w, factor) + padded_h = _ceil_to_factor(resized_h, factor) + if layout.fixed_output_tokens is not None: + num_tokens = layout.fixed_output_tokens + else: + num_tokens = (padded_h // factor) * (padded_w // factor) + return padded_w, padded_h, int(num_tokens) + + +def describe_kimi_image_layout(part: dict[str, Any]) -> KimiImageLayoutDescriptor: + source = _image_source(part) + height, width = _image_dimensions(source) + layout = KIMI_K25_IMAGE_LAYOUT + padded_w, padded_h, num_media_tokens = _kimi_resize_config(width, height, layout) + grid_thws = [[1, padded_h // layout.patch_size, padded_w // layout.patch_size]] + fingerprint = image_layout_fingerprint( + family=KIMI_K25_FAMILY, + patch_size=layout.patch_size, + merge_kernel_size=layout.merge_kernel_size, + in_patch_limit=layout.in_patch_limit, + patch_limit_on_one_side=layout.patch_limit_on_one_side, + fixed_output_tokens=layout.fixed_output_tokens, + image_mean=list(layout.image_mean), + image_std=list(layout.image_std), + ) + return KimiImageLayoutDescriptor( + mm_hash=_image_content_hash(source), + grid_thws=grid_thws, + num_media_tokens=num_media_tokens, + fingerprint=fingerprint, + raw_image_uri=_raw_image_uri(source), + ) + + +def kimi_image_item_for_render(part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: + desc = describe_kimi_image_layout(part) + item = raw_mm_item( + modality="image", + family=KIMI_K25_FAMILY, + layout_fingerprint=desc.fingerprint, + payload={ + "grid_thws": desc.grid_thws, + "num_media_tokens": desc.num_media_tokens, + }, + raw_image_uri=desc.raw_image_uri, + vllm_modality=KIMI_K25_VLLM_MODALITY, + ) + return 1, desc.mm_hash, item + + +def load_kimi_processor(tokenizer): + try: + from transformers import AutoProcessor + except ImportError as exc: + raise RuntimeError( + "Processed multimodal rendering requires transformers with " + "AutoProcessor support." + ) from exc + + name = getattr(tokenizer, "name_or_path", None) + if not name: + raise RuntimeError( + "KimiK25Renderer needs a processor for multimodal_output='processed'. " + "Inject `renderer._processor` or load the tokenizer with a known " + "name_or_path." + ) + + from renderers.base import TRUSTED_REVISIONS + + kwargs: dict[str, Any] = {"trust_remote_code": True} + revision = TRUSTED_REVISIONS.get(name) + if revision is not None: + kwargs["revision"] = revision + return AutoProcessor.from_pretrained(name, **kwargs) + + +def kimi_processed_image_item_for_render( + part: dict[str, Any], + *, + processor: Any, + image_cache: dict[str, tuple[Any, int]], +) -> tuple[int, str, dict[str, Any]]: + pil = _load_pil_image(part) + image_hash = _pil_image_hash(pil) + cached = image_cache.get(image_hash) + if cached is not None: + out, _num_patches = cached + else: + img_proc = processor.image_processor + media_item = {"type": "image", "image": pil} + out = img_proc.preprocess([media_item], return_tensors="np") + num_patches = int(img_proc.media_tokens_calculator(media_item)) + if len(image_cache) >= 256: + image_cache.pop(next(iter(image_cache))) + image_cache[image_hash] = (out, num_patches) + item = { + "pixel_values": out["pixel_values"], + "grid_thws": out["grid_thws"], + } + return 1, image_hash, item + + # --------------------------------------------------------------------------- # Kimi K2.5 response parsing (mirrors K2 format, same token structure) # --------------------------------------------------------------------------- @@ -594,11 +752,10 @@ def __init__( self, tokenizer: PreTrainedTokenizer, config: KimiK25RendererConfig | None = None, - *, - processor: Any = None, ): self._tokenizer = tokenizer - self._processor = processor + self._processor: Any = None + self._image_cache: dict[str, tuple[Any, int]] = {} self.config = config or KimiK25RendererConfig() self.effective_thinking_retention = resolve_thinking_retention( self.config, @@ -638,13 +795,6 @@ def __init__( # The stop token for generation self._endoftext: int | None = self._try_token_id("<|endoftext|>") - # Per-instance image-processor cache (FIFO-bounded). Same shape as - # ``Qwen3VLRenderer._image_cache`` — keyed by content hash, value is - # ``(processor_out, num_patches)``. ``num_patches`` is informational - # for Kimi (we emit a single placeholder regardless), but kept for - # consistency / debugging. - self._image_cache: dict[str, tuple[Any, int]] = {} - @property def mm_token_type_id_map(self) -> dict[int, int]: """Token-id → modality marker. For Kimi K2.5 only ``<|media_pad|>`` @@ -652,54 +802,6 @@ def mm_token_type_id_map(self) -> dict[int, int]: internally from ``pixel_values``.""" return {self._media_pad: 1} - def _get_processor(self): - if self._processor is not None: - return self._processor - from transformers import AutoProcessor - - name = getattr(self._tokenizer, "name_or_path", None) - if not name: - raise RuntimeError( - "KimiK25Renderer needs a processor to render image content. " - "Pass `processor=AutoProcessor.from_pretrained(name, trust_remote_code=True, " - "revision=)` to the constructor, or load the tokenizer with a " - "known name_or_path so the processor can be auto-loaded." - ) - # Kimi's processor is custom Python in the model repo and requires - # trust_remote_code=True. Callers using ``create_renderer_pool`` go - # through ``load_tokenizer`` which already pins the revision; for - # auto-load here, we delegate to AutoProcessor with the same flag. - self._processor = AutoProcessor.from_pretrained(name, trust_remote_code=True) - return self._processor - - def _process_image(self, part: dict[str, Any]): - """Resolve, process, and characterize a single image part for Kimi K2.5. - - Returns ``(pil, processor_out, num_patches, image_hash)`` where - ``processor_out`` contains ``pixel_values`` and ``grid_thws`` - (Kimi's keys; differ from Qwen-VL's ``image_grid_thw``). Single - ``<|media_pad|>`` per image in the token stream; the patch count - is informational only. - """ - pil = _load_pil_image(part) - h = _image_hash(pil) - cached = self._image_cache.get(h) - if cached is not None: - out, num_patches = cached - return pil, out, num_patches, h - proc = self._get_processor() - img_proc = proc.image_processor - # Kimi's vision processor takes a media-dict shape, not raw PIL. - media_item = {"type": "image", "image": pil} - out = img_proc.preprocess([media_item], return_tensors="np") - # Patch count via the processor's own calculator (matches the - # model's per-patch attention count); kept for debugging. - num_patches = int(img_proc.media_tokens_calculator(media_item)) - if len(self._image_cache) >= self.config.image_cache_max: - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_patches) - return pil, out, num_patches, h - # ------------------------------------------------------------------ # Token helpers # ------------------------------------------------------------------ @@ -723,6 +825,22 @@ def _encode(self, text: str) -> list[int]: return [] return self._tokenizer.encode(text, add_special_tokens=False) + def _get_processor(self): + if self._processor is None: + self._processor = load_kimi_processor(self._tokenizer) + return self._processor + + def _image_item_for_render( + self, part: dict[str, Any] + ) -> tuple[int, str, dict[str, Any]]: + if self.config.multimodal_output == "processed": + return kimi_processed_image_item_for_render( + part, + processor=self._get_processor(), + image_cache=self._image_cache, + ) + return kimi_image_item_for_render(part) + # ------------------------------------------------------------------ # Core render # ------------------------------------------------------------------ @@ -820,7 +938,7 @@ def emit_image( ``<|media_content|>``, ``<|media_end|>``, the trailing ``\\n``) are template-injected scaffold. """ - _, out, _num_patches, h = self._process_image(part) + _placeholder_len, h, mm_item = self._image_item_for_render(part) emit_special( self._media_begin, msg_idx, is_sampled=is_sampled, is_content=False ) @@ -843,16 +961,7 @@ def emit_image( mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=1) ) - # ``grid_thws`` (Kimi) is the per-image equivalent of Qwen-VL's - # ``image_grid_thw``. Ship under Kimi's native key so the - # orchestrator's generic ``torch.cat``-based packer routes it - # directly into the model's forward kwargs. - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "grid_thws": out["grid_thws"], - } - ) + mm_items.setdefault("image", []).append(mm_item) # ── Tool declaration prefix (comes first) ── # K2.5/K2.6's tokenizer auto-computes ``tools_ts_str`` and threads @@ -1114,7 +1223,7 @@ def emit_image( is_sampled: bool = False, is_content: bool = False, ) -> None: - _, out, _num_patches, h = self._process_image(part) + _placeholder_len, h, mm_item = self._image_item_for_render(part) emit_special(self._media_begin, msg_idx) emit_text("image", msg_idx) emit_special(self._media_content, msg_idx) @@ -1128,12 +1237,7 @@ def emit_image( new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=1) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "grid_thws": out["grid_thws"], - } - ) + new_items.setdefault("image", []).append(mm_item) # Bridge handles user/system/tool only (reject_assistant_in_extension # blocks assistants), so no hist/suffix split needed. diff --git a/renderers/mm_store.py b/renderers/mm_store.py new file mode 100644 index 0000000..ddeb9bf --- /dev/null +++ b/renderers/mm_store.py @@ -0,0 +1,241 @@ +"""Run-scoped image asset helpers for multimodal rendering. + +The default renderer multimodal mode does not ship processed image features. +Images are written once into the run output tree and messages carry ``file://`` +URLs to those files. Renderers then emit lightweight image refs for vLLM only +when the engine needs to process an image. +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import os +import re +import threading +from dataclasses import dataclass +from pathlib import Path + +# Contract: must match prime_rl.utils.run_assets.IMAGE_OFFLOAD_DIR_ENV. +IMAGE_OFFLOAD_DIR_ENV = "VF_RENDERER_IMAGE_OFFLOAD_DIR" + +IMAGE_REF_PREFIX = "mmraw" +RAW_MM_ITEM_KIND = "prime_raw_mm_item" + +_SAFE = { + "multimodal family": re.compile(r"^[A-Za-z0-9_.-]+$"), + "raw multimodal modality": re.compile(r"^[A-Za-z0-9_.-]+$"), + "image layout fingerprint": re.compile(r"^[a-f0-9]{16,64}$"), + "image hash": re.compile(r"^[a-f0-9]{16,128}$"), + "raw multimodal ref payload segment": re.compile(r"^[A-Za-z0-9_-]*$"), +} + +_MEDIA_TYPE_EXT = { + "jpeg": ".jpg", + "jpg": ".jpg", + "png": ".png", + "webp": ".webp", + "gif": ".gif", +} + + +def _ensure_safe(label: str, value: str) -> str: + if not _SAFE[label].fullmatch(value): + raise ValueError(f"Invalid {label}: {value!r}") + return value + + +def run_image_dir() -> Path: + """Resolve the directory for raw image assets for a run.""" + explicit = os.getenv(IMAGE_OFFLOAD_DIR_ENV, "").strip() + if explicit: + return Path(explicit).resolve() + raise RuntimeError( + f"Set {IMAGE_OFFLOAD_DIR_ENV} before resolving raw image assets." + ) + + +def _media_type_ext(media_type: str) -> str: + subtype = media_type.split("/", 1)[-1].split(";", 1)[0].strip().lower() + return _MEDIA_TYPE_EXT.get(subtype, ".img") + + +def offload_image_to_run_assets( + url: object, image_dir: Path | None = None +) -> str | None: + """Decode a base64 data image into the run image assets directory. + + Returns a ``file://`` URL when ``url`` was rewritten and ``None`` for + non-data-image values. Writes are content-addressed and atomic. + """ + if not isinstance(url, str) or not url.startswith("data:image/"): + return None + marker = ";base64," + if marker not in url: + return None + + header, b64 = url.split(marker, 1) + try: + raw = base64.b64decode(b64) + except Exception: + return None + + root = (image_dir or run_image_dir()).resolve() + root.mkdir(parents=True, exist_ok=True) + digest = hashlib.sha256(raw).hexdigest()[:16] + path = root / f"{digest}{_media_type_ext(header[len('data:') :])}" + if not path.exists(): + tmp = path.with_name(f".{path.name}.{os.getpid()}.{threading.get_ident()}.tmp") + tmp.write_bytes(raw) + os.replace(tmp, path) + else: + try: + path.touch() + except OSError: + pass + return path.as_uri() + + +def _json_fingerprint_value(value: object) -> str: + return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str) + + +def _encode_ref_payload(payload: dict[str, object] | None) -> str: + raw = json.dumps(payload or {}, sort_keys=True, separators=(",", ":")).encode( + "utf-8" + ) + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + +def _decode_ref_payload(encoded: str) -> dict[str, object]: + _ensure_safe("raw multimodal ref payload segment", encoded) + padded = encoded + "=" * (-len(encoded) % 4) + payload = json.loads( + base64.urlsafe_b64decode(padded.encode("ascii")).decode("utf-8") + ) + if not isinstance(payload, dict): + raise ValueError("Raw multimodal ref payload must decode to a dict") + return payload + + +def image_layout_fingerprint(*, family: str, **values: object) -> str: + """Stable adapter-owned fingerprint for raw multimodal layout contracts.""" + _ensure_safe("multimodal family", family) + encoded_values = ":".join( + f"{key}={_json_fingerprint_value(values[key])}" for key in sorted(values) + ) + raw = f"image-layout:{family}:{encoded_values}".encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:32] + + +def raw_mm_item( + *, + modality: str, + family: str, + layout_fingerprint: str, + payload: dict[str, object], + raw_image_uri: str, + vllm_modality: str | None = None, +) -> dict[str, object]: + """Build the JSON-safe raw multimodal descriptor envelope. + + ``payload`` is intentionally adapter-owned. Shared consumers may route by + ``family`` and validate the common envelope, but must not inspect adapter + payload keys. + """ + _ensure_safe("multimodal family", family) + _ensure_safe("raw multimodal modality", modality) + _ensure_safe("image layout fingerprint", layout_fingerprint) + out: dict[str, object] = { + "kind": RAW_MM_ITEM_KIND, + "modality": modality, + "family": family, + "layout_fingerprint": layout_fingerprint, + "payload": payload, + } + if vllm_modality is not None: + out["vllm_modality"] = vllm_modality + out["raw_image_uri"] = raw_image_uri + return out + + +@dataclass(frozen=True) +class RawMMRef: + family: str + fingerprint: str + modality: str + mm_hash: str + payload: dict[str, object] + raw_image_uri: str + + +def raw_mm_ref( + *, + family: str, + fingerprint: str, + modality: str, + mm_hash: str, + raw_image_uri: str, + payload: dict[str, object] | None = None, +) -> str: + """Generic raw multimodal asset ref. + + Adapter-owned details stay in the descriptor payload so refs can serve + future families without baking shape names into the wire id. + """ + _ensure_safe("multimodal family", family) + _ensure_safe("image layout fingerprint", fingerprint) + _ensure_safe("raw multimodal modality", modality) + _ensure_safe("image hash", mm_hash) + + ref_payload: dict[str, object] = { + "family": family, + "fingerprint": fingerprint, + "modality": modality, + "mm_hash": mm_hash, + "payload": payload or {}, + "raw_image_uri": raw_image_uri, + } + + return f"{IMAGE_REF_PREFIX}:{_encode_ref_payload(ref_payload)}" + + +def split_raw_mm_ref(ref: str) -> RawMMRef: + parts = ref.split(":") + if len(parts) != 2 or parts[0] != IMAGE_REF_PREFIX: + raise ValueError(f"Invalid raw multimodal ref shape: {ref!r}") + + payload = _decode_ref_payload(parts[1]) + family = payload.get("family") + fingerprint = payload.get("fingerprint") + modality = payload.get("modality") + mm_hash = payload.get("mm_hash") + raw_image_uri = payload.get("raw_image_uri") + item_payload = payload.get("payload") + + if not isinstance(family, str): + raise ValueError("Raw multimodal ref is missing family") + if not isinstance(fingerprint, str): + raise ValueError("Raw multimodal ref is missing fingerprint") + if not isinstance(modality, str): + raise ValueError("Raw multimodal ref is missing modality") + if not isinstance(mm_hash, str): + raise ValueError("Raw multimodal ref is missing mm_hash") + if not isinstance(raw_image_uri, str): + raise ValueError("Raw multimodal ref is missing raw_image_uri") + if not isinstance(item_payload, dict): + raise ValueError("Raw multimodal ref payload must be a dict") + + return RawMMRef( + family=_ensure_safe("multimodal family", family), + fingerprint=_ensure_safe("image layout fingerprint", fingerprint), + modality=_ensure_safe("raw multimodal modality", modality), + mm_hash=_ensure_safe("image hash", mm_hash), + payload=item_payload, + raw_image_uri=raw_image_uri, + ) + + +def is_raw_mm_ref(ref: object) -> bool: + return isinstance(ref, str) and ref.startswith(f"{IMAGE_REF_PREFIX}:") diff --git a/renderers/qwen35.py b/renderers/qwen35.py index 3ebbbc6..ccce03f 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -7,9 +7,10 @@ processor class ``Qwen3VLProcessor``). When a user/tool message carries an ``ImagePart``, the renderer emits the same ``<|vision_start|>``+N×``<|image_pad|>`` +``<|vision_end|>`` expansion as the HF chat template (``N = -image_grid_thw.prod() // merge_size**2``) and ships processed pixel_values via -``RenderedTokens.multi_modal_data``. Text-only inputs take the original fast -path and remain byte-identical to ``apply_chat_template``. +image_grid_thw.prod() // merge_size**2``) using the renderer's baked image +layout spec. By default, vLLM receives run image refs for images it must +process; ``multimodal_output="processed"`` emits image-processor payloads for +SFT/training callers. """ from __future__ import annotations @@ -36,10 +37,11 @@ from renderers.configs import Qwen35RendererConfig from renderers.parsing import parse_qwen35 from renderers.qwen3_vl import ( - _image_hash, _is_image_part, _is_video_part, - _load_pil_image, + load_qwen_processor, + qwen_image_item_for_render, + qwen_processed_image_item_for_render, ) # --------------------------------------------------------------------------- @@ -117,11 +119,10 @@ def __init__( self, tokenizer: PreTrainedTokenizer, config: Qwen35RendererConfig | None = None, - *, - processor: Any = None, ): self._tokenizer = tokenizer - self._processor = processor + self._processor: Any = None + self._image_cache: dict[str, tuple[Any, int]] = {} cfg = config or type(self)._config_cls() # ``enable_thinking=None`` defers to the model's known default (see # ``_ENABLE_THINKING_DEFAULTS``). Materialise here so downstream reads @@ -158,11 +159,6 @@ def __init__( self._image_pad = self._token_id("<|image_pad|>") self._video_pad = self._token_id("<|video_pad|>") - # Per-instance image-processor cache; see Qwen3VLRenderer for the - # rationale (FIFO-bounded; same image seen across rollouts / - # bridge re-renders). - self._image_cache: dict[str, tuple[Any, int]] = {} - @property def mm_token_type_id_map(self) -> dict[int, int]: """Token-id → modality marker (1 = image, 2 = video) used by the @@ -171,46 +167,6 @@ def mm_token_type_id_map(self) -> dict[int, int]: """ return {self._image_pad: 1, self._video_pad: 2} - def _get_processor(self): - if self._processor is not None: - return self._processor - from transformers import AutoProcessor - - name = getattr(self._tokenizer, "name_or_path", None) - if not name: - raise RuntimeError( - "Qwen35Renderer needs a processor to render image / video parts. " - "Pass `processor=AutoProcessor.from_pretrained(...)` to the " - "constructor, or load the tokenizer with a known name_or_path " - "so the processor can be auto-loaded." - ) - self._processor = AutoProcessor.from_pretrained(name) - return self._processor - - def _process_image(self, part: dict[str, Any]): - """Resolve, process, and characterize a single image part. - - Returns ``(pil, processor_out, num_image_tokens, image_hash)``. - Mirrors ``Qwen3VLRenderer._process_image``: hashes the loaded PIL, - consults ``self._image_cache``, runs the HF image processor on - miss, FIFO-evicts on overflow. - """ - pil = _load_pil_image(part) - h = _image_hash(pil) - cached = self._image_cache.get(h) - if cached is not None: - out, num_image_tokens = cached - return pil, out, num_image_tokens, h - proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="np") - grid_thw = out["image_grid_thw"][0] - merge_size = proc.image_processor.merge_size - num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) - return pil, out, num_image_tokens, h - @staticmethod def _content_has_media(content: Any) -> bool: """True when ``content`` is a structured list containing image / video parts.""" @@ -233,6 +189,22 @@ def _encode(self, text: str) -> list[int]: return [] return self._tokenizer.encode(text, add_special_tokens=False) + def _get_processor(self): + if self._processor is None: + self._processor = load_qwen_processor(self._tokenizer, type(self).__name__) + return self._processor + + def _image_item_for_render( + self, part: dict[str, Any] + ) -> tuple[int, str, dict[str, Any]]: + if self.config.multimodal_output == "processed": + return qwen_processed_image_item_for_render( + part, + processor=self._get_processor(), + image_cache=self._image_cache, + ) + return qwen_image_item_for_render(part) + # ------------------------------------------------------------------ # Content rendering (mirrors the render_content Jinja macro) # ------------------------------------------------------------------ @@ -378,7 +350,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # specials are template scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = self._image_item_for_render(part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text( @@ -400,12 +372,7 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: """Emit a user message whose content list contains image parts. @@ -729,7 +696,7 @@ def emit_text_segments( content_mask.append(is_content) def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = self._image_item_for_render(part) vision_counts["image"] += 1 if self.config.add_vision_id: emit_text(f"Picture {vision_counts['image']}: ", msg_idx) @@ -742,12 +709,7 @@ def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def emit_user_with_media(content_list: list[Any], msg_idx: int) -> None: emit_special(self._im_start, msg_idx) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 4823e78..2e424ff 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -1,4 +1,4 @@ -"""Qwen3-VL renderer with multimodal (image + video) support. +"""Qwen3-VL renderer with multimodal image support. Produces a token stream that matches ``Qwen3VLProcessor.apply_chat_template`` byte-for-byte for text-only inputs and emits the same @@ -6,14 +6,10 @@ for image inputs as the HF processor (``N = image_grid_thw.prod() // merge_size**2``). -Image data is shipped to the inference engine via -``RenderedTokens.multi_modal_data``: ``mm_placeholders`` records the -``(offset, length)`` span of each image's placeholder tokens in the -prompt, ``mm_items`` carries the per-image processor output -(``pixel_values``, ``image_grid_thw``), and ``mm_hashes`` carries a -stable identifier for cache lookup. The wire-format conversion to -vLLM's ``/inference/v1/generate`` ``features`` field lives in -``renderers.client``. +By default, image data is shipped to the inference engine via run image refs, +not processed image-processor payloads. ``multimodal_output="processed"`` +instead emits processor payloads for SFT/training callers that need +``pixel_values`` directly. BPE boundary discipline: text runs that the chat template emits contiguously (e.g. ``"user\\n" + content_text``) must be encoded as a @@ -22,6 +18,9 @@ tokens (``<|im_start|>``, ``<|im_end|>``, ````, ``<|vision_start|>``…), which act as atomic boundaries the template also can't merge across. + +Video-shaped content parts are detected and rejected explicitly; video +materialization is not implemented yet. """ from __future__ import annotations @@ -30,8 +29,11 @@ import hashlib import io import json +import math +from dataclasses import dataclass +from pathlib import Path from typing import Any -from urllib.parse import urlparse +from urllib.parse import unquote, urlparse from transformers.tokenization_utils import PreTrainedTokenizer @@ -50,6 +52,10 @@ trim_to_turn_close, ) from renderers.configs import Qwen3VLRendererConfig +from renderers.mm_store import ( + image_layout_fingerprint, + raw_mm_item, +) from renderers.parsing import parse_qwen3 _TOOLS_HEADER = ( @@ -96,32 +102,97 @@ def _is_video_part(item: Any) -> bool: return bool(item.get("video")) or bool(item.get("video_url")) -def _load_pil_image(item: dict[str, Any]): - """Resolve an ImagePart to a PIL Image. +@dataclass(frozen=True) +class QwenVLImageLayoutSpec: + patch_size: int = 16 + temporal_patch_size: int = 2 + merge_size: int = 2 + min_pixels: int = 65536 + max_pixels: int = 16777216 + + +QWEN_VL_IMAGE_LAYOUT = QwenVLImageLayoutSpec() +_PROCESSED_IMAGE_CACHE_MAX = 256 + + +@dataclass(frozen=True) +class QwenImageLayoutDescriptor: + mm_hash: str + image_grid_thw: list[list[int]] + num_image_tokens: int + fingerprint: str + raw_image_uri: str + + +def _smart_resize( + height: int, + width: int, + *, + factor: int, + min_pixels: int, + max_pixels: int, +) -> tuple[int, int]: + """Qwen image resize math without materializing resized pixels.""" + if height <= 0 or width <= 0: + raise ValueError(f"image dimensions must be positive, got {height}x{width}") + if max(height, width) / min(height, width) > 200: + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def _image_source(item: dict[str, Any]) -> Any: + if "image" in item: + return item["image"] + if "image_url" in item: + image_url = item.get("image_url") + return image_url.get("url") if isinstance(image_url, dict) else image_url + return item.get("url") or item.get("path") + + +def _file_path_from_source(source: Any) -> Path | None: + if not isinstance(source, str): + return None + parsed = urlparse(source) + if parsed.scheme == "file": + return Path(unquote(parsed.path)).resolve() + if parsed.scheme == "": + return Path(source).resolve() + return None + + +def _offloaded_image_path(source: Any) -> Path: + path = _file_path_from_source(source) + if path is None: + raise ValueError( + "v1 multimodal image rendering requires offloaded file:// image assets" + ) + return path - Accepts pre-loaded PIL Images, raw bytes, filesystem paths, - ``file://``/``http(s)://`` URLs, and ``data:image/...;base64,...`` URIs. - """ + +def _load_pil_image(item: dict[str, Any]): + """Resolve an ImagePart to a PIL Image for processed multimodal output.""" try: from PIL import Image except ImportError as exc: raise RuntimeError( - "Pillow is required for multimodal rendering. Install with " - "`pip install Pillow` (or `pip install renderers[multimodal]`)." + "Processed multimodal rendering requires Pillow. Install " + "`renderers[vision]` or provide Pillow in the caller environment." ) from exc - raw: Any - if "image" in item: - raw = item["image"] - elif "image_url" in item: - # OpenAI canonical shape is ``image_url: {"url": "..."}`` — but - # some VLM processors (Kimi K2.5 / K2.6) hand a raw PIL / str - # directly under ``image_url``. Accept both. - iu = item.get("image_url") - raw = iu.get("url") if isinstance(iu, dict) else iu - else: - raw = item.get("url") or item.get("path") - + raw = _image_source(item) if isinstance(raw, Image.Image): return raw.convert("RGB") if raw.mode != "RGB" else raw @@ -135,7 +206,6 @@ def _load_pil_image(item: dict[str, Any]): ) if raw.startswith("data:"): - # data:image/png;base64,XXXX _, _, payload = raw.partition(",") return Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") @@ -143,28 +213,135 @@ def _load_pil_image(item: dict[str, Any]): if parsed.scheme in ("http", "https"): import urllib.request - with urllib.request.urlopen(raw) as resp: # noqa: S310 — user-supplied URL + with urllib.request.urlopen(raw) as resp: # noqa: S310 return Image.open(io.BytesIO(resp.read())).convert("RGB") - if parsed.scheme == "file" or parsed.scheme == "": - path = parsed.path if parsed.scheme == "file" else raw + if parsed.scheme in ("file", ""): + path = unquote(parsed.path) if parsed.scheme == "file" else raw return Image.open(path).convert("RGB") raise ValueError(f"Unsupported image URL scheme: {parsed.scheme!r} in {raw!r}") -def _image_hash(pil_image) -> str: - """Stable per-image identifier for cache lookup. +def _image_dimensions(source: Any) -> tuple[int, int]: + try: + from PIL import Image + except ImportError as exc: + raise RuntimeError( + "Pillow is required to read image dimensions for multimodal rendering." + ) from exc - Uses the resolved RGB bytes so two ``ImagePart``\\s pointing at the - same logical image (path, in-memory, data URI) hash identically. - """ + with Image.open(_offloaded_image_path(source)) as image: + return image.height, image.width + + +def _image_content_hash(source: Any) -> str: + return hashlib.sha256(_offloaded_image_path(source).read_bytes()).hexdigest()[:32] + + +def _pil_image_hash(pil_image) -> str: h = hashlib.sha256() h.update(pil_image.tobytes()) h.update(f"{pil_image.size}".encode()) return h.hexdigest()[:32] +def _raw_image_uri(source: Any) -> str: + return _offloaded_image_path(source).as_uri() + + +def describe_qwen_image_layout(part: dict[str, Any]) -> QwenImageLayoutDescriptor: + """Return Qwen image layout metadata without invoking an image processor.""" + source = _image_source(part) + height, width = _image_dimensions(source) + layout = QWEN_VL_IMAGE_LAYOUT + resized_h, resized_w = _smart_resize( + height, + width, + factor=layout.patch_size * layout.merge_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + grid_t = 1 + grid_h = resized_h // layout.patch_size + grid_w = resized_w // layout.patch_size + num_image_tokens = ( + grid_t * grid_h * grid_w // (layout.merge_size * layout.merge_size) + ) + fingerprint = image_layout_fingerprint( + family="qwen_vl", + patch_size=layout.patch_size, + merge_size=layout.merge_size, + temporal_patch_size=layout.temporal_patch_size, + min_pixels=layout.min_pixels, + max_pixels=layout.max_pixels, + ) + return QwenImageLayoutDescriptor( + mm_hash=_image_content_hash(source), + image_grid_thw=[[grid_t, grid_h, grid_w]], + num_image_tokens=num_image_tokens, + fingerprint=fingerprint, + raw_image_uri=_raw_image_uri(source), + ) + + +def qwen_image_item_for_render(part: dict[str, Any]) -> tuple[int, str, dict[str, Any]]: + desc = describe_qwen_image_layout(part) + item = raw_mm_item( + modality="image", + family="qwen_vl", + layout_fingerprint=desc.fingerprint, + payload={"image_grid_thw": desc.image_grid_thw}, + raw_image_uri=desc.raw_image_uri, + ) + return desc.num_image_tokens, desc.mm_hash, item + + +def load_qwen_processor(tokenizer, renderer_name: str): + try: + from transformers import AutoProcessor + except ImportError as exc: + raise RuntimeError( + "Processed multimodal rendering requires transformers with " + "AutoProcessor support." + ) from exc + + name = getattr(tokenizer, "name_or_path", None) + if not name: + raise RuntimeError( + f"{renderer_name} needs a processor for multimodal_output='processed'. " + "Inject `renderer._processor` or load the tokenizer with a known " + "name_or_path." + ) + return AutoProcessor.from_pretrained(name) + + +def qwen_processed_image_item_for_render( + part: dict[str, Any], + *, + processor: Any, + image_cache: dict[str, tuple[Any, int]], +) -> tuple[int, str, dict[str, Any]]: + pil = _load_pil_image(part) + image_hash = _pil_image_hash(pil) + cached = image_cache.get(image_hash) + if cached is not None: + out, num_image_tokens = cached + else: + out = processor.image_processor(images=[pil], return_tensors="np") + grid_thw = out["image_grid_thw"][0] + merge_size = processor.image_processor.merge_size + num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) + if len(image_cache) >= _PROCESSED_IMAGE_CACHE_MAX: + image_cache.pop(next(iter(image_cache))) + image_cache[image_hash] = (out, num_image_tokens) + item = { + "pixel_values": out["pixel_values"], + "image_grid_thw": out["image_grid_thw"], + } + return num_image_tokens, image_hash, item + + class _Emitter: """Token-stream builder with BPE-safe text buffering. @@ -298,11 +475,6 @@ class Qwen3VLRenderer: config: Typed renderer config (see :class:`renderers.Qwen3VLRendererConfig`). Defaults to a blank config with template defaults. - processor: Optional ``Qwen3VLProcessor``. Required when rendering - messages that contain image / video parts. If not supplied, - the renderer lazy-loads it via ``AutoProcessor.from_pretrained`` - keyed off ``tokenizer.name_or_path`` the first time a - multimodal part is seen. Qwen3-VL has no historical reasoning channel in this renderer. The default bridge policy therefore resolves to ``"all"``; explicit @@ -313,11 +485,10 @@ def __init__( self, tokenizer: PreTrainedTokenizer, config: Qwen3VLRendererConfig | None = None, - *, - processor: Any = None, ): self._tokenizer = tokenizer - self._processor = processor + self._processor: Any = None + self._image_cache: dict[str, tuple[Any, int]] = {} self.config = config or Qwen3VLRendererConfig() self.effective_thinking_retention = resolve_thinking_retention( self.config, @@ -337,16 +508,6 @@ def __init__( self._image_pad = self._token_id("<|image_pad|>") self._video_pad = self._token_id("<|video_pad|>") - # Per-instance image-processor cache. The HF image processor is the - # most expensive step on the renderer hot path (~tens of ms per - # image for typical grid_thw). The same image gets re-seen across - # ``rollouts_per_example`` rollouts of one example and (for - # multi-turn) across turn boundaries when the bridge re-renders - # rather than extends. Cache keyed by content hash — values are - # tuples of ``(processor_out, num_image_tokens)`` — bounded to - # avoid unbounded growth on long-lived pools. - self._image_cache: dict[str, tuple[Any, int]] = {} - def _token_id(self, token: str) -> int: tid = self._tokenizer.convert_tokens_to_ids(token) assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( @@ -373,20 +534,20 @@ def _encode(self, text: str) -> list[int]: return self._tokenizer.encode(text, add_special_tokens=False) def _get_processor(self): - if self._processor is not None: - return self._processor - from transformers import AutoProcessor + if self._processor is None: + self._processor = load_qwen_processor(self._tokenizer, type(self).__name__) + return self._processor - name = getattr(self._tokenizer, "name_or_path", None) - if not name: - raise RuntimeError( - "Qwen3VLRenderer needs a processor to render image / video parts. " - "Pass `processor=AutoProcessor.from_pretrained(...)` to the " - "constructor, or load the tokenizer with a known name_or_path " - "so the processor can be auto-loaded." + def _image_item_for_render( + self, part: dict[str, Any] + ) -> tuple[int, str, dict[str, Any]]: + if self.config.multimodal_output == "processed": + return qwen_processed_image_item_for_render( + part, + processor=self._get_processor(), + image_cache=self._image_cache, ) - self._processor = AutoProcessor.from_pretrained(name) - return self._processor + return qwen_image_item_for_render(part) @staticmethod def _render_text_content(content: Any) -> str: @@ -426,31 +587,6 @@ def _is_user_query_message(msg: Message) -> bool: and content.endswith("") ) - def _process_image(self, part: dict[str, Any]): - """Resolve, process, and characterize a single image part. - - Returns ``(pil, processor_out, num_image_tokens, image_hash)``. - Hashes the loaded PIL first and consults ``self._image_cache``; - on hit the HF image-processor call is skipped entirely. - """ - pil = _load_pil_image(part) - h = _image_hash(pil) - cached = self._image_cache.get(h) - if cached is not None: - out, num_image_tokens = cached - return pil, out, num_image_tokens, h - proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="np") - grid_thw = out["image_grid_thw"][0] - merge_size = proc.image_processor.merge_size - num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self.config.image_cache_max: - # FIFO eviction — Python dicts preserve insertion order, so - # ``next(iter(...))`` is the oldest key. - self._image_cache.pop(next(iter(self._image_cache))) - self._image_cache[h] = (out, num_image_tokens) - return pil, out, num_image_tokens, h - def render( self, messages: list[Message], @@ -480,7 +616,7 @@ def emit_image(part: dict[str, Any]) -> None: # image data, so they ARE body content (is_content=True); # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # markers are renderer-emitted scaffold. - _, out, n, h = self._process_image(part) + n, h, mm_item = self._image_item_for_render(part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -497,12 +633,7 @@ def emit_image(part: dict[str, Any]) -> None: mm_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - mm_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + mm_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: """Emit a user/tool content list with media handled inline. @@ -752,7 +883,7 @@ def bridge_to_next_turn( vision_counts = {"image": prev_image_count, "video": prev_video_count} def emit_image(part: dict[str, Any]) -> None: - _, out, n, h = self._process_image(part) + n, h, mm_item = self._image_item_for_render(part) vision_counts["image"] += 1 if self.config.add_vision_id: em.text( @@ -769,12 +900,7 @@ def emit_image(part: dict[str, Any]) -> None: new_placeholders.setdefault("image", []).append( PlaceholderRange(offset=offset, length=n) ) - new_items.setdefault("image", []).append( - { - "pixel_values": out["pixel_values"], - "image_grid_thw": out["image_grid_thw"], - } - ) + new_items.setdefault("image", []).append(mm_item) def render_media_content(content: Any) -> None: if isinstance(content, str): diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..ac60a27 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,6 @@ import asyncio import base64 +import hashlib import json import httpx @@ -13,13 +14,15 @@ ) from renderers.client import generate +_OPENAI_TOOL = {"type": "function", "function": {"name": "echo"}} + class _FakeRenderer: supports_tools = True def render(self, messages, *, tools=None, add_generation_prompt=False): assert messages == [{"role": "user", "content": "hi"}] - assert tools == [{"type": "function", "function": {"name": "echo"}}] + assert tools == [_OPENAI_TOOL] assert add_generation_prompt is True # Populate the full attribution surface so the test can verify # ``generate`` threads it through to the result dict unchanged. @@ -101,6 +104,21 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): ) +def test_offload_image_to_run_assets_writes_content_addressed_file(tmp_path): + from renderers.mm_store import offload_image_to_run_assets + + raw = b"png-ish bytes" + url = "data:image/png;base64," + base64.b64encode(raw).decode("ascii") + + file_url = offload_image_to_run_assets(url, image_dir=tmp_path) + + assert file_url is not None + assert file_url.startswith("file://") + path = tmp_path / file_url.rsplit("/", 1)[-1] + assert path.name == f"{hashlib.sha256(raw).hexdigest()[:16]}.png" + assert path.read_bytes() == raw + + def test_generate_builds_request_body_and_parses_response(): client = _FakeClient() renderer = _FakeRenderer() @@ -111,7 +129,7 @@ def test_generate_builds_request_body_and_parses_response(): renderer=renderer, messages=[{"role": "user", "content": "hi"}], model="test-model", - tools=[{"type": "function", "function": {"name": "echo"}}], + tools=[_OPENAI_TOOL], sampling_params={"temperature": 0.3, "max_tokens": 7, "min_tokens": 2}, cache_salt="ckpt-42", ) @@ -119,9 +137,7 @@ def test_generate_builds_request_body_and_parses_response(): # The client must plumb `tools` through to parse_response so XML-style # parsers can preserve declared-string args verbatim. - assert renderer._last_parse_tools == [ - {"type": "function", "function": {"name": "echo"}} - ] + assert renderer._last_parse_tools == [_OPENAI_TOOL] assert len(client.calls) == 1 # /inference/v1/generate is mounted at the server root, so we post to @@ -204,7 +220,7 @@ def test_generate_does_not_promote_finish_reason_for_malformed_tool_calls(): renderer=_MalformedToolRenderer(), messages=[{"role": "user", "content": "hi"}], model="test-model", - tools=[{"type": "function", "function": {"name": "echo"}}], + tools=[_OPENAI_TOOL], ) ) assert result["finish_reason"] == "stop" @@ -281,74 +297,72 @@ def test_generate_threads_prompt_attribution_through_prebuilt_prompt_path(): @pytest.mark.parametrize( - "model_id,renderer_class_path", + "family,payload,expected_modality,vllm_modality", [ - ("Qwen/Qwen3-VL-4B-Instruct", "renderers.qwen3_vl:Qwen3VLRenderer"), - ("Qwen/Qwen3.5-2B", "renderers.qwen35:Qwen35Renderer"), + ("qwen_vl", {"image_grid_thw": [[1, 2, 2]]}, "image", None), + ( + "kimi_k25", + {"grid_thws": [[1, 2, 2]], "num_media_tokens": 1}, + "vision_chunk", + "vision_chunk", + ), ], - ids=["qwen3_vl", "qwen35"], + ids=["default_image_modality", "kimi_vllm_modality"], ) -def test_generate_serializes_multimodal_features_for_qwen_vl_family( - model_id, renderer_class_path +def test_generate_serializes_raw_mm_refs( + tmp_path, family, payload, expected_modality, vllm_modality ): - """When the renderer emits ``MultiModalData``, ``generate`` translates - it into vLLM's ``features`` payload (mm_hashes + mm_placeholders + - base64-encoded kwargs_data) and sticks it in the request body. Covers - every renderer routed through ``_build_qwen_vl_features``.""" - import importlib - - pytest.importorskip("torch") - pytest.importorskip("vllm", reason="vllm needed for features serialization") + """``generate`` serializes raw multimodal envelopes to vLLM refs. - import torch as _torch + The client owns only the generic wire shape: hashes, placeholder spans, + and one raw ref per item. Family-specific payload keys stay opaque here. + """ from renderers.base import ( MultiModalData, PlaceholderRange, - load_tokenizer, + ) + from renderers.mm_store import ( + image_layout_fingerprint, + raw_mm_item, + split_raw_mm_ref, ) - mod_name, cls_name = renderer_class_path.split(":") - renderer_cls = getattr(importlib.import_module(mod_name), cls_name) - - # Build a minimal real renderer so type dispatch in - # _build_mm_features hits the qwen branch. The tokenizer is only - # touched in __init__ to grab special-token ids; render() / etc. - # aren't called here because we pre-supply prompt_ids + mm_data. - tokenizer = load_tokenizer(model_id) - renderer = renderer_cls(tokenizer) + image_dir = tmp_path / "run_rawtest" / "assets" / "images" + image_dir.mkdir(parents=True) + image_path = image_dir / "image.png" + image_path.write_bytes(b"image-bytes") + image_uri = image_path.as_uri() + fingerprint = image_layout_fingerprint(family=family, revision="test") + mm_hash = "a" * 32 - # Two synthetic 1×2×2 images. Field factory expects pixel_values - # shape ``(sum_HW, embed_dim)`` and grid_thw shape ``(N, 3)``; the - # values themselves don't matter for the encoding round-trip. mm_data = MultiModalData( - mm_hashes={"image": ["aaa", "bbb"]}, + mm_hashes={"image": [mm_hash]}, mm_placeholders={ "image": [ PlaceholderRange(offset=5, length=1), - PlaceholderRange(offset=10, length=1), ] }, mm_items={ "image": [ - { - "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), - "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), - }, - { - "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), - "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), - }, + raw_mm_item( + modality="image", + family=family, + layout_fingerprint=fingerprint, + payload=payload, + raw_image_uri=image_uri, + vllm_modality=vllm_modality, + ), ], }, ) client = _FakeClient() - asyncio.run( + result = asyncio.run( generate( client=client, - renderer=renderer, + renderer=_NoRenderRenderer(), messages=[], - model="qwen3-vl", + model="test-model", prompt_ids=list(range(20)), multi_modal_data=mm_data, sampling_params={"max_tokens": 4}, @@ -358,17 +372,28 @@ def test_generate_serializes_multimodal_features_for_qwen_vl_family( body = client.calls[0]["body"] assert "features" in body, "multimodal call should attach features" features = body["features"] - assert features["mm_hashes"] == {"image": ["aaa", "bbb"]} + assert features["mm_hashes"] == {expected_modality: [mm_hash]} assert features["mm_placeholders"] == { - "image": [{"offset": 5, "length": 1}, {"offset": 10, "length": 1}], + expected_modality: [{"offset": 5, "length": 1}], } - assert "kwargs_data" in features - assert features["kwargs_data"] is not None - assert "image" in features["kwargs_data"] - assert len(features["kwargs_data"]["image"]) == 2 - # Items are base64 strings (encode_mm_kwargs_item output). - for item in features["kwargs_data"]["image"]: - assert isinstance(item, str) and len(item) > 0 + refs = features["kwargs_data"][expected_modality] + assert len(refs) == 1 + ref = split_raw_mm_ref(refs[0]) + assert ref.payload == payload + assert ( + ref.family, + ref.fingerprint, + ref.modality, + ref.mm_hash, + ref.raw_image_uri, + ) == ( + family, + fingerprint, + expected_modality, + mm_hash, + image_uri, + ) + assert result["multi_modal_data"] is mm_data # --------------------------------------------------------------------------- diff --git a/tests/test_multimodal.py b/tests/test_multimodal.py index 6b06add..1a20e45 100644 --- a/tests/test_multimodal.py +++ b/tests/test_multimodal.py @@ -138,6 +138,14 @@ def tiny_image(): return Image.new("RGB", (224, 224), color=(128, 192, 255)) +@pytest.fixture +def offloaded_tiny_image(tmp_path, tiny_image): + """Renderer-side image fixture: v1 renderers require offloaded file assets.""" + path = tmp_path / "tiny.png" + tiny_image.save(path) + return path.as_uri() + + # --------------------------------------------------------------------------- # Modality → (renderer-side content part, processor-side image-list builder). # Each modality has its own "make a content part" / "extract source images" @@ -451,7 +459,9 @@ def _supports_tool_message_images(renderer) -> bool: @pytest.mark.parametrize( "mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES] ) -def test_multimodal_byte_parity_vs_processor(mm_model_name, modality, tiny_image): +def test_multimodal_byte_parity_vs_processor( + mm_model_name, modality, tiny_image, offloaded_tiny_image +): """Token byte-parity with ``processor.apply_chat_template`` + ``processor(...)``. Locks in the property that lets the inference engine see byte-identical @@ -464,8 +474,14 @@ def test_multimodal_byte_parity_vs_processor(mm_model_name, modality, tiny_image kit = _modality_kit(modality, mm_model_name) tokenizer, processor, renderer = _load_processor_and_renderer(mm_model_name) - for case in _build_cases(kit["make_part"], tiny_image): - messages, add_gp = case.values + renderer_cases = _build_cases(kit["make_part"], offloaded_tiny_image) + processor_cases = _build_cases(kit["make_part"], tiny_image) + for renderer_case, processor_case in zip( + renderer_cases, processor_cases, strict=True + ): + messages, add_gp = renderer_case.values + processor_messages, processor_add_gp = processor_case.values + assert add_gp == processor_add_gp # Ours. ours = renderer.render_ids(messages, add_generation_prompt=add_gp) @@ -473,10 +489,10 @@ def test_multimodal_byte_parity_vs_processor(mm_model_name, modality, tiny_image # Theirs: family-specific processor call. Qwen-VL is a two-step # (apply_chat_template + processor(images=, text=)); Kimi K2.5 is # a one-shot processor(messages=). - theirs = kit["processor_input_ids"](processor, messages, add_gp) + theirs = kit["processor_input_ids"](processor, processor_messages, add_gp) assert ours == theirs, ( - f"{mm_model_name} / {modality} / case={case.id}: " + f"{mm_model_name} / {modality} / case={renderer_case.id}: " f"renderer diverges from processor.\n" f" ours[:80]={ours[:80]}\n theirs[:80]={theirs[:80]}\n" f" len(ours)={len(ours)} len(theirs)={len(theirs)}" @@ -486,7 +502,9 @@ def test_multimodal_byte_parity_vs_processor(mm_model_name, modality, tiny_image @pytest.mark.parametrize( "mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES] ) -def test_multimodal_placeholders_match_pad_runs(mm_model_name, modality, tiny_image): +def test_multimodal_placeholders_match_pad_runs( + mm_model_name, modality, offloaded_tiny_image +): """``mm_placeholders`` exactly cover the runs of the modality's pad token.""" if not _hf_snapshot_cached(mm_model_name): pytest.skip(f"{mm_model_name}: HF snapshot not cached locally") @@ -495,7 +513,7 @@ def test_multimodal_placeholders_match_pad_runs(mm_model_name, modality, tiny_im tokenizer, _, renderer = _load_processor_and_renderer(mm_model_name) pad_id = tokenizer.convert_tokens_to_ids(kit["placeholder_token"]) - for case in _build_cases(kit["make_part"], tiny_image): + for case in _build_cases(kit["make_part"], offloaded_tiny_image): messages, add_gp = case.values rendered = renderer.render(messages, add_generation_prompt=add_gp) @@ -529,7 +547,7 @@ def test_multimodal_placeholders_match_pad_runs(mm_model_name, modality, tiny_im "mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES] ) def test_multimodal_bridge_extends_and_carries_mm_data( - mm_model_name, modality, tiny_image + mm_model_name, modality, offloaded_tiny_image ): """Bridge-to-next-turn invariants for the multimodal case. @@ -567,7 +585,7 @@ def test_multimodal_bridge_extends_and_carries_mm_data( { "role": "user", "content": [ - kit["make_part"](tiny_image), + kit["make_part"](offloaded_tiny_image), {"type": "text", "text": "Turn one."}, ], } @@ -576,7 +594,7 @@ def test_multimodal_bridge_extends_and_carries_mm_data( { "role": "user", "content": [ - kit["make_part"](tiny_image), + kit["make_part"](offloaded_tiny_image), {"type": "text", "text": "Turn two."}, ], } @@ -665,7 +683,9 @@ def test_modality_registry_models_route_to_renderer(): @pytest.mark.parametrize( "mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES] ) -def test_tool_response_image_byte_parity(mm_model_name, modality, tiny_image): +def test_tool_response_image_byte_parity( + mm_model_name, modality, tiny_image, offloaded_tiny_image +): """Tool-message image parity vs ``processor.apply_chat_template`` + ``processor(...)``. Browser-agent SFT traces carry post-action screenshots as ``tool`` @@ -688,12 +708,18 @@ def test_tool_response_image_byte_parity(mm_model_name, modality, tiny_image): f"{type(renderer).__name__} does not yet emit images inside tool responses" ) - for case in _build_tool_image_cases(kit["make_part"], tiny_image): - messages, add_gp = case.values + renderer_cases = _build_tool_image_cases(kit["make_part"], offloaded_tiny_image) + processor_cases = _build_tool_image_cases(kit["make_part"], tiny_image) + for renderer_case, processor_case in zip( + renderer_cases, processor_cases, strict=True + ): + messages, add_gp = renderer_case.values + processor_messages, processor_add_gp = processor_case.values + assert add_gp == processor_add_gp ours = renderer.render_ids(messages, add_generation_prompt=add_gp) - theirs = kit["processor_input_ids"](processor, messages, add_gp) + theirs = kit["processor_input_ids"](processor, processor_messages, add_gp) assert ours == theirs, ( - f"{mm_model_name} / tool / case={case.id}: " + f"{mm_model_name} / tool / case={renderer_case.id}: " f"renderer diverges from processor.\n" f" len(ours)={len(ours)} len(theirs)={len(theirs)}\n" f" ours[:60]={ours[:60]}\n theirs[:60]={theirs[:60]}" @@ -750,7 +776,7 @@ def _qwen_vl_processor_input_ids_with_kwargs( ) @pytest.mark.parametrize("add_vision_id", [True, False]) def test_add_vision_id_parity_vs_processor( - mm_model_name, modality, add_vision_id, tiny_image + mm_model_name, modality, add_vision_id, tiny_image, offloaded_tiny_image ): """Parity for ``add_vision_id`` across image-bearing shapes. @@ -775,15 +801,21 @@ def test_add_vision_id_parity_vs_processor( if hasattr(renderer, "_processor") and renderer._processor is None: renderer._processor = processor - for case in _build_cases(kit["make_part"], tiny_image): - messages, add_gp = case.values + renderer_cases = _build_cases(kit["make_part"], offloaded_tiny_image) + processor_cases = _build_cases(kit["make_part"], tiny_image) + for renderer_case, processor_case in zip( + renderer_cases, processor_cases, strict=True + ): + messages, add_gp = renderer_case.values + processor_messages, processor_add_gp = processor_case.values + assert add_gp == processor_add_gp ours = renderer.render_ids(messages, add_generation_prompt=add_gp) theirs = _qwen_vl_processor_input_ids_with_kwargs( - processor, messages, add_gp, add_vision_id=add_vision_id + processor, processor_messages, add_gp, add_vision_id=add_vision_id ) assert ours == theirs, ( f"{mm_model_name} / add_vision_id={add_vision_id} / " - f"case={case.id}: renderer diverges from processor.\n" + f"case={renderer_case.id}: renderer diverges from processor.\n" f" ours[:80]={ours[:80]}\n theirs[:80]={theirs[:80]}\n" f" len(ours)={len(ours)} len(theirs)={len(theirs)}" ) @@ -795,7 +827,7 @@ def test_add_vision_id_parity_vs_processor( ids=[f"{m}|{mo}" for m, mo in _ADD_VISION_ID_CASES], ) def test_bridge_refuses_when_add_vision_id_loses_prior_count( - mm_model_name, modality, tiny_image + mm_model_name, modality, offloaded_tiny_image ): """When ``add_vision_id=True``, the bridge needs the prior turn's image / video count to keep the ``Picture N:`` numbering correct. @@ -832,7 +864,7 @@ def test_bridge_refuses_when_add_vision_id_loses_prior_count( { "role": "user", "content": [ - kit["make_part"](tiny_image), + kit["make_part"](offloaded_tiny_image), {"type": "text", "text": "Turn one."}, ], } @@ -841,7 +873,7 @@ def test_bridge_refuses_when_add_vision_id_loses_prior_count( { "role": "user", "content": [ - kit["make_part"](tiny_image), + kit["make_part"](offloaded_tiny_image), {"type": "text", "text": "Turn two."}, ], } @@ -898,7 +930,7 @@ def test_is_image_part_treats_type_field_as_authoritative(): ``text: None`` added to every image part). The classifier must treat the ``type`` field as authoritative when present — falling back to a key-presence check on ``image_url`` would misclassify the text - part and the renderer would later raise on ``_load_pil_image(None)``. + part and the renderer would later try to resolve ``None`` as an image. """ from renderers.qwen3_vl import _is_image_part, _is_video_part diff --git a/tests/test_multimodal_output_modes.py b/tests/test_multimodal_output_modes.py new file mode 100644 index 0000000..ed898b4 --- /dev/null +++ b/tests/test_multimodal_output_modes.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest + +from renderers.kimi_k25 import kimi_processed_image_item_for_render +from renderers.qwen3_vl import qwen_processed_image_item_for_render + + +def _tiny_image_path(tmp_path): + Image = pytest.importorskip("PIL.Image") + path = tmp_path / "tiny.png" + Image.new("RGB", (16, 16), color=(120, 80, 40)).save(path) + return path + + +def test_qwen_processed_image_item_emits_processor_payload(tmp_path): + class _ImageProcessor: + merge_size = 2 + + def __call__(self, images, return_tensors): + assert len(images) == 1 + assert return_tensors == "np" + return { + "pixel_values": np.ones((4, 3), dtype=np.float32), + "image_grid_thw": np.array([[1, 4, 4]], dtype=np.int64), + } + + class _Processor: + image_processor = _ImageProcessor() + + num_tokens, image_hash, item = qwen_processed_image_item_for_render( + {"type": "image", "image": str(_tiny_image_path(tmp_path))}, + processor=_Processor(), + image_cache={}, + ) + + assert num_tokens == 4 + assert len(image_hash) == 32 + assert set(item) == {"pixel_values", "image_grid_thw"} + assert item["pixel_values"].shape == (4, 3) + assert item["image_grid_thw"].tolist() == [[1, 4, 4]] + + +def test_kimi_processed_image_item_emits_processor_payload(tmp_path): + class _ImageProcessor: + def preprocess(self, media, return_tensors): + assert len(media) == 1 + assert media[0]["type"] == "image" + assert return_tensors == "np" + return { + "pixel_values": np.ones((2, 3), dtype=np.float32), + "grid_thws": np.array([[1, 2, 2]], dtype=np.int64), + } + + def media_tokens_calculator(self, media): + assert media["type"] == "image" + return 2 + + class _Processor: + image_processor = _ImageProcessor() + + placeholder_len, image_hash, item = kimi_processed_image_item_for_render( + {"type": "image", "image": str(_tiny_image_path(tmp_path))}, + processor=_Processor(), + image_cache={}, + ) + + assert placeholder_len == 1 + assert len(image_hash) == 32 + assert set(item) == {"pixel_values", "grid_thws"} + assert item["pixel_values"].shape == (2, 3) + assert item["grid_thws"].tolist() == [[1, 2, 2]] diff --git a/tests/test_renderer_config.py b/tests/test_renderer_config.py index a35f270..4da5d57 100644 --- a/tests/test_renderer_config.py +++ b/tests/test_renderer_config.py @@ -101,13 +101,14 @@ def __init__(self, tokenizer, config): renderer = create_renderer( SimpleNamespace(name_or_path="fake/qwen35"), - AutoRendererConfig(thinking_retention="all"), + AutoRendererConfig(thinking_retention="all", multimodal_output="processed"), ) assert isinstance(renderer.config, Qwen35RendererConfig) assert renderer.config.thinking_retention == "all" + assert renderer.config.multimodal_output == "processed" # Template-level kwargs stay at their per-renderer defaults — auto - # carries only the thinking_retention flag. + # carries only shared base fields. assert renderer.config.add_vision_id is False diff --git a/uv.lock b/uv.lock index 2c6f5e6..05058e1 100644 --- a/uv.lock +++ b/uv.lock @@ -1360,6 +1360,13 @@ dependencies = [ { name = "transformers" }, ] +[package.optional-dependencies] +vision = [ + { name = "pillow" }, + { name = "torch" }, + { name = "torchvision" }, +] + [package.dev-dependencies] dev = [ { name = "pillow" }, @@ -1378,10 +1385,14 @@ requires-dist = [ { name = "numpy" }, { name = "openai", specifier = ">=1.108.1" }, { name = "openai-harmony", specifier = ">=0.0.4" }, + { name = "pillow", marker = "extra == 'vision'", specifier = ">=12.2.0" }, { name = "prime-pydantic-config", specifier = ">=0.3.0.dev83" }, { name = "tiktoken" }, + { name = "torch", marker = "extra == 'vision'", specifier = ">=2.11.0" }, + { name = "torchvision", marker = "extra == 'vision'", specifier = ">=0.26.0" }, { name = "transformers", specifier = ">=4.50.0" }, ] +provides-extras = ["vision"] [package.metadata.requires-dev] dev = [