From 5b7585999ceb21273dafef4feed212f651df51e2 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Thu, 4 Jun 2026 02:03:40 +0000 Subject: [PATCH] [Bugfix] softmax: enable vectorized fast path (fix fastmath attr + f32 128b copy width) Re-enables the dead-coded BufferCopy128b fast path in softmax_kernel.py: - Drop the `False and` guard that made the fast path unreachable, so any N that is a multiple of the tile width takes the vectorized load/store. - fastmath=True -> fm_fast: a Python bool was emitting an invalid `#arith.fastmath` attribute (MLIRError). The generic path already used arith.FastMathFlags.fast. - vec_width = 128 // elem_bits (8 for 16-bit, 4 for f32) so the register vector matches the single 128-bit BufferCopy128b transaction. f32 was 8x32=256 bits against a 128-bit copy atom, which aborted in codegen with `CastInst::Create: Invalid cast!`. This is the kernel's documented "WIDTH=8/4" intent. Verified on MI350X / gfx950 vs torch.softmax: f32 bit-exact (max_err ~5e-10), bf16/f16 within atol, across single- and multi-tile shapes and the generic (non-multiple-N) path. Kernel-time A/B vs the scalar path is on-par to +7% on large bf16 and neutral elsewhere (both are HBM-bound and buffer the whole row in registers) -- not the 2x originally hypothesized, but the change turns a latent Invalid-cast abort into a correct, enabled vectorized path. Fixes #627 Co-Authored-By: Claude Opus 4.8 (1M context) --- kernels/softmax_kernel.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index a1d71623e..ec0e1c2ca 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -26,13 +26,15 @@ BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() -VEC_WIDTH = 8 def build_softmax_module(M: int, N: int, dtype_str: str = "f32"): - tile_cols = BLOCK_THREADS * VEC_WIDTH - RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 + # BufferCopy128b moves one 128-bit transaction per lane, so the register + # vector width must satisfy vec_width * elem_bits == 128 (8 for 16-bit, 4 for f32). + vec_width = 128 // elem_bits + tile_cols = BLOCK_THREADS * vec_width + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) @fx.struct class SharedStorage: @@ -101,7 +103,7 @@ def block_reduce(val, mode, s_red_buffer): # ================================================================== # Fast path: N is a multiple of tile_cols # ================================================================== - if const_expr(False and N >= tile_cols and N % tile_cols == 0): + if const_expr(N >= tile_cols and N % tile_cols == 0): num_tiles = N // tile_cols # ── Layout API: buffer-backed tensors + tiled access ───── A_buf = fx.rocdl.make_buffer_tensor(A) @@ -110,18 +112,18 @@ def block_reduce(val, mode, s_red_buffer): row_a = fx.slice(A_buf, (bid, None)) row_c = fx.slice(C_buf, (bid, None)) - a_div = fx.logical_divide(row_a, fx.make_layout(VEC_WIDTH, 1)) - c_div = fx.logical_divide(row_c, fx.make_layout(VEC_WIDTH, 1)) + a_div = fx.logical_divide(row_a, fx.make_layout(vec_width, 1)) + c_div = fx.logical_divide(row_c, fx.make_layout(vec_width, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) def _load_vec(div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) + r = fx.make_rmem_tensor(vec_width, elem_dtype) fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) return fx.memref_load_vec(r) def _store_vec(val, div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) + r = fx.make_rmem_tensor(vec_width, elem_dtype) fx.memref_store_vec(val, r) fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) @@ -145,7 +147,7 @@ def _store_vec(val, div_tensor, idx): for i in range_constexpr(num_tiles): x = row_buffer[i] scaled = (x - global_max) * c_log2e - exp_val = fmath.exp2(scaled, fastmath=True) + exp_val = fmath.exp2(scaled, fastmath=fm_fast) row_buffer[i] = exp_val red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sum = thread_sum + red_sum