diff --git a/asm_templates/__init__.py b/asm_templates/__init__.py index d6e77e8..8c5c02f 100644 --- a/asm_templates/__init__.py +++ b/asm_templates/__init__.py @@ -18,7 +18,7 @@ from .elementwise_add_asm import elementwise_add_asm from .embedding_asm import embedding_asm from .ffn_asm import ffn_asm, ffn_intermediate_asm, ffn_up_silu_asm -from .flash_attn_asm import flash_attn_asm +from .flashattn import flash_attn_asm from .gelu_asm import gelu_asm from .im2col_asm import im2col_asm from .im2col_asm_no_shift import im2col_asm_no_shift diff --git a/asm_templates/flash_attn_asm.py b/asm_templates/flash_attn_asm.py deleted file mode 100644 index 6d507bf..0000000 --- a/asm_templates/flash_attn_asm.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Flash Attention assembly code generation. - -This module has been refactored into the flashattn package. -This file re-exports all functions for backward compatibility. - -For new code, prefer importing from compiler.asm_templates.flashattn directly. -""" - -# Re-export all functions from the flashattn package for backward compatibility -from .flashattn import ( - computing_o_code as _computing_o_code, -) -from .flashattn import ( - computing_pv_code as _computing_pv_code, -) -from .flashattn import ( - computing_row_wise_scaling_code as _computing_row_wise_scaling_code, -) -from .flashattn import ( - flash_attn_asm, - qkt_multiply, -) -from .flashattn import ( - online_softmax_code as _online_softmax_code, -) -from .flashattn import ( - reset_fpsram_code as _reset_fpsram_code, -) -from .flashattn import ( - reset_kv_prefetch as _reset_kv_prefetch, -) -from .flashattn import ( - reset_vssram_code as _reset_vssram_code, -) - -# Also export IMM2_BOUND for backward compatibility -IMM2_BOUND = 2**18 - 1 - -__all__ = [ - "IMM2_BOUND", - "_computing_o_code", - "_computing_pv_code", - "_computing_row_wise_scaling_code", - "_online_softmax_code", - "_reset_fpsram_code", - "_reset_kv_prefetch", - "_reset_vssram_code", - "flash_attn_asm", - "qkt_multiply", -] diff --git a/asm_templates/flashattn/memory_layout.md b/asm_templates/flashattn/memory_layout.md deleted file mode 100644 index e69de29..0000000 diff --git a/asm_templates/load_int.py b/asm_templates/load_int.py deleted file mode 100644 index e69de29..0000000 diff --git a/asm_templates/tests/test_vram_sub_projection.py b/asm_templates/tests/test_vram_sub_projection.py index 197f178..912bee0 100644 --- a/asm_templates/tests/test_vram_sub_projection.py +++ b/asm_templates/tests/test_vram_sub_projection.py @@ -3,7 +3,7 @@ Verifies the extracted free function produces the expected ISA output for looped/unrolled and transposed/non-transposed variants, and asserts byte-identical parity with the delegating -``TileCompiler._vram_sub_projection_asm_impl`` method. +``IsaCompiler._vram_sub_projection_asm_impl`` method. """ import sys @@ -95,10 +95,10 @@ def test_unrolled_no_loops(self): self.assertIn("M_MM_WO", asm) def test_output_byte_identical_to_method(self): - """The free function must produce byte-identical output to TileCompiler's method.""" - from compiler.aten.plena_compiler import TileCompiler + """The free function must produce byte-identical output to IsaCompiler's method.""" + from compiler.aten.plena import IsaCompiler - tc = TileCompiler(mlen=64, blen=4, unroll_loops=False) + compiler = IsaCompiler(mlen=64, blen=4, unroll_loops=False) method_kwargs = dict( header_lines=["; header"], @@ -113,7 +113,7 @@ def test_output_byte_identical_to_method(self): caller_name="test", ) - method_out = tc._vram_sub_projection_asm_impl(**method_kwargs) + method_out = compiler._vram_sub_projection_asm_impl(**method_kwargs) free_out = vram_sub_projection_asm_impl( mlen=64, blen=4, diff --git a/asm_templates/vram_sub_projection_asm.py b/asm_templates/vram_sub_projection_asm.py index 5851f5e..ebc68b7 100644 --- a/asm_templates/vram_sub_projection_asm.py +++ b/asm_templates/vram_sub_projection_asm.py @@ -1,7 +1,7 @@ """Pure emitter for VRAM sub-projection ISA. -Shared implementation kernel used by ``TileCompiler.vram_sub_projection_asm`` -and ``TileCompiler.vram_sub_projection_T_asm``. The caller resolves all +Shared implementation kernel used by ``IsaCompiler.vram_sub_projection_asm`` +and ``IsaCompiler.vram_sub_projection_T_asm``. The caller resolves all instance-dependent state (register allocator, tile layouts, MRAM addresses, ``unroll_loops`` default) and passes it in as plain parameters, so this emitter can be unit-tested in isolation. diff --git a/aten/__init__.py b/aten/__init__.py index 19d6fd8..d237319 100644 --- a/aten/__init__.py +++ b/aten/__init__.py @@ -1,6 +1,6 @@ """compiler.aten — ATen-style PLENA compiler path. -PlenaCompiler DSL + op backend registry. Pairs with compiler.generator +PlenaCompiler program builder + op backend registry. Pairs with compiler.generator for the two-path compiler (template vs. aten). """ @@ -9,30 +9,22 @@ PLENA_PKG_DIR = Path(__file__).parent NATIVE_OPS_YAML = PLENA_PKG_DIR / "native_ops.yaml" -from compiler.aten.plena_compiler import ( # noqa: E402, F401 +from compiler.aten.isa_builder import ( # noqa: E402, F401 + Comment, + Instr, + IsaBuilder, + Register, + addr, + fp, + gp, +) +from compiler.aten.plena import ( # noqa: E402, F401 + FPVar, + InputVar, + IsaCompiler, + MemoryStateMixin, PlenaCompiler, - TileCompiler, - DeveloperCompiler, - RegisterAllocator, TensorVar, - InputVar, VRAMMatrixVar, - FPVar, - TensorKind, - tensor_kind, - Tensor, - TensorInfo, - TileLayout, - MemoryBlock, - VirtualMemoryManager, - MRAMAllocator, - VRAMAllocator, - FPRAMAllocator, - SubMatrixInfo, - MatrixBlockLayout, - VRAMSubMatrixInfo, - VRAMMatrixBlockLayout, - MemoryObjectInfo, - FPRAMObjectLayout, ) from compiler.aten.ops.registry import OpRegistry, Backend # noqa: E402, F401 diff --git a/aten/e2e_runner.py b/aten/e2e_runner.py new file mode 100644 index 0000000..70ffc48 --- /dev/null +++ b/aten/e2e_runner.py @@ -0,0 +1,323 @@ +"""ATen-backed end-to-end runner. + +This wraps the verified ATen compilation path: + + HuggingFace model -> PlenaCompiler + ops.* -> ISA -> emulator -> golden check + +The symbolic generator path is separate and remains under ``generator.runner +codegen``. + +Usage: + python -m compiler.aten.e2e_runner AICrossSim/clm-60m --seq-len 32 +""" + +import sys +import time +from pathlib import Path + +# --------------------------------------------------------------------------- +# Repo root bootstrap — mirror the same sys.path setup used by the existing +# test infrastructure so imports resolve regardless of cwd. +# --------------------------------------------------------------------------- +_COMPILER_ROOT = Path(__file__).resolve().parents[1] # PLENA_Compiler/ +_REPO_ROOT = _COMPILER_ROOT.parent +for _p in [str(_REPO_ROOT), str(_REPO_ROOT / "tools"), str(_COMPILER_ROOT)]: + if _p not in sys.path: + sys.path.insert(0, _p) + + +def run_aten_e2e( + model_id: str, + seq_len: int = 64, + num_layers: int = 1, + build_dir: str | None = None, + layer_idx: int = 0, + hidden_size: int = 64, + inter_dim: int = 128, + trust_remote_code: bool = False, + partial_load: bool = False, +) -> dict: + """Run a HF model through the ATen compilation path end-to-end. + + Steps: + 1. Load model config + layer weights from HuggingFace + 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) + 3. Set up sim environment (ASM + HBM weights + FPRAM constants) + 4. Run Rust emulator + 5. Compare VRAM output against golden PyTorch reference + + Returns dict with: + passed: bool + allclose_match_rate: float (percentage) + max_error: float + mae: float + mse: float + elapsed_s: float (wall-clock seconds) + model_id: str + layer_idx: int + num_layers: int + seq_len: int + hidden_size: int + inter_dim: int + build_dir: str + """ + from transactional_emulator.testbench.emulator_runner import compare_emulator_output + from transactional_emulator.testbench.model_layer_test_builder import ( + build_and_run_decoder_test, + build_and_run_multi_layer_test, + get_model_dims, + slice_dims_for_sim, + ) + + t0 = time.time() + + # Resolve build directory + if build_dir is None: + safe_name = model_id.replace("/", "_") + build_dir = str( + Path("/tmp") / f"aten_e2e_{safe_name}_sl{seq_len}_l{layer_idx}" + ) + build_path = Path(build_dir) + + # ------------------------------------------------------------------ + # [1/5] Probe model config + # ------------------------------------------------------------------ + print(f"[1/5] Probing model config: {model_id}") + try: + full_dims = get_model_dims(model_id) + except (OSError, ConnectionError) as exc: + print(f"[SKIP] HuggingFace model '{model_id}' unavailable: {exc}") + return { + "passed": False, + "error": str(exc), + "model_id": model_id, + } + sim_dims = slice_dims_for_sim(full_dims, hidden_slice=hidden_size, inter_slice=inter_dim) + print(f" Full dims: hidden={full_dims.hidden_size}, inter={full_dims.inter_dim}, " + f"heads={full_dims.num_heads}, kv_heads={full_dims.num_kv_heads}, head_dim={full_dims.head_dim}") + print(f" Sim dims: hidden={sim_dims.hidden_size}, inter={sim_dims.inter_dim}") + + # ------------------------------------------------------------------ + # [2/5] Build ISA + golden reference + sim env via build_and_run_decoder_test + # + # We call the proven function directly — it handles: + # - Weight loading + slicing + # - PlenaCompiler ISA generation + # - create_sim_env + create_mem_for_sim + # - Golden reference computation + # - Emulator execution + comparison + # + # For multi-layer: iterate layers (each is independent at sim scale). + # ------------------------------------------------------------------ + results_per_layer = [] + + if num_layers == 1: + # Single layer: use proven single-layer path (with RoPE) + current_layer = layer_idx + asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" + layer_build = build_path / f"layer_{current_layer}" + + print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") + print(f"[3/5] Setting up sim environment: {layer_build}") + print("[4/5] Running Rust transactional emulator") + + extra_kwargs = {} + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True + if partial_load: + extra_kwargs["partial_load"] = True + + try: + build_and_run_decoder_test( + model_id=model_id, + asm_name=asm_name, + build_dir=layer_build, + layer_idx=current_layer, + seq_len=seq_len, + hidden_size=hidden_size, + inter_dim=inter_dim, + **extra_kwargs, + ) + comp_results, _comp_params = compare_emulator_output(layer_build) + results_per_layer.append({ + "layer": current_layer, + "passed": True, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except SystemExit as e: + if e.code == 0: + return { + "passed": False, + "error": "HuggingFace model unavailable (skipped)", + "model_id": model_id, + } + try: + comp_results, _comp_params = compare_emulator_output(layer_build) + results_per_layer.append({ + "layer": current_layer, + "passed": False, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except Exception: + results_per_layer.append({ + "layer": current_layer, + "passed": False, + "error": f"Emulator comparison failed after exit code {e.code}", + }) + else: + # Multi-layer: chain N layers with residual connections (no RoPE) + asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" + chain_build = build_path / f"chain_{num_layers}layers" + + print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") + print(f"[3/5] Setting up sim environment: {chain_build}") + print("[4/5] Running Rust transactional emulator") + + extra_kwargs = {} + if trust_remote_code: + extra_kwargs["trust_remote_code"] = True + if partial_load: + extra_kwargs["partial_load"] = True + + try: + build_and_run_multi_layer_test( + model_id=model_id, + asm_name=asm_name, + build_dir=chain_build, + num_layers=num_layers, + layer_idx_start=layer_idx, + seq_len=seq_len, + hidden_size=hidden_size, + inter_dim=inter_dim, + **extra_kwargs, + ) + comp_results, _comp_params = compare_emulator_output(chain_build) + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": True, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except SystemExit as e: + if e.code == 0: + return { + "passed": False, + "error": "HuggingFace model unavailable (skipped)", + "model_id": model_id, + } + try: + comp_results, _comp_params = compare_emulator_output(chain_build) + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": False, + "allclose_match_rate": comp_results["allclose_match_rate"], + "max_error": comp_results["max_error"], + "mae": comp_results["mae"], + "mse": comp_results["mse"], + }) + except Exception: + results_per_layer.append({ + "layer": f"chain_{num_layers}", + "passed": False, + "error": f"Emulator comparison failed after exit code {e.code}", + }) + + elapsed = time.time() - t0 + + # ------------------------------------------------------------------ + # [5/5] Aggregate results + # ------------------------------------------------------------------ + print(f"\n[5/5] Results summary ({elapsed:.1f}s elapsed)") + all_passed = all(r.get("passed", False) for r in results_per_layer) + + # Use first layer's metrics for the top-level result + first = results_per_layer[0] if results_per_layer else {} + + summary = { + "passed": all_passed, + "allclose_match_rate": first.get("allclose_match_rate", 0.0), + "max_error": first.get("max_error", float("inf")), + "mae": first.get("mae", float("inf")), + "mse": first.get("mse", float("inf")), + "elapsed_s": elapsed, + "model_id": model_id, + "layer_idx": layer_idx, + "num_layers": num_layers, + "seq_len": seq_len, + "hidden_size": hidden_size, + "inter_dim": inter_dim, + "build_dir": str(build_path), + "layers": results_per_layer, + } + + for r in results_per_layer: + status = "PASS" if r.get("passed") else "FAIL" + match = r.get("allclose_match_rate", "N/A") + if isinstance(match, float): + match = f"{match:.2f}%" + print(f" Layer {r.get('layer', '?')}: [{status}] allclose={match}") + + if all_passed: + print(f"\n[ATen e2e PASSED] {model_id} — {num_layers} layer(s), " + f"allclose={first.get('allclose_match_rate', 0):.2f}%") + else: + print(f"\n[ATen e2e FAILED] {model_id} — see per-layer results above") + + return summary + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", + prog="python -m compiler.aten.e2e_runner", + ) + parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") + parser.add_argument("--seq-len", type=int, default=64, + help="Sequence length (default: 64)") + parser.add_argument("--num-layers", type=int, default=1, + help="Number of decoder layers to test (default: 1)") + parser.add_argument("--layer-idx", type=int, default=0, + help="Starting layer index (default: 0)") + parser.add_argument("--hidden-size", type=int, default=64, + help="Hidden dimension clipped to sim limits (default: 64)") + parser.add_argument("--inter-dim", type=int, default=128, + help="FFN intermediate dimension clipped to sim limits (default: 128)") + parser.add_argument("--build-dir", type=str, default=None, + help="Build directory for sim artifacts (default: /tmp/aten_e2e_...)") + parser.add_argument("--trust-remote-code", action="store_true", + help="Trust remote code for HF model loading") + parser.add_argument("--partial-load", action="store_true", + help="Load only needed weight shards (for large models)") + + args = parser.parse_args() + + result = run_aten_e2e( + model_id=args.model_id, + seq_len=args.seq_len, + num_layers=args.num_layers, + build_dir=args.build_dir, + layer_idx=args.layer_idx, + hidden_size=args.hidden_size, + inter_dim=args.inter_dim, + trust_remote_code=args.trust_remote_code, + partial_load=args.partial_load, + ) + + sys.exit(0 if result["passed"] else 1) + + +if __name__ == "__main__": + main() diff --git a/aten/isa_builder.py b/aten/isa_builder.py new file mode 100644 index 0000000..8a84778 --- /dev/null +++ b/aten/isa_builder.py @@ -0,0 +1,139 @@ +"""Typed ISA builder for the ATen PLENA compiler path. + +This is intentionally small: it models the physical instruction stream and +prints the same assembly syntax the existing compiler emits. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Protocol + + +class Renderable(Protocol): + def render(self) -> str: + """Render this object as assembly text.""" + + +@dataclass(frozen=True) +class Register: + prefix: str + index: int + + def render(self) -> str: + return f"{self.prefix}{self.index}" + + +def gp(index: int) -> Register: + return Register("gp", index) + + +def fp(index: int) -> Register: + return Register("f", index) + + +def addr(index: int) -> Register: + return Register("a", index) + + +AsmArg = str | int | Register + + +def render_arg(arg: AsmArg) -> str: + if isinstance(arg, Register): + return arg.render() + return str(arg) + + +@dataclass(frozen=True) +class Instr: + opcode: str + args: tuple[AsmArg, ...] = () + + def render(self) -> str: + if not self.args: + return self.opcode + return f"{self.opcode} {', '.join(render_arg(arg) for arg in self.args)}" + + +@dataclass(frozen=True) +class Comment: + text: str + + def render(self) -> str: + text = self.text.rstrip() + if text.startswith(";"): + return text + return f"; {text}" + + +AsmItem = str | Instr | Comment +IMM2_BOUND = 1 << 18 + + +@dataclass +class IsaBuilder: + items: list[AsmItem] = field(default_factory=list) + + def comment(self, text: str) -> IsaBuilder: + self.items.append(Comment(text)) + return self + + def instr(self, opcode: str, *args: AsmArg) -> IsaBuilder: + self.items.append(Instr(opcode, args)) + return self + + def raw(self, line: str) -> IsaBuilder: + self.items.append(line.rstrip("\n")) + return self + + def extend(self, items: Iterable[AsmItem]) -> IsaBuilder: + self.items.extend(items) + return self + + def render(self) -> str: + if not self.items: + return "" + return "\n".join(render_item(item) for item in legalize_large_immediates(self.items)) + "\n" + + +AsmInput = str | Renderable + + +def render_item(item: AsmItem) -> str: + if isinstance(item, str): + return item.rstrip("\n") + return item.render() + + +def render_asm(value: AsmInput) -> str: + if isinstance(value, str): + return value + return value.render() + + +def is_gp_zero(arg: AsmArg) -> bool: + return isinstance(arg, Register) and arg.prefix == "gp" and arg.index == 0 + + +def legalize_large_immediates(items: Iterable[AsmItem]) -> list[AsmItem]: + """Split typed absolute S_ADDI_INT loads that exceed the immediate field. + + This is the typed equivalent of plena_frontend._fix_large_immediates. + Raw string items are intentionally left alone until those call sites move + onto typed instructions. + """ + legalized: list[AsmItem] = [] + for item in items: + if isinstance(item, Instr) and item.opcode == "S_ADDI_INT" and len(item.args) == 3: + rd, rs, imm = item.args + if isinstance(rd, Register) and is_gp_zero(rs) and isinstance(imm, int) and imm >= IMM2_BOUND: + upper = imm >> 12 + lower = imm & 0xFFF + legalized.append(Instr("S_LUI_INT", (rd, upper))) + if lower: + legalized.append(Instr("S_ADDI_INT", (rd, rd, lower))) + continue + legalized.append(item) + return legalized diff --git a/aten/model_extract.py b/aten/model_extract.py new file mode 100644 index 0000000..25a07eb --- /dev/null +++ b/aten/model_extract.py @@ -0,0 +1,130 @@ +"""HuggingFace decoder model extraction helpers for the PLENA frontend.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass(frozen=True) +class ModelConfig: + """Decoder dimensions needed by the PLENA ATen frontend.""" + + hidden_size: int + inter_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + eps: float + rope_theta: float + vocab_size: int | None + model_type: str + + @property + def total_q_dim(self) -> int: + return self.num_heads * self.head_dim + + @property + def head_ratio(self) -> int: + return self.num_heads // self.num_kv_heads + + +@dataclass(frozen=True) +class LayerWeights: + """One decoder layer in PLENA's (in, out) linear-weight convention.""" + + w_q: torch.Tensor + w_o: torch.Tensor + w_k_heads: list[torch.Tensor] + w_v_heads: list[torch.Tensor] + w_gate: torch.Tensor + w_up: torch.Tensor + w_down: torch.Tensor + eps: float + + def tensor_entries(self, layer_idx: int) -> list[tuple[str, torch.Tensor]]: + entries = [ + (f"W_q_{layer_idx}", self.w_q), + (f"W_o_{layer_idx}", self.w_o), + ] + for kv_h, (w_k, w_v) in enumerate(zip(self.w_k_heads, self.w_v_heads, strict=True)): + entries.extend( + [ + (f"W_k_{layer_idx}_h{kv_h}", w_k), + (f"W_v_{layer_idx}_h{kv_h}", w_v), + ] + ) + entries.extend( + [ + (f"W_gate_{layer_idx}", self.w_gate), + (f"W_up_{layer_idx}", self.w_up), + (f"W_down_{layer_idx}", self.w_down), + ] + ) + return entries + + +def find_model_root(model: Any) -> Any: + """Find the transformer backbone (model.model or model.model.text_model).""" + for candidate in [ + getattr(model, "model", None), + getattr(getattr(model, "model", None), "text_model", None), + getattr(model, "language_model", getattr(model, "text_model", None)), + ]: + if candidate is not None and hasattr(candidate, "layers"): + return candidate + raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") + + +def embedding_module(root: Any) -> Any | None: + """Return the token embedding module when the backbone exposes one.""" + return getattr(root, "embed_tokens", getattr(root, "wte", None)) + + +def extract_model_config(model: Any) -> ModelConfig: + """Extract decoder dimensions, resolving text_config for VLM wrappers.""" + config = getattr(model.config, "text_config", model.config) + hidden = config.hidden_size + num_heads = config.num_attention_heads + return ModelConfig( + hidden_size=hidden, + inter_dim=getattr(config, "intermediate_size", 4 * hidden), + num_heads=num_heads, + num_kv_heads=getattr(config, "num_key_value_heads", num_heads), + head_dim=hidden // num_heads, + eps=getattr(config, "rms_norm_eps", 1e-5), + rope_theta=getattr(config, "rope_theta", 10000.0), + vocab_size=getattr(config, "vocab_size", None), + model_type=getattr(config, "model_type", "unknown"), + ) + + +def extract_layer_weights(layer: Any, config: ModelConfig) -> LayerWeights: + """Extract one decoder layer in PLENA's (in, out) convention.""" + hidden = config.hidden_size + total_kv_dim = config.num_kv_heads * config.head_dim + w_k_full = _linear_weight(layer.self_attn.k_proj, hidden, total_kv_dim) + w_v_full = _linear_weight(layer.self_attn.v_proj, hidden, total_kv_dim) + norm = layer.input_layernorm + + return LayerWeights( + w_q=_linear_weight(layer.self_attn.q_proj, hidden, config.total_q_dim), + w_o=_linear_weight(layer.self_attn.o_proj, config.total_q_dim, hidden), + w_k_heads=_split_heads(w_k_full, config.head_dim, config.num_kv_heads), + w_v_heads=_split_heads(w_v_full, config.head_dim, config.num_kv_heads), + w_gate=_linear_weight(layer.mlp.gate_proj, hidden, config.inter_dim), + w_up=_linear_weight(layer.mlp.up_proj, hidden, config.inter_dim), + w_down=_linear_weight(layer.mlp.down_proj, config.inter_dim, hidden), + eps=getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)), + ) + + +def _linear_weight(module: Any, rows: int, cols: int) -> torch.Tensor: + """HF Linear stores (out, in); PLENA uses (in, out).""" + return module.weight.detach().T.contiguous()[:rows, :cols] + + +def _split_heads(weight: torch.Tensor, head_dim: int, num_heads: int) -> list[torch.Tensor]: + return [weight[:, h * head_dim:(h + 1) * head_dim].contiguous() for h in range(num_heads)] diff --git a/aten/native_ops.yaml b/aten/native_ops.yaml index 9cf8aaa..3a828d4 100644 --- a/aten/native_ops.yaml +++ b/aten/native_ops.yaml @@ -31,7 +31,7 @@ uses_mram: false doc: "Online softmax: row-wise numerically stable softmax (in-place on S tile)" -- func: "linear(Tensor input, Tensor weight) -> Tensor" +- func: "linear(Tensor input, Tensor weight, str name='linear_out') -> Tensor" category: primitive in_place: false dispatch: diff --git a/aten/ops/cpu/linear_ops.py b/aten/ops/cpu/linear_ops.py index 58c7cda..69cbe05 100644 --- a/aten/ops/cpu/linear_ops.py +++ b/aten/ops/cpu/linear_ops.py @@ -6,6 +6,8 @@ def linear_cpu( input: torch.Tensor, weight: torch.Tensor, + name: str = "linear_out", ) -> torch.Tensor: """CPU reference: input @ weight (float32 accumulation).""" + del name return torch.matmul(input.float(), weight.float()) diff --git a/aten/ops/cpu/norm_ops.py b/aten/ops/cpu/norm_ops.py index d633d71..8f0362b 100644 --- a/aten/ops/cpu/norm_ops.py +++ b/aten/ops/cpu/norm_ops.py @@ -6,8 +6,11 @@ def rms_norm_cpu( input: torch.Tensor, eps: float = 1e-6, + eps_offset: int = 1, + reci_hid_offset: int = 2, ) -> torch.Tensor: """CPU reference: RMS normalization.""" + del eps_offset, reci_hid_offset x = input.float() rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) return x / rms @@ -16,8 +19,11 @@ def rms_norm_cpu( def layer_norm_cpu( input: torch.Tensor, eps: float = 1e-6, + eps_offset: int = 1, + reci_hid_offset: int = 2, ) -> torch.Tensor: """CPU reference: Layer normalization (zero-mean, unit-variance per row).""" + del eps_offset, reci_hid_offset x = input.float() mean = x.mean(-1, keepdim=True) var = x.var(-1, keepdim=True, unbiased=False) diff --git a/aten/ops/plena/attention_ops.py b/aten/ops/plena/attention_ops.py index b62eb81..e1433ce 100644 --- a/aten/ops/plena/attention_ops.py +++ b/aten/ops/plena/attention_ops.py @@ -1,198 +1,5 @@ -"""PLENA backend implementation for Flash Attention operator.""" +"""PLENA backend wrapper for Flash Attention.""" def flash_attention_plena(prog, Q, K, V, scale=None, hq=1, hkv=1, h_qkv=None, causal_mask=None): - """PLENA backend: Flash Attention via PlenaCompiler. - - Dispatches to one of two codegen paths based on shape: - * MHA (hq == hkv == 1) — online-softmax loop using PlenaCompiler primitives - * GQA (hq // hkv > 1) — fused codegen via `flash_attn_asm` template, packs - `hq/hkv` Q heads into blen of M_BTMM (hardware GQA fusion, matches main) - - Args: - prog: PlenaCompiler instance - Q: VRAMMatrixVar — Q in VRAM, shape (seq_len, hq*h_qkv) - K: InputVar — K in HBM, shape (seq_len, hkv*h_qkv_padded) - V: InputVar — V in HBM, shape (seq_len, hkv*h_qkv_padded) - scale: Attention scale (default 1/sqrt(head_dim)) - hq, hkv, h_qkv: GQA params. Defaults treat input as single-head MHA. - causal_mask: VRAMMatrixVar or None — (mlen, mlen) mask with 0 on/below - diagonal and -inf above. Added to S before softmax. - - Returns: - VRAMMatrixVar for O, shape matching Q. - """ - - # Detect MHA vs GQA - if hq == 1 and hkv == 1: - return _flash_attention_mha(prog, Q, K, V, scale, causal_mask=causal_mask) - - # GQA: dispatch to fused codegen - if h_qkv is None: - raise ValueError("GQA mode requires h_qkv to be specified") - if causal_mask is not None: - raise NotImplementedError("causal_mask is not yet supported for GQA flash attention") - return _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv) - - -def _flash_attention_mha(prog, Q, K, V, scale, causal_mask=None): - """Single-head flash attention using PlenaCompiler primitives (online softmax). - - Args: - prog: PlenaCompiler instance - Q: VRAMMatrixVar — query matrix in VRAM - K: InputVar — key matrix in HBM - V: InputVar — value matrix in HBM - scale: Attention scale factor - causal_mask: VRAMMatrixVar or None — (mlen, mlen) mask with 0 on/below - diagonal and -inf above. When provided, added to S_block - after QK^T and before online softmax. - """ - import math - - seq_len, head_dim = Q.shape - mlen = prog.mlen - - if scale is None: - scale = 1.0 / math.sqrt(head_dim) - - num_q_blocks = seq_len // mlen - num_k_blocks = seq_len // mlen - - S_block = prog.alloc("S", mlen, mlen) - PV = prog.alloc("PV", mlen, head_dim) - O = prog.alloc("O", seq_len, head_dim) - - for q_idx in range(num_q_blocks): - prog.init_online_softmax(q_idx, O) - - for k_idx in range(num_k_blocks): - prog.vram_sub_projection_T_to( - Q, - q_idx, - K, - k_idx, - S_block, - target_row_idx=0, - target_col_idx=0, - ) - # Apply causal mask before softmax: S += mask (-inf above diagonal) - if causal_mask is not None: - prog.vram_add(S_block, causal_mask) - prog.online_softmax_block(S_block, scale) - prog.compute_pv(S_block, V, k_idx, PV, head_dim) - prog.scale_o_row(O, q_idx) - prog.vram_add(O, PV, dst_row_offset=q_idx * mlen) - - prog.final_scale_o(q_idx, O) - - return O - - -def _flash_attention_gqa_fused(prog, Q, K, V, scale, hq, hkv, h_qkv): - """GQA flash attention using fused M_BTMM codegen (main branch template). - - Packs `hq/hkv` Q heads into the `blen` systolic dimension — one M_BTMM - produces `ratio` head outputs in parallel. Requires: - (hq // hkv) == prog.blen, and (hq // hkv) * h_qkv == prog.mlen. - """ - import math - from compiler.asm_templates.flashattn import flash_attn_asm - from compiler.asm_templates import preload_addr_reg_asm - - ratio = hq // hkv - mlen = prog.mlen - blen = prog.blen - vlen = mlen - - if ratio != blen: - raise ValueError( - f"GQA ratio hq/hkv={ratio} must equal blen={blen} (hardware packs " - f"heads into blen). Use ratio=4 with default blen." - ) - if ratio * h_qkv != mlen: - raise ValueError( - f"GQA constraint: (hq/hkv)*h_qkv = {ratio * h_qkv} must equal " - f"mlen={mlen}. E.g. hq=4, hkv=1, h_qkv=16 with mlen=64." - ) - - s_q, _q_total_dim = Q.shape - s_kv, _k_total_dim = K.shape - - if scale is None: - scale = 1.0 / math.sqrt(h_qkv) - - # Allocate HBM addr registers for K and V (C_SET_ADDR_REG aN) - # Make sure K/V HBM subst-matrix registry knows about them - prog._ensure_hbm_sub_matrix_registered(K) - prog._ensure_hbm_sub_matrix_registered(V) - alloc = prog._compiler.register_allocator - k_addr, v_addr = alloc.allocate_addr(2) - gp_for_preload = alloc.allocate_gp(2) - setup = preload_addr_reg_asm( - addr_reg_to_set=[k_addr, v_addr], - available_registers=gp_for_preload, - addr_reg_val=[K.hbm_addr, V.hbm_addr], - ) - alloc.free_gp(gp_for_preload) - prog._compiler.generated_code += setup - - # Allocate VRAM buffers mirroring main's layout. - # S, PV each require mlen*mlen*ratio elements; O is s_q * (hq*h_qkv). - # We let PlenaCompiler's allocator handle placement — main's template uses - # vector_sram_base_address + computed offsets for S/PV/O, so they must be - # allocated contiguously starting right after Q. allocate_vram_matrix - # bump-allocates, which preserves that contiguity. - q_vram_base = prog._compiler.get_vram_addr(Q.name) - s_name = prog._scoped_name("_gqa_S") - pv_name = prog._scoped_name("_gqa_PV") - o_name = prog._scoped_name("O") - - # Express sizes as (rows, cols) that multiply to the required counts. - prog._compiler.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) - prog._compiler.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) - prog._compiler.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) - - # Reserve FPRAM for multi-head softmax state. main's flash_attn_asm - # assumes slots 0..2 hold constants (0.0, scale, -inf, preloaded by the - # test harness) and softmax state starts at fp_sram_start_address, with - # 3 triples (m_old, m_res, l_old) per head, strided 3*br apart. - # Reserve slots 0..2 first so our bump-allocated state lives at offset 3+. - br = min(mlen, s_q) - fp_allocs = prog._compiler.fpram_allocator - if "_gqa_fp_const_zero" not in fp_allocs.allocations: - fp_allocs.allocate(name="_gqa_fp_const_zero", size=1) - fp_allocs.allocate(name="_gqa_fp_const_scale", size=1) - fp_allocs.allocate(name="_gqa_fp_const_neg_inf", size=1) - fp_state_size = 3 * br * ratio - fp_start = prog._compiler.allocate_fpram(name="_gqa_softmax_state", size=fp_state_size) - - # Call main's fused GQA template - asm = flash_attn_asm( - mlen=mlen, - vlen=vlen, - blen=blen, - batch=1, - hq=hq, - hkv=hkv, - d=h_qkv, - q_len=s_q, - kv_len=s_kv, - alive_registers_int=list(range(1, 16)), - alive_registers_fp=list(range(1, 8)), - vector_sram_base_address=q_vram_base, - fp_sram_start_address=fp_start, - k_base_hbm_offset_reg=k_addr, - v_base_hbm_offset_reg=v_addr, - ) - prog._compiler.generated_code += asm - - # Release HBM addr regs (they're only needed during the call) - alloc.free_addr([k_addr, v_addr]) - - # Return O as a VRAMMatrixVar the caller can consume - from compiler.aten.plena_compiler import VRAMMatrixVar - - O = VRAMMatrixVar(prog, o_name, (s_q, hq * h_qkv), display_name="O") - prog._tensors[o_name] = O - return O + return prog.flash_attention(Q, K, V, scale=scale, hq=hq, hkv=hkv, h_qkv=h_qkv, causal_mask=causal_mask) diff --git a/aten/ops/plena/conv_ops.py b/aten/ops/plena/conv_ops.py index 57eb562..3fecc13 100644 --- a/aten/ops/plena/conv_ops.py +++ b/aten/ops/plena/conv_ops.py @@ -9,7 +9,7 @@ patches. Requires the emulator to support V_SHIFT_V (opcode 0x31, LSB-first right-shift per main.rs fix by George Wu, commit 24eb011). -After im2col the systolic matmul uses the standard linear_plena path. +After im2col the systolic matmul uses the compiler's standard linear path. HBM layout convention (caller must arrange data accordingly): input_raw shape = (C_in * H, W_padded) — each row is one spatial row of one channel @@ -20,9 +20,6 @@ With W_padded=64 and ow=0 (OW=1): offset = (c*H + oh+kr) * 64 → always aligned. """ -from compiler.aten.ops.plena.linear_ops import linear_plena - - _PREFETCH_V_AMOUNT = 4 # H_PREFETCH_V always loads this many VRAM rows @@ -114,12 +111,12 @@ def conv2d_plena( # Look up VRAM base addresses from the symbol table # ------------------------------------------------------------------ if use_shift: - mask_vec_vram_addr = prog._compiler.get_vram_addr(mask_mat.name) + mask_vec_vram_addr = prog.get_vram_addr(mask_mat.name) else: - basis_vram_base = prog._compiler.get_vram_addr(basis_mat.name) - scratch_vram_addr = prog._compiler.get_vram_addr(scratch_mat.name) - temp_vram_addr = prog._compiler.get_vram_addr(temp_mat.name) - output_vram_base = prog._compiler.get_vram_addr(output_mat.name) + basis_vram_base = prog.get_vram_addr(basis_mat.name) + scratch_vram_addr = prog.get_vram_addr(scratch_mat.name) + temp_vram_addr = prog.get_vram_addr(temp_mat.name) + output_vram_base = prog.get_vram_addr(output_mat.name) # ------------------------------------------------------------------ # GP register allocation @@ -146,7 +143,7 @@ def conv2d_plena( setup_lines.append(f"S_LUI_INT gp{setup_gp}, {hbm_base >> 12}") setup_lines.append(f"S_ADDI_INT gp{setup_gp}, gp{setup_gp}, {hbm_base & 0xFFF}") setup_lines.append(f"C_SET_ADDR_REG a{addr_reg_idx}, gp0, gp{setup_gp}") - prog._compiler.generated_code += "\n".join(setup_lines) + "\n" + prog.emit("\n".join(setup_lines) + "\n") # ------------------------------------------------------------------ # Emit: im2col assembly @@ -191,9 +188,9 @@ def conv2d_plena( fp_one_reg=fp_one_reg, # f1 = 1.0 by default (must be in fp_preload[fp_one_reg]) fp_ex_reg=2, # f2 = V_RED_SUM accumulator ) - prog._compiler.generated_code += asm_code + prog.emit(asm_code) # ------------------------------------------------------------------ # Systolic matmul: im2col_out @ weight_2d -> (M, C_out) # ------------------------------------------------------------------ - return linear_plena(prog, output_mat, weight_2d_var) + return prog.linear(output_mat, weight_2d_var) diff --git a/aten/ops/plena/embedding_ops.py b/aten/ops/plena/embedding_ops.py index 495f876..5497f88 100644 --- a/aten/ops/plena/embedding_ops.py +++ b/aten/ops/plena/embedding_ops.py @@ -1,23 +1,9 @@ -"""PLENA backend implementations for positional encoding operators.""" +"""PLENA backend wrappers for positional encoding operators.""" def embedding_add_plena(prog, input_var, pos_weight_var): - """PLENA backend: add learned position embeddings to input in-place. - - Both input_var and pos_weight_var must be VRAMMatrixVar with the same shape. - Uses prog.vram_add() which emits V_ADD_VV row-by-row. - """ - prog.vram_add(input_var, pos_weight_var) - return input_var + return prog.embedding_add(input_var, pos_weight_var) def rope_plena(prog, x_var, x_rot_var, cos_var, sin_var): - """PLENA backend: apply RoPE in-place. - - x_var is updated in-place: x = x * cos + rotate_half(x) * sin - - x_rot_var must be a VRAMMatrixVar holding rotate_half(x), preloaded from HBM - by the caller before dispatching this op. - """ - prog.rope(x_var, x_rot_var, cos_var, sin_var) - return x_var + return prog.rope(x_var, x_rot_var, cos_var, sin_var) diff --git a/aten/ops/plena/ffn_ops.py b/aten/ops/plena/ffn_ops.py index 89fe670..3fdcab5 100644 --- a/aten/ops/plena/ffn_ops.py +++ b/aten/ops/plena/ffn_ops.py @@ -1,61 +1,5 @@ -"""PLENA backend implementation for FFN operator.""" - -from compiler.asm_templates import ffn_asm, preload_addr_reg_asm, reset_reg_asm +"""PLENA backend wrapper for FFN.""" def ffn_plena(prog, input_var, w_gate, w_up, w_down): - """PLENA backend: FFN with SiLU gate via ffn_asm. - - Generates ISA code for: w_down @ (silu(w_gate @ x) * (w_up @ x)) - - Args: - prog: PlenaCompiler instance - input_var: BatchVar — activation in VRAM, shape (batch, hidden) - w_gate: InputVar — gate weight in HBM, shape (hidden, inter_dim) - w_up: InputVar — up-projection weight in HBM, shape (hidden, inter_dim) - w_down: InputVar — down-projection weight in HBM, shape (inter_dim, hidden) - - Returns: - VRAMMatrixVar for the FFN output (stored at activation_base_address in VRAM). - """ - batch_size, hidden_size = input_var.shape - _, inter_dim = w_up.shape - mlen = prog.mlen - blen = prog.blen - vlen = prog.mlen - - # Retrieve VRAM address of the loaded activation - activation_base_address = prog._compiler.get_vram_addr(input_var.name) - - # Set HBM address registers for each weight matrix - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[1, 2, 3], - available_registers=[1, 2, 3], - addr_reg_val=[w_gate.hbm_addr, w_up.hbm_addr, w_down.hbm_addr], - ) - - # Reset registers before FFN kernel - isa_code += reset_reg_asm(alive_registers=[1, 2, 3]) - - # Generate FFN ISA kernel - isa_code += ffn_asm( - mlen=mlen, - vlen=vlen, - blen=blen, - batch=batch_size, - seq_len=1, - hidden_size=hidden_size, - intermediate_size=inter_dim, - alive_registers=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - gate_weight_hbm_offset_reg=1, - up_weight_hbm_offset_reg=2, - down_weight_hbm_offset_reg=3, - const_one_fp_address=5, - activation_base_address=activation_base_address, - use_loop_instructions=True, - ) - - prog._compiler.generated_code += isa_code - - # FFN result is written back to the activation area in VRAM (in-place overwrite) - return input_var + return prog.ffn(input_var, w_gate, w_up, w_down) diff --git a/aten/ops/plena/linear_ops.py b/aten/ops/plena/linear_ops.py index c94d292..7319e9f 100644 --- a/aten/ops/plena/linear_ops.py +++ b/aten/ops/plena/linear_ops.py @@ -1,106 +1,9 @@ -"""PLENA backend stubs for linear projection operators.""" +"""PLENA backend wrappers for linear operators.""" -def linear_plena(prog, input_var, weight_var): - """PLENA backend: linear projection via PlenaCompiler sub-matrix operations. +def linear_projection_plena(prog, input_var, weight_var, name: str = "linear_out"): + return prog.linear_projection(input_var, weight_var, name) - Supports M > mlen via row-block iteration and K_col > 4*mlen via K-split - partial sums accumulated in VRAM. - MRAM capacity: 4 tiles × mlen² = 4 × 4096 = 16384 elements (MAX_K_TILES=4). - When K tiles > MAX_K_TILES, we split into chunks and accumulate partial sums. - - output[r][c] = sum_k input[r][k] @ weight[k][c] - for r in 0..num_row_blocks-1, c in 0..num_col_blocks-1, - k split into chunks of at most MAX_K_TILES tiles. - """ - import math - - mlen = prog.mlen - MAX_K_TILES = 4 # MRAM capacity: 4 × mlen² elements - - rows, k_total = input_var.shape - _, out_features = weight_var.shape - num_row_blocks = math.ceil(rows / mlen) - num_col_blocks = out_features // mlen - num_k_tiles = math.ceil(k_total / mlen) - - # Allocate output matrix. When batch is not a multiple of mlen we pass - # strict=False so the allocator doesn't reject the shape; the hardware will - # operate on full mlen-wide tiles (HBM zero-pads unused rows) and only the - # first `rows` rows of the output contain valid results. - output_strict = rows % mlen == 0 - output = prog.alloc("linear_out", rows, out_features, strict=output_strict) - - if num_k_tiles <= MAX_K_TILES: - # Single pass: all K tiles fit in MRAM - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - ) - else: - # K-split: chunk K tiles into groups of MAX_K_TILES, accumulate partial sums - k_chunks = [] - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - k_chunks.append((k_start, k_end - k_start)) - k_start = k_end - - # Temp buffer for partial sums — only needs one (mlen, mlen) tile since - # each sub_projection_to writes a single tile before accumulating. - # Using the full output shape would cause VRAM overlap with the output - # when out_features > mlen (column-block-major layout). - temp = prog.alloc("linear_out_temp", mlen, mlen) - - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - if k_chunk_idx == 0: - # First chunk: write directly to output - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - else: - # Subsequent chunks: write to temp tile, then accumulate into output. - # temp is a single (mlen, mlen) tile to avoid VRAM overlap - # with the output buffer in column-block-major layout. - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - temp, - 0, - 0, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - prog.vram_block_add_to( - output, - row_idx, - col_idx, - temp, - 0, - 0, - output, - row_idx, - col_idx, - ) - - return output +def linear_plena(prog, input_var, weight_var, name: str = "linear_out"): + return prog.linear_projection(input_var, weight_var, name) diff --git a/aten/ops/plena/softmax_ops.py b/aten/ops/plena/softmax_ops.py index f7b99cd..9570bc1 100644 --- a/aten/ops/plena/softmax_ops.py +++ b/aten/ops/plena/softmax_ops.py @@ -50,7 +50,7 @@ def softmax_plena(prog, input_var, scale: float = 1.0): # Reserve fp_preload region (addresses 0-2): 0=0.0, 1=scale, 2=-inf # This ensures FPRAM variable allocation starts after these reserved slots - prog._compiler.fpram_allocator.next_free = 3 + prog.fpram_allocator.next_free = 3 # Allocate FPRAM variables scale_fp = prog.fp_var("scale_fp", size=1) # scale factor @@ -69,10 +69,9 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.fpvar_fill_from_fpram(l_old, src_fpram_addr=0) # load 0.0 # Row-by-row online softmax - compiler = prog._compiler for row in range(mlen): # 1. Save old max for this row - compiler.fpvar_copy_asm(m_old.address + row, m_old_saved.address, 1) + prog.fpvar_copy_asm(m_old.address + row, m_old_saved.address, 1) # 2. Scale: S[row] *= scale prog.tile_row_mul_fp_broadcast(S, scale_fp.address, row) @@ -81,11 +80,11 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.tile_row_max(row_max_tmp.address, S, row) # 4. Update running max: m_old[row] = max(m_old[row], row_max) - compiler.fpvar_max_asm(m_old.address + row, row_max_tmp.address, m_old.address + row, 1) + prog.fpvar_max_asm(m_old.address + row, row_max_tmp.address, m_old.address + row, 1) # 5. Decay factor: m_res[row] = exp(m_old_saved - m_curr) - compiler.fpvar_sub_asm(m_old_saved.address, m_old.address + row, m_old_saved.address, 1) - compiler.fpvar_exp_asm(m_old_saved.address, m_res.address + row, 1) + prog.fpvar_sub_asm(m_old_saved.address, m_old.address + row, m_old_saved.address, 1) + prog.fpvar_exp_asm(m_old_saved.address, m_res.address + row, 1) # 6. Subtract max: S[row] -= m_curr prog.tile_row_sub_fp(S, m_old.address + row, row) @@ -97,8 +96,8 @@ def softmax_plena(prog, input_var, scale: float = 1.0): prog.tile_row_sum(sum_p_tmp.address, S, row) # 9. Update accumulated sum: l_old[row] = l_old[row]*m_res[row] + sum_p - compiler.fpvar_mul_asm(l_old.address + row, m_res.address + row, l_old.address + row, 1) - compiler.fpvar_add_asm(l_old.address + row, sum_p_tmp.address, l_old.address + row, 1) + prog.fpvar_mul_asm(l_old.address + row, m_res.address + row, l_old.address + row, 1) + prog.fpvar_add_asm(l_old.address + row, sum_p_tmp.address, l_old.address + row, 1) # Final normalization: P[row] /= l_old[row] prog.fpvar_reci(l_old, inv_l) diff --git a/aten/plena/__init__.py b/aten/plena/__init__.py new file mode 100644 index 0000000..4d8604f --- /dev/null +++ b/aten/plena/__init__.py @@ -0,0 +1,20 @@ +"""Internal modules for the ATen PLENA compiler implementation.""" + +from compiler.aten.plena.compiler import PlenaCompiler +from compiler.aten.plena.constants import BLEN, IMM2_BOUND, MLEN +from compiler.aten.plena.isa_compiler import IsaCompiler +from compiler.aten.plena.memory_state import MemoryStateMixin +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar + +__all__ = [ + "BLEN", + "IMM2_BOUND", + "MLEN", + "FPVar", + "InputVar", + "IsaCompiler", + "MemoryStateMixin", + "PlenaCompiler", + "TensorVar", + "VRAMMatrixVar", +] diff --git a/aten/plena/compiler.py b/aten/plena/compiler.py new file mode 100644 index 0000000..91e006b --- /dev/null +++ b/aten/plena/compiler.py @@ -0,0 +1,113 @@ +"""User-facing PLENA compiler program builder.""" + +from __future__ import annotations + +import os + +from compiler.aten.plena.isa_compiler import IsaCompiler +from compiler.aten.plena.program_attention import ProgramAttentionMixin +from compiler.aten.plena.program_fp_tile_ops import ProgramFPTileOpsMixin +from compiler.aten.plena.program_matrix_ops import ProgramMatrixOpsMixin +from compiler.aten.plena.program_tensors import ProgramTensorMixin +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar + + +# ============================================================================ +# PlenaCompiler Main Class +# ============================================================================ + + +class PlenaCompiler( + ProgramTensorMixin, + ProgramFPTileOpsMixin, + ProgramMatrixOpsMixin, + ProgramAttentionMixin, + IsaCompiler, +): + """ + PLENA High-level Compiler Interface. + + Inherits the ISA-emission machinery from IsaCompiler and layers typed + program-builder helpers on top. Operations eagerly emit ISA text. + """ + + def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + """ + Args: + mlen: Matrix tile size (default 64) + blen: Vector tile size (default 4) + real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) + unroll_loops: If True, unroll sub-projection loops at ASM-gen time to + eliminate C_LOOP_START/END overhead. Overridden by the + ATEN_UNROLL env var ("1"=True, "0"=False). + """ + _env_unroll = os.environ.get("ATEN_UNROLL", "") + if _env_unroll == "1": + unroll_loops = True + elif _env_unroll == "0": + unroll_loops = False + super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops) + + # HBM address auto-allocation + self._next_hbm_addr: int = 0 + self._hbm_free_blocks: list[tuple[int, int]] = [] # (addr, size) + + # Variable registries + self._inputs: dict[str, InputVar] = {} + self._tensors: dict[str, TensorVar] = {} + self._fp_vars: dict[str, FPVar] = {} + self._registered_hbm_sub_matrices: dict[str, bool] = {} + self._registered_vram_sub_matrices: dict[str, bool] = {} + + # ======================================================================== + # Compilation + # ======================================================================== + + def compile(self) -> str: + """Get generated ISA code string.""" + return super().get_code() + + @property + def _compiler(self) -> PlenaCompiler: + """Compatibility alias for simulator testbench callers.""" + return self + + # ======================================================================== + # Utility Methods + # ======================================================================== + + def _scoped_name(self, name: str) -> str: + return name + + def _allocate_hbm(self, hbm_size: int) -> int: + """Allocate HBM range, preferring previously freed blocks.""" + best_idx = None + best_waste = None + for i, (addr, size) in enumerate(self._hbm_free_blocks): + if size >= hbm_size: + waste = size - hbm_size + if best_waste is None or waste < best_waste: + best_idx = i + best_waste = waste + + if best_idx is not None: + addr, block_size = self._hbm_free_blocks.pop(best_idx) + # Return excess fragment to free list + excess = block_size - hbm_size + if excess > 0: + self._hbm_free_blocks.append((addr + hbm_size, excess)) + return addr + + addr = self._next_hbm_addr + m = self.mlen + self._next_hbm_addr = ((addr + hbm_size + m - 1) // m) * m + return addr + + def _recycle_hbm(self, hbm_addr: int, hbm_size: int): + """Recycle an HBM range for future auto-allocation.""" + if hbm_size <= 0: + return + self._hbm_free_blocks.append((hbm_addr, hbm_size)) + + +__all__ = ["PlenaCompiler"] diff --git a/aten/plena/constants.py b/aten/plena/constants.py new file mode 100644 index 0000000..0c72ec5 --- /dev/null +++ b/aten/plena/constants.py @@ -0,0 +1,5 @@ +"""Shared constants for the ATen PLENA compiler path.""" + +MLEN = 64 # Minimum matrix block size +BLEN = 4 # Vector tile size +IMM2_BOUND = 2**18 diff --git a/aten/plena/isa_attention.py b/aten/plena/isa_attention.py new file mode 100644 index 0000000..455a639 --- /dev/null +++ b/aten/plena/isa_attention.py @@ -0,0 +1,531 @@ +"""Flash-attention ISA helpers for IsaCompiler.""" + +from __future__ import annotations + + +class IsaAttentionMixin: + # ========================================================================= + # Flash Attention Implementation + # ========================================================================= + + def _online_softmax_asm( + self, + mlen: int, + s_address: int, + m_start_address: int, + scale: float = 1.0, + ) -> str: + """ + Online Softmax Computation. + + Per row of S: + 1. m_curr = max(S[row], m_old) + 2. m_res = exp(m_old - m_curr) # used to update O downstream + 3. S'[row] = S[row] - m_curr + 4. P[row] = exp(S'[row]) + 5. l_new = l_old * m_res + sum(P[row]) + + FP SRAM layout (from m_start_address): + [0, mlen): m_old / m_curr + [mlen, 2*mlen): m_res = exp(m_old - m_curr) + [2*mlen, 3*mlen): l_old / l_new + """ + gp_regs = self.register_allocator.allocate_gp(4) + gp_s = gp_regs[0] + gp_m_addr = gp_regs[1] + gp_m_res_addr = gp_regs[2] + gp_l_addr = gp_regs[3] + + # Fixed FP register allocation for online softmax pipeline. + # These registers are shared across _online_softmax_asm, _scale_o_asm, + # and _final_scaling_asm — they MUST remain consistent across all three. + # WARNING: Do not use f1-f6 in any code that calls these methods. + fp_m_old = 1 # f1: m_old value + fp_m_res = 2 # f2: exp(m_old - m_curr) + fp_l_old = 3 # f3: l_old value + fp_sum_p = 4 # f4: sum(P) + fp_scale = 5 # f5: scale factor + fp_row_max = 6 # f6: current row max (temporary) + + lines = [] + lines.append("; === Online Softmax ===") + + # Set address registers + lines.append(f"S_ADDI_INT gp{gp_s}, gp0, {s_address}") + lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp0, {m_start_address}") + lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_addr}, {mlen}") + lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") + + # scale factor is pre-loaded at FP SRAM addr 1 by the flash-attention driver. + if scale != 1.0: + lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") + + for row in range(mlen): + lines.append(f"; Row {row}") + + lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, {row}") + lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") + + if scale != 1.0: + lines.append(f"V_MUL_VF gp{gp_s}, gp{gp_s}, f{fp_scale}, 0") + + lines.append(f"V_RED_MAX f{fp_row_max}, gp{gp_s}, 0") + + # m_curr = max(row_max, m_old) — online softmax must retain the running max. + lines.append(f"S_MAX_FP f{fp_m_old}, f{fp_row_max}, f{fp_m_old}") + + lines.append(f"S_SUB_FP f{fp_m_res}, f{fp_m_res}, f{fp_m_old}") + lines.append(f"S_EXP_FP f{fp_m_res}, f{fp_m_res}, 0") + + lines.append(f"S_ST_FP f{fp_m_res}, gp{gp_m_res_addr}, {row}") + lines.append(f"S_ST_FP f{fp_m_old}, gp{gp_m_addr}, {row}") + + lines.append(f"V_SUB_VF gp{gp_s}, gp{gp_s}, f{fp_m_old}, 0, 0") + lines.append(f"V_EXP_V gp{gp_s}, gp{gp_s}, 0, 0") + + lines.append(f"S_LD_FP f{fp_l_old}, gp{gp_l_addr}, {row}") + + lines.append(f"S_ADD_FP f{fp_sum_p}, f0, f0") + lines.append(f"V_RED_SUM f{fp_sum_p}, gp{gp_s}, 0, 0") + + lines.append(f"S_MUL_FP f{fp_l_old}, f{fp_l_old}, f{fp_m_res}") + lines.append(f"S_ADD_FP f{fp_l_old}, f{fp_l_old}, f{fp_sum_p}") + + lines.append(f"S_ST_FP f{fp_l_old}, gp{gp_l_addr}, {row}") + + lines.append(f"S_ADDI_INT gp{gp_s}, gp{gp_s}, {mlen}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _pv_multiply_asm( + self, + mlen: int, + blen: int, + head_dim: int, + p_address: int, + v_hbm_offset_reg: int, + v_hbm_offset: int, + pv_address: int, + ) -> str: + """ + Compute PV = P @ V via M_MM. + + P: (mlen, mlen) in VRAM (softmax output) + V: (mlen, head_dim) in HBM (prefetched into MSRAM in mlen-wide column blocks) + PV: (mlen, head_dim) in VRAM + + M_MM computes one (blen, mlen) @ (mlen, blen) -> (blen, blen) in a single op + (K=mlen done in one shot). For head_dim > mlen, V is split into head_dim/mlen + column blocks; the outer loop iterates blocks, middle loop iterates blen-wide + V columns within a block, inner loop iterates blen-wide P rows. + """ + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(5) + gp_p = gp_regs[0] + gp_v = gp_regs[1] + gp_pv = gp_regs[2] + gp_hbm = gp_regs[3] + gp_stride = gp_regs[4] + + num_v_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === PV Multiply (P @ V) using M_MM ===") + lines.append(f"; P: ({mlen}, {mlen}) @ V: ({mlen}, {head_dim}) -> PV: ({mlen}, {head_dim})") + lines.append("; M_MM: (blen, mlen) @ (mlen, blen) -> (blen, blen), K=mlen in one shot") + lines.append(f"; V split into {num_v_col_blocks} column blocks of width {mlen}") + lines.append("; Storage layout: (batch, mlen, hidden/mlen), column-block major") + + # STRIDE was set to mlen by the flash-attention driver — do not overwrite it here. + # M_MM_WO requires a nonzero stride reg (gp0=0 would be interpreted as stride=1). + # With column-block-major storage, consecutive rows within a column block are + # adjacent, so the writeback stride = 1. + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for v_col_block in range(num_v_col_blocks): + lines.append( + f"; --- V column block {v_col_block} (columns {v_col_block * mlen} to {(v_col_block + 1) * mlen - 1}) ---" + ) + + # Prefetch V[:, v_col_block*mlen:(v_col_block+1)*mlen] (mlen × mlen) to MSRAM. + # V is row-major in HBM: V[row, col] at offset row*head_dim + col, so the + # column-block base offset = v_hbm_offset + v_col_block * mlen (elements). + v_block_hbm_offset = v_hbm_offset + v_col_block * mlen + lines.append(f"S_ADDI_INT gp{gp_v}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") + lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") + + # mat_offset constraint: < mlen and a multiple of blen. + for v_col in range(mlen // blen): + lines.append(f"; V column {v_col_block * mlen + v_col * blen}") + + v_msram_offset = v_col * blen + lines.append(f"S_ADDI_INT gp{gp_v}, gp0, {v_msram_offset}") + + for p_row in range(mlen // blen): + p_row_addr = p_address + p_row * blen * mlen + lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_row_addr}") + + lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") + + # PV[row, col] addr = base + col_block * mlen * mlen + row * mlen + col_in_block + # with row = p_row * blen and col_in_block = v_col * blen. + pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen + lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_address + pv_offset}") + lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _scale_o_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + m_res_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Scale each row of O by m_res: O[row] *= m_res[row].""" + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_m_res = gp_regs[0] + gp_o = gp_regs[1] + fp_m_res = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Scale O by m_res ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") + + for row in range(mlen): + lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, {row}") + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _add_pv_to_o_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + pv_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """Accumulate PV into O: O[row] += PV[row].""" + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_o = gp_regs[0] + gp_pv = gp_regs[1] + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Add PV to O ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + for row in range(mlen): + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + pv_addr = pv_address + col_block * mlen * mlen + row * mlen + + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_addr}") + lines.append(f"V_ADD_VV gp{gp_o}, gp{gp_o}, gp{gp_pv}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _final_scaling_asm( + self, + mlen: int, + head_dim: int, + seq_len: int, + l_address: int, + o_address: int, + row_offset: int = 0, + ) -> str: + """ + Final scaling: O[row] /= l[row]. + + V_MUL_VF processes mlen elements at a time; when head_dim > mlen, + each row is split into head_dim // mlen mlen-wide blocks. + """ + assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" + + gp_regs = self.register_allocator.allocate_gp(2) + gp_l = gp_regs[0] + gp_o = gp_regs[1] + fp_l = 1 + + num_col_blocks = head_dim // mlen + + lines = [] + lines.append("; === Final Scaling O = O / l ===") + lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") + lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") + lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") + + lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") + + for row in range(mlen): + lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, {row}") + lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") + actual_row = row_offset + row + + for col_block in range(num_col_blocks): + o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") + lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _reset_fpsram_asm( + self, + start_address: int, + count: int, + value_address: int, # FP SRAM slot: 0 = zero, 2 = -inf + ) -> str: + """Reset a region of FP SRAM to the value at value_address.""" + gp_regs = self.register_allocator.allocate_gp(1) + gp_addr = gp_regs[0] + + lines = [] + lines.append(f"; Reset FP SRAM [{start_address}, {start_address + count})") + + lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {start_address}") + # Use f1 for FP scalar - FP registers don't go through GP allocator + lines.append(f"S_LD_FP f1, gp0, {value_address}") + + for i in range(count): + lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + def _reset_vram_asm( + self, + start_address: int, + rows: int, + cols: int, + total_rows: int, + mlen: int = 64, + row_offset: int = 0, + ) -> str: + """ + Reset a region of VRAM to zero. + + V_MUL_VF processes mlen elements at a time; when cols > mlen, each + row is split into cols // mlen mlen-wide blocks. + """ + gp_regs = self.register_allocator.allocate_gp(1) + gp_addr = gp_regs[0] + + num_col_blocks = (cols + mlen - 1) // mlen + + lines = [] + lines.append(f"; Reset VRAM rows [{row_offset}, {row_offset + rows}) of matrix at {start_address}") + lines.append(f"; {rows} rows x {cols} cols, {num_col_blocks} blocks per row") + lines.append("; Storage layout: (total_rows, mlen, cols/mlen), column-block major") + lines.append(f"; total_rows = {total_rows}, row_offset = {row_offset}") + + for row in range(rows): + actual_row = row_offset + row + for col_block in range(num_col_blocks): + addr = start_address + col_block * total_rows * mlen + actual_row * mlen + lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") + lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") + + self.register_allocator.free_gp(gp_regs) + return "\n".join(lines) + "\n" + + # ========================================================================= + # Expanded Flash Attention Operations + # ========================================================================= + + def init_online_softmax( + self, + q_idx: int, + o_matrix: str, + seq_len: int, + head_dim: int, + ) -> str: + """ + Initialize Online Softmax state for Q block q_idx: + m_old = -inf (FP SRAM), l = 0 (FP SRAM), O_row = 0 (VRAM). + """ + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_old_addr = fp_sram_start + l_addr = fp_sram_start + 2 * self.mlen # skip m_res region + + o_info = self[o_matrix] + o_vram_addr = o_info.vram_addr + row_offset = q_idx * self.mlen + + isa_code = f"; === Init Online Softmax for Q block {q_idx} ===\n" + + isa_code += self._reset_fpsram_asm(m_old_addr, self.mlen, 2) # slot 2 = -inf + isa_code += self._reset_fpsram_asm(l_addr, self.mlen, 0) # slot 0 = 0.0 + isa_code += self._reset_vram_asm( + start_address=o_vram_addr, + rows=self.mlen, + cols=head_dim, + total_rows=seq_len, + mlen=self.mlen, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + def online_softmax_block( + self, + s_block_matrix: str, + scale: float, + ) -> str: + """ + Run Online Softmax on one S block. + Input: S_block (mlen × mlen) in VRAM + Output: P (mlen × mlen) in-place in VRAM + Updates: m_old, m_res, l in FP SRAM + ``scale`` is the QK^T scaling factor (typically 1/sqrt(d)). + """ + s_info = self[s_block_matrix] + s_address = s_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_start_address = fp_sram_start + + isa_code = f"; === Online Softmax Block {s_block_matrix} ===\n" + isa_code += self._online_softmax_asm( + mlen=self.mlen, s_address=s_address, m_start_address=m_start_address, scale=scale + ) + + return self._emit(isa_code) + + def compute_pv( + self, + s_block_matrix: str, + v_sub_matrix: str, + k_idx: int, + pv_matrix: str, + head_dim: int, + ) -> str: + """ + Compute PV = P @ V[k_idx]. + + P lives in s_block_matrix (softmax result); V is prefetched from + HBM; PV is written to VRAM via pv_matrix. + """ + s_info = self[s_block_matrix] + p_address = s_info.vram_addr + + pv_info = self[pv_matrix] + pv_address = pv_info.vram_addr + + v_layout = self.get_hbm_layout(v_sub_matrix) + v_hbm_offset = k_idx * self.mlen * head_dim + + isa_code = f"; === Compute PV = P @ V[k_idx={k_idx}] ===\n" + + addr_regs = self.register_allocator.allocate_addr(1) + v_hbm_reg = addr_regs[0] + gp_regs = self.register_allocator.allocate_gp(2) + + from compiler.asm_templates import preload_addr_reg_asm + + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[v_hbm_reg], available_registers=gp_regs, addr_reg_val=[v_layout.hbm_base_addr] + ) + + isa_code += self._pv_multiply_asm( + mlen=self.mlen, + blen=self.blen, + head_dim=head_dim, + p_address=p_address, + v_hbm_offset_reg=v_hbm_reg, + v_hbm_offset=v_hbm_offset, + pv_address=pv_address, + ) + + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_addr(addr_regs) + + return self._emit(isa_code) + + def scale_o_row( + self, + o_matrix: str, + q_idx: int, + seq_len: int, + head_dim: int, + ) -> str: + """Scale the current row block of O by m_res: O[q_idx] *= m_res.""" + o_info = self[o_matrix] + o_address = o_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + m_res_addr = fp_sram_start + self.mlen + + row_offset = q_idx * self.mlen + + isa_code = f"; === Scale O[q_idx={q_idx}] by m_res ===\n" + isa_code += self._scale_o_asm( + mlen=self.mlen, + head_dim=head_dim, + seq_len=seq_len, + m_res_address=m_res_addr, + o_address=o_address, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + def final_scale_o( + self, + q_idx: int, + o_matrix: str, + seq_len: int, + head_dim: int, + ) -> str: + """Final scaling: O[q_idx] /= l.""" + o_info = self[o_matrix] + o_address = o_info.vram_addr + + fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE + l_addr = fp_sram_start + 2 * self.mlen + + row_offset = q_idx * self.mlen + + isa_code = f"; === Final Scale O for Q block {q_idx} ===\n" + isa_code += self._final_scaling_asm( + mlen=self.mlen, + head_dim=head_dim, + seq_len=seq_len, + l_address=l_addr, + o_address=o_address, + row_offset=row_offset, + ) + + return self._emit(isa_code) + + +__all__ = ["IsaAttentionMixin"] diff --git a/aten/plena/isa_compiler.py b/aten/plena/isa_compiler.py new file mode 100644 index 0000000..6deb05b --- /dev/null +++ b/aten/plena/isa_compiler.py @@ -0,0 +1,424 @@ +"""Low-level ISA emission layer for the ATen PLENA compiler.""" + +from __future__ import annotations + +from compiler.asm_templates import ( + layer_norm_asm, + preload_act_asm, + preload_addr_reg_asm, + reset_reg_asm, + rms_norm_asm, + rope_asm, + store_act_asm, +) +from compiler.aten.plena.isa_attention import IsaAttentionMixin +from compiler.aten.plena.isa_emit import IsaEmitMixin +from compiler.aten.plena.isa_fp_ops import IsaFPOpsMixin +from compiler.aten.plena.isa_matrix import IsaMatrixMixin +from compiler.aten.plena.isa_tile_rows import IsaTileRowMixin +from compiler.aten.plena.memory_state import MemoryStateMixin +from compiler.aten.plena.registers import RegisterAllocator + + +class IsaCompiler( + IsaAttentionMixin, + IsaMatrixMixin, + IsaTileRowMixin, + IsaFPOpsMixin, + IsaEmitMixin, + MemoryStateMixin, +): + """ + ISA Compiler: lowers PLENA compiler operations to assembly text. + + Owns register allocation, generated assembly, and tiled memory metadata. + """ + + _ONLINE_SOFTMAX_FPSRAM_BASE = 10 + + def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): + # MemoryStateMixin.__init__ sets dimensions, layout tables, and memory allocators. + super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) + self.real_data_ratio = real_data_ratio + self.register_allocator = RegisterAllocator() + self.generated_code = "" + + def load_batch( + self, + hbm_object_name: str, + vram_object_name: str, + vlen: int = 64, + preload_len: int = 4, + ) -> str: + """ + Load a Batch tensor from HBM to VRAM. + + HBM storage is MXFP (1 scale per 8 elements), so HBM actual size = + logical size * real_data_ratio = 1.125. VRAM stores only the vector + data (no scale), so VRAM size = logical size. + + Order (matters): allocate VRAM → register in symbol table → emit ISA. + """ + hbm_layout = self.get_hbm_layout(hbm_object_name) + h, w = hbm_layout.full_shape + hbm_addr = hbm_layout.hbm_base_addr + size = h * w + vram_base = self.vram_allocator.allocate(size, name=vram_object_name) + self.add_vram_object( + name=vram_object_name, + shape=(h, w), + vram_addr=vram_base, + dtype="fp16", + kind="Batch", + allocate_if_none=False, + strict=False, + ) + + addr_reg = self.register_allocator.allocate_addr(1)[0] + gp_regs_for_addr = self.register_allocator.allocate_gp(1) + + isa_code = f"; Load_Batch {hbm_object_name} -> {vram_object_name}\n" + isa_code += f"; HBM[{hbm_addr}] → VRAM[{vram_base}], shape=({h}, {w})\n" + + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] + ) + + # preload_act_asm requires 5 GP registers: [a_actual, stride, result, outer_loop, inner_loop]. + gp_regs_for_preload = self.register_allocator.allocate_gp(5) + isa_code += reset_reg_asm(alive_registers=gp_regs_for_preload) + + isa_code += preload_act_asm( + vlen=vlen, + preload_len=preload_len, + batch=h, + hidden_size=w, + alive_registers=gp_regs_for_preload, + act_vram_offset=vram_base, + activation_offset_reg=addr_reg, + stride_size=w, + ) + + self.register_allocator.free_gp(gp_regs_for_addr) + self.register_allocator.free_gp(gp_regs_for_preload) + self.register_allocator.free_addr([addr_reg]) + + return self._emit(isa_code) + + def store_to_hbm( + self, + tensor_name: str, + hbm_addr: int | None = None, + hbm_object_name: str | None = None, + hbm_addr_reg: int | None = None, + vlen: int = 64, + precision: int = 0, # 0 = Activation, 1 = KeyValue + store_amount: int = 4, # HBM_V_Writeback_Amount + ) -> str: + """ + Write tensor from VRAM back to HBM. + + Used to spill computed intermediates (e.g., K) from VRAM to HBM so + downstream ops (e.g., QK^T) can read them from HBM. Emits + ``store_act_asm`` for tensors of any supported size. + """ + if tensor_name not in self: + raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") + + tensor_info = self[tensor_name] + + # Batch and VRAMMatrix share the same VRAM storage layout. + if tensor_info.kind not in ("Batch", "VRAMMatrix"): + raise ValueError( + f"Tensor '{tensor_name}' must be Batch or VRAMMatrix to store from VRAM, got {tensor_info.kind}" + ) + + if tensor_info.vram_addr is None: + raise ValueError(f"Tensor '{tensor_name}' has no VRAM address to store") + + if hbm_addr is None: + if tensor_info.hbm_addr >= 0: + hbm_addr = tensor_info.hbm_addr + else: + raise ValueError(f"Tensor '{tensor_name}' has no HBM address. Please specify hbm_addr.") + + batch_size = tensor_info.shape[0] + hidden_size = tensor_info.shape[1] + + isa_code = f"; Store {tensor_name} from VRAM to HBM\n" + isa_code += f"; VRAM[{tensor_info.vram_addr}] -> HBM[{hbm_addr}], shape=({batch_size}, {hidden_size})\n" + + gp_regs = self.register_allocator.allocate_gp(5) + + if hbm_addr_reg is None: + addr_regs = self.register_allocator.allocate_addr(1) + hbm_addr_reg = addr_regs[0] + need_free_addr = True + else: + addr_regs = [] + need_free_addr = False + + try: + gp_regs_for_addr = self.register_allocator.allocate_gp(2) + isa_code += preload_addr_reg_asm( + addr_reg_to_set=[hbm_addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] + ) + self.register_allocator.free_gp(gp_regs_for_addr) + + isa_code += store_act_asm( + vlen=vlen, + batch=batch_size, + hidden_size=hidden_size, + alive_registers=gp_regs, + act_vram_offset=tensor_info.vram_addr, + hbm_addr_reg=hbm_addr_reg, + stride_size=hidden_size, + store_amount=store_amount, + ) + + if tensor_info.hbm_addr < 0 or tensor_info.hbm_addr != hbm_addr: + tensor_info.hbm_addr = hbm_addr + # HBM stores the MXFP-expanded size (logical size × real_data_ratio). + size = batch_size * hidden_size + tensor_info.hbm_size = int(size * self.real_data_ratio) + finally: + self.register_allocator.free_gp(gp_regs) + if need_free_addr: + self.register_allocator.free_addr(addr_regs) + + if hbm_object_name is not None: + self.add_hbm_object( + name=hbm_object_name, + hbm_addr=hbm_addr, + shape=(batch_size, hidden_size), + ) + + return self._emit(isa_code) + + def normalize( + self, + tensor_name: str, + mode: str = "rms", + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> str: + """ + Normalize a VRAM tensor in-place. + + Supports: + - mode="rms": RMSNorm + - mode="layer": LayerNorm + + Args: + tensor_name: Tensor name in symbol table (must have VRAM address) + mode: "rms" or "layer" + eps_offset: FPRAM address of epsilon + reci_hid_offset: FPRAM address of 1/hidden_dim + vlen: vector length (default: self.mlen) + scratchpad_vram_addr: scratchpad VRAM address (default: auto-allocate temporary space) + """ + if tensor_name not in self: + raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") + + tensor_info = self[tensor_name] + if tensor_info.vram_addr is None: + raise ValueError(f"Tensor '{tensor_name}' has no VRAM address") + + batch_size, hidden_dim = tensor_info.shape + if vlen is None: + vlen = self.mlen + if hidden_dim % vlen != 0: + raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by vlen ({vlen}) for normalization_asm") + + mode = mode.lower() + if mode not in ("rms", "layer"): + raise ValueError(f"Unsupported normalization mode: {mode}. Expected 'rms' or 'layer'.") + + gp_regs = self.register_allocator.allocate_gp(4) + + temp_scratchpad_name = None + if scratchpad_vram_addr is None: + temp_scratchpad_name = f"__norm_scratch__{tensor_name}__{len(self.generated_code)}" + scratchpad_vram_addr = self.vram_allocator.allocate(vlen, name=temp_scratchpad_name) + + try: + isa_code = f"; Normalize ({mode}) {tensor_name}, shape=({batch_size}, {hidden_dim})\n" + if mode == "rms": + isa_code += rms_norm_asm( + _eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + alive_registers=gp_regs, + activation_base_address=tensor_info.vram_addr, + scratchpad_base_address=scratchpad_vram_addr, + vlen=vlen, + batch_size=batch_size, + hidden_dim=hidden_dim, + ) + else: + isa_code += layer_norm_asm( + _eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + alive_registers=gp_regs, + activation_base_address=tensor_info.vram_addr, + scratchpad_base_address=scratchpad_vram_addr, + vlen=vlen, + batch_size=batch_size, + hidden_dim=hidden_dim, + ) + + return self._emit(isa_code) + finally: + # Always release allocated GP registers used by normalization template. + self.register_allocator.free_gp(gp_regs) + if temp_scratchpad_name is not None: + self.vram_allocator.free(temp_scratchpad_name, strict=False) + + def rope( + self, + x_name: str, + x_rot_name: str, + cos_name: str, + sin_name: str, + ) -> str: + """Apply RoPE in-place: x = x * cos + rotate_half(x) * sin + + All four tensors must already be in VRAM with the same shape (seq_len, head_dim). + x_rot must be preloaded by the caller as rotate_half(x). + """ + x_info = self[x_name] + xrot_info = self[x_rot_name] + cos_info = self[cos_name] + sin_info = self[sin_name] + + if x_info.vram_addr is None: + raise ValueError(f"Tensor '{x_name}' has no VRAM address") + + seq_len, head_dim = x_info.shape + vlen = self.mlen + + if head_dim % vlen != 0: + raise ValueError(f"head_dim ({head_dim}) must be divisible by vlen ({vlen}) for rope") + + gp_regs = self.register_allocator.allocate_gp(5) + + scratch_name = f"__rope_scratch__{x_name}__{len(self.generated_code)}" + scratch_addr = self.vram_allocator.allocate(vlen, name=scratch_name) + + try: + isa_code = rope_asm( + alive_registers=gp_regs, + x_base_address=x_info.vram_addr, + x_rot_base_address=xrot_info.vram_addr, + cos_base_address=cos_info.vram_addr, + sin_base_address=sin_info.vram_addr, + scratchpad_base_address=scratch_addr, + vlen=vlen, + seq_len=seq_len, + head_dim=head_dim, + ) + return self._emit(isa_code) + finally: + self.register_allocator.free_gp(gp_regs) + self.vram_allocator.free(scratch_name, strict=False) + + def get_code(self) -> str: + """Get all accumulated generated ISA code""" + return self.generated_code + + def reset(self): + """Reset compiler state (clear code, but retain symbol table)""" + self.generated_code = "" + self.register_allocator = RegisterAllocator() + # Call MemoryStateMixin.reset() explicitly since the merged class shadows it. + MemoryStateMixin.reset(self) + + def get_tensor_info(self, name: str): + """Get unified tensor/object info by name.""" + return self[name] + + def add_hbm_object( + self, + name: str, + hbm_addr: int, + shape: tuple[int, int], + real_data_ratio: float = 1.125, + ): + """Register an HBM object and build its HBM layout. + + Wraps the memory-layout ``add_hbm_object`` with a different positional + parameter order ``(name, hbm_addr, shape, ...)`` that all IsaCompiler + callers use. + """ + return MemoryStateMixin.add_hbm_object( + self, + name=name, + shape=shape, + hbm_addr=hbm_addr, + real_data_ratio=real_data_ratio, + ) + + def free_hbm_object(self, name: str, strict: bool = False): + """Free an HBM object by name (defaults to non-strict).""" + return MemoryStateMixin.free_hbm_object(self, name, strict=strict) + + def get_vram_addr(self, name: str) -> int: + """Get VRAM base address of an object.""" + info = self.get_tensor_info(name) + if info.vram_addr is None: + raise ValueError(f"Object '{name}' has no VRAM address") + return info.vram_addr + + def get_vram_tile_addr( + self, + name: str, + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> int: + """ + Get VRAM address of a specific tile (sub-block) in a VRAM matrix. + + Args: + name: matrix name + tile_row_idx: tile row index (0-based) + tile_col_idx: tile col index (0-based) + """ + self._ensure_vram_matrix_layout(name) + sub = self.get_vram_sub_block(name, tile_row_idx, tile_col_idx) + return sub.vram_addr + + def ensure_hbm_sub_matrix( + self, + name: str, + hbm_addr: int, + shape: tuple[int, int], + real_data_ratio: float = 1.125, + ): + """Ensure HBM matrix layout exists.""" + if name in self.hbm_matrices: + return + self.add_hbm_object( + name=name, + hbm_addr=hbm_addr, + shape=shape, + real_data_ratio=real_data_ratio, + ) + + def ensure_vram_matrix_layout(self, name: str, shape: tuple[int, int]): + """Ensure VRAM matrix layout exists for an already allocated VRAM object.""" + if name in self.vram_matrices: + return + vram_addr = self.get_vram_addr(name) + self.add_vram_object( + name=name, + shape=shape, + vram_addr=vram_addr, + allocate_if_none=False, + ) + + def free_vram_object(self, name: str, strict: bool = False): + """Free a VRAM object by name (defaults to non-strict).""" + return MemoryStateMixin.free_vram_object(self, name, strict=strict) + +__all__ = ["IsaCompiler"] diff --git a/aten/plena/isa_emit.py b/aten/plena/isa_emit.py new file mode 100644 index 0000000..adad6be --- /dev/null +++ b/aten/plena/isa_emit.py @@ -0,0 +1,75 @@ +"""Emission and FPRAM allocation helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import AsmInput, IsaBuilder, render_asm +from compiler.aten.plena.registers import RegisterAllocator + + +class IsaEmitMixin: + # ========================================================================= + # FP Register & FPRAM Management (inlined from former FPRAMCompiler). + # All state lives on self (register_allocator, fpram_allocator, etc.). + # ========================================================================= + + @property + def _reg(self) -> RegisterAllocator: + """Shorthand for self.register_allocator (used by FPVar ISA helpers).""" + return self.register_allocator + + @property + def _unroll(self) -> bool: + """Shorthand for self.unroll_loops.""" + return self.unroll_loops + + def _emit(self, isa_code: AsmInput) -> str: + """Append ISA text to the output buffer and return it.""" + rendered = render_asm(isa_code) + self.generated_code += rendered + return rendered + + def emit(self, isa_code: AsmInput) -> str: + """Public emission hook for code outside IsaCompiler internals.""" + return self._emit(isa_code) + + def emit_comment(self, text: str) -> str: + """Append one assembly comment line.""" + return self._emit(IsaBuilder().comment(text)) + + # ------------------------------------------------------------------ + # FP Register management + # ------------------------------------------------------------------ + + def allocate_fp_reg(self, count: int = 1) -> list[int]: + """Allocate FP registers (f0-f7).""" + return self._reg.allocate_fp(count) + + def free_fp_reg(self, registers: list[int]): + """Free FP registers.""" + self._reg.free_fp(registers) + + # ------------------------------------------------------------------ + # FPRAM address-space management + # ------------------------------------------------------------------ + + def allocate_fpram(self, name: str, size: int) -> int: + """Allocate FPRAM space, returns base address.""" + info = self.add_fpram_object(name=name, size=size) + if info.fpram_addr is None: + raise RuntimeError(f"Failed to allocate FPRAM for '{name}'") + return info.fpram_addr + + def free_fpram(self, name: str, strict: bool = True): + """Free FPRAM object by name.""" + return self.free_fpram_object(name, strict=strict) + + def get_fpram_addr(self, name: str) -> int: + """Get FPRAM base address from object name.""" + return self.get_fpram_layout(name).fpram_addr + + def get_fpram_size(self, name: str) -> int: + """Get FPRAM allocation size from object name.""" + return self.get_fpram_layout(name).size + + +__all__ = ["IsaEmitMixin"] diff --git a/aten/plena/isa_fp_ops.py b/aten/plena/isa_fp_ops.py new file mode 100644 index 0000000..dae941f --- /dev/null +++ b/aten/plena/isa_fp_ops.py @@ -0,0 +1,318 @@ +"""FPVar and FPRAM ISA helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import IsaBuilder, fp, gp + + +class IsaFPOpsMixin: + # ========================================================================= + # FPVar ISA helpers (address-based) + # ========================================================================= + + def _emit_fpvar_skip(self, op_name: str, count: int) -> str: + return self._emit(IsaBuilder().comment(f"FPVar {op_name} skipped: count={count}")) + + def _fpvar_unary_asm( + self, + op_name: str, + op_description: str, + opcode: str, + src_addr: int, + dst_addr: int, + count: int, + ) -> str: + if count <= 0: + return self._emit_fpvar_skip(op_name, count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar {op_name}: {op_description}, count={count}") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_src), i) + asm.instr(opcode, fp(1), fp(1), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + asm.instr(opcode, fp(1), fp(1), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _fpvar_binary_asm( + self, + op_name: str, + op_description: str, + opcode: str, + src1_addr: int, + src2_addr: int, + dst_addr: int, + count: int, + ) -> str: + if count <= 0: + return self._emit_fpvar_skip(op_name, count) + + gp_regs = self._reg.allocate_gp(4) + gp_a, gp_b, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar {op_name}: {op_description}, count={count}") + asm.instr("S_ADDI_INT", gp(gp_a), gp(0), src1_addr) + asm.instr("S_ADDI_INT", gp(gp_b), gp(0), src2_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_a), i) + asm.instr("S_LD_FP", fp(2), gp(gp_b), i) + asm.instr(opcode, fp(1), fp(1), fp(2)) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_a), 0) + asm.instr("S_LD_FP", fp(2), gp(gp_b), 0) + asm.instr(opcode, fp(1), fp(1), fp(2)) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_a), gp(gp_a), 1) + asm.instr("S_ADDI_INT", gp(gp_b), gp(gp_b), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_copy_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Copy", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment( + f"FPVar Copy: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_addr}:{src_addr + count}]" + ) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(1), gp(gp_src), i) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_fill_from_fpram_asm(self, dst_addr: int, src_fpram_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Fill", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar Fill: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_fpram_addr}]") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_fpram_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_LD_FP", fp(1), gp(gp_src), 0) + + if self._unroll: + for i in range(count): + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_reci_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_unary_asm("Reci", "dst = 1/src", "S_RECI_FP", src_addr, dst_addr, count) + + def fpvar_exp_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_unary_asm("Exp", "dst = exp(src)", "S_EXP_FP", src_addr, dst_addr, count) + + def fpvar_add_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Add", "dst = src1 + src2", "S_ADD_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_sub_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Sub", "dst = src1 - src2", "S_SUB_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_mul_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Mul", "dst = src1 * src2", "S_MUL_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_max_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: + return self._fpvar_binary_asm("Max", "dst = max(src1, src2)", "S_MAX_FP", src1_addr, src2_addr, dst_addr, count) + + def fpvar_sum_asm(self, src_addr: int, dst_addr: int, count: int) -> str: + if count <= 0: + return self._emit_fpvar_skip("Sum", count) + + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"FPVar Sum: FPRAM[{dst_addr}] = sum(FPRAM[{src_addr}:{src_addr + count}])") + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + + if self._unroll: + for i in range(count): + asm.instr("S_LD_FP", fp(2), gp(gp_src), i) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + else: + asm.instr("C_LOOP_START", gp(gp_loop), count) + asm.instr("S_LD_FP", fp(2), gp(gp_src), 0) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), 1) + asm.instr("C_LOOP_END", gp(gp_loop)) + + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def fpvar_shift_asm( + self, + src_addr: int, + dst_addr: int, + count: int, + shift: int, + fill_fpram_addr: int = 0, + ) -> str: + """ + Shift FPVar into dst. + - shift > 0: right shift (leading positions filled) + - shift < 0: left shift (trailing positions filled) + """ + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_fill = gp_regs + try: + asm = IsaBuilder().comment( + f"FPVar Shift: dst=shift(src, shift={shift}), count={count}, fill=FPRAM[{fill_fpram_addr}]" + ) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr) + asm.instr("S_ADDI_INT", gp(gp_fill), gp(0), fill_fpram_addr) + asm.instr("S_LD_FP", fp(3), gp(gp_fill), 0) + + for i in range(count): + src_idx = i - shift + if 0 <= src_idx < count: + asm.instr("S_LD_FP", fp(1), gp(gp_src), src_idx) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), i) + else: + asm.instr("S_ST_FP", fp(3), gp(gp_dst), i) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + # ========================================================================= + # FPVar helpers (name-based wrappers over the address-based ISA generators) + # ========================================================================= + + def _fpram_unary(self, asm_method: str, src_name: str, dst_name: str, count: int | None = None) -> str: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) if count is None else count + return getattr(self, asm_method)(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) + + def _fpram_binary( + self, + asm_method: str, + src1_name: str, + src2_name: str, + dst_name: str, + count: int | None = None, + ) -> str: + count = ( + min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) + if count is None + else count + ) + return getattr(self, asm_method)( + self.get_fpram_addr(src1_name), + self.get_fpram_addr(src2_name), + self.get_fpram_addr(dst_name), + count, + ) + + def fpram_copy(self, src_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_unary("fpvar_copy_asm", src_name, dst_name, count) + + def fpram_reci(self, src_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_unary("fpvar_reci_asm", src_name, dst_name, count) + + def fpram_exp(self, src_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_unary("fpvar_exp_asm", src_name, dst_name, count) + + def fpram_add(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_binary("fpvar_add_asm", src1_name, src2_name, dst_name, count) + + def fpram_sub(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_binary("fpvar_sub_asm", src1_name, src2_name, dst_name, count) + + def fpram_mul(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_binary("fpvar_mul_asm", src1_name, src2_name, dst_name, count) + + def fpram_max(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: + return self._fpram_binary("fpvar_max_asm", src1_name, src2_name, dst_name, count) + + def fpram_sum(self, src_name: str, dst_name: str, count: int | None = None) -> str: + if count is None: + count = self.get_fpram_size(src_name) + return self.fpvar_sum_asm( + self.get_fpram_addr(src_name), + self.get_fpram_addr(dst_name), + count, + ) + + def fpram_shift( + self, + src_name: str, + dst_name: str, + shift: int, + count: int | None = None, + fill_fpram_name: str | None = None, + ) -> str: + if count is None: + count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) + fill_addr = 0 if fill_fpram_name is None else self.get_fpram_addr(fill_fpram_name) + return self.fpvar_shift_asm( + src_addr=self.get_fpram_addr(src_name), + dst_addr=self.get_fpram_addr(dst_name), + count=count, + shift=shift, + fill_fpram_addr=fill_addr, + ) + + def fpram_fill_from_fpram(self, dst_name: str, src_fpram_addr: int, count: int | None = None) -> str: + if count is None: + count = self.get_fpram_size(dst_name) + return self.fpvar_fill_from_fpram_asm( + dst_addr=self.get_fpram_addr(dst_name), + src_fpram_addr=src_fpram_addr, + count=count, + ) + + +__all__ = ["IsaFPOpsMixin"] diff --git a/aten/plena/isa_matrix.py b/aten/plena/isa_matrix.py new file mode 100644 index 0000000..8056552 --- /dev/null +++ b/aten/plena/isa_matrix.py @@ -0,0 +1,778 @@ +"""Matrix movement and VRAM projection helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.asm_templates import preload_addr_reg_asm +from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl +from compiler.aten.isa_builder import IsaBuilder, addr as areg, gp + + +class IsaMatrixMixin: + def _emit_hbm_matrix_load(self, layout, gp_count: int, build_body) -> str: + gp_regs = self.register_allocator.allocate_gp(gp_count) + gp_for_addr = self.register_allocator.allocate_gp(2) + addr_reg = self.register_allocator.allocate_addr(1)[0] + try: + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_for_addr, + addr_reg_val=[layout.hbm_base_addr], + ) + isa_code += build_body(addr_reg, gp_regs) + return self._emit(isa_code) + finally: + self.register_allocator.free_gp(gp_regs) + self.register_allocator.free_gp(gp_for_addr) + self.register_allocator.free_addr([addr_reg]) + + def reset_mram(self) -> str: + """ + Reset MRAM allocator, free all allocated space + Used in scenarios where sub-blocks need to be reloaded within a for loop + """ + self.mram_allocator.reset() + self.clear_mram_bindings() + + return self._emit(IsaBuilder().comment("=== Reset MRAM ===")) + + def _default_hbm_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3] if gp_regs is None else gp_regs + + def _emit_hbm_prefetch_setup(self, asm: IsaBuilder, layout, gp_scale: int, gp_stride: int) -> None: + rows, cols = layout.full_shape + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), rows * cols) + asm.instr("C_SET_SCALE_REG", gp(gp_scale)) + asm.instr("S_ADDI_INT", gp(gp_stride), gp(0), cols) + asm.instr("C_SET_STRIDE_REG", gp(gp_stride)) + + def _emit_hbm_subblock_prefetch( + self, + asm: IsaBuilder, + layout, + row_idx: int, + col_idx: int, + mram_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + comment: str | None = None, + ) -> None: + sub_block = layout.get_sub_block(row_idx, col_idx) + hbm_offset = sub_block.hbm_offset + sub_block.mram_addr = mram_addr + + asm.comment(comment if comment is not None else f"SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") + asm.instr("S_ADDI_INT", gp(gp_mram), gp(0), mram_addr) + asm.instr("S_ADDI_INT", gp(gp_scale), gp(0), hbm_offset) + asm.instr("H_PREFETCH_M", gp(gp_mram), gp(gp_scale), areg(hbm_addr_reg), 1, 0) + + def _emit_hbm_subblock_sequence( + self, + asm: IsaBuilder, + layout, + block_coords, + mram_start_addr: int, + hbm_addr_reg: int, + gp_scale: int, + gp_mram: int, + ) -> None: + mram_addr = mram_start_addr + block_size = self.mlen * self.mlen + for row_idx, col_idx in block_coords: + self._emit_hbm_subblock_prefetch( + asm, + layout, + row_idx, + col_idx, + mram_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + mram_addr += block_size + + def load_sub_matrix_asm( + self, + name: str, + row_idx: int, + col_idx: int, + mram_dest_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """Emit HBM->MRAM prefetch for one mlen x mlen sub-block.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + hbm_offset = layout.get_sub_block(row_idx, col_idx).hbm_offset + self._emit_hbm_subblock_prefetch( + asm, + layout, + row_idx, + col_idx, + mram_dest_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + comment=f"HBM offset: {hbm_offset} (precomputed)", + ) + + return asm.render() + + def load_row_sub_matrices_asm( + self, + name: str, + row_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + ) -> str: + """Emit HBM->MRAM prefetches for one block row.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + + self._emit_hbm_subblock_sequence( + asm, + layout, + ((row_idx, col_idx) for col_idx in range(layout.num_col_blocks)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + + return asm.render() + + def load_col_sub_matrices_asm( + self, + name: str, + col_idx: int, + mram_start_addr: int, + hbm_addr_reg: int = 1, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """Emit HBM->MRAM prefetches for one block column or K-split slice.""" + gp_regs = self._default_hbm_gp_regs(gp_regs) + layout = self.hbm_matrices[name] + num_row_blocks = layout.num_row_blocks + + asm = IsaBuilder() + asm.comment(f"Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") + + gp_scale = gp_regs[0] + gp_stride = gp_regs[1] + gp_mram = gp_regs[2] + self._emit_hbm_prefetch_setup(asm, layout, gp_scale, gp_stride) + + effective_count = k_block_count if k_block_count is not None else num_row_blocks + self._emit_hbm_subblock_sequence( + asm, + layout, + ((row_idx, col_idx) for row_idx in range(k_block_start, k_block_start + effective_count)), + mram_start_addr, + hbm_addr_reg, + gp_scale, + gp_mram, + ) + + return asm.render() + + def _default_projection_gp_regs(self, gp_regs: list[int] | None) -> list[int]: + return [1, 2, 3, 4, 5, 6, 7, 8, 9] if gp_regs is None else gp_regs + + def _projection_context(self, vram_mat_name: str, vram_row_idx: int, mram_mat_name: str): + if vram_mat_name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") + vram_layout = self.vram_matrices[vram_mat_name] + return vram_layout, self.hbm_matrices[mram_mat_name], vram_layout.get_row_blocks(vram_row_idx) + + def _loaded_mram_start(self, blocks, missing_label) -> int: + for sub_block in blocks: + if sub_block.mram_addr is None: + raise RuntimeError(f"SubBlock {missing_label(sub_block)} not loaded to MRAM") + return blocks[0].mram_addr + + def _vram_sub_projection_asm_impl( + self, + header_lines: list[str], + vram_row_start_addr: int, + mram_start_addr: int, + result_vram_addr: int, + full_batch: int, + num_hidden_blocks: int, + mat_col_stride: int, + transposed: bool, + gp_regs: list[int], + caller_name: str, + unroll: bool | None = None, + ) -> str: + """Emit the shared projection loop after callers resolve operands.""" + do_unroll = self.unroll_loops if unroll is None else unroll + return vram_sub_projection_asm_impl( + mlen=self.mlen, + blen=self.blen, + unroll_loops=do_unroll, + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=transposed, + gp_regs=gp_regs, + caller_name=caller_name, + ) + + def vram_sub_projection_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_col_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + unroll: bool | None = None, + ) -> str: + """Emit VRAM[row][:] @ MRAM[:][col] projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) + mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) + if k_block_count is not None: + mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] + + num_hidden_blocks = len(mram_col_blocks) + if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): + raise ValueError( + f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " + f"got {num_hidden_blocks}" + ) + + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr + mram_col_start_addr = self._loaded_mram_start( + mram_col_blocks, + lambda block: f"{mram_mat_name}[{block.row_idx}][{mram_col_idx}]", + ) + + header_lines = [ + f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", + f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_col_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=self.blen, + transposed=False, + gp_regs=gp_regs, + caller_name="vram_sub_projection_asm", + unroll=unroll, + ) + + def vram_sub_projection_T_asm( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_row_idx: int, + result_vram_addr: int, + gp_regs: list[int] | None = None, + unroll: bool | None = None, + ) -> str: + """Emit VRAM[row][:] @ MRAM[row][:]^T projection.""" + gp_regs = self._default_projection_gp_regs(gp_regs) + vram_layout, mram_layout, vram_row_blocks = self._projection_context( + vram_mat_name, vram_row_idx, mram_mat_name + ) + mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) + + if len(vram_row_blocks) != len(mram_row_blocks): + raise ValueError( + f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" + ) + + num_hidden_blocks = len(vram_row_blocks) + full_batch = vram_layout.full_shape[0] + vram_row_start_addr = vram_row_blocks[0].vram_addr + mram_row_start_addr = self._loaded_mram_start( + mram_row_blocks, + lambda block: f"{mram_mat_name}[{mram_row_idx}][{block.col_idx}]", + ) + # M_TMM reads the weight in transposed layout, so the outer-column + # stride is one full sub-block instead of the non-transposed blen stride. + mat_col_stride = self.blen * self.mlen + + header_lines = [ + f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", + f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", + f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", + f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", + ] + + return self._vram_sub_projection_asm_impl( + header_lines=header_lines, + vram_row_start_addr=vram_row_start_addr, + mram_start_addr=mram_row_start_addr, + result_vram_addr=result_vram_addr, + full_batch=full_batch, + num_hidden_blocks=num_hidden_blocks, + mat_col_stride=mat_col_stride, + transposed=True, + gp_regs=gp_regs, + caller_name="vram_sub_projection_T_asm", + unroll=unroll, + ) + + def vram_block_add_asm( + self, + src1_name: str, + src1_row_idx: int, + src1_col_idx: int, + src2_name: str, + src2_row_idx: int, + src2_col_idx: int, + target_name: str, + target_row_idx: int, + target_col_idx: int, + gp_regs: list[int] | None = None, + ) -> str: + """ + Add two mlen x mlen blocks and write to any target block: + target[target_row_idx][target_col_idx] = + src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] + """ + if gp_regs is None: + gp_regs = [1, 2, 3, 4] + if len(gp_regs) < 4: + raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") + + src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) + src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) + target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) + + gp_dst = gp_regs[0] + gp_src1 = gp_regs[1] + gp_src2 = gp_regs[2] + gp_loop = gp_regs[3] + + lines = [ + f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " + f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" + ] + + if self.unroll_loops: + for i in range(self.mlen): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + else: + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") + lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + + return "\n".join(lines) + "\n" + + def load_sub_matrix_row( + self, + name: str, + row_idx: int, + mram_start_addr: int | None = None, + ) -> str: + """Load entire row sub-blocks from HBM to MRAM: matrix[row_idx][:].""" + layout = self.get_hbm_layout(name) + num_col_blocks = layout.num_col_blocks + block_size = self.mlen * self.mlen + + if mram_start_addr is None: + total_size = num_col_blocks * block_size + mram_start_addr = self.mram_allocator.allocate(f"{name}[{row_idx}][:]", total_size) + + return self._emit_hbm_matrix_load( + layout, + 3, + lambda addr_reg, gp_regs: self.load_row_sub_matrices_asm( + name=name, + row_idx=row_idx, + mram_start_addr=mram_start_addr, + hbm_addr_reg=addr_reg, + gp_regs=gp_regs, + ), + ) + + def load_sub_matrix_col( + self, + name: str, + col_idx: int, + mram_start_addr: int | None = None, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """ + Load entire column sub-blocks from HBM to MRAM: matrix[:][col_idx]. + Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. + """ + layout = self.get_hbm_layout(name) + num_row_blocks = layout.num_row_blocks + block_size = self.mlen * self.mlen + + if mram_start_addr is None: + effective_count = k_block_count if k_block_count is not None else num_row_blocks + total_size = effective_count * block_size + mram_start_addr = self.mram_allocator.allocate(f"{name}[:][{col_idx}]", total_size) + + return self._emit_hbm_matrix_load( + layout, + 3, + lambda addr_reg, gp_regs: self.load_col_sub_matrices_asm( + name=name, + col_idx=col_idx, + mram_start_addr=mram_start_addr, + hbm_addr_reg=addr_reg, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ), + ) + + def allocate_vram_matrix( + self, + name: str, + rows: int, + cols: int, + strict: bool = True, + ) -> int: + """Allocate a VRAM matrix large enough to hold combined results of multiple sub-blocks. Returns the VRAM base address.""" + size = rows * cols + vram_addr = self.vram_allocator.allocate(size, name=name) + + self.add_vram_object( + name=name, + shape=(rows, cols), + vram_addr=vram_addr, + dtype="fp32", + kind="VRAMMatrix", + allocate_if_none=False, + strict=strict, + ) + + isa_code = f"; Allocate VRAM Matrix {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" + self._emit(isa_code) + + return vram_addr + + def _ensure_vram_matrix_layout(self, matrix_name: str): + """Ensure a VRAM-resident tensor has a block layout.""" + if matrix_name not in self: + raise KeyError(f"Matrix '{matrix_name}' not found in symbol table") + + info = self[matrix_name] + if info.vram_addr is None: + raise ValueError(f"Matrix '{matrix_name}' has no VRAM address") + + try: + self.get_vram_layout(matrix_name) + except KeyError: + self.register_vram_matrix( + name=matrix_name, + shape=info.shape, + vram_base_addr=info.vram_addr, + ) + + def vram_block_add_to( + self, + src1_matrix: str, + src1_row_idx: int, + src1_col_idx: int, + src2_matrix: str, + src2_row_idx: int, + src2_col_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + ) -> str: + """ + mlen x mlen block add: + target[rt][ct] = src1[r1][c1] + src2[r2][c2] + + Source/target may be the same matrix (supports in-place overwrite). + """ + self._ensure_vram_matrix_layout(src1_matrix) + self._ensure_vram_matrix_layout(src2_matrix) + self._ensure_vram_matrix_layout(target_matrix) + + gp_regs = self.register_allocator.allocate_gp(4) + isa_code = self.vram_block_add_asm( + src1_name=src1_matrix, + src1_row_idx=src1_row_idx, + src1_col_idx=src1_col_idx, + src2_name=src2_matrix, + src2_row_idx=src2_row_idx, + src2_col_idx=src2_col_idx, + target_name=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + gp_regs=gp_regs, + ) + self.register_allocator.free_gp(gp_regs) + + return self._emit(isa_code) + + def vram_matrix_add( + self, + dst_matrix: str, + src_matrix: str, + dst_row_offset: int = 0, + src_row_offset: int = 0, + num_rows: int | None = None, + ) -> str: + """ + General VRAM Matrix Addition: dst[row_offset:] += src. + + row_offsets are logical rows (not VRAM addresses); num_rows defaults + to the source matrix's row count. + """ + dst_info = self[dst_matrix] + src_info = self[src_matrix] + + # Block-add path depends on registered VRAM block layouts. + self._ensure_vram_matrix_layout(dst_matrix) + self._ensure_vram_matrix_layout(src_matrix) + + dst_addr = dst_info.vram_addr + src_addr = src_info.vram_addr + + dst_rows, dst_cols = dst_info.shape + src_rows, src_cols = src_info.shape + + if num_rows is None: + num_rows = src_rows + + # Ensure column count matches + assert dst_cols == src_cols, f"Column mismatch: dst={dst_cols}, src={src_cols}" + assert dst_row_offset + num_rows <= dst_rows, ( + f"dst row range out of bounds: offset={dst_row_offset}, num_rows={num_rows}, dst_rows={dst_rows}" + ) + assert src_row_offset + num_rows <= src_rows, ( + f"src row range out of bounds: offset={src_row_offset}, num_rows={num_rows}, src_rows={src_rows}" + ) + lines = [] + lines.append( + f"; === VRAM Matrix Add: " + f"{dst_matrix}[{dst_row_offset}:{dst_row_offset + num_rows}] += " + f"{src_matrix}[{src_row_offset}:{src_row_offset + num_rows}] ===" + ) + lines.append(f"; dst shape: {dst_info.shape}, src shape: {src_info.shape}") + + # Prefer block add path so we can reuse the compact C_LOOP-based add kernel. + block_aligned = ( + dst_cols % self.mlen == 0 + and src_cols % self.mlen == 0 + and dst_row_offset % self.mlen == 0 + and src_row_offset % self.mlen == 0 + and num_rows % self.mlen == 0 + ) + + if block_aligned: + num_row_blocks = num_rows // self.mlen + num_col_blocks = dst_cols // self.mlen + dst_row_block_base = dst_row_offset // self.mlen + src_row_block_base = src_row_offset // self.mlen + lines.append(f"; block add path: row_blocks={num_row_blocks}, col_blocks={num_col_blocks}") + + for row_block in range(num_row_blocks): + for col_block in range(num_col_blocks): + gp_regs = self.register_allocator.allocate_gp(4) + lines.append( + self.vram_block_add_asm( + src1_name=dst_matrix, + src1_row_idx=dst_row_block_base + row_block, + src1_col_idx=col_block, + src2_name=src_matrix, + src2_row_idx=src_row_block_base + row_block, + src2_col_idx=col_block, + target_name=dst_matrix, + target_row_idx=dst_row_block_base + row_block, + target_col_idx=col_block, + gp_regs=gp_regs, + ).rstrip("\n") + ) + self.register_allocator.free_gp(gp_regs) + else: + # Fallback for non-mlen-aligned ranges. + gp_regs = self.register_allocator.allocate_gp(2) + gp_dst = gp_regs[0] + gp_src = gp_regs[1] + num_col_blocks = dst_cols // self.mlen + lines.append(f"; fallback row-wise path: num_rows={num_rows}, num_col_blocks={num_col_blocks}") + + for row in range(num_rows): + dst_actual_row = dst_row_offset + row + src_actual_row = src_row_offset + row + + for col_block in range(num_col_blocks): + dst_block_addr = dst_addr + col_block * dst_rows * self.mlen + dst_actual_row * self.mlen + src_block_addr = src_addr + col_block * src_rows * self.mlen + src_actual_row * self.mlen + + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_block_addr}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_block_addr}") + lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") + self.register_allocator.free_gp(gp_regs) + + isa_code = "\n".join(lines) + "\n" + return self._emit(isa_code) + + def _target_tile_addr(self, target_matrix: str, target_row_idx: int, target_col_idx: int) -> tuple[int, int, int]: + if target_matrix not in self: + raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") + + target_info = self[target_matrix] + target_rows, _target_cols = target_info.shape + target_base_addr = target_info.vram_addr + result_vram_addr = ( + target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen + ) + return result_vram_addr, target_base_addr, target_rows + + def _emit_vram_sub_projection_to( + self, + *, + transposed: bool, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + result_vram_addr, target_base_addr, target_rows = self._target_tile_addr( + target_matrix, target_row_idx, target_col_idx + ) + gp_regs = self.register_allocator.allocate_gp(9) + + if transposed: + isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + asm = self.vram_sub_projection_T_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_row_idx=mram_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + ) + else: + isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" + asm = self.vram_sub_projection_asm( + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_col_idx=mram_idx, + result_vram_addr=result_vram_addr, + gp_regs=gp_regs, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" + isa_code += asm + + self.register_allocator.free_gp(gp_regs) + return self._emit(isa_code) + + def vram_sub_projection_to( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_col_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + k_block_start: int = 0, + k_block_count: int | None = None, + ) -> str: + """ + Sub-block multiplication: + target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[:][mram_col_idx]. + Target matrix must have been allocated via allocate_vram_matrix. + """ + return self._emit_vram_sub_projection_to( + transposed=False, + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_idx=mram_col_idx, + target_matrix=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + + def vram_sub_projection_T_to( + self, + vram_mat_name: str, + vram_row_idx: int, + mram_mat_name: str, + mram_row_idx: int, + target_matrix: str, + target_row_idx: int, + target_col_idx: int, + ) -> str: + """ + Transposed sub-block multiplication: + target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[mram_row_idx][:]^T. + + Used by Flash Attention for S = Q @ K^T: + Q[i][:]: (mlen, hidden_size) row sub-block + K[j][:]: (mlen, hidden_size) row sub-block, transposed to (hidden_size, mlen) + S[i][j]: (mlen, mlen) + """ + return self._emit_vram_sub_projection_to( + transposed=True, + vram_mat_name=vram_mat_name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_mat_name, + mram_idx=mram_row_idx, + target_matrix=target_matrix, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + ) + + +__all__ = ["IsaMatrixMixin"] diff --git a/aten/plena/isa_tile_rows.py b/aten/plena/isa_tile_rows.py new file mode 100644 index 0000000..a225209 --- /dev/null +++ b/aten/plena/isa_tile_rows.py @@ -0,0 +1,437 @@ +"""Tile-row ISA helpers for IsaCompiler.""" + +from __future__ import annotations + +from compiler.aten.isa_builder import IsaBuilder, fp, gp + + +class IsaTileRowMixin: + # ========================================================================= + # Tile-row helpers (name-based) + # ========================================================================= + + def _tile_addr(self, matrix_name: str, tile_row_idx: int = 0, tile_col_idx: int = 0) -> int: + return self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) + + def _tile_row_single_matrix_op( + self, + asm_method: str, + matrix_name: str, + arg, + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return getattr(self, asm_method)(self._tile_addr(matrix_name, tile_row_idx, tile_col_idx), arg) + + def _tile_row_binary_matrix_op( + self, + asm_method: str, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + return getattr(self, asm_method)( + self._tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx), + self._tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx), + rows, + ) + + def tile_row_max( + self, + source_matrix: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_max_asm", source_matrix, row_map, tile_row_idx, tile_col_idx) + + def tile_row_sum( + self, + source_matrix: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_sum_asm", source_matrix, row_map, tile_row_idx, tile_col_idx) + + def tile_row_exp( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_exp_asm", matrix_name, rows, tile_row_idx, tile_col_idx) + + def tile_row_reci( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_reci_asm", matrix_name, rows, tile_row_idx, tile_col_idx) + + def tile_row_sub_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_sub_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) + + def tile_row_mul_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_mul_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) + + def tile_row_add_fp( + self, + matrix_name: str, + row_map: list[tuple[int, int]], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("tile_row_add_fp_asm", matrix_name, row_map, tile_row_idx, tile_col_idx) + + def tile_row_add( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + return self._tile_row_binary_matrix_op( + "tile_row_add_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) + + def tile_row_sub( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + return self._tile_row_binary_matrix_op( + "tile_row_sub_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) + + def tile_row_mul( + self, + dst_matrix: str, + src_matrix: str, + rows: list[int], + dst_tile_row_idx: int = 0, + dst_tile_col_idx: int = 0, + src_tile_row_idx: int = 0, + src_tile_col_idx: int = 0, + ) -> str: + return self._tile_row_binary_matrix_op( + "tile_row_mul_asm", + dst_matrix, + src_matrix, + rows, + dst_tile_row_idx, + dst_tile_col_idx, + src_tile_row_idx, + src_tile_col_idx, + ) + + def tile_row_mul_fp_broadcast( + self, + matrix_name: str, + fpram_scalar_addr: int, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self.tile_row_mul_fp_broadcast_asm( + self._tile_addr(matrix_name, tile_row_idx, tile_col_idx), + fpram_scalar_addr, + rows, + ) + + def vram_fill_zero( + self, + matrix_name: str, + rows: list[int], + tile_row_idx: int = 0, + tile_col_idx: int = 0, + ) -> str: + return self._tile_row_single_matrix_op("vram_fill_zero_asm", matrix_name, rows, tile_row_idx, tile_col_idx) + + # ========================================================================= + # Tile-row ISA helpers (address-based) + # ========================================================================= + + def _arith_progression(self, values: list[int]) -> tuple[int, int, int] | None: + """Return (start, count, step) if values form an arithmetic progression.""" + if not values: + return None + if len(values) == 1: + return (values[0], 1, 0) + step = values[1] - values[0] + for i in range(2, len(values)): + if values[i] - values[i - 1] != step: + return None + if step == 0: + return None # Constant sequence (step=0, count>1) would cause infinite HW loop + return (values[0], len(values), step) + + def _row_progression(self, rows: list[int]) -> tuple[int, int, int] | None: + return None if self._unroll else self._arith_progression(rows) + + def _emit_tile_row_reduce( + self, + label: str, + source_vram_addr: int, + row_map: list[tuple[int, int]], + opcode: str, + opcode_extra_args: tuple[int, ...] = (), + clear_accumulator: bool = False, + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} from VRAM[{source_vram_addr}]") + rows = [row for row, _ in row_map] + fp_addrs = [addr_ for _, addr_ in row_map] + row_prog = self._row_progression(rows) + fp_prog = self._row_progression(fp_addrs) + + if row_prog is not None and fp_prog is not None: + row_start, row_count, row_step = row_prog + fp_start, _, fp_step = fp_prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), source_vram_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), fp_start) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + if clear_accumulator: + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + asm.instr(opcode, fp(1), gp(gp_src), 0, *opcode_extra_args) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), fp_step) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx, fpram_addr in row_map: + row_addr = source_vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + if clear_accumulator: + asm.instr("S_ADD_FP", fp(1), fp(0), fp(0)) + asm.instr(opcode, fp(1), gp(gp_src), 0, *opcode_extra_args) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), fpram_addr) + asm.instr("S_ST_FP", fp(1), gp(gp_dst), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_unary(self, label: str, opcode: str, vram_addr: int, rows: list[int]) -> str: + gp_regs = self._reg.allocate_gp(2) + gp_src, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} on VRAM[{vram_addr}]") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), vram_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr(opcode, gp(gp_src), gp(gp_src), 0) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + asm.instr(opcode, gp(gp_src), gp(gp_src), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_fp_scalar( + self, + label: str, + opcode: str, + vram_addr: int, + row_map: list[tuple[int, int]], + opcode_extra_args: tuple[int, ...] = (), + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_src, gp_fp, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"Tile Row {label} FP on VRAM[{vram_addr}]") + rows = [row for row, _ in row_map] + fp_addrs = [addr_ for _, addr_ in row_map] + row_prog = self._row_progression(rows) + fp_prog = self._row_progression(fp_addrs) + + if row_prog is not None and fp_prog is not None: + row_start, row_count, row_step = row_prog + fp_start, _, fp_step = fp_prog + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), vram_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(0), fp_start) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr("S_LD_FP", fp(1), gp(gp_fp), 0) + asm.instr(opcode, gp(gp_src), gp(gp_src), fp(1), *opcode_extra_args) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(gp_fp), fp_step) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx, fpram_addr in row_map: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), row_addr) + asm.instr("S_ADDI_INT", gp(gp_fp), gp(0), fpram_addr) + asm.instr("S_LD_FP", fp(1), gp(gp_fp), 0) + asm.instr(opcode, gp(gp_src), gp(gp_src), fp(1), *opcode_extra_args) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def _emit_tile_row_vector_op( + self, + label: str, + opcode: str, + dst_addr: int, + src_addr: int, + rows: list[int], + ) -> str: + gp_regs = self._reg.allocate_gp(3) + gp_dst, gp_src, gp_loop = gp_regs + try: + assignment_op = {"Add": "+", "Sub": "-", "Mul": "*"}.get(label, label) + asm = IsaBuilder().comment(f"Tile Row {label}: VRAM[{dst_addr}] {assignment_op}= VRAM[{src_addr}]") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_addr + row_start * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr(opcode, gp(gp_dst), gp(gp_dst), gp(gp_src), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), row_step * self.mlen) + asm.instr("S_ADDI_INT", gp(gp_src), gp(gp_src), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + dst_row_addr = dst_addr + row_idx * self.mlen + src_row_addr = src_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), dst_row_addr) + asm.instr("S_ADDI_INT", gp(gp_src), gp(0), src_row_addr) + asm.instr(opcode, gp(gp_dst), gp(gp_dst), gp(gp_src), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + def tile_row_max_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_reduce("Max", source_vram_addr, row_map, "V_RED_MAX") + + def tile_row_sum_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_reduce( + "Sum", + source_vram_addr, + row_map, + "V_RED_SUM", + opcode_extra_args=(0,), + clear_accumulator=True, + ) + + def tile_row_exp_asm(self, vram_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_unary("Exp", "V_EXP_V", vram_addr, rows) + + def tile_row_reci_asm(self, vram_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_unary("Reciprocal", "V_RECI_V", vram_addr, rows) + + def tile_row_sub_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Sub", "V_SUB_VF", vram_addr, row_map, opcode_extra_args=(0, 0)) + + def tile_row_mul_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Mul", "V_MUL_VF", vram_addr, row_map, opcode_extra_args=(0,)) + + def tile_row_add_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: + return self._emit_tile_row_fp_scalar("Add", "V_ADD_VF", vram_addr, row_map, opcode_extra_args=(0,)) + + def tile_row_add_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Add", "V_ADD_VV", dst_addr, src_addr, rows) + + def tile_row_sub_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Sub", "V_SUB_VV", dst_addr, src_addr, rows) + + def tile_row_mul_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: + return self._emit_tile_row_vector_op("Mul", "V_MUL_VV", dst_addr, src_addr, rows) + + def tile_row_mul_fp_broadcast_asm(self, vram_addr: int, fpram_scalar_addr: int, rows: list[int]) -> str: + row_map = [(r, fpram_scalar_addr) for r in rows] + return self.tile_row_mul_fp_asm(vram_addr, row_map) + + def vram_fill_zero_asm( + self, + vram_addr: int, + rows: list[int], + ) -> str: + """ + VRAM Fill Zero: fill specified rows with 0. + + For each row_idx in rows: + VRAM[row] = 0 + """ + if not rows: + return self._emit(IsaBuilder().comment(f"=== VRAM Fill Zero: VRAM[{vram_addr}] rows [] = 0 ===")) + + gp_regs = self._reg.allocate_gp(2) + gp_dst, gp_loop = gp_regs + try: + asm = IsaBuilder().comment(f"=== VRAM Fill Zero: VRAM[{vram_addr}] rows {rows} = 0 ===") + prog = self._row_progression(rows) + + if prog is not None: + row_start, row_count, row_step = prog + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), vram_addr + row_start * self.mlen) + asm.instr("C_LOOP_START", gp(gp_loop), row_count) + asm.instr("V_MUL_VF", gp(gp_dst), gp(gp_dst), fp(0), 0) + asm.instr("S_ADDI_INT", gp(gp_dst), gp(gp_dst), row_step * self.mlen) + asm.instr("C_LOOP_END", gp(gp_loop)) + else: + for row_idx in rows: + row_addr = vram_addr + row_idx * self.mlen + asm.instr("S_ADDI_INT", gp(gp_dst), gp(0), row_addr) + asm.instr("V_MUL_VF", gp(gp_dst), gp(gp_dst), fp(0), 0) + + return self._emit(asm) + finally: + self._reg.free_gp(gp_regs) + + +__all__ = ["IsaTileRowMixin"] diff --git a/aten/plena/memory.py b/aten/plena/memory.py new file mode 100644 index 0000000..e09bf54 --- /dev/null +++ b/aten/plena/memory.py @@ -0,0 +1,437 @@ +"""Memory layouts and allocators for the ATen PLENA compiler path.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import math + +from compiler.aten.plena.constants import MLEN + + +# ============================================================================== +# Virtual Memory Manager +# ============================================================================== + + +@dataclass +class MemoryBlock: + name: str # Allocation name (e.g., "W[0][1]" or "activation_A") + addr: int # Starting address + size: int # Block size (number of elements) + + +class VirtualMemoryManager: + """Best-fit reuse plus bump allocation for PLENA virtual memories.""" + + def __init__(self, total_size: int, alignment: int = MLEN, mem_name: str = "Memory"): + self.total_size = total_size + self.alignment = alignment + self.mem_name = mem_name + self.next_bump = 0 # Bump allocation pointer + + # Two core stacks + self.used_stack: list[MemoryBlock] = [] + self.free_stack: list[MemoryBlock] = [] + + def _align(self, value: int) -> int: + """Align value to alignment""" + return ((value + self.alignment - 1) // self.alignment) * self.alignment + + def allocate(self, name: str, size: int) -> int: + """Allocate by best-fit reuse first, then bump allocation.""" + aligned_size = self._align(size) + + best = min( + ((block.size - aligned_size, i) for i, block in enumerate(self.free_stack) if block.size >= aligned_size), + default=None, + ) + + if best is not None: + _, best_idx = best + reused_block = self.free_stack.pop(best_idx) + + # If block is larger than needed, split remaining part and return to free_stack + if reused_block.size > aligned_size: + remaining = MemoryBlock( + name="", addr=reused_block.addr + aligned_size, size=reused_block.size - aligned_size + ) + self.free_stack.append(remaining) + + new_block = MemoryBlock(name=name, addr=reused_block.addr, size=aligned_size) + self.used_stack.append(new_block) + return new_block.addr + + aligned_addr = self._align(self.next_bump) + + if self.total_size > 0 and aligned_addr + aligned_size > self.total_size: + raise MemoryError( + f"{self.mem_name} overflow: need {aligned_size} at addr {aligned_addr}, " + f"total_size={self.total_size}, " + f"used={len(self.used_stack)} blocks, " + f"free={len(self.free_stack)} blocks" + ) + + new_block = MemoryBlock(name=name, addr=aligned_addr, size=aligned_size) + self.used_stack.append(new_block) + self.next_bump = aligned_addr + aligned_size + return aligned_addr + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + """Move an allocation from used_stack to reusable free_stack.""" + for i, block in enumerate(self.used_stack): + if block.name == name: + freed = self.used_stack.pop(i) + self.free_stack.append(freed) + self._coalesce_free_stack() + return freed + + if strict: + raise KeyError( + f"{self.mem_name}: allocation '{name}' not found in used_stack. " + f"Current used: {[b.name for b in self.used_stack]}" + ) + return None + + def mark_used(self, addr: int, size: int, name: str) -> None: + """Register a pre-known occupied range and advance bump past it.""" + aligned_size = self._align(size) + block = MemoryBlock(name=name, addr=addr, size=aligned_size) + self.used_stack.append(block) + # Advance bump pointer past this region if it would otherwise overlap. + end = addr + aligned_size + if self.next_bump < end: + self.next_bump = end + + def _coalesce_free_stack(self): + """Merge adjacent free blocks by address.""" + if len(self.free_stack) <= 1: + return + + blocks = sorted(self.free_stack, key=lambda b: b.addr) + merged: list[MemoryBlock] = [blocks[0]] + for block in blocks[1:]: + prev = merged[-1] + if prev.addr + prev.size == block.addr: + merged[-1] = MemoryBlock( + name="", + addr=prev.addr, + size=prev.size + block.size, + ) + else: + merged.append(block) + + self.free_stack = merged + + def reset(self): + """Reset manager""" + self.next_bump = 0 + self.used_stack.clear() + self.free_stack.clear() + + +# ============================================================================== +# Sub-matrix Information +# ============================================================================== + + +@dataclass +class SubMatrixInfo: + """Metadata for sub-matrices""" + + parent_name: str # Parent matrix name + row_idx: int # Sub-block row index + col_idx: int # Sub-block column index + shape: tuple[int, int] # Sub-block shape (typically 64x64) + + # Pre-calculated addresses (computed during compiler phase, used directly at runtime) + hbm_offset: int = 0 # Offset in HBM (in elements) + mram_addr: int | None = None # Address in MRAM (if loaded) + + +@dataclass +class MatrixBlockLayout: + """ + Block layout information for large matrices. + + HBM storage: [rows, cols] row-major contiguous, stride=cols per row. + MRAM storage: [batch, mlen, hidden/mlen] column-block major. + """ + + name: str + full_shape: tuple[int, int] # Full matrix shape (rows, cols) + block_size: int = MLEN # Sub-block size (default 64) + + num_row_blocks: int = 0 + num_col_blocks: int = 0 + + # HBM Address Information + hbm_base_addr: int = 0 + hbm_size: int = 0 # Size after applying real_data_ratio (MXFP8 = 1.125) + + # Sub-block map: (row_idx, col_idx) -> SubMatrixInfo + sub_blocks: dict[tuple[int, int], SubMatrixInfo] = field(default_factory=dict) + + def __post_init__(self): + """Initialize block information""" + rows, cols = self.full_shape + self.num_row_blocks = math.ceil(rows / self.block_size) + self.num_col_blocks = math.ceil(cols / self.block_size) + + # Create information for all sub-blocks (pre-calculate addresses) + for r in range(self.num_row_blocks): + for c in range(self.num_col_blocks): + # HBM offset (row-major): sub-block (r,c) starts at r*block_size*cols + c*block_size + hbm_offset = r * self.block_size * cols + c * self.block_size + + sub_info = SubMatrixInfo( + parent_name=self.name, + row_idx=r, + col_idx=c, + shape=(self.block_size, self.block_size), + hbm_offset=hbm_offset, + mram_addr=None, + ) + self.sub_blocks[(r, c)] = sub_info + + def get_sub_block(self, row_idx: int, col_idx: int) -> SubMatrixInfo: + """Get specified sub-block""" + if (row_idx, col_idx) not in self.sub_blocks: + raise IndexError(f"Sub block [{row_idx}][{col_idx}] out of range") + return self.sub_blocks[(row_idx, col_idx)] + + def get_row_blocks(self, row_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a row""" + return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] + + def get_col_blocks(self, col_idx: int) -> list[SubMatrixInfo]: + """Get all sub-blocks in a column""" + return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] + + +# ============================================================================== +# VRAM Sub-matrix Information +# ============================================================================== + + +@dataclass +class VRAMSubMatrixInfo: + """Metadata for sub-matrices in VRAM""" + + parent_name: str # Parent matrix name + row_idx: int # Sub-block row index (batch dimension) + col_idx: int # Sub-block column index (hidden dimension) + shape: tuple[int, int] # Sub-block shape (typically mlen x mlen) + + # Pre-calculated VRAM address + vram_addr: int = 0 + + +@dataclass +class VRAMMatrixBlockLayout: + """ + Block layout for large matrices in VRAM. + + VRAM storage format: [batch, mlen, hidden/mlen] column-block major. + - Batch dimension contiguous within each column block. + - Column blocks laid out sequentially. + + Column block c base address: vram_base + c * batch * mlen. + Sub-block (r, c) offset within column block: r * mlen * mlen. + """ + + name: str + full_shape: tuple[int, int] # Full matrix shape (batch, hidden_size) + vram_base_addr: int # VRAM base address + block_size: int = MLEN # Sub-block size (default 64) + + num_row_blocks: int = 0 # Number of blocks in batch dimension + num_col_blocks: int = 0 # Number of blocks in hidden dimension + + # Sub-block map: (row_idx, col_idx) -> VRAMSubMatrixInfo + sub_blocks: dict[tuple[int, int], VRAMSubMatrixInfo] = field(default_factory=dict) + + def __post_init__(self): + """Initialize block information""" + batch, hidden = self.full_shape + self.num_row_blocks = math.ceil(batch / self.block_size) + self.num_col_blocks = math.ceil(hidden / self.block_size) + + # VRAM column-block major address calculation: + # col block c base = vram_base + c * batch * mlen + # row sub-block r offset within column block = r * mlen * mlen + for r in range(self.num_row_blocks): + for c in range(self.num_col_blocks): + col_block_base = self.vram_base_addr + c * batch * self.block_size + row_offset = r * self.block_size * self.block_size + vram_addr = col_block_base + row_offset + + sub_info = VRAMSubMatrixInfo( + parent_name=self.name, + row_idx=r, + col_idx=c, + shape=(self.block_size, self.block_size), + vram_addr=vram_addr, + ) + self.sub_blocks[(r, c)] = sub_info + + def get_sub_block(self, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: + """Get specified sub-block""" + if (row_idx, col_idx) not in self.sub_blocks: + raise IndexError(f"VRAM sub block [{row_idx}][{col_idx}] out of range") + return self.sub_blocks[(row_idx, col_idx)] + + def get_row_blocks(self, row_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a row (A[row_idx][:])""" + return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] + + def get_col_blocks(self, col_idx: int) -> list[VRAMSubMatrixInfo]: + """Get all sub-blocks in a column""" + return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] + + +# ============================================================================== +# Unified Memory Object Metadata +# ============================================================================== + + +@dataclass +class MemoryObjectInfo: + """Unified metadata for objects managed across HBM / VRAM / FPRAM.""" + + name: str + kind: str + dtype: str = "fp16" + shape: tuple[int, int] = (0, 0) + size: int = 0 + hbm_addr: int = -1 + hbm_size: int = 0 + vram_addr: int | None = None + fpram_addr: int | None = None + fpram_size: int = 0 + + +@dataclass +class FPRAMObjectLayout: + """FPRAM object layout.""" + + name: str + fpram_addr: int + size: int + dtype: str = "fp16" + kind: str = "FPRAMObject" + + +# ============================================================================== +# Allocators +# ============================================================================== + + +class MemoryAllocatorBase: + """Shared wrapper over VirtualMemoryManager for compiler address spaces.""" + + def __init__(self, total_size: int, alignment: int, mem_name: str): + self.total_size = total_size + self.alignment = alignment + self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name=mem_name) + + @property + def next_free(self) -> int: + return self._vmm.next_bump + + @next_free.setter + def next_free(self, value: int): + self._validate_next_free(value) + self._vmm.next_bump = value + + def _validate_next_free(self, value: int) -> None: + del value + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + return self._vmm.free(name, strict=strict) + + def reset(self): + self._vmm.reset() + + +class MRAMAllocator(MemoryAllocatorBase): + """ + Matrix RAM address allocator. + + Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. + Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). + """ + + def __init__(self, total_size: int = MLEN * MLEN * 4): + super().__init__(total_size=total_size, alignment=MLEN * MLEN, mem_name="MRAM") + + def allocate(self, name: str, size: int) -> int: + return self._vmm.allocate(name, size) + + +class VRAMAllocator(MemoryAllocatorBase): + """VRAM address allocator with MLEN-aligned best-fit reuse + bump allocation.""" + + def __init__(self, alignment: int = MLEN, total_size: int = 0): + super().__init__(total_size=total_size, alignment=alignment, mem_name="VRAM") + + def allocate(self, size: int, name: str = "") -> int: + if not name: + raise ValueError("VRAMAllocator.allocate() requires name for subsequent free.") + return self._vmm.allocate(name, size) + + +class FPRAMAllocator(MemoryAllocatorBase): + """ + Floating Point RAM Allocator (based on VirtualMemoryManager). + + FPRAM stores scalar FP values (f16), accessed via S_LD_FP / S_ST_FP. + Uses the same strategy as VRAM/MRAM allocator: + - used_stack + free_stack + - allocate: best-fit reuse first, then bump + - free: move block to free_stack, supports out-of-order free + + Fixed slot conventions (must not be overwritten by dynamic allocations): + slot 0 = 0.0 (gp0/f0 reserved as hardware zero) + slot 1 = attn_scale + slot 2 = -inf (online softmax) + slot 3 = eps (rms_norm/layer_norm) + slot 4 = 1/hidden_size (rms_norm/layer_norm) + slot 5 = 1.0 (FFN SiLU sigmoid denominator; im2col fp_one_reg) + + Hardware: 1024 f16 elements (configurable via total_size). + """ + + def __init__(self, total_size: int = 1024): + """ + Args: + total_size: Total FP RAM size (default 1024, matching hardware fpsram) + """ + super().__init__(total_size=total_size, alignment=1, mem_name="FPRAM") + self.allocations: dict[str, tuple[int, int]] = {} + + def _validate_next_free(self, value: int) -> None: + if value < 0 or value > self.total_size: + raise ValueError(f"next_free out of range: {value}, expected [0, {self.total_size}]") + + def allocate(self, name: str, size: int) -> int: + """Allocate FP RAM space (best-fit + bump).""" + if size <= 0: + raise ValueError(f"FPRAM allocation size must be > 0, got {size}") + if name in self.allocations: + raise KeyError(f"FPRAM name '{name}' already allocated") + + addr = self._vmm.allocate(name, size) + self.allocations[name] = (addr, size) + return addr + + def free(self, name: str, strict: bool = True) -> MemoryBlock | None: + """Free a block and move it to free_stack (same as VirtualMemoryManager).""" + freed = self._vmm.free(name, strict=strict) + if freed is not None: + self.allocations.pop(name, None) + return freed + + def reset(self): + """Reset allocator""" + self._vmm.reset() + self.allocations.clear() diff --git a/aten/plena/memory_state.py b/aten/plena/memory_state.py new file mode 100644 index 0000000..6e4482e --- /dev/null +++ b/aten/plena/memory_state.py @@ -0,0 +1,268 @@ +"""Memory layout state for the ATen PLENA compiler.""" + +from __future__ import annotations + +from compiler.aten.plena.constants import BLEN, MLEN +from compiler.aten.plena.memory import ( + FPRAMAllocator, + FPRAMObjectLayout, + MRAMAllocator, + MatrixBlockLayout, + MemoryObjectInfo, + VRAMAllocator, + VRAMMatrixBlockLayout, + VRAMSubMatrixInfo, +) + + +class MemoryStateMixin: + """Sub-matrix layout manager for PLENA HBM, VRAM, MRAM, and FPRAM state.""" + + def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): + self.mlen = mlen + self.blen = blen + self.unroll_loops = unroll_loops + + # Layout tables + self.hbm_matrices: dict[str, MatrixBlockLayout] = {} + self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} + self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} + # Memory Allocators + self.vram_allocator = VRAMAllocator() + self.mram_allocator = MRAMAllocator() + self.fpram_allocator = FPRAMAllocator() + + def __contains__(self, name: str) -> bool: + return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices + + def __getitem__(self, name: str) -> MemoryObjectInfo: + if name not in self: + raise KeyError(f"Object '{name}' not found") + info = MemoryObjectInfo(name=name, kind="Unknown") + hbm_layout = self.hbm_matrices.get(name) + vram_layout = self.vram_matrices.get(name) + fpram_layout = self.fpram_matrices.get(name) + + if hbm_layout is not None: + rows, cols = hbm_layout.full_shape + info.shape = hbm_layout.full_shape + info.size = rows * cols + info.hbm_addr = hbm_layout.hbm_base_addr + info.hbm_size = hbm_layout.hbm_size + info.kind = "Matrix" + + if vram_layout is not None: + rows, cols = vram_layout.full_shape + info.shape = vram_layout.full_shape + info.size = rows * cols + info.vram_addr = vram_layout.vram_base_addr + info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" + + if fpram_layout is not None: + info.shape = (1, fpram_layout.size) + info.size = fpram_layout.size + info.fpram_addr = fpram_layout.fpram_addr + info.fpram_size = fpram_layout.size + info.kind = "FPRAMObject" + + return info + + def get_hbm_layout(self, name: str) -> MatrixBlockLayout: + """Read HBM matrix layout by name.""" + if name not in self.hbm_matrices: + raise KeyError(f"HBM matrix '{name}' not found") + return self.hbm_matrices[name] + + def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: + """Read VRAM matrix layout by name.""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not found") + return self.vram_matrices[name] + + def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: + """Read FPRAM object layout by name.""" + if name not in self.fpram_matrices: + raise KeyError(f"FPRAM object '{name}' not found") + return self.fpram_matrices[name] + + # ========================================================================== + # Unified Object Management APIs + # ========================================================================== + + def add_hbm_object( + self, + name: str, + shape: tuple[int, int], + hbm_addr: int, + dtype: str = "fp16", + kind: str = "HBMObject", + real_data_ratio: float = 1.125, + strict: bool = False, + ) -> MemoryObjectInfo: + del dtype, kind + self.register_matrix( + name=name, + shape=shape, + hbm_base_addr=hbm_addr, + real_data_ratio=real_data_ratio, + strict=strict, + ) + return self[name] + + def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.hbm_matrices: + if strict: + raise KeyError(f"HBM object '{name}' not found") + return None + info = self[name] + self.hbm_matrices.pop(name, None) + return info + + def add_vram_object( + self, + name: str, + shape: tuple[int, int], + vram_addr: int | None = None, + dtype: str = "fp16", + kind: str = "VRAMObject", + allocate_if_none: bool = True, + strict: bool = True, + ) -> MemoryObjectInfo: + rows, cols = shape + size = rows * cols + if vram_addr is None: + if not allocate_if_none: + raise ValueError("vram_addr is None and allocate_if_none is False") + vram_addr = self.vram_allocator.allocate(size=size, name=name) + del dtype, kind + self.register_vram_matrix( + name=name, + shape=shape, + vram_base_addr=vram_addr, + strict=strict, + ) + return self[name] + + def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.vram_matrices: + if strict: + raise KeyError(f"VRAM object '{name}' not found") + return None + info = self[name] + self.vram_allocator.free(name, strict=strict) + self.vram_matrices.pop(name, None) + return info + + def add_fpram_object( + self, + name: str, + size: int, + dtype: str = "fp16", + kind: str = "FPRAMObject", + ) -> MemoryObjectInfo: + fpram_addr = self.fpram_allocator.allocate(name, size) + self.fpram_matrices[name] = FPRAMObjectLayout( + name=name, + fpram_addr=fpram_addr, + size=size, + dtype=dtype, + kind=kind, + ) + return self[name] + + def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: + if name not in self.fpram_matrices: + if strict: + raise KeyError(f"FPRAM object '{name}' not found") + return None + info = self[name] + self.fpram_allocator.free(name, strict=strict) + self.fpram_matrices.pop(name, None) + return info + + def register_matrix( + self, + name: str, + shape: tuple[int, int], + hbm_base_addr: int, + real_data_ratio: float = 1.125, + strict: bool = True, + ) -> MatrixBlockLayout: + """Register an HBM matrix and derive its mlen block layout.""" + rows, cols = shape + + if strict: + if rows % self.mlen != 0: + raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") + if cols % self.mlen != 0: + raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") + + size = rows * cols + hbm_size = int(size * real_data_ratio) + + layout = MatrixBlockLayout( + name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size + ) + + self.hbm_matrices[name] = layout + return layout + + # ========================================================================== + # VRAM Sub-matrix Management + # ========================================================================== + + def register_vram_matrix( + self, + name: str, + shape: tuple[int, int], + vram_base_addr: int, + strict: bool = True, + ) -> VRAMMatrixBlockLayout: + """ + Register a large matrix in VRAM and automatically block it. + + Args: + name: matrix name + shape: full shape (batch, hidden_size) + vram_base_addr: VRAM base address + + Returns: + VRAMMatrixBlockLayout object + """ + batch, hidden = shape + + if strict: + if batch % self.mlen != 0: + raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") + if hidden % self.mlen != 0: + raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") + + layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) + + self.vram_matrices[name] = layout + return layout + + def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: + """Get VRAM sub-block information""" + if name not in self.vram_matrices: + raise KeyError(f"VRAM matrix '{name}' not registered") + return self.vram_matrices[name].get_sub_block(row_idx, col_idx) + + def clear_mram_bindings(self) -> None: + """Clear cached MRAM addresses on all HBM sub-blocks.""" + for layout in self.hbm_matrices.values(): + for sub_block in layout.sub_blocks.values(): + sub_block.mram_addr = None + + def reset(self): + """Reset manager state.""" + self.clear_mram_bindings() + self.hbm_matrices.clear() + self.vram_matrices.clear() + self.fpram_matrices.clear() + self.vram_allocator.reset() + self.mram_allocator.reset() + self.fpram_allocator.reset() + + +__all__ = ["MemoryStateMixin"] diff --git a/aten/plena/program_attention.py b/aten/plena/program_attention.py new file mode 100644 index 0000000..0e22c9b --- /dev/null +++ b/aten/plena/program_attention.py @@ -0,0 +1,216 @@ +"""Flash-attention operations for the PLENA program builder.""" + +from __future__ import annotations + +import math + +from compiler.asm_templates import preload_addr_reg_asm +from compiler.asm_templates.flashattn import flash_attn_asm +from compiler.aten.plena.vars import InputVar, VRAMMatrixVar + + +class ProgramAttentionMixin: + # ======================================================================== + # Flash Attention Operations + # ======================================================================== + + def flash_attention(self, Q, K, V, scale=None, hq=1, hkv=1, h_qkv=None, causal_mask=None): + """Emit flash attention, dispatching to MHA or fused GQA codegen by shape.""" + if hq == 1 and hkv == 1: + return self._flash_attention_mha(Q, K, V, scale, causal_mask=causal_mask) + + if h_qkv is None: + raise ValueError("GQA mode requires h_qkv to be specified") + if causal_mask is not None: + raise NotImplementedError("causal_mask is not yet supported for GQA flash attention") + return self._flash_attention_gqa_fused(Q, K, V, scale, hq, hkv, h_qkv) + + def _flash_attention_mha(self, Q, K, V, scale, causal_mask=None): + """Single-head online-softmax flash attention using compiler primitives.""" + seq_len, head_dim = Q.shape + mlen = self.mlen + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + num_q_blocks = seq_len // mlen + num_k_blocks = seq_len // mlen + + S_block = self.alloc("S", mlen, mlen) + PV = self.alloc("PV", mlen, head_dim) + O = self.alloc("O", seq_len, head_dim) + + for q_idx in range(num_q_blocks): + self.init_online_softmax(q_idx, O) + + for k_idx in range(num_k_blocks): + self.vram_sub_projection_T_to( + Q, + q_idx, + K, + k_idx, + S_block, + target_row_idx=0, + target_col_idx=0, + ) + if causal_mask is not None: + self.vram_add(S_block, causal_mask) + self.online_softmax_block(S_block, scale) + self.compute_pv(S_block, V, k_idx, PV, head_dim) + self.scale_o_row(O, q_idx) + self.vram_add(O, PV, dst_row_offset=q_idx * mlen) + + self.final_scale_o(q_idx, O) + + return O + + def _flash_attention_gqa_fused(self, Q, K, V, scale, hq, hkv, h_qkv): + """GQA flash attention using the fused M_BTMM template.""" + ratio = hq // hkv + mlen = self.mlen + blen = self.blen + vlen = mlen + + if ratio != blen: + raise ValueError( + f"GQA ratio hq/hkv={ratio} must equal blen={blen} " + "(hardware packs heads into blen)." + ) + if ratio * h_qkv != mlen: + raise ValueError( + f"GQA constraint: (hq/hkv)*h_qkv = {ratio * h_qkv} must equal mlen={mlen}." + ) + + s_q, _q_total_dim = Q.shape + s_kv, _k_total_dim = K.shape + + if scale is None: + scale = 1.0 / math.sqrt(h_qkv) + + self._ensure_hbm_sub_matrix_registered(K) + self._ensure_hbm_sub_matrix_registered(V) + alloc = self.register_allocator + k_addr, v_addr = alloc.allocate_addr(2) + gp_for_preload = alloc.allocate_gp(2) + setup = preload_addr_reg_asm( + addr_reg_to_set=[k_addr, v_addr], + available_registers=gp_for_preload, + addr_reg_val=[K.hbm_addr, V.hbm_addr], + ) + alloc.free_gp(gp_for_preload) + self.emit(setup) + + q_vram_base = self.get_vram_addr(Q.name) + s_name = self._scoped_name("_gqa_S") + pv_name = self._scoped_name("_gqa_PV") + o_name = self._scoped_name("O") + + self.allocate_vram_matrix(name=s_name, rows=mlen * ratio, cols=mlen, strict=False) + self.allocate_vram_matrix(name=pv_name, rows=mlen * ratio, cols=mlen, strict=False) + self.allocate_vram_matrix(name=o_name, rows=s_q, cols=hq * h_qkv, strict=False) + + br = min(mlen, s_q) + fp_allocs = self.fpram_allocator + if "_gqa_fp_const_zero" not in fp_allocs.allocations: + fp_allocs.allocate(name="_gqa_fp_const_zero", size=1) + fp_allocs.allocate(name="_gqa_fp_const_scale", size=1) + fp_allocs.allocate(name="_gqa_fp_const_neg_inf", size=1) + fp_info = self.add_fpram_object(name="_gqa_softmax_state", size=3 * br * ratio) + if fp_info.fpram_addr is None: + raise RuntimeError("Failed to allocate FPRAM for GQA softmax state") + + self.emit( + flash_attn_asm( + mlen=mlen, + vlen=vlen, + blen=blen, + batch=1, + hq=hq, + hkv=hkv, + d=h_qkv, + q_len=s_q, + kv_len=s_kv, + alive_registers_int=list(range(1, 16)), + alive_registers_fp=list(range(1, 8)), + vector_sram_base_address=q_vram_base, + fp_sram_start_address=fp_info.fpram_addr, + k_base_hbm_offset_reg=k_addr, + v_base_hbm_offset_reg=v_addr, + ) + ) + + alloc.free_addr([k_addr, v_addr]) + O = VRAMMatrixVar(self, o_name, (s_q, hq * h_qkv), display_name="O") + self._tensors[o_name] = O + return O + + def init_online_softmax(self, q_idx: int, o_matrix: VRAMMatrixVar): + """Initialize Online Softmax state: m=-inf, l=0, O_row=0""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().init_online_softmax( + q_idx=q_idx, + o_matrix=o_matrix.name, + seq_len=seq_len, + head_dim=head_dim, + ) + + def online_softmax_block(self, s_block: VRAMMatrixVar, scale: float): + """Perform Online Softmax on S block""" + super().online_softmax_block( + s_block_matrix=s_block.name, + scale=scale, + ) + + def compute_pv( + self, + s_block: VRAMMatrixVar, + v_input: InputVar, + k_idx: int, + pv_matrix: VRAMMatrixVar, + head_dim: int, + ): + """Compute PV = P @ V[k_idx] where P is stored in s_block.""" + if not isinstance(s_block, VRAMMatrixVar): + raise TypeError(f"s_block must be VRAMMatrixVar, got {type(s_block)}") + if not isinstance(v_input, InputVar): + raise TypeError(f"v_input must be InputVar, got {type(v_input)}") + if not isinstance(pv_matrix, VRAMMatrixVar): + raise TypeError(f"pv_matrix must be VRAMMatrixVar, got {type(pv_matrix)}") + + self._ensure_hbm_sub_matrix_registered(v_input) + super().compute_pv( + s_block_matrix=s_block.name, + v_sub_matrix=v_input.name, + k_idx=k_idx, + pv_matrix=pv_matrix.name, + head_dim=head_dim, + ) + + def scale_o_row(self, o_matrix: VRAMMatrixVar, q_idx: int): + """Scale current row block of O by m_res""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().scale_o_row( + o_matrix=o_matrix.name, + q_idx=q_idx, + seq_len=seq_len, + head_dim=head_dim, + ) + + def final_scale_o(self, q_idx: int, o_matrix: VRAMMatrixVar): + """Final scaling: O[q_idx] = O[q_idx] / l""" + o_info = super().get_tensor_info(o_matrix.name) + seq_len, head_dim = o_info.shape + + super().final_scale_o( + q_idx=q_idx, + o_matrix=o_matrix.name, + seq_len=seq_len, + head_dim=head_dim, + ) + + +__all__ = ["ProgramAttentionMixin"] diff --git a/aten/plena/program_fp_tile_ops.py b/aten/plena/program_fp_tile_ops.py new file mode 100644 index 0000000..86a8d16 --- /dev/null +++ b/aten/plena/program_fp_tile_ops.py @@ -0,0 +1,347 @@ +"""FPRAM, FPVar, and tile-row operations for the PLENA program builder.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from compiler.aten.plena.vars import FPVar, VRAMMatrixVar + + +class ProgramFPTileOpsMixin: + # ======================================================================== + # FP Variable (FPRAM) + # ======================================================================== + + def allocate_fpram( + self, + internal_name: str, + size: int = 1, + display_name: str | None = None, + ) -> FPVar: + """Allocate FPRAM with an explicit internal name and return an FPVar.""" + if size <= 0: + raise ValueError(f"FPRAM allocation size must be positive, got {size}") + + address = super().allocate_fpram(internal_name, size) + var = FPVar( + self, + internal_name, + address, + size, + display_name=display_name if display_name is not None else internal_name, + ) + self._fp_vars[internal_name] = var + return var + + def free_fpram(self, internal_name: str, strict: bool = True): + super().free_fpram(internal_name, strict=strict) + self._fp_vars.pop(internal_name, None) + + def fp_var(self, name: str, size: int = 1) -> FPVar: + return self.allocate_fpram( + internal_name=self._scoped_name(name), + size=size, + display_name=name, + ) + + # ======================================================================== + # Shared argument normalization + # ======================================================================== + + def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: + if isinstance(addr_or_var, FPVar): + if offset < 0 or offset >= addr_or_var.size: + raise ValueError( + f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" + ) + return addr_or_var.address + offset + if not isinstance(addr_or_var, int): + raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") + return addr_or_var + offset + + def _resolve_rows( + self, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + ) -> list[int]: + if row_idx is not None and rows is not None: + raise ValueError("Provide either row_idx or rows, not both") + if rows is not None: + return list(rows) + if row_idx is not None: + return [row_idx] + return list(range(self.mlen)) + + def _default_rows(self, rows: Iterable[int] | None, *, total_rows: int | None = None) -> list[int]: + return list(range(self.mlen if total_rows is None else total_rows)) if rows is None else list(rows) + + def _fpram_row_map( + self, + fpram_addr: int | FPVar, + *, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + single_offset: int = 0, + base_offset: int = 0, + ) -> list[tuple[int, int]]: + resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) + offsets = ( + [single_offset] + if len(resolved_rows) == 1 + else [base_offset + i for i in range(len(resolved_rows))] + ) + return [(row, self._resolve_fpram_addr(fpram_addr, offset)) for row, offset in zip(resolved_rows, offsets)] + + def _fp_count(self, vars_: Iterable[FPVar], count: int | None, *, default: int | None = None) -> int: + fp_vars = list(vars_) + resolved_count = default if count is None and default is not None else count + if resolved_count is None: + resolved_count = min(var.size for var in fp_vars) + if any(resolved_count > var.size for var in fp_vars): + sizes = ", ".join(f"{var.name}.size={var.size}" for var in fp_vars) + raise ValueError(f"count={resolved_count} exceeds FPVar size: {sizes}") + return resolved_count + + def _fpvar_unary(self, isa_method: str, src: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src, dst], count) + return getattr(super(), isa_method)(src.name, dst.name, count) + + def _fpvar_binary(self, isa_method: str, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src1, src2, dst], count) + return getattr(super(), isa_method)(src1.name, src2.name, dst.name, count) + + # ======================================================================== + # FPRAM tile-row operations + # ======================================================================== + + def _tile_row_reduce_to_fpram( + self, + isa_method: str, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None, + rows: Iterable[int] | None, + target_offset: int, + target_base_offset: int, + ): + return getattr(super(), isa_method)( + source.name, + self._fpram_row_map( + target_fpram_addr, + row_idx=row_idx, + rows=rows, + single_offset=target_offset, + base_offset=target_base_offset, + ), + ) + + def tile_row_max( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + return self._tile_row_reduce_to_fpram( + "tile_row_max", target_fpram_addr, source, row_idx, rows, target_offset, target_base_offset + ) + + def tile_row_sum( + self, + target_fpram_addr: int | FPVar, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + target_offset: int = 0, + target_base_offset: int = 0, + ): + return self._tile_row_reduce_to_fpram( + "tile_row_sum", target_fpram_addr, source, row_idx, rows, target_offset, target_base_offset + ) + + def tile_row_exp( + self, + source: VRAMMatrixVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + ): + super().tile_row_exp(source.name, self._resolve_rows(row_idx=row_idx, rows=rows)) + + def tile_row_reci( + self, + source: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + super().tile_row_reci(source.name, self._default_rows(rows)) + + def tile_row_sub_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + return self._tile_row_fp_scalar( + "tile_row_sub_fp", source, fpram_addr, row_idx, rows, fpram_offset, fpram_base_offset + ) + + def tile_row_mul_fp( + self, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + fpram_base_offset: int = 0, + ): + return self._tile_row_fp_scalar( + "tile_row_mul_fp", source, fpram_addr, row_idx, rows, fpram_offset, fpram_base_offset + ) + + def tile_row_add_fp( + self, + source: VRAMMatrixVar, + fp_var: FPVar, + rows: Iterable[int] | None = None, + ): + resolved_rows = self._default_rows(rows) + super().tile_row_add_fp(source.name, [(row, fp_var[row]) for row in resolved_rows]) + + def _tile_row_binary(self, isa_method: str, dst: VRAMMatrixVar, src: VRAMMatrixVar, rows: Iterable[int] | None): + return getattr(super(), isa_method)(dst.name, src.name, self._default_rows(rows)) + + def _tile_row_fp_scalar( + self, + isa_method: str, + source: VRAMMatrixVar, + fpram_addr: int | FPVar, + row_idx: int | None, + rows: Iterable[int] | None, + fpram_offset: int, + fpram_base_offset: int, + ): + return getattr(super(), isa_method)( + source.name, + self._fpram_row_map( + fpram_addr, + row_idx=row_idx, + rows=rows, + single_offset=fpram_offset, + base_offset=fpram_base_offset, + ), + ) + + def tile_row_add( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_add", dst, src, rows) + + def tile_row_sub( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_sub", dst, src, rows) + + def tile_row_mul( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + return self._tile_row_binary("tile_row_mul", dst, src, rows) + + def tile_row_mul_fp_broadcast( + self, + source: VRAMMatrixVar, + fpram_scalar_addr: int | FPVar, + row_idx: int | None = None, + rows: Iterable[int] | None = None, + fpram_offset: int = 0, + ): + scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) + super().tile_row_mul_fp_broadcast(source.name, scalar_addr, self._resolve_rows(row_idx=row_idx, rows=rows)) + + # ======================================================================== + # FPVar operations + # ======================================================================== + + def fpvar_reci(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_reci", src, dst, count) + + def fpvar_exp(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_exp", src, dst, count) + + def fpvar_copy(self, src: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_unary("fpram_copy", src, dst, count) + + def fpvar_max(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_max", src1, src2, dst, count) + + def fpvar_sub(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_sub", src1, src2, dst, count) + + def fpvar_mul(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_mul", src1, src2, dst, count) + + def fpvar_add(self, src1: FPVar, src2: FPVar, dst: FPVar, count: int | None = None): + return self._fpvar_binary("fpram_add", src1, src2, dst, count) + + def fpvar_sum(self, src: FPVar, dst: FPVar, count: int | None = None): + count = self._fp_count([src], count, default=src.size) + return super().fpram_sum(src.name, dst.name, count) + + def fpvar_shift( + self, + src: FPVar, + dst: FPVar, + shift: int, + count: int | None = None, + fill: FPVar | None = None, + ): + count = self._fp_count([src, dst], count) + return super().fpram_shift( + src_name=src.name, + dst_name=dst.name, + shift=shift, + count=count, + fill_fpram_name=None if fill is None else fill.name, + ) + + def fpvar_fill_from_fpram( + self, + dst: FPVar, + src_fpram_addr: int, + count: int | None = None, + ): + count = self._fp_count([dst], count, default=dst.size) + return super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) + + def vram_fill_zero( + self, + matrix: VRAMMatrixVar, + rows: Iterable[int] | None = None, + ): + resolved_rows = self._default_rows(rows, total_rows=matrix.shape[0]) + total_rows, cols = matrix.shape + if any(row < 0 or row >= total_rows for row in resolved_rows): + raise ValueError( + f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={resolved_rows}" + ) + + # VRAM matrices are column-block-major. The low-level helper zeros one + # tile column, so walk every column block for wide matrices. + num_col_blocks = (cols + self.mlen - 1) // self.mlen + for col_block in range(num_col_blocks): + super().vram_fill_zero(matrix.name, resolved_rows, tile_col_idx=col_block) + + +__all__ = ["ProgramFPTileOpsMixin"] diff --git a/aten/plena/program_matrix_ops.py b/aten/plena/program_matrix_ops.py new file mode 100644 index 0000000..22224c6 --- /dev/null +++ b/aten/plena/program_matrix_ops.py @@ -0,0 +1,281 @@ +"""Matrix projection, RoPE, and VRAM operations for the PLENA program builder.""" + +from __future__ import annotations + +import math + +from compiler.aten.plena.vars import InputVar, TensorVar, VRAMMatrixVar + +MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements + + +def _iter_k_chunks(num_k_tiles: int): + k_start = 0 + while k_start < num_k_tiles: + k_end = min(k_start + MAX_K_TILES, num_k_tiles) + yield k_start, k_end - k_start + k_start = k_end + + +class ProgramMatrixOpsMixin: + # ======================================================================== + # Matrix Projection and VRAM Operations + # ======================================================================== + + def _require_var(self, value, expected_type, label: str): + if not isinstance(value, expected_type): + raise TypeError(f"{label} must be {expected_type.__name__}, got {type(value)}") + return value + + def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): + """Ensure an HBM input is registered in compiler sub-matrix manager.""" + if self._registered_hbm_sub_matrices.get(input_var.name): + return + h, w = input_var.shape + super().ensure_hbm_sub_matrix( + name=input_var.name, + hbm_addr=input_var.hbm_addr, + shape=(h, w), + real_data_ratio=self.real_data_ratio, + ) + self._registered_hbm_sub_matrices[input_var.name] = True + + def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): + """Ensure a VRAM matrix is registered in compiler sub-matrix manager.""" + if self._registered_vram_sub_matrices.get(matrix_var.name): + return + super().ensure_vram_matrix_layout( + name=matrix_var.name, + shape=matrix_var.shape, + ) + self._registered_vram_sub_matrices[matrix_var.name] = True + + def _prepare_projection(self, vram_matrix, mram_input, target, auto_reset_mram: bool): + vram_matrix = self._require_var(vram_matrix, VRAMMatrixVar, "vram_matrix") + mram_input = self._require_var(mram_input, InputVar, "mram_input") + target = self._require_var(target, VRAMMatrixVar, "target") + self._ensure_vram_sub_matrix_registered(vram_matrix) + self._ensure_hbm_sub_matrix_registered(mram_input) + if auto_reset_mram: + super().reset_mram() + return vram_matrix, mram_input, target + + def vram_sub_projection_to( + self, + vram_matrix: VRAMMatrixVar, + vram_row_idx: int, + mram_input: InputVar, + mram_col_idx: int, + target: VRAMMatrixVar, + target_row_idx: int, + target_col_idx: int, + auto_reset_mram: bool = True, + k_block_start: int = 0, + k_block_count: int | None = None, + ): + """ + target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[:][mram_col_idx] + Supports K-split: k_block_start/k_block_count select a subset of K tiles. + """ + vram_matrix, mram_input, target = self._prepare_projection( + vram_matrix, mram_input, target, auto_reset_mram + ) + super().load_sub_matrix_col( + name=mram_input.name, + col_idx=mram_col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + super().vram_sub_projection_to( + vram_mat_name=vram_matrix.name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_input.name, + mram_col_idx=mram_col_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + k_block_start=k_block_start, + k_block_count=k_block_count, + ) + + def vram_sub_projection_T_to( + self, + vram_matrix: VRAMMatrixVar, + vram_row_idx: int, + mram_input: InputVar, + mram_row_idx: int, + target: VRAMMatrixVar, + target_row_idx: int, + target_col_idx: int, + auto_reset_mram: bool = True, + ): + """ + target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[mram_row_idx][:]^T + """ + vram_matrix, mram_input, target = self._prepare_projection( + vram_matrix, mram_input, target, auto_reset_mram + ) + super().load_sub_matrix_row(name=mram_input.name, row_idx=mram_row_idx) + super().vram_sub_projection_T_to( + vram_mat_name=vram_matrix.name, + vram_row_idx=vram_row_idx, + mram_mat_name=mram_input.name, + mram_row_idx=mram_row_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + ) + + def linear_projection(self, input_var: VRAMMatrixVar, weight_var: InputVar, name: str = "linear_out"): + """Emit tiled PLENA linear projection, including K-split accumulation.""" + mlen = self.mlen + + rows, k_total = input_var.shape + _, out_features = weight_var.shape + num_row_blocks = math.ceil(rows / mlen) + if out_features % mlen != 0: + raise ValueError(f"out_features ({out_features}) must be a multiple of mlen ({mlen})") + num_col_blocks = out_features // mlen + num_k_tiles = math.ceil(k_total / mlen) + + # When rows is not a multiple of mlen the hardware still operates on + # full tiles; only the first `rows` rows contain valid output. + output = self.alloc(name, rows, out_features, strict=rows % mlen == 0) + + def emit_projection(row_idx, col_idx, target, target_row_idx, target_col_idx, **k_split): + self.vram_sub_projection_to( + input_var, + row_idx, + weight_var, + col_idx, + target, + target_row_idx, + target_col_idx, + **k_split, + ) + + if num_k_tiles <= MAX_K_TILES: + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + emit_projection(row_idx, col_idx, output, row_idx, col_idx) + return output + + # Temp buffer for one partial-sum tile. Allocating the full output shape + # here can overlap with the real output for wide projections. + temp = self.alloc(f"{name}_temp", mlen, mlen) + for k_chunk_idx, (k_block_start, k_block_count) in enumerate(_iter_k_chunks(num_k_tiles)): + k_split = { + "k_block_start": k_block_start, + "k_block_count": k_block_count, + } + for col_idx in range(num_col_blocks): + for row_idx in range(num_row_blocks): + if k_chunk_idx == 0: + emit_projection(row_idx, col_idx, output, row_idx, col_idx, **k_split) + else: + emit_projection(row_idx, col_idx, temp, 0, 0, **k_split) + self.vram_block_add_to( + output, + row_idx, + col_idx, + temp, + 0, + 0, + output, + row_idx, + col_idx, + ) + self.free_tensor(temp) + return output + + def linear(self, input_var: VRAMMatrixVar, weight_var: InputVar): + """Default linear op compatibility surface.""" + return self.linear_projection(input_var, weight_var) + + # ======================================================================== + # RoPE (1D Positional Encoding) + # ======================================================================== + + def rope( + self, + x_var: VRAMMatrixVar, + x_rot_var: VRAMMatrixVar, + cos_var: VRAMMatrixVar, + sin_var: VRAMMatrixVar, + ) -> VRAMMatrixVar: + """Apply Rotary Position Embedding in-place: x = x * cos + rotate_half(x) * sin + + x_rot_var must already be in VRAM as rotate_half(x), preloaded by caller. + Returns x_var (modified in-place). + """ + super().rope( + x_name=x_var.name, + x_rot_name=x_rot_var.name, + cos_name=cos_var.name, + sin_name=sin_var.name, + ) + return x_var + + # ======================================================================== + # VRAM Matrix Addition + # ======================================================================== + + def vram_add( + self, + dst: VRAMMatrixVar, + src: VRAMMatrixVar, + dst_row_offset: int = 0, + src_row_offset: int = 0, + num_rows: int | None = None, + ): + """VRAM matrix add: dst[row_offset:] += src""" + super().vram_matrix_add( + dst_matrix=dst.name, + src_matrix=src.name, + dst_row_offset=dst_row_offset, + src_row_offset=src_row_offset, + num_rows=num_rows, + ) + + def embedding_add(self, input_var: VRAMMatrixVar, pos_weight_var: VRAMMatrixVar): + """Add learned/positional embedding weights to input in-place.""" + self.vram_add(input_var, pos_weight_var) + return input_var + + def vram_block_add_to( + self, + src1: TensorVar, + src1_row_idx: int, + src1_col_idx: int, + src2: TensorVar, + src2_row_idx: int, + src2_col_idx: int, + target: TensorVar, + target_row_idx: int, + target_col_idx: int, + ): + """ + mlen x mlen block add: + target[target_row_idx][target_col_idx] = + src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] + + Supports writing back to the same matrix/block (in-place overwrite). + """ + src1 = self._require_var(src1, VRAMMatrixVar, "src1") + src2 = self._require_var(src2, VRAMMatrixVar, "src2") + target = self._require_var(target, VRAMMatrixVar, "target") + + super().vram_block_add_to( + src1_matrix=src1.name, + src1_row_idx=src1_row_idx, + src1_col_idx=src1_col_idx, + src2_matrix=src2.name, + src2_row_idx=src2_row_idx, + src2_col_idx=src2_col_idx, + target_matrix=target.name, + target_row_idx=target_row_idx, + target_col_idx=target_col_idx, + ) + + +__all__ = ["ProgramMatrixOpsMixin"] diff --git a/aten/plena/program_tensors.py b/aten/plena/program_tensors.py new file mode 100644 index 0000000..b8b946b --- /dev/null +++ b/aten/plena/program_tensors.py @@ -0,0 +1,358 @@ +"""Tensor, memory, and normalization operations for the PLENA program builder.""" + +from __future__ import annotations + +from compiler.asm_templates import ffn_asm, preload_addr_reg_asm, reset_reg_asm +from compiler.aten.plena.vars import FPVar, InputVar, TensorVar, VRAMMatrixVar + + +class ProgramTensorMixin: + # ======================================================================== + # Input Declaration + # ======================================================================== + + def input( + self, + name: str, + shape: tuple[int, int], + hbm_addr: int | None = None, + prestaged_vram_addr: int | None = None, + ) -> InputVar: + """ + Declare an input tensor (in HBM). + + Args: + name: tensor name + shape: (height, width) + hbm_addr: HBM address (None = auto-allocate) + prestaged_vram_addr: If an int, the tensor is assumed to be already + present in VRAM at this byte address. A subsequent call to + ``load_batch`` will register it at that address without emitting + any HBM→VRAM prefetch instructions. If None (default), the + normal HBM→VRAM load path is used. + + Returns: + InputVar proxy object + """ + h, w = shape + size = h * w + hbm_size = int(size * self.real_data_ratio) + + if hbm_addr is None: + hbm_addr = self._allocate_hbm(hbm_size) + + var = InputVar(self, name, shape, hbm_addr, hbm_size, prestaged_vram_addr=prestaged_vram_addr) + self._inputs[name] = var + super().add_hbm_object( + name=name, + hbm_addr=hbm_addr, + shape=shape, + real_data_ratio=self.real_data_ratio, + ) + return var + + # ======================================================================== + # Load Operations + # ======================================================================== + + def load_batch( + self, + input_var: InputVar, + name: str | None = None, + ) -> VRAMMatrixVar: + """ + Load tensor from HBM to VRAM (Batch type). + + When ``input_var.prestaged_vram_addr`` is set the tensor is assumed to + be already resident in VRAM at that address. No HBM→VRAM prefetch + instructions are emitted; the tensor is simply registered in the symbol + table at the given address. + + Args: + input_var: source InputVar + name: result name (None = use input name) + + Returns: + VRAMMatrixVar proxy object + """ + if not isinstance(input_var, InputVar): + raise TypeError(f"Expected InputVar, got {type(input_var)}") + + display_name = name if name is not None else input_var.display_name + internal_name = self._scoped_name(display_name) + + if input_var.prestaged_vram_addr is not None: + # Prestaged path: tensor is already in VRAM — register without ISA. + h, w = input_var.shape + vram_addr = input_var.prestaged_vram_addr + # Tell the VRAM allocator that this region is occupied so subsequent + # allocations don't collide with it. + self.vram_allocator._vmm.mark_used(vram_addr, h * w, name=internal_name) + super().add_vram_object( + name=internal_name, + shape=(h, w), + vram_addr=vram_addr, + dtype="fp16", + kind="Batch", + allocate_if_none=False, + strict=False, + ) + else: + # Normal path: emit HBM → VRAM prefetch ISA. + super().load_batch( + hbm_object_name=input_var.name, vram_object_name=internal_name, vlen=self.mlen, preload_len=4 + ) + + var = VRAMMatrixVar(self, internal_name, input_var.shape, display_name=display_name) + self._tensors[internal_name] = var + return var + + # ======================================================================== + # Store Operations + # ======================================================================== + + def store(self, tensor_var, name: str | None = None, hbm_addr: int | None = None) -> InputVar: + """ + Write tensor from VRAM back to HBM. + + Returns: + InputVar proxy object (can be loaded back later) + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"Store requires VRAMMatrixVar, got {type(tensor_var)}") + + display_name = name if name is not None else f"{tensor_var.display_name}_stored" + internal_name = self._scoped_name(display_name) + + if hbm_addr is None: + h, w = tensor_var.shape + size = h * w + hbm_size = int(size * self.real_data_ratio) + hbm_addr = self._allocate_hbm(hbm_size) + else: + h, w = tensor_var.shape + hbm_size = int(h * w * self.real_data_ratio) + + super().store_to_hbm( + tensor_name=tensor_var.name, # internal name for symbol table lookup + hbm_addr=hbm_addr, + hbm_object_name=internal_name, + vlen=self.mlen, + ) + + var = InputVar(self, internal_name, tensor_var.shape, hbm_addr, hbm_size, display_name=display_name) + self._inputs[internal_name] = var + return var + + # ======================================================================== + # VRAM Matrix Allocation + # ======================================================================== + + def alloc(self, name: str, rows: int, cols: int, strict: bool = True) -> VRAMMatrixVar: + """ + Allocate a VRAM matrix. + + Used to store intermediate results (e.g., S block, PV, O). + Within function scope, names are automatically prefixed to avoid conflicts. + + Args: + name: matrix name (user-visible) + rows: number of rows + cols: number of columns + strict: if False, skip mlen-alignment checks (for small scratch matrices) + + Returns: + VRAMMatrixVar proxy object + """ + display_name = name + internal_name = self._scoped_name(name) + super().allocate_vram_matrix(name=internal_name, rows=rows, cols=cols, strict=strict) + + var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) + self._tensors[internal_name] = var + return var + + def alloc_at(self, name: str, rows: int, cols: int, vram_addr: int) -> VRAMMatrixVar: + """Allocate a VRAM matrix view at a specific address. + + Used to create views into existing VRAM matrices (e.g., per-head + slices of a multi-head Q projection output). Does NOT bump the + VRAM allocator -- the caller is responsible for ensuring the region + is valid. + + Args: + name: matrix name (user-visible) + rows: number of rows + cols: number of columns + vram_addr: absolute VRAM address for this view + + Returns: + VRAMMatrixVar proxy object + """ + display_name = name + internal_name = self._scoped_name(name) + self.add_vram_object( + name=internal_name, + shape=(rows, cols), + vram_addr=vram_addr, + allocate_if_none=False, + ) + isa_code = f"; VRAM View {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" + self.emit(isa_code) + var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) + self._tensors[internal_name] = var + return var + + def free_tensor(self, tensor_var: TensorVar): + """ + Free a tensor in VRAM, reclaiming space for subsequent allocations. + + Freed space can be reused by new alloc() or other operations. + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"Can only free VRAMMatrixVar, got {type(tensor_var)}") + + super().free_vram_object(tensor_var.name, strict=False) + # Keep sub-matrix registration state consistent after free. + self._registered_vram_sub_matrices[tensor_var.name] = False + + def free_input(self, input_var: InputVar): + """ + Free an InputVar bookkeeping and recycle its HBM range for future auto-allocation. + + Notes: + - This only affects PlenaCompiler's address management state. + - If a freed input is referenced again later, caller is responsible for correctness. + """ + if not isinstance(input_var, InputVar): + raise TypeError(f"Can only free InputVar, got {type(input_var)}") + + super().free_hbm_object(input_var.name, strict=False) + self._registered_hbm_sub_matrices[input_var.name] = False + self._recycle_hbm(input_var.hbm_addr, input_var.hbm_size) + self._inputs.pop(input_var.name, None) + + def free_fp_var(self, fp_var: FPVar): + """ + Free an FPVar and return its block to FPRAM free pool. + """ + if not isinstance(fp_var, FPVar): + raise TypeError(f"Can only free FPVar, got {type(fp_var)}") + self.free_fpram(fp_var.name, strict=True) + + # ======================================================================== + # Normalization Operations + # ======================================================================== + + def norm( + self, + tensor_var: TensorVar, + mode: str = "rms", + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """ + Normalize tensor in-place. + + Args: + tensor_var: tensor to normalize (must have VRAM backing, e.g., VRAMMatrixVar) + mode: "rms" or "layer" + eps_offset: FPRAM address of epsilon + reci_hid_offset: FPRAM address of 1/hidden_dim + vlen: vector length (default: program mlen) + scratchpad_vram_addr: optional scratchpad VRAM address + + Returns: + The same tensor_var (in-place operation) + """ + if not isinstance(tensor_var, VRAMMatrixVar): + raise TypeError(f"norm requires VRAMMatrixVar, got {type(tensor_var)}") + + super().normalize( + tensor_name=tensor_var.name, + mode=mode, + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + return tensor_var + + def rms_norm( + self, + tensor_var: TensorVar, + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """RMS normalization (in-place).""" + return self.norm( + tensor_var=tensor_var, + mode="rms", + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + + def layer_norm( + self, + tensor_var: TensorVar, + eps_offset: int = 1, + reci_hid_offset: int = 2, + vlen: int | None = None, + scratchpad_vram_addr: int | None = None, + ) -> TensorVar: + """Layer normalization (in-place).""" + return self.norm( + tensor_var=tensor_var, + mode="layer", + eps_offset=eps_offset, + reci_hid_offset=reci_hid_offset, + vlen=vlen, + scratchpad_vram_addr=scratchpad_vram_addr, + ) + + # ======================================================================== + # Composite Decoder Operations + # ======================================================================== + + def ffn(self, input_var: VRAMMatrixVar, w_gate: InputVar, w_up: InputVar, w_down: InputVar): + """Emit the fused FFN kernel and return the in-place activation var.""" + batch_size, hidden_size = input_var.shape + _, inter_dim = w_up.shape + mlen = self.mlen + blen = self.blen + activation_base_address = self.get_vram_addr(input_var.name) + + isa_code = preload_addr_reg_asm( + addr_reg_to_set=[1, 2, 3], + available_registers=[1, 2, 3], + addr_reg_val=[w_gate.hbm_addr, w_up.hbm_addr, w_down.hbm_addr], + ) + isa_code += reset_reg_asm(alive_registers=[1, 2, 3]) + isa_code += ffn_asm( + mlen=mlen, + vlen=mlen, + blen=blen, + batch=batch_size, + seq_len=1, + hidden_size=hidden_size, + intermediate_size=inter_dim, + alive_registers=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + gate_weight_hbm_offset_reg=1, + up_weight_hbm_offset_reg=2, + down_weight_hbm_offset_reg=3, + const_one_fp_address=5, + activation_base_address=activation_base_address, + use_loop_instructions=True, + ) + + self.emit(isa_code) + return input_var + + +__all__ = ["ProgramTensorMixin"] diff --git a/aten/plena/registers.py b/aten/plena/registers.py new file mode 100644 index 0000000..029dba4 --- /dev/null +++ b/aten/plena/registers.py @@ -0,0 +1,66 @@ +"""Register allocation helpers for the ATen PLENA compiler path.""" + +class RegisterAllocator: + """Register Allocator: Manages address registers and GP registers""" + + def __init__(self, start_gp: int = 1, start_addr: int = 0, start_fp: int = 1): + # HW OPERAND_WIDTH = 4 bits → gp0-gp15; gp0 reserved as constant 0. + self.gp_registers = list(range(start_gp, 16)) + self.addr_registers = list(range(start_addr, 8)) + # f0 reserved as constant 0 (writing to f0 is a no-op for V_RED_MAX/V_RED_SUM). + self.fp_registers = list(range(start_fp, 8)) + self.used_gp = [] + self.used_addr = [] + self.used_fp = [] + + def allocate_gp(self, count: int = 1) -> list[int]: + if len(self.gp_registers) < count: + raise RuntimeError(f"Not enough GP registers available. Need {count}, have {len(self.gp_registers)}") + + allocated = self.gp_registers[:count] + self.gp_registers = self.gp_registers[count:] + self.used_gp.extend(allocated) + return allocated + + def allocate_addr(self, count: int = 1) -> list[int]: + if len(self.addr_registers) < count: + raise RuntimeError(f"Not enough address registers available. Need {count}, have {len(self.addr_registers)}") + + allocated = self.addr_registers[:count] + self.addr_registers = self.addr_registers[count:] + self.used_addr.extend(allocated) + return allocated + + def free_gp(self, registers: list[int]): + for reg in registers: + if reg in self.used_gp: + self.used_gp.remove(reg) + self.gp_registers.append(reg) + self.gp_registers.sort() + + def free_addr(self, registers: list[int]): + for reg in registers: + if reg in self.used_addr: + self.used_addr.remove(reg) + self.addr_registers.append(reg) + self.addr_registers.sort() + + def allocate_fp(self, count: int = 1) -> list[int]: + if len(self.fp_registers) < count: + raise RuntimeError(f"Not enough FP registers available. Need {count}, have {len(self.fp_registers)}") + + # Reverse allocation: prefer high-numbered regs to avoid conflicts with legacy hardcoded forward-allocation. + allocated = list(reversed(self.fp_registers[-count:])) + self.fp_registers = self.fp_registers[:-count] + self.used_fp.extend(allocated) + return allocated + + def free_fp(self, registers: list[int]): + for reg in registers: + if reg in self.used_fp: + self.used_fp.remove(reg) + self.fp_registers.append(reg) + # Keep sorted so allocate_fp's tail-slice continues to return descending IDs. + self.fp_registers.sort() + + diff --git a/aten/plena/vars.py b/aten/plena/vars.py new file mode 100644 index 0000000..d6dc65b --- /dev/null +++ b/aten/plena/vars.py @@ -0,0 +1,124 @@ +"""Tensor proxy classes for the ATen PLENA compiler path.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from compiler.aten.plena.compiler import PlenaCompiler + + +class TensorVar: + """ + Tensor proxy object base class + + All tensor variables inherit from this class. + Dual naming: + - display_name: User-visible name (e.g., "temp", "Q", "S") + - internal_name: System internal name, used for symbol table and ISA generation + """ + + def __init__( + self, + program: PlenaCompiler, + internal_name: str, + kind: str, + shape: tuple[int, int], + display_name: str | None = None, + ): + self._program = program + self.internal_name = internal_name # System internal name (with scope prefix), used by symbol table + self.display_name = display_name if display_name is not None else internal_name # User-visible name + self.kind = kind # "input", "batch", "matrix", "vram_matrix" + self.shape = shape + + @property + def name(self) -> str: + """Compatibility property: returns internal_name for internal system use""" + return self.internal_name + + def __repr__(self): + if self.display_name != self.internal_name: + return ( + f"{self.__class__.__name__}(display={self.display_name!r}, " + f"internal={self.internal_name!r}, shape={self.shape})" + ) + return f"{self.__class__.__name__}({self.display_name!r}, shape={self.shape})" + + +class InputVar(TensorVar): + """ + Input variable: tensor declared in HBM + + Not yet loaded to VRAM; needs to be loaded via load_batch / load_matrix. + + If ``prestaged_vram_addr`` is not None the tensor is assumed to be already + present in VRAM at that byte address. ``load_batch`` will register it at + that address without emitting any HBM→VRAM prefetch instructions. + """ + + def __init__( + self, + program: PlenaCompiler, + name: str, + shape: tuple[int, int], + hbm_addr: int, + hbm_size: int, + display_name: str | None = None, + prestaged_vram_addr: int | None = None, + ): + super().__init__(program, name, "input", shape, display_name=display_name) + self.hbm_addr = hbm_addr + self.hbm_size = hbm_size + self.prestaged_vram_addr = prestaged_vram_addr + + +class FPVar: + """ + FP variable: maps to a contiguous region in FPRAM + + Declared via prog.fp_var("scale", size=1), automatically allocates FPRAM space. + Provides .address for ISA generation (S_LD_FP / S_ST_FP). + + Usage: + scale = prog.fp_var("scale", size=1) + m_old = prog.fp_var("m_old", size=64) + + scale.address # -> FPRAM address (int) + scale.size # -> number of elements + scale[3] # -> address + 3 (element offset) + """ + + def __init__( + self, program: PlenaCompiler, internal_name: str, address: int, size: int, display_name: str | None = None + ): + self._program = program + self.internal_name = internal_name + self.display_name = display_name if display_name is not None else internal_name + self.address = address + self.size = size + + @property + def name(self) -> str: + return self.internal_name + + def __getitem__(self, idx: int) -> int: + """Element offset: fp_var[i] -> address + i""" + if idx < 0 or idx >= self.size: + raise IndexError(f"FPVar '{self.display_name}' index {idx} out of range [0, {self.size})") + return self.address + idx + + def __repr__(self): + return f"FPVar({self.display_name!r}, addr={self.address}, size={self.size})" + + +class VRAMMatrixVar(TensorVar): + """ + VRAM matrix variable: large matrix allocated via alloc + + Used to store intermediate results (e.g., S block, PV, O). + Supports sub-block indexed writes: `O[r][c] = ...` + """ + + def __init__(self, program: PlenaCompiler, name: str, shape: tuple[int, int], display_name: str | None = None): + super().__init__(program, name, "vram_matrix", shape, display_name=display_name) diff --git a/aten/plena_compiler.py b/aten/plena_compiler.py deleted file mode 100644 index 550c872..0000000 --- a/aten/plena_compiler.py +++ /dev/null @@ -1,5756 +0,0 @@ -"""PlenaCompiler -- ATen Pipeline (Pipeline 1) compilation backend. - -Manages VRAM/MRAM/FPRAM allocation, HBM weight layout, and address -register initialization (C_SET_ADDR_REG). Produces numerically verified -ISA (98-100% allclose against PyTorch golden reference). - -Contains the full inheritance chain: TileCompiler (memory bookkeeping) -> -DeveloperCompiler (ISA emission, FP/FPRAM ops, interrupts) -> PlenaCompiler -(user-facing DSL). Tensor proxy classes (TensorVar, InputVar, VRAMMatrixVar, -FPVar) and the unified Tensor type union / TensorKind enum are re-exported -from the same module. - -Previously aliased as ``PLENAProgram``; that alias has been retired. - -See docs/COMPILATION_PIPELINES.md for the full architecture overview. -""" - -from __future__ import annotations - -from enum import Enum -from dataclasses import dataclass, field -import math - -import os -from collections.abc import Callable -from functools import wraps - -from compiler.asm_templates import ( - preload_act_asm, - reset_reg_asm, - preload_addr_reg_asm, - store_act_asm, - rms_norm_asm, - layer_norm_asm, - rope_asm, -) -from compiler.asm_templates.vram_sub_projection_asm import vram_sub_projection_asm_impl - - -MLEN = 64 # Minimum matrix block size -BLEN = 4 # Vector tile size -IMM2_BOUND = 2**18 - - -# ============================================================================== -# Virtual Memory Manager -# ============================================================================== - - -@dataclass -class MemoryBlock: - """Memory Block Information""" - - name: str # Allocation name (e.g., "W[0][1]" or "activation_A") - addr: int # Starting address - size: int # Block size (number of elements) - - def __repr__(self) -> str: - return f"MemBlock({self.name}, addr={self.addr}, size={self.size})" - - -class VirtualMemoryManager: - """ - Virtual Memory Manager - - Core Design: - - used_stack: Allocated and in-use memory blocks - - free_stack: Freed and reusable memory blocks - - Workflow: - 1. allocate(name, size): Allocate memory - - Prioritize best-fit search for reusable blocks in free_stack - - If not found, use bump allocation from the end - 2. free(name): Free memory - - Move block from used_stack to free_stack - - Address can be reused by subsequent allocate calls - - Throws exception if not found when strict=True, returns None when strict=False - - VRAM/MRAM storage format: (batch_size, mlen, hidden_size/mlen), column-block major. - mlen=64, blen=4. Alignment depends on storage hierarchy. - """ - - def __init__(self, total_size: int, alignment: int = MLEN, mem_name: str = "Memory"): - """ - Args: - total_size: Total memory size (number of elements) - alignment: alignment granularity (VRAM uses MLEN=64, MRAM uses MLEN*MLEN=4096) - mem_name: Memory name, for debugging information (e.g., "VRAM" or "MRAM") - """ - self.total_size = total_size - self.alignment = alignment - self.mem_name = mem_name - self.next_bump = 0 # Bump allocation pointer - - # Two core stacks - self.used_stack: list[MemoryBlock] = [] - self.free_stack: list[MemoryBlock] = [] - - def _align(self, value: int) -> int: - """Align value to alignment""" - return ((value + self.alignment - 1) // self.alignment) * self.alignment - - def allocate(self, name: str, size: int) -> int: - """ - Allocate memory. - - Strategy: - 1. Best-fit from free_stack (reusable block with least waste). - 2. If no suitable block, bump allocation. - - Returns: - Allocated starting address. - - Raises: - MemoryError: Insufficient memory. - """ - aligned_size = self._align(size) - - # Strategy 1: best-fit from free_stack - best_idx = None - best_waste = float("inf") - - for i, block in enumerate(self.free_stack): - if block.size >= aligned_size: - waste = block.size - aligned_size - if waste < best_waste: - best_waste = waste - best_idx = i - - if best_idx is not None: - reused_block = self.free_stack.pop(best_idx) - - # If block is larger than needed, split remaining part and return to free_stack - if reused_block.size > aligned_size: - remaining = MemoryBlock( - name="", addr=reused_block.addr + aligned_size, size=reused_block.size - aligned_size - ) - self.free_stack.append(remaining) - - new_block = MemoryBlock(name=name, addr=reused_block.addr, size=aligned_size) - self.used_stack.append(new_block) - return new_block.addr - - # Strategy 2: Bump allocation - aligned_addr = self._align(self.next_bump) - - if self.total_size > 0 and aligned_addr + aligned_size > self.total_size: - raise MemoryError( - f"{self.mem_name} overflow: need {aligned_size} at addr {aligned_addr}, " - f"total_size={self.total_size}, " - f"used={len(self.used_stack)} blocks, " - f"free={len(self.free_stack)} blocks" - ) - - new_block = MemoryBlock(name=name, addr=aligned_addr, size=aligned_size) - self.used_stack.append(new_block) - self.next_bump = aligned_addr + aligned_size - return aligned_addr - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """ - Free memory: move block from used_stack to free_stack. - - Args: - name: Name of allocation to free - strict: Throws KeyError if not found when strict=True, returns None when strict=False - - Returns: - Freed memory block, returns None if strict=False and not found - """ - for i, block in enumerate(self.used_stack): - if block.name == name: - freed = self.used_stack.pop(i) - self.free_stack.append(freed) - self._coalesce_free_stack() - return freed - - if strict: - raise KeyError( - f"{self.mem_name}: allocation '{name}' not found in used_stack. " - f"Current used: {[b.name for b in self.used_stack]}" - ) - return None - - def mark_used(self, addr: int, size: int, name: str) -> None: - """ - Register a pre-known address range as occupied without bump allocation. - - Used for prestaged VRAM tensors that are already in VRAM before program - execution (e.g. Q pre-loaded by the test harness). Advances next_bump - past the end of this region so subsequent bump allocations do not collide. - - Args: - addr: Start address of the pre-occupied region. - size: Number of elements in the region. - name: Name for tracking/free. - """ - aligned_size = self._align(size) - block = MemoryBlock(name=name, addr=addr, size=aligned_size) - self.used_stack.append(block) - # Advance bump pointer past this region if it would otherwise overlap. - end = addr + aligned_size - if self.next_bump < end: - self.next_bump = end - - def _coalesce_free_stack(self): - """ - Merge adjacent free blocks by address to reduce long-term fragmentation. - """ - if len(self.free_stack) <= 1: - return - - blocks = sorted(self.free_stack, key=lambda b: b.addr) - merged: list[MemoryBlock] = [blocks[0]] - for block in blocks[1:]: - prev = merged[-1] - if prev.addr + prev.size == block.addr: - merged[-1] = MemoryBlock( - name="", - addr=prev.addr, - size=prev.size + block.size, - ) - else: - merged.append(block) - - self.free_stack = merged - - def is_allocated(self, name: str) -> bool: - """Check if a name is in used_stack""" - return any(b.name == name for b in self.used_stack) - - def get_block(self, name: str) -> MemoryBlock | None: - """Get memory block with specified name from used_stack""" - for block in self.used_stack: - if block.name == name: - return block - return None - - def get_used_size(self) -> int: - """Get total used size""" - return sum(b.size for b in self.used_stack) - - def get_free_size(self) -> int: - """Get total reusable size""" - return sum(b.size for b in self.free_stack) - - def reset(self): - """Reset manager""" - self.next_bump = 0 - self.used_stack.clear() - self.free_stack.clear() - - def print_status(self): - """Print memory status""" - print(f"=== {self.mem_name} Virtual Memory Status ===") - print(f"Total size: {self.total_size}") - print(f"Bump pointer: {self.next_bump}") - print(f"Used blocks ({len(self.used_stack)}):") - for b in self.used_stack: - print(f" {b}") - print(f"Free blocks ({len(self.free_stack)}):") - for b in self.free_stack: - print(f" {b}") - total_used = self.get_used_size() - total_free = self.get_free_size() - if self.total_size > 0: - available = self.total_size - self.next_bump + total_free - print( - f"Summary: used={total_used}, free={total_free}, " - f"bump={self.next_bump}, available={available}/{self.total_size}" - ) - else: - print(f"Summary: used={total_used}, free={total_free}, bump={self.next_bump} (unlimited mode)") - - def __repr__(self) -> str: - return ( - f"VirtualMemoryManager({self.mem_name}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)}, " - f"bump={self.next_bump}/{self.total_size})" - ) - - -# ============================================================================== -# Sub-matrix Information -# ============================================================================== - - -@dataclass -class SubMatrixInfo: - """Metadata for sub-matrices""" - - parent_name: str # Parent matrix name - row_idx: int # Sub-block row index - col_idx: int # Sub-block column index - shape: tuple[int, int] # Sub-block shape (typically 64x64) - - # Pre-calculated addresses (computed during compiler phase, used directly at runtime) - hbm_offset: int = 0 # Offset in HBM (in elements) - mram_addr: int | None = None # Address in MRAM (if loaded) - - def __repr__(self) -> str: - mram_str = f"{self.mram_addr}" if self.mram_addr is not None else "None" - return ( - f"SubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, hbm_off={self.hbm_offset}, mram={mram_str})" - ) - - -@dataclass -class MatrixBlockLayout: - """ - Block layout information for large matrices. - - HBM storage: [rows, cols] row-major contiguous, stride=cols per row. - MRAM storage: [batch, mlen, hidden/mlen] column-block major. - """ - - name: str - full_shape: tuple[int, int] # Full matrix shape (rows, cols) - block_size: int = MLEN # Sub-block size (default 64) - - num_row_blocks: int = 0 - num_col_blocks: int = 0 - - # HBM Address Information - hbm_base_addr: int = 0 - hbm_size: int = 0 # Size after applying real_data_ratio (MXFP8 = 1.125) - - # Sub-block map: (row_idx, col_idx) -> SubMatrixInfo - sub_blocks: dict[tuple[int, int], SubMatrixInfo] = field(default_factory=dict) - - def __post_init__(self): - """Initialize block information""" - rows, cols = self.full_shape - self.num_row_blocks = math.ceil(rows / self.block_size) - self.num_col_blocks = math.ceil(cols / self.block_size) - - # Create information for all sub-blocks (pre-calculate addresses) - for r in range(self.num_row_blocks): - for c in range(self.num_col_blocks): - # HBM offset (row-major): sub-block (r,c) starts at r*block_size*cols + c*block_size - hbm_offset = r * self.block_size * cols + c * self.block_size - - sub_info = SubMatrixInfo( - parent_name=self.name, - row_idx=r, - col_idx=c, - shape=(self.block_size, self.block_size), - hbm_offset=hbm_offset, - mram_addr=None, - ) - self.sub_blocks[(r, c)] = sub_info - - def get_sub_block(self, row_idx: int, col_idx: int) -> SubMatrixInfo: - """Get specified sub-block""" - if (row_idx, col_idx) not in self.sub_blocks: - raise IndexError(f"Sub block [{row_idx}][{col_idx}] out of range") - return self.sub_blocks[(row_idx, col_idx)] - - def get_row_blocks(self, row_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a row""" - return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] - - def get_col_blocks(self, col_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a column""" - return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] - - -# ============================================================================== -# VRAM Sub-matrix Information -# ============================================================================== - - -@dataclass -class VRAMSubMatrixInfo: - """Metadata for sub-matrices in VRAM""" - - parent_name: str # Parent matrix name - row_idx: int # Sub-block row index (batch dimension) - col_idx: int # Sub-block column index (hidden dimension) - shape: tuple[int, int] # Sub-block shape (typically mlen x mlen) - - # Pre-calculated VRAM address - vram_addr: int = 0 - - def __repr__(self) -> str: - return ( - f"VRAMSubMatrix({self.parent_name}[{self.row_idx}][{self.col_idx}], " - f"shape={self.shape}, vram={self.vram_addr})" - ) - - -@dataclass -class VRAMMatrixBlockLayout: - """ - Block layout for large matrices in VRAM. - - VRAM storage format: [batch, mlen, hidden/mlen] column-block major. - - Batch dimension contiguous within each column block. - - Column blocks laid out sequentially. - - Column block c base address: vram_base + c * batch * mlen. - Sub-block (r, c) offset within column block: r * mlen * mlen. - """ - - name: str - full_shape: tuple[int, int] # Full matrix shape (batch, hidden_size) - vram_base_addr: int # VRAM base address - block_size: int = MLEN # Sub-block size (default 64) - - num_row_blocks: int = 0 # Number of blocks in batch dimension - num_col_blocks: int = 0 # Number of blocks in hidden dimension - - # Sub-block map: (row_idx, col_idx) -> VRAMSubMatrixInfo - sub_blocks: dict[tuple[int, int], VRAMSubMatrixInfo] = field(default_factory=dict) - - def __post_init__(self): - """Initialize block information""" - batch, hidden = self.full_shape - self.num_row_blocks = math.ceil(batch / self.block_size) - self.num_col_blocks = math.ceil(hidden / self.block_size) - - # VRAM column-block major address calculation: - # col block c base = vram_base + c * batch * mlen - # row sub-block r offset within column block = r * mlen * mlen - for r in range(self.num_row_blocks): - for c in range(self.num_col_blocks): - col_block_base = self.vram_base_addr + c * batch * self.block_size - row_offset = r * self.block_size * self.block_size - vram_addr = col_block_base + row_offset - - sub_info = VRAMSubMatrixInfo( - parent_name=self.name, - row_idx=r, - col_idx=c, - shape=(self.block_size, self.block_size), - vram_addr=vram_addr, - ) - self.sub_blocks[(r, c)] = sub_info - - def get_sub_block(self, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: - """Get specified sub-block""" - if (row_idx, col_idx) not in self.sub_blocks: - raise IndexError(f"VRAM sub block [{row_idx}][{col_idx}] out of range") - return self.sub_blocks[(row_idx, col_idx)] - - def get_row_blocks(self, row_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a row (A[row_idx][:])""" - return [self.sub_blocks[(row_idx, c)] for c in range(self.num_col_blocks)] - - def get_col_blocks(self, col_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a column""" - return [self.sub_blocks[(r, col_idx)] for r in range(self.num_row_blocks)] - - def get_row_vram_addrs(self, row_idx: int) -> list[int]: - """Get list of VRAM addresses for all sub-blocks in a row""" - return [block.vram_addr for block in self.get_row_blocks(row_idx)] - - -# ============================================================================== -# Unified Memory Object Metadata -# ============================================================================== - - -@dataclass -class MemoryObjectInfo: - """Unified metadata for objects managed across HBM / VRAM / FPRAM.""" - - name: str - kind: str - dtype: str = "fp16" - shape: tuple[int, int] = (0, 0) - size: int = 0 - hbm_addr: int = -1 - hbm_size: int = 0 - vram_addr: int | None = None - fpram_addr: int | None = None - fpram_size: int = 0 - - -@dataclass -class FPRAMObjectLayout: - """FPRAM object layout.""" - - name: str - fpram_addr: int - size: int - dtype: str = "fp16" - kind: str = "FPRAMObject" - - -# ============================================================================== -# MRAM Allocator -# ============================================================================== - - -class MRAMAllocator: - """ - Matrix RAM address allocator (based on VirtualMemoryManager). - - Each sub-block is mlen x mlen = 4096 elements; aligned to mlen*mlen. - Default total_size=16384 holds 4 sub-blocks (MAX_K_TILES=4). - Supports virtual free/reuse: freed blocks move to free_stack and are - preferred by subsequent allocations. - """ - - def __init__(self, total_size: int = MLEN * MLEN * 4): - """ - Args: - total_size: Total MRAM size (default 16384, can hold 4 64x64 matrix blocks) - """ - self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=MLEN * MLEN, # aligned to one sub-block size - mem_name="MRAM", - ) - - @property - def next_free(self) -> int: - return self._vmm.next_bump - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, name: str, size: int) -> int: - """Allocate MRAM space (prioritize reusing freed blocks).""" - return self._vmm.allocate(name, size) - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """Free specified allocation: move from used_stack to free_stack.""" - return self._vmm.free(name, strict=strict) - - def is_allocated(self, name: str) -> bool: - """Check if a name is allocated""" - return self._vmm.is_allocated(name) - - def reset(self): - """Reset allocator""" - self._vmm.reset() - - def print_status(self): - """Print memory status""" - self._vmm.print_status() - - -class VRAMAllocator: - """ - VRAM address allocator (based on VirtualMemoryManager). - - VRAM supports best-fit reuse + bump allocation, same as MRAM/FPRAM allocators. - Alignment defaults to MLEN to match VRAM storage format requirements. - """ - - def __init__(self, alignment: int = MLEN, total_size: int = 0): - self.alignment = alignment - self._vmm = VirtualMemoryManager(total_size=total_size, alignment=alignment, mem_name="VRAM") - - @property - def next_free(self) -> int: - return self._vmm.next_bump - - @next_free.setter - def next_free(self, value: int): - self._vmm.next_bump = value - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, size: int, name: str = "") -> int: - if not name: - raise ValueError("VRAMAllocator.allocate() requires name for subsequent free.") - return self._vmm.allocate(name, size) - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - return self._vmm.free(name, strict=strict) - - def is_allocated(self, name: str) -> bool: - return self._vmm.is_allocated(name) - - def reset(self): - self._vmm.reset() - - def print_status(self): - self._vmm.print_status() - - def __repr__(self) -> str: - return ( - f"VRAMAllocator(next_free={self.next_free}, alignment={self.alignment}, " - f"used={len(self.used_stack)}, free={len(self.free_stack)})" - ) - - -class FPRAMAllocator: - """ - Floating Point RAM Allocator (based on VirtualMemoryManager). - - FPRAM stores scalar FP values (f16), accessed via S_LD_FP / S_ST_FP. - Uses the same strategy as VRAM/MRAM allocator: - - used_stack + free_stack - - allocate: best-fit reuse first, then bump - - free: move block to free_stack, supports out-of-order free - - Fixed slot conventions (must not be overwritten by dynamic allocations): - slot 0 = 0.0 (gp0/f0 reserved as hardware zero) - slot 1 = attn_scale - slot 2 = -inf (online softmax) - slot 3 = eps (rms_norm/layer_norm) - slot 4 = 1/hidden_size (rms_norm/layer_norm) - slot 5 = 1.0 (FFN SiLU sigmoid denominator; im2col fp_one_reg) - - Hardware: 1024 f16 elements (configurable via total_size). - """ - - def __init__(self, total_size: int = 1024): - """ - Args: - total_size: Total FP RAM size (default 1024, matching hardware fpsram) - """ - self.total_size = total_size - self._vmm = VirtualMemoryManager( - total_size=total_size, - alignment=1, - mem_name="FPRAM", - ) - self.allocations: dict[str, tuple[int, int]] = {} - self._snapshots: dict[int, tuple[int, list[MemoryBlock], list[MemoryBlock], dict[str, tuple[int, int]]]] = {} - self._next_snapshot_id = 1 - - @property - def next_free(self) -> int: - """Compatibility alias: next bump pointer.""" - return self._vmm.next_bump - - @next_free.setter - def next_free(self, value: int): - if value < 0 or value > self.total_size: - raise ValueError(f"next_free out of range: {value}, expected [0, {self.total_size}]") - self._vmm.next_bump = value - - @property - def used_stack(self) -> list[MemoryBlock]: - return self._vmm.used_stack - - @property - def free_stack(self) -> list[MemoryBlock]: - return self._vmm.free_stack - - def allocate(self, name: str, size: int) -> int: - """Allocate FP RAM space (best-fit + bump).""" - if size <= 0: - raise ValueError(f"FPRAM allocation size must be > 0, got {size}") - if name in self.allocations: - raise KeyError(f"FPRAM name '{name}' already allocated") - - addr = self._vmm.allocate(name, size) - self.allocations[name] = (addr, size) - return addr - - def free(self, name: str, strict: bool = True) -> MemoryBlock | None: - """Free a block and move it to free_stack (same as VirtualMemoryManager).""" - freed = self._vmm.free(name, strict=strict) - if freed is not None: - self.allocations.pop(name, None) - return freed - - def save_state(self) -> int: - """ - Save current allocator state and return a snapshot token. - """ - sid = self._next_snapshot_id - self._next_snapshot_id += 1 - self._snapshots[sid] = ( - self._vmm.next_bump, - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.used_stack], - [MemoryBlock(b.name, b.addr, b.size) for b in self._vmm.free_stack], - dict(self.allocations), - ) - return sid - - def restore_state(self, snapshot: int): - """Restore allocator state from snapshot token.""" - if snapshot not in self._snapshots: - raise KeyError(f"Unknown FPRAM snapshot id: {snapshot}") - next_bump, used_stack, free_stack, allocations = self._snapshots[snapshot] - self._vmm.next_bump = next_bump - self._vmm.used_stack = [MemoryBlock(b.name, b.addr, b.size) for b in used_stack] - self._vmm.free_stack = [MemoryBlock(b.name, b.addr, b.size) for b in free_stack] - self.allocations = dict(allocations) - - def reset(self): - """Reset allocator""" - self._vmm.reset() - self.allocations.clear() - self._snapshots.clear() - self._next_snapshot_id = 1 - - -# ============================================================================== -# Sub Matrix Manager -# ============================================================================== - - -class TileCompiler: - """ - Tile compiler / sub-matrix manager. - - Core Functions: - 1. Register large matrices as blocked matrices. - 2. Support sub-block indexing: matrix[row_idx][col_idx] or matrix[row_idx][:]. - 3. Pre-calculate all addresses at compiler phase. - 4. Generate ISA code for load_sub_matrix and sub_projection. - - Key Constraints: - - Minimum block size is 64x64 (MLEN). - - Matrix must be loaded into MRAM before participating in computation. - - HBM and VRAM/MRAM use different storage formats (row-major vs column-block major). - """ - - def __init__(self, mlen: int = MLEN, blen: int = BLEN, unroll_loops: bool = False): - self.mlen = mlen - self.blen = blen - self.unroll_loops = unroll_loops - - # Layout tables - self.hbm_matrices: dict[str, MatrixBlockLayout] = {} - self.vram_matrices: dict[str, VRAMMatrixBlockLayout] = {} - self.fpram_matrices: dict[str, FPRAMObjectLayout] = {} - # Memory Allocators - self.vram_allocator = VRAMAllocator() - self.mram_allocator = MRAMAllocator() - self.fpram_allocator = FPRAMAllocator() - - # Currently loaded sub-blocks in MRAM - self.loaded_sub_blocks: dict[str, SubMatrixInfo] = {} - - # Pre-calculated address cache - self._address_cache: dict[str, int] = {} - - def __contains__(self, name: str) -> bool: - return name in self.hbm_matrices or name in self.vram_matrices or name in self.fpram_matrices - - def __getitem__(self, name: str) -> MemoryObjectInfo: - if name not in self: - raise KeyError(f"Object '{name}' not found") - info = MemoryObjectInfo(name=name, kind="Unknown") - hbm_layout = self.hbm_matrices.get(name) - vram_layout = self.vram_matrices.get(name) - fpram_layout = self.fpram_matrices.get(name) - - if hbm_layout is not None: - rows, cols = hbm_layout.full_shape - info.shape = hbm_layout.full_shape - info.size = rows * cols - info.hbm_addr = hbm_layout.hbm_base_addr - info.hbm_size = hbm_layout.hbm_size - info.kind = "Matrix" - - if vram_layout is not None: - rows, cols = vram_layout.full_shape - info.shape = vram_layout.full_shape - info.size = rows * cols - info.vram_addr = vram_layout.vram_base_addr - info.kind = "VRAMMatrix" if hbm_layout is None else "Batch" - - if fpram_layout is not None: - info.shape = (1, fpram_layout.size) - info.size = fpram_layout.size - info.fpram_addr = fpram_layout.fpram_addr - info.fpram_size = fpram_layout.size - info.kind = "FPRAMObject" - - return info - - def get(self, name: str, default: MemoryObjectInfo | None = None) -> MemoryObjectInfo | None: - try: - return self[name] - except KeyError: - return default - - def get_hbm_layout(self, name: str) -> MatrixBlockLayout: - """Read HBM matrix layout by name.""" - if name not in self.hbm_matrices: - raise KeyError(f"HBM matrix '{name}' not found") - return self.hbm_matrices[name] - - def get_vram_layout(self, name: str) -> VRAMMatrixBlockLayout: - """Read VRAM matrix layout by name.""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not found") - return self.vram_matrices[name] - - def get_fpram_layout(self, name: str) -> FPRAMObjectLayout: - """Read FPRAM object layout by name.""" - if name not in self.fpram_matrices: - raise KeyError(f"FPRAM object '{name}' not found") - return self.fpram_matrices[name] - - # ========================================================================== - # Unified Object Management APIs - # ========================================================================== - - def add_hbm_object( - self, - name: str, - shape: tuple[int, int], - hbm_addr: int, - dtype: str = "fp16", - kind: str = "HBMObject", - real_data_ratio: float = 1.125, - strict: bool = False, - ) -> MemoryObjectInfo: - del dtype, kind - self.register_matrix( - name=name, - shape=shape, - hbm_base_addr=hbm_addr, - real_data_ratio=real_data_ratio, - strict=strict, - ) - return self[name] - - def free_hbm_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.hbm_matrices: - if strict: - raise KeyError(f"HBM object '{name}' not found") - return None - info = self[name] - self.hbm_matrices.pop(name, None) - return info - - def add_vram_object( - self, - name: str, - shape: tuple[int, int], - vram_addr: int | None = None, - dtype: str = "fp16", - kind: str = "VRAMObject", - allocate_if_none: bool = True, - strict: bool = True, - ) -> MemoryObjectInfo: - rows, cols = shape - size = rows * cols - if vram_addr is None: - if not allocate_if_none: - raise ValueError("vram_addr is None and allocate_if_none is False") - vram_addr = self.vram_allocator.allocate(size=size, name=name) - del dtype, kind - self.register_vram_matrix( - name=name, - shape=shape, - vram_base_addr=vram_addr, - strict=strict, - ) - return self[name] - - def free_vram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.vram_matrices: - if strict: - raise KeyError(f"VRAM object '{name}' not found") - return None - info = self[name] - self.vram_allocator.free(name, strict=strict) - self.vram_matrices.pop(name, None) - return info - - def add_fpram_object( - self, - name: str, - size: int, - dtype: str = "fp16", - kind: str = "FPRAMObject", - ) -> MemoryObjectInfo: - fpram_addr = self.fpram_allocator.allocate(name, size) - self.fpram_matrices[name] = FPRAMObjectLayout( - name=name, - fpram_addr=fpram_addr, - size=size, - dtype=dtype, - kind=kind, - ) - return self[name] - - def free_fpram_object(self, name: str, strict: bool = True) -> MemoryObjectInfo | None: - if name not in self.fpram_matrices: - if strict: - raise KeyError(f"FPRAM object '{name}' not found") - return None - info = self[name] - self.fpram_allocator.free(name, strict=strict) - self.fpram_matrices.pop(name, None) - return info - - def register_matrix( - self, - name: str, - shape: tuple[int, int], - hbm_base_addr: int, - real_data_ratio: float = 1.125, - strict: bool = True, - ) -> MatrixBlockLayout: - """ - Register a large matrix and automatically block it. - - Args: - name: matrix name - shape: full shape (rows, cols) - hbm_base_addr: HBM base address - real_data_ratio: HBM storage ratio (MXFP8 = 1.125) - strict: if False, skip mlen-alignment check (for raw HBM access) - - Returns: - MatrixBlockLayout object - """ - rows, cols = shape - - if strict: - if rows % self.mlen != 0: - raise ValueError(f"Matrix rows ({rows}) must be multiple of mlen ({self.mlen})") - if cols % self.mlen != 0: - raise ValueError(f"Matrix cols ({cols}) must be multiple of mlen ({self.mlen})") - - size = rows * cols - hbm_size = int(size * real_data_ratio) - - layout = MatrixBlockLayout( - name=name, full_shape=shape, block_size=self.mlen, hbm_base_addr=hbm_base_addr, hbm_size=hbm_size - ) - - self.hbm_matrices[name] = layout - return layout - - def get_sub_block(self, name: str, row_idx: int, col_idx: int) -> SubMatrixInfo: - """Get sub-block information.""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_sub_block(row_idx, col_idx) - - def get_row_blocks(self, name: str, row_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a row: matrix[row_idx][:]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_row_blocks(row_idx) - - def get_col_blocks(self, name: str, col_idx: int) -> list[SubMatrixInfo]: - """Get all sub-blocks in a column: matrix[:][col_idx]""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - return self.hbm_matrices[name].get_col_blocks(col_idx) - - # ========================================================================== - # VRAM Sub-matrix Management - # ========================================================================== - - def register_vram_matrix( - self, - name: str, - shape: tuple[int, int], - vram_base_addr: int, - strict: bool = True, - ) -> VRAMMatrixBlockLayout: - """ - Register a large matrix in VRAM and automatically block it. - - Args: - name: matrix name - shape: full shape (batch, hidden_size) - vram_base_addr: VRAM base address - - Returns: - VRAMMatrixBlockLayout object - """ - batch, hidden = shape - - if strict: - if batch % self.mlen != 0: - raise ValueError(f"VRAM matrix batch ({batch}) must be multiple of mlen ({self.mlen})") - if hidden % self.mlen != 0: - raise ValueError(f"VRAM matrix hidden ({hidden}) must be multiple of mlen ({self.mlen})") - - layout = VRAMMatrixBlockLayout(name=name, full_shape=shape, vram_base_addr=vram_base_addr, block_size=self.mlen) - - self.vram_matrices[name] = layout - return layout - - def get_vram_sub_block(self, name: str, row_idx: int, col_idx: int) -> VRAMSubMatrixInfo: - """Get VRAM sub-block information""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_sub_block(row_idx, col_idx) - - def get_vram_row_blocks(self, name: str, row_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a row of VRAM matrix: matrix[row_idx][:]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_row_blocks(row_idx) - - def get_vram_col_blocks(self, name: str, col_idx: int) -> list[VRAMSubMatrixInfo]: - """Get all sub-blocks in a column of VRAM matrix: matrix[:][col_idx]""" - if name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{name}' not registered") - return self.vram_matrices[name].get_col_blocks(col_idx) - - # ========================================================================== - # Address Calculation (Core - Pre-calculated during compiler phase) - # ========================================================================== - - def compute_hbm_offset(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Compute HBM offset for sub-block (in elements, not bytes). - - HBM row-major: sub-block (r, c) starts at r*block_size*full_cols + c*block_size. - """ - layout = self.hbm_matrices[name] - sub_block = layout.get_sub_block(row_idx, col_idx) - return sub_block.hbm_offset - - def compute_absolute_hbm_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """ - Calculate absolute HBM address of sub-block (in elements). - - Returns: - Absolute HBM address = base + offset - """ - layout = self.hbm_matrices[name] - offset = self.compute_hbm_offset(name, row_idx, col_idx) - return layout.hbm_base_addr + offset - - # ========================================================================== - # ISA Generation: Load Sub Matrix - # ========================================================================== - - def load_sub_matrix_asm( - self, - name: str, - row_idx: int, - col_idx: int, - mram_dest_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """ - Generate ISA code for loading sub-matrix from HBM to MRAM. - - HBM is row-major; H_PREFETCH_M loads mlen x mlen blocks into MRAM. - SCALE = full matrix element count; STRIDE = full column width (row stride). - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - sub_block = layout.get_sub_block(row_idx, col_idx) - - hbm_offset = sub_block.hbm_offset - sub_block.mram_addr = mram_dest_addr - - lines = [] - lines.append(f"; Load SubMatrix {name}[{row_idx}][{col_idx}] -> MRAM[{mram_dest_addr}]") - lines.append(f"; HBM offset: {hbm_offset} (precomputed)") - - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_dest_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - return "\n".join(lines) + "\n" - - def load_row_sub_matrices_asm( - self, - name: str, - row_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - ) -> str: - """ - Generate ISA code for loading all sub-blocks in a row: matrix[row_idx][:] - - SCALE and STRIDE set once for all sub-blocks in the row. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - num_col_blocks = layout.num_col_blocks - - lines = [] - lines.append(f"; Load SubMatrix Row {name}[{row_idx}][:] -> MRAM[{mram_start_addr}]") - - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - - for col_idx in range(num_col_blocks): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - lines.append(f"; SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size - - return "\n".join(lines) + "\n" - - def load_col_sub_matrices_asm( - self, - name: str, - col_idx: int, - mram_start_addr: int, - hbm_addr_reg: int = 1, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Generate ISA code for loading all sub-blocks in a column: matrix[:][col_idx]. - - Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. - k_block_start/k_block_count select a K-split slice of the column. - """ - if gp_regs is None: - gp_regs = [1, 2, 3] - - layout = self.hbm_matrices[name] - num_row_blocks = layout.num_row_blocks - - lines = [] - lines.append(f"; Load SubMatrix Col {name}[:][{col_idx}] -> MRAM[{mram_start_addr}]") - - # Set SCALE and STRIDE once for all sub-blocks - full_size = layout.full_shape[0] * layout.full_shape[1] - full_cols = layout.full_shape[1] - - gp_scale = gp_regs[0] - gp_stride = gp_regs[1] - gp_mram = gp_regs[2] - - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {full_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_scale}") - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {full_cols}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - mram_addr = mram_start_addr - block_size = self.mlen * self.mlen - - effective_count = k_block_count if k_block_count is not None else num_row_blocks - for row_idx in range(k_block_start, k_block_start + effective_count): - sub_block = layout.get_sub_block(row_idx, col_idx) - hbm_offset = sub_block.hbm_offset - - sub_block.mram_addr = mram_addr - - lines.append(f"; SubBlock [{row_idx}][{col_idx}]: HBM offset = {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}") - lines.append(f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{hbm_addr_reg}, 1, 0") - - block_key = f"{name}[{row_idx}][{col_idx}]" - self.loaded_sub_blocks[block_key] = sub_block - - mram_addr += block_size - - return "\n".join(lines) + "\n" - - # ========================================================================== - # ISA Generation: Sub Projection - # ========================================================================== - - def _vram_sub_projection_asm_impl( - self, - header_lines: list[str], - vram_row_start_addr: int, - mram_start_addr: int, - result_vram_addr: int, - full_batch: int, - num_hidden_blocks: int, - mat_col_stride: int, - transposed: bool, - gp_regs: list[int], - caller_name: str, - unroll: bool | None = None, - ) -> str: - """ - Shared implementation kernel for vram_sub_projection_asm and - vram_sub_projection_T_asm. - - Parameters resolved by the caller before this point: - header_lines -- comment lines already assembled by the caller - vram_row_start_addr -- VRAM address of the first activation block - mram_start_addr -- MRAM address of the first weight block - result_vram_addr -- VRAM destination address for the (mlen, mlen) result - full_batch -- full batch dimension of the activation VRAM matrix - num_hidden_blocks -- number of K-blocks to accumulate over - mat_col_stride -- MRAM outer-column stride (blen for M_MM, blen*mlen for M_TMM) - transposed -- True → emit M_TMM with (act, mat) operand order; - False → emit M_MM with (mat, act) operand order - gp_regs -- list of at least 9 GP register indices - caller_name -- used in error messages only - unroll -- override instance unroll_loops flag (None = use instance default) - - Returns: - ISA code string. - """ - do_unroll = self.unroll_loops if unroll is None else unroll - return vram_sub_projection_asm_impl( - mlen=self.mlen, - blen=self.blen, - unroll_loops=do_unroll, - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=transposed, - gp_regs=gp_regs, - caller_name=caller_name, - ) - - def vram_sub_projection_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_col_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - unroll: bool | None = None, - ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[:][col_idx] - - VRAM_A[row_idx][:] is (mlen, hidden_size) spread across multiple (mlen, mlen) column blocks. - MRAM_W[:][col_idx] is (hidden_size, mlen) already loaded into MRAM. - Result is (mlen, mlen) written to VRAM. - - Loop structure (outer=output cols by blen, middle=output rows by blen, inner=K accumulation). - k_block_start/k_block_count select a K-split chunk. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) - mram_col_blocks = mram_layout.get_col_blocks(mram_col_idx) - # K-split: slice to only the loaded k-chunk - if k_block_count is not None: - mram_col_blocks = mram_col_blocks[k_block_start : k_block_start + k_block_count] - - num_hidden_blocks = len(mram_col_blocks) - if num_hidden_blocks != (k_block_count if k_block_count is not None else len(vram_row_blocks)): - raise ValueError( - f"Dimension mismatch: expected {k_block_count or len(vram_row_blocks)} MRAM blocks, " - f"got {num_hidden_blocks}" - ) - - for sub_block in mram_col_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{sub_block.row_idx}][{mram_col_idx}] not loaded to MRAM") - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[k_block_start].vram_addr - mram_col_start_addr = mram_col_blocks[0].mram_addr - - header_lines = [ - f"; VRAM Sub Projection: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}]", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden) spread across {num_hidden_blocks} column blocks", - f"; MRAM W[:][col_idx]: (hidden, {self.mlen}) with {num_hidden_blocks} sub-blocks", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_col_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=self.blen, - transposed=False, - gp_regs=gp_regs, - caller_name="vram_sub_projection_asm", - unroll=unroll, - ) - - def vram_sub_projection_T_asm( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_row_idx: int, - result_vram_addr: int, - gp_regs: list[int] | None = None, - unroll: bool | None = None, - ) -> str: - """ - Generate ISA for VRAM sub-block × MRAM sub-matrix transposed multiply. - - Computes: result = VRAM_A[row_idx][:] @ MRAM_W[row_idx][:]^T - - VRAM_A[row_idx][:] is (mlen, hidden_size). - MRAM_W[row_idx][:] is (mlen, hidden_size); transposed to (hidden_size, mlen). - Result is (mlen, mlen) written to VRAM. - - Uses M_TMM instruction. mat_col_stride = blen*mlen (transposed addressing). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5, 6, 7, 8, 9] - - if vram_mat_name not in self.vram_matrices: - raise KeyError(f"VRAM matrix '{vram_mat_name}' not registered") - vram_layout = self.vram_matrices[vram_mat_name] - - mram_layout = self.hbm_matrices[mram_mat_name] - - vram_row_blocks = vram_layout.get_row_blocks(vram_row_idx) - mram_row_blocks = mram_layout.get_row_blocks(mram_row_idx) - - if len(vram_row_blocks) != len(mram_row_blocks): - raise ValueError( - f"Dimension mismatch: VRAM has {len(vram_row_blocks)} blocks, MRAM has {len(mram_row_blocks)} blocks" - ) - - num_hidden_blocks = len(vram_row_blocks) - - for sub_block in mram_row_blocks: - if sub_block.mram_addr is None: - raise RuntimeError(f"SubBlock {mram_mat_name}[{mram_row_idx}][{sub_block.col_idx}] not loaded to MRAM") - - full_batch = vram_layout.full_shape[0] - vram_row_start_addr = vram_row_blocks[0].vram_addr - mram_row_start_addr = mram_row_blocks[0].mram_addr - # NOTE: For M_TMM (transposed matmul), the MRAM outer-column stride is - # blen * mlen (full sub-block size) because M_TMM reads the weight in - # transposed layout. This differs from the non-transposed path which uses - # blen. This is intentional for the M_TMM addressing contract. - mat_col_stride = self.blen * self.mlen - - header_lines = [ - f"; VRAM Sub Projection T: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T", - f"; VRAM A[row_idx][:]: ({self.mlen}, hidden)", - f"; MRAM W[row_idx][:]^T: (hidden, {self.mlen})", - f"; Result: ({self.mlen}, {self.mlen}) at VRAM[{result_vram_addr}]", - ] - - return self._vram_sub_projection_asm_impl( - header_lines=header_lines, - vram_row_start_addr=vram_row_start_addr, - mram_start_addr=mram_row_start_addr, - result_vram_addr=result_vram_addr, - full_batch=full_batch, - num_hidden_blocks=num_hidden_blocks, - mat_col_stride=mat_col_stride, - transposed=True, - gp_regs=gp_regs, - caller_name="vram_sub_projection_T_asm", - unroll=unroll, - ) - - def vram_block_add_asm( - self, - src1_name: str, - src1_row_idx: int, - src1_col_idx: int, - src2_name: str, - src2_row_idx: int, - src2_col_idx: int, - target_name: str, - target_row_idx: int, - target_col_idx: int, - gp_regs: list[int] | None = None, - ) -> str: - """ - Add two mlen x mlen blocks and write to any target block: - target[target_row_idx][target_col_idx] = - src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] - - Source/target can be the same matrix (supports in-place overwrite). - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4] - if len(gp_regs) < 4: - raise ValueError(f"Need at least 4 GP regs, got {len(gp_regs)}") - - src1_block = self.get_vram_sub_block(src1_name, src1_row_idx, src1_col_idx) - src2_block = self.get_vram_sub_block(src2_name, src2_row_idx, src2_col_idx) - target_block = self.get_vram_sub_block(target_name, target_row_idx, target_col_idx) - - gp_dst = gp_regs[0] - gp_src1 = gp_regs[1] - gp_src2 = gp_regs[2] - gp_loop = gp_regs[3] - - lines = [] - lines.append( - f"; VRAM Block Add: {target_name}[{target_row_idx}][{target_col_idx}] = " - f"{src1_name}[{src1_row_idx}][{src1_col_idx}] + {src2_name}[{src2_row_idx}][{src2_col_idx}]" - ) - - # One V_ADD_VV processes one row (mlen elements). Use C_LOOP to reduce ISA size. - if self.unroll_loops: - for i in range(self.mlen): - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr + i * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr + i * self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - else: - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {target_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp0, {src1_block.vram_addr}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp0, {src2_block.vram_addr}") - lines.append(f"C_LOOP_START gp{gp_loop}, {self.mlen}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_src1}, gp{gp_src2}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src1}, gp{gp_src1}, {self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src2}, gp{gp_src2}, {self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - - return "\n".join(lines) + "\n" - - # ========================================================================== - # High-level Interface: Complete Sub-block Computation - # ========================================================================== - - def compute_sub_matmul( - self, - a_name: str, - a_row_idx: int | slice, - b_name: str, - b_col_idx: int | slice, - result_name: str, - transpose_b: bool = False, - ) -> tuple[str, int]: - """Not yet implemented. Use vram_sub_projection_asm or vram_sub_projection_T_asm.""" - raise NotImplementedError( - "compute_sub_matmul is not yet implemented. Use vram_sub_projection_asm " - "or vram_sub_projection_T_asm for matrix multiplication." - ) - - # ========================================================================== - # Format Conversion: HBM <-> VRAM - # ========================================================================== - - def load_activation_with_format_convert_asm( - self, - name: str, - hbm_base_addr: int, - batch: int, - hidden_size: int, - vram_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Load activation from HBM to VRAM with format conversion. - - HBM layout: [batch, hidden_size] row-major (element[b,h] at hbm_base + b*hidden_size + h). - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - element[b,h]: vram_base + (h//mlen)*batch*mlen + b*mlen + (h%mlen). - - H_PREFETCH_V loads mlen elements per call; columns loaded in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - lines = [] - lines.append(f"; Load Activation with Format Convert: {name}") - lines.append(f"; HBM[{hbm_base_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - lines.append(f"; VRAM[{vram_dest_addr}]: [batch, mlen, hidden/mlen] column-block major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - preload_len = 4 # load 4 rows per H_PREFETCH_V call - - total_size = batch * hidden_size - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {total_size}") - lines.append(f"C_SET_SCALE_REG gp{gp_hbm_offset}") - - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {hidden_size}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - for col_block in range(num_col_blocks): - lines.append(f"; Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_dest_addr + col_block * batch * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr}") - - for batch_block in range(math.ceil(batch / preload_len)): - actual_batch_offset = batch_block * preload_len * hidden_size - actual_vram_offset = batch_block * preload_len * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset + actual_batch_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr + actual_vram_offset}") - lines.append(f"H_PREFETCH_V gp{gp_vram}, gp{gp_hbm_offset}, a{hbm_addr_reg}, 1, 0") - - return "\n".join(lines) + "\n" - - def store_activation_with_format_convert_asm( - self, - name: str, - vram_src_addr: int, - batch: int, - hidden_size: int, - hbm_dest_addr: int, - hbm_addr_reg: int = 0, - gp_regs: list[int] | None = None, - ) -> str: - """ - Store activation from VRAM to HBM with format conversion. - - VRAM layout: [batch, mlen, hidden/mlen] column-block major. - HBM layout: [batch, hidden_size] row-major. - - H_STORE_V stores mlen elements per call; columns stored in blocks of mlen. - """ - if gp_regs is None: - gp_regs = [1, 2, 3, 4, 5] - - lines = [] - lines.append(f"; Store Activation with Format Convert: {name}") - lines.append(f"; VRAM[{vram_src_addr}]: [batch, mlen, hidden/mlen] column-block major") - lines.append(f"; HBM[{hbm_dest_addr}]: [batch={batch}, hidden={hidden_size}] row-major") - - gp_hbm_offset = gp_regs[0] - gp_stride = gp_regs[1] - gp_vram = gp_regs[2] - _gp_outer = gp_regs[3] - _gp_inner = gp_regs[4] - - num_col_blocks = hidden_size // self.mlen - store_amount = 4 # store 4 rows per H_STORE_V call - - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, {hidden_size}") - lines.append(f"C_SET_STRIDE_REG gp{gp_stride}") - - for col_block in range(num_col_blocks): - lines.append(f"; Column block {col_block}") - - hbm_offset = col_block * self.mlen - vram_addr = vram_src_addr + col_block * batch * self.mlen - - for batch_block in range(math.ceil(batch / store_amount)): - actual_batch_offset = batch_block * store_amount * hidden_size - actual_vram_offset = batch_block * store_amount * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_hbm_offset}, gp0, {hbm_offset + actual_batch_offset}") - lines.append(f"S_ADDI_INT gp{gp_vram}, gp0, {vram_addr + actual_vram_offset}") - lines.append(f"H_STORE_V gp{gp_vram}, gp{gp_hbm_offset}, a{hbm_addr_reg}, 0") - - return "\n".join(lines) + "\n" - - # ========================================================================== - # Pre-calculated Address Table Generation - # ========================================================================== - - def generate_address_table(self, name: str) -> dict[str, int]: - """Generate complete address table for a matrix (for debugging and verification).""" - if name not in self.hbm_matrices: - raise KeyError(f"Matrix '{name}' not registered") - - layout = self.hbm_matrices[name] - addr_table = {} - - for (r, c), sub_block in layout.sub_blocks.items(): - key = f"{name}[{r}][{c}]" - addr_table[f"{key}_hbm_offset"] = sub_block.hbm_offset - addr_table[f"{key}_hbm_abs"] = layout.hbm_base_addr + sub_block.hbm_offset - if sub_block.mram_addr is not None: - addr_table[f"{key}_mram"] = sub_block.mram_addr - - return addr_table - - def print_address_table(self, name: str): - """Print address table for a matrix.""" - addr_table = self.generate_address_table(name) - print(f"Address Table for {name}:") - for key, addr in sorted(addr_table.items()): - print(f" {key}: {addr}") - - # ========================================================================== - # Helper Methods - # ========================================================================== - - def get_loaded_block_addr(self, name: str, row_idx: int, col_idx: int) -> int: - """Get MRAM address of a loaded sub-block.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - if block_key not in self.loaded_sub_blocks: - raise KeyError(f"SubBlock {block_key} not loaded") - return self.loaded_sub_blocks[block_key].mram_addr - - def is_block_loaded(self, name: str, row_idx: int, col_idx: int) -> bool: - """Check whether a sub-block has been loaded into MRAM.""" - block_key = f"{name}[{row_idx}][{col_idx}]" - return block_key in self.loaded_sub_blocks - - def reset(self): - """Reset manager state.""" - self.hbm_matrices.clear() - self.vram_matrices.clear() - self.fpram_matrices.clear() - self.vram_allocator.reset() - self.mram_allocator.reset() - self.fpram_allocator.reset() - self.loaded_sub_blocks.clear() - self._address_cache.clear() - - def print_table(self): - """Print unified managed object table.""" - print("=" * 95) - print("Managed Object Table") - print("=" * 95) - print( - f"{'Name':<20} {'Kind':<12} {'Shape':<15} {'HBM Addr':<10} {'HBM Size':<10} {'VRAM Addr':<12} {'FPRAM':<12} {'Size':<8}" - ) - print("-" * 95) - names = sorted(set(self.hbm_matrices) | set(self.vram_matrices) | set(self.fpram_matrices)) - for name in names: - info = self[name] - shape_str = f"({info.shape[0]}, {info.shape[1]})" - vram_str = f"{info.vram_addr}" if info.vram_addr is not None else "None" - fpram_str = f"{info.fpram_addr}/{info.fpram_size}" if info.fpram_addr is not None else "None" - print( - f"{name:<20} {info.kind:<12} {shape_str:<15} " - f"{info.hbm_addr:<10} {info.hbm_size:<10} {vram_str:<12} {fpram_str:<12} {info.size:<8}" - ) - print("=" * 95) - - def print_layout(self, name: str): - """Print block layout for a matrix.""" - if name not in self.hbm_matrices: - print(f"Matrix '{name}' not registered") - return - - layout = self.hbm_matrices[name] - print(f"Matrix: {name}") - print(f" Full shape: {layout.full_shape}") - print(f" Block size: {layout.block_size}") - print(f" Blocks: {layout.num_row_blocks} x {layout.num_col_blocks}") - print(f" HBM base: {layout.hbm_base_addr}") - print(" Sub blocks:") - for (r, c), sub in layout.sub_blocks.items(): - loaded = "LOADED" if sub.mram_addr is not None else "" - print(f" [{r}][{c}]: hbm_off={sub.hbm_offset}, mram={sub.mram_addr} {loaded}") - - -# ============================================================================== -# Example Usage -# ============================================================================== - - -class RegisterAllocator: - """Register Allocator: Manages address registers and GP registers""" - - def __init__(self, start_gp: int = 1, start_addr: int = 0, start_fp: int = 1): - # HW OPERAND_WIDTH = 4 bits → gp0-gp15; gp0 reserved as constant 0. - self.gp_registers = list(range(start_gp, 16)) - self.addr_registers = list(range(start_addr, 8)) - # f0 reserved as constant 0 (writing to f0 is a no-op for V_RED_MAX/V_RED_SUM). - self.fp_registers = list(range(start_fp, 8)) - self.used_gp = [] - self.used_addr = [] - self.used_fp = [] - - def allocate_gp(self, count: int = 1) -> list[int]: - if len(self.gp_registers) < count: - raise RuntimeError(f"Not enough GP registers available. Need {count}, have {len(self.gp_registers)}") - - allocated = self.gp_registers[:count] - self.gp_registers = self.gp_registers[count:] - self.used_gp.extend(allocated) - return allocated - - def allocate_addr(self, count: int = 1) -> list[int]: - if len(self.addr_registers) < count: - raise RuntimeError(f"Not enough address registers available. Need {count}, have {len(self.addr_registers)}") - - allocated = self.addr_registers[:count] - self.addr_registers = self.addr_registers[count:] - self.used_addr.extend(allocated) - return allocated - - def free_gp(self, registers: list[int]): - for reg in registers: - if reg in self.used_gp: - self.used_gp.remove(reg) - self.gp_registers.append(reg) - self.gp_registers.sort() - - def free_addr(self, registers: list[int]): - for reg in registers: - if reg in self.used_addr: - self.used_addr.remove(reg) - self.addr_registers.append(reg) - self.addr_registers.sort() - - def allocate_fp(self, count: int = 1) -> list[int]: - if len(self.fp_registers) < count: - raise RuntimeError(f"Not enough FP registers available. Need {count}, have {len(self.fp_registers)}") - - # Reverse allocation: prefer high-numbered regs to avoid conflicts with legacy hardcoded forward-allocation. - allocated = list(reversed(self.fp_registers[-count:])) - self.fp_registers = self.fp_registers[:-count] - self.used_fp.extend(allocated) - return allocated - - def free_fp(self, registers: list[int]): - for reg in registers: - if reg in self.used_fp: - self.used_fp.remove(reg) - self.fp_registers.append(reg) - # Keep sorted so allocate_fp's tail-slice continues to return descending IDs. - self.fp_registers.sort() - - -class DeveloperCompiler(TileCompiler): - """ - Developer Compiler: Compiles high-level IR to ISA. - - Owns symbol_table, register_allocator, and the InterruptManager. - Sub-matrix / memory management is inherited from TileCompiler; the - legacy ``self.tile_compiler`` accessor is preserved as a property - returning ``self`` for a handful of remaining external callers. - """ - - _ONLINE_SOFTMAX_FPSRAM_BASE = 10 - - class InterruptManager: - """ - Interrupt Manager — manages execution timing only. - Actual handlers live on DeveloperCompiler as ``_handle_k_start``, - ``_handle_k_prefetch_done``, ``_handle_s_tile_done``, ``_handle_k_end``. - """ - - def __init__(self, compiler: DeveloperCompiler): - self.compiler = compiler - self.enabled = False - - self._k_count = 0 - self._tile_count = 0 - - self.current_matrix: str = "" - self.current_activation: str = "" - self._mlen = compiler.mlen - self._blen = compiler.blen - self._batch = compiler.mlen - - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - @property - def tile_compiler(self): - return self.compiler.tile_compiler - - @property - def k_count(self) -> int: - return self._k_count - - @property - def tile_count(self) -> int: - return self._tile_count - - @property - def batch(self) -> int: - return self._batch - - @property - def out_features(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[0] - return self._mlen - - @property - def hidden_size(self) -> int: - if self.current_matrix and self.current_matrix in self: - info = self[self.current_matrix] - return info.shape[1] - return self._mlen - - @property - def k_block(self) -> int: - return self._k_block_idx - - @property - def q_block(self) -> int: - return self._q_block_idx - - @property - def s_tile_address(self) -> int: - return self._s_tile_address - - @property - def mlen(self) -> int: - return self._mlen - - @property - def blen(self) -> int: - return self._blen - - def reset(self): - """Reset counters (does not clear current_matrix).""" - self._k_count = 0 - self._tile_count = 0 - self._q_block_idx = 0 - self._k_block_idx = 0 - self._s_tile_address = 0 - - def enable(self): - self.enabled = True - - def disable(self): - self.enabled = False - - def trigger_k_start(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_start() - - def trigger_k_prefetch_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_k_prefetch_done() - self._k_count += 1 - return result - - def trigger_s_tile_done(self) -> str: - if not self.enabled: - return "" - result = self.compiler._handle_s_tile_done() - self._tile_count += 1 - return result - - def trigger_k_end(self) -> str: - if not self.enabled: - return "" - return self.compiler._handle_k_end() - - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - # TileCompiler.__init__ sets mlen, blen, unroll_loops, the HBM/VRAM/FPRAM - # matrices and allocators, loaded_sub_blocks, and _address_cache. - super().__init__(mlen=mlen, blen=blen, unroll_loops=unroll_loops) - self.real_data_ratio = real_data_ratio - self.register_allocator = RegisterAllocator() - self.generated_code = "" - self.interrupt = self.InterruptManager(self) - - # Back-compat shim: older callers (and a couple of external ops modules) - # reach into ``compiler.tile_compiler`` directly. After the merge, - # DeveloperCompiler *is* the TileCompiler, so the property just returns - # ``self``. - @property - def tile_compiler(self) -> DeveloperCompiler: - return self - - # Interrupt handler placeholders (overridden by flash-attention passes). - - def _handle_k_start(self) -> str: - return "" - - def _handle_k_prefetch_done(self) -> str: - return "" - - def _handle_s_tile_done(self) -> str: - return "" - - def _handle_k_end(self) -> str: - return "" - - # ========================================================================= - # Flash Attention Implementation - # ========================================================================= - - def _online_softmax_asm( - self, - mlen: int, - s_address: int, - m_start_address: int, - scale: float = 1.0, - ) -> str: - """ - Online Softmax Computation. - - Per row of S: - 1. m_curr = max(S[row], m_old) - 2. m_res = exp(m_old - m_curr) # used to update O downstream - 3. S'[row] = S[row] - m_curr - 4. P[row] = exp(S'[row]) - 5. l_new = l_old * m_res + sum(P[row]) - - FP SRAM layout (from m_start_address): - [0, mlen): m_old / m_curr - [mlen, 2*mlen): m_res = exp(m_old - m_curr) - [2*mlen, 3*mlen): l_old / l_new - """ - gp_regs = self.register_allocator.allocate_gp(4) - gp_s = gp_regs[0] - gp_m_addr = gp_regs[1] - gp_m_res_addr = gp_regs[2] - gp_l_addr = gp_regs[3] - - # Fixed FP register allocation for online softmax pipeline. - # These registers are shared across _online_softmax_asm, _scale_o_asm, - # and _final_scaling_asm — they MUST remain consistent across all three. - # WARNING: Do not use f1-f6 in any code that calls these methods. - fp_m_old = 1 # f1: m_old value - fp_m_res = 2 # f2: exp(m_old - m_curr) - fp_l_old = 3 # f3: l_old value - fp_sum_p = 4 # f4: sum(P) - fp_scale = 5 # f5: scale factor - fp_row_max = 6 # f6: current row max (temporary) - - lines = [] - lines.append("; === Online Softmax ===") - - # Set address registers - lines.append(f"S_ADDI_INT gp{gp_s}, gp0, {s_address}") - lines.append(f"S_ADDI_INT gp{gp_m_addr}, gp0, {m_start_address}") - lines.append(f"S_ADDI_INT gp{gp_m_res_addr}, gp{gp_m_addr}, {mlen}") - lines.append(f"S_ADDI_INT gp{gp_l_addr}, gp{gp_m_res_addr}, {mlen}") - - # scale factor is pre-loaded at FP SRAM addr 1 by the flash-attention driver. - if scale != 1.0: - lines.append(f"S_LD_FP f{fp_scale}, gp0, 1") - - for row in range(mlen): - lines.append(f"; Row {row}") - - lines.append(f"S_LD_FP f{fp_m_old}, gp{gp_m_addr}, {row}") - lines.append(f"S_ADD_FP f{fp_m_res}, f{fp_m_old}, f0") - - if scale != 1.0: - lines.append(f"V_MUL_VF gp{gp_s}, gp{gp_s}, f{fp_scale}, 0") - - lines.append(f"V_RED_MAX f{fp_row_max}, gp{gp_s}, 0") - - # m_curr = max(row_max, m_old) — online softmax must retain the running max. - lines.append(f"S_MAX_FP f{fp_m_old}, f{fp_row_max}, f{fp_m_old}") - - lines.append(f"S_SUB_FP f{fp_m_res}, f{fp_m_res}, f{fp_m_old}") - lines.append(f"S_EXP_FP f{fp_m_res}, f{fp_m_res}, 0") - - lines.append(f"S_ST_FP f{fp_m_res}, gp{gp_m_res_addr}, {row}") - lines.append(f"S_ST_FP f{fp_m_old}, gp{gp_m_addr}, {row}") - - lines.append(f"V_SUB_VF gp{gp_s}, gp{gp_s}, f{fp_m_old}, 0, 0") - lines.append(f"V_EXP_V gp{gp_s}, gp{gp_s}, 0, 0") - - lines.append(f"S_LD_FP f{fp_l_old}, gp{gp_l_addr}, {row}") - - lines.append(f"S_ADD_FP f{fp_sum_p}, f0, f0") - lines.append(f"V_RED_SUM f{fp_sum_p}, gp{gp_s}, 0, 0") - - lines.append(f"S_MUL_FP f{fp_l_old}, f{fp_l_old}, f{fp_m_res}") - lines.append(f"S_ADD_FP f{fp_l_old}, f{fp_l_old}, f{fp_sum_p}") - - lines.append(f"S_ST_FP f{fp_l_old}, gp{gp_l_addr}, {row}") - - lines.append(f"S_ADDI_INT gp{gp_s}, gp{gp_s}, {mlen}") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _pv_multiply_asm( - self, - mlen: int, - blen: int, - head_dim: int, - p_address: int, - v_hbm_offset_reg: int, - v_hbm_offset: int, - pv_address: int, - ) -> str: - """ - Compute PV = P @ V via M_MM. - - P: (mlen, mlen) in VRAM (softmax output) - V: (mlen, head_dim) in HBM (prefetched into MSRAM in mlen-wide column blocks) - PV: (mlen, head_dim) in VRAM - - M_MM computes one (blen, mlen) @ (mlen, blen) -> (blen, blen) in a single op - (K=mlen done in one shot). For head_dim > mlen, V is split into head_dim/mlen - column blocks; the outer loop iterates blocks, middle loop iterates blen-wide - V columns within a block, inner loop iterates blen-wide P rows. - """ - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(5) - gp_p = gp_regs[0] - gp_v = gp_regs[1] - gp_pv = gp_regs[2] - gp_hbm = gp_regs[3] - gp_stride = gp_regs[4] - - num_v_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === PV Multiply (P @ V) using M_MM ===") - lines.append(f"; P: ({mlen}, {mlen}) @ V: ({mlen}, {head_dim}) -> PV: ({mlen}, {head_dim})") - lines.append("; M_MM: (blen, mlen) @ (mlen, blen) -> (blen, blen), K=mlen in one shot") - lines.append(f"; V split into {num_v_col_blocks} column blocks of width {mlen}") - lines.append("; Storage layout: (batch, mlen, hidden/mlen), column-block major") - - # STRIDE was set to mlen by the flash-attention driver — do not overwrite it here. - # M_MM_WO requires a nonzero stride reg (gp0=0 would be interpreted as stride=1). - # With column-block-major storage, consecutive rows within a column block are - # adjacent, so the writeback stride = 1. - lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") - - for v_col_block in range(num_v_col_blocks): - lines.append( - f"; --- V column block {v_col_block} (columns {v_col_block * mlen} to {(v_col_block + 1) * mlen - 1}) ---" - ) - - # Prefetch V[:, v_col_block*mlen:(v_col_block+1)*mlen] (mlen × mlen) to MSRAM. - # V is row-major in HBM: V[row, col] at offset row*head_dim + col, so the - # column-block base offset = v_hbm_offset + v_col_block * mlen (elements). - v_block_hbm_offset = v_hbm_offset + v_col_block * mlen - lines.append(f"S_ADDI_INT gp{gp_v}, gp0, 0") - lines.append(f"S_ADDI_INT gp{gp_hbm}, gp0, {v_block_hbm_offset}") - lines.append(f"H_PREFETCH_M gp{gp_v}, gp{gp_hbm}, a{v_hbm_offset_reg}, 1, 1") - - # mat_offset constraint: < mlen and a multiple of blen. - for v_col in range(mlen // blen): - lines.append(f"; V column {v_col_block * mlen + v_col * blen}") - - v_msram_offset = v_col * blen - lines.append(f"S_ADDI_INT gp{gp_v}, gp0, {v_msram_offset}") - - for p_row in range(mlen // blen): - p_row_addr = p_address + p_row * blen * mlen - lines.append(f"S_ADDI_INT gp{gp_p}, gp0, {p_row_addr}") - - lines.append(f"M_MM 0, gp{gp_v}, gp{gp_p}") - - # PV[row, col] addr = base + col_block * mlen * mlen + row * mlen + col_in_block - # with row = p_row * blen and col_in_block = v_col * blen. - pv_offset = v_col_block * mlen * mlen + p_row * blen * mlen + v_col * blen - lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_address + pv_offset}") - lines.append(f"M_MM_WO gp{gp_pv}, gp{gp_stride}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _scale_o_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - m_res_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """Scale each row of O by m_res: O[row] *= m_res[row].""" - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_m_res = gp_regs[0] - gp_o = gp_regs[1] - fp_m_res = 1 - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Scale O by m_res ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - lines.append(f"S_ADDI_INT gp{gp_m_res}, gp0, {m_res_address}") - - for row in range(mlen): - lines.append(f"S_LD_FP f{fp_m_res}, gp{gp_m_res}, {row}") - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_m_res}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _add_pv_to_o_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - pv_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """Accumulate PV into O: O[row] += PV[row].""" - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_o = gp_regs[0] - gp_pv = gp_regs[1] - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Add PV to O ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - for row in range(mlen): - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - pv_addr = pv_address + col_block * mlen * mlen + row * mlen - - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"S_ADDI_INT gp{gp_pv}, gp0, {pv_addr}") - lines.append(f"V_ADD_VV gp{gp_o}, gp{gp_o}, gp{gp_pv}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _final_scaling_asm( - self, - mlen: int, - head_dim: int, - seq_len: int, - l_address: int, - o_address: int, - row_offset: int = 0, - ) -> str: - """ - Final scaling: O[row] /= l[row]. - - V_MUL_VF processes mlen elements at a time; when head_dim > mlen, - each row is split into head_dim // mlen mlen-wide blocks. - """ - assert head_dim % mlen == 0, f"head_dim ({head_dim}) must be multiple of mlen ({mlen})" - - gp_regs = self.register_allocator.allocate_gp(2) - gp_l = gp_regs[0] - gp_o = gp_regs[1] - fp_l = 1 - - num_col_blocks = head_dim // mlen - - lines = [] - lines.append("; === Final Scaling O = O / l ===") - lines.append(f"; head_dim = {head_dim}, {num_col_blocks} mlen-blocks per row") - lines.append("; Storage layout: (seq_len, mlen, head_dim/mlen), column-block major") - lines.append(f"; seq_len = {seq_len}, row_offset = {row_offset}") - - lines.append(f"S_ADDI_INT gp{gp_l}, gp0, {l_address}") - - for row in range(mlen): - lines.append(f"S_LD_FP f{fp_l}, gp{gp_l}, {row}") - lines.append(f"S_RECI_FP f{fp_l}, f{fp_l}, 0") - actual_row = row_offset + row - - for col_block in range(num_col_blocks): - o_addr = o_address + col_block * seq_len * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {o_addr}") - lines.append(f"V_MUL_VF gp{gp_o}, gp{gp_o}, f{fp_l}, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _reset_fpsram_asm( - self, - start_address: int, - count: int, - value_address: int, # FP SRAM slot: 0 = zero, 2 = -inf - ) -> str: - """Reset a region of FP SRAM to the value at value_address.""" - gp_regs = self.register_allocator.allocate_gp(1) - gp_addr = gp_regs[0] - - lines = [] - lines.append(f"; Reset FP SRAM [{start_address}, {start_address + count})") - - lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {start_address}") - # Use f1 for FP scalar - FP registers don't go through GP allocator - lines.append(f"S_LD_FP f1, gp0, {value_address}") - - for i in range(count): - lines.append(f"S_ST_FP f1, gp{gp_addr}, {i}") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def _reset_vram_asm( - self, - start_address: int, - rows: int, - cols: int, - total_rows: int, - mlen: int = 64, - row_offset: int = 0, - ) -> str: - """ - Reset a region of VRAM to zero. - - V_MUL_VF processes mlen elements at a time; when cols > mlen, each - row is split into cols // mlen mlen-wide blocks. - """ - gp_regs = self.register_allocator.allocate_gp(1) - gp_addr = gp_regs[0] - - num_col_blocks = (cols + mlen - 1) // mlen - - lines = [] - lines.append(f"; Reset VRAM rows [{row_offset}, {row_offset + rows}) of matrix at {start_address}") - lines.append(f"; {rows} rows x {cols} cols, {num_col_blocks} blocks per row") - lines.append("; Storage layout: (total_rows, mlen, cols/mlen), column-block major") - lines.append(f"; total_rows = {total_rows}, row_offset = {row_offset}") - - for row in range(rows): - actual_row = row_offset + row - for col_block in range(num_col_blocks): - addr = start_address + col_block * total_rows * mlen + actual_row * mlen - lines.append(f"S_ADDI_INT gp{gp_addr}, gp0, {addr}") - lines.append(f"V_MUL_VF gp{gp_addr}, gp{gp_addr}, f0, 0") - - self.register_allocator.free_gp(gp_regs) - return "\n".join(lines) + "\n" - - def load_batch( - self, - hbm_object_name: str, - vram_object_name: str, - vlen: int = 64, - preload_len: int = 4, - ) -> str: - """ - Load a Batch tensor from HBM to VRAM. - - HBM storage is MXFP (1 scale per 8 elements), so HBM actual size = - logical size * real_data_ratio = 1.125. VRAM stores only the vector - data (no scale), so VRAM size = logical size. - - Order (matters): allocate VRAM → register in symbol table → emit ISA. - """ - hbm_layout = self.get_hbm_layout(hbm_object_name) - h, w = hbm_layout.full_shape - hbm_addr = hbm_layout.hbm_base_addr - size = h * w - vram_base = self.vram_allocator.allocate(size, name=vram_object_name) - self.add_vram_object( - name=vram_object_name, - shape=(h, w), - vram_addr=vram_base, - dtype="fp16", - kind="Batch", - allocate_if_none=False, - strict=False, - ) - - addr_reg = self.register_allocator.allocate_addr(1)[0] - gp_regs_for_addr = self.register_allocator.allocate_gp(1) - - isa_code = f"; Load_Batch {hbm_object_name} -> {vram_object_name}\n" - isa_code += f"; HBM[{hbm_addr}] → VRAM[{vram_base}], shape=({h}, {w})\n" - - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] - ) - - # preload_act_asm requires 5 GP registers: [a_actual, stride, result, outer_loop, inner_loop]. - gp_regs_for_preload = self.register_allocator.allocate_gp(5) - isa_code += reset_reg_asm(alive_registers=gp_regs_for_preload) - - isa_code += preload_act_asm( - vlen=vlen, - preload_len=preload_len, - batch=h, - hidden_size=w, - alive_registers=gp_regs_for_preload, - act_vram_offset=vram_base, - activation_offset_reg=addr_reg, - stride_size=w, - ) - - self.register_allocator.free_gp(gp_regs_for_addr) - self.register_allocator.free_gp(gp_regs_for_preload) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - - return isa_code - - def store_to_hbm( - self, - tensor_name: str, - hbm_addr: int | None = None, - hbm_object_name: str | None = None, - hbm_addr_reg: int | None = None, - vlen: int = 64, - precision: int = 0, # 0 = Activation, 1 = KeyValue - store_amount: int = 4, # HBM_V_Writeback_Amount - ) -> str: - """ - Write tensor from VRAM back to HBM. - - Used to spill computed intermediates (e.g., K) from VRAM to HBM so - downstream ops (e.g., QK^T) can read them from HBM. Emits - ``store_act_asm`` for tensors of any supported size. - """ - if tensor_name not in self: - raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") - - tensor_info = self[tensor_name] - - # Batch and VRAMMatrix share the same VRAM storage layout. - if tensor_info.kind not in ("Batch", "VRAMMatrix"): - raise ValueError( - f"Tensor '{tensor_name}' must be Batch or VRAMMatrix to store from VRAM, got {tensor_info.kind}" - ) - - if tensor_info.vram_addr is None: - raise ValueError(f"Tensor '{tensor_name}' has no VRAM address to store") - - if hbm_addr is None: - if tensor_info.hbm_addr >= 0: - hbm_addr = tensor_info.hbm_addr - else: - raise ValueError(f"Tensor '{tensor_name}' has no HBM address. Please specify hbm_addr.") - - batch_size = tensor_info.shape[0] - hidden_size = tensor_info.shape[1] - - isa_code = f"; Store {tensor_name} from VRAM to HBM\n" - isa_code += f"; VRAM[{tensor_info.vram_addr}] -> HBM[{hbm_addr}], shape=({batch_size}, {hidden_size})\n" - - gp_regs = self.register_allocator.allocate_gp(5) - - if hbm_addr_reg is None: - addr_regs = self.register_allocator.allocate_addr(1) - hbm_addr_reg = addr_regs[0] - need_free_addr = True - else: - addr_regs = [] - need_free_addr = False - - try: - gp_regs_for_addr = self.register_allocator.allocate_gp(2) - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[hbm_addr_reg], available_registers=gp_regs_for_addr, addr_reg_val=[hbm_addr] - ) - self.register_allocator.free_gp(gp_regs_for_addr) - - isa_code += store_act_asm( - vlen=vlen, - batch=batch_size, - hidden_size=hidden_size, - alive_registers=gp_regs, - act_vram_offset=tensor_info.vram_addr, - hbm_addr_reg=hbm_addr_reg, - stride_size=hidden_size, - store_amount=store_amount, - ) - - if tensor_info.hbm_addr < 0 or tensor_info.hbm_addr != hbm_addr: - tensor_info.hbm_addr = hbm_addr - # HBM stores the MXFP-expanded size (logical size × real_data_ratio). - size = batch_size * hidden_size - tensor_info.hbm_size = int(size * self.real_data_ratio) - finally: - self.register_allocator.free_gp(gp_regs) - if need_free_addr: - self.register_allocator.free_addr(addr_regs) - - if hbm_object_name is not None: - self.add_hbm_object( - name=hbm_object_name, - hbm_addr=hbm_addr, - shape=(batch_size, hidden_size), - ) - - self.generated_code += isa_code - - return isa_code - - def normalize( - self, - tensor_name: str, - mode: str = "rms", - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> str: - """ - Normalize a VRAM tensor in-place. - - Supports: - - mode="rms": RMSNorm - - mode="layer": LayerNorm - - Args: - tensor_name: Tensor name in symbol table (must have VRAM address) - mode: "rms" or "layer" - eps_offset: FPRAM address of epsilon - reci_hid_offset: FPRAM address of 1/hidden_dim - vlen: vector length (default: self.mlen) - scratchpad_vram_addr: scratchpad VRAM address (default: auto-allocate temporary space) - """ - if tensor_name not in self: - raise KeyError(f"Tensor '{tensor_name}' not found in symbol table") - - tensor_info = self[tensor_name] - if tensor_info.vram_addr is None: - raise ValueError(f"Tensor '{tensor_name}' has no VRAM address") - - batch_size, hidden_dim = tensor_info.shape - if vlen is None: - vlen = self.mlen - if hidden_dim % vlen != 0: - raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by vlen ({vlen}) for normalization_asm") - - mode = mode.lower() - if mode not in ("rms", "layer"): - raise ValueError(f"Unsupported normalization mode: {mode}. Expected 'rms' or 'layer'.") - - gp_regs = self.register_allocator.allocate_gp(4) - - temp_scratchpad_name = None - if scratchpad_vram_addr is None: - temp_scratchpad_name = f"__norm_scratch__{tensor_name}__{len(self.generated_code)}" - scratchpad_vram_addr = self.vram_allocator.allocate(vlen, name=temp_scratchpad_name) - - try: - isa_code = f"; Normalize ({mode}) {tensor_name}, shape=({batch_size}, {hidden_dim})\n" - if mode == "rms": - isa_code += rms_norm_asm( - _eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - alive_registers=gp_regs, - activation_base_address=tensor_info.vram_addr, - scratchpad_base_address=scratchpad_vram_addr, - vlen=vlen, - batch_size=batch_size, - hidden_dim=hidden_dim, - ) - else: - isa_code += layer_norm_asm( - _eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - alive_registers=gp_regs, - activation_base_address=tensor_info.vram_addr, - scratchpad_base_address=scratchpad_vram_addr, - vlen=vlen, - batch_size=batch_size, - hidden_dim=hidden_dim, - ) - - self.generated_code += isa_code - return isa_code - finally: - # Always release allocated GP registers used by normalization template. - self.register_allocator.free_gp(gp_regs) - if temp_scratchpad_name is not None: - self.vram_allocator.free(temp_scratchpad_name, strict=False) - - def rope( - self, - x_name: str, - x_rot_name: str, - cos_name: str, - sin_name: str, - ) -> str: - """Apply RoPE in-place: x = x * cos + rotate_half(x) * sin - - All four tensors must already be in VRAM with the same shape (seq_len, head_dim). - x_rot must be preloaded by the caller as rotate_half(x). - """ - x_info = self[x_name] - xrot_info = self[x_rot_name] - cos_info = self[cos_name] - sin_info = self[sin_name] - - if x_info.vram_addr is None: - raise ValueError(f"Tensor '{x_name}' has no VRAM address") - - seq_len, head_dim = x_info.shape - vlen = self.mlen - - if head_dim % vlen != 0: - raise ValueError(f"head_dim ({head_dim}) must be divisible by vlen ({vlen}) for rope") - - gp_regs = self.register_allocator.allocate_gp(5) - - scratch_name = f"__rope_scratch__{x_name}__{len(self.generated_code)}" - scratch_addr = self.vram_allocator.allocate(vlen, name=scratch_name) - - try: - isa_code = rope_asm( - alive_registers=gp_regs, - x_base_address=x_info.vram_addr, - x_rot_base_address=xrot_info.vram_addr, - cos_base_address=cos_info.vram_addr, - sin_base_address=sin_info.vram_addr, - scratchpad_base_address=scratch_addr, - vlen=vlen, - seq_len=seq_len, - head_dim=head_dim, - ) - self.generated_code += isa_code - return isa_code - finally: - self.register_allocator.free_gp(gp_regs) - self.vram_allocator.free(scratch_name, strict=False) - - def get_code(self) -> str: - """Get all accumulated generated ISA code""" - return self.generated_code - - def reset(self): - """Reset compiler state (clear code, but retain symbol table)""" - self.generated_code = "" - self.register_allocator = RegisterAllocator() - # Call TileCompiler.reset() explicitly since the merged class shadows it. - TileCompiler.reset(self) - - def print_symbol_table(self): - """Print symbol table""" - self.print_table() - - def get_symbol_table(self): - """Get managed object table view.""" - return self - - def get_tensor_info(self, name: str): - """Get unified tensor/object info by name.""" - return self[name] - - def add_hbm_object( - self, - name: str, - hbm_addr: int, - shape: tuple[int, int], - real_data_ratio: float = 1.125, - ): - """Register an HBM object and build its HBM layout. - - Wraps ``TileCompiler.add_hbm_object`` with a different positional - parameter order ``(name, hbm_addr, shape, ...)`` that all DeveloperCompiler - callers use. - """ - return TileCompiler.add_hbm_object( - self, - name=name, - shape=shape, - hbm_addr=hbm_addr, - real_data_ratio=real_data_ratio, - ) - - def free_hbm_object(self, name: str, strict: bool = False): - """Free an HBM object by name (defaults to non-strict).""" - return TileCompiler.free_hbm_object(self, name, strict=strict) - - def get_vram_addr(self, name: str) -> int: - """Get VRAM base address of an object.""" - info = self.get_tensor_info(name) - if info.vram_addr is None: - raise ValueError(f"Object '{name}' has no VRAM address") - return info.vram_addr - - def get_vram_tile_addr( - self, - name: str, - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> int: - """ - Get VRAM address of a specific tile (sub-block) in a VRAM matrix. - - Args: - name: matrix name - tile_row_idx: tile row index (0-based) - tile_col_idx: tile col index (0-based) - """ - self._ensure_vram_matrix_layout(name) - sub = self.get_vram_sub_block(name, tile_row_idx, tile_col_idx) - return sub.vram_addr - - def ensure_hbm_sub_matrix( - self, - name: str, - hbm_addr: int, - shape: tuple[int, int], - real_data_ratio: float = 1.125, - ): - """Ensure HBM matrix layout exists.""" - if name in self.hbm_matrices: - return - self.add_hbm_object( - name=name, - hbm_addr=hbm_addr, - shape=shape, - real_data_ratio=real_data_ratio, - ) - - def ensure_vram_matrix_layout(self, name: str, shape: tuple[int, int]): - """Ensure VRAM matrix layout exists for an already allocated VRAM object.""" - if name in self.vram_matrices: - return - vram_addr = self.get_vram_addr(name) - self.add_vram_object( - name=name, - shape=shape, - vram_addr=vram_addr, - allocate_if_none=False, - ) - - def free_vram_object(self, name: str, strict: bool = False): - """Free a VRAM object by name (defaults to non-strict).""" - return TileCompiler.free_vram_object(self, name, strict=strict) - - # ========================================================================= - # FP Register & FPRAM Management (inlined from former FPRAMCompiler). - # All state lives on self (register_allocator, fpram_allocator, etc.). - # ========================================================================= - - @property - def _reg(self) -> RegisterAllocator: - """Shorthand for self.register_allocator (used by FPVar ISA helpers).""" - return self.register_allocator - - @property - def _unroll(self) -> bool: - """Shorthand for self.unroll_loops.""" - return self.unroll_loops - - def _emit(self, isa_code: str) -> str: - """Append ISA text to the output buffer and return it.""" - self.generated_code += isa_code - return isa_code - - # ------------------------------------------------------------------ - # FP Register management - # ------------------------------------------------------------------ - - def allocate_fp_reg(self, count: int = 1) -> list[int]: - """Allocate FP registers (f0-f7).""" - return self._reg.allocate_fp(count) - - def free_fp_reg(self, registers: list[int]): - """Free FP registers.""" - self._reg.free_fp(registers) - - # ------------------------------------------------------------------ - # FPRAM address-space management - # ------------------------------------------------------------------ - - def allocate_fpram(self, name: str, size: int) -> int: - """Allocate FPRAM space, returns base address.""" - info = self.add_fpram_object(name=name, size=size) - if info.fpram_addr is None: - raise RuntimeError(f"Failed to allocate FPRAM for '{name}'") - return info.fpram_addr - - def free_fpram(self, name: str, strict: bool = True): - """Free FPRAM object by name.""" - return self.free_fpram_object(name, strict=strict) - - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot.""" - return self.fpram_allocator.save_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot.""" - self.fpram_allocator.restore_state(snapshot) - - def list_fpram_allocations(self) -> list[str]: - """List currently allocated FPRAM object names.""" - return list(self.fpram_allocator.allocations.keys()) - - def get_fpram_addr(self, name: str) -> int: - """Get FPRAM base address from object name.""" - return self.get_fpram_layout(name).fpram_addr - - def get_fpram_size(self, name: str) -> int: - """Get FPRAM allocation size from object name.""" - return self.get_fpram_layout(name).size - - # ========================================================================= - # FPVar ISA helpers (address-based) - # ========================================================================= - - def fpvar_copy_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Copy skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Copy: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_addr}:{src_addr + count}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_fill_from_fpram_asm(self, dst_addr: int, src_fpram_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Fill skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Fill: FPRAM[{dst_addr}:{dst_addr + count}] = FPRAM[{src_fpram_addr}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_fpram_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - if self._unroll: - for i in range(count): - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_reci_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Reci skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Reci: dst = 1/src, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append("S_RECI_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append("S_RECI_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_exp_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Exp skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Exp: dst = exp(src), count={count}"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_src}, {i}") - lines.append("S_EXP_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_src}, 0") - lines.append("S_EXP_FP f1, f1, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_add_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Add skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Add: dst = src1 + src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_sub_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Sub skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Sub: dst = src1 - src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_SUB_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_SUB_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_mul_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Mul skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Mul: dst = src1 * src2, count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_MUL_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_MUL_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_max_asm(self, src1_addr: int, src2_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Max skipped: count={count}\n") - gp = self._reg.allocate_gp(4) - gp_a, gp_b, gp_dst, gp_loop = gp - lines = [f"; FPVar Max: dst = max(src1, src2), count={count}"] - lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_addr}") - lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f1, gp{gp_a}, {i}") - lines.append(f"S_LD_FP f2, gp{gp_b}, {i}") - lines.append("S_MAX_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f1, gp{gp_a}, 0") - lines.append(f"S_LD_FP f2, gp{gp_b}, 0") - lines.append("S_MAX_FP f1, f1, f2") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, 1") - lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, 1") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_sum_asm(self, src_addr: int, dst_addr: int, count: int) -> str: - if count <= 0: - return self._emit(f"; FPVar Sum skipped: count={count}\n") - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; FPVar Sum: FPRAM[{dst_addr}] = sum(FPRAM[{src_addr}:{src_addr + count}])"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append("S_ADD_FP f1, f0, f0") - if self._unroll: - for i in range(count): - lines.append(f"S_LD_FP f2, gp{gp_src}, {i}") - lines.append("S_ADD_FP f1, f1, f2") - else: - lines.append(f"C_LOOP_START gp{gp_loop}, {count}") - lines.append(f"S_LD_FP f2, gp{gp_src}, 0") - lines.append("S_ADD_FP f1, f1, f2") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, 1") - lines.append(f"C_LOOP_END gp{gp_loop}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - def fpvar_shift_asm( - self, - src_addr: int, - dst_addr: int, - count: int, - shift: int, - fill_fpram_addr: int = 0, - ) -> str: - """ - Shift FPVar into dst. - - shift > 0: right shift (leading positions filled) - - shift < 0: left shift (trailing positions filled) - """ - gp = self._reg.allocate_gp(3) - gp_src, gp_dst, gp_fill = gp - lines = [f"; FPVar Shift: dst=shift(src, shift={shift}), count={count}, fill=FPRAM[{fill_fpram_addr}]"] - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr}") - lines.append(f"S_ADDI_INT gp{gp_fill}, gp0, {fill_fpram_addr}") - lines.append(f"S_LD_FP f3, gp{gp_fill}, 0") - - for i in range(count): - src_idx = i - shift - if 0 <= src_idx < count: - lines.append(f"S_LD_FP f1, gp{gp_src}, {src_idx}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, {i}") - else: - lines.append(f"S_ST_FP f3, gp{gp_dst}, {i}") - - self._reg.free_gp(gp) - return self._emit("\n".join(lines) + "\n") - - # ========================================================================= - # FPVar helpers (name-based wrappers over the address-based ISA generators) - # ========================================================================= - - def fpram_copy(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_copy_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_reci(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_reci_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_exp(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - return self.fpvar_exp_asm(self.get_fpram_addr(src_name), self.get_fpram_addr(dst_name), count) - - def fpram_add(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_add_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_sub(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_sub_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_mul(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_mul_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_max(self, src1_name: str, src2_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = min(self.get_fpram_size(src1_name), self.get_fpram_size(src2_name), self.get_fpram_size(dst_name)) - return self.fpvar_max_asm( - self.get_fpram_addr(src1_name), self.get_fpram_addr(src2_name), self.get_fpram_addr(dst_name), count - ) - - def fpram_sum(self, src_name: str, dst_name: str, count: int | None = None) -> str: - if count is None: - count = self.get_fpram_size(src_name) - return self.fpvar_sum_asm( - self.get_fpram_addr(src_name), - self.get_fpram_addr(dst_name), - count, - ) - - def fpram_shift( - self, - src_name: str, - dst_name: str, - shift: int, - count: int | None = None, - fill_fpram_name: str | None = None, - ) -> str: - if count is None: - count = min(self.get_fpram_size(src_name), self.get_fpram_size(dst_name)) - fill_addr = 0 if fill_fpram_name is None else self.get_fpram_addr(fill_fpram_name) - return self.fpvar_shift_asm( - src_addr=self.get_fpram_addr(src_name), - dst_addr=self.get_fpram_addr(dst_name), - count=count, - shift=shift, - fill_fpram_addr=fill_addr, - ) - - def fpram_fill_from_fpram(self, dst_name: str, src_fpram_addr: int, count: int | None = None) -> str: - if count is None: - count = self.get_fpram_size(dst_name) - return self.fpvar_fill_from_fpram_asm( - dst_addr=self.get_fpram_addr(dst_name), - src_fpram_addr=src_fpram_addr, - count=count, - ) - - # ========================================================================= - # Tile-row helpers (name-based) - # ========================================================================= - - def tile_row_max( - self, - source_matrix: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_max_asm(source_addr, row_map) - - def tile_row_sum( - self, - source_matrix: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - source_addr = self.get_vram_tile_addr(source_matrix, tile_row_idx, tile_col_idx) - return self.tile_row_sum_asm(source_addr, row_map) - - def tile_row_exp( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_exp_asm(matrix_addr, rows) - - def tile_row_reci( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_reci_asm(matrix_addr, rows) - - def tile_row_sub_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_sub_fp_asm(matrix_addr, row_map) - - def tile_row_mul_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_asm(matrix_addr, row_map) - - def tile_row_add_fp( - self, - matrix_name: str, - row_map: list[tuple[int, int]], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_add_fp_asm(matrix_addr, row_map) - - def tile_row_add( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_add_asm(dst_addr, src_addr, rows) - - def tile_row_sub( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_sub_asm(dst_addr, src_addr, rows) - - def tile_row_mul( - self, - dst_matrix: str, - src_matrix: str, - rows: list[int], - dst_tile_row_idx: int = 0, - dst_tile_col_idx: int = 0, - src_tile_row_idx: int = 0, - src_tile_col_idx: int = 0, - ) -> str: - dst_addr = self.get_vram_tile_addr(dst_matrix, dst_tile_row_idx, dst_tile_col_idx) - src_addr = self.get_vram_tile_addr(src_matrix, src_tile_row_idx, src_tile_col_idx) - return self.tile_row_mul_asm(dst_addr, src_addr, rows) - - def tile_row_mul_fp_broadcast( - self, - matrix_name: str, - fpram_scalar_addr: int, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.tile_row_mul_fp_broadcast_asm(matrix_addr, fpram_scalar_addr, rows) - - def vram_fill_zero( - self, - matrix_name: str, - rows: list[int], - tile_row_idx: int = 0, - tile_col_idx: int = 0, - ) -> str: - matrix_addr = self.get_vram_tile_addr(matrix_name, tile_row_idx, tile_col_idx) - return self.vram_fill_zero_asm(matrix_addr, rows) - - # ========================================================================= - # Tile-row ISA helpers (address-based) - # ========================================================================= - - def _arith_progression(self, values: list[int]) -> tuple[int, int, int] | None: - """Return (start, count, step) if values form an arithmetic progression.""" - if not values: - return None - if len(values) == 1: - return (values[0], 1, 0) - step = values[1] - values[0] - for i in range(2, len(values)): - if values[i] - values[i - 1] != step: - return None - if step == 0: - return None # Constant sequence (step=0, count>1) would cause infinite HW loop - return (values[0], len(values), step) - - def tile_row_max_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; Tile Row Max from VRAM[{source_vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {source_vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_RED_MAX f1, gp{gp_src}, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = source_vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_RED_MAX f1, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sum_asm(self, source_vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_dst, gp_loop = gp - lines = [f"; Tile Row Sum from VRAM[{source_vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {source_vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append("S_ADD_FP f1, f0, f0") - lines.append(f"V_RED_SUM f1, gp{gp_src}, 0, 0") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = source_vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append("S_ADD_FP f1, f0, f0") - lines.append(f"V_RED_SUM f1, gp{gp_src}, 0, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr}") - lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_exp_asm(self, vram_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(2) - gp_src, gp_loop = gp - lines = [f"; Tile Row Exp on VRAM[{vram_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_EXP_V gp{gp_src}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_EXP_V gp{gp_src}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_reci_asm(self, vram_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(2) - gp_src, gp_loop = gp - lines = [f"; Tile Row Reciprocal on VRAM[{vram_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_RECI_V gp{gp_src}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"V_RECI_V gp{gp_src}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sub_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Sub FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_SUB_VF gp{gp_src}, gp{gp_src}, f1, 0, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_SUB_VF gp{gp_src}, gp{gp_src}, f1, 0, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Mul FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_MUL_VF gp{gp_src}, gp{gp_src}, f1, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_MUL_VF gp{gp_src}, gp{gp_src}, f1, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_add_fp_asm(self, vram_addr: int, row_map: list[tuple[int, int]]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_src, gp_fp, gp_loop = gp - lines = [f"; Tile Row Add FP on VRAM[{vram_addr}]"] - rows = [r for r, _ in row_map] - fp_addrs = [a for _, a in row_map] - row_prog = None if self.unroll_loops else self._arith_progression(rows) - fp_prog = None if self.unroll_loops else self._arith_progression(fp_addrs) - if row_prog is not None and fp_prog is not None: - row_start, row_count, row_step = row_prog - fp_start, _, fp_step = fp_prog - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fp_start}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_ADD_VF gp{gp_src}, gp{gp_src}, f1, 0") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {fp_step}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx, fpram_addr in row_map: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") - lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {fpram_addr}") - lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") - lines.append(f"V_ADD_VF gp{gp_src}, gp{gp_src}, f1, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_add_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Add: VRAM[{dst_addr}] += VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_sub_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Sub: VRAM[{dst_addr}] -= VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_SUB_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_SUB_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_asm(self, dst_addr: int, src_addr: int, rows: list[int]) -> str: - gp = self.register_allocator.allocate_gp(3) - gp_dst, gp_src, gp_loop = gp - lines = [f"; Tile Row Mul: VRAM[{dst_addr}] *= VRAM[{src_addr}]"] - prog = None if self.unroll_loops else self._arith_progression(rows) - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_addr + row_start * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_MUL_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - d = dst_addr + row_idx * self.mlen - s = src_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {d}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {s}") - lines.append(f"V_MUL_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp) - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def tile_row_mul_fp_broadcast_asm(self, vram_addr: int, fpram_scalar_addr: int, rows: list[int]) -> str: - row_map = [(r, fpram_scalar_addr) for r in rows] - return self.tile_row_mul_fp_asm(vram_addr, row_map) - - def vram_fill_zero_asm( - self, - vram_addr: int, - rows: list[int], - ) -> str: - """ - VRAM Fill Zero: fill specified rows with 0. - - For each row_idx in rows: - VRAM[row] = 0 - """ - if not rows: - isa_code = f"; === VRAM Fill Zero: VRAM[{vram_addr}] rows [] = 0 ===\n" - self.generated_code += isa_code - return isa_code - - gp_regs = self.register_allocator.allocate_gp(2) - gp_dst, gp_loop = gp_regs - - lines = [] - lines.append(f"; === VRAM Fill Zero: VRAM[{vram_addr}] rows {rows} = 0 ===") - prog = None if self.unroll_loops else self._arith_progression(rows) - - if prog is not None: - row_start, row_count, row_step = prog - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {vram_addr + row_start * self.mlen}") - lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") - lines.append(f"V_MUL_VF gp{gp_dst}, gp{gp_dst}, f0, 0") - lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_step * self.mlen}") - lines.append(f"C_LOOP_END gp{gp_loop}") - else: - for row_idx in rows: - row_addr = vram_addr + row_idx * self.mlen - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {row_addr}") - lines.append(f"V_MUL_VF gp{gp_dst}, gp{gp_dst}, f0, 0") - - self.register_allocator.free_gp(gp_regs) - - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def reset_mram(self) -> str: - """ - Reset MRAM allocator, free all allocated space - Used in scenarios where sub-blocks need to be reloaded within a for loop - """ - self.mram_allocator.reset() - self.loaded_sub_blocks.clear() - - isa_code = "; === Reset MRAM ===\n" - self.generated_code += isa_code - return isa_code - - def load_sub_matrix_row( - self, - name: str, - row_idx: int, - mram_start_addr: int | None = None, - ) -> str: - """Load entire row sub-blocks from HBM to MRAM: matrix[row_idx][:].""" - layout = self.get_hbm_layout(name) - num_col_blocks = layout.num_col_blocks - block_size = self.mlen * self.mlen - - if mram_start_addr is None: - total_size = num_col_blocks * block_size - mram_start_addr = self.mram_allocator.allocate(f"{name}[{row_idx}][:]", total_size) - - gp_regs = self.register_allocator.allocate_gp(4) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] - ) - - isa_code += self.load_row_sub_matrices_asm( - name=name, row_idx=row_idx, mram_start_addr=mram_start_addr, hbm_addr_reg=addr_reg, gp_regs=gp_regs - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - return isa_code - - def load_sub_matrix_col( - self, - name: str, - col_idx: int, - mram_start_addr: int | None = None, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Load entire column sub-blocks from HBM to MRAM: matrix[:][col_idx]. - Used for sub_projection: A @ W[:, col_idx*mlen:(col_idx+1)*mlen]. - """ - layout = self.get_hbm_layout(name) - num_row_blocks = layout.num_row_blocks - block_size = self.mlen * self.mlen - - if mram_start_addr is None: - effective_count = k_block_count if k_block_count is not None else num_row_blocks - total_size = effective_count * block_size - mram_start_addr = self.mram_allocator.allocate(f"{name}[:][{col_idx}]", total_size) - - gp_regs = self.register_allocator.allocate_gp(3) - gp_for_addr = self.register_allocator.allocate_gp(2) - addr_reg = self.register_allocator.allocate_addr(1)[0] - - isa_code = preload_addr_reg_asm( - addr_reg_to_set=[addr_reg], available_registers=gp_for_addr, addr_reg_val=[layout.hbm_base_addr] - ) - - isa_code += self.load_col_sub_matrices_asm( - name=name, - col_idx=col_idx, - mram_start_addr=mram_start_addr, - hbm_addr_reg=addr_reg, - gp_regs=gp_regs, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_gp(gp_for_addr) - self.register_allocator.free_addr([addr_reg]) - - self.generated_code += isa_code - return isa_code - - def allocate_vram_matrix( - self, - name: str, - rows: int, - cols: int, - strict: bool = True, - ) -> int: - """Allocate a VRAM matrix large enough to hold combined results of multiple sub-blocks. Returns the VRAM base address.""" - size = rows * cols - vram_addr = self.vram_allocator.allocate(size, name=name) - - self.add_vram_object( - name=name, - shape=(rows, cols), - vram_addr=vram_addr, - dtype="fp32", - kind="VRAMMatrix", - allocate_if_none=False, - strict=strict, - ) - - isa_code = f"; Allocate VRAM Matrix {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" - self.generated_code += isa_code - - return vram_addr - - def _ensure_vram_matrix_layout(self, matrix_name: str): - """Ensure a VRAM-resident tensor has a block layout in TileCompiler.""" - if matrix_name not in self: - raise KeyError(f"Matrix '{matrix_name}' not found in symbol table") - - info = self[matrix_name] - if info.vram_addr is None: - raise ValueError(f"Matrix '{matrix_name}' has no VRAM address") - - try: - self.get_vram_layout(matrix_name) - except KeyError: - self.register_vram_matrix( - name=matrix_name, - shape=info.shape, - vram_base_addr=info.vram_addr, - ) - - def vram_block_add_to( - self, - src1_matrix: str, - src1_row_idx: int, - src1_col_idx: int, - src2_matrix: str, - src2_row_idx: int, - src2_col_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - ) -> str: - """ - mlen x mlen block add: - target[rt][ct] = src1[r1][c1] + src2[r2][c2] - - Source/target may be the same matrix (supports in-place overwrite). - """ - self._ensure_vram_matrix_layout(src1_matrix) - self._ensure_vram_matrix_layout(src2_matrix) - self._ensure_vram_matrix_layout(target_matrix) - - gp_regs = self.register_allocator.allocate_gp(4) - isa_code = self.vram_block_add_asm( - src1_name=src1_matrix, - src1_row_idx=src1_row_idx, - src1_col_idx=src1_col_idx, - src2_name=src2_matrix, - src2_row_idx=src2_row_idx, - src2_col_idx=src2_col_idx, - target_name=target_matrix, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - gp_regs=gp_regs, - ) - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - def vram_matrix_add( - self, - dst_matrix: str, - src_matrix: str, - dst_row_offset: int = 0, - src_row_offset: int = 0, - num_rows: int | None = None, - ) -> str: - """ - General VRAM Matrix Addition: dst[row_offset:] += src. - - row_offsets are logical rows (not VRAM addresses); num_rows defaults - to the source matrix's row count. - """ - dst_info = self[dst_matrix] - src_info = self[src_matrix] - - # Block-add path depends on TileCompiler VRAM layouts. - self._ensure_vram_matrix_layout(dst_matrix) - self._ensure_vram_matrix_layout(src_matrix) - - dst_addr = dst_info.vram_addr - src_addr = src_info.vram_addr - - dst_rows, dst_cols = dst_info.shape - src_rows, src_cols = src_info.shape - - if num_rows is None: - num_rows = src_rows - - # Ensure column count matches - assert dst_cols == src_cols, f"Column mismatch: dst={dst_cols}, src={src_cols}" - assert dst_row_offset + num_rows <= dst_rows, ( - f"dst row range out of bounds: offset={dst_row_offset}, num_rows={num_rows}, dst_rows={dst_rows}" - ) - assert src_row_offset + num_rows <= src_rows, ( - f"src row range out of bounds: offset={src_row_offset}, num_rows={num_rows}, src_rows={src_rows}" - ) - lines = [] - lines.append( - f"; === VRAM Matrix Add: " - f"{dst_matrix}[{dst_row_offset}:{dst_row_offset + num_rows}] += " - f"{src_matrix}[{src_row_offset}:{src_row_offset + num_rows}] ===" - ) - lines.append(f"; dst shape: {dst_info.shape}, src shape: {src_info.shape}") - - # Prefer block add path so we can reuse the compact C_LOOP-based add kernel. - block_aligned = ( - dst_cols % self.mlen == 0 - and src_cols % self.mlen == 0 - and dst_row_offset % self.mlen == 0 - and src_row_offset % self.mlen == 0 - and num_rows % self.mlen == 0 - ) - - if block_aligned: - num_row_blocks = num_rows // self.mlen - num_col_blocks = dst_cols // self.mlen - dst_row_block_base = dst_row_offset // self.mlen - src_row_block_base = src_row_offset // self.mlen - lines.append(f"; block add path: row_blocks={num_row_blocks}, col_blocks={num_col_blocks}") - - for row_block in range(num_row_blocks): - for col_block in range(num_col_blocks): - gp_regs = self.register_allocator.allocate_gp(4) - lines.append( - self.vram_block_add_asm( - src1_name=dst_matrix, - src1_row_idx=dst_row_block_base + row_block, - src1_col_idx=col_block, - src2_name=src_matrix, - src2_row_idx=src_row_block_base + row_block, - src2_col_idx=col_block, - target_name=dst_matrix, - target_row_idx=dst_row_block_base + row_block, - target_col_idx=col_block, - gp_regs=gp_regs, - ).rstrip("\n") - ) - self.register_allocator.free_gp(gp_regs) - else: - # Fallback for non-mlen-aligned ranges. - gp_regs = self.register_allocator.allocate_gp(2) - gp_dst = gp_regs[0] - gp_src = gp_regs[1] - num_col_blocks = dst_cols // self.mlen - lines.append(f"; fallback row-wise path: num_rows={num_rows}, num_col_blocks={num_col_blocks}") - - for row in range(num_rows): - dst_actual_row = dst_row_offset + row - src_actual_row = src_row_offset + row - - for col_block in range(num_col_blocks): - dst_block_addr = dst_addr + col_block * dst_rows * self.mlen + dst_actual_row * self.mlen - src_block_addr = src_addr + col_block * src_rows * self.mlen + src_actual_row * self.mlen - - lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_block_addr}") - lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_block_addr}") - lines.append(f"V_ADD_VV gp{gp_dst}, gp{gp_dst}, gp{gp_src}, 0") - self.register_allocator.free_gp(gp_regs) - - isa_code = "\n".join(lines) + "\n" - self.generated_code += isa_code - return isa_code - - def vram_sub_projection_to( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_col_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - k_block_start: int = 0, - k_block_count: int | None = None, - ) -> str: - """ - Sub-block multiplication: - target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[:][mram_col_idx]. - Target matrix must have been allocated via allocate_vram_matrix. - """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[:][{mram_col_idx}] -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_asm( - vram_mat_name=vram_mat_name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_mat_name, - mram_col_idx=mram_col_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - def vram_sub_projection_T_to( - self, - vram_mat_name: str, - vram_row_idx: int, - mram_mat_name: str, - mram_row_idx: int, - target_matrix: str, - target_row_idx: int, - target_col_idx: int, - ) -> str: - """ - Transposed sub-block multiplication: - target[target_row_idx][target_col_idx] = VRAM_A[vram_row_idx][:] @ MRAM_W[mram_row_idx][:]^T. - - Used by Flash Attention for S = Q @ K^T: - Q[i][:]: (mlen, hidden_size) row sub-block - K[j][:]: (mlen, hidden_size) row sub-block, transposed to (hidden_size, mlen) - S[i][j]: (mlen, mlen) - """ - if target_matrix not in self: - raise KeyError(f"Target matrix '{target_matrix}' not found. Use allocate_vram_matrix first.") - - target_info = self[target_matrix] - target_rows, _target_cols = target_info.shape - target_base_addr = target_info.vram_addr - - # VRAM layout: [batch, mlen, hidden/mlen], column-block major. - # Sub-block (r, c) addr = base + c * rows * mlen + r * mlen * mlen. - result_vram_addr = ( - target_base_addr + target_col_idx * target_rows * self.mlen + target_row_idx * self.mlen * self.mlen - ) - - gp_regs = self.register_allocator.allocate_gp(9) - - isa_code = f"; VRAM Sub Projection T To: {vram_mat_name}[{vram_row_idx}][:] @ {mram_mat_name}[{mram_row_idx}][:]^T -> {target_matrix}[{target_row_idx}][{target_col_idx}]\n" - isa_code += f"; Target VRAM addr: {result_vram_addr} (base={target_base_addr}, offset=col*{target_rows}*{self.mlen} + row*{self.mlen}*{self.mlen})\n" - isa_code += self.vram_sub_projection_T_asm( - vram_mat_name=vram_mat_name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_mat_name, - mram_row_idx=mram_row_idx, - result_vram_addr=result_vram_addr, - gp_regs=gp_regs, - ) - - self.register_allocator.free_gp(gp_regs) - - self.generated_code += isa_code - return isa_code - - # ========================================================================= - # Expanded Flash Attention Operations - # ========================================================================= - - def init_online_softmax( - self, - q_idx: int, - o_matrix: str, - seq_len: int, - head_dim: int, - ) -> str: - """ - Initialize Online Softmax state for Q block q_idx: - m_old = -inf (FP SRAM), l = 0 (FP SRAM), O_row = 0 (VRAM). - """ - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_old_addr = fp_sram_start - l_addr = fp_sram_start + 2 * self.mlen # skip m_res region - - o_info = self[o_matrix] - o_vram_addr = o_info.vram_addr - row_offset = q_idx * self.mlen - - isa_code = f"; === Init Online Softmax for Q block {q_idx} ===\n" - - isa_code += self._reset_fpsram_asm(m_old_addr, self.mlen, 2) # slot 2 = -inf - isa_code += self._reset_fpsram_asm(l_addr, self.mlen, 0) # slot 0 = 0.0 - isa_code += self._reset_vram_asm( - start_address=o_vram_addr, - rows=self.mlen, - cols=head_dim, - total_rows=seq_len, - mlen=self.mlen, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - def online_softmax_block( - self, - s_block_matrix: str, - scale: float, - ) -> str: - """ - Run Online Softmax on one S block. - Input: S_block (mlen × mlen) in VRAM - Output: P (mlen × mlen) in-place in VRAM - Updates: m_old, m_res, l in FP SRAM - ``scale`` is the QK^T scaling factor (typically 1/sqrt(d)). - """ - s_info = self[s_block_matrix] - s_address = s_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_start_address = fp_sram_start - - isa_code = f"; === Online Softmax Block {s_block_matrix} ===\n" - isa_code += self._online_softmax_asm( - mlen=self.mlen, s_address=s_address, m_start_address=m_start_address, scale=scale - ) - - self.generated_code += isa_code - return isa_code - - def compute_pv( - self, - s_block_matrix: str, - v_sub_matrix: str, - k_idx: int, - pv_matrix: str, - head_dim: int, - ) -> str: - """ - Compute PV = P @ V[k_idx]. - - P lives in s_block_matrix (softmax result); V is prefetched from - HBM; PV is written to VRAM via pv_matrix. - """ - s_info = self[s_block_matrix] - p_address = s_info.vram_addr - - pv_info = self[pv_matrix] - pv_address = pv_info.vram_addr - - v_layout = self.get_hbm_layout(v_sub_matrix) - v_hbm_offset = k_idx * self.mlen * head_dim - - isa_code = f"; === Compute PV = P @ V[k_idx={k_idx}] ===\n" - - addr_regs = self.register_allocator.allocate_addr(1) - v_hbm_reg = addr_regs[0] - gp_regs = self.register_allocator.allocate_gp(2) - - from compiler.asm_templates import preload_addr_reg_asm - - isa_code += preload_addr_reg_asm( - addr_reg_to_set=[v_hbm_reg], available_registers=gp_regs, addr_reg_val=[v_layout.hbm_base_addr] - ) - - isa_code += self._pv_multiply_asm( - mlen=self.mlen, - blen=self.blen, - head_dim=head_dim, - p_address=p_address, - v_hbm_offset_reg=v_hbm_reg, - v_hbm_offset=v_hbm_offset, - pv_address=pv_address, - ) - - self.register_allocator.free_gp(gp_regs) - self.register_allocator.free_addr(addr_regs) - - self.generated_code += isa_code - return isa_code - - def scale_o_row( - self, - o_matrix: str, - q_idx: int, - seq_len: int, - head_dim: int, - ) -> str: - """Scale the current row block of O by m_res: O[q_idx] *= m_res.""" - o_info = self[o_matrix] - o_address = o_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - m_res_addr = fp_sram_start + self.mlen - - row_offset = q_idx * self.mlen - - isa_code = f"; === Scale O[q_idx={q_idx}] by m_res ===\n" - isa_code += self._scale_o_asm( - mlen=self.mlen, - head_dim=head_dim, - seq_len=seq_len, - m_res_address=m_res_addr, - o_address=o_address, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - def final_scale_o( - self, - q_idx: int, - o_matrix: str, - seq_len: int, - head_dim: int, - ) -> str: - """Final scaling: O[q_idx] /= l.""" - o_info = self[o_matrix] - o_address = o_info.vram_addr - - fp_sram_start = self._ONLINE_SOFTMAX_FPSRAM_BASE - l_addr = fp_sram_start + 2 * self.mlen - - row_offset = q_idx * self.mlen - - isa_code = f"; === Final Scale O for Q block {q_idx} ===\n" - isa_code += self._final_scaling_asm( - mlen=self.mlen, - head_dim=head_dim, - seq_len=seq_len, - l_address=l_addr, - o_address=o_address, - row_offset=row_offset, - ) - - self.generated_code += isa_code - return isa_code - - -# Example Usage - - -class _DeveloperView: - """ - Back-compat proxy for legacy ``prog._compiler.X(...)`` call sites. - - PlenaCompiler now inherits DeveloperCompiler rather than composing it, so - for call sites that still expect to reach the low-level DeveloperCompiler - API (e.g., ``allocate_fpram(name=..., size=...)`` returning an int), we - expose a proxy whose attribute lookup resolves callables on - DeveloperCompiler directly (bypassing any PlenaCompiler overrides with - colliding names). Non-callable attributes (e.g., ``generated_code``, - ``vram_allocator``) fall through to the underlying instance unchanged. - """ - - __slots__ = ("_inst",) - - def __init__(self, inst: PlenaCompiler): - object.__setattr__(self, "_inst", inst) - - def __getattr__(self, name: str): - cls_attr = getattr(DeveloperCompiler, name, None) - if cls_attr is not None and callable(cls_attr): - return cls_attr.__get__(self._inst, DeveloperCompiler) - return getattr(self._inst, name) - - def __setattr__(self, name: str, value): - setattr(self._inst, name, value) - - -# ============================================================================ -# TensorVar Proxy Object Hierarchy -# ============================================================================ - - -class TensorVar: - """ - Tensor proxy object base class - - All tensor variables inherit from this class. - Supports __matmul__ (`@`) operator, which automatically dispatches to appropriate PlenaCompiler methods. - - Dual naming: - - display_name: User-visible name (e.g., "temp", "Q", "S") - - internal_name: System internal name (e.g., "my_func_0/temp"), used for symbol table and ISA generation - """ - - def __init__( - self, - program: PlenaCompiler, - internal_name: str, - kind: str, - shape: tuple[int, int], - display_name: str | None = None, - ): - self._program = program - self.internal_name = internal_name # System internal name (with scope prefix), used by symbol table - self.display_name = display_name if display_name is not None else internal_name # User-visible name - self.kind = kind # "input", "batch", "matrix", "vram_matrix" - self.shape = shape - - @property - def name(self) -> str: - """Compatibility property: returns internal_name for internal system use""" - return self.internal_name - - def __matmul__(self, other): - """A @ B: Dispatch to appropriate computation based on operand types""" - return self._program._dispatch_matmul(self, other) - - def __repr__(self): - if self.display_name != self.internal_name: - return ( - f"{self.__class__.__name__}(display={self.display_name!r}, " - f"internal={self.internal_name!r}, shape={self.shape})" - ) - return f"{self.__class__.__name__}({self.display_name!r}, shape={self.shape})" - - -class InputVar(TensorVar): - """ - Input variable: tensor declared in HBM - - Not yet loaded to VRAM; needs to be loaded via load_batch / load_matrix. - - If ``prestaged_vram_addr`` is not None the tensor is assumed to be already - present in VRAM at that byte address. ``load_batch`` will register it at - that address without emitting any HBM→VRAM prefetch instructions. - """ - - def __init__( - self, - program: PlenaCompiler, - name: str, - shape: tuple[int, int], - hbm_addr: int, - hbm_size: int, - display_name: str | None = None, - prestaged_vram_addr: int | None = None, - ): - super().__init__(program, name, "input", shape, display_name=display_name) - self.hbm_addr = hbm_addr - self.hbm_size = hbm_size - self.prestaged_vram_addr = prestaged_vram_addr - - -class FPVar: - """ - FP variable: maps to a contiguous region in FPRAM - - Declared via prog.fp_var("scale", size=1), automatically allocates FPRAM space. - Provides .address for ISA generation (S_LD_FP / S_ST_FP). - - Usage: - scale = prog.fp_var("scale", size=1) - m_old = prog.fp_var("m_old", size=64) - - scale.address # -> FPRAM address (int) - scale.size # -> number of elements - scale[3] # -> address + 3 (element offset) - """ - - def __init__( - self, program: PlenaCompiler, internal_name: str, address: int, size: int, display_name: str | None = None - ): - self._program = program - self.internal_name = internal_name - self.display_name = display_name if display_name is not None else internal_name - self.address = address - self.size = size - - @property - def name(self) -> str: - return self.internal_name - - def __getitem__(self, idx: int) -> int: - """Element offset: fp_var[i] -> address + i""" - if idx < 0 or idx >= self.size: - raise IndexError(f"FPVar '{self.display_name}' index {idx} out of range [0, {self.size})") - return self.address + idx - - def __repr__(self): - return f"FPVar({self.display_name!r}, addr={self.address}, size={self.size})" - - -class VRAMMatrixVar(TensorVar): - """ - VRAM matrix variable: large matrix allocated via alloc - - Used to store intermediate results (e.g., S block, PV, O). - Supports sub-block indexed writes: `O[r][c] = ...` - """ - - def __init__(self, program: PlenaCompiler, name: str, shape: tuple[int, int], display_name: str | None = None): - super().__init__(program, name, "vram_matrix", shape, display_name=display_name) - - -# ============================================================================ -# PlenaCompiler Main Class -# ============================================================================ - - -class PlenaCompiler(DeveloperCompiler): - """ - PLENA High-level Compiler Interface. - - Inherits the ISA-emission machinery from DeveloperCompiler and layers a - Pythonic DSL on top. All operations are eagerly evaluated — ISA code is - generated immediately upon call. - """ - - def __init__(self, mlen: int = 64, blen: int = 4, real_data_ratio: float = 1.125, unroll_loops: bool = False): - """ - Args: - mlen: Matrix tile size (default 64) - blen: Vector tile size (default 4) - real_data_ratio: HBM storage ratio (MXFP8 format = 1.125) - unroll_loops: If True, unroll sub-projection loops at ASM-gen time to - eliminate C_LOOP_START/END overhead. Overridden by the - ATEN_UNROLL env var ("1"=True, "0"=False). - """ - _env_unroll = os.environ.get("ATEN_UNROLL", "") - if _env_unroll == "1": - unroll_loops = True - elif _env_unroll == "0": - unroll_loops = False - super().__init__(mlen=mlen, blen=blen, real_data_ratio=real_data_ratio, unroll_loops=unroll_loops) - - # HBM address auto-allocation - self._next_hbm_addr: int = 0 - self._hbm_free_blocks: list[tuple[int, int]] = [] # (addr, size) - - # Variable registries - self._inputs: dict[str, InputVar] = {} - self._tensors: dict[str, TensorVar] = {} - self._fp_vars: dict[str, FPVar] = {} - self._functions: dict[str, Callable] = {} - self._registered_hbm_sub_matrices: dict[str, bool] = {} - self._registered_vram_sub_matrices: dict[str, bool] = {} - - self._result_tensor: TensorVar | None = None - - # Auto-generated name counter - self._auto_name_counter: int = 0 - - # Function scope namespace - # Push a prefix on each function call (e.g., "linear_0/"), pop on exit - # _auto_name will automatically add current scope prefix, avoiding name conflicts when calling the same function multiple times - self._scope_stack: list[str] = [] - self._function_call_counters: dict[str, int] = {} # func_name -> call count - - # ======================================================================== - # Property Access - # ======================================================================== - - # mlen / blen are instance attributes inherited from TileCompiler.__init__. - - @property - def compiler(self) -> PlenaCompiler: - """Legacy accessor — returns self now that PlenaCompiler is the compiler.""" - return self - - @property - def _compiler(self) -> _DeveloperView: - """Back-compat shim for legacy ``prog._compiler.X(...)`` call sites. - Returns a proxy that resolves callables against DeveloperCompiler - directly so callers reach the low-level API regardless of any - PlenaCompiler override with the same name.""" - return _DeveloperView(self) - - @property - def symbol_table(self): - """Access symbol table.""" - return self.get_symbol_table() - - # ======================================================================== - # Input Declaration - # ======================================================================== - - def input( - self, - name: str, - shape: tuple[int, int], - hbm_addr: int | None = None, - prestaged_vram_addr: int | None = None, - ) -> InputVar: - """ - Declare an input tensor (in HBM). - - Args: - name: tensor name - shape: (height, width) - hbm_addr: HBM address (None = auto-allocate) - prestaged_vram_addr: If an int, the tensor is assumed to be already - present in VRAM at this byte address. A subsequent call to - ``load_batch`` will register it at that address without emitting - any HBM→VRAM prefetch instructions. If None (default), the - normal HBM→VRAM load path is used. - - Returns: - InputVar proxy object - """ - h, w = shape - size = h * w - hbm_size = int(size * self.real_data_ratio) - - if hbm_addr is None: - hbm_addr = self._allocate_hbm(hbm_size) - - var = InputVar(self, name, shape, hbm_addr, hbm_size, prestaged_vram_addr=prestaged_vram_addr) - self._inputs[name] = var - super().add_hbm_object( - name=name, - hbm_addr=hbm_addr, - shape=shape, - real_data_ratio=self.real_data_ratio, - ) - return var - - # ======================================================================== - # Load Operations - # ======================================================================== - - def load_batch( - self, - input_var: InputVar, - name: str | None = None, - ) -> VRAMMatrixVar: - """ - Load tensor from HBM to VRAM (Batch type). - - When ``input_var.prestaged_vram_addr`` is set the tensor is assumed to - be already resident in VRAM at that address. No HBM→VRAM prefetch - instructions are emitted; the tensor is simply registered in the symbol - table at the given address. - - Args: - input_var: source InputVar - name: result name (None = use input name) - - Returns: - VRAMMatrixVar proxy object - """ - if not isinstance(input_var, InputVar): - raise TypeError(f"Expected InputVar, got {type(input_var)}") - - display_name = name if name is not None else input_var.display_name - internal_name = self._scoped_name(display_name) - - if input_var.prestaged_vram_addr is not None: - # Prestaged path: tensor is already in VRAM — register without ISA. - h, w = input_var.shape - vram_addr = input_var.prestaged_vram_addr - # Tell the VRAM allocator that this region is occupied so subsequent - # allocations don't collide with it. - self.vram_allocator._vmm.mark_used(vram_addr, h * w, name=internal_name) - super().add_vram_object( - name=internal_name, - shape=(h, w), - vram_addr=vram_addr, - dtype="fp16", - kind="Batch", - allocate_if_none=False, - strict=False, - ) - else: - # Normal path: emit HBM → VRAM prefetch ISA. - super().load_batch( - hbm_object_name=input_var.name, vram_object_name=internal_name, vlen=self.mlen, preload_len=4 - ) - - var = VRAMMatrixVar(self, internal_name, input_var.shape, display_name=display_name) - self._tensors[internal_name] = var - return var - - # ======================================================================== - # Store Operations - # ======================================================================== - - def store(self, tensor_var, name: str | None = None, hbm_addr: int | None = None) -> InputVar: - """ - Write tensor from VRAM back to HBM. - - Returns: - InputVar proxy object (can be loaded back later) - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"Store requires VRAMMatrixVar, got {type(tensor_var)}") - - display_name = name if name is not None else f"{tensor_var.display_name}_stored" - internal_name = self._scoped_name(display_name) - - if hbm_addr is None: - h, w = tensor_var.shape - size = h * w - hbm_size = int(size * self.real_data_ratio) - hbm_addr = self._allocate_hbm(hbm_size) - else: - h, w = tensor_var.shape - hbm_size = int(h * w * self.real_data_ratio) - - super().store_to_hbm( - tensor_name=tensor_var.name, # internal name for symbol table lookup - hbm_addr=hbm_addr, - hbm_object_name=internal_name, - vlen=self.mlen, - ) - - var = InputVar(self, internal_name, tensor_var.shape, hbm_addr, hbm_size, display_name=display_name) - self._inputs[internal_name] = var - return var - - # ======================================================================== - # VRAM Matrix Allocation - # ======================================================================== - - def alloc(self, name: str, rows: int, cols: int, strict: bool = True) -> VRAMMatrixVar: - """ - Allocate a VRAM matrix. - - Used to store intermediate results (e.g., S block, PV, O). - Within function scope, names are automatically prefixed to avoid conflicts. - - Args: - name: matrix name (user-visible) - rows: number of rows - cols: number of columns - strict: if False, skip mlen-alignment checks (for small scratch matrices) - - Returns: - VRAMMatrixVar proxy object - """ - display_name = name - internal_name = self._scoped_name(name) - super().allocate_vram_matrix(name=internal_name, rows=rows, cols=cols, strict=strict) - - var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) - self._tensors[internal_name] = var - return var - - def alloc_at(self, name: str, rows: int, cols: int, vram_addr: int) -> VRAMMatrixVar: - """Allocate a VRAM matrix view at a specific address. - - Used to create views into existing VRAM matrices (e.g., per-head - slices of a multi-head Q projection output). Does NOT bump the - VRAM allocator -- the caller is responsible for ensuring the region - is valid. - - Args: - name: matrix name (user-visible) - rows: number of rows - cols: number of columns - vram_addr: absolute VRAM address for this view - - Returns: - VRAMMatrixVar proxy object - """ - display_name = name - internal_name = self._scoped_name(name) - self._compiler.add_vram_object( - name=internal_name, - shape=(rows, cols), - vram_addr=vram_addr, - allocate_if_none=False, - ) - isa_code = f"; VRAM View {name}: ({rows}, {cols}) at VRAM[{vram_addr}]\n" - self._compiler.generated_code += isa_code - var = VRAMMatrixVar(self, internal_name, (rows, cols), display_name=display_name) - self._tensors[internal_name] = var - return var - - def free_tensor(self, tensor_var: TensorVar): - """ - Free a tensor in VRAM, reclaiming space for subsequent allocations. - - Freed space can be reused by new alloc() or other operations. - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"Can only free VRAMMatrixVar, got {type(tensor_var)}") - - super().free_vram_object(tensor_var.name, strict=False) - # Keep sub-matrix registration state consistent after free. - self._registered_vram_sub_matrices[tensor_var.name] = False - - def free_input(self, input_var: InputVar): - """ - Free an InputVar bookkeeping and recycle its HBM range for future auto-allocation. - - Notes: - - This only affects PlenaCompiler's address management state. - - If a freed input is referenced again later, caller is responsible for correctness. - """ - if not isinstance(input_var, InputVar): - raise TypeError(f"Can only free InputVar, got {type(input_var)}") - - super().free_hbm_object(input_var.name, strict=False) - self._registered_hbm_sub_matrices[input_var.name] = False - self._recycle_hbm(input_var.hbm_addr, input_var.hbm_size) - self._inputs.pop(input_var.name, None) - - def free_fp_var(self, fp_var: FPVar): - """ - Free an FPVar and return its block to FPRAM free pool. - """ - if not isinstance(fp_var, FPVar): - raise TypeError(f"Can only free FPVar, got {type(fp_var)}") - self.free_fpram(fp_var.name, strict=True) - - # ======================================================================== - # Normalization Operations - # ======================================================================== - - def norm( - self, - tensor_var: TensorVar, - mode: str = "rms", - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """ - Normalize tensor in-place. - - Args: - tensor_var: tensor to normalize (must have VRAM backing, e.g., VRAMMatrixVar) - mode: "rms" or "layer" - eps_offset: FPRAM address of epsilon - reci_hid_offset: FPRAM address of 1/hidden_dim - vlen: vector length (default: program mlen) - scratchpad_vram_addr: optional scratchpad VRAM address - - Returns: - The same tensor_var (in-place operation) - """ - if not isinstance(tensor_var, VRAMMatrixVar): - raise TypeError(f"norm requires VRAMMatrixVar, got {type(tensor_var)}") - - super().normalize( - tensor_name=tensor_var.name, - mode=mode, - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - return tensor_var - - def rms_norm( - self, - tensor_var: TensorVar, - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """RMS normalization (in-place).""" - return self.norm( - tensor_var=tensor_var, - mode="rms", - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - - def layer_norm( - self, - tensor_var: TensorVar, - eps_offset: int = 1, - reci_hid_offset: int = 2, - vlen: int | None = None, - scratchpad_vram_addr: int | None = None, - ) -> TensorVar: - """Layer normalization (in-place).""" - return self.norm( - tensor_var=tensor_var, - mode="layer", - eps_offset=eps_offset, - reci_hid_offset=reci_hid_offset, - vlen=vlen, - scratchpad_vram_addr=scratchpad_vram_addr, - ) - - # ======================================================================== - # FP Variable (FPRAM) - # ======================================================================== - - def allocate_fpram( - self, - internal_name: str, - size: int = 1, - display_name: str | None = None, - ) -> FPVar: - """ - Allocate FPRAM with explicit internal name and return FPVar proxy. - """ - if size <= 0: - raise ValueError(f"FPRAM allocation size must be positive, got {size}") - - address = super().allocate_fpram(internal_name, size) - var = FPVar( - self, - internal_name, - address, - size, - display_name=display_name if display_name is not None else internal_name, - ) - self._fp_vars[internal_name] = var - return var - - def free_fpram(self, internal_name: str, strict: bool = True): - """ - Free FPRAM allocation by internal name. - """ - super().free_fpram(internal_name, strict=strict) - self._fp_vars.pop(internal_name, None) - - def fp_var(self, name: str, size: int = 1) -> FPVar: - """ - Declare an FP variable in FPRAM. - - Allocates a contiguous region in FPRAM and returns an FPVar proxy. - Within function scope, names are automatically prefixed. - - Args: - name: variable name - size: number of f16 elements to allocate (default 1) - - Returns: - FPVar proxy object (use .address for ISA generation) - - Example: - scale = prog.fp_var("scale") # 1 element - m_old = prog.fp_var("m_old", size=64) # 64 elements - prog.compiler # access compiler for ISA if needed - """ - display_name = name - internal_name = self._scoped_name(name) - - return self.allocate_fpram( - internal_name=internal_name, - size=size, - display_name=display_name, - ) - - def save_fpram_state(self) -> int: - """Save FPRAM allocator snapshot""" - return super().save_fpram_state() - - def restore_fpram_state(self, snapshot: int): - """Restore FPRAM allocator snapshot""" - super().restore_fpram_state(snapshot) - # Remove FPVar proxies that are no longer allocated in allocator. - allocations = set(super().list_fpram_allocations()) - to_remove = [n for n in self._fp_vars if n not in allocations] - for n in to_remove: - del self._fp_vars[n] - - # ======================================================================== - # FPRAM Tile Operations - # ======================================================================== - - def _resolve_fpram_addr(self, addr_or_var: int | FPVar, offset: int = 0) -> int: - if isinstance(addr_or_var, FPVar): - if offset < 0 or offset >= addr_or_var.size: - raise ValueError( - f"FPVar offset out of range: offset={offset}, size={addr_or_var.size}, var={addr_or_var.name}" - ) - return addr_or_var.address + offset - if not isinstance(addr_or_var, int): - raise TypeError(f"Expected int or FPVar, got {type(addr_or_var)}") - return addr_or_var + offset - - def _resolve_rows( - self, - row_idx: int | None = None, - rows: list[int] | None = None, - ) -> list[int]: - if row_idx is not None and rows is not None: - raise ValueError("Provide either row_idx or rows, not both") - if rows is not None: - return rows - if row_idx is not None: - return [row_idx] - return list(range(self.mlen)) - - def tile_row_max( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Max: reduce a single row to max, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - - Example: - m = prog.fp_var("m", size=1) - S = prog.alloc("S", 64, 64) - for row in range(64): - prog.tile_row_max(m, S, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_max( - source_matrix=source.name, - row_map=row_map, - ) - - def tile_row_sum( - self, - target_fpram_addr: int | FPVar, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - target_offset: int = 0, - target_base_offset: int = 0, - ): - """ - Tile Row Sum: reduce a single row to sum, store to FPRAM address. - - Args: - target_fpram_addr: FPRAM address or FPVar to write result - source: VRAM tile (mlen x mlen) - row_idx: single row index (legacy path) - rows: multiple row indices - target_offset: element offset when target_fpram_addr is FPVar - target_base_offset: base offset for multi-row writes (contiguous) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [target_offset] - else: - offsets = [target_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(target_fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sum(source.name, row_map) - - def tile_row_exp( - self, - source: VRAMMatrixVar, - row_idx: int | None = None, - rows: list[int] | None = None, - ): - """ - Tile Row Exp: in-place exp on specified rows. - - For each row i: source[i, :] = exp(source[i, :]) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - super().tile_row_exp(source.name, resolved_rows) - - def tile_row_reci( - self, - source: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Reciprocal: in-place 1/x on specified rows. - - For each row i: source[i, :] = 1.0 / source[i, :] - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_reci(source.name, rows) - - def tile_row_sub_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Sub FP: subtract FPRAM scalar from a single row. - - For row i: source[i, :] = source[i, :] - FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_sub_fp(source.name, row_map) - - def tile_row_mul_fp( - self, - source: VRAMMatrixVar, - fpram_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - fpram_base_offset: int = 0, - ): - """ - Tile Row Mul FP: multiply a single row by FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_addr] - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - if len(resolved_rows) == 1: - offsets = [fpram_offset] - else: - offsets = [fpram_base_offset + i for i in range(len(resolved_rows))] - row_map = [(r, self._resolve_fpram_addr(fpram_addr, off)) for r, off in zip(resolved_rows, offsets)] - super().tile_row_mul_fp(source.name, row_map) - - def tile_row_add_fp( - self, - source: VRAMMatrixVar, - fp_var: FPVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add FP: add FPRAM scalar to each specified row. - - For each row i: source[i, :] = source[i, :] + fp_var[i] - """ - if rows is None: - rows = list(range(self.mlen)) - row_map = [(r, fp_var[r]) for r in rows] - super().tile_row_add_fp(source.name, row_map) - - def tile_row_add( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Add: dst[i, :] += src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_add(dst.name, src.name, rows) - - def tile_row_sub( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Sub: dst[i, :] -= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_sub(dst.name, src.name, rows) - - def tile_row_mul( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - Tile Row Mul: dst[i, :] *= src[i, :] for specified rows. - """ - if rows is None: - rows = list(range(self.mlen)) - super().tile_row_mul(dst.name, src.name, rows) - - def fpvar_reci( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Reciprocal: compute 1/x for FPRAM scalar array. - - For each element i: dst[i] = 1.0 / src[i] - - Args: - src: source FPVar - dst: destination FPVar - count: number of elements (default: min(src.size, dst.size)) - - Example: - l = prog.fp_var("l", size=64) - inv_l = prog.fp_var("inv_l", size=64) - prog.fpvar_reci(l, inv_l) # inv_l = 1/l - """ - if count is None: - count = min(src.size, dst.size) - if count > src.size or count > dst.size: - raise ValueError(f"count={count} exceeds FPVar size: src.size={src.size}, dst.size={dst.size}") - super().fpram_reci(src.name, dst.name, count) - - def fpvar_max( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Max: element-wise max for FPRAM scalar arrays. - - For each element i: dst[i] = max(src1[i], src2[i]) - - Example: - m_new = prog.fp_var("m_new", size=64) - prog.fpvar_max(m_old, row_max, m_new) # m_new = max(m_old, row_max) - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_max(src1.name, src2.name, dst.name, count) - - def fpvar_sub( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Subtract: element-wise subtraction for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] - src2[i] - - Example: - diff = prog.fp_var("diff", size=64) - prog.fpvar_sub(m_old, m_new, diff) # diff = m_old - m_new - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_sub(src1.name, src2.name, dst.name, count) - - def fpvar_exp( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Exp: element-wise exp for FPRAM scalar array. - - For each element i: dst[i] = exp(src[i]) - - Example: - m_res = prog.fp_var("m_res", size=64) - prog.fpvar_exp(diff, m_res) # m_res = exp(diff) - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_exp(src.name, dst.name, count) - - def fpvar_mul( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Multiply: element-wise multiplication for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] * src2[i] - - Example: - result = prog.fp_var("result", size=64) - prog.fpvar_mul(l_old, m_res, result) # result = l_old * m_res - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_mul(src1.name, src2.name, dst.name, count) - - def fpvar_add( - self, - src1: FPVar, - src2: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Add: element-wise addition for FPRAM scalar arrays. - - For each element i: dst[i] = src1[i] + src2[i] - - Example: - l_new = prog.fp_var("l_new", size=64) - prog.fpvar_add(l_old, sum_p, l_new) # l_new = l_old + sum_p - """ - if count is None: - count = min(src1.size, src2.size, dst.size) - super().fpram_add(src1.name, src2.name, dst.name, count) - - def fpvar_copy( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Copy: copy FPRAM scalar array. - - For each element i: dst[i] = src[i] - - Example: - m_old_saved = prog.fp_var("m_old_saved", size=64) - prog.fpvar_copy(m_old, m_old_saved) # backup m_old - """ - if count is None: - count = min(src.size, dst.size) - super().fpram_copy(src.name, dst.name, count) - - def fpvar_sum( - self, - src: FPVar, - dst: FPVar, - count: int | None = None, - ): - """ - FPVar Sum: reduction sum of src into dst[0] (via compiler FPRAM op). - """ - if count is None: - count = src.size - super().fpram_sum(src.name, dst.name, count) - - def fpvar_shift( - self, - src: FPVar, - dst: FPVar, - shift: int, - count: int | None = None, - fill: FPVar | None = None, - ): - """ - FPVar Shift: shift src into dst, filling out-of-range slots with fill (default FPRAM zero). - """ - if count is None: - count = min(src.size, dst.size) - fill_name = None if fill is None else fill.name - super().fpram_shift( - src_name=src.name, - dst_name=dst.name, - shift=shift, - count=count, - fill_fpram_name=fill_name, - ) - - def tile_row_mul_fp_broadcast( - self, - source: VRAMMatrixVar, - fpram_scalar_addr: int | FPVar, - row_idx: int | None = None, - rows: list[int] | None = None, - fpram_offset: int = 0, - ): - """ - Tile Row Mul FP Broadcast: multiply a single row by a FPRAM scalar. - - For row i: source[i, :] = source[i, :] * FPRAM[fpram_scalar_addr] - - Args: - source: VRAM tile (mlen x mlen) - fpram_scalar_addr: FPRAM address or FPVar of the scalar - row_idx: single row index (legacy path) - rows: multiple row indices - - Example: - scale_fp = prog.fp_var("scale", size=1) - for row in range(64): - prog.tile_row_mul_fp_broadcast(S, scale_fp, rows=list(range(64))) - """ - resolved_rows = self._resolve_rows(row_idx=row_idx, rows=rows) - scalar_addr = self._resolve_fpram_addr(fpram_scalar_addr, fpram_offset) - super().tile_row_mul_fp_broadcast(source.name, scalar_addr, resolved_rows) - - def fpvar_fill_from_fpram( - self, - dst: FPVar, - src_fpram_addr: int, - count: int | None = None, - ): - """ - FPVar Fill from FPRAM: fill all elements with a value from FPRAM. - - For each element i: dst[i] = FPRAM[src_fpram_addr] - - Args: - dst: destination FPVar - src_fpram_addr: source FPRAM address (e.g., 0 for 0.0, 2 for -inf) - count: number of elements (default: dst.size) - - Example: - m_old = prog.fp_var("m_old", size=64) - prog.fpvar_fill_from_fpram(m_old, 2) # fill with -inf from address 2 - """ - if count is None: - count = dst.size - super().fpram_fill_from_fpram(dst.name, src_fpram_addr, count) - - def vram_fill_zero( - self, - matrix: VRAMMatrixVar, - rows: list[int] | None = None, - ): - """ - VRAM Fill Zero: fill specified rows with 0. - - Args: - matrix: VRAM matrix - rows: which rows to fill (default: all rows) - - Example: - O = prog.alloc("O", 128, 128) - prog.vram_fill_zero(O, rows=range(64, 128)) # zero out second half - """ - if rows is None: - rows = list(range(matrix.shape[0])) - else: - rows = list(rows) - - total_rows, cols = matrix.shape - if any(row < 0 or row >= total_rows for row in rows): - raise ValueError(f"vram_fill_zero rows out of bounds for {matrix.name}: shape={matrix.shape}, rows={rows}") - - # VRAM matrices are column-block-major. The low-level tile helper zeros - # one 64-column tile, so walk every column block for wide matrices. - num_col_blocks = (cols + self.mlen - 1) // self.mlen - for col_block in range(num_col_blocks): - super().vram_fill_zero(matrix.name, rows, tile_col_idx=col_block) - - def _ensure_hbm_sub_matrix_registered(self, input_var: InputVar): - """Ensure an HBM input is registered in compiler sub-matrix manager.""" - if ( - input_var.name in self._registered_hbm_sub_matrices - and self._registered_hbm_sub_matrices[input_var.name] is True - ): - return - h, w = input_var.shape - super().ensure_hbm_sub_matrix( - name=input_var.name, - hbm_addr=input_var.hbm_addr, - shape=(h, w), - real_data_ratio=self.real_data_ratio, - ) - self._registered_hbm_sub_matrices[input_var.name] = True - - def _ensure_vram_sub_matrix_registered(self, matrix_var: VRAMMatrixVar): - """Ensure a VRAM matrix is registered in compiler sub-matrix manager.""" - if ( - matrix_var.name in self._registered_vram_sub_matrices - and self._registered_vram_sub_matrices[matrix_var.name] is True - ): - return - super().ensure_vram_matrix_layout( - name=matrix_var.name, - shape=matrix_var.shape, - ) - self._registered_vram_sub_matrices[matrix_var.name] = True - - def vram_sub_projection_to( - self, - vram_matrix: VRAMMatrixVar, - vram_row_idx: int, - mram_input: InputVar, - mram_col_idx: int, - target: VRAMMatrixVar, - target_row_idx: int, - target_col_idx: int, - auto_reset_mram: bool = True, - k_block_start: int = 0, - k_block_count: int | None = None, - ): - """ - target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[:][mram_col_idx] - Supports K-split: k_block_start/k_block_count select a subset of K tiles. - """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() - super().load_sub_matrix_col( - name=mram_input.name, - col_idx=mram_col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - super().vram_sub_projection_to( - vram_mat_name=vram_matrix.name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_input.name, - mram_col_idx=mram_col_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - - def vram_sub_projection_T_to( - self, - vram_matrix: VRAMMatrixVar, - vram_row_idx: int, - mram_input: InputVar, - mram_row_idx: int, - target: VRAMMatrixVar, - target_row_idx: int, - target_col_idx: int, - auto_reset_mram: bool = True, - ): - """ - target[target_row_idx][target_col_idx] = vram_matrix[vram_row_idx][:] @ mram_input[mram_row_idx][:]^T - """ - if not isinstance(vram_matrix, VRAMMatrixVar): - raise TypeError(f"vram_matrix must be VRAMMatrixVar, got {type(vram_matrix)}") - if not isinstance(mram_input, InputVar): - raise TypeError(f"mram_input must be InputVar, got {type(mram_input)}") - if not isinstance(target, VRAMMatrixVar): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - self._ensure_vram_sub_matrix_registered(vram_matrix) - self._ensure_hbm_sub_matrix_registered(mram_input) - if auto_reset_mram: - super().reset_mram() - super().load_sub_matrix_row(name=mram_input.name, row_idx=mram_row_idx) - super().vram_sub_projection_T_to( - vram_mat_name=vram_matrix.name, - vram_row_idx=vram_row_idx, - mram_mat_name=mram_input.name, - mram_row_idx=mram_row_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - ) - - # ======================================================================== - # RoPE (1D Positional Encoding) - # ======================================================================== - - def rope( - self, - x_var: VRAMMatrixVar, - x_rot_var: VRAMMatrixVar, - cos_var: VRAMMatrixVar, - sin_var: VRAMMatrixVar, - ) -> VRAMMatrixVar: - """Apply Rotary Position Embedding in-place: x = x * cos + rotate_half(x) * sin - - x_rot_var must already be in VRAM as rotate_half(x), preloaded by caller. - Returns x_var (modified in-place). - """ - super().rope( - x_name=x_var.name, - x_rot_name=x_rot_var.name, - cos_name=cos_var.name, - sin_name=sin_var.name, - ) - return x_var - - # ======================================================================== - # VRAM Matrix Addition - # ======================================================================== - - def vram_add( - self, - dst: VRAMMatrixVar, - src: VRAMMatrixVar, - dst_row_offset: int = 0, - src_row_offset: int = 0, - num_rows: int | None = None, - ): - """VRAM matrix add: dst[row_offset:] += src""" - super().vram_matrix_add( - dst_matrix=dst.name, - src_matrix=src.name, - dst_row_offset=dst_row_offset, - src_row_offset=src_row_offset, - num_rows=num_rows, - ) - - def vram_block_add_to( - self, - src1: TensorVar, - src1_row_idx: int, - src1_col_idx: int, - src2: TensorVar, - src2_row_idx: int, - src2_col_idx: int, - target: TensorVar, - target_row_idx: int, - target_col_idx: int, - ): - """ - mlen x mlen block add: - target[target_row_idx][target_col_idx] = - src1[src1_row_idx][src1_col_idx] + src2[src2_row_idx][src2_col_idx] - - Supports writing back to the same matrix/block (in-place overwrite). - """ - allowed = (VRAMMatrixVar,) - if not isinstance(src1, allowed): - raise TypeError(f"src1 must be VRAMMatrixVar, got {type(src1)}") - if not isinstance(src2, allowed): - raise TypeError(f"src2 must be VRAMMatrixVar, got {type(src2)}") - if not isinstance(target, allowed): - raise TypeError(f"target must be VRAMMatrixVar, got {type(target)}") - - super().vram_block_add_to( - src1_matrix=src1.name, - src1_row_idx=src1_row_idx, - src1_col_idx=src1_col_idx, - src2_matrix=src2.name, - src2_row_idx=src2_row_idx, - src2_col_idx=src2_col_idx, - target_matrix=target.name, - target_row_idx=target_row_idx, - target_col_idx=target_col_idx, - ) - - # ======================================================================== - # Flash Attention Operations - # ======================================================================== - - def init_online_softmax(self, q_idx: int, o_matrix: VRAMMatrixVar): - """Initialize Online Softmax state: m=-inf, l=0, O_row=0""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().init_online_softmax( - q_idx=q_idx, - o_matrix=o_matrix.name, - seq_len=seq_len, - head_dim=head_dim, - ) - - def online_softmax_block(self, s_block: VRAMMatrixVar, scale: float): - """Perform Online Softmax on S block""" - super().online_softmax_block( - s_block_matrix=s_block.name, - scale=scale, - ) - - def compute_pv( - self, - s_block: VRAMMatrixVar, - v_input: InputVar, - k_idx: int, - pv_matrix: VRAMMatrixVar, - head_dim: int, - ): - """Compute PV = P @ V[k_idx] where P is stored in s_block.""" - if not isinstance(s_block, VRAMMatrixVar): - raise TypeError(f"s_block must be VRAMMatrixVar, got {type(s_block)}") - if not isinstance(v_input, InputVar): - raise TypeError(f"v_input must be InputVar, got {type(v_input)}") - if not isinstance(pv_matrix, VRAMMatrixVar): - raise TypeError(f"pv_matrix must be VRAMMatrixVar, got {type(pv_matrix)}") - - self._ensure_hbm_sub_matrix_registered(v_input) - super().compute_pv( - s_block_matrix=s_block.name, - v_sub_matrix=v_input.name, - k_idx=k_idx, - pv_matrix=pv_matrix.name, - head_dim=head_dim, - ) - - def scale_o_row(self, o_matrix: VRAMMatrixVar, q_idx: int): - """Scale current row block of O by m_res""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().scale_o_row( - o_matrix=o_matrix.name, - q_idx=q_idx, - seq_len=seq_len, - head_dim=head_dim, - ) - - def final_scale_o(self, q_idx: int, o_matrix: VRAMMatrixVar): - """Final scaling: O[q_idx] = O[q_idx] / l""" - o_info = super().get_tensor_info(o_matrix.name) - seq_len, head_dim = o_info.shape - - super().final_scale_o( - q_idx=q_idx, - o_matrix=o_matrix.name, - seq_len=seq_len, - head_dim=head_dim, - ) - - # ======================================================================== - # Function Decorator - # ======================================================================== - - def function(self, func: Callable) -> Callable: - """ - Decorator: Define reusable functions. - - Each invocation generates fresh ISA code (eager evaluation). - Internally allocated tensors are auto-freed on exit unless returned. - - Scoping: intermediate tensors get a call-index prefix to avoid name - collisions across repeated calls (e.g., "linear_0/proj_1", "linear_1/proj_1"). - Nested functions compose prefixes: "two_layer_0/linear_0/proj_1". - """ - func_name = func.__name__ - - @wraps(func) - def wrapper(*args, **kwargs): - call_idx = self._function_call_counters.get(func_name, 0) - self._function_call_counters[func_name] = call_idx + 1 - - scope = f"{func_name}_{call_idx}/" - self._scope_stack.append(scope) - - self.generated_code += f"; === Enter {func_name} (call #{call_idx}) ===\n" - - # Snapshot: record existing tensors before function execution - tensors_before = set(self._tensors.keys()) - inputs_before = set(self._inputs.keys()) - fp_vars_before = set(self._fp_vars.keys()) - - try: - result = func(*args, **kwargs) - - # Auto-free: free locally allocated tensors that are not returned - return_names = set() - return_fp_names = set() - if isinstance(result, TensorVar): - return_names.add(result.internal_name) - elif isinstance(result, FPVar): - return_fp_names.add(result.internal_name) - elif isinstance(result, (tuple, list)): - for r in result: - if isinstance(r, TensorVar): - return_names.add(r.internal_name) - elif isinstance(r, FPVar): - return_fp_names.add(r.internal_name) - - for name in set(self._tensors.keys()) - tensors_before: - if name not in return_names: - tensor = self._tensors[name] - if isinstance(tensor, VRAMMatrixVar): - self.free_tensor(tensor) - self._registered_vram_sub_matrices[tensor.name] = False - - for name in set(self._inputs.keys()) - inputs_before: - if name not in return_names: - self.free_input(self._inputs[name]) - - local_fp_names = sorted( - set(self._fp_vars.keys()) - fp_vars_before, - key=lambda n: self._fp_vars[n].address, - reverse=True, - ) - for name in local_fp_names: - if name in return_fp_names: - continue - fp_var = self._fp_vars.get(name) - if fp_var is not None: - self.free_fp_var(fp_var) - finally: - self._scope_stack.pop() - self.generated_code += f"; === Exit {func_name} (call #{call_idx}) ===\n" - - return result - - self._functions[func_name] = wrapper - wrapper._plena_function = True - wrapper._plena_name = func_name - return wrapper - - # ======================================================================== - # Result Marking - # ======================================================================== - - def result(self, tensor_var: TensorVar): - """Mark output result tensor.""" - self._result_tensor = tensor_var - - # ======================================================================== - # Compilation - # ======================================================================== - - def compile(self) -> str: - """Get generated ISA code string.""" - return super().get_code() - - def print_symbol_table(self): - """Print symbol table""" - super().print_symbol_table() - - def get_symbol_table(self): - """Get symbol table""" - return super().get_symbol_table() - - # ======================================================================== - # Operator Dispatch (internal) - # ======================================================================== - - def _dispatch_matmul(self, left: TensorVar, right) -> TensorVar: - raise TypeError("@ operator is no longer supported in PlenaCompiler. Use explicit program APIs instead.") - - # ======================================================================== - # Utility Methods - # ======================================================================== - - def _scoped_name(self, name: str) -> str: - """ - Apply current scope prefix to a name. - - - Top-level alloc("temp"): -> "temp" - - Inside linear call 0, alloc("temp"): -> "linear_0/temp" - - Nested two_layer->linear, alloc("temp"): -> "two_layer_0/linear_0/temp" - """ - if not self._scope_stack: - return name - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}{name}" - - def _allocate_hbm(self, hbm_size: int) -> int: - """Allocate HBM range, preferring previously freed blocks.""" - best_idx = None - best_waste = None - for i, (addr, size) in enumerate(self._hbm_free_blocks): - if size >= hbm_size: - waste = size - hbm_size - if best_waste is None or waste < best_waste: - best_idx = i - best_waste = waste - - if best_idx is not None: - addr, block_size = self._hbm_free_blocks.pop(best_idx) - # Return excess fragment to free list - excess = block_size - hbm_size - if excess > 0: - self._hbm_free_blocks.append((addr + hbm_size, excess)) - return addr - - addr = self._next_hbm_addr - m = self.mlen - self._next_hbm_addr = ((addr + hbm_size + m - 1) // m) * m - return addr - - def _recycle_hbm(self, hbm_addr: int, hbm_size: int): - """Recycle an HBM range for future auto-allocation.""" - if hbm_size <= 0: - return - self._hbm_free_blocks.append((hbm_addr, hbm_size)) - - def _auto_name(self, prefix: str = "t") -> str: - """ - Generate a unique scoped name. - - - Top-level: "__proj_1" - - linear call 0: "linear_0/__proj_1" - - nested: "two_layer_0/linear_0/__proj_1" - """ - self._auto_name_counter += 1 - scope_prefix = "".join(self._scope_stack) - return f"{scope_prefix}__{prefix}_{self._auto_name_counter}" - - def __repr__(self): - num_inputs = len(self._inputs) - num_tensors = len(self._tensors) - num_functions = len(self._functions) - code_len = len(super().get_code().splitlines()) - return ( - f"PlenaCompiler(mlen={self.mlen}, blen={self.blen}, " - f"inputs={num_inputs}, tensors={num_tensors}, " - f"functions={num_functions}, isa_lines={code_len})" - ) - - -class TensorKind(Enum): - """Identifies which memory the tensor lives in / which proxy backs it.""" - - HBM = "hbm" # legacy: InputVar - VRAM = "vram" # legacy: VRAMMatrixVar - FPRAM = "fpram" # legacy: FPVar - - -# Type alias: a "Tensor" is any of the legacy proxy objects. Callers can use -# this annotation without worrying about which specific backing allocator -# a given tensor lives in. -Tensor = TensorVar | InputVar | VRAMMatrixVar | FPVar - - -def tensor_kind(tensor: Tensor) -> TensorKind: - """Return the backing storage of a tensor proxy.""" - if isinstance(tensor, FPVar): - return TensorKind.FPRAM - if isinstance(tensor, VRAMMatrixVar): - return TensorKind.VRAM - if isinstance(tensor, InputVar): - return TensorKind.HBM - if isinstance(tensor, TensorVar): - # Generic TensorVar without a specific backing — classify by ``kind``. - kind = getattr(tensor, "kind", "") - if kind in ("vram_matrix", "batch", "matrix"): - return TensorKind.VRAM - if kind == "input": - return TensorKind.HBM - raise TypeError(f"Unknown tensor type: {type(tensor).__name__}") - - -# ============================================================================= -# Unified dataclass aliases for the three overlapping "Info" and three -# overlapping "Layout" types in tile_compiler.py. -# ============================================================================= - - -# ``TensorInfo`` is the union of the three Info dataclasses. Callers can -# import ``TensorInfo`` and use it as an annotation; at runtime the object -# will be whichever specific Info subtype ``TileCompiler`` constructed. -TensorInfo = MemoryObjectInfo | SubMatrixInfo | VRAMSubMatrixInfo - -# ``TileLayout`` is the union of the three Layout dataclasses. -TileLayout = MatrixBlockLayout | VRAMMatrixBlockLayout | FPRAMObjectLayout - - -# ============================================================================= -# Public exports. -# ============================================================================= - - -__all__ = [ - "DeveloperCompiler", - "FPRAMAllocator", - "FPRAMObjectLayout", - "FPVar", - "InputVar", - "MRAMAllocator", - "MatrixBlockLayout", - "MemoryBlock", - "MemoryObjectInfo", - "PlenaCompiler", - "RegisterAllocator", - "SubMatrixInfo", - "Tensor", - "TensorInfo", - "TensorKind", - "TensorVar", - "TileCompiler", - "TileLayout", - "VRAMAllocator", - "VRAMMatrixBlockLayout", - "VRAMMatrixVar", - "VRAMSubMatrixInfo", - "VirtualMemoryManager", - "tensor_kind", -] diff --git a/aten/plena_frontend.py b/aten/plena_frontend.py index 486c552..10abb25 100644 --- a/aten/plena_frontend.py +++ b/aten/plena_frontend.py @@ -1,44 +1,46 @@ -""" -Automatic HuggingFace model -> PLENA ISA compiler. +"""HuggingFace decoder model to PLENA ISA compiler.""" -Walks an HF nn.Module tree, extracts weights, and generates ISA with -proper residual connections for multi-layer decoder pipelines. - -Usage: - from transformers import AutoModelForCausalLM - from compiler.aten.plena_frontend import compile_hf_model, compile_and_run - - model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", - torch_dtype=torch.float32) - # Just compile (no emulator run): - result = compile_hf_model(model, seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) - - # Compile + run emulator + compare: - result = compile_and_run(model, "/tmp/build", seq_len=64, hidden_size=64, inter_dim=128, num_layers=2) -""" - -import json import math -from pathlib import Path +import re +from dataclasses import dataclass +from typing import Any import torch -import torch.nn.functional as F -from compiler.aten.plena_compiler import PlenaCompiler -from compiler.aten.ops.registry import OpRegistry, Backend +from compiler.aten.model_extract import ( + LayerWeights, + embedding_module, + extract_layer_weights, + extract_model_config, + find_model_root, +) import compiler.aten.ops as ops -from transactional_emulator.testbench.model_layer_test_builder import ( +from compiler.aten.ops.registry import Backend, OpRegistry +from compiler.aten.plena import PlenaCompiler +from compiler.aten.reference import ( + ReferencePrecision, + _ksplit_matmul, + _make_rotate_half_matrix, + make_rope_inputs, quantize_to_mxfp, - _make_rope_tables, + run_decoder_reference, ) -import re + +__all__ = [ + "_fix_large_immediates", + "_ksplit_matmul", + "_make_rotate_half_matrix", + "compile_hf_model", + "quantize_to_mxfp", +] _IMM2_BOUND = 1 << 18 # S_ADDI_INT max immediate + def _fix_large_immediates(isa_code: str) -> str: """Post-process ISA: replace S_ADDI_INT gp{r}, gp0, {large} with S_LUI_INT + S_ADDI_INT. - PlenaCompiler emits raw S_ADDI_INT for VRAM/HBM addresses. At native + PlenaCompiler emits raw S_ADDI_INT for VRAM/HBM addresses. At large model dimensions these can exceed the 18-bit immediate limit. This pass splits them into S_LUI_INT (upper 22 bits, shifted <<12) + S_ADDI_INT (lower 12 bits), matching asm_templates._imm.load_large_int. @@ -66,322 +68,195 @@ def _fix_large_immediates(isa_code: str) -> str: # --------------------------------------------------------------------------- REAL_DATA_RATIO = (8 * 8 + 8) / (8 * 8) -# Hardware K-split tile limit (matches _linear_projection MAX_K_TILES) -_HW_MAX_K_TILES = 4 +@dataclass(frozen=True) +class LayerInputVars: + """PLENA input variables for one extracted decoder layer.""" -def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, from_inter=None): - """Matrix multiply matching hardware K-split BF16 precision. + w_q: Any + w_o: Any + w_k_heads: list[Any] + w_v_heads: list[Any] + w_gate: Any + w_up: Any + w_down: Any - When the K dimension exceeds max_k_tiles * mlen, the hardware splits the - inner product into chunks, writing each partial sum to BF16 VRAM and - accumulating via BF16 add. This function replicates that precision loss - so the golden reference matches the emulator output. - If the K dimension fits in a single pass (num_k_tiles <= max_k_tiles), - this is equivalent to a single matmul with BF16 cast. - """ - if to_inter is None: - to_inter = lambda x: x.to(torch.bfloat16) - if from_inter is None: - from_inter = lambda x: x.float() - - k_total = A.shape[1] - num_k_tiles = math.ceil(k_total / mlen) - - if num_k_tiles <= max_k_tiles: - # Single pass — no K-split precision loss - # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) - # Cast B to BF16 then float32 to match MRAM precision loss - return from_inter(to_inter(torch.matmul(from_inter(to_inter(A)), from_inter(to_inter(B))))) - - # K-split: chunk K tiles into groups of max_k_tiles - result = None - k_start = 0 - while k_start < k_total: - k_end = min(k_start + max_k_tiles * mlen, k_total) - A_chunk = A[:, k_start:k_end] - B_chunk = B[k_start:k_end, :] - # Cast B_chunk to BF16 then float32 to match MRAM precision loss - partial = from_inter(to_inter(torch.matmul(from_inter(to_inter(A_chunk)), from_inter(to_inter(B_chunk))))) - if result is None: - result = partial - else: - # Hardware accumulates in BF16 VRAM (vram_block_add_to) - result = from_inter(to_inter(result) + to_inter(partial)) - k_start = k_end +def _save_residual_and_norm(prog, source, scratch): + """Emit the common decoder pre-norm residual prologue.""" + prog.vram_fill_zero(scratch) + prog.vram_add(scratch, source) + ops.rms_norm(prog, source, eps_offset=3, reci_hid_offset=4) - return result +def _add_residual(prog, target, scratch): + prog.vram_add(target, scratch) + return target -# --------------------------------------------------------------------------- -# RoPE helpers -# --------------------------------------------------------------------------- -def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: - """Build the (head_dim, head_dim) matrix that computes rotate_half. - rotate_half(x) = x @ R, where R permutes the halves with a sign flip: - output[:d//2] = -input[d//2:] - output[d//2:] = +input[:d//2] - """ - R = torch.zeros(head_dim, head_dim) - half = head_dim // 2 - for i in range(half): - R[i + half, i] = -1.0 # first half of output = negated second half of input - R[i, i + half] = 1.0 # second half of output = first half of input - return R +def _linear_projection(prog, input_var, weight_var, name: str): + return ops.linear(prog, input_var, weight_var, name=name) -# --------------------------------------------------------------------------- -# Model structure helpers -# --------------------------------------------------------------------------- -def _find_model_root(model): - """Find the transformer backbone (model.model or model.model.text_model). +def _apply_rope_projection(prog, x_var, rope_matrix, cos_var, sin_var, name): + x_rot = _linear_projection(prog, x_var, rope_matrix, name) + ops.rope(prog, x_var, x_rot, cos_var, sin_var) + prog.free_tensor(x_rot) + return x_var - Handles standard CausalLM models and VLMs like SmolVLM2. - """ - for candidate in [ - getattr(model, "model", None), - getattr(getattr(model, "model", None), "text_model", None), - getattr(model, "language_model", getattr(model, "text_model", None)), - ]: - if candidate is not None and hasattr(candidate, "layers"): - return candidate - raise ValueError(f"Cannot find decoder layers on {type(model).__name__}") - - -def _extract_config(model): - """Extract config dimensions from the model, resolving text_config for VLMs.""" - config = getattr(model.config, "text_config", model.config) - native_hidden = config.hidden_size - native_inter = getattr(config, "intermediate_size", 4 * native_hidden) - native_heads = config.num_attention_heads - native_kv_heads = getattr(config, "num_key_value_heads", native_heads) - native_head_dim = native_hidden // native_heads - eps = getattr(config, "rms_norm_eps", 1e-5) - rope_theta = getattr(config, "rope_theta", 10000.0) - vocab_size = getattr(config, "vocab_size", None) - return { - "hidden_size": native_hidden, - "inter_dim": native_inter, - "num_heads": native_heads, - "num_kv_heads": native_kv_heads, - "head_dim": native_head_dim, - "eps": eps, - "rope_theta": rope_theta, - "vocab_size": vocab_size, - "model_type": getattr(config, "model_type", "unknown"), - } +def _copy_into_vram_view(prog, source, name, rows, cols, vram_addr): + target = prog.alloc_at(name, rows, cols, vram_addr) + prog.vram_fill_zero(target) + prog.vram_add(target, source) + return target -def _extract_layer_weights(layer, hidden_slice, inter_slice, head_dim_slice, num_heads, head_dim, - num_kv_heads=1, native_mode=False): - """Extract and slice weights from a single decoder layer. - - Transposes from HF's (out_features, in_features) to PLENA's (in, out) convention. - - In native mode (hidden_slice == native hidden, multi-head): - - W_q: (hidden, num_heads * head_dim) — full Q projection - - W_o: (num_heads * head_dim, hidden) — full O projection - - W_k_heads: list of (hidden, head_dim) per KV head - - W_v_heads: list of (hidden, head_dim) per KV head - - In sliced mode (legacy single-head): - - W_q: (hidden_slice, head_dim_slice) - - W_o: (head_dim_slice, hidden_slice) - - W_k: (hidden_slice, head_dim_slice) - - W_v: (hidden_slice, head_dim_slice) - - Args: - layer: nn.Module for a single decoder layer - hidden_slice: target hidden dimension - inter_slice: target intermediate dimension - head_dim_slice: target head dimension (min of native head_dim, hidden_slice) - num_heads: number of attention heads (native config) - head_dim: native head dimension - num_kv_heads: number of KV heads (native config) - native_mode: if True, extract full multi-head weights - - Returns: - dict with W_q, W_o, W_gate, W_up, W_down, W_k/W_k_heads, W_v/W_v_heads, eps - """ - # FFN weights: HF stores (out, in) -> transpose to (in, out) -> slice - W_gate = layer.mlp.gate_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_up = layer.mlp.up_proj.weight.detach().T.contiguous()[:hidden_slice, :inter_slice] - W_down = layer.mlp.down_proj.weight.detach().T.contiguous()[:inter_slice, :hidden_slice] - - # eps from input_layernorm - norm = layer.input_layernorm - eps = getattr(norm, "variance_epsilon", getattr(norm, "eps", 1e-5)) - - if native_mode: - # Full multi-head weights (no slicing) - total_q_dim = num_heads * head_dim - total_kv_dim = num_kv_heads * head_dim - - W_q = layer.self_attn.q_proj.weight.detach().T.contiguous()[:hidden_slice, :total_q_dim] - W_o = layer.self_attn.o_proj.weight.detach().T.contiguous()[:total_q_dim, :hidden_slice] - - # Per-KV-head K and V weights - W_k_full = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :total_kv_dim] - W_v_full = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :total_kv_dim] - - W_k_heads = [W_k_full[:, h * head_dim:(h + 1) * head_dim].contiguous() - for h in range(num_kv_heads)] - W_v_heads = [W_v_full[:, h * head_dim:(h + 1) * head_dim].contiguous() - for h in range(num_kv_heads)] - - return { - "W_q": W_q, - "W_o": W_o, - "W_gate": W_gate, - "W_up": W_up, - "W_down": W_down, - "W_k_heads": W_k_heads, - "W_v_heads": W_v_heads, - "eps": eps, - } - else: - # Legacy single-head sliced mode - W_q = layer.self_attn.q_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - W_o = layer.self_attn.o_proj.weight.detach().T.contiguous()[:head_dim_slice, :hidden_slice] - W_k = layer.self_attn.k_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - W_v = layer.self_attn.v_proj.weight.detach().T.contiguous()[:hidden_slice, :head_dim_slice] - - return { - "W_q": W_q, - "W_o": W_o, - "W_gate": W_gate, - "W_up": W_up, - "W_down": W_down, - "W_k": W_k, - "W_v": W_v, - "eps": eps, - } +def _free_named_tensors(prog, names): + for name in names: + tensor = prog._tensors.get(name) + if tensor is not None: + prog.free_tensor(tensor) + prog._tensors.pop(name, None) -# --------------------------------------------------------------------------- -# Golden reference helpers (match hardware: MXFP8 HBM + BF16 intermediates) -# --------------------------------------------------------------------------- -def _flash_attn_ref(Q, K, V, scale, causal=False): - """CPU reference: scaled dot-product attention matching hardware BF16 precision. - Hardware path: - S = Q @ K^T (M_TMM in f32, written to VRAM as BF16 via M_MM_WO) - S *= scale (V_MUL_VF: BF16 * f32 -> BF16) - P = softmax(S) (online softmax on BF16 data) - O = P @ V (M_MM in f32, written to VRAM as BF16 via M_MM_WO) +def _emit_kv_stores(prog, current, layer_inputs, rope_inputs, layer_idx, num_kv_heads): + rope_matrix, cos_var, sin_var = rope_inputs + kv_stored = [] + for kv_h in range(num_kv_heads): + K_h = _linear_projection( + prog, + current, + layer_inputs.w_k_heads[kv_h], + f"K_{layer_idx}_h{kv_h}", + ) + V_h = _linear_projection( + prog, + current, + layer_inputs.w_v_heads[kv_h], + f"V_{layer_idx}_h{kv_h}", + ) - We model the key BF16 truncation points to match hardware precision. - """ - # S = Q @ K^T, then truncate to BF16 (M_MM_WO writes BF16) - scores = (Q @ K.T).to(torch.bfloat16).float() * scale - scores = scores.to(torch.bfloat16).float() # V_MUL_VF result is BF16 - if causal: - mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], - device=scores.device), diagonal=1).bool() - scores.masked_fill_(mask, float('-inf')) - # Softmax output written to BF16 VRAM - attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() - # O = P @ V, result written to BF16 VRAM - return (attn @ V).to(torch.bfloat16).float() - - -def _rms_norm_ref(x, eps): - """CPU reference: RMS normalization matching PLENA hardware. - - Hardware path: V_RED_SUM accumulates sum-of-squares into f32 scalar register, - S_MUL_FP / S_ADD_FP / S_SQRT_FP / S_RECI_FP all operate in f32, - then V_MUL_VF multiplies BF16 vector by f32 scalar -> BF16 result. - The scalar rms factor stays in f32 throughout; only the vector data is BF16. - """ - x_bf = x.to(torch.bfloat16) - # Compute rms in f32 (matching hardware scalar register precision) - rms = torch.rsqrt(x_bf.float().pow(2).mean(-1, keepdim=True) + eps) - # V_MUL_VF: BF16 vector * f32 scalar -> quantized back to BF16 - return (x_bf.float() * rms).to(torch.bfloat16).float() + _apply_rope_projection( + prog, + K_h, + rope_matrix, + cos_var, + sin_var, + f"K_rot_{layer_idx}_h{kv_h}", + ) + K_stored = prog.store(K_h, name=f"K_stored_{layer_idx}_h{kv_h}") + V_stored = prog.store(V_h, name=f"V_stored_{layer_idx}_h{kv_h}") + kv_stored.append((K_stored, V_stored)) + + prog.free_tensor(K_h) + prog.free_tensor(V_h) + + return kv_stored + + +def _emit_attention_block( + prog, + current, + layer_inputs, + rope_inputs, + causal_mask, + scratch, + scale, + layer_idx, + seq_len, + head_dim, + total_q_dim, + num_heads, + num_kv_heads, + ratio, +): + _save_residual_and_norm(prog, current, scratch) + + Q = _linear_projection(prog, current, layer_inputs.w_q, f"Q_{layer_idx}") + q_full_addr = prog.get_vram_addr(Q.name) + + O_full = prog.alloc(f"O_full_{layer_idx}", seq_len, total_q_dim) + o_full_addr = prog.get_vram_addr(O_full.name) + + kv_stored = _emit_kv_stores( + prog, + current, + layer_inputs, + rope_inputs, + layer_idx, + num_kv_heads, + ) -# --------------------------------------------------------------------------- -# PLENA ISA helper: named linear projection (avoids name conflicts) -# --------------------------------------------------------------------------- -def _linear_projection(prog, input_var, weight_var, name): - """Emit a linear projection with a custom VRAM output name. + rope_matrix, cos_var, sin_var = rope_inputs + for h in range(num_heads): + kv_h = h // ratio + K_stored, V_stored = kv_stored[kv_h] + + q_h_addr = q_full_addr + h * seq_len * prog.mlen + Q_h = prog.alloc_at(f"Q_h{h}_{layer_idx}", seq_len, head_dim, q_h_addr) + _apply_rope_projection( + prog, + Q_h, + rope_matrix, + cos_var, + sin_var, + f"Q_rot_{layer_idx}_h{h}", + ) - Equivalent to ops.linear but uses *name* for the output allocation so - that multiple projections in the same scope don't collide on the - default "linear_out" name. + O_h = ops.flash_attention( + prog, + Q_h, + K_stored, + V_stored, + scale, + causal_mask=causal_mask, + ) - Supports K-split: when K tiles exceed MRAM capacity (4 tiles), splits - into chunks and accumulates partial sums via a temporary buffer. - """ - import math as _math + o_h_dest_addr = o_full_addr + h * seq_len * prog.mlen + _copy_into_vram_view( + prog, + O_h, + f"O_dest_h{h}_{layer_idx}", + seq_len, + head_dim, + o_h_dest_addr, + ) + _free_named_tensors(prog, ("O", "S", "PV")) + + O_proj = _linear_projection(prog, O_full, layer_inputs.w_o, f"O_proj_{layer_idx}") + return _add_residual(prog, O_proj, scratch) - mlen = prog.mlen - MAX_K_TILES = 4 # MRAM capacity: 4 x mlen^2 elements - rows, k_total = input_var.shape - _, out_features = weight_var.shape - num_row_blocks = _math.ceil(rows / mlen) - assert out_features % mlen == 0, ( - f"out_features ({out_features}) must be a multiple of mlen ({mlen})" +def _emit_ffn_block(prog, current, layer_inputs, scratch): + _save_residual_and_norm(prog, current, scratch) + ops.ffn(prog, current, layer_inputs.w_gate, layer_inputs.w_up, layer_inputs.w_down) + return _add_residual(prog, current, scratch) + + +def _register_layer_inputs(prog, layer_idx: int, weights: LayerWeights) -> LayerInputVars: + named_vars = {} + w_k_heads = [] + w_v_heads = [] + for tensor_name, tensor in weights.tensor_entries(layer_idx): + var = prog.input(tensor_name, shape=tuple(tensor.shape)) + if tensor_name.startswith(f"W_k_{layer_idx}_h"): + w_k_heads.append(var) + elif tensor_name.startswith(f"W_v_{layer_idx}_h"): + w_v_heads.append(var) + else: + named_vars[tensor_name[: tensor_name.rfind(f"_{layer_idx}")]] = var + + return LayerInputVars( + w_q=named_vars["W_q"], + w_o=named_vars["W_o"], + w_k_heads=w_k_heads, + w_v_heads=w_v_heads, + w_gate=named_vars["W_gate"], + w_up=named_vars["W_up"], + w_down=named_vars["W_down"], ) - num_col_blocks = out_features // mlen - num_k_tiles = _math.ceil(k_total / mlen) - - output_strict = rows % mlen == 0 - output = prog.alloc(name, rows, out_features, strict=output_strict) - - if num_k_tiles <= MAX_K_TILES: - # Single pass: all K tiles fit in MRAM - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - prog.vram_sub_projection_to( - input_var, - row_idx, - weight_var, - col_idx, - output, - row_idx, - col_idx, - ) - else: - # K-split: chunk K tiles into groups of MAX_K_TILES - k_chunks = [] - k_start = 0 - while k_start < num_k_tiles: - k_end = min(k_start + MAX_K_TILES, num_k_tiles) - k_chunks.append((k_start, k_end - k_start)) - k_start = k_end - - temp = prog.alloc(f"{name}_ksplit_tmp", rows, out_features, strict=output_strict) - - for k_chunk_idx, (k_block_start, k_block_count) in enumerate(k_chunks): - for col_idx in range(num_col_blocks): - for row_idx in range(num_row_blocks): - if k_chunk_idx == 0: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - output, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - else: - prog.vram_sub_projection_to( - input_var, row_idx, weight_var, col_idx, - temp, row_idx, col_idx, - k_block_start=k_block_start, - k_block_count=k_block_count, - ) - prog.vram_block_add_to( - output, row_idx, col_idx, - temp, row_idx, col_idx, - output, row_idx, col_idx, - ) - - prog.free_tensor(temp) - - return output # --------------------------------------------------------------------------- @@ -390,98 +265,29 @@ def _linear_projection(prog, input_var, weight_var, name): def compile_hf_model( model, seq_len: int = 64, - hidden_size: int | None = None, - inter_dim: int | None = None, num_layers: int | None = None, layer_idx_start: int = 0, mlen: int = 64, blen: int = 4, seed: int = 42, - include_lm_head: bool = False, golden_precision: str = "hardware", + verbose: bool = False, ) -> dict: - """Compile a HuggingFace model to PLENA ISA via PlenaCompiler. - - Walks the nn.Module tree, extracts weights, and generates ISA with - proper residual connections for multi-layer decoders. - - The pipeline implemented (per-layer, with pre-norm + residual): - X = embed_tokens(input_ids) # real HF embedding lookup - X = embedding_add(X, zeros) # pos_weight=0 for Llama (RoPE handles position) - for each layer: - residual = X - X = rms_norm(X) - Q = linear(X, W_q) # Q projection - K_h = linear(X, W_k_h) # K projection (per KV head, on-chip) - V_h = linear(X, W_v_h) # V projection (per KV head, on-chip) - # RoPE (native mode only): - Q_rot_h = linear(Q_h, R_rope) # rotate_half via matmul - Q_h = Q_h * cos + Q_rot_h * sin - K_rot_h = linear(K_h, R_rope) - K_h = K_h * cos + K_rot_h * sin - store K_h, V_h to HBM - O = flash_attention(Q, K, V, scale) - O = linear(O, W_o) # O projection - X = O + residual - residual = X - X = rms_norm(X) - X = ffn(X, gate, up, down) - X = X + residual - X = rms_norm(X) # final norm - if include_lm_head: - logits = linear(X, W_lm_head) # (seq, vocab_size) - - RoPE is applied in native mode via rotate_half matrix multiplication. - Q, K, V, and O linear projections are all computed on-chip. - - Args: - model: nn.Module (HF CausalLM model, already loaded) - seq_len: Sequence length (default 64) - hidden_size: Target hidden dimension (None = use model's native) - inter_dim: Target intermediate dimension (None = use model's native) - num_layers: Number of layers to compile (None = all layers) - layer_idx_start: First layer index to use (default 0) - mlen: Matrix tile length (default 64) - blen: Batch tile length (default 4) - seed: Random seed for test data generation - include_lm_head: If True, add lm_head projection after final norm (default False) - golden_precision: Precision mode for golden reference computation. - "hardware" — MXFP8 weights + BF16 intermediates (default, matches HW) - "no_weight_quant" — float32 weights + BF16 intermediates (isolates MXFP8 effect) - "no_bf16" — MXFP8 weights + float32 intermediates (isolates BF16 effect) - "fp32" — float32 weights + float32 intermediates (should match HF exactly) - - Returns: - dict with: - isa: str - generated ISA code - golden_output: torch.Tensor - CPU golden reference output - input_tensors: dict - {name: tensor} for sim env setup - data_order: list[str] - HBM tensor ordering - fp_preload: list[float] - FPRAM constants - comparison_params: dict - for emulator comparison - info: dict - model dims, VRAM usage, etc. - """ - # -------------------------------------------------------------- config - native_cfg = _extract_config(model) - hidden = hidden_size if hidden_size is not None else native_cfg["hidden_size"] - inter = inter_dim if inter_dim is not None else native_cfg["inter_dim"] - - num_heads = native_cfg["num_heads"] - num_kv_heads = native_cfg["num_kv_heads"] - native_head_dim = native_cfg["head_dim"] - - # Native mode: use all heads at full dimension when hidden_size is not overridden - native_mode = (hidden_size is None) and (num_heads > 1) - if native_mode: - head_dim = native_head_dim - total_q_dim = num_heads * head_dim # e.g. 6*64 = 384 - total_kv_dim = num_kv_heads * head_dim # e.g. 2*64 = 128 - else: - head_dim = min(native_head_dim, hidden) - total_q_dim = head_dim - total_kv_dim = head_dim - - root = _find_model_root(model) + """Compile a HuggingFace decoder model to PLENA ISA and simulation metadata.""" + def _verbose(message: str = ""): + if verbose: + print(message) + + model_cfg = extract_model_config(model) + hidden = model_cfg.hidden_size + inter = model_cfg.inter_dim + num_heads = model_cfg.num_heads + num_kv_heads = model_cfg.num_kv_heads + head_dim = model_cfg.head_dim + total_q_dim = model_cfg.total_q_dim + ratio = model_cfg.head_ratio + + root = find_model_root(model) layers = root.layers n_layers = num_layers if num_layers is not None else len(layers) assert layer_idx_start + n_layers <= len(layers), ( @@ -490,46 +296,15 @@ def compile_hf_model( ) scale = 1.0 / math.sqrt(head_dim) - - # ----------------------------------------------------------- embedding - embed = getattr(root, "embed_tokens", getattr(root, "wte", None)) - - # ----------------------------------------------------------- lm_head - lm_head_weight = None - vocab_size = native_cfg.get("vocab_size") - if include_lm_head: - lm_head_mod = getattr(model, "lm_head", None) - if lm_head_mod is None: - lm_head_mod = getattr( - getattr(model, "language_model", model), "lm_head", None - ) - if lm_head_mod is not None and hasattr(lm_head_mod, "weight"): - # nn.Linear stores (vocab, hidden) -> transpose to (hidden, vocab) - lm_head_weight_raw = lm_head_mod.weight.detach().T.contiguous() - # Slice to hidden if we're in sliced mode - lm_head_weight = lm_head_weight_raw[:hidden, :] - vocab_size = lm_head_weight.shape[1] - # Ensure vocab_size is a multiple of mlen (pad if needed) - if vocab_size % mlen != 0: - pad_cols = mlen - (vocab_size % mlen) - lm_head_weight = F.pad(lm_head_weight, (0, pad_cols)) - vocab_size = lm_head_weight.shape[1] - else: - print("WARNING: include_lm_head=True but no lm_head module found; skipping") - include_lm_head = False + embed = embedding_module(root) print("=" * 80) - print(f"Model Compiler - {native_cfg['model_type']} ({n_layers} layers)") - print(f" native: hidden={native_cfg['hidden_size']}, inter={native_cfg['inter_dim']}, " - f"heads={native_cfg['num_heads']}/{native_cfg['num_kv_heads']}, " - f"head_dim={native_cfg['head_dim']}") - print(f" sim: hidden={hidden}, inter={inter}, head_dim={head_dim}, " - f"seq_len={seq_len}, mlen={mlen}, native_mode={native_mode}") - if native_mode: - print(f" MHA: num_heads={num_heads}, num_kv_heads={num_kv_heads}, " - f"total_q_dim={total_q_dim}, total_kv_dim={total_kv_dim}") - if include_lm_head: - print(f" lm_head: W_lm_head={lm_head_weight.shape}, vocab_size={vocab_size}") + print(f"Model Compiler - {model_cfg.model_type} ({n_layers} layer{'s' if n_layers != 1 else ''})") + print( + f" decoder: hidden={hidden}, inter={inter}, heads={num_heads}/{num_kv_heads}, " + f"head_dim={head_dim}" + ) + print(f" compile: seq_len={seq_len}, mlen={mlen}, blen={blen}, total_q_dim={total_q_dim}") print("=" * 80) # ----------------------------------------------------------- weights @@ -537,37 +312,30 @@ def compile_hf_model( all_weights = [] for i in range(n_layers): layer_module = layers[layer_idx_start + i] - w = _extract_layer_weights( - layer_module, hidden, inter, head_dim, - num_heads, native_head_dim, - num_kv_heads=num_kv_heads, - native_mode=native_mode, - ) + w = extract_layer_weights(layer_module, model_cfg) all_weights.append(w) - if native_mode: - print(f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " - f"W_gate={w['W_gate'].shape}, " - f"K_heads={len(w['W_k_heads'])}x{w['W_k_heads'][0].shape}, eps={w['eps']}") - else: - print(f" Layer {i}: W_q={w['W_q'].shape}, W_o={w['W_o'].shape}, " - f"W_gate={w['W_gate'].shape}, W_k={w['W_k'].shape}, eps={w['eps']}") + _verbose( + f" Layer {i}: W_q={w.w_q.shape}, W_o={w.w_o.shape}, " + f"W_gate={w.w_gate.shape}, " + f"K_heads={len(w.w_k_heads)}x{w.w_k_heads[0].shape}, eps={w.eps}" + ) - eps = all_weights[0]["eps"] + eps = all_weights[0].eps - # ----------------------------------------------------------- test data torch.manual_seed(seed) - # Embedding lookup: use real HF embedding table if available if embed is not None: - input_ids = torch.randint(0, native_cfg["vocab_size"] or 32000, (seq_len,)) + input_ids = torch.randint(0, model_cfg.vocab_size or 32000, (seq_len,)) with torch.no_grad(): token_embeds = embed(input_ids).float() if token_embeds.dim() == 3: token_embeds = token_embeds.squeeze(0) - # Slice to target hidden dim (native mode uses full hidden) + # Slice to the model hidden width. token_embeds = token_embeds[:, :hidden] - print(f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " - f"token_embeds={token_embeds.shape}") + _verbose( + f"\nEmbedding lookup: input_ids shape={input_ids.shape}, " + f"token_embeds={token_embeds.shape}" + ) else: token_embeds = torch.randn(seq_len, hidden) print(f"\nNo embed_tokens found; using random token_embeds: {token_embeds.shape}") @@ -576,215 +344,51 @@ def compile_hf_model( # Set pos_weight to zeros so embedding_add is a no-op for position. pos_weight = torch.zeros(seq_len, hidden) - # K/V test data: native mode computes K/V on-chip, legacy mode uses precomputed - if native_mode: - # K/V are computed on-chip from X_normed @ W_k/W_v — no precomputed K/V needed - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") - for i in range(n_layers): - for kv_h in range(num_kv_heads): - print(f" W_k_{i}_h{kv_h}: {all_weights[i]['W_k_heads'][kv_h].shape}, " - f"W_v_{i}_h{kv_h}: {all_weights[i]['W_v_heads'][kv_h].shape}") - else: - K_mats = [] - V_mats = [] - for i in range(n_layers): - X_ctx = torch.randn(seq_len, hidden) - K_mats.append(X_ctx @ all_weights[i]["W_k"]) - V_mats.append(X_ctx @ all_weights[i]["W_v"]) - - print(f"pos_weight: zeros {pos_weight.shape} (Llama uses RoPE, not learned pos embed)") - for i in range(n_layers): - print(f" K_{i}: {K_mats[i].shape}, V_{i}: {V_mats[i].shape}") - print(f"attn_scale: {scale:.6f}") - - # ----------------------------------------------------------- golden ref - _do_quant = golden_precision in ("hardware", "no_bf16") - _do_bf16 = golden_precision in ("hardware", "no_weight_quant") - _qw = quantize_to_mxfp if _do_quant else (lambda x: x) - _to_inter = (lambda x: x.to(torch.bfloat16)) if _do_bf16 else (lambda x: x) - _from_inter = (lambda x: x.float()) if _do_bf16 else (lambda x: x) - _prec_label = {"hardware": "MXFP8 weights + BF16 intermediates", - "no_weight_quant": "float32 weights + BF16 intermediates", - "no_bf16": "MXFP8 weights + float32 intermediates", - "fp32": "float32 weights + float32 intermediates"}[golden_precision] - print(f"\n--- CPU Golden Reference ({_prec_label}) ---") - - if native_mode: - W_k_q_heads = [[_qw(all_weights[i]["W_k_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - W_v_q_heads = [[_qw(all_weights[i]["W_v_heads"][h]) - for h in range(num_kv_heads)] - for i in range(n_layers)] - if native_mode: - R_matrix = _make_rotate_half_matrix(head_dim) - R_rope_q = _qw(R_matrix) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, native_cfg["rope_theta"]) - - else: - K_q_list = [_qw(K_mats[i]) for i in range(n_layers)] - V_q_list = [_qw(V_mats[i]) for i in range(n_layers)] - - X_gold = _qw(token_embeds.clone()) + _qw(pos_weight) # embedding_add (MXFP8-quantized, matching HBM) - ratio = num_heads // num_kv_heads - + _verbose(f"pos_weight: zeros {pos_weight.shape} (RoPE model; learned position add is a no-op)") for i in range(n_layers): - w = all_weights[i] - W_q_q = _qw(w["W_q"]) - W_o_q = _qw(w["W_o"]) - W_gate_q = _qw(w["W_gate"]) - W_up_q = _qw(w["W_up"]) - W_down_q = _qw(w["W_down"]) - - # --- Attention block --- - residual = X_gold.clone() - X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 - rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() - if native_mode: - Q_gold = _ksplit_matmul(X_gold, W_q_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - Q_gold = _from_inter(_to_inter(Q_gold)) # VRAM write: BF16 - - K_q_heads_i = [] - V_q_heads_i = [] - for kv_h in range(num_kv_heads): - K_h = _ksplit_matmul(X_gold, W_k_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - V_h = _ksplit_matmul(X_gold, W_v_q_heads[i][kv_h], mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - K_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(K_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - K_cos = _from_inter(_to_inter(K_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - K_rot_sin = _from_inter(_to_inter(K_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - K_h = _from_inter(_to_inter(K_cos) + _to_inter(K_rot_sin)) - # Hardware stores K/V to HBM as MXFP8, loads back as BF16 for attention - K_q_heads_i.append(_from_inter(_to_inter(_qw(K_h)))) - V_q_heads_i.append(_from_inter(_to_inter(_qw(V_h)))) - - O_heads = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_gold[:, h * head_dim:(h + 1) * head_dim] - Q_rot_h = _from_inter(_to_inter(torch.matmul(_from_inter(_to_inter(Q_h)), _from_inter(_to_inter(R_rope_q))))) - # Hardware RoPE: V_MUL_VV (BF16*BF16->BF16), V_ADD_VV (BF16+BF16->BF16) - Q_cos = _from_inter(_to_inter(Q_h.to(torch.bfloat16).float() * cos_table.to(torch.bfloat16).float())) - Q_rot_sin = _from_inter(_to_inter(Q_rot_h.to(torch.bfloat16).float() * sin_table.to(torch.bfloat16).float())) - Q_h = _from_inter(_to_inter(Q_cos) + _to_inter(Q_rot_sin)) - O_h = _flash_attn_ref(Q_h, K_q_heads_i[kv_h], V_q_heads_i[kv_h], scale, causal=True) - O_heads.append(O_h) - attn_out = _from_inter(_to_inter(torch.cat(O_heads, dim=1))) # VRAM write: BF16 - O_gold = _ksplit_matmul(attn_out, W_o_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - O_gold = _from_inter(_to_inter(O_gold)) # VRAM write: BF16 - X_gold = _from_inter(_to_inter(O_gold + residual)) # residual add -> VRAM write: BF16 - else: - attn_out = _flash_attn_ref(X_gold, K_q_list[i], V_q_list[i], scale, causal=True) - X_gold = _from_inter(_to_inter(attn_out + residual)) # residual add -> VRAM write: BF16 - - # --- FFN block --- - residual = X_gold.clone() - X_bf = _to_inter(X_gold) - # Hardware: scalar rms stays in f32, V_MUL_VF multiplies BF16 * f32 -> BF16 - rms = torch.rsqrt(_from_inter(X_bf).pow(2).mean(-1, keepdim=True) + eps) - X_gold = (_from_inter(X_bf) * rms).to(torch.bfloat16).float() - # Hardware path: MXFP8 (HBM) → BF16 (MRAM) → float32 (M_MM) - # Use _ksplit_matmul to match hardware K-split BF16 precision loss for - # projections that exceed MAX_K_TILES (e.g., hidden=576 → 9 tiles, inter=1536 → 24 tiles) - up_out = _ksplit_matmul(X_gold, W_up_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - gate_out = _ksplit_matmul(X_gold, W_gate_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - # Hardware: SiLU(up) * gate -> BF16 VRAM write before down projection - silu_gate = _to_inter(F.silu(_from_inter(_to_inter(up_out))) * _from_inter(_to_inter(gate_out))) - X_gold = _ksplit_matmul(_from_inter(silu_gate), W_down_q, mlen, _HW_MAX_K_TILES, _to_inter, _from_inter) - X_gold = _from_inter(_to_inter(X_gold)) # VRAM write: BF16 after down proj - X_gold = _from_inter(_to_inter(X_gold + residual)) # residual add -> VRAM write: BF16 - - print(f" After layer {i}: X_gold[0,:4] = {X_gold[0, :4].tolist()}") + for kv_h in range(num_kv_heads): + _verbose( + f" W_k_{i}_h{kv_h}: {all_weights[i].w_k_heads[kv_h].shape}, " + f"W_v_{i}_h{kv_h}: {all_weights[i].w_v_heads[kv_h].shape}" + ) + print(f"attn_scale: {scale:.6f}") - # Final norm - X_gold = _rms_norm_ref(X_gold, eps) - - # lm_head projection (optional) - if include_lm_head and lm_head_weight is not None: - W_lm_q = quantize_to_mxfp(lm_head_weight) - # Hardware: MXFP8 → BF16 (MRAM) → float32 (M_MM) - logits_gold = torch.matmul( - X_gold.to(torch.bfloat16).float(), W_lm_q.to(torch.bfloat16).float() - ).to(torch.bfloat16).float() - golden_out = logits_gold - print(f" logits_gold: {golden_out.shape}") - else: - golden_out = X_gold + R_matrix, cos_table, sin_table = make_rope_inputs(seq_len, model_cfg) + + golden_policy = ReferencePrecision.from_mode(golden_precision) + print(f"\nComputing CPU golden reference ({golden_policy.label})") + golden_out = run_decoder_reference( + token_embeds, + pos_weight, + all_weights, + model_cfg, + R_matrix, + cos_table, + sin_table, + mlen=mlen, + precision=golden_policy, + trace=lambda i, x: _verbose(f" After layer {i}: X_gold[0,:4] = {x[0, :4].tolist()}"), + ) print(f" golden_out: {golden_out.shape}") - print(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") + _verbose(f" golden_out[0,:4]: {golden_out[0, :4].tolist()}") - # ----------------------------------------------------------- HF ground truth - # Same pipeline as golden but pure float32: no MXFP8 quantization, no BF16 casting. - # This is the "best possible" reference for the sliced weight dimensions being tested. - print(f"\n--- HF Ground Truth (float32, {n_layers} layers, no quantization) ---") + print(f"\nComputing HF reference (float32, {n_layers} layer{'s' if n_layers != 1 else ''}, no quantization)") with torch.no_grad(): - X_hf = token_embeds.clone() + pos_weight # embedding_add - - for i in range(n_layers): - w = all_weights[i] - eps_i = w["eps"] - - # --- Attention block (float32) --- - residual = X_hf.clone() - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps_i) - X_normed = X_hf * rms - - if native_mode: - Q_hf = X_normed @ w["W_q"].float() - - K_hf_heads = [] - V_hf_heads = [] - R_mat_f32 = R_matrix.float() - cos_f32 = cos_table.float() - sin_f32 = sin_table.float() - for kv_h in range(num_kv_heads): - K_h = X_normed @ w["W_k_heads"][kv_h].float() - V_h = X_normed @ w["W_v_heads"][kv_h].float() - # RoPE on K_h (float32) - K_rot = K_h @ R_mat_f32 - K_h = K_h * cos_f32 + K_rot * sin_f32 - K_hf_heads.append(K_h) - V_hf_heads.append(V_h) - - O_heads_hf = [] - for h in range(num_heads): - kv_h = h // ratio - Q_h = Q_hf[:, h * head_dim:(h + 1) * head_dim] - # RoPE on Q_h (float32) - Q_rot = Q_h @ R_mat_f32 - Q_h = Q_h * cos_f32 + Q_rot * sin_f32 - O_h = _flash_attn_ref(Q_h, K_hf_heads[kv_h], V_hf_heads[kv_h], scale, causal=True) - O_heads_hf.append(O_h) - attn_out_hf = torch.cat(O_heads_hf, dim=1) - O_hf = attn_out_hf @ w["W_o"].float() - X_hf = O_hf + residual - else: - attn_out_hf = _flash_attn_ref(X_normed, K_mats[i].float(), V_mats[i].float(), scale, causal=True) - X_hf = attn_out_hf + residual - - # --- FFN block (float32) --- - residual = X_hf.clone() - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps_i) - X_normed = X_hf * rms - up_out = F.silu(X_normed @ w["W_up"].float()) - gate_out = X_normed @ w["W_gate"].float() - X_hf = (up_out * gate_out) @ w["W_down"].float() + residual - - print(f" After layer {i}: X_hf[0,:4] = {X_hf[0, :4].tolist()}") - - # Final norm (float32) - rms = torch.rsqrt(X_hf.pow(2).mean(-1, keepdim=True) + eps) - X_hf = X_hf * rms - - if include_lm_head and lm_head_weight is not None: - hf_ground_truth = (X_hf @ lm_head_weight.float()) - else: - hf_ground_truth = X_hf + hf_ground_truth = run_decoder_reference( + token_embeds, + pos_weight, + all_weights, + model_cfg, + R_matrix, + cos_table, + sin_table, + mlen=mlen, + precision=ReferencePrecision.from_mode("hf_fp32"), + trace=lambda i, x: _verbose(f" After layer {i}: X_hf[0,:4] = {x[0, :4].tolist()}"), + ) print(f" hf_ground_truth: {hf_ground_truth.shape}") - print(f" hf_ground_truth[0,:4]: {hf_ground_truth[0, :4].tolist()}") + _verbose(f" hf_ground_truth[0,:4]: {hf_ground_truth[0, :4].tolist()}") # ----------------------------------------------------------- PLENA ISA print("\n--- PLENA Backend (ISA generation) ---") @@ -797,17 +401,11 @@ def compile_hf_model( x_input = prog.input("X", shape=(seq_len, hidden)) pos_input = prog.input("POS", shape=(seq_len, hidden)) - # RoPE inputs (native mode only: R_rope matrix + cos/sin tables) - if native_mode: - rope_theta = native_cfg["rope_theta"] - R_matrix = _make_rotate_half_matrix(head_dim) - cos_table, sin_table = _make_rope_tables(seq_len, head_dim, rope_theta) - - r_input = prog.input("R_rope", shape=(head_dim, head_dim)) - cos_input = prog.input("COS", shape=(seq_len, head_dim)) - sin_input = prog.input("SIN", shape=(seq_len, head_dim)) - COS = prog.load_batch(cos_input, name="COS") - SIN = prog.load_batch(sin_input, name="SIN") + r_input = prog.input("R_rope", shape=(head_dim, head_dim)) + cos_input = prog.input("COS", shape=(seq_len, head_dim)) + sin_input = prog.input("SIN", shape=(seq_len, head_dim)) + COS = prog.load_batch(cos_input, name="COS") + SIN = prog.load_batch(sin_input, name="SIN") # Causal mask: (mlen, mlen) with 0 on/below diagonal, -inf above causal_mask_data = torch.zeros(mlen, mlen) @@ -820,35 +418,7 @@ def compile_hf_model( # Per-layer weight inputs (order determines HBM layout) layer_inputs = [] for i in range(n_layers): - wq = prog.input(f"W_q_{i}", shape=(hidden, total_q_dim)) - wo = prog.input(f"W_o_{i}", shape=(total_q_dim, hidden)) - if native_mode: - wk_heads = [] - wv_heads = [] - for kv_h in range(num_kv_heads): - wk_heads.append(prog.input(f"W_k_{i}_h{kv_h}", shape=(hidden, head_dim))) - wv_heads.append(prog.input(f"W_v_{i}_h{kv_h}", shape=(hidden, head_dim))) - li_entry = { - "W_q": wq, "W_o": wo, - "W_k_heads": wk_heads, "W_v_heads": wv_heads, - } - else: - ki = prog.input(f"K_{i}", shape=(seq_len, head_dim)) - vi = prog.input(f"V_{i}", shape=(seq_len, head_dim)) - li_entry = { - "W_q": wq, "W_o": wo, - "K": ki, "V": vi, - } - wg = prog.input(f"W_gate_{i}", shape=(hidden, inter)) - wu = prog.input(f"W_up_{i}", shape=(hidden, inter)) - wd = prog.input(f"W_down_{i}", shape=(inter, hidden)) - li_entry.update({"W_gate": wg, "W_up": wu, "W_down": wd}) - layer_inputs.append(li_entry) - - # lm_head weight input (after all layer weights in HBM layout) - lm_head_input = None - if include_lm_head and lm_head_weight is not None: - lm_head_input = prog.input("W_lm_head", shape=(hidden, vocab_size)) + layer_inputs.append(_register_layer_inputs(prog, i, all_weights[i])) # Load activations to VRAM X_batch = prog.load_batch(x_input, name="X") @@ -860,8 +430,7 @@ def compile_hf_model( # The residual scratch buffer must be placed ABOVE this region. _ffn_intermediate_end = seq_len * hidden + 2 * inter * seq_len _current_bump = 2 * seq_len * hidden # X + POS already allocated - if native_mode: - _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM + _current_bump += 2 * seq_len * head_dim # COS + SIN loaded to VRAM _current_bump += mlen * mlen # CAUSAL_MASK loaded to VRAM if _current_bump < _ffn_intermediate_end: _pad_elems = _ffn_intermediate_end - _current_bump @@ -881,117 +450,30 @@ def compile_hf_model( li = layer_inputs[i] # Layer progress marker (visible in non-quiet emulator output) - prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} START ===\n" - - # --- Attention block --- - # Save residual: scratch = current (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, current) - - # Norm (in-place on current) - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - - if native_mode: - # Q projection: current (seq, hidden) @ W_q (hidden, total_q_dim) - Q = _linear_projection(prog, current, li["W_q"], f"Q_{i}") - - # Per-KV-head: compute K/V on-chip, store to HBM, then run Q-heads - q_full_addr = prog._compiler.get_vram_addr(Q.name) - - # Allocate O_full for concatenated head outputs - O_full = prog.alloc(f"O_full_{i}", seq_len, total_q_dim) - o_full_addr = prog._compiler.get_vram_addr(O_full.name) - - # Store K/V InputVars for each KV head (filled during loop below) - kv_stored = [] - - for kv_h in range(num_kv_heads): - # K projection: current (seq, hidden) @ W_k_h (hidden, head_dim) - K_h = _linear_projection(prog, current, li["W_k_heads"][kv_h], f"K_{i}_h{kv_h}") - # V projection: current (seq, hidden) @ W_v_h (hidden, head_dim) - V_h = _linear_projection(prog, current, li["W_v_heads"][kv_h], f"V_{i}_h{kv_h}") - - # RoPE on K_h: K_rot = linear(K_h, R_rope), then K_h = K_h * cos + K_rot * sin - K_rot = _linear_projection(prog, K_h, r_input, f"K_rot_{i}_h{kv_h}") - prog.rope(K_h, K_rot, COS, SIN) - prog.free_tensor(K_rot) - - # Store K/V from VRAM to HBM (auto-allocates HBM space) - K_stored = prog.store(K_h, name=f"K_stored_{i}_h{kv_h}") - V_stored = prog.store(V_h, name=f"V_stored_{i}_h{kv_h}") - kv_stored.append((K_stored, V_stored)) - - # Free K/V VRAM — data is in HBM now - prog.free_tensor(K_h) - prog.free_tensor(V_h) - - for h in range(num_heads): - kv_h = h // ratio - K_stored, V_stored = kv_stored[kv_h] - - # View for this head's Q slice: column block h in Q_full - q_h_addr = q_full_addr + h * seq_len * mlen - Q_h = prog.alloc_at(f"Q_h{h}_{i}", seq_len, head_dim, q_h_addr) - - # RoPE on Q_h: Q_rot = linear(Q_h, R_rope), then Q_h = Q_h * cos + Q_rot * sin - Q_rot = _linear_projection(prog, Q_h, r_input, f"Q_rot_{i}_h{h}") - prog.rope(Q_h, Q_rot, COS, SIN) - prog.free_tensor(Q_rot) - - # Single-head flash attention (K/V read from HBM) with causal mask - O_h = ops.flash_attention(prog, Q_h, K_stored, V_stored, scale, - causal_mask=CAUSAL_MASK) - - # Copy O_h to the right column block of O_full - o_h_dest_addr = o_full_addr + h * seq_len * mlen - O_h_dest = prog.alloc_at(f"O_dest_h{h}_{i}", seq_len, head_dim, o_h_dest_addr) - prog.vram_fill_zero(O_h_dest) - prog.vram_add(O_h_dest, O_h) - - # Free flash_attention intermediates (S, PV, O) to reclaim VRAM - for _tmp_name in ("O", "S", "PV"): - _tmp_var = prog._tensors.get(_tmp_name) - if _tmp_var is not None: - prog.free_tensor(_tmp_var) - prog._tensors.pop(_tmp_name, None) - - # O projection: O_full @ W_o -> O_proj (seq, hidden) - O_proj = _linear_projection(prog, O_full, li["W_o"], f"O_proj_{i}") - - # Attention residual: O_proj += scratch - prog.vram_add(O_proj, scratch) - current_after_attn = O_proj - else: - # Legacy: X is Q directly (no projections), with causal mask - O = ops.flash_attention(prog, current, li["K"], li["V"], scale, - causal_mask=CAUSAL_MASK) - prog.vram_add(O, scratch) - current_after_attn = O - - # --- FFN block --- - # Save residual: scratch = current_after_attn (zero then add) - prog.vram_fill_zero(scratch) - prog.vram_add(scratch, current_after_attn) - - # Norm (in-place) - prog.rms_norm(current_after_attn, eps_offset=3, reci_hid_offset=4) - - # FFN (in-place) - ops.ffn(prog, current_after_attn, li["W_gate"], li["W_up"], li["W_down"]) - - # FFN residual - prog.vram_add(current_after_attn, scratch) + prog.emit_comment(f"=== LAYER {i}/{n_layers} START ===") + + current_after_attn = _emit_attention_block( + prog, + current, + li, + (r_input, COS, SIN), + CAUSAL_MASK, + scratch, + scale, + i, + seq_len, + head_dim, + total_q_dim, + num_heads, + num_kv_heads, + ratio, + ) - current = current_after_attn # carry forward - prog._compiler.generated_code += f"; === LAYER {i}/{n_layers} COMPLETE ===\n" + current = _emit_ffn_block(prog, current_after_attn, li, scratch) + prog.emit_comment(f"=== LAYER {i}/{n_layers} COMPLETE ===") # Final norm - prog.rms_norm(current, eps_offset=3, reci_hid_offset=4) - - # lm_head projection (optional): logits = linear(current, W_lm_head) - if include_lm_head and lm_head_input is not None: - logits = _linear_projection(prog, current, lm_head_input, "logits") - current = logits + ops.rms_norm(prog, current, eps_offset=3, reci_hid_offset=4) isa_code = prog.compile() isa_code = _fix_large_immediates(isa_code) @@ -999,45 +481,14 @@ def compile_hf_model( print(f"\nGenerated {len(lines)} lines of ISA code") # ----------------------------------------------------------- build return - input_tensors = {"X": token_embeds, "POS": pos_weight} - data_order = ["X", "POS"] - if native_mode: - input_tensors["R_rope"] = R_matrix - input_tensors["COS"] = cos_table - input_tensors["SIN"] = sin_table - data_order.extend(["R_rope", "COS", "SIN"]) + input_tensors = {"X": token_embeds, "POS": pos_weight, "R_rope": R_matrix, "COS": cos_table, "SIN": sin_table} + data_order = ["X", "POS", "R_rope", "COS", "SIN"] input_tensors["causal_mask"] = causal_mask_data data_order.append("causal_mask") for i in range(n_layers): - input_tensors[f"W_q_{i}"] = all_weights[i]["W_q"] - input_tensors[f"W_o_{i}"] = all_weights[i]["W_o"] - if native_mode: - for kv_h in range(num_kv_heads): - input_tensors[f"W_k_{i}_h{kv_h}"] = all_weights[i]["W_k_heads"][kv_h] - input_tensors[f"W_v_{i}_h{kv_h}"] = all_weights[i]["W_v_heads"][kv_h] - else: - input_tensors[f"K_{i}"] = K_mats[i] - input_tensors[f"V_{i}"] = V_mats[i] - input_tensors[f"W_gate_{i}"] = all_weights[i]["W_gate"] - input_tensors[f"W_up_{i}"] = all_weights[i]["W_up"] - input_tensors[f"W_down_{i}"] = all_weights[i]["W_down"] - - kv_keys = [] - if native_mode: - for kv_h in range(num_kv_heads): - kv_keys.extend([f"W_k_{i}_h{kv_h}", f"W_v_{i}_h{kv_h}"]) - else: - kv_keys = [f"K_{i}", f"V_{i}"] - data_order.extend([ - f"W_q_{i}", f"W_o_{i}", - *kv_keys, - f"W_gate_{i}", f"W_up_{i}", f"W_down_{i}", - ]) - - # lm_head weight in input_tensors + data_order - if include_lm_head and lm_head_weight is not None: - input_tensors["W_lm_head"] = lm_head_weight - data_order.append("W_lm_head") + for name, tensor in all_weights[i].tensor_entries(i): + input_tensors[name] = tensor + data_order.append(name) # FPRAM layout (same as single-layer decoder): # slot 0 = 0.0 (reserved) @@ -1050,13 +501,9 @@ def compile_hf_model( fp_preload = [0.0, scale, float("-inf"), eps, 1.0 / hidden, 1.0] + [0.0] * 4 # Result is at current's VRAM location - o_vram_addr = prog._compiler.get_vram_addr(current.name) + o_vram_addr = prog.get_vram_addr(current.name) - # Output dimensions depend on whether lm_head is included - if include_lm_head and lm_head_weight is not None: - out_features = vocab_size - else: - out_features = hidden + out_features = hidden comparison_params = { "start_row_idx": o_vram_addr // mlen, @@ -1068,17 +515,14 @@ def compile_hf_model( } info = { - "model_type": native_cfg["model_type"], + "model_type": model_cfg.model_type, "hidden_size": hidden, "inter_dim": inter, "num_layers": n_layers, "seq_len": seq_len, "head_dim": head_dim, - "num_heads": num_heads if native_mode else 1, - "num_kv_heads": num_kv_heads if native_mode else 1, - "native_mode": native_mode, - "include_lm_head": include_lm_head and lm_head_weight is not None, - "vocab_size": vocab_size if (include_lm_head and lm_head_weight is not None) else None, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, "mlen": mlen, "blen": blen, "isa_lines": len(lines), @@ -1086,9 +530,6 @@ def compile_hf_model( print(f"\nCompilation complete: {info['isa_lines']} ISA lines, " f"{n_layers} layers, output at VRAM row {o_vram_addr // mlen}") - if include_lm_head and lm_head_weight is not None: - print(f" lm_head: output shape=({seq_len}, {vocab_size})") - return { "isa": isa_code, "golden_output": golden_out, @@ -1100,126 +541,3 @@ def compile_hf_model( "info": info, "golden_precision": golden_precision, } - - -# --------------------------------------------------------------------------- -# Convenience: compile + run emulator + compare -# --------------------------------------------------------------------------- -def compile_and_run( - model, - build_dir, - **kwargs, -) -> dict: - """Compile, run emulator, and compare against golden. - - Convenience wrapper that calls compile_hf_model, sets up simulation - environment, runs the Rust transactional emulator, and compares output. - - Args: - model: nn.Module (HF CausalLM model, already loaded) - build_dir: Directory for simulation artifacts - **kwargs: Forwarded to compile_hf_model (seq_len, hidden_size, etc.) - - Returns: - dict with compilation info + comparison results including - 'allclose_match_rate' percentage. - """ - from transactional_emulator.tools.create_sim_env import create_sim_env - from compiler.sim_env_utils.build_env import create_mem_for_sim - from transactional_emulator.testbench.emulator_runner import ( - run_and_assert, - compare_emulator_output, - ) - - result = compile_hf_model(model, **kwargs) - build_dir = Path(build_dir) - build_dir.mkdir(parents=True, exist_ok=True) - - mlen = kwargs.get("mlen", 64) - blen = kwargs.get("blen", 4) - asm_name = f"model_{result['info']['model_type']}_{result['info']['num_layers']}L" - - # Write sim env files - create_sim_env( - result["input_tensors"], - result["isa"], - {"original_output": result["golden_output"]}, - result["fp_preload"], - build_dir=str(build_dir), - ) - - create_mem_for_sim( - data_size=256, - mode="behave_sim", - asm=asm_name, - data=None, - specified_data_order=result["data_order"], - build_path=build_dir, - ) - - with open(build_dir / "comparison_params.json", "w") as f: - json.dump(result["comparison_params"], f, indent=2) - - with open(build_dir / "generated_asm_code.asm", "w") as f: - f.write(result["isa"]) - - print(f"\nSimulation environment created: {build_dir}") - print(f" Result location: VRAM row {result['comparison_params']['start_row_idx']}") - print(f" Layers: {result['info']['num_layers']}, data_order: {result['data_order']}") - - # Run emulator and compare (don't exit on failure — VRAM stage comparison follows) - from transactional_emulator.testbench.emulator_runner import update_plena_config, run_emulator - update_plena_config(vlen=mlen, mlen=mlen, blen=blen, verbose=False) - print("\n--- Running Rust transactional emulator ---") - run_emulator(build_dir) - - print("\n--- Comparing emulator output vs golden ---") - comp_results, _params = compare_emulator_output(build_dir) - from transactional_emulator.tools.check_mem import print_comparison_results - print_comparison_results(comp_results, verbose=True, comparison_params=_params) - - if comp_results["allclose_pass"]: - print(f"\n[ATen-style {asm_name} test PASSED - ISA generated + emulator verified]") - else: - print(f"\n[ATen-style {asm_name} test FAILED - emulator numerical check failed]") - - # Three-way comparison - golden = result["golden_output"] - hf_gt = result["hf_ground_truth"] - print("\n--- Three-way comparison ---") - if hf_gt is not None and golden is not None: - # HF float32 vs golden (MXFP8 + BF16) - n = min(hf_gt.numel(), golden.numel()) - allclose_hf_vs_gold = ( - torch.isclose(hf_gt.float().flatten()[:n], - golden.float().flatten()[:n], atol=1e-2) - .float().mean().item() * 100 - ) - print(f" HF float32 vs golden (MXFP8+BF16): {allclose_hf_vs_gold:.2f}% allclose") - # Emulator vs golden: reported by compare_emulator_output - emu_match = comp_results.get("allclose_match_rate", None) - if emu_match is not None: - print(f" Emulator vs golden (MXFP8+BF16): {emu_match:.2f}% allclose") - - # VRAM stage comparison: validates each pipeline segment using - # emulator's own intermediates as golden input (immune to accumulation drift) - try: - from compiler.aten.vram_stage_compare import compare_stages - emulator_dir = Path(__file__).parent.parent.parent / "transactional_emulator" - vram_path = str(emulator_dir / "vram_dump.bin") - print("\n--- VRAM stage comparison (authoritative) ---") - stage_results = compare_stages( - vram_path=vram_path, - build_dir=str(build_dir), - hidden=result["info"]["hidden_size"], - inter=result["info"].get("inter_dim", result["info"]["hidden_size"] * 4), - num_heads=result["info"]["num_heads"], - num_kv_heads=result["info"]["num_kv_heads"], - ) - stage_pass = stage_results.get("norm+FFN+norm", 0) >= 99.0 - comp_results["vram_stage_allclose"] = stage_results.get("norm+FFN+norm", None) - comp_results["vram_stage_pass"] = stage_pass - except Exception as e: - print(f" (skipped: {e})") - - return {**result["info"], **comp_results} diff --git a/aten/reference.py b/aten/reference.py new file mode 100644 index 0000000..bde2020 --- /dev/null +++ b/aten/reference.py @@ -0,0 +1,291 @@ +"""CPU decoder references used by the PLENA ATen frontend.""" + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from compiler.aten.model_extract import LayerWeights, ModelConfig +from quant.quantizer.hardware_quantizer.mxfp import _mx_fp_quantize_hardware + + +_HW_MAX_K_TILES = 4 + + +@dataclass(frozen=True) +class ReferencePrecision: + """Precision policy for the shared decoder reference runner.""" + + name: str + label: str + quantize_hbm: bool + bf16_intermediates: bool + use_ksplit: bool + + @classmethod + def from_mode(cls, mode: str) -> ReferencePrecision: + modes = { + "hardware": cls("hardware", "MXFP8 weights + BF16 intermediates", True, True, True), + "no_weight_quant": cls( + "no_weight_quant", + "float32 weights + BF16 intermediates", + False, + True, + True, + ), + "no_bf16": cls("no_bf16", "MXFP8 weights + float32 intermediates", True, False, True), + "fp32": cls("fp32", "float32 weights + float32 intermediates", False, False, True), + "hf_fp32": cls("hf_fp32", "float32, no quantization", False, False, False), + } + try: + return modes[mode] + except KeyError as exc: + valid = ", ".join(modes) + raise ValueError(f"Unknown reference precision '{mode}'. Expected one of: {valid}") from exc + + def quantize(self, tensor: torch.Tensor) -> torch.Tensor: + return quantize_to_mxfp(tensor) if self.quantize_hbm else tensor + + def to_inter(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(torch.bfloat16) if self.bf16_intermediates else tensor + + def from_inter(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.float() if self.bf16_intermediates else tensor + + +def quantize_to_mxfp(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor to MXFP8 matching HBM hardware format; return dequantized result.""" + orig_shape = tensor.shape + tensor_2d = tensor.float().reshape(-1, tensor.shape[-1]) + bm_x, _, _, _ = _mx_fp_quantize_hardware( + tensor_2d, + width=8, + exponent_width=4, + exponent_bias_width=8, + block_size=[1, 8], + ) + return bm_x.reshape(orig_shape) + + +def _make_rope_tables(seq_len: int, head_dim: int, theta: float = 10000.0): + """Compute RoPE cos/sin tables, shape (seq_len, head_dim).""" + half = head_dim // 2 + freqs = 1.0 / (theta ** (torch.arange(0, half).float() / half)) + positions = torch.arange(seq_len).float() + angles = torch.outer(positions, freqs) + cos_half = torch.cos(angles) + sin_half = torch.sin(angles) + cos = torch.cat([cos_half, cos_half], dim=-1) + sin = torch.cat([sin_half, sin_half], dim=-1) + return cos, sin + + +def _make_rotate_half_matrix(head_dim: int) -> torch.Tensor: + """Build the (head_dim, head_dim) matrix that computes rotate_half.""" + rotate = torch.zeros(head_dim, head_dim) + half = head_dim // 2 + for i in range(half): + rotate[i + half, i] = -1.0 + rotate[i, i + half] = 1.0 + return rotate + + +def make_rope_inputs(seq_len: int, config: ModelConfig) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the rotate_half matrix and RoPE tables used by CPU and PLENA paths.""" + rotate = _make_rotate_half_matrix(config.head_dim) + cos_table, sin_table = _make_rope_tables(seq_len, config.head_dim, config.rope_theta) + return rotate, cos_table, sin_table + + +def run_decoder_reference( + token_embeds: torch.Tensor, + pos_weight: torch.Tensor, + weights: list[LayerWeights], + config: ModelConfig, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + *, + mlen: int, + precision: ReferencePrecision, + trace: Callable[[int, torch.Tensor], None] | None = None, +) -> torch.Tensor: + """Run the compiled decoder blocks under a given precision policy.""" + quantize = precision.quantize + x = quantize(token_embeds.clone()) + quantize(pos_weight) + rope_ref = quantize(rope_matrix) + + for layer_idx, layer in enumerate(weights): + x = _attention_block_ref( + x, + layer, + config, + rope_ref, + cos_table, + sin_table, + mlen, + precision, + ) + x = _ffn_block_ref(x, layer, mlen, precision) + + if trace is not None: + trace(layer_idx, x) + + return _rms_norm_ref(x, weights[0].eps, precision) + + +def _attention_block_ref( + x: torch.Tensor, + layer: LayerWeights, + config: ModelConfig, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + quantize = precision.quantize + residual = x.clone() + x_normed = _rms_norm_ref(x, layer.eps, precision) + q_full = _round(_linear_ref(x_normed, quantize(layer.w_q), mlen, precision), precision) + + k_heads = [] + v_heads = [] + for kv_h in range(config.num_kv_heads): + k_h = _linear_ref(x_normed, quantize(layer.w_k_heads[kv_h]), mlen, precision) + v_h = _linear_ref(x_normed, quantize(layer.w_v_heads[kv_h]), mlen, precision) + k_h = _rope_ref(k_h, rope_matrix, cos_table, sin_table, precision) + k_heads.append(_hbm_round_ref(k_h, precision)) + v_heads.append(_hbm_round_ref(v_h, precision)) + + scale = 1.0 / math.sqrt(config.head_dim) + o_heads = [] + for h in range(config.num_heads): + kv_h = h // config.head_ratio + q_h = q_full[:, h * config.head_dim:(h + 1) * config.head_dim] + q_h = _rope_ref(q_h, rope_matrix, cos_table, sin_table, precision) + o_heads.append(_flash_attn_ref(q_h, k_heads[kv_h], v_heads[kv_h], scale, causal=True)) + + attn_out = _round(torch.cat(o_heads, dim=1), precision) + o_proj = _round(_linear_ref(attn_out, quantize(layer.w_o), mlen, precision), precision) + return _residual_add_ref(o_proj, residual, precision) + + +def _ffn_block_ref( + x: torch.Tensor, + layer: LayerWeights, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + quantize = precision.quantize + residual = x.clone() + x_normed = _rms_norm_ref(x, layer.eps, precision) + up_out = _linear_ref(x_normed, quantize(layer.w_up), mlen, precision) + gate_out = _linear_ref(x_normed, quantize(layer.w_gate), mlen, precision) + silu_gate = precision.to_inter(F.silu(_round(up_out, precision)) * _round(gate_out, precision)) + x = _linear_ref(precision.from_inter(silu_gate), quantize(layer.w_down), mlen, precision) + return _residual_add_ref(_round(x, precision), residual, precision) + + +def _round(x: torch.Tensor, precision: ReferencePrecision) -> torch.Tensor: + return precision.from_inter(precision.to_inter(x)) + + +def _rms_norm_ref(x: torch.Tensor, eps: float, precision: ReferencePrecision) -> torch.Tensor: + x_inter = precision.to_inter(x) + rms = torch.rsqrt(precision.from_inter(x_inter).pow(2).mean(-1, keepdim=True) + eps) + return _round(precision.from_inter(x_inter) * rms, precision) + + +def _linear_ref( + x: torch.Tensor, + weight: torch.Tensor, + mlen: int, + precision: ReferencePrecision, +) -> torch.Tensor: + if not precision.use_ksplit: + return x @ weight.float() + return _ksplit_matmul( + x, + weight, + mlen=mlen, + max_k_tiles=_HW_MAX_K_TILES, + to_inter=precision.to_inter, + from_inter=precision.from_inter, + ) + + +def _rope_ref( + x: torch.Tensor, + rope_matrix: torch.Tensor, + cos_table: torch.Tensor, + sin_table: torch.Tensor, + precision: ReferencePrecision, +) -> torch.Tensor: + x_inter = _round(x, precision) + x_rot = _round( + torch.matmul(x_inter, _round(rope_matrix, precision)), + precision, + ) + x_cos = _round(x_inter * _round(cos_table, precision), precision) + x_rot_sin = _round(x_rot * _round(sin_table, precision), precision) + return _round(x_cos + x_rot_sin, precision) + + +def _hbm_round_ref(x: torch.Tensor, precision: ReferencePrecision) -> torch.Tensor: + return _round(precision.quantize(x), precision) + + +def _residual_add_ref( + x: torch.Tensor, + residual: torch.Tensor, + precision: ReferencePrecision, +) -> torch.Tensor: + return _round(x + residual, precision) + + +def _flash_attn_ref(Q, K, V, scale, causal=False): + """CPU reference: scaled dot-product attention matching hardware BF16 precision.""" + scores = (Q @ K.T).to(torch.bfloat16).float() * scale + scores = scores.to(torch.bfloat16).float() + if causal: + mask = torch.triu(torch.ones(scores.shape[-2], scores.shape[-1], device=scores.device), diagonal=1).bool() + scores.masked_fill_(mask, float("-inf")) + attn = F.softmax(scores, dim=-1).to(torch.bfloat16).float() + return (attn @ V).to(torch.bfloat16).float() + + +def _ksplit_matmul(A, B, mlen=64, max_k_tiles=_HW_MAX_K_TILES, to_inter=None, from_inter=None): + """Matrix multiply matching hardware K-split BF16 precision.""" + if to_inter is None: + def to_inter(x): + return x.to(torch.bfloat16) + + if from_inter is None: + def from_inter(x): + return x.float() + + k_total = A.shape[1] + num_k_tiles = math.ceil(k_total / mlen) + + if num_k_tiles <= max_k_tiles: + return from_inter(to_inter(torch.matmul(from_inter(to_inter(A)), from_inter(to_inter(B))))) + + result = None + k_start = 0 + while k_start < k_total: + k_end = min(k_start + max_k_tiles * mlen, k_total) + a_chunk = A[:, k_start:k_end] + b_chunk = B[k_start:k_end, :] + partial = from_inter(to_inter(torch.matmul(from_inter(to_inter(a_chunk)), from_inter(to_inter(b_chunk))))) + if result is None: + result = partial + else: + result = from_inter(to_inter(result) + to_inter(partial)) + k_start = k_end + + return result diff --git a/aten/tests/test_plena_compiler.py b/aten/tests/test_plena_compiler.py index f168504..a10932e 100644 --- a/aten/tests/test_plena_compiler.py +++ b/aten/tests/test_plena_compiler.py @@ -1,13 +1,13 @@ """Unit tests for PlenaCompiler ATen path (no emulator needed). -Run: PYTHONPATH=.:tools:compiler python3 compiler/aten/tests/test_plena_compiler.py +Run: PYTHONPATH=.:tools:.. python3 aten/tests/test_plena_compiler.py """ import sys import os # Insert PLENA_Simulator root and tools/ so imports resolve correctly regardless -# of how the test is invoked (direct python3 or via PYTHONPATH=.:tools:compiler). +# of how the test is invoked (direct python3 or via PYTHONPATH=.:tools:..). _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _SIM_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))) _SIM_ROOT = os.path.join(_SIM_ROOT, "PLENA_Simulator") @@ -18,12 +18,76 @@ if _p not in sys.path: sys.path.insert(0, _p) -import torch +import torch # noqa: E402 + + +def test_isa_builder_renders_typed_instruction(): + """Typed ISA builder should render the existing asm syntax.""" + from compiler.aten.isa_builder import IsaBuilder, fp, gp + + asm = IsaBuilder() + asm.comment("typed builder smoke") + asm.instr("S_ADDI_INT", gp(3), gp(0), 4096) + asm.instr("S_ADD_FP", fp(1), fp(1), fp(2)) + + rendered = asm.render() + assert "; typed builder smoke" in rendered + assert "S_ADDI_INT gp3, gp0, 4096" in rendered + assert "S_ADD_FP f1, f1, f2" in rendered + print(" PASS test_isa_builder_renders_typed_instruction") + + +def test_isa_builder_legalizes_large_absolute_immediates(): + """Typed ISA builder should split large absolute S_ADDI_INT loads.""" + from compiler.aten.isa_builder import IsaBuilder, gp + + rendered = IsaBuilder().instr("S_ADDI_INT", gp(5), gp(0), 300000).render() + assert "S_LUI_INT gp5, 73" in rendered + assert "S_ADDI_INT gp5, gp5, 992" in rendered + assert "S_ADDI_INT gp5, gp0, 300000" not in rendered + print(" PASS test_isa_builder_legalizes_large_absolute_immediates") + + +def test_isa_builder_preserves_relative_large_immediates(): + """Typed ISA builder should not split relative S_ADDI_INT instructions.""" + from compiler.aten.isa_builder import IsaBuilder, gp + + rendered = IsaBuilder().instr("S_ADDI_INT", gp(5), gp(3), 300000).render() + assert "S_ADDI_INT gp5, gp3, 300000" in rendered + assert "S_LUI_INT" not in rendered + print(" PASS test_isa_builder_preserves_relative_large_immediates") + + +def test_fpvar_helper_uses_canonical_emit_path(): + """Converted FPVar helpers should still append and return asm text.""" + from compiler.aten.plena import PlenaCompiler + + prog = PlenaCompiler() + code = prog.fpvar_add_asm(src1_addr=0, src2_addr=4, dst_addr=8, count=2) + + assert code == prog.get_code() + assert "S_ADD_FP f1, f1, f2" in code + assert "C_LOOP_START" in code + print(" PASS test_fpvar_helper_uses_canonical_emit_path") + + +def test_hbm_load_helper_uses_typed_legalization(): + """Converted HBM load helpers should legalize typed large immediates.""" + from compiler.aten.plena import IsaCompiler + + compiler = IsaCompiler() + compiler.register_matrix("W", (512, 512), hbm_base_addr=0) + + code = compiler.load_sub_matrix_asm("W", row_idx=0, col_idx=0, mram_dest_addr=0) + assert "S_LUI_INT gp1, 64" in code + assert "S_ADDI_INT gp1, gp0, 262144" not in code + assert "H_PREFETCH_M gp3, gp1, a1, 1, 0" in code + print(" PASS test_hbm_load_helper_uses_typed_legalization") def test_vram_fill_zero_all_column_blocks(): """vram_fill_zero must zero ALL column blocks of a wide matrix.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() x = prog.alloc("X", 64, 384) @@ -39,7 +103,7 @@ def test_vram_fill_zero_all_column_blocks(): def test_vram_add_all_column_blocks(): """vram_add must add ALL column blocks of wide matrices.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() x = prog.alloc("X", 64, 384) @@ -56,7 +120,7 @@ def test_vram_add_all_column_blocks(): def test_alloc_at_correct_address(): """alloc_at must create a VRAM view at the specified address.""" - from compiler.aten.plena_compiler import PlenaCompiler + from compiler.aten.plena import PlenaCompiler prog = PlenaCompiler() # Allocate a matrix, then create a view into its second column block @@ -67,7 +131,7 @@ def test_alloc_at_correct_address(): view_addr = x_addr + 2 * 64 * 64 view = prog.alloc_at("X_cb2_view", 64, 64, view_addr) - actual_addr = prog._compiler.get_vram_addr(view.name) + actual_addr = prog.get_vram_addr(view.name) assert actual_addr == view_addr, ( f"alloc_at address mismatch: expected {view_addr}, got {actual_addr}" ) @@ -150,7 +214,7 @@ def test_compile_hf_model_golden_vs_hf(): ) model.eval() - r = compile_hf_model(model, seq_len=64, hidden_size=None, inter_dim=None, num_layers=1) + r = compile_hf_model(model, seq_len=64, num_layers=1) golden = r["golden_output"] hf = r["hf_ground_truth"] @@ -167,16 +231,18 @@ def test_compile_hf_model_golden_vs_hf(): def test_native_compile_assembles(): """Native-dim ISA must assemble without overflow.""" - from compiler.aten.plena_frontend import compile_hf_model, _fix_large_immediates + import os + import tempfile + + from compiler.aten.plena_frontend import compile_hf_model from transformers import AutoModelForCausalLM - import tempfile, os model = AutoModelForCausalLM.from_pretrained( "AICrossSim/clm-60m", torch_dtype=torch.float32 ) model.eval() - r = compile_hf_model(model, seq_len=64, hidden_size=None, inter_dim=None, num_layers=1) + r = compile_hf_model(model, seq_len=64, num_layers=1) isa = r["isa"] # Assemble — should not raise ValueError (u32 overflow) @@ -213,6 +279,11 @@ def test_native_compile_assembles(): print("=" * 60) tests = [ + test_isa_builder_renders_typed_instruction, + test_isa_builder_legalizes_large_absolute_immediates, + test_isa_builder_preserves_relative_large_immediates, + test_fpvar_helper_uses_canonical_emit_path, + test_hbm_load_helper_uses_typed_legalization, test_vram_fill_zero_all_column_blocks, test_vram_add_all_column_blocks, test_alloc_at_correct_address, diff --git a/aten/tests/test_quantization_ablation.py b/aten/tests/test_quantization_ablation.py index 2142df3..13762f1 100644 --- a/aten/tests/test_quantization_ablation.py +++ b/aten/tests/test_quantization_ablation.py @@ -9,8 +9,8 @@ fp32 (fp32 + fp32) ~99% allclose ← confirms quantization is sole cause Usage: - pytest compiler/aten/tests/test_quantization_ablation.py -v -s - python3 compiler/aten/tests/test_quantization_ablation.py [--layers N] + pytest aten/tests/test_quantization_ablation.py -v -s + python3 aten/tests/test_quantization_ablation.py [--layers N] """ import argparse @@ -84,8 +84,8 @@ def test_mxfp8_is_sole_gap_source(): for mode in MODES: r = results[mode] print(f" {mode:<20} {r['allclose']:>11.2f}% {r['mse']:>15.6e}") - print(f"\n MXFP8 weight quantization = 100% of the gap") - print(f" BF16 intermediate precision = 0% of the gap") + print("\n MXFP8 weight quantization = 100% of the gap") + print(" BF16 intermediate precision = 0% of the gap") if __name__ == "__main__": @@ -103,5 +103,5 @@ def test_mxfp8_is_sole_gap_source(): for mode in MODES: r = results[mode] print(f" {mode:<20} {r['allclose']:>11.2f}% {r['mse']:>15.6e}") - print(f"\n Conclusion: MXFP8 weight quantization = 100% of the gap") - print(f" BF16 intermediate precision = 0% of the gap") + print("\n Conclusion: MXFP8 weight quantization = 100% of the gap") + print(" BF16 intermediate precision = 0% of the gap") diff --git a/compiler/__init__.py b/compiler/__init__.py new file mode 100644 index 0000000..a01b45f --- /dev/null +++ b/compiler/__init__.py @@ -0,0 +1,12 @@ +"""Compatibility namespace for legacy ``compiler.*`` imports. + +PLENA_Compiler packages live at the repository root (``aten``, ``generator``, +``asm_templates``, ...), but existing code imports them through ``compiler``. +Keep that import path local to this submodule instead of resolving to the +simulator sibling directory. +""" + +from pathlib import Path + +__path__ = [str(Path(__file__).resolve().parent.parent)] + diff --git a/doc/precision.svh b/doc/precision.svh index a8a1f1c..2652c7d 100644 --- a/doc/precision.svh +++ b/doc/precision.svh @@ -1,6 +1,6 @@ // precision.svh - PLENA numeric precision parameters. // -// Intentionally minimal: compiler/generator/parser/hardware_parser.py (and the +// Intentionally minimal: generator/parser/hardware_parser.py (and the // helpers under tools/) tolerate an empty precision file and fall back to the // following per-parameter defaults when the corresponding `parameter` line is // absent: diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 6fe9cd5..a0ff701 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -3,7 +3,7 @@ ## Directory Structure ``` -compiler/ +PLENA_Compiler/ |-- asm_templates/ # Shared ISA instruction emitters (used by BOTH pipelines) | |-- flashattn/ # Flash attention (qkt, pv, online_softmax, output, reset) | |-- ffn_asm.py # FFN (gate/up/down with K-split) @@ -26,15 +26,20 @@ compiler/ | +-- reset_reg_asm.py # Register reset helpers | |-- aten/ # Pipeline 1: ATen compilation backend -| |-- plena_compiler.py # PlenaCompiler class (VRAM/MRAM/FPRAM management) +| |-- plena_frontend.py # HF model -> PLENA program -> ISA text +| |-- e2e_runner.py # ATen e2e runner: PlenaCompiler -> emulator -> golden +| |-- plena/ # Canonical PlenaCompiler implementation package +| | |-- compiler.py # PlenaCompiler composition class +| | |-- memory_state.py # Tensor/input/FP memory state +| | |-- program_*.py # High-level program operations +| | +-- isa_*.py # Low-level ISA emitters | +-- ops/ # Registered ATen op implementations | |-- registry.py # Op dispatch registry -| |-- plena/ # PLENA backend (linear, attention, ffn, norm, conv, softmax, embedding) +| |-- plena/ # PLENA backend wrappers | +-- cpu/ # CPU reference implementations | |-- generator/ # Pipeline 2: Config-driven code generation -| |-- runner.py # CLI entry point (codegen, aten, utilization modes) -| |-- aten_runner.py # ATen backend bridge (wraps PlenaCompiler for E2E) +| |-- runner.py # CLI entry point for codegen/utilization | |-- parser/ # HF config -> symbolic graph | | |-- llm_parser.py # LLMModelParser (text decoder graphs) | | +-- hardware_parser.py # configuration.svh / precision.svh reader @@ -82,7 +87,7 @@ compiler/ nn.Module -> torch.export (FX graph of ATen ops) -> op dispatch (aten/ops/registry.py) - -> PlenaCompiler (aten/plena_compiler.py) + -> PlenaCompiler (aten/plena/) - VRAM/MRAM/FPRAM allocation - HBM weight layout + address register init - calls asm_templates/* for ISA emission @@ -100,7 +105,7 @@ HF config.json -> code_gen_pass (generator/passes/code_gen.py) - dispatches each node to asm_templates/* -> assembler/ (ASM -> .mem binary) - -> emulator (runs but numerically incorrect -- no addr reg init) + -> emulator smoke / utilization analysis ``` See `docs/COMPILATION_PIPELINES.md` for detailed comparison, known gaps, diff --git a/docs/ATEN_TREE.md b/docs/ATEN_TREE.md new file mode 100644 index 0000000..e5800da --- /dev/null +++ b/docs/ATEN_TREE.md @@ -0,0 +1,77 @@ +# ATen Compiler Tree + +This summarizes the current `aten/` package layout. Generated `__pycache__/` +directories are intentionally omitted. + +```text +aten/ +|-- __init__.py # Public ATen package exports +|-- native_ops.yaml # Operator registry spec: signatures and dispatch targets +|-- isa_builder.py # Typed ISA instruction/register builder and legalization +|-- model_extract.py # HuggingFace model config/layer/embedding extraction helpers +|-- plena_frontend.py # HF decoder model -> PLENA program -> ISA text +|-- e2e_runner.py # HF model -> ATen compiler -> emulator -> golden check +|-- reference.py # CPU golden/reference math and MXFP/BF16 helpers +|-- vram_stage_compare.py # Debug tooling for VRAM stage comparisons +| +|-- ops/ # ATen-style operator dispatch layer +| |-- __init__.py # User-facing ops.* dispatch functions +| |-- registry.py # Backend registry: CPU vs PLENA implementation lookup +| | +| |-- cpu/ # PyTorch reference backend +| | |-- __init__.py +| | |-- attention_ops.py # CPU flash attention reference +| | |-- conv_ops.py # CPU conv reference +| | |-- embedding_ops.py # CPU embedding/RoPE reference +| | |-- ffn_ops.py # CPU FFN reference +| | |-- linear_ops.py # CPU linear reference +| | |-- norm_ops.py # CPU RMS/layer norm reference +| | +-- softmax_ops.py # CPU softmax reference +| | +| +-- plena/ # PLENA backend operator wrappers +| |-- __init__.py +| |-- attention_ops.py # ops.flash_attention -> prog.flash_attention +| |-- conv_ops.py # conv lowering / PLENA conv codegen helper +| |-- embedding_ops.py # embedding_add / rope wrappers +| |-- ffn_ops.py # ops.ffn -> prog.ffn +| |-- linear_ops.py # ops.linear -> prog.linear_projection +| |-- norm_ops.py # rms_norm / layer_norm wrappers +| +-- softmax_ops.py # PLENA softmax lowering +| +|-- plena/ # Canonical PLENA compiler implementation package +| |-- __init__.py # Canonical exports: PlenaCompiler, IsaCompiler, vars, constants +| |-- compiler.py # Top-level PlenaCompiler composition class +| |-- constants.py # BLEN/MLEN/immediate constants +| |-- vars.py # Tensor/Input/VRAM/FP variable descriptors +| |-- registers.py # GP/ADDR/FP register allocation helpers +| |-- memory.py # Memory layout/address helpers +| |-- memory_state.py # Compiler-owned tensor/input/fp memory state +| | +| |-- program_tensors.py # Program-level tensor allocation/load/store helpers +| |-- program_fp_tile_ops.py # Program-level FP/tile scalar-vector operations +| |-- program_matrix_ops.py # Program-level matrix/projection/FFN operations +| |-- program_attention.py # Program-level flash attention operation +| | +| |-- isa_compiler.py # Low-level ISA emitter base and typed emit path +| |-- isa_emit.py # Generic emit helpers +| |-- isa_fp_ops.py # Scalar FP/FPRAM/tile FP ISA helpers +| |-- isa_tile_rows.py # Tile-row unary/binary loop emitters +| |-- isa_matrix.py # Matrix/projection/load/store ISA emitters +| +-- isa_attention.py # Attention-specific ISA helpers +| ++-- tests/ + |-- __init__.py + |-- test_bf16_numerical_stability.py + |-- test_plena_compiler.py + +-- test_quantization_ablation.py +``` + +Key points: + +- `aten/plena/` is the canonical compiler implementation package. +- `aten/ops/` is the ATen-style dispatcher surface. +- `aten/plena_frontend.py` is the HuggingFace/ATen frontend that drives model + compilation. +- `aten/e2e_runner.py` runs the ATen compiler path through the emulator and + golden comparison. +- The old `aten/plena_compiler.py` compatibility facade has been removed. diff --git a/docs/COMPILATION_PIPELINES.md b/docs/COMPILATION_PIPELINES.md index 38dc7d2..69e2ccc 100644 --- a/docs/COMPILATION_PIPELINES.md +++ b/docs/COMPILATION_PIPELINES.md @@ -49,7 +49,7 @@ backends, and weight-handling strategies. embedding). A CPU fallback registry (`aten/ops/cpu/`) provides reference implementations for ops not yet hardware-mapped. -3. **Backend**: `PlenaCompiler` (`aten/plena_compiler.py`) manages all +3. **Backend**: `PlenaCompiler` (`aten/plena/`) manages all hardware state -- VRAM allocation, MRAM tile scheduling, FPRAM slot assignment, HBM weight layout, and address register initialization (`C_SET_ADDR_REG`). It calls into `asm_templates/` to emit ISA strings. @@ -65,18 +65,19 @@ backends, and weight-handling strategies. | File | Role | |------|------| -| `aten/plena_compiler.py` | PlenaCompiler class (VRAM/MRAM/FPRAM management, ISA emission) | +| `aten/plena/` | Canonical PlenaCompiler implementation package | +| `aten/plena_frontend.py` | HuggingFace model frontend that drives ATen compilation | | `aten/ops/plena/*.py` | Registered ATen op implementations (linear, attention, ffn, norm, conv, softmax, embedding) | | `aten/ops/cpu/*.py` | CPU reference fallbacks | | `aten/ops/registry.py` | Op dispatch registry | -| `generator/aten_runner.py` | E2E harness: model load -> compile -> emulate -> verify | +| `aten/e2e_runner.py` | E2E harness: model load -> compile -> emulate -> verify | | `sim_env_utils/build_env.py` | Simulation environment builder | ### Entry points - **Single-layer tests**: `model_layer_test_builder.py::build_and_run_decoder_test` -- **Full-model E2E**: `generator/aten_runner.py::run_aten_e2e` -- **CLI**: `python -m generator.runner aten --seq-len 32 --num-layers 1` +- **Full-model E2E**: `aten/e2e_runner.py::run_aten_e2e` +- **CLI**: `python -m compiler.aten.e2e_runner --seq-len 32 --num-layers 1` ### Test suite @@ -88,8 +89,7 @@ pass with 98-100% allclose. ## Pipeline 2: Generator Path -**Status**: Generates valid ISA for analysis. Numerically incomplete -(HBM address registers uninitialized). +**Status**: Generates valid ISA for structural analysis and smoke tests. ### How it works @@ -104,16 +104,16 @@ pass with 98-100% allclose. 3. **Backend**: `code_gen_pass` (`generator/passes/code_gen.py`) walks the symbolic graph and dispatches each node to the appropriate `asm_templates/` function, passing scheduler-derived register and - address parameters. + address parameters. It emits address-register initialization for HBM-backed + weights before the generated compute body. 4. **Weight loading**: For E2E smoke tests, `test_generator_e2e.py` has a `_build_hbm_from_hf_weights` helper that loads real weights. The standard codegen path does not touch weights at all. -5. **Output**: A `.asm` file that assembles cleanly and runs on the emulator - (the instructions are structurally valid), but produces numerically - incorrect results because HBM address registers (`C_SET_ADDR_REG`) are - not initialized. +5. **Output**: A `.asm` file that assembles cleanly and runs on the emulator. + The generator path is still primarily used for structural codegen and + utilization work; the ATen path remains the numerically verified flow. ### Key files diff --git a/generator/aten_runner.py b/generator/aten_runner.py index 72f2ddf..cef8d18 100644 --- a/generator/aten_runner.py +++ b/generator/aten_runner.py @@ -1,332 +1,12 @@ -""" -ATen-backend runner for the generator. - -Wraps the proven ATen compilation path (PlenaCompiler + ops.*) to provide -end-to-end model compilation + emulation + numerical verification. This -is the same pipeline that passes 18/18 tests in the ATen test suite. - -The generator's value is in config parsing and symbolic graph analysis; -actual ISA compilation is best left to PlenaCompiler. - -Usage (standalone): - python -m generator.aten_runner AICrossSim/clm-60m --seq-len 32 +"""Deprecated alias for the ATen e2e runner. -Usage (from generator.runner): - python -m generator.runner aten AICrossSim/clm-60m --seq-len 32 --num-layers 1 +Use ``python -m compiler.aten.e2e_runner`` or import +``compiler.aten.e2e_runner.run_aten_e2e`` directly. """ -import sys -import time -from pathlib import Path - -# --------------------------------------------------------------------------- -# Repo root bootstrap — mirror the same sys.path setup used by the existing -# test infrastructure so imports resolve regardless of cwd. -# --------------------------------------------------------------------------- -_COMPILER_ROOT = Path(__file__).resolve().parents[1] # compiler/ -_REPO_ROOT = _COMPILER_ROOT.parent -for _p in [str(_REPO_ROOT), str(_REPO_ROOT / "tools"), str(_COMPILER_ROOT)]: - if _p not in sys.path: - sys.path.insert(0, _p) - - -def run_aten_e2e( - model_id: str, - seq_len: int = 64, - num_layers: int = 1, - build_dir: str | None = None, - layer_idx: int = 0, - hidden_size: int = 64, - inter_dim: int = 128, - trust_remote_code: bool = False, - partial_load: bool = False, -) -> dict: - """Run a HF model through the ATen compilation path end-to-end. - - Steps: - 1. Load model config + layer weights from HuggingFace - 2. Build ISA via PlenaCompiler + ops.* (numerically verified path) - 3. Set up sim environment (ASM + HBM weights + FPRAM constants) - 4. Run Rust emulator - 5. Compare VRAM output against golden PyTorch reference - - Returns dict with: - passed: bool - allclose_match_rate: float (percentage) - max_error: float - mae: float - mse: float - elapsed_s: float (wall-clock seconds) - model_id: str - layer_idx: int - num_layers: int - seq_len: int - hidden_size: int - inter_dim: int - build_dir: str - """ - from transactional_emulator.testbench.model_layer_test_builder import ( - build_and_run_decoder_test, - build_and_run_multi_layer_test, - get_model_dims, - slice_dims_for_sim, - MLEN, - ) - from transactional_emulator.testbench.emulator_runner import ( - compare_emulator_output, - run_emulator, - ) - from transactional_emulator.testbench.config_utils import update_plena_config - from transactional_emulator.tools.check_mem import print_comparison_results - - t0 = time.time() - - # Resolve build directory - if build_dir is None: - safe_name = model_id.replace("/", "_") - build_dir = str( - Path("/tmp") / f"aten_e2e_{safe_name}_sl{seq_len}_l{layer_idx}" - ) - build_path = Path(build_dir) - - # ------------------------------------------------------------------ - # [1/5] Probe model config - # ------------------------------------------------------------------ - print(f"[1/5] Probing model config: {model_id}") - try: - full_dims = get_model_dims(model_id) - except (OSError, ConnectionError) as exc: - print(f"[SKIP] HuggingFace model '{model_id}' unavailable: {exc}") - return { - "passed": False, - "error": str(exc), - "model_id": model_id, - } - sim_dims = slice_dims_for_sim(full_dims, hidden_slice=hidden_size, inter_slice=inter_dim) - print(f" Full dims: hidden={full_dims.hidden_size}, inter={full_dims.inter_dim}, " - f"heads={full_dims.num_heads}, kv_heads={full_dims.num_kv_heads}, head_dim={full_dims.head_dim}") - print(f" Sim dims: hidden={sim_dims.hidden_size}, inter={sim_dims.inter_dim}") - - # ------------------------------------------------------------------ - # [2/5] Build ISA + golden reference + sim env via build_and_run_decoder_test - # - # We call the proven function directly — it handles: - # - Weight loading + slicing - # - PlenaCompiler ISA generation - # - create_sim_env + create_mem_for_sim - # - Golden reference computation - # - Emulator execution + comparison - # - # For multi-layer: iterate layers (each is independent at sim scale). - # ------------------------------------------------------------------ - results_per_layer = [] - - if num_layers == 1: - # Single layer: use proven single-layer path (with RoPE) - current_layer = layer_idx - asm_name = f"aten_{model_id.split('/')[-1]}_l{current_layer}" - layer_build = build_path / f"layer_{current_layer}" - - print(f"\n[2/5] Building ISA for layer {current_layer} via PlenaCompiler + ops.*") - print(f"[3/5] Setting up sim environment: {layer_build}") - print(f"[4/5] Running Rust transactional emulator") - - extra_kwargs = {} - if trust_remote_code: - extra_kwargs["trust_remote_code"] = True - if partial_load: - extra_kwargs["partial_load"] = True - - try: - build_and_run_decoder_test( - model_id=model_id, - asm_name=asm_name, - build_dir=layer_build, - layer_idx=current_layer, - seq_len=seq_len, - hidden_size=hidden_size, - inter_dim=inter_dim, - **extra_kwargs, - ) - comp_results, comp_params = compare_emulator_output(layer_build) - results_per_layer.append({ - "layer": current_layer, - "passed": True, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except SystemExit as e: - if e.code == 0: - return { - "passed": False, - "error": "HuggingFace model unavailable (skipped)", - "model_id": model_id, - } - try: - comp_results, comp_params = compare_emulator_output(layer_build) - results_per_layer.append({ - "layer": current_layer, - "passed": False, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except Exception: - results_per_layer.append({ - "layer": current_layer, - "passed": False, - "error": f"Emulator comparison failed after exit code {e.code}", - }) - else: - # Multi-layer: chain N layers with residual connections (no RoPE) - asm_name = f"aten_{model_id.split('/')[-1]}_chain{num_layers}" - chain_build = build_path / f"chain_{num_layers}layers" - - print(f"\n[2/5] Building chained {num_layers}-layer ISA via PlenaCompiler + ops.*") - print(f"[3/5] Setting up sim environment: {chain_build}") - print(f"[4/5] Running Rust transactional emulator") - - extra_kwargs = {} - if trust_remote_code: - extra_kwargs["trust_remote_code"] = True - if partial_load: - extra_kwargs["partial_load"] = True - - try: - build_and_run_multi_layer_test( - model_id=model_id, - asm_name=asm_name, - build_dir=chain_build, - num_layers=num_layers, - layer_idx_start=layer_idx, - seq_len=seq_len, - hidden_size=hidden_size, - inter_dim=inter_dim, - **extra_kwargs, - ) - comp_results, comp_params = compare_emulator_output(chain_build) - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": True, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except SystemExit as e: - if e.code == 0: - return { - "passed": False, - "error": "HuggingFace model unavailable (skipped)", - "model_id": model_id, - } - try: - comp_results, comp_params = compare_emulator_output(chain_build) - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": False, - "allclose_match_rate": comp_results["allclose_match_rate"], - "max_error": comp_results["max_error"], - "mae": comp_results["mae"], - "mse": comp_results["mse"], - }) - except Exception: - results_per_layer.append({ - "layer": f"chain_{num_layers}", - "passed": False, - "error": f"Emulator comparison failed after exit code {e.code}", - }) - - elapsed = time.time() - t0 - - # ------------------------------------------------------------------ - # [5/5] Aggregate results - # ------------------------------------------------------------------ - print(f"\n[5/5] Results summary ({elapsed:.1f}s elapsed)") - all_passed = all(r.get("passed", False) for r in results_per_layer) - - # Use first layer's metrics for the top-level result - first = results_per_layer[0] if results_per_layer else {} - - summary = { - "passed": all_passed, - "allclose_match_rate": first.get("allclose_match_rate", 0.0), - "max_error": first.get("max_error", float("inf")), - "mae": first.get("mae", float("inf")), - "mse": first.get("mse", float("inf")), - "elapsed_s": elapsed, - "model_id": model_id, - "layer_idx": layer_idx, - "num_layers": num_layers, - "seq_len": seq_len, - "hidden_size": hidden_size, - "inter_dim": inter_dim, - "build_dir": str(build_path), - "layers": results_per_layer, - } - - for r in results_per_layer: - status = "PASS" if r.get("passed") else "FAIL" - match = r.get("allclose_match_rate", "N/A") - if isinstance(match, float): - match = f"{match:.2f}%" - print(f" Layer {r.get('layer', '?')}: [{status}] allclose={match}") - - if all_passed: - print(f"\n[ATen e2e PASSED] {model_id} — {num_layers} layer(s), " - f"allclose={first.get('allclose_match_rate', 0):.2f}%") - else: - print(f"\n[ATen e2e FAILED] {model_id} — see per-layer results above") - - return summary - - -# --------------------------------------------------------------------------- -# CLI entry point -# --------------------------------------------------------------------------- -def main(): - import argparse - - parser = argparse.ArgumentParser( - description="Run HF model through ATen compilation path (PlenaCompiler + ops.*)", - prog="python -m generator.aten_runner", - ) - parser.add_argument("model_id", help="HuggingFace model ID (e.g. AICrossSim/clm-60m)") - parser.add_argument("--seq-len", type=int, default=64, - help="Sequence length (default: 64)") - parser.add_argument("--num-layers", type=int, default=1, - help="Number of decoder layers to test (default: 1)") - parser.add_argument("--layer-idx", type=int, default=0, - help="Starting layer index (default: 0)") - parser.add_argument("--hidden-size", type=int, default=64, - help="Hidden dimension clipped to sim limits (default: 64)") - parser.add_argument("--inter-dim", type=int, default=128, - help="FFN intermediate dimension clipped to sim limits (default: 128)") - parser.add_argument("--build-dir", type=str, default=None, - help="Build directory for sim artifacts (default: /tmp/aten_e2e_...)") - parser.add_argument("--trust-remote-code", action="store_true", - help="Trust remote code for HF model loading") - parser.add_argument("--partial-load", action="store_true", - help="Load only needed weight shards (for large models)") - - args = parser.parse_args() - - result = run_aten_e2e( - model_id=args.model_id, - seq_len=args.seq_len, - num_layers=args.num_layers, - build_dir=args.build_dir, - layer_idx=args.layer_idx, - hidden_size=args.hidden_size, - inter_dim=args.inter_dim, - trust_remote_code=args.trust_remote_code, - partial_load=args.partial_load, - ) +from compiler.aten.e2e_runner import main, run_aten_e2e - sys.exit(0 if result["passed"] else 1) +__all__ = ["main", "run_aten_e2e"] if __name__ == "__main__": diff --git a/generator/runner.py b/generator/runner.py index 123a835..e1c7c0e 100644 --- a/generator/runner.py +++ b/generator/runner.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -PLENA Generator Runner -- entry point for both compilation pipelines. +PLENA generator runner -- symbolic codegen and ATen e2e convenience entry. Modes: codegen -- Pipeline 2 (Generator): HF config -> symbolic graph -> @@ -12,7 +12,7 @@ Examples: python -m generator.runner codegen AICrossSim/clm-60m output.asm --seq-len 512 - python -m generator.runner aten AICrossSim/clm-60m --seq-len 32 --num-layers 1 + python -m compiler.aten.e2e_runner AICrossSim/clm-60m --seq-len 32 --num-layers 1 See docs/COMPILATION_PIPELINES.md for the full architecture overview. """ @@ -28,8 +28,8 @@ def _run_aten(args) -> int: - """ATen-backed end-to-end: PlenaCompiler + ops.* → emulator → numerical check.""" - from generator.aten_runner import run_aten_e2e + """ATen-backed end-to-end: PlenaCompiler + ops.* -> emulator -> numerical check.""" + from compiler.aten.e2e_runner import run_aten_e2e result = run_aten_e2e( model_id=args.model_path, diff --git a/generator/tests/test_generator_e2e.py b/generator/tests/test_generator_e2e.py index 6efd40b..2ef75aa 100644 --- a/generator/tests/test_generator_e2e.py +++ b/generator/tests/test_generator_e2e.py @@ -89,7 +89,7 @@ def _build_hbm_from_hf_weights( ) -> dict: """Populate hbm_for_behave_sim.bin with real HF model weights. - Mirrors compiler/sim_env_utils/build_env.py::create_mem_for_sim but + Mirrors sim_env_utils/build_env.py::create_mem_for_sim but operates directly on HF tensors (no intermediate .pt files) and writes each weight block at the scheduler-assigned HBM offset. @@ -308,7 +308,7 @@ def _build_fp_sram_preload( slot (or the harness needs to refresh slot 5 between text and vision runs) to fix this. Tracking issue: TODO. - Slot map (from compiler/generator/scheduler/mem_layout_lib.json): + Slot map (from generator/scheduler/mem_layout_lib.json): 0: infinity — softmax masking sentinel (use a large fp16 negative) 1: eps — RMSNorm epsilon 2: hid_reciprocal — 1.0 / hidden_size @@ -518,7 +518,7 @@ def run_test_aten( numerical verification deferred, this immediately gets full numerical correctness via the mature ATen compilation backend. """ - from generator.aten_runner import run_aten_e2e + from compiler.aten.e2e_runner import run_aten_e2e print("=" * 80) print(f"Generator e2e harness (ATen backend) — {model_id} — " diff --git a/pyproject.toml b/pyproject.toml index a5a400a..dc610c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = [ + "compiler", "assembler", "asm_templates", "aten", diff --git a/sim_env_utils/__init__.py b/sim_env_utils/__init__.py index a40e6d3..b2e4ac8 100644 --- a/sim_env_utils/__init__.py +++ b/sim_env_utils/__init__.py @@ -1,4 +1,3 @@ -from .build_env import create_mem_for_sim +from .build_env import create_mem_for_sim as create_mem_for_sim -# Legacy alias -build_sim_env = create_mem_for_sim +__all__ = ["create_mem_for_sim"] diff --git a/sim_env_utils/build_env.py b/sim_env_utils/build_env.py index 9b8696c..6c99cc0 100644 --- a/sim_env_utils/build_env.py +++ b/sim_env_utils/build_env.py @@ -16,7 +16,7 @@ from .build_sys_tools import env_setup, init_mem # noqa: E402 -# Project root is 3 levels up from compiler/sim_env_utils/ +# Project root is 3 levels up from PLENA_Compiler/sim_env_utils/ _PROJECT_ROOT = Path(__file__).parent.parent.parent logger = get_logger("testbench") diff --git a/sim_env_utils/build_sys_tools.py b/sim_env_utils/build_sys_tools.py index a134b40..bd39558 100644 --- a/sim_env_utils/build_sys_tools.py +++ b/sim_env_utils/build_sys_tools.py @@ -12,7 +12,7 @@ from compiler.assembler.assembly_to_binary import AssemblyToBinary -# Project root is 3 levels up from compiler/sim_env_utils/ +# Project root is 3 levels up from PLENA_Compiler/sim_env_utils/ _PROJECT_ROOT = Path(__file__).parent.parent.parent