[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052allenphilipj wants to merge 3 commits into
Conversation
Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
937ef34 to
80304fa
Compare
Greptile SummaryThis PR fixes a missing propagation of the
Confidence Score: 5/5Safe to merge — the change is a targeted one-liner fix that mirrors identical logic already present in three sibling modules. The fix is a direct port of the same four-line block from Linear, LayerNormLinear, and LayerNormMLP. The previously hardcoded None caused the FP8 weight-update skip tensor to be silently dropped during CUDA-graph capture for GroupedLinear, but the surrounding quantize_weight call already accepts and correctly handles the tensor at line 205 — so the fix slots in without any structural changes. The accompanying test exercises both the propagated tensor identity and the is_first_microbatch override, with named constants documenting the tuple layout. No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant GroupedLinear
participant FP8GlobalStateManager
participant _GroupedLinear
Caller->>GroupedLinear: forward(inp, m_splits, is_first_microbatch)
GroupedLinear->>FP8GlobalStateManager: fp8_graph_capturing()
alt Graph capture in progress
FP8GlobalStateManager-->>GroupedLinear: True
GroupedLinear->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
FP8GlobalStateManager-->>GroupedLinear: skip_tensor (not None)
GroupedLinear->>GroupedLinear: "override is_first_microbatch = False"
else Normal execution
FP8GlobalStateManager-->>GroupedLinear: False
GroupedLinear->>GroupedLinear: "skip_fp8_weight_update = None"
end
GroupedLinear->>_GroupedLinear: forward(..., non_tensor_args[skip_fp8_weight_update], ...)
_GroupedLinear->>_GroupedLinear: "quantize_weight(skip_update_flag=skip_fp8_weight_update)"
_GroupedLinear-->>GroupedLinear: output
GroupedLinear-->>Caller: output
Reviews (3): Last reviewed commit: "Merge branch 'main' into codex-grouped-l..." | Re-trigger Greptile |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: allenphilipj <allen.philip@intercom.io>
|
/te-ci pytorch |
Summary:
Validation:
Fixes #3051