diff --git a/aten/vram_stage_compare.py b/aten/vram_stage_compare.py index f169fb3..a206868 100644 --- a/aten/vram_stage_compare.py +++ b/aten/vram_stage_compare.py @@ -11,8 +11,10 @@ vram_path="transactional_emulator/vram_dump.bin", build_dir="/tmp/smolvlm2_1layer_f32regs", hidden=576, inter=1536, num_heads=9, num_kv_heads=3, + layer_idx=0, ) """ +import re import struct import numpy as np import torch @@ -47,14 +49,36 @@ def _mse(a, b): return ((a.float() - b.float()) ** 2).mean().item() +def _infer_final_layer_idx(build: Path) -> int: + indices = [] + for path in build.glob("W_o_*.pt"): + match = re.fullmatch(r"W_o_(\d+)\.pt", path.name) + if match: + indices.append(int(match.group(1))) + if not indices: + raise FileNotFoundError(f"No W_o_.pt files found in {build}") + return max(indices) + + +def _read_alloc_addr(asm: str, name: str) -> int | None: + match = re.search( + rf"Allocate VRAM Matrix {re.escape(name)}: .*?VRAM\[(\d+)\]", + asm, + ) + return int(match.group(1)) if match else None + + def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, - seq_len=64, mlen=64, head_dim=64, eps=1e-5, verbose=True): + seq_len=64, mlen=64, head_dim=64, eps=1e-5, verbose=True, + layer_idx=None): """Compare each pipeline stage using emulator's own VRAM intermediates. Args: vram_path: path to the emulator's vram_dump.bin build_dir: path to the build directory with weight .pt files hidden, inter, num_heads, num_kv_heads: model dimensions + layer_idx: decoder layer to validate. Defaults to the last layer found + in build_dir, which is the layer that feeds the final output. Returns: dict of stage results with allclose percentages @@ -63,40 +87,39 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, _to_inter = lambda x: x.to(torch.bfloat16) _from_inter = lambda x: x.float() - # VRAM addresses (from ISA comments — these are model-dependent) - # For SmolVLM2 1-layer: X=12288, scratch=233472, Q=270400, O_full=307264, O_proj=356416 - # For clm-60m: different addresses. We compute from VRAM layout. - tiles = hidden // mlen - x_addr = 3 * mlen * mlen + 0 # after COS, SIN, mask (for native mode) + if layer_idx is None: + layer_idx = _infer_final_layer_idx(build) + # Read final output address from comparison_params import json params = json.load(open(build / "comparison_params.json")) final_addr = params["start_row_idx"] * mlen - results = {} + results = {"layer_idx": layer_idx} # --- Load weights --- - W_o = quantize_to_mxfp(torch.load(build / "W_o_0.pt", weights_only=True)) - W_gate = quantize_to_mxfp(torch.load(build / "W_gate_0.pt", weights_only=True)) - W_up = quantize_to_mxfp(torch.load(build / "W_up_0.pt", weights_only=True)) - W_down = quantize_to_mxfp(torch.load(build / "W_down_0.pt", weights_only=True)) + W_o = quantize_to_mxfp(torch.load(build / f"W_o_{layer_idx}.pt", weights_only=True)) + W_gate = quantize_to_mxfp(torch.load(build / f"W_gate_{layer_idx}.pt", weights_only=True)) + W_up = quantize_to_mxfp(torch.load(build / f"W_up_{layer_idx}.pt", weights_only=True)) + W_down = quantize_to_mxfp(torch.load(build / f"W_down_{layer_idx}.pt", weights_only=True)) # --- Stage 1: O_full (attention output) --- # Find O_full address from ISA comments o_full_addr = final_addr - 2 * seq_len * hidden - import re asm_path = build / "generated_asm_code.asm" if asm_path.exists(): with open(asm_path) as f: asm = f.read() - m = re.search(r'Allocate VRAM Matrix O_full_0.*?VRAM\[(\d+)\]', asm) - if m: - o_full_addr = int(m.group(1)) - m2 = re.search(r'Allocate VRAM Matrix residual_scratch.*?VRAM\[(\d+)\]', asm) - scratch_addr = int(m2.group(1)) if m2 else None + parsed_o_full_addr = _read_alloc_addr(asm, f"O_full_{layer_idx}") + if parsed_o_full_addr is not None: + o_full_addr = parsed_o_full_addr + scratch_addr = _read_alloc_addr(asm, "residual_scratch") else: scratch_addr = None + if verbose: + print(f" Validating layer {layer_idx}") + O_full = _read_bf16_matrix(vram_path, o_full_addr, seq_len, hidden) # --- Stage 2: O_proj = O_full @ W_o --- @@ -163,11 +186,13 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads, import sys vram = sys.argv[1] if len(sys.argv) > 1 else "transactional_emulator/vram_dump.bin" build = sys.argv[2] if len(sys.argv) > 2 else "/tmp/smolvlm2_1layer_f32regs" + layer_idx = int(sys.argv[3]) if len(sys.argv) > 3 else None print("=== VRAM Stage Comparison ===") results = compare_stages( vram_path=vram, build_dir=build, hidden=576, inter=1536, num_heads=9, num_kv_heads=3, + layer_idx=layer_idx, ) print(f"\nOverall: {'PASS' if results.get('norm+FFN+norm', 0) >= 99.0 else 'FAIL'}")