refactor: slim PLENA compiler frontend and codegen#38
Merged
Conversation
Add model_compiler.py that walks an HF nn.Module tree and compiles to PLENA ISA via PlenaCompiler with pre-norm + residual connections. Two public functions: compile_hf_model(model, ...) → dict with ISA, golden, tensors compile_and_run(model, build_dir, ...) → compile + emulate + compare Verified: SmolLM2-135M 2-layer at 93.65% allclose.
…tention - Add alloc_at() to PlenaCompiler for VRAM view creation - Native mode: Q/O linear projections with K-split, per-head MHA (6 heads, 2 KV heads) - Legacy mode: sliced dims without projections (backward compatible) - _fix_large_immediates post-processor for S_ADDI_INT overflow at native dims - K-split support in _linear_projection for hidden > 4*mlen Verified: clm-60m native dims (hidden=384, inter=1408) at 92.68% allclose.
- HIGH-2: assert out_features % mlen == 0 in _linear_projection - HIGH-3: assert cols % block_dim == 0 in _weight_hbm_bytes - HIGH-4: fix VRAM padding ceiling division - LOW-2: fix docstring import path (model_compiler → plena_frontend)
True transformer forward pass: - K/V computed on-chip via linear projection, stored to HBM - RoPE via rotate_half matrix: Q_rot = linear(Q, R) - Applied to both Q and K per head - 37K ISA lines at native dims (clm-60m hidden=384) Verified: native compile OK, legacy path 100% allclose.
- Real embedding lookup via embed_tokens(input_ids) replacing random input - Zero pos_weight for Llama (RoPE handles positioning) - Optional lm_head projection (include_lm_head=True, default False) - Tied-weight detection for shared embed/lm_head - Multi-layer verified at 2 layers native dims (163K ISA lines) Full pipeline: embed → [norm → Q/K/V proj → RoPE → MHA → O proj → res → norm → FFN → res] × N → norm → [lm_head] Verified: 92.65% allclose at native dims (hidden=384, 1 layer, with RoPE).
- Causal mask: upper-triangular -inf added to QK^T scores before softmax - Real embedding lookup via embed_tokens(input_ids) - Optional lm_head projection (include_lm_head flag) - Three-way comparison: HF float32 vs golden (MXFP8+BF16) vs emulator - pos_weight=zeros for Llama (RoPE handles positioning) Verified: 99.93% allclose with causal mask (sliced dims). HF vs golden at native dims: 100% allclose, cosine=0.9995.
The compiler/utils/ package shadows tools/utils/ when both are on sys.path, causing ImportError for load_toml_config which only existed in tools/utils/load_config.py. Add the function here so imports resolve regardless of PYTHONPATH ordering.
Add golden_precision parameter to compile_hf_model with four modes: - hardware: MXFP8 weights + BF16 intermediates (default, matches HW) - no_weight_quant: fp32 weights + BF16 intermediates - no_bf16: MXFP8 weights + fp32 intermediates - fp32: float32 everything Ablation on 5-layer clm-60m proves MXFP8 weight quantization accounts for 100% of the HF-vs-golden gap (51.95% → 98.64% when removed). BF16 intermediate precision contributes 0%. Add test_quantization_ablation.py with @pytest.mark.slow.
Reads emulator's actual VRAM intermediates at each stage boundary, uses them as golden input for the next stage. Validates each stage independently, proving emulator correctness regardless of golden chain precision drift. SmolVLM2 result: 100% allclose, MSE=1.45e-05.
11 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
aten/plena/modules for model/program state, memory/register handling, program-level operations, and low-level ISA emission.PlenaCompiler-> produce ISA text.PlenaCompilermethods.aten/plena_compiler.pytile_compilercompatibility aliasasm_templates/flash_attn_asm.pybuild_sim_envalias in favor ofcreate_mem_for_simdocs/ATEN_TREE.mdand refreshes compiler architecture/pipeline docs for the current package layout.Validation
git diff --checkasm_templates/tests/test_vram_sub_projection.py- 5 passedaten/tests/test_plena_compiler.py- 13 passedasm_templates/tests/test_large_immediate.py- 23 passed