From 99bedaf4b48d6d9eb7d02c1d7497675ebe758233 Mon Sep 17 00:00:00 2001 From: Eli Gottlieb Date: Tue, 26 May 2026 21:30:10 +0000 Subject: [PATCH] feat(sft): surface multimodal payload through build_training_sample MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit build_training_sample now returns a RenderedTrainingSample struct (token_ids, loss_mask, multi_modal_data, mm_token_type_ids) instead of a bare (token_ids, loss_mask) tuple. The mm fields are populated only when the renderer emitted media — None for text-only renderers and text-only samples — so text token_ids/loss_mask are byte-identical to before. This lets prime-rl's SFT consume one helper for both text and VLM instead of re-implementing the render+mask path inline for VLMs. mm_token_type_ids (0=text, 1=image, 2=video) are built from the rendered placeholder ranges at full token-stream length; the consumer truncates/shifts them in lockstep with token_ids. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/__init__.py | 2 ++ renderers/base.py | 63 +++++++++++++++++++++++++++++++++++-- tests/test_build_helpers.py | 20 ++++++++++-- tests/test_is_content.py | 3 +- 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/renderers/__init__.py b/renderers/__init__.py index 9fd385e..7f49495 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -19,6 +19,7 @@ PlaceholderRange, RenderedConversation, RenderedTokens, + RenderedTrainingSample, Renderer, RendererPool, TextPart, @@ -164,6 +165,7 @@ def __dir__() -> list[str]: "Qwen3VLRendererConfig", "RenderedConversation", "RenderedTokens", + "RenderedTrainingSample", "Renderer", "RendererConfig", "RendererPool", diff --git a/renderers/base.py b/renderers/base.py index b1d397f..52491a7 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -1534,6 +1534,42 @@ def _resolve_auto_config( # --------------------------------------------------------------------------- +# Match prime-rl's multimodal token type convention: 0=text, 1=image, 2=video. +_MM_TYPE_ID: dict[str, int] = {"image": 1, "video": 2} + + +@dataclass(frozen=True) +class RenderedTrainingSample: + """Output of :func:`build_training_sample`. + + ``token_ids`` and ``loss_mask`` are always populated. ``multi_modal_data`` + and ``mm_token_type_ids`` are populated only when a multimodal renderer + actually emitted media (both ``None`` for text-only renderers and for + text-only samples through a VLM renderer), so the text path is unchanged. + """ + + token_ids: list[int] + loss_mask: list[bool] + multi_modal_data: "MultiModalData | None" = None + mm_token_type_ids: list[int] | None = None + + +def _build_mm_token_type_ids( + mm_placeholders: dict[str, list[PlaceholderRange]], length: int +) -> list[int]: + """Per-token modality flags (0=text, 1=image, 2=video) from placeholder ranges.""" + ids = [0] * length + for modality, ranges in mm_placeholders.items(): + type_id = _MM_TYPE_ID.get(modality, 0) + if type_id == 0: + continue + for r in ranges: + end = min(r.offset + r.length, length) + for i in range(r.offset, end): + ids[i] = type_id + return ids + + def build_training_sample( renderer: Renderer, messages: list[Message], @@ -1541,8 +1577,12 @@ def build_training_sample( role_to_mask: Callable[[Message], bool] | None = None, tools: list[ToolSpec] | None = None, content_sft_roles: "set[str] | frozenset[str] | None" = None, -) -> tuple[list[int], list[bool]]: - """Build (token_ids, loss_mask) for supervised training. +) -> RenderedTrainingSample: + """Build a :class:`RenderedTrainingSample` for supervised training. + + Returns ``token_ids`` + ``loss_mask`` (always), plus ``multi_modal_data`` + and ``mm_token_type_ids`` when the renderer emitted media (``None`` for + text — the text token_ids/loss_mask are byte-identical to before). Single render() call + message_indices → per-token mask. Replaces build_incremental_token_mask (O(N) renders → O(1)). @@ -1631,7 +1671,24 @@ def build_training_sample( loss_mask.append(True) else: loss_mask.append(role_to_mask(msg)) - return rendered.token_ids, loss_mask + + # Surface the multimodal payload for VLM renderers. ``None`` for text + # renderers and for text-only samples (empty media) so downstream + # ``multi_modal_data is not None`` is a reliable "has media" check. + mm = rendered.multi_modal_data + if mm is not None and mm.is_empty(): + mm = None + mm_token_type_ids = ( + _build_mm_token_type_ids(mm.mm_placeholders, len(rendered.token_ids)) + if mm is not None and mm.mm_placeholders + else None + ) + return RenderedTrainingSample( + token_ids=rendered.token_ids, + loss_mask=loss_mask, + multi_modal_data=mm, + mm_token_type_ids=mm_token_type_ids, + ) def _common_prefix_len(a: list[int], b: list[int]) -> int: diff --git a/tests/test_build_helpers.py b/tests/test_build_helpers.py index 9fecc71..8779b6a 100644 --- a/tests/test_build_helpers.py +++ b/tests/test_build_helpers.py @@ -4,6 +4,17 @@ """ from renderers import build_training_sample, build_trajectory_step +from renderers.base import PlaceholderRange, _build_mm_token_type_ids + + +def test_build_mm_token_type_ids_marks_ranges(): + """Image runs → 1, video runs → 2, everything else → 0; clips at length.""" + placeholders = { + "image": [PlaceholderRange(offset=2, length=3)], # tokens 2,3,4 + "video": [PlaceholderRange(offset=7, length=2)], # tokens 7,8 + } + ids = _build_mm_token_type_ids(placeholders, length=10) + assert ids == [0, 0, 1, 1, 1, 0, 0, 2, 2, 0] def _expected(tokenizer, messages, **kwargs): @@ -29,10 +40,14 @@ def test_build_training_sample_ids_match(model_name, tokenizer, renderer): {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}, ] - ids, mask = build_training_sample( + sample = build_training_sample( renderer, msgs, role_to_mask=lambda m: m["role"] == "assistant" ) + ids = sample.token_ids assert ids == _expected(tokenizer, msgs) + # text-only sample carries no multimodal payload + assert sample.multi_modal_data is None + assert sample.mm_token_type_ids is None def test_build_training_sample_has_trainable_tokens(model_name, tokenizer, renderer): @@ -41,9 +56,10 @@ def test_build_training_sample_has_trainable_tokens(model_name, tokenizer, rende {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}, ] - ids, mask = build_training_sample( + sample = build_training_sample( renderer, msgs, role_to_mask=lambda m: m["role"] == "assistant" ) + ids, mask = sample.token_ids, sample.loss_mask assert sum(mask) > 0 assert len(mask) == len(ids) diff --git a/tests/test_is_content.py b/tests/test_is_content.py index ddac1f5..cd01507 100644 --- a/tests/test_is_content.py +++ b/tests/test_is_content.py @@ -353,12 +353,13 @@ def test_build_training_sample_content_sft_roles_picks_up_tool_body( {"role": "tool", "content": "done", "tool_call_id": "call_z"}, {"role": "assistant", "content": "OK."}, ] - ids, mask = build_training_sample( + sample = build_training_sample( renderer, msgs, role_to_mask=lambda m: m["role"] == "assistant", content_sft_roles={"tool"}, ) + ids, mask = sample.token_ids, sample.loss_mask assert len(mask) == len(ids) # We need at least one trainable tool-body token if the renderer