Skip to content

Merge plena-compiler: graph-IR pipeline rewrite, mid_ir passes, SSB chain support#42

Open
gaoziqian123 wants to merge 21 commits into
mainfrom
plena-compiler
Open

Merge plena-compiler: graph-IR pipeline rewrite, mid_ir passes, SSB chain support#42
gaoziqian123 wants to merge 21 commits into
mainfrom
plena-compiler

Conversation

@gaoziqian123
Copy link
Copy Markdown
Collaborator

Summary

Merges the plena-compiler branch into main. This brings the compiler from the legacy graph-layer pipeline to the new mid_ir-based pipeline, adds SPMD groundwork, and lands SSB (SingleStreamBlock) chained-kernel support.

103 files changed, ~37k insertions.

Highlights

  • Pipeline rewrite — migrate frontend to all-graph-layer pipeline, then drop the graph layer in favor of a mid_ir pipeline with cluster_dim; add BTMV / MV / vram-vram-copy dispatch.
  • Buffer layout — unify VRAM/MRAM ≥2D buffers to BSHD; consolidate row_*_at addressing; pinned globals are row-major-flat; row_stack lane-stride fix.
  • gemm schema — region + dim_roles schema for gemm ops.
  • SPMD groundworkSPMD_REWRITE.md design doc + step 1 classify_lane_use pass with unit tests (replacing the 4 lane-fusion graph passes).
  • SSB chain supportconcat_min kernel, head-layout helpers, mid_ir + kernel fixes.
  • register_allocspill_borrow now also filters pinned GPs.
  • rope_min — v↔fp transfer treats cluster phase as 0; multi-lane wrap.
  • Cleanup — remove legacy non-min kernel demos; refresh PIPELINE_ARCHITECTURE / MIGRATION_PLAN / AI_AGENT_GUIDE docs.
  • Tests — extensive new mid_ir pass test suite (fold / fuse / split / mark / view / infer_lane_axis / distribute_cluster / to_plena), narrow-mm emitter, tiled BTMM, online-softmax, reference kernels.

Test plan

  • mid_ir pass unit tests under tilelang_tvm_compiler/tests/ (see diff).
  • SSB staged precision diagnostics validated separately in the simulator testbench.

🤖 Generated with Claude Code

gaoziqian123 and others added 18 commits February 8, 2026 05:51
Mirrors preload_act_asm logic in reverse direction, using stride mode
for hardware-assisted format conversion between VRAM block layout and
HBM row-major layout.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Make PLENA_Compiler match the active PLENA_Simulator/compiler state.

tilelang_tvm_compiler:
- add frontend/ pipeline (allocate_group_memory / annotate_gemm_kind /
  annotate_group / annotate_sync / forbid_plena_extern / fuse_elementwise
  / inline_let_stmts / lower_compound_fp_stores / lower_fp_row_patterns /
  lower_to_hlir / scope_inference / split_lane_groups + gemm_macros +
  pipeline) and frontend_legacy/ snapshot
- add kernels: tiled_conv2d, flash_decode_min, mm64, qk_btmm, rope_min
- update kernels: flash_attention_min, online_softmax_min
- remove deprecated kernels: fpram_smoke, row_mask_smoke, tiled_mm
  (and test_fpram_ops)
- update core: __init__, __main__, codegen, hlir, intrinsics,
  isa_emitter, isa_pass
- add PIPELINE_ARCHITECTURE.md and doc/AI_AGENT_GUIDE.md
- add frontend tests + test_matmul_emitter, test_reference_kernels;
  refresh test_expr_materializer, test_online_softmax_min

assembler/doc/runtime:
- update assembler/{assembly_to_binary,parser}.py
- update doc/operation.svh, doc/plena_isa_spec.md
- update tilelang_runtime_compier _isa_emitter

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the legacy stmt-walker frontend chain with a graph-IR-centric
pipeline. Programs lift once into a Graph (graph_ir.Graph), passes
operate on the graph, and a single materialize step generates final
TIR with plena.* externs.

