Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions kernels/softmax_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@

BLOCK_THREADS = 256
WARP_SIZE = get_warp_size()
VEC_WIDTH = 8


def build_softmax_module(M: int, N: int, dtype_str: str = "f32"):
tile_cols = BLOCK_THREADS * VEC_WIDTH
RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE)
elem_bits = 32 if dtype_str == "f32" else 16
# BufferCopy128b moves one 128-bit transaction per lane, so the register
# vector width must satisfy vec_width * elem_bits == 128 (8 for 16-bit, 4 for f32).
vec_width = 128 // elem_bits
tile_cols = BLOCK_THREADS * vec_width
Comment thread
coderfeli marked this conversation as resolved.
RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE)

@fx.struct
class SharedStorage:
Expand Down Expand Up @@ -101,7 +103,7 @@ def block_reduce(val, mode, s_red_buffer):
# ==================================================================
# Fast path: N is a multiple of tile_cols
# ==================================================================
if const_expr(False and N >= tile_cols and N % tile_cols == 0):
if const_expr(N >= tile_cols and N % tile_cols == 0):
Comment thread
coderfeli marked this conversation as resolved.
num_tiles = N // tile_cols
# ── Layout API: buffer-backed tensors + tiled access ─────
A_buf = fx.rocdl.make_buffer_tensor(A)
Expand All @@ -110,18 +112,18 @@ def block_reduce(val, mode, s_red_buffer):
row_a = fx.slice(A_buf, (bid, None))
row_c = fx.slice(C_buf, (bid, None))

a_div = fx.logical_divide(row_a, fx.make_layout(VEC_WIDTH, 1))
c_div = fx.logical_divide(row_c, fx.make_layout(VEC_WIDTH, 1))
a_div = fx.logical_divide(row_a, fx.make_layout(vec_width, 1))
c_div = fx.logical_divide(row_c, fx.make_layout(vec_width, 1))

copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits)

def _load_vec(div_tensor, idx):
r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype)
r = fx.make_rmem_tensor(vec_width, elem_dtype)
fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r)
return fx.memref_load_vec(r)

def _store_vec(val, div_tensor, idx):
r = fx.make_rmem_tensor(VEC_WIDTH, elem_dtype)
r = fx.make_rmem_tensor(vec_width, elem_dtype)
fx.memref_store_vec(val, r)
fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx)))

Expand All @@ -145,7 +147,7 @@ def _store_vec(val, div_tensor, idx):
for i in range_constexpr(num_tiles):
x = row_buffer[i]
scaled = (x - global_max) * c_log2e
exp_val = fmath.exp2(scaled, fastmath=True)
exp_val = fmath.exp2(scaled, fastmath=fm_fast)
row_buffer[i] = exp_val
red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast)
thread_sum = thread_sum + red_sum
Expand Down
Loading