Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,14 @@ def _set_target_inputs(

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
def _get_label_counts(self, mask: torch.Tensor) -> torch.Tensor:
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
Expand Down
10 changes: 10 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ class ScheduleConfig(Config):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
docs_per_step: int = Field(
default=0,
desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. "
"When >0, each training step dynamically accumulates microbatches until the globally all-reduced "
"document count reaches this value, then triggers the optimizer step. "
"depth_first_micro_batches is ignored when this is set. "
"0 = use depth_first_micro_batches as-is (fixed microbatch count per step).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
breadth_first_micro_batches: int = Field(
default=1,
desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.",
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def _preprocess_data(
if context.schedule.phase.is_training
else None
)
model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)]
n_micro_batches = context.schedule._eff_sequential_micro_batches
model_inputs = [next(data_iterator) for _ in range(n_micro_batches)]
model_inputs[0][0].share_batch_data(
[model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed
)
Expand All @@ -336,7 +337,7 @@ def _preprocess_data(
extra_kwargs={
"grad_output": grad_output,
"micro_batch": micro_batch,
"num_micro_batches": self._config.sequential_micro_batches,
"num_micro_batches": n_micro_batches,
"micro_batch_splits": self._config.micro_batch_splits,
},
)
Expand Down
38 changes: 25 additions & 13 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def __init__(
batch_meta: list[ModelInput],
distributed_config: DistributedConfig,
phase: PhaseType,
_depth_first_override: int | None = None,
):
super().__init__(config)
self._depth_first_override = _depth_first_override
self._multi_stage = multi_stage
self._distributed_config = distributed_config
self._num_stages = len(self._multi_stage.stages)
self._phase = phase
self._is_training = self._phase.is_training

if self._config.num_inputs < self._distributed_config.pipeline_parallel:
if self._eff_num_inputs < self._distributed_config.pipeline_parallel:
warnings.warn("Not enough input to achieve true pipeline parallelism.")

# Setup the activation metas.
Expand Down Expand Up @@ -155,9 +157,25 @@ def __init__(
def phase(self) -> PhaseType:
return self._phase

@property
def _eff_depth_first(self) -> int:
return (
self._depth_first_override
if self._depth_first_override is not None
else self._config.depth_first_micro_batches
)

@property
def _eff_sequential_micro_batches(self) -> int:
return self._eff_depth_first * self._config.breadth_first_micro_batches

@property
def _eff_num_inputs(self) -> int:
return self._eff_sequential_micro_batches * self._config.micro_batch_splits

@property
def samples_per_batch(self) -> int:
return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel
return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel

def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]:
return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank])
Expand Down Expand Up @@ -189,7 +207,7 @@ def _create_index(self) -> None:
Assert.in_range(
step.index,
0,
self._config.num_inputs,
self._eff_num_inputs,
)
Assert.incl(step.type_, (StepType.forward, StepType.backward))
step.global_index = i
Expand All @@ -205,7 +223,7 @@ def _create_index(self) -> None:
Assert.custom(all, self._device_steps)
# Consistency checks
step_map = self._step_map.copy()
for data_index in range(self._config.num_inputs):
for data_index in range(self._eff_num_inputs):
for type_ in (StepType.forward, StepType.backward):
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
assert (
Expand Down Expand Up @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]:
first_grad_stage += 1
else:
first_grad_stage = self._num_stages
for depth_first_micro_batch in range(self._config.depth_first_micro_batches):
for depth_first_micro_batch in range(self._eff_depth_first):
for stage in range(self._num_stages):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in range(self._config.micro_batch_splits):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand All @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]:
for stage in reversed(range(first_grad_stage, self._num_stages)):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in reversed(range(self._config.micro_batch_splits)):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand Down
61 changes: 54 additions & 7 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def setup(self, distributed: Distributed, run: Run) -> None:
preprocessing_config = self._multi_stage.get_preprocessing_config(
PhaseType.training, self._config.schedule.micro_batch_splits
)
self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size)
self._schedule_cache: dict[int, Schedule] = {}
self._schedule = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size),
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
)
Expand All @@ -140,6 +142,41 @@ def setup(self, distributed: Distributed, run: Run) -> None:

self._is_setup = True

def _get_or_build_schedule(self, n_microbatches: int) -> Schedule:
if n_microbatches not in self._schedule_cache:
bfmb = self._config.schedule.breadth_first_micro_batches
depth_first = n_microbatches // bfmb
self._schedule_cache[n_microbatches] = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
_depth_first_override=depth_first,
)
return self._schedule_cache[n_microbatches]

def _prefetch_to_doc_target(self, data_iterator) -> list:
target = self._config.schedule.docs_per_step
bfmb = self._config.schedule.breadth_first_micro_batches
buffer = []
total_docs = 0
while total_docs < target:
mb = next(data_iterator)
mb[0].share_batch_data(mb, self._distributed)
total_docs += mb[0].num_documents_in_batch
buffer.append(mb)
Assert.eq(
len(buffer) % bfmb,
0,
msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}",
)
# Reset num_documents_in_batch to the step total on all microbatches
for mb in buffer:
for mi in mb:
mi.num_documents_in_batch = total_docs
return buffer

@abc.abstractmethod
def _get_data(self) -> Data:
pass
Expand Down Expand Up @@ -220,12 +257,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:

# TODO: Data loader hates getting all micro-batches at once.
# (Also preprocessing adds overhead)
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
if self._config.schedule.docs_per_step > 0:
buffer = self._prefetch_to_doc_target(train_iterator)
step_schedule = self._get_or_build_schedule(len(buffer))
reduced_losses, update_successful, train_metrics = self._runner.run_step(
iter(buffer),
step_schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
else:
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)

# Advanced, skipped, and Nan iterations.
if update_successful:
Expand Down
Loading
Loading