Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PlaceholderRange,
RenderedConversation,
RenderedTokens,
RenderedTrainingSample,
Renderer,
RendererPool,
TextPart,
Expand Down Expand Up @@ -164,6 +165,7 @@ def __dir__() -> list[str]:
"Qwen3VLRendererConfig",
"RenderedConversation",
"RenderedTokens",
"RenderedTrainingSample",
"Renderer",
"RendererConfig",
"RendererPool",
Expand Down
63 changes: 60 additions & 3 deletions renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,15 +1534,55 @@ def _resolve_auto_config(
# ---------------------------------------------------------------------------


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment here saying this mapping is matching the convention in prime-rl just to be crystal clear

# Match prime-rl's multimodal token type convention: 0=text, 1=image, 2=video.
_MM_TYPE_ID: dict[str, int] = {"image": 1, "video": 2}


@dataclass(frozen=True)
class RenderedTrainingSample:
"""Output of :func:`build_training_sample`.

``token_ids`` and ``loss_mask`` are always populated. ``multi_modal_data``
and ``mm_token_type_ids`` are populated only when a multimodal renderer
actually emitted media (both ``None`` for text-only renderers and for
text-only samples through a VLM renderer), so the text path is unchanged.
"""

token_ids: list[int]
loss_mask: list[bool]
multi_modal_data: "MultiModalData | None" = None
mm_token_type_ids: list[int] | None = None


def _build_mm_token_type_ids(
mm_placeholders: dict[str, list[PlaceholderRange]], length: int
) -> list[int]:
"""Per-token modality flags (0=text, 1=image, 2=video) from placeholder ranges."""
ids = [0] * length
for modality, ranges in mm_placeholders.items():
type_id = _MM_TYPE_ID.get(modality, 0)
if type_id == 0:
continue
for r in ranges:
end = min(r.offset + r.length, length)
for i in range(r.offset, end):
ids[i] = type_id
return ids


def build_training_sample(
renderer: Renderer,
messages: list[Message],
*,
role_to_mask: Callable[[Message], bool] | None = None,
tools: list[ToolSpec] | None = None,
content_sft_roles: "set[str] | frozenset[str] | None" = None,
) -> tuple[list[int], list[bool]]:
"""Build (token_ids, loss_mask) for supervised training.
) -> RenderedTrainingSample:
"""Build a :class:`RenderedTrainingSample` for supervised training.

Returns ``token_ids`` + ``loss_mask`` (always), plus ``multi_modal_data``
and ``mm_token_type_ids`` when the renderer emitted media (``None`` for
text — the text token_ids/loss_mask are byte-identical to before).

Single render() call + message_indices → per-token mask.
Replaces build_incremental_token_mask (O(N) renders → O(1)).
Expand Down Expand Up @@ -1631,7 +1671,24 @@ def build_training_sample(
loss_mask.append(True)
else:
loss_mask.append(role_to_mask(msg))
return rendered.token_ids, loss_mask

# Surface the multimodal payload for VLM renderers. ``None`` for text
# renderers and for text-only samples (empty media) so downstream
# ``multi_modal_data is not None`` is a reliable "has media" check.
mm = rendered.multi_modal_data
if mm is not None and mm.is_empty():
mm = None
mm_token_type_ids = (
_build_mm_token_type_ids(mm.mm_placeholders, len(rendered.token_ids))
if mm is not None and mm.mm_placeholders
else None
)
return RenderedTrainingSample(
token_ids=rendered.token_ids,
loss_mask=loss_mask,
multi_modal_data=mm,
mm_token_type_ids=mm_token_type_ids,
)


def _common_prefix_len(a: list[int], b: list[int]) -> int:
Expand Down
20 changes: 18 additions & 2 deletions tests/test_build_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
"""

from renderers import build_training_sample, build_trajectory_step
from renderers.base import PlaceholderRange, _build_mm_token_type_ids


def test_build_mm_token_type_ids_marks_ranges():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont need these three tests

"""Image runs → 1, video runs → 2, everything else → 0; clips at length."""
placeholders = {
"image": [PlaceholderRange(offset=2, length=3)], # tokens 2,3,4
"video": [PlaceholderRange(offset=7, length=2)], # tokens 7,8
}
ids = _build_mm_token_type_ids(placeholders, length=10)
assert ids == [0, 0, 1, 1, 1, 0, 0, 2, 2, 0]


def _expected(tokenizer, messages, **kwargs):
Expand All @@ -29,10 +40,14 @@ def test_build_training_sample_ids_match(model_name, tokenizer, renderer):
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
]
ids, mask = build_training_sample(
sample = build_training_sample(
renderer, msgs, role_to_mask=lambda m: m["role"] == "assistant"
)
ids = sample.token_ids
assert ids == _expected(tokenizer, msgs)
# text-only sample carries no multimodal payload
assert sample.multi_modal_data is None
assert sample.mm_token_type_ids is None


def test_build_training_sample_has_trainable_tokens(model_name, tokenizer, renderer):
Expand All @@ -41,9 +56,10 @@ def test_build_training_sample_has_trainable_tokens(model_name, tokenizer, rende
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
]
ids, mask = build_training_sample(
sample = build_training_sample(
renderer, msgs, role_to_mask=lambda m: m["role"] == "assistant"
)
ids, mask = sample.token_ids, sample.loss_mask
assert sum(mask) > 0
assert len(mask) == len(ids)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_is_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,13 @@ def test_build_training_sample_content_sft_roles_picks_up_tool_body(
{"role": "tool", "content": "done", "tool_call_id": "call_z"},
{"role": "assistant", "content": "OK."},
]
ids, mask = build_training_sample(
sample = build_training_sample(
renderer,
msgs,
role_to_mask=lambda m: m["role"] == "assistant",
content_sft_roles={"tool"},
)
ids, mask = sample.token_ids, sample.loss_mask
assert len(mask) == len(ids)

# We need at least one trainable tool-body token if the renderer
Expand Down
Loading