[Bugfix] softmax: enable vectorized fast path (fix fastmath attr + f32 128b copy width) (#627)#650
Merged
Conversation
…2 128b copy width) Re-enables the dead-coded BufferCopy128b fast path in softmax_kernel.py: - Drop the `False and` guard that made the fast path unreachable, so any N that is a multiple of the tile width takes the vectorized load/store. - fastmath=True -> fm_fast: a Python bool was emitting an invalid `#arith.fastmath<True>` attribute (MLIRError). The generic path already used arith.FastMathFlags.fast. - vec_width = 128 // elem_bits (8 for 16-bit, 4 for f32) so the register vector matches the single 128-bit BufferCopy128b transaction. f32 was 8x32=256 bits against a 128-bit copy atom, which aborted in codegen with `CastInst::Create: Invalid cast!`. This is the kernel's documented "WIDTH=8/4" intent. Verified on MI350X / gfx950 vs torch.softmax: f32 bit-exact (max_err ~5e-10), bf16/f16 within atol, across single- and multi-tile shapes and the generic (non-multiple-N) path. Kernel-time A/B vs the scalar path is on-par to +7% on large bf16 and neutral elsewhere (both are HBM-bound and buffer the whole row in registers) -- not the 2x originally hypothesized, but the change turns a latent Invalid-cast abort into a correct, enabled vectorized path. Fixes ROCm#627 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR adjusts the softmax kernel’s vectorization and enables an optimized “fast path” when N aligns with the tile width, aiming to improve correctness/performance across f16 vs f32.
Changes:
- Compute vector width from element bit-width to match 128-bit buffer copy transactions.
- Enable the fast path when
Nis a multiple oftile_cols. - Make
exp2fast-math behavior consistent with the existingfm_fastflag.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
coderfeli
approved these changes
Jun 4, 2026
15 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Re-enables the dead-coded vectorized (
BufferCopy128b) fast path inkernels/softmax_kernel.pyand fixes the two compile failures that kept it switched off (#627).Three changes, all in the fast-path branch:
False andguard (if const_expr(False and ...)) so anNthat is a multiple of the tile width actually takes the vectorized path.fastmath=True→fastmath=fm_fast— the Python bool was stringified into an invalid#arith.fastmath<True>attribute (MLIRError: expected ... one of: none, reassoc, ...). The generic path already usedarith.FastMathFlags.fast.vec_width = 128 // elem_bits(8 for 16-bit, 4 for f32) instead of a hard-codedVEC_WIDTH = 8.Root cause of the
Invalid cast!assertionThe hard LLVM abort
is f32-specific, not bf16. The copy atom is
BufferCopy128b— one 128-bit transaction per lane — but the register vector was alwaysVEC_WIDTH = 8elements:8 × elem_bitsTying the vector width to the transaction (
128 // elem_bits→ 4 lanes for f32) makes the register vector match the 128-bit copy and lets f32 lower cleanly. This is the kernel's own documented "Vectorized Loads/Stores (WIDTH=8/4)" intent.Verification (MI350X / gfx950, ROCm 7.2)
Correctness vs
torch.softmax— fast path (single + multi-tile) and generic (non-multipleN) path, all dtypes:IR confirms the vectorized path is actually taken (
buffer_load_dwordx4,vector<8xf32>/vector<4xf32>,llvm.vector.reduce.fmax/fadd).Performance — honest A/B (kernel time, fast vs scalar, identical shapes)
This is on-par to ~+7% on large bf16, neutral elsewhere — not a 2×. Both paths already saturate HBM (~5 TB/s) because the scalar loads coalesce across the wavefront, and both register-buffer the entire row, so vectorization only trims instruction count. (The 2.05× in the original sweep was FlyDSL-vs-Triton, not fast-vs-scalar.)
The value here is correctness: the change turns a latent
Invalid cast!abort (for anyone who re-enables the path, or for f32) into a correct, enabled vectorized path, and removes theFalse anddead-code trap. A larger win would require redesigning the full-row register buffering into a streaming/tiled scheme — out of scope for this fix.Fixes #627
🤖 Generated with Claude Code