New infrastructure (frontend/passes/):
  - graph_ir.py: Graph / GraphNode / BufferNode / BufferAccess /
    LaneGroup / NestedForGroup / ForRoot / NodeRoot / RawStmt
  - lift_from_raw.py: raw PrimFunc -> Graph (was lift_to_blocks +
    lift_to_graph two-step)
  - graph_walker.py: shared traversal helpers
  - graph_pipeline.materialize_to_primfunc with expand_lane_buffers=True
  - graph_passes/ subpackage:
    - annotate_grid (was stmt annotate_group)
    - annotate_sync (graph-layer rewrite)
    - split_lane_groups (with inlined _StmtVarSubst)
    - lift_lane_groups (ForRoot -> LaneGroup upgrade)
    - fuse_elementwise (T.Parallel -> plena.v_*)
    - scope_inference (owns BufferScopeMap / ScopeInferenceError)
    - allocate_group_memory.analyze (sets ATTR_LANE_LAYOUT)
    - expand_buffers.expand (rebuilds tir.Buffer + rewrites indices)
    - lower_fp_row_patterns (fp_*_at / row_*_at)

Deleted (replaced by graph_passes/ counterparts):
  - frontend/passes/{annotate_group,annotate_sync,annotate_gemm_kind,
    split_lane_groups,fuse_elementwise,scope_inference,
    allocate_group_memory,lower_fp_row_patterns,lift_to_blocks,
    lift_to_graph}.py
  - frontend_legacy/ (entire orphan tree)
  - 6 stmt-walker test files

frontend/pipeline.py: rewritten to a single graph path; no fallback
flag, no env var.

Bug fix: fuse_elementwise now sets ATTR_IS_SYNC=True on newly created
plena.zero_v / plena.v_add / plena.v_sub / plena.v_mul GraphNodes.
Without this, the materialize-time partitioner emits these
INHERENTLY_SYNC_EXTERNS inside the per-lane for-by, causing
flash_attention_min to compute O *= lane_count (numerically off by 4x).

All 101 frontend tests pass. flash_attention_min e2e numerics now
match golden.

MIGRATION_PLAN.md added with a status writeup.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… for graph-IR pipeline

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ssing

* expand_buffers: ROW_STACK now emits BSHD (B=lane, S=rows, H=1, D=mlen)
  instead of BHSD-shaped (1, lane, rows, mlen) — lane lives in the B axis
  so BTMM-WO output and BSHD attention buffers share one 7D physical
  layout family. Added a BSHD_LIFT mode that promotes the remaining 2D
  VRAM/MRAM allocs (those not touched by lane-fusion) to 4D BSHD so
  downstream passes only see one shape rank. global.* buffers are
  intentionally skipped (preserve user-chosen 2D semantic).

* lower_fp_row_patterns: row_*_at calls now always emit (row, head) as
  the trailing scalar pair, independent of the buffer's lane-pack mode.
  _row_dims_from_indices and _try_lower_reduce infer the lane axis from
  the post-expand 4D BSHD shape and pick the matching index slot.

* isa_pass._resolve_row_at_coords rewritten to dispatch on the 4D BSHD
  shape pattern (COL_PACK / ROW_STACK / wide-D / single-tile) and
  translate (row, head) into a physical VRAM mlen-row + optional V_MASK.
  No more shape-rank or (dim2, dim3) order branches at call sites.

* intrinsics: rename row_*_at scalar args dim2/dim3 → row/head and add
  a docstring describing the layout-agnostic semantic.

* Fixes that fell out:
  - isa_emitter.emit_tile_binary: V_SUB_VV emits (dst, lhs, rhs) — the
    earlier (dst, rhs, lhs) reversal contradicted the simulator's
    vd = vs1 - vs2 semantics.
  - lower_to_hlir._flatten_starts_tiled: B's own stride is inner_s (one
    inner tile), not inner_b (the B-axis total volume). The old formula
    accidentally worked only because every existing kernel had B==1.
  - isa_pass._emit_dma_h2v_slice_multi_tile: same B-stride fix; also
    extended to accept dynamic slice starts (dyn base reg + static
    per-tile residual), matching the single-tile fast path.
  - isa_emitter._emit_preload_tile_isa / _emit_store_tile_isa /
    emit_hbm_tile_to_mram: when hbm_start_offset_reg is provided, fold
    hbm_start_offset into the S_ADDI_INT as a static residual instead
    of overwriting it.

* graph_pipeline: thread the scope map into expand_buffers so the
  BSHD_LIFT pass can pick out VRAM/MRAM buffers before their declared
  scope gets rewritten from shared.dyn / local.fragment.

