Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions aten/vram_stage_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
vram_path="transactional_emulator/vram_dump.bin",
build_dir="/tmp/smolvlm2_1layer_f32regs",
hidden=576, inter=1536, num_heads=9, num_kv_heads=3,
layer_idx=0,
)
"""
import re
import struct
import numpy as np
import torch
Expand Down Expand Up @@ -47,14 +49,36 @@ def _mse(a, b):
return ((a.float() - b.float()) ** 2).mean().item()


def _infer_final_layer_idx(build: Path) -> int:
indices = []
for path in build.glob("W_o_*.pt"):
match = re.fullmatch(r"W_o_(\d+)\.pt", path.name)
if match:
indices.append(int(match.group(1)))
if not indices:
raise FileNotFoundError(f"No W_o_<layer>.pt files found in {build}")
return max(indices)


def _read_alloc_addr(asm: str, name: str) -> int | None:
match = re.search(
rf"Allocate VRAM Matrix {re.escape(name)}: .*?VRAM\[(\d+)\]",
asm,
)
return int(match.group(1)) if match else None


def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads,
seq_len=64, mlen=64, head_dim=64, eps=1e-5, verbose=True):
seq_len=64, mlen=64, head_dim=64, eps=1e-5, verbose=True,
layer_idx=None):
"""Compare each pipeline stage using emulator's own VRAM intermediates.

Args:
vram_path: path to the emulator's vram_dump.bin
build_dir: path to the build directory with weight .pt files
hidden, inter, num_heads, num_kv_heads: model dimensions
layer_idx: decoder layer to validate. Defaults to the last layer found
in build_dir, which is the layer that feeds the final output.

Returns:
dict of stage results with allclose percentages
Expand All @@ -63,40 +87,39 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads,
_to_inter = lambda x: x.to(torch.bfloat16)
_from_inter = lambda x: x.float()

# VRAM addresses (from ISA comments — these are model-dependent)
# For SmolVLM2 1-layer: X=12288, scratch=233472, Q=270400, O_full=307264, O_proj=356416
# For clm-60m: different addresses. We compute from VRAM layout.
tiles = hidden // mlen
x_addr = 3 * mlen * mlen + 0 # after COS, SIN, mask (for native mode)
if layer_idx is None:
layer_idx = _infer_final_layer_idx(build)

# Read final output address from comparison_params
import json
params = json.load(open(build / "comparison_params.json"))
final_addr = params["start_row_idx"] * mlen

results = {}
results = {"layer_idx": layer_idx}

# --- Load weights ---
W_o = quantize_to_mxfp(torch.load(build / "W_o_0.pt", weights_only=True))
W_gate = quantize_to_mxfp(torch.load(build / "W_gate_0.pt", weights_only=True))
W_up = quantize_to_mxfp(torch.load(build / "W_up_0.pt", weights_only=True))
W_down = quantize_to_mxfp(torch.load(build / "W_down_0.pt", weights_only=True))
W_o = quantize_to_mxfp(torch.load(build / f"W_o_{layer_idx}.pt", weights_only=True))
W_gate = quantize_to_mxfp(torch.load(build / f"W_gate_{layer_idx}.pt", weights_only=True))
W_up = quantize_to_mxfp(torch.load(build / f"W_up_{layer_idx}.pt", weights_only=True))
W_down = quantize_to_mxfp(torch.load(build / f"W_down_{layer_idx}.pt", weights_only=True))

# --- Stage 1: O_full (attention output) ---
# Find O_full address from ISA comments
o_full_addr = final_addr - 2 * seq_len * hidden
import re
asm_path = build / "generated_asm_code.asm"
if asm_path.exists():
with open(asm_path) as f:
asm = f.read()
m = re.search(r'Allocate VRAM Matrix O_full_0.*?VRAM\[(\d+)\]', asm)
if m:
o_full_addr = int(m.group(1))
m2 = re.search(r'Allocate VRAM Matrix residual_scratch.*?VRAM\[(\d+)\]', asm)
scratch_addr = int(m2.group(1)) if m2 else None
parsed_o_full_addr = _read_alloc_addr(asm, f"O_full_{layer_idx}")
if parsed_o_full_addr is not None:
o_full_addr = parsed_o_full_addr
scratch_addr = _read_alloc_addr(asm, "residual_scratch")
else:
scratch_addr = None

if verbose:
print(f" Validating layer {layer_idx}")

O_full = _read_bf16_matrix(vram_path, o_full_addr, seq_len, hidden)

# --- Stage 2: O_proj = O_full @ W_o ---
Expand Down Expand Up @@ -163,11 +186,13 @@ def compare_stages(vram_path, build_dir, hidden, inter, num_heads, num_kv_heads,
import sys
vram = sys.argv[1] if len(sys.argv) > 1 else "transactional_emulator/vram_dump.bin"
build = sys.argv[2] if len(sys.argv) > 2 else "/tmp/smolvlm2_1layer_f32regs"
layer_idx = int(sys.argv[3]) if len(sys.argv) > 3 else None

print("=== VRAM Stage Comparison ===")
results = compare_stages(
vram_path=vram,
build_dir=build,
hidden=576, inter=1536, num_heads=9, num_kv_heads=3,
layer_idx=layer_idx,
)
print(f"\nOverall: {'PASS' if results.get('norm+FFN+norm', 0) >= 99.0 else 'FAIL'}")
Loading