@@ -352,35 +352,54 @@ def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors:
352352 return dummy
353353
354354
355+ def resolve_global_grad_accumulation_sequences (
356+ global_grad_accumulation_sequences : int | None ,
357+ ) -> int :
358+ dp_world_size = ps .get_data_parallel_world_size ()
359+ if global_grad_accumulation_sequences is None :
360+ return dp_world_size
361+ return global_grad_accumulation_sequences
362+
363+
355364def resolve_local_grad_accumulation_sequences (
356- global_grad_accumulation_sequences : int ,
365+ global_grad_accumulation_sequences : int | None ,
357366) -> int :
367+ resolved_global_grad_accumulation_sequences = (
368+ resolve_global_grad_accumulation_sequences (
369+ global_grad_accumulation_sequences = global_grad_accumulation_sequences
370+ )
371+ )
358372 dp_world_size = ps .get_data_parallel_world_size ()
359373 if (
360- global_grad_accumulation_sequences <= 0
361- or global_grad_accumulation_sequences % dp_world_size != 0
374+ resolved_global_grad_accumulation_sequences <= 0
375+ or resolved_global_grad_accumulation_sequences % dp_world_size != 0
362376 ):
363377 raise RuntimeError (
364378 "Invalid global grad accumulation / DP world size combination: "
365- f"global_grad_accumulation_sequences={ global_grad_accumulation_sequences } , "
379+ f"global_grad_accumulation_sequences={ resolved_global_grad_accumulation_sequences } , "
366380 f"dp_world_size={ dp_world_size } "
367381 )
368- return global_grad_accumulation_sequences // dp_world_size
382+ return resolved_global_grad_accumulation_sequences // dp_world_size
369383
370384
371385def build_micro_sample_indices (
372386 step_index : int ,
373387 num_sequences : int ,
374- global_grad_accumulation_sequences : int ,
388+ global_grad_accumulation_sequences : int | None ,
375389) -> list [int | None ]:
376390 dp_rank = ps .get_data_parallel_rank ()
391+ resolved_global_grad_accumulation_sequences = (
392+ resolve_global_grad_accumulation_sequences (
393+ global_grad_accumulation_sequences = global_grad_accumulation_sequences
394+ )
395+ )
377396 dp_world_size = ps .get_data_parallel_world_size ()
378397 local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences (
379- global_grad_accumulation_sequences = global_grad_accumulation_sequences ,
398+ global_grad_accumulation_sequences = resolved_global_grad_accumulation_sequences ,
380399 )
381- base_global_sample_index = step_index * global_grad_accumulation_sequences
400+ base_global_sample_index = step_index * resolved_global_grad_accumulation_sequences
382401 global_step_indices : list [int | None ] = []
383- for offset in range (global_grad_accumulation_sequences ):
402+ for offset in range (resolved_global_grad_accumulation_sequences ):
384403 global_sample_index = base_global_sample_index + offset
385404 global_step_indices .append (
386405 global_sample_index if global_sample_index < num_sequences else None
@@ -479,10 +498,15 @@ def run_training_step(
479498 micro_sample_indices = [sample_index ]
480499
481500 if moe_routing_replay_controller is not None :
501+ resolved_global_grad_accumulation_sequences = (
502+ resolve_global_grad_accumulation_sequences (
503+ config .grad_accumulation_sequences
504+ )
505+ )
482506 moe_routing_replay_controller .set_step (
483507 step_index = step_index ,
484508 sample_index = micro_sample_indices ,
485- global_grad_accumulation_sequences = config . grad_accumulation_sequences ,
509+ global_grad_accumulation_sequences = resolved_global_grad_accumulation_sequences ,
486510 )
487511
488512 device = next (model_chunks [0 ].parameters ()).device
@@ -532,6 +556,7 @@ def run_training_step(
532556 if new_logprobs is None or raw_loss_sum is None :
533557 raise RuntimeError ("run_training_step did not produce outputs" )
534558
559+ # num_tokens is reduced in place across ranks by finalize_model_grads().
535560 finalize_model_grads_extended (model_chunks , num_tokens = num_tokens )
536561 update_successful , grad_norm , num_zeros_in_grad = _optimizer_step (
537562 optimizer ,
0 commit comments