Skip to content

Commit 2e64da0

Browse files
committed
Default Megatron grad accumulation by DP size
1 parent 511d72c commit 2e64da0

4 files changed

Lines changed: 43 additions & 13 deletions

File tree

src/art/local/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,9 @@ async def _train_model(
805805
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
806806
)
807807
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
808-
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
808+
grad_accumulation_sequences = max(
809+
1, int(config.grad_accumulation_sequences or 1)
810+
)
809811
estimated_gradient_steps = math.ceil(
810812
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
811813
)

src/art/megatron/shared.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
configure_moe_routing_replay,
2929
load_adapter_into_model,
3030
print0,
31+
resolve_global_grad_accumulation_sequences,
3132
run_training_step,
3233
select_indexed_inputs,
3334
select_micro_inputs,
@@ -119,7 +120,9 @@ def run_megatron_rl_job(
119120
template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0))
120121
zero_template = _zero_contribution_inputs(template)
121122
num_sequences = job.disk_packed_tensors["num_sequences"]
122-
global_grad_accumulation_sequences = job.config.grad_accumulation_sequences
123+
global_grad_accumulation_sequences = resolve_global_grad_accumulation_sequences(
124+
job.config.grad_accumulation_sequences
125+
)
123126
num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences)
124127
for step_index in range(num_steps):
125128
micro_indices = build_micro_sample_indices(

src/art/megatron/train.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -352,35 +352,54 @@ def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors:
352352
return dummy
353353

354354

355+
def resolve_global_grad_accumulation_sequences(
356+
global_grad_accumulation_sequences: int | None,
357+
) -> int:
358+
dp_world_size = ps.get_data_parallel_world_size()
359+
if global_grad_accumulation_sequences is None:
360+
return dp_world_size
361+
return global_grad_accumulation_sequences
362+
363+
355364
def resolve_local_grad_accumulation_sequences(
356-
global_grad_accumulation_sequences: int,
365+
global_grad_accumulation_sequences: int | None,
357366
) -> int:
367+
resolved_global_grad_accumulation_sequences = (
368+
resolve_global_grad_accumulation_sequences(
369+
global_grad_accumulation_sequences=global_grad_accumulation_sequences
370+
)
371+
)
358372
dp_world_size = ps.get_data_parallel_world_size()
359373
if (
360-
global_grad_accumulation_sequences <= 0
361-
or global_grad_accumulation_sequences % dp_world_size != 0
374+
resolved_global_grad_accumulation_sequences <= 0
375+
or resolved_global_grad_accumulation_sequences % dp_world_size != 0
362376
):
363377
raise RuntimeError(
364378
"Invalid global grad accumulation / DP world size combination: "
365-
f"global_grad_accumulation_sequences={global_grad_accumulation_sequences}, "
379+
f"global_grad_accumulation_sequences={resolved_global_grad_accumulation_sequences}, "
366380
f"dp_world_size={dp_world_size}"
367381
)
368-
return global_grad_accumulation_sequences // dp_world_size
382+
return resolved_global_grad_accumulation_sequences // dp_world_size
369383

370384

371385
def build_micro_sample_indices(
372386
step_index: int,
373387
num_sequences: int,
374-
global_grad_accumulation_sequences: int,
388+
global_grad_accumulation_sequences: int | None,
375389
) -> list[int | None]:
376390
dp_rank = ps.get_data_parallel_rank()
391+
resolved_global_grad_accumulation_sequences = (
392+
resolve_global_grad_accumulation_sequences(
393+
global_grad_accumulation_sequences=global_grad_accumulation_sequences
394+
)
395+
)
377396
dp_world_size = ps.get_data_parallel_world_size()
378397
local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences(
379-
global_grad_accumulation_sequences=global_grad_accumulation_sequences,
398+
global_grad_accumulation_sequences=resolved_global_grad_accumulation_sequences,
380399
)
381-
base_global_sample_index = step_index * global_grad_accumulation_sequences
400+
base_global_sample_index = step_index * resolved_global_grad_accumulation_sequences
382401
global_step_indices: list[int | None] = []
383-
for offset in range(global_grad_accumulation_sequences):
402+
for offset in range(resolved_global_grad_accumulation_sequences):
384403
global_sample_index = base_global_sample_index + offset
385404
global_step_indices.append(
386405
global_sample_index if global_sample_index < num_sequences else None
@@ -479,10 +498,15 @@ def run_training_step(
479498
micro_sample_indices = [sample_index]
480499

481500
if moe_routing_replay_controller is not None:
501+
resolved_global_grad_accumulation_sequences = (
502+
resolve_global_grad_accumulation_sequences(
503+
config.grad_accumulation_sequences
504+
)
505+
)
482506
moe_routing_replay_controller.set_step(
483507
step_index=step_index,
484508
sample_index=micro_sample_indices,
485-
global_grad_accumulation_sequences=config.grad_accumulation_sequences,
509+
global_grad_accumulation_sequences=resolved_global_grad_accumulation_sequences,
486510
)
487511

488512
device = next(model_chunks[0].parameters()).device
@@ -532,6 +556,7 @@ def run_training_step(
532556
if new_logprobs is None or raw_loss_sum is None:
533557
raise RuntimeError("run_training_step did not produce outputs")
534558

559+
# num_tokens is reduced in place across ranks by finalize_model_grads().
535560
finalize_model_grads_extended(model_chunks, num_tokens=num_tokens)
536561
update_successful, grad_norm, num_zeros_in_grad = _optimizer_step(
537562
optimizer,

src/art/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class TrainConfig(pydantic.BaseModel):
1818
learning_rate: float = 5e-6
1919
kl_penalty_coef: float = 0.0
20-
grad_accumulation_sequences: int = pydantic.Field(default=1, ge=1)
20+
grad_accumulation_sequences: int | None = pydantic.Field(default=None, ge=1)
2121

2222

2323
class TrainSFTConfig(pydantic.BaseModel):

0 commit comments

Comments
 (0)