File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -398,7 +398,9 @@ def __init__(
398398 update = {
399399 "sharded" : False ,
400400 "shard_dim" : None ,
401- "grad_sync_op" : GRAD_SYNC_OP_SUM , # sum replicated TP contributions
401+ # Row-parallel output uses TP collectives whose backward already gives
402+ # replicated B the full output gradient on each TP rank.
403+ "grad_sync_op" : GRAD_SYNC_OP_NONE ,
402404 }
403405 )
404406 self .lora = LoRA (
@@ -689,7 +691,9 @@ def __init__(
689691 "sharded" : False ,
690692 "shard_dim" : None ,
691693 "grad_sync_domain" : EXPERT_TP_GRAD_SYNC_DOMAIN ,
692- "grad_sync_op" : GRAD_SYNC_OP_SUM , # we handle this with extended finalize_grads
694+ # Expert row-parallel output follows the same pattern: replicated B
695+ # already sees the full gradient from the backward TP collective.
696+ "grad_sync_op" : GRAD_SYNC_OP_NONE ,
693697 }
694698 )
695699 self .lora = LoRA (
You can’t perform that action at this time.
0 commit comments