Skip to content

refactor: slim PLENA compiler frontend and codegen#38

Merged
booth-algo merged 32 commits into
mainfrom
feat/codegen-addr-reg-init
May 13, 2026
Merged

refactor: slim PLENA compiler frontend and codegen#38
booth-algo merged 32 commits into
mainfrom
feat/codegen-addr-reg-init

Conversation

@booth-algo
Copy link
Copy Markdown
Collaborator

@booth-algo booth-algo commented May 11, 2026

Summary

  • Splits the former monolithic ATen PLENA compiler into focused aten/plena/ modules for model/program state, memory/register handling, program-level operations, and low-level ISA emission.
  • Keeps the frontend shape simple: parse HF/ATen model inputs -> call the ATen ops dispatcher -> emit through PlenaCompiler -> produce ISA text.
  • Preserves the ATen dispatcher surface by routing PLENA op implementations through thin wrappers over PlenaCompiler methods.
  • Removes stale compatibility layers and aliases:
    • removed aten/plena_compiler.py
    • removed tile_compiler compatibility alias
    • removed asm_templates/flash_attn_asm.py
    • removed build_sim_env alias in favor of create_mem_for_sim
  • Adds docs/ATEN_TREE.md and refreshes compiler architecture/pipeline docs for the current package layout.

Validation

  • git diff --check
  • tracked-file audit for stale compatibility names/imports
  • compile checks for touched Python files
  • focused ruff checks for touched compiler/test files
  • asm_templates/tests/test_vram_sub_projection.py - 5 passed
  • aten/tests/test_plena_compiler.py - 13 passed
  • asm_templates/tests/test_large_immediate.py - 23 passed

booth-algo added 30 commits May 11, 2026 14:16
Add model_compiler.py that walks an HF nn.Module tree and compiles to
PLENA ISA via PlenaCompiler with pre-norm + residual connections.

Two public functions:
  compile_hf_model(model, ...) → dict with ISA, golden, tensors
  compile_and_run(model, build_dir, ...) → compile + emulate + compare

Verified: SmolLM2-135M 2-layer at 93.65% allclose.
…tention

- Add alloc_at() to PlenaCompiler for VRAM view creation
- Native mode: Q/O linear projections with K-split, per-head MHA (6 heads, 2 KV heads)
- Legacy mode: sliced dims without projections (backward compatible)
- _fix_large_immediates post-processor for S_ADDI_INT overflow at native dims
- K-split support in _linear_projection for hidden > 4*mlen

Verified: clm-60m native dims (hidden=384, inter=1408) at 92.68% allclose.
- HIGH-2: assert out_features % mlen == 0 in _linear_projection
- HIGH-3: assert cols % block_dim == 0 in _weight_hbm_bytes
- HIGH-4: fix VRAM padding ceiling division
- LOW-2: fix docstring import path (model_compiler → plena_frontend)
True transformer forward pass:
- K/V computed on-chip via linear projection, stored to HBM
- RoPE via rotate_half matrix: Q_rot = linear(Q, R)
- Applied to both Q and K per head
- 37K ISA lines at native dims (clm-60m hidden=384)

Verified: native compile OK, legacy path 100% allclose.
- Real embedding lookup via embed_tokens(input_ids) replacing random input
- Zero pos_weight for Llama (RoPE handles positioning)
- Optional lm_head projection (include_lm_head=True, default False)
- Tied-weight detection for shared embed/lm_head
- Multi-layer verified at 2 layers native dims (163K ISA lines)

Full pipeline: embed → [norm → Q/K/V proj → RoPE → MHA → O proj → res → norm → FFN → res] × N → norm → [lm_head]
Verified: 92.65% allclose at native dims (hidden=384, 1 layer, with RoPE).
- Causal mask: upper-triangular -inf added to QK^T scores before softmax
- Real embedding lookup via embed_tokens(input_ids)
- Optional lm_head projection (include_lm_head flag)
- Three-way comparison: HF float32 vs golden (MXFP8+BF16) vs emulator
- pos_weight=zeros for Llama (RoPE handles positioning)

Verified: 99.93% allclose with causal mask (sliced dims).
HF vs golden at native dims: 100% allclose, cosine=0.9995.
The compiler/utils/ package shadows tools/utils/ when both are on
sys.path, causing ImportError for load_toml_config which only existed
in tools/utils/load_config.py. Add the function here so imports resolve
regardless of PYTHONPATH ordering.
Add golden_precision parameter to compile_hf_model with four modes:
- hardware: MXFP8 weights + BF16 intermediates (default, matches HW)
- no_weight_quant: fp32 weights + BF16 intermediates
- no_bf16: MXFP8 weights + fp32 intermediates
- fp32: float32 everything

Ablation on 5-layer clm-60m proves MXFP8 weight quantization accounts
for 100% of the HF-vs-golden gap (51.95% → 98.64% when removed).
BF16 intermediate precision contributes 0%.

Add test_quantization_ablation.py with @pytest.mark.slow.
Reads emulator's actual VRAM intermediates at each stage boundary,
uses them as golden input for the next stage. Validates each stage
independently, proving emulator correctness regardless of golden
chain precision drift.

SmolVLM2 result: 100% allclose, MSE=1.45e-05.
@booth-algo booth-algo changed the title refactor: split PLENA compiler DSL and ISA layers refactor: slim PLENA compiler frontend and codegen May 12, 2026
@booth-algo booth-algo merged commit d656186 into main May 13, 2026
3 checks passed
@booth-algo booth-algo deleted the feat/codegen-addr-reg-init branch May 13, 2026 00:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant