From 85c268f1d3adafba331c03b03c87c9be7364a06d Mon Sep 17 00:00:00 2001 From: booth-algo Date: Mon, 27 Apr 2026 20:09:32 +0100 Subject: [PATCH 01/32] =?UTF-8?q?feat(aten):=20automatic=20HF=20model=20?= =?UTF-8?q?=E2=86=92=20PLENA=20ISA=20compiler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add model_compiler.py that walks an HF nn.Module tree and compiles to PLENA ISA via PlenaCompiler with pre-norm + residual connections. Two public functions: compile_hf_model(model, ...) → dict with ISA, golden, tensors compile_and_run(model, build_dir, ...) → compile + emulate + compare Verified: SmolLM2-135M 2-layer at 93.65% allclose. --- aten/model_compiler.py | 501 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 501 insertions(+) create mode 100644 aten/model_compiler.py diff --git a/aten/model_compiler.py b/aten/model_compiler.py new file mode 100644 index 0000000..1601dd5 --- /dev/null +++ b/aten/model_compiler.py @@ -0,0 +1,501 @@ +""" +Automatic HuggingFace model -> PLENA ISA compiler. + +Walks an HF nn.Module tree, extracts weights, and generates ISA with +proper residual connections for multi-layer decoder pipelines. + +Usage: + from transformers import AutoModelForCausalLM + from compiler.aten.model_compiler import compile_hf_model, compile_and_run + + model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", + torch_dtype=torch.float32) + # Just compile (no emulator run): + result = compile_hf_model(model, seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) + + # Compile + run emulator + compare: + result = compile_and_run(model, "/tmp/build", seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) +""" + +import json +import math +from pathlib import Path + +import torch +import torch.nn.functional as F + +from compiler.aten.plena_compiler import PlenaCompiler +from compiler.aten.ops.registry import OpRegistry, Backend +import compiler.aten.ops as ops +from transactional_emulator.testbench.model_layer_test_builder import quantize_to_mxfp + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +REAL_DATA_RATIO = (8 * 8 + 8) / (8 * 8) + + +# --------------------------------------------------------------------------- +# Model structure helpers +# --------------------------------------------------------------------------- +def _find_model_root(model): + """Find the transformer backbone (model.model or model.model.text_model). + + Handles standard CausalLM models and VLMs like SmolVLM2. + """ + for candidate in [ + getattr(model, "model", None), + getattr(getattr(model, "model", None), "text_model", None), + getattr(model, "language_model", getattr(model, "text_model", None)), + ]: + if candidate is not None and hasattr(candidate, "layers"): + return candidate + raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") + + +def _extract_config(model): + """Extract config dimensions from the model, resolving text_config for VLMs.""" + config = getattr(model.config, "text_config", model.config) + native_hidden = config.hidden_size + native_inter = getattr(config, "intermediate_size", 4 * native_hidden) + native_heads = config.num_attention_heads + native_kv_heads = getattr(config, "num_key_value_heads", native_heads) + native_head_dim = native_hidden // native_heads + eps = getattr(config, "rms_norm_eps", 1e-5) + rope_theta = getattr(config, "rope_theta", 10000.0) + return { + "hidden_size": native_hidden, + "inter_dim": native_inter, + "num_heads": native_heads, + "num_kv_heads": native_kv_heads, + "head_dim": native_head_dim, + "eps": eps, + "rope_theta": rope_theta, + "model_type": getattr(config, "model_type", "unknown"), + } + + +def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice): + """Extract and slice weights from a single decoder layer. + + Transposes from HF's (out_features, in_features) to PLENA's (in, out) convention. + Uses KV-head 0 and caps head_dim to hidden_slice for sim compatibility. + + Args: + layer: nn.Module for a single decoder layer + hidden_slice: target hidden dimension + inter_slice: target intermediate dimension + head_dim_slice: target head dimension (min of native head_dim, hidden_slice) + + Returns: + dict with W_gate, W_up, W_down, W_k, W_v, eps + """ + # FFN weights: HF stores (out, in) -> transpose to (in, out) -> slice + W_gate = layer.mlp.gate_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] + W_up = layer.mlp.up_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] + W_down = layer.mlp.down_proj.weight.detach().T.contiguous()[:inter_slice, :hidden_slice] + + # Attention: KV projections, sliced to sim dims + W_k = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] + W_v = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] + + # eps from input_layernorm + norm = layer.input_layernorm + eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)) + + return { + "W_gate": W_gate, + "W_up": W_up, + "W_down": W_down, + "W_k": W_k, + "W_v": W_v, + "eps": eps, + } + + +# --------------------------------------------------------------------------- +# Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) +# --------------------------------------------------------------------------- +def _flash_attn_ref(Q, K, V, scale): + """CPU reference: scaled dot-product attention.""" + scores = (Q @ K.T) * scale + attn = F.softmax(scores, dim=-1) + return attn @ V + + +def _rms_norm_ref(x, eps): + """CPU reference: RMS normalization matching PLENA hardware (BF16 intermediate).""" + x_bf = x.to(torch.bfloat16) + rms = torch.rsqrt(x_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) + return (x_bf * rms).float() + + +# --------------------------------------------------------------------------- +# Main compilation function +# --------------------------------------------------------------------------- +def compile_hf_model( + model, + seq_len: int = 64, + hidden_size: int | None = None, + inter_dim: int | None = None, + num_layers: int | None = None, + layer_idx_start: int = 0, + mlen: int = 64, + blen: int = 4, + seed: int = 42, +) -> dict: + """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. + + Walks the nn.Module tree, extracts weights, and generates ISA with + proper residual connections for multi-layer decoders. + + The pipeline implemented (per-layer, with pre-norm + residual): + X = embedding_add(token_embeds, pos_weight) + for each layer: + residual = X + X = rms_norm(X) + X = flash_attention(X, K, V, scale) + X = X + residual + residual = X + X = rms_norm(X) + X = ffn(X, gate, up, down) + X = X + residual + X = rms_norm(X) # final norm + + RoPE is omitted (orthogonal to multi-layer compilation testing). + + Args: + model: nn.Module (HF CausalLM model, already loaded) + seq_len: Sequence length (default 64) + hidden_size: Target hidden dimension (None = use model's native) + inter_dim: Target intermediate dimension (None = use model's native) + num_layers: Number of layers to compile (None = all layers) + layer_idx_start: First layer index to use (default 0) + mlen: Matrix tile length (default 64) + blen: Batch tile length (default 4) + seed: Random seed for test data generation + + Returns: + dict with: + isa: str - generated ISA code + golden_output: torch.Tensor - CPU golden reference output + input_tensors: dict - {name: tensor} for sim env setup + data_order: list[str] - HBM tensor ordering + fp_preload: list[float] - FPRAM constants + comparison_params: dict - for emulator comparison + info: dict - model dims, VRAM usage, etc. + """ + # -------------------------------------------------------------- config + native_cfg = _extract_config(model) + hidden = hidden_size if hidden_size is not None else native_cfg["hidden_size"] + inter = inter_dim if inter_dim is not None else native_cfg["inter_dim"] + head_dim = min(native_cfg["head_dim"], hidden) # head_dim must fit in hidden + + root = _find_model_root(model) + layers = root.layers + n_layers = num_layers if num_layers is not None else len(layers) + assert layer_idx_start + n_layers <= len(layers), ( + f"Requested layers [{layer_idx_start}, {layer_idx_start + n_layers}) " + f"but model only has {len(layers)} layers" + ) + + scale = 1.0 / math.sqrt(head_dim) + + print("=" * 80) + print(f"Model Compiler - {native_cfg['model_type']} ({n_layers} layers)") + print(f" native: hidden={native_cfg['hidden_size']}, inter={native_cfg['inter_dim']}, " + f"heads={native_cfg['num_heads']}/{native_cfg['num_kv_heads']}, " + f"head_dim={native_cfg['head_dim']}") + print(f" sim: hidden={hidden}, inter={inter}, head_dim={head_dim}, " + f"seq_len={seq_len}, mlen={mlen}") + print("=" * 80) + + # ----------------------------------------------------------- weights + print(f"\nExtracting weights from layers {layer_idx_start}..{layer_idx_start + n_layers - 1}...") + all_weights = [] + for i in range(n_layers): + layer = layers[layer_idx_start + i] + w = _extract_layer_weights(layer, hidden, inter, head_dim) + all_weights.append(w) + print(f" Layer {i}: W_gate={w['W_gate'].shape}, W_k={w['W_k'].shape}, eps={w['eps']}") + + eps = all_weights[0]["eps"] + + # ----------------------------------------------------------- test data + torch.manual_seed(seed) + token_embeds = torch.randn(seq_len, hidden) + pos_weight = torch.randn(seq_len, hidden) + + K_mats = [] + V_mats = [] + for i in range(n_layers): + X_ctx = torch.randn(seq_len, hidden) + K_mats.append(X_ctx @ all_weights[i]["W_k"]) + V_mats.append(X_ctx @ all_weights[i]["W_v"]) + + print(f"\ntoken_embeds: {token_embeds.shape}") + print(f"pos_weight: {pos_weight.shape}") + for i in range(n_layers): + print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") + print(f"attn_scale: {scale:.6f}") + + # ----------------------------------------------------------- golden ref + K_q_list = [quantize_to_mxfp(K_mats[i]) for i in range(n_layers)] + V_q_list = [quantize_to_mxfp(V_mats[i]) for i in range(n_layers)] + + print("\n--- CPU Golden Reference (MXFP8 quantized HBM + BF16 intermediates) ---") + + X_gold = token_embeds.clone() + pos_weight # embedding_add + + for i in range(n_layers): + w = all_weights[i] + W_gate_q = quantize_to_mxfp(w["W_gate"]) + W_up_q = quantize_to_mxfp(w["W_up"]) + W_down_q = quantize_to_mxfp(w["W_down"]) + + # --- Attention block --- + residual = X_gold.clone() + # rms_norm with bfloat16 to match PLENA + X_bf = X_gold.to(torch.bfloat16) + rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) + X_gold = (X_bf * rms).float() + # flash attention (no RoPE) + X_gold = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale) + # attention residual + X_gold = X_gold + residual + + # --- FFN block --- + residual = X_gold.clone() + # rms_norm with bfloat16 + X_bf = X_gold.to(torch.bfloat16) + rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) + X_gold = (X_bf * rms).float() + # FFN with MXFP8 weights + BF16 intermediates + up_out = torch.matmul(X_gold.to(torch.bfloat16).float(), W_up_q.float()).to(torch.bfloat16) + gate_out = torch.matmul(X_gold.to(torch.bfloat16).float(), W_gate_q.float()).to(torch.bfloat16) + silu_gate = (F.silu(up_out.float()) * gate_out.float()).to(torch.bfloat16) + X_gold = torch.matmul(silu_gate.float(), W_down_q.float()).to(torch.bfloat16).float() + # FFN residual + X_gold = X_gold + residual + + print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") + + # Final norm + X_gold = _rms_norm_ref(X_gold, eps) + + golden_out = X_gold + print(f" golden_out: {golden_out.shape}") + print(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") + + # ----------------------------------------------------------- PLENA ISA + print("\n--- PLENA Backend (ISA generation) ---") + registry = OpRegistry.load() + registry.set_backend(Backend.PLENA) + + prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO) + + # Shared inputs + x_input = prog.input("X", shape=(seq_len, hidden)) + pos_input = prog.input("POS", shape=(seq_len, hidden)) + + # Per-layer weight inputs (order determines HBM layout) + layer_inputs = [] + for i in range(n_layers): + ki = prog.input(f"K_{i}", shape=(seq_len, hidden)) + vi = prog.input(f"V_{i}", shape=(seq_len, hidden)) + wg = prog.input(f"W_gate_{i}", shape=(hidden, inter)) + wu = prog.input(f"W_up_{i}", shape=(hidden, inter)) + wd = prog.input(f"W_down_{i}", shape=(inter, hidden)) + layer_inputs.append({"K": ki, "V": vi, "W_gate": wg, "W_up": wu, "W_down": wd}) + + # Load activations to VRAM + X_batch = prog.load_batch(x_input, name="X") + POS_batch = prog.load_batch(pos_input, name="POS") + ops.embedding_add(prog, X_batch, POS_batch) # X += POS in-place + + # VRAM layout hazard: ffn_asm writes gate/up intermediates at absolute + # address batch*hidden spanning up to batch*hidden + 2*inter*batch. + # The residual scratch buffer must be placed ABOVE this region. + _ffn_intermediate_end = seq_len * hidden + 2 * inter * seq_len + _current_bump = 2 * seq_len * hidden # X + POS already allocated + if _current_bump < _ffn_intermediate_end: + _pad_size = _ffn_intermediate_end - _current_bump + _pad_rows = max(1, _pad_size // hidden) + prog.alloc("_vram_padding", _pad_rows, hidden) + + # Allocate scratch buffer for residual save/restore (reused across layers) + scratch = prog.alloc("residual_scratch", seq_len, hidden) + + # Chain layers + current = X_batch + + for i in range(n_layers): + li = layer_inputs[i] + + # --- Attention block --- + # Save residual: scratch = current (zero then add) + prog.vram_fill_zero(scratch) + prog.vram_add(scratch, current) + + # Norm (in-place on current) + prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) + + # Flash attention (no RoPE) - current is Q after norm + O = ops.flash_attention(prog, current, li["K"], li["V"], scale) + + # Attention residual: O += scratch + prog.vram_add(O, scratch) + + # --- FFN block --- + # Save residual: scratch = O (zero then add) + prog.vram_fill_zero(scratch) + prog.vram_add(scratch, O) + + # Norm (in-place on O) + prog.rms_norm(O, eps_offset=3, reci_hid_offset=4) + + # FFN (in-place on O) + ops.ffn(prog, O, li["W_gate"], li["W_up"], li["W_down"]) + + # FFN residual: O += scratch + prog.vram_add(O, scratch) + + current = O # carry forward + + # Final norm + prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) + + isa_code = prog.compile() + lines = isa_code.splitlines() + print(f"\nGenerated {len(lines)} lines of ISA code") + + # ----------------------------------------------------------- build return + input_tensors = {"X": token_embeds, "POS": pos_weight} + data_order = ["X", "POS"] + for i in range(n_layers): + input_tensors[f"K_{i}"] = K_mats[i] + input_tensors[f"V_{i}"] = V_mats[i] + input_tensors[f"W_gate_{i}"] = all_weights[i]["W_gate"] + input_tensors[f"W_up_{i}"] = all_weights[i]["W_up"] + input_tensors[f"W_down_{i}"] = all_weights[i]["W_down"] + data_order.extend([f"K_{i}", f"V_{i}", f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}"]) + + # FPRAM layout (same as single-layer decoder): + # slot 0 = 0.0 (reserved) + # slot 1 = attn_scale (flash_attention) + # slot 2 = -inf (flash_attention softmax mask) + # slot 3 = eps (rms_norm, offset=3) + # slot 4 = 1/hidden (rms_norm, offset=4) + # slot 5 = 1.0 (FFN SiLU) + # slots 6-9 = 0.0 (padding) + fp_preload = [0.0, scale, float("-inf"), eps, 1.0 / hidden, 1.0] + [0.0] * 4 + + # Result is at current's VRAM location (last O from flash_attention chain) + o_vram_addr = prog._compiler.get_vram_addr(current.name) + + comparison_params = { + "start_row_idx": o_vram_addr // mlen, + "num_rows": (seq_len * hidden) // mlen, + "num_batches": seq_len, + "elements_per_batch": hidden, + "row_dim": mlen, + "use_stride_mode": hidden > mlen, + } + + info = { + "model_type": native_cfg["model_type"], + "hidden_size": hidden, + "inter_dim": inter, + "num_layers": n_layers, + "seq_len": seq_len, + "head_dim": head_dim, + "mlen": mlen, + "blen": blen, + "isa_lines": len(lines), + } + + print(f"\nCompilation complete: {info['isa_lines']} ISA lines, " + f"{n_layers} layers, output at VRAM row {o_vram_addr // mlen}") + + return { + "isa": isa_code, + "golden_output": golden_out, + "input_tensors": input_tensors, + "data_order": data_order, + "fp_preload": fp_preload, + "comparison_params": comparison_params, + "info": info, + } + + +# --------------------------------------------------------------------------- +# Convenience: compile + run emulator + compare +# --------------------------------------------------------------------------- +def compile_and_run( + model, + build_dir, + **kwargs, +) -> dict: + """Compile, run emulator, and compare against golden. + + Convenience wrapper that calls compile_hf_model, sets up simulation + environment, runs the Rust transactional emulator, and compares output. + + Args: + model: nn.Module (HF CausalLM model, already loaded) + build_dir: Directory for simulation artifacts + **kwargs: Forwarded to compile_hf_model (seq_len, hidden_size, etc.) + + Returns: + dict with compilation info + comparison results including + 'allclose_match_rate' percentage. + """ + from transactional_emulator.tools.create_sim_env import create_sim_env + from compiler.sim_env_utils.build_env import create_mem_for_sim + from transactional_emulator.testbench.emulator_runner import ( + run_and_assert, + compare_emulator_output, + ) + + result = compile_hf_model(model, **kwargs) + build_dir = Path(build_dir) + build_dir.mkdir(parents=True, exist_ok=True) + + mlen = kwargs.get("mlen", 64) + blen = kwargs.get("blen", 4) + asm_name = f"model_{result['info']['model_type']}_{result['info']['num_layers']}L" + + # Write sim env files + create_sim_env( + result["input_tensors"], + result["isa"], + {"original_output": result["golden_output"]}, + result["fp_preload"], + build_dir=str(build_dir), + ) + + create_mem_for_sim( + data_size=256, + mode="behave_sim", + asm=asm_name, + data=None, + specified_data_order=result["data_order"], + build_path=build_dir, + ) + + with open(build_dir / "comparison_params.json", "w") as f: + json.dump(result["comparison_params"], f, indent=2) + + with open(build_dir / "generated_asm_code.asm", "w") as f: + f.write(result["isa"]) + + print(f"\nSimulation environment created: {build_dir}") + print(f" Result location: VRAM row {result['comparison_params']['start_row_idx']}") + print(f" Layers: {result['info']['num_layers']}, data_order: {result['data_order']}") + + # Run emulator and compare + run_and_assert(build_dir, asm_name, mlen=mlen, blen=blen) + + comp_results, _ = compare_emulator_output(build_dir) + return {**result["info"], **comp_results} From 6df4a4c1906c64d889b23aab38b10f14cfb5d7c7 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Mon, 27 Apr 2026 20:17:33 +0100 Subject: [PATCH 02/32] =?UTF-8?q?rename:=20model=5Fcompiler.py=20=E2=86=92?= =?UTF-8?q?=20plena=5Fparser.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aten/{model_compiler.py => plena_parser.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename aten/{model_compiler.py => plena_parser.py} (100%) diff --git a/aten/model_compiler.py b/aten/plena_parser.py similarity index 100% rename from aten/model_compiler.py rename to aten/plena_parser.py From 1f65319a46acb279628d0e31e79e86fce608f0c1 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Mon, 27 Apr 2026 20:21:47 +0100 Subject: [PATCH 03/32] =?UTF-8?q?rename:=20plena=5Fparser.py=20=E2=86=92?= =?UTF-8?q?=20plena=5Ffrontend.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aten/plena_parser.py | 501 ------------------------------------------- 1 file changed, 501 deletions(-) delete mode 100644 aten/plena_parser.py diff --git a/aten/plena_parser.py b/aten/plena_parser.py deleted file mode 100644 index 1601dd5..0000000 --- a/aten/plena_parser.py +++ /dev/null @@ -1,501 +0,0 @@ -""" -Automatic HuggingFace model -> PLENA ISA compiler. - -Walks an HF nn.Module tree, extracts weights, and generates ISA with -proper residual connections for multi-layer decoder pipelines. - -Usage: - from transformers import AutoModelForCausalLM - from compiler.aten.model_compiler import compile_hf_model, compile_and_run - - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", - torch_dtype=torch.float32) - # Just compile (no emulator run): - result = compile_hf_model(model, seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) - - # Compile + run emulator + compare: - result = compile_and_run(model, "/tmp/build", seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) -""" - -import json -import math -from pathlib import Path - -import torch -import torch.nn.functional as F - -from compiler.aten.plena_compiler import PlenaCompiler -from compiler.aten.ops.registry import OpRegistry, Backend -import compiler.aten.ops as ops -from transactional_emulator.testbench.model_layer_test_builder import quantize_to_mxfp - - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- -REAL_DATA_RATIO = (8 * 8 + 8) / (8 * 8) - - -# --------------------------------------------------------------------------- -# Model structure helpers -# --------------------------------------------------------------------------- -def _find_model_root(model): - """Find the transformer backbone (model.model or model.model.text_model). - - Handles standard CausalLM models and VLMs like SmolVLM2. - """ - for candidate in [ - getattr(model, "model", None), - getattr(getattr(model, "model", None), "text_model", None), - getattr(model, "language_model", getattr(model, "text_model", None)), - ]: - if candidate is not None and hasattr(candidate, "layers"): - return candidate - raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") - - -def _extract_config(model): - """Extract config dimensions from the model, resolving text_config for VLMs.""" - config = getattr(model.config, "text_config", model.config) - native_hidden = config.hidden_size - native_inter = getattr(config, "intermediate_size", 4 * native_hidden) - native_heads = config.num_attention_heads - native_kv_heads = getattr(config, "num_key_value_heads", native_heads) - native_head_dim = native_hidden // native_heads - eps = getattr(config, "rms_norm_eps", 1e-5) - rope_theta = getattr(config, "rope_theta", 10000.0) - return { - "hidden_size": native_hidden, - "inter_dim": native_inter, - "num_heads": native_heads, - "num_kv_heads": native_kv_heads, - "head_dim": native_head_dim, - "eps": eps, - "rope_theta": rope_theta, - "model_type": getattr(config, "model_type", "unknown"), - } - - -def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice): - """Extract and slice weights from a single decoder layer. - - Transposes from HF's (out_features, in_features) to PLENA's (in, out) convention. - Uses KV-head 0 and caps head_dim to hidden_slice for sim compatibility. - - Args: - layer: nn.Module for a single decoder layer - hidden_slice: target hidden dimension - inter_slice: target intermediate dimension - head_dim_slice: target head dimension (min of native head_dim, hidden_slice) - - Returns: - dict with W_gate, W_up, W_down, W_k, W_v, eps - """ - # FFN weights: HF stores (out, in) -> transpose to (in, out) -> slice - W_gate = layer.mlp.gate_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_up = layer.mlp.up_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_down = layer.mlp.down_proj.weight.detach().T.contiguous()[:inter_slice, :hidden_slice] - - # Attention: KV projections, sliced to sim dims - W_k = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - W_v = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - - # eps from input_layernorm - norm = layer.input_layernorm - eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)) - - return { - "W_gate": W_gate, - "W_up": W_up, - "W_down": W_down, - "W_k": W_k, - "W_v": W_v, - "eps": eps, - } - - -# --------------------------------------------------------------------------- -# Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) -# --------------------------------------------------------------------------- -def _flash_attn_ref(Q, K, V, scale): - """CPU reference: scaled dot-product attention.""" - scores = (Q @ K.T) * scale - attn = F.softmax(scores, dim=-1) - return attn @ V - - -def _rms_norm_ref(x, eps): - """CPU reference: RMS normalization matching PLENA hardware (BF16 intermediate).""" - x_bf = x.to(torch.bfloat16) - rms = torch.rsqrt(x_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - return (x_bf * rms).float() - - -# --------------------------------------------------------------------------- -# Main compilation function -# --------------------------------------------------------------------------- -def compile_hf_model( - model, - seq_len: int = 64, - hidden_size: int | None = None, - inter_dim: int | None = None, - num_layers: int | None = None, - layer_idx_start: int = 0, - mlen: int = 64, - blen: int = 4, - seed: int = 42, -) -> dict: - """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. - - Walks the nn.Module tree, extracts weights, and generates ISA with - proper residual connections for multi-layer decoders. - - The pipeline implemented (per-layer, with pre-norm + residual): - X = embedding_add(token_embeds, pos_weight) - for each layer: - residual = X - X = rms_norm(X) - X = flash_attention(X, K, V, scale) - X = X + residual - residual = X - X = rms_norm(X) - X = ffn(X, gate, up, down) - X = X + residual - X = rms_norm(X) # final norm - - RoPE is omitted (orthogonal to multi-layer compilation testing). - - Args: - model: nn.Module (HF CausalLM model, already loaded) - seq_len: Sequence length (default 64) - hidden_size: Target hidden dimension (None = use model's native) - inter_dim: Target intermediate dimension (None = use model's native) - num_layers: Number of layers to compile (None = all layers) - layer_idx_start: First layer index to use (default 0) - mlen: Matrix tile length (default 64) - blen: Batch tile length (default 4) - seed: Random seed for test data generation - - Returns: - dict with: - isa: str - generated ISA code - golden_output: torch.Tensor - CPU golden reference output - input_tensors: dict - {name: tensor} for sim env setup - data_order: list[str] - HBM tensor ordering - fp_preload: list[float] - FPRAM constants - comparison_params: dict - for emulator comparison - info: dict - model dims, VRAM usage, etc. - """ - # -------------------------------------------------------------- config - native_cfg = _extract_config(model) - hidden = hidden_size if hidden_size is not None else native_cfg["hidden_size"] - inter = inter_dim if inter_dim is not None else native_cfg["inter_dim"] - head_dim = min(native_cfg["head_dim"], hidden) # head_dim must fit in hidden - - root = _find_model_root(model) - layers = root.layers - n_layers = num_layers if num_layers is not None else len(layers) - assert layer_idx_start + n_layers <= len(layers), ( - f"Requested layers [{layer_idx_start}, {layer_idx_start + n_layers}) " - f"but model only has {len(layers)} layers" - ) - - scale = 1.0 / math.sqrt(head_dim) - - print("=" * 80) - print(f"Model Compiler - {native_cfg['model_type']} ({n_layers} layers)") - print(f" native: hidden={native_cfg['hidden_size']}, inter={native_cfg['inter_dim']}, " - f"heads={native_cfg['num_heads']}/{native_cfg['num_kv_heads']}, " - f"head_dim={native_cfg['head_dim']}") - print(f" sim: hidden={hidden}, inter={inter}, head_dim={head_dim}, " - f"seq_len={seq_len}, mlen={mlen}") - print("=" * 80) - - # ----------------------------------------------------------- weights - print(f"\nExtracting weights from layers {layer_idx_start}..{layer_idx_start + n_layers - 1}...") - all_weights = [] - for i in range(n_layers): - layer = layers[layer_idx_start + i] - w = _extract_layer_weights(layer, hidden, inter, head_dim) - all_weights.append(w) - print(f" Layer {i}: W_gate={w['W_gate'].shape}, W_k={w['W_k'].shape}, eps={w['eps']}") - - eps = all_weights[0]["eps"] - - # ----------------------------------------------------------- test data - torch.manual_seed(seed) - token_embeds = torch.randn(seq_len, hidden) - pos_weight = torch.randn(seq_len, hidden) - - K_mats = [] - V_mats = [] - for i in range(n_layers): - X_ctx = torch.randn(seq_len, hidden) - K_mats.append(X_ctx @ all_weights[i]["W_k"]) - V_mats.append(X_ctx @ all_weights[i]["W_v"]) - - print(f"\ntoken_embeds: {token_embeds.shape}") - print(f"pos_weight: {pos_weight.shape}") - for i in range(n_layers): - print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") - print(f"attn_scale: {scale:.6f}") - - # ----------------------------------------------------------- golden ref - K_q_list = [quantize_to_mxfp(K_mats[i]) for i in range(n_layers)] - V_q_list = [quantize_to_mxfp(V_mats[i]) for i in range(n_layers)] - - print("\n--- CPU Golden Reference (MXFP8 quantized HBM + BF16 intermediates) ---") - - X_gold = token_embeds.clone() + pos_weight # embedding_add - - for i in range(n_layers): - w = all_weights[i] - W_gate_q = quantize_to_mxfp(w["W_gate"]) - W_up_q = quantize_to_mxfp(w["W_up"]) - W_down_q = quantize_to_mxfp(w["W_down"]) - - # --- Attention block --- - residual = X_gold.clone() - # rms_norm with bfloat16 to match PLENA - X_bf = X_gold.to(torch.bfloat16) - rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - X_gold = (X_bf * rms).float() - # flash attention (no RoPE) - X_gold = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale) - # attention residual - X_gold = X_gold + residual - - # --- FFN block --- - residual = X_gold.clone() - # rms_norm with bfloat16 - X_bf = X_gold.to(torch.bfloat16) - rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - X_gold = (X_bf * rms).float() - # FFN with MXFP8 weights + BF16 intermediates - up_out = torch.matmul(X_gold.to(torch.bfloat16).float(), W_up_q.float()).to(torch.bfloat16) - gate_out = torch.matmul(X_gold.to(torch.bfloat16).float(), W_gate_q.float()).to(torch.bfloat16) - silu_gate = (F.silu(up_out.float()) * gate_out.float()).to(torch.bfloat16) - X_gold = torch.matmul(silu_gate.float(), W_down_q.float()).to(torch.bfloat16).float() - # FFN residual - X_gold = X_gold + residual - - print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") - - # Final norm - X_gold = _rms_norm_ref(X_gold, eps) - - golden_out = X_gold - print(f" golden_out: {golden_out.shape}") - print(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") - - # ----------------------------------------------------------- PLENA ISA - print("\n--- PLENA Backend (ISA generation) ---") - registry = OpRegistry.load() - registry.set_backend(Backend.PLENA) - - prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO) - - # Shared inputs - x_input = prog.input("X", shape=(seq_len, hidden)) - pos_input = prog.input("POS", shape=(seq_len, hidden)) - - # Per-layer weight inputs (order determines HBM layout) - layer_inputs = [] - for i in range(n_layers): - ki = prog.input(f"K_{i}", shape=(seq_len, hidden)) - vi = prog.input(f"V_{i}", shape=(seq_len, hidden)) - wg = prog.input(f"W_gate_{i}", shape=(hidden, inter)) - wu = prog.input(f"W_up_{i}", shape=(hidden, inter)) - wd = prog.input(f"W_down_{i}", shape=(inter, hidden)) - layer_inputs.append({"K": ki, "V": vi, "W_gate": wg, "W_up": wu, "W_down": wd}) - - # Load activations to VRAM - X_batch = prog.load_batch(x_input, name="X") - POS_batch = prog.load_batch(pos_input, name="POS") - ops.embedding_add(prog, X_batch, POS_batch) # X += POS in-place - - # VRAM layout hazard: ffn_asm writes gate/up intermediates at absolute - # address batch*hidden spanning up to batch*hidden + 2*inter*batch. - # The residual scratch buffer must be placed ABOVE this region. - _ffn_intermediate_end = seq_len * hidden + 2 * inter * seq_len - _current_bump = 2 * seq_len * hidden # X + POS already allocated - if _current_bump < _ffn_intermediate_end: - _pad_size = _ffn_intermediate_end - _current_bump - _pad_rows = max(1, _pad_size // hidden) - prog.alloc("_vram_padding", _pad_rows, hidden) - - # Allocate scratch buffer for residual save/restore (reused across layers) - scratch = prog.alloc("residual_scratch", seq_len, hidden) - - # Chain layers - current = X_batch - - for i in range(n_layers): - li = layer_inputs[i] - - # --- Attention block --- - # Save residual: scratch = current (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, current) - - # Norm (in-place on current) - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - - # Flash attention (no RoPE) - current is Q after norm - O = ops.flash_attention(prog, current, li["K"], li["V"], scale) - - # Attention residual: O += scratch - prog.vram_add(O, scratch) - - # --- FFN block --- - # Save residual: scratch = O (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, O) - - # Norm (in-place on O) - prog.rms_norm(O, eps_offset=3, reci_hid_offset=4) - - # FFN (in-place on O) - ops.ffn(prog, O, li["W_gate"], li["W_up"], li["W_down"]) - - # FFN residual: O += scratch - prog.vram_add(O, scratch) - - current = O # carry forward - - # Final norm - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - - isa_code = prog.compile() - lines = isa_code.splitlines() - print(f"\nGenerated {len(lines)} lines of ISA code") - - # ----------------------------------------------------------- build return - input_tensors = {"X": token_embeds, "POS": pos_weight} - data_order = ["X", "POS"] - for i in range(n_layers): - input_tensors[f"K_{i}"] = K_mats[i] - input_tensors[f"V_{i}"] = V_mats[i] - input_tensors[f"W_gate_{i}"] = all_weights[i]["W_gate"] - input_tensors[f"W_up_{i}"] = all_weights[i]["W_up"] - input_tensors[f"W_down_{i}"] = all_weights[i]["W_down"] - data_order.extend([f"K_{i}", f"V_{i}", f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}"]) - - # FPRAM layout (same as single-layer decoder): - # slot 0 = 0.0 (reserved) - # slot 1 = attn_scale (flash_attention) - # slot 2 = -inf (flash_attention softmax mask) - # slot 3 = eps (rms_norm, offset=3) - # slot 4 = 1/hidden (rms_norm, offset=4) - # slot 5 = 1.0 (FFN SiLU) - # slots 6-9 = 0.0 (padding) - fp_preload = [0.0, scale, float("-inf"), eps, 1.0 / hidden, 1.0] + [0.0] * 4 - - # Result is at current's VRAM location (last O from flash_attention chain) - o_vram_addr = prog._compiler.get_vram_addr(current.name) - - comparison_params = { - "start_row_idx": o_vram_addr // mlen, - "num_rows": (seq_len * hidden) // mlen, - "num_batches": seq_len, - "elements_per_batch": hidden, - "row_dim": mlen, - "use_stride_mode": hidden > mlen, - } - - info = { - "model_type": native_cfg["model_type"], - "hidden_size": hidden, - "inter_dim": inter, - "num_layers": n_layers, - "seq_len": seq_len, - "head_dim": head_dim, - "mlen": mlen, - "blen": blen, - "isa_lines": len(lines), - } - - print(f"\nCompilation complete: {info['isa_lines']} ISA lines, " - f"{n_layers} layers, output at VRAM row {o_vram_addr // mlen}") - - return { - "isa": isa_code, - "golden_output": golden_out, - "input_tensors": input_tensors, - "data_order": data_order, - "fp_preload": fp_preload, - "comparison_params": comparison_params, - "info": info, - } - - -# --------------------------------------------------------------------------- -# Convenience: compile + run emulator + compare -# --------------------------------------------------------------------------- -def compile_and_run( - model, - build_dir, - **kwargs, -) -> dict: - """Compile, run emulator, and compare against golden. - - Convenience wrapper that calls compile_hf_model, sets up simulation - environment, runs the Rust transactional emulator, and compares output. - - Args: - model: nn.Module (HF CausalLM model, already loaded) - build_dir: Directory for simulation artifacts - **kwargs: Forwarded to compile_hf_model (seq_len, hidden_size, etc.) - - Returns: - dict with compilation info + comparison results including - 'allclose_match_rate' percentage. - """ - from transactional_emulator.tools.create_sim_env import create_sim_env - from compiler.sim_env_utils.build_env import create_mem_for_sim - from transactional_emulator.testbench.emulator_runner import ( - run_and_assert, - compare_emulator_output, - ) - - result = compile_hf_model(model, **kwargs) - build_dir = Path(build_dir) - build_dir.mkdir(parents=True, exist_ok=True) - - mlen = kwargs.get("mlen", 64) - blen = kwargs.get("blen", 4) - asm_name = f"model_{result['info']['model_type']}_{result['info']['num_layers']}L" - - # Write sim env files - create_sim_env( - result["input_tensors"], - result["isa"], - {"original_output": result["golden_output"]}, - result["fp_preload"], - build_dir=str(build_dir), - ) - - create_mem_for_sim( - data_size=256, - mode="behave_sim", - asm=asm_name, - data=None, - specified_data_order=result["data_order"], - build_path=build_dir, - ) - - with open(build_dir / "comparison_params.json", "w") as f: - json.dump(result["comparison_params"], f, indent=2) - - with open(build_dir / "generated_asm_code.asm", "w") as f: - f.write(result["isa"]) - - print(f"\nSimulation environment created: {build_dir}") - print(f" Result location: VRAM row {result['comparison_params']['start_row_idx']}") - print(f" Layers: {result['info']['num_layers']}, data_order: {result['data_order']}") - - # Run emulator and compare - run_and_assert(build_dir, asm_name, mlen=mlen, blen=blen) - - comp_results, _ = compare_emulator_output(build_dir) - return {**result["info"], **comp_results} From 81061b5aad62ae27afe6b9e6ad0d558b9d6ea6be Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 08:50:31 +0100 Subject: [PATCH 04/32] feat(plena_frontend): native-dimension compilation with multi-head attention - Add alloc_at() to PlenaCompiler for VRAM view creation - Native mode: Q/O linear projections with K-split, per-head MHA (6 heads, 2 KV heads) - Legacy mode: sliced dims without projections (backward compatible) - _fix_large_immediates post-processor for S_ADDI_INT overflow at native dims - K-split support in _linear_projection for hidden > 4*mlen Verified: clm-60m native dims (hidden=384, inter=1408) at 92.68% allclose. --- aten/plena_frontend.py | 311 +++++++++++++++++++---------------------- 1 file changed, 147 insertions(+), 164 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 486c552..3e27958 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -27,10 +27,7 @@ from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.ops.registry import OpRegistry, Backend import compiler.aten.ops as ops -from transactional_emulator.testbench.model_layer_test_builder import ( - quantize_to_mxfp, - _make_rope_tables, -) +from transactional_emulator.testbench.model_layer_test_builder import quantize_to_mxfp import re _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate @@ -384,6 +381,85 @@ def _linear_projection(prog, input_var, weight_var, name): return output +# --------------------------------------------------------------------------- +# PLENA ISA helper: named linear projection (avoids name conflicts) +# --------------------------------------------------------------------------- +def _linear_projection(prog, input_var, weight_var, name): + """Emit a linear projection with a custom VRAM output name. + + Equivalent to ops.linear but uses *name* for the output allocation so + that multiple projections in the same scope don't collide on the + default "linear_out" name. + + Supports K-split: when K tiles exceed MRAM capacity (4 tiles), splits + into chunks and accumulates partial sums via a temporary buffer. + """ + import math as _math + + mlen = prog.mlen + MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements + + rows, k_total = input_var.shape + _, out_features = weight_var.shape + num_row_blocks = _math.ceil(rows / mlen) + num_col_blocks = out_features // mlen + num_k_tiles = _math.ceil(k_total / mlen) + + output_strict = rows % mlen == 0 + output = prog.alloc(name, rows, out_features, strict=output_strict) + + if num_k_tiles <= MAX_K_TILES: + # Single pass: all K tiles fit in MRAM + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + prog.vram_sub_projection_to( + input_var, + row_idx, + weight_var, + col_idx, + output, + row_idx, + col_idx, + ) + else: + # K-split: chunk K tiles into groups of MAX_K_TILES + k_chunks = [] + k_start = 0 + while k_start < num_k_tiles: + k_end = min(k_start + MAX_K_TILES, num_k_tiles) + k_chunks.append((k_start, k_end - k_start)) + k_start = k_end + + temp = prog.alloc(f"{name}_ksplit_tmp", rows, out_features, strict=output_strict) + + for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + if k_chunk_idx == 0: + prog.vram_sub_projection_to( + input_var, row_idx, weight_var, col_idx, + output, row_idx, col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + else: + prog.vram_sub_projection_to( + input_var, row_idx, weight_var, col_idx, + temp, row_idx, col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + prog.vram_block_add_to( + output, row_idx, col_idx, + temp, row_idx, col_idx, + output, row_idx, col_idx, + ) + + prog.free_tensor(temp) + + return output + + # --------------------------------------------------------------------------- # Main compilation function # --------------------------------------------------------------------------- @@ -412,14 +488,6 @@ def compile_hf_model( residual = X X = rms_norm(X) Q = linear(X, W_q) # Q projection - K_h = linear(X, W_k_h) # K projection (per KV head, on-chip) - V_h = linear(X, W_v_h) # V projection (per KV head, on-chip) - # RoPE (native mode only): - Q_rot_h = linear(Q_h, R_rope) # rotate_half via matmul - Q_h = Q_h * cos + Q_rot_h * sin - K_rot_h = linear(K_h, R_rope) - K_h = K_h * cos + K_rot_h * sin - store K_h, V_h to HBM O = flash_attention(Q, K, V, scale) O = linear(O, W_o) # O projection X = O + residual @@ -431,8 +499,8 @@ def compile_hf_model( if include_lm_head: logits = linear(X, W_lm_head) # (seq, vocab_size) - RoPE is applied in native mode via rotate_half matrix multiplication. - Q, K, V, and O linear projections are all computed on-chip. + RoPE is omitted (orthogonal to multi-layer compilation testing). + Q and O linear projections are included for the attention block. Args: model: nn.Module (HF CausalLM model, already loaded) @@ -528,8 +596,6 @@ def compile_hf_model( if native_mode: print(f" MHA: num_heads={num_heads}, num_kv_heads={num_kv_heads}, " f"total_q_dim={total_q_dim}, total_kv_dim={total_kv_dim}") - if include_lm_head: - print(f" lm_head: W_lm_head={lm_head_weight.shape}, vocab_size={vocab_size}") print("=" * 80) # ----------------------------------------------------------- weights @@ -557,33 +623,27 @@ def compile_hf_model( # ----------------------------------------------------------- test data torch.manual_seed(seed) - # Embedding lookup: use real HF embedding table if available - if embed is not None: - input_ids = torch.randint(0, native_cfg["vocab_size"] or 32000, (seq_len,)) - with torch.no_grad(): - token_embeds = embed(input_ids).float() - if token_embeds.dim() == 3: - token_embeds = token_embeds.squeeze(0) - # Slice to target hidden dim (native mode uses full hidden) - token_embeds = token_embeds[:, :hidden] - print(f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " - f"token_embeds={token_embeds.shape}") - else: - token_embeds = torch.randn(seq_len, hidden) - print(f"\nNo embed_tokens found; using random token_embeds: {token_embeds.shape}") - - # Llama-style models use RoPE (not learned position embeddings). - # Set pos_weight to zeros so embedding_add is a no-op for position. - pos_weight = torch.zeros(seq_len, hidden) - - # K/V test data: native mode computes K/V on-chip, legacy mode uses precomputed + # K/V test data: per-KV-head in native mode, single matrix in legacy mode if native_mode: - # K/V are computed on-chip from X_normed @ W_k/W_v — no precomputed K/V needed - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") + # List of lists: K_head_mats[layer][kv_head] = (seq_len, head_dim) + K_head_mats = [] + V_head_mats = [] for i in range(n_layers): + k_heads_i = [] + v_heads_i = [] for kv_h in range(num_kv_heads): - print(f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " - f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}") + X_ctx = torch.randn(seq_len, hidden) + k_heads_i.append(X_ctx @ all_weights[i]["W_k_heads"][kv_h]) + v_heads_i.append(X_ctx @ all_weights[i]["W_v_heads"][kv_h]) + K_head_mats.append(k_heads_i) + V_head_mats.append(v_heads_i) + + print(f"\ntoken_embeds: {token_embeds.shape}") + print(f"pos_weight: {pos_weight.shape}") + for i in range(n_layers): + for kv_h in range(num_kv_heads): + print(f" K_{i}_h{kv_h}: {K_head_mats[i][kv_h].shape}, " + f"V_{i}_h{kv_h}: {V_head_mats[i][kv_h].shape}") else: K_mats = [] V_mats = [] @@ -592,92 +652,60 @@ def compile_hf_model( K_mats.append(X_ctx @ all_weights[i]["W_k"]) V_mats.append(X_ctx @ all_weights[i]["W_v"]) - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") + print(f"\ntoken_embeds: {token_embeds.shape}") + print(f"pos_weight: {pos_weight.shape}") for i in range(n_layers): print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") print(f"attn_scale: {scale:.6f}") # ----------------------------------------------------------- golden ref - _do_quant = golden_precision in ("hardware", "no_bf16") - _do_bf16 = golden_precision in ("hardware", "no_weight_quant") - _qw = quantize_to_mxfp if _do_quant else (lambda x: x) - _to_inter = (lambda x: x.to(torch.bfloat16)) if _do_bf16 else (lambda x: x) - _from_inter = (lambda x: x.float()) if _do_bf16 else (lambda x: x) - _prec_label = {"hardware": "MXFP8 weights + BF16 intermediates", - "no_weight_quant": "float32 weights + BF16 intermediates", - "no_bf16": "MXFP8 weights + float32 intermediates", - "fp32": "float32 weights + float32 intermediates"}[golden_precision] - print(f"\n--- CPU Golden Reference ({_prec_label}) ---") + print("\n--- CPU Golden Reference (MXFP8 quantized HBM + BF16 intermediates) ---") if native_mode: - W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - if native_mode: - R_matrix = _make_rotate_half_matrix(head_dim) - R_rope_q = _qw(R_matrix) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, native_cfg["rope_theta"]) - + # Quantize per-head K/V + K_q_heads = [[quantize_to_mxfp(K_head_mats[i][h]) for h in range(num_kv_heads)] + for i in range(n_layers)] + V_q_heads = [[quantize_to_mxfp(V_head_mats[i][h]) for h in range(num_kv_heads)] + for i in range(n_layers)] else: - K_q_list = [_qw(K_mats[i]) for i in range(n_layers)] - V_q_list = [_qw(V_mats[i]) for i in range(n_layers)] + K_q_list = [quantize_to_mxfp(K_mats[i]) for i in range(n_layers)] + V_q_list = [quantize_to_mxfp(V_mats[i]) for i in range(n_layers)] - X_gold = _qw(token_embeds.clone()) + _qw(pos_weight) # embedding_add (MXFP8-quantized, matching HBM) + X_gold = token_embeds.clone() + pos_weight # embedding_add ratio = num_heads // num_kv_heads for i in range(n_layers): w = all_weights[i] - W_q_q = _qw(w["W_q"]) - W_o_q = _qw(w["W_o"]) - W_gate_q = _qw(w["W_gate"]) - W_up_q = _qw(w["W_up"]) - W_down_q = _qw(w["W_down"]) + W_q_q = quantize_to_mxfp(w["W_q"]) + W_o_q = quantize_to_mxfp(w["W_o"]) + W_gate_q = quantize_to_mxfp(w["W_gate"]) + W_up_q = quantize_to_mxfp(w["W_up"]) + W_down_q = quantize_to_mxfp(w["W_down"]) # --- Attention block --- residual = X_gold.clone() - X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 - rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() + # rms_norm with bfloat16 to match PLENA + X_bf = X_gold.to(torch.bfloat16) + rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) + X_gold = (X_bf * rms).float() if native_mode: - Q_gold = _ksplit_matmul(X_gold, W_q_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - Q_gold = _from_inter(_to_inter(Q_gold)) # VRAM write: BF16 - - K_q_heads_i = [] - V_q_heads_i = [] - for kv_h in range(num_kv_heads): - K_h = _ksplit_matmul(X_gold, W_k_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - V_h = _ksplit_matmul(X_gold, W_v_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - K_cos = _from_inter(_to_inter(K_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - K_rot_sin = _from_inter(_to_inter(K_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - K_h = _from_inter(_to_inter(K_cos) + _to_inter(K_rot_sin)) - # Hardware stores K/V to HBM as MXFP8, loads back as BF16 for attention - K_q_heads_i.append(_from_inter(_to_inter(_qw(K_h)))) - V_q_heads_i.append(_from_inter(_to_inter(_qw(V_h)))) - + # Q projection: X @ W_q (MXFP8-quantized weight, BF16 intermediate) + Q_gold = torch.matmul(X_gold.to(torch.bfloat16).float(), W_q_q.float()).to(torch.bfloat16).float() + # Per-head flash attention O_heads = [] for h in range(num_heads): kv_h = h // ratio Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - Q_cos = _from_inter(_to_inter(Q_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - Q_rot_sin = _from_inter(_to_inter(Q_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - Q_h = _from_inter(_to_inter(Q_cos) + _to_inter(Q_rot_sin)) - O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) + O_h = _flash_attn_ref(Q_h, K_q_heads[i][kv_h], V_q_heads[i][kv_h], scale) O_heads.append(O_h) - attn_out = _from_inter(_to_inter(torch.cat(O_heads, dim=1))) # VRAM write: BF16 - O_gold = _ksplit_matmul(attn_out, W_o_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - O_gold = _from_inter(_to_inter(O_gold)) # VRAM write: BF16 - X_gold = _from_inter(_to_inter(O_gold + residual)) # residual add -> VRAM write: BF16 + attn_out = torch.cat(O_heads, dim=1) # (seq, num_heads * head_dim) + # O projection + O_gold = torch.matmul(attn_out.to(torch.bfloat16).float(), W_o_q.float()).to(torch.bfloat16).float() + X_gold = O_gold + residual else: - attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) - X_gold = _from_inter(_to_inter(attn_out + residual)) # residual add -> VRAM write: BF16 + # Legacy: X is Q directly (no projection) + attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale) + X_gold = attn_out + residual # --- FFN block --- residual = X_gold.clone() @@ -823,14 +851,14 @@ def compile_hf_model( wq = prog.input(f"W_q_{i}", shape=(hidden, total_q_dim)) wo = prog.input(f"W_o_{i}", shape=(total_q_dim, hidden)) if native_mode: - wk_heads = [] - wv_heads = [] + k_heads = [] + v_heads = [] for kv_h in range(num_kv_heads): - wk_heads.append(prog.input(f"W_k_{i}_h{kv_h}", shape=(hidden, head_dim))) - wv_heads.append(prog.input(f"W_v_{i}_h{kv_h}", shape=(hidden, head_dim))) + k_heads.append(prog.input(f"K_{i}_h{kv_h}", shape=(seq_len, head_dim))) + v_heads.append(prog.input(f"V_{i}_h{kv_h}", shape=(seq_len, head_dim))) li_entry = { "W_q": wq, "W_o": wo, - "W_k_heads": wk_heads, "W_v_heads": wv_heads, + "K_heads": k_heads, "V_heads": v_heads, } else: ki = prog.input(f"K_{i}", shape=(seq_len, head_dim)) @@ -845,11 +873,6 @@ def compile_hf_model( li_entry.update({"W_gate": wg, "W_up": wu, "W_down": wd}) layer_inputs.append(li_entry) - # lm_head weight input (after all layer weights in HBM layout) - lm_head_input = None - if include_lm_head and lm_head_weight is not None: - lm_head_input = prog.input("W_lm_head", shape=(hidden, vocab_size)) - # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") POS_batch = prog.load_batch(pos_input, name="POS") @@ -864,10 +887,9 @@ def compile_hf_model( _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: - _pad_elems = _ffn_intermediate_end - _current_bump - # Allocate enough mlen-wide rows to cover the padding - _pad_rows = (_pad_elems + mlen - 1) // mlen - # Round up to mlen for VRAM alignment + _pad_size = _ffn_intermediate_end - _current_bump + _pad_rows = max(1, _pad_size // mlen) + # Round up to mlen multiple for VRAM alignment _pad_rows = ((_pad_rows + mlen - 1) // mlen) * mlen prog.alloc("_vram_padding", _pad_rows, mlen, strict=False) @@ -895,52 +917,22 @@ def compile_hf_model( # Q projection: current (seq, hidden) @ W_q (hidden, total_q_dim) Q = _linear_projection(prog, current, li["W_q"], f"Q_{i}") - # Per-KV-head: compute K/V on-chip, store to HBM, then run Q-heads + # Per-head flash attention with VRAM views q_full_addr = prog._compiler.get_vram_addr(Q.name) # Allocate O_full for concatenated head outputs O_full = prog.alloc(f"O_full_{i}", seq_len, total_q_dim) o_full_addr = prog._compiler.get_vram_addr(O_full.name) - # Store K/V InputVars for each KV head (filled during loop below) - kv_stored = [] - - for kv_h in range(num_kv_heads): - # K projection: current (seq, hidden) @ W_k_h (hidden, head_dim) - K_h = _linear_projection(prog, current, li["W_k_heads"][kv_h], f"K_{i}_h{kv_h}") - # V projection: current (seq, hidden) @ W_v_h (hidden, head_dim) - V_h = _linear_projection(prog, current, li["W_v_heads"][kv_h], f"V_{i}_h{kv_h}") - - # RoPE on K_h: K_rot = linear(K_h, R_rope), then K_h = K_h * cos + K_rot * sin - K_rot = _linear_projection(prog, K_h, r_input, f"K_rot_{i}_h{kv_h}") - prog.rope(K_h, K_rot, COS, SIN) - prog.free_tensor(K_rot) - - # Store K/V from VRAM to HBM (auto-allocates HBM space) - K_stored = prog.store(K_h, name=f"K_stored_{i}_h{kv_h}") - V_stored = prog.store(V_h, name=f"V_stored_{i}_h{kv_h}") - kv_stored.append((K_stored, V_stored)) - - # Free K/V VRAM — data is in HBM now - prog.free_tensor(K_h) - prog.free_tensor(V_h) - for h in range(num_heads): kv_h = h // ratio - K_stored, V_stored = kv_stored[kv_h] # View for this head's Q slice: column block h in Q_full q_h_addr = q_full_addr + h * seq_len * mlen Q_h = prog.alloc_at(f"Q_h{h}_{i}", seq_len, head_dim, q_h_addr) - # RoPE on Q_h: Q_rot = linear(Q_h, R_rope), then Q_h = Q_h * cos + Q_rot * sin - Q_rot = _linear_projection(prog, Q_h, r_input, f"Q_rot_{i}_h{h}") - prog.rope(Q_h, Q_rot, COS, SIN) - prog.free_tensor(Q_rot) - - # Single-head flash attention (K/V read from HBM) with causal mask - O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale, - causal_mask=CAUSAL_MASK) + # Single-head flash attention (allocates S, PV, O internally) + O_h = ops.flash_attention(prog, Q_h, li["K_heads"][kv_h], li["V_heads"][kv_h], scale) # Copy O_h to the right column block of O_full o_h_dest_addr = o_full_addr + h * seq_len * mlen @@ -962,9 +954,8 @@ def compile_hf_model( prog.vram_add(O_proj, scratch) current_after_attn = O_proj else: - # Legacy: X is Q directly (no projections), with causal mask - O = ops.flash_attention(prog, current, li["K"], li["V"], scale, - causal_mask=CAUSAL_MASK) + # Legacy: X is Q directly (no projections) + O = ops.flash_attention(prog, current, li["K"], li["V"], scale) prog.vram_add(O, scratch) current_after_attn = O @@ -983,7 +974,6 @@ def compile_hf_model( prog.vram_add(current_after_attn, scratch) current = current_after_attn # carry forward - prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} COMPLETE ===\n" # Final norm prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) @@ -1013,8 +1003,8 @@ def compile_hf_model( input_tensors[f"W_o_{i}"] = all_weights[i]["W_o"] if native_mode: for kv_h in range(num_kv_heads): - input_tensors[f"W_k_{i}_h{kv_h}"] = all_weights[i]["W_k_heads"][kv_h] - input_tensors[f"W_v_{i}_h{kv_h}"] = all_weights[i]["W_v_heads"][kv_h] + input_tensors[f"K_{i}_h{kv_h}"] = K_head_mats[i][kv_h] + input_tensors[f"V_{i}_h{kv_h}"] = V_head_mats[i][kv_h] else: input_tensors[f"K_{i}"] = K_mats[i] input_tensors[f"V_{i}"] = V_mats[i] @@ -1025,7 +1015,7 @@ def compile_hf_model( kv_keys = [] if native_mode: for kv_h in range(num_kv_heads): - kv_keys.extend([f"W_k_{i}_h{kv_h}", f"W_v_{i}_h{kv_h}"]) + kv_keys.extend([f"K_{i}_h{kv_h}", f"V_{i}_h{kv_h}"]) else: kv_keys = [f"K_{i}", f"V_{i}"] data_order.extend([ @@ -1034,11 +1024,6 @@ def compile_hf_model( f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}", ]) - # lm_head weight in input_tensors + data_order - if include_lm_head and lm_head_weight is not None: - input_tensors["W_lm_head"] = lm_head_weight - data_order.append("W_lm_head") - # FPRAM layout (same as single-layer decoder): # slot 0 = 0.0 (reserved) # slot 1 = attn_scale (flash_attention) @@ -1077,8 +1062,6 @@ def compile_hf_model( "num_heads": num_heads if native_mode else 1, "num_kv_heads": num_kv_heads if native_mode else 1, "native_mode": native_mode, - "include_lm_head": include_lm_head and lm_head_weight is not None, - "vocab_size": vocab_size if (include_lm_head and lm_head_weight is not None) else None, "mlen": mlen, "blen": blen, "isa_lines": len(lines), From a5be39ac1f3df90a31747f370c445b07a8d148b5 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 13:08:37 +0100 Subject: [PATCH 05/32] =?UTF-8?q?fix:=20audit=20fixes=20=E2=80=94=20assert?= =?UTF-8?q?ions,=20docstring,=20VRAM=20padding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - HIGH-2: assert out_features % mlen == 0 in _linear_projection - HIGH-3: assert cols % block_dim == 0 in _weight_hbm_bytes - HIGH-4: fix VRAM padding ceiling division - LOW-2: fix docstring import path (model_compiler → plena_frontend) --- aten/plena_frontend.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 3e27958..da6fe40 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -402,6 +402,9 @@ def _linear_projection(prog, input_var, weight_var, name): rows, k_total = input_var.shape _, out_features = weight_var.shape num_row_blocks = _math.ceil(rows / mlen) + assert out_features % mlen == 0, ( + f"out_features ({out_features}) must be a multiple of mlen ({mlen})" + ) num_col_blocks = out_features // mlen num_k_tiles = _math.ceil(k_total / mlen) @@ -887,9 +890,10 @@ def compile_hf_model( _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: - _pad_size = _ffn_intermediate_end - _current_bump - _pad_rows = max(1, _pad_size // mlen) - # Round up to mlen multiple for VRAM alignment + _pad_elems = _ffn_intermediate_end - _current_bump + # Allocate enough mlen-wide rows to cover the padding + _pad_rows = (_pad_elems + mlen - 1) // mlen + # Round up to mlen for VRAM alignment _pad_rows = ((_pad_rows + mlen - 1) // mlen) * mlen prog.alloc("_vram_padding", _pad_rows, mlen, strict=False) From 0700e802d7770a5212c68dd9279bccb872bedfdf Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 15:07:17 +0100 Subject: [PATCH 06/32] feat(plena_frontend): on-chip K/V projections + RoPE True transformer forward pass: - K/V computed on-chip via linear projection, stored to HBM - RoPE via rotate_half matrix: Q_rot = linear(Q, R) - Applied to both Q and K per head - 37K ISA lines at native dims (clm-60m hidden=384) Verified: native compile OK, legacy path 100% allclose. --- aten/plena_frontend.py | 160 ++++++++++++++++++++++++++++------------- 1 file changed, 112 insertions(+), 48 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index da6fe40..09acc52 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -27,7 +27,10 @@ from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.ops.registry import OpRegistry, Backend import compiler.aten.ops as ops -from transactional_emulator.testbench.model_layer_test_builder import quantize_to_mxfp +from transactional_emulator.testbench.model_layer_test_builder import ( + quantize_to_mxfp, + _make_rope_tables, +) import re _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate @@ -129,6 +132,24 @@ def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: return R +# --------------------------------------------------------------------------- +# RoPE helpers +# --------------------------------------------------------------------------- +def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: + """Build the (head_dim, head_dim) matrix that computes rotate_half. + + rotate_half(x) = x @ R, where R permutes the halves with a sign flip: + output[:d//2] = -input[d//2:] + output[d//2:] = +input[:d//2] + """ + R = torch.zeros(head_dim, head_dim) + half = head_dim // 2 + for i in range(half): + R[i + half, i] = -1.0 # first half of output = negated second half of input + R[i, i + half] = 1.0 # second half of output = first half of input + return R + + # --------------------------------------------------------------------------- # Model structure helpers # --------------------------------------------------------------------------- @@ -491,6 +512,14 @@ def compile_hf_model( residual = X X = rms_norm(X) Q = linear(X, W_q) # Q projection + K_h = linear(X, W_k_h) # K projection (per KV head, on-chip) + V_h = linear(X, W_v_h) # V projection (per KV head, on-chip) + # RoPE (native mode only): + Q_rot_h = linear(Q_h, R_rope) # rotate_half via matmul + Q_h = Q_h * cos + Q_rot_h * sin + K_rot_h = linear(K_h, R_rope) + K_h = K_h * cos + K_rot_h * sin + store K_h, V_h to HBM O = flash_attention(Q, K, V, scale) O = linear(O, W_o) # O projection X = O + residual @@ -502,8 +531,8 @@ def compile_hf_model( if include_lm_head: logits = linear(X, W_lm_head) # (seq, vocab_size) - RoPE is omitted (orthogonal to multi-layer compilation testing). - Q and O linear projections are included for the attention block. + RoPE is applied in native mode via rotate_half matrix multiplication. + Q, K, V, and O linear projections are all computed on-chip. Args: model: nn.Module (HF CausalLM model, already loaded) @@ -626,27 +655,15 @@ def compile_hf_model( # ----------------------------------------------------------- test data torch.manual_seed(seed) - # K/V test data: per-KV-head in native mode, single matrix in legacy mode + # K/V test data: native mode computes K/V on-chip, legacy mode uses precomputed if native_mode: - # List of lists: K_head_mats[layer][kv_head] = (seq_len, head_dim) - K_head_mats = [] - V_head_mats = [] - for i in range(n_layers): - k_heads_i = [] - v_heads_i = [] - for kv_h in range(num_kv_heads): - X_ctx = torch.randn(seq_len, hidden) - k_heads_i.append(X_ctx @ all_weights[i]["W_k_heads"][kv_h]) - v_heads_i.append(X_ctx @ all_weights[i]["W_v_heads"][kv_h]) - K_head_mats.append(k_heads_i) - V_head_mats.append(v_heads_i) - + # K/V are computed on-chip from X_normed @ W_k/W_v — no precomputed K/V needed print(f"\ntoken_embeds: {token_embeds.shape}") print(f"pos_weight: {pos_weight.shape}") for i in range(n_layers): for kv_h in range(num_kv_heads): - print(f" K_{i}_h{kv_h}: {K_head_mats[i][kv_h].shape}, " - f"V_{i}_h{kv_h}: {V_head_mats[i][kv_h].shape}") + print(f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " + f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}") else: K_mats = [] V_mats = [] @@ -665,11 +682,18 @@ def compile_hf_model( print("\n--- CPU Golden Reference (MXFP8 quantized HBM + BF16 intermediates) ---") if native_mode: - # Quantize per-head K/V - K_q_heads = [[quantize_to_mxfp(K_head_mats[i][h]) for h in range(num_kv_heads)] - for i in range(n_layers)] - V_q_heads = [[quantize_to_mxfp(V_head_mats[i][h]) for h in range(num_kv_heads)] - for i in range(n_layers)] + # Quantize per-KV-head projection weights for golden reference + W_k_q_heads = [[quantize_to_mxfp(all_weights[i]["W_k_heads"][h]) + for h in range(num_kv_heads)] + for i in range(n_layers)] + W_v_q_heads = [[quantize_to_mxfp(all_weights[i]["W_v_heads"][h]) + for h in range(num_kv_heads)] + for i in range(n_layers)] + if native_mode: + R_matrix = _make_rotate_half_matrix(head_dim) + R_rope_q = quantize_to_mxfp(R_matrix) + cos_table, sin_table = _make_rope_tables(seq_len, head_dim, native_cfg["rope_theta"]) + else: K_q_list = [quantize_to_mxfp(K_mats[i]) for i in range(n_layers)] V_q_list = [quantize_to_mxfp(V_mats[i]) for i in range(n_layers)] @@ -694,12 +718,34 @@ def compile_hf_model( if native_mode: # Q projection: X @ W_q (MXFP8-quantized weight, BF16 intermediate) Q_gold = torch.matmul(X_gold.to(torch.bfloat16).float(), W_q_q.float()).to(torch.bfloat16).float() - # Per-head flash attention + + # On-chip K/V projections: X_normed @ W_k_h / W_v_h per KV head + # K/V are stored to HBM (MXFP8 quantized) then read back + K_q_heads_i = [] + V_q_heads_i = [] + for kv_h in range(num_kv_heads): + K_h = torch.matmul( + X_gold.to(torch.bfloat16).float(), W_k_q_heads[i][kv_h].float() + ).to(torch.bfloat16).float() + V_h = torch.matmul( + X_gold.to(torch.bfloat16).float(), W_v_q_heads[i][kv_h].float() + ).to(torch.bfloat16).float() + # RoPE on K_h: K_rot = K_h @ R_rope, K_h = K_h * cos + K_rot * sin + K_rot_h = torch.matmul(K_h.to(torch.bfloat16).float(), R_rope_q.float()).to(torch.bfloat16).float() + K_h = (K_h * cos_table + K_rot_h * sin_table) + # Quantize for HBM store+load round-trip + K_q_heads_i.append(quantize_to_mxfp(K_h)) + V_q_heads_i.append(quantize_to_mxfp(V_h)) + + # Per-head flash attention (with RoPE on Q per head) O_heads = [] for h in range(num_heads): kv_h = h // ratio Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - O_h = _flash_attn_ref(Q_h, K_q_heads[i][kv_h], V_q_heads[i][kv_h], scale) + # RoPE on Q_h: Q_rot = Q_h @ R_rope, Q_h = Q_h * cos + Q_rot * sin + Q_rot_h = torch.matmul(Q_h.to(torch.bfloat16).float(), R_rope_q.float()).to(torch.bfloat16).float() + Q_h = (Q_h * cos_table + Q_rot_h * sin_table) + O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale) O_heads.append(O_h) attn_out = torch.cat(O_heads, dim=1) # (seq, num_heads * head_dim) # O projection @@ -840,28 +886,20 @@ def compile_hf_model( COS = prog.load_batch(cos_input, name="COS") SIN = prog.load_batch(sin_input, name="SIN") - # Causal mask: (mlen, mlen) with 0 on/below diagonal, -inf above - causal_mask_data = torch.zeros(mlen, mlen) - causal_mask_data.masked_fill_( - torch.triu(torch.ones(mlen, mlen), diagonal=1).bool(), float('-inf') - ) - causal_mask_input = prog.input("causal_mask", shape=(mlen, mlen)) - CAUSAL_MASK = prog.load_batch(causal_mask_input, name="CAUSAL_MASK") - # Per-layer weight inputs (order determines HBM layout) layer_inputs = [] for i in range(n_layers): wq = prog.input(f"W_q_{i}", shape=(hidden, total_q_dim)) wo = prog.input(f"W_o_{i}", shape=(total_q_dim, hidden)) if native_mode: - k_heads = [] - v_heads = [] + wk_heads = [] + wv_heads = [] for kv_h in range(num_kv_heads): - k_heads.append(prog.input(f"K_{i}_h{kv_h}", shape=(seq_len, head_dim))) - v_heads.append(prog.input(f"V_{i}_h{kv_h}", shape=(seq_len, head_dim))) + wk_heads.append(prog.input(f"W_k_{i}_h{kv_h}", shape=(hidden, head_dim))) + wv_heads.append(prog.input(f"W_v_{i}_h{kv_h}", shape=(hidden, head_dim))) li_entry = { "W_q": wq, "W_o": wo, - "K_heads": k_heads, "V_heads": v_heads, + "W_k_heads": wk_heads, "W_v_heads": wv_heads, } else: ki = prog.input(f"K_{i}", shape=(seq_len, head_dim)) @@ -888,7 +926,6 @@ def compile_hf_model( _current_bump = 2 * seq_len * hidden # X + POS already allocated if native_mode: _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM - _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: _pad_elems = _ffn_intermediate_end - _current_bump # Allocate enough mlen-wide rows to cover the padding @@ -921,22 +958,51 @@ def compile_hf_model( # Q projection: current (seq, hidden) @ W_q (hidden, total_q_dim) Q = _linear_projection(prog, current, li["W_q"], f"Q_{i}") - # Per-head flash attention with VRAM views + # Per-KV-head: compute K/V on-chip, store to HBM, then run Q-heads q_full_addr = prog._compiler.get_vram_addr(Q.name) # Allocate O_full for concatenated head outputs O_full = prog.alloc(f"O_full_{i}", seq_len, total_q_dim) o_full_addr = prog._compiler.get_vram_addr(O_full.name) + # Store K/V InputVars for each KV head (filled during loop below) + kv_stored = [] + + for kv_h in range(num_kv_heads): + # K projection: current (seq, hidden) @ W_k_h (hidden, head_dim) + K_h = _linear_projection(prog, current, li["W_k_heads"][kv_h], f"K_{i}_h{kv_h}") + # V projection: current (seq, hidden) @ W_v_h (hidden, head_dim) + V_h = _linear_projection(prog, current, li["W_v_heads"][kv_h], f"V_{i}_h{kv_h}") + + # RoPE on K_h: K_rot = linear(K_h, R_rope), then K_h = K_h * cos + K_rot * sin + K_rot = _linear_projection(prog, K_h, r_input, f"K_rot_{i}_h{kv_h}") + prog.rope(K_h, K_rot, COS, SIN) + prog.free_tensor(K_rot) + + # Store K/V from VRAM to HBM (auto-allocates HBM space) + K_stored = prog.store(K_h, name=f"K_stored_{i}_h{kv_h}") + V_stored = prog.store(V_h, name=f"V_stored_{i}_h{kv_h}") + kv_stored.append((K_stored, V_stored)) + + # Free K/V VRAM — data is in HBM now + prog.free_tensor(K_h) + prog.free_tensor(V_h) + for h in range(num_heads): kv_h = h // ratio + K_stored, V_stored = kv_stored[kv_h] # View for this head's Q slice: column block h in Q_full q_h_addr = q_full_addr + h * seq_len * mlen Q_h = prog.alloc_at(f"Q_h{h}_{i}", seq_len, head_dim, q_h_addr) - # Single-head flash attention (allocates S, PV, O internally) - O_h = ops.flash_attention(prog, Q_h, li["K_heads"][kv_h], li["V_heads"][kv_h], scale) + # RoPE on Q_h: Q_rot = linear(Q_h, R_rope), then Q_h = Q_h * cos + Q_rot * sin + Q_rot = _linear_projection(prog, Q_h, r_input, f"Q_rot_{i}_h{h}") + prog.rope(Q_h, Q_rot, COS, SIN) + prog.free_tensor(Q_rot) + + # Single-head flash attention (K/V read from HBM) + O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale) # Copy O_h to the right column block of O_full o_h_dest_addr = o_full_addr + h * seq_len * mlen @@ -1000,15 +1066,13 @@ def compile_hf_model( input_tensors["COS"] = cos_table input_tensors["SIN"] = sin_table data_order.extend(["R_rope", "COS", "SIN"]) - input_tensors["causal_mask"] = causal_mask_data - data_order.append("causal_mask") for i in range(n_layers): input_tensors[f"W_q_{i}"] = all_weights[i]["W_q"] input_tensors[f"W_o_{i}"] = all_weights[i]["W_o"] if native_mode: for kv_h in range(num_kv_heads): - input_tensors[f"K_{i}_h{kv_h}"] = K_head_mats[i][kv_h] - input_tensors[f"V_{i}_h{kv_h}"] = V_head_mats[i][kv_h] + input_tensors[f"W_k_{i}_h{kv_h}"] = all_weights[i]["W_k_heads"][kv_h] + input_tensors[f"W_v_{i}_h{kv_h}"] = all_weights[i]["W_v_heads"][kv_h] else: input_tensors[f"K_{i}"] = K_mats[i] input_tensors[f"V_{i}"] = V_mats[i] @@ -1019,7 +1083,7 @@ def compile_hf_model( kv_keys = [] if native_mode: for kv_h in range(num_kv_heads): - kv_keys.extend([f"K_{i}_h{kv_h}", f"V_{i}_h{kv_h}"]) + kv_keys.extend([f"W_k_{i}_h{kv_h}", f"W_v_{i}_h{kv_h}"]) else: kv_keys = [f"K_{i}", f"V_{i}"] data_order.extend([ From e94db7582726d9cf5fa63c7c043c7a91e1565881 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 16:10:19 +0100 Subject: [PATCH 07/32] =?UTF-8?q?feat(plena=5Ffrontend):=20true=20e2e=20?= =?UTF-8?q?=E2=80=94=20real=20embeddings=20+=20lm=5Fhead=20+=20multi-layer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Real embedding lookup via embed_tokens(input_ids) replacing random input - Zero pos_weight for Llama (RoPE handles positioning) - Optional lm_head projection (include_lm_head=True, default False) - Tied-weight detection for shared embed/lm_head - Multi-layer verified at 2 layers native dims (163K ISA lines) Full pipeline: embed → [norm → Q/K/V proj → RoPE → MHA → O proj → res → norm → FFN → res] × N → norm → [lm_head] Verified: 92.65% allclose at native dims (hidden=384, 1 layer, with RoPE). --- aten/plena_frontend.py | 48 +++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 09acc52..974a8b3 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -498,7 +498,6 @@ def compile_hf_model( blen: int = 4, seed: int = 42, include_lm_head: bool = False, - golden_precision: str = "hardware", ) -> dict: """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. @@ -545,11 +544,6 @@ def compile_hf_model( blen: Batch tile length (default 4) seed: Random seed for test data generation include_lm_head: If True, add lm_head projection after final norm (default False) - golden_precision: Precision mode for golden reference computation. - "hardware" — MXFP8 weights + BF16 intermediates (default, matches HW) - "no_weight_quant" — float32 weights + BF16 intermediates (isolates MXFP8 effect) - "no_bf16" — MXFP8 weights + float32 intermediates (isolates BF16 effect) - "fp32" — float32 weights + float32 intermediates (should match HF exactly) Returns: dict with: @@ -628,6 +622,8 @@ def compile_hf_model( if native_mode: print(f" MHA: num_heads={num_heads}, num_kv_heads={num_kv_heads}, " f"total_q_dim={total_q_dim}, total_kv_dim={total_kv_dim}") + if include_lm_head: + print(f" lm_head: W_lm_head={lm_head_weight.shape}, vocab_size={vocab_size}") print("=" * 80) # ----------------------------------------------------------- weights @@ -655,11 +651,29 @@ def compile_hf_model( # ----------------------------------------------------------- test data torch.manual_seed(seed) + # Embedding lookup: use real HF embedding table if available + if embed is not None: + input_ids = torch.randint(0, native_cfg["vocab_size"] or 32000, (seq_len,)) + with torch.no_grad(): + token_embeds = embed(input_ids).float() + if token_embeds.dim() == 3: + token_embeds = token_embeds.squeeze(0) + # Slice to target hidden dim (native mode uses full hidden) + token_embeds = token_embeds[:, :hidden] + print(f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " + f"token_embeds={token_embeds.shape}") + else: + token_embeds = torch.randn(seq_len, hidden) + print(f"\nNo embed_tokens found; using random token_embeds: {token_embeds.shape}") + + # Llama-style models use RoPE (not learned position embeddings). + # Set pos_weight to zeros so embedding_add is a no-op for position. + pos_weight = torch.zeros(seq_len, hidden) + # K/V test data: native mode computes K/V on-chip, legacy mode uses precomputed if native_mode: # K/V are computed on-chip from X_normed @ W_k/W_v — no precomputed K/V needed - print(f"\ntoken_embeds: {token_embeds.shape}") - print(f"pos_weight: {pos_weight.shape}") + print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") for i in range(n_layers): for kv_h in range(num_kv_heads): print(f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " @@ -672,8 +686,7 @@ def compile_hf_model( K_mats.append(X_ctx @ all_weights[i]["W_k"]) V_mats.append(X_ctx @ all_weights[i]["W_v"]) - print(f"\ntoken_embeds: {token_embeds.shape}") - print(f"pos_weight: {pos_weight.shape}") + print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") for i in range(n_layers): print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") print(f"attn_scale: {scale:.6f}") @@ -781,9 +794,8 @@ def compile_hf_model( # lm_head projection (optional) if include_lm_head and lm_head_weight is not None: W_lm_q = quantize_to_mxfp(lm_head_weight) - # Hardware: MXFP8 → BF16 (MRAM) → float32 (M_MM) logits_gold = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_lm_q.to(torch.bfloat16).float() + X_gold.to(torch.bfloat16).float(), W_lm_q.float() ).to(torch.bfloat16).float() golden_out = logits_gold print(f" logits_gold: {golden_out.shape}") @@ -914,6 +926,11 @@ def compile_hf_model( li_entry.update({"W_gate": wg, "W_up": wu, "W_down": wd}) layer_inputs.append(li_entry) + # lm_head weight input (after all layer weights in HBM layout) + lm_head_input = None + if include_lm_head and lm_head_weight is not None: + lm_head_input = prog.input("W_lm_head", shape=(hidden, vocab_size)) + # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") POS_batch = prog.load_batch(pos_input, name="POS") @@ -1092,6 +1109,11 @@ def compile_hf_model( f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}", ]) + # lm_head weight in input_tensors + data_order + if include_lm_head and lm_head_weight is not None: + input_tensors["W_lm_head"] = lm_head_weight + data_order.append("W_lm_head") + # FPRAM layout (same as single-layer decoder): # slot 0 = 0.0 (reserved) # slot 1 = attn_scale (flash_attention) @@ -1130,6 +1152,8 @@ def compile_hf_model( "num_heads": num_heads if native_mode else 1, "num_kv_heads": num_kv_heads if native_mode else 1, "native_mode": native_mode, + "include_lm_head": include_lm_head and lm_head_weight is not None, + "vocab_size": vocab_size if (include_lm_head and lm_head_weight is not None) else None, "mlen": mlen, "blen": blen, "isa_lines": len(lines), From 0d9b180e2f99b01a8a2bbaa91ed35c341795c0ec Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 16:26:34 +0100 Subject: [PATCH 08/32] =?UTF-8?q?feat(plena=5Ffrontend):=20three-way=20com?= =?UTF-8?q?parison=20=E2=80=94=20HF=20float32=20vs=20golden=20vs=20emulato?= =?UTF-8?q?r?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aten/plena_frontend.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 974a8b3..7c5425e 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -844,13 +844,13 @@ def compile_hf_model( # RoPE on Q_h (float32) Q_rot = Q_h @ R_mat_f32 Q_h = Q_h * cos_f32 + Q_rot * sin_f32 - O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) + O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale) O_heads_hf.append(O_h) attn_out_hf = torch.cat(O_heads_hf, dim=1) O_hf = attn_out_hf @ w["W_o"].float() X_hf = O_hf + residual else: - attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale, causal=True) + attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale) X_hf = attn_out_hf + residual # --- FFN block (float32) --- @@ -1297,4 +1297,24 @@ def compile_and_run( except Exception as e: print(f" (skipped: {e})") + comp_results, _params = compare_emulator_output(build_dir) + + # Three-way comparison + golden = result["golden_output"] + hf_gt = result["hf_ground_truth"] + print("\n--- Three-way comparison ---") + if hf_gt is not None and golden is not None: + # HF float32 vs golden (MXFP8 + BF16) + n = min(hf_gt.numel(), golden.numel()) + allclose_hf_vs_gold = ( + torch.isclose(hf_gt.float().flatten()[:n], + golden.float().flatten()[:n], atol=1e-2) + .float().mean().item() * 100 + ) + print(f" HF float32 vs golden (MXFP8+BF16): {allclose_hf_vs_gold:.2f}% allclose") + # Emulator vs golden: reported by compare_emulator_output + emu_match = comp_results.get("allclose_match_rate", None) + if emu_match is not None: + print(f" Emulator vs golden (MXFP8+BF16): {emu_match:.2f}% allclose") + return {**result["info"], **comp_results} From a6beac2100c8e54d831168c09a4df827be7b6d75 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 28 Apr 2026 16:52:40 +0100 Subject: [PATCH 09/32] feat: causal masking + real embeddings + lm_head + three-way comparison - Causal mask: upper-triangular -inf added to QK^T scores before softmax - Real embedding lookup via embed_tokens(input_ids) - Optional lm_head projection (include_lm_head flag) - Three-way comparison: HF float32 vs golden (MXFP8+BF16) vs emulator - pos_weight=zeros for Llama (RoPE handles positioning) Verified: 99.93% allclose with causal mask (sliced dims). HF vs golden at native dims: 100% allclose, cosine=0.9995. --- aten/plena_frontend.py | 52 +++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 7c5425e..180fa02 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -282,27 +282,20 @@ def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice, num # Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) # --------------------------------------------------------------------------- def _flash_attn_ref(Q, K, V, scale, causal=False): - """CPU reference: scaled dot-product attention matching hardware BF16 precision. + """CPU reference: scaled dot-product attention. - Hardware path: - S = Q @ K^T (M_TMM in f32, written to VRAM as BF16 via M_MM_WO) - S *= scale (V_MUL_VF: BF16 * f32 -> BF16) - P = softmax(S) (online softmax on BF16 data) - O = P @ V (M_MM in f32, written to VRAM as BF16 via M_MM_WO) - - We model the key BF16 truncation points to match hardware precision. + Args: + Q, K, V: query, key, value tensors + scale: attention scale factor (1/sqrt(d)) + causal: if True, apply causal mask (position i can only attend to j <= i) """ - # S = Q @ K^T, then truncate to BF16 (M_MM_WO writes BF16) - scores = (Q @ K.T).to(torch.bfloat16).float() * scale - scores = scores.to(torch.bfloat16).float() # V_MUL_VF result is BF16 + scores = (Q @ K.T) * scale if causal: mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], device=scores.device), diagonal=1).bool() scores.masked_fill_(mask, float('-inf')) - # Softmax output written to BF16 VRAM - attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() - # O = P @ V, result written to BF16 VRAM - return (attn @ V).to(torch.bfloat16).float() + attn = F.softmax(scores, dim=-1) + return attn @ V def _rms_norm_ref(x, eps): @@ -758,7 +751,7 @@ def compile_hf_model( # RoPE on Q_h: Q_rot = Q_h @ R_rope, Q_h = Q_h * cos + Q_rot * sin Q_rot_h = torch.matmul(Q_h.to(torch.bfloat16).float(), R_rope_q.float()).to(torch.bfloat16).float() Q_h = (Q_h * cos_table + Q_rot_h * sin_table) - O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale) + O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) O_heads.append(O_h) attn_out = torch.cat(O_heads, dim=1) # (seq, num_heads * head_dim) # O projection @@ -766,7 +759,7 @@ def compile_hf_model( X_gold = O_gold + residual else: # Legacy: X is Q directly (no projection) - attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale) + attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) X_gold = attn_out + residual # --- FFN block --- @@ -844,13 +837,13 @@ def compile_hf_model( # RoPE on Q_h (float32) Q_rot = Q_h @ R_mat_f32 Q_h = Q_h * cos_f32 + Q_rot * sin_f32 - O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale) + O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) O_heads_hf.append(O_h) attn_out_hf = torch.cat(O_heads_hf, dim=1) O_hf = attn_out_hf @ w["W_o"].float() X_hf = O_hf + residual else: - attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale) + attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale, causal=True) X_hf = attn_out_hf + residual # --- FFN block (float32) --- @@ -898,6 +891,14 @@ def compile_hf_model( COS = prog.load_batch(cos_input, name="COS") SIN = prog.load_batch(sin_input, name="SIN") + # Causal mask: (mlen, mlen) with 0 on/below diagonal, -inf above + causal_mask_data = torch.zeros(mlen, mlen) + causal_mask_data.masked_fill_( + torch.triu(torch.ones(mlen, mlen), diagonal=1).bool(), float('-inf') + ) + causal_mask_input = prog.input("causal_mask", shape=(mlen, mlen)) + CAUSAL_MASK = prog.load_batch(causal_mask_input, name="CAUSAL_MASK") + # Per-layer weight inputs (order determines HBM layout) layer_inputs = [] for i in range(n_layers): @@ -943,6 +944,7 @@ def compile_hf_model( _current_bump = 2 * seq_len * hidden # X + POS already allocated if native_mode: _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM + _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: _pad_elems = _ffn_intermediate_end - _current_bump # Allocate enough mlen-wide rows to cover the padding @@ -1018,8 +1020,9 @@ def compile_hf_model( prog.rope(Q_h, Q_rot, COS, SIN) prog.free_tensor(Q_rot) - # Single-head flash attention (K/V read from HBM) - O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale) + # Single-head flash attention (K/V read from HBM) with causal mask + O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale, + causal_mask=CAUSAL_MASK) # Copy O_h to the right column block of O_full o_h_dest_addr = o_full_addr + h * seq_len * mlen @@ -1041,8 +1044,9 @@ def compile_hf_model( prog.vram_add(O_proj, scratch) current_after_attn = O_proj else: - # Legacy: X is Q directly (no projections) - O = ops.flash_attention(prog, current, li["K"], li["V"], scale) + # Legacy: X is Q directly (no projections), with causal mask + O = ops.flash_attention(prog, current, li["K"], li["V"], scale, + causal_mask=CAUSAL_MASK) prog.vram_add(O, scratch) current_after_attn = O @@ -1083,6 +1087,8 @@ def compile_hf_model( input_tensors["COS"] = cos_table input_tensors["SIN"] = sin_table data_order.extend(["R_rope", "COS", "SIN"]) + input_tensors["causal_mask"] = causal_mask_data + data_order.append("causal_mask") for i in range(n_layers): input_tensors[f"W_q_{i}"] = all_weights[i]["W_q"] input_tensors[f"W_o_{i}"] = all_weights[i]["W_o"] From 7ece0a77718c2b136100a44a765e133d27b1e195 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Wed, 29 Apr 2026 21:04:59 +0100 Subject: [PATCH 10/32] feat: add layer progress markers to ISA output for tracking --- aten/plena_frontend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 180fa02..5992709 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -1065,6 +1065,7 @@ def compile_hf_model( prog.vram_add(current_after_attn, scratch) current = current_after_attn # carry forward + prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} COMPLETE ===\n" # Final norm prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) From 7881fb1ef1accc6f2130936dfb5e0964b0561b50 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Fri, 1 May 2026 22:19:41 +0100 Subject: [PATCH 11/32] fix: add load_toml_config to compiler/utils/load_config.py The compiler/utils/ package shadows tools/utils/ when both are on sys.path, causing ImportError for load_toml_config which only existed in tools/utils/load_config.py. Add the function here so imports resolve regardless of PYTHONPATH ordering. --- utils/load_config.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/utils/load_config.py b/utils/load_config.py index 44901bb..3894f70 100644 --- a/utils/load_config.py +++ b/utils/load_config.py @@ -1,10 +1,7 @@ import re from pathlib import Path -try: - import toml -except ImportError: - toml = None +import toml _PARAM_PATTERN = re.compile(r"\s*(?:localparam|parameter)\s+(?:\w+\s+)?(\w+)\s*=\s*([^;]+);") @@ -40,8 +37,6 @@ def load_svh_settings(file_path: str | Path) -> dict[str, int]: def load_toml_config(file_path, section_to_load=None, mode="BEHAVIOR"): - if toml is None: - raise ImportError("'toml' package required for load_toml_config. Install with: pip install toml") with open(file_path) as f: full_toml = toml.load(f) mode_section = full_toml.get(mode, {}) From 2021b5d77a2edc29e980a4b6b199d8feaaf888a7 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Sat, 2 May 2026 00:25:15 +0100 Subject: [PATCH 12/32] feat: golden_precision ablation for quantization gap proof MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add golden_precision parameter to compile_hf_model with four modes: - hardware: MXFP8 weights + BF16 intermediates (default, matches HW) - no_weight_quant: fp32 weights + BF16 intermediates - no_bf16: MXFP8 weights + fp32 intermediates - fp32: float32 everything Ablation on 5-layer clm-60m proves MXFP8 weight quantization accounts for 100% of the HF-vs-golden gap (51.95% → 98.64% when removed). BF16 intermediate precision contributes 0%. Add test_quantization_ablation.py with @pytest.mark.slow. --- aten/plena_frontend.py | 100 ++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 5992709..6ad0ba3 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -491,6 +491,7 @@ def compile_hf_model( blen: int = 4, seed: int = 42, include_lm_head: bool = False, + golden_precision: str = "hardware", ) -> dict: """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. @@ -537,6 +538,11 @@ def compile_hf_model( blen: Batch tile length (default 4) seed: Random seed for test data generation include_lm_head: If True, add lm_head projection after final norm (default False) + golden_precision: Precision mode for golden reference computation. + "hardware" — MXFP8 weights + BF16 intermediates (default, matches HW) + "no_weight_quant" — float32 weights + BF16 intermediates (isolates MXFP8 effect) + "no_bf16" — MXFP8 weights + float32 intermediates (isolates BF16 effect) + "fp32" — float32 weights + float32 intermediates (should match HF exactly) Returns: dict with: @@ -685,99 +691,93 @@ def compile_hf_model( print(f"attn_scale: {scale:.6f}") # ----------------------------------------------------------- golden ref - print("\n--- CPU Golden Reference (MXFP8 quantized HBM + BF16 intermediates) ---") + _do_quant = golden_precision in ("hardware", "no_bf16") + _do_bf16 = golden_precision in ("hardware", "no_weight_quant") + _qw = quantize_to_mxfp if _do_quant else (lambda x: x) + _to_inter = (lambda x: x.to(torch.bfloat16)) if _do_bf16 else (lambda x: x) + _from_inter = (lambda x: x.float()) if _do_bf16 else (lambda x: x) + _prec_label = {"hardware": "MXFP8 weights + BF16 intermediates", + "no_weight_quant": "float32 weights + BF16 intermediates", + "no_bf16": "MXFP8 weights + float32 intermediates", + "fp32": "float32 weights + float32 intermediates"}[golden_precision] + print(f"\n--- CPU Golden Reference ({_prec_label}) ---") if native_mode: - # Quantize per-KV-head projection weights for golden reference - W_k_q_heads = [[quantize_to_mxfp(all_weights[i]["W_k_heads"][h]) + W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] - W_v_q_heads = [[quantize_to_mxfp(all_weights[i]["W_v_heads"][h]) + W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] if native_mode: R_matrix = _make_rotate_half_matrix(head_dim) - R_rope_q = quantize_to_mxfp(R_matrix) + R_rope_q = _qw(R_matrix) cos_table, sin_table = _make_rope_tables(seq_len, head_dim, native_cfg["rope_theta"]) else: - K_q_list = [quantize_to_mxfp(K_mats[i]) for i in range(n_layers)] - V_q_list = [quantize_to_mxfp(V_mats[i]) for i in range(n_layers)] + K_q_list = [_qw(K_mats[i]) for i in range(n_layers)] + V_q_list = [_qw(V_mats[i]) for i in range(n_layers)] X_gold = token_embeds.clone() + pos_weight # embedding_add ratio = num_heads // num_kv_heads for i in range(n_layers): w = all_weights[i] - W_q_q = quantize_to_mxfp(w["W_q"]) - W_o_q = quantize_to_mxfp(w["W_o"]) - W_gate_q = quantize_to_mxfp(w["W_gate"]) - W_up_q = quantize_to_mxfp(w["W_up"]) - W_down_q = quantize_to_mxfp(w["W_down"]) + W_q_q = _qw(w["W_q"]) + W_o_q = _qw(w["W_o"]) + W_gate_q = _qw(w["W_gate"]) + W_up_q = _qw(w["W_up"]) + W_down_q = _qw(w["W_down"]) # --- Attention block --- residual = X_gold.clone() - # rms_norm with bfloat16 to match PLENA - X_bf = X_gold.to(torch.bfloat16) - rms = torch.rsqrt(X_bf.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - X_gold = (X_bf * rms).float() + X_bf = _to_inter(X_gold) + rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) + rms = _to_inter(rms) + X_gold = _from_inter(X_bf * rms) if native_mode: - # Q projection: X @ W_q (MXFP8-quantized weight, BF16 intermediate) - Q_gold = torch.matmul(X_gold.to(torch.bfloat16).float(), W_q_q.float()).to(torch.bfloat16).float() + Q_gold = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_q_q.float()))) - # On-chip K/V projections: X_normed @ W_k_h / W_v_h per KV head - # K/V are stored to HBM (MXFP8 quantized) then read back K_q_heads_i = [] V_q_heads_i = [] for kv_h in range(num_kv_heads): - K_h = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_k_q_heads[i][kv_h].float() - ).to(torch.bfloat16).float() - V_h = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_v_q_heads[i][kv_h].float() - ).to(torch.bfloat16).float() - # RoPE on K_h: K_rot = K_h @ R_rope, K_h = K_h * cos + K_rot * sin - K_rot_h = torch.matmul(K_h.to(torch.bfloat16).float(), R_rope_q.float()).to(torch.bfloat16).float() + K_h = _from_inter(_to_inter(torch.matmul( + _from_inter(_to_inter(X_gold)), W_k_q_heads[i][kv_h].float() + ))) + V_h = _from_inter(_to_inter(torch.matmul( + _from_inter(_to_inter(X_gold)), W_v_q_heads[i][kv_h].float() + ))) + K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), R_rope_q.float()))) K_h = (K_h * cos_table + K_rot_h * sin_table) - # Quantize for HBM store+load round-trip - K_q_heads_i.append(quantize_to_mxfp(K_h)) - V_q_heads_i.append(quantize_to_mxfp(V_h)) + K_q_heads_i.append(_qw(K_h)) + V_q_heads_i.append(_qw(V_h)) - # Per-head flash attention (with RoPE on Q per head) O_heads = [] for h in range(num_heads): kv_h = h // ratio Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - # RoPE on Q_h: Q_rot = Q_h @ R_rope, Q_h = Q_h * cos + Q_rot * sin - Q_rot_h = torch.matmul(Q_h.to(torch.bfloat16).float(), R_rope_q.float()).to(torch.bfloat16).float() + Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), R_rope_q.float()))) Q_h = (Q_h * cos_table + Q_rot_h * sin_table) O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) O_heads.append(O_h) - attn_out = torch.cat(O_heads, dim=1) # (seq, num_heads * head_dim) - # O projection - O_gold = torch.matmul(attn_out.to(torch.bfloat16).float(), W_o_q.float()).to(torch.bfloat16).float() + attn_out = torch.cat(O_heads, dim=1) + O_gold = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(attn_out)), W_o_q.float()))) X_gold = O_gold + residual else: - # Legacy: X is Q directly (no projection) attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) X_gold = attn_out + residual # --- FFN block --- residual = X_gold.clone() X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() - # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) - # Use _ksplit_matmul to match hardware K-split BF16 precision loss for - # projections that exceed MAX_K_TILES (e.g., hidden=576 → 9 tiles, inter=1536 → 24 tiles) - up_out = _ksplit_matmul(X_gold, W_up_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - gate_out = _ksplit_matmul(X_gold, W_gate_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - # Hardware: SiLU(up) * gate -> BF16 VRAM write before down projection - silu_gate = _to_inter(F.silu(_from_inter(_to_inter(up_out))) * _from_inter(_to_inter(gate_out))) - X_gold = _ksplit_matmul(_from_inter(silu_gate), W_down_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - X_gold = _from_inter(_to_inter(X_gold)) # VRAM write: BF16 after down proj - X_gold = _from_inter(_to_inter(X_gold + residual)) # residual add -> VRAM write: BF16 + rms = _to_inter(rms) + X_gold = _from_inter(X_bf * rms) + up_out = _to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_up_q.float())) + gate_out = _to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_gate_q.float())) + silu_gate = _to_inter(F.silu(_from_inter(up_out)) * _from_inter(gate_out)) + X_gold = _from_inter(_to_inter(torch.matmul(_from_inter(silu_gate), W_down_q.float()))) + X_gold = X_gold + residual print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") From eb5a8b6cf063cceed9d0af09a054909bdacb0a38 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Thu, 7 May 2026 01:04:20 +0100 Subject: [PATCH 13/32] fix(golden): BF16 truncation at all pipeline stages + K-split temp fix --- aten/plena_frontend.py | 188 ++++++++++++----------------------------- 1 file changed, 53 insertions(+), 135 deletions(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 6ad0ba3..f033a15 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -132,24 +132,6 @@ def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: return R -# --------------------------------------------------------------------------- -# RoPE helpers -# --------------------------------------------------------------------------- -def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: - """Build the (head_dim, head_dim) matrix that computes rotate_half. - - rotate_half(x) = x @ R, where R permutes the halves with a sign flip: - output[:d//2] = -input[d//2:] - output[d//2:] = +input[:d//2] - """ - R = torch.zeros(head_dim, head_dim) - half = head_dim // 2 - for i in range(half): - R[i + half, i] = -1.0 # first half of output = negated second half of input - R[i, i + half] = 1.0 # second half of output = first half of input - return R - - # --------------------------------------------------------------------------- # Model structure helpers # --------------------------------------------------------------------------- @@ -282,20 +264,27 @@ def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice, num # Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) # --------------------------------------------------------------------------- def _flash_attn_ref(Q, K, V, scale, causal=False): - """CPU reference: scaled dot-product attention. + """CPU reference: scaled dot-product attention matching hardware BF16 precision. - Args: - Q, K, V: query, key, value tensors - scale: attention scale factor (1/sqrt(d)) - causal: if True, apply causal mask (position i can only attend to j <= i) + Hardware path: + S = Q @ K^T (M_TMM in f32, written to VRAM as BF16 via M_MM_WO) + S *= scale (V_MUL_VF: BF16 * f32 -> BF16) + P = softmax(S) (online softmax on BF16 data) + O = P @ V (M_MM in f32, written to VRAM as BF16 via M_MM_WO) + + We model the key BF16 truncation points to match hardware precision. """ - scores = (Q @ K.T) * scale + # S = Q @ K^T, then truncate to BF16 (M_MM_WO writes BF16) + scores = (Q @ K.T).to(torch.bfloat16).float() * scale + scores = scores.to(torch.bfloat16).float() # V_MUL_VF result is BF16 if causal: mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], device=scores.device), diagonal=1).bool() scores.masked_fill_(mask, float('-inf')) - attn = F.softmax(scores, dim=-1) - return attn @ V + # Softmax output written to BF16 VRAM + attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() + # O = P @ V, result written to BF16 VRAM + return (attn @ V).to(torch.bfloat16).float() def _rms_norm_ref(x, eps): @@ -395,88 +384,6 @@ def _linear_projection(prog, input_var, weight_var, name): return output -# --------------------------------------------------------------------------- -# PLENA ISA helper: named linear projection (avoids name conflicts) -# --------------------------------------------------------------------------- -def _linear_projection(prog, input_var, weight_var, name): - """Emit a linear projection with a custom VRAM output name. - - Equivalent to ops.linear but uses *name* for the output allocation so - that multiple projections in the same scope don't collide on the - default "linear_out" name. - - Supports K-split: when K tiles exceed MRAM capacity (4 tiles), splits - into chunks and accumulates partial sums via a temporary buffer. - """ - import math as _math - - mlen = prog.mlen - MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements - - rows, k_total = input_var.shape - _, out_features = weight_var.shape - num_row_blocks = _math.ceil(rows / mlen) - assert out_features % mlen == 0, ( - f"out_features ({out_features}) must be a multiple of mlen ({mlen})" - ) - num_col_blocks = out_features // mlen - num_k_tiles = _math.ceil(k_total / mlen) - - output_strict = rows % mlen == 0 - output = prog.alloc(name, rows, out_features, strict=output_strict) - - if num_k_tiles <= MAX_K_TILES: - # Single pass: all K tiles fit in MRAM - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - ) - else: - # K-split: chunk K tiles into groups of MAX_K_TILES - k_chunks = [] - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - k_chunks.append((k_start, k_end - k_start)) - k_start = k_end - - temp = prog.alloc(f"{name}_ksplit_tmp", rows, out_features, strict=output_strict) - - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - if k_chunk_idx == 0: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - output, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - else: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - temp, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - prog.vram_block_add_to( - output, row_idx, col_idx, - temp, row_idx, col_idx, - output, row_idx, col_idx, - ) - - prog.free_tensor(temp) - - return output - - # --------------------------------------------------------------------------- # Main compilation function # --------------------------------------------------------------------------- @@ -732,52 +639,62 @@ def compile_hf_model( # --- Attention block --- residual = X_gold.clone() X_bf = _to_inter(X_gold) + # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - rms = _to_inter(rms) - X_gold = _from_inter(X_bf * rms) + X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() if native_mode: - Q_gold = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_q_q.float()))) + Q_gold = _ksplit_matmul(X_gold, W_q_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + Q_gold = _from_inter(_to_inter(Q_gold)) # VRAM write: BF16 K_q_heads_i = [] V_q_heads_i = [] for kv_h in range(num_kv_heads): - K_h = _from_inter(_to_inter(torch.matmul( - _from_inter(_to_inter(X_gold)), W_k_q_heads[i][kv_h].float() - ))) - V_h = _from_inter(_to_inter(torch.matmul( - _from_inter(_to_inter(X_gold)), W_v_q_heads[i][kv_h].float() - ))) - K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), R_rope_q.float()))) - K_h = (K_h * cos_table + K_rot_h * sin_table) - K_q_heads_i.append(_qw(K_h)) - V_q_heads_i.append(_qw(V_h)) + K_h = _ksplit_matmul(X_gold, W_k_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + V_h = _ksplit_matmul(X_gold, W_v_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), _from_inter(_to_inter(R_rope_q))))) + # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) + K_cos = _from_inter(_to_inter(K_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) + K_rot_sin = _from_inter(_to_inter(K_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) + K_h = _from_inter(_to_inter(K_cos) + _to_inter(K_rot_sin)) + # Hardware stores K/V to HBM as MXFP8, loads back as BF16 for attention + K_q_heads_i.append(_from_inter(_to_inter(_qw(K_h)))) + V_q_heads_i.append(_from_inter(_to_inter(_qw(V_h)))) O_heads = [] for h in range(num_heads): kv_h = h // ratio Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), R_rope_q.float()))) - Q_h = (Q_h * cos_table + Q_rot_h * sin_table) + Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), _from_inter(_to_inter(R_rope_q))))) + # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) + Q_cos = _from_inter(_to_inter(Q_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) + Q_rot_sin = _from_inter(_to_inter(Q_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) + Q_h = _from_inter(_to_inter(Q_cos) + _to_inter(Q_rot_sin)) O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) O_heads.append(O_h) - attn_out = torch.cat(O_heads, dim=1) - O_gold = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(attn_out)), W_o_q.float()))) - X_gold = O_gold + residual + attn_out = _from_inter(_to_inter(torch.cat(O_heads, dim=1))) # VRAM write: BF16 + O_gold = _ksplit_matmul(attn_out, W_o_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + O_gold = _from_inter(_to_inter(O_gold)) # VRAM write: BF16 + X_gold = _from_inter(_to_inter(O_gold + residual)) # residual add -> VRAM write: BF16 else: attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) - X_gold = attn_out + residual + X_gold = _from_inter(_to_inter(attn_out + residual)) # residual add -> VRAM write: BF16 # --- FFN block --- residual = X_gold.clone() X_bf = _to_inter(X_gold) + # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - rms = _to_inter(rms) - X_gold = _from_inter(X_bf * rms) - up_out = _to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_up_q.float())) - gate_out = _to_inter(torch.matmul(_from_inter(_to_inter(X_gold)), W_gate_q.float())) - silu_gate = _to_inter(F.silu(_from_inter(up_out)) * _from_inter(gate_out)) - X_gold = _from_inter(_to_inter(torch.matmul(_from_inter(silu_gate), W_down_q.float()))) - X_gold = X_gold + residual + X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() + # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) + # Use _ksplit_matmul to match hardware K-split BF16 precision loss for + # projections that exceed MAX_K_TILES (e.g., hidden=576 → 9 tiles, inter=1536 → 24 tiles) + up_out = _ksplit_matmul(X_gold, W_up_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + gate_out = _ksplit_matmul(X_gold, W_gate_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + # Hardware: SiLU(up) * gate -> BF16 VRAM write before down projection + silu_gate = _to_inter(F.silu(_from_inter(_to_inter(up_out))) * _from_inter(_to_inter(gate_out))) + X_gold = _ksplit_matmul(_from_inter(silu_gate), W_down_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) + X_gold = _from_inter(_to_inter(X_gold)) # VRAM write: BF16 after down proj + X_gold = _from_inter(_to_inter(X_gold + residual)) # residual add -> VRAM write: BF16 print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") @@ -787,8 +704,9 @@ def compile_hf_model( # lm_head projection (optional) if include_lm_head and lm_head_weight is not None: W_lm_q = quantize_to_mxfp(lm_head_weight) + # Hardware: MXFP8 → BF16 (MRAM) → float32 (M_MM) logits_gold = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_lm_q.float() + X_gold.to(torch.bfloat16).float(), W_lm_q.to(torch.bfloat16).float() ).to(torch.bfloat16).float() golden_out = logits_gold print(f" logits_gold: {golden_out.shape}") From 0cbf9ba2305d9cedfa979c4ec0fee3e4f92f1929 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Fri, 8 May 2026 04:38:11 +0100 Subject: [PATCH 14/32] fix(golden): use MXFP8-quantized X as golden input, matching HBM --- aten/plena_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index f033a15..fe372d9 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -625,7 +625,7 @@ def compile_hf_model( K_q_list = [_qw(K_mats[i]) for i in range(n_layers)] V_q_list = [_qw(V_mats[i]) for i in range(n_layers)] - X_gold = token_embeds.clone() + pos_weight # embedding_add + X_gold = _qw(token_embeds.clone()) + _qw(pos_weight) # embedding_add (MXFP8-quantized, matching HBM) ratio = num_heads // num_kv_heads for i in range(n_layers): From 18f819ad4e3365c3ab8de6a6589da6b3b90d901b Mon Sep 17 00:00:00 2001 From: booth-algo Date: Fri, 8 May 2026 07:45:34 +0100 Subject: [PATCH 15/32] feat: VRAM-in-the-loop stage comparison for emulator validation Reads emulator's actual VRAM intermediates at each stage boundary, uses them as golden input for the next stage. Validates each stage independently, proving emulator correctness regardless of golden chain precision drift. SmolVLM2 result: 100% allclose, MSE=1.45e-05. --- aten/vram_stage_compare.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aten/vram_stage_compare.py b/aten/vram_stage_compare.py index f169fb3..1f92788 100644 --- a/aten/vram_stage_compare.py +++ b/aten/vram_stage_compare.py @@ -68,7 +68,7 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, # For clm-60m: different addresses. We compute from VRAM layout. tiles = hidden // mlen x_addr = 3 * mlen * mlen + 0 # after COS, SIN, mask (for native mode) - # Read final output address from comparison_params + # Actually, let's read the comparison_params for the final output address import json params = json.load(open(build / "comparison_params.json")) final_addr = params["start_row_idx"] * mlen @@ -82,8 +82,9 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, W_down = quantize_to_mxfp(torch.load(build / "W_down_0.pt", weights_only=True)) # --- Stage 1: O_full (attention output) --- - # Find O_full address from ISA comments - o_full_addr = final_addr - 2 * seq_len * hidden + # We trust O_full is correct (proven by per-head comparison) + o_full_addr = final_addr - 2 * seq_len * hidden # O_full is 2 allocations before O_proj + # Actually, let's find O_full from ISA import re asm_path = build / "generated_asm_code.asm" if asm_path.exists(): From 17f50d8dfae0933b9a0c445d07a3d64ea7dee256 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Sat, 9 May 2026 07:14:18 +0100 Subject: [PATCH 16/32] fix(ci): guard toml import for CI environments without toml package --- utils/load_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/utils/load_config.py b/utils/load_config.py index 3894f70..44901bb 100644 --- a/utils/load_config.py +++ b/utils/load_config.py @@ -1,7 +1,10 @@ import re from pathlib import Path -import toml +try: + import toml +except ImportError: + toml = None _PARAM_PATTERN = re.compile(r"\s*(?:localparam|parameter)\s+(?:\w+\s+)?(\w+)\s*=\s*([^;]+);") @@ -37,6 +40,8 @@ def load_svh_settings(file_path: str | Path) -> dict[str, int]: def load_toml_config(file_path, section_to_load=None, mode="BEHAVIOR"): + if toml is None: + raise ImportError("'toml' package required for load_toml_config. Install with: pip install toml") with open(file_path) as f: full_toml = toml.load(f) mode_section = full_toml.get(mode, {}) From f5b9e222033b15a6e22a5ee426e1177d9263bd9f Mon Sep 17 00:00:00 2001 From: booth-algo Date: Sat, 9 May 2026 08:16:35 +0100 Subject: [PATCH 17/32] style: clean up informal comments in vram_stage_compare --- aten/vram_stage_compare.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/vram_stage_compare.py b/aten/vram_stage_compare.py index 1f92788..f169fb3 100644 --- a/aten/vram_stage_compare.py +++ b/aten/vram_stage_compare.py @@ -68,7 +68,7 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, # For clm-60m: different addresses. We compute from VRAM layout. tiles = hidden // mlen x_addr = 3 * mlen * mlen + 0 # after COS, SIN, mask (for native mode) - # Actually, let's read the comparison_params for the final output address + # Read final output address from comparison_params import json params = json.load(open(build / "comparison_params.json")) final_addr = params["start_row_idx"] * mlen @@ -82,9 +82,8 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, W_down = quantize_to_mxfp(torch.load(build / "W_down_0.pt", weights_only=True)) # --- Stage 1: O_full (attention output) --- - # We trust O_full is correct (proven by per-head comparison) - o_full_addr = final_addr - 2 * seq_len * hidden # O_full is 2 allocations before O_proj - # Actually, let's find O_full from ISA + # Find O_full address from ISA comments + o_full_addr = final_addr - 2 * seq_len * hidden import re asm_path = build / "generated_asm_code.asm" if asm_path.exists(): From 8501754c449c2215a30ff41e8331b4b1443097c9 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Sun, 10 May 2026 21:44:59 +0100 Subject: [PATCH 18/32] fix: use 'compiler' not 'PLENA_Compiler' for doc paths (post-merge fix) --- sim_env_utils/build_sys_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sim_env_utils/build_sys_tools.py b/sim_env_utils/build_sys_tools.py index a134b40..ae63a66 100644 --- a/sim_env_utils/build_sys_tools.py +++ b/sim_env_utils/build_sys_tools.py @@ -75,8 +75,8 @@ def env_setup(memory_data_manager, build_path: str, data_config, quant_config, h hbm_row_width: HBM row width test_file_name: Optional test file name """ - isa_file_path = _PROJECT_ROOT / "PLENA_Compiler" / "doc" / "operation.svh" - config_file_path = _PROJECT_ROOT / "PLENA_Compiler" / "doc" / "configuration.svh" + isa_file_path = _PROJECT_ROOT / "compiler" / "doc" / "operation.svh" + config_file_path = _PROJECT_ROOT / "compiler" / "doc" / "configuration.svh" if test_file_name is None: assembler = AssemblyToBinary(str(isa_file_path), str(config_file_path)) From 6ad4228785f527c92e9159fe4c16166f461a9258 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Sun, 10 May 2026 22:05:01 +0100 Subject: [PATCH 19/32] fix: revert to PLENA_Compiler path (matches production layout) --- sim_env_utils/build_sys_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sim_env_utils/build_sys_tools.py b/sim_env_utils/build_sys_tools.py index ae63a66..a134b40 100644 --- a/sim_env_utils/build_sys_tools.py +++ b/sim_env_utils/build_sys_tools.py @@ -75,8 +75,8 @@ def env_setup(memory_data_manager, build_path: str, data_config, quant_config, h hbm_row_width: HBM row width test_file_name: Optional test file name """ - isa_file_path = _PROJECT_ROOT / "compiler" / "doc" / "operation.svh" - config_file_path = _PROJECT_ROOT / "compiler" / "doc" / "configuration.svh" + isa_file_path = _PROJECT_ROOT / "PLENA_Compiler" / "doc" / "operation.svh" + config_file_path = _PROJECT_ROOT / "PLENA_Compiler" / "doc" / "configuration.svh" if test_file_name is None: assembler = AssemblyToBinary(str(isa_file_path), str(config_file_path)) From c6a61d4ae80ad51b0d056f90857e94a33cf94ce7 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Mon, 11 May 2026 10:16:09 +0100 Subject: [PATCH 20/32] refactor: split PLENA compiler DSL and ISA layers --- aten/__init__.py | 12 +- aten/isa_builder.py | 139 + aten/ops/plena/attention_ops.py | 4 +- aten/ops/plena/conv_ops.py | 4 +- aten/ops/plena/ffn_ops.py | 2 +- aten/plena/__init__.py | 50 + aten/plena/compiler.py | 313 ++ aten/plena/constants.py | 5 + aten/plena/dsl_attention.py | 82 + aten/plena/dsl_fp_tile_ops.py | 568 +++ aten/plena/dsl_matrix_ops.py | 209 ++ aten/plena/dsl_tensors.py | 319 ++ aten/plena/isa_attention.py | 531 +++ aten/plena/isa_compiler.py | 576 +++ aten/plena/isa_emit.py | 87 + aten/plena/isa_fp_ops.py | 316 ++ aten/plena/isa_matrix.py | 384 ++ aten/plena/isa_tile_rows.py | 391 ++ aten/plena/memory.py | 674 ++++ aten/plena/registers.py | 66 + aten/plena/tile_compiler.py | 1015 +++++ aten/plena/vars.py | 163 + aten/plena_compiler.py | 5752 +---------------------------- aten/plena_frontend.py | 46 +- aten/tests/test_plena_compiler.py | 77 +- 25 files changed, 6047 insertions(+), 5738 deletions(-) create mode 100644 aten/isa_builder.py create mode 100644 aten/plena/__init__.py create mode 100644 aten/plena/compiler.py create mode 100644 aten/plena/constants.py create mode 100644 aten/plena/dsl_attention.py create mode 100644 aten/plena/dsl_fp_tile_ops.py create mode 100644 aten/plena/dsl_matrix_ops.py create mode 100644 aten/plena/dsl_tensors.py create mode 100644 aten/plena/isa_attention.py create mode 100644 aten/plena/isa_compiler.py create mode 100644 aten/plena/isa_emit.py create mode 100644 aten/plena/isa_fp_ops.py create mode 100644 aten/plena/isa_matrix.py create mode 100644 aten/plena/isa_tile_rows.py create mode 100644 aten/plena/memory.py create mode 100644 aten/plena/registers.py create mode 100644 aten/plena/tile_compiler.py create mode 100644 aten/plena/vars.py diff --git a/aten/__init__.py b/aten/__init__.py index 19d6fd8..ba18cc9 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -9,10 +9,20 @@ PLENA_PKG_DIR = Path(__file__).parent NATIVE_OPS_YAML = PLENA_PKG_DIR / "native_ops.yaml" +from compiler.aten.isa_builder import ( # noqa: E402, F401 + Comment, + Instr, + IsaBuilder, + Register, + addr, + fp, + gp, +) from compiler.aten.plena_compiler import ( # noqa: E402, F401 + DeveloperCompiler, PlenaCompiler, TileCompiler, - DeveloperCompiler, + IsaCompiler, RegisterAllocator, TensorVar, InputVar, diff --git a/aten/isa_builder.py b/aten/isa_builder.py new file mode 100644 index 0000000..8a84778 --- /dev/null +++ b/aten/isa_builder.py @@ -0,0 +1,139 @@ +"""Typed ISA builder for the ATen PLENA compiler path. + +This is intentionally small: it models the physical instruction stream and +prints the same assembly syntax the existing compiler emits. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Protocol + + +class Renderable(Protocol): + def render(self) -> str: + """Render this object as assembly text.""" + + +@dataclass(frozen=True) +class Register: + prefix: str + index: int + + def render(self) -> str: + return f"{self.prefix}{self.index}" + + +def gp(index: int) -> Register: + return Register("gp", index) + + +def fp(index: int) -> Register: + return Register("f", index) + + +def addr(index: int) -> Register: + return Register("a", index) + + +AsmArg = str | int | Register + + +def render_arg(arg: AsmArg) -> str: + if isinstance(arg, Register): + return arg.render() + return str(arg) + + +@dataclass(frozen=True) +class Instr: + opcode: str + args: tuple[AsmArg, ...] = () + + def render(self) -> str: + if not self.args: + return self.opcode + return f"{self.opcode} {', '.join(render_arg(arg) for arg in self.args)}" + + +@dataclass(frozen=True) +class Comment: + text: str + + def render(self) -> str: + text = self.text.rstrip() + if text.startswith(";"): + return text + return f"; {text}" + + +AsmItem = str | Instr | Comment +IMM2_BOUND = 1 << 18 + + +@dataclass +class IsaBuilder: + items: list[AsmItem] = field(default_factory=list) + + def comment(self, text: str) -> IsaBuilder: + self.items.append(Comment(text)) + return self + + def instr(self, opcode: str, *args: AsmArg) -> IsaBuilder: + self.items.append(Instr(opcode, args)) + return self + + def raw(self, line: str) -> IsaBuilder: + self.items.append(line.rstrip("\n")) + return self + + def extend(self, items: Iterable[AsmItem]) -> IsaBuilder: + self.items.extend(items) + return self + + def render(self) -> str: + if not self.items: + return "" + return "\n".join(render_item(item) for item in legalize_large_immediates(self.items)) + "\n" + + +AsmInput = str | Renderable + + +def render_item(item: AsmItem) -> str: + if isinstance(item, str): + return item.rstrip("\n") + return item.render() + + +def render_asm(value: AsmInput) -> str: + if isinstance(value, str): + return value + return value.render() + + +def is_gp_zero(arg: AsmArg) -> bool: + return isinstance(arg, Register) and arg.prefix == "gp" and arg.index == 0 + + +def legalize_large_immediates(items: Iterable[AsmItem]) -> list[AsmItem]: + """Split typed absolute S_ADDI_INT loads that exceed the immediate field. + + This is the typed equivalent of plena_frontend._fix_large_immediates. + Raw string items are intentionally left alone until those call sites move + onto typed instructions. + """ + legalized: list[AsmItem] = [] + for item in items: + if isinstance(item, Instr) and item.opcode == "S_ADDI_INT" and len(item.args) == 3: + rd, rs, imm = item.args + if isinstance(rd, Register) and is_gp_zero(rs) and isinstance(imm, int) and imm >= IMM2_BOUND: + upper = imm >> 12 + lower = imm & 0xFFF + legalized.append(Instr("S_LUI_INT", (rd, upper))) + if lower: + legalized.append(Instr("S_ADDI_INT", (rd, rd, lower))) + continue + legalized.append(item) + return legalized diff --git a/aten/ops/plena/attention_ops.py b/aten/ops/plena/attention_ops.py index b62eb81..0798010 100644 --- a/aten/ops/plena/attention_ops.py +++ b/aten/ops/plena/attention_ops.py @@ -135,7 +135,7 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): addr_reg_val=[K.hbm_addr, V.hbm_addr], ) alloc.free_gp(gp_for_preload) - prog._compiler.generated_code += setup + prog._compiler.emit(setup) # Allocate VRAM buffers mirroring main's layout. # S, PV each require mlen*mlen*ratio elements; O is s_q * (hq*h_qkv). @@ -185,7 +185,7 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): k_base_hbm_offset_reg=k_addr, v_base_hbm_offset_reg=v_addr, ) - prog._compiler.generated_code += asm + prog._compiler.emit(asm) # Release HBM addr regs (they're only needed during the call) alloc.free_addr([k_addr, v_addr]) diff --git a/aten/ops/plena/conv_ops.py b/aten/ops/plena/conv_ops.py index 57eb562..689e9ba 100644 --- a/aten/ops/plena/conv_ops.py +++ b/aten/ops/plena/conv_ops.py @@ -146,7 +146,7 @@ def conv2d_plena( setup_lines.append(f"S_LUI_INT gp{setup_gp}, {hbm_base >> 12}") setup_lines.append(f"S_ADDI_INT gp{setup_gp}, gp{setup_gp}, {hbm_base & 0xFFF}") setup_lines.append(f"C_SET_ADDR_REG a{addr_reg_idx}, gp0, gp{setup_gp}") - prog._compiler.generated_code += "\n".join(setup_lines) + "\n" + prog._compiler.emit("\n".join(setup_lines) + "\n") # ------------------------------------------------------------------ # Emit: im2col assembly @@ -191,7 +191,7 @@ def conv2d_plena( fp_one_reg=fp_one_reg, # f1 = 1.0 by default (must be in fp_preload[fp_one_reg]) fp_ex_reg=2, # f2 = V_RED_SUM accumulator ) - prog._compiler.generated_code += asm_code + prog._compiler.emit(asm_code) # ------------------------------------------------------------------ # Systolic matmul: im2col_out @ weight_2d -> (M, C_out) diff --git a/aten/ops/plena/ffn_ops.py b/aten/ops/plena/ffn_ops.py index 89fe670..9d4be26 100644 --- a/aten/ops/plena/ffn_ops.py +++ b/aten/ops/plena/ffn_ops.py @@ -55,7 +55,7 @@ def ffn_plena(prog, input_var, w_gate, w_up, w_down): use_loop_instructions=True, ) - prog._compiler.generated_code += isa_code + prog._compiler.emit(isa_code) # FFN result is written back to the activation area in VRAM (in-place overwrite) return input_var diff --git a/aten/plena/__init__.py b/aten/plena/__init__.py new file mode 100644 index 0000000..09ad0c4 --- /dev/null +++ b/aten/plena/__init__.py @@ -0,0 +1,50 @@ +"""Internal modules for the ATen PLENA compiler implementation.""" + +from compiler.aten.plena.compiler import PlenaCompiler +from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN +from compiler.aten.plena.isa_compiler import DeveloperCompiler, IsaCompiler +from compiler.aten.plena.memory import ( + FPRAMAllocator, + FPRAMObjectLayout, + MRAMAllocator, + MatrixBlockLayout, + MemoryBlock, + MemoryObjectInfo, + SubMatrixInfo, + VRAMAllocator, + VRAMMatrixBlockLayout, + VRAMSubMatrixInfo, + VirtualMemoryManager, +) +from compiler.aten.plena.registers import RegisterAllocator +from compiler.aten.plena.tile_compiler import TileCompiler +from compiler.aten.plena.vars import FPVar, InputVar, Tensor, TensorKind, TensorVar, VRAMMatrixVar, tensor_kind + +__all__ = [ + "BLEN", + "IMM2_BOUND", + "MLEN", + "DeveloperCompiler", + "FPRAMAllocator", + "FPRAMObjectLayout", + "FPVar", + "InputVar", + "IsaCompiler", + "MRAMAllocator", + "MatrixBlockLayout", + "MemoryBlock", + "MemoryObjectInfo", + "PlenaCompiler", + "RegisterAllocator", + "SubMatrixInfo", + "Tensor", + "TensorKind", + "TensorVar", + "TileCompiler", + "VRAMAllocator", + "VRAMMatrixBlockLayout", + "VRAMMatrixVar", + "VRAMSubMatrixInfo", + "VirtualMemoryManager", + "tensor_kind", +] diff --git a/aten/plena/compiler.py b/aten/plena/compiler.py new file mode 100644 index 0000000..395b637 --- /dev/null +++ b/aten/plena/compiler.py @@ -0,0 +1,313 @@ +"""User-facing PLENA compiler DSL.""" + +from __future__ import annotations + +import os +from collections.abc import Callable +from functools import wraps + +from compiler.aten.plena.isa_compiler import IsaCompiler +from compiler.aten.plena.dsl_attention import DslAttentionMixin +from compiler.aten.plena.dsl_fp_tile_ops import DslFPTileOpsMixin +from compiler.aten.plena.dsl_matrix_ops import DslMatrixOpsMixin +from compiler.aten.plena.dsl_tensors import DslTensorMixin +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar + + +class _IsaCompilerView: + """ + Back-compat proxy for legacy ``prog._compiler.X(...)`` call sites. + + PlenaCompiler now inherits IsaCompiler rather than composing it, so + for call sites that still expect to reach the low-level IsaCompiler + API (e.g., ``allocate_fpram(name=..., size=...)`` returning an int), we + expose a proxy whose attribute lookup resolves callables on + IsaCompiler directly (bypassing any PlenaCompiler overrides with + colliding names). Non-callable attributes (e.g., ``generated_code``, + ``vram_allocator``) fall through to the underlying instance unchanged. + """ + + __slots__ = ("_inst",) + + def __init__(self, inst: PlenaCompiler): + object.__setattr__(self, "_inst", inst) + + def __getattr__(self, name: str): + cls_attr = getattr(IsaCompiler, name, None) + if cls_attr is not None and callable(cls_attr): + return cls_attr.__get__(self._inst, IsaCompiler) + return getattr(self._inst, name) + + def __setattr__(self, name: str, value): + setattr(self._inst, name, value) + + +# ============================================================================ +# PlenaCompiler Main Class +# ============================================================================ + + +class PlenaCompiler( + DslTensorMixin, + DslFPTileOpsMixin, + DslMatrixOpsMixin, + DslAttentionMixin, + IsaCompiler, +): + """ + PLENA High-level Compiler Interface. + + Inherits the ISA-emission machinery from IsaCompiler and layers a + Pythonic DSL on top. All operations are eagerly evaluated — ISA code is + generated immediately upon call. + """ + + def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + """ + Args: + mlen: Matrix tile size (default 64) + blen: Vector tile size (default 4) + real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) + unroll_loops: If True, unroll sub-projection loops at ASM-gen time to + eliminate C_LOOP_START/END overhead. Overridden by the + ATEN_UNROLL env var ("1"=True, "0"=False). + """ + _env_unroll = os.environ.get("ATEN_UNROLL", "") + if _env_unroll == "1": + unroll_loops = True + elif _env_unroll == "0": + unroll_loops = False + super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops) + + # HBM address auto-allocation + self._next_hbm_addr: int = 0 + self._hbm_free_blocks: list[tuple[int, int]] = [] # (addr, size) + + # Variable registries + self._inputs: dict[str, InputVar] = {} + self._tensors: dict[str, TensorVar] = {} + self._fp_vars: dict[str, FPVar] = {} + self._functions: dict[str, Callable] = {} + self._registered_hbm_sub_matrices: dict[str, bool] = {} + self._registered_vram_sub_matrices: dict[str, bool] = {} + + self._result_tensor: TensorVar | None = None + + # Auto-generated name counter + self._auto_name_counter: int = 0 + + # Function scope namespace + # Push a prefix on each function call (e.g., "linear_0/"), pop on exit + # _auto_name will automatically add current scope prefix, avoiding name conflicts when calling the same function multiple times + self._scope_stack: list[str] = [] + self._function_call_counters: dict[str, int] = {} # func_name -> call count + + # ======================================================================== + # Property Access + # ======================================================================== + + # mlen / blen are instance attributes inherited from TileCompiler.__init__. + + @property + def compiler(self) -> PlenaCompiler: + """Legacy accessor — returns self now that PlenaCompiler is the compiler.""" + return self + + @property + def _compiler(self) -> _IsaCompilerView: + """Back-compat shim for legacy ``prog._compiler.X(...)`` call sites. + Returns a proxy that resolves callables against IsaCompiler + directly so callers reach the low-level API regardless of any + PlenaCompiler override with the same name.""" + return _IsaCompilerView(self) + + @property + def symbol_table(self): + """Access symbol table.""" + return self.get_symbol_table() + + # ======================================================================== + # Function Decorator + # ======================================================================== + + def function(self, func: Callable) -> Callable: + """ + Decorator: Define reusable functions. + + Each invocation generates fresh ISA code (eager evaluation). + Internally allocated tensors are auto-freed on exit unless returned. + + Scoping: intermediate tensors get a call-index prefix to avoid name + collisions across repeated calls (e.g., "linear_0/proj_1", "linear_1/proj_1"). + Nested functions compose prefixes: "two_layer_0/linear_0/proj_1". + """ + func_name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + call_idx = self._function_call_counters.get(func_name, 0) + self._function_call_counters[func_name] = call_idx + 1 + + scope = f"{func_name}_{call_idx}/" + self._scope_stack.append(scope) + + self.emit_comment(f"=== Enter {func_name} (call #{call_idx}) ===") + + # Snapshot: record existing tensors before function execution + tensors_before = set(self._tensors.keys()) + inputs_before = set(self._inputs.keys()) + fp_vars_before = set(self._fp_vars.keys()) + + try: + result = func(*args, **kwargs) + + # Auto-free: free locally allocated tensors that are not returned + return_names = set() + return_fp_names = set() + if isinstance(result, TensorVar): + return_names.add(result.internal_name) + elif isinstance(result, FPVar): + return_fp_names.add(result.internal_name) + elif isinstance(result, (tuple, list)): + for r in result: + if isinstance(r, TensorVar): + return_names.add(r.internal_name) + elif isinstance(r, FPVar): + return_fp_names.add(r.internal_name) + + for name in set(self._tensors.keys()) - tensors_before: + if name not in return_names: + tensor = self._tensors[name] + if isinstance(tensor, VRAMMatrixVar): + self.free_tensor(tensor) + self._registered_vram_sub_matrices[tensor.name] = False + + for name in set(self._inputs.keys()) - inputs_before: + if name not in return_names: + self.free_input(self._inputs[name]) + + local_fp_names = sorted( + set(self._fp_vars.keys()) - fp_vars_before, + key=lambda n: self._fp_vars[n].address, + reverse=True, + ) + for name in local_fp_names: + if name in return_fp_names: + continue + fp_var = self._fp_vars.get(name) + if fp_var is not None: + self.free_fp_var(fp_var) + finally: + self._scope_stack.pop() + self.emit_comment(f"=== Exit {func_name} (call #{call_idx}) ===") + + return result + + self._functions[func_name] = wrapper + wrapper._plena_function = True + wrapper._plena_name = func_name + return wrapper + + # ======================================================================== + # Result Marking + # ======================================================================== + + def result(self, tensor_var: TensorVar): + """Mark output result tensor.""" + self._result_tensor = tensor_var + + # ======================================================================== + # Compilation + # ======================================================================== + + def compile(self) -> str: + """Get generated ISA code string.""" + return super().get_code() + + def print_symbol_table(self): + """Print symbol table""" + super().print_symbol_table() + + def get_symbol_table(self): + """Get symbol table""" + return super().get_symbol_table() + + # ======================================================================== + # Operator Dispatch (internal) + # ======================================================================== + + def _dispatch_matmul(self, left: TensorVar, right) -> TensorVar: + raise TypeError("@ operator is no longer supported in PlenaCompiler. Use explicit program APIs instead.") + + # ======================================================================== + # Utility Methods + # ======================================================================== + + def _scoped_name(self, name: str) -> str: + """ + Apply current scope prefix to a name. + + - Top-level alloc("temp"): -> "temp" + - Inside linear call 0, alloc("temp"): -> "linear_0/temp" + - Nested two_layer->linear, alloc("temp"): -> "two_layer_0/linear_0/temp" + """ + if not self._scope_stack: + return name + scope_prefix = "".join(self._scope_stack) + return f"{scope_prefix}{name}" + + def _allocate_hbm(self, hbm_size: int) -> int: + """Allocate HBM range, preferring previously freed blocks.""" + best_idx = None + best_waste = None + for i, (addr, size) in enumerate(self._hbm_free_blocks): + if size >= hbm_size: + waste = size - hbm_size + if best_waste is None or waste < best_waste: + best_idx = i + best_waste = waste + + if best_idx is not None: + addr, block_size = self._hbm_free_blocks.pop(best_idx) + # Return excess fragment to free list + excess = block_size - hbm_size + if excess > 0: + self._hbm_free_blocks.append((addr + hbm_size, excess)) + return addr + + addr = self._next_hbm_addr + m = self.mlen + self._next_hbm_addr = ((addr + hbm_size + m - 1) // m) * m + return addr + + def _recycle_hbm(self, hbm_addr: int, hbm_size: int): + """Recycle an HBM range for future auto-allocation.""" + if hbm_size <= 0: + return + self._hbm_free_blocks.append((hbm_addr, hbm_size)) + + def _auto_name(self, prefix: str = "t") -> str: + """ + Generate a unique scoped name. + + - Top-level: "__proj_1" + - linear call 0: "linear_0/__proj_1" + - nested: "two_layer_0/linear_0/__proj_1" + """ + self._auto_name_counter += 1 + scope_prefix = "".join(self._scope_stack) + return f"{scope_prefix}__{prefix}_{self._auto_name_counter}" + + def __repr__(self): + num_inputs = len(self._inputs) + num_tensors = len(self._tensors) + num_functions = len(self._functions) + code_len = len(super().get_code().splitlines()) + return ( + f"PlenaCompiler(mlen={self.mlen}, blen={self.blen}, " + f"inputs={num_inputs}, tensors={num_tensors}, " + f"functions={num_functions}, isa_lines={code_len})" + ) + + +__all__ = ["PlenaCompiler"] diff --git a/aten/plena/constants.py b/aten/plena/constants.py new file mode 100644 index 0000000..0c72ec5 --- /dev/null +++ b/aten/plena/constants.py @@ -0,0 +1,5 @@ +"""Shared constants for the ATen PLENA compiler path.""" + +MLEN = 64 # Minimum matrix block size +BLEN = 4 # Vector tile size +IMM2_BOUND = 2**18 diff --git a/aten/plena/dsl_attention.py b/aten/plena/dsl_attention.py new file mode 100644 index 0000000..b1ef4d7 --- /dev/null +++ b/aten/plena/dsl_attention.py @@ -0,0 +1,82 @@ +"""Flash-attention operations for the PLENA DSL.""" + +from __future__ import annotations + +from compiler.aten.plena.vars import InputVar, VRAMMatrixVar + + +class DslAttentionMixin: + # ======================================================================== + # Flash Attention Operations + # ======================================================================== + + def init_online_softmax(self, q_idx: int, o_matrix: VRAMMatrixVar): + """Initialize Online Softmax state: m=-inf, l=0, O_row=0""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().init_online_softmax( + q_idx=q_idx, + o_matrix=o_matrix.name, + seq_len=seq_len, + head_dim=head_dim, + ) + + def online_softmax_block(self, s_block: VRAMMatrixVar, scale: float): + """Perform Online Softmax on S block""" + super().online_softmax_block( + s_block_matrix=s_block.name, + scale=scale, + ) + + def compute_pv( + self, + s_block: VRAMMatrixVar, + v_input: InputVar, + k_idx: int, + pv_matrix: VRAMMatrixVar, + head_dim: int, + ): + """Compute PV = P @ V[k_idx] where P is stored in s_block.""" + if not isinstance(s_block, VRAMMatrixVar): + raise TypeError(f"s_block must be VRAMMatrixVar, got {type(s_block)}") + if not isinstance(v_input, InputVar): + raise TypeError(f"v_input must be InputVar, got {type(v_input)}") + if not isinstance(pv_matrix, VRAMMatrixVar): + raise TypeError(f"pv_matrix must be VRAMMatrixVar, got {type(pv_matrix)}") + + self._ensure_hbm_sub_matrix_registered(v_input) + super().compute_pv( + s_block_matrix=s_block.name, + v_sub_matrix=v_input.name, + k_idx=k_idx, + pv_matrix=pv_matrix.name, + head_dim=head_dim, + ) + + def scale_o_row(self, o_matrix: VRAMMatrixVar, q_idx: int): + """Scale current row block of O by m_res""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().scale_o_row( + o_matrix=o_matrix.name, + q_idx=q_idx, + seq_len=seq_len, + head_dim=head_dim, + ) + + def final_scale_o(self, q_idx: int, o_matrix: VRAMMatrixVar): + """Final scaling: O[q_idx] = O[q_idx] / l""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().final_scale_o( + q_idx=q_idx, + o_matrix=o_matrix.name, + seq_len=seq_len, + head_dim=head_dim, + ) + + +__all__ = ["DslAttentionMixin"] diff --git a/aten/plena/dsl_fp_tile_ops.py b/aten/plena/dsl_fp_tile_ops.py new file mode 100644 index 0000000..0c40411 --- /dev/null +++ b/aten/plena/dsl_fp_tile_ops.py @@ -0,0 +1,568 @@ +"""FPRAM, FPVar, and tile-row operations for the PLENA DSL.""" + +from __future__ import annotations + +from compiler.aten.plena.vars import FPVar, VRAMMatrixVar + + +class DslFPTileOpsMixin: + # ======================================================================== + # FP Variable (FPRAM) + # ======================================================================== + + def allocate_fpram( + self, + internal_name: str, + size: int = 1, + display_name: str | None = None, + ) -> FPVar: + """ + Allocate FPRAM with explicit internal name and return FPVar proxy. + """ + if size <= 0: + raise ValueError(f"FPRAM allocation size must be positive, got {size}") + + address = super().allocate_fpram(internal_name, size) + var = FPVar( + self, + internal_name, + address, + size, + display_name=display_name if display_name is not None else internal_name, + ) + self._fp_vars[internal_name] = var + return var + + def free_fpram(self, internal_name: str, strict: bool = True): + """ + Free FPRAM allocation by internal name. + """ + super().free_fpram(internal_name, strict=strict) + self._fp_vars.pop(internal_name, None) + + def fp_var(self, name: str, size: int = 1) -> FPVar: + """ + Declare an FP variable in FPRAM. + + Allocates a contiguous region in FPRAM and returns an FPVar proxy. + Within function scope, names are automatically prefixed. + + Args: + name: variable name + size: number of f16 elements to allocate (default 1) + + Returns: + FPVar proxy object (use .address for ISA generation) + + Example: + scale = prog.fp_var("scale") # 1 element + m_old = prog.fp_var("m_old", size=64) # 64 elements + prog.compiler # access compiler for ISA if needed + """ + display_name = name + internal_name = self._scoped_name(name) + + return self.allocate_fpram( + internal_name=internal_name, + size=size, + display_name=display_name, + ) + + def save_fpram_state(self) -> int: + """Save FPRAM allocator snapshot""" + return super().save_fpram_state() + + def restore_fpram_state(self, snapshot: int): + """Restore FPRAM allocator snapshot""" + super().restore_fpram_state(snapshot) + # Remove FPVar proxies that are no longer allocated in allocator. + allocations = set(super().list_fpram_allocations()) + to_remove = [n for n in self._fp_vars if n not in allocations] + for n in to_remove: + del self._fp_vars[n] + + # ======================================================================== + # FPRAM Tile Operations + # ======================================================================== + + def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: + if isinstance(addr_or_var, FPVar): + if offset < 0 or offset >= addr_or_var.size: + raise ValueError( + f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" + ) + return addr_or_var.address + offset + if not isinstance(addr_or_var, int): + raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") + return addr_or_var + offset + + def _resolve_rows( + self, + row_idx: int | None = None, + rows: list[int] | None = None, + ) -> list[int]: + if row_idx is not None and rows is not None: + raise ValueError("Provide either row_idx or rows, not both") + if rows is not None: + return rows + if row_idx is not None: + return [row_idx] + return list(range(self.mlen)) + + def tile_row_max( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: list[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + """ + Tile Row Max: reduce a single row to max, store to FPRAM address. + + Args: + target_fpram_addr: FPRAM address or FPVar to write result + source: VRAM tile (mlen x mlen) + row_idx: single row index (legacy path) + rows: multiple row indices + target_offset: element offset when target_fpram_addr is FPVar + target_base_offset: base offset for multi-row writes (contiguous) + + Example: + m = prog.fp_var("m", size=1) + S = prog.alloc("S", 64, 64) + for row in range(64): + prog.tile_row_max(m, S, rows=list(range(64))) + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + if len(resolved_rows) == 1: + offsets = [target_offset] + else: + offsets = [target_base_offset + i for i in range(len(resolved_rows))] + row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] + super().tile_row_max( + source_matrix=source.name, + row_map=row_map, + ) + + def tile_row_sum( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: list[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + """ + Tile Row Sum: reduce a single row to sum, store to FPRAM address. + + Args: + target_fpram_addr: FPRAM address or FPVar to write result + source: VRAM tile (mlen x mlen) + row_idx: single row index (legacy path) + rows: multiple row indices + target_offset: element offset when target_fpram_addr is FPVar + target_base_offset: base offset for multi-row writes (contiguous) + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + if len(resolved_rows) == 1: + offsets = [target_offset] + else: + offsets = [target_base_offset + i for i in range(len(resolved_rows))] + row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] + super().tile_row_sum(source.name, row_map) + + def tile_row_exp( + self, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: list[int] | None = None, + ): + """ + Tile Row Exp: in-place exp on specified rows. + + For each row i: source[i, :] = exp(source[i, :]) + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + super().tile_row_exp(source.name, resolved_rows) + + def tile_row_reci( + self, + source: VRAMMatrixVar, + rows: list[int] | None = None, + ): + """ + Tile Row Reciprocal: in-place 1/x on specified rows. + + For each row i: source[i, :] = 1.0 / source[i, :] + """ + if rows is None: + rows = list(range(self.mlen)) + super().tile_row_reci(source.name, rows) + + def tile_row_sub_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: list[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + """ + Tile Row Sub FP: subtract FPRAM scalar from a single row. + + For row i: source[i, :] = source[i, :] - FPRAM[fpram_addr] + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + if len(resolved_rows) == 1: + offsets = [fpram_offset] + else: + offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] + row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] + super().tile_row_sub_fp(source.name, row_map) + + def tile_row_mul_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: list[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + """ + Tile Row Mul FP: multiply a single row by FPRAM scalar. + + For row i: source[i, :] = source[i, :] * FPRAM[fpram_addr] + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + if len(resolved_rows) == 1: + offsets = [fpram_offset] + else: + offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] + row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] + super().tile_row_mul_fp(source.name, row_map) + + def tile_row_add_fp( + self, + source: VRAMMatrixVar, + fp_var: FPVar, + rows: list[int] | None = None, + ): + """ + Tile Row Add FP: add FPRAM scalar to each specified row. + + For each row i: source[i, :] = source[i, :] + fp_var[i] + """ + if rows is None: + rows = list(range(self.mlen)) + row_map = [(r, fp_var[r]) for r in rows] + super().tile_row_add_fp(source.name, row_map) + + def tile_row_add( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: list[int] | None = None, + ): + """ + Tile Row Add: dst[i, :] += src[i, :] for specified rows. + """ + if rows is None: + rows = list(range(self.mlen)) + super().tile_row_add(dst.name, src.name, rows) + + def tile_row_sub( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: list[int] | None = None, + ): + """ + Tile Row Sub: dst[i, :] -= src[i, :] for specified rows. + """ + if rows is None: + rows = list(range(self.mlen)) + super().tile_row_sub(dst.name, src.name, rows) + + def tile_row_mul( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: list[int] | None = None, + ): + """ + Tile Row Mul: dst[i, :] *= src[i, :] for specified rows. + """ + if rows is None: + rows = list(range(self.mlen)) + super().tile_row_mul(dst.name, src.name, rows) + + def fpvar_reci( + self, + src: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Reciprocal: compute 1/x for FPRAM scalar array. + + For each element i: dst[i] = 1.0 / src[i] + + Args: + src: source FPVar + dst: destination FPVar + count: number of elements (default: min(src.size, dst.size)) + + Example: + l = prog.fp_var("l", size=64) + inv_l = prog.fp_var("inv_l", size=64) + prog.fpvar_reci(l, inv_l) # inv_l = 1/l + """ + if count is None: + count = min(src.size, dst.size) + if count > src.size or count > dst.size: + raise ValueError(f"count={count} exceeds FPVar size: src.size={src.size}, dst.size={dst.size}") + super().fpram_reci(src.name, dst.name, count) + + def fpvar_max( + self, + src1: FPVar, + src2: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Max: element-wise max for FPRAM scalar arrays. + + For each element i: dst[i] = max(src1[i], src2[i]) + + Example: + m_new = prog.fp_var("m_new", size=64) + prog.fpvar_max(m_old, row_max, m_new) # m_new = max(m_old, row_max) + """ + if count is None: + count = min(src1.size, src2.size, dst.size) + super().fpram_max(src1.name, src2.name, dst.name, count) + + def fpvar_sub( + self, + src1: FPVar, + src2: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Subtract: element-wise subtraction for FPRAM scalar arrays. + + For each element i: dst[i] = src1[i] - src2[i] + + Example: + diff = prog.fp_var("diff", size=64) + prog.fpvar_sub(m_old, m_new, diff) # diff = m_old - m_new + """ + if count is None: + count = min(src1.size, src2.size, dst.size) + super().fpram_sub(src1.name, src2.name, dst.name, count) + + def fpvar_exp( + self, + src: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Exp: element-wise exp for FPRAM scalar array. + + For each element i: dst[i] = exp(src[i]) + + Example: + m_res = prog.fp_var("m_res", size=64) + prog.fpvar_exp(diff, m_res) # m_res = exp(diff) + """ + if count is None: + count = min(src.size, dst.size) + super().fpram_exp(src.name, dst.name, count) + + def fpvar_mul( + self, + src1: FPVar, + src2: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Multiply: element-wise multiplication for FPRAM scalar arrays. + + For each element i: dst[i] = src1[i] * src2[i] + + Example: + result = prog.fp_var("result", size=64) + prog.fpvar_mul(l_old, m_res, result) # result = l_old * m_res + """ + if count is None: + count = min(src1.size, src2.size, dst.size) + super().fpram_mul(src1.name, src2.name, dst.name, count) + + def fpvar_add( + self, + src1: FPVar, + src2: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Add: element-wise addition for FPRAM scalar arrays. + + For each element i: dst[i] = src1[i] + src2[i] + + Example: + l_new = prog.fp_var("l_new", size=64) + prog.fpvar_add(l_old, sum_p, l_new) # l_new = l_old + sum_p + """ + if count is None: + count = min(src1.size, src2.size, dst.size) + super().fpram_add(src1.name, src2.name, dst.name, count) + + def fpvar_copy( + self, + src: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Copy: copy FPRAM scalar array. + + For each element i: dst[i] = src[i] + + Example: + m_old_saved = prog.fp_var("m_old_saved", size=64) + prog.fpvar_copy(m_old, m_old_saved) # backup m_old + """ + if count is None: + count = min(src.size, dst.size) + super().fpram_copy(src.name, dst.name, count) + + def fpvar_sum( + self, + src: FPVar, + dst: FPVar, + count: int | None = None, + ): + """ + FPVar Sum: reduction sum of src into dst[0] (via compiler FPRAM op). + """ + if count is None: + count = src.size + super().fpram_sum(src.name, dst.name, count) + + def fpvar_shift( + self, + src: FPVar, + dst: FPVar, + shift: int, + count: int | None = None, + fill: FPVar | None = None, + ): + """ + FPVar Shift: shift src into dst, filling out-of-range slots with fill (default FPRAM zero). + """ + if count is None: + count = min(src.size, dst.size) + fill_name = None if fill is None else fill.name + super().fpram_shift( + src_name=src.name, + dst_name=dst.name, + shift=shift, + count=count, + fill_fpram_name=fill_name, + ) + + def tile_row_mul_fp_broadcast( + self, + source: VRAMMatrixVar, + fpram_scalar_addr: int | FPVar, + row_idx: int | None = None, + rows: list[int] | None = None, + fpram_offset: int = 0, + ): + """ + Tile Row Mul FP Broadcast: multiply a single row by a FPRAM scalar. + + For row i: source[i, :] = source[i, :] * FPRAM[fpram_scalar_addr] + + Args: + source: VRAM tile (mlen x mlen) + fpram_scalar_addr: FPRAM address or FPVar of the scalar + row_idx: single row index (legacy path) + rows: multiple row indices + + Example: + scale_fp = prog.fp_var("scale", size=1) + for row in range(64): + prog.tile_row_mul_fp_broadcast(S, scale_fp, rows=list(range(64))) + """ + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) + super().tile_row_mul_fp_broadcast(source.name, scalar_addr, resolved_rows) + + def fpvar_fill_from_fpram( + self, + dst: FPVar, + src_fpram_addr: int, + count: int | None = None, + ): + """ + FPVar Fill from FPRAM: fill all elements with a value from FPRAM. + + For each element i: dst[i] = FPRAM[src_fpram_addr] + + Args: + dst: destination FPVar + src_fpram_addr: source FPRAM address (e.g., 0 for 0.0, 2 for -inf) + count: number of elements (default: dst.size) + + Example: + m_old = prog.fp_var("m_old", size=64) + prog.fpvar_fill_from_fpram(m_old, 2) # fill with -inf from address 2 + """ + if count is None: + count = dst.size + super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) + + def vram_fill_zero( + self, + matrix: VRAMMatrixVar, + rows: list[int] | None = None, + ): + """ + VRAM Fill Zero: fill specified rows with 0. + + Args: + matrix: VRAM matrix + rows: which rows to fill (default: all rows) + + Example: + O = prog.alloc("O", 128, 128) + prog.vram_fill_zero(O, rows=range(64, 128)) # zero out second half + """ + if rows is None: + rows = list(range(matrix.shape[0])) + else: + rows = list(rows) + + total_rows, cols = matrix.shape + if any(row < 0 or row >= total_rows for row in rows): + raise ValueError(f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={rows}") + + # VRAM matrices are column-block-major. The low-level tile helper zeros + # one 64-column tile, so walk every column block for wide matrices. + num_col_blocks = (cols + self.mlen - 1) // self.mlen + for col_block in range(num_col_blocks): + super().vram_fill_zero(matrix.name, rows, tile_col_idx=col_block) + + +__all__ = ["DslFPTileOpsMixin"] diff --git a/aten/plena/dsl_matrix_ops.py b/aten/plena/dsl_matrix_ops.py new file mode 100644 index 0000000..7cf9ae6 --- /dev/null +++ b/aten/plena/dsl_matrix_ops.py @@ -0,0 +1,209 @@ +"""Matrix projection, RoPE, and VRAM operations for the PLENA DSL.""" + +from __future__ import annotations + +from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar + + +class DslMatrixOpsMixin: + # ======================================================================== + # Matrix Projection and VRAM Operations + # ======================================================================== + + def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): + """Ensure an HBM input is registered in compiler sub-matrix manager.""" + if ( + input_var.name in self._registered_hbm_sub_matrices + and self._registered_hbm_sub_matrices[input_var.name] is True + ): + return + h, w = input_var.shape + super().ensure_hbm_sub_matrix( + name=input_var.name, + hbm_addr=input_var.hbm_addr, + shape=(h, w), + real_data_ratio=self.real_data_ratio, + ) + self._registered_hbm_sub_matrices[input_var.name] = True + + def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): + """Ensure a VRAM matrix is registered in compiler sub-matrix manager.""" + if ( + matrix_var.name in self._registered_vram_sub_matrices + and self._registered_vram_sub_matrices[matrix_var.name] is True + ): + return + super().ensure_vram_matrix_layout( + name=matrix_var.name, + shape=matrix_var.shape, + ) + self._registered_vram_sub_matrices[matrix_var.name] = True + + def vram_sub_projection_to( + self, + vram_matrix: VRAMMatrixVar, + vram_row_idx: int, + mram_input: InputVar, + mram_col_idx: int, + target: VRAMMatrixVar, + target_row_idx: int, + target_col_idx: int, + auto_reset_mram: bool = True, + k_block_start: int = 0, + k_block_count: int | None = None, + ): + """ + target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[:][mram_col_idx] + Supports K-split: k_block_start/k_block_count select a subset of K tiles. + """ + if not isinstance(vram_matrix, VRAMMatrixVar): + raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") + if not isinstance(mram_input, InputVar): + raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") + if not isinstance(target, VRAMMatrixVar): + raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") + + self._ensure_vram_sub_matrix_registered(vram_matrix) + self._ensure_hbm_sub_matrix_registered(mram_input) + if auto_reset_mram: + super().reset_mram() + super().load_sub_matrix_col( + name=mram_input.name, + col_idx=mram_col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + super().vram_sub_projection_to( + vram_mat_name=vram_matrix.name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_input.name, + mram_col_idx=mram_col_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + + def vram_sub_projection_T_to( + self, + vram_matrix: VRAMMatrixVar, + vram_row_idx: int, + mram_input: InputVar, + mram_row_idx: int, + target: VRAMMatrixVar, + target_row_idx: int, + target_col_idx: int, + auto_reset_mram: bool = True, + ): + """ + target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[mram_row_idx][:]^T + """ + if not isinstance(vram_matrix, VRAMMatrixVar): + raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") + if not isinstance(mram_input, InputVar): + raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") + if not isinstance(target, VRAMMatrixVar): + raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") + + self._ensure_vram_sub_matrix_registered(vram_matrix) + self._ensure_hbm_sub_matrix_registered(mram_input) + if auto_reset_mram: + super().reset_mram() + super().load_sub_matrix_row(name=mram_input.name, row_idx=mram_row_idx) + super().vram_sub_projection_T_to( + vram_mat_name=vram_matrix.name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_input.name, + mram_row_idx=mram_row_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + ) + + # ======================================================================== + # RoPE (1D Positional Encoding) + # ======================================================================== + + def rope( + self, + x_var: VRAMMatrixVar, + x_rot_var: VRAMMatrixVar, + cos_var: VRAMMatrixVar, + sin_var: VRAMMatrixVar, + ) -> VRAMMatrixVar: + """Apply Rotary Position Embedding in-place: x = x * cos + rotate_half(x) * sin + + x_rot_var must already be in VRAM as rotate_half(x), preloaded by caller. + Returns x_var (modified in-place). + """ + super().rope( + x_name=x_var.name, + x_rot_name=x_rot_var.name, + cos_name=cos_var.name, + sin_name=sin_var.name, + ) + return x_var + + # ======================================================================== + # VRAM Matrix Addition + # ======================================================================== + + def vram_add( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + dst_row_offset: int = 0, + src_row_offset: int = 0, + num_rows: int | None = None, + ): + """VRAM matrix add: dst[row_offset:] += src""" + super().vram_matrix_add( + dst_matrix=dst.name, + src_matrix=src.name, + dst_row_offset=dst_row_offset, + src_row_offset=src_row_offset, + num_rows=num_rows, + ) + + def vram_block_add_to( + self, + src1: TensorVar, + src1_row_idx: int, + src1_col_idx: int, + src2: TensorVar, + src2_row_idx: int, + src2_col_idx: int, + target: TensorVar, + target_row_idx: int, + target_col_idx: int, + ): + """ + mlen x mlen block add: + target[target_row_idx][target_col_idx] = + src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] + + Supports writing back to the same matrix/block (in-place overwrite). + """ + allowed = (VRAMMatrixVar,) + if not isinstance(src1, allowed): + raise TypeError(f"src1 must be VRAMMatrixVar, got {type(src1)}") + if not isinstance(src2, allowed): + raise TypeError(f"src2 must be VRAMMatrixVar, got {type(src2)}") + if not isinstance(target, allowed): + raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") + + super().vram_block_add_to( + src1_matrix=src1.name, + src1_row_idx=src1_row_idx, + src1_col_idx=src1_col_idx, + src2_matrix=src2.name, + src2_row_idx=src2_row_idx, + src2_col_idx=src2_col_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + ) + + +__all__ = ["DslMatrixOpsMixin"] diff --git a/aten/plena/dsl_tensors.py b/aten/plena/dsl_tensors.py new file mode 100644 index 0000000..9f480dc --- /dev/null +++ b/aten/plena/dsl_tensors.py @@ -0,0 +1,319 @@ +"""Tensor, memory, and normalization operations for the PLENA DSL.""" + +from __future__ import annotations + +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar + + +class DslTensorMixin: + # ======================================================================== + # Input Declaration + # ======================================================================== + + def input( + self, + name: str, + shape: tuple[int, int], + hbm_addr: int | None = None, + prestaged_vram_addr: int | None = None, + ) -> InputVar: + """ + Declare an input tensor (in HBM). + + Args: + name: tensor name + shape: (height, width) + hbm_addr: HBM address (None = auto-allocate) + prestaged_vram_addr: If an int, the tensor is assumed to be already + present in VRAM at this byte address. A subsequent call to + ``load_batch`` will register it at that address without emitting + any HBM→VRAM prefetch instructions. If None (default), the + normal HBM→VRAM load path is used. + + Returns: + InputVar proxy object + """ + h, w = shape + size = h * w + hbm_size = int(size * self.real_data_ratio) + + if hbm_addr is None: + hbm_addr = self._allocate_hbm(hbm_size) + + var = InputVar(self, name, shape, hbm_addr, hbm_size, prestaged_vram_addr=prestaged_vram_addr) + self._inputs[name] = var + super().add_hbm_object( + name=name, + hbm_addr=hbm_addr, + shape=shape, + real_data_ratio=self.real_data_ratio, + ) + return var + + # ======================================================================== + # Load Operations + # ======================================================================== + + def load_batch( + self, + input_var: InputVar, + name: str | None = None, + ) -> VRAMMatrixVar: + """ + Load tensor from HBM to VRAM (Batch type). + + When ``input_var.prestaged_vram_addr`` is set the tensor is assumed to + be already resident in VRAM at that address. No HBM→VRAM prefetch + instructions are emitted; the tensor is simply registered in the symbol + table at the given address. + + Args: + input_var: source InputVar + name: result name (None = use input name) + + Returns: + VRAMMatrixVar proxy object + """ + if not isinstance(input_var, InputVar): + raise TypeError(f"Expected InputVar, got {type(input_var)}") + + display_name = name if name is not None else input_var.display_name + internal_name = self._scoped_name(display_name) + + if input_var.prestaged_vram_addr is not None: + # Prestaged path: tensor is already in VRAM — register without ISA. + h, w = input_var.shape + vram_addr = input_var.prestaged_vram_addr + # Tell the VRAM allocator that this region is occupied so subsequent + # allocations don't collide with it. + self.vram_allocator._vmm.mark_used(vram_addr, h * w, name=internal_name) + super().add_vram_object( + name=internal_name, + shape=(h, w), + vram_addr=vram_addr, + dtype="fp16", + kind="Batch", + allocate_if_none=False, + strict=False, + ) + else: + # Normal path: emit HBM → VRAM prefetch ISA. + super().load_batch( + hbm_object_name=input_var.name, vram_object_name=internal_name, vlen=self.mlen, preload_len=4 + ) + + var = VRAMMatrixVar(self, internal_name, input_var.shape, display_name=display_name) + self._tensors[internal_name] = var + return var + + # ======================================================================== + # Store Operations + # ======================================================================== + + def store(self, tensor_var, name: str | None = None, hbm_addr: int | None = None) -> InputVar: + """ + Write tensor from VRAM back to HBM. + + Returns: + InputVar proxy object (can be loaded back later) + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"Store requires VRAMMatrixVar, got {type(tensor_var)}") + + display_name = name if name is not None else f"{tensor_var.display_name}_stored" + internal_name = self._scoped_name(display_name) + + if hbm_addr is None: + h, w = tensor_var.shape + size = h * w + hbm_size = int(size * self.real_data_ratio) + hbm_addr = self._allocate_hbm(hbm_size) + else: + h, w = tensor_var.shape + hbm_size = int(h * w * self.real_data_ratio) + + super().store_to_hbm( + tensor_name=tensor_var.name, # internal name for symbol table lookup + hbm_addr=hbm_addr, + hbm_object_name=internal_name, + vlen=self.mlen, + ) + + var = InputVar(self, internal_name, tensor_var.shape, hbm_addr, hbm_size, display_name=display_name) + self._inputs[internal_name] = var + return var + + # ======================================================================== + # VRAM Matrix Allocation + # ======================================================================== + + def alloc(self, name: str, rows: int, cols: int, strict: bool = True) -> VRAMMatrixVar: + """ + Allocate a VRAM matrix. + + Used to store intermediate results (e.g., S block, PV, O). + Within function scope, names are automatically prefixed to avoid conflicts. + + Args: + name: matrix name (user-visible) + rows: number of rows + cols: number of columns + strict: if False, skip mlen-alignment checks (for small scratch matrices) + + Returns: + VRAMMatrixVar proxy object + """ + display_name = name + internal_name = self._scoped_name(name) + super().allocate_vram_matrix(name=internal_name, rows=rows, cols=cols, strict=strict) + + var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) + self._tensors[internal_name] = var + return var + + def alloc_at(self, name: str, rows: int, cols: int, vram_addr: int) -> VRAMMatrixVar: + """Allocate a VRAM matrix view at a specific address. + + Used to create views into existing VRAM matrices (e.g., per-head + slices of a multi-head Q projection output). Does NOT bump the + VRAM allocator -- the caller is responsible for ensuring the region + is valid. + + Args: + name: matrix name (user-visible) + rows: number of rows + cols: number of columns + vram_addr: absolute VRAM address for this view + + Returns: + VRAMMatrixVar proxy object + """ + display_name = name + internal_name = self._scoped_name(name) + self._compiler.add_vram_object( + name=internal_name, + shape=(rows, cols), + vram_addr=vram_addr, + allocate_if_none=False, + ) + isa_code = f"; VRAM View {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" + self._compiler.emit(isa_code) + var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) + self._tensors[internal_name] = var + return var + + def free_tensor(self, tensor_var: TensorVar): + """ + Free a tensor in VRAM, reclaiming space for subsequent allocations. + + Freed space can be reused by new alloc() or other operations. + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"Can only free VRAMMatrixVar, got {type(tensor_var)}") + + super().free_vram_object(tensor_var.name, strict=False) + # Keep sub-matrix registration state consistent after free. + self._registered_vram_sub_matrices[tensor_var.name] = False + + def free_input(self, input_var: InputVar): + """ + Free an InputVar bookkeeping and recycle its HBM range for future auto-allocation. + + Notes: + - This only affects PlenaCompiler's address management state. + - If a freed input is referenced again later, caller is responsible for correctness. + """ + if not isinstance(input_var, InputVar): + raise TypeError(f"Can only free InputVar, got {type(input_var)}") + + super().free_hbm_object(input_var.name, strict=False) + self._registered_hbm_sub_matrices[input_var.name] = False + self._recycle_hbm(input_var.hbm_addr, input_var.hbm_size) + self._inputs.pop(input_var.name, None) + + def free_fp_var(self, fp_var: FPVar): + """ + Free an FPVar and return its block to FPRAM free pool. + """ + if not isinstance(fp_var, FPVar): + raise TypeError(f"Can only free FPVar, got {type(fp_var)}") + self.free_fpram(fp_var.name, strict=True) + + # ======================================================================== + # Normalization Operations + # ======================================================================== + + def norm( + self, + tensor_var: TensorVar, + mode: str = "rms", + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """ + Normalize tensor in-place. + + Args: + tensor_var: tensor to normalize (must have VRAM backing, e.g., VRAMMatrixVar) + mode: "rms" or "layer" + eps_offset: FPRAM address of epsilon + reci_hid_offset: FPRAM address of 1/hidden_dim + vlen: vector length (default: program mlen) + scratchpad_vram_addr: optional scratchpad VRAM address + + Returns: + The same tensor_var (in-place operation) + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"norm requires VRAMMatrixVar, got {type(tensor_var)}") + + super().normalize( + tensor_name=tensor_var.name, + mode=mode, + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + return tensor_var + + def rms_norm( + self, + tensor_var: TensorVar, + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """RMS normalization (in-place).""" + return self.norm( + tensor_var=tensor_var, + mode="rms", + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + + def layer_norm( + self, + tensor_var: TensorVar, + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """Layer normalization (in-place).""" + return self.norm( + tensor_var=tensor_var, + mode="layer", + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + + +__all__ = ["DslTensorMixin"] diff --git a/aten/plena/isa_attention.py b/aten/plena/isa_attention.py new file mode 100644 index 0000000..455a639 --- /dev/null +++ b/aten/plena/isa_attention.py @@ -0,0 +1,531 @@ +"""Flash-attention ISA helpers for IsaCompiler.""" + +from __future__ import annotations + + +class IsaAttentionMixin: + # ========================================================================= + # Flash Attention Implementation + # ========================================================================= + + def _online_softmax_asm( + self, + mlen: int, + s_address: int, + m_start_address: int, + scale: float = 1.0, + ) -> str: + """ + Online Softmax Computation. + + Per row of S: + 1. m_curr = max(S[row], m_old) + 2. m_res = exp(m_old - m_curr) # used to update O downstream + 3. S'[row] = S[row] - m_curr + 4. P[row] = exp(S'[row]) + 5. l_new = l_old * m_res + sum(P[row]) + + FP SRAM layout (from m_start_address): + [0, mlen): m_old / m_curr + [mlen, 2*mlen): m_res = exp(m_old - m_curr) + [2*mlen, 3*mlen): l_old / l_new + """ + gp_regs = self.register_allocator.allocate_gp(4) + gp_s = gp_regs[0] + gp_m_addr = gp_regs[1] + gp_m_res_addr = gp_regs[2] + gp_l_addr = gp_regs[3] + + # Fixed FP register allocation for online softmax pipeline. + # These registers are shared across _online_softmax_asm, _scale_o_asm, + # and _final_scaling_asm — they MUST remain consistent across all three. + # WARNING: Do not use f1-f6 in any code that calls these methods. + fp_m_old = 1 # f1: m_old value + fp_m_res = 2 # f2: exp(m_old - m_curr) + fp_l_old = 3 # f3: l_old value + fp_sum_p = 4 # f4: sum(P) + fp_scale = 5 # f5: scale factor + fp_row_max = 6 # f6: current row max (temporary) + + lines = [] + lines.append("; === Online Softmax ===") + + # Set address registers + lines.append(f"S_ADDI_INT gp{gp_s}, gp0, {s_address}") + lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp0, {m_start_address}") + lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_addr}, {mlen}") + lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") + + # scale factor is pre-loaded at FP SRAM addr 1 by the flash-attention driver. + if scale != 1.0: + lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") + + for row in range(mlen): + lines.append(f"; Row {row}") + + lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, {row}") + lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") + + if scale != 1.0: + lines.append(f"V_MUL_VF gp{gp_s}, gp{gp_s}, f{fp_scale}, 0") + + lines.append(f"V_RED_MAX f{fp_row_max}, gp{gp_s}, 0") + + # m_curr = max(row_max, m_old) — online softmax must retain the running max. + lines.append(f"S_MAX_FP f{fp_m_old}, f{fp_row_max}, f{fp_m_old}") + + lines.append(f"S_SUB_FP f{fp_m_res}, f{fp_m_res}, f{fp_m_old}") + lines.append(f"S_EXP_FP f{fp_m_res}, f{fp_m_res}, 0") + + lines.append(f"S_ST_FP f{fp_m_res}, gp{gp_m_res_addr}, {row}") + lines.append(f"S_ST_FP f{fp_m_old}, gp{gp_m_addr}, {row}") + + lines.append(f"V_SUB_VF gp{gp_s}, gp{gp_s}, f{fp_m_old}, 0, 0") + lines.append(f"V_EXP_V gp{gp_s}, gp{gp_s}, 0, 0") + + lines.append(f"S_LD_FP f{fp_l_old}, gp{gp_l_addr}, {row}") + + lines.append(f"S_ADD_FP f{fp_sum_p}, f0, f0") + lines.append(f"V_RED_SUM f{fp_sum_p}, gp{gp_s}, 0, 0") + + lines.append(f"S_MUL_FP f{fp_l_old}, f{fp_l_old}, f{fp_m_res}") + lines.append(f"S_ADD_FP f{fp_l_old}, f{fp_l_old}, f{fp_sum_p}") + + lines.append(f"S_ST_FP f{fp_l_old}, gp{gp_l_addr}, {row}") + + lines.append(f"S_ADDI_INT gp{gp_s}, gp{gp_s}, {mlen}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _pv_multiply_asm( + self, + mlen: int, + blen: int, + head_dim: int, + p_address: int, + v_hbm_offset_reg: int, + v_hbm_offset: int, + pv_address: int, + ) -> str: + """ + Compute PV = P @ V via M_MM. + + P: (mlen, mlen) in VRAM (softmax output) + V: (mlen, head_dim) in HBM (prefetched into MSRAM in mlen-wide column blocks) + PV: (mlen, head_dim) in VRAM + + M_MM computes one (blen, mlen) @ (mlen, blen) -> (blen, blen) in a single op + (K=mlen done in one shot). For head_dim > mlen, V is split into head_dim/mlen + column blocks; the outer loop iterates blocks, middle loop iterates blen-wide + V columns within a block, inner loop iterates blen-wide P rows. + """ + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(5) + gp_p = gp_regs[0] + gp_v = gp_regs[1] + gp_pv = gp_regs[2] + gp_hbm = gp_regs[3] + gp_stride = gp_regs[4] + + num_v_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === PV Multiply (P @ V) using M_MM ===") + lines.append(f"; P: ({mlen}, {mlen}) @ V: ({mlen}, {head_dim}) -> PV: ({mlen}, {head_dim})") + lines.append("; M_MM: (blen, mlen) @ (mlen, blen) -> (blen, blen), K=mlen in one shot") + lines.append(f"; V split into {num_v_col_blocks} column blocks of width {mlen}") + lines.append("; Storage layout: (batch, mlen, hidden/mlen), column-block major") + + # STRIDE was set to mlen by the flash-attention driver — do not overwrite it here. + # M_MM_WO requires a nonzero stride reg (gp0=0 would be interpreted as stride=1). + # With column-block-major storage, consecutive rows within a column block are + # adjacent, so the writeback stride = 1. + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for v_col_block in range(num_v_col_blocks): + lines.append( + f"; --- V column block {v_col_block} (columns {v_col_block * mlen} to {(v_col_block + 1) * mlen - 1}) ---" + ) + + # Prefetch V[:, v_col_block*mlen:(v_col_block+1)*mlen] (mlen × mlen) to MSRAM. + # V is row-major in HBM: V[row, col] at offset row*head_dim + col, so the + # column-block base offset = v_hbm_offset + v_col_block * mlen (elements). + v_block_hbm_offset = v_hbm_offset + v_col_block * mlen + lines.append(f"S_ADDI_INT gp{gp_v}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") + lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") + + # mat_offset constraint: < mlen and a multiple of blen. + for v_col in range(mlen // blen): + lines.append(f"; V column {v_col_block * mlen + v_col * blen}") + + v_msram_offset = v_col * blen + lines.append(f"S_ADDI_INT gp{gp_v}, gp0, {v_msram_offset}") + + for p_row in range(mlen // blen): + p_row_addr = p_address + p_row * blen * mlen + lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_row_addr}") + + lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") + + # PV[row, col] addr = base + col_block * mlen * mlen + row * mlen + col_in_block + # with row = p_row * blen and col_in_block = v_col * blen. + pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen + lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_address + pv_offset}") + lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _scale_o_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + m_res_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Scale each row of O by m_res: O[row] *= m_res[row].""" + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_m_res = gp_regs[0] + gp_o = gp_regs[1] + fp_m_res = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Scale O by m_res ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") + + for row in range(mlen): + lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, {row}") + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _add_pv_to_o_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + pv_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Accumulate PV into O: O[row] += PV[row].""" + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_o = gp_regs[0] + gp_pv = gp_regs[1] + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Add PV to O ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + for row in range(mlen): + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + pv_addr = pv_address + col_block * mlen * mlen + row * mlen + + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_addr}") + lines.append(f"V_ADD_VV gp{gp_o}, gp{gp_o}, gp{gp_pv}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _final_scaling_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + l_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """ + Final scaling: O[row] /= l[row]. + + V_MUL_VF processes mlen elements at a time; when head_dim > mlen, + each row is split into head_dim // mlen mlen-wide blocks. + """ + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_l = gp_regs[0] + gp_o = gp_regs[1] + fp_l = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Final Scaling O = O / l ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") + + for row in range(mlen): + lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, {row}") + lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _reset_fpsram_asm( + self, + start_address: int, + count: int, + value_address: int, # FP SRAM slot: 0 = zero, 2 = -inf + ) -> str: + """Reset a region of FP SRAM to the value at value_address.""" + gp_regs = self.register_allocator.allocate_gp(1) + gp_addr = gp_regs[0] + + lines = [] + lines.append(f"; Reset FP SRAM [{start_address}, {start_address + count})") + + lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {start_address}") + # Use f1 for FP scalar - FP registers don't go through GP allocator + lines.append(f"S_LD_FP f1, gp0, {value_address}") + + for i in range(count): + lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _reset_vram_asm( + self, + start_address: int, + rows: int, + cols: int, + total_rows: int, + mlen: int = 64, + row_offset: int = 0, + ) -> str: + """ + Reset a region of VRAM to zero. + + V_MUL_VF processes mlen elements at a time; when cols > mlen, each + row is split into cols // mlen mlen-wide blocks. + """ + gp_regs = self.register_allocator.allocate_gp(1) + gp_addr = gp_regs[0] + + num_col_blocks = (cols + mlen - 1) // mlen + + lines = [] + lines.append(f"; Reset VRAM rows [{row_offset}, {row_offset + rows}) of matrix at {start_address}") + lines.append(f"; {rows} rows x {cols} cols, {num_col_blocks} blocks per row") + lines.append("; Storage layout: (total_rows, mlen, cols/mlen), column-block major") + lines.append(f"; total_rows = {total_rows}, row_offset = {row_offset}") + + for row in range(rows): + actual_row = row_offset + row + for col_block in range(num_col_blocks): + addr = start_address + col_block * total_rows * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") + lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + # ========================================================================= + # Expanded Flash Attention Operations + # ========================================================================= + + def init_online_softmax( + self, + q_idx: int, + o_matrix: str, + seq_len: int, + head_dim: int, + ) -> str: + """ + Initialize Online Softmax state for Q block q_idx: + m_old = -inf (FP SRAM), l = 0 (FP SRAM), O_row = 0 (VRAM). + """ + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_old_addr = fp_sram_start + l_addr = fp_sram_start + 2 * self.mlen # skip m_res region + + o_info = self[o_matrix] + o_vram_addr = o_info.vram_addr + row_offset = q_idx * self.mlen + + isa_code = f"; === Init Online Softmax for Q block {q_idx} ===\n" + + isa_code += self._reset_fpsram_asm(m_old_addr, self.mlen, 2) # slot 2 = -inf + isa_code += self._reset_fpsram_asm(l_addr, self.mlen, 0) # slot 0 = 0.0 + isa_code += self._reset_vram_asm( + start_address=o_vram_addr, + rows=self.mlen, + cols=head_dim, + total_rows=seq_len, + mlen=self.mlen, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + def online_softmax_block( + self, + s_block_matrix: str, + scale: float, + ) -> str: + """ + Run Online Softmax on one S block. + Input: S_block (mlen × mlen) in VRAM + Output: P (mlen × mlen) in-place in VRAM + Updates: m_old, m_res, l in FP SRAM + ``scale`` is the QK^T scaling factor (typically 1/sqrt(d)). + """ + s_info = self[s_block_matrix] + s_address = s_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_start_address = fp_sram_start + + isa_code = f"; === Online Softmax Block {s_block_matrix} ===\n" + isa_code += self._online_softmax_asm( + mlen=self.mlen, s_address=s_address, m_start_address=m_start_address, scale=scale + ) + + return self._emit(isa_code) + + def compute_pv( + self, + s_block_matrix: str, + v_sub_matrix: str, + k_idx: int, + pv_matrix: str, + head_dim: int, + ) -> str: + """ + Compute PV = P @ V[k_idx]. + + P lives in s_block_matrix (softmax result); V is prefetched from + HBM; PV is written to VRAM via pv_matrix. + """ + s_info = self[s_block_matrix] + p_address = s_info.vram_addr + + pv_info = self[pv_matrix] + pv_address = pv_info.vram_addr + + v_layout = self.get_hbm_layout(v_sub_matrix) + v_hbm_offset = k_idx * self.mlen * head_dim + + isa_code = f"; === Compute PV = P @ V[k_idx={k_idx}] ===\n" + + addr_regs = self.register_allocator.allocate_addr(1) + v_hbm_reg = addr_regs[0] + gp_regs = self.register_allocator.allocate_gp(2) + + from compiler.asm_templates import preload_addr_reg_asm + + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[v_hbm_reg], available_registers=gp_regs, addr_reg_val=[v_layout.hbm_base_addr] + ) + + isa_code += self._pv_multiply_asm( + mlen=self.mlen, + blen=self.blen, + head_dim=head_dim, + p_address=p_address, + v_hbm_offset_reg=v_hbm_reg, + v_hbm_offset=v_hbm_offset, + pv_address=pv_address, + ) + + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_addr(addr_regs) + + return self._emit(isa_code) + + def scale_o_row( + self, + o_matrix: str, + q_idx: int, + seq_len: int, + head_dim: int, + ) -> str: + """Scale the current row block of O by m_res: O[q_idx] *= m_res.""" + o_info = self[o_matrix] + o_address = o_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_res_addr = fp_sram_start + self.mlen + + row_offset = q_idx * self.mlen + + isa_code = f"; === Scale O[q_idx={q_idx}] by m_res ===\n" + isa_code += self._scale_o_asm( + mlen=self.mlen, + head_dim=head_dim, + seq_len=seq_len, + m_res_address=m_res_addr, + o_address=o_address, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + def final_scale_o( + self, + q_idx: int, + o_matrix: str, + seq_len: int, + head_dim: int, + ) -> str: + """Final scaling: O[q_idx] /= l.""" + o_info = self[o_matrix] + o_address = o_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + l_addr = fp_sram_start + 2 * self.mlen + + row_offset = q_idx * self.mlen + + isa_code = f"; === Final Scale O for Q block {q_idx} ===\n" + isa_code += self._final_scaling_asm( + mlen=self.mlen, + head_dim=head_dim, + seq_len=seq_len, + l_address=l_addr, + o_address=o_address, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + +__all__ = ["IsaAttentionMixin"] diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py new file mode 100644 index 0000000..35c6896 --- /dev/null +++ b/aten/plena/isa_compiler.py @@ -0,0 +1,576 @@ +"""Low-level ISA emission layer for the ATen PLENA compiler.""" + +from __future__ import annotations + +from compiler.asm_templates import ( + layer_norm_asm, + preload_act_asm, + preload_addr_reg_asm, + reset_reg_asm, + rms_norm_asm, + rope_asm, + store_act_asm, +) +from compiler.aten.plena.isa_attention import IsaAttentionMixin +from compiler.aten.plena.isa_emit import IsaEmitMixin +from compiler.aten.plena.isa_fp_ops import IsaFPOpsMixin +from compiler.aten.plena.isa_matrix import IsaMatrixMixin +from compiler.aten.plena.isa_tile_rows import IsaTileRowMixin +from compiler.aten.plena.registers import RegisterAllocator +from compiler.aten.plena.tile_compiler import TileCompiler + + +class IsaCompiler( + IsaAttentionMixin, + IsaMatrixMixin, + IsaTileRowMixin, + IsaFPOpsMixin, + IsaEmitMixin, + TileCompiler, +): + """ + ISA Compiler: lowers PLENA compiler operations to assembly text. + + Owns symbol_table, register_allocator, and the InterruptManager. + Sub-matrix / memory management is inherited from TileCompiler; the + legacy ``self.tile_compiler`` accessor is preserved as a property + returning ``self`` for a handful of remaining external callers. + """ + + _ONLINE_SOFTMAX_FPSRAM_BASE = 10 + + class InterruptManager: + """ + Interrupt Manager — manages execution timing only. + Actual handlers live on IsaCompiler as ``_handle_k_start``, + ``_handle_k_prefetch_done``, ``_handle_s_tile_done``, ``_handle_k_end``. + """ + + def __init__(self, compiler: IsaCompiler): + self.compiler = compiler + self.enabled = False + + self._k_count = 0 + self._tile_count = 0 + + self.current_matrix: str = "" + self.current_activation: str = "" + self._mlen = compiler.mlen + self._blen = compiler.blen + self._batch = compiler.mlen + + self._q_block_idx = 0 + self._k_block_idx = 0 + self._s_tile_address = 0 + + @property + def tile_compiler(self): + return self.compiler.tile_compiler + + @property + def k_count(self) -> int: + return self._k_count + + @property + def tile_count(self) -> int: + return self._tile_count + + @property + def batch(self) -> int: + return self._batch + + @property + def out_features(self) -> int: + if self.current_matrix and self.current_matrix in self: + info = self[self.current_matrix] + return info.shape[0] + return self._mlen + + @property + def hidden_size(self) -> int: + if self.current_matrix and self.current_matrix in self: + info = self[self.current_matrix] + return info.shape[1] + return self._mlen + + @property + def k_block(self) -> int: + return self._k_block_idx + + @property + def q_block(self) -> int: + return self._q_block_idx + + @property + def s_tile_address(self) -> int: + return self._s_tile_address + + @property + def mlen(self) -> int: + return self._mlen + + @property + def blen(self) -> int: + return self._blen + + def reset(self): + """Reset counters (does not clear current_matrix).""" + self._k_count = 0 + self._tile_count = 0 + self._q_block_idx = 0 + self._k_block_idx = 0 + self._s_tile_address = 0 + + def enable(self): + self.enabled = True + + def disable(self): + self.enabled = False + + def trigger_k_start(self) -> str: + if not self.enabled: + return "" + return self.compiler._handle_k_start() + + def trigger_k_prefetch_done(self) -> str: + if not self.enabled: + return "" + result = self.compiler._handle_k_prefetch_done() + self._k_count += 1 + return result + + def trigger_s_tile_done(self) -> str: + if not self.enabled: + return "" + result = self.compiler._handle_s_tile_done() + self._tile_count += 1 + return result + + def trigger_k_end(self) -> str: + if not self.enabled: + return "" + return self.compiler._handle_k_end() + + def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + # TileCompiler.__init__ sets mlen, blen, unroll_loops, the HBM/VRAM/FPRAM + # matrices and allocators, loaded_sub_blocks, and _address_cache. + super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) + self.real_data_ratio = real_data_ratio + self.register_allocator = RegisterAllocator() + self.generated_code = "" + self.interrupt = self.InterruptManager(self) + + # Back-compat shim: older callers (and a couple of external ops modules) + # reach into ``compiler.tile_compiler`` directly. After the merge, + # IsaCompiler *is* the TileCompiler, so the property just returns + # ``self``. + @property + def tile_compiler(self) -> IsaCompiler: + return self + + # Interrupt handler placeholders (overridden by flash-attention passes). + + def _handle_k_start(self) -> str: + return "" + + def _handle_k_prefetch_done(self) -> str: + return "" + + def _handle_s_tile_done(self) -> str: + return "" + + def _handle_k_end(self) -> str: + return "" + + def load_batch( + self, + hbm_object_name: str, + vram_object_name: str, + vlen: int = 64, + preload_len: int = 4, + ) -> str: + """ + Load a Batch tensor from HBM to VRAM. + + HBM storage is MXFP (1 scale per 8 elements), so HBM actual size = + logical size * real_data_ratio = 1.125. VRAM stores only the vector + data (no scale), so VRAM size = logical size. + + Order (matters): allocate VRAM → register in symbol table → emit ISA. + """ + hbm_layout = self.get_hbm_layout(hbm_object_name) + h, w = hbm_layout.full_shape + hbm_addr = hbm_layout.hbm_base_addr + size = h * w + vram_base = self.vram_allocator.allocate(size, name=vram_object_name) + self.add_vram_object( + name=vram_object_name, + shape=(h, w), + vram_addr=vram_base, + dtype="fp16", + kind="Batch", + allocate_if_none=False, + strict=False, + ) + + addr_reg = self.register_allocator.allocate_addr(1)[0] + gp_regs_for_addr = self.register_allocator.allocate_gp(1) + + isa_code = f"; Load_Batch {hbm_object_name} -> {vram_object_name}\n" + isa_code += f"; HBM[{hbm_addr}] → VRAM[{vram_base}], shape=({h}, {w})\n" + + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] + ) + + # preload_act_asm requires 5 GP registers: [a_actual, stride, result, outer_loop, inner_loop]. + gp_regs_for_preload = self.register_allocator.allocate_gp(5) + isa_code += reset_reg_asm(alive_registers=gp_regs_for_preload) + + isa_code += preload_act_asm( + vlen=vlen, + preload_len=preload_len, + batch=h, + hidden_size=w, + alive_registers=gp_regs_for_preload, + act_vram_offset=vram_base, + activation_offset_reg=addr_reg, + stride_size=w, + ) + + self.register_allocator.free_gp(gp_regs_for_addr) + self.register_allocator.free_gp(gp_regs_for_preload) + self.register_allocator.free_addr([addr_reg]) + + return self._emit(isa_code) + + def store_to_hbm( + self, + tensor_name: str, + hbm_addr: int | None = None, + hbm_object_name: str | None = None, + hbm_addr_reg: int | None = None, + vlen: int = 64, + precision: int = 0, # 0 = Activation, 1 = KeyValue + store_amount: int = 4, # HBM_V_Writeback_Amount + ) -> str: + """ + Write tensor from VRAM back to HBM. + + Used to spill computed intermediates (e.g., K) from VRAM to HBM so + downstream ops (e.g., QK^T) can read them from HBM. Emits + ``store_act_asm`` for tensors of any supported size. + """ + if tensor_name not in self: + raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") + + tensor_info = self[tensor_name] + + # Batch and VRAMMatrix share the same VRAM storage layout. + if tensor_info.kind not in ("Batch", "VRAMMatrix"): + raise ValueError( + f"Tensor '{tensor_name}' must be Batch or VRAMMatrix to store from VRAM, got {tensor_info.kind}" + ) + + if tensor_info.vram_addr is None: + raise ValueError(f"Tensor '{tensor_name}' has no VRAM address to store") + + if hbm_addr is None: + if tensor_info.hbm_addr >= 0: + hbm_addr = tensor_info.hbm_addr + else: + raise ValueError(f"Tensor '{tensor_name}' has no HBM address. Please specify hbm_addr.") + + batch_size = tensor_info.shape[0] + hidden_size = tensor_info.shape[1] + + isa_code = f"; Store {tensor_name} from VRAM to HBM\n" + isa_code += f"; VRAM[{tensor_info.vram_addr}] -> HBM[{hbm_addr}], shape=({batch_size}, {hidden_size})\n" + + gp_regs = self.register_allocator.allocate_gp(5) + + if hbm_addr_reg is None: + addr_regs = self.register_allocator.allocate_addr(1) + hbm_addr_reg = addr_regs[0] + need_free_addr = True + else: + addr_regs = [] + need_free_addr = False + + try: + gp_regs_for_addr = self.register_allocator.allocate_gp(2) + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[hbm_addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] + ) + self.register_allocator.free_gp(gp_regs_for_addr) + + isa_code += store_act_asm( + vlen=vlen, + batch=batch_size, + hidden_size=hidden_size, + alive_registers=gp_regs, + act_vram_offset=tensor_info.vram_addr, + hbm_addr_reg=hbm_addr_reg, + stride_size=hidden_size, + store_amount=store_amount, + ) + + if tensor_info.hbm_addr < 0 or tensor_info.hbm_addr != hbm_addr: + tensor_info.hbm_addr = hbm_addr + # HBM stores the MXFP-expanded size (logical size × real_data_ratio). + size = batch_size * hidden_size + tensor_info.hbm_size = int(size * self.real_data_ratio) + finally: + self.register_allocator.free_gp(gp_regs) + if need_free_addr: + self.register_allocator.free_addr(addr_regs) + + if hbm_object_name is not None: + self.add_hbm_object( + name=hbm_object_name, + hbm_addr=hbm_addr, + shape=(batch_size, hidden_size), + ) + + return self._emit(isa_code) + + def normalize( + self, + tensor_name: str, + mode: str = "rms", + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> str: + """ + Normalize a VRAM tensor in-place. + + Supports: + - mode="rms": RMSNorm + - mode="layer": LayerNorm + + Args: + tensor_name: Tensor name in symbol table (must have VRAM address) + mode: "rms" or "layer" + eps_offset: FPRAM address of epsilon + reci_hid_offset: FPRAM address of 1/hidden_dim + vlen: vector length (default: self.mlen) + scratchpad_vram_addr: scratchpad VRAM address (default: auto-allocate temporary space) + """ + if tensor_name not in self: + raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") + + tensor_info = self[tensor_name] + if tensor_info.vram_addr is None: + raise ValueError(f"Tensor '{tensor_name}' has no VRAM address") + + batch_size, hidden_dim = tensor_info.shape + if vlen is None: + vlen = self.mlen + if hidden_dim % vlen != 0: + raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by vlen ({vlen}) for normalization_asm") + + mode = mode.lower() + if mode not in ("rms", "layer"): + raise ValueError(f"Unsupported normalization mode: {mode}. Expected 'rms' or 'layer'.") + + gp_regs = self.register_allocator.allocate_gp(4) + + temp_scratchpad_name = None + if scratchpad_vram_addr is None: + temp_scratchpad_name = f"__norm_scratch__{tensor_name}__{len(self.generated_code)}" + scratchpad_vram_addr = self.vram_allocator.allocate(vlen, name=temp_scratchpad_name) + + try: + isa_code = f"; Normalize ({mode}) {tensor_name}, shape=({batch_size}, {hidden_dim})\n" + if mode == "rms": + isa_code += rms_norm_asm( + _eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + alive_registers=gp_regs, + activation_base_address=tensor_info.vram_addr, + scratchpad_base_address=scratchpad_vram_addr, + vlen=vlen, + batch_size=batch_size, + hidden_dim=hidden_dim, + ) + else: + isa_code += layer_norm_asm( + _eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + alive_registers=gp_regs, + activation_base_address=tensor_info.vram_addr, + scratchpad_base_address=scratchpad_vram_addr, + vlen=vlen, + batch_size=batch_size, + hidden_dim=hidden_dim, + ) + + return self._emit(isa_code) + finally: + # Always release allocated GP registers used by normalization template. + self.register_allocator.free_gp(gp_regs) + if temp_scratchpad_name is not None: + self.vram_allocator.free(temp_scratchpad_name, strict=False) + + def rope( + self, + x_name: str, + x_rot_name: str, + cos_name: str, + sin_name: str, + ) -> str: + """Apply RoPE in-place: x = x * cos + rotate_half(x) * sin + + All four tensors must already be in VRAM with the same shape (seq_len, head_dim). + x_rot must be preloaded by the caller as rotate_half(x). + """ + x_info = self[x_name] + xrot_info = self[x_rot_name] + cos_info = self[cos_name] + sin_info = self[sin_name] + + if x_info.vram_addr is None: + raise ValueError(f"Tensor '{x_name}' has no VRAM address") + + seq_len, head_dim = x_info.shape + vlen = self.mlen + + if head_dim % vlen != 0: + raise ValueError(f"head_dim ({head_dim}) must be divisible by vlen ({vlen}) for rope") + + gp_regs = self.register_allocator.allocate_gp(5) + + scratch_name = f"__rope_scratch__{x_name}__{len(self.generated_code)}" + scratch_addr = self.vram_allocator.allocate(vlen, name=scratch_name) + + try: + isa_code = rope_asm( + alive_registers=gp_regs, + x_base_address=x_info.vram_addr, + x_rot_base_address=xrot_info.vram_addr, + cos_base_address=cos_info.vram_addr, + sin_base_address=sin_info.vram_addr, + scratchpad_base_address=scratch_addr, + vlen=vlen, + seq_len=seq_len, + head_dim=head_dim, + ) + return self._emit(isa_code) + finally: + self.register_allocator.free_gp(gp_regs) + self.vram_allocator.free(scratch_name, strict=False) + + def get_code(self) -> str: + """Get all accumulated generated ISA code""" + return self.generated_code + + def reset(self): + """Reset compiler state (clear code, but retain symbol table)""" + self.generated_code = "" + self.register_allocator = RegisterAllocator() + # Call TileCompiler.reset() explicitly since the merged class shadows it. + TileCompiler.reset(self) + + def print_symbol_table(self): + """Print symbol table""" + self.print_table() + + def get_symbol_table(self): + """Get managed object table view.""" + return self + + def get_tensor_info(self, name: str): + """Get unified tensor/object info by name.""" + return self[name] + + def add_hbm_object( + self, + name: str, + hbm_addr: int, + shape: tuple[int, int], + real_data_ratio: float = 1.125, + ): + """Register an HBM object and build its HBM layout. + + Wraps ``TileCompiler.add_hbm_object`` with a different positional + parameter order ``(name, hbm_addr, shape, ...)`` that all IsaCompiler + callers use. + """ + return TileCompiler.add_hbm_object( + self, + name=name, + shape=shape, + hbm_addr=hbm_addr, + real_data_ratio=real_data_ratio, + ) + + def free_hbm_object(self, name: str, strict: bool = False): + """Free an HBM object by name (defaults to non-strict).""" + return TileCompiler.free_hbm_object(self, name, strict=strict) + + def get_vram_addr(self, name: str) -> int: + """Get VRAM base address of an object.""" + info = self.get_tensor_info(name) + if info.vram_addr is None: + raise ValueError(f"Object '{name}' has no VRAM address") + return info.vram_addr + + def get_vram_tile_addr( + self, + name: str, + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> int: + """ + Get VRAM address of a specific tile (sub-block) in a VRAM matrix. + + Args: + name: matrix name + tile_row_idx: tile row index (0-based) + tile_col_idx: tile col index (0-based) + """ + self._ensure_vram_matrix_layout(name) + sub = self.get_vram_sub_block(name, tile_row_idx, tile_col_idx) + return sub.vram_addr + + def ensure_hbm_sub_matrix( + self, + name: str, + hbm_addr: int, + shape: tuple[int, int], + real_data_ratio: float = 1.125, + ): + """Ensure HBM matrix layout exists.""" + if name in self.hbm_matrices: + return + self.add_hbm_object( + name=name, + hbm_addr=hbm_addr, + shape=shape, + real_data_ratio=real_data_ratio, + ) + + def ensure_vram_matrix_layout(self, name: str, shape: tuple[int, int]): + """Ensure VRAM matrix layout exists for an already allocated VRAM object.""" + if name in self.vram_matrices: + return + vram_addr = self.get_vram_addr(name) + self.add_vram_object( + name=name, + shape=shape, + vram_addr=vram_addr, + allocate_if_none=False, + ) + + def free_vram_object(self, name: str, strict: bool = False): + """Free a VRAM object by name (defaults to non-strict).""" + return TileCompiler.free_vram_object(self, name, strict=strict) + + +# Compatibility alias for callers that imported the old low-level layer name. +DeveloperCompiler = IsaCompiler + + +__all__ = ["DeveloperCompiler", "IsaCompiler"] diff --git a/aten/plena/isa_emit.py b/aten/plena/isa_emit.py new file mode 100644 index 0000000..c0014a1 --- /dev/null +++ b/aten/plena/isa_emit.py @@ -0,0 +1,87 @@ +"""Emission and FPRAM allocation helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import AsmInput, IsaBuilder, render_asm +from compiler.aten.plena.registers import RegisterAllocator + + +class IsaEmitMixin: + # ========================================================================= + # FP Register & FPRAM Management (inlined from former FPRAMCompiler). + # All state lives on self (register_allocator, fpram_allocator, etc.). + # ========================================================================= + + @property + def _reg(self) -> RegisterAllocator: + """Shorthand for self.register_allocator (used by FPVar ISA helpers).""" + return self.register_allocator + + @property + def _unroll(self) -> bool: + """Shorthand for self.unroll_loops.""" + return self.unroll_loops + + def _emit(self, isa_code: AsmInput) -> str: + """Append ISA text to the output buffer and return it.""" + rendered = render_asm(isa_code) + self.generated_code += rendered + return rendered + + def emit(self, isa_code: AsmInput) -> str: + """Public emission hook for code outside IsaCompiler internals.""" + return self._emit(isa_code) + + def emit_comment(self, text: str) -> str: + """Append one assembly comment line.""" + return self._emit(IsaBuilder().comment(text)) + + # ------------------------------------------------------------------ + # FP Register management + # ------------------------------------------------------------------ + + def allocate_fp_reg(self, count: int = 1) -> list[int]: + """Allocate FP registers (f0-f7).""" + return self._reg.allocate_fp(count) + + def free_fp_reg(self, registers: list[int]): + """Free FP registers.""" + self._reg.free_fp(registers) + + # ------------------------------------------------------------------ + # FPRAM address-space management + # ------------------------------------------------------------------ + + def allocate_fpram(self, name: str, size: int) -> int: + """Allocate FPRAM space, returns base address.""" + info = self.add_fpram_object(name=name, size=size) + if info.fpram_addr is None: + raise RuntimeError(f"Failed to allocate FPRAM for '{name}'") + return info.fpram_addr + + def free_fpram(self, name: str, strict: bool = True): + """Free FPRAM object by name.""" + return self.free_fpram_object(name, strict=strict) + + def save_fpram_state(self) -> int: + """Save FPRAM allocator snapshot.""" + return self.fpram_allocator.save_state() + + def restore_fpram_state(self, snapshot: int): + """Restore FPRAM allocator snapshot.""" + self.fpram_allocator.restore_state(snapshot) + + def list_fpram_allocations(self) -> list[str]: + """List currently allocated FPRAM object names.""" + return list(self.fpram_allocator.allocations.keys()) + + def get_fpram_addr(self, name: str) -> int: + """Get FPRAM base address from object name.""" + return self.get_fpram_layout(name).fpram_addr + + def get_fpram_size(self, name: str) -> int: + """Get FPRAM allocation size from object name.""" + return self.get_fpram_layout(name).size + + +__all__ = ["IsaEmitMixin"] diff --git a/aten/plena/isa_fp_ops.py b/aten/plena/isa_fp_ops.py new file mode 100644 index 0000000..66ad248 --- /dev/null +++ b/aten/plena/isa_fp_ops.py @@ -0,0 +1,316 @@ +"""FPVar and FPRAM ISA helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import IsaBuilder, fp, gp + + +class IsaFPOpsMixin: + # ========================================================================= + # FPVar ISA helpers (address-based) + # ========================================================================= + + def _emit_fpvar_skip(self, op_name: str, count: int) -> str: + return self._emit(IsaBuilder().comment(f"FPVar {op_name} skipped: count={count}")) + + def _fpvar_unary_asm( + self, + op_name: str, + op_description: str, + opcode: str, + src_addr: int, + dst_addr: int, + count: int, + ) -> str: + if count <= 0: + return self._emit_fpvar_skip(op_name, count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar {op_name}: {op_description}, count={count}") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_src), i) + asm.instr(opcode, fp(1), fp(1), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + asm.instr(opcode, fp(1), fp(1), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _fpvar_binary_asm( + self, + op_name: str, + op_description: str, + opcode: str, + src1_addr: int, + src2_addr: int, + dst_addr: int, + count: int, + ) -> str: + if count <= 0: + return self._emit_fpvar_skip(op_name, count) + + gp_regs = self._reg.allocate_gp(4) + gp_a, gp_b, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar {op_name}: {op_description}, count={count}") + asm.instr("S_ADDI_INT", gp(gp_a), gp(0), src1_addr) + asm.instr("S_ADDI_INT", gp(gp_b), gp(0), src2_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_a), i) + asm.instr("S_LD_FP", fp(2), gp(gp_b), i) + asm.instr(opcode, fp(1), fp(1), fp(2)) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_a), 0) + asm.instr("S_LD_FP", fp(2), gp(gp_b), 0) + asm.instr(opcode, fp(1), fp(1), fp(2)) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_a), gp(gp_a), 1) + asm.instr("S_ADDI_INT", gp(gp_b), gp(gp_b), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_copy_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Copy", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment( + f"FPVar Copy: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_addr}:{src_addr + count}]" + ) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_src), i) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_fill_from_fpram_asm(self, dst_addr: int, src_fpram_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Fill", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar Fill: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_fpram_addr}]") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_fpram_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + + if self._unroll: + for i in range(count): + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_reci_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_unary_asm("Reci", "dst = 1/src", "S_RECI_FP", src_addr, dst_addr, count) + + def fpvar_exp_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_unary_asm("Exp", "dst = exp(src)", "S_EXP_FP", src_addr, dst_addr, count) + + def fpvar_add_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Add", "dst = src1 + src2", "S_ADD_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_sub_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Sub", "dst = src1 - src2", "S_SUB_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_mul_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Mul", "dst = src1 * src2", "S_MUL_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_max_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Max", "dst = max(src1, src2)", "S_MAX_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_sum_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Sum", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar Sum: FPRAM[{dst_addr}] = sum(FPRAM[{src_addr}:{src_addr + count}])") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(2), gp(gp_src), i) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(2), gp(gp_src), 0) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_shift_asm( + self, + src_addr: int, + dst_addr: int, + count: int, + shift: int, + fill_fpram_addr: int = 0, + ) -> str: + """ + Shift FPVar into dst. + - shift > 0: right shift (leading positions filled) + - shift < 0: left shift (trailing positions filled) + """ + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_fill = gp_regs + try: + asm = IsaBuilder().comment( + f"FPVar Shift: dst=shift(src, shift={shift}), count={count}, fill=FPRAM[{fill_fpram_addr}]" + ) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_ADDI_INT", gp(gp_fill), gp(0), fill_fpram_addr) + asm.instr("S_LD_FP", fp(3), gp(gp_fill), 0) + + for i in range(count): + src_idx = i - shift + if 0 <= src_idx < count: + asm.instr("S_LD_FP", fp(1), gp(gp_src), src_idx) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("S_ST_FP", fp(3), gp(gp_dst), i) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + # ========================================================================= + # FPVar helpers (name-based wrappers over the address-based ISA generators) + # ========================================================================= + + def fpram_copy(self, src_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) + return self.fpvar_copy_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + + def fpram_reci(self, src_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) + return self.fpvar_reci_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + + def fpram_exp(self, src_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) + return self.fpvar_exp_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + + def fpram_add(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + return self.fpvar_add_asm( + self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count + ) + + def fpram_sub(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + return self.fpvar_sub_asm( + self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count + ) + + def fpram_mul(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + return self.fpvar_mul_asm( + self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count + ) + + def fpram_max(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + return self.fpvar_max_asm( + self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count + ) + + def fpram_sum(self, src_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = self.get_fpram_size(src_name) + return self.fpvar_sum_asm( + self.get_fpram_addr(src_name), + self.get_fpram_addr(dst_name), + count, + ) + + def fpram_shift( + self, + src_name: str, + dst_name: str, + shift: int, + count: int | None = None, + fill_fpram_name: str | None = None, + ) -> str: + if count is None: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) + fill_addr = 0 if fill_fpram_name is None else self.get_fpram_addr(fill_fpram_name) + return self.fpvar_shift_asm( + src_addr=self.get_fpram_addr(src_name), + dst_addr=self.get_fpram_addr(dst_name), + count=count, + shift=shift, + fill_fpram_addr=fill_addr, + ) + + def fpram_fill_from_fpram(self, dst_name: str, src_fpram_addr: int, count: int | None = None) -> str: + if count is None: + count = self.get_fpram_size(dst_name) + return self.fpvar_fill_from_fpram_asm( + dst_addr=self.get_fpram_addr(dst_name), + src_fpram_addr=src_fpram_addr, + count=count, + ) + + +__all__ = ["IsaFPOpsMixin"] diff --git a/aten/plena/isa_matrix.py b/aten/plena/isa_matrix.py new file mode 100644 index 0000000..c7e9abe --- /dev/null +++ b/aten/plena/isa_matrix.py @@ -0,0 +1,384 @@ +"""Matrix movement and VRAM projection helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.asm_templates import preload_addr_reg_asm +from compiler.aten.isa_builder import IsaBuilder + + +class IsaMatrixMixin: + def reset_mram(self) -> str: + """ + Reset MRAM allocator, free all allocated space + Used in scenarios where sub-blocks need to be reloaded within a for loop + """ + self.mram_allocator.reset() + self.loaded_sub_blocks.clear() + + return self._emit(IsaBuilder().comment("=== Reset MRAM ===")) + + def load_sub_matrix_row( + self, + name: str, + row_idx: int, + mram_start_addr: int | None = None, + ) -> str: + """Load entire row sub-blocks from HBM to MRAM: matrix[row_idx][:].""" + layout = self.get_hbm_layout(name) + num_col_blocks = layout.num_col_blocks + block_size = self.mlen * self.mlen + + if mram_start_addr is None: + total_size = num_col_blocks * block_size + mram_start_addr = self.mram_allocator.allocate(f"{name}[{row_idx}][:]", total_size) + + gp_regs = self.register_allocator.allocate_gp(4) + gp_for_addr = self.register_allocator.allocate_gp(2) + addr_reg = self.register_allocator.allocate_addr(1)[0] + + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] + ) + + isa_code += self.load_row_sub_matrices_asm( + name=name, row_idx=row_idx, mram_start_addr=mram_start_addr, hbm_addr_reg=addr_reg, gp_regs=gp_regs + ) + + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_gp(gp_for_addr) + self.register_allocator.free_addr([addr_reg]) + + return self._emit(isa_code) + + def load_sub_matrix_col( + self, + name: str, + col_idx: int, + mram_start_addr: int | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """ + Load entire column sub-blocks from HBM to MRAM: matrix[:][col_idx]. + Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. + """ + layout = self.get_hbm_layout(name) + num_row_blocks = layout.num_row_blocks + block_size = self.mlen * self.mlen + + if mram_start_addr is None: + effective_count = k_block_count if k_block_count is not None else num_row_blocks + total_size = effective_count * block_size + mram_start_addr = self.mram_allocator.allocate(f"{name}[:][{col_idx}]", total_size) + + gp_regs = self.register_allocator.allocate_gp(3) + gp_for_addr = self.register_allocator.allocate_gp(2) + addr_reg = self.register_allocator.allocate_addr(1)[0] + + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] + ) + + isa_code += self.load_col_sub_matrices_asm( + name=name, + col_idx=col_idx, + mram_start_addr=mram_start_addr, + hbm_addr_reg=addr_reg, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_gp(gp_for_addr) + self.register_allocator.free_addr([addr_reg]) + + return self._emit(isa_code) + + def allocate_vram_matrix( + self, + name: str, + rows: int, + cols: int, + strict: bool = True, + ) -> int: + """Allocate a VRAM matrix large enough to hold combined results of multiple sub-blocks. Returns the VRAM base address.""" + size = rows * cols + vram_addr = self.vram_allocator.allocate(size, name=name) + + self.add_vram_object( + name=name, + shape=(rows, cols), + vram_addr=vram_addr, + dtype="fp32", + kind="VRAMMatrix", + allocate_if_none=False, + strict=strict, + ) + + isa_code = f"; Allocate VRAM Matrix {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" + self._emit(isa_code) + + return vram_addr + + def _ensure_vram_matrix_layout(self, matrix_name: str): + """Ensure a VRAM-resident tensor has a block layout in TileCompiler.""" + if matrix_name not in self: + raise KeyError(f"Matrix '{matrix_name}' not found in symbol table") + + info = self[matrix_name] + if info.vram_addr is None: + raise ValueError(f"Matrix '{matrix_name}' has no VRAM address") + + try: + self.get_vram_layout(matrix_name) + except KeyError: + self.register_vram_matrix( + name=matrix_name, + shape=info.shape, + vram_base_addr=info.vram_addr, + ) + + def vram_block_add_to( + self, + src1_matrix: str, + src1_row_idx: int, + src1_col_idx: int, + src2_matrix: str, + src2_row_idx: int, + src2_col_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + ) -> str: + """ + mlen x mlen block add: + target[rt][ct] = src1[r1][c1] + src2[r2][c2] + + Source/target may be the same matrix (supports in-place overwrite). + """ + self._ensure_vram_matrix_layout(src1_matrix) + self._ensure_vram_matrix_layout(src2_matrix) + self._ensure_vram_matrix_layout(target_matrix) + + gp_regs = self.register_allocator.allocate_gp(4) + isa_code = self.vram_block_add_asm( + src1_name=src1_matrix, + src1_row_idx=src1_row_idx, + src1_col_idx=src1_col_idx, + src2_name=src2_matrix, + src2_row_idx=src2_row_idx, + src2_col_idx=src2_col_idx, + target_name=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + gp_regs=gp_regs, + ) + self.register_allocator.free_gp(gp_regs) + + return self._emit(isa_code) + + def vram_matrix_add( + self, + dst_matrix: str, + src_matrix: str, + dst_row_offset: int = 0, + src_row_offset: int = 0, + num_rows: int | None = None, + ) -> str: + """ + General VRAM Matrix Addition: dst[row_offset:] += src. + + row_offsets are logical rows (not VRAM addresses); num_rows defaults + to the source matrix's row count. + """ + dst_info = self[dst_matrix] + src_info = self[src_matrix] + + # Block-add path depends on TileCompiler VRAM layouts. + self._ensure_vram_matrix_layout(dst_matrix) + self._ensure_vram_matrix_layout(src_matrix) + + dst_addr = dst_info.vram_addr + src_addr = src_info.vram_addr + + dst_rows, dst_cols = dst_info.shape + src_rows, src_cols = src_info.shape + + if num_rows is None: + num_rows = src_rows + + # Ensure column count matches + assert dst_cols == src_cols, f"Column mismatch: dst={dst_cols}, src={src_cols}" + assert dst_row_offset + num_rows <= dst_rows, ( + f"dst row range out of bounds: offset={dst_row_offset}, num_rows={num_rows}, dst_rows={dst_rows}" + ) + assert src_row_offset + num_rows <= src_rows, ( + f"src row range out of bounds: offset={src_row_offset}, num_rows={num_rows}, src_rows={src_rows}" + ) + lines = [] + lines.append( + f"; === VRAM Matrix Add: " + f"{dst_matrix}[{dst_row_offset}:{dst_row_offset + num_rows}] += " + f"{src_matrix}[{src_row_offset}:{src_row_offset + num_rows}] ===" + ) + lines.append(f"; dst shape: {dst_info.shape}, src shape: {src_info.shape}") + + # Prefer block add path so we can reuse the compact C_LOOP-based add kernel. + block_aligned = ( + dst_cols % self.mlen == 0 + and src_cols % self.mlen == 0 + and dst_row_offset % self.mlen == 0 + and src_row_offset % self.mlen == 0 + and num_rows % self.mlen == 0 + ) + + if block_aligned: + num_row_blocks = num_rows // self.mlen + num_col_blocks = dst_cols // self.mlen + dst_row_block_base = dst_row_offset // self.mlen + src_row_block_base = src_row_offset // self.mlen + lines.append(f"; block add path: row_blocks={num_row_blocks}, col_blocks={num_col_blocks}") + + for row_block in range(num_row_blocks): + for col_block in range(num_col_blocks): + gp_regs = self.register_allocator.allocate_gp(4) + lines.append( + self.vram_block_add_asm( + src1_name=dst_matrix, + src1_row_idx=dst_row_block_base + row_block, + src1_col_idx=col_block, + src2_name=src_matrix, + src2_row_idx=src_row_block_base + row_block, + src2_col_idx=col_block, + target_name=dst_matrix, + target_row_idx=dst_row_block_base + row_block, + target_col_idx=col_block, + gp_regs=gp_regs, + ).rstrip("\n") + ) + self.register_allocator.free_gp(gp_regs) + else: + # Fallback for non-mlen-aligned ranges. + gp_regs = self.register_allocator.allocate_gp(2) + gp_dst = gp_regs[0] + gp_src = gp_regs[1] + num_col_blocks = dst_cols // self.mlen + lines.append(f"; fallback row-wise path: num_rows={num_rows}, num_col_blocks={num_col_blocks}") + + for row in range(num_rows): + dst_actual_row = dst_row_offset + row + src_actual_row = src_row_offset + row + + for col_block in range(num_col_blocks): + dst_block_addr = dst_addr + col_block * dst_rows * self.mlen + dst_actual_row * self.mlen + src_block_addr = src_addr + col_block * src_rows * self.mlen + src_actual_row * self.mlen + + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_block_addr}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_block_addr}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") + self.register_allocator.free_gp(gp_regs) + + isa_code = "\n".join(lines) + "\n" + return self._emit(isa_code) + + def vram_sub_projection_to( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_col_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """ + Sub-block multiplication: + target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[:][mram_col_idx]. + Target matrix must have been allocated via allocate_vram_matrix. + """ + if target_matrix not in self: + raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") + + target_info = self[target_matrix] + target_rows, _target_cols = target_info.shape + target_base_addr = target_info.vram_addr + + # VRAM layout: [batch, mlen, hidden/mlen], column-block major. + # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. + result_vram_addr = ( + target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen + ) + + gp_regs = self.register_allocator.allocate_gp(9) + + isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" + isa_code += self.vram_sub_projection_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_col_idx=mram_col_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + + self.register_allocator.free_gp(gp_regs) + + return self._emit(isa_code) + + def vram_sub_projection_T_to( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_row_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + ) -> str: + """ + Transposed sub-block multiplication: + target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[mram_row_idx][:]^T. + + Used by Flash Attention for S = Q @ K^T: + Q[i][:]: (mlen, hidden_size) row sub-block + K[j][:]: (mlen, hidden_size) row sub-block, transposed to (hidden_size, mlen) + S[i][j]: (mlen, mlen) + """ + if target_matrix not in self: + raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") + + target_info = self[target_matrix] + target_rows, _target_cols = target_info.shape + target_base_addr = target_info.vram_addr + + # VRAM layout: [batch, mlen, hidden/mlen], column-block major. + # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. + result_vram_addr = ( + target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen + ) + + gp_regs = self.register_allocator.allocate_gp(9) + + isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" + isa_code += self.vram_sub_projection_T_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_row_idx=mram_row_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + ) + + self.register_allocator.free_gp(gp_regs) + + return self._emit(isa_code) + + +__all__ = ["IsaMatrixMixin"] diff --git a/aten/plena/isa_tile_rows.py b/aten/plena/isa_tile_rows.py new file mode 100644 index 0000000..8d2e247 --- /dev/null +++ b/aten/plena/isa_tile_rows.py @@ -0,0 +1,391 @@ +"""Tile-row ISA helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import IsaBuilder, fp, gp + + +class IsaTileRowMixin: + # ========================================================================= + # Tile-row helpers (name-based) + # ========================================================================= + + def tile_row_max( + self, + source_matrix: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) + return self.tile_row_max_asm(source_addr, row_map) + + def tile_row_sum( + self, + source_matrix: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) + return self.tile_row_sum_asm(source_addr, row_map) + + def tile_row_exp( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_exp_asm(matrix_addr, rows) + + def tile_row_reci( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_reci_asm(matrix_addr, rows) + + def tile_row_sub_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_sub_fp_asm(matrix_addr, row_map) + + def tile_row_mul_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_mul_fp_asm(matrix_addr, row_map) + + def tile_row_add_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_add_fp_asm(matrix_addr, row_map) + + def tile_row_add( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) + src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) + return self.tile_row_add_asm(dst_addr, src_addr, rows) + + def tile_row_sub( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) + src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) + return self.tile_row_sub_asm(dst_addr, src_addr, rows) + + def tile_row_mul( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) + src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) + return self.tile_row_mul_asm(dst_addr, src_addr, rows) + + def tile_row_mul_fp_broadcast( + self, + matrix_name: str, + fpram_scalar_addr: int, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.tile_row_mul_fp_broadcast_asm(matrix_addr, fpram_scalar_addr, rows) + + def vram_fill_zero( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + return self.vram_fill_zero_asm(matrix_addr, rows) + + # ========================================================================= + # Tile-row ISA helpers (address-based) + # ========================================================================= + + def _arith_progression(self, values: list[int]) -> tuple[int, int, int] | None: + """Return (start, count, step) if values form an arithmetic progression.""" + if not values: + return None + if len(values) == 1: + return (values[0], 1, 0) + step = values[1] - values[0] + for i in range(2, len(values)): + if values[i] - values[i - 1] != step: + return None + if step == 0: + return None # Constant sequence (step=0, count>1) would cause infinite HW loop + return (values[0], len(values), step) + + def _row_progression(self, rows: list[int]) -> tuple[int, int, int] | None: + return None if self._unroll else self._arith_progression(rows) + + def _emit_tile_row_reduce( + self, + label: str, + source_vram_addr: int, + row_map: list[tuple[int, int]], + opcode: str, + opcode_extra_args: tuple[int, ...] = (), + clear_accumulator: bool = False, + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} from VRAM[{source_vram_addr}]") + rows = [row for row, _ in row_map] + fp_addrs = [addr_ for _, addr_ in row_map] + row_prog = self._row_progression(rows) + fp_prog = self._row_progression(fp_addrs) + + if row_prog is not None and fp_prog is not None: + row_start, row_count, row_step = row_prog + fp_start, _, fp_step = fp_prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), source_vram_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), fp_start) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + if clear_accumulator: + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + asm.instr(opcode, fp(1), gp(gp_src), 0, *opcode_extra_args) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), fp_step) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx, fpram_addr in row_map: + row_addr = source_vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + if clear_accumulator: + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + asm.instr(opcode, fp(1), gp(gp_src), 0, *opcode_extra_args) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), fpram_addr) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_unary(self, label: str, opcode: str, vram_addr: int, rows: list[int]) -> str: + gp_regs = self._reg.allocate_gp(2) + gp_src, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} on VRAM[{vram_addr}]") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), vram_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr(opcode, gp(gp_src), gp(gp_src), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + asm.instr(opcode, gp(gp_src), gp(gp_src), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_fp_scalar( + self, + label: str, + opcode: str, + vram_addr: int, + row_map: list[tuple[int, int]], + opcode_extra_args: tuple[int, ...] = (), + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_fp, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} FP on VRAM[{vram_addr}]") + rows = [row for row, _ in row_map] + fp_addrs = [addr_ for _, addr_ in row_map] + row_prog = self._row_progression(rows) + fp_prog = self._row_progression(fp_addrs) + + if row_prog is not None and fp_prog is not None: + row_start, row_count, row_step = row_prog + fp_start, _, fp_step = fp_prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), vram_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(0), fp_start) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr("S_LD_FP", fp(1), gp(gp_fp), 0) + asm.instr(opcode, gp(gp_src), gp(gp_src), fp(1), *opcode_extra_args) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(gp_fp), fp_step) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx, fpram_addr in row_map: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(0), fpram_addr) + asm.instr("S_LD_FP", fp(1), gp(gp_fp), 0) + asm.instr(opcode, gp(gp_src), gp(gp_src), fp(1), *opcode_extra_args) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_vector_op( + self, + label: str, + opcode: str, + dst_addr: int, + src_addr: int, + rows: list[int], + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_dst, gp_src, gp_loop = gp_regs + try: + assignment_op = {"Add": "+", "Sub": "-", "Mul": "*"}.get(label, label) + asm = IsaBuilder().comment(f"Tile Row {label}: VRAM[{dst_addr}] {assignment_op}= VRAM[{src_addr}]") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr(opcode, gp(gp_dst), gp(gp_dst), gp(gp_src), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + dst_row_addr = dst_addr + row_idx * self.mlen + src_row_addr = src_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_row_addr) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_row_addr) + asm.instr(opcode, gp(gp_dst), gp(gp_dst), gp(gp_src), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def tile_row_max_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_reduce("Max", source_vram_addr, row_map, "V_RED_MAX") + + def tile_row_sum_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_reduce( + "Sum", + source_vram_addr, + row_map, + "V_RED_SUM", + opcode_extra_args=(0,), + clear_accumulator=True, + ) + + def tile_row_exp_asm(self, vram_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_unary("Exp", "V_EXP_V", vram_addr, rows) + + def tile_row_reci_asm(self, vram_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_unary("Reciprocal", "V_RECI_V", vram_addr, rows) + + def tile_row_sub_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Sub", "V_SUB_VF", vram_addr, row_map, opcode_extra_args=(0, 0)) + + def tile_row_mul_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Mul", "V_MUL_VF", vram_addr, row_map, opcode_extra_args=(0,)) + + def tile_row_add_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Add", "V_ADD_VF", vram_addr, row_map, opcode_extra_args=(0,)) + + def tile_row_add_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Add", "V_ADD_VV", dst_addr, src_addr, rows) + + def tile_row_sub_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Sub", "V_SUB_VV", dst_addr, src_addr, rows) + + def tile_row_mul_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Mul", "V_MUL_VV", dst_addr, src_addr, rows) + + def tile_row_mul_fp_broadcast_asm(self, vram_addr: int, fpram_scalar_addr: int, rows: list[int]) -> str: + row_map = [(r, fpram_scalar_addr) for r in rows] + return self.tile_row_mul_fp_asm(vram_addr, row_map) + + def vram_fill_zero_asm( + self, + vram_addr: int, + rows: list[int], + ) -> str: + """ + VRAM Fill Zero: fill specified rows with 0. + + For each row_idx in rows: + VRAM[row] = 0 + """ + if not rows: + return self._emit(IsaBuilder().comment(f"=== VRAM Fill Zero: VRAM[{vram_addr}] rows [] = 0 ===")) + + gp_regs = self._reg.allocate_gp(2) + gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"=== VRAM Fill Zero: VRAM[{vram_addr}] rows {rows} = 0 ===") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), vram_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr("V_MUL_VF", gp(gp_dst), gp(gp_dst), fp(0), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), row_addr) + asm.instr("V_MUL_VF", gp(gp_dst), gp(gp_dst), fp(0), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + +__all__ = ["IsaTileRowMixin"] diff --git a/aten/plena/memory.py b/aten/plena/memory.py new file mode 100644 index 0000000..5d51187 --- /dev/null +++ b/aten/plena/memory.py @@ -0,0 +1,674 @@ +"""Memory layouts and allocators for the ATen PLENA compiler path.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import math + +from compiler.aten.plena.constants import MLEN + + +# ============================================================================== +# Virtual Memory Manager +# ============================================================================== + + +@dataclass +class MemoryBlock: + """Memory Block Information""" + + name: str # Allocation name (e.g., "W[0][1]" or "activation_A") + addr: int # Starting address + size: int # Block size (number of elements) + + def __repr__(self) -> str: + return f"MemBlock({self.name}, addr={self.addr}, size={self.size})" + + +class VirtualMemoryManager: + """ + Virtual Memory Manager + + Core Design: + - used_stack: Allocated and in-use memory blocks + - free_stack: Freed and reusable memory blocks + + Workflow: + 1. allocate(name, size): Allocate memory + - Prioritize best-fit search for reusable blocks in free_stack + - If not found, use bump allocation from the end + 2. free(name): Free memory + - Move block from used_stack to free_stack + - Address can be reused by subsequent allocate calls + - Throws exception if not found when strict=True, returns None when strict=False + + VRAM/MRAM storage format: (batch_size, mlen, hidden_size/mlen), column-block major. + mlen=64, blen=4. Alignment depends on storage hierarchy. + """ + + def __init__(self, total_size: int, alignment: int = MLEN, mem_name: str = "Memory"): + """ + Args: + total_size: Total memory size (number of elements) + alignment: alignment granularity (VRAM uses MLEN=64, MRAM uses MLEN*MLEN=4096) + mem_name: Memory name, for debugging information (e.g., "VRAM" or "MRAM") + """ + self.total_size = total_size + self.alignment = alignment + self.mem_name = mem_name + self.next_bump = 0 # Bump allocation pointer + + # Two core stacks + self.used_stack: list[MemoryBlock] = [] + self.free_stack: list[MemoryBlock] = [] + + def _align(self, value: int) -> int: + """Align value to alignment""" + return ((value + self.alignment - 1) // self.alignment) * self.alignment + + def allocate(self, name: str, size: int) -> int: + """ + Allocate memory. + + Strategy: + 1. Best-fit from free_stack (reusable block with least waste). + 2. If no suitable block, bump allocation. + + Returns: + Allocated starting address. + + Raises: + MemoryError: Insufficient memory. + """ + aligned_size = self._align(size) + + # Strategy 1: best-fit from free_stack + best_idx = None + best_waste = float("inf") + + for i, block in enumerate(self.free_stack): + if block.size >= aligned_size: + waste = block.size - aligned_size + if waste < best_waste: + best_waste = waste + best_idx = i + + if best_idx is not None: + reused_block = self.free_stack.pop(best_idx) + + # If block is larger than needed, split remaining part and return to free_stack + if reused_block.size > aligned_size: + remaining = MemoryBlock( + name="", addr=reused_block.addr + aligned_size, size=reused_block.size - aligned_size + ) + self.free_stack.append(remaining) + + new_block = MemoryBlock(name=name, addr=reused_block.addr, size=aligned_size) + self.used_stack.append(new_block) + return new_block.addr + + # Strategy 2: Bump allocation + aligned_addr = self._align(self.next_bump) + + if self.total_size > 0 and aligned_addr + aligned_size > self.total_size: + raise MemoryError( + f"{self.mem_name} overflow: need {aligned_size} at addr {aligned_addr}, " + f"total_size={self.total_size}, " + f"used={len(self.used_stack)} blocks, " + f"free={len(self.free_stack)} blocks" + ) + + new_block = MemoryBlock(name=name, addr=aligned_addr, size=aligned_size) + self.used_stack.append(new_block) + self.next_bump = aligned_addr + aligned_size + return aligned_addr + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + """ + Free memory: move block from used_stack to free_stack. + + Args: + name: Name of allocation to free + strict: Throws KeyError if not found when strict=True, returns None when strict=False + + Returns: + Freed memory block, returns None if strict=False and not found + """ + for i, block in enumerate(self.used_stack): + if block.name == name: + freed = self.used_stack.pop(i) + self.free_stack.append(freed) + self._coalesce_free_stack() + return freed + + if strict: + raise KeyError( + f"{self.mem_name}: allocation '{name}' not found in used_stack. " + f"Current used: {[b.name for b in self.used_stack]}" + ) + return None + + def mark_used(self, addr: int, size: int, name: str) -> None: + """ + Register a pre-known address range as occupied without bump allocation. + + Used for prestaged VRAM tensors that are already in VRAM before program + execution (e.g. Q pre-loaded by the test harness). Advances next_bump + past the end of this region so subsequent bump allocations do not collide. + + Args: + addr: Start address of the pre-occupied region. + size: Number of elements in the region. + name: Name for tracking/free. + """ + aligned_size = self._align(size) + block = MemoryBlock(name=name, addr=addr, size=aligned_size) + self.used_stack.append(block) + # Advance bump pointer past this region if it would otherwise overlap. + end = addr + aligned_size + if self.next_bump < end: + self.next_bump = end + + def _coalesce_free_stack(self): + """ + Merge adjacent free blocks by address to reduce long-term fragmentation. + """ + if len(self.free_stack) <= 1: + return + + blocks = sorted(self.free_stack, key=lambda b: b.addr) + merged: list[MemoryBlock] = [blocks[0]] + for block in blocks[1:]: + prev = merged[-1] + if prev.addr + prev.size == block.addr: + merged[-1] = MemoryBlock( + name="", + addr=prev.addr, + size=prev.size + block.size, + ) + else: + merged.append(block) + + self.free_stack = merged + + def is_allocated(self, name: str) -> bool: + """Check if a name is in used_stack""" + return any(b.name == name for b in self.used_stack) + + def get_block(self, name: str) -> MemoryBlock | None: + """Get memory block with specified name from used_stack""" + for block in self.used_stack: + if block.name == name: + return block + return None + + def get_used_size(self) -> int: + """Get total used size""" + return sum(b.size for b in self.used_stack) + + def get_free_size(self) -> int: + """Get total reusable size""" + return sum(b.size for b in self.free_stack) + + def reset(self): + """Reset manager""" + self.next_bump = 0 + self.used_stack.clear() + self.free_stack.clear() + + def print_status(self): + """Print memory status""" + print(f"=== {self.mem_name} Virtual Memory Status ===") + print(f"Total size: {self.total_size}") + print(f"Bump pointer: {self.next_bump}") + print(f"Used blocks ({len(self.used_stack)}):") + for b in self.used_stack: + print(f" {b}") + print(f"Free blocks ({len(self.free_stack)}):") + for b in self.free_stack: + print(f" {b}") + total_used = self.get_used_size() + total_free = self.get_free_size() + if self.total_size > 0: + available = self.total_size - self.next_bump + total_free + print( + f"Summary: used={total_used}, free={total_free}, " + f"bump={self.next_bump}, available={available}/{self.total_size}" + ) + else: + print(f"Summary: used={total_used}, free={total_free}, bump={self.next_bump} (unlimited mode)") + + def __repr__(self) -> str: + return ( + f"VirtualMemoryManager({self.mem_name}, " + f"used={len(self.used_stack)}, free={len(self.free_stack)}, " + f"bump={self.next_bump}/{self.total_size})" + ) + + +# ============================================================================== +# Sub-matrix Information +# ============================================================================== + + +@dataclass +class SubMatrixInfo: + """Metadata for sub-matrices""" + + parent_name: str # Parent matrix name + row_idx: int # Sub-block row index + col_idx: int # Sub-block column index + shape: tuple[int, int] # Sub-block shape (typically 64x64) + + # Pre-calculated addresses (computed during compiler phase, used directly at runtime) + hbm_offset: int = 0 # Offset in HBM (in elements) + mram_addr: int | None = None # Address in MRAM (if loaded) + + def __repr__(self) -> str: + mram_str = f"{self.mram_addr}" if self.mram_addr is not None else "None" + return ( + f"SubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " + f"shape={self.shape}, hbm_off={self.hbm_offset}, mram={mram_str})" + ) + + +@dataclass +class MatrixBlockLayout: + """ + Block layout information for large matrices. + + HBM storage: [rows, cols] row-major contiguous, stride=cols per row. + MRAM storage: [batch, mlen, hidden/mlen] column-block major. + """ + + name: str + full_shape: tuple[int, int] # Full matrix shape (rows, cols) + block_size: int = MLEN # Sub-block size (default 64) + + num_row_blocks: int = 0 + num_col_blocks: int = 0 + + # HBM Address Information + hbm_base_addr: int = 0 + hbm_size: int = 0 # Size after applying real_data_ratio (MXFP8 = 1.125) + + # Sub-block map: (row_idx, col_idx) -> SubMatrixInfo + sub_blocks: dict[tuple[int, int], SubMatrixInfo] = field(default_factory=dict) + + def __post_init__(self): + """Initialize block information""" + rows, cols = self.full_shape + self.num_row_blocks = math.ceil(rows / self.block_size) + self.num_col_blocks = math.ceil(cols / self.block_size) + + # Create information for all sub-blocks (pre-calculate addresses) + for r in range(self.num_row_blocks): + for c in range(self.num_col_blocks): + # HBM offset (row-major): sub-block (r,c) starts at r*block_size*cols + c*block_size + hbm_offset = r * self.block_size * cols + c * self.block_size + + sub_info = SubMatrixInfo( + parent_name=self.name, + row_idx=r, + col_idx=c, + shape=(self.block_size, self.block_size), + hbm_offset=hbm_offset, + mram_addr=None, + ) + self.sub_blocks[(r, c)] = sub_info + + def get_sub_block(self, row_idx: int, col_idx: int) -> SubMatrixInfo: + """Get specified sub-block""" + if (row_idx, col_idx) not in self.sub_blocks: + raise IndexError(f"Sub block [{row_idx}][{col_idx}] out of range") + return self.sub_blocks[(row_idx, col_idx)] + + def get_row_blocks(self, row_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a row""" + return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] + + def get_col_blocks(self, col_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a column""" + return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] + + +# ============================================================================== +# VRAM Sub-matrix Information +# ============================================================================== + + +@dataclass +class VRAMSubMatrixInfo: + """Metadata for sub-matrices in VRAM""" + + parent_name: str # Parent matrix name + row_idx: int # Sub-block row index (batch dimension) + col_idx: int # Sub-block column index (hidden dimension) + shape: tuple[int, int] # Sub-block shape (typically mlen x mlen) + + # Pre-calculated VRAM address + vram_addr: int = 0 + + def __repr__(self) -> str: + return ( + f"VRAMSubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " + f"shape={self.shape}, vram={self.vram_addr})" + ) + + +@dataclass +class VRAMMatrixBlockLayout: + """ + Block layout for large matrices in VRAM. + + VRAM storage format: [batch, mlen, hidden/mlen] column-block major. + - Batch dimension contiguous within each column block. + - Column blocks laid out sequentially. + + Column block c base address: vram_base + c * batch * mlen. + Sub-block (r, c) offset within column block: r * mlen * mlen. + """ + + name: str + full_shape: tuple[int, int] # Full matrix shape (batch, hidden_size) + vram_base_addr: int # VRAM base address + block_size: int = MLEN # Sub-block size (default 64) + + num_row_blocks: int = 0 # Number of blocks in batch dimension + num_col_blocks: int = 0 # Number of blocks in hidden dimension + + # Sub-block map: (row_idx, col_idx) -> VRAMSubMatrixInfo + sub_blocks: dict[tuple[int, int], VRAMSubMatrixInfo] = field(default_factory=dict) + + def __post_init__(self): + """Initialize block information""" + batch, hidden = self.full_shape + self.num_row_blocks = math.ceil(batch / self.block_size) + self.num_col_blocks = math.ceil(hidden / self.block_size) + + # VRAM column-block major address calculation: + # col block c base = vram_base + c * batch * mlen + # row sub-block r offset within column block = r * mlen * mlen + for r in range(self.num_row_blocks): + for c in range(self.num_col_blocks): + col_block_base = self.vram_base_addr + c * batch * self.block_size + row_offset = r * self.block_size * self.block_size + vram_addr = col_block_base + row_offset + + sub_info = VRAMSubMatrixInfo( + parent_name=self.name, + row_idx=r, + col_idx=c, + shape=(self.block_size, self.block_size), + vram_addr=vram_addr, + ) + self.sub_blocks[(r, c)] = sub_info + + def get_sub_block(self, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: + """Get specified sub-block""" + if (row_idx, col_idx) not in self.sub_blocks: + raise IndexError(f"VRAM sub block [{row_idx}][{col_idx}] out of range") + return self.sub_blocks[(row_idx, col_idx)] + + def get_row_blocks(self, row_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a row (A[row_idx][:])""" + return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] + + def get_col_blocks(self, col_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a column""" + return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] + + def get_row_vram_addrs(self, row_idx: int) -> list[int]: + """Get list of VRAM addresses for all sub-blocks in a row""" + return [block.vram_addr for block in self.get_row_blocks(row_idx)] + + +# ============================================================================== +# Unified Memory Object Metadata +# ============================================================================== + + +@dataclass +class MemoryObjectInfo: + """Unified metadata for objects managed across HBM / VRAM / FPRAM.""" + + name: str + kind: str + dtype: str = "fp16" + shape: tuple[int, int] = (0, 0) + size: int = 0 + hbm_addr: int = -1 + hbm_size: int = 0 + vram_addr: int | None = None + fpram_addr: int | None = None + fpram_size: int = 0 + + +@dataclass +class FPRAMObjectLayout: + """FPRAM object layout.""" + + name: str + fpram_addr: int + size: int + dtype: str = "fp16" + kind: str = "FPRAMObject" + + +# ============================================================================== +# MRAM Allocator +# ============================================================================== + + +class MRAMAllocator: + """ + Matrix RAM address allocator (based on VirtualMemoryManager). + + Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. + Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). + Supports virtual free/reuse: freed blocks move to free_stack and are + preferred by subsequent allocations. + """ + + def __init__(self, total_size: int = MLEN * MLEN * 4): + """ + Args: + total_size: Total MRAM size (default 16384, can hold 4 64x64 matrix blocks) + """ + self.total_size = total_size + self._vmm = VirtualMemoryManager( + total_size=total_size, + alignment=MLEN * MLEN, # aligned to one sub-block size + mem_name="MRAM", + ) + + @property + def next_free(self) -> int: + return self._vmm.next_bump + + @property + def used_stack(self) -> list[MemoryBlock]: + return self._vmm.used_stack + + @property + def free_stack(self) -> list[MemoryBlock]: + return self._vmm.free_stack + + def allocate(self, name: str, size: int) -> int: + """Allocate MRAM space (prioritize reusing freed blocks).""" + return self._vmm.allocate(name, size) + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + """Free specified allocation: move from used_stack to free_stack.""" + return self._vmm.free(name, strict=strict) + + def is_allocated(self, name: str) -> bool: + """Check if a name is allocated""" + return self._vmm.is_allocated(name) + + def reset(self): + """Reset allocator""" + self._vmm.reset() + + def print_status(self): + """Print memory status""" + self._vmm.print_status() + + +class VRAMAllocator: + """ + VRAM address allocator (based on VirtualMemoryManager). + + VRAM supports best-fit reuse + bump allocation, same as MRAM/FPRAM allocators. + Alignment defaults to MLEN to match VRAM storage format requirements. + """ + + def __init__(self, alignment: int = MLEN, total_size: int = 0): + self.alignment = alignment + self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name="VRAM") + + @property + def next_free(self) -> int: + return self._vmm.next_bump + + @next_free.setter + def next_free(self, value: int): + self._vmm.next_bump = value + + @property + def used_stack(self) -> list[MemoryBlock]: + return self._vmm.used_stack + + @property + def free_stack(self) -> list[MemoryBlock]: + return self._vmm.free_stack + + def allocate(self, size: int, name: str = "") -> int: + if not name: + raise ValueError("VRAMAllocator.allocate() requires name for subsequent free.") + return self._vmm.allocate(name, size) + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + return self._vmm.free(name, strict=strict) + + def is_allocated(self, name: str) -> bool: + return self._vmm.is_allocated(name) + + def reset(self): + self._vmm.reset() + + def print_status(self): + self._vmm.print_status() + + def __repr__(self) -> str: + return ( + f"VRAMAllocator(next_free={self.next_free}, alignment={self.alignment}, " + f"used={len(self.used_stack)}, free={len(self.free_stack)})" + ) + + +class FPRAMAllocator: + """ + Floating Point RAM Allocator (based on VirtualMemoryManager). + + FPRAM stores scalar FP values (f16), accessed via S_LD_FP / S_ST_FP. + Uses the same strategy as VRAM/MRAM allocator: + - used_stack + free_stack + - allocate: best-fit reuse first, then bump + - free: move block to free_stack, supports out-of-order free + + Fixed slot conventions (must not be overwritten by dynamic allocations): + slot 0 = 0.0 (gp0/f0 reserved as hardware zero) + slot 1 = attn_scale + slot 2 = -inf (online softmax) + slot 3 = eps (rms_norm/layer_norm) + slot 4 = 1/hidden_size (rms_norm/layer_norm) + slot 5 = 1.0 (FFN SiLU sigmoid denominator; im2col fp_one_reg) + + Hardware: 1024 f16 elements (configurable via total_size). + """ + + def __init__(self, total_size: int = 1024): + """ + Args: + total_size: Total FP RAM size (default 1024, matching hardware fpsram) + """ + self.total_size = total_size + self._vmm = VirtualMemoryManager( + total_size=total_size, + alignment=1, + mem_name="FPRAM", + ) + self.allocations: dict[str, tuple[int, int]] = {} + self._snapshots: dict[int, tuple[int, list[MemoryBlock], list[MemoryBlock], dict[str, tuple[int, int]]]] = {} + self._next_snapshot_id = 1 + + @property + def next_free(self) -> int: + """Compatibility alias: next bump pointer.""" + return self._vmm.next_bump + + @next_free.setter + def next_free(self, value: int): + if value < 0 or value > self.total_size: + raise ValueError(f"next_free out of range: {value}, expected [0, {self.total_size}]") + self._vmm.next_bump = value + + @property + def used_stack(self) -> list[MemoryBlock]: + return self._vmm.used_stack + + @property + def free_stack(self) -> list[MemoryBlock]: + return self._vmm.free_stack + + def allocate(self, name: str, size: int) -> int: + """Allocate FP RAM space (best-fit + bump).""" + if size <= 0: + raise ValueError(f"FPRAM allocation size must be > 0, got {size}") + if name in self.allocations: + raise KeyError(f"FPRAM name '{name}' already allocated") + + addr = self._vmm.allocate(name, size) + self.allocations[name] = (addr, size) + return addr + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + """Free a block and move it to free_stack (same as VirtualMemoryManager).""" + freed = self._vmm.free(name, strict=strict) + if freed is not None: + self.allocations.pop(name, None) + return freed + + def save_state(self) -> int: + """ + Save current allocator state and return a snapshot token. + """ + sid = self._next_snapshot_id + self._next_snapshot_id += 1 + self._snapshots[sid] = ( + self._vmm.next_bump, + [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.used_stack], + [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.free_stack], + dict(self.allocations), + ) + return sid + + def restore_state(self, snapshot: int): + """Restore allocator state from snapshot token.""" + if snapshot not in self._snapshots: + raise KeyError(f"Unknown FPRAM snapshot id: {snapshot}") + next_bump, used_stack, free_stack, allocations = self._snapshots[snapshot] + self._vmm.next_bump = next_bump + self._vmm.used_stack = [MemoryBlock(b.name, b.addr, b.size) for b in used_stack] + self._vmm.free_stack = [MemoryBlock(b.name, b.addr, b.size) for b in free_stack] + self.allocations = dict(allocations) + + def reset(self): + """Reset allocator""" + self._vmm.reset() + self.allocations.clear() + self._snapshots.clear() + self._next_snapshot_id = 1 + + diff --git a/aten/plena/registers.py b/aten/plena/registers.py new file mode 100644 index 0000000..029dba4 --- /dev/null +++ b/aten/plena/registers.py @@ -0,0 +1,66 @@ +"""Register allocation helpers for the ATen PLENA compiler path.""" + +class RegisterAllocator: + """Register Allocator: Manages address registers and GP registers""" + + def __init__(self, start_gp: int = 1, start_addr: int = 0, start_fp: int = 1): + # HW OPERAND_WIDTH = 4 bits → gp0-gp15; gp0 reserved as constant 0. + self.gp_registers = list(range(start_gp, 16)) + self.addr_registers = list(range(start_addr, 8)) + # f0 reserved as constant 0 (writing to f0 is a no-op for V_RED_MAX/V_RED_SUM). + self.fp_registers = list(range(start_fp, 8)) + self.used_gp = [] + self.used_addr = [] + self.used_fp = [] + + def allocate_gp(self, count: int = 1) -> list[int]: + if len(self.gp_registers) < count: + raise RuntimeError(f"Not enough GP registers available. Need {count}, have {len(self.gp_registers)}") + + allocated = self.gp_registers[:count] + self.gp_registers = self.gp_registers[count:] + self.used_gp.extend(allocated) + return allocated + + def allocate_addr(self, count: int = 1) -> list[int]: + if len(self.addr_registers) < count: + raise RuntimeError(f"Not enough address registers available. Need {count}, have {len(self.addr_registers)}") + + allocated = self.addr_registers[:count] + self.addr_registers = self.addr_registers[count:] + self.used_addr.extend(allocated) + return allocated + + def free_gp(self, registers: list[int]): + for reg in registers: + if reg in self.used_gp: + self.used_gp.remove(reg) + self.gp_registers.append(reg) + self.gp_registers.sort() + + def free_addr(self, registers: list[int]): + for reg in registers: + if reg in self.used_addr: + self.used_addr.remove(reg) + self.addr_registers.append(reg) + self.addr_registers.sort() + + def allocate_fp(self, count: int = 1) -> list[int]: + if len(self.fp_registers) < count: + raise RuntimeError(f"Not enough FP registers available. Need {count}, have {len(self.fp_registers)}") + + # Reverse allocation: prefer high-numbered regs to avoid conflicts with legacy hardcoded forward-allocation. + allocated = list(reversed(self.fp_registers[-count:])) + self.fp_registers = self.fp_registers[:-count] + self.used_fp.extend(allocated) + return allocated + + def free_fp(self, registers: list[int]): + for reg in registers: + if reg in self.used_fp: + self.used_fp.remove(reg) + self.fp_registers.append(reg) + # Keep sorted so allocate_fp's tail-slice continues to return descending IDs. + self.fp_registers.sort() + + diff --git a/aten/plena/tile_compiler.py b/aten/plena/tile_compiler.py new file mode 100644 index 0000000..ad6b0ed --- /dev/null +++ b/aten/plena/tile_compiler.py @@ -0,0 +1,1015 @@ +"""Tile/block lowering helpers for the ATen PLENA compiler.""" + +from __future__ import annotations + +import math + +from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl +from compiler.aten.isa_builder import IsaBuilder, addr as areg, gp +from compiler.aten.plena.constants import BLEN, MLEN +from compiler.aten.plena.memory import ( + FPRAMAllocator, + FPRAMObjectLayout, + MRAMAllocator, + MatrixBlockLayout, + MemoryObjectInfo, + SubMatrixInfo, + VRAMAllocator, + VRAMMatrixBlockLayout, + VRAMSubMatrixInfo, +) + + +class TileCompiler: + """ + Tile compiler / sub-matrix manager. + + Core Functions: + 1. Register large matrices as blocked matrices. + 2. Support sub-block indexing: matrix[row_idx][col_idx] or matrix[row_idx][:]. + 3. Pre-calculate all addresses at compiler phase. + 4. Generate ISA code for load_sub_matrix and sub_projection. + + Key Constraints: + - Minimum block size is 64x64 (MLEN). + - Matrix must be loaded into MRAM before participating in computation. + - HBM and VRAM/MRAM use different storage formats (row-major vs column-block major). + """ + + def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): + self.mlen = mlen + self.blen = blen + self.unroll_loops = unroll_loops + + # Layout tables + self.hbm_matrices: dict[str, MatrixBlockLayout] = {} + self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} + self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} + # Memory Allocators + self.vram_allocator = VRAMAllocator() + self.mram_allocator = MRAMAllocator() + self.fpram_allocator = FPRAMAllocator() + + # Currently loaded sub-blocks in MRAM + self.loaded_sub_blocks: dict[str, SubMatrixInfo] = {} + + # Pre-calculated address cache + self._address_cache: dict[str, int] = {} + + def __contains__(self, name: str) -> bool: + return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices + + def __getitem__(self, name: str) -> MemoryObjectInfo: + if name not in self: + raise KeyError(f"Object '{name}' not found") + info = MemoryObjectInfo(name=name, kind="Unknown") + hbm_layout = self.hbm_matrices.get(name) + vram_layout = self.vram_matrices.get(name) + fpram_layout = self.fpram_matrices.get(name) + + if hbm_layout is not None: + rows, cols = hbm_layout.full_shape + info.shape = hbm_layout.full_shape + info.size = rows * cols + info.hbm_addr = hbm_layout.hbm_base_addr + info.hbm_size = hbm_layout.hbm_size + info.kind = "Matrix" + + if vram_layout is not None: + rows, cols = vram_layout.full_shape + info.shape = vram_layout.full_shape + info.size = rows * cols + info.vram_addr = vram_layout.vram_base_addr + info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" + + if fpram_layout is not None: + info.shape = (1, fpram_layout.size) + info.size = fpram_layout.size + info.fpram_addr = fpram_layout.fpram_addr + info.fpram_size = fpram_layout.size + info.kind = "FPRAMObject" + + return info + + def get(self, name: str, default: MemoryObjectInfo | None = None) -> MemoryObjectInfo | None: + try: + return self[name] + except KeyError: + return default + + def get_hbm_layout(self, name: str) -> MatrixBlockLayout: + """Read HBM matrix layout by name.""" + if name not in self.hbm_matrices: + raise KeyError(f"HBM matrix '{name}' not found") + return self.hbm_matrices[name] + + def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: + """Read VRAM matrix layout by name.""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not found") + return self.vram_matrices[name] + + def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: + """Read FPRAM object layout by name.""" + if name not in self.fpram_matrices: + raise KeyError(f"FPRAM object '{name}' not found") + return self.fpram_matrices[name] + + # ========================================================================== + # Unified Object Management APIs + # ========================================================================== + + def add_hbm_object( + self, + name: str, + shape: tuple[int, int], + hbm_addr: int, + dtype: str = "fp16", + kind: str = "HBMObject", + real_data_ratio: float = 1.125, + strict: bool = False, + ) -> MemoryObjectInfo: + del dtype, kind + self.register_matrix( + name=name, + shape=shape, + hbm_base_addr=hbm_addr, + real_data_ratio=real_data_ratio, + strict=strict, + ) + return self[name] + + def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.hbm_matrices: + if strict: + raise KeyError(f"HBM object '{name}' not found") + return None + info = self[name] + self.hbm_matrices.pop(name, None) + return info + + def add_vram_object( + self, + name: str, + shape: tuple[int, int], + vram_addr: int | None = None, + dtype: str = "fp16", + kind: str = "VRAMObject", + allocate_if_none: bool = True, + strict: bool = True, + ) -> MemoryObjectInfo: + rows, cols = shape + size = rows * cols + if vram_addr is None: + if not allocate_if_none: + raise ValueError("vram_addr is None and allocate_if_none is False") + vram_addr = self.vram_allocator.allocate(size=size, name=name) + del dtype, kind + self.register_vram_matrix( + name=name, + shape=shape, + vram_base_addr=vram_addr, + strict=strict, + ) + return self[name] + + def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.vram_matrices: + if strict: + raise KeyError(f"VRAM object '{name}' not found") + return None + info = self[name] + self.vram_allocator.free(name, strict=strict) + self.vram_matrices.pop(name, None) + return info + + def add_fpram_object( + self, + name: str, + size: int, + dtype: str = "fp16", + kind: str = "FPRAMObject", + ) -> MemoryObjectInfo: + fpram_addr = self.fpram_allocator.allocate(name, size) + self.fpram_matrices[name] = FPRAMObjectLayout( + name=name, + fpram_addr=fpram_addr, + size=size, + dtype=dtype, + kind=kind, + ) + return self[name] + + def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.fpram_matrices: + if strict: + raise KeyError(f"FPRAM object '{name}' not found") + return None + info = self[name] + self.fpram_allocator.free(name, strict=strict) + self.fpram_matrices.pop(name, None) + return info + + def register_matrix( + self, + name: str, + shape: tuple[int, int], + hbm_base_addr: int, + real_data_ratio: float = 1.125, + strict: bool = True, + ) -> MatrixBlockLayout: + """ + Register a large matrix and automatically block it. + + Args: + name: matrix name + shape: full shape (rows, cols) + hbm_base_addr: HBM base address + real_data_ratio: HBM storage ratio (MXFP8 = 1.125) + strict: if False, skip mlen-alignment check (for raw HBM access) + + Returns: + MatrixBlockLayout object + """ + rows, cols = shape + + if strict: + if rows % self.mlen != 0: + raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") + if cols % self.mlen != 0: + raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") + + size = rows * cols + hbm_size = int(size * real_data_ratio) + + layout = MatrixBlockLayout( + name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size + ) + + self.hbm_matrices[name] = layout + return layout + + def get_sub_block(self, name: str, row_idx: int, col_idx: int) -> SubMatrixInfo: + """Get sub-block information.""" + if name not in self.hbm_matrices: + raise KeyError(f"Matrix '{name}' not registered") + return self.hbm_matrices[name].get_sub_block(row_idx, col_idx) + + def get_row_blocks(self, name: str, row_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a row: matrix[row_idx][:]""" + if name not in self.hbm_matrices: + raise KeyError(f"Matrix '{name}' not registered") + return self.hbm_matrices[name].get_row_blocks(row_idx) + + def get_col_blocks(self, name: str, col_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a column: matrix[:][col_idx]""" + if name not in self.hbm_matrices: + raise KeyError(f"Matrix '{name}' not registered") + return self.hbm_matrices[name].get_col_blocks(col_idx) + + # ========================================================================== + # VRAM Sub-matrix Management + # ========================================================================== + + def register_vram_matrix( + self, + name: str, + shape: tuple[int, int], + vram_base_addr: int, + strict: bool = True, + ) -> VRAMMatrixBlockLayout: + """ + Register a large matrix in VRAM and automatically block it. + + Args: + name: matrix name + shape: full shape (batch, hidden_size) + vram_base_addr: VRAM base address + + Returns: + VRAMMatrixBlockLayout object + """ + batch, hidden = shape + + if strict: + if batch % self.mlen != 0: + raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") + if hidden % self.mlen != 0: + raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") + + layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) + + self.vram_matrices[name] = layout + return layout + + def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: + """Get VRAM sub-block information""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not registered") + return self.vram_matrices[name].get_sub_block(row_idx, col_idx) + + def get_vram_row_blocks(self, name: str, row_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a row of VRAM matrix: matrix[row_idx][:]""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not registered") + return self.vram_matrices[name].get_row_blocks(row_idx) + + def get_vram_col_blocks(self, name: str, col_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a column of VRAM matrix: matrix[:][col_idx]""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not registered") + return self.vram_matrices[name].get_col_blocks(col_idx) + + # ========================================================================== + # Address Calculation (Core - Pre-calculated during compiler phase) + # ========================================================================== + + def compute_hbm_offset(self, name: str, row_idx: int, col_idx: int) -> int: + """ + Compute HBM offset for sub-block (in elements, not bytes). + + HBM row-major: sub-block (r, c) starts at r*block_size*full_cols + c*block_size. + """ + layout = self.hbm_matrices[name] + sub_block = layout.get_sub_block(row_idx, col_idx) + return sub_block.hbm_offset + + def compute_absolute_hbm_addr(self, name: str, row_idx: int, col_idx: int) -> int: + """ + Calculate absolute HBM address of sub-block (in elements). + + Returns: + Absolute HBM address = base + offset + """ + layout = self.hbm_matrices[name] + offset = self.compute_hbm_offset(name, row_idx, col_idx) + return layout.hbm_base_addr + offset + + # ========================================================================== + # ISA Generation: Load Sub Matrix + # ========================================================================== + + def load_sub_matrix_asm( + self, + name: str, + row_idx: int, + col_idx: int, + mram_dest_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """ + Generate ISA code for loading sub-matrix from HBM to MRAM. + + HBM is row-major; H_PREFETCH_M loads mlen x mlen blocks into MRAM. + SCALE = full matrix element count; STRIDE = full column width (row stride). + """ + if gp_regs is None: + gp_regs = [1, 2, 3] + + layout = self.hbm_matrices[name] + sub_block = layout.get_sub_block(row_idx, col_idx) + + hbm_offset = sub_block.hbm_offset + sub_block.mram_addr = mram_dest_addr + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") + asm.comment(f"HBM offset: {hbm_offset} (precomputed)") + + full_size = layout.full_shape[0] * layout.full_shape[1] + full_cols = layout.full_shape[1] + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_dest_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) + + block_key = f"{name}[{row_idx}][{col_idx}]" + self.loaded_sub_blocks[block_key] = sub_block + + return asm.render() + + def load_row_sub_matrices_asm( + self, + name: str, + row_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """ + Generate ISA code for loading all sub-blocks in a row: matrix[row_idx][:] + + SCALE and STRIDE set once for all sub-blocks in the row. + """ + if gp_regs is None: + gp_regs = [1, 2, 3] + + layout = self.hbm_matrices[name] + num_col_blocks = layout.num_col_blocks + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") + + # Set SCALE and STRIDE once for all sub-blocks + full_size = layout.full_shape[0] * layout.full_shape[1] + full_cols = layout.full_shape[1] + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + mram_addr = mram_start_addr + block_size = self.mlen * self.mlen + + for col_idx in range(num_col_blocks): + sub_block = layout.get_sub_block(row_idx, col_idx) + hbm_offset = sub_block.hbm_offset + + sub_block.mram_addr = mram_addr + + asm.comment(f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) + + block_key = f"{name}[{row_idx}][{col_idx}]" + self.loaded_sub_blocks[block_key] = sub_block + + mram_addr += block_size + + return asm.render() + + def load_col_sub_matrices_asm( + self, + name: str, + col_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """ + Generate ISA code for loading all sub-blocks in a column: matrix[:][col_idx]. + + Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. + k_block_start/k_block_count select a K-split slice of the column. + """ + if gp_regs is None: + gp_regs = [1, 2, 3] + + layout = self.hbm_matrices[name] + num_row_blocks = layout.num_row_blocks + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") + + # Set SCALE and STRIDE once for all sub-blocks + full_size = layout.full_shape[0] * layout.full_shape[1] + full_cols = layout.full_shape[1] + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + mram_addr = mram_start_addr + block_size = self.mlen * self.mlen + + effective_count = k_block_count if k_block_count is not None else num_row_blocks + for row_idx in range(k_block_start, k_block_start + effective_count): + sub_block = layout.get_sub_block(row_idx, col_idx) + hbm_offset = sub_block.hbm_offset + + sub_block.mram_addr = mram_addr + + asm.comment(f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) + + block_key = f"{name}[{row_idx}][{col_idx}]" + self.loaded_sub_blocks[block_key] = sub_block + + mram_addr += block_size + + return asm.render() + + # ========================================================================== + # ISA Generation: Sub Projection + # ========================================================================== + + def _vram_sub_projection_asm_impl( + self, + header_lines: list[str], + vram_row_start_addr: int, + mram_start_addr: int, + result_vram_addr: int, + full_batch: int, + num_hidden_blocks: int, + mat_col_stride: int, + transposed: bool, + gp_regs: list[int], + caller_name: str, + unroll: bool | None = None, + ) -> str: + """ + Shared implementation kernel for vram_sub_projection_asm and + vram_sub_projection_T_asm. + + Parameters resolved by the caller before this point: + header_lines -- comment lines already assembled by the caller + vram_row_start_addr -- VRAM address of the first activation block + mram_start_addr -- MRAM address of the first weight block + result_vram_addr -- VRAM destination address for the (mlen, mlen) result + full_batch -- full batch dimension of the activation VRAM matrix + num_hidden_blocks -- number of K-blocks to accumulate over + mat_col_stride -- MRAM outer-column stride (blen for M_MM, blen*mlen for M_TMM) + transposed -- True → emit M_TMM with (act, mat) operand order; + False → emit M_MM with (mat, act) operand order + gp_regs -- list of at least 9 GP register indices + caller_name -- used in error messages only + unroll -- override instance unroll_loops flag (None = use instance default) + + Returns: + ISA code string. + """ + do_unroll = self.unroll_loops if unroll is None else unroll + return vram_sub_projection_asm_impl( + mlen=self.mlen, + blen=self.blen, + unroll_loops=do_unroll, + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=transposed, + gp_regs=gp_regs, + caller_name=caller_name, + ) + + def vram_sub_projection_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_col_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + unroll: bool | None = None, + ) -> str: + """ + Generate ISA for VRAM sub-block × MRAM sub-matrix multiply. + + Computes: result = VRAM_A[row_idx][:] @ MRAM_W[:][col_idx] + + VRAM_A[row_idx][:] is (mlen, hidden_size) spread across multiple (mlen, mlen) column blocks. + MRAM_W[:][col_idx] is (hidden_size, mlen) already loaded into MRAM. + Result is (mlen, mlen) written to VRAM. + + Loop structure (outer=output cols by blen, middle=output rows by blen, inner=K accumulation). + k_block_start/k_block_count select a K-split chunk. + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] + + if vram_mat_name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") + vram_layout = self.vram_matrices[vram_mat_name] + + mram_layout = self.hbm_matrices[mram_mat_name] + + vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) + mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) + # K-split: slice to only the loaded k-chunk + if k_block_count is not None: + mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] + + num_hidden_blocks = len(mram_col_blocks) + if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): + raise ValueError( + f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " + f"got {num_hidden_blocks}" + ) + + for sub_block in mram_col_blocks: + if sub_block.mram_addr is None: + raise RuntimeError(f"SubBlock {mram_mat_name}[{sub_block.row_idx}][{mram_col_idx}] not loaded to MRAM") + + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr + mram_col_start_addr = mram_col_blocks[0].mram_addr + + header_lines = [ + f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", + f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_col_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=self.blen, + transposed=False, + gp_regs=gp_regs, + caller_name="vram_sub_projection_asm", + unroll=unroll, + ) + + def vram_sub_projection_T_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_row_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + unroll: bool | None = None, + ) -> str: + """ + Generate ISA for VRAM sub-block × MRAM sub-matrix transposed multiply. + + Computes: result = VRAM_A[row_idx][:] @ MRAM_W[row_idx][:]^T + + VRAM_A[row_idx][:] is (mlen, hidden_size). + MRAM_W[row_idx][:] is (mlen, hidden_size); transposed to (hidden_size, mlen). + Result is (mlen, mlen) written to VRAM. + + Uses M_TMM instruction. mat_col_stride = blen*mlen (transposed addressing). + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] + + if vram_mat_name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") + vram_layout = self.vram_matrices[vram_mat_name] + + mram_layout = self.hbm_matrices[mram_mat_name] + + vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) + mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) + + if len(vram_row_blocks) != len(mram_row_blocks): + raise ValueError( + f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" + ) + + num_hidden_blocks = len(vram_row_blocks) + + for sub_block in mram_row_blocks: + if sub_block.mram_addr is None: + raise RuntimeError(f"SubBlock {mram_mat_name}[{mram_row_idx}][{sub_block.col_idx}] not loaded to MRAM") + + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[0].vram_addr + mram_row_start_addr = mram_row_blocks[0].mram_addr + # NOTE: For M_TMM (transposed matmul), the MRAM outer-column stride is + # blen * mlen (full sub-block size) because M_TMM reads the weight in + # transposed layout. This differs from the non-transposed path which uses + # blen. This is intentional for the M_TMM addressing contract. + mat_col_stride = self.blen * self.mlen + + header_lines = [ + f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", + f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_row_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=True, + gp_regs=gp_regs, + caller_name="vram_sub_projection_T_asm", + unroll=unroll, + ) + + def vram_block_add_asm( + self, + src1_name: str, + src1_row_idx: int, + src1_col_idx: int, + src2_name: str, + src2_row_idx: int, + src2_col_idx: int, + target_name: str, + target_row_idx: int, + target_col_idx: int, + gp_regs: list[int] | None = None, + ) -> str: + """ + Add two mlen x mlen blocks and write to any target block: + target[target_row_idx][target_col_idx] = + src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] + + Source/target can be the same matrix (supports in-place overwrite). + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4] + if len(gp_regs) < 4: + raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") + + src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) + src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) + target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) + + gp_dst = gp_regs[0] + gp_src1 = gp_regs[1] + gp_src2 = gp_regs[2] + gp_loop = gp_regs[3] + + lines = [] + lines.append( + f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " + f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" + ) + + # One V_ADD_VV processes one row (mlen elements). Use C_LOOP to reduce ISA size. + if self.unroll_loops: + for i in range(self.mlen): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + else: + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + + return "\n".join(lines) + "\n" + + # ========================================================================== + # High-level Interface: Complete Sub-block Computation + # ========================================================================== + + def compute_sub_matmul( + self, + a_name: str, + a_row_idx: int | slice, + b_name: str, + b_col_idx: int | slice, + result_name: str, + transpose_b: bool = False, + ) -> tuple[str, int]: + """Not yet implemented. Use vram_sub_projection_asm or vram_sub_projection_T_asm.""" + raise NotImplementedError( + "compute_sub_matmul is not yet implemented. Use vram_sub_projection_asm " + "or vram_sub_projection_T_asm for matrix multiplication." + ) + + # ========================================================================== + # Format Conversion: HBM <-> VRAM + # ========================================================================== + + def load_activation_with_format_convert_asm( + self, + name: str, + hbm_base_addr: int, + batch: int, + hidden_size: int, + vram_dest_addr: int, + hbm_addr_reg: int = 0, + gp_regs: list[int] | None = None, + ) -> str: + """ + Load activation from HBM to VRAM with format conversion. + + HBM layout: [batch, hidden_size] row-major (element[b,h] at hbm_base + b*hidden_size + h). + VRAM layout: [batch, mlen, hidden/mlen] column-block major. + element[b,h]: vram_base + (h//mlen)*batch*mlen + b*mlen + (h%mlen). + + H_PREFETCH_V loads mlen elements per call; columns loaded in blocks of mlen. + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4, 5] + + asm = IsaBuilder() + asm.comment(f"Load Activation with Format Convert: {name}") + asm.comment(f"HBM[{hbm_base_addr}]: [batch={batch}, hidden={hidden_size}] row-major") + asm.comment(f"VRAM[{vram_dest_addr}]: [batch, mlen, hidden/mlen] column-block major") + + gp_hbm_offset = gp_regs[0] + gp_stride = gp_regs[1] + gp_vram = gp_regs[2] + _gp_outer = gp_regs[3] + _gp_inner = gp_regs[4] + + num_col_blocks = hidden_size // self.mlen + preload_len = 4 # load 4 rows per H_PREFETCH_V call + + total_size = batch * hidden_size + asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), total_size) + asm.instr("C_SET_SCALE_REG", gp(gp_hbm_offset)) + + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), hidden_size) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + for col_block in range(num_col_blocks): + asm.comment(f"Column block {col_block}") + + hbm_offset = col_block * self.mlen + vram_addr = vram_dest_addr + col_block * batch * self.mlen + + asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset) + asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr) + + for batch_block in range(math.ceil(batch / preload_len)): + actual_batch_offset = batch_block * preload_len * hidden_size + actual_vram_offset = batch_block * preload_len * self.mlen + + asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset + actual_batch_offset) + asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr + actual_vram_offset) + asm.instr("H_PREFETCH_V", gp(gp_vram), gp(gp_hbm_offset), areg(hbm_addr_reg), 1, 0) + + return asm.render() + + def store_activation_with_format_convert_asm( + self, + name: str, + vram_src_addr: int, + batch: int, + hidden_size: int, + hbm_dest_addr: int, + hbm_addr_reg: int = 0, + gp_regs: list[int] | None = None, + ) -> str: + """ + Store activation from VRAM to HBM with format conversion. + + VRAM layout: [batch, mlen, hidden/mlen] column-block major. + HBM layout: [batch, hidden_size] row-major. + + H_STORE_V stores mlen elements per call; columns stored in blocks of mlen. + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4, 5] + + asm = IsaBuilder() + asm.comment(f"Store Activation with Format Convert: {name}") + asm.comment(f"VRAM[{vram_src_addr}]: [batch, mlen, hidden/mlen] column-block major") + asm.comment(f"HBM[{hbm_dest_addr}]: [batch={batch}, hidden={hidden_size}] row-major") + + gp_hbm_offset = gp_regs[0] + gp_stride = gp_regs[1] + gp_vram = gp_regs[2] + _gp_outer = gp_regs[3] + _gp_inner = gp_regs[4] + + num_col_blocks = hidden_size // self.mlen + store_amount = 4 # store 4 rows per H_STORE_V call + + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), hidden_size) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + for col_block in range(num_col_blocks): + asm.comment(f"Column block {col_block}") + + hbm_offset = col_block * self.mlen + vram_addr = vram_src_addr + col_block * batch * self.mlen + + for batch_block in range(math.ceil(batch / store_amount)): + actual_batch_offset = batch_block * store_amount * hidden_size + actual_vram_offset = batch_block * store_amount * self.mlen + + asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset + actual_batch_offset) + asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr + actual_vram_offset) + asm.instr("H_STORE_V", gp(gp_vram), gp(gp_hbm_offset), areg(hbm_addr_reg), 0) + + return asm.render() + + # ========================================================================== + # Pre-calculated Address Table Generation + # ========================================================================== + + def generate_address_table(self, name: str) -> dict[str, int]: + """Generate complete address table for a matrix (for debugging and verification).""" + if name not in self.hbm_matrices: + raise KeyError(f"Matrix '{name}' not registered") + + layout = self.hbm_matrices[name] + addr_table = {} + + for (r, c), sub_block in layout.sub_blocks.items(): + key = f"{name}[{r}][{c}]" + addr_table[f"{key}_hbm_offset"] = sub_block.hbm_offset + addr_table[f"{key}_hbm_abs"] = layout.hbm_base_addr + sub_block.hbm_offset + if sub_block.mram_addr is not None: + addr_table[f"{key}_mram"] = sub_block.mram_addr + + return addr_table + + def print_address_table(self, name: str): + """Print address table for a matrix.""" + addr_table = self.generate_address_table(name) + print(f"Address Table for {name}:") + for key, addr in sorted(addr_table.items()): + print(f" {key}: {addr}") + + # ========================================================================== + # Helper Methods + # ========================================================================== + + def get_loaded_block_addr(self, name: str, row_idx: int, col_idx: int) -> int: + """Get MRAM address of a loaded sub-block.""" + block_key = f"{name}[{row_idx}][{col_idx}]" + if block_key not in self.loaded_sub_blocks: + raise KeyError(f"SubBlock {block_key} not loaded") + return self.loaded_sub_blocks[block_key].mram_addr + + def is_block_loaded(self, name: str, row_idx: int, col_idx: int) -> bool: + """Check whether a sub-block has been loaded into MRAM.""" + block_key = f"{name}[{row_idx}][{col_idx}]" + return block_key in self.loaded_sub_blocks + + def reset(self): + """Reset manager state.""" + self.hbm_matrices.clear() + self.vram_matrices.clear() + self.fpram_matrices.clear() + self.vram_allocator.reset() + self.mram_allocator.reset() + self.fpram_allocator.reset() + self.loaded_sub_blocks.clear() + self._address_cache.clear() + + def print_table(self): + """Print unified managed object table.""" + print("=" * 95) + print("Managed Object Table") + print("=" * 95) + print( + f"{'Name':<20} {'Kind':<12} {'Shape':<15} {'HBM Addr':<10} {'HBM Size':<10} {'VRAM Addr':<12} {'FPRAM':<12} {'Size':<8}" + ) + print("-" * 95) + names = sorted(set(self.hbm_matrices) | set(self.vram_matrices) | set(self.fpram_matrices)) + for name in names: + info = self[name] + shape_str = f"({info.shape[0]}, {info.shape[1]})" + vram_str = f"{info.vram_addr}" if info.vram_addr is not None else "None" + fpram_str = f"{info.fpram_addr}/{info.fpram_size}" if info.fpram_addr is not None else "None" + print( + f"{name:<20} {info.kind:<12} {shape_str:<15} " + f"{info.hbm_addr:<10} {info.hbm_size:<10} {vram_str:<12} {fpram_str:<12} {info.size:<8}" + ) + print("=" * 95) + + def print_layout(self, name: str): + """Print block layout for a matrix.""" + if name not in self.hbm_matrices: + print(f"Matrix '{name}' not registered") + return + + layout = self.hbm_matrices[name] + print(f"Matrix: {name}") + print(f" Full shape: {layout.full_shape}") + print(f" Block size: {layout.block_size}") + print(f" Blocks: {layout.num_row_blocks} x {layout.num_col_blocks}") + print(f" HBM base: {layout.hbm_base_addr}") + print(" Sub blocks:") + for (r, c), sub in layout.sub_blocks.items(): + loaded = "LOADED" if sub.mram_addr is not None else "" + print(f" [{r}][{c}]: hbm_off={sub.hbm_offset}, mram={sub.mram_addr} {loaded}") + +__all__ = ["TileCompiler"] diff --git a/aten/plena/vars.py b/aten/plena/vars.py new file mode 100644 index 0000000..b513165 --- /dev/null +++ b/aten/plena/vars.py @@ -0,0 +1,163 @@ +"""Tensor proxy classes for the ATen PLENA compiler path.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from compiler.aten.plena_compiler import PlenaCompiler + + +class TensorVar: + """ + Tensor proxy object base class + + All tensor variables inherit from this class. + Supports __matmul__ (`@`) operator, which automatically dispatches to appropriate PlenaCompiler methods. + + Dual naming: + - display_name: User-visible name (e.g., "temp", "Q", "S") + - internal_name: System internal name (e.g., "my_func_0/temp"), used for symbol table and ISA generation + """ + + def __init__( + self, + program: PlenaCompiler, + internal_name: str, + kind: str, + shape: tuple[int, int], + display_name: str | None = None, + ): + self._program = program + self.internal_name = internal_name # System internal name (with scope prefix), used by symbol table + self.display_name = display_name if display_name is not None else internal_name # User-visible name + self.kind = kind # "input", "batch", "matrix", "vram_matrix" + self.shape = shape + + @property + def name(self) -> str: + """Compatibility property: returns internal_name for internal system use""" + return self.internal_name + + def __matmul__(self, other): + """A @ B: Dispatch to appropriate computation based on operand types""" + return self._program._dispatch_matmul(self, other) + + def __repr__(self): + if self.display_name != self.internal_name: + return ( + f"{self.__class__.__name__}(display={self.display_name!r}, " + f"internal={self.internal_name!r}, shape={self.shape})" + ) + return f"{self.__class__.__name__}({self.display_name!r}, shape={self.shape})" + + +class InputVar(TensorVar): + """ + Input variable: tensor declared in HBM + + Not yet loaded to VRAM; needs to be loaded via load_batch / load_matrix. + + If ``prestaged_vram_addr`` is not None the tensor is assumed to be already + present in VRAM at that byte address. ``load_batch`` will register it at + that address without emitting any HBM→VRAM prefetch instructions. + """ + + def __init__( + self, + program: PlenaCompiler, + name: str, + shape: tuple[int, int], + hbm_addr: int, + hbm_size: int, + display_name: str | None = None, + prestaged_vram_addr: int | None = None, + ): + super().__init__(program, name, "input", shape, display_name=display_name) + self.hbm_addr = hbm_addr + self.hbm_size = hbm_size + self.prestaged_vram_addr = prestaged_vram_addr + + +class FPVar: + """ + FP variable: maps to a contiguous region in FPRAM + + Declared via prog.fp_var("scale", size=1), automatically allocates FPRAM space. + Provides .address for ISA generation (S_LD_FP / S_ST_FP). + + Usage: + scale = prog.fp_var("scale", size=1) + m_old = prog.fp_var("m_old", size=64) + + scale.address # -> FPRAM address (int) + scale.size # -> number of elements + scale[3] # -> address + 3 (element offset) + """ + + def __init__( + self, program: PlenaCompiler, internal_name: str, address: int, size: int, display_name: str | None = None + ): + self._program = program + self.internal_name = internal_name + self.display_name = display_name if display_name is not None else internal_name + self.address = address + self.size = size + + @property + def name(self) -> str: + return self.internal_name + + def __getitem__(self, idx: int) -> int: + """Element offset: fp_var[i] -> address + i""" + if idx < 0 or idx >= self.size: + raise IndexError(f"FPVar '{self.display_name}' index {idx} out of range [0, {self.size})") + return self.address + idx + + def __repr__(self): + return f"FPVar({self.display_name!r}, addr={self.address}, size={self.size})" + + +class VRAMMatrixVar(TensorVar): + """ + VRAM matrix variable: large matrix allocated via alloc + + Used to store intermediate results (e.g., S block, PV, O). + Supports sub-block indexed writes: `O[r][c] = ...` + """ + + def __init__(self, program: PlenaCompiler, name: str, shape: tuple[int, int], display_name: str | None = None): + super().__init__(program, name, "vram_matrix", shape, display_name=display_name) + + +class TensorKind(Enum): + """Identifies which memory the tensor lives in / which proxy backs it.""" + + HBM = "hbm" # legacy: InputVar + VRAM = "vram" # legacy: VRAMMatrixVar + FPRAM = "fpram" # legacy: FPVar + + +# Type alias: a "Tensor" is any of the legacy proxy objects. Callers can use +# this annotation without worrying about which specific backing allocator +# a given tensor lives in. +Tensor = TensorVar | InputVar | VRAMMatrixVar | FPVar + + +def tensor_kind(tensor: Tensor) -> TensorKind: + """Return the backing storage of a tensor proxy.""" + if isinstance(tensor, FPVar): + return TensorKind.FPRAM + if isinstance(tensor, VRAMMatrixVar): + return TensorKind.VRAM + if isinstance(tensor, InputVar): + return TensorKind.HBM + if isinstance(tensor, TensorVar): + # Generic TensorVar without a specific backing — classify by ``kind``. + kind = getattr(tensor, "kind", "") + if kind in ("vram_matrix", "batch", "matrix"): + return TensorKind.VRAM + if kind == "input": + return TensorKind.HBM + raise TypeError(f"Unknown tensor type: {type(tensor).__name__}") diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py index 550c872..fa8f0d0 100644 --- a/aten/plena_compiler.py +++ b/aten/plena_compiler.py @@ -1,5739 +1,53 @@ -"""PlenaCompiler -- ATen Pipeline (Pipeline 1) compilation backend. +"""Compatibility facade for the ATen PLENA compiler path. -Manages VRAM/MRAM/FPRAM allocation, HBM weight layout, and address -register initialization (C_SET_ADDR_REG). Produces numerically verified -ISA (98-100% allclose against PyTorch golden reference). - -Contains the full inheritance chain: TileCompiler (memory bookkeeping) -> -DeveloperCompiler (ISA emission, FP/FPRAM ops, interrupts) -> PlenaCompiler -(user-facing DSL). Tensor proxy classes (TensorVar, InputVar, VRAMMatrixVar, -FPVar) and the unified Tensor type union / TensorKind enum are re-exported -from the same module. - -Previously aliased as ``PLENAProgram``; that alias has been retired. - -See docs/COMPILATION_PIPELINES.md for the full architecture overview. -""" - -from __future__ import annotations - -from enum import Enum -from dataclasses import dataclass, field -import math - -import os -from collections.abc import Callable -from functools import wraps - -from compiler.asm_templates import ( - preload_act_asm, - reset_reg_asm, - preload_addr_reg_asm, - store_act_asm, - rms_norm_asm, - layer_norm_asm, - rope_asm, -) -from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl - - -MLEN = 64 # Minimum matrix block size -BLEN = 4 # Vector tile size -IMM2_BOUND = 2**18 - - -# ============================================================================== -# Virtual Memory Manager -# ============================================================================== - - -@dataclass -class MemoryBlock: - """Memory Block Information""" - - name: str # Allocation name (e.g., "W[0][1]" or "activation_A") - addr: int # Starting address - size: int # Block size (number of elements) - - def __repr__(self) -> str: - return f"MemBlock({self.name}, addr={self.addr}, size={self.size})" - - -class VirtualMemoryManager: - """ - Virtual Memory Manager - - Core Design: - - used_stack: Allocated and in-use memory blocks - - free_stack: Freed and reusable memory blocks - - Workflow: - 1. allocate(name, size): Allocate memory - - Prioritize best-fit search for reusable blocks in free_stack - - If not found, use bump allocation from the end - 2. free(name): Free memory - - Move block from used_stack to free_stack - - Address can be reused by subsequent allocate calls - - Throws exception if not found when strict=True, returns None when strict=False - - VRAM/MRAM storage format: (batch_size, mlen, hidden_size/mlen), column-block major. - mlen=64, blen=4. Alignment depends on storage hierarchy. - """ - - def __init__(self, total_size: int, alignment: int = MLEN, mem_name: str = "Memory"): - """ - Args: - total_size: Total memory size (number of elements) - alignment: alignment granularity (VRAM uses MLEN=64, MRAM uses MLEN*MLEN=4096) - mem_name: Memory name, for debugging information (e.g., "VRAM" or "MRAM") - """ - self.total_size = total_size - self.alignment = alignment - self.mem_name = mem_name - self.next_bump = 0 # Bump allocation pointer - - # Two core stacks - self.used_stack: list[MemoryBlock] = [] - self.free_stack: list[MemoryBlock] = [] - - def _align(self, value: int) -> int: - """Align value to alignment""" - return ((value + self.alignment - 1) // self.alignment) * self.alignment - - def allocate(self, name: str, size: int) -> int: - """ - Allocate memory. - - Strategy: - 1. Best-fit from free_stack (reusable block with least waste). - 2. If no suitable block, bump allocation. - - Returns: - Allocated starting address. - - Raises: - MemoryError: Insufficient memory. - """ - aligned_size = self._align(size) - - # Strategy 1: best-fit from free_stack - best_idx = None - best_waste = float("inf") - - for i, block in enumerate(self.free_stack): - if block.size >= aligned_size: - waste = block.size - aligned_size - if waste < best_waste: - best_waste = waste - best_idx = i - - if best_idx is not None: - reused_block = self.free_stack.pop(best_idx) - - # If block is larger than needed, split remaining part and return to free_stack - if reused_block.size > aligned_size: - remaining = MemoryBlock( - name="", addr=reused_block.addr + aligned_size, size=reused_block.size - aligned_size - ) - self.free_stack.append(remaining) - - new_block = MemoryBlock(name=name, addr=reused_block.addr, size=aligned_size) - self.used_stack.append(new_block) - return new_block.addr - - # Strategy 2: Bump allocation - aligned_addr = self._align(self.next_bump) - - if self.total_size > 0 and aligned_addr + aligned_size > self.total_size: - raise MemoryError( - f"{self.mem_name} overflow: need {aligned_size} at addr {aligned_addr}, " - f"total_size={self.total_size}, " - f"used={len(self.used_stack)} blocks, " - f"free={len(self.free_stack)} blocks" - ) - - new_block = MemoryBlock(name=name, addr=aligned_addr, size=aligned_size) - self.used_stack.append(new_block) - self.next_bump = aligned_addr + aligned_size - return aligned_addr - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """ - Free memory: move block from used_stack to free_stack. - - Args: - name: Name of allocation to free - strict: Throws KeyError if not found when strict=True, returns None when strict=False - - Returns: - Freed memory block, returns None if strict=False and not found - """ - for i, block in enumerate(self.used_stack): - if block.name == name: - freed = self.used_stack.pop(i) - self.free_stack.append(freed) - self._coalesce_free_stack() - return freed - - if strict: - raise KeyError( - f"{self.mem_name}: allocation '{name}' not found in used_stack. " - f"Current used: {[b.name for b in self.used_stack]}" - ) - return None - - def mark_used(self, addr: int, size: int, name: str) -> None: - """ - Register a pre-known address range as occupied without bump allocation. - - Used for prestaged VRAM tensors that are already in VRAM before program - execution (e.g. Q pre-loaded by the test harness). Advances next_bump - past the end of this region so subsequent bump allocations do not collide. - - Args: - addr: Start address of the pre-occupied region. - size: Number of elements in the region. - name: Name for tracking/free. - """ - aligned_size = self._align(size) - block = MemoryBlock(name=name, addr=addr, size=aligned_size) - self.used_stack.append(block) - # Advance bump pointer past this region if it would otherwise overlap. - end = addr + aligned_size - if self.next_bump < end: - self.next_bump = end - - def _coalesce_free_stack(self): - """ - Merge adjacent free blocks by address to reduce long-term fragmentation. - """ - if len(self.free_stack) <= 1: - return - - blocks = sorted(self.free_stack, key=lambda b: b.addr) - merged: list[MemoryBlock] = [blocks[0]] - for block in blocks[1:]: - prev = merged[-1] - if prev.addr + prev.size == block.addr: - merged[-1] = MemoryBlock( - name="", - addr=prev.addr, - size=prev.size + block.size, - ) - else: - merged.append(block) - - self.free_stack = merged - - def is_allocated(self, name: str) -> bool: - """Check if a name is in used_stack""" - return any(b.name == name for b in self.used_stack) - - def get_block(self, name: str) -> MemoryBlock | None: - """Get memory block with specified name from used_stack""" - for block in self.used_stack: - if block.name == name: - return block - return None - - def get_used_size(self) -> int: - """Get total used size""" - return sum(b.size for b in self.used_stack) - - def get_free_size(self) -> int: - """Get total reusable size""" - return sum(b.size for b in self.free_stack) - - def reset(self): - """Reset manager""" - self.next_bump = 0 - self.used_stack.clear() - self.free_stack.clear() - - def print_status(self): - """Print memory status""" - print(f"=== {self.mem_name} Virtual Memory Status ===") - print(f"Total size: {self.total_size}") - print(f"Bump pointer: {self.next_bump}") - print(f"Used blocks ({len(self.used_stack)}):") - for b in self.used_stack: - print(f" {b}") - print(f"Free blocks ({len(self.free_stack)}):") - for b in self.free_stack: - print(f" {b}") - total_used = self.get_used_size() - total_free = self.get_free_size() - if self.total_size > 0: - available = self.total_size - self.next_bump + total_free - print( - f"Summary: used={total_used}, free={total_free}, " - f"bump={self.next_bump}, available={available}/{self.total_size}" - ) - else: - print(f"Summary: used={total_used}, free={total_free}, bump={self.next_bump} (unlimited mode)") - - def __repr__(self) -> str: - return ( - f"VirtualMemoryManager({self.mem_name}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)}, " - f"bump={self.next_bump}/{self.total_size})" - ) - - -# ============================================================================== -# Sub-matrix Information -# ============================================================================== - - -@dataclass -class SubMatrixInfo: - """Metadata for sub-matrices""" - - parent_name: str # Parent matrix name - row_idx: int # Sub-block row index - col_idx: int # Sub-block column index - shape: tuple[int, int] # Sub-block shape (typically 64x64) - - # Pre-calculated addresses (computed during compiler phase, used directly at runtime) - hbm_offset: int = 0 # Offset in HBM (in elements) - mram_addr: int | None = None # Address in MRAM (if loaded) - - def __repr__(self) -> str: - mram_str = f"{self.mram_addr}" if self.mram_addr is not None else "None" - return ( - f"SubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, hbm_off={self.hbm_offset}, mram={mram_str})" - ) - - -@dataclass -class MatrixBlockLayout: - """ - Block layout information for large matrices. - - HBM storage: [rows, cols] row-major contiguous, stride=cols per row. - MRAM storage: [batch, mlen, hidden/mlen] column-block major. - """ - - name: str - full_shape: tuple[int, int] # Full matrix shape (rows, cols) - block_size: int = MLEN # Sub-block size (default 64) - - num_row_blocks: int = 0 - num_col_blocks: int = 0 - - # HBM Address Information - hbm_base_addr: int = 0 - hbm_size: int = 0 # Size after applying real_data_ratio (MXFP8 = 1.125) - - # Sub-block map: (row_idx, col_idx) -> SubMatrixInfo - sub_blocks: dict[tuple[int, int], SubMatrixInfo] = field(default_factory=dict) - - def __post_init__(self): - """Initialize block information""" - rows, cols = self.full_shape - self.num_row_blocks = math.ceil(rows / self.block_size) - self.num_col_blocks = math.ceil(cols / self.block_size) - - # Create information for all sub-blocks (pre-calculate addresses) - for r in range(self.num_row_blocks): - for c in range(self.num_col_blocks): - # HBM offset (row-major): sub-block (r,c) starts at r*block_size*cols + c*block_size - hbm_offset = r * self.block_size * cols + c * self.block_size - - sub_info = SubMatrixInfo( - parent_name=self.name, - row_idx=r, - col_idx=c, - shape=(self.block_size, self.block_size), - hbm_offset=hbm_offset, - mram_addr=None, - ) - self.sub_blocks[(r, c)] = sub_info - - def get_sub_block(self, row_idx: int, col_idx: int) -> SubMatrixInfo: - """Get specified sub-block""" - if (row_idx, col_idx) not in self.sub_blocks: - raise IndexError(f"Sub block [{row_idx}][{col_idx}] out of range") - return self.sub_blocks[(row_idx, col_idx)] - - def get_row_blocks(self, row_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a row""" - return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] - - def get_col_blocks(self, col_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a column""" - return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] - - -# ============================================================================== -# VRAM Sub-matrix Information -# ============================================================================== - - -@dataclass -class VRAMSubMatrixInfo: - """Metadata for sub-matrices in VRAM""" - - parent_name: str # Parent matrix name - row_idx: int # Sub-block row index (batch dimension) - col_idx: int # Sub-block column index (hidden dimension) - shape: tuple[int, int] # Sub-block shape (typically mlen x mlen) - - # Pre-calculated VRAM address - vram_addr: int = 0 - - def __repr__(self) -> str: - return ( - f"VRAMSubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, vram={self.vram_addr})" - ) - - -@dataclass -class VRAMMatrixBlockLayout: - """ - Block layout for large matrices in VRAM. - - VRAM storage format: [batch, mlen, hidden/mlen] column-block major. - - Batch dimension contiguous within each column block. - - Column blocks laid out sequentially. - - Column block c base address: vram_base + c * batch * mlen. - Sub-block (r, c) offset within column block: r * mlen * mlen. - """ - - name: str - full_shape: tuple[int, int] # Full matrix shape (batch, hidden_size) - vram_base_addr: int # VRAM base address - block_size: int = MLEN # Sub-block size (default 64) - - num_row_blocks: int = 0 # Number of blocks in batch dimension - num_col_blocks: int = 0 # Number of blocks in hidden dimension - - # Sub-block map: (row_idx, col_idx) -> VRAMSubMatrixInfo - sub_blocks: dict[tuple[int, int], VRAMSubMatrixInfo] = field(default_factory=dict) - - def __post_init__(self): - """Initialize block information""" - batch, hidden = self.full_shape - self.num_row_blocks = math.ceil(batch / self.block_size) - self.num_col_blocks = math.ceil(hidden / self.block_size) - - # VRAM column-block major address calculation: - # col block c base = vram_base + c * batch * mlen - # row sub-block r offset within column block = r * mlen * mlen - for r in range(self.num_row_blocks): - for c in range(self.num_col_blocks): - col_block_base = self.vram_base_addr + c * batch * self.block_size - row_offset = r * self.block_size * self.block_size - vram_addr = col_block_base + row_offset - - sub_info = VRAMSubMatrixInfo( - parent_name=self.name, - row_idx=r, - col_idx=c, - shape=(self.block_size, self.block_size), - vram_addr=vram_addr, - ) - self.sub_blocks[(r, c)] = sub_info - - def get_sub_block(self, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: - """Get specified sub-block""" - if (row_idx, col_idx) not in self.sub_blocks: - raise IndexError(f"VRAM sub block [{row_idx}][{col_idx}] out of range") - return self.sub_blocks[(row_idx, col_idx)] - - def get_row_blocks(self, row_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a row (A[row_idx][:])""" - return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] - - def get_col_blocks(self, col_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a column""" - return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] - - def get_row_vram_addrs(self, row_idx: int) -> list[int]: - """Get list of VRAM addresses for all sub-blocks in a row""" - return [block.vram_addr for block in self.get_row_blocks(row_idx)] - - -# ============================================================================== -# Unified Memory Object Metadata -# ============================================================================== - - -@dataclass -class MemoryObjectInfo: - """Unified metadata for objects managed across HBM / VRAM / FPRAM.""" - - name: str - kind: str - dtype: str = "fp16" - shape: tuple[int, int] = (0, 0) - size: int = 0 - hbm_addr: int = -1 - hbm_size: int = 0 - vram_addr: int | None = None - fpram_addr: int | None = None - fpram_size: int = 0 - - -@dataclass -class FPRAMObjectLayout: - """FPRAM object layout.""" - - name: str - fpram_addr: int - size: int - dtype: str = "fp16" - kind: str = "FPRAMObject" - - -# ============================================================================== -# MRAM Allocator -# ============================================================================== - - -class MRAMAllocator: - """ - Matrix RAM address allocator (based on VirtualMemoryManager). - - Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. - Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). - Supports virtual free/reuse: freed blocks move to free_stack and are - preferred by subsequent allocations. - """ - - def __init__(self, total_size: int = MLEN * MLEN * 4): - """ - Args: - total_size: Total MRAM size (default 16384, can hold 4 64x64 matrix blocks) - """ - self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=MLEN * MLEN, # aligned to one sub-block size - mem_name="MRAM", - ) - - @property - def next_free(self) -> int: - return self._vmm.next_bump - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, name: str, size: int) -> int: - """Allocate MRAM space (prioritize reusing freed blocks).""" - return self._vmm.allocate(name, size) - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """Free specified allocation: move from used_stack to free_stack.""" - return self._vmm.free(name, strict=strict) - - def is_allocated(self, name: str) -> bool: - """Check if a name is allocated""" - return self._vmm.is_allocated(name) - - def reset(self): - """Reset allocator""" - self._vmm.reset() - - def print_status(self): - """Print memory status""" - self._vmm.print_status() - - -class VRAMAllocator: - """ - VRAM address allocator (based on VirtualMemoryManager). - - VRAM supports best-fit reuse + bump allocation, same as MRAM/FPRAM allocators. - Alignment defaults to MLEN to match VRAM storage format requirements. - """ - - def __init__(self, alignment: int = MLEN, total_size: int = 0): - self.alignment = alignment - self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name="VRAM") - - @property - def next_free(self) -> int: - return self._vmm.next_bump - - @next_free.setter - def next_free(self, value: int): - self._vmm.next_bump = value - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, size: int, name: str = "") -> int: - if not name: - raise ValueError("VRAMAllocator.allocate() requires name for subsequent free.") - return self._vmm.allocate(name, size) - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - return self._vmm.free(name, strict=strict) - - def is_allocated(self, name: str) -> bool: - return self._vmm.is_allocated(name) - - def reset(self): - self._vmm.reset() - - def print_status(self): - self._vmm.print_status() - - def __repr__(self) -> str: - return ( - f"VRAMAllocator(next_free={self.next_free}, alignment={self.alignment}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)})" - ) - - -class FPRAMAllocator: - """ - Floating Point RAM Allocator (based on VirtualMemoryManager). - - FPRAM stores scalar FP values (f16), accessed via S_LD_FP / S_ST_FP. - Uses the same strategy as VRAM/MRAM allocator: - - used_stack + free_stack - - allocate: best-fit reuse first, then bump - - free: move block to free_stack, supports out-of-order free - - Fixed slot conventions (must not be overwritten by dynamic allocations): - slot 0 = 0.0 (gp0/f0 reserved as hardware zero) - slot 1 = attn_scale - slot 2 = -inf (online softmax) - slot 3 = eps (rms_norm/layer_norm) - slot 4 = 1/hidden_size (rms_norm/layer_norm) - slot 5 = 1.0 (FFN SiLU sigmoid denominator; im2col fp_one_reg) - - Hardware: 1024 f16 elements (configurable via total_size). - """ - - def __init__(self, total_size: int = 1024): - """ - Args: - total_size: Total FP RAM size (default 1024, matching hardware fpsram) - """ - self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=1, - mem_name="FPRAM", - ) - self.allocations: dict[str, tuple[int, int]] = {} - self._snapshots: dict[int, tuple[int, list[MemoryBlock], list[MemoryBlock], dict[str, tuple[int, int]]]] = {} - self._next_snapshot_id = 1 - - @property - def next_free(self) -> int: - """Compatibility alias: next bump pointer.""" - return self._vmm.next_bump - - @next_free.setter - def next_free(self, value: int): - if value < 0 or value > self.total_size: - raise ValueError(f"next_free out of range: {value}, expected [0, {self.total_size}]") - self._vmm.next_bump = value - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, name: str, size: int) -> int: - """Allocate FP RAM space (best-fit + bump).""" - if size <= 0: - raise ValueError(f"FPRAM allocation size must be > 0, got {size}") - if name in self.allocations: - raise KeyError(f"FPRAM name '{name}' already allocated") - - addr = self._vmm.allocate(name, size) - self.allocations[name] = (addr, size) - return addr - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """Free a block and move it to free_stack (same as VirtualMemoryManager).""" - freed = self._vmm.free(name, strict=strict) - if freed is not None: - self.allocations.pop(name, None) - return freed - - def save_state(self) -> int: - """ - Save current allocator state and return a snapshot token. - """ - sid = self._next_snapshot_id - self._next_snapshot_id += 1 - self._snapshots[sid] = ( - self._vmm.next_bump, - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.used_stack], - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.free_stack], - dict(self.allocations), - ) - return sid - - def restore_state(self, snapshot: int): - """Restore allocator state from snapshot token.""" - if snapshot not in self._snapshots: - raise KeyError(f"Unknown FPRAM snapshot id: {snapshot}") - next_bump, used_stack, free_stack, allocations = self._snapshots[snapshot] - self._vmm.next_bump = next_bump - self._vmm.used_stack = [MemoryBlock(b.name, b.addr, b.size) for b in used_stack] - self._vmm.free_stack = [MemoryBlock(b.name, b.addr, b.size) for b in free_stack] - self.allocations = dict(allocations) - - def reset(self): - """Reset allocator""" - self._vmm.reset() - self.allocations.clear() - self._snapshots.clear() - self._next_snapshot_id = 1 - - -# ============================================================================== -# Sub Matrix Manager -# ============================================================================== - - -class TileCompiler: - """ - Tile compiler / sub-matrix manager. - - Core Functions: - 1. Register large matrices as blocked matrices. - 2. Support sub-block indexing: matrix[row_idx][col_idx] or matrix[row_idx][:]. - 3. Pre-calculate all addresses at compiler phase. - 4. Generate ISA code for load_sub_matrix and sub_projection. - - Key Constraints: - - Minimum block size is 64x64 (MLEN). - - Matrix must be loaded into MRAM before participating in computation. - - HBM and VRAM/MRAM use different storage formats (row-major vs column-block major). - """ - - def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): - self.mlen = mlen - self.blen = blen - self.unroll_loops = unroll_loops - - # Layout tables - self.hbm_matrices: dict[str, MatrixBlockLayout] = {} - self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} - self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} - # Memory Allocators - self.vram_allocator = VRAMAllocator() - self.mram_allocator = MRAMAllocator() - self.fpram_allocator = FPRAMAllocator() - - # Currently loaded sub-blocks in MRAM - self.loaded_sub_blocks: dict[str, SubMatrixInfo] = {} - - # Pre-calculated address cache - self._address_cache: dict[str, int] = {} - - def __contains__(self, name: str) -> bool: - return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices - - def __getitem__(self, name: str) -> MemoryObjectInfo: - if name not in self: - raise KeyError(f"Object '{name}' not found") - info = MemoryObjectInfo(name=name, kind="Unknown") - hbm_layout = self.hbm_matrices.get(name) - vram_layout = self.vram_matrices.get(name) - fpram_layout = self.fpram_matrices.get(name) - - if hbm_layout is not None: - rows, cols = hbm_layout.full_shape - info.shape = hbm_layout.full_shape - info.size = rows * cols - info.hbm_addr = hbm_layout.hbm_base_addr - info.hbm_size = hbm_layout.hbm_size - info.kind = "Matrix" - - if vram_layout is not None: - rows, cols = vram_layout.full_shape - info.shape = vram_layout.full_shape - info.size = rows * cols - info.vram_addr = vram_layout.vram_base_addr - info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" - - if fpram_layout is not None: - info.shape = (1, fpram_layout.size) - info.size = fpram_layout.size - info.fpram_addr = fpram_layout.fpram_addr - info.fpram_size = fpram_layout.size - info.kind = "FPRAMObject" - - return info - - def get(self, name: str, default: MemoryObjectInfo | None = None) -> MemoryObjectInfo | None: - try: - return self[name] - except KeyError: - return default - - def get_hbm_layout(self, name: str) -> MatrixBlockLayout: - """Read HBM matrix layout by name.""" - if name not in self.hbm_matrices: - raise KeyError(f"HBM matrix '{name}' not found") - return self.hbm_matrices[name] - - def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: - """Read VRAM matrix layout by name.""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not found") - return self.vram_matrices[name] - - def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: - """Read FPRAM object layout by name.""" - if name not in self.fpram_matrices: - raise KeyError(f"FPRAM object '{name}' not found") - return self.fpram_matrices[name] - - # ========================================================================== - # Unified Object Management APIs - # ========================================================================== - - def add_hbm_object( - self, - name: str, - shape: tuple[int, int], - hbm_addr: int, - dtype: str = "fp16", - kind: str = "HBMObject", - real_data_ratio: float = 1.125, - strict: bool = False, - ) -> MemoryObjectInfo: - del dtype, kind - self.register_matrix( - name=name, - shape=shape, - hbm_base_addr=hbm_addr, - real_data_ratio=real_data_ratio, - strict=strict, - ) - return self[name] - - def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.hbm_matrices: - if strict: - raise KeyError(f"HBM object '{name}' not found") - return None - info = self[name] - self.hbm_matrices.pop(name, None) - return info - - def add_vram_object( - self, - name: str, - shape: tuple[int, int], - vram_addr: int | None = None, - dtype: str = "fp16", - kind: str = "VRAMObject", - allocate_if_none: bool = True, - strict: bool = True, - ) -> MemoryObjectInfo: - rows, cols = shape - size = rows * cols - if vram_addr is None: - if not allocate_if_none: - raise ValueError("vram_addr is None and allocate_if_none is False") - vram_addr = self.vram_allocator.allocate(size=size, name=name) - del dtype, kind - self.register_vram_matrix( - name=name, - shape=shape, - vram_base_addr=vram_addr, - strict=strict, - ) - return self[name] - - def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.vram_matrices: - if strict: - raise KeyError(f"VRAM object '{name}' not found") - return None - info = self[name] - self.vram_allocator.free(name, strict=strict) - self.vram_matrices.pop(name, None) - return info - - def add_fpram_object( - self, - name: str, - size: int, - dtype: str = "fp16", - kind: str = "FPRAMObject", - ) -> MemoryObjectInfo: - fpram_addr = self.fpram_allocator.allocate(name, size) - self.fpram_matrices[name] = FPRAMObjectLayout( - name=name, - fpram_addr=fpram_addr, - size=size, - dtype=dtype, - kind=kind, - ) - return self[name] - - def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.fpram_matrices: - if strict: - raise KeyError(f"FPRAM object '{name}' not found") - return None - info = self[name] - self.fpram_allocator.free(name, strict=strict) - self.fpram_matrices.pop(name, None) - return info - - def register_matrix( - self, - name: str, - shape: tuple[int, int], - hbm_base_addr: int, - real_data_ratio: float = 1.125, - strict: bool = True, - ) -> MatrixBlockLayout: - """ - Register a large matrix and automatically block it. - - Args: - name: matrix name - shape: full shape (rows, cols) - hbm_base_addr: HBM base address - real_data_ratio: HBM storage ratio (MXFP8 = 1.125) - strict: if False, skip mlen-alignment check (for raw HBM access) - - Returns: - MatrixBlockLayout object - """ - rows, cols = shape - - if strict: - if rows % self.mlen != 0: - raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") - if cols % self.mlen != 0: - raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") - - size = rows * cols - hbm_size = int(size * real_data_ratio) - - layout = MatrixBlockLayout( - name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size - ) - - self.hbm_matrices[name] = layout - return layout - - def get_sub_block(self, name: str, row_idx: int, col_idx: int) -> SubMatrixInfo: - """Get sub-block information.""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_sub_block(row_idx, col_idx) - - def get_row_blocks(self, name: str, row_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a row: matrix[row_idx][:]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_row_blocks(row_idx) - - def get_col_blocks(self, name: str, col_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a column: matrix[:][col_idx]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_col_blocks(col_idx) - - # ========================================================================== - # VRAM Sub-matrix Management - # ========================================================================== - - def register_vram_matrix( - self, - name: str, - shape: tuple[int, int], - vram_base_addr: int, - strict: bool = True, - ) -> VRAMMatrixBlockLayout: - """ - Register a large matrix in VRAM and automatically block it. - - Args: - name: matrix name - shape: full shape (batch, hidden_size) - vram_base_addr: VRAM base address - - Returns: - VRAMMatrixBlockLayout object - """ - batch, hidden = shape - - if strict: - if batch % self.mlen != 0: - raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") - if hidden % self.mlen != 0: - raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") - - layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) - - self.vram_matrices[name] = layout - return layout - - def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: - """Get VRAM sub-block information""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_sub_block(row_idx, col_idx) - - def get_vram_row_blocks(self, name: str, row_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a row of VRAM matrix: matrix[row_idx][:]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_row_blocks(row_idx) - - def get_vram_col_blocks(self, name: str, col_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a column of VRAM matrix: matrix[:][col_idx]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_col_blocks(col_idx) - - # ========================================================================== - # Address Calculation (Core - Pre-calculated during compiler phase) - # ========================================================================== - - def compute_hbm_offset(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Compute HBM offset for sub-block (in elements, not bytes). - - HBM row-major: sub-block (r, c) starts at r*block_size*full_cols + c*block_size. - """ - layout = self.hbm_matrices[name] - sub_block = layout.get_sub_block(row_idx, col_idx) - return sub_block.hbm_offset - - def compute_absolute_hbm_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Calculate absolute HBM address of sub-block (in elements). - - Returns: - Absolute HBM address = base + offset - """ - layout = self.hbm_matrices[name] - offset = self.compute_hbm_offset(name, row_idx, col_idx) - return layout.hbm_base_addr + offset - - # ========================================================================== - # ISA Generation: Load Sub Matrix - # ========================================================================== - - def load_sub_matrix_asm( - self, - name: str, - row_idx: int, - col_idx: int, - mram_dest_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """ - Generate ISA code for loading sub-matrix from HBM to MRAM. - - HBM is row-major; H_PREFETCH_M loads mlen x mlen blocks into MRAM. - SCALE = full matrix element count; STRIDE = full column width (row stride). - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - sub_block = layout.get_sub_block(row_idx, col_idx) - - hbm_offset = sub_block.hbm_offset - sub_block.mram_addr = mram_dest_addr - - lines = [] - lines.append(f"; Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") - lines.append(f"; HBM offset: {hbm_offset} (precomputed)") - - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_dest_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - return "\n".join(lines) + "\n" - - def load_row_sub_matrices_asm( - self, - name: str, - row_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """ - Generate ISA code for loading all sub-blocks in a row: matrix[row_idx][:] - - SCALE and STRIDE set once for all sub-blocks in the row. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - num_col_blocks = layout.num_col_blocks - - lines = [] - lines.append(f"; Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") - - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - - for col_idx in range(num_col_blocks): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - lines.append(f"; SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size - - return "\n".join(lines) + "\n" - - def load_col_sub_matrices_asm( - self, - name: str, - col_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Generate ISA code for loading all sub-blocks in a column: matrix[:][col_idx]. - - Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. - k_block_start/k_block_count select a K-split slice of the column. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - num_row_blocks = layout.num_row_blocks - - lines = [] - lines.append(f"; Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") - - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - - effective_count = k_block_count if k_block_count is not None else num_row_blocks - for row_idx in range(k_block_start, k_block_start + effective_count): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - lines.append(f"; SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size - - return "\n".join(lines) + "\n" - - # ========================================================================== - # ISA Generation: Sub Projection - # ========================================================================== - - def _vram_sub_projection_asm_impl( - self, - header_lines: list[str], - vram_row_start_addr: int, - mram_start_addr: int, - result_vram_addr: int, - full_batch: int, - num_hidden_blocks: int, - mat_col_stride: int, - transposed: bool, - gp_regs: list[int], - caller_name: str, - unroll: bool | None = None, - ) -> str: - """ - Shared implementation kernel for vram_sub_projection_asm and - vram_sub_projection_T_asm. - - Parameters resolved by the caller before this point: - header_lines -- comment lines already assembled by the caller - vram_row_start_addr -- VRAM address of the first activation block - mram_start_addr -- MRAM address of the first weight block - result_vram_addr -- VRAM destination address for the (mlen, mlen) result - full_batch -- full batch dimension of the activation VRAM matrix - num_hidden_blocks -- number of K-blocks to accumulate over - mat_col_stride -- MRAM outer-column stride (blen for M_MM, blen*mlen for M_TMM) - transposed -- True → emit M_TMM with (act, mat) operand order; - False → emit M_MM with (mat, act) operand order - gp_regs -- list of at least 9 GP register indices - caller_name -- used in error messages only - unroll -- override instance unroll_loops flag (None = use instance default) - - Returns: - ISA code string. - """ - do_unroll = self.unroll_loops if unroll is None else unroll - return vram_sub_projection_asm_impl( - mlen=self.mlen, - blen=self.blen, - unroll_loops=do_unroll, - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=transposed, - gp_regs=gp_regs, - caller_name=caller_name, - ) - - def vram_sub_projection_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_col_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - unroll: bool | None = None, - ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[:][col_idx] - - VRAM_A[row_idx][:] is (mlen, hidden_size) spread across multiple (mlen, mlen) column blocks. - MRAM_W[:][col_idx] is (hidden_size, mlen) already loaded into MRAM. - Result is (mlen, mlen) written to VRAM. - - Loop structure (outer=output cols by blen, middle=output rows by blen, inner=K accumulation). - k_block_start/k_block_count select a K-split chunk. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) - mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) - # K-split: slice to only the loaded k-chunk - if k_block_count is not None: - mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] - - num_hidden_blocks = len(mram_col_blocks) - if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): - raise ValueError( - f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " - f"got {num_hidden_blocks}" - ) - - for sub_block in mram_col_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{sub_block.row_idx}][{mram_col_idx}] not loaded to MRAM") - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr - mram_col_start_addr = mram_col_blocks[0].mram_addr - - header_lines = [ - f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", - f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_col_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=self.blen, - transposed=False, - gp_regs=gp_regs, - caller_name="vram_sub_projection_asm", - unroll=unroll, - ) - - def vram_sub_projection_T_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_row_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - unroll: bool | None = None, - ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix transposed multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[row_idx][:]^T - - VRAM_A[row_idx][:] is (mlen, hidden_size). - MRAM_W[row_idx][:] is (mlen, hidden_size); transposed to (hidden_size, mlen). - Result is (mlen, mlen) written to VRAM. - - Uses M_TMM instruction. mat_col_stride = blen*mlen (transposed addressing). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) - mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) - - if len(vram_row_blocks) != len(mram_row_blocks): - raise ValueError( - f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" - ) - - num_hidden_blocks = len(vram_row_blocks) - - for sub_block in mram_row_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{mram_row_idx}][{sub_block.col_idx}] not loaded to MRAM") - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[0].vram_addr - mram_row_start_addr = mram_row_blocks[0].mram_addr - # NOTE: For M_TMM (transposed matmul), the MRAM outer-column stride is - # blen * mlen (full sub-block size) because M_TMM reads the weight in - # transposed layout. This differs from the non-transposed path which uses - # blen. This is intentional for the M_TMM addressing contract. - mat_col_stride = self.blen * self.mlen - - header_lines = [ - f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", - f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_row_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=True, - gp_regs=gp_regs, - caller_name="vram_sub_projection_T_asm", - unroll=unroll, - ) - - def vram_block_add_asm( - self, - src1_name: str, - src1_row_idx: int, - src1_col_idx: int, - src2_name: str, - src2_row_idx: int, - src2_col_idx: int, - target_name: str, - target_row_idx: int, - target_col_idx: int, - gp_regs: list[int] | None = None, - ) -> str: - """ - Add two mlen x mlen blocks and write to any target block: - target[target_row_idx][target_col_idx] = - src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] - - Source/target can be the same matrix (supports in-place overwrite). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4] - if len(gp_regs) < 4: - raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") - - src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) - src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) - target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) - - gp_dst = gp_regs[0] - gp_src1 = gp_regs[1] - gp_src2 = gp_regs[2] - gp_loop = gp_regs[3] - - lines = [] - lines.append( - f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " - f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" - ) - - # One V_ADD_VV processes one row (mlen elements). Use C_LOOP to reduce ISA size. - if self.unroll_loops: - for i in range(self.mlen): - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - else: - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") - lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - - return "\n".join(lines) + "\n" - - # ========================================================================== - # High-level Interface: Complete Sub-block Computation - # ========================================================================== - - def compute_sub_matmul( - self, - a_name: str, - a_row_idx: int | slice, - b_name: str, - b_col_idx: int | slice, - result_name: str, - transpose_b: bool = False, - ) -> tuple[str, int]: - """Not yet implemented. Use vram_sub_projection_asm or vram_sub_projection_T_asm.""" - raise NotImplementedError( - "compute_sub_matmul is not yet implemented. Use vram_sub_projection_asm " - "or vram_sub_projection_T_asm for matrix multiplication." - ) - - # ========================================================================== - # Format Conversion: HBM <-> VRAM - # ========================================================================== - - def load_activation_with_format_convert_asm( - self, - name: str, - hbm_base_addr: int, - batch: int, - hidden_size: int, - vram_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Load activation from HBM to VRAM with format conversion. - - HBM layout: [batch, hidden_size] row-major (element[b,h] at hbm_base + b*hidden_size + h). - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - element[b,h]: vram_base + (h//mlen)*batch*mlen + b*mlen + (h%mlen). - - H_PREFETCH_V loads mlen elements per call; columns loaded in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - lines = [] - lines.append(f"; Load Activation with Format Convert: {name}") - lines.append(f"; HBM[{hbm_base_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - lines.append(f"; VRAM[{vram_dest_addr}]: [batch, mlen, hidden/mlen] column-block major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - preload_len = 4 # load 4 rows per H_PREFETCH_V call - - total_size = batch * hidden_size - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {total_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_hbm_offset}") - - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {hidden_size}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - for col_block in range(num_col_blocks): - lines.append(f"; Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_dest_addr + col_block * batch * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr}") - - for batch_block in range(math.ceil(batch / preload_len)): - actual_batch_offset = batch_block * preload_len * hidden_size - actual_vram_offset = batch_block * preload_len * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset + actual_batch_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr + actual_vram_offset}") - lines.append(f"H_PREFETCH_V gp{gp_vram}, gp{gp_hbm_offset}, a{hbm_addr_reg}, 1, 0") - - return "\n".join(lines) + "\n" - - def store_activation_with_format_convert_asm( - self, - name: str, - vram_src_addr: int, - batch: int, - hidden_size: int, - hbm_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Store activation from VRAM to HBM with format conversion. - - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - HBM layout: [batch, hidden_size] row-major. - - H_STORE_V stores mlen elements per call; columns stored in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - lines = [] - lines.append(f"; Store Activation with Format Convert: {name}") - lines.append(f"; VRAM[{vram_src_addr}]: [batch, mlen, hidden/mlen] column-block major") - lines.append(f"; HBM[{hbm_dest_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - store_amount = 4 # store 4 rows per H_STORE_V call - - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {hidden_size}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - for col_block in range(num_col_blocks): - lines.append(f"; Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_src_addr + col_block * batch * self.mlen - - for batch_block in range(math.ceil(batch / store_amount)): - actual_batch_offset = batch_block * store_amount * hidden_size - actual_vram_offset = batch_block * store_amount * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset + actual_batch_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr + actual_vram_offset}") - lines.append(f"H_STORE_V gp{gp_vram}, gp{gp_hbm_offset}, a{hbm_addr_reg}, 0") - - return "\n".join(lines) + "\n" - - # ========================================================================== - # Pre-calculated Address Table Generation - # ========================================================================== - - def generate_address_table(self, name: str) -> dict[str, int]: - """Generate complete address table for a matrix (for debugging and verification).""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - - layout = self.hbm_matrices[name] - addr_table = {} - - for (r, c), sub_block in layout.sub_blocks.items(): - key = f"{name}[{r}][{c}]" - addr_table[f"{key}_hbm_offset"] = sub_block.hbm_offset - addr_table[f"{key}_hbm_abs"] = layout.hbm_base_addr + sub_block.hbm_offset - if sub_block.mram_addr is not None: - addr_table[f"{key}_mram"] = sub_block.mram_addr - - return addr_table - - def print_address_table(self, name: str): - """Print address table for a matrix.""" - addr_table = self.generate_address_table(name) - print(f"Address Table for {name}:") - for key, addr in sorted(addr_table.items()): - print(f" {key}: {addr}") - - # ========================================================================== - # Helper Methods - # ========================================================================== - - def get_loaded_block_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """Get MRAM address of a loaded sub-block.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - if block_key not in self.loaded_sub_blocks: - raise KeyError(f"SubBlock {block_key} not loaded") - return self.loaded_sub_blocks[block_key].mram_addr - - def is_block_loaded(self, name: str, row_idx: int, col_idx: int) -> bool: - """Check whether a sub-block has been loaded into MRAM.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - return block_key in self.loaded_sub_blocks - - def reset(self): - """Reset manager state.""" - self.hbm_matrices.clear() - self.vram_matrices.clear() - self.fpram_matrices.clear() - self.vram_allocator.reset() - self.mram_allocator.reset() - self.fpram_allocator.reset() - self.loaded_sub_blocks.clear() - self._address_cache.clear() - - def print_table(self): - """Print unified managed object table.""" - print("=" * 95) - print("Managed Object Table") - print("=" * 95) - print( - f"{'Name':<20} {'Kind':<12} {'Shape':<15} {'HBM Addr':<10} {'HBM Size':<10} {'VRAM Addr':<12} {'FPRAM':<12} {'Size':<8}" - ) - print("-" * 95) - names = sorted(set(self.hbm_matrices) | set(self.vram_matrices) | set(self.fpram_matrices)) - for name in names: - info = self[name] - shape_str = f"({info.shape[0]}, {info.shape[1]})" - vram_str = f"{info.vram_addr}" if info.vram_addr is not None else "None" - fpram_str = f"{info.fpram_addr}/{info.fpram_size}" if info.fpram_addr is not None else "None" - print( - f"{name:<20} {info.kind:<12} {shape_str:<15} " - f"{info.hbm_addr:<10} {info.hbm_size:<10} {vram_str:<12} {fpram_str:<12} {info.size:<8}" - ) - print("=" * 95) - - def print_layout(self, name: str): - """Print block layout for a matrix.""" - if name not in self.hbm_matrices: - print(f"Matrix '{name}' not registered") - return - - layout = self.hbm_matrices[name] - print(f"Matrix: {name}") - print(f" Full shape: {layout.full_shape}") - print(f" Block size: {layout.block_size}") - print(f" Blocks: {layout.num_row_blocks} x {layout.num_col_blocks}") - print(f" HBM base: {layout.hbm_base_addr}") - print(" Sub blocks:") - for (r, c), sub in layout.sub_blocks.items(): - loaded = "LOADED" if sub.mram_addr is not None else "" - print(f" [{r}][{c}]: hbm_off={sub.hbm_offset}, mram={sub.mram_addr} {loaded}") - - -# ============================================================================== -# Example Usage -# ============================================================================== - - -class RegisterAllocator: - """Register Allocator: Manages address registers and GP registers""" - - def __init__(self, start_gp: int = 1, start_addr: int = 0, start_fp: int = 1): - # HW OPERAND_WIDTH = 4 bits → gp0-gp15; gp0 reserved as constant 0. - self.gp_registers = list(range(start_gp, 16)) - self.addr_registers = list(range(start_addr, 8)) - # f0 reserved as constant 0 (writing to f0 is a no-op for V_RED_MAX/V_RED_SUM). - self.fp_registers = list(range(start_fp, 8)) - self.used_gp = [] - self.used_addr = [] - self.used_fp = [] - - def allocate_gp(self, count: int = 1) -> list[int]: - if len(self.gp_registers) < count: - raise RuntimeError(f"Not enough GP registers available. Need {count}, have {len(self.gp_registers)}") - - allocated = self.gp_registers[:count] - self.gp_registers = self.gp_registers[count:] - self.used_gp.extend(allocated) - return allocated - - def allocate_addr(self, count: int = 1) -> list[int]: - if len(self.addr_registers) < count: - raise RuntimeError(f"Not enough address registers available. Need {count}, have {len(self.addr_registers)}") - - allocated = self.addr_registers[:count] - self.addr_registers = self.addr_registers[count:] - self.used_addr.extend(allocated) - return allocated - - def free_gp(self, registers: list[int]): - for reg in registers: - if reg in self.used_gp: - self.used_gp.remove(reg) - self.gp_registers.append(reg) - self.gp_registers.sort() - - def free_addr(self, registers: list[int]): - for reg in registers: - if reg in self.used_addr: - self.used_addr.remove(reg) - self.addr_registers.append(reg) - self.addr_registers.sort() - - def allocate_fp(self, count: int = 1) -> list[int]: - if len(self.fp_registers) < count: - raise RuntimeError(f"Not enough FP registers available. Need {count}, have {len(self.fp_registers)}") - - # Reverse allocation: prefer high-numbered regs to avoid conflicts with legacy hardcoded forward-allocation. - allocated = list(reversed(self.fp_registers[-count:])) - self.fp_registers = self.fp_registers[:-count] - self.used_fp.extend(allocated) - return allocated - - def free_fp(self, registers: list[int]): - for reg in registers: - if reg in self.used_fp: - self.used_fp.remove(reg) - self.fp_registers.append(reg) - # Keep sorted so allocate_fp's tail-slice continues to return descending IDs. - self.fp_registers.sort() - - -class DeveloperCompiler(TileCompiler): - """ - Developer Compiler: Compiles high-level IR to ISA. - - Owns symbol_table, register_allocator, and the InterruptManager. - Sub-matrix / memory management is inherited from TileCompiler; the - legacy ``self.tile_compiler`` accessor is preserved as a property - returning ``self`` for a handful of remaining external callers. - """ - - _ONLINE_SOFTMAX_FPSRAM_BASE = 10 - - class InterruptManager: - """ - Interrupt Manager — manages execution timing only. - Actual handlers live on DeveloperCompiler as ``_handle_k_start``, - ``_handle_k_prefetch_done``, ``_handle_s_tile_done``, ``_handle_k_end``. - """ - - def __init__(self, compiler: DeveloperCompiler): - self.compiler = compiler - self.enabled = False - - self._k_count = 0 - self._tile_count = 0 - - self.current_matrix: str = "" - self.current_activation: str = "" - self._mlen = compiler.mlen - self._blen = compiler.blen - self._batch = compiler.mlen - - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - @property - def tile_compiler(self): - return self.compiler.tile_compiler - - @property - def k_count(self) -> int: - return self._k_count - - @property - def tile_count(self) -> int: - return self._tile_count - - @property - def batch(self) -> int: - return self._batch - - @property - def out_features(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[0] - return self._mlen - - @property - def hidden_size(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[1] - return self._mlen - - @property - def k_block(self) -> int: - return self._k_block_idx - - @property - def q_block(self) -> int: - return self._q_block_idx - - @property - def s_tile_address(self) -> int: - return self._s_tile_address - - @property - def mlen(self) -> int: - return self._mlen - - @property - def blen(self) -> int: - return self._blen - - def reset(self): - """Reset counters (does not clear current_matrix).""" - self._k_count = 0 - self._tile_count = 0 - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - def enable(self): - self.enabled = True - - def disable(self): - self.enabled = False - - def trigger_k_start(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_start() - - def trigger_k_prefetch_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_k_prefetch_done() - self._k_count += 1 - return result - - def trigger_s_tile_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_s_tile_done() - self._tile_count += 1 - return result - - def trigger_k_end(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_end() - - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - # TileCompiler.__init__ sets mlen, blen, unroll_loops, the HBM/VRAM/FPRAM - # matrices and allocators, loaded_sub_blocks, and _address_cache. - super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) - self.real_data_ratio = real_data_ratio - self.register_allocator = RegisterAllocator() - self.generated_code = "" - self.interrupt = self.InterruptManager(self) - - # Back-compat shim: older callers (and a couple of external ops modules) - # reach into ``compiler.tile_compiler`` directly. After the merge, - # DeveloperCompiler *is* the TileCompiler, so the property just returns - # ``self``. - @property - def tile_compiler(self) -> DeveloperCompiler: - return self - - # Interrupt handler placeholders (overridden by flash-attention passes). - - def _handle_k_start(self) -> str: - return "" - - def _handle_k_prefetch_done(self) -> str: - return "" - - def _handle_s_tile_done(self) -> str: - return "" - - def _handle_k_end(self) -> str: - return "" - - # ========================================================================= - # Flash Attention Implementation - # ========================================================================= - - def _online_softmax_asm( - self, - mlen: int, - s_address: int, - m_start_address: int, - scale: float = 1.0, - ) -> str: - """ - Online Softmax Computation. - - Per row of S: - 1. m_curr = max(S[row], m_old) - 2. m_res = exp(m_old - m_curr) # used to update O downstream - 3. S'[row] = S[row] - m_curr - 4. P[row] = exp(S'[row]) - 5. l_new = l_old * m_res + sum(P[row]) - - FP SRAM layout (from m_start_address): - [0, mlen): m_old / m_curr - [mlen, 2*mlen): m_res = exp(m_old - m_curr) - [2*mlen, 3*mlen): l_old / l_new - """ - gp_regs = self.register_allocator.allocate_gp(4) - gp_s = gp_regs[0] - gp_m_addr = gp_regs[1] - gp_m_res_addr = gp_regs[2] - gp_l_addr = gp_regs[3] - - # Fixed FP register allocation for online softmax pipeline. - # These registers are shared across _online_softmax_asm, _scale_o_asm, - # and _final_scaling_asm — they MUST remain consistent across all three. - # WARNING: Do not use f1-f6 in any code that calls these methods. - fp_m_old = 1 # f1: m_old value - fp_m_res = 2 # f2: exp(m_old - m_curr) - fp_l_old = 3 # f3: l_old value - fp_sum_p = 4 # f4: sum(P) - fp_scale = 5 # f5: scale factor - fp_row_max = 6 # f6: current row max (temporary) - - lines = [] - lines.append("; === Online Softmax ===") - - # Set address registers - lines.append(f"S_ADDI_INT gp{gp_s}, gp0, {s_address}") - lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp0, {m_start_address}") - lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_addr}, {mlen}") - lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") - - # scale factor is pre-loaded at FP SRAM addr 1 by the flash-attention driver. - if scale != 1.0: - lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") - - for row in range(mlen): - lines.append(f"; Row {row}") - - lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, {row}") - lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") - - if scale != 1.0: - lines.append(f"V_MUL_VF gp{gp_s}, gp{gp_s}, f{fp_scale}, 0") - - lines.append(f"V_RED_MAX f{fp_row_max}, gp{gp_s}, 0") - - # m_curr = max(row_max, m_old) — online softmax must retain the running max. - lines.append(f"S_MAX_FP f{fp_m_old}, f{fp_row_max}, f{fp_m_old}") - - lines.append(f"S_SUB_FP f{fp_m_res}, f{fp_m_res}, f{fp_m_old}") - lines.append(f"S_EXP_FP f{fp_m_res}, f{fp_m_res}, 0") - - lines.append(f"S_ST_FP f{fp_m_res}, gp{gp_m_res_addr}, {row}") - lines.append(f"S_ST_FP f{fp_m_old}, gp{gp_m_addr}, {row}") - - lines.append(f"V_SUB_VF gp{gp_s}, gp{gp_s}, f{fp_m_old}, 0, 0") - lines.append(f"V_EXP_V gp{gp_s}, gp{gp_s}, 0, 0") - - lines.append(f"S_LD_FP f{fp_l_old}, gp{gp_l_addr}, {row}") - - lines.append(f"S_ADD_FP f{fp_sum_p}, f0, f0") - lines.append(f"V_RED_SUM f{fp_sum_p}, gp{gp_s}, 0, 0") - - lines.append(f"S_MUL_FP f{fp_l_old}, f{fp_l_old}, f{fp_m_res}") - lines.append(f"S_ADD_FP f{fp_l_old}, f{fp_l_old}, f{fp_sum_p}") - - lines.append(f"S_ST_FP f{fp_l_old}, gp{gp_l_addr}, {row}") - - lines.append(f"S_ADDI_INT gp{gp_s}, gp{gp_s}, {mlen}") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _pv_multiply_asm( - self, - mlen: int, - blen: int, - head_dim: int, - p_address: int, - v_hbm_offset_reg: int, - v_hbm_offset: int, - pv_address: int, - ) -> str: - """ - Compute PV = P @ V via M_MM. - - P: (mlen, mlen) in VRAM (softmax output) - V: (mlen, head_dim) in HBM (prefetched into MSRAM in mlen-wide column blocks) - PV: (mlen, head_dim) in VRAM - - M_MM computes one (blen, mlen) @ (mlen, blen) -> (blen, blen) in a single op - (K=mlen done in one shot). For head_dim > mlen, V is split into head_dim/mlen - column blocks; the outer loop iterates blocks, middle loop iterates blen-wide - V columns within a block, inner loop iterates blen-wide P rows. - """ - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(5) - gp_p = gp_regs[0] - gp_v = gp_regs[1] - gp_pv = gp_regs[2] - gp_hbm = gp_regs[3] - gp_stride = gp_regs[4] - - num_v_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === PV Multiply (P @ V) using M_MM ===") - lines.append(f"; P: ({mlen}, {mlen}) @ V: ({mlen}, {head_dim}) -> PV: ({mlen}, {head_dim})") - lines.append("; M_MM: (blen, mlen) @ (mlen, blen) -> (blen, blen), K=mlen in one shot") - lines.append(f"; V split into {num_v_col_blocks} column blocks of width {mlen}") - lines.append("; Storage layout: (batch, mlen, hidden/mlen), column-block major") - - # STRIDE was set to mlen by the flash-attention driver — do not overwrite it here. - # M_MM_WO requires a nonzero stride reg (gp0=0 would be interpreted as stride=1). - # With column-block-major storage, consecutive rows within a column block are - # adjacent, so the writeback stride = 1. - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") - - for v_col_block in range(num_v_col_blocks): - lines.append( - f"; --- V column block {v_col_block} (columns {v_col_block * mlen} to {(v_col_block + 1) * mlen - 1}) ---" - ) - - # Prefetch V[:, v_col_block*mlen:(v_col_block+1)*mlen] (mlen × mlen) to MSRAM. - # V is row-major in HBM: V[row, col] at offset row*head_dim + col, so the - # column-block base offset = v_hbm_offset + v_col_block * mlen (elements). - v_block_hbm_offset = v_hbm_offset + v_col_block * mlen - lines.append(f"S_ADDI_INT gp{gp_v}, gp0, 0") - lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") - - # mat_offset constraint: < mlen and a multiple of blen. - for v_col in range(mlen // blen): - lines.append(f"; V column {v_col_block * mlen + v_col * blen}") - - v_msram_offset = v_col * blen - lines.append(f"S_ADDI_INT gp{gp_v}, gp0, {v_msram_offset}") - - for p_row in range(mlen // blen): - p_row_addr = p_address + p_row * blen * mlen - lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_row_addr}") - - lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") - - # PV[row, col] addr = base + col_block * mlen * mlen + row * mlen + col_in_block - # with row = p_row * blen and col_in_block = v_col * blen. - pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen - lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_address + pv_offset}") - lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _scale_o_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - m_res_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """Scale each row of O by m_res: O[row] *= m_res[row].""" - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_m_res = gp_regs[0] - gp_o = gp_regs[1] - fp_m_res = 1 - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Scale O by m_res ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") - - for row in range(mlen): - lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, {row}") - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _add_pv_to_o_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - pv_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """Accumulate PV into O: O[row] += PV[row].""" - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_o = gp_regs[0] - gp_pv = gp_regs[1] - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Add PV to O ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - for row in range(mlen): - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - pv_addr = pv_address + col_block * mlen * mlen + row * mlen - - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_addr}") - lines.append(f"V_ADD_VV gp{gp_o}, gp{gp_o}, gp{gp_pv}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _final_scaling_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - l_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """ - Final scaling: O[row] /= l[row]. - - V_MUL_VF processes mlen elements at a time; when head_dim > mlen, - each row is split into head_dim // mlen mlen-wide blocks. - """ - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_l = gp_regs[0] - gp_o = gp_regs[1] - fp_l = 1 - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Final Scaling O = O / l ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") - - for row in range(mlen): - lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, {row}") - lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _reset_fpsram_asm( - self, - start_address: int, - count: int, - value_address: int, # FP SRAM slot: 0 = zero, 2 = -inf - ) -> str: - """Reset a region of FP SRAM to the value at value_address.""" - gp_regs = self.register_allocator.allocate_gp(1) - gp_addr = gp_regs[0] - - lines = [] - lines.append(f"; Reset FP SRAM [{start_address}, {start_address + count})") - - lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {start_address}") - # Use f1 for FP scalar - FP registers don't go through GP allocator - lines.append(f"S_LD_FP f1, gp0, {value_address}") - - for i in range(count): - lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _reset_vram_asm( - self, - start_address: int, - rows: int, - cols: int, - total_rows: int, - mlen: int = 64, - row_offset: int = 0, - ) -> str: - """ - Reset a region of VRAM to zero. - - V_MUL_VF processes mlen elements at a time; when cols > mlen, each - row is split into cols // mlen mlen-wide blocks. - """ - gp_regs = self.register_allocator.allocate_gp(1) - gp_addr = gp_regs[0] - - num_col_blocks = (cols + mlen - 1) // mlen - - lines = [] - lines.append(f"; Reset VRAM rows [{row_offset}, {row_offset + rows}) of matrix at {start_address}") - lines.append(f"; {rows} rows x {cols} cols, {num_col_blocks} blocks per row") - lines.append("; Storage layout: (total_rows, mlen, cols/mlen), column-block major") - lines.append(f"; total_rows = {total_rows}, row_offset = {row_offset}") - - for row in range(rows): - actual_row = row_offset + row - for col_block in range(num_col_blocks): - addr = start_address + col_block * total_rows * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") - lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def load_batch( - self, - hbm_object_name: str, - vram_object_name: str, - vlen: int = 64, - preload_len: int = 4, - ) -> str: - """ - Load a Batch tensor from HBM to VRAM. - - HBM storage is MXFP (1 scale per 8 elements), so HBM actual size = - logical size * real_data_ratio = 1.125. VRAM stores only the vector - data (no scale), so VRAM size = logical size. - - Order (matters): allocate VRAM → register in symbol table → emit ISA. - """ - hbm_layout = self.get_hbm_layout(hbm_object_name) - h, w = hbm_layout.full_shape - hbm_addr = hbm_layout.hbm_base_addr - size = h * w - vram_base = self.vram_allocator.allocate(size, name=vram_object_name) - self.add_vram_object( - name=vram_object_name, - shape=(h, w), - vram_addr=vram_base, - dtype="fp16", - kind="Batch", - allocate_if_none=False, - strict=False, - ) - - addr_reg = self.register_allocator.allocate_addr(1)[0] - gp_regs_for_addr = self.register_allocator.allocate_gp(1) - - isa_code = f"; Load_Batch {hbm_object_name} -> {vram_object_name}\n" - isa_code += f"; HBM[{hbm_addr}] → VRAM[{vram_base}], shape=({h}, {w})\n" - - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] - ) - - # preload_act_asm requires 5 GP registers: [a_actual, stride, result, outer_loop, inner_loop]. - gp_regs_for_preload = self.register_allocator.allocate_gp(5) - isa_code += reset_reg_asm(alive_registers=gp_regs_for_preload) - - isa_code += preload_act_asm( - vlen=vlen, - preload_len=preload_len, - batch=h, - hidden_size=w, - alive_registers=gp_regs_for_preload, - act_vram_offset=vram_base, - activation_offset_reg=addr_reg, - stride_size=w, - ) - - self.register_allocator.free_gp(gp_regs_for_addr) - self.register_allocator.free_gp(gp_regs_for_preload) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - - return isa_code - - def store_to_hbm( - self, - tensor_name: str, - hbm_addr: int | None = None, - hbm_object_name: str | None = None, - hbm_addr_reg: int | None = None, - vlen: int = 64, - precision: int = 0, # 0 = Activation, 1 = KeyValue - store_amount: int = 4, # HBM_V_Writeback_Amount - ) -> str: - """ - Write tensor from VRAM back to HBM. - - Used to spill computed intermediates (e.g., K) from VRAM to HBM so - downstream ops (e.g., QK^T) can read them from HBM. Emits - ``store_act_asm`` for tensors of any supported size. - """ - if tensor_name not in self: - raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") - - tensor_info = self[tensor_name] - - # Batch and VRAMMatrix share the same VRAM storage layout. - if tensor_info.kind not in ("Batch", "VRAMMatrix"): - raise ValueError( - f"Tensor '{tensor_name}' must be Batch or VRAMMatrix to store from VRAM, got {tensor_info.kind}" - ) - - if tensor_info.vram_addr is None: - raise ValueError(f"Tensor '{tensor_name}' has no VRAM address to store") - - if hbm_addr is None: - if tensor_info.hbm_addr >= 0: - hbm_addr = tensor_info.hbm_addr - else: - raise ValueError(f"Tensor '{tensor_name}' has no HBM address. Please specify hbm_addr.") - - batch_size = tensor_info.shape[0] - hidden_size = tensor_info.shape[1] - - isa_code = f"; Store {tensor_name} from VRAM to HBM\n" - isa_code += f"; VRAM[{tensor_info.vram_addr}] -> HBM[{hbm_addr}], shape=({batch_size}, {hidden_size})\n" - - gp_regs = self.register_allocator.allocate_gp(5) - - if hbm_addr_reg is None: - addr_regs = self.register_allocator.allocate_addr(1) - hbm_addr_reg = addr_regs[0] - need_free_addr = True - else: - addr_regs = [] - need_free_addr = False - - try: - gp_regs_for_addr = self.register_allocator.allocate_gp(2) - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[hbm_addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] - ) - self.register_allocator.free_gp(gp_regs_for_addr) - - isa_code += store_act_asm( - vlen=vlen, - batch=batch_size, - hidden_size=hidden_size, - alive_registers=gp_regs, - act_vram_offset=tensor_info.vram_addr, - hbm_addr_reg=hbm_addr_reg, - stride_size=hidden_size, - store_amount=store_amount, - ) - - if tensor_info.hbm_addr < 0 or tensor_info.hbm_addr != hbm_addr: - tensor_info.hbm_addr = hbm_addr - # HBM stores the MXFP-expanded size (logical size × real_data_ratio). - size = batch_size * hidden_size - tensor_info.hbm_size = int(size * self.real_data_ratio) - finally: - self.register_allocator.free_gp(gp_regs) - if need_free_addr: - self.register_allocator.free_addr(addr_regs) - - if hbm_object_name is not None: - self.add_hbm_object( - name=hbm_object_name, - hbm_addr=hbm_addr, - shape=(batch_size, hidden_size), - ) - - self.generated_code += isa_code - - return isa_code - - def normalize( - self, - tensor_name: str, - mode: str = "rms", - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> str: - """ - Normalize a VRAM tensor in-place. - - Supports: - - mode="rms": RMSNorm - - mode="layer": LayerNorm - - Args: - tensor_name: Tensor name in symbol table (must have VRAM address) - mode: "rms" or "layer" - eps_offset: FPRAM address of epsilon - reci_hid_offset: FPRAM address of 1/hidden_dim - vlen: vector length (default: self.mlen) - scratchpad_vram_addr: scratchpad VRAM address (default: auto-allocate temporary space) - """ - if tensor_name not in self: - raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") - - tensor_info = self[tensor_name] - if tensor_info.vram_addr is None: - raise ValueError(f"Tensor '{tensor_name}' has no VRAM address") - - batch_size, hidden_dim = tensor_info.shape - if vlen is None: - vlen = self.mlen - if hidden_dim % vlen != 0: - raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by vlen ({vlen}) for normalization_asm") - - mode = mode.lower() - if mode not in ("rms", "layer"): - raise ValueError(f"Unsupported normalization mode: {mode}. Expected 'rms' or 'layer'.") - - gp_regs = self.register_allocator.allocate_gp(4) - - temp_scratchpad_name = None - if scratchpad_vram_addr is None: - temp_scratchpad_name = f"__norm_scratch__{tensor_name}__{len(self.generated_code)}" - scratchpad_vram_addr = self.vram_allocator.allocate(vlen, name=temp_scratchpad_name) - - try: - isa_code = f"; Normalize ({mode}) {tensor_name}, shape=({batch_size}, {hidden_dim})\n" - if mode == "rms": - isa_code += rms_norm_asm( - _eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - alive_registers=gp_regs, - activation_base_address=tensor_info.vram_addr, - scratchpad_base_address=scratchpad_vram_addr, - vlen=vlen, - batch_size=batch_size, - hidden_dim=hidden_dim, - ) - else: - isa_code += layer_norm_asm( - _eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - alive_registers=gp_regs, - activation_base_address=tensor_info.vram_addr, - scratchpad_base_address=scratchpad_vram_addr, - vlen=vlen, - batch_size=batch_size, - hidden_dim=hidden_dim, - ) - - self.generated_code += isa_code - return isa_code - finally: - # Always release allocated GP registers used by normalization template. - self.register_allocator.free_gp(gp_regs) - if temp_scratchpad_name is not None: - self.vram_allocator.free(temp_scratchpad_name, strict=False) - - def rope( - self, - x_name: str, - x_rot_name: str, - cos_name: str, - sin_name: str, - ) -> str: - """Apply RoPE in-place: x = x * cos + rotate_half(x) * sin - - All four tensors must already be in VRAM with the same shape (seq_len, head_dim). - x_rot must be preloaded by the caller as rotate_half(x). - """ - x_info = self[x_name] - xrot_info = self[x_rot_name] - cos_info = self[cos_name] - sin_info = self[sin_name] - - if x_info.vram_addr is None: - raise ValueError(f"Tensor '{x_name}' has no VRAM address") - - seq_len, head_dim = x_info.shape - vlen = self.mlen - - if head_dim % vlen != 0: - raise ValueError(f"head_dim ({head_dim}) must be divisible by vlen ({vlen}) for rope") - - gp_regs = self.register_allocator.allocate_gp(5) - - scratch_name = f"__rope_scratch__{x_name}__{len(self.generated_code)}" - scratch_addr = self.vram_allocator.allocate(vlen, name=scratch_name) - - try: - isa_code = rope_asm( - alive_registers=gp_regs, - x_base_address=x_info.vram_addr, - x_rot_base_address=xrot_info.vram_addr, - cos_base_address=cos_info.vram_addr, - sin_base_address=sin_info.vram_addr, - scratchpad_base_address=scratch_addr, - vlen=vlen, - seq_len=seq_len, - head_dim=head_dim, - ) - self.generated_code += isa_code - return isa_code - finally: - self.register_allocator.free_gp(gp_regs) - self.vram_allocator.free(scratch_name, strict=False) - - def get_code(self) -> str: - """Get all accumulated generated ISA code""" - return self.generated_code - - def reset(self): - """Reset compiler state (clear code, but retain symbol table)""" - self.generated_code = "" - self.register_allocator = RegisterAllocator() - # Call TileCompiler.reset() explicitly since the merged class shadows it. - TileCompiler.reset(self) - - def print_symbol_table(self): - """Print symbol table""" - self.print_table() - - def get_symbol_table(self): - """Get managed object table view.""" - return self - - def get_tensor_info(self, name: str): - """Get unified tensor/object info by name.""" - return self[name] - - def add_hbm_object( - self, - name: str, - hbm_addr: int, - shape: tuple[int, int], - real_data_ratio: float = 1.125, - ): - """Register an HBM object and build its HBM layout. - - Wraps ``TileCompiler.add_hbm_object`` with a different positional - parameter order ``(name, hbm_addr, shape, ...)`` that all DeveloperCompiler - callers use. - """ - return TileCompiler.add_hbm_object( - self, - name=name, - shape=shape, - hbm_addr=hbm_addr, - real_data_ratio=real_data_ratio, - ) - - def free_hbm_object(self, name: str, strict: bool = False): - """Free an HBM object by name (defaults to non-strict).""" - return TileCompiler.free_hbm_object(self, name, strict=strict) - - def get_vram_addr(self, name: str) -> int: - """Get VRAM base address of an object.""" - info = self.get_tensor_info(name) - if info.vram_addr is None: - raise ValueError(f"Object '{name}' has no VRAM address") - return info.vram_addr - - def get_vram_tile_addr( - self, - name: str, - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> int: - """ - Get VRAM address of a specific tile (sub-block) in a VRAM matrix. - - Args: - name: matrix name - tile_row_idx: tile row index (0-based) - tile_col_idx: tile col index (0-based) - """ - self._ensure_vram_matrix_layout(name) - sub = self.get_vram_sub_block(name, tile_row_idx, tile_col_idx) - return sub.vram_addr - - def ensure_hbm_sub_matrix( - self, - name: str, - hbm_addr: int, - shape: tuple[int, int], - real_data_ratio: float = 1.125, - ): - """Ensure HBM matrix layout exists.""" - if name in self.hbm_matrices: - return - self.add_hbm_object( - name=name, - hbm_addr=hbm_addr, - shape=shape, - real_data_ratio=real_data_ratio, - ) - - def ensure_vram_matrix_layout(self, name: str, shape: tuple[int, int]): - """Ensure VRAM matrix layout exists for an already allocated VRAM object.""" - if name in self.vram_matrices: - return - vram_addr = self.get_vram_addr(name) - self.add_vram_object( - name=name, - shape=shape, - vram_addr=vram_addr, - allocate_if_none=False, - ) - - def free_vram_object(self, name: str, strict: bool = False): - """Free a VRAM object by name (defaults to non-strict).""" - return TileCompiler.free_vram_object(self, name, strict=strict) - - # ========================================================================= - # FP Register & FPRAM Management (inlined from former FPRAMCompiler). - # All state lives on self (register_allocator, fpram_allocator, etc.). - # ========================================================================= - - @property - def _reg(self) -> RegisterAllocator: - """Shorthand for self.register_allocator (used by FPVar ISA helpers).""" - return self.register_allocator - - @property - def _unroll(self) -> bool: - """Shorthand for self.unroll_loops.""" - return self.unroll_loops - - def _emit(self, isa_code: str) -> str: - """Append ISA text to the output buffer and return it.""" - self.generated_code += isa_code - return isa_code - - # ------------------------------------------------------------------ - # FP Register management - # ------------------------------------------------------------------ - - def allocate_fp_reg(self, count: int = 1) -> list[int]: - """Allocate FP registers (f0-f7).""" - return self._reg.allocate_fp(count) - - def free_fp_reg(self, registers: list[int]): - """Free FP registers.""" - self._reg.free_fp(registers) - - # ------------------------------------------------------------------ - # FPRAM address-space management - # ------------------------------------------------------------------ - - def allocate_fpram(self, name: str, size: int) -> int: - """Allocate FPRAM space, returns base address.""" - info = self.add_fpram_object(name=name, size=size) - if info.fpram_addr is None: - raise RuntimeError(f"Failed to allocate FPRAM for '{name}'") - return info.fpram_addr - - def free_fpram(self, name: str, strict: bool = True): - """Free FPRAM object by name.""" - return self.free_fpram_object(name, strict=strict) - - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot.""" - return self.fpram_allocator.save_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot.""" - self.fpram_allocator.restore_state(snapshot) - - def list_fpram_allocations(self) -> list[str]: - """List currently allocated FPRAM object names.""" - return list(self.fpram_allocator.allocations.keys()) - - def get_fpram_addr(self, name: str) -> int: - """Get FPRAM base address from object name.""" - return self.get_fpram_layout(name).fpram_addr - - def get_fpram_size(self, name: str) -> int: - """Get FPRAM allocation size from object name.""" - return self.get_fpram_layout(name).size - - # ========================================================================= - # FPVar ISA helpers (address-based) - # ========================================================================= - - def fpvar_copy_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Copy skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Copy: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_addr}:{src_addr + count}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_fill_from_fpram_asm(self, dst_addr: int, src_fpram_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Fill skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Fill: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_fpram_addr}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_fpram_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - if self._unroll: - for i in range(count): - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_reci_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Reci skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Reci: dst = 1/src, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append("S_RECI_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append("S_RECI_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_exp_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Exp skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Exp: dst = exp(src), count={count}"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append("S_EXP_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append("S_EXP_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_add_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Add skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Add: dst = src1 + src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_sub_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Sub skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Sub: dst = src1 - src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_SUB_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_SUB_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_mul_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Mul skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Mul: dst = src1 * src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_MUL_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_MUL_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_max_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Max skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Max: dst = max(src1, src2), count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_MAX_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_MAX_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_sum_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Sum skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Sum: FPRAM[{dst_addr}] = sum(FPRAM[{src_addr}:{src_addr + count}])"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append("S_ADD_FP f1, f0, f0") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f2, gp{gp_src}, {i}") - lines.append("S_ADD_FP f1, f1, f2") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f2, gp{gp_src}, 0") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_shift_asm( - self, - src_addr: int, - dst_addr: int, - count: int, - shift: int, - fill_fpram_addr: int = 0, - ) -> str: - """ - Shift FPVar into dst. - - shift > 0: right shift (leading positions filled) - - shift < 0: left shift (trailing positions filled) - """ - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_fill = gp - lines = [f"; FPVar Shift: dst=shift(src, shift={shift}), count={count}, fill=FPRAM[{fill_fpram_addr}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append(f"S_ADDI_INT gp{gp_fill}, gp0, {fill_fpram_addr}") - lines.append(f"S_LD_FP f3, gp{gp_fill}, 0") - - for i in range(count): - src_idx = i - shift - if 0 <= src_idx < count: - lines.append(f"S_LD_FP f1, gp{gp_src}, {src_idx}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"S_ST_FP f3, gp{gp_dst}, {i}") - - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - # ========================================================================= - # FPVar helpers (name-based wrappers over the address-based ISA generators) - # ========================================================================= - - def fpram_copy(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_copy_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_reci(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_reci_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_exp(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_exp_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_add(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_add_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_sub(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_sub_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_mul(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_mul_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_max(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_max_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_sum(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = self.get_fpram_size(src_name) - return self.fpvar_sum_asm( - self.get_fpram_addr(src_name), - self.get_fpram_addr(dst_name), - count, - ) - - def fpram_shift( - self, - src_name: str, - dst_name: str, - shift: int, - count: int | None = None, - fill_fpram_name: str | None = None, - ) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - fill_addr = 0 if fill_fpram_name is None else self.get_fpram_addr(fill_fpram_name) - return self.fpvar_shift_asm( - src_addr=self.get_fpram_addr(src_name), - dst_addr=self.get_fpram_addr(dst_name), - count=count, - shift=shift, - fill_fpram_addr=fill_addr, - ) - - def fpram_fill_from_fpram(self, dst_name: str, src_fpram_addr: int, count: int | None = None) -> str: - if count is None: - count = self.get_fpram_size(dst_name) - return self.fpvar_fill_from_fpram_asm( - dst_addr=self.get_fpram_addr(dst_name), - src_fpram_addr=src_fpram_addr, - count=count, - ) - - # ========================================================================= - # Tile-row helpers (name-based) - # ========================================================================= - - def tile_row_max( - self, - source_matrix: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_max_asm(source_addr, row_map) - - def tile_row_sum( - self, - source_matrix: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_sum_asm(source_addr, row_map) - - def tile_row_exp( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_exp_asm(matrix_addr, rows) - - def tile_row_reci( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_reci_asm(matrix_addr, rows) - - def tile_row_sub_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_sub_fp_asm(matrix_addr, row_map) - - def tile_row_mul_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_asm(matrix_addr, row_map) - - def tile_row_add_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_add_fp_asm(matrix_addr, row_map) - - def tile_row_add( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_add_asm(dst_addr, src_addr, rows) - - def tile_row_sub( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_sub_asm(dst_addr, src_addr, rows) - - def tile_row_mul( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_mul_asm(dst_addr, src_addr, rows) - - def tile_row_mul_fp_broadcast( - self, - matrix_name: str, - fpram_scalar_addr: int, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_broadcast_asm(matrix_addr, fpram_scalar_addr, rows) - - def vram_fill_zero( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.vram_fill_zero_asm(matrix_addr, rows) - - # ========================================================================= - # Tile-row ISA helpers (address-based) - # ========================================================================= - - def _arith_progression(self, values: list[int]) -> tuple[int, int, int] | None: - """Return (start, count, step) if values form an arithmetic progression.""" - if not values: - return None - if len(values) == 1: - return (values[0], 1, 0) - step = values[1] - values[0] - for i in range(2, len(values)): - if values[i] - values[i - 1] != step: - return None - if step == 0: - return None # Constant sequence (step=0, count>1) would cause infinite HW loop - return (values[0], len(values), step) - - def tile_row_max_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; Tile Row Max from VRAM[{source_vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {source_vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_RED_MAX f1, gp{gp_src}, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = source_vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_RED_MAX f1, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sum_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; Tile Row Sum from VRAM[{source_vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {source_vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append("S_ADD_FP f1, f0, f0") - lines.append(f"V_RED_SUM f1, gp{gp_src}, 0, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = source_vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append("S_ADD_FP f1, f0, f0") - lines.append(f"V_RED_SUM f1, gp{gp_src}, 0, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_exp_asm(self, vram_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(2) - gp_src, gp_loop = gp - lines = [f"; Tile Row Exp on VRAM[{vram_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_EXP_V gp{gp_src}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_EXP_V gp{gp_src}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_reci_asm(self, vram_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(2) - gp_src, gp_loop = gp - lines = [f"; Tile Row Reciprocal on VRAM[{vram_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_RECI_V gp{gp_src}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_RECI_V gp{gp_src}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sub_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Sub FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_SUB_VF gp{gp_src}, gp{gp_src}, f1, 0, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_SUB_VF gp{gp_src}, gp{gp_src}, f1, 0, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Mul FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_MUL_VF gp{gp_src}, gp{gp_src}, f1, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_MUL_VF gp{gp_src}, gp{gp_src}, f1, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_add_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Add FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_ADD_VF gp{gp_src}, gp{gp_src}, f1, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_ADD_VF gp{gp_src}, gp{gp_src}, f1, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_add_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Add: VRAM[{dst_addr}] += VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sub_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Sub: VRAM[{dst_addr}] -= VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_SUB_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_SUB_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Mul: VRAM[{dst_addr}] *= VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_MUL_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_MUL_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_fp_broadcast_asm(self, vram_addr: int, fpram_scalar_addr: int, rows: list[int]) -> str: - row_map = [(r, fpram_scalar_addr) for r in rows] - return self.tile_row_mul_fp_asm(vram_addr, row_map) - - def vram_fill_zero_asm( - self, - vram_addr: int, - rows: list[int], - ) -> str: - """ - VRAM Fill Zero: fill specified rows with 0. - - For each row_idx in rows: - VRAM[row] = 0 - """ - if not rows: - isa_code = f"; === VRAM Fill Zero: VRAM[{vram_addr}] rows [] = 0 ===\n" - self.generated_code += isa_code - return isa_code - - gp_regs = self.register_allocator.allocate_gp(2) - gp_dst, gp_loop = gp_regs - - lines = [] - lines.append(f"; === VRAM Fill Zero: VRAM[{vram_addr}] rows {rows} = 0 ===") - prog = None if self.unroll_loops else self._arith_progression(rows) - - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_MUL_VF gp{gp_dst}, gp{gp_dst}, f0, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {row_addr}") - lines.append(f"V_MUL_VF gp{gp_dst}, gp{gp_dst}, f0, 0") - - self.register_allocator.free_gp(gp_regs) - - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def reset_mram(self) -> str: - """ - Reset MRAM allocator, free all allocated space - Used in scenarios where sub-blocks need to be reloaded within a for loop - """ - self.mram_allocator.reset() - self.loaded_sub_blocks.clear() - - isa_code = "; === Reset MRAM ===\n" - self.generated_code += isa_code - return isa_code - - def load_sub_matrix_row( - self, - name: str, - row_idx: int, - mram_start_addr: int | None = None, - ) -> str: - """Load entire row sub-blocks from HBM to MRAM: matrix[row_idx][:].""" - layout = self.get_hbm_layout(name) - num_col_blocks = layout.num_col_blocks - block_size = self.mlen * self.mlen - - if mram_start_addr is None: - total_size = num_col_blocks * block_size - mram_start_addr = self.mram_allocator.allocate(f"{name}[{row_idx}][:]", total_size) - - gp_regs = self.register_allocator.allocate_gp(4) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] - ) - - isa_code += self.load_row_sub_matrices_asm( - name=name, row_idx=row_idx, mram_start_addr=mram_start_addr, hbm_addr_reg=addr_reg, gp_regs=gp_regs - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - return isa_code - - def load_sub_matrix_col( - self, - name: str, - col_idx: int, - mram_start_addr: int | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Load entire column sub-blocks from HBM to MRAM: matrix[:][col_idx]. - Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. - """ - layout = self.get_hbm_layout(name) - num_row_blocks = layout.num_row_blocks - block_size = self.mlen * self.mlen - - if mram_start_addr is None: - effective_count = k_block_count if k_block_count is not None else num_row_blocks - total_size = effective_count * block_size - mram_start_addr = self.mram_allocator.allocate(f"{name}[:][{col_idx}]", total_size) - - gp_regs = self.register_allocator.allocate_gp(3) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] - ) - - isa_code += self.load_col_sub_matrices_asm( - name=name, - col_idx=col_idx, - mram_start_addr=mram_start_addr, - hbm_addr_reg=addr_reg, - gp_regs=gp_regs, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - return isa_code - - def allocate_vram_matrix( - self, - name: str, - rows: int, - cols: int, - strict: bool = True, - ) -> int: - """Allocate a VRAM matrix large enough to hold combined results of multiple sub-blocks. Returns the VRAM base address.""" - size = rows * cols - vram_addr = self.vram_allocator.allocate(size, name=name) - - self.add_vram_object( - name=name, - shape=(rows, cols), - vram_addr=vram_addr, - dtype="fp32", - kind="VRAMMatrix", - allocate_if_none=False, - strict=strict, - ) - - isa_code = f"; Allocate VRAM Matrix {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" - self.generated_code += isa_code - - return vram_addr - - def _ensure_vram_matrix_layout(self, matrix_name: str): - """Ensure a VRAM-resident tensor has a block layout in TileCompiler.""" - if matrix_name not in self: - raise KeyError(f"Matrix '{matrix_name}' not found in symbol table") - - info = self[matrix_name] - if info.vram_addr is None: - raise ValueError(f"Matrix '{matrix_name}' has no VRAM address") - - try: - self.get_vram_layout(matrix_name) - except KeyError: - self.register_vram_matrix( - name=matrix_name, - shape=info.shape, - vram_base_addr=info.vram_addr, - ) - - def vram_block_add_to( - self, - src1_matrix: str, - src1_row_idx: int, - src1_col_idx: int, - src2_matrix: str, - src2_row_idx: int, - src2_col_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - ) -> str: - """ - mlen x mlen block add: - target[rt][ct] = src1[r1][c1] + src2[r2][c2] - - Source/target may be the same matrix (supports in-place overwrite). - """ - self._ensure_vram_matrix_layout(src1_matrix) - self._ensure_vram_matrix_layout(src2_matrix) - self._ensure_vram_matrix_layout(target_matrix) - - gp_regs = self.register_allocator.allocate_gp(4) - isa_code = self.vram_block_add_asm( - src1_name=src1_matrix, - src1_row_idx=src1_row_idx, - src1_col_idx=src1_col_idx, - src2_name=src2_matrix, - src2_row_idx=src2_row_idx, - src2_col_idx=src2_col_idx, - target_name=target_matrix, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - gp_regs=gp_regs, - ) - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - def vram_matrix_add( - self, - dst_matrix: str, - src_matrix: str, - dst_row_offset: int = 0, - src_row_offset: int = 0, - num_rows: int | None = None, - ) -> str: - """ - General VRAM Matrix Addition: dst[row_offset:] += src. - - row_offsets are logical rows (not VRAM addresses); num_rows defaults - to the source matrix's row count. - """ - dst_info = self[dst_matrix] - src_info = self[src_matrix] - - # Block-add path depends on TileCompiler VRAM layouts. - self._ensure_vram_matrix_layout(dst_matrix) - self._ensure_vram_matrix_layout(src_matrix) - - dst_addr = dst_info.vram_addr - src_addr = src_info.vram_addr - - dst_rows, dst_cols = dst_info.shape - src_rows, src_cols = src_info.shape - - if num_rows is None: - num_rows = src_rows - - # Ensure column count matches - assert dst_cols == src_cols, f"Column mismatch: dst={dst_cols}, src={src_cols}" - assert dst_row_offset + num_rows <= dst_rows, ( - f"dst row range out of bounds: offset={dst_row_offset}, num_rows={num_rows}, dst_rows={dst_rows}" - ) - assert src_row_offset + num_rows <= src_rows, ( - f"src row range out of bounds: offset={src_row_offset}, num_rows={num_rows}, src_rows={src_rows}" - ) - lines = [] - lines.append( - f"; === VRAM Matrix Add: " - f"{dst_matrix}[{dst_row_offset}:{dst_row_offset + num_rows}] += " - f"{src_matrix}[{src_row_offset}:{src_row_offset + num_rows}] ===" - ) - lines.append(f"; dst shape: {dst_info.shape}, src shape: {src_info.shape}") - - # Prefer block add path so we can reuse the compact C_LOOP-based add kernel. - block_aligned = ( - dst_cols % self.mlen == 0 - and src_cols % self.mlen == 0 - and dst_row_offset % self.mlen == 0 - and src_row_offset % self.mlen == 0 - and num_rows % self.mlen == 0 - ) - - if block_aligned: - num_row_blocks = num_rows // self.mlen - num_col_blocks = dst_cols // self.mlen - dst_row_block_base = dst_row_offset // self.mlen - src_row_block_base = src_row_offset // self.mlen - lines.append(f"; block add path: row_blocks={num_row_blocks}, col_blocks={num_col_blocks}") - - for row_block in range(num_row_blocks): - for col_block in range(num_col_blocks): - gp_regs = self.register_allocator.allocate_gp(4) - lines.append( - self.vram_block_add_asm( - src1_name=dst_matrix, - src1_row_idx=dst_row_block_base + row_block, - src1_col_idx=col_block, - src2_name=src_matrix, - src2_row_idx=src_row_block_base + row_block, - src2_col_idx=col_block, - target_name=dst_matrix, - target_row_idx=dst_row_block_base + row_block, - target_col_idx=col_block, - gp_regs=gp_regs, - ).rstrip("\n") - ) - self.register_allocator.free_gp(gp_regs) - else: - # Fallback for non-mlen-aligned ranges. - gp_regs = self.register_allocator.allocate_gp(2) - gp_dst = gp_regs[0] - gp_src = gp_regs[1] - num_col_blocks = dst_cols // self.mlen - lines.append(f"; fallback row-wise path: num_rows={num_rows}, num_col_blocks={num_col_blocks}") - - for row in range(num_rows): - dst_actual_row = dst_row_offset + row - src_actual_row = src_row_offset + row - - for col_block in range(num_col_blocks): - dst_block_addr = dst_addr + col_block * dst_rows * self.mlen + dst_actual_row * self.mlen - src_block_addr = src_addr + col_block * src_rows * self.mlen + src_actual_row * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_block_addr}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_block_addr}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp_regs) - - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def vram_sub_projection_to( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_col_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Sub-block multiplication: - target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[:][mram_col_idx]. - Target matrix must have been allocated via allocate_vram_matrix. - """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_asm( - vram_mat_name=vram_mat_name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_mat_name, - mram_col_idx=mram_col_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - def vram_sub_projection_T_to( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_row_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - ) -> str: - """ - Transposed sub-block multiplication: - target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[mram_row_idx][:]^T. - - Used by Flash Attention for S = Q @ K^T: - Q[i][:]: (mlen, hidden_size) row sub-block - K[j][:]: (mlen, hidden_size) row sub-block, transposed to (hidden_size, mlen) - S[i][j]: (mlen, mlen) - """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_T_asm( - vram_mat_name=vram_mat_name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_mat_name, - mram_row_idx=mram_row_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, - ) - - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - # ========================================================================= - # Expanded Flash Attention Operations - # ========================================================================= - - def init_online_softmax( - self, - q_idx: int, - o_matrix: str, - seq_len: int, - head_dim: int, - ) -> str: - """ - Initialize Online Softmax state for Q block q_idx: - m_old = -inf (FP SRAM), l = 0 (FP SRAM), O_row = 0 (VRAM). - """ - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_old_addr = fp_sram_start - l_addr = fp_sram_start + 2 * self.mlen # skip m_res region - - o_info = self[o_matrix] - o_vram_addr = o_info.vram_addr - row_offset = q_idx * self.mlen - - isa_code = f"; === Init Online Softmax for Q block {q_idx} ===\n" - - isa_code += self._reset_fpsram_asm(m_old_addr, self.mlen, 2) # slot 2 = -inf - isa_code += self._reset_fpsram_asm(l_addr, self.mlen, 0) # slot 0 = 0.0 - isa_code += self._reset_vram_asm( - start_address=o_vram_addr, - rows=self.mlen, - cols=head_dim, - total_rows=seq_len, - mlen=self.mlen, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - def online_softmax_block( - self, - s_block_matrix: str, - scale: float, - ) -> str: - """ - Run Online Softmax on one S block. - Input: S_block (mlen × mlen) in VRAM - Output: P (mlen × mlen) in-place in VRAM - Updates: m_old, m_res, l in FP SRAM - ``scale`` is the QK^T scaling factor (typically 1/sqrt(d)). - """ - s_info = self[s_block_matrix] - s_address = s_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_start_address = fp_sram_start - - isa_code = f"; === Online Softmax Block {s_block_matrix} ===\n" - isa_code += self._online_softmax_asm( - mlen=self.mlen, s_address=s_address, m_start_address=m_start_address, scale=scale - ) - - self.generated_code += isa_code - return isa_code - - def compute_pv( - self, - s_block_matrix: str, - v_sub_matrix: str, - k_idx: int, - pv_matrix: str, - head_dim: int, - ) -> str: - """ - Compute PV = P @ V[k_idx]. - - P lives in s_block_matrix (softmax result); V is prefetched from - HBM; PV is written to VRAM via pv_matrix. - """ - s_info = self[s_block_matrix] - p_address = s_info.vram_addr - - pv_info = self[pv_matrix] - pv_address = pv_info.vram_addr - - v_layout = self.get_hbm_layout(v_sub_matrix) - v_hbm_offset = k_idx * self.mlen * head_dim - - isa_code = f"; === Compute PV = P @ V[k_idx={k_idx}] ===\n" - - addr_regs = self.register_allocator.allocate_addr(1) - v_hbm_reg = addr_regs[0] - gp_regs = self.register_allocator.allocate_gp(2) - - from compiler.asm_templates import preload_addr_reg_asm - - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[v_hbm_reg], available_registers=gp_regs, addr_reg_val=[v_layout.hbm_base_addr] - ) - - isa_code += self._pv_multiply_asm( - mlen=self.mlen, - blen=self.blen, - head_dim=head_dim, - p_address=p_address, - v_hbm_offset_reg=v_hbm_reg, - v_hbm_offset=v_hbm_offset, - pv_address=pv_address, - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_addr(addr_regs) - - self.generated_code += isa_code - return isa_code - - def scale_o_row( - self, - o_matrix: str, - q_idx: int, - seq_len: int, - head_dim: int, - ) -> str: - """Scale the current row block of O by m_res: O[q_idx] *= m_res.""" - o_info = self[o_matrix] - o_address = o_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_res_addr = fp_sram_start + self.mlen - - row_offset = q_idx * self.mlen - - isa_code = f"; === Scale O[q_idx={q_idx}] by m_res ===\n" - isa_code += self._scale_o_asm( - mlen=self.mlen, - head_dim=head_dim, - seq_len=seq_len, - m_res_address=m_res_addr, - o_address=o_address, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - def final_scale_o( - self, - q_idx: int, - o_matrix: str, - seq_len: int, - head_dim: int, - ) -> str: - """Final scaling: O[q_idx] /= l.""" - o_info = self[o_matrix] - o_address = o_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - l_addr = fp_sram_start + 2 * self.mlen - - row_offset = q_idx * self.mlen - - isa_code = f"; === Final Scale O for Q block {q_idx} ===\n" - isa_code += self._final_scaling_asm( - mlen=self.mlen, - head_dim=head_dim, - seq_len=seq_len, - l_address=l_addr, - o_address=o_address, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - -# Example Usage - - -class _DeveloperView: - """ - Back-compat proxy for legacy ``prog._compiler.X(...)`` call sites. - - PlenaCompiler now inherits DeveloperCompiler rather than composing it, so - for call sites that still expect to reach the low-level DeveloperCompiler - API (e.g., ``allocate_fpram(name=..., size=...)`` returning an int), we - expose a proxy whose attribute lookup resolves callables on - DeveloperCompiler directly (bypassing any PlenaCompiler overrides with - colliding names). Non-callable attributes (e.g., ``generated_code``, - ``vram_allocator``) fall through to the underlying instance unchanged. - """ - - __slots__ = ("_inst",) - - def __init__(self, inst: PlenaCompiler): - object.__setattr__(self, "_inst", inst) - - def __getattr__(self, name: str): - cls_attr = getattr(DeveloperCompiler, name, None) - if cls_attr is not None and callable(cls_attr): - return cls_attr.__get__(self._inst, DeveloperCompiler) - return getattr(self._inst, name) - - def __setattr__(self, name: str, value): - setattr(self._inst, name, value) - - -# ============================================================================ -# TensorVar Proxy Object Hierarchy -# ============================================================================ - - -class TensorVar: - """ - Tensor proxy object base class - - All tensor variables inherit from this class. - Supports __matmul__ (`@`) operator, which automatically dispatches to appropriate PlenaCompiler methods. - - Dual naming: - - display_name: User-visible name (e.g., "temp", "Q", "S") - - internal_name: System internal name (e.g., "my_func_0/temp"), used for symbol table and ISA generation - """ - - def __init__( - self, - program: PlenaCompiler, - internal_name: str, - kind: str, - shape: tuple[int, int], - display_name: str | None = None, - ): - self._program = program - self.internal_name = internal_name # System internal name (with scope prefix), used by symbol table - self.display_name = display_name if display_name is not None else internal_name # User-visible name - self.kind = kind # "input", "batch", "matrix", "vram_matrix" - self.shape = shape - - @property - def name(self) -> str: - """Compatibility property: returns internal_name for internal system use""" - return self.internal_name - - def __matmul__(self, other): - """A @ B: Dispatch to appropriate computation based on operand types""" - return self._program._dispatch_matmul(self, other) - - def __repr__(self): - if self.display_name != self.internal_name: - return ( - f"{self.__class__.__name__}(display={self.display_name!r}, " - f"internal={self.internal_name!r}, shape={self.shape})" - ) - return f"{self.__class__.__name__}({self.display_name!r}, shape={self.shape})" - - -class InputVar(TensorVar): - """ - Input variable: tensor declared in HBM - - Not yet loaded to VRAM; needs to be loaded via load_batch / load_matrix. - - If ``prestaged_vram_addr`` is not None the tensor is assumed to be already - present in VRAM at that byte address. ``load_batch`` will register it at - that address without emitting any HBM→VRAM prefetch instructions. - """ - - def __init__( - self, - program: PlenaCompiler, - name: str, - shape: tuple[int, int], - hbm_addr: int, - hbm_size: int, - display_name: str | None = None, - prestaged_vram_addr: int | None = None, - ): - super().__init__(program, name, "input", shape, display_name=display_name) - self.hbm_addr = hbm_addr - self.hbm_size = hbm_size - self.prestaged_vram_addr = prestaged_vram_addr - - -class FPVar: - """ - FP variable: maps to a contiguous region in FPRAM - - Declared via prog.fp_var("scale", size=1), automatically allocates FPRAM space. - Provides .address for ISA generation (S_LD_FP / S_ST_FP). - - Usage: - scale = prog.fp_var("scale", size=1) - m_old = prog.fp_var("m_old", size=64) - - scale.address # -> FPRAM address (int) - scale.size # -> number of elements - scale[3] # -> address + 3 (element offset) - """ - - def __init__( - self, program: PlenaCompiler, internal_name: str, address: int, size: int, display_name: str | None = None - ): - self._program = program - self.internal_name = internal_name - self.display_name = display_name if display_name is not None else internal_name - self.address = address - self.size = size - - @property - def name(self) -> str: - return self.internal_name - - def __getitem__(self, idx: int) -> int: - """Element offset: fp_var[i] -> address + i""" - if idx < 0 or idx >= self.size: - raise IndexError(f"FPVar '{self.display_name}' index {idx} out of range [0, {self.size})") - return self.address + idx - - def __repr__(self): - return f"FPVar({self.display_name!r}, addr={self.address}, size={self.size})" - - -class VRAMMatrixVar(TensorVar): - """ - VRAM matrix variable: large matrix allocated via alloc - - Used to store intermediate results (e.g., S block, PV, O). - Supports sub-block indexed writes: `O[r][c] = ...` - """ - - def __init__(self, program: PlenaCompiler, name: str, shape: tuple[int, int], display_name: str | None = None): - super().__init__(program, name, "vram_matrix", shape, display_name=display_name) - - -# ============================================================================ -# PlenaCompiler Main Class -# ============================================================================ - - -class PlenaCompiler(DeveloperCompiler): - """ - PLENA High-level Compiler Interface. - - Inherits the ISA-emission machinery from DeveloperCompiler and layers a - Pythonic DSL on top. All operations are eagerly evaluated — ISA code is - generated immediately upon call. - """ - - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - """ - Args: - mlen: Matrix tile size (default 64) - blen: Vector tile size (default 4) - real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) - unroll_loops: If True, unroll sub-projection loops at ASM-gen time to - eliminate C_LOOP_START/END overhead. Overridden by the - ATEN_UNROLL env var ("1"=True, "0"=False). - """ - _env_unroll = os.environ.get("ATEN_UNROLL", "") - if _env_unroll == "1": - unroll_loops = True - elif _env_unroll == "0": - unroll_loops = False - super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops) - - # HBM address auto-allocation - self._next_hbm_addr: int = 0 - self._hbm_free_blocks: list[tuple[int, int]] = [] # (addr, size) - - # Variable registries - self._inputs: dict[str, InputVar] = {} - self._tensors: dict[str, TensorVar] = {} - self._fp_vars: dict[str, FPVar] = {} - self._functions: dict[str, Callable] = {} - self._registered_hbm_sub_matrices: dict[str, bool] = {} - self._registered_vram_sub_matrices: dict[str, bool] = {} - - self._result_tensor: TensorVar | None = None - - # Auto-generated name counter - self._auto_name_counter: int = 0 - - # Function scope namespace - # Push a prefix on each function call (e.g., "linear_0/"), pop on exit - # _auto_name will automatically add current scope prefix, avoiding name conflicts when calling the same function multiple times - self._scope_stack: list[str] = [] - self._function_call_counters: dict[str, int] = {} # func_name -> call count - - # ======================================================================== - # Property Access - # ======================================================================== - - # mlen / blen are instance attributes inherited from TileCompiler.__init__. - - @property - def compiler(self) -> PlenaCompiler: - """Legacy accessor — returns self now that PlenaCompiler is the compiler.""" - return self - - @property - def _compiler(self) -> _DeveloperView: - """Back-compat shim for legacy ``prog._compiler.X(...)`` call sites. - Returns a proxy that resolves callables against DeveloperCompiler - directly so callers reach the low-level API regardless of any - PlenaCompiler override with the same name.""" - return _DeveloperView(self) - - @property - def symbol_table(self): - """Access symbol table.""" - return self.get_symbol_table() - - # ======================================================================== - # Input Declaration - # ======================================================================== - - def input( - self, - name: str, - shape: tuple[int, int], - hbm_addr: int | None = None, - prestaged_vram_addr: int | None = None, - ) -> InputVar: - """ - Declare an input tensor (in HBM). - - Args: - name: tensor name - shape: (height, width) - hbm_addr: HBM address (None = auto-allocate) - prestaged_vram_addr: If an int, the tensor is assumed to be already - present in VRAM at this byte address. A subsequent call to - ``load_batch`` will register it at that address without emitting - any HBM→VRAM prefetch instructions. If None (default), the - normal HBM→VRAM load path is used. - - Returns: - InputVar proxy object - """ - h, w = shape - size = h * w - hbm_size = int(size * self.real_data_ratio) - - if hbm_addr is None: - hbm_addr = self._allocate_hbm(hbm_size) - - var = InputVar(self, name, shape, hbm_addr, hbm_size, prestaged_vram_addr=prestaged_vram_addr) - self._inputs[name] = var - super().add_hbm_object( - name=name, - hbm_addr=hbm_addr, - shape=shape, - real_data_ratio=self.real_data_ratio, - ) - return var - - # ======================================================================== - # Load Operations - # ======================================================================== - - def load_batch( - self, - input_var: InputVar, - name: str | None = None, - ) -> VRAMMatrixVar: - """ - Load tensor from HBM to VRAM (Batch type). - - When ``input_var.prestaged_vram_addr`` is set the tensor is assumed to - be already resident in VRAM at that address. No HBM→VRAM prefetch - instructions are emitted; the tensor is simply registered in the symbol - table at the given address. - - Args: - input_var: source InputVar - name: result name (None = use input name) - - Returns: - VRAMMatrixVar proxy object - """ - if not isinstance(input_var, InputVar): - raise TypeError(f"Expected InputVar, got {type(input_var)}") - - display_name = name if name is not None else input_var.display_name - internal_name = self._scoped_name(display_name) - - if input_var.prestaged_vram_addr is not None: - # Prestaged path: tensor is already in VRAM — register without ISA. - h, w = input_var.shape - vram_addr = input_var.prestaged_vram_addr - # Tell the VRAM allocator that this region is occupied so subsequent - # allocations don't collide with it. - self.vram_allocator._vmm.mark_used(vram_addr, h * w, name=internal_name) - super().add_vram_object( - name=internal_name, - shape=(h, w), - vram_addr=vram_addr, - dtype="fp16", - kind="Batch", - allocate_if_none=False, - strict=False, - ) - else: - # Normal path: emit HBM → VRAM prefetch ISA. - super().load_batch( - hbm_object_name=input_var.name, vram_object_name=internal_name, vlen=self.mlen, preload_len=4 - ) - - var = VRAMMatrixVar(self, internal_name, input_var.shape, display_name=display_name) - self._tensors[internal_name] = var - return var - - # ======================================================================== - # Store Operations - # ======================================================================== - - def store(self, tensor_var, name: str | None = None, hbm_addr: int | None = None) -> InputVar: - """ - Write tensor from VRAM back to HBM. - - Returns: - InputVar proxy object (can be loaded back later) - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"Store requires VRAMMatrixVar, got {type(tensor_var)}") - - display_name = name if name is not None else f"{tensor_var.display_name}_stored" - internal_name = self._scoped_name(display_name) - - if hbm_addr is None: - h, w = tensor_var.shape - size = h * w - hbm_size = int(size * self.real_data_ratio) - hbm_addr = self._allocate_hbm(hbm_size) - else: - h, w = tensor_var.shape - hbm_size = int(h * w * self.real_data_ratio) - - super().store_to_hbm( - tensor_name=tensor_var.name, # internal name for symbol table lookup - hbm_addr=hbm_addr, - hbm_object_name=internal_name, - vlen=self.mlen, - ) - - var = InputVar(self, internal_name, tensor_var.shape, hbm_addr, hbm_size, display_name=display_name) - self._inputs[internal_name] = var - return var - - # ======================================================================== - # VRAM Matrix Allocation - # ======================================================================== - - def alloc(self, name: str, rows: int, cols: int, strict: bool = True) -> VRAMMatrixVar: - """ - Allocate a VRAM matrix. - - Used to store intermediate results (e.g., S block, PV, O). - Within function scope, names are automatically prefixed to avoid conflicts. - - Args: - name: matrix name (user-visible) - rows: number of rows - cols: number of columns - strict: if False, skip mlen-alignment checks (for small scratch matrices) - - Returns: - VRAMMatrixVar proxy object - """ - display_name = name - internal_name = self._scoped_name(name) - super().allocate_vram_matrix(name=internal_name, rows=rows, cols=cols, strict=strict) - - var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) - self._tensors[internal_name] = var - return var - - def alloc_at(self, name: str, rows: int, cols: int, vram_addr: int) -> VRAMMatrixVar: - """Allocate a VRAM matrix view at a specific address. - - Used to create views into existing VRAM matrices (e.g., per-head - slices of a multi-head Q projection output). Does NOT bump the - VRAM allocator -- the caller is responsible for ensuring the region - is valid. - - Args: - name: matrix name (user-visible) - rows: number of rows - cols: number of columns - vram_addr: absolute VRAM address for this view - - Returns: - VRAMMatrixVar proxy object - """ - display_name = name - internal_name = self._scoped_name(name) - self._compiler.add_vram_object( - name=internal_name, - shape=(rows, cols), - vram_addr=vram_addr, - allocate_if_none=False, - ) - isa_code = f"; VRAM View {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" - self._compiler.generated_code += isa_code - var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) - self._tensors[internal_name] = var - return var - - def free_tensor(self, tensor_var: TensorVar): - """ - Free a tensor in VRAM, reclaiming space for subsequent allocations. - - Freed space can be reused by new alloc() or other operations. - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"Can only free VRAMMatrixVar, got {type(tensor_var)}") - - super().free_vram_object(tensor_var.name, strict=False) - # Keep sub-matrix registration state consistent after free. - self._registered_vram_sub_matrices[tensor_var.name] = False - - def free_input(self, input_var: InputVar): - """ - Free an InputVar bookkeeping and recycle its HBM range for future auto-allocation. - - Notes: - - This only affects PlenaCompiler's address management state. - - If a freed input is referenced again later, caller is responsible for correctness. - """ - if not isinstance(input_var, InputVar): - raise TypeError(f"Can only free InputVar, got {type(input_var)}") - - super().free_hbm_object(input_var.name, strict=False) - self._registered_hbm_sub_matrices[input_var.name] = False - self._recycle_hbm(input_var.hbm_addr, input_var.hbm_size) - self._inputs.pop(input_var.name, None) - - def free_fp_var(self, fp_var: FPVar): - """ - Free an FPVar and return its block to FPRAM free pool. - """ - if not isinstance(fp_var, FPVar): - raise TypeError(f"Can only free FPVar, got {type(fp_var)}") - self.free_fpram(fp_var.name, strict=True) - - # ======================================================================== - # Normalization Operations - # ======================================================================== - - def norm( - self, - tensor_var: TensorVar, - mode: str = "rms", - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """ - Normalize tensor in-place. - - Args: - tensor_var: tensor to normalize (must have VRAM backing, e.g., VRAMMatrixVar) - mode: "rms" or "layer" - eps_offset: FPRAM address of epsilon - reci_hid_offset: FPRAM address of 1/hidden_dim - vlen: vector length (default: program mlen) - scratchpad_vram_addr: optional scratchpad VRAM address - - Returns: - The same tensor_var (in-place operation) - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"norm requires VRAMMatrixVar, got {type(tensor_var)}") - - super().normalize( - tensor_name=tensor_var.name, - mode=mode, - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - return tensor_var - - def rms_norm( - self, - tensor_var: TensorVar, - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """RMS normalization (in-place).""" - return self.norm( - tensor_var=tensor_var, - mode="rms", - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - - def layer_norm( - self, - tensor_var: TensorVar, - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """Layer normalization (in-place).""" - return self.norm( - tensor_var=tensor_var, - mode="layer", - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - - # ======================================================================== - # FP Variable (FPRAM) - # ======================================================================== - - def allocate_fpram( - self, - internal_name: str, - size: int = 1, - display_name: str | None = None, - ) -> FPVar: - """ - Allocate FPRAM with explicit internal name and return FPVar proxy. - """ - if size <= 0: - raise ValueError(f"FPRAM allocation size must be positive, got {size}") - - address = super().allocate_fpram(internal_name, size) - var = FPVar( - self, - internal_name, - address, - size, - display_name=display_name if display_name is not None else internal_name, - ) - self._fp_vars[internal_name] = var - return var - - def free_fpram(self, internal_name: str, strict: bool = True): - """ - Free FPRAM allocation by internal name. - """ - super().free_fpram(internal_name, strict=strict) - self._fp_vars.pop(internal_name, None) - - def fp_var(self, name: str, size: int = 1) -> FPVar: - """ - Declare an FP variable in FPRAM. - - Allocates a contiguous region in FPRAM and returns an FPVar proxy. - Within function scope, names are automatically prefixed. - - Args: - name: variable name - size: number of f16 elements to allocate (default 1) - - Returns: - FPVar proxy object (use .address for ISA generation) - - Example: - scale = prog.fp_var("scale") # 1 element - m_old = prog.fp_var("m_old", size=64) # 64 elements - prog.compiler # access compiler for ISA if needed - """ - display_name = name - internal_name = self._scoped_name(name) - - return self.allocate_fpram( - internal_name=internal_name, - size=size, - display_name=display_name, - ) - - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot""" - return super().save_fpram_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot""" - super().restore_fpram_state(snapshot) - # Remove FPVar proxies that are no longer allocated in allocator. - allocations = set(super().list_fpram_allocations()) - to_remove = [n for n in self._fp_vars if n not in allocations] - for n in to_remove: - del self._fp_vars[n] - - # ======================================================================== - # FPRAM Tile Operations - # ======================================================================== - - def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: - if isinstance(addr_or_var, FPVar): - if offset < 0 or offset >= addr_or_var.size: - raise ValueError( - f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" - ) - return addr_or_var.address + offset - if not isinstance(addr_or_var, int): - raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") - return addr_or_var + offset - - def _resolve_rows( - self, - row_idx: int | None = None, - rows: list[int] | None = None, - ) -> list[int]: - if row_idx is not None and rows is not None: - raise ValueError("Provide either row_idx or rows, not both") - if rows is not None: - return rows - if row_idx is not None: - return [row_idx] - return list(range(self.mlen)) - - def tile_row_max( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Max: reduce a single row to max, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - - Example: - m = prog.fp_var("m", size=1) - S = prog.alloc("S", 64, 64) - for row in range(64): - prog.tile_row_max(m, S, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_max( - source_matrix=source.name, - row_map=row_map, - ) - - def tile_row_sum( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Sum: reduce a single row to sum, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sum(source.name, row_map) - - def tile_row_exp( - self, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - ): - """ - Tile Row Exp: in-place exp on specified rows. - - For each row i: source[i, :] = exp(source[i, :]) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - super().tile_row_exp(source.name, resolved_rows) - - def tile_row_reci( - self, - source: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Reciprocal: in-place 1/x on specified rows. - - For each row i: source[i, :] = 1.0 / source[i, :] - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_reci(source.name, rows) - - def tile_row_sub_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Sub FP: subtract FPRAM scalar from a single row. - - For row i: source[i, :] = source[i, :] - FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sub_fp(source.name, row_map) - - def tile_row_mul_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Mul FP: multiply a single row by FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_mul_fp(source.name, row_map) - - def tile_row_add_fp( - self, - source: VRAMMatrixVar, - fp_var: FPVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add FP: add FPRAM scalar to each specified row. - - For each row i: source[i, :] = source[i, :] + fp_var[i] - """ - if rows is None: - rows = list(range(self.mlen)) - row_map = [(r, fp_var[r]) for r in rows] - super().tile_row_add_fp(source.name, row_map) - - def tile_row_add( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add: dst[i, :] += src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_add(dst.name, src.name, rows) - - def tile_row_sub( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Sub: dst[i, :] -= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_sub(dst.name, src.name, rows) - - def tile_row_mul( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Mul: dst[i, :] *= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_mul(dst.name, src.name, rows) - - def fpvar_reci( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Reciprocal: compute 1/x for FPRAM scalar array. - - For each element i: dst[i] = 1.0 / src[i] - - Args: - src: source FPVar - dst: destination FPVar - count: number of elements (default: min(src.size, dst.size)) - - Example: - l = prog.fp_var("l", size=64) - inv_l = prog.fp_var("inv_l", size=64) - prog.fpvar_reci(l, inv_l) # inv_l = 1/l - """ - if count is None: - count = min(src.size, dst.size) - if count > src.size or count > dst.size: - raise ValueError(f"count={count} exceeds FPVar size: src.size={src.size}, dst.size={dst.size}") - super().fpram_reci(src.name, dst.name, count) - - def fpvar_max( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Max: element-wise max for FPRAM scalar arrays. - - For each element i: dst[i] = max(src1[i], src2[i]) - - Example: - m_new = prog.fp_var("m_new", size=64) - prog.fpvar_max(m_old, row_max, m_new) # m_new = max(m_old, row_max) - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_max(src1.name, src2.name, dst.name, count) - - def fpvar_sub( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Subtract: element-wise subtraction for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] - src2[i] - - Example: - diff = prog.fp_var("diff", size=64) - prog.fpvar_sub(m_old, m_new, diff) # diff = m_old - m_new - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_sub(src1.name, src2.name, dst.name, count) - - def fpvar_exp( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Exp: element-wise exp for FPRAM scalar array. - - For each element i: dst[i] = exp(src[i]) - - Example: - m_res = prog.fp_var("m_res", size=64) - prog.fpvar_exp(diff, m_res) # m_res = exp(diff) - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_exp(src.name, dst.name, count) - - def fpvar_mul( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Multiply: element-wise multiplication for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] * src2[i] - - Example: - result = prog.fp_var("result", size=64) - prog.fpvar_mul(l_old, m_res, result) # result = l_old * m_res - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_mul(src1.name, src2.name, dst.name, count) - - def fpvar_add( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Add: element-wise addition for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] + src2[i] - - Example: - l_new = prog.fp_var("l_new", size=64) - prog.fpvar_add(l_old, sum_p, l_new) # l_new = l_old + sum_p - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_add(src1.name, src2.name, dst.name, count) - - def fpvar_copy( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Copy: copy FPRAM scalar array. - - For each element i: dst[i] = src[i] - - Example: - m_old_saved = prog.fp_var("m_old_saved", size=64) - prog.fpvar_copy(m_old, m_old_saved) # backup m_old - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_copy(src.name, dst.name, count) - - def fpvar_sum( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Sum: reduction sum of src into dst[0] (via compiler FPRAM op). - """ - if count is None: - count = src.size - super().fpram_sum(src.name, dst.name, count) - - def fpvar_shift( - self, - src: FPVar, - dst: FPVar, - shift: int, - count: int | None = None, - fill: FPVar | None = None, - ): - """ - FPVar Shift: shift src into dst, filling out-of-range slots with fill (default FPRAM zero). - """ - if count is None: - count = min(src.size, dst.size) - fill_name = None if fill is None else fill.name - super().fpram_shift( - src_name=src.name, - dst_name=dst.name, - shift=shift, - count=count, - fill_fpram_name=fill_name, - ) - - def tile_row_mul_fp_broadcast( - self, - source: VRAMMatrixVar, - fpram_scalar_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - ): - """ - Tile Row Mul FP Broadcast: multiply a single row by a FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_scalar_addr] - - Args: - source: VRAM tile (mlen x mlen) - fpram_scalar_addr: FPRAM address or FPVar of the scalar - row_idx: single row index (legacy path) - rows: multiple row indices - - Example: - scale_fp = prog.fp_var("scale", size=1) - for row in range(64): - prog.tile_row_mul_fp_broadcast(S, scale_fp, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) - super().tile_row_mul_fp_broadcast(source.name, scalar_addr, resolved_rows) - - def fpvar_fill_from_fpram( - self, - dst: FPVar, - src_fpram_addr: int, - count: int | None = None, - ): - """ - FPVar Fill from FPRAM: fill all elements with a value from FPRAM. - - For each element i: dst[i] = FPRAM[src_fpram_addr] - - Args: - dst: destination FPVar - src_fpram_addr: source FPRAM address (e.g., 0 for 0.0, 2 for -inf) - count: number of elements (default: dst.size) - - Example: - m_old = prog.fp_var("m_old", size=64) - prog.fpvar_fill_from_fpram(m_old, 2) # fill with -inf from address 2 - """ - if count is None: - count = dst.size - super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) - - def vram_fill_zero( - self, - matrix: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - VRAM Fill Zero: fill specified rows with 0. - - Args: - matrix: VRAM matrix - rows: which rows to fill (default: all rows) - - Example: - O = prog.alloc("O", 128, 128) - prog.vram_fill_zero(O, rows=range(64, 128)) # zero out second half - """ - if rows is None: - rows = list(range(matrix.shape[0])) - else: - rows = list(rows) - - total_rows, cols = matrix.shape - if any(row < 0 or row >= total_rows for row in rows): - raise ValueError(f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={rows}") - - # VRAM matrices are column-block-major. The low-level tile helper zeros - # one 64-column tile, so walk every column block for wide matrices. - num_col_blocks = (cols + self.mlen - 1) // self.mlen - for col_block in range(num_col_blocks): - super().vram_fill_zero(matrix.name, rows, tile_col_idx=col_block) - - def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): - """Ensure an HBM input is registered in compiler sub-matrix manager.""" - if ( - input_var.name in self._registered_hbm_sub_matrices - and self._registered_hbm_sub_matrices[input_var.name] is True - ): - return - h, w = input_var.shape - super().ensure_hbm_sub_matrix( - name=input_var.name, - hbm_addr=input_var.hbm_addr, - shape=(h, w), - real_data_ratio=self.real_data_ratio, - ) - self._registered_hbm_sub_matrices[input_var.name] = True - - def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): - """Ensure a VRAM matrix is registered in compiler sub-matrix manager.""" - if ( - matrix_var.name in self._registered_vram_sub_matrices - and self._registered_vram_sub_matrices[matrix_var.name] is True - ): - return - super().ensure_vram_matrix_layout( - name=matrix_var.name, - shape=matrix_var.shape, - ) - self._registered_vram_sub_matrices[matrix_var.name] = True - - def vram_sub_projection_to( - self, - vram_matrix: VRAMMatrixVar, - vram_row_idx: int, - mram_input: InputVar, - mram_col_idx: int, - target: VRAMMatrixVar, - target_row_idx: int, - target_col_idx: int, - auto_reset_mram: bool = True, - k_block_start: int = 0, - k_block_count: int | None = None, - ): - """ - target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[:][mram_col_idx] - Supports K-split: k_block_start/k_block_count select a subset of K tiles. - """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() - super().load_sub_matrix_col( - name=mram_input.name, - col_idx=mram_col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - super().vram_sub_projection_to( - vram_mat_name=vram_matrix.name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_input.name, - mram_col_idx=mram_col_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - def vram_sub_projection_T_to( - self, - vram_matrix: VRAMMatrixVar, - vram_row_idx: int, - mram_input: InputVar, - mram_row_idx: int, - target: VRAMMatrixVar, - target_row_idx: int, - target_col_idx: int, - auto_reset_mram: bool = True, - ): - """ - target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[mram_row_idx][:]^T - """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() - super().load_sub_matrix_row(name=mram_input.name, row_idx=mram_row_idx) - super().vram_sub_projection_T_to( - vram_mat_name=vram_matrix.name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_input.name, - mram_row_idx=mram_row_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - ) - - # ======================================================================== - # RoPE (1D Positional Encoding) - # ======================================================================== - - def rope( - self, - x_var: VRAMMatrixVar, - x_rot_var: VRAMMatrixVar, - cos_var: VRAMMatrixVar, - sin_var: VRAMMatrixVar, - ) -> VRAMMatrixVar: - """Apply Rotary Position Embedding in-place: x = x * cos + rotate_half(x) * sin - - x_rot_var must already be in VRAM as rotate_half(x), preloaded by caller. - Returns x_var (modified in-place). - """ - super().rope( - x_name=x_var.name, - x_rot_name=x_rot_var.name, - cos_name=cos_var.name, - sin_name=sin_var.name, - ) - return x_var - - # ======================================================================== - # VRAM Matrix Addition - # ======================================================================== - - def vram_add( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - dst_row_offset: int = 0, - src_row_offset: int = 0, - num_rows: int | None = None, - ): - """VRAM matrix add: dst[row_offset:] += src""" - super().vram_matrix_add( - dst_matrix=dst.name, - src_matrix=src.name, - dst_row_offset=dst_row_offset, - src_row_offset=src_row_offset, - num_rows=num_rows, - ) - - def vram_block_add_to( - self, - src1: TensorVar, - src1_row_idx: int, - src1_col_idx: int, - src2: TensorVar, - src2_row_idx: int, - src2_col_idx: int, - target: TensorVar, - target_row_idx: int, - target_col_idx: int, - ): - """ - mlen x mlen block add: - target[target_row_idx][target_col_idx] = - src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] - - Supports writing back to the same matrix/block (in-place overwrite). - """ - allowed = (VRAMMatrixVar,) - if not isinstance(src1, allowed): - raise TypeError(f"src1 must be VRAMMatrixVar, got {type(src1)}") - if not isinstance(src2, allowed): - raise TypeError(f"src2 must be VRAMMatrixVar, got {type(src2)}") - if not isinstance(target, allowed): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - super().vram_block_add_to( - src1_matrix=src1.name, - src1_row_idx=src1_row_idx, - src1_col_idx=src1_col_idx, - src2_matrix=src2.name, - src2_row_idx=src2_row_idx, - src2_col_idx=src2_col_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - ) - - # ======================================================================== - # Flash Attention Operations - # ======================================================================== - - def init_online_softmax(self, q_idx: int, o_matrix: VRAMMatrixVar): - """Initialize Online Softmax state: m=-inf, l=0, O_row=0""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().init_online_softmax( - q_idx=q_idx, - o_matrix=o_matrix.name, - seq_len=seq_len, - head_dim=head_dim, - ) - - def online_softmax_block(self, s_block: VRAMMatrixVar, scale: float): - """Perform Online Softmax on S block""" - super().online_softmax_block( - s_block_matrix=s_block.name, - scale=scale, - ) - - def compute_pv( - self, - s_block: VRAMMatrixVar, - v_input: InputVar, - k_idx: int, - pv_matrix: VRAMMatrixVar, - head_dim: int, - ): - """Compute PV = P @ V[k_idx] where P is stored in s_block.""" - if not isinstance(s_block, VRAMMatrixVar): - raise TypeError(f"s_block must be VRAMMatrixVar, got {type(s_block)}") - if not isinstance(v_input, InputVar): - raise TypeError(f"v_input must be InputVar, got {type(v_input)}") - if not isinstance(pv_matrix, VRAMMatrixVar): - raise TypeError(f"pv_matrix must be VRAMMatrixVar, got {type(pv_matrix)}") - - self._ensure_hbm_sub_matrix_registered(v_input) - super().compute_pv( - s_block_matrix=s_block.name, - v_sub_matrix=v_input.name, - k_idx=k_idx, - pv_matrix=pv_matrix.name, - head_dim=head_dim, - ) - - def scale_o_row(self, o_matrix: VRAMMatrixVar, q_idx: int): - """Scale current row block of O by m_res""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().scale_o_row( - o_matrix=o_matrix.name, - q_idx=q_idx, - seq_len=seq_len, - head_dim=head_dim, - ) - - def final_scale_o(self, q_idx: int, o_matrix: VRAMMatrixVar): - """Final scaling: O[q_idx] = O[q_idx] / l""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().final_scale_o( - q_idx=q_idx, - o_matrix=o_matrix.name, - seq_len=seq_len, - head_dim=head_dim, - ) - - # ======================================================================== - # Function Decorator - # ======================================================================== - - def function(self, func: Callable) -> Callable: - """ - Decorator: Define reusable functions. - - Each invocation generates fresh ISA code (eager evaluation). - Internally allocated tensors are auto-freed on exit unless returned. - - Scoping: intermediate tensors get a call-index prefix to avoid name - collisions across repeated calls (e.g., "linear_0/proj_1", "linear_1/proj_1"). - Nested functions compose prefixes: "two_layer_0/linear_0/proj_1". - """ - func_name = func.__name__ - - @wraps(func) - def wrapper(*args, **kwargs): - call_idx = self._function_call_counters.get(func_name, 0) - self._function_call_counters[func_name] = call_idx + 1 - - scope = f"{func_name}_{call_idx}/" - self._scope_stack.append(scope) - - self.generated_code += f"; === Enter {func_name} (call #{call_idx}) ===\n" - - # Snapshot: record existing tensors before function execution - tensors_before = set(self._tensors.keys()) - inputs_before = set(self._inputs.keys()) - fp_vars_before = set(self._fp_vars.keys()) - - try: - result = func(*args, **kwargs) - - # Auto-free: free locally allocated tensors that are not returned - return_names = set() - return_fp_names = set() - if isinstance(result, TensorVar): - return_names.add(result.internal_name) - elif isinstance(result, FPVar): - return_fp_names.add(result.internal_name) - elif isinstance(result, (tuple, list)): - for r in result: - if isinstance(r, TensorVar): - return_names.add(r.internal_name) - elif isinstance(r, FPVar): - return_fp_names.add(r.internal_name) - - for name in set(self._tensors.keys()) - tensors_before: - if name not in return_names: - tensor = self._tensors[name] - if isinstance(tensor, VRAMMatrixVar): - self.free_tensor(tensor) - self._registered_vram_sub_matrices[tensor.name] = False - - for name in set(self._inputs.keys()) - inputs_before: - if name not in return_names: - self.free_input(self._inputs[name]) - - local_fp_names = sorted( - set(self._fp_vars.keys()) - fp_vars_before, - key=lambda n: self._fp_vars[n].address, - reverse=True, - ) - for name in local_fp_names: - if name in return_fp_names: - continue - fp_var = self._fp_vars.get(name) - if fp_var is not None: - self.free_fp_var(fp_var) - finally: - self._scope_stack.pop() - self.generated_code += f"; === Exit {func_name} (call #{call_idx}) ===\n" - - return result - - self._functions[func_name] = wrapper - wrapper._plena_function = True - wrapper._plena_name = func_name - return wrapper - - # ======================================================================== - # Result Marking - # ======================================================================== - - def result(self, tensor_var: TensorVar): - """Mark output result tensor.""" - self._result_tensor = tensor_var - - # ======================================================================== - # Compilation - # ======================================================================== - - def compile(self) -> str: - """Get generated ISA code string.""" - return super().get_code() - - def print_symbol_table(self): - """Print symbol table""" - super().print_symbol_table() - - def get_symbol_table(self): - """Get symbol table""" - return super().get_symbol_table() - - # ======================================================================== - # Operator Dispatch (internal) - # ======================================================================== - - def _dispatch_matmul(self, left: TensorVar, right) -> TensorVar: - raise TypeError("@ operator is no longer supported in PlenaCompiler. Use explicit program APIs instead.") - - # ======================================================================== - # Utility Methods - # ======================================================================== - - def _scoped_name(self, name: str) -> str: - """ - Apply current scope prefix to a name. - - - Top-level alloc("temp"): -> "temp" - - Inside linear call 0, alloc("temp"): -> "linear_0/temp" - - Nested two_layer->linear, alloc("temp"): -> "two_layer_0/linear_0/temp" - """ - if not self._scope_stack: - return name - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}{name}" - - def _allocate_hbm(self, hbm_size: int) -> int: - """Allocate HBM range, preferring previously freed blocks.""" - best_idx = None - best_waste = None - for i, (addr, size) in enumerate(self._hbm_free_blocks): - if size >= hbm_size: - waste = size - hbm_size - if best_waste is None or waste < best_waste: - best_idx = i - best_waste = waste - - if best_idx is not None: - addr, block_size = self._hbm_free_blocks.pop(best_idx) - # Return excess fragment to free list - excess = block_size - hbm_size - if excess > 0: - self._hbm_free_blocks.append((addr + hbm_size, excess)) - return addr - - addr = self._next_hbm_addr - m = self.mlen - self._next_hbm_addr = ((addr + hbm_size + m - 1) // m) * m - return addr - - def _recycle_hbm(self, hbm_addr: int, hbm_size: int): - """Recycle an HBM range for future auto-allocation.""" - if hbm_size <= 0: - return - self._hbm_free_blocks.append((hbm_addr, hbm_size)) - - def _auto_name(self, prefix: str = "t") -> str: - """ - Generate a unique scoped name. - - - Top-level: "__proj_1" - - linear call 0: "linear_0/__proj_1" - - nested: "two_layer_0/linear_0/__proj_1" - """ - self._auto_name_counter += 1 - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}__{prefix}_{self._auto_name_counter}" - - def __repr__(self): - num_inputs = len(self._inputs) - num_tensors = len(self._tensors) - num_functions = len(self._functions) - code_len = len(super().get_code().splitlines()) - return ( - f"PlenaCompiler(mlen={self.mlen}, blen={self.blen}, " - f"inputs={num_inputs}, tensors={num_tensors}, " - f"functions={num_functions}, isa_lines={code_len})" - ) - - -class TensorKind(Enum): - """Identifies which memory the tensor lives in / which proxy backs it.""" - - HBM = "hbm" # legacy: InputVar - VRAM = "vram" # legacy: VRAMMatrixVar - FPRAM = "fpram" # legacy: FPVar - - -# Type alias: a "Tensor" is any of the legacy proxy objects. Callers can use -# this annotation without worrying about which specific backing allocator -# a given tensor lives in. -Tensor = TensorVar | InputVar | VRAMMatrixVar | FPVar - - -def tensor_kind(tensor: Tensor) -> TensorKind: - """Return the backing storage of a tensor proxy.""" - if isinstance(tensor, FPVar): - return TensorKind.FPRAM - if isinstance(tensor, VRAMMatrixVar): - return TensorKind.VRAM - if isinstance(tensor, InputVar): - return TensorKind.HBM - if isinstance(tensor, TensorVar): - # Generic TensorVar without a specific backing — classify by ``kind``. - kind = getattr(tensor, "kind", "") - if kind in ("vram_matrix", "batch", "matrix"): - return TensorKind.VRAM - if kind == "input": - return TensorKind.HBM - raise TypeError(f"Unknown tensor type: {type(tensor).__name__}") +The implementation is split under ``compiler.aten.plena``. This module keeps +legacy imports such as ``from compiler.aten.plena_compiler import PlenaCompiler`` +working while re-exporting the public compiler, memory, register, and tensor +proxy types. +""" +from __future__ import annotations -# ============================================================================= -# Unified dataclass aliases for the three overlapping "Info" and three -# overlapping "Layout" types in tile_compiler.py. -# ============================================================================= +from compiler.aten.plena.compiler import PlenaCompiler +from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN +from compiler.aten.plena.isa_compiler import DeveloperCompiler, IsaCompiler +from compiler.aten.plena.memory import ( + FPRAMAllocator, + FPRAMObjectLayout, + MRAMAllocator, + MatrixBlockLayout, + MemoryBlock, + MemoryObjectInfo, + SubMatrixInfo, + VRAMAllocator, + VRAMMatrixBlockLayout, + VRAMSubMatrixInfo, + VirtualMemoryManager, +) +from compiler.aten.plena.registers import RegisterAllocator +from compiler.aten.plena.tile_compiler import TileCompiler +from compiler.aten.plena.vars import FPVar, InputVar, Tensor, TensorKind, TensorVar, VRAMMatrixVar, tensor_kind -# ``TensorInfo`` is the union of the three Info dataclasses. Callers can -# import ``TensorInfo`` and use it as an annotation; at runtime the object -# will be whichever specific Info subtype ``TileCompiler`` constructed. +# ``TensorInfo`` is the union of the three Info dataclasses. Callers can import +# it as an annotation; at runtime the object is whichever specific Info subtype +# ``TileCompiler`` constructed. TensorInfo = MemoryObjectInfo | SubMatrixInfo | VRAMSubMatrixInfo # ``TileLayout`` is the union of the three Layout dataclasses. TileLayout = MatrixBlockLayout | VRAMMatrixBlockLayout | FPRAMObjectLayout -# ============================================================================= -# Public exports. -# ============================================================================= - - __all__ = [ + "BLEN", + "IMM2_BOUND", + "MLEN", "DeveloperCompiler", "FPRAMAllocator", "FPRAMObjectLayout", "FPVar", "InputVar", + "IsaCompiler", "MRAMAllocator", "MatrixBlockLayout", "MemoryBlock", diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index fe372d9..aba35a8 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -19,6 +19,7 @@ import json import math +import re from pathlib import Path import torch @@ -27,11 +28,7 @@ from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.ops.registry import OpRegistry, Backend import compiler.aten.ops as ops -from transactional_emulator.testbench.model_layer_test_builder import ( - quantize_to_mxfp, - _make_rope_tables, -) -import re +from quant.quantizer.hardware_quantizer.mxfp import _mx_fp_quantize_hardware _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate @@ -70,6 +67,33 @@ def _fix_large_immediates(isa_code: str) -> str: _HW_MAX_K_TILES = 4 +def quantize_to_mxfp(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor to MXFP8 matching HBM hardware format; return dequantized result.""" + orig_shape = tensor.shape + tensor_2d = tensor.float().reshape(-1, tensor.shape[-1]) + bm_x, _, _, _ = _mx_fp_quantize_hardware( + tensor_2d, + width=8, + exponent_width=4, + exponent_bias_width=8, + block_size=[1, 8], + ) + return bm_x.reshape(orig_shape) + + +def _make_rope_tables(seq_len: int, head_dim: int, theta: float = 10000.0): + """Compute RoPE cos/sin tables, shape (seq_len, head_dim).""" + half = head_dim // 2 + freqs = 1.0 / (theta ** (torch.arange(0, half).float() / half)) + positions = torch.arange(seq_len).float() + angles = torch.outer(positions, freqs) + cos_half = torch.cos(angles) + sin_half = torch.sin(angles) + cos = torch.cat([cos_half, cos_half], dim=-1) + sin = torch.cat([sin_half, sin_half], dim=-1) + return cos, sin + + def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, from_inter=None): """Matrix multiply matching hardware K-split BF16 precision. @@ -82,9 +106,12 @@ def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, fr this is equivalent to a single matmul with BF16 cast. """ if to_inter is None: - to_inter = lambda x: x.to(torch.bfloat16) + def to_inter(x): + return x.to(torch.bfloat16) + if from_inter is None: - from_inter = lambda x: x.float() + def from_inter(x): + return x.float() k_total = A.shape[1] num_k_tiles = math.ceil(k_total / mlen) @@ -881,7 +908,7 @@ def compile_hf_model( li = layer_inputs[i] # Layer progress marker (visible in non-quiet emulator output) - prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} START ===\n" + prog._compiler.emit_comment(f"=== LAYER {i}/{n_layers} START ===") # --- Attention block --- # Save residual: scratch = current (zero then add) @@ -983,7 +1010,7 @@ def compile_hf_model( prog.vram_add(current_after_attn, scratch) current = current_after_attn # carry forward - prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} COMPLETE ===\n" + prog._compiler.emit_comment(f"=== LAYER {i}/{n_layers} COMPLETE ===") # Final norm prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) @@ -1127,7 +1154,6 @@ def compile_and_run( from transactional_emulator.tools.create_sim_env import create_sim_env from compiler.sim_env_utils.build_env import create_mem_for_sim from transactional_emulator.testbench.emulator_runner import ( - run_and_assert, compare_emulator_output, ) diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index f168504..43b7520 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -18,7 +18,71 @@ if _p not in sys.path: sys.path.insert(0, _p) -import torch +import torch # noqa: E402 + + +def test_isa_builder_renders_typed_instruction(): + """Typed ISA builder should render the existing asm syntax.""" + from compiler.aten.isa_builder import IsaBuilder, fp, gp + + asm = IsaBuilder() + asm.comment("typed builder smoke") + asm.instr("S_ADDI_INT", gp(3), gp(0), 4096) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + + rendered = asm.render() + assert "; typed builder smoke" in rendered + assert "S_ADDI_INT gp3, gp0, 4096" in rendered + assert "S_ADD_FP f1, f1, f2" in rendered + print(" PASS test_isa_builder_renders_typed_instruction") + + +def test_isa_builder_legalizes_large_absolute_immediates(): + """Typed ISA builder should split large absolute S_ADDI_INT loads.""" + from compiler.aten.isa_builder import IsaBuilder, gp + + rendered = IsaBuilder().instr("S_ADDI_INT", gp(5), gp(0), 300000).render() + assert "S_LUI_INT gp5, 73" in rendered + assert "S_ADDI_INT gp5, gp5, 992" in rendered + assert "S_ADDI_INT gp5, gp0, 300000" not in rendered + print(" PASS test_isa_builder_legalizes_large_absolute_immediates") + + +def test_isa_builder_preserves_relative_large_immediates(): + """Typed ISA builder should not split relative S_ADDI_INT instructions.""" + from compiler.aten.isa_builder import IsaBuilder, gp + + rendered = IsaBuilder().instr("S_ADDI_INT", gp(5), gp(3), 300000).render() + assert "S_ADDI_INT gp5, gp3, 300000" in rendered + assert "S_LUI_INT" not in rendered + print(" PASS test_isa_builder_preserves_relative_large_immediates") + + +def test_fpvar_helper_uses_canonical_emit_path(): + """Converted FPVar helpers should still append and return asm text.""" + from compiler.aten.plena_compiler import PlenaCompiler + + prog = PlenaCompiler() + code = prog.fpvar_add_asm(src1_addr=0, src2_addr=4, dst_addr=8, count=2) + + assert code == prog.get_code() + assert "S_ADD_FP f1, f1, f2" in code + assert "C_LOOP_START" in code + print(" PASS test_fpvar_helper_uses_canonical_emit_path") + + +def test_hbm_load_helper_uses_typed_legalization(): + """Converted HBM load helpers should legalize typed large immediates.""" + from compiler.aten.plena_compiler import TileCompiler + + compiler = TileCompiler() + compiler.register_matrix("W", (512, 512), hbm_base_addr=0) + + code = compiler.load_sub_matrix_asm("W", row_idx=0, col_idx=0, mram_dest_addr=0) + assert "S_LUI_INT gp1, 64" in code + assert "S_ADDI_INT gp1, gp0, 262144" not in code + assert "H_PREFETCH_M gp3, gp1, a1, 1, 0" in code + print(" PASS test_hbm_load_helper_uses_typed_legalization") def test_vram_fill_zero_all_column_blocks(): @@ -167,9 +231,11 @@ def test_compile_hf_model_golden_vs_hf(): def test_native_compile_assembles(): """Native-dim ISA must assemble without overflow.""" - from compiler.aten.plena_frontend import compile_hf_model, _fix_large_immediates + import os + import tempfile + + from compiler.aten.plena_frontend import compile_hf_model from transformers import AutoModelForCausalLM - import tempfile, os model = AutoModelForCausalLM.from_pretrained( "AICrossSim/clm-60m", torch_dtype=torch.float32 @@ -213,6 +279,11 @@ def test_native_compile_assembles(): print("=" * 60) tests = [ + test_isa_builder_renders_typed_instruction, + test_isa_builder_legalizes_large_absolute_immediates, + test_isa_builder_preserves_relative_large_immediates, + test_fpvar_helper_uses_canonical_emit_path, + test_hbm_load_helper_uses_typed_legalization, test_vram_fill_zero_all_column_blocks, test_vram_add_all_column_blocks, test_alloc_at_correct_address, From 5d0447d06e2627a18bda2f183b9627dcf3157055 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 12:57:09 +0100 Subject: [PATCH 21/32] refactor: slim ATen compiler path --- aten/__init__.py | 28 +- aten/ops/plena/attention_ops.py | 23 +- aten/ops/plena/conv_ops.py | 14 +- aten/ops/plena/ffn_ops.py | 4 +- aten/ops/plena/linear_ops.py | 100 +- aten/ops/plena/softmax_ops.py | 15 +- aten/plena/__init__.py | 34 +- aten/plena/compiler.py | 234 +-- aten/plena/dsl_fp_tile_ops.py | 568 -------- aten/plena/isa_compiler.py | 159 +-- aten/plena/isa_emit.py | 12 - aten/plena/isa_fp_ops.py | 60 +- aten/plena/isa_matrix.py | 198 +-- aten/plena/isa_tile_rows.py | 100 +- aten/plena/memory.py | 313 +--- ...{dsl_attention.py => program_attention.py} | 6 +- aten/plena/program_fp_tile_ops.py | 347 +++++ ...sl_matrix_ops.py => program_matrix_ops.py} | 69 +- .../{dsl_tensors.py => program_tensors.py} | 10 +- aten/plena/tile_compiler.py | 615 ++------ aten/plena/vars.py | 43 +- aten/plena_compiler.py | 54 +- aten/plena_frontend.py | 1258 ++++++----------- aten/tests/test_plena_compiler.py | 6 +- aten/tests/test_quantization_ablation.py | 8 +- compiler/__init__.py | 12 + pyproject.toml | 1 + 27 files changed, 1310 insertions(+), 2981 deletions(-) delete mode 100644 aten/plena/dsl_fp_tile_ops.py rename aten/plena/{dsl_attention.py => program_attention.py} (95%) create mode 100644 aten/plena/program_fp_tile_ops.py rename aten/plena/{dsl_matrix_ops.py => program_matrix_ops.py} (75%) rename aten/plena/{dsl_tensors.py => program_tensors.py} (98%) create mode 100644 compiler/__init__.py diff --git a/aten/__init__.py b/aten/__init__.py index ba18cc9..a9068ef 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -1,6 +1,6 @@ """compiler.aten — ATen-style PLENA compiler path. -PlenaCompiler DSL + op backend registry. Pairs with compiler.generator +PlenaCompiler program builder + op backend registry. Pairs with compiler.generator for the two-path compiler (template vs. aten). """ @@ -19,30 +19,12 @@ gp, ) from compiler.aten.plena_compiler import ( # noqa: E402, F401 - DeveloperCompiler, - PlenaCompiler, - TileCompiler, + FPVar, + InputVar, IsaCompiler, - RegisterAllocator, + PlenaCompiler, TensorVar, - InputVar, + TileCompiler, VRAMMatrixVar, - FPVar, - TensorKind, - tensor_kind, - Tensor, - TensorInfo, - TileLayout, - MemoryBlock, - VirtualMemoryManager, - MRAMAllocator, - VRAMAllocator, - FPRAMAllocator, - SubMatrixInfo, - MatrixBlockLayout, - VRAMSubMatrixInfo, - VRAMMatrixBlockLayout, - MemoryObjectInfo, - FPRAMObjectLayout, ) from compiler.aten.ops.registry import OpRegistry, Backend # noqa: E402, F401 diff --git a/aten/ops/plena/attention_ops.py b/aten/ops/plena/attention_ops.py index 0798010..54678a1 100644 --- a/aten/ops/plena/attention_ops.py +++ b/aten/ops/plena/attention_ops.py @@ -126,7 +126,7 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): # Make sure K/V HBM subst-matrix registry knows about them prog._ensure_hbm_sub_matrix_registered(K) prog._ensure_hbm_sub_matrix_registered(V) - alloc = prog._compiler.register_allocator + alloc = prog.register_allocator k_addr, v_addr = alloc.allocate_addr(2) gp_for_preload = alloc.allocate_gp(2) setup = preload_addr_reg_asm( @@ -135,7 +135,7 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): addr_reg_val=[K.hbm_addr, V.hbm_addr], ) alloc.free_gp(gp_for_preload) - prog._compiler.emit(setup) + prog.emit(setup) # Allocate VRAM buffers mirroring main's layout. # S, PV each require mlen*mlen*ratio elements; O is s_q * (hq*h_qkv). @@ -143,15 +143,15 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): # vector_sram_base_address + computed offsets for S/PV/O, so they must be # allocated contiguously starting right after Q. allocate_vram_matrix # bump-allocates, which preserves that contiguity. - q_vram_base = prog._compiler.get_vram_addr(Q.name) + q_vram_base = prog.get_vram_addr(Q.name) s_name = prog._scoped_name("_gqa_S") pv_name = prog._scoped_name("_gqa_PV") o_name = prog._scoped_name("O") # Express sizes as (rows, cols) that multiply to the required counts. - prog._compiler.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) - prog._compiler.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) - prog._compiler.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) + prog.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) + prog.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) + prog.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) # Reserve FPRAM for multi-head softmax state. main's flash_attn_asm # assumes slots 0..2 hold constants (0.0, scale, -inf, preloaded by the @@ -159,13 +159,16 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): # 3 triples (m_old, m_res, l_old) per head, strided 3*br apart. # Reserve slots 0..2 first so our bump-allocated state lives at offset 3+. br = min(mlen, s_q) - fp_allocs = prog._compiler.fpram_allocator + fp_allocs = prog.fpram_allocator if "_gqa_fp_const_zero" not in fp_allocs.allocations: fp_allocs.allocate(name="_gqa_fp_const_zero", size=1) fp_allocs.allocate(name="_gqa_fp_const_scale", size=1) fp_allocs.allocate(name="_gqa_fp_const_neg_inf", size=1) fp_state_size = 3 * br * ratio - fp_start = prog._compiler.allocate_fpram(name="_gqa_softmax_state", size=fp_state_size) + fp_info = prog.add_fpram_object(name="_gqa_softmax_state", size=fp_state_size) + if fp_info.fpram_addr is None: + raise RuntimeError("Failed to allocate FPRAM for GQA softmax state") + fp_start = fp_info.fpram_addr # Call main's fused GQA template asm = flash_attn_asm( @@ -185,13 +188,13 @@ def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): k_base_hbm_offset_reg=k_addr, v_base_hbm_offset_reg=v_addr, ) - prog._compiler.emit(asm) + prog.emit(asm) # Release HBM addr regs (they're only needed during the call) alloc.free_addr([k_addr, v_addr]) # Return O as a VRAMMatrixVar the caller can consume - from compiler.aten.plena_compiler import VRAMMatrixVar + from compiler.aten.plena.vars import VRAMMatrixVar O = VRAMMatrixVar(prog, o_name, (s_q, hq * h_qkv), display_name="O") prog._tensors[o_name] = O diff --git a/aten/ops/plena/conv_ops.py b/aten/ops/plena/conv_ops.py index 689e9ba..314671e 100644 --- a/aten/ops/plena/conv_ops.py +++ b/aten/ops/plena/conv_ops.py @@ -114,12 +114,12 @@ def conv2d_plena( # Look up VRAM base addresses from the symbol table # ------------------------------------------------------------------ if use_shift: - mask_vec_vram_addr = prog._compiler.get_vram_addr(mask_mat.name) + mask_vec_vram_addr = prog.get_vram_addr(mask_mat.name) else: - basis_vram_base = prog._compiler.get_vram_addr(basis_mat.name) - scratch_vram_addr = prog._compiler.get_vram_addr(scratch_mat.name) - temp_vram_addr = prog._compiler.get_vram_addr(temp_mat.name) - output_vram_base = prog._compiler.get_vram_addr(output_mat.name) + basis_vram_base = prog.get_vram_addr(basis_mat.name) + scratch_vram_addr = prog.get_vram_addr(scratch_mat.name) + temp_vram_addr = prog.get_vram_addr(temp_mat.name) + output_vram_base = prog.get_vram_addr(output_mat.name) # ------------------------------------------------------------------ # GP register allocation @@ -146,7 +146,7 @@ def conv2d_plena( setup_lines.append(f"S_LUI_INT gp{setup_gp}, {hbm_base >> 12}") setup_lines.append(f"S_ADDI_INT gp{setup_gp}, gp{setup_gp}, {hbm_base & 0xFFF}") setup_lines.append(f"C_SET_ADDR_REG a{addr_reg_idx}, gp0, gp{setup_gp}") - prog._compiler.emit("\n".join(setup_lines) + "\n") + prog.emit("\n".join(setup_lines) + "\n") # ------------------------------------------------------------------ # Emit: im2col assembly @@ -191,7 +191,7 @@ def conv2d_plena( fp_one_reg=fp_one_reg, # f1 = 1.0 by default (must be in fp_preload[fp_one_reg]) fp_ex_reg=2, # f2 = V_RED_SUM accumulator ) - prog._compiler.emit(asm_code) + prog.emit(asm_code) # ------------------------------------------------------------------ # Systolic matmul: im2col_out @ weight_2d -> (M, C_out) diff --git a/aten/ops/plena/ffn_ops.py b/aten/ops/plena/ffn_ops.py index 9d4be26..d695242 100644 --- a/aten/ops/plena/ffn_ops.py +++ b/aten/ops/plena/ffn_ops.py @@ -25,7 +25,7 @@ def ffn_plena(prog, input_var, w_gate, w_up, w_down): vlen = prog.mlen # Retrieve VRAM address of the loaded activation - activation_base_address = prog._compiler.get_vram_addr(input_var.name) + activation_base_address = prog.get_vram_addr(input_var.name) # Set HBM address registers for each weight matrix isa_code = preload_addr_reg_asm( @@ -55,7 +55,7 @@ def ffn_plena(prog, input_var, w_gate, w_up, w_down): use_loop_instructions=True, ) - prog._compiler.emit(isa_code) + prog.emit(isa_code) # FFN result is written back to the activation area in VRAM (in-place overwrite) return input_var diff --git a/aten/ops/plena/linear_ops.py b/aten/ops/plena/linear_ops.py index c94d292..9f39cf6 100644 --- a/aten/ops/plena/linear_ops.py +++ b/aten/ops/plena/linear_ops.py @@ -1,27 +1,29 @@ """PLENA backend stubs for linear projection operators.""" +import math -def linear_plena(prog, input_var, weight_var): - """PLENA backend: linear projection via PlenaCompiler sub-matrix operations. - Supports M > mlen via row-block iteration and K_col > 4*mlen via K-split - partial sums accumulated in VRAM. +MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements + - MRAM capacity: 4 tiles × mlen² = 4 × 4096 = 16384 elements (MAX_K_TILES=4). - When K tiles > MAX_K_TILES, we split into chunks and accumulate partial sums. +def _iter_k_chunks(num_k_tiles: int): + k_start = 0 + while k_start < num_k_tiles: + k_end = min(k_start + MAX_K_TILES, num_k_tiles) + yield k_start, k_end - k_start + k_start = k_end - output[r][c] = sum_k input[r][k] @ weight[k][c] - for r in 0..num_row_blocks-1, c in 0..num_col_blocks-1, - k split into chunks of at most MAX_K_TILES tiles. - """ - import math +def linear_projection_plena(prog, input_var, weight_var, name: str = "linear_out"): + """Emit tiled PLENA linear projection, including K-split accumulation.""" mlen = prog.mlen - MAX_K_TILES = 4 # MRAM capacity: 4 × mlen² elements rows, k_total = input_var.shape _, out_features = weight_var.shape num_row_blocks = math.ceil(rows / mlen) + assert out_features % mlen == 0, ( + f"out_features ({out_features}) must be a multiple of mlen ({mlen})" + ) num_col_blocks = out_features // mlen num_k_tiles = math.ceil(k_total / mlen) @@ -30,67 +32,42 @@ def linear_plena(prog, input_var, weight_var): # operate on full mlen-wide tiles (HBM zero-pads unused rows) and only the # first `rows` rows of the output contain valid results. output_strict = rows % mlen == 0 - output = prog.alloc("linear_out", rows, out_features, strict=output_strict) + output = prog.alloc(name, rows, out_features, strict=output_strict) + + def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **k_split): + prog.vram_sub_projection_to( + input_var, + row_idx, + weight_var, + col_idx, + target, + target_row_idx, + target_col_idx, + **k_split, + ) if num_k_tiles <= MAX_K_TILES: - # Single pass: all K tiles fit in MRAM for col_idx in range(num_col_blocks): for row_idx in range(num_row_blocks): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - ) + emit_projection(row_idx, col_idx, output, row_idx, col_idx) else: - # K-split: chunk K tiles into groups of MAX_K_TILES, accumulate partial sums - k_chunks = [] - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - k_chunks.append((k_start, k_end - k_start)) - k_start = k_end - # Temp buffer for partial sums — only needs one (mlen, mlen) tile since # each sub_projection_to writes a single tile before accumulating. # Using the full output shape would cause VRAM overlap with the output # when out_features > mlen (column-block-major layout). - temp = prog.alloc("linear_out_temp", mlen, mlen) + temp = prog.alloc(f"{name}_temp", mlen, mlen) - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): + for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)): + k_split = { + "k_block_start": k_block_start, + "k_block_count": k_block_count, + } for col_idx in range(num_col_blocks): for row_idx in range(num_row_blocks): if k_chunk_idx == 0: - # First chunk: write directly to output - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) + emit_projection(row_idx, col_idx, output, row_idx, col_idx, **k_split) else: - # Subsequent chunks: write to temp tile, then accumulate into output. - # temp is a single (mlen, mlen) tile to avoid VRAM overlap - # with the output buffer in column-block-major layout. - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - temp, - 0, - 0, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) + emit_projection(row_idx, col_idx, temp, 0, 0, **k_split) prog.vram_block_add_to( output, row_idx, @@ -102,5 +79,10 @@ def linear_plena(prog, input_var, weight_var): row_idx, col_idx, ) + prog.free_tensor(temp) return output + + +def linear_plena(prog, input_var, weight_var): + return linear_projection_plena(prog, input_var, weight_var) diff --git a/aten/ops/plena/softmax_ops.py b/aten/ops/plena/softmax_ops.py index f7b99cd..9570bc1 100644 --- a/aten/ops/plena/softmax_ops.py +++ b/aten/ops/plena/softmax_ops.py @@ -50,7 +50,7 @@ def softmax_plena(prog, input_var, scale: float = 1.0): # Reserve fp_preload region (addresses 0-2): 0=0.0, 1=scale, 2=-inf # This ensures FPRAM variable allocation starts after these reserved slots - prog._compiler.fpram_allocator.next_free = 3 + prog.fpram_allocator.next_free = 3 # Allocate FPRAM variables scale_fp = prog.fp_var("scale_fp", size=1) # scale factor @@ -69,10 +69,9 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.fpvar_fill_from_fpram(l_old, src_fpram_addr=0) # load 0.0 # Row-by-row online softmax - compiler = prog._compiler for row in range(mlen): # 1. Save old max for this row - compiler.fpvar_copy_asm(m_old.address + row, m_old_saved.address, 1) + prog.fpvar_copy_asm(m_old.address + row, m_old_saved.address, 1) # 2. Scale: S[row] *= scale prog.tile_row_mul_fp_broadcast(S, scale_fp.address, row) @@ -81,11 +80,11 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.tile_row_max(row_max_tmp.address, S, row) # 4. Update running max: m_old[row] = max(m_old[row], row_max) - compiler.fpvar_max_asm(m_old.address + row, row_max_tmp.address, m_old.address + row, 1) + prog.fpvar_max_asm(m_old.address + row, row_max_tmp.address, m_old.address + row, 1) # 5. Decay factor: m_res[row] = exp(m_old_saved - m_curr) - compiler.fpvar_sub_asm(m_old_saved.address, m_old.address + row, m_old_saved.address, 1) - compiler.fpvar_exp_asm(m_old_saved.address, m_res.address + row, 1) + prog.fpvar_sub_asm(m_old_saved.address, m_old.address + row, m_old_saved.address, 1) + prog.fpvar_exp_asm(m_old_saved.address, m_res.address + row, 1) # 6. Subtract max: S[row] -= m_curr prog.tile_row_sub_fp(S, m_old.address + row, row) @@ -97,8 +96,8 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.tile_row_sum(sum_p_tmp.address, S, row) # 9. Update accumulated sum: l_old[row] = l_old[row]*m_res[row] + sum_p - compiler.fpvar_mul_asm(l_old.address + row, m_res.address + row, l_old.address + row, 1) - compiler.fpvar_add_asm(l_old.address + row, sum_p_tmp.address, l_old.address + row, 1) + prog.fpvar_mul_asm(l_old.address + row, m_res.address + row, l_old.address + row, 1) + prog.fpvar_add_asm(l_old.address + row, sum_p_tmp.address, l_old.address + row, 1) # Final normalization: P[row] /= l_old[row] prog.fpvar_reci(l_old, inv_l) diff --git a/aten/plena/__init__.py b/aten/plena/__init__.py index 09ad0c4..1695c6a 100644 --- a/aten/plena/__init__.py +++ b/aten/plena/__init__.py @@ -2,49 +2,19 @@ from compiler.aten.plena.compiler import PlenaCompiler from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN -from compiler.aten.plena.isa_compiler import DeveloperCompiler, IsaCompiler -from compiler.aten.plena.memory import ( - FPRAMAllocator, - FPRAMObjectLayout, - MRAMAllocator, - MatrixBlockLayout, - MemoryBlock, - MemoryObjectInfo, - SubMatrixInfo, - VRAMAllocator, - VRAMMatrixBlockLayout, - VRAMSubMatrixInfo, - VirtualMemoryManager, -) -from compiler.aten.plena.registers import RegisterAllocator +from compiler.aten.plena.isa_compiler import IsaCompiler from compiler.aten.plena.tile_compiler import TileCompiler -from compiler.aten.plena.vars import FPVar, InputVar, Tensor, TensorKind, TensorVar, VRAMMatrixVar, tensor_kind +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar __all__ = [ "BLEN", "IMM2_BOUND", "MLEN", - "DeveloperCompiler", - "FPRAMAllocator", - "FPRAMObjectLayout", "FPVar", "InputVar", "IsaCompiler", - "MRAMAllocator", - "MatrixBlockLayout", - "MemoryBlock", - "MemoryObjectInfo", "PlenaCompiler", - "RegisterAllocator", - "SubMatrixInfo", - "Tensor", - "TensorKind", "TensorVar", "TileCompiler", - "VRAMAllocator", - "VRAMMatrixBlockLayout", "VRAMMatrixVar", - "VRAMSubMatrixInfo", - "VirtualMemoryManager", - "tensor_kind", ] diff --git a/aten/plena/compiler.py b/aten/plena/compiler.py index 395b637..91e006b 100644 --- a/aten/plena/compiler.py +++ b/aten/plena/compiler.py @@ -1,45 +1,15 @@ -"""User-facing PLENA compiler DSL.""" +"""User-facing PLENA compiler program builder.""" from __future__ import annotations import os -from collections.abc import Callable -from functools import wraps from compiler.aten.plena.isa_compiler import IsaCompiler -from compiler.aten.plena.dsl_attention import DslAttentionMixin -from compiler.aten.plena.dsl_fp_tile_ops import DslFPTileOpsMixin -from compiler.aten.plena.dsl_matrix_ops import DslMatrixOpsMixin -from compiler.aten.plena.dsl_tensors import DslTensorMixin -from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar - - -class _IsaCompilerView: - """ - Back-compat proxy for legacy ``prog._compiler.X(...)`` call sites. - - PlenaCompiler now inherits IsaCompiler rather than composing it, so - for call sites that still expect to reach the low-level IsaCompiler - API (e.g., ``allocate_fpram(name=..., size=...)`` returning an int), we - expose a proxy whose attribute lookup resolves callables on - IsaCompiler directly (bypassing any PlenaCompiler overrides with - colliding names). Non-callable attributes (e.g., ``generated_code``, - ``vram_allocator``) fall through to the underlying instance unchanged. - """ - - __slots__ = ("_inst",) - - def __init__(self, inst: PlenaCompiler): - object.__setattr__(self, "_inst", inst) - - def __getattr__(self, name: str): - cls_attr = getattr(IsaCompiler, name, None) - if cls_attr is not None and callable(cls_attr): - return cls_attr.__get__(self._inst, IsaCompiler) - return getattr(self._inst, name) - - def __setattr__(self, name: str, value): - setattr(self._inst, name, value) +from compiler.aten.plena.program_attention import ProgramAttentionMixin +from compiler.aten.plena.program_fp_tile_ops import ProgramFPTileOpsMixin +from compiler.aten.plena.program_matrix_ops import ProgramMatrixOpsMixin +from compiler.aten.plena.program_tensors import ProgramTensorMixin +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar # ============================================================================ @@ -48,18 +18,17 @@ def __setattr__(self, name: str, value): class PlenaCompiler( - DslTensorMixin, - DslFPTileOpsMixin, - DslMatrixOpsMixin, - DslAttentionMixin, + ProgramTensorMixin, + ProgramFPTileOpsMixin, + ProgramMatrixOpsMixin, + ProgramAttentionMixin, IsaCompiler, ): """ PLENA High-level Compiler Interface. - Inherits the ISA-emission machinery from IsaCompiler and layers a - Pythonic DSL on top. All operations are eagerly evaluated — ISA code is - generated immediately upon call. + Inherits the ISA-emission machinery from IsaCompiler and layers typed + program-builder helpers on top. Operations eagerly emit ISA text. """ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): @@ -87,135 +56,9 @@ def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125 self._inputs: dict[str, InputVar] = {} self._tensors: dict[str, TensorVar] = {} self._fp_vars: dict[str, FPVar] = {} - self._functions: dict[str, Callable] = {} self._registered_hbm_sub_matrices: dict[str, bool] = {} self._registered_vram_sub_matrices: dict[str, bool] = {} - self._result_tensor: TensorVar | None = None - - # Auto-generated name counter - self._auto_name_counter: int = 0 - - # Function scope namespace - # Push a prefix on each function call (e.g., "linear_0/"), pop on exit - # _auto_name will automatically add current scope prefix, avoiding name conflicts when calling the same function multiple times - self._scope_stack: list[str] = [] - self._function_call_counters: dict[str, int] = {} # func_name -> call count - - # ======================================================================== - # Property Access - # ======================================================================== - - # mlen / blen are instance attributes inherited from TileCompiler.__init__. - - @property - def compiler(self) -> PlenaCompiler: - """Legacy accessor — returns self now that PlenaCompiler is the compiler.""" - return self - - @property - def _compiler(self) -> _IsaCompilerView: - """Back-compat shim for legacy ``prog._compiler.X(...)`` call sites. - Returns a proxy that resolves callables against IsaCompiler - directly so callers reach the low-level API regardless of any - PlenaCompiler override with the same name.""" - return _IsaCompilerView(self) - - @property - def symbol_table(self): - """Access symbol table.""" - return self.get_symbol_table() - - # ======================================================================== - # Function Decorator - # ======================================================================== - - def function(self, func: Callable) -> Callable: - """ - Decorator: Define reusable functions. - - Each invocation generates fresh ISA code (eager evaluation). - Internally allocated tensors are auto-freed on exit unless returned. - - Scoping: intermediate tensors get a call-index prefix to avoid name - collisions across repeated calls (e.g., "linear_0/proj_1", "linear_1/proj_1"). - Nested functions compose prefixes: "two_layer_0/linear_0/proj_1". - """ - func_name = func.__name__ - - @wraps(func) - def wrapper(*args, **kwargs): - call_idx = self._function_call_counters.get(func_name, 0) - self._function_call_counters[func_name] = call_idx + 1 - - scope = f"{func_name}_{call_idx}/" - self._scope_stack.append(scope) - - self.emit_comment(f"=== Enter {func_name} (call #{call_idx}) ===") - - # Snapshot: record existing tensors before function execution - tensors_before = set(self._tensors.keys()) - inputs_before = set(self._inputs.keys()) - fp_vars_before = set(self._fp_vars.keys()) - - try: - result = func(*args, **kwargs) - - # Auto-free: free locally allocated tensors that are not returned - return_names = set() - return_fp_names = set() - if isinstance(result, TensorVar): - return_names.add(result.internal_name) - elif isinstance(result, FPVar): - return_fp_names.add(result.internal_name) - elif isinstance(result, (tuple, list)): - for r in result: - if isinstance(r, TensorVar): - return_names.add(r.internal_name) - elif isinstance(r, FPVar): - return_fp_names.add(r.internal_name) - - for name in set(self._tensors.keys()) - tensors_before: - if name not in return_names: - tensor = self._tensors[name] - if isinstance(tensor, VRAMMatrixVar): - self.free_tensor(tensor) - self._registered_vram_sub_matrices[tensor.name] = False - - for name in set(self._inputs.keys()) - inputs_before: - if name not in return_names: - self.free_input(self._inputs[name]) - - local_fp_names = sorted( - set(self._fp_vars.keys()) - fp_vars_before, - key=lambda n: self._fp_vars[n].address, - reverse=True, - ) - for name in local_fp_names: - if name in return_fp_names: - continue - fp_var = self._fp_vars.get(name) - if fp_var is not None: - self.free_fp_var(fp_var) - finally: - self._scope_stack.pop() - self.emit_comment(f"=== Exit {func_name} (call #{call_idx}) ===") - - return result - - self._functions[func_name] = wrapper - wrapper._plena_function = True - wrapper._plena_name = func_name - return wrapper - - # ======================================================================== - # Result Marking - # ======================================================================== - - def result(self, tensor_var: TensorVar): - """Mark output result tensor.""" - self._result_tensor = tensor_var - # ======================================================================== # Compilation # ======================================================================== @@ -224,37 +67,17 @@ def compile(self) -> str: """Get generated ISA code string.""" return super().get_code() - def print_symbol_table(self): - """Print symbol table""" - super().print_symbol_table() - - def get_symbol_table(self): - """Get symbol table""" - return super().get_symbol_table() - - # ======================================================================== - # Operator Dispatch (internal) - # ======================================================================== - - def _dispatch_matmul(self, left: TensorVar, right) -> TensorVar: - raise TypeError("@ operator is no longer supported in PlenaCompiler. Use explicit program APIs instead.") + @property + def _compiler(self) -> PlenaCompiler: + """Compatibility alias for simulator testbench callers.""" + return self # ======================================================================== # Utility Methods # ======================================================================== def _scoped_name(self, name: str) -> str: - """ - Apply current scope prefix to a name. - - - Top-level alloc("temp"): -> "temp" - - Inside linear call 0, alloc("temp"): -> "linear_0/temp" - - Nested two_layer->linear, alloc("temp"): -> "two_layer_0/linear_0/temp" - """ - if not self._scope_stack: - return name - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}{name}" + return name def _allocate_hbm(self, hbm_size: int) -> int: """Allocate HBM range, preferring previously freed blocks.""" @@ -286,28 +109,5 @@ def _recycle_hbm(self, hbm_addr: int, hbm_size: int): return self._hbm_free_blocks.append((hbm_addr, hbm_size)) - def _auto_name(self, prefix: str = "t") -> str: - """ - Generate a unique scoped name. - - - Top-level: "__proj_1" - - linear call 0: "linear_0/__proj_1" - - nested: "two_layer_0/linear_0/__proj_1" - """ - self._auto_name_counter += 1 - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}__{prefix}_{self._auto_name_counter}" - - def __repr__(self): - num_inputs = len(self._inputs) - num_tensors = len(self._tensors) - num_functions = len(self._functions) - code_len = len(super().get_code().splitlines()) - return ( - f"PlenaCompiler(mlen={self.mlen}, blen={self.blen}, " - f"inputs={num_inputs}, tensors={num_tensors}, " - f"functions={num_functions}, isa_lines={code_len})" - ) - __all__ = ["PlenaCompiler"] diff --git a/aten/plena/dsl_fp_tile_ops.py b/aten/plena/dsl_fp_tile_ops.py deleted file mode 100644 index 0c40411..0000000 --- a/aten/plena/dsl_fp_tile_ops.py +++ /dev/null @@ -1,568 +0,0 @@ -"""FPRAM, FPVar, and tile-row operations for the PLENA DSL.""" - -from __future__ import annotations - -from compiler.aten.plena.vars import FPVar, VRAMMatrixVar - - -class DslFPTileOpsMixin: - # ======================================================================== - # FP Variable (FPRAM) - # ======================================================================== - - def allocate_fpram( - self, - internal_name: str, - size: int = 1, - display_name: str | None = None, - ) -> FPVar: - """ - Allocate FPRAM with explicit internal name and return FPVar proxy. - """ - if size <= 0: - raise ValueError(f"FPRAM allocation size must be positive, got {size}") - - address = super().allocate_fpram(internal_name, size) - var = FPVar( - self, - internal_name, - address, - size, - display_name=display_name if display_name is not None else internal_name, - ) - self._fp_vars[internal_name] = var - return var - - def free_fpram(self, internal_name: str, strict: bool = True): - """ - Free FPRAM allocation by internal name. - """ - super().free_fpram(internal_name, strict=strict) - self._fp_vars.pop(internal_name, None) - - def fp_var(self, name: str, size: int = 1) -> FPVar: - """ - Declare an FP variable in FPRAM. - - Allocates a contiguous region in FPRAM and returns an FPVar proxy. - Within function scope, names are automatically prefixed. - - Args: - name: variable name - size: number of f16 elements to allocate (default 1) - - Returns: - FPVar proxy object (use .address for ISA generation) - - Example: - scale = prog.fp_var("scale") # 1 element - m_old = prog.fp_var("m_old", size=64) # 64 elements - prog.compiler # access compiler for ISA if needed - """ - display_name = name - internal_name = self._scoped_name(name) - - return self.allocate_fpram( - internal_name=internal_name, - size=size, - display_name=display_name, - ) - - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot""" - return super().save_fpram_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot""" - super().restore_fpram_state(snapshot) - # Remove FPVar proxies that are no longer allocated in allocator. - allocations = set(super().list_fpram_allocations()) - to_remove = [n for n in self._fp_vars if n not in allocations] - for n in to_remove: - del self._fp_vars[n] - - # ======================================================================== - # FPRAM Tile Operations - # ======================================================================== - - def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: - if isinstance(addr_or_var, FPVar): - if offset < 0 or offset >= addr_or_var.size: - raise ValueError( - f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" - ) - return addr_or_var.address + offset - if not isinstance(addr_or_var, int): - raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") - return addr_or_var + offset - - def _resolve_rows( - self, - row_idx: int | None = None, - rows: list[int] | None = None, - ) -> list[int]: - if row_idx is not None and rows is not None: - raise ValueError("Provide either row_idx or rows, not both") - if rows is not None: - return rows - if row_idx is not None: - return [row_idx] - return list(range(self.mlen)) - - def tile_row_max( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Max: reduce a single row to max, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - - Example: - m = prog.fp_var("m", size=1) - S = prog.alloc("S", 64, 64) - for row in range(64): - prog.tile_row_max(m, S, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_max( - source_matrix=source.name, - row_map=row_map, - ) - - def tile_row_sum( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Sum: reduce a single row to sum, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sum(source.name, row_map) - - def tile_row_exp( - self, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - ): - """ - Tile Row Exp: in-place exp on specified rows. - - For each row i: source[i, :] = exp(source[i, :]) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - super().tile_row_exp(source.name, resolved_rows) - - def tile_row_reci( - self, - source: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Reciprocal: in-place 1/x on specified rows. - - For each row i: source[i, :] = 1.0 / source[i, :] - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_reci(source.name, rows) - - def tile_row_sub_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Sub FP: subtract FPRAM scalar from a single row. - - For row i: source[i, :] = source[i, :] - FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sub_fp(source.name, row_map) - - def tile_row_mul_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Mul FP: multiply a single row by FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_mul_fp(source.name, row_map) - - def tile_row_add_fp( - self, - source: VRAMMatrixVar, - fp_var: FPVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add FP: add FPRAM scalar to each specified row. - - For each row i: source[i, :] = source[i, :] + fp_var[i] - """ - if rows is None: - rows = list(range(self.mlen)) - row_map = [(r, fp_var[r]) for r in rows] - super().tile_row_add_fp(source.name, row_map) - - def tile_row_add( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add: dst[i, :] += src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_add(dst.name, src.name, rows) - - def tile_row_sub( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Sub: dst[i, :] -= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_sub(dst.name, src.name, rows) - - def tile_row_mul( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Mul: dst[i, :] *= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_mul(dst.name, src.name, rows) - - def fpvar_reci( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Reciprocal: compute 1/x for FPRAM scalar array. - - For each element i: dst[i] = 1.0 / src[i] - - Args: - src: source FPVar - dst: destination FPVar - count: number of elements (default: min(src.size, dst.size)) - - Example: - l = prog.fp_var("l", size=64) - inv_l = prog.fp_var("inv_l", size=64) - prog.fpvar_reci(l, inv_l) # inv_l = 1/l - """ - if count is None: - count = min(src.size, dst.size) - if count > src.size or count > dst.size: - raise ValueError(f"count={count} exceeds FPVar size: src.size={src.size}, dst.size={dst.size}") - super().fpram_reci(src.name, dst.name, count) - - def fpvar_max( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Max: element-wise max for FPRAM scalar arrays. - - For each element i: dst[i] = max(src1[i], src2[i]) - - Example: - m_new = prog.fp_var("m_new", size=64) - prog.fpvar_max(m_old, row_max, m_new) # m_new = max(m_old, row_max) - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_max(src1.name, src2.name, dst.name, count) - - def fpvar_sub( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Subtract: element-wise subtraction for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] - src2[i] - - Example: - diff = prog.fp_var("diff", size=64) - prog.fpvar_sub(m_old, m_new, diff) # diff = m_old - m_new - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_sub(src1.name, src2.name, dst.name, count) - - def fpvar_exp( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Exp: element-wise exp for FPRAM scalar array. - - For each element i: dst[i] = exp(src[i]) - - Example: - m_res = prog.fp_var("m_res", size=64) - prog.fpvar_exp(diff, m_res) # m_res = exp(diff) - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_exp(src.name, dst.name, count) - - def fpvar_mul( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Multiply: element-wise multiplication for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] * src2[i] - - Example: - result = prog.fp_var("result", size=64) - prog.fpvar_mul(l_old, m_res, result) # result = l_old * m_res - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_mul(src1.name, src2.name, dst.name, count) - - def fpvar_add( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Add: element-wise addition for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] + src2[i] - - Example: - l_new = prog.fp_var("l_new", size=64) - prog.fpvar_add(l_old, sum_p, l_new) # l_new = l_old + sum_p - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_add(src1.name, src2.name, dst.name, count) - - def fpvar_copy( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Copy: copy FPRAM scalar array. - - For each element i: dst[i] = src[i] - - Example: - m_old_saved = prog.fp_var("m_old_saved", size=64) - prog.fpvar_copy(m_old, m_old_saved) # backup m_old - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_copy(src.name, dst.name, count) - - def fpvar_sum( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Sum: reduction sum of src into dst[0] (via compiler FPRAM op). - """ - if count is None: - count = src.size - super().fpram_sum(src.name, dst.name, count) - - def fpvar_shift( - self, - src: FPVar, - dst: FPVar, - shift: int, - count: int | None = None, - fill: FPVar | None = None, - ): - """ - FPVar Shift: shift src into dst, filling out-of-range slots with fill (default FPRAM zero). - """ - if count is None: - count = min(src.size, dst.size) - fill_name = None if fill is None else fill.name - super().fpram_shift( - src_name=src.name, - dst_name=dst.name, - shift=shift, - count=count, - fill_fpram_name=fill_name, - ) - - def tile_row_mul_fp_broadcast( - self, - source: VRAMMatrixVar, - fpram_scalar_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - ): - """ - Tile Row Mul FP Broadcast: multiply a single row by a FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_scalar_addr] - - Args: - source: VRAM tile (mlen x mlen) - fpram_scalar_addr: FPRAM address or FPVar of the scalar - row_idx: single row index (legacy path) - rows: multiple row indices - - Example: - scale_fp = prog.fp_var("scale", size=1) - for row in range(64): - prog.tile_row_mul_fp_broadcast(S, scale_fp, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) - super().tile_row_mul_fp_broadcast(source.name, scalar_addr, resolved_rows) - - def fpvar_fill_from_fpram( - self, - dst: FPVar, - src_fpram_addr: int, - count: int | None = None, - ): - """ - FPVar Fill from FPRAM: fill all elements with a value from FPRAM. - - For each element i: dst[i] = FPRAM[src_fpram_addr] - - Args: - dst: destination FPVar - src_fpram_addr: source FPRAM address (e.g., 0 for 0.0, 2 for -inf) - count: number of elements (default: dst.size) - - Example: - m_old = prog.fp_var("m_old", size=64) - prog.fpvar_fill_from_fpram(m_old, 2) # fill with -inf from address 2 - """ - if count is None: - count = dst.size - super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) - - def vram_fill_zero( - self, - matrix: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - VRAM Fill Zero: fill specified rows with 0. - - Args: - matrix: VRAM matrix - rows: which rows to fill (default: all rows) - - Example: - O = prog.alloc("O", 128, 128) - prog.vram_fill_zero(O, rows=range(64, 128)) # zero out second half - """ - if rows is None: - rows = list(range(matrix.shape[0])) - else: - rows = list(rows) - - total_rows, cols = matrix.shape - if any(row < 0 or row >= total_rows for row in rows): - raise ValueError(f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={rows}") - - # VRAM matrices are column-block-major. The low-level tile helper zeros - # one 64-column tile, so walk every column block for wide matrices. - num_col_blocks = (cols + self.mlen - 1) // self.mlen - for col_block in range(num_col_blocks): - super().vram_fill_zero(matrix.name, rows, tile_col_idx=col_block) - - -__all__ = ["DslFPTileOpsMixin"] diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py index 35c6896..c5d7f4b 100644 --- a/aten/plena/isa_compiler.py +++ b/aten/plena/isa_compiler.py @@ -31,156 +31,18 @@ class IsaCompiler( """ ISA Compiler: lowers PLENA compiler operations to assembly text. - Owns symbol_table, register_allocator, and the InterruptManager. - Sub-matrix / memory management is inherited from TileCompiler; the - legacy ``self.tile_compiler`` accessor is preserved as a property - returning ``self`` for a handful of remaining external callers. + Owns register allocation, generated assembly, and tiled memory metadata. """ _ONLINE_SOFTMAX_FPSRAM_BASE = 10 - class InterruptManager: - """ - Interrupt Manager — manages execution timing only. - Actual handlers live on IsaCompiler as ``_handle_k_start``, - ``_handle_k_prefetch_done``, ``_handle_s_tile_done``, ``_handle_k_end``. - """ - - def __init__(self, compiler: IsaCompiler): - self.compiler = compiler - self.enabled = False - - self._k_count = 0 - self._tile_count = 0 - - self.current_matrix: str = "" - self.current_activation: str = "" - self._mlen = compiler.mlen - self._blen = compiler.blen - self._batch = compiler.mlen - - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - @property - def tile_compiler(self): - return self.compiler.tile_compiler - - @property - def k_count(self) -> int: - return self._k_count - - @property - def tile_count(self) -> int: - return self._tile_count - - @property - def batch(self) -> int: - return self._batch - - @property - def out_features(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[0] - return self._mlen - - @property - def hidden_size(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[1] - return self._mlen - - @property - def k_block(self) -> int: - return self._k_block_idx - - @property - def q_block(self) -> int: - return self._q_block_idx - - @property - def s_tile_address(self) -> int: - return self._s_tile_address - - @property - def mlen(self) -> int: - return self._mlen - - @property - def blen(self) -> int: - return self._blen - - def reset(self): - """Reset counters (does not clear current_matrix).""" - self._k_count = 0 - self._tile_count = 0 - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - def enable(self): - self.enabled = True - - def disable(self): - self.enabled = False - - def trigger_k_start(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_start() - - def trigger_k_prefetch_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_k_prefetch_done() - self._k_count += 1 - return result - - def trigger_s_tile_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_s_tile_done() - self._tile_count += 1 - return result - - def trigger_k_end(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_end() - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - # TileCompiler.__init__ sets mlen, blen, unroll_loops, the HBM/VRAM/FPRAM - # matrices and allocators, loaded_sub_blocks, and _address_cache. + # TileCompiler.__init__ sets dimensions, layout tables, memory allocators, + # and the currently loaded MRAM sub-block table. super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) self.real_data_ratio = real_data_ratio self.register_allocator = RegisterAllocator() self.generated_code = "" - self.interrupt = self.InterruptManager(self) - - # Back-compat shim: older callers (and a couple of external ops modules) - # reach into ``compiler.tile_compiler`` directly. After the merge, - # IsaCompiler *is* the TileCompiler, so the property just returns - # ``self``. - @property - def tile_compiler(self) -> IsaCompiler: - return self - - # Interrupt handler placeholders (overridden by flash-attention passes). - - def _handle_k_start(self) -> str: - return "" - - def _handle_k_prefetch_done(self) -> str: - return "" - - def _handle_s_tile_done(self) -> str: - return "" - - def _handle_k_end(self) -> str: - return "" def load_batch( self, @@ -473,14 +335,6 @@ def reset(self): # Call TileCompiler.reset() explicitly since the merged class shadows it. TileCompiler.reset(self) - def print_symbol_table(self): - """Print symbol table""" - self.print_table() - - def get_symbol_table(self): - """Get managed object table view.""" - return self - def get_tensor_info(self, name: str): """Get unified tensor/object info by name.""" return self[name] @@ -568,9 +422,4 @@ def free_vram_object(self, name: str, strict: bool = False): """Free a VRAM object by name (defaults to non-strict).""" return TileCompiler.free_vram_object(self, name, strict=strict) - -# Compatibility alias for callers that imported the old low-level layer name. -DeveloperCompiler = IsaCompiler - - -__all__ = ["DeveloperCompiler", "IsaCompiler"] +__all__ = ["IsaCompiler"] diff --git a/aten/plena/isa_emit.py b/aten/plena/isa_emit.py index c0014a1..adad6be 100644 --- a/aten/plena/isa_emit.py +++ b/aten/plena/isa_emit.py @@ -63,18 +63,6 @@ def free_fpram(self, name: str, strict: bool = True): """Free FPRAM object by name.""" return self.free_fpram_object(name, strict=strict) - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot.""" - return self.fpram_allocator.save_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot.""" - self.fpram_allocator.restore_state(snapshot) - - def list_fpram_allocations(self) -> list[str]: - """List currently allocated FPRAM object names.""" - return list(self.fpram_allocator.allocations.keys()) - def get_fpram_addr(self, name: str) -> int: """Get FPRAM base address from object name.""" return self.get_fpram_layout(name).fpram_addr diff --git a/aten/plena/isa_fp_ops.py b/aten/plena/isa_fp_ops.py index 66ad248..dae941f 100644 --- a/aten/plena/isa_fp_ops.py +++ b/aten/plena/isa_fp_ops.py @@ -232,48 +232,50 @@ def fpvar_shift_asm( # FPVar helpers (name-based wrappers over the address-based ISA generators) # ========================================================================= + def _fpram_unary(self, asm_method: str, src_name: str, dst_name: str, count: int | None = None) -> str: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) if count is None else count + return getattr(self, asm_method)(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + + def _fpram_binary( + self, + asm_method: str, + src1_name: str, + src2_name: str, + dst_name: str, + count: int | None = None, + ) -> str: + count = ( + min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + if count is None + else count + ) + return getattr(self, asm_method)( + self.get_fpram_addr(src1_name), + self.get_fpram_addr(src2_name), + self.get_fpram_addr(dst_name), + count, + ) + def fpram_copy(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_copy_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + return self._fpram_unary("fpvar_copy_asm", src_name, dst_name, count) def fpram_reci(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_reci_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + return self._fpram_unary("fpvar_reci_asm", src_name, dst_name, count) def fpram_exp(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_exp_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + return self._fpram_unary("fpvar_exp_asm", src_name, dst_name, count) def fpram_add(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_add_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) + return self._fpram_binary("fpvar_add_asm", src1_name, src2_name, dst_name, count) def fpram_sub(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_sub_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) + return self._fpram_binary("fpvar_sub_asm", src1_name, src2_name, dst_name, count) def fpram_mul(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_mul_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) + return self._fpram_binary("fpvar_mul_asm", src1_name, src2_name, dst_name, count) def fpram_max(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_max_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) + return self._fpram_binary("fpvar_max_asm", src1_name, src2_name, dst_name, count) def fpram_sum(self, src_name: str, dst_name: str, count: int | None = None) -> str: if count is None: diff --git a/aten/plena/isa_matrix.py b/aten/plena/isa_matrix.py index c7e9abe..ef8c10a 100644 --- a/aten/plena/isa_matrix.py +++ b/aten/plena/isa_matrix.py @@ -7,6 +7,23 @@ class IsaMatrixMixin: + def _emit_hbm_matrix_load(self, layout, gp_count: int, build_body) -> str: + gp_regs = self.register_allocator.allocate_gp(gp_count) + gp_for_addr = self.register_allocator.allocate_gp(2) + addr_reg = self.register_allocator.allocate_addr(1)[0] + try: + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_for_addr, + addr_reg_val=[layout.hbm_base_addr], + ) + isa_code += build_body(addr_reg, gp_regs) + return self._emit(isa_code) + finally: + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_gp(gp_for_addr) + self.register_allocator.free_addr([addr_reg]) + def reset_mram(self) -> str: """ Reset MRAM allocator, free all allocated space @@ -32,24 +49,18 @@ def load_sub_matrix_row( total_size = num_col_blocks * block_size mram_start_addr = self.mram_allocator.allocate(f"{name}[{row_idx}][:]", total_size) - gp_regs = self.register_allocator.allocate_gp(4) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] + return self._emit_hbm_matrix_load( + layout, + 3, + lambda addr_reg, gp_regs: self.load_row_sub_matrices_asm( + name=name, + row_idx=row_idx, + mram_start_addr=mram_start_addr, + hbm_addr_reg=addr_reg, + gp_regs=gp_regs, + ), ) - isa_code += self.load_row_sub_matrices_asm( - name=name, row_idx=row_idx, mram_start_addr=mram_start_addr, hbm_addr_reg=addr_reg, gp_regs=gp_regs - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - return self._emit(isa_code) - def load_sub_matrix_col( self, name: str, @@ -71,30 +82,20 @@ def load_sub_matrix_col( total_size = effective_count * block_size mram_start_addr = self.mram_allocator.allocate(f"{name}[:][{col_idx}]", total_size) - gp_regs = self.register_allocator.allocate_gp(3) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] + return self._emit_hbm_matrix_load( + layout, + 3, + lambda addr_reg, gp_regs: self.load_col_sub_matrices_asm( + name=name, + col_idx=col_idx, + mram_start_addr=mram_start_addr, + hbm_addr_reg=addr_reg, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ), ) - isa_code += self.load_col_sub_matrices_asm( - name=name, - col_idx=col_idx, - mram_start_addr=mram_start_addr, - hbm_addr_reg=addr_reg, - gp_regs=gp_regs, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - return self._emit(isa_code) - def allocate_vram_matrix( self, name: str, @@ -282,6 +283,65 @@ def vram_matrix_add( isa_code = "\n".join(lines) + "\n" return self._emit(isa_code) + def _target_tile_addr(self, target_matrix: str, target_row_idx: int, target_col_idx: int) -> tuple[int, int, int]: + if target_matrix not in self: + raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") + + target_info = self[target_matrix] + target_rows, _target_cols = target_info.shape + target_base_addr = target_info.vram_addr + result_vram_addr = ( + target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen + ) + return result_vram_addr, target_base_addr, target_rows + + def _emit_vram_sub_projection_to( + self, + *, + transposed: bool, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + result_vram_addr, target_base_addr, target_rows = self._target_tile_addr( + target_matrix, target_row_idx, target_col_idx + ) + gp_regs = self.register_allocator.allocate_gp(9) + + if transposed: + isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + asm = self.vram_sub_projection_T_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_row_idx=mram_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + ) + else: + isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + asm = self.vram_sub_projection_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_col_idx=mram_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" + isa_code += asm + + self.register_allocator.free_gp(gp_regs) + return self._emit(isa_code) + def vram_sub_projection_to( self, vram_mat_name: str, @@ -299,38 +359,19 @@ def vram_sub_projection_to( target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[:][mram_col_idx]. Target matrix must have been allocated via allocate_vram_matrix. """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_asm( + return self._emit_vram_sub_projection_to( + transposed=False, vram_mat_name=vram_mat_name, vram_row_idx=vram_row_idx, mram_mat_name=mram_mat_name, - mram_col_idx=mram_col_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, + mram_idx=mram_col_idx, + target_matrix=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, k_block_start=k_block_start, k_block_count=k_block_count, ) - self.register_allocator.free_gp(gp_regs) - - return self._emit(isa_code) - def vram_sub_projection_T_to( self, vram_mat_name: str, @@ -350,35 +391,16 @@ def vram_sub_projection_T_to( K[j][:]: (mlen, hidden_size) row sub-block, transposed to (hidden_size, mlen) S[i][j]: (mlen, mlen) """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_T_asm( + return self._emit_vram_sub_projection_to( + transposed=True, vram_mat_name=vram_mat_name, vram_row_idx=vram_row_idx, mram_mat_name=mram_mat_name, - mram_row_idx=mram_row_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, + mram_idx=mram_row_idx, + target_matrix=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, ) - self.register_allocator.free_gp(gp_regs) - - return self._emit(isa_code) - __all__ = ["IsaMatrixMixin"] diff --git a/aten/plena/isa_tile_rows.py b/aten/plena/isa_tile_rows.py index 8d2e247..a225209 100644 --- a/aten/plena/isa_tile_rows.py +++ b/aten/plena/isa_tile_rows.py @@ -10,6 +10,36 @@ class IsaTileRowMixin: # Tile-row helpers (name-based) # ========================================================================= + def _tile_addr(self, matrix_name: str, tile_row_idx: int = 0, tile_col_idx: int = 0) -> int: + return self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + + def _tile_row_single_matrix_op( + self, + asm_method: str, + matrix_name: str, + arg, + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return getattr(self, asm_method)(self._tile_addr(matrix_name, tile_row_idx, tile_col_idx), arg) + + def _tile_row_binary_matrix_op( + self, + asm_method: str, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + return getattr(self, asm_method)( + self._tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx), + self._tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx), + rows, + ) + def tile_row_max( self, source_matrix: str, @@ -17,8 +47,7 @@ def tile_row_max( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_max_asm(source_addr, row_map) + return self._tile_row_single_matrix_op("tile_row_max_asm", source_matrix, row_map, tile_row_idx, tile_col_idx) def tile_row_sum( self, @@ -27,8 +56,7 @@ def tile_row_sum( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_sum_asm(source_addr, row_map) + return self._tile_row_single_matrix_op("tile_row_sum_asm", source_matrix, row_map, tile_row_idx, tile_col_idx) def tile_row_exp( self, @@ -37,8 +65,7 @@ def tile_row_exp( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_exp_asm(matrix_addr, rows) + return self._tile_row_single_matrix_op("tile_row_exp_asm", matrix_name, rows, tile_row_idx, tile_col_idx) def tile_row_reci( self, @@ -47,8 +74,7 @@ def tile_row_reci( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_reci_asm(matrix_addr, rows) + return self._tile_row_single_matrix_op("tile_row_reci_asm", matrix_name, rows, tile_row_idx, tile_col_idx) def tile_row_sub_fp( self, @@ -57,8 +83,7 @@ def tile_row_sub_fp( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_sub_fp_asm(matrix_addr, row_map) + return self._tile_row_single_matrix_op("tile_row_sub_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) def tile_row_mul_fp( self, @@ -67,8 +92,7 @@ def tile_row_mul_fp( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_asm(matrix_addr, row_map) + return self._tile_row_single_matrix_op("tile_row_mul_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) def tile_row_add_fp( self, @@ -77,8 +101,7 @@ def tile_row_add_fp( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_add_fp_asm(matrix_addr, row_map) + return self._tile_row_single_matrix_op("tile_row_add_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) def tile_row_add( self, @@ -90,9 +113,16 @@ def tile_row_add( src_tile_row_idx: int = 0, src_tile_col_idx: int = 0, ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_add_asm(dst_addr, src_addr, rows) + return self._tile_row_binary_matrix_op( + "tile_row_add_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) def tile_row_sub( self, @@ -104,9 +134,16 @@ def tile_row_sub( src_tile_row_idx: int = 0, src_tile_col_idx: int = 0, ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_sub_asm(dst_addr, src_addr, rows) + return self._tile_row_binary_matrix_op( + "tile_row_sub_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) def tile_row_mul( self, @@ -118,9 +155,16 @@ def tile_row_mul( src_tile_row_idx: int = 0, src_tile_col_idx: int = 0, ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_mul_asm(dst_addr, src_addr, rows) + return self._tile_row_binary_matrix_op( + "tile_row_mul_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) def tile_row_mul_fp_broadcast( self, @@ -130,8 +174,11 @@ def tile_row_mul_fp_broadcast( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_broadcast_asm(matrix_addr, fpram_scalar_addr, rows) + return self.tile_row_mul_fp_broadcast_asm( + self._tile_addr(matrix_name, tile_row_idx, tile_col_idx), + fpram_scalar_addr, + rows, + ) def vram_fill_zero( self, @@ -140,8 +187,7 @@ def vram_fill_zero( tile_row_idx: int = 0, tile_col_idx: int = 0, ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.vram_fill_zero_asm(matrix_addr, rows) + return self._tile_row_single_matrix_op("vram_fill_zero_asm", matrix_name, rows, tile_row_idx, tile_col_idx) # ========================================================================= # Tile-row ISA helpers (address-based) diff --git a/aten/plena/memory.py b/aten/plena/memory.py index 5d51187..e09bf54 100644 --- a/aten/plena/memory.py +++ b/aten/plena/memory.py @@ -15,44 +15,15 @@ @dataclass class MemoryBlock: - """Memory Block Information""" - name: str # Allocation name (e.g., "W[0][1]" or "activation_A") addr: int # Starting address size: int # Block size (number of elements) - def __repr__(self) -> str: - return f"MemBlock({self.name}, addr={self.addr}, size={self.size})" - class VirtualMemoryManager: - """ - Virtual Memory Manager - - Core Design: - - used_stack: Allocated and in-use memory blocks - - free_stack: Freed and reusable memory blocks - - Workflow: - 1. allocate(name, size): Allocate memory - - Prioritize best-fit search for reusable blocks in free_stack - - If not found, use bump allocation from the end - 2. free(name): Free memory - - Move block from used_stack to free_stack - - Address can be reused by subsequent allocate calls - - Throws exception if not found when strict=True, returns None when strict=False - - VRAM/MRAM storage format: (batch_size, mlen, hidden_size/mlen), column-block major. - mlen=64, blen=4. Alignment depends on storage hierarchy. - """ + """Best-fit reuse plus bump allocation for PLENA virtual memories.""" def __init__(self, total_size: int, alignment: int = MLEN, mem_name: str = "Memory"): - """ - Args: - total_size: Total memory size (number of elements) - alignment: alignment granularity (VRAM uses MLEN=64, MRAM uses MLEN*MLEN=4096) - mem_name: Memory name, for debugging information (e.g., "VRAM" or "MRAM") - """ self.total_size = total_size self.alignment = alignment self.mem_name = mem_name @@ -67,33 +38,16 @@ def _align(self, value: int) -> int: return ((value + self.alignment - 1) // self.alignment) * self.alignment def allocate(self, name: str, size: int) -> int: - """ - Allocate memory. - - Strategy: - 1. Best-fit from free_stack (reusable block with least waste). - 2. If no suitable block, bump allocation. - - Returns: - Allocated starting address. - - Raises: - MemoryError: Insufficient memory. - """ + """Allocate by best-fit reuse first, then bump allocation.""" aligned_size = self._align(size) - # Strategy 1: best-fit from free_stack - best_idx = None - best_waste = float("inf") - - for i, block in enumerate(self.free_stack): - if block.size >= aligned_size: - waste = block.size - aligned_size - if waste < best_waste: - best_waste = waste - best_idx = i + best = min( + ((block.size - aligned_size, i) for i, block in enumerate(self.free_stack) if block.size >= aligned_size), + default=None, + ) - if best_idx is not None: + if best is not None: + _, best_idx = best reused_block = self.free_stack.pop(best_idx) # If block is larger than needed, split remaining part and return to free_stack @@ -107,7 +61,6 @@ def allocate(self, name: str, size: int) -> int: self.used_stack.append(new_block) return new_block.addr - # Strategy 2: Bump allocation aligned_addr = self._align(self.next_bump) if self.total_size > 0 and aligned_addr + aligned_size > self.total_size: @@ -124,16 +77,7 @@ def allocate(self, name: str, size: int) -> int: return aligned_addr def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """ - Free memory: move block from used_stack to free_stack. - - Args: - name: Name of allocation to free - strict: Throws KeyError if not found when strict=True, returns None when strict=False - - Returns: - Freed memory block, returns None if strict=False and not found - """ + """Move an allocation from used_stack to reusable free_stack.""" for i, block in enumerate(self.used_stack): if block.name == name: freed = self.used_stack.pop(i) @@ -149,18 +93,7 @@ def free(self, name: str, strict: bool = True) -> MemoryBlock | None: return None def mark_used(self, addr: int, size: int, name: str) -> None: - """ - Register a pre-known address range as occupied without bump allocation. - - Used for prestaged VRAM tensors that are already in VRAM before program - execution (e.g. Q pre-loaded by the test harness). Advances next_bump - past the end of this region so subsequent bump allocations do not collide. - - Args: - addr: Start address of the pre-occupied region. - size: Number of elements in the region. - name: Name for tracking/free. - """ + """Register a pre-known occupied range and advance bump past it.""" aligned_size = self._align(size) block = MemoryBlock(name=name, addr=addr, size=aligned_size) self.used_stack.append(block) @@ -170,9 +103,7 @@ def mark_used(self, addr: int, size: int, name: str) -> None: self.next_bump = end def _coalesce_free_stack(self): - """ - Merge adjacent free blocks by address to reduce long-term fragmentation. - """ + """Merge adjacent free blocks by address.""" if len(self.free_stack) <= 1: return @@ -191,60 +122,12 @@ def _coalesce_free_stack(self): self.free_stack = merged - def is_allocated(self, name: str) -> bool: - """Check if a name is in used_stack""" - return any(b.name == name for b in self.used_stack) - - def get_block(self, name: str) -> MemoryBlock | None: - """Get memory block with specified name from used_stack""" - for block in self.used_stack: - if block.name == name: - return block - return None - - def get_used_size(self) -> int: - """Get total used size""" - return sum(b.size for b in self.used_stack) - - def get_free_size(self) -> int: - """Get total reusable size""" - return sum(b.size for b in self.free_stack) - def reset(self): """Reset manager""" self.next_bump = 0 self.used_stack.clear() self.free_stack.clear() - def print_status(self): - """Print memory status""" - print(f"=== {self.mem_name} Virtual Memory Status ===") - print(f"Total size: {self.total_size}") - print(f"Bump pointer: {self.next_bump}") - print(f"Used blocks ({len(self.used_stack)}):") - for b in self.used_stack: - print(f" {b}") - print(f"Free blocks ({len(self.free_stack)}):") - for b in self.free_stack: - print(f" {b}") - total_used = self.get_used_size() - total_free = self.get_free_size() - if self.total_size > 0: - available = self.total_size - self.next_bump + total_free - print( - f"Summary: used={total_used}, free={total_free}, " - f"bump={self.next_bump}, available={available}/{self.total_size}" - ) - else: - print(f"Summary: used={total_used}, free={total_free}, bump={self.next_bump} (unlimited mode)") - - def __repr__(self) -> str: - return ( - f"VirtualMemoryManager({self.mem_name}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)}, " - f"bump={self.next_bump}/{self.total_size})" - ) - # ============================================================================== # Sub-matrix Information @@ -264,13 +147,6 @@ class SubMatrixInfo: hbm_offset: int = 0 # Offset in HBM (in elements) mram_addr: int | None = None # Address in MRAM (if loaded) - def __repr__(self) -> str: - mram_str = f"{self.mram_addr}" if self.mram_addr is not None else "None" - return ( - f"SubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, hbm_off={self.hbm_offset}, mram={mram_str})" - ) - @dataclass class MatrixBlockLayout: @@ -349,12 +225,6 @@ class VRAMSubMatrixInfo: # Pre-calculated VRAM address vram_addr: int = 0 - def __repr__(self) -> str: - return ( - f"VRAMSubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, vram={self.vram_addr})" - ) - @dataclass class VRAMMatrixBlockLayout: @@ -418,10 +288,6 @@ def get_col_blocks(self, col_idx: int) -> list[VRAMSubMatrixInfo]: """Get all sub-blocks in a column""" return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] - def get_row_vram_addrs(self, row_idx: int) -> list[int]: - """Get list of VRAM addresses for all sub-blocks in a row""" - return [block.vram_addr for block in self.get_row_blocks(row_idx)] - # ============================================================================== # Unified Memory Object Metadata @@ -456,118 +322,65 @@ class FPRAMObjectLayout: # ============================================================================== -# MRAM Allocator +# Allocators # ============================================================================== -class MRAMAllocator: - """ - Matrix RAM address allocator (based on VirtualMemoryManager). +class MemoryAllocatorBase: + """Shared wrapper over VirtualMemoryManager for compiler address spaces.""" - Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. - Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). - Supports virtual free/reuse: freed blocks move to free_stack and are - preferred by subsequent allocations. - """ - - def __init__(self, total_size: int = MLEN * MLEN * 4): - """ - Args: - total_size: Total MRAM size (default 16384, can hold 4 64x64 matrix blocks) - """ + def __init__(self, total_size: int, alignment: int, mem_name: str): self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=MLEN * MLEN, # aligned to one sub-block size - mem_name="MRAM", - ) + self.alignment = alignment + self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name=mem_name) @property def next_free(self) -> int: return self._vmm.next_bump - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack + @next_free.setter + def next_free(self, value: int): + self._validate_next_free(value) + self._vmm.next_bump = value - def allocate(self, name: str, size: int) -> int: - """Allocate MRAM space (prioritize reusing freed blocks).""" - return self._vmm.allocate(name, size) + def _validate_next_free(self, value: int) -> None: + del value def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """Free specified allocation: move from used_stack to free_stack.""" return self._vmm.free(name, strict=strict) - def is_allocated(self, name: str) -> bool: - """Check if a name is allocated""" - return self._vmm.is_allocated(name) - def reset(self): - """Reset allocator""" self._vmm.reset() - def print_status(self): - """Print memory status""" - self._vmm.print_status() - -class VRAMAllocator: +class MRAMAllocator(MemoryAllocatorBase): """ - VRAM address allocator (based on VirtualMemoryManager). + Matrix RAM address allocator. - VRAM supports best-fit reuse + bump allocation, same as MRAM/FPRAM allocators. - Alignment defaults to MLEN to match VRAM storage format requirements. + Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. + Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). """ - def __init__(self, alignment: int = MLEN, total_size: int = 0): - self.alignment = alignment - self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name="VRAM") + def __init__(self, total_size: int = MLEN * MLEN * 4): + super().__init__(total_size=total_size, alignment=MLEN * MLEN, mem_name="MRAM") - @property - def next_free(self) -> int: - return self._vmm.next_bump + def allocate(self, name: str, size: int) -> int: + return self._vmm.allocate(name, size) - @next_free.setter - def next_free(self, value: int): - self._vmm.next_bump = value - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack +class VRAMAllocator(MemoryAllocatorBase): + """VRAM address allocator with MLEN-aligned best-fit reuse + bump allocation.""" - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack + def __init__(self, alignment: int = MLEN, total_size: int = 0): + super().__init__(total_size=total_size, alignment=alignment, mem_name="VRAM") def allocate(self, size: int, name: str = "") -> int: if not name: raise ValueError("VRAMAllocator.allocate() requires name for subsequent free.") return self._vmm.allocate(name, size) - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - return self._vmm.free(name, strict=strict) - - def is_allocated(self, name: str) -> bool: - return self._vmm.is_allocated(name) - def reset(self): - self._vmm.reset() - - def print_status(self): - self._vmm.print_status() - - def __repr__(self) -> str: - return ( - f"VRAMAllocator(next_free={self.next_free}, alignment={self.alignment}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)})" - ) - - -class FPRAMAllocator: +class FPRAMAllocator(MemoryAllocatorBase): """ Floating Point RAM Allocator (based on VirtualMemoryManager). @@ -593,34 +406,12 @@ def __init__(self, total_size: int = 1024): Args: total_size: Total FP RAM size (default 1024, matching hardware fpsram) """ - self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=1, - mem_name="FPRAM", - ) + super().__init__(total_size=total_size, alignment=1, mem_name="FPRAM") self.allocations: dict[str, tuple[int, int]] = {} - self._snapshots: dict[int, tuple[int, list[MemoryBlock], list[MemoryBlock], dict[str, tuple[int, int]]]] = {} - self._next_snapshot_id = 1 - - @property - def next_free(self) -> int: - """Compatibility alias: next bump pointer.""" - return self._vmm.next_bump - @next_free.setter - def next_free(self, value: int): + def _validate_next_free(self, value: int) -> None: if value < 0 or value > self.total_size: raise ValueError(f"next_free out of range: {value}, expected [0, {self.total_size}]") - self._vmm.next_bump = value - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack def allocate(self, name: str, size: int) -> int: """Allocate FP RAM space (best-fit + bump).""" @@ -640,35 +431,7 @@ def free(self, name: str, strict: bool = True) -> MemoryBlock | None: self.allocations.pop(name, None) return freed - def save_state(self) -> int: - """ - Save current allocator state and return a snapshot token. - """ - sid = self._next_snapshot_id - self._next_snapshot_id += 1 - self._snapshots[sid] = ( - self._vmm.next_bump, - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.used_stack], - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.free_stack], - dict(self.allocations), - ) - return sid - - def restore_state(self, snapshot: int): - """Restore allocator state from snapshot token.""" - if snapshot not in self._snapshots: - raise KeyError(f"Unknown FPRAM snapshot id: {snapshot}") - next_bump, used_stack, free_stack, allocations = self._snapshots[snapshot] - self._vmm.next_bump = next_bump - self._vmm.used_stack = [MemoryBlock(b.name, b.addr, b.size) for b in used_stack] - self._vmm.free_stack = [MemoryBlock(b.name, b.addr, b.size) for b in free_stack] - self.allocations = dict(allocations) - def reset(self): """Reset allocator""" self._vmm.reset() self.allocations.clear() - self._snapshots.clear() - self._next_snapshot_id = 1 - - diff --git a/aten/plena/dsl_attention.py b/aten/plena/program_attention.py similarity index 95% rename from aten/plena/dsl_attention.py rename to aten/plena/program_attention.py index b1ef4d7..77ce389 100644 --- a/aten/plena/dsl_attention.py +++ b/aten/plena/program_attention.py @@ -1,11 +1,11 @@ -"""Flash-attention operations for the PLENA DSL.""" +"""Flash-attention operations for the PLENA program builder.""" from __future__ import annotations from compiler.aten.plena.vars import InputVar, VRAMMatrixVar -class DslAttentionMixin: +class ProgramAttentionMixin: # ======================================================================== # Flash Attention Operations # ======================================================================== @@ -79,4 +79,4 @@ def final_scale_o(self, q_idx: int, o_matrix: VRAMMatrixVar): ) -__all__ = ["DslAttentionMixin"] +__all__ = ["ProgramAttentionMixin"] diff --git a/aten/plena/program_fp_tile_ops.py b/aten/plena/program_fp_tile_ops.py new file mode 100644 index 0000000..86a8d16 --- /dev/null +++ b/aten/plena/program_fp_tile_ops.py @@ -0,0 +1,347 @@ +"""FPRAM, FPVar, and tile-row operations for the PLENA program builder.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from compiler.aten.plena.vars import FPVar, VRAMMatrixVar + + +class ProgramFPTileOpsMixin: + # ======================================================================== + # FP Variable (FPRAM) + # ======================================================================== + + def allocate_fpram( + self, + internal_name: str, + size: int = 1, + display_name: str | None = None, + ) -> FPVar: + """Allocate FPRAM with an explicit internal name and return an FPVar.""" + if size <= 0: + raise ValueError(f"FPRAM allocation size must be positive, got {size}") + + address = super().allocate_fpram(internal_name, size) + var = FPVar( + self, + internal_name, + address, + size, + display_name=display_name if display_name is not None else internal_name, + ) + self._fp_vars[internal_name] = var + return var + + def free_fpram(self, internal_name: str, strict: bool = True): + super().free_fpram(internal_name, strict=strict) + self._fp_vars.pop(internal_name, None) + + def fp_var(self, name: str, size: int = 1) -> FPVar: + return self.allocate_fpram( + internal_name=self._scoped_name(name), + size=size, + display_name=name, + ) + + # ======================================================================== + # Shared argument normalization + # ======================================================================== + + def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: + if isinstance(addr_or_var, FPVar): + if offset < 0 or offset >= addr_or_var.size: + raise ValueError( + f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" + ) + return addr_or_var.address + offset + if not isinstance(addr_or_var, int): + raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") + return addr_or_var + offset + + def _resolve_rows( + self, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + ) -> list[int]: + if row_idx is not None and rows is not None: + raise ValueError("Provide either row_idx or rows, not both") + if rows is not None: + return list(rows) + if row_idx is not None: + return [row_idx] + return list(range(self.mlen)) + + def _default_rows(self, rows: Iterable[int] | None, *, total_rows: int | None = None) -> list[int]: + return list(range(self.mlen if total_rows is None else total_rows)) if rows is None else list(rows) + + def _fpram_row_map( + self, + fpram_addr: int | FPVar, + *, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + single_offset: int = 0, + base_offset: int = 0, + ) -> list[tuple[int, int]]: + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + offsets = ( + [single_offset] + if len(resolved_rows) == 1 + else [base_offset + i for i in range(len(resolved_rows))] + ) + return [(row, self._resolve_fpram_addr(fpram_addr, offset)) for row, offset in zip(resolved_rows, offsets)] + + def _fp_count(self, vars_: Iterable[FPVar], count: int | None, *, default: int | None = None) -> int: + fp_vars = list(vars_) + resolved_count = default if count is None and default is not None else count + if resolved_count is None: + resolved_count = min(var.size for var in fp_vars) + if any(resolved_count > var.size for var in fp_vars): + sizes = ", ".join(f"{var.name}.size={var.size}" for var in fp_vars) + raise ValueError(f"count={resolved_count} exceeds FPVar size: {sizes}") + return resolved_count + + def _fpvar_unary(self, isa_method: str, src: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src, dst], count) + return getattr(super(), isa_method)(src.name, dst.name, count) + + def _fpvar_binary(self, isa_method: str, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src1, src2, dst], count) + return getattr(super(), isa_method)(src1.name, src2.name, dst.name, count) + + # ======================================================================== + # FPRAM tile-row operations + # ======================================================================== + + def _tile_row_reduce_to_fpram( + self, + isa_method: str, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None, + rows: Iterable[int] | None, + target_offset: int, + target_base_offset: int, + ): + return getattr(super(), isa_method)( + source.name, + self._fpram_row_map( + target_fpram_addr, + row_idx=row_idx, + rows=rows, + single_offset=target_offset, + base_offset=target_base_offset, + ), + ) + + def tile_row_max( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + return self._tile_row_reduce_to_fpram( + "tile_row_max", target_fpram_addr, source, row_idx, rows, target_offset, target_base_offset + ) + + def tile_row_sum( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + return self._tile_row_reduce_to_fpram( + "tile_row_sum", target_fpram_addr, source, row_idx, rows, target_offset, target_base_offset + ) + + def tile_row_exp( + self, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + ): + super().tile_row_exp(source.name, self._resolve_rows(row_idx=row_idx, rows=rows)) + + def tile_row_reci( + self, + source: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + super().tile_row_reci(source.name, self._default_rows(rows)) + + def tile_row_sub_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + return self._tile_row_fp_scalar( + "tile_row_sub_fp", source, fpram_addr, row_idx, rows, fpram_offset, fpram_base_offset + ) + + def tile_row_mul_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + return self._tile_row_fp_scalar( + "tile_row_mul_fp", source, fpram_addr, row_idx, rows, fpram_offset, fpram_base_offset + ) + + def tile_row_add_fp( + self, + source: VRAMMatrixVar, + fp_var: FPVar, + rows: Iterable[int] | None = None, + ): + resolved_rows = self._default_rows(rows) + super().tile_row_add_fp(source.name, [(row, fp_var[row]) for row in resolved_rows]) + + def _tile_row_binary(self, isa_method: str, dst: VRAMMatrixVar, src: VRAMMatrixVar, rows: Iterable[int] | None): + return getattr(super(), isa_method)(dst.name, src.name, self._default_rows(rows)) + + def _tile_row_fp_scalar( + self, + isa_method: str, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None, + rows: Iterable[int] | None, + fpram_offset: int, + fpram_base_offset: int, + ): + return getattr(super(), isa_method)( + source.name, + self._fpram_row_map( + fpram_addr, + row_idx=row_idx, + rows=rows, + single_offset=fpram_offset, + base_offset=fpram_base_offset, + ), + ) + + def tile_row_add( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_add", dst, src, rows) + + def tile_row_sub( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_sub", dst, src, rows) + + def tile_row_mul( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_mul", dst, src, rows) + + def tile_row_mul_fp_broadcast( + self, + source: VRAMMatrixVar, + fpram_scalar_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + ): + scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) + super().tile_row_mul_fp_broadcast(source.name, scalar_addr, self._resolve_rows(row_idx=row_idx, rows=rows)) + + # ======================================================================== + # FPVar operations + # ======================================================================== + + def fpvar_reci(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_reci", src, dst, count) + + def fpvar_exp(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_exp", src, dst, count) + + def fpvar_copy(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_copy", src, dst, count) + + def fpvar_max(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_max", src1, src2, dst, count) + + def fpvar_sub(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_sub", src1, src2, dst, count) + + def fpvar_mul(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_mul", src1, src2, dst, count) + + def fpvar_add(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_add", src1, src2, dst, count) + + def fpvar_sum(self, src: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src], count, default=src.size) + return super().fpram_sum(src.name, dst.name, count) + + def fpvar_shift( + self, + src: FPVar, + dst: FPVar, + shift: int, + count: int | None = None, + fill: FPVar | None = None, + ): + count = self._fp_count([src, dst], count) + return super().fpram_shift( + src_name=src.name, + dst_name=dst.name, + shift=shift, + count=count, + fill_fpram_name=None if fill is None else fill.name, + ) + + def fpvar_fill_from_fpram( + self, + dst: FPVar, + src_fpram_addr: int, + count: int | None = None, + ): + count = self._fp_count([dst], count, default=dst.size) + return super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) + + def vram_fill_zero( + self, + matrix: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + resolved_rows = self._default_rows(rows, total_rows=matrix.shape[0]) + total_rows, cols = matrix.shape + if any(row < 0 or row >= total_rows for row in resolved_rows): + raise ValueError( + f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={resolved_rows}" + ) + + # VRAM matrices are column-block-major. The low-level helper zeros one + # tile column, so walk every column block for wide matrices. + num_col_blocks = (cols + self.mlen - 1) // self.mlen + for col_block in range(num_col_blocks): + super().vram_fill_zero(matrix.name, resolved_rows, tile_col_idx=col_block) + + +__all__ = ["ProgramFPTileOpsMixin"] diff --git a/aten/plena/dsl_matrix_ops.py b/aten/plena/program_matrix_ops.py similarity index 75% rename from aten/plena/dsl_matrix_ops.py rename to aten/plena/program_matrix_ops.py index 7cf9ae6..840da57 100644 --- a/aten/plena/dsl_matrix_ops.py +++ b/aten/plena/program_matrix_ops.py @@ -1,21 +1,23 @@ -"""Matrix projection, RoPE, and VRAM operations for the PLENA DSL.""" +"""Matrix projection, RoPE, and VRAM operations for the PLENA program builder.""" from __future__ import annotations from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar -class DslMatrixOpsMixin: +class ProgramMatrixOpsMixin: # ======================================================================== # Matrix Projection and VRAM Operations # ======================================================================== + def _require_var(self, value, expected_type, label: str): + if not isinstance(value, expected_type): + raise TypeError(f"{label} must be {expected_type.__name__}, got {type(value)}") + return value + def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): """Ensure an HBM input is registered in compiler sub-matrix manager.""" - if ( - input_var.name in self._registered_hbm_sub_matrices - and self._registered_hbm_sub_matrices[input_var.name] is True - ): + if self._registered_hbm_sub_matrices.get(input_var.name): return h, w = input_var.shape super().ensure_hbm_sub_matrix( @@ -28,10 +30,7 @@ def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): """Ensure a VRAM matrix is registered in compiler sub-matrix manager.""" - if ( - matrix_var.name in self._registered_vram_sub_matrices - and self._registered_vram_sub_matrices[matrix_var.name] is True - ): + if self._registered_vram_sub_matrices.get(matrix_var.name): return super().ensure_vram_matrix_layout( name=matrix_var.name, @@ -39,6 +38,16 @@ def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): ) self._registered_vram_sub_matrices[matrix_var.name] = True + def _prepare_projection(self, vram_matrix, mram_input, target, auto_reset_mram: bool): + vram_matrix = self._require_var(vram_matrix, VRAMMatrixVar, "vram_matrix") + mram_input = self._require_var(mram_input, InputVar, "mram_input") + target = self._require_var(target, VRAMMatrixVar, "target") + self._ensure_vram_sub_matrix_registered(vram_matrix) + self._ensure_hbm_sub_matrix_registered(mram_input) + if auto_reset_mram: + super().reset_mram() + return vram_matrix, mram_input, target + def vram_sub_projection_to( self, vram_matrix: VRAMMatrixVar, @@ -56,17 +65,9 @@ def vram_sub_projection_to( target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[:][mram_col_idx] Supports K-split: k_block_start/k_block_count select a subset of K tiles. """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() + vram_matrix, mram_input, target = self._prepare_projection( + vram_matrix, mram_input, target, auto_reset_mram + ) super().load_sub_matrix_col( name=mram_input.name, col_idx=mram_col_idx, @@ -99,17 +100,9 @@ def vram_sub_projection_T_to( """ target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[mram_row_idx][:]^T """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() + vram_matrix, mram_input, target = self._prepare_projection( + vram_matrix, mram_input, target, auto_reset_mram + ) super().load_sub_matrix_row(name=mram_input.name, row_idx=mram_row_idx) super().vram_sub_projection_T_to( vram_mat_name=vram_matrix.name, @@ -185,13 +178,9 @@ def vram_block_add_to( Supports writing back to the same matrix/block (in-place overwrite). """ - allowed = (VRAMMatrixVar,) - if not isinstance(src1, allowed): - raise TypeError(f"src1 must be VRAMMatrixVar, got {type(src1)}") - if not isinstance(src2, allowed): - raise TypeError(f"src2 must be VRAMMatrixVar, got {type(src2)}") - if not isinstance(target, allowed): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") + src1 = self._require_var(src1, VRAMMatrixVar, "src1") + src2 = self._require_var(src2, VRAMMatrixVar, "src2") + target = self._require_var(target, VRAMMatrixVar, "target") super().vram_block_add_to( src1_matrix=src1.name, @@ -206,4 +195,4 @@ def vram_block_add_to( ) -__all__ = ["DslMatrixOpsMixin"] +__all__ = ["ProgramMatrixOpsMixin"] diff --git a/aten/plena/dsl_tensors.py b/aten/plena/program_tensors.py similarity index 98% rename from aten/plena/dsl_tensors.py rename to aten/plena/program_tensors.py index 9f480dc..9184c66 100644 --- a/aten/plena/dsl_tensors.py +++ b/aten/plena/program_tensors.py @@ -1,11 +1,11 @@ -"""Tensor, memory, and normalization operations for the PLENA DSL.""" +"""Tensor, memory, and normalization operations for the PLENA program builder.""" from __future__ import annotations from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar -class DslTensorMixin: +class ProgramTensorMixin: # ======================================================================== # Input Declaration # ======================================================================== @@ -190,14 +190,14 @@ def alloc_at(self, name: str, rows: int, cols: int, vram_addr: int) -> VRAMMatri """ display_name = name internal_name = self._scoped_name(name) - self._compiler.add_vram_object( + self.add_vram_object( name=internal_name, shape=(rows, cols), vram_addr=vram_addr, allocate_if_none=False, ) isa_code = f"; VRAM View {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" - self._compiler.emit(isa_code) + self.emit(isa_code) var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) self._tensors[internal_name] = var return var @@ -316,4 +316,4 @@ def layer_norm( ) -__all__ = ["DslTensorMixin"] +__all__ = ["ProgramTensorMixin"] diff --git a/aten/plena/tile_compiler.py b/aten/plena/tile_compiler.py index ad6b0ed..8f4ef3a 100644 --- a/aten/plena/tile_compiler.py +++ b/aten/plena/tile_compiler.py @@ -2,8 +2,6 @@ from __future__ import annotations -import math - from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl from compiler.aten.isa_builder import IsaBuilder, addr as areg, gp from compiler.aten.plena.constants import BLEN, MLEN @@ -21,20 +19,7 @@ class TileCompiler: - """ - Tile compiler / sub-matrix manager. - - Core Functions: - 1. Register large matrices as blocked matrices. - 2. Support sub-block indexing: matrix[row_idx][col_idx] or matrix[row_idx][:]. - 3. Pre-calculate all addresses at compiler phase. - 4. Generate ISA code for load_sub_matrix and sub_projection. - - Key Constraints: - - Minimum block size is 64x64 (MLEN). - - Matrix must be loaded into MRAM before participating in computation. - - HBM and VRAM/MRAM use different storage formats (row-major vs column-block major). - """ + """Sub-matrix layout manager and ISA emitter for tiled PLENA memory ops.""" def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): self.mlen = mlen @@ -53,9 +38,6 @@ def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = Fals # Currently loaded sub-blocks in MRAM self.loaded_sub_blocks: dict[str, SubMatrixInfo] = {} - # Pre-calculated address cache - self._address_cache: dict[str, int] = {} - def __contains__(self, name: str) -> bool: return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices @@ -91,12 +73,6 @@ def __getitem__(self, name: str) -> MemoryObjectInfo: return info - def get(self, name: str, default: MemoryObjectInfo | None = None) -> MemoryObjectInfo | None: - try: - return self[name] - except KeyError: - return default - def get_hbm_layout(self, name: str) -> MatrixBlockLayout: """Read HBM matrix layout by name.""" if name not in self.hbm_matrices: @@ -218,19 +194,7 @@ def register_matrix( real_data_ratio: float = 1.125, strict: bool = True, ) -> MatrixBlockLayout: - """ - Register a large matrix and automatically block it. - - Args: - name: matrix name - shape: full shape (rows, cols) - hbm_base_addr: HBM base address - real_data_ratio: HBM storage ratio (MXFP8 = 1.125) - strict: if False, skip mlen-alignment check (for raw HBM access) - - Returns: - MatrixBlockLayout object - """ + """Register an HBM matrix and derive its mlen block layout.""" rows, cols = shape if strict: @@ -249,24 +213,6 @@ def register_matrix( self.hbm_matrices[name] = layout return layout - def get_sub_block(self, name: str, row_idx: int, col_idx: int) -> SubMatrixInfo: - """Get sub-block information.""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_sub_block(row_idx, col_idx) - - def get_row_blocks(self, name: str, row_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a row: matrix[row_idx][:]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_row_blocks(row_idx) - - def get_col_blocks(self, name: str, col_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a column: matrix[:][col_idx]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_col_blocks(col_idx) - # ========================================================================== # VRAM Sub-matrix Management # ========================================================================== @@ -308,46 +254,62 @@ def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMa raise KeyError(f"VRAM matrix '{name}' not registered") return self.vram_matrices[name].get_sub_block(row_idx, col_idx) - def get_vram_row_blocks(self, name: str, row_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a row of VRAM matrix: matrix[row_idx][:]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_row_blocks(row_idx) - - def get_vram_col_blocks(self, name: str, col_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a column of VRAM matrix: matrix[:][col_idx]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_col_blocks(col_idx) - # ========================================================================== - # Address Calculation (Core - Pre-calculated during compiler phase) + # ISA Generation: Load Sub Matrix # ========================================================================== - def compute_hbm_offset(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Compute HBM offset for sub-block (in elements, not bytes). + def _default_hbm_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3] if gp_regs is None else gp_regs - HBM row-major: sub-block (r, c) starts at r*block_size*full_cols + c*block_size. - """ - layout = self.hbm_matrices[name] + def _emit_hbm_prefetch_setup(self, asm: IsaBuilder, layout: MatrixBlockLayout, gp_scale: int, gp_stride: int) -> None: + rows, cols = layout.full_shape + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), rows * cols) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + def _emit_hbm_subblock_prefetch( + self, + asm: IsaBuilder, + name: str, + layout: MatrixBlockLayout, + row_idx: int, + col_idx: int, + mram_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + comment: str | None = None, + ) -> None: sub_block = layout.get_sub_block(row_idx, col_idx) - return sub_block.hbm_offset + hbm_offset = sub_block.hbm_offset + sub_block.mram_addr = mram_addr - def compute_absolute_hbm_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Calculate absolute HBM address of sub-block (in elements). + asm.comment(comment if comment is not None else f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) - Returns: - Absolute HBM address = base + offset - """ - layout = self.hbm_matrices[name] - offset = self.compute_hbm_offset(name, row_idx, col_idx) - return layout.hbm_base_addr + offset + self.loaded_sub_blocks[f"{name}[{row_idx}][{col_idx}]"] = sub_block - # ========================================================================== - # ISA Generation: Load Sub Matrix - # ========================================================================== + def _emit_hbm_subblock_sequence( + self, + asm: IsaBuilder, + name: str, + layout: MatrixBlockLayout, + block_coords, + mram_start_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + ) -> None: + mram_addr = mram_start_addr + block_size = self.mlen * self.mlen + for row_idx, col_idx in block_coords: + self._emit_hbm_subblock_prefetch( + asm, name, layout, row_idx, col_idx, mram_addr, hbm_addr_reg, gp_scale, gp_mram + ) + mram_addr += block_size def load_sub_matrix_asm( self, @@ -358,44 +320,29 @@ def load_sub_matrix_asm( hbm_addr_reg: int = 1, gp_regs: list[int] | None = None, ) -> str: - """ - Generate ISA code for loading sub-matrix from HBM to MRAM. - - HBM is row-major; H_PREFETCH_M loads mlen x mlen blocks into MRAM. - SCALE = full matrix element count; STRIDE = full column width (row stride). - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - + """Emit HBM->MRAM prefetch for one mlen x mlen sub-block.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) layout = self.hbm_matrices[name] - sub_block = layout.get_sub_block(row_idx, col_idx) - - hbm_offset = sub_block.hbm_offset - sub_block.mram_addr = mram_dest_addr asm = IsaBuilder() asm.comment(f"Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") - asm.comment(f"HBM offset: {hbm_offset} (precomputed)") - - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - gp_scale = gp_regs[0] gp_stride = gp_regs[1] gp_mram = gp_regs[2] - - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) - asm.instr("C_SET_SCALE_REG", gp(gp_scale)) - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_dest_addr) - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) - - asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + hbm_offset = layout.get_sub_block(row_idx, col_idx).hbm_offset + self._emit_hbm_subblock_prefetch( + asm, + name, + layout, + row_idx, + col_idx, + mram_dest_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + comment=f"HBM offset: {hbm_offset} (precomputed)", + ) return asm.render() @@ -407,51 +354,28 @@ def load_row_sub_matrices_asm( hbm_addr_reg: int = 1, gp_regs: list[int] | None = None, ) -> str: - """ - Generate ISA code for loading all sub-blocks in a row: matrix[row_idx][:] - - SCALE and STRIDE set once for all sub-blocks in the row. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - + """Emit HBM->MRAM prefetches for one block row.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) layout = self.hbm_matrices[name] - num_col_blocks = layout.num_col_blocks asm = IsaBuilder() asm.comment(f"Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - gp_scale = gp_regs[0] gp_stride = gp_regs[1] gp_mram = gp_regs[2] - - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) - asm.instr("C_SET_SCALE_REG", gp(gp_scale)) - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - - for col_idx in range(num_col_blocks): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - asm.comment(f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) - asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + + self._emit_hbm_subblock_sequence( + asm, + name, + layout, + ((row_idx, col_idx) for col_idx in range(layout.num_col_blocks)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) return asm.render() @@ -465,53 +389,30 @@ def load_col_sub_matrices_asm( k_block_start: int = 0, k_block_count: int | None = None, ) -> str: - """ - Generate ISA code for loading all sub-blocks in a column: matrix[:][col_idx]. - - Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. - k_block_start/k_block_count select a K-split slice of the column. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - + """Emit HBM->MRAM prefetches for one block column or K-split slice.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) layout = self.hbm_matrices[name] num_row_blocks = layout.num_row_blocks asm = IsaBuilder() asm.comment(f"Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - gp_scale = gp_regs[0] gp_stride = gp_regs[1] gp_mram = gp_regs[2] - - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), full_size) - asm.instr("C_SET_SCALE_REG", gp(gp_scale)) - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), full_cols) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) effective_count = k_block_count if k_block_count is not None else num_row_blocks - for row_idx in range(k_block_start, k_block_start + effective_count): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - asm.comment(f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) - asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size + self._emit_hbm_subblock_sequence( + asm, + name, + layout, + ((row_idx, col_idx) for row_idx in range(k_block_start, k_block_start + effective_count)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) return asm.render() @@ -519,6 +420,21 @@ def load_col_sub_matrices_asm( # ISA Generation: Sub Projection # ========================================================================== + def _default_projection_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3, 4, 5, 6, 7, 8, 9] if gp_regs is None else gp_regs + + def _projection_context(self, vram_mat_name: str, vram_row_idx: int, mram_mat_name: str): + if vram_mat_name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") + vram_layout = self.vram_matrices[vram_mat_name] + return vram_layout, self.hbm_matrices[mram_mat_name], vram_layout.get_row_blocks(vram_row_idx) + + def _loaded_mram_start(self, blocks, missing_label) -> int: + for sub_block in blocks: + if sub_block.mram_addr is None: + raise RuntimeError(f"SubBlock {missing_label(sub_block)} not loaded to MRAM") + return blocks[0].mram_addr + def _vram_sub_projection_asm_impl( self, header_lines: list[str], @@ -533,27 +449,7 @@ def _vram_sub_projection_asm_impl( caller_name: str, unroll: bool | None = None, ) -> str: - """ - Shared implementation kernel for vram_sub_projection_asm and - vram_sub_projection_T_asm. - - Parameters resolved by the caller before this point: - header_lines -- comment lines already assembled by the caller - vram_row_start_addr -- VRAM address of the first activation block - mram_start_addr -- MRAM address of the first weight block - result_vram_addr -- VRAM destination address for the (mlen, mlen) result - full_batch -- full batch dimension of the activation VRAM matrix - num_hidden_blocks -- number of K-blocks to accumulate over - mat_col_stride -- MRAM outer-column stride (blen for M_MM, blen*mlen for M_TMM) - transposed -- True → emit M_TMM with (act, mat) operand order; - False → emit M_MM with (mat, act) operand order - gp_regs -- list of at least 9 GP register indices - caller_name -- used in error messages only - unroll -- override instance unroll_loops flag (None = use instance default) - - Returns: - ISA code string. - """ + """Emit the shared projection loop after callers resolve operands.""" do_unroll = self.unroll_loops if unroll is None else unroll return vram_sub_projection_asm_impl( mlen=self.mlen, @@ -583,28 +479,11 @@ def vram_sub_projection_asm( k_block_count: int | None = None, unroll: bool | None = None, ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[:][col_idx] - - VRAM_A[row_idx][:] is (mlen, hidden_size) spread across multiple (mlen, mlen) column blocks. - MRAM_W[:][col_idx] is (hidden_size, mlen) already loaded into MRAM. - Result is (mlen, mlen) written to VRAM. - - Loop structure (outer=output cols by blen, middle=output rows by blen, inner=K accumulation). - k_block_start/k_block_count select a K-split chunk. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) + """Emit VRAM[row][:] @ MRAM[:][col] projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) # K-split: slice to only the loaded k-chunk if k_block_count is not None: @@ -617,13 +496,12 @@ def vram_sub_projection_asm( f"got {num_hidden_blocks}" ) - for sub_block in mram_col_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{sub_block.row_idx}][{mram_col_idx}] not loaded to MRAM") - full_batch = vram_layout.full_shape[0] vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr - mram_col_start_addr = mram_col_blocks[0].mram_addr + mram_col_start_addr = self._loaded_mram_start( + mram_col_blocks, + lambda block: f"{mram_mat_name}[{block.row_idx}][{mram_col_idx}]", + ) header_lines = [ f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", @@ -656,27 +534,11 @@ def vram_sub_projection_T_asm( gp_regs: list[int] | None = None, unroll: bool | None = None, ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix transposed multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[row_idx][:]^T - - VRAM_A[row_idx][:] is (mlen, hidden_size). - MRAM_W[row_idx][:] is (mlen, hidden_size); transposed to (hidden_size, mlen). - Result is (mlen, mlen) written to VRAM. - - Uses M_TMM instruction. mat_col_stride = blen*mlen (transposed addressing). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) + """Emit VRAM[row][:] @ MRAM[row][:]^T projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) if len(vram_row_blocks) != len(mram_row_blocks): @@ -686,13 +548,12 @@ def vram_sub_projection_T_asm( num_hidden_blocks = len(vram_row_blocks) - for sub_block in mram_row_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{mram_row_idx}][{sub_block.col_idx}] not loaded to MRAM") - full_batch = vram_layout.full_shape[0] vram_row_start_addr = vram_row_blocks[0].vram_addr - mram_row_start_addr = mram_row_blocks[0].mram_addr + mram_row_start_addr = self._loaded_mram_start( + mram_row_blocks, + lambda block: f"{mram_mat_name}[{mram_row_idx}][{block.col_idx}]", + ) # NOTE: For M_TMM (transposed matmul), the MRAM outer-column stride is # blen * mlen (full sub-block size) because M_TMM reads the weight in # transposed layout. This differs from the non-transposed path which uses @@ -780,189 +641,6 @@ def vram_block_add_asm( return "\n".join(lines) + "\n" - # ========================================================================== - # High-level Interface: Complete Sub-block Computation - # ========================================================================== - - def compute_sub_matmul( - self, - a_name: str, - a_row_idx: int | slice, - b_name: str, - b_col_idx: int | slice, - result_name: str, - transpose_b: bool = False, - ) -> tuple[str, int]: - """Not yet implemented. Use vram_sub_projection_asm or vram_sub_projection_T_asm.""" - raise NotImplementedError( - "compute_sub_matmul is not yet implemented. Use vram_sub_projection_asm " - "or vram_sub_projection_T_asm for matrix multiplication." - ) - - # ========================================================================== - # Format Conversion: HBM <-> VRAM - # ========================================================================== - - def load_activation_with_format_convert_asm( - self, - name: str, - hbm_base_addr: int, - batch: int, - hidden_size: int, - vram_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Load activation from HBM to VRAM with format conversion. - - HBM layout: [batch, hidden_size] row-major (element[b,h] at hbm_base + b*hidden_size + h). - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - element[b,h]: vram_base + (h//mlen)*batch*mlen + b*mlen + (h%mlen). - - H_PREFETCH_V loads mlen elements per call; columns loaded in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - asm = IsaBuilder() - asm.comment(f"Load Activation with Format Convert: {name}") - asm.comment(f"HBM[{hbm_base_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - asm.comment(f"VRAM[{vram_dest_addr}]: [batch, mlen, hidden/mlen] column-block major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - preload_len = 4 # load 4 rows per H_PREFETCH_V call - - total_size = batch * hidden_size - asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), total_size) - asm.instr("C_SET_SCALE_REG", gp(gp_hbm_offset)) - - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), hidden_size) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - for col_block in range(num_col_blocks): - asm.comment(f"Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_dest_addr + col_block * batch * self.mlen - - asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset) - asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr) - - for batch_block in range(math.ceil(batch / preload_len)): - actual_batch_offset = batch_block * preload_len * hidden_size - actual_vram_offset = batch_block * preload_len * self.mlen - - asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset + actual_batch_offset) - asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr + actual_vram_offset) - asm.instr("H_PREFETCH_V", gp(gp_vram), gp(gp_hbm_offset), areg(hbm_addr_reg), 1, 0) - - return asm.render() - - def store_activation_with_format_convert_asm( - self, - name: str, - vram_src_addr: int, - batch: int, - hidden_size: int, - hbm_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Store activation from VRAM to HBM with format conversion. - - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - HBM layout: [batch, hidden_size] row-major. - - H_STORE_V stores mlen elements per call; columns stored in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - asm = IsaBuilder() - asm.comment(f"Store Activation with Format Convert: {name}") - asm.comment(f"VRAM[{vram_src_addr}]: [batch, mlen, hidden/mlen] column-block major") - asm.comment(f"HBM[{hbm_dest_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - store_amount = 4 # store 4 rows per H_STORE_V call - - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), hidden_size) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - for col_block in range(num_col_blocks): - asm.comment(f"Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_src_addr + col_block * batch * self.mlen - - for batch_block in range(math.ceil(batch / store_amount)): - actual_batch_offset = batch_block * store_amount * hidden_size - actual_vram_offset = batch_block * store_amount * self.mlen - - asm.instr("S_ADDI_INT", gp(gp_hbm_offset), gp(0), hbm_offset + actual_batch_offset) - asm.instr("S_ADDI_INT", gp(gp_vram), gp(0), vram_addr + actual_vram_offset) - asm.instr("H_STORE_V", gp(gp_vram), gp(gp_hbm_offset), areg(hbm_addr_reg), 0) - - return asm.render() - - # ========================================================================== - # Pre-calculated Address Table Generation - # ========================================================================== - - def generate_address_table(self, name: str) -> dict[str, int]: - """Generate complete address table for a matrix (for debugging and verification).""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - - layout = self.hbm_matrices[name] - addr_table = {} - - for (r, c), sub_block in layout.sub_blocks.items(): - key = f"{name}[{r}][{c}]" - addr_table[f"{key}_hbm_offset"] = sub_block.hbm_offset - addr_table[f"{key}_hbm_abs"] = layout.hbm_base_addr + sub_block.hbm_offset - if sub_block.mram_addr is not None: - addr_table[f"{key}_mram"] = sub_block.mram_addr - - return addr_table - - def print_address_table(self, name: str): - """Print address table for a matrix.""" - addr_table = self.generate_address_table(name) - print(f"Address Table for {name}:") - for key, addr in sorted(addr_table.items()): - print(f" {key}: {addr}") - - # ========================================================================== - # Helper Methods - # ========================================================================== - - def get_loaded_block_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """Get MRAM address of a loaded sub-block.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - if block_key not in self.loaded_sub_blocks: - raise KeyError(f"SubBlock {block_key} not loaded") - return self.loaded_sub_blocks[block_key].mram_addr - - def is_block_loaded(self, name: str, row_idx: int, col_idx: int) -> bool: - """Check whether a sub-block has been loaded into MRAM.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - return block_key in self.loaded_sub_blocks - def reset(self): """Reset manager state.""" self.hbm_matrices.clear() @@ -972,44 +650,5 @@ def reset(self): self.mram_allocator.reset() self.fpram_allocator.reset() self.loaded_sub_blocks.clear() - self._address_cache.clear() - - def print_table(self): - """Print unified managed object table.""" - print("=" * 95) - print("Managed Object Table") - print("=" * 95) - print( - f"{'Name':<20} {'Kind':<12} {'Shape':<15} {'HBM Addr':<10} {'HBM Size':<10} {'VRAM Addr':<12} {'FPRAM':<12} {'Size':<8}" - ) - print("-" * 95) - names = sorted(set(self.hbm_matrices) | set(self.vram_matrices) | set(self.fpram_matrices)) - for name in names: - info = self[name] - shape_str = f"({info.shape[0]}, {info.shape[1]})" - vram_str = f"{info.vram_addr}" if info.vram_addr is not None else "None" - fpram_str = f"{info.fpram_addr}/{info.fpram_size}" if info.fpram_addr is not None else "None" - print( - f"{name:<20} {info.kind:<12} {shape_str:<15} " - f"{info.hbm_addr:<10} {info.hbm_size:<10} {vram_str:<12} {fpram_str:<12} {info.size:<8}" - ) - print("=" * 95) - - def print_layout(self, name: str): - """Print block layout for a matrix.""" - if name not in self.hbm_matrices: - print(f"Matrix '{name}' not registered") - return - - layout = self.hbm_matrices[name] - print(f"Matrix: {name}") - print(f" Full shape: {layout.full_shape}") - print(f" Block size: {layout.block_size}") - print(f" Blocks: {layout.num_row_blocks} x {layout.num_col_blocks}") - print(f" HBM base: {layout.hbm_base_addr}") - print(" Sub blocks:") - for (r, c), sub in layout.sub_blocks.items(): - loaded = "LOADED" if sub.mram_addr is not None else "" - print(f" [{r}][{c}]: hbm_off={sub.hbm_offset}, mram={sub.mram_addr} {loaded}") __all__ = ["TileCompiler"] diff --git a/aten/plena/vars.py b/aten/plena/vars.py index b513165..d6dc65b 100644 --- a/aten/plena/vars.py +++ b/aten/plena/vars.py @@ -2,11 +2,10 @@ from __future__ import annotations -from enum import Enum from typing import TYPE_CHECKING if TYPE_CHECKING: - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena.compiler import PlenaCompiler class TensorVar: @@ -14,11 +13,9 @@ class TensorVar: Tensor proxy object base class All tensor variables inherit from this class. - Supports __matmul__ (`@`) operator, which automatically dispatches to appropriate PlenaCompiler methods. - Dual naming: - display_name: User-visible name (e.g., "temp", "Q", "S") - - internal_name: System internal name (e.g., "my_func_0/temp"), used for symbol table and ISA generation + - internal_name: System internal name, used for symbol table and ISA generation """ def __init__( @@ -40,10 +37,6 @@ def name(self) -> str: """Compatibility property: returns internal_name for internal system use""" return self.internal_name - def __matmul__(self, other): - """A @ B: Dispatch to appropriate computation based on operand types""" - return self._program._dispatch_matmul(self, other) - def __repr__(self): if self.display_name != self.internal_name: return ( @@ -129,35 +122,3 @@ class VRAMMatrixVar(TensorVar): def __init__(self, program: PlenaCompiler, name: str, shape: tuple[int, int], display_name: str | None = None): super().__init__(program, name, "vram_matrix", shape, display_name=display_name) - - -class TensorKind(Enum): - """Identifies which memory the tensor lives in / which proxy backs it.""" - - HBM = "hbm" # legacy: InputVar - VRAM = "vram" # legacy: VRAMMatrixVar - FPRAM = "fpram" # legacy: FPVar - - -# Type alias: a "Tensor" is any of the legacy proxy objects. Callers can use -# this annotation without worrying about which specific backing allocator -# a given tensor lives in. -Tensor = TensorVar | InputVar | VRAMMatrixVar | FPVar - - -def tensor_kind(tensor: Tensor) -> TensorKind: - """Return the backing storage of a tensor proxy.""" - if isinstance(tensor, FPVar): - return TensorKind.FPRAM - if isinstance(tensor, VRAMMatrixVar): - return TensorKind.VRAM - if isinstance(tensor, InputVar): - return TensorKind.HBM - if isinstance(tensor, TensorVar): - # Generic TensorVar without a specific backing — classify by ``kind``. - kind = getattr(tensor, "kind", "") - if kind in ("vram_matrix", "batch", "matrix"): - return TensorKind.VRAM - if kind == "input": - return TensorKind.HBM - raise TypeError(f"Unknown tensor type: {type(tensor).__name__}") diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py index fa8f0d0..8e473c6 100644 --- a/aten/plena_compiler.py +++ b/aten/plena_compiler.py @@ -1,70 +1,22 @@ -"""Compatibility facade for the ATen PLENA compiler path. - -The implementation is split under ``compiler.aten.plena``. This module keeps -legacy imports such as ``from compiler.aten.plena_compiler import PlenaCompiler`` -working while re-exporting the public compiler, memory, register, and tensor -proxy types. -""" +"""Compatibility facade for the active ATen PLENA compiler path.""" from __future__ import annotations from compiler.aten.plena.compiler import PlenaCompiler from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN -from compiler.aten.plena.isa_compiler import DeveloperCompiler, IsaCompiler -from compiler.aten.plena.memory import ( - FPRAMAllocator, - FPRAMObjectLayout, - MRAMAllocator, - MatrixBlockLayout, - MemoryBlock, - MemoryObjectInfo, - SubMatrixInfo, - VRAMAllocator, - VRAMMatrixBlockLayout, - VRAMSubMatrixInfo, - VirtualMemoryManager, -) -from compiler.aten.plena.registers import RegisterAllocator +from compiler.aten.plena.isa_compiler import IsaCompiler from compiler.aten.plena.tile_compiler import TileCompiler -from compiler.aten.plena.vars import FPVar, InputVar, Tensor, TensorKind, TensorVar, VRAMMatrixVar, tensor_kind - - -# ``TensorInfo`` is the union of the three Info dataclasses. Callers can import -# it as an annotation; at runtime the object is whichever specific Info subtype -# ``TileCompiler`` constructed. -TensorInfo = MemoryObjectInfo | SubMatrixInfo | VRAMSubMatrixInfo - -# ``TileLayout`` is the union of the three Layout dataclasses. -TileLayout = MatrixBlockLayout | VRAMMatrixBlockLayout | FPRAMObjectLayout - +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar __all__ = [ "BLEN", "IMM2_BOUND", "MLEN", - "DeveloperCompiler", - "FPRAMAllocator", - "FPRAMObjectLayout", "FPVar", "InputVar", "IsaCompiler", - "MRAMAllocator", - "MatrixBlockLayout", - "MemoryBlock", - "MemoryObjectInfo", "PlenaCompiler", - "RegisterAllocator", - "SubMatrixInfo", - "Tensor", - "TensorInfo", - "TensorKind", "TensorVar", "TileCompiler", - "TileLayout", - "VRAMAllocator", - "VRAMMatrixBlockLayout", "VRAMMatrixVar", - "VRAMSubMatrixInfo", - "VirtualMemoryManager", - "tensor_kind", ] diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index aba35a8..b77c679 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -1,32 +1,14 @@ -""" -Automatic HuggingFace model -> PLENA ISA compiler. +"""HuggingFace decoder model to PLENA ISA compiler.""" -Walks an HF nn.Module tree, extracts weights, and generates ISA with -proper residual connections for multi-layer decoder pipelines. - -Usage: - from transformers import AutoModelForCausalLM - from compiler.aten.plena_frontend import compile_hf_model, compile_and_run - - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", - torch_dtype=torch.float32) - # Just compile (no emulator run): - result = compile_hf_model(model, seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) - - # Compile + run emulator + compare: - result = compile_and_run(model, "/tmp/build", seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) -""" - -import json import math import re -from pathlib import Path import torch import torch.nn.functional as F from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.ops.registry import OpRegistry, Backend +from compiler.aten.ops.plena.linear_ops import linear_projection_plena as _linear_projection import compiler.aten.ops as ops from quant.quantizer.hardware_quantizer.mxfp import _mx_fp_quantize_hardware @@ -35,7 +17,7 @@ def _fix_large_immediates(isa_code: str) -> str: """Post-process ISA: replace S_ADDI_INT gp{r}, gp0, {large} with S_LUI_INT + S_ADDI_INT. - PlenaCompiler emits raw S_ADDI_INT for VRAM/HBM addresses. At native + PlenaCompiler emits raw S_ADDI_INT for VRAM/HBM addresses. At large model dimensions these can exceed the 18-bit immediate limit. This pass splits them into S_LUI_INT (upper 22 bits, shifted <<12) + S_ADDI_INT (lower 12 bits), matching asm_templates._imm.load_large_int. @@ -63,7 +45,7 @@ def _fix_large_immediates(isa_code: str) -> str: # --------------------------------------------------------------------------- REAL_DATA_RATIO = (8 * 8 + 8) / (8 * 8) -# Hardware K-split tile limit (matches _linear_projection MAX_K_TILES) +# Hardware K-split tile limit (matches the PLENA linear backend) _HW_MAX_K_TILES = 4 @@ -162,6 +144,15 @@ def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: # --------------------------------------------------------------------------- # Model structure helpers # --------------------------------------------------------------------------- +def _linear_weight(module, rows, cols): + """HF Linear stores (out, in); PLENA uses (in, out).""" + return module.weight.detach().T.contiguous()[:rows, :cols] + + +def _split_heads(weight, head_dim, num_heads): + return [weight[:, h * head_dim:(h + 1) * head_dim].contiguous() for h in range(num_heads)] + + def _find_model_root(model): """Find the transformer backbone (model.model or model.model.text_model). @@ -180,20 +171,20 @@ def _find_model_root(model): def _extract_config(model): """Extract config dimensions from the model, resolving text_config for VLMs.""" config = getattr(model.config, "text_config", model.config) - native_hidden = config.hidden_size - native_inter = getattr(config, "intermediate_size", 4 * native_hidden) - native_heads = config.num_attention_heads - native_kv_heads = getattr(config, "num_key_value_heads", native_heads) - native_head_dim = native_hidden // native_heads + hidden = config.hidden_size + inter = getattr(config, "intermediate_size", 4 * hidden) + num_heads = config.num_attention_heads + num_kv_heads = getattr(config, "num_key_value_heads", num_heads) + head_dim = hidden // num_heads eps = getattr(config, "rms_norm_eps", 1e-5) rope_theta = getattr(config, "rope_theta", 10000.0) vocab_size = getattr(config, "vocab_size", None) return { - "hidden_size": native_hidden, - "inter_dim": native_inter, - "num_heads": native_heads, - "num_kv_heads": native_kv_heads, - "head_dim": native_head_dim, + "hidden_size": hidden, + "inter_dim": inter, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, "eps": eps, "rope_theta": rope_theta, "vocab_size": vocab_size, @@ -201,90 +192,25 @@ def _extract_config(model): } -def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice, num_heads, head_dim, - num_kv_heads=1, native_mode=False): - """Extract and slice weights from a single decoder layer. - - Transposes from HF's (out_features, in_features) to PLENA's (in, out) convention. - - In native mode (hidden_slice == native hidden, multi-head): - - W_q: (hidden, num_heads * head_dim) — full Q projection - - W_o: (num_heads * head_dim, hidden) — full O projection - - W_k_heads: list of (hidden, head_dim) per KV head - - W_v_heads: list of (hidden, head_dim) per KV head - - In sliced mode (legacy single-head): - - W_q: (hidden_slice, head_dim_slice) - - W_o: (head_dim_slice, hidden_slice) - - W_k: (hidden_slice, head_dim_slice) - - W_v: (hidden_slice, head_dim_slice) - - Args: - layer: nn.Module for a single decoder layer - hidden_slice: target hidden dimension - inter_slice: target intermediate dimension - head_dim_slice: target head dimension (min of native head_dim, hidden_slice) - num_heads: number of attention heads (native config) - head_dim: native head dimension - num_kv_heads: number of KV heads (native config) - native_mode: if True, extract full multi-head weights - - Returns: - dict with W_q, W_o, W_gate, W_up, W_down, W_k/W_k_heads, W_v/W_v_heads, eps - """ - # FFN weights: HF stores (out, in) -> transpose to (in, out) -> slice - W_gate = layer.mlp.gate_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_up = layer.mlp.up_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_down = layer.mlp.down_proj.weight.detach().T.contiguous()[:inter_slice, :hidden_slice] - - # eps from input_layernorm +def _extract_layer_weights(layer, hidden, inter, num_heads, head_dim, num_kv_heads=1): + """Extract one decoder layer in PLENA's (in, out) linear-weight convention.""" norm = layer.input_layernorm - eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)) - - if native_mode: - # Full multi-head weights (no slicing) - total_q_dim = num_heads * head_dim - total_kv_dim = num_kv_heads * head_dim - - W_q = layer.self_attn.q_proj.weight.detach().T.contiguous()[:hidden_slice, :total_q_dim] - W_o = layer.self_attn.o_proj.weight.detach().T.contiguous()[:total_q_dim, :hidden_slice] - - # Per-KV-head K and V weights - W_k_full = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :total_kv_dim] - W_v_full = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :total_kv_dim] - - W_k_heads = [W_k_full[:, h * head_dim:(h + 1) * head_dim].contiguous() - for h in range(num_kv_heads)] - W_v_heads = [W_v_full[:, h * head_dim:(h + 1) * head_dim].contiguous() - for h in range(num_kv_heads)] - - return { - "W_q": W_q, - "W_o": W_o, - "W_gate": W_gate, - "W_up": W_up, - "W_down": W_down, - "W_k_heads": W_k_heads, - "W_v_heads": W_v_heads, - "eps": eps, - } - else: - # Legacy single-head sliced mode - W_q = layer.self_attn.q_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - W_o = layer.self_attn.o_proj.weight.detach().T.contiguous()[:head_dim_slice, :hidden_slice] - W_k = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - W_v = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - - return { - "W_q": W_q, - "W_o": W_o, - "W_gate": W_gate, - "W_up": W_up, - "W_down": W_down, - "W_k": W_k, - "W_v": W_v, - "eps": eps, - } + total_q_dim = num_heads * head_dim + total_kv_dim = num_kv_heads * head_dim + W_k_full = _linear_weight(layer.self_attn.k_proj, hidden, total_kv_dim) + W_v_full = _linear_weight(layer.self_attn.v_proj, hidden, total_kv_dim) + + weights = { + "W_q": _linear_weight(layer.self_attn.q_proj, hidden, total_q_dim), + "W_o": _linear_weight(layer.self_attn.o_proj, total_q_dim, hidden), + "W_k_heads": _split_heads(W_k_full, head_dim, num_kv_heads), + "W_v_heads": _split_heads(W_v_full, head_dim, num_kv_heads), + "W_gate": _linear_weight(layer.mlp.gate_proj, hidden, inter), + "W_up": _linear_weight(layer.mlp.up_proj, hidden, inter), + "W_down": _linear_weight(layer.mlp.down_proj, inter, hidden), + "eps": getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)), + } + return weights # --------------------------------------------------------------------------- @@ -314,101 +240,241 @@ def _flash_attn_ref(Q, K, V, scale, causal=False): return (attn @ V).to(torch.bfloat16).float() -def _rms_norm_ref(x, eps): - """CPU reference: RMS normalization matching PLENA hardware. +def _inter_round(x, to_inter, from_inter): + return from_inter(to_inter(x)) - Hardware path: V_RED_SUM accumulates sum-of-squares into f32 scalar register, - S_MUL_FP / S_ADD_FP / S_SQRT_FP / S_RECI_FP all operate in f32, - then V_MUL_VF multiplies BF16 vector by f32 scalar -> BF16 result. - The scalar rms factor stays in f32 throughout; only the vector data is BF16. - """ - x_bf = x.to(torch.bfloat16) - # Compute rms in f32 (matching hardware scalar register precision) - rms = torch.rsqrt(x_bf.float().pow(2).mean(-1, keepdim=True) + eps) - # V_MUL_VF: BF16 vector * f32 scalar -> quantized back to BF16 - return (x_bf.float() * rms).to(torch.bfloat16).float() +def _hardware_rms_norm_ref(x, eps, to_inter, from_inter): + x_inter = to_inter(x) + rms = torch.rsqrt(from_inter(x_inter).pow(2).mean(-1, keepdim=True) + eps) + return _inter_round(from_inter(x_inter) * rms, to_inter, from_inter) -# --------------------------------------------------------------------------- -# PLENA ISA helper: named linear projection (avoids name conflicts) -# --------------------------------------------------------------------------- -def _linear_projection(prog, input_var, weight_var, name): - """Emit a linear projection with a custom VRAM output name. - Equivalent to ops.linear but uses *name* for the output allocation so - that multiple projections in the same scope don't collide on the - default "linear_out" name. +def _hardware_linear_ref(x, weight, mlen, to_inter, from_inter): + return _ksplit_matmul(x, weight, mlen, _HW_MAX_K_TILES, to_inter, from_inter) - Supports K-split: when K tiles exceed MRAM capacity (4 tiles), splits - into chunks and accumulates partial sums via a temporary buffer. - """ - import math as _math - mlen = prog.mlen - MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements +def _hardware_rope_ref(x, rope_matrix, cos_table, sin_table, to_inter, from_inter): + x_inter = _inter_round(x, to_inter, from_inter) + x_rot = _inter_round( + torch.matmul(x_inter, _inter_round(rope_matrix, to_inter, from_inter)), + to_inter, + from_inter, + ) + x_cos = _inter_round(x_inter * _inter_round(cos_table, to_inter, from_inter), to_inter, from_inter) + x_rot_sin = _inter_round(x_rot * _inter_round(sin_table, to_inter, from_inter), to_inter, from_inter) + return _inter_round(x_cos + x_rot_sin, to_inter, from_inter) + - rows, k_total = input_var.shape - _, out_features = weight_var.shape - num_row_blocks = _math.ceil(rows / mlen) - assert out_features % mlen == 0, ( - f"out_features ({out_features}) must be a multiple of mlen ({mlen})" +def _hardware_hbm_round_ref(x, quantize_weight, to_inter, from_inter): + return _inter_round(quantize_weight(x), to_inter, from_inter) + + +def _hardware_residual_add_ref(x, residual, to_inter, from_inter): + return _inter_round(x + residual, to_inter, from_inter) + + +def _hardware_ffn_ref(x, W_gate, W_up, W_down, eps, mlen, to_inter, from_inter): + residual = x.clone() + x = _hardware_rms_norm_ref(x, eps, to_inter, from_inter) + up_out = _hardware_linear_ref(x, W_up, mlen, to_inter, from_inter) + gate_out = _hardware_linear_ref(x, W_gate, mlen, to_inter, from_inter) + silu_gate = to_inter( + F.silu(_inter_round(up_out, to_inter, from_inter)) * _inter_round(gate_out, to_inter, from_inter) ) - num_col_blocks = out_features // mlen - num_k_tiles = _math.ceil(k_total / mlen) - - output_strict = rows % mlen == 0 - output = prog.alloc(name, rows, out_features, strict=output_strict) - - if num_k_tiles <= MAX_K_TILES: - # Single pass: all K tiles fit in MRAM - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - ) - else: - # K-split: chunk K tiles into groups of MAX_K_TILES - k_chunks = [] - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - k_chunks.append((k_start, k_end - k_start)) - k_start = k_end - - temp = prog.alloc(f"{name}_ksplit_tmp", rows, out_features, strict=output_strict) - - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - if k_chunk_idx == 0: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - output, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - else: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - temp, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - prog.vram_block_add_to( - output, row_idx, col_idx, - temp, row_idx, col_idx, - output, row_idx, col_idx, - ) - - prog.free_tensor(temp) - - return output + x = _hardware_linear_ref(from_inter(silu_gate), W_down, mlen, to_inter, from_inter) + return _hardware_residual_add_ref(_inter_round(x, to_inter, from_inter), residual, to_inter, from_inter) + + +def _fp32_rms_norm_ref(x, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def _fp32_linear_ref(x, weight): + return x @ weight.float() + + +def _fp32_rope_ref(x, rope_matrix, cos_table, sin_table): + return x * cos_table.float() + (x @ rope_matrix.float()) * sin_table.float() + + +def _fp32_ffn_ref(x, weights, eps): + residual = x.clone() + x_normed = _fp32_rms_norm_ref(x, eps) + up_out = F.silu(_fp32_linear_ref(x_normed, weights["W_up"])) + gate_out = _fp32_linear_ref(x_normed, weights["W_gate"]) + return _fp32_linear_ref(up_out * gate_out, weights["W_down"]) + residual + + +def _save_residual_and_norm(prog, source, scratch): + """Emit the common decoder pre-norm residual prologue.""" + prog.vram_fill_zero(scratch) + prog.vram_add(scratch, source) + prog.rms_norm(source, eps_offset=3, reci_hid_offset=4) + + +def _add_residual(prog, target, scratch): + prog.vram_add(target, scratch) + return target + + +def _apply_rope_projection(prog, x_var, rope_matrix, cos_var, sin_var, name): + x_rot = _linear_projection(prog, x_var, rope_matrix, name) + prog.rope(x_var, x_rot, cos_var, sin_var) + prog.free_tensor(x_rot) + return x_var + + +def _copy_into_vram_view(prog, source, name, rows, cols, vram_addr): + target = prog.alloc_at(name, rows, cols, vram_addr) + prog.vram_fill_zero(target) + prog.vram_add(target, source) + return target + + +def _free_named_tensors(prog, names): + for name in names: + tensor = prog._tensors.get(name) + if tensor is not None: + prog.free_tensor(tensor) + prog._tensors.pop(name, None) + + +def _emit_kv_stores(prog, current, layer_inputs, rope_inputs, layer_idx, num_kv_heads): + rope_matrix, cos_var, sin_var = rope_inputs + kv_stored = [] + for kv_h in range(num_kv_heads): + K_h = _linear_projection( + prog, + current, + layer_inputs["W_k_heads"][kv_h], + f"K_{layer_idx}_h{kv_h}", + ) + V_h = _linear_projection( + prog, + current, + layer_inputs["W_v_heads"][kv_h], + f"V_{layer_idx}_h{kv_h}", + ) + + _apply_rope_projection( + prog, + K_h, + rope_matrix, + cos_var, + sin_var, + f"K_rot_{layer_idx}_h{kv_h}", + ) + + K_stored = prog.store(K_h, name=f"K_stored_{layer_idx}_h{kv_h}") + V_stored = prog.store(V_h, name=f"V_stored_{layer_idx}_h{kv_h}") + kv_stored.append((K_stored, V_stored)) + + prog.free_tensor(K_h) + prog.free_tensor(V_h) + + return kv_stored + + +def _emit_attention_block( + prog, + current, + layer_inputs, + rope_inputs, + causal_mask, + scratch, + scale, + layer_idx, + seq_len, + head_dim, + total_q_dim, + num_heads, + num_kv_heads, + ratio, +): + _save_residual_and_norm(prog, current, scratch) + + Q = _linear_projection(prog, current, layer_inputs["W_q"], f"Q_{layer_idx}") + q_full_addr = prog.get_vram_addr(Q.name) + + O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim) + o_full_addr = prog.get_vram_addr(O_full.name) + + kv_stored = _emit_kv_stores( + prog, + current, + layer_inputs, + rope_inputs, + layer_idx, + num_kv_heads, + ) + + rope_matrix, cos_var, sin_var = rope_inputs + for h in range(num_heads): + kv_h = h // ratio + K_stored, V_stored = kv_stored[kv_h] + + q_h_addr = q_full_addr + h * seq_len * prog.mlen + Q_h = prog.alloc_at(f"Q_h{h}_{layer_idx}", seq_len, head_dim, q_h_addr) + _apply_rope_projection( + prog, + Q_h, + rope_matrix, + cos_var, + sin_var, + f"Q_rot_{layer_idx}_h{h}", + ) + + O_h = ops.flash_attention( + prog, + Q_h, + K_stored, + V_stored, + scale, + causal_mask=causal_mask, + ) + + o_h_dest_addr = o_full_addr + h * seq_len * prog.mlen + _copy_into_vram_view( + prog, + O_h, + f"O_dest_h{h}_{layer_idx}", + seq_len, + head_dim, + o_h_dest_addr, + ) + _free_named_tensors(prog, ("O", "S", "PV")) + + O_proj = _linear_projection(prog, O_full, layer_inputs["W_o"], f"O_proj_{layer_idx}") + return _add_residual(prog, O_proj, scratch) + + +def _emit_ffn_block(prog, current, layer_inputs, scratch): + _save_residual_and_norm(prog, current, scratch) + ops.ffn(prog, current, layer_inputs["W_gate"], layer_inputs["W_up"], layer_inputs["W_down"]) + return _add_residual(prog, current, scratch) + + +def _layer_tensor_entries(layer_idx, weights, num_kv_heads): + entries = [ + (f"W_q_{layer_idx}", weights["W_q"]), + (f"W_o_{layer_idx}", weights["W_o"]), + ] + for kv_h in range(num_kv_heads): + entries.extend( + [ + (f"W_k_{layer_idx}_h{kv_h}", weights["W_k_heads"][kv_h]), + (f"W_v_{layer_idx}_h{kv_h}", weights["W_v_heads"][kv_h]), + ] + ) + entries.extend( + [ + (f"W_gate_{layer_idx}", weights["W_gate"]), + (f"W_up_{layer_idx}", weights["W_up"]), + (f"W_down_{layer_idx}", weights["W_down"]), + ] + ) + return entries # --------------------------------------------------------------------------- @@ -417,96 +483,27 @@ def _linear_projection(prog, input_var, weight_var, name): def compile_hf_model( model, seq_len: int = 64, - hidden_size: int | None = None, - inter_dim: int | None = None, num_layers: int | None = None, layer_idx_start: int = 0, mlen: int = 64, blen: int = 4, seed: int = 42, - include_lm_head: bool = False, golden_precision: str = "hardware", + verbose: bool = False, ) -> dict: - """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. - - Walks the nn.Module tree, extracts weights, and generates ISA with - proper residual connections for multi-layer decoders. - - The pipeline implemented (per-layer, with pre-norm + residual): - X = embed_tokens(input_ids) # real HF embedding lookup - X = embedding_add(X, zeros) # pos_weight=0 for Llama (RoPE handles position) - for each layer: - residual = X - X = rms_norm(X) - Q = linear(X, W_q) # Q projection - K_h = linear(X, W_k_h) # K projection (per KV head, on-chip) - V_h = linear(X, W_v_h) # V projection (per KV head, on-chip) - # RoPE (native mode only): - Q_rot_h = linear(Q_h, R_rope) # rotate_half via matmul - Q_h = Q_h * cos + Q_rot_h * sin - K_rot_h = linear(K_h, R_rope) - K_h = K_h * cos + K_rot_h * sin - store K_h, V_h to HBM - O = flash_attention(Q, K, V, scale) - O = linear(O, W_o) # O projection - X = O + residual - residual = X - X = rms_norm(X) - X = ffn(X, gate, up, down) - X = X + residual - X = rms_norm(X) # final norm - if include_lm_head: - logits = linear(X, W_lm_head) # (seq, vocab_size) - - RoPE is applied in native mode via rotate_half matrix multiplication. - Q, K, V, and O linear projections are all computed on-chip. - - Args: - model: nn.Module (HF CausalLM model, already loaded) - seq_len: Sequence length (default 64) - hidden_size: Target hidden dimension (None = use model's native) - inter_dim: Target intermediate dimension (None = use model's native) - num_layers: Number of layers to compile (None = all layers) - layer_idx_start: First layer index to use (default 0) - mlen: Matrix tile length (default 64) - blen: Batch tile length (default 4) - seed: Random seed for test data generation - include_lm_head: If True, add lm_head projection after final norm (default False) - golden_precision: Precision mode for golden reference computation. - "hardware" — MXFP8 weights + BF16 intermediates (default, matches HW) - "no_weight_quant" — float32 weights + BF16 intermediates (isolates MXFP8 effect) - "no_bf16" — MXFP8 weights + float32 intermediates (isolates BF16 effect) - "fp32" — float32 weights + float32 intermediates (should match HF exactly) - - Returns: - dict with: - isa: str - generated ISA code - golden_output: torch.Tensor - CPU golden reference output - input_tensors: dict - {name: tensor} for sim env setup - data_order: list[str] - HBM tensor ordering - fp_preload: list[float] - FPRAM constants - comparison_params: dict - for emulator comparison - info: dict - model dims, VRAM usage, etc. - """ + """Compile a HuggingFace decoder model to PLENA ISA and simulation metadata.""" + def _verbose(message: str = ""): + if verbose: + print(message) + # -------------------------------------------------------------- config - native_cfg = _extract_config(model) - hidden = hidden_size if hidden_size is not None else native_cfg["hidden_size"] - inter = inter_dim if inter_dim is not None else native_cfg["inter_dim"] - - num_heads = native_cfg["num_heads"] - num_kv_heads = native_cfg["num_kv_heads"] - native_head_dim = native_cfg["head_dim"] - - # Native mode: use all heads at full dimension when hidden_size is not overridden - native_mode = (hidden_size is None) and (num_heads > 1) - if native_mode: - head_dim = native_head_dim - total_q_dim = num_heads * head_dim # e.g. 6*64 = 384 - total_kv_dim = num_kv_heads * head_dim # e.g. 2*64 = 128 - else: - head_dim = min(native_head_dim, hidden) - total_q_dim = head_dim - total_kv_dim = head_dim + model_cfg = _extract_config(model) + hidden = model_cfg["hidden_size"] + inter = model_cfg["inter_dim"] + num_heads = model_cfg["num_heads"] + num_kv_heads = model_cfg["num_kv_heads"] + head_dim = model_cfg["head_dim"] + total_q_dim = num_heads * head_dim root = _find_model_root(model) layers = root.layers @@ -521,42 +518,13 @@ def compile_hf_model( # ----------------------------------------------------------- embedding embed = getattr(root, "embed_tokens", getattr(root, "wte", None)) - # ----------------------------------------------------------- lm_head - lm_head_weight = None - vocab_size = native_cfg.get("vocab_size") - if include_lm_head: - lm_head_mod = getattr(model, "lm_head", None) - if lm_head_mod is None: - lm_head_mod = getattr( - getattr(model, "language_model", model), "lm_head", None - ) - if lm_head_mod is not None and hasattr(lm_head_mod, "weight"): - # nn.Linear stores (vocab, hidden) -> transpose to (hidden, vocab) - lm_head_weight_raw = lm_head_mod.weight.detach().T.contiguous() - # Slice to hidden if we're in sliced mode - lm_head_weight = lm_head_weight_raw[:hidden, :] - vocab_size = lm_head_weight.shape[1] - # Ensure vocab_size is a multiple of mlen (pad if needed) - if vocab_size % mlen != 0: - pad_cols = mlen - (vocab_size % mlen) - lm_head_weight = F.pad(lm_head_weight, (0, pad_cols)) - vocab_size = lm_head_weight.shape[1] - else: - print("WARNING: include_lm_head=True but no lm_head module found; skipping") - include_lm_head = False - print("=" * 80) - print(f"Model Compiler - {native_cfg['model_type']} ({n_layers} layers)") - print(f" native: hidden={native_cfg['hidden_size']}, inter={native_cfg['inter_dim']}, " - f"heads={native_cfg['num_heads']}/{native_cfg['num_kv_heads']}, " - f"head_dim={native_cfg['head_dim']}") - print(f" sim: hidden={hidden}, inter={inter}, head_dim={head_dim}, " - f"seq_len={seq_len}, mlen={mlen}, native_mode={native_mode}") - if native_mode: - print(f" MHA: num_heads={num_heads}, num_kv_heads={num_kv_heads}, " - f"total_q_dim={total_q_dim}, total_kv_dim={total_kv_dim}") - if include_lm_head: - print(f" lm_head: W_lm_head={lm_head_weight.shape}, vocab_size={vocab_size}") + print(f"Model Compiler - {model_cfg['model_type']} ({n_layers} layer{'s' if n_layers != 1 else ''})") + print( + f" decoder: hidden={hidden}, inter={inter}, heads={num_heads}/{num_kv_heads}, " + f"head_dim={head_dim}" + ) + print(f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, total_q_dim={total_q_dim}") print("=" * 80) # ----------------------------------------------------------- weights @@ -565,19 +533,19 @@ def compile_hf_model( for i in range(n_layers): layer_module = layers[layer_idx_start + i] w = _extract_layer_weights( - layer_module, hidden, inter, head_dim, - num_heads, native_head_dim, + layer_module, + hidden, + inter, + num_heads, + head_dim, num_kv_heads=num_kv_heads, - native_mode=native_mode, ) all_weights.append(w) - if native_mode: - print(f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " - f"W_gate={w['W_gate'].shape}, " - f"K_heads={len(w['W_k_heads'])}x{w['W_k_heads'][0].shape}, eps={w['eps']}") - else: - print(f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " - f"W_gate={w['W_gate'].shape}, W_k={w['W_k'].shape}, eps={w['eps']}") + _verbose( + f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " + f"W_gate={w['W_gate'].shape}, " + f"K_heads={len(w['W_k_heads'])}x{w['W_k_heads'][0].shape}, eps={w['eps']}" + ) eps = all_weights[0]["eps"] @@ -586,15 +554,17 @@ def compile_hf_model( # Embedding lookup: use real HF embedding table if available if embed is not None: - input_ids = torch.randint(0, native_cfg["vocab_size"] or 32000, (seq_len,)) + input_ids = torch.randint(0, model_cfg["vocab_size"] or 32000, (seq_len,)) with torch.no_grad(): token_embeds = embed(input_ids).float() if token_embeds.dim() == 3: token_embeds = token_embeds.squeeze(0) - # Slice to target hidden dim (native mode uses full hidden) + # Slice to the model hidden width. token_embeds = token_embeds[:, :hidden] - print(f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " - f"token_embeds={token_embeds.shape}") + _verbose( + f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " + f"token_embeds={token_embeds.shape}" + ) else: token_embeds = torch.randn(seq_len, hidden) print(f"\nNo embed_tokens found; using random token_embeds: {token_embeds.shape}") @@ -603,25 +573,13 @@ def compile_hf_model( # Set pos_weight to zeros so embedding_add is a no-op for position. pos_weight = torch.zeros(seq_len, hidden) - # K/V test data: native mode computes K/V on-chip, legacy mode uses precomputed - if native_mode: - # K/V are computed on-chip from X_normed @ W_k/W_v — no precomputed K/V needed - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") - for i in range(n_layers): - for kv_h in range(num_kv_heads): - print(f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " - f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}") - else: - K_mats = [] - V_mats = [] - for i in range(n_layers): - X_ctx = torch.randn(seq_len, hidden) - K_mats.append(X_ctx @ all_weights[i]["W_k"]) - V_mats.append(X_ctx @ all_weights[i]["W_v"]) - - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") - for i in range(n_layers): - print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") + _verbose(f"pos_weight: zeros {pos_weight.shape} (RoPE model; learned position add is a no-op)") + for i in range(n_layers): + for kv_h in range(num_kv_heads): + _verbose( + f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " + f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}" + ) print(f"attn_scale: {scale:.6f}") # ----------------------------------------------------------- golden ref @@ -634,23 +592,13 @@ def compile_hf_model( "no_weight_quant": "float32 weights + BF16 intermediates", "no_bf16": "MXFP8 weights + float32 intermediates", "fp32": "float32 weights + float32 intermediates"}[golden_precision] - print(f"\n--- CPU Golden Reference ({_prec_label}) ---") - - if native_mode: - W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - if native_mode: - R_matrix = _make_rotate_half_matrix(head_dim) - R_rope_q = _qw(R_matrix) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, native_cfg["rope_theta"]) + print(f"\nComputing CPU golden reference ({_prec_label})") - else: - K_q_list = [_qw(K_mats[i]) for i in range(n_layers)] - V_q_list = [_qw(V_mats[i]) for i in range(n_layers)] + W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] + W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] + R_matrix = _make_rotate_half_matrix(head_dim) + R_rope_q = _qw(R_matrix) + cos_table, sin_table = _make_rope_tables(seq_len, head_dim, model_cfg["rope_theta"]) X_gold = _qw(token_embeds.clone()) + _qw(pos_weight) # embedding_add (MXFP8-quantized, matching HBM) ratio = num_heads // num_kv_heads @@ -665,87 +613,45 @@ def compile_hf_model( # --- Attention block --- residual = X_gold.clone() - X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 - rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() - if native_mode: - Q_gold = _ksplit_matmul(X_gold, W_q_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - Q_gold = _from_inter(_to_inter(Q_gold)) # VRAM write: BF16 - - K_q_heads_i = [] - V_q_heads_i = [] - for kv_h in range(num_kv_heads): - K_h = _ksplit_matmul(X_gold, W_k_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - V_h = _ksplit_matmul(X_gold, W_v_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - K_cos = _from_inter(_to_inter(K_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - K_rot_sin = _from_inter(_to_inter(K_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - K_h = _from_inter(_to_inter(K_cos) + _to_inter(K_rot_sin)) - # Hardware stores K/V to HBM as MXFP8, loads back as BF16 for attention - K_q_heads_i.append(_from_inter(_to_inter(_qw(K_h)))) - V_q_heads_i.append(_from_inter(_to_inter(_qw(V_h)))) - - O_heads = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - Q_cos = _from_inter(_to_inter(Q_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - Q_rot_sin = _from_inter(_to_inter(Q_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - Q_h = _from_inter(_to_inter(Q_cos) + _to_inter(Q_rot_sin)) - O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) - O_heads.append(O_h) - attn_out = _from_inter(_to_inter(torch.cat(O_heads, dim=1))) # VRAM write: BF16 - O_gold = _ksplit_matmul(attn_out, W_o_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - O_gold = _from_inter(_to_inter(O_gold)) # VRAM write: BF16 - X_gold = _from_inter(_to_inter(O_gold + residual)) # residual add -> VRAM write: BF16 - else: - attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) - X_gold = _from_inter(_to_inter(attn_out + residual)) # residual add -> VRAM write: BF16 + X_gold = _hardware_rms_norm_ref(X_gold, eps, _to_inter, _from_inter) + Q_gold = _inter_round(_hardware_linear_ref(X_gold, W_q_q, mlen, _to_inter, _from_inter), _to_inter, _from_inter) + + K_q_heads_i = [] + V_q_heads_i = [] + for kv_h in range(num_kv_heads): + K_h = _hardware_linear_ref(X_gold, W_k_q_heads[i][kv_h], mlen, _to_inter, _from_inter) + V_h = _hardware_linear_ref(X_gold, W_v_q_heads[i][kv_h], mlen, _to_inter, _from_inter) + K_h = _hardware_rope_ref(K_h, R_rope_q, cos_table, sin_table, _to_inter, _from_inter) + K_q_heads_i.append(_hardware_hbm_round_ref(K_h, _qw, _to_inter, _from_inter)) + V_q_heads_i.append(_hardware_hbm_round_ref(V_h, _qw, _to_inter, _from_inter)) + + O_heads = [] + for h in range(num_heads): + kv_h = h // ratio + Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] + Q_h = _hardware_rope_ref(Q_h, R_rope_q, cos_table, sin_table, _to_inter, _from_inter) + O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) + O_heads.append(O_h) + attn_out = _inter_round(torch.cat(O_heads, dim=1), _to_inter, _from_inter) + O_gold = _inter_round(_hardware_linear_ref(attn_out, W_o_q, mlen, _to_inter, _from_inter), _to_inter, _from_inter) + X_gold = _hardware_residual_add_ref(O_gold, residual, _to_inter, _from_inter) # --- FFN block --- - residual = X_gold.clone() - X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 - rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() - # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) - # Use _ksplit_matmul to match hardware K-split BF16 precision loss for - # projections that exceed MAX_K_TILES (e.g., hidden=576 → 9 tiles, inter=1536 → 24 tiles) - up_out = _ksplit_matmul(X_gold, W_up_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - gate_out = _ksplit_matmul(X_gold, W_gate_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - # Hardware: SiLU(up) * gate -> BF16 VRAM write before down projection - silu_gate = _to_inter(F.silu(_from_inter(_to_inter(up_out))) * _from_inter(_to_inter(gate_out))) - X_gold = _ksplit_matmul(_from_inter(silu_gate), W_down_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - X_gold = _from_inter(_to_inter(X_gold)) # VRAM write: BF16 after down proj - X_gold = _from_inter(_to_inter(X_gold + residual)) # residual add -> VRAM write: BF16 - - print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") + X_gold = _hardware_ffn_ref(X_gold, W_gate_q, W_up_q, W_down_q, eps, mlen, _to_inter, _from_inter) + + _verbose(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") # Final norm - X_gold = _rms_norm_ref(X_gold, eps) - - # lm_head projection (optional) - if include_lm_head and lm_head_weight is not None: - W_lm_q = quantize_to_mxfp(lm_head_weight) - # Hardware: MXFP8 → BF16 (MRAM) → float32 (M_MM) - logits_gold = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_lm_q.to(torch.bfloat16).float() - ).to(torch.bfloat16).float() - golden_out = logits_gold - print(f" logits_gold: {golden_out.shape}") - else: - golden_out = X_gold + X_gold = _hardware_rms_norm_ref(X_gold, eps, _to_inter, _from_inter) + + golden_out = X_gold print(f" golden_out: {golden_out.shape}") - print(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") + _verbose(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") # ----------------------------------------------------------- HF ground truth # Same pipeline as golden but pure float32: no MXFP8 quantization, no BF16 casting. - # This is the "best possible" reference for the sliced weight dimensions being tested. - print(f"\n--- HF Ground Truth (float32, {n_layers} layers, no quantization) ---") + # This is the float32 reference for the same decoder blocks being compiled. + print(f"\nComputing HF reference (float32, {n_layers} layer{'s' if n_layers != 1 else ''}, no quantization)") with torch.no_grad(): X_hf = token_embeds.clone() + pos_weight # embedding_add @@ -753,65 +659,45 @@ def compile_hf_model( w = all_weights[i] eps_i = w["eps"] - # --- Attention block (float32) --- residual = X_hf.clone() - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps_i) - X_normed = X_hf * rms - - if native_mode: - Q_hf = X_normed @ w["W_q"].float() - - K_hf_heads = [] - V_hf_heads = [] - R_mat_f32 = R_matrix.float() - cos_f32 = cos_table.float() - sin_f32 = sin_table.float() - for kv_h in range(num_kv_heads): - K_h = X_normed @ w["W_k_heads"][kv_h].float() - V_h = X_normed @ w["W_v_heads"][kv_h].float() - # RoPE on K_h (float32) - K_rot = K_h @ R_mat_f32 - K_h = K_h * cos_f32 + K_rot * sin_f32 - K_hf_heads.append(K_h) - V_hf_heads.append(V_h) - - O_heads_hf = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_hf[:, h * head_dim:(h + 1) * head_dim] - # RoPE on Q_h (float32) - Q_rot = Q_h @ R_mat_f32 - Q_h = Q_h * cos_f32 + Q_rot * sin_f32 - O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) - O_heads_hf.append(O_h) - attn_out_hf = torch.cat(O_heads_hf, dim=1) - O_hf = attn_out_hf @ w["W_o"].float() - X_hf = O_hf + residual - else: - attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale, causal=True) - X_hf = attn_out_hf + residual + X_normed = _fp32_rms_norm_ref(X_hf, eps_i) - # --- FFN block (float32) --- - residual = X_hf.clone() - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps_i) - X_normed = X_hf * rms - up_out = F.silu(X_normed @ w["W_up"].float()) - gate_out = X_normed @ w["W_gate"].float() - X_hf = (up_out * gate_out) @ w["W_down"].float() + residual + Q_hf = _fp32_linear_ref(X_normed, w["W_q"]) + + K_hf_heads = [] + V_hf_heads = [] + for kv_h in range(num_kv_heads): + K_hf_heads.append( + _fp32_rope_ref( + _fp32_linear_ref(X_normed, w["W_k_heads"][kv_h]), + R_matrix, + cos_table, + sin_table, + ) + ) + V_hf_heads.append(_fp32_linear_ref(X_normed, w["W_v_heads"][kv_h])) - print(f" After layer {i}: X_hf[0,:4] = {X_hf[0, :4].tolist()}") + O_heads_hf = [] + for h in range(num_heads): + kv_h = h // ratio + Q_h = Q_hf[:, h * head_dim:(h + 1) * head_dim] + Q_h = _fp32_rope_ref(Q_h, R_matrix, cos_table, sin_table) + O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) + O_heads_hf.append(O_h) + attn_out_hf = torch.cat(O_heads_hf, dim=1) + O_hf = _fp32_linear_ref(attn_out_hf, w["W_o"]) + X_hf = O_hf + residual - # Final norm (float32) - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps) - X_hf = X_hf * rms + X_hf = _fp32_ffn_ref(X_hf, w, eps_i) - if include_lm_head and lm_head_weight is not None: - hf_ground_truth = (X_hf @ lm_head_weight.float()) - else: - hf_ground_truth = X_hf + _verbose(f" After layer {i}: X_hf[0,:4] = {X_hf[0, :4].tolist()}") + + X_hf = _fp32_rms_norm_ref(X_hf, eps) + + hf_ground_truth = X_hf print(f" hf_ground_truth: {hf_ground_truth.shape}") - print(f" hf_ground_truth[0,:4]: {hf_ground_truth[0, :4].tolist()}") + _verbose(f" hf_ground_truth[0,:4]: {hf_ground_truth[0, :4].tolist()}") # ----------------------------------------------------------- PLENA ISA print("\n--- PLENA Backend (ISA generation) ---") @@ -824,17 +710,11 @@ def compile_hf_model( x_input = prog.input("X", shape=(seq_len, hidden)) pos_input = prog.input("POS", shape=(seq_len, hidden)) - # RoPE inputs (native mode only: R_rope matrix + cos/sin tables) - if native_mode: - rope_theta = native_cfg["rope_theta"] - R_matrix = _make_rotate_half_matrix(head_dim) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, rope_theta) - - r_input = prog.input("R_rope", shape=(head_dim, head_dim)) - cos_input = prog.input("COS", shape=(seq_len, head_dim)) - sin_input = prog.input("SIN", shape=(seq_len, head_dim)) - COS = prog.load_batch(cos_input, name="COS") - SIN = prog.load_batch(sin_input, name="SIN") + r_input = prog.input("R_rope", shape=(head_dim, head_dim)) + cos_input = prog.input("COS", shape=(seq_len, head_dim)) + sin_input = prog.input("SIN", shape=(seq_len, head_dim)) + COS = prog.load_batch(cos_input, name="COS") + SIN = prog.load_batch(sin_input, name="SIN") # Causal mask: (mlen, mlen) with 0 on/below diagonal, -inf above causal_mask_data = torch.zeros(mlen, mlen) @@ -847,36 +727,17 @@ def compile_hf_model( # Per-layer weight inputs (order determines HBM layout) layer_inputs = [] for i in range(n_layers): - wq = prog.input(f"W_q_{i}", shape=(hidden, total_q_dim)) - wo = prog.input(f"W_o_{i}", shape=(total_q_dim, hidden)) - if native_mode: - wk_heads = [] - wv_heads = [] - for kv_h in range(num_kv_heads): - wk_heads.append(prog.input(f"W_k_{i}_h{kv_h}", shape=(hidden, head_dim))) - wv_heads.append(prog.input(f"W_v_{i}_h{kv_h}", shape=(hidden, head_dim))) - li_entry = { - "W_q": wq, "W_o": wo, - "W_k_heads": wk_heads, "W_v_heads": wv_heads, - } - else: - ki = prog.input(f"K_{i}", shape=(seq_len, head_dim)) - vi = prog.input(f"V_{i}", shape=(seq_len, head_dim)) - li_entry = { - "W_q": wq, "W_o": wo, - "K": ki, "V": vi, - } - wg = prog.input(f"W_gate_{i}", shape=(hidden, inter)) - wu = prog.input(f"W_up_{i}", shape=(hidden, inter)) - wd = prog.input(f"W_down_{i}", shape=(inter, hidden)) - li_entry.update({"W_gate": wg, "W_up": wu, "W_down": wd}) + li_entry = {"W_k_heads": [], "W_v_heads": []} + for tensor_name, tensor in _layer_tensor_entries(i, all_weights[i], num_kv_heads): + var = prog.input(tensor_name, shape=tuple(tensor.shape)) + if tensor_name.startswith(f"W_k_{i}_h"): + li_entry["W_k_heads"].append(var) + elif tensor_name.startswith(f"W_v_{i}_h"): + li_entry["W_v_heads"].append(var) + else: + li_entry[tensor_name[: tensor_name.rfind(f"_{i}")]] = var layer_inputs.append(li_entry) - # lm_head weight input (after all layer weights in HBM layout) - lm_head_input = None - if include_lm_head and lm_head_weight is not None: - lm_head_input = prog.input("W_lm_head", shape=(hidden, vocab_size)) - # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") POS_batch = prog.load_batch(pos_input, name="POS") @@ -887,8 +748,7 @@ def compile_hf_model( # The residual scratch buffer must be placed ABOVE this region. _ffn_intermediate_end = seq_len * hidden + 2 * inter * seq_len _current_bump = 2 * seq_len * hidden # X + POS already allocated - if native_mode: - _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM + _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: _pad_elems = _ffn_intermediate_end - _current_bump @@ -908,163 +768,45 @@ def compile_hf_model( li = layer_inputs[i] # Layer progress marker (visible in non-quiet emulator output) - prog._compiler.emit_comment(f"=== LAYER {i}/{n_layers} START ===") - - # --- Attention block --- - # Save residual: scratch = current (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, current) - - # Norm (in-place on current) - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - - if native_mode: - # Q projection: current (seq, hidden) @ W_q (hidden, total_q_dim) - Q = _linear_projection(prog, current, li["W_q"], f"Q_{i}") - - # Per-KV-head: compute K/V on-chip, store to HBM, then run Q-heads - q_full_addr = prog._compiler.get_vram_addr(Q.name) - - # Allocate O_full for concatenated head outputs - O_full = prog.alloc(f"O_full_{i}", seq_len, total_q_dim) - o_full_addr = prog._compiler.get_vram_addr(O_full.name) - - # Store K/V InputVars for each KV head (filled during loop below) - kv_stored = [] - - for kv_h in range(num_kv_heads): - # K projection: current (seq, hidden) @ W_k_h (hidden, head_dim) - K_h = _linear_projection(prog, current, li["W_k_heads"][kv_h], f"K_{i}_h{kv_h}") - # V projection: current (seq, hidden) @ W_v_h (hidden, head_dim) - V_h = _linear_projection(prog, current, li["W_v_heads"][kv_h], f"V_{i}_h{kv_h}") - - # RoPE on K_h: K_rot = linear(K_h, R_rope), then K_h = K_h * cos + K_rot * sin - K_rot = _linear_projection(prog, K_h, r_input, f"K_rot_{i}_h{kv_h}") - prog.rope(K_h, K_rot, COS, SIN) - prog.free_tensor(K_rot) - - # Store K/V from VRAM to HBM (auto-allocates HBM space) - K_stored = prog.store(K_h, name=f"K_stored_{i}_h{kv_h}") - V_stored = prog.store(V_h, name=f"V_stored_{i}_h{kv_h}") - kv_stored.append((K_stored, V_stored)) - - # Free K/V VRAM — data is in HBM now - prog.free_tensor(K_h) - prog.free_tensor(V_h) - - for h in range(num_heads): - kv_h = h // ratio - K_stored, V_stored = kv_stored[kv_h] - - # View for this head's Q slice: column block h in Q_full - q_h_addr = q_full_addr + h * seq_len * mlen - Q_h = prog.alloc_at(f"Q_h{h}_{i}", seq_len, head_dim, q_h_addr) - - # RoPE on Q_h: Q_rot = linear(Q_h, R_rope), then Q_h = Q_h * cos + Q_rot * sin - Q_rot = _linear_projection(prog, Q_h, r_input, f"Q_rot_{i}_h{h}") - prog.rope(Q_h, Q_rot, COS, SIN) - prog.free_tensor(Q_rot) - - # Single-head flash attention (K/V read from HBM) with causal mask - O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale, - causal_mask=CAUSAL_MASK) - - # Copy O_h to the right column block of O_full - o_h_dest_addr = o_full_addr + h * seq_len * mlen - O_h_dest = prog.alloc_at(f"O_dest_h{h}_{i}", seq_len, head_dim, o_h_dest_addr) - prog.vram_fill_zero(O_h_dest) - prog.vram_add(O_h_dest, O_h) - - # Free flash_attention intermediates (S, PV, O) to reclaim VRAM - for _tmp_name in ("O", "S", "PV"): - _tmp_var = prog._tensors.get(_tmp_name) - if _tmp_var is not None: - prog.free_tensor(_tmp_var) - prog._tensors.pop(_tmp_name, None) - - # O projection: O_full @ W_o -> O_proj (seq, hidden) - O_proj = _linear_projection(prog, O_full, li["W_o"], f"O_proj_{i}") - - # Attention residual: O_proj += scratch - prog.vram_add(O_proj, scratch) - current_after_attn = O_proj - else: - # Legacy: X is Q directly (no projections), with causal mask - O = ops.flash_attention(prog, current, li["K"], li["V"], scale, - causal_mask=CAUSAL_MASK) - prog.vram_add(O, scratch) - current_after_attn = O - - # --- FFN block --- - # Save residual: scratch = current_after_attn (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, current_after_attn) - - # Norm (in-place) - prog.rms_norm(current_after_attn, eps_offset=3, reci_hid_offset=4) - - # FFN (in-place) - ops.ffn(prog, current_after_attn, li["W_gate"], li["W_up"], li["W_down"]) - - # FFN residual - prog.vram_add(current_after_attn, scratch) + prog.emit_comment(f"=== LAYER {i}/{n_layers} START ===") + + current_after_attn = _emit_attention_block( + prog, + current, + li, + (r_input, COS, SIN), + CAUSAL_MASK, + scratch, + scale, + i, + seq_len, + head_dim, + total_q_dim, + num_heads, + num_kv_heads, + ratio, + ) - current = current_after_attn # carry forward - prog._compiler.emit_comment(f"=== LAYER {i}/{n_layers} COMPLETE ===") + current = _emit_ffn_block(prog, current_after_attn, li, scratch) + prog.emit_comment(f"=== LAYER {i}/{n_layers} COMPLETE ===") # Final norm prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - # lm_head projection (optional): logits = linear(current, W_lm_head) - if include_lm_head and lm_head_input is not None: - logits = _linear_projection(prog, current, lm_head_input, "logits") - current = logits - isa_code = prog.compile() isa_code = _fix_large_immediates(isa_code) lines = isa_code.splitlines() print(f"\nGenerated {len(lines)} lines of ISA code") # ----------------------------------------------------------- build return - input_tensors = {"X": token_embeds, "POS": pos_weight} - data_order = ["X", "POS"] - if native_mode: - input_tensors["R_rope"] = R_matrix - input_tensors["COS"] = cos_table - input_tensors["SIN"] = sin_table - data_order.extend(["R_rope", "COS", "SIN"]) + input_tensors = {"X": token_embeds, "POS": pos_weight, "R_rope": R_matrix, "COS": cos_table, "SIN": sin_table} + data_order = ["X", "POS", "R_rope", "COS", "SIN"] input_tensors["causal_mask"] = causal_mask_data data_order.append("causal_mask") for i in range(n_layers): - input_tensors[f"W_q_{i}"] = all_weights[i]["W_q"] - input_tensors[f"W_o_{i}"] = all_weights[i]["W_o"] - if native_mode: - for kv_h in range(num_kv_heads): - input_tensors[f"W_k_{i}_h{kv_h}"] = all_weights[i]["W_k_heads"][kv_h] - input_tensors[f"W_v_{i}_h{kv_h}"] = all_weights[i]["W_v_heads"][kv_h] - else: - input_tensors[f"K_{i}"] = K_mats[i] - input_tensors[f"V_{i}"] = V_mats[i] - input_tensors[f"W_gate_{i}"] = all_weights[i]["W_gate"] - input_tensors[f"W_up_{i}"] = all_weights[i]["W_up"] - input_tensors[f"W_down_{i}"] = all_weights[i]["W_down"] - - kv_keys = [] - if native_mode: - for kv_h in range(num_kv_heads): - kv_keys.extend([f"W_k_{i}_h{kv_h}", f"W_v_{i}_h{kv_h}"]) - else: - kv_keys = [f"K_{i}", f"V_{i}"] - data_order.extend([ - f"W_q_{i}", f"W_o_{i}", - *kv_keys, - f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}", - ]) - - # lm_head weight in input_tensors + data_order - if include_lm_head and lm_head_weight is not None: - input_tensors["W_lm_head"] = lm_head_weight - data_order.append("W_lm_head") + for name, tensor in _layer_tensor_entries(i, all_weights[i], num_kv_heads): + input_tensors[name] = tensor + data_order.append(name) # FPRAM layout (same as single-layer decoder): # slot 0 = 0.0 (reserved) @@ -1077,13 +819,9 @@ def compile_hf_model( fp_preload = [0.0, scale, float("-inf"), eps, 1.0 / hidden, 1.0] + [0.0] * 4 # Result is at current's VRAM location - o_vram_addr = prog._compiler.get_vram_addr(current.name) + o_vram_addr = prog.get_vram_addr(current.name) - # Output dimensions depend on whether lm_head is included - if include_lm_head and lm_head_weight is not None: - out_features = vocab_size - else: - out_features = hidden + out_features = hidden comparison_params = { "start_row_idx": o_vram_addr // mlen, @@ -1095,17 +833,14 @@ def compile_hf_model( } info = { - "model_type": native_cfg["model_type"], + "model_type": model_cfg["model_type"], "hidden_size": hidden, "inter_dim": inter, "num_layers": n_layers, "seq_len": seq_len, "head_dim": head_dim, - "num_heads": num_heads if native_mode else 1, - "num_kv_heads": num_kv_heads if native_mode else 1, - "native_mode": native_mode, - "include_lm_head": include_lm_head and lm_head_weight is not None, - "vocab_size": vocab_size if (include_lm_head and lm_head_weight is not None) else None, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, "mlen": mlen, "blen": blen, "isa_lines": len(lines), @@ -1113,9 +848,6 @@ def compile_hf_model( print(f"\nCompilation complete: {info['isa_lines']} ISA lines, " f"{n_layers} layers, output at VRAM row {o_vram_addr // mlen}") - if include_lm_head and lm_head_weight is not None: - print(f" lm_head: output shape=({seq_len}, {vocab_size})") - return { "isa": isa_code, "golden_output": golden_out, @@ -1127,145 +859,3 @@ def compile_hf_model( "info": info, "golden_precision": golden_precision, } - - -# --------------------------------------------------------------------------- -# Convenience: compile + run emulator + compare -# --------------------------------------------------------------------------- -def compile_and_run( - model, - build_dir, - **kwargs, -) -> dict: - """Compile, run emulator, and compare against golden. - - Convenience wrapper that calls compile_hf_model, sets up simulation - environment, runs the Rust transactional emulator, and compares output. - - Args: - model: nn.Module (HF CausalLM model, already loaded) - build_dir: Directory for simulation artifacts - **kwargs: Forwarded to compile_hf_model (seq_len, hidden_size, etc.) - - Returns: - dict with compilation info + comparison results including - 'allclose_match_rate' percentage. - """ - from transactional_emulator.tools.create_sim_env import create_sim_env - from compiler.sim_env_utils.build_env import create_mem_for_sim - from transactional_emulator.testbench.emulator_runner import ( - compare_emulator_output, - ) - - result = compile_hf_model(model, **kwargs) - build_dir = Path(build_dir) - build_dir.mkdir(parents=True, exist_ok=True) - - mlen = kwargs.get("mlen", 64) - blen = kwargs.get("blen", 4) - asm_name = f"model_{result['info']['model_type']}_{result['info']['num_layers']}L" - - # Write sim env files - create_sim_env( - result["input_tensors"], - result["isa"], - {"original_output": result["golden_output"]}, - result["fp_preload"], - build_dir=str(build_dir), - ) - - create_mem_for_sim( - data_size=256, - mode="behave_sim", - asm=asm_name, - data=None, - specified_data_order=result["data_order"], - build_path=build_dir, - ) - - with open(build_dir / "comparison_params.json", "w") as f: - json.dump(result["comparison_params"], f, indent=2) - - with open(build_dir / "generated_asm_code.asm", "w") as f: - f.write(result["isa"]) - - print(f"\nSimulation environment created: {build_dir}") - print(f" Result location: VRAM row {result['comparison_params']['start_row_idx']}") - print(f" Layers: {result['info']['num_layers']}, data_order: {result['data_order']}") - - # Run emulator and compare (don't exit on failure — VRAM stage comparison follows) - from transactional_emulator.testbench.emulator_runner import update_plena_config, run_emulator - update_plena_config(vlen=mlen, mlen=mlen, blen=blen, verbose=False) - print("\n--- Running Rust transactional emulator ---") - run_emulator(build_dir) - - print("\n--- Comparing emulator output vs golden ---") - comp_results, _params = compare_emulator_output(build_dir) - from transactional_emulator.tools.check_mem import print_comparison_results - print_comparison_results(comp_results, verbose=True, comparison_params=_params) - - if comp_results["allclose_pass"]: - print(f"\n[ATen-style {asm_name} test PASSED - ISA generated + emulator verified]") - else: - print(f"\n[ATen-style {asm_name} test FAILED - emulator numerical check failed]") - - # Three-way comparison - golden = result["golden_output"] - hf_gt = result["hf_ground_truth"] - print("\n--- Three-way comparison ---") - if hf_gt is not None and golden is not None: - # HF float32 vs golden (MXFP8 + BF16) - n = min(hf_gt.numel(), golden.numel()) - allclose_hf_vs_gold = ( - torch.isclose(hf_gt.float().flatten()[:n], - golden.float().flatten()[:n], atol=1e-2) - .float().mean().item() * 100 - ) - print(f" HF float32 vs golden (MXFP8+BF16): {allclose_hf_vs_gold:.2f}% allclose") - # Emulator vs golden: reported by compare_emulator_output - emu_match = comp_results.get("allclose_match_rate", None) - if emu_match is not None: - print(f" Emulator vs golden (MXFP8+BF16): {emu_match:.2f}% allclose") - - # VRAM stage comparison: validates each pipeline segment using - # emulator's own intermediates as golden input (immune to accumulation drift) - try: - from compiler.aten.vram_stage_compare import compare_stages - emulator_dir = Path(__file__).parent.parent.parent / "transactional_emulator" - vram_path = str(emulator_dir / "vram_dump.bin") - print("\n--- VRAM stage comparison (authoritative) ---") - stage_results = compare_stages( - vram_path=vram_path, - build_dir=str(build_dir), - hidden=result["info"]["hidden_size"], - inter=result["info"].get("inter_dim", result["info"]["hidden_size"] * 4), - num_heads=result["info"]["num_heads"], - num_kv_heads=result["info"]["num_kv_heads"], - ) - stage_pass = stage_results.get("norm+FFN+norm", 0) >= 99.0 - comp_results["vram_stage_allclose"] = stage_results.get("norm+FFN+norm", None) - comp_results["vram_stage_pass"] = stage_pass - except Exception as e: - print(f" (skipped: {e})") - - comp_results, _params = compare_emulator_output(build_dir) - - # Three-way comparison - golden = result["golden_output"] - hf_gt = result["hf_ground_truth"] - print("\n--- Three-way comparison ---") - if hf_gt is not None and golden is not None: - # HF float32 vs golden (MXFP8 + BF16) - n = min(hf_gt.numel(), golden.numel()) - allclose_hf_vs_gold = ( - torch.isclose(hf_gt.float().flatten()[:n], - golden.float().flatten()[:n], atol=1e-2) - .float().mean().item() * 100 - ) - print(f" HF float32 vs golden (MXFP8+BF16): {allclose_hf_vs_gold:.2f}% allclose") - # Emulator vs golden: reported by compare_emulator_output - emu_match = comp_results.get("allclose_match_rate", None) - if emu_match is not None: - print(f" Emulator vs golden (MXFP8+BF16): {emu_match:.2f}% allclose") - - return {**result["info"], **comp_results} diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index 43b7520..1e7efd5 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -131,7 +131,7 @@ def test_alloc_at_correct_address(): view_addr = x_addr + 2 * 64 * 64 view = prog.alloc_at("X_cb2_view", 64, 64, view_addr) - actual_addr = prog._compiler.get_vram_addr(view.name) + actual_addr = prog.get_vram_addr(view.name) assert actual_addr == view_addr, ( f"alloc_at address mismatch: expected {view_addr}, got {actual_addr}" ) @@ -214,7 +214,7 @@ def test_compile_hf_model_golden_vs_hf(): ) model.eval() - r = compile_hf_model(model, seq_len=64, hidden_size=None, inter_dim=None, num_layers=1) + r = compile_hf_model(model, seq_len=64, num_layers=1) golden = r["golden_output"] hf = r["hf_ground_truth"] @@ -242,7 +242,7 @@ def test_native_compile_assembles(): ) model.eval() - r = compile_hf_model(model, seq_len=64, hidden_size=None, inter_dim=None, num_layers=1) + r = compile_hf_model(model, seq_len=64, num_layers=1) isa = r["isa"] # Assemble — should not raise ValueError (u32 overflow) diff --git a/aten/tests/test_quantization_ablation.py b/aten/tests/test_quantization_ablation.py index 2142df3..58fdb94 100644 --- a/aten/tests/test_quantization_ablation.py +++ b/aten/tests/test_quantization_ablation.py @@ -84,8 +84,8 @@ def test_mxfp8_is_sole_gap_source(): for mode in MODES: r = results[mode] print(f" {mode:<20} {r['allclose']:>11.2f}% {r['mse']:>15.6e}") - print(f"\n MXFP8 weight quantization = 100% of the gap") - print(f" BF16 intermediate precision = 0% of the gap") + print("\n MXFP8 weight quantization = 100% of the gap") + print(" BF16 intermediate precision = 0% of the gap") if __name__ == "__main__": @@ -103,5 +103,5 @@ def test_mxfp8_is_sole_gap_source(): for mode in MODES: r = results[mode] print(f" {mode:<20} {r['allclose']:>11.2f}% {r['mse']:>15.6e}") - print(f"\n Conclusion: MXFP8 weight quantization = 100% of the gap") - print(f" BF16 intermediate precision = 0% of the gap") + print("\n Conclusion: MXFP8 weight quantization = 100% of the gap") + print(" BF16 intermediate precision = 0% of the gap") diff --git a/compiler/__init__.py b/compiler/__init__.py new file mode 100644 index 0000000..a01b45f --- /dev/null +++ b/compiler/__init__.py @@ -0,0 +1,12 @@ +"""Compatibility namespace for legacy ``compiler.*`` imports. + +PLENA_Compiler packages live at the repository root (``aten``, ``generator``, +``asm_templates``, ...), but existing code imports them through ``compiler``. +Keep that import path local to this submodule instead of resolving to the +simulator sibling directory. +""" + +from pathlib import Path + +__path__ = [str(Path(__file__).resolve().parent.parent)] + diff --git a/pyproject.toml b/pyproject.toml index a5a400a..dc610c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = [ + "compiler", "assembler", "asm_templates", "aten", From 5e4bed63e013667fcd6400ed7ffb1d31d69768cd Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 14:03:27 +0100 Subject: [PATCH 22/32] refactor: split ATen frontend references --- aten/model_extract.py | 130 ++++++++++ aten/plena_frontend.py | 544 +++++++++-------------------------------- aten/reference.py | 291 ++++++++++++++++++++++ 3 files changed, 532 insertions(+), 433 deletions(-) create mode 100644 aten/model_extract.py create mode 100644 aten/reference.py diff --git a/aten/model_extract.py b/aten/model_extract.py new file mode 100644 index 0000000..25a07eb --- /dev/null +++ b/aten/model_extract.py @@ -0,0 +1,130 @@ +"""HuggingFace decoder model extraction helpers for the PLENA frontend.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass(frozen=True) +class ModelConfig: + """Decoder dimensions needed by the PLENA ATen frontend.""" + + hidden_size: int + inter_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + eps: float + rope_theta: float + vocab_size: int | None + model_type: str + + @property + def total_q_dim(self) -> int: + return self.num_heads * self.head_dim + + @property + def head_ratio(self) -> int: + return self.num_heads // self.num_kv_heads + + +@dataclass(frozen=True) +class LayerWeights: + """One decoder layer in PLENA's (in, out) linear-weight convention.""" + + w_q: torch.Tensor + w_o: torch.Tensor + w_k_heads: list[torch.Tensor] + w_v_heads: list[torch.Tensor] + w_gate: torch.Tensor + w_up: torch.Tensor + w_down: torch.Tensor + eps: float + + def tensor_entries(self, layer_idx: int) -> list[tuple[str, torch.Tensor]]: + entries = [ + (f"W_q_{layer_idx}", self.w_q), + (f"W_o_{layer_idx}", self.w_o), + ] + for kv_h, (w_k, w_v) in enumerate(zip(self.w_k_heads, self.w_v_heads, strict=True)): + entries.extend( + [ + (f"W_k_{layer_idx}_h{kv_h}", w_k), + (f"W_v_{layer_idx}_h{kv_h}", w_v), + ] + ) + entries.extend( + [ + (f"W_gate_{layer_idx}", self.w_gate), + (f"W_up_{layer_idx}", self.w_up), + (f"W_down_{layer_idx}", self.w_down), + ] + ) + return entries + + +def find_model_root(model: Any) -> Any: + """Find the transformer backbone (model.model or model.model.text_model).""" + for candidate in [ + getattr(model, "model", None), + getattr(getattr(model, "model", None), "text_model", None), + getattr(model, "language_model", getattr(model, "text_model", None)), + ]: + if candidate is not None and hasattr(candidate, "layers"): + return candidate + raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") + + +def embedding_module(root: Any) -> Any | None: + """Return the token embedding module when the backbone exposes one.""" + return getattr(root, "embed_tokens", getattr(root, "wte", None)) + + +def extract_model_config(model: Any) -> ModelConfig: + """Extract decoder dimensions, resolving text_config for VLM wrappers.""" + config = getattr(model.config, "text_config", model.config) + hidden = config.hidden_size + num_heads = config.num_attention_heads + return ModelConfig( + hidden_size=hidden, + inter_dim=getattr(config, "intermediate_size", 4 * hidden), + num_heads=num_heads, + num_kv_heads=getattr(config, "num_key_value_heads", num_heads), + head_dim=hidden // num_heads, + eps=getattr(config, "rms_norm_eps", 1e-5), + rope_theta=getattr(config, "rope_theta", 10000.0), + vocab_size=getattr(config, "vocab_size", None), + model_type=getattr(config, "model_type", "unknown"), + ) + + +def extract_layer_weights(layer: Any, config: ModelConfig) -> LayerWeights: + """Extract one decoder layer in PLENA's (in, out) convention.""" + hidden = config.hidden_size + total_kv_dim = config.num_kv_heads * config.head_dim + w_k_full = _linear_weight(layer.self_attn.k_proj, hidden, total_kv_dim) + w_v_full = _linear_weight(layer.self_attn.v_proj, hidden, total_kv_dim) + norm = layer.input_layernorm + + return LayerWeights( + w_q=_linear_weight(layer.self_attn.q_proj, hidden, config.total_q_dim), + w_o=_linear_weight(layer.self_attn.o_proj, config.total_q_dim, hidden), + w_k_heads=_split_heads(w_k_full, config.head_dim, config.num_kv_heads), + w_v_heads=_split_heads(w_v_full, config.head_dim, config.num_kv_heads), + w_gate=_linear_weight(layer.mlp.gate_proj, hidden, config.inter_dim), + w_up=_linear_weight(layer.mlp.up_proj, hidden, config.inter_dim), + w_down=_linear_weight(layer.mlp.down_proj, config.inter_dim, hidden), + eps=getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)), + ) + + +def _linear_weight(module: Any, rows: int, cols: int) -> torch.Tensor: + """HF Linear stores (out, in); PLENA uses (in, out).""" + return module.weight.detach().T.contiguous()[:rows, :cols] + + +def _split_heads(weight: torch.Tensor, head_dim: int, num_heads: int) -> list[torch.Tensor]: + return [weight[:, h * head_dim:(h + 1) * head_dim].contiguous() for h in range(num_heads)] diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index b77c679..47576fb 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -2,15 +2,38 @@ import math import re +from dataclasses import dataclass +from typing import Any import torch -import torch.nn.functional as F +from compiler.aten.model_extract import ( + LayerWeights, + embedding_module, + extract_layer_weights, + extract_model_config, + find_model_root, +) from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.ops.registry import OpRegistry, Backend from compiler.aten.ops.plena.linear_ops import linear_projection_plena as _linear_projection +from compiler.aten.reference import ( + ReferencePrecision, + _ksplit_matmul, + _make_rotate_half_matrix, + make_rope_inputs, + quantize_to_mxfp, + run_decoder_reference, +) import compiler.aten.ops as ops -from quant.quantizer.hardware_quantizer.mxfp import _mx_fp_quantize_hardware + +__all__ = [ + "_fix_large_immediates", + "_ksplit_matmul", + "_make_rotate_half_matrix", + "compile_hf_model", + "quantize_to_mxfp", +] _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate @@ -45,265 +68,18 @@ def _fix_large_immediates(isa_code: str) -> str: # --------------------------------------------------------------------------- REAL_DATA_RATIO = (8 * 8 + 8) / (8 * 8) -# Hardware K-split tile limit (matches the PLENA linear backend) -_HW_MAX_K_TILES = 4 - - -def quantize_to_mxfp(tensor: torch.Tensor) -> torch.Tensor: - """Quantize tensor to MXFP8 matching HBM hardware format; return dequantized result.""" - orig_shape = tensor.shape - tensor_2d = tensor.float().reshape(-1, tensor.shape[-1]) - bm_x, _, _, _ = _mx_fp_quantize_hardware( - tensor_2d, - width=8, - exponent_width=4, - exponent_bias_width=8, - block_size=[1, 8], - ) - return bm_x.reshape(orig_shape) - - -def _make_rope_tables(seq_len: int, head_dim: int, theta: float = 10000.0): - """Compute RoPE cos/sin tables, shape (seq_len, head_dim).""" - half = head_dim // 2 - freqs = 1.0 / (theta ** (torch.arange(0, half).float() / half)) - positions = torch.arange(seq_len).float() - angles = torch.outer(positions, freqs) - cos_half = torch.cos(angles) - sin_half = torch.sin(angles) - cos = torch.cat([cos_half, cos_half], dim=-1) - sin = torch.cat([sin_half, sin_half], dim=-1) - return cos, sin - - -def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, from_inter=None): - """Matrix multiply matching hardware K-split BF16 precision. - - When the K dimension exceeds max_k_tiles * mlen, the hardware splits the - inner product into chunks, writing each partial sum to BF16 VRAM and - accumulating via BF16 add. This function replicates that precision loss - so the golden reference matches the emulator output. - - If the K dimension fits in a single pass (num_k_tiles <= max_k_tiles), - this is equivalent to a single matmul with BF16 cast. - """ - if to_inter is None: - def to_inter(x): - return x.to(torch.bfloat16) - - if from_inter is None: - def from_inter(x): - return x.float() - - k_total = A.shape[1] - num_k_tiles = math.ceil(k_total / mlen) - - if num_k_tiles <= max_k_tiles: - # Single pass — no K-split precision loss - # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) - # Cast B to BF16 then float32 to match MRAM precision loss - return from_inter(to_inter(torch.matmul(from_inter(to_inter(A)), from_inter(to_inter(B))))) - - # K-split: chunk K tiles into groups of max_k_tiles - result = None - k_start = 0 - while k_start < k_total: - k_end = min(k_start + max_k_tiles * mlen, k_total) - A_chunk = A[:, k_start:k_end] - B_chunk = B[k_start:k_end, :] - # Cast B_chunk to BF16 then float32 to match MRAM precision loss - partial = from_inter(to_inter(torch.matmul(from_inter(to_inter(A_chunk)), from_inter(to_inter(B_chunk))))) - if result is None: - result = partial - else: - # Hardware accumulates in BF16 VRAM (vram_block_add_to) - result = from_inter(to_inter(result) + to_inter(partial)) - k_start = k_end - - return result - - -# --------------------------------------------------------------------------- -# RoPE helpers -# --------------------------------------------------------------------------- -def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: - """Build the (head_dim, head_dim) matrix that computes rotate_half. - - rotate_half(x) = x @ R, where R permutes the halves with a sign flip: - output[:d//2] = -input[d//2:] - output[d//2:] = +input[:d//2] - """ - R = torch.zeros(head_dim, head_dim) - half = head_dim // 2 - for i in range(half): - R[i + half, i] = -1.0 # first half of output = negated second half of input - R[i, i + half] = 1.0 # second half of output = first half of input - return R - - -# --------------------------------------------------------------------------- -# Model structure helpers -# --------------------------------------------------------------------------- -def _linear_weight(module, rows, cols): - """HF Linear stores (out, in); PLENA uses (in, out).""" - return module.weight.detach().T.contiguous()[:rows, :cols] - - -def _split_heads(weight, head_dim, num_heads): - return [weight[:, h * head_dim:(h + 1) * head_dim].contiguous() for h in range(num_heads)] - - -def _find_model_root(model): - """Find the transformer backbone (model.model or model.model.text_model). - - Handles standard CausalLM models and VLMs like SmolVLM2. - """ - for candidate in [ - getattr(model, "model", None), - getattr(getattr(model, "model", None), "text_model", None), - getattr(model, "language_model", getattr(model, "text_model", None)), - ]: - if candidate is not None and hasattr(candidate, "layers"): - return candidate - raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") - - -def _extract_config(model): - """Extract config dimensions from the model, resolving text_config for VLMs.""" - config = getattr(model.config, "text_config", model.config) - hidden = config.hidden_size - inter = getattr(config, "intermediate_size", 4 * hidden) - num_heads = config.num_attention_heads - num_kv_heads = getattr(config, "num_key_value_heads", num_heads) - head_dim = hidden // num_heads - eps = getattr(config, "rms_norm_eps", 1e-5) - rope_theta = getattr(config, "rope_theta", 10000.0) - vocab_size = getattr(config, "vocab_size", None) - return { - "hidden_size": hidden, - "inter_dim": inter, - "num_heads": num_heads, - "num_kv_heads": num_kv_heads, - "head_dim": head_dim, - "eps": eps, - "rope_theta": rope_theta, - "vocab_size": vocab_size, - "model_type": getattr(config, "model_type", "unknown"), - } - - -def _extract_layer_weights(layer, hidden, inter, num_heads, head_dim, num_kv_heads=1): - """Extract one decoder layer in PLENA's (in, out) linear-weight convention.""" - norm = layer.input_layernorm - total_q_dim = num_heads * head_dim - total_kv_dim = num_kv_heads * head_dim - W_k_full = _linear_weight(layer.self_attn.k_proj, hidden, total_kv_dim) - W_v_full = _linear_weight(layer.self_attn.v_proj, hidden, total_kv_dim) - - weights = { - "W_q": _linear_weight(layer.self_attn.q_proj, hidden, total_q_dim), - "W_o": _linear_weight(layer.self_attn.o_proj, total_q_dim, hidden), - "W_k_heads": _split_heads(W_k_full, head_dim, num_kv_heads), - "W_v_heads": _split_heads(W_v_full, head_dim, num_kv_heads), - "W_gate": _linear_weight(layer.mlp.gate_proj, hidden, inter), - "W_up": _linear_weight(layer.mlp.up_proj, hidden, inter), - "W_down": _linear_weight(layer.mlp.down_proj, inter, hidden), - "eps": getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)), - } - return weights - - -# --------------------------------------------------------------------------- -# Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) -# --------------------------------------------------------------------------- -def _flash_attn_ref(Q, K, V, scale, causal=False): - """CPU reference: scaled dot-product attention matching hardware BF16 precision. - - Hardware path: - S = Q @ K^T (M_TMM in f32, written to VRAM as BF16 via M_MM_WO) - S *= scale (V_MUL_VF: BF16 * f32 -> BF16) - P = softmax(S) (online softmax on BF16 data) - O = P @ V (M_MM in f32, written to VRAM as BF16 via M_MM_WO) - - We model the key BF16 truncation points to match hardware precision. - """ - # S = Q @ K^T, then truncate to BF16 (M_MM_WO writes BF16) - scores = (Q @ K.T).to(torch.bfloat16).float() * scale - scores = scores.to(torch.bfloat16).float() # V_MUL_VF result is BF16 - if causal: - mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], - device=scores.device), diagonal=1).bool() - scores.masked_fill_(mask, float('-inf')) - # Softmax output written to BF16 VRAM - attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() - # O = P @ V, result written to BF16 VRAM - return (attn @ V).to(torch.bfloat16).float() - - -def _inter_round(x, to_inter, from_inter): - return from_inter(to_inter(x)) - - -def _hardware_rms_norm_ref(x, eps, to_inter, from_inter): - x_inter = to_inter(x) - rms = torch.rsqrt(from_inter(x_inter).pow(2).mean(-1, keepdim=True) + eps) - return _inter_round(from_inter(x_inter) * rms, to_inter, from_inter) - - -def _hardware_linear_ref(x, weight, mlen, to_inter, from_inter): - return _ksplit_matmul(x, weight, mlen, _HW_MAX_K_TILES, to_inter, from_inter) +@dataclass(frozen=True) +class LayerInputVars: + """PLENA input variables for one extracted decoder layer.""" -def _hardware_rope_ref(x, rope_matrix, cos_table, sin_table, to_inter, from_inter): - x_inter = _inter_round(x, to_inter, from_inter) - x_rot = _inter_round( - torch.matmul(x_inter, _inter_round(rope_matrix, to_inter, from_inter)), - to_inter, - from_inter, - ) - x_cos = _inter_round(x_inter * _inter_round(cos_table, to_inter, from_inter), to_inter, from_inter) - x_rot_sin = _inter_round(x_rot * _inter_round(sin_table, to_inter, from_inter), to_inter, from_inter) - return _inter_round(x_cos + x_rot_sin, to_inter, from_inter) - - -def _hardware_hbm_round_ref(x, quantize_weight, to_inter, from_inter): - return _inter_round(quantize_weight(x), to_inter, from_inter) - - -def _hardware_residual_add_ref(x, residual, to_inter, from_inter): - return _inter_round(x + residual, to_inter, from_inter) - - -def _hardware_ffn_ref(x, W_gate, W_up, W_down, eps, mlen, to_inter, from_inter): - residual = x.clone() - x = _hardware_rms_norm_ref(x, eps, to_inter, from_inter) - up_out = _hardware_linear_ref(x, W_up, mlen, to_inter, from_inter) - gate_out = _hardware_linear_ref(x, W_gate, mlen, to_inter, from_inter) - silu_gate = to_inter( - F.silu(_inter_round(up_out, to_inter, from_inter)) * _inter_round(gate_out, to_inter, from_inter) - ) - x = _hardware_linear_ref(from_inter(silu_gate), W_down, mlen, to_inter, from_inter) - return _hardware_residual_add_ref(_inter_round(x, to_inter, from_inter), residual, to_inter, from_inter) - - -def _fp32_rms_norm_ref(x, eps): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - -def _fp32_linear_ref(x, weight): - return x @ weight.float() - - -def _fp32_rope_ref(x, rope_matrix, cos_table, sin_table): - return x * cos_table.float() + (x @ rope_matrix.float()) * sin_table.float() - - -def _fp32_ffn_ref(x, weights, eps): - residual = x.clone() - x_normed = _fp32_rms_norm_ref(x, eps) - up_out = F.silu(_fp32_linear_ref(x_normed, weights["W_up"])) - gate_out = _fp32_linear_ref(x_normed, weights["W_gate"]) - return _fp32_linear_ref(up_out * gate_out, weights["W_down"]) + residual + w_q: Any + w_o: Any + w_k_heads: list[Any] + w_v_heads: list[Any] + w_gate: Any + w_up: Any + w_down: Any def _save_residual_and_norm(prog, source, scratch): @@ -347,13 +123,13 @@ def _emit_kv_stores(prog, current, layer_inputs, rope_inputs, layer_idx, num_kv_ K_h = _linear_projection( prog, current, - layer_inputs["W_k_heads"][kv_h], + layer_inputs.w_k_heads[kv_h], f"K_{layer_idx}_h{kv_h}", ) V_h = _linear_projection( prog, current, - layer_inputs["W_v_heads"][kv_h], + layer_inputs.w_v_heads[kv_h], f"V_{layer_idx}_h{kv_h}", ) @@ -394,7 +170,7 @@ def _emit_attention_block( ): _save_residual_and_norm(prog, current, scratch) - Q = _linear_projection(prog, current, layer_inputs["W_q"], f"Q_{layer_idx}") + Q = _linear_projection(prog, current, layer_inputs.w_q, f"Q_{layer_idx}") q_full_addr = prog.get_vram_addr(Q.name) O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim) @@ -445,36 +221,38 @@ def _emit_attention_block( ) _free_named_tensors(prog, ("O", "S", "PV")) - O_proj = _linear_projection(prog, O_full, layer_inputs["W_o"], f"O_proj_{layer_idx}") + O_proj = _linear_projection(prog, O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") return _add_residual(prog, O_proj, scratch) def _emit_ffn_block(prog, current, layer_inputs, scratch): _save_residual_and_norm(prog, current, scratch) - ops.ffn(prog, current, layer_inputs["W_gate"], layer_inputs["W_up"], layer_inputs["W_down"]) + ops.ffn(prog, current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) return _add_residual(prog, current, scratch) -def _layer_tensor_entries(layer_idx, weights, num_kv_heads): - entries = [ - (f"W_q_{layer_idx}", weights["W_q"]), - (f"W_o_{layer_idx}", weights["W_o"]), - ] - for kv_h in range(num_kv_heads): - entries.extend( - [ - (f"W_k_{layer_idx}_h{kv_h}", weights["W_k_heads"][kv_h]), - (f"W_v_{layer_idx}_h{kv_h}", weights["W_v_heads"][kv_h]), - ] - ) - entries.extend( - [ - (f"W_gate_{layer_idx}", weights["W_gate"]), - (f"W_up_{layer_idx}", weights["W_up"]), - (f"W_down_{layer_idx}", weights["W_down"]), - ] +def _register_layer_inputs(prog, layer_idx: int, weights: LayerWeights) -> LayerInputVars: + named_vars = {} + w_k_heads = [] + w_v_heads = [] + for tensor_name, tensor in weights.tensor_entries(layer_idx): + var = prog.input(tensor_name, shape=tuple(tensor.shape)) + if tensor_name.startswith(f"W_k_{layer_idx}_h"): + w_k_heads.append(var) + elif tensor_name.startswith(f"W_v_{layer_idx}_h"): + w_v_heads.append(var) + else: + named_vars[tensor_name[: tensor_name.rfind(f"_{layer_idx}")]] = var + + return LayerInputVars( + w_q=named_vars["W_q"], + w_o=named_vars["W_o"], + w_k_heads=w_k_heads, + w_v_heads=w_v_heads, + w_gate=named_vars["W_gate"], + w_up=named_vars["W_up"], + w_down=named_vars["W_down"], ) - return entries # --------------------------------------------------------------------------- @@ -496,16 +274,16 @@ def _verbose(message: str = ""): if verbose: print(message) - # -------------------------------------------------------------- config - model_cfg = _extract_config(model) - hidden = model_cfg["hidden_size"] - inter = model_cfg["inter_dim"] - num_heads = model_cfg["num_heads"] - num_kv_heads = model_cfg["num_kv_heads"] - head_dim = model_cfg["head_dim"] - total_q_dim = num_heads * head_dim + model_cfg = extract_model_config(model) + hidden = model_cfg.hidden_size + inter = model_cfg.inter_dim + num_heads = model_cfg.num_heads + num_kv_heads = model_cfg.num_kv_heads + head_dim = model_cfg.head_dim + total_q_dim = model_cfg.total_q_dim + ratio = model_cfg.head_ratio - root = _find_model_root(model) + root = find_model_root(model) layers = root.layers n_layers = num_layers if num_layers is not None else len(layers) assert layer_idx_start + n_layers <= len(layers), ( @@ -514,12 +292,10 @@ def _verbose(message: str = ""): ) scale = 1.0 / math.sqrt(head_dim) - - # ----------------------------------------------------------- embedding - embed = getattr(root, "embed_tokens", getattr(root, "wte", None)) + embed = embedding_module(root) print("=" * 80) - print(f"Model Compiler - {model_cfg['model_type']} ({n_layers} layer{'s' if n_layers != 1 else ''})") + print(f"Model Compiler - {model_cfg.model_type} ({n_layers} layer{'s' if n_layers != 1 else ''})") print( f" decoder: hidden={hidden}, inter={inter}, heads={num_heads}/{num_kv_heads}, " f"head_dim={head_dim}" @@ -532,29 +308,20 @@ def _verbose(message: str = ""): all_weights = [] for i in range(n_layers): layer_module = layers[layer_idx_start + i] - w = _extract_layer_weights( - layer_module, - hidden, - inter, - num_heads, - head_dim, - num_kv_heads=num_kv_heads, - ) + w = extract_layer_weights(layer_module, model_cfg) all_weights.append(w) _verbose( - f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " - f"W_gate={w['W_gate'].shape}, " - f"K_heads={len(w['W_k_heads'])}x{w['W_k_heads'][0].shape}, eps={w['eps']}" + f" Layer {i}: W_q={w.w_q.shape}, W_o={w.w_o.shape}, " + f"W_gate={w.w_gate.shape}, " + f"K_heads={len(w.w_k_heads)}x{w.w_k_heads[0].shape}, eps={w.eps}" ) - eps = all_weights[0]["eps"] + eps = all_weights[0].eps - # ----------------------------------------------------------- test data torch.manual_seed(seed) - # Embedding lookup: use real HF embedding table if available if embed is not None: - input_ids = torch.randint(0, model_cfg["vocab_size"] or 32000, (seq_len,)) + input_ids = torch.randint(0, model_cfg.vocab_size or 32000, (seq_len,)) with torch.no_grad(): token_embeds = embed(input_ids).float() if token_embeds.dim() == 3: @@ -577,124 +344,44 @@ def _verbose(message: str = ""): for i in range(n_layers): for kv_h in range(num_kv_heads): _verbose( - f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " - f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}" + f" W_k_{i}_h{kv_h}: {all_weights[i].w_k_heads[kv_h].shape}, " + f"W_v_{i}_h{kv_h}: {all_weights[i].w_v_heads[kv_h].shape}" ) print(f"attn_scale: {scale:.6f}") - # ----------------------------------------------------------- golden ref - _do_quant = golden_precision in ("hardware", "no_bf16") - _do_bf16 = golden_precision in ("hardware", "no_weight_quant") - _qw = quantize_to_mxfp if _do_quant else (lambda x: x) - _to_inter = (lambda x: x.to(torch.bfloat16)) if _do_bf16 else (lambda x: x) - _from_inter = (lambda x: x.float()) if _do_bf16 else (lambda x: x) - _prec_label = {"hardware": "MXFP8 weights + BF16 intermediates", - "no_weight_quant": "float32 weights + BF16 intermediates", - "no_bf16": "MXFP8 weights + float32 intermediates", - "fp32": "float32 weights + float32 intermediates"}[golden_precision] - print(f"\nComputing CPU golden reference ({_prec_label})") - - W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] - W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) for h in range(num_kv_heads)] for i in range(n_layers)] - R_matrix = _make_rotate_half_matrix(head_dim) - R_rope_q = _qw(R_matrix) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, model_cfg["rope_theta"]) - - X_gold = _qw(token_embeds.clone()) + _qw(pos_weight) # embedding_add (MXFP8-quantized, matching HBM) - ratio = num_heads // num_kv_heads - - for i in range(n_layers): - w = all_weights[i] - W_q_q = _qw(w["W_q"]) - W_o_q = _qw(w["W_o"]) - W_gate_q = _qw(w["W_gate"]) - W_up_q = _qw(w["W_up"]) - W_down_q = _qw(w["W_down"]) - - # --- Attention block --- - residual = X_gold.clone() - X_gold = _hardware_rms_norm_ref(X_gold, eps, _to_inter, _from_inter) - Q_gold = _inter_round(_hardware_linear_ref(X_gold, W_q_q, mlen, _to_inter, _from_inter), _to_inter, _from_inter) - - K_q_heads_i = [] - V_q_heads_i = [] - for kv_h in range(num_kv_heads): - K_h = _hardware_linear_ref(X_gold, W_k_q_heads[i][kv_h], mlen, _to_inter, _from_inter) - V_h = _hardware_linear_ref(X_gold, W_v_q_heads[i][kv_h], mlen, _to_inter, _from_inter) - K_h = _hardware_rope_ref(K_h, R_rope_q, cos_table, sin_table, _to_inter, _from_inter) - K_q_heads_i.append(_hardware_hbm_round_ref(K_h, _qw, _to_inter, _from_inter)) - V_q_heads_i.append(_hardware_hbm_round_ref(V_h, _qw, _to_inter, _from_inter)) - - O_heads = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - Q_h = _hardware_rope_ref(Q_h, R_rope_q, cos_table, sin_table, _to_inter, _from_inter) - O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) - O_heads.append(O_h) - attn_out = _inter_round(torch.cat(O_heads, dim=1), _to_inter, _from_inter) - O_gold = _inter_round(_hardware_linear_ref(attn_out, W_o_q, mlen, _to_inter, _from_inter), _to_inter, _from_inter) - X_gold = _hardware_residual_add_ref(O_gold, residual, _to_inter, _from_inter) - - # --- FFN block --- - X_gold = _hardware_ffn_ref(X_gold, W_gate_q, W_up_q, W_down_q, eps, mlen, _to_inter, _from_inter) - - _verbose(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") - - # Final norm - X_gold = _hardware_rms_norm_ref(X_gold, eps, _to_inter, _from_inter) - - golden_out = X_gold + R_matrix, cos_table, sin_table = make_rope_inputs(seq_len, model_cfg) + + golden_policy = ReferencePrecision.from_mode(golden_precision) + print(f"\nComputing CPU golden reference ({golden_policy.label})") + golden_out = run_decoder_reference( + token_embeds, + pos_weight, + all_weights, + model_cfg, + R_matrix, + cos_table, + sin_table, + mlen=mlen, + precision=golden_policy, + trace=lambda i, x: _verbose(f" After layer {i}: X_gold[0,:4] = {x[0, :4].tolist()}"), + ) print(f" golden_out: {golden_out.shape}") _verbose(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") - # ----------------------------------------------------------- HF ground truth - # Same pipeline as golden but pure float32: no MXFP8 quantization, no BF16 casting. - # This is the float32 reference for the same decoder blocks being compiled. print(f"\nComputing HF reference (float32, {n_layers} layer{'s' if n_layers != 1 else ''}, no quantization)") with torch.no_grad(): - X_hf = token_embeds.clone() + pos_weight # embedding_add - - for i in range(n_layers): - w = all_weights[i] - eps_i = w["eps"] - - residual = X_hf.clone() - X_normed = _fp32_rms_norm_ref(X_hf, eps_i) - - Q_hf = _fp32_linear_ref(X_normed, w["W_q"]) - - K_hf_heads = [] - V_hf_heads = [] - for kv_h in range(num_kv_heads): - K_hf_heads.append( - _fp32_rope_ref( - _fp32_linear_ref(X_normed, w["W_k_heads"][kv_h]), - R_matrix, - cos_table, - sin_table, - ) - ) - V_hf_heads.append(_fp32_linear_ref(X_normed, w["W_v_heads"][kv_h])) - - O_heads_hf = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_hf[:, h * head_dim:(h + 1) * head_dim] - Q_h = _fp32_rope_ref(Q_h, R_matrix, cos_table, sin_table) - O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) - O_heads_hf.append(O_h) - attn_out_hf = torch.cat(O_heads_hf, dim=1) - O_hf = _fp32_linear_ref(attn_out_hf, w["W_o"]) - X_hf = O_hf + residual - - X_hf = _fp32_ffn_ref(X_hf, w, eps_i) - - _verbose(f" After layer {i}: X_hf[0,:4] = {X_hf[0, :4].tolist()}") - - X_hf = _fp32_rms_norm_ref(X_hf, eps) - - hf_ground_truth = X_hf + hf_ground_truth = run_decoder_reference( + token_embeds, + pos_weight, + all_weights, + model_cfg, + R_matrix, + cos_table, + sin_table, + mlen=mlen, + precision=ReferencePrecision.from_mode("hf_fp32"), + trace=lambda i, x: _verbose(f" After layer {i}: X_hf[0,:4] = {x[0, :4].tolist()}"), + ) print(f" hf_ground_truth: {hf_ground_truth.shape}") _verbose(f" hf_ground_truth[0,:4]: {hf_ground_truth[0, :4].tolist()}") @@ -727,16 +414,7 @@ def _verbose(message: str = ""): # Per-layer weight inputs (order determines HBM layout) layer_inputs = [] for i in range(n_layers): - li_entry = {"W_k_heads": [], "W_v_heads": []} - for tensor_name, tensor in _layer_tensor_entries(i, all_weights[i], num_kv_heads): - var = prog.input(tensor_name, shape=tuple(tensor.shape)) - if tensor_name.startswith(f"W_k_{i}_h"): - li_entry["W_k_heads"].append(var) - elif tensor_name.startswith(f"W_v_{i}_h"): - li_entry["W_v_heads"].append(var) - else: - li_entry[tensor_name[: tensor_name.rfind(f"_{i}")]] = var - layer_inputs.append(li_entry) + layer_inputs.append(_register_layer_inputs(prog, i, all_weights[i])) # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") @@ -804,7 +482,7 @@ def _verbose(message: str = ""): input_tensors["causal_mask"] = causal_mask_data data_order.append("causal_mask") for i in range(n_layers): - for name, tensor in _layer_tensor_entries(i, all_weights[i], num_kv_heads): + for name, tensor in all_weights[i].tensor_entries(i): input_tensors[name] = tensor data_order.append(name) @@ -833,7 +511,7 @@ def _verbose(message: str = ""): } info = { - "model_type": model_cfg["model_type"], + "model_type": model_cfg.model_type, "hidden_size": hidden, "inter_dim": inter, "num_layers": n_layers, diff --git a/aten/reference.py b/aten/reference.py new file mode 100644 index 0000000..bde2020 --- /dev/null +++ b/aten/reference.py @@ -0,0 +1,291 @@ +"""CPU decoder references used by the PLENA ATen frontend.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from compiler.aten.model_extract import LayerWeights, ModelConfig +from quant.quantizer.hardware_quantizer.mxfp import _mx_fp_quantize_hardware + + +_HW_MAX_K_TILES = 4 + + +@dataclass(frozen=True) +class ReferencePrecision: + """Precision policy for the shared decoder reference runner.""" + + name: str + label: str + quantize_hbm: bool + bf16_intermediates: bool + use_ksplit: bool + + @classmethod + def from_mode(cls, mode: str) -> ReferencePrecision: + modes = { + "hardware": cls("hardware", "MXFP8 weights + BF16 intermediates", True, True, True), + "no_weight_quant": cls( + "no_weight_quant", + "float32 weights + BF16 intermediates", + False, + True, + True, + ), + "no_bf16": cls("no_bf16", "MXFP8 weights + float32 intermediates", True, False, True), + "fp32": cls("fp32", "float32 weights + float32 intermediates", False, False, True), + "hf_fp32": cls("hf_fp32", "float32, no quantization", False, False, False), + } + try: + return modes[mode] + except KeyError as exc: + valid = ", ".join(modes) + raise ValueError(f"Unknown reference precision '{mode}'. Expected one of: {valid}") from exc + + def quantize(self, tensor: torch.Tensor) -> torch.Tensor: + return quantize_to_mxfp(tensor) if self.quantize_hbm else tensor + + def to_inter(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(torch.bfloat16) if self.bf16_intermediates else tensor + + def from_inter(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.float() if self.bf16_intermediates else tensor + + +def quantize_to_mxfp(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor to MXFP8 matching HBM hardware format; return dequantized result.""" + orig_shape = tensor.shape + tensor_2d = tensor.float().reshape(-1, tensor.shape[-1]) + bm_x, _, _, _ = _mx_fp_quantize_hardware( + tensor_2d, + width=8, + exponent_width=4, + exponent_bias_width=8, + block_size=[1, 8], + ) + return bm_x.reshape(orig_shape) + + +def _make_rope_tables(seq_len: int, head_dim: int, theta: float = 10000.0): + """Compute RoPE cos/sin tables, shape (seq_len, head_dim).""" + half = head_dim // 2 + freqs = 1.0 / (theta ** (torch.arange(0, half).float() / half)) + positions = torch.arange(seq_len).float() + angles = torch.outer(positions, freqs) + cos_half = torch.cos(angles) + sin_half = torch.sin(angles) + cos = torch.cat([cos_half, cos_half], dim=-1) + sin = torch.cat([sin_half, sin_half], dim=-1) + return cos, sin + + +def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: + """Build the (head_dim, head_dim) matrix that computes rotate_half.""" + rotate = torch.zeros(head_dim, head_dim) + half = head_dim // 2 + for i in range(half): + rotate[i + half, i] = -1.0 + rotate[i, i + half] = 1.0 + return rotate + + +def make_rope_inputs(seq_len: int, config: ModelConfig) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the rotate_half matrix and RoPE tables used by CPU and PLENA paths.""" + rotate = _make_rotate_half_matrix(config.head_dim) + cos_table, sin_table = _make_rope_tables(seq_len, config.head_dim, config.rope_theta) + return rotate, cos_table, sin_table + + +def run_decoder_reference( + token_embeds: torch.Tensor, + pos_weight: torch.Tensor, + weights: list[LayerWeights], + config: ModelConfig, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + *, + mlen: int, + precision: ReferencePrecision, + trace: Callable[[int, torch.Tensor], None] | None = None, +) -> torch.Tensor: + """Run the compiled decoder blocks under a given precision policy.""" + quantize = precision.quantize + x = quantize(token_embeds.clone()) + quantize(pos_weight) + rope_ref = quantize(rope_matrix) + + for layer_idx, layer in enumerate(weights): + x = _attention_block_ref( + x, + layer, + config, + rope_ref, + cos_table, + sin_table, + mlen, + precision, + ) + x = _ffn_block_ref(x, layer, mlen, precision) + + if trace is not None: + trace(layer_idx, x) + + return _rms_norm_ref(x, weights[0].eps, precision) + + +def _attention_block_ref( + x: torch.Tensor, + layer: LayerWeights, + config: ModelConfig, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + quantize = precision.quantize + residual = x.clone() + x_normed = _rms_norm_ref(x, layer.eps, precision) + q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, precision), precision) + + k_heads = [] + v_heads = [] + for kv_h in range(config.num_kv_heads): + k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, precision) + v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, precision) + k_h = _rope_ref(k_h, rope_matrix, cos_table, sin_table, precision) + k_heads.append(_hbm_round_ref(k_h, precision)) + v_heads.append(_hbm_round_ref(v_h, precision)) + + scale = 1.0 / math.sqrt(config.head_dim) + o_heads = [] + for h in range(config.num_heads): + kv_h = h // config.head_ratio + q_h = q_full[:, h * config.head_dim:(h + 1) * config.head_dim] + q_h = _rope_ref(q_h, rope_matrix, cos_table, sin_table, precision) + o_heads.append(_flash_attn_ref(q_h, k_heads[kv_h], v_heads[kv_h], scale, causal=True)) + + attn_out = _round(torch.cat(o_heads, dim=1), precision) + o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, precision), precision) + return _residual_add_ref(o_proj, residual, precision) + + +def _ffn_block_ref( + x: torch.Tensor, + layer: LayerWeights, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + quantize = precision.quantize + residual = x.clone() + x_normed = _rms_norm_ref(x, layer.eps, precision) + up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, precision) + gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, precision) + silu_gate = precision.to_inter(F.silu(_round(up_out, precision)) * _round(gate_out, precision)) + x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, precision) + return _residual_add_ref(_round(x, precision), residual, precision) + + +def _round(x: torch.Tensor, precision: ReferencePrecision) -> torch.Tensor: + return precision.from_inter(precision.to_inter(x)) + + +def _rms_norm_ref(x: torch.Tensor, eps: float, precision: ReferencePrecision) -> torch.Tensor: + x_inter = precision.to_inter(x) + rms = torch.rsqrt(precision.from_inter(x_inter).pow(2).mean(-1, keepdim=True) + eps) + return _round(precision.from_inter(x_inter) * rms, precision) + + +def _linear_ref( + x: torch.Tensor, + weight: torch.Tensor, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + if not precision.use_ksplit: + return x @ weight.float() + return _ksplit_matmul( + x, + weight, + mlen=mlen, + max_k_tiles=_HW_MAX_K_TILES, + to_inter=precision.to_inter, + from_inter=precision.from_inter, + ) + + +def _rope_ref( + x: torch.Tensor, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + precision: ReferencePrecision, +) -> torch.Tensor: + x_inter = _round(x, precision) + x_rot = _round( + torch.matmul(x_inter, _round(rope_matrix, precision)), + precision, + ) + x_cos = _round(x_inter * _round(cos_table, precision), precision) + x_rot_sin = _round(x_rot * _round(sin_table, precision), precision) + return _round(x_cos + x_rot_sin, precision) + + +def _hbm_round_ref(x: torch.Tensor, precision: ReferencePrecision) -> torch.Tensor: + return _round(precision.quantize(x), precision) + + +def _residual_add_ref( + x: torch.Tensor, + residual: torch.Tensor, + precision: ReferencePrecision, +) -> torch.Tensor: + return _round(x + residual, precision) + + +def _flash_attn_ref(Q, K, V, scale, causal=False): + """CPU reference: scaled dot-product attention matching hardware BF16 precision.""" + scores = (Q @ K.T).to(torch.bfloat16).float() * scale + scores = scores.to(torch.bfloat16).float() + if causal: + mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], device=scores.device), diagonal=1).bool() + scores.masked_fill_(mask, float("-inf")) + attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() + return (attn @ V).to(torch.bfloat16).float() + + +def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, from_inter=None): + """Matrix multiply matching hardware K-split BF16 precision.""" + if to_inter is None: + def to_inter(x): + return x.to(torch.bfloat16) + + if from_inter is None: + def from_inter(x): + return x.float() + + k_total = A.shape[1] + num_k_tiles = math.ceil(k_total / mlen) + + if num_k_tiles <= max_k_tiles: + return from_inter(to_inter(torch.matmul(from_inter(to_inter(A)), from_inter(to_inter(B))))) + + result = None + k_start = 0 + while k_start < k_total: + k_end = min(k_start + max_k_tiles * mlen, k_total) + a_chunk = A[:, k_start:k_end] + b_chunk = B[k_start:k_end, :] + partial = from_inter(to_inter(torch.matmul(from_inter(to_inter(a_chunk)), from_inter(to_inter(b_chunk))))) + if result is None: + result = partial + else: + result = from_inter(to_inter(result) + to_inter(partial)) + k_start = k_end + + return result From f989e1c5e3a5fa5990409195c8e30735009aca8e Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 14:28:46 +0100 Subject: [PATCH 23/32] refactor: call PLENA compiler directly from ATen frontend --- aten/ops/plena/attention_ops.py | 200 +------------------------------ aten/ops/plena/conv_ops.py | 7 +- aten/ops/plena/embedding_ops.py | 20 +--- aten/ops/plena/ffn_ops.py | 60 +--------- aten/ops/plena/linear_ops.py | 85 +------------ aten/plena/program_attention.py | 134 +++++++++++++++++++++ aten/plena/program_matrix_ops.py | 83 +++++++++++++ aten/plena/program_tensors.py | 39 ++++++ aten/plena_frontend.py | 26 ++-- generator/aten_runner.py | 30 ++--- generator/runner.py | 2 +- 11 files changed, 290 insertions(+), 396 deletions(-) diff --git a/aten/ops/plena/attention_ops.py b/aten/ops/plena/attention_ops.py index 54678a1..3f07564 100644 --- a/aten/ops/plena/attention_ops.py +++ b/aten/ops/plena/attention_ops.py @@ -1,201 +1,5 @@ -"""PLENA backend implementation for Flash Attention operator.""" +"""PLENA backend compatibility shim for Flash Attention.""" def flash_attention_plena(prog, Q, K, V, scale=None, hq=1, hkv=1, h_qkv=None, causal_mask=None): - """PLENA backend: Flash Attention via PlenaCompiler. - - Dispatches to one of two codegen paths based on shape: - * MHA (hq == hkv == 1) — online-softmax loop using PlenaCompiler primitives - * GQA (hq // hkv > 1) — fused codegen via `flash_attn_asm` template, packs - `hq/hkv` Q heads into blen of M_BTMM (hardware GQA fusion, matches main) - - Args: - prog: PlenaCompiler instance - Q: VRAMMatrixVar — Q in VRAM, shape (seq_len, hq*h_qkv) - K: InputVar — K in HBM, shape (seq_len, hkv*h_qkv_padded) - V: InputVar — V in HBM, shape (seq_len, hkv*h_qkv_padded) - scale: Attention scale (default 1/sqrt(head_dim)) - hq, hkv, h_qkv: GQA params. Defaults treat input as single-head MHA. - causal_mask: VRAMMatrixVar or None — (mlen, mlen) mask with 0 on/below - diagonal and -inf above. Added to S before softmax. - - Returns: - VRAMMatrixVar for O, shape matching Q. - """ - - # Detect MHA vs GQA - if hq == 1 and hkv == 1: - return _flash_attention_mha(prog, Q, K, V, scale, causal_mask=causal_mask) - - # GQA: dispatch to fused codegen - if h_qkv is None: - raise ValueError("GQA mode requires h_qkv to be specified") - if causal_mask is not None: - raise NotImplementedError("causal_mask is not yet supported for GQA flash attention") - return _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv) - - -def _flash_attention_mha(prog, Q, K, V, scale, causal_mask=None): - """Single-head flash attention using PlenaCompiler primitives (online softmax). - - Args: - prog: PlenaCompiler instance - Q: VRAMMatrixVar — query matrix in VRAM - K: InputVar — key matrix in HBM - V: InputVar — value matrix in HBM - scale: Attention scale factor - causal_mask: VRAMMatrixVar or None — (mlen, mlen) mask with 0 on/below - diagonal and -inf above. When provided, added to S_block - after QK^T and before online softmax. - """ - import math - - seq_len, head_dim = Q.shape - mlen = prog.mlen - - if scale is None: - scale = 1.0 / math.sqrt(head_dim) - - num_q_blocks = seq_len // mlen - num_k_blocks = seq_len // mlen - - S_block = prog.alloc("S", mlen, mlen) - PV = prog.alloc("PV", mlen, head_dim) - O = prog.alloc("O", seq_len, head_dim) - - for q_idx in range(num_q_blocks): - prog.init_online_softmax(q_idx, O) - - for k_idx in range(num_k_blocks): - prog.vram_sub_projection_T_to( - Q, - q_idx, - K, - k_idx, - S_block, - target_row_idx=0, - target_col_idx=0, - ) - # Apply causal mask before softmax: S += mask (-inf above diagonal) - if causal_mask is not None: - prog.vram_add(S_block, causal_mask) - prog.online_softmax_block(S_block, scale) - prog.compute_pv(S_block, V, k_idx, PV, head_dim) - prog.scale_o_row(O, q_idx) - prog.vram_add(O, PV, dst_row_offset=q_idx * mlen) - - prog.final_scale_o(q_idx, O) - - return O - - -def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): - """GQA flash attention using fused M_BTMM codegen (main branch template). - - Packs `hq/hkv` Q heads into the `blen` systolic dimension — one M_BTMM - produces `ratio` head outputs in parallel. Requires: - (hq // hkv) == prog.blen, and (hq // hkv) * h_qkv == prog.mlen. - """ - import math - from compiler.asm_templates.flashattn import flash_attn_asm - from compiler.asm_templates import preload_addr_reg_asm - - ratio = hq // hkv - mlen = prog.mlen - blen = prog.blen - vlen = mlen - - if ratio != blen: - raise ValueError( - f"GQA ratio hq/hkv={ratio} must equal blen={blen} (hardware packs " - f"heads into blen). Use ratio=4 with default blen." - ) - if ratio * h_qkv != mlen: - raise ValueError( - f"GQA constraint: (hq/hkv)*h_qkv = {ratio * h_qkv} must equal " - f"mlen={mlen}. E.g. hq=4, hkv=1, h_qkv=16 with mlen=64." - ) - - s_q, _q_total_dim = Q.shape - s_kv, _k_total_dim = K.shape - - if scale is None: - scale = 1.0 / math.sqrt(h_qkv) - - # Allocate HBM addr registers for K and V (C_SET_ADDR_REG aN) - # Make sure K/V HBM subst-matrix registry knows about them - prog._ensure_hbm_sub_matrix_registered(K) - prog._ensure_hbm_sub_matrix_registered(V) - alloc = prog.register_allocator - k_addr, v_addr = alloc.allocate_addr(2) - gp_for_preload = alloc.allocate_gp(2) - setup = preload_addr_reg_asm( - addr_reg_to_set=[k_addr, v_addr], - available_registers=gp_for_preload, - addr_reg_val=[K.hbm_addr, V.hbm_addr], - ) - alloc.free_gp(gp_for_preload) - prog.emit(setup) - - # Allocate VRAM buffers mirroring main's layout. - # S, PV each require mlen*mlen*ratio elements; O is s_q * (hq*h_qkv). - # We let PlenaCompiler's allocator handle placement — main's template uses - # vector_sram_base_address + computed offsets for S/PV/O, so they must be - # allocated contiguously starting right after Q. allocate_vram_matrix - # bump-allocates, which preserves that contiguity. - q_vram_base = prog.get_vram_addr(Q.name) - s_name = prog._scoped_name("_gqa_S") - pv_name = prog._scoped_name("_gqa_PV") - o_name = prog._scoped_name("O") - - # Express sizes as (rows, cols) that multiply to the required counts. - prog.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) - prog.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) - prog.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) - - # Reserve FPRAM for multi-head softmax state. main's flash_attn_asm - # assumes slots 0..2 hold constants (0.0, scale, -inf, preloaded by the - # test harness) and softmax state starts at fp_sram_start_address, with - # 3 triples (m_old, m_res, l_old) per head, strided 3*br apart. - # Reserve slots 0..2 first so our bump-allocated state lives at offset 3+. - br = min(mlen, s_q) - fp_allocs = prog.fpram_allocator - if "_gqa_fp_const_zero" not in fp_allocs.allocations: - fp_allocs.allocate(name="_gqa_fp_const_zero", size=1) - fp_allocs.allocate(name="_gqa_fp_const_scale", size=1) - fp_allocs.allocate(name="_gqa_fp_const_neg_inf", size=1) - fp_state_size = 3 * br * ratio - fp_info = prog.add_fpram_object(name="_gqa_softmax_state", size=fp_state_size) - if fp_info.fpram_addr is None: - raise RuntimeError("Failed to allocate FPRAM for GQA softmax state") - fp_start = fp_info.fpram_addr - - # Call main's fused GQA template - asm = flash_attn_asm( - mlen=mlen, - vlen=vlen, - blen=blen, - batch=1, - hq=hq, - hkv=hkv, - d=h_qkv, - q_len=s_q, - kv_len=s_kv, - alive_registers_int=list(range(1, 16)), - alive_registers_fp=list(range(1, 8)), - vector_sram_base_address=q_vram_base, - fp_sram_start_address=fp_start, - k_base_hbm_offset_reg=k_addr, - v_base_hbm_offset_reg=v_addr, - ) - prog.emit(asm) - - # Release HBM addr regs (they're only needed during the call) - alloc.free_addr([k_addr, v_addr]) - - # Return O as a VRAMMatrixVar the caller can consume - from compiler.aten.plena.vars import VRAMMatrixVar - - O = VRAMMatrixVar(prog, o_name, (s_q, hq * h_qkv), display_name="O") - prog._tensors[o_name] = O - return O + return prog.flash_attention(Q, K, V, scale=scale, hq=hq, hkv=hkv, h_qkv=h_qkv, causal_mask=causal_mask) diff --git a/aten/ops/plena/conv_ops.py b/aten/ops/plena/conv_ops.py index 314671e..3fecc13 100644 --- a/aten/ops/plena/conv_ops.py +++ b/aten/ops/plena/conv_ops.py @@ -9,7 +9,7 @@ patches. Requires the emulator to support V_SHIFT_V (opcode 0x31, LSB-first right-shift per main.rs fix by George Wu, commit 24eb011). -After im2col the systolic matmul uses the standard linear_plena path. +After im2col the systolic matmul uses the compiler's standard linear path. HBM layout convention (caller must arrange data accordingly): input_raw shape = (C_in * H, W_padded) — each row is one spatial row of one channel @@ -20,9 +20,6 @@ With W_padded=64 and ow=0 (OW=1): offset = (c*H + oh+kr) * 64 → always aligned. """ -from compiler.aten.ops.plena.linear_ops import linear_plena - - _PREFETCH_V_AMOUNT = 4 # H_PREFETCH_V always loads this many VRAM rows @@ -196,4 +193,4 @@ def conv2d_plena( # ------------------------------------------------------------------ # Systolic matmul: im2col_out @ weight_2d -> (M, C_out) # ------------------------------------------------------------------ - return linear_plena(prog, output_mat, weight_2d_var) + return prog.linear(output_mat, weight_2d_var) diff --git a/aten/ops/plena/embedding_ops.py b/aten/ops/plena/embedding_ops.py index 495f876..f4ab13b 100644 --- a/aten/ops/plena/embedding_ops.py +++ b/aten/ops/plena/embedding_ops.py @@ -1,23 +1,9 @@ -"""PLENA backend implementations for positional encoding operators.""" +"""PLENA backend compatibility shims for positional encoding operators.""" def embedding_add_plena(prog, input_var, pos_weight_var): - """PLENA backend: add learned position embeddings to input in-place. - - Both input_var and pos_weight_var must be VRAMMatrixVar with the same shape. - Uses prog.vram_add() which emits V_ADD_VV row-by-row. - """ - prog.vram_add(input_var, pos_weight_var) - return input_var + return prog.embedding_add(input_var, pos_weight_var) def rope_plena(prog, x_var, x_rot_var, cos_var, sin_var): - """PLENA backend: apply RoPE in-place. - - x_var is updated in-place: x = x * cos + rotate_half(x) * sin - - x_rot_var must be a VRAMMatrixVar holding rotate_half(x), preloaded from HBM - by the caller before dispatching this op. - """ - prog.rope(x_var, x_rot_var, cos_var, sin_var) - return x_var + return prog.rope(x_var, x_rot_var, cos_var, sin_var) diff --git a/aten/ops/plena/ffn_ops.py b/aten/ops/plena/ffn_ops.py index d695242..726fc47 100644 --- a/aten/ops/plena/ffn_ops.py +++ b/aten/ops/plena/ffn_ops.py @@ -1,61 +1,5 @@ -"""PLENA backend implementation for FFN operator.""" - -from compiler.asm_templates import ffn_asm, preload_addr_reg_asm, reset_reg_asm +"""PLENA backend compatibility shim for FFN.""" def ffn_plena(prog, input_var, w_gate, w_up, w_down): - """PLENA backend: FFN with SiLU gate via ffn_asm. - - Generates ISA code for: w_down @ (silu(w_gate @ x) * (w_up @ x)) - - Args: - prog: PlenaCompiler instance - input_var: BatchVar — activation in VRAM, shape (batch, hidden) - w_gate: InputVar — gate weight in HBM, shape (hidden, inter_dim) - w_up: InputVar — up-projection weight in HBM, shape (hidden, inter_dim) - w_down: InputVar — down-projection weight in HBM, shape (inter_dim, hidden) - - Returns: - VRAMMatrixVar for the FFN output (stored at activation_base_address in VRAM). - """ - batch_size, hidden_size = input_var.shape - _, inter_dim = w_up.shape - mlen = prog.mlen - blen = prog.blen - vlen = prog.mlen - - # Retrieve VRAM address of the loaded activation - activation_base_address = prog.get_vram_addr(input_var.name) - - # Set HBM address registers for each weight matrix - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[1, 2, 3], - available_registers=[1, 2, 3], - addr_reg_val=[w_gate.hbm_addr, w_up.hbm_addr, w_down.hbm_addr], - ) - - # Reset registers before FFN kernel - isa_code += reset_reg_asm(alive_registers=[1, 2, 3]) - - # Generate FFN ISA kernel - isa_code += ffn_asm( - mlen=mlen, - vlen=vlen, - blen=blen, - batch=batch_size, - seq_len=1, - hidden_size=hidden_size, - intermediate_size=inter_dim, - alive_registers=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - gate_weight_hbm_offset_reg=1, - up_weight_hbm_offset_reg=2, - down_weight_hbm_offset_reg=3, - const_one_fp_address=5, - activation_base_address=activation_base_address, - use_loop_instructions=True, - ) - - prog.emit(isa_code) - - # FFN result is written back to the activation area in VRAM (in-place overwrite) - return input_var + return prog.ffn(input_var, w_gate, w_up, w_down) diff --git a/aten/ops/plena/linear_ops.py b/aten/ops/plena/linear_ops.py index 9f39cf6..4619529 100644 --- a/aten/ops/plena/linear_ops.py +++ b/aten/ops/plena/linear_ops.py @@ -1,88 +1,9 @@ -"""PLENA backend stubs for linear projection operators.""" - -import math - - -MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements - - -def _iter_k_chunks(num_k_tiles: int): - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - yield k_start, k_end - k_start - k_start = k_end +"""PLENA backend compatibility shims for linear operators.""" def linear_projection_plena(prog, input_var, weight_var, name: str = "linear_out"): - """Emit tiled PLENA linear projection, including K-split accumulation.""" - mlen = prog.mlen - - rows, k_total = input_var.shape - _, out_features = weight_var.shape - num_row_blocks = math.ceil(rows / mlen) - assert out_features % mlen == 0, ( - f"out_features ({out_features}) must be a multiple of mlen ({mlen})" - ) - num_col_blocks = out_features // mlen - num_k_tiles = math.ceil(k_total / mlen) - - # Allocate output matrix. When batch is not a multiple of mlen we pass - # strict=False so the allocator doesn't reject the shape; the hardware will - # operate on full mlen-wide tiles (HBM zero-pads unused rows) and only the - # first `rows` rows of the output contain valid results. - output_strict = rows % mlen == 0 - output = prog.alloc(name, rows, out_features, strict=output_strict) - - def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **k_split): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - target, - target_row_idx, - target_col_idx, - **k_split, - ) - - if num_k_tiles <= MAX_K_TILES: - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - emit_projection(row_idx, col_idx, output, row_idx, col_idx) - else: - # Temp buffer for partial sums — only needs one (mlen, mlen) tile since - # each sub_projection_to writes a single tile before accumulating. - # Using the full output shape would cause VRAM overlap with the output - # when out_features > mlen (column-block-major layout). - temp = prog.alloc(f"{name}_temp", mlen, mlen) - - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)): - k_split = { - "k_block_start": k_block_start, - "k_block_count": k_block_count, - } - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - if k_chunk_idx == 0: - emit_projection(row_idx, col_idx, output, row_idx, col_idx, **k_split) - else: - emit_projection(row_idx, col_idx, temp, 0, 0, **k_split) - prog.vram_block_add_to( - output, - row_idx, - col_idx, - temp, - 0, - 0, - output, - row_idx, - col_idx, - ) - prog.free_tensor(temp) - - return output + return prog.linear_projection(input_var, weight_var, name) def linear_plena(prog, input_var, weight_var): - return linear_projection_plena(prog, input_var, weight_var) + return prog.linear(input_var, weight_var) diff --git a/aten/plena/program_attention.py b/aten/plena/program_attention.py index 77ce389..0e22c9b 100644 --- a/aten/plena/program_attention.py +++ b/aten/plena/program_attention.py @@ -2,6 +2,10 @@ from __future__ import annotations +import math + +from compiler.asm_templates import preload_addr_reg_asm +from compiler.asm_templates.flashattn import flash_attn_asm from compiler.aten.plena.vars import InputVar, VRAMMatrixVar @@ -10,6 +14,136 @@ class ProgramAttentionMixin: # Flash Attention Operations # ======================================================================== + def flash_attention(self, Q, K, V, scale=None, hq=1, hkv=1, h_qkv=None, causal_mask=None): + """Emit flash attention, dispatching to MHA or fused GQA codegen by shape.""" + if hq == 1 and hkv == 1: + return self._flash_attention_mha(Q, K, V, scale, causal_mask=causal_mask) + + if h_qkv is None: + raise ValueError("GQA mode requires h_qkv to be specified") + if causal_mask is not None: + raise NotImplementedError("causal_mask is not yet supported for GQA flash attention") + return self._flash_attention_gqa_fused(Q, K, V, scale, hq, hkv, h_qkv) + + def _flash_attention_mha(self, Q, K, V, scale, causal_mask=None): + """Single-head online-softmax flash attention using compiler primitives.""" + seq_len, head_dim = Q.shape + mlen = self.mlen + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + num_q_blocks = seq_len // mlen + num_k_blocks = seq_len // mlen + + S_block = self.alloc("S", mlen, mlen) + PV = self.alloc("PV", mlen, head_dim) + O = self.alloc("O", seq_len, head_dim) + + for q_idx in range(num_q_blocks): + self.init_online_softmax(q_idx, O) + + for k_idx in range(num_k_blocks): + self.vram_sub_projection_T_to( + Q, + q_idx, + K, + k_idx, + S_block, + target_row_idx=0, + target_col_idx=0, + ) + if causal_mask is not None: + self.vram_add(S_block, causal_mask) + self.online_softmax_block(S_block, scale) + self.compute_pv(S_block, V, k_idx, PV, head_dim) + self.scale_o_row(O, q_idx) + self.vram_add(O, PV, dst_row_offset=q_idx * mlen) + + self.final_scale_o(q_idx, O) + + return O + + def _flash_attention_gqa_fused(self, Q, K, V, scale, hq, hkv, h_qkv): + """GQA flash attention using the fused M_BTMM template.""" + ratio = hq // hkv + mlen = self.mlen + blen = self.blen + vlen = mlen + + if ratio != blen: + raise ValueError( + f"GQA ratio hq/hkv={ratio} must equal blen={blen} " + "(hardware packs heads into blen)." + ) + if ratio * h_qkv != mlen: + raise ValueError( + f"GQA constraint: (hq/hkv)*h_qkv = {ratio * h_qkv} must equal mlen={mlen}." + ) + + s_q, _q_total_dim = Q.shape + s_kv, _k_total_dim = K.shape + + if scale is None: + scale = 1.0 / math.sqrt(h_qkv) + + self._ensure_hbm_sub_matrix_registered(K) + self._ensure_hbm_sub_matrix_registered(V) + alloc = self.register_allocator + k_addr, v_addr = alloc.allocate_addr(2) + gp_for_preload = alloc.allocate_gp(2) + setup = preload_addr_reg_asm( + addr_reg_to_set=[k_addr, v_addr], + available_registers=gp_for_preload, + addr_reg_val=[K.hbm_addr, V.hbm_addr], + ) + alloc.free_gp(gp_for_preload) + self.emit(setup) + + q_vram_base = self.get_vram_addr(Q.name) + s_name = self._scoped_name("_gqa_S") + pv_name = self._scoped_name("_gqa_PV") + o_name = self._scoped_name("O") + + self.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) + self.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) + self.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) + + br = min(mlen, s_q) + fp_allocs = self.fpram_allocator + if "_gqa_fp_const_zero" not in fp_allocs.allocations: + fp_allocs.allocate(name="_gqa_fp_const_zero", size=1) + fp_allocs.allocate(name="_gqa_fp_const_scale", size=1) + fp_allocs.allocate(name="_gqa_fp_const_neg_inf", size=1) + fp_info = self.add_fpram_object(name="_gqa_softmax_state", size=3 * br * ratio) + if fp_info.fpram_addr is None: + raise RuntimeError("Failed to allocate FPRAM for GQA softmax state") + + self.emit( + flash_attn_asm( + mlen=mlen, + vlen=vlen, + blen=blen, + batch=1, + hq=hq, + hkv=hkv, + d=h_qkv, + q_len=s_q, + kv_len=s_kv, + alive_registers_int=list(range(1, 16)), + alive_registers_fp=list(range(1, 8)), + vector_sram_base_address=q_vram_base, + fp_sram_start_address=fp_info.fpram_addr, + k_base_hbm_offset_reg=k_addr, + v_base_hbm_offset_reg=v_addr, + ) + ) + + alloc.free_addr([k_addr, v_addr]) + O = VRAMMatrixVar(self, o_name, (s_q, hq * h_qkv), display_name="O") + self._tensors[o_name] = O + return O + def init_online_softmax(self, q_idx: int, o_matrix: VRAMMatrixVar): """Initialize Online Softmax state: m=-inf, l=0, O_row=0""" o_info = super().get_tensor_info(o_matrix.name) diff --git a/aten/plena/program_matrix_ops.py b/aten/plena/program_matrix_ops.py index 840da57..22224c6 100644 --- a/aten/plena/program_matrix_ops.py +++ b/aten/plena/program_matrix_ops.py @@ -2,8 +2,20 @@ from __future__ import annotations +import math + from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar +MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements + + +def _iter_k_chunks(num_k_tiles: int): + k_start = 0 + while k_start < num_k_tiles: + k_end = min(k_start + MAX_K_TILES, num_k_tiles) + yield k_start, k_end - k_start + k_start = k_end + class ProgramMatrixOpsMixin: # ======================================================================== @@ -114,6 +126,72 @@ def vram_sub_projection_T_to( target_col_idx=target_col_idx, ) + def linear_projection(self, input_var: VRAMMatrixVar, weight_var: InputVar, name: str = "linear_out"): + """Emit tiled PLENA linear projection, including K-split accumulation.""" + mlen = self.mlen + + rows, k_total = input_var.shape + _, out_features = weight_var.shape + num_row_blocks = math.ceil(rows / mlen) + if out_features % mlen != 0: + raise ValueError(f"out_features ({out_features}) must be a multiple of mlen ({mlen})") + num_col_blocks = out_features // mlen + num_k_tiles = math.ceil(k_total / mlen) + + # When rows is not a multiple of mlen the hardware still operates on + # full tiles; only the first `rows` rows contain valid output. + output = self.alloc(name, rows, out_features, strict=rows % mlen == 0) + + def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **k_split): + self.vram_sub_projection_to( + input_var, + row_idx, + weight_var, + col_idx, + target, + target_row_idx, + target_col_idx, + **k_split, + ) + + if num_k_tiles <= MAX_K_TILES: + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + emit_projection(row_idx, col_idx, output, row_idx, col_idx) + return output + + # Temp buffer for one partial-sum tile. Allocating the full output shape + # here can overlap with the real output for wide projections. + temp = self.alloc(f"{name}_temp", mlen, mlen) + for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)): + k_split = { + "k_block_start": k_block_start, + "k_block_count": k_block_count, + } + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + if k_chunk_idx == 0: + emit_projection(row_idx, col_idx, output, row_idx, col_idx, **k_split) + else: + emit_projection(row_idx, col_idx, temp, 0, 0, **k_split) + self.vram_block_add_to( + output, + row_idx, + col_idx, + temp, + 0, + 0, + output, + row_idx, + col_idx, + ) + self.free_tensor(temp) + return output + + def linear(self, input_var: VRAMMatrixVar, weight_var: InputVar): + """Default linear op compatibility surface.""" + return self.linear_projection(input_var, weight_var) + # ======================================================================== # RoPE (1D Positional Encoding) # ======================================================================== @@ -159,6 +237,11 @@ def vram_add( num_rows=num_rows, ) + def embedding_add(self, input_var: VRAMMatrixVar, pos_weight_var: VRAMMatrixVar): + """Add learned/positional embedding weights to input in-place.""" + self.vram_add(input_var, pos_weight_var) + return input_var + def vram_block_add_to( self, src1: TensorVar, diff --git a/aten/plena/program_tensors.py b/aten/plena/program_tensors.py index 9184c66..b8b946b 100644 --- a/aten/plena/program_tensors.py +++ b/aten/plena/program_tensors.py @@ -2,6 +2,7 @@ from __future__ import annotations +from compiler.asm_templates import ffn_asm, preload_addr_reg_asm, reset_reg_asm from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar @@ -315,5 +316,43 @@ def layer_norm( scratchpad_vram_addr=scratchpad_vram_addr, ) + # ======================================================================== + # Composite Decoder Operations + # ======================================================================== + + def ffn(self, input_var: VRAMMatrixVar, w_gate: InputVar, w_up: InputVar, w_down: InputVar): + """Emit the fused FFN kernel and return the in-place activation var.""" + batch_size, hidden_size = input_var.shape + _, inter_dim = w_up.shape + mlen = self.mlen + blen = self.blen + activation_base_address = self.get_vram_addr(input_var.name) + + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[1, 2, 3], + available_registers=[1, 2, 3], + addr_reg_val=[w_gate.hbm_addr, w_up.hbm_addr, w_down.hbm_addr], + ) + isa_code += reset_reg_asm(alive_registers=[1, 2, 3]) + isa_code += ffn_asm( + mlen=mlen, + vlen=mlen, + blen=blen, + batch=batch_size, + seq_len=1, + hidden_size=hidden_size, + intermediate_size=inter_dim, + alive_registers=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + gate_weight_hbm_offset_reg=1, + up_weight_hbm_offset_reg=2, + down_weight_hbm_offset_reg=3, + const_one_fp_address=5, + activation_base_address=activation_base_address, + use_loop_instructions=True, + ) + + self.emit(isa_code) + return input_var + __all__ = ["ProgramTensorMixin"] diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 47576fb..fdd4355 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -15,8 +15,6 @@ find_model_root, ) from compiler.aten.plena_compiler import PlenaCompiler -from compiler.aten.ops.registry import OpRegistry, Backend -from compiler.aten.ops.plena.linear_ops import linear_projection_plena as _linear_projection from compiler.aten.reference import ( ReferencePrecision, _ksplit_matmul, @@ -25,7 +23,6 @@ quantize_to_mxfp, run_decoder_reference, ) -import compiler.aten.ops as ops __all__ = [ "_fix_large_immediates", @@ -37,6 +34,7 @@ _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate + def _fix_large_immediates(isa_code: str) -> str: """Post-process ISA: replace S_ADDI_INT gp{r}, gp0, {large} with S_LUI_INT + S_ADDI_INT. @@ -95,7 +93,7 @@ def _add_residual(prog, target, scratch): def _apply_rope_projection(prog, x_var, rope_matrix, cos_var, sin_var, name): - x_rot = _linear_projection(prog, x_var, rope_matrix, name) + x_rot = prog.linear_projection(x_var, rope_matrix, name) prog.rope(x_var, x_rot, cos_var, sin_var) prog.free_tensor(x_rot) return x_var @@ -120,14 +118,12 @@ def _emit_kv_stores(prog, current, layer_inputs, rope_inputs, layer_idx, num_kv_ rope_matrix, cos_var, sin_var = rope_inputs kv_stored = [] for kv_h in range(num_kv_heads): - K_h = _linear_projection( - prog, + K_h = prog.linear_projection( current, layer_inputs.w_k_heads[kv_h], f"K_{layer_idx}_h{kv_h}", ) - V_h = _linear_projection( - prog, + V_h = prog.linear_projection( current, layer_inputs.w_v_heads[kv_h], f"V_{layer_idx}_h{kv_h}", @@ -170,7 +166,7 @@ def _emit_attention_block( ): _save_residual_and_norm(prog, current, scratch) - Q = _linear_projection(prog, current, layer_inputs.w_q, f"Q_{layer_idx}") + Q = prog.linear_projection(current, layer_inputs.w_q, f"Q_{layer_idx}") q_full_addr = prog.get_vram_addr(Q.name) O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim) @@ -201,8 +197,7 @@ def _emit_attention_block( f"Q_rot_{layer_idx}_h{h}", ) - O_h = ops.flash_attention( - prog, + O_h = prog.flash_attention( Q_h, K_stored, V_stored, @@ -221,13 +216,13 @@ def _emit_attention_block( ) _free_named_tensors(prog, ("O", "S", "PV")) - O_proj = _linear_projection(prog, O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") + O_proj = prog.linear_projection(O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") return _add_residual(prog, O_proj, scratch) def _emit_ffn_block(prog, current, layer_inputs, scratch): _save_residual_and_norm(prog, current, scratch) - ops.ffn(prog, current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) + prog.ffn(current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) return _add_residual(prog, current, scratch) @@ -388,9 +383,6 @@ def _verbose(message: str = ""): # ----------------------------------------------------------- PLENA ISA print("\n--- PLENA Backend (ISA generation) ---") - registry = OpRegistry.load() - registry.set_backend(Backend.PLENA) - prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO) # Shared inputs @@ -419,7 +411,7 @@ def _verbose(message: str = ""): # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") POS_batch = prog.load_batch(pos_input, name="POS") - ops.embedding_add(prog, X_batch, POS_batch) # X += POS in-place + prog.embedding_add(X_batch, POS_batch) # X += POS in-place # VRAM layout hazard: ffn_asm writes gate/up intermediates at absolute # address batch*hidden spanning up to batch*hidden + 2*inter*batch. diff --git a/generator/aten_runner.py b/generator/aten_runner.py index 72f2ddf..a1b8602 100644 --- a/generator/aten_runner.py +++ b/generator/aten_runner.py @@ -1,7 +1,7 @@ """ ATen-backend runner for the generator. -Wraps the proven ATen compilation path (PlenaCompiler + ops.*) to provide +Wraps the proven ATen compilation path (PlenaCompiler direct codegen) to provide end-to-end model compilation + emulation + numerical verification. This is the same pipeline that passes 18/18 tests in the ATen test suite. @@ -45,7 +45,7 @@ def run_aten_e2e( Steps: 1. Load model config + layer weights from HuggingFace - 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) + 2. Build ISA via PlenaCompiler direct codegen (numerically verified path) 3. Set up sim environment (ASM + HBM weights + FPRAM constants) 4. Run Rust emulator 5. Compare VRAM output against golden PyTorch reference @@ -65,19 +65,13 @@ def run_aten_e2e( inter_dim: int build_dir: str """ + from transactional_emulator.testbench.emulator_runner import compare_emulator_output from transactional_emulator.testbench.model_layer_test_builder import ( build_and_run_decoder_test, build_and_run_multi_layer_test, get_model_dims, slice_dims_for_sim, - MLEN, ) - from transactional_emulator.testbench.emulator_runner import ( - compare_emulator_output, - run_emulator, - ) - from transactional_emulator.testbench.config_utils import update_plena_config - from transactional_emulator.tools.check_mem import print_comparison_results t0 = time.time() @@ -127,9 +121,9 @@ def run_aten_e2e( asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" layer_build = build_path / f"layer_{current_layer}" - print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") + print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler direct codegen") print(f"[3/5] Setting up sim environment: {layer_build}") - print(f"[4/5] Running Rust transactional emulator") + print("[4/5] Running Rust transactional emulator") extra_kwargs = {} if trust_remote_code: @@ -148,7 +142,7 @@ def run_aten_e2e( inter_dim=inter_dim, **extra_kwargs, ) - comp_results, comp_params = compare_emulator_output(layer_build) + comp_results, _comp_params = compare_emulator_output(layer_build) results_per_layer.append({ "layer": current_layer, "passed": True, @@ -165,7 +159,7 @@ def run_aten_e2e( "model_id": model_id, } try: - comp_results, comp_params = compare_emulator_output(layer_build) + comp_results, _comp_params = compare_emulator_output(layer_build) results_per_layer.append({ "layer": current_layer, "passed": False, @@ -185,9 +179,9 @@ def run_aten_e2e( asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" chain_build = build_path / f"chain_{num_layers}layers" - print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") + print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler direct codegen") print(f"[3/5] Setting up sim environment: {chain_build}") - print(f"[4/5] Running Rust transactional emulator") + print("[4/5] Running Rust transactional emulator") extra_kwargs = {} if trust_remote_code: @@ -207,7 +201,7 @@ def run_aten_e2e( inter_dim=inter_dim, **extra_kwargs, ) - comp_results, comp_params = compare_emulator_output(chain_build) + comp_results, _comp_params = compare_emulator_output(chain_build) results_per_layer.append({ "layer": f"chain_{num_layers}", "passed": True, @@ -224,7 +218,7 @@ def run_aten_e2e( "model_id": model_id, } try: - comp_results, comp_params = compare_emulator_output(chain_build) + comp_results, _comp_params = compare_emulator_output(chain_build) results_per_layer.append({ "layer": f"chain_{num_layers}", "passed": False, @@ -291,7 +285,7 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", + description="Run HF model through ATen compilation path (PlenaCompiler direct codegen)", prog="python -m generator.aten_runner", ) parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") diff --git a/generator/runner.py b/generator/runner.py index 123a835..ca12c94 100644 --- a/generator/runner.py +++ b/generator/runner.py @@ -28,7 +28,7 @@ def _run_aten(args) -> int: - """ATen-backed end-to-end: PlenaCompiler + ops.* → emulator → numerical check.""" + """ATen-backed end-to-end: PlenaCompiler direct codegen -> emulator -> numerical check.""" from generator.aten_runner import run_aten_e2e result = run_aten_e2e( From a0ae6713f93ef78c485f6f283d5bc35a64dfb842 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 15:33:51 +0100 Subject: [PATCH 24/32] refactor: route HF frontend through aten ops --- aten/native_ops.yaml | 2 +- aten/ops/cpu/linear_ops.py | 2 ++ aten/ops/cpu/norm_ops.py | 6 ++++++ aten/ops/plena/linear_ops.py | 4 ++-- aten/plena_frontend.py | 34 +++++++++++++++++++++++----------- generator/aten_runner.py | 10 +++++----- generator/runner.py | 2 +- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/aten/native_ops.yaml b/aten/native_ops.yaml index 9cf8aaa..3a828d4 100644 --- a/aten/native_ops.yaml +++ b/aten/native_ops.yaml @@ -31,7 +31,7 @@ uses_mram: false doc: "Online softmax: row-wise numerically stable softmax (in-place on S tile)" -- func: "linear(Tensor input, Tensor weight) -> Tensor" +- func: "linear(Tensor input, Tensor weight, str name='linear_out') -> Tensor" category: primitive in_place: false dispatch: diff --git a/aten/ops/cpu/linear_ops.py b/aten/ops/cpu/linear_ops.py index 58c7cda..69cbe05 100644 --- a/aten/ops/cpu/linear_ops.py +++ b/aten/ops/cpu/linear_ops.py @@ -6,6 +6,8 @@ def linear_cpu( input: torch.Tensor, weight: torch.Tensor, + name: str = "linear_out", ) -> torch.Tensor: """CPU reference: input @ weight (float32 accumulation).""" + del name return torch.matmul(input.float(), weight.float()) diff --git a/aten/ops/cpu/norm_ops.py b/aten/ops/cpu/norm_ops.py index d633d71..8f0362b 100644 --- a/aten/ops/cpu/norm_ops.py +++ b/aten/ops/cpu/norm_ops.py @@ -6,8 +6,11 @@ def rms_norm_cpu( input: torch.Tensor, eps: float = 1e-6, + eps_offset: int = 1, + reci_hid_offset: int = 2, ) -> torch.Tensor: """CPU reference: RMS normalization.""" + del eps_offset, reci_hid_offset x = input.float() rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) return x / rms @@ -16,8 +19,11 @@ def rms_norm_cpu( def layer_norm_cpu( input: torch.Tensor, eps: float = 1e-6, + eps_offset: int = 1, + reci_hid_offset: int = 2, ) -> torch.Tensor: """CPU reference: Layer normalization (zero-mean, unit-variance per row).""" + del eps_offset, reci_hid_offset x = input.float() mean = x.mean(-1, keepdim=True) var = x.var(-1, keepdim=True, unbiased=False) diff --git a/aten/ops/plena/linear_ops.py b/aten/ops/plena/linear_ops.py index 4619529..56d801c 100644 --- a/aten/ops/plena/linear_ops.py +++ b/aten/ops/plena/linear_ops.py @@ -5,5 +5,5 @@ def linear_projection_plena(prog, input_var, weight_var, name: str = "linear_out return prog.linear_projection(input_var, weight_var, name) -def linear_plena(prog, input_var, weight_var): - return prog.linear(input_var, weight_var) +def linear_plena(prog, input_var, weight_var, name: str = "linear_out"): + return prog.linear_projection(input_var, weight_var, name) diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index fdd4355..f888273 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -14,6 +14,8 @@ extract_model_config, find_model_root, ) +import compiler.aten.ops as ops +from compiler.aten.ops.registry import Backend, OpRegistry from compiler.aten.plena_compiler import PlenaCompiler from compiler.aten.reference import ( ReferencePrecision, @@ -84,7 +86,7 @@ def _save_residual_and_norm(prog, source, scratch): """Emit the common decoder pre-norm residual prologue.""" prog.vram_fill_zero(scratch) prog.vram_add(scratch, source) - prog.rms_norm(source, eps_offset=3, reci_hid_offset=4) + ops.rms_norm(prog, source, eps_offset=3, reci_hid_offset=4) def _add_residual(prog, target, scratch): @@ -92,9 +94,13 @@ def _add_residual(prog, target, scratch): return target +def _linear_projection(prog, input_var, weight_var, name: str): + return ops.linear(prog, input_var, weight_var, name=name) + + def _apply_rope_projection(prog, x_var, rope_matrix, cos_var, sin_var, name): - x_rot = prog.linear_projection(x_var, rope_matrix, name) - prog.rope(x_var, x_rot, cos_var, sin_var) + x_rot = _linear_projection(prog, x_var, rope_matrix, name) + ops.rope(prog, x_var, x_rot, cos_var, sin_var) prog.free_tensor(x_rot) return x_var @@ -118,12 +124,14 @@ def _emit_kv_stores(prog, current, layer_inputs, rope_inputs, layer_idx, num_kv_ rope_matrix, cos_var, sin_var = rope_inputs kv_stored = [] for kv_h in range(num_kv_heads): - K_h = prog.linear_projection( + K_h = _linear_projection( + prog, current, layer_inputs.w_k_heads[kv_h], f"K_{layer_idx}_h{kv_h}", ) - V_h = prog.linear_projection( + V_h = _linear_projection( + prog, current, layer_inputs.w_v_heads[kv_h], f"V_{layer_idx}_h{kv_h}", @@ -166,7 +174,7 @@ def _emit_attention_block( ): _save_residual_and_norm(prog, current, scratch) - Q = prog.linear_projection(current, layer_inputs.w_q, f"Q_{layer_idx}") + Q = _linear_projection(prog, current, layer_inputs.w_q, f"Q_{layer_idx}") q_full_addr = prog.get_vram_addr(Q.name) O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim) @@ -197,7 +205,8 @@ def _emit_attention_block( f"Q_rot_{layer_idx}_h{h}", ) - O_h = prog.flash_attention( + O_h = ops.flash_attention( + prog, Q_h, K_stored, V_stored, @@ -216,13 +225,13 @@ def _emit_attention_block( ) _free_named_tensors(prog, ("O", "S", "PV")) - O_proj = prog.linear_projection(O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") + O_proj = _linear_projection(prog, O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") return _add_residual(prog, O_proj, scratch) def _emit_ffn_block(prog, current, layer_inputs, scratch): _save_residual_and_norm(prog, current, scratch) - prog.ffn(current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) + ops.ffn(prog, current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) return _add_residual(prog, current, scratch) @@ -383,6 +392,9 @@ def _verbose(message: str = ""): # ----------------------------------------------------------- PLENA ISA print("\n--- PLENA Backend (ISA generation) ---") + registry = OpRegistry.load() + registry.set_backend(Backend.PLENA) + prog = PlenaCompiler(mlen=mlen, blen=blen, real_data_ratio=REAL_DATA_RATIO) # Shared inputs @@ -411,7 +423,7 @@ def _verbose(message: str = ""): # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") POS_batch = prog.load_batch(pos_input, name="POS") - prog.embedding_add(X_batch, POS_batch) # X += POS in-place + ops.embedding_add(prog, X_batch, POS_batch) # X += POS in-place # VRAM layout hazard: ffn_asm writes gate/up intermediates at absolute # address batch*hidden spanning up to batch*hidden + 2*inter*batch. @@ -461,7 +473,7 @@ def _verbose(message: str = ""): prog.emit_comment(f"=== LAYER {i}/{n_layers} COMPLETE ===") # Final norm - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) + ops.rms_norm(prog, current, eps_offset=3, reci_hid_offset=4) isa_code = prog.compile() isa_code = _fix_large_immediates(isa_code) diff --git a/generator/aten_runner.py b/generator/aten_runner.py index a1b8602..62bb670 100644 --- a/generator/aten_runner.py +++ b/generator/aten_runner.py @@ -1,7 +1,7 @@ """ ATen-backend runner for the generator. -Wraps the proven ATen compilation path (PlenaCompiler direct codegen) to provide +Wraps the proven ATen compilation path (PlenaCompiler + ops.*) to provide end-to-end model compilation + emulation + numerical verification. This is the same pipeline that passes 18/18 tests in the ATen test suite. @@ -45,7 +45,7 @@ def run_aten_e2e( Steps: 1. Load model config + layer weights from HuggingFace - 2. Build ISA via PlenaCompiler direct codegen (numerically verified path) + 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) 3. Set up sim environment (ASM + HBM weights + FPRAM constants) 4. Run Rust emulator 5. Compare VRAM output against golden PyTorch reference @@ -121,7 +121,7 @@ def run_aten_e2e( asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" layer_build = build_path / f"layer_{current_layer}" - print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler direct codegen") + print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") print(f"[3/5] Setting up sim environment: {layer_build}") print("[4/5] Running Rust transactional emulator") @@ -179,7 +179,7 @@ def run_aten_e2e( asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" chain_build = build_path / f"chain_{num_layers}layers" - print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler direct codegen") + print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") print(f"[3/5] Setting up sim environment: {chain_build}") print("[4/5] Running Rust transactional emulator") @@ -285,7 +285,7 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="Run HF model through ATen compilation path (PlenaCompiler direct codegen)", + description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", prog="python -m generator.aten_runner", ) parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") diff --git a/generator/runner.py b/generator/runner.py index ca12c94..1dacbd1 100644 --- a/generator/runner.py +++ b/generator/runner.py @@ -28,7 +28,7 @@ def _run_aten(args) -> int: - """ATen-backed end-to-end: PlenaCompiler direct codegen -> emulator -> numerical check.""" + """ATen-backed end-to-end: PlenaCompiler + ops.* -> emulator -> numerical check.""" from generator.aten_runner import run_aten_e2e result = run_aten_e2e( From 50cd8b4318b6d07e326c50de726a18c50fd76993 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 16:03:14 +0100 Subject: [PATCH 25/32] refactor: split memory state from matrix ISA --- .../tests/test_vram_sub_projection.py | 10 +- asm_templates/vram_sub_projection_asm.py | 4 +- aten/__init__.py | 1 + aten/plena/__init__.py | 2 + aten/plena/isa_compiler.py | 19 +- aten/plena/isa_matrix.py | 380 +++++++++- aten/plena/memory_state.py | 268 +++++++ aten/plena/tile_compiler.py | 653 +----------------- aten/plena_compiler.py | 2 + aten/tests/test_plena_compiler.py | 4 +- 10 files changed, 671 insertions(+), 672 deletions(-) create mode 100644 aten/plena/memory_state.py diff --git a/asm_templates/tests/test_vram_sub_projection.py b/asm_templates/tests/test_vram_sub_projection.py index 197f178..27ba85e 100644 --- a/asm_templates/tests/test_vram_sub_projection.py +++ b/asm_templates/tests/test_vram_sub_projection.py @@ -3,7 +3,7 @@ Verifies the extracted free function produces the expected ISA output for looped/unrolled and transposed/non-transposed variants, and asserts byte-identical parity with the delegating -``TileCompiler._vram_sub_projection_asm_impl`` method. +``IsaCompiler._vram_sub_projection_asm_impl`` method. """ import sys @@ -95,10 +95,10 @@ def test_unrolled_no_loops(self): self.assertIn("M_MM_WO", asm) def test_output_byte_identical_to_method(self): - """The free function must produce byte-identical output to TileCompiler's method.""" - from compiler.aten.plena_compiler import TileCompiler + """The free function must produce byte-identical output to IsaCompiler's method.""" + from compiler.aten.plena_compiler import IsaCompiler - tc = TileCompiler(mlen=64, blen=4, unroll_loops=False) + compiler = IsaCompiler(mlen=64, blen=4, unroll_loops=False) method_kwargs = dict( header_lines=["; header"], @@ -113,7 +113,7 @@ def test_output_byte_identical_to_method(self): caller_name="test", ) - method_out = tc._vram_sub_projection_asm_impl(**method_kwargs) + method_out = compiler._vram_sub_projection_asm_impl(**method_kwargs) free_out = vram_sub_projection_asm_impl( mlen=64, blen=4, diff --git a/asm_templates/vram_sub_projection_asm.py b/asm_templates/vram_sub_projection_asm.py index 5851f5e..ebc68b7 100644 --- a/asm_templates/vram_sub_projection_asm.py +++ b/asm_templates/vram_sub_projection_asm.py @@ -1,7 +1,7 @@ """Pure emitter for VRAM sub-projection ISA. -Shared implementation kernel used by ``TileCompiler.vram_sub_projection_asm`` -and ``TileCompiler.vram_sub_projection_T_asm``. The caller resolves all +Shared implementation kernel used by ``IsaCompiler.vram_sub_projection_asm`` +and ``IsaCompiler.vram_sub_projection_T_asm``. The caller resolves all instance-dependent state (register allocator, tile layouts, MRAM addresses, ``unroll_loops`` default) and passes it in as plain parameters, so this emitter can be unit-tested in isolation. diff --git a/aten/__init__.py b/aten/__init__.py index a9068ef..8d4db3a 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -22,6 +22,7 @@ FPVar, InputVar, IsaCompiler, + MemoryStateMixin, PlenaCompiler, TensorVar, TileCompiler, diff --git a/aten/plena/__init__.py b/aten/plena/__init__.py index 1695c6a..d812ba6 100644 --- a/aten/plena/__init__.py +++ b/aten/plena/__init__.py @@ -3,6 +3,7 @@ from compiler.aten.plena.compiler import PlenaCompiler from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN from compiler.aten.plena.isa_compiler import IsaCompiler +from compiler.aten.plena.memory_state import MemoryStateMixin from compiler.aten.plena.tile_compiler import TileCompiler from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar @@ -13,6 +14,7 @@ "FPVar", "InputVar", "IsaCompiler", + "MemoryStateMixin", "PlenaCompiler", "TensorVar", "TileCompiler", diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py index c5d7f4b..6deb05b 100644 --- a/aten/plena/isa_compiler.py +++ b/aten/plena/isa_compiler.py @@ -16,8 +16,8 @@ from compiler.aten.plena.isa_fp_ops import IsaFPOpsMixin from compiler.aten.plena.isa_matrix import IsaMatrixMixin from compiler.aten.plena.isa_tile_rows import IsaTileRowMixin +from compiler.aten.plena.memory_state import MemoryStateMixin from compiler.aten.plena.registers import RegisterAllocator -from compiler.aten.plena.tile_compiler import TileCompiler class IsaCompiler( @@ -26,7 +26,7 @@ class IsaCompiler( IsaTileRowMixin, IsaFPOpsMixin, IsaEmitMixin, - TileCompiler, + MemoryStateMixin, ): """ ISA Compiler: lowers PLENA compiler operations to assembly text. @@ -37,8 +37,7 @@ class IsaCompiler( _ONLINE_SOFTMAX_FPSRAM_BASE = 10 def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - # TileCompiler.__init__ sets dimensions, layout tables, memory allocators, - # and the currently loaded MRAM sub-block table. + # MemoryStateMixin.__init__ sets dimensions, layout tables, and memory allocators. super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) self.real_data_ratio = real_data_ratio self.register_allocator = RegisterAllocator() @@ -332,8 +331,8 @@ def reset(self): """Reset compiler state (clear code, but retain symbol table)""" self.generated_code = "" self.register_allocator = RegisterAllocator() - # Call TileCompiler.reset() explicitly since the merged class shadows it. - TileCompiler.reset(self) + # Call MemoryStateMixin.reset() explicitly since the merged class shadows it. + MemoryStateMixin.reset(self) def get_tensor_info(self, name: str): """Get unified tensor/object info by name.""" @@ -348,11 +347,11 @@ def add_hbm_object( ): """Register an HBM object and build its HBM layout. - Wraps ``TileCompiler.add_hbm_object`` with a different positional + Wraps the memory-layout ``add_hbm_object`` with a different positional parameter order ``(name, hbm_addr, shape, ...)`` that all IsaCompiler callers use. """ - return TileCompiler.add_hbm_object( + return MemoryStateMixin.add_hbm_object( self, name=name, shape=shape, @@ -362,7 +361,7 @@ def add_hbm_object( def free_hbm_object(self, name: str, strict: bool = False): """Free an HBM object by name (defaults to non-strict).""" - return TileCompiler.free_hbm_object(self, name, strict=strict) + return MemoryStateMixin.free_hbm_object(self, name, strict=strict) def get_vram_addr(self, name: str) -> int: """Get VRAM base address of an object.""" @@ -420,6 +419,6 @@ def ensure_vram_matrix_layout(self, name: str, shape: tuple[int, int]): def free_vram_object(self, name: str, strict: bool = False): """Free a VRAM object by name (defaults to non-strict).""" - return TileCompiler.free_vram_object(self, name, strict=strict) + return MemoryStateMixin.free_vram_object(self, name, strict=strict) __all__ = ["IsaCompiler"] diff --git a/aten/plena/isa_matrix.py b/aten/plena/isa_matrix.py index ef8c10a..8056552 100644 --- a/aten/plena/isa_matrix.py +++ b/aten/plena/isa_matrix.py @@ -3,7 +3,8 @@ from __future__ import annotations from compiler.asm_templates import preload_addr_reg_asm -from compiler.aten.isa_builder import IsaBuilder +from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl +from compiler.aten.isa_builder import IsaBuilder, addr as areg, gp class IsaMatrixMixin: @@ -30,10 +31,381 @@ def reset_mram(self) -> str: Used in scenarios where sub-blocks need to be reloaded within a for loop """ self.mram_allocator.reset() - self.loaded_sub_blocks.clear() + self.clear_mram_bindings() return self._emit(IsaBuilder().comment("=== Reset MRAM ===")) + def _default_hbm_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3] if gp_regs is None else gp_regs + + def _emit_hbm_prefetch_setup(self, asm: IsaBuilder, layout, gp_scale: int, gp_stride: int) -> None: + rows, cols = layout.full_shape + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), rows * cols) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + def _emit_hbm_subblock_prefetch( + self, + asm: IsaBuilder, + layout, + row_idx: int, + col_idx: int, + mram_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + comment: str | None = None, + ) -> None: + sub_block = layout.get_sub_block(row_idx, col_idx) + hbm_offset = sub_block.hbm_offset + sub_block.mram_addr = mram_addr + + asm.comment(comment if comment is not None else f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) + + def _emit_hbm_subblock_sequence( + self, + asm: IsaBuilder, + layout, + block_coords, + mram_start_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + ) -> None: + mram_addr = mram_start_addr + block_size = self.mlen * self.mlen + for row_idx, col_idx in block_coords: + self._emit_hbm_subblock_prefetch( + asm, + layout, + row_idx, + col_idx, + mram_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + mram_addr += block_size + + def load_sub_matrix_asm( + self, + name: str, + row_idx: int, + col_idx: int, + mram_dest_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """Emit HBM->MRAM prefetch for one mlen x mlen sub-block.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + hbm_offset = layout.get_sub_block(row_idx, col_idx).hbm_offset + self._emit_hbm_subblock_prefetch( + asm, + layout, + row_idx, + col_idx, + mram_dest_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + comment=f"HBM offset: {hbm_offset} (precomputed)", + ) + + return asm.render() + + def load_row_sub_matrices_asm( + self, + name: str, + row_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """Emit HBM->MRAM prefetches for one block row.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + + self._emit_hbm_subblock_sequence( + asm, + layout, + ((row_idx, col_idx) for col_idx in range(layout.num_col_blocks)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + + return asm.render() + + def load_col_sub_matrices_asm( + self, + name: str, + col_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """Emit HBM->MRAM prefetches for one block column or K-split slice.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + num_row_blocks = layout.num_row_blocks + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + + effective_count = k_block_count if k_block_count is not None else num_row_blocks + self._emit_hbm_subblock_sequence( + asm, + layout, + ((row_idx, col_idx) for row_idx in range(k_block_start, k_block_start + effective_count)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + + return asm.render() + + def _default_projection_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3, 4, 5, 6, 7, 8, 9] if gp_regs is None else gp_regs + + def _projection_context(self, vram_mat_name: str, vram_row_idx: int, mram_mat_name: str): + if vram_mat_name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") + vram_layout = self.vram_matrices[vram_mat_name] + return vram_layout, self.hbm_matrices[mram_mat_name], vram_layout.get_row_blocks(vram_row_idx) + + def _loaded_mram_start(self, blocks, missing_label) -> int: + for sub_block in blocks: + if sub_block.mram_addr is None: + raise RuntimeError(f"SubBlock {missing_label(sub_block)} not loaded to MRAM") + return blocks[0].mram_addr + + def _vram_sub_projection_asm_impl( + self, + header_lines: list[str], + vram_row_start_addr: int, + mram_start_addr: int, + result_vram_addr: int, + full_batch: int, + num_hidden_blocks: int, + mat_col_stride: int, + transposed: bool, + gp_regs: list[int], + caller_name: str, + unroll: bool | None = None, + ) -> str: + """Emit the shared projection loop after callers resolve operands.""" + do_unroll = self.unroll_loops if unroll is None else unroll + return vram_sub_projection_asm_impl( + mlen=self.mlen, + blen=self.blen, + unroll_loops=do_unroll, + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=transposed, + gp_regs=gp_regs, + caller_name=caller_name, + ) + + def vram_sub_projection_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_col_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + unroll: bool | None = None, + ) -> str: + """Emit VRAM[row][:] @ MRAM[:][col] projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) + mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) + if k_block_count is not None: + mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] + + num_hidden_blocks = len(mram_col_blocks) + if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): + raise ValueError( + f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " + f"got {num_hidden_blocks}" + ) + + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr + mram_col_start_addr = self._loaded_mram_start( + mram_col_blocks, + lambda block: f"{mram_mat_name}[{block.row_idx}][{mram_col_idx}]", + ) + + header_lines = [ + f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", + f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_col_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=self.blen, + transposed=False, + gp_regs=gp_regs, + caller_name="vram_sub_projection_asm", + unroll=unroll, + ) + + def vram_sub_projection_T_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_row_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + unroll: bool | None = None, + ) -> str: + """Emit VRAM[row][:] @ MRAM[row][:]^T projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) + mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) + + if len(vram_row_blocks) != len(mram_row_blocks): + raise ValueError( + f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" + ) + + num_hidden_blocks = len(vram_row_blocks) + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[0].vram_addr + mram_row_start_addr = self._loaded_mram_start( + mram_row_blocks, + lambda block: f"{mram_mat_name}[{mram_row_idx}][{block.col_idx}]", + ) + # M_TMM reads the weight in transposed layout, so the outer-column + # stride is one full sub-block instead of the non-transposed blen stride. + mat_col_stride = self.blen * self.mlen + + header_lines = [ + f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", + f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_row_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=True, + gp_regs=gp_regs, + caller_name="vram_sub_projection_T_asm", + unroll=unroll, + ) + + def vram_block_add_asm( + self, + src1_name: str, + src1_row_idx: int, + src1_col_idx: int, + src2_name: str, + src2_row_idx: int, + src2_col_idx: int, + target_name: str, + target_row_idx: int, + target_col_idx: int, + gp_regs: list[int] | None = None, + ) -> str: + """ + Add two mlen x mlen blocks and write to any target block: + target[target_row_idx][target_col_idx] = + src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4] + if len(gp_regs) < 4: + raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") + + src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) + src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) + target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) + + gp_dst = gp_regs[0] + gp_src1 = gp_regs[1] + gp_src2 = gp_regs[2] + gp_loop = gp_regs[3] + + lines = [ + f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " + f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" + ] + + if self.unroll_loops: + for i in range(self.mlen): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + else: + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + + return "\n".join(lines) + "\n" + def load_sub_matrix_row( self, name: str, @@ -123,7 +495,7 @@ def allocate_vram_matrix( return vram_addr def _ensure_vram_matrix_layout(self, matrix_name: str): - """Ensure a VRAM-resident tensor has a block layout in TileCompiler.""" + """Ensure a VRAM-resident tensor has a block layout.""" if matrix_name not in self: raise KeyError(f"Matrix '{matrix_name}' not found in symbol table") @@ -196,7 +568,7 @@ def vram_matrix_add( dst_info = self[dst_matrix] src_info = self[src_matrix] - # Block-add path depends on TileCompiler VRAM layouts. + # Block-add path depends on registered VRAM block layouts. self._ensure_vram_matrix_layout(dst_matrix) self._ensure_vram_matrix_layout(src_matrix) diff --git a/aten/plena/memory_state.py b/aten/plena/memory_state.py new file mode 100644 index 0000000..6e4482e --- /dev/null +++ b/aten/plena/memory_state.py @@ -0,0 +1,268 @@ +"""Memory layout state for the ATen PLENA compiler.""" + +from __future__ import annotations + +from compiler.aten.plena.constants import BLEN, MLEN +from compiler.aten.plena.memory import ( + FPRAMAllocator, + FPRAMObjectLayout, + MRAMAllocator, + MatrixBlockLayout, + MemoryObjectInfo, + VRAMAllocator, + VRAMMatrixBlockLayout, + VRAMSubMatrixInfo, +) + + +class MemoryStateMixin: + """Sub-matrix layout manager for PLENA HBM, VRAM, MRAM, and FPRAM state.""" + + def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): + self.mlen = mlen + self.blen = blen + self.unroll_loops = unroll_loops + + # Layout tables + self.hbm_matrices: dict[str, MatrixBlockLayout] = {} + self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} + self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} + # Memory Allocators + self.vram_allocator = VRAMAllocator() + self.mram_allocator = MRAMAllocator() + self.fpram_allocator = FPRAMAllocator() + + def __contains__(self, name: str) -> bool: + return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices + + def __getitem__(self, name: str) -> MemoryObjectInfo: + if name not in self: + raise KeyError(f"Object '{name}' not found") + info = MemoryObjectInfo(name=name, kind="Unknown") + hbm_layout = self.hbm_matrices.get(name) + vram_layout = self.vram_matrices.get(name) + fpram_layout = self.fpram_matrices.get(name) + + if hbm_layout is not None: + rows, cols = hbm_layout.full_shape + info.shape = hbm_layout.full_shape + info.size = rows * cols + info.hbm_addr = hbm_layout.hbm_base_addr + info.hbm_size = hbm_layout.hbm_size + info.kind = "Matrix" + + if vram_layout is not None: + rows, cols = vram_layout.full_shape + info.shape = vram_layout.full_shape + info.size = rows * cols + info.vram_addr = vram_layout.vram_base_addr + info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" + + if fpram_layout is not None: + info.shape = (1, fpram_layout.size) + info.size = fpram_layout.size + info.fpram_addr = fpram_layout.fpram_addr + info.fpram_size = fpram_layout.size + info.kind = "FPRAMObject" + + return info + + def get_hbm_layout(self, name: str) -> MatrixBlockLayout: + """Read HBM matrix layout by name.""" + if name not in self.hbm_matrices: + raise KeyError(f"HBM matrix '{name}' not found") + return self.hbm_matrices[name] + + def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: + """Read VRAM matrix layout by name.""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not found") + return self.vram_matrices[name] + + def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: + """Read FPRAM object layout by name.""" + if name not in self.fpram_matrices: + raise KeyError(f"FPRAM object '{name}' not found") + return self.fpram_matrices[name] + + # ========================================================================== + # Unified Object Management APIs + # ========================================================================== + + def add_hbm_object( + self, + name: str, + shape: tuple[int, int], + hbm_addr: int, + dtype: str = "fp16", + kind: str = "HBMObject", + real_data_ratio: float = 1.125, + strict: bool = False, + ) -> MemoryObjectInfo: + del dtype, kind + self.register_matrix( + name=name, + shape=shape, + hbm_base_addr=hbm_addr, + real_data_ratio=real_data_ratio, + strict=strict, + ) + return self[name] + + def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.hbm_matrices: + if strict: + raise KeyError(f"HBM object '{name}' not found") + return None + info = self[name] + self.hbm_matrices.pop(name, None) + return info + + def add_vram_object( + self, + name: str, + shape: tuple[int, int], + vram_addr: int | None = None, + dtype: str = "fp16", + kind: str = "VRAMObject", + allocate_if_none: bool = True, + strict: bool = True, + ) -> MemoryObjectInfo: + rows, cols = shape + size = rows * cols + if vram_addr is None: + if not allocate_if_none: + raise ValueError("vram_addr is None and allocate_if_none is False") + vram_addr = self.vram_allocator.allocate(size=size, name=name) + del dtype, kind + self.register_vram_matrix( + name=name, + shape=shape, + vram_base_addr=vram_addr, + strict=strict, + ) + return self[name] + + def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.vram_matrices: + if strict: + raise KeyError(f"VRAM object '{name}' not found") + return None + info = self[name] + self.vram_allocator.free(name, strict=strict) + self.vram_matrices.pop(name, None) + return info + + def add_fpram_object( + self, + name: str, + size: int, + dtype: str = "fp16", + kind: str = "FPRAMObject", + ) -> MemoryObjectInfo: + fpram_addr = self.fpram_allocator.allocate(name, size) + self.fpram_matrices[name] = FPRAMObjectLayout( + name=name, + fpram_addr=fpram_addr, + size=size, + dtype=dtype, + kind=kind, + ) + return self[name] + + def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.fpram_matrices: + if strict: + raise KeyError(f"FPRAM object '{name}' not found") + return None + info = self[name] + self.fpram_allocator.free(name, strict=strict) + self.fpram_matrices.pop(name, None) + return info + + def register_matrix( + self, + name: str, + shape: tuple[int, int], + hbm_base_addr: int, + real_data_ratio: float = 1.125, + strict: bool = True, + ) -> MatrixBlockLayout: + """Register an HBM matrix and derive its mlen block layout.""" + rows, cols = shape + + if strict: + if rows % self.mlen != 0: + raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") + if cols % self.mlen != 0: + raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") + + size = rows * cols + hbm_size = int(size * real_data_ratio) + + layout = MatrixBlockLayout( + name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size + ) + + self.hbm_matrices[name] = layout + return layout + + # ========================================================================== + # VRAM Sub-matrix Management + # ========================================================================== + + def register_vram_matrix( + self, + name: str, + shape: tuple[int, int], + vram_base_addr: int, + strict: bool = True, + ) -> VRAMMatrixBlockLayout: + """ + Register a large matrix in VRAM and automatically block it. + + Args: + name: matrix name + shape: full shape (batch, hidden_size) + vram_base_addr: VRAM base address + + Returns: + VRAMMatrixBlockLayout object + """ + batch, hidden = shape + + if strict: + if batch % self.mlen != 0: + raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") + if hidden % self.mlen != 0: + raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") + + layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) + + self.vram_matrices[name] = layout + return layout + + def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: + """Get VRAM sub-block information""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not registered") + return self.vram_matrices[name].get_sub_block(row_idx, col_idx) + + def clear_mram_bindings(self) -> None: + """Clear cached MRAM addresses on all HBM sub-blocks.""" + for layout in self.hbm_matrices.values(): + for sub_block in layout.sub_blocks.values(): + sub_block.mram_addr = None + + def reset(self): + """Reset manager state.""" + self.clear_mram_bindings() + self.hbm_matrices.clear() + self.vram_matrices.clear() + self.fpram_matrices.clear() + self.vram_allocator.reset() + self.mram_allocator.reset() + self.fpram_allocator.reset() + + +__all__ = ["MemoryStateMixin"] diff --git a/aten/plena/tile_compiler.py b/aten/plena/tile_compiler.py index 8f4ef3a..72091a4 100644 --- a/aten/plena/tile_compiler.py +++ b/aten/plena/tile_compiler.py @@ -1,654 +1,9 @@ -"""Tile/block lowering helpers for the ATen PLENA compiler.""" +"""Compatibility alias for the former TileCompiler memory-state class.""" from __future__ import annotations -from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl -from compiler.aten.isa_builder import IsaBuilder, addr as areg, gp -from compiler.aten.plena.constants import BLEN, MLEN -from compiler.aten.plena.memory import ( - FPRAMAllocator, - FPRAMObjectLayout, - MRAMAllocator, - MatrixBlockLayout, - MemoryObjectInfo, - SubMatrixInfo, - VRAMAllocator, - VRAMMatrixBlockLayout, - VRAMSubMatrixInfo, -) +from compiler.aten.plena.memory_state import MemoryStateMixin +TileCompiler = MemoryStateMixin -class TileCompiler: - """Sub-matrix layout manager and ISA emitter for tiled PLENA memory ops.""" - - def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): - self.mlen = mlen - self.blen = blen - self.unroll_loops = unroll_loops - - # Layout tables - self.hbm_matrices: dict[str, MatrixBlockLayout] = {} - self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} - self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} - # Memory Allocators - self.vram_allocator = VRAMAllocator() - self.mram_allocator = MRAMAllocator() - self.fpram_allocator = FPRAMAllocator() - - # Currently loaded sub-blocks in MRAM - self.loaded_sub_blocks: dict[str, SubMatrixInfo] = {} - - def __contains__(self, name: str) -> bool: - return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices - - def __getitem__(self, name: str) -> MemoryObjectInfo: - if name not in self: - raise KeyError(f"Object '{name}' not found") - info = MemoryObjectInfo(name=name, kind="Unknown") - hbm_layout = self.hbm_matrices.get(name) - vram_layout = self.vram_matrices.get(name) - fpram_layout = self.fpram_matrices.get(name) - - if hbm_layout is not None: - rows, cols = hbm_layout.full_shape - info.shape = hbm_layout.full_shape - info.size = rows * cols - info.hbm_addr = hbm_layout.hbm_base_addr - info.hbm_size = hbm_layout.hbm_size - info.kind = "Matrix" - - if vram_layout is not None: - rows, cols = vram_layout.full_shape - info.shape = vram_layout.full_shape - info.size = rows * cols - info.vram_addr = vram_layout.vram_base_addr - info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" - - if fpram_layout is not None: - info.shape = (1, fpram_layout.size) - info.size = fpram_layout.size - info.fpram_addr = fpram_layout.fpram_addr - info.fpram_size = fpram_layout.size - info.kind = "FPRAMObject" - - return info - - def get_hbm_layout(self, name: str) -> MatrixBlockLayout: - """Read HBM matrix layout by name.""" - if name not in self.hbm_matrices: - raise KeyError(f"HBM matrix '{name}' not found") - return self.hbm_matrices[name] - - def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: - """Read VRAM matrix layout by name.""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not found") - return self.vram_matrices[name] - - def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: - """Read FPRAM object layout by name.""" - if name not in self.fpram_matrices: - raise KeyError(f"FPRAM object '{name}' not found") - return self.fpram_matrices[name] - - # ========================================================================== - # Unified Object Management APIs - # ========================================================================== - - def add_hbm_object( - self, - name: str, - shape: tuple[int, int], - hbm_addr: int, - dtype: str = "fp16", - kind: str = "HBMObject", - real_data_ratio: float = 1.125, - strict: bool = False, - ) -> MemoryObjectInfo: - del dtype, kind - self.register_matrix( - name=name, - shape=shape, - hbm_base_addr=hbm_addr, - real_data_ratio=real_data_ratio, - strict=strict, - ) - return self[name] - - def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.hbm_matrices: - if strict: - raise KeyError(f"HBM object '{name}' not found") - return None - info = self[name] - self.hbm_matrices.pop(name, None) - return info - - def add_vram_object( - self, - name: str, - shape: tuple[int, int], - vram_addr: int | None = None, - dtype: str = "fp16", - kind: str = "VRAMObject", - allocate_if_none: bool = True, - strict: bool = True, - ) -> MemoryObjectInfo: - rows, cols = shape - size = rows * cols - if vram_addr is None: - if not allocate_if_none: - raise ValueError("vram_addr is None and allocate_if_none is False") - vram_addr = self.vram_allocator.allocate(size=size, name=name) - del dtype, kind - self.register_vram_matrix( - name=name, - shape=shape, - vram_base_addr=vram_addr, - strict=strict, - ) - return self[name] - - def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.vram_matrices: - if strict: - raise KeyError(f"VRAM object '{name}' not found") - return None - info = self[name] - self.vram_allocator.free(name, strict=strict) - self.vram_matrices.pop(name, None) - return info - - def add_fpram_object( - self, - name: str, - size: int, - dtype: str = "fp16", - kind: str = "FPRAMObject", - ) -> MemoryObjectInfo: - fpram_addr = self.fpram_allocator.allocate(name, size) - self.fpram_matrices[name] = FPRAMObjectLayout( - name=name, - fpram_addr=fpram_addr, - size=size, - dtype=dtype, - kind=kind, - ) - return self[name] - - def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.fpram_matrices: - if strict: - raise KeyError(f"FPRAM object '{name}' not found") - return None - info = self[name] - self.fpram_allocator.free(name, strict=strict) - self.fpram_matrices.pop(name, None) - return info - - def register_matrix( - self, - name: str, - shape: tuple[int, int], - hbm_base_addr: int, - real_data_ratio: float = 1.125, - strict: bool = True, - ) -> MatrixBlockLayout: - """Register an HBM matrix and derive its mlen block layout.""" - rows, cols = shape - - if strict: - if rows % self.mlen != 0: - raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") - if cols % self.mlen != 0: - raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") - - size = rows * cols - hbm_size = int(size * real_data_ratio) - - layout = MatrixBlockLayout( - name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size - ) - - self.hbm_matrices[name] = layout - return layout - - # ========================================================================== - # VRAM Sub-matrix Management - # ========================================================================== - - def register_vram_matrix( - self, - name: str, - shape: tuple[int, int], - vram_base_addr: int, - strict: bool = True, - ) -> VRAMMatrixBlockLayout: - """ - Register a large matrix in VRAM and automatically block it. - - Args: - name: matrix name - shape: full shape (batch, hidden_size) - vram_base_addr: VRAM base address - - Returns: - VRAMMatrixBlockLayout object - """ - batch, hidden = shape - - if strict: - if batch % self.mlen != 0: - raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") - if hidden % self.mlen != 0: - raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") - - layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) - - self.vram_matrices[name] = layout - return layout - - def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: - """Get VRAM sub-block information""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_sub_block(row_idx, col_idx) - - # ========================================================================== - # ISA Generation: Load Sub Matrix - # ========================================================================== - - def _default_hbm_gp_regs(self, gp_regs: list[int] | None) -> list[int]: - return [1, 2, 3] if gp_regs is None else gp_regs - - def _emit_hbm_prefetch_setup(self, asm: IsaBuilder, layout: MatrixBlockLayout, gp_scale: int, gp_stride: int) -> None: - rows, cols = layout.full_shape - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), rows * cols) - asm.instr("C_SET_SCALE_REG", gp(gp_scale)) - asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), cols) - asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) - - def _emit_hbm_subblock_prefetch( - self, - asm: IsaBuilder, - name: str, - layout: MatrixBlockLayout, - row_idx: int, - col_idx: int, - mram_addr: int, - hbm_addr_reg: int, - gp_scale: int, - gp_mram: int, - comment: str | None = None, - ) -> None: - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - sub_block.mram_addr = mram_addr - - asm.comment(comment if comment is not None else f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) - asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) - asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) - - self.loaded_sub_blocks[f"{name}[{row_idx}][{col_idx}]"] = sub_block - - def _emit_hbm_subblock_sequence( - self, - asm: IsaBuilder, - name: str, - layout: MatrixBlockLayout, - block_coords, - mram_start_addr: int, - hbm_addr_reg: int, - gp_scale: int, - gp_mram: int, - ) -> None: - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - for row_idx, col_idx in block_coords: - self._emit_hbm_subblock_prefetch( - asm, name, layout, row_idx, col_idx, mram_addr, hbm_addr_reg, gp_scale, gp_mram - ) - mram_addr += block_size - - def load_sub_matrix_asm( - self, - name: str, - row_idx: int, - col_idx: int, - mram_dest_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """Emit HBM->MRAM prefetch for one mlen x mlen sub-block.""" - gp_regs = self._default_hbm_gp_regs(gp_regs) - layout = self.hbm_matrices[name] - - asm = IsaBuilder() - asm.comment(f"Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) - hbm_offset = layout.get_sub_block(row_idx, col_idx).hbm_offset - self._emit_hbm_subblock_prefetch( - asm, - name, - layout, - row_idx, - col_idx, - mram_dest_addr, - hbm_addr_reg, - gp_scale, - gp_mram, - comment=f"HBM offset: {hbm_offset} (precomputed)", - ) - - return asm.render() - - def load_row_sub_matrices_asm( - self, - name: str, - row_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """Emit HBM->MRAM prefetches for one block row.""" - gp_regs = self._default_hbm_gp_regs(gp_regs) - layout = self.hbm_matrices[name] - - asm = IsaBuilder() - asm.comment(f"Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) - - self._emit_hbm_subblock_sequence( - asm, - name, - layout, - ((row_idx, col_idx) for col_idx in range(layout.num_col_blocks)), - mram_start_addr, - hbm_addr_reg, - gp_scale, - gp_mram, - ) - - return asm.render() - - def load_col_sub_matrices_asm( - self, - name: str, - col_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """Emit HBM->MRAM prefetches for one block column or K-split slice.""" - gp_regs = self._default_hbm_gp_regs(gp_regs) - layout = self.hbm_matrices[name] - num_row_blocks = layout.num_row_blocks - - asm = IsaBuilder() - asm.comment(f"Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) - - effective_count = k_block_count if k_block_count is not None else num_row_blocks - self._emit_hbm_subblock_sequence( - asm, - name, - layout, - ((row_idx, col_idx) for row_idx in range(k_block_start, k_block_start + effective_count)), - mram_start_addr, - hbm_addr_reg, - gp_scale, - gp_mram, - ) - - return asm.render() - - # ========================================================================== - # ISA Generation: Sub Projection - # ========================================================================== - - def _default_projection_gp_regs(self, gp_regs: list[int] | None) -> list[int]: - return [1, 2, 3, 4, 5, 6, 7, 8, 9] if gp_regs is None else gp_regs - - def _projection_context(self, vram_mat_name: str, vram_row_idx: int, mram_mat_name: str): - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - return vram_layout, self.hbm_matrices[mram_mat_name], vram_layout.get_row_blocks(vram_row_idx) - - def _loaded_mram_start(self, blocks, missing_label) -> int: - for sub_block in blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {missing_label(sub_block)} not loaded to MRAM") - return blocks[0].mram_addr - - def _vram_sub_projection_asm_impl( - self, - header_lines: list[str], - vram_row_start_addr: int, - mram_start_addr: int, - result_vram_addr: int, - full_batch: int, - num_hidden_blocks: int, - mat_col_stride: int, - transposed: bool, - gp_regs: list[int], - caller_name: str, - unroll: bool | None = None, - ) -> str: - """Emit the shared projection loop after callers resolve operands.""" - do_unroll = self.unroll_loops if unroll is None else unroll - return vram_sub_projection_asm_impl( - mlen=self.mlen, - blen=self.blen, - unroll_loops=do_unroll, - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=transposed, - gp_regs=gp_regs, - caller_name=caller_name, - ) - - def vram_sub_projection_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_col_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - unroll: bool | None = None, - ) -> str: - """Emit VRAM[row][:] @ MRAM[:][col] projection.""" - gp_regs = self._default_projection_gp_regs(gp_regs) - vram_layout, mram_layout, vram_row_blocks = self._projection_context( - vram_mat_name, vram_row_idx, mram_mat_name - ) - mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) - # K-split: slice to only the loaded k-chunk - if k_block_count is not None: - mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] - - num_hidden_blocks = len(mram_col_blocks) - if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): - raise ValueError( - f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " - f"got {num_hidden_blocks}" - ) - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr - mram_col_start_addr = self._loaded_mram_start( - mram_col_blocks, - lambda block: f"{mram_mat_name}[{block.row_idx}][{mram_col_idx}]", - ) - - header_lines = [ - f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", - f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_col_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=self.blen, - transposed=False, - gp_regs=gp_regs, - caller_name="vram_sub_projection_asm", - unroll=unroll, - ) - - def vram_sub_projection_T_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_row_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - unroll: bool | None = None, - ) -> str: - """Emit VRAM[row][:] @ MRAM[row][:]^T projection.""" - gp_regs = self._default_projection_gp_regs(gp_regs) - vram_layout, mram_layout, vram_row_blocks = self._projection_context( - vram_mat_name, vram_row_idx, mram_mat_name - ) - mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) - - if len(vram_row_blocks) != len(mram_row_blocks): - raise ValueError( - f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" - ) - - num_hidden_blocks = len(vram_row_blocks) - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[0].vram_addr - mram_row_start_addr = self._loaded_mram_start( - mram_row_blocks, - lambda block: f"{mram_mat_name}[{mram_row_idx}][{block.col_idx}]", - ) - # NOTE: For M_TMM (transposed matmul), the MRAM outer-column stride is - # blen * mlen (full sub-block size) because M_TMM reads the weight in - # transposed layout. This differs from the non-transposed path which uses - # blen. This is intentional for the M_TMM addressing contract. - mat_col_stride = self.blen * self.mlen - - header_lines = [ - f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", - f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_row_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=True, - gp_regs=gp_regs, - caller_name="vram_sub_projection_T_asm", - unroll=unroll, - ) - - def vram_block_add_asm( - self, - src1_name: str, - src1_row_idx: int, - src1_col_idx: int, - src2_name: str, - src2_row_idx: int, - src2_col_idx: int, - target_name: str, - target_row_idx: int, - target_col_idx: int, - gp_regs: list[int] | None = None, - ) -> str: - """ - Add two mlen x mlen blocks and write to any target block: - target[target_row_idx][target_col_idx] = - src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] - - Source/target can be the same matrix (supports in-place overwrite). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4] - if len(gp_regs) < 4: - raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") - - src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) - src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) - target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) - - gp_dst = gp_regs[0] - gp_src1 = gp_regs[1] - gp_src2 = gp_regs[2] - gp_loop = gp_regs[3] - - lines = [] - lines.append( - f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " - f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" - ) - - # One V_ADD_VV processes one row (mlen elements). Use C_LOOP to reduce ISA size. - if self.unroll_loops: - for i in range(self.mlen): - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - else: - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") - lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - - return "\n".join(lines) + "\n" - - def reset(self): - """Reset manager state.""" - self.hbm_matrices.clear() - self.vram_matrices.clear() - self.fpram_matrices.clear() - self.vram_allocator.reset() - self.mram_allocator.reset() - self.fpram_allocator.reset() - self.loaded_sub_blocks.clear() - -__all__ = ["TileCompiler"] +__all__ = ["MemoryStateMixin", "TileCompiler"] diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py index 8e473c6..1e63ce7 100644 --- a/aten/plena_compiler.py +++ b/aten/plena_compiler.py @@ -5,6 +5,7 @@ from compiler.aten.plena.compiler import PlenaCompiler from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN from compiler.aten.plena.isa_compiler import IsaCompiler +from compiler.aten.plena.memory_state import MemoryStateMixin from compiler.aten.plena.tile_compiler import TileCompiler from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar @@ -15,6 +16,7 @@ "FPVar", "InputVar", "IsaCompiler", + "MemoryStateMixin", "PlenaCompiler", "TensorVar", "TileCompiler", diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index 1e7efd5..52eb650 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -73,9 +73,9 @@ def test_fpvar_helper_uses_canonical_emit_path(): def test_hbm_load_helper_uses_typed_legalization(): """Converted HBM load helpers should legalize typed large immediates.""" - from compiler.aten.plena_compiler import TileCompiler + from compiler.aten.plena_compiler import IsaCompiler - compiler = TileCompiler() + compiler = IsaCompiler() compiler.register_matrix("W", (512, 512), hbm_base_addr=0) code = compiler.load_sub_matrix_asm("W", row_idx=0, col_idx=0, mram_dest_addr=0) From ad2983ed35f3158d8ef72f02a7aa17cbc3033531 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 16:14:20 +0100 Subject: [PATCH 26/32] refactor: remove tile compiler compatibility alias --- aten/__init__.py | 1 - aten/plena/__init__.py | 2 -- aten/plena/tile_compiler.py | 9 --------- aten/plena_compiler.py | 2 -- 4 files changed, 14 deletions(-) delete mode 100644 aten/plena/tile_compiler.py diff --git a/aten/__init__.py b/aten/__init__.py index 8d4db3a..f95ea37 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -25,7 +25,6 @@ MemoryStateMixin, PlenaCompiler, TensorVar, - TileCompiler, VRAMMatrixVar, ) from compiler.aten.ops.registry import OpRegistry, Backend # noqa: E402, F401 diff --git a/aten/plena/__init__.py b/aten/plena/__init__.py index d812ba6..4d8604f 100644 --- a/aten/plena/__init__.py +++ b/aten/plena/__init__.py @@ -4,7 +4,6 @@ from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN from compiler.aten.plena.isa_compiler import IsaCompiler from compiler.aten.plena.memory_state import MemoryStateMixin -from compiler.aten.plena.tile_compiler import TileCompiler from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar __all__ = [ @@ -17,6 +16,5 @@ "MemoryStateMixin", "PlenaCompiler", "TensorVar", - "TileCompiler", "VRAMMatrixVar", ] diff --git a/aten/plena/tile_compiler.py b/aten/plena/tile_compiler.py deleted file mode 100644 index 72091a4..0000000 --- a/aten/plena/tile_compiler.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Compatibility alias for the former TileCompiler memory-state class.""" - -from __future__ import annotations - -from compiler.aten.plena.memory_state import MemoryStateMixin - -TileCompiler = MemoryStateMixin - -__all__ = ["MemoryStateMixin", "TileCompiler"] diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py index 1e63ce7..8b2a001 100644 --- a/aten/plena_compiler.py +++ b/aten/plena_compiler.py @@ -6,7 +6,6 @@ from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN from compiler.aten.plena.isa_compiler import IsaCompiler from compiler.aten.plena.memory_state import MemoryStateMixin -from compiler.aten.plena.tile_compiler import TileCompiler from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar __all__ = [ @@ -19,6 +18,5 @@ "MemoryStateMixin", "PlenaCompiler", "TensorVar", - "TileCompiler", "VRAMMatrixVar", ] From 01c0a32bbc430c3c751b89640f1c00a163d99ff5 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 16:40:36 +0100 Subject: [PATCH 27/32] refactor: remove plena compiler compatibility facade --- asm_templates/load_int.py | 0 .../tests/test_vram_sub_projection.py | 2 +- aten/__init__.py | 2 +- aten/plena_compiler.py | 22 ------------------- aten/plena_frontend.py | 2 +- aten/tests/test_plena_compiler.py | 10 ++++----- 6 files changed, 8 insertions(+), 30 deletions(-) delete mode 100644 asm_templates/load_int.py delete mode 100644 aten/plena_compiler.py diff --git a/asm_templates/load_int.py b/asm_templates/load_int.py deleted file mode 100644 index e69de29..0000000 diff --git a/asm_templates/tests/test_vram_sub_projection.py b/asm_templates/tests/test_vram_sub_projection.py index 27ba85e..912bee0 100644 --- a/asm_templates/tests/test_vram_sub_projection.py +++ b/asm_templates/tests/test_vram_sub_projection.py @@ -96,7 +96,7 @@ def test_unrolled_no_loops(self): def test_output_byte_identical_to_method(self): """The free function must produce byte-identical output to IsaCompiler's method.""" - from compiler.aten.plena_compiler import IsaCompiler + from compiler.aten.plena import IsaCompiler compiler = IsaCompiler(mlen=64, blen=4, unroll_loops=False) diff --git a/aten/__init__.py b/aten/__init__.py index f95ea37..d237319 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -18,7 +18,7 @@ fp, gp, ) -from compiler.aten.plena_compiler import ( # noqa: E402, F401 +from compiler.aten.plena import ( # noqa: E402, F401 FPVar, InputVar, IsaCompiler, diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py deleted file mode 100644 index 8b2a001..0000000 --- a/aten/plena_compiler.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Compatibility facade for the active ATen PLENA compiler path.""" - -from __future__ import annotations - -from compiler.aten.plena.compiler import PlenaCompiler -from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN -from compiler.aten.plena.isa_compiler import IsaCompiler -from compiler.aten.plena.memory_state import MemoryStateMixin -from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar - -__all__ = [ - "BLEN", - "IMM2_BOUND", - "MLEN", - "FPVar", - "InputVar", - "IsaCompiler", - "MemoryStateMixin", - "PlenaCompiler", - "TensorVar", - "VRAMMatrixVar", -] diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index f888273..10abb25 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -16,7 +16,7 @@ ) import compiler.aten.ops as ops from compiler.aten.ops.registry import Backend, OpRegistry -from compiler.aten.plena_compiler import PlenaCompiler +from compiler.aten.plena import PlenaCompiler from compiler.aten.reference import ( ReferencePrecision, _ksplit_matmul, diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index 52eb650..8985138 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -60,7 +60,7 @@ def test_isa_builder_preserves_relative_large_immediates(): def test_fpvar_helper_uses_canonical_emit_path(): """Converted FPVar helpers should still append and return asm text.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() code = prog.fpvar_add_asm(src1_addr=0, src2_addr=4, dst_addr=8, count=2) @@ -73,7 +73,7 @@ def test_fpvar_helper_uses_canonical_emit_path(): def test_hbm_load_helper_uses_typed_legalization(): """Converted HBM load helpers should legalize typed large immediates.""" - from compiler.aten.plena_compiler import IsaCompiler + from compiler.aten.plena import IsaCompiler compiler = IsaCompiler() compiler.register_matrix("W", (512, 512), hbm_base_addr=0) @@ -87,7 +87,7 @@ def test_hbm_load_helper_uses_typed_legalization(): def test_vram_fill_zero_all_column_blocks(): """vram_fill_zero must zero ALL column blocks of a wide matrix.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() x = prog.alloc("X", 64, 384) @@ -103,7 +103,7 @@ def test_vram_fill_zero_all_column_blocks(): def test_vram_add_all_column_blocks(): """vram_add must add ALL column blocks of wide matrices.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() x = prog.alloc("X", 64, 384) @@ -120,7 +120,7 @@ def test_vram_add_all_column_blocks(): def test_alloc_at_correct_address(): """alloc_at must create a VRAM view at the specified address.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() # Allocate a matrix, then create a view into its second column block From 274a1b02ec42750f2fbb10f70305f3f24c74f9b2 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 17:22:34 +0100 Subject: [PATCH 28/32] docs: add aten package tree --- docs/ATEN_TREE.md | 74 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 docs/ATEN_TREE.md diff --git a/docs/ATEN_TREE.md b/docs/ATEN_TREE.md new file mode 100644 index 0000000..c95fae9 --- /dev/null +++ b/docs/ATEN_TREE.md @@ -0,0 +1,74 @@ +# ATen Compiler Tree + +This summarizes the current `aten/` package layout. Generated `__pycache__/` +directories are intentionally omitted. + +```text +aten/ +|-- __init__.py # Public ATen package exports +|-- native_ops.yaml # Operator registry spec: signatures and dispatch targets +|-- isa_builder.py # Typed ISA instruction/register builder and legalization +|-- model_extract.py # HuggingFace model config/layer/embedding extraction helpers +|-- plena_frontend.py # HF decoder model -> PLENA program -> ISA text +|-- reference.py # CPU golden/reference math and MXFP/BF16 helpers +|-- vram_stage_compare.py # Debug tooling for VRAM stage comparisons +| +|-- ops/ # ATen-style operator dispatch layer +| |-- __init__.py # User-facing ops.* dispatch functions +| |-- registry.py # Backend registry: CPU vs PLENA implementation lookup +| | +| |-- cpu/ # PyTorch reference backend +| | |-- __init__.py +| | |-- attention_ops.py # CPU flash attention reference +| | |-- conv_ops.py # CPU conv reference +| | |-- embedding_ops.py # CPU embedding/RoPE reference +| | |-- ffn_ops.py # CPU FFN reference +| | |-- linear_ops.py # CPU linear reference +| | |-- norm_ops.py # CPU RMS/layer norm reference +| | +-- softmax_ops.py # CPU softmax reference +| | +| +-- plena/ # PLENA backend operator wrappers +| |-- __init__.py +| |-- attention_ops.py # ops.flash_attention -> prog.flash_attention +| |-- conv_ops.py # conv lowering / PLENA conv codegen helper +| |-- embedding_ops.py # embedding_add / rope wrappers +| |-- ffn_ops.py # ops.ffn -> prog.ffn +| |-- linear_ops.py # ops.linear -> prog.linear_projection +| |-- norm_ops.py # rms_norm / layer_norm wrappers +| +-- softmax_ops.py # PLENA softmax lowering +| +|-- plena/ # Canonical PLENA compiler implementation package +| |-- __init__.py # Canonical exports: PlenaCompiler, IsaCompiler, vars, constants +| |-- compiler.py # Top-level PlenaCompiler composition class +| |-- constants.py # BLEN/MLEN/immediate constants +| |-- vars.py # Tensor/Input/VRAM/FP variable descriptors +| |-- registers.py # GP/ADDR/FP register allocation helpers +| |-- memory.py # Memory layout/address helpers +| |-- memory_state.py # Compiler-owned tensor/input/fp memory state +| | +| |-- program_tensors.py # Program-level tensor allocation/load/store helpers +| |-- program_fp_tile_ops.py # Program-level FP/tile scalar-vector operations +| |-- program_matrix_ops.py # Program-level matrix/projection/FFN operations +| |-- program_attention.py # Program-level flash attention operation +| | +| |-- isa_compiler.py # Low-level ISA emitter base and typed emit path +| |-- isa_emit.py # Generic emit helpers +| |-- isa_fp_ops.py # Scalar FP/FPRAM/tile FP ISA helpers +| |-- isa_tile_rows.py # Tile-row unary/binary loop emitters +| |-- isa_matrix.py # Matrix/projection/load/store ISA emitters +| +-- isa_attention.py # Attention-specific ISA helpers +| ++-- tests/ + |-- __init__.py + |-- test_bf16_numerical_stability.py + |-- test_plena_compiler.py + +-- test_quantization_ablation.py +``` + +Key points: + +- `aten/plena/` is the canonical compiler implementation package. +- `aten/ops/` is the ATen-style dispatcher surface. +- `aten/plena_frontend.py` is the HuggingFace/ATen frontend that drives model + compilation. +- The old `aten/plena_compiler.py` compatibility facade has been removed. From f110f05f5167fa9c27454fe73634454548644795 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 17:44:03 +0100 Subject: [PATCH 29/32] refactor: remove remaining compatibility shims --- asm_templates/__init__.py | 2 +- asm_templates/flash_attn_asm.py | 50 ------------------------ asm_templates/flashattn/memory_layout.md | 0 aten/ops/plena/attention_ops.py | 2 +- aten/ops/plena/embedding_ops.py | 2 +- aten/ops/plena/ffn_ops.py | 2 +- aten/ops/plena/linear_ops.py | 2 +- aten/tests/test_plena_compiler.py | 4 +- aten/tests/test_quantization_ablation.py | 4 +- doc/precision.svh | 2 +- docs/ARCHITECTURE.md | 15 ++++--- docs/COMPILATION_PIPELINES.md | 18 ++++----- generator/tests/test_generator_e2e.py | 2 +- sim_env_utils/__init__.py | 5 +-- 14 files changed, 32 insertions(+), 78 deletions(-) delete mode 100644 asm_templates/flash_attn_asm.py delete mode 100644 asm_templates/flashattn/memory_layout.md diff --git a/asm_templates/__init__.py b/asm_templates/__init__.py index d6e77e8..8c5c02f 100644 --- a/asm_templates/__init__.py +++ b/asm_templates/__init__.py @@ -18,7 +18,7 @@ from .elementwise_add_asm import elementwise_add_asm from .embedding_asm import embedding_asm from .ffn_asm import ffn_asm, ffn_intermediate_asm, ffn_up_silu_asm -from .flash_attn_asm import flash_attn_asm +from .flashattn import flash_attn_asm from .gelu_asm import gelu_asm from .im2col_asm import im2col_asm from .im2col_asm_no_shift import im2col_asm_no_shift diff --git a/asm_templates/flash_attn_asm.py b/asm_templates/flash_attn_asm.py deleted file mode 100644 index 6d507bf..0000000 --- a/asm_templates/flash_attn_asm.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Flash Attention assembly code generation. - -This module has been refactored into the flashattn package. -This file re-exports all functions for backward compatibility. - -For new code, prefer importing from compiler.asm_templates.flashattn directly. -""" - -# Re-export all functions from the flashattn package for backward compatibility -from .flashattn import ( - computing_o_code as _computing_o_code, -) -from .flashattn import ( - computing_pv_code as _computing_pv_code, -) -from .flashattn import ( - computing_row_wise_scaling_code as _computing_row_wise_scaling_code, -) -from .flashattn import ( - flash_attn_asm, - qkt_multiply, -) -from .flashattn import ( - online_softmax_code as _online_softmax_code, -) -from .flashattn import ( - reset_fpsram_code as _reset_fpsram_code, -) -from .flashattn import ( - reset_kv_prefetch as _reset_kv_prefetch, -) -from .flashattn import ( - reset_vssram_code as _reset_vssram_code, -) - -# Also export IMM2_BOUND for backward compatibility -IMM2_BOUND = 2**18 - 1 - -__all__ = [ - "IMM2_BOUND", - "_computing_o_code", - "_computing_pv_code", - "_computing_row_wise_scaling_code", - "_online_softmax_code", - "_reset_fpsram_code", - "_reset_kv_prefetch", - "_reset_vssram_code", - "flash_attn_asm", - "qkt_multiply", -] diff --git a/asm_templates/flashattn/memory_layout.md b/asm_templates/flashattn/memory_layout.md deleted file mode 100644 index e69de29..0000000 diff --git a/aten/ops/plena/attention_ops.py b/aten/ops/plena/attention_ops.py index 3f07564..e1433ce 100644 --- a/aten/ops/plena/attention_ops.py +++ b/aten/ops/plena/attention_ops.py @@ -1,4 +1,4 @@ -"""PLENA backend compatibility shim for Flash Attention.""" +"""PLENA backend wrapper for Flash Attention.""" def flash_attention_plena(prog, Q, K, V, scale=None, hq=1, hkv=1, h_qkv=None, causal_mask=None): diff --git a/aten/ops/plena/embedding_ops.py b/aten/ops/plena/embedding_ops.py index f4ab13b..5497f88 100644 --- a/aten/ops/plena/embedding_ops.py +++ b/aten/ops/plena/embedding_ops.py @@ -1,4 +1,4 @@ -"""PLENA backend compatibility shims for positional encoding operators.""" +"""PLENA backend wrappers for positional encoding operators.""" def embedding_add_plena(prog, input_var, pos_weight_var): diff --git a/aten/ops/plena/ffn_ops.py b/aten/ops/plena/ffn_ops.py index 726fc47..3fdcab5 100644 --- a/aten/ops/plena/ffn_ops.py +++ b/aten/ops/plena/ffn_ops.py @@ -1,4 +1,4 @@ -"""PLENA backend compatibility shim for FFN.""" +"""PLENA backend wrapper for FFN.""" def ffn_plena(prog, input_var, w_gate, w_up, w_down): diff --git a/aten/ops/plena/linear_ops.py b/aten/ops/plena/linear_ops.py index 56d801c..7319e9f 100644 --- a/aten/ops/plena/linear_ops.py +++ b/aten/ops/plena/linear_ops.py @@ -1,4 +1,4 @@ -"""PLENA backend compatibility shims for linear operators.""" +"""PLENA backend wrappers for linear operators.""" def linear_projection_plena(prog, input_var, weight_var, name: str = "linear_out"): diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index 8985138..a10932e 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -1,13 +1,13 @@ """Unit tests for PlenaCompiler ATen path (no emulator needed). -Run: PYTHONPATH=.:tools:compiler python3 compiler/aten/tests/test_plena_compiler.py +Run: PYTHONPATH=.:tools:.. python3 aten/tests/test_plena_compiler.py """ import sys import os # Insert PLENA_Simulator root and tools/ so imports resolve correctly regardless -# of how the test is invoked (direct python3 or via PYTHONPATH=.:tools:compiler). +# of how the test is invoked (direct python3 or via PYTHONPATH=.:tools:..). _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _SIM_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))) _SIM_ROOT = os.path.join(_SIM_ROOT, "PLENA_Simulator") diff --git a/aten/tests/test_quantization_ablation.py b/aten/tests/test_quantization_ablation.py index 58fdb94..13762f1 100644 --- a/aten/tests/test_quantization_ablation.py +++ b/aten/tests/test_quantization_ablation.py @@ -9,8 +9,8 @@ fp32 (fp32 + fp32) ~99% allclose ← confirms quantization is sole cause Usage: - pytest compiler/aten/tests/test_quantization_ablation.py -v -s - python3 compiler/aten/tests/test_quantization_ablation.py [--layers N] + pytest aten/tests/test_quantization_ablation.py -v -s + python3 aten/tests/test_quantization_ablation.py [--layers N] """ import argparse diff --git a/doc/precision.svh b/doc/precision.svh index a8a1f1c..2652c7d 100644 --- a/doc/precision.svh +++ b/doc/precision.svh @@ -1,6 +1,6 @@ // precision.svh - PLENA numeric precision parameters. // -// Intentionally minimal: compiler/generator/parser/hardware_parser.py (and the +// Intentionally minimal: generator/parser/hardware_parser.py (and the // helpers under tools/) tolerate an empty precision file and fall back to the // following per-parameter defaults when the corresponding `parameter` line is // absent: diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 6fe9cd5..483d994 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -3,7 +3,7 @@ ## Directory Structure ``` -compiler/ +PLENA_Compiler/ |-- asm_templates/ # Shared ISA instruction emitters (used by BOTH pipelines) | |-- flashattn/ # Flash attention (qkt, pv, online_softmax, output, reset) | |-- ffn_asm.py # FFN (gate/up/down with K-split) @@ -26,10 +26,15 @@ compiler/ | +-- reset_reg_asm.py # Register reset helpers | |-- aten/ # Pipeline 1: ATen compilation backend -| |-- plena_compiler.py # PlenaCompiler class (VRAM/MRAM/FPRAM management) +| |-- plena_frontend.py # HF model -> PLENA program -> ISA text +| |-- plena/ # Canonical PlenaCompiler implementation package +| | |-- compiler.py # PlenaCompiler composition class +| | |-- memory_state.py # Tensor/input/FP memory state +| | |-- program_*.py # High-level program operations +| | +-- isa_*.py # Low-level ISA emitters | +-- ops/ # Registered ATen op implementations | |-- registry.py # Op dispatch registry -| |-- plena/ # PLENA backend (linear, attention, ffn, norm, conv, softmax, embedding) +| |-- plena/ # PLENA backend wrappers | +-- cpu/ # CPU reference implementations | |-- generator/ # Pipeline 2: Config-driven code generation @@ -82,7 +87,7 @@ compiler/ nn.Module -> torch.export (FX graph of ATen ops) -> op dispatch (aten/ops/registry.py) - -> PlenaCompiler (aten/plena_compiler.py) + -> PlenaCompiler (aten/plena/) - VRAM/MRAM/FPRAM allocation - HBM weight layout + address register init - calls asm_templates/* for ISA emission @@ -100,7 +105,7 @@ HF config.json -> code_gen_pass (generator/passes/code_gen.py) - dispatches each node to asm_templates/* -> assembler/ (ASM -> .mem binary) - -> emulator (runs but numerically incorrect -- no addr reg init) + -> emulator smoke / utilization analysis ``` See `docs/COMPILATION_PIPELINES.md` for detailed comparison, known gaps, diff --git a/docs/COMPILATION_PIPELINES.md b/docs/COMPILATION_PIPELINES.md index 38dc7d2..04d93a0 100644 --- a/docs/COMPILATION_PIPELINES.md +++ b/docs/COMPILATION_PIPELINES.md @@ -49,7 +49,7 @@ backends, and weight-handling strategies. embedding). A CPU fallback registry (`aten/ops/cpu/`) provides reference implementations for ops not yet hardware-mapped. -3. **Backend**: `PlenaCompiler` (`aten/plena_compiler.py`) manages all +3. **Backend**: `PlenaCompiler` (`aten/plena/`) manages all hardware state -- VRAM allocation, MRAM tile scheduling, FPRAM slot assignment, HBM weight layout, and address register initialization (`C_SET_ADDR_REG`). It calls into `asm_templates/` to emit ISA strings. @@ -65,7 +65,8 @@ backends, and weight-handling strategies. | File | Role | |------|------| -| `aten/plena_compiler.py` | PlenaCompiler class (VRAM/MRAM/FPRAM management, ISA emission) | +| `aten/plena/` | Canonical PlenaCompiler implementation package | +| `aten/plena_frontend.py` | HuggingFace model frontend that drives ATen compilation | | `aten/ops/plena/*.py` | Registered ATen op implementations (linear, attention, ffn, norm, conv, softmax, embedding) | | `aten/ops/cpu/*.py` | CPU reference fallbacks | | `aten/ops/registry.py` | Op dispatch registry | @@ -88,8 +89,7 @@ pass with 98-100% allclose. ## Pipeline 2: Generator Path -**Status**: Generates valid ISA for analysis. Numerically incomplete -(HBM address registers uninitialized). +**Status**: Generates valid ISA for structural analysis and smoke tests. ### How it works @@ -104,16 +104,16 @@ pass with 98-100% allclose. 3. **Backend**: `code_gen_pass` (`generator/passes/code_gen.py`) walks the symbolic graph and dispatches each node to the appropriate `asm_templates/` function, passing scheduler-derived register and - address parameters. + address parameters. It emits address-register initialization for HBM-backed + weights before the generated compute body. 4. **Weight loading**: For E2E smoke tests, `test_generator_e2e.py` has a `_build_hbm_from_hf_weights` helper that loads real weights. The standard codegen path does not touch weights at all. -5. **Output**: A `.asm` file that assembles cleanly and runs on the emulator - (the instructions are structurally valid), but produces numerically - incorrect results because HBM address registers (`C_SET_ADDR_REG`) are - not initialized. +5. **Output**: A `.asm` file that assembles cleanly and runs on the emulator. + The generator path is still primarily used for structural codegen and + utilization work; the ATen path remains the numerically verified flow. ### Key files diff --git a/generator/tests/test_generator_e2e.py b/generator/tests/test_generator_e2e.py index 6efd40b..82e61e8 100644 --- a/generator/tests/test_generator_e2e.py +++ b/generator/tests/test_generator_e2e.py @@ -308,7 +308,7 @@ def _build_fp_sram_preload( slot (or the harness needs to refresh slot 5 between text and vision runs) to fix this. Tracking issue: TODO. - Slot map (from compiler/generator/scheduler/mem_layout_lib.json): + Slot map (from generator/scheduler/mem_layout_lib.json): 0: infinity — softmax masking sentinel (use a large fp16 negative) 1: eps — RMSNorm epsilon 2: hid_reciprocal — 1.0 / hidden_size diff --git a/sim_env_utils/__init__.py b/sim_env_utils/__init__.py index a40e6d3..b2e4ac8 100644 --- a/sim_env_utils/__init__.py +++ b/sim_env_utils/__init__.py @@ -1,4 +1,3 @@ -from .build_env import create_mem_for_sim +from .build_env import create_mem_for_sim as create_mem_for_sim -# Legacy alias -build_sim_env = create_mem_for_sim +__all__ = ["create_mem_for_sim"] From a344f1d69e931b3cfac27fb2845073338a3b0729 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 17:51:27 +0100 Subject: [PATCH 30/32] docs: refresh generator path comment --- generator/tests/test_generator_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generator/tests/test_generator_e2e.py b/generator/tests/test_generator_e2e.py index 82e61e8..897ceef 100644 --- a/generator/tests/test_generator_e2e.py +++ b/generator/tests/test_generator_e2e.py @@ -89,7 +89,7 @@ def _build_hbm_from_hf_weights( ) -> dict: """Populate hbm_for_behave_sim.bin with real HF model weights. - Mirrors compiler/sim_env_utils/build_env.py::create_mem_for_sim but + Mirrors sim_env_utils/build_env.py::create_mem_for_sim but operates directly on HF tensors (no intermediate .pt files) and writes each weight block at the scheduler-assigned HBM offset. From a74454d37843b4b9a5012bfc1c14404e5b5faa66 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Tue, 12 May 2026 17:54:16 +0100 Subject: [PATCH 31/32] docs: refresh sim env path comments --- sim_env_utils/build_env.py | 2 +- sim_env_utils/build_sys_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sim_env_utils/build_env.py b/sim_env_utils/build_env.py index 9b8696c..6c99cc0 100644 --- a/sim_env_utils/build_env.py +++ b/sim_env_utils/build_env.py @@ -16,7 +16,7 @@ from .build_sys_tools import env_setup, init_mem # noqa: E402 -# Project root is 3 levels up from compiler/sim_env_utils/ +# Project root is 3 levels up from PLENA_Compiler/sim_env_utils/ _PROJECT_ROOT = Path(__file__).parent.parent.parent logger = get_logger("testbench") diff --git a/sim_env_utils/build_sys_tools.py b/sim_env_utils/build_sys_tools.py index a134b40..bd39558 100644 --- a/sim_env_utils/build_sys_tools.py +++ b/sim_env_utils/build_sys_tools.py @@ -12,7 +12,7 @@ from compiler.assembler.assembly_to_binary import AssemblyToBinary -# Project root is 3 levels up from compiler/sim_env_utils/ +# Project root is 3 levels up from PLENA_Compiler/sim_env_utils/ _PROJECT_ROOT = Path(__file__).parent.parent.parent From 2af064ff454b1db57390c12f945ab6e3d49cf650 Mon Sep 17 00:00:00 2001 From: booth-algo Date: Wed, 13 May 2026 00:54:39 +0100 Subject: [PATCH 32/32] refactor: move ATen e2e runner out of generator --- aten/e2e_runner.py | 323 +++++++++++++++++++++++++ docs/ARCHITECTURE.md | 4 +- docs/ATEN_TREE.md | 3 + docs/COMPILATION_PIPELINES.md | 6 +- generator/aten_runner.py | 324 +------------------------- generator/runner.py | 6 +- generator/tests/test_generator_e2e.py | 2 +- 7 files changed, 340 insertions(+), 328 deletions(-) create mode 100644 aten/e2e_runner.py diff --git a/aten/e2e_runner.py b/aten/e2e_runner.py new file mode 100644 index 0000000..70ffc48 --- /dev/null +++ b/aten/e2e_runner.py @@ -0,0 +1,323 @@ +"""ATen-backed end-to-end runner. + +This wraps the verified ATen compilation path: + + HuggingFace model -> PlenaCompiler + ops.* -> ISA -> emulator -> golden check + +The symbolic generator path is separate and remains under ``generator.runner +codegen``. + +Usage: + python -m compiler.aten.e2e_runner AICrossSim/clm-60m --seq-len 32 +""" + +import sys +import time +from pathlib import Path + +# --------------------------------------------------------------------------- +# Repo root bootstrap — mirror the same sys.path setup used by the existing +# test infrastructure so imports resolve regardless of cwd. +# --------------------------------------------------------------------------- +_COMPILER_ROOT = Path(__file__).resolve().parents[1] # PLENA_Compiler/ +_REPO_ROOT = _COMPILER_ROOT.parent +for _p in [str(_REPO_ROOT), str(_REPO_ROOT / "tools"), str(_COMPILER_ROOT)]: + if _p not in sys.path: + sys.path.insert(0, _p) + + +def run_aten_e2e( + model_id: str, + seq_len: int = 64, + num_layers: int = 1, + build_dir: str | None = None, + layer_idx: int = 0, + hidden_size: int = 64, + inter_dim: int = 128, + trust_remote_code: bool = False, + partial_load: bool = False, +) -> dict: + """Run a HF model through the ATen compilation path end-to-end. + + Steps: + 1. Load model config + layer weights from HuggingFace + 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) + 3. Set up sim environment (ASM + HBM weights + FPRAM constants) + 4. Run Rust emulator + 5. Compare VRAM output against golden PyTorch reference + + Returns dict with: + passed: bool + allclose_match_rate: float (percentage) + max_error: float + mae: float + mse: float + elapsed_s: float (wall-clock seconds) + model_id: str + layer_idx: int + num_layers: int + seq_len: int + hidden_size: int + inter_dim: int + build_dir: str + """ + from transactional_emulator.testbench.emulator_runner import compare_emulator_output + from transactional_emulator.testbench.model_layer_test_builder import ( + build_and_run_decoder_test, + build_and_run_multi_layer_test, + get_model_dims, + slice_dims_for_sim, + ) + + t0 = time.time() + + # Resolve build directory + if build_dir is None: + safe_name = model_id.replace("/", "_") + build_dir = str( + Path("/tmp") / f"aten_e2e_{safe_name}_sl{seq_len}_l{layer_idx}" + ) + build_path = Path(build_dir) + + # ------------------------------------------------------------------ + # [1/5] Probe model config + # ------------------------------------------------------------------ + print(f"[1/5] Probing model config: {model_id}") + try: + full_dims = get_model_dims(model_id) + except (OSError, ConnectionError) as exc: + print(f"[SKIP] HuggingFace model '{model_id}' unavailable: {exc}") + return { + "passed": False, + "error": str(exc), + "model_id": model_id, + } + sim_dims = slice_dims_for_sim(full_dims, hidden_slice=hidden_size, inter_slice=inter_dim) + print(f" Full dims: hidden={full_dims.hidden_size}, inter={full_dims.inter_dim}, " + f"heads={full_dims.num_heads}, kv_heads={full_dims.num_kv_heads}, head_dim={full_dims.head_dim}") + print(f" Sim dims: hidden={sim_dims.hidden_size}, inter={sim_dims.inter_dim}") + + # ------------------------------------------------------------------ + # [2/5] Build ISA + golden reference + sim env via build_and_run_decoder_test + # + # We call the proven function directly — it handles: + # - Weight loading + slicing + # - PlenaCompiler ISA generation + # - create_sim_env + create_mem_for_sim + # - Golden reference computation + # - Emulator execution + comparison + # + # For multi-layer: iterate layers (each is independent at sim scale). + # ------------------------------------------------------------------ + results_per_layer = [] + + if num_layers == 1: + # Single layer: use proven single-layer path (with RoPE) + current_layer = layer_idx + asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" + layer_build = build_path / f"layer_{current_layer}" + + print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") + print(f"[3/5] Setting up sim environment: {layer_build}") + print("[4/5] Running Rust transactional emulator") + + extra_kwargs = {} + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True + if partial_load: + extra_kwargs["partial_load"] = True + + try: + build_and_run_decoder_test( + model_id=model_id, + asm_name=asm_name, + build_dir=layer_build, + layer_idx=current_layer, + seq_len=seq_len, + hidden_size=hidden_size, + inter_dim=inter_dim, + **extra_kwargs, + ) + comp_results, _comp_params = compare_emulator_output(layer_build) + results_per_layer.append({ + "layer": current_layer, + "passed": True, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except SystemExit as e: + if e.code == 0: + return { + "passed": False, + "error": "HuggingFace model unavailable (skipped)", + "model_id": model_id, + } + try: + comp_results, _comp_params = compare_emulator_output(layer_build) + results_per_layer.append({ + "layer": current_layer, + "passed": False, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except Exception: + results_per_layer.append({ + "layer": current_layer, + "passed": False, + "error": f"Emulator comparison failed after exit code {e.code}", + }) + else: + # Multi-layer: chain N layers with residual connections (no RoPE) + asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" + chain_build = build_path / f"chain_{num_layers}layers" + + print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") + print(f"[3/5] Setting up sim environment: {chain_build}") + print("[4/5] Running Rust transactional emulator") + + extra_kwargs = {} + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True + if partial_load: + extra_kwargs["partial_load"] = True + + try: + build_and_run_multi_layer_test( + model_id=model_id, + asm_name=asm_name, + build_dir=chain_build, + num_layers=num_layers, + layer_idx_start=layer_idx, + seq_len=seq_len, + hidden_size=hidden_size, + inter_dim=inter_dim, + **extra_kwargs, + ) + comp_results, _comp_params = compare_emulator_output(chain_build) + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": True, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except SystemExit as e: + if e.code == 0: + return { + "passed": False, + "error": "HuggingFace model unavailable (skipped)", + "model_id": model_id, + } + try: + comp_results, _comp_params = compare_emulator_output(chain_build) + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": False, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except Exception: + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": False, + "error": f"Emulator comparison failed after exit code {e.code}", + }) + + elapsed = time.time() - t0 + + # ------------------------------------------------------------------ + # [5/5] Aggregate results + # ------------------------------------------------------------------ + print(f"\n[5/5] Results summary ({elapsed:.1f}s elapsed)") + all_passed = all(r.get("passed", False) for r in results_per_layer) + + # Use first layer's metrics for the top-level result + first = results_per_layer[0] if results_per_layer else {} + + summary = { + "passed": all_passed, + "allclose_match_rate": first.get("allclose_match_rate", 0.0), + "max_error": first.get("max_error", float("inf")), + "mae": first.get("mae", float("inf")), + "mse": first.get("mse", float("inf")), + "elapsed_s": elapsed, + "model_id": model_id, + "layer_idx": layer_idx, + "num_layers": num_layers, + "seq_len": seq_len, + "hidden_size": hidden_size, + "inter_dim": inter_dim, + "build_dir": str(build_path), + "layers": results_per_layer, + } + + for r in results_per_layer: + status = "PASS" if r.get("passed") else "FAIL" + match = r.get("allclose_match_rate", "N/A") + if isinstance(match, float): + match = f"{match:.2f}%" + print(f" Layer {r.get('layer', '?')}: [{status}] allclose={match}") + + if all_passed: + print(f"\n[ATen e2e PASSED] {model_id} — {num_layers} layer(s), " + f"allclose={first.get('allclose_match_rate', 0):.2f}%") + else: + print(f"\n[ATen e2e FAILED] {model_id} — see per-layer results above") + + return summary + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", + prog="python -m compiler.aten.e2e_runner", + ) + parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") + parser.add_argument("--seq-len", type=int, default=64, + help="Sequence length (default: 64)") + parser.add_argument("--num-layers", type=int, default=1, + help="Number of decoder layers to test (default: 1)") + parser.add_argument("--layer-idx", type=int, default=0, + help="Starting layer index (default: 0)") + parser.add_argument("--hidden-size", type=int, default=64, + help="Hidden dimension clipped to sim limits (default: 64)") + parser.add_argument("--inter-dim", type=int, default=128, + help="FFN intermediate dimension clipped to sim limits (default: 128)") + parser.add_argument("--build-dir", type=str, default=None, + help="Build directory for sim artifacts (default: /tmp/aten_e2e_...)") + parser.add_argument("--trust-remote-code", action="store_true", + help="Trust remote code for HF model loading") + parser.add_argument("--partial-load", action="store_true", + help="Load only needed weight shards (for large models)") + + args = parser.parse_args() + + result = run_aten_e2e( + model_id=args.model_id, + seq_len=args.seq_len, + num_layers=args.num_layers, + build_dir=args.build_dir, + layer_idx=args.layer_idx, + hidden_size=args.hidden_size, + inter_dim=args.inter_dim, + trust_remote_code=args.trust_remote_code, + partial_load=args.partial_load, + ) + + sys.exit(0 if result["passed"] else 1) + + +if __name__ == "__main__": + main() diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 483d994..a0ff701 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -27,6 +27,7 @@ PLENA_Compiler/ | |-- aten/ # Pipeline 1: ATen compilation backend | |-- plena_frontend.py # HF model -> PLENA program -> ISA text +| |-- e2e_runner.py # ATen e2e runner: PlenaCompiler -> emulator -> golden | |-- plena/ # Canonical PlenaCompiler implementation package | | |-- compiler.py # PlenaCompiler composition class | | |-- memory_state.py # Tensor/input/FP memory state @@ -38,8 +39,7 @@ PLENA_Compiler/ | +-- cpu/ # CPU reference implementations | |-- generator/ # Pipeline 2: Config-driven code generation -| |-- runner.py # CLI entry point (codegen, aten, utilization modes) -| |-- aten_runner.py # ATen backend bridge (wraps PlenaCompiler for E2E) +| |-- runner.py # CLI entry point for codegen/utilization | |-- parser/ # HF config -> symbolic graph | | |-- llm_parser.py # LLMModelParser (text decoder graphs) | | +-- hardware_parser.py # configuration.svh / precision.svh reader diff --git a/docs/ATEN_TREE.md b/docs/ATEN_TREE.md index c95fae9..e5800da 100644 --- a/docs/ATEN_TREE.md +++ b/docs/ATEN_TREE.md @@ -10,6 +10,7 @@ aten/ |-- isa_builder.py # Typed ISA instruction/register builder and legalization |-- model_extract.py # HuggingFace model config/layer/embedding extraction helpers |-- plena_frontend.py # HF decoder model -> PLENA program -> ISA text +|-- e2e_runner.py # HF model -> ATen compiler -> emulator -> golden check |-- reference.py # CPU golden/reference math and MXFP/BF16 helpers |-- vram_stage_compare.py # Debug tooling for VRAM stage comparisons | @@ -71,4 +72,6 @@ Key points: - `aten/ops/` is the ATen-style dispatcher surface. - `aten/plena_frontend.py` is the HuggingFace/ATen frontend that drives model compilation. +- `aten/e2e_runner.py` runs the ATen compiler path through the emulator and + golden comparison. - The old `aten/plena_compiler.py` compatibility facade has been removed. diff --git a/docs/COMPILATION_PIPELINES.md b/docs/COMPILATION_PIPELINES.md index 04d93a0..69e2ccc 100644 --- a/docs/COMPILATION_PIPELINES.md +++ b/docs/COMPILATION_PIPELINES.md @@ -70,14 +70,14 @@ backends, and weight-handling strategies. | `aten/ops/plena/*.py` | Registered ATen op implementations (linear, attention, ffn, norm, conv, softmax, embedding) | | `aten/ops/cpu/*.py` | CPU reference fallbacks | | `aten/ops/registry.py` | Op dispatch registry | -| `generator/aten_runner.py` | E2E harness: model load -> compile -> emulate -> verify | +| `aten/e2e_runner.py` | E2E harness: model load -> compile -> emulate -> verify | | `sim_env_utils/build_env.py` | Simulation environment builder | ### Entry points - **Single-layer tests**: `model_layer_test_builder.py::build_and_run_decoder_test` -- **Full-model E2E**: `generator/aten_runner.py::run_aten_e2e` -- **CLI**: `python -m generator.runner aten --seq-len 32 --num-layers 1` +- **Full-model E2E**: `aten/e2e_runner.py::run_aten_e2e` +- **CLI**: `python -m compiler.aten.e2e_runner --seq-len 32 --num-layers 1` ### Test suite diff --git a/generator/aten_runner.py b/generator/aten_runner.py index 62bb670..cef8d18 100644 --- a/generator/aten_runner.py +++ b/generator/aten_runner.py @@ -1,326 +1,12 @@ -""" -ATen-backend runner for the generator. - -Wraps the proven ATen compilation path (PlenaCompiler + ops.*) to provide -end-to-end model compilation + emulation + numerical verification. This -is the same pipeline that passes 18/18 tests in the ATen test suite. - -The generator's value is in config parsing and symbolic graph analysis; -actual ISA compilation is best left to PlenaCompiler. - -Usage (standalone): - python -m generator.aten_runner AICrossSim/clm-60m --seq-len 32 +"""Deprecated alias for the ATen e2e runner. -Usage (from generator.runner): - python -m generator.runner aten AICrossSim/clm-60m --seq-len 32 --num-layers 1 +Use ``python -m compiler.aten.e2e_runner`` or import +``compiler.aten.e2e_runner.run_aten_e2e`` directly. """ -import sys -import time -from pathlib import Path - -# --------------------------------------------------------------------------- -# Repo root bootstrap — mirror the same sys.path setup used by the existing -# test infrastructure so imports resolve regardless of cwd. -# --------------------------------------------------------------------------- -_COMPILER_ROOT = Path(__file__).resolve().parents[1] # compiler/ -_REPO_ROOT = _COMPILER_ROOT.parent -for _p in [str(_REPO_ROOT), str(_REPO_ROOT / "tools"), str(_COMPILER_ROOT)]: - if _p not in sys.path: - sys.path.insert(0, _p) - - -def run_aten_e2e( - model_id: str, - seq_len: int = 64, - num_layers: int = 1, - build_dir: str | None = None, - layer_idx: int = 0, - hidden_size: int = 64, - inter_dim: int = 128, - trust_remote_code: bool = False, - partial_load: bool = False, -) -> dict: - """Run a HF model through the ATen compilation path end-to-end. - - Steps: - 1. Load model config + layer weights from HuggingFace - 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) - 3. Set up sim environment (ASM + HBM weights + FPRAM constants) - 4. Run Rust emulator - 5. Compare VRAM output against golden PyTorch reference - - Returns dict with: - passed: bool - allclose_match_rate: float (percentage) - max_error: float - mae: float - mse: float - elapsed_s: float (wall-clock seconds) - model_id: str - layer_idx: int - num_layers: int - seq_len: int - hidden_size: int - inter_dim: int - build_dir: str - """ - from transactional_emulator.testbench.emulator_runner import compare_emulator_output - from transactional_emulator.testbench.model_layer_test_builder import ( - build_and_run_decoder_test, - build_and_run_multi_layer_test, - get_model_dims, - slice_dims_for_sim, - ) - - t0 = time.time() - - # Resolve build directory - if build_dir is None: - safe_name = model_id.replace("/", "_") - build_dir = str( - Path("/tmp") / f"aten_e2e_{safe_name}_sl{seq_len}_l{layer_idx}" - ) - build_path = Path(build_dir) - - # ------------------------------------------------------------------ - # [1/5] Probe model config - # ------------------------------------------------------------------ - print(f"[1/5] Probing model config: {model_id}") - try: - full_dims = get_model_dims(model_id) - except (OSError, ConnectionError) as exc: - print(f"[SKIP] HuggingFace model '{model_id}' unavailable: {exc}") - return { - "passed": False, - "error": str(exc), - "model_id": model_id, - } - sim_dims = slice_dims_for_sim(full_dims, hidden_slice=hidden_size, inter_slice=inter_dim) - print(f" Full dims: hidden={full_dims.hidden_size}, inter={full_dims.inter_dim}, " - f"heads={full_dims.num_heads}, kv_heads={full_dims.num_kv_heads}, head_dim={full_dims.head_dim}") - print(f" Sim dims: hidden={sim_dims.hidden_size}, inter={sim_dims.inter_dim}") - - # ------------------------------------------------------------------ - # [2/5] Build ISA + golden reference + sim env via build_and_run_decoder_test - # - # We call the proven function directly — it handles: - # - Weight loading + slicing - # - PlenaCompiler ISA generation - # - create_sim_env + create_mem_for_sim - # - Golden reference computation - # - Emulator execution + comparison - # - # For multi-layer: iterate layers (each is independent at sim scale). - # ------------------------------------------------------------------ - results_per_layer = [] - - if num_layers == 1: - # Single layer: use proven single-layer path (with RoPE) - current_layer = layer_idx - asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" - layer_build = build_path / f"layer_{current_layer}" - - print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") - print(f"[3/5] Setting up sim environment: {layer_build}") - print("[4/5] Running Rust transactional emulator") - - extra_kwargs = {} - if trust_remote_code: - extra_kwargs["trust_remote_code"] = True - if partial_load: - extra_kwargs["partial_load"] = True - - try: - build_and_run_decoder_test( - model_id=model_id, - asm_name=asm_name, - build_dir=layer_build, - layer_idx=current_layer, - seq_len=seq_len, - hidden_size=hidden_size, - inter_dim=inter_dim, - **extra_kwargs, - ) - comp_results, _comp_params = compare_emulator_output(layer_build) - results_per_layer.append({ - "layer": current_layer, - "passed": True, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except SystemExit as e: - if e.code == 0: - return { - "passed": False, - "error": "HuggingFace model unavailable (skipped)", - "model_id": model_id, - } - try: - comp_results, _comp_params = compare_emulator_output(layer_build) - results_per_layer.append({ - "layer": current_layer, - "passed": False, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except Exception: - results_per_layer.append({ - "layer": current_layer, - "passed": False, - "error": f"Emulator comparison failed after exit code {e.code}", - }) - else: - # Multi-layer: chain N layers with residual connections (no RoPE) - asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" - chain_build = build_path / f"chain_{num_layers}layers" - - print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") - print(f"[3/5] Setting up sim environment: {chain_build}") - print("[4/5] Running Rust transactional emulator") - - extra_kwargs = {} - if trust_remote_code: - extra_kwargs["trust_remote_code"] = True - if partial_load: - extra_kwargs["partial_load"] = True - - try: - build_and_run_multi_layer_test( - model_id=model_id, - asm_name=asm_name, - build_dir=chain_build, - num_layers=num_layers, - layer_idx_start=layer_idx, - seq_len=seq_len, - hidden_size=hidden_size, - inter_dim=inter_dim, - **extra_kwargs, - ) - comp_results, _comp_params = compare_emulator_output(chain_build) - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": True, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except SystemExit as e: - if e.code == 0: - return { - "passed": False, - "error": "HuggingFace model unavailable (skipped)", - "model_id": model_id, - } - try: - comp_results, _comp_params = compare_emulator_output(chain_build) - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": False, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except Exception: - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": False, - "error": f"Emulator comparison failed after exit code {e.code}", - }) - - elapsed = time.time() - t0 - - # ------------------------------------------------------------------ - # [5/5] Aggregate results - # ------------------------------------------------------------------ - print(f"\n[5/5] Results summary ({elapsed:.1f}s elapsed)") - all_passed = all(r.get("passed", False) for r in results_per_layer) - - # Use first layer's metrics for the top-level result - first = results_per_layer[0] if results_per_layer else {} - - summary = { - "passed": all_passed, - "allclose_match_rate": first.get("allclose_match_rate", 0.0), - "max_error": first.get("max_error", float("inf")), - "mae": first.get("mae", float("inf")), - "mse": first.get("mse", float("inf")), - "elapsed_s": elapsed, - "model_id": model_id, - "layer_idx": layer_idx, - "num_layers": num_layers, - "seq_len": seq_len, - "hidden_size": hidden_size, - "inter_dim": inter_dim, - "build_dir": str(build_path), - "layers": results_per_layer, - } - - for r in results_per_layer: - status = "PASS" if r.get("passed") else "FAIL" - match = r.get("allclose_match_rate", "N/A") - if isinstance(match, float): - match = f"{match:.2f}%" - print(f" Layer {r.get('layer', '?')}: [{status}] allclose={match}") - - if all_passed: - print(f"\n[ATen e2e PASSED] {model_id} — {num_layers} layer(s), " - f"allclose={first.get('allclose_match_rate', 0):.2f}%") - else: - print(f"\n[ATen e2e FAILED] {model_id} — see per-layer results above") - - return summary - - -# --------------------------------------------------------------------------- -# CLI entry point -# --------------------------------------------------------------------------- -def main(): - import argparse - - parser = argparse.ArgumentParser( - description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", - prog="python -m generator.aten_runner", - ) - parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") - parser.add_argument("--seq-len", type=int, default=64, - help="Sequence length (default: 64)") - parser.add_argument("--num-layers", type=int, default=1, - help="Number of decoder layers to test (default: 1)") - parser.add_argument("--layer-idx", type=int, default=0, - help="Starting layer index (default: 0)") - parser.add_argument("--hidden-size", type=int, default=64, - help="Hidden dimension clipped to sim limits (default: 64)") - parser.add_argument("--inter-dim", type=int, default=128, - help="FFN intermediate dimension clipped to sim limits (default: 128)") - parser.add_argument("--build-dir", type=str, default=None, - help="Build directory for sim artifacts (default: /tmp/aten_e2e_...)") - parser.add_argument("--trust-remote-code", action="store_true", - help="Trust remote code for HF model loading") - parser.add_argument("--partial-load", action="store_true", - help="Load only needed weight shards (for large models)") - - args = parser.parse_args() - - result = run_aten_e2e( - model_id=args.model_id, - seq_len=args.seq_len, - num_layers=args.num_layers, - build_dir=args.build_dir, - layer_idx=args.layer_idx, - hidden_size=args.hidden_size, - inter_dim=args.inter_dim, - trust_remote_code=args.trust_remote_code, - partial_load=args.partial_load, - ) +from compiler.aten.e2e_runner import main, run_aten_e2e - sys.exit(0 if result["passed"] else 1) +__all__ = ["main", "run_aten_e2e"] if __name__ == "__main__": diff --git a/generator/runner.py b/generator/runner.py index 1dacbd1..e1c7c0e 100644 --- a/generator/runner.py +++ b/generator/runner.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -PLENA Generator Runner -- entry point for both compilation pipelines. +PLENA generator runner -- symbolic codegen and ATen e2e convenience entry. Modes: codegen -- Pipeline 2 (Generator): HF config -> symbolic graph -> @@ -12,7 +12,7 @@ Examples: python -m generator.runner codegen AICrossSim/clm-60m output.asm --seq-len 512 - python -m generator.runner aten AICrossSim/clm-60m --seq-len 32 --num-layers 1 + python -m compiler.aten.e2e_runner AICrossSim/clm-60m --seq-len 32 --num-layers 1 See docs/COMPILATION_PIPELINES.md for the full architecture overview. """ @@ -29,7 +29,7 @@ def _run_aten(args) -> int: """ATen-backed end-to-end: PlenaCompiler + ops.* -> emulator -> numerical check.""" - from generator.aten_runner import run_aten_e2e + from compiler.aten.e2e_runner import run_aten_e2e result = run_aten_e2e( model_id=args.model_path, diff --git a/generator/tests/test_generator_e2e.py b/generator/tests/test_generator_e2e.py index 897ceef..2ef75aa 100644 --- a/generator/tests/test_generator_e2e.py +++ b/generator/tests/test_generator_e2e.py @@ -518,7 +518,7 @@ def run_test_aten( numerical verification deferred, this immediately gets full numerical correctness via the mature ATen compilation backend. """ - from generator.aten_runner import run_aten_e2e + from compiler.aten.e2e_runner import run_aten_e2e print("=" * 80) print(f"Generator e2e harness (ATen backend) — {model_id} — "