Skip to content

[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056

Open
KshitijLakhani wants to merge 9 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/enable-headdim256-bwd-sm100
Open

[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
KshitijLakhani wants to merge 9 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/enable-headdim256-bwd-sm100

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 28, 2026

Description

Support for D=256 BWD for Blackwell CC 10x via the C++ API (which TE uses) was added in cuDNN 9.23 + cuDNN FE 1.24. Enabling this support in TE attention

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Add guard when picking the backend (sub backend) in TE common.
Add tests for D=256 case in TE PyT and TE JAX

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feat/enable-headdim256-bwd-sm100 branch from 51ad582 to d177ecf Compare May 28, 2026 23:05
@KshitijLakhani KshitijLakhani changed the title [JAX] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x [JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x May 28, 2026
pre-commit-ci Bot and others added 4 commits May 28, 2026 23:06
…n fused attn

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as ready for review May 29, 2026 22:56
@KshitijLakhani KshitijLakhani requested a review from cyanguwa as a code owner May 29, 2026 22:56
@KshitijLakhani KshitijLakhani self-assigned this May 29, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 29, 2026

Greptile Summary

This PR enables cuDNN 9.23 / FE 1.24's dedicated deterministic SDPA backward kernel for head-dim 256 on Blackwell (SM10.x / CC 10.x) GPUs. The C++ backend selector, JAX test skip guards, and PyTorch test cases are all updated in concert.

  • fused_attn.cpp adds a new compound condition that selects the new BWD path only when d_qk == d_v == 256, is_training, sm_arch ∈ [100, 110), cuDNN ≥ 9.23, non-paged layout, no bias, no dropout, vanilla softmax, and the window-size rule (full window OR causal mask with right window −1 or 0).
  • Both the JAX and PyTorch tests add matching guards and new parametrized cases; the PyTorch comment block contains a duplicated trailing sentence that should be removed, and the JAX skip logic for non-1HSS bias configs diverges slightly from the C++ gate (those configs won't use the new kernel path but will not fail).

Confidence Score: 4/5

The C++ backend gate is additive and well-guarded; existing paths are unchanged and the new path only activates under a very specific combination of hardware, cuDNN version, and kernel parameters.

The core logic in fused_attn.cpp looks correct and conservative. The two test-side issues (a duplicate comment line and a JAX skip-guard that allows some bias configs to pass without actually exercising the new kernel) are minor and do not affect production correctness. No existing behavior is altered.

The JAX test's _check_configs bias skip logic (tests/jax/test_fused_attn.py, lines 465-475) deserves a second look to ensure the skip conditions exactly match the C++ gate.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds a new head-dim guard condition enabling the cuDNN 9.23+ deterministic D=256 BWD path on SM10.x. The condition correctly restricts to d_qk == d_v == 256, is_training, sm_arch [100, 110), cuDNN >= 92300, non-paged layout, no bias, no dropout, vanilla softmax, and the appropriate window-size rule (full window OR causal mask with right window -1/0).
tests/pytorch/attention/test_attention.py Adds test_dpa_fused_attn_hdim256 gated on cuDNN >= (9,23,0) and SM100/SM103. Tests no-mask, padding, causal+SWA, and padding-causal+GQA variants. Contains a duplicated trailing comment fragment.
tests/jax/test_fused_attn.py Adds skip guards in FusedAttnRunner._check_configs for D=256 BWD on SM10.x (cuDNN version, head dim symmetry, dropout, softmax type, window size). The skip logic for non-1HSS bias diverges from the C++ gate, meaning some configs will silently use a different backend.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend called] --> B{dtype FP16/BF16?}
    B -- No --> Z[Other backend]
    B -- Yes --> C{head_dim_qk/v <= 128?}
    C -- Yes --> ARB[NVTE_F16_Arbitrary backend]
    C -- No --> D{d_qk==d_v==256 AND is_training\nAND sm_arch in 100-109\nAND cuDNN >= 9.23?}
    D -- Yes --> E{bias==NO_BIAS\ndropout==0\nsoftmax==VANILLA?}
    E -- No --> SKIP[Skip: Fall through to next condition]
    E -- Yes --> F{window_size == -1,-1\nOR causal mask + right_win in -1,0?}
    F -- No --> SKIP
    F -- Yes --> ARB
    D -- No --> G{d_qk/v <= 256 AND Hopper?\nOR Blackwell fprop?\nOR other existing rules}
    G -- Yes --> ARB
    G -- No --> SKIP
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +382 to +383
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated comment fragment

The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.

Suggested change
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +465 to +475
# Non-learnable bias is fine (bias is allowed as an input); only dBias is
# unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s]
# (see test_backward), so gate on that.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
unsupported = (
"bias gradients (dBias); frozen/non-learnable bias inputs"
" (i.e. non-1HSS bias shapes) are supported"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JAX skip logic diverges from C++ backend gate for non-1HSS bias

The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant