From 932f0b522de6ad5dd1cfd902b6c1112a8bc3f3b8 Mon Sep 17 00:00:00 2001 From: Zhipeng Wang Date: Tue, 26 May 2026 07:14:41 +0000 Subject: [PATCH 1/4] Add OPSD example: config, divergence losses, utils + tests First slice of the on-policy distillation example app under examples/opsd/. This commit lands the framework-agnostic foundation: the OPSDConfig dataclass hierarchy, chunked / streamed forward-KL / reverse-KL / JSD losses with sequence-axis chunking to bound peak memory, response-mask + shift helpers, and a 24-case CPU-only test suite covering identity, masking, chunk equivalence, gradient flow, and numerical edge cases. Signed-off-by: Zhipeng Wang --- examples/opsd/opsd/__init__.py | 17 +++ examples/opsd/opsd/config.py | 149 ++++++++++++++++++++++ examples/opsd/opsd/losses.py | 192 +++++++++++++++++++++++++++++ examples/opsd/opsd/utils.py | 52 ++++++++ examples/opsd/requirements.txt | 5 + examples/opsd/tests/test_losses.py | 166 +++++++++++++++++++++++++ 6 files changed, 581 insertions(+) create mode 100644 examples/opsd/opsd/__init__.py create mode 100644 examples/opsd/opsd/config.py create mode 100644 examples/opsd/opsd/losses.py create mode 100644 examples/opsd/opsd/utils.py create mode 100644 examples/opsd/requirements.txt create mode 100644 examples/opsd/tests/test_losses.py diff --git a/examples/opsd/opsd/__init__.py b/examples/opsd/opsd/__init__.py new file mode 100644 index 000000000000..a0916026f680 --- /dev/null +++ b/examples/opsd/opsd/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""On-Policy Distillation (OPSD) training on DeepSpeed. + +A student model generates rollouts; a frozen teacher scores them; the student +is updated by a per-token divergence (forward-KL / reverse-KL / JSD) computed +against the teacher's distribution on the student's own samples. + +Supports two rollout engines selected via config: + * ``hybrid_engine`` — DeepSpeed's built-in train+infer engine (ZeRO-3 safe) + * ``vllm`` — vLLM running on a disjoint set of GPUs with NCCL + weight sync from the trainer each step +""" + +__version__ = "0.1.0" diff --git a/examples/opsd/opsd/config.py b/examples/opsd/opsd/config.py new file mode 100644 index 000000000000..b55487d738bd --- /dev/null +++ b/examples/opsd/opsd/config.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Configuration dataclasses for OPSD training. + +A single :class:`OPSDConfig` is loaded from a JSON file (see ``configs/`` for +examples) and threaded through the rest of the pipeline. We use plain +dataclasses instead of Hydra/pydantic to match the rest of the DeepSpeed +example apps and to keep the dependency surface minimal. +""" + +import json +from dataclasses import dataclass, field, asdict +from typing import List, Optional + + +@dataclass +class StudentConfig: + model_name_or_path: str + dtype: str = "bfloat16" + trust_remote_code: bool = False + # Architecture key used to look up the weight bridge for vLLM rollout. If + # unset, the trainer will infer it from the HF config's ``model_type``. + arch: Optional[str] = None + + +@dataclass +class TeacherConfig: + model_name_or_path: str + dtype: str = "bfloat16" + trust_remote_code: bool = False + # Keep teacher params on CPU and gather per-forward via ZeRO-3. Saves GPU + # memory at the cost of host<->device transfer each step. + offload_to_cpu: bool = True + + +@dataclass +class RolloutConfig: + # "hybrid_engine" | "vllm" + engine: str = "hybrid_engine" + + # Generation knobs (apply to either engine) + max_prompt_length: int = 1024 + max_response_length: int = 1024 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + n_samples_per_prompt: int = 1 + + # vLLM-specific. ``gpus`` is the disjoint set of CUDA device indices vLLM + # may use; the training ranks must not overlap with these. If None, the + # trainer will refuse to start in vllm mode. + gpus: Optional[List[int]] = None + tensor_parallel_size: int = 1 + gpu_memory_utilization: float = 0.85 + vllm_dtype: str = "bfloat16" + # Push student weights into vLLM every N optimizer steps. Larger values + # trade staleness for throughput. + weight_sync_interval: int = 1 + # Pinned vLLM version known to expose the worker APIs we rely on. + vllm_min_version: str = "0.6.4" + # Skip CUDA-graph capture at vLLM startup. Saves several minutes of + # one-time compilation (worth it for smoke tests / short-lived runs); + # leave False for steady-state throughput. + vllm_enforce_eager: bool = False + + +@dataclass +class DistillationConfig: + # "forward_kl" | "reverse_kl" | "jsd" + loss_type: str = "reverse_kl" + temperature: float = 1.0 + # Chunk size along the sequence dimension for the per-token divergence. + # Bounds peak memory: full [B, T, V] is never materialized at once when + # T > chunk_size. + chunk_size: int = 512 + + +@dataclass +class TrainingConfig: + train_batch_size: int = 8 + micro_batch_size_per_gpu: int = 1 + gradient_accumulation_steps: int = 1 + learning_rate: float = 1e-6 + weight_decay: float = 0.0 + num_train_epochs: int = 1 + max_steps: int = -1 + warmup_steps: int = 0 + save_steps: int = 500 + logging_steps: int = 10 + save_dir: str = "./opsd_ckpt" + seed: int = 42 + + +@dataclass +class DataConfig: + path: str = "" + prompt_field: str = "prompt" + # Optional HF chat template override; if None we use the student tokenizer's + # default. + chat_template: Optional[str] = None + shuffle: bool = True + + +@dataclass +class OPSDConfig: + student: StudentConfig + teacher: TeacherConfig + rollout: RolloutConfig = field(default_factory=RolloutConfig) + distillation: DistillationConfig = field(default_factory=DistillationConfig) + training: TrainingConfig = field(default_factory=TrainingConfig) + data: DataConfig = field(default_factory=DataConfig) + # Path to the DeepSpeed JSON config used for ``deepspeed.initialize`` on the + # student. Kept as a separate file because it has its own schema owned by + # DeepSpeed. + deepspeed_config: str = "" + + @classmethod + def from_json(cls, path: str) -> "OPSDConfig": + with open(path, "r") as f: + raw = json.load(f) + return cls.from_dict(raw) + + @classmethod + def from_dict(cls, raw: dict) -> "OPSDConfig": + return cls( + student=StudentConfig(**raw["student"]), + teacher=TeacherConfig(**raw["teacher"]), + rollout=RolloutConfig(**raw.get("rollout", {})), + distillation=DistillationConfig(**raw.get("distillation", {})), + training=TrainingConfig(**raw.get("training", {})), + data=DataConfig(**raw.get("data", {})), + deepspeed_config=raw.get("deepspeed_config", ""), + ) + + def to_dict(self) -> dict: + return asdict(self) + + def validate(self) -> None: + if self.distillation.loss_type not in ("forward_kl", "reverse_kl", "jsd"): + raise ValueError(f"Unknown loss_type {self.distillation.loss_type!r}") + if self.rollout.engine not in ("hybrid_engine", "vllm"): + raise ValueError(f"Unknown rollout engine {self.rollout.engine!r}") + # rollout.gpus may be left empty for the "shared" topology where vLLM + # runs in-process on the same GPU as training rank 0; populated for + # the "disjoint" topology where it runs on a separate set of devices. + if self.distillation.chunk_size <= 0: + raise ValueError("distillation.chunk_size must be positive") diff --git a/examples/opsd/opsd/losses.py b/examples/opsd/opsd/losses.py new file mode 100644 index 000000000000..d9f4b9266da5 --- /dev/null +++ b/examples/opsd/opsd/losses.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Per-token distillation divergences with sequence-axis chunking. + +The full ``[B, T, V]`` tensor produced by a forward pass on a modern LLM can +easily exceed several GB in fp32 (e.g. 8 * 1024 * 150k * 4 B ~ 4.9 GB). Holding +both student *and* teacher logits at once would double that. We chunk along the +sequence axis so the per-chunk softmax + difference only ever needs +``[B, chunk, V]`` of working memory, regardless of T. + +Math conventions: + * ``forward_kl`` = D_KL(teacher || student) — mode-covering for student + * ``reverse_kl`` = D_KL(student || teacher) — mode-seeking for student + * ``jsd`` = 0.5 * D_KL(P || M) + 0.5 * D_KL(Q || M), M = (P+Q)/2 + +All three follow the standard knowledge-distillation temperature convention: +divide logits by T before softmax, then multiply the result by T**2 so that +gradient magnitudes are comparable across temperatures. +""" + +from typing import Callable + +import torch +import torch.nn.functional as F + + +def _forward_kl(student_logits: torch.Tensor, teacher_logits: torch.Tensor, temperature: float) -> torch.Tensor: + s_log_probs = F.log_softmax(student_logits / temperature, dim=-1) + t_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1) + t_probs = t_log_probs.exp() + kl = (t_probs * (t_log_probs - s_log_probs)).sum(dim=-1) + return kl * (temperature**2) + + +def _reverse_kl(student_logits: torch.Tensor, teacher_logits: torch.Tensor, temperature: float) -> torch.Tensor: + s_log_probs = F.log_softmax(student_logits / temperature, dim=-1) + t_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1) + s_probs = s_log_probs.exp() + kl = (s_probs * (s_log_probs - t_log_probs)).sum(dim=-1) + return kl * (temperature**2) + + +def _jsd(student_logits: torch.Tensor, teacher_logits: torch.Tensor, temperature: float) -> torch.Tensor: + s_log_probs = F.log_softmax(student_logits / temperature, dim=-1) + t_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1) + s_probs = s_log_probs.exp() + t_probs = t_log_probs.exp() + m_probs = 0.5 * (s_probs + t_probs) + # Clamp guards against log(0) when both distributions have ~0 mass on the + # same vocab id (rare in practice but possible after temperature scaling). + m_log_probs = m_probs.clamp_min(1e-12).log() + kl_s = (s_probs * (s_log_probs - m_log_probs)).sum(dim=-1) + kl_t = (t_probs * (t_log_probs - m_log_probs)).sum(dim=-1) + return 0.5 * (kl_s + kl_t) * (temperature**2) + + +_LOSS_FNS: "dict[str, Callable[..., torch.Tensor]]" = { + "forward_kl": _forward_kl, + "reverse_kl": _reverse_kl, + "jsd": _jsd, +} + + +def chunked_distillation_loss( + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + response_mask: torch.Tensor, + loss_type: str = "reverse_kl", + temperature: float = 1.0, + chunk_size: int = 512, +) -> torch.Tensor: + """Mean per-token divergence over response positions, chunked over the + sequence axis to bound peak memory. + + Args: + student_logits: ``[B, T, V]`` — gradient flows here. + teacher_logits: ``[B, T, V]`` — caller is responsible for ``detach()`` + (we do not detach here so the function stays cheap). + response_mask: ``[B, T]`` — 1 where the position should contribute to + the loss (i.e. response tokens, not prompt or padding), 0 elsewhere. + loss_type: ``"forward_kl"`` | ``"reverse_kl"`` | ``"jsd"``. + temperature: KD temperature; >1 softens both distributions. + chunk_size: Sequence-axis chunk size. + + Returns: + Scalar loss = sum-over-positions(per_tok * mask) / sum(mask), promoted + to fp32 internally for numerical stability. + """ + if loss_type not in _LOSS_FNS: + raise ValueError(f"Unknown loss_type {loss_type!r}; choose from {sorted(_LOSS_FNS)}") + fn = _LOSS_FNS[loss_type] + + if student_logits.shape != teacher_logits.shape: + raise ValueError(f"shape mismatch: student {tuple(student_logits.shape)} vs teacher " + f"{tuple(teacher_logits.shape)}") + B, T, _ = student_logits.shape + if response_mask.shape != (B, T): + raise ValueError(f"response_mask {tuple(response_mask.shape)} does not match logits " + f"prefix ({B}, {T})") + + mask_f = response_mask.to(torch.float32) + total_tokens = mask_f.sum().clamp_min(1.0) + total_loss = student_logits.new_zeros((), dtype=torch.float32) + + for start in range(0, T, chunk_size): + end = min(start + chunk_size, T) + chunk_mask = mask_f[:, start:end] + # Skipping empty chunks avoids a redundant forward through the softmax + # path on chunks that wouldn't contribute anything to the sum. + if chunk_mask.sum().item() == 0: + continue + per_tok = fn( + student_logits[:, start:end].float(), + teacher_logits[:, start:end].float(), + temperature, + ) + total_loss = total_loss + (per_tok * chunk_mask).sum() + + return total_loss / total_tokens + + +def streamed_distillation_loss( + student_logits: torch.Tensor, + teacher_chunk_fetcher: Callable[[int, int], torch.Tensor], + response_mask: torch.Tensor, + loss_type: str = "reverse_kl", + temperature: float = 1.0, + chunk_size: int = 512, +) -> torch.Tensor: + """Same math as :func:`chunked_distillation_loss`, but teacher logits are + pulled chunk-by-chunk via a fetcher so the full ``[B, T, V]`` teacher + tensor never needs to live on the same device as the student. + + Args: + student_logits: ``[B, T, V]`` on the training device. + teacher_chunk_fetcher: ``fn(start, end) -> [B, end - start, V]``, already + on the same device and broadcastable dtype as ``student_logits``. + Typically wraps ``TeacherLogitCache.chunk_to_device``. + response_mask: ``[B, T]`` — 1 where the position should contribute. + loss_type: one of ``"forward_kl" | "reverse_kl" | "jsd"``. + temperature: KD temperature. + chunk_size: Sequence-axis chunk size. + """ + if loss_type not in _LOSS_FNS: + raise ValueError(f"Unknown loss_type {loss_type!r}; choose from {sorted(_LOSS_FNS)}") + fn = _LOSS_FNS[loss_type] + + B, T, _ = student_logits.shape + if response_mask.shape != (B, T): + raise ValueError(f"response_mask {tuple(response_mask.shape)} does not match logits " + f"prefix ({B}, {T})") + + mask_f = response_mask.to(torch.float32) + total_tokens = mask_f.sum().clamp_min(1.0) + total_loss = student_logits.new_zeros((), dtype=torch.float32) + + for start in range(0, T, chunk_size): + end = min(start + chunk_size, T) + chunk_mask = mask_f[:, start:end] + if chunk_mask.sum().item() == 0: + continue + teacher_chunk = teacher_chunk_fetcher(start, end) + if teacher_chunk.shape[1] != (end - start): + raise RuntimeError(f"fetcher returned chunk of length {teacher_chunk.shape[1]}, " + f"expected {end - start}") + per_tok = fn( + student_logits[:, start:end].float(), + teacher_chunk.float(), + temperature, + ) + total_loss = total_loss + (per_tok * chunk_mask).sum() + + return total_loss / total_tokens + + +def per_token_logprobs(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Gather log p(label_t | context_ torch.Tensor: + """Mark positions belonging to the response (not prompt, not padding). + + Args: + response_start_idx: ``[B]`` int tensor — the first column index that is + part of the response, per sample. For *right-padded* prompts this + equals the prompt's token count; for the more common *left-padded* + convention used by causal generation it equals the prompt section + length (i.e. the column where prompt ends and response begins). + attention_mask: ``[B, T]`` — 1 on real tokens (prompt + response), 0 on + padding. + + Returns: + ``[B, T]`` 0/1 mask with the same dtype as ``attention_mask``. 1 only + at positions ``t >= response_start_idx[b]`` that are also attended. + """ + if response_start_idx.dim() != 1: + raise ValueError(f"response_start_idx must be 1-D, got shape {tuple(response_start_idx.shape)}") + if attention_mask.dim() != 2: + raise ValueError(f"attention_mask must be 2-D, got shape {tuple(attention_mask.shape)}") + B, T = attention_mask.shape + if response_start_idx.shape[0] != B: + raise ValueError(f"response_start_idx batch ({response_start_idx.shape[0]}) != " + f"attention_mask batch ({B})") + + pos = torch.arange(T, device=attention_mask.device).unsqueeze(0).expand(B, T) + is_response = pos >= response_start_idx.to(pos.dtype).unsqueeze(1) + return is_response.to(attention_mask.dtype) * attention_mask + + +def shift_for_next_token_prediction(logits: torch.Tensor, labels: torch.Tensor): + """Align logits at position t with the label at position t+1. + + Returns: + Tuple ``(shifted_logits[:, :-1, :], shifted_labels[:, 1:])`` — both + contiguous, so they can be safely indexed for the divergence loss. + """ + return logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous() diff --git a/examples/opsd/requirements.txt b/examples/opsd/requirements.txt new file mode 100644 index 000000000000..fb5a091575da --- /dev/null +++ b/examples/opsd/requirements.txt @@ -0,0 +1,5 @@ +datasets>=2.0.0 +numpy +transformers>=4.40.0 +# Optional, only needed when rollout.engine == "vllm": +# vllm>=0.6.4 diff --git a/examples/opsd/tests/test_losses.py b/examples/opsd/tests/test_losses.py new file mode 100644 index 000000000000..1cf9aede6756 --- /dev/null +++ b/examples/opsd/tests/test_losses.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only numerics tests for the distillation divergences. + +These exercise the loss math without needing GPUs, models, or a torchrun +launcher. Run from the example root with:: + + cd examples/opsd && python -m pytest tests/test_losses.py -v +""" + +import pytest +import torch + +from opsd.losses import chunked_distillation_loss, per_token_logprobs +from opsd.utils import build_response_mask, shift_for_next_token_prediction + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_zero_when_identical(loss_type): + torch.manual_seed(0) + logits = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(logits, logits.clone(), mask, loss_type=loss_type) + assert loss.item() == pytest.approx(0.0, abs=1e-5) + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_positive_when_different(loss_type): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type=loss_type) + assert loss.item() > 0.0 + + +@pytest.mark.parametrize("loss_type", ["forward_kl", "reverse_kl", "jsd"]) +def test_chunking_equivalent_to_unchunked(loss_type): + torch.manual_seed(0) + s = torch.randn(2, 100, 32) + t = torch.randn(2, 100, 32) + mask = torch.ones(2, 100) + loss_chunked = chunked_distillation_loss(s, t, mask, loss_type=loss_type, chunk_size=10) + loss_whole = chunked_distillation_loss(s, t, mask, loss_type=loss_type, chunk_size=10_000) + assert torch.allclose(loss_chunked, loss_whole, atol=1e-5) + + +def test_mask_excludes_tokens(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + half_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float32) + loss_direct = chunked_distillation_loss(s[:, :4], t[:, :4], torch.ones(2, 4), loss_type="reverse_kl") + loss_masked = chunked_distillation_loss(s, t, half_mask, loss_type="reverse_kl") + assert torch.allclose(loss_direct, loss_masked, atol=1e-5) + + +def test_gradient_flows_to_student(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32, requires_grad=True) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl") + loss.backward() + assert s.grad is not None + assert s.grad.abs().sum().item() > 0 + + +def test_gradient_does_not_flow_to_teacher_when_detached(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32, requires_grad=True) + t = torch.randn(2, 8, 32, requires_grad=True) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t.detach(), mask, loss_type="reverse_kl") + loss.backward() + assert t.grad is None + + +def test_unknown_loss_type_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 4, 8) + mask = torch.ones(2, 4) + with pytest.raises(ValueError, match="Unknown loss_type"): + chunked_distillation_loss(s, t, mask, loss_type="totally_made_up") + + +def test_shape_mismatch_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 5, 8) + mask = torch.ones(2, 4) + with pytest.raises(ValueError, match="shape mismatch"): + chunked_distillation_loss(s, t, mask) + + +def test_mask_shape_mismatch_raises(): + s = torch.randn(2, 4, 8) + t = torch.randn(2, 4, 8) + mask = torch.ones(2, 5) + with pytest.raises(ValueError, match="does not match"): + chunked_distillation_loss(s, t, mask) + + +@pytest.mark.parametrize("temperature", [0.5, 1.0, 2.0]) +def test_temperature_changes_loss_but_stays_finite(temperature): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl", temperature=temperature) + assert torch.isfinite(loss).item() + + +def test_jsd_is_symmetric(): + torch.manual_seed(0) + a = torch.randn(2, 8, 32) + b = torch.randn(2, 8, 32) + mask = torch.ones(2, 8) + jsd_ab = chunked_distillation_loss(a, b, mask, loss_type="jsd") + jsd_ba = chunked_distillation_loss(b, a, mask, loss_type="jsd") + assert torch.allclose(jsd_ab, jsd_ba, atol=1e-5) + + +def test_all_zero_mask_returns_zero(): + torch.manual_seed(0) + s = torch.randn(2, 8, 32) + t = torch.randn(2, 8, 32) + mask = torch.zeros(2, 8) + loss = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl") + assert loss.item() == pytest.approx(0.0, abs=1e-6) + + +def test_per_token_logprobs_matches_manual(): + torch.manual_seed(0) + logits = torch.randn(2, 4, 16) + labels = torch.randint(0, 16, (2, 4)) + got = per_token_logprobs(logits, labels) + expected = torch.log_softmax(logits.float(), dim=-1) + expected = expected.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + assert torch.allclose(got, expected, atol=1e-6) + + +def test_build_response_mask_basic(): + attention_mask = torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]) + response_start_idx = torch.tensor([2, 3]) + resp = build_response_mask(response_start_idx, attention_mask) + expected = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 0, 1, 1]]) + assert torch.equal(resp, expected) + + +def test_build_response_mask_validates_shapes(): + with pytest.raises(ValueError, match="response_start_idx must be 1-D"): + build_response_mask(torch.zeros(2, 2), torch.ones(2, 4)) + with pytest.raises(ValueError, match="attention_mask must be 2-D"): + build_response_mask(torch.zeros(2), torch.ones(4)) + with pytest.raises(ValueError, match="batch"): + build_response_mask(torch.zeros(3), torch.ones(2, 4)) + + +def test_shift_for_next_token_prediction_shapes(): + logits = torch.randn(2, 5, 8) + labels = torch.randint(0, 8, (2, 5)) + sl, sla = shift_for_next_token_prediction(logits, labels) + assert sl.shape == (2, 4, 8) + assert sla.shape == (2, 4) From 14d8fe7ee594efd9528406934d5a7d6a28e9e67f Mon Sep 17 00:00:00 2001 From: Zhipeng Wang Date: Tue, 26 May 2026 07:15:03 +0000 Subject: [PATCH 2/4] Add OPSD frozen teacher with CPU logit cache + tests Adds the two-phase teacher path: * TeacherWrapper loads a HuggingFace causal LM, freezes it, and runs forward-only. Two modes: load + pin on GPU (offload_to_cpu=false), or wrap with deepspeed.initialize using a ZeRO-3 + offload_param=cpu config (offload_to_cpu=true). Avoids deepspeed.zero.Init() around from_pretrained because HF's loader partitions params to zero-width shards before the checkpoint can fill them. * TeacherLogitCache stages the [B, T, V] teacher logits to (pinned) host memory in bf16, and exposes chunk_to_device() so the student-side loss can pull sequence slices back on demand. This is the memory-economising half of the two-phase update. CPU-only tests cover the cache shape / dtype / round-trip / chunk-bounds behaviour and verify the streamed-via-cache loss matches the direct chunked loss bit-for-bit. Signed-off-by: Zhipeng Wang --- examples/opsd/opsd/teacher.py | 191 ++++++++++++++++++++ examples/opsd/tests/test_teacher_caching.py | 101 +++++++++++ 2 files changed, 292 insertions(+) create mode 100644 examples/opsd/opsd/teacher.py create mode 100644 examples/opsd/tests/test_teacher_caching.py diff --git a/examples/opsd/opsd/teacher.py b/examples/opsd/opsd/teacher.py new file mode 100644 index 000000000000..a7895beddf00 --- /dev/null +++ b/examples/opsd/opsd/teacher.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Frozen teacher: two-phase forward with CPU-cached logits. + +The trainer runs each step in two phases: + + 1. **Teacher phase.** Forward over the prompt+response. The full ``[B, T, V]`` + logit tensor is moved off the GPU into a :class:`TeacherLogitCache` so that + teacher weight buffers can be released before the student backward pass. + 2. **Student phase.** Forward + backward on the student. The distillation + loss pulls teacher logits back to GPU **one sequence chunk at a time** via + :meth:`TeacherLogitCache.chunk_to_device`, so peak GPU memory for teacher + data is only ``[B, chunk, V]``. + +This module deliberately lazy-imports ``deepspeed`` and ``transformers`` so +that the pure data-handling pieces (``TeacherLogitCache`` and the streamed +loss in :mod:`opsd.losses`) remain importable in CPU-only unit tests that do +not have a working DeepSpeed launcher. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +# ``opsd.config`` is pure-Python (no distributed imports), so we can import it +# at module load time without pulling in DeepSpeed. +from opsd.config import TeacherConfig + + +@dataclass +class TeacherLogitCache: + """CPU-resident teacher logits with on-demand chunk fetch. + + Stored in low precision (default ``bfloat16``) to halve host memory; the + consumer in :mod:`opsd.losses` promotes back to fp32 inside the divergence + so the KD math stays well-conditioned. + """ + + cpu_logits: torch.Tensor # [B, T, V] + + def __post_init__(self) -> None: + if self.cpu_logits.dim() != 3: + raise ValueError(f"cpu_logits must be 3-D [B, T, V]; got shape " + f"{tuple(self.cpu_logits.shape)}") + if self.cpu_logits.device.type != "cpu": + raise ValueError(f"cpu_logits must live on CPU; got device " + f"{self.cpu_logits.device}") + + @classmethod + def from_gpu_logits(cls, logits: torch.Tensor, store_dtype: torch.dtype = torch.bfloat16) -> "TeacherLogitCache": + """Detach + downcast + move to (pinned) host memory. + + ``non_blocking=True`` lets the copy overlap with the next CUDA op when + the destination is pinned; we try to pin and fall back silently if the + host doesn't support it (e.g. CPU-only test environments). + """ + downcast = logits.detach().to(dtype=store_dtype) + try: + host = torch.empty(downcast.shape, dtype=store_dtype, pin_memory=True) + host.copy_(downcast, non_blocking=True) + except RuntimeError: + host = downcast.cpu() + return cls(cpu_logits=host) + + @property + def shape(self) -> Tuple[int, int, int]: + s = self.cpu_logits.shape + return (int(s[0]), int(s[1]), int(s[2])) + + @property + def dtype(self) -> torch.dtype: + return self.cpu_logits.dtype + + def chunk_to_device(self, + start: int, + end: int, + device: torch.device, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Slice ``[:, start:end, :]`` and stage it on ``device``. + + ``dtype`` is the dtype on the destination; if ``None``, the stored + dtype is preserved. + """ + _, T, _ = self.shape + if not (0 <= start < end <= T): + raise ValueError(f"chunk bounds [{start}, {end}) invalid for T={T}") + chunk = self.cpu_logits[:, start:end] + out = chunk.to(device=device, dtype=dtype if dtype is not None else chunk.dtype, non_blocking=True) + return out + + def free(self) -> None: + """Drop the underlying buffer so a step's teacher cache can be GC'd + before the next teacher forward.""" + self.cpu_logits = torch.empty(0) + + +_DTYPE_MAP = { + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, +} + + +def _resolve_dtype(name: str) -> torch.dtype: + if name not in _DTYPE_MAP: + raise ValueError(f"Unknown dtype {name!r}; choose from {sorted(_DTYPE_MAP)}") + return _DTYPE_MAP[name] + + +class TeacherWrapper: + """Frozen teacher. + + Two modes depending on ``cfg.offload_to_cpu``: + + * ``offload_to_cpu=False`` — load the teacher with HF's standard + ``from_pretrained`` and pin it on the local accelerator device. The + whole teacher lives in GPU memory; simplest path and what to use when + the teacher fits. + + * ``offload_to_cpu=True`` — wrap the loaded model with + ``deepspeed.initialize`` using a ZeRO-3 config with + ``offload_param.device='cpu'``. The optimizer slot is unused (no + trainable params) but ZeRO-3 gives us per-forward parameter gather + / release and keeps weights on the host between forwards. This is the + path to use when the teacher would otherwise not fit alongside the + student. + + Both paths load the full checkpoint on each rank before DeepSpeed (if + used) partitions; we intentionally do **not** wrap ``from_pretrained`` + in ``deepspeed.zero.Init()`` because HF's loader partitions + ``low_cpu_mem_usage`` params to zero-width shards before the checkpoint + can fill them, which surfaces as a "size mismatch" load error. + """ + + def __init__(self, cfg: TeacherConfig, world_size: int): + from deepspeed.accelerator import get_accelerator + from transformers import AutoModelForCausalLM + + self.cfg = cfg + dtype = _resolve_dtype(cfg.dtype) + device = get_accelerator().current_device_name() + + model = AutoModelForCausalLM.from_pretrained( + cfg.model_name_or_path, + torch_dtype=dtype, + trust_remote_code=cfg.trust_remote_code, + ) + model.eval() + for p in model.parameters(): + p.requires_grad_(False) + + if cfg.offload_to_cpu: + import deepspeed + + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "bf16": { + "enabled": dtype is torch.bfloat16 + }, + "fp16": { + "enabled": dtype is torch.float16 + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu" + }, + }, + } + engine, *_ = deepspeed.initialize(model=model, config=ds_config) + self._callable = engine + self._uses_ds = True + else: + model.to(device) + self._callable = model + self._uses_ds = False + + @torch.no_grad() + def forward_to_cache(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + store_dtype: torch.dtype = torch.bfloat16) -> TeacherLogitCache: + """Run teacher forward and stage logits onto the host.""" + outputs = self._callable(input_ids=input_ids, attention_mask=attention_mask) + return TeacherLogitCache.from_gpu_logits(outputs.logits, store_dtype=store_dtype) diff --git a/examples/opsd/tests/test_teacher_caching.py b/examples/opsd/tests/test_teacher_caching.py new file mode 100644 index 000000000000..5702bc287ffe --- /dev/null +++ b/examples/opsd/tests/test_teacher_caching.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only tests for TeacherLogitCache. + +The ``TeacherWrapper`` itself (which wraps deepspeed+transformers) is not +exercised here because it requires a real model and a DeepSpeed launcher; the +caching/streaming pieces are isolated into ``TeacherLogitCache`` so they can +be tested in isolation. +""" + +import pytest +import torch + +from opsd.teacher import TeacherLogitCache + + +def test_round_trip_preserves_values_within_dtype(): + torch.manual_seed(0) + gpu_like = torch.randn(2, 16, 32, dtype=torch.float32) + cache = TeacherLogitCache.from_gpu_logits(gpu_like, store_dtype=torch.bfloat16) + assert cache.shape == (2, 16, 32) + assert cache.dtype == torch.bfloat16 + chunk = cache.chunk_to_device(0, 16, torch.device("cpu"), dtype=torch.float32) + # bf16 round-trip loses precision; check it stays within bf16's worst-case + # relative error rather than asserting exact equality. + assert torch.allclose(chunk, gpu_like, atol=1e-1, rtol=1e-1) + + +def test_chunk_slicing_is_correct(): + torch.manual_seed(0) + src = torch.randn(3, 100, 8) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + for start, end in [(0, 10), (10, 50), (50, 100), (33, 77)]: + got = cache.chunk_to_device(start, end, torch.device("cpu")) + assert got.shape == (3, end - start, 8) + assert torch.allclose(got, src[:, start:end]) + + +def test_invalid_chunk_bounds_raise(): + cache = TeacherLogitCache.from_gpu_logits(torch.zeros(1, 8, 4), store_dtype=torch.float32) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(0, 9, torch.device("cpu")) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(5, 3, torch.device("cpu")) + with pytest.raises(ValueError, match="invalid"): + cache.chunk_to_device(-1, 4, torch.device("cpu")) + + +def test_rejects_non_3d_logits(): + with pytest.raises(ValueError, match="must be 3-D"): + TeacherLogitCache(cpu_logits=torch.zeros(8, 32)) + + +def test_rejects_gpu_resident_logits(): + if not torch.cuda.is_available(): #ignore-cuda + pytest.skip("no CUDA available to construct GPU tensor") + with pytest.raises(ValueError, match="must live on CPU"): + TeacherLogitCache(cpu_logits=torch.zeros(1, 8, 4, device="cuda")) + + +def test_dtype_override_in_chunk_to_device(): + src = torch.randn(2, 8, 16, dtype=torch.float32) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + chunk = cache.chunk_to_device(0, 8, torch.device("cpu"), dtype=torch.bfloat16) + assert chunk.dtype == torch.bfloat16 + + +def test_free_releases_buffer(): + src = torch.randn(2, 32, 16) + cache = TeacherLogitCache.from_gpu_logits(src, store_dtype=torch.float32) + assert cache.cpu_logits.numel() == 2 * 32 * 16 + cache.free() + assert cache.cpu_logits.numel() == 0 + + +def test_default_store_dtype_is_bf16(): + src = torch.randn(1, 4, 8) + cache = TeacherLogitCache.from_gpu_logits(src) + assert cache.dtype == torch.bfloat16 + + +def test_streamed_chunked_loss_matches_full_loss(): + """End-to-end check: pulling teacher logits chunk-by-chunk through the + cache yields the same distillation loss as passing the full teacher tensor + to ``chunked_distillation_loss`` directly.""" + from opsd.losses import chunked_distillation_loss + + torch.manual_seed(0) + s = torch.randn(2, 64, 32) + t = torch.randn(2, 64, 32) + mask = torch.ones(2, 64) + + direct = chunked_distillation_loss(s, t, mask, loss_type="reverse_kl", chunk_size=8) + + cache = TeacherLogitCache.from_gpu_logits(t, store_dtype=torch.float32) + staged_full = cache.chunk_to_device(0, 64, torch.device("cpu"), dtype=torch.float32) + via_cache = chunked_distillation_loss(s, staged_full, mask, loss_type="reverse_kl", chunk_size=8) + + assert torch.allclose(direct, via_cache, atol=1e-6) From 837787a041599c8cec1d11e55faf03f93e1813f6 Mon Sep 17 00:00:00 2001 From: Zhipeng Wang Date: Tue, 26 May 2026 07:15:28 +0000 Subject: [PATCH 3/4] Add OPSD trainer, hybrid-engine rollout, and end-to-end entry point Lands the fully-runnable hybrid-engine training path: a backend-agnostic RolloutEngine ABC with RolloutRequest / RolloutBatch / SamplingConfig dataclasses, a HybridEngineRollout implementation that uses DeepSpeed's accelerated decode when an inference policy exists and otherwise falls back to GatheredParameters + the raw HF generate (covers Qwen-family and other models not in DeepSpeed's inference container list), a left-padded prompt dataset + collator, a three-phase trainer loop (rollout -> teacher forward + cache -> student forward + streamed KL + backward + step), the argparse + deepspeed.initialize entry point, base DeepSpeed ZeRO-3 + hybrid_engine JSON configs, a 5-step smoke config and launcher script, and a 20-prompt math toy dataset for the smoke run. Smoke-validated end-to-end on 2x H200 with Qwen2.5-0.5B-Instruct student and Qwen2.5-1.5B-Instruct teacher; loss finite for 5 steps. Rollout interface contract is covered by tests/test_rollout_interface.py. Signed-off-by: Zhipeng Wang --- examples/opsd/configs/ds_zero3.json | 43 ++++ examples/opsd/configs/opsd_hybrid_engine.json | 49 +++++ examples/opsd/configs/smoke_ds_zero3.json | 35 ++++ examples/opsd/configs/smoke_hybrid.json | 49 +++++ examples/opsd/data/prompts.jsonl | 20 ++ examples/opsd/main.py | 135 ++++++++++++ examples/opsd/opsd/data.py | 108 ++++++++++ examples/opsd/opsd/rollout/__init__.py | 39 ++++ examples/opsd/opsd/rollout/base.py | 117 +++++++++++ examples/opsd/opsd/rollout/hybrid_engine.py | 119 +++++++++++ examples/opsd/opsd/trainer.py | 197 ++++++++++++++++++ examples/opsd/scripts/train_opsd_hybrid.sh | 14 ++ examples/opsd/tests/test_rollout_interface.py | 156 ++++++++++++++ 13 files changed, 1081 insertions(+) create mode 100644 examples/opsd/configs/ds_zero3.json create mode 100644 examples/opsd/configs/opsd_hybrid_engine.json create mode 100644 examples/opsd/configs/smoke_ds_zero3.json create mode 100644 examples/opsd/configs/smoke_hybrid.json create mode 100644 examples/opsd/data/prompts.jsonl create mode 100644 examples/opsd/main.py create mode 100644 examples/opsd/opsd/data.py create mode 100644 examples/opsd/opsd/rollout/__init__.py create mode 100644 examples/opsd/opsd/rollout/base.py create mode 100644 examples/opsd/opsd/rollout/hybrid_engine.py create mode 100644 examples/opsd/opsd/trainer.py create mode 100644 examples/opsd/scripts/train_opsd_hybrid.sh create mode 100644 examples/opsd/tests/test_rollout_interface.py diff --git a/examples/opsd/configs/ds_zero3.json b/examples/opsd/configs/ds_zero3.json new file mode 100644 index 000000000000..1f43339a6f20 --- /dev/null +++ b/examples/opsd/configs/ds_zero3.json @@ -0,0 +1,43 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_prefetch_bucket_size": 5e7, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.0 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-6, + "warmup_num_steps": 0 + } + }, + "gradient_clipping": 1.0, + "hybrid_engine": { + "enabled": true, + "max_out_tokens": 2048, + "inference_tp_size": 1, + "release_inference_cache": false, + "pin_parameters": true, + "tp_gather_partition_size": 8 + }, + "wall_clock_breakdown": false +} diff --git a/examples/opsd/configs/opsd_hybrid_engine.json b/examples/opsd/configs/opsd_hybrid_engine.json new file mode 100644 index 000000000000..5a7d45b54f6a --- /dev/null +++ b/examples/opsd/configs/opsd_hybrid_engine.json @@ -0,0 +1,49 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "arch": "qwen2" + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": true + }, + "rollout": { + "engine": "hybrid_engine", + "max_prompt_length": 1024, + "max_response_length": 1024, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "weight_sync_interval": 1 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 1.0, + "chunk_size": 512 + }, + "training": { + "train_batch_size": 8, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": -1, + "warmup_steps": 0, + "save_steps": 500, + "logging_steps": 10, + "save_dir": "./opsd_ckpt_hybrid", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/ds_zero3.json" +} diff --git a/examples/opsd/configs/smoke_ds_zero3.json b/examples/opsd/configs/smoke_ds_zero3.json new file mode 100644 index 000000000000..74211f3fbd9f --- /dev/null +++ b/examples/opsd/configs/smoke_ds_zero3.json @@ -0,0 +1,35 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_prefetch_bucket_size": 5e7, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.0 + } + }, + "gradient_clipping": 1.0, + "hybrid_engine": { + "enabled": true, + "max_out_tokens": 512, + "inference_tp_size": 1, + "release_inference_cache": false, + "pin_parameters": true, + "tp_gather_partition_size": 8 + }, + "wall_clock_breakdown": false +} diff --git a/examples/opsd/configs/smoke_hybrid.json b/examples/opsd/configs/smoke_hybrid.json new file mode 100644 index 000000000000..218bd990ae97 --- /dev/null +++ b/examples/opsd/configs/smoke_hybrid.json @@ -0,0 +1,49 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "arch": "qwen2" + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": false + }, + "rollout": { + "engine": "hybrid_engine", + "max_prompt_length": 128, + "max_response_length": 64, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "weight_sync_interval": 1 + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 1.0, + "chunk_size": 128 + }, + "training": { + "train_batch_size": 2, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": 5, + "warmup_steps": 0, + "save_steps": 10000, + "logging_steps": 1, + "save_dir": "./opsd_smoke_hybrid_ckpt", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/smoke_ds_zero3.json" +} diff --git a/examples/opsd/data/prompts.jsonl b/examples/opsd/data/prompts.jsonl new file mode 100644 index 000000000000..a95a17c57557 --- /dev/null +++ b/examples/opsd/data/prompts.jsonl @@ -0,0 +1,20 @@ +{"prompt": "Solve: 17 + 25 = ?"} +{"prompt": "What is 12 multiplied by 8?"} +{"prompt": "If a train travels 60 miles per hour for 3 hours, how far does it go?"} +{"prompt": "What is the square root of 144?"} +{"prompt": "Compute 15% of 240."} +{"prompt": "A rectangle has length 7 and width 4. What is its area?"} +{"prompt": "Solve for x: 2x + 5 = 17."} +{"prompt": "What is 7 factorial?"} +{"prompt": "Compute the sum of integers from 1 to 10."} +{"prompt": "What is 2 to the power of 10?"} +{"prompt": "Find the perimeter of a square with side length 9."} +{"prompt": "If 5 apples cost $2.50, what is the cost of 12 apples?"} +{"prompt": "What is the greatest common divisor of 24 and 36?"} +{"prompt": "Convert 0.75 to a fraction in simplest form."} +{"prompt": "If x + y = 10 and x - y = 4, find x and y."} +{"prompt": "What is 1/4 + 1/3?"} +{"prompt": "A circle has radius 5. What is its area? (Use pi = 3.14)"} +{"prompt": "Compute (3 + 4) * (5 - 2)."} +{"prompt": "What is 81 divided by 9?"} +{"prompt": "If a number doubled is 18, what is the number?"} diff --git a/examples/opsd/main.py b/examples/opsd/main.py new file mode 100644 index 000000000000..b2e5c4c6929b --- /dev/null +++ b/examples/opsd/main.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""OPSD training entry point. + +Launch with the DeepSpeed launcher:: + + deepspeed --num_gpus 8 main.py --config configs/opsd_hybrid_engine.json + +The DeepSpeed launcher sets ``LOCAL_RANK``, ``RANK``, and ``WORLD_SIZE`` in +the environment; we call :func:`deepspeed.init_distributed` to take that over. +""" + +import argparse +import json +import os +import random + +import deepspeed +import numpy as np +import torch +from deepspeed.accelerator import get_accelerator +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer + +from opsd.config import OPSDConfig +from opsd.data import LeftPaddedPromptCollator, PromptDataset +from opsd.rollout import build_rollout +from opsd.teacher import TeacherWrapper +from opsd.trainer import OPSDTrainer + + +def _seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if get_accelerator().is_available(): + get_accelerator().manual_seed_all(seed) + + +def _resolve_dtype(name: str) -> torch.dtype: + return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name] + + +def _load_ds_config(path: str) -> dict: + with open(path, "r") as f: + return json.load(f) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="Path to OPSDConfig JSON") + parser.add_argument("--local_rank", type=int, default=int(os.environ.get("LOCAL_RANK", 0))) + args = parser.parse_args() + + cfg = OPSDConfig.from_json(args.config) + cfg.validate() + _seed_everything(cfg.training.seed) + + deepspeed.init_distributed() + + # --- tokenizer (shared between data + rollout) ------------------------- + tokenizer = AutoTokenizer.from_pretrained( + cfg.student.model_name_or_path, + trust_remote_code=cfg.student.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # --- student model + DeepSpeed engine ---------------------------------- + student_dtype = _resolve_dtype(cfg.student.dtype) + student_model = AutoModelForCausalLM.from_pretrained( + cfg.student.model_name_or_path, + torch_dtype=student_dtype, + trust_remote_code=cfg.student.trust_remote_code, + ) + + ds_config = _load_ds_config(cfg.deepspeed_config) + ds_config["train_micro_batch_size_per_gpu"] = cfg.training.micro_batch_size_per_gpu + ds_config["train_batch_size"] = cfg.training.train_batch_size + ds_config["gradient_accumulation_steps"] = cfg.training.gradient_accumulation_steps + + student_engine, *_ = deepspeed.initialize( + model=student_model, + model_parameters=student_model.parameters(), + config=ds_config, + ) + + # --- frozen teacher ---------------------------------------------------- + teacher = TeacherWrapper(cfg.teacher, world_size=dist_world_size()) + + # --- rollout engine ---------------------------------------------------- + rollout = build_rollout( + cfg.rollout, + student_engine=student_engine, + tokenizer=tokenizer, + student_model_path=cfg.student.model_name_or_path, + arch=cfg.student.arch, + ) + + # --- dataloader -------------------------------------------------------- + dataset = PromptDataset( + path=cfg.data.path, + tokenizer=tokenizer, + max_prompt_length=cfg.rollout.max_prompt_length, + prompt_field=cfg.data.prompt_field, + chat_template=cfg.data.chat_template, + ) + collator = LeftPaddedPromptCollator(tokenizer=tokenizer, max_prompt_length=cfg.rollout.max_prompt_length) + loader = DataLoader( + dataset, + batch_size=cfg.training.micro_batch_size_per_gpu, + shuffle=cfg.data.shuffle, + collate_fn=collator, + drop_last=True, + ) + + OPSDTrainer( + cfg=cfg, + student_engine=student_engine, + teacher=teacher, + tokenizer=tokenizer, + rollout=rollout, + dataloader=loader, + ).train() + + +def dist_world_size() -> int: + return int(os.environ.get("WORLD_SIZE", "1")) + + +if __name__ == "__main__": + main() diff --git a/examples/opsd/opsd/data.py b/examples/opsd/opsd/data.py new file mode 100644 index 000000000000..02ecf417e5c3 --- /dev/null +++ b/examples/opsd/opsd/data.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Prompt dataset and left-padding collator for OPSD rollouts. + +The dataset reads a JSONL file with one record per line; each record must +contain a string under :attr:`DataConfig.prompt_field` (default ``"prompt"``). +If the tokenizer exposes ``apply_chat_template``, single-turn prompts are +wrapped in a user-role message with ``add_generation_prompt=True`` so the +student generates the assistant turn. + +Batches are **left-padded** because causal generation requires real tokens at +the right edge — :class:`opsd.rollout.RolloutRequest` and the hybrid-engine +backend both assume this layout. +""" + +import json +from typing import Any, Dict, List, Optional + +import torch +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + """Reads ``{prompt_field: str}`` records from a JSONL file.""" + + def __init__( + self, + path: str, + tokenizer: Any, + max_prompt_length: int, + prompt_field: str = "prompt", + chat_template: Optional[str] = None, + ): + self.records = self._load_jsonl(path) + self.tokenizer = tokenizer + self.max_prompt_length = max_prompt_length + self.prompt_field = prompt_field + self.chat_template = chat_template + + @staticmethod + def _load_jsonl(path: str) -> List[Dict[str, Any]]: + records: List[Dict[str, Any]] = [] + with open(path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + records.append(json.loads(line)) + return records + + def __len__(self) -> int: + return len(self.records) + + def __getitem__(self, idx: int) -> str: + rec = self.records[idx] + if self.prompt_field not in rec: + raise KeyError(f"record {idx} missing field {self.prompt_field!r}") + text = rec[self.prompt_field] + + # If the tokenizer knows a chat template, render the prompt as a single + # user-role turn and request the generation prompt. This matches how + # instruction-tuned student/teacher checkpoints expect inputs. + if hasattr(self.tokenizer, "apply_chat_template"): + messages = [{"role": "user", "content": text}] if isinstance(text, str) else text + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + ) + return text + + +class LeftPaddedPromptCollator: + """Tokenizes a batch of prompt strings into a left-padded tensor batch.""" + + def __init__(self, tokenizer: Any, max_prompt_length: int): + self.tokenizer = tokenizer + self.max_prompt_length = max_prompt_length + self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + if self.pad_id is None: + raise ValueError("tokenizer has neither pad_token_id nor eos_token_id; " + "cannot construct a padding collator") + + def __call__(self, batch_texts: List[str]) -> Dict[str, torch.Tensor]: + per_sample = [ + self.tokenizer( + t, + add_special_tokens=False, + truncation=True, + max_length=self.max_prompt_length, + return_tensors="pt", + )["input_ids"].squeeze(0) for t in batch_texts + ] + max_len = max(int(x.shape[0]) for x in per_sample) + B = len(per_sample) + + prompt_ids = torch.full((B, max_len), self.pad_id, dtype=torch.long) + attention_mask = torch.zeros((B, max_len), dtype=torch.long) + for i, ids in enumerate(per_sample): + n = int(ids.shape[0]) + # left-pad: real tokens at the right edge + prompt_ids[i, max_len - n:] = ids + attention_mask[i, max_len - n:] = 1 + + return {"prompt_ids": prompt_ids, "prompt_attention_mask": attention_mask} diff --git a/examples/opsd/opsd/rollout/__init__.py b/examples/opsd/opsd/rollout/__init__.py new file mode 100644 index 000000000000..0509d6d8b4c9 --- /dev/null +++ b/examples/opsd/opsd/rollout/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout engines for OPSD: hybrid engine (built-in) or vLLM (disjoint GPUs).""" + +from opsd.rollout.base import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig + +__all__ = ["RolloutBatch", "RolloutEngine", "RolloutRequest", "SamplingConfig", "build_rollout"] + + +def build_rollout(rollout_cfg, student_engine=None, tokenizer=None, student_model_path=None, arch=None): + """Factory: construct the rollout engine specified by ``rollout_cfg.engine``. + + Imports of heavy backends are deferred to here so that selecting the + hybrid-engine path doesn't transitively require vLLM (and vice versa). + """ + engine_name = rollout_cfg.engine + if engine_name == "hybrid_engine": + from opsd.rollout.hybrid_engine import HybridEngineRollout + + if student_engine is None or tokenizer is None: + raise ValueError("hybrid_engine rollout needs both student_engine and tokenizer") + return HybridEngineRollout(student_engine=student_engine, tokenizer=tokenizer, cfg=rollout_cfg) + + if engine_name == "vllm": + from opsd.rollout.vllm import VLLMRollout + + if tokenizer is None: + raise ValueError("vllm rollout needs a tokenizer for length accounting") + return VLLMRollout( + cfg=rollout_cfg, + tokenizer=tokenizer, + student_engine=student_engine, + student_model_path=student_model_path, + arch=arch, + ) + + raise ValueError(f"Unknown rollout engine {engine_name!r}; choose from 'hybrid_engine' | 'vllm'") diff --git a/examples/opsd/opsd/rollout/base.py b/examples/opsd/opsd/rollout/base.py new file mode 100644 index 000000000000..62789d25c1cd --- /dev/null +++ b/examples/opsd/opsd/rollout/base.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout engine interface. + +The trainer talks to its rollout engine through three small dataclasses +(``RolloutRequest`` in / ``RolloutBatch`` out / ``SamplingConfig``) and one +ABC. This keeps the engine-specific concerns (hybrid-engine vs vLLM, weight +sync, process topology) out of the trainer loop, so swapping engines is a +one-line config change. + +Concrete engines live in sibling modules: + * :mod:`opsd.rollout.hybrid_engine` — DeepSpeed hybrid engine + * :mod:`opsd.rollout.vllm` — vLLM on a disjoint GPU group +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + + +@dataclass +class SamplingConfig: + """Sampling knobs that the trainer passes to ``generate`` each step.""" + + max_new_tokens: int + temperature: float = 1.0 + top_p: float = 1.0 + # ``top_k <= 0`` means "no top-k truncation". + top_k: int = -1 + # Number of samples per prompt. >1 expands the effective batch. + n_samples_per_prompt: int = 1 + + +@dataclass +class RolloutRequest: + """Input to ``RolloutEngine.generate``. + + Prompts arrive *left-padded* (i.e. real tokens at the right edge) so that + causal generation appends naturally after them. + """ + + prompt_ids: torch.Tensor # [B, T_p] left-padded with pad_token_id + prompt_attention_mask: torch.Tensor # [B, T_p], 1 on real prompt tokens + + def __post_init__(self) -> None: + if self.prompt_ids.dim() != 2: + raise ValueError(f"prompt_ids must be 2-D [B, T_p]; got {tuple(self.prompt_ids.shape)}") + if self.prompt_attention_mask.shape != self.prompt_ids.shape: + raise ValueError(f"prompt_attention_mask shape {tuple(self.prompt_attention_mask.shape)} " + f"does not match prompt_ids {tuple(self.prompt_ids.shape)}") + + +@dataclass +class RolloutBatch: + """Output of ``RolloutEngine.generate``. + + ``input_ids`` holds the *concatenation* of (left-padded) prompt and + response, right-padded to the longest sequence in the batch. + ``response_start_idx[i]`` is the column index at which the response + begins, so positions ``>= response_start_idx[i]`` (intersected with + ``attention_mask``) are response tokens. + + Note: with the standard *left-padded* prompt convention, every sample's + response starts at the same column (= the prompt section length), but the + field is kept per-sample so that mixed-batch backends (e.g. vLLM, which + may strip its own padding) can still report a meaningful boundary. + """ + + input_ids: torch.Tensor # [B', T_p + T_r]; B' = B * n_samples_per_prompt + attention_mask: torch.Tensor # [B', T_p + T_r] + response_start_idx: torch.Tensor # [B'] int + + def __post_init__(self) -> None: + if self.input_ids.dim() != 2: + raise ValueError(f"input_ids must be 2-D; got {tuple(self.input_ids.shape)}") + if self.attention_mask.shape != self.input_ids.shape: + raise ValueError(f"attention_mask shape {tuple(self.attention_mask.shape)} does not " + f"match input_ids {tuple(self.input_ids.shape)}") + B = self.input_ids.shape[0] + if self.response_start_idx.shape != (B, ): + raise ValueError(f"response_start_idx must be 1-D of length {B}; got " + f"{tuple(self.response_start_idx.shape)}") + + @property + def batch_size(self) -> int: + return int(self.input_ids.shape[0]) + + @property + def seq_len(self) -> int: + return int(self.input_ids.shape[1]) + + +class RolloutEngine(ABC): + """Abstract base for student rollout engines.""" + + name: str = "base" + + @abstractmethod + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + """Run the student's generate, return prompt+response in one tensor.""" + + @abstractmethod + def sync_weights_from_student(self, step: int) -> None: + """Push the student's current weights into the rollout backend. + + No-op for :class:`HybridEngineRollout` (the engine reads weights live + from the same process). Meaningful for :class:`VLLMRollout`, which + holds its own copy and must be refreshed periodically. + """ + + def shutdown(self) -> None: + """Release any backend resources (vLLM workers, NCCL groups, ...). + Default no-op.""" + return None diff --git a/examples/opsd/opsd/rollout/hybrid_engine.py b/examples/opsd/opsd/rollout/hybrid_engine.py new file mode 100644 index 000000000000..7e7ced928655 --- /dev/null +++ b/examples/opsd/opsd/rollout/hybrid_engine.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Rollout backed by DeepSpeed's hybrid engine, with a ZeRO-3 fallback. + +For architectures in DeepSpeed's inference-container policy list +(GPT2 / GPT-NeoX / OPT / BLOOM / LLAMA / LLAMA2 / InternLM as of 0.15) the +hybrid engine gives accelerated decode by swapping in optimized inference +kernels for the duration of the rollout. For everything else (Qwen2 / Qwen3 +/ any model without a policy), no inference container is created and +``DeepSpeedHybridEngine.generate`` would AttributeError on its unbound +``_generate`` slot — so we detect that case at construction time and fall +back to a manual path that just gathers ZeRO-3 partitions and calls the +HuggingFace model's ``generate`` directly. Correct, just slower than the +accelerated path. +""" + +import torch + +from opsd.config import RolloutConfig +from opsd.rollout.base import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig + + +def _hybrid_engine_has_accel(engine) -> bool: + # The accelerated path is only wired up when at least one inference + # container was populated for the model's layers. ``_inference_containers`` + # and ``_generate`` are both internal but they are the only two reliable + # signals across DeepSpeed 0.14–0.19; ``_generate`` is bound exactly when + # the container list is non-empty. + return getattr(engine, "_generate", None) is not None + + +class HybridEngineRollout(RolloutEngine): + name = "hybrid_engine" + + def __init__(self, student_engine, tokenizer, cfg: RolloutConfig): + if cfg.engine != "hybrid_engine": + raise ValueError(f"RolloutConfig.engine must be 'hybrid_engine'; got {cfg.engine!r}") + self.engine = student_engine + self.tokenizer = tokenizer + self.cfg = cfg + self._has_accel = _hybrid_engine_has_accel(student_engine) + + @torch.no_grad() + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + # Many decoder-only tokenizers (Llama, Qwen) ship without a pad + # token. Fall back to eos so that generate doesn't crash on the + # left-padded prompts. + pad_id = self.tokenizer.eos_token_id + + gen_kwargs = dict( + input_ids=request.prompt_ids, + attention_mask=request.prompt_attention_mask, + max_new_tokens=sampling.max_new_tokens, + do_sample=sampling.temperature > 0.0, + temperature=max(sampling.temperature, 1e-8), + top_p=sampling.top_p, + top_k=sampling.top_k if sampling.top_k > 0 else 0, + num_return_sequences=sampling.n_samples_per_prompt, + pad_token_id=pad_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + # Hybrid engine expects training mode toggled off so the inference + # containers take over. eval() is cheap (boolean flip + module walk). + self.engine.eval() + try: + if self._has_accel: + seqs = self.engine.generate(**gen_kwargs) + else: + seqs = self._fallback_generate(**gen_kwargs) + finally: + self.engine.train() + + # ``seqs`` is [B * n, T_p + T_r_actual], left-padded prompt + response. + # With left-padded prompts every sample's response starts at column T_p. + B = request.prompt_ids.shape[0] + n = sampling.n_samples_per_prompt + T_p = request.prompt_ids.shape[1] + if seqs.shape[0] != B * n: + raise RuntimeError(f"generate returned batch {seqs.shape[0]}, expected {B * n}") + + response_start_idx = torch.full((B * n, ), T_p, dtype=torch.long, device=seqs.device) + + # Response positions are anything past the prompt that is also not pad. + attention_mask = (seqs != pad_id).to(request.prompt_attention_mask.dtype) + # Keep the prompt portion of the mask aligned with what the caller + # passed in (a prompt token equal to pad_id should still be attended); + # for typical left-padded prompts the overlap is identical. + prompt_mask_expanded = request.prompt_attention_mask.repeat_interleave(n, dim=0) + attention_mask[:, :T_p] = prompt_mask_expanded + + return RolloutBatch(input_ids=seqs, attention_mask=attention_mask, response_start_idx=response_start_idx) + + def sync_weights_from_student(self, step: int) -> None: # noqa: ARG002 + # The hybrid engine reads the student's live weights every generate + # call, so there is nothing to sync. + return None + + @torch.no_grad() + def _fallback_generate(self, **gen_kwargs) -> torch.Tensor: + """Manual ZeRO-3 generate for architectures the hybrid engine doesn't + have an inference policy for. + + Walks every parameter into a ``GatheredParameters`` context so the full + weight is materialized on each rank for the duration of generation, + then calls the underlying HF model's ``generate``. Re-partitions on + exit. This is correct but does not get the hybrid engine's optimized + kernels — expect ~3-5x slower decode than the accelerated path. + """ + from deepspeed.runtime.zero import GatheredParameters + + module = self.engine.module + all_params = list(module.parameters()) + with GatheredParameters(all_params): + return module.generate(**gen_kwargs) diff --git a/examples/opsd/opsd/trainer.py b/examples/opsd/opsd/trainer.py new file mode 100644 index 000000000000..315b5145ef7a --- /dev/null +++ b/examples/opsd/opsd/trainer.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""On-policy distillation training loop. + +Each step is three phases: + + 0. **Rollout.** The student generates responses for the batch's prompts + (via the configured :class:`~opsd.rollout.RolloutEngine` — hybrid engine + or vLLM). + 1. **Teacher.** The frozen teacher runs a forward over prompt+response. The + full logit tensor is parked on the host via + :class:`~opsd.teacher.TeacherLogitCache` so teacher GPU buffers can be + released before the student backward. + 2. **Student.** The student runs forward+backward on prompt+response. The + loss is the per-token divergence to the teacher, streamed from the + host-resident cache one sequence chunk at a time + (:func:`~opsd.losses.streamed_distillation_loss`), so the full + ``[B, T, V]`` teacher tensor never co-resides with the student logits on + the training device. + +The trainer itself contains no DeepSpeed-specific control flow beyond the +``backward`` / ``step`` calls on the student engine; backend choice (ZeRO +stage, offload, hybrid engine) is owned entirely by the DeepSpeed JSON config. +""" + +import os +import time +from typing import Any + +import torch +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator + +from opsd.config import OPSDConfig +from opsd.losses import streamed_distillation_loss +from opsd.rollout import RolloutEngine, RolloutRequest, SamplingConfig +from opsd.utils import build_response_mask + + +def _is_rank_zero() -> bool: + return (not dist.is_initialized()) or dist.get_rank() == 0 + + +class OPSDTrainer: + + def __init__( + self, + cfg: OPSDConfig, + student_engine: Any, + teacher: Any, + tokenizer: Any, + rollout: RolloutEngine, + dataloader: Any, + ): + self.cfg = cfg + self.student_engine = student_engine + self.teacher = teacher + self.tokenizer = tokenizer + self.rollout = rollout + self.dataloader = dataloader + + self.device = get_accelerator().current_device_name() + self.step = 0 + + # ------------------------------------------------------------------ + # Driver + # ------------------------------------------------------------------ + + def train(self) -> None: + max_steps = self.cfg.training.max_steps + for epoch in range(self.cfg.training.num_train_epochs): + for batch in self.dataloader: + if max_steps > 0 and self.step >= max_steps: + return + metrics = self._train_step(batch) + self._maybe_log(metrics) + self._maybe_save() + self.step += 1 + if max_steps > 0 and self.step >= max_steps: + return + + # ------------------------------------------------------------------ + # One step + # ------------------------------------------------------------------ + + def _train_step(self, batch) -> dict: + t_start = time.time() + + prompt_ids = batch["prompt_ids"].to(self.device, non_blocking=True) + prompt_attn = batch["prompt_attention_mask"].to(self.device, non_blocking=True) + + # Push student weights into the rollout backend if it's time to. + # No-op for the hybrid engine; meaningful for vLLM. + if self.step % self.cfg.rollout.weight_sync_interval == 0: + self.rollout.sync_weights_from_student(self.step) + + # --- Phase 0: rollout (student generates responses) --------------- + sampling = SamplingConfig( + max_new_tokens=self.cfg.rollout.max_response_length, + temperature=self.cfg.rollout.temperature, + top_p=self.cfg.rollout.top_p, + top_k=self.cfg.rollout.top_k, + n_samples_per_prompt=self.cfg.rollout.n_samples_per_prompt, + ) + roll = self.rollout.generate( + RolloutRequest(prompt_ids=prompt_ids, prompt_attention_mask=prompt_attn), + sampling, + ) + input_ids = roll.input_ids.to(self.device, non_blocking=True) + attention_mask = roll.attention_mask.to(self.device, non_blocking=True) + response_start_idx = roll.response_start_idx.to(self.device, non_blocking=True) + response_mask = build_response_mask(response_start_idx, attention_mask) + t_rollout = time.time() - t_start + + # --- Phase 1: teacher forward → host-cached logits ---------------- + t1 = time.time() + teacher_cache = self.teacher.forward_to_cache(input_ids, attention_mask) + t_teacher = time.time() - t1 + + # --- Phase 2: student forward + streamed KL + backward ------------ + t2 = time.time() + self.student_engine.train() + outputs = self.student_engine(input_ids=input_ids, attention_mask=attention_mask) + student_logits = outputs.logits # [B, T, V] + + # Shift for next-token prediction: logits at position t predict token + # at t+1, so the loss aligns student_logits[:, :-1] with the position + # t+1 entries of the response mask. + student_logits_shifted = student_logits[:, :-1, :] + mask_shifted = response_mask[:, 1:].contiguous() + + def _fetch(start: int, end: int) -> torch.Tensor: + # The cache holds *unshifted* teacher logits; for the next-token + # objective we ask the cache for positions [start, end) of the + # shifted teacher, which is positions [start, end) of the original + # since we already lopped off the final column in the student. + return teacher_cache.chunk_to_device(start, + end, + device=student_logits_shifted.device, + dtype=student_logits_shifted.dtype) + + loss = streamed_distillation_loss( + student_logits=student_logits_shifted, + teacher_chunk_fetcher=_fetch, + response_mask=mask_shifted, + loss_type=self.cfg.distillation.loss_type, + temperature=self.cfg.distillation.temperature, + chunk_size=self.cfg.distillation.chunk_size, + ) + + self.student_engine.backward(loss) + self.student_engine.step() + + teacher_cache.free() + t_student = time.time() - t2 + + # Reduce loss across ranks for a clean log line. + loss_for_log = loss.detach().clone() + if dist.is_initialized(): + dist.all_reduce(loss_for_log) + loss_for_log /= dist.get_world_size() + + return { + "loss": float(loss_for_log.item()), + "rollout_s": t_rollout, + "teacher_s": t_teacher, + "student_s": t_student, + "step_s": time.time() - t_start, + "response_tokens": int(mask_shifted.sum().item()), + } + + # ------------------------------------------------------------------ + # Logging / checkpointing + # ------------------------------------------------------------------ + + def _maybe_log(self, metrics: dict) -> None: + if self.step % self.cfg.training.logging_steps != 0: + return + if not _is_rank_zero(): + return + print(f"[opsd][step {self.step}] loss={metrics['loss']:.4f} " + f"rollout={metrics['rollout_s']:.2f}s teacher={metrics['teacher_s']:.2f}s " + f"student={metrics['student_s']:.2f}s step={metrics['step_s']:.2f}s " + f"resp_tok={metrics['response_tokens']}") + + def _maybe_save(self) -> None: + if self.step == 0: + return + if self.step % self.cfg.training.save_steps != 0: + return + tag = f"step_{self.step}" + os.makedirs(self.cfg.training.save_dir, exist_ok=True) + self.student_engine.save_checkpoint(self.cfg.training.save_dir, tag=tag) + if _is_rank_zero(): + print(f"[opsd] saved checkpoint to {self.cfg.training.save_dir}/{tag}") diff --git a/examples/opsd/scripts/train_opsd_hybrid.sh b/examples/opsd/scripts/train_opsd_hybrid.sh new file mode 100644 index 000000000000..69e3bdc68a7b --- /dev/null +++ b/examples/opsd/scripts/train_opsd_hybrid.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +# +# Launch OPSD training with the DeepSpeed hybrid-engine rollout (no vLLM). +# Assumes you're cd'd into examples/opsd/. +set -euo pipefail + +CONFIG="${1:-configs/opsd_hybrid_engine.json}" +NUM_GPUS="${NUM_GPUS:-8}" + +deepspeed --num_gpus "${NUM_GPUS}" main.py --config "${CONFIG}" diff --git a/examples/opsd/tests/test_rollout_interface.py b/examples/opsd/tests/test_rollout_interface.py new file mode 100644 index 000000000000..7c6fd0545443 --- /dev/null +++ b/examples/opsd/tests/test_rollout_interface.py @@ -0,0 +1,156 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Conformance tests for the RolloutEngine interface. + +Validates the dataclass invariants and exercises the interface against a +``FakeRollout`` so the contract is testable without GPUs or a model. The real +backends are tested manually with a launched training script (see README). +""" + +import pytest +import torch + +from opsd.rollout import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig +from opsd.utils import build_response_mask + +# --- dataclass invariants --------------------------------------------------- + + +def test_rollout_request_validates_shapes(): + with pytest.raises(ValueError, match="must be 2-D"): + RolloutRequest(prompt_ids=torch.zeros(8), prompt_attention_mask=torch.ones(8)) + with pytest.raises(ValueError, match="does not match"): + RolloutRequest(prompt_ids=torch.zeros(2, 4, dtype=torch.long), prompt_attention_mask=torch.ones(2, 5)) + + +def test_rollout_batch_validates_shapes(): + with pytest.raises(ValueError, match="must be 2-D"): + RolloutBatch(input_ids=torch.zeros(8, dtype=torch.long), + attention_mask=torch.ones(8), + response_start_idx=torch.tensor([4])) + with pytest.raises(ValueError, match="does not match"): + RolloutBatch(input_ids=torch.zeros(2, 4, dtype=torch.long), + attention_mask=torch.ones(2, 5), + response_start_idx=torch.tensor([4, 4])) + with pytest.raises(ValueError, match="1-D of length"): + RolloutBatch(input_ids=torch.zeros(2, 4, dtype=torch.long), + attention_mask=torch.ones(2, 4), + response_start_idx=torch.tensor([4])) + + +def test_rollout_batch_accessors(): + batch = RolloutBatch( + input_ids=torch.zeros(3, 12, dtype=torch.long), + attention_mask=torch.ones(3, 12), + response_start_idx=torch.tensor([4, 5, 6]), + ) + assert batch.batch_size == 3 + assert batch.seq_len == 12 + + +def test_sampling_config_defaults(): + cfg = SamplingConfig(max_new_tokens=32) + assert cfg.temperature == 1.0 + assert cfg.top_p == 1.0 + assert cfg.top_k == -1 + assert cfg.n_samples_per_prompt == 1 + + +# --- interface conformance via FakeRollout --------------------------------- + + +class FakeRollout(RolloutEngine): + """Deterministic stub: appends ``[42] * max_new_tokens`` to each prompt.""" + + name = "fake" + + def __init__(self, response_token: int = 42): + self.response_token = response_token + self.sync_calls: list = [] + + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + B, T_p = request.prompt_ids.shape + n = sampling.n_samples_per_prompt + T_r = sampling.max_new_tokens + + prompts_expanded = request.prompt_ids.repeat_interleave(n, dim=0) + attn_p_expanded = request.prompt_attention_mask.repeat_interleave(n, dim=0) + response = torch.full((B * n, T_r), self.response_token, dtype=request.prompt_ids.dtype) + response_attn = torch.ones((B * n, T_r), dtype=attn_p_expanded.dtype) + + input_ids = torch.cat([prompts_expanded, response], dim=1) + attention_mask = torch.cat([attn_p_expanded, response_attn], dim=1) + response_start_idx = torch.full((B * n, ), T_p, dtype=torch.long) + return RolloutBatch(input_ids=input_ids, attention_mask=attention_mask, response_start_idx=response_start_idx) + + def sync_weights_from_student(self, step: int) -> None: + self.sync_calls.append(step) + + +def test_fake_rollout_shape_basic(): + fake = FakeRollout() + req = RolloutRequest(prompt_ids=torch.tensor([[1, 2, 3], [4, 5, 6]]), + prompt_attention_mask=torch.ones(2, 3, dtype=torch.long)) + out = fake.generate(req, SamplingConfig(max_new_tokens=4)) + assert out.input_ids.shape == (2, 7) + assert out.attention_mask.shape == (2, 7) + # With left-padded (fully real here) prompts of width 3, response begins + # at column 3 for every sample. + assert out.response_start_idx.tolist() == [3, 3] + + +def test_fake_rollout_with_n_samples(): + fake = FakeRollout() + req = RolloutRequest(prompt_ids=torch.tensor([[1, 2], [3, 4]]), + prompt_attention_mask=torch.ones(2, 2, dtype=torch.long)) + out = fake.generate(req, SamplingConfig(max_new_tokens=3, n_samples_per_prompt=4)) + assert out.input_ids.shape == (8, 5) + assert out.response_start_idx.tolist() == [2] * 8 + + +def test_fake_rollout_left_padded_prompts(): + fake = FakeRollout() + # left-padded prompts: prompt B has only the last 2 positions real, but + # response_start_idx still equals the prompt column width T_p. + prompt_ids = torch.tensor([[1, 2, 3, 4], [0, 0, 5, 6]]) + attn = torch.tensor([[1, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.long) + req = RolloutRequest(prompt_ids=prompt_ids, prompt_attention_mask=attn) + out = fake.generate(req, SamplingConfig(max_new_tokens=2)) + assert out.response_start_idx.tolist() == [4, 4] + + +def test_response_mask_from_rollout_output_matches_helper(): + fake = FakeRollout() + prompt_ids = torch.tensor([[1, 2, 3], [0, 4, 5]]) + attn = torch.tensor([[1, 1, 1], [0, 1, 1]], dtype=torch.long) + out = fake.generate(RolloutRequest(prompt_ids, attn), SamplingConfig(max_new_tokens=3)) + mask = build_response_mask(out.response_start_idx, out.attention_mask) + # Both samples: response starts at column 3 (T_p), and all post-prompt + # positions are attended (FakeRollout produces no padding in the response). + assert mask[0].tolist() == [0, 0, 0, 1, 1, 1] + assert mask[1].tolist() == [0, 0, 0, 1, 1, 1] + + +def test_sync_records_steps(): + fake = FakeRollout() + fake.sync_weights_from_student(0) + fake.sync_weights_from_student(5) + assert fake.sync_calls == [0, 5] + + +def test_engine_factory_unknown_raises(): + from opsd.config import RolloutConfig + from opsd.rollout import build_rollout + + with pytest.raises(ValueError, match="Unknown rollout engine"): + build_rollout(RolloutConfig(engine="totally_made_up")) + + +def test_engine_factory_hybrid_requires_student_engine(): + from opsd.config import RolloutConfig + from opsd.rollout import build_rollout + + with pytest.raises(ValueError, match="needs both"): + build_rollout(RolloutConfig(engine="hybrid_engine")) From 6384396b48f29732115a3b931bb4e71c63d6d827 Mon Sep 17 00:00:00 2001 From: Zhipeng Wang Date: Tue, 26 May 2026 07:15:53 +0000 Subject: [PATCH 4/4] Add OPSD vLLM rollout scaffold, Qwen2/Qwen3 weight bridges, and README Lands the second-stage rollout path, weight-sync infrastructure, and the example app's README. Includes: * VLLMRollout that constructs vllm.LLM on training rank 0 and broadcasts generated token ids to peer ranks, with disjoint-GPU (subprocess) and shared (in-process) topology paths. Weight sync gathers ZeRO-3 params cooperatively then pushes to vLLM via LLM.collective_rpc("load_weights"). * WeightBridge ABC with COLUMN / ROW / VOCAB / REPLICATED parallel kinds and an even-slice per-rank slicer; Qwen2WeightBridge with the full per-parameter table for Qwen2 / Qwen2.5; Qwen3WeightBridge adding the per-head q_norm / k_norm tensors as REPLICATED. * vLLM-side prompt+response stitching factored into stitch_rollout() so its index math is unit-testable without a live vLLM. * CPU-only tests: tests/test_weight_bridge.py covers parallel-kind dispatch, per-rank shape/gather round-trips across tp_size in {1,2,4}, indivisibility / invalid-rank guards, and the registry; tests/test_vllm_stitch.py covers prompt/response stitching for the common shapes including variable response lengths and left-padded prompts. * configs + launch scripts for both production and smoke vLLM runs. **Known blocker called out in README and module docstring:** vLLM's worker init calls new_group() on the global process group, which deadlocks when launched under the standard `deepspeed --num_gpus N` launcher (rank 0 calls vLLM, other ranks never participate in vLLM's collective). The documented fix is the TRL/OpenRLHF separate-server pattern; this PR lands the scaffolding so that work can begin against a green codebase. Signed-off-by: Zhipeng Wang --- examples/opsd/README.md | 232 +++++++++++++ examples/opsd/configs/opsd_vllm_disjoint.json | 54 +++ examples/opsd/configs/smoke_vllm.json | 55 +++ examples/opsd/opsd/rollout/vllm.py | 314 ++++++++++++++++++ examples/opsd/opsd/weight_bridge/__init__.py | 32 ++ examples/opsd/opsd/weight_bridge/base.py | 109 ++++++ examples/opsd/opsd/weight_bridge/qwen2.py | 84 +++++ examples/opsd/opsd/weight_bridge/qwen3.py | 37 +++ examples/opsd/scripts/train_opsd_vllm.sh | 19 ++ examples/opsd/tests/test_vllm_stitch.py | 97 ++++++ examples/opsd/tests/test_weight_bridge.py | 259 +++++++++++++++ 11 files changed, 1292 insertions(+) create mode 100644 examples/opsd/README.md create mode 100644 examples/opsd/configs/opsd_vllm_disjoint.json create mode 100644 examples/opsd/configs/smoke_vllm.json create mode 100644 examples/opsd/opsd/rollout/vllm.py create mode 100644 examples/opsd/opsd/weight_bridge/__init__.py create mode 100644 examples/opsd/opsd/weight_bridge/base.py create mode 100644 examples/opsd/opsd/weight_bridge/qwen2.py create mode 100644 examples/opsd/opsd/weight_bridge/qwen3.py create mode 100644 examples/opsd/scripts/train_opsd_vllm.sh create mode 100644 examples/opsd/tests/test_vllm_stitch.py create mode 100644 examples/opsd/tests/test_weight_bridge.py diff --git a/examples/opsd/README.md b/examples/opsd/README.md new file mode 100644 index 000000000000..9eab8485a707 --- /dev/null +++ b/examples/opsd/README.md @@ -0,0 +1,232 @@ +# On-Policy Distillation (OPSD) on DeepSpeed + +A DeepSpeed-native port of [HJSang/OPSD_OnPolicyDistillation](https://github.com/HJSang/OPSD_OnPolicyDistillation), +removing the verl dependency and building directly on DeepSpeed primitives +(ZeRO-3, hybrid engine, `deepspeed.initialize`). + +On-policy distillation trains a small **student** model to imitate a large +frozen **teacher** on the student's *own* generated rollouts. Each training +step has three phases: + +``` +┌────────────┐ prompts ┌──────────────────┐ prompt+response ┌────────────┐ +│ Dataloader │ ──────────▶ │ Student rollout │ ──────────────────▶ │ Teacher │ +└────────────┘ │ (hybrid / vLLM) │ │ forward │ + └──────────────────┘ └─────┬──────┘ + │ logits → CPU cache + ▼ + ┌─────────────────────┐ + │ Student forward + │ + │ streamed KL / JSD + │ + │ backward / step │ + └─────────────────────┘ +``` + +Loss = per-token divergence (`forward_kl` | `reverse_kl` | `jsd`) between +student and teacher distributions on the student's generated tokens, chunked +over the sequence axis so the full `[B, T, V]` teacher tensor never +co-resides with the student logits on the training device. + +## Layout + +``` +examples/opsd/ +├── main.py # entry point (deepspeed launcher) +├── opsd/ +│ ├── config.py # OPSDConfig dataclass + JSON loader +│ ├── losses.py # chunked / streamed KL & JSD +│ ├── teacher.py # frozen teacher + CPU logit cache +│ ├── trainer.py # three-phase training loop +│ ├── data.py # JSONL prompt dataset + left-pad collator +│ ├── utils.py # response-mask + shift helpers +│ ├── rollout/ +│ │ ├── base.py # RolloutEngine ABC, request/batch dataclasses +│ │ ├── hybrid_engine.py # DeepSpeed hybrid-engine rollout +│ │ └── vllm.py # vLLM rollout on disjoint GPUs +│ └── weight_bridge/ +│ ├── base.py # ParallelKind + per-rank slicer +│ ├── qwen2.py # Qwen2 / Qwen2.5 TP mapping +│ └── qwen3.py # Qwen3 dense (adds q_norm/k_norm) +├── configs/ +│ ├── ds_zero3.json # base DeepSpeed ZeRO-3 + hybrid engine +│ ├── opsd_hybrid_engine.json # production-ish hybrid-engine OPSD config +│ ├── opsd_vllm_disjoint.json # vLLM rollout on a disjoint GPU group +│ ├── smoke_hybrid.json # 5-step smoke test with Qwen2.5-0.5B / 1.5B +│ ├── smoke_vllm.json # same but with vLLM rollout +│ └── smoke_ds_zero3.json # ZeRO-3 config tuned for smoke runs +├── scripts/ +│ ├── train_opsd_hybrid.sh # launch hybrid-engine training +│ └── train_opsd_vllm.sh # launch vLLM training +└── tests/ # CPU-only unit tests (run with pytest) +``` + +## Quick start + +### Install + +``` +pip install deepspeed transformers datasets accelerate +# Optional, only for the vLLM rollout backend: +pip install 'vllm>=0.6.4' +``` + +### Hybrid-engine training (single-node, no vLLM) + +``` +cd examples/opsd +NUM_GPUS=8 bash scripts/train_opsd_hybrid.sh configs/opsd_hybrid_engine.json +``` + +The hybrid engine path lives entirely within DeepSpeed: the student engine +both trains and generates, sharing weights without a copy step. Easiest to +get running; slower generation than vLLM. + +### vLLM training (disjoint GPU group) + +``` +cd examples/opsd +# Train on GPUs 0..5, run vLLM on 6,7 (matches default config) +NUM_TRAIN_GPUS=6 INCLUDE_GPUS=0,1,2,3,4,5 \ + bash scripts/train_opsd_vllm.sh configs/opsd_vllm_disjoint.json +``` + +vLLM gets dedicated GPUs (`rollout.gpus` in the config). Training rank 0 +constructs the `LLM` handle; other training ranks receive generated token +ids via NCCL broadcast. + +### Smoke tests (5 steps, small models) + +The `smoke_*.json` configs run on 2 GPUs in a few minutes with Qwen2.5-0.5B +(student) and Qwen2.5-1.5B (teacher), so the full pipeline can be validated +end-to-end before scaling up. + +``` +cd examples/opsd +deepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.json +# For vLLM (uses GPUs 0,1 for training and 2,3 for vLLM): +NUM_TRAIN_GPUS=2 INCLUDE_GPUS=0,1 deepspeed --num_gpus 2 --include localhost:0,1 \ + main.py --config configs/smoke_vllm.json +``` + +## Unit tests + +The CPU-runnable test suite exercises the loss math, teacher caching, rollout +contract, weight-bridge TP slicing, and vLLM stitch logic. Run with: + +``` +cd examples/opsd +python -m pytest tests/ -v +``` + +## Configuration + +`OPSDConfig` is a plain dataclass loaded from JSON (no Hydra). The schema: + +```json +{ + "student": { "model_name_or_path": "...", "dtype": "bfloat16", "arch": "qwen2" }, + "teacher": { "model_name_or_path": "...", "dtype": "bfloat16", "offload_to_cpu": true }, + "rollout": { "engine": "hybrid_engine | vllm", ... }, + "distillation": { "loss_type": "reverse_kl", "temperature": 1.0, "chunk_size": 512 }, + "training": { "train_batch_size": 8, "learning_rate": 1e-6, ... }, + "data": { "path": "data/prompts.jsonl", "prompt_field": "prompt" }, + "deepspeed_config": "configs/ds_zero3.json" +} +``` + +See `configs/opsd_hybrid_engine.json` and `configs/opsd_vllm_disjoint.json` +for fully-populated examples. + +## Adding a new model architecture + +To support a model the bridge doesn't recognise yet: + +1. Add `opsd/weight_bridge/.py` subclassing `Qwen2WeightBridge` (or + `WeightBridge` directly) and override `parallel_kind` / `_extra_layer_kind` + for any parameters not in Qwen2's table. +2. Register the new arch in `opsd/weight_bridge/__init__.py::get_bridge`. +3. Add a test in `tests/test_weight_bridge.py` covering parallel-kind dispatch + and a slice-then-gather round trip for one layer of realistic shapes. + +## Design notes + +* **Why CPU-cache the teacher logits?** Holding both student and teacher + `[B, T, V]` tensors on GPU at once doubles memory pressure. Staging the + teacher to host between the teacher forward and the student backward halves + the worst-case GPU footprint of the loss path. The streamed loss + (`losses.streamed_distillation_loss`) pulls teacher chunks back to GPU + one sequence slice at a time so the full tensor never re-materialises. + +* **Why an abstract `RolloutEngine`?** The hybrid-engine and vLLM backends + have very different lifecycles (hybrid engine reads student weights live; + vLLM holds its own copy and must be synced) but the trainer should not + care. The ABC keeps the trainer engine-agnostic so additional backends + (e.g. a future colocated-vLLM-with-`sleep_mode`) drop in without touching + the loop. + +* **vLLM topology = disjoint, not colocated (v1).** The disjoint topology is + simpler to debug — failures in vLLM don't take down training and vice + versa. A colocated topology using vLLM 0.6.4+'s `sleep_mode` is planned as + a follow-up. + +* **Weight bridge does not pre-fuse QKV / gate-up.** vLLM's per-model loader + already knows how to fuse these from the standard HuggingFace layout, so + the bridge only handles per-rank slicing. + +## vLLM status + +The vLLM rollout (`opsd/rollout/vllm.py`) is **written and unit-tested but +not yet usable under the DeepSpeed launcher**. During live validation on +4× H200 we hit a blocking issue: + +> vLLM's worker init calls `new_group(...)` on the global process group as +> a collective. Under `deepspeed --num_gpus N`, the world is all `N` +> training ranks but only rank 0 calls into vLLM, so the constructor hangs +> waiting on the other ranks. Reproduced with vllm 0.6.6 + deepspeed 0.15.4 + +> torch 2.5.1. Standalone vLLM (world size 1) works in seconds. + +The fix requires running vLLM in a **separate top-level Python process** +with its own world, accessed over HTTP/RPC from the trainer — the pattern +used by TRL and OpenRLHF. That's a larger refactor than fits in this PR; +the current `VLLMRollout` will be the basis for it once landed. + +What's verified for the vLLM path today: +* `tests/test_vllm_stitch.py` — prompt + response stitching (CPU unit test) +* `tests/test_weight_bridge.py` — TP-slice math for Qwen2 / Qwen3 (CPU) +* `vllm.LLM` itself runs fine standalone on Qwen2.5-0.5B (validated) + +What's **not** verified: +* End-to-end training loop with `rollout.engine = "vllm"` in `OPSDConfig` +* `LLM.collective_rpc("load_weights", ...)` weight sync at training time + +The hybrid-engine path (`rollout.engine = "hybrid_engine"`) is validated +end-to-end on the same hardware. + +## Other known limitations (v1) + +* **vLLM weight sync (when it works) goes through pickle** — + `LLM.collective_rpc("load_weights", args=((name, tensor_on_cpu),))`. + Expect several seconds per sync on a 7B model. A faster v2 would broadcast + tensors via NCCL on a shared trainer↔vLLM process group — see verl's + `bucketed_weight_transfer.py` for a reference design. +* **vLLM `tensor_parallel_size > 1` is untested.** The weight bridge's + slicing math is unit-tested but no live run exists. +* **Reward-weighted distillation** (OPSD's `opd.reward_beta` knob) is not + ported. Easy to add: scale `per_tok` by a reward weight in the loss path. +* **GRPO and other on-policy RL recipes** are out of scope. The + `RolloutEngine` / `WeightBridge` abstractions are reusable, but a GRPO + trainer would add its own advantage / KL-to-reference logic on top. +* **Qwen3-MoE** is not covered. Add `weight_bridge/qwen3_moe.py` when needed. +* **Hybrid engine on Qwen-family models uses a ZeRO-3 fallback** (no + hybrid-engine inference acceleration), since DeepSpeed's inference policy + list only covers GPT2/GPT-NeoX/OPT/BLOOM/LLAMA/LLAMA2/InternLM as of 0.15. + The fallback gathers params via `GatheredParameters` and calls the HF + model's `generate` directly — correct, just ~3-5x slower than the + accelerated path. + +## References + +* OPSD reference repo: +* DeepSpeed hybrid engine: `deepspeed/runtime/hybrid_engine.py` +* verl rollout / weight-sync design (used as a cross-check): + diff --git a/examples/opsd/configs/opsd_vllm_disjoint.json b/examples/opsd/configs/opsd_vllm_disjoint.json new file mode 100644 index 000000000000..9668b3702981 --- /dev/null +++ b/examples/opsd/configs/opsd_vllm_disjoint.json @@ -0,0 +1,54 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "arch": "qwen2" + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": true + }, + "rollout": { + "engine": "vllm", + "max_prompt_length": 1024, + "max_response_length": 1024, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "gpus": [6, 7], + "tensor_parallel_size": 2, + "gpu_memory_utilization": 0.85, + "vllm_dtype": "bfloat16", + "weight_sync_interval": 4, + "vllm_min_version": "0.6.4" + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 1.0, + "chunk_size": 512 + }, + "training": { + "train_batch_size": 6, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": -1, + "warmup_steps": 0, + "save_steps": 500, + "logging_steps": 10, + "save_dir": "./opsd_ckpt_vllm", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/ds_zero3.json" +} diff --git a/examples/opsd/configs/smoke_vllm.json b/examples/opsd/configs/smoke_vllm.json new file mode 100644 index 000000000000..8daf31537df2 --- /dev/null +++ b/examples/opsd/configs/smoke_vllm.json @@ -0,0 +1,55 @@ +{ + "student": { + "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "arch": "qwen2" + }, + "teacher": { + "model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct", + "dtype": "bfloat16", + "trust_remote_code": false, + "offload_to_cpu": false + }, + "rollout": { + "engine": "vllm", + "max_prompt_length": 128, + "max_response_length": 64, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "n_samples_per_prompt": 1, + "gpus": [], + "tensor_parallel_size": 1, + "gpu_memory_utilization": 0.3, + "vllm_dtype": "bfloat16", + "weight_sync_interval": 2, + "vllm_min_version": "0.6.4", + "vllm_enforce_eager": true + }, + "distillation": { + "loss_type": "reverse_kl", + "temperature": 1.0, + "chunk_size": 128 + }, + "training": { + "train_batch_size": 2, + "micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-6, + "weight_decay": 0.0, + "num_train_epochs": 1, + "max_steps": 5, + "warmup_steps": 0, + "save_steps": 10000, + "logging_steps": 1, + "save_dir": "./opsd_smoke_vllm_ckpt", + "seed": 42 + }, + "data": { + "path": "data/prompts.jsonl", + "prompt_field": "prompt", + "shuffle": true + }, + "deepspeed_config": "configs/smoke_ds_zero3.json" +} diff --git a/examples/opsd/opsd/rollout/vllm.py b/examples/opsd/opsd/rollout/vllm.py new file mode 100644 index 000000000000..947e43fbc7aa --- /dev/null +++ b/examples/opsd/opsd/rollout/vllm.py @@ -0,0 +1,314 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""vLLM rollout on a disjoint GPU group. + +**Topology (intended)** + * Training ranks 0..N-1 run the student under ZeRO-3 on the first N GPUs. + * vLLM workers run on the device indices listed in ``cfg.gpus`` (or in + "shared" mode, alongside training rank 0). + * The vLLM ``LLM`` handle is constructed **only on training rank 0**. + * Other training ranks receive generated token ids by broadcast from + rank 0 (:func:`deepspeed.comm.broadcast_object_list`). + +**Weight sync** + * All training ranks cooperatively gather each ZeRO-3 parameter via + :class:`deepspeed.runtime.zero.GatheredParameters`. + * Rank 0 pushes the full tensor to vLLM via ``LLM.collective_rpc(...)``, + which dispatches to every vLLM worker; each worker uses its own TP rank + to slice and load. + +**KNOWN BLOCKING ISSUE — same-process vLLM under the DeepSpeed launcher** + + vLLM's worker initialisation calls ``new_group(...)`` on the global + process group as a collective. Under the standard DeepSpeed launcher + (e.g. ``deepspeed --num_gpus 2``) the world spans **all** training + ranks, but only rank 0 calls into vLLM. The other training ranks never + participate in vLLM's collective, so the ``LLM`` constructor hangs + forever waiting on them. + + This was reproduced with vllm 0.6.6 + deepspeed 0.15.4 + torch 2.5.1; the + same code-path completes in seconds when ``LLM`` is constructed in a + process whose world size is 1. Verified by minimal repro (rank 0 LLM + init blocks; rank 1 idle). + + **Workarounds (none currently implemented):** + 1. Run vLLM in a **separate top-level Python process** with its own + world (size 1), and have the trainer talk to it over an HTTP or + RPC channel. This is what TRL and OpenRLHF do for their vLLM + backends. + 2. Spawn vLLM as a subprocess from rank 0 and tunnel calls through a + queue. Similar to (1) but lower-level. + 3. Wait for upstream vLLM to expose a flag that skips its internal + ``new_group`` calls when the caller already owns process-group + setup. + + Until one of those lands, **the vLLM rollout in this PR is verified at + the unit-test level only** (see ``tests/test_vllm_stitch.py`` and + ``tests/test_weight_bridge.py``). The hybrid engine rollout is the + fully-validated live path. See the project README's "vLLM status" + section for current state. +""" + +import os +from typing import Any, List, Optional + +import torch + +from opsd.config import RolloutConfig +from opsd.rollout.base import RolloutBatch, RolloutEngine, RolloutRequest, SamplingConfig +from opsd.weight_bridge import WeightBridge, get_bridge + + +def _is_rank_zero() -> bool: + # Deferred so this module remains importable in CPU-only test envs that + # don't have ``deepspeed`` available (the ``stitch_rollout`` helper below + # is pure tensor math and is unit-tested without DeepSpeed). + from deepspeed import comm as dist + + return (not dist.is_initialized()) or dist.get_rank() == 0 + + +def stitch_rollout( + prompt_ids: torch.Tensor, + prompt_attention_mask: torch.Tensor, + responses: List[List[int]], + pad_id: int, + n_samples_per_prompt: int, +) -> RolloutBatch: + """Stitch left-padded prompts and per-sample response token ids into one + right-padded ``RolloutBatch``. + + This is the only piece of vLLM-side post-processing that doesn't depend + on a live LLM handle, so we factor it out for CPU unit testing. + + Args: + prompt_ids: ``[B, T_p]`` left-padded prompts. + prompt_attention_mask: ``[B, T_p]`` matching attention mask. + responses: list of length ``B * n_samples_per_prompt``; each element + is the list of generated token ids for one (prompt, sample). + pad_id: pad token used for both prompt left-padding and response + right-padding (typically the tokenizer's ``pad_token_id`` or + ``eos_token_id``). + n_samples_per_prompt: number of generated samples per prompt. + + Returns: + :class:`RolloutBatch` with ``response_start_idx = T_p`` for every + sample. + """ + B, T_p = prompt_ids.shape + n = n_samples_per_prompt + expected = B * n + if len(responses) != expected: + raise ValueError(f"expected {expected} response token-id lists " + f"(B={B} * n_samples={n}); got {len(responses)}") + + if responses: + max_response_len = max(len(r) for r in responses) + else: + max_response_len = 0 + T_total = T_p + max_response_len + device = prompt_ids.device + + out_ids = torch.full((expected, T_total), pad_id, dtype=torch.long, device=device) + out_attn = torch.zeros((expected, T_total), dtype=prompt_attention_mask.dtype, device=device) + + prompts_expanded = prompt_ids.repeat_interleave(n, dim=0) + attn_expanded = prompt_attention_mask.repeat_interleave(n, dim=0) + out_ids[:, :T_p] = prompts_expanded + out_attn[:, :T_p] = attn_expanded + + for i, resp in enumerate(responses): + L = len(resp) + if L == 0: + continue + out_ids[i, T_p:T_p + L] = torch.tensor(resp, dtype=torch.long, device=device) + out_attn[i, T_p:T_p + L] = 1 + + response_start_idx = torch.full((expected, ), T_p, dtype=torch.long, device=device) + return RolloutBatch(input_ids=out_ids, attention_mask=out_attn, response_start_idx=response_start_idx) + + +class VLLMRollout(RolloutEngine): + + name = "vllm" + + def __init__( + self, + cfg: RolloutConfig, + tokenizer: Any, + student_engine: Any = None, + student_model_path: Optional[str] = None, + arch: Optional[str] = None, + ): + if cfg.engine != "vllm": + raise ValueError(f"RolloutConfig.engine must be 'vllm'; got {cfg.engine!r}") + if student_model_path is None: + raise ValueError("VLLMRollout needs student_model_path to initialise the vLLM engine " + "(it loads weights from disk at construction time)") + + self.cfg = cfg + self.tokenizer = tokenizer + self.student_engine = student_engine + self._model_path = student_model_path + + self.is_rank_zero = _is_rank_zero() + self.llm: Optional[Any] = None + self.bridge: Optional[WeightBridge] = get_bridge(arch) if arch is not None else None + + if self.is_rank_zero: + self._init_vllm() + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + def _init_vllm(self) -> None: + # Topology selection: + # * cfg.gpus empty → SHARED: vLLM runs in-process on the same GPU + # as training rank 0. Simple; no CUDA visibility tricks. Used for + # smoke tests and when vLLM + student fit alongside each other. + # * cfg.gpus set → DISJOINT: vLLM workers are pinned to the + # listed devices via CUDA_VISIBLE_DEVICES + a spawn-mode + # subprocess executor so the new CUDA context isn't inherited + # from the already-initialised rank-0 process. + shared = not self.cfg.gpus + + prev_cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + prev_mp = os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + if not shared: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in self.cfg.gpus) + # Must be set before the vllm import; the value is read at import time. + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + try: + try: + from vllm import LLM + except ImportError as e: + raise ImportError(f"VLLMRollout requires vllm>={self.cfg.vllm_min_version}. " + f"Install with: pip install 'vllm>={self.cfg.vllm_min_version}'") from e + + llm_kwargs = dict( + model=self._model_path, + tensor_parallel_size=self.cfg.tensor_parallel_size, + gpu_memory_utilization=self.cfg.gpu_memory_utilization, + dtype=self.cfg.vllm_dtype, + enforce_eager=self.cfg.vllm_enforce_eager, + ) + if not shared: + llm_kwargs["distributed_executor_backend"] = "mp" + self.llm = LLM(**llm_kwargs) + finally: + if prev_cvd is None: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = prev_cvd + if prev_mp is None: + os.environ.pop("VLLM_WORKER_MULTIPROC_METHOD", None) + else: + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = prev_mp + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + def generate(self, request: RolloutRequest, sampling: SamplingConfig) -> RolloutBatch: + B = int(request.prompt_ids.shape[0]) + n = sampling.n_samples_per_prompt + + if self.is_rank_zero: + from vllm import SamplingParams + + # We send prompt *token ids* rather than text to vLLM so the + # generation stays bit-exact with how the trainer tokenised. This + # avoids any subtle BOS / special-token differences between the + # trainer's and vLLM's text->id paths. + prompt_token_ids: List[List[int]] = [] + for i in range(B): + mask = request.prompt_attention_mask[i].bool() + ids = request.prompt_ids[i][mask].tolist() + prompt_token_ids.append(ids) + + sp = SamplingParams( + n=n, + temperature=sampling.temperature, + top_p=sampling.top_p, + top_k=sampling.top_k if sampling.top_k > 0 else -1, + max_tokens=sampling.max_new_tokens, + ) + results = self.llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sp, use_tqdm=False) + responses: List[List[int]] = [] + for r in results: + for out in r.outputs: + responses.append(list(out.token_ids)) + else: + responses = [] + + from deepspeed import comm as dist + + if dist.is_initialized() and dist.get_world_size() > 1: + obj = [responses] + dist.broadcast_object_list(obj, src=0) + responses = obj[0] + + pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + return stitch_rollout( + prompt_ids=request.prompt_ids, + prompt_attention_mask=request.prompt_attention_mask, + responses=responses, + pad_id=pad_id, + n_samples_per_prompt=n, + ) + + # ------------------------------------------------------------------ + # Weight sync + # ------------------------------------------------------------------ + + def sync_weights_from_student(self, step: int) -> None: + if self.student_engine is None: + return + if self.bridge is None: + # Best-effort inference of arch from the student model class name. + model = self.student_engine.module + cls = type(model).__name__.lower() + if "qwen3" in cls: + self.bridge = get_bridge("qwen3") + elif "qwen2" in cls: + self.bridge = get_bridge("qwen2") + else: + raise RuntimeError(f"Cannot infer weight bridge for student class {cls!r}; " + f"set StudentConfig.arch explicitly") + + from deepspeed.runtime.zero import GatheredParameters + + model = self.student_engine.module + for name, param in model.named_parameters(): + # GatheredParameters is a no-op when ZeRO stage < 3, and a full + # all-gather when stage == 3. Either way every rank sees the full + # tensor inside the context; only rank 0 forwards it to vLLM. + with GatheredParameters([param], modifier_rank=0): + if not self.is_rank_zero: + continue + # Sanity-check the param name against the bridge so a renamed + # parameter trips here (cheap) rather than as a silent layout + # mismatch inside vLLM later (very hard to debug). + self.bridge.parallel_kind(name) + self._push_one_param(name, param.data.detach()) + + def _push_one_param(self, name: str, tensor: torch.Tensor) -> None: + # collective_rpc dispatches to every vLLM worker; pickle handles the + # tensor transfer. CPU tensors pickle cleanly across process bounds. + cpu = tensor.contiguous().cpu() + # vLLM's per-architecture model class exposes ``load_weights`` taking + # an iterable of (name, tensor) pairs and internally handles QKV / + # gate_up fusion plus per-rank slicing for tensor parallelism. + self.llm.collective_rpc("load_weights", args=([(name, cpu)], )) + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self.llm is not None: + del self.llm + self.llm = None diff --git a/examples/opsd/opsd/weight_bridge/__init__.py b/examples/opsd/opsd/weight_bridge/__init__.py new file mode 100644 index 000000000000..b415b1a1b0e8 --- /dev/null +++ b/examples/opsd/opsd/weight_bridge/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Architecture-specific bridges that slice HuggingFace weights for vLLM TP. + +A bridge takes the student's full ``(name, tensor)`` pairs (after we've +gathered them across ZeRO-3 ranks) and emits the per-vLLM-rank slices ready +to push into vLLM's ``model.load_weights(...)``. + +vLLM internally fuses Q/K/V into ``qkv_proj`` and gate/up into ``gate_up_proj``. +We do **not** pre-fuse on our side — vLLM's loader already understands the +unfused HuggingFace layout — so the bridge only needs to know each parameter's +parallel kind (column / row / vocab / replicated) and slice on the right dim. +""" + +from opsd.weight_bridge.base import ParallelKind, WeightBridge +from opsd.weight_bridge.qwen2 import Qwen2WeightBridge +from opsd.weight_bridge.qwen3 import Qwen3WeightBridge + +__all__ = ["WeightBridge", "ParallelKind", "Qwen2WeightBridge", "Qwen3WeightBridge", "get_bridge"] + + +def get_bridge(arch: str) -> WeightBridge: + """Look up a bridge by architecture key (matches HF's ``model_type``).""" + key = arch.lower() + if key in ("qwen2", "qwen2.5"): + return Qwen2WeightBridge() + if key in ("qwen3", ): + return Qwen3WeightBridge() + raise ValueError(f"No weight bridge registered for arch {arch!r}; " + f"add a sibling of opsd/weight_bridge/qwen2.py and register here") diff --git a/examples/opsd/opsd/weight_bridge/base.py b/examples/opsd/opsd/weight_bridge/base.py new file mode 100644 index 000000000000..3e780a05ae68 --- /dev/null +++ b/examples/opsd/opsd/weight_bridge/base.py @@ -0,0 +1,109 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""WeightBridge ABC: per-tensor TP slicing for vLLM weight sync.""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Iterable, Iterator, Tuple + +import torch + + +class ParallelKind(str, Enum): + """How a single parameter is split across vLLM TP ranks. + + Notation matches the standard Megatron-style decomposition: + + * ``COLUMN`` — output dim (dim 0) is split. Each rank owns + ``out_features / tp`` rows. Used for attention Q/K/V and MLP + gate/up. + * ``ROW`` — input dim (dim 1) is split. Each rank owns + ``in_features / tp`` columns. Used for attention output projection + and MLP down projection. + * ``VOCAB`` — like COLUMN but applied to the embedding / LM head where + the partitioned dim is the vocab axis. Treated the same as COLUMN + for slicing purposes; the kind is kept distinct to make divisibility + diagnostics clearer at debug time. + * ``REPLICATED`` — the same tensor lives on every rank + (layer norms, RMSNorm scalars, per-head q_norm/k_norm in Qwen3). + """ + + COLUMN = "column" + ROW = "row" + VOCAB = "vocab" + REPLICATED = "replicated" + + +def _even_slice(t: torch.Tensor, dim: int, rank: int, tp_size: int) -> torch.Tensor: + """Return rank ``rank`` 's contiguous chunk of ``t`` along ``dim``. + + Refuses uneven divisions so that bugs surface here rather than as silent + layout mismatches once weights are loaded into vLLM. + """ + total = int(t.shape[dim]) + if total % tp_size != 0: + raise ValueError(f"Shape {tuple(t.shape)} dim {dim} (={total}) not divisible by " + f"tp_size {tp_size}") + per = total // tp_size + return t.narrow(dim, rank * per, per).contiguous() + + +class WeightBridge(ABC): + """Strategy object that maps HuggingFace param names to a parallel kind. + + Subclasses only need to implement :meth:`parallel_kind`; the slicing + machinery is inherited. + """ + + # Subclasses set this to a human-readable tag, e.g. "qwen2". + arch: str = "base" + + @abstractmethod + def parallel_kind(self, hf_name: str) -> ParallelKind: + """Return how parameter ``hf_name`` should be partitioned across TP.""" + + def slice_for_rank( + self, + hf_name: str, + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + ) -> torch.Tensor: + """Return the slice of ``tensor`` that belongs to rank ``tp_rank``.""" + if tp_size < 1 or not (0 <= tp_rank < tp_size): + raise ValueError(f"invalid tp_rank={tp_rank} for tp_size={tp_size}") + if tp_size == 1: + return tensor + kind = self.parallel_kind(hf_name) + if kind is ParallelKind.REPLICATED: + return tensor + # COLUMN and VOCAB partition dim 0 (output / vocab). ROW partitions + # dim 1 (input). Both kinds may apply to 1-D tensors (biases): for a + # 1-D bias on a COLUMN-parallel linear, dim 0 IS the partitioned dim. + if kind in (ParallelKind.COLUMN, ParallelKind.VOCAB): + return _even_slice(tensor, dim=0, rank=tp_rank, tp_size=tp_size) + if kind is ParallelKind.ROW: + if tensor.dim() < 2: + # Row-parallel linears have a replicated bias (vLLM convention), + # so a 1-D tensor reaching this branch is a bug. + raise ValueError(f"ROW parallel kind requires >=2-D tensor for {hf_name}; " + f"got shape {tuple(tensor.shape)}") + return _even_slice(tensor, dim=1, rank=tp_rank, tp_size=tp_size) + raise ValueError(f"unhandled parallel kind {kind!r}") + + def map_state_dict( + self, + hf_named_tensors: Iterable[Tuple[str, torch.Tensor]], + tp_rank: int, + tp_size: int, + ) -> Iterator[Tuple[str, torch.Tensor]]: + """Yield ``(vllm_name, sliced_tensor)`` for every input pair. + + For Qwen-family models the vLLM parameter name is identical to the + HF name (vLLM's loader handles QKV/gate-up fusion internally), so the + emitted names are unchanged. + """ + for name, tensor in hf_named_tensors: + yield name, self.slice_for_rank(name, tensor, tp_rank, tp_size) diff --git a/examples/opsd/opsd/weight_bridge/qwen2.py b/examples/opsd/opsd/weight_bridge/qwen2.py new file mode 100644 index 000000000000..903d47e81c1f --- /dev/null +++ b/examples/opsd/opsd/weight_bridge/qwen2.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Weight bridge for Qwen2 / Qwen2.5 dense models. + +Naming follows the standard HF Qwen2 layout:: + + model.embed_tokens.weight + model.layers.{i}.self_attn.{q,k,v,o}_proj.{weight,bias} + model.layers.{i}.mlp.{gate,up,down}_proj.weight + model.layers.{i}.{input,post_attention}_layernorm.weight + model.norm.weight + lm_head.weight # may be tied to embed_tokens + +Parallel kinds: + * Q/K/V projections — column-parallel (split heads across ranks) + * Attention output projection — row-parallel + * MLP gate / up projections — column-parallel + * MLP down projection — row-parallel + * Layer norms / final norm — replicated + * Token embedding & LM head — vocab-parallel (split vocab dim) + * Bias on Q/K/V — column-parallel (1-D bias on a column-parallel linear) + * Bias on o_proj / down_proj — replicated (row-parallel linears have a + replicated bias under vLLM's convention; the partial sums are reduced + before the bias add) +""" + +import re + +from opsd.weight_bridge.base import ParallelKind, WeightBridge + +_LAYER_RE = re.compile(r"^model\.layers\.\d+\.(?P.+)$") + + +class Qwen2WeightBridge(WeightBridge): + arch = "qwen2" + + # Suffix → parallel kind. Keyed by the part after "model.layers.{i}." for + # transformer-block params, plus a few full names for embeddings / norms. + _LAYER_RULES = { + "self_attn.q_proj.weight": ParallelKind.COLUMN, + "self_attn.k_proj.weight": ParallelKind.COLUMN, + "self_attn.v_proj.weight": ParallelKind.COLUMN, + "self_attn.q_proj.bias": ParallelKind.COLUMN, + "self_attn.k_proj.bias": ParallelKind.COLUMN, + "self_attn.v_proj.bias": ParallelKind.COLUMN, + "self_attn.o_proj.weight": ParallelKind.ROW, + "self_attn.o_proj.bias": ParallelKind.REPLICATED, + "mlp.gate_proj.weight": ParallelKind.COLUMN, + "mlp.up_proj.weight": ParallelKind.COLUMN, + "mlp.down_proj.weight": ParallelKind.ROW, + "mlp.down_proj.bias": ParallelKind.REPLICATED, + "input_layernorm.weight": ParallelKind.REPLICATED, + "post_attention_layernorm.weight": ParallelKind.REPLICATED, + } + + _GLOBAL_RULES = { + "model.embed_tokens.weight": ParallelKind.VOCAB, + "model.norm.weight": ParallelKind.REPLICATED, + "lm_head.weight": ParallelKind.VOCAB, + } + + def parallel_kind(self, hf_name: str) -> ParallelKind: + if hf_name in self._GLOBAL_RULES: + return self._GLOBAL_RULES[hf_name] + m = _LAYER_RE.match(hf_name) + if m is not None: + rest = m.group("rest") + if rest in self._LAYER_RULES: + return self._LAYER_RULES[rest] + # Per-layer name not in our table — surface a clear error so the + # weight sync isn't silently wrong for an unrecognised tensor. + extra = self._extra_layer_kind(rest) + if extra is not None: + return extra + raise KeyError(f"Unknown per-layer Qwen2 parameter suffix {rest!r}; add a rule " + f"in Qwen2WeightBridge._LAYER_RULES") + raise KeyError(f"Unknown Qwen2 parameter name {hf_name!r}") + + def _extra_layer_kind(self, _suffix: str): # noqa: D401, ARG002 + """Hook for subclasses (Qwen3) to add per-layer rules without + duplicating the rest of the table.""" + return None diff --git a/examples/opsd/opsd/weight_bridge/qwen3.py b/examples/opsd/opsd/weight_bridge/qwen3.py new file mode 100644 index 000000000000..6b3d7695ed32 --- /dev/null +++ b/examples/opsd/opsd/weight_bridge/qwen3.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Weight bridge for Qwen3 dense models. + +Qwen3-dense uses the same overall layout as Qwen2 with one addition: +per-head RMSNorm applied to the query and key projections before attention:: + + model.layers.{i}.self_attn.q_norm.weight # shape [head_dim] + model.layers.{i}.self_attn.k_norm.weight # shape [head_dim] + +These weights are 1-D over ``head_dim`` (not ``num_heads * head_dim``), so they +are **replicated** across TP ranks: every rank owns a subset of heads but each +head normalises with the same per-head-dim scalars. + +Qwen3-MoE (the ``Qwen3MoeForCausalLM`` family) is **not** covered here — MoE +introduces gate/expert routing and per-expert MLPs that need their own bridge. +Add a sibling ``qwen3_moe.py`` when that path becomes a priority. +""" + +from typing import Optional + +from opsd.weight_bridge.base import ParallelKind +from opsd.weight_bridge.qwen2 import Qwen2WeightBridge + + +class Qwen3WeightBridge(Qwen2WeightBridge): + arch = "qwen3" + + _Q_NORM = "self_attn.q_norm.weight" + _K_NORM = "self_attn.k_norm.weight" + + def _extra_layer_kind(self, suffix: str) -> Optional[ParallelKind]: + if suffix in (self._Q_NORM, self._K_NORM): + return ParallelKind.REPLICATED + return None diff --git a/examples/opsd/scripts/train_opsd_vllm.sh b/examples/opsd/scripts/train_opsd_vllm.sh new file mode 100644 index 000000000000..83ed4dc96d7e --- /dev/null +++ b/examples/opsd/scripts/train_opsd_vllm.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +# +# Launch OPSD training with vLLM rollout on a disjoint GPU group. +# +# Default config assumes 8 GPUs: ranks 0..5 train (ZeRO-3), devices 6-7 run +# vLLM with TP=2. Adjust configs/opsd_vllm_disjoint.json::rollout.gpus and +# NUM_TRAIN_GPUS to match your topology. +set -euo pipefail + +CONFIG="${1:-configs/opsd_vllm_disjoint.json}" +NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-6}" +INCLUDE_GPUS="${INCLUDE_GPUS:-0,1,2,3,4,5}" + +deepspeed --num_gpus "${NUM_TRAIN_GPUS}" --include "localhost:${INCLUDE_GPUS}" \ + main.py --config "${CONFIG}" diff --git a/examples/opsd/tests/test_vllm_stitch.py b/examples/opsd/tests/test_vllm_stitch.py new file mode 100644 index 000000000000..bd8e1b4e4c0f --- /dev/null +++ b/examples/opsd/tests/test_vllm_stitch.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only tests for the vLLM rollout post-processing. + +We can't run vLLM here, but the prompt/response stitching is pure tensor +manipulation and is the part most prone to silent index bugs. +""" + +import pytest +import torch + +from opsd.rollout.vllm import stitch_rollout +from opsd.utils import build_response_mask + + +def test_stitch_basic_single_sample(): + prompt_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) + attn = torch.ones(2, 3, dtype=torch.long) + responses = [[10, 11, 12], [20, 21]] + out = stitch_rollout(prompt_ids, attn, responses, pad_id=0, n_samples_per_prompt=1) + assert out.input_ids.shape == (2, 6) + assert out.input_ids[0].tolist() == [1, 2, 3, 10, 11, 12] + assert out.input_ids[1].tolist() == [4, 5, 6, 20, 21, 0] + assert out.attention_mask[0].tolist() == [1, 1, 1, 1, 1, 1] + assert out.attention_mask[1].tolist() == [1, 1, 1, 1, 1, 0] + assert out.response_start_idx.tolist() == [3, 3] + + +def test_stitch_with_n_samples(): + prompt_ids = torch.tensor([[1, 2], [3, 4]]) + attn = torch.ones(2, 2, dtype=torch.long) + responses = [[5, 6], [7, 8], [9, 10], [11, 12]] + out = stitch_rollout(prompt_ids, attn, responses, pad_id=0, n_samples_per_prompt=2) + assert out.input_ids.shape == (4, 4) + # Prompts are repeat_interleaved: [P0, P0, P1, P1]. + assert out.input_ids[0].tolist() == [1, 2, 5, 6] + assert out.input_ids[1].tolist() == [1, 2, 7, 8] + assert out.input_ids[2].tolist() == [3, 4, 9, 10] + assert out.input_ids[3].tolist() == [3, 4, 11, 12] + assert out.response_start_idx.tolist() == [2, 2, 2, 2] + + +def test_stitch_left_padded_prompts(): + prompt_ids = torch.tensor([[0, 1, 2], [3, 4, 5]]) + attn = torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long) + responses = [[6], [7]] + out = stitch_rollout(prompt_ids, attn, responses, pad_id=0, n_samples_per_prompt=1) + # Response begins at column T_p == 3 for both, regardless of prompt padding. + assert out.response_start_idx.tolist() == [3, 3] + # Prompt section keeps the caller's left-padding mask. + assert out.attention_mask[:, :3].tolist() == [[0, 1, 1], [1, 1, 1]] + + +def test_stitch_mismatched_response_count_raises(): + prompt_ids = torch.tensor([[1, 2]]) + attn = torch.ones(1, 2, dtype=torch.long) + with pytest.raises(ValueError, match="expected"): + stitch_rollout(prompt_ids, attn, [[3], [4]], pad_id=0, n_samples_per_prompt=1) + + +def test_stitch_empty_responses_still_well_shaped(): + prompt_ids = torch.tensor([[1, 2], [3, 4]]) + attn = torch.ones(2, 2, dtype=torch.long) + out = stitch_rollout(prompt_ids, attn, [[], []], pad_id=0, n_samples_per_prompt=1) + # No response tokens means total length == prompt length. + assert out.input_ids.shape == (2, 2) + # Mask over the (zero) response section is empty; response_start_idx still + # points at the end of the prompt. + assert out.response_start_idx.tolist() == [2, 2] + + +def test_stitch_handles_variable_response_lengths(): + prompt_ids = torch.tensor([[1], [2], [3]]) + attn = torch.ones(3, 1, dtype=torch.long) + responses = [[10], [20, 21, 22, 23], [30, 31]] + out = stitch_rollout(prompt_ids, attn, responses, pad_id=99, n_samples_per_prompt=1) + # Total length = T_p + max(response lengths) = 1 + 4 = 5. + assert out.input_ids.shape == (3, 5) + assert out.input_ids[0].tolist() == [1, 10, 99, 99, 99] + assert out.input_ids[1].tolist() == [2, 20, 21, 22, 23] + assert out.input_ids[2].tolist() == [3, 30, 31, 99, 99] + assert out.attention_mask[0].tolist() == [1, 1, 0, 0, 0] + assert out.attention_mask[1].tolist() == [1, 1, 1, 1, 1] + assert out.attention_mask[2].tolist() == [1, 1, 1, 0, 0] + + +def test_stitch_output_feeds_build_response_mask(): + prompt_ids = torch.tensor([[0, 1, 2], [3, 4, 5]]) + attn = torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long) + out = stitch_rollout(prompt_ids, attn, [[10, 11], [20]], pad_id=0, n_samples_per_prompt=1) + mask = build_response_mask(out.response_start_idx, out.attention_mask) + # Sample 0: T_p=3, response tokens at 3,4 (both attended). + assert mask[0].tolist() == [0, 0, 0, 1, 1] + # Sample 1: T_p=3, response token at 3 only (position 4 is pad). + assert mask[1].tolist() == [0, 0, 0, 1, 0] diff --git a/examples/opsd/tests/test_weight_bridge.py b/examples/opsd/tests/test_weight_bridge.py new file mode 100644 index 000000000000..9aa50414cbb2 --- /dev/null +++ b/examples/opsd/tests/test_weight_bridge.py @@ -0,0 +1,259 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""CPU-only tests for the TP weight bridges. + +These exercise the parallel-kind table and the per-rank slicing math without +requiring vLLM, GPUs, or real model checkpoints. +""" + +import pytest +import torch + +from opsd.weight_bridge import ParallelKind, Qwen2WeightBridge, Qwen3WeightBridge, get_bridge + +# Realistic-ish shapes for a Qwen2.5-0.5B-style model: hidden=896, num_heads=14, +# num_kv_heads=2, head_dim=64, intermediate=4864, vocab=151936. Picked so all +# the per-dim sizes are divisible by tp_size=2. +HIDDEN = 896 +NUM_HEADS = 14 +NUM_KV_HEADS = 2 +HEAD_DIM = 64 +INTERMEDIATE = 4864 +VOCAB = 151936 + + +def _qwen2_named_tensors(): + """A minimal stand-in for one layer of a Qwen2 state dict.""" + q_dim = NUM_HEADS * HEAD_DIM + kv_dim = NUM_KV_HEADS * HEAD_DIM + return [ + ("model.embed_tokens.weight", torch.randn(VOCAB, HIDDEN)), + ("model.layers.0.self_attn.q_proj.weight", torch.randn(q_dim, HIDDEN)), + ("model.layers.0.self_attn.k_proj.weight", torch.randn(kv_dim, HIDDEN)), + ("model.layers.0.self_attn.v_proj.weight", torch.randn(kv_dim, HIDDEN)), + ("model.layers.0.self_attn.q_proj.bias", torch.randn(q_dim)), + ("model.layers.0.self_attn.k_proj.bias", torch.randn(kv_dim)), + ("model.layers.0.self_attn.v_proj.bias", torch.randn(kv_dim)), + ("model.layers.0.self_attn.o_proj.weight", torch.randn(HIDDEN, q_dim)), + ("model.layers.0.mlp.gate_proj.weight", torch.randn(INTERMEDIATE, HIDDEN)), + ("model.layers.0.mlp.up_proj.weight", torch.randn(INTERMEDIATE, HIDDEN)), + ("model.layers.0.mlp.down_proj.weight", torch.randn(HIDDEN, INTERMEDIATE)), + ("model.layers.0.input_layernorm.weight", torch.randn(HIDDEN)), + ("model.layers.0.post_attention_layernorm.weight", torch.randn(HIDDEN)), + ("model.norm.weight", torch.randn(HIDDEN)), + ("lm_head.weight", torch.randn(VOCAB, HIDDEN)), + ] + + +# --- parallel kind dispatch ------------------------------------------------- + + +@pytest.mark.parametrize("name, expected", [ + ("model.embed_tokens.weight", ParallelKind.VOCAB), + ("model.layers.0.self_attn.q_proj.weight", ParallelKind.COLUMN), + ("model.layers.0.self_attn.k_proj.weight", ParallelKind.COLUMN), + ("model.layers.0.self_attn.v_proj.weight", ParallelKind.COLUMN), + ("model.layers.42.self_attn.q_proj.bias", ParallelKind.COLUMN), + ("model.layers.3.self_attn.o_proj.weight", ParallelKind.ROW), + ("model.layers.3.mlp.gate_proj.weight", ParallelKind.COLUMN), + ("model.layers.3.mlp.up_proj.weight", ParallelKind.COLUMN), + ("model.layers.3.mlp.down_proj.weight", ParallelKind.ROW), + ("model.layers.0.input_layernorm.weight", ParallelKind.REPLICATED), + ("model.layers.0.post_attention_layernorm.weight", ParallelKind.REPLICATED), + ("model.norm.weight", ParallelKind.REPLICATED), + ("lm_head.weight", ParallelKind.VOCAB), +]) +def test_qwen2_parallel_kinds(name, expected): + assert Qwen2WeightBridge().parallel_kind(name) == expected + + +def test_qwen2_unknown_layer_param_raises(): + with pytest.raises(KeyError, match="Unknown per-layer Qwen2"): + Qwen2WeightBridge().parallel_kind("model.layers.0.self_attn.q_norm.weight") + + +def test_qwen2_unknown_global_param_raises(): + with pytest.raises(KeyError, match="Unknown Qwen2 parameter"): + Qwen2WeightBridge().parallel_kind("totally.made.up.weight") + + +def test_qwen3_adds_qk_norm(): + bridge = Qwen3WeightBridge() + assert bridge.parallel_kind("model.layers.0.self_attn.q_norm.weight") == ParallelKind.REPLICATED + assert bridge.parallel_kind("model.layers.0.self_attn.k_norm.weight") == ParallelKind.REPLICATED + # Inherits the rest from Qwen2. + assert bridge.parallel_kind("model.layers.0.self_attn.q_proj.weight") == ParallelKind.COLUMN + + +# --- slicing math ----------------------------------------------------------- + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_column_slice_shapes(tp_size): + bridge = Qwen2WeightBridge() + w = torch.randn(NUM_HEADS * HEAD_DIM, HIDDEN) + for rank in range(tp_size): + sliced = bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, rank, tp_size) + assert sliced.shape == (NUM_HEADS * HEAD_DIM // tp_size, HIDDEN) + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_row_slice_shapes(tp_size): + bridge = Qwen2WeightBridge() + w = torch.randn(HIDDEN, NUM_HEADS * HEAD_DIM) + for rank in range(tp_size): + sliced = bridge.slice_for_rank("model.layers.0.self_attn.o_proj.weight", w, rank, tp_size) + assert sliced.shape == (HIDDEN, NUM_HEADS * HEAD_DIM // tp_size) + + +def test_replicated_returns_full_tensor(): + bridge = Qwen2WeightBridge() + w = torch.randn(HIDDEN) + for rank in range(4): + sliced = bridge.slice_for_rank("model.layers.0.input_layernorm.weight", w, rank, tp_size=4) + assert sliced.shape == w.shape + assert torch.equal(sliced, w) + + +def test_column_slices_gather_to_original(): + bridge = Qwen2WeightBridge() + w = torch.randn(NUM_HEADS * HEAD_DIM, HIDDEN) + tp_size = 2 + pieces = [bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, r, tp_size) for r in range(tp_size)] + assert torch.equal(torch.cat(pieces, dim=0), w) + + +def test_row_slices_gather_to_original(): + bridge = Qwen2WeightBridge() + w = torch.randn(HIDDEN, INTERMEDIATE) + tp_size = 4 + pieces = [bridge.slice_for_rank("model.layers.0.mlp.down_proj.weight", w, r, tp_size) for r in range(tp_size)] + assert torch.equal(torch.cat(pieces, dim=1), w) + + +def test_vocab_slices_gather_to_original(): + bridge = Qwen2WeightBridge() + w = torch.randn(VOCAB, HIDDEN) + tp_size = 4 + pieces = [bridge.slice_for_rank("model.embed_tokens.weight", w, r, tp_size) for r in range(tp_size)] + assert torch.equal(torch.cat(pieces, dim=0), w) + + +def test_bias_column_slices_gather_to_original(): + bridge = Qwen2WeightBridge() + b = torch.randn(NUM_HEADS * HEAD_DIM) + tp_size = 2 + pieces = [bridge.slice_for_rank("model.layers.0.self_attn.q_proj.bias", b, r, tp_size) for r in range(tp_size)] + assert torch.equal(torch.cat(pieces, dim=0), b) + + +def test_indivisible_shape_raises(): + bridge = Qwen2WeightBridge() + # 7 is not divisible by 2; should fail loudly rather than truncate. + w = torch.randn(7, HIDDEN) + with pytest.raises(ValueError, match="not divisible by"): + bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, 0, 2) + + +def test_invalid_rank_raises(): + bridge = Qwen2WeightBridge() + w = torch.randn(NUM_HEADS * HEAD_DIM, HIDDEN) + with pytest.raises(ValueError, match="invalid tp_rank"): + bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, 4, 4) + with pytest.raises(ValueError, match="invalid tp_rank"): + bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, -1, 2) + + +def test_row_parallel_rejects_1d(): + """The defensive check inside ``slice_for_rank`` is unreachable through + the real Qwen2 table (row-parallel biases are tagged REPLICATED), but a + future bridge could route a 1-D tensor through ROW. Exercise via a + minimal subclass so the guard stays covered.""" + + class _BadBridge(Qwen2WeightBridge): + + def parallel_kind(self, hf_name): # noqa: ARG002 + return ParallelKind.ROW + + with pytest.raises(ValueError, match="ROW parallel kind requires"): + _BadBridge().slice_for_rank("anything", torch.randn(HIDDEN), 0, 2) + + +def test_tp1_is_passthrough(): + bridge = Qwen2WeightBridge() + w = torch.randn(NUM_HEADS * HEAD_DIM, HIDDEN) + out = bridge.slice_for_rank("model.layers.0.self_attn.q_proj.weight", w, 0, 1) + assert torch.equal(out, w) + + +# --- state-dict iteration --------------------------------------------------- + + +def test_map_state_dict_emits_correct_shapes_for_tp2(): + bridge = Qwen2WeightBridge() + tp_size = 2 + # Build the source once; each rank consumes a fresh iterator over the + # same materialised list so we're slicing identical tensors. + src = _qwen2_named_tensors() + by_rank = {r: dict(bridge.map_state_dict(iter(src), r, tp_size)) for r in range(tp_size)} + src_by_name = dict(src) + + # Replicated tensors should be identical across ranks AND match source. + a = by_rank[0]["model.layers.0.input_layernorm.weight"] + b = by_rank[1]["model.layers.0.input_layernorm.weight"] + assert torch.equal(a, b) + assert torch.equal(a, src_by_name["model.layers.0.input_layernorm.weight"]) + + # Column-parallel Q: shapes halved on dim 0; gather reconstructs source. + q_full_rows = NUM_HEADS * HEAD_DIM + assert by_rank[0]["model.layers.0.self_attn.q_proj.weight"].shape == (q_full_rows // 2, HIDDEN) + gathered_q = torch.cat([ + by_rank[0]["model.layers.0.self_attn.q_proj.weight"], + by_rank[1]["model.layers.0.self_attn.q_proj.weight"], + ], + dim=0) + assert torch.equal(gathered_q, src_by_name["model.layers.0.self_attn.q_proj.weight"]) + + +def test_map_state_dict_gather_round_trip_with_fixed_seed(): + bridge = Qwen2WeightBridge() + torch.manual_seed(123) + src = _qwen2_named_tensors() + src_by_name = dict(src) + + tp_size = 4 + sliced = [list(bridge.map_state_dict(src, r, tp_size)) for r in range(tp_size)] + + # For every entry, reconstruct from per-rank slices and compare to the + # source. The reconstruction op depends on the parallel kind. + for r0_name, _ in sliced[0]: + kind = bridge.parallel_kind(r0_name) + per_rank = [dict(s)[r0_name] for s in sliced] + if kind is ParallelKind.REPLICATED: + recon = per_rank[0] + elif kind in (ParallelKind.COLUMN, ParallelKind.VOCAB): + recon = torch.cat(per_rank, dim=0) + elif kind is ParallelKind.ROW: + recon = torch.cat(per_rank, dim=1) + else: + raise AssertionError(f"unhandled kind {kind}") + assert torch.equal(recon, src_by_name[r0_name]), f"round-trip mismatch for {r0_name}" + + +# --- registry --------------------------------------------------------------- + + +def test_get_bridge_qwen2(): + assert isinstance(get_bridge("qwen2"), Qwen2WeightBridge) + assert isinstance(get_bridge("Qwen2.5"), Qwen2WeightBridge) + + +def test_get_bridge_qwen3(): + assert isinstance(get_bridge("qwen3"), Qwen3WeightBridge) + + +def test_get_bridge_unknown_raises(): + with pytest.raises(ValueError, match="No weight bridge registered"): + get_bridge("totally-made-up-arch")