Skip to content

[Bugfix] softmax: enable vectorized fast path (fix fastmath attr + f32 128b copy width) (#627)#650

Merged
coderfeli merged 2 commits into
ROCm:mainfrom
jhinpan:fix-issue-627-softmax-vec-fastpath
Jun 4, 2026
Merged

[Bugfix] softmax: enable vectorized fast path (fix fastmath attr + f32 128b copy width) (#627)#650
coderfeli merged 2 commits into
ROCm:mainfrom
jhinpan:fix-issue-627-softmax-vec-fastpath

Conversation

@jhinpan
Copy link
Copy Markdown
Contributor

@jhinpan jhinpan commented Jun 4, 2026

Summary

Re-enables the dead-coded vectorized (BufferCopy128b) fast path in kernels/softmax_kernel.py and fixes the two compile failures that kept it switched off (#627).

Three changes, all in the fast-path branch:

  1. Drop the False and guard (if const_expr(False and ...)) so an N that is a multiple of the tile width actually takes the vectorized path.
  2. fastmath=Truefastmath=fm_fast — the Python bool was stringified into an invalid #arith.fastmath<True> attribute (MLIRError: expected ... one of: none, reassoc, ...). The generic path already used arith.FastMathFlags.fast.
  3. vec_width = 128 // elem_bits (8 for 16-bit, 4 for f32) instead of a hard-coded VEC_WIDTH = 8.

Root cause of the Invalid cast! assertion

The hard LLVM abort

llvm/lib/IR/Instructions.cpp: CastInst::Create: Assertion `castIsValid(op, S, Ty) && "Invalid cast!"' failed.

is f32-specific, not bf16. The copy atom is BufferCopy128b — one 128-bit transaction per lane — but the register vector was always VEC_WIDTH = 8 elements:

dtype 8 × elem_bits vs 128-bit copy
f16/bf16 8 × 16 = 128b matches ✅
f32 8 × 32 = 256b mismatch → invalid cast ❌

Tying the vector width to the transaction (128 // elem_bits → 4 lanes for f32) makes the register vector match the 128-bit copy and lets f32 lower cleanly. This is the kernel's own documented "Vectorized Loads/Stores (WIDTH=8/4)" intent.

Verification (MI350X / gfx950, ROCm 7.2)

Correctness vs torch.softmax — fast path (single + multi-tile) and generic (non-multiple N) path, all dtypes:

M × N dtype path max abs err
256 × 1024 f32 fast (w=4) 9.3e-10
256 × 2048 f32 fast (w=4, 2 tiles) 4.7e-10
256 × 2048 f16 fast (w=8) 1.9e-06
4096 × 2048 bf16 fast (w=8) 1.6e-05
64 × 4096 bf16 fast (2 tiles) 7.5e-06
128 × 2000 bf16 generic 1.5e-05
8192 × 8192 bf16 fast 3.9e-06

IR confirms the vectorized path is actually taken (buffer_load_dwordx4, vector<8xf32> / vector<4xf32>, llvm.vector.reduce.fmax/fadd).

Performance — honest A/B (kernel time, fast vs scalar, identical shapes)

shape generic fast speedup
8192 × 8192 bf16 52.8 µs 49.5 µs 1.07×
16384 × 8192 bf16 108.9 µs 103.7 µs 1.05×
4096 × 16384 bf16 51.2 µs 49.8 µs 1.03×
8192 × 8192 f32 100.2 µs 102.1 µs 0.98×
≤ 2048-row shapes ~47 µs ~47 µs ~1.00× (launch-bound)

This is on-par to ~+7% on large bf16, neutral elsewhere — not a 2×. Both paths already saturate HBM (~5 TB/s) because the scalar loads coalesce across the wavefront, and both register-buffer the entire row, so vectorization only trims instruction count. (The 2.05× in the original sweep was FlyDSL-vs-Triton, not fast-vs-scalar.)

The value here is correctness: the change turns a latent Invalid cast! abort (for anyone who re-enables the path, or for f32) into a correct, enabled vectorized path, and removes the False and dead-code trap. A larger win would require redesigning the full-row register buffering into a streaming/tiled scheme — out of scope for this fix.

Fixes #627

🤖 Generated with Claude Code

…2 128b copy width)

Re-enables the dead-coded BufferCopy128b fast path in softmax_kernel.py:

- Drop the `False and` guard that made the fast path unreachable, so any
  N that is a multiple of the tile width takes the vectorized load/store.
- fastmath=True -> fm_fast: a Python bool was emitting an invalid
  `#arith.fastmath<True>` attribute (MLIRError). The generic path already
  used arith.FastMathFlags.fast.
- vec_width = 128 // elem_bits (8 for 16-bit, 4 for f32) so the register
  vector matches the single 128-bit BufferCopy128b transaction. f32 was
  8x32=256 bits against a 128-bit copy atom, which aborted in codegen with
  `CastInst::Create: Invalid cast!`. This is the kernel's documented
  "WIDTH=8/4" intent.

Verified on MI350X / gfx950 vs torch.softmax: f32 bit-exact (max_err ~5e-10),
bf16/f16 within atol, across single- and multi-tile shapes and the generic
(non-multiple-N) path. Kernel-time A/B vs the scalar path is on-par to +7%
on large bf16 and neutral elsewhere (both are HBM-bound and buffer the whole
row in registers) -- not the 2x originally hypothesized, but the change turns
a latent Invalid-cast abort into a correct, enabled vectorized path.

Fixes ROCm#627

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 4, 2026 02:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR adjusts the softmax kernel’s vectorization and enables an optimized “fast path” when N aligns with the tile width, aiming to improve correctness/performance across f16 vs f32.

Changes:

  • Compute vector width from element bit-width to match 128-bit buffer copy transactions.
  • Enable the fast path when N is a multiple of tile_cols.
  • Make exp2 fast-math behavior consistent with the existing fm_fast flag.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/softmax_kernel.py
Comment thread kernels/softmax_kernel.py
@coderfeli coderfeli merged commit 24c697e into ROCm:main Jun 4, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

softmax: vectorized fast path is dead-coded off and does not compile (fastmath=True + LLVM cast assertion)

3 participants