Skip to content

Commit 3d8d1f5

Browse files
committed
Fix Megatron row-parallel LoRA grad sync
1 parent f6cd445 commit 3d8d1f5

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/art/megatron/lora.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,9 @@ def __init__(
398398
update={
399399
"sharded": False,
400400
"shard_dim": None,
401-
"grad_sync_op": GRAD_SYNC_OP_SUM, # sum replicated TP contributions
401+
# Row-parallel output uses TP collectives whose backward already gives
402+
# replicated B the full output gradient on each TP rank.
403+
"grad_sync_op": GRAD_SYNC_OP_NONE,
402404
}
403405
)
404406
self.lora = LoRA(
@@ -689,7 +691,9 @@ def __init__(
689691
"sharded": False,
690692
"shard_dim": None,
691693
"grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN,
692-
"grad_sync_op": GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads
694+
# Expert row-parallel output follows the same pattern: replicated B
695+
# already sees the full gradient from the backward TP collective.
696+
"grad_sync_op": GRAD_SYNC_OP_NONE,
693697
}
694698
)
695699
self.lora = LoRA(

0 commit comments

Comments
 (0)