Skip to content

fix: prevent MXFP8 amax aliasing with dSoftmaxOffset in bwd#4

Open
vedaanta wants to merge 1 commit into
cyanguwa:add_mxfp8from
vedaanta:fix/mxfp8-dsink-amax-aliasing
Open

fix: prevent MXFP8 amax aliasing with dSoftmaxOffset in bwd#4
vedaanta wants to merge 1 commit into
cyanguwa:add_mxfp8from
vedaanta:fix/mxfp8-dsink-amax-aliasing

Conversation

@vedaanta
Copy link
Copy Markdown

Summary

  • In fused_attn_fp8_bwd, amax pointers for dQ/dK/dV all pointed to output_dQ->amax.dptr. For MXFP8 (which doesn't use per-tensor
    amax), PyTorch's caching allocator could reuse this memory for the dSoftmaxOffset tensor.
  • cuDNN's second bprop kernel wrote amax_dQ to that address, corrupting d_softmax_offset[0].
  • Fix: use each tensor's own amax/scale pointers, and when aliasing with dSoftmaxOffset is detected, allocate scratch space for the
    amax outputs.

Test plan

  • test_dpa_fp8_vs_f16[mxfp8-True-True-bshd_bshd_bshd-fp8_16-dtype0] _ d_softmax_offset RMSE dropped from 0.60 to 0.15

…ttention backward

  In fused_attn_fp8_bwd, the amax pointers for dQ/dK/dV all pointed to
  output_dQ->amax.dptr. For MXFP8 (which doesn't use per-tensor amax),
  PyTorch's caching allocator could reuse this memory for the
  dSoftmaxOffset tensor. cuDNN's second bprop kernel then wrote amax_dQ
  to that address, corrupting d_softmax_offset[0].

  Fix: use each tensor's own amax/scale pointers, and when aliasing with
  dSoftmaxOffset is detected, allocate scratch space for the amax outputs.

  Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@vedaanta
Copy link
Copy Markdown
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant