diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 16114cb80..96ab2b7b9 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -207,7 +207,7 @@ def _set_target_inputs( model_input.targets.append(target_input) - def _get_label_counts(self, mask: torch.Tensor): + def _get_label_counts(self, mask: torch.Tensor) -> torch.Tensor: # Count the number of non-masked labels in each document through cumulative sums. mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) @@ -215,7 +215,6 @@ def _get_label_counts(self, mask: torch.Tensor): labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1] # Expand to one entry per token: find each token's document index via the sorted # length cumsum, then look up that document's label count. - # TODO: Document index already computed in `LengthModelInputPreprocessor`. document_index = torch.searchsorted( length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" ) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 29720b90b..2920c1334 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,6 +21,16 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + docs_per_step: int = Field( + default=0, + desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. " + "When >0, each training step dynamically accumulates microbatches until the globally all-reduced " + "document count reaches this value, then triggers the optimizer step. " + "depth_first_micro_batches is ignored when this is set. " + "0 = use depth_first_micro_batches as-is (fixed microbatch count per step).", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) breadth_first_micro_batches: int = Field( default=1, desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b2e212946..128b95e8e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,8 @@ def _preprocess_data( if context.schedule.phase.is_training else None ) - model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + n_micro_batches = context.schedule._eff_sequential_micro_batches + model_inputs = [next(data_iterator) for _ in range(n_micro_batches)] model_inputs[0][0].share_batch_data( [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) @@ -336,7 +337,7 @@ def _preprocess_data( extra_kwargs={ "grad_output": grad_output, "micro_batch": micro_batch, - "num_micro_batches": self._config.sequential_micro_batches, + "num_micro_batches": n_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 6f7bf1d95..845b5df82 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -115,15 +115,17 @@ def __init__( batch_meta: list[ModelInput], distributed_config: DistributedConfig, phase: PhaseType, + _depth_first_override: int | None = None, ): super().__init__(config) + self._depth_first_override = _depth_first_override self._multi_stage = multi_stage self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._config.num_inputs < self._distributed_config.pipeline_parallel: + if self._eff_num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -155,9 +157,25 @@ def __init__( def phase(self) -> PhaseType: return self._phase + @property + def _eff_depth_first(self) -> int: + return ( + self._depth_first_override + if self._depth_first_override is not None + else self._config.depth_first_micro_batches + ) + + @property + def _eff_sequential_micro_batches(self) -> int: + return self._eff_depth_first * self._config.breadth_first_micro_batches + + @property + def _eff_num_inputs(self) -> int: + return self._eff_sequential_micro_batches * self._config.micro_batch_splits + @property def samples_per_batch(self) -> int: - return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel + return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -189,7 +207,7 @@ def _create_index(self) -> None: Assert.in_range( step.index, 0, - self._config.num_inputs, + self._eff_num_inputs, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -205,7 +223,7 @@ def _create_index(self) -> None: Assert.custom(all, self._device_steps) # Consistency checks step_map = self._step_map.copy() - for data_index in range(self._config.num_inputs): + for data_index in range(self._eff_num_inputs): for type_ in (StepType.forward, StepType.backward): for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]: first_grad_stage += 1 else: first_grad_stage = self._num_stages - for depth_first_micro_batch in range(self._config.depth_first_micro_batches): + for depth_first_micro_batch in range(self._eff_depth_first): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in range(self._config.micro_batch_splits): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in reversed(range(self._config.micro_batch_splits)): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 1ed18c449..77a88377e 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,10 +115,12 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) + self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) + self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, - batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size), + batch_meta=self._single_mb_meta, distributed_config=self._config.model.distributed, phase=PhaseType.training, ) @@ -140,6 +142,41 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._is_setup = True + def _get_or_build_schedule(self, n_microbatches: int) -> Schedule: + if n_microbatches not in self._schedule_cache: + bfmb = self._config.schedule.breadth_first_micro_batches + depth_first = n_microbatches // bfmb + self._schedule_cache[n_microbatches] = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._single_mb_meta, + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + _depth_first_override=depth_first, + ) + return self._schedule_cache[n_microbatches] + + def _prefetch_to_doc_target(self, data_iterator) -> list: + target = self._config.schedule.docs_per_step + bfmb = self._config.schedule.breadth_first_micro_batches + buffer = [] + total_docs = 0 + while total_docs < target: + mb = next(data_iterator) + mb[0].share_batch_data(mb, self._distributed) + total_docs += mb[0].num_documents_in_batch + buffer.append(mb) + Assert.eq( + len(buffer) % bfmb, + 0, + msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}", + ) + # Reset num_documents_in_batch to the step total on all microbatches + for mb in buffer: + for mi in mb: + mi.num_documents_in_batch = total_docs + return buffer + @abc.abstractmethod def _get_data(self) -> Data: pass @@ -220,12 +257,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) - reduced_losses, update_successful, train_metrics = self._runner.run_step( - train_iterator, - self._schedule, - iteration=self._completed_steps, - return_metrics=is_logging, - ) + if self._config.schedule.docs_per_step > 0: + buffer = self._prefetch_to_doc_target(train_iterator) + step_schedule = self._get_or_build_schedule(len(buffer)) + reduced_losses, update_successful, train_metrics = self._runner.run_step( + iter(buffer), + step_schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) + else: + reduced_losses, update_successful, train_metrics = self._runner.run_step( + train_iterator, + self._schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) # Advanced, skipped, and Nan iterations. if update_successful: diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py new file mode 100644 index 000000000..b288a934f --- /dev/null +++ b/tests/layers/test_docs_per_step.py @@ -0,0 +1,204 @@ +""" +Unit tests for docs_per_step. + +Covers: + 1. Divisor scaling in fused_grpo_loss_forward_backward + 2. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties + 3. Trainer._prefetch_to_doc_target accumulation logic +""" + +import dataclasses +import types + +import pytest +import torch + +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.loss.policy_gradient import fused_grpo_loss_forward_backward + +device = "cuda" if torch.cuda.is_available() else "cpu" +_atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# 1. Divisor-scaling correctness in raw kernels +# --------------------------------------------------------------------------- + + +def test_grpo_divisor_scales_loss(): + """Halving the divisor should double the loss.""" + torch.manual_seed(10) + n_tok, vocab = 16, 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d1) + loss2, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d2) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +# --------------------------------------------------------------------------- +# 2. Schedule._eff_* properties +# --------------------------------------------------------------------------- + + +def _make_bare_schedule(depth_first: int, breadth_first: int, splits: int, override: int | None) -> Schedule: + """Create a Schedule with __init__ bypassed to test the _eff_* properties only.""" + config = ScheduleConfig( + depth_first_micro_batches=depth_first, + breadth_first_micro_batches=breadth_first, + micro_batch_splits=splits, + ) + sched = object.__new__(Schedule) + # Minimal attributes used by the three _eff_* properties. + object.__setattr__(sched, "_config", config) + object.__setattr__(sched, "_depth_first_override", override) + # samples_per_batch also needs _distributed_config.batch_data_parallel + fake_distributed = types.SimpleNamespace(batch_data_parallel=1) + object.__setattr__(sched, "_distributed_config", fake_distributed) + return sched + + +def test_schedule_eff_properties_no_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=None) + assert sched._eff_depth_first == 4 + assert sched._eff_sequential_micro_batches == 8 # 4 * 2 + assert sched._eff_num_inputs == 24 # 8 * 3 + assert sched.samples_per_batch == 8 # 8 * dp=1 + + +def test_schedule_eff_properties_with_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=7) + assert sched._eff_depth_first == 7 # override wins + assert sched._eff_sequential_micro_batches == 14 # 7 * 2 + assert sched._eff_num_inputs == 42 # 14 * 3 + assert sched.samples_per_batch == 14 # 14 * dp=1 + + +def test_schedule_eff_properties_override_equals_config(): + """Override equal to config value → same result as no override.""" + sched_no = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=None) + sched_yes = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=3) + assert sched_no._eff_depth_first == sched_yes._eff_depth_first + assert sched_no._eff_sequential_micro_batches == sched_yes._eff_sequential_micro_batches + assert sched_no._eff_num_inputs == sched_yes._eff_num_inputs + + +def test_schedule_samples_per_batch_uses_eff(): + """samples_per_batch should scale with _eff_sequential, not config.sequential.""" + sched = _make_bare_schedule(depth_first=2, breadth_first=2, splits=1, override=5) + # Config says depth_first=2 → sequential=4; override=5 → eff_sequential=10 + assert sched._eff_sequential_micro_batches == 10 + assert sched.samples_per_batch == 10 # dp=1 + + +# --------------------------------------------------------------------------- +# 3. _prefetch_to_doc_target accumulation logic +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _FakeMicrobatch: + """Stub for a single split of one microbatch.""" + + num_documents: int + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, inputs, distributed): + """Mimic TokenModelInput.share_batch_data with group=None (single process).""" + if inputs[0].num_documents_in_batch is None: + total = sum(inp.num_documents for inp in inputs) + for inp in inputs: + inp.num_documents_in_batch = total + + +def _fake_iterator(doc_counts: list[int]): + """Yield [_FakeMicrobatch(n)] for each n in doc_counts.""" + for n in doc_counts: + yield [_FakeMicrobatch(num_documents=n)] + + +class _StubTrainer: + """Concrete stub that exposes only the interface _prefetch_to_doc_target needs.""" + + # Borrow the method directly so it runs against this stub's attributes. + from fast_llm.engine.training.trainer import Trainer as _Trainer + + _prefetch_to_doc_target = _Trainer._prefetch_to_doc_target + + +def _make_fake_trainer(docs_per_step: int, bfmb: int = 1): + """Create a _StubTrainer with the attributes _prefetch_to_doc_target reads.""" + schedule_cfg = types.SimpleNamespace( + docs_per_step=docs_per_step, + breadth_first_micro_batches=bfmb, + ) + config = types.SimpleNamespace(schedule=schedule_cfg) + distributed = types.SimpleNamespace(batch_data_group=None) + + trainer = _StubTrainer() + trainer._config = config + trainer._distributed = distributed + return trainer + + +def test_prefetch_stops_at_target(): + """Buffer should stop growing once cumulative docs ≥ docs_per_step.""" + trainer = _make_fake_trainer(docs_per_step=6, bfmb=1) + # Each microbatch has 2 docs; need ≥6 → expect 3 microbatches + it = _fake_iterator([2, 2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + + assert len(buffer) == 3, f"Expected 3 microbatches, got {len(buffer)}" + + +def test_prefetch_resets_num_documents_in_batch(): + """After the call, every microbatch input has num_documents_in_batch = step total.""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + # 3 docs, 3 docs → total=6 (overshoots 5, stops after 2nd) + it = _fake_iterator([3, 3, 3]) + buffer = trainer._prefetch_to_doc_target(it) + + step_total = sum(mb[0].num_documents for mb in buffer) + for mb in buffer: + for mi in mb: + assert ( + mi.num_documents_in_batch == step_total + ), f"Expected num_documents_in_batch={step_total}, got {mi.num_documents_in_batch}" + + +def test_prefetch_overshoot_is_included(): + """A microbatch that pushes the total over the target IS included (not dropped).""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + it = _fake_iterator([4, 4]) # 4 < 5, then 8 ≥ 5 → 2 microbatches + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 + assert buffer[-1][0].num_documents_in_batch == 8 # step total = 4+4 + + +def test_prefetch_divisibility_check(): + """Raises when fetched count is not divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # Each microbatch has 5 docs → only 1 mb needed, but 1 % 2 != 0 + it = _fake_iterator([5, 5, 5]) + with pytest.raises(Exception): + trainer._prefetch_to_doc_target(it) + + +def test_prefetch_exact_divisibility(): + """No error when fetched count is exactly divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # 2 docs each → need ≥4 → fetch 2 microbatches → 2 % 2 == 0 + it = _fake_iterator([2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2