From 6b8ae9884b223c798cb1643d865169bdd276de22 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Mon, 18 May 2026 23:52:37 +0100 Subject: [PATCH] Loop ATen MHA attention helpers --- aten/plena/compiler.py | 6 +- aten/plena/isa_attention.py | 325 ++++++++++++++++++++++++++++++++++-- aten/plena/isa_compiler.py | 1 + 3 files changed, 312 insertions(+), 20 deletions(-) diff --git a/aten/plena/compiler.py b/aten/plena/compiler.py index 91e006b..6d4c4d1 100644 --- a/aten/plena/compiler.py +++ b/aten/plena/compiler.py @@ -37,9 +37,9 @@ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125 mlen: Matrix tile size (default 64) blen: Vector tile size (default 4) real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) - unroll_loops: If True, unroll sub-projection loops at ASM-gen time to - eliminate C_LOOP_START/END overhead. Overridden by the - ATEN_UNROLL env var ("1"=True, "0"=False). + unroll_loops: If True, unroll sub-projection and attention helper loops + at ASM-gen time to eliminate C_LOOP_START/END overhead. + Overridden by the ATEN_UNROLL env var ("1"=True, "0"=False). """ _env_unroll = os.environ.get("ATEN_UNROLL", "") if _env_unroll == "1": diff --git a/aten/plena/isa_attention.py b/aten/plena/isa_attention.py index 455a639..15b9a80 100644 --- a/aten/plena/isa_attention.py +++ b/aten/plena/isa_attention.py @@ -30,11 +30,20 @@ def _online_softmax_asm( [mlen, 2*mlen): m_res = exp(m_old - m_curr) [2*mlen, 3*mlen): l_old / l_new """ - gp_regs = self.register_allocator.allocate_gp(4) + if getattr(self, "unroll_attention", False): + return self._online_softmax_asm_unrolled( + mlen=mlen, + s_address=s_address, + m_start_address=m_start_address, + scale=scale, + ) + + gp_regs = self.register_allocator.allocate_gp(5) gp_s = gp_regs[0] gp_m_addr = gp_regs[1] gp_m_res_addr = gp_regs[2] gp_l_addr = gp_regs[3] + gp_loop = gp_regs[4] # Fixed FP register allocation for online softmax pipeline. # These registers are shared across _online_softmax_asm, _scale_o_asm, @@ -57,12 +66,82 @@ def _online_softmax_asm( lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") # scale factor is pre-loaded at FP SRAM addr 1 by the flash-attention driver. + if scale != 1.0: + lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") + + lines.append(f"C_LOOP_START gp{gp_loop}, {mlen}") + lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, 0") + lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") + + if scale != 1.0: + lines.append(f"V_MUL_VF gp{gp_s}, gp{gp_s}, f{fp_scale}, 0") + + lines.append(f"V_RED_MAX f{fp_row_max}, gp{gp_s}, 0") + + # m_curr = max(row_max, m_old) — online softmax must retain the running max. + lines.append(f"S_MAX_FP f{fp_m_old}, f{fp_row_max}, f{fp_m_old}") + + lines.append(f"S_SUB_FP f{fp_m_res}, f{fp_m_res}, f{fp_m_old}") + lines.append(f"S_EXP_FP f{fp_m_res}, f{fp_m_res}, 0") + + lines.append(f"S_ST_FP f{fp_m_res}, gp{gp_m_res_addr}, 0") + lines.append(f"S_ST_FP f{fp_m_old}, gp{gp_m_addr}, 0") + + lines.append(f"V_SUB_VF gp{gp_s}, gp{gp_s}, f{fp_m_old}, 0, 0") + lines.append(f"V_EXP_V gp{gp_s}, gp{gp_s}, 0, 0") + + lines.append(f"S_LD_FP f{fp_l_old}, gp{gp_l_addr}, 0") + + lines.append(f"S_ADD_FP f{fp_sum_p}, f0, f0") + lines.append(f"V_RED_SUM f{fp_sum_p}, gp{gp_s}, 0, 0") + + lines.append(f"S_MUL_FP f{fp_l_old}, f{fp_l_old}, f{fp_m_res}") + lines.append(f"S_ADD_FP f{fp_l_old}, f{fp_l_old}, f{fp_sum_p}") + + lines.append(f"S_ST_FP f{fp_l_old}, gp{gp_l_addr}, 0") + + lines.append(f"S_ADDI_INT gp{gp_s}, gp{gp_s}, {mlen}") + lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp{gp_m_addr}, 1") + lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_res_addr}, 1") + lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_l_addr}, 1") + lines.append(f"C_LOOP_END gp{gp_loop}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _online_softmax_asm_unrolled( + self, + mlen: int, + s_address: int, + m_start_address: int, + scale: float = 1.0, + ) -> str: + """Legacy Python-unrolled online softmax emission, kept for A/B comparisons.""" + gp_regs = self.register_allocator.allocate_gp(4) + gp_s = gp_regs[0] + gp_m_addr = gp_regs[1] + gp_m_res_addr = gp_regs[2] + gp_l_addr = gp_regs[3] + + fp_m_old = 1 + fp_m_res = 2 + fp_l_old = 3 + fp_sum_p = 4 + fp_scale = 5 + fp_row_max = 6 + + lines = [] + lines.append("; === Online Softmax ===") + lines.append(f"S_ADDI_INT gp{gp_s}, gp0, {s_address}") + lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp0, {m_start_address}") + lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_addr}, {mlen}") + lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") + if scale != 1.0: lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") for row in range(mlen): lines.append(f"; Row {row}") - lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, {row}") lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") @@ -122,14 +201,29 @@ def _pv_multiply_asm( """ assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - gp_regs = self.register_allocator.allocate_gp(5) + if getattr(self, "unroll_attention", False): + return self._pv_multiply_asm_unrolled( + mlen=mlen, + blen=blen, + head_dim=head_dim, + p_address=p_address, + v_hbm_offset_reg=v_hbm_offset_reg, + v_hbm_offset=v_hbm_offset, + pv_address=pv_address, + ) + + gp_regs = self.register_allocator.allocate_gp(8) gp_p = gp_regs[0] gp_v = gp_regs[1] gp_pv = gp_regs[2] gp_hbm = gp_regs[3] gp_stride = gp_regs[4] + gp_pv_col_base = gp_regs[5] + gp_v_loop = gp_regs[6] + gp_p_loop = gp_regs[7] num_v_col_blocks = head_dim // mlen + tiles_per_mlen = mlen // blen lines = [] lines.append("; === PV Multiply (P @ V) using M_MM ===") @@ -157,21 +251,71 @@ def _pv_multiply_asm( lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") - # mat_offset constraint: < mlen and a multiple of blen. + pv_col_block_base = pv_address + v_col_block * mlen * mlen + lines.append(f"S_ADDI_INT gp{gp_pv_col_base}, gp0, {pv_col_block_base}") + lines.append(f"C_LOOP_START gp{gp_v_loop}, {tiles_per_mlen}") + lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_address}") + lines.append(f"S_ADDI_INT gp{gp_pv}, gp{gp_pv_col_base}, 0") + lines.append(f"C_LOOP_START gp{gp_p_loop}, {tiles_per_mlen}") + lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") + lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") + lines.append(f"S_ADDI_INT gp{gp_p}, gp{gp_p}, {blen * mlen}") + lines.append(f"S_ADDI_INT gp{gp_pv}, gp{gp_pv}, {blen * mlen}") + lines.append(f"C_LOOP_END gp{gp_p_loop}") + lines.append(f"S_ADDI_INT gp{gp_v}, gp{gp_v}, {blen}") + lines.append(f"S_ADDI_INT gp{gp_pv_col_base}, gp{gp_pv_col_base}, {blen}") + lines.append(f"C_LOOP_END gp{gp_v_loop}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _pv_multiply_asm_unrolled( + self, + mlen: int, + blen: int, + head_dim: int, + p_address: int, + v_hbm_offset_reg: int, + v_hbm_offset: int, + pv_address: int, + ) -> str: + """Legacy Python-unrolled P @ V emission, kept for A/B comparisons.""" + gp_regs = self.register_allocator.allocate_gp(5) + gp_p = gp_regs[0] + gp_v = gp_regs[1] + gp_pv = gp_regs[2] + gp_hbm = gp_regs[3] + gp_stride = gp_regs[4] + + num_v_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === PV Multiply (P @ V) using M_MM ===") + lines.append(f"; P: ({mlen}, {mlen}) @ V: ({mlen}, {head_dim}) -> PV: ({mlen}, {head_dim})") + lines.append("; M_MM: (blen, mlen) @ (mlen, blen) -> (blen, blen), K=mlen in one shot") + lines.append(f"; V split into {num_v_col_blocks} column blocks of width {mlen}") + lines.append("; Storage layout: (batch, mlen, hidden/mlen), column-block major") + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for v_col_block in range(num_v_col_blocks): + lines.append( + f"; --- V column block {v_col_block} (columns {v_col_block * mlen} to {(v_col_block + 1) * mlen - 1}) ---" + ) + v_block_hbm_offset = v_hbm_offset + v_col_block * mlen + lines.append(f"S_ADDI_INT gp{gp_v}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") + lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") + for v_col in range(mlen // blen): lines.append(f"; V column {v_col_block * mlen + v_col * blen}") - v_msram_offset = v_col * blen lines.append(f"S_ADDI_INT gp{gp_v}, gp0, {v_msram_offset}") for p_row in range(mlen // blen): p_row_addr = p_address + p_row * blen * mlen lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_row_addr}") - lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") - # PV[row, col] addr = base + col_block * mlen * mlen + row * mlen + col_in_block - # with row = p_row * blen and col_in_block = v_col * blen. pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_address + pv_offset}") lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") @@ -191,6 +335,70 @@ def _scale_o_asm( """Scale each row of O by m_res: O[row] *= m_res[row].""" assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + if getattr(self, "unroll_attention", False): + return self._scale_o_asm_unrolled( + mlen=mlen, + head_dim=head_dim, + seq_len=seq_len, + m_res_address=m_res_address, + o_address=o_address, + row_offset=row_offset, + ) + + gp_regs = self.register_allocator.allocate_gp(4) + gp_m_res = gp_regs[0] + gp_o_row_base = gp_regs[1] + gp_o = gp_regs[2] + gp_row_loop = gp_regs[3] + fp_m_res = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Scale O by m_res ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + if num_col_blocks == 1: + o_addr = o_address + row_offset * mlen + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"C_LOOP_START gp{gp_row_loop}, {mlen}") + lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, 0") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp{gp_m_res}, 1") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o}, {mlen}") + lines.append(f"C_LOOP_END gp{gp_row_loop}") + else: + gp_col_loop = self.register_allocator.allocate_gp(1)[0] + o_addr = o_address + row_offset * mlen + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") + lines.append(f"S_ADDI_INT gp{gp_o_row_base}, gp0, {o_addr}") + lines.append(f"C_LOOP_START gp{gp_row_loop}, {mlen}") + lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, 0") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o_row_base}, 0") + lines.append(f"C_LOOP_START gp{gp_col_loop}, {num_col_blocks}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o}, {seq_len * mlen}") + lines.append(f"C_LOOP_END gp{gp_col_loop}") + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp{gp_m_res}, 1") + lines.append(f"S_ADDI_INT gp{gp_o_row_base}, gp{gp_o_row_base}, {mlen}") + lines.append(f"C_LOOP_END gp{gp_row_loop}") + self.register_allocator.free_gp([gp_col_loop]) + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _scale_o_asm_unrolled( + self, + mlen: int, + head_dim: int, + seq_len: int, + m_res_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Legacy Python-unrolled O *= m_res emission, kept for A/B comparisons.""" gp_regs = self.register_allocator.allocate_gp(2) gp_m_res = gp_regs[0] gp_o = gp_regs[1] @@ -202,7 +410,6 @@ def _scale_o_asm( lines.append("; === Scale O by m_res ===") lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") for row in range(mlen): @@ -271,6 +478,73 @@ def _final_scaling_asm( """ assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + if getattr(self, "unroll_attention", False): + return self._final_scaling_asm_unrolled( + mlen=mlen, + head_dim=head_dim, + seq_len=seq_len, + l_address=l_address, + o_address=o_address, + row_offset=row_offset, + ) + + gp_regs = self.register_allocator.allocate_gp(4) + gp_l = gp_regs[0] + gp_o_row_base = gp_regs[1] + gp_o = gp_regs[2] + gp_row_loop = gp_regs[3] + fp_l = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Final Scaling O = O / l ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + if num_col_blocks == 1: + o_addr = o_address + row_offset * mlen + lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"C_LOOP_START gp{gp_row_loop}, {mlen}") + lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, 0") + lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") + lines.append(f"S_ADDI_INT gp{gp_l}, gp{gp_l}, 1") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o}, {mlen}") + lines.append(f"C_LOOP_END gp{gp_row_loop}") + else: + gp_col_loop = self.register_allocator.allocate_gp(1)[0] + o_addr = o_address + row_offset * mlen + lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") + lines.append(f"S_ADDI_INT gp{gp_o_row_base}, gp0, {o_addr}") + lines.append(f"C_LOOP_START gp{gp_row_loop}, {mlen}") + lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, 0") + lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o_row_base}, 0") + lines.append(f"C_LOOP_START gp{gp_col_loop}, {num_col_blocks}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o}, {seq_len * mlen}") + lines.append(f"C_LOOP_END gp{gp_col_loop}") + lines.append(f"S_ADDI_INT gp{gp_l}, gp{gp_l}, 1") + lines.append(f"S_ADDI_INT gp{gp_o_row_base}, gp{gp_o_row_base}, {mlen}") + lines.append(f"C_LOOP_END gp{gp_row_loop}") + self.register_allocator.free_gp([gp_col_loop]) + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _final_scaling_asm_unrolled( + self, + mlen: int, + head_dim: int, + seq_len: int, + l_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Legacy Python-unrolled final O /= l emission, kept for A/B comparisons.""" gp_regs = self.register_allocator.allocate_gp(2) gp_l = gp_regs[0] gp_o = gp_regs[1] @@ -283,7 +557,6 @@ def _final_scaling_asm( lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") for row in range(mlen): @@ -306,8 +579,9 @@ def _reset_fpsram_asm( value_address: int, # FP SRAM slot: 0 = zero, 2 = -inf ) -> str: """Reset a region of FP SRAM to the value at value_address.""" - gp_regs = self.register_allocator.allocate_gp(1) + gp_regs = self.register_allocator.allocate_gp(2) gp_addr = gp_regs[0] + gp_loop = gp_regs[1] lines = [] lines.append(f"; Reset FP SRAM [{start_address}, {start_address + count})") @@ -316,8 +590,14 @@ def _reset_fpsram_asm( # Use f1 for FP scalar - FP registers don't go through GP allocator lines.append(f"S_LD_FP f1, gp0, {value_address}") - for i in range(count): - lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") + if getattr(self, "unroll_attention", False): + for i in range(count): + lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") + else: + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_ST_FP f1, gp{gp_addr}, 0") + lines.append(f"S_ADDI_INT gp{gp_addr}, gp{gp_addr}, 1") + lines.append(f"C_LOOP_END gp{gp_loop}") self.register_allocator.free_gp(gp_regs) return "\n".join(lines) + "\n" @@ -337,8 +617,9 @@ def _reset_vram_asm( V_MUL_VF processes mlen elements at a time; when cols > mlen, each row is split into cols // mlen mlen-wide blocks. """ - gp_regs = self.register_allocator.allocate_gp(1) + gp_regs = self.register_allocator.allocate_gp(2) gp_addr = gp_regs[0] + gp_loop = gp_regs[1] num_col_blocks = (cols + mlen - 1) // mlen @@ -348,12 +629,22 @@ def _reset_vram_asm( lines.append("; Storage layout: (total_rows, mlen, cols/mlen), column-block major") lines.append(f"; total_rows = {total_rows}, row_offset = {row_offset}") - for row in range(rows): - actual_row = row_offset + row + if getattr(self, "unroll_attention", False): + for row in range(rows): + actual_row = row_offset + row + for col_block in range(num_col_blocks): + addr = start_address + col_block * total_rows * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") + lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") + else: for col_block in range(num_col_blocks): - addr = start_address + col_block * total_rows * mlen + actual_row * mlen + addr = start_address + col_block * total_rows * mlen + row_offset * mlen + lines.append(f"; Column block {col_block}") lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {rows}") lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") + lines.append(f"S_ADDI_INT gp{gp_addr}, gp{gp_addr}, {mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") self.register_allocator.free_gp(gp_regs) return "\n".join(lines) + "\n" diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py index 6deb05b..6bb6c09 100644 --- a/aten/plena/isa_compiler.py +++ b/aten/plena/isa_compiler.py @@ -42,6 +42,7 @@ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125 self.real_data_ratio = real_data_ratio self.register_allocator = RegisterAllocator() self.generated_code = "" + self.unroll_attention = unroll_loops def load_batch( self,