Merge plena-compiler: graph-IR pipeline rewrite, mid_ir passes, SSB chain support#42
Open
gaoziqian123 wants to merge 21 commits into
Open
Merge plena-compiler: graph-IR pipeline rewrite, mid_ir passes, SSB chain support#42gaoziqian123 wants to merge 21 commits into
gaoziqian123 wants to merge 21 commits into
Conversation
Mirrors preload_act_asm logic in reverse direction, using stride mode for hardware-assisted format conversion between VRAM block layout and HBM row-major layout. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Make PLENA_Compiler match the active PLENA_Simulator/compiler state.
tilelang_tvm_compiler:
- add frontend/ pipeline (allocate_group_memory / annotate_gemm_kind /
annotate_group / annotate_sync / forbid_plena_extern / fuse_elementwise
/ inline_let_stmts / lower_compound_fp_stores / lower_fp_row_patterns /
lower_to_hlir / scope_inference / split_lane_groups + gemm_macros +
pipeline) and frontend_legacy/ snapshot
- add kernels: tiled_conv2d, flash_decode_min, mm64, qk_btmm, rope_min
- update kernels: flash_attention_min, online_softmax_min
- remove deprecated kernels: fpram_smoke, row_mask_smoke, tiled_mm
(and test_fpram_ops)
- update core: __init__, __main__, codegen, hlir, intrinsics,
isa_emitter, isa_pass
- add PIPELINE_ARCHITECTURE.md and doc/AI_AGENT_GUIDE.md
- add frontend tests + test_matmul_emitter, test_reference_kernels;
refresh test_expr_materializer, test_online_softmax_min
assembler/doc/runtime:
- update assembler/{assembly_to_binary,parser}.py
- update doc/operation.svh, doc/plena_isa_spec.md
- update tilelang_runtime_compier _isa_emitter
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the legacy stmt-walker frontend chain with a graph-IR-centric
pipeline. Programs lift once into a Graph (graph_ir.Graph), passes
operate on the graph, and a single materialize step generates final
TIR with plena.* externs.
New infrastructure (frontend/passes/):
- graph_ir.py: Graph / GraphNode / BufferNode / BufferAccess /
LaneGroup / NestedForGroup / ForRoot / NodeRoot / RawStmt
- lift_from_raw.py: raw PrimFunc -> Graph (was lift_to_blocks +
lift_to_graph two-step)
- graph_walker.py: shared traversal helpers
- graph_pipeline.materialize_to_primfunc with expand_lane_buffers=True
- graph_passes/ subpackage:
- annotate_grid (was stmt annotate_group)
- annotate_sync (graph-layer rewrite)
- split_lane_groups (with inlined _StmtVarSubst)
- lift_lane_groups (ForRoot -> LaneGroup upgrade)
- fuse_elementwise (T.Parallel -> plena.v_*)
- scope_inference (owns BufferScopeMap / ScopeInferenceError)
- allocate_group_memory.analyze (sets ATTR_LANE_LAYOUT)
- expand_buffers.expand (rebuilds tir.Buffer + rewrites indices)
- lower_fp_row_patterns (fp_*_at / row_*_at)
Deleted (replaced by graph_passes/ counterparts):
- frontend/passes/{annotate_group,annotate_sync,annotate_gemm_kind,
split_lane_groups,fuse_elementwise,scope_inference,
allocate_group_memory,lower_fp_row_patterns,lift_to_blocks,
lift_to_graph}.py
- frontend_legacy/ (entire orphan tree)
- 6 stmt-walker test files
frontend/pipeline.py: rewritten to a single graph path; no fallback
flag, no env var.
Bug fix: fuse_elementwise now sets ATTR_IS_SYNC=True on newly created
plena.zero_v / plena.v_add / plena.v_sub / plena.v_mul GraphNodes.
Without this, the materialize-time partitioner emits these
INHERENTLY_SYNC_EXTERNS inside the per-lane for-by, causing
flash_attention_min to compute O *= lane_count (numerically off by 4x).
All 101 frontend tests pass. flash_attention_min e2e numerics now
match golden.
MIGRATION_PLAN.md added with a status writeup.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… for graph-IR pipeline Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ssing
* expand_buffers: ROW_STACK now emits BSHD (B=lane, S=rows, H=1, D=mlen)
instead of BHSD-shaped (1, lane, rows, mlen) — lane lives in the B axis
so BTMM-WO output and BSHD attention buffers share one 7D physical
layout family. Added a BSHD_LIFT mode that promotes the remaining 2D
VRAM/MRAM allocs (those not touched by lane-fusion) to 4D BSHD so
downstream passes only see one shape rank. global.* buffers are
intentionally skipped (preserve user-chosen 2D semantic).
* lower_fp_row_patterns: row_*_at calls now always emit (row, head) as
the trailing scalar pair, independent of the buffer's lane-pack mode.
_row_dims_from_indices and _try_lower_reduce infer the lane axis from
the post-expand 4D BSHD shape and pick the matching index slot.
* isa_pass._resolve_row_at_coords rewritten to dispatch on the 4D BSHD
shape pattern (COL_PACK / ROW_STACK / wide-D / single-tile) and
translate (row, head) into a physical VRAM mlen-row + optional V_MASK.
No more shape-rank or (dim2, dim3) order branches at call sites.
* intrinsics: rename row_*_at scalar args dim2/dim3 → row/head and add
a docstring describing the layout-agnostic semantic.
* Fixes that fell out:
- isa_emitter.emit_tile_binary: V_SUB_VV emits (dst, lhs, rhs) — the
earlier (dst, rhs, lhs) reversal contradicted the simulator's
vd = vs1 - vs2 semantics.
- lower_to_hlir._flatten_starts_tiled: B's own stride is inner_s (one
inner tile), not inner_b (the B-axis total volume). The old formula
accidentally worked only because every existing kernel had B==1.
- isa_pass._emit_dma_h2v_slice_multi_tile: same B-stride fix; also
extended to accept dynamic slice starts (dyn base reg + static
per-tile residual), matching the single-tile fast path.
- isa_emitter._emit_preload_tile_isa / _emit_store_tile_isa /
emit_hbm_tile_to_mram: when hbm_start_offset_reg is provided, fold
hbm_start_offset into the S_ADDI_INT as a static residual instead
of overwriting it.
* graph_pipeline: thread the scope map into expand_buffers so the
BSHD_LIFT pass can pick out VRAM/MRAM buffers before their declared
scope gets rewritten from shared.dyn / local.fragment.
Verified compile end-to-end: conv2d_min, flash_attention_min,
flash_decode_min, rope_min. flash_attention's S_loc is now 4x64x1x64
BSHD with all row_*_at calls in consistent (row, by_i) order; emitted
ASM offsets match the BMM_WO writeback formula (j*mlen + i)*mlen.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop demo/exploration kernels (loop_dma, loop_slice_dma, minimal_btmm, mm64, qk_btmm, static_slice_dma, tiled_btmm, tiled_conv2d) that were holdovers from the early bring-up phase. The supported kernel surface going forward is the *_min family. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Plan replaces split_lane_groups / lift_lane_groups /
allocate_group_memory / expand_buffers with three small early TIR
passes that resolve lane fusion before lift_from_raw_primfunc:
classify_lane_use — tag each buffer with its lane-fusion role
from op annotations + `by` use
expand_lane_grid — for tagged buffers, add a LANE outer dim and
wrap per-lane work in a serial loop
infer_lane_layout — pick where the lane axis sits per buffer
(BSHD vs BHSD) and rewrite shape + indices
Net change: −1500 / +600 lines, 4 graph passes deleted, 2 simplified.
Buffer model stays vanilla 3D TIR — no new macros, no *_multi op
kinds, no contiguous-backing tricks. The IR graph_passes see is free
of lane-fusion concepts.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements the first of three new TIR passes from SPMD_REWRITE.md.
Walks a raw PrimFunc (post inline_let_stmts + lower_compound_fp_stores,
pre lift_from_raw_primfunc), looks at each tl.tileop.gemm_py / copy
site + the surrounding plena.gemm_kind AttrStmt + the kernel's
plena.lane_axis func attr, and assigns one of:
btmm_lhs / btmm_rhs / btmm_out
per_head_lhs / per_head_rhs / per_head_out
lane_dma_dst
none
per buffer. Layout-compatible re-tags (e.g. lane_dma_dst then btmm_lhs,
both COL_PACK) are silently merged; structurally-incompatible re-tags
raise ClassifyLaneUseError.
The pass is read-only — it returns the original PrimFunc plus a
{buffer_name: BufferRole} dict that the next two passes
(expand_lane_grid, infer_lane_layout) will consume. No IR rewriting
happens here.
Tests build raw TIR by hand using tir.call_extern (no tilelang
dependency), exercise the full flash_attention_min op set, the
no-btmm-attr fallback, and the no-lane-axis defensive case.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…-copy dispatch - delete legacy graph-IR pipeline (graph_passes/, graph_pipeline, graph_walker, graph_ir, lift_from_raw, lower_to_hlir, classify_lane_use, expand_lane_grid, infer_lane_layout, fuse_elementwise, forbid_plena_extern) + their tests - frontend/pipeline.py becomes a stub that raises on compile_func - rename plena.v_* / plena.zero_v → plena.tile_* (whole-tile intent; row_*_op stays for single-row instructions) - unify row_*_op to one HLIR op = one HW instruction; multi-row callers wrap in HLIR for-row - buffer.cluster_dim metadata: explicit lane-axis position carried through split → view → burn_view → to_plena - _resolve_row_at_coords uses cluster_dim to compute head/row stride; no shape-value heuristics - BTMV / M_MV dispatch on LHS rows==1 for decode - T.copy(vram, vram) → copy_v_to_v (V_ADD_VF f0=0) - T.copy(vram, fpram) / T.copy(fpram, vram) → row_load_v_to_fp / row_store_fp_to_v (S_MAP_*_FP/V) - async marker pruning: per-lane FPRAM scalar ops no longer flagged async (only DMA / BTMM / tile_* survive) - dead_buffer_elim pass strips unused buffers - kernels (attention/decode/rope) no longer pre-lower; return raw PrimFunc for compile_kernel to drive Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
S_MAP_FP_V / S_MAP_V_FP transfers VLEN=MLEN contiguous fp slots in one issue spanning all cluster lanes natively — so under sync wrap the cluster-phase index on both vram and fp sides must collapse to 0. Use buffer.cluster_dim to locate the phase axis on the FP ref and zero it. Add row_load_v_to_fp / row_store_fp_to_v to the multi-lane op set so they don't get a synthetic ``for by_phase`` re-issuing the same instruction 4×. fold: ``_wrap_src`` allows affine-offset srcs (independent indices) when the dst is an FPRAM scalar slot, so compound-store patterns like ``OUT[2*i] = X[2*i]*C[2*i] + X[2*i+1]*NS[2*i]`` lower cleanly. RawStore fallback was removed; fold now errors if it can't recognise a store. rope_min: write Q_OUT via explicit slice form to match input copies so dma_v2h_slice sees the full rows-length s-dim. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Bring spill_borrow's victim selection in line with _auto_spill — pinned GPs (loop hw counters, long-lived symbol-table bindings) must never be picked as spill candidates, regardless of which path needs the spill. Without this, a serial for-loop body that triggers spill_borrow under register pressure could silently displace gp_loop to IntRAM and reuse the same physical register inside the borrow scope, corrupting C_LOOP_END's counter read. Also rolls in pending in-progress work across the mid_ir / isa_pass stack and flips flash_attention_min's T.unroll loops to T.serial to exercise the pinned-GP path end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…k lane stride fix * Einstein-summation-style gemm schema: region (4D BSHD start+extent) per operand + per-axis M/K/N dim_role labels. matmul / mv / btmm / btmv all read M/K/N positions from the role tables to drive instruction selection and extent lookup. Drops the old explicit M_tiles/K_tiles/N/transpose_b scalar args. * MramRegion added as a twin of VramRegion (two distinct dataclasses); dead_buffer_elim and to_plena lowering recognise both. * Author-pinned global.vram / global.mram tensor caches (Q_cache, O_cache) now carry is_pinned_global=True. AddressAllocationPass skips make_tile_layout for them, and the offset-walking iterators (_vram_region_iter_chunks, _region_origin_offset) compute addresses as flat row-major instead of 7D mlen-tile-padded — matches how the testbench actually loads these buffers. * global.vram / global.mram also get pad-to-4D, but with heads-at-H rule ((1,1,a,b)) instead of the default (1,a,1,b) — matches the head-major (head_count, hlen) layout kernels use these caches for. cluster_dim stays None on globals so sync-wrap iterators don't fold the head axis. * copy_v_to_v handles cluster-asymmetric src/dst: each side emits its own region at native rank using ref indices; _rewrite_refs_to_4d lifts both per their own pad_inserts / cluster_modes entry. * M_MM_WO physical row stride fix: c_orow_step = blen * mlen (not blen * dst_row_stride). M_MM_WO writes blen rows at physical pitch mlen regardless of how dense the dst's logical N maps inside each mlen-row. * row_stack lane stride fix: when cluster_dim==0 (lane on B axis, lane_count==1), lane_stride = product(buf.shape[1:]) — matches the M_BMM_WO / M_BMV_WO hardware writeback (lane j -> base + j * per_lane_elems). For S=mlen this coincides with the old b_stride = mlen*inner_lane; for S<mlen (flash_decode rows=1) the old value overshot and lane 1+ wrote past the buffer. * unroll_loops mode on matmul emitter (avoid sim MAX_LOOP_INSTRUCTIONS). New debug kernels: flash_attention_gemm_only, flash_decode_min_gemm_only to bisect gemm schema in isolation. Plus layernorm_min / linear_min / linear_min_no_transpose / modulate_min / residual_gate_min / rmsnorm_min / silu_min / gelu_min staging kernels. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ernel fixes - concat_min: feature-axis concat of two head-packed tensors - _head_layout: BSHD <-> B,S,1,H*D view helpers - copy_offset_min: o_head_offset probe kernel - hoist_float_constants pass; mid_ir fold/fuse/to_plena fixes - kernel updates across flash_attention/gelu/layernorm/linear/rmsnorm/silu Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…n kernel, plena_settings Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ocation
- New MIR layer (mir.py, mir_passes.py) + PreIsaIR v2 (pre_isa_ir_v2.py,
pre_isa_pass_v2.py, pre_isa_to_mir.py) + ISA emit (mir_to_isa.py).
- Register allocation: structured live intervals + scope-recursive spill.
Loop-carried values are stored to IntRAM at scope entry (outside the
loop body) and reloaded per use inside; the store never lands in a body
that loops, fixing cross-iteration spill corruption.
- C_SET_ADDR_REG emits 3 operands (aN, gp0, gp{addr}) matching HW.
- FORCE_SERIAL_LOOPS: all loops lower to hardware C_LOOP (emit-time
unroll removed as unsound).
- Add v2 test suite (test_v2_*, test_pre_isa_*, test_mir_passes); drop
superseded legacy emitter tests.
- doc/simulator_cost_model.md, REGALLOC design notes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Merges the
plena-compilerbranch intomain. This brings the compiler from the legacy graph-layer pipeline to the new mid_ir-based pipeline, adds SPMD groundwork, and lands SSB (SingleStreamBlock) chained-kernel support.103 files changed, ~37k insertions.
Highlights
mid_irpipeline withcluster_dim; add BTMV / MV / vram-vram-copy dispatch.row_*_ataddressing; pinned globals are row-major-flat;row_stacklane-stride fix.dim_rolesschema for gemm ops.SPMD_REWRITE.mddesign doc + step 1classify_lane_usepass with unit tests (replacing the 4 lane-fusion graph passes).concat_minkernel, head-layout helpers, mid_ir + kernel fixes.spill_borrownow also filters pinned GPs.Test plan
tilelang_tvm_compiler/tests/(see diff).🤖 Generated with Claude Code