Verified compile end-to-end: conv2d_min, flash_attention_min,
flash_decode_min, rope_min. flash_attention's S_loc is now 4x64x1x64
BSHD with all row_*_at calls in consistent (row, by_i) order; emitted
ASM offsets match the BMM_WO writeback formula (j*mlen + i)*mlen.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop demo/exploration kernels (loop_dma, loop_slice_dma, minimal_btmm,
mm64, qk_btmm, static_slice_dma, tiled_btmm, tiled_conv2d) that were
holdovers from the early bring-up phase. The supported kernel surface
going forward is the *_min family.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Plan replaces split_lane_groups / lift_lane_groups /
allocate_group_memory / expand_buffers with three small early TIR
passes that resolve lane fusion before lift_from_raw_primfunc:

  classify_lane_use  — tag each buffer with its lane-fusion role
                       from op annotations + `by` use
  expand_lane_grid   — for tagged buffers, add a LANE outer dim and
                       wrap per-lane work in a serial loop
  infer_lane_layout  — pick where the lane axis sits per buffer
                       (BSHD vs BHSD) and rewrite shape + indices

Net change: −1500 / +600 lines, 4 graph passes deleted, 2 simplified.
Buffer model stays vanilla 3D TIR — no new macros, no *_multi op
kinds, no contiguous-backing tricks. The IR graph_passes see is free
of lane-fusion concepts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements the first of three new TIR passes from SPMD_REWRITE.md.
Walks a raw PrimFunc (post inline_let_stmts + lower_compound_fp_stores,
pre lift_from_raw_primfunc), looks at each tl.tileop.gemm_py / copy
site + the surrounding plena.gemm_kind AttrStmt + the kernel's
plena.lane_axis func attr, and assigns one of:

  btmm_lhs / btmm_rhs / btmm_out
  per_head_lhs / per_head_rhs / per_head_out
  lane_dma_dst
  none

per buffer. Layout-compatible re-tags (e.g. lane_dma_dst then btmm_lhs,
both COL_PACK) are silently merged; structurally-incompatible re-tags
raise ClassifyLaneUseError.

The pass is read-only — it returns the original PrimFunc plus a
{buffer_name: BufferRole} dict that the next two passes
(expand_lane_grid, infer_lane_layout) will consume. No IR rewriting
happens here.

Tests build raw TIR by hand using tir.call_extern (no tilelang
dependency), exercise the full flash_attention_min op set, the
no-btmm-attr fallback, and the no-lane-axis defensive case.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…-copy dispatch

- delete legacy graph-IR pipeline (graph_passes/, graph_pipeline,
  graph_walker, graph_ir, lift_from_raw, lower_to_hlir,
  classify_lane_use, expand_lane_grid, infer_lane_layout,
  fuse_elementwise, forbid_plena_extern) + their tests
- frontend/pipeline.py becomes a stub that raises on compile_func
- rename plena.v_* / plena.zero_v → plena.tile_* (whole-tile intent;
  row_*_op stays for single-row instructions)
- unify row_*_op to one HLIR op = one HW instruction; multi-row
  callers wrap in HLIR for-row
- buffer.cluster_dim metadata: explicit lane-axis position carried
  through split → view → burn_view → to_plena
- _resolve_row_at_coords uses cluster_dim to compute head/row
  stride; no shape-value heuristics
- BTMV / M_MV dispatch on LHS rows==1 for decode
- T.copy(vram, vram) → copy_v_to_v (V_ADD_VF f0=0)
- T.copy(vram, fpram) / T.copy(fpram, vram) → row_load_v_to_fp /
  row_store_fp_to_v (S_MAP_*_FP/V)
- async marker pruning: per-lane FPRAM scalar ops no longer flagged
  async (only DMA / BTMM / tile_* survive)
- dead_buffer_elim pass strips unused buffers
- kernels (attention/decode/rope) no longer pre-lower; return raw
  PrimFunc for compile_kernel to drive

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
S_MAP_FP_V / S_MAP_V_FP transfers VLEN=MLEN contiguous fp slots in
one issue spanning all cluster lanes natively — so under sync wrap
the cluster-phase index on both vram and fp sides must collapse to
0. Use buffer.cluster_dim to locate the phase axis on the FP ref
and zero it. Add row_load_v_to_fp / row_store_fp_to_v to the
multi-lane op set so they don't get a synthetic ``for by_phase``
re-issuing the same instruction 4×.

fold: ``_wrap_src`` allows affine-offset srcs (independent indices)
when the dst is an FPRAM scalar slot, so compound-store patterns
like ``OUT[2*i] = X[2*i]*C[2*i] + X[2*i+1]*NS[2*i]`` lower
cleanly. RawStore fallback was removed; fold now errors if it
can't recognise a store.

rope_min: write Q_OUT via explicit slice form to match input copies
so dma_v2h_slice sees the full rows-length s-dim.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring spill_borrow's victim selection in line with _auto_spill —
pinned GPs (loop hw counters, long-lived symbol-table bindings) must
never be picked as spill candidates, regardless of which path needs
the spill. Without this, a serial for-loop body that triggers
spill_borrow under register pressure could silently displace gp_loop
to IntRAM and reuse the same physical register inside the borrow
scope, corrupting C_LOOP_END's counter read.

Also rolls in pending in-progress work across the mid_ir / isa_pass
stack and flips flash_attention_min's T.unroll loops to T.serial to
exercise the pinned-GP path end-to-end.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…k lane stride fix

* Einstein-summation-style gemm schema: region (4D BSHD start+extent) per
  operand + per-axis M/K/N dim_role labels. matmul / mv / btmm / btmv all
  read M/K/N positions from the role tables to drive instruction
  selection and extent lookup. Drops the old explicit
  M_tiles/K_tiles/N/transpose_b scalar args.

* MramRegion added as a twin of VramRegion (two distinct dataclasses);
  dead_buffer_elim and to_plena lowering recognise both.

* Author-pinned global.vram / global.mram tensor caches (Q_cache,
  O_cache) now carry is_pinned_global=True. AddressAllocationPass skips
  make_tile_layout for them, and the offset-walking iterators
  (_vram_region_iter_chunks, _region_origin_offset) compute addresses
  as flat row-major instead of 7D mlen-tile-padded — matches how the
  testbench actually loads these buffers.

* global.vram / global.mram also get pad-to-4D, but with heads-at-H
  rule ((1,1,a,b)) instead of the default (1,a,1,b) — matches the
  head-major (head_count, hlen) layout kernels use these caches for.
  cluster_dim stays None on globals so sync-wrap iterators don't fold
  the head axis.

* copy_v_to_v handles cluster-asymmetric src/dst: each side emits its
  own region at native rank using ref indices; _rewrite_refs_to_4d
  lifts both per their own pad_inserts / cluster_modes entry.

* M_MM_WO physical row stride fix: c_orow_step = blen * mlen (not
  blen * dst_row_stride). M_MM_WO writes blen rows at physical pitch
  mlen regardless of how dense the dst's logical N maps inside each
  mlen-row.

* row_stack lane stride fix: when cluster_dim==0 (lane on B axis,
  lane_count==1), lane_stride = product(buf.shape[1:]) — matches the
  M_BMM_WO / M_BMV_WO hardware writeback (lane j -> base + j *
  per_lane_elems). For S=mlen this coincides with the old b_stride =
  mlen*inner_lane; for S<mlen (flash_decode rows=1) the old value
  overshot and lane 1+ wrote past the buffer.

* unroll_loops mode on matmul emitter (avoid sim
  MAX_LOOP_INSTRUCTIONS).

New debug kernels: flash_attention_gemm_only, flash_decode_min_gemm_only
to bisect gemm schema in isolation. Plus layernorm_min / linear_min /
linear_min_no_transpose / modulate_min / residual_gate_min / rmsnorm_min
/ silu_min / gelu_min staging kernels.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ernel fixes

- concat_min: feature-axis concat of two head-packed tensors
- _head_layout: BSHD <-> B,S,1,H*D view helpers
- copy_offset_min: o_head_offset probe kernel
- hoist_float_constants pass; mid_ir fold/fuse/to_plena fixes
- kernel updates across flash_attention/gelu/layernorm/linear/rmsnorm/silu

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of lines (20,000). Try reducing the number of changed lines and requesting a review from Copilot again.

gaoziqian123 and others added 2 commits May 19, 2026 13:21
…n kernel, plena_settings

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ocation

- New MIR layer (mir.py, mir_passes.py) + PreIsaIR v2 (pre_isa_ir_v2.py,
  pre_isa_pass_v2.py, pre_isa_to_mir.py) + ISA emit (mir_to_isa.py).
- Register allocation: structured live intervals + scope-recursive spill.
  Loop-carried values are stored to IntRAM at scope entry (outside the
  loop body) and reloaded per use inside; the store never lands in a body
  that loops, fixing cross-iteration spill corruption.
- C_SET_ADDR_REG emits 3 operands (aN, gp0, gp{addr}) matching HW.
- FORCE_SERIAL_LOOPS: all loops lower to hardware C_LOOP (emit-time
  unroll removed as unsound).
- Add v2 test suite (test_v2_*, test_pre_isa_*, test_mir_passes); drop
  superseded legacy emitter tests.
- doc/simulator_cost_model.md, REGALLOC design notes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants