Skip to content

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052

Open
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip
Open

[PyTorch] Propagate FP8 graph weight update flag in GroupedLinear#3052
allenphilipj wants to merge 3 commits into
NVIDIA:mainfrom
allenphilipj:codex-grouped-linear-fp8-cudagraph-skip

Conversation

@allenphilipj
Copy link
Copy Markdown

Summary:

  • Propagate the FP8 graph-capture skip_fp8_weight_update tensor through GroupedLinear.
  • Align GroupedLinear graph-capture handling with Linear, LayerNormLinear, and LayerNormMLP.
  • Add a focused regression test for the forwarded skip tensor and graph-compatible is_first_microbatch behavior.

Validation:

  • git diff --check
  • python3 -m py_compile transformer_engine/pytorch/module/grouped_linear.py tests/pytorch/test_cuda_graphs.py
  • Not run: focused pytest, because pytest is not installed in this local environment.

Fixes #3051

@allenphilipj allenphilipj requested a review from ksivaman as a code owner May 28, 2026 12:36
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
Signed-off-by: allenphilipj <allenphilipj@users.noreply.github.com>
@allenphilipj allenphilipj force-pushed the codex-grouped-linear-fp8-cudagraph-skip branch from 937ef34 to 80304fa Compare May 28, 2026 12:40
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a missing propagation of the skip_fp8_weight_update_tensor flag in GroupedLinear.forward, bringing it in line with Linear, LayerNormLinear, and LayerNormMLP. Previously, grouped_linear.py always passed None for skip_fp8_weight_update to _GroupedLinear, so weight-update skipping during CUDA-graph capture was silently ignored for grouped GEMMs.

  • grouped_linear.py: Reads FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor when fp8_graph_capturing() is True, overrides is_first_microbatch to False (same logic as peers), and passes the tensor instead of None into non_tensor_args.
  • test_cuda_graphs.py: Adds a focused unit test that monkeypatches the graph-capture state, intercepts _GroupedLinear.forward, and asserts both the is_first_microbatch override and the propagated skip tensor are correct.

Confidence Score: 5/5

Safe to merge — the change is a targeted one-liner fix that mirrors identical logic already present in three sibling modules.

The fix is a direct port of the same four-line block from Linear, LayerNormLinear, and LayerNormMLP. The previously hardcoded None caused the FP8 weight-update skip tensor to be silently dropped during CUDA-graph capture for GroupedLinear, but the surrounding quantize_weight call already accepts and correctly handles the tensor at line 205 — so the fix slots in without any structural changes. The accompanying test exercises both the propagated tensor identity and the is_first_microbatch override, with named constants documenting the tuple layout.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Added graph-capture skip-tensor propagation identical to Linear/LayerNormLinear/LayerNormMLP; replaces hardcoded None with the actual tensor and overrides is_first_microbatch to False.
tests/pytorch/test_cuda_graphs.py New regression test verifies skip-tensor propagation and is_first_microbatch override via monkeypatching; uses named index constants to document and anchor the non_tensor_args layout.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant GroupedLinear
    participant FP8GlobalStateManager
    participant _GroupedLinear

    Caller->>GroupedLinear: forward(inp, m_splits, is_first_microbatch)
    GroupedLinear->>FP8GlobalStateManager: fp8_graph_capturing()
    alt Graph capture in progress
        FP8GlobalStateManager-->>GroupedLinear: True
        GroupedLinear->>FP8GlobalStateManager: quantization_state.skip_fp8_weight_update_tensor
        FP8GlobalStateManager-->>GroupedLinear: skip_tensor (not None)
        GroupedLinear->>GroupedLinear: "override is_first_microbatch = False"
    else Normal execution
        FP8GlobalStateManager-->>GroupedLinear: False
        GroupedLinear->>GroupedLinear: "skip_fp8_weight_update = None"
    end
    GroupedLinear->>_GroupedLinear: forward(..., non_tensor_args[skip_fp8_weight_update], ...)
    _GroupedLinear->>_GroupedLinear: "quantize_weight(skip_update_flag=skip_fp8_weight_update)"
    _GroupedLinear-->>GroupedLinear: output
    GroupedLinear-->>Caller: output
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into codex-grouped-l..." | Re-trigger Greptile

Comment thread tests/pytorch/test_cuda_graphs.py Outdated
allenphilipj and others added 2 commits May 28, 2026 13:54
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: allenphilipj <allen.philip@intercom.io>
@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture

2 participants