Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions aten/plena/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@ class PlenaCompiler(
program-builder helpers on top. Operations eagerly emit ISA text.
"""

def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False):
def __init__(
self,
mlen: int = 64,
blen: int = 4,
real_data_ratio: float = 1.125,
unroll_loops: bool = False,
mram_tile_capacity: int = 4,
):
"""
Args:
mlen: Matrix tile size (default 64)
blen: Vector tile size (default 4)
real_data_ratio: HBM storage ratio (MXFP8 format = 1.125)
mram_tile_capacity: Number of mlen x mlen tiles that fit in MRAM.
unroll_loops: If True, unroll sub-projection and attention helper loops
at ASM-gen time to eliminate C_LOOP_START/END overhead.
Overridden by the ATEN_UNROLL env var ("1"=True, "0"=False).
Expand All @@ -46,7 +54,13 @@ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125
unroll_loops = True
elif _env_unroll == "0":
unroll_loops = False
super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops)
super().__init__(
mlen=mlen,
blen=blen,
real_data_ratio=real_data_ratio,
unroll_loops=unroll_loops,
mram_tile_capacity=mram_tile_capacity,
)

# HBM address auto-allocation
self._next_hbm_addr: int = 0
Expand Down
16 changes: 14 additions & 2 deletions aten/plena/isa_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,21 @@ class IsaCompiler(

_ONLINE_SOFTMAX_FPSRAM_BASE = 10

def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False):
def __init__(
self,
mlen: int = 64,
blen: int = 4,
real_data_ratio: float = 1.125,
unroll_loops: bool = False,
mram_tile_capacity: int = 4,
):
# MemoryStateMixin.__init__ sets dimensions, layout tables, and memory allocators.
super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops)
super().__init__(
mlen=mlen,
blen=blen,
unroll_loops=unroll_loops,
mram_tile_capacity=mram_tile_capacity,
)
self.real_data_ratio = real_data_ratio
self.register_allocator = RegisterAllocator()
self.generated_code = ""
Expand Down
18 changes: 14 additions & 4 deletions aten/plena/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,22 @@ class MRAMAllocator(MemoryAllocatorBase):
"""
Matrix RAM address allocator.

Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen.
Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4).
Each sub-block is mlen x mlen elements; aligned to mlen*mlen.
By default the allocator holds four matrix tiles.
"""

def __init__(self, total_size: int = MLEN * MLEN * 4):
super().__init__(total_size=total_size, alignment=MLEN * MLEN, mem_name="MRAM")
def __init__(self, total_size: int | None = None, *, mlen: int = MLEN, tile_capacity: int = 4):
if mlen <= 0:
raise ValueError(f"mlen must be > 0, got {mlen}")
if tile_capacity <= 0:
raise ValueError(f"tile_capacity must be > 0, got {tile_capacity}")

self.mlen = mlen
self.tile_capacity = tile_capacity
self.tile_elems = mlen * mlen
if total_size is None:
total_size = self.tile_elems * tile_capacity
super().__init__(total_size=total_size, alignment=self.tile_elems, mem_name="MRAM")

def allocate(self, name: str, size: int) -> int:
return self._vmm.allocate(name, size)
Expand Down
20 changes: 17 additions & 3 deletions aten/plena/memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,32 @@
class MemoryStateMixin:
"""Sub-matrix layout manager for PLENA HBM, VRAM, MRAM, and FPRAM state."""

def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False):
def __init__(
self,
mlen: int = MLEN,
blen: int = BLEN,
unroll_loops: bool = False,
mram_tile_capacity: int = 4,
):
if mlen <= 0:
raise ValueError(f"mlen must be > 0, got {mlen}")
if mram_tile_capacity <= 0:
raise ValueError(f"mram_tile_capacity must be > 0, got {mram_tile_capacity}")

self.mlen = mlen
self.blen = blen
self.unroll_loops = unroll_loops
self.mram_tile_capacity = mram_tile_capacity
self.mram_tile_elems = mlen * mlen
self.mram_capacity_elems = self.mram_tile_capacity * self.mram_tile_elems

# Layout tables
self.hbm_matrices: dict[str, MatrixBlockLayout] = {}
self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {}
self.fpram_matrices: dict[str, FPRAMObjectLayout] = {}
# Memory Allocators
self.vram_allocator = VRAMAllocator()
self.mram_allocator = MRAMAllocator()
self.vram_allocator = VRAMAllocator(alignment=mlen)
self.mram_allocator = MRAMAllocator(mlen=mlen, tile_capacity=mram_tile_capacity)
self.fpram_allocator = FPRAMAllocator()

def __contains__(self, name: str) -> bool:
Expand Down
13 changes: 7 additions & 6 deletions aten/plena/program_matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar

MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements


def _iter_k_chunks(num_k_tiles: int):
def _iter_k_chunks(num_k_tiles: int, max_k_tiles: int):
if max_k_tiles <= 0:
raise ValueError(f"max_k_tiles must be > 0, got {max_k_tiles}")
k_start = 0
while k_start < num_k_tiles:
k_end = min(k_start + MAX_K_TILES, num_k_tiles)
k_end = min(k_start + max_k_tiles, num_k_tiles)
yield k_start, k_end - k_start
k_start = k_end

Expand Down Expand Up @@ -137,6 +137,7 @@ def linear_projection(self, input_var: VRAMMatrixVar, weight_var: InputVar, name
raise ValueError(f"out_features ({out_features}) must be a multiple of mlen ({mlen})")
num_col_blocks = out_features // mlen
num_k_tiles = math.ceil(k_total / mlen)
max_k_tiles = self.mram_tile_capacity

# When rows is not a multiple of mlen the hardware still operates on
# full tiles; only the first `rows` rows contain valid output.
Expand All @@ -154,7 +155,7 @@ def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **
**k_split,
)

if num_k_tiles <= MAX_K_TILES:
if num_k_tiles <= max_k_tiles:
for col_idx in range(num_col_blocks):
for row_idx in range(num_row_blocks):
emit_projection(row_idx, col_idx, output, row_idx, col_idx)
Expand All @@ -163,7 +164,7 @@ def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **
# Temp buffer for one partial-sum tile. Allocating the full output shape
# here can overlap with the real output for wide projections.
temp = self.alloc(f"{name}_temp", mlen, mlen)
for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)):
for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles, max_k_tiles)):
k_split = {
"k_block_start": k_block_start,
"k_block_count": k_block_count,
Expand Down
15 changes: 13 additions & 2 deletions aten/plena_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def compile_native_hf_decoder(
layer_idx_start: int = 0,
mlen: int = 64,
blen: int = 4,
mram_tile_capacity: int = 4,
seed: int = 42,
golden_precision: str = "hardware",
verbose: bool = False,
Expand Down Expand Up @@ -305,7 +306,10 @@ def _verbose(message: str = ""):
f" decoder: hidden={hidden}, inter={inter}, heads={num_heads}/{num_kv_heads}, "
f"head_dim={head_dim}"
)
print(f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, total_q_dim={total_q_dim}")
print(
f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, "
f"mram_tile_capacity={mram_tile_capacity}, total_q_dim={total_q_dim}"
)
print("=" * 80)

# ----------------------------------------------------------- weights
Expand Down Expand Up @@ -367,6 +371,7 @@ def _verbose(message: str = ""):
cos_table,
sin_table,
mlen=mlen,
max_k_tiles=mram_tile_capacity,
precision=golden_policy,
trace=lambda i, x: _verbose(f" After layer {i}: X_gold[0,:4] = {x[0, :4].tolist()}"),
)
Expand All @@ -384,6 +389,7 @@ def _verbose(message: str = ""):
cos_table,
sin_table,
mlen=mlen,
max_k_tiles=mram_tile_capacity,
precision=ReferencePrecision.from_mode("hf_fp32"),
trace=lambda i, x: _verbose(f" After layer {i}: X_hf[0,:4] = {x[0, :4].tolist()}"),
)
Expand All @@ -396,7 +402,12 @@ def _verbose(message: str = ""):
registry = OpRegistry.load()
registry.set_backend(Backend.PLENA)

prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO)
prog = PlenaCompiler(
mlen=mlen,
blen=blen,
real_data_ratio=REAL_DATA_RATIO,
mram_tile_capacity=mram_tile_capacity,
)

# Shared inputs
x_input = prog.input("X", shape=(seq_len, hidden))
Expand Down
23 changes: 14 additions & 9 deletions aten/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def run_decoder_reference(
sin_table: torch.Tensor,
*,
mlen: int,
max_k_tiles: int = _HW_MAX_K_TILES,
precision: ReferencePrecision,
trace: Callable[[int, torch.Tensor], None] | None = None,
) -> torch.Tensor:
Expand All @@ -128,9 +129,10 @@ def run_decoder_reference(
cos_table,
sin_table,
mlen,
max_k_tiles,
precision,
)
x = _ffn_block_ref(x, layer, mlen, precision)
x = _ffn_block_ref(x, layer, mlen, max_k_tiles, precision)

if trace is not None:
trace(layer_idx, x)
Expand All @@ -146,18 +148,19 @@ def _attention_block_ref(
cos_table: torch.Tensor,
sin_table: torch.Tensor,
mlen: int,
max_k_tiles: int,
precision: ReferencePrecision,
) -> torch.Tensor:
quantize = precision.quantize
residual = x.clone()
x_normed = _rms_norm_ref(x, layer.eps, precision)
q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, precision), precision)
q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, max_k_tiles, precision), precision)

k_heads = []
v_heads = []
for kv_h in range(config.num_kv_heads):
k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, precision)
v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, precision)
k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, max_k_tiles, precision)
v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, max_k_tiles, precision)
k_h = _rope_ref(k_h, rope_matrix, cos_table, sin_table, precision)
k_heads.append(_hbm_round_ref(k_h, precision))
v_heads.append(_hbm_round_ref(v_h, precision))
Expand All @@ -171,23 +174,24 @@ def _attention_block_ref(
o_heads.append(_flash_attn_ref(q_h, k_heads[kv_h], v_heads[kv_h], scale, causal=True))

attn_out = _round(torch.cat(o_heads, dim=1), precision)
o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, precision), precision)
o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, max_k_tiles, precision), precision)
return _residual_add_ref(o_proj, residual, precision)


def _ffn_block_ref(
x: torch.Tensor,
layer: LayerWeights,
mlen: int,
max_k_tiles: int,
precision: ReferencePrecision,
) -> torch.Tensor:
quantize = precision.quantize
residual = x.clone()
x_normed = _rms_norm_ref(x, layer.eps, precision)
up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, precision)
gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, precision)
up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, max_k_tiles, precision)
gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, max_k_tiles, precision)
silu_gate = precision.to_inter(F.silu(_round(up_out, precision)) * _round(gate_out, precision))
x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, precision)
x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, max_k_tiles, precision)
return _residual_add_ref(_round(x, precision), residual, precision)


Expand All @@ -205,6 +209,7 @@ def _linear_ref(
x: torch.Tensor,
weight: torch.Tensor,
mlen: int,
max_k_tiles: int,
precision: ReferencePrecision,
) -> torch.Tensor:
if not precision.use_ksplit:
Expand All @@ -213,7 +218,7 @@ def _linear_ref(
x,
weight,
mlen=mlen,
max_k_tiles=_HW_MAX_K_TILES,
max_k_tiles=max_k_tiles,
to_inter=precision.to_inter,
from_inter=precision.from_inter,
)
Expand Down
67 changes: 67 additions & 0 deletions aten/tests/test_plena_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,70 @@ def test_alloc_at_correct_address():
print(" PASS test_alloc_at_correct_address")


def test_mram_allocator_scales_with_runtime_mlen():
"""MRAMAllocator capacity/alignment must use runtime mlen, not module MLEN=64."""
from compiler.aten.plena.memory import MRAMAllocator

alloc = MRAMAllocator(mlen=256, tile_capacity=4)
tile_elems = 256 * 256

assert alloc.alignment == tile_elems
assert alloc.total_size == 4 * tile_elems
assert alloc.tile_elems == tile_elems
assert alloc.tile_capacity == 4

addrs = [alloc.allocate(f"tile_{i}", tile_elems) for i in range(4)]
assert addrs == [0, tile_elems, 2 * tile_elems, 3 * tile_elems]

try:
alloc.allocate("overflow", tile_elems)
except MemoryError:
pass
else:
raise AssertionError("MRAMAllocator accepted a fifth tile despite tile_capacity=4")

print(" PASS test_mram_allocator_scales_with_runtime_mlen")


def test_compiler_threads_runtime_memory_geometry():
"""PlenaCompiler must pass runtime mlen/capacity into VRAM and MRAM allocators."""
from compiler.aten.plena import PlenaCompiler

prog = PlenaCompiler(mlen=256, blen=64, mram_tile_capacity=3)
tile_elems = 256 * 256

assert prog.mram_tile_capacity == 3
assert prog.mram_tile_elems == tile_elems
assert prog.mram_capacity_elems == 3 * tile_elems
assert prog.mram_allocator.alignment == tile_elems
assert prog.mram_allocator.total_size == 3 * tile_elems
assert prog.vram_allocator.alignment == 256

print(" PASS test_compiler_threads_runtime_memory_geometry")


def test_linear_projection_uses_runtime_mram_tile_capacity():
"""K-split should follow prog.mram_tile_capacity instead of a module constant."""
from compiler.aten.plena import PlenaCompiler

prog = PlenaCompiler(mlen=128, blen=4, mram_tile_capacity=2)
x_input = prog.input("X", shape=(128, 384), prestaged_vram_addr=0)
x = prog.load_batch(x_input, name="X")
w = prog.input("W", shape=(384, 128))

prog.linear_projection(x, w, name="Y")
code = prog.compile()

# K=384 at mlen=128 is 3 K-tiles. With runtime capacity 2 this must split
# into two projection chunks and accumulate the second partial sum.
assert prog.mram_allocator.total_size == 2 * 128 * 128
assert "V_ADD_VV" in code
assert "linear_out_temp" not in code
assert "Y_temp" in code

print(" PASS test_linear_projection_uses_runtime_mram_tile_capacity")


def test_fix_large_immediates_roundtrip():
"""_fix_large_immediates must preserve exact address values."""
from compiler.aten.plena_frontend import _fix_large_immediates
Expand Down Expand Up @@ -287,6 +351,9 @@ def test_native_compile_assembles():
test_vram_fill_zero_all_column_blocks,
test_vram_add_all_column_blocks,
test_alloc_at_correct_address,
test_mram_allocator_scales_with_runtime_mlen,
test_compiler_threads_runtime_memory_geometry,
test_linear_projection_uses_runtime_mram_tile_capacity,
test_fix_large_immediates_roundtrip,
test_fix_large_immediates_preserves_relative_adds,
test_rotate_half_matrix_identity,
Expand Down
Loading