diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index a1d71623e..ec0e1c2ca 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -26,13 +26,15 @@ BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() -VEC_WIDTH = 8 def build_softmax_module(M: int, N: int, dtype_str: str = "f32"): - tile_cols = BLOCK_THREADS * VEC_WIDTH - RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) elem_bits = 32 if dtype_str == "f32" else 16 + # BufferCopy128b moves one 128-bit transaction per lane, so the register + # vector width must satisfy vec_width * elem_bits == 128 (8 for 16-bit, 4 for f32). + vec_width = 128 // elem_bits + tile_cols = BLOCK_THREADS * vec_width + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) @fx.struct class SharedStorage: @@ -101,7 +103,7 @@ def block_reduce(val, mode, s_red_buffer): # ================================================================== # Fast path: N is a multiple of tile_cols # ================================================================== - if const_expr(False and N >= tile_cols and N % tile_cols == 0): + if const_expr(N >= tile_cols and N % tile_cols == 0): num_tiles = N // tile_cols # ── Layout API: buffer-backed tensors + tiled access ───── A_buf = fx.rocdl.make_buffer_tensor(A) @@ -110,18 +112,18 @@ def block_reduce(val, mode, s_red_buffer): row_a = fx.slice(A_buf, (bid, None)) row_c = fx.slice(C_buf, (bid, None)) - a_div = fx.logical_divide(row_a, fx.make_layout(VEC_WIDTH, 1)) - c_div = fx.logical_divide(row_c, fx.make_layout(VEC_WIDTH, 1)) + a_div = fx.logical_divide(row_a, fx.make_layout(vec_width, 1)) + c_div = fx.logical_divide(row_c, fx.make_layout(vec_width, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) def _load_vec(div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) + r = fx.make_rmem_tensor(vec_width, elem_dtype) fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) return fx.memref_load_vec(r) def _store_vec(val, div_tensor, idx): - r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype) + r = fx.make_rmem_tensor(vec_width, elem_dtype) fx.memref_store_vec(val, r) fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) @@ -145,7 +147,7 @@ def _store_vec(val, div_tensor, idx): for i in range_constexpr(num_tiles): x = row_buffer[i] scaled = (x - global_max) * c_log2e - exp_val = fmath.exp2(scaled, fastmath=True) + exp_val = fmath.exp2(scaled, fastmath=fm_fast) row_buffer[i] = exp_val red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sum = thread_sum + red_sum