diff --git a/aten/plena/compiler.py b/aten/plena/compiler.py index 6d4c4d1..0fb46c9 100644 --- a/aten/plena/compiler.py +++ b/aten/plena/compiler.py @@ -31,12 +31,20 @@ class PlenaCompiler( program-builder helpers on top. Operations eagerly emit ISA text. """ - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + def __init__( + self, + mlen: int = 64, + blen: int = 4, + real_data_ratio: float = 1.125, + unroll_loops: bool = False, + mram_tile_capacity: int = 4, + ): """ Args: mlen: Matrix tile size (default 64) blen: Vector tile size (default 4) real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) + mram_tile_capacity: Number of mlen x mlen tiles that fit in MRAM. unroll_loops: If True, unroll sub-projection and attention helper loops at ASM-gen time to eliminate C_LOOP_START/END overhead. Overridden by the ATEN_UNROLL env var ("1"=True, "0"=False). @@ -46,7 +54,13 @@ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125 unroll_loops = True elif _env_unroll == "0": unroll_loops = False - super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops) + super().__init__( + mlen=mlen, + blen=blen, + real_data_ratio=real_data_ratio, + unroll_loops=unroll_loops, + mram_tile_capacity=mram_tile_capacity, + ) # HBM address auto-allocation self._next_hbm_addr: int = 0 diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py index 6bb6c09..4621d65 100644 --- a/aten/plena/isa_compiler.py +++ b/aten/plena/isa_compiler.py @@ -36,9 +36,21 @@ class IsaCompiler( _ONLINE_SOFTMAX_FPSRAM_BASE = 10 - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + def __init__( + self, + mlen: int = 64, + blen: int = 4, + real_data_ratio: float = 1.125, + unroll_loops: bool = False, + mram_tile_capacity: int = 4, + ): # MemoryStateMixin.__init__ sets dimensions, layout tables, and memory allocators. - super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) + super().__init__( + mlen=mlen, + blen=blen, + unroll_loops=unroll_loops, + mram_tile_capacity=mram_tile_capacity, + ) self.real_data_ratio = real_data_ratio self.register_allocator = RegisterAllocator() self.generated_code = "" diff --git a/aten/plena/memory.py b/aten/plena/memory.py index e09bf54..6b50ec1 100644 --- a/aten/plena/memory.py +++ b/aten/plena/memory.py @@ -357,12 +357,22 @@ class MRAMAllocator(MemoryAllocatorBase): """ Matrix RAM address allocator. - Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. - Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). + Each sub-block is mlen x mlen elements; aligned to mlen*mlen. + By default the allocator holds four matrix tiles. """ - def __init__(self, total_size: int = MLEN * MLEN * 4): - super().__init__(total_size=total_size, alignment=MLEN * MLEN, mem_name="MRAM") + def __init__(self, total_size: int | None = None, *, mlen: int = MLEN, tile_capacity: int = 4): + if mlen <= 0: + raise ValueError(f"mlen must be > 0, got {mlen}") + if tile_capacity <= 0: + raise ValueError(f"tile_capacity must be > 0, got {tile_capacity}") + + self.mlen = mlen + self.tile_capacity = tile_capacity + self.tile_elems = mlen * mlen + if total_size is None: + total_size = self.tile_elems * tile_capacity + super().__init__(total_size=total_size, alignment=self.tile_elems, mem_name="MRAM") def allocate(self, name: str, size: int) -> int: return self._vmm.allocate(name, size) diff --git a/aten/plena/memory_state.py b/aten/plena/memory_state.py index 6e4482e..cae8bd2 100644 --- a/aten/plena/memory_state.py +++ b/aten/plena/memory_state.py @@ -18,18 +18,32 @@ class MemoryStateMixin: """Sub-matrix layout manager for PLENA HBM, VRAM, MRAM, and FPRAM state.""" - def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): + def __init__( + self, + mlen: int = MLEN, + blen: int = BLEN, + unroll_loops: bool = False, + mram_tile_capacity: int = 4, + ): + if mlen <= 0: + raise ValueError(f"mlen must be > 0, got {mlen}") + if mram_tile_capacity <= 0: + raise ValueError(f"mram_tile_capacity must be > 0, got {mram_tile_capacity}") + self.mlen = mlen self.blen = blen self.unroll_loops = unroll_loops + self.mram_tile_capacity = mram_tile_capacity + self.mram_tile_elems = mlen * mlen + self.mram_capacity_elems = self.mram_tile_capacity * self.mram_tile_elems # Layout tables self.hbm_matrices: dict[str, MatrixBlockLayout] = {} self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} # Memory Allocators - self.vram_allocator = VRAMAllocator() - self.mram_allocator = MRAMAllocator() + self.vram_allocator = VRAMAllocator(alignment=mlen) + self.mram_allocator = MRAMAllocator(mlen=mlen, tile_capacity=mram_tile_capacity) self.fpram_allocator = FPRAMAllocator() def __contains__(self, name: str) -> bool: diff --git a/aten/plena/program_matrix_ops.py b/aten/plena/program_matrix_ops.py index 22224c6..6d21019 100644 --- a/aten/plena/program_matrix_ops.py +++ b/aten/plena/program_matrix_ops.py @@ -6,13 +6,13 @@ from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar -MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements - -def _iter_k_chunks(num_k_tiles: int): +def _iter_k_chunks(num_k_tiles: int, max_k_tiles: int): + if max_k_tiles <= 0: + raise ValueError(f"max_k_tiles must be > 0, got {max_k_tiles}") k_start = 0 while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) + k_end = min(k_start + max_k_tiles, num_k_tiles) yield k_start, k_end - k_start k_start = k_end @@ -137,6 +137,7 @@ def linear_projection(self, input_var: VRAMMatrixVar, weight_var: InputVar, name raise ValueError(f"out_features ({out_features}) must be a multiple of mlen ({mlen})") num_col_blocks = out_features // mlen num_k_tiles = math.ceil(k_total / mlen) + max_k_tiles = self.mram_tile_capacity # When rows is not a multiple of mlen the hardware still operates on # full tiles; only the first `rows` rows contain valid output. @@ -154,7 +155,7 @@ def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, ** **k_split, ) - if num_k_tiles <= MAX_K_TILES: + if num_k_tiles <= max_k_tiles: for col_idx in range(num_col_blocks): for row_idx in range(num_row_blocks): emit_projection(row_idx, col_idx, output, row_idx, col_idx) @@ -163,7 +164,7 @@ def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, ** # Temp buffer for one partial-sum tile. Allocating the full output shape # here can overlap with the real output for wide projections. temp = self.alloc(f"{name}_temp", mlen, mlen) - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)): + for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles, max_k_tiles)): k_split = { "k_block_start": k_block_start, "k_block_count": k_block_count, diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 2242fc7..57633b9 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -270,6 +270,7 @@ def compile_native_hf_decoder( layer_idx_start: int = 0, mlen: int = 64, blen: int = 4, + mram_tile_capacity: int = 4, seed: int = 42, golden_precision: str = "hardware", verbose: bool = False, @@ -305,7 +306,10 @@ def _verbose(message: str = ""): f" decoder: hidden={hidden}, inter={inter}, heads={num_heads}/{num_kv_heads}, " f"head_dim={head_dim}" ) - print(f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, total_q_dim={total_q_dim}") + print( + f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, " + f"mram_tile_capacity={mram_tile_capacity}, total_q_dim={total_q_dim}" + ) print("=" * 80) # ----------------------------------------------------------- weights @@ -367,6 +371,7 @@ def _verbose(message: str = ""): cos_table, sin_table, mlen=mlen, + max_k_tiles=mram_tile_capacity, precision=golden_policy, trace=lambda i, x: _verbose(f" After layer {i}: X_gold[0,:4] = {x[0, :4].tolist()}"), ) @@ -384,6 +389,7 @@ def _verbose(message: str = ""): cos_table, sin_table, mlen=mlen, + max_k_tiles=mram_tile_capacity, precision=ReferencePrecision.from_mode("hf_fp32"), trace=lambda i, x: _verbose(f" After layer {i}: X_hf[0,:4] = {x[0, :4].tolist()}"), ) @@ -396,7 +402,12 @@ def _verbose(message: str = ""): registry = OpRegistry.load() registry.set_backend(Backend.PLENA) - prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO) + prog = PlenaCompiler( + mlen=mlen, + blen=blen, + real_data_ratio=REAL_DATA_RATIO, + mram_tile_capacity=mram_tile_capacity, + ) # Shared inputs x_input = prog.input("X", shape=(seq_len, hidden)) diff --git a/aten/reference.py b/aten/reference.py index bde2020..8b201a5 100644 --- a/aten/reference.py +++ b/aten/reference.py @@ -111,6 +111,7 @@ def run_decoder_reference( sin_table: torch.Tensor, *, mlen: int, + max_k_tiles: int = _HW_MAX_K_TILES, precision: ReferencePrecision, trace: Callable[[int, torch.Tensor], None] | None = None, ) -> torch.Tensor: @@ -128,9 +129,10 @@ def run_decoder_reference( cos_table, sin_table, mlen, + max_k_tiles, precision, ) - x = _ffn_block_ref(x, layer, mlen, precision) + x = _ffn_block_ref(x, layer, mlen, max_k_tiles, precision) if trace is not None: trace(layer_idx, x) @@ -146,18 +148,19 @@ def _attention_block_ref( cos_table: torch.Tensor, sin_table: torch.Tensor, mlen: int, + max_k_tiles: int, precision: ReferencePrecision, ) -> torch.Tensor: quantize = precision.quantize residual = x.clone() x_normed = _rms_norm_ref(x, layer.eps, precision) - q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, precision), precision) + q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, max_k_tiles, precision), precision) k_heads = [] v_heads = [] for kv_h in range(config.num_kv_heads): - k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, precision) - v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, precision) + k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, max_k_tiles, precision) + v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, max_k_tiles, precision) k_h = _rope_ref(k_h, rope_matrix, cos_table, sin_table, precision) k_heads.append(_hbm_round_ref(k_h, precision)) v_heads.append(_hbm_round_ref(v_h, precision)) @@ -171,7 +174,7 @@ def _attention_block_ref( o_heads.append(_flash_attn_ref(q_h, k_heads[kv_h], v_heads[kv_h], scale, causal=True)) attn_out = _round(torch.cat(o_heads, dim=1), precision) - o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, precision), precision) + o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, max_k_tiles, precision), precision) return _residual_add_ref(o_proj, residual, precision) @@ -179,15 +182,16 @@ def _ffn_block_ref( x: torch.Tensor, layer: LayerWeights, mlen: int, + max_k_tiles: int, precision: ReferencePrecision, ) -> torch.Tensor: quantize = precision.quantize residual = x.clone() x_normed = _rms_norm_ref(x, layer.eps, precision) - up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, precision) - gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, precision) + up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, max_k_tiles, precision) + gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, max_k_tiles, precision) silu_gate = precision.to_inter(F.silu(_round(up_out, precision)) * _round(gate_out, precision)) - x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, precision) + x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, max_k_tiles, precision) return _residual_add_ref(_round(x, precision), residual, precision) @@ -205,6 +209,7 @@ def _linear_ref( x: torch.Tensor, weight: torch.Tensor, mlen: int, + max_k_tiles: int, precision: ReferencePrecision, ) -> torch.Tensor: if not precision.use_ksplit: @@ -213,7 +218,7 @@ def _linear_ref( x, weight, mlen=mlen, - max_k_tiles=_HW_MAX_K_TILES, + max_k_tiles=max_k_tiles, to_inter=precision.to_inter, from_inter=precision.from_inter, ) diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index 1011f4d..cd51e69 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -138,6 +138,70 @@ def test_alloc_at_correct_address(): print(" PASS test_alloc_at_correct_address") +def test_mram_allocator_scales_with_runtime_mlen(): + """MRAMAllocator capacity/alignment must use runtime mlen, not module MLEN=64.""" + from compiler.aten.plena.memory import MRAMAllocator + + alloc = MRAMAllocator(mlen=256, tile_capacity=4) + tile_elems = 256 * 256 + + assert alloc.alignment == tile_elems + assert alloc.total_size == 4 * tile_elems + assert alloc.tile_elems == tile_elems + assert alloc.tile_capacity == 4 + + addrs = [alloc.allocate(f"tile_{i}", tile_elems) for i in range(4)] + assert addrs == [0, tile_elems, 2 * tile_elems, 3 * tile_elems] + + try: + alloc.allocate("overflow", tile_elems) + except MemoryError: + pass + else: + raise AssertionError("MRAMAllocator accepted a fifth tile despite tile_capacity=4") + + print(" PASS test_mram_allocator_scales_with_runtime_mlen") + + +def test_compiler_threads_runtime_memory_geometry(): + """PlenaCompiler must pass runtime mlen/capacity into VRAM and MRAM allocators.""" + from compiler.aten.plena import PlenaCompiler + + prog = PlenaCompiler(mlen=256, blen=64, mram_tile_capacity=3) + tile_elems = 256 * 256 + + assert prog.mram_tile_capacity == 3 + assert prog.mram_tile_elems == tile_elems + assert prog.mram_capacity_elems == 3 * tile_elems + assert prog.mram_allocator.alignment == tile_elems + assert prog.mram_allocator.total_size == 3 * tile_elems + assert prog.vram_allocator.alignment == 256 + + print(" PASS test_compiler_threads_runtime_memory_geometry") + + +def test_linear_projection_uses_runtime_mram_tile_capacity(): + """K-split should follow prog.mram_tile_capacity instead of a module constant.""" + from compiler.aten.plena import PlenaCompiler + + prog = PlenaCompiler(mlen=128, blen=4, mram_tile_capacity=2) + x_input = prog.input("X", shape=(128, 384), prestaged_vram_addr=0) + x = prog.load_batch(x_input, name="X") + w = prog.input("W", shape=(384, 128)) + + prog.linear_projection(x, w, name="Y") + code = prog.compile() + + # K=384 at mlen=128 is 3 K-tiles. With runtime capacity 2 this must split + # into two projection chunks and accumulate the second partial sum. + assert prog.mram_allocator.total_size == 2 * 128 * 128 + assert "V_ADD_VV" in code + assert "linear_out_temp" not in code + assert "Y_temp" in code + + print(" PASS test_linear_projection_uses_runtime_mram_tile_capacity") + + def test_fix_large_immediates_roundtrip(): """_fix_large_immediates must preserve exact address values.""" from compiler.aten.plena_frontend import _fix_large_immediates @@ -287,6 +351,9 @@ def test_native_compile_assembles(): test_vram_fill_zero_all_column_blocks, test_vram_add_all_column_blocks, test_alloc_at_correct_address, + test_mram_allocator_scales_with_runtime_mlen, + test_compiler_threads_runtime_memory_geometry, + test_linear_projection_uses_runtime_mram_tile_capacity, test_fix_large_immediates_roundtrip, test_fix_large_immediates_preserves_relative_adds, test_rotate_half_matrix_identity,