diff --git a/SPMD_REWRITE.md b/SPMD_REWRITE.md new file mode 100644 index 0000000..3f7f1fd --- /dev/null +++ b/SPMD_REWRITE.md @@ -0,0 +1,371 @@ +# SPMD Lane-Group Rewrite — Design + +Status: design-only, 2026-05-11. Not yet implemented. + +Replaces the four lane-fusion graph passes (`split_lane_groups`, +`lift_lane_groups`, `allocate_group_memory`, `expand_buffers`) with +**three early TIR passes** that resolve lane fusion before +`lift_from_raw_primfunc` ever runs: + +1. `classify_lane_use` — scan op annotations + how `by` is used, + tag each buffer with its lane-fusion role. +2. `expand_lane_grid` — for tagged buffers only, add a LANE outer + dim and wrap per-lane work in a serial loop. +3. `infer_lane_layout` — pick where the lane axis sits in each + buffer (BSHD vs BHSD) and rewrite shape + indices accordingly. + +The split is intentional: each pass touches a different aspect of the +IR and can be unit-tested independently. They share state through +buffer attributes set by step 1. + +--- + +## 1. Why + +Today's `allocate_group_memory` + `expand_buffers` encode lane fusion +as a **buffer-shape decision**: a kernel-author 2D buffer +`alloc_shared((rows, hlen))` is silently rewritten into a 4D +`(B, rows, lane, hlen)` (COL_PACK) or `(B, lane, rows, hlen)` +(ROW_STACK), with the choice driven by ~8 heuristics. Every downstream +op then has to be lane-aware. + +Three problems: + +1. **Source code lies.** The kernel writer sees 2D, the compiler sees + 4D, the ASM consumer sees yet another shape. Every layer needs to + re-derive what's true. +2. **Heuristics are opaque.** `_resolve_row_at_coords`'s if-else chain + in `isa_pass.py` is a faithful read of where the buffer ended up — + but you have to trace through 4 passes to know why. +3. **Lane fusion bugs are silent.** Wrong COL_PACK vs ROW_STACK pick + = mis-laid buffer = numerically wrong ASM, no compile error. + +The new model treats lane fusion as **buffer dimensionality + a serial +loop**: every `T.alloc_shared((rows, hlen))` becomes a `(LANE, rows, +hlen)` buffer (one extra outermost dim), the grid `by` axis is +replaced by an explicit `for lane in serial(LANE)`, and a separate +small pass decides where the lane dimension actually sits in each +buffer (BSHD vs BHSD) by inspecting how the buffer is used. + +--- + +## 2. Where the new passes live + +``` +T.prim_func (tilelang DSL output) + ↓ inline_let_stmts ← unchanged + ↓ lower_compound_fp_stores ← unchanged + ↓ classify_lane_use ★ NEW ← tag each buffer with lane-fusion role + ↓ expand_lane_grid ★ NEW ← tagged buffers gain a LANE outer dim; lane loop wraps per-lane work + ↓ infer_lane_layout ★ NEW ← move lane dim to its physical position; rewrite indices + ↓ lift_from_raw_primfunc ← unchanged +[Graph] ← graph_passes never see lane fusion + ↓ graph_annotate_grid ← simplified (no lane handling) + ↓ graph_annotate_sync ← deleted (no async/sync distinction left) + ↓ split_lane_groups ← deleted + ↓ lift_lane_groups ← deleted + ↓ fuse_elementwise ← kept, simplified + ↓ scope_inference ← kept, simplified +[graph_pipeline.materialize] + ↓ allocate_group_memory.analyze ← deleted + ↓ expand_buffers.expand ← deleted + ↓ lower_fp_row_patterns ← kept +``` + +**Net:** 4 graph passes deleted, 2 graph passes simplified, 3 new TIR +passes added. Estimated line delta: −1500, +600. + +--- + +## 3. The three new passes + +### 3.0 `classify_lane_use` — tag buffers by lane-fusion role + +**Why first.** `expand_lane_grid` can't blindly add a LANE dim to +every buffer in the kernel — only buffers that participate in lane +fusion need it. Whether a buffer participates depends on **how the +ops that touch it are annotated**, not on the buffer's shape. So we +need one walk over the function body first, before any rewriting, +to label each buffer. + +**Inputs the classifier looks at:** + +| Op site | Role assigned to operand buffers | +|---------------------------------------------------------------|----------------------------------| +| `T.gemm(A, B, C)` under `T.attr(0, KIND, "btmm")` | A: btmm_lhs, B: btmm_rhs, C: btmm_out | +| `T.gemm(A, B, C)` with no KIND attr | A: per_head_lhs, B: per_head_rhs, C: per_head_out | +| `T.copy(hbm_slice, dst)` where the HBM slice indexes `by` | dst: lane_dma_dst | +| `T.copy(hbm_slice, dst)` with no `by` in the slice | dst: single_lane (no LANE dim) | +| `T.serial(N) + T.Parallel(M)` over a tagged buffer | propagates the tag to the body's loads/stores | +| Anything else | scalar — keep on outer attribute table | + +**Output:** every `tir.Buffer` (param or alloc) gains one of: + +- `lane_aware = True` (gets LANE outer dim in step 3.1) + - sub-tag picks layout in step 3.2: `col_pack` / `bhsd` / `single` +- `lane_aware = False` (untouched in steps 3.1 and 3.2) + +Stored as `buffer.var`-keyed attributes the next two passes read. + +**Implementation:** one `stmt_functor.post_order_visit` walk. The +classification rules above each become one `isinstance` check + one +attribute write. ~80 lines. + +**This is where we keep "implicit conventions inferred from `by` +use"** rather than forcing kernel authors to rewrite all 7 kernels +with explicit `T.async_copy` / `T.lane_parallel` macros. Today's 8 +buffer-shape heuristics are gone; what stays is "if the op site says +btmm or uses `by`, lane fusion applies". + +### 3.1 `expand_lane_grid` — pure structural rewrite + +**Input:** a `tir.PrimFunc` post-`classify_lane_use`. Every buffer +has its `lane_aware` flag set. `LANE = MLEN / btmm_hlen` (= 4 today). +The kernel author marks the lane axis with +`T.func_attr({"plena.lane_axis": "by"})`. + +**Output:** the same `tir.PrimFunc` with: + +- The lane-axis grid var **erased**. If extent == LANE, no surrounding + loop. If extent is a multiple of LANE, wrap with + `for by_outer in T.serial(extent // LANE)`. +- Every buffer with `lane_aware = True` rewritten from + `T.alloc_*((shape...))` to `T.alloc_*((LANE,) + shape)` — one extra + outermost dim. Buffers with `lane_aware = False` are untouched. + No new macro, no name mangling: `Q_sh` stays `Q_sh`, just shape + `(LANE, rows, hlen)` instead of `(rows, hlen)`. +- Every reference to a lane-aware buffer indexed accordingly: + - **Sync ops** (DMA, BTMM, V_*, etc.) that consume the buffer + whole: the surface call stays `T.copy(...)` / `T.gemm(...)`. + Codegen later sees a 3D buffer where dim 0 == LANE and emits + the multi-lane HW instruction. No `*_multi` op kind. + - **Per-lane work** (row_*, fp_*, per-head matmul on a + `per_head_lhs`-tagged buffer): wrapped in + `for lane in T.serial(LANE)` and indexed `Q_sh[lane, ...]`. The + `lane` var is a regular serial loop var; no lane fusion semantics + attached. + +**That's the entire pass.** It does not pick layouts, does not insert +reshape ops, does not touch what scope buffers live in. Pure +structural: grid → buffer dim + loop. + +**Worked example.** Input slice of `flash_attention_min` (LANE=4, +head_count=4): + +```python +with T.Kernel(num_q_blocks, head_count) as (q_block, by): + Q_sh = T.alloc_shared((rows, hlen), "f16") + K_sh = T.alloc_shared((rows, hlen), "f16") + S_loc = T.alloc_fragment((rows, MLEN), "f16") + M_OLD = T.alloc_fragment((rows,), "f16") + + T.copy(Q_hbm[0, q_block*rows, by, 0], Q_sh) + T.copy(K_hbm[0, kv_block*rows, by, 0], K_sh) + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + for row in T.serial(rows): + M_OLD[row] = M_INIT[row] +``` + +After `expand_lane_grid`: + +```python +# `by` erased; head_count == LANE so no by_outer loop +Q_sh = T.alloc_shared((4, rows, hlen), "f16") # +outer dim +K_sh = T.alloc_shared((4, rows, hlen), "f16") +S_loc = T.alloc_fragment((4, rows, MLEN), "f16") +M_OLD = T.alloc_fragment((4, rows), "f16") + +T.copy(Q_hbm[0, q_block*rows, 0:4, 0], Q_sh) # 3D dst +T.copy(K_hbm[0, kv_block*rows, 0:4, 0], K_sh) +with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) # all 3D + +for lane in T.serial(4): # ← was implicit lane fusion + for row in T.serial(rows): + M_OLD[lane, row] = M_INIT[lane, row] +``` + +### 3.2 `infer_lane_layout` — choose where the lane dim sits + +After `expand_lane_grid`, every lane-aware buffer has its lane dim at +position 0. But the **physical** layout (which 7D slot the lane axis +occupies in VRAM/MRAM) depends on how the buffer is consumed. Two +flavors today: + +- **COL_PACK:** lane axis in the H slot of BSHD. Used for VRAM tiles + that BTMM reads as LHS or that per-row VRAM↔FPRAM ops walk. +- **ROW_STACK / BHSD:** lane axis in the H slot but ahead of S. Used + for BTMM outputs (`S_loc`) where each lane writes a full + (rows, MLEN) slab and per-head matmul consumes one lane's slab as + its LHS. + +In the new model, "BHSD vs BSHD" reduces to **which dim of the +buffer's shape carries the lane index**. Always lane-dim-at-0 for +COL_PACK; lane-dim-at-1 for BHSD; etc. + +**This pass:** + +1. For each `lane_aware = True` buffer, **read its role tag from + step 3.0** and map role → layout: + - `btmm_lhs`, `btmm_rhs`, `lane_dma_dst`, `per_head_out`, + fp/row state → COL_PACK (lane at outer dim 0) + - `btmm_out`, `per_head_lhs` → BHSD (lane at dim 1) + - Conflicting roles → error (a buffer can't be both BTMM output + and BTMM LHS; if classification disagrees it's a kernel-source + issue, not a pass issue) +2. **Permute the buffer's shape** so the lane axis sits where the + chosen layout wants it. e.g. `S_loc: (4, rows, MLEN) → (rows, 4, + MLEN)`. +3. **Update every load / store** of that buffer to permute its index + tuple correspondingly. e.g. `S_loc[lane, row, col] → + S_loc[row, lane, col]`. + +No reshape ops in the IR — we rewrite shapes and indices in place. +The pass is one walk + one rewrite, no fixpoint, no cross-buffer +flow analysis. + +**Worked example continuing from §3.1.** Suppose `S_loc` is BTMM output ++ per-head matmul LHS → both votes for BHSD. Then: + +```python +S_loc = T.alloc_fragment((rows, 4, MLEN), "f16") # ← lane moved to dim 1 + +with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) # codegen reads buf shape + +for lane in T.serial(4): + for row in T.serial(rows): + M_OLD[lane, row] = ... S_loc[row, lane, ...] ... +``` + +`Q_sh`, `K_sh` stay `(4, rows, hlen)` — no consumer wanted them +anywhere else. + +--- + +## 4. What graph_passes / codegen need to know + +Graph IR sees normal 3D buffers. The lane axis is **just a dim**; it +doesn't get a special label. Codegen distinguishes single-lane from +multi-lane purely by buffer shape: + +- DMA / BTMM / V_* called on a buffer whose shape matches a known + 3D-multi-lane pattern (dim that equals LANE is at the layout's H + slot per `plena.layout`) → emit the multi-lane HW instruction + with `lane_count=LANE`. +- DMA / BTMM / V_* called on a 2D buffer (or a 3D buffer whose + outer dim is *not* LANE — single-tile broadcast case) → emit the + single-lane HW instruction. + +This is one if-else in each codegen handler, replacing the entire +`_resolve_row_at_coords` chain. + +**Per-lane row/fp ops** stay exactly as they are — they receive a +single-lane address (the result of indexing a 3D buffer with the +`lane` loop var) and emit one per-lane HW instruction per `lane` +iteration. The legacy `lower_fp_row_patterns` pass handles the +`for lane in serial(4)` exactly like any other serial loop. + +--- + +## 5. What gets deleted, what gets kept + +### 5.1 Deleted + +- `frontend/passes/graph_passes/split_lane_groups.py` +- `frontend/passes/graph_passes/lift_lane_groups.py` +- `frontend/passes/graph_passes/allocate_group_memory.py` +- `frontend/passes/graph_passes/expand_buffers.py` +- `frontend/passes/graph_passes/annotate_sync.py` +- The lane-handling branches in `_resolve_row_at_coords` (in + `isa_pass.py`). + +### 5.2 Simplified + +- `frontend/passes/graph_passes/annotate_grid.py` — only annotates + non-lane grid axes (`q_block` etc). +- `frontend/passes/graph_passes/fuse_elementwise.py` — no longer + needs to special-case lane-group regions. +- `frontend/passes/graph_passes/scope_inference.py` — buffer scope + comes straight from the kernel's `alloc_*` call; no cross-buffer + inference. +- `codegen.py` — single-lane vs multi-lane is one shape check per + intrinsic, no lookup table. +- All 7 kernels in `tilelang_tvm_compiler/kernels/` — one new line + each: `T.func_attr({"plena.lane_axis": "by"})`. Buffer alloc + shapes stay literally the same — `expand_lane_grid` adds the + outer dim. + +### 5.3 Untouched + +- `intrinsics.py` (no new op kinds — multi-lane is implicit in + buffer shape) +- `isa_emitter.py` / `isa_pass.py` (modulo the simplification above) +- `expr_materializer.py`, `register_alloc.py`, `address_alloc.py` +- Everything under `transactional_emulator/` + +--- + +## 6. Risks / open questions + +1. **`expand_lane_grid` and `T.copy`'s slice arg.** Today + `T.copy(K_hbm[0, kv*rows, by, 0], K_sh)` uses `by` as an HBM + index. After erasure we need to turn it into a **range** indexing + `0:LANE` over that axis. That's the easy case (lane axis indexes + directly into a tensor dim). If the kernel does arithmetic on + `by` beyond a bare reference, the pass should refuse and ask the + author to refactor — not try to be clever. + +2. **`infer_lane_layout` conflict resolution.** A buffer used as + both BTMM LHS (wants COL_PACK) and BTMM output (wants BHSD) is + structurally impossible — it's not the same buffer. But what + about a buffer used as BTMM input AND per-head matmul input? This + does happen (S_loc → P @ V). The classification table in §3.2 + needs to be exhaustive against the 7 kernels; we'll discover any + gap during step-by-step migration. + +3. **Layout permutation cost.** Rewriting every load/store to + permute its index tuple touches many sites in long kernels. It's + mechanical but easy to bug. Mitigation: write a small + `permute_index(buf_var, perm)` helper and use it everywhere; one + permutation table per buffer. + +4. **`q_block` is not a lane axis.** It's a sequential block index + over Q tiles. The new pass only erases the var marked + `plena.lane_axis`; everything else (including unmarked grid axes) + stays as-is, lowered to `T.serial` by the existing TVM pipeline. + +5. **head_count > LANE.** When `head_count = 8, LANE = 4`: + `for by_outer in T.serial(2)` wraps the body. The lane buffers + are reused across by_outer iterations — matches today. + +6. **migration order.** Old and new pipelines coexist on + `compile_func` per kernel, gated by + `T.func_attr({"plena.use_spmd": True})`. flash_attention_min + first; the rest follow once HLIR diff is byte-clean. + +--- + +## 7. Implementation order + +| Step | Deliverable | Estimate | +|---|---|---| +| 1 | `classify_lane_use.py` — role-tagging walk with the 6 rules in §3.0 | 0.5 day | +| 2 | `expand_lane_grid.py` for flash_attention_min only — reads tags from step 1, adds LANE dim + lane loop | 1.5 days | +| 3 | `infer_lane_layout.py` — reads tags from step 1, permutes shapes + indices | 1 day | +| 4 | Codegen: shape-driven multi-lane vs single-lane dispatch in `_resolve_row_at_coords` (and equivalents) | 1 day | +| 5 | flash_attention_min source: add `plena.lane_axis` + `plena.use_spmd` attrs; verify HLIR matches today's output byte-for-byte | 1 day | +| 6 | Simplify `annotate_grid` / `fuse_elementwise` / `scope_inference` to remove lane handling | 1 day | +| 7 | Delete the four graph passes + their tests + dead branches in `_resolve_row_at_coords` | 0.5 day | +| 8 | Port the other 6 kernels (mostly: add the func attrs; extend the role table in step 1 if any new pattern shows up) | 2 days | +| 9 | Delete the legacy `compile_func` branch (`expand_lane_buffers=True`) | 0.5 day | + +Total: **~9 days** of focused work. Three small single-purpose +passes beat one big multi-walk pass for readability and +testability: classify alone can be unit-tested by checking the role +tags it sets, expand alone can be tested by feeding it a hand-tagged +IR, and infer alone can be tested similarly. The IR stays in vanilla +TIR throughout — no new macros, no `*_multi` op kinds, no +contiguous-backing tricks. diff --git a/asm_templates/__init__.py b/asm_templates/__init__.py index 8c5c02f..5375b5b 100644 --- a/asm_templates/__init__.py +++ b/asm_templates/__init__.py @@ -30,8 +30,8 @@ from .reset_reg_asm import reset_fpreg_asm, reset_reg_asm from .rope_asm import rope_asm from .silu_asm import silu_asm +from .gelu_asm import gelu_asm from .store_act_asm import store_act_asm -from .gemv_asm import gemv_asm __all__ = [ "batched_matmul_asm", @@ -56,5 +56,6 @@ "rms_norm_asm", "rope_asm", "silu_asm", + "gelu_asm", "store_act_asm", ] diff --git a/asm_templates/preload_act.py b/asm_templates/preload_act.py index 8c783e4..e190855 100644 --- a/asm_templates/preload_act.py +++ b/asm_templates/preload_act.py @@ -23,6 +23,7 @@ def preload_act_asm( inner_loop_register = alive_registers[4] stride_len = vlen if stride_size is None else stride_size + scale_len = hidden_size * batch if scale_size is None else scale_size # Set scale offset generated_code += _load_large_int(a_actual_register, hidden_size * batch) diff --git a/assembler/assembly_to_binary.py b/assembler/assembly_to_binary.py index bb3c578..76d1839 100644 --- a/assembler/assembly_to_binary.py +++ b/assembler/assembly_to_binary.py @@ -38,30 +38,43 @@ def _convert_to_binary(self, instruction): imm = instruction.imm rmask = instruction.rmask binary_instruction = 0 + + imm_mask = (1 << self.imm_width) - 1 if self.imm_width > 0 else 0 + if imm_mask and isinstance(imm, int) and (imm < 0 or imm > imm_mask): + print( + f"[assembler] WARN: imm overflow on {instruction.opcode}: " + f"raw imm={imm} (0x{imm & 0xFFFFFFFFFFFFFFFF:X}), " + f"IMM_WIDTH={self.imm_width}, masking to 0x{imm & imm_mask:X}" + ) + imm = imm & imm_mask + # print(f"Converting instruction: {instruction.opcode} with opcode={hex(opcode)}, rd={rd}, rs1={rs1}, rs2={rs2}, rstride={rstride}, funct1={funct1}, funct2={funct2}, imm={imm}") ow = self.operands_width opw = self.opcode_width - if instruction.opcode in [ - "S_ADDI_INT", - "M_MM_WO", - "S_LD_FP", - "S_ST_FP", - "S_LD_INT", - "S_ST_INT", - "S_MAP_V_FP", - "V_RED_MAX", - "V_RECI_V", - "V_EXP_V", - ]: - binary_instruction = (imm << (opw + 2 * ow)) + (rs1 << (opw + ow)) + (rd << opw) + opcode + if instruction.opcode in ["S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "M_MM_WO", "S_LD_FP", "S_ST_FP", "S_LD_INT", "S_ST_INT", "S_MAP_V_FP", "S_MAP_FP_V"]: + binary_instruction = ( + (imm << (opw + 2 * ow)) + + (rs1 << (opw + ow)) + + (rd << opw) + + opcode + ) elif instruction.opcode in ["S_LUI_INT", "M_MV_WO", "M_BMM_WO", "M_BMV_WO"]: - binary_instruction = (imm << (opw + ow)) + (rd << opw) + opcode - elif instruction.opcode in ["S_MV_FP", "S_RECI_FP", "S_EXP_FP", "S_SQRT_FP", "V_EXP_V", "V_RED_SUM"]: - binary_instruction = (rs1 << (opw + ow)) + (rd << opw) + opcode - elif instruction.opcode in ["C_BREAK"]: - binary_instruction = opcode - elif instruction.opcode in ["C_SET_SCALE_REG", "C_SET_STRIDE_REG", "C_SET_V_MASK_REG", "C_LOOP_END"]: - binary_instruction = (rd << opw) + opcode + binary_instruction = ( + (imm << (opw + ow)) + + (rd << opw) + + opcode + ) + elif instruction.opcode in [ "S_MV_FP", "S_RECI_FP", "S_EXP_FP", "S_SQRT_FP"]: + binary_instruction = ( + (rs1 << (opw + ow)) + + (rd << opw) + + opcode + ) + elif instruction.opcode in [ "C_SET_SCALE_REG", "C_SET_STRIDE_REG", "C_SET_V_MASK_REG", "C_LOOP_END"]: + binary_instruction = ( + (rd << opw) + + opcode + ) elif instruction.opcode in ["C_LOOP_START"]: # C_LOOP_START rd, imm - uses 22-bit immediate like S_LUI_INT binary_instruction = (imm << (opw + ow)) + (rd << opw) + opcode @@ -74,17 +87,15 @@ def _convert_to_binary(self, instruction): + (rd << opw) + opcode ) - elif instruction.opcode in [ - "V_ADD_VV", - "V_ADD_VF", - "V_MUL_VV", - "V_SUB_VV", - "V_MUL_VF", - "V_EXP_V", - "V_RECI_V", - "V_RED_SUM", - "V_RED_MAX", - ]: + elif instruction.opcode in ["V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"]: + binary_instruction = ( + (rmask << (opw + 3 * ow)) + + (0 << (opw + 2 * ow)) + + (rs1 << (opw + ow)) + + (rd << opw) + + opcode + ) + elif instruction.opcode in ["V_ADD_VV", "V_ADD_VF", "V_MUL_VV", "V_SUB_VV", "V_MUL_VF"]: binary_instruction = ( (rmask << (opw + 3 * ow)) + (rs2 << (opw + 2 * ow)) + (rs1 << (opw + ow)) + (rd << opw) + opcode ) @@ -122,10 +133,17 @@ def _convert_to_binary(self, instruction): return binary_instruction def write_binary_to_file(self, binary_instructions, output_file: str): - with open(output_file, "w") as file: - for instruction in binary_instructions: - file.write(f"0x{instruction:08X}\n") - + instr_mask = (1 << self.instruction_length) - 1 if self.instruction_length > 0 else 0xFFFFFFFF + with open(output_file, 'w') as file: + for idx, instruction in enumerate(binary_instructions): + if instruction & ~instr_mask: + print( + f"[assembler] WARN: instruction #{idx} overflows " + f"INSTRUCTION_LENGTH={self.instruction_length}: " + f"raw=0x{instruction:X}, truncating to 0x{instruction & instr_mask:08X}" + ) + file.write(f"0x{instruction & instr_mask:08X}\n") + def generate_binary(self, asm_file: str, output_file: str): """ Generate binary instructions from the assembled instructions. @@ -141,3 +159,10 @@ def generate_binary(self, asm_file: str, output_file: str): return binary_instructions +# isa_file_path = '../../src/definitions/operation.svh' +# config_file_path = '../../src/definitions/configuration.svh' +# asm_file_path = f'../../test/{args.test_type}/{args.layer}.asm' +# print(f'Assembling {asm_file_path} to {args.layer}.mem') +# output_file_path = f'../../test/{args.test_type}/{args.layer}.mem' +# assembler = AssemblyToBinary(isa_file_path, config_file_path) +# assembler.generate_binary(asm_file_path, output_file_path) diff --git a/assembler/parser.py b/assembler/parser.py index 12306b2..de8f07c 100644 --- a/assembler/parser.py +++ b/assembler/parser.py @@ -182,14 +182,22 @@ def parse_reg_or_int(operand): imm = int(operand_1) except ValueError: imm = None - # If it looks like register, rs2; else, imm (overwrites imm if rs1 not present) - if operand_2.strip().startswith(("gp", "f", "a")): - rs2 = parse_reg_or_int(operand_2) - else: + # Some vector ops use a 3-operand form where the last field is + # rmask, not rs2/imm. + if opcode in {"V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"}: try: - imm = int(operand_2) + rstride = int(operand_2) except ValueError: - pass + rstride = None + else: + # If it looks like register, rs2; else, imm (overwrites imm if rs1 not present) + if operand_2.strip().startswith(('gp','f','a')): + rs2 = parse_reg_or_int(operand_2) + else: + try: + imm = int(operand_2) + except ValueError: + pass elif len(operands) == 4: operand_0, operand_1, operand_2, operand_3 = operands rd = parse_reg_or_int(operand_0) diff --git a/doc/operation.svh b/doc/operation.svh index 7e5808d..334971e 100644 --- a/doc/operation.svh +++ b/doc/operation.svh @@ -146,6 +146,7 @@ typedef enum logic [instruction_pkg::OPCODE_WIDTH - 1:0] { S_LD_FP = 6'h1E, S_ST_FP = 6'h1F, S_MAP_V_FP = 6'h20, + S_MAP_FP_V = 6'h35, // Scalar Operations (INT) S_ADD_INT = 6'h21, @@ -155,6 +156,13 @@ typedef enum logic [instruction_pkg::OPCODE_WIDTH - 1:0] { S_LUI_INT = 6'h25, S_LD_INT = 6'h26, S_ST_INT = 6'h27, + // Logical shifts. Arithmetic-right (SRA) intentionally omitted: PLENA's + // integer domain is unsigned address/index arithmetic, no negative values + // ever flow through the GP pool, so sign-extension on shift is a no-op. + S_SLL_INT = 6'h36, + S_SLLI_INT = 6'h37, + S_SRL_INT = 6'h38, + S_SRLI_INT = 6'h39, // Memory Operations H_PREFETCH_M = 6'h28, diff --git a/doc/plena_isa_spec.md b/doc/plena_isa_spec.md index 9d16666..44b49f8 100644 --- a/doc/plena_isa_spec.md +++ b/doc/plena_isa_spec.md @@ -407,6 +407,60 @@ S_ADDI_INT gp2, gp1, 64 ; gp2 = 128 + 64 = 192 Load upper immediate value into the integer register. +#### S_SLL_INT + +**Format:** `S_SLL_INT rd, rs1, rs2` + +**Operation:** `gp_reg = gp_reg << (gp_reg & 0x1F)` + +**Description:** + +Logical shift left, shift amount taken from the lower 5 bits of `gp_reg`. + +#### S_SLLI_INT + +**Format:** `S_SLLI_INT rd, rs1, imm` + +**Operation:** `gp_reg = gp_reg << (imm & 0x1F)` + +**Description:** + +Logical shift left by an immediate. The lower 5 bits of `imm` are used as the shift amount; upper bits are ignored. + +**Example:** +```asm +S_ADDI_INT gp1, gp0, 5 ; gp1 = 5 +S_SLLI_INT gp2, gp1, 3 ; gp2 = 5 << 3 = 40 +``` + +#### S_SRL_INT + +**Format:** `S_SRL_INT rd, rs1, rs2` + +**Operation:** `gp_reg = gp_reg >> (gp_reg & 0x1F)` (logical, zero-fill) + +**Description:** + +Logical shift right, shift amount taken from the lower 5 bits of `gp_reg`. + +#### S_SRLI_INT + +**Format:** `S_SRLI_INT rd, rs1, imm` + +**Operation:** `gp_reg = gp_reg >> (imm & 0x1F)` (logical, zero-fill) + +**Description:** + +Logical shift right by an immediate. The lower 5 bits of `imm` are used as the shift amount. + +**Example:** +```asm +S_ADDI_INT gp1, gp0, 200 ; gp1 = 200 +S_SRLI_INT gp2, gp1, 3 ; gp2 = 200 >> 3 = 25 +``` + +**Note on omissions:** PLENA does **not** provide arithmetic right shift (SRA) -- the integer pool is treated as unsigned for address arithmetic. There is also no AND/OR/XOR; bit-level masking is not part of the kernel programming model. + #### S_LD_INT **Format:** `S_LD_INT rd, rs1, imm` diff --git a/doc/simulator_cost_model.md b/doc/simulator_cost_model.md new file mode 100644 index 0000000..e1a9327 --- /dev/null +++ b/doc/simulator_cost_model.md @@ -0,0 +1,291 @@ +# PLENA Simulator Cost & Concurrency Model + +What every PLENA opcode actually costs when run through +[`transactional_emulator`](../../transactional_emulator/src/main.rs), +and what (does / doesn't) run in parallel. Everything below was +verified from the Rust source on 2026-05-21. The active config is +analytic mode, dc_lib_en = 1, `mlen=1024`, `blen=8`, `hlen=128`, +`VLEN=1024`. + +## TL;DR for compiler authors + +1. The simulator runs **one sequential main task**. Every opcode + handler `.await`s its `cycle!(N)` cost before returning. There is + **no instruction-level parallelism**: while `M_MM` blocks for 1024 + cycles, no scalar / vector / scalar-FP instr makes progress. +2. The **only background task** is HBM DMA. Inside that task there + are **no `cycle!()` calls** — HBM transfers are modelled as + instantaneous, infinite-bandwidth, zero-latency. +3. **IntRAM access is 1 cycle**, identical to a scalar ADD. Spill + to IntRAM is not a 5–6× penalty (that was an unjustified + assumption); it's a 1× penalty. Register allocation should treat + spill as essentially free. +4. **Belady's algorithm (farthest last_use) is the optimal spill + picker** under this model — reload is a flat 1 cycle, so + minimising the count of reloads equals minimising spill cost. +5. **LICM should be aggressive always** — hoisting one invariant + out of an `extent=N` loop saves N×ALU cycles; the worst case + (the hoisted value gets spilled and reloaded N times) costs + N×LD_INT cycles = same total. Net is always non-negative. +6. **Strength reduction and rematerialisation give 0 benefit.** + `S_SLLI_INT` and `S_MUL_INT` are both 1 cycle; `S_LUI_INT` and + `S_LD_INT` are both 1 cycle. Trading one for the other never + wins. +7. **C_LOOP is good for heavy bodies, bad for trivial bodies.** The + `C_LOOP_END` adds 1 cycle per iteration on top of the body, so + for a 1-cycle body the loop is 2× the unrolled cost. For a + 1024-cycle `M_MM` body the overhead is 0.1%. + +## Concurrency: one task, sequential dispatch + +`transactional_emulator/src/main.rs:2515-2521`: + +```rust +#[tokio::main] +async fn main() { + let executor = Executor::new(); + executor.spawn(start()); + executor.enter(Instant::ETERNITY).await; +} +``` + +`start()` builds the M/V/HBM machines and calls `do_ops` exactly +once. `do_ops` is a flat `while pc < ops.len()` loop: + +```rust +async fn do_ops(&mut self, ops: &[Opcode]) { + let mut pc = 0; + while pc < ops.len() { + match ops[pc] { + M_MM { .. } => self.m_machine.mm(...).await, // cycle!(1024) + V_ADD_VV { .. } => self.v_machine.add(...).await, // cycle!(1) + S_ADD_INT { .. } => { ...; cycle!(1); } + // ... + } + pc += 1; + } +} +``` + +`cycle!(N)` expands to `Executor::current().resolve_at(now + N).await`. +That schedules a timer and suspends the current task until the +executor advances simulated time to `now + N`. Because only one +task ever issues opcodes, suspending it means *no further opcode +can issue* until the timer fires. + +### Time only moves forward in `executor.enter()` + +The scheduler loop (`lib/runtime/src/executor.rs:185-252`) is: + +1. Run every ready task to completion (or until it `.await`s). +2. Pop the earliest pending timer. Set `now = timer.resolve_at`. +3. Wake that timer. Go to 1. + +Crucially: **simulated time never advances except in step 2**. +Anything in step 1 — including HBM reads, channel sends, mutex +locks — completes "instantly" from the simulated-clock POV. + +### `tokio::join!` is in-task, not cross-instr + +The V_*_VV handlers use `tokio::join!(self.vram.read(vs1), +self.vram.read(vs2))` to read both operands "concurrently". This +join is *inside* one opcode handler in the main task; it does not +overlap with any other instruction. + +### The lone exception: HBM DMA + +The H_PREFETCH_V / H_PREFETCH_M / H_STORE_V handlers +(`main.rs:2039-2165`) all delegate to +`transfer_mx_from_hbm` (or its store counterpart). Inside, +`Executor::current().spawn(async move { ... })` fires off a +background task that: + +* reads HBM 64 bytes at a time via `hbm_clone.read(addr).await`, +* sends the assembled tensor through a oneshot channel, +* the main task's `continous_write_delayed(...).await` waits on + that channel and writes the destination tile. + +But `hbm.read()` itself contains *no* `cycle!()` calls. The +background task therefore runs to completion in step 1 of the +scheduler loop — instantaneously, in simulated time. The main +task's `.await` on the channel completes at the same `now`. + +This means: the HBM model is **infinite bandwidth, zero latency, +unbounded outstanding transactions**. Issuing an `H_PREFETCH_V` +costs zero cycles on the main timeline, and the prefetched data is +available immediately to any downstream consumer. + +## Per-opcode cycle cost + +All costs are for analytic mode, `dc_lib_en=1`. dc_lib_dis varies +slightly (mostly 2× scalar FP exp/sqrt/reci) but `SCALAR_INT_BASIC` +is 1 in both modes. + +### Scalar integer (`S_*_INT`) + +| Opcode | Cycles | +|--------|-------:| +| `S_ADD_INT` | 1 | +| `S_ADDI_INT` | 1 | +| `S_SUB_INT` | 1 | +| `S_MUL_INT` | 1 | +| `S_LUI_INT` | 1 | +| `S_SLL_INT` / `S_SLLI_INT` | 1 | +| `S_SRL_INT` / `S_SRLI_INT` | 1 | +| **`S_LD_INT`** | **1** | +| **`S_ST_INT`** | **1** | + +Every entry calls `cycle!(*SCALAR_INT_BASIC_CYCLES)` and that +constant is `1` (see `load_config.rs:339`). IntRAM access is the +same cost as ADD. + +### Scalar FP (`S_*_FP`) + +| Opcode | Cycles (analytic) | +|--------|-------:| +| `S_ADD_FP` / `S_SUB_FP` / `S_MUL_FP` / `S_MAX_FP` | 1 | +| `S_EXP_FP` | 1 | +| `S_RECI_FP` | 1 | +| `S_SQRT_FP` | 1 | +| `S_LD_FP` / `S_ST_FP` | 1 | +| **`S_MAP_V_FP`** | **1024 (`VLEN`)** | +| **`S_MAP_FP_V`** | **1024 (`VLEN`)** | + +`S_MAP_V_FP` / `S_MAP_FP_V` broadcast a scalar across a vector; +they cost a full vector pass. Avoid them in inner loops. + +### Vector (`V_*`, 1024-wide) + +| Opcode | Cycles | +|--------|-------:| +| `V_ADD_VV` / `V_ADD_VF` | 1 | +| `V_SUB_VV` / `V_SUB_VF` | 1 | +| `V_MUL_VV` / `V_MUL_VF` | 1 | +| `V_EXP_V` | 1 | +| `V_RECI_V` | 2 | +| `V_RED_MAX` | 4 | +| `V_RED_SUM` | 8 | + +`V_RED_SUM` is the most expensive scalar-to-from-vector op. Used +once per softmax row, so flash-attention pays +`8 × num_softmax_rows` for it. + +### Systolic (`M_*`) + +| Opcode | Cycles | +|--------|-------:| +| `M_MM` / `M_TMM` | **mlen = 1024** | +| `M_BMM` / `M_BTMM` | 1024 | +| `M_MV` / `M_TMV` | 1024 | +| `M_BMV` / `M_BTMV` | **1** | +| `M_MM_WO` / `M_BMM_WO` / `M_MV_WO` / `M_BMV_WO` | 1 | + +`M_BMV` / `M_BTMV` cost just 1 cycle — they're the broadcast +matrix-vector op, the lane-parallel sibling of `M_MV`. flash-decode +relies on this. + +`M_MM` dominates everything in flash-attention: at 1024 cycles +each, 2048 of them = 2,097,152 cycles, which is ~99.5% of the +kernel's simulated runtime. Optimising the surrounding scalar +arithmetic to save dozens of cycles is statistical noise. + +### HBM (`H_*`) + +| Opcode | Cycles | +|--------|-------:| +| `H_PREFETCH_V` | **0** | +| `H_PREFETCH_M` | **0** | +| `H_STORE_V` | **0** | +| `H_LOAD_V` | **0** | + +All HBM ops dispatch via `Executor::spawn` and rely on tile-future +machinery; the main task isn't charged. See "Concurrency" above +for why. + +### Control (`C_*`) + +| Opcode | Cycles | +|--------|-------:| +| `C_SET_ADDR_REG` | 1 | +| `C_SET_SCALE_REG` | 1 | +| `C_SET_STRIDE_REG` | 1 | +| `C_SET_V_MASK_REG` | 1 | +| `C_LOOP_START` | 1 | +| `C_LOOP_END` | 1 *per iteration that loops back* | + +`C_LOOP_END` runs N times for an `extent=N` loop (once at the end +of each iteration's body), each costing 1 cycle. `C_LOOP_START` +runs only once. + +## C_LOOP vs unroll + +A loop `for i in [0, N): body` costs: + +* Unrolled: `N × body_cycles` +* C_LOOP: `1 + N × (body_cycles + 1)` + +Per-iteration overhead is 2 extra cycles (the START at the top, +the END at the bottom — but START is amortised over N iters so +effectively just 1 extra per iter from C_LOOP_END). + +Break-even point: + +| body_cycles | unroll | C_LOOP | C_LOOP / unroll | +|---:|---:|---:|---:| +| 1 (scalar) | N | 1 + 2N | ~2× | +| 5 (small vec chain) | 5N | 1 + 6N | ~1.2× | +| 50 | 50N | 1 + 51N | 1.02× | +| 1024 (M_MM) | 1024N | 1 + 1025N | 1.001× | + +* **Trivial bodies**: unroll is ~2× faster. ISA bloat per iter is + small enough that unrolling is also acceptable for size. +* **Heavy bodies (M_MM)**: C_LOOP overhead is negligible. Use + `loop_kind="serial"` to keep ISA compact. + +## How this maps to compiler decisions + +The v2 compiler is documented at +[`plena_v2_pipeline.md`](./plena_v2_pipeline.md). Specific +choices justified by this cost model: + +- **Spill picker uses Belady (farthest last_use).** At a uniform + 1-cycle reload, the total reload count = the total spill cost, + and Belady provably minimises miss/reload count. +- **LICM is enabled by default.** Hoist always non-negative even + under worst-case spill thrashing. +- **`const_fold` peephole keeps `S_ADDI_INT %x, 0 → %x`**. + Eliminating an ADDI saves 1 cycle. +- **`reassociate` canonicalises ADD chains** and folds matching + ±terms. Lets CSE find shared partial sums (e.g. + `result_addr = mat_addr + orow_term`). +- **No strength reduction pass.** SLLI / MUL / LUI / LD_INT are + all 1 cycle. +- **No rematerialisation pass.** Same reason. +- **PreIsaPassV2 emits `loop_kind="unroll"` for small extents and + should switch to `"serial"` for M_MM-dominated bodies.** (Open + item — see v2 pipeline doc.) +- **DMA placement is "issue early, no scheduling needed".** Any + H_PREFETCH that hoists out of an inner loop is free. There is no + reason to limit prefetches by count. + +## Where this model deviates from real hardware + +This is a *behavioural* simulator. Real PLENA almost certainly has: + +* finite HBM bandwidth (model says 0) +* separate M, V, S issue ports that can overlap (model has 0 + parallelism) +* a non-trivial pipeline depth before issued instructions retire + +A compiler tuned aggressively against this simulator may make +choices that look bad on silicon. Concretely: + +* aggressive prefetch may saturate HBM bandwidth on real hardware +* aggressive LICM increases register pressure that the simulator + resolves for free via IntRAM spill at 1 cycle, but real spill + costs depend on the IntRAM port latency +* there's no benefit modelled for issuing scalar arithmetic in + parallel with M_MM, so the compiler won't schedule for it + +If/when the simulator gains a more realistic timing model, revisit +this doc and the spill/LICM thresholds. diff --git a/tilelang_runtime_compier/doc/TILE_TENSOR_COMPILER_PRINCIPLES.md b/tilelang_runtime_compier/doc/TILE_TENSOR_COMPILER_PRINCIPLES.md new file mode 100644 index 0000000..4fd3cf1 --- /dev/null +++ b/tilelang_runtime_compier/doc/TILE_TENSOR_COMPILER_PRINCIPLES.md @@ -0,0 +1,333 @@ +# TileTensor Compiler Principles + +This note gives a short conceptual introduction to the current +`tile_tensor_program.py` compiler/runtime structure. + +It is intentionally high-level. The goal is to answer: + +- what the current compiler/runtime is +- what layers it is made of +- what the main units are +- what design principles it follows + +Related docs: + +- `TILE_TENSOR_PROGRAM_USAGE.md` +- `TILE_TENSOR_RUNTIME_NOTES.md` +- `TILE_TENSOR_KERNEL_PROGRAMS.md` + +## 1. Overall Positioning + +The current `TileTensorProgram` system is not a fully general tensor compiler. +It is better understood as: + +- a program-building API for TileTensor testbench kernels +- a runtime that maps logical tensor objects onto backing values +- a lowering pipeline that emits emulator-oriented ISA + +At a very high level, the flow is: + +1. the user describes logical tensors and logical compute +2. the runtime decides which backing values those logical objects currently see +3. the compute layer lowers the result into emulator instructions + +## 2. Main Layers + +The current structure can be understood as three core layers, two supporting +layers, and one user-facing facade. + +### 2.1 `TensorManager` + +`TensorManager` is the logical layer. + +It owns: + +- logical `Input`, `Tensor`, and `Vector` objects +- tile creation +- slice resolution +- `mapt` grouping + +Its job is to answer questions like: + +- what logical tensor exists +- how it is tiled +- which logical tile or slice an operation refers to + +It does not decide backing residency or physical storage allocation. + +### 2.2 `ValueManager` + +`ValueManager` is the backing-value and residency layer. + +It owns: + +- `tile -> ValueTile` bindings +- `ValueTileView` +- residency transitions across `VRAM`, `MRAM`, `HBM`, and `FPRAM` +- write preparation and rebinding + +Its job is to answer questions like: + +- which real backing value a logical tile currently points to +- which window of that value the tile currently sees +- whether a write may reuse the old backing or must create a new one + +### 2.3 `ComputeManager` + +`ComputeManager` is the last-mile lowering layer. + +It owns: + +- operand validation +- ensure-at-use placement checks +- ISA emission for compute operations + +Its job is to turn prepared operands into actual emulator-side compute +instructions such as: + +- matmul-like operations +- tile binary operations +- FP kernel operations +- row operations + +### 2.4 `ThreadManager` + +`ThreadManager` is the symbolic parallel layer. + +It owns: + +- `parallel_region3d(...)` +- `parallel_region2d(...)` +- parallel expression capture +- `ParallelRegionGraph` +- execution-plan derivation and lowering + +Its job is to support a "describe first, lower later" programming model for +parallel regions. + +### 2.5 `HardwareManager` + +`HardwareManager` is the hardware-object registry. + +It owns metadata for: + +- HBM objects +- VRAM objects +- MRAM objects + +It is not responsible for tensor semantics. It acts more like a hardware-side +registry of visible objects and addresses. + +### 2.6 `TileTensorProgram` + +`TileTensorProgram` is the user-facing facade. + +This is the API layer users work with directly: + +- `input(...)` +- `tensor(...)` +- `copy(...)` +- `matmul(...)` +- `row_op(...)` +- `parallel_region3d(...)` +- `compile()` + +It orchestrates the lower layers rather than replacing them. + +## 3. Main Units + +The current system is organized around a few important units. + +### 3.1 Logical objects + +- `Input` +- `Tensor` +- `Vector` + +These are the logical objects users author against. + +### 3.2 Logical tiles + +- `InputTile` +- `TensorTile` +- `VectorTile` + +These are the main execution-side logical units in the tensor path. + +For the current runtime, tile is the main tensor-world unit, not individual +element IR. + +### 3.3 Backing values + +- `ValueTile` + +This represents one concrete backing value version. + +The key idea is that a logical tile and its backing value are not the same +thing. One logical tile points to one current backing value, but that binding +may change over time. + +### 3.4 Views + +- `ValueTileView` + +This represents the logical window a tile currently sees on a backing value. + +This is central to the current write model, because many writes are not simply +"replace one whole tensor", but rather "update one logical view of a backing +value". + +### 3.5 FP-domain units + +- `FPVar` +- `FPFragment` + +These belong to the FP domain and are intentionally separate from the main +tensor value/view path. + +### 3.6 Parallel symbolic units + +- `ParallelAccess` +- `ParallelExpr` +- `ParallelRegionGraph` +- `ParallelExecutionPlan` + +These are used by the symbolic parallel path to represent access patterns, +expressions, captured regions, and derived execution plans. + +## 4. Core Runtime Law + +The most important runtime law in the current system is: + +`logical tile -> ValueTileView -> compute -> bind/writeback` + +This means: + +1. start from the logical tile the user refers to +2. resolve the view that tile currently sees +3. run compute on the appropriate backing value(s) +4. bind the result back or write it out + +This is more accurate than thinking in terms of "the tensor directly owns the +physical data". + +## 5. Main Design Principles + +### 5.1 Logical objects and physical backing are separated + +The system intentionally separates: + +- logical tensor identity +- physical backing value identity + +This allows rebinding, alias-safe updates, partial views, and residency control +without pretending that one logical tensor always corresponds to one immutable +piece of storage. + +### 5.2 Tile is the main tensor execution unit + +The current compiler/runtime is built around tile-level execution. + +That means: + +- tensor lowering is primarily organized around tiles +- placement and residency are tracked per value/tile relationship +- many compute paths assume tile-granular movement and tile-granular writes + +### 5.3 Writes are view-aware + +A destination write is not treated as a blind overwrite by default. + +Instead, the runtime asks: + +- what view is being updated +- whether the old backing can be safely reused +- whether a new backing must be created +- whether old contents must be preserved + +This is why `ValueTileView` and `PreparedWrite` exist. + +### 5.4 Alias safety is more important than naive in-place update + +If the destination aliases a live source, the runtime does not assume that +overwriting in place is safe. + +It prefers to preserve correct read/write semantics first, then optimize the +physical path second. + +### 5.5 Preserve-copy is the last resort + +For partial updates, full physical copy in VRAM is intentionally treated as the +slow fallback path. + +The preferred order is: + +1. reuse old backing in place when safe +2. replace whole logical tile without preserve copy when possible +3. create a partial-update successor without physical copy when possible +4. use physical preserve copy only as a last resort + +### 5.6 FP domain is separate from the tensor value/view domain + +The system intentionally keeps: + +- tensor path +- FP-var / FP-fragment path + +as two related but distinct worlds. + +This is important because FP-oriented scalar/vector logic often has different +requirements from tile-backed tensor writes. + +### 5.7 Parallel regions are symbolic first, executable second + +`parallel_region3d(...)` and `parallel_region2d(...)` are not immediate +execution blocks. + +They work in two stages: + +1. capture symbolic accesses and expressions +2. finalize and lower them into execution steps later + +So the parallel path is fundamentally a symbolic programming model, not just a +Python loop shortcut. + +### 5.8 Prefer structured layouts over ad hoc 2D authoring + +Although rank-2 shapes exist in the current runtime, the most mature and +recommended authoring path is still BSHD-style structured layouts: + +- `(B, S, H, D)` +- `(B, S, 1, hidden)` + +In practice, new kernels should usually prefer these layouts over building new +flows around plain 2D matrices. + +## 6. How To Think About The Current Compiler + +One useful mental model is: + +- `TensorManager` decides what the user meant logically +- `ValueManager` decides what backing values exist and where they live +- `ComputeManager` decides how to execute the operation +- `ThreadManager` handles symbolic parallel capture and lowering +- `TileTensorProgram` ties the whole flow together + +Another useful summary is: + +"The current compiler/runtime is a tile-centric, view-aware lowering system +that separates logical tensors from physical backing values and lowers both +normal tensor compute and symbolic parallel regions into emulator ISA." + +## 7. Short Summary + +If this document needs to be reduced to one paragraph, the most accurate short +description is: + +The current `TileTensorProgram` compiler/runtime is a TileTensor testbench +program builder organized around logical tensors, tile-based execution, +backing-value rebinding, and explicit lowering. Its core idea is that logical +tensors do not directly own physical storage; instead, logical tiles resolve to +views on backing values, compute runs on those values with alias-safe update +rules, and the final result is lowered into emulator-oriented ISA. Parallel +regions add a symbolic capture layer on top of that model. diff --git a/tilelang_runtime_compier/doc/TILE_TENSOR_PROGRAM_USAGE.md b/tilelang_runtime_compier/doc/TILE_TENSOR_PROGRAM_USAGE.md new file mode 100644 index 0000000..fc6f632 --- /dev/null +++ b/tilelang_runtime_compier/doc/TILE_TENSOR_PROGRAM_USAGE.md @@ -0,0 +1,815 @@ +# TileTensorProgram Usage Guide + +This document is a user-facing guide for +`transactional_emulator/testbench/tile_tensor_program.py`. + +It focuses on how to author programs with `TileTensorProgram`, which public +APIs are available, how the common workflows fit together, and what the +current implementation constraints are. + +For runtime internals and design notes, see: + +- `TILE_TENSOR_RUNTIME_NOTES.md` +- `TILE_TENSOR_KERNEL_PROGRAMS.md` + +## 1. What This File Is + +`TileTensorProgram` is the main authoring API for building TileTensor testbench +programs. + +At a high level it lets you: + +- declare logical inputs and working tensors +- express tile-level tensor movement and compute +- express FP-domain scalar / fragment compute +- describe symbolic parallel regions +- lower all of that into emulator-oriented ISA text through `compile()` + +The most common workflow is: + +1. create `TileTensorProgram` +2. declare `input(...)` and `tensor(...)` +3. move data with `copy(...)` +4. run compute with `matmul(...)`, `atomic_*`, `row_op(...)`, `pure_fp_compute(...)`, or parallel regions +5. copy the final tensor back to an output buffer +6. call `compile()` + +## 2. Construction + +Typical construction: + +```python +from tile_tensor_program import TileTensorProgram + +prog = TileTensorProgram( + mlen=64, + blen=4, + btmm_hlen=16, + real_data_ratio=1.125, + vram_tile_capacity=16, + mram_tile_capacity=4, + fpram_capacity=1024, +) +``` + +Main constructor parameters: + +- `mlen` + Tile width / height in logical elements. Many vectorized operations assume + rows of width `mlen`. +- `blen` + Block width used by the underlying matmul lowering. +- `btmm_hlen` + Head width for BTMM-style paths. Must be a positive divisor of `mlen`. +- `real_data_ratio` + Scaling factor used when allocating HBM addresses. +- `vram_tile_capacity`, `mram_tile_capacity`, `fpram_capacity` + Resource sizing hints for the emulator/runtime. +- `hbm_base_addr` + Initial HBM allocation base. + +## 3. Logical Shapes + +The runtime currently works with logical shapes of rank 2, 3, or 4: + +- 2D: `(rows, cols)` +- 3D: `(x, y, z)` and internally treated as `rows=x`, `cols=y*z` +- 4D: `(B, S, H, D)` and internally treated as `rows=B*S`, `cols=H*D` + +Common patterns: + +- plain matrix: `(rows, cols)` +- sequence-hidden tensor: `(batch, seq, 1, hidden)` +- attention layout: `(batch, seq, heads, head_dim)` + +Important recommendation: + +- Although rank-2 `(rows, cols)` plain matrices are supported in several basic + paths, the 2D path is not yet the most mature authoring path in the current + runtime. +- For new kernels and new program authoring, prefer writing tensors in BSHD + form, even when the computation could be expressed as a plain 2D matrix. +- In practice, the most stable and best-covered authoring style today is to + use `(B, S, H, D)` or `(B, S, 1, hidden)` layouts rather than building new + flows around rank-2 tensors. + +## 4. Main Public APIs + +The most important user-facing methods are: + +- declaration + - `input(name, logical_shape, hbm_addr=None)` + - `tensor(name, logical_shape)` + - `vector(name, logical_shape)` + - `alloc_fragment(name, logical_shape, init_zero=False, dtype="fp32")` + - `alloc_shared(name, logical_shape, init_zero=False, dtype="fp32")` + - `fp_var(name, value=0.0, size=1)` + - `fp_fragment(name, shape, init=0.0)` + - `constant(name, value, size=1)` + +- tensor movement / compute + - `copy(src, dst)` + - `matmul(src1, src2, dst)` + - `atomic_add(src1, src2, dst)` + - `atomic_sub(src1, src2, dst)` + - `atomic_mul(src1, src2, dst)` + - `row_op(src, rhs=None, op=..., out=None, dim=-1)` + - `clear(tensor)` + - `clear_tensor(operand, weak=None)` + +- FP-domain compute + - `fp_copy`, `fp_fill`, `fp_add`, `fp_sub`, `fp_mul`, `fp_max` + - `fp_exp`, `fp_reci`, `fp_sqrt` + - `fill(dst, src)` for FP-domain destinations + +- symbolic parallel programming + - `parallel_region3d((S, H, D), name=None)` + - `parallel_region2d((X, Y), name=None)` + - `where(predicate, on_true, on_false)` + - `if_then_else(predicate, on_true, on_false)` + - `max(lhs, rhs)`, `exp(x)`, `reci(x)`, `sqrt(x)` + - `pair(axis)`, `half_index(axis)` + - `parallel_execution_plans()` + - `lower_parallel_execution_plans()` + +- loop hints and planning helpers + - `parallel(extent)` + - `pipelined(extent, num_stages=1)` + +- reporting / output + - `write_operation_report(output_path)` + - `build_fp_preload(min_size=0)` + - `compile()` + +- advanced / low-level APIs + - `pure_fp_compute(src1, dst, src2=None, control=...)` + - `fp_kernel(src1, dst, src2=None, control=...)` + - `mapf(...)`, `mapf_t(...)` + - `btmm(...)`, `btmm_write(...)` + - `alloc_hbm_addr(...)`, `add_hbm_object(...)` + - `emit_*` family for direct ISA emission + +## 5. Minimal End-To-End Example + +This is the simplest practical pattern: declare input and output buffers, copy +through a working tensor, then compile. + +```python +from tile_tensor_program import TileTensorProgram + +prog = TileTensorProgram( + mlen=64, + blen=4, + btmm_hlen=16, + vram_tile_capacity=16, + mram_tile_capacity=4, + fpram_capacity=1024, +) + +x_in = prog.input("X_IN", (1, 64)) +out_buf = prog.input("OUT", (1, 64)) +x = prog.tensor("X", (1, 64)) + +prog.copy(x_in, x) +prog.copy(x, out_buf) + +asm = prog.compile() +print(asm) +``` + +Typical real kernels insert additional compute between the two `copy(...)` +operations. + +## 6. Declaring Operands + +### 6.1 `input(...)` + +Use `input(...)` for logical tensors backed by HBM input/output objects. + +```python +x_in = prog.input("X_IN", (batch, seq, 1, hidden)) +out_buf = prog.input("OUT", (batch, seq, 1, hidden)) +``` + +Notes: + +- Inputs are usually sources, but an `Input` may also be used as a final + writeback target. +- You may provide `hbm_addr=...` if the buffer must live at a fixed HBM + address. + +### 6.2 `tensor(...)` + +Use `tensor(...)` for normal working tensors managed by the runtime. + +```python +x = prog.tensor("X", (batch, seq, 1, hidden)) +y = prog.tensor("Y", (batch, seq, 1, hidden)) +``` + +These are the standard temporary / internal compute operands. + +### 6.3 `vector(...)` + +Use `vector(...)` when you want an FP-backed vector-style object. Vector tiles +are associated with FP fragments rather than the normal tensor value/view path. + +This is mainly relevant for: + +- explicit FP-domain authoring +- `parallel_region2d`, which currently lowers only FP-backed `Vector` + destinations + +### 6.4 `alloc_fragment(...)` + +Use `alloc_fragment(...)` for default scratch temporaries. Depending on the +shape, the runtime may return a normal `Tensor` or a `Vector`. + +Typical examples: + +```python +centered = prog.alloc_fragment("CENTERED", (1, seq_len, 1, hidden_size)) +mean = prog.alloc_fragment("MEAN", (1, 1, seq_len)) +``` + +This is the most common way to allocate internal working buffers. + +Current VRAM intent: + +- `alloc_fragment(...)` marks the allocated tensor/vector as `l0` +- vector-shaped scratch such as `(1, H, M)` or `(1, 1, M)` should usually stay + on this path + +### 6.5 `alloc_shared(...)` + +Use `alloc_shared(...)` when the tensor is intended to behave like shared +scratch instead of ordinary `l0` scratch. + +Typical examples: + +```python +mask = prog.alloc_shared("MASK", (1, mlen, 1, mlen)) +q_group = prog.alloc_shared("Q_GROUP", (1, mlen, group_heads, hlen)) +``` + +Current VRAM intent: + +- `alloc_shared(...)` marks the allocated tensor/vector as `shared` +- shared-protected values are harder to evict from VRAM than ordinary `l0` + scratch +- use this path deliberately; it is not the default + +### 6.6 FP scalar / fragment declarations + +```python +scale = prog.fp_var("scale", value=0.125) +eps = prog.constant("eps", 1.0e-6, size=seq_len) +frag = prog.fp_fragment("tmp_fp", (seq_len,), init=0.0) +``` + +Use these for scalar values and small FP-domain arrays that should live in +FP_MEM / FP fragments. + +## 7. Indexing, Slicing, and Element Access + +The API supports Python indexing and slicing: + +```python +x[batch_index, :, :, :] +x[:, :, 0:1, :] +x[0, 0, :] +``` + +Common patterns: + +- whole tensor: `x` +- slice view: `x[:, :, :, :]`, `x[0, :, :, :]`, `x[:, :, 0:1, :]` +- element-like FP access: `scores_max[0, h, s]` + +Important distinction: + +- tensor / input slices participate in the tile/value runtime +- element-style accesses are used heavily by FP-domain and parallel-expression + APIs + +## 8. Data Movement + +### 8.1 `copy(src, dst)` + +`copy(...)` is the basic logical movement operator. + +```python +prog.copy(x_in, x) +prog.copy(y, out_buf) +prog.copy(x[0, :, :, :], tmp[0, :, :, :]) +``` + +Behavior: + +- tensor/input to tensor: rebinds or prepares backing values as needed +- tensor to input: performs logical writeback +- FP-domain operands: routes to `fp_copy` + +## 9. Tensor Compute APIs + +### 9.1 `matmul(src1, src2, dst)` + +`matmul(...)` is the main matrix multiply entrypoint. Internally it may choose +one of several paths: + +- default tilewise matmul +- view-based matmul for grouped narrow-head layouts +- BTMM/QKT path when `src2` is explicitly transposed and shapes match + +Example: + +```python +prog.matmul(x, w, y) +prog.matmul(q_group, k_group.T, score_group) +``` + +Notes: + +- explicit transpose syntax on the RHS is currently reserved for the BTMM/QKT + route +- not every transposed case is supported + +### 9.2 `atomic_add`, `atomic_sub`, `atomic_mul` + +These implement elementwise tile ops with alias-safe destination updates. + +```python +prog.atomic_add(a, b, out) +prog.atomic_mul(centered, centered, sq) +prog.atomic_add(score_head, mask_head, score_head) +``` + +Use these when: + +- the operation is tilewise / elementwise +- the destination may alias one input +- you want runtime-managed preservation and rebinding behavior + +### 9.3 `clear(tensor)` + +Zeroes all current value tiles of a tensor in VRAM. + +```python +prog.clear(accumulator) +``` + +### 9.4 `clear_tensor(...)` + +Clears runtime bindings for one tensor, slice, or tile. + +```python +prog.clear_tensor(tmp) +prog.clear_tensor(score_group) +``` + +This is commonly used for scratch fragments once their values are no longer +needed. + +`free_tensor_tile(...)` still exists as a compatibility alias, but new code +should use `clear_tensor(...)`. + +## 10. Row Operations + +`row_op(...)` is the main API for row-wise vector math and reductions along the +last logical dimension. + +Supported operations: + +- unary row ops + - `exp` + - `reci` +- binary row ops + - `mul` + - `add` + - `sub` +- reductions + - `reduce_sum` + - `reduce_max` + +Examples: + +```python +prog.row_op(x_head, op="reduce_sum", out=mean[0, 0, :], dim=-1) +prog.row_op(x_head, mean[0, 0, :], "sub", dim=-1) +prog.row_op(work_head, inv_rms[0, 0, :], "mul", dim=-1) +prog.row_op(score_head, op="exp", dim=-1) +``` + +Current contract: + +- `dim` must be `-1` +- reductions require `out=...` +- binary row ops require `rhs` +- the current lowering is most natural when each logical row has width `mlen` + +## 11. FP-Domain APIs + +The FP domain is intentionally separate from the tensor value/view pipeline. +Use it for FP-variable / FP-fragment compute, including: + +- scalar FP values +- short FP vectors +- fragment-backed row data +- elementwise FP math over one mapped FP-var list +- reduction-associated post-processing + +Important recommendation: + +- `pure_fp_compute(...)` and related FP mapping APIs are still supported, but + they are better treated as lower-level runtime-facing interfaces +- for new FP-heavy authoring, the preferred direction is usually + `parallel_region2d(...)` when the computation naturally fits lane-wise FP + vector work +- in other words, new user-facing FP logic should generally prefer symbolic + `parallel_region2d(...)` over building more code around `pure_fp_compute(...)` + +### 11.1 Recommended direction for new FP logic: `parallel_region2d(...)` + +For new FP-oriented logic, the preferred high-level style is usually +`parallel_region2d(...)`. + +Examples: + +```python +with prog.parallel_region2d((group_heads, mlen)) as (h, s): + scores_max[0, h, s] = prog.max(scores_max[0, h, s], scores_max_prev[0, h, s]) + +with prog.parallel_region2d((1, mlen)) as (_, s): + scores_scale[0, 0, s] = prog.exp(scores_scale[0, 0, s]) +``` + +Why this is the preferred direction: + +- it reads more like the intended lane-wise FP computation +- it avoids exposing as much FP-var plumbing in user-facing kernels +- it is closer to the current symbolic parallel authoring direction + +Current limitation: + +- `parallel_region2d(...)` is still narrower than a fully general FP compiler +- today it is mainly for FP-backed `Vector`-style destinations and supported + FP expression forms + +### 11.2 Low-level FP compute APIs: `pure_fp_compute(...)` and `fp_kernel(...)` + +These are the generic FP compute entrypoints. + +```python +prog.pure_fp_compute(mean[0, 0, :], mean[0, 0, :], src2=recip_hidden, control="mul") +prog.pure_fp_compute(var[0, 0, :], var[0, 0, :], src2=eps_vec, control="add") +``` + +Important clarification: + +- these are not limited to one scalar element +- in normal use, they operate over the full FP-var list returned by `mapf(...)` +- that means calls like `mean[0, 0, :]` or `var[0, 0, :]` are vector-style + elementwise FP operations over the whole slice, not one single-element op +- the runtime applies the requested FP operation across the mapped destination + FP vars + +How to think about them: + +- these APIs are still useful +- but they are better considered low-level or transitional interfaces +- they expose more of the FP-mapping model than we usually want in the main + user-facing programming style + +Supported `control` values include: + +- `copy` +- `add` +- `sub` +- `mul` +- `max` +- `exp` +- `reci` +- `sqrt` + +In existing code they can be convenient, but for new authoring they should not +be treated as the primary recommended FP style. + +### 11.3 Convenience wrappers + +These all dispatch into the FP kernel path: + +- `fp_copy(src, dst)` +- `fp_fill(dst, src)` +- `fp_fill_from_addr(dst, src_fpram_addr)` +- `fp_add(src1, src2, dst)` +- `fp_sub(src1, src2, dst)` +- `fp_mul(src1, src2, dst)` +- `fp_max(src1, src2, dst)` +- `fp_exp(src, dst)` +- `fp_reci(src, dst)` +- `fp_sqrt(src, dst)` + +### 11.4 `fill(dst, src)` + +`fill(...)` currently supports FP-domain destinations only. + +```python +prog.fill(mean[0, 0, :], 0.0) +prog.fill(var[0, 0, :], 0.0) +``` + +### 11.5 Low-level mapping helpers: `mapf(...)` and `mapf_t(...)` + +These are lower-level APIs for mapping operands into FP variables. + +- `mapf(operand)` + Returns the FP-var list for an operand. +- `mapf_t(tensor_operand, fp_operand, control="mixed")` + Mixed tensor-to-FP mapping helper. + +Most kernel authors do not need to call these directly unless they are doing +custom FP-domain orchestration. + +### 11.6 `build_fp_preload(...)` + +Returns the FP_MEM initialization array in address order. + +```python +fp_init = prog.build_fp_preload(min_size=32) +``` + +This is typically passed into a testbench artifact writer. + +## 12. Symbolic Parallel Programming + +The runtime supports symbolic parallel authoring through `parallel_region3d` +and `parallel_region2d`. + +Inside these scopes: + +- indexing with region axes produces symbolic loads +- assignments register symbolic compute instead of running immediately +- region finalization builds an execution plan and lowers it later + +### 12.1 `parallel_region3d((S, H, D))` + +This is the main parallel API for tensor-backed parallel compute. + +Example from RoPE-style code: + +```python +with prog.parallel_region3d((seq_len, head_count, full_dim), name="rope_q") as (s, h, d): + q_out[0, s, h, d] = prog.if_then_else( + d % 2 == 0, + xq[0, s, h, d] * cos_t[0, s, h, d] + + xq[0, s, h, prog.pair(d)] * neg_sin_t[0, s, h, d], + xq[0, s, h, prog.pair(d)] * sin_t[0, s, h, d] + + xq[0, s, h, d] * cos_t[0, s, h, d], + ) +``` + +Useful helpers: + +- `if_then_else(...)` / `where(...)` +- arithmetic operators on symbolic expressions +- comparisons like `<`, `<=`, `==`, `>=`, `>` +- `pair(d)` for even/odd lane partner selection +- `half_index(d)` for half-width grouping logic +- unary helpers `exp(...)`, `reci(...)`, `sqrt(...)` +- binary helper `max(...)` + +### 12.2 `parallel_region2d((X, Y))` + +This is a narrower parallel path used for FP-backed vector destinations. + +Example: + +```python +with prog.parallel_region2d((group_heads, mlen)) as (h, s): + scores_max[0, h, s] = prog.max(scores_max[0, h, s], scores_max_prev[0, h, s]) +``` + +Another example: + +```python +with prog.parallel_region2d((1, mlen)) as (_, s): + scores_scale[0, 0, s] = prog.exp(scores_scale[0, 0, s]) +``` + +Current 2D contract is much narrower than the 3D path: + +- destinations must be FP-backed `Vector`-style objects +- lowering supports FP expression kernels, not the full tensor write path + +### 12.3 Plan inspection + +You can inspect or force lowering of captured parallel regions: + +```python +plans = prog.parallel_execution_plans() +prog.lower_parallel_execution_plans() +``` + +`compile()` automatically lowers deferred parallel plans if needed. + +## 13. Loop Hints + +### 13.1 `parallel(extent)` + +Returns a range-like object and records a parallel loop hint. + +```python +for local_head in prog.parallel(group_heads): + ... +``` + +This is commonly used in kernel authoring to express per-head or per-lane +structure around other runtime ops. + +### 13.2 `pipelined(extent, num_stages=1)` + +Returns a range-like object and records a pipelining hint. + +```python +for i in prog.pipelined(tile_count, num_stages=2): + ... +``` + +This is mainly a planning hint for future/lower layers. + +## 14. Reporting and Debugging + +### 14.1 `write_operation_report(...)` + +Writes a human-readable trace of recorded operations and a delta report. + +```python +prog.write_operation_report("build/my_operation_report.txt") +``` + +The report includes: + +- operation kind and details +- VRAM / MRAM / HBM-resident value tiles +- active FP fragments +- value-tile-to-slice references +- FP-fragment-to-value references + +### 14.2 `compile()` + +`compile()` lowers any remaining parallel execution plans, normalizes large +immediates, and returns the generated ISA text. + +```python +asm = prog.compile() +``` + +## 15. HBM and Direct ISA Helpers + +These are advanced APIs for users who need explicit control over memory +objects or direct instruction emission. + +### 15.1 HBM helpers + +- `alloc_hbm_addr(elems)` +- `add_hbm_object(name, shape, hbm_addr=None)` + +Example: + +```python +base = prog.alloc_hbm_addr(64 * 64) +prog.add_hbm_object("X_BUF", (64, 64), hbm_addr=base) +``` + +### 15.2 Direct emit helpers + +Available direct emit APIs include: + +- `emit_hbm_tile_to_mram(...)` +- `emit_load_tile_from_hbm(...)` +- `emit_store_tile_to_hbm(...)` +- `emit_zero_vram_tile(...)` +- `emit_map_v_fp_tile(...)` +- `emit_map_fp_v_tile(...)` +- `emit_btmm(...)` +- `emit_btmm_wo(...)` +- `emit_matmul(...)` +- `emit_slot_matmul(...)` +- `emit_tile_binary(...)` +- `emit_tile_add(...)` +- `emit_fp_kernel(...)` +- `emit_row_operation(...)` + +These bypass much of the higher-level logical runtime. Use them only when: + +- building custom lowering paths +- debugging ISA generation +- prototyping a new runtime feature + +Most kernel authors should prefer the higher-level APIs first. + +## 16. Advanced BTMM APIs + +`matmul(...)` already routes into BTMM when the pattern matches, but the lower +level entrypoints also exist: + +- `btmm(lhs_packed_value=..., rhs_value=..., task_id="btmm")` +- `btmm_write(btmm_state=..., tile_count=..., ...)` + +These are advanced runtime hooks for explicit BTMM orchestration and are +normally used by the internal specialized matmul path. + +## 17. Common Authoring Patterns + +### 17.1 Input -> work tensor -> compute -> output + +```python +prog.copy(x_in, x) +prog.matmul(x, w, y) +prog.copy(y, out_buf) +``` + +### 17.2 LayerNorm-style flow + +```python +prog.copy(x_in, x) +prog.copy(x, centered) +prog.fill(mean[0, 0, :], 0.0) +prog.row_op(centered[0, :, 0:1, :], op="reduce_sum", out=mean[0, 0, :], dim=-1) +prog.pure_fp_compute(mean[0, 0, :], mean[0, 0, :], src2=recip_hidden, control="mul") +prog.row_op(centered[0, :, 0:1, :], mean[0, 0, :], "sub", dim=-1) +``` + +### 17.3 Parallel symbolic transform + +```python +with prog.parallel_region3d((seq_len, heads, dim), name="transform") as (s, h, d): + y[0, s, h, d] = prog.if_then_else( + d % 2 == 0, + x[0, s, h, d], + x[0, s, h, prog.pair(d)], + ) +``` + +## 18. Current Constraints + +The current implementation is intentionally narrower than a general tensor +compiler. The main constraints to document clearly are: + +- logical shapes are currently rank 2, 3, or 4 +- rank-2 plain matrix support exists, but it is not yet the most mature or + recommended primary authoring path +- for new development, prefer expressing tensors in BSHD-style layouts +- many high-level tensor ops currently require each BSHD operation to address + exactly one batch at a time +- `row_op(...)` currently supports `dim=-1` only +- `fill(...)` currently supports FP-domain destinations only +- explicit transposed RHS matmul is currently reserved for the BTMM/QKT route +- `parallel_region3d` lowering supports a focused subset of symbolic tensor + expressions +- `parallel_region2d` currently supports FP-backed vector destinations only +- current parallel lowering assumes structured, row-oriented execution and is + not a general arbitrary-index compiler +- several low-level emit helpers assume widths tied to `mlen` + +## 19. Practical Advice + +- Prefer `copy`, `matmul`, `atomic_*`, and `row_op` first. +- Prefer BSHD-style logical layouts for new code, even if a problem looks like + a plain 2D matrix at first glance. +- Use `alloc_fragment(...)` for normal scratch temporaries and + `alloc_shared(...)` when the tensor should be treated as shared scratch. +- Use `clear_tensor(...)` when a scratch tensor is no longer needed. +- Use `parallel_region3d(...)` when the computation is naturally lane-wise over + `(S, H, D)`. +- Prefer `parallel_region2d(...)` over `pure_fp_compute(...)` when writing new + FP-vector-style logic. +- Use `write_operation_report(...)` when debugging residency, rebinding, or + unexpected tile reuse. +- Use direct `emit_*` APIs only when you intentionally want ISA-level control. + +## 20. Where To Look For Examples + +Good in-repo examples: + +- `tile_tensor_kernel_programs/linear.py` + Basic input -> matmul -> output flow +- `tile_tensor_kernel_programs/layernorm.py` + `row_op`, `pure_fp_compute`, fragments, and FP helpers +- `tile_tensor_kernel_programs/rmsnorm.py` + Reduction-heavy FP/tensor mixed flow +- `tile_tensor_kernel_programs/rope.py` + `parallel_region3d`, `if_then_else`, `pair` +- `tile_tensor_kernel_programs/attention.py` + mixed `matmul`, `row_op`, `parallel_region2d`, and scratch-fragment usage + +## 21. Quick Reference + +If you only remember one short checklist, use this: + +1. declare `input(...)` and `tensor(...)` +2. `copy(...)` data into working tensors +3. use `matmul`, `atomic_*`, `row_op`, or parallel regions for compute +4. use `pure_fp_compute` / `fp_*` for scalar or FP-fragment math +5. `copy(...)` final tensors into output buffers +6. call `compile()` diff --git a/tilelang_runtime_compier/doc/TILE_TENSOR_RUNTIME_NOTES.md b/tilelang_runtime_compier/doc/TILE_TENSOR_RUNTIME_NOTES.md new file mode 100644 index 0000000..c86cc7b --- /dev/null +++ b/tilelang_runtime_compier/doc/TILE_TENSOR_RUNTIME_NOTES.md @@ -0,0 +1,139 @@ +# TileTensor Runtime Notes + +This note reflects the current runtime architecture in `tile_tensor_program.py`. + +## Main layers + +- `TensorManager` + Owns logical tensors, tiles, slices, and `mapt` grouping. + +- `ValueManager` + Owns `tile -> ValueTile` bindings, `ValueTileView` resolution, residency, and + write preparation. + +- `ComputeManager` + Owns last-mile ensure-at-use, operand validation, and ISA emission. + +- `ThreadManager` + Owns the symbolic `parallel_region3d` flow: region capture, expression + validation, graph finalization, cache planning, and execution-plan lowering. + +## Core concepts + +- `ValueTile` + Persistent backing object. Residency and addresses live here. + +- `ValueTileView` + Ephemeral logical window over one `ValueTile`. Views are computed from the + current tile binding and metadata; they are not stored as long-lived state. + +- `PreparedWrite` + Explicit write-preparation result returned by + `prepare_updated_view_value(...)`. + +- `ParallelRegionGraph` + Captured representation of one symbolic 3D parallel region. It stores axes, + assignments, cache metadata, and the derived execution plan. + +- `ParallelExecutionPlan` + Cycle-structured lowering plan produced from one `ParallelRegionGraph`. + +- `ParallelAccess` / `ParallelExpr` + Symbolic load and expression nodes used while building a parallel graph. + +## Core functions + +- `resolve_value_tile(tile)` + Input: `TensorTile | InputTile` + Output: `ValueTile` + Meaning: return the current backing value for one logical tile. + +- `resolve_value_tile_view(tile)` + Input: `TensorTile | InputTile` + Output: `ValueTileView` + Meaning: return the logical window that this tile currently sees on its + backing value. + +- `prepare_updated_view_value(tile, view, ...)` + Input: one destination tile plus the view being updated + Output: `PreparedWrite` + Meaning: main tensor write-preparation API. Decides: + - whether the write may reuse the old backing + - whether it must switch to a fresh backing + - whether a partial-update preserve copy is still required + +- `prepare_vram_backing_value(value, ...)` + Lower-level helper that prepares a VRAM-backed `ValueTile`. It does not by + itself define write semantics. + +- `parallel_region3d((S, H, D), name=...)` + Enter a symbolic 3-axis parallel capture scope. + +- `where(...)`, `if_then_else(...)` + Build symbolic selection expressions for masked parallel compute. + +- `pair(axis)`, `half_index(axis)` + Parallel indexing helpers used by RoPE-like lane pairing and coefficient + addressing. + +- `parallel_execution_plans()` + Inspect finalized parallel execution plans. + +- `lower_parallel_execution_plans()` + Force emission of deferred parallel plans if they were not lowered on region + exit. + +## Preferred tensor write path + +For tensor destinations, the preferred internal flow is: + +1. resolve the target view +2. call `prepare_updated_view_value(...)` +3. run compute +4. bind/write back the result + +## Parallel Execution Flow + +The `parallel` feature is now a major part of the runtime, not a side helper. +The intended flow is: + +1. enter `parallel_region3d((S, H, D))` +2. use symbolic axes to describe loads and assignments +3. finalize the region into `ParallelAssignment` records +4. derive cache and cycle plans +5. lower each cycle into load / compute / writeback steps +6. bind the region outputs back to tensor tiles + +This gives kernel authors a higher-level way to describe structured SIMD-style +tile work while still preserving explicit lowering behavior. + +## Current Parallel Contract + +The current implementation is intentionally narrower than a fully general tensor +compiler, and the docs should state that clearly: + +- parallel destinations must resolve to tensor-backed writes +- destination selectors must use exactly the active 3D axes +- expression lowering currently supports binary `add`, `sub`, and `mul` +- predicate lowering currently supports binary comparisons +- present lowering expects one full-width contiguous row per cycle + +Within those limits, `parallel_region3d` is already powerful enough to express +important kernels such as lane-wise elementwise flows, RoPE remapping, and +parallel attention-style data movement. + +## FP domain + +The FP domain is intentionally separate from the tensor value/view path. + +- `FPVar` + Scalar FP storage + +- `FPFragment` + Small structured FP storage + +- `pure_fp_compute(...)`, `fp_kernel(...)` + FP-domain execution helpers + +The FP domain often interacts with tensor ops through `row_op(...)`, reductions, +and scalar broadcast cases, but it is not modeled as `ValueTileView`. diff --git a/tilelang_runtime_compier/tile_tensor_program/__init__.py b/tilelang_runtime_compier/tile_tensor_program/__init__.py new file mode 100644 index 0000000..b6f2721 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/__init__.py @@ -0,0 +1,43 @@ +"""TileTensor program package. + +Re-exports the public API so existing callers can continue to use +`from tile_tensor_program import TileTensorProgram` etc. unchanged. + +The package is split as: +- _types: FP/Parallel/Tile dataclasses and module-level type aliases +- _helpers: module-level `_xxx` helper functions +- _hardware_manager: HardwareManager +- _thread_manager: ThreadManager (+ _LoopHintRange) +- _value_manager: ValueManager +- _tensor_manager: TensorManager +- _vector_manager: VectorManager +- _compute_manager: ComputeManager (ISA emitter) +- _program: TileTensorProgram (top-level builder) +""" + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 +from ._hardware_manager import HardwareManager +from ._thread_manager import ThreadManager +from ._value_manager import ValueManager +from ._tensor_manager import TensorManager +from ._vector_manager import VectorManager +from ._compute_manager import ComputeManager +from ._program import TileTensorProgram + +# Re-export commonly imported names from outside the package: +# from tile_tensor_program import Input, TileTensorProgram, _logical_shape_to_physical_shape +from ._types import Input +from ._helpers import _logical_shape_to_physical_shape + +__all__ = [ + "TileTensorProgram", + "HardwareManager", + "ThreadManager", + "ValueManager", + "TensorManager", + "VectorManager", + "ComputeManager", + "Input", + "_logical_shape_to_physical_shape", +] diff --git a/tilelang_runtime_compier/tile_tensor_program/_compute_manager.py b/tilelang_runtime_compier/tile_tensor_program/_compute_manager.py new file mode 100644 index 0000000..cafb626 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_compute_manager.py @@ -0,0 +1,380 @@ +"""ComputeManager: validates operands, ensures residency, emits ISA.""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class ComputeManager: + """Execute already-prepared tensor/FP operations and emit ISA. + + ComputeManager should not invent binding policy. It assumes the write path + has already been prepared by ValueManager and mainly does: + + - ensure operands in the correct place close to use + - validate lane/view layout for execution kernels + - emit ISA + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.isa_emitter = program.isa_emitter + self.ops: List[Dict[str, object]] = [] + + def execute(self, signal: List[object]) -> Dict[str, object]: + operands, op_kind = signal + record = {"op_kind": op_kind, "operands": operands} + self.ops.append(record) + if op_kind == "matmul": + return self._execute_matmul(operands) + return { + "op_kind": op_kind, + "inputs": operands, + "outputs": operands.get("outputs", []) if isinstance(operands, dict) else operands, + } + + def _execute_matmul(self, operands: object) -> Dict[str, object]: + if not isinstance(operands, tuple) or len(operands) != 4 or operands[0] != "matmul": + raise RuntimeError("matmul execute expects ('matmul', src_pairs, dst_value, dst_tile)") + _, src_pairs, dst_value, _ = operands + if not isinstance(dst_value, ValueTile): + raise RuntimeError("matmul execute expects one destination ValueTile") + + lhs_vram_addrs: List[int] = [] + rhs_mram_addrs: List[int] = [] + for pair in src_pairs: + if not isinstance(pair, list) or len(pair) != 2: + continue + lhs_value, rhs_value = pair + if not isinstance(lhs_value, ValueTile) or not isinstance(rhs_value, ValueTile): + raise RuntimeError("matmul execute expects ValueTile sources") + self.program.value_manager.ensure_value_tile_in_place(lhs_value, "vram") + self.program.value_manager.ensure_value_tile_in_place(rhs_value, "mram") + lhs_vram_addr = lhs_value.residency.get("vram_addr") + rhs_mram_addr = rhs_value.residency.get("mram_addr") + if lhs_vram_addr is None: + raise RuntimeError(f"matmul execute requires lhs value in VRAM: {lhs_value.value_tile_id}") + if rhs_mram_addr is None: + raise RuntimeError(f"matmul execute requires rhs value in MRAM: {rhs_value.value_tile_id}") + lhs_vram_addrs.append(int(lhs_vram_addr)) + rhs_mram_addrs.append(int(rhs_mram_addr)) + + self.program.value_manager.ensure_value_tile_in_place(dst_value, "vram") + dst_vram_addr = dst_value.residency.get("vram_addr") + if dst_vram_addr is None: + raise RuntimeError(f"matmul execute requires dst value in VRAM: {dst_value.value_tile_id}") + + task_id = self._matmul_task_id_from_value(dst_value) + self.isa_emitter.emit_matmul( + lhs_vram_addrs=lhs_vram_addrs, + rhs_mram_addrs=rhs_mram_addrs, + dst_vram_addr=int(dst_vram_addr), + task_id=task_id, + zero_dst=True, + ) + return { + "op_kind": "matmul", + "inputs": operands, + "outputs": [dst_value], + "dst": dst_value, + "task_id": task_id, + } + + def view_matmul( + self, + lhs_values: List[ValueTile], + rhs_tile: TensorTile | InputTile, + dst_tile: TensorTile | InputTile, + dst_value: ValueTile, + *, + task_id: str, + zero_dst: bool, + ) -> Dict[str, object]: + if not lhs_values: + raise RuntimeError("view_matmul expects one non-empty lhs ValueTile list") + if not all(isinstance(value, ValueTile) for value in lhs_values): + raise RuntimeError("view_matmul expects lhs_values to contain ValueTile objects only") + rhs_views = self.program.value_manager._tile_compute_views(rhs_tile) + dst_views = self.program.value_manager._tile_compute_views(dst_tile) + if not rhs_views or not dst_views: + raise RuntimeError("view_matmul expects non-empty rhs/dst view lanes") + rhs_value = self.program.value_manager.value_tiles.get(rhs_views[0].backing_value_tile_id) + if not isinstance(rhs_value, ValueTile): + raise RuntimeError("view_matmul requires rhs backing value") + + for lhs_value in lhs_values: + self.program.value_manager.ensure_value_tile_in_place(lhs_value, "vram") + self.program.value_manager.ensure_value_tile_in_place(rhs_value, "mram") + self.program.value_manager.ensure_value_tile_in_place(dst_value, "vram") + + rhs_mram_addr = rhs_value.residency.get("mram_addr") + dst_vram_addr = dst_value.residency.get("vram_addr") + lhs_vram_addrs = [value.residency.get("vram_addr") for value in lhs_values] + if rhs_mram_addr is None or dst_vram_addr is None or any(addr is None for addr in lhs_vram_addrs): + raise RuntimeError("view_matmul requires lhs in VRAM, rhs in MRAM, dst in VRAM") + if len(rhs_views) != len(dst_views): + raise RuntimeError( + f"view_matmul requires matching rhs/dst slot counts, got rhs={len(rhs_views)} dst={len(dst_views)}" + ) + if len(lhs_values) != len(rhs_views): + raise RuntimeError( + f"view_matmul requires lhs_values to align with lanes, got lhs={len(lhs_values)} slots={len(rhs_views)}" + ) + + lane_logs: List[Dict[str, object]] = [] + for lane_index, (lhs_addr, rhs_view, dst_view, lhs_value) in enumerate( + zip(lhs_vram_addrs, rhs_views, dst_views, lhs_values) + ): + if lhs_addr is None: + raise RuntimeError(f"view_matmul lane {lane_index} is missing one lhs VRAM address") + if rhs_view.col_count != dst_view.col_count: + raise RuntimeError( + f"view_matmul lane {lane_index} slot width mismatch rhs={rhs_view.col_count} dst={dst_view.col_count}" + ) + self.isa_emitter.emit_slot_matmul( + lhs_vram_addr=int(lhs_addr), + rhs_mram_addr=int(rhs_mram_addr), + rhs_col_offset=int(rhs_view.col_offset), + dst_vram_addr=int(dst_vram_addr), + dst_col_offset=int(dst_view.col_offset), + col_count=int(rhs_view.col_count), + task_id=f"{task_id}.lane{lane_index}", + zero_dst=(zero_dst and lane_index == 0), + ) + lane_logs.append( + { + "lane_index": lane_index, + "lhs_value": lhs_value.value_tile_id, + "lhs_vram_addr": int(lhs_addr), + "rhs_view": rhs_view.view_id, + "rhs_col_offset": int(rhs_view.col_offset), + "dst_view": dst_view.view_id, + "dst_col_offset": int(dst_view.col_offset), + "col_count": int(rhs_view.col_count), + } + ) + return { + "op_kind": "view_matmul", + "inputs": [lhs_values, rhs_tile, dst_tile], + "outputs": [dst_value], + "dst": dst_value, + "task_id": task_id, + "lane_logs": lane_logs, + } + + def btmm( + self, + *, + lhs_packed_value: ValueTile, + rhs_value: ValueTile, + task_id: str = "btmm", + ) -> Dict[str, object]: + self.program.value_manager.ensure_value_tile_in_place(lhs_packed_value, "vram") + self.program.value_manager.ensure_value_tile_in_place(rhs_value, "mram") + + lhs_vram_addr = lhs_packed_value.residency.get("vram_addr") + rhs_mram_addr = rhs_value.residency.get("mram_addr") + if lhs_vram_addr is None or rhs_mram_addr is None: + raise RuntimeError("btmm requires lhs_packed_value in VRAM and rhs_value in MRAM") + + self.isa_emitter.emit_btmm( + lhs_packed_vram_addr=int(lhs_vram_addr), + rhs_mram_addr=int(rhs_mram_addr), + task_id=task_id, + ) + return { + "op_kind": "btmm", + "lhs": lhs_packed_value, + "rhs": rhs_value, + "btmm_finished": True, + "task_id": task_id, + } + + def btmm_write( + self, + *, + btmm_state: Dict[str, object], + tile_count: Optional[int] = None, + reason: str = "btmm_write", + logical_shape: Optional[Tuple[int, int]] = None, + metadata: Optional[Dict[str, object]] = None, + task_id: str = "btmm_wo", + ) -> Dict[str, object]: + if not btmm_state.get("btmm_finished"): + raise RuntimeError("btmm_write requires btmm_state.btmm_finished == True") + + resolved_tile_count = self.program.btmm_lane_count if tile_count is None else int(tile_count) + if resolved_tile_count <= 0: + raise ValueError(f"btmm_write requires one positive tile_count, got {resolved_tile_count}") + + out_values, base_addr = self.program.value_manager.allocate_contiguous_vram_value_tiles( + tile_count=resolved_tile_count, + logical_shape=logical_shape if logical_shape is not None else (self.program.mlen, self.program.mlen), + metadata=metadata, + reason=reason, + ) + self.isa_emitter.emit_btmm_wo( + base_addr=base_addr, + tile_count=resolved_tile_count, + task_id=task_id, + ) + return { + "op_kind": "btmm_wo", + "btmm_state": btmm_state, + "dst_values": out_values, + "base_addr": base_addr, + "tile_count": resolved_tile_count, + "task_id": task_id, + } + + def _matmul_task_id_from_value(self, value: ValueTile) -> str: + source_tile_id = value.metadata.get("source_tile_id") + if not isinstance(source_tile_id, str): + return f"matmul.{value.value_tile_id}" + dst_tile = self.program.tensor_manager.tensor_tiles.get(source_tile_id) + if dst_tile is None: + return f"matmul.{value.value_tile_id}" + row_block, col_block = dst_tile.coord + return f"matmul.r{row_block}.c{col_block}" + + def fp_kernel( + self, + src1: Sequence[FPVar], + dst: Sequence[FPVar], + *, + src2: Optional[Sequence[FPVar]] = None, + op: str = "add", + task_id: str = "fp_kernel", + ) -> Dict[str, object]: + unary_ops = {"copy", "fill", "exp", "reci", "sqrt"} + binary_ops = {"add", "sub", "mul", "max"} + valid_ops = unary_ops | binary_ops + if op not in valid_ops: + raise ValueError(f"Unsupported fp_kernel op {op!r}; expected one of {sorted(valid_ops)}") + if op in binary_ops and src2 is None: + raise ValueError(f"Binary fp_kernel op {op!r} requires src2") + if op in unary_ops and src2 is not None: + raise ValueError(f"Unary fp_kernel op {op!r} does not accept src2") + + src1_vars = list(src1) + dst_vars = list(dst) + src2_vars = list(src2) if src2 is not None else None + if len(src1_vars) != len(dst_vars): + if op in {"copy", "fill"} and len(src1_vars) == 1 and len(dst_vars) > 1: + src1_vars = src1_vars * len(dst_vars) + else: + raise ValueError(f"fp_kernel expects matched src1/dst lengths, got {len(src1_vars)} vs {len(dst_vars)}") + if src2_vars is not None and len(src2_vars) != len(dst_vars): + raise ValueError(f"fp_kernel expects matched src2/dst lengths, got {len(src2_vars)} vs {len(dst_vars)}") + + self.isa_emitter.emit_fp_kernel( + src1_addrs=[_require_fp_addr(var) for var in src1_vars], + dst_addrs=[_require_fp_addr(var) for var in dst_vars], + src2_addrs=[_require_fp_addr(var) for var in src2_vars] if src2_vars is not None else None, + op=op, + task_id=task_id, + ) + record = { + "op_kind": "fp_kernel", + "task_id": task_id, + "op": op, + "src1": [var.name for var in src1_vars], + "src2": [var.name for var in src2_vars] if src2_vars is not None else None, + "dst": [var.name for var in dst_vars], + } + self.ops.append(record) + return record + + def pure_fp_compute( + self, + src1: Sequence[FPVar], + dst: Sequence[FPVar], + *, + src2: Optional[Sequence[FPVar]] = None, + op: str = "add", + task_id: str = "pure_fp_compute", + ) -> Dict[str, object]: + return self.fp_kernel(src1, dst, src2=src2, op=op, task_id=task_id) + + def row_operations( + self, + src: RowOperandLike, + *, + dst_operand: Optional[RowOperandLike] = None, + dst: Optional[Sequence[FPVar]] = None, + rhs: Optional[Sequence[FPVar]] = None, + op: str, + task_id: str = "row_operations", + ) -> Dict[str, object]: + if isinstance(src, ValueTileView): + backing_value = self.program.value_manager.value_tiles.get(src.backing_value_tile_id) + if not isinstance(backing_value, ValueTile): + raise RuntimeError(f"row_operations view source is missing backing value: {src.view_id}") + self.program.value_manager.ensure_value_tile_in_place(backing_value, "vram") + src_vram_addr = backing_value.residency.get("vram_addr") + row_count = int(src.row_count) + mask_unit = int(self.program.btmm_hlen) + col_offset = int(src.col_offset) + col_count = int(src.col_count) + if mask_unit <= 0: + raise RuntimeError(f"row_operations requires positive mask_unit, got {mask_unit}") + if col_offset % mask_unit != 0 or col_count % mask_unit != 0: + raise RuntimeError( + f"row_operations view mask expects col_offset/col_count aligned to mask_unit={mask_unit}, " + f"got col_offset={col_offset} col_count={col_count}" + ) + lane_start = col_offset // mask_unit + lane_count = col_count // mask_unit + mask_val = ((1 << lane_count) - 1) << lane_start + src_name = src.view_id + else: + self.program.value_manager.ensure_value_tile_in_place(src, "vram") + src_vram_addr = src.residency.get("vram_addr") + row_count = int(src.logical_shape[0]) + mask_val = None + src_name = src.value_tile_id + if src_vram_addr is None: + raise RuntimeError(f"row_operations requires src in VRAM: {src_name}") + if dst_operand is None: + dst_operand = src + if isinstance(dst_operand, ValueTileView): + dst_backing_value = self.program.value_manager.value_tiles.get(dst_operand.backing_value_tile_id) + if not isinstance(dst_backing_value, ValueTile): + raise RuntimeError(f"row_operations view destination is missing backing value: {dst_operand.view_id}") + self.program.value_manager.ensure_value_tile_in_place(dst_backing_value, "vram") + dst_vram_addr = dst_backing_value.residency.get("vram_addr") + else: + self.program.value_manager.ensure_value_tile_in_place(dst_operand, "vram") + dst_vram_addr = dst_operand.residency.get("vram_addr") + if dst_vram_addr is None: + raise RuntimeError(f"row_operations requires dst in VRAM: {task_id}") + + dst_addrs = [_require_fp_addr(var) for var in dst] if dst is not None else None + rhs_addrs = [_require_fp_addr(var) for var in rhs] if rhs is not None else None + self.isa_emitter.emit_row_operation( + src_vram_addr=int(src_vram_addr), + dst_vram_addr=int(dst_vram_addr), + dst_addrs=dst_addrs, + rhs_addrs=rhs_addrs, + row_count=row_count, + mask_val=mask_val, + op=op, + task_id=task_id, + ) + record = { + "op_kind": "row_operations", + "task_id": task_id, + "op": op, + "src": src_name, + "dst_operand": getattr(dst_operand, "view_id", getattr(dst_operand, "value_tile_id", None)), + "dst": [var.name for var in dst] if dst is not None else None, + "rhs": [var.name for var in rhs] if rhs is not None else None, + "mask_val": mask_val, + } + self.ops.append(record) + return record + + diff --git a/tilelang_runtime_compier/tile_tensor_program/_hardware_manager.py b/tilelang_runtime_compier/tile_tensor_program/_hardware_manager.py new file mode 100644 index 0000000..62d16cd --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_hardware_manager.py @@ -0,0 +1,24 @@ +"""HardwareManager: registry for simulated HBM/VRAM/MRAM objects.""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class HardwareManager: + """Registry for simulated HBM/VRAM/MRAM objects and placement metadata. + + This layer tracks hardware-visible objects only. It does not own tensor + grouping, value/scatter binding policy, or compute semantics. + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.hbm_objects: Dict[str, Dict[str, object]] = {} + self.vram_objects: Dict[str, Dict[str, object]] = {} + self.mram_objects: Dict[str, Dict[str, object]] = {} + + diff --git a/tilelang_runtime_compier/tile_tensor_program/_helpers.py b/tilelang_runtime_compier/tile_tensor_program/_helpers.py new file mode 100644 index 0000000..c2b8779 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_helpers.py @@ -0,0 +1,728 @@ +"""Module-level helper functions used by managers and TileTensorProgram.""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +from ._types import * # noqa: F401,F403 + + +__all__ = [ + "_logical_shape_to_physical_shape", + "_logical_selectors_to_physical_ranges", + "_slice_item_to_range", + "_ranges_overlap", + "_tiles_in_grid_order", + "_bshd_tile_batch_index", + "_bshd_tile_seq_block", + "_is_tile_object", + "_is_full_element_index", + "_contains_parallel_selector", + "_normalize_index", + "_tile_owner_name", + "_logical_shape_to_hbm_stride", + "_tile_coord_to_hbm_offset", + "_logical_3d_selectors_to_flat_col_range", + "_logical_indices_to_physical_coord", + "_physical_tile_coord_to_fp_index", + "_vector_tile_row_fp_groups", + "_unwrap_transposed_operand", + "_is_transposed_operand", + "_is_narrow_tile", + "_is_fp_domain_operand", + "_is_parallel_graph_operand", + "_coerce_parallel_expr", + "_collect_parallel_accesses", + "_collect_parallel_predicates", + "_parallel_access_identity", + "_parallel_expr_identity", + "_infer_parallel_load_metadata", + "_infer_parallel_predicate_kind", + "_build_parallel_execution_plan", + "_infer_parallel_elem_width", + "_build_parallel_cycle_plan", + "_iter_fp_indices", + "_iter_logical_indices", + "_iter_selected_logical_indices", + "_format_fp_index", + "_require_fp_addr", + "_fp_fragment_shape_to_tile_shape", + "_fp_fragment_row_fp_vars", +] + + +def _logical_shape_to_physical_shape(logical_shape: LogicalShape) -> Tuple[int, int]: + if len(logical_shape) == 4: + b, s, h, d = logical_shape + return b * s, h * d + if len(logical_shape) == 3: + x, y, z = logical_shape + return x, y * z + if len(logical_shape) == 2: + return logical_shape[0], logical_shape[1] + raise NotImplementedError(f"Unsupported logical shape: {logical_shape}") + + +def _logical_selectors_to_physical_ranges( + logical_shape: LogicalShape, + selectors: Tuple[SliceItem, ...], +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + normalized = list(selectors) + [slice(None)] * max(0, len(logical_shape) - len(selectors)) + if len(logical_shape) == 4: + b, s, h, d = logical_shape + b_sel, s_sel, h_sel, d_sel = normalized[:4] + b_range = _slice_item_to_range(b_sel, b) + s_range = _slice_item_to_range(s_sel, s) + h_range = _slice_item_to_range(h_sel, h) + d_range = _slice_item_to_range(d_sel, d) + row_range = (b_range[0] * s + s_range[0], (b_range[1] - 1) * s + s_range[1]) + col_range = (h_range[0] * d + d_range[0], (h_range[1] - 1) * d + d_range[1]) + return row_range, col_range + if len(logical_shape) == 3: + rows, outer, inner = logical_shape + row_sel, outer_sel, inner_sel = normalized[:3] + row_range = _slice_item_to_range(row_sel, rows) + col_range = _logical_3d_selectors_to_flat_col_range( + outer_extent=outer, + inner_extent=inner, + outer_selector=outer_sel, + inner_selector=inner_sel, + ) + return row_range, col_range + if len(logical_shape) == 2: + rows, cols = logical_shape + row_sel, col_sel = normalized[:2] + return _slice_item_to_range(row_sel, rows), _slice_item_to_range(col_sel, cols) + raise NotImplementedError(f"Unsupported logical shape for selectors: {logical_shape}") + + +def _slice_item_to_range(selector: SliceItem, extent: int) -> Tuple[int, int]: + if isinstance(selector, int): + index = selector if selector >= 0 else extent + selector + return index, index + 1 + start = 0 if selector.start is None else selector.start + stop = extent if selector.stop is None else selector.stop + return start, stop + + +def _ranges_overlap(lhs: Tuple[int, int], rhs: Tuple[int, int]) -> bool: + return lhs[0] < rhs[1] and rhs[0] < lhs[1] + + +def _tiles_in_grid_order(tiles: Dict[TileCoord, object]) -> List[object]: + return [tile for _, tile in sorted(tiles.items(), key=lambda item: item[0])] + + +def _bshd_tile_batch_index(tile: object) -> int: + return int(getattr(tile, "metadata", {}).get("batch_index", 0)) + + +def _bshd_tile_seq_block(tile: object) -> int: + return int(getattr(tile, "metadata", {}).get("seq_block", getattr(tile, "coord", (0, 0))[0])) + + +def _is_tile_object(tile: object) -> bool: + return isinstance(tile, (TensorTile, InputTile, VectorTile)) + + +def _is_full_element_index(item: Tuple[SliceItem, ...], rank: int) -> bool: + return len(item) == rank and all(isinstance(index, int) for index in item) + + +def _contains_parallel_selector(item: Tuple[object, ...]) -> bool: + return any(isinstance(selector, (ParallelAxis, ParallelExpr, ParallelAccess)) for selector in item) + + +def _normalize_index(index: int, extent: int) -> int: + normalized = int(index) + if normalized < 0: + normalized += int(extent) + if normalized < 0 or normalized >= int(extent): + raise IndexError(f"Index {index} is out of range for extent {extent}") + return normalized + + +def _tile_owner_name(tile: TileLike) -> str: + if isinstance(tile, InputTile): + return tile.input_name + return tile.tensor_name + + +def _logical_shape_to_hbm_stride(logical_shape: LogicalShape) -> int: + if len(logical_shape) == 4: + _, _, heads, head_dim = logical_shape + return int(heads) * int(head_dim) + rows, cols = _logical_shape_to_physical_shape(logical_shape) + return int(cols if cols > 0 else rows) + + +def _tile_coord_to_hbm_offset(coord: TileCoord, logical_shape: LogicalShape, mlen: int) -> int: + _, stride = _logical_shape_to_physical_shape(logical_shape)[0], _logical_shape_to_hbm_stride(logical_shape) + return int(coord[0]) * int(mlen) * int(stride) + int(coord[1]) * int(mlen) + + +def _logical_3d_selectors_to_flat_col_range( + *, + outer_extent: int, + inner_extent: int, + outer_selector: SliceItem, + inner_selector: SliceItem, +) -> Tuple[int, int]: + outer_range = _slice_item_to_range(outer_selector, outer_extent) + inner_range = _slice_item_to_range(inner_selector, inner_extent) + outer_full = outer_range == (0, outer_extent) + inner_full = inner_range == (0, inner_extent) + if isinstance(outer_selector, int): + base = outer_range[0] * inner_extent + return base + inner_range[0], base + inner_range[1] + if inner_full: + return outer_range[0] * inner_extent, outer_range[1] * inner_extent + if outer_range[1] - outer_range[0] == 1: + base = outer_range[0] * inner_extent + return base + inner_range[0], base + inner_range[1] + if outer_full and inner_range[1] - inner_range[0] == inner_extent: + return 0, outer_extent * inner_extent + raise NotImplementedError( + "3D vector slicing currently supports full-inner slices or one selected outer lane; " + f"got outer={outer_selector!r} inner={inner_selector!r}" + ) + + +def _logical_indices_to_physical_coord( + logical_shape: LogicalShape, + indices: Tuple[int, ...], +) -> Tuple[int, int]: + if len(logical_shape) != len(indices): + raise ValueError(f"logical_indices rank mismatch: shape={logical_shape} indices={indices}") + if len(logical_shape) == 4: + b, s, h, d = logical_shape + bi, si, hi, di = indices + return bi * s + si, hi * d + di + if len(logical_shape) == 3: + x, y, z = logical_shape + xi, yi, zi = indices + return xi, yi * z + zi + if len(logical_shape) == 2: + ri, ci = indices + return ri, ci + raise NotImplementedError(f"Unsupported logical shape for element indices: {logical_shape}") + + +def _physical_tile_coord_to_fp_index( + fragment_shape: Tuple[int, ...], + *, + local_row: int, + local_col: int, + mlen: int, + btmm_hlen: int, +) -> FPIndex: + normalized = tuple(int(dim) for dim in fragment_shape) + if len(normalized) == 2: + rows, cols = normalized + if local_row < 0 or local_row >= rows or local_col < 0 or local_col >= cols: + raise IndexError( + f"Local tile coord ({local_row}, {local_col}) is out of range for FP fragment shape {fragment_shape}" + ) + return int(local_row), int(local_col) + if normalized == (mlen, mlen): + return int(local_row), int(local_col) + if btmm_hlen > 0 and normalized == (mlen, mlen // btmm_hlen, btmm_hlen): + return int(local_row), int(local_col // btmm_hlen), int(local_col % btmm_hlen) + raise ValueError( + f"Unsupported fp fragment shape for tile-element mapping: {fragment_shape}; " + f"expected ({mlen}, {mlen}) or ({mlen}, {mlen // btmm_hlen if btmm_hlen > 0 else 'invalid'}, {btmm_hlen})" + ) + + +def _vector_tile_row_fp_groups( + *, + src_tile: VectorTile, + fragment: FPFragment, + mlen: int, + btmm_hlen: int, + src_slice_ranges: Optional[Tuple[Tuple[int, int], Tuple[int, int]]], +) -> List[List[FPVar]]: + row_block, col_block = src_tile.coord + row_start = row_block * mlen + row_end = row_start + int(src_tile.tile_shape[0]) + col_start = col_block * mlen + col_end = col_start + int(src_tile.tile_shape[1]) + + if src_slice_ranges is None: + use_row_start, use_row_end = row_start, row_end + use_col_start, use_col_end = col_start, col_end + else: + req_row_range, req_col_range = src_slice_ranges + use_row_start = max(row_start, int(req_row_range[0])) + use_row_end = min(row_end, int(req_row_range[1])) + use_col_start = max(col_start, int(req_col_range[0])) + use_col_end = min(col_end, int(req_col_range[1])) + if use_row_start >= use_row_end or use_col_start >= use_col_end: + return [] + + groups: List[List[FPVar]] = [] + for physical_row in range(use_row_start, use_row_end): + local_row = physical_row - row_start + row_vars: List[FPVar] = [] + for physical_col in range(use_col_start, use_col_end): + local_col = physical_col - col_start + fp_index = _physical_tile_coord_to_fp_index( + fragment.shape, + local_row=local_row, + local_col=local_col, + mlen=mlen, + btmm_hlen=btmm_hlen, + ) + fp_var = fragment.vars.get(fp_index) + if not isinstance(fp_var, FPVar): + raise RuntimeError( + f"VectorTile {src_tile.tile_id} bound to fragment {fragment.name!r} is missing fp cell {fp_index}" + ) + row_vars.append(fp_var) + groups.append(row_vars) + return groups + + +def _unwrap_transposed_operand(operand: object) -> object: + if isinstance(operand, (TensorTranspose, InputTranspose, VectorTranspose)): + return operand.base + return operand + + +def _is_transposed_operand(operand: object) -> bool: + return isinstance(operand, (TensorTranspose, InputTranspose, VectorTranspose)) + + +def _is_narrow_tile(tile: TileLike) -> bool: + mlen = int(tile.metadata.get("mlen", tile.tile_shape[0])) + return tile.tile_shape[0] != mlen or tile.tile_shape[1] != mlen + + +def _is_fp_domain_operand(operand: object) -> bool: + return isinstance( + operand, + ( + FPVar, + FPFragment, + FPFragmentSlice, + Vector, + VectorSlice, + VectorTile, + ElementRef, + ), + ) + + +def _is_parallel_graph_operand(operand: object) -> bool: + return isinstance(operand, (ParallelAxis, ParallelAccess, ParallelExpr)) + + +def _coerce_parallel_expr(value: object) -> ParallelExpr: + if isinstance(value, ParallelExpr): + return value + if isinstance(value, ParallelAxis): + return value._as_expr() + if isinstance(value, ParallelAccess): + return value._as_expr() + if isinstance(value, FPVar): + return ParallelExpr(kind="fpvar", value=value) + if isinstance(value, (int, float)): + return ParallelExpr(kind="literal", value=float(value)) + raise TypeError(f"Unsupported parallel expression operand: {type(value).__name__}") + + +def _collect_parallel_accesses(expr: ParallelExpr) -> List[ParallelAccess]: + accesses: List[ParallelAccess] = [] + if expr.kind == "load": + access = expr.value + if isinstance(access, ParallelAccess): + accesses.append(access) + return accesses + for arg in expr.args: + accesses.extend(_collect_parallel_accesses(arg)) + return accesses + + +def _collect_parallel_predicates(expr: ParallelExpr) -> List[str]: + predicates: List[str] = [] + if expr.kind == "select" and expr.args: + predicate = expr.args[0] + predicates.append(_infer_parallel_predicate_kind(predicate)) + for arg in expr.args: + predicates.extend(_collect_parallel_predicates(arg)) + return predicates + + +def _parallel_access_identity(access: ParallelAccess) -> str: + base_name = getattr(access.base, "name", type(access.base).__name__) + return f"{base_name}{tuple(access.selectors)!r}" + + +def _parallel_expr_identity(expr: ParallelExpr) -> str: + if expr.kind == "literal": + return f"literal({expr.value})" + if expr.kind == "fpvar": + fp_var = expr.value + if isinstance(fp_var, FPVar): + return f"fpvar({fp_var.name})" + return "fpvar(?)" + if expr.kind == "axis": + axis = expr.value + if isinstance(axis, ParallelAxis): + return f"axis({axis.name})" + return "axis(?)" + if expr.kind == "load": + access = expr.value + if isinstance(access, ParallelAccess): + return f"load({_parallel_access_identity(access)})" + return "load(?)" + if expr.kind in {"pair_index", "half_index"}: + args = ",".join(_parallel_expr_identity(arg) for arg in expr.args) + return f"{expr.kind}({args})" + if expr.kind == "unary_op": + args = ",".join(_parallel_expr_identity(arg) for arg in expr.args) + return f"{expr.op}({args})" + if expr.kind == "op": + args = ",".join(_parallel_expr_identity(arg) for arg in expr.args) + return f"{expr.op}({args})" + if expr.kind == "select": + args = ",".join(_parallel_expr_identity(arg) for arg in expr.args) + return f"select({args})" + return expr.kind + + +def _infer_parallel_load_metadata(access: ParallelAccess) -> Dict[str, object]: + metadata: Dict[str, object] = {} + selectors = tuple(access.selectors) + if selectors and isinstance(selectors[-1], ParallelExpr): + lane_expr = selectors[-1] + if lane_expr.kind == "pair_index": + metadata["companion_kind"] = "pair_swap" + elif lane_expr.kind == "half_index": + metadata["coefficient_layout"] = "preexpanded_full_lane" + return metadata + + +def _infer_parallel_predicate_kind(expr: ParallelExpr) -> str: + if ( + expr.kind == "op" + and expr.op == "eq" + and len(expr.args) == 2 + and ((expr.args[0].kind == "op" and expr.args[0].op == "mod") or (expr.args[1].kind == "op" and expr.args[1].op == "mod")) + ): + return "even_mask" + return "generic" + + +def _build_parallel_execution_plan( + region: ParallelRegionGraph, + *, + program: "TileTensorProgram", +) -> ParallelExecutionPlan: + ext_i, ext_j, ext_k = (int(dim) for dim in region.extents) + elem_width = _infer_parallel_elem_width(region=region, program=program) + if elem_width <= 0 or program.mlen % elem_width != 0: + raise ValueError( + f"parallel execution plan requires elem_width to divide mlen: elem_width={elem_width}, mlen={program.mlen}" + ) + if elem_width == int(program.mlen): + k_step = int(program.mlen) + k_count_per_cycle = 1 + else: + k_step = max(1, program.mlen // elem_width) + k_count_per_cycle = k_step + cycle_groups: List[ParallelCycleGroup] = [] + cycle_plans: List[ParallelCyclePlan] = [] + pack_axis1_lanes = elem_width < int(program.mlen) and ext_k == elem_width + if pack_axis1_lanes: + if ext_j % k_count_per_cycle != 0: + raise NotImplementedError( + "parallel execution plan currently requires axis-1 lane packing to divide ext_j evenly; " + f"ext_j={ext_j}, lanes_per_cycle={k_count_per_cycle}, elem_width={elem_width}, mlen={program.mlen}" + ) + for i_index in range(ext_i): + for j_index in range(0, ext_j, k_count_per_cycle): + group = ParallelCycleGroup( + i_index=int(i_index), + j_index=int(j_index), + k_base=0, + k_count=int(k_count_per_cycle), + elem_width=int(elem_width), + element_count=int(k_count_per_cycle * elem_width), + ) + cycle_groups.append(group) + cycle_plans.append(_build_parallel_cycle_plan(region=region, group=group)) + else: + for i_index in range(ext_i): + for j_index in range(ext_j): + for k_base in range(0, ext_k, k_step): + if elem_width == int(program.mlen): + k_count = 1 + element_count = int(program.mlen) + else: + k_count = min(k_count_per_cycle, ext_k - k_base) + element_count = int(k_count * elem_width) + group = ParallelCycleGroup( + i_index=int(i_index), + j_index=int(j_index), + k_base=int(k_base), + k_count=int(k_count), + elem_width=int(elem_width), + element_count=element_count, + ) + cycle_groups.append(group) + cycle_plans.append(_build_parallel_cycle_plan(region=region, group=group)) + return ParallelExecutionPlan( + region_name=region.name, + cycle_groups=cycle_groups, + cycle_plans=cycle_plans, + metadata={ + "elem_width": int(elem_width), + "k_count_per_cycle": int(k_count_per_cycle), + "cycle_element_budget": int(program.mlen), + }, + ) + + +def _infer_parallel_elem_width( + *, + region: ParallelRegionGraph, + program: "TileTensorProgram", +) -> int: + candidate_widths: set[int] = set() + for assignment in region.assignments: + dst_shape = tuple(getattr(assignment.dst.base, "logical_shape", ())) + if len(dst_shape) < 1: + continue + candidate_widths.add(int(dst_shape[-1])) + if not candidate_widths: + return int(program.mlen) + if len(candidate_widths) != 1: + raise ValueError( + f"parallel execution plan currently expects one unified innermost width, got {sorted(candidate_widths)}" + ) + innermost_width = int(next(iter(candidate_widths))) + if innermost_width % int(program.mlen) == 0: + return int(program.mlen) + if innermost_width == int(program.btmm_hlen): + return int(program.btmm_hlen) + raise ValueError( + "parallel execution plan only supports innermost width that is either " + f"btmm_hlen ({int(program.btmm_hlen)}) or a multiple of mlen ({int(program.mlen)}); " + f"got {innermost_width}" + ) + + +def _build_parallel_cycle_plan( + *, + region: ParallelRegionGraph, + group: ParallelCycleGroup, +) -> ParallelCyclePlan: + input_slot_map: Dict[str, int] = {} + input_slots: List[ParallelInputCacheSlotPlan] = [] + output_slot_map: Dict[str, int] = {} + output_slots: List[ParallelOutputCacheSlotPlan] = [] + load_ops: List[ParallelLoadOp] = [] + compute_ops: List[ParallelComputeOp] = [] + writeback_ops: List[ParallelWritebackOp] = [] + + for assignment in region.assignments: + dst_access = assignment.dst + dst_identity = _parallel_access_identity(dst_access) + dst_slot_id = output_slot_map.get(dst_identity) + if dst_slot_id is None: + dst_slot_id = len(output_slots) + output_slot_map[dst_identity] = dst_slot_id + output_slots.append( + ParallelOutputCacheSlotPlan( + slot_id=dst_slot_id, + access=dst_access, + metadata={ + "group_i": group.i_index, + "group_j": group.j_index, + "group_k_base": group.k_base, + "group_k_count": group.k_count, + }, + ) + ) + writeback_ops.append( + ParallelWritebackOp( + slot_id=dst_slot_id, + access=dst_access, + metadata={ + "writeback_kind": "value_view_update", + "group_i": group.i_index, + "group_j": group.j_index, + "group_k_base": group.k_base, + "group_k_count": group.k_count, + }, + ) + ) + + input_slot_ids: List[int] = [] + for access in assignment.sources: + access_identity = _parallel_access_identity(access) + slot_id = input_slot_map.get(access_identity) + if slot_id is None: + slot_id = len(input_slots) + input_slot_map[access_identity] = slot_id + load_metadata = _infer_parallel_load_metadata(access) + source_kind = "direct_fpfragment" if isinstance(access.base, Vector) else "mapv_to_fpram" + input_slots.append( + ParallelInputCacheSlotPlan( + slot_id=slot_id, + access=access, + pattern_kind="uniform", + metadata={ + **load_metadata, + "source_kind": source_kind, + "group_i": group.i_index, + "group_j": group.j_index, + "group_k_base": group.k_base, + "group_k_count": group.k_count, + }, + ) + ) + if not isinstance(access.base, Vector): + load_ops.append( + ParallelLoadOp( + slot_id=slot_id, + access=access, + metadata={ + **load_metadata, + "group_i": group.i_index, + "group_j": group.j_index, + "group_k_base": group.k_base, + "group_k_count": group.k_count, + }, + ) + ) + input_slot_ids.append(slot_id) + + predicate_kinds = _collect_parallel_predicates(assignment.expr) + compute_ops.append( + ParallelComputeOp( + task_id=assignment.task_id, + dst_slot_id=dst_slot_id, + expr=assignment.expr, + input_slot_ids=input_slot_ids, + metadata={ + "predicate_kinds": predicate_kinds, + "processing_kind": "per_output_element_ordered", + }, + ) + ) + + return ParallelCyclePlan( + group=group, + input_slots=input_slots, + output_slots=output_slots, + load_ops=load_ops, + compute_ops=compute_ops, + writeback_ops=writeback_ops, + metadata={ + "writeback_at_cycle_end": True, + "processing_mode": "per_output_element_ordered", + }, + ) + + +def _iter_fp_indices(shape: Tuple[int, ...]) -> List[FPIndex]: + if not shape: + return [()] + indices: List[FPIndex] = [()] + for dim in shape: + next_indices: List[FPIndex] = [] + for prefix in indices: + for value in range(int(dim)): + next_indices.append(prefix + (value,)) + indices = next_indices + return indices + + +def _iter_logical_indices(shape: Tuple[int, ...]) -> List[Tuple[int, ...]]: + if not shape: + return [()] + indices: List[Tuple[int, ...]] = [()] + for dim in shape: + next_indices: List[Tuple[int, ...]] = [] + for prefix in indices: + for value in range(int(dim)): + next_indices.append(prefix + (value,)) + indices = next_indices + return indices + + +def _iter_selected_logical_indices( + shape: Tuple[int, ...], + selectors: Tuple[SliceItem, ...], +) -> List[Tuple[int, ...]]: + normalized = list(selectors) + [slice(None)] * max(0, len(shape) - len(selectors)) + selected: List[Tuple[int, ...]] = [] + for logical_index in _iter_logical_indices(shape): + keep = True + for dim_idx, selector in enumerate(normalized[: len(shape)]): + start, stop = _slice_item_to_range(selector, int(shape[dim_idx])) + if logical_index[dim_idx] < start or logical_index[dim_idx] >= stop: + keep = False + break + if keep: + selected.append(logical_index) + return selected + + +def _format_fp_index(index: FPIndex) -> str: + return "".join(f"[{value}]" for value in index) + + +def _require_fp_addr(fp_var: FPVar) -> int: + if fp_var.fp_mem_addr is None: + raise RuntimeError(f"FPVar {fp_var.name!r} has no fp_mem_addr") + return int(fp_var.fp_mem_addr) + + +def _fp_fragment_shape_to_tile_shape( + shape: Tuple[int, ...], + *, + mlen: int, + btmm_hlen: int, +) -> Tuple[int, int]: + normalized = tuple(int(dim) for dim in shape) + if len(normalized) == 2 and 0 < normalized[0] <= mlen and 0 < normalized[1] <= mlen: + return normalized[0], normalized[1] + if normalized == (mlen, mlen): + return mlen, mlen + if btmm_hlen > 0: + expected_lane_count = mlen // btmm_hlen + if normalized == (mlen, expected_lane_count, btmm_hlen): + return mlen, mlen + raise ValueError( + "fpram-interacting FPFragment must have shape " + f"({mlen}, {mlen}) or ({mlen}, {mlen // btmm_hlen if btmm_hlen > 0 else 'invalid'}, {btmm_hlen}), " + f"got {shape}" + ) + + +def _fp_fragment_row_fp_vars( + fragment: FPFragment, + *, + row_index: int, + row_width: int, + btmm_hlen: int, +) -> List[FPVar]: + shape = tuple(int(dim) for dim in fragment.shape) + if row_index < 0 or row_index >= int(shape[0]): + raise IndexError(f"row_index {row_index} out of range for FPFragment {fragment.name!r} with shape {shape}") + if shape == (row_width, row_width): + return [fragment.vars[(row_index, col_index)] for col_index in range(int(row_width))] + if btmm_hlen > 0: + packed_head_count = row_width // btmm_hlen + if shape == (row_width, packed_head_count, btmm_hlen): + row_vars: List[FPVar] = [] + for head_index in range(packed_head_count): + for col_index in range(btmm_hlen): + row_vars.append(fragment.vars[(row_index, head_index, col_index)]) + return row_vars + raise ValueError( + f"FPFragment {fragment.name!r} with shape {shape} cannot be materialized as one {row_width}-wide row" + ) diff --git a/tilelang_runtime_compier/tile_tensor_program/_isa_emitter.py b/tilelang_runtime_compier/tile_tensor_program/_isa_emitter.py new file mode 100644 index 0000000..f1f802b --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_isa_emitter.py @@ -0,0 +1,802 @@ +"""ISAEmitter: turns prepared tile/FP operations into ISA strings. + +Owns all `emit_*` methods (HBM/VRAM transfer, BTMM, matmul, FP kernels, +row operations, etc.). Managers and TileTensorProgram hold a reference +to an ISAEmitter and call its methods directly rather than going through +the program object. +""" + +from __future__ import annotations + +import math +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from compiler.asm_templates import preload_addr_reg_asm, reset_reg_asm + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class ISAEmitter: + """Emit ISA strings for already-prepared tensor/FP operations.""" + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + + def _emit_preload_tile_isa( + self, + *, + vlen: int, + preload_len: int, + batch: int, + hidden_size: int, + act_vram_offset: int, + alive_registers: List[int], + activation_offset_reg: int, + stride_size: Optional[int] = None, + scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + ) -> str: + generated_code = "; Preload Activation Generation \n" + a_actual_register = alive_registers[0] + set_stride_register = alive_registers[1] + result_register = alive_registers[2] + outer_loop_register = alive_registers[3] + inner_loop_register = alive_registers[4] + + stride_len = vlen if stride_size is None else int(stride_size) + scale_len = hidden_size * batch if scale_size is None else int(scale_size) + load_amount_per_hidden = math.ceil(hidden_size / vlen) + + generated_code += f"S_ADDI_INT gp{a_actual_register}, gp0, {scale_len} \n" + generated_code += f"C_SET_SCALE_REG gp{a_actual_register} \n" + generated_code += f"S_ADDI_INT gp{a_actual_register}, gp0, {int(hbm_start_offset)} \n" + generated_code += f"S_ADDI_INT gp{result_register}, gp0, {act_vram_offset} \n" + + if batch == 1: + elements_per_prefetch = vlen * preload_len + for _ in range(math.ceil(hidden_size / elements_per_prefetch)): + generated_code += ( + f"H_PREFETCH_V gp{result_register}, gp{a_actual_register}, " + f"a{activation_offset_reg}, 0, 0, 0 \n" + ) + generated_code += ( + f"S_ADDI_INT gp{result_register}, gp{result_register}, {elements_per_prefetch} \n" + ) + generated_code += ( + f"S_ADDI_INT gp{a_actual_register}, gp{a_actual_register}, {elements_per_prefetch} \n" + ) + return generated_code + + generated_code += f"S_ADDI_INT gp{set_stride_register}, gp0, {stride_len} \n" + generated_code += f"C_SET_STRIDE_REG gp{set_stride_register} \n" + a_offset_register = set_stride_register + generated_code += f"C_LOOP_START gp{outer_loop_register}, {load_amount_per_hidden} \n" + generated_code += f"S_ADDI_INT gp{a_offset_register}, gp{a_actual_register}, 0 \n" + if batch > preload_len: + generated_code += f"C_LOOP_START gp{inner_loop_register}, {math.ceil(batch / preload_len)} \n" + generated_code += f"H_PREFETCH_V gp{result_register}, gp{a_offset_register}, a{activation_offset_reg}, 1, 0 \n" + generated_code += f"S_ADDI_INT gp{result_register}, gp{result_register}, {vlen * preload_len} \n" + if batch > preload_len: + generated_code += ( + f"S_ADDI_INT gp{a_offset_register}, gp{a_offset_register}, {stride_len * preload_len} \n" + ) + generated_code += f"C_LOOP_END gp{inner_loop_register} \n" + generated_code += f"S_ADDI_INT gp{a_actual_register}, gp{a_actual_register}, {vlen} \n" + generated_code += f"C_LOOP_END gp{outer_loop_register} \n" + return generated_code + + def _emit_store_tile_isa( + self, + *, + vlen: int, + batch: int, + hidden_size: int, + alive_registers: List[int], + act_vram_offset: int, + hbm_addr_reg: int, + stride_size: Optional[int] = None, + scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + store_amount: int = 4, + ) -> str: + generated_code = "; Store Activation Generation\n" + + hbm_offset_reg = alive_registers[0] + set_stride_register = alive_registers[1] + vram_reg = alive_registers[2] + outer_loop_register = alive_registers[3] + inner_loop_register = alive_registers[4] + + stride_len = hidden_size if stride_size is None else int(stride_size) + scale_len = hidden_size * batch if scale_size is None else int(scale_size) + store_amount_per_hidden = math.ceil(hidden_size / vlen) + + generated_code += f"S_ADDI_INT gp{vram_reg}, gp0, {act_vram_offset}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp0, {scale_len}\n" + generated_code += f"C_SET_SCALE_REG gp{hbm_offset_reg}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp0, {int(hbm_start_offset)}\n" + + if batch == 1: + elements_per_store = vlen * store_amount + for _ in range(math.ceil(hidden_size / elements_per_store)): + generated_code += f"H_STORE_V gp{vram_reg}, gp{hbm_offset_reg}, a{hbm_addr_reg}, 0, 0\n" + generated_code += f"S_ADDI_INT gp{vram_reg}, gp{vram_reg}, {elements_per_store}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp{hbm_offset_reg}, {elements_per_store}\n" + return generated_code + + generated_code += f"S_ADDI_INT gp{set_stride_register}, gp0, {stride_len}\n" + generated_code += f"C_SET_STRIDE_REG gp{set_stride_register}\n" + hbm_base_reg = set_stride_register + generated_code += f"C_LOOP_START gp{outer_loop_register}, {store_amount_per_hidden}\n" + generated_code += f"S_ADDI_INT gp{hbm_base_reg}, gp{hbm_offset_reg}, 0\n" + if batch > store_amount: + generated_code += f"C_LOOP_START gp{inner_loop_register}, {math.ceil(batch / store_amount)}\n" + generated_code += f"H_STORE_V gp{vram_reg}, gp{hbm_base_reg}, a{hbm_addr_reg}, 1, 0\n" + generated_code += f"S_ADDI_INT gp{vram_reg}, gp{vram_reg}, {vlen * store_amount}\n" + if batch > store_amount: + generated_code += f"S_ADDI_INT gp{hbm_base_reg}, gp{hbm_base_reg}, {stride_len * store_amount}\n" + generated_code += f"C_LOOP_END gp{inner_loop_register}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp{hbm_offset_reg}, {vlen}\n" + generated_code += f"C_LOOP_END gp{outer_loop_register}\n" + return generated_code + + def emit_hbm_tile_to_mram( + self, + *, + hbm_addr: int, + mram_addr: int, + hbm_offset: int = 0, + hbm_scale: Optional[int] = None, + hbm_stride: Optional[int] = None, + ) -> None: + addr_reg = self.program.compiler.register_allocator.allocate_addr(1)[0] + gp_addr = self.program.compiler.register_allocator.allocate_gp(2) + gp_exec = self.program.compiler.register_allocator.allocate_gp(3) + gp_scale, gp_stride, gp_mram = gp_exec + scale_val = self.program.tile_elems if hbm_scale is None else int(hbm_scale) + stride_val = self.program.mlen if hbm_stride is None else int(hbm_stride) + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[hbm_addr], + ) + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {scale_val}\n" + isa += f"C_SET_SCALE_REG gp{gp_scale}\n" + isa += f"S_ADDI_INT gp{gp_stride}, gp0, {stride_val}\n" + isa += f"C_SET_STRIDE_REG gp{gp_stride}\n" + isa += f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}\n" + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}\n" + isa += f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{addr_reg}, 1, 0\n" + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {self.program.tile_elems}\n" + isa += f"C_SET_SCALE_REG gp{gp_scale}\n" + isa += f"S_ADDI_INT gp{gp_stride}, gp0, {self.program.mlen}\n" + isa += f"C_SET_STRIDE_REG gp{gp_stride}\n" + self.program.compiler.generated_code += isa + + self.program.compiler.register_allocator.free_gp(gp_addr) + self.program.compiler.register_allocator.free_gp(gp_exec) + self.program.compiler.register_allocator.free_addr([addr_reg]) + + def emit_load_tile_from_hbm( + self, + *, + hbm_addr: int, + vram_addr: int, + hbm_stride: Optional[int] = None, + hbm_scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + ) -> None: + addr_reg = self.program.compiler.register_allocator.allocate_addr(1)[0] + gp_addr = self.program.compiler.register_allocator.allocate_gp(1) + gp_preload = self.program.compiler.register_allocator.allocate_gp(5) + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[int(hbm_addr)], + ) + isa += reset_reg_asm(alive_registers=gp_preload) + isa += self._emit_preload_tile_isa( + vlen=self.program.mlen, + preload_len=self.program.blen, + batch=self.program.mlen, + hidden_size=self.program.mlen, + act_vram_offset=vram_addr, + alive_registers=gp_preload, + activation_offset_reg=addr_reg, + stride_size=self.program.mlen if hbm_stride is None else int(hbm_stride), + scale_size=self.program.tile_elems if hbm_scale_size is None else int(hbm_scale_size), + hbm_start_offset=int(hbm_start_offset), + ) + self.program.compiler.generated_code += isa + + self.program.compiler.register_allocator.free_gp(gp_addr) + self.program.compiler.register_allocator.free_gp(gp_preload) + self.program.compiler.register_allocator.free_addr([addr_reg]) + + def emit_store_tile_to_hbm( + self, + *, + vram_addr: int, + hbm_addr: int, + hbm_stride: Optional[int] = None, + hbm_scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + ) -> None: + addr_reg = self.program.compiler.register_allocator.allocate_addr(1)[0] + gp_addr = self.program.compiler.register_allocator.allocate_gp(1) + gp_store = self.program.compiler.register_allocator.allocate_gp(5) + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[int(hbm_addr)], + ) + isa += self._emit_store_tile_isa( + vlen=self.program.mlen, + batch=self.program.mlen, + hidden_size=self.program.mlen, + alive_registers=gp_store, + act_vram_offset=vram_addr, + hbm_addr_reg=addr_reg, + stride_size=self.program.mlen if hbm_stride is None else int(hbm_stride), + scale_size=self.program.tile_elems if hbm_scale_size is None else int(hbm_scale_size), + hbm_start_offset=int(hbm_start_offset), + store_amount=self.program.blen, + ) + self.program.compiler.generated_code += isa + + self.program.compiler.register_allocator.free_gp(gp_addr) + self.program.compiler.register_allocator.free_gp(gp_store) + self.program.compiler.register_allocator.free_addr([addr_reg]) + + def emit_zero_vram_tile(self, vram_addr: int) -> None: + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp, gp_loop = gp_regs + lines = [f"; zero tile vram[{vram_addr}]"] + lines.append(f"S_ADDI_INT gp{gp}, gp0, {vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.program.mlen}") + lines.append(f"V_MUL_VF gp{gp}, gp{gp}, f0, 0") + lines.append(f"S_ADDI_INT gp{gp}, gp{gp}, {self.program.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_map_v_fp_tile( + self, + *, + vram_addr: int, + fpram_addr: int, + row_count: int, + row_width: int, + task_id: str = "map_v_fp_tile", + ) -> None: + if row_count <= 0 or row_width <= 0: + raise ValueError(f"emit_map_v_fp_tile expects positive row_count/row_width, got {row_count}/{row_width}") + if row_width != self.program.mlen: + raise ValueError( + f"emit_map_v_fp_tile currently requires row_width == mlen == {self.program.mlen}, got {row_width}" + ) + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_dst, gp_src, gp_loop = gp_regs + lines = [f"; map fp tile task {task_id} fpram[{fpram_addr}] -> vram[{vram_addr}]"] + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {fpram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") + lines.append(f"S_MAP_V_FP gp{gp_dst}, gp{gp_src}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_width}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_width}") + lines.append(f"C_LOOP_END gp{gp_loop}") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_map_fp_v_tile( + self, + *, + fpram_addr: int, + vram_addr: int, + row_count: int, + row_width: int, + task_id: str = "map_fp_v_tile", + ) -> None: + if row_count <= 0 or row_width <= 0: + raise ValueError(f"emit_map_fp_v_tile expects positive row_count/row_width, got {row_count}/{row_width}") + if row_width != self.program.mlen: + raise ValueError( + f"emit_map_fp_v_tile currently requires row_width == mlen == {self.program.mlen}, got {row_width}" + ) + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_dst, gp_src, gp_loop = gp_regs + lines = [f"; map fp tile task {task_id} vram[{vram_addr}] -> fpram[{fpram_addr}]"] + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {row_count}") + lines.append(f"S_MAP_FP_V gp{gp_dst}, gp{gp_src}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {row_width}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {row_width}") + lines.append(f"C_LOOP_END gp{gp_loop}") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_btmm( + self, + *, + lhs_packed_vram_addr: int, + rhs_mram_addr: int, + task_id: str = "btmm", + ) -> None: + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_mram_base, gp_lhs_base = gp_regs + lines = [ + ( + f"; btmm task {task_id} lhs_packed=vram[{lhs_packed_vram_addr}] " + f"rhs_mram={rhs_mram_addr} lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_mram_base}, gp0, {rhs_mram_addr}", + f"S_ADDI_INT gp{gp_lhs_base}, gp0, {lhs_packed_vram_addr}", + f"M_BTMM gp0, gp{gp_mram_base}, gp{gp_lhs_base}", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp(gp_regs) + + def emit_btmm_wo( + self, + *, + base_addr: int, + tile_count: int, + task_id: str = "btmm_wo", + ) -> None: + gp_out = self.program.compiler.register_allocator.allocate_gp(1)[0] + lines = [ + ( + f"; btmm write-only task {task_id} out=vram[{base_addr}] " + f"tiles={tile_count} lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_out}, gp0, {base_addr}", + f"M_BMM_WO gp{gp_out}, 0", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp([gp_out]) + + def emit_matmul( + self, + *, + lhs_vram_addrs: Sequence[int], + rhs_mram_addrs: Sequence[int], + dst_vram_addr: int, + task_id: str = "matmul", + zero_dst: bool = False, + ) -> None: + if len(lhs_vram_addrs) != len(rhs_mram_addrs): + raise ValueError("lhs_vram_addrs and rhs_mram_addrs must have equal lengths") + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + lines = [f"; matmul task {task_id}"] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + lhs_prog = self.program._arith_progression([int(addr) for addr in lhs_vram_addrs]) + rhs_prog = self.program._arith_progression([int(addr) for addr in rhs_mram_addrs]) + + for oc in range(tiles_per_mlen): + for orow in range(tiles_per_mlen): + if lhs_prog is not None and rhs_prog is not None: + lhs_start, pair_count, lhs_step = lhs_prog + rhs_start, _, rhs_step = rhs_prog + act_addr = lhs_start + orow * self.program.blen * self.program.mlen + mat_addr = rhs_start + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {pair_count}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"S_ADDI_INT gp{gp_act}, gp{gp_act}, {lhs_step}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp{gp_mat}, {rhs_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + for lhs_addr, rhs_addr in zip(lhs_vram_addrs, rhs_mram_addrs): + act_addr = lhs_addr + orow * self.program.blen * self.program.mlen + mat_addr = rhs_addr + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + out_addr = dst_vram_addr + orow * self.program.blen * self.program.mlen + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_out}, gp0, {out_addr}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_slot_matmul( + self, + *, + lhs_vram_addr: int, + rhs_mram_addr: int, + rhs_col_offset: int, + dst_vram_addr: int, + dst_col_offset: int, + col_count: int, + task_id: str = "slot_matmul", + zero_dst: bool = False, + ) -> None: + if col_count <= 0: + raise ValueError("emit_slot_matmul requires one positive col_count") + if col_count % self.program.blen != 0: + raise ValueError( + f"emit_slot_matmul requires col_count divisible by blen={self.program.blen}, got {col_count}" + ) + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + tiles_per_slot = col_count // self.program.blen + lines = [f"; slot matmul task {task_id}"] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for oc in range(tiles_per_slot): + act_addr = lhs_vram_addr + mat_addr = rhs_mram_addr + rhs_col_offset + oc * self.program.blen + out_addr = dst_vram_addr + dst_col_offset + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + lines.append(f"S_ADDI_INT gp{gp_out}, gp0, {out_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {tiles_per_mlen}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_act}, gp{gp_act}, {self.program.blen * self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_out}, gp{gp_out}, {self.program.blen * self.program.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_matmul_narrow_tile_hwloop( + self, + *, + lhs_vram_addr: int, + rhs_mram_addr: int, + dst_vram_addr: int, + hlen: int, + rhs_col_offset: int = 0, + dst_col_offset: int = 0, + dst_row_stride: Optional[int] = None, + task_id: str = "matmul_narrow_hwloop", + zero_dst: bool = False, + ) -> None: + """Emit `mlen x mlen @ mlen x hlen` through M_MM/M_MM_WO.""" + if hlen <= 0: + raise ValueError("emit_matmul_narrow_tile_hwloop requires positive hlen") + if hlen > self.program.mlen: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires hlen <= mlen={self.program.mlen}, got {hlen}" + ) + if hlen % self.program.blen != 0: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires hlen divisible by blen={self.program.blen}, got {hlen}" + ) + if dst_row_stride is None: + dst_row_stride = int(hlen) + if dst_row_stride < hlen: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires dst_row_stride >= hlen ({hlen}), got {dst_row_stride}" + ) + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + tiles_per_slot = hlen // self.program.blen + output_row_stride = self.program.blen * int(dst_row_stride) + lines = [ + f"; narrow matmul task {task_id} lhs=vram[{lhs_vram_addr}] " + f"rhs=mram[{rhs_mram_addr}] rhs_col_offset={rhs_col_offset} " + f"dst=vram[{dst_vram_addr}] dst_col_offset={dst_col_offset} " + f"hlen={hlen} dst_row_stride={dst_row_stride}" + ] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for oc in range(tiles_per_slot): + act_addr = lhs_vram_addr + mat_addr = rhs_mram_addr + rhs_col_offset + oc * self.program.blen + out_addr = dst_vram_addr + dst_col_offset + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + lines.append(f"S_ADDI_INT gp{gp_out}, gp0, {out_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {tiles_per_mlen}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_act}, gp{gp_act}, {self.program.blen * self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_out}, gp{gp_out}, {output_row_stride}") + lines.append(f"C_LOOP_END gp{gp_loop}") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_tile_binary( + self, + *, + lhs_vram_addr: int, + rhs_vram_addr: int, + dst_vram_addr: int, + op: str = "add", + task_id: str = "tile_binary", + ) -> None: + op_to_insn = { + "add": "V_ADD_VV", + "sub": "V_SUB_VV", + "mul": "V_MUL_VV", + } + if op not in op_to_insn: + raise ValueError(f"Unsupported tile binary op={op!r}") + gp_regs = self.program.compiler.register_allocator.allocate_gp(4) + gp_dst, gp_lhs, gp_rhs, gp_loop = gp_regs + lines = [f"; tile binary task {task_id} op={op}"] + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_lhs}, gp0, {lhs_vram_addr}") + lines.append(f"S_ADDI_INT gp{gp_rhs}, gp0, {rhs_vram_addr}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.program.mlen}") + if op == "sub": + lines.append(f"{op_to_insn[op]} gp{gp_dst}, gp{gp_rhs}, gp{gp_lhs}, 0") + else: + lines.append(f"{op_to_insn[op]} gp{gp_dst}, gp{gp_lhs}, gp{gp_rhs}, 0") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_lhs}, gp{gp_lhs}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_rhs}, gp{gp_rhs}, {self.program.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_tile_add( + self, + *, + lhs_vram_addr: int, + rhs_vram_addr: int, + dst_vram_addr: int, + task_id: str = "tile_add", + ) -> None: + self.emit_tile_binary( + lhs_vram_addr=lhs_vram_addr, + rhs_vram_addr=rhs_vram_addr, + dst_vram_addr=dst_vram_addr, + op="add", + task_id=task_id, + ) + + def emit_fp_kernel( + self, + *, + src1_addrs: Sequence[int], + dst_addrs: Sequence[int], + src2_addrs: Optional[Sequence[int]] = None, + op: str, + task_id: str = "fp_kernel", + ) -> None: + unary_copy = {"copy", "fill"} + unary_math = {"exp": "S_EXP_FP", "reci": "S_RECI_FP", "sqrt": "S_SQRT_FP"} + binary_math = {"add": "S_ADD_FP", "sub": "S_SUB_FP", "mul": "S_MUL_FP", "max": "S_MAX_FP"} + if len(src1_addrs) != len(dst_addrs): + raise ValueError("emit_fp_kernel expects matched src1/dst lengths") + if src2_addrs is not None and len(src2_addrs) != len(dst_addrs): + raise ValueError("emit_fp_kernel expects matched src2/dst lengths") + if op in unary_copy: + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + lines = [f"; fp kernel task {task_id} op={op}"] + src_prog = self.program._arith_progression([int(addr) for addr in src1_addrs]) + dst_prog = self.program._arith_progression([int(addr) for addr in dst_addrs]) + if src_prog is not None and dst_prog is not None: + src_start, count, src_step = src_prog + dst_start, _, dst_step = dst_prog + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_start}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_start}") + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {src_step}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {dst_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + for src_addr, dst_addr in zip(src1_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + if op in unary_math: + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_src, gp_dst, gp_loop = gp_regs + lines = [f"; fp kernel task {task_id} op={op}"] + src_prog = self.program._arith_progression([int(addr) for addr in src1_addrs]) + dst_prog = self.program._arith_progression([int(addr) for addr in dst_addrs]) + if src_prog is not None and dst_prog is not None: + src_start, count, src_step = src_prog + dst_start, _, dst_step = dst_prog + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {src_start}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_start}") + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + if op in {"exp", "reci"}: + lines.append(f"{unary_math[op]} f1, f1, 0") + else: + lines.append(f"{unary_math[op]} f1, f1") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {src_step}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {dst_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + for src_addr, dst_addr in zip(src1_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + if op in {"exp", "reci"}: + lines.append(f"{unary_math[op]} f1, f1, 0") + else: + lines.append(f"{unary_math[op]} f1, f1") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + if op in binary_math: + if src2_addrs is None: + raise ValueError(f"emit_fp_kernel op={op!r} requires src2_addrs") + gp_regs = self.program.compiler.register_allocator.allocate_gp(4) + gp_a, gp_b, gp_dst, gp_loop = gp_regs + lines = [f"; fp kernel task {task_id} op={op}"] + src1_prog = self.program._arith_progression([int(addr) for addr in src1_addrs]) + src2_prog = self.program._arith_progression([int(addr) for addr in src2_addrs]) + dst_prog = self.program._arith_progression([int(addr) for addr in dst_addrs]) + if src1_prog is not None and src2_prog is not None and dst_prog is not None: + src1_start, count, src1_step = src1_prog + src2_start, _, src2_step = src2_prog + dst_start, _, dst_step = dst_prog + lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {src1_start}") + lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {src2_start}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_start}") + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_LD_FP f1, gp{gp_a}, 0") + lines.append(f"S_LD_FP f2, gp{gp_b}, 0") + lines.append(f"{binary_math[op]} f1, f1, f2") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + lines.append(f"S_ADDI_INT gp{gp_a}, gp{gp_a}, {src1_step}") + lines.append(f"S_ADDI_INT gp{gp_b}, gp{gp_b}, {src2_step}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {dst_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + for src1_addr, src2_addr, dst_addr in zip(src1_addrs, src2_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {int(src1_addr)}") + lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {int(src2_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_a}, 0") + lines.append(f"S_LD_FP f2, gp{gp_b}, 0") + lines.append(f"{binary_math[op]} f1, f1, f2") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + raise ValueError(f"Unsupported emit_fp_kernel op={op!r}") + + def emit_row_operation( + self, + *, + src_vram_addr: int, + dst_vram_addr: Optional[int] = None, + op: str, + row_count: int, + dst_addrs: Optional[Sequence[int]] = None, + rhs_addrs: Optional[Sequence[int]] = None, + mask_val: Optional[int] = None, + task_id: str = "row_operations", + ) -> None: + if row_count <= 0: + return + unary_ops = {"exp", "reci"} + reduce_ops = {"reduce_max": "V_RED_MAX", "reduce_sum": "V_RED_SUM"} + binary_ops = {"mul": "V_MUL_VF", "add": "V_ADD_VF", "sub": "V_SUB_VF"} + if op not in unary_ops | set(reduce_ops) | set(binary_ops): + raise ValueError(f"Unsupported emit_row_operation op={op!r}") + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_src, gp_fp, gp_dst, gp_loop, gp_mask = gp_regs + lines = [f"; row operation task {task_id} op={op} rows={row_count}"] + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_vram_addr)}") + dst_vram_addr = int(src_vram_addr if dst_vram_addr is None else dst_vram_addr) + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_vram_addr}") + use_mask = mask_val is not None + if use_mask: + lines.append(f"; row operation mask {int(mask_val)}") + lines.append(f"S_ADDI_INT gp{gp_mask}, gp0, {int(mask_val)}") + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + + if op in unary_ops: + lines.append(f"C_LOOP_START gp{gp_loop}, {int(row_count)}") + if op == "exp": + lines.append(f"V_EXP_V gp{gp_dst}, gp{gp_src}, {1 if use_mask else 0}") + else: + lines.append(f"V_RECI_V gp{gp_dst}, gp{gp_src}, {1 if use_mask else 0}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.program.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + elif op in reduce_ops: + if dst_addrs is None or len(dst_addrs) != row_count: + raise ValueError(f"emit_row_operation op={op!r} expects one dst fp addr per row") + dst_prog = self.program._arith_progression([int(addr) for addr in dst_addrs]) + if dst_prog is None: + for row_index, dst_addr in enumerate(dst_addrs): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_dst}, 0") + lines.append(f"{reduce_ops[op]} f1, gp{gp_src}, {1 if use_mask else 0}") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + else: + dst_start, count, dst_step = dst_prog + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_start}") + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_LD_FP f1, gp{gp_dst}, 0") + lines.append(f"{reduce_ops[op]} f1, gp{gp_src}, {1 if use_mask else 0}") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {dst_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + if rhs_addrs is None or len(rhs_addrs) not in (1, row_count): + raise ValueError(f"emit_row_operation op={op!r} expects one rhs fp addr or one per row") + rhs_prog = self.program._arith_progression([int(addr) for addr in rhs_addrs]) if len(rhs_addrs) > 1 else None + if len(rhs_addrs) == 1: + lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {int(rhs_addrs[0])}") + lines.append(f"C_LOOP_START gp{gp_loop}, {int(row_count)}") + lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") + if op == "sub": + lines.append(f"V_SUB_VF gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}, 0") + else: + lines.append(f"{binary_ops[op]} gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.program.mlen}") + lines.append(f"C_LOOP_END gp{gp_loop}") + elif rhs_prog is not None: + rhs_start, count, rhs_step = rhs_prog + lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {rhs_start}") + lines.append(f"C_LOOP_START gp{gp_loop}, {count}") + lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") + if op == "sub": + lines.append(f"V_SUB_VF gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}, 0") + else: + lines.append(f"{binary_ops[op]} gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp{gp_src}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, {self.program.mlen}") + lines.append(f"S_ADDI_INT gp{gp_fp}, gp{gp_fp}, {rhs_step}") + lines.append(f"C_LOOP_END gp{gp_loop}") + else: + for row_index, rhs_addr in enumerate(rhs_addrs): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + dst_row_addr = dst_vram_addr + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_row_addr}") + lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {int(rhs_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") + if op == "sub": + lines.append(f"V_SUB_VF gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}, 0") + else: + lines.append(f"{binary_ops[op]} gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}") + + if use_mask: + lines.append("S_ADDI_INT gp{0}, gp0, 0".format(gp_mask)) + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" diff --git a/tilelang_runtime_compier/tile_tensor_program/_program.py b/tilelang_runtime_compier/tile_tensor_program/_program.py new file mode 100644 index 0000000..5091b7b --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_program.py @@ -0,0 +1,1573 @@ +"""TileTensorProgram: top-level user-facing program builder.""" + +from __future__ import annotations + +import inspect +import sys +from math import ceil +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from compiler.asm_templates import preload_addr_reg_asm +from tiled_developer_compiler import TiledDeveloperCompiler +from operation_report_delta import build_delta_report, parse_operation_report + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 +from ._hardware_manager import HardwareManager +from ._isa_emitter import ISAEmitter +from ._thread_manager import ThreadManager, _LoopHintRange +from ._value_manager import ValueManager +from ._tensor_manager import TensorManager +from ._vector_manager import VectorManager +from ._compute_manager import ComputeManager + + +class TileTensorProgram: + """User-facing program builder over the logical/value/compute pipeline. + + This class exposes the testbench authoring API (`input`, `tensor`, `copy`, + `matmul`, `atomic_add`, FP helpers, reporting, and compile hooks) while + delegating the real work to TensorManager, ValueManager, and ComputeManager. + + In practice it acts as the orchestration layer for the current runtime law: + + mapt -> mapv -> compute -> mapv_back -> mapt_back + + The important modern write-path rule is: + + resolve view -> prepare_updated_view_value -> compute -> bind/writeback + + FP-domain operations are intentionally separate from the tensor + value/view pipeline. + """ + + def __init__( + self, + *, + mlen: int, + blen: int, + btmm_hlen: Optional[int] = None, + real_data_ratio: float = 1.0, + vram_tile_capacity: int = 0, + mram_tile_capacity: int = 0, + fpram_capacity: int = 0, + hbm_base_addr: int = 0, + ) -> None: + self.mlen = int(mlen) + self.blen = int(blen) + self.btmm_hlen = int(btmm_hlen) if btmm_hlen is not None else (self.mlen // self.blen) + if self.btmm_hlen <= 0 or self.mlen % self.btmm_hlen != 0: + raise ValueError( + f"Invalid btmm_hlen={self.btmm_hlen}; require positive divisor of mlen={self.mlen}" + ) + self.btmm_lane_count = self.mlen // self.btmm_hlen + self.real_data_ratio = float(real_data_ratio) + self.vram_tile_capacity = int(vram_tile_capacity) + self.mram_tile_capacity = int(mram_tile_capacity) + self.fpram_capacity = int(fpram_capacity) + self.tile_elems = self.mlen * self.mlen + self._next_hbm_addr = int(hbm_base_addr) + + self.compiler = TiledDeveloperCompiler( + mlen=self.mlen, + blen=self.blen, + fpram_total_size=(self.fpram_capacity or 1024), + ) + self.hardware = HardwareManager(self) + self.isa_emitter = ISAEmitter(self) + self.thread_manager = ThreadManager(self) + self.value_manager = ValueManager(self) + self.tensor_manager = TensorManager(self) + self.vector_manager = VectorManager(self) + self.compute_manager = ComputeManager(self) + self._auto_name_counters: Dict[str, int] = {} + self.loop_hints: List[Dict[str, int | str]] = [] + self.operation_snapshots: List[Dict[str, object]] = [] + self.eviction_warnings: List[Dict[str, object]] = [] + self._active_parallel_region_ids: List[int] = [] + self._parallel_region_counter = 0 + self._parallel_snapshot_keys: set[Tuple[int, str, int, str]] = set() + self._parallel_execution_lowered = False + + def input(self, name: str, logical_shape: LogicalShape, *, hbm_addr: Optional[int] = None) -> Input: + return self.tensor_manager.input(name, logical_shape, hbm_addr=hbm_addr) + + def tensor(self, name: str, logical_shape: LogicalShape) -> Tensor | Vector: + tensor = self.tensor_manager.tensor(name, logical_shape) + if isinstance(tensor, Vector): + self.vector_manager.initialize_vector_backing(tensor) + return tensor + + def vector(self, name: str, logical_shape: LogicalShape) -> Vector: + vector = self.vector_manager.vector(name, logical_shape) + self.vector_manager.initialize_vector_backing(vector) + return vector + + def parallel_region3d( + self, + extents: Tuple[int, int, int] | List[int], + *, + name: Optional[str] = None, + ) -> _ParallelRegionScope: + return self.thread_manager.parallel_region3d(extents, name=name) + + def parallel_region2d( + self, + extents: Tuple[int, int] | List[int], + *, + name: Optional[str] = None, + ) -> _ParallelRegion2DScope: + return self.thread_manager.parallel_region2d(extents, name=name) + + def where(self, predicate: object, on_true: object, on_false: object) -> ParallelExpr: + return self.thread_manager.where(predicate, on_true, on_false) + + def if_then_else(self, predicate: object, on_true: object, on_false: object) -> ParallelExpr: + return self.thread_manager.if_then_else(predicate, on_true, on_false) + + def max(self, lhs: object, rhs: object) -> ParallelExpr: + return ParallelExpr( + kind="op", + op="max", + args=(_coerce_parallel_expr(lhs), _coerce_parallel_expr(rhs)), + ) + + def exp(self, operand: object) -> ParallelExpr: + return ParallelExpr(kind="unary_op", op="exp", args=(_coerce_parallel_expr(operand),)) + + def reci(self, operand: object) -> ParallelExpr: + return ParallelExpr(kind="unary_op", op="reci", args=(_coerce_parallel_expr(operand),)) + + def sqrt(self, operand: object) -> ParallelExpr: + return ParallelExpr(kind="unary_op", op="sqrt", args=(_coerce_parallel_expr(operand),)) + + def pair(self, axis: object) -> ParallelExpr: + return self.thread_manager.pair(axis) + + def half_index(self, axis: object) -> ParallelExpr: + return self.thread_manager.half_index(axis) + + def parallel_execution_plans(self) -> List[ParallelExecutionPlan]: + return self.thread_manager.parallel_execution_plans() + + def lower_parallel_execution_plans(self) -> None: + self.thread_manager.lower_parallel_execution_plans() + + + def _auto_name(self, prefix: str) -> str: + count = self._auto_name_counters.get(prefix, 0) + self._auto_name_counters[prefix] = count + 1 + return f"{prefix}.{count}" + + def fp_var(self, name: str, value: float = 0.0, size: int = 1) -> FPVar | FPFragment: + return self.tensor_manager.fp_var(name, value=value, size=size) + + def fp_fragment(self, name: str, shape: Tuple[int, ...] | int, *, init: float = 0.0) -> FPFragment: + return self.tensor_manager.fp_fragment(name=name, shape=shape, init=init) + + def alloc_fragment( + self, + name: str, + logical_shape: LogicalShape, + *, + init_zero: bool = False, + dtype: str = "fp32", + ) -> Tensor | Vector: + fragment = self.tensor_manager.alloc_fragment( + name=name, + logical_shape=logical_shape, + init_zero=init_zero, + dtype=dtype, + ) + if isinstance(fragment, Vector): + self.vector_manager.initialize_vector_backing(fragment, init_zero=init_zero) + if init_zero and isinstance(fragment, Tensor): + self.clear(fragment) + return fragment + + def alloc_shared( + self, + name: str, + logical_shape: LogicalShape, + *, + init_zero: bool = False, + dtype: str = "fp32", + ) -> Tensor | Vector: + shared = self.tensor_manager.alloc_shared( + name=name, + logical_shape=logical_shape, + init_zero=init_zero, + dtype=dtype, + ) + if isinstance(shared, Vector): + self.vector_manager.initialize_vector_backing(shared, init_zero=init_zero) + if init_zero and isinstance(shared, Tensor): + self.clear(shared) + return shared + + def constant(self, name: str, value: float, size: int = 1) -> FPVar | FPFragment: + return self.fp_var(name, value=value, size=size) + + def create_value_tile_in_fpram( + self, + tile: TensorTile | InputTile, + fragment: FPFragment, + *, + bind: bool = True, + metadata: Optional[Dict[str, object]] = None, + ) -> ValueTile: + return self.value_manager.create_value_tile_in_fpram_for_tile( + tile, + fragment, + bind=bind, + metadata=metadata, + ) + + def set_vram_class(self, operand: object, vram_class: str) -> None: + resolved_class = "shared" if str(vram_class) == "shared" else "l0" + if hasattr(operand, "metadata") and isinstance(getattr(operand, "metadata"), dict): + operand.metadata["vram_class"] = resolved_class + if isinstance(operand, (TensorTile, InputTile, VectorTile)): + tiles = [operand] + else: + tiles = [ + tile + for tile in self.tensor_manager._resolve_tiles_from_operand(operand) + if isinstance(tile, (TensorTile, InputTile, VectorTile)) + ] + for tile in tiles: + tile.metadata["vram_class"] = resolved_class + if isinstance(tile, VectorTile): + continue + value = self.value_manager.full_tile_bindings.get(tile.tile_id) + if value is None: + continue + bound_value = self.value_manager.value_tiles.get(value) + if bound_value is not None: + bound_value.metadata["vram_class"] = resolved_class + + def set_vram_priority(self, operand: object, priority: int) -> None: + resolved_priority = int(priority) + self.set_vram_class(operand, "shared" if resolved_priority > 0 else "l0") + + def clear_tensor(self, operand: object, *, weak: Optional[bool] = None) -> Optional[str] | List[str]: + if isinstance(operand, TensorTile) and not isinstance(operand, VectorTile): + return self.value_manager.free_tensor_tile(operand, weak=weak) + tiles = self.tensor_manager._resolve_tiles_from_operand(operand) + value_tile_ids: List[str] = [] + for tile in tiles: + if not isinstance(tile, TensorTile) or isinstance(tile, VectorTile): + raise TypeError( + f"clear_tensor expects Tensor/TensorSlice/TensorTile operands, got {type(tile).__name__}" + ) + value_tile_id = self.value_manager.free_tensor_tile(tile, weak=weak) + if value_tile_id is not None: + value_tile_ids.append(value_tile_id) + return value_tile_ids + + def free_tensor_tile(self, operand: object, *, weak: Optional[bool] = None) -> Optional[str] | List[str]: + return self.clear_tensor(operand, weak=weak) + + def map_tile_to_fp_fragment( + self, + tile: VectorTile, + fragment: FPFragment, + ) -> FPFragment: + return self.vector_manager.bind_tile_to_fp_fragment(tile, fragment) + + def _initialize_vector_backing(self, vector: Vector, *, init_zero: bool = False) -> None: + self.vector_manager.initialize_vector_backing(vector, init_zero=init_zero) + + def pipelined(self, extent: int, num_stages: int = 1) -> range: + self.loop_hints.append( + { + "kind": "pipelined", + "extent": int(extent), + "num_stages": int(num_stages), + } + ) + return _LoopHintRange(self, kind="pipelined", extent=int(extent)) + + def parallel(self, extent: int) -> range: + region_id = self._parallel_region_counter + self._parallel_region_counter += 1 + self.loop_hints.append( + { + "kind": "parallel", + "extent": int(extent), + "region_id": region_id, + } + ) + return _LoopHintRange(self, kind="parallel", extent=int(extent), region_id=region_id) + + def _sorted_value_tile_ids(self, place: str) -> List[str]: + if place == "vram": + items = sorted(self.value_manager._value_tiles_in_vram.items(), key=lambda item: (int(item[1]), item[0])) + return [value_tile_id for value_tile_id, _ in items] + if place == "mram": + items = sorted(self.value_manager._value_tiles_in_mram.items(), key=lambda item: (int(item[1]), item[0])) + return [value_tile_id for value_tile_id, _ in items] + if place == "hbm": + return sorted(self.value_manager._value_tiles_in_hbm.keys()) + raise ValueError(f"Unsupported value-tile residency place: {place}") + + def _tile_by_id(self, tile_id: str) -> Optional[TileLike]: + return ( + self.tensor_manager.tensor_tiles.get(tile_id) + or self.tensor_manager.input_tiles.get(tile_id) + or self.tensor_manager.vector_tiles.get(tile_id) + ) + + def _logical_row_segment_labels( + self, + logical_shape: LogicalShape, + row_start: int, + row_end: int, + ) -> List[str]: + if len(logical_shape) == 4: + batch, seq, _, _ = logical_shape + labels: List[str] = [] + for batch_index in range(int(batch)): + batch_row_start = batch_index * int(seq) + batch_row_end = batch_row_start + int(seq) + overlap_start = max(int(row_start), batch_row_start) + overlap_end = min(int(row_end), batch_row_end) + if overlap_start >= overlap_end: + continue + labels.append( + f"batch={batch_index},seq={overlap_start - batch_row_start}:{overlap_end - batch_row_start}" + ) + return labels + if len(logical_shape) == 3: + return [f"x={int(row_start)}:{int(row_end)}"] + return [f"row={int(row_start)}:{int(row_end)}"] + + def _tile_slice_labels(self, tile: TileLike) -> List[str]: + logical_shape = tuple(tile.metadata.get("logical_shape", ())) + owner_name = _tile_owner_name(tile) + row_start = int(tile.coord[0]) * int(self.mlen) + row_end = row_start + int(tile.tile_shape[0]) + row_labels = self._logical_row_segment_labels(logical_shape, row_start, row_end) + + if len(logical_shape) == 4: + head_dim = int(tile.metadata.get("head_dim", self.mlen)) + grouped_narrow = bool(tile.metadata.get("grouped_narrow")) + if grouped_narrow: + group_head_start = int(tile.metadata.get("group_head_start", 0)) + packed_head_count = int(tile.metadata.get("packed_head_count", 1)) + labels: List[str] = [] + for row_label in row_labels: + for head_offset in range(packed_head_count): + labels.append(f"{owner_name}[{row_label},head={group_head_start + head_offset},d=0:{head_dim}]") + return labels + + col_start = int(tile.coord[1]) * int(self.mlen) + col_end = col_start + int(tile.tile_shape[1]) + head_start = col_start // head_dim if head_dim > 0 else 0 + head_end = (col_end - 1) // head_dim + 1 if head_dim > 0 and col_end > col_start else head_start + d_start = col_start % head_dim if head_dim > 0 else 0 + d_end = col_end % head_dim if head_dim > 0 else 0 + if d_end == 0 and col_end > col_start: + d_end = head_dim + if head_end - head_start == 1: + col_label = f"head={head_start},d={d_start}:{d_end}" + elif d_start == 0 and d_end == head_dim: + col_label = f"head={head_start}:{head_end},d=0:{head_dim}" + else: + col_label = f"flat_col={col_start}:{col_end}" + return [f"{owner_name}[{row_label},{col_label}]" for row_label in row_labels] + + col_start = int(tile.coord[1]) * int(self.mlen) + col_end = col_start + int(tile.tile_shape[1]) + col_label = f"col={col_start}:{col_end}" + return [f"{owner_name}[{row_label},{col_label}]" for row_label in row_labels] + + def _value_tile_slice_refs_snapshot(self) -> Dict[str, List[str]]: + refs: Dict[str, List[str]] = {} + for value_tile_id in sorted(self.value_manager.value_tiles.keys()): + tile_ids = sorted(self.value_manager.value_tile_tensor_refs.get(value_tile_id, set())) + if not tile_ids: + continue + labels: set[str] = set() + for tile_id in tile_ids: + tile = self._tile_by_id(tile_id) + if tile is None: + labels.add(tile_id) + continue + labels.update(self._tile_slice_labels(tile)) + if labels: + refs[value_tile_id] = sorted(labels) + return refs + + def _fp_fragment_value_refs_snapshot(self) -> Dict[str, List[str]]: + refs: Dict[str, List[str]] = {} + for value_tile_id, value_tile in sorted(self.value_manager.value_tiles.items()): + fragment_name = value_tile.metadata.get("fp_fragment_name") + if not isinstance(fragment_name, str): + continue + refs.setdefault(fragment_name, []).append(value_tile_id) + for fragment_name in refs: + refs[fragment_name].sort() + return refs + + def _active_fp_fragments_snapshot(self) -> List[str]: + fragment_names = { + fragment_name + for fragment_name in self.vector_manager.fp_fragment_bindings.values() + if isinstance(fragment_name, str) + } + fragment_names.update(self._fp_fragment_value_refs_snapshot().keys()) + return sorted(fragment_names) + + def _should_skip_parallel_snapshot(self, op_kind: str) -> bool: + if not self._active_parallel_region_ids: + return False + frame = inspect.currentframe() + caller = frame.f_back.f_back if frame is not None and frame.f_back is not None and frame.f_back.f_back is not None else None + filename = caller.f_code.co_filename if caller is not None else "" + lineno = caller.f_lineno if caller is not None else -1 + region_id = self._active_parallel_region_ids[-1] + dedupe_key = (region_id, filename, lineno, op_kind) + if dedupe_key in self._parallel_snapshot_keys: + return True + self._parallel_snapshot_keys.add(dedupe_key) + return False + + def _record_operation_snapshot(self, op_kind: str, **details: object) -> None: + if self._should_skip_parallel_snapshot(op_kind): + return + self.operation_snapshots.append( + { + "index": len(self.operation_snapshots), + "op_kind": op_kind, + "details": {key: value for key, value in details.items() if value is not None}, + "vram_value_tiles": self._sorted_value_tile_ids("vram"), + "mram_value_tiles": self._sorted_value_tile_ids("mram"), + "hbm_value_tiles": self._sorted_value_tile_ids("hbm"), + "fpram_fp_fragments": self._active_fp_fragments_snapshot(), + "value_tile_slice_refs": self._value_tile_slice_refs_snapshot(), + "fp_fragment_value_refs": self._fp_fragment_value_refs_snapshot(), + } + ) + + def _format_operation_label(self, snapshot: Dict[str, object]) -> str: + op_kind = str(snapshot.get("op_kind", "unknown")) + details = snapshot.get("details", {}) + if not isinstance(details, dict): + return op_kind + if op_kind == "matmul": + path = details.get("path") + src1 = details.get("src1") + src2 = details.get("src2") + dst = details.get("dst") + extras = [item for item in (f"path={path}" if path else None, f"src1={src1}" if src1 else None, f"src2={src2}" if src2 else None, f"dst={dst}" if dst else None) if item] + return f"{op_kind} ({', '.join(extras)})" if extras else op_kind + if op_kind == "atomic_ops": + op = details.get("op") + src1 = details.get("src1") + src2 = details.get("src2") + dst = details.get("dst") + extras = [item for item in (f"op={op}" if op else None, f"src1={src1}" if src1 else None, f"src2={src2}" if src2 else None, f"dst={dst}" if dst else None) if item] + return f"{op_kind} ({', '.join(extras)})" if extras else op_kind + if op_kind == "row_op": + op = details.get("op") + src = details.get("src") + rhs = details.get("rhs") + out = details.get("out") + task_id = details.get("task_id") + extras = [item for item in (f"op={op}" if op else None, f"src={src}" if src else None, f"rhs={rhs}" if rhs else None, f"out={out}" if out else None, f"task_id={task_id}" if task_id else None) if item] + return f"{op_kind} ({', '.join(extras)})" if extras else op_kind + if op_kind in {"pure_fp_compute", "fill"}: + control = details.get("control") + src = details.get("src") + dst = details.get("dst") + task_id = details.get("task_id") + extras = [item for item in (f"control={control}" if control else None, f"src={src}" if src else None, f"dst={dst}" if dst else None, f"task_id={task_id}" if task_id else None) if item] + return f"{op_kind} ({', '.join(extras)})" if extras else op_kind + return op_kind + + def write_operation_report(self, output_path: str | Path) -> None: + output_path = Path(output_path) + lines: List[str] = [] + for snapshot in self.operation_snapshots: + op_index = int(snapshot.get("index", 0)) + op_kind = self._format_operation_label(snapshot) + details = snapshot.get("details", {}) + detail_text = "" + if isinstance(details, dict) and details: + detail_text = " | " + ", ".join(f"{key}={value}" for key, value in details.items()) + lines.append(f"op[{op_index}]: {op_kind}{detail_text}") + lines.append(f" vram_value_tiles: {', '.join(snapshot.get('vram_value_tiles', [])) or '(empty)'}") + lines.append(f" mram_value_tiles: {', '.join(snapshot.get('mram_value_tiles', [])) or '(empty)'}") + lines.append(f" hbm_value_tiles: {', '.join(snapshot.get('hbm_value_tiles', [])) or '(empty)'}") + lines.append(f" fpram_fp_fragments: {', '.join(snapshot.get('fpram_fp_fragments', [])) or '(empty)'}") + + value_refs = snapshot.get("value_tile_slice_refs", {}) + lines.append(" value_tile_slice_refs:") + if isinstance(value_refs, dict) and value_refs: + for value_tile_id, tile_ids in value_refs.items(): + lines.append(f" {value_tile_id}: {', '.join(tile_ids)}") + else: + lines.append(" (empty)") + + fragment_refs = snapshot.get("fp_fragment_value_refs", {}) + lines.append(" fp_fragment_value_refs:") + if isinstance(fragment_refs, dict) and fragment_refs: + for fragment_name, value_tile_ids in fragment_refs.items(): + lines.append(f" {fragment_name}: {', '.join(value_tile_ids)}") + else: + lines.append(" (empty)") + lines.append("") + report_text = "\n".join(lines) + output_path.write_text(report_text, encoding="utf-8") + delta_path = output_path.with_name(f"{output_path.stem}_delta.txt") + delta_text = build_delta_report(parse_operation_report(report_text)) + delta_path.write_text(delta_text, encoding="utf-8") + + def write_eviction_warnings(self, output_path: str | Path) -> None: + output_path = Path(output_path) + if not self.eviction_warnings: + output_path.write_text("(no FIFO VRAM evictions recorded)\n", encoding="utf-8") + return + lines: List[str] = [f"FIFO VRAM evictions: {len(self.eviction_warnings)}", ""] + for entry in self.eviction_warnings: + fields = ", ".join(f"{key}={value}" for key, value in entry.items()) + lines.append(f"WARN evict {fields}") + output_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + def mapf(self, operand: object) -> List[FPVar]: + return self.tensor_manager.mapf(operand) + + def mapf_t(self, tensor_operand: object, fp_operand: object, *, control: str = "mixed") -> Dict[str, object]: + self._require_single_batch_tensor_op("mapf_t", tensor_operand) + return self.tensor_manager.mapf_t(tensor_operand, fp_operand, control=control) + + def _logical_batch_extent_for_selectors( + self, + logical_shape: LogicalShape, + selectors: Tuple[object, ...], + ) -> Optional[int]: + if len(logical_shape) != 4: + return None + if not selectors: + return int(logical_shape[0]) + batch_selector = selectors[0] + if isinstance(batch_selector, int): + return 1 + if isinstance(batch_selector, slice): + start, stop = _slice_item_to_range(batch_selector, int(logical_shape[0])) + return max(0, int(stop) - int(start)) + return int(logical_shape[0]) + + def _operand_batch_extent(self, operand: object) -> Optional[int]: + if operand is None: + return None + if isinstance(operand, (Input, Tensor, Vector, InputTranspose, TensorTranspose, VectorTranspose)): + shape = tuple(operand.logical_shape) + return int(shape[0]) if len(shape) == 4 else None + if isinstance(operand, ParallelAccess): + return self._logical_batch_extent_for_selectors(tuple(operand.logical_shape), tuple(operand.selectors)) + if isinstance(operand, (InputSlice, TensorSlice, VectorSlice)): + return self._logical_batch_extent_for_selectors(tuple(operand.base.logical_shape), tuple(operand.selectors)) + if isinstance(operand, ElementRef): + shape = tuple(getattr(operand.base, "logical_shape", ())) + return 1 if len(shape) == 4 else None + if isinstance(operand, (InputTile, TensorTile, VectorTile)): + shape = tuple(operand.metadata.get("logical_shape", ())) + return 1 if len(shape) == 4 else None + return None + + def _operand_debug_name(self, operand: object) -> str: + if operand is None: + return "None" + if isinstance(operand, (InputSlice, TensorSlice, VectorSlice)): + return getattr(operand.base, "name", type(operand.base).__name__) + if isinstance(operand, ElementRef): + return getattr(operand.base, "name", type(operand.base).__name__) + if isinstance(operand, (InputTile, TensorTile, VectorTile)): + return _tile_owner_name(operand) + return str(getattr(operand, "name", type(operand).__name__)) + + def _require_single_batch_tensor_op(self, op_name: str, *operands: object) -> None: + for operand in operands: + batch_extent = self._operand_batch_extent(operand) + if batch_extent is None or int(batch_extent) <= 1: + continue + raise NotImplementedError( + f"{op_name} currently requires each BSHD tensor operation to address exactly one batch; " + f"got {self._operand_debug_name(operand)} with batch_extent={batch_extent}" + ) + + def fp_kernel( + self, + src1: object, + dst: object, + *, + src2: Optional[object] = None, + control: str = "add", + task_id: str = "fp_kernel", + ) -> Dict[str, object]: + self._require_single_batch_tensor_op(task_id, src1, src2, dst) + src1_vars = self.mapf(src1) + if isinstance(dst, ElementRef): + if control in {"copy", "fill"}: + if len(src1_vars) != 1: + raise ValueError(f"ElementRef dst with control={control!r} expects one source FPVar") + bound = self.tensor_manager.bind_element_pointer(dst, src1_vars[0], mode="alias") + record = { + "op_kind": "fp_kernel_bind", + "task_id": task_id, + "op": control, + "src1": [src1_vars[0].name], + "dst": [bound.name], + } + self.compute_manager.ops.append(record) + return record + dst_var = self.tensor_manager.allocate_element_result_fpvar(dst) + record = self.compute_manager.fp_kernel( + src1_vars, + [dst_var], + src2=self.mapf(src2) if src2 is not None else None, + op=control, + task_id=task_id, + ) + self.tensor_manager.bind_element_pointer(dst, dst_var, mode="result") + return record + return self.compute_manager.fp_kernel( + src1_vars, + self.tensor_manager.mapf_dst(dst, control=control, src1_vars=src1_vars), + src2=self.mapf(src2) if src2 is not None else None, + op=control, + task_id=task_id, + ) + + def pure_fp_compute( + self, + src1: object, + dst: object, + *, + src2: Optional[object] = None, + control: str = "add", + task_id: str = "pure_fp_compute", + ) -> Dict[str, object]: + self._require_single_batch_tensor_op(task_id, src1, src2, dst) + src1_vars = self.mapf(src1) + if isinstance(dst, ElementRef): + if control in {"copy", "fill"}: + if len(src1_vars) != 1: + raise ValueError(f"ElementRef dst with control={control!r} expects one source FPVar") + bound = self.tensor_manager.bind_element_pointer(dst, src1_vars[0], mode="alias") + record = { + "op_kind": "pure_fp_bind", + "task_id": task_id, + "op": control, + "src1": [src1_vars[0].name], + "dst": [bound.name], + } + self.compute_manager.ops.append(record) + return record + dst_var = self.tensor_manager.allocate_element_result_fpvar(dst) + record = self.compute_manager.pure_fp_compute( + src1_vars, + [dst_var], + src2=self.mapf(src2) if src2 is not None else None, + op=control, + task_id=task_id, + ) + self.tensor_manager.bind_element_pointer(dst, dst_var, mode="result") + return record + record = self.compute_manager.pure_fp_compute( + src1_vars, + self.tensor_manager.mapf_dst(dst, control=control, src1_vars=src1_vars), + src2=self.mapf(src2) if src2 is not None else None, + op=control, + task_id=task_id, + ) + return record + + def copy(self, src: object, dst: object) -> object: + if _is_fp_domain_operand(src) or _is_fp_domain_operand(dst): + return self.fp_copy(src, dst) + self._require_single_batch_tensor_op("copy", src, dst) + src_groups = self.tensor_manager.mapt([src, 0]) + dst_groups = self.tensor_manager.mapt([dst, 0]) + if len(src_groups) != len(dst_groups): + raise RuntimeError( + f"copy expects matching tile counts, got src={len(src_groups)} dst={len(dst_groups)}" + ) + signal_4 = [] + for src_group, dst_group in zip(src_groups, dst_groups): + if len(src_group) != 1 or len(dst_group) != 1: + raise RuntimeError("copy currently expects mapt(control=0) groups with one tile each") + src_tile = src_group[0] + dst_tile = dst_group[0] + if not isinstance(src_tile, (TensorTile, InputTile)) or not isinstance(dst_tile, (TensorTile, InputTile)): + raise RuntimeError("copy expects tensor/input tile groups only") + src_value = self.value_manager.resolve_value_tile(src_tile) + if isinstance(dst_tile, InputTile): + self.value_manager._write_value_back_to_input_tile(src_value, dst_tile) + else: + self.value_manager._bind_value_to_tensor_tile(src_value, dst_tile) + signal_4.append( + { + "control": "copy_bind" if not isinstance(dst_tile, InputTile) else "copy_writeback", + "dst_tile": dst_tile, + "dst_value_id": src_value.value_tile_id, + } + ) + out = self.tensor_manager.mapt_back(signal_4, dst_groups) + if signal_4: + copy_trace: List[Dict[str, object]] = [] + for src_group, dst_group in zip(src_groups, dst_groups): + src_tile = src_group[0] + dst_tile = dst_group[0] + src_value = self.value_manager.resolve_value_tile(src_tile) + dst_value = self.value_manager.resolve_value_tile(dst_tile) + copy_trace.append( + { + "src_tile": self.value_manager._tile_debug_state(src_tile), + "dst_tile": self.value_manager._tile_debug_state(dst_tile), + "src_value": self.value_manager._value_debug_state(src_value), + "dst_value": self.value_manager._value_debug_state(dst_value), + } + ) + self._record_operation_snapshot( + "copy", + src=getattr(src, "name", type(src).__name__), + dst=getattr(dst, "name", type(dst).__name__), + tile_copies=copy_trace, + ) + return out + + def atomic_ops( + self, + src1: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + src2: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + dst: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + *, + op: str = "add", + ) -> object: + """Run elementwise tile ops with alias-safe destination updates. + + The public API looks like one simple tilewise binary op, but the runtime + has two materially different execution paths: + + - wide/full-tile path using direct ValueTile operands in VRAM + - narrow/grouped path using ValueTileView-aware compute and rebinding + + When the destination aliases one source (for example `A + B -> B`), the + wide-tile path first detaches the old destination binding and + materializes a fresh writable value so reads remain stable during the + update. + """ + if op not in {"add", "sub", "mul"}: + raise ValueError(f"atomic_ops only supports add/sub/mul, got op={op!r}") + self._require_single_batch_tensor_op(f"atomic_{op}", src1, src2, dst) + + src1_groups = self.tensor_manager.mapt([src1, 0]) + src2_groups = self.tensor_manager.mapt([src2, 0]) + dst_groups = self.tensor_manager.mapt([dst, 0]) + if len(src1_groups) != len(src2_groups) or len(src1_groups) != len(dst_groups): + raise RuntimeError( + f"atomic_ops expects matching tile counts, got src1={len(src1_groups)} src2={len(src2_groups)} dst={len(dst_groups)}" + ) + + signal_4: List[Dict[str, object]] = [] + for group_index, (src1_group, src2_group, dst_group_tiles) in enumerate(zip(src1_groups, src2_groups, dst_groups)): + if len(src1_group) != 1 or len(src2_group) != 1 or len(dst_group_tiles) != 1: + raise RuntimeError("atomic_ops currently expects mapt(control=0) groups with one tile each") + lhs_tile = src1_group[0] + rhs_tile = src2_group[0] + dst_tile = dst_group_tiles[0] + if not isinstance(lhs_tile, (TensorTile, InputTile)) or not isinstance(rhs_tile, (TensorTile, InputTile)) or not isinstance(dst_tile, (TensorTile, InputTile)): + raise RuntimeError("atomic_ops expects tensor/input tile groups only") + + lhs_value = self.value_manager.resolve_value_tile(lhs_tile) + rhs_value = self.value_manager.resolve_value_tile(rhs_tile) + dst_aliases_source = lhs_tile.tile_id == dst_tile.tile_id or rhs_tile.tile_id == dst_tile.tile_id + if isinstance(dst_tile, TensorTile): + dst_view = self.value_manager.resolve_value_tile_view(dst_tile) + prepared_write = self.value_manager.prepare_updated_view_value( + dst_tile, + dst_view, + ensure_old_place="vram" if dst_aliases_source else None, + new_place="vram", + ) + dst_value = prepared_write.new_value + if lhs_tile.tile_id == dst_tile.tile_id: + lhs_value = prepared_write.old_value + if rhs_tile.tile_id == dst_tile.tile_id: + rhs_value = prepared_write.old_value + if prepared_write.requires_preserve_copy: + old_vram_addr = prepared_write.old_value.residency.get("vram_addr") + new_vram_addr = prepared_write.new_value.residency.get("vram_addr") + if old_vram_addr is None or new_vram_addr is None: + raise RuntimeError( + "atomic_ops preserve copy requires old/new values resident in VRAM" + ) + self.isa_emitter.emit_zero_vram_tile(int(new_vram_addr)) + self.isa_emitter.emit_tile_binary( + lhs_vram_addr=int(new_vram_addr), + rhs_vram_addr=int(old_vram_addr), + dst_vram_addr=int(new_vram_addr), + op="add", + task_id=f"atomic_preserve_copy.{dst_tile.tile_id}.{group_index}", + ) + else: + dst_value = self.value_manager._prepare_mapv_destination_value(dst_tile, "vram") + self.value_manager.ensure_value_tile_in_place(lhs_value, "vram") + self.value_manager.ensure_value_tile_in_place(rhs_value, "vram") + self.value_manager.ensure_value_tile_in_place(dst_value, "vram") + + lhs_vram_addr = lhs_value.residency.get("vram_addr") + rhs_vram_addr = rhs_value.residency.get("vram_addr") + dst_vram_addr = dst_value.residency.get("vram_addr") + if lhs_vram_addr is None or rhs_vram_addr is None or dst_vram_addr is None: + raise RuntimeError("atomic_ops wide-tile path requires all operands in VRAM") + self.isa_emitter.emit_tile_binary( + lhs_vram_addr=int(lhs_vram_addr), + rhs_vram_addr=int(rhs_vram_addr), + dst_vram_addr=int(dst_vram_addr), + op=op, + task_id=f"atomic_{op}.{group_index}", + ) + if isinstance(dst_tile, TensorTile) and not prepared_write.reuse_old: + self.value_manager._release_unreferenced_value_tile(prepared_write.old_value.value_tile_id) + signal_4.append( + { + "control": f"atomic_{op}_tile", + "dst_tile": dst_tile, + "dst_value_id": dst_value.value_tile_id, + } + ) + + out = self.tensor_manager.mapt_back(signal_4, dst_groups) + self._record_operation_snapshot( + "atomic_ops", + op=op, + src1=getattr(src1, "name", type(src1).__name__), + src2=getattr(src2, "name", type(src2).__name__), + dst=getattr(dst, "name", type(dst).__name__), + ) + return out + + def atomic_add( + self, + src1: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + src2: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + dst: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + ) -> object: + return self.atomic_ops(src1, src2, dst, op="add") + + def atomic_sub( + self, + src1: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + src2: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + dst: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + ) -> object: + return self.atomic_ops(src1, src2, dst, op="sub") + + def atomic_mul( + self, + src1: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + src2: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + dst: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + ) -> object: + return self.atomic_ops(src1, src2, dst, op="mul") + + def fill(self, dst: object, src: object) -> object: + if isinstance(dst, (FPVar, FPFragment, FPFragmentSlice, Vector, VectorSlice, VectorTile, ElementRef)): + out = self.fp_fill(dst, src) + self._record_operation_snapshot( + "fill", + control="copy", + src=getattr(src, "name", src if isinstance(src, (int, float, str)) else type(src).__name__), + dst=getattr(dst, "name", type(dst).__name__), + ) + return out + raise NotImplementedError(f"fill currently supports FP-domain destinations only, got {type(dst).__name__}") + + def matmul(self, src1: Tensor | Input, src2: Tensor | Input | TensorTranspose | InputTranspose, dst: Tensor | Input) -> object: + """Route one matmul request to the correct execution strategy. + + The current runtime supports multiple matmul families behind one API: + + - default tilewise matmul using `mapt -> mapv -> compute` + - view-based lane matmul for grouped narrow-head layouts + - BTMM/QKT path when the RHS is explicitly transposed and shapes match + + This function is therefore both an entry point and a router. The exact + path is selected from logical shape/layout information before compute + packets are materialized. + """ + self._require_single_batch_tensor_op("matmul", src1, src2, dst) + if self._should_use_btmm_qkt_matmul(src1, src2, dst): + out = self._matmul_btmm_qkt_path(src1, _unwrap_transposed_operand(src2), dst) + self._record_operation_snapshot( + "matmul", + path="btmm_qkt", + src1=getattr(src1, "name", type(src1).__name__), + src2=getattr(_unwrap_transposed_operand(src2), "name", type(_unwrap_transposed_operand(src2)).__name__), + dst=getattr(dst, "name", type(dst).__name__), + ) + return out + if _is_transposed_operand(src2): + raise RuntimeError("BTMM/QKT matmul only supports explicit transpose syntax as prog.matmul(q, k.T, p)") + if self._should_use_view_matmul(src1, src2, dst): + out = self._matmul_view_path(src1, src2, dst) + self._record_operation_snapshot( + "matmul", + path="view", + src1=getattr(src1, "name", type(src1).__name__), + src2=getattr(src2, "name", type(src2).__name__), + dst=getattr(dst, "name", type(dst).__name__), + ) + return out + signal_0 = [src1, src2, dst, 0] + signal_1 = self.tensor_manager.mapt(signal_0) + signal_4 = [] + for a in signal_1: + a.append(["vram", "mram", "vram"]) + tmp = self.value_manager.mapv(a) + signal_2 = self.compute_manager.execute([tmp, "matmul"]) + signal_3 = self.value_manager.mapv_back([signal_2, tmp]) + signal_4.append(signal_3) + out = self.tensor_manager.mapt_back(signal_4, signal_1) + self._record_operation_snapshot( + "matmul", + path="default", + src1=getattr(src1, "name", type(src1).__name__), + src2=getattr(src2, "name", type(src2).__name__), + dst=getattr(dst, "name", type(dst).__name__), + ) + return out + + def _should_use_btmm_qkt_matmul( + self, + src1: Tensor | Input, + src2: Tensor | Input | TensorTranspose | InputTranspose, + dst: Tensor | Input, + ) -> bool: + if not _is_transposed_operand(src2): + return False + src2_base = _unwrap_transposed_operand(src2) + logical_shapes = [getattr(src, "logical_shape", ()) for src in (src1, src2_base, dst)] + if not all(len(shape) == 4 for shape in logical_shapes): + return False + _, src1_seq, src1_heads, src1_dim = logical_shapes[0] + _, src2_seq, src2_heads, src2_dim = logical_shapes[1] + _, dst_seq, dst_heads, dst_dim = logical_shapes[2] + if src1_heads != src2_heads or src1_heads != dst_heads: + return False + if src1_dim != self.btmm_hlen or src2_dim != self.btmm_hlen: + return False + if dst_seq != src1_seq or dst_dim != src2_seq: + return False + if dst_dim % self.mlen != 0: + return False + return True + + def _should_use_view_matmul(self, src1: Tensor | Input, src2: Tensor | Input, dst: Tensor | Input) -> bool: + logical_shapes = [getattr(src, "logical_shape", ()) for src in (src1, src2, dst)] + if not all(len(shape) == 4 for shape in logical_shapes): + return False + _, _, _, src2_head_dim = logical_shapes[1] + _, _, _, dst_head_dim = logical_shapes[2] + if src2_head_dim <= 0 or self.mlen % src2_head_dim != 0: + return False + if dst_head_dim != src2_head_dim: + return False + return True + + def _matmul_view_path(self, src1: Tensor | Input, src2: Tensor | Input, dst: Tensor | Input) -> object: + signal_1 = self.tensor_manager.mapt_view_matmul(src1, src2, dst) + signal_4: List[Dict[str, object]] = [] + + for dst_tile, terms, group_start in signal_1: + if not terms: + continue + + dst_view = self.value_manager.resolve_value_tile_view(dst_tile) + prepared_write = self.value_manager.prepare_updated_view_value( + dst_tile, + dst_view, + ensure_old_place="vram", + new_place="vram", + ) + dst_value = prepared_write.new_value + + for term_index, (lhs_tiles, rhs_tile) in enumerate(terms): + if not lhs_tiles: + continue + lhs_values = [self.value_manager.resolve_value_tile(tile) for tile in lhs_tiles] + self.compute_manager.view_matmul( + lhs_values=lhs_values, + rhs_tile=rhs_tile, + dst_tile=dst_tile, + dst_value=dst_value, + task_id=f"view_matmul.{dst_tile.tile_id}.term{term_index}", + zero_dst=(term_index == 0), + ) + if not prepared_write.reuse_old: + self.value_manager._release_unreferenced_value_tile(prepared_write.old_value.value_tile_id) + + signal_4.append( + { + "control": "view_matmul", + "dst_tile_id": dst_tile.tile_id, + "dst_value_id": dst_value.value_tile_id, + "dst_tile": dst_tile, + } + ) + + out = self.tensor_manager.mapt_back(signal_4, signal_1) + return out + + def _matmul_btmm_qkt_path(self, src1: Tensor | Input, src2: Tensor | Input, dst: Tensor | Input) -> object: + signal_1 = self.tensor_manager.mapt([src1, src2, dst, 1]) + signal_4: List[Dict[str, object]] = [] + + for thread_index, thread in enumerate(signal_1): + if not isinstance(thread, dict): + raise RuntimeError(f"BTMM QKT matmul expected one dict thread, got {type(thread).__name__}") + lhs_tiles = thread.get("lhs_tiles") + rhs_tiles = thread.get("rhs_tiles") + dst_tiles = thread.get("dst_tiles") + if not isinstance(lhs_tiles, list) or not isinstance(rhs_tiles, list) or not isinstance(dst_tiles, list): + raise RuntimeError("BTMM QKT matmul thread is missing lhs_tiles/rhs_tiles/dst_tiles lists") + if len(lhs_tiles) != 1 or len(rhs_tiles) != 1: + raise RuntimeError( + f"BTMM QKT matmul currently expects one lhs tile and one rhs tile per thread, " + f"got lhs={len(lhs_tiles)} rhs={len(rhs_tiles)}" + ) + if not dst_tiles: + continue + + lhs_tile = lhs_tiles[0] + rhs_tile = rhs_tiles[0] + if not isinstance(lhs_tile, (TensorTile, InputTile)) or not isinstance(rhs_tile, (TensorTile, InputTile)): + raise RuntimeError("BTMM QKT matmul thread tiles must be tensor/input tiles") + if not all(isinstance(tile, (TensorTile, InputTile)) for tile in dst_tiles): + raise RuntimeError("BTMM QKT matmul destination group must contain tensor/input tiles only") + + lhs_value = self.value_manager._resolve_mapv_source_value(lhs_tile, "vram") + rhs_value = self.value_manager._resolve_mapv_source_value(rhs_tile, "mram") + if not isinstance(lhs_value, ValueTile) or not isinstance(rhs_value, ValueTile): + raise RuntimeError("BTMM QKT matmul currently expects full-tile source values") + + task_id = ( + f"btmm_qkt.r{thread.get('lhs_row_block', 0)}" + f".k{thread.get('rhs_row_block', 0)}" + f".g{thread.get('group_start', 0)}" + f".t{thread_index}" + ) + btmm_state = self.btmm( + lhs_packed_value=lhs_value, + rhs_value=rhs_value, + task_id=task_id, + ) + write_state = self.btmm_write( + btmm_state=btmm_state, + tile_count=len(dst_tiles), + reason=task_id, + logical_shape=(self.mlen, self.mlen), + metadata={ + "source_thread": task_id, + "group_start": thread.get("group_start"), + "lhs_row_block": thread.get("lhs_row_block"), + "rhs_row_block": thread.get("rhs_row_block"), + "vram_class": dst_tiles[0].metadata.get("vram_class", "l0"), + }, + task_id=f"{task_id}.wo", + ) + + out_values = write_state.get("dst_values") + if not isinstance(out_values, list) or len(out_values) != len(dst_tiles): + raise RuntimeError( + f"BTMM QKT writeback expected {len(dst_tiles)} output value tiles, got {len(out_values) if isinstance(out_values, list) else 'invalid'}" + ) + + for dst_tile, dst_value in zip(dst_tiles, out_values): + if not isinstance(dst_value, ValueTile): + raise RuntimeError("BTMM QKT writeback produced one non-ValueTile output") + if isinstance(dst_tile, InputTile): + self.value_manager._write_value_back_to_input_tile(dst_value, dst_tile) + else: + self.value_manager._bind_value_to_tensor_tile(dst_value, dst_tile) + + signal_4.append( + { + "control": "btmm_qkt_matmul", + "dst_tiles": dst_tiles, + "dst_tile": dst_tiles[0], + "task_id": task_id, + "thread_index": thread_index, + "base_addr": write_state.get("base_addr"), + } + ) + + out = self.tensor_manager.mapt_back(signal_4, signal_1) + return out + + def btmm( + self, + *, + lhs_packed_value: ValueTile, + rhs_value: ValueTile, + task_id: str = "btmm", + ) -> Dict[str, object]: + return self.compute_manager.btmm( + lhs_packed_value=lhs_packed_value, + rhs_value=rhs_value, + task_id=task_id, + ) + + def btmm_write( + self, + *, + btmm_state: Dict[str, object], + tile_count: Optional[int] = None, + reason: str = "btmm_write", + logical_shape: Optional[Tuple[int, int]] = None, + metadata: Optional[Dict[str, object]] = None, + task_id: str = "btmm_wo", + ) -> Dict[str, object]: + return self.compute_manager.btmm_write( + btmm_state=btmm_state, + tile_count=tile_count, + reason=reason, + logical_shape=logical_shape, + metadata=metadata, + task_id=task_id, + ) + + def fp_copy(self, src: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src, dst, control="copy", task_id="fp_copy") + + def fp_fill(self, dst: object, src: object) -> Dict[str, object]: + return self.fp_kernel(src, dst, control="copy", task_id="fp_fill") + + def fp_fill_from_addr(self, dst: object, src_fpram_addr: int) -> Dict[str, object]: + src_var = self._fp_var_from_addr(int(src_fpram_addr)) + return self.fp_fill(dst, src_var) + + def fp_add(self, src1: object, src2: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src1, dst, src2=src2, control="add", task_id="fp_add") + + def fp_sub(self, src1: object, src2: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src1, dst, src2=src2, control="sub", task_id="fp_sub") + + def fp_mul(self, src1: object, src2: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src1, dst, src2=src2, control="mul", task_id="fp_mul") + + def fp_max(self, src1: object, src2: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src1, dst, src2=src2, control="max", task_id="fp_max") + + def fp_exp(self, src: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src, dst, control="exp", task_id="fp_exp") + + def fp_reci(self, src: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src, dst, control="reci", task_id="fp_reci") + + def fp_sqrt(self, src: object, dst: object) -> Dict[str, object]: + return self.fp_kernel(src, dst, control="sqrt", task_id="fp_sqrt") + + def row_op( + self, + src: Tensor | Input | Vector | TensorSlice | InputSlice | VectorSlice, + rhs: Optional[object] = None, + op: str = "exp", + *, + out: Optional[object] = None, + dim: int = -1, + task_id: Optional[str] = None, + ) -> List[Dict[str, object]]: + self._require_single_batch_tensor_op(task_id or f"row_op.{op}", src, rhs, out) + if dim != -1: + raise NotImplementedError(f"row_op currently supports dim=-1 only, got {dim}") + src_slice_ranges: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None + if isinstance(src, (TensorSlice, InputSlice, VectorSlice)): + src_slice_ranges = _logical_selectors_to_physical_ranges(src.base.logical_shape, src.selectors) + src_groups = self.tensor_manager.mapt([src, 0]) + if not src_groups: + return [] + records: List[Dict[str, object]] = [] + mutates_src = op in {"exp", "reci", "mul", "add", "sub"} + rhs_vars = self.mapf(rhs) if rhs is not None and op in {"mul", "add", "sub"} else None + out_vars = self.mapf(out) if out is not None else None + rhs_cursor = 0 + out_cursor = 0 + for group_index, src_group in enumerate(src_groups): + if len(src_group) != 1 or not isinstance(src_group[0], (TensorTile, InputTile, VectorTile)): + raise RuntimeError("row_op currently expects one full tile per mapt group") + src_tile = src_group[0] + if isinstance(src_tile, VectorTile): + record = self._row_op_vector_tile( + src_tile, + src_slice_ranges=src_slice_ranges, + rhs_vars=rhs_vars, + out_vars=out_vars, + op=op, + rhs_cursor=rhs_cursor, + out_cursor=out_cursor, + task_id=task_id or f"row_op.{op}.{group_index}", + ) + rhs_cursor = int(record.get("rhs_cursor", rhs_cursor)) + out_cursor = int(record.get("out_cursor", out_cursor)) + records.append(record) + continue + if src_slice_ranges is not None: + src_operand = self.value_manager.resolve_row_operand_for_ranges( + src_tile, + src_slice_ranges[0], + src_slice_ranges[1], + "vram", + ) + else: + src_operand = self.value_manager.resolve_row_operand(src_tile, "vram") + if src_slice_ranges is not None: + target_view = src_operand if isinstance(src_operand, ValueTileView) else None + else: + target_view = None + dst_operand: RowOperandLike = src_operand + if mutates_src: + if isinstance(src_tile, TensorTile): + if target_view is None: + target_view = self.value_manager.resolve_value_tile_view(src_tile) + prepared_write = self.value_manager.prepare_updated_view_value( + src_tile, + target_view, + ensure_old_place="vram", + new_place="vram", + ) + if not prepared_write.reuse_old: + if prepared_write.requires_preserve_copy: + old_vram_addr = prepared_write.old_value.residency.get("vram_addr") + new_vram_addr = prepared_write.new_value.residency.get("vram_addr") + if old_vram_addr is None or new_vram_addr is None: + raise RuntimeError( + "row_op preserve copy requires old/new values resident in VRAM" + ) + self.isa_emitter.emit_zero_vram_tile(int(new_vram_addr)) + self.isa_emitter.emit_tile_binary( + lhs_vram_addr=int(new_vram_addr), + rhs_vram_addr=int(old_vram_addr), + dst_vram_addr=int(new_vram_addr), + op="add", + task_id=f"row_op_preserve_copy.{src_tile.tile_id}.{group_index}", + ) + if isinstance(src_operand, ValueTileView): + dst_operand = prepared_write.target_view + else: + dst_operand = prepared_write.new_value + else: + dst_operand = self.value_manager._prepare_mapv_destination_value(src_tile, "vram") + row_count = int(src_operand.row_count if isinstance(src_operand, ValueTileView) else src_operand.logical_shape[0]) + + group_out: Optional[List[FPVar]] = None + if op in {"reduce_max", "reduce_sum"}: + if out_vars is None: + raise ValueError(f"row_op op={op!r} requires out") + if out_cursor + row_count > len(out_vars): + raise ValueError(f"row_op op={op!r} out size is smaller than required rows") + group_out = out_vars[out_cursor : out_cursor + row_count] + out_cursor += row_count + + group_rhs: Optional[List[FPVar]] = None + if op in {"mul", "add", "sub"}: + if rhs_vars is None: + raise ValueError(f"row_op op={op!r} requires rhs") + if len(rhs_vars) == 1: + group_rhs = list(rhs_vars) + else: + if rhs_cursor + row_count > len(rhs_vars): + raise ValueError(f"row_op op={op!r} rhs size is smaller than required rows") + group_rhs = rhs_vars[rhs_cursor : rhs_cursor + row_count] + rhs_cursor += row_count + + record = self.compute_manager.row_operations( + src_operand, + dst_operand=dst_operand, + dst=group_out, + rhs=group_rhs, + op=op, + task_id=task_id or f"row_op.{op}.{group_index}", + ) + if mutates_src and isinstance(src_tile, TensorTile) and not prepared_write.reuse_old: + self.value_manager._release_unreferenced_value_tile(prepared_write.old_value.value_tile_id) + records.append(record) + self._record_operation_snapshot( + "row_op", + op=op, + src=getattr(src, "name", type(src).__name__), + rhs=getattr(rhs, "name", rhs if isinstance(rhs, (int, float, str)) else type(rhs).__name__) if rhs is not None else None, + out=getattr(out, "name", type(out).__name__) if out is not None else None, + task_id=task_id or "row_op", + ) + return records + + def _row_op_vector_tile( + self, + src_tile: VectorTile, + *, + src_slice_ranges: Optional[Tuple[Tuple[int, int], Tuple[int, int]]], + rhs_vars: Optional[List[FPVar]], + out_vars: Optional[List[FPVar]], + op: str, + rhs_cursor: int, + out_cursor: int, + task_id: str, + ) -> Dict[str, object]: + fragment = self.vector_manager.resolve_fp_fragment(src_tile) + row_groups = _vector_tile_row_fp_groups( + src_tile=src_tile, + fragment=fragment, + mlen=self.mlen, + btmm_hlen=self.btmm_hlen, + src_slice_ranges=src_slice_ranges, + ) + if not row_groups: + return { + "op_kind": "row_op_vector", + "task_id": task_id, + "tile": src_tile.tile_id, + "rows": 0, + "rhs_cursor": rhs_cursor, + "out_cursor": out_cursor, + } + + if op in {"exp", "reci"}: + for row_index, row_vars in enumerate(row_groups): + self.compute_manager.fp_kernel( + row_vars, + row_vars, + op=op, + task_id=f"{task_id}.row{row_index}", + ) + elif op in {"add", "sub", "mul"}: + if rhs_vars is None: + raise ValueError(f"row_op op={op!r} requires rhs") + if len(rhs_vars) == 1: + row_rhs_vars = [rhs_vars[0] for _ in row_groups] + else: + if rhs_cursor + len(row_groups) > len(rhs_vars): + raise ValueError(f"row_op op={op!r} rhs size is smaller than required rows") + row_rhs_vars = rhs_vars[rhs_cursor : rhs_cursor + len(row_groups)] + rhs_cursor += len(row_groups) + for row_index, (row_vars, rhs_var) in enumerate(zip(row_groups, row_rhs_vars)): + self.compute_manager.fp_kernel( + row_vars, + row_vars, + src2=[rhs_var] * len(row_vars), + op=op, + task_id=f"{task_id}.row{row_index}", + ) + elif op in {"reduce_sum", "reduce_max"}: + if out_vars is None: + raise ValueError(f"row_op op={op!r} requires out") + if out_cursor + len(row_groups) > len(out_vars): + raise ValueError(f"row_op op={op!r} out size is smaller than required rows") + row_out_vars = out_vars[out_cursor : out_cursor + len(row_groups)] + out_cursor += len(row_groups) + for row_index, (row_vars, out_var) in enumerate(zip(row_groups, row_out_vars)): + if not row_vars: + continue + if op == "reduce_sum": + self.compute_manager.fp_kernel( + [self.mapf(0.0)[0]], + [out_var], + op="copy", + task_id=f"{task_id}.row{row_index}.init", + ) + for cell_index, cell_var in enumerate(row_vars): + self.compute_manager.fp_kernel( + [out_var], + [out_var], + src2=[cell_var], + op="add", + task_id=f"{task_id}.row{row_index}.cell{cell_index}", + ) + else: + self.compute_manager.fp_kernel( + [row_vars[0]], + [out_var], + op="copy", + task_id=f"{task_id}.row{row_index}.init", + ) + for cell_index, cell_var in enumerate(row_vars[1:], start=1): + self.compute_manager.fp_kernel( + [out_var], + [out_var], + src2=[cell_var], + op="max", + task_id=f"{task_id}.row{row_index}.cell{cell_index}", + ) + else: + raise NotImplementedError(f"row_op vector path does not support op={op!r}") + + return { + "op_kind": "row_op_vector", + "task_id": task_id, + "tile": src_tile.tile_id, + "fragment": fragment.name, + "rows": len(row_groups), + "op": op, + "rhs_cursor": rhs_cursor, + "out_cursor": out_cursor, + } + + def elementwise( + self, + src1: object, + dst: object, + *, + src2: Optional[object] = None, + op: str = "add", + task_id: Optional[str] = None, + ) -> Dict[str, object]: + self._require_single_batch_tensor_op(task_id or f"elementwise.{op}", src1, src2, dst) + if _is_parallel_graph_operand(src1) or _is_parallel_graph_operand(dst) or _is_parallel_graph_operand(src2): + if not isinstance(dst, ParallelAccess): + raise ValueError("parallel elementwise dst must be one ParallelAccess target") + expr = _coerce_parallel_expr(src1) + if op == "copy": + self.thread_manager.record_parallel_assignment_from_access(dst, expr) + elif op == "add": + self.thread_manager.record_parallel_assignment_from_access(dst, expr + src2) + elif op == "sub": + self.thread_manager.record_parallel_assignment_from_access(dst, expr - src2) + elif op == "mul": + self.thread_manager.record_parallel_assignment_from_access(dst, expr * src2) + else: + raise ValueError(f"parallel elementwise does not support op={op!r}") + region = self.thread_manager.current_parallel_graph() + return { + "op_kind": "parallel_graph_elementwise", + "task_id": task_id or f"elementwise.{op}", + "region": region.name, + "op": op, + "dst": _parallel_access_identity(dst), + } + return self.pure_fp_compute( + src1, + dst, + src2=src2, + control=op, + task_id=task_id or f"elementwise.{op}", + ) + + def clear(self, tensor: Tensor) -> None: + self._require_single_batch_tensor_op("clear", tensor) + cleared_values: set[str] = set() + for tile in _tiles_in_grid_order(tensor.tiles): + value = self.value_manager.resolve_value_tile(tile) + self.value_manager.ensure_value_tile_in_place(value, "vram") + if value.value_tile_id in cleared_values: + continue + vram_addr = value.residency.get("vram_addr") + if vram_addr is None: + raise RuntimeError(f"clear expected VRAM residency for {value.value_tile_id}") + self.isa_emitter.emit_zero_vram_tile(int(vram_addr)) + cleared_values.add(value.value_tile_id) + + def _fp_var_from_addr(self, fp_mem_addr: int) -> FPVar: + for fp_var in self.tensor_manager.fp_vars.values(): + if fp_var.fp_mem_addr == fp_mem_addr: + return fp_var + raise KeyError(f"No FPVar found at fp_mem_addr={fp_mem_addr}") + + def _arith_progression(self, values: Sequence[int]) -> Optional[Tuple[int, int, int]]: + if not values: + return None + if len(values) == 1: + return int(values[0]), 1, 0 + first = int(values[0]) + step = int(values[1]) - first + for idx, value in enumerate(values[1:], start=1): + if int(value) != first + idx * step: + return None + return first, len(values), step + + def alloc_hbm_addr(self, elems: int) -> int: + size = int(elems * self.real_data_ratio) + base = self._next_hbm_addr + self._next_hbm_addr += size + return base + + def add_hbm_object(self, name: str, shape: Tuple[int, int], *, hbm_addr: Optional[int] = None) -> int: + base_addr = self.alloc_hbm_addr(shape[0] * shape[1]) if hbm_addr is None else int(hbm_addr) + self.compiler.add_hbm_object( + name=name, + shape=shape, + hbm_addr=base_addr, + real_data_ratio=self.real_data_ratio, + ) + self.hardware.hbm_objects[name] = { + "name": name, + "shape": shape, + "base_addr": base_addr, + } + return base_addr + + def build_fp_preload(self, min_size: int = 0) -> List[float]: + """Return the FP_MEM initialisation array ordered by address. + + Entries come from fp_var() declarations; any slots beyond the + declared range up to min_size are zero-padded. + """ + values = list(self.tensor_manager._fp_mem_values) + size = max(len(values), int(min_size)) + values.extend([0.0] * (size - len(values))) + return values + + def _normalize_large_addi_immediates(self, asm_code: str) -> str: + lines: List[str] = [] + for raw_line in asm_code.splitlines(): + line = raw_line.rstrip("\n") + stripped = line.strip() + if not stripped or stripped.startswith(";"): + lines.append(line) + continue + + parts = stripped.split(None, 1) + if len(parts) != 2 or parts[0] != "S_ADDI_INT": + lines.append(line) + continue + + operands = [item.strip() for item in parts[1].split(",")] + if len(operands) != 3: + lines.append(line) + continue + + rd, rs1, imm_text = operands + try: + imm_value = int(imm_text) + except ValueError: + lines.append(line) + continue + + if 0 <= imm_value <= 262143: + lines.append(line) + continue + + if rs1 != "gp0": + lines.append(line) + continue + + upper = imm_value >> 12 + lower = imm_value & 0xFFF + lines.append(f"S_LUI_INT {rd}, {upper}") + lines.append(f"S_ADDI_INT {rd}, {rd}, {lower}") + normalized = "\n".join(lines) + if asm_code.endswith("\n"): + normalized += "\n" + return normalized + + def compile(self) -> str: + if not self._parallel_execution_lowered: + self.lower_parallel_execution_plans() + self.compiler.generated_code = self._normalize_large_addi_immediates(self.compiler.generated_code) + return self.compiler.generated_code diff --git a/tilelang_runtime_compier/tile_tensor_program/_tensor_manager.py b/tilelang_runtime_compier/tile_tensor_program/_tensor_manager.py new file mode 100644 index 0000000..9191039 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_tensor_manager.py @@ -0,0 +1,976 @@ +"""TensorManager: logical tensor/input objects, tile creation, slice resolution.""" + +from __future__ import annotations + +from math import ceil +from typing import Dict, List, Optional, Sequence, Tuple + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class TensorManager: + """Manage logical tensors, tiles, slices, and tensor-thread grouping. + + TensorManager operates on logical objects only. It owns shape flattening, + tile metadata, slice resolution, and `mapt` grouping. It deliberately does + not create ValueTile / ValueTileView objects and does not decide + residency placement; that work stays in ValueManager. + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.isa_emitter = program.isa_emitter + self.inputs: Dict[str, Input] = {} + self.tensors: Dict[str, Tensor] = {} + self.vectors: Dict[str, Vector] = {} + self.fp_fragments: Dict[str, FPFragment] = {} + self.input_tiles: Dict[str, InputTile] = {} + self.tensor_tiles: Dict[str, TensorTile] = {} + self.vector_tiles: Dict[str, VectorTile] = {} + self._input_tile_counter = 0 + self._tensor_tile_counter = 0 + # FPVar management: one FP_MEM slot per scalar constant. + # _fp_mem_values is ordered by address so build_fp_preload can + # return the initialisation array directly. + # Addresses [0, 32) are reserved for system/hardware constants; + # user fp_var() declarations start at address 32. + self.fp_vars: Dict[str, FPVar] = {} + self._fp_mem_values: List[float] = [0.0] * 32 + self._next_fp_mem_addr: int = 32 + self._literal_fp_vars: Dict[Tuple[str, float], FPVar] = {} + + def fp_var(self, name: str, value: float = 0.0, size: int = 1) -> FPVar | FPFragment: + """Allocate FP-domain storage. + + The new default rule is that one FPVar represents one scalar slot. + For compatibility, requesting size > 1 returns one FPFragment whose + cells are backed by one scalar FPVar each. + + Usage: + scale = program.fp_var("scale", value=1.0 / math.sqrt(dim)) + """ + if size <= 0: + raise ValueError(f"FP allocation size must be positive, got {size}") + if size != 1: + fragment = self.fp_fragment(name=name, shape=(int(size),), init=value) + return fragment + if name in self.fp_vars: + raise ValueError(f"FPVar {name!r} already declared") + addr = self._next_fp_mem_addr + self._next_fp_mem_addr += 1 + var = FPVar(name=name, fp_mem_addr=addr) + self.fp_vars[name] = var + self._fp_mem_values.append(float(value)) + return var + + def fp_fragment( + self, + name: str, + shape: Tuple[int, ...] | int, + *, + init: float = 0.0, + dtype: str = "fp32", + ) -> FPFragment: + if isinstance(shape, int): + shape = (shape,) + normalized_shape = tuple(int(dim) for dim in shape) + if not normalized_shape or any(dim <= 0 for dim in normalized_shape): + raise ValueError(f"FPFragment shape must contain positive extents, got {shape}") + if name in self.fp_fragments or name in self.fp_vars: + raise ValueError(f"FPFragment {name!r} already declared") + + fragment = FPFragment(program=self.program, name=name, shape=normalized_shape, dtype=dtype) + for index in _iter_fp_indices(normalized_shape): + cell_name = f"{name}{_format_fp_index(index)}" + fragment.vars[index] = self.fp_var(cell_name, value=init, size=1) # type: ignore[assignment] + + self.fp_fragments[name] = fragment + return fragment + + def alloc_fragment( + self, + name: str, + logical_shape: LogicalShape, + *, + init_zero: bool = False, + dtype: str = "fp32", + ) -> Tensor | Vector: + return self._alloc_tensor_storage( + name=name, + logical_shape=logical_shape, + init_zero=init_zero, + dtype=dtype, + vram_class="l0", + ) + + def alloc_shared( + self, + name: str, + logical_shape: LogicalShape, + *, + init_zero: bool = False, + dtype: str = "fp32", + ) -> Tensor | Vector: + return self._alloc_tensor_storage( + name=name, + logical_shape=logical_shape, + init_zero=init_zero, + dtype=dtype, + vram_class="shared", + ) + + def _alloc_tensor_storage( + self, + *, + name: str, + logical_shape: LogicalShape, + init_zero: bool, + dtype: str, + vram_class: str, + ) -> Tensor | Vector: + if len(logical_shape) == 4: + tensor = self.tensor(name, logical_shape) + tensor.metadata["fragment_kind"] = "tensor" + tensor.metadata["dtype"] = dtype + tensor.metadata["init_zero"] = bool(init_zero) + tensor.metadata["vram_class"] = vram_class + for tile in tensor.tiles.values(): + tile.metadata["vram_class"] = vram_class + return tensor + if len(logical_shape) == 3: + vector = self.vector(name, logical_shape) + vector.metadata["fragment_kind"] = "vector" + vector.metadata["dtype"] = dtype + vector.metadata["init_zero"] = bool(init_zero) + vector.metadata["vram_class"] = vram_class + for tile in vector.tiles.values(): + tile.metadata["vram_class"] = vram_class + return vector + raise NotImplementedError( + f"tensor storage allocation supports 4D tensor fragments and 3D vector fragments only, got {logical_shape}" + ) + + def mapf(self, operand: object) -> List[FPVar]: + if isinstance(operand, (int, float)): + literal_value = float(operand) + key = ("fp32", literal_value) + literal_var = self._literal_fp_vars.get(key) + if literal_var is None: + literal_name = self.program._auto_name("fp_literal") + created = self.fp_var(literal_name, value=literal_value, size=1) + if not isinstance(created, FPVar): + raise RuntimeError("literal fp allocation expected one FPVar") + literal_var = created + self._literal_fp_vars[key] = literal_var + return [literal_var] + if isinstance(operand, FPVar): + return [operand] + if isinstance(operand, FPFragment): + return [operand.vars[index] for index in _iter_fp_indices(operand.shape)] + if isinstance(operand, FPFragmentSlice): + return self._resolve_fp_fragment_slice(operand.base, operand.selectors) + if isinstance(operand, Vector): + return self.program.vector_manager.resolve_vector_fp_vars(operand) + if isinstance(operand, VectorSlice): + return self.program.vector_manager.resolve_vector_slice_fp_vars(operand) + if isinstance(operand, VectorTile): + return self.program.vector_manager.resolve_vector_tile_fp_vars(operand) + if isinstance(operand, ElementRef): + return [self._resolve_element_fpvar(operand)] + if isinstance(operand, (list, tuple)): + resolved: List[FPVar] = [] + for item in operand: + resolved.extend(self.mapf(item)) + return resolved + raise NotImplementedError(f"Unsupported operand for mapf: {type(operand).__name__}") + + def mapf_dst(self, operand: object, *, control: str, src1_vars: Optional[Sequence[FPVar]] = None) -> List[FPVar]: + if isinstance(operand, (list, tuple)): + resolved: List[FPVar] = [] + for item in operand: + resolved.extend(self.mapf_dst(item, control=control, src1_vars=src1_vars)) + return resolved + return self.mapf(operand) + + def _resolve_element_operand_context( + self, + operand: ElementRef, + ) -> Tuple[object, Tuple[int, ...], TileLike, int, int]: + base = operand.base + logical_shape = tuple(getattr(base, "logical_shape", ())) + if not logical_shape: + raise RuntimeError(f"ElementRef base {type(base).__name__} does not expose logical_shape") + if len(operand.indices) != len(logical_shape): + raise RuntimeError( + f"ElementRef expected {len(logical_shape)} indices for {type(base).__name__}, got {len(operand.indices)}" + ) + + normalized_indices = tuple(_normalize_index(index, extent) for index, extent in zip(operand.indices, logical_shape)) + physical_row, physical_col = _logical_indices_to_physical_coord(logical_shape, normalized_indices) + tile_coord = (physical_row // self.program.mlen, physical_col // self.program.mlen) + tile_col_start = tile_coord[1] * self.program.mlen + tile_row_start = tile_coord[0] * self.program.mlen + + tiles = getattr(base, "tiles", None) + if not isinstance(tiles, dict): + raise RuntimeError(f"ElementRef base {type(base).__name__} does not expose tiles") + tile = tiles.get(tile_coord) + if not isinstance(tile, (TensorTile, InputTile, VectorTile)): + raise RuntimeError( + f"ElementRef {getattr(base, 'name', type(base).__name__)}{normalized_indices} " + f"did not resolve to one tile at coord={tile_coord}" + ) + return ( + base, + normalized_indices, + tile, + int(physical_row - tile_row_start), + int(physical_col - tile_col_start), + ) + + def _ensure_element_tile_fp_fragment( + self, + *, + base: object, + normalized_indices: Tuple[int, ...], + tile: TensorTile | InputTile, + ) -> FPFragment: + backing_value = self.program.value_manager.resolve_value_tile(tile) + if backing_value.residency.get("fpram_ready"): + return self.program.value_manager._resolve_value_fp_fragment(backing_value) + + has_materialized_storage = any( + backing_value.residency.get(key) is not None + for key in ("vram_addr", "mram_addr", "hbm_addr") + ) or bool(backing_value.residency.get("hbm_ready")) + if has_materialized_storage: + raise RuntimeError( + "ElementRef write requires one FP-backed tile before mutating materialized tensor storage; " + f"tile={tile.tile_id} base={getattr(base, 'name', type(base).__name__)} indices={normalized_indices}" + ) + + fragment_name = self.program._auto_name(f"{getattr(base, 'name', 'tensor')}.element_fp_tile") + zero_var = self.mapf(0.0)[0] + fragment = FPFragment( + program=self.program, + name=fragment_name, + shape=tile.tile_shape, + dtype="fp32", + ) + for fp_index in _iter_fp_indices(tile.tile_shape): + fragment.vars[fp_index] = zero_var + self.fp_fragments[fragment_name] = fragment + self.program.create_value_tile_in_fpram( + tile, + fragment, + bind=True, + metadata={ + "element_ref_direct_backing": True, + "source_tensor": getattr(base, "name", type(base).__name__), + "source_tile_id": tile.tile_id, + }, + ) + return fragment + + def _element_fragment_and_index( + self, + operand: ElementRef, + *, + ensure_write_backing: bool = False, + ) -> Tuple[FPFragment, FPIndex, object, Tuple[int, ...], TileLike]: + base, normalized_indices, tile, local_row, local_col = self._resolve_element_operand_context(operand) + if isinstance(tile, VectorTile): + fragment = self.program.vector_manager.resolve_fp_fragment(tile) + else: + backing_value = self.program.value_manager.resolve_value_tile(tile) + if ensure_write_backing: + fragment = self._ensure_element_tile_fp_fragment( + base=base, + normalized_indices=normalized_indices, + tile=tile, + ) + elif not backing_value.residency.get("fpram_ready"): + raise RuntimeError( + f"ElementRef {getattr(base, 'name', type(base).__name__)}{normalized_indices} requires one fpram-backed " + f"value tile; backing value {backing_value.value_tile_id} is no longer resident in fpram" + ) + else: + fragment = self.program.value_manager._resolve_value_fp_fragment(backing_value) + fp_index = _physical_tile_coord_to_fp_index( + fragment.shape, + local_row=local_row, + local_col=local_col, + mlen=self.program.mlen, + btmm_hlen=self.program.btmm_hlen, + ) + return fragment, fp_index, base, normalized_indices, tile + + def _resolve_element_fpvar(self, operand: ElementRef, *, create_for_write: bool = False) -> FPVar: + fragment, fp_index, base, normalized_indices, _tile = self._element_fragment_and_index( + operand, + ensure_write_backing=create_for_write, + ) + fp_var = fragment.vars.get(fp_index) + if not isinstance(fp_var, FPVar): + raise RuntimeError( + f"ElementRef {getattr(base, 'name', type(base).__name__)}{normalized_indices} resolved to missing fp cell {fp_index}" + ) + return fp_var + + def bind_element_pointer(self, operand: ElementRef, fp_var: FPVar, *, mode: str = "alias") -> FPVar: + fragment, fp_index, base, normalized_indices, tile = self._element_fragment_and_index( + operand, + ensure_write_backing=True, + ) + fragment.vars[fp_index] = fp_var + return fp_var + + def allocate_element_result_fpvar(self, operand: ElementRef) -> FPVar: + _fragment, _fp_index, base, normalized_indices, tile = self._element_fragment_and_index( + operand, + ensure_write_backing=True, + ) + created = self.fp_var( + self.program._auto_name(f"{getattr(base, 'name', 'tensor')}.element_fp"), + value=0.0, + size=1, + ) + if not isinstance(created, FPVar): + raise RuntimeError("ElementRef result allocation expected one scalar FPVar") + return created + + def mapf_t(self, tensor_operand: object, fp_operand: object, *, control: str = "mixed") -> Dict[str, object]: + tensor_tiles = self.mapt([tensor_operand, 0]) if tensor_operand is not None else [] + fp_vars = self.mapf(fp_operand) + packet = { + "control": control, + "tensor_operand": tensor_operand, + "tensor_groups": tensor_tiles, + "fp_operand": fp_operand, + "fp_vars": fp_vars, + } + return packet + + def _resolve_fp_fragment_slice( + self, + fragment: FPFragment, + selectors: Tuple[SliceItem, ...], + ) -> List[FPVar]: + normalized = list(selectors) + [slice(None)] * max(0, len(fragment.shape) - len(selectors)) + selected_indices: List[FPIndex] = [] + for index in _iter_fp_indices(fragment.shape): + keep = True + for dim_idx, selector in enumerate(normalized[: len(fragment.shape)]): + start, stop = _slice_item_to_range(selector, fragment.shape[dim_idx]) + if index[dim_idx] < start or index[dim_idx] >= stop: + keep = False + break + if keep: + selected_indices.append(index) + return [fragment.vars[index] for index in selected_indices] + + def _next_input_tile_id(self) -> str: + tile_id = f"input_tile.{self._input_tile_counter}" + self._input_tile_counter += 1 + return tile_id + + def _next_tensor_tile_id(self) -> str: + tile_id = f"tensor_tile.{self._tensor_tile_counter}" + self._tensor_tile_counter += 1 + return tile_id + + def create_input_tiles(self, input_name: str, logical_shape: LogicalShape) -> Dict[TileCoord, InputTile]: + rows, cols = _logical_shape_to_physical_shape(logical_shape) + row_blocks = ceil(rows / self.program.mlen) + col_blocks = ceil(cols / self.program.mlen) + tiles: Dict[TileCoord, InputTile] = {} + for row_block in range(row_blocks): + for col_block in range(col_blocks): + row_count = min(self.program.mlen, rows - row_block * self.program.mlen) + col_count = min(self.program.mlen, cols - col_block * self.program.mlen) + input_tile = InputTile( + tile_id=self._next_input_tile_id(), + input_name=input_name, + coord=(row_block, col_block), + tile_shape=(row_count, col_count), + metadata=self._build_tile_metadata(logical_shape, row_block, col_block, row_count, col_count), + ) + tiles[(row_block, col_block)] = input_tile + self.input_tiles[input_tile.tile_id] = input_tile + return tiles + + def create_tensor_tiles(self, tensor_name: str, logical_shape: LogicalShape) -> Dict[TileCoord, TensorTile]: + rows, cols = _logical_shape_to_physical_shape(logical_shape) + row_blocks = ceil(rows / self.program.mlen) + col_blocks = ceil(cols / self.program.mlen) + tiles: Dict[TileCoord, TensorTile] = {} + for row_block in range(row_blocks): + for col_block in range(col_blocks): + row_count = min(self.program.mlen, rows - row_block * self.program.mlen) + col_count = min(self.program.mlen, cols - col_block * self.program.mlen) + tensor_tile = TensorTile( + tile_id=self._next_tensor_tile_id(), + tensor_name=tensor_name, + coord=(row_block, col_block), + tile_shape=(row_count, col_count), + metadata=self._build_tile_metadata(logical_shape, row_block, col_block, row_count, col_count), + ) + tiles[(row_block, col_block)] = tensor_tile + self.tensor_tiles[tensor_tile.tile_id] = tensor_tile + return tiles + + def _build_tile_metadata( + self, + logical_shape: LogicalShape, + row_block: int, + col_block: int, + row_count: int, + col_count: int, + ) -> Dict[str, object]: + """Build per-tile logical metadata used by later grouping/mapping stages. + + For 4D BSHD tensors, the current convention treats one physical tile as + one logical window over flattened `(seq, head * head_dim)` storage. + When `head_dim < mlen`, one physical tile may pack multiple adjacent + heads. The metadata below records both views: + + - per-head view: `head_index`, `head_col_offset`, `d_tile_index` + - packed-group view: `group_head_start`, `packed_head_count` + - scatter layout view: `grouped_narrow`, `packed_head_group`, + `scatter_slot_width` + + Downstream `mapt_head_group`, scatter-group matmul, and group-head + elementwise paths all rely on these fields instead of re-deriving the + packing rules independently. + """ + metadata: Dict[str, object] = { + "mlen": self.program.mlen, + "logical_shape": logical_shape, + "row_block": row_block, + "col_block": col_block, + "row_count": row_count, + "col_count": col_count, + "tile_width_class": "narrow" if int(col_count) < int(self.program.mlen) else "full", + } + if len(logical_shape) == 4: + b, s, h, d = logical_shape + if int(b) > 1 and int(s) % int(self.program.mlen) != 0: + raise ValueError( + f"BSHD tensors with batch>1 require S to be a multiple of mlen={self.program.mlen}; " + f"got shape={logical_shape}" + ) + row_blocks_per_batch = ( + max(1, int(s) // int(self.program.mlen)) + if int(s) % int(self.program.mlen) == 0 + else max(1, ceil(int(s) / int(self.program.mlen))) + ) + batch_index = int(row_block) // row_blocks_per_batch + seq_block = int(row_block) % row_blocks_per_batch + seq_start = seq_block * int(self.program.mlen) + seq_end = min(int(s), seq_start + int(row_count)) + physical_col_start = col_block * self.program.mlen + head_index = physical_col_start // d if d > 0 else 0 + head_col_offset = physical_col_start % d if d > 0 else 0 + grouped_narrow = d > 0 and d < self.program.mlen + packed_head_count = min(max(self.program.mlen // d, 1), max(h - head_index, 0)) if grouped_narrow else 1 + metadata.update( + { + "layout": "bshd", + "batch": b, + "seq": s, + "heads": h, + "head_dim": d, + "batch_index": batch_index, + "seq_block": seq_block, + "seq_start": seq_start, + "seq_end": seq_end, + "row_blocks_per_batch": row_blocks_per_batch, + "head_index": head_index, + "head_col_offset": head_col_offset, + "d_tile_index": head_col_offset // self.program.mlen if self.program.mlen > 0 else 0, + "grouped_narrow": grouped_narrow, + "packed_head_group": grouped_narrow, + "tile_width_class": "narrow" if grouped_narrow or int(col_count) < int(self.program.mlen) else "full", + "group_head_start": head_index, + "packed_head_count": packed_head_count, + "scatter_slot_width": d if grouped_narrow else col_count, + } + ) + elif len(logical_shape) == 3: + x, y, z = logical_shape + metadata.update( + { + "layout": "vector3d", + "vector_extents": (x, y, z), + "vector_row_dim": x, + "vector_col_dims": (y, z), + } + ) + else: + metadata["layout"] = "2d" + return metadata + + def input(self, name: str, logical_shape: LogicalShape, *, hbm_addr: Optional[int] = None) -> Input: + physical_shape = _logical_shape_to_physical_shape(logical_shape) + hbm_group_name = f"{name}.hbm" + if hbm_group_name not in self.program.hardware.hbm_objects: + self.program.add_hbm_object(hbm_group_name, physical_shape, hbm_addr=hbm_addr) + input_obj = Input(program=self.program, name=name, logical_shape=logical_shape) + input_obj.metadata["hbm_group_obj"] = hbm_group_name + self.inputs[name] = input_obj + return input_obj + + def tensor(self, name: str, logical_shape: LogicalShape) -> Tensor | Vector: + if len(logical_shape) == 3: + return self.vector(name, logical_shape) + tensor = Tensor(program=self.program, name=name, logical_shape=logical_shape) + self.tensors[name] = tensor + return tensor + + def vector(self, name: str, logical_shape: LogicalShape) -> Vector: + return self.program.vector_manager.vector(name, logical_shape) + + def mapt(self, signal: List[object]) -> List[object]: + """Group logical tensor tiles into per-thread compute packets. + + `mapt` is the logical staging step before value resolution. Depending + on the control mode, it can: + + - enumerate tiles directly for copy / elementwise paths + - build BSHD matmul groups + - build head-group packets for grouped-narrow tensors + - build BTMM/QKT-specific thread packets + + The output is intentionally still a tensor-layer structure. Value/scatter + objects are resolved later by `mapv`, not here. + """ + if len(signal) == 2: + operand, control = signal + if control == 0: + resolved_tiles = self._resolve_tiles_from_operand(operand) + return [[tile] for tile in resolved_tiles] + if control == "head_group": + return self.mapt_head_group(operand) + resolved_tiles = self._resolve_tiles_from_operand(operand) + raise NotImplementedError(f"Basic mapt resolve does not support control={control!r}") + + src1, src2, dst, control = signal + if control not in (0, 1): + raise NotImplementedError(f"Unsupported mapt control: {control}") + if ( + len(getattr(src1, "logical_shape", ())) == 4 + and len(getattr(src2, "logical_shape", ())) == 4 + and len(getattr(dst, "logical_shape", ())) == 4 + ): + if control == 1: + return self.mapt_btmm_head_group_qkt(src1, src2, dst) # type: ignore[return-value] + return self._mapt_bshd_matmul_groups(src1, src2, dst) + + src1_tiles = _tiles_in_grid_order(src1.tiles) + src2_tiles = _tiles_in_grid_order(src2.tiles) + dst_tiles = _tiles_in_grid_order(dst.tiles) + groups: List[List[object]] = [] + for dst_tile in dst_tiles: + lhs_group = [tile for tile in src1_tiles if tile.coord[0] == dst_tile.coord[0]] + rhs_group = [tile for tile in src2_tiles if tile.coord[1] == dst_tile.coord[1]] + groups.append([*lhs_group, *rhs_group, dst_tile]) + return groups + + def mapt_head_group(self, operand: object) -> List[Dict[str, object]]: + resolved_tiles = self._resolve_tiles_from_operand(operand) + if not resolved_tiles: + return [] + if not all(isinstance(tile, (TensorTile, InputTile, VectorTile)) for tile in resolved_tiles): + raise RuntimeError("mapt_head_group expects tile operands only") + + first_tile = resolved_tiles[0] + logical_shape = getattr(getattr(operand, "base", operand), "logical_shape", ()) + if len(logical_shape) != 4: + return [ + { + "control": "head_group", + "tiles": [tile], + "row_block": int(tile.metadata.get("row_block", tile.coord[0])), + "group_start": int(tile.metadata.get("head_index", 0)), + "group_heads": 1, + "lane_heads": [int(tile.metadata.get("head_index", 0))], + "group_key": ( + int(tile.metadata.get("row_block", tile.coord[0])), + int(tile.metadata.get("head_index", 0)), + ), + } + for tile in resolved_tiles + ] + + groups: Dict[Tuple[int, int], Dict[str, object]] = {} + for tile in resolved_tiles: + row_block = int(tile.metadata.get("row_block", tile.coord[0])) + group_start = int(tile.metadata.get("group_head_start", tile.metadata.get("head_index", 0))) + packed_head_count = int(tile.metadata.get("packed_head_count", 1)) + lane_heads = [int(tile.metadata.get("head_index", 0))] + if packed_head_count > 1: + lane_heads = [group_start + lane for lane in range(packed_head_count)] + group_key = (row_block, group_start) + packet = groups.get(group_key) + if packet is None: + packet = { + "control": "head_group", + "tiles": [], + "row_block": row_block, + "group_start": group_start, + "group_heads": 0, + "lane_heads": [], + "group_key": group_key, + } + groups[group_key] = packet + packet["tiles"].append(tile) + existing_heads = set(packet["lane_heads"]) + for head in lane_heads: + if head not in existing_heads: + packet["lane_heads"].append(head) + existing_heads.add(head) + packet["group_heads"] = len(packet["lane_heads"]) + + packets = list(groups.values()) + packets.sort(key=lambda item: (int(item["row_block"]), int(item["group_start"]))) + return packets + + def _mapt_bshd_matmul_groups(self, src1: object, src2: object, dst: object) -> List[List[object]]: + src1_shape = tuple(getattr(src1, "logical_shape", ())) + src2_shape = tuple(getattr(src2, "logical_shape", ())) + dst_shape = tuple(getattr(dst, "logical_shape", ())) + if src1_shape[0] != src2_shape[0] or src1_shape[0] != dst_shape[0]: + raise ValueError( + f"BSHD matmul requires matched batch size, got src1={src1_shape[0]} " + f"src2={src2_shape[0]} dst={dst_shape[0]}" + ) + src1_tiles = _tiles_in_grid_order(src1.tiles) + src2_tiles = _tiles_in_grid_order(src2.tiles) + dst_tiles = _tiles_in_grid_order(dst.tiles) + src1_by_batch_head_seq_k: Dict[Tuple[int, int, int, int], object] = {} + src2_by_batch_head_k_col: Dict[Tuple[int, int, int, int], object] = {} + groups: List[List[object]] = [] + + for tile in src1_tiles: + batch_index = _bshd_tile_batch_index(tile) + head_index = int(tile.metadata.get("head_index", 0)) + seq_block = _bshd_tile_seq_block(tile) + k_index = int(tile.metadata.get("d_tile_index", tile.coord[1])) + src1_by_batch_head_seq_k[(batch_index, head_index, seq_block, k_index)] = tile + + for tile in src2_tiles: + batch_index = _bshd_tile_batch_index(tile) + head_index = int(tile.metadata.get("head_index", 0)) + k_index = _bshd_tile_seq_block(tile) + d_tile_index = int(tile.metadata.get("d_tile_index", 0)) + src2_by_batch_head_k_col[(batch_index, head_index, k_index, d_tile_index)] = tile + + for dst_tile in dst_tiles: + batch_index = _bshd_tile_batch_index(dst_tile) + head_index = int(dst_tile.metadata.get("head_index", 0)) + seq_block = _bshd_tile_seq_block(dst_tile) + d_tile_index = int(dst_tile.metadata.get("d_tile_index", 0)) + lhs_candidates = [ + key + for key in src1_by_batch_head_seq_k.keys() + if key[0] == batch_index and key[1] == head_index and key[2] == seq_block + ] + k_values = sorted(key[3] for key in lhs_candidates) + group: List[object] = [] + for k_index in k_values: + lhs_tile = src1_by_batch_head_seq_k.get((batch_index, head_index, seq_block, k_index)) + rhs_tile = src2_by_batch_head_k_col.get((batch_index, head_index, k_index, d_tile_index)) + if lhs_tile is None or rhs_tile is None: + continue + group.append([lhs_tile, rhs_tile]) + group.append([dst_tile]) + groups.append(group) + return groups + + def mapt_btmm_head_group_qkt( + self, + src1: object, + src2: object, + dst: object, + ) -> List[BTMMHeadGroupThread]: + if not ( + len(getattr(src1, "logical_shape", ())) == 4 + and len(getattr(src2, "logical_shape", ())) == 4 + and len(getattr(dst, "logical_shape", ())) == 4 + ): + raise NotImplementedError("mapt_btmm_head_group_qkt currently supports BSHD tensors only") + + src1_batch, src1_seq, src1_heads, src1_dim = getattr(src1, "logical_shape") + src2_batch, src2_seq, src2_heads, src2_dim = getattr(src2, "logical_shape") + dst_batch, dst_seq, dst_heads, dst_dim = getattr(dst, "logical_shape") + if src1_batch != src2_batch or src1_batch != dst_batch: + raise ValueError( + f"BTMM QKT mapt requires matched batch size, got src1={src1_batch} " + f"src2={src2_batch} dst={dst_batch}" + ) + if src1_heads != src2_heads or src1_heads != dst_heads: + raise ValueError( + f"BTMM QKT mapt requires matched head count, got src1={src1_heads} src2={src2_heads} dst={dst_heads}" + ) + if src1_dim != self.program.btmm_hlen or src2_dim != self.program.btmm_hlen: + raise ValueError( + f"BTMM QKT mapt requires src1/src2 head_dim == btmm_hlen == {self.program.btmm_hlen}, " + f"got src1={src1_dim} src2={src2_dim}" + ) + if src1_seq != dst_seq: + raise ValueError(f"BTMM QKT mapt requires dst seq to match src1 seq, got dst={dst_seq} src1={src1_seq}") + if src2_seq != dst_dim: + raise ValueError( + f"BTMM QKT mapt requires dst last dim to match src2 seq, got dst={dst_dim} src2={src2_seq}" + ) + if dst_dim % self.program.mlen != 0: + raise ValueError( + f"BTMM QKT mapt requires dst last dim multiple of mlen={self.program.mlen}, got {dst_dim}" + ) + + lhs_tiles = _tiles_in_grid_order(src1.tiles) + rhs_tiles = _tiles_in_grid_order(src2.tiles) + dst_tiles = _tiles_in_grid_order(dst.tiles) + lhs_groups: Dict[Tuple[int, int, int], TileLike] = {} + rhs_groups: Dict[Tuple[int, int, int], TileLike] = {} + dst_by_key: Dict[Tuple[int, int, int, int], TileLike] = {} + threads: List[BTMMHeadGroupThread] = [] + + for tile in lhs_tiles: + batch_index = _bshd_tile_batch_index(tile) + seq_block = _bshd_tile_seq_block(tile) + group_block = int(tile.coord[1]) + lhs_groups[(batch_index, seq_block, group_block)] = tile + + for tile in rhs_tiles: + batch_index = _bshd_tile_batch_index(tile) + seq_block = _bshd_tile_seq_block(tile) + group_block = int(tile.coord[1]) + rhs_groups[(batch_index, seq_block, group_block)] = tile + + dst_col_blocks_per_head = dst_dim // self.program.mlen + for tile in dst_tiles: + batch_index = _bshd_tile_batch_index(tile) + seq_block = _bshd_tile_seq_block(tile) + head_index = int(tile.metadata.get("head_index", 0)) + rhs_row_block = int(tile.coord[1]) - head_index * dst_col_blocks_per_head + dst_by_key[(batch_index, seq_block, rhs_row_block, head_index)] = tile + + group_heads = self.program.btmm_lane_count + q_row_blocks = max(1, ceil(src1_seq / self.program.mlen)) + k_row_blocks = max(1, ceil(src2_seq / self.program.mlen)) + group_blocks = max(1, ceil(src1_heads / group_heads)) + + for batch_index in range(int(src1_batch)): + for lhs_row_block in range(q_row_blocks): + for rhs_row_block in range(k_row_blocks): + for group_block in range(group_blocks): + lhs_tile = lhs_groups.get((batch_index, lhs_row_block, group_block)) + rhs_tile = rhs_groups.get((batch_index, rhs_row_block, group_block)) + if lhs_tile is None or rhs_tile is None: + continue + + head_start = group_block * group_heads + dst_group_tiles: List[TileLike] = [] + lane_heads: List[int] = [] + for lane in range(group_heads): + head_index = head_start + lane + if head_index >= dst_heads: + break + dst_tile = dst_by_key.get((batch_index, lhs_row_block, rhs_row_block, head_index)) + if dst_tile is None: + continue + lane_heads.append(head_index) + dst_group_tiles.append(dst_tile) + + if not dst_group_tiles: + continue + + threads.append( + { + "control": "tensor_tile_group", + "lhs_tiles": [lhs_tile], + "rhs_tiles": [rhs_tile], + "dst_tiles": dst_group_tiles, + "batch_index": batch_index, + "group_block": group_block, + "group_start": head_start, + "group_heads": len(dst_group_tiles), + "lane_heads": lane_heads, + "lhs_row_block": lhs_row_block, + "rhs_row_block": rhs_row_block, + } + ) + return threads + + def mapt_view_matmul( + self, + src1: object, + src2: object, + dst: object, + ) -> List[ViewMatmulThread]: + if not ( + len(getattr(src1, "logical_shape", ())) == 4 + and len(getattr(src2, "logical_shape", ())) == 4 + and len(getattr(dst, "logical_shape", ())) == 4 + ): + raise NotImplementedError("mapt_view_matmul currently supports BSHD tensors only") + + src1_batch = int(getattr(src1, "logical_shape", ())[0]) + src2_batch = int(getattr(src2, "logical_shape", ())[0]) + dst_batch = int(getattr(dst, "logical_shape", ())[0]) + if src1_batch != src2_batch or src1_batch != dst_batch: + raise ValueError( + f"scatter-group mapt requires matched batch size, got src1={src1_batch} " + f"src2={src2_batch} dst={dst_batch}" + ) + src1_head_dim = int(getattr(src1, "logical_shape", ())[-1]) + src2_head_dim = int(getattr(src2, "logical_shape", ())[-1]) + dst_head_dim = int(getattr(dst, "logical_shape", ())[-1]) + if src2_head_dim <= 0 or self.program.mlen % src2_head_dim != 0: + raise ValueError( + f"scatter-group mapt requires src2 head_dim to divide mlen={self.program.mlen}, got {src2_head_dim}" + ) + if dst_head_dim != src2_head_dim: + raise ValueError( + f"scatter-group mapt expects dst head_dim == src2 head_dim, got dst={dst_head_dim} src2={src2_head_dim}" + ) + + group_heads = self.program.mlen // src2_head_dim + src1_by_batch_head_seq_k: Dict[Tuple[int, int, int, int], object] = {} + src2_by_batch_seq_group: Dict[Tuple[int, int, int], object] = {} + threads: List[ViewMatmulThread] = [] + + for tile in _tiles_in_grid_order(src1.tiles): + batch_index = _bshd_tile_batch_index(tile) + head_index = int(tile.metadata.get("head_index", 0)) + seq_block = _bshd_tile_seq_block(tile) + k_index = int(tile.metadata.get("d_tile_index", tile.coord[1])) + src1_by_batch_head_seq_k[(batch_index, head_index, seq_block, k_index)] = tile + + for tile in _tiles_in_grid_order(src2.tiles): + batch_index = _bshd_tile_batch_index(tile) + seq_block = _bshd_tile_seq_block(tile) + group_block = int(tile.coord[1]) + src2_by_batch_seq_group[(batch_index, seq_block, group_block)] = tile + + for dst_tile in _tiles_in_grid_order(dst.tiles): + batch_index = _bshd_tile_batch_index(dst_tile) + seq_block = _bshd_tile_seq_block(dst_tile) + group_block = int(dst_tile.coord[1]) + group_start = group_block * group_heads + lane_heads: List[int] = [] + lhs_candidates: List[List[object]] = [] + + for lane in range(group_heads): + head_index = group_start + lane + lane_k_tiles = [ + tile + for (tile_batch, tile_head, tile_seq, _), tile in src1_by_batch_head_seq_k.items() + if tile_batch == batch_index and tile_head == head_index and tile_seq == seq_block + ] + if not lane_k_tiles: + continue + lane_heads.append(head_index) + lhs_candidates.append(sorted(lane_k_tiles, key=lambda tile: int(tile.metadata.get("d_tile_index", 0)))) + + rhs_terms: List[ViewMatmulTerm] = [] + rhs_row_blocks = sorted( + row + for (tile_batch, row, col_group) in src2_by_batch_seq_group.keys() + if tile_batch == batch_index and col_group == group_block + ) + for rhs_row_block in rhs_row_blocks: + rhs_tile = src2_by_batch_seq_group.get((batch_index, rhs_row_block, group_block)) + if rhs_tile is None: + continue + term_lhs_tiles: List[object] = [] + for lane_tiles in lhs_candidates: + if rhs_row_block >= len(lane_tiles): + term_lhs_tiles = [] + break + term_lhs_tiles.append(lane_tiles[rhs_row_block]) + if not term_lhs_tiles: + continue + rhs_terms.append((term_lhs_tiles, rhs_tile)) + + threads.append((dst_tile, rhs_terms, group_start)) + return threads + + def mapt_back(self, signal_4: List[object], signal_1: List[object]) -> object: + if not signal_1: + return None + if signal_4: + controls = { + item.get("control") + for item in signal_4 + if isinstance(item, dict) and item.get("control") is not None + } + if len(controls) > 1: + raise RuntimeError(f"mapt_back received mixed map controls: {sorted(controls)}") + dst_tile = self._extract_dst_tile_from_group(signal_1[0]) + if dst_tile is None: + return None + if isinstance(dst_tile, TensorTile): + return self.tensors.get(dst_tile.tensor_name) or self.vectors.get(dst_tile.tensor_name) or self.inputs.get(dst_tile.tensor_name) + if isinstance(dst_tile, InputTile): + return self.inputs.get(dst_tile.input_name) + return None + + def _extract_dst_tile_from_group(self, group: object) -> Optional[object]: + if isinstance(group, dict): + dst_tile = group.get("dst_tile") + if isinstance(dst_tile, (TensorTile, InputTile, VectorTile)): + return dst_tile + dst_tiles = group.get("dst_tiles") + if isinstance(dst_tiles, list): + for item in dst_tiles: + if isinstance(item, (TensorTile, InputTile, VectorTile)): + return item + return None + if not isinstance(group, list) or not group: + return None + tail = group[-1] + if isinstance(tail, list) and len(tail) == 1 and isinstance(tail[0], (TensorTile, InputTile, VectorTile)): + return tail[0] + if isinstance(tail, (TensorTile, InputTile, VectorTile)): + return tail + return None + + def _resolve_tiles_from_operand(self, operand: object) -> List[object]: + if isinstance(operand, Input): + return _tiles_in_grid_order(operand.tiles) + if isinstance(operand, Tensor): + return _tiles_in_grid_order(operand.tiles) + if isinstance(operand, Vector): + return _tiles_in_grid_order(operand.tiles) + if isinstance(operand, InputSlice): + return self._resolve_slice_tiles(operand.base.tiles, operand.base.logical_shape, operand.selectors) + if isinstance(operand, TensorSlice): + return self._resolve_slice_tiles(operand.base.tiles, operand.base.logical_shape, operand.selectors) + if isinstance(operand, VectorSlice): + return self._resolve_slice_tiles(operand.base.tiles, operand.base.logical_shape, operand.selectors) + if isinstance(operand, (InputTile, TensorTile, VectorTile)): + return [operand] + raise NotImplementedError(f"Unsupported operand for mapt(control=0): {type(operand).__name__}") + + def _resolve_slice_tiles( + self, + tiles: Dict[TileCoord, object], + logical_shape: LogicalShape, + selectors: Tuple[SliceItem, ...], + ) -> List[object]: + row_range, col_range = _logical_selectors_to_physical_ranges(logical_shape, selectors) + resolved: List[object] = [] + for tile in _tiles_in_grid_order(tiles): + row_block, col_block = tile.coord + row_start = row_block * self.program.mlen + row_end = row_start + tile.tile_shape[0] + col_start = col_block * self.program.mlen + col_end = col_start + tile.tile_shape[1] + if _ranges_overlap((row_start, row_end), row_range) and _ranges_overlap((col_start, col_end), col_range): + # Views are aliases: return the owner tile directly instead of + # materializing a derived tile object with independent identity. + resolved.append(tile) + return resolved diff --git a/tilelang_runtime_compier/tile_tensor_program/_thread_manager.py b/tilelang_runtime_compier/tile_tensor_program/_thread_manager.py new file mode 100644 index 0000000..0562958 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_thread_manager.py @@ -0,0 +1,1388 @@ +"""ThreadManager: parallel thread regions, expression graphs, cache planning.""" + +from __future__ import annotations + +from math import ceil +from typing import Dict, List, Optional, Tuple + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class ThreadManager: + """Manage parallel thread regions, expression graphs, and cache planning. + + This layer is intentionally FP-first for now. It owns the symbolic + `parallel_region3d` flow and keeps the graph/cache planning state out of + `TileTensorProgram` so later lowering can evolve independently. + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.isa_emitter = program.isa_emitter + self._active_parallel_graphs: List[ParallelRegionGraph] = [] + self.parallel_regions: List[ParallelRegionGraph] = [] + self._parallel2d_scratch_var_cache: Dict[Tuple[int, int], List[FPVar]] = {} + + def parallel_region3d( + self, + extents: Tuple[int, int, int] | List[int], + *, + name: Optional[str] = None, + ) -> _ParallelRegionScope: + normalized = tuple(int(extent) for extent in extents) + if len(normalized) != 3 or any(extent <= 0 for extent in normalized): + raise ValueError(f"parallel_region3d expects three positive extents, got {extents}") + return _ParallelRegionScope(self.program, extents=normalized, name=name) + + def parallel_region2d( + self, + extents: Tuple[int, int] | List[int], + *, + name: Optional[str] = None, + ) -> _ParallelRegion2DScope: + normalized = tuple(int(extent) for extent in extents) + if len(normalized) != 2 or any(extent <= 0 for extent in normalized): + raise ValueError(f"parallel_region2d expects two positive extents, got {extents}") + return _ParallelRegion2DScope(self.program, extents=normalized, name=name) + + def where(self, predicate: object, on_true: object, on_false: object) -> ParallelExpr: + return ParallelExpr( + kind="select", + args=( + _coerce_parallel_expr(predicate), + _coerce_parallel_expr(on_true), + _coerce_parallel_expr(on_false), + ), + ) + + def if_then_else(self, predicate: object, on_true: object, on_false: object) -> ParallelExpr: + return self.where(predicate, on_true, on_false) + + def pair(self, axis: object) -> ParallelExpr: + # RoPE-style lane pairing helper: pair(2k)=2k+1, pair(2k+1)=2k. + return ParallelExpr(kind="pair_index", args=(_coerce_parallel_expr(axis),)) + + def half_index(self, axis: object) -> ParallelExpr: + # RoPE coefficient group helper: half_index(d)=d//2. + # Runtime planning assumes coefficients may be pre-expanded to full-lane layout. + return ParallelExpr(kind="half_index", args=(_coerce_parallel_expr(axis),)) + + def current_parallel_graph(self) -> ParallelRegionGraph: + if not self._active_parallel_graphs: + raise RuntimeError("parallel graph write requested outside active parallel_region3d") + return self._active_parallel_graphs[-1] + + def record_parallel_assignment_from_index( + self, + base: object, + item: SliceItem | Tuple[SliceItem, ...], + value: object, + ) -> None: + if not isinstance(item, tuple): + item = (item,) + self.record_parallel_assignment_from_access(ParallelAccess(base=base, selectors=tuple(item)), value) + + def record_parallel_assignment_from_access(self, dst_access: ParallelAccess, value: object) -> None: + region = self.current_parallel_graph() + if not isinstance(dst_access.base, Tensor): + raise TypeError( + "parallel assignment destination must be a Tensor-backed access; " + f"got {type(dst_access.base).__name__}" + ) + if len(dst_access.selectors) != len(dst_access.logical_shape): + raise ValueError( + f"parallel assignment target must fully index rank-{len(dst_access.logical_shape)} tensor, " + f"got selectors={dst_access.selectors}" + ) + expr = _coerce_parallel_expr(value) + if region.cache_plan.get("lowering_kind") == "fp2d": + self._validate_parallel2d_fp_assignment(region, dst_access, expr) + else: + self._validate_parallel_assignment(region, dst_access, expr) + assignment = ParallelAssignment( + dst=dst_access, + expr=expr, + task_id=self.program._auto_name(f"{region.name}.assign"), + sources=_collect_parallel_accesses(expr), + ) + region.assignments.append(assignment) + + def _validate_parallel2d_fp_assignment( + self, + region: ParallelRegionGraph, + dst_access: ParallelAccess, + expr: ParallelExpr, + ) -> None: + if not isinstance(dst_access.base, Vector): + raise TypeError( + "parallel_region2d currently supports FP-backed Vector destinations only; " + f"got {type(dst_access.base).__name__}" + ) + axis_refs = [selector for selector in dst_access.selectors if isinstance(selector, ParallelAxis)] + axis_ids = {int(axis.axis) for axis in axis_refs} + expected_axis_ids = {1, 2} + if int(region.extents[1]) == 1: + expected_axis_ids = {2} + if axis_ids != expected_axis_ids: + raise ValueError( + "parallel_region2d destination must index with its active axes; " + f"got selectors={dst_access.selectors}" + ) + axis_region_ids = {axis.region_id for axis in axis_refs} + if axis_region_ids != {region.region_id}: + raise ValueError("parallel_region2d destination mixes axes from another parallel region") + self._validate_parallel_expr(expr, region=region) + + def _validate_parallel_assignment( + self, + region: ParallelRegionGraph, + dst_access: ParallelAccess, + expr: ParallelExpr, + ) -> None: + axis_refs = [selector for selector in dst_access.selectors if isinstance(selector, ParallelAxis)] + if len(axis_refs) != 3: + raise ValueError( + "parallel assignment destination must index with exactly the active 3D parallel axes; " + f"got selectors={dst_access.selectors}" + ) + axis_region_ids = {axis.region_id for axis in axis_refs} + if axis_region_ids != {region.region_id}: + raise ValueError("parallel assignment destination mixes axes from another parallel region") + self._validate_parallel_expr(expr, region=region) + + def _validate_parallel_expr( + self, + expr: ParallelExpr, + *, + region: ParallelRegionGraph, + ) -> None: + if expr.kind in {"literal", "axis", "fpvar"}: + return + if expr.kind == "load": + access = expr.value + if not isinstance(access, ParallelAccess): + raise TypeError(f"parallel load expected ParallelAccess, got {type(access).__name__}") + axis_region_ids = { + selector.region_id + for selector in access.selectors + if isinstance(selector, ParallelAxis) + } + if axis_region_ids and axis_region_ids != {region.region_id}: + raise ValueError("parallel expression mixes axes from another parallel region") + return + if expr.kind == "select": + if len(expr.args) != 3: + raise ValueError("parallel select expression expects exactly three arguments") + self._validate_parallel_predicate(expr.args[0], region=region) + for arg in expr.args[1:]: + self._validate_parallel_expr(arg, region=region) + return + if expr.kind == "op": + if expr.op not in {"add", "sub", "mul", "max"} or len(expr.args) != 2: + raise NotImplementedError( + f"parallel expression currently supports only binary add/sub/mul/max, got {expr.op!r}" + ) + for arg in expr.args: + self._validate_parallel_expr(arg, region=region) + return + if expr.kind == "unary_op": + if expr.op not in {"exp", "reci", "sqrt"} or len(expr.args) != 1: + raise NotImplementedError( + f"parallel expression currently supports unary exp/reci/sqrt, got {expr.op!r}" + ) + self._validate_parallel_expr(expr.args[0], region=region) + return + if expr.kind in {"pair_index", "half_index"}: + for arg in expr.args: + self._validate_parallel_expr(arg, region=region) + return + raise NotImplementedError(f"Unsupported parallel expression kind: {expr.kind}") + + def _validate_parallel_predicate( + self, + expr: ParallelExpr, + *, + region: ParallelRegionGraph, + ) -> None: + if expr.kind == "op" and expr.op in {"lt", "le", "gt", "ge", "eq"} and len(expr.args) == 2: + for arg in expr.args: + self._validate_parallel_index_expr(arg, region=region) + return + raise NotImplementedError( + "parallel predicates currently support only binary comparisons " + f"(lt/le/gt/ge/eq), got {expr.kind}:{getattr(expr, 'op', None)!r}" + ) + + def _validate_parallel_index_expr( + self, + expr: ParallelExpr, + *, + region: ParallelRegionGraph, + ) -> None: + if expr.kind in {"literal", "axis"}: + return + if expr.kind == "op" and expr.op in {"add", "sub", "mul", "mod"} and len(expr.args) == 2: + for arg in expr.args: + self._validate_parallel_index_expr(arg, region=region) + return + raise NotImplementedError( + f"parallel predicate index expression currently supports only axis/literal/add/sub/mul/mod, got {expr.kind}:{getattr(expr, 'op', None)!r}" + ) + + def parallel_execution_plans(self) -> List[ParallelExecutionPlan]: + plans: List[ParallelExecutionPlan] = [] + for region in self.parallel_regions: + if region.execution_plan is not None: + plans.append(region.execution_plan) + return plans + + def lower_parallel_execution_plans(self) -> None: + if self.program._parallel_execution_lowered: + return + for region in self.parallel_regions: + if region.execution_plan is None: + continue + self._emit_parallel_execution_plan(region, region.execution_plan) + self.program._parallel_execution_lowered = True + + def _emit_parallel_execution_plan( + self, + region: ParallelRegionGraph, + execution_plan: ParallelExecutionPlan, + ) -> None: + if not execution_plan.cycle_plans: + raise RuntimeError(f"parallel region {region.name} finalized without any cycle plans") + self._prepare_parallel_region_output_bindings(region) + region_output_values: Dict[str, Tuple[TensorTile, ValueTile]] = {} + for cycle_plan in execution_plan.cycle_plans: + self._emit_parallel_cycle_plan(region, cycle_plan, region_output_values=region_output_values) + for dst_tile, output_value in region_output_values.values(): + self.program.value_manager._bind_value_to_tensor_tile(output_value, dst_tile) + + def _emit_parallel2d_fp_region(self, region: ParallelRegionGraph) -> None: + if region.cache_plan.get("lowering_kind") != "fp2d": + raise RuntimeError(f"parallel2d fp lowering got non-fp2d region {region.name}") + _, head_count, lane_count = (int(extent) for extent in region.extents) + for assignment in region.assignments: + if not isinstance(assignment.dst.base, Vector): + raise TypeError( + "parallel_region2d lowering supports only FP-backed Vector destinations; " + f"got {type(assignment.dst.base).__name__}" + ) + for head_index in range(head_count): + dst_vars = self._parallel2d_access_vars( + assignment.dst, + head_index=head_index, + lane_count=lane_count, + ) + self._emit_parallel2d_fp_expr_kernel( + assignment.expr, + dst_vars=dst_vars, + head_index=head_index, + lane_count=lane_count, + task_id=f"{assignment.task_id}.h{head_index}", + ) + + def _emit_parallel2d_fp_expr_kernel( + self, + expr: ParallelExpr, + *, + dst_vars: Sequence[FPVar], + head_index: int, + lane_count: int, + task_id: str, + ) -> None: + self._parallel2d_materialize_expr_into( + expr, + head_index=head_index, + lane_count=lane_count, + task_id=task_id, + dst_vars=list(dst_vars), + temp_depth=0, + ) + + def _parallel2d_materialize_expr_into( + self, + expr: ParallelExpr, + *, + head_index: int, + lane_count: int, + task_id: str, + dst_vars: Sequence[FPVar], + temp_depth: int, + ) -> None: + if expr.kind in {"load", "fpvar", "literal"}: + leaf_vars = self._parallel2d_expr_vars(expr, head_index=head_index, lane_count=lane_count) + self.isa_emitter.emit_fp_kernel( + src1_addrs=[_require_fp_addr(var) for var in leaf_vars], + dst_addrs=[_require_fp_addr(var) for var in dst_vars], + op="copy", + task_id=task_id, + ) + return + if expr.kind == "op": + if expr.op not in {"add", "sub", "mul", "max"} or len(expr.args) != 2: + raise NotImplementedError( + f"parallel_region2d FP lowering supports binary add/sub/mul/max, got {expr.op!r}" + ) + lhs_expr, rhs_expr = expr.args + if lhs_expr.kind in {"load", "fpvar", "literal"}: + src1_vars = self._parallel2d_expr_vars(lhs_expr, head_index=head_index, lane_count=lane_count) + else: + self._parallel2d_materialize_expr_into( + lhs_expr, + head_index=head_index, + lane_count=lane_count, + task_id=f"{task_id}.lhs", + dst_vars=dst_vars, + temp_depth=temp_depth, + ) + src1_vars = list(dst_vars) + + if rhs_expr.kind in {"load", "fpvar", "literal"}: + src2_vars = self._parallel2d_expr_vars(rhs_expr, head_index=head_index, lane_count=lane_count) + else: + src2_vars = self._parallel2d_scratch_vars( + lane_count=lane_count, + slot_index=temp_depth, + ) + self._parallel2d_materialize_expr_into( + rhs_expr, + head_index=head_index, + lane_count=lane_count, + task_id=f"{task_id}.rhs", + dst_vars=src2_vars, + temp_depth=temp_depth + 1, + ) + self.isa_emitter.emit_fp_kernel( + src1_addrs=[_require_fp_addr(var) for var in src1_vars], + src2_addrs=[_require_fp_addr(var) for var in src2_vars], + dst_addrs=[_require_fp_addr(var) for var in dst_vars], + op=str(expr.op), + task_id=task_id, + ) + return + if expr.kind == "unary_op": + if expr.op not in {"exp", "reci", "sqrt"} or len(expr.args) != 1: + raise NotImplementedError( + f"parallel_region2d FP lowering supports unary exp/reci/sqrt, got {expr.op!r}" + ) + arg_expr = expr.args[0] + if arg_expr.kind in {"load", "fpvar", "literal"}: + src_vars = self._parallel2d_expr_vars(arg_expr, head_index=head_index, lane_count=lane_count) + else: + self._parallel2d_materialize_expr_into( + arg_expr, + head_index=head_index, + lane_count=lane_count, + task_id=f"{task_id}.src", + dst_vars=dst_vars, + temp_depth=temp_depth, + ) + src_vars = list(dst_vars) + self.isa_emitter.emit_fp_kernel( + src1_addrs=[_require_fp_addr(var) for var in src_vars], + dst_addrs=[_require_fp_addr(var) for var in dst_vars], + op=str(expr.op), + task_id=task_id, + ) + return + raise NotImplementedError( + f"parallel_region2d FP lowering does not support expr kind {expr.kind!r}" + ) + + def _parallel2d_scratch_vars( + self, + *, + lane_count: int, + slot_index: int, + ) -> List[FPVar]: + key = (int(lane_count), int(slot_index)) + cached = self._parallel2d_scratch_var_cache.get(key) + if cached is not None: + return cached + fragment = self.program.fp_fragment( + self.program._auto_name(f"parallel2d.scratch{slot_index}"), + (int(lane_count),), + init=0.0, + ) + scratch_vars = self.program.mapf(fragment) + self._parallel2d_scratch_var_cache[key] = scratch_vars + return scratch_vars + + def _parallel2d_expr_vars( + self, + expr: ParallelExpr, + *, + head_index: int, + lane_count: int, + ) -> List[FPVar]: + if expr.kind == "load": + access = expr.value + if not isinstance(access, ParallelAccess): + raise RuntimeError(f"parallel_region2d load expr missing ParallelAccess: {expr}") + return self._parallel2d_access_vars(access, head_index=head_index, lane_count=lane_count) + if expr.kind == "fpvar": + fp_var = expr.value + if not isinstance(fp_var, FPVar): + raise RuntimeError(f"parallel_region2d fpvar expr missing FPVar payload: {expr}") + return [fp_var] * int(lane_count) + if expr.kind == "literal": + literal_var = self.program.mapf(float(expr.value))[0] + return [literal_var] * int(lane_count) + raise NotImplementedError( + f"parallel_region2d FP kernel operands currently support load/fpvar/literal, got {expr.kind}" + ) + + def _parallel2d_access_vars( + self, + access: ParallelAccess, + *, + head_index: int, + lane_count: int, + ) -> List[FPVar]: + if not isinstance(access.base, Vector): + raise TypeError( + "parallel_region2d FP access supports only Vector operands; " + f"got {type(access.base).__name__}" + ) + resolved: List[FPVar] = [] + for lane_index in range(int(lane_count)): + logical_index = self._parallel2d_access_logical_index( + access, + head_index=head_index, + lane_index=lane_index, + ) + resolved.append( + self.program.tensor_manager._resolve_element_fpvar( + ElementRef(base=access.base, indices=logical_index) + ) + ) + return resolved + + def _parallel2d_access_logical_index( + self, + access: ParallelAccess, + *, + head_index: int, + lane_index: int, + ) -> Tuple[int, ...]: + logical_index: List[int] = [] + for selector in access.selectors: + if isinstance(selector, ParallelAxis): + if int(selector.axis) == 1: + logical_index.append(int(head_index)) + elif int(selector.axis) == 2: + logical_index.append(int(lane_index)) + else: + raise RuntimeError(f"parallel_region2d does not expose axis {selector.axis}") + elif isinstance(selector, ParallelExpr): + logical_index.append( + self._parallel2d_index_expr_value( + selector, + head_index=head_index, + lane_index=lane_index, + ) + ) + elif isinstance(selector, int): + logical_index.append(int(selector)) + else: + raise NotImplementedError( + f"parallel_region2d does not support selector {selector!r} of type {type(selector).__name__}" + ) + return tuple(logical_index) + + def _parallel2d_index_expr_value( + self, + expr: ParallelExpr, + *, + head_index: int, + lane_index: int, + ) -> int: + if expr.kind == "literal": + return int(expr.value) + if expr.kind == "axis": + axis = expr.value + if not isinstance(axis, ParallelAxis): + raise RuntimeError("parallel_region2d axis expression missing axis metadata") + if int(axis.axis) == 1: + return int(head_index) + if int(axis.axis) == 2: + return int(lane_index) + raise RuntimeError(f"parallel_region2d does not expose axis {axis.axis}") + if expr.kind == "op" and len(expr.args) == 2: + lhs = self._parallel2d_index_expr_value(expr.args[0], head_index=head_index, lane_index=lane_index) + rhs = self._parallel2d_index_expr_value(expr.args[1], head_index=head_index, lane_index=lane_index) + if expr.op == "add": + return lhs + rhs + if expr.op == "sub": + return lhs - rhs + if expr.op == "mul": + return lhs * rhs + if expr.op == "mod": + return lhs % rhs + raise NotImplementedError( + f"Unsupported parallel_region2d index expression: {expr.kind}:{getattr(expr, 'op', None)!r}" + ) + + def _emit_parallel_cycle_plan( + self, + region: ParallelRegionGraph, + cycle_plan: ParallelCyclePlan, + *, + region_output_values: Dict[str, Tuple[TensorTile, ValueTile]], + ) -> None: + group = cycle_plan.group + if not cycle_plan.output_slots: + raise RuntimeError(f"parallel cycle {region.name} has no output slots") + if not cycle_plan.compute_ops: + raise RuntimeError(f"parallel cycle {region.name} has no compute ops") + if int(group.element_count) != int(self.program.mlen): + raise NotImplementedError( + "parallel lowering requires element_count == mlen per cycle " + f"(elem_width={group.elem_width}, k_count={group.k_count}, " + f"element_count={group.element_count}, mlen={self.program.mlen})" + ) + + cache_tag = f"{region.name}.i{group.i_index}.j{group.j_index}.k{group.k_base}" + ( + input_slot_bases, + output_slot_bases, + input_slot_names, + output_slot_names, + ) = self._allocate_parallel_cycle_cache_slots(cycle_plan, cache_tag) + output_slot_values = self._resolve_parallel_cycle_output_values( + cycle_plan, + group, + region_output_values=region_output_values, + ) + + try: + self._emit_parallel_cycle_loads(cycle_plan, group, cache_tag, input_slot_bases) + self._emit_parallel_cycle_compute(cycle_plan, group, input_slot_bases, output_slot_bases) + self._emit_parallel_cycle_writebacks( + cycle_plan, + group, + cache_tag, + output_slot_bases, + output_slot_values, + ) + finally: + self._free_parallel_cycle_cache_slots(input_slot_names, output_slot_names) + + def _allocate_parallel_cycle_cache_slots( + self, + cycle_plan: ParallelCyclePlan, + cache_tag: str, + ) -> Tuple[Dict[int, int], Dict[int, int], List[str], List[str]]: + allocator = self.program.compiler.sub_matrix_manager.fpram_allocator + cache_floor = int(self.program.tensor_manager._next_fp_mem_addr) + if allocator.next_free < cache_floor: + allocator.next_free = cache_floor + allocator.free_stack[:] = [block for block in allocator.free_stack if int(block.addr) >= cache_floor] + + input_slot_bases: Dict[int, int] = {} + output_slot_bases: Dict[int, int] = {} + input_slot_names: List[str] = [] + output_slot_names: List[str] = [] + for input_slot in cycle_plan.input_slots: + slot_name = f"__parallel_input_cache__.{cache_tag}.slot{input_slot.slot_id}" + input_slot_bases[input_slot.slot_id] = int(allocator.allocate(slot_name, self.program.mlen)) + input_slot_names.append(slot_name) + for output_slot in cycle_plan.output_slots: + slot_name = f"__parallel_output_cache__.{cache_tag}.slot{output_slot.slot_id}" + output_slot_bases[output_slot.slot_id] = int(allocator.allocate(slot_name, self.program.mlen)) + output_slot_names.append(slot_name) + return input_slot_bases, output_slot_bases, input_slot_names, output_slot_names + + def _resolve_parallel_cycle_output_values( + self, + cycle_plan: ParallelCyclePlan, + group: ParallelCycleGroup, + *, + region_output_values: Dict[str, Tuple[TensorTile, ValueTile]], + ) -> Dict[int, ValueTile]: + output_slot_values: Dict[int, ValueTile] = {} + for output_slot in cycle_plan.output_slots: + output_tile = self._parallel_access_cycle_dst_tile(output_slot.access, group) + output_slot_values[output_slot.slot_id] = self._get_or_create_parallel_region_output_value( + output_tile, + region_output_values=region_output_values, + ) + return output_slot_values + + def _emit_parallel_cycle_loads( + self, + cycle_plan: ParallelCyclePlan, + group: ParallelCycleGroup, + cache_tag: str, + input_slot_bases: Dict[int, int], + ) -> None: + for load_op in cycle_plan.load_ops: + src_vram_addr = self._parallel_access_cycle_src_vram_row_addr(load_op.access, group) + self.isa_emitter.emit_map_fp_v_tile( + fpram_addr=input_slot_bases[load_op.slot_id], + vram_addr=src_vram_addr, + row_count=1, + row_width=self.program.mlen, + task_id=f"parallel_load.{cache_tag}.slot{load_op.slot_id}", + ) + + def _emit_parallel_cycle_compute( + self, + cycle_plan: ParallelCyclePlan, + group: ParallelCycleGroup, + input_slot_bases: Dict[int, int], + output_slot_bases: Dict[int, int], + ) -> None: + for compute_op in cycle_plan.compute_ops: + access_order = _collect_parallel_accesses(compute_op.expr) + access_slot_map = { + _parallel_access_identity(access): slot_id + for access, slot_id in zip(access_order, compute_op.input_slot_ids) + } + dst_base = output_slot_bases[compute_op.dst_slot_id] + if self._try_emit_parallel_pairwise_cloop_compute( + compute_op=compute_op, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + dst_base=int(dst_base), + ): + continue + for lane_offset in range(self.program.mlen): + dst_addr = int(dst_base + lane_offset) + self._emit_parallel_expr_to_addr( + expr=compute_op.expr, + dst_addr=dst_addr, + lane_offset=lane_offset, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + task_id=f"{compute_op.task_id}.lane{lane_offset}", + ) + + def _emit_parallel_cycle_writebacks( + self, + cycle_plan: ParallelCyclePlan, + group: ParallelCycleGroup, + cache_tag: str, + output_slot_bases: Dict[int, int], + output_slot_values: Dict[int, ValueTile], + ) -> None: + for writeback_op in cycle_plan.writeback_ops: + output_value = output_slot_values[writeback_op.slot_id] + dst_vram_addr = output_value.residency.get("vram_addr") + if dst_vram_addr is None: + raise RuntimeError("parallel output writeback expected new value tile in VRAM") + dst_row = self._parallel_access_cycle_row(writeback_op.access, group) + dst_vram_row_addr = int(dst_vram_addr) + (int(dst_row) % self.program.mlen) * self.program.mlen + self.isa_emitter.emit_map_v_fp_tile( + vram_addr=dst_vram_row_addr, + fpram_addr=output_slot_bases[writeback_op.slot_id], + row_count=1, + row_width=self.program.mlen, + task_id=f"parallel_writeback.{cache_tag}.slot{writeback_op.slot_id}", + ) + + def _free_parallel_cycle_cache_slots( + self, + input_slot_names: List[str], + output_slot_names: List[str], + ) -> None: + allocator = self.program.compiler.sub_matrix_manager.fpram_allocator + for slot_name in input_slot_names: + allocator.free(slot_name, strict=False) + for slot_name in output_slot_names: + allocator.free(slot_name, strict=False) + + def _prepare_parallel_region_output_bindings(self, region: ParallelRegionGraph) -> None: + detached_tile_ids: set[str] = set() + if region.execution_plan is None: + return + for cycle_plan in region.execution_plan.cycle_plans: + for output_slot in cycle_plan.output_slots: + dst_tile = self._parallel_access_cycle_dst_tile(output_slot.access, cycle_plan.group) + if dst_tile.tile_id in detached_tile_ids: + continue + self.program.value_manager._unbind_tile_value_pointer(dst_tile.tile_id) + detached_tile_ids.add(dst_tile.tile_id) + + def _create_parallel_region_output_value(self, dst_tile: TensorTile) -> ValueTile: + value = ValueTile( + value_tile_id=self.program.value_manager._next_value_tile_id(), + logical_shape=dst_tile.tile_shape, + metadata={"source_tile_id": dst_tile.tile_id, "parallel_region_output": True}, + ) + vram_name = f"{value.value_tile_id}.vram" + vram_addr = self.program.value_manager.allocate_value_tile_address( + size=self.program.tile_elems, + name=vram_name, + place="vram", + value_tile=value, + ) + value.residency["vram_addr"] = vram_addr + value.residency["vram_name"] = vram_name + self.program.value_manager.value_tiles[value.value_tile_id] = value + self.program.value_manager._value_tiles_in_vram[value.value_tile_id] = int(vram_addr) + return value + + def _get_or_create_parallel_region_output_value( + self, + dst_tile: TensorTile, + *, + region_output_values: Dict[str, Tuple[TensorTile, ValueTile]], + ) -> ValueTile: + existing = region_output_values.get(dst_tile.tile_id) + if existing is not None: + return existing[1] + created = self._create_parallel_region_output_value(dst_tile) + region_output_values[dst_tile.tile_id] = (dst_tile, created) + return created + + def _try_emit_parallel_pairwise_cloop_compute( + self, + *, + compute_op: ParallelComputeOp, + group: ParallelCycleGroup, + access_slot_map: Dict[str, int], + input_slot_bases: Dict[int, int], + dst_base: int, + ) -> bool: + rope_inputs = self._match_parallel_pairwise_rope_expr(compute_op.expr) + if rope_inputs is None: + return False + if self.program.mlen % 2 != 0: + return False + + x_slot = access_slot_map.get(_parallel_access_identity(rope_inputs["x_direct"])) + cos_slot = access_slot_map.get(_parallel_access_identity(rope_inputs["cos"])) + sin_slot = access_slot_map.get(_parallel_access_identity(rope_inputs["sin"])) + neg_sin_slot = access_slot_map.get(_parallel_access_identity(rope_inputs["neg_sin"])) + if any(slot is None for slot in (x_slot, cos_slot, sin_slot, neg_sin_slot)): + return False + + gp_regs = self.program.compiler.register_allocator.allocate_gp(6) + gp_x, gp_cos, gp_sin, gp_neg_sin, gp_dst, gp_loop = gp_regs + try: + lines = [f"; parallel pairwise cloop compute {compute_op.task_id}"] + lines.append(f"S_ADDI_INT gp{gp_x}, gp0, {int(input_slot_bases[int(x_slot)])}") + lines.append(f"S_ADDI_INT gp{gp_cos}, gp0, {int(input_slot_bases[int(cos_slot)])}") + lines.append(f"S_ADDI_INT gp{gp_sin}, gp0, {int(input_slot_bases[int(sin_slot)])}") + lines.append(f"S_ADDI_INT gp{gp_neg_sin}, gp0, {int(input_slot_bases[int(neg_sin_slot)])}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_base)}") + lines.append(f"C_LOOP_START gp{gp_loop}, {self.program.mlen // 2}") + lines.append(f"S_LD_FP f1, gp{gp_x}, 0") + lines.append(f"S_LD_FP f2, gp{gp_x}, 1") + lines.append(f"S_LD_FP f3, gp{gp_cos}, 0") + lines.append(f"S_LD_FP f4, gp{gp_neg_sin}, 0") + lines.append(f"S_MUL_FP f5, f1, f3") + lines.append(f"S_MUL_FP f6, f2, f4") + lines.append(f"S_ADD_FP f5, f5, f6") + lines.append(f"S_ST_FP f5, gp{gp_dst}, 0") + lines.append(f"S_LD_FP f4, gp{gp_sin}, 0") + lines.append(f"S_MUL_FP f5, f1, f4") + lines.append(f"S_MUL_FP f6, f2, f3") + lines.append(f"S_ADD_FP f5, f5, f6") + lines.append(f"S_ST_FP f5, gp{gp_dst}, 1") + lines.append(f"S_ADDI_INT gp{gp_x}, gp{gp_x}, 2") + lines.append(f"S_ADDI_INT gp{gp_cos}, gp{gp_cos}, 2") + lines.append(f"S_ADDI_INT gp{gp_sin}, gp{gp_sin}, 2") + lines.append(f"S_ADDI_INT gp{gp_neg_sin}, gp{gp_neg_sin}, 2") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp{gp_dst}, 2") + lines.append(f"C_LOOP_END gp{gp_loop}") + self.program.compiler.generated_code += "\n".join(lines) + "\n" + finally: + self.program.compiler.register_allocator.free_gp(gp_regs) + return True + + def _match_parallel_pairwise_rope_expr( + self, + expr: ParallelExpr, + ) -> Optional[Dict[str, ParallelAccess]]: + if expr.kind != "select" or len(expr.args) != 3: + return None + predicate, even_expr, odd_expr = expr.args + if not self._parallel_expr_matches_even_mod2(predicate): + return None + if even_expr.kind != "op" or even_expr.op != "add" or len(even_expr.args) != 2: + return None + if odd_expr.kind != "op" or odd_expr.op != "add" or len(odd_expr.args) != 2: + return None + + even_terms = self._collect_parallel_mul_terms(even_expr) + odd_terms = self._collect_parallel_mul_terms(odd_expr) + if even_terms is None or odd_terms is None: + return None + + x_base_id = self._resolve_parallel_pairwise_data_base_id(even_terms + odd_terms) + if x_base_id is None: + return None + + even_direct = self._find_parallel_term(even_terms, data_base_id=x_base_id, pair=False) + even_pair = self._find_parallel_term(even_terms, data_base_id=x_base_id, pair=True) + odd_direct = self._find_parallel_term(odd_terms, data_base_id=x_base_id, pair=False) + odd_pair = self._find_parallel_term(odd_terms, data_base_id=x_base_id, pair=True) + if any(item is None for item in (even_direct, even_pair, odd_direct, odd_pair)): + return None + + if id(even_direct["data"].base) != x_base_id or id(even_pair["data"].base) != x_base_id: + return None + if id(odd_direct["data"].base) != x_base_id or id(odd_pair["data"].base) != x_base_id: + return None + if _parallel_access_identity(even_direct["coeff"]) != _parallel_access_identity(odd_direct["coeff"]): + return None + + return { + "x_direct": even_direct["data"], + "x_pair": even_pair["data"], + "cos": even_direct["coeff"], + "neg_sin": even_pair["coeff"], + "sin": odd_pair["coeff"], + } + + def _analyze_parallel_mul_term( + self, + expr: ParallelExpr, + ) -> Optional[Dict[str, ParallelAccess]]: + if expr.kind != "op" or expr.op != "mul" or len(expr.args) != 2: + return None + lhs, rhs = expr.args + if lhs.kind != "load" or rhs.kind != "load": + return None + lhs_access = lhs.value + rhs_access = rhs.value + if not isinstance(lhs_access, ParallelAccess) or not isinstance(rhs_access, ParallelAccess): + return None + return {"lhs": lhs_access, "rhs": rhs_access} + + def _collect_parallel_mul_terms( + self, + expr: ParallelExpr, + ) -> Optional[List[Dict[str, ParallelAccess]]]: + if expr.kind != "op" or expr.op != "add" or len(expr.args) != 2: + return None + terms = [self._analyze_parallel_mul_term(term) for term in expr.args] + if any(term is None for term in terms): + return None + return [term for term in terms if term is not None] + + def _resolve_parallel_pairwise_data_base_id( + self, + terms: List[Dict[str, ParallelAccess]], + ) -> Optional[int]: + direct_bases = set() + pair_bases = set() + for term in terms: + for access in (term["lhs"], term["rhs"]): + if self._parallel_access_is_pair(access): + pair_bases.add(id(access.base)) + else: + direct_bases.add(id(access.base)) + candidate_bases = direct_bases & pair_bases + if len(candidate_bases) != 1: + return None + return next(iter(candidate_bases)) + + def _find_parallel_term( + self, + terms: List[Optional[Dict[str, ParallelAccess]]], + *, + data_base_id: int, + pair: bool, + ) -> Optional[Dict[str, ParallelAccess]]: + for term in terms: + if term is None: + continue + lhs_access = term["lhs"] + rhs_access = term["rhs"] + lhs_is_data = id(lhs_access.base) == data_base_id and self._parallel_access_is_pair(lhs_access) == pair + rhs_is_data = id(rhs_access.base) == data_base_id and self._parallel_access_is_pair(rhs_access) == pair + if lhs_is_data == rhs_is_data: + continue + if lhs_is_data: + return {"data": lhs_access, "coeff": rhs_access} + return {"data": rhs_access, "coeff": lhs_access} + return None + + def _parallel_access_is_pair(self, access: ParallelAccess) -> bool: + selectors = tuple(access.selectors) + return bool( + selectors + and isinstance(selectors[-1], ParallelExpr) + and selectors[-1].kind == "pair_index" + ) + + def _parallel_expr_matches_even_mod2(self, expr: ParallelExpr) -> bool: + if expr.kind != "op" or expr.op != "eq" or len(expr.args) != 2: + return False + lhs, rhs = expr.args + if rhs.kind == "literal" and int(rhs.value) == 0: + return self._parallel_expr_is_mod2(lhs) + if lhs.kind == "literal" and int(lhs.value) == 0: + return self._parallel_expr_is_mod2(rhs) + return False + + def _parallel_expr_is_mod2(self, expr: ParallelExpr) -> bool: + return ( + expr.kind == "op" + and expr.op == "mod" + and len(expr.args) == 2 + and expr.args[1].kind == "literal" + and int(expr.args[1].value) == 2 + ) + + def _emit_parallel_expr_to_addr( + self, + *, + expr: ParallelExpr, + dst_addr: int, + lane_offset: int, + group: ParallelCycleGroup, + access_slot_map: Dict[str, int], + input_slot_bases: Dict[int, int], + task_id: str, + ) -> None: + gp_regs = self.program.compiler.register_allocator.allocate_gp(1) + gp_dst = gp_regs[0] + fp_reg = self._emit_parallel_expr_to_fp_reg( + expr=expr, + lane_offset=lane_offset, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + task_id=task_id, + preferred_fp_reg=1, + scratch_fp_regs=(2, 3, 4, 5, 6, 7), + ) + try: + lines = [ + f"; parallel expr store {task_id}", + f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}", + f"S_ST_FP f{fp_reg}, gp{gp_dst}, 0", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + finally: + self.program.compiler.register_allocator.free_gp(gp_regs) + + def _emit_parallel_expr_to_fp_reg( + self, + *, + expr: ParallelExpr, + lane_offset: int, + group: ParallelCycleGroup, + access_slot_map: Dict[str, int], + input_slot_bases: Dict[int, int], + task_id: str, + preferred_fp_reg: int, + scratch_fp_regs: Tuple[int, ...], + ) -> int: + if expr.kind == "select": + predicate = expr.args[0] + branch_expr = expr.args[1] if self._parallel_predicate_value(predicate, lane_offset, group) else expr.args[2] + return self._emit_parallel_expr_to_fp_reg( + expr=branch_expr, + lane_offset=lane_offset, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + task_id=task_id, + preferred_fp_reg=preferred_fp_reg, + scratch_fp_regs=scratch_fp_regs, + ) + if expr.kind == "literal": + literal_var = self.program.mapf(float(expr.value))[0] + return self._emit_parallel_load_addr_to_fp_reg( + int(_require_fp_addr(literal_var)), + task_id=f"{task_id}.literal", + fp_dst=preferred_fp_reg, + ) + if expr.kind == "fpvar": + fp_var = expr.value + if not isinstance(fp_var, FPVar): + raise RuntimeError(f"parallel fpvar expr missing FPVar payload: {expr}") + return self._emit_parallel_load_addr_to_fp_reg( + int(_require_fp_addr(fp_var)), + task_id=f"{task_id}.fpvar", + fp_dst=preferred_fp_reg, + ) + if expr.kind == "load": + access = expr.value + if not isinstance(access, ParallelAccess): + raise RuntimeError(f"parallel load expr missing ParallelAccess: {expr}") + if isinstance(access.base, Vector): + return self._emit_parallel_vector_access_to_fp_reg( + access=access, + lane_offset=lane_offset, + group=group, + task_id=f"{task_id}.vector_load", + fp_dst=preferred_fp_reg, + ) + slot_id = access_slot_map[_parallel_access_identity(access)] + lane_index = self._parallel_access_lane_index(access, lane_offset) + return self._emit_parallel_load_addr_to_fp_reg( + int(input_slot_bases[slot_id] + lane_index), + task_id=f"{task_id}.load", + fp_dst=preferred_fp_reg, + ) + if expr.kind == "op": + if len(expr.args) != 2 or expr.op not in {"add", "sub", "mul", "max"}: + raise NotImplementedError(f"parallel expr op lowering supports add/sub/mul/max only, got {expr.op!r}") + if not scratch_fp_regs: + raise RuntimeError(f"parallel expr {task_id} ran out of hard-coded FP scratch registers") + rhs_preferred = scratch_fp_regs[0] + rhs_scratch = tuple(reg for reg in scratch_fp_regs[1:] if reg != preferred_fp_reg) + lhs_reg = self._emit_parallel_expr_to_fp_reg( + expr=expr.args[0], + lane_offset=lane_offset, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + task_id=f"{task_id}.lhs", + preferred_fp_reg=preferred_fp_reg, + scratch_fp_regs=scratch_fp_regs, + ) + rhs_reg = self._emit_parallel_expr_to_fp_reg( + expr=expr.args[1], + lane_offset=lane_offset, + group=group, + access_slot_map=access_slot_map, + input_slot_bases=input_slot_bases, + task_id=f"{task_id}.rhs", + preferred_fp_reg=rhs_preferred, + scratch_fp_regs=rhs_scratch, + ) + op_to_insn = {"add": "S_ADD_FP", "sub": "S_SUB_FP", "mul": "S_MUL_FP", "max": "S_MAX_FP"} + self.program.compiler.generated_code += ( + f"; parallel expr {task_id}.{expr.op}\n" + f"{op_to_insn[str(expr.op)]} f{lhs_reg}, f{lhs_reg}, f{rhs_reg}\n" + ) + return lhs_reg + raise NotImplementedError(f"Unsupported parallel expr kind for lowering: {expr.kind}") + + def _emit_parallel_load_addr_to_fp_reg( + self, + addr: int, + *, + task_id: str, + fp_dst: int, + ) -> int: + gp_regs = self.program.compiler.register_allocator.allocate_gp(1) + gp_src = gp_regs[0] + lines = [ + f"; parallel load scalar {task_id}", + f"S_ADDI_INT gp{gp_src}, gp0, {int(addr)}", + f"S_LD_FP f{fp_dst}, gp{gp_src}, 0", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp(gp_regs) + return fp_dst + + def _emit_parallel_vector_access_to_fp_reg( + self, + *, + access: ParallelAccess, + lane_offset: int, + group: ParallelCycleGroup, + task_id: str, + fp_dst: int, + ) -> int: + fp_addr = self._parallel_vector_access_fp_addr(access, lane_offset=lane_offset, group=group) + return self._emit_parallel_load_addr_to_fp_reg(fp_addr, task_id=task_id, fp_dst=fp_dst) + + def _parallel_vector_access_fp_addr( + self, + access: ParallelAccess, + *, + lane_offset: int, + group: ParallelCycleGroup, + ) -> int: + if not isinstance(access.base, Vector): + raise RuntimeError(f"parallel vector fp addr expected Vector base, got {type(access.base).__name__}") + logical_index = self._parallel_access_lane_logical_index(access, lane_offset=lane_offset, group=group) + fp_var = self.program.tensor_manager._resolve_element_fpvar(ElementRef(base=access.base, indices=logical_index)) + return int(_require_fp_addr(fp_var)) + + def _parallel_predicate_value( + self, + expr: ParallelExpr, + lane_offset: int, + group: ParallelCycleGroup, + ) -> bool: + if expr.kind == "op" and len(expr.args) == 2: + lhs = self._parallel_index_expr_value(expr.args[0], lane_offset=lane_offset, group=group) + rhs = self._parallel_index_expr_value(expr.args[1], lane_offset=lane_offset, group=group) + if expr.op == "lt": + return lhs < rhs + if expr.op == "le": + return lhs <= rhs + if expr.op == "gt": + return lhs > rhs + if expr.op == "ge": + return lhs >= rhs + if expr.op == "eq": + return lhs == rhs + raise NotImplementedError(f"Unsupported parallel predicate lowering: {expr.kind}") + + def _parallel_index_expr_value( + self, + expr: ParallelExpr, + *, + lane_offset: int, + group: ParallelCycleGroup, + ) -> int: + if expr.kind == "literal": + return int(expr.value) + if expr.kind == "axis": + axis = expr.value + if axis is None: + raise RuntimeError("parallel axis expression missing axis metadata") + axis_id = int(axis.axis) + if axis_id == 0: + return int(group.i_index) + if axis_id == 1: + return int(group.j_index) + if axis_id == 2: + if int(group.k_count) > 1 and int(group.elem_width) < int(self.program.mlen): + return int(group.k_base) + (int(lane_offset) % int(group.elem_width)) + return int(group.k_base) + int(lane_offset) + raise RuntimeError(f"Unsupported parallel axis id: {axis_id}") + if expr.kind == "op" and len(expr.args) == 2: + lhs = self._parallel_index_expr_value(expr.args[0], lane_offset=lane_offset, group=group) + rhs = self._parallel_index_expr_value(expr.args[1], lane_offset=lane_offset, group=group) + if expr.op == "add": + return lhs + rhs + if expr.op == "sub": + return lhs - rhs + if expr.op == "mul": + return lhs * rhs + if expr.op == "mod": + return lhs % rhs + raise NotImplementedError(f"Unsupported parallel index expression lowering: {expr.kind}:{getattr(expr, 'op', None)!r}") + + def _parallel_access_lane_index(self, access: ParallelAccess, lane_offset: int) -> int: + selectors = tuple(access.selectors) + if not selectors: + return int(lane_offset) + last = selectors[-1] + if isinstance(last, ParallelExpr) and last.kind == "pair_index": + return int(lane_offset) ^ 1 + return int(lane_offset) + + def _parallel_access_packs_axis1_lanes( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> bool: + if int(group.k_count) <= 1 or int(group.elem_width) >= int(self.program.mlen): + return False + selectors = tuple(access.selectors) + if not selectors: + return False + lane_selector = selectors[-1] + if isinstance(lane_selector, ParallelAxis): + lane_axis_ok = int(lane_selector.axis) == 2 + elif isinstance(lane_selector, ParallelExpr): + lane_axis_ok = lane_selector.kind in {"pair_index", "half_index"} + else: + lane_axis_ok = False + has_axis1 = any( + isinstance(selector, ParallelAxis) and int(selector.axis) == 1 + for selector in selectors[:-1] + ) + return bool(lane_axis_ok and has_axis1) + + def _parallel_access_lane_logical_index( + self, + access: ParallelAccess, + *, + lane_offset: int, + group: ParallelCycleGroup, + ) -> Tuple[int, ...]: + logical_index: List[int] = [] + lane_axis_index = len(access.logical_shape) - 1 + multi_lane = int(group.k_count) > 1 + elem_width = int(group.elem_width) + packed_axis1 = self._parallel_access_packs_axis1_lanes(access, group) + for axis_pos, selector in enumerate(access.selectors): + if isinstance(selector, ParallelAxis): + if int(selector.axis) == 0: + logical_index.append(int(group.i_index)) + elif int(selector.axis) == 1: + if multi_lane and packed_axis1: + logical_index.append(int(group.j_index) + int(lane_offset) // elem_width) + else: + logical_index.append(int(group.j_index)) + elif int(selector.axis) == 2: + if multi_lane: + if packed_axis1 or axis_pos == lane_axis_index: + # Innermost: element position within lane + logical_index.append(int(self._parallel_access_lane_index(access, lane_offset)) % elem_width) + else: + # Non-innermost: head/group index + logical_index.append(int(group.k_base) + int(lane_offset) // elem_width) + else: + logical_index.append(int(group.k_base) + (self._parallel_access_lane_index(access, lane_offset) if axis_pos == lane_axis_index else int(lane_offset))) + else: + raise RuntimeError(f"Unsupported parallel axis id: {selector.axis}") + elif isinstance(selector, ParallelExpr): + if axis_pos != lane_axis_index: + raise NotImplementedError( + f"parallel lane logical index only supports selector expr on innermost axis, got axis_pos={axis_pos}" + ) + if multi_lane: + # Innermost expr (pair_index etc.): element position within lane + logical_index.append(int(self._parallel_access_lane_index(access, lane_offset)) % elem_width) + else: + logical_index.append(int(group.k_base) + int(self._parallel_access_lane_index(access, lane_offset))) + elif isinstance(selector, int): + logical_index.append(int(selector)) + else: + raise NotImplementedError( + f"parallel lane logical index does not support selector {selector!r} of type {type(selector).__name__}" + ) + return tuple(logical_index) + + def _parallel_access_cycle_src_vram_row_addr( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> int: + tile = self._parallel_access_cycle_src_tile(access, group) + row = self._parallel_access_cycle_row(access, group) + local_row = row % self.program.mlen + value = self.program.value_manager.resolve_value_tile(tile) + self.program.value_manager.ensure_value_tile_in_place(value, "vram") + vram_addr = value.residency.get("vram_addr") + if vram_addr is None: + raise RuntimeError(f"parallel lowering expected VRAM residency for {value.value_tile_id}") + return int(vram_addr) + local_row * self.program.mlen + + def _parallel_access_cycle_row( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> int: + concrete_selectors = self._parallel_access_concrete_selectors(access, group) + row_range, col_range = _logical_selectors_to_physical_ranges(access.base.logical_shape, concrete_selectors) + expected_width = int(group.element_count) + if (row_range[1] - row_range[0]) != 1 or (col_range[1] - col_range[0]) != expected_width: + raise NotImplementedError( + "parallel lowering currently supports one full-width contiguous row per cycle; " + f"got row_range={row_range}, col_range={col_range}, expected_width={expected_width}, mlen={self.program.mlen}" + ) + return int(row_range[0]) + + def _parallel_access_cycle_src_tile( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> TileLike: + tile = self._parallel_access_cycle_tile(access, group) + if not isinstance(tile, (TensorTile, InputTile)): + raise RuntimeError("parallel source access did not resolve to TensorTile/InputTile") + return tile + + def _parallel_access_cycle_dst_tile( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> TensorTile: + tile = self._parallel_access_cycle_tile(access, group) + if not isinstance(tile, TensorTile): + raise RuntimeError( + f"parallel destination access must resolve to TensorTile, got {type(tile).__name__}" + ) + return tile + + def _parallel_access_cycle_tile( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> TileLike: + concrete_selectors = self._parallel_access_concrete_selectors(access, group) + row_range, col_range = _logical_selectors_to_physical_ranges(access.base.logical_shape, concrete_selectors) + expected_width = int(group.element_count) + if (row_range[1] - row_range[0]) != 1 or (col_range[1] - col_range[0]) != expected_width: + raise NotImplementedError( + "parallel lowering currently supports one full-width contiguous row per cycle; " + f"got row_range={row_range}, col_range={col_range}, expected_width={expected_width}, mlen={self.program.mlen}" + ) + row = int(row_range[0]) + col = int(col_range[0]) + tile_coord = (row // self.program.mlen, col // self.program.mlen) + tiles = getattr(access.base, "tiles", None) + if not isinstance(tiles, dict): + raise RuntimeError(f"parallel access base {type(access.base).__name__} does not expose tiles") + tile = tiles.get(tile_coord) + if not isinstance(tile, (TensorTile, InputTile)): + raise RuntimeError(f"parallel access did not resolve to TensorTile/InputTile at coord={tile_coord}") + return tile + + def _parallel_access_concrete_selectors( + self, + access: ParallelAccess, + group: ParallelCycleGroup, + ) -> Tuple[SliceItem, ...]: + concrete: List[SliceItem] = [] + multi_lane = int(group.k_count) > 1 + packed_axis1 = self._parallel_access_packs_axis1_lanes(access, group) + if multi_lane and packed_axis1: + axis_to_value = { + 0: int(group.i_index), + 1: slice(int(group.j_index), int(group.j_index) + int(group.k_count)), + 2: slice(0, int(group.elem_width)), + } + expr_range = slice(0, int(group.elem_width)) + elif multi_lane: + # Multi-lane: axis 2 selects k_count head/group indices; + # pair_index/half_index cover the per-head elem_width range. + k_axis_range = slice(int(group.k_base), int(group.k_base) + int(group.k_count)) + expr_range = slice(0, int(group.elem_width)) + axis_to_value = { + 0: int(group.i_index), + 1: int(group.j_index), + 2: k_axis_range, + } + else: + k_axis_range = slice(int(group.k_base), int(group.k_base) + int(self.program.mlen)) + expr_range = slice(int(group.k_base), int(group.k_base) + int(self.program.mlen)) + axis_to_value = { + 0: int(group.i_index), + 1: int(group.j_index), + 2: k_axis_range, + } + for selector in access.selectors: + if isinstance(selector, ParallelAxis): + concrete.append(axis_to_value[int(selector.axis)]) + elif isinstance(selector, ParallelExpr): + if selector.kind == "pair_index": + concrete.append(expr_range) + elif selector.kind == "half_index": + concrete.append(expr_range) + else: + raise NotImplementedError(f"Unsupported selector expr for concrete access lowering: {selector.kind}") + else: + concrete.append(selector) + return tuple(concrete) + + +class _LoopHintRange: + def __init__(self, program: "TileTensorProgram", *, kind: str, extent: int, region_id: Optional[int] = None) -> None: + self.program = program + self.kind = kind + self.extent = int(extent) + self.region_id = region_id + + def __iter__(self): + for index in range(self.extent): + if self.kind == "parallel" and self.region_id is not None: + self.program._active_parallel_region_ids.append(self.region_id) + try: + yield index + finally: + if self.kind == "parallel" and self.region_id is not None: + self.program._active_parallel_region_ids.pop() + + diff --git a/tilelang_runtime_compier/tile_tensor_program/_types.py b/tilelang_runtime_compier/tile_tensor_program/_types.py new file mode 100644 index 0000000..e48e0ba --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_types.py @@ -0,0 +1,746 @@ +"""TileTensor data classes: FP types, parallel types, tile/tensor types. + +Includes the dataclasses, scope helpers, and module-level type aliases used +by the rest of the `tile_tensor_program` package. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from math import ceil +from typing import Dict, List, Optional, Sequence, Tuple + + +__all__ = [ + "TileCoord", + "LogicalShape", + "SliceItem", + "FPIndex", + "FPVar", + "FPFragment", + "FPFragmentSlice", + "ElementRef", + "ParallelAxis", + "ParallelAccess", + "ParallelExpr", + "ParallelAssignment", + "ParallelCycleGroup", + "ParallelInputCacheSlotPlan", + "ParallelOutputCacheSlotPlan", + "ParallelLoadOp", + "ParallelComputeOp", + "ParallelWritebackOp", + "ParallelCyclePlan", + "ParallelExecutionPlan", + "ParallelRegionGraph", + "_ParallelRegionScope", + "_ParallelRegion2DScope", + "InputTile", + "TensorTile", + "VectorTile", + "ValueTile", + "ValueTileView", + "PreparedWrite", + "Input", + "Tensor", + "Vector", + "InputSlice", + "TensorSlice", + "VectorSlice", + "InputTranspose", + "TensorTranspose", + "VectorTranspose", + "TileLike", + "TensorLike", + "TransposedTensorLike", + "SourceValueLike", + "RowOperandLike", + "ViewMatmulTerm", + "ViewMatmulThread", + "BTMMHeadGroupThread", + "CopyMapvPacket", + "MatmulMapvPacket", + "GemmMapvPacket", + "MapvPacket", +] + + +TileCoord = Tuple[int, int] +LogicalShape = Tuple[int, ...] +SliceItem = int | slice +FPIndex = Tuple[int, ...] + + +TileCoord = Tuple[int, int] +LogicalShape = Tuple[int, ...] +SliceItem = int | slice +FPIndex = Tuple[int, ...] + + +@dataclass +class FPVar: + name: str + dtype: str = "fp32" + size: int = 1 + storage: str = "fpram" + fp_mem_addr: Optional[int] = None # Address in FP_MEM; loaded via S_LD_FP before VF ops + + +@dataclass +class FPFragment: + program: "TileTensorProgram" + name: str + shape: Tuple[int, ...] + vars: Dict[FPIndex, FPVar] = field(default_factory=dict) + dtype: str = "fp32" + storage: str = "fpram" + metadata: Dict[str, object] = field(default_factory=dict) + + def __getitem__(self, item: SliceItem | Tuple[SliceItem, ...]) -> "FPFragmentSlice": + if not isinstance(item, tuple): + item = (item,) + return FPFragmentSlice(base=self, selectors=item) + + +@dataclass +class FPFragmentSlice: + base: FPFragment + selectors: Tuple[SliceItem, ...] + + +@dataclass(frozen=True) +class ElementRef: + base: object + indices: Tuple[int, ...] + + +@dataclass(frozen=True) +class ParallelAxis: + program: "TileTensorProgram" + region_id: int + axis: int + name: str + extent: int + + def _as_expr(self) -> "ParallelExpr": + return ParallelExpr(kind="axis", value=self) + + def __add__(self, other: object) -> "ParallelExpr": + return self._as_expr().__add__(other) + + def __radd__(self, other: object) -> "ParallelExpr": + return self._as_expr().__radd__(other) + + def __sub__(self, other: object) -> "ParallelExpr": + return self._as_expr().__sub__(other) + + def __rsub__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rsub__(other) + + def __mul__(self, other: object) -> "ParallelExpr": + return self._as_expr().__mul__(other) + + def __rmul__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rmul__(other) + + def __mod__(self, other: object) -> "ParallelExpr": + return self._as_expr().__mod__(other) + + def __rmod__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rmod__(other) + + def __lt__(self, other: object) -> "ParallelExpr": + return self._as_expr().__lt__(other) + + def __le__(self, other: object) -> "ParallelExpr": + return self._as_expr().__le__(other) + + def __gt__(self, other: object) -> "ParallelExpr": + return self._as_expr().__gt__(other) + + def __ge__(self, other: object) -> "ParallelExpr": + return self._as_expr().__ge__(other) + + def __eq__(self, other: object) -> "ParallelExpr": # type: ignore[override] + return self._as_expr().__eq__(other) + + +@dataclass(frozen=True) +class ParallelAccess: + base: object + selectors: Tuple[object, ...] + + @property + def program(self) -> "TileTensorProgram": + return self.base.program + + @property + def logical_shape(self) -> LogicalShape: + shape = tuple(getattr(self.base, "logical_shape", ())) + if not shape: + raise RuntimeError(f"ParallelAccess base {type(self.base).__name__} does not expose logical_shape") + return shape + + def append_selectors(self, item: SliceItem | Tuple[SliceItem, ...]) -> "ParallelAccess": + if not isinstance(item, tuple): + item = (item,) + return ParallelAccess(base=self.base, selectors=self.selectors + tuple(item)) + + def __getitem__(self, item: SliceItem | Tuple[SliceItem, ...]) -> "ParallelAccess": + return self.append_selectors(item) + + def __setitem__(self, item: SliceItem | Tuple[SliceItem, ...], value: object) -> None: + self.program.thread_manager.record_parallel_assignment_from_access(self.append_selectors(item), value) + + def _as_expr(self) -> "ParallelExpr": + return ParallelExpr(kind="load", value=self) + + def __add__(self, other: object) -> "ParallelExpr": + return self._as_expr().__add__(other) + + def __radd__(self, other: object) -> "ParallelExpr": + return self._as_expr().__radd__(other) + + def __sub__(self, other: object) -> "ParallelExpr": + return self._as_expr().__sub__(other) + + def __rsub__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rsub__(other) + + def __mul__(self, other: object) -> "ParallelExpr": + return self._as_expr().__mul__(other) + + def __rmul__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rmul__(other) + + def __mod__(self, other: object) -> "ParallelExpr": + return self._as_expr().__mod__(other) + + def __rmod__(self, other: object) -> "ParallelExpr": + return self._as_expr().__rmod__(other) + + def __lt__(self, other: object) -> "ParallelExpr": + return self._as_expr().__lt__(other) + + def __le__(self, other: object) -> "ParallelExpr": + return self._as_expr().__le__(other) + + def __gt__(self, other: object) -> "ParallelExpr": + return self._as_expr().__gt__(other) + + def __ge__(self, other: object) -> "ParallelExpr": + return self._as_expr().__ge__(other) + + def __eq__(self, other: object) -> "ParallelExpr": # type: ignore[override] + return self._as_expr().__eq__(other) + + +@dataclass(frozen=True) +class ParallelExpr: + kind: str + value: object = None + args: Tuple["ParallelExpr", ...] = () + op: Optional[str] = None + + def _binary(self, other: object, *, op: str) -> "ParallelExpr": + return ParallelExpr( + kind="op", + op=op, + args=(self, _coerce_parallel_expr(other)), + ) + + def __add__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="add") + + def __radd__(self, other: object) -> "ParallelExpr": + return _coerce_parallel_expr(other)._binary(self, op="add") + + def __sub__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="sub") + + def __rsub__(self, other: object) -> "ParallelExpr": + return _coerce_parallel_expr(other)._binary(self, op="sub") + + def __mul__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="mul") + + def __rmul__(self, other: object) -> "ParallelExpr": + return _coerce_parallel_expr(other)._binary(self, op="mul") + + def __mod__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="mod") + + def __rmod__(self, other: object) -> "ParallelExpr": + return _coerce_parallel_expr(other)._binary(self, op="mod") + + def __lt__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="lt") + + def __le__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="le") + + def __gt__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="gt") + + def __ge__(self, other: object) -> "ParallelExpr": + return self._binary(other, op="ge") + + def __eq__(self, other: object) -> "ParallelExpr": # type: ignore[override] + return self._binary(other, op="eq") + + +@dataclass +class ParallelAssignment: + dst: ParallelAccess + expr: ParallelExpr + task_id: str + sources: List[ParallelAccess] = field(default_factory=list) + + +@dataclass(frozen=True) +class ParallelCycleGroup: + i_index: int + j_index: int + k_base: int + k_count: int + elem_width: int + element_count: int + + +@dataclass(frozen=True) +class ParallelInputCacheSlotPlan: + slot_id: int + access: ParallelAccess + pattern_kind: str + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ParallelOutputCacheSlotPlan: + slot_id: int + access: ParallelAccess + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ParallelLoadOp: + slot_id: int + access: ParallelAccess + ensure_place: str = "vram" + load_kind: str = "mapv_to_fpram" + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ParallelComputeOp: + task_id: str + dst_slot_id: int + expr: ParallelExpr + input_slot_ids: List[int] = field(default_factory=list) + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ParallelWritebackOp: + slot_id: int + access: ParallelAccess + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class ParallelCyclePlan: + group: ParallelCycleGroup + input_slots: List[ParallelInputCacheSlotPlan] = field(default_factory=list) + output_slots: List[ParallelOutputCacheSlotPlan] = field(default_factory=list) + load_ops: List[ParallelLoadOp] = field(default_factory=list) + compute_ops: List[ParallelComputeOp] = field(default_factory=list) + writeback_ops: List[ParallelWritebackOp] = field(default_factory=list) + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class ParallelExecutionPlan: + region_name: str + cycle_groups: List[ParallelCycleGroup] = field(default_factory=list) + cycle_plans: List[ParallelCyclePlan] = field(default_factory=list) + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class ParallelRegionGraph: + region_id: int + name: str + extents: Tuple[int, int, int] + axes: Tuple[ParallelAxis, ParallelAxis, ParallelAxis] + assignments: List[ParallelAssignment] = field(default_factory=list) + cache_plan: Dict[str, object] = field(default_factory=dict) + execution_plan: Optional[ParallelExecutionPlan] = None + + def finalize(self, program: "TileTensorProgram") -> None: + unique_input_identities: Dict[str, ParallelAccess] = {} + predicate_kinds: set[str] = set() + output_cache_count = 0 + for assignment in self.assignments: + for access in assignment.sources: + unique_input_identities.setdefault(_parallel_access_identity(access), access) + for predicate in _collect_parallel_predicates(assignment.expr): + predicate_kinds.add(predicate) + output_cache_count = max(output_cache_count, 1) + self.cache_plan = { + "thread_group_size": int(program.mlen), + "cache_shape": (1, int(program.mlen)), + "cache_cycle_model": "uniform", + "input_cache_count": int(len(unique_input_identities)), + "output_cache_count": int(output_cache_count), + "d_axis_groups": int(ceil(self.extents[2] / program.mlen)), + "input_accesses": sorted(unique_input_identities.keys()), + "predicate_kinds": sorted(predicate_kinds), + } + self.execution_plan = _build_parallel_execution_plan(self, program=program) + + +class _ParallelRegionScope: + def __init__( + self, + program: "TileTensorProgram", + *, + extents: Tuple[int, int, int], + name: Optional[str] = None, + ) -> None: + self.program = program + self.extents = tuple(int(extent) for extent in extents) + self.name = name or self.program._auto_name("parallel_region") + self.region_id = self.program._parallel_region_counter + self.program._parallel_region_counter += 1 + self.region: Optional[ParallelRegionGraph] = None + + def __enter__(self) -> Tuple[ParallelAxis, ParallelAxis, ParallelAxis]: + axes = ( + ParallelAxis(self.program, self.region_id, 0, "s", self.extents[0]), + ParallelAxis(self.program, self.region_id, 1, "h", self.extents[1]), + ParallelAxis(self.program, self.region_id, 2, "d", self.extents[2]), + ) + self.region = ParallelRegionGraph( + region_id=self.region_id, + name=self.name, + extents=self.extents, + axes=axes, + ) + self.program.thread_manager._active_parallel_graphs.append(self.region) + return axes + + def __exit__(self, exc_type, exc, tb) -> None: + region = self.region + if region is None: + return + popped = self.program.thread_manager._active_parallel_graphs.pop() + if popped is not region: + raise RuntimeError("parallel region stack became inconsistent") + if exc_type is None: + region.finalize(self.program) + self.program.thread_manager.parallel_regions.append(region) + if region.execution_plan is not None: + self.program.thread_manager._emit_parallel_execution_plan(region, region.execution_plan) + self.program._parallel_execution_lowered = True + + +class _ParallelRegion2DScope: + def __init__( + self, + program: "TileTensorProgram", + *, + extents: Tuple[int, int], + name: Optional[str] = None, + ) -> None: + self.program = program + self.extents = tuple(int(extent) for extent in extents) + self.name = name or self.program._auto_name("parallel_region2d") + self.region_id = self.program._parallel_region_counter + self.program._parallel_region_counter += 1 + self.region: Optional[ParallelRegionGraph] = None + + def __enter__(self) -> Tuple[ParallelAxis, ParallelAxis]: + axes = ( + ParallelAxis(self.program, self.region_id, 0, "_", 1), + ParallelAxis(self.program, self.region_id, 1, "h", self.extents[0]), + ParallelAxis(self.program, self.region_id, 2, "s", self.extents[1]), + ) + self.region = ParallelRegionGraph( + region_id=self.region_id, + name=self.name, + extents=(1, self.extents[0], self.extents[1]), + axes=axes, + ) + self.region.cache_plan["lowering_kind"] = "fp2d" + self.program.thread_manager._active_parallel_graphs.append(self.region) + return axes[1], axes[2] + + def __exit__(self, exc_type, exc, tb) -> None: + region = self.region + if region is None: + return + popped = self.program.thread_manager._active_parallel_graphs.pop() + if popped is not region: + raise RuntimeError("parallel region stack became inconsistent") + if exc_type is None: + self.program.thread_manager.parallel_regions.append(region) + self.program.thread_manager._emit_parallel2d_fp_region(region) + + +@dataclass +class InputTile: + tile_id: str + input_name: str + coord: TileCoord + tile_shape: Tuple[int, int] + binding: Optional[str] = None + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class TensorTile: + tile_id: str + tensor_name: str + coord: TileCoord + tile_shape: Tuple[int, int] + binding: Optional[str] = None + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class VectorTile(TensorTile): + pass + + +@dataclass +class ValueTile: + value_tile_id: str + logical_shape: Tuple[int, int] + from_input_tile: bool = False + source_input_tile_id: Optional[str] = None + residency: Dict[str, Optional[int | str]] = field(default_factory=dict) + metadata: Dict[str, object] = field(default_factory=dict) + + +@dataclass +class ValueTileView: + backing_value_tile_id: str + owner_tile_id: str + row_offset: int + row_count: int + col_offset: int + col_count: int + metadata: Dict[str, object] = field(default_factory=dict) + + @property + def view_id(self) -> str: + lane = self.metadata.get("lane_index") + if isinstance(lane, int): + return f"{self.owner_tile_id}.lane{lane}" + return self.owner_tile_id + + +@dataclass +class PreparedWrite: + """Explicit write-preparation result for one tensor view update. + + `prepare_updated_view_value(...)` returns this object so callers do not need + to reverse-engineer write semantics from scattered booleans. + """ + old_value: ValueTile + new_value: ValueTile + target_view: ValueTileView + reuse_old: bool + requires_preserve_copy: bool = False + + +@dataclass +class Input: + program: "TileTensorProgram" + name: str + logical_shape: LogicalShape + tiles: Dict[TileCoord, InputTile] = field(default_factory=dict) + metadata: Dict[str, object] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.tiles = self.program.tensor_manager.create_input_tiles(self.name, self.logical_shape) + + def __getitem__(self, item: SliceItem | Tuple[SliceItem, ...]) -> "InputSlice": + if not isinstance(item, tuple): + item = (item,) + if _contains_parallel_selector(item): + return ParallelAccess(base=self, selectors=item) + if _is_full_element_index(item, len(self.logical_shape)): + return ElementRef(base=self, indices=tuple(int(index) for index in item)) + return InputSlice(base=self, selectors=item) + + def __setitem__(self, item: SliceItem | Tuple[SliceItem, ...], value: object) -> None: + self.program.thread_manager.record_parallel_assignment_from_index(self, item, value) + + @property + def T(self) -> "InputTranspose": + return InputTranspose(base=self) + + +@dataclass +class Tensor: + program: "TileTensorProgram" + name: str + logical_shape: LogicalShape + tiles: Dict[TileCoord, TensorTile] = field(default_factory=dict) + metadata: Dict[str, object] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.tiles = self.program.tensor_manager.create_tensor_tiles(self.name, self.logical_shape) + + def __getitem__(self, item: SliceItem | Tuple[SliceItem, ...]) -> "TensorSlice": + if not isinstance(item, tuple): + item = (item,) + if _contains_parallel_selector(item): + return ParallelAccess(base=self, selectors=item) + if _is_full_element_index(item, len(self.logical_shape)): + return ElementRef(base=self, indices=tuple(int(index) for index in item)) + return TensorSlice(base=self, selectors=item) + + def __setitem__(self, item: SliceItem | Tuple[SliceItem, ...], value: object) -> None: + self.program.thread_manager.record_parallel_assignment_from_index(self, item, value) + + @property + def T(self) -> "TensorTranspose": + return TensorTranspose(base=self) + + +@dataclass +class Vector(Tensor): + tiles: Dict[TileCoord, VectorTile] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.tiles = self.program.vector_manager.create_vector_tiles(self.name, self.logical_shape) + + def __getitem__(self, item: SliceItem | Tuple[SliceItem, ...]) -> "VectorSlice": + if not isinstance(item, tuple): + item = (item,) + if _contains_parallel_selector(item): + return ParallelAccess(base=self, selectors=item) + if _is_full_element_index(item, len(self.logical_shape)): + return ElementRef(base=self, indices=tuple(int(index) for index in item)) + return VectorSlice(base=self, selectors=item) + + def __setitem__(self, item: SliceItem | Tuple[SliceItem, ...], value: object) -> None: + self.program.thread_manager.record_parallel_assignment_from_index(self, item, value) + + @property + def T(self) -> "VectorTranspose": + return VectorTranspose(base=self) + + +@dataclass +class InputSlice: + base: Input + selectors: Tuple[SliceItem, ...] + + +@dataclass +class TensorSlice: + base: Tensor + selectors: Tuple[SliceItem, ...] + + +@dataclass +class VectorSlice: + base: Vector + selectors: Tuple[SliceItem, ...] + + +@dataclass(frozen=True) +class InputTranspose: + base: Input + + @property + def program(self) -> "TileTensorProgram": + return self.base.program + + @property + def name(self) -> str: + return f"{self.base.name}.T" + + @property + def logical_shape(self) -> LogicalShape: + return self.base.logical_shape + + @property + def tiles(self) -> Dict[TileCoord, InputTile]: + return self.base.tiles + + +@dataclass(frozen=True) +class TensorTranspose: + base: Tensor + + @property + def program(self) -> "TileTensorProgram": + return self.base.program + + @property + def name(self) -> str: + return f"{self.base.name}.T" + + @property + def logical_shape(self) -> LogicalShape: + return self.base.logical_shape + + @property + def tiles(self) -> Dict[TileCoord, TensorTile]: + return self.base.tiles + + +@dataclass(frozen=True) +class VectorTranspose: + base: Vector + + @property + def program(self) -> "TileTensorProgram": + return self.base.program + + @property + def name(self) -> str: + return f"{self.base.name}.T" + + @property + def logical_shape(self) -> LogicalShape: + return self.base.logical_shape + + @property + def tiles(self) -> Dict[TileCoord, VectorTile]: + return self.base.tiles + + +TileLike = TensorTile | InputTile | VectorTile +TensorLike = Tensor | Input | Vector +TransposedTensorLike = TensorTranspose | InputTranspose | VectorTranspose +SourceValueLike = ValueTile +RowOperandLike = ValueTile | ValueTileView +ViewMatmulTerm = Tuple[List[TileLike], TileLike] +ViewMatmulThread = Tuple[TileLike, List[ViewMatmulTerm], int] +BTMMHeadGroupThread = Dict[str, object] +CopyMapvPacket = Tuple[str, ValueTile, TileLike] +MatmulMapvPacket = Tuple[str, List[List[SourceValueLike]], ValueTile, TileLike] +GemmMapvPacket = Tuple[str, List[List[SourceValueLike]], ValueTile, ValueTile, TileLike] +MapvPacket = CopyMapvPacket | MatmulMapvPacket | GemmMapvPacket + + + + + +# Bottom-of-file import: helpers used by dataclass method bodies. +# Placed here (not at top) to avoid a circular import — `_helpers` does +# `from ._types import *` at its top, which would fail before the classes +# defined above were registered in this module's namespace. +from ._helpers import ( + _build_parallel_execution_plan, + _coerce_parallel_expr, + _collect_parallel_predicates, + _contains_parallel_selector, + _is_full_element_index, + _parallel_access_identity, +) diff --git a/tilelang_runtime_compier/tile_tensor_program/_value_manager.py b/tilelang_runtime_compier/tile_tensor_program/_value_manager.py new file mode 100644 index 0000000..677d688 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_value_manager.py @@ -0,0 +1,1942 @@ +"""ValueManager: backing-value bindings, view resolution, residency, write prep.""" + +from __future__ import annotations + +from math import ceil +from typing import Dict, List, Optional, Sequence, Tuple + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class ValueManager: + """Resolve logical tiles into backing values/views and manage residency. + + The value layer is responsible for: + + - direct `tile -> ValueTile` bindings + - `ValueTileView` resolution over shared backing values + - write preparation for mutating tensor destinations + - HBM/VRAM/MRAM residency transitions + - rebinding and release when compute produces updated values + + This class is the main implementation of the runtime's value layer. The + preferred write-preparation entrypoint is `prepare_updated_view_value(...)`. + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.isa_emitter = program.isa_emitter + self.value_tiles: Dict[str, ValueTile] = {} + self.full_tile_bindings: Dict[str, str] = {} + self.value_tile_tensor_refs: Dict[str, set[str]] = {} + self.narrow_group_bindings: Dict[Tuple[object, ...], str] = {} + self._value_tiles_in_vram: Dict[str, int] = {} + self._value_tiles_in_mram: Dict[str, int] = {} + self._value_tiles_in_hbm: Dict[str, object] = {} + self._mram_fifo: List[str] = [] + self._protected_vram_value_tile_ids: set[str] = set() + self._value_tile_counter = 0 + + @property + def bindings(self) -> Dict[str, str]: + # Compatibility alias for older scaffold/debug helpers. + return self.full_tile_bindings + + def _next_value_tile_id(self) -> str: + value_tile_id = f"value_tile.{self._value_tile_counter}" + self._value_tile_counter += 1 + return value_tile_id + + def mapv(self, signal: List[object]) -> MapvPacket: + """Resolve one mapped logical packet into concrete value-layer operands. + + Input packets come from TensorManager's `mapt` stage plus residency + targets and, optionally, one control tag. The function performs late + source resolution so compute sees the correct runtime object type: + + - wide/full tiles -> ValueTile + - narrow/grouped tiles -> shared backing ValueTile + + Destination resolution is also late here so updates can detach old + bindings and materialize one fresh writable value only when compute is + ready to run. + """ + control = None + if signal and isinstance(signal[-1], str): + control = signal[-1] + residency_targets = signal[-2] + signal_items = signal[:-2] + else: + residency_targets = signal[-1] + signal_items = signal[:-1] + + if control == "copy_tile_pair": + if len(signal_items) != 2 or not all(_is_tile_object(item) for item in signal_items): + raise RuntimeError("copy_tile_pair mapv expects [src_tile, dst_tile, residency_targets, control]") + src_tile, dst_tile = signal_items + src_value = self._resolve_mapv_source_value(src_tile, residency_targets[0]) + if not isinstance(src_value, ValueTile): + raise RuntimeError("copy mapv expects one full source ValueTile") + return ("copy", src_value, dst_tile) + + pair_groups, dst_tile = self._split_mapv_signal(signal_items) + mapped_pairs: List[List[object]] = [] + for pair in pair_groups: + if len(pair) != 2: + continue + src1_tile, src2_tile = pair + v1 = self._resolve_mapv_source_value(src1_tile, residency_targets[0]) + v2 = self._resolve_mapv_source_value(src2_tile, residency_targets[1]) + mapped_pairs.append([v1, v2]) + + if dst_tile is None: + raise RuntimeError("mapv expects one destination tensor tile") + if isinstance(dst_tile, TensorTile): + dst_view = self.resolve_value_tile_view(dst_tile) + prepared_write = self.prepare_updated_view_value( + dst_tile, + dst_view, + ensure_old_place=None, + new_place=residency_targets[2], + ) + v3 = prepared_write.new_value + else: + v3 = self._prepare_mapv_destination_value(dst_tile, residency_targets[2]) + return ("matmul", mapped_pairs, v3, dst_tile) + + def _resolve_mapv_source_value(self, tile: TensorTile | InputTile | VectorTile, place: str) -> SourceValueLike: + if isinstance(tile, VectorTile): + raise RuntimeError( + f"VectorTile {tile.tile_id} maps to FPFragment rather than ValueTile; " + "use mapf or ElementRef-based FP kernels" + ) + value = self._resolve_tile_backing_value(tile) + return value + + def _resolve_alias_owner_tile(self, tile: TileLike) -> TileLike: + if not bool(tile.metadata.get("slice_materialized", False)): + return tile + source_tile_id = tile.metadata.get("source_tile_id") + if not isinstance(source_tile_id, str): + return tile + owner_tile = self.program.tensor_manager.tensor_tiles.get(source_tile_id) + if owner_tile is None: + owner_tile = self.program.tensor_manager.input_tiles.get(source_tile_id) + if not isinstance(owner_tile, (TensorTile, InputTile, VectorTile)): + return tile + return owner_tile + + def _prepare_mapv_destination_value(self, tile: TensorTile | InputTile | VectorTile, place: str) -> ValueTile: + if isinstance(tile, VectorTile): + raise RuntimeError( + f"VectorTile {tile.tile_id} does not prepare one destination ValueTile; " + "bind it to FPFragment through ValueManager" + ) + canonical_tile = self._resolve_alias_owner_tile(tile) + if canonical_tile is not tile and not self._is_narrow_tensor_tile(tile): + tile = canonical_tile + if isinstance(tile, TensorTile) and not self._is_narrow_tensor_tile(tile): + old_value = self.resolve_value_tile(tile) + old_value_tile_id = self._detach_tile_value_pointer(tile.tile_id, auto_free=False) + if old_value_tile_id is None: + raise RuntimeError(f"Wide destination tile {tile.tile_id} had no bound value to detach") + new_value = self.prepare_vram_backing_value(old_value) + self._attach_tile_value_pointer(tile.tile_id, new_value.value_tile_id) + self.free_value_tile(old_value_tile_id) + return new_value + dst_source_value = self.resolve_value_tile(tile) + value = self.prepare_vram_backing_value(dst_source_value) + return value + + def _is_packed_narrow_tile(self, tile: TileLike) -> bool: + return int(tile.metadata.get("packed_head_count", 1)) > 1 or bool(tile.metadata.get("packed_head_group", False)) + + def _is_grouped_narrow_backing_tile(self, tile: TileLike) -> bool: + return self._is_packed_narrow_tile(tile) + + def _is_narrow_tensor_tile(self, tile: TileLike) -> bool: + width_class = tile.metadata.get("tile_width_class") + if width_class == "narrow": + return True + if width_class == "full": + return False + return int(tile.tile_shape[1]) < int(self.program.mlen) + + def _view_group_key_for_tile(self, tile: TileLike) -> Tuple[object, ...]: + owner_name = _tile_owner_name(tile) + if bool(tile.metadata.get("packed_head_group", False)): + head_index = int(tile.metadata.get("group_head_start", tile.metadata.get("head_index", 0))) + else: + head_index = int(tile.metadata.get("head_index", 0)) + row_block = int(tile.metadata.get("row_block", tile.coord[0])) + return (owner_name, head_index, row_block) + + def _view_slot_key_for_tile(self, tile: TileLike) -> Tuple[object, ...]: + owner_name = _tile_owner_name(tile) + head_index = int(tile.metadata.get("slot_head_index", tile.metadata.get("head_index", 0))) + row_block = int(tile.metadata.get("row_block", tile.coord[0])) + col_offset = int(tile.metadata.get("scatter_col_offset", tile.coord[1] * self.program.mlen)) + col_count = int(tile.metadata.get("scatter_col_count", tile.tile_shape[1])) + return (owner_name, head_index, row_block, col_offset, col_count) + + def _tiles_sharing_backing(self, tile: TensorTile | InputTile) -> List[TensorTile | InputTile]: + if not self._is_narrow_tensor_tile(tile): + return [tile] + if self._is_packed_narrow_tile(tile): + return [tile] + return self._iter_group_tiles(tile) + + def _bind_tiles_to_value(self, tiles: Sequence[TensorTile | InputTile], value_tile_id: str) -> List[str]: + detached_ids: List[str] = [] + for tile in tiles: + old_value_tile_id = self._detach_tile_value_pointer(tile.tile_id) + if old_value_tile_id is not None and old_value_tile_id != value_tile_id: + detached_ids.append(old_value_tile_id) + self._attach_tile_value_pointer(tile.tile_id, value_tile_id) + return detached_ids + + def _rebind_view_group_value(self, tile: TensorTile | InputTile, new_value: ValueTile) -> None: + group_tiles = self._tiles_sharing_backing(tile) + if self._is_narrow_tensor_tile(tile): + self.narrow_group_bindings[self._view_group_key_for_tile(tile)] = new_value.value_tile_id + detached_ids = self._bind_tiles_to_value(group_tiles, new_value.value_tile_id) + for old_value_tile_id in sorted(set(detached_ids)): + self.free_value_tile(old_value_tile_id) + + def _iter_value_tile_views(self, value_tile_id: str) -> List[ValueTileView]: + tile_ids = sorted(self.value_tile_tensor_refs.get(value_tile_id, set())) + views: List[ValueTileView] = [] + for tile_id in tile_ids: + tile = self.program.tensor_manager.tensor_tiles.get(tile_id) + if tile is None: + tile = self.program.tensor_manager.input_tiles.get(tile_id) + if not isinstance(tile, (TensorTile, InputTile)): + continue + for view in self._tile_compute_views(tile): + if view.backing_value_tile_id == value_tile_id: + views.append(view) + return views + + def _views_overlap(self, lhs: ValueTileView, rhs: ValueTileView) -> bool: + lhs_row_end = int(lhs.row_offset) + int(lhs.row_count) + rhs_row_end = int(rhs.row_offset) + int(rhs.row_count) + lhs_col_end = int(lhs.col_offset) + int(lhs.col_count) + rhs_col_end = int(rhs.col_offset) + int(rhs.col_count) + return not ( + lhs_row_end <= int(rhs.row_offset) + or rhs_row_end <= int(lhs.row_offset) + or lhs_col_end <= int(rhs.col_offset) + or rhs_col_end <= int(lhs.col_offset) + ) + + def _same_view_identity(self, lhs: ValueTileView, rhs: ValueTileView) -> bool: + return ( + lhs.backing_value_tile_id == rhs.backing_value_tile_id + and lhs.owner_tile_id == rhs.owner_tile_id + and int(lhs.row_offset) == int(rhs.row_offset) + and int(lhs.row_count) == int(rhs.row_count) + and int(lhs.col_offset) == int(rhs.col_offset) + and int(lhs.col_count) == int(rhs.col_count) + ) + + def view_has_conflicting_refs(self, view: ValueTileView) -> bool: + for other_view in self._iter_value_tile_views(view.backing_value_tile_id): + if self._same_view_identity(view, other_view): + continue + if self._views_overlap(view, other_view): + return True + return False + + def prepare_updated_view_value( + self, + tile: TensorTile | InputTile, + view: ValueTileView, + *, + ensure_old_place: Optional[str] = None, + new_place: str = "vram", + ) -> PreparedWrite: + """Prepare one mutating tensor-view write. + + This is the main write-path helper for tensor destinations. + + Returned `PreparedWrite` tells the caller: + - whether the write is in-place (`reuse_old`) + - which backing value should receive the write (`new_value`) + - which view on the new backing should be targeted (`target_view`) + - whether a partial-update preserve copy is still required + (`requires_preserve_copy`) + """ + old_value = self.value_tiles.get(view.backing_value_tile_id) + if not isinstance(old_value, ValueTile): + raise RuntimeError(f"View {view.view_id} is missing backing value {view.backing_value_tile_id}") + if ensure_old_place is not None: + self.ensure_value_tile_in_place(old_value, ensure_old_place) + if not self.view_has_conflicting_refs(view): + self.ensure_value_tile_in_place(old_value, new_place) + if new_place == "vram": + self._drop_stale_non_vram_residency(old_value) + return PreparedWrite( + old_value=old_value, + new_value=old_value, + target_view=view, + reuse_old=True, + requires_preserve_copy=False, + ) + requires_preserve_copy = False + self.protect_value_tile(old_value, "vram") + try: + if self._view_covers_logical_tile(tile, view): + new_value = self.prepare_vram_backing_value(old_value, preserve_existing=True) + else: + new_value = self._prepare_partial_update_vram_successor(old_value) + if new_value is None: + new_value = self.prepare_vram_backing_value(old_value, preserve_existing=True) + requires_preserve_copy = True + self._rebind_view_group_value(tile, new_value) + finally: + self.stop_protect_value_tile(old_value, "vram") + self.ensure_value_tile_in_place(new_value, new_place) + return PreparedWrite( + old_value=old_value, + new_value=new_value, + target_view=self.rebind_view(view, new_value), + reuse_old=False, + requires_preserve_copy=requires_preserve_copy, + ) + + def resolve_value_tile_view(self, tile: TensorTile | InputTile) -> ValueTileView: + backing_value = self.resolve_value_tile(tile) + if self._is_packed_narrow_tile(tile): + return ValueTileView( + backing_value_tile_id=backing_value.value_tile_id, + owner_tile_id=tile.tile_id, + row_offset=0, + row_count=int(tile.tile_shape[0]), + col_offset=0, + col_count=int(tile.tile_shape[1]), + metadata={"slot_key": self._view_group_key_for_tile(tile), "kind": "packed_tile"}, + ) + if self._is_narrow_tensor_tile(tile): + slot_key = self._view_slot_key_for_tile(tile) + return ValueTileView( + backing_value_tile_id=backing_value.value_tile_id, + owner_tile_id=tile.tile_id, + row_offset=0, + row_count=int(tile.tile_shape[0]), + col_offset=int(slot_key[3]), + col_count=int(slot_key[4]), + metadata={"slot_key": slot_key, "kind": "narrow_tile"}, + ) + return ValueTileView( + backing_value_tile_id=backing_value.value_tile_id, + owner_tile_id=tile.tile_id, + row_offset=0, + row_count=int(tile.tile_shape[0]), + col_offset=0, + col_count=int(tile.tile_shape[1]), + metadata={"kind": "full_tile"}, + ) + + def _tile_compute_views(self, tile: TensorTile | InputTile) -> List[ValueTileView]: + if not self._is_packed_narrow_tile(tile): + return [self.resolve_value_tile_view(tile)] + backing_value = self.resolve_value_tile(tile) + packed_heads = int(tile.metadata.get("packed_head_count", 1)) + slot_width = int(tile.metadata.get("scatter_slot_width", tile.tile_shape[1])) + views: List[ValueTileView] = [] + for lane_index in range(packed_heads): + views.append( + ValueTileView( + backing_value_tile_id=backing_value.value_tile_id, + owner_tile_id=tile.tile_id, + row_offset=0, + row_count=int(tile.tile_shape[0]), + col_offset=lane_index * slot_width, + col_count=slot_width, + metadata={"lane_index": lane_index, "kind": "packed_lane"}, + ) + ) + return views + + def resolve_row_operand(self, tile: TensorTile | InputTile, place: str = "vram") -> RowOperandLike: + if self._is_narrow_tensor_tile(tile): + view = self.resolve_value_tile_view(tile) + return view + value = self.resolve_value_tile(tile) + return value + + def resolve_row_operand_for_ranges( + self, + tile: TensorTile | InputTile, + row_range: Tuple[int, int], + col_range: Tuple[int, int], + place: str = "vram", + ) -> RowOperandLike: + if not self._is_narrow_tensor_tile(tile): + return self.resolve_row_operand(tile, place) + + row_block, col_block = tile.coord + row_start = row_block * self.program.mlen + row_end = row_start + int(tile.tile_shape[0]) + col_start = col_block * self.program.mlen + col_end = col_start + int(tile.tile_shape[1]) + if not _ranges_overlap((row_start, row_end), row_range) or not _ranges_overlap((col_start, col_end), col_range): + raise RuntimeError( + f"Requested row operand slice row_range={row_range} col_range={col_range} does not overlap tile {tile.tile_id}" + ) + + overlap_col_start = max(col_start, col_range[0]) + overlap_col_end = min(col_end, col_range[1]) + overlap_col_offset = int(overlap_col_start - col_start) + overlap_col_count = int(overlap_col_end - overlap_col_start) + if overlap_col_count <= 0: + raise RuntimeError(f"Resolved empty column overlap for tile {tile.tile_id}") + + if overlap_col_offset == 0 and overlap_col_count == int(tile.tile_shape[1]): + return self.resolve_row_operand(tile, place) + + slot_width = int(tile.metadata.get("scatter_slot_width", overlap_col_count)) + if overlap_col_offset % slot_width != 0 or overlap_col_count % slot_width != 0: + raise RuntimeError( + f"Slice overlap for tile {tile.tile_id} is not aligned to slot width {slot_width}: " + f"offset={overlap_col_offset} count={overlap_col_count}" + ) + backing_value = self.resolve_value_tile(tile) + return ValueTileView( + backing_value_tile_id=backing_value.value_tile_id, + owner_tile_id=tile.tile_id, + row_offset=0, + row_count=int(tile.tile_shape[0]), + col_offset=int(overlap_col_offset), + col_count=int(overlap_col_count), + metadata={ + "slot_width": slot_width, + "lane_index": overlap_col_offset // slot_width, + "source": "slice_range", + }, + ) + + def rebind_view(self, view: ValueTileView, new_value: ValueTile) -> ValueTileView: + return ValueTileView( + backing_value_tile_id=new_value.value_tile_id, + owner_tile_id=view.owner_tile_id, + row_offset=int(view.row_offset), + row_count=int(view.row_count), + col_offset=int(view.col_offset), + col_count=int(view.col_count), + metadata=dict(view.metadata), + ) + + def _drop_stale_non_vram_residency(self, value: ValueTile) -> None: + mram_name = value.residency.pop("mram_name", None) + if mram_name is not None: + self.program.compiler.sub_matrix_manager.mram_allocator.free(str(mram_name), strict=False) + value.residency.pop("mram_addr", None) + self._value_tiles_in_mram.pop(value.value_tile_id, None) + self._mram_fifo[:] = [item for item in self._mram_fifo if item != value.value_tile_id] + + # HBM residency is preserved: the tile remains valid in HBM while also + # resident in VRAM. Only MRAM is evicted on HBM→VRAM moves. + + def _view_covers_logical_tile(self, tile: TensorTile | InputTile, view: ValueTileView) -> bool: + return ( + int(view.row_offset) == 0 + and int(view.col_offset) == 0 + and int(view.row_count) == int(tile.tile_shape[0]) + and int(view.col_count) == int(tile.tile_shape[1]) + ) + + def _prepare_partial_update_vram_successor(self, old_value: ValueTile) -> Optional[ValueTile]: + has_hbm_backing = ( + old_value.residency.get("hbm_addr") is not None + and old_value.residency.get("hbm_name") is not None + and bool(old_value.residency.get("hbm_ready")) + ) + old_vram_addr = old_value.residency.get("vram_addr") + if not has_hbm_backing or old_vram_addr is None: + return None + + new_value = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=old_value.logical_shape, + metadata=dict(old_value.metadata), + ) + new_value.from_input_tile = old_value.from_input_tile + new_value.source_input_tile_id = old_value.source_input_tile_id + new_value.residency["vram_addr"] = old_value.residency.pop("vram_addr") + new_value.residency["vram_name"] = old_value.residency.pop("vram_name", None) + new_value.residency["vram_owner_from"] = old_value.value_tile_id + self._value_tiles_in_vram.pop(old_value.value_tile_id, None) + self._value_tiles_in_vram[new_value.value_tile_id] = int(new_value.residency["vram_addr"]) + self.value_tiles[new_value.value_tile_id] = new_value + return new_value + + def protect_value_tile(self, value: ValueTile, place: str = "vram") -> None: + if place != "vram": + raise ValueError(f"Unsupported protect place: {place}") + already_protected = value.value_tile_id in self._protected_vram_value_tile_ids + self._protected_vram_value_tile_ids.add(value.value_tile_id) + + def stop_protect_value_tile(self, value: Optional[ValueTile] = None, place: str = "vram") -> None: + if place != "vram": + raise ValueError(f"Unsupported protect place: {place}") + if value is None: + if not self._protected_vram_value_tile_ids: + return + self._protected_vram_value_tile_ids.clear() + return + if value.value_tile_id not in self._protected_vram_value_tile_ids: + return + self._protected_vram_value_tile_ids.remove(value.value_tile_id) + + def _is_protected_value_tile(self, value_tile_id: str, place: str = "vram") -> bool: + if place != "vram": + return False + return value_tile_id in self._protected_vram_value_tile_ids + + def _create_value_tile_for_tile(self, tile: TensorTile | InputTile, *, bind_tile_pointer: bool = True) -> ValueTile: + if bind_tile_pointer: + existing_id = self.full_tile_bindings.get(tile.tile_id) + if existing_id is not None: + existing = self.value_tiles.get(existing_id) + if existing is not None: + return existing + value_tile = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=tile.tile_shape, + from_input_tile=isinstance(tile, InputTile), + source_input_tile_id=tile.tile_id if isinstance(tile, InputTile) else None, + metadata={ + **dict(tile.metadata), + "source_tile_id": tile.tile_id, + }, + ) + if isinstance(tile, InputTile): + hbm_name = f"{tile.input_name}.hbm" + logical_shape = tuple(tile.metadata.get("logical_shape", ())) + hbm_stride = _logical_shape_to_hbm_stride(logical_shape) + hbm_offset = _tile_coord_to_hbm_offset(tile.coord, logical_shape, self.program.mlen) + hbm_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=f"{value_tile.value_tile_id}.hbm", + place="hbm", + value_tile=value_tile, + hbm_name=hbm_name, + hbm_offset=hbm_offset, + hbm_stride=hbm_stride if hbm_stride > 0 else self.program.mlen, + ) + value_tile.residency["hbm_addr"] = hbm_addr + value_tile.residency["hbm_name"] = hbm_name + value_tile.residency["hbm_offset"] = hbm_offset + value_tile.residency["hbm_stride"] = hbm_stride if hbm_stride > 0 else self.program.mlen + value_tile.residency["hbm_ready"] = True + self.value_tiles[value_tile.value_tile_id] = value_tile + if bind_tile_pointer: + self._bind_tile_pointer(tile.tile_id, value_tile.value_tile_id) + return value_tile + + def create_value_tile_in_fpram_for_tile( + self, + tile: TensorTile | InputTile, + fragment: FPFragment, + *, + bind: bool = True, + metadata: Optional[Dict[str, object]] = None, + ) -> ValueTile: + value = self.create_value_tile_in_fpram_from_fp_fragment( + fragment, + logical_shape=tile.tile_shape, + metadata={ + **dict(tile.metadata), + **(dict(metadata) if metadata is not None else {}), + "source_tile_id": tile.tile_id, + "source_fragment_name": fragment.name, + }, + ) + if bind: + if isinstance(tile, InputTile): + self._write_value_back_to_input_tile(value, tile) + else: + self._bind_value_to_tensor_tile(value, tile) + return value + + def _iter_group_tiles(self, tile: TensorTile | InputTile) -> List[TensorTile | InputTile]: + owner_tiles = self._owner_tiles_for_tile(tile) + group_key = self._view_group_key_for_tile(tile) + candidates: List[TensorTile | InputTile] = [] + for candidate in _tiles_in_grid_order(owner_tiles): + if not isinstance(candidate, (TensorTile, InputTile)): + continue + if not self._is_narrow_tensor_tile(candidate): + continue + if self._view_group_key_for_tile(candidate) != group_key: + continue + candidates.append(candidate) + return candidates + + def _owner_tiles_for_tile(self, tile: TensorTile | InputTile) -> Dict[TileCoord, TensorTile | InputTile]: + if isinstance(tile, TensorTile): + owner = self.program.tensor_manager.tensors.get(tile.tensor_name) + if owner is None: + raise RuntimeError(f"Unknown tensor owner for tile {tile.tile_id}: {tile.tensor_name}") + return owner.tiles + owner = self.program.tensor_manager.inputs.get(tile.input_name) + if owner is None: + raise RuntimeError(f"Unknown input owner for tile {tile.tile_id}: {tile.input_name}") + return owner.tiles + + def _split_mapv_signal(self, items: List[object]) -> Tuple[List[List[object]], Optional[TileLike]]: + pair_groups: List[List[object]] = [] + dst_tile: Optional[TileLike] = None + for item in items: + if isinstance(item, list) and len(item) == 2 and all(_is_tile_object(part) for part in item): + pair_groups.append(item) + continue + if isinstance(item, list) and len(item) == 1 and isinstance(item[0], (TensorTile, InputTile, VectorTile)): + dst_tile = item[0] + continue + return pair_groups, dst_tile + + def _resolve_tile_backing_value(self, tile: TensorTile | InputTile) -> ValueTile: + canonical_tile = self._resolve_alias_owner_tile(tile) + if canonical_tile is not tile and not self._is_narrow_tensor_tile(tile): + tile = canonical_tile + if self._is_narrow_tensor_tile(tile): + existing_id = self.full_tile_bindings.get(tile.tile_id) + if existing_id is not None: + existing = self.value_tiles.get(existing_id) + if existing is not None: + return existing + group_key = self._view_group_key_for_tile(tile) + group_value_id = self.narrow_group_bindings.get(group_key) + if group_value_id is not None: + existing = self.value_tiles.get(group_value_id) + if existing is not None: + self._bind_tiles_to_value(self._tiles_sharing_backing(tile), existing.value_tile_id) + return existing + value = self._create_value_tile_for_tile(tile, bind_tile_pointer=False) + self.narrow_group_bindings[group_key] = value.value_tile_id + self._bind_tiles_to_value(self._tiles_sharing_backing(tile), value.value_tile_id) + return value + existing_id = self.full_tile_bindings.get(tile.tile_id) + if existing_id is not None: + existing = self.value_tiles.get(existing_id) + if existing is not None: + return existing + return self._create_value_tile_for_tile(tile, bind_tile_pointer=True) + + def resolve_value_tile(self, tile: TensorTile | InputTile) -> ValueTile: + return self._resolve_tile_backing_value(tile) + + def get_value_tile(self, tile: TensorTile | InputTile) -> ValueTile: + # Compatibility wrapper around resolve_value_tile(). + return self.resolve_value_tile(tile) + + def bind_tile_to_fp_fragment(self, tile: VectorTile, fragment: FPFragment) -> FPFragment: + return self.program.vector_manager.bind_tile_to_fp_fragment(tile, fragment) + + def resolve_fp_fragment(self, tile: VectorTile) -> FPFragment: + return self.program.vector_manager.resolve_fp_fragment(tile) + + def _value_tile_has_live_refs(self, value_tile_id: str) -> bool: + if self.value_tile_tensor_refs.get(value_tile_id): + return True + return False + + def _value_debug_state(self, value: ValueTile) -> Dict[str, object]: + tensor_refs = sorted(self.value_tile_tensor_refs.get(value.value_tile_id, set())) + residency = value.residency + return { + "value_tile_id": value.value_tile_id, + "from_input_tile": bool(value.from_input_tile), + "source_input_tile_id": value.source_input_tile_id, + "vram_addr": residency.get("vram_addr"), + "mram_addr": residency.get("mram_addr"), + "hbm_name": residency.get("hbm_name"), + "hbm_addr": residency.get("hbm_addr"), + "hbm_offset": residency.get("hbm_offset"), + "hbm_stride": residency.get("hbm_stride"), + "hbm_scale_size": residency.get("hbm_scale_size"), + "hbm_ready": residency.get("hbm_ready"), + "tensor_refs": tensor_refs, + "last_move": value.metadata.get("last_move"), + } + + def _tile_debug_state(self, tile: TensorTile | InputTile) -> Dict[str, object]: + state: Dict[str, object] = { + "tile_id": tile.tile_id, + "coord": tile.coord, + "tile_shape": tile.tile_shape, + "kind": type(tile).__name__, + } + if isinstance(tile, InputTile): + state["owner"] = tile.input_name + elif isinstance(tile, TensorTile): + state["owner"] = tile.tensor_name + logical_shape = tile.metadata.get("logical_shape") + if logical_shape is not None: + state["logical_shape"] = logical_shape + return state + + def prepare_vram_backing_value( + self, + value: Optional[ValueTile] = None, + *, + preserve_existing: bool = False, + ) -> ValueTile: + if value is not None and not preserve_existing and not self._value_tile_has_live_refs(value.value_tile_id): + self.ensure_value_tile_in_place(value, "vram") + return value + new_value_tile = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=value.logical_shape if value is not None else (self.program.mlen, self.program.mlen), + metadata=dict(value.metadata) if value is not None else {}, + ) + if value is not None: + new_value_tile.from_input_tile = value.from_input_tile + new_value_tile.source_input_tile_id = value.source_input_tile_id + has_live_refs = self._value_tile_has_live_refs(value.value_tile_id) + can_transfer_vram = ( + value.residency.get("vram_addr") is not None + and not self._is_protected_value_tile(value.value_tile_id, "vram") + and not has_live_refs + ) + if can_transfer_vram: + new_value_tile.residency["vram_addr"] = value.residency.pop("vram_addr") + new_value_tile.residency["vram_name"] = value.residency.pop("vram_name", None) + new_value_tile.residency["vram_owner_from"] = value.value_tile_id + old_addr = self._value_tiles_in_vram.pop(value.value_tile_id, None) + if old_addr is not None: + self._value_tiles_in_vram[new_value_tile.value_tile_id] = old_addr + elif ( + has_live_refs + and ( + value.residency.get("vram_addr") is not None + or value.residency.get("hbm_addr") is not None + or value.residency.get("hbm_ready") + ) + ): + self.ensure_value_tile_in_place(value, "hbm") + if new_value_tile.residency.get("vram_addr") is None: + vram_name = f"{new_value_tile.value_tile_id}.vram" + vram_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=vram_name, + place="vram", + value_tile=new_value_tile, + ) + new_value_tile.residency["vram_addr"] = vram_addr + new_value_tile.residency["vram_name"] = vram_name + self._value_tiles_in_vram[new_value_tile.value_tile_id] = vram_addr + self.value_tiles[new_value_tile.value_tile_id] = new_value_tile + return new_value_tile + + def create_value_tile_in_fpram( + self, + *, + logical_shape: Tuple[int, int], + fpram_addr: int, + fpram_size: int, + fpram_name: str, + metadata: Optional[Dict[str, object]] = None, + ) -> ValueTile: + value_tile = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=tuple(int(dim) for dim in logical_shape), + metadata=dict(metadata) if metadata is not None else {}, + ) + value_tile.residency["fpram_addr"] = int(fpram_addr) + value_tile.residency["fpram_name"] = str(fpram_name) + value_tile.residency["fpram_size"] = int(fpram_size) + value_tile.residency["fpram_ready"] = True + self.value_tiles[value_tile.value_tile_id] = value_tile + return value_tile + + def create_value_tile_in_fpram_from_fp_fragment( + self, + fragment: FPFragment, + *, + logical_shape: Optional[Tuple[int, int]] = None, + metadata: Optional[Dict[str, object]] = None, + ) -> ValueTile: + fragment_shape = tuple(int(dim) for dim in fragment.shape) + tile_rows, tile_cols = _fp_fragment_shape_to_tile_shape( + fragment_shape, + mlen=self.program.mlen, + btmm_hlen=self.program.btmm_hlen, + ) + fp_vars = [fragment.vars[index] for index in _iter_fp_indices(fragment_shape)] + fp_addrs = [_require_fp_addr(fp_var) for fp_var in fp_vars] + fp_prog = self.program._arith_progression(fp_addrs) + expected_cells = tile_rows * tile_cols + if len(fp_addrs) != expected_cells: + raise RuntimeError( + f"FPFragment {fragment.name!r} expected {expected_cells} FP cells for one tile, got {len(fp_addrs)}" + ) + fp_base_addr = int(fp_addrs[0]) if fp_addrs else 0 + fp_dense = bool(fp_prog is not None and fp_prog[1] == expected_cells and fp_prog[2] == 1) + + return self.create_value_tile_in_fpram( + logical_shape=logical_shape if logical_shape is not None else (tile_rows, tile_cols), + fpram_addr=int(fp_base_addr), + fpram_size=int(expected_cells), + fpram_name=fragment.name, + metadata={ + **(dict(metadata) if metadata is not None else {}), + "fp_fragment_name": fragment.name, + "fp_fragment_shape": fragment_shape, + "fp_materialized_tile_shape": (tile_rows, tile_cols), + "fp_fragment_dense": fp_dense, + }, + ) + + def _resolve_value_fp_fragment(self, value: ValueTile) -> FPFragment: + fragment_name = value.metadata.get("fp_fragment_name") + if not isinstance(fragment_name, str): + raise RuntimeError( + f"fpram-backed value tile {value.value_tile_id} is missing fp_fragment_name metadata" + ) + fragment = self.program.tensor_manager.fp_fragments.get(fragment_name) + if not isinstance(fragment, FPFragment): + raise RuntimeError( + f"fpram-backed value tile {value.value_tile_id} references missing FPFragment {fragment_name!r}" + ) + return fragment + + def _temporary_fpram_row_scratch(self, row_width: int, *, value_tile_id: str, row_index: int) -> Tuple[str, int]: + allocator = self.program.compiler.sub_matrix_manager.fpram_allocator + floor = int(self.program.tensor_manager._next_fp_mem_addr) + if allocator.next_free < floor: + allocator.next_free = floor + allocator.free_stack[:] = [ + block for block in allocator.free_stack + if int(block.addr) >= floor + ] + scratch_name = f"__fpram_row_scratch__.{value_tile_id}.row{row_index}" + scratch_addr = allocator.allocate(scratch_name, row_width) + return scratch_name, int(scratch_addr) + + def _tile_vram_class(self, tile_id: str) -> str: + tile = self.program.tensor_manager.tensor_tiles.get(tile_id) + if tile is None: + tile = self.program.tensor_manager.input_tiles.get(tile_id) + if not isinstance(tile, (TensorTile, InputTile, VectorTile)): + return "l0" + raw_class = tile.metadata.get("vram_class", "l0") + if raw_class == "shared": + return "shared" + return "l0" + + def _value_tile_ref_counts(self, value_tile_id: str) -> Dict[str, int]: + ref_counts = {"shared": 0, "l0": 0} + for tile_id in self.value_tile_tensor_refs.get(value_tile_id, set()): + ref_counts[self._tile_vram_class(tile_id)] = ref_counts.get(self._tile_vram_class(tile_id), 0) + 1 + return ref_counts + + def _metadata_vram_class(self, metadata: Optional[Dict[str, object]]) -> str: + if metadata is None: + return "l0" + raw_class = metadata.get("vram_class", "l0") + if raw_class == "shared": + return "shared" + return "l0" + + def _value_tile_vram_protection_score(self, value: Optional[ValueTile]) -> int: + if value is None: + return 0 + ref_counts = self._value_tile_ref_counts(value.value_tile_id) + shared_refs = int(ref_counts.get("shared", 0)) + total_refs = sum(int(count) for count in ref_counts.values()) + return shared_refs * 1000 + total_refs * 10 + + def _request_vram_protection_score( + self, + *, + value: Optional[ValueTile] = None, + metadata: Optional[Dict[str, object]] = None, + ) -> int: + if value is not None: + live_score = self._value_tile_vram_protection_score(value) + if live_score > 0: + return live_score + source_tile_id = value.metadata.get("source_tile_id") + if isinstance(source_tile_id, str) and self._tile_vram_class(source_tile_id) == "shared": + return 1010 + if self._metadata_vram_class(value.metadata) == "shared": + return 1010 + return 10 + if self._metadata_vram_class(metadata) == "shared": + return 1010 + return 10 + + def _vram_eviction_candidates(self, *, max_score: Optional[int] = None) -> List[str]: + arrival_order = {value_tile_id: index for index, value_tile_id in enumerate(self._value_tiles_in_vram.keys())} + candidates: List[str] = [] + for value_tile_id in self._value_tiles_in_vram.keys(): + if self._is_protected_value_tile(value_tile_id, "vram"): + continue + value = self.value_tiles.get(value_tile_id) + if value is None: + continue + score = self._value_tile_vram_protection_score(value) + if max_score is not None and score > int(max_score): + continue + candidates.append(value_tile_id) + candidates.sort( + key=lambda value_tile_id: ( + self._value_tile_vram_protection_score(self.value_tiles.get(value_tile_id)), + arrival_order.get(value_tile_id, 0), + ) + ) + return candidates + + def _eviction_count_needed_for_vram_request(self, tile_count: int) -> int: + capacity = getattr(self.program, "vram_tile_capacity", 0) + if capacity <= 0: + return 0 + return max(0, len(self._value_tiles_in_vram) + int(tile_count) - capacity) + + def evaluate_contiguous_vram_value_tile_window( + self, + *, + tile_count: int, + metadata: Optional[Dict[str, object]] = None, + reason: str = "contiguous_vram_window", + ) -> Dict[str, object]: + if tile_count <= 0: + raise ValueError(f"tile_count must be positive, got {tile_count}") + + allocator = self.program.compiler.sub_matrix_manager.vram_allocator + tile_size = self.program.tile_elems + window_size = tile_count * tile_size + candidates: List[Dict[str, object]] = [] + request_score = self._request_vram_protection_score(metadata=metadata) + eviction_count = self._eviction_count_needed_for_vram_request(tile_count) + eviction_candidates = self._vram_eviction_candidates(max_score=request_score) + eviction_penalty = 0 + if eviction_count > 0: + if len(eviction_candidates) < eviction_count: + eviction_penalty = 10**12 + else: + eviction_penalty = sum( + 1 + max(0, self._value_tile_vram_protection_score(self.value_tiles.get(value_tile_id))) + for value_tile_id in eviction_candidates[:eviction_count] + ) + + for block in sorted(allocator.free_stack, key=lambda item: (item.size, item.addr)): + if block.size < window_size: + continue + waste = int(block.size - window_size) + candidates.append( + { + "kind": "free_stack", + "addr": int(block.addr), + "size": int(block.size), + "cost": waste + eviction_penalty, + "waste": waste, + "block_name": block.name, + "eviction_count": eviction_count, + "eviction_penalty": eviction_penalty, + } + ) + + aligned_bump_addr = ((int(allocator.next_free) + tile_size - 1) // tile_size) * tile_size + candidates.append( + { + "kind": "bump", + "addr": aligned_bump_addr, + "size": window_size, + "cost": window_size + eviction_penalty, + "waste": 0, + "block_name": "", + "eviction_count": eviction_count, + "eviction_penalty": eviction_penalty, + } + ) + + candidates.sort(key=lambda item: (int(item["cost"]), int(item["waste"]), int(item["addr"]))) + chosen = dict(candidates[0]) + plan = { + "reason": reason, + "tile_count": tile_count, + "tile_size": tile_size, + "window_size": window_size, + "chosen": chosen, + "candidates": candidates, + } + return plan + + def allocate_contiguous_vram_value_tiles( + self, + *, + tile_count: int, + logical_shape: Optional[Tuple[int, int]] = None, + metadata: Optional[Dict[str, object]] = None, + reason: str = "contiguous_vram_window", + ) -> Tuple[List[ValueTile], int]: + plan = self.evaluate_contiguous_vram_value_tile_window( + tile_count=tile_count, + metadata=metadata, + reason=reason, + ) + request_score = self._request_vram_protection_score(metadata=metadata) + self._evict_fifo_if_needed("vram", required_tiles=tile_count, max_score=request_score) + alloc_name = f"contiguous_values.{self._next_value_tile_id()}.vram" + window_size = int(plan["window_size"]) + tile_size = int(plan["tile_size"]) + base_addr = self.program.compiler.sub_matrix_manager.vram_allocator.allocate(size=window_size, name=alloc_name) + + template_metadata = dict(metadata) if metadata is not None else {} + reserved_values: List[ValueTile] = [] + for lane in range(tile_count): + value = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=logical_shape if logical_shape is not None else (self.program.mlen, self.program.mlen), + metadata={ + **template_metadata, + "contiguous_lane_index": lane, + }, + ) + vram_addr = base_addr + lane * tile_size + value.residency["vram_addr"] = vram_addr + value.residency["vram_name"] = alloc_name + value.residency["vram_lane_index"] = lane + self.value_tiles[value.value_tile_id] = value + self._value_tiles_in_vram[value.value_tile_id] = vram_addr + self._touch_fifo("vram", value.value_tile_id) + reserved_values.append(value) + + return reserved_values, base_addr + + def ensure_value_tile_in_place(self, value: ValueTile, place: str) -> ValueTile: + if place == "vram": + if value.residency.get("vram_addr") is not None: + return value + if value.residency.get("fpram_ready"): + vram_name = value.residency.get("vram_name") or f"{value.value_tile_id}.vram" + vram_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=str(vram_name), + place="vram", + value_tile=value, + ) + value.residency["vram_addr"] = vram_addr + value.residency["vram_name"] = vram_name + self.move_tile(value, "fpram", "vram") + self._value_tiles_in_vram[value.value_tile_id] = vram_addr + return value + # Fresh output/scratch values may not have any HBM provenance yet. + # For those, materialize directly in VRAM instead of forcing an HBM round-trip. + if value.residency.get("hbm_addr") is None and not value.residency.get("hbm_ready"): + vram_name = value.residency.get("vram_name") or f"{value.value_tile_id}.vram" + vram_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=str(vram_name), + place="vram", + value_tile=value, + ) + value.residency["vram_addr"] = vram_addr + value.residency["vram_name"] = vram_name + self._value_tiles_in_vram[value.value_tile_id] = vram_addr + return value + self.ensure_value_tile_in_place(value, "hbm") + if value.residency.get("vram_addr") is None: + vram_name = value.residency.get("vram_name") or f"{value.value_tile_id}.vram" + vram_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=str(vram_name), + place="vram", + value_tile=value, + ) + value.residency["vram_addr"] = vram_addr + value.residency["vram_name"] = vram_name + self.move_tile(value, "hbm", "vram") + if value.residency.get("vram_addr") is not None: + self._value_tiles_in_vram[value.value_tile_id] = value.residency["vram_addr"] + self.program._record_operation_snapshot( + "value_residency", + stage="ensure", + target_place="vram", + value=self._value_debug_state(value), + ) + return value + if place == "mram": + if value.residency.get("mram_addr") is not None: + return value + if value.residency.get("mram_addr") is None: + mram_name = f"{value.value_tile_id}.mram" + mram_addr = self.allocate_value_tile_address( + name=mram_name, + size=self.program.tile_elems, + place="mram", + value_tile=value, + ) + value.residency["mram_addr"] = mram_addr + value.residency["mram_name"] = mram_name + self.ensure_value_tile_in_place(value, "hbm") + self.move_tile(value, "hbm", "mram") + self._value_tiles_in_mram[value.value_tile_id] = value.residency["mram_addr"] + return value + if place == "fpram": + if value.residency.get("fpram_ready"): + return value + if value.residency.get("vram_addr") is not None and value.metadata.get("fp_fragment_name") is not None: + self.move_tile(value, "vram", "fpram") + return value + raise RuntimeError( + f"Value tile {value.value_tile_id} is not fpram-backed; current implementation only " + "supports values created initially in fpram" + ) + if place == "hbm": + if value.residency.get("hbm_ready"): + self._value_tiles_in_hbm[value.value_tile_id] = True + return value + if value.residency.get("fpram_ready"): + self.ensure_value_tile_in_place(value, "vram") + self.move_tile(value, "vram", "hbm") + value.residency["hbm_ready"] = True + self._value_tiles_in_hbm[value.value_tile_id] = { + "addr": value.residency.get("hbm_addr"), + "name": value.residency.get("hbm_name"), + "offset": value.residency.get("hbm_offset"), + "stride": value.residency.get("hbm_stride"), + } + return value + if value.residency.get("vram_addr") is None: + if value.residency.get("hbm_addr") is not None: + value.residency["hbm_ready"] = True + self._value_tiles_in_hbm[value.value_tile_id] = { + "addr": value.residency.get("hbm_addr"), + "name": value.residency.get("hbm_name"), + "offset": value.residency.get("hbm_offset"), + "stride": value.residency.get("hbm_stride"), + } + return value + raise RuntimeError( + f"Value tile {value.value_tile_id} is neither in HBM nor VRAM; refusing to ensure HBM to avoid loops" + ) + if value.residency.get("hbm_addr") is None: + hbm_name = f"{value.value_tile_id}.hbm" + hbm_addr = self.allocate_value_tile_address( + size=self.program.tile_elems, + name=hbm_name, + place="hbm", + value_tile=value, + ) + value.residency["hbm_addr"] = hbm_addr + value.residency["hbm_name"] = hbm_name + value.residency["hbm_offset"] = 0 + value.residency["hbm_stride"] = self.program.mlen + self.move_tile(value, "vram", "hbm") + value.residency["hbm_ready"] = True + self._value_tiles_in_hbm[value.value_tile_id] = { + "addr": value.residency.get("hbm_addr"), + "name": value.residency.get("hbm_name"), + "offset": value.residency.get("hbm_offset"), + "stride": value.residency.get("hbm_stride"), + } + self.program._record_operation_snapshot( + "value_residency", + stage="ensure", + target_place="hbm", + value=self._value_debug_state(value), + ) + return value + raise ValueError(f"Unsupported place for ensure_value_tile_in_place: {place}") + + def move_tile(self, value: ValueTile, src_place: str, dst_place: str) -> None: + if src_place == "fpram" and dst_place == "vram": + fpram_addr = value.residency.get("fpram_addr") + vram_addr = value.residency.get("vram_addr") + fragment_shape = value.metadata.get("fp_fragment_shape") + if vram_addr is None: + raise RuntimeError( + f"move_tile fpram->vram requires vram_addr for {value.value_tile_id}" + ) + if not isinstance(fragment_shape, tuple): + raise RuntimeError( + f"fpram-backed value tile {value.value_tile_id} is missing fp_fragment_shape metadata" + ) + fragment = self._resolve_value_fp_fragment(value) + row_count, row_width = _fp_fragment_shape_to_tile_shape( + tuple(int(dim) for dim in fragment_shape), + mlen=self.program.mlen, + btmm_hlen=self.program.btmm_hlen, + ) + slow_rows = 0 + for row_index in range(int(row_count)): + row_fp_vars = _fp_fragment_row_fp_vars( + fragment, + row_index=row_index, + row_width=int(row_width), + btmm_hlen=self.program.btmm_hlen, + ) + row_addrs = [_require_fp_addr(fp_var) for fp_var in row_fp_vars] + row_prog = self.program._arith_progression(row_addrs) + row_vram_addr = int(vram_addr) + row_index * int(row_width) + if row_prog is not None and row_prog[1] == int(row_width) and row_prog[2] == 1: + self.isa_emitter.emit_map_v_fp_tile( + vram_addr=row_vram_addr, + fpram_addr=int(row_prog[0]), + row_count=1, + row_width=int(row_width), + task_id=f"fpram_to_vram.{value.value_tile_id}.row{row_index}", + ) + continue + + slow_rows += 1 + scratch_name, scratch_addr = self._temporary_fpram_row_scratch( + int(row_width), + value_tile_id=value.value_tile_id, + row_index=row_index, + ) + scratch_addrs = [scratch_addr + offset for offset in range(int(row_width))] + try: + self.isa_emitter.emit_fp_kernel( + src1_addrs=row_addrs, + dst_addrs=scratch_addrs, + op="copy", + task_id=f"fpram_row_gather.{value.value_tile_id}.row{row_index}", + ) + self.isa_emitter.emit_map_v_fp_tile( + vram_addr=row_vram_addr, + fpram_addr=int(scratch_addr), + row_count=1, + row_width=int(row_width), + task_id=f"fpram_to_vram.{value.value_tile_id}.row{row_index}.scratch", + ) + finally: + self.program.compiler.sub_matrix_manager.fpram_allocator.free(scratch_name, strict=False) + value.metadata["last_move"] = ("fpram", "vram") + value.residency.pop("fpram_addr", None) + value.residency.pop("fpram_name", None) + value.residency.pop("fpram_size", None) + value.residency.pop("fpram_ready", None) + value.residency.pop("hbm_addr", None) + value.residency.pop("hbm_name", None) + value.residency.pop("hbm_offset", None) + value.residency.pop("hbm_stride", None) + value.residency.pop("hbm_scale_size", None) + value.residency.pop("hbm_ready", None) + value.residency.pop("mram_addr", None) + value.residency.pop("mram_name", None) + self._value_tiles_in_hbm.pop(value.value_tile_id, None) + self._value_tiles_in_mram.pop(value.value_tile_id, None) + self._mram_fifo[:] = [item for item in self._mram_fifo if item != value.value_tile_id] + return + if src_place == "vram" and dst_place == "hbm": + vram_addr = value.residency.get("vram_addr") + hbm_params = self._hbm_base_offset_scale_for_value(value) + hbm_addr = hbm_params["hbm_addr"] + hbm_name = hbm_params["hbm_name"] + if vram_addr is None or hbm_addr is None or hbm_name is None: + raise RuntimeError( + f"move_tile vram->hbm requires vram_addr/hbm_addr/hbm_name for {value.value_tile_id}" + ) + self.isa_emitter.emit_store_tile_to_hbm( + vram_addr=int(vram_addr), + hbm_addr=int(hbm_params["hbm_base_addr"]), + hbm_stride=int(hbm_params["hbm_stride"]), + hbm_scale_size=int(hbm_params["hbm_scale_size"]), + hbm_start_offset=int(hbm_params["hbm_offset"]), + ) + value.metadata["last_move"] = ("vram", "hbm") + self.program._record_operation_snapshot( + "value_residency", + stage="move_tile", + src_place="vram", + dst_place="hbm", + hbm_params=dict(hbm_params), + value=self._value_debug_state(value), + ) + return + if src_place == "hbm" and dst_place == "vram": + hbm_params = self._hbm_base_offset_scale_for_value(value) + hbm_addr = hbm_params["hbm_addr"] + vram_addr = value.residency.get("vram_addr") + hbm_name = hbm_params["hbm_name"] + if hbm_addr is None or vram_addr is None or hbm_name is None: + raise RuntimeError(f"move_tile hbm->vram requires both hbm_addr and vram_addr for {value.value_tile_id}") + self.isa_emitter.emit_load_tile_from_hbm( + hbm_addr=int(hbm_params["hbm_base_addr"]), + vram_addr=int(vram_addr), + hbm_stride=int(hbm_params["hbm_stride"]), + hbm_scale_size=int(hbm_params["hbm_scale_size"]), + hbm_start_offset=int(hbm_params["hbm_offset"]), + ) + value.metadata["last_move"] = ("hbm", "vram") + self._drop_stale_non_vram_residency(value) + self.program._record_operation_snapshot( + "value_residency", + stage="move_tile", + src_place="hbm", + dst_place="vram", + hbm_params=dict(hbm_params), + value=self._value_debug_state(value), + ) + return + if src_place == "hbm" and dst_place == "mram": + hbm_params = self._hbm_base_offset_scale_for_value(value) + hbm_addr = hbm_params["hbm_addr"] + mram_addr = value.residency.get("mram_addr") + if hbm_addr is None or mram_addr is None: + raise RuntimeError(f"move_tile hbm->mram requires both hbm_addr and mram_addr for {value.value_tile_id}") + self.isa_emitter.emit_hbm_tile_to_mram( + hbm_addr=int(hbm_params["hbm_base_addr"]), + mram_addr=int(mram_addr), + hbm_offset=int(hbm_params["hbm_offset"]), + hbm_scale=int(hbm_params["hbm_scale_size"]), + hbm_stride=int(hbm_params["hbm_stride"]), + ) + value.metadata["last_move"] = ("hbm", "mram") + return + if src_place == "vram" and dst_place == "fpram": + vram_addr = value.residency.get("vram_addr") + if vram_addr is None: + raise RuntimeError( + f"move_tile vram->fpram requires vram_addr for {value.value_tile_id}" + ) + fragment = self._resolve_value_fp_fragment(value) + fragment_shape = tuple(int(dim) for dim in fragment.shape) + row_count, row_width = _fp_fragment_shape_to_tile_shape( + fragment_shape, + mlen=self.program.mlen, + btmm_hlen=self.program.btmm_hlen, + ) + for row_index in range(int(row_count)): + row_fp_vars = _fp_fragment_row_fp_vars( + fragment, + row_index=row_index, + row_width=int(row_width), + btmm_hlen=self.program.btmm_hlen, + ) + row_addrs = [_require_fp_addr(fp_var) for fp_var in row_fp_vars] + row_prog = self.program._arith_progression(row_addrs) + row_vram_addr = int(vram_addr) + row_index * int(row_width) + if row_prog is not None and row_prog[1] == int(row_width) and row_prog[2] == 1: + self.isa_emitter.emit_map_fp_v_tile( + fpram_addr=int(row_prog[0]), + vram_addr=row_vram_addr, + row_count=1, + row_width=int(row_width), + task_id=f"vram_to_fpram.{value.value_tile_id}.row{row_index}", + ) + continue + + scratch_name, scratch_addr = self._temporary_fpram_row_scratch( + int(row_width), + value_tile_id=value.value_tile_id, + row_index=row_index, + ) + scratch_addrs = [scratch_addr + offset for offset in range(int(row_width))] + try: + self.isa_emitter.emit_map_fp_v_tile( + fpram_addr=int(scratch_addr), + vram_addr=row_vram_addr, + row_count=1, + row_width=int(row_width), + task_id=f"vram_to_fpram.{value.value_tile_id}.row{row_index}.scratch", + ) + self.isa_emitter.emit_fp_kernel( + src1_addrs=scratch_addrs, + dst_addrs=row_addrs, + op="copy", + task_id=f"fpram_row_scatter.{value.value_tile_id}.row{row_index}", + ) + finally: + self.program.compiler.sub_matrix_manager.fpram_allocator.free(scratch_name, strict=False) + + fp_vars = [fragment.vars[index] for index in _iter_fp_indices(fragment_shape)] + fp_addrs = [_require_fp_addr(fp_var) for fp_var in fp_vars] + value.residency["fpram_name"] = fragment.name + value.residency["fpram_size"] = len(fp_addrs) + value.residency["fpram_ready"] = True + if fp_addrs: + value.residency["fpram_addr"] = int(fp_addrs[0]) + value.metadata["last_move"] = ("vram", "fpram") + return + raise ValueError(f"Unsupported move_tile path: {src_place} -> {dst_place}") + + def _hbm_base_offset_scale_for_value(self, value: ValueTile) -> Dict[str, object]: + explicit_hbm_name = value.residency.get("hbm_name") + explicit_hbm_addr = value.residency.get("hbm_addr") + explicit_hbm_offset = value.residency.get("hbm_offset") + explicit_hbm_stride = value.residency.get("hbm_stride") + if ( + explicit_hbm_name is not None + and explicit_hbm_addr is not None + and explicit_hbm_offset is not None + and explicit_hbm_stride is not None + ): + hbm_object = self.program.hardware.hbm_objects.get(str(explicit_hbm_name)) + if hbm_object is None: + raise RuntimeError( + f"Value tile {value.value_tile_id} references missing explicit HBM object {explicit_hbm_name}" + ) + hbm_shape = tuple(hbm_object.get("shape", (self.program.mlen, self.program.mlen))) + hbm_scale_size = int(value.residency.get("hbm_scale_size", int(hbm_shape[0]) * int(hbm_shape[1]))) + hbm_base_addr = int(hbm_object["base_addr"]) + return { + "hbm_name": str(explicit_hbm_name), + "hbm_addr": int(explicit_hbm_addr), + "hbm_base_addr": hbm_base_addr, + "hbm_offset": int(explicit_hbm_offset), + "hbm_stride": int(explicit_hbm_stride), + "hbm_scale_size": hbm_scale_size, + } + if value.from_input_tile and value.source_input_tile_id is not None: + input_tile = self.program.tensor_manager.input_tiles.get(value.source_input_tile_id) + if input_tile is not None: + input_obj = self.program.tensor_manager.inputs.get(input_tile.input_name) + hbm_name = ( + input_obj.metadata.get("hbm_group_obj", f"{input_tile.input_name}.hbm") + if input_obj is not None + else f"{input_tile.input_name}.hbm" + ) + logical_shape = tuple(input_tile.metadata.get("logical_shape", ())) + hbm_stride = _logical_shape_to_hbm_stride(logical_shape) + hbm_offset = _tile_coord_to_hbm_offset(input_tile.coord, logical_shape, self.program.mlen) + hbm_object = self.program.hardware.hbm_objects.get(str(hbm_name)) + if hbm_object is None: + raise RuntimeError( + f"Input-backed value tile {value.value_tile_id} references missing HBM object {hbm_name}" + ) + hbm_shape = tuple(hbm_object.get("shape", (self.program.mlen, self.program.mlen))) + hbm_scale_size = int(hbm_shape[0]) * int(hbm_shape[1]) + hbm_base_addr = int(hbm_object["base_addr"]) + hbm_addr = hbm_base_addr + int(hbm_offset) + value.residency["hbm_name"] = str(hbm_name) + value.residency["hbm_addr"] = hbm_addr + value.residency["hbm_offset"] = int(hbm_offset) + value.residency["hbm_stride"] = int(hbm_stride) + value.residency["hbm_scale_size"] = int(hbm_scale_size) + return { + "hbm_name": str(hbm_name), + "hbm_addr": hbm_addr, + "hbm_base_addr": hbm_base_addr, + "hbm_offset": int(hbm_offset), + "hbm_stride": int(hbm_stride), + "hbm_scale_size": int(hbm_scale_size), + } + + hbm_name = value.residency.get("hbm_name") + hbm_addr = value.residency.get("hbm_addr") + if hbm_name is None or hbm_addr is None: + raise RuntimeError(f"Value tile {value.value_tile_id} is missing HBM metadata") + hbm_object = self.program.hardware.hbm_objects.get(str(hbm_name)) + if hbm_object is None: + raise RuntimeError(f"Unknown HBM object for value tile {value.value_tile_id}: {hbm_name}") + hbm_base_addr = int(hbm_object["base_addr"]) + hbm_shape = tuple(hbm_object.get("shape", (self.program.mlen, self.program.mlen))) + explicit_hbm_scale_size = value.residency.get("hbm_scale_size") + hbm_scale_size = int(explicit_hbm_scale_size) if explicit_hbm_scale_size is not None else int(hbm_shape[0]) * int(hbm_shape[1]) + hbm_offset = int(value.residency.get("hbm_offset", int(hbm_addr) - hbm_base_addr)) + hbm_stride = int(value.residency.get("hbm_stride", self.program.mlen)) + if explicit_hbm_scale_size is None: + value.residency["hbm_scale_size"] = int(hbm_scale_size) + return { + "hbm_name": str(hbm_name), + "hbm_addr": int(hbm_addr), + "hbm_base_addr": hbm_base_addr, + "hbm_offset": int(hbm_offset), + "hbm_stride": int(hbm_stride), + "hbm_scale_size": int(hbm_scale_size), + } + + def allocate_value_tile_address( + self, + *, + size: int, + name: str, + place: str, + value_tile: Optional[ValueTile] = None, + hbm_name: Optional[str] = None, + hbm_offset: int = 0, + hbm_stride: Optional[int] = None, + ) -> int: + if place == "vram": + request_score = self._request_vram_protection_score(value=value_tile) + self._evict_fifo_if_needed("vram", required_tiles=1, max_score=request_score) + if value_tile is not None: + self._touch_fifo("vram", value_tile.value_tile_id) + addr = self.program.compiler.sub_matrix_manager.vram_allocator.allocate(size=size, name=name) + return addr + if place == "mram": + self._evict_fifo_if_needed("mram") + if value_tile is not None: + self._touch_fifo("mram", value_tile.value_tile_id) + addr = self.program.compiler.sub_matrix_manager.mram_allocator.allocate(name=name, size=size) + return addr + if place == "hbm": + resolved_name = hbm_name or name + if resolved_name not in self.program.hardware.hbm_objects: + base_addr = self.program.add_hbm_object( + resolved_name, + (self.program.mlen, self.program.mlen), + ) + else: + base_addr = self.program.hardware.hbm_objects[resolved_name]["base_addr"] + hbm_object = self.program.hardware.hbm_objects[resolved_name] + hbm_shape = tuple(hbm_object.get("shape", (self.program.mlen, self.program.mlen))) + hbm_scale_size = int(hbm_shape[0]) * int(hbm_shape[1]) + addr = base_addr + int(hbm_offset) + if value_tile is not None: + scale_size = int(value_tile.residency.get("hbm_scale_size", hbm_scale_size)) + self._value_tiles_in_hbm[value_tile.value_tile_id] = { + "addr": addr, + "name": resolved_name, + "offset": int(hbm_offset), + "stride": self.program.mlen if hbm_stride is None else int(hbm_stride), + "scale_size": scale_size, + } + if "hbm_scale_size" not in value_tile.residency: + value_tile.residency["hbm_scale_size"] = hbm_scale_size + return addr + raise ValueError(f"Unsupported place for allocate_value_tile_address: {place}") + + def _touch_fifo(self, place: str, value_tile_id: str) -> None: + if place == "vram": + if value_tile_id in self._value_tiles_in_vram: + addr = self._value_tiles_in_vram.pop(value_tile_id) + self._value_tiles_in_vram[value_tile_id] = addr + return + fifo = self._mram_fifo + fifo[:] = [item for item in fifo if item != value_tile_id] + fifo.append(value_tile_id) + + def _evict_fifo_if_needed( + self, + place: str, + *, + required_tiles: int = 1, + max_score: Optional[int] = None, + ) -> None: + if place == "vram": + capacity = getattr(self.program, "vram_tile_capacity", 0) + if capacity > 0: + while len(self._value_tiles_in_vram) + int(required_tiles) - 1 >= capacity: + self._evict_one_value_tile("vram", max_score=max_score) + return + if place == "mram": + capacity = getattr(self.program, "mram_tile_capacity", 0) + if capacity > 0 and len(self._value_tiles_in_mram) >= capacity: + self._evict_one_value_tile("mram") + return + + def _evict_one_value_tile(self, place: str, *, max_score: Optional[int] = None) -> None: + residency_table = self._value_tiles_in_vram if place == "vram" else self._value_tiles_in_mram + addr_key = "vram_addr" if place == "vram" else "mram_addr" + name_key = "vram_name" if place == "vram" else "mram_name" + allocator = ( + self.program.compiler.sub_matrix_manager.vram_allocator + if place == "vram" + else self.program.compiler.sub_matrix_manager.mram_allocator + ) + if place == "vram": + resident_ids = self._vram_eviction_candidates(max_score=max_score) + if not resident_ids: + detail = "" if max_score is None else f" with max_score={int(max_score)}" + raise RuntimeError( + f"{place.upper()} allocation requested but no resident value tile was available for FIFO eviction{detail}" + ) + skipped_protected = 0 + while resident_ids: + evict_id = resident_ids.pop(0) + if self._is_protected_value_tile(evict_id, "vram"): + addr = residency_table.pop(evict_id) + residency_table[evict_id] = addr + skipped_protected += 1 + if skipped_protected >= len(residency_table): + raise RuntimeError( + f"VRAM eviction stalled because all resident value tiles are currently protected" + ) + continue + evict_value = self.value_tiles.get(evict_id) + if evict_value is None: + raise RuntimeError( + f"{place.upper()} residency table references missing value tile {evict_id}; internal residency state is inconsistent" + ) + self.ensure_value_tile_in_place(evict_value, "hbm") + alloc_name = evict_value.residency.get(name_key) + if alloc_name is not None: + allocator.free(str(alloc_name), strict=False) + vram_addr_before = evict_value.residency.get(addr_key) + hbm_addr_after = evict_value.residency.get("hbm_addr") + evict_value.residency.pop(addr_key, None) + evict_value.residency.pop(name_key, None) + residency_table.pop(evict_id, None) + self.program.eviction_warnings.append({ + "op_index": len(self.program.operation_snapshots), + "value_tile_id": evict_id, + "tensor_refs": list(self.value_tile_tensor_refs.get(evict_id, set())), + "vram_addr": vram_addr_before, + "hbm_addr": hbm_addr_after, + "score": self._value_tile_vram_protection_score(evict_value), + "max_score_filter": max_score, + }) + return + raise RuntimeError(f"{place.upper()} allocation requested but no resident value tile was available for FIFO eviction") + + fifo = self._mram_fifo + while fifo: + evict_id = fifo.pop(0) + evict_value = self.value_tiles.get(evict_id) + if evict_value is None: + raise RuntimeError( + f"{place.upper()} FIFO references missing value tile {evict_id}; internal residency state is inconsistent" + ) + alloc_name = evict_value.residency.get(name_key) + if alloc_name is not None: + allocator.free(str(alloc_name), strict=False) + evict_value.residency.pop(addr_key, None) + evict_value.residency.pop(name_key, None) + residency_table.pop(evict_id, None) + return + raise RuntimeError(f"{place.upper()} allocation requested but no resident value tile was available for FIFO eviction") + + def mapv_back(self, signal: List[object]) -> Dict[str, object]: + compute_output, mapv_input = signal + dst_value = compute_output.get("dst") if isinstance(compute_output, dict) else None + if not isinstance(mapv_input, tuple) or not mapv_input: + raise RuntimeError("mapv_back expects one tuple mapv packet") + control = mapv_input[0] + if control == "copy": + if len(mapv_input) != 3: + raise RuntimeError("copy mapv_back expects ('copy', src_value, dst_tile)") + _, src_value, dst_tile = mapv_input + if not isinstance(src_value, ValueTile): + raise RuntimeError("copy mapv_back expects one source ValueTile") + if not isinstance(dst_tile, (TensorTile, InputTile)): + raise RuntimeError("copy mapv_back expects one destination tile") + if isinstance(dst_tile, InputTile): + self._write_value_back_to_input_tile(src_value, dst_tile) + else: + self._bind_value_to_tensor_tile(src_value, dst_tile) + return { + "mapped_values": compute_output, + "mapv_input": mapv_input, + "dst_tile_id": dst_tile.tile_id, + "dst_value_tile_id": src_value.value_tile_id, + "control": control, + } + + if len(mapv_input) != 4: + raise RuntimeError("matmul mapv_back expects ('matmul', src_pairs, dst_value, dst_tile)") + _, _, _, dst_tile = mapv_input + if not isinstance(dst_value, ValueTile): + raise RuntimeError("mapv_back expects compute output to contain one destination ValueTile") + if not isinstance(dst_tile, (TensorTile, InputTile)): + raise RuntimeError("mapv_back expects mapv input to contain one destination tile") + if isinstance(dst_tile, InputTile): + self._write_value_back_to_input_tile(dst_value, dst_tile) + else: + self._bind_value_to_tensor_tile(dst_value, dst_tile) + return { + "mapped_values": compute_output, + "mapv_input": mapv_input, + "dst_tile_id": dst_tile.tile_id, + "dst_value_tile_id": dst_value.value_tile_id, + "control": control, + } + + def _write_value_back_to_input_tile(self, value: ValueTile, dst_tile: InputTile) -> None: + original_value = value + input_obj = self.program.tensor_manager.inputs.get(dst_tile.input_name) + if input_obj is None: + raise RuntimeError(f"Unknown input owner for input tile {dst_tile.tile_id}: {dst_tile.input_name}") + hbm_name = input_obj.metadata.get("hbm_group_obj", f"{dst_tile.input_name}.hbm") + logical_shape = tuple(dst_tile.metadata.get("logical_shape", ())) + hbm_stride = _logical_shape_to_hbm_stride(logical_shape) + hbm_offset = _tile_coord_to_hbm_offset(dst_tile.coord, logical_shape, self.program.mlen) + hbm_object = self.program.hardware.hbm_objects.get(str(hbm_name)) + if hbm_object is None: + raise RuntimeError(f"Unknown HBM object for input writeback: {hbm_name}") + hbm_shape = tuple(hbm_object.get("shape", (self.program.mlen, self.program.mlen))) + hbm_addr = int(hbm_object["base_addr"]) + int(hbm_offset) + + prev_hbm_name = value.residency.get("hbm_name") + prev_hbm_addr = value.residency.get("hbm_addr") + prev_hbm_offset = value.residency.get("hbm_offset") + prev_hbm_stride = value.residency.get("hbm_stride") + target_changed = ( + prev_hbm_name != str(hbm_name) + or prev_hbm_addr != hbm_addr + or prev_hbm_offset != hbm_offset + or prev_hbm_stride != hbm_stride + ) + + # Preserve the current value contents before retargeting its HBM identity + # to the destination input/output object. Otherwise a non-VRAM resident + # value could be reloaded from the destination HBM slot instead of its + # original backing. + self.ensure_value_tile_in_place(value, "vram") + writeback_value = value + shared_tensor_refs = bool(self.value_tile_tensor_refs.get(value.value_tile_id)) + if shared_tensor_refs and target_changed: + old_vram_addr = value.residency.pop("vram_addr", None) + old_vram_name = value.residency.pop("vram_name", None) + if old_vram_addr is None: + raise RuntimeError( + f"shared writeback split requires VRAM residency, got {value.value_tile_id}" + ) + writeback_value = ValueTile( + value_tile_id=self._next_value_tile_id(), + logical_shape=value.logical_shape, + metadata=dict(value.metadata), + ) + writeback_value.residency["vram_addr"] = int(old_vram_addr) + if old_vram_name is not None: + writeback_value.residency["vram_name"] = old_vram_name + self.value_tiles[writeback_value.value_tile_id] = writeback_value + self._value_tiles_in_vram.pop(value.value_tile_id, None) + self._value_tiles_in_vram[writeback_value.value_tile_id] = int(old_vram_addr) + + writeback_value.residency["hbm_addr"] = hbm_addr + writeback_value.residency["hbm_name"] = str(hbm_name) + writeback_value.residency["hbm_offset"] = hbm_offset + writeback_value.residency["hbm_stride"] = hbm_stride + writeback_value.residency["hbm_scale_size"] = int(hbm_shape[0]) * int(hbm_shape[1]) + if target_changed: + # A value may already be "hbm_ready" in a temporary spill object. + # Final output writeback must retarget and actually store into the + # destination input/output HBM object instead of early-returning. + writeback_value.residency["hbm_ready"] = False + self.move_tile(writeback_value, "vram", "hbm") + writeback_value.residency["hbm_ready"] = True + self._value_tiles_in_hbm[writeback_value.value_tile_id] = { + "addr": writeback_value.residency.get("hbm_addr"), + "name": writeback_value.residency.get("hbm_name"), + "offset": writeback_value.residency.get("hbm_offset"), + "stride": writeback_value.residency.get("hbm_stride"), + "scale_size": writeback_value.residency.get("hbm_scale_size"), + } + if self._is_narrow_tensor_tile(dst_tile): + self._rebind_view_group_value(dst_tile, writeback_value) + else: + self._bind_tile_pointer(dst_tile.tile_id, writeback_value.value_tile_id) + writeback_value.metadata["input_writeback_tile_id"] = dst_tile.tile_id + writeback_value.metadata["input_writeback_name"] = dst_tile.input_name + self.program._record_operation_snapshot( + "value_writeback", + src_value=self._value_debug_state(original_value), + writeback_value=self._value_debug_state(writeback_value), + dst_tile=self._tile_debug_state(dst_tile), + target_hbm={ + "hbm_name": str(hbm_name), + "hbm_addr": hbm_addr, + "hbm_offset": hbm_offset, + "hbm_stride": hbm_stride, + "hbm_scale_size": int(hbm_shape[0]) * int(hbm_shape[1]), + }, + target_changed=target_changed, + shared_tensor_refs=shared_tensor_refs, + ) + self._release_vram_if_only_input_refs(writeback_value.value_tile_id) + + def _detach_input_backing_identity(self, value: ValueTile) -> None: + if not value.from_input_tile and value.source_input_tile_id is None: + return + # Keep the explicit HBM residency fields intact, but stop treating this + # value as one logical alias of its original input tile in later fallback + # HBM reconstruction paths. + value.from_input_tile = False + value.source_input_tile_id = None + + def _bind_value_to_tensor_tile(self, value: ValueTile, dst_tile: TensorTile) -> None: + canonical_tile = self._resolve_alias_owner_tile(dst_tile) + if isinstance(canonical_tile, TensorTile) and canonical_tile is not dst_tile and not self._is_narrow_tensor_tile(dst_tile): + dst_tile = canonical_tile + self._detach_input_backing_identity(value) + if self._is_narrow_tensor_tile(dst_tile): + self._rebind_view_group_value(dst_tile, value) + return + self._bind_tile_pointer(dst_tile.tile_id, value.value_tile_id) + + def _bind_tile_pointer(self, tile_id: str, value_tile_id: str) -> None: + old_value_tile_id = self.full_tile_bindings.get(tile_id) + if old_value_tile_id == value_tile_id: + self.value_tile_tensor_refs.setdefault(value_tile_id, set()).add(tile_id) + return + if old_value_tile_id is not None: + detached_old_value_tile_id = self._detach_tile_value_pointer(tile_id) + self._attach_tile_value_pointer(tile_id, value_tile_id) + if detached_old_value_tile_id is not None: + self.free_value_tile(detached_old_value_tile_id) + return + self._attach_tile_value_pointer(tile_id, value_tile_id) + + def _attach_tile_value_pointer(self, tile_id: str, value_tile_id: str) -> None: + self.full_tile_bindings[tile_id] = value_tile_id + self.value_tile_tensor_refs.setdefault(value_tile_id, set()).add(tile_id) + + def _release_unreferenced_value_tile(self, value_tile_id: str) -> None: + if self.value_tile_tensor_refs.get(value_tile_id): + return + if self._is_input_backed_value_tile(value_tile_id): + self._free_value_tile_vram_residency(value_tile_id) + return + self.free_value_tile(value_tile_id) + + def _release_vram_if_only_input_refs(self, value_tile_id: str) -> bool: + value = self.value_tiles.get(value_tile_id) + if value is None: + return False + refs = self.value_tile_tensor_refs.get(value_tile_id, set()) + if not refs: + return False + if any(ref_id not in self.program.tensor_manager.input_tiles for ref_id in refs): + return False + if not bool(value.residency.get("hbm_ready")): + return False + return self._free_value_tile_vram_residency(value_tile_id) + + def _detach_tile_value_pointer(self, tile_id: str, *, auto_free: bool = True) -> Optional[str]: + old_value_tile_id = self.full_tile_bindings.pop(tile_id, None) + if old_value_tile_id is None: + return None + old_refs = self.value_tile_tensor_refs.get(old_value_tile_id) + if old_refs is not None: + old_refs.discard(tile_id) + if not old_refs: + self.value_tile_tensor_refs.pop(old_value_tile_id, None) + if auto_free: + self._release_unreferenced_value_tile(old_value_tile_id) + self._release_vram_if_only_input_refs(old_value_tile_id) + return old_value_tile_id + + def _unbind_tile_value_pointer(self, tile_id: str) -> None: + old_value_tile_id = self._detach_tile_value_pointer(tile_id) + if old_value_tile_id is None: + return + self.free_value_tile(old_value_tile_id) + + def _is_input_backed_value_tile(self, value_tile_id: str) -> bool: + value = self.value_tiles.get(value_tile_id) + if value is not None and (value.from_input_tile or value.source_input_tile_id is not None): + return True + return any(ref_id in self.program.tensor_manager.input_tiles for ref_id in self.value_tile_tensor_refs.get(value_tile_id, set())) + + def _free_value_tile_vram_residency(self, value_tile_id: str) -> bool: + value = self.value_tiles.get(value_tile_id) + if value is None or self._is_protected_value_tile(value_tile_id, "vram"): + return False + vram_name = value.residency.pop("vram_name", None) + value.residency.pop("vram_addr", None) + self._value_tiles_in_vram.pop(value_tile_id, None) + if vram_name is None: + return False + has_other_live_owner = any( + other_id != value_tile_id and other.residency.get("vram_name") == vram_name + for other_id, other in self.value_tiles.items() + ) + if not has_other_live_owner: + self.program.compiler.sub_matrix_manager.vram_allocator.free(str(vram_name), strict=False) + return True + + def _non_input_value_refs(self, value_tile_id: str) -> List[str]: + return sorted( + ref_id + for ref_id in self.value_tile_tensor_refs.get(value_tile_id, set()) + if ref_id not in self.program.tensor_manager.input_tiles + ) + + def free_tensor_tile(self, tile: TensorTile, *, weak: Optional[bool] = None) -> Optional[str]: + if isinstance(tile, VectorTile): + raise TypeError("free_tensor_tile only supports TensorTile; VectorTile uses FPFragment backing") + value_tile_id = self.full_tile_bindings.get(tile.tile_id) + if value_tile_id is None: + return None + if weak: + self._detach_tile_value_pointer(tile.tile_id) + self.program._record_operation_snapshot( + "free_tensor_tile", + mode="weak", + tile=self._tile_debug_state(tile), + value_tile_id=value_tile_id, + ) + return value_tile_id + + if weak is None: + detached_tile_ids = [tile.tile_id] + self._detach_tile_value_pointer(tile.tile_id) + released_vram = False + if not self._non_input_value_refs(value_tile_id): + if self._is_input_backed_value_tile(value_tile_id): + released_vram = value_tile_id not in self._value_tiles_in_vram + else: + released_vram = value_tile_id not in self.value_tiles + self.program._record_operation_snapshot( + "free_tensor_tile", + mode="auto", + tile=self._tile_debug_state(tile), + value_tile_id=value_tile_id, + detached_tile_ids=detached_tile_ids, + released_vram=released_vram, + ) + return value_tile_id + + ref_tile_ids = sorted(self.value_tile_tensor_refs.get(value_tile_id, set())) + detach_tile_ids = ref_tile_ids + input_backed = self._is_input_backed_value_tile(value_tile_id) + if input_backed: + detach_tile_ids = [ + ref_tile_id + for ref_tile_id in ref_tile_ids + if ref_tile_id not in self.program.tensor_manager.input_tiles + ] + for ref_tile_id in detach_tile_ids: + self._detach_tile_value_pointer(ref_tile_id) + self.narrow_group_bindings = { + group_key: bound_value_tile_id + for group_key, bound_value_tile_id in self.narrow_group_bindings.items() + if bound_value_tile_id != value_tile_id + } + released_vram = False + if input_backed: + released_vram = self._free_value_tile_vram_residency(value_tile_id) + else: + self.free_value_tile(value_tile_id) + released_vram = True + self.program._record_operation_snapshot( + "free_tensor_tile", + mode="strong", + tile=self._tile_debug_state(tile), + value_tile_id=value_tile_id, + detached_tile_ids=detach_tile_ids, + preserved_input_tile_ids=[ref_id for ref_id in ref_tile_ids if ref_id not in detach_tile_ids], + released_vram=released_vram, + ) + return value_tile_id + + def free_value_tile(self, value_tile_id: str) -> None: + value = self.value_tiles.get(value_tile_id) + if value is None: + return + if self.value_tile_tensor_refs.get(value_tile_id): + return + if self._is_protected_value_tile(value_tile_id, "vram"): + return + vram_name = value.residency.pop("vram_name", None) + if vram_name is not None: + has_other_live_owner = any( + other_id != value_tile_id and other.residency.get("vram_name") == vram_name + for other_id, other in self.value_tiles.items() + ) + if not has_other_live_owner: + self.program.compiler.sub_matrix_manager.vram_allocator.free(str(vram_name), strict=False) + mram_name = value.residency.pop("mram_name", None) + if mram_name is not None: + self.program.compiler.sub_matrix_manager.mram_allocator.free(str(mram_name), strict=False) + value.residency.pop("vram_addr", None) + value.residency.pop("mram_addr", None) + self._value_tiles_in_vram.pop(value_tile_id, None) + self._value_tiles_in_mram.pop(value_tile_id, None) + self._value_tiles_in_hbm.pop(value_tile_id, None) + self._mram_fifo[:] = [item for item in self._mram_fifo if item != value_tile_id] + self.narrow_group_bindings = { + group_key: bound_value_tile_id + for group_key, bound_value_tile_id in self.narrow_group_bindings.items() + if bound_value_tile_id != value_tile_id + } + self.value_tiles.pop(value_tile_id, None) diff --git a/tilelang_runtime_compier/tile_tensor_program/_vector_manager.py b/tilelang_runtime_compier/tile_tensor_program/_vector_manager.py new file mode 100644 index 0000000..ae9a4a2 --- /dev/null +++ b/tilelang_runtime_compier/tile_tensor_program/_vector_manager.py @@ -0,0 +1,120 @@ +"""VectorManager: vector objects, vector tiles, and FP-fragment backing.""" + +from __future__ import annotations + +from math import ceil +from typing import Dict, List, Optional + +from ._types import * # noqa: F401,F403 +from ._helpers import * # noqa: F401,F403 + + +class VectorManager: + """Own vector-specific logical/runtime behavior. + + Today vectors are FP-fragment-backed rather than ValueTile-backed. This + manager centralizes: + + - logical `Vector` creation and registration + - `VectorTile` creation + - `VectorTile -> FPFragment` binding and lookup + - eager vector backing initialization + - vector FP-var resolution helpers used by `mapf` + """ + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + self.fp_fragment_bindings: Dict[str, str] = {} + + def vector(self, name: str, logical_shape: LogicalShape) -> Vector: + if len(logical_shape) != 3: + raise ValueError(f"vector expects one 3D logical shape, got {logical_shape}") + vector = Vector(program=self.program, name=name, logical_shape=logical_shape) + self.program.tensor_manager.vectors[name] = vector + return vector + + def create_vector_tiles(self, vector_name: str, logical_shape: LogicalShape) -> Dict[TileCoord, VectorTile]: + rows, cols = _logical_shape_to_physical_shape(logical_shape) + row_blocks = ceil(rows / self.program.mlen) + col_blocks = ceil(cols / self.program.mlen) + tiles: Dict[TileCoord, VectorTile] = {} + tensor_manager = self.program.tensor_manager + for row_block in range(row_blocks): + for col_block in range(col_blocks): + row_count = min(self.program.mlen, rows - row_block * self.program.mlen) + col_count = min(self.program.mlen, cols - col_block * self.program.mlen) + vector_tile = VectorTile( + tile_id=tensor_manager._next_tensor_tile_id(), + tensor_name=vector_name, + coord=(row_block, col_block), + tile_shape=(row_count, col_count), + metadata=tensor_manager._build_tile_metadata( + logical_shape, + row_block, + col_block, + row_count, + col_count, + ), + ) + tiles[(row_block, col_block)] = vector_tile + tensor_manager.vector_tiles[vector_tile.tile_id] = vector_tile + tensor_manager.tensor_tiles[vector_tile.tile_id] = vector_tile + return tiles + + def bind_tile_to_fp_fragment(self, tile: VectorTile, fragment: FPFragment) -> FPFragment: + self.fp_fragment_bindings[tile.tile_id] = fragment.name + return fragment + + def resolve_fp_fragment(self, tile: VectorTile) -> FPFragment: + fragment_name = self.fp_fragment_bindings.get(tile.tile_id) + if not isinstance(fragment_name, str): + raise RuntimeError(f"VectorTile {tile.tile_id} is not bound to one FPFragment") + fragment = self.program.tensor_manager.fp_fragments.get(fragment_name) + if not isinstance(fragment, FPFragment): + raise RuntimeError( + f"VectorTile {tile.tile_id} binding points to missing FPFragment {fragment_name!r}" + ) + return fragment + + def initialize_vector_backing(self, vector: Vector, *, init_zero: bool = False) -> None: + for tile in _tiles_in_grid_order(vector.tiles): + if self.fp_fragment_bindings.get(tile.tile_id): + continue + fragment_name = self.program._auto_name(f"{vector.name}.fp_tile") + fragment = self.program.tensor_manager.fp_fragment( + name=fragment_name, + shape=tile.tile_shape, + init=0.0, + ) + self.bind_tile_to_fp_fragment(tile, fragment) + tile.metadata["fp_fragment_name"] = fragment.name + if init_zero: + self.program.fp_fill(fragment, 0.0) + + def resolve_vector_fp_vars(self, vector: Vector) -> List[FPVar]: + resolved: List[FPVar] = [] + for logical_index in _iter_logical_indices(vector.logical_shape): + resolved.append(self.program.tensor_manager._resolve_element_fpvar(ElementRef(base=vector, indices=logical_index))) + return resolved + + def resolve_vector_slice_fp_vars(self, vector_slice: VectorSlice) -> List[FPVar]: + resolved: List[FPVar] = [] + for logical_index in _iter_selected_logical_indices(vector_slice.base.logical_shape, vector_slice.selectors): + resolved.append( + self.program.tensor_manager._resolve_element_fpvar( + ElementRef(base=vector_slice.base, indices=logical_index) + ) + ) + return resolved + + def resolve_vector_tile_fp_vars(self, tile: VectorTile) -> List[FPVar]: + fragment = self.resolve_fp_fragment(tile) + row_groups = _vector_tile_row_fp_groups( + src_tile=tile, + fragment=fragment, + mlen=self.program.mlen, + btmm_hlen=self.program.btmm_hlen, + src_slice_ranges=None, + ) + return [fp_var for row in row_groups for fp_var in row] + diff --git a/tilelang_tvm_compiler/MIGRATION_PLAN.md b/tilelang_tvm_compiler/MIGRATION_PLAN.md new file mode 100644 index 0000000..015b7f3 --- /dev/null +++ b/tilelang_tvm_compiler/MIGRATION_PLAN.md @@ -0,0 +1,371 @@ +# Migration Plan: All-Graph-Layer Frontend + +This document captured the target architecture for the frontend after +fully migrating from the original "stmt-walker chain + thin graph layer" +to "minimal stmt prep + full graph layer". Each section below describes +the target form, the migration steps, and the rationale for the design +choices. + +**Status: Phases A / B / C.1 / C.2 complete.** The all-graph-layer +pipeline is the only path; legacy stmt-walker passes and +`frontend_legacy/` have been deleted. See `PIPELINE_ARCHITECTURE.md` +for the architecture as it stands. **Phase D (loop scheduling — DMA +merge, prefetch / double-buffer) is the next milestone, not started**; +it depends on hardware capability data that is still TBD (see § Open +questions). + +> **Reading guide.** This document is now mostly retrospective — the +> "current pipeline" / "stmt walker" passages below describe the +> pre-migration state for context. For what the pipeline *currently* +> does, read `PIPELINE_ARCHITECTURE.md`. The Phase-D plan and the +> still-relevant Open Questions at the bottom are the only forward- +> looking parts. + +--- + +## Why we're doing this + +Current frontend is a chain of stmt-walker passes that communicate via +`T.attr(0, "plena.*", ...)` AttrStmts. Each pass re-walks the IR and +mutates it. Adding a new analysis means another walker; adding a new +fusion rule means coordinating attr ordering; adding new ops requires +multiple touchpoints scattered across passes. + +In the graph layer, each op is a `GraphNode` with `attrs` — passes read +and write attrs directly. `reads` / `writes` are filled at lift time +and live on the node, so any pass can do data-flow analysis without +re-walking stmt trees. New analyses are pure functions on `Graph`. New +fusion rules are pattern match + node replace. + +Architectural insight (key): **buffer layout decisions belong AT the +end of graph optimization, not as a separate stmt pass run before it.** +If `allocate_group_memory` runs before the graph layer, every graph +optimization that wants to change buffer shape (double-buffering for +prefetch, eliminating dead temps, etc) is locked out. The plan moves +buffer-shape allocation into `materialize`, where it has full +visibility of the post-optimization graph. + +--- + +## Target pipeline shape + +``` + [STMT STAGE — minimal] +@T.prim_func + │ + ▼ +inline_let_stmts # TIR housekeeping (LetStmt → subst) +lower_compound_fp_stores # arr[i]=a*b+c*d → temp → temp → out + │ + ▼ + [LIFT TO GRAPH] +lift_from_raw_primfunc # raw PrimFunc → Graph + # (one shot — no longer post-stmt-walker + # preprocessing required) + │ + ▼ + [GRAPH STAGE — analysis / annotation, no side effects] +graph.annotate_grid # grid bindings → ForNode attrs + # (replaces stmt-walker annotate_group) +graph.annotate_sync # ATTR_IS_SYNC on each GraphNode + # (current stub already runs) +graph.scope_inference # BufferNode.physical_scope + # (current implementation already exists, + # not wired) +graph.annotate_gemm_kind # ATTR_GEMM_KIND on gemm nodes + # (already done by lift_from_raw via + # KIND AttrStmt absorption) + │ + ▼ + [GRAPH STAGE — pattern fusion / canonicalisation] +graph.fuse_elementwise # for+BufferStore patterns → plena.v_* + # (replaces stmt fuse_elementwise) +graph.lower_fp_row_patterns # FPRAM row-element idioms → plena.fp_*_at + # (replaces stmt lower_fp_row_patterns) + │ + ▼ + [GRAPH STAGE — schedule transforms] +graph.split_lane_groups # head_count > lane_count axis split + # (replaces stmt split_lane_groups; now + # operates on ForNode + lane-attribute + # rather than for-rewriting + var subst) + │ + ▼ (future) + [GRAPH STAGE — real fusion (Phase D)] +graph.dma_merge # cross-iteration DMA combining + # (depends on HW capability data: + # max single-DMA size, etc) +graph.prefetch # double-buffer K/V across kv_block + # (depends on HW: NO async DMA in PLENA, + # so this would be reorder-only, not + # overlap) + │ + ▼ + [MATERIALIZE] +graph.materialize: + 1. Resolve each BufferNode's final physical layout: + - apply ATTR_LANE_LAYOUT (col_pack / row_stack / fp_lane) → + expand shape with lane dimension; + - check `global.*` scopes (no lane expansion); + - account for any layout decisions from real fusion passes + (e.g. double-buffer doubles the lane-axis extent). + 2. Build new tir.Buffer objects with finalized shapes. + 3. Walk graph items, lower each GraphNode to a tir.Stmt + (delegates to existing lower_to_hlir helpers _lower_copy / + _lower_gemm; their input is per-op call info, no IR walking). + 4. Wrap per-lane runs in `for(lane_var)` (the current + graph_pipeline._partition_and_materialize logic moves here + intact — it's already operating on the graph, no rewrite needed). + 5. Emit final tir.PrimFunc. + + │ + ▼ + [BACKEND — unchanged from today] +PlenaCodegen.lower_to_hlir() (TIR → HLIR) +AddressAllocationPass +IsaEmitterPass + │ + ▼ +.asm +``` + +--- + +## graph_ir extensions needed + +**Already done (Phase A)**: +* `BufferNode(name, shape, dtype, declared_scope, physical_scope, data_var, attrs)` +* `ForNode(loop_var, min, extent, kind, thread_binding, body_items, attrs)` +* Attr keys: `ATTR_IS_SYNC`, `ATTR_GEMM_KIND`, `ATTR_GROUP_EXTENT`, + `ATTR_IS_LANE_FOR`, `ATTR_LANE_LAYOUT`, `LAYOUT_COL_PACK`, + `LAYOUT_ROW_STACK`, `LAYOUT_FP_LANE`. + +**To add (Phase C.1)**: +* `Graph.buffer_nodes: dict[str, BufferNode]` — every alloc'd buffer + AND every param buffer becomes a BufferNode. GraphNode.reads/writes + reference these by name (not by tir.BufferRegion directly), so layout + changes propagate without rewriting reads/writes. +* `GraphNode.reads/writes` semantics shift: today `tir.BufferRegion` + carrying a `tir.Buffer` reference. Tomorrow: `BufferAccess(buffer_name, + starts, extents)` referencing `Graph.buffer_nodes[buffer_name]`. + +**Open**: do we keep the current `LaneGroup` / `NestedForGroup` / +`NodeRoot` / `ForRoot` distinction, or unify into a single recursive +`ForNode + items` representation? Probably unify — once +`graph.split_lane_groups` operates on ForNode directly, the +distinction adds little value. + +--- + +## Per-pass migration notes + +### inline_let_stmts (stmt stage, unchanged) +Pure TIR housekeeping. No reason to move. + +### lower_compound_fp_stores (stmt stage, unchanged for now) +Decomposes compound FP store RHS into sequential single-op stores plus +auto-allocated `__tmp_fp_*` buffers. Operates on `tir.BufferStore` +trees with `.value` arithmetic — best done while we still have +expression-level IR, before lift collapses ops into op-call form. +Future improvement: this could happen in graph layer too, but the +benefit is minor. + +### annotate_gemm_kind (DONE — removed from pipeline) +User writes `with T.attr(0, KIND, "btmm"): T.gemm(...)`; the AttrStmt +sits in raw IR. `lift_from_raw_primfunc._items_from_stmt` peels KIND +AttrStmts and writes `ATTR_GEMM_KIND` directly on the gemm GraphNode. +No-KIND gemm sites default to `"overwrite"` at lower time (in +`graph_pipeline._lower_node`). + +### annotate_group → graph.annotate_grid +Stmt walker today identifies grid bindings (`thread_extent` AttrStmts ++ `T.Parallel` for-loops) and wraps them in +`for v: T.attr(0, "plena.group", N): ...`. + +In the graph layer this becomes: walk every `ForNode`, set +`ATTR_GROUP_EXTENT` based on whether the for came from a grid binding +or a `T.Parallel` (lift_from_raw already records this). No IR rewrite +— just attr setting. + +`lift_from_raw` currently produces `ForRoot` for grid bindings. Phase +C.1 should record grid-axis extent on the `ForRoot` directly, and +graph passes consume that. + +### annotate_sync → graph.annotate_sync (DONE — `graph_passes/annotate_sync.py`) +Already exists. Sets `ATTR_IS_SYNC` on each GraphNode by inspecting +op kind, op-call args, ATTR_GEMM_KIND. Currently runs alongside the +stmt-walker version (which is still required because +`split_lane_groups` reads stmt-walker's `plena.sync` AttrStmts). + +In Phase C, after `split_lane_groups` migrates, the stmt-walker +version goes away; `graph.annotate_sync` is the only path. + +### split_lane_groups → graph.split_lane_groups +Stmt walker today: when a grid axis has extent > lane_count and is +sync-eligible, splits the for into outer × inner with var subst. + +In graph layer: +1. Find ForRoots/ForNodes whose extent > lane_count and whose body + contains a sync GraphNode (via ATTR_IS_SYNC, walking the items list). +2. Replace that ForNode with two new ForNodes: outer (extent = + original / lane_count) wrapping inner (extent = lane_count, with + `ATTR_IS_LANE_FOR=True`). +3. Walk the body items, substituting every reference to the original + loop_var with `outer_var * lane_count + inner_var`. This walks + `op_call.args` (region starts) and any RawStmt expressions. + +The var subst step is the hardest part — affects every reads/writes +region and every op_call argument. Solution: a graph-wide expr +rewriter, similar to `_VarSubst` in the stmt walker. + +### fuse_elementwise → graph.fuse_elementwise +Stmt walker today: matches `for i: AttrStmt(plena.group): BufferStore` +patterns and rewrites the whole for to a single `plena.v_*` extern +call. + +In graph layer, the equivalent pattern is `NestedForGroup(items=[ +RawStmt(BufferStore)])` (because lift_from_raw wraps BufferStores in +RawStmt) inside a LaneGroup or another NestedForGroup with matching +extent. Pass walks items, finds the pattern, replaces the whole +NestedForGroup with a single GraphNode(plena.v_*). + +The "nested fold" rule (outer T.serial(R) wrapping a fuse target → +single whole-buffer op) becomes: when the outer NestedForGroup's body +is a single fused GraphNode with whole-buffer semantics, drop the +outer for entirely. + +### lower_fp_row_patterns → graph.lower_fp_row_patterns +Largely the same as fuse_elementwise but with more pattern variants +(plena.fp_copy_at / fp_add_at / fp_exp_at / row_reduce_max_at / etc). +Uses BufferNode.physical_scope to determine FPRAM-residency of operands. + +### scope_inference → graph.scope_inference (DONE — `graph_passes/scope_inference.py`) +Already exists, equivalent to stmt-walker version. Will start being +used once allocate_group_memory and the row-pattern lower also live in +the graph layer (so the dict is consumed by graph code, not stmt code). + +### allocate_group_memory → INTEGRATED INTO MATERIALIZE +**Key architectural shift**: `allocate_group_memory` is no longer a +pass that runs in the middle of the pipeline. Its work — assigning +each buffer a lane layout (col_pack / row_stack / fp_lane) and +expanding shape — happens in `materialize`, after all graph +optimization is done. + +Why: graph optimizations may want to change buffer shape (e.g. a +double-buffering pass doubles the lane-axis extent for K_sh; a +dead-temp-elim pass removes a temp entirely). If shape is baked in +before optimization, those transforms are blocked. Move shape decisions +to materialize, after optimization stabilizes. + +Mechanism in materialize: +1. Walk graph nodes; for each op, infer the lane-layout role of each + operand (from op kind + ATTR_GEMM_KIND, same rules as today's + stmt-walker). Set `ATTR_LANE_LAYOUT` on each BufferNode. +2. Apply layout to BufferNode.shape: col_pack → + `(..., orig_last) → (1, ..., lane_count, orig_last)`; row_stack → + `(orig_first, ...) → (1, lane_count, orig_first, ...)`; fp_lane → + `(N,) → (lane_count, N)`. +3. Build tir.Buffer objects from final BufferNode state. +4. Lowering each op call uses the final shapes for offset computation + (existing `_auto_lane_offset` / `_dst_row_stride` logic in + lower_to_hlir applies, just driven by BufferNode instead of + tir.Buffer). + +### lower_to_hlir helpers (kept as op-level lowering library) +`_lower_copy` / `_lower_gemm` / `_rewrite_buffer_scopes` and their +expr / region helpers are pure op-level translation — they take a Call +and emit a tir.Call for plena.* extern. They don't need to migrate; +materialize calls them per node. Already wired this way today. + +--- + +## Verification strategy + +Single source of truth for correctness: HLIR diff against `backup_20260509`. + +For each migration step in Phase C: +1. Run all 7 working kernels through both old and new pipelines. +2. Compare HLIR ops + buffer tables. +3. Byte-identical → step lands. +4. Diverges → fix before merging. + +Unit tests (`tests/`) provide a faster signal but cover less. They are +necessary but not sufficient — HLIR diff is the real verifier. + +`/tmp/hlir_diff_tool.py` and `/tmp/hlir_diff_e2e_args.py` are existing +scripts that do the diff. Keep them through migration. + +--- + +## Phases + +* **A** ✅ — graph_ir extended, lift_from_raw exists, walker helpers exist. +* **B** ✅ — annotate_gemm_kind absorbed by lift; graph scope_inference and + graph annotate_sync written and used. +* **C.1** ✅ — graph passes for annotate_grid / split_lane_groups / + lift_lane_groups / fuse_elementwise / lower_fp_row_patterns written. + allocate_group_memory split into `analyze` (graph pass) + + `expand_buffers.expand` (run inside materialize). Pipeline rewired + to graph-only path with a temporary fallback flag. +* **C.2** ✅ — fallback flag and legacy stmt-walker pipeline deleted. + All 9 stmt-walker passes (annotate_group, annotate_sync, + annotate_gemm_kind, split_lane_groups, fuse_elementwise, + scope_inference, allocate_group_memory, lower_fp_row_patterns) + + lift_to_blocks + lift_to_graph removed. `frontend_legacy/` + directory removed. 6 stmt-walker test files removed. +* **D** ⏳ — loop-scheduling layer (NOT STARTED). Real new fusion: + * **DMA merge** across iterations: combine consecutive K/V DMAs into + one bigger transfer. + * **Prefetch / double-buffer**: alloc `K_sh_alt`, prefetch tile k+1 + while computing on k. + + Requires confirmed answers to the HW capability questions below. + +--- + +## Open questions + +### Still relevant + +**HW capabilities for Phase D**: +* Maximum single H_PREFETCH_V / H_PREFETCH_M element count? +* Total MRAM / VRAM capacity (so we know how much we can grow K_sh / + V_sh for cross-kv_block merge)? +* Are H_PREFETCH variants stride-aware? (Affects whether N rows of + K from `K_hbm[kv*64..(kv+1)*64]` can fold into one DMA.) + +### Resolved during C.1 / C.2 + +* ~~Should `Graph.buffer_nodes` index by `tir.Var` (data) or by + string name?~~ — **By name.** `BufferAccess(buffer_name, starts, + extents)` references it; works across pass boundaries even if + underlying `tir.Var` identity churns. +* ~~NestedForGroup vs ForNode unification?~~ — **Not done; not needed.** + We added `attrs: dict` to NestedForGroup and ForRoot directly. The + separate ForNode dataclass in graph_ir.py is now unused + forward-looking infrastructure (no consumer). Can be removed in a + later cleanup or repurposed if Phase D wants a different + NestedForGroup shape. + +--- + +## Phase C.2 cleanup notes (for future archaeologists) + +What was deleted vs. inlined into graph_passes: + +| Old stmt-walker | Replaced by | +|---|---| +| `annotate_group.py` | `graph_passes/annotate_grid.py` (+ `_VarSubst` inlined into `graph_passes/split_lane_groups.py` as `_StmtVarSubst`) | +| `annotate_sync.py` | `graph_passes/annotate_sync.py` | +| `annotate_gemm_kind.py` | absorbed by `lift_from_raw` (KIND_KEY constant inlined there) | +| `split_lane_groups.py` | `graph_passes/split_lane_groups.py` (operates on graph items + does the var substitution there) | +| `fuse_elementwise.py` | `graph_passes/fuse_elementwise.py` (sets ATTR_IS_SYNC on created nodes — see PIPELINE_ARCHITECTURE.md § Bug history) | +| `scope_inference.py` | `graph_passes/scope_inference.py` (also owns `BufferScopeMap` / `ScopeInferenceError` types) | +| `allocate_group_memory.py` | split into `graph_passes/allocate_group_memory.py:analyze` (pure analysis) + `graph_passes/expand_buffers.py:expand` (the rewriter; runs inside materialize). `_expand_buffer` and `_Rewriter` inlined into expand_buffers as `_StmtRewriter`. | +| `lower_fp_row_patterns.py` | `graph_passes/lower_fp_row_patterns.py` (runs inside materialize, AFTER expand_buffers, because pattern matchers need the 4D-expanded shape) | +| `lift_to_blocks.py` + `lift_to_graph.py` | `lift_from_raw.py` (single-shot; no intermediate block form needed) | + +`graph_pipeline.run()` (the backwards-compat wrapper) deleted; only +`materialize_to_primfunc(graph, scopes, expand_lane_buffers=True)` +remains. `frontend/pipeline.py:compile_func` is now a single +straight-line call sequence with no fallback. diff --git a/tilelang_tvm_compiler/PIPELINE_ARCHITECTURE.md b/tilelang_tvm_compiler/PIPELINE_ARCHITECTURE.md new file mode 100644 index 0000000..a06c96c --- /dev/null +++ b/tilelang_tvm_compiler/PIPELINE_ARCHITECTURE.md @@ -0,0 +1,479 @@ +# Pipeline Architecture + +End-to-end walkthrough of how `tilelang_tvm_compiler` lowers a user-written +`@T.prim_func` to PLENA ISA, with notes on each pass's responsibilities, +inter-pass dependencies, and known gaps. + +--- + +## 1. Overview + +``` +@T.prim_func (user's tilelang DSL kernel) + │ + │ Frontend (frontend/pipeline.py) + │ 1. Stmt prep: inline_let_stmts, lower_compound_fp_stores + │ 2. lift_from_raw_primfunc (TIR → graph_ir.Graph) + │ 3. Graph-IR passes (annotate, lower, fuse, ...) + │ 4. materialize_to_primfunc (Graph → TIR) + │ 5. _rewrite_buffer_scopes (shared.dyn → vram, etc.) + ▼ +TIR with plena.* extern calls only +(plena.matmul / plena.mv / plena.btmm / plena.zero_v / plena.v_add / + plena.dma_h2v_slice / plena.copy_v_to_v / plena.row_load_v_to_fp / …) + │ + │ Backend (pipeline.compile_kernel) + │ PlenaCodegen.lower_to_hlir() + ▼ +HLIRModule (buffers + linear ops list) + │ + │ AddressAllocationPass + ▼ +HLIR with concrete addresses on every buffer + │ + │ IsaEmitterPass + ▼ +ISA text (the final .asm) +``` + +**Architectural principles:** + +1. **All semantic structure work lives on the Graph IR.** Per-op metadata + (sync, gemm kind, lane layout) is stored as `attrs` on `GraphNode` / + `BufferNode` / `NestedForGroup` / `ForRoot` — not as `T.attr(...)` + AttrStmts in a TIR tree. Passes are pure `Graph → Graph` functions. +2. **User-facing surface is tilelang DSL only** — `T.gemm` / `T.copy` / + `T.Parallel` / `T.alloc_*` / `T.attr(0, "plena.gemm_kind", ...)`. + `plena.*` is a compiler-internal IR namespace; kernel authors must + not write it directly. +3. **Per-head offsets are auto-injected.** The user writes + `T.gemm(buf, buf, buf)`; the compiler infers each operand's lane-axis + stride from its post-expansion shape (`expand_buffers`). +4. **Buffer-shape decisions happen at materialize time, not mid-pipeline.** + This is the key architectural shift versus the old stmt-walker + pipeline: graph optimizations run on un-expanded (logical 2D) shapes; + the lane axis is added once, at the boundary into TIR. Future + optimizations that want to change buffer shape (double-buffering, + dead-temp elim, etc.) are unblocked. + +--- + +## 2. Frontend pipeline + +The full chain from `frontend/pipeline.py:compile_func`: + +``` +TIR (user) + │ inline_let_stmts (stmt walker, IR cleanup) + │ lower_compound_fp_stores (stmt walker, expression-level) + │ lift_from_raw_primfunc ← into the Graph IR + ▼ +Graph + │ graph_passes.annotate_grid (ATTR_GROUP_EXTENT) + │ graph_passes.annotate_sync (ATTR_IS_SYNC) + │ graph_passes.split_lane_groups (extent>lane → outer × inner) + │ graph_passes.lift_lane_groups (ForRoot → LaneGroup) + │ graph_passes.fuse_elementwise (T.Parallel → plena.v_*) + │ graph_passes.scope_inference (BufferScopeMap) + │ graph_pipeline.materialize_to_primfunc(expand_lane_buffers=True): + │ graph_passes.allocate_group_memory.analyze (ATTR_LANE_LAYOUT) + │ graph_passes.expand_buffers.expand (rebuild tir.Buffer) + │ graph_passes.lower_fp_row_patterns (fp_*_at / row_*_at) + │ _partition_and_materialize (curtain bundle) + ▼ +TIR (with plena.* externs, lane-expanded buffers, tilelang scopes) + │ _rewrite_buffer_scopes (shared.dyn → vram, etc.) + ▼ +TIR (fully lowered, physical scopes — backend input) +``` + +### 2.1 Stmt prep (pre-graph) + +#### `inline_let_stmts` +Inlines `let x = expr in body` LetStmts. Pure IR cleanup, no semantic +change. Kept on stmt level because it's pre-everything-else and trivial. + +#### `lower_compound_fp_stores` +Rewrites `arr[i] = a*b + c*d` (compound RHS) into a sequence of single-op +stores using auto-allocated `__tmp_fp_*` temporaries. **Must run before +lift** because it operates on TIR expression trees; once lift folds ops +into call form, the RHS expression structure is no longer accessible. + +### 2.2 Lift to Graph IR (`lift_from_raw.py`) + +`lift_from_raw_primfunc(func) → Graph`. Single shot. Translates: + +| TIR construct | Graph IR | +|---|---| +| `T.launch_thread("blockIdx.x", N>1)` | `ForRoot(loop_var, extent=N)` | +| `threadIdx.*` or `blockIdx.*` extent==1 | dropped (degenerate) | +| `tilelang_root` BlockRealize body | `NodeRoot(items=[...])` | +| `Evaluate(Call(tl.tileop.copy/gemm_py/reduce))` | `GraphNode(op_call, reads, writes)` | +| `Evaluate(Call(tir.call_extern("plena.*")))` | `GraphNode(op_call)` (already-lowered passthrough) | +| `tir.For` (any kind) | `NestedForGroup(loop_var, kind, items)` | +| `BufferStore` / `LetStmt` / `IfThenElse` | `RawStmt(stmt)` (escape hatch) | +| `with T.attr(0, "plena.gemm_kind", "btmm"): T.gemm(...)` | `GraphNode(attrs={ATTR_GEMM_KIND: "btmm"})` (the AttrStmt is absorbed) | + +Concurrently, `_collect_buffers` walks every `Block.alloc_buffers` and +`func.buffer_map` and builds `Graph.buffer_nodes: dict[str, BufferNode]` +(name → node with `shape`, `dtype`, `declared_scope`, `data_var`). + +Each `GraphNode.reads / .writes` is a list of `BufferAccess(buffer_name, +starts, extents)` — references the BufferNode by name, not by direct +`tir.Buffer` reference. Layout rewrites in `expand_buffers` flow through +the BufferNode without per-region mutation. + +### 2.3 Graph-IR passes (shape-agnostic phase) + +Each pass takes a `Graph` and returns a `Graph`. None of them change +buffer shapes — those decisions are deferred to materialize. + +#### `annotate_grid` (`graph_passes/annotate_grid.py`) +Sets `ATTR_GROUP_EXTENT` on every `ForRoot` (which came from a +`blockIdx.* > 1` binding) and on every `NestedForGroup` whose `kind == +PARALLEL` (came from `T.Parallel`). Also rewrites the parallel kind to +SERIAL (PLENA hardware is single-threaded; the group annotation is what +signals "iterations are lane-fusion-eligible" to downstream passes). + +#### `annotate_sync` (`graph_passes/annotate_sync.py`) +Sets `ATTR_IS_SYNC` on every GraphNode. True iff: +- HBM ↔ local-buffer DMA copy +- VRAM ↔ FPRAM rank-1 copy (S_MAP_*_*) +- VRAM ↔ VRAM copy (V_ADD_VF f0=0) +- gemm with `ATTR_GEMM_KIND == "btmm"` +- already-lowered `plena.*` extern in `INHERENTLY_SYNC_EXTERNS` + +The classification table is in [graph_pipeline.py](frontend/passes/graph_pipeline.py#L52) +(`INHERENTLY_SYNC_EXTERNS` / `PER_LANE_UNROLLED_EXTERNS`). + +> **Important invariant:** any pass that *creates* new GraphNodes must +> set `ATTR_IS_SYNC` itself if the new node is in +> `INHERENTLY_SYNC_EXTERNS`. `annotate_sync` runs once and never sees +> later-created nodes. `fuse_elementwise` sets it on the +> `plena.zero_v` / `plena.v_*` it creates — see [bug fix](#bug-history) +> for the consequence of forgetting. + +#### `split_lane_groups` (`graph_passes/split_lane_groups.py`) +For every `ForRoot` / `NestedForGroup` carrying `ATTR_GROUP_EXTENT > lane_count` +where the body recursively contains a sync GraphNode that references the +loop var: rewrite as `outer(extent/lane_count) × inner(lane_count, +ATTR_IS_LANE_FOR=True)`. The body's references to `v` get substituted +with `v_outer * lane_count + v_inner` via `_GraphVarSubst`, which walks: +- every `GraphNode.op_call` (TIR Call, recursively) +- every `BufferAccess.starts` / `.extents` +- every `RawStmt.stmt` (recursively, via inlined `_StmtVarSubst`) +- every nested `NestedForGroup.min` / `.extent` + +#### `lift_lane_groups` (`graph_passes/lift_lane_groups.py`) +ForRoots / inner-of-pair NestedForGroups carrying `ATTR_GROUP_EXTENT == +lane_count` (or `ATTR_IS_LANE_FOR=True`) get upgraded to `LaneGroup` +nodes — the explicit container for "one lane fusion bundle" that the +materialize-time partitioner identifies as the curtain-bundle scope. + +Without this upgrade, materialize would emit each lane group as a plain +`tir.For` with no per-lane partitioning — all ops would run inside the +for-by, including the sync ones (wrong). + +#### `fuse_elementwise` (`graph_passes/fuse_elementwise.py`) +Pattern-matches `NestedForGroup(ATTR_GROUP_EXTENT, items=[RawStmt(BufferStore)])` +forms and replaces them with single GraphNodes: + +| Pattern | Replacement | +|---|---| +| `for i: dst[..., i] = a[..., i] + b[..., i]` | `plena.v_add(a, b, dst)` | +| `for i: dst[..., i] = a[..., i] - b[..., i]` | `plena.v_sub(a, b, dst)` | +| `for i: dst[..., i] = a[..., i] * b[..., i]` | `plena.v_mul(a, b, dst)` | +| `for i: dst[..., i] = 0` | `plena.zero_v(dst)` | + +Plus the **nested fold**: outer serial-for wrapping a single fused +whole-buffer op → drops the outer for entirely (since `plena.v_*` and +`plena.zero_v` are inherently whole-buffer; running them N times is +wrong). + +The created GraphNodes are tagged `ATTR_IS_SYNC=True` because they're +all in `INHERENTLY_SYNC_EXTERNS`. + +#### `scope_inference` (`graph_passes/scope_inference.py`) +Owns `BufferScopeMap = dict[str, str]` and `ScopeInferenceError`. Walks +every GraphNode and infers each buffer's physical scope: + +| Declared scope + usage | Resolved physical scope | +|---|---| +| `func.buffer_map` (param) | `hbm` | +| `global.` | as-declared | +| `shared.dyn`, used as gemm RHS / matmul-arg[1] | `mram` | +| Other `shared.dyn` | `vram` | +| `local.fragment`, used as `plena.fp_*_at` / `row_*_at` operand, OR rank-1, OR T.reduce dst | `fpram` | +| Other `local.fragment` | `vram` | + +### 2.4 Materialize (shape-aware phase + lowering) + +`materialize_to_primfunc(graph, scopes, expand_lane_buffers=True)` — +defined in [graph_pipeline.py](frontend/passes/graph_pipeline.py). +Runs three more graph passes that need final shape decisions, then +walks the graph to emit TIR. + +#### `allocate_group_memory.analyze` +Classifies every buffer touched by lane-fused ops. Sets +`BufferNode.attrs[ATTR_LANE_LAYOUT]` to one of: + +| Layout | Pre-expansion shape | Post-expansion shape | Triggered by | +|---|---|---|---| +| `col_pack` | `(rows, last)` | `(1, rows, lane_count, last)` | BTMM args[0,1]; non-btmm matmul args[1,2]; HBM↔local DMA local side; plena.v_*/matmul/row_* trailing Var args | +| `row_stack` | `(rows, last)` | `(1, lane_count, rows, last)` | BTMM args[2]; non-btmm matmul args[0] | +| `fp_lane` | `(N,)` | `(lane_count, N)` | VRAM↔FPRAM rank-1 copy fpram side; plena.fp_*_at / row_*_at FP operands; BufferStore to FPRAM | + +Conflict rules: `row_stack` wins over `col_pack` (BTMM output's BHSD +layout dictates); `fp_lane` doesn't mix with the other two. +`global.*`-scoped buffers are skipped (user already wrote the physical +shape). + +Also writes `BufferNode.attrs[ATTR_LANE_VAR]` (string name of the lane +var that this buffer's lane axis substitutes for). + +#### `expand_buffers.expand` +For every BufferNode with `ATTR_LANE_LAYOUT`, rebuilds a fresh +`tir.Buffer` with the expanded shape, then walks the whole graph +rewriting: +- every `GraphNode.op_call` arg referencing the old buffer (BufferLoad + inside `tl.tileop.region`, trailing `buf.data` Var args, etc.) → + swap to new buffer + fold lane index into the indices via + `_StmtRewriter._fold_lane` +- every `BufferAccess(name, starts, extents)` → starts get `_fold_lane` + applied, extents get a unit slot inserted at the lane axis +- every `RawStmt(stmt)` → `_StmtRewriter.visit(stmt)` does the same + swap+fold for any BufferLoad/BufferStore/Var inside + +Index folding rules (mirrors the buffer-shape changes): + +| Mode | Pre-fold indices | Post-fold indices | +|---|---|---| +| `col_pack` | `[r, c]` | `[0, r, lane_var, c]` | +| `row_stack` | `[r, c]` | `[0, lane_var, r, c]` | +| `fp_lane` | `[r]` | `[lane_var, r]` | + +Critical detail: `lane_var` is the **actual** `tir.Var` from the +surrounding ForRoot/LaneGroup, not a fresh same-named Var. TIR resolves +vars by identity; using a synthetic var produces "unbound symbol" +errors at codegen time. + +#### `lower_fp_row_patterns` +Must run **after** expand because the row-parallel pattern matcher +requires the 4D-expanded buffer shape (matches legacy stmt-walker +ordering). Three pattern families: + +| Source | Replacement | +|---|---| +| `RawStmt(BufferStore)` to FPRAM buffer | GraphNode `plena.fp_zero_at` / `fp_copy_at` / `fp_add_at` / `fp_sub_at` / `fp_mul_at` / `fp_exp_at` / `fp_reci_at` | +| `NestedForGroup(ATTR_GROUP_EXTENT, items=[RawStmt(BufferStore)])` on VRAM | GraphNode `plena.row_exp_at` / `row_sub_fp_at` / `row_mul_fp_at` | +| `GraphNode(tl.tileop.reduce)` with VRAM src + FPRAM dst | `RawStmt(For row in N: Evaluate(plena.row_reduce_max_at / row_reduce_sum_at))` (escape hatch — the per-row for has no graph-IR analogue) | + +#### `_partition_and_materialize` (curtain bundle) +Walks the (now expanded + lowered) graph and emits TIR. + +For a `LaneGroup`: scan items, partition at sync boundaries: +- **Sync GraphNode**: flush any accumulated per-lane run, emit op once + with `in_sync=True` (no surrounding for-by — it's a multi-lane HW + instruction). +- **Per-lane GraphNode**: accumulate into the current per-lane run. +- **NestedForGroup with no inner sync**: accumulate as opaque per-lane + block. +- **NestedForGroup with inner sync**: flush per-lane run, recurse into + body, wrap result in `tir.For(loop_var)`. +- **RawStmt**: accumulate as per-lane. + +When the per-lane run flushes, it gets wrapped in +`for(lane_var, range(lane_count))` — `UNROLLED` kind if any item is in +`PER_LANE_UNROLLED_EXTERNS` (currently just `plena.matmul`), else +`SERIAL`. + +For a `NodeRoot` (no lane fusion, e.g. mm64-style): items emit as a +plain stmt sequence with `lane_var=None`. + +For `ForRoot`: recursively materialize the body, then wrap in `tir.For`. + +Each `GraphNode._lower_node` delegates to +[`lower_to_hlir._lower_copy / _lower_gemm`](frontend/passes/lower_to_hlir.py) +for the actual `tl.tileop.copy → plena.dma_*` and `tl.tileop.gemm_py → +plena.btmm/matmul/mv` translation. Already-lowered `tir.call_extern` +nodes pass through unchanged. + +### 2.5 Final scope rewrite (post-graph) + +`_rewrite_buffer_scopes` (in `lower_to_hlir.py`, despite the misleading +filename — see § 5.4): substitutes every `shared.dyn` / +`local.fragment` declared scope with the resolved physical scope from +`scopes`. Rebuilds `tir.Buffer` objects so backend codegen reads +`buf.scope() ∈ {hbm, vram, mram, fpram, global.*}` directly. + +This step is intentionally outside the graph layer — graph passes use +declared (tilelang) scopes; the codegen-facing physical scope rename +is the boundary into backend. + +--- + +## 3. Backend — three stages + +(Not part of `frontend/pipeline.py`; same `compile_kernel` flow; see +[`tilelang_tvm_compiler/pipeline.py`](pipeline.py).) + +### 3.1 `PlenaCodegen.lower_to_hlir()` ([codegen.py](codegen.py)) +TIR → HLIR data structure. +- `_collect_param_buffers` + `_collect_alloc_buffers` walk every buffer + into `_buffers` (keyed by `tir.Var`, i.e. `buffer.data`). +- `_walk_stmt` / `_walk_evaluate` rewrite each `plena.*` extern call to + `_hlir.Op(kind="", buffer_args=[...], + scalar_args=[...])`. +- For-loops become `_hlir.Op(kind="for", body=[...])` nests. +- Output is `_hlir.HLIRModule(name, buffers, ops)`. + +### 3.2 `AddressAllocationPass` ([address_alloc.py](address_alloc.py)) +Assigns each buffer a concrete address in declaration order: +- HBM: from `0`, advance by buffer size. +- VRAM: from `0`, advance by buffer size (each row is MLEN-wide). +- MRAM: tile-aligned allocation. +- FPRAM: from `FPRAM_USER_BASE = 32`, advance by buffer size. + +### 3.3 `IsaEmitterPass` ([isa_pass.py](isa_pass.py)) +HLIR → ISA text. Each op kind has a `_emit_*` method. A +`symbol_table: Dict[tir.Var, int]` tracks loop var → GP register +bindings; `ExprMaterializer` lowers dynamic `PrimExpr`s into chains of +ISA arithmetic instructions. + +--- + +## 4. End-to-end trace: a P @ V `T.gemm` + +```python +# User writes: +with T.Kernel(1, head_count) as (_, by): + ... + T.gemm(S_loc, V_sh, PV_loc) # default KIND="overwrite" +``` + +| Step | Pass | What changes | +|------|------|---| +| 1 | `inline_let_stmts`, `lower_compound_fp_stores` | TIR cleanup; no change to this gemm. | +| 2 | `lift_from_raw_primfunc` | `T.gemm(...)` → `GraphNode(op_call=tl.tileop.gemm_py(...), reads, writes)`. The blockIdx.y binding becomes a `ForRoot(by, extent=head_count)`. | +| 3 | `annotate_grid` | `ForRoot(by).attrs[ATTR_GROUP_EXTENT] = head_count`. | +| 4 | `annotate_sync` | gemm has no `ATTR_GEMM_KIND="btmm"`, so `attrs[ATTR_IS_SYNC] = False`. | +| 5 | `split_lane_groups` | If `head_count > lane_count`, split the ForRoot into `by_outer × by_inner`; vars in op_call args / BufferAccess get rewritten. | +| 6 | `lift_lane_groups` | Inner ForRoot (extent=lane_count, ATTR_IS_LANE_FOR=True) → `LaneGroup(lane_var=by_inner, lane_count=4, items=[...])`. | +| 7 | `scope_inference` | `S_loc`/`V_sh`/`PV_loc` resolve (S_loc → vram, V_sh → mram, PV_loc → vram). | +| 8 | `materialize`: `allocate_group_memory.analyze` | Non-btmm gemm: `S_loc → row_stack`, `V_sh → col_pack`, `PV_loc → col_pack` (only if untouched by other ops). | +| 9 | `materialize`: `expand_buffers.expand` | `S_loc.shape: (64, 64) → (1, lane_count, 64, 64)`; `V_sh / PV_loc: (64, 16) → (1, 64, lane_count, 16)`. Op_call indices get `_fold_lane` applied. | +| 10 | `materialize`: `_partition_and_materialize` | gemm is per-lane (not sync). Accumulates into per-lane run wrapped in `for(by_inner)`. | +| 11 | `_lower_node → _lower_gemm` | KIND=overwrite + LHS rows=1 ⇒ `plena.mv`; per-lane offsets auto from `by_inner * stride`. | +| 12 | `_rewrite_buffer_scopes` | `shared.dyn` → `mram` for V_sh; `local.fragment` → `vram` for S_loc / PV_loc. | +| 13 | `PlenaCodegen` | `plena.mv` → `Op(kind="mv", scalar_args=[by*64, by*16, by*16])`. | +| 14 | `AddressAllocationPass` | Concrete addresses for `S_loc` / `V_sh` / `PV_loc`. | +| 15 | `IsaEmitterPass` | Emit `M_MV` × tile_count + `M_MV_WO` writeback. | + +--- + +## 5. Known gaps + +### 5.1 Loop-scheduling layer not started +Graph IR has structure for it (`ForRoot.attrs`, `NestedForGroup.attrs`, +`Graph.buffer_nodes` whose shape can be mutated mid-pipeline) but no +pass actually does loop optimization. The migration plan §"Phase D" +covers what's planned: + +- **Cross-iter DMA merge**: combine DMAs of consecutive K/V tiles into + a single bigger DMA. Reduces per-tile DMA setup cost. +- **Double-buffering / prefetch**: alloc `K_sh_alt`, prefetch tile `k+1` + while computing on tile `k`. Hides DMA latency. + +Both depend on hardware capability info we haven't pinned down (max +single-DMA element count, MRAM/VRAM capacity headroom, DMA stride +support — see `MIGRATION_PLAN.md` §"Open questions"). + +### 5.2 RawStmt is the graph layer's blind spot +`RawStmt` wraps a TIR subtree the lift can't classify (`IfThenElse`, +LetStmt-leftovers, BufferStore inside non-lane-eligible serial fors, +T.reduce that lower_fp_row_patterns turned into `for row: row_reduce_*`). +Graph passes can do mechanical var-subst and buffer-replace inside it +(`_StmtVarSubst`, `_StmtRewriter.visit`) but cannot reason about its +control flow. + +If/when we add proper graph-IR nodes for `IfThenElse` and `Reduce`, +RawStmt usage shrinks; until then it's an escape hatch for "this shape +is rare, just pass it through". + +### 5.3 `fuse_elementwise` op set +Currently: `+` `-` `*` `0`-fill. No `/`, no `exp`, no `relu`, no +non-zero const fill. Adding new ones requires a backend intrinsic plus +extending `fuse_elementwise._OP_TO_INTRIN`. ~20 LoC each. + +### 5.4 `KIND="add"` reserved but not wired +`C += A @ B` in one gemm. Recognised by the kind parser (no error) but +`_lower_gemm` raises `NotImplementedError`. Workaround: + +```python +scratch = T.alloc_fragment((rows, hlen), "float16") +T.gemm(A, B, scratch) # KIND=overwrite +for r in T.serial(rows): + for c in T.Parallel(C): + dst[r, c] = dst[r, c] + scratch[r, c] # auto-fuses to plena.v_add +``` + +### 5.5 `lower_to_hlir.py` is a misleading filename +Despite the name, [frontend/passes/lower_to_hlir.py](frontend/passes/lower_to_hlir.py) +is **not** the TIR → HLIR backend — that's `PlenaCodegen.lower_to_hlir()` +in `codegen.py`, despite sharing the method name. The file holds three +unrelated frontend op-level helpers: + +- `_lower_copy`: `tl.tileop.copy → plena.dma_*` +- `_lower_gemm`: `tl.tileop.gemm_py → plena.btmm/matmul/mv` +- `_rewrite_buffer_scopes`: tilelang scope → physical scope (used as + the very last frontend step) + +Renaming this file to e.g. `op_lowering.py` or `dsl_lowering.py` would +remove a recurring confusion source. Not done yet because it's a +renames-everywhere change. + +### 5.6 `forbid_plena_extern` is opt-in +A sanity check that asserts a kernel uses only tilelang DSL (no +`plena.*` extern calls). Currently kernel authors must invoke it +manually before calling `compile_func`. Could be wired in by default +under a flag. + +--- + +## 6. Bug history (lessons learned the hard way) {#bug-history} + +### `fuse_elementwise` missed `ATTR_IS_SYNC` on created nodes +**Symptom**: flash_attention_min e2e numerics off by exactly `lane_count` +(simulated ≈ 4× golden) for the largest output magnitudes. + +**Root cause**: `fuse_elementwise` runs **after** `annotate_sync`. When +it created new `plena.zero_v` / `plena.v_add` GraphNodes, it left +`attrs={}`. `annotate_sync` had already finished and didn't see them. +The materialize-time partitioner saw `ATTR_IS_SYNC=False` (default) and +emitted these `INHERENTLY_SYNC_EXTERNS` ops *inside* the per-lane +`for(by_i)`, running each `O += PV` four times instead of once. + +**Fix**: `fuse_elementwise` now sets `attrs={ATTR_IS_SYNC: True}` on +every new node. Same invariant applies to any future pass that creates +inherently-sync ops. + +**General principle**: any pass that creates new GraphNodes is +responsible for setting their attrs correctly — `annotate_*` passes +don't re-run. + +--- + +## 7. Phase status + +- ✅ **Phase A**: graph IR types, lift_from_raw, walker helpers +- ✅ **Phase B**: incremental graph-layer passes (annotate_gemm_kind, + annotate_sync, scope_inference initially as proof of concept) +- ✅ **Phase C.1**: write all graph-layer passes; switch pipeline to + graph path +- ✅ **Phase C.2**: delete legacy stmt-walker passes, frontend_legacy/, + 6 stmt-walker test files; only graph path remains +- ⏳ **Phase D**: loop-scheduling layer (DMA merge, prefetch / double-buffer) + — not started; depends on HW capability info + +See [`MIGRATION_PLAN.md`](MIGRATION_PLAN.md) for the original migration +plan and open questions. diff --git a/tilelang_tvm_compiler/REGALLOC_B_DESIGN.md b/tilelang_tvm_compiler/REGALLOC_B_DESIGN.md new file mode 100644 index 0000000..4b291a4 --- /dev/null +++ b/tilelang_tvm_compiler/REGALLOC_B_DESIGN.md @@ -0,0 +1,111 @@ +# Region-recursive GP allocation (design B) + +Replaces the linear-scan-on-flattened-IR allocator in `mir_to_isa.py`. +The MIR is already clean SCF (blocks of interleaved instrs + nested +`MirLoop` regions). This design allocates *over that tree* instead of +flattening it, so loop-carried values are handled by construction — no +`_open_loop_ends` / `_operand_lock` / `_no_release` / pin patches. + +## The one idea + +> A region's **live-in** values (used inside, defined outside) occupy +> fixed GPs for the WHOLE region. Only values *born and killed inside* +> the region are spill candidates within it. + +Why this kills the current bug: emit is single-pass but a loop body runs +N times. A spill emitted in the body's second half is not replayed when +control jumps back to the head, so spilling a value the head re-reads +corrupts iteration 2. Live-in values are exactly the values the head +re-reads. If they are reserved before the body and never offered to the +in-body spill picker, the corruption is structurally impossible. + +A value born and killed inside the body is safe to spill: its spill and +reload both lie within one iteration's emitted code, so the round-trip +re-executes intact every iteration. + +## Liveness the SCF way (no flat index) + +For each value `v`, `def_region(v)` = the innermost region containing its +def. `v` is **live-in** to region `R` iff it has a use inside `R` and +`def_region(v)` is a strict ancestor of `R` (i.e. `v` is defined outside +`R`). + +Per region `R` we need: +- `live_in(R)` — values defined outside `R`, used inside `R`. +- `local(R)` — values defined inside `R` (directly, not in a child + region). These are the spillable ones for `R`. + +Both are computable in one post-order walk: +``` +live_in(R) = (union over children C of live_in(C) + ∪ {operands of R's own instrs}) + minus {values defined in R} +local(R) = {results of R's own instrs} # children's locals stay theirs +``` +(`live_in` of the function's top region is just the function consts / +block args — typically only gp0.) + +## Allocation walk + +`assign_region(R, free)`: +- `free` = set of GP numbers usable inside `R` (caller already subtracted + what it reserved for `R`'s live-ins + the loop counter/lvar). +- Walk `R`'s items in order. For each: + - **MirInstr**: operands already hold GPs (live-in reserved, or a local + assigned earlier in this walk). Assign the result a GP from `free` + (or spill a *local* whose last in-region use has passed). Free locals + whose last use is this instr. + - **MirLoop child `L`**: + 1. `carried = live_in(L) ∩ (currently-held values)` — these must + stay put across `L`. They are *already* in GPs; we simply do NOT + lend those GPs into `L`. + 2. Reserve a GP for `L`'s counter + lvar. + 3. `inner_free = free − {carried GPs} − {counter,lvar GPs} − + {GPs held by R-locals still live across L}`. + 4. `assign_region(L, inner_free)`. + 5. On return, reclaim counter/lvar GPs. + +The spill picker, when invoked inside `R`, can only see `local(R)` values +not yet dead — never a live-in (those GPs were never in `inner_free`). + +## What stays identical +- `compute_emit_order` — still used, only to order last-use *within a + region* (a region is straight-line in emit order, so a plain index + inside the region is exact; no cross-loop extension needed anymore). +- `assign_addr_reg` (a-regs), fp_reg tokens, gp0 = const-zero, IntRAM + slot reservation, the S_ST_INT/S_LD_INT spill/reload emission, the + serial-loop prologue/epilogue, `_free_gp`/`_take_gp` invariants. +- The `_emit_instr` operand-formatting + result-prepend logic. + +## What is deleted +- `_compute_live_intervals` (the flat `end` table) → replaced by + `live_in`/`local` per region. +- `release_dead_at(cur_idx)` global sweep → per-region last-use frees. +- `_open_loop_ends`, `_operand_lock`, `_no_release`, `pin`/`unpin` for + data values. (Counter/lvar still get a GP reserved for the region, + but via reservation, not a pin set.) + +## Spill correctness (now provable) +- A spilled value is always a `local(R)` of the region currently being + walked. Its spill point and every reload point are emitted inside the + same straight-line region body. A loop re-entry re-executes that whole + body, so the spill/reload pair re-runs together each iteration. No + cross-iteration leak. +- Live-ins are never spilled inside a child loop (their GPs aren't lent + in), so the head's re-read always sees the correct GP. + +## Register-pressure note +Reserving all live-ins of a deep loop nest could exhaust 15 GPs. If +`inner_free` is empty when a child needs registers, that's genuine +pressure: spill an *outer* `local` (safe — it's straight-line at that +level) before descending, or, as a later refinement, allow spilling a +live-in to a **dedicated per-region IntRAM slot reloaded at the loop +head** (the "loop-head fixup" that makes spilling a carried value safe). +Start without that; the min kernels fit. + +## Migration +Single file (`mir_to_isa.py`). `MirToIsa.run()` calls +`assign_region(top, all_user_gps)` once; `_emit_instr` and the serial +loop emitter call into the region allocator instead of the linear-scan +one. Keep the old class around behind a flag for one commit to A/B the +ISA, then delete. diff --git a/tilelang_tvm_compiler/REGALLOC_SCOPE_DESIGN.md b/tilelang_tvm_compiler/REGALLOC_SCOPE_DESIGN.md new file mode 100644 index 0000000..8901256 --- /dev/null +++ b/tilelang_tvm_compiler/REGALLOC_SCOPE_DESIGN.md @@ -0,0 +1,108 @@ +# Scope-recursive GP allocation (the real fix) + +Replaces the linear-scan + patch pile in `mir_to_isa.py`. The MIR is +clean SCF (blocks of instrs + nested `MirLoop` regions). Allocate **over +that tree, with spill/reload happening only at scope boundaries** — never +inside a loop body. That boundary discipline is what every previous bug +violated (a spill/reload instruction landing in a loop body re-executes +every iteration with stale register state). + +## 0. Why all the previous bugs were the same bug + +A `S_ST_INT`/`S_LD_INT` for a value that lives ACROSS a loop must not sit +in the loop body. The body is emitted once but runs N times: +- store-in-body → each iteration overwrites the slot with whatever the GP + now holds (the `intram[7]=2048` bug); +- reload-in-body-before-its-store → reads last iteration's clobbered GP + (the `%74=800` bug). + +Fix structurally: **spill at scope entry, reload at scope exit, both +OUTSIDE the body.** Then a body never contains a cross-boundary +spill/reload, so it can't be corrupted by re-execution. + +## 1. Pre-pass: classify every value (one walk) + +Walk the SCF tree once. For each `MirValue` record: + +- **kind**: + - `counter` — a serial loop's hardware counter (minted by emit). Needs + a GP pinned for the whole loop; HW reads/decrements it. Never spill. + - `loop_idx` — a loop_var (block argument). Has an IntRAM idx slot; + reloaded at each iteration header, ++ stored at footer. Never needs a + persistent GP. + - `normal` — everything else (address arithmetic, temporaries). +- **def_scope**: the innermost region (loop) containing its def. Top-level + = the function scope. +- **last_use / live scopes**: the set of scopes in which it is read. + +From this derive, per scope S (a loop region or the top region): +- `live_in(S)` — values defined OUTSIDE S, read INSIDE S (carried). +- `local(S)` — values defined directly in S. + +(`gp0` is the const-zero fixture; never allocated/spilled.) + +## 2. Per-scope register budget (decide BEFORE entering) + +Recurse the tree. Entering scope S we know: +- `regs_free` = GPs available from the parent. +- S needs: 1 GP for its counter (if serial loop), 1 for its loop_idx + while live, plus enough for `local(S)` peak + the `live_in(S)` values + it actually touches. + +If `regs_free` can't cover S's needs, we **demote the lowest-priority +carried values to IntRAM at S's entry** (see §3). This is the decision +point — made structurally per scope, not reactively mid-body. + +## 3. Scope entry / exit protocol (the heart of it) + +Entering scope S (right after `C_LOOP_START`, before the body): +1. Counter: assign + pin its GP. +2. loop_idx: header `S_LD_INT idx_gp, slot` (already the case). +3. For each `live_in(S)` value V the body will read: + - if a GP is available → keep V in its GP (no memory traffic); + - else → it is **resident in IntRAM**: ensure V's value is in its slot + (stored at S's ENTRY, outside the body), and the body reads V by + reloading from the slot **at each use** — but those reloads are + body-local and re-execute fine because the SLOT WAS WRITTEN AT ENTRY + (outside the body) and V is loop-invariant. + Key: the STORE is emitted at entry (outside body). Only LOADs are in + the body. Loads-in-body are safe; stores-in-body are not. + +Exiting scope S: +4. Any value that S's body left in IntRAM but the PARENT scope still needs + in a GP → reload it once here (outside S, after `C_LOOP_END`). +5. Free S's counter/idx GPs. + +So the invariant holds by construction: **every `S_ST_INT` is at a scope +entry, every cross-scope `S_LD_INT` is at a scope entry/exit — none in a +body that loops.** Body-internal spill/reload only ever touches `local(S)` +temporaries whose store+load are in the same straight-line body (one +iteration), which is correct. + +## 4. Inside a scope body (straight-line) + +Within S's body, between child loops, it's straight-line: ordinary +linear assignment from `regs_free`, and `local(S)` temporaries may +spill/reload freely (same-iteration, safe). Child `MirLoop` → recurse +(§2/§3) with the GPs not reserved by S's carried set + counter/idx. + +## 5. What this kills / keeps +- Kills: `_carried_now` spill-prohibition hacks, volatile/remat tiers, + the "spill landed in body" corruption, the four-dict state churn. +- Keeps: `compute_emit_order`, `_free_gp`/`_take_gp` invariants, IntRAM + monotonic slot allocator, operand-lock, addr_reg/fp/gp0 handling, the + serial-loop prologue/epilogue ISA, `_check_no_iter_args` guard. + +## 6. Register pressure / failure +If even after demoting all `live_in(S)` to IntRAM a scope still can't fit +`local(S)` peak + counter + idx in the GP file, that's genuine pressure → +fail loud (or spill a `local` to IntRAM within the body, which is safe). +For fa_min: counter+idx per level (≤2/level) + a few locals; the carried +constants/address-terms go to IntRAM, reloaded per use. Fits. + +## 7. Migration +Single file. New `ScopeAllocator` driven by a recursive `emit_scope` +replacing the flat `_emit_block`/`_emit_loop_serial` walk's allocation +decisions. Keep old class behind a flag for one commit to A/B the ISA, +then delete. Build the classification pre-pass first; assert it matches +the `_check_no_iter_args` expectations (only loop_idx as block arg). diff --git a/tilelang_tvm_compiler/__init__.py b/tilelang_tvm_compiler/__init__.py new file mode 100644 index 0000000..a202f92 --- /dev/null +++ b/tilelang_tvm_compiler/__init__.py @@ -0,0 +1,82 @@ +"""TVM-based PLENA compiler. + +Pipeline (3 passes): + + TIR PrimFunc (with PLENA scopes + plena.* extern calls) + | + v PASS 1 PlenaCodegen.lower_to_hlir() + HLIR (Buffer + Op stream, no addresses) + | + v PASS 2 AddressAllocationPass + HLIR with HBM/VRAM/MRAM addresses + stride/scale + tile annotations + | + v PASS 3 IsaEmitterPass (re-uses runtime's ISAEmitter via shim) + Real PLENA ISA text -> assembler -> .mem -> emulator + +============================================================================== +HARD CONVENTIONS -- READ THIS FIRST WHEN ADDING A KERNEL +============================================================================== + +These three rules are load-bearing. Violating any of them silently produces +emulator output that does not match golden, with no immediate error. + +1) HBM IS ALWAYS BSHD. + Every HBM-resident T.Buffer must declare shape as (Batch, Seq, Heads, Dim). + `create_mem_for_sim` -> `map_mx_data_to_hbm_for_behave_sim` packs tensors + into hbm_for_behave_sim.bin assuming this layout. The address_alloc pass + computes per-tile stride from `H*D` (last two dims merged). If you + declare an HBM buffer in any other order, the H_PREFETCH_*/H_STORE_V + addresses will be wrong and the emulator will read/write the wrong cells. + +2) VRAM/MRAM REFLECTS PHYSICAL HARDWARE LAYOUT. + Different ops produce different physical layouts in VRAM/MRAM: + - DMA from HBM (H_PREFETCH_V/M) preserves BSHD. + - BTMM/BMM_WO writes BHSD (head is the outermost stored dimension -- + see transactional_emulator/src/main.rs:bmm_wo()). + Declare the buffer's shape to match what the hardware actually produces. + The dma_v2h pass is what reconciles "VRAM=BHSD" with "HBM=BSHD" via + tile-level reorder during the store -- it walks col-block-major over the + BSHD HBM dst, which lands vram_off = idx * tile_elems exactly on each + head's tile boundary in BHSD VRAM, transparently transposing. + Lying about the VRAM shape (e.g. labelling a BHSD VRAM buffer as BSHD) + may "work" by coincidence for one kernel but breaks as soon as another + op tries to index it. + +3) COMPARISON IS ALWAYS BSHD. + At the testbench boundary, both golden and the post-staging VRAM dump + must be in BSHD-flat form (B*S rows x H*D cols): + - golden: `golden_4d.reshape(B*S, H*D)` before passing to + `create_sim_env(golden_result=...)`. + - simulated: the `--stage-output BUFFER` flag emits per-tile DMAs + that lay out HBM-BSHD into VRAM[0..] in a stride-mode-compatible + arrangement. `view_mem.py` reassembles via `chunks_per_batch=H`, + producing a BSHD-flat (B*S, H*D) tensor for diff against golden. + +============================================================================== +""" + +# Bootstrap: tilelang ships its own TVM 0.23 in `tilelang/3rdparty/tvm/`. +# Importing tilelang first injects that bundled TVM onto sys.path so the +# subsequent `from tvm import ...` statements throughout this package +# pick up 0.23 (which is what we target). When this package is consumed +# from a venv that does not have tilelang installed, `import tilelang` +# raises ImportError; we tolerate that case so unit tests of pure +# codegen logic can still run as long as some TVM is on sys.path. +try: + import tilelang as _tilelang_for_tvm_bootstrap # noqa: F401 +except ImportError: + pass + +from .codegen import PlenaCodegen, compile_module +from .test_helper import TvmTestbenchSpec, run as run_testbench +from . import scope +from . import intrinsics + +__all__ = [ + "PlenaCodegen", + "compile_module", + "TvmTestbenchSpec", + "run_testbench", + "scope", + "intrinsics", +] diff --git a/tilelang_tvm_compiler/__main__.py b/tilelang_tvm_compiler/__main__.py new file mode 100644 index 0000000..55abc7b --- /dev/null +++ b/tilelang_tvm_compiler/__main__.py @@ -0,0 +1,423 @@ +"""CLI entry for cross-venv invocation. + +The TVM compiler runs in a Python 3.11 venv (.venv-tvm) because no +apache-tvm wheel exists for 3.12. The rest of the project (torch, sim +env utilities, golden compute) lives in the main 3.12 venv. + +Driver scripts in the main venv subprocess into here to get the ISA +text without dragging TVM into their own interpreter: + + out = subprocess.check_output([ + ".venv-tvm/bin/python", "-m", "tilelang_tvm_compiler", + "compile", + "--kernel", "tilelang_tvm_compiler.kernels.minimal_btmm:minimal_btmm", + "--asm-name", "tvm_btmm_kernel", + ], env={"LD_LIBRARY_PATH": "", ...}).decode() + +`out` is the full ISA text, ready to drop into create_sim_env's +`generated_code` argument. +""" + +from __future__ import annotations + +import argparse +import importlib +import json +import sys +from pathlib import Path + +from .pipeline import compile_kernel, PlenaTarget +from .program_shim import make_shim +from .register_alloc import RegisterAllocator +from .isa_emitter import ISAEmitter +from .hlir import format_hlir +from . import scope as _scope + + +def _parse_kernel_kwargs(spec: str | None) -> dict: + """Parse a comma-separated `k=v,k=v` string into a kwargs dict. + Values are coerced to int if possible (the only case we currently + need for shape parameters).""" + if not spec: + return {} + out: dict = {} + for pair in spec.split(","): + if not pair.strip(): + continue + if "=" not in pair: + raise SystemExit( + f"--kernel-kwargs entry must be key=value, got {pair!r}" + ) + k, v = pair.split("=", 1) + k = k.strip() + v = v.strip() + try: + out[k] = int(v) + except ValueError: + out[k] = v + return out + + +def _resolve_kernel(spec: str, kwargs: dict | None = None): + """Resolve a `module.path:symbol` string into a TIR PrimFunc. + + `symbol` can be either: + * a `tir.PrimFunc` directly (no kwargs accepted) + * a callable factory; we call it with `kwargs` and accept either a + PrimFunc or a `(PrimFunc, ...)` tuple as the return value + """ + from tvm import tir + if ":" not in spec: + raise SystemExit( + f"--kernel must be of the form module:funcname, got {spec!r}" + ) + mod_path, func_name = spec.split(":", 1) + mod = importlib.import_module(mod_path) + print(f"[kernel] {mod_path} -> {mod.__file__}", file=sys.stderr) + if not hasattr(mod, func_name): + raise SystemExit(f"{mod_path!r} has no attribute {func_name!r}") + obj = getattr(mod, func_name) + if isinstance(obj, tir.PrimFunc): + if kwargs: + raise SystemExit( + f"{func_name!r} is already a PrimFunc; --kernel-kwargs not allowed" + ) + return obj + if callable(obj): + result = obj(**(kwargs or {})) + if isinstance(result, tuple): + # Factories like make_tiled_btmm return (PrimFunc, constants). + result = result[0] + if not isinstance(result, tir.PrimFunc): + raise SystemExit( + f"factory {func_name!r} returned {type(result).__name__}, " + f"expected tir.PrimFunc" + ) + return result + raise SystemExit( + f"{func_name!r} is neither PrimFunc nor callable: {type(obj).__name__}" + ) + + +def _emit_output_staging( + compiled, + target: PlenaTarget, + out_buffer_name: str, +) -> str: + """Append "load output back to VRAM[0..]" ISA so view_mem can compare. + + The output buffer ends up in HBM after the main kernel. To check it + against the golden, the runtime convention is to drop tile-by-tile + DMAs at the end of the program that re-load HBM into VRAM[0..], + laid out tile-major. view_mem.py then reads that VRAM region and + compares against golden_result.txt. + + For a 2D logical view (rows, cols) of the output, we walk col-blocks + first, then row-blocks (matches the runtime helper's "stage_order: + col_major"), and emit one emit_load_tile_from_hbm per tile. + """ + from tilelang_tvm_compiler.hlir import ( + hbm_strides_for_layout, + make_tile_layout, + ) + + buf = compiled.hlir.get_buffer(out_buffer_name) + if buf.scope != _scope.HBM: + raise SystemExit( + f"--stage-output buffer {out_buffer_name!r} must be in HBM, " + f"got {buf.scope!r}" + ) + rows, cols = _logical_2d(buf.shape, buf.layout) + mlen = target.mlen + tile_elems = mlen * mlen + full_tensor_size = rows * cols + + from .plena_settings import ( + v_prefetch_amount as _v_prefetch_amount, + v_writeback_amount as _v_writeback_amount, + ) + shim = make_shim( + mlen=target.mlen, + blen=target.blen, + btmm_lane_count=target.btmm_lane_count, + btmm_hlen=target.btmm_hlen, + v_prefetch_amount=_v_prefetch_amount(), + v_writeback_amount=_v_writeback_amount(), + register_allocator=RegisterAllocator(), + ) + emitter = ISAEmitter(shim) + + # Per-tile DMA stride between successive rows of one inner tile. + # For BSHD this equals cols (legacy behaviour); for NCHW it is + # the row-axis HBM stride (= W), NOT the cross-channel cols. + # ``buf.hbm_stride`` was already set to the right value by + # AddressAllocationPass (via _row_stride_for_layout). + inner_tile_stride = buf.hbm_stride if buf.hbm_stride is not None else cols + + # ----- Multi-tile path: when the buffer's 4D logical shape needs + # the 7D physical layout (e.g. NCHW with C_OUT > 1), iterate the + # tile grid in canonical (D_TILES, S_TILES, H_GROUPS, B) order + # and emit one DMA per inner tile. HBM offsets per tile come from + # ``hbm_strides_for_layout`` so they correctly account for + # NCHW's channel-major HBM layout. + if len(buf.shape) == 4: + layout = make_tile_layout( + shape=tuple(int(x) for x in buf.shape), layout=buf.layout, + mlen=mlen, hlen=target.btmm_hlen, + ) + else: + layout = None + + shim.compiler.generated_code = ( + "\n; ============================================================\n" + f"; compare staging: {out_buffer_name} (HBM @ {buf.address}) -> VRAM[0..]\n" + ) + + if layout is not None: + hbm_b, hbm_s, hbm_h, _hbm_d = hbm_strides_for_layout( + buf.shape, buf.layout, + ) + shim.compiler.generated_code += ( + f"; tile_layout: d_tiles={layout.d_tiles} s_tiles={layout.s_tiles} " + f"h_groups={layout.h_groups} b={layout.logical_b}\n" + f"; ({mlen}x{mlen} per inner tile, layout={buf.layout})\n" + "; ============================================================\n" + ) + # Iteration order matches the legacy col-major-block-major + # ``for j in col_blocks: for i in row_blocks`` walk: outer is + # the col-axis tile (d_tile, then h_grp), inner is the + # row-axis tile (s_tile, then b). This keeps the per-tile + # VRAM landing position byte-identical to what the + # comparator's stride-mode reassembler assumes. + vram_addr = 0 + for d_tile in range(layout.d_tiles): + for h_grp in range(layout.h_groups): + for s_tile in range(layout.s_tiles): + for b in range(layout.logical_b): + hbm_off = ( + b * hbm_b + + s_tile * mlen * hbm_s + + h_grp * layout.lane_count * hbm_h + + d_tile * mlen + ) + shim.compiler.generated_code += ( + f"; stage tile (d={d_tile}, h={h_grp}, " + f"s={s_tile}, b={b}) hbm_off={hbm_off} " + f"-> vram[{vram_addr}]\n" + ) + emitter.emit_load_tile_from_hbm( + hbm_addr=buf.address, + vram_addr=vram_addr, + hbm_stride=inner_tile_stride, + hbm_scale_size=full_tensor_size, + hbm_start_offset=hbm_off, + ) + vram_addr += tile_elems + return shim.compiler.generated_code + + # ----- Single-tile fast path (non-4D, or 4D buffers that fit + # exactly one MLEN×MLEN tile) — same iteration as before. + # The MLEN-alignment check only applies here; the 4D path above + # walks a per-(inner-tile) grid where alignment is enforced + # tile-by-tile via the TileLayout, not on the logical-2D + # projection (which can be misleading for multi-channel NCHW / + # multi-head BSHD shapes that genuinely need the 7D physical + # layout to stage correctly). + if rows % mlen or cols % mlen: + raise SystemExit( + f"staging only supports mlen-aligned shapes for now, got " + f"rows={rows} cols={cols} mlen={mlen}" + ) + row_blocks = rows // mlen + col_blocks = cols // mlen + shim.compiler.generated_code += ( + f"; layout: rows={rows} cols={cols} -> {row_blocks}x{col_blocks} tiles " + f"({mlen}x{mlen} each), col-block-major\n" + "; ============================================================\n" + ) + vram_addr = 0 + for j in range(col_blocks): + for i in range(row_blocks): + hbm_offset_elems = i * mlen * cols + j * mlen + shim.compiler.generated_code += ( + f"; stage tile [{i},{j}] hbm_offset(elems)={hbm_offset_elems} " + f"-> vram[{vram_addr}]\n" + ) + emitter.emit_load_tile_from_hbm( + hbm_addr=buf.address, + vram_addr=vram_addr, + hbm_stride=inner_tile_stride, + hbm_scale_size=full_tensor_size, + hbm_start_offset=hbm_offset_elems, + ) + vram_addr += tile_elems + + return shim.compiler.generated_code + + +def _logical_2d(shape, layout: str = "BSHD") -> tuple[int, int]: + """Layout-aware (rows, cols) projection. Delegates to the shared + helper so __main__'s stage-output staging picks the same axes as + address_alloc / isa_pass.""" + from tilelang_tvm_compiler.hlir import logical_2d_extents + return logical_2d_extents(shape, layout) + + +def _cmd_compile(args: argparse.Namespace) -> int: + kernel_kwargs = _parse_kernel_kwargs(args.kernel_kwargs) + func = _resolve_kernel(args.kernel, kernel_kwargs) + target = PlenaTarget( + mlen=args.mlen, + blen=args.blen, + btmm_lane_count=args.btmm_lane_count, + btmm_hlen=args.btmm_hlen, + ) + midir_dump_dir = Path(args.dump_hlir).parent if args.dump_hlir else None + compiled = compile_kernel( + func, target=target, name=args.asm_name, + midir_dump_dir=midir_dump_dir, + use_v2=bool(getattr(args, "use_v2", False)), + ) + isa_text = compiled.isa_text + if args.stage_output: + isa_text = isa_text.rstrip() + _emit_output_staging( + compiled, target, args.stage_output, + ) + + if args.dump_hlir: + Path(args.dump_hlir).write_text(format_hlir(compiled.hlir)) + # Companion lowir report: per-op physical address expressions + # in their "last variable-form" (pre-gp). Op indices line up + # with the .hlir.txt just written. + from tilelang_tvm_compiler.hlir import format_lowir + Path(args.dump_hlir).with_name( + f"{args.asm_name}.lowir.txt" + ).write_text( + format_lowir(compiled.hlir, compiled.lowir_log or []) + ) + + # GP allocator trace: side-by-side TSV that any reader can align with + # the ASM dump via the ``asm_line`` column. Column order is fixed so + # consumers (humans, scripts) don't have to discover it. Always + # written into the same dir as ``--dump-hlir`` so step kernels' + # traces stay next to their ASM. + if args.dump_hlir and compiled.gp_trace: + trace_path = Path(args.dump_hlir).with_name( + f"{args.asm_name}.gp_trace.tsv" + ) + cols = [ + "asm_line", "event", "site", + "regs", "n", "slot", "addr", "spilled", + "free", "in_use", "pinned", + ] + lines = ["\t".join(cols)] + for row in compiled.gp_trace: + lines.append("\t".join(str(row.get(c, "")) for c in cols)) + trace_path.write_text("\n".join(lines) + "\n") + + if args.dump_buffer_addrs: + # Single source of truth for buffer addresses: dump the post + # AddressAllocationPass HLIR addresses as JSON for testbenches / + # external tooling to consume. Avoids the constants-dict-vs-actual + # drift that bit us in flash_decode_min (the FPRAM SCALE/M_INIT/ + # L_INIT addresses the testbench used were a hand-rolled mirror + # of `_slot_addresses`, off by 64 words from what TVM actually + # allocated, leading to head-1/2 numerical drift). + def _buf_entry(buf): + entry = { + "scope": buf.scope, + "address": buf.address, + "shape": [int(s) for s in buf.shape], + "dtype": str(buf.dtype), + } + # Auto-hoisted FP constants carry their compile-time value + # so the testbench harness can preload it without per-kernel + # boilerplate. See frontend/passes/hoist_float_constants.py. + if buf.constant_value is not None: + entry["value"] = float(buf.constant_value) + return entry + + addr_table = { + buf.name: _buf_entry(buf) + for buf in compiled.hlir.buffers.values() + } + Path(args.dump_buffer_addrs).write_text( + json.dumps(addr_table, indent=2) + ) + + if args.output: + Path(args.output).write_text(isa_text) + else: + sys.stdout.write(isa_text) + return 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(prog="tilelang_tvm_compiler") + sub = parser.add_subparsers(dest="cmd", required=True) + + p_compile = sub.add_parser( + "compile", + help="Compile a TIR PrimFunc to PLENA ISA text.", + ) + p_compile.add_argument( + "--kernel", + required=True, + help='Kernel spec, e.g. "tilelang_tvm_compiler.kernels.minimal_btmm:minimal_btmm"; ' + 'may also point at a factory function used together with --kernel-kwargs', + ) + p_compile.add_argument( + "--kernel-kwargs", + default=None, + help="Comma-separated k=v pairs to pass when --kernel resolves to a " + "factory (e.g. `seq_q=128,seq_k=128`). Values are coerced to int " + "when possible.", + ) + p_compile.add_argument("--asm-name", default="kernel") + p_compile.add_argument("--output", default=None, + help="If given, write ISA to this path; else stdout.") + # Hardware sizes default to plena_settings.toml's active mode; + # an explicit flag still overrides for a one-off non-default run. + from .plena_settings import load_sizes as _load_sizes + _hw = _load_sizes() + p_compile.add_argument("--mlen", type=int, default=_hw.mlen) + p_compile.add_argument("--blen", type=int, default=_hw.blen) + p_compile.add_argument("--btmm-lane-count", type=int, + default=_hw.hardware_lane_count) + p_compile.add_argument("--btmm-hlen", type=int, default=_hw.hlen) + p_compile.add_argument( + "--stage-output", + default=None, + help="If given, append per-tile DMA HBM->VRAM[0..] for this output " + "buffer so view_mem.py can compare against golden.", + ) + p_compile.add_argument( + "--dump-hlir", + default=None, + help="If given, write a human-readable HLIR dump to this path " + "(after address allocation; final form fed into ISA emit).", + ) + p_compile.add_argument( + "--dump-buffer-addrs", + default=None, + help="If given, write a JSON dict {buffer_name: {scope, address, " + "shape, dtype}} so testbenches can read the *actual* allocated " + "addresses instead of mirroring them by hand. Single source of " + "truth for FPRAM / VRAM / MRAM / HBM offsets.", + ) + p_compile.add_argument( + "--use-v2", + action="store_true", + help="Route through the PreIsaPassV2 → MIR → ISA pipeline " + "instead of the legacy IsaEmitterPass single-pass. Same " + "HW op stream; tighter register allocation; MIR dump " + "available for debugging.", + ) + p_compile.set_defaults(func=_cmd_compile) + + args = parser.parse_args(argv) + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tilelang_tvm_compiler/address_alloc.py b/tilelang_tvm_compiler/address_alloc.py new file mode 100644 index 0000000..b68c4af --- /dev/null +++ b/tilelang_tvm_compiler/address_alloc.py @@ -0,0 +1,250 @@ +"""Pass 2: assign physical addresses to every HLIR buffer. + +Three independent bump allocators (one per memory space): + - HBM : starts at HBM_BASE, advances by buffer.byte_size + - VRAM : starts at 0, advances by buffer.num_elements + - MRAM : starts at 0, advances by buffer.num_elements + - FPRAM : starts at FPRAM_USER_BASE, advances by buffer.num_elements + +Bump-only is sufficient for the kernels we care about right now (no +buffer is reused after its last op). When we start emitting kernels with +long-lived staging buffers we'll swap this for a liveness-aware allocator; +the pass interface won't change. + +We also fill in stride/scale defaults for every HBM buffer: + hbm_stride <- mlen + hbm_scale_size <- mlen * mlen (i.e. tile_elems) +The runtime emitter applies the same defaults internally; setting them +here makes the values explicit in HLIR so a debug dump shows what the +emitter will actually use. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from typing import Dict, Tuple + +from . import hlir as _hlir +from . import scope as _scope + + +def _row_stride_for_layout( + shape: Tuple[int, ...], layout: str, *, fallback: int, +) -> int: + """Element distance between row r and row r+1 *within the same + channel* of a 4D buffer laid out row-major in HBM under ``layout``. + + For BSHD (B, S, H, D): S → S+1 advances H*D elements (same as cols + in the logical 2D collapse). + For NCHW (N, C, H, W): H → H+1 advances W elements (NOT C*W). + + Falls back to ``fallback`` for non-4D shapes (where there's no + layout-specific notion of "row stride within a channel"). + """ + if len(shape) != 4: + return fallback + # The row dim's stride is just the product of every dim that lies + # AFTER it in the source layout's row-major order. + bi, ri, _ci, _di = _hlir.LAYOUT_AXES[layout] + stride = 1 + for i in range(ri + 1, len(shape)): + stride *= int(shape[i]) + return stride + + +def _logical_2d(shape: Tuple[int, ...], layout: str = "BSHD") -> Tuple[int, int]: + """Collapse N-D shape -> (rows, cols) using ``layout``. + + Thin wrapper around ``hlir.logical_2d_extents``. For BSHD the legacy + "merge last two as cols, fold the rest into rows" heuristic matches + (and is used directly for non-4D shapes). For 4D NCHW the row dim + is axis 2 (not 1), so we permute via ``LAYOUT_AXES``. + + For BTMM, GROUP_HEADS narrow heads of width HLEN pack into one + mlen-wide tile (GROUP_HEADS*HLEN == mlen). HBM has them contiguous + on the innermost dims so the merge is a free reinterpretation — + no data movement. + """ + return _hlir.logical_2d_extents(shape, layout) + + +# Conservative defaults. Pick non-zero HBM base so address-zero stays +# reserved (handy when debugging null-pointer-style bugs in emitted ISA). +_HBM_BASE = 0x0000 +_VRAM_BASE = 0 +_MRAM_BASE = 0 +# Runtime compiler reserves the first 32 FP slots for system/hardware +# constants and expects them to stay zero-initialized. TVM-generated +# kernels must honor the same contract or FPSRAM preloads/results end up +# shifted relative to the emulator/runtime view. +FPRAM_USER_BASE = 32 +_FPRAM_BASE = FPRAM_USER_BASE + + +@dataclass +class AddressAllocConfig: + mlen: int + blen: int + hlen: int = 16 # narrow head dim — typically MLEN/4. Used for + # tile_layout detection on 4D BSHD-shaped local + # buffers. Default matches PlenaTarget.btmm_hlen. + hbm_base: int = _HBM_BASE + vram_base: int = _VRAM_BASE + mram_base: int = _MRAM_BASE + fpram_base: int = _FPRAM_BASE + + # HBM packing parameters for the BEHAVIOR sim. These mirror the + # plena_settings.toml [BEHAVIOR.PRECISION.HBM_*_TYPE] entries that + # `create_mem_for_sim` -> `map_mx_data_to_hbm_for_behave_sim` use to + # lay out tensors in `hbm_for_behave_sim.bin`. We must match them + # exactly here, otherwise our ISA's HBM addresses point into the wrong + # tensor (or padding) and emulator reads garbage. + hbm_row_width: int = 512 # bytes -- BEHAVIOR.CONFIG.HBM_WIDTH + hbm_elem_bits: int = 8 # FP4: sign(1)+exp(4)+mant(3) = 8 bits per element + hbm_scale_bits: int = 8 # 1 byte per scale (Fp(s=0,e=8,m=0)) + hbm_block_size: int = 8 # 1 scale per 8 elements + + # Per-buffer address pins. Used by multi-kernel drivers (e.g. + # tvm_single_stream_block_test) that pre-plan a global HBM layout + # and need each kernel's HBM tensors to land on specific bytes so + # producer.output_addr == consumer.input_addr. Buffer names not in + # the dict fall back to the bump allocator. Pinned addresses do NOT + # advance the bump cursor — the driver is responsible for not + # double-booking bytes. + hbm_address_overrides: Dict[str, int] = field(default_factory=dict) + fpram_address_overrides: Dict[str, int] = field(default_factory=dict) + + @property + def tile_elems(self) -> int: + return self.mlen * self.mlen + + +def _align_up(n: int, mul: int) -> int: + return ((n + mul - 1) // mul) * mul + + +def _hbm_packed_byte_size(num_elements: int, cfg: "AddressAllocConfig") -> int: + """Bytes a single tensor occupies in hbm_for_behave_sim.bin after + `map_mx_data_to_hbm_for_behave_sim` packs it. + + Layout: [element bytes, padded to hbm_row_width][scale bytes, padded to + hbm_row_width][total padded to 64 bytes]. + """ + elem_bytes = num_elements * cfg.hbm_elem_bits // 8 + elem_bytes = _align_up(elem_bytes, cfg.hbm_row_width) + + num_scales = num_elements // cfg.hbm_block_size + scale_bytes = num_scales * cfg.hbm_scale_bits // 8 + if scale_bytes: + scale_bytes = _align_up(scale_bytes, cfg.hbm_row_width) + + return _align_up(elem_bytes + scale_bytes, 64) + + +class AddressAllocationPass: + def __init__(self, cfg: AddressAllocConfig) -> None: + self.cfg = cfg + + def run(self, mod: _hlir.HLIRModule) -> _hlir.HLIRModule: + hbm_cur = self.cfg.hbm_base + vram_cur = self.cfg.vram_base + mram_cur = self.cfg.mram_base + fpram_cur = self.cfg.fpram_base + + for buf in mod.buffers.values(): + # Collapse `global.` to `` for residency decisions — + # a `global.vram` buffer allocates from the same VRAM pool as a + # regular `vram` buffer; the user-declared global flag only + # affects lane-fusion expansion (in allocate_group_memory). + phys = _scope.physical_scope(buf.scope) + if phys == _scope.HBM: + override = self.cfg.hbm_address_overrides.get(buf.name) + if override is not None: + # Pinned by the driver — don't advance the bump cursor. + buf.address = int(override) + else: + buf.address = hbm_cur + # IMPORTANT: increment by the MXFP-packed byte size, not by + # the raw fp16 buf.byte_size. `create_mem_for_sim` packs + # tensors into hbm_for_behave_sim.bin using FP4 elements + # (1 byte each) plus 1/8 byte scales, padded to row width. + # If we use buf.byte_size here our HBM addresses won't match + # what's actually on disk and H_PREFETCH_M reads garbage. + hbm_cur += _hbm_packed_byte_size(buf.num_elements, self.cfg) + rows, cols = _logical_2d(buf.shape, buf.layout) + # stride = HBM-row-major distance from canonical row r + # to row r+1 of the same channel (NOT cols, when those + # differ). + # + # For BSHD (B, S, H, D) the row dim S is the outer of + # the row/col pair, so stride = H*D = cols. ✓ + # + # For NCHW (N, C, H, W) the row dim H sits BETWEEN the + # channel C and the col W. Going H → H+1 within the + # same channel is W elements; cols = C*W is the + # cross-channel collapse. Using cols here would jump a + # full channel-width worth of elements per row → wrong + # data on every per-tile DMA stride between rows. + if buf.hbm_stride is None: + buf.hbm_stride = _row_stride_for_layout( + buf.shape, buf.layout, fallback=cols, + ) + # scale_size = total elements of the HBM region (rows*cols), + # NOT one tile. The runtime ValueManager always uses the HBM + # backing object's full shape product; defaulting to + # mlen*mlen only happens to be correct when the buffer is + # exactly one mlen-square tile. For our multi-tile buffers + # (e.g. C_hbm = 64x256) this MUST be 16384, not 4096, or the + # emulator's HBM addressing wraps wrong. + if buf.hbm_scale_size is None: + buf.hbm_scale_size = rows * cols + # Stash the logical 2D dims as annotations the ISA pass + # can read to decide per-tile decomposition. + buf.annotations["logical_rows"] = rows + buf.annotations["logical_cols"] = cols + buf.annotations["row_blocks"] = max(1, rows // self.cfg.mlen) + buf.annotations["col_blocks"] = max(1, cols // self.cfg.mlen) + elif phys == _scope.VRAM: + buf.address = vram_cur + vram_cur += buf.num_elements + # Detect 4D buffers that need multi-tile physical + # storage. None for 2D/1D shapes or single-tile-fitting + # shapes — caller falls back to row-major (existing + # kernels' single-tile case is unaffected). The buffer's + # ``layout`` attribute (set by codegen from + # ``T.func_attr({"plena.layout": ...})``) controls how + # the 4D axes map to the canonical (B, S, H, D) tile + # roles. + if len(buf.shape) == 4 and not buf.is_pinned_global: + buf.tile_layout = _hlir.make_tile_layout( + shape=tuple(int(x) for x in buf.shape), + layout=buf.layout, + mlen=self.cfg.mlen, hlen=self.cfg.hlen, + ) + elif phys == _scope.MRAM: + buf.address = mram_cur + mram_cur += buf.num_elements + if len(buf.shape) == 4 and not buf.is_pinned_global: + buf.tile_layout = _hlir.make_tile_layout( + shape=tuple(int(x) for x in buf.shape), + layout=buf.layout, + mlen=self.cfg.mlen, hlen=self.cfg.hlen, + ) + elif phys == _scope.FPRAM: + # FPRAM stores scalar FP values; address them in element units + # to match S_LD_FP / S_ST_FP and the emulator's fpsram indexing. + override = self.cfg.fpram_address_overrides.get(buf.name) + if override is not None: + buf.address = int(override) + else: + buf.address = fpram_cur + fpram_cur += buf.num_elements + else: + raise ValueError(f"buffer {buf.name!r}: unknown scope {buf.scope!r}") + + _hlir.assert_addresses_resolved(mod) + return mod + + +__all__ = ["AddressAllocConfig", "AddressAllocationPass", "FPRAM_USER_BASE"] diff --git a/tilelang_tvm_compiler/backend_emit.py b/tilelang_tvm_compiler/backend_emit.py new file mode 100644 index 0000000..e8fc1f2 --- /dev/null +++ b/tilelang_tvm_compiler/backend_emit.py @@ -0,0 +1,1245 @@ +"""BackendEmit — consume PreIsaIR, produce final ISA text. + +This is the second half of the IsaPass split: + + Old: HLIR -> IsaEmitterPass (algebra + materialise + emit) -> ISA + + New: HLIR -> PreIsaPass (algebra; produce PreIsaIR) + -> pre_isa_optimize (arith.simplify; CSE; LICM) + -> BackendEmit (materialise each operand; emit ISA) + +BackendEmit owns the GP register / ISA-text wiring. For each PreIsaOp +in the input stream: + * Materialise every ``tir.PrimExpr`` operand via the existing + ``ExprMaterializer`` (which handles symbol_table lookup, GP alloc, + eager auto-spill, constant folding, the lot). + * Plug the resulting GP register numbers into a per-opcode ISA + template string and append one ISA line to + ``shim.compiler.generated_code``. + * Release the operand GPs after the emit. + +The opcode dispatch table lives in ``_TEMPLATES`` below. Each entry +declares the ISA mnemonic's operand layout — which operand slots are +materialised (PrimExpr -> GP) vs dropped in verbatim (already a +literal int / a hard-coded ``f0`` / ``a3`` token / a flag). +Adding a new HW instruction to the BackendEmit means adding one row +to ``_TEMPLATES`` and (if needed) the mnemonic to ``KNOWN_OPCODES`` +in ``pre_isa_ir.py``. + +Loop handling: ``LOOP_START`` / ``LOOP_END`` are the PreIsaIR +control markers. The PreIsaOp on ``LOOP_START`` carries +``annotations["loop_kind"]`` which selects the codegen strategy: + + * ``"serial"`` (default) — emit a hardware loop. Claims an IntRAM + idx slot, binds the loop_var into ``symbol_table`` as + ``("ram", idx_addr)``, and emits the counter init + the literal + ISA mnemonic ``C_LOOP_START gp_loop, extent``. The matching + ``LOOP_END`` emits the idx-increment + ``C_LOOP_END gp_loop`` + and releases the slot. Byte-equal to legacy ``_emit_for``'s + serial branch. + + * ``"unroll"`` — bind loop_var to a per-iteration ``tir.IntImm`` + and recursively emit the body PreIsaOps N times. No idx slot, + no loop_gp use, no hardware C_LOOP_* lines. Byte-equal to legacy + ``_emit_for``'s unrolled branch. + +Switching the strategy is a single annotation edit at any PreIsaIR +optimisation pass — the IR shape doesn't change. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +from tvm import tir + +from .expr_materializer import ExprMaterializer, MaterializedExpr +from .pre_isa_ir import PreIsaModule, PreIsaOp +from .program_shim import ProgramShim + + +class BackendEmitError(RuntimeError): + pass + + +# Sentinel used by ``_invoke_slot`` to signal that a slot's GP register +# came from the group cache, NOT from a fresh materialise(). The +# caller's per-emit cleanup must skip these (they're released at +# group close, not per emit). +_CACHED_SENTINEL = object() + + +# Operand-slot descriptor: a Python callable that, given the slot +# value (PrimExpr / int / str) and the materialiser, returns the +# token to drop into the ISA template ("gpN" / literal / "f0" / ...) +# AND an optional ``MaterializedExpr`` to release after the emit. +def _slot_expr( + val: Any, mat: ExprMaterializer, +) -> Tuple[str, Optional[MaterializedExpr]]: + """A ``PrimExpr`` operand → materialise to gpN, return (``"gpN"``, + handle). A plain ``int`` also goes through materialise so loop-vars + and large literals get the proper S_ADDI/S_LUI sequence the + existing emitter expects.""" + if isinstance(val, (int, tir.PrimExpr)): + m = mat.materialize(val) + return f"gp{m.register}", m + raise BackendEmitError( + f"slot_expr: expected PrimExpr / int, got {type(val).__name__} {val!r}" + ) + + +def _slot_literal_int( + val: Any, mat: ExprMaterializer, +) -> Tuple[str, Optional[MaterializedExpr]]: + """An immediate field of the ISA encoding — must be a compile-time + int literal, dropped verbatim into the template.""" + if isinstance(val, int): + return str(val), None + if isinstance(val, tir.IntImm): + return str(int(val.value)), None + raise BackendEmitError( + f"slot_literal_int: ISA literal must be a compile-time int; " + f"got {type(val).__name__} {val!r}" + ) + + +def _slot_verbatim( + val: Any, mat: ExprMaterializer, +) -> Tuple[str, Optional[MaterializedExpr]]: + """A hard-coded token (e.g. ``"f0"`` for the zero FPRAM register, + ``"a3"`` for an addr-reg slot). Caller already wrote the string + they want in the ISA; we just drop it in.""" + if isinstance(val, str): + return val, None + raise BackendEmitError( + f"slot_verbatim: expected str token; got {type(val).__name__} {val!r}" + ) + + +# Marker slot kind: the BackendEmit handler treats it specially in +# _invoke_slot. Looks up the operand PrimExpr in the current group +# cache and returns the cached GP token; never materialises fresh. +# Used by row_*_at's destructive in-place stride bump pattern: the +# V_*_VF / V_RED_* / V_EXP_V iterations all read the same GP that an +# earlier _PRELOAD_ADDR established, and _BUMP_CACHED_GP mutates that +# GP between iterations. +def _slot_expr_cached( + val: Any, mat: ExprMaterializer, +) -> Tuple[str, Optional[MaterializedExpr]]: + raise BackendEmitError( + "_slot_expr_cached must be intercepted by BackendEmit._invoke_slot" + ) + + +# Marker slot kind: the BackendEmit handler looks up the operand +# PrimExpr in the ADDR-REG cache (populated by ``_PRELOAD_ADDR_REG``) +# and returns the cached ``aN`` token. Used by DMA HW instructions +# (H_PREFETCH_V / H_STORE_V / H_PREFETCH_M). +def _slot_addr_reg_cached( + val: Any, mat: ExprMaterializer, +) -> Tuple[str, Optional[MaterializedExpr]]: + raise BackendEmitError( + "_slot_addr_reg_cached must be intercepted by BackendEmit._invoke_slot" + ) + + +@dataclass +class _Template: + """One opcode's emit template. + + ``slots`` is a list of (slot-kind-callable) entries, one per + operand on the PreIsaOp. ``fmt`` is the ISA-line format string + with positional placeholders ``{0}`` ... ``{N-1}`` corresponding + to the slots' rendered tokens. + """ + slots: List[Callable[[Any, ExprMaterializer], Tuple[str, Optional[MaterializedExpr]]]] + fmt: str + + +# Per-opcode dispatch table. **Only opcodes whose handler has been +# migrated to the PreIsaPass producer appear here.** As more handlers +# migrate (one PR-or-commit at a time, byte-equal verified per op), +# new rows land in this table. +_TEMPLATES: Dict[str, _Template] = { + # FP scalar load/store: ``S_LD_FP fX, gp{addr}, 0`` / + # ``S_ST_FP fX, gp{addr}, 0``. First operand is the FP register + # token (verbatim), second is the address (PrimExpr), third is the + # element offset literal (always 0 in current emit, kept as + # operand for future flexibility). + "S_LD_FP": _Template( + slots=[_slot_verbatim, _slot_expr, _slot_literal_int], + fmt="S_LD_FP {0}, {1}, {2}", + ), + "S_ST_FP": _Template( + slots=[_slot_verbatim, _slot_expr, _slot_literal_int], + fmt="S_ST_FP {0}, {1}, {2}", + ), + # FP scalar binary ops: ``OP f_dst, f_lhs, f_rhs`` — all three + # operands are FP register tokens. + "S_ADD_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim, _slot_verbatim], + fmt="S_ADD_FP {0}, {1}, {2}", + ), + "S_SUB_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim, _slot_verbatim], + fmt="S_SUB_FP {0}, {1}, {2}", + ), + "S_MUL_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim, _slot_verbatim], + fmt="S_MUL_FP {0}, {1}, {2}", + ), + "S_MAX_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim, _slot_verbatim], + fmt="S_MAX_FP {0}, {1}, {2}", + ), + # FP scalar unary ops. Note legacy emitter is inconsistent: S_EXP_FP + # takes a trailing ``0`` flag operand, S_RECI_FP / S_SQRT_FP do not. + "S_EXP_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim, _slot_literal_int], + fmt="S_EXP_FP {0}, {1}, {2}", + ), + "S_RECI_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim], + fmt="S_RECI_FP {0}, {1}", + ), + "S_SQRT_FP": _Template( + slots=[_slot_verbatim, _slot_verbatim], + fmt="S_SQRT_FP {0}, {1}", + ), + # Vector ops. ``V_*_VV`` / ``V_*_VF`` all take a trailing literal + # flag (almost always 0 in current emit). The "VV" variant takes + # gp-gp-gp (dst, lhs, rhs); the "VF" variant takes gp-gp-fpram_reg. + "V_ADD_VV": _Template( + slots=[_slot_expr, _slot_expr, _slot_expr, _slot_literal_int], + fmt="V_ADD_VV {0}, {1}, {2}, {3}", + ), + "V_SUB_VV": _Template( + slots=[_slot_expr, _slot_expr, _slot_expr, _slot_literal_int], + fmt="V_SUB_VV {0}, {1}, {2}, {3}", + ), + "V_MUL_VV": _Template( + slots=[_slot_expr, _slot_expr, _slot_expr, _slot_literal_int], + fmt="V_MUL_VV {0}, {1}, {2}, {3}", + ), + "V_MUL_VF": _Template( + slots=[_slot_expr, _slot_expr, _slot_verbatim, _slot_literal_int], + fmt="V_MUL_VF {0}, {1}, {2}, {3}", + ), + "V_ADD_VF": _Template( + slots=[_slot_expr, _slot_expr, _slot_verbatim, _slot_literal_int], + fmt="V_ADD_VF {0}, {1}, {2}, {3}", + ), + "V_SUB_VF": _Template( + slots=[_slot_expr, _slot_expr, _slot_verbatim, _slot_literal_int], + fmt="V_SUB_VF {0}, {1}, {2}, {3}", + ), + "V_EXP_V": _Template( + slots=[_slot_expr, _slot_expr, _slot_literal_int], + fmt="V_EXP_V {0}, {1}, {2}", + ), + "V_RECI_V": _Template( + slots=[_slot_expr, _slot_expr, _slot_literal_int], + fmt="V_RECI_V {0}, {1}, {2}", + ), + "V_SQRT_V": _Template( + slots=[_slot_expr, _slot_expr, _slot_literal_int], + fmt="V_SQRT_V {0}, {1}, {2}", + ), + # VRAM <-> FPRAM transfer (S_MAP_*_*). Both directions take + # (gp_dst_addr, gp_src_addr, 0) — the gp values are FPRAM / + # VRAM addresses, so both slots are PrimExpr. + "S_MAP_FP_V": _Template( + slots=[_slot_expr, _slot_expr, _slot_literal_int], + fmt="S_MAP_FP_V {0}, {1}, {2}", + ), + "S_MAP_V_FP": _Template( + slots=[_slot_expr, _slot_expr, _slot_literal_int], + fmt="S_MAP_V_FP {0}, {1}, {2}", + ), + # Mask register control. The operand is a GP holding the new mask + # value. legacy emits this twice per masked row op: once before to + # arm, once after to reset to 0. + "C_SET_V_MASK_REG": _Template( + slots=[_slot_expr_cached], + fmt="C_SET_V_MASK_REG {0}", + ), + # Vector reduce — accumulates into f1. Legacy ``V_RED_*`` takes + # (f_acc, gp_src_vec, mask_flag). + "V_RED_MAX": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_literal_int], + fmt="V_RED_MAX {0}, {1}, {2}", + ), + "V_RED_SUM": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_literal_int], + fmt="V_RED_SUM {0}, {1}, {2}", + ), + # S_SUB_VF for row_sub_fp: legacy emits a 5-operand form + # ``V_SUB_VF gp{dst}, gp{src}, f1, mask_flag, 0``. The trailing 0 + # is a literal that doesn't appear in V_ADD_VF / V_MUL_VF (those + # are 4-operand). + "_V_SUB_VF_ROW": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_verbatim, + _slot_literal_int, _slot_literal_int], + fmt="V_SUB_VF {0}, {1}, {2}, {3}, {4}", + ), + # Same as _V_SUB_VF_ROW but for V_ADD_VF / V_MUL_VF with the cached + # GP operand pattern (used by row_*_at's d_tile unroll loop). + "_V_ADD_VF_ROW": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_verbatim, + _slot_literal_int], + fmt="V_ADD_VF {0}, {1}, {2}, {3}", + ), + "_V_MUL_VF_ROW": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_verbatim, + _slot_literal_int], + fmt="V_MUL_VF {0}, {1}, {2}, {3}", + ), + # V_EXP_V / V_RECI_V with cached GPs (no fresh materialise). + "_V_EXP_V_ROW": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_literal_int], + fmt="V_EXP_V {0}, {1}, {2}", + ), + "_V_RECI_V_ROW": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_literal_int], + fmt="V_RECI_V {0}, {1}, {2}", + ), + # S_ADDI_INT for setting up mask reset (gp_mask, gp0, 0). Cached + # form (gp_mask is in the group cache from an earlier + # _PRELOAD_ADDR / materialise). Prefix avoids colliding with future + # uses of S_ADDI_INT that take freshly-materialised GPs. + "_S_ADDI_INT_RESET_MASK": _Template( + slots=[_slot_expr_cached, _slot_verbatim, _slot_literal_int], + fmt="S_ADDI_INT {0}, {1}, {2}", + ), + # S_LD_FP / S_ST_FP variants where the FP-side GP is cached (an + # FPRAM destination address staying live across multiple ops). + # ``f1, gp_cached, 0`` and ``gp_cached, f1, 0`` patterns. + "_S_LD_FP_CACHED": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_literal_int], + fmt="S_LD_FP {0}, {1}, {2}", + ), + "_S_ST_FP_CACHED": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_literal_int], + fmt="S_ST_FP {0}, {1}, {2}", + ), + # Matrix ops. Legacy emit_btmm form: + # M_BTMM gp0, gp{rhs_mram_base}, gp{lhs_packed_vram_base} + # All three operand slots use cached-GP (the two base addresses + # were just preloaded via S_ADDI_INT; gp0 is the constant-zero + # verbatim token used as a dummy result accumulator). + "M_BTMM": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_expr_cached], + fmt="M_BTMM {0}, {1}, {2}", + ), + # Legacy emit_btmm_wo form: M_BMM_WO gp{out_base}, 0 + "M_BMM_WO": _Template( + slots=[_slot_expr_cached, _slot_literal_int], + fmt="M_BMM_WO {0}, {1}", + ), + # Legacy emit_btmv form (clone of emit_btmm with M_BTMV mnemonic): + # M_BTMV gp0, gp{rhs_mram_base}, gp{lhs_packed_vram_base} + "M_BTMV": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_expr_cached], + fmt="M_BTMV {0}, {1}, {2}", + ), + # Legacy emit_bmv_wo form: M_BMV_WO gp{out}, 0 + "M_BMV_WO": _Template( + slots=[_slot_expr_cached, _slot_literal_int], + fmt="M_BMV_WO {0}, {1}", + ), + # M_MV / M_MV_WO — per-iteration HW ops in emit_mv. + # M_MV gp0, gp{rhs_base}, gp{lhs_base} -- both bases cached + # M_MV_WO gp{dst_base}, 0 + "M_MV": _Template( + slots=[_slot_verbatim, _slot_expr_cached, _slot_expr_cached], + fmt="M_MV {0}, {1}, {2}", + ), + "M_MV_WO": _Template( + slots=[_slot_expr_cached, _slot_literal_int], + fmt="M_MV_WO {0}, {1}", + ), + # M_MM / M_MM_WO — emitted by mm / matmul handlers. + # M_MM 0, gp{mat_col_base}, gp{act_row_base} + # M_MM_WO gp{result}, gp0, 0 + "M_MM": _Template( + slots=[_slot_literal_int, _slot_expr_cached, _slot_expr_cached], + fmt="M_MM {0}, {1}, {2}", + ), + "M_MM_WO": _Template( + slots=[_slot_expr_cached, _slot_verbatim, _slot_literal_int], + fmt="M_MM_WO {0}, {1}, {2}", + ), + # M_TMM — transposed matmul: + # M_TMM 0, gp{act_vram}, gp{mat_mram} + # (note: operand order swapped vs M_MM — rs1 is lhs, rs2 is rhs). + "M_TMM": _Template( + slots=[_slot_literal_int, _slot_expr_cached, _slot_expr_cached], + fmt="M_TMM {0}, {1}, {2}", + ), + # DMA control / data-movement instructions. + # + # C_SET_SCALE_REG gp{r} — set scale length (gp{r} holds it) + # C_SET_STRIDE_REG gp{r} — set stride length + # C_SET_ADDR_REG aN, gp0, gp{r} — load addr-reg ``aN`` from gp{r} + # + # The ``aN`` (PLENA addr register) appears as a verbatim string + # operand on the PreIsaOp (chosen by the producer at allocation + # time). The legacy ``gp0`` constant-zero source in C_SET_ADDR_REG + # stays hardcoded in the template (matches PLENA HW encoding). + "C_SET_SCALE_REG": _Template( + slots=[_slot_expr_cached], + fmt="C_SET_SCALE_REG {0}", + ), + "C_SET_STRIDE_REG": _Template( + slots=[_slot_expr_cached], + fmt="C_SET_STRIDE_REG {0}", + ), + # C_SET_ADDR_REG is handled internally by _PRELOAD_ADDR_REG; no + # separate template needed (the meta-op emits the C_SET_ADDR_REG + # text directly during its handler). + # + # H_PREFETCH_V — HBM→VRAM tile prefetch. Two operand-count forms + # in legacy emit (5-op batch>1 / 6-op batch=1): + # 5-op: H_PREFETCH_V gp{result}, gp{a_off}, aN, 1, 0 + # 6-op: H_PREFETCH_V gp{result}, gp{a_actual}, aN, 0, 0, 0 + # ``aN`` (3rd slot) is the addr-reg token resolved from the + # addr-reg cache (populated by an earlier _PRELOAD_ADDR_REG of the + # same PrimExpr object). + "H_PREFETCH_V": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_addr_reg_cached, + _slot_literal_int, _slot_literal_int], + fmt="H_PREFETCH_V {0}, {1}, {2}, {3}, {4}", + ), + "_H_PREFETCH_V_6OP": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_addr_reg_cached, + _slot_literal_int, _slot_literal_int, _slot_literal_int], + fmt="H_PREFETCH_V {0}, {1}, {2}, {3}, {4}, {5}", + ), + # H_PREFETCH_M — HBM→MRAM tile prefetch: + # H_PREFETCH_M gp{mram}, gp{scale}, aN, 1, 0 + "H_PREFETCH_M": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_addr_reg_cached, + _slot_literal_int, _slot_literal_int], + fmt="H_PREFETCH_M {0}, {1}, {2}, {3}, {4}", + ), + # H_STORE_V — VRAM→HBM tile store: + # H_STORE_V gp{vram}, gp{hbm_off}, aN, , 0 + "H_STORE_V": _Template( + slots=[_slot_expr_cached, _slot_expr_cached, _slot_addr_reg_cached, + _slot_literal_int, _slot_literal_int], + fmt="H_STORE_V {0}, {1}, {2}, {3}, {4}", + ), + # More entries land here as handlers migrate. +} + + +class BackendEmit: + """Walks a ``PreIsaModule`` and produces final ISA text. + + Construction takes a fully wired ``ProgramShim`` (same one the old + ``IsaEmitterPass`` uses) so the materialiser sees the same register + allocator + generated_code sink. + + Call ``run(pre_isa_mod)`` to drive emission; the resulting ISA text + is read from ``shim.compiler.generated_code`` (a string under the + new architecture — no CapturingCode proxy). + + Materialisation grouping + ------------------------ + The legacy ``IsaEmitterPass`` would materialise each HLIR op's + address operands ONCE per ``begin_op`` / ``end_op`` scope and reuse + the resulting GP registers across every ISA line that op emitted + (e.g. an FP binary ``_at`` op reuses 3 GPs across 5 ISA lines + via ``ra.pin_gp``). + + In PreIsaIR each of those 5 ISA lines is its OWN PreIsaOp, so a + naive per-PreIsaOp ``begin_op`` / ``end_op`` would re-materialise + the same address 5 times — emitting 5x as much address-setup ISA + and breaking byte-equality with the legacy path. + + BackendEmit therefore groups consecutive PreIsaOps that share a + materialisation scope. The grouping is driven by an integer + ``annotations["group_id"]`` stamped by the PreIsaPass producer: + every PreIsaOp produced by ONE call to a legacy-style handler + (i.e. ONE HLIR op) shares one ``group_id``. BackendEmit opens a + fresh materialiser scope on a ``group_id`` transition and closes + it on the next transition; PreIsaOps with no ``group_id`` (e.g. + free-standing comments) inherit the current scope without + transitioning. + + Within a group, repeated occurrences of the SAME Python expression + object (id()) are materialised once and the resulting GP register + cached — this is how an FP-binary's 3 addresses survive across its + 5 ISA lines with one S_ADDI_INT each rather than five. + """ + + def __init__(self, shim: ProgramShim) -> None: + self.shim = shim + self.symbol_table: Dict[tir.Var, Any] = {} + # Bind the hw-shape constants (mlen, blen, btmm_hlen, ...) into + # the symbol_table as IntImms taken from the shim. PreIsaIR + # producers use the symbolic tir.Vars from ``hw_consts``; + # ExprMaterializer's ``_peephole_const_fold`` substitutes them + # at materialise time so the final ISA has the hardware's + # current numeric values folded in — but PreIsaIR itself stays + # algebraic. See ``hw_consts.py`` for the design. + from .hw_consts import HW_CONST_ATTRS + for var, attr in HW_CONST_ATTRS.items(): + self.symbol_table[var] = tir.IntImm( + "int32", int(getattr(shim, attr)), + ) + self.materializer = ExprMaterializer(shim, self.symbol_table) + # Group materialisation cache STACK. Each entry represents one + # open materialisation scope: a dict keyed by ``id(prim_expr)`` + # with values ``(gp_reg, MaterializedExpr)``. + # + # The stack supports nested scopes — outer scopes' cached + # entries are visible to lookups while inner scopes are open + # (lookup walks the stack from top to bottom). This is what + # lets a producer preload an address in an outer iteration + # (e.g. matmul's per-oc ``mat_addr``) and reference it from + # PreIsaOps inside an inner unroll body without the inner + # ``per_iter`` close clobbering the outer entry. + # + # When a group is opened with ``_open_group`` a fresh empty + # dict is pushed; ``_close_group`` pops the top dict and frees + # its entries (NOT entries in any outer scope). ``_slot_expr`` + # caches into the TOP scope; ``_slot_expr_cached`` looks up + # through the whole stack. + self._group_stack: List[Dict[int, Tuple[int, MaterializedExpr]]] = [] + self._group_id_stack: List[Optional[Any]] = [] + # Per-scope close order. Indexed in parallel with _group_stack. + # Default "reverse" matches the ``for m in reversed(mats): + # m.release()`` pattern; PreIsaPass producers may set + # "insertion" via annotations on the group's first PreIsaOp. + self._group_close_order_stack: List[str] = [] + # Stack of open hardware loops. Each entry is a dict produced + # by _emit_loop_start_serial and consumed by the matching + # _emit_loop_end_serial; tracks the loop_var, loop_gp, + # idx_addr so nested loops compose correctly. + self._loop_stack: List[Dict[str, Any]] = [] + # Addr-register cache: id(addr_value_expr) -> (a_reg_int, token) + # — populated by ``_PRELOAD_ADDR_REG`` and looked up by DMA + # PreIsaOps that need to reference an ``aN`` token. Lives in a + # per-scope stack mirroring ``_group_stack`` so nested DMA + # contexts compose. Each entry's addr reg is released on + # scope close. + self._addr_reg_stack: List[Dict[int, Tuple[int, str]]] = [] + # Scope-floor stack — minimum allowed scope depth that + # ``_enter_group_for`` may close down to. Each entry is an + # int. Pushed by ``_emit_unroll`` (so sibling-gid transitions + # inside an unroll iter body can't clobber the outer scope + # that holds e.g. an addr-reg binding). Default floor is 0 + # — no constraint at the top level. + self._scope_floor: List[int] = [] + + # Back-compat properties used by legacy methods (now read top of + # stack instead of the old flat single-cache state). + @property + def _group_cache(self) -> Dict[int, Tuple[int, MaterializedExpr]]: + if not self._group_stack: + # No scope open — return an empty dict to make cache hit + # checks always fail. Writes via this property still go + # somewhere, but no scope means nothing to write to; we + # require a scope to be open before any _slot_expr. + return {} + return self._group_stack[-1] + + @property + def _group_open(self) -> bool: + return bool(self._group_stack) + + @property + def _current_group_id(self) -> Optional[Any]: + if not self._group_id_stack: + return None + return self._group_id_stack[-1] + + @property + def _group_close_order(self) -> str: + if not self._group_close_order_stack: + return "reverse" + return self._group_close_order_stack[-1] + + @_group_close_order.setter + def _group_close_order(self, value: str) -> None: + if not self._group_close_order_stack: + return + self._group_close_order_stack[-1] = value + + def run(self, mod: PreIsaModule) -> str: + """Emit ``mod`` into ``shim.compiler.generated_code``. Returns + the final ISA text (str). + + The walker has to handle ``loop_kind="unroll"`` LOOP_STARTs + specially: when it sees one, it locates the matching LOOP_END + (respecting nesting) and recursively re-emits the body N + times with ``loop_var`` rebound to ``tir.IntImm(init + i)``. + ``loop_kind="serial"`` LOOP_STARTs go straight through + ``_emit_one`` which calls the hardware-loop path. + """ + self._run_ops(mod.ops) + return self.shim.compiler.generated_code + + def _run_ops(self, ops: List[PreIsaOp], _close_at_end: bool = True) -> None: + """Walk a (sub)sequence of PreIsaOps, handling unroll + LOOP_START specially. Called recursively for each unrolled + iteration's body, so nested unrolls compose. + + ``_close_at_end`` controls whether scopes opened DURING this + call's body are flushed at the end. Snapshots stack depth on + entry; only scopes pushed since are popped (outer scopes are + not touched). The outer ``run`` call leaves it True. + ``_emit_unroll`` in ``"shared"`` mode passes False — the + unroll body's scope must stay open across iterations. + """ + entry_depth = len(self._group_stack) + try: + i = 0 + n = len(ops) + while i < n: + op = ops[i] + if ( + op.opcode == "LOOP_START" + and op.annotations.get("loop_kind", "serial") == "unroll" + ): + j, body, init_imm, extent_imm, loop_var = ( + self._locate_unroll_body(ops, i) + ) + self._emit_unroll( + body, init_imm, extent_imm, loop_var, + annotations=op.annotations, + ) + i = j + 1 + continue + self._enter_group_for(op, i) + self._emit_one(op) + i += 1 + finally: + if _close_at_end: + while len(self._group_stack) > entry_depth: + self._close_group() + + def _locate_unroll_body( + self, + ops: List[PreIsaOp], + start_idx: int, + ) -> Tuple[int, List[PreIsaOp], int, int, tir.Var]: + """Find the LOOP_END matching the LOOP_START at ``start_idx`` + (must be loop_kind="unroll"). Returns + ``(end_idx, body_ops, init_imm, extent_imm, loop_var)``. + Body is the slice ``ops[start_idx+1 .. end_idx-1]``. + """ + start_op = ops[start_idx] + if len(start_op.operands) != 2: + raise BackendEmitError( + f"LOOP_START at [{start_idx}] expects " + f"[init_imm, extent_imm]; got {start_op.operands!r}" + ) + init_imm = int(start_op.operands[0]) + extent_imm = int(start_op.operands[1]) + loop_var = start_op.binds + if loop_var is None: + raise BackendEmitError( + f"LOOP_START at [{start_idx}] has no binds " + f"(loop iteration var)" + ) + + depth = 1 + j = start_idx + 1 + n = len(ops) + while j < n: + opc = ops[j].opcode + if opc == "LOOP_START": + depth += 1 + elif opc == "LOOP_END": + depth -= 1 + if depth == 0: + break + j += 1 + if j == n: + raise BackendEmitError( + f"LOOP_START at [{start_idx}] has no matching LOOP_END" + ) + body = ops[start_idx + 1:j] + return j, body, init_imm, extent_imm, loop_var + + def _emit_unroll( + self, + body: List[PreIsaOp], + init_imm: int, + extent_imm: int, + loop_var: tir.Var, + annotations: Dict[str, Any], + ) -> None: + """Emit ``extent_imm`` copies of ``body`` with ``loop_var`` + bound to ``IntImm(init_imm + iter)`` in the materialiser's + symbol_table. + + Two scope-management modes, selected by + ``annotations["unroll_scope"]``: + + * ``"per_iter"`` (default) — each iteration is a fresh + materialiser scope; the group cache resets between iters. + Matches legacy ``_emit_for``'s unrolled branch where + ``begin_op`` / ``end_op`` runs once per body sub-op per + iter (no GP carry-over across iters). + * ``"shared"`` — the materialiser scope OPEN at unroll-loop + entry is kept open across all iterations. The body's + id()-keyed cache hits the same PrimExpr objects across + iters and reuses GPs. Used by handlers that need + destructive in-place state (e.g. mv's per-tile bump of + gp_m / gp_o across iters): the bumps mutate the cached + GP value, the next iter's body picks up the bumped value + via the SAME cached entry. + """ + scope_mode = annotations.get("unroll_scope", "per_iter") + if scope_mode not in ("per_iter", "shared"): + raise BackendEmitError( + f"unknown unroll_scope {scope_mode!r}; " + f"expected 'per_iter' or 'shared'" + ) + if loop_var in self.symbol_table: + raise BackendEmitError( + f"loop_var {loop_var.name!r} already bound — nested " + f"unroll reusing the same Var is unsupported" + ) + # Header comment matching legacy. + self.shim.compiler.generated_code += ( + f"; unroll for {loop_var.name} in " + f"[{init_imm}, {init_imm + extent_imm}) -- idx is a literal\n" + ) + # Snapshot scope-stack depth at entry. ``per_iter`` mode resets + # the stack to THIS depth between iters (closing only scopes + # OPENED inside the body), preserving outer scopes (e.g. + # matmul narrow's per-oc mat_addr scope spans the inner t + # unroll body even though inner uses per_iter). + entry_depth = len(self._group_stack) + # Push a scope-floor watermark so ``_enter_group_for``'s + # sibling-gid-transition close is bounded by the entry depth. + # Without this, the first body PreIsaOp's gid != outer gid + # triggers a close of the outer scope itself, killing any + # addr-reg / GP caches the inner body relies on. + self._scope_floor.append(entry_depth) + try: + for k in range(extent_imm): + iter_val = init_imm + k + self.symbol_table[loop_var] = tir.IntImm( + "int32", iter_val, + ) + self.shim.compiler.generated_code += ( + f"; ... unroll iter {k} -> " + f"{loop_var.name}={iter_val}\n" + ) + if scope_mode == "per_iter": + # Close any inner scopes opened during the previous + # iter's body. We must NOT close scopes that were + # open at unroll entry. + while len(self._group_stack) > entry_depth: + self._close_group() + # In shared mode we leave the stack alone. + self._run_ops(body, _close_at_end=(scope_mode != "shared")) + finally: + self.symbol_table.pop(loop_var, None) + # On exit close down to entry depth. + while len(self._group_stack) > entry_depth: + self._close_group() + # Pop the floor. + self._scope_floor.pop() + + # ------------------------------------------------------------------ + # group management + # ------------------------------------------------------------------ + def _enter_group_for(self, op: PreIsaOp, default_idx: int) -> None: + """Open / transition / close the materialisation scope around + ``op`` based on its ``annotations["group_id"]``. + + Rules: + * No group_id: behaves like a singleton group keyed on the + op's index. Used for _COMMENT and any leaf op the producer + didn't bother tagging. + * Same group_id as the open scope: keep the scope open. + * Different group_id: close the open scope, open a fresh one. + """ + gid = op.annotations.get("group_id", None) + # _COMMENT ops never disturb the open scope — they're not real + # HW instructions and never materialise operands. Lets a + # comment land in the middle of a multi-line group (e.g. the + # ``; fp scalar task ... op=mul`` header inside an + # fp_mul_at's 5-line burst) without forcing a scope flush. + if op.opcode == "_COMMENT": + return + if gid is None: + # Singleton — close any open scope, then open a fresh + # one keyed on the default index so each "ungrouped" op + # gets its own materialiser.begin_op cycle (matches the + # pre-grouping behaviour for handlers that emit a single + # ISA line per op). + gid = ("_singleton", default_idx) + if self._group_open and gid == self._current_group_id: + return + # Close the top scope only if it sits ABOVE the current + # scope-floor watermark. Inside an unroll iter body, the + # outer scope (which holds e.g. an _PRELOAD_ADDR_REG'd addr + # register the inner H_PREFETCH_V needs) is protected by a + # floor pushed by _emit_unroll — sibling-gid transitions in + # the body PUSH new scopes on top instead of replacing. + floor = self._scope_floor[-1] if self._scope_floor else 0 + if len(self._group_stack) > floor: + self._close_group() + self._open_group(gid, default_idx) + # First PreIsaOp of a group may carry close_order in its + # annotations. Honour it for the rest of the group's lifetime. + order = op.annotations.get("close_order") + if order is not None: + if order not in ("reverse", "insertion"): + raise BackendEmitError( + f"close_order must be 'reverse' or 'insertion'; " + f"got {order!r}" + ) + self._group_close_order = order + + def _open_group(self, gid: Any, idx: int) -> None: + """Push a fresh materialisation scope onto the stack.""" + self.materializer.set_lowir_op_idx(idx) + self.materializer.begin_op() + self._group_stack.append({}) + self._group_id_stack.append(gid) + self._group_close_order_stack.append("reverse") + self._addr_reg_stack.append({}) + + def _close_group(self) -> None: + """Pop the TOP materialisation scope, freeing its cached GPs. + Outer scopes (deeper in the stack) remain open and their + cached GPs visible to subsequent lookups. + + Legacy emitters use two different release patterns: + * fp_*_at, v_*, row_*_at: ``for m in reversed(mats): + m.release()`` — releases in REVERSE insertion order, so + the FIRST-inserted reg ends up on top. + * emit_btmm, emit_btmm_wo, emit_mv: ``ra.free_gp(gp_regs)`` + — passes a list in insertion order; free_gp iterates that + list, so the LAST-inserted reg ends up on top. + + The PreIsaPass producer stamps each group with the desired + order via ``annotations["close_order"]`` on the group's FIRST + PreIsaOp. The setting was snapshotted into + ``_group_close_order_stack`` at open time. + """ + if not self._group_stack: + return + top_cache = self._group_stack.pop() + self._group_id_stack.pop() + close_order = self._group_close_order_stack.pop() + # Release any addr-regs allocated in this scope BEFORE popping + # so the allocator's free_addr happens before the gp release. + addr_top = ( + self._addr_reg_stack.pop() + if self._addr_reg_stack + else {} + ) + ra = self.shim.compiler.register_allocator + for _key, (a_reg, _tok) in addr_top.items(): + ra.free_addr([a_reg]) + items = list(top_cache.values()) + if close_order == "insertion": + iter_order = items + else: + iter_order = list(reversed(items)) + for _gp, m in iter_order: + if m.owns_register: + ra.unpin_gp(m.register) + m.release() + self.materializer.end_op() + + # ------------------------------------------------------------------ + # per-op emit + # ------------------------------------------------------------------ + def _emit_one(self, op: PreIsaOp) -> None: + if op.opcode == "_COMMENT": + text = op.operands[0] if op.operands else "" + self.shim.compiler.generated_code += f"; {text}\n" + return + + if op.opcode == "LOOP_START": + # Strategy is on the PreIsaOp; default = "serial". + kind = op.annotations.get("loop_kind", "serial") + if kind == "serial": + self._emit_loop_start_serial(op) + elif kind == "unroll": + # Unrolled loops are handled by ``run`` because they + # need to drive the body-replay loop themselves. The + # main ``run`` walker detects this opcode + kind and + # never calls ``_emit_one`` on it directly. + raise BackendEmitError( + "LOOP_START with loop_kind='unroll' must be " + "expanded by run()'s outer walker, not reach " + "_emit_one — internal invariant violated" + ) + else: + raise BackendEmitError( + f"LOOP_START: unknown loop_kind {kind!r} " + f"(expected 'serial' or 'unroll')" + ) + return + + if op.opcode == "LOOP_END": + # Mirror of LOOP_START dispatch — serial closes the HW + # loop. Unroll's LOOP_END is consumed by run() outside. + kind = op.annotations.get("loop_kind", "serial") + if kind == "serial": + self._emit_loop_end_serial(op) + elif kind == "unroll": + raise BackendEmitError( + "LOOP_END with loop_kind='unroll' must be " + "skipped by run()'s outer walker" + ) + else: + raise BackendEmitError( + f"LOOP_END: unknown loop_kind {kind!r}" + ) + return + + if op.opcode == "_PRELOAD_ADDR_REG": + # Allocate a PLENA addr register, materialise the operand + # value into a scratch GP, emit the C_SET_ADDR_REG bind, + # and cache (id(operand) -> ("aN", a_reg_int)) for any + # later DMA PreIsaOp that references the same operand. + if len(op.operands) != 1: + raise BackendEmitError( + f"_PRELOAD_ADDR_REG expects 1 operand (the addr " + f"value expr); got {len(op.operands)}" + ) + val = op.operands[0] + if not self._addr_reg_stack: + raise BackendEmitError( + "_PRELOAD_ADDR_REG: no open scope to cache the " + "addr-register binding. Open a group_id first." + ) + top = self._addr_reg_stack[-1] + key = id(val) + if key in top: + # Already bound — no-op (legacy would emit a redundant + # C_SET_ADDR_REG; producer should not duplicate). + return + ra = self.shim.compiler.register_allocator + a_reg = ra.allocate_addr(1)[0] + tok = f"a{a_reg}" + # Materialise the value into a GP first (this puts the + # GP into the current group cache; subsequent ops in the + # same scope can reference ``val`` via _slot_expr_cached + # to get the same GP back, or use the cached addr-reg + # token through the addr-reg cache). + m_val_tok, _h = self._invoke_slot(_slot_expr, val) + # Emit C_SET_ADDR_REG aN, gp0, gp{r} + self.shim.compiler.generated_code += ( + f"C_SET_ADDR_REG {tok}, gp0, {m_val_tok}\n" + ) + top[key] = (a_reg, tok) + return + + if op.opcode == "_PRELOAD_ADDR": + # Materialise the operand PrimExpr into a GP now and stash + # it in the group cache. Mirrors legacy + # _emit_fp_scalar_op_at's upfront materialisation loop — + # without this, addresses materialise lazily on first ISA + # use, which interleaves S_ADDI_INTs with FP ops and breaks + # byte-equality with the legacy emitter. + if len(op.operands) != 1: + raise BackendEmitError( + f"_PRELOAD_ADDR expects 1 operand (the address " + f"PrimExpr); got {len(op.operands)}" + ) + val = op.operands[0] + # Drive through _invoke_slot so the cache machinery runs + # exactly the way it would for a real ISA op. + self._invoke_slot(_slot_expr, val) + return + + if op.opcode == "_BUMP_CACHED_GP": + # ``S_ADDI_INT gp{N}, gp{N}, stride`` where gp{N} is the + # cached GP for the operand PrimExpr. Mutates the cached + # value — subsequent _slot_expr / _slot_expr_cached + # lookups of the same expr return the same GP but its + # value is now ``orig + stride``. Producer must arrange + # iteration semantics accordingly (mirrors legacy's + # destructive in-place stride bump in row_*_at). + if len(op.operands) != 2: + raise BackendEmitError( + f"_BUMP_CACHED_GP expects [cached_expr, stride]; " + f"got {op.operands!r}" + ) + cached_expr, stride = op.operands + # Stride may be a compile-time int OR a PrimExpr that + # references hw-shape consts (BLEN_VAR / MLEN_VAR etc). + # The latter goes through ``_peephole_const_fold`` to + # substitute the symbol_table-bound IntImms and simplify + # to a literal integer at emit time. + if isinstance(stride, int): + stride_int = stride + elif isinstance(stride, tir.IntImm): + stride_int = int(stride.value) + elif isinstance(stride, tir.PrimExpr): + folded = self.materializer._peephole_const_fold(stride) + if isinstance(folded, tir.IntImm): + stride_int = int(folded.value) + else: + raise BackendEmitError( + f"_BUMP_CACHED_GP stride PrimExpr did not fold " + f"to an IntImm at emit time: {stride!r} -> " + f"{folded!r}. All free vars must be bound in " + f"symbol_table." + ) + else: + raise BackendEmitError( + f"_BUMP_CACHED_GP stride must be int / IntImm / " + f"PrimExpr; got {type(stride).__name__} {stride!r}" + ) + key = id(cached_expr) + hit = self._lookup_cache(key) + if hit is None: + raise BackendEmitError( + f"_BUMP_CACHED_GP: PrimExpr not in any open scope " + f"(no preceding _PRELOAD_ADDR for the same Python " + f"PrimExpr object)" + ) + gp, _m = hit + self.shim.compiler.generated_code += ( + f"S_ADDI_INT gp{gp}, gp{gp}, {stride_int}\n" + ) + return + + tmpl = _TEMPLATES.get(op.opcode) + if tmpl is None: + raise BackendEmitError( + f"no BackendEmit template for opcode {op.opcode!r}. " + f"The handler that produced this PreIsaOp must be " + f"matched by a row in pre_isa_emit._TEMPLATES." + ) + + if len(op.operands) != len(tmpl.slots): + raise BackendEmitError( + f"{op.opcode}: operand count {len(op.operands)} does " + f"not match template arity {len(tmpl.slots)}" + ) + + tokens: List[str] = [] + # Per-emit list of "fresh" handles to release at the end. + # Cached handles (group-shared) are NOT released here — they + # live until _close_group. + fresh_handles: List[MaterializedExpr] = [] + try: + ra = self.shim.compiler.register_allocator + for slot_fn, val in zip(tmpl.slots, op.operands): + tok, handle = self._invoke_slot(slot_fn, val) + tokens.append(tok) + if handle is not None and handle is not _CACHED_SENTINEL: + fresh_handles.append(handle) + # Pin until end of emit so the next operand's + # materialise() can't auto-spill it (mirrors the + # legacy ``ra.pin_gp(m.register)`` pattern in + # _emit_fp_scalar_op_at). + if handle.owns_register: + ra.pin_gp(handle.register) + self.shim.compiler.generated_code += tmpl.fmt.format(*tokens) + "\n" + finally: + ra = self.shim.compiler.register_allocator + for h in reversed(fresh_handles): + if h.owns_register: + ra.unpin_gp(h.register) + # Fresh (non-cached) handles release immediately; cached + # ones stay alive in self._group_cache. + + # ------------------------------------------------------------------ + # loop handling — LOOP_START / LOOP_END + # ------------------------------------------------------------------ + def _emit_loop_start_serial(self, op: PreIsaOp) -> None: + """Open a hardware (serial) loop. Emits the literal PLENA ISA + ``C_LOOP_START gp_loop, extent`` along with the idx-init. + + operands = [init_imm:int, extent_imm:int] + binds = the loop iteration tir.Var (body PreIsaOps reference + it via PrimExpr operands) + annotations["loop_gp"] = the GP reserved by loop_register_alloc + for this loop's HW counter + + Mirrors legacy ``_emit_for``'s serial branch byte-for-byte. + """ + if len(op.operands) != 2: + raise BackendEmitError( + f"LOOP_START expects [init_imm, extent_imm]; " + f"got {op.operands!r}" + ) + init_imm = int(op.operands[0]) + extent_imm = int(op.operands[1]) + loop_var = op.binds + loop_gp = op.annotations.get("loop_gp") + if loop_var is None: + raise BackendEmitError( + f"LOOP_START has no binds (loop iteration tir.Var)" + ) + if loop_gp is None: + raise BackendEmitError( + f"LOOP_START (serial) has no 'loop_gp' annotation " + f"(must be set by PreIsaPass from the HLIR op's " + f"loop_register_alloc stamp)" + ) + + self._close_group() + + ra = self.shim.compiler.register_allocator + if loop_var in self.symbol_table: + raise BackendEmitError( + f"loop_var {loop_var.name!r} already bound — nested " + f"loops reusing the same Var aren't supported" + ) + ra.pin_gp(loop_gp) + idx_addr = ra.claim_idx_slot() + + if init_imm == 0: + self.shim.compiler.generated_code += ( + f"; for {loop_var.name} in [{init_imm}, " + f"{init_imm + extent_imm}) -- hw counter gp{loop_gp}, " + f"idx ram[{idx_addr}]\n" + f"S_ST_INT gp0, gp0, {idx_addr}\n" + f"C_LOOP_START gp{loop_gp}, {extent_imm}\n" + ) + else: + init_gp = ra.allocate_gp(1)[0] + self.shim.compiler.generated_code += ( + f"; for {loop_var.name} in [{init_imm}, " + f"{init_imm + extent_imm}) -- hw counter gp{loop_gp}, " + f"idx ram[{idx_addr}]\n" + f"S_ADDI_INT gp{init_gp}, gp0, {init_imm}\n" + f"S_ST_INT gp{init_gp}, gp0, {idx_addr}\n" + f"C_LOOP_START gp{loop_gp}, {extent_imm}\n" + ) + ra.free_gp([init_gp]) + + self.symbol_table[loop_var] = ("ram", idx_addr) + # Stash state for the matching LOOP_END. + self._loop_stack.append({ + "loop_var": loop_var, + "loop_gp": loop_gp, + "idx_addr": idx_addr, + }) + + def _emit_loop_end_serial(self, op: PreIsaOp) -> None: + """Close the most-recent serial loop opened by + ``_emit_loop_start_serial``. Emits the literal PLENA ISA + ``C_LOOP_END gp_loop`` along with the idx-increment epilogue. + Matches legacy ``_emit_for`` byte-for-byte. + """ + if not self._loop_stack: + raise BackendEmitError( + f"LOOP_END with no matching LOOP_START on the stack" + ) + self._close_group() + + st = self._loop_stack.pop() + loop_var = st["loop_var"] + loop_gp = st["loop_gp"] + idx_addr = st["idx_addr"] + + ra = self.shim.compiler.register_allocator + inc_gp = ra.allocate_gp(1)[0] + self.shim.compiler.generated_code += ( + f"; idx {loop_var.name} += 1 (ram[{idx_addr}])\n" + f"S_LD_INT gp{inc_gp}, gp0, {idx_addr}\n" + f"S_ADDI_INT gp{inc_gp}, gp{inc_gp}, 1\n" + f"S_ST_INT gp{inc_gp}, gp0, {idx_addr}\n" + f"C_LOOP_END gp{loop_gp}\n" + ) + ra.free_gp([inc_gp]) + + self.symbol_table.pop(loop_var, None) + ra.unpin_gp(loop_gp) + ra.release_idx_slot(idx_addr) + + def _lookup_cache( + self, key: int, + ) -> Optional[Tuple[int, MaterializedExpr]]: + """Walk the scope stack top-to-bottom looking for ``key``. The + first hit wins — inner scope shadows outer. Returns ``None`` + when ``key`` is in no open scope.""" + for scope in reversed(self._group_stack): + hit = scope.get(key) + if hit is not None: + return hit + return None + + def _invoke_slot( + self, + slot_fn: Callable[[Any, ExprMaterializer], Tuple[str, Optional[MaterializedExpr]]], + val: Any, + ) -> Tuple[str, Optional[MaterializedExpr]]: + """Resolve a slot. + + ``_slot_expr`` caches into the TOP open scope on first sight; + subsequent occurrences of the same Python expression object + (``id`` match) anywhere in the scope stack return the cached + GP. Other slot kinds (literal int, verbatim string) never + cache. + + ``_slot_expr_cached`` is the explicit "must already be in + SOME open scope" form — looks up the PrimExpr through the + whole scope stack. Used by destructive in-place patterns + (row_*_at's d_tile stride bump) and nested-scope patterns + (matmul narrow's outer-scope ``mat_addr`` referenced from the + inner unroll body). + """ + if slot_fn is _slot_expr_cached: + key = id(val) + hit = self._lookup_cache(key) + if hit is None: + raise BackendEmitError( + f"_slot_expr_cached: PrimExpr {val!r} not in any " + f"open group cache. The producer must have emitted " + f"a _PRELOAD_ADDR for this expr earlier in an open " + f"scope." + ) + gp, _m = hit + return f"gp{gp}", _CACHED_SENTINEL + if slot_fn is _slot_addr_reg_cached: + key = id(val) + # Walk the addr-reg stack top-to-bottom. + for scope in reversed(self._addr_reg_stack): + hit = scope.get(key) + if hit is not None: + _a_reg, tok = hit + return tok, _CACHED_SENTINEL + raise BackendEmitError( + f"_slot_addr_reg_cached: PrimExpr {val!r} not in any " + f"open addr-reg cache. The producer must have emitted " + f"a _PRELOAD_ADDR_REG for this expr earlier in an open " + f"scope." + ) + if slot_fn is _slot_expr: + key = id(val) + hit = self._lookup_cache(key) + if hit is not None: + gp, _m = hit + return f"gp{gp}", _CACHED_SENTINEL + # First sight — materialise and cache into TOP scope. + if not self._group_stack: + # Defensive: an _PRELOAD_ADDR / _slot_expr outside any + # scope is a producer bug. + raise BackendEmitError( + f"_slot_expr: no open scope to cache materialised " + f"PrimExpr {val!r} into. The producer must open a " + f"group (via group_id annotation) before its first " + f"PrimExpr operand." + ) + tok, handle = slot_fn(val, self.materializer) + if handle is not None: + ra = self.shim.compiler.register_allocator + if handle.owns_register: + ra.pin_gp(handle.register) + self._group_stack[-1][key] = (handle.register, handle) + return tok, _CACHED_SENTINEL + return tok, handle + return slot_fn(val, self.materializer) + + +__all__ = ["BackendEmit", "BackendEmitError"] diff --git a/tilelang_tvm_compiler/codegen.py b/tilelang_tvm_compiler/codegen.py new file mode 100644 index 0000000..06d1143 --- /dev/null +++ b/tilelang_tvm_compiler/codegen.py @@ -0,0 +1,580 @@ +"""TIR -> PLENA pseudo-ISA codegen. + +Walks a `tvm.tir.PrimFunc`: + + 1. Collects every buffer (params + alloc_buffer inside Blocks) and its scope. + 2. Walks the body and finds `T.call_extern("handle", "plena.*", ...)` sites. + 3. Looks up the intrinsic spec, type-checks operand scopes, emits ISA text. + +This is the equivalent of an MLIR "convert-plena-to-isa" pass, written +imperatively in Python because we are using TVM (no dialect machinery). +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import tvm +from tvm import tir +from tvm.tir import stmt_functor + +from . import intrinsics as _intrin +from . import scope as _scope +from . import hlir as _hlir + + +class CodegenError(RuntimeError): + pass + + +class _BufferInfo: + """What we remember per buffer for ISA emission.""" + + __slots__ = ("name", "scope", "shape", "dtype") + + def __init__(self, name: str, scope: str, shape, dtype: str): + self.name = name + self.scope = scope + self.shape = tuple(int(s) if isinstance(s, (int, tir.IntImm)) else s for s in shape) + self.dtype = dtype + + def __repr__(self) -> str: + return f"{self.name}<{self.scope}>" + + +def _normalize_scope(s: str) -> str: + """Map TVM's default empty/"global" scope to our HBM.""" + if s in ("", "global"): + return _scope.HBM + return s + + +class PlenaCodegen: + """One instance per PrimFunc compile.""" + + def __init__(self, func: tir.PrimFunc, name: str = "kernel"): + self.func = func + self.name = name + # data-handle Var -> _BufferInfo + self._buffers: Dict[tir.Var, _BufferInfo] = {} + # name-keyed lookup for diagnostic messages + self._buffers_by_name: Dict[str, _BufferInfo] = {} + self._isa_lines: List[str] = [] + + # ------------------------------------------------------------------ + # public API + # ------------------------------------------------------------------ + def buffers_by_name(self) -> Dict[str, "_BufferInfo"]: + """Read-only view of {buffer_name -> info}. Populated after run().""" + return dict(self._buffers_by_name) + + def lower_to_hlir(self) -> _hlir.HLIRModule: + """Pass 1: walk TIR -> HLIR module (buffers + ordered op stream). + + Replaces the text-based `run()` for the new pipeline. We keep + both paths because `run()` is convenient for quick eyeballing + of a kernel during development. + """ + # Reuse the buffer collection logic from run(). + self._buffers.clear() + self._buffers_by_name.clear() + self._collect_param_buffers() + self._collect_alloc_buffers() + + # Read kernel-wide layout from ``T.func_attr({"plena.layout": + # "NCHW"})``. Defaults to BSHD — every kernel written before + # the layout migration relied on that. Stamp it onto every 4D + # HLIR Buffer so address_alloc can pick the right axis for + # row-tile / channel-group / col-tile. + kernel_layout = self._kernel_layout() + + # Construct HLIR buffers (preserving param order). + hlir_buffers: Dict[str, _hlir.Buffer] = {} + param_names: List[str] = [] + for var in self.func.params: + buf = self.func.buffer_map.get(var, None) + if buf is None: + continue + info = self._buffers_by_name[buf.name] + hlir_buffers[info.name] = self._buf_info_to_hlir(info, kernel_layout) + param_names.append(info.name) + for name, info in self._buffers_by_name.items(): + if name not in hlir_buffers: + hlir_buffers[name] = self._buf_info_to_hlir(info, kernel_layout) + + # Walk the body and collect Op stream. + ops: List[_hlir.Op] = [] + self._collect_ops(self.func.body, ops) + + return _hlir.HLIRModule( + name=self.name, + buffers=hlir_buffers, + ops=ops, + param_names=param_names, + ) + + def _kernel_layout(self) -> str: + """Read ``T.func_attr({"plena.layout": ...})``. Defaults to BSHD.""" + attrs = self.func.attrs + if attrs is None or "plena.layout" not in attrs: + return "BSHD" + val = attrs["plena.layout"] + if isinstance(val, tir.StringImm): + return str(val.value) + return str(val) + + @staticmethod + def _buf_info_to_hlir( + info: "_BufferInfo", kernel_layout: str = "BSHD", + ) -> _hlir.Buffer: + return _hlir.Buffer( + name=info.name, + scope=info.scope, + shape=tuple(int(s) for s in info.shape), + dtype=info.dtype, + layout=kernel_layout, + ) + + def _collect_ops(self, stmt, ops: List[_hlir.Op]) -> None: + if isinstance(stmt, tir.SeqStmt): + for s in stmt: + self._collect_ops(s, ops) + elif isinstance(stmt, tir.BlockRealize): + self._collect_ops(stmt.block, ops) + elif isinstance(stmt, tir.Block): + self._collect_ops(stmt.body, ops) + elif isinstance(stmt, tir.LetStmt): + self._collect_ops(stmt.body, ops) + elif isinstance(stmt, tir.IfThenElse): + # PLENA's ISA has no scalar branch instructions, and the previous + # "literal True/False only" handling was misleading -- it covered + # essentially nothing that a Python-level `if` at kernel-build + # time can't already express more clearly. Reject all TIR ifs. + raise CodegenError( + "tir.IfThenElse is not supported. Use a Python-level `if` in " + "the kernel factory to specialize at build time, or T.unroll " + "+ Python branching for per-iteration variants. PLENA has no " + "branch ISA, so dynamic conditions cannot be lowered." + ) + elif isinstance(stmt, tir.For): + # Recursively collect the body into a fresh op list, then wrap + # it in a structured ForOp. Pass 3 walks `body` while binding + # `loop_var` to a GP register so PrimExprs that reference it + # can be materialised by ExprMaterializer. + body_ops: List[_hlir.Op] = [] + self._collect_ops(stmt.body, body_ops) + extent = ( + int(stmt.extent.value) if isinstance(stmt.extent, tir.IntImm) + else stmt.extent + ) + init = ( + int(stmt.min.value) if isinstance(stmt.min, tir.IntImm) + else stmt.min + ) + # Capture loop kind so isa_pass can pick between hardware-loop + # emission (C_LOOP_START/END) and full unrolling. T.unroll(...) + # in the kernel maps to ForKind.UNROLLED here; isa_pass uses + # this to escape the emulator's MAX_LOOP_INSTRUCTIONS-per-iter + # cap when one iteration of an outer loop dispatches a body + # too large to fit (e.g. the 16x16 emit_matmul expansion). + # tir.For.kind is an int-valued enum member; resolve to a name. + raw_kind = getattr(stmt, "kind", None) + try: + kind_str = tir.ForKind(int(raw_kind)).name.lower() + except (TypeError, ValueError): + kind_str = "serial" + for_op = _hlir.make_for_op( + loop_var=stmt.loop_var, + extent=extent, + body=body_ops, + init=init, + ) + for_op.annotations["loop_kind"] = kind_str + ops.append(for_op) + elif isinstance(stmt, tir.Evaluate): + self._collect_op_from_evaluate(stmt, ops) + elif isinstance(stmt, tir.AttrStmt): + self._collect_ops(stmt.body, ops) + + def _collect_slice_op( + self, + val: tir.Call, + name: str, + kind: str, + ops: List[_hlir.Op], + ) -> None: + """Parse `plena.dma_*_slice` calls. + + Layout: + args[1] src_buf.data (Var) + args[2] dst_buf.data (Var) + args[3] ndim (IntImm) + args[4..4+ndim-1] starts (PrimExpr / IntImm) + args[4+ndim..4+2*ndim-1] extents (IntImm) + + The src OR dst is the sliced one, depending on direction: + h2v / h2m -> src is sliced (HBM tensor) + v2h -> dst is sliced (writing to a sub-region of HBM) + """ + raw = list(val.args[1:]) + if len(raw) < 4: + raise CodegenError( + f"{name}: expected at least 4 args (src, dst, ndim, ...), got {len(raw)}" + ) + src_var, dst_var, ndim_imm = raw[0], raw[1], raw[2] + if not isinstance(ndim_imm, tir.IntImm): + raise CodegenError( + f"{name}: ndim must be a compile-time int, got {type(ndim_imm).__name__}" + ) + ndim = int(ndim_imm.value) + if len(raw) != 3 + 2 * ndim: + raise CodegenError( + f"{name}: with ndim={ndim} expected exactly {3 + 2 * ndim} args, " + f"got {len(raw)}" + ) + starts_raw = raw[3 : 3 + ndim] + extents_raw = raw[3 + ndim : 3 + 2 * ndim] + + # Each start may be int / IntImm (static) or arbitrary PrimExpr + # (dynamic). Pass 3 will dispatch on type. + starts: List[Any] = [] + for s in starts_raw: + if isinstance(s, tir.IntImm): + starts.append(int(s.value)) + elif isinstance(s, tir.PrimExpr): + starts.append(s) + else: + raise CodegenError( + f"{name}: start must be IntImm or PrimExpr, got {type(s).__name__}" + ) + extents: List[int] = [] + for e in extents_raw: + if not isinstance(e, tir.IntImm): + raise CodegenError( + f"{name}: extent must be a compile-time int, got " + f"{type(e).__name__}={e!r}" + ) + extents.append(int(e.value)) + + # Look up parent buffers from the data-handle Vars. + if not (isinstance(src_var, tir.Var) and src_var in self._buffers): + raise CodegenError(f"{name}: src is not a known buffer handle") + if not (isinstance(dst_var, tir.Var) and dst_var in self._buffers): + raise CodegenError(f"{name}: dst is not a known buffer handle") + src_info = self._buffers[src_var] + dst_info = self._buffers[dst_var] + + # Decide which side is sliced based on the intrinsic. + if name in ("plena.dma_h2v_slice", "plena.dma_h2m_slice"): + sliced = _hlir.BufferSlice( + parent=src_info.name, starts=tuple(starts), extents=tuple(extents), + ) + buffer_args: List[Any] = [sliced, dst_info.name] + elif name == "plena.dma_v2h_slice": + sliced = _hlir.BufferSlice( + parent=dst_info.name, starts=tuple(starts), extents=tuple(extents), + ) + buffer_args = [src_info.name, sliced] + else: + raise CodegenError(f"unhandled slice intrinsic: {name}") + + ops.append(_hlir.Op( + kind=kind, + buffer_args=buffer_args, + scalar_args=[], + annotations={"intrinsic": name}, + )) + + def _collect_op_from_evaluate(self, ev: tir.Evaluate, ops: List[_hlir.Op]) -> None: + val = ev.value + if not isinstance(val, tir.Call): + return + name = self._call_extern_name(val) + if name is None or not name.startswith("plena."): + return + spec = _intrin.lookup(name) # validates that the op is known + kind = name[len("plena."):] + + # Slice variants have a structured arg pack: src, dst, ndim, + # *starts, *extents. Pack the variadic suffix into a BufferSlice + # and produce an HLIR Op whose `buffer_args[0]` is the slice (or + # for v2h_slice, `buffer_args[1]`). + if name.endswith("_slice"): + self._collect_slice_op(val, name, kind, ops) + return + + # Arg resolution. Buffer-handle Vars (those that map to a Buffer + # we've already collected) become buffer_args by name. Everything + # else is a scalar argument: + # - IntImm / FloatImm / StringImm -> native Python int/float/str + # (cheaper for downstream passes than carrying the IR node) + # - any other PrimExpr (loop var, compound expression like + # kv_block * mlen + offset) -> kept as-is so ExprMaterializer + # can lower it at ISA emit time + raw_args = list(val.args[1:]) + buffer_args: List[str] = [] + scalar_args: List[Any] = [] + scopes: List[Optional[str]] = [] + for a in raw_args: + if isinstance(a, tir.Var) and a in self._buffers: + info = self._buffers[a] + buffer_args.append(info.name) + scopes.append(info.scope) + continue + if isinstance(a, tir.BufferLoad) and a.buffer.data in self._buffers: + info = self._buffers[a.buffer.data] + if _scope.physical_scope(info.scope) == _scope.FPRAM: + scalar_args.append(_hlir.BufferElement( + buffer=info.name, + indices=tuple(self._normalize_scalar_expr(i) for i in a.indices), + )) + scopes.append(None) + continue + scopes.append(None) + scalar_args.append(self._normalize_scalar_expr(a)) + # Verify scopes against the registered intrinsic spec. We collapse + # scopes from buffer/scalar args back into the original positional + # order so verification matches op signatures. + ordered_scopes: List[Optional[str]] = [] + bi = 0 + si = 0 + for a in raw_args: + if isinstance(a, tir.Var) and a in self._buffers: + ordered_scopes.append(self._buffers[a].scope) + bi += 1 + else: + ordered_scopes.append(None) + si += 1 + self._verify_scopes(spec, name, ordered_scopes) + + ops.append(_hlir.Op( + kind=kind, + buffer_args=buffer_args, + scalar_args=scalar_args, + annotations={"intrinsic": name}, + )) + + def run(self) -> str: + self._collect_param_buffers() + self._collect_alloc_buffers() + self._emit_header() + self._emit_buffer_directives() + self._isa_lines.append("") + self._emit_body() + return "\n".join(self._isa_lines) + "\n" + + # ------------------------------------------------------------------ + # buffer collection + # ------------------------------------------------------------------ + def _collect_param_buffers(self) -> None: + for var in self.func.params: + buf = self.func.buffer_map.get(var, None) + if buf is None: + # opaque handle / scalar param -- skip for now + continue + self._record_buffer(buf, default_scope=_scope.HBM) + + def _collect_alloc_buffers(self) -> None: + def visitor(node): + if isinstance(node, tir.Block): + for buf in node.alloc_buffers: + self._record_buffer(buf, default_scope=_scope.HBM) + elif isinstance(node, tir.Allocate): + # post-block-flattening form -- not used in our entry IR but + # cheap to support so that lowering passes don't break us. + pass + + stmt_functor.post_order_visit(self.func.body, visitor) + + def _record_buffer(self, buf: tir.Buffer, default_scope: str) -> None: + scope = _normalize_scope(buf.scope() or default_scope) + if not _scope.is_known(scope): + raise CodegenError( + f"buffer {buf.name!r} has unknown scope {scope!r}; " + f"expected one of {_scope.ALL_SCOPES}" + ) + info = _BufferInfo(buf.name, scope, buf.shape, str(buf.dtype)) + self._buffers[buf.data] = info + self._buffers_by_name[buf.name] = info + + # ------------------------------------------------------------------ + # body walk + # ------------------------------------------------------------------ + def _emit_body(self) -> None: + # Use a manual recursive walk so we can preserve emission order. + # post_order_visit reverses statements, which would scramble the ISA. + self._walk_stmt(self.func.body) + + def _walk_stmt(self, stmt) -> None: + if isinstance(stmt, tir.SeqStmt): + for s in stmt: + self._walk_stmt(s) + elif isinstance(stmt, tir.BlockRealize): + self._walk_stmt(stmt.block) + elif isinstance(stmt, tir.Block): + self._walk_stmt(stmt.body) + elif isinstance(stmt, tir.For): + # We don't emit loop control yet -- just unroll-by-walking. + # Real PLENA would lower this to C_LOOP_START/END. For the + # skeleton kernel there are no loops. + self._isa_lines.append( + f"; for {stmt.loop_var.name_hint} in [{stmt.min}, {stmt.min} + {stmt.extent})" + ) + self._walk_stmt(stmt.body) + self._isa_lines.append(f"; end for {stmt.loop_var.name_hint}") + elif isinstance(stmt, tir.Evaluate): + self._walk_evaluate(stmt) + elif isinstance(stmt, tir.AttrStmt): + self._walk_stmt(stmt.body) + else: + # Unknown stmt -- emit a comment so we can spot it during dev. + self._isa_lines.append(f"; ") + + def _walk_evaluate(self, ev: tir.Evaluate) -> None: + val = ev.value + if not isinstance(val, tir.Call): + return + name = self._call_extern_name(val) + if name is None or not name.startswith("plena."): + return + spec = _intrin.lookup(name) + # call_extern args are: [StringImm(name), op1, op2, ...] + raw_args = list(val.args[1:]) + resolved, scopes = self._resolve_args(raw_args) + self._verify_scopes(spec, name, scopes) + self._isa_lines.append(spec.emit(resolved)) + + @staticmethod + def _call_extern_name(call: tir.Call) -> Optional[str]: + op = call.op + # tvm.ir.Op for builtins like "tir.call_extern" + op_name = getattr(op, "name", None) + if op_name != "tir.call_extern": + return None + if not call.args: + return None + head = call.args[0] + if isinstance(head, tir.StringImm): + return head.value + return None + + def _resolve_args(self, args) -> tuple[list[str], list[Optional[str]]]: + resolved: list[str] = [] + scopes: list[Optional[str]] = [] + for a in args: + if isinstance(a, tir.Var) and a in self._buffers: + info = self._buffers[a] + resolved.append(info.name) + scopes.append(info.scope) + elif isinstance(a, tir.BufferLoad) and a.buffer.data in self._buffers: + info = self._buffers[a.buffer.data] + if _scope.physical_scope(info.scope) == _scope.FPRAM: + idx = ", ".join(str(self._normalize_scalar_expr(i)) for i in a.indices) + resolved.append(f"{info.name}[{idx}]") + scopes.append(None) + else: + resolved.append(str(a)) + scopes.append(None) + elif isinstance(a, (tir.IntImm, tir.FloatImm)): + resolved.append(str(a.value)) + scopes.append(None) + elif isinstance(a, tir.StringImm): + resolved.append(repr(a.value)) + scopes.append(None) + else: + # Could be a buffer .data we missed, or a complex expr. + # Fall back to a textual rendering and no scope. + resolved.append(str(a)) + scopes.append(None) + return resolved, scopes + + @staticmethod + def _normalize_scalar_expr(a): + if isinstance(a, tir.IntImm): + return int(a.value) + if isinstance(a, tir.FloatImm): + return float(a.value) + if isinstance(a, tir.StringImm): + return str(a.value) + if isinstance(a, tir.PrimExpr): + return a + return str(a) + + def _verify_scopes( + self, spec: _intrin.IntrinsicSpec, name: str, scopes: list[Optional[str]] + ) -> None: + expected = list(spec.operand_scopes) + if len(scopes) != len(expected): + raise CodegenError( + f"{name}: expected {len(expected)} operands, got {len(scopes)}" + ) + for i, (want, got) in enumerate(zip(expected, scopes)): + if want is None: + continue + if got is None: + raise CodegenError( + f"{name}: operand {i} must be a buffer in scope {want!r}, " + f"got non-buffer value" + ) + # `global.` operands satisfy a `` operand spec — + # the user-declared global flag only changes lane-fusion + # behaviour, not which physical RAM the operand reads/writes. + if _scope.physical_scope(got) != want: + raise CodegenError( + f"{name}: operand {i} must be in scope {want!r}, " + f"but found {got!r}" + ) + + # ------------------------------------------------------------------ + # header / buffer directives + # ------------------------------------------------------------------ + def _emit_header(self) -> None: + self._isa_lines.append(f"; ============================================") + self._isa_lines.append(f"; PLENA pseudo-ISA -- kernel: {self.name}") + self._isa_lines.append(f"; generated by tilelang_tvm_compiler (skeleton)") + self._isa_lines.append(f"; ============================================") + + def _emit_buffer_directives(self) -> None: + if not self._buffers: + return + self._isa_lines.append("") + self._isa_lines.append("; ---- buffers ----") + # Stable order: params first (by appearance), then allocs (by name). + seen = set() + order: list[_BufferInfo] = [] + for var in self.func.params: + buf = self.func.buffer_map.get(var, None) + if buf is not None and buf.name in self._buffers_by_name: + info = self._buffers_by_name[buf.name] + if info.name not in seen: + order.append(info) + seen.add(info.name) + for name, info in sorted(self._buffers_by_name.items()): + if name not in seen: + order.append(info) + seen.add(name) + for info in order: + shape_str = "x".join(str(s) for s in info.shape) + scope_token = { + _scope.HBM: "ALLOC_HBM ", + _scope.VRAM: "ALLOC_VRAM", + _scope.MRAM: "ALLOC_MRAM", + _scope.FPRAM: "ALLOC_FPRAM", + }[_scope.physical_scope(info.scope)] + self._isa_lines.append( + f"{scope_token} {info.name} shape={shape_str} dtype={info.dtype}" + ) + + +def compile_module(mod: tvm.IRModule) -> Dict[str, str]: + """Compile every PrimFunc in `mod` to PLENA pseudo-ISA. + + Returns a {global_symbol -> isa_text} mapping. + """ + out: Dict[str, str] = {} + for gv, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + cg = PlenaCodegen(func, name=gv.name_hint) + out[gv.name_hint] = cg.run() + return out diff --git a/tilelang_tvm_compiler/dead_buffer_elim.py b/tilelang_tvm_compiler/dead_buffer_elim.py new file mode 100644 index 0000000..38f25b4 --- /dev/null +++ b/tilelang_tvm_compiler/dead_buffer_elim.py @@ -0,0 +1,87 @@ +"""Dead-buffer elimination pass. + +Walks every op in the HLIRModule (recursing into structured ``for`` op +bodies), collects every buffer name referenced from ``buffer_args`` +(strings and ``BufferSlice.parent``) and from ``scalar_args`` +(``BufferElement.buffer`` plus any name picked up from a PrimExpr tree). +Buffer names that don't appear in that reachable set get dropped from +``mod.buffers`` — except param buffers, which stay (they're the kernel's +public interface and the HBM staging code assumes they exist). + +Why bother: kernels that allocate FP-slot fragments for one algorithm +variant (M_OLD, L_NEW, P_SUM, …) but only end up using a subset still +declare every fragment, and the post-expansion shape-checker can reject +fragments whose shapes don't match the in-use lane mode (see the +``PV_loc shape=(1, 4, 1, 16)`` case). Removing unreachable buffers up +front keeps HLIR honest and avoids spending FPRAM / VRAM on slots no op +will touch. +""" + +from __future__ import annotations + +from typing import Iterable, Set + +from . import hlir as _hlir + + +def _collect_from_primexpr(expr, out: Set[str]) -> None: + """Best-effort: walk a PrimExpr tree looking for BufferElement / + BufferLoad / Var-backed buffer references. We don't import tir here + so the pass stays usable in pure-HLIR builds; isinstance against the + BufferElement dataclass and a duck-typed ``.indices``/``.buffer`` + walk is enough for everything to_plena produces.""" + if isinstance(expr, _hlir.BufferElement): + out.add(expr.buffer) + for i in expr.indices: + _collect_from_primexpr(i, out) + return + # tir.BufferLoad: .buffer.name + recurse into .indices + buf_attr = getattr(expr, "buffer", None) + if buf_attr is not None and hasattr(buf_attr, "name"): + out.add(str(buf_attr.name)) + indices = getattr(expr, "indices", None) + if indices is not None: + for i in indices: + _collect_from_primexpr(i, out) + # Generic binop / call recursion via the .a/.b or .args fields tir + # nodes expose. + for attr in ("a", "b", "value"): + sub = getattr(expr, attr, None) + if sub is not None and not isinstance(sub, (int, float, str, bool)): + _collect_from_primexpr(sub, out) + args = getattr(expr, "args", None) + if args is not None: + for a in args: + _collect_from_primexpr(a, out) + + +def _collect_op_refs(op: _hlir.Op, out: Set[str]) -> None: + for ba in op.buffer_args: + if isinstance(ba, str): + out.add(ba) + elif isinstance(ba, (_hlir.BufferSlice, _hlir.VramRegion, + _hlir.MramRegion)): + out.add(ba.parent) + for sa in op.scalar_args: + _collect_from_primexpr(sa, out) + if op.body is not None: + for inner in op.body: + _collect_op_refs(inner, out) + + +def _collect_reachable(ops: Iterable[_hlir.Op]) -> Set[str]: + out: Set[str] = set() + for op in ops: + _collect_op_refs(op, out) + return out + + +def run(mod: _hlir.HLIRModule) -> _hlir.HLIRModule: + """Drop buffers that no op references. Param buffers (kernel + interface) are always kept.""" + reachable = _collect_reachable(mod.ops) + keep = set(reachable) | set(mod.param_names) + dropped = [name for name in mod.buffers if name not in keep] + for name in dropped: + del mod.buffers[name] + return mod diff --git a/tilelang_tvm_compiler/doc/AI_AGENT_GUIDE.md b/tilelang_tvm_compiler/doc/AI_AGENT_GUIDE.md new file mode 100644 index 0000000..61f800f --- /dev/null +++ b/tilelang_tvm_compiler/doc/AI_AGENT_GUIDE.md @@ -0,0 +1,641 @@ +# tilelang_tvm_compiler — AI Agent Knowledge Base + +This file is intentionally written for AI agents (and humans) starting a new +session. It captures the non-obvious lessons accumulated while building the +TVM-frontend path from TIR → HLIR → PLENA ISA. **Read this before editing +kernels under `kernels/`, intrinsics, the lowering pass, or testbenches under +`transactional_emulator/testbench/`.** + +If you discover something the next agent will trip on, append it here. + +--- + +## 1. Pipeline at a glance + +> **Architecture note.** The old graph-IR frontend +> (`graph_ir.Graph`, `lift_from_raw_primfunc`, `split_lane_groups`, +> `materialize_to_primfunc`, `PlenaCodegen`) **has been deleted**. +> `frontend/pipeline.py` is now a stub that raises if called. The only +> active lowering chain is the **mid_ir pipeline**. Don't trust any +> doc/comment that still mentions `graph_ir`. + +The real driver is **`tilelang_tvm_compiler/pipeline.py : compile_kernel`**: + +``` + @T.prim_func (tilelang DSL) + │ 0. stmt prep + │ inline_let_stmts → lower_compound_fp_stores → hoist_float_constants + ▼ + raw tir.PrimFunc (FP literals hoisted to global.fpram 1-slot buffers) + │ 1. mid_ir pipeline (frontend/mid_ir/passes/, 9 passes) + │ infer_lane_axis — pick the lane axis (see §11) + │ fold — raw TIR → mid_ir.MidFunc dataclass tree + │ mark — tag each op with its lane-fusion role + │ split — lane blockIdx → (number, phase); grow + │ non-global buffers by a cluster outer dim + │ distribute_cluster — push CLUSTER axes inside unroll/pipeline loops + │ async_wrap — wrap can-async ops in Async regions + │ view — assign view_perm to every BufferRef; + │ substitute lane var on HBM refs + │ fuse — collapse each Async region → one MultiLaneOp + │ burn_view — bake view_perm into physical shape + indices + │ to_plena — MidFunc → HLIRModule (exits mid_ir domain) + ▼ + HLIRModule ← buffers + Op stream, no addresses + │ 1.5 dead_buffer_elim — drop buffers no HLIR op references + │ 2. AddressAllocationPass (address_alloc.py) + ▼ + HLIRModule + addresses ← per-buffer base address resolved + │ 3. IsaEmitterPass.run (isa_pass.py) — RegisterAllocator + shim + ▼ + ISA text (`*_generated_asm_code.asm`) +``` + +- **The mid_ir is a typed dataclass tree** (`mid_ir/ir.py : MidFunc`), + not TIR. Passes 1–8 are `MidFunc → MidFunc`; `to_plena` (pass 9) is the + single bridge out to `HLIRModule`. The only TIR-level walkers left are + the pass-0 stmt-prep steps and `infer_lane_axis` (which still inspects + raw TIR — see §11). +- `compile_kernel(prim_func, *, target, name, midir_dump_dir=..., + addr_config_override=...)` is the entry point. `midir_dump_dir` makes + `to_plena` write `.midir.txt` **and** `compile_kernel` write + `post_to_plena.hlir.txt` right after `to_plena` — both survive a later + pass failure, so they are the go-to debugging artefacts. +- `addr_config_override` lets a multi-kernel driver pin FPRAM/HBM bases + per kernel — used by `tvm_single_stream_block_test` to stitch the + MMDiT chain into one continuous ASM run. + +--- + +## 2. Hardware mental model (essential) + +### Memories + +| Storage | Layout | dtype (per `plena_settings.toml` ANALYTIC mode) | +|---|---|---| +| HBM | DRAM, large | **MX-FP8** (block=8, elem e4m3, scale e8m0) | +| MRAM | matrix-side SRAM | **bf16** (Plain Fp, exp=8, mantissa=7) | +| VRAM | vector-side SRAM | **bf16** | +| FPRAM | scalar FP scratch | bf16 | + +The MX-FP8 quantization on HBM is real and intentional. Outputs from a +fp32 reference will have ~5–15% relative error vs. simulator output once +they've round-tripped through HBM at magnitudes like 5–17. **This is not +a kernel bug.** If you suspect quantization noise, switch the +`plena_settings.toml` HBM types from `format = "Mx"` to `format = "Plain"` +bf16 and rerun — error should drop to ~0. + +### Tile sizes + +- `MLEN = 64` — full vector tile width. A "VRAM row" is MLEN elements wide. +- `HLEN = 16` (typical) — narrow head dim. Used for BTMM and per-head MM. +- `BLEN = 4` — systolic-array tile size (BLEN × BLEN per `M_MM_WO`). +- `LANE_COUNT = MLEN / HLEN = 4` — number of hardware lanes packed into + one VRAM row. + +### Matrix engines + +| Op | Hardware | Reduction dim | Output shape per head | Use | +|---|---|---|---|---| +| `plena.btmm` | `M_BTMM`/`M_BMM_WO` | **HLEN** | (MLEN, MLEN) | Q @ K^T (lhs (B,S,H,D), rhs in MRAM) | +| `plena.mm` | `M_MM`/`M_MM_WO` | **MLEN** | (MLEN, MLEN) | regular MM, single head | +| `plena.mm_slot` | `M_MM`/`M_MM_WO` (column-slot loop) | **MLEN** | (MLEN, hlen) | per-head narrow MM (P @ V, etc.) | + +**Crucial rule**: `M_MM 0, rs1, rs2` reads vector tile from VRAM (rs2) and +matrix tile from MRAM (rs1). So **A @ B → A in VRAM, B in MRAM** regardless +of "narrow vs wide". The runtime compiler enforces this in +`_compute_manager.py:54-55` (`ensure_value_tile_in_place(rhs, "mram")`). + +### Buffer layouts you will see + +| Layout name | Shape pattern | Where | +|---|---|---| +| BSHD | `(B, S, H, D)` | HBM tensors (canonical), DMA preserves it | +| BHSD | `(B, H, S, D)` | BTMM #1 output (head-major: head h's tile starts at `h * S * D`) | + +`BHSD` and `BSHD` differ in **whether heads are interleaved within a row +(BSHD packed-narrow)** or **stacked as separate row groups (BHSD)**. The +`_logical_2d` helper flattens shapes to (rows, cols) — for BHSD `(1, 4, +64, 64)` you get `(4*64, 64*64/64) = (256, 64)`. Picking head h's tile +out of a BHSD VRAM buffer means addressing at `base + h * MLEN * MLEN`. + +### Where quantization actually happens (verified in the sim) + +Common misconception: that the sim quantizes everywhere. It does not. + +- **`QuantTensor::quantize` is a TODO no-op** — it does not quantize. + Any VRAM tile written via `QuantTensor::quantize(t, ty)` keeps its + fp32 values; the `ty` is just a label. +- The real MX-E4M3 quantization happens in **`into_bytes`** (the + VRAM→HBM serialise path, `H_STORE_V`). HBM stores MX-E4M3. +- The HBM→VRAM path (`H_PREFETCH_V` / `transfer_mx_from_hbm`) *decodes* + MX back to fp32; it does not add a second rounding. +- Net effect: an intermediate tensor that a kernel writes to an HBM + scratch and the next kernel reads back is **MX-quantized exactly once** + (at the store). +- On-chip matmul / vector ops run in **fp32** in the sim (`to_kind(Float)`). +- The Rust `into_bytes` MX quantizer is line-for-line equivalent to the + Python `_mx_fp_quantize_hardware` (same `floor(log2(max)+1e-9)`, same + bias, same block grouping over the last dim). So a golden modelled with + `_mx_fp_quantize_hardware` matches the sim's HBM bytes — *provided* the + golden quantizes the tensor in its real staged 4D shape (block grouping + is over the last dim; quantizing a reshaped 2D view changes the block + boundaries). + +--- + +## 3. Intrinsics convention + +Every `plena.*` intrinsic has a fixed `operand_scopes` tuple in +`intrinsics.py`. The trailing `None` slots are *scalar* slots (length must +match exactly at codegen time, see `_verify_scopes` in `codegen.py:474`). + +### The `_at` family — per-row scalar tasks + +The "row scalar" ops come in two physical shapes: + +**FP-touching, signature `(vram_row, lane, mask)`** (3 scalars): +- `plena.row_reduce_max_at` / `plena.row_reduce_sum_at` +- `plena.row_sub_fp_at` / `plena.row_add_fp_at` / `plena.row_mul_fp_at` + +**VRAM-only, signature `(vram_row, mask)`** (2 scalars): +- `plena.row_exp_at` + +**Mask semantics (KEY):** +- `mask = 0` → unmasked path: lowering emits ` ..., 0` (no V_MASK + setup, no tail clear). VRAM addressing uses `vram_row`. FP addressing + uses `lane * fp_buf.shape[-1] + vram_row`. +- `mask != 0` (literal or PrimExpr) → masked: emits `C_SET_V_MASK_REG mask`, + uses `..., 1` flag, **always emits a tail clear `C_SET_V_MASK_REG 0`** + so subsequent ops see V_MASK=0. + +The unification is implemented by `mask_static_zero` in +`isa_pass.py:_emit_row_scalar_op_at`. When mask is the literal 0, V_MASK +emission is skipped — but it MUST be a literal int / IntImm. A PrimExpr +that happens to evaluate to zero will not trigger the optimisation. + +**When to use mask=0 vs mask=1<` per its own iter; if that exceeds 10 000, **unroll the outer + loop** with `T.unroll(N)`. Unrolled iters don't appear as a hardware + loop, so they don't accumulate. +- Optimisations that lowered the row body include + `_emit_fp_scalar_op_at`'s row-expr CSE (one materialise, S_ADDI per + buffer) and the materializer's Add/Mul folds. + +--- + +## 6. FP buffer layout convention (FlashAttention kernel) + +FP buffers in `flash_attention_min.py` are declared as 1D per-lane +fragments `(rows,)` and the compiler expands each to `(lane_count, rows)` +inside the lane group. The address allocator places them sequentially +starting at `FPRAM_USER_BASE = 32`, each slot `lane_count * rows` wide +(= `4 * 64 = 256` for the typical config). + +Current FP buffers (per the actual HLIR, in declaration order): + +``` +M_OLD addr = 32 + 0 * 256 = 32 +M_CURR addr = 32 + 1 * 256 = 288 +M_RES addr = 32 + 2 * 256 = 544 +L_OLD addr = 32 + 3 * 256 = 800 +L_NEW addr = 32 + 4 * 256 = 1056 +P_SUM addr = 32 + 5 * 256 = 1312 +L_INV addr = 32 + 6 * 256 = 1568 +``` + +Per-lane addressing within an FP buffer: element `[lane, row]` is at +offset `base + lane * rows + row`. + +**Scale / M_init / L_init are no longer declared FP buffers.** The +kernel embeds the literals directly as `T.float16(...)` +(`scale_val = 1/sqrt(d_k)`, `-1.0e4` for the M init, `0` for L). The +`hoist_float_constants` pre-pass synthesises a 1-slot `global.fpram` +buffer per unique value (e.g. `__const_f16_0p25`, +`__const_f16_neg10000`), and `test_helper` auto-preloads them from the +`--dump-buffer-addrs` JSON. So the kernel no longer needs the testbench +to preload an `active_lane` segment of any user FP buffer — every FP +buffer is written before it is read (M_OLD/L_OLD reset from the hoisted +consts at the top of each q_block). + +--- + +## 7. FlashAttention kernel structure (current state) + +`flash_attention_min.py` produces this op nest (HLIR view, with +`head_count=8, num_q_blocks=2, num_kv_blocks=2`). All heads are run — +the per-`by_phase` loops below cover the full `lane_count` (= MLEN/HLEN): + +``` +for q_block in [0, 2): ; outer Q loop + dma Q[q_block] -> Q_sh + for row in [0, 64): ; zero running output + v_zero O_loc[row] + for row in [0, 64): ; reset FP state (ALL lanes) + for by_phase in [0, 4): + fp_copy_at __const_f16_neg10000 -> M_OLD + for by_phase in [0, 4): + fp_zero_at L_OLD + for kv_block in [0, 2): ; KV loop + dma K[kv_block] -> K_sh + dma V[kv_block] -> V_sh + btmm Q_sh @ K_sh -> S_loc ; per-head Q @ K^T + for row in [0, 64): + for by_phase in [0, 4): + row_mul_fp S_loc *= __const_f16_0p25 ; 1/sqrt(d_k) + for by_phase in [0, 4): + fp_copy_at M_OLD -> M_CURR + for row in [0, 64): + for by_phase in [0, 4): + row_reduce_max_at S_loc -> M_CURR + for row in [0, 64): + for by_phase in [0, 4): + fp_sub_at M_OLD - M_CURR -> M_RES + for by_phase in [0, 4): + fp_exp_at M_RES -> M_RES + for by_phase in [0, 4): + row_sub_fp S_loc -= M_CURR + for by_phase in [0, 4): + row_exp S_loc = exp(S_loc) + for by_phase in [0, 4): + fp_zero_at P_SUM + for row in [0, 64): + for by_phase in [0, 4): + row_reduce_sum_at S_loc -> P_SUM + for row in [0, 64): + for by_phase in [0, 4): + fp_mul_at L_NEW = L_OLD * M_RES + for by_phase in [0, 4): + fp_add_at L_NEW += P_SUM + for by_phase in [0, 4): + row_mul_fp O_loc *= M_RES + for by_phase in [0, 4): + fp_copy_at M_CURR -> M_OLD + for by_phase in [0, 4): + fp_copy_at L_NEW -> L_OLD + for by_phase in [0, 4): ; per-head P @ V + matmul S_loc[by_phase] @ V_sh[..by_phase..] -> PV_loc[..by_phase..] + for row in [0, 64): + v_add O_loc += PV_loc + for row in [0, 64): ; finalize: O /= L_new + for by_phase in [0, 4): + fp_reci_at L_NEW -> L_INV + for by_phase in [0, 4): + row_mul_fp O_loc *= L_INV + dma O_loc -> O_hbm[q_block] +``` + +- **Softmax is run on every head** — each `for by_phase in [0, 4)` walks + the full lane group. (An earlier version ran only a single + `active_lane`; that is no longer the case.) +- **P @ V uses `plena.matmul`** (`M_MM` / `M_MM_WO`), one issuance per + head — *not* `mm_slot`. +- `M_OLD` / `L_OLD` are **re-initialised inside the q_block loop** from + the hoisted consts `__const_f16_neg10000` / a zero store, so a + multi-q-block run resets cleanly per tile. + +### Layouts + +- `S_loc` is **BHSD** (BTMM #1's natural output): each VRAM row is one + head's full mlen-wide score row. +- `O_loc` / `PV_loc` are **BSHD**: heads occupy column slots within a + row. `matmul` writes head h's hlen columns at the matching slot. + +### What's intentionally NOT done yet + +- **Causal mask** — needs a preloaded VRAM `mask` buffer + `v_add` + before softmax. Mirror `attention.py`'s approach. +- **Batch > 1**. + +--- + +## 8. Testbench conventions +(`transactional_emulator/testbench/tvm_flash_attention_min_test.py` and +similar `tvm_*_test.py` files at the testbench root.) + +- The build recipe `just build-emulator-debug ` looks for the script + at `transactional_emulator/testbench/_test.py` (top level), or for + a hard-coded list of names like `tvm_online_softmax_min` it routes to + `transactional_emulator/testbench/tile_tensor_kernel_programs/.py`. +- Use a robust repo-root finder (walk up `_THIS_FILE.parents` for + `.venv-tvm` and `compiler/`) — depth depends on which subdir the script + ends up in. +- The compile happens via `subprocess.run(VENV_TVM_PYTHON, "-m", + "tilelang_tvm_compiler", "compile", ...)` because TVM only lives in + `.venv-tvm`. The test script itself runs in the main `.venv` (Python 3.12). +- `create_sim_env` lays down `Q_hbm.pt`, `K_hbm.pt`, `V_hbm.pt`, `O_hbm.pt`, + `golden_result.txt`, `fp_sram.bin`, `int_sram.bin`. `create_mem_for_sim` + produces `generated_machine_code.mem` and `hbm_for_behave_sim.bin`. +- `comparison_params.json` controls the post-run diff — set `num_rows`, + `num_batches`, `elements_per_batch`, `row_dim` so they tile the flat + golden correctly. + +### Golden gotchas + +- The kernel runs softmax on **all heads** now, so the golden is plain + per-head `softmax(scaled_score) @ V` for every head. (An earlier + version ran a single `active_lane` and the golden had to mirror that + with `score @ V` for non-active heads — that is no longer needed.) +- `torch.softmax(x, dim=-1)` is mathematically equivalent to the kernel's + online `max → sub → exp → sum → divide` chain. Either works for the + golden; the online form is only needed if you want to model the f16 + truncation of each FPRAM scalar step. + +### Golden comparison — the biggest time sink (read this) + +The golden *comparison* path has bitten us harder than any kernel bug. +Two distinct bugs, both in how the golden reaches `check_mem.py`: + +1. **`golden_result.txt` was written with `%.2f`** (2 decimal places). + `check_mem.parse_golden_output` then parsed that text back as the + golden — so a true golden of `-0.0083` became `-0.01`, and small + values showed a fake ~0.8 relative error. Fix: `create_sim_env.py` + now also writes a lossless `golden_output.pt`; `parse_golden_output` + prefers the `.pt` and only falls back to the text. + +2. **`check_mem.py` down-cast the golden to `bfloat16`** before + comparing ("for fair comparison with hardware"). bf16 has 7 mantissa + bits — rounding the golden first inflates `|err|/|golden|` for small + values. Fix: keep the golden `.float()`; only the *simulated* side + stays bf16 (that IS the VRAM storage). + +If a comparison shows a wall of fake error on small magnitudes, suspect +the comparison path before the kernel. Verify `golden_output.pt` exists +in `build/` and that `parse_golden_output` is reading it. + +### Diagnosing a chained-kernel (SSB / MMDiT) failure + +When a multi-kernel chain's match rate collapses but each kernel passes +its own standalone test, the bug is almost always a **kernel-to-kernel +hand-off**, not a kernel. The proven isolation technique: + +- Give the suspect kernel's input an **independent HBM tensor** (its own + address, role `"input"`, preloaded) instead of aliasing the upstream + kernel's output `"scratch"` buffer. The upstream kernel still runs but + can no longer overwrite the isolated input. +- Feed that independent input either a clean random tensor *or* the + upstream kernel's own golden output. +- If the kernel now passes → the bug is the upstream→this hand-off (the + sim writing the shared HBM wrong, or a layout mismatch). If it still + fails → the kernel itself, in the chain environment, is wrong. +- **Pitfall**: do NOT just preload the shared `scratch` buffer — the + upstream kernel will overwrite it at runtime. The input must be a + *separate* address the upstream kernel never writes. + +--- + +## 9. Lessons from previous failure modes + +These are real bugs that cost time during development. If you reintroduce +any of these, the test will fail in confusing ways: + +- **Don't reuse one scalar for two different addressings**. The pre-fix + `_at` ops took a single `row` scalar and used it as both VRAM row index + and FP element offset. For multi-lane FP state with single-lane VRAM + data (BSHD), these need to differ. Hence the current `(vram_row, lane, + mask)` triple. + +- **Don't put all kv_blocks in one hw loop body**. With softmax body ~145 + instr × 64 row iters = 9 280 per kv iter, an outer `T.serial(num_kv)` + loop would multiply that into one hw-loop iter and hit the 10 000 cap. + Use `T.unroll`. Same applies to q_block. + +- **`M_OLD` / `L_OLD` must be reset *inside* the q_block loop**. After + the first q_block runs, `M_OLD` is overwritten by `fp_copy(M_curr → + M_old)` at the end of every row, so the next q_block would start from + stale state. The current kernel handles this correctly: it re-inits + `M_OLD` / `L_OLD` from the hoisted consts (`__const_f16_neg10000` / + zero) at the top of each q_block. Do not move that reset outside the + loop or back to a one-time preload. + +- **Don't use `from __future__ import annotations`** in kernel files. (See + §4.) + +- **`mm_slot` LHS check used to require single-tile**. Was relaxed to allow + multi-head LHS via `lhs_row_offset` scalar. Existing callers (tiled_mm) + pass `0` as the new first scalar. + +--- + +## 10. Useful one-liners + +Recompile a kernel + dump HLIR (from repo root): + +``` +PYTHONPATH=compiler LD_LIBRARY_PATH= .venv-tvm/bin/python -m tilelang_tvm_compiler \ + compile \ + --kernel "tilelang_tvm_compiler.kernels.flash_attention_min:make_flash_attention_min" \ + --kernel-kwargs "rows=64,hlen=16,lane_count=4,active_lane=2,num_kv_blocks=2,num_q_blocks=2" \ + --asm-name flash_attention_min --mlen 64 --btmm-hlen 16 --stage-output O_hbm \ + --dump-hlir transactional_emulator/testbench/build/flash_attention_min.hlir.txt \ + > transactional_emulator/testbench/build/flash_attention_min_generated_asm_code.asm +``` + +Run all relevant unit tests: + +``` +cd compiler && PYTHONPATH=. LD_LIBRARY_PATH= ../.venv-tvm/bin/python \ + tilelang_tvm_compiler/tests/test_expr_materializer.py +# then test_narrow_mm_emitter.py, test_fpram_ops.py, test_loop_dma.py, ... +``` + +Measure inner-loop body size (in lines, ≈ static instr) to verify the +10 000-cap budget: + +``` +awk '/; for row in/{f++} f==1 && /C_LOOP_START.*64/{n=NR+1; next} \ + f==1 && /C_LOOP_END/ && n {print "row body:", NR-n; exit}' \ + transactional_emulator/testbench/build/.asm +``` + +Show the high-level ISA structure (loops + matmul ops) without flooding +the terminal: + +``` +grep -nE '^; for |^C_LOOP_(START|END)|^M_BTMM|^M_BMM_WO|^M_MM\b|^V_ADD_VV' \ + transactional_emulator/testbench/build/.asm +``` + +--- + +## 11. Lane fusion — the core mechanism + +Multi-lane fusion is the heart of the frontend: `MLEN / HLEN` hardware +lanes are packed into one VRAM row, and one multi-lane HW op fires once +for all lanes instead of looping. Getting the **lane axis** right is +what makes this work. + +### Lane axis is picked by IR-node analysis, NOT string matching + +`infer_lane_axis.py` decides which `blockIdx.*` grid var is the lane +axis. The judgment is done on the **TIR AST**, not on text: + +```python +# infer_lane_axis._collect_bare_index_var_names +def visit(node): + if isinstance(node, tir.BufferLoad): + for idx in node.indices: + if isinstance(idx, tir.Var): # ← the actual test + found.add(idx.name) +stmt_functor.post_order_visit(func.body, visit) +``` + +The rule: a grid var is a **lane candidate** iff it appears as a +**bare index slot** — `BufferLoad.indices[i]` is *exactly* a `tir.Var` +node — somewhere in the body, AND its extent is divisible by LANE. + +- `Q_hbm[0, q_block*rows, by, 0]` — the `by` slot is a naked `tir.Var` + node → `by` is a lane candidate. +- `q_block * rows` — that's a `tir.Mul` node; `q_block` is *inside* it, + not bare → `q_block` is an outer control loop, **not** a lane axis. + +A plain string search for `"by"` could never tell those apart — both +texts contain `by`/`q_block`. Only inspecting the IR node *type* at the +index position (`isinstance(idx, tir.Var)`) distinguishes a per-lane +index from an arithmetic offset. This is the precise sense in which +lane fusion is **value/ref-based, not string-based**. + +Resolution: 0 candidates → no lane axis (cluster pipeline skipped); +1 → picked; 2+ → `InferLaneAxisError`, author must set +`T.func_attr({"plena.lane_axis": ""})`. A manual attr always wins. + +### How the lane axis flows through the rest of the pipeline + +- **split** turns the lane blockIdx into `(number, phase)` and grows + every non-global buffer by a cluster outer dim, so per-lane data has + somewhere to live. +- **mark** tags each op with its lane-fusion role; **async_wrap** groups + can-async ops; **fuse** collapses each Async region into a single + `MultiLaneOp` — that is the actual "fire once for all lanes" step. +- **view** assigns a `view_perm` to every `BufferRef` and substitutes + the lane var on HBM refs; **burn_view** bakes that permutation into + the physical shape + index tuples. + +### Bare-Var detection vs. lane-var substitution — two different things + +`infer_lane_axis`'s **bare-`tir.Var`** test only *picks which axis is +the lane*. It is deliberately strict: `by + 8` (a `tir.Add`) is NOT a +lane candidate, only a naked `by` is. This keeps the axis choice +unambiguous. + +Once the axis is chosen, **`view._subst_lane_var` is recursive** and +handles the lane var wherever it sits inside an index expression — not +just bare: + +```python +def _subst_lane_var(idx, ctx): + if isinstance(idx, VarRef) and idx == ctx.original_var: + return phase + number * cluster_count # the substitution + if isinstance(idx, dict): # compound node (add/mul/…) + return {"op": idx["op"], + "args": [_subst_lane_var(a, ctx) for a in idx["args"]]} + return idx +``` + +So an HBM index like `O_hbm[..., by + o_head_offset, ...]` works: the +recursion descends into the `add`, finds the `VarRef(by)` inside, and +rewrites just that leaf — the `+ o_head_offset` is preserved. This is +what `flash_attention_min.py` relies on to write its output into a +head-slice of a wider tensor. + +Summary: **lane-axis selection** = strict bare-Var only; **lane-var +substitution** = recursive, accepts the lane var combined with offsets +(`by + c`, `by * c`, …). + +So "multi-lane fuse" is not one pass — it's the chain +`infer_lane_axis → split → mark → async_wrap → fuse → view → burn_view`, +all operating on typed mid_ir nodes. diff --git a/tilelang_tvm_compiler/doc/LOOP_REGISTER_ALLOC.md b/tilelang_tvm_compiler/doc/LOOP_REGISTER_ALLOC.md new file mode 100644 index 0000000..5f6e42d --- /dev/null +++ b/tilelang_tvm_compiler/doc/LOOP_REGISTER_ALLOC.md @@ -0,0 +1,130 @@ +# 循环变量寄存器分配 — HLIR liveness pass 设计 note + +> 状态:设计定稿,未实施。 + +## 1. 背景与问题 + +当前 GP 寄存器分配**全部在 emit 阶段边走边分**,散在三处: + +* `isa_pass._emit_for` — 循环的 `gp_loop`(硬件计数器)+ idx。 +* `expr_materializer` — 物化 `tir.PrimExpr` 的临时 GP。 +* `isa_emitter.emit_*`(40+ 个)— 每条指令的 scratch GP。 + +三者**抢同一个 16-GP 池**(gp0 保留),靠 `pin_gp` / auto-spill 临时协调。后果: + +* **Bug 2 类**:caller(如 `_resolve_offset`)对 materializer 返回的寄存器 + 做 `pin/unpin/free`,而该寄存器可能是 per-op idx 缓存拥有的 → + 双重 pin、提前 unpin、潜在 double-free。根因是「长生命周期的循环 + 变量寄存器」和「短生命周期的临时值」在同一个裸池里没有边界。 +* **深嵌套 GP 耗尽**:每层 C_LOOP 占 `gp_loop`(+idx),五层嵌套 + + matmul 的 7 个 scratch → 15 个 GP 不够,`RegisterExhausted`。 + +## 2. 为什么不做「一个 pass 全包」 + +HLIR liveness **看不到** emit 阶段的临时 GP 需求: + +* `expr_materializer` 物化一个 PrimExpr 树要几个临时 GP,取决于树形 + 和遍历顺序 —— 是 emit 时才定的。 +* `emit_matmul_general` 内部 7 个 scratch、`emit_row_operation` 5 个 … + 这些是「一个 HLIR leaf op 降成几十条 ISA」时内部的事,HLIR 上不可见。 + +所以 HLIR 上能精确分析的,只有**显式写出来的循环变量**(`for` op 的 +`loop_var`)—— 它的 def / use / kill 在 HLIR 上一清二楚。 + +## 3. 方案 — 分层 + +| 层 | 谁分配 | 何时 | +|----|--------|------| +| **循环变量寄存器**(`gp_loop` + idx,每层 C_LOOP 一组) | 新的 **HLIR liveness pass** —— 全局算活跃区间,提前定死 GP 号 | address_alloc 之后、isa_pass 之前 | +| **op 内部临时值**(matmul scratch、materializer 中间值) | emit 阶段局部分配器 —— **不变**,但 GP 池被缩小 | isa_pass 内,照旧 | + +**关键机制**:`RegisterAllocator.__init__` 已有 `gp_reserved` 参数 +(register_alloc.py:76)。liveness pass 算出「同一时刻最多 N 个循环 +变量寄存器活跃」,并给每个循环变量定一个具体 GP 号;这些 GP 号进 +`gp_reserved`。emit 阶段的 `allocate_gp` 从此**拿不到**这些 GP —— 循环 +变量寄存器和临时值物理隔离,不再同池抢。 + +`_emit_for` 不再 `allocate_gp` 循环寄存器,改成**读 pass 标在 op 上的 +GP 号**。emit 阶段的临时分配只在「剩余 GP」里跑。 + +## 4. liveness 分析 + +HLIR 是线性 op 流(`HLIRModule.ops: List[Op]`),`for` op 带 `body` +子列表。循环变量的生命周期由结构决定,不需要数据流不动点迭代: + +* **def**:`for` op 进入处。 +* **use**:body(含嵌套)里任何 `scalar_args` / region `starts` / + `BufferElement.indices` 的 PrimExpr 树中出现该 `loop_var`。 +* **kill**:`for` op 的 body 结束。 + +所以「活跃区间」= 循环的词法嵌套范围。**同一时刻活跃的循环变量数 += 当前嵌套深度**。算法: + +1. 递归遍历 `mod.ops`,维护一个「当前嵌套的 `for` 栈」。 +2. 进入一个 `for` → 它的循环变量在栈上,需要: + * 1 个 `gp_loop`(C_LOOP 硬件计数器,恒占 GP)。 + * idx:若该 `for` 是 `unroll` → idx 是编译期常量,**0 GP** + (已实现,`_emit_for` unroll 分支绑 `IntImm`); + 若是 `serial`(C_LOOP)→ idx 需要 1 个寄存器位 + (GP 或 IntRAM slot —— 见第 6 节)。 +3. 退出 `for` → 释放。 +4. 全程峰值 = 需要预留的循环寄存器总数。 +5. 按嵌套深度给每层一组固定 GP 号(外层先分,线性扫描即可 —— 区间 + 是严格嵌套的,没有部分重叠,不需要图着色)。 + +「严格嵌套、无部分重叠」是这个分析能简单的根本原因:循环区间要么 +包含、要么不相交,永远不会交叉。线性扫描 / 栈深度就够,不需要通用 +寄存器分配器。 + +## 5. 改动清单 + +| 文件 | 改动 | +|------|------| +| 新建 `loop_register_alloc.py` | liveness pass:递归遍历,算每层 `for` 的循环寄存器,把 GP 号写到 `op.annotations["loop_gp"]`(`gp_loop` + 可选 idx_gp)。算出预留集合。 | +| `hlir.py` | `Op.annotations` 约定新键 `loop_gp` —— 无需改结构。 | +| `pipeline.py` | address_alloc 之后插入 `loop_register_alloc.run(mod)`;它返回预留 GP 集合,传给 `IsaEmitterPass` 构造的 `RegisterAllocator(gp_reserved=...)`。 | +| `isa_pass._emit_for` | 不再 `allocate_gp`/`claim_idx_slot` 循环寄存器,改读 `op.annotations["loop_gp"]`。`pin` 仍做(防 emit 临时分配器误用),但 GP 号是 pass 给的。 | +| `expr_materializer` / `emit_*` | **不动** —— 它们的 `allocate_gp` 照旧,只是池子小了。 | + +不动:`register_alloc.py` 的核心(只多用 `gp_reserved`)、所有 mid_ir +pass、`address_alloc`、`dead_buffer_elim`、`loop_interchange` / +`fuse_adjacent_loops`。 + +## 6. 待定 — idx 放 GP 还是 IntRAM + +`serial` 循环的 idx,两个选择: + +* **idx 进预留 GP**:body 里用 idx 零 `S_LD_INT`(materializer 走 + `binding is int` 快路径)。代价:每层 serial 多预留 1 个 GP。 +* **idx 进 IntRAM**:每层只预留 `gp_loop` 1 个 GP,idx 走 `S_LD_INT` + 重载(已有 per-op 缓存把同一 op 内的重载压成 1 次)。 + +liveness pass 算出峰值嵌套深度后,可以**动态决定**:深度浅 → idx 全 +进 GP(最快);深度深到 GP 不够 → 外层 idx 落 IntRAM、内层进 GP。这 +正是「全局视野」带来的好处 —— emit 阶段做不到,pass 能。 + +第一版可以先简单:idx 一律 IntRAM(保守、不会 GP 耗尽),把「按深度 +混合」留作 pass 内的后续优化。 + +## 7. 这样为什么根治 Bug 2 + +Bug 2 的根源:循环变量寄存器和 materializer 临时值在同一个裸 GP 池, +caller 不知道手里的寄存器是不是别人(idx 缓存)拥有的,乱 `pin/free`。 + +分层后:循环变量寄存器在 `gp_reserved` 里,emit 阶段的 `allocate_gp` +**物理上拿不到它们**。materializer / `_resolve_offset` 操作的永远是 +「临时池」里的寄存器,那些确实是 caller 拥有、可以自由 pin/free 的。 +两类寄存器物理隔离 → 不存在「谁拥有」的歧义 → Bug 2 不可能发生。 + +per-op idx 缓存(已实现)那个 `owns_register=False` 的别扭设计,分层 +后也可以简化 —— idx 寄存器是预留的,本来就不该被 caller 的 `release` +碰。 + +## 8. 风险 + +* idx 全进 IntRAM(第一版保守选择)→ loadback 仍在,但 per-op 缓存 + 已把同-op 重复压掉。可接受,后续按第 6 节优化。 +* emit 临时池变小:预留掉循环寄存器后,深嵌套 kernel 留给 emit 的 GP + 可能不足 → `emit_matmul_general` 的 7 个 scratch 可能不够。liveness + pass 应在算出预留集合时**检查** `15 - 预留数 >= 单个 op 最大 scratch + 需求`,不满足就报错(明确,而非 emit 阶段 auto-spill 到崩)。 diff --git a/tilelang_tvm_compiler/doc/NESTED_CLUSTER_GQA.md b/tilelang_tvm_compiler/doc/NESTED_CLUSTER_GQA.md new file mode 100644 index 0000000..9b7c727 --- /dev/null +++ b/tilelang_tvm_compiler/doc/NESTED_CLUSTER_GQA.md @@ -0,0 +1,210 @@ +# 嵌套 Cluster — GQA 支持设计 note + +> 状态:设计已定稿,未实施。GQA kernel 已写好 +> (`kernels/flash_attention_gqa_min.py`),作为端到端验证用例。 + +## 1. 背景与目标 + +GQA(grouped-query attention)里 KV head 数 < Q head 数:`group_size` +个 Q head 共享一个 KV head,`kv_head_count = head_count // group_size`。 + +Q head `by` 映射到 KV head 用 **`by % kv_head_count`**(取模,不是 +floordiv): + +* PLENA ISA **没有整数除法 op**,只有取模 → `by // group_size` 无法 + lower 到硬件。 +* `%` 给出 *interleaved* 的 group 布局:Q head + `h, h+kv_head_count, h+2*kv_head_count, ...` 共享 KV head `h`。 +* 调用方需把 Q head 在 HBM 里按 interleaved-by-KV-head 排布。 + +这要求 lane 轴 `by` 被切成 **两层嵌套 cluster**,而当前体系只支持 +**单层 cluster**。 + +### 目标层级 + +`by`(head_count 个)被嵌套切分,产出三个轴: + +| 轴 | 角色 | kind | +|----------|-----------------------------------|-----------| +| `by_o` | 外层 cluster 的 grid(number) | BLOCK_IDX | +| `by_i_o` | 外层 cluster 的 phase;同时是内层 cluster 的 grid(number) | CLUSTER | +| `by_i_i` | 内层 cluster 的 phase | CLUSTER | + +`by_i_o` 的"双重身份"是**循环嵌套位置性**的(它在 `by_i_o`-cluster +体内、`by_i_i`-cluster 体外),不编码进 kind —— 它和 `by_i_i` 都是 +CLUSTER。 + +### op 的层级归属(GQA kernel) + +| op | 嵌套在 | 原因 | +|-----------------------------|---------------|------| +| K/V DMA(`by % kv_head_count`) | 只被 `by_o`(外层)| KV head 只随外层变,group 内共享 | +| Q DMA / BTMM(`Q@K^T`) | 被 `by_i_o`(外层 cluster)| Q 随完整 `by`,要覆盖整个 group | +| per-lane softmax / 标量 op | 被 `by_i_i`(内层)| 每个 Q head 各自的 softmax 状态 | + +`Q_sh` 的 s 维:`4 → 8`(hardware_lane_count × 第二层 cluster)。 + +## 2. 核心设计决定 + +| 决定 | 选择 | 理由 | +|------|------|------| +| `BufferDef.cluster_dim` | **不动**,保持单值 `int` | 走"乘积单维",砍掉 `to_plena` 等 40+/70+ 处级联改动 — 最大的难度来源被绕过 | +| buffer growth | **A:乘积单维** | `Q_sh [64,16] → [8,64,16]`,`cluster_dim=0`。两层之分不进 buffer 形状 | +| `by_i_o` kind | CLUSTER | 双重身份靠循环位置表达,不进对象状态 | +| op 层级归属判定 | **async_wrap** 算,不是 fuse | async_wrap 跑在 view 之前,看到的索引是干净的 `by % N`;它本就管同步域 | +| fuse 角色 | 只做结构折叠 + 提升 | 不做语义推断,只读 async_wrap 标好的位图 | + +**"cluster 列表"是这个方案的本质** —— 把到处隐含的单层 cluster 变成 +一等公民的有序列表: + +* 编译期配置:`cluster_counts` 每条目变成 `List[int]`(外→内)。 +* 运行结构:fuse 的 `cluster_stack` / `MultiLaneOp.cluster_axis_names`。 +* op 归属位图:`Async.trivial_levels` / `MultiLaneOp` 上的同名字段。 + +只有 `cluster_dim` 不列表化(故意)。 + +## 3. 改动方案 — 四节 + +依赖顺序:IR 字段 → split → async_wrap → view → fuse。 + +### 节 0:IR 字段 + +`ir.py`: + +* `Async` 新增 `trivial_levels: List[bool]`(外→内,每层 cluster 一个 + bool;`True` = 该层对这个 op trivial)。 +* `MultiLaneOp` 新增同名字段(从 `Async` 透传)。 +* `MidFunc.cluster_counts` 语义变更:`List[int]` → `List[List[int]]` + (每个 lane 轴一个内层列表)。 + +### 节 1:split — 嵌套切分 + +文件:`passes/split.py` + +* **`cluster_counts` 嵌套化**:每条目 `int → List[int]`(外→内)。 + GQA:`lane_axes=["by"]`, `cluster_counts=[[kv_head_count, group_size]]`。 + 单层 kernel 写 `[[4]]`(或保留 `int` 自动包成单元素 list 兼容)。 +* **`_split_once(axis, count)`**:抽出一个纯函数,把一个轴切成 + `(number, phase)` 对,**不依赖** `lane_axes` / kind 检查。 +* **递归切**:`_split_or_walk_parallel` 命中 lane 轴时,拿到该轴的 + count 列表,对列表 `fold`:第一次输入用户的 BLOCK_IDX 轴,后续每次 + 输入上一次产出的 phase 轴。`splittable_kinds` 闸门**只在最外层入口 + 判一次**;内层对 phase 轴的递归调用直接走 `_split_once`,不过 kind + 检查(因为 phase 轴是 CLUSTER,不在 splittable_kinds 里)。 + 嵌套结构:`by_o.body=[by_i_o]`, `by_i_o.body=[by_i_i]`, + `by_i_i.body=inner_body`。 +* **buffer growth**:现有 line 376-383 的乘积逻辑**几乎不动**,只需 + flatten 一层嵌套列表: + + ```python + cluster_total = 1 + for axis_counts in cluster_counts: # axis_counts = [kv_head_count, group_size] + for c in axis_counts: + cluster_total *= c + grown[buf.name] = _grow_buffer(buf, cluster_total) + ``` + + `_grow_buffer` 本身一行不改:仍 `shape=[cluster]+shape`, + `cluster_dim=0`。`Q_sh → [head_count,64,16]`。 + +### 节 2:fuse — 嵌套 + DMA 提升 + +文件:`passes/fuse.py` + +* **`cluster_stack` 递归 push — 零改动**。`_walk` 每进一层 CLUSTER 就 + `cluster_stack + [_ClusterAxis(...)]`;嵌套后走到内层体内自然是 + `[by_i_o_axis, by_i_i_axis]` 两项。 +* **`_fuse_async` 读 `trivial_levels`**:不再自己扫索引反推 trivial。 + 直接读 `Async.trivial_levels`,透传到 `MultiLaneOp`。`MultiLaneOp` + 仍带**全部层** `cluster_axis_names` + `trivial_levels` 位图。 + `dim_map` 维持 `[cdim]*n_axes`(buffer 走 A,两层物理维都是 0)。 +* **提升**:`_walk` 处理 CLUSTER 节点时,收集 body 里"内层全 trivial" + 的 MultiLaneOp,移到该 cluster 轴的**兄弟位置(前面)**。K/V DMA + 因此只被 `by_i_o` 包着。 + **关键约束**:K/V DMA 在 `for kv_block` 内,提升出内层 cluster 但 + **不能**提出 `kv_block` 循环(每个 kv block 的 K/V 不同)。提升落点 + 是「最内的非 trivial 包裹」—— 内层 cluster 之外、`kv_block` 之内。 +* **纪律**:fuse **不碰索引**。识别 `by % N` 只在 async_wrap 做一次 + (趁索引干净)。`_match_lane_composite` 维持原状只认 lane 形式。 + 不要在 fuse 里 match view 改写过的脏 mod 表达式。 + +### 节 3:async_wrap — 标每层 trivial 位图 + +文件:`passes/async_wrap.py` + +* **`_walk` 携带 cluster 轴列表**:`in_cluster: bool` → + `cluster_axes: List[ParallelAxis]`(外→内的 CLUSTER 轴)。进 CLUSTER + 节点时 append。 +* **包 Async 时算位图**:async_wrap **不碰索引**(现状 line 25-27), + 所以此刻 K/V DMA 的索引还是干净的 `by % kv_head_count`,Q/BTMM 是 + 裸 `by`。对每个 can_async op、对 `cluster_axes` 每一层 `ax`: + * 索引里 `by` 出现在 `mod(by, N)` 内 → 只依赖外层 → 内层 trivial。 + * 索引里 `by` 裸用 / 其它形态 → 两层都非 trivial。 + * 产出 `trivial_levels`:K/V DMA = `[False, True]`, + BTMM/softmax = `[False, False]`。 +* `trivial_levels` 写到 `Async` 上,随后透传到 `MultiLaneOp`(fuse)。 + +### 节 4:view — ctx 栈化 + +文件:`passes/view.py` + +**这是收尾里最实的一块** —— view 不是零改动。 + +* 现状 `_walk`(line 451)进 CLUSTER 节点造 `new_ctx` **替换**传入的 + `ctx`(line 505)。嵌套两层时外层 ctx 被丢弃 → 出错。 +* **`_ClusterCtx` 单值 → ctx 栈**:`List[_ClusterCtx]`(外→内)。进 + CLUSTER 节点时 append 而非替换。 +* **`_rewrite_lane_ref`**(line 211):`new_indices = [ctx.phase_var] + + indices` 只 prepend 一个 phase。改成 prepend 栈里**所有层**的 + phase(non-global buffer `Q_sh` 已 grow 成 `[8,64,16]`,需两个 phase + 索引才对得上 rank)。 +* **`_subst_lane_var`**(line 122):只认 `ctx.original_var` 一个,把 + `by` 换成单层 `phase+number*count`。改成把 `by` 换成**两层合成式**。 + 递归不挑 op(line 134-138),`by % N` 的 `mod` 节点会被下钻,里面的 + `by` 被替换 —— 这点本就成立。 +* `trivial_levels` 位图在 view **之前**就由 async_wrap 定死,view 改的 + 是索引的值、不碰位图 —— 位图安全。 + +## 4. 改动总览表 + +| 节 | pass / 文件 | 改动量 | 说明 | +|----|-------------|--------|------| +| 0 | `ir.py` | 小 | 2 个新字段 + `cluster_counts` 类型 | +| 1 | `split.py` | 中 | `_split_once` + fold 递归;乘积 flatten 一层 | +| 3 | `async_wrap.py` | 中 | `_walk` 带轴列表;算 `trivial_levels` | +| 2 | `fuse.py` | 中 | `cluster_stack` 零改;`_fuse_async` 读位图 + 提升落点判断 | +| 4 | `view.py` | 中 | `_ClusterCtx` 栈化;多层 prepend;两层合成替换 | + +**不动**:`BufferDef.cluster_dim`(保持单值 int),`to_plena.py` +(70+ 处读 `cluster_dim`),`_grow_buffer`(一行不改)。 + +## 5. 风险点 + +* **空 cluster body**:提升 K/V DMA 出内层后,若某 cluster body 所有 op + 都被提走 → body 空。`distribute_cluster` / `to_plena` 应 graceful + 容忍空 `ParallelAxis.body`(生成空循环或跳过),不报 unhandled。 + **GQA kernel 实际跑不到**(softmax 链一定留在最内层),标记为风险 + 但不阻塞主线。 +* **`trivial_levels` 透传**:位图从 async_wrap 生成、穿过 view、到 fuse + 消费 —— 每个 pass 重建节点时都要记得透传,中间不能丢。 +* **fuse 提升落点**:「提出内层 cluster 但不出 `kv_block` 循环」是方案 + 里最容易写错的一处,需要细心。 + +## 6. 难度评估 + +中等偏上,不是大重构。约 **3–5 天实现 + 2–3 天调试**(嵌套 cluster +第一次跑,大概率有 rank / 索引对不齐要排)。 + +最值钱的判断:走"乘积单维"绕开 `cluster_dim` 级联,把"40+ 处改动" +降到"4 个 pass 局部改 + 1 个 IR 字段"。 + +GQA kernel 已就绪,是现成的端到端验证用例。 + +## 7. 验证 + +`kernels/flash_attention_gqa_min.py`: + +* `make_flash_attention_gqa_min(head_count=8, group_size=2, ...)` → + `kv_head_count=4`。 +* `group_size=1` 退化成纯 MHA,应与 `flash_attention_min` 输出一致 — + 回归保护。 diff --git a/tilelang_tvm_compiler/expr_materializer.py b/tilelang_tvm_compiler/expr_materializer.py new file mode 100644 index 0000000..6b937f0 --- /dev/null +++ b/tilelang_tvm_compiler/expr_materializer.py @@ -0,0 +1,634 @@ +"""Lower a `tir.PrimExpr` tree into ISA + a register holding the value. + +This is the Phase-1 foundation for handling dynamic / symbolic values +(loop vars, slice offsets, tensor strides that depend on shape vars). +It is intentionally NOT yet wired into the main pipeline -- the existing +BTMM end-to-end test does not exercise any PrimExpr arg, so adding this +module is purely additive and cannot regress that path. + +Typical use, once wired in: + + sym_table: dict[tir.Var, int] = {} # var -> currently-bound GP reg + mat = ExprMaterializer(shim, sym_table) + m = mat.materialize(my_expr) + shim.compiler.generated_code += m.isa + # ... emit an instruction that uses gp{m.register} ... + m.release() # frees register + intermediates + +Design notes: + - We do NOT try to be a peephole optimizer. Constant folding for + pure-literal subtrees and a handful of trivial identities (mul by + 1, add 0) are the only "smarts". Everything else compiles to the + obvious ADD/SUB/MUL chain. + - Every materialised value lives in exactly one GP register at the + end. Intermediate registers are freed eagerly to keep pressure low + on the small (16-entry) GP pool. + - PrimExpr nodes we don't handle yet raise loudly. Better to fail + visibly than silently produce wrong ISA. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from tvm import tir + +from .program_shim import ProgramShim + + +# Maximum unsigned literal that fits in a single S_ADDI_INT immediate. +# (preload_addr_reg.py uses the same bound.) +_S_ADDI_MAX = 262143 + + +class ExprMaterializeError(RuntimeError): + pass + + +@dataclass +class MaterializedExpr: + """Result of materialising a PrimExpr. + + `register` holds the value AFTER `isa` is emitted. The caller is + responsible for emitting `isa` into the ISA stream and then either + consuming `register` (using it in a subsequent instruction) and + calling `release()` when done, or copying its value elsewhere first. + """ + register: int + isa: str + owns_register: bool # caller may free `register` via release() + intermediates: List[int] = field(default_factory=list) + _materializer: "ExprMaterializer | None" = None + + def release(self) -> None: + """Free `register` (if owned) and any intermediates we held on to.""" + if self._materializer is None: + return + ra = self._materializer.shim.compiler.register_allocator + for r in self.intermediates: + ra.free_gp([r]) + if self.owns_register: + ra.free_gp([self.register]) + self.intermediates = [] + self.owns_register = False + self._materializer = None + + +class ExprMaterializer: + """Lowers `tir.PrimExpr` -> ISA text + GP register. + + The `symbol_table` maps already-bound `tir.Var` instances (typically + loop indices) to the GP register currently holding their value. The + ForOp emit code in the ISA pass is responsible for installing / + removing entries when entering / leaving a loop body. + """ + + def __init__(self, shim: ProgramShim, symbol_table: Dict[tir.Var, int]) -> None: + self.shim = shim + self.symbol_table = symbol_table + # Per-op IntRAM-idx load cache. An IntRAM-backed loop var + # (binding ``("ram", addr)``) would otherwise emit a fresh + # ``S_LD_INT`` on every single use inside one op's lowering. + # While an op is being lowered we cache ``ram_addr -> gp`` so the + # idx is loaded once and the register reused. The cache is a + # STACK of scopes (op lowering nests: a ``for`` op's handler + # dispatches child-op handlers); ``begin_op`` pushes a scope, + # ``end_op`` pops it and frees the GPs that scope loaded. + self._idx_cache_stack: List[Dict[int, int]] = [] + # Optional lowir recorder. When non-None, every top-level + # ``materialize`` call appends the pre-register symbolic + # expression here, tagged with the current op index. This is + # the "last variable-form" snapshot the lowir report dumps — + # the exact expressions the ISA actually consumes, captured at + # the single var->gp chokepoint so the report can never drift + # from real codegen. None (default) = zero overhead. + self._lowir_log: Optional[List[tuple]] = None + self._lowir_op_idx: int = -1 + + # ------------------------------------------------------------------ + # lowir recording — see _lowir_log + # ------------------------------------------------------------------ + def enable_lowir_log(self) -> None: + """Start recording materialized expressions for the lowir report.""" + self._lowir_log = [] + + def lowir_log(self) -> List[tuple]: + """Recorded ``(op_idx, expr_str)`` entries; empty if disabled.""" + return self._lowir_log or [] + + def set_lowir_op_idx(self, idx: int) -> None: + """Tag subsequent recordings with this HLIR op index.""" + self._lowir_op_idx = idx + + # ------------------------------------------------------------------ + # per-op lifetime — see _idx_cache_stack + # ------------------------------------------------------------------ + def begin_op(self) -> None: + """Open a fresh idx-load cache scope for one op's lowering.""" + self._idx_cache_stack.append({}) + + def end_op(self) -> None: + """Close the current scope: unpin + free every idx GP it loaded.""" + if not self._idx_cache_stack: + return + scope = self._idx_cache_stack.pop() + ra = self.shim.compiler.register_allocator + for gp in scope.values(): + ra.unpin_gp(gp) + ra.free_gp([gp]) + + # ------------------------------------------------------------------ + # public API + # ------------------------------------------------------------------ + def materialize(self, expr) -> MaterializedExpr: + """Top-level entry. Always returns a MaterializedExpr. + + Before dispatch we apply a single peephole optimisation on the + incoming PrimExpr: if every ``tir.Var`` referenced inside ``expr`` + has an ``IntImm`` binding in ``symbol_table`` (the canonical case + for fully-unrolled loop iterations), we substitute the Vars to + their IntImm values and run ``arith.Analyzer().simplify`` over + the result. This collapses expressions like + ``Add(IntImm(base), Mul(IntImm(iter), IntImm(stride)))`` to a + single ``IntImm(base + iter * stride)`` so ``_materialize`` emits + a single ``S_ADDI_INT`` instead of a 3-instruction chain. + + The optimisation is a no-op when: + * ``expr`` is not a tir.PrimExpr (int / etc.) — bypassed + * any referenced Var has a non-IntImm binding (e.g. ``("ram", + addr)`` for serial-loop idx vars) — falling through preserves + the legacy materialise path for those. + """ + if isinstance(expr, tir.PrimExpr): + expr = self._peephole_const_fold(expr) + if self._lowir_log is not None: + # Record the symbolic expression BEFORE register lowering — + # tir.Var loop indices (head_phase, row, ...) survive in the + # string, which is exactly the "last variable-form" the + # lowir report wants. str() of a tir node is side-effect free. + self._lowir_log.append((self._lowir_op_idx, str(expr))) + return self._materialize(expr) + + def _peephole_const_fold(self, expr): + """Substitute every ``tir.Var`` in ``expr`` that has an + ``IntImm`` binding in ``self.symbol_table``, then run + ``arith.Analyzer().simplify`` over the result. + + If any referenced Var has a non-IntImm binding (or no binding), + the original ``expr`` is returned unchanged — the legacy + materialise path then handles it (loading from IntRAM / using a + live GP / emitting an ADD chain). + """ + # Lazy imports — arith / stmt_functor are heavy and not always + # needed (the legacy callers without symbol_table go through + # the unchanged path). + from tvm import arith + from tvm.tir import stmt_functor + + # Collect every free Var. If any isn't an IntImm-bound entry in + # symbol_table, don't try to substitute (the materialise path + # is the safe fallback). + free_vars: List[tir.Var] = [] + + def _collect(node): + if isinstance(node, tir.Var): + if node not in free_vars: + free_vars.append(node) + + stmt_functor.post_order_visit(expr, _collect) + if not free_vars: + # Pure-literal expr — substitution would do nothing; just + # simplify in case there's still arithmetic to fold. + try: + return arith.Analyzer().simplify(expr) + except Exception: + return expr + + var_map: Dict[tir.Var, tir.PrimExpr] = {} + for v in free_vars: + binding = self.symbol_table.get(v) + if isinstance(binding, tir.IntImm): + var_map[v] = binding + else: + # Any Var without an IntImm binding — bail. The + # materialiser handles "live GP" / "ram-backed idx" + # bindings natively. + return expr + substituted = stmt_functor.substitute(expr, var_map) + try: + return arith.Analyzer().simplify(substituted) + except Exception: + return substituted + + # ------------------------------------------------------------------ + # core dispatch + # ------------------------------------------------------------------ + def _materialize(self, expr) -> MaterializedExpr: + # Plain Python ints sneak in via address-allocation; treat them + # as IntImm. + if isinstance(expr, int): + return self._materialize_int(int(expr)) + + if isinstance(expr, tir.IntImm): + return self._materialize_int(int(expr.value)) + + if isinstance(expr, tir.Var): + return self._materialize_var(expr) + + if isinstance(expr, tir.Add): + # Flatten one nested Add of constants: Add(c1, Add(c2, x)) -> + # Add(c1+c2, x). Saves a load+add when the inner constant comes + # from a buffer base and the outer from a lane offset (or vice + # versa). Also helps the `Add(imm, var)` fast-path below kick in. + a, b = expr.a, expr.b + if _is_intlike(a) and isinstance(b, tir.Add): + if _is_intlike(b.a): + a, b = tir.IntImm("int32", _int_value(a) + _int_value(b.a)), b.b + elif _is_intlike(b.b): + a, b = tir.IntImm("int32", _int_value(a) + _int_value(b.b)), b.a + elif _is_intlike(b) and isinstance(a, tir.Add): + if _is_intlike(a.a): + a, b = tir.IntImm("int32", _int_value(b) + _int_value(a.a)), a.b + elif _is_intlike(a.b): + a, b = tir.IntImm("int32", _int_value(b) + _int_value(a.b)), a.a + # Fast path: Add(IntImm, X) -> S_ADDI_INT (one instr) when the + # immediate fits and the OTHER side isn't itself a literal (the + # both-literal case should constant-fold via _materialize_binop). + if _is_intlike(a) and not _is_intlike(b) and 0 <= _int_value(a) <= _S_ADDI_MAX: + return self._materialize_unary_imm(b, "S_ADDI_INT", _int_value(a)) + if _is_intlike(b) and not _is_intlike(a) and 0 <= _int_value(b) <= _S_ADDI_MAX: + return self._materialize_unary_imm(a, "S_ADDI_INT", _int_value(b)) + return self._materialize_binop( + a, b, "S_ADD_INT", lambda x, y: x + y, identity_const=0 + ) + + if isinstance(expr, tir.Sub): + # x - 0 -> x (but NOT 0 - x, which would be negation). + if _is_intlike(expr.b) and _int_value(expr.b) == 0: + return self._materialize(expr.a) + return self._materialize_binop(expr.a, expr.b, "S_SUB_INT", lambda x, y: x - y) + + if isinstance(expr, tir.Mul): + # Fold both-literal subtrees FIRST (before strength reduction), + # so e.g. 4*64 collapses to a single LI of 256 rather than to + # an S_SLLI_INT we don't actually need. + if _is_intlike(expr.a) and _is_intlike(expr.b): + return self._materialize_int(_int_value(expr.a) * _int_value(expr.b)) + # Strength reduce `x * 2^k` -> S_SLLI_INT. One instr instead + # of two (avoids the LI for the literal multiplier and avoids + # using the multiplier itself). + shift = _try_pow2_shift_amount(expr.b) + if shift is not None: + return self._materialize_unary_imm(expr.a, "S_SLLI_INT", shift) + shift = _try_pow2_shift_amount(expr.a) + if shift is not None: + return self._materialize_unary_imm(expr.b, "S_SLLI_INT", shift) + return self._materialize_binop( + expr.a, expr.b, "S_MUL_INT", lambda x, y: x * y, identity_const=1 + ) + + # FloorDiv / FloorMod: PLENA ISA has no integer divide and no + # shift, so we can ONLY handle the case where both operands are + # literals (compile-time fold) or where the divisor is 1. + # Anything else surfaces as an error -- the kernel author has to + # restructure their code to avoid runtime division. + if isinstance(expr, tir.FloorDiv): + return self._materialize_floordivmod(expr.a, expr.b, "//", lambda x, y: x // y) + + if isinstance(expr, tir.FloorMod): + return self._materialize_floordivmod(expr.a, expr.b, "%", lambda x, y: x % y) + + # tir.shift_left / shift_right surface as Call nodes (Op("tir.shift_*")). + # Constant-fold both-literal cases; lower the rest to PLENA shift + # instructions (S_SLLI_INT for literal RHS, S_SLL_INT for register RHS). + if isinstance(expr, tir.Call): + op_name = getattr(expr.op, "name", None) if hasattr(expr, "op") else None + if op_name in ("tir.shift_left", "tir.shift_right"): + lhs, rhs = expr.args[0], expr.args[1] + py = (lambda a, b: a << b) if op_name == "tir.shift_left" else (lambda a, b: a >> b) + if _is_intlike(lhs) and _is_intlike(rhs): + return self._materialize_int(py(_int_value(lhs), _int_value(rhs))) + imm_op = "S_SLLI_INT" if op_name == "tir.shift_left" else "S_SRLI_INT" + reg_op = "S_SLL_INT" if op_name == "tir.shift_left" else "S_SRL_INT" + if _is_intlike(rhs): + return self._materialize_unary_imm(lhs, imm_op, _int_value(rhs)) + # Variable shift amount: S_SLL_INT rd, rs1, rs2. + return self._materialize_binop(lhs, rhs, reg_op, py) + + raise ExprMaterializeError( + f"unsupported PrimExpr node: {type(expr).__name__} ({expr!r})" + ) + + # ------------------------------------------------------------------ + # leaf cases + # ------------------------------------------------------------------ + def _materialize_int(self, n: int) -> MaterializedExpr: + """Produce a register holding integer literal `n`.""" + ra = self.shim.compiler.register_allocator + r = ra.allocate_gp(1)[0] + if 0 <= n <= _S_ADDI_MAX: + isa = f"S_ADDI_INT gp{r}, gp0, {n}\n" + elif n > _S_ADDI_MAX: + # Two-instruction form: load upper 20 bits, then add lower 12. + upper = n >> 12 + lower = n & 0xFFF + isa = ( + f"S_LUI_INT gp{r}, {upper}\n" + f"S_ADDI_INT gp{r}, gp{r}, {lower}\n" + ) + else: + # Negative immediates aren't part of typical PLENA use cases + # (offsets, sizes are >=0). Surface this loudly when it + # eventually comes up so we can decide on a proper encoding. + raise ExprMaterializeError( + f"negative int literal not supported yet: {n}" + ) + # Eager flush: write the ISA to generated_code immediately so + # call-order matches emit-order. Lazy isa strings caused + # cross-call register clobbering (a later allocate_gp triggered + # auto_spill or reuse that interleaved with an earlier lazy + # "S_LD_INT gp{r}, ..." -- the same physical reg ended up being + # loaded twice with different values, second one winning). + self.shim.compiler.generated_code += isa + return MaterializedExpr( + register=r, isa="", owns_register=True, _materializer=self + ) + + def _materialize_var(self, v: tir.Var) -> MaterializedExpr: + """Look up a bound var in the symbol table. + + Two binding forms: + * ``int`` — GP reg already holding the value + (legacy / unroll loop idx). No alloc, no ISA. + * ``("ram", intram_addr)`` — value lives in IntRAM (serial + loop idx; see ``_emit_for``). Borrow a fresh GP, emit a + ``S_LD_INT`` to load it, and return as ``owns_register=True`` + so the caller's release() returns the GP to the pool. Every + use re-loads (cheap; avoids pinning a permanent GP for the + idx in deeply nested kernels). + """ + if v not in self.symbol_table: + raise ExprMaterializeError( + f"unbound tir.Var {v.name!r}; not in symbol_table " + f"(known: {[x.name for x in self.symbol_table]!r})" + ) + binding = self.symbol_table[v] + # Unrolled loop var: bound to a constant. Materialise the literal + # into a register (the materializer constant-folds any enclosing + # arithmetic before this point, so reaching here means the bare + # var itself is needed as a value). + if isinstance(binding, tir.IntImm): + return self._materialize_int(int(binding.value)) + if isinstance(binding, int): + return MaterializedExpr( + register=binding, isa="", owns_register=False, _materializer=self + ) + if isinstance(binding, tuple) and len(binding) == 2 and binding[0] == "ram": + ram_addr = int(binding[1]) + ra = self.shim.compiler.register_allocator + # Per-op cache: if this op's lowering already loaded this + # IntRAM idx, reuse the register — skip the redundant + # S_LD_INT. The cached GP is owned by the op scope (freed by + # ``end_op``), so it is handed out as ``owns_register=False`` + # — the caller's ``release()`` must NOT free it, or a later + # use in the same op would read a dead register. + scope = self._idx_cache_stack[-1] if self._idx_cache_stack else None + if scope is not None and ram_addr in scope: + return MaterializedExpr( + register=scope[ram_addr], isa="", + owns_register=False, _materializer=self, + ) + reg = ra.allocate_gp(1)[0] + # IMPORTANT: write the load ISA directly to ``generated_code`` + # rather than the lazy ``isa`` field. Auto-spill (triggered by + # later allocs) emits S_ST/LD_INT eagerly into generated_code, + # so a later isa-string concatenation would reorder relative + # to those spill instructions and silently corrupt the value + # this load was supposed to deliver. + self.shim.compiler.generated_code += ( + f"; load ram-backed idx {v.name} <- intram[{ram_addr}]\n" + f"S_LD_INT gp{reg}, gp0, {ram_addr}\n" + ) + if scope is not None: + # Cache + pin for the rest of this op's lowering so a + # later auto-spill can't evict the value out from under + # a subsequent reuse. ``end_op`` unpins + frees it. + ra.pin_gp(reg) + scope[ram_addr] = reg + return MaterializedExpr( + register=reg, isa="", owns_register=False, + _materializer=self, + ) + # No active op scope (defensive): fall back to caller-owned. + return MaterializedExpr( + register=reg, isa="", owns_register=True, _materializer=self + ) + raise ExprMaterializeError( + f"symbol_table[{v.name!r}] has unsupported binding {binding!r}" + ) + + # ------------------------------------------------------------------ + # binary ops (Add / Sub / Mul share this skeleton) + # ------------------------------------------------------------------ + def _materialize_binop( + self, + lhs, + rhs, + opcode: str, + py_op, + identity_const: int | None = None, + ) -> MaterializedExpr: + # Constant fold both-literal subtrees so we don't burn a register + # on something the compiler already knows. + if _is_intlike(lhs) and _is_intlike(rhs): + return self._materialize_int(py_op(_int_value(lhs), _int_value(rhs))) + + # Trivial identity: x * 1 / 1 * x -- skip the multiplication. + if identity_const is not None: + if _is_intlike(rhs) and _int_value(rhs) == identity_const: + return self._materialize(lhs) + if _is_intlike(lhs) and _int_value(lhs) == identity_const: + return self._materialize(rhs) + + m_lhs = self._materialize(lhs) + # Pin m_lhs.register across both the m_rhs materialise AND the + # out_reg alloc: any allocate_gp in between may trigger auto_spill, + # which picks non-pinned in-use regs as victims. Without the pin, + # m_lhs's value could be silently displaced to IntRAM and the same + # physical register handed out as m_rhs's or out_reg's, aliasing + # operands. + ra = self.shim.compiler.register_allocator + lhs_was_owned = m_lhs.owns_register + if lhs_was_owned: + ra.pin_gp(m_lhs.register) + m_rhs = None + try: + m_rhs = self._materialize(rhs) + # Same protection for m_rhs while we alloc out_reg. + rhs_was_owned = m_rhs.owns_register + if rhs_was_owned: + ra.pin_gp(m_rhs.register) + try: + out_reg = ra.allocate_gp(1)[0] + finally: + if rhs_was_owned: + ra.unpin_gp(m_rhs.register) + finally: + if lhs_was_owned: + ra.unpin_gp(m_lhs.register) + + # Eager flush. m_lhs.isa and m_rhs.isa are already empty under + # the eager-emit invariant (their ISA was written to + # generated_code at construction time); we keep the `+ m.isa` + # bits for any legacy MaterializedExpr that might still carry a + # non-empty isa string. + self.shim.compiler.generated_code += m_lhs.isa + m_rhs.isa + ( + f"{opcode} gp{out_reg}, gp{m_lhs.register}, gp{m_rhs.register}\n" + ) + isa = "" + + # Eagerly free operand registers we own; the result is in out_reg. + if m_lhs.owns_register: + ra.free_gp([m_lhs.register]) + if m_rhs.owns_register: + ra.free_gp([m_rhs.register]) + # Inherit any intermediates the operands collected (so the caller + # can release them transitively if they want). + intermediates = list(m_lhs.intermediates) + list(m_rhs.intermediates) + + return MaterializedExpr( + register=out_reg, + isa=isa, + owns_register=True, + intermediates=intermediates, + _materializer=self, + ) + + + def _materialize_floordivmod(self, lhs, rhs, op_str: str, py_op) -> MaterializedExpr: + """FloorDiv / FloorMod: only fold-or-identity-or-shift, never an + actual hardware divide (PLENA has no integer divide instruction). + """ + if _is_intlike(lhs) and _is_intlike(rhs): + b = _int_value(rhs) + if b == 0: + raise ExprMaterializeError(f"division by zero in expr ({lhs} {op_str} {rhs})") + return self._materialize_int(py_op(_int_value(lhs), b)) + + # x // 1 -> x ; x % 1 -> 0 + if _is_intlike(rhs) and _int_value(rhs) == 1: + if op_str == "//": + return self._materialize(lhs) + else: # mod 1 is always 0 + return self._materialize_int(0) + + # x // 2^k -> S_SRLI_INT x, k. Unblocks the most common + # "runtime divide" case (block index from element index). + if op_str == "//": + shift = _try_pow2_shift_amount(rhs) + if shift is not None: + return self._materialize_unary_imm(lhs, "S_SRLI_INT", shift) + + # x % 2^k = x - (x // 2^k) * 2^k. PLENA has no bitwise AND, but + # the shift + multiply + subtract sequence uses only ops we + # already support. Lower by rewriting the PrimExpr and + # re-entering the materializer — FloorDiv hits the S_SRLI_INT + # branch above and Mul-by-pow2 hits the S_SLLI_INT path the + # binop dispatcher already handles. + if op_str == "%": + shift = _try_pow2_shift_amount(rhs) + if shift is not None: + m = 1 << shift + shifted = tir.FloorDiv(lhs, tir.IntImm("int32", m)) + scaled = tir.Mul(shifted, tir.IntImm("int32", m)) + return self._materialize(tir.Sub(lhs, scaled)) + + raise ExprMaterializeError( + f"cannot lower runtime {op_str}: PLENA ISA has no integer divide and " + f"no bitwise-AND. The only supported runtime forms are `x // 2^k` " + f"(via S_SRLI_INT) and `x % 2^k` (lowered as `x - (x // 2^k) * 2^k`). " + f"Got `{lhs} {op_str} {rhs}`. Restructure the kernel so this is " + f"computed at compile time, or use a power-of-2 divisor." + ) + + + def _materialize_unary_imm( + self, + operand, + opcode: str, + imm: int, + _identity_when_zero: bool = True, + ) -> MaterializedExpr: + """Common shape for ` rd, rs1, imm` where `rs1` comes from + materialising `operand` and `imm` is an integer baked into the ISA + text. + + Used for shifts (S_SLLI_INT / S_SRLI_INT) where the shift amount + is a compile-time literal. + """ + # Shift by zero is a no-op -- skip the instruction entirely. + if _identity_when_zero and imm == 0: + return self._materialize(operand) + + m_operand = self._materialize(operand) + ra = self.shim.compiler.register_allocator + # Pin m_operand's register across the out_reg alloc -- same race + # as _materialize_binop: an auto_spill triggered here could + # otherwise displace m_operand's value to IntRAM and reuse the + # same physical register for out_reg, causing reads from + # m_operand.register to silently return out_reg's content. + if m_operand.owns_register: + ra.pin_gp(m_operand.register) + try: + out_reg = ra.allocate_gp(1)[0] + finally: + if m_operand.owns_register: + ra.unpin_gp(m_operand.register) + # Eager flush (see _materialize_binop / _materialize_var for + # the rationale -- lazy isa strings interleave incorrectly with + # eager auto-spill / ram-idx loads). + self.shim.compiler.generated_code += m_operand.isa + ( + f"{opcode} gp{out_reg}, gp{m_operand.register}, {imm}\n" + ) + isa = "" + if m_operand.owns_register: + ra.free_gp([m_operand.register]) + return MaterializedExpr( + register=out_reg, + isa=isa, + owns_register=True, + intermediates=list(m_operand.intermediates), + _materializer=self, + ) + + +def _is_intlike(x) -> bool: + return isinstance(x, int) or isinstance(x, tir.IntImm) + + +def _int_value(x) -> int: + return int(x.value) if isinstance(x, tir.IntImm) else int(x) + + +def _try_pow2_shift_amount(x) -> int | None: + """If `x` is a positive int literal that is a power of two, return its + log2 (i.e. the shift amount). Otherwise None. + + Caps at 31 because the PLENA shift instructions take the shift amount + mod 32, so anything >= 32 would silently misbehave; we'd rather force + the caller down the regular MUL/DIV path (which still folds at compile + time if both sides are literal). + """ + if not _is_intlike(x): + return None + n = _int_value(x) + if n <= 1 or (n & (n - 1)) != 0: + return None + k = n.bit_length() - 1 + if k > 31: + return None + return k + + +__all__ = ["ExprMaterializer", "MaterializedExpr", "ExprMaterializeError"] diff --git a/tilelang_tvm_compiler/frontend/__init__.py b/tilelang_tvm_compiler/frontend/__init__.py new file mode 100644 index 0000000..5903103 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/__init__.py @@ -0,0 +1,13 @@ +"""tilelang -> PLENA-flavored TIR frontend (legacy). + +The legacy graph-IR pipeline lives in ``frontend/passes/`` and used to +expose ``compile_func`` as the public entry. The active pipeline now +runs through ``frontend/mid_ir/`` instead, so this package's __init__ +intentionally does NOT eagerly import ``compile_func`` — that would +cause a circular import when ``..pipeline`` (the new top-level +compile_kernel) imports ``frontend/passes/inline_let_stmts``, +``lower_compound_fp_stores``, and ``frontend/mid_ir/passes/...``. + +Callers that still need ``compile_func`` for some legacy reason can +import it directly: ``from .pipeline import compile_func``. +""" diff --git a/tilelang_tvm_compiler/frontend/gemm_macros.py b/tilelang_tvm_compiler/frontend/gemm_macros.py new file mode 100644 index 0000000..9d689ee --- /dev/null +++ b/tilelang_tvm_compiler/frontend/gemm_macros.py @@ -0,0 +1,104 @@ +"""User-facing helpers for tagging a `T.gemm` with an explicit PLENA kind. + +The user-facing API has **two** kinds today (plus one reserved): + + * ``"overwrite"`` (the default — used when no ``T.attr`` wraps the + gemm) — every gemm that is **not** head-fused. Internally the + compiler decides between two HW lowerings based on the LHS shape: + + * LHS rows == 1 → ``plena.mv`` (M_MV / M_MV_WO, per-head 1D LHS) + * otherwise → ``plena.matmul`` (M_MM / M_MM_WO, per-head 2D) + + Lane-axis layout of the operands is also handled automatically: + + * If the surrounding DMA / btmm / extern call already marked an + operand's lane axis, the gemm leaves it alone (idempotent — this + preserves the "matmul is neutral" semantics for the whole-buffer + DMA-driven case). + * If an operand has no surrounding marking (typical for fragment + outputs like PV_loc in flash-attention P @ V), the gemm marks + LHS=ROW_STACK and RHS / DST=COL_PACK, so each lane addresses + its own head slice. + + Sliced operands are supported: starts on any of A / B / C are folded + into ``lhs_offset / rhs_offset / dst_offset``. Whole-buffer gemms in + a lane group get the per-lane offset auto-injected from each + buffer's lane-axis stride — kernel authors never have to spell out a + ``by * stride`` literal. + + * ``"btmm"`` — head-fused matmul (Q @ K^T style: same Q broadcast + across all lanes, K split per lane, one per-head score row out per + lane). Lowers to ``plena.btmm`` or ``plena.btmv`` (auto-dispatched + on LHS rows the same way ``"overwrite"`` picks matmul vs mv). The + kernel must already have set up a head-fused grid (``T.Kernel`` with + a ``head_like`` axis at extent ``btmm_lane_count``); + ``transpose_B=True`` is the typical Q@K^T case. + + * ``"add"`` (**reserved, not yet implemented**) — additive + ``C += A @ B``. The planned interface: kernel author allocates a + scratch buffer and passes it via ``T.attr`` around the gemm: + + scratch = T.alloc_fragment((rows, hlen), "float16") + with T.attr(scratch.data, "plena.gemm_scratch", 0): + with T.attr(0, KIND, "add"): + T.gemm(A, B, C) # C += A @ B + + The lowering would emit ``plena.matmul → scratch`` then + ``plena.v_add(C, scratch, C)``. Currently the lowering raises + ``NotImplementedError``; for now write the two ops manually + (``T.gemm(A, B, scratch)`` + an inline T.Parallel add that + fuse_elementwise auto-folds to ``plena.v_add``). + +Usage (inline form — REQUIRED inside a ``@T.prim_func`` body, since +tilelang's eager TVMScript builder does AST analysis and cannot trace +into a helper function call):: + + from tilelang_tvm_compiler.frontend.gemm_macros import KIND + + @T.prim_func + def k(...): + ... + # Default — no T.attr needed: + T.gemm(A_sh, B_sh, C_loc) # whole-buffer or per-head, compiler picks + T.gemm(S_loc, V_sh, PV_loc) # per-head P @ V (decode if rows=1, prefill else) + + # Head-fused Q @ K^T: + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + +The ``T.attr`` wraps the next statement (the ``T.gemm``) in a +``tir.AttrStmt`` carrying ``attr_key="plena.gemm_kind"`` and +``value=StringImm()``. The ``annotate_gemm_path`` pass picks it +up and overrides any shape-driven auto-detection. +""" + +from __future__ import annotations + + +# Attribute key used by `annotate_gemm_path` to read the explicit kind. +# Kernel authors should pass this constant to `T.attr(0, KIND, "")` +# (rather than typing the literal string) so refactors of the key name +# stay consistent. +KIND = "plena.gemm_kind" + + +# Valid kind values (mirrors the lookup in `annotate_gemm_path`). +# OVERWRITE is the default — applied when no ``T.attr(KIND, ...)`` wraps +# the gemm. BTMM is the explicit head-fused path. ADD is reserved (the +# annotate pass accepts it; the lowering raises NotImplementedError). +OVERWRITE = "overwrite" +BTMM = "btmm" +ADD = "add" + + +VALID_KINDS = (OVERWRITE, BTMM, ADD) + + +# AttrStmt key the kernel author would use to attach a scratch buffer +# to a kind="add" gemm (since ``T.gemm`` itself has no slot for it). +# Reserved for the future kind="add" lowering — see gemm_macros +# docstring above and PIPELINE_ARCHITECTURE.md § 5.4. +GEMM_SCRATCH_KEY = "plena.gemm_scratch" + + +__all__ = ["KIND", "OVERWRITE", "BTMM", "ADD", "VALID_KINDS", "GEMM_SCRATCH_KEY"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/__init__.py b/tilelang_tvm_compiler/frontend/mid_ir/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilelang_tvm_compiler/frontend/mid_ir/cluster_guard.py b/tilelang_tvm_compiler/frontend/mid_ir/cluster_guard.py new file mode 100644 index 0000000..ad6f825 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/cluster_guard.py @@ -0,0 +1,45 @@ +"""Helper: should the cluster pipeline (pass_3 onwards) run on this MidFunc? + +Two skip conditions, either alone is enough: + + 1. ``MidFunc.lane_axes`` is empty — kernel didn't declare any axis + for cluster fusion. Treat as "this kernel doesn't need cluster", + no error. + 2. Every non-global buffer's last dim is already >= MLEN. + A HW vector op is MLEN-wide; if every on-chip buffer already + covers a full vector along its trailing axis, one lane fills + a single instruction by itself — cluster fusion buys nothing. + +The cluster passes (split / distribute_cluster / async_wrap / view / +fuse) all check this guard at entry and no-op if either condition +holds. +""" + +from .ir import MidFunc + + +# MLEN: hardware vector width — read from plena_settings.toml's active +# mode, the single source of truth shared with the simulator. +from ...plena_settings import mlen as _mlen + +MLEN = _mlen() + + +def _last_dim(buf) -> int: + if not buf.shape: + return 0 + last = buf.shape[-1] + return int(last) if isinstance(last, int) else 0 + + +def should_skip_cluster(func: MidFunc) -> bool: + """True if cluster fusion is unnecessary for this MidFunc.""" + if not func.lane_axes: + return True + non_global = [b for b in func.allocs if b.scope != "global"] + if non_global and all(_last_dim(b) >= MLEN for b in non_global): + return True + return False + + +__all__ = ["should_skip_cluster", "MLEN"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/ir.py b/tilelang_tvm_compiler/frontend/mid_ir/ir.py new file mode 100644 index 0000000..c149090 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/ir.py @@ -0,0 +1,801 @@ +"""Mid-IR node definitions. + +A small dataclass IR sitting between raw TIR and PLENA HLIR. Designed +to make the **lane-fusion rewrite a sequence of mechanical, one-thing-each +passes** instead of the layered Graph-IR-everything approach. + +Pipeline (each step is one pass; the IR shape only changes in the +specific ways noted): + + raw tir.PrimFunc + │ + │ pass_1_fold: nested for + BufferStore → Elementwise / Reduce / Broadcast + │ pass_2_mark: tag dma / btmm gemm / elementwise sites with .marker + │ (blockIdx still alive throughout) + │ pass_3_split: pick a blockIdx axis; split it into (number, phase); + │ grow every non-global buffer by one outer `cluster` dim + │ — only *add*, never permute. + │ pass_4_async: wrap each marked op in Async(...) + │ pass_5_loop: introduce `for phase in [0, cluster_count)` but break + │ it into multiple fors at every Async boundary + │ pass_6_fuse: collapse `for phase: ` into + │ MultiLaneOp(op, dim_map = {cluster_axis: (buf, dim)}) + │ pass_7_perm: use dim_map to permute the cluster axis into the + │ physical-layout slot per buffer + │ + mid_ir ready + │ + │ pass_8_to_plena: lower to HLIR + │ + HLIRModule + +This file ONLY defines the nodes + a tiny printer. No pass logic here. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union + + +# --------------------------------------------------------------------------- +# Op kinds +# --------------------------------------------------------------------------- + + +class BinOp(Enum): + ADD = "add" + SUB = "sub" + MUL = "mul" + MAX = "max" + MIN = "min" + + +class UnaryOp(Enum): + EXP = "exp" + RECI = "reci" + SQRT = "sqrt" + COPY = "copy" + NEG = "neg" + + +class ReduceOp(Enum): + SUM = "sum" + MAX = "max" + MIN = "min" + + +class Marker(Enum): + """Tag attached to op sites by pass_2_mark. Drives async/cluster handling + downstream. Untagged ops stay sequential, never get lane-fused.""" + DMA = "dma" # HBM ↔ on-chip transfer + BTMM = "btmm" # head-fused matmul + LANE_OP = "lane_op" # an elementwise/reduce that lives inside the cluster + + +class AxisRole(Enum): + """The role a single axis plays for one op operand. + + Every op carries, for every operand BufferRef, a per-axis tag + saying *why* that axis is there. Downstream passes (lower, + codegen) read these directly instead of inferring from buffer + shape and cluster_dim — that "infer" path is the source of every + silent off-by-vlen / off-by-lane bug we've hit. + + BATCH outer fan-out: lower wraps the op in a ``for`` over + this axis (one HLIR issue per index). + SIMD inner vector axis: one HW vector instruction covers + a contiguous run along it; extent is per-issue length. + REDUCE axis collapsed by a Reduce; appears only on the src. + BROADCAST axis only on the dst — the src has no corresponding + dim and is replayed across all index values here. + CLUSTER HW lane axis: a single multi-lane instruction covers + every index along it; not wrapped in a for at lower. + GEMM_M / GEMM_N / GEMM_K + Gemm operand roles, the matmul HW knows how to walk + them natively; one HW instruction covers their full + extents (modulo BATCH outer fan-out wrap, if any). + """ + BATCH = "batch" + SIMD = "simd" + REDUCE = "reduce" + BROADCAST = "broadcast" + CLUSTER = "cluster" + GEMM_M = "gemm_m" + GEMM_N = "gemm_n" + GEMM_K = "gemm_k" + + +@dataclass +class AxisInfo: + """Per-axis metadata for a BufferRef on one op operand.""" + role: AxisRole + extent: int + + +# --------------------------------------------------------------------------- +# Buffer references +# --------------------------------------------------------------------------- + + +@dataclass +class BufferDef: + """A buffer the kernel allocates or receives. + + ``shape`` is the *logical* shape (what the kernel author wrote). + Passes that grow the buffer (e.g. pass_3) update this in place by + prepending dims; pass_7 may permute. ``scope`` is one of + ``"global"`` / ``"shared"`` / ``"fragment"`` etc — same string the + kernel uses with T.alloc_shared / T.alloc_fragment. + + ``cluster_dim`` is the index into ``shape`` of the lane/cluster + axis (the dim pass_3_split prepended when growing the buffer for + cluster fusion). pass_4b_view / pass_5b_burn_view track it through + any axis permutation. ``None`` for buffers that don't have a + cluster axis (HBM params, user-declared ``global.*`` caches, etc). + Downstream addressing reads this directly instead of guessing the + lane position from shape values. + """ + name: str + shape: List[int] # logical extents (int-only; symbolic later) + dtype: str + scope: str = "global" + cluster_dim: Optional[int] = None + + def with_outer_dim(self, extent: int) -> "BufferDef": + # Prepending a dim shifts every existing axis by +1, including + # the cluster_dim marker if set. + new_cluster = None if self.cluster_dim is None else self.cluster_dim + 1 + return BufferDef( + name=self.name, + shape=[extent] + list(self.shape), + dtype=self.dtype, + scope=self.scope, + cluster_dim=new_cluster, + ) + + +@dataclass +class BufferRef: + """A read/write of a buffer at a given index tuple. + + ``indices`` lists one IndexExpr per dim of ``buffer.shape``. Each + is either an int (static), a string (symbolic var name — e.g. + ``"by"``, ``"by_phase"``, ``"row"``), or a Slice (``":"``-style + whole-axis access). + + By convention pass_1_fold produces refs where every dim that the + fused loop spanned is a ``Slice``, and remaining dims are concrete + indices. + + ``view_perm`` is set by pass_4b_view to express the op-local view + permutation: ``view_perm[i]`` says which physical dim the op sees + as logical dim ``i``. Identity (None) means logical = physical. + Example: physical buffer shape ``[lane=4, S=64, D=16]``; + ``view_perm=[1, 0, 2]`` means the op sees ``[S=64, lane=4, D=16]`` + (BSHD view of a BHSD-shell buffer). + """ + buffer: BufferDef + indices: List["IndexExpr"] + view_perm: Optional[List[int]] = None + + +@dataclass +class Slice: + """``:`` — whole-axis access. May carry an explicit range later + if needed; for now whole-axis is the only kind we care about.""" + pass + + +class VarRef: + """Identity-based reference to a ``tir.Var`` used as a mid_ir index. + + Equality is ``var.same_as(other.var)`` — two ``VarRef`` instances + compare equal iff they wrap the same underlying TIR object. Hash is + ``id(var)``. The wrapped Var's ``name_hint`` is kept for dump and + debugging only; it has no role in comparison. + + Why this exists: prior versions stored bare ``str`` names inside + ``BufferRef.indices`` and downstream passes did string-equality to + decide whether an index referenced the cluster lane axis. Two + distinct ``tir.Var`` objects sharing a ``name_hint`` would silently + collide. ``VarRef`` makes identity the contract so a name collision + can't masquerade as a real reference. + + Constructed with a bare ``tir.Var`` object: + + VarRef(some_tir_var) + + We intentionally do *not* import ``tvm`` at module scope — mid_ir + is supposed to stay TVM-agnostic so unit tests can build refs by + hand. The wrapped object only needs ``same_as(other)`` and a + ``name`` attribute (real ``tir.Var`` already satisfies this; tests + can pass any duck-type). + """ + __slots__ = ("var",) + + def __init__(self, var): + # tir.Var supports same_as; duck-typed test fakes need to too. + self.var = var + + def __eq__(self, other): + if not isinstance(other, VarRef): + return False + # Prefer same_as (TIR's identity check); fall back to ``is`` for + # plain test doubles that don't implement it. + same_as = getattr(self.var, "same_as", None) + if same_as is not None: + return bool(same_as(other.var)) + return self.var is other.var + + def __hash__(self): + return id(self.var) + + def __repr__(self): + return f"VarRef({self.name!r})" + + @property + def name(self) -> str: + # ``tir.Var.name`` is the ``name_hint`` — fine for dump. + return str(getattr(self.var, "name", self.var)) + + +# An IndexExpr is one of: int (concrete index / extent literal), +# VarRef (identity-typed variable reference), Slice (whole axis), or a +# compound dict {"op": "add", "args": [...]} for things like +# ``by_phase + by_number*C``. +# +# ``str`` was previously allowed as a stand-in for "variable named X", +# but two distinct ``tir.Var`` objects sharing a name then silently +# collided in passes that compared by name. ``VarRef`` is the +# replacement — it carries the wrapped ``tir.Var`` and compares by +# identity (``var.same_as``). ``str`` is no longer permitted. +IndexExpr = Union[int, VarRef, Slice, dict] + + +# --------------------------------------------------------------------------- +# Op nodes — the three fold targets +# --------------------------------------------------------------------------- + + +@dataclass +class Elementwise: + """``dst[idx] = op(src_0[idx], src_1[idx], ...)`` over matching axes. + + ``op`` is BinOp for 2+ srcs, UnaryOp for 1 src. + + Per-axis roles (the only source of truth for lower): + * ``dst_axes`` — one ``AxisInfo`` per dst dim, in dst-axis order. + * ``src_axes`` — list parallel to ``srcs``; each entry is per-axis + info for that src in src-axis order. A src wrapped in + Broadcast has fewer axes than dst — its ``src_axes`` entry is + shorter and the corresponding dst-axes carry ``BROADCAST`` role. + + Legacy fields ``axis``, ``size``, ``outer_extents`` are kept + transitionally for code that hasn't migrated; new lowering code + must read ``dst_axes`` / ``src_axes`` only. + + ``can_async`` is True when the HW lowering is a single multi-lane + vector instruction (``v_add`` / ``v_exp_v`` etc.). False when the + lowering is per-row (``row_sub_fp_at`` and friends — typically + the case when one src is a Broadcast wrapping a smaller-rank fp + scalar). pass_2_mark sets this; pass_4_async only wraps ops with + can_async=True in Async regions. + """ + dst: BufferRef + srcs: List[Union[BufferRef, "Broadcast"]] + op: Union[BinOp, UnaryOp] + # Per-axis roles + extents (authoritative source for lower). + dst_axes: List[AxisInfo] = field(default_factory=list) + src_axes: List[List[AxisInfo]] = field(default_factory=list) + # Legacy fields kept transitionally for code that hasn't migrated + # to axes; new lowering paths must read dst_axes / src_axes only. + axis: Optional[Union[int, List[int]]] = None + size: int = 1 + marker: Optional[Marker] = None + can_async: bool = False + + +@dataclass +class Broadcast: + """Wrap a smaller-rank src to match a larger-rank dst. + + ``broadcast_dims`` is the list of dst-dim indices along which + ``src`` repeats. E.g. dst is rank 3, src is rank 1 with values per + last-axis position → broadcast_dims = [0, 1]. + """ + src: BufferRef + broadcast_dims: List[int] + + +@dataclass +class Reduce: + """``dst[idx_without_axis] = reduce(src[idx], op, axis)``. + + Per-axis roles (authoritative): + * ``dst_axes`` — one per dst dim (the collapsed axis is gone). + * ``src_axes`` — one per src dim; the collapsed dim is tagged + ``REDUCE``; the others mirror their dst counterpart. + + Legacy ``axis`` carried the collapsed-axis index; transitionally + kept for old code paths. + + ``can_async`` is always False — reduce on PLENA is per-row + (``row_reduce_max_at`` / ``row_reduce_sum_at``), one row at a time + into a per-row fp scalar slot. No multi-lane reduce HW op exists. + """ + dst: BufferRef + src: BufferRef + op: ReduceOp + axis: int + dst_axes: List[AxisInfo] = field(default_factory=list) + src_axes: List[AxisInfo] = field(default_factory=list) + marker: Optional[Marker] = None + can_async: bool = False + + +# --------------------------------------------------------------------------- +# Op nodes — gemm and dma stay as their own kinds (we don't decompose +# them into elementwise/reduce; the HW has dedicated instructions). +# --------------------------------------------------------------------------- + + +@dataclass +class Gemm: + """``c = a @ b`` (transpose flags carried). + + Per-axis roles: + * ``a_axes`` / ``b_axes`` / ``c_axes`` tag every axis on each + operand with one of {BATCH, GEMM_M, GEMM_N, GEMM_K, CLUSTER}. + + The ``kind`` matches the kernel's ``T.attr(0, KIND, ...)`` — most + importantly ``"btmm"`` for head-fused vs the default per-head form. + Pass_2_mark sets the marker to BTMM when kind == "btmm". + + ``can_async`` is True for kind=="btmm" (one multi-lane M_BTMM + instruction); False for kind=="overwrite" (per-head matmul that + runs inside the lane loop, one matmul per lane). + """ + a: BufferRef + b: BufferRef + c: BufferRef + transpose_a: bool = False + transpose_b: bool = False + kind: str = "overwrite" + a_axes: List[AxisInfo] = field(default_factory=list) + b_axes: List[AxisInfo] = field(default_factory=list) + c_axes: List[AxisInfo] = field(default_factory=list) + marker: Optional[Marker] = None + can_async: bool = False + + +@dataclass +class Dma: + """``dst = src`` across a memory scope boundary (HBM ↔ on-chip). + + Both src and dst are BufferRefs whose ``indices`` describe the + slice being transferred. Direction is implicit from src.scope / + dst.scope. + + Per-axis roles: ``src_axes`` / ``dst_axes`` tag each axis with + ``BATCH`` / ``SIMD`` / ``CLUSTER`` per the same conventions as + Elementwise. The slice indices themselves are still in + ``BufferRef.indices``; the axes table is the lower-time view of + "what role does this dim play for this transfer", independent of + static-vs-dynamic-slice rendering. + + ``can_async`` is always True — DMA is always a single multi-lane + HW instruction (``H_LOAD_V`` / ``H_STORE_V`` etc.). + """ + src: BufferRef + dst: BufferRef + src_axes: List[AxisInfo] = field(default_factory=list) + dst_axes: List[AxisInfo] = field(default_factory=list) + marker: Optional[Marker] = None + can_async: bool = False + + +@dataclass +class RawStore: + """A BufferStore that the fold pass couldn't recognize as one of + Elementwise / Broadcast / Reduce. + + Lives inside an enclosing ``For`` loop (or a chain of them) and + represents an opaque per-iteration scalar update. Examples that + end up as RawStore today: + + * ``in_FP_padded[MLEN + k] = 0`` (compound dst index) + * ``shift_FP[m] = in_FP_padded[m + kw]`` (shifted copy) + * any RHS the fold pass doesn't decompose + + Downstream passes treat RawStore as opaque: they don't peek at + ``value``, don't apply cluster/permute rewrites to its indices. + The lowering pass (mid_ir → plena_ir / HLIR) is responsible for + pattern-matching specific RawStore shapes and erroring on + unrecognized ones. + + ``value`` is held as an opaque object (typically a tir.PrimExpr + captured from raw TIR). Mid-IR walkers don't touch it. + """ + dst: BufferRef + value: object # tir.PrimExpr or similar — opaque to mid_ir passes + + +# --------------------------------------------------------------------------- +# Structure nodes +# --------------------------------------------------------------------------- + + +class ParallelKind(Enum): + """Kinds of parallel axes mid-IR represents. Three concepts, never + interchangeable: + + * ``BLOCK_IDX`` — a HW grid axis. N independent program instances + run; no lockstep guarantee. Comes from ``T.Kernel(...)``'s grid + bindings. Has a ``thread_tag`` ("blockIdx.x" / .y / .z). + * ``LOGICAL_GRID`` — a parallel axis the kernel author marked with + ``T.Parallel(...)`` that fold couldn't collapse into an + Elementwise/Reduce. Semantically still N independent instances + (kernel author asserts iteration order doesn't matter), but + not bound to a HW grid dim — it's an inner / kernel-body axis. + No ``thread_tag``. Pass_3 may split a LOGICAL_GRID just like a + BLOCK_IDX. + * ``CLUSTER`` — a lockstep multi-lane axis. ``cluster_count`` lanes + execute the same instruction stream in lockstep, one HW + instruction = one operation across all lanes. Created by pass_3 + when it splits a (BLOCK_IDX | LOGICAL_GRID) lane axis. Carries + ``parent_grid_axis_name`` pointing at the matching grid number + axis. + + Never converted to a For during mid-IR. pass_8_to_plena flattens + every kind into the appropriate HLIR form. + """ + BLOCK_IDX = "blockIdx" + LOGICAL_GRID = "logical_grid" + CLUSTER = "cluster" + + +@dataclass +class ParallelAxis: + """SPMD parallel axis. ``extent`` independent threads execute + ``body`` concurrently. NOT a sequential loop — pass_8 is the only + place where parallelism collapses into a for. + + ``axis_name`` is the user-visible name (e.g. ``"by"``, + ``"by_phase"``, ``"q_block"``). + + ``thread_tag`` carries the underlying ``blockIdx.*`` tag for + BLOCK_IDX axes only. None for LOGICAL_GRID and CLUSTER. + + ``parent_grid_axis_name`` is the cluster→grid back-link set by + pass_3 when it splits a lane axis. A CLUSTER axis carries the + name of the grid-number axis it was split out of (e.g. cluster + ``by_phase`` has parent ``by_number``). None for grid kinds. + + ``original_axis_name`` is the user-visible name the kernel + originally bound this axis to *before* pass_3 split it into + phase + number (e.g. CLUSTER ``by_phase`` originated from ``by``). + Set by pass_3 on the CLUSTER side and on the matching grid number + side too so consumers can look up either half without parsing + name suffixes like ``"_phase"`` / ``"_number"``. + + Identity channels (separate from the string-named channels): + + ``axis_var`` carries the same axis as a :class:`VarRef` — the + identity used by ``BufferRef.indices`` consumers to recognise + references to this axis. Optional during the transitional period + while passes are still being migrated off bare-string indices; + will become required. + + ``original_axis_var`` mirrors ``original_axis_name`` but as a + ``VarRef``: the identity of the kernel-author lane var ``by`` + before pass_3 split it. The CLUSTER and matching number axes + both carry it so HBM-lane-substitution / cluster-collapse can + identify the pre-split var without name matching. + """ + axis_name: str + extent: int + body: List["Stmt"] + kind: ParallelKind + thread_tag: Optional[str] = None + parent_grid_axis_name: Optional[str] = None + original_axis_name: Optional[str] = None + axis_var: Optional[VarRef] = None + original_axis_var: Optional[VarRef] = None + + +@dataclass +class For: + """A truly-sequential loop. Each iteration runs after the previous. + + Comes from the kernel's ``T.serial(...)`` / ``T.unroll(...)`` — + e.g. flash_attention's ``for kv_block`` or conv2d's ``for oc / for + oh``. NEVER carries a ``thread_tag`` and NEVER represents a + parallel axis (that's ``ParallelAxis``). + + ``kind`` is one of ``"serial"`` (default) or ``"unroll"`` (pass_8 + asks the lowering to fully unroll the loop body — used for tiny + KW/KH loops in conv2d). + + ``loop_var_var`` is the same loop var as a :class:`VarRef` — + identity-based handle that downstream passes can match against + ``BufferRef.indices`` entries. Optional during transitional period. + """ + loop_var: str + extent: int + body: List["Stmt"] + kind: str = "serial" # "serial" | "unroll" + loop_var_var: Optional[VarRef] = None + + +@dataclass +class Async: + """Wrap one or more ops in an async region. Pass_5_loop uses Async + boundaries to *split* an enclosing ``for phase`` into multiple fors + (each Async ends up in its own for, fused into a multi-lane HW op + by pass_6_fuse). + """ + body: List["Stmt"] + scope_id: int # unique per async region + + +@dataclass +class MultiLaneOp: + """Output of pass_6_fuse: a single op that the HW executes across + all cluster lanes in one instruction. + + Multi-axis clusters are supported from day one: a kernel that + cluster-fuses both ``by`` and ``q_block`` would produce a + MultiLaneOp with two entries in ``cluster_axis_names`` and + matching-length lists in ``dim_map`` per buffer. + + Fields: + * ``inner`` the underlying op (Dma / Gemm / + Elementwise / Reduce) with + cluster-dim indices pre-resolved. + * ``cluster_axis_names`` the list of axis names this op fuses + across, e.g. ``["by_phase"]`` or + ``["by_phase", "qb_phase"]``. Order + matches the entries in ``dim_map``. + * ``cluster_axis_vars`` parallel list of :class:`VarRef` for + each entry in ``cluster_axis_names``. + Lowering uses these to identify + phase-shorthand indices by identity + (vs ``cluster_axis_names`` which is + kept for HLIR dump / loop_var minting). + * ``dim_map`` ``buf_name -> [dim_idx_for_axis_0, + dim_idx_for_axis_1, ...]``. Length of + each list matches ``len(cluster_axis_names)``. + Pass_7_perm reads this when deciding + where each cluster axis sits physically. + """ + inner: "Op" + cluster_axis_names: List[str] + dim_map: Dict[str, List[int]] + cluster_axis_vars: List[VarRef] = field(default_factory=list) + + +# A "statement" is anything that appears in a body list. +Stmt = Union[ParallelAxis, For, Async, Dma, Gemm, Elementwise, Broadcast, + Reduce, RawStore, MultiLaneOp] + +# An "Op" is the leaf-level op kinds (no structure). +Op = Union[Dma, Gemm, Elementwise, Reduce] + + +# --------------------------------------------------------------------------- +# Function +# --------------------------------------------------------------------------- + + +@dataclass +class MidFunc: + """Top-level: a kernel function in mid-IR form. + + ``params`` are param buffers (always global / HBM). ``allocs`` are + on-chip buffers the kernel allocates. ``body`` is the sequence of + statements at the kernel's grid scope. + + ``lane_axes`` records the kernel author's + ``T.func_attr({"plena.lane_axis": ["by"]})`` declaration so pass_3 + knows which blockIdx(es) to split. Multi-axis cluster is supported + from day one: a kernel can declare ``["by", "q_block"]`` to + cluster-fuse both. Each axis gets its own cluster_count entry, in + the same order. ``cluster_counts`` is filled in by pass_3 + (typically = ``[lane_count]`` = ``[MLEN / btmm_hlen]``). + + Mid-IR output preserves blockIdx — see ``For`` docstring. The + body's outermost layer is typically a chain of ``For(thread_tag= + "blockIdx.*")`` for both untouched grid axes and the *_number + halves of split lane axes. + """ + name: str + params: List[BufferDef] + allocs: List[BufferDef] + body: List[Stmt] + lane_axes: List[str] = field(default_factory=list) + cluster_counts: List[int] = field(default_factory=list) + attrs: Dict[str, str] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Tiny printer (debug only — not stable, format may change freely) +# --------------------------------------------------------------------------- + + +def _fmt_idx(i: IndexExpr) -> str: + if isinstance(i, Slice): + return ":" + if isinstance(i, VarRef): + return i.name + if isinstance(i, dict): + op = i.get("op", "?") + args = i.get("args", []) + return f"({op} {' '.join(_fmt_idx(a) for a in args)})" + return str(i) + + +def _fmt_ref(r: BufferRef) -> str: + body = r.buffer.name + if r.view_perm is not None and list(r.view_perm) != list(range(len(r.indices))): + body += f"" + if not r.indices: + return body + return f"{body}[{', '.join(_fmt_idx(i) for i in r.indices)}]" + + +def _fmt_src(s) -> str: + if isinstance(s, Broadcast): + return f"bcast({_fmt_ref(s.src)}, dims={s.broadcast_dims})" + return _fmt_ref(s) + + +def _fmt_marker(m: Optional[Marker], can_async: bool = False) -> str: + if m is None: + return "" + suffix = " async" if can_async else "" + return f" #{m.value}{suffix}" + + +def _fmt_axes(axes: List[AxisInfo]) -> str: + """Compact ``[role:extent, role:extent, ...]`` dump for a per-op + axes list. Empty list returns ``[]`` (so a missing-axes case is + visually obvious).""" + if not axes: + return "[]" + parts = [f"{a.role.value}:{int(a.extent)}" for a in axes] + return "[" + ", ".join(parts) + "]" + + +def _print_stmt(s: Stmt, indent: int, out: List[str]) -> None: + pad = " " * indent + if isinstance(s, ParallelAxis): + # Display kind directly as the keyword: + # BLOCK_IDX → "grid" (HW grid axis: blockIdx.* binding) + # LOGICAL_GRID → "logical_grid" (kernel-body parallel, no HW binding) + # CLUSTER → "cluster" (lockstep lane axis) + # Suffixes: + # * BLOCK_IDX shows ``[blockIdx.y]`` (its thread_tag) + # * CLUSTER shows ``← `` so the + # cluster→grid back-link is readable at a glance + if s.kind == ParallelKind.BLOCK_IDX: + keyword = "grid" + suffix = f" [{s.thread_tag}]" if s.thread_tag else "" + elif s.kind == ParallelKind.LOGICAL_GRID: + keyword = "logical_grid" + suffix = "" + else: + keyword = "cluster" + suffix = (f" ← {s.parent_grid_axis_name}" + if s.parent_grid_axis_name else "") + out.append( + f"{pad}{keyword} {s.axis_name} in 0..{s.extent}{suffix}:" + ) + for b in s.body: + _print_stmt(b, indent + 1, out) + return + if isinstance(s, For): + kind = "" if s.kind == "serial" else f" ({s.kind})" + out.append(f"{pad}for {s.loop_var} in 0..{s.extent}{kind}:") + for b in s.body: + _print_stmt(b, indent + 1, out) + return + if isinstance(s, Async): + out.append(f"{pad}async #{s.scope_id} {{") + for b in s.body: + _print_stmt(b, indent + 1, out) + out.append(f"{pad}}}") + return + if isinstance(s, Dma): + out.append( + f"{pad}dma {_fmt_ref(s.src)} -> {_fmt_ref(s.dst)}" + f"{_fmt_marker(s.marker, s.can_async)}" + ) + out.append( + f"{pad} src_axes={_fmt_axes(s.src_axes)} " + f"dst_axes={_fmt_axes(s.dst_axes)}" + ) + return + if isinstance(s, Gemm): + ta = "ᵀ" if s.transpose_a else "" + tb = "ᵀ" if s.transpose_b else "" + out.append( + f"{pad}gemm[{s.kind}] {_fmt_ref(s.c)} = " + f"{_fmt_ref(s.a)}{ta} @ {_fmt_ref(s.b)}{tb}" + f"{_fmt_marker(s.marker, s.can_async)}" + ) + out.append( + f"{pad} a_axes={_fmt_axes(s.a_axes)} " + f"b_axes={_fmt_axes(s.b_axes)} c_axes={_fmt_axes(s.c_axes)}" + ) + return + if isinstance(s, Elementwise): + srcs = ", ".join(_fmt_src(x) for x in s.srcs) + axis = f" axis={s.axis}" if s.axis is not None else "" + out.append( + f"{pad}elementwise[{s.op.value}] {_fmt_ref(s.dst)} = " + f"f({srcs}){axis}{_fmt_marker(s.marker, s.can_async)}" + ) + src_axes_strs = [_fmt_axes(sa) for sa in (s.src_axes or [])] + out.append( + f"{pad} dst_axes={_fmt_axes(s.dst_axes)} " + f"src_axes=[{', '.join(src_axes_strs)}]" + ) + return + if isinstance(s, Reduce): + out.append( + f"{pad}reduce[{s.op.value} axis={s.axis}] " + f"{_fmt_ref(s.dst)} = R({_fmt_ref(s.src)})" + f"{_fmt_marker(s.marker, s.can_async)}" + ) + out.append( + f"{pad} src_axes={_fmt_axes(s.src_axes)} " + f"dst_axes={_fmt_axes(s.dst_axes)}" + ) + return + if isinstance(s, RawStore): + out.append(f"{pad}raw_store {_fmt_ref(s.dst)} = ") + return + if isinstance(s, MultiLaneOp): + out.append( + f"{pad}multi_lane (cluster_axes={s.cluster_axis_names}, " + f"dim_map={s.dim_map}) {{" + ) + _print_stmt(s.inner, indent + 1, out) + out.append(f"{pad}}}") + return + out.append(f"{pad}") + + +def format_func(fn: MidFunc) -> str: + """Return a multi-line text dump of ``fn`` for eyeballing.""" + out: List[str] = [] + params = ", ".join( + f"{p.name}: {p.dtype}{tuple(p.shape)} @{p.scope}" for p in fn.params + ) + out.append(f"func @{fn.name}({params})") + if fn.lane_axes: + out.append(f" // lane_axes = {fn.lane_axes}, " + f"cluster_counts = {fn.cluster_counts}") + if fn.allocs: + out.append(" allocs:") + for a in fn.allocs: + out.append(f" {a.name}: {a.dtype}{tuple(a.shape)} @{a.scope}") + out.append(" body:") + for s in fn.body: + _print_stmt(s, 2, out) + return "\n".join(out) + + +__all__ = [ + "BinOp", "UnaryOp", "ReduceOp", "Marker", + "AxisRole", "AxisInfo", + "BufferDef", "BufferRef", "Slice", "VarRef", "IndexExpr", + "Elementwise", "Broadcast", "Reduce", + "Gemm", "Dma", "RawStore", + "ParallelKind", "ParallelAxis", + "For", "Async", "MultiLaneOp", + "MidFunc", + "format_func", +] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/__init__.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/async_wrap.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/async_wrap.py new file mode 100644 index 0000000..54a04a8 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/async_wrap.py @@ -0,0 +1,173 @@ +"""pass_4_async: wrap can_async ops in Async regions. + +Why this pass exists +-------------------- + +After ``split`` + ``distribute_cluster``, every cluster body holds a +mix of: + + * ``can_async=True`` ops (Dma, Gemm[btmm], pure Elementwise) — each + lowers to a single multi-lane HW instruction (M_BTMM / H_LOAD_V / + V_ADD ...). These get wrapped one-per-Async (strict 'one async + one op' rule from SPMD_REWRITE.md); the next pass picks the + physical buffer view per Async, and pass_5 then fuses each + Async into a MultiLaneOp. + + * ``can_async=False`` ops (Reduce, Elementwise containing a + Broadcast src) — these lower to per-row HW instructions + (row_reduce_max_at, row_sub_fp_at, ...) that need a fresh fp + scalar address per lane. They stay in the cluster body bare; + pass_8 emits ``for lane in range(cluster)`` around them. + +What this pass does NOT touch +----------------------------- + + * BufferRef indices / view_perm — that's the next pass. + The buffer-rank vs ref-rank mismatch introduced by ``split`` + persists past this pass; it gets resolved by the view pass. + * Buffer shapes — those were set by pass_3. + * MultiLaneOp synthesis — pass_5. + * Anything outside a cluster body — Fors, RawStore, grid headers + etc. are walked structurally but not rewritten. +""" + +from __future__ import annotations + +from typing import List + +from ..cluster_guard import should_skip_cluster +from ..ir import ( + Dma, Gemm, Elementwise, Reduce, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +class AsyncWrapError(RuntimeError): + pass + + +class _IdCounter: + """Module-global counter for generating fresh Async scope IDs.""" + def __init__(self) -> None: + self.next = 0 + + def fresh(self) -> int: + n = self.next + self.next += 1 + return n + + +# --------------------------------------------------------------------------- +# Walker +# --------------------------------------------------------------------------- + + +def _walk(stmt: Stmt, in_cluster: bool, ids: _IdCounter) -> Stmt: + if isinstance(stmt, ParallelAxis): + if stmt.kind == ParallelKind.CLUSTER: + if stmt.parent_grid_axis_name is None: + raise AsyncWrapError( + f"cluster {stmt.axis_name!r} has no parent_grid_axis_name; " + f"split should have set it" + ) + new_body = [_walk(s, in_cluster=True, ids=ids) for s in stmt.body] + new_body = _wrap_async_runs(new_body, ids) + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=new_body, + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + # grid / logical_grid: pass through, but if we're already inside + # a cluster the leaf-op bodies here still need async wrapping. + new_body = [_walk(s, in_cluster=in_cluster, ids=ids) for s in stmt.body] + if in_cluster: + new_body = _wrap_async_runs(new_body, ids) + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=new_body, + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, For): + new_body = [_walk(s, in_cluster=in_cluster, ids=ids) for s in stmt.body] + if in_cluster: + new_body = _wrap_async_runs(new_body, ids) + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=new_body, + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + # Already wrapped; preserve and recurse (idempotency). + return Async( + body=[_walk(s, in_cluster=in_cluster, ids=ids) for s in stmt.body], + scope_id=stmt.scope_id, + ) + # Leaf op or MultiLaneOp: pass through unchanged. The async wrap + # decision is made one level up by _wrap_async_runs (which only + # fires inside cluster bodies). + return stmt + + +def _wrap_async_runs(stmts: List[Stmt], ids: _IdCounter) -> List[Stmt]: + """For each ``can_async=True`` leaf op in the cluster body, wrap it + in its own Async region (strict one-async-one-op). can_async=False + ops stay unwrapped. Already-wrapped Async / MultiLaneOp / nested + structure stays as-is.""" + out: List[Stmt] = [] + for s in stmts: + if _is_async_eligible(s): + out.append(Async(body=[s], scope_id=ids.fresh())) + else: + out.append(s) + return out + + +def _is_async_eligible(s: Stmt) -> bool: + """True if ``s`` is a leaf op with ``can_async=True``.""" + return isinstance(s, (Dma, Gemm, Elementwise, Reduce)) \ + and getattr(s, "can_async", False) + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Wrap can_async ops in Async regions. Index/view rewriting is a + separate downstream pass. + + No-op if ``should_skip_cluster(func)`` — without cluster fusion + there's no notion of multi-lane dispatch to mark.""" + if should_skip_cluster(func): + return func + ids = _IdCounter() + new_body = [_walk(s, in_cluster=False, ids=ids) for s in func.body] + return MidFunc( + name=func.name, + params=list(func.params), + allocs=list(func.allocs), + body=new_body, + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "AsyncWrapError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/burn_view.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/burn_view.py new file mode 100644 index 0000000..b4ca4cd --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/burn_view.py @@ -0,0 +1,378 @@ +"""pass_5b_burn_view: bake each BufferRef.view_perm into the buffer's +physical shape and into every ref's index tuple. + +Why this pass exists +-------------------- + +In the cluster pipeline (pass_4b), each non-global BufferRef carries +a ``view_perm`` describing the op-local view of an underlying physical +buffer. The view is "soft" — the BufferDef.shape is the storage shape +(BHSD shell with lane at physical dim 0), but ops see it permuted +(BSHD with lane at logical dim 1, etc). + +By the time we lower to HLIR, the soft view has to become hard: + + * Buffer.shape needs to be a single physical layout (HLIR has no + notion of per-ref view). + * Indices that referenced the buffer have to match the new physical + dim order. + +This pass does that bake. For each lane-aware buffer: + + 1. Collect all refs of that buffer; gather their view_perm. + 2. Verify they're all identical (pass_4b's consistency check + guarantees this; we re-verify for safety). + 3. If non-identity: permute Buffer.shape and every ref's indices + via that perm. + 4. Reset every ref's view_perm to None — the perm is now baked. + +After this pass: ``view_perm == None`` on every ref. HLIR lowering +treats Buffer.shape as authoritative. + +What's left untouched +--------------------- + + * HBM (global) buffers — view_perm was always None on those. + * Buffers in a MidFunc skipped by cluster_guard. + * BufferDef objects in MidFunc.params / .allocs — replaced with + the permuted version. ``BufferRef.buffer`` references get + swapped to point at the permuted def. + * RawStore.dst — opaque, no view rewriting. + * broadcast_dims — index *into* dst's logical shape; since the + bake just relabels physical dims to match the existing logical + view, broadcast_dims stay the same (they always referred to the + logical view). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from ..cluster_guard import should_skip_cluster +from ..ir import ( + AxisRole, AxisInfo, + BufferDef, BufferRef, Slice, + Dma, Gemm, Elementwise, Broadcast, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +class BurnViewError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# Phase 1: collect per-buffer view_perm +# --------------------------------------------------------------------------- + + +def _identity_perm(rank: int) -> List[int]: + return list(range(rank)) + + +def _is_identity(perm: List[int]) -> bool: + return perm == _identity_perm(len(perm)) + + +def _collect_views(stmt: Stmt, table: Dict[str, List[Tuple[int, ...]]]) -> None: + """Walk; for every BufferRef record its view_perm under buffer name. + Skips global buffers and refs with view_perm=None.""" + + def visit_ref(ref: BufferRef) -> None: + # User-declared globals (HBM + on-chip ``global.*`` caches) + # keep their as-written shape — burn_view never touches them. + if ref.buffer.scope == "global" or ref.buffer.scope.startswith("global."): + return + if ref.view_perm is None: + return + table.setdefault(ref.buffer.name, []).append(tuple(ref.view_perm)) + + def visit_src(src) -> None: + if isinstance(src, Broadcast): + visit_ref(src.src) + else: + visit_ref(src) + + if isinstance(stmt, Dma): + visit_ref(stmt.src) + visit_ref(stmt.dst) + elif isinstance(stmt, Gemm): + visit_ref(stmt.a) + visit_ref(stmt.b) + visit_ref(stmt.c) + elif isinstance(stmt, Elementwise): + visit_ref(stmt.dst) + for s in stmt.srcs: + visit_src(s) + elif isinstance(stmt, Reduce): + visit_ref(stmt.dst) + visit_ref(stmt.src) + elif isinstance(stmt, ParallelAxis): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, For): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, Async): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, MultiLaneOp): + _collect_views(stmt.inner, table) + # RawStore: opaque, skip. + + +def _agreed_perms(table: Dict[str, List[Tuple[int, ...]]] + ) -> Dict[str, Tuple[int, ...]]: + """For each buffer, verify all collected perms agree. Returns + name → single perm. Raises on mismatch.""" + out: Dict[str, Tuple[int, ...]] = {} + for name, perms in table.items(): + first = perms[0] + for p in perms[1:]: + if p != first: + raise BurnViewError( + f"buffer {name!r} has inconsistent view perms after " + f"pass_4b: {set(perms)}. pass_4b should have caught this." + ) + out[name] = first + return out + + +# --------------------------------------------------------------------------- +# Phase 2: build permuted BufferDefs +# --------------------------------------------------------------------------- + + +def _permute_buffer(buf: BufferDef, perm: Tuple[int, ...]) -> BufferDef: + if len(perm) != len(buf.shape): + raise BurnViewError( + f"buffer {buf.name!r} rank {len(buf.shape)} doesn't match " + f"perm rank {len(perm)}" + ) + # Track the cluster axis through the permutation: it lands at the + # new position whose ``perm[i]`` equals the old cluster_dim. + new_cluster: Optional[int] = None + if buf.cluster_dim is not None: + for new_i, old_i in enumerate(perm): + if old_i == buf.cluster_dim: + new_cluster = new_i + break + return BufferDef( + name=buf.name, + shape=[buf.shape[i] for i in perm], + dtype=buf.dtype, + scope=buf.scope, + cluster_dim=new_cluster, + ) + + +def _build_permuted_defs(func: MidFunc, + perms: Dict[str, Tuple[int, ...]]) -> Dict[str, BufferDef]: + """Return name → permuted BufferDef for every lane-aware buffer that + needs a non-identity perm. Identity perms still build a fresh def + for uniformity (so callers swap to a single canonical def).""" + out: Dict[str, BufferDef] = {} + for buf in list(func.params) + list(func.allocs): + if buf.scope == "global" or buf.scope.startswith("global."): + continue + if buf.name not in perms: + # No ref carried a view (e.g. unused buffer). Leave alone. + continue + out[buf.name] = _permute_buffer(buf, perms[buf.name]) + return out + + +# --------------------------------------------------------------------------- +# Phase 3: rewrite refs (indices + buffer pointer + clear view_perm) +# --------------------------------------------------------------------------- + + +def _rewrite_ref(ref: BufferRef, + new_defs: Dict[str, BufferDef]) -> BufferRef: + if ref.buffer.scope == "global" or ref.buffer.scope.startswith("global."): + return ref + new_def = new_defs.get(ref.buffer.name) + if new_def is None: + # Buffer not in the table: nothing to bake. + return ref + perm = ref.view_perm + if perm is None: + # View was never set (e.g. ref outside cluster). Just swap to + # the new def to keep the buffer-pointer single-source-of-truth. + return BufferRef(buffer=new_def, indices=list(ref.indices)) + new_indices = [ref.indices[i] for i in perm] + return BufferRef( + buffer=new_def, + indices=new_indices, + view_perm=None, # baked + ) + + +def _permute_axes(axes: List[AxisInfo], ref: BufferRef) -> List[AxisInfo]: + """Apply the same view_perm to the per-axis info that burn_view + applies to ref.indices. ``ref.view_perm`` is the perm; ``axes`` + is in pre-permute (view-pass) order. + + Returns the post-permute list. Length must equal ``len(ref.indices)``; + we accept and pass through a zero-length / mismatched list as + a transitional courtesy (so passes not yet axes-aware don't crash + the bake). + """ + perm = ref.view_perm + if perm is None or not axes: + return list(axes) + if len(axes) != len(perm): + # Length mismatch usually indicates a producer pass hasn't been + # updated yet (axes still empty). Don't permute; downstream + # consumers will fall back to legacy guess and we'll catch the + # gap during verification. + return list(axes) + return [axes[i] for i in perm] + + +def _rewrite_src(src, new_defs): + if isinstance(src, Broadcast): + return Broadcast( + src=_rewrite_ref(src.src, new_defs), + broadcast_dims=list(src.broadcast_dims), + ) + return _rewrite_ref(src, new_defs) + + +def _rewrite_op(op, new_defs): + if isinstance(op, Dma): + return Dma( + src=_rewrite_ref(op.src, new_defs), + dst=_rewrite_ref(op.dst, new_defs), + src_axes=_permute_axes(op.src_axes, op.src), + dst_axes=_permute_axes(op.dst_axes, op.dst), + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, Gemm): + return Gemm( + a=_rewrite_ref(op.a, new_defs), + b=_rewrite_ref(op.b, new_defs), + c=_rewrite_ref(op.c, new_defs), + transpose_a=op.transpose_a, + transpose_b=op.transpose_b, + kind=op.kind, + a_axes=_permute_axes(op.a_axes, op.a), + b_axes=_permute_axes(op.b_axes, op.b), + c_axes=_permute_axes(op.c_axes, op.c), + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, Elementwise): + # Permute per-src axes using each src's own view_perm. + new_src_axes = [] + for s, sa in zip(op.srcs, op.src_axes or [[]] * len(op.srcs)): + inner = s.src if isinstance(s, Broadcast) else s + new_src_axes.append(_permute_axes(sa, inner)) + return Elementwise( + dst=_rewrite_ref(op.dst, new_defs), + srcs=[_rewrite_src(s, new_defs) for s in op.srcs], + op=op.op, + dst_axes=_permute_axes(op.dst_axes, op.dst), + src_axes=new_src_axes, + axis=op.axis, + size=op.size, + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, Reduce): + return Reduce( + dst=_rewrite_ref(op.dst, new_defs), + src=_rewrite_ref(op.src, new_defs), + op=op.op, + axis=op.axis, + dst_axes=_permute_axes(op.dst_axes, op.dst), + src_axes=_permute_axes(op.src_axes, op.src), + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, RawStore): + return RawStore( + dst=_rewrite_ref(op.dst, new_defs), + value=op.value, + ) + raise BurnViewError(f"unhandled op type {type(op).__name__}") + + +def _walk(stmt: Stmt, new_defs: Dict[str, BufferDef]) -> Stmt: + if isinstance(stmt, ParallelAxis): + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s, new_defs) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=[_walk(s, new_defs) for s in stmt.body], + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + return Async( + body=[_walk(s, new_defs) for s in stmt.body], + scope_id=stmt.scope_id, + ) + if isinstance(stmt, MultiLaneOp): + return MultiLaneOp( + inner=_rewrite_op(stmt.inner, new_defs), + cluster_axis_names=list(stmt.cluster_axis_names), + cluster_axis_vars=list(stmt.cluster_axis_vars), + dim_map=dict(stmt.dim_map), + ) + return _rewrite_op(stmt, new_defs) + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Bake view_perm into Buffer shapes + ref indices.""" + if should_skip_cluster(func): + return func + + # Phase 1+2: gather views, verify, build new defs. + table: Dict[str, List[Tuple[int, ...]]] = {} + for s in func.body: + _collect_views(s, table) + perms = _agreed_perms(table) + new_defs = _build_permuted_defs(func, perms) + + if not new_defs: + # Nothing to bake (e.g. all views were identity / no view set). + return func + + # Phase 3: rewrite body + replace BufferDefs in params/allocs. + new_body = [_walk(s, new_defs) for s in func.body] + new_params = [new_defs.get(b.name, b) for b in func.params] + new_allocs = [new_defs.get(b.name, b) for b in func.allocs] + + return MidFunc( + name=func.name, + params=new_params, + allocs=new_allocs, + body=new_body, + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "BurnViewError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/distribute_cluster.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/distribute_cluster.py new file mode 100644 index 0000000..85b9620 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/distribute_cluster.py @@ -0,0 +1,267 @@ +"""pass_3b_distribute_cluster: push CLUSTER axes inside enclosing +unroll/pipeline For loops. + +Why this pass exists +-------------------- + +After ``split``, the IR can have a CLUSTER ParallelAxis whose body +contains a ``For(kind="unroll")``:: + + cluster c [...]: + for kh in unroll(KH): + op_a + op_b + +When pass_4_async wraps the can_async ops in Async regions, an Async +that lives outside the unroll loop logically spans **all KH iters**: + + cluster c [...]: + async { dma_a, dma_b } ← single Async covers ALL iters + for kh in unroll(KH): + ... + +That's wrong for the HW: each unroll iter should fire its own Async +dispatch (own DMA, own gemm, etc.) and complete independently. Mixing +"one Async / KH iters" semantics with the per-iter HW model creates +ambiguity about when the Async actually waits. + +This pass rewrites the nesting so each unroll iter has its OWN cluster +body — pass_4 then naturally produces one Async per iter:: + + for kh in unroll(KH): + cluster c [...]: ← cluster repeats inside each iter + op_a + op_b + +Mixed cluster bodies +-------------------- + +When a cluster body holds an unroll For interleaved with other ops, +the cluster splits into multiple cluster instances around each unroll +For:: + + cluster c [...]: + op_pre + for kh in unroll(KH): + op_inner + op_post + ↓ + cluster c [...]: + op_pre + for kh in unroll(KH): + cluster c [...]: + op_inner + cluster c [...]: + op_post + +This preserves the original execution order and keeps every op inside +some cluster body (so pass_4 can still see lane fusion context for it). + +What this pass does NOT touch +----------------------------- + + * ``For(kind="serial")`` — sequential loops; cluster sits OUTSIDE + just fine. Sequential iters don't have the "concurrent dispatch" + problem unroll has. + * Nested clusters (cluster inside cluster) — kernels don't produce + them today; if one shows up the inner one passes through. + * Async / MultiLaneOp — neither exists yet at this point in the + pipeline (pass_4 hasn't run). +""" + +from __future__ import annotations + +from typing import List + +from ..cluster_guard import should_skip_cluster +from ..ir import ( + BufferDef, BufferRef, + Dma, Gemm, Elementwise, Broadcast, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +class DistributeClusterError(RuntimeError): + pass + + +def _is_unroll_for(s: Stmt) -> bool: + return isinstance(s, For) and s.kind == "unroll" + + +def _clone_cluster_with_body(template: ParallelAxis, + body: List[Stmt]) -> ParallelAxis: + """Make a fresh CLUSTER axis with the same axis_name / extent / + parent_grid_axis_name as ``template`` but a different body.""" + return ParallelAxis( + axis_name=template.axis_name, + extent=template.extent, + body=body, + kind=ParallelKind.CLUSTER, + thread_tag=template.thread_tag, # always None for CLUSTER, but copy anyway + parent_grid_axis_name=template.parent_grid_axis_name, + original_axis_name=template.original_axis_name, + axis_var=template.axis_var, + original_axis_var=template.original_axis_var, + ) + + +# --------------------------------------------------------------------------- +# Cluster distribution +# --------------------------------------------------------------------------- + + +def _distribute_one_cluster(cluster: ParallelAxis) -> List[Stmt]: + """Take a single CLUSTER axis whose body may contain unroll Fors, + return a list of stmts equivalent to the original but with each + unroll For lifted out and the cluster pushed inside. + + See module docstring for the rewrite rule. + """ + out: List[Stmt] = [] + pending: List[Stmt] = [] # ops accumulated since last unroll-for boundary + + def flush_pending() -> None: + nonlocal pending + if pending: + out.append(_clone_cluster_with_body(cluster, pending)) + pending = [] + + for child in cluster.body: + if _is_unroll_for(child): + # Boundary: emit any pending ops as a cluster, then emit + # the unroll For with cluster pushed into its body. + flush_pending() + inner_body = _walk_stmts(child.body) + new_for = For( + loop_var=child.loop_var, + extent=child.extent, + body=[_clone_cluster_with_body(cluster, inner_body)], + kind=child.kind, + loop_var_var=child.loop_var_var, + ) + out.append(new_for) + else: + # Recurse into nested structure first, then accumulate. + pending.append(_walk_stmt(child)) + + flush_pending() + return out + + +def _walk_stmt(stmt: Stmt) -> Stmt: + if isinstance(stmt, ParallelAxis): + if stmt.kind == ParallelKind.CLUSTER: + distributed = _distribute_one_cluster(stmt) + # If nothing changed (no unroll for inside), wrap back. + # Otherwise the cluster has been distributed across + # multiple stmts; we can't return a list from this point — + # callers expect one stmt. Wrap multiple in a synthetic + # cluster… no, better: surface a SeqStmt-like via the + # parent. Easiest path: parent walker collects stmt lists, + # not single stmts. See _walk_stmts below. + if (len(distributed) == 1 + and isinstance(distributed[0], ParallelAxis) + and distributed[0] is not stmt): + return distributed[0] + # Single-stmt no-change case + if (len(distributed) == 1 + and isinstance(distributed[0], ParallelAxis) + and distributed[0].axis_name == stmt.axis_name + and distributed[0].body == stmt.body): + return stmt + # Multi-stmt distribution — caller has to flatten via _walk_stmts. + # Mark by returning a "marker" sentinel? We avoid that by + # not exposing single-stmt callers in this pass — _walk_stmts + # is the canonical entrypoint for body lists. + # Fallthrough for safety: re-wrap into a plain cluster. + return _clone_cluster_with_body(stmt, distributed) if len(distributed) > 1 else distributed[0] + # Non-CLUSTER ParallelAxis: walk body recursively, no rewrite. + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=_walk_stmts(stmt.body), + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=_walk_stmts(stmt.body), + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + return Async(body=_walk_stmts(stmt.body), scope_id=stmt.scope_id) + if isinstance(stmt, MultiLaneOp): + return MultiLaneOp( + inner=_walk_stmt(stmt.inner), + cluster_axis_names=list(stmt.cluster_axis_names), + cluster_axis_vars=list(stmt.cluster_axis_vars), + dim_map=dict(stmt.dim_map), + ) + # Leaf op: pass through unchanged. + return stmt + + +def _walk_stmts(stmts: List[Stmt]) -> List[Stmt]: + """Walk a body list. CLUSTER axes that distribute into multiple + stmts are flattened in here (the only place that handles a 1→N + rewrite cleanly).""" + out: List[Stmt] = [] + for s in stmts: + if (isinstance(s, ParallelAxis) + and s.kind == ParallelKind.CLUSTER + and any(_is_unroll_for(c) for c in s.body)): + # Recurse into the cluster body's children first so any + # nested clusters / fors are handled, then distribute. + child = ParallelAxis( + axis_name=s.axis_name, + extent=s.extent, + body=_walk_stmts(s.body), + kind=s.kind, + thread_tag=s.thread_tag, + parent_grid_axis_name=s.parent_grid_axis_name, + original_axis_name=s.original_axis_name, + axis_var=s.axis_var, + original_axis_var=s.original_axis_var, + ) + out.extend(_distribute_one_cluster(child)) + else: + out.append(_walk_stmt(s)) + return out + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Push every CLUSTER axis inside any unroll For it currently wraps + around. ``serial`` Fors and non-CLUSTER ParallelAxes are left + alone. + + No-op if ``should_skip_cluster(func)`` — there's no cluster to + distribute.""" + if should_skip_cluster(func): + return func + return MidFunc( + name=func.name, + params=list(func.params), + allocs=list(func.allocs), + body=_walk_stmts(func.body), + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "DistributeClusterError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/fold.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/fold.py new file mode 100644 index 0000000..4206ba8 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/fold.py @@ -0,0 +1,1409 @@ +"""pass_1_fold: raw tir.PrimFunc → mid_ir.MidFunc. + +What this pass folds +-------------------- + +Five raw-TIR shapes get collapsed into one mid_ir node each: + + tl.tileop.copy(region(src), region(dst)) + → Dma(src, dst) + + tl.tileop.gemm_py(region(A), region(B), region(C), ...) + [optionally wrapped by ``T.attr(0, "plena.gemm_kind", "btmm")``] + → Gemm(A, B, C, kind=) + + tl.tileop.reduce(region(src), region(dst), dim, clear) + → Reduce(dst, src, op, axis) + + for i in T.Parallel(N): + dst[..., i] = A[..., i] OP B[..., i] + → Elementwise(dst, [A, B], BinOp.) # whole-buffer + for i in T.Parallel(N): + dst[..., i] = T.exp(A[..., i]) + → Elementwise(dst, [A], UnaryOp.EXP, axis=-1) + + for i in T.serial(N): + dst[i] = scalar_expr_of(A[i], B[i]) # 1D fp scalar update + → Elementwise(dst, [A, B], BinOp.) or Reduce / Broadcast + +Anything that doesn't match one of the above is preserved as a raw +``RawStmt`` wrapper for the next pass to look at — but for the +flash_attention_min op set everything is expected to fold. + +Structure-preserving wrappers (For with thread_tag, AttrStmt for +KIND, SeqStmt, BlockRealize) are translated to mid_ir's For + body +list as appropriate. Raw structure isn't carried over verbatim — the +output is purely mid_ir nodes. + +Scope +----- + +Only handles the rounded ops the kernel test set exercises today: +add / sub / mul / max / exp / reci / copy / 0-fill (zero_v) / sum-reduce +/ max-reduce. Anything else (FloatImm in store other than 0, DivNode +RHS, Cast in expr) raises ``FoldError`` so we notice early — better +than silently emitting a malformed mid_ir node. + +Limitations / explicit gaps for later +------------------------------------- + + * Compound RHS (a*b + c*d) is rejected — relies on + ``lower_compound_fp_stores`` running first. + * IfThenElse: kernels don't use it; raises FoldError. + * Match-buffers / non-trivial Block.alloc_buffers: passed through + via best-effort BufferDef synthesis. + * Reduce's ``clear`` flag is read but mid_ir doesn't represent it + yet — we store it on Reduce.attrs for the lowering pass to read. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import tvm +from tvm import tir + +from ..ir import ( + BinOp, UnaryOp, ReduceOp, + AxisRole, AxisInfo, + BufferDef, BufferRef, Slice, VarRef, + Elementwise, Broadcast, Reduce, + Gemm, Dma, RawStore, For, MidFunc, + ParallelAxis, ParallelKind, +) + + +_TILEOP_COPY = "tl.tileop.copy" +_TILEOP_GEMM = "tl.tileop.gemm_py" +_TILEOP_REDUCE = "tl.tileop.reduce" +_TILEOP_REGION = "tl.tileop.region" +_KIND_KEY = "plena.gemm_kind" +_LANE_AXIS_FUNC_ATTR = "plena.lane_axis" + + +class FoldError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# Per-fold Var registry +# --------------------------------------------------------------------------- +# +# Canonicalises ``tir.Var`` -> ``VarRef`` within one fold call. +# +# Keyed by ``id(var)`` so a given ``tir.Var`` object always yields the +# same ``VarRef`` instance — identity comparisons on ``VarRef`` are +# stable across multiple visits of the same underlying var. +# +# Same-name different-object vars are explicitly *allowed*. That's the +# whole point of moving off bare-string indices: each ``tir.Var`` keeps +# its own identity even when ``name_hint`` collides (e.g. two unrelated +# ``row`` vars in the same PrimFunc). Two ``VarRef``s wrapping +# different ``tir.Var``s with the same name will compare unequal via +# ``var.same_as`` — exactly the contract we want. +# +# Reset at every call to :func:`run`. + + +class _VarRegistry: + def __init__(self) -> None: + self._by_id: Dict[int, VarRef] = {} + # Keep a strong reference to each Var so its id() can't be + # recycled inside this fold call. + self._anchor: List[object] = [] + + def ref(self, var) -> VarRef: + existing = self._by_id.get(id(var)) + if existing is not None: + return existing + new_ref = VarRef(var) + self._by_id[id(var)] = new_ref + self._anchor.append(var) + return new_ref + + +# Active registry for the current ``run`` invocation. Module-level +# (vs threaded through every helper) because the recursion already +# fans out through many small functions; passing it would force a +# wide signature change for no benefit. Reset at the top of ``run``. +_active_registry: Optional[_VarRegistry] = None + + +def _vref(var) -> VarRef: + """Canonical ``VarRef`` for ``var`` in the active fold call.""" + if _active_registry is None: + raise FoldError( + "_vref called outside a fold ``run`` — registry not active" + ) + return _active_registry.ref(var) + + +def _assert_no_str_in_indices(stmts) -> None: + """Walk ``stmts`` and assert no BufferRef.indices entry is a bare + ``str``. Bare-string indices were the pre-VarRef cheat; fold output + must be VarRef-only.""" + def visit_ref(ref: BufferRef) -> None: + for i, idx in enumerate(ref.indices): + _check_idx(idx, ref.buffer.name, i) + + def _check_idx(idx, buf_name, pos) -> None: + if isinstance(idx, str): + raise FoldError( + f"fold produced a bare-string index in BufferRef " + f"{buf_name}[..pos {pos}..] = {idx!r}. mid_ir now requires " + f"VarRef; investigate the producer." + ) + if isinstance(idx, dict): + for a in idx.get("args", []): + _check_idx(a, buf_name, pos) + + def visit_src(src) -> None: + if isinstance(src, Broadcast): + visit_ref(src.src) + else: + visit_ref(src) + + def walk(s) -> None: + if isinstance(s, Elementwise): + visit_ref(s.dst) + for x in s.srcs: + visit_src(x) + elif isinstance(s, Reduce): + visit_ref(s.dst) + visit_ref(s.src) + elif isinstance(s, Dma): + visit_ref(s.src) + visit_ref(s.dst) + elif isinstance(s, Gemm): + visit_ref(s.a) + visit_ref(s.b) + visit_ref(s.c) + elif isinstance(s, RawStore): + visit_ref(s.dst) + # Recurse into structural nodes. + body = getattr(s, "body", None) + if isinstance(body, list): + for c in body: + walk(c) + + for s in stmts: + walk(s) + + +# --------------------------------------------------------------------------- +# Call kind helpers (mirror prior passes; kept local for self-containment) +# --------------------------------------------------------------------------- + + +def _call_kind(call: tir.Call) -> Optional[str]: + if not isinstance(call, tir.Call): + return None + op_name = getattr(call.op, "name", "") + if op_name and not op_name.startswith("tir."): + return op_name + if op_name == "tir.call_extern" and call.args: + head = call.args[0] + if isinstance(head, tir.StringImm): + return str(head.value) + return None + + +def _call_args(call: tir.Call) -> List: + op_name = getattr(call.op, "name", "") + if op_name == "tir.call_extern" and call.args: + return list(call.args[1:]) + return list(call.args) + + +# --------------------------------------------------------------------------- +# Buffer table +# --------------------------------------------------------------------------- + + +def _scope_string(buf: tir.Buffer, default: str) -> str: + s = getattr(buf, "scope", None) + if callable(s): + try: + return str(s()) + except Exception: + return default + if isinstance(s, str): + return s + return default + + +def _shape_ints(buf: tir.Buffer) -> List[int]: + out = [] + for d in buf.shape: + if isinstance(d, tir.IntImm): + out.append(int(d.value)) + elif isinstance(d, int): + out.append(int(d)) + else: + raise FoldError( + f"buffer {buf.name!r} has symbolic dim {d!r}; mid_ir is " + f"int-shape-only at this stage" + ) + return out + + +def _buffer_def(buf: tir.Buffer, default_scope: str = "global") -> BufferDef: + shape = _shape_ints(buf) + scope = _scope_string(buf, default_scope) + # Hard rule: 1D ``shared`` (VRAM tile) buffers are rejected. + # Reason: fold's broadcast detection in ``_wrap_src`` only flips a + # same-rank operand into ``Broadcast`` when the src ref's rank is + # strictly shorter than dst's. A 1D shared dst paired with a 1D + # ``fp[0]``-style scalar src therefore fails to fold (same rank, + # but the src is logically a scalar broadcast). Force authors to + # write ``(1, N)`` instead; downstream tile/row machinery is built + # around ≥2D shared anyway. + if scope.startswith("shared") and len(shape) == 1: + raise FoldError( + f"1D shared buffer {buf.name!r} (shape={shape}) is not " + f"supported. Use ``T.alloc_shared((1, N), ...)`` instead — " + f"1D shared loses the broadcast-axis fold needed for " + f"`vram[i] op fp_scalar[0]` to lower to a vector op." + ) + return BufferDef( + name=buf.name, + shape=shape, + dtype=str(buf.dtype), + scope=scope, + ) + + +# --------------------------------------------------------------------------- +# Raw-TIR → mid_ir IndexExpr conversion +# --------------------------------------------------------------------------- + + +def _index_expr(expr) -> Union[int, VarRef, dict]: + """Convert a TIR PrimExpr appearing as an index into a mid_ir + IndexExpr (int / VarRef / dict). Compound arithmetic becomes a + ``{"op": "", "args": [...]}`` dict; passes that need + to manipulate it can parse the dict. + + A vectorized index (``tir.Ramp(base, stride, lanes)``) is treated + as a contiguous slice and encoded as + ``{"op": "ramp", "args": [base, stride, lanes]}``. Callers that + can collapse a full-range ramp (base=0, stride=1, lanes=buffer_dim) + into a ``Slice`` should do so before calling here, but the encoding + survives if they don't. + + ``tir.Var`` -> ``VarRef`` (identity-typed). The active fold-call + registry canonicalises so visits of the same Var object always + yield the same VarRef, and two distinct Vars sharing a name raise. + """ + if isinstance(expr, (int,)): + return int(expr) + if isinstance(expr, tir.IntImm): + return int(expr.value) + if isinstance(expr, tir.Var): + return _vref(expr) + if isinstance(expr, tir.Add): + return {"op": "add", "args": [_index_expr(expr.a), _index_expr(expr.b)]} + if isinstance(expr, tir.Sub): + return {"op": "sub", "args": [_index_expr(expr.a), _index_expr(expr.b)]} + if isinstance(expr, tir.Mul): + return {"op": "mul", "args": [_index_expr(expr.a), _index_expr(expr.b)]} + if isinstance(expr, tir.FloorDiv): + return {"op": "fdiv", "args": [_index_expr(expr.a), _index_expr(expr.b)]} + if isinstance(expr, tir.FloorMod): + return {"op": "fmod", "args": [_index_expr(expr.a), _index_expr(expr.b)]} + if isinstance(expr, tir.Ramp): + return { + "op": "ramp", + "args": [_index_expr(expr.base), _index_expr(expr.stride), + int(expr.lanes)], + } + raise FoldError( + f"unsupported index expression type {type(expr).__name__}: {expr!r}" + ) + + +def _is_full_range_ramp_for_dim(idx, dim_extent: int) -> bool: + """True if ``idx`` is a Ramp(0, 1, dim_extent) — equivalent to a + whole-axis Slice on a buffer dim of size ``dim_extent``.""" + if isinstance(idx, tir.Ramp): + if (isinstance(idx.base, tir.IntImm) and int(idx.base.value) == 0 + and isinstance(idx.stride, tir.IntImm) and int(idx.stride.value) == 1 + and int(idx.lanes) == dim_extent): + return True + return False + + +# --------------------------------------------------------------------------- +# Region call → BufferRef +# --------------------------------------------------------------------------- + + +def _region_to_ref(call: tir.Call, + buf_table: Dict[str, BufferDef]) -> BufferRef: + """Convert ``tl.tileop.region(BufferLoad(buf, [starts]), mode, *extents)`` + into a BufferRef. + + Indexing convention for mid_ir: + * Where extent equals the buffer's full extent on that axis, + the index is a Slice (whole-axis access). + * Otherwise the index is the BufferLoad's start (a literal int, + a var name, or a compound expression). + + This matches the way the kernel author wrote the access: a sliced + access ``Q_hbm[0, q*rows, by, 0]`` has whole-axis on dims 1 and 3 + (not really — but on our flash_attention_min slice with extents + [1, rows, 1, hlen] only dims 1 and 3 cover the full HBM axis, so + mark only those as Slice). The point: mid_ir's BufferRef tells + later passes whether a dim is "fully consumed" so cluster-dim + rewrites know what to leave alone vs index into. + """ + # tilelang's reduce ABI sometimes passes a bare BufferLoad as the + # src/dst arg instead of wrapping it in a tl.tileop.region call. + # Treat that as a whole-buffer reference: any dim whose index is + # 0-IntImm OR a full-range Ramp(0, 1, dim_extent) becomes Slice; + # everything else is the literal index. (Ramp shows up because + # tilelang's reduce passes vectorized loads — Ramp(0,1,N) means + # "the whole N-wide range".) + if isinstance(call, tir.BufferLoad): + load = call + buf_def = buf_table.get(load.buffer.name) + if buf_def is None: + buf_def = _buffer_def(load.buffer, default_scope="shared") + buf_table[load.buffer.name] = buf_def + indices: List = [] + for axis, idx in enumerate(load.indices): + dim_extent = int(buf_def.shape[axis]) if axis < len(buf_def.shape) else None + if isinstance(idx, tir.IntImm) and int(idx.value) == 0: + indices.append(Slice()) + elif (dim_extent is not None + and _is_full_range_ramp_for_dim(idx, dim_extent)): + indices.append(Slice()) + else: + indices.append(_index_expr(idx)) + return BufferRef(buffer=buf_def, indices=indices) + + args = _call_args(call) + if not args: + raise FoldError("empty region call") + load = args[0] + if not isinstance(load, tir.BufferLoad): + raise FoldError(f"region first arg is not a BufferLoad: {load!r}") + starts = list(load.indices) + extents = list(args[2:]) # args[1] is mode + if len(starts) != len(extents): + raise FoldError( + f"region rank mismatch: starts={len(starts)} extents={len(extents)}" + ) + # Reconcile to a BufferDef (the buffer may have been seen via + # decl_buffer here for the first time). + buf_def = buf_table.get(load.buffer.name) + if buf_def is None: + buf_def = _buffer_def(load.buffer, default_scope="shared") + buf_table[load.buffer.name] = buf_def + indices: List = [] + for axis, (s, e) in enumerate(zip(starts, extents)): + if not isinstance(e, tir.IntImm): + raise FoldError( + f"region extent on axis {axis} of {buf_def.name!r} is " + f"non-static: {e!r}" + ) + e_int = int(e.value) + # Slice when extent matches buffer dim AND the start is 0 + # OR start is a full-range Ramp (vectorized whole-axis load). + is_zero_start = isinstance(s, tir.IntImm) and int(s.value) == 0 + is_full_ramp = _is_full_range_ramp_for_dim(s, e_int) + if e_int == buf_def.shape[axis] and (is_zero_start or is_full_ramp): + indices.append(Slice()) + elif e_int > 1: + # Partial range with extent > 1: preserve both start and extent + # via a compound "ranged_slice" expression so downstream passes + # (to_plena _ref_extents / _render_idx) can recover the tile + # width. Keeps the mid_ir IndexExpr taxonomy unchanged. + indices.append({ + "op": "ranged_slice", + "args": [_index_expr(s), e_int], + }) + else: + indices.append(_index_expr(s)) + return BufferRef(buffer=buf_def, indices=indices) + + +# --------------------------------------------------------------------------- +# BufferStore RHS → mid_ir Op recogniser +# --------------------------------------------------------------------------- + + +_BIN_NODE_TO_OP = { + tir.Add: BinOp.ADD, + tir.Sub: BinOp.SUB, + tir.Mul: BinOp.MUL, + # Max / Min are tir.Max / tir.Min; tested below. +} + + +def _peel_cast_roundtrip(expr, target_dtype: Optional[str] = None): + """Strip TVM-inserted fp16↔fp32 cast roundtrips. + + TVM lowers ``fp16 = 1.0 / fp16`` to + ``Cast(fp16, Cast(fp32, 1.0) / Cast(fp32, x_fp16))``. PLENA HW does + the reciprocal in fp16 natively, so for fold-pattern matching we + want to see the inner expression as if no widening happened. + + The strategy: + 1. If ``expr`` is ``Cast(T, x)`` where ``x.dtype == T`` → return + ``x`` (no-op cast). + 2. If ``expr`` is ``Cast(T, Cast(_, x))`` where ``x.dtype == T`` + → return ``x`` (widen-then-narrow roundtrip). + 3. If ``expr`` is ``Cast(T, arith_expr)`` where ``arith_expr`` is + a Div / Add / Sub / Mul / Max / Min / unary Call whose operands + are themselves ``Cast(_, leaf)`` widening originals from dtype + ``T`` → return a rebuilt ``arith_expr`` whose operands are the + peeled leaves (so the whole thing is now in dtype ``T``). + 4. Otherwise return ``expr`` unchanged. + """ + if not isinstance(expr, tir.Cast): + return expr + target = str(expr.dtype) if target_dtype is None else target_dtype + inner = expr.value + # Rule 2: nested cast roundtrip. + if isinstance(inner, tir.Cast): + innermost = inner.value + if str(innermost.dtype) == target: + return innermost + return expr + # Rule 1: redundant same-dtype cast. + inner_dtype = getattr(inner, "dtype", None) + if inner_dtype is not None and str(inner_dtype) == target: + return inner + # Rule 3: widen-op-narrow over arithmetic. Peel each operand and, + # if every operand is either a leaf already in ``target`` or a + # ``Cast(_, x)`` from ``target``, rebuild the arith expression at + # ``target`` dtype. + def _peel_to_target(e): + if isinstance(e, tir.Cast): + x = e.value + if str(getattr(x, "dtype", "")) == target: + return x + # Constant under cast (e.g. Cast(fp32, FloatImm(1.0))) — + # re-emit the constant at target dtype directly. + if isinstance(x, tir.IntImm): + return tir.IntImm(target, int(x.value)) + if isinstance(x, tir.FloatImm): + return tir.FloatImm(target, float(x.value)) + return None + # Bare literals: re-emit at target dtype so binop dtypes match. + if isinstance(e, tir.IntImm): + return tir.IntImm(target, int(e.value)) + if isinstance(e, tir.FloatImm): + return tir.FloatImm(target, float(e.value)) + if str(getattr(e, "dtype", "")) == target: + return e + return None + + cls = type(inner) + if cls in (tir.Add, tir.Sub, tir.Mul, tir.Div, tir.Max, tir.Min): + a = _peel_to_target(inner.a) + b = _peel_to_target(inner.b) + if a is not None and b is not None: + return cls(a, b) + return expr + if isinstance(inner, tir.Call) and len(inner.args) == 1: + a = _peel_to_target(inner.args[0]) + if a is not None: + return tir.Call(target, inner.op, [a]) + return expr + return expr + + +def _try_bin_op(node) -> Optional[BinOp]: + cls = type(node) + if cls in _BIN_NODE_TO_OP: + return _BIN_NODE_TO_OP[cls] + if cls is tir.Max: + return BinOp.MAX + if cls is tir.Min: + return BinOp.MIN + return None + + +def _try_unary_call(node) -> Optional[UnaryOp]: + """Recognise T.exp(x), T.sqrt(x). Reciprocal shows up as + ``1.0 / x`` (a tir.Div), not as a Call — handled separately.""" + if not isinstance(node, tir.Call): + return None + op_name = getattr(node.op, "name", "") + if op_name == "tir.exp": + return UnaryOp.EXP + if op_name == "tir.sqrt": + return UnaryOp.SQRT + return None + + +# --------------------------------------------------------------------------- +# Per-loop folder (T.Parallel / T.serial elementwise) +# --------------------------------------------------------------------------- + + +def _store_to_ref(store: tir.BufferStore, + buf_table: Dict[str, BufferDef]) -> BufferRef: + name = store.buffer.name + if name not in buf_table: + buf_table[name] = _buffer_def(store.buffer, default_scope="shared") + return BufferRef( + buffer=buf_table[name], + indices=[_index_expr(i) for i in store.indices], + ) + + +def _load_to_ref(load: tir.BufferLoad, + buf_table: Dict[str, BufferDef]) -> BufferRef: + name = load.buffer.name + if name not in buf_table: + buf_table[name] = _buffer_def(load.buffer, default_scope="shared") + return BufferRef( + buffer=buf_table[name], + indices=[_index_expr(i) for i in load.indices], + ) + + +def _try_fold_parallel(for_stmt: tir.For, + buf_table: Dict[str, BufferDef]) -> Optional[Elementwise]: + """``for i in T.Parallel(N): dst[..., i] = expr(loads_at_..._i)`` + → Elementwise. + + Produces ``axis=-1`` (acts on the last axis only — other axes are + independent / fired once per element) and ``size=N`` (one HW + vector instruction covers ``N`` elements in one issue).""" + if for_stmt.kind != tir.ForKind.PARALLEL: + return None + body = for_stmt.body + if not isinstance(body, tir.BufferStore): + return None + extent = (int(for_stmt.extent.value) + if isinstance(for_stmt.extent, tir.IntImm) else 1) + return _try_fold_store( + body, for_stmt.loop_var, buf_table, axis=-1, size=extent, + ) + + +def _index_exprs_equal(a, b) -> bool: + """Cheap structural equality on two mid_ir IndexExpr values. Used + to decide whether a src's index tuple is a prefix of dst's (which + is how we detect a broadcast).""" + if isinstance(a, Slice) and isinstance(b, Slice): + return True + if isinstance(a, int) and isinstance(b, int): + return a == b + if isinstance(a, VarRef) and isinstance(b, VarRef): + return a == b # identity-based equality + if isinstance(a, dict) and isinstance(b, dict): + if a.get("op") != b.get("op"): + return False + aa, bb = a.get("args", []), b.get("args", []) + if len(aa) != len(bb): + return False + return all(_index_exprs_equal(x, y) for x, y in zip(aa, bb)) + return False + + +def _wrap_src(load: tir.BufferLoad, + dst_indices: List, + buf_table: Dict[str, BufferDef], + dst_buf: Optional[BufferDef] = None, + ) -> Optional[Union[BufferRef, Broadcast]]: + """Convert a BufferLoad src into either a plain BufferRef (when + its index tuple matches dst's exactly) or a Broadcast (when it's + a prefix — fewer trailing dims). + + Broadcast detection rule: src.indices == dst.indices[:len(src)] + AND len(src) < len(dst). The missing trailing dims become + ``broadcast_dims``. + + Special case: FPRAM scalar fragment stores (rank-1 + ``local.fragment`` dst) lower to one ``S_*_FP`` per call, with + every operand carrying its own independent scalar address. Index + matching is irrelevant — return the BufferRef as-is so the FPRAM + scalar emitter can use ``src.indices`` directly. Used by kernels + like RoPE that compute pair-swap addresses ``X_FP[2*i+1]`` against + ``__tmp_fp_0[2*i]``. + + Anything else (mismatched non-prefix shapes on SIMD-style refs) + returns None so fold fails loudly. + """ + src_ref = _load_to_ref(load, buf_table) + src_idx = src_ref.indices + is_fpram_scalar_dst = ( + dst_buf is not None + and dst_buf.scope in ("local.fragment", "fragment", "fragment.fpram") + and len(dst_buf.shape) == 1 + ) + if is_fpram_scalar_dst: + # FPRAM scalar lower-cycle: each operand is just a scalar + # address, no SIMD axis to align. + return src_ref + if len(src_idx) == len(dst_indices): + # Same rank — every position must be index-equal to dst's + # for this to be a valid per-element op. + if all(_index_exprs_equal(s, d) for s, d in zip(src_idx, dst_indices)): + return src_ref + elif len(src_idx) < len(dst_indices): + # Broadcast: src must equal a prefix of dst. + prefix = dst_indices[:len(src_idx)] + if all(_index_exprs_equal(s, p) for s, p in zip(src_idx, prefix)): + broadcast_dims = list(range(len(src_idx), len(dst_indices))) + return Broadcast(src=src_ref, broadcast_dims=broadcast_dims) + return None + + +def _to_raw_store(store: tir.BufferStore, + buf_table: Dict[str, BufferDef]) -> RawStore: + """Wrap a BufferStore as an opaque RawStore. + + The ``dst`` BufferRef is computed from the store's indices via the + same ``_index_expr`` machinery as elsewhere — so e.g. ``buf[MLEN+k]`` + becomes ``BufferRef(buf, [{op:add, args:[64, "k"]}])`` and the + ``value`` payload is the raw ``store.value`` PrimExpr (opaque). + """ + return RawStore( + dst=_store_to_ref(store, buf_table), + value=store.value, + ) + + +def _axes_for_ref(ref: BufferRef, + simd_axis: Optional[int], + simd_size: int) -> List[AxisInfo]: + """Build per-axis AxisInfo for a BufferRef given the op's SIMD context. + + Each axis carries its **buffer-declared extent** plus a role: + * SIMD with extent ``simd_size`` for the axis the op vectorises + along (``simd_axis`` index, negative wraps). + * BATCH with the buffer's declared extent for every other axis. + The op fan-outs that many times along that dim; whether or + not an outer for-loop is currently in scope at fold time is + a fold detail (the for may even get absorbed) — the axis's + fan-out cardinality stays the same regardless. Lower wraps + outer fors based on this extent. + + ``simd_axis is None`` → whole-buffer SIMD: every dim is SIMD at + its full extent. + + CLUSTER role is never assigned at this point — the cluster axis + is prepended later by pass_3_split / pass_4b_view alongside the + indices change. + """ + out: List[AxisInfo] = [] + rank = len(ref.indices) + normalised_simd = ( + simd_axis + rank if (simd_axis is not None and simd_axis < 0) + else simd_axis + ) + for dim, extent_decl in enumerate(ref.buffer.shape): + full = int(extent_decl) + if normalised_simd is not None and dim == normalised_simd: + out.append(AxisInfo(role=AxisRole.SIMD, extent=int(simd_size))) + elif simd_axis is None: + out.append(AxisInfo(role=AxisRole.SIMD, extent=full)) + else: + out.append(AxisInfo(role=AxisRole.BATCH, extent=full)) + return out + + +def _axes_for_broadcast_src( + bc: Broadcast, + simd_axis: Optional[int], + simd_size: int, +) -> List[AxisInfo]: + """Per-axis AxisInfo for the inner ref of a Broadcast. + + The dst rank exceeds the src rank by ``len(bc.broadcast_dims)``; + those dst-side axes are tagged BROADCAST and don't appear in the + src's axes list at all. The src's own axes use the same rules as + ``_axes_for_ref`` but with the SIMD axis index translated to its + position inside the src's (shorter) shape, if applicable. + """ + # Dst-side SIMD axis index. If it lives in a broadcast_dim, the + # src side has no corresponding axis to be SIMD; we leave all of + # the src's axes as BATCH (its values are broadcast onto the SIMD + # dim from outside). + rank_src = len(bc.src.indices) + if simd_axis is None: + return [ + AxisInfo(role=AxisRole.SIMD, extent=int(d)) + for d in bc.src.buffer.shape + ] + # Map dst-side simd index → src-side index by counting how many + # broadcast_dims sit at or before it. + bd_set = set(bc.broadcast_dims) + if simd_axis in bd_set: + src_simd: Optional[int] = None + else: + # Number of broadcast_dims with smaller index — drop them from + # the dst index to land on the matching src dim. + offset = sum(1 for b in bd_set if b < simd_axis) + src_simd = simd_axis - offset + if not (0 <= src_simd < rank_src): + src_simd = None + return _axes_for_ref(bc.src, src_simd, simd_size) + + +def _try_fold_store(store: tir.BufferStore, + parallel_var: Optional[tir.Var], + buf_table: Dict[str, BufferDef], + axis: Optional[int] = None, + size: int = 1) -> Optional[Elementwise]: + """Recognise the RHS of ``store`` as a mid_ir-expressible + Elementwise. Returns None on no-match — caller is responsible for + falling back to ``RawStore`` rather than raising. Never raises. + + ``size`` is the per-issue element count for the resulting + Elementwise (see :class:`Elementwise.size`): 1 for a scalar store + (SISD), N for a folded ``T.Parallel(N)`` vector store (SIMD). + """ + if not store.indices: + return None + if parallel_var is not None: + last = store.indices[-1] + if not (isinstance(last, tir.Var) and last.same_as(parallel_var)): + return None + # Compound indices (e.g. ``buf[2 * i + 1] = ...``) are kept as + # affine PrimExpr in the resulting Elementwise/BufferRef. Used by + # kernels like RoPE that compute pair offsets ``e = 2*i, + # o = 2*i + 1`` and write into a per-pair fragment. Downstream + # ``_render_idx_as_primexpr`` materialises them. + dst = _store_to_ref(store, buf_table) + # Peel TVM's fp16↔fp32 cast roundtrip so reciprocal / binop / unary + # matchers below see a clean fp16 expression tree. + expr = _peel_cast_roundtrip(store.value, target_dtype=str(store.buffer.dtype)) + + def _build_axes(srcs): + return ( + _axes_for_ref(dst, axis, size), + [ + _axes_for_broadcast_src(s, axis, size) + if isinstance(s, Broadcast) + else _axes_for_ref(s, axis, size) + for s in srcs + ], + ) + + # Constant fill: only 0 maps cleanly to a HW vector op. Anything + # else falls back to RawStore (the caller wraps it). + if isinstance(expr, (tir.IntImm, tir.FloatImm)): + if float(expr.value) != 0.0: + return None + dst_axes, src_axes = _build_axes([]) + return Elementwise( + dst=dst, srcs=[], op=UnaryOp.COPY, + axis=axis, size=size, + dst_axes=dst_axes, src_axes=src_axes, + ) + + # Unary: T.exp(x), T.sqrt(x). + unary = _try_unary_call(expr) + if unary is not None: + if len(expr.args) != 1: + return None + a = _peel_cast_roundtrip(expr.args[0]) + if not isinstance(a, tir.BufferLoad): + return None + wrapped = _wrap_src(a, dst.indices, buf_table, dst_buf=dst.buffer) + if wrapped is None: + return None + dst_axes, src_axes = _build_axes([wrapped]) + return Elementwise( + dst=dst, srcs=[wrapped], op=unary, + axis=axis, size=size, + dst_axes=dst_axes, src_axes=src_axes, + ) + + # Reciprocal: 1.0 / x. + if isinstance(expr, tir.Div): + a = _peel_cast_roundtrip(expr.a) + b = _peel_cast_roundtrip(expr.b) + if (isinstance(a, (tir.IntImm, tir.FloatImm)) + and float(a.value) == 1.0 + and isinstance(b, tir.BufferLoad)): + wrapped = _wrap_src(b, dst.indices, buf_table, dst_buf=dst.buffer) + if wrapped is None: + return None + dst_axes, src_axes = _build_axes([wrapped]) + return Elementwise( + dst=dst, srcs=[wrapped], op=UnaryOp.RECI, + axis=axis, size=size, + dst_axes=dst_axes, src_axes=src_axes, + ) + return None + + # Pure copy: dst[idx] = src[idx]. + if isinstance(expr, tir.BufferLoad): + wrapped = _wrap_src(expr, dst.indices, buf_table, dst_buf=dst.buffer) + if wrapped is None: + return None + dst_axes, src_axes = _build_axes([wrapped]) + return Elementwise( + dst=dst, srcs=[wrapped], op=UnaryOp.COPY, + axis=axis, size=size, + dst_axes=dst_axes, src_axes=src_axes, + ) + + # Binary: A op B (each a BufferLoad — may broadcast independently). + binop = _try_bin_op(expr) + if binop is not None: + srcs: List[Union[BufferRef, Broadcast]] = [] + for arg in (expr.a, expr.b): + arg = _peel_cast_roundtrip(arg) + if isinstance(arg, tir.BufferLoad): + wrapped = _wrap_src(arg, dst.indices, buf_table, dst_buf=dst.buffer) + if wrapped is None: + return None + srcs.append(wrapped) + else: + # Scalar literal / compound expr in binop → not foldable. + return None + dst_axes, src_axes = _build_axes(srcs) + return Elementwise( + dst=dst, srcs=srcs, op=binop, + axis=axis, size=size, + dst_axes=dst_axes, src_axes=src_axes, + ) + + return None + + +# --------------------------------------------------------------------------- +# Reduce / Gemm / Dma extern recognisers +# --------------------------------------------------------------------------- + + +_REDUCE_OPS_BY_NAME = { + "max": ReduceOp.MAX, + "sum": ReduceOp.SUM, + "min": ReduceOp.MIN, +} + + +def _fold_reduce(call: tir.Call, + buf_table: Dict[str, BufferDef]) -> Reduce: + """``tl.tileop.reduce(src, dst, op_name, dim, clear)``. + + Tilelang's reduce ABI varies — args[0] / args[1] are always + src/dst (either a region call or a bare BufferLoad). The op-name + StringImm and the dim IntImm can sit in different positions + depending on tilelang version (we've seen op at arg[2] / dim at + arg[3], and dim at arg[2] / op at arg[4]). Scan args[2:] to + pick them out by type. + """ + args = _call_args(call) + if len(args) < 4: + raise FoldError(f"tl.tileop.reduce: expected ≥4 args, got {len(args)}") + src_ref = _region_to_ref(args[0], buf_table) + dst_ref = _region_to_ref(args[1], buf_table) + + op_name: Optional[str] = None + axis: Optional[int] = None + for cand in args[2:]: + if op_name is None and isinstance(cand, tir.StringImm): + op_name = str(cand.value).lower() + elif axis is None and isinstance(cand, tir.IntImm): + # First IntImm after the regions is the dim. (clear=0/1 + # also IntImm, but we only need one.) + axis = int(cand.value) + if op_name is not None and axis is not None: + break + if op_name is None: + raise FoldError( + f"tl.tileop.reduce: cannot determine op kind from args={args!r}" + ) + if axis is None: + raise FoldError( + f"tl.tileop.reduce: cannot determine dim from args={args!r}" + ) + op = _REDUCE_OPS_BY_NAME.get(op_name) + if op is None: + raise FoldError(f"unknown reduce op {op_name!r}") + # Build axes: dst is one rank lower than src; the collapsed + # axis is tagged REDUCE on src, every other axis is BATCH. + src_rank = len(src_ref.indices) + normalised_axis = axis + src_rank if axis < 0 else axis + src_axes: List[AxisInfo] = [] + for dim, ext in enumerate(src_ref.buffer.shape): + full = int(ext) + if dim == normalised_axis: + src_axes.append(AxisInfo(role=AxisRole.REDUCE, extent=full)) + else: + src_axes.append(_axes_for_ref(src_ref, None, 0)[dim]) + # ``_axes_for_ref(..., None, 0)`` returned SIMD across every dim; + # we only want SIMD treatment when dim is REDUCE's neighbor on the + # SIMD axis, which Reduce doesn't have. Force BATCH instead. + src_axes[-1] = AxisInfo(role=AxisRole.BATCH, extent=full) + dst_axes: List[AxisInfo] = [] + for dim, ext in enumerate(dst_ref.buffer.shape): + dst_axes.append(AxisInfo(role=AxisRole.BATCH, extent=int(ext))) + return Reduce( + dst=dst_ref, src=src_ref, op=op, axis=axis, + dst_axes=dst_axes, src_axes=src_axes, + ) + + +def _fold_dma(call: tir.Call, + buf_table: Dict[str, BufferDef]) -> Dma: + args = _call_args(call) + if len(args) < 2: + raise FoldError(f"tl.tileop.copy: expected 2 args, got {len(args)}") + src_ref = _region_to_ref(args[0], buf_table) + dst_ref = _region_to_ref(args[1], buf_table) + # Default DMA axis tagging: the innermost dim is SIMD (one HW + # vector load/store moves a contiguous mlen-aligned run along it), + # every other dim is BATCH (the kernel fans out one HW issue per + # index along it). View pass prepends a CLUSTER axis when the + # buffer is lane-aware. This matches the per-axis story Elementwise + # uses for default ``axis=-1, size=last_dim``. + def _default_axes(ref: BufferRef) -> List[AxisInfo]: + shape = ref.buffer.shape + out: List[AxisInfo] = [] + for i, d in enumerate(shape): + role = AxisRole.SIMD if i == len(shape) - 1 else AxisRole.BATCH + out.append(AxisInfo(role=role, extent=int(d))) + return out + return Dma( + src=src_ref, dst=dst_ref, + src_axes=_default_axes(src_ref), + dst_axes=_default_axes(dst_ref), + ) + + +def _fold_gemm(call: tir.Call, + kind: str, + buf_table: Dict[str, BufferDef]) -> Gemm: + args = _call_args(call) + if len(args) < 3: + raise FoldError(f"tl.tileop.gemm_py: expected ≥3 args, got {len(args)}") + a = _region_to_ref(args[0], buf_table) + b = _region_to_ref(args[1], buf_table) + c = _region_to_ref(args[2], buf_table) + # tilelang's gemm extern ABI: args[3..] include transpose flags as + # IntImm 0/1. Order is (transpose_a, transpose_b) per gemm_macros + # docstring. Accept either position; default both False. + ta, tb = False, False + flags = [a for a in args[3:] if isinstance(a, tir.IntImm)] + if len(flags) >= 1: + ta = bool(int(flags[0].value)) + if len(flags) >= 2: + tb = bool(int(flags[1].value)) + # Tag Gemm operand axes with their algebra roles. At fold time the + # refs are rank-2 (pre-split lane prepend), so the labelling is + # unambiguous from the matmul algebra: + # + # c = a @ b -> c is [M, N] + # a is [M, K] (transpose_a flips to [K, M]) + # b is [K, N] (transpose_b flips to [N, K]) + # + # split prepends an extra CLUSTER axis on lane-aware operands; + # view/burn_view permute the axes alongside indices. Downstream + # lowering (e.g. ``_lower_bare_per_head_gemm``) reads ``c_axes`` to + # locate GEMM_M without scanning shape extents. + def _pair(ref, roles): + rank = len(ref.buffer.shape) + if rank != 2: + # Fold sees pre-split rank-2 operands. If a kernel author + # ever hands us a rank-3 tile (unlikely but possible), + # leave axes empty — downstream paths will need to handle + # it explicitly. + return [] + return [ + AxisInfo(role=roles[i], extent=int(ref.buffer.shape[i])) + for i in range(rank) + ] + + a_roles = (AxisRole.GEMM_K, AxisRole.GEMM_M) if ta else (AxisRole.GEMM_M, AxisRole.GEMM_K) + b_roles = (AxisRole.GEMM_N, AxisRole.GEMM_K) if tb else (AxisRole.GEMM_K, AxisRole.GEMM_N) + c_roles = (AxisRole.GEMM_M, AxisRole.GEMM_N) + return Gemm( + a=a, b=b, c=c, + transpose_a=ta, transpose_b=tb, kind=kind, + a_axes=_pair(a, a_roles), + b_axes=_pair(b, b_roles), + c_axes=_pair(c, c_roles), + ) + + +# --------------------------------------------------------------------------- +# Walker — produces a flat list of mid_ir Stmt +# --------------------------------------------------------------------------- + + +def _tir_for_kind_name(stmt: tir.For) -> str: + """Return the lowercase tilelang ForKind name (``"serial"`` / + ``"parallel"`` / ``"unrolled"`` / ...). Used to pick between For + and ParallelAxis(CLUSTER) when a T.Parallel doesn't fold into an + Elementwise.""" + try: + return tir.ForKind(int(stmt.kind)).name.lower() + except Exception: + return "serial" + + +def _mid_for_kind(name: str) -> str: + """Map a tilelang for-kind name to the mid-IR For.kind string. + + Values: + * ``"unroll"`` — fully unrolled at codegen + * ``"parallel"`` — ``T.Parallel`` that didn't fold into a vector + op and isn't lifted to a ParallelAxis; preserved here as a + hint that the body is order-independent (no cross-iter + dependency). Downstream passes treat this exactly like + ``"serial"`` except the HLIR for-op carries an + ``annotations["order_independent"] = True`` flag, which the + v2 backend uses to drop the IntRAM idx slot + per-iter + LD/ADDI/ST overhead in favor of running the hw counter as + the loop var directly. + * ``"serial"`` — default, strict-order loop + """ + if name == "unrolled" or name == "unroll": + return "unroll" + if name == "parallel": + return "parallel" + return "serial" + + +def _outer_loop_matches_buffer_axis(dst: BufferRef, + loop_var: tir.Var, + extent: int) -> bool: + """True when ``dst.indices`` references ``loop_var`` (by identity) + on a non-last axis whose buffer extent equals ``extent``. Used to + decide whether an outer ``for row`` is redundant on top of an + already-whole-buffer Elementwise.""" + target = _vref(loop_var) + shape = dst.buffer.shape + if len(dst.indices) != len(shape): + return False + for axis, idx in enumerate(dst.indices): + if axis == len(dst.indices) - 1: + continue # inner axis = the one Elementwise(axis=-1) already covers + if (isinstance(idx, VarRef) and idx == target + and int(shape[axis]) == extent): + return True + return False + + +def _index_expr_uses_varref(idx, target: VarRef) -> bool: + """Recursively look for ``target`` (by identity) inside an IndexExpr. + + Used to decide whether an outer ``for row`` can be absorbed into an + already-folded inner Elementwise: if any Broadcast src still + references the outer var, absorbing the for would leave the var + unbound. + """ + if isinstance(idx, VarRef): + return idx == target + if isinstance(idx, dict): + return any(_index_expr_uses_varref(a, target) + for a in idx.get("args", [])) + return False + + +def _elementwise_refs_var(ew, target: VarRef) -> bool: + """True if any ``Broadcast`` src of ``ew`` references ``target`` + (by identity) in its indices. + + Only the broadcast case is problematic: ``_wrap_src`` keeps the + Broadcast's indices as a prefix of dst's, so the outer var + is preserved literally in the source. Absorbing the outer for then + leaves the var referenced but unbound. + + Same-rank BufferRef srcs already have indices that match dst's one + for one, so absorbing the outer for is symmetric — both dst and + src lose their reference to the var simultaneously and the + whole-buffer Elementwise stands on its own. + """ + for src in ew.srcs: + if not isinstance(src, Broadcast): + continue + for idx in src.src.indices: + if _index_expr_uses_varref(idx, target): + return True + return False + + +def _is_serial_for(stmt: tir.For) -> bool: + return stmt.kind == tir.ForKind.SERIAL + + +def _walk_stmt(stmt, + buf_table: Dict[str, BufferDef], + current_kind: Optional[str]) -> List: + """Walk one TIR Stmt, return a list of mid_ir Stmt items. + + A single TIR construct may unfold into 0, 1, or more mid_ir items + (e.g. a SeqStmt becomes its concatenated children; an AttrStmt for + KIND becomes nothing of its own — its body inherits ``current_kind``). + """ + if stmt is None: + return [] + if isinstance(stmt, tir.SeqStmt): + out = [] + for c in stmt.seq: + out.extend(_walk_stmt(c, buf_table, current_kind)) + return out + if isinstance(stmt, tir.BlockRealize): + for b in stmt.block.alloc_buffers: + if b.name not in buf_table: + buf_table[b.name] = _buffer_def(b, default_scope="shared") + return _walk_stmt(stmt.block.body, buf_table, current_kind) + if isinstance(stmt, tir.AttrStmt): + if stmt.attr_key == _KIND_KEY: + v = stmt.value + kind = v.value if isinstance(v, tir.StringImm) else str(v) + return _walk_stmt(stmt.body, buf_table, current_kind=kind) + if (stmt.attr_key == "thread_extent" + and isinstance(stmt.node, tir.IterVar)): + iv = stmt.node + inner = _walk_stmt(stmt.body, buf_table, current_kind) + ext_val = stmt.value + if not isinstance(ext_val, tir.IntImm): + raise FoldError( + f"thread_extent {iv.var.name!r} non-static: {ext_val!r}" + ) + # blockIdx grid binding from T.Kernel → ParallelAxis(BLOCK_IDX). + # Mid-IR keeps multi-thread semantics: this is NOT a serial + # for, it's an SPMD parallel axis. Pass_8_to_plena flattens + # to a serial outer for at HLIR generation time. + return [ParallelAxis( + axis_name=iv.var.name, + extent=int(ext_val.value), + body=inner, + kind=ParallelKind.BLOCK_IDX, + thread_tag=str(iv.thread_tag) if iv.thread_tag else None, + axis_var=_vref(iv.var), + )] + # Unknown attr — pass through. + return _walk_stmt(stmt.body, buf_table, current_kind) + if isinstance(stmt, tir.For): + # Try fold first. + if stmt.kind == tir.ForKind.PARALLEL: + ew = _try_fold_parallel(stmt, buf_table) + if ew is not None: + return [ew] + # Serial outer wrapping a single fold-able store (1D fp scalar + # update like ``L_INV[row] = 1 / L_NEW[row]``): fold the inner + # body. The serial for is *absorbed* by the Elementwise — its + # extent becomes ``size`` so downstream lowering knows the op + # covers that many elements along ``axis=-1`` (otherwise the + # default ``size=1`` mis-types it as a single scalar and FPRAM + # unrolling won't kick in). + if _is_serial_for(stmt) and isinstance(stmt.body, tir.BufferStore): + extent = (int(stmt.extent.value) + if isinstance(stmt.extent, tir.IntImm) else 1) + ew = _try_fold_store( + stmt.body, parallel_var=stmt.loop_var, + buf_table=buf_table, axis=-1, size=extent, + ) + if ew is not None: + return [ew] + # Serial / unroll outer wrapping an inner T.Parallel that + # already folds to a whole-buffer Elementwise: the inner fold + # produced an op covering every element along ``axis=-1``; the + # outer ``for row`` re-iterates over a dim the op already + # covers. Absorb the outer loop if its extent matches the dim + # the inner Elementwise iterates over (typically the row axis + # of dst). Without this we'd emit the same whole-buffer op + # ``rows`` times. + if (_is_serial_for(stmt) or _tir_for_kind_name(stmt) == "unrolled"): + inner_body = stmt.body + if (isinstance(inner_body, tir.For) + and inner_body.kind == tir.ForKind.PARALLEL): + inner_ew = _try_fold_parallel(inner_body, buf_table) + if (inner_ew is not None + and isinstance(stmt.extent, tir.IntImm) + and _outer_loop_matches_buffer_axis( + inner_ew.dst, stmt.loop_var, int(stmt.extent.value), + ) + # Don't absorb if any Broadcast src still uses + # the outer loop var — e.g. ``dst[row,col] = + # a[row,col] * b[row]`` folds the parallel ``col`` + # away but the ``b[row]`` Broadcast still needs + # ``row`` bound by an enclosing for. Absorbing it + # would leave ``row`` referenced but unbound, + # crashing ExprMaterializer later. Pass through + # as a regular For instead so the outer loop + # keeps its scope. + and not _elementwise_refs_var( + inner_ew, _vref(stmt.loop_var), + )): + return [inner_ew] + # Pass through as a regular For. Body is recursively walked; + # any nested BufferStore that doesn't fold becomes a RawStore. + if not isinstance(stmt.extent, tir.IntImm): + raise FoldError( + f"non-static loop extent on {stmt.loop_var.name!r}: " + f"{stmt.extent!r}" + ) + body = _walk_stmt(stmt.body, buf_table, current_kind) + kind_name = _tir_for_kind_name(stmt) + if kind_name == "parallel": + # T.Parallel that didn't fold into an Elementwise (because + # the inner store had a non-elementwise pattern). Surface + # as a LOGICAL_GRID parallel axis — semantically still N + # concurrent program instances. Pass_3 may later split it + # into (LOGICAL_GRID number, CLUSTER phase) just like a + # blockIdx-bound axis. + return [ParallelAxis( + axis_name=stmt.loop_var.name, + extent=int(stmt.extent.value), + body=body, + kind=ParallelKind.LOGICAL_GRID, + thread_tag=None, + axis_var=_vref(stmt.loop_var), + )] + return [For( + loop_var=stmt.loop_var.name, + extent=int(stmt.extent.value), + body=body, + kind=_mid_for_kind(kind_name), + loop_var_var=_vref(stmt.loop_var), + )] + if isinstance(stmt, tir.LetStmt): + # Should be eliminated by inline_let_stmts; if it lingers, + # walk through and warn implicitly by losing the binding. + return _walk_stmt(stmt.body, buf_table, current_kind) + if isinstance(stmt, tir.IfThenElse): + raise FoldError("tir.IfThenElse not supported by mid_ir") + if isinstance(stmt, tir.Allocate): + # Create a BufferDef from the Allocate (raw form has only + # buffer_var). Best-effort: name from var, shape from extents. + # The body's BufferLoad/Store will see the real name if there's + # a corresponding decl_buffer; if not, this synth def stands in. + name = stmt.buffer_var.name + if name not in buf_table: + try: + buf_table[name] = BufferDef( + name=name, + shape=[int(e.value) for e in stmt.extents], + dtype=str(stmt.dtype), + scope="shared", + ) + except Exception as e: + raise FoldError( + f"could not build BufferDef from raw Allocate {name!r}: {e}" + ) + return _walk_stmt(stmt.body, buf_table, current_kind) + if isinstance(stmt, tir.Evaluate): + val = stmt.value + if not isinstance(val, tir.Call): + return [] + kind = _call_kind(val) + if kind == _TILEOP_COPY: + return [_fold_dma(val, buf_table)] + if kind == _TILEOP_GEMM: + return [_fold_gemm(val, kind=current_kind or "overwrite", + buf_table=buf_table)] + if kind == _TILEOP_REDUCE: + return [_fold_reduce(val, buf_table)] + # Unknown extern: drop with a deliberate marker. Production + # could accumulate these into a side list for diagnostics. + return [] + if isinstance(stmt, tir.BufferStore): + ew = _try_fold_store(stmt, parallel_var=None, buf_table=buf_table) + if ew is not None: + return [ew] + raise FoldError( + f"unrecognised BufferStore — every store must lower to a " + f"single elementwise / reduce / broadcast pattern. " + f"dst={stmt.buffer.name}{list(stmt.indices)} := {stmt.value!r}" + ) + raise FoldError(f"unhandled stmt type {type(stmt).__name__}") + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: tir.PrimFunc, name: str = "kernel") -> MidFunc: + """Fold a raw tir.PrimFunc into mid_ir.""" + global _active_registry + _active_registry = _VarRegistry() + try: + return _run_locked(func, name) + finally: + _active_registry = None + + +def _run_locked(func: tir.PrimFunc, name: str) -> MidFunc: + buf_table: Dict[str, BufferDef] = {} + + # Seed param buffers (always global by convention). + params: List[BufferDef] = [] + for var in func.params: + buf = func.buffer_map.get(var) + if buf is None: + continue + bd = _buffer_def(buf, default_scope="global") + # Force "global" — tilelang doesn't tag params with a scope. + bd = BufferDef(name=bd.name, shape=bd.shape, dtype=bd.dtype, scope="global") + buf_table[bd.name] = bd + params.append(bd) + + body = _walk_stmt(func.body, buf_table, current_kind=None) + + # Fold output invariant: no ``str`` may appear in any BufferRef + # indices. Bare-string indices were the cheat the VarRef rewrite + # exists to remove. If something slipped through, fail loudly here + # rather than letting fuse/to_plena silently mishandle it. + _assert_no_str_in_indices(body) + + # Allocs are everything in buf_table that isn't a param. + param_names = {p.name for p in params} + allocs = [b for n, b in buf_table.items() if n not in param_names] + + # Lane axes from func attr (T.func_attr({"plena.lane_axis": "by"}) or list). + lane_axes: List[str] = [] + if func.attrs is not None and _LANE_AXIS_FUNC_ATTR in func.attrs: + raw = func.attrs[_LANE_AXIS_FUNC_ATTR] + if isinstance(raw, tir.StringImm): + lane_axes = [str(raw.value)] + elif isinstance(raw, str): + lane_axes = [raw] + elif hasattr(raw, "__iter__"): + lane_axes = [ + str(s.value) if isinstance(s, tir.StringImm) else str(s) + for s in raw + ] + + # Carry select prim_func attrs forward so downstream passes (e.g. + # to_plena reading ``plena.layout``) can find them. We unwrap TVM + # ObjectRef strings to plain Python so dict access works uniformly. + attrs_out: Dict[str, object] = {} + if func.attrs is not None: + for k in ("plena.layout",): + if k in func.attrs: + v = func.attrs[k] + if isinstance(v, tir.StringImm): + attrs_out[k] = str(v.value) + else: + attrs_out[k] = str(v) + # ``plena.hoisted_constants`` is a {buffer_name: value} map + # stamped by the ``hoist_float_constants`` pre-pass. Unwrap to + # a plain ``Dict[str, float]`` so to_plena can iterate it + # without TVM-side type acrobatics. + if "plena.hoisted_constants" in func.attrs: + raw = func.attrs["plena.hoisted_constants"] + attrs_out["plena.hoisted_constants"] = { + str(name): float(val.value if hasattr(val, "value") else val) + for name, val in raw.items() + } + + return MidFunc( + name=name, + params=params, + allocs=allocs, + body=body, + lane_axes=lane_axes, + cluster_counts=[], # filled by pass_3 + attrs=attrs_out, + ) + + +__all__ = ["run", "FoldError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/fuse.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/fuse.py new file mode 100644 index 0000000..8c83aec --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/fuse.py @@ -0,0 +1,461 @@ +"""pass_5_fuse: collapse each Async region into a single MultiLaneOp. + +Why this pass exists +-------------------- + +After ``async_wrap`` + ``view``, each cluster body holds: + + * ``Async(body=[one_can_async_op])`` — one async per op (strict) + * bare ``can_async=False`` ops (Reduce, Elementwise w/ Broadcast) + * possibly nested For / cluster / grid + +We collapse each Async into one ``MultiLaneOp`` that: + + * carries the underlying op as ``inner`` + * names the enclosing cluster axes via ``cluster_axis_names`` + (outermost-to-innermost order) + * exposes ``dim_map`` — for each lane-aware (non-global) buffer the + op references, the list of physical dims that correspond to each + cluster axis. Today every buffer's cluster dim is physical dim 0 + (pass_4b prepends the phase index there), so dim_map values are + always ``[0]``. Multi-axis cluster fusion would put extra entries + in this list. + +What stays untouched +-------------------- + + * Bare can_async=False ops in the cluster body — these are per-row + ops that lower to ``for lane in range(cluster)`` at pass_8. + pass_5 doesn't wrap them; they keep BufferRefs with view_perm. + * RawStore, For, ParallelAxis structure + * Buffer shapes +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional + +from ..cluster_guard import should_skip_cluster +from ..ir import ( + BufferRef, Broadcast, VarRef, + Dma, Gemm, Elementwise, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +class FuseError(RuntimeError): + pass + + +# Per-``run`` lookup: name -> VarRef for non-cluster ParallelAxes. The +# CLUSTER walker reads it to find its sibling number axis's identity. +_NUMBER_VAR_BY_NAME: dict = {} + + +@dataclass +class _ClusterAxis: + """One enclosing cluster axis as seen by fuse. + + Names are kept for HLIR dump / API surfaces (``cluster_axis_names`` + is a list of strings). Identity comparisons live on the VarRef + fields; ``_collapse_lane_axis`` only compares by identity. + + Used by ``_collapse_lane_axis`` to recognise both: + * Per-lane indices written as + ``add(phase_var, mul(number_var, count))`` (produced by + pass_4b_view for non-global buffers). + * Bare ``VarRef`` matching the original lane var + (``by``-equivalent), kept verbatim for global / global.* refs + whose indices view skipped. + Both forms collapse to ``ranged_slice(mul(number_var, count), + count)`` so multi-lane sync ops read the full cluster's chunk in + one go. + """ + phase_name: str + number_name: str + count: int + original_name: str + phase_var: VarRef + number_var: VarRef + original_var: VarRef + + +# --------------------------------------------------------------------------- +# Cluster stack — track enclosing cluster axes outermost → innermost +# --------------------------------------------------------------------------- + + +def _collect_op_refs(op) -> List[BufferRef]: + """Return every BufferRef the op directly references. Used to + build dim_map.""" + refs: List[BufferRef] = [] + if isinstance(op, Dma): + refs.extend([op.src, op.dst]) + elif isinstance(op, Gemm): + refs.extend([op.a, op.b, op.c]) + elif isinstance(op, Elementwise): + refs.append(op.dst) + for s in op.srcs: + if isinstance(s, Broadcast): + refs.append(s.src) + else: + refs.append(s) + elif isinstance(op, Reduce): + refs.extend([op.dst, op.src]) + elif isinstance(op, RawStore): + refs.append(op.dst) + return refs + + +def _build_dim_map(op, cluster_axis_names: List[str]) -> Dict[str, List[int]]: + """For each non-global buffer the op touches, record the physical + dims that map to each cluster axis (in cluster_axis_names order). + + Reads ``ref.buffer.cluster_dim`` directly — split sets it on + cluster-expanded buffers and view/burn_view permute it along with + any axis reshuffle. For ``global.*`` user caches (not + cluster-expanded), ``cluster_dim is None`` and the buffer is + excluded from the map. + """ + n_axes = len(cluster_axis_names) + out: Dict[str, List[int]] = {} + refs = op.list_refs() if hasattr(op, "list_refs") else _collect_op_refs(op) + for ref in refs: + if ref.buffer.scope == "global": + continue + cdim = ref.buffer.cluster_dim + if cdim is None: + continue + # Single-axis cluster today: emit ``[cdim]`` for every entry + # in cluster_axis_names. Multi-axis support would carry an + # ordered list on the BufferDef. + out[ref.buffer.name] = [cdim] * n_axes + return out + + +# --------------------------------------------------------------------------- +# Walker +# --------------------------------------------------------------------------- + + +def _walk(stmt: Stmt, cluster_stack: List[_ClusterAxis]) -> Stmt: + if isinstance(stmt, ParallelAxis): + if stmt.kind == ParallelKind.CLUSTER: + if stmt.parent_grid_axis_name is None: + raise FuseError( + f"cluster axis {stmt.axis_name!r} missing " + f"parent_grid_axis_name; pass_3_split should have set it" + ) + # Read the user-visible original axis name straight off the + # ParallelAxis (set by pass_3_split). Parsing string + # suffixes (``"_phase"`` / ``"_number"``) used to work but + # made the contract fragile against any future renaming + # scheme; ``original_axis_name`` is the explicit channel. + phase = stmt.axis_name + original = stmt.original_axis_name + if original is None: + raise FuseError( + f"cluster axis {phase!r} missing original_axis_name; " + f"pass_3_split should have set it" + ) + if stmt.axis_var is None or stmt.original_axis_var is None: + raise FuseError( + f"cluster axis {phase!r}: identity fields " + f"(axis_var / original_axis_var) must be set by split" + ) + number_var = _NUMBER_VAR_BY_NAME.get(stmt.parent_grid_axis_name) + if number_var is None: + raise FuseError( + f"cluster {phase!r}: number axis " + f"{stmt.parent_grid_axis_name!r} VarRef not recorded; " + f"split must emit ``number -> phase`` nesting" + ) + new_stack = cluster_stack + [_ClusterAxis( + phase_name=phase, + number_name=stmt.parent_grid_axis_name, + count=stmt.extent, + original_name=original, + phase_var=stmt.axis_var, + number_var=number_var, + original_var=stmt.original_axis_var, + )] + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s, new_stack) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + # Non-cluster ParallelAxis. Record axis_var by name so a nested + # CLUSTER can pick up the matching number VarRef. + if stmt.axis_var is not None: + _NUMBER_VAR_BY_NAME[stmt.axis_name] = stmt.axis_var + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s, cluster_stack) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=[_walk(s, cluster_stack) for s in stmt.body], + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + return _fuse_async(stmt, cluster_stack) + if isinstance(stmt, MultiLaneOp): + # Already fused (idempotency). Recurse into inner just in case + # — though typically inner is a leaf. + return stmt + # Leaf: pass through. + return stmt + + +def _match_lane_composite(idx, axes: List[_ClusterAxis]): + """If ``idx`` is exactly a lane index — either the bare original + lane var, or the ``add(phase, mul(number, count))`` split form + pass_4b_view produces — return its matching ``_ClusterAxis``. + Otherwise return ``None``. + + This is the recogniser for "this expression IS one lane axis"; + ``_collapse_lane_axis`` uses it both for whole-axis indices and as + the kernel of the ``lane ± const`` head-offset case. + """ + if isinstance(idx, VarRef): + for ax in axes: + if idx == ax.original_var: + return ax + return None + if isinstance(idx, dict) and idx.get("op") == "add": + args = idx.get("args", []) + if len(args) == 2 and isinstance(args[0], VarRef): + phase = args[0] + inner = args[1] + if isinstance(inner, dict) and inner.get("op") == "mul": + m_args = inner.get("args", []) + if (len(m_args) == 2 and isinstance(m_args[0], VarRef) + and isinstance(m_args[1], int)): + number, count = m_args[0], m_args[1] + for ax in axes: + if (phase == ax.phase_var + and number == ax.number_var + and count == ax.count): + return ax + return None + + +def _ranged_slice_for_axis(ax: "_ClusterAxis", extra_offset=None): + """Build ``ranged_slice(mul(number, count) [+ extra_offset], count)`` + for one cluster axis. ``extra_offset`` (a constant head offset, or + None) is folded into the slice's START so the ranged_slice stays at + the TOP of the index expression — downstream ``_ref_extents`` and + ``_render_idx_as_primexpr`` only recognise a top-level ranged_slice. + """ + base = {"op": "mul", "args": [ax.number_var, ax.count]} + if extra_offset is not None: + base = {"op": "add", "args": [base, extra_offset]} + return {"op": "ranged_slice", "args": [base, ax.count]} + + +def _collapse_lane_axis(idx, axes: List[_ClusterAxis]): + """Fold a per-lane index expression back into a cluster-wide + ``ranged_slice``. + + pass_4b_view turns the original lane var (e.g. ``by``) into + ``add(phase_var, mul(number_var, count))`` — that's the correct + per-lane expression for op kinds that fire once per lane (Reduce, + Broadcast Elementwise). For multi-lane ops (Async-wrapped DMA / + btmm / pure Elementwise) the cluster fires the op exactly once + across all lanes, so the same axis position should describe a span + of ``count`` consecutive lane indices starting at the cluster's + base — encoded as ``ranged_slice(mul(number_var, count), count)``. + + Three cases handled: + 1. The bare original lane var, or the ``add(phase, mul(number, + count))`` split form — collapse straight to a ranged_slice. + 2. ``lane_composite ± const`` (a head-offset write, e.g. + ``Y_hbm[..., by + 8, ...]``) — the constant is folded INTO + the ranged_slice's start so the ranged_slice stays top-level + and keeps ``extent == count``. Without this the constant add + buries the ranged_slice one level down and ``_ref_extents`` + falls back to extent 1, writing only one lane's worth. + 3. Anything else — recurse into children (the lane composite may + live deep inside a compound, e.g. ``mul(by_expr, stride)``). + """ + # Case 1: idx is itself a lane axis. + ax = _match_lane_composite(idx, axes) + if ax is not None: + return _ranged_slice_for_axis(ax) + + if not isinstance(idx, dict): + return idx + + # Case 2: ``lane_composite + const`` or ``const + lane_composite``. + # (Subtraction of a const is normalised by upstream IR builders to + # an add of a negative IntImm, so matching ``add`` covers both.) + if idx.get("op") == "add": + args = idx.get("args", []) + if len(args) == 2: + a0, a1 = args + for lane_arg, other in ((a0, a1), (a1, a0)): + lane_ax = _match_lane_composite(lane_arg, axes) + if lane_ax is not None and isinstance(other, int): + return _ranged_slice_for_axis(lane_ax, extra_offset=other) + + # Case 3: recurse into children. + return { + "op": idx.get("op"), + "args": [_collapse_lane_axis(a, axes) for a in idx.get("args", [])], + } + + +def _collapse_ref(ref: BufferRef, axes: List[_ClusterAxis]) -> BufferRef: + """Apply ``_collapse_lane_axis`` to every index of a user-declared + global ref (``global`` HBM or ``global.vram`` / ``global.mram`` / + ``global.fpram`` on-chip caches). + + Non-global refs already had their lane axis baked into the + prepended phase by pass_4b_view, so we only rewrite globals. For + on-chip globals (``global.*``), the kernel author still indexes + them with the un-split logical lane var ``by`` — fuse_pass widens + that to a cluster-wide ``ranged_slice`` so the multi-lane sync + wrap reads the full cluster's chunk (by_phase drops out).""" + if not (ref.buffer.scope == "global" + or ref.buffer.scope.startswith("global.")): + return ref + return BufferRef( + buffer=ref.buffer, + indices=[_collapse_lane_axis(i, axes) for i in ref.indices], + view_perm=ref.view_perm, + ) + + +def _collapse_src(src, axes: List[_ClusterAxis]): + if isinstance(src, Broadcast): + return Broadcast( + src=_collapse_ref(src.src, axes), + broadcast_dims=list(src.broadcast_dims), + ) + return _collapse_ref(src, axes) + + +def _collapse_lane_in_op(op, axes: List[_ClusterAxis]): + """Rebuild ``op`` with HBM refs widened to cluster-wide ranged + slices. Only HBM refs are touched; on-chip refs are unchanged. + + axes (per-axis info on each operand) are passed through verbatim + — collapsing HBM lane slices only changes ``indices``/extents at + the *value* level, not the axis-role tagging. axis-aware passes + that need the new extent should consult ``ref.buffer.shape``. + """ + if isinstance(op, Dma): + return Dma( + src=_collapse_ref(op.src, axes), + dst=_collapse_ref(op.dst, axes), + src_axes=list(op.src_axes), + dst_axes=list(op.dst_axes), + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, Gemm): + return Gemm( + a=_collapse_ref(op.a, axes), + b=_collapse_ref(op.b, axes), + c=_collapse_ref(op.c, axes), + transpose_a=op.transpose_a, + transpose_b=op.transpose_b, + kind=op.kind, + a_axes=list(op.a_axes), + b_axes=list(op.b_axes), + c_axes=list(op.c_axes), + marker=op.marker, + can_async=op.can_async, + ) + if isinstance(op, Elementwise): + return Elementwise( + dst=_collapse_ref(op.dst, axes), + srcs=[_collapse_src(s, axes) for s in op.srcs], + op=op.op, + dst_axes=list(op.dst_axes), + src_axes=[list(s) for s in op.src_axes], + axis=op.axis, + size=op.size, + marker=op.marker, + can_async=op.can_async, + ) + return op + + +def _fuse_async(stmt: Async, cluster_stack: List[_ClusterAxis]) -> Stmt: + """One Async → one MultiLaneOp. Async body must hold exactly one + op (the strict one-async-one-op invariant from pass_4). + + During fusion we also collapse the per-lane index expressions + inside the inner op's HBM refs back into cluster-wide ranged + slices, since the resulting MultiLaneOp fires once across all + lanes (not once per lane). + """ + if not cluster_stack: + raise FuseError( + f"Async #{stmt.scope_id} found outside any cluster — " + f"this shouldn't happen if pass_4_async ran first" + ) + if len(stmt.body) != 1: + raise FuseError( + f"Async #{stmt.scope_id} must hold exactly one op " + f"(got {len(stmt.body)}); pass_4 enforces one-async-one-op" + ) + inner = stmt.body[0] + if isinstance(inner, (Async, MultiLaneOp, ParallelAxis, For)): + raise FuseError( + f"Async #{stmt.scope_id} body must be a leaf op, got " + f"{type(inner).__name__}" + ) + inner = _collapse_lane_in_op(inner, cluster_stack) + axis_names = [ax.phase_name for ax in cluster_stack] + axis_vars = [ax.phase_var for ax in cluster_stack] + return MultiLaneOp( + inner=inner, + cluster_axis_names=axis_names, + cluster_axis_vars=axis_vars, + dim_map=_build_dim_map(inner, axis_names), + ) + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Collapse Async regions into MultiLaneOp nodes.""" + if should_skip_cluster(func): + return func + _NUMBER_VAR_BY_NAME.clear() + return MidFunc( + name=func.name, + params=list(func.params), + allocs=list(func.allocs), + body=[_walk(s, cluster_stack=[]) for s in func.body], + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "FuseError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/infer_lane_axis.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/infer_lane_axis.py new file mode 100644 index 0000000..77e1524 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/infer_lane_axis.py @@ -0,0 +1,184 @@ +"""pass_0_infer_lane_axis: pick the lane axis from a raw PrimFunc. + +Why this pass exists +-------------------- + +Kernel authors used to declare the lane axis explicitly via +``T.func_attr({"plena.lane_axis": "by"})``. That's annoying boilerplate +the compiler can deduce by looking at how each grid var is *used* in +the kernel body — not at its extent. + +The judgment principle: a lane axis is a blockIdx grid var that +appears as a **bare** index into some buffer access (e.g. ``T.copy( +Q_hbm[0, q_block*rows, by, 0], Q_sh)`` — ``by`` sits at index slot +2 directly, naked). A grid var that only appears wrapped in +arithmetic (``q_block * rows`` for an offset computation) is acting +as an outer control loop, not as a per-lane indexing dim. + +Algorithm: + + * Walk every ``AttrStmt(thread_extent, IterVar(thread_tag="blockIdx.*"))`` + to enumerate grid vars + their extents. + * Walk every ``BufferLoad`` and every ``tl.tileop.region`` extern + call. For each grid var, check if it appears as a **bare** + index slot somewhere (``BufferLoad.indices[i] is the same Var``, + not a compound expression containing it). + * Lane candidates = grid vars that appear bare AT LEAST ONCE, + AND whose extent is divisible by LANE. + * If the user manually set ``plena.lane_axis``, respect it. + * If 0 candidates → leave func.attrs as-is; cluster_guard will + skip the cluster pipeline. + * If 1 candidate → pick it. + * If 2+ candidates → ambiguous; raise InferLaneAxisError and ask + the kernel author to disambiguate via ``T.func_attr``. + +Runs BEFORE pass_1_fold — it operates on raw TIR. +""" + +from __future__ import annotations + +from typing import List, Optional, Tuple + +import tvm +from tvm import tir + + +# Same constant as cluster_guard.MLEN; we re-derive LANE here so this +# pass doesn't have to depend on the mid_ir scope vocabulary. +_DEFAULT_LANE = 4 + +_LANE_AXIS_FUNC_ATTR = "plena.lane_axis" + + +class InferLaneAxisError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# Candidate collection +# --------------------------------------------------------------------------- + + +def _collect_block_idx_bindings(func: tir.PrimFunc + ) -> List[Tuple[str, int]]: + """Walk the body, collect ``(var_name, extent)`` for every + ``thread_extent`` AttrStmt whose IterVar is bound to ``blockIdx.*`` + and has a static integer extent.""" + out: List[Tuple[str, int]] = [] + + def visit(stmt) -> None: + if stmt is None: + return + if isinstance(stmt, tir.AttrStmt): + if (stmt.attr_key == "thread_extent" + and isinstance(stmt.node, tir.IterVar) + and stmt.node.thread_tag is not None + and stmt.node.thread_tag.startswith("blockIdx") + and isinstance(stmt.value, tir.IntImm)): + out.append((stmt.node.var.name, int(stmt.value.value))) + visit(stmt.body) + return + if isinstance(stmt, tir.SeqStmt): + for c in stmt.seq: + visit(c) + return + if isinstance(stmt, tir.BlockRealize): + visit(stmt.block.body) + if stmt.block.init is not None: + visit(stmt.block.init) + return + if isinstance(stmt, (tir.For, tir.LetStmt, tir.Allocate)): + visit(stmt.body) + return + if isinstance(stmt, tir.IfThenElse): + visit(stmt.then_case) + if stmt.else_case is not None: + visit(stmt.else_case) + return + + visit(func.body) + return out + + +def _existing_lane_axis(func: tir.PrimFunc) -> Optional[str]: + if func.attrs is None: + return None + if _LANE_AXIS_FUNC_ATTR not in func.attrs: + return None + raw = func.attrs[_LANE_AXIS_FUNC_ATTR] + if isinstance(raw, tir.StringImm): + return str(raw.value) + return str(raw) + + +# --------------------------------------------------------------------------- +# Bare-index detection +# --------------------------------------------------------------------------- + + +def _collect_bare_index_var_names(func: tir.PrimFunc) -> set: + """Return the set of var names that appear as a *bare* index slot + in some BufferLoad anywhere in the body. + + "Bare" means: ``BufferLoad.indices[i]`` is exactly a ``tir.Var``, + not a compound expression. ``q_block * 64`` doesn't qualify; + ``by`` does. + """ + found: set = set() + from tvm.tir import stmt_functor + + def visit(node) -> None: + if isinstance(node, tir.BufferLoad): + for idx in node.indices: + if isinstance(idx, tir.Var): + found.add(idx.name) + # Region-extern calls (tl.tileop.region(BufferLoad, ...)) are + # already covered by the BufferLoad inside their first arg — + # post_order_visit walks down into args. + + stmt_functor.post_order_visit(func.body, visit) + return found + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: tir.PrimFunc, lane: int = _DEFAULT_LANE) -> tir.PrimFunc: + """Return ``func`` with ``plena.lane_axis`` set on attrs. + + Picks the unique grid var that: + * is bound to ``blockIdx.*`` with a static integer extent, + * has extent divisible by ``lane``, + * appears as a bare index slot in some BufferLoad. + + Manual override (``T.func_attr({"plena.lane_axis": ...})``) wins + over the auto-pick. Zero candidates → no attr (cluster_guard + later skips). Multiple candidates → ambiguous, raises. + """ + if _existing_lane_axis(func) is not None: + return func + + grid_bindings = _collect_block_idx_bindings(func) + bare_names = _collect_bare_index_var_names(func) + + candidates = [ + (name, ext) for (name, ext) in grid_bindings + if ext % lane == 0 and name in bare_names + ] + + if not candidates: + return func + if len(candidates) == 1: + return func.with_attr(_LANE_AXIS_FUNC_ATTR, candidates[0][0]) + + raise InferLaneAxisError( + f"ambiguous lane axis: more than one grid var qualifies as a " + f"lane candidate ({[n for n, _ in candidates]!r}). Disambiguate " + f"by writing T.func_attr({{'plena.lane_axis': ''}}) in " + f"the kernel." + ) + + +__all__ = ["run", "InferLaneAxisError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/mark.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/mark.py new file mode 100644 index 0000000..98fb7c1 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/mark.py @@ -0,0 +1,229 @@ +"""pass_2_mark: tag op sites with their lane-fusion role marker. + +Why this pass exists +-------------------- + +After ``fold``, the body is a clean sequence of mid_ir op nodes +(``Dma`` / ``Gemm`` / ``Elementwise`` / ``Reduce`` / ``RawStore``) plus +structure (``For`` for any preserved loops). But there's no per-op +hint about *which* of those sites care about lane fusion. + +This pass walks the function and sets ``.marker`` on each op according +to a small fixed table: + + Dma → Marker.DMA + Gemm(kind="btmm") → Marker.BTMM + Gemm(kind="overwrite" / other) → no marker (per-head, runs inside the lane loop) + Elementwise → Marker.LANE_OP + Reduce → Marker.LANE_OP + RawStore → no marker (pass_3/4 leave it alone) + + Broadcast → no marker — it's not a top-level + stmt. Broadcast appears only as + an entry inside ``Elementwise.srcs`` + (e.g. ``S[r,c] - M_CURR[r]`` folds + to ``Elementwise(S, [S, + Broadcast(M_CURR, dims=[1])], SUB)``). + The enclosing Elementwise already + carries Marker.LANE_OP, which + covers the whole expression incl. + its broadcast srcs. + +That's the entire pass. It does NOT decide which buffers are lane-aware +(pass_3 does that). It does NOT split the grid (pass_3 does that +either). It does NOT wrap anything in Async (pass_4 does that). It +just sets a per-op flag the later passes consult. + +Why no LANE_OP exclusion rules +------------------------------ + +It's tempting to skip Elementwise that operates on already-known +"non-lane" buffers (e.g. an FP scalar update on a buffer the kernel +will keep at full extent). But: + + * we don't know "non-lane" yet — that's pass_3's call + * conservative marking is safe — pass_4 only wraps marked ops in + Async, but the wrapping is harmless if the underlying buffers + happen to not need cluster expansion + * the wrapping decision lives in pass_4 anyway, where it has access + to pass_3's grown-buffer info + +So mark stays dumb-and-uniform. + +Output +------ + +Returns a new MidFunc with the same shape, but with ``.marker`` set on +every Dma / Gemm[btmm] / Elementwise / Reduce node. The pass is +idempotent — calling it twice yields the same markers. +""" + +from __future__ import annotations + +from typing import List + +from ..ir import ( + Dma, Gemm, Elementwise, Reduce, RawStore, For, Async, MultiLaneOp, + Broadcast, Marker, MidFunc, Stmt, + ParallelAxis, +) + + +class MarkError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# Per-node marker assignment +# --------------------------------------------------------------------------- + + +def _mark_dma(op: Dma) -> Dma: + # DMA is always a single multi-lane HW instruction. + return Dma( + src=op.src, dst=op.dst, + src_axes=list(op.src_axes), dst_axes=list(op.dst_axes), + marker=Marker.DMA, can_async=True, + ) + + +def _mark_gemm(op: Gemm) -> Gemm: + # btmm: one multi-lane M_BTMM instruction → async. + # overwrite (per-head): one matmul per lane, runs inside the lane + # loop → not async. + is_btmm = op.kind == "btmm" + return Gemm( + a=op.a, b=op.b, c=op.c, + transpose_a=op.transpose_a, transpose_b=op.transpose_b, + kind=op.kind, + a_axes=list(op.a_axes), + b_axes=list(op.b_axes), + c_axes=list(op.c_axes), + marker=Marker.BTMM if is_btmm else None, + can_async=is_btmm, + ) + + +def _has_broadcast_src(op: Elementwise) -> bool: + return any(isinstance(s, Broadcast) for s in op.srcs) + + +def _mark_elementwise(op: Elementwise) -> Elementwise: + # ``can_async`` marks ops that lower to a single heavyweight HW + # instruction the control thread can fire and walk past — DMA, + # systolic matmul, and vector-engine ops over a full tile. Only + # those genuinely benefit from async dispatch; tagging every + # scalar / per-row op as async just clutters the IR. + # + # Eligible elementwise lowerings: + # * VRAM dst, no broadcast src → tile-wide ``V_*_VV`` / + # ``V_EXP_V`` / etc. — async-eligible. + # Excluded: + # * FPRAM-scalar dst (``fragment`` allocated at rank 1) → lowers + # to a sequence of ``S_*_FP`` per slot. Per-row, fp scalar + # instruction, not async-eligible. + # * Any elementwise with a Broadcast src (``S[r,c] - M_CURR[r]``) + # → lowers to ``row_*_fp`` per row, also not async-eligible. + is_fpram_dst = ( + op.dst.buffer.scope in ("fragment", "local.fragment", "fragment.fpram") + and len(op.dst.buffer.shape) == 1 + ) + can_async = ( + not _has_broadcast_src(op) + and not is_fpram_dst + ) + return Elementwise( + dst=op.dst, srcs=op.srcs, op=op.op, axis=op.axis, size=op.size, + dst_axes=list(op.dst_axes), + src_axes=[list(s) for s in op.src_axes], + marker=Marker.LANE_OP, + can_async=can_async, + ) + + +def _mark_reduce(op: Reduce) -> Reduce: + # Reduce on PLENA is ``row_reduce_max_at`` / ``row_reduce_sum_at`` — + # per-row, never async. + return Reduce( + dst=op.dst, src=op.src, op=op.op, axis=op.axis, + dst_axes=list(op.dst_axes), + src_axes=list(op.src_axes), + marker=Marker.LANE_OP, + can_async=False, + ) + + +# --------------------------------------------------------------------------- +# Walker +# --------------------------------------------------------------------------- + + +def _walk(stmt: Stmt) -> Stmt: + if isinstance(stmt, Dma): + return _mark_dma(stmt) + if isinstance(stmt, Gemm): + return _mark_gemm(stmt) + if isinstance(stmt, Elementwise): + return _mark_elementwise(stmt) + if isinstance(stmt, Reduce): + return _mark_reduce(stmt) + if isinstance(stmt, RawStore): + return stmt # pass-through; never gets a marker + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=[_walk(s) for s in stmt.body], + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, ParallelAxis): + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, Async): + # mark runs before pass_4_async, so we don't expect Async here. + # But if a caller runs mark twice (idempotency), preserve the + # wrapper and re-mark its body. + return Async(body=[_walk(s) for s in stmt.body], scope_id=stmt.scope_id) + if isinstance(stmt, MultiLaneOp): + # Likewise: shouldn't show up before pass_6, but be defensive. + return MultiLaneOp( + inner=_walk(stmt.inner), + cluster_axis_names=list(stmt.cluster_axis_names), + cluster_axis_vars=list(stmt.cluster_axis_vars), + dim_map=dict(stmt.dim_map), + ) + raise MarkError(f"unhandled stmt type {type(stmt).__name__}") + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Set ``.marker`` on every Dma / Gemm[btmm] / Elementwise / + Reduce in ``func.body``. Returns a new MidFunc; original is not + mutated.""" + new_body: List[Stmt] = [_walk(s) for s in func.body] + return MidFunc( + name=func.name, + params=list(func.params), + allocs=list(func.allocs), + body=new_body, + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "MarkError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/split.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/split.py new file mode 100644 index 0000000..f40b404 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/split.py @@ -0,0 +1,435 @@ +"""pass_3_split: split lane-axis blockIdx into (number, phase); grow +non-global buffers by one cluster outer dim. + +What this pass does +------------------- + +Two structural changes, both purely additive: + + 1. **Split the lane-axis blockIdx ParallelAxis** into a (number, phase) + pair, both still parallel axes (NOT for loops): + + parallel by in 0..head_count [blockIdx.y, BLOCK_IDX] + body + ↓ + parallel by_number in 0..head_count/cluster_count [blockIdx.y, BLOCK_IDX] + parallel by_phase in 0..cluster_count [CLUSTER] + body + + ``by_number`` keeps the BLOCK_IDX kind + ``blockIdx.*`` thread_tag. + The HW grid dim shrinks but blockIdx stays a HW grid axis. + + ``by_phase`` is the new CLUSTER axis — ``cluster_count`` lanes + execute the body in lockstep. Mid-IR keeps multi-thread semantics + here; pass_8_to_plena is the only place we ever flatten parallel + to a serial for. + + 2. **Grow every non-global buffer by one outermost dim** of size + ``cluster_count``: + + BufferDef(name="Q_sh", shape=[64, 16], scope="shared") + ↓ + BufferDef(name="Q_sh", shape=[4, 64, 16], scope="shared") + + Every BufferDef referenced from this point forward — params with + ``scope != "global"`` (rare), allocs, and the ``buffer`` field of + every BufferRef — is replaced by the grown version. + + The pass does **not** touch ``BufferRef.indices``: an existing + ``Q_sh[r, c]`` (rank 2) now references a rank-3 buffer with + mismatched index rank. That's intentional. ``pass_4_async`` / + ``pass_5_loop`` introduce ``by_phase`` and prepend it to indices + when wrapping ops in Async regions. + +What this pass DOES NOT do +-------------------------- + + * No async wrapping (pass_4) + * No cluster loop introduction inside the body (pass_5) + * No layout permute / reshape (pass_7) + * No Stmt-rewriting except for the outer-loop split + +Inputs / outputs +---------------- + +Reads ``MidFunc.lane_axes`` (set by fold from the kernel's +``T.func_attr({"plena.lane_axis": ...})``). Writes +``MidFunc.cluster_counts`` (one entry per lane axis; defaults to LANE). + +LANE defaults to 4 (= MLEN/btmm_hlen for the current target). Callable +override via the ``cluster_counts`` argument to ``run`` for tests / +non-default targets. + +Multi-axis lane fusion is supported by passing multiple entries in +``lane_axes`` — each gets its own split + matching cluster_count entry. +The split happens outside-in: for ``lane_axes=["q_block", "by"]`` the +final body shape is ``For(q_block_number) → For(q_block_phase) → +For(by_number) → For(by_phase) → ...``. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional + +from tvm import tir as _tir + +from ..cluster_guard import should_skip_cluster +from ..ir import ( + BufferDef, BufferRef, Slice, VarRef, + Dma, Gemm, Elementwise, Broadcast, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +# MLEN / HLEN for the active target — read from plena_settings.toml +# (the single source of truth shared with the simulator). +from ....plena_settings import load_sizes as _load_sizes + +_DEFAULT_LANE = _load_sizes().hardware_lane_count + + +class SplitError(RuntimeError): + pass + + +# --------------------------------------------------------------------------- +# Buffer growth +# --------------------------------------------------------------------------- + + +def _grow_buffer(buf: BufferDef, cluster: int) -> BufferDef: + """Return a new BufferDef with shape (cluster,) + buf.shape. + + The prepended cluster dim is marked with ``cluster_dim=0`` so any + later permutation (view / burn_view) can track which axis is the + lane axis without re-deriving it from shape values. + + Preserves the kernel-author's logical rank in the scope string when + it carries semantics: a ``fragment`` allocated with a 1D shape is + per-lane scalar state (M_OLD, P_SUM etc.) and must end up in FPRAM + even after the lane dim is prepended (post-grow rank is 2). We + rename its scope to ``"fragment.fpram"`` so ``to_plena._map_scope`` + can route it without having to re-derive the original rank. + """ + new_scope = buf.scope + if buf.scope in ("fragment", "local.fragment") and len(buf.shape) == 1: + new_scope = "fragment.fpram" + return BufferDef( + name=buf.name, + shape=[cluster] + list(buf.shape), + dtype=buf.dtype, + scope=new_scope, + cluster_dim=0, + ) + + +def _is_lane_aware_buffer(buf: BufferDef) -> bool: + """Anything not user-declared global gets cluster-grown. ``"global"`` + is HBM and ``"global.vram"`` / ``"global.mram"`` / ``"global.fpram"`` + are on-chip user-managed caches — both keep their as-written shape.""" + return not (buf.scope == "global" or buf.scope.startswith("global.")) + + +# --------------------------------------------------------------------------- +# Buffer remapping helpers +# --------------------------------------------------------------------------- + + +@dataclass +class _Ctx: + cluster_counts: List[int] + lane_axes: List[str] + # name -> grown BufferDef. Only contains entries for buffers that + # were grown (lane-aware). Other buffers are left alone. + grown: Dict[str, BufferDef] + + +def _swap_buf(buf: BufferDef, ctx: _Ctx) -> BufferDef: + return ctx.grown.get(buf.name, buf) + + +def _swap_ref(ref: BufferRef, ctx: _Ctx) -> BufferRef: + """Replace the Buffer in a BufferRef. Indices are NOT touched — + a rank-2 ref into a now-rank-3 buffer is intentional; pass_4 + fixes the rank mismatch by prepending ``by_phase`` when it wraps + an op in Async.""" + new_buf = _swap_buf(ref.buffer, ctx) + if new_buf is ref.buffer: + return ref + return BufferRef(buffer=new_buf, indices=list(ref.indices)) + + +def _swap_src(src, ctx: _Ctx): + if isinstance(src, Broadcast): + return Broadcast( + src=_swap_ref(src.src, ctx), + broadcast_dims=list(src.broadcast_dims), + ) + return _swap_ref(src, ctx) + + +# --------------------------------------------------------------------------- +# Stmt walker — only swaps BufferRefs; structure unchanged +# --------------------------------------------------------------------------- + + +def _walk_stmt(stmt: Stmt, ctx: _Ctx) -> Stmt: + # NOTE: split rewrites BufferDef shapes (prepends cluster dim) but + # intentionally leaves BufferRef.indices alone — async_wrap is the + # pass that adds the corresponding ``by_phase`` index later. Since + # ``axes`` are aligned to ``indices`` (one entry per index), they + # transparently survive split. We deep-copy them so downstream + # mutations don't alias the originals. + if isinstance(stmt, Dma): + return Dma( + src=_swap_ref(stmt.src, ctx), + dst=_swap_ref(stmt.dst, ctx), + src_axes=list(stmt.src_axes), + dst_axes=list(stmt.dst_axes), + marker=stmt.marker, + can_async=stmt.can_async, + ) + if isinstance(stmt, Gemm): + return Gemm( + a=_swap_ref(stmt.a, ctx), + b=_swap_ref(stmt.b, ctx), + c=_swap_ref(stmt.c, ctx), + transpose_a=stmt.transpose_a, + transpose_b=stmt.transpose_b, + kind=stmt.kind, + a_axes=list(stmt.a_axes), + b_axes=list(stmt.b_axes), + c_axes=list(stmt.c_axes), + marker=stmt.marker, + can_async=stmt.can_async, + ) + if isinstance(stmt, Elementwise): + return Elementwise( + dst=_swap_ref(stmt.dst, ctx), + srcs=[_swap_src(s, ctx) for s in stmt.srcs], + op=stmt.op, + dst_axes=list(stmt.dst_axes), + src_axes=[list(s) for s in stmt.src_axes], + axis=stmt.axis, + size=stmt.size, + marker=stmt.marker, + can_async=stmt.can_async, + ) + if isinstance(stmt, Reduce): + return Reduce( + dst=_swap_ref(stmt.dst, ctx), + src=_swap_ref(stmt.src, ctx), + op=stmt.op, + axis=stmt.axis, + dst_axes=list(stmt.dst_axes), + src_axes=list(stmt.src_axes), + marker=stmt.marker, + can_async=stmt.can_async, + ) + if isinstance(stmt, RawStore): + return RawStore( + dst=_swap_ref(stmt.dst, ctx), + value=stmt.value, + ) + if isinstance(stmt, ParallelAxis): + return _split_or_walk_parallel(stmt, ctx) + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=[_walk_stmt(s, ctx) for s in stmt.body], + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + return Async( + body=[_walk_stmt(s, ctx) for s in stmt.body], + scope_id=stmt.scope_id, + ) + if isinstance(stmt, MultiLaneOp): + return MultiLaneOp( + inner=_walk_stmt(stmt.inner, ctx), + cluster_axis_names=list(stmt.cluster_axis_names), + cluster_axis_vars=list(stmt.cluster_axis_vars), + dim_map=dict(stmt.dim_map), + ) + raise SplitError(f"unhandled stmt type {type(stmt).__name__}") + + +def _split_or_walk_parallel(stmt: ParallelAxis, ctx: _Ctx) -> Stmt: + """If ``stmt`` is a (BLOCK_IDX | LOGICAL_GRID) axis whose name is in + ``lane_axes``, split it into (number axis, CLUSTER phase axis). + The number axis preserves the source kind (BLOCK_IDX stays + BLOCK_IDX, LOGICAL_GRID stays LOGICAL_GRID); the phase axis + becomes CLUSTER and back-references the number axis name via + ``parent_grid_axis_name``. + + Identity channel: we mint *fresh* ``tir.Var`` objects for the phase + and number axes (no existing TIR var corresponds to them — they + didn't exist pre-split). ``original_axis_var`` on both new axes + points at the pre-split user var's :class:`VarRef`, taken from + ``stmt.axis_var`` so consumers can recognise references to the + original lane variable by identity. + """ + splittable_kinds = (ParallelKind.BLOCK_IDX, ParallelKind.LOGICAL_GRID) + if stmt.kind in splittable_kinds and stmt.axis_name in ctx.lane_axes: + idx = ctx.lane_axes.index(stmt.axis_name) + cluster = ctx.cluster_counts[idx] + if stmt.extent % cluster != 0: + raise SplitError( + f"lane axis {stmt.axis_name!r} extent={stmt.extent} is " + f"not a multiple of cluster_count={cluster}" + ) + outer_extent = stmt.extent // cluster + inner_body = [_walk_stmt(s, ctx) for s in stmt.body] + original_name = stmt.axis_name + number_name = f"{original_name}_number" + phase_name = f"{original_name}_phase" + # Identity propagation: the user-written lane var (e.g. ``by``) + # came from fold and lives on ``stmt.axis_var``. Both halves of + # the split point back at it via ``original_axis_var`` so + # downstream identity checks can recognise references to the + # pre-split var even after we've replaced the axis with two new + # ones. + original_var_ref = stmt.axis_var + if original_var_ref is None: + raise SplitError( + f"split: lane-axis ParallelAxis {original_name!r} has no " + f"axis_var; fold must have populated it" + ) + # Mint fresh tir.Vars for the phase / number axes; they didn't + # exist in the input TIR. + phase_var_obj = _tir.Var(phase_name, "int32") + number_var_obj = _tir.Var(number_name, "int32") + phase_axis = ParallelAxis( + axis_name=phase_name, + extent=cluster, + body=inner_body, + kind=ParallelKind.CLUSTER, + thread_tag=None, + parent_grid_axis_name=number_name, + original_axis_name=original_name, + axis_var=VarRef(phase_var_obj), + original_axis_var=original_var_ref, + ) + number_axis = ParallelAxis( + axis_name=number_name, + extent=outer_extent, + body=[phase_axis], + kind=stmt.kind, # BLOCK_IDX or LOGICAL_GRID + thread_tag=stmt.thread_tag, # only set for BLOCK_IDX + parent_grid_axis_name=None, + original_axis_name=original_name, + axis_var=VarRef(number_var_obj), + original_axis_var=original_var_ref, + ) + return number_axis + + # Not a lane axis: pass through, recurse into body. + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk_stmt(s, ctx) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc, + cluster_counts: Optional[List[int]] = None) -> MidFunc: + """Split each declared lane-axis blockIdx into (number, phase) and + grow every non-global buffer by one outer cluster dim. + + ``cluster_counts`` defaults to ``[_DEFAULT_LANE] * len(lane_axes)``. + Pass an explicit list to override (e.g. for non-default targets or + multi-axis cluster fusion with different per-axis sizes). + + No-op if ``should_skip_cluster(func)`` (kernel didn't declare any + lane axis, OR every on-chip buffer's last dim already covers a + full HW vector). + """ + if should_skip_cluster(func): + return func + if cluster_counts is None: + cluster_counts = [_DEFAULT_LANE] * len(func.lane_axes) + if len(cluster_counts) != len(func.lane_axes): + raise SplitError( + f"cluster_counts has {len(cluster_counts)} entries but " + f"lane_axes has {len(func.lane_axes)}; must match" + ) + + # Build the grown-buffer map up front. + # NOTE: with multiple lane axes we currently apply the **product** + # of all cluster sizes as a single outermost dim. This matches + # the only multi-axis case we've seen on paper. If a kernel needs + # a different layout (e.g. separate dims per axis) the BufferDef + # growth would change shape; pass_7_perm decides physical placement + # afterwards. + cluster_total = 1 + for c in cluster_counts: + cluster_total *= c + + grown: Dict[str, BufferDef] = {} + for buf in list(func.params) + list(func.allocs): + if _is_lane_aware_buffer(buf) and buf.name not in grown: + grown[buf.name] = _grow_buffer(buf, cluster_total) + + # Sanity check: every name in lane_axes must correspond to a + # ParallelAxis in the body. Without this, a typo silently leaves + # the body unchanged — every downstream "this axis is the cluster + # axis" assumption then fails far away from the source. + found_axis_names: set = set() + + def _collect_axis_names(stmt): + if isinstance(stmt, ParallelAxis): + found_axis_names.add(stmt.axis_name) + for c in stmt.body: + _collect_axis_names(c) + elif hasattr(stmt, "body") and isinstance(getattr(stmt, "body"), list): + for c in stmt.body: + _collect_axis_names(c) + for s in func.body: + _collect_axis_names(s) + missing = [n for n in func.lane_axes if n not in found_axis_names] + if missing: + raise SplitError( + f"lane_axes {missing!r} not found among the kernel's ParallelAxis " + f"names {sorted(found_axis_names)!r}. Did the kernel actually bind " + f"these axes via T.Kernel(...)? Mismatch would otherwise silently " + f"skip cluster split." + ) + + ctx = _Ctx( + cluster_counts=list(cluster_counts), + lane_axes=list(func.lane_axes), + grown=grown, + ) + + new_body: List[Stmt] = [_walk_stmt(s, ctx) for s in func.body] + new_params = [_swap_buf(b, ctx) for b in func.params] + new_allocs = [_swap_buf(b, ctx) for b in func.allocs] + + return MidFunc( + name=func.name, + params=new_params, + allocs=new_allocs, + body=new_body, + lane_axes=list(func.lane_axes), + cluster_counts=list(cluster_counts), + attrs=dict(func.attrs), + ) + + +__all__ = ["run", "SplitError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/to_plena.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/to_plena.py new file mode 100644 index 0000000..c177009 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/to_plena.py @@ -0,0 +1,2823 @@ +"""pass_6_to_plena: lower MidFunc → HLIRModule. + +This is the only pass that exits the mid-IR domain. Everything before +it stays in mid_ir-native dataclasses; this one walks the (now-baked) +mid_ir and produces the HLIR Buffer/Op/HLIRModule that the legacy +backend (AddressAllocationPass + ISAEmitterPass) consumes unchanged. + +What gets lowered +----------------- + +* Buffers + BufferDef.scope → HLIR Buffer.scope + "global" → _scope.HBM + "shared" / "shared.dyn" → _scope.VRAM + "fragment" / "local.fragment" → _scope.VRAM (rank ≥ 2) + or _scope.FPRAM (rank == 1) + "global." → strip prefix + already a physical scope → identity + BufferDef.shape / .dtype → identity copy + BufferDef.name → identity + +* Op nodes + MultiLaneOp(inner=Dma) → Op(kind="dma_h2v" / "dma_v2h" + / "dma_h2m", + buffer_args=[src_or_slice, dst_or_slice], + scalar_args=[lane_count]) + MultiLaneOp(inner=Gemm[btmm]) → Op(kind="btmm", + buffer_args=[a, b, c], + scalar_args=[group_heads]) + MultiLaneOp(inner=Elementwise pure) → Op(kind="v_add" / "v_sub" / + "v_mul" / "v_exp" / + "v_zero" / ...) + (1-D vector op; multi-row + ops are wrapped in an + explicit for-row in HLIR) + Gemm(kind="overwrite") [bare in cluster] → + for-loop over lane that wraps + one Op(kind="matmul"|"mv") per iter + Reduce [bare] → for lane: for row: Op("row_reduce__at") + Elementwise(broadcast) [bare] → for lane: for row: Op("row__fp_at") + ParallelAxis(BLOCK_IDX|LOGICAL_GRID) → Op(kind="for", body=[...]) + ParallelAxis(CLUSTER) → unwrapped (its body becomes + flat ops in the enclosing scope; + cluster info already burned into + MultiLaneOp scalar_args / dim_map) + For(serial / unroll) → Op(kind="for") + loop_kind annotation + RawStore → *not handled here yet*; raises + +Auto-dump +--------- + +Pass 6 also writes a human-readable mid_ir snapshot to disk before +lowering, per the convention used for HLIR (``.hlir.txt``). +File name: ``.midir.txt`` under the supplied ``build_dir``. +Dumping is opt-in via ``build_dir`` argument; pass None to skip +(handy for tests). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from tvm import tir as _tir + +from .... import hlir as _hlir +from .... import scope as _scope +from ..cluster_guard import should_skip_cluster +from ..ir import ( + BinOp, UnaryOp, ReduceOp, + AxisRole, AxisInfo, + BufferDef, BufferRef, Slice, VarRef, + Dma, Gemm, Elementwise, Broadcast, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, format_func, +) + + +class ToPlenaError(RuntimeError): + pass + + +def _make_loop_var(name: str) -> _tir.Var: + """Build a tir.Var for use as an HLIR ``for`` loop_var annotation + from a bare name. Used when no mid_ir VarRef carries the identity + (e.g. synthetic for-rows that to_plena introduces itself). + + Shares ``_VAR_CACHE`` keyed by name so two calls with the same name + produce the same tir.Var object (the ISA pass keys ``symbol_table`` + by identity). + + Prefer ``_axis_loop_var(stmt)`` / ``_for_loop_var(stmt)`` when + lowering a mid_ir ParallelAxis / For: those reuse the identity + captured during fold so inner BufferRef indices keyed off the same + var resolve to the same tir.Var object. + """ + return _get_var(name) + + +def _axis_loop_var(axis: ParallelAxis) -> _tir.Var: + """HLIR for-loop_var for ``axis``. Routed through the name cache so + every reference to this axis (via ``_render_idx_as_primexpr`` on + matching VarRefs) resolves to the same ``tir.Var`` object the ISA + pass binds in its ``symbol_table``.""" + return _make_loop_var(axis.axis_name) + + +def _for_loop_var(for_stmt: For) -> _tir.Var: + """HLIR for-loop_var for a mid_ir ``For``. Same name-cache routing + as :func:`_axis_loop_var`.""" + return _make_loop_var(for_stmt.loop_var) + + +# --------------------------------------------------------------------------- +# Scope mapping +# --------------------------------------------------------------------------- + + +def _map_scope(scope: str, rank: int, + override: Optional[str] = None) -> str: + """mid_ir scope string → HLIR scope string. + + ``override`` is set by the use-driven inference (e.g. a buffer used + as a Gemm B operand needs MRAM regardless of its declared shared/ + fragment scope). Override wins over the rank-based default. + """ + if scope == "global": + return _scope.HBM + if scope.startswith("global."): + return _scope.physical_scope(scope) + if override is not None: + return override + if scope in ("shared", "shared.dyn"): + return _scope.VRAM + if scope == "fragment.fpram": + # split pass marks per-lane scalar-state fragments (M_OLD, L_NEW + # etc., originally rank-1) with this scope so we route them to + # FPRAM even after the lane dim prepend bumped rank to 2. + return _scope.FPRAM + if scope in ("fragment", "local.fragment"): + # Rank-1 fragments are FPRAM scalar-state (M_OLD, L_NEW etc.); + # higher-rank fragments stay in VRAM (S_loc, PV_loc). + return _scope.FPRAM if rank == 1 else _scope.VRAM + if scope in _scope.PHYSICAL_SCOPES: + return scope + raise ToPlenaError(f"unknown mid_ir scope {scope!r}") + + +# --------------------------------------------------------------------------- +# Buffer construction +# --------------------------------------------------------------------------- + + +def _pad_to_4d_shape( + shape: Tuple[int, ...], heads_at_h: bool = False, +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """Pad a 1D / 2D shape up to canonical 4D for downstream uniformity. + + Distinct from cluster expansion: this is a pure rank-normalisation + step. It carries no lane / cluster semantics — the inserted axes + are extent-1 placeholders so that address_alloc / isa_emit see one + rank everywhere and don't need rank-conditional branches. + + Returns ``(new_4d_shape, inserted_positions)``. Positions index + into the OUTPUT 4D shape; callers must pad every same-rank + reference (VramRegion starts / extents) at exactly these + positions, with ``start=0`` / ``extent=1``. + + Default rule (on-chip scratch): + * 1D ``(n,)`` -> ``(1, 1, 1, n)`` inserts at (0, 1, 2) + * 2D ``(a, b)`` -> ``(1, a, 1, b)`` inserts at (0, 2) + * 4D -> unchanged inserts == () + + ``heads_at_h=True`` (author-pinned ``global.vram``/``global.mram`` + tensor caches whose first axis is the head dim): + * 2D ``(a, b)`` -> ``(1, 1, a, b)`` inserts at (0, 1) + """ + rank = len(shape) + if rank == 4: + return tuple(int(d) for d in shape), () + if rank == 2: + a, b = int(shape[0]), int(shape[1]) + if heads_at_h: + return (1, 1, a, b), (0, 1) + return (1, a, 1, b), (0, 2) + if rank == 1: + n = int(shape[0]) + return (1, 1, 1, n), (0, 1, 2) + raise ToPlenaError( + f"_pad_to_4d_shape: only 1D/2D/4D supported; got " + f"rank-{rank} shape={tuple(shape)}" + ) + + +def _make_hlir_buffer( + buf: BufferDef, + override: Optional[str] = None, + lane_count: Optional[int] = None, + mode: Optional[str] = None, + kernel_layout: str = "BSHD", +) -> Tuple[_hlir.Buffer, Tuple[int, ...]]: + """Build an HLIR ``Buffer`` from a mid_ir ``BufferDef``. + + Returns ``(buffer, inserted_positions)`` — positions in the OUTPUT + 4D shape that ``_pad_to_4d_shape`` synthesised (extent-1). Empty + tuple when no padding happened (cluster-expanded or already 4D). + + Two routes, both producing 4D on VRAM/MRAM: + + * Cluster-fusion route (``mode != None`` and ``lane_count >= 1``) + — ``_expand_buffer_shape_with_cluster`` picks BSHD axes per + lane mode, carries a ``cluster_dim`` to track the lane axis. + * Pad-to-4D route (``mode == None``, cluster skipped) — pure + rank normalisation; no cluster semantics, ``cluster_dim`` + stays as whatever the BufferDef had (typically None). + + HBM (``global*``) and FPRAM buffers keep their author-declared + rank. FPRAM is scalar-addressed (no tile layout); HBM keeps its + kernel-surface shape for parent-stride math. + """ + # Resolve the destination physical scope first so the pad-to-4D + # decision uses the same source of truth the HLIR ``Buffer.scope`` + # field ends up with — rather than guessing from the spelling of + # the mid_ir scope string (``"local.fragment"`` vs + # ``"fragment.fpram"`` vs ``"global.fpram"``). + physical = _map_scope(buf.scope, len(buf.shape), override) + is_global = _is_global_scope(buf.scope) + inserts: Tuple[int, ...] = () + # Cluster expand only applies to allocatable on-chip buffers. HBM + # (``"global"``) and author-pinned on-chip globals + # (``"global.vram"`` / ``"global.mram"`` / ``"global.fpram"``) + # keep the shape the kernel author wrote — they're explicit + # tensor-cache regions, not lane-aware scratch — so neither + # cluster grow nor lane-mode expand may touch them. + if mode is not None and not is_global and lane_count is not None: + shape_list, cluster_dim = _expand_buffer_shape_with_cluster(buf, lane_count, mode) + shape = tuple(shape_list) + else: + shape = tuple(int(d) for d in buf.shape) + cluster_dim = buf.cluster_dim + # Pad-to-4D for on-chip VRAM / MRAM. Author-pinned globals + # (``global.vram`` / ``global.mram``) also get padded so every + # downstream pass sees rank-4, but their pad rule puts heads at + # H (slot 2) — matches how kernels actually use these tensor + # caches (head-major (head_count, hlen)). Crucially we DO NOT + # stamp a ``cluster_dim`` on globals: they aren't cluster- + # expanded, so a cluster tag would confuse the sync-wrap + # iterator. HBM (plain ``"global"``) keeps its author rank for + # parent-stride math; FPRAM is scalar-addressed. + is_onchip_global = is_global and physical in (_scope.VRAM, _scope.MRAM) + if physical in (_scope.VRAM, _scope.MRAM) and len(shape) != 4: + if is_onchip_global: + shape, inserts = _pad_to_4d_shape(shape, heads_at_h=True) + elif not is_global: + shape, inserts = _pad_to_4d_shape(shape) + # ``plena.layout`` describes the HBM-side physical layout of the + # kernel's tensor params. On-chip buffers (VRAM/MRAM/FPRAM allocs, + # and on-chip pad-to-4D synthetic axes) keep the default BSHD — + # tile_layout machinery interprets their (B,S,H,D) by position. + # Stamping NCHW onto e.g. ``in_stage`` would mis-assign its synthetic + # 4D axes (axis-1 isn't a channel dim, it's the row dim by + # construction) and inflate ``h_groups`` to the row extent. + buf_layout = kernel_layout if is_global else "BSHD" + is_pinned = is_global and physical in (_scope.VRAM, _scope.MRAM) + return ( + _hlir.Buffer( + name=buf.name, + scope=physical, + shape=shape, + dtype=buf.dtype, + cluster_dim=cluster_dim, + layout=buf_layout, + is_pinned_global=is_pinned, + ), + inserts, + ) + + +# --------------------------------------------------------------------------- +# Use-driven scope overrides +# --------------------------------------------------------------------------- + + +# Lane-expansion modes for the 4D-BSHD rewrite below. Strings match the +# graph_passes/expand_buffers vocabulary so anyone reading both layers +# sees the same names. +_MODE_COL_PACK = "col_pack" # H carries lane: (1, S, lane, D_narrow) +_MODE_ROW_STACK = "row_stack" # B carries lane: (lane, S, 1, MLEN) +_MODE_FP_LANE = "fp_lane" # FPRAM: (lane, N) +_MODE_BSHD_LIFT = "bshd_lift" # No lane fusion: (1, S, 1, D) + + +def _is_global_scope(scope: str) -> bool: + """True for any author-declared global buffer. + + Matches ``"global"`` (HBM) plus the ``global.`` family + (``global.vram`` / ``global.mram`` / ``global.fpram``) used by + kernels to pin a buffer to a specific on-chip cache without + letting downstream cluster / lane logic re-expand its shape. + + Centralises the check so every pass that needs to skip globals + uses the same predicate (the bug from flash_decode_min was + ``_infer_lane_modes`` checking ``scope == "global"`` only, which + miscategorised ``global.vram`` as lane-aware). + """ + return scope == "global" or scope.startswith("global.") + + +def _infer_lane_modes(func: MidFunc) -> Dict[str, str]: + """For every non-global buffer, decide its lane-expansion mode by + inspecting how mid_ir ops use it. Mirrors + ``graph_passes.allocate_group_memory`` but works on mid_ir nodes. + + Priority (first match wins) — ``ROW_STACK`` takes precedence over + ``COL_PACK`` if a buffer is used as a BTMM dst somewhere. + """ + modes: Dict[str, str] = {} + + def record(name: str, mode: str) -> None: + prev = modes.get(name) + if prev is None: + modes[name] = mode + return + if prev == _MODE_ROW_STACK or mode == _MODE_ROW_STACK: + modes[name] = _MODE_ROW_STACK + return + if prev == _MODE_FP_LANE or mode == _MODE_FP_LANE: + modes[name] = _MODE_FP_LANE + + def visit_op(op) -> None: + if isinstance(op, Gemm): + if op.kind == "btmm": + if not _is_global_scope(op.a.buffer.scope): + record(op.a.buffer.name, _MODE_COL_PACK) + if not _is_global_scope(op.b.buffer.scope): + record(op.b.buffer.name, _MODE_COL_PACK) + if not _is_global_scope(op.c.buffer.scope): + record(op.c.buffer.name, _MODE_ROW_STACK) + else: + if not _is_global_scope(op.a.buffer.scope): + record(op.a.buffer.name, _MODE_ROW_STACK) + if not _is_global_scope(op.b.buffer.scope): + record(op.b.buffer.name, _MODE_COL_PACK) + if not _is_global_scope(op.c.buffer.scope): + record(op.c.buffer.name, _MODE_COL_PACK) + return + if isinstance(op, Dma): + for ref in (op.src, op.dst): + if _is_global_scope(ref.buffer.scope): + continue + if ref.buffer.scope == "fragment.fpram": + record(ref.buffer.name, _MODE_FP_LANE) + else: + record(ref.buffer.name, _MODE_COL_PACK) + return + if isinstance(op, (Elementwise, Reduce)): + refs = [] + if isinstance(op, Elementwise): + refs.append(op.dst) + for s in op.srcs: + refs.append(s.src if isinstance(s, Broadcast) else s) + else: + refs.extend([op.dst, op.src]) + for ref in refs: + if _is_global_scope(ref.buffer.scope): + continue + if ref.buffer.scope == "fragment.fpram": + record(ref.buffer.name, _MODE_FP_LANE) + else: + record(ref.buffer.name, _MODE_COL_PACK) + return + + def visit_stmt(s) -> None: + if isinstance(s, (ParallelAxis, For, Async)): + for c in s.body: + visit_stmt(c) + return + if isinstance(s, MultiLaneOp): + visit_op(s.inner) + return + visit_op(s) + + for s in func.body: + visit_stmt(s) + + # Anything left without a mode (allocated but unused by any + # tracked op) gets the no-lane-fusion catch-all so it still ends + # up as 4D BSHD downstream. + for buf in func.allocs: + if _is_global_scope(buf.scope): + continue + if buf.name not in modes: + modes[buf.name] = ( + _MODE_FP_LANE if buf.scope == "fragment.fpram" + else _MODE_BSHD_LIFT + ) + return modes + + +def _expand_buffer_shape_with_cluster( + buf: BufferDef, lane_count: int, mode: str, +) -> Tuple[List[int], Optional[int]]: + """Reshape to canonical 4D BSHD (or 2D for FPRAM) and report which + axis holds the cluster (lane) dim in the new shape. + + Modes: + * COL_PACK → ``(1, rows, lane, D_narrow)`` — cluster_dim = 2 + * ROW_STACK → ``(lane, rows, 1, MLEN)`` — cluster_dim = 0 + * FP_LANE → ``(lane, N)`` — cluster_dim = 0 + * BSHD_LIFT → ``(1, rows, 1, D)`` — no cluster axis + + Mid-IR pre-expansion shapes (per ``view``/``burn_view`` placement + of the lane axis recorded in ``buf.cluster_dim``): + COL_PACK : ``(rows=shape[0], lane=shape[1], D=shape[2])`` + ROW_STACK : ``(lane=shape[0], rows=shape[1], D=shape[2])`` + BSHD_LIFT : ``(?, rows=shape[1], D=shape[2])`` + """ + if mode == _MODE_FP_LANE: + if len(buf.shape) != 2: + raise ToPlenaError( + f"{buf.name!r}: FP_LANE expansion expects post-grow rank 2 " + f"(lane, N); got shape={list(buf.shape)}" + ) + return [int(buf.shape[0]), int(buf.shape[1])], 0 + if len(buf.shape) != 3: + raise ToPlenaError( + f"{buf.name!r}: 2D-lane expansion expects post-grow rank 3; " + f"got shape={list(buf.shape)} mode={mode}" + ) + if mode == _MODE_ROW_STACK: + rows = int(buf.shape[1]) + last = int(buf.shape[2]) + return [int(lane_count), rows, 1, last], 0 + if mode == _MODE_COL_PACK: + # view leaves lane at axis 1; rows is the leading axis. + rows = int(buf.shape[0]) + last = int(buf.shape[2]) + return [1, rows, int(lane_count), last], 2 + if mode == _MODE_BSHD_LIFT: + # No lane axis in scope; rows still sits at axis 1 per view. + rows = int(buf.shape[1]) + last = int(buf.shape[2]) + return [1, rows, 1, last], None + raise ToPlenaError(f"unknown lane mode {mode!r} for {buf.name!r}") + + +# Where each cluster mode lands the lane axis in the post-expansion +# 4D BSHD shape. Used to drive ref rewriting from rank 3 to rank 4 +# without enumerating per-mode permutations. +# +# This must stay consistent with the new ``cluster_dim`` value +# returned by :func:`_expand_buffer_shape_with_cluster`. +_CLUSTER_MODE_NEW_LANE_DIM: Dict[str, Optional[int]] = { + _MODE_COL_PACK: 2, # lane → H + _MODE_ROW_STACK: 0, # lane → B + _MODE_BSHD_LIFT: None, # no lane in scope +} + + +def _rewrite_ref_for_cluster_mode( + starts_or_extents: Tuple[Any, ...], + mode: str, + old_cluster_dim: Optional[int], + new_shape: Tuple[int, ...], + *, + is_extent: bool, +) -> Tuple[Any, ...]: + """Map a rank-3 ref to its rank-4 BSHD equivalent. + + Inputs: + * ``starts_or_extents`` — the rank-3 source tuple from mid_ir. + For starts the lane-axis slot has already been zeroed under + the sync-wrap convention (so the value at + ``old_cluster_dim`` is ``0``, not a Var). + * ``old_cluster_dim`` — where lane sat in the rank-3 source + (from ``buf.cluster_dim``). + * ``new_shape`` — the post-expansion 4D buffer shape; used to + recover the lane extent (BSHD H or B dim) under sync wrap. + + Sync-wrap semantics on the lane axis (matches + ``_ref_per_dim_starts`` zero-ing the phase): + * lane-axis start -> ``0`` + * lane-axis extent -> ``new_shape[new_lane_dim]`` (the full + lane span, not the per-lane ``1``) + + Non-lane source axes keep their relative order and fill the + remaining non-lane output slots; the leftover output slot is + inserted with ``0`` / ``1``. + """ + if mode == _MODE_FP_LANE: + return starts_or_extents + if len(starts_or_extents) != 3: + raise ToPlenaError( + f"cluster mode {mode!r} expects rank-3 ref tuple, got " + f"{tuple(starts_or_extents)}" + ) + if mode not in _CLUSTER_MODE_NEW_LANE_DIM: + raise ToPlenaError(f"no ref-rewrite rule for cluster mode {mode!r}") + new_lane_dim = _CLUSTER_MODE_NEW_LANE_DIM[mode] + fill: Any = 1 if is_extent else 0 + + if mode == _MODE_BSHD_LIFT or old_cluster_dim is None or new_lane_dim is None: + # No lane axis to relocate. mid_ir source is rank-3 + # ``(?, S, D)`` and we lift it to ``(?, S, 1, D)`` — + # i.e. insert an extent-1 H at position 2. + a, b, c = starts_or_extents + return (a, b, fill, c) + + # Anchor non-lane axes to canonical BSHD positions: + # * source last axis (always D in mid_ir convention) → out[3] (D) + # * source first non-lane axis (S in mid_ir convention) → out[1] (S) + # The remaining BSHD slot is whichever of B (0) / H (2) is NOT the + # lane position; it stays as ``fill``. + non_lane_sources = [i for i in range(3) if i != old_cluster_dim] + if len(non_lane_sources) != 2: + raise ToPlenaError( + f"unexpected non-lane source count {len(non_lane_sources)}" + ) + s_src, d_src = non_lane_sources[0], non_lane_sources[1] + + out: List[Any] = [fill] * 4 + # Lane slot: sync-wrap convention (start=0, extent=lane span). + if is_extent: + out[new_lane_dim] = int(new_shape[new_lane_dim]) + else: + out[new_lane_dim] = 0 + out[1] = starts_or_extents[s_src] # S + out[3] = starts_or_extents[d_src] # D + return tuple(out) + + +def _expand_buffer_shape( + buf: BufferDef, lane_count: int, mode: str, +) -> List[int]: + """Back-compat shape-only wrapper. New code should use + ``_expand_buffer_shape_with_cluster``.""" + shape, _ = _expand_buffer_shape_with_cluster(buf, lane_count, mode) + return shape + + +def _infer_scope_overrides(func: MidFunc) -> Dict[str, str]: + """Scan body; figure out per-buffer scope overrides driven by op + usage. + + Today the only override: any buffer used as a Gemm B operand + (BTMM RHS or per-head matmul RHS) has to live in MRAM. PLENA's + BTMM/MM hardware reads RHS from MRAM only. + + This propagates through DMA destinations: ``dma_h2v`` lowering is + chosen by ``_dma_kind_from_scopes`` from the dst's scope, so once + the dst buffer is tagged MRAM, ``dma_h2m`` will be picked + automatically. + """ + overrides: Dict[str, str] = {} + + def visit_op(op) -> None: + if isinstance(op, Gemm): + # B operand → MRAM, regardless of how shared/fragment + # would otherwise default it. Never overrides author-pinned + # globals (HBM / global.vram / global.mram / global.fpram): + # those carry an explicit user-side placement contract. + if not _is_global_scope(op.b.buffer.scope): + overrides[op.b.buffer.name] = _scope.MRAM + + def visit_stmt(s) -> None: + if isinstance(s, (ParallelAxis, For, Async)): + for c in s.body: + visit_stmt(c) + return + if isinstance(s, MultiLaneOp): + visit_op(s.inner) + return + visit_op(s) + + for s in func.body: + visit_stmt(s) + return overrides + + +# --------------------------------------------------------------------------- +# Index expression rendering +# --------------------------------------------------------------------------- + + +def _render_idx(idx) -> Any: + """Convert mid_ir IndexExpr into an HLIR-friendly form. Slice maps + to 0 (whole-axis access starts at 0). Compound dict expressions are + flattened to a string for the legacy ExprMaterializer to handle — + pre-existing convention is to pass non-static offsets as PrimExpr + or string in scalar_args, but here for HBM we only need int-or-str.""" + if isinstance(idx, Slice): + return 0 + if isinstance(idx, int): + return int(idx) + if isinstance(idx, VarRef): + return idx.name + if isinstance(idx, str): + return idx + if isinstance(idx, dict): + op = idx.get("op", "?") + args = idx.get("args", []) + # ranged_slice carries (start_expr, extent_int): for the HLIR + # BufferSlice 'starts' field we want the start expression; the + # extent is recovered separately by _ref_extents. + if op == "ranged_slice": + return _render_idx(args[0]) + rendered = [_render_idx(a) for a in args] + # Render compound exprs as readable Python syntax for now. + # Legacy ExprMaterializer is what actually materializes them. + if op == "add": + return f"({rendered[0]} + {rendered[1]})" + if op == "sub": + return f"({rendered[0]} - {rendered[1]})" + if op == "mul": + return f"({rendered[0]} * {rendered[1]})" + if op == "fdiv": + return f"({rendered[0]} // {rendered[1]})" + if op == "fmod": + return f"({rendered[0]} % {rendered[1]})" + return f"" + return idx + + +def _ref_extents(ref: BufferRef) -> Tuple[int, ...]: + """Extents of the slice this ref describes — full-axis Slices give + the buffer's full size on that dim; a ranged_slice compound carries + its own extent; everything else (single scalar index) is 1.""" + out: List[int] = [] + for i, dim in enumerate(ref.buffer.shape): + idx = ref.indices[i] + if isinstance(idx, Slice): + out.append(int(dim)) + elif isinstance(idx, dict) and idx.get("op") == "ranged_slice": + out.append(int(idx["args"][1])) + else: + out.append(1) + return tuple(out) + + +def _is_whole_buffer_ref(ref: BufferRef) -> bool: + """All indices are Slice (no narrowing). + + The view pass prepends a cluster-phase index (a bare string) on + every non-global ref that went through cluster fusion, and + burn_view may permute it to a non-zero position. For DMA / btmm / + pure-elementwise ops that fire across all lanes (wrapped in + MultiLaneOp) the phase index is a whole-lane-axis access — + equivalent to Slice for whole-buffer purposes. + + Use the buffer's own ``cluster_dim`` (set by ``split._grow_buffer`` + and propagated through view / burn_view) as the source of truth: + a bare-string index sitting at that physical position is the + phase shorthand. When ``cluster_dim`` is None (cluster was + skipped — kernel has no lane axis, or every on-chip buffer was + already mlen-wide), no axis is treated as a phase shorthand, and + we fall back to "all-Slice only". This avoids silently swallowing + ordinary loop-var narrowings (e.g. ``oc`` in + ``Output[:, oc, :, :]``) when no cluster is in scope. + """ + if not ref.indices: + return True + if all(isinstance(i, Slice) for i in ref.indices): + return True + cdim = getattr(ref.buffer, "cluster_dim", None) + if cdim is None or not (0 <= cdim < len(ref.indices)): + return False + # The cluster-dim slot must be a VarRef phase shorthand; every + # other slot must be a Slice (real narrowing on a non-phase axis + # disqualifies whole-buffer treatment). + cluster_idx = ref.indices[cdim] + if not isinstance(cluster_idx, VarRef): + return False + for i, idx in enumerate(ref.indices): + if i == cdim: + continue + if not isinstance(idx, Slice): + return False + return True + + +# --------------------------------------------------------------------------- +# Op-arg construction +# --------------------------------------------------------------------------- + + +_INT32 = "int32" + +# Cache (name → tir.Var) so HLIR ``for`` ops constructed by to_plena +# (loop_var on synthetic for-rows etc.) and any name-keyed lookups +# resolve to the same Python object. ISA pass identifies bindings by +# object identity in its symbol_table. +# +# This cache services only ``_make_loop_var(name)`` — paths that +# synthesise *new* loop vars (e.g. ``"row0"``, the HLIR for created +# from a mid_ir For). Index expressions from mid_ir come through +# VarRef and bypass this cache entirely (``_render_idx_as_primexpr`` +# unwraps the wrapped ``tir.Var`` directly). +_VAR_CACHE: Dict[str, "_tir.Var"] = {} + +# Module-global lane modes table, populated by ``run`` at the start of +# each compile and read by per-op lowering helpers. Cleared by ``run``. +_LANE_MODES: Dict[str, str] = {} + +# Per-buffer pad-to-4D insert positions. Populated by ``run`` after +# ``_make_hlir_buffer``. Used by ``_axes_for_ref`` to align mid_ir +# per-op axes (rank == mid_ir buffer rank) with the post-pad HLIR +# shape so ``hlir.Op.buffer_axes`` agrees with ``buf.shape``. +_PAD_INSERTS: Dict[str, Tuple[int, ...]] = {} + +# Per-buffer cluster-expansion records ``(mode, mid_ir_cluster_dim, +# new_4d_shape)``. The pair to ``_PAD_INSERTS``: any on-chip buffer +# that wasn't pad-to-4D'd was instead grown by a cluster expansion +# (col_pack / row_stack / bshd_lift / fp_lane), and that's recorded +# here so axes-aware helpers can translate mid_ir-side (rank-3) axes +# tables onto the post-expansion 4D HLIR shape. +_CLUSTER_MODES: Dict[ + str, Tuple[str, Optional[int], Tuple[int, ...]], +] = {} + +# Per-buffer HLIR-side ``buffer_axes`` (post pad-to-4D / cluster-expand). +# Each entry is ``Tuple[(role_name, extent), ...]`` aligned with the +# HLIR buffer's shape — one role tag per physical dim. Populated by +# ``run`` once the per-buffer mode is known; read by every lowering +# helper that needs to stamp ``buffer_axes`` on its emitted hlir.Op. +_BUFFER_HLIR_AXES: Dict[str, Tuple[Tuple[str, int], ...]] = {} + + +def _axes_of(buf_name: str) -> Optional[Tuple[Tuple[str, int], ...]]: + """Look up the per-dim role tuple for ``buf_name`` from the + ``_BUFFER_HLIR_AXES`` table populated at ``run`` start. Returns + ``None`` for unknown buffers (e.g. transient names that never + landed in the HLIR buffer table) so callers can store ``None`` in + ``buffer_axes`` without crashing.""" + return _BUFFER_HLIR_AXES.get(buf_name) + + +def _hlir_axes_for_buffer(buf: "_hlir.Buffer") -> Tuple[Tuple[str, int], ...]: + """Synthesise the ``buffer_axes`` tuple for ``buf`` from its + post-expansion shape + ``cluster_dim``. + + Role assignment per physical dim: + * the innermost dim (axis ``rank-1``) is the SIMD / D / N axis + — tagged ``"simd"``. + * the cluster dim (``buf.cluster_dim``, if set) is the lane + axis — tagged ``"cluster"``. + * every other dim is a row-fanout axis — tagged ``"batch"``. + Both the leading B=1 placeholder (pad-to-4D) and the real S + (rows) end up as ``"batch"``; downstream callers can pick the + non-degenerate one via ``extent != 1`` if needed, but most + row-at consumers just want "which dim is rows" and that's the + ``"batch"`` dim with the largest extent. + """ + shape = [int(d) for d in buf.shape] + rank = len(shape) + if rank == 0: + return () + d_axis = rank - 1 + cdim = buf.cluster_dim + out: List[Tuple[str, int]] = [] + for i in range(rank): + if i == d_axis: + role = "simd" + elif cdim is not None and i == cdim: + role = "cluster" + else: + role = "batch" + out.append((role, shape[i])) + return tuple(out) + + +# Logical lane var (identity-keyed) → (phase VarRef, number VarRef, count). +# Populated by ``run()`` by scanning the body for CLUSTER ParallelAxes +# that carry ``original_axis_var`` matching one of ``func.lane_axes``. +# ``_render_idx_as_primexpr`` consults this to expand a bare lane-var +# index (e.g. ``by``) into ``by_phase + by_number * lane_count`` so the +# ISA materializer only sees axes bound by enclosing HLIR for-ops. +_LANE_AXIS_INFO: "Dict[VarRef, Tuple[VarRef, VarRef, int]]" = {} + + +# Hardware vector lane width (MLEN). Set by ``run()`` from the compile +# target, consumed by lowerings that emit per-row HW vector ops where +# the row stride is the HW vlen rather than anything derivable from +# the buffer's logical shape (e.g. vram→vram copy lowers each iteration +# to one V_ADD_VF that strides by mlen). +_HW_MLEN: int = 0 + + +def _get_var(name: str) -> "_tir.Var": + v = _VAR_CACHE.get(name) + if v is None: + v = _tir.Var(name, _INT32) + _VAR_CACHE[name] = v + return v + + +def _populate_lane_axis_info( + func: MidFunc, + lane_axes: List[str], + cluster_counts: List[int], +) -> None: + """Walk ``func.body`` looking for each named lane axis's pair of + (CLUSTER phase ParallelAxis, sibling number ParallelAxis). Record + ``original_var -> (phase_var, number_var, count)`` in + ``_LANE_AXIS_INFO`` for later VarRef-keyed lookup in + ``_render_idx_as_primexpr``. + + The matching CLUSTER axis is the one whose + ``original_axis_name == lane_axes[i]`` and whose + ``parent_grid_axis_name`` names a sibling axis with the same + ``original_axis_name``. Both axes are produced by split as a + ``number -> phase`` nest. + """ + # First gather every ParallelAxis in the body, keyed by axis_name. + axes_by_name: Dict[str, ParallelAxis] = {} + + def collect(s) -> None: + if isinstance(s, ParallelAxis): + axes_by_name[s.axis_name] = s + for c in s.body: + collect(c) + elif isinstance(s, (For, Async)): + for c in s.body: + collect(c) + elif isinstance(s, MultiLaneOp): + # Inner is a leaf op; no nested ParallelAxis. + return + + for s in func.body: + collect(s) + + for axis_name, count in zip(lane_axes, cluster_counts): + phase_axis = None + for ax in axes_by_name.values(): + if (ax.kind == ParallelKind.CLUSTER + and ax.original_axis_name == axis_name): + phase_axis = ax + break + if phase_axis is None: + # Either cluster was skipped for this kernel, or the kernel + # has no CLUSTER for this lane. Nothing to expand. + continue + number_axis = axes_by_name.get(phase_axis.parent_grid_axis_name) + if number_axis is None: + raise ToPlenaError( + f"lane axis {axis_name!r}: CLUSTER " + f"{phase_axis.axis_name!r} references unknown number " + f"axis {phase_axis.parent_grid_axis_name!r}" + ) + if (phase_axis.axis_var is None + or phase_axis.original_axis_var is None + or number_axis.axis_var is None): + raise ToPlenaError( + f"lane axis {axis_name!r}: identity (axis_var) fields " + f"missing on CLUSTER {phase_axis.axis_name!r} or " + f"number axis {number_axis.axis_name!r}. Split should " + f"have populated them." + ) + _LANE_AXIS_INFO[phase_axis.original_axis_var] = ( + phase_axis.axis_var, + number_axis.axis_var, + int(count), + ) + + +def _render_idx_as_primexpr(idx): + """Like ``_render_idx`` but returns a value suitable for + ``hlir.BufferSlice.starts``: ints stay ints; VarRefs become a + ``tir.Var`` (or, for logical lane vars, the split-form composite); + compound dicts become real ``tir.PrimExpr`` trees so the ISA pass's + ``_build_slice_offset_expr`` can multiply them by a stride directly. + + NOTE: VarRefs are unwrapped via the *name cache* (``_get_var``) — + not via ``idx.var`` directly. The downstream ISA materialiser binds + by ``tir.Var`` *identity* through its ``symbol_table``, and the + matching HLIR ``for`` loop_var is also minted from the same name + cache. Routing through the cache here makes the two halves resolve + to the same Python object, so ``symbol_table[var]`` finds the + binding. The earlier in-pipeline identity discipline (provided by + ``VarRef.same_as``) is independent of this rendering step. + """ + if isinstance(idx, Slice): + return 0 + if isinstance(idx, int): + return int(idx) + if isinstance(idx, VarRef): + # Logical lane vars (e.g. the user-written ``by``) get expanded + # to their split form ``by_phase + by_number * lane_count`` so + # the ISA layer, which only binds the split axes, can + # materialise the index. The lookup is by VarRef identity. + info = _LANE_AXIS_INFO.get(idx) + if info is not None: + phase, number, count = info + return _get_var(phase.name) + _get_var(number.name) * _tir.IntImm(_INT32, count) + return _get_var(idx.name) + if isinstance(idx, dict): + op = idx.get("op", "?") + args = idx.get("args", []) + if op == "ranged_slice": + # For HBM starts: the slice begins at args[0] (the cluster + # base expression). The extent (args[1]) is recovered by + # _ref_extents elsewhere. + return _render_idx_as_primexpr(args[0]) + rendered = [_render_idx_as_primexpr(a) for a in args] + if op == "add": + return rendered[0] + rendered[1] + if op == "sub": + return rendered[0] - rendered[1] + if op == "mul": + return rendered[0] * rendered[1] + if op == "fdiv": + return _tir.floordiv(rendered[0], rendered[1]) + if op == "fmod": + return _tir.floormod(rendered[0], rendered[1]) + raise ToPlenaError( + f"unhandled compound index op {op!r} in _render_idx_as_primexpr" + ) + return idx + + +def _make_buffer_arg(ref: BufferRef) -> Union[str, _hlir.BufferSlice]: + """Build an HLIR buffer_arg from a mid_ir BufferRef. Whole-buffer + refs become bare buffer-name strings; partial refs become BufferSlice.""" + if _is_whole_buffer_ref(ref): + return ref.buffer.name + starts = tuple(_render_idx_as_primexpr(i) for i in ref.indices) + extents = _ref_extents(ref) + return _hlir.BufferSlice( + parent=ref.buffer.name, + starts=starts, + extents=extents, + ) + + +# --------------------------------------------------------------------------- +# Op lowering +# --------------------------------------------------------------------------- + + +_BINOP_TO_INTRIN = { + BinOp.ADD: "v_add", + BinOp.SUB: "v_sub", + BinOp.MUL: "v_mul", +} + + +_UNARY_TO_INTRIN = { + UnaryOp.EXP: "v_exp", + UnaryOp.RECI: "v_reci", + UnaryOp.SQRT: "v_sqrt", + UnaryOp.COPY: "copy_v_to_v", +} + + +_REDUCE_TO_ROW_AT = { + ReduceOp.MAX: "row_reduce_max_at", + ReduceOp.SUM: "row_reduce_sum_at", +} + + +_ROW_FP_BINOP_TO_INTRIN = { + # Per-row VRAM × FPRAM-scalar op. One HLIR op = one HW instruction + # over a single row; ``_lower_bare_broadcast_elementwise`` wraps + # this in a ``for row`` so multi-row callers don't need to. + BinOp.ADD: "row_add_fp", + BinOp.SUB: "row_sub_fp", + BinOp.MUL: "row_mul_fp", +} + + +# FPRAM-resident per-lane scalar Elementwise: lower to a ``for row: +# fp__at`` loop. Used for things like ``M_OLD[row] = M_INIT[row]`` +# where every operand is a rank-1 FPRAM buffer. +_FP_AT_BINOP_TO_INTRIN = { + BinOp.ADD: "fp_add_at", + BinOp.SUB: "fp_sub_at", + BinOp.MUL: "fp_mul_at", + BinOp.MAX: "fp_max_at", +} + +_FP_AT_UNARY_TO_INTRIN = { + UnaryOp.COPY: "fp_copy_at", + UnaryOp.EXP: "fp_exp_at", + UnaryOp.RECI: "fp_reci_at", + UnaryOp.SQRT: "fp_sqrt_at", +} + + +def _dma_kind_from_scopes(src_scope: str, dst_scope: str) -> str: + """Pick the dma_* op kind from the (src, dst) scope pair. Mirrors + the legacy intrinsics registry: H↔V/M only, V↔H whole-buffer.""" + if src_scope == _scope.HBM and dst_scope == _scope.VRAM: + return "dma_h2v" + if src_scope == _scope.HBM and dst_scope == _scope.MRAM: + return "dma_h2m" + if src_scope == _scope.VRAM and dst_scope == _scope.HBM: + return "dma_v2h" + raise ToPlenaError( + f"unsupported DMA src→dst: {src_scope}→{dst_scope}" + ) + + +def _dma_kind_slice_variant(base: str) -> str: + """``dma_h2v`` → ``dma_h2v_slice`` etc. Used when one of the refs + isn't whole-buffer.""" + return f"{base}_slice" + + +def _lower_multi_lane_dma(op: Dma, lane_count: int, + buf_name_to_hlir: Dict[str, _hlir.Buffer], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + ) -> _hlir.Op: + src_scope = buf_name_to_hlir[op.src.buffer.name].scope + dst_scope = buf_name_to_hlir[op.dst.buffer.name].scope + # VRAM → VRAM "copy" isn't a HW DMA — emit ``copy_v_to_v`` + # (V_ADD_VF dst, src, f0=0) instead. + if src_scope == _scope.VRAM and dst_scope == _scope.VRAM: + return _lower_vram_to_vram_copy( + op, buf_name_to_hlir, cluster_axis_name=cluster_axis_name, + cluster_axis_var=cluster_axis_var, + ) + # VRAM ↔ FPRAM: not a DMA either — single S_MAP_FP_V / S_MAP_V_FP + # per mlen-wide row. tilelang authors write these as T.copy and + # rely on us to route them to the right HW path. + if src_scope == _scope.VRAM and dst_scope == _scope.FPRAM: + return _lower_v_fp_transfer( + op, "v_to_fp", buf_name_to_hlir, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + ) + if src_scope == _scope.FPRAM and dst_scope == _scope.VRAM: + return _lower_v_fp_transfer( + op, "fp_to_v", buf_name_to_hlir, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + ) + base = _dma_kind_from_scopes(src_scope, dst_scope) + src_arg = _make_buffer_arg(op.src) + dst_arg = _make_buffer_arg(op.dst) + has_slice = isinstance(src_arg, _hlir.BufferSlice) \ + or isinstance(dst_arg, _hlir.BufferSlice) + kind = _dma_kind_slice_variant(base) if has_slice else base + return _hlir.Op( + kind=kind, + buffer_args=[src_arg, dst_arg], + scalar_args=[lane_count], + annotations={"source": "MultiLaneOp(Dma)"}, + ) + + +def _ref_flat_offset(ref: BufferRef, + phase_var_zero: "Optional[VarRef]" = None) -> _tir.PrimExpr: + """Compute ``ref``'s starting element offset in row-major flat layout. + + Iterates buffer.shape backwards accumulating stride; concrete indices + contribute ``idx * stride``. ``Slice`` is whole-axis (start = 0), + contributes nothing. ``ranged_slice(start_expr, extent)`` contributes + ``start_expr * stride``. When ``phase_var_zero`` is set (the cluster + phase axis's VarRef), VarRef occurrences equal to it (by identity) + are treated as 0 — mirrors the ``_is_whole_buffer_ref`` convention + for sync-wrap multi-lane ops where the phase index just marks + "this op covers every lane in lockstep".""" + offset: _tir.PrimExpr = _tir.IntImm(_INT32, 0) + stride = 1 + for dim, idx in zip(reversed(ref.buffer.shape), reversed(ref.indices)): + if isinstance(idx, Slice): + pass # whole-axis access — start is 0 + elif (phase_var_zero is not None + and isinstance(idx, VarRef) and idx == phase_var_zero): + pass # cluster phase axis under sync wrap — contributes 0 + elif isinstance(idx, dict) and idx.get("op") == "ranged_slice": + start_expr = _render_idx_as_primexpr(idx["args"][0]) + scaled = start_expr if stride == 1 else _tir.Mul( + start_expr, _tir.IntImm(_INT32, int(stride)), + ) + offset = scaled if ( + isinstance(offset, _tir.IntImm) and int(offset.value) == 0 + ) else _tir.Add(offset, scaled) + else: + term = _render_idx_as_primexpr(idx) + if isinstance(term, _tir.IntImm) and int(term.value) == 0: + pass + else: + scaled = term if stride == 1 else _tir.Mul( + term, _tir.IntImm(_INT32, int(stride)), + ) + offset = scaled if ( + isinstance(offset, _tir.IntImm) and int(offset.value) == 0 + ) else _tir.Add(offset, scaled) + stride *= int(dim) + return offset + + +def _lower_vram_to_vram_copy(op: Dma, + buf_name_to_hlir: Dict[str, _hlir.Buffer], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + ) -> _hlir.Op: + """``T.copy(vram_src, vram_dst)`` → region-schema ``copy_v_to_v``. + + Emits one VramRegion per side, at each buffer's PRE-expansion + (native) rank using the ref's own indices: + * Slice / phase-var -> start=0, extent=full + * ranged_slice(s, ext) -> start=s, extent=ext + * VarRef / concrete idx -> start=expr, extent=1 + + The post-walk ``_rewrite_refs_to_4d`` pass then lifts each region + to 4D using the buffer's own ``_PAD_INSERTS`` or ``_CLUSTER_MODES`` + entry — which is exactly the right thing for cluster-asymmetric + pairs (one side cluster-expanded, the other a pinned ``global.vram`` + that only got pad-to-4D). Each side gets lifted on its own terms; + the lifted 4D extents end up matching because mid_ir guarantees + the logical region is the same on both sides (it's a sync-wrap + multi-lane copy). + """ + def _ref_region(ref: BufferRef) -> _hlir.VramRegion: + starts: List[Any] = [] + extents: List[int] = [] + for dim, idx in zip(ref.buffer.shape, ref.indices): + if isinstance(idx, Slice): + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(dim)) + elif (cluster_axis_var is not None + and isinstance(idx, VarRef) and idx == cluster_axis_var): + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(dim)) + elif isinstance(idx, dict) and idx.get("op") == "ranged_slice": + s_expr = _render_idx_as_primexpr(idx["args"][0]) + ext = int(idx["args"][1]) + starts.append(s_expr) + extents.append(ext) + else: + starts.append(_render_idx_as_primexpr(idx)) + extents.append(1) + return _hlir.VramRegion( + parent=ref.buffer.name, + starts=tuple(starts), extents=tuple(extents), + ) + + src_region = _ref_region(op.src) + dst_region = _ref_region(op.dst) + for_specs: List[Tuple[Any, int]] = [] + leaf = _hlir.Op( + kind="copy_v_to_v", + buffer_args=[src_region, dst_region], + scalar_args=[], + annotations={"source": "vram→vram copy"}, + ) + if not for_specs: + return leaf + body: List[_hlir.Op] = [leaf] + for v, ext in reversed(for_specs): + body = [_hlir.make_for_op(loop_var=v, extent=int(ext), body=body)] + return body[0] + + +def _is_zero_imm(expr) -> bool: + return isinstance(expr, _tir.IntImm) and int(expr.value) == 0 + + +def _ref_per_dim_starts( + ref: BufferRef, phase_var_zero: "Optional[VarRef]" = None, +) -> Tuple[Any, ...]: + """Per-dim start indices for a BufferRef, mirroring ``_ref_extents``. + + Slice → 0 (whole axis); ranged_slice → start_expr (rendered); + anything else → the index rendered as a PrimExpr. The + ``phase_var_zero`` convention matches ``_ref_flat_offset`` — a + VarRef equal (by identity) to ``phase_var_zero`` (the + cluster-phase axis VarRef) is treated as 0 under sync wrap. + """ + out: List[Any] = [] + for idx in ref.indices: + if isinstance(idx, Slice): + out.append(0) + elif (phase_var_zero is not None + and isinstance(idx, VarRef) and idx == phase_var_zero): + out.append(0) + elif isinstance(idx, dict) and idx.get("op") == "ranged_slice": + out.append(_render_idx_as_primexpr(idx["args"][0])) + else: + out.append(_render_idx_as_primexpr(idx)) + return tuple(out) + + +def _lower_v_fp_transfer( + op: Dma, + direction: str, # "v_to_fp" or "fp_to_v" + buf_name_to_hlir: Dict[str, _hlir.Buffer], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, +) -> _hlir.Op: + """``T.copy(vram, fpram)`` / ``T.copy(fpram, vram)`` → one HLIR slice + op carrying the full logical region. + + Splitting the logical transfer into HW-MLEN-wide ``S_MAP_*_FP/V`` + issues (and computing per-issue physical VRAM offsets through the + parent's 7D tile layout) is the ISA emitter's job — HLIR stays at + the logical-region level. + + HLIR ops emitted: + * ``v_fp_transfer_slice_v_to_fp`` + buffer_args=[VramRegion] scalars=[fp_addr] + * ``v_fp_transfer_slice_fp_to_v`` + buffer_args=[VramRegion] scalars=[fp_addr] + """ + if direction == "v_to_fp": + vram_ref, fp_ref = op.src, op.dst + kind = "v_fp_transfer_slice_v_to_fp" + else: + vram_ref, fp_ref = op.dst, op.src + kind = "v_fp_transfer_slice_fp_to_v" + vram_buf = buf_name_to_hlir[vram_ref.buffer.name] + fp_buf = buf_name_to_hlir[fp_ref.buffer.name] + + starts = _ref_per_dim_starts(vram_ref, phase_var_zero=cluster_axis_var) + extents = _ref_extents(vram_ref) + region = _hlir.VramRegion( + parent=vram_buf.name, + starts=starts, + extents=extents, + ) + + fp_indices = _zero_cluster_axis_in_fp_indices(fp_ref, cluster_axis_name) + fp_addr = _hlir.BufferElement(buffer=fp_buf.name, indices=fp_indices) + + return _hlir.Op( + kind=kind, + buffer_args=[region], + scalar_args=[fp_addr], + annotations={"source": f"T.copy vram↔fp ({direction})"}, + ) + + +def _zero_cluster_axis_in_fp_indices( + ref: BufferRef, cluster_axis_name: Optional[str], +) -> "tuple": + """Render ``ref.indices`` to PrimExprs, replacing the cluster-phase + axis with 0 when the enclosing MultiLaneOp covers all lanes in one + issue. ``ref.buffer.cluster_dim`` (set by split / propagated by + burn_view) tells us which physical axis carries the phase index. + """ + indices = list(ref.indices) + cluster_dim = getattr(ref.buffer, "cluster_dim", None) + if (cluster_axis_name is not None + and cluster_dim is not None + and 0 <= cluster_dim < len(indices)): + indices[cluster_dim] = 0 + return tuple(_render_idx_as_primexpr(i) for i in indices) + + +def _ref_touch_count(ref: BufferRef) -> int: + """How many elements does ``ref`` actually touch? + + Slice → buffer dim; ranged_slice → its declared extent; concrete + index → 1. Multiplied across all axes.""" + count = 1 + for dim, idx in zip(ref.buffer.shape, ref.indices): + if isinstance(idx, Slice): + count *= int(dim) + elif isinstance(idx, dict) and idx.get("op") == "ranged_slice": + count *= int(idx["args"][1]) + # concrete index: contributes 1 + return count + + +def _lower_multi_lane_btmm(op: Gemm, lane_count: int, + buf_name_to_hlir: Optional[Dict[str, _hlir.Buffer]] = None, + ) -> _hlir.Op: + """MultiLaneOp(Gemm) → region-schema btmm / btmv op. + + Like the bare per-head gemm path, but the cluster axis is folded + into a single HW multi-lane issue (no ``for lane`` synthesised). + Region.starts on the lane axis are 0 with extent == lane_count + so the emitter knows it fires across every lane natively. + """ + rows = _logical_rows_from_buf(op.a) + kind = "btmv" if rows == 1 else "btmm" + if buf_name_to_hlir is None: + raise ToPlenaError( + f"multi-lane Gemm[{kind}]: buf_name_to_hlir is required for " + f"region-schema lowering" + ) + a_buf = buf_name_to_hlir[op.a.buffer.name] + b_buf = buf_name_to_hlir[op.b.buffer.name] + c_buf = buf_name_to_hlir[op.c.buffer.name] + + # Region: lane_axis_name=None ⇒ start=0 on every axis (whole-buffer + # region). The emitter sees the cluster axis covered by its full + # lane_count extent and issues a single multi-lane instruction. + a_region = _gemm_full_region(op.a, a_buf, lane_axis_name=None) + b_region = _gemm_full_region(op.b, b_buf, lane_axis_name=None) + c_region = _gemm_full_region(op.c, c_buf, lane_axis_name=None) + a_roles = _align_dim_roles_to_4d(op.a.buffer.name, op.a_axes) + b_roles = _align_dim_roles_to_4d(op.b.buffer.name, op.b_axes) + c_roles = _align_dim_roles_to_4d(op.c.buffer.name, op.c_axes) + + return _hlir.Op( + kind=kind, + buffer_args=[a_region, b_region, c_region], + scalar_args=[a_roles, b_roles, c_roles], + annotations={"source": f"MultiLaneOp(Gemm[{kind}])"}, + ) + + +def _find_role(axes: List[AxisInfo], role: AxisRole + ) -> Tuple[Optional[int], Optional[AxisInfo]]: + """Return ``(dim_index, AxisInfo)`` of the first entry tagged + ``role``, or ``(None, None)`` if not found.""" + for i, a in enumerate(axes): + if a.role == role: + return i, a + return None, None + + +def _has_cluster_role(axes: Optional[List[AxisInfo]]) -> bool: + """True if any dim in ``axes`` is tagged ``CLUSTER``.""" + if not axes: + return False + return any(a.role == AxisRole.CLUSTER for a in axes) + + +def _axes_to_hlir_tuple( + axes: List[AxisInfo], + inserts: Tuple[int, ...] = (), +) -> Tuple[Tuple[str, int], ...]: + """Convert mid_ir per-axis ``AxisInfo`` list into the + ``Tuple[(role_name, extent), ...]`` form ``hlir.Op.buffer_axes`` + expects. + + ``inserts`` mirrors the per-buffer pad-to-4D rule: each position + in this list adds a ``("batch", 1)`` placeholder dim so the + returned tuple aligns with the post-pad HLIR shape. Positions + follow the same convention as ``_pad_tuple_at`` — ascending + indices in OUTPUT coords. + """ + out: List[Tuple[str, int]] = [ + (a.role.value, int(a.extent)) for a in axes + ] + for pos in inserts: + out.insert(pos, ("batch", 1)) + return tuple(out) + + +def _split_axes_by_role(axes: List[AxisInfo]): + """Group an op's axes table into ``(batch_extents, inner_extent)``. + + ``batch_extents`` are the BATCH-role extents in axis order — each + one becomes an outer ``for`` loop wrapped around the leaf op. + ``inner_extent`` is the product of every CLUSTER + SIMD axis — + that is, the contiguous 1-D vector each leaf op processes. + + REDUCE / BROADCAST axes are op-specific and not handled here (only + Elementwise / Dma / Reduce-leaf paths use this helper). + """ + batch_extents: List[int] = [] + inner = 1 + for a in axes: + if a.role == AxisRole.BATCH: + batch_extents.append(int(a.extent)) + elif a.role in (AxisRole.SIMD, AxisRole.CLUSTER): + inner *= int(a.extent) + else: + raise ToPlenaError( + f"_split_axes_by_role: unsupported role {a.role!r} in axes " + f"{axes!r}" + ) + return batch_extents, inner + + +def _lower_multi_lane_elementwise( + op: Elementwise, lane_count: int, + buf_name_to_hlir: Optional[Dict[str, _hlir.Buffer]] = None, + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, +) -> _hlir.Op: + """Pure elementwise (no Broadcast srcs) → ``v_add`` / ``v_exp`` / + ``v_zero`` / etc., wrapped in explicit ``for`` loops for each + BATCH axis. + + Reads the per-axis ``dst_axes`` table directly — every BATCH axis + becomes an outer ``for``; the contiguous CLUSTER + SIMD extents + multiply into the leaf ``v_*`` op's ``n_elem``. No buffer-shape + guessing, no row-geometry heuristics. + + Falls back to the FPRAM and per-row unary paths when the dst scope + or axes shape demand them. + """ + if (buf_name_to_hlir is not None + and buf_name_to_hlir[op.dst.buffer.name].scope == _scope.FPRAM): + return _lower_bare_fp_scalar_elementwise( + op, lane_count, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + ) + # COPY has a multi-lane ``copy_v_to_v`` intrin (handled below) and + # no per-row variant — keep it on the v_copy path even when the + # dst's row footprint is 1 (an artifact of fold absorbing + # T.Parallel rather than a request for a row_ emission). + if (op.op in _UNARY_TO_INTRIN + and op.op != UnaryOp.COPY + and len(op.srcs) == 1 + and _row_footprint(op.dst) == 1): + return _lower_per_row_unary( + op, cluster_axis_name, cluster_axis_var=cluster_axis_var, + ) + if op.op in _BINOP_TO_INTRIN: + kind = _BINOP_TO_INTRIN[op.op] + elif op.op in _UNARY_TO_INTRIN: + kind = _UNARY_TO_INTRIN[op.op] + if op.op == UnaryOp.COPY and not op.srcs: + kind = "v_zero" + else: + raise ToPlenaError(f"unsupported elementwise op {op.op!r}") + buffer_args: List[Any] = [] + for s in op.srcs: + if isinstance(s, Broadcast): + raise ToPlenaError( + f"MultiLaneOp Elementwise with Broadcast src — pass_2_mark " + f"should have set can_async=False" + ) + buffer_args.append(s.buffer.name) + buffer_args.append(op.dst.buffer.name) + + # Read fan-out structure straight off the axes table. + if not op.dst_axes: + raise ToPlenaError( + f"Elementwise on {op.dst.buffer.name!r} has empty dst_axes — " + f"fold/split/view must populate the axes table for every op " + f"so lower can stop guessing geometry from buffer shape." + ) + # Binop family (v_add / v_sub / v_mul) uses the new logical-coord + # schema: scalar_args = [idx0, idx1, idx2] picks one mlen-wide row + # per non-SIMD axis of the dst's 4D shape (B, S, H), and the D axis + # is implicit — the emitter walks all d_tiles itself. Extent-1 axes + # (e.g. the pad-to-4D B=1 placeholder) collapse to ``IntImm(0)`` + # rather than wrapping ``for _ in range(1)``. + # + # Unary / copy family (v_exp / v_reci / v_sqrt / v_zero / + # copy_v_to_v) still uses the legacy flat-offset schema until its + # emitters are migrated; that path stays on the older + # ``[off, off, n_elem]`` form below. + # Region schema (unified for v_add/v_sub/v_mul/v_exp/v_reci/v_sqrt/v_zero): + # each buffer_arg becomes a VramRegion with 4D BSHD (starts, + # extents). The HLIR axes table is the source of truth — it's + # computed from the post-pad-to-4D buffer shape and is always 4 + # entries in canonical (B, S, H, D) order. + # + # For each non-SIMD axis (slot 0..2): + # * extent == 1 -> start=0, extent=1 (no for, no idx) + # * cluster axis (lane span) -> start=0, extent=full + # (whole packed-head group lives in one mlen-row; emitter + # folds this axis out of the walk via parent.cluster_dim) + # * other batch axis ext > 1 -> fresh row var, outer for-op, + # region.start=var, region.extent=1 + # Slot 3 (SIMD/D) is always start=0, extent=D_full. + if kind in ("v_add", "v_sub", "v_mul", + "v_exp", "v_reci", "v_sqrt", + "v_zero", "copy_v_to_v"): + dst_hlir_axes = _axes_of(op.dst.buffer.name) + if dst_hlir_axes is None or len(dst_hlir_axes) != 4: + raise ToPlenaError( + f"v_* lowering: dst {op.dst.buffer.name!r} has no 4D " + f"hlir axes table; got {dst_hlir_axes!r}" + ) + starts: List[Any] = [] + extents: List[int] = [] + for_specs: List[Tuple[Any, int]] = [] + for slot, (role_name, extent) in enumerate(dst_hlir_axes[:3]): + if int(extent) == 1: + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(1) + elif role_name == "cluster": + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(extent)) + else: + v = _fresh_var(f"row{slot}") + starts.append(v) + extents.append(1) + for_specs.append((v, int(extent))) + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(dst_hlir_axes[3][1])) + + region_args = [ + _hlir.VramRegion( + parent=name, starts=tuple(starts), extents=tuple(extents), + ) + for name in buffer_args + ] + leaf = _hlir.Op( + kind=kind, + buffer_args=region_args, + scalar_args=[], + annotations={"source": f"MultiLaneOp(Elementwise {op.op.value})"}, + ) + if not for_specs: + return leaf + body: List[_hlir.Op] = [leaf] + for v, ext in reversed(for_specs): + body = [_hlir.make_for_op(loop_var=v, extent=ext, body=body)] + return body[0] + + # Legacy flat-offset fallback (nothing currently routes here; kept + # in case a future v_* kind shows up before the migration cleanup). + batch_extents, inner_extent = _split_axes_by_role(op.dst_axes) + + def _build_leaf_flat(row_offset): + off = row_offset if row_offset is not None else _tir.IntImm(_INT32, 0) + raise ToPlenaError(f"unhandled v_* kind {kind!r}") + + if not batch_extents: + return _build_leaf_flat(row_offset=None) + + strides: List[int] = [] + running = inner_extent + for ext in reversed(batch_extents): + strides.append(running) + running *= int(ext) + strides.reverse() + row_vars = [_fresh_var(f"row{i}") for i in range(len(batch_extents))] + total_offset = None + for v, st in zip(row_vars, strides): + term = _tir.Mul(v, _tir.IntImm(_INT32, st)) if st != 1 else v + total_offset = term if total_offset is None else _tir.Add(total_offset, term) + body: List[_hlir.Op] = [_build_leaf_flat(row_offset=total_offset)] + for v, ext in reversed(list(zip(row_vars, batch_extents))): + body = [_hlir.make_for_op(loop_var=v, extent=int(ext), body=body)] + return body[0] + + +_UNARY_TO_ROW_INTRIN = { + UnaryOp.EXP: "row_exp", +} + + +def _lower_per_row_unary( + op: Elementwise, + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, +) -> _hlir.Op: + """Per-row unary VRAM op → single-row ``row_`` leaf. + + Contract matches ``_emit_row_scalar_op_at`` with no fp operand: + buffer_args = [vram_src, vram_dst] + scalar_args = [row_var, lane_var] + Caller's ``for row`` (kernel-written, preserved by mid_ir) is + rendered by the walker. + """ + if op.op not in _UNARY_TO_ROW_INTRIN: + raise ToPlenaError( + f"per-row unary op {op.op!r} has no row_ intrinsic" + ) + intrin = _UNARY_TO_ROW_INTRIN[op.op] + (src,) = op.srcs + if isinstance(src, Broadcast): + raise ToPlenaError( + "per-row unary expects a direct BufferRef src, got Broadcast" + ) + # Pick row_var from the dst's BATCH axis index (mid_ir source of + # truth via op.dst_axes). The kernel-written ``for row`` / + # ``for oh`` / ... wraps this op; the dst index at that axis is + # exactly the bound loop var we need. + row_axis = _row_axis_index_from_axes( + op.dst_axes, ctx=f"per-row unary {intrin} dst", + ) + row_var = _render_idx_as_primexpr(op.dst.indices[row_axis]) + if cluster_axis_name is not None: + lane_var = _make_loop_var(cluster_axis_name) + elif cluster_axis_var is not None: + lane_var = _make_loop_var(cluster_axis_var.name) + else: + lane_var = _fresh_var("lane") + src_region = _build_row_at_region( + src.buffer.name, row_var=row_var, lane_var=lane_var, + ctx=f"per-row unary {intrin} src", + ) + dst_region = _build_row_at_region( + op.dst.buffer.name, row_var=row_var, lane_var=lane_var, + ctx=f"per-row unary {intrin} dst", + ) + return _hlir.Op( + kind=intrin, + buffer_args=[src_region, dst_region], + scalar_args=[], + annotations={"source": f"per-row Elementwise[{op.op.value}]"}, + ) + + +def _lower_multi_lane(mlo: MultiLaneOp, + buf_name_to_hlir: Dict[str, _hlir.Buffer], + lane_count: int) -> _hlir.Op: + """Single MultiLaneOp → single HLIR Op with lane_count scalar. + + ``lane_count`` is the enclosing cluster's extent (e.g. 4 for a + typical MLEN/hlen split). It controls both (a) how many lanes the + HW-side multi-lane op spans (passed via scalar_args[0]) and (b) + the synthetic ``for lane`` extent when the inner op falls back to + a per-lane FPRAM template. + """ + axis_name = mlo.cluster_axis_names[0] if mlo.cluster_axis_names else None + axis_var = mlo.cluster_axis_vars[0] if mlo.cluster_axis_vars else None + inner = mlo.inner + if isinstance(inner, Dma): + return _lower_multi_lane_dma( + inner, lane_count, buf_name_to_hlir, + cluster_axis_name=axis_name, + cluster_axis_var=axis_var, + ) + if isinstance(inner, Gemm): + return _lower_multi_lane_btmm(inner, lane_count, buf_name_to_hlir) + if isinstance(inner, Elementwise): + return _lower_multi_lane_elementwise( + inner, lane_count, buf_name_to_hlir, axis_name, + cluster_axis_var=axis_var, + ) + raise ToPlenaError( + f"unsupported MultiLaneOp inner: {type(inner).__name__}" + ) + + +def _lane_loop_var(cluster_axis_name: Optional[str], + cluster_axis_var: "Optional[VarRef]" = None) -> _tir.Var: + """Pick a loop_var for the synthetic ``for lane`` that wraps a + bare op inside a cluster. Prefer the cluster axis's VarRef (so the + loop_var identity matches the in-buffer phase var); fall back to + the name-cached var, or finally ``"lane"`` for bare ops outside + any cluster.""" + if cluster_axis_var is not None: + return cluster_axis_var.var + if cluster_axis_name is not None: + return _make_loop_var(cluster_axis_name) + return _make_loop_var("lane") + + +def _fresh_var(name: str) -> _tir.Var: + """Allocate a brand-new ``tir.Var`` that is NOT cached. Use for + synthetic loops (row/col templates) whose Var must not collide + with an enclosing same-named loop in the kernel.""" + return _tir.Var(name, _INT32) + + +def _per_lane_stride(buf: _hlir.Buffer, mode: str) -> int: + """Stride (in elements) between consecutive lanes for a buffer. + + Reads ``buf.cluster_dim`` directly: the lane stride is the product + of every shape axis strictly to the right of the cluster dim. + Earlier versions had a ``mode``-keyed fallback that hard-coded + ``shape[1] * shape[2] * shape[3]`` etc.; that legacy path masked + a missing ``cluster_dim``. Any buffer reaching codegen without a + ``cluster_dim`` is a real bug now — raise loudly instead of + silently miscomputing. + """ + shape = [int(d) for d in buf.shape] + if buf.cluster_dim is None: + raise ToPlenaError( + f"_per_lane_stride: buffer {buf.name!r} (mode={mode!r}) has no " + f"cluster_dim; split / view passes must have populated it before " + f"codegen. shape={shape}" + ) + stride = 1 + for axis in range(buf.cluster_dim + 1, len(shape)): + stride *= shape[axis] + return stride + + +_GEMM_ROLE_TO_LABEL: Dict[AxisRole, str] = { + AxisRole.GEMM_M: "M", + AxisRole.GEMM_K: "K", + AxisRole.GEMM_N: "N", +} + + +def _axes_to_dim_roles(axes: List[AxisInfo]) -> Tuple[str, ...]: + """Project a mid_ir per-axis ``AxisInfo`` table onto the gemm + dim-role labels emitters consume. + + Only the matmul-specific roles (M / K / N) keep a distinct label; + everything else (BATCH, CLUSTER, SIMD, BROADCAST, ...) collapses + to ``"_"`` — the emitter only cares about M/K/N positions to drive + instruction selection (M_MM vs M_TMM, M_MV vs M_BTMV) and to look + up extents; other axes contribute via region.starts/extents in + the usual way (lane idx, batch fan-out). + """ + return tuple( + _GEMM_ROLE_TO_LABEL.get(a.role, "_") for a in axes + ) + + +def _align_dim_roles_to_4d(buf_name: str, + mid_axes: List[AxisInfo]) -> Tuple[str, ...]: + """Align a mid_ir per-axis roles table onto the HLIR buffer's + post-expansion 4D shape, returning a 4-tuple of dim-role labels + ("M"/"K"/"N"/"_") suitable for the gemm Region+roles schema. + + Two cases: + + * ``_PAD_INSERTS`` entry → the buffer was rank-padded to 4D by + inserting extent-1 axes at recorded positions. We pad the + roles list at the same positions with ``"_"``. + + * ``_CLUSTER_MODES`` entry → the buffer was cluster-expanded. + We mirror the exact axis placement that + ``_rewrite_ref_for_cluster_mode`` does for starts/extents, + so roles end up at the right *physical* axis after the + lane → BSHD anchoring (row_stack puts lane at axis 0; + col_pack puts it at axis 2). The mid_ir axis at the source + cluster_dim collapses to ``"_"`` (cluster never carries a + gemm role); the leftover BSHD slot is also ``"_"``. + + Returns a 4-tuple. If the mid_ir axes are already rank-4 + (HBM buffer, or a no-op) returns the projection directly. + """ + base = _axes_to_dim_roles(mid_axes) + if len(base) == 4: + return base + if buf_name in _PAD_INSERTS: + inserts = _PAD_INSERTS[buf_name] + out = list(base) + for pos in inserts: + out.insert(pos, "_") + if len(out) != 4: + raise ToPlenaError( + f"_align_dim_roles_to_4d: pad applied to {buf_name!r} but " + f"result rank {len(out)} != 4 (mid_axes={mid_axes!r}, " + f"inserts={inserts!r})" + ) + return tuple(out) + if buf_name in _CLUSTER_MODES: + mode, old_cluster_dim, new_shape = _CLUSTER_MODES[buf_name] + if mode == _MODE_FP_LANE: + return base + if len(base) != 3: + raise ToPlenaError( + f"_align_dim_roles_to_4d: cluster-expanded {buf_name!r} " + f"expects rank-3 mid axes, got {len(base)} ({mid_axes!r})" + ) + if mode == _MODE_BSHD_LIFT or old_cluster_dim is None: + # mid_ir (?, S, D) -> BSHD (?, S, 1, D) + return (base[0], base[1], "_", base[2]) + new_lane_dim = _CLUSTER_MODE_NEW_LANE_DIM[mode] + if new_lane_dim is None: + return (base[0], base[1], "_", base[2]) + # Same anchoring as _rewrite_ref_for_cluster_mode: + # * non-lane sources keep order: first non-lane mid axis -> out[1] (S), + # second non-lane mid axis -> out[3] (D) + # * lane slot at new_lane_dim is "_" (cluster never carries a gemm role) + # * remaining BSHD slot is "_" + non_lane_sources = [i for i in range(3) if i != old_cluster_dim] + if len(non_lane_sources) != 2: + raise ToPlenaError( + f"_align_dim_roles_to_4d: unexpected non-lane source count " + f"{len(non_lane_sources)} for {buf_name!r}" + ) + s_src, d_src = non_lane_sources + out = ["_"] * 4 + out[1] = base[s_src] + out[3] = base[d_src] + # new_lane_dim already "_" from initialisation; leftover slot too. + return tuple(out) + # Unknown buffer (or no rank change needed): pad with leading "_" + # placeholders defensively. + deficit = 4 - len(base) + if deficit < 0: + raise ToPlenaError( + f"_align_dim_roles_to_4d: {buf_name!r} mid axes rank " + f"{len(base)} > 4 with no pad/cluster record" + ) + return tuple(["_"] * deficit) + base + + +def _make_onchip_region( + name: str, + buf: "_hlir.Buffer", + starts: Tuple[Any, ...], + extents: Tuple[int, ...], +): + """Build a Vram/MramRegion based on buffer scope.""" + if buf.scope == _scope.MRAM: + return _hlir.MramRegion(parent=name, starts=starts, extents=extents) + return _hlir.VramRegion(parent=name, starts=starts, extents=extents) + + +def _gemm_full_region( + ref: "BufferRef", + buf: "_hlir.Buffer", + *, + lane_axis_name: Optional[str] = None, +) -> "Any": + """Build a 4D Vram/MramRegion that covers the *whole* buffer. + + ``starts`` are all zero (or ``lane_var`` on the cluster axis when + the caller's gemm sits inside a CLUSTER and the buffer's + ``cluster_dim`` marks that axis); ``extents`` are the buffer's + full per-dim extents from the mid_ir-side shape. The emitter walks + M_tiles / K_tiles / N from those extents using the parallel + ``dim_roles`` tuple, so this region uniformly describes the gemm + workspace whether it's BTMM-shaped (single mlen-tile per axis) or + multi-tile linear. + + Cluster handling: when ``lane_axis_name`` is provided (the gemm + is wrapped in a CLUSTER → ``for lane:`` in mid_ir), each per-lane + issue lives in ``starts[cluster_dim] = lane_var`` with the lane + axis's extent reduced to 1 (a single lane per leaf op). Without + a cluster context, every start is 0. + """ + cluster_dim = getattr(buf, "cluster_dim", None) + shape = [int(d) for d in buf.shape] + starts: List[Any] = [] + extents: List[int] = [] + for i, dim_extent in enumerate(shape): + if (cluster_dim is not None + and i == cluster_dim + and lane_axis_name is not None): + starts.append(_make_loop_var(lane_axis_name)) + extents.append(1) + else: + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(dim_extent)) + return _make_onchip_region( + ref.buffer.name, buf, tuple(starts), tuple(extents), + ) + + +def _lower_bare_per_head_gemm( + op: Gemm, + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + buf_name_to_hlir: Optional[Dict[str, _hlir.Buffer]] = None, + lane_modes: Optional[Dict[str, str]] = None, +) -> _hlir.Op: + """Bare (non-async) per-head matmul / mv → region-based HLIR op. + + New region schema: + buffer_args = [a_region, b_region, c_region] # Vram/MramRegion + scalar_args = [a_dim_roles, b_dim_roles, c_dim_roles] + each is a 4-tuple of "M"/"K"/"N"/"_" labels aligned with + the parent buffer's 4D physical shape. + + Region.starts encodes the per-lane gemm position: when this op + is wrapped in a CLUSTER (``cluster_axis_name`` set), the + cluster-marked axis of each region gets ``lane_var`` at start + (extent=1 → one lane per leaf op). The emitter walks the lane + axis using its tile_layout stride; the mid_ir layer just hands + over the logical (b, s, h, d) position, without doing any + physical stride math. + + transpose_b is dropped as a flag — b's dim_roles tuple encodes + "K before N" (K-inner, standard) vs "N before K" (transpose_b); + the emitter decides M_MM vs M_TMM from that ordering. + + Falls back to ``kind="mv"`` (instead of "matmul") when the LHS + has only one M-row (decode-style P @ V). + """ + a_buf = buf_name_to_hlir[op.a.buffer.name] if buf_name_to_hlir else None + b_buf = buf_name_to_hlir[op.b.buffer.name] if buf_name_to_hlir else None + c_buf = buf_name_to_hlir[op.c.buffer.name] if buf_name_to_hlir else None + if a_buf is None or b_buf is None or c_buf is None: + raise ToPlenaError( + f"per-head Gemm: missing HLIR buffer for one of a/b/c " + f"(a={op.a.buffer.name!r}, b={op.b.buffer.name!r}, " + f"c={op.c.buffer.name!r})" + ) + + # LHS rows == 1 → matrix-vector. Read off ``op.a_axes`` GEMM_M + # extent (authoritative); fall back to buffer shape only if axes + # are missing. + a_M_extent: Optional[int] = None + if op.a_axes: + _, a_m_info = _find_role(op.a_axes, AxisRole.GEMM_M) + if a_m_info is not None: + a_M_extent = int(a_m_info.extent) + if a_M_extent is not None: + use_mv = a_M_extent == 1 + else: + use_mv = _logical_rows_from_buf(op.a) == 1 + + inside_cluster = ( + cluster_extent is not None and cluster_axis_name is not None + ) + lane_axis_name = cluster_axis_name if inside_cluster else None + + a_region = _gemm_full_region(op.a, a_buf, lane_axis_name=lane_axis_name) + b_region = _gemm_full_region(op.b, b_buf, lane_axis_name=lane_axis_name) + c_region = _gemm_full_region(op.c, c_buf, lane_axis_name=lane_axis_name) + a_roles = _align_dim_roles_to_4d(op.a.buffer.name, op.a_axes) + b_roles = _align_dim_roles_to_4d(op.b.buffer.name, op.b_axes) + c_roles = _align_dim_roles_to_4d(op.c.buffer.name, op.c_axes) + + annotations: Dict[str, Any] = { + "source": ( + "per-head Gemm(rows=1) inside cluster" if (use_mv and inside_cluster) + else "per-head Gemm(rows=1)" if use_mv + else "per-head Gemm(overwrite) inside cluster" if inside_cluster + else "per-head Gemm(overwrite)" + ), + } + + return _hlir.Op( + kind="mv" if use_mv else "matmul", + buffer_args=[a_region, b_region, c_region], + scalar_args=[a_roles, b_roles, c_roles], + annotations=annotations, + ) + + +def _row_axis_index_of_buf(name: str, shape, cluster_dim: Optional[int]) -> int: + """Index of the logical row axis in a buffer's shape. + + Skips the innermost axis (column / D / hlen) and the cluster axis + (lane); the first remaining axis is rows. ``cluster_dim`` is the + explicit marker propagated from pass_3_split. + + NOTE: this still derives the rows axis from buffer shape + cluster + dim rather than the op's per-axis ``axes`` table. Equivalent in + every shape layout we ship today, but the cleanest replacement is + ``next(i for i, a in enumerate(op_axes) if a.role == AxisRole.BATCH)`` + once every caller has an op handle in scope. Left as-is to limit + blast radius — flag if a future buffer layout pushes rows past a + non-cluster, non-innermost axis the per-axis table sees but the + cluster_dim heuristic doesn't.""" + shape = [int(d) for d in shape] + if len(shape) < 2: + raise ToPlenaError( + f"buffer {name!r} rank {len(shape)} has no row axis" + ) + inner = len(shape) - 1 + for axis in range(len(shape)): + if axis == inner: + continue + if cluster_dim is not None and axis == cluster_dim: + continue + return axis + raise ToPlenaError( + f"buffer {name!r} shape={shape} cluster_dim={cluster_dim}: " + "no row axis available (only inner + cluster?)" + ) + + +def _row_axis_index(ref: BufferRef) -> int: + """Row axis index for a mid_ir BufferRef. Delegates to the + buffer-only form below.""" + return _row_axis_index_of_buf( + ref.buffer.name, ref.buffer.shape, ref.buffer.cluster_dim, + ) + + +def _row_axis_index_from_axes( + axes: List[AxisInfo], + *, + ctx: str, +) -> int: + """Pick the rows axis from an op's per-axis role table. + + The rows axis is the (last) ``BATCH`` axis with the largest + extent. SIMD / CLUSTER / REDUCE / BROADCAST / GEMM_* are + skipped. Caller passes an ``op.dst_axes`` / ``op.src_axes[i]`` + so the answer comes from mid_ir's authoritative role table, + not from buffer shape + cluster_dim heuristics. + """ + rows_axis = -1 + rows_extent = -1 + for i, a in enumerate(axes): + if a.role != AxisRole.BATCH: + continue + if int(a.extent) > rows_extent: + rows_extent = int(a.extent) + rows_axis = i + if rows_axis < 0: + raise ToPlenaError( + f"{ctx}: no BATCH axis in axes {axes!r}; cannot locate " + f"rows dimension" + ) + return rows_axis + + +def _row_footprint(ref: BufferRef) -> int: + """How many rows does this op-local ref actually touch? + + ``Slice`` on the row axis → full buffer row extent (op covers the + whole row stack). Concrete index (var / int) → 1 (op acts on a + single row picked by the enclosing for-row). + """ + axis = _row_axis_index(ref) + idx = ref.indices[axis] + if isinstance(idx, Slice): + return int(ref.buffer.shape[axis]) + if isinstance(idx, dict) and idx.get("op") == "ranged_slice": + return int(idx["args"][1]) + return 1 + + +def _logical_rows_from_buf(ref: BufferRef) -> int: + """Recover the kernel's logical row count from a mid_ir BufferRef. + + Uses the explicit ``cluster_dim`` marker on the BufferDef to skip + the lane axis; the rows axis is the remaining non-innermost dim.""" + shape = [int(d) for d in ref.buffer.shape] + if len(shape) < 2: + return 1 + try: + return int(shape[_row_axis_index(ref)]) + except ToPlenaError: + # Defensive fallback (shouldn't trigger in practice). + return shape[-2] + + +def _render_ref_with_role_axes( + ref: "BufferRef", + axes: List[AxisInfo], + *, + row_var, + lane_var, + ctx: str, +) -> Tuple[Any, ...]: + """Render a mid_ir ``BufferRef`` as an HLIR index tuple, using its + per-axis role table to resolve any ``Slice`` ("the op covers this + whole axis") into the right loop var. + + Used for Reduce dst, whose axes (per mid_ir) are exactly the dst + fragment's surviving dims after the REDUCE collapse — so each axis + is either: + * ``CLUSTER`` -> ``lane_var`` (per-lane fan-out) + * ``BATCH`` -> ``row_var`` (per-row fan-out) + + User-pinned indices (``MEAN_SUM[r]`` with concrete ``r``) bypass + the Slice rule and are rendered as-is — kernels that explicitly + thread a loop var still work. + + SIMD / REDUCE roles never appear on a Reduce dst by construction + (REDUCE is collapsed; SIMD is what made the src vectorisable); + seeing either here means an upstream pass produced inconsistent + axes, so we raise. + """ + if len(ref.indices) != len(axes): + raise ToPlenaError( + f"{ctx}: rank mismatch — ref {ref.buffer.name!r} has " + f"{len(ref.indices)} indices but axes table has {len(axes)} " + f"entries (indices={list(ref.indices)!r}, axes={axes!r})" + ) + out: List[Any] = [] + for axis_idx, (raw_idx, axis) in enumerate(zip(ref.indices, axes)): + if not isinstance(raw_idx, Slice): + out.append(_render_idx_as_primexpr(raw_idx)) + continue + if axis.role == AxisRole.CLUSTER: + out.append(lane_var) + elif axis.role == AxisRole.BATCH: + out.append(row_var) + else: + raise ToPlenaError( + f"{ctx}: axis {axis_idx} on ref {ref.buffer.name!r} is a " + f"Slice with role {axis.role!r}; expected CLUSTER or " + f"BATCH (SIMD / REDUCE shouldn't survive on a reduce dst)." + ) + return tuple(out) + + +def _build_row_at_region( + buf_name: str, + *, + row_var, + lane_var, + ctx: str, +) -> "_hlir.VramRegion": + """Build a ``VramRegion`` for a single-row row_*_at op. + + ``row_*_at`` ops touch exactly one logical row (one (b, s, h) + cell), so non-D extents are always 1. The starts are placed per + axis ROLE so the buffer's actual physical layout is respected: + * SIMD axis (innermost, always D) -> start = 0, extent = D_full + * CLUSTER axis (lane carrier) -> start = lane_var, extent = 1 + * BATCH axis with the largest extent (the "rows" axis) + -> start = row_var, extent = 1 + * any other BATCH axis (extent-1 placeholders, e.g. pad-to-4D + B=1 or row_stack-spare H=1) -> start = 0, extent = 1 + + This is needed because different lane-fusion modes put the + cluster axis at different physical positions: + * col_pack: (B=1, S, lane=H, D_narrow) cluster_dim=2 + * row_stack: (lane=B, S, H=1, MLEN) cluster_dim=0 + A schema that hard-codes ``starts=(0, row, lane, 0)`` would + misplace the lane var in row_stack mode and corrupt all reads. + """ + axes = _axes_of(buf_name) + if axes is None or len(axes) != 4: + raise ToPlenaError( + f"{ctx}: buffer {buf_name!r} has no 4D hlir axes; got {axes!r}" + ) + starts: List[Any] = [] + extents: List[int] = [] + rows_slot: Optional[int] = None + rows_extent: int = -1 + # First pass: find the rows axis (largest batch). + for i, (role, extent) in enumerate(axes): + if role == "batch" and int(extent) > rows_extent: + rows_extent = int(extent) + rows_slot = i + for i, (role, extent) in enumerate(axes): + if role == "simd": + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(int(extent)) + elif role == "cluster": + starts.append(lane_var) + extents.append(1) + elif i == rows_slot and int(extent) > 1: + starts.append(row_var) + extents.append(1) + else: + # Degenerate batch placeholder (extent 1 or smaller batch). + starts.append(_tir.IntImm(_INT32, 0)) + extents.append(1) + return _hlir.VramRegion( + parent=buf_name, + starts=tuple(starts), + extents=tuple(extents), + ) + + +def _lower_bare_reduce(op: Reduce, + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + ) -> _hlir.Op: + """Bare reduce → single-row ``row_reduce_*_at`` leaf op, + optionally wrapped in a synthesised ``for row``. + + Contract: one HLIR leaf op = one HW instruction. Decision uses + the src ref's row footprint (reduce collapses the inner axis; + iterating over rows yields one HW reduce per row): + + * Slice on the row axis → reduce covers every row; synthesise + ``for row``. + * Concrete index → reduce acts on a single row picked by the + kernel's outer ``for row``, which the walker renders. + + The enclosing CLUSTER ``ParallelAxis`` (lowered to ``for lane`` + by the walker) binds the lane var. + """ + if op.op not in _REDUCE_TO_ROW_AT: + raise ToPlenaError(f"unsupported reduce op {op.op!r}") + intrin = _REDUCE_TO_ROW_AT[op.op] + # Lane axis only exists when a cluster wraps the op; otherwise + # there is no lane dim and the head index in VRAM addressing is 0. + if cluster_axis_name is not None: + lane_var: _tir.PrimExpr = _make_loop_var(cluster_axis_name) + else: + lane_var = _tir.IntImm(_INT32, 0) + row_footprint = _row_footprint(op.src) + if row_footprint > 1: + row_var: _tir.PrimExpr = _fresh_var("row") + wrap_rows = row_footprint + else: + # Single-row reduce: pick row var from the src's row-axis + # index via the op's per-axis role table (mid_ir's + # authoritative info). The enclosing kernel-written for + # already bound the index; we just thread its VarRef / + # IntImm through the rendered tree. + row_axis = _row_axis_index_from_axes( + op.src_axes, ctx=f"reduce[{op.op.value}] src", + ) + row_var = _render_idx_as_primexpr(op.src.indices[row_axis]) + wrap_rows = None + fp_addr = _hlir.BufferElement( + buffer=op.dst.buffer.name, + indices=_render_ref_with_role_axes( + op.dst, op.dst_axes, + row_var=row_var, lane_var=lane_var, + ctx=f"reduce[{op.op.value}] dst", + ), + ) + # Region schema: starts placed per axis ROLE (cluster slot gets + # the lane var, the largest non-cluster batch slot gets the row + # var, everything else is 0). One mlen-row covers all of D. + src_region = _build_row_at_region( + op.src.buffer.name, row_var=row_var, lane_var=lane_var, + ctx="reduce src", + ) + leaf = _hlir.Op( + kind=intrin, + buffer_args=[src_region], + scalar_args=[fp_addr], + annotations={"source": f"bare Reduce[{op.op.value}]"}, + ) + if wrap_rows is None: + return leaf + return _hlir.make_for_op(loop_var=row_var, extent=wrap_rows, body=[leaf]) + + +def _lower_bare_broadcast_elementwise( + op: Elementwise, + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, +) -> _hlir.Op: + """Bare elementwise with a per-row FPRAM Broadcast src + → single-row ``row__fp`` leaf op. + + Contract: one HLIR leaf op = one HW V_*_VF instruction over one + row. Whether we wrap the leaf in a synthesised ``for row`` is + decided by the op's row-axis footprint on its dst ref: + + * Slice (op covers the whole row stack) → fold absorbed the + kernel's outer ``for row``; we re-synthesise the loop here so + the HW still gets one issue per row. + * Concrete index (var / int) → the kernel's outer + ``for row`` is still in mid_ir as a ``For`` and the walker + will render it; we just emit the leaf. + """ + if op.op not in _ROW_FP_BINOP_TO_INTRIN: + raise ToPlenaError( + f"unsupported broadcast Elementwise op {op.op!r}" + ) + intrin = _ROW_FP_BINOP_TO_INTRIN[op.op] + bcast_src = None + bcast_axes = None + direct_src = None + for s, s_axes in zip(op.srcs, op.src_axes): + if isinstance(s, Broadcast): + bcast_src = s + bcast_axes = s_axes + else: + direct_src = s + if bcast_src is None or direct_src is None: + raise ToPlenaError( + "broadcast Elementwise expected one BufferRef + one Broadcast src" + ) + row_footprint = _row_footprint(op.dst) + if row_footprint > 1: + row_var = _fresh_var("row") + wrap_rows = row_footprint + else: + # Single-row leaf: pick the row var from the dst's row-axis + # index via the op's per-axis role table (mid_ir source of + # truth). The enclosing kernel-written for already bound it + # — int IntImm or VarRef rendered with name-cached identity + # so symbol_table lookups match. + row_axis = _row_axis_index_from_axes( + op.dst_axes, ctx=f"row_*_fp[{op.op.value}] dst", + ) + row_var = _render_idx_as_primexpr(op.dst.indices[row_axis]) + wrap_rows = None + # Lane axis only exists when a cluster wraps the op; otherwise the + # FP buffer has no lane dim and the VRAM dst's head index is 0. + if cluster_axis_name is not None: + lane_var: _tir.PrimExpr = _make_loop_var(cluster_axis_name) + else: + lane_var = _tir.IntImm(_INT32, 0) + # Resolve fp_addr indices via the mid_ir src_axes role table — + # same approach as ``_lower_bare_reduce``: a Slice on a CLUSTER + # axis gets ``lane_var``, on a BATCH axis ``row_var``; concrete + # indices (the kernel pinned them) are rendered as-is. This + # replaces the old packed / non-packed branch that hard-coded + # ``(lane_var, row_var)`` on packed-head and silently rendered + # Slice -> 0 on non-packed paths. + fp_addr = _hlir.BufferElement( + buffer=bcast_src.src.buffer.name, + indices=_render_ref_with_role_axes( + bcast_src.src, bcast_axes, + row_var=row_var, lane_var=lane_var, + ctx=f"row_*_fp[{op.op.value}] bcast src", + ), + ) + src_region = _build_row_at_region( + direct_src.buffer.name, row_var=row_var, lane_var=lane_var, + ctx="row_*_fp src", + ) + dst_region = _build_row_at_region( + op.dst.buffer.name, row_var=row_var, lane_var=lane_var, + ctx="row_*_fp dst", + ) + leaf = _hlir.Op( + kind=intrin, + buffer_args=[src_region, dst_region], + scalar_args=[fp_addr], + annotations={"source": f"bare Elementwise[{op.op.value}] w/ broadcast"}, + ) + if wrap_rows is None: + return leaf + return _hlir.make_for_op(loop_var=row_var, extent=wrap_rows, body=[leaf]) + + +def _fp_buffer_element_from_ref(ref: BufferRef) -> _hlir.BufferElement: + """Build an HLIR ``BufferElement`` from a mid_ir ``BufferRef`` whose + indices already address a single scalar position (i.e. the enclosing + mid_ir loops bind every index var). All indices are passed through + ``_render_idx_as_primexpr`` so loop-var strings become the same + ``tir.Var`` objects the surrounding HLIR ``for`` ops use.""" + return _hlir.BufferElement( + buffer=ref.buffer.name, + indices=tuple(_render_idx_as_primexpr(i) for i in ref.indices), + ) + + +def _lower_bare_fp_scalar_elementwise( + op: Elementwise, + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, +) -> _hlir.Op: + """Bare elementwise on FPRAM rank-1 per-lane state → ``for lane: + fp__at()``. + + Two shapes show up here: + + * ``M_OLD[row] = M_INIT[row]`` — single-scalar update, already + nested inside a kernel-written ``for row`` (rendered to a HLIR + for op by the walker). Emit a single ``fp__at`` leaf. + + * ``for m in T.Parallel(MLEN): A_sh_acc[m] = ...`` — fold + absorbed the parallel loop into ``axis=-1, size=N``. FPRAM has + no vector op, so we re-emit a ``for`` over ``size`` issuing one + ``S_*_FP`` per element. The loop var name comes from the + mid_ir Elementwise's dst index (preserving Var identity with + the rendered indices below). + """ + if op.op in _FP_AT_BINOP_TO_INTRIN: + intrin = _FP_AT_BINOP_TO_INTRIN[op.op] + elif op.op in _FP_AT_UNARY_TO_INTRIN: + intrin = _FP_AT_UNARY_TO_INTRIN[op.op] + # COPY with srcs=[] is fold's zero-fill sentinel — route to the + # FPRAM-side zero op so the emitter sees only the dst address. + if op.op == UnaryOp.COPY and not op.srcs: + intrin = "fp_zero_at" + else: + raise ToPlenaError( + f"unsupported FPRAM Elementwise op {op.op!r}" + ) + # _emit_fp_scalar_op_at signature: for unary/copy, scalar_args = + # (src_addr, dst_addr); for binary, (lhs_addr, rhs_addr, dst_addr). + src_elements = [_fp_buffer_element_from_ref(s) for s in op.srcs] + dst_element = _fp_buffer_element_from_ref(op.dst) + leaf = _hlir.Op( + kind=intrin, + buffer_args=[], + scalar_args=src_elements + [dst_element], + annotations={"source": f"bare FPRAM Elementwise[{op.op.value}]"}, + ) + # Unroll the SIMD axis when fold absorbed a T.Parallel into the op: + # FPRAM has no vector ISA, so emit one S_*_FP per element. + if op.axis == -1 and op.size > 1: + # Identify the SIMD-axis loop var: the VarRef idx in dst that + # isn't the cluster phase axis. With a cluster, FP_LANE expansion + # gives dst rank 2 — (lane_var, row_var); without a cluster it's + # rank 1 — (row_var,). Skip the cluster axis (if present) and the + # remaining single VarRef is the SIMD axis. + candidates = [ + i for i in op.dst.indices + if isinstance(i, VarRef) + and (cluster_axis_name is None or i.name != cluster_axis_name) + ] + if len(candidates) != 1: + raise ToPlenaError( + f"FPRAM elementwise with axis=-1 size={op.size} expects " + f"exactly one non-cluster VarRef in dst; got dst " + f"{op.dst.buffer.name!r} indices {list(op.dst.indices)!r} " + f"(cluster_axis={cluster_axis_name!r})" + ) + idx = candidates[0] + loop_var = _make_loop_var(idx.name) + return _hlir.make_for_op(loop_var=loop_var, extent=op.size, body=[leaf]) + return leaf + + +# --------------------------------------------------------------------------- +# Walker +# --------------------------------------------------------------------------- + + +_MULTI_LANE_OP_KINDS = frozenset({ + "dma_h2v", "dma_h2m", "dma_v2h", + "dma_h2v_slice", "dma_h2m_slice", "dma_v2h_slice", + "btmm", "btmv", + "v_add", "v_sub", "v_mul", "v_exp", "v_reci", + "v_sqrt", "v_zero", + # vram→vram copy: one V_ADD_VF (f0=0) spans a full MLEN-wide row; + # under sync wrap the by_phase index is already folded to 0, so the + # op covers all cluster lanes in one issue — don't re-wrap in a + # synthetic for-lane loop. + "copy_v_to_v", + # vram↔fpram transfer: the ``v_fp_transfer_slice_*`` ops carry + # the whole logical region; isa_emit splits it into per-MLEN + # S_MAP_FP_V / S_MAP_V_FP issues. Each issue natively spans all + # cluster lanes (sync wrap zeroes the cluster phase axis on both + # sides), so the op is multi-lane in hardware and must not be + # re-wrapped in a synthetic for-lane loop. + "v_fp_transfer_slice_v_to_fp", "v_fp_transfer_slice_fp_to_v", +}) + + +def _op_fires_once_across_lanes(op: _hlir.Op) -> bool: + """True if ``op`` is a single HW instruction that already spans + every lane natively (multi-lane in hardware). Such ops must NOT + be wrapped in a synthetic ``for lane`` loop — that would re-issue + the same multi-lane instruction lane_count times.""" + return op.kind in _MULTI_LANE_OP_KINDS + + +def _wrap_per_lane_ops_with_for_lane( + body: List[_hlir.Op], + lane_var: "_tir.Var", + lane_extent: int, +) -> List[_hlir.Op]: + """Emit the body of a CLUSTER ParallelAxis. Each per-lane op gets + its OWN ``for lane in cluster_extent`` wrapper so the lane axis is + threaded one instruction at a time (matching the kernel's program + order). Multi-lane HW ops stay un-wrapped: they fire once and + cover every lane natively. + + Structural ops (``for`` nodes — typically the kernel's ``for row`` + or ``for kv_block``) are recursed into: their body is rewritten + with the same rule so per-lane ops inside still get their own + for-lane wrapper. + """ + out: List[_hlir.Op] = [] + for op in body: + if op.kind == "for": + # Recurse: inner body may contain per-lane ops that need + # their own per-op for-lane wrapper. + inner = _wrap_per_lane_ops_with_for_lane( + op.body or [], lane_var, lane_extent, + ) + out.append(_hlir.Op( + kind="for", + buffer_args=list(op.buffer_args), + scalar_args=list(op.scalar_args), + annotations=dict(op.annotations), + body=inner, + )) + elif _op_fires_once_across_lanes(op): + out.append(op) + else: + out.append(_hlir.make_for_op( + loop_var=lane_var, extent=lane_extent, body=[op], + )) + return out + + +def _walk_stmts(stmts: List[Stmt], + buf_name_to_hlir: Dict[str, _hlir.Buffer], + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + ) -> List[_hlir.Op]: + out: List[_hlir.Op] = [] + for s in stmts: + out.extend(_walk_stmt(s, buf_name_to_hlir, cluster_extent, + cluster_axis_name, cluster_axis_var)) + return out + + +def _walk_stmt(stmt: Stmt, + buf_name_to_hlir: Dict[str, _hlir.Buffer], + cluster_extent: Optional[int], + cluster_axis_name: Optional[str] = None, + cluster_axis_var: "Optional[VarRef]" = None, + ) -> List[_hlir.Op]: + if isinstance(stmt, ParallelAxis): + if stmt.kind == ParallelKind.CLUSTER: + # CLUSTER: each per-lane op gets its own ``for lane in + # cluster_extent`` wrapper (matching the kernel's program + # order); multi-lane HW ops (DMA / BTMM / V-machine whole- + # tile) stay un-wrapped because they fire across all lanes + # natively. We recurse into structural ``for`` nodes so + # per-lane ops nested inside (e.g. inside a kernel + # ``for row``) still get a per-op for-lane wrapper. + body = _walk_stmts(stmt.body, buf_name_to_hlir, stmt.extent, + stmt.axis_name, stmt.axis_var) + lane_var = _axis_loop_var(stmt) + return _wrap_per_lane_ops_with_for_lane( + body, lane_var, stmt.extent, + ) + # threadIdx.* axes are a GPU abstraction: PLENA HW has no + # thread-level dispatch, every instruction is implicitly + # broadcast across the SIMD width. Unwrap them so the body + # runs once, not threads-many times. + if (stmt.thread_tag is not None + and stmt.thread_tag.startswith("threadIdx.")): + return _walk_stmts(stmt.body, buf_name_to_hlir, cluster_extent, + cluster_axis_name, cluster_axis_var) + # BLOCK_IDX or LOGICAL_GRID → flatten to a serial for. + body = _walk_stmts(stmt.body, buf_name_to_hlir, cluster_extent, + cluster_axis_name, cluster_axis_var) + return [_hlir.make_for_op( + loop_var=_axis_loop_var(stmt), + extent=stmt.extent, body=body, + )] + if isinstance(stmt, For): + body = _walk_stmts(stmt.body, buf_name_to_hlir, cluster_extent, + cluster_axis_name, cluster_axis_var) + for_op = _hlir.make_for_op( + _for_loop_var(stmt), stmt.extent, body=body, + ) + # ``parallel`` mid-IR kind = user wrote T.Parallel but the + # body didn't fold into a vector op. Treat it downstream + # like a serial loop (HW C_LOOP_START/END) but mark + # ``order_independent`` so the backend can skip the per-iter + # idx-slot read/write and run the hw counter as the lvar + # directly. See mir_to_isa._emit_loop_serial. + if stmt.kind == "parallel": + for_op.annotations["loop_kind"] = "serial" + for_op.annotations["order_independent"] = True + else: + for_op.annotations["loop_kind"] = stmt.kind + return [for_op] + if isinstance(stmt, Async): + # By pass_5 every Async should be MultiLaneOp; if it lingers, + # walk through. + return _walk_stmts(stmt.body, buf_name_to_hlir, cluster_extent, + cluster_axis_name, cluster_axis_var) + if isinstance(stmt, MultiLaneOp): + # The actual cluster extent comes from the enclosing CLUSTER + # ParallelAxis (forwarded as ``cluster_extent``). Pass it down + # so per-op lowering helpers can use it directly — both for + # the HW-side lane_count scalar and for any synthetic + # ``for lane`` they wrap around per-lane FPRAM ops. + actual_lane = cluster_extent or 1 + return [_lower_multi_lane(stmt, buf_name_to_hlir, actual_lane)] + if isinstance(stmt, Dma): + # Bare Dma — shouldn't normally happen post-pipeline, but + # support it as a single-lane DMA. + return [_lower_multi_lane_dma(stmt, 1, buf_name_to_hlir)] + if isinstance(stmt, Gemm): + if stmt.kind == "btmm": + # Shouldn't be bare; treat as single-lane btmm. + return [_lower_multi_lane_btmm(stmt, 1, buf_name_to_hlir)] + return [_lower_bare_per_head_gemm( + stmt, cluster_extent, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + buf_name_to_hlir=buf_name_to_hlir, lane_modes=_LANE_MODES, + )] + if isinstance(stmt, Reduce): + return [_lower_bare_reduce(stmt, cluster_extent, cluster_axis_name, + cluster_axis_var=cluster_axis_var)] + if isinstance(stmt, Elementwise): + has_broadcast = any(isinstance(s, Broadcast) for s in stmt.srcs) + if has_broadcast: + return [_lower_bare_broadcast_elementwise( + stmt, cluster_extent, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + )] + dst_scope = buf_name_to_hlir[stmt.dst.buffer.name].scope + if dst_scope == _scope.FPRAM: + # Per-lane FPRAM scalar update (M_OLD[row] = M_INIT[row] + # etc.). Lower to a ``for lane: fp__at`` loop — the + # ``row`` loop is the enclosing mid_ir For (already a HLIR + # for op by now). + return [_lower_bare_fp_scalar_elementwise( + stmt, cluster_extent, cluster_axis_name, + cluster_axis_var=cluster_axis_var, + )] + # Pure elementwise that wasn't wrapped (shouldn't happen if + # pass_4 ran). Treat as single-lane multi_lane. + return [_lower_multi_lane_elementwise(stmt, cluster_extent or 1, buf_name_to_hlir)] + if isinstance(stmt, RawStore): + raise ToPlenaError( + f"RawStore lowering is not implemented yet — fold did not " + f"recognise the op pattern. dst={stmt.dst.buffer.name}" + f"{list(stmt.dst.indices)} := {stmt.value!r}" + ) + raise ToPlenaError(f"unhandled mid_ir stmt {type(stmt).__name__}") + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc, + build_dir: Optional[Path] = None, + mlen: int = 64) -> _hlir.HLIRModule: + """Lower a MidFunc to HLIRModule. + + ``mlen`` is the hardware vector lane width (V_*_V row width). It is + stashed in ``_HW_MLEN`` so lowerings that emit per-row HW vector + ops (vram→vram copy etc.) can stride by it without having to + reverse-engineer it from buffer shapes. + + If ``build_dir`` is given, write a ``.midir.txt`` snapshot + there before lowering — useful for diff against the legacy pipeline + or for post-mortem when HLIR looks wrong. + """ + global _HW_MLEN + _HW_MLEN = int(mlen) + _VAR_CACHE.clear() + _LANE_MODES.clear() + _LANE_AXIS_INFO.clear() + _BUFFER_HLIR_AXES.clear() + # Register the split-axis form for each logical lane axis. mid_ir + # carries ``lane_axes`` (one per cluster) and ``cluster_counts``; + # the original lane var (as a ``VarRef``) appears in BufferRef + # indices as the un-split logical view (kept on global refs whose + # view pass skipped), and must expand to ``phase_var + number_var + # * count`` for ISA materialisation. + # + # Look up each pair (original VarRef, phase VarRef, number VarRef, + # count) by walking the body for matching CLUSTER ParallelAxes. + # Split records ``original_axis_var`` on both the CLUSTER (phase) + # and its enclosing number axis, with parent_grid_axis_name on the + # CLUSTER pointing at the number axis by name. + _populate_lane_axis_info( + func, + lane_axes=getattr(func, "lane_axes", []) or [], + cluster_counts=getattr(func, "cluster_counts", []) or [], + ) + if build_dir is not None: + build_dir = Path(build_dir) + build_dir.mkdir(parents=True, exist_ok=True) + dump_path = build_dir / f"{func.name}.midir.txt" + dump_path.write_text(format_func(func)) + + # Use-driven scope overrides (e.g. Gemm B operands → MRAM). + overrides = _infer_scope_overrides(func) + # Lane-expansion modes for each non-global buffer (COL_PACK / + # ROW_STACK / FP_LANE / BSHD_LIFT). Used to reshape buffers and + # fold ref indices into canonical 4D BSHD form. + # + # If the kernel didn't go through cluster fusion (no lane axes, or + # every non-global buffer already covers a full MLEN-wide vector), + # ``split`` and friends were no-ops and buffers still have their + # as-written shapes. Lane expansion in that case would force an + # extra cluster dim onto already-correct shapes and crash + # ``_expand_buffer_shape_with_cluster``'s rank check. Skip it. + cluster_skipped = should_skip_cluster(func) + if cluster_skipped: + lane_modes: Dict[str, str] = {} + lane_count = 1 + else: + lane_modes = _infer_lane_modes(func) + lane_count = func.cluster_counts[0] if func.cluster_counts else 1 + _LANE_MODES.update(lane_modes) + + # Build buffer table. Track per-buffer ref-rewrite recipe so refs + # in the op stream can be transformed into the buffer's post- + # expansion 4D coordinate system after the walk. + # + # Two sources of rank growth, both needing to keep refs in sync + # with their buffers: + # * pad-to-4D path: 1D/2D author-declared on-chip buffers got + # extent-1 axes inserted; refs need the same fills (start=0, + # extent=1) at those positions. + # * cluster-expand path: rank-3 ``(LANE, S, D)`` from + # ``split._grow_buffer`` got permuted to 4D BSHD; refs are + # still rank 3 (view pass prepended phase) and need the + # cluster-mode-specific permutation in + # ``_CLUSTER_REF_REWRITE``. + # Kernel-wide layout hint (``T.func_attr({"plena.layout": "NCHW"})``) + # stamped onto every HLIR Buffer so downstream stride math picks the + # right axis. Default ``BSHD`` matches what the rest of the pipeline + # assumes when no hint is provided. + kernel_layout = "BSHD" + if func.attrs is not None and "plena.layout" in func.attrs: + kernel_layout = str(func.attrs["plena.layout"]) + + buf_name_to_hlir: Dict[str, _hlir.Buffer] = {} + pad_inserts: Dict[str, Tuple[int, ...]] = {} + cluster_modes_for_refs: Dict[ + str, Tuple[str, Optional[int], Tuple[int, ...]], + ] = {} + # Reset module-global mid_ir-axes ↔ hlir-axes alignment tables so a + # fresh compile doesn't see stale entries from a previous one. + _PAD_INSERTS.clear() + _CLUSTER_MODES.clear() + for buf in list(func.params) + list(func.allocs): + if buf.name in buf_name_to_hlir: + continue + hlir_buf, inserts = _make_hlir_buffer( + buf, + override=overrides.get(buf.name), + lane_count=lane_count, + mode=lane_modes.get(buf.name), + kernel_layout=kernel_layout, + ) + buf_name_to_hlir[buf.name] = hlir_buf + # Stamp HLIR-side per-dim role tags for every buffer, derived + # directly from the post-expansion shape + cluster_dim. Ops + # that need to identify dims by role (row_*_at family) read + # this through ``_BUFFER_HLIR_AXES``. + _BUFFER_HLIR_AXES[buf.name] = _hlir_axes_for_buffer(hlir_buf) + if inserts: + pad_inserts[buf.name] = inserts + _PAD_INSERTS[buf.name] = inserts + else: + # Cluster-expand route: track ``(mode, mid_ir_cluster_dim, + # new_4d_shape)`` so the walker can apply + # ``_rewrite_ref_for_cluster_mode``. ``cluster_dim`` on the + # mid_ir BufferDef tells us where the lane axis sits in + # the rank-3 source (set by ``split._grow_buffer`` and + # tracked through view/burn_view); the new 4D shape is + # needed to recover the lane extent under sync-wrap. + mode = lane_modes.get(buf.name) + if (mode is not None + and mode != _MODE_FP_LANE + and not _is_global_scope(buf.scope)): + cluster_modes_for_refs[buf.name] = ( + mode, buf.cluster_dim, tuple(hlir_buf.shape), + ) + _CLUSTER_MODES[buf.name] = ( + mode, buf.cluster_dim, tuple(hlir_buf.shape), + ) + + # Copy hoisted-constant values onto the matching HLIR buffers so + # ``--dump-buffer-addrs`` can emit them and the test harness can + # auto-preload. The pre-pass (frontend/passes/hoist_float_constants.py) + # stashes ``{buffer_name: value}`` on PrimFunc.attrs; mid_ir passes + # carry that attr forward to ``func.attrs`` unchanged. + if func.attrs is not None and "plena.hoisted_constants" in func.attrs: + for name, value in func.attrs["plena.hoisted_constants"].items(): + hlir_buf = buf_name_to_hlir.get(str(name)) + if hlir_buf is not None: + hlir_buf.constant_value = float(value) + + # Walk the body. + ops = _walk_stmts(func.body, buf_name_to_hlir, cluster_extent=None) + + # Synchronise refs with their buffers' post-expansion 4D rank. + # One walker, two rewrite mechanisms. In-place on op.buffer_args. + if pad_inserts or cluster_modes_for_refs: + _rewrite_refs_to_4d(ops, pad_inserts, cluster_modes_for_refs) + + # Tag every cluster-axis (per-lane) ``for`` so the loop_interchange + # pass can recognise it by structure, not by name string. A cluster + # ``for``'s loop_var is the phase axis var registered in + # ``_LANE_AXIS_INFO`` (value tuple = (phase, number, count)); both + # go through ``_get_var``'s name cache so identity == name match. + _tag_cluster_for_ops(ops) + + return _hlir.HLIRModule( + name=func.name, + buffers=buf_name_to_hlir, + ops=ops, + param_names=[b.name for b in func.params], + ) + + +def _tag_cluster_for_ops(ops: "List[_hlir.Op]") -> None: + """Stamp ``annotations["is_cluster_axis"] = True`` on every ``for`` + op whose loop variable is a cluster-phase axis. Recurses into nested + ``for`` bodies. In place. + + The set of phase-axis vars is read from ``_LANE_AXIS_INFO`` — its + values are ``(phase_var, number_var, count)`` and the phase var's + ``.name`` is exactly what a per-lane ``for`` carries as ``loop_var`` + (both minted via ``_get_var``).""" + phase_names = {phase.name for (phase, _num, _cnt) in _LANE_AXIS_INFO.values()} + if not phase_names: + return + + def _walk(op_list): + for op in op_list: + if op.kind == "for": + lv = op.annotations.get("loop_var") + lv_name = getattr(lv, "name", lv) + if lv_name in phase_names: + op.annotations["is_cluster_axis"] = True + if op.body is not None: + _walk(op.body) + elif op.body is not None: + _walk(op.body) + + _walk(ops) + + +def _pad_tuple_at(values, inserts: Tuple[int, ...], fill): + """Insert ``fill`` at each position in ``inserts`` (positions in + OUTPUT coords, ascending). E.g. inserting (0, 2) into ('a', 'b') + with fill 0 yields (0, 'a', 0, 'b').""" + out = list(values) + for pos in inserts: + out.insert(pos, fill) + return tuple(out) + + +def _rewrite_refs_to_4d( + ops, + pad_inserts: Dict[str, Tuple[int, ...]], + cluster_modes: Dict[str, Tuple[str, Optional[int], Tuple[int, ...]]], +) -> None: + """Bring every ``VramRegion`` ref in line with its buffer's post- + expansion 4D rank, in place. + + Picks the rewrite per buffer: pad-to-4D buffers get extent-1 + fills inserted at recorded positions; cluster-expanded buffers + use the lane-position-driven rewrite in + :func:`_rewrite_ref_for_cluster_mode`. BufferSlice (HBM-parent + slices) is not affected — HBM never gets rank-expanded. Recurses + into structured ops' bodies. + """ + for op in ops: + new_bargs = [] + for a in op.buffer_args: + if isinstance(a, (_hlir.VramRegion, _hlir.MramRegion)): + # Idempotent guard: lower paths that already emit 4D + # regions don't need any extra padding here. + if len(a.starts) == 4 and len(a.extents) == 4: + new_bargs.append(a) + continue + ctor = type(a) + if a.parent in pad_inserts: + inserts = pad_inserts[a.parent] + a = ctor( + parent=a.parent, + starts=_pad_tuple_at(a.starts, inserts, 0), + extents=_pad_tuple_at(a.extents, inserts, 1), + ) + elif a.parent in cluster_modes: + mode, old_cluster_dim, new_shape = cluster_modes[a.parent] + a = ctor( + parent=a.parent, + starts=_rewrite_ref_for_cluster_mode( + tuple(a.starts), mode, old_cluster_dim, + new_shape, is_extent=False, + ), + extents=_rewrite_ref_for_cluster_mode( + tuple(a.extents), mode, old_cluster_dim, + new_shape, is_extent=True, + ), + ) + new_bargs.append(a) + op.buffer_args = new_bargs + if op.body: + _rewrite_refs_to_4d(op.body, pad_inserts, cluster_modes) + + +__all__ = ["run", "ToPlenaError"] diff --git a/tilelang_tvm_compiler/frontend/mid_ir/passes/view.py b/tilelang_tvm_compiler/frontend/mid_ir/passes/view.py new file mode 100644 index 0000000..c5eab48 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/mid_ir/passes/view.py @@ -0,0 +1,649 @@ +"""pass_4b_view: assign view_perm to every BufferRef in cluster scope; +substitute the lane axis var on HBM refs; check global view consistency. + +Why this pass exists +-------------------- + +After ``split``, on-chip buffers got a LANE outer dim (physical +shape: ``(LANE, ..., D)``), but BufferRef.indices weren't updated — +they still address the pre-grow rank. After ``async_wrap``, can_async +ops are wrapped in Async regions but indices are still untouched. + +This pass does three things, all op-local + read-only on structure: + + 1. **Assign view_perm to each non-global BufferRef** by looking up + a static rule table keyed on (op-kind, ref-position). + * BTMM output (S_loc) and per-head LHS (S_loc as P @ V LHS) → + BHSD = lane stays at physical dim 0 (identity perm). + * Everything else → BSHD = lane permuted to logical dim 1 + (perm ``[1, 0, ..., N-1]``). + 2. **Prepend the cluster phase index** to each non-global ref's + indices, growing the ref from rank N to rank N+1 to match the + buffer's new physical rank. The prepended slot is ``cluster.axis_name`` + (e.g. ``"by_phase"``). NB: this addresses the *physical* dim 0, + not the logical view dim 0 — view_perm is applied on top. + 3. **Substitute the lane axis var on HBM refs** — any IndexExpr + equal to the original lane axis name (e.g. ``"by"``) becomes + the composite ``cluster_phase + grid_number * cluster_count``. + HBM buffers are NOT rank-grown (they're global) and don't get + view_perm. + +Global view consistency +----------------------- + +After per-ref assignment, the pass walks every BufferRef once more +and checks: for each buffer, all view_perms must be identical. If +two ops disagree on a buffer's view, raise ``ViewConflictError`` — +the kernel author has to refactor (no auto-reshape today). + +What this pass does NOT touch +----------------------------- + + * RawStore (lives outside cluster contexts; left alone) + * For / Async / ParallelAxis structure + * Buffer shapes + * Op kind / marker / can_async + * Anything outside a cluster body +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +from ..cluster_guard import should_skip_cluster, MLEN +from ..ir import ( + AxisRole, AxisInfo, + BufferDef, BufferRef, Slice, VarRef, + Dma, Gemm, Elementwise, Broadcast, Reduce, RawStore, + For, Async, MultiLaneOp, + ParallelAxis, ParallelKind, + MidFunc, Stmt, +) + + +class ViewError(RuntimeError): + pass + + +class ViewConflictError(ViewError): + pass + + +# Per-``run`` lookup so a CLUSTER ParallelAxis can find its enclosing +# number axis's VarRef by name. Split emits ``number -> phase`` nesting, +# so the number axis is always visited first. +_NUMBER_VAR_BY_NAME: dict = {} + + +# --------------------------------------------------------------------------- +# Roles → view kind +# --------------------------------------------------------------------------- + +# Two view shapes for now (rank 3, lane in dim 0 physically): +# BHSD: lane stays at logical dim 0 (identity perm) +# BSHD: lane swaps with the next dim → perm [1, 0, 2, ...] + + +def _identity_perm(rank: int) -> List[int]: + return list(range(rank)) + + +def _bshd_perm(rank: int) -> List[int]: + """Swap logical dim 0 (lane) and logical dim 1 (S). Other dims + untouched. Requires rank >= 2.""" + if rank < 2: + raise ViewError(f"BSHD requires rank >= 2, got {rank}") + return [1, 0] + list(range(2, rank)) + + +# --------------------------------------------------------------------------- +# Cluster context (active while we're inside a cluster body) +# --------------------------------------------------------------------------- + + +@dataclass +class _ClusterCtx: + phase_name: str # "by_phase" + number_name: str # "by_number" + cluster_count: int + original_axis_name: str # "by" (for HBM lane-var substitution) + # Identity channels — set by ``_walk`` from the CLUSTER ParallelAxis. + phase_var: VarRef + number_var: VarRef + original_var: VarRef + + +# --------------------------------------------------------------------------- +# Index rewriting (HBM lane-var substitution) +# --------------------------------------------------------------------------- + + +def _subst_lane_var(idx, ctx: _ClusterCtx): + """Recursively rewrite an IndexExpr: any ``VarRef`` matching the + original lane axis (by identity) becomes the composite + ``phase + number * cluster_count`` expression.""" + if isinstance(idx, VarRef) and idx == ctx.original_var: + return { + "op": "add", + "args": [ + ctx.phase_var, + {"op": "mul", "args": [ctx.number_var, ctx.cluster_count]}, + ], + } + if isinstance(idx, dict): + return { + "op": idx.get("op"), + "args": [_subst_lane_var(a, ctx) for a in idx.get("args", [])], + } + return idx + + +# --------------------------------------------------------------------------- +# Per-ref rewrite +# --------------------------------------------------------------------------- + + +def _rewrite_global_ref(ref: BufferRef, ctx: _ClusterCtx) -> BufferRef: + """User-declared global tensor. + + * Bare ``global`` (HBM-resident params): substitute the lane var + in indices so kernel-side ``[..., by, ...]`` accesses the + correct head slice. + * ``global.vram`` / ``global.mram`` / ``global.fpram`` (on-chip + global caches): leave indices verbatim — the kernel author + indexes them with raw scope-local coordinates; no lane-axis + substitution and no rank growth. + + Either way, ``view_perm`` stays as the user's declaration.""" + if ref.buffer.scope.startswith("global."): + return ref + return BufferRef( + buffer=ref.buffer, + indices=[_subst_lane_var(i, ctx) for i in ref.indices], + view_perm=ref.view_perm, + ) + + +def _forced_view_kind(buf: BufferDef) -> Optional[str]: + """Decide a view_kind that the buffer's physical shape demands, + overriding whatever the op-role table proposes. + + Rules keyed on the innermost dim D (= ``shape[-1]``), with ``shape`` + already post-split (LANE prepended): + + * rank == 2 (``[LANE, rows]`` — no H/D axis at all) → identity + perm regardless of op role. Row-only state buffers (M_OLD, + P_SUM, etc.) live here. + * D % MLEN == 0 → no force; op-role choice is fine (S dim spans + a full lane width, so head-pack vs head-stack are equivalent). + * 0 < D < MLEN → must be BSHD; the innermost dim is sub-lane + (typically hlen) so heads have to be col-packed. + + Invariant: this is called inside the view pass, between split + (which prepends the lane axis at position 0 and sets cluster_dim=0) + and burn_view (which is what actually permutes physical shape). + So ``shape[-1]`` is the original innermost D axis at this point. + The lookup would have to switch to ``cluster_dim``-aware indexing + if this helper ever moved downstream of burn_view. + """ + rank = len(buf.shape) + if rank <= 2: + return "IDENTITY" + d = buf.shape[-1] + if isinstance(d, int): + if d % MLEN == 0: + return None + if 0 < d < MLEN: + return "BSHD" + return None + + +def _rewrite_lane_ref(ref: BufferRef, ctx: _ClusterCtx, + view_kind: str) -> BufferRef: + """Non-global ref: prepend cluster phase to indices, set view_perm. + + ``view_kind`` is "BHSD" (identity) or "BSHD" (swap dim 0/1) from the + op-role table. The buffer's physical shape may override this (e.g. + row-only buffers force identity, sub-MLEN D forces BSHD) — see + ``_forced_view_kind``. + """ + new_indices = [ctx.phase_var] + list(ref.indices) + rank = len(new_indices) + forced = _forced_view_kind(ref.buffer) + effective = forced if forced is not None else view_kind + if effective == "BHSD" or effective == "IDENTITY": + perm = _identity_perm(rank) + elif effective == "BSHD": + perm = _bshd_perm(rank) + else: + raise ViewError(f"unknown view_kind {effective!r}") + return BufferRef( + buffer=ref.buffer, + indices=new_indices, + view_perm=perm, + ) + + +def _axes_after_lane_rewrite( + pre_axes: List[AxisInfo], + ref: BufferRef, + ctx: _ClusterCtx, + view_kind: str, +) -> List[AxisInfo]: + """Mirror ``_rewrite_lane_ref`` on the per-axis info table. + + ``_rewrite_lane_ref`` only *prepends* the cluster phase to + ``indices`` and stamps ``view_perm`` — it does NOT permute + indices. burn_view is the pass that bakes view_perm into both + Buffer.shape and ref.indices; the axes list must follow the same + schedule, so this helper only prepends the CLUSTER entry. The + permute on axes happens in burn_view alongside the index permute. + """ + return [ + AxisInfo(role=AxisRole.CLUSTER, extent=int(ctx.cluster_count)) + ] + list(pre_axes) + + +def _axes_after_global_rewrite( + pre_axes: List[AxisInfo], +) -> List[AxisInfo]: + """Global refs are unchanged by view (no phase prepend, no view_perm). + Just return a shallow copy so downstream mutations don't alias.""" + return list(pre_axes) + + +def _rewrite_ref(ref: BufferRef, ctx: _ClusterCtx, + view_kind: str) -> BufferRef: + # ``global`` / ``global.vram`` / ``global.mram`` / ``global.fpram``: + # user-declared global tensors keep their declared rank verbatim. + # No phase axis prepended, no view_perm, no lane-axis substitution + # of any kind (the kernel author indexes them with raw scope-local + # coordinates). + if ref.buffer.scope == "global" or ref.buffer.scope.startswith("global."): + return _rewrite_global_ref(ref, ctx) + return _rewrite_lane_ref(ref, ctx, view_kind) + + +def _rewrite_ref_with_axes( + ref: BufferRef, pre_axes: List[AxisInfo], + ctx: _ClusterCtx, view_kind: str, +) -> Tuple[BufferRef, List[AxisInfo]]: + """Run the same rewrite the ref undergoes on its axes table. + + Keeps the two channels (physical buffer view + per-op axis roles) + in lock-step: a cluster phase prepended to the ref's indices also + prepends a CLUSTER ``AxisInfo`` to its axes; a permute on + ``view_perm`` permutes axes the same way. + """ + is_global = ( + ref.buffer.scope == "global" or ref.buffer.scope.startswith("global.") + ) + if is_global: + return _rewrite_global_ref(ref, ctx), _axes_after_global_rewrite(pre_axes) + return ( + _rewrite_lane_ref(ref, ctx, view_kind), + _axes_after_lane_rewrite(pre_axes, ref, ctx, view_kind), + ) + + +def _rewrite_src(src, ctx: _ClusterCtx, view_kind: str): + """Wrap-aware rewrite: Broadcast carries a BufferRef inside + + broadcast_dims that point into dst's logical rank. Since dst rank + grows by 1 (the prepended phase index), broadcast_dims need to + shift by 1 too — but only if the original dim index was at or + after dim 0 in dst's logical (post-prepend, post-view) shape. + Simpler: bump every dim by 1.""" + if isinstance(src, Broadcast): + new_dims = [d + 1 for d in src.broadcast_dims] + return Broadcast( + src=_rewrite_ref(src.src, ctx, view_kind), + broadcast_dims=new_dims, + ) + return _rewrite_ref(src, ctx, view_kind) + + +def _rewrite_src_with_axes( + src, pre_axes: List[AxisInfo], + ctx: _ClusterCtx, view_kind: str, +): + """Axes-aware src rewrite that mirrors ``_rewrite_src``. + + Returns ``(new_src, new_axes)``. For Broadcast srcs, the wrapped + BufferRef's axes are rewritten the same way as a plain ref's. + """ + if isinstance(src, Broadcast): + new_dims = [d + 1 for d in src.broadcast_dims] + new_inner, new_axes = _rewrite_ref_with_axes( + src.src, pre_axes, ctx, view_kind, + ) + return Broadcast(src=new_inner, broadcast_dims=new_dims), new_axes + return _rewrite_ref_with_axes(src, pre_axes, ctx, view_kind) + + +# --------------------------------------------------------------------------- +# BHSD buffer set — pre-scan +# --------------------------------------------------------------------------- + + +def _collect_bhsd_buffers(stmts) -> set: + """Return the set of buffer names that MUST be BHSD because some + op produces or consumes them in BHSD form. + + Today the BHSD-anchored ops are: + * Gemm[btmm].c — BTMM output buffer + * Gemm[overwrite].a — per-head matmul LHS + + Once a buffer is anchored to BHSD by one of those ops, every other + op touching the same buffer (Reduce, Elementwise w/ Broadcast, + even pure Elementwise) must also use BHSD on that buffer to keep + the global view consistent. + """ + out: set = set() + + def visit(s): + if isinstance(s, Gemm): + if s.kind == "btmm" and s.c.buffer.scope != "global": + out.add(s.c.buffer.name) + if s.kind == "overwrite" and s.a.buffer.scope != "global": + out.add(s.a.buffer.name) + if isinstance(s, (ParallelAxis, For, Async)): + for c in s.body: + visit(c) + return + if isinstance(s, MultiLaneOp): + visit(s.inner) + return + + for s in stmts: + visit(s) + return out + + +# --------------------------------------------------------------------------- +# Op rewrite — picks per-ref view_kind from a static rule table +# --------------------------------------------------------------------------- + + +# Convention: every non-global ref is BSHD by default, except the +# specific (op, position) pairs listed here (which want BHSD). +_BHSD_POSITIONS = { + # BTMM output + ("Gemm[btmm]", "c"), + # per-head matmul LHS + ("Gemm[overwrite]", "a"), +} + + +def _gemm_kind_key(op: Gemm) -> str: + return f"Gemm[{op.kind}]" + + +def _view_kind_for(op_key: str, position: str) -> str: + return "BHSD" if (op_key, position) in _BHSD_POSITIONS else "BSHD" + + +def _rewrite_op(op, ctx: _ClusterCtx, bhsd_buffers: set): + if isinstance(op, Dma): + kv_src = _view_kind_for("Dma", "src") + kv_dst = _view_kind_for("Dma", "dst") + new_src, new_src_axes = _rewrite_ref_with_axes(op.src, op.src_axes, ctx, kv_src) + new_dst, new_dst_axes = _rewrite_ref_with_axes(op.dst, op.dst_axes, ctx, kv_dst) + return Dma( + src=new_src, dst=new_dst, + src_axes=new_src_axes, dst_axes=new_dst_axes, + marker=op.marker, can_async=op.can_async, + ) + if isinstance(op, Gemm): + key = _gemm_kind_key(op) + new_a, new_a_axes = _rewrite_ref_with_axes(op.a, op.a_axes, ctx, _view_kind_for(key, "a")) + new_b, new_b_axes = _rewrite_ref_with_axes(op.b, op.b_axes, ctx, _view_kind_for(key, "b")) + new_c, new_c_axes = _rewrite_ref_with_axes(op.c, op.c_axes, ctx, _view_kind_for(key, "c")) + return Gemm( + a=new_a, b=new_b, c=new_c, + a_axes=new_a_axes, b_axes=new_b_axes, c_axes=new_c_axes, + transpose_a=op.transpose_a, transpose_b=op.transpose_b, + kind=op.kind, marker=op.marker, can_async=op.can_async, + ) + if isinstance(op, Elementwise): + # View follows the dst buffer's anchor: if dst was anchored to + # BHSD by a BTMM output (or per-head LHS), elementwise must + # honor that. Otherwise default to BSHD — both pure ew (v_add) + # and broadcast ew (row_*_fp_at) accept BSHD freely. + view = "BHSD" if op.dst.buffer.name in bhsd_buffers else "BSHD" + new_dst, new_dst_axes = _rewrite_ref_with_axes(op.dst, op.dst_axes, ctx, view) + new_srcs: list = [] + new_src_axes: list = [] + for s, sa in zip(op.srcs, op.src_axes or [[]] * len(op.srcs)): + new_s, new_sa = _rewrite_src_with_axes(s, sa, ctx, view) + new_srcs.append(new_s) + new_src_axes.append(new_sa) + return Elementwise( + dst=new_dst, srcs=new_srcs, op=op.op, + dst_axes=new_dst_axes, src_axes=new_src_axes, + axis=op.axis, size=op.size, + marker=op.marker, can_async=op.can_async, + ) + if isinstance(op, Reduce): + # Reduce src is what determines layout (the row-reducible + # buffer). If src is BHSD-anchored (BTMM output) → BHSD; else + # BSHD. + view = "BHSD" if op.src.buffer.name in bhsd_buffers else "BSHD" + new_dst, new_dst_axes = _rewrite_ref_with_axes(op.dst, op.dst_axes, ctx, view) + new_src, new_src_axes = _rewrite_ref_with_axes(op.src, op.src_axes, ctx, view) + return Reduce( + dst=new_dst, src=new_src, op=op.op, axis=op.axis, + dst_axes=new_dst_axes, src_axes=new_src_axes, + marker=op.marker, can_async=op.can_async, + ) + if isinstance(op, RawStore): + # RawStore is opaque; don't rewrite. (And it shouldn't appear + # inside a cluster anyway; pass_4 doesn't wrap it.) + return op + raise ViewError(f"unhandled op type {type(op).__name__}") + + +# --------------------------------------------------------------------------- +# Walker +# --------------------------------------------------------------------------- + + +def _walk(stmt: Stmt, ctx: Optional[_ClusterCtx], bhsd_buffers: set) -> Stmt: + if isinstance(stmt, ParallelAxis): + if stmt.kind == ParallelKind.CLUSTER: + if stmt.parent_grid_axis_name is None: + raise ViewError( + f"cluster {stmt.axis_name!r} has no parent_grid_axis_name" + ) + original = stmt.original_axis_name + if original is None: + raise ViewError( + f"cluster {stmt.axis_name!r} missing original_axis_name; " + f"pass_3_split should have set it" + ) + # Identity channel from pass_3_split: phase axis carries + # axis_var (phase var) + original_axis_var (pre-split user + # var). Number var must be looked up by name from the + # surrounding scope (the matching ``*_number`` axis), but + # we only need the VarRef — we don't have it in scope here, + # so we capture it from a sibling lookup during walking. + # In practice the number ParallelAxis is the *parent* of + # this CLUSTER node (split emits ``number -> phase`` nesting), + # so by the time we hit the CLUSTER node we passed through + # the number axis. We grab its axis_var via a small + # side-channel: walk-time companion below. + if stmt.axis_var is None: + raise ViewError( + f"cluster {stmt.axis_name!r} missing axis_var; " + f"pass_3_split should have set it" + ) + if stmt.original_axis_var is None: + raise ViewError( + f"cluster {stmt.axis_name!r} missing original_axis_var; " + f"pass_3_split should have set it" + ) + number_var = _NUMBER_VAR_BY_NAME.get(stmt.parent_grid_axis_name) + if number_var is None: + raise ViewError( + f"cluster {stmt.axis_name!r}: parent number axis " + f"{stmt.parent_grid_axis_name!r} not visited before " + f"this CLUSTER (axis_var lookup failed). Did split " + f"break the number-outside / phase-inside nesting?" + ) + new_ctx = _ClusterCtx( + phase_name=stmt.axis_name, + number_name=stmt.parent_grid_axis_name, + cluster_count=stmt.extent, + original_axis_name=original, + phase_var=stmt.axis_var, + number_var=number_var, + original_var=stmt.original_axis_var, + ) + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s, new_ctx, bhsd_buffers) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + # Non-cluster ParallelAxis. Record axis_var by name so a nested + # CLUSTER can find its number-axis identity later. + if stmt.axis_var is not None: + _NUMBER_VAR_BY_NAME[stmt.axis_name] = stmt.axis_var + return ParallelAxis( + axis_name=stmt.axis_name, + extent=stmt.extent, + body=[_walk(s, ctx, bhsd_buffers) for s in stmt.body], + kind=stmt.kind, + thread_tag=stmt.thread_tag, + parent_grid_axis_name=stmt.parent_grid_axis_name, + original_axis_name=stmt.original_axis_name, + axis_var=stmt.axis_var, + original_axis_var=stmt.original_axis_var, + ) + if isinstance(stmt, For): + return For( + loop_var=stmt.loop_var, + extent=stmt.extent, + body=[_walk(s, ctx, bhsd_buffers) for s in stmt.body], + kind=stmt.kind, + loop_var_var=stmt.loop_var_var, + ) + if isinstance(stmt, Async): + return Async( + body=[_walk(s, ctx, bhsd_buffers) for s in stmt.body], + scope_id=stmt.scope_id, + ) + if isinstance(stmt, MultiLaneOp): + return stmt + # Leaf op. + if ctx is None: + return stmt + return _rewrite_op(stmt, ctx, bhsd_buffers) + + +# --------------------------------------------------------------------------- +# Global view consistency check +# --------------------------------------------------------------------------- + + +def _collect_views(stmt: Stmt, + table: Dict[str, List[Tuple[Optional[List[int]], str]]] + ) -> None: + """Walk; for every BufferRef record (view_perm, op_label) under the + buffer's name. Only refs to non-global buffers tracked.""" + def visit_ref(ref: BufferRef, label: str) -> None: + if ref.buffer.scope == "global": + return + table.setdefault(ref.buffer.name, []).append( + (tuple(ref.view_perm) if ref.view_perm is not None else None, label) + ) + + def visit_src(src, label: str) -> None: + if isinstance(src, Broadcast): + visit_ref(src.src, label) + else: + visit_ref(src, label) + + if isinstance(stmt, Dma): + visit_ref(stmt.src, "Dma.src") + visit_ref(stmt.dst, "Dma.dst") + elif isinstance(stmt, Gemm): + visit_ref(stmt.a, f"Gemm[{stmt.kind}].a") + visit_ref(stmt.b, f"Gemm[{stmt.kind}].b") + visit_ref(stmt.c, f"Gemm[{stmt.kind}].c") + elif isinstance(stmt, Elementwise): + visit_ref(stmt.dst, "Elementwise.dst") + for i, s in enumerate(stmt.srcs): + visit_src(s, f"Elementwise.src[{i}]") + elif isinstance(stmt, Reduce): + visit_ref(stmt.dst, "Reduce.dst") + visit_ref(stmt.src, "Reduce.src") + elif isinstance(stmt, ParallelAxis): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, For): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, Async): + for s in stmt.body: + _collect_views(s, table) + elif isinstance(stmt, MultiLaneOp): + _collect_views(stmt.inner, table) + # RawStore not inspected — its refs are opaque to view rules. + + +def _check_global_consistency(func: MidFunc) -> None: + table: Dict[str, List[Tuple[Optional[List[int]], str]]] = {} + for s in func.body: + _collect_views(s, table) + for buf_name, entries in table.items(): + # Drop entries with None perm (they came from outside cluster + # — never rewritten by this pass; not relevant to consistency + # within cluster scope). + in_cluster = [(p, l) for (p, l) in entries if p is not None] + if not in_cluster: + continue + first_perm, first_label = in_cluster[0] + for perm, label in in_cluster[1:]: + if perm != first_perm: + raise ViewConflictError( + f"buffer {buf_name!r} has conflicting view perms: " + f"{list(first_perm)} (from {first_label}) vs " + f"{list(perm)} (from {label}). " + f"Refactor the kernel to use two separate buffers, " + f"or change the op's role table." + ) + + +# --------------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------------- + + +def run(func: MidFunc) -> MidFunc: + """Assign view_perm + align ranks + substitute HBM lane var. + Errors on globally-inconsistent views.""" + if should_skip_cluster(func): + return func + _NUMBER_VAR_BY_NAME.clear() + bhsd_buffers = _collect_bhsd_buffers(func.body) + new_body = [_walk(s, ctx=None, bhsd_buffers=bhsd_buffers) for s in func.body] + new_func = MidFunc( + name=func.name, + params=list(func.params), + allocs=list(func.allocs), + body=new_body, + lane_axes=list(func.lane_axes), + cluster_counts=list(func.cluster_counts), + attrs=dict(func.attrs), + ) + _check_global_consistency(new_func) + return new_func + + +__all__ = ["run", "ViewError", "ViewConflictError"] diff --git a/tilelang_tvm_compiler/frontend/passes/__init__.py b/tilelang_tvm_compiler/frontend/passes/__init__.py new file mode 100644 index 0000000..f25959a --- /dev/null +++ b/tilelang_tvm_compiler/frontend/passes/__init__.py @@ -0,0 +1,6 @@ +"""Compiler passes for tilelang_plena_compiler. + +Each pass module exposes a single `run(mod) -> mod` function. Passes are +intentionally independent so they can be unit-tested in isolation; the +main `pipeline.py` strings them together. +""" diff --git a/tilelang_tvm_compiler/frontend/passes/hoist_float_constants.py b/tilelang_tvm_compiler/frontend/passes/hoist_float_constants.py new file mode 100644 index 0000000..412ff34 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/passes/hoist_float_constants.py @@ -0,0 +1,405 @@ +"""Hoist FP literals out of kernel bodies into ``global.fpram`` buffers. + +Why this pass exists +-------------------- + +A kernel author writes ``a[i,j] = a[i,j] * T.float16(0.0884)``. The +literal is a per-issue scalar; the compiler has no FP-immediate slot in +its scalar-vector ops, so historically the kernel had to declare a +dedicated ``alloc_fragment`` scalar (``SCALE``), wire up a testbench +preload to write the value, and reference ``SCALE[0]``. That's a lot +of boilerplate for one number. + +This pass automates that boilerplate. It scans the PrimFunc body, finds +every ``tir.FloatImm`` that appears as a value (RHS of a BufferStore or +inside a Call argument — *not* loop extents / buffer indices, which are +ints anyway), and: + + 1. De-duplicates by ``(dtype, value)``. + 2. Synthesises one ``alloc_shared(scope="global.fpram")`` buffer per + unique entry, shape ``(1,)``, name ``__const__``. + 3. Adds the new buffers to the kernel's ``tilelang_root`` block's + ``alloc_buffers`` list — i.e. exactly what the author would've + written by hand. + 4. Rewrites every targeted FloatImm into ``BufferLoad(synth, [0])``. + 5. Stamps ``{"plena.hoisted_constants": {name: value}}`` onto + ``PrimFunc.attrs`` so to_plena → HLIR can copy the values onto + ``hlir.Buffer.constant_value`` for the buffer-addrs JSON dump. + +Downstream passes (fold, mid_ir, address_alloc, isa_emit) see a plain +``global.fpram`` scalar buffer and lower it normally — exactly the path +the hand-written ``SCALE`` / ``M_INIT`` / ``L_INIT`` buffers in +``flash_decode_min`` already exercise. No other pass needs to know +about hoisting. + +What's NOT hoisted +------------------ + + * ``T.float16(0)`` as a sole RHS — fold.py already lowers this to a + multi-lane vector fill (``plena.zero_v``) which is far cheaper than + an FPRAM round-trip. Hoisting it would be a regression. + * FloatImms inside a ``T.float16(1.0) / x`` pattern that fold.py + recognises as ``UnaryOp.RECI`` — the literal is consumed by the + operator, not stored anywhere. + * Anything in loop extents / buffer indices (those are IntImm). + +The pass keeps a deny-list keyed on parent-expression shape so these +already-handled cases continue to flow through unchanged. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +import tvm +from tvm import tir + + +# Attribute key on the rewritten PrimFunc; to_plena reads it to populate +# ``hlir.Buffer.constant_value``. +ATTR_KEY = "plena.hoisted_constants" + + +def _safe_value_token(value: float) -> str: + """Render a float so it round-trips into a Python identifier piece. + + Examples: ``0.0884`` → ``0p0884``, ``-10000.0`` → ``neg10000``, + ``1.5e-3`` → ``0p0015``. We avoid scientific notation in the name + because ``e`` is ambiguous (could be misread as a hex digit); the + name is only used for diagnostics anyway. + """ + if value == int(value): + s = str(int(value)) + else: + s = f"{value:.6g}" + if s.startswith("-"): + return "neg" + s[1:].replace(".", "p") + return s.replace(".", "p") + + +def _dtype_token(dtype: str) -> str: + """``float16`` → ``f16``, ``float32`` → ``f32``.""" + if dtype.startswith("float"): + return "f" + dtype[len("float"):] + return dtype + + +class _ConstantTable: + """Dedup by ``(dtype, value)``. Synthesises one ``tir.Buffer`` per + unique entry and remembers everything needed by the downstream + rewrite + the to_plena attr handoff.""" + + def __init__(self) -> None: + # (dtype, value) → (buffer, name) + self._table: Dict[Tuple[str, float], Tuple[tir.Buffer, str]] = {} + # Insertion order, so generated ASM is reproducible across runs. + self._order: List[Tuple[str, float]] = [] + + def get_or_create(self, dtype: str, value: float) -> tir.Buffer: + key = (dtype, float(value)) + existing = self._table.get(key) + if existing is not None: + return existing[0] + name = f"__const_{_dtype_token(dtype)}_{_safe_value_token(value)}" + # Disambiguate if two distinct values somehow rendered to the + # same name (e.g. 0.000001 vs 0.0000011 both → "1e-06"-ish). + # Append a counter until unique. + existing_names = {n for _, n in self._table.values()} + unique = name + suffix = 1 + while unique in existing_names: + suffix += 1 + unique = f"{name}__{suffix}" + # 0-D scalar buffer. The semantically correct shape — a hoisted + # constant is a single value, broadcast to whatever rank the + # consumer needs. fold.py's existing broadcast-prefix matcher + # handles ``len(src_idx)=0 < len(dst_indices)`` by wrapping the + # load in a ``Broadcast(broadcast_dims=[0..n-1])`` automatically, + # so we don't need a special case. address_alloc walks + # ``num_elements`` which is 1 for any shape with empty tuple + # (product of empty seq = 1) — single slot, exactly what we + # want. + buf = tir.decl_buffer( + shape=(), dtype=dtype, name=unique, scope="global.fpram", + ) + self._table[key] = (buf, unique) + self._order.append(key) + return buf + + def alloc_buffers(self) -> List[tir.Buffer]: + return [self._table[k][0] for k in self._order] + + def name_value_map(self) -> Dict[str, float]: + return {self._table[k][1]: k[1] for k in self._order} + + +def _is_hoistable_dtype(dtype: str) -> bool: + """We only hoist fp16 / fp32 / bf16 literals — the only types the + FPRAM can store. Other dtypes pass through unchanged. + """ + return dtype in ("float16", "float32", "bfloat16") + + +# Expressions we deliberately skip — these patterns are absorbed by +# downstream passes more efficiently than an FPRAM round-trip would be. +def _is_skip_expr(expr) -> bool: + """``T.float16(0)`` as the entire RHS lowers to a multi-lane vector + fill in fold.py — far cheaper than going through FPRAM. Likewise a + ``T.float16(1.0) / x`` reciprocal is folded into a unary op. We + only test the top-level expr; nested cases (0 deep inside a binop) + still get hoisted, but those don't actually occur in real kernels. + """ + if isinstance(expr, tir.FloatImm) and float(expr.value) == 0.0: + return True + if isinstance(expr, tir.Div): + a = expr.a + # Peel TVM's fp16↔fp32 cast roundtrip — matches fold.py:822. + if isinstance(a, tir.Cast): + a = a.value + if isinstance(a, tir.FloatImm) and float(a.value) == 1.0: + return True + return False + + +def _rewrite_expr(expr, table: _ConstantTable, skip_top: bool, + broadcast_index=None): + """Recursively rewrite FloatImms inside ``expr``. + + ``skip_top`` policy: + * If ``expr`` is a ``FloatImm(0.0)`` and ``skip_top`` is True, we + return it unchanged — fold.py recognises this as a zero-fill + store and lowers it to ``plena.zero_v`` (cheaper than going + through FPRAM). + * If ``expr`` is ``Div(FloatImm(1.0), x)`` and ``skip_top`` is + True, we leave the literal ``1.0`` alone (fold.py absorbs it + into ``UnaryOp.RECI``) but still recurse into ``x`` — a + denominator can carry its own hoistable literals. + * Everything else: recurse as normal with ``skip_top=False``. + + ``broadcast_index`` is the leading index expression of the + enclosing BufferStore's dst — see :func:`_make_synth_load`. + """ + if expr is None: + return None + if skip_top: + # Zero-fill: leave the literal alone (fold turns it into + # plena.zero_v with no FPRAM round-trip). + if isinstance(expr, tir.FloatImm) and float(expr.value) == 0.0: + return expr + # Reciprocal: leave the leading 1.0 alone, recurse into the + # denominator. ``a`` might be wrapped in a Cast(fp32/fp16); + # peel one layer for the check, but pass the original ``expr.a`` + # through unchanged so we don't accidentally rewrite the dtype. + if isinstance(expr, tir.Div): + a = expr.a.value if isinstance(expr.a, tir.Cast) else expr.a + if isinstance(a, tir.FloatImm) and float(a.value) == 1.0: + return type(expr)( + expr.a, + _rewrite_expr(expr.b, table, False, broadcast_index), + ) + # Any other top-level shape: fall through to the normal walk. + if isinstance(expr, tir.FloatImm): + if not _is_hoistable_dtype(str(expr.dtype)): + return expr + buf = table.get_or_create(str(expr.dtype), float(expr.value)) + return _make_synth_load(buf, broadcast_index) + if isinstance(expr, tir.Cast): + return tir.Cast( + expr.dtype, _rewrite_expr(expr.value, table, False, broadcast_index), + ) + if isinstance(expr, tir.Call): + return tir.Call( + expr.dtype, expr.op, + [_rewrite_expr(a, table, False, broadcast_index) for a in expr.args], + ) + if isinstance(expr, tir.BufferLoad): + return tir.BufferLoad( + expr.buffer, + [_rewrite_expr(i, table, False, broadcast_index) for i in expr.indices], + ) + if isinstance(expr, tir.Select): + return tir.Select( + _rewrite_expr(expr.condition, table, False, broadcast_index), + _rewrite_expr(expr.true_value, table, False, broadcast_index), + _rewrite_expr(expr.false_value, table, False, broadcast_index), + ) + if isinstance(expr, tir.Ramp): + return tir.Ramp( + _rewrite_expr(expr.base, table, False, broadcast_index), + _rewrite_expr(expr.stride, table, False, broadcast_index), + expr.lanes, + ) + if isinstance(expr, tir.Broadcast): + return tir.Broadcast( + _rewrite_expr(expr.value, table, False, broadcast_index), + expr.lanes, + ) + # Generic binary op (Add/Sub/Mul/Div/Max/Min/...). + if hasattr(expr, "a") and hasattr(expr, "b"): + return type(expr)( + _rewrite_expr(expr.a, table, False, broadcast_index), + _rewrite_expr(expr.b, table, False, broadcast_index), + ) + # Pass-through for IntImm / Var / StringImm / etc. + return expr + + +def _make_synth_load(buf, broadcast_index): + """Build the ``BufferLoad`` that replaces a hoisted FloatImm. + + The synthesised buffer is 0-D (``shape=()``) in ``global.fpram`` + — semantically a scalar, broadcast to whatever rank the consumer + needs. We emit a ``BufferLoad`` with *no indices*; fold.py's + existing broadcast-prefix matcher picks this up as + ``len(src_idx) = 0 < len(dst_indices)`` and wraps the load in + ``Broadcast(broadcast_dims=[0..n-1])`` automatically, fanning the + value across every dst axis with no special-case logic on the + fold side. + + ``broadcast_index`` is no longer used (the load carries no index + at all) but is kept in the signature so callers don't break. + """ + del broadcast_index + return tir.BufferLoad(buf, []) + + +def _walk(stmt, table: _ConstantTable, root_block_name: str, + extra_allocs: List[tir.Buffer]): + if stmt is None: + return None + if isinstance(stmt, tir.SeqStmt): + return tir.SeqStmt( + [_walk(c, table, root_block_name, extra_allocs) for c in stmt.seq] + ) + if isinstance(stmt, tir.BlockRealize): + return tir.BlockRealize( + iter_values=list(stmt.iter_values), + predicate=stmt.predicate, + block=_walk(stmt.block, table, root_block_name, extra_allocs), + ) + if isinstance(stmt, tir.Block): + new_body = _walk(stmt.body, table, root_block_name, extra_allocs) + new_init = _walk(stmt.init, table, root_block_name, extra_allocs) \ + if stmt.init is not None else None + # Attach the synthesized buffers to the kernel-root Block — + # tilelang's `T.Kernel(...)` macro emits a Block named + # ``"tilelang_root"`` that owns every user alloc_buffer. Adding + # ours there keeps them in the same scoping bracket so all + # passes treat them like hand-written allocs. + if stmt.name_hint == root_block_name and extra_allocs: + new_allocs = list(stmt.alloc_buffers) + list(extra_allocs) + else: + new_allocs = stmt.alloc_buffers + return tir.Block( + iter_vars=stmt.iter_vars, reads=stmt.reads, writes=stmt.writes, + name_hint=stmt.name_hint, + body=new_body, + init=new_init, + alloc_buffers=new_allocs, + match_buffers=stmt.match_buffers, + annotations=stmt.annotations, + ) + if isinstance(stmt, tir.AttrStmt): + return tir.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, + _walk(stmt.body, table, root_block_name, extra_allocs), + ) + if isinstance(stmt, tir.For): + return tir.For( + stmt.loop_var, stmt.min, stmt.extent, stmt.kind, + _walk(stmt.body, table, root_block_name, extra_allocs), + stmt.thread_binding, stmt.annotations, + ) + if isinstance(stmt, tir.IfThenElse): + return tir.IfThenElse( + stmt.condition, + _walk(stmt.then_case, table, root_block_name, extra_allocs), + _walk(stmt.else_case, table, root_block_name, extra_allocs) + if stmt.else_case is not None else None, + ) + if isinstance(stmt, tir.LetStmt): + # inline_let_stmts runs before us, so this shouldn't appear in + # practice — but if it does, we recurse for robustness. + return tir.LetStmt( + stmt.var, stmt.value, + _walk(stmt.body, table, root_block_name, extra_allocs), + ) + if isinstance(stmt, tir.BufferStore): + skip_top = _is_skip_expr(stmt.value) + # Pass the dst's leading index expression to the rewriter so + # synthesised loads index their (1,)-shape buffer with the + # same expression. Fold's broadcast-prefix matcher then + # accepts the load as a scalar broadcast across that axis. + # ``stmt.indices`` is non-empty in practice for every store + # we'd rewrite into (a scalar store with no indices wouldn't + # have a vector RHS to broadcast against). + broadcast_index = stmt.indices[0] if stmt.indices else None + return tir.BufferStore( + stmt.buffer, + _rewrite_expr(stmt.value, table, skip_top, broadcast_index), + # Indices stay as-is — they're IntImm / Var / affine expr, + # never hoistable FloatImms. + list(stmt.indices), + ) + if isinstance(stmt, tir.Evaluate): + return tir.Evaluate(_rewrite_expr(stmt.value, table, False)) + if isinstance(stmt, tir.Allocate): + return tir.Allocate( + stmt.buffer_var, stmt.dtype, list(stmt.extents), + stmt.condition, + _walk(stmt.body, table, root_block_name, extra_allocs), + stmt.annotations, + ) + # Fall through: unknown stmt types pass unchanged so the pass + # doesn't get in the way of features added to TIR later. + return stmt + + +# tilelang names the kernel-body Block ``tilelang_root``. If a future +# tilelang version renames it, callers can pass an override via the +# ``root_block_name`` kwarg on :func:`run`. +DEFAULT_ROOT_BLOCK_NAME = "tilelang_root" + + +def run(func: tir.PrimFunc, + *, + root_block_name: str = DEFAULT_ROOT_BLOCK_NAME) -> tir.PrimFunc: + """Hoist FP literals to ``global.fpram`` buffers. See module docstring. + + No-op on kernels with no hoistable FloatImms (most kernels today). + """ + table = _ConstantTable() + # First sweep: collect + rewrite in one pass (the table is + # populated as a side effect of ``_rewrite_expr``). Buffers come + # out in the order they were first encountered. + new_body = _walk(func.body, table, root_block_name, extra_allocs=[]) + + new_allocs = table.alloc_buffers() + if not new_allocs: + return func + + # Second sweep: re-walk only to inject the new alloc_buffers into + # the root block. We can't do it in the first sweep because we + # don't know what to inject until the first sweep finishes. + new_body = _walk(new_body, _ConstantTable(), root_block_name, new_allocs) + + # Stash the {name: value} table on PrimFunc.attrs so to_plena can + # propagate it onto HLIR Buffer.constant_value (for the dump). We + # rebuild the PrimFunc with the new body / allocs first, then ride + # ``with_attr`` to set the attr — synthesising a ``DictAttrs`` + # directly is fragile across TVM versions. + name_value = table.name_value_map() + new_func = tir.PrimFunc( + params=func.params, + body=new_body, + ret_type=func.ret_type, + buffer_map=func.buffer_map, + attrs=func.attrs, + ) + return new_func.with_attr( + ATTR_KEY, + tvm.runtime.convert({name: float(v) for name, v in name_value.items()}), + ) + + +__all__ = ["run", "ATTR_KEY", "DEFAULT_ROOT_BLOCK_NAME"] diff --git a/tilelang_tvm_compiler/frontend/passes/inline_let_stmts.py b/tilelang_tvm_compiler/frontend/passes/inline_let_stmts.py new file mode 100644 index 0000000..cf53e83 --- /dev/null +++ b/tilelang_tvm_compiler/frontend/passes/inline_let_stmts.py @@ -0,0 +1,167 @@ +"""Inline every ``tir.LetStmt`` by substituting the bound var into its body. + +Why this pass exists +-------------------- + +When a kernel writes ``e = 2 * i`` and then references ``e`` multiple times, +TVMScript / tilelang's tracer is free to materialize the binding as a +``tir.LetStmt`` (typed ``e: T.int32 = 2 * i``). Several downstream passes +in this compiler walk the IR with hand-rolled visitors that don't enumerate +``tir.LetStmt`` (they fall through to a default ``return stmt`` branch), +which silently skips the body. Symptoms range from "BufferStore not +lowered" (lower_fp_row_patterns) to "unbound tir.Var" at isa-emit time. + +Rather than teach every visitor about LetStmt, we run this pass first and +make LetStmts disappear entirely: every ``Let(var, value, body)`` is +replaced by ``substitute(body, {var: value})``. Downstream passes then +only have to handle the canonical Stmt set. + +Substitution is recursive: nested LetStmts compose, and a LetStmt whose +``value`` itself references a previously-bound var has its value +substituted too. + +This pass is a no-op for kernels without LetStmts. +""" + +from __future__ import annotations + +from typing import Dict + +import tvm +from tvm import tir + + +class InlineLetStmtsError(RuntimeError): + pass + + +def _subst_expr(expr, mapping: Dict[tir.Var, tir.PrimExpr]): + if expr is None: + return None + if isinstance(expr, tir.Var): + repl = mapping.get(expr) + return repl if repl is not None else expr + if isinstance(expr, (tir.IntImm, tir.FloatImm, tir.StringImm)): + return expr + if isinstance(expr, tir.Cast): + return tir.Cast(expr.dtype, _subst_expr(expr.value, mapping)) + if isinstance(expr, tir.Call): + return tir.Call( + expr.dtype, expr.op, + [_subst_expr(a, mapping) for a in expr.args], + ) + if isinstance(expr, tir.BufferLoad): + return tir.BufferLoad( + expr.buffer, + [_subst_expr(i, mapping) for i in expr.indices], + ) + if isinstance(expr, tir.Select): + return tir.Select( + _subst_expr(expr.condition, mapping), + _subst_expr(expr.true_value, mapping), + _subst_expr(expr.false_value, mapping), + ) + if isinstance(expr, tir.Ramp): + return tir.Ramp( + _subst_expr(expr.base, mapping), + _subst_expr(expr.stride, mapping), + expr.lanes, + ) + if isinstance(expr, tir.Broadcast): + return tir.Broadcast(_subst_expr(expr.value, mapping), expr.lanes) + if hasattr(expr, "a") and hasattr(expr, "b"): + return type(expr)( + _subst_expr(expr.a, mapping), + _subst_expr(expr.b, mapping), + ) + if hasattr(expr, "value"): + # Catches tir.Not and friends. + return type(expr)(_subst_expr(expr.value, mapping)) + return expr + + +def _walk(stmt, mapping: Dict[tir.Var, tir.PrimExpr]): + if stmt is None: + return None + if isinstance(stmt, tir.SeqStmt): + return tir.SeqStmt([_walk(c, mapping) for c in stmt.seq]) + if isinstance(stmt, tir.BlockRealize): + return tir.BlockRealize( + iter_values=[_subst_expr(v, mapping) for v in stmt.iter_values], + predicate=_subst_expr(stmt.predicate, mapping), + block=_walk(stmt.block, mapping), + ) + if isinstance(stmt, tir.Block): + return tir.Block( + iter_vars=stmt.iter_vars, reads=stmt.reads, writes=stmt.writes, + name_hint=stmt.name_hint, + body=_walk(stmt.body, mapping), + init=_walk(stmt.init, mapping) if stmt.init is not None else None, + alloc_buffers=stmt.alloc_buffers, + match_buffers=stmt.match_buffers, + annotations=stmt.annotations, + ) + if isinstance(stmt, tir.AttrStmt): + return tir.AttrStmt( + stmt.node, + stmt.attr_key, + _subst_expr(stmt.value, mapping), + _walk(stmt.body, mapping), + ) + if isinstance(stmt, tir.For): + return tir.For( + stmt.loop_var, + _subst_expr(stmt.min, mapping), + _subst_expr(stmt.extent, mapping), + stmt.kind, + _walk(stmt.body, mapping), + stmt.thread_binding, + stmt.annotations, + ) + if isinstance(stmt, tir.IfThenElse): + return tir.IfThenElse( + _subst_expr(stmt.condition, mapping), + _walk(stmt.then_case, mapping), + _walk(stmt.else_case, mapping) if stmt.else_case is not None else None, + ) + if isinstance(stmt, tir.LetStmt): + # Substitute previously-seen vars into the new value, then bind. + # The original LetStmt is dropped — the body is rewritten with + # ``var -> value`` and walked. + new_value = _subst_expr(stmt.value, mapping) + new_mapping = dict(mapping) + new_mapping[stmt.var] = new_value + return _walk(stmt.body, new_mapping) + if isinstance(stmt, tir.BufferStore): + return tir.BufferStore( + stmt.buffer, + _subst_expr(stmt.value, mapping), + [_subst_expr(i, mapping) for i in stmt.indices], + ) + if isinstance(stmt, tir.Evaluate): + return tir.Evaluate(_subst_expr(stmt.value, mapping)) + if isinstance(stmt, tir.Allocate): + return tir.Allocate( + stmt.buffer_var, + stmt.dtype, + [_subst_expr(e, mapping) for e in stmt.extents], + _subst_expr(stmt.condition, mapping), + _walk(stmt.body, mapping), + stmt.annotations, + ) + raise InlineLetStmtsError( + f"unhandled stmt type {type(stmt).__name__}: {stmt!r}" + ) + + +def run(func: tir.PrimFunc) -> tir.PrimFunc: + return tir.PrimFunc( + params=func.params, + body=_walk(func.body, {}), + ret_type=func.ret_type, + buffer_map=func.buffer_map, + attrs=func.attrs, + ) + + +__all__ = ["run", "InlineLetStmtsError"] diff --git a/tilelang_tvm_compiler/frontend/passes/lower_compound_fp_stores.py b/tilelang_tvm_compiler/frontend/passes/lower_compound_fp_stores.py new file mode 100644 index 0000000..cb10feb --- /dev/null +++ b/tilelang_tvm_compiler/frontend/passes/lower_compound_fp_stores.py @@ -0,0 +1,385 @@ +"""Decompose compound FPRAM ``BufferStore`` RHS into single-op stores. + +Why this pass exists +-------------------- + +The downstream pass :mod:`lower_fp_row_patterns` only recognises *flat* +single-op assignments on FPRAM fragments:: + + OUT_FP[i] = X_FP[i] # fp_copy_at + OUT_FP[i] = X_FP[i] +/- /* Y_FP[i] # fp_add/sub/mul_at + OUT_FP[i] = T.exp(X_FP[i]) # fp_exp_at + OUT_FP[i] = 1 / X_FP[i] # fp_reci_at + +If a kernel writes a compound expression like:: + + OUT_FP[e] = X_FP[e] * C_FP[e] + X_FP[o] * NS_FP[e] + +the pattern matcher returns ``None`` and the store falls through unlowered, +producing silently-wrong ISA (the compiler emits an empty ``for`` body). + +This pass walks the IR before ``scope_inference`` runs and rewrites such +compound stores into a sequence of single-op stores using auto-allocated +temporary FPRAM fragments (``__tmp_fp_``):: + + __tmp_fp_0[e] = X_FP[e] * C_FP[e] + __tmp_fp_1[e] = X_FP[o] * NS_FP[e] + OUT_FP[e] = __tmp_fp_0[e] + __tmp_fp_1[e] + +Each generated temp matches the same shape / dtype / declared-scope +(``local.fragment``) as the original destination, so ``scope_inference`` +auto-promotes them to FPRAM (rank-1 fragment used in FP scalar context), +``allocate_group_memory`` auto-expands them to ``(lane_count, ...)``, and +``lower_fp_row_patterns`` lowers each single-op store as usual. + +The new buffers are appended to the *enclosing* ``tir.Block``'s +``alloc_buffers`` so they share the same scope as the user-declared +fragments. Each compound store gets its own fresh temps; address allocation +happens later (in HLIR construction) and is not lifetime-aware here, so a +deeply-nested expression may produce more temps than strictly necessary. + +This pass is a no-op for stores whose RHS already fits a recognised +single-op pattern. +""" + +from __future__ import annotations + +from typing import List, Optional + +import tvm +from tvm import tir + + +class LowerCompoundFpStoresError(RuntimeError): + pass + + +_BINOPS = (tir.Add, tir.Sub, tir.Mul) + + +def _is_fragment_buffer(buf: tir.Buffer) -> bool: + declared = buf.scope() if callable(getattr(buf, "scope", None)) else "global" + return declared == "local.fragment" + + +_PEEL_BINOPS = (tir.Add, tir.Sub, tir.Mul, tir.Div, tir.Max, tir.Min) + + +def _peel_cast(expr, target_dtype: str): + """Recursively strip TVM's fp16↔fp32 widening Casts so the whole + subtree is rebuilt at ``target_dtype``. + + TVM lowers ``fp16_a op fp16_b`` to + ``Cast(fp16, Cast(fp32, fp16_a) op Cast(fp32, fp16_b))`` and the + same widening propagates through nested calls (``T.exp``, + reciprocal, …). For decomposition we want to see the math as the + kernel author wrote it — purely at the dst's dtype. This walker + descends through both layers (outer narrow Cast and inner widen + Casts) and reconstructs binops / unary calls / leaves at the target + dtype. Anything it can't normalise is returned unchanged. + + A subtree returned by this function is invariant: it does not + contain any Cast nodes that change dtype, all literals are at + ``target_dtype``, and every BufferLoad already at ``target_dtype`` + is exposed as a leaf for ``_is_leaf`` to pick up. + """ + + def _rebuild(e): + # Drop redundant Cast wrappers regardless of nesting depth. + if isinstance(e, tir.Cast): + return _rebuild(e.value) + if isinstance(e, tir.IntImm): + return tir.IntImm(target_dtype, int(e.value)) + if isinstance(e, tir.FloatImm): + return tir.FloatImm(target_dtype, float(e.value)) + if isinstance(e, tir.BufferLoad): + return e + cls = type(e) + if cls in _PEEL_BINOPS: + return cls(_rebuild(e.a), _rebuild(e.b)) + if isinstance(e, tir.Call) and len(e.args) == 1: + return tir.Call(target_dtype, e.op, [_rebuild(e.args[0])]) + # Unknown node — bail out by returning as-is. The caller falls + # back to leaving the store untouched, surfacing the unknown + # shape downstream rather than silently lowering it wrong. + return e + + return _rebuild(expr) + + +def _is_leaf(expr) -> bool: + """A leaf expression doesn't need decomposition: it can sit directly + inside the recognised single-op pattern.""" + if isinstance(expr, (tir.BufferLoad, tir.IntImm, tir.FloatImm)): + return True + return False + + +def _is_one(expr) -> bool: + if isinstance(expr, tir.IntImm): + return int(expr.value) == 1 + if isinstance(expr, tir.FloatImm): + return float(expr.value) == 1.0 + return False + + +def _is_reci_pattern(expr) -> Optional[tir.PrimExpr]: + """Return the denominator of ``1 / x``, else None.""" + if isinstance(expr, tir.Div) and _is_one(expr.a): + return expr.b + return None + + +def _is_exp_call(expr) -> bool: + return ( + isinstance(expr, tir.Call) + and getattr(expr.op, "name", None) == "tir.exp" + and len(expr.args) == 1 + ) + + +def _is_already_single_op(value) -> bool: + """True iff `value` already matches a pattern recognised by + `lower_fp_row_patterns._try_lower_fp_store`.""" + if isinstance(value, tir.BufferLoad): + return True + if isinstance(value, _BINOPS): + return _is_leaf(value.a) and _is_leaf(value.b) + if _is_exp_call(value): + return _is_leaf(value.args[0]) + if _is_reci_pattern(value) is not None: + return _is_leaf(_is_reci_pattern(value)) + return False + + +class _Ctx: + """Allocator + accumulator state shared across the recursive walk.""" + + def __init__(self) -> None: + self.next_id = 0 + self.new_buffers: List[tir.Buffer] = [] + + def fresh_tmp(self, template: tir.Buffer) -> tir.Buffer: + name = f"__tmp_fp_{self.next_id}" + self.next_id += 1 + data = tir.Var( + name, + tvm.ir.PointerType(tvm.ir.PrimType(template.dtype), "local.fragment"), + ) + buf = tir.decl_buffer( + shape=list(template.shape), + dtype=template.dtype, + name=name, + data=data, + scope="local.fragment", + ) + self.new_buffers.append(buf) + return buf + + +def _to_leaf(expr, dst: tir.Buffer, indices, pre: List[tir.Stmt], + ctx: _Ctx) -> tir.PrimExpr: + """Ensure ``expr`` is a leaf (BufferLoad or constant); if not, evaluate + it into a fresh fragment and return a BufferLoad of that fragment. + + ``indices`` is reused as the storage index inside the temporary — every + auto-allocated fragment has the same shape as ``dst`` so it accepts the + same indexing. + """ + expr = _peel_cast(expr, str(dst.dtype)) + if _is_leaf(expr): + return expr + if isinstance(expr, _BINOPS): + lhs = _to_leaf(expr.a, dst, indices, pre, ctx) + rhs = _to_leaf(expr.b, dst, indices, pre, ctx) + tmp = ctx.fresh_tmp(dst) + pre.append(tir.BufferStore(tmp, type(expr)(lhs, rhs), list(indices))) + return tir.BufferLoad(tmp, list(indices)) + if _is_exp_call(expr): + inner = _to_leaf(expr.args[0], dst, indices, pre, ctx) + tmp = ctx.fresh_tmp(dst) + pre.append(tir.BufferStore( + tmp, + tir.Call(expr.dtype, expr.op, [inner]), + list(indices), + )) + return tir.BufferLoad(tmp, list(indices)) + denom = _is_reci_pattern(expr) + if denom is not None: + inner = _to_leaf(denom, dst, indices, pre, ctx) + tmp = ctx.fresh_tmp(dst) + pre.append(tir.BufferStore( + tmp, + tir.Div(tir.FloatImm(expr.dtype, 1.0), inner), + list(indices), + )) + return tir.BufferLoad(tmp, list(indices)) + raise LowerCompoundFpStoresError( + f"unsupported subexpression in compound FP store RHS: " + f"{type(expr).__name__}: {expr!r}" + ) + + +def _decompose_store(store: tir.BufferStore, ctx: _Ctx) -> tir.Stmt: + if not _is_fragment_buffer(store.buffer): + return store + if len(store.buffer.shape) != 1: + # FPRAM fragments are declared rank-1 by convention; anything else is + # left to the existing passes. + return store + + pre: List[tir.Stmt] = [] + target_dtype = str(store.buffer.dtype) + # Peel fp16↔fp32 cast roundtrips so the dispatch below matches the + # actual op shape regardless of TVM's widening artifacts. + value = _peel_cast(store.value, target_dtype) + + if _is_already_single_op(value): + # Rebuild the store so the RHS reflects the peeled form even when + # no decomposition is required. + if value is store.value: + return store + return tir.BufferStore(store.buffer, value, list(store.indices)) + + if isinstance(value, _BINOPS): + lhs = _to_leaf(value.a, store.buffer, store.indices, pre, ctx) + rhs = _to_leaf(value.b, store.buffer, store.indices, pre, ctx) + new_value = type(value)(lhs, rhs) + elif _is_exp_call(value): + inner = _to_leaf(value.args[0], store.buffer, store.indices, pre, ctx) + new_value = tir.Call(value.dtype, value.op, [inner]) + else: + denom = _is_reci_pattern(value) + if denom is not None: + inner = _to_leaf(denom, store.buffer, store.indices, pre, ctx) + new_value = tir.Div(tir.FloatImm(value.dtype, 1.0), inner) + else: + # Unknown shape — leave for downstream to flag. + return store + + final = tir.BufferStore(store.buffer, new_value, list(store.indices)) + if not pre: + return final + return tir.SeqStmt([*pre, final]) + + +def _walk(stmt, ctx: _Ctx): + if stmt is None: + return None + if isinstance(stmt, tir.SeqStmt): + return tir.SeqStmt([_walk(c, ctx) for c in stmt.seq]) + if isinstance(stmt, tir.BlockRealize): + return tir.BlockRealize( + iter_values=stmt.iter_values, + predicate=stmt.predicate, + block=_walk(stmt.block, ctx), + ) + if isinstance(stmt, tir.Block): + return tir.Block( + iter_vars=stmt.iter_vars, reads=stmt.reads, writes=stmt.writes, + name_hint=stmt.name_hint, + body=_walk(stmt.body, ctx), + init=_walk(stmt.init, ctx) if stmt.init is not None else None, + alloc_buffers=stmt.alloc_buffers, + match_buffers=stmt.match_buffers, + annotations=stmt.annotations, + ) + if isinstance(stmt, tir.AttrStmt): + return tir.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, _walk(stmt.body, ctx), + ) + if isinstance(stmt, tir.For): + return tir.For( + stmt.loop_var, stmt.min, stmt.extent, stmt.kind, + _walk(stmt.body, ctx), + stmt.thread_binding, stmt.annotations, + ) + if isinstance(stmt, tir.IfThenElse): + return tir.IfThenElse( + stmt.condition, + _walk(stmt.then_case, ctx), + _walk(stmt.else_case, ctx) if stmt.else_case is not None else None, + ) + if isinstance(stmt, tir.LetStmt): + # inline_let_stmts is supposed to have removed these, but be defensive. + return tir.LetStmt(stmt.var, stmt.value, _walk(stmt.body, ctx)) + if isinstance(stmt, tir.BufferStore): + return _decompose_store(stmt, ctx) + if isinstance(stmt, tir.Evaluate): + return stmt + if isinstance(stmt, tir.Allocate): + return tir.Allocate( + stmt.buffer_var, stmt.dtype, list(stmt.extents), + stmt.condition, _walk(stmt.body, ctx), stmt.annotations, + ) + return stmt + + +def _inject_alloc_buffers(stmt, new_buffers: List[tir.Buffer]): + """Append ``new_buffers`` to the alloc_buffers of the *first* tir.Block + we encounter (the kernel root block under T.Kernel). + + A simple top-down search is fine because there is exactly one root + block in the kernels we lower; extending the inner scopes wouldn't help + because every FP fragment needs to be visible across the whole kernel + body anyway. + """ + if not new_buffers: + return stmt + if isinstance(stmt, tir.SeqStmt): + out = [] + injected = False + for c in stmt.seq: + if injected: + out.append(c) + else: + new_c = _inject_alloc_buffers(c, new_buffers) + if new_c is not c: + injected = True + out.append(new_c) + return tir.SeqStmt(out) + if isinstance(stmt, tir.BlockRealize): + return tir.BlockRealize( + iter_values=stmt.iter_values, + predicate=stmt.predicate, + block=_inject_alloc_buffers(stmt.block, new_buffers), + ) + if isinstance(stmt, tir.Block): + return tir.Block( + iter_vars=stmt.iter_vars, reads=stmt.reads, writes=stmt.writes, + name_hint=stmt.name_hint, body=stmt.body, init=stmt.init, + alloc_buffers=list(stmt.alloc_buffers) + list(new_buffers), + match_buffers=stmt.match_buffers, + annotations=stmt.annotations, + ) + if isinstance(stmt, tir.AttrStmt): + return tir.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, + _inject_alloc_buffers(stmt.body, new_buffers), + ) + if isinstance(stmt, tir.For): + return tir.For( + stmt.loop_var, stmt.min, stmt.extent, stmt.kind, + _inject_alloc_buffers(stmt.body, new_buffers), + stmt.thread_binding, stmt.annotations, + ) + if isinstance(stmt, tir.LetStmt): + return tir.LetStmt(stmt.var, stmt.value, + _inject_alloc_buffers(stmt.body, new_buffers)) + return stmt # no Block found in this branch + + +def run(func: tir.PrimFunc) -> tir.PrimFunc: + ctx = _Ctx() + new_body = _walk(func.body, ctx) + new_body = _inject_alloc_buffers(new_body, ctx.new_buffers) + return tir.PrimFunc( + params=func.params, + body=new_body, + ret_type=func.ret_type, + buffer_map=func.buffer_map, + attrs=func.attrs, + ) + + +__all__ = ["run", "LowerCompoundFpStoresError"] diff --git a/tilelang_tvm_compiler/frontend/pipeline.py b/tilelang_tvm_compiler/frontend/pipeline.py new file mode 100644 index 0000000..4c8559d --- /dev/null +++ b/tilelang_tvm_compiler/frontend/pipeline.py @@ -0,0 +1,20 @@ +"""Legacy graph-IR frontend pipeline — removed. + +The mid_ir pipeline in ``frontend/mid_ir/passes/`` is the only active +lowering chain now. The graph-IR layer (``frontend/passes/graph_*``, +``classify_lane_use``, ``expand_lane_grid``, ``infer_lane_layout``, +``fuse_elementwise``, ``lower_to_hlir``, ``lift_from_raw``, +``forbid_plena_extern``) has been deleted in full. + +This stub stays so that ``import``-error sites surface a clear message +instead of ``ModuleNotFoundError`` for callers that haven't been +migrated yet (e.g. ``kernels/conv2d_min.py``). Migrate them to the +mid_ir pipeline (``tilelang_tvm_compiler.pipeline.compile_kernel``). +""" + + +def compile_func(*_args, **_kwargs): + raise RuntimeError( + "frontend.pipeline.compile_func has been removed. Use " + "tilelang_tvm_compiler.pipeline.compile_kernel instead." + ) diff --git a/tilelang_tvm_compiler/fuse_adjacent_loops.py b/tilelang_tvm_compiler/fuse_adjacent_loops.py new file mode 100644 index 0000000..6cf6357 --- /dev/null +++ b/tilelang_tvm_compiler/fuse_adjacent_loops.py @@ -0,0 +1,286 @@ +"""Adjacent-loop fusion pass (HLIR post-processing). + +Why this pass exists +-------------------- + +``to_plena`` lowers every per-lane scalar / per-row op independently, +and each one ends up wrapped in its own ``for`` loop. A run of +consecutive per-lane ops therefore produces a run of consecutive +identical loops:: + + for by_phase in [0, 4): fp_mul_at(...) + for by_phase in [0, 4): fp_add_at(...) + for by_phase in [0, 4): row_mul_fp(...) + for by_phase in [0, 4): fp_copy_at(...) + +That is needless loop overhead — the four loops have the same variable +and the same extent, with nothing (no DMA / btmm / multi-lane op) in +between. They are equivalent to a single loop over the four bodies:: + + for by_phase in [0, 4): + fp_mul_at(...); fp_add_at(...); row_mul_fp(...); fp_copy_at(...) + +This pass merges such adjacent loops. + +What it does +------------ + +For every body list in the module (top level and inside every ``for``): + + 1. **Recurse first** — fuse inside each ``for`` body bottom-up, so an + arbitrarily deep nest collapses in one pass. + 2. **Then merge adjacent siblings** — walk the list; whenever two + neighbouring ops are both ``for`` with the *same loop-variable + name*, the *same extent*, and the *same init*, concatenate their + bodies into one loop. Chains of N collapse into one. + +Correctness +----------- + + * Per-lane scalar ops carry no cross-lane dependency (each lane owns + its own FPRAM scalar slot), and same-lane order is preserved + because body B is appended after body A *inside the same iteration*. + So merging never reorders dependent work. + * The two loops use different ``loop_var`` objects that merely share a + name (``to_plena`` mints a fresh var per loop). After merging we + keep loop A's var and **substitute** every reference to loop B's var + in B's body with loop A's var, recursively, through ``scalar_args`` + and ``buffer_args`` PrimExpr / region trees. + +This pass is generic: it is not keyed on ``by_phase`` or any particular +loop name — any two adjacent same-shape loops fuse. +""" + +from __future__ import annotations + +from typing import Any, List + +from . import hlir as _hlir + + +def _loop_meta(op: _hlir.Op): + """Return ``(loop_var, extent, init)`` for a ``for`` op.""" + a = op.annotations + return a.get("loop_var"), a.get("extent"), a.get("init", 0) + + +def _var_name(v) -> Any: + """A loop var's identifying name. ``to_plena`` stores either a + ``tir.Var`` (``.name``) or a plain string. Two loops fuse when these + names match — the underlying objects differ per loop.""" + return getattr(v, "name", v) + + +# --------------------------------------------------------------------------- +# Variable substitution — rewrite refs to loop B's var with loop A's var +# --------------------------------------------------------------------------- + + +def _subst_in_value(value: Any, old_name: Any, new_var: Any) -> Any: + """Recursively replace a loop variable inside a scalar_arg / + buffer_arg value. Handles tir PrimExpr trees, HLIR BufferElement, + VramRegion / MramRegion / BufferSlice, plain containers. Anything it + doesn't recognise is returned untouched. + + Identity is by *name*: a tir.Var (or string) whose name equals + ``old_name`` becomes ``new_var``. + """ + # Bare loop var — a tir.Var or a plain string. Match on these + # EXACT types only: a generic tir PrimExpr has no `.name`, and + # comparing it with `== old_name` would invoke tir's overloaded + # `__eq__` (which builds an int32-vs-string equality node and + # crashes). A compound PrimExpr instead falls through to the + # `_is_tir_expr` branch below and is handled by tir's substitute. + if isinstance(value, str): + return new_var if value == old_name else value + if _is_tir_var(value): + return new_var if value.name == old_name else value + + # HLIR BufferElement: substitute inside its index expressions. + if isinstance(value, _hlir.BufferElement): + return _hlir.BufferElement( + buffer=value.buffer, + indices=tuple( + _subst_in_value(i, old_name, new_var) for i in value.indices + ), + ) + + # VramRegion / MramRegion: substitute inside `starts` (extents are + # ints, never carry a loop var). + if isinstance(value, _hlir.VramRegion): + return _hlir.VramRegion( + parent=value.parent, + starts=tuple( + _subst_in_value(s, old_name, new_var) for s in value.starts + ), + extents=value.extents, + ) + if isinstance(value, _hlir.MramRegion): + return _hlir.MramRegion( + parent=value.parent, + starts=tuple( + _subst_in_value(s, old_name, new_var) for s in value.starts + ), + extents=value.extents, + ) + + # tir PrimExpr tree — duck-typed walk over the usual child fields. + # Rebuilding tir nodes generically is fragile, so we rely on tir's + # own substitute when the value is a PrimExpr. + if _is_tir_expr(value): + return _tir_substitute(value, old_name, new_var) + + # Containers. + if isinstance(value, list): + return [_subst_in_value(v, old_name, new_var) for v in value] + if isinstance(value, tuple): + return tuple(_subst_in_value(v, old_name, new_var) for v in value) + + return value + + +def _is_tir_var(value: Any) -> bool: + """True iff ``value`` is exactly a ``tvm.tir.Var`` — the only tir + node that carries a ``.name`` and stands for a loop variable. + Compound PrimExprs (Add / Mul / …) are deliberately excluded.""" + try: + from tvm import tir as _t + return isinstance(value, _t.Var) + except Exception: + return False + + +def _is_tir_expr(value: Any) -> bool: + """True if ``value`` looks like a tvm.tir PrimExpr (has a dtype and + is not one of our own HLIR dataclasses / a plain scalar).""" + if isinstance(value, (int, float, str, bool)): + return False + return hasattr(value, "dtype") and value.__class__.__module__.startswith( + "tvm" + ) + + +def _tir_substitute(expr: Any, old_name: Any, new_var: Any) -> Any: + """Substitute a variable inside a tir PrimExpr using tir's own + ``substitute``. Falls back to returning the expr unchanged if tir is + unavailable or the substitution can't be expressed.""" + try: + from tvm import tir as _t + + def _mapper(v): + if getattr(v, "name", None) == old_name: + return new_var if _is_tir_expr(new_var) else None + return None + + return _t.stmt_functor.substitute(expr, _mapper) \ + if hasattr(_t, "stmt_functor") else _t.substitute(expr, _mapper) + except Exception: + return expr + + +def _subst_op(op: _hlir.Op, old_name: Any, new_var: Any) -> _hlir.Op: + """Return a copy of ``op`` with loop var ``old_name`` rewritten to + ``new_var`` throughout its args and (recursively) its body.""" + new_body = None + if op.body is not None: + new_body = [_subst_op(b, old_name, new_var) for b in op.body] + new_anno = dict(op.annotations) + # A nested for that happens to reuse the name keeps its own var — + # but its loop_var object is distinct, so only rewrite if the names + # collide AND it's not this op redefining it. Safe path: rewrite the + # stored loop_var only when it is not itself the shadowing var. In + # practice to_plena never shadows, so a plain rewrite is fine. + if "loop_var" in new_anno and _var_name(new_anno["loop_var"]) == old_name: + # This inner loop shadows old_name — stop substitution here for + # its body by NOT rewriting (its body refers to its own var). + return _hlir.Op( + kind=op.kind, buffer_args=list(op.buffer_args), + scalar_args=list(op.scalar_args), annotations=new_anno, + body=op.body, buffer_axes=list(op.buffer_axes), + ) + return _hlir.Op( + kind=op.kind, + buffer_args=[_subst_in_value(b, old_name, new_var) + for b in op.buffer_args], + scalar_args=[_subst_in_value(s, old_name, new_var) + for s in op.scalar_args], + annotations=new_anno, + body=new_body, + buffer_axes=list(op.buffer_axes), + ) + + +# --------------------------------------------------------------------------- +# Fusion +# --------------------------------------------------------------------------- + + +def _can_fuse(a: _hlir.Op, b: _hlir.Op) -> bool: + """Two ops fuse iff both are ``for`` loops with matching loop-var + name, extent and init.""" + if a.kind != "for" or b.kind != "for": + return False + a_var, a_ext, a_init = _loop_meta(a) + b_var, b_ext, b_init = _loop_meta(b) + return ( + _var_name(a_var) == _var_name(b_var) + and a_ext == b_ext + and a_init == b_init + ) + + +def _fuse_body(ops: List[_hlir.Op], changed: List[bool]) -> List[_hlir.Op]: + """Fuse adjacent loops in one body list. Recurses into every ``for`` + body first (bottom-up) so nested runs collapse too. ``changed`` is a + one-element mutable cell flipped to True on any merge.""" + # 1) recurse bottom-up + recursed: List[_hlir.Op] = [] + for op in ops: + if op.kind == "for" and op.body is not None: + recursed.append(_hlir.Op( + kind=op.kind, + buffer_args=list(op.buffer_args), + scalar_args=list(op.scalar_args), + annotations=dict(op.annotations), + body=_fuse_body(op.body, changed), + buffer_axes=list(op.buffer_axes), + )) + else: + recursed.append(op) + + # 2) merge adjacent siblings + out: List[_hlir.Op] = [] + for op in recursed: + if out and _can_fuse(out[-1], op): + prev = out[-1] + changed[0] = True + keep_var = _loop_meta(prev)[0] + drop_name = _var_name(_loop_meta(op)[0]) + # Rewrite op's body to use the kept loop var, then append. + rebased = [_subst_op(b, drop_name, keep_var) for b in op.body] + merged_body = list(prev.body) + rebased + # The merged loop body itself may now expose new adjacent + # loops — fuse it again so chains fully collapse. + out[-1] = _hlir.Op( + kind="for", + buffer_args=list(prev.buffer_args), + scalar_args=list(prev.scalar_args), + annotations=dict(prev.annotations), + body=_fuse_body(merged_body, changed), + buffer_axes=list(prev.buffer_axes), + ) + else: + out.append(op) + return out + + +def run(mod: _hlir.HLIRModule): + """Merge adjacent same-shape ``for`` loops throughout the module. + Mutates ``mod.ops`` in place. Returns ``(mod, changed)``; ``changed`` + is True iff at least one merge happened (fixed-point signal).""" + changed: List[bool] = [False] + mod.ops = _fuse_body(mod.ops, changed) + return mod, changed[0] + + +__all__ = ["run"] diff --git a/tilelang_tvm_compiler/hlir.py b/tilelang_tvm_compiler/hlir.py new file mode 100644 index 0000000..ba49ba3 --- /dev/null +++ b/tilelang_tvm_compiler/hlir.py @@ -0,0 +1,742 @@ +"""HLIR -- the small, pass-friendly IR that flows between PLENA passes. + +Pipeline overview: + + TIR PrimFunc + | + v PASS 1 PlenaCodegen.lower_to_hlir() + HLIR (Buffer + Op stream, no addresses) + | + v PASS 2 AddressAllocationPass + HLIR (each Buffer now has hbm/vram/mram address; Op args reference + buffers, not raw addresses; stride/scale defaults filled in) + | + v PASS 3 ISAEmitterPass + Real ISA text + +Why a tiny custom IR (rather than re-walking TIR each pass): + - we need to attach pass-specific state (resolved addresses, register + hints, scheduling info) to ops and buffers + - TIR doesn't easily carry that without dialect machinery + - keeping the IR small (fewer than ten op kinds for now) means each + pass is a single-file Python function, easy to read and test + +Op kinds intentionally mirror our `intrinsics` registry one-to-one: +each `plena.` extern call in TIR becomes one HLIR Op with kind=. +That keeps the codegen pass mechanical -- no clever rewrites yet. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from . import scope as _scope + + +@dataclass(frozen=True) +class TileLayout: + """Physical multi-tile VRAM/MRAM layout for a 4D BSHD-shaped buffer. + + A buffer with logical shape ``(B, S, H, D)`` whose ``S`` and/or ``D`` + overflow MLEN gets stored physically as a 7D layout: + + (D_TILES, S_TILES, H_GROUPS, B, MLEN, LANE_COUNT, D_INNER) + └─────── outer tile index ───────┘ └──── inner per-tile ──┘ + + where: + S_TILES = ceildiv(S, MLEN) + D_TILES = ceildiv(D, MLEN) if D > MLEN else 1 + D_INNER = MLEN if D > MLEN else min(D, MLEN) + H_GROUPS = ceildiv(H, LANE_COUNT) + LANE_COUNT determined by D_INNER: + * D_INNER == MLEN → LANE_COUNT = 1 + * D_INNER < MLEN → LANE_COUNT = MLEN // D_INNER + (typically MLEN // HLEN, hardware-dependent) + + Each per-tile inner block ``(MLEN, LANE_COUNT, D_INNER)`` is exactly + one ``H_LOAD_V`` worth of data. The outer ``(D_TILES, S_TILES, + H_GROUPS, B)`` is the tile grid; multi-tile DMAs walk it in + row-major (D_TILE outermost so D_TILES > 1 cases stay contiguous in + the natural way). + + Logical (b, s, h, d) → physical 7D: + d_tile = d // MLEN d_inner = d % MLEN + s_tile = s // MLEN s_inner = s % MLEN + h_grp = h // LANE_COUNT lane = h % LANE_COUNT + Physical flat offset: + d_tile * (S_TILES * H_GROUPS * B * MLEN * LANE_COUNT * D_INNER) + + s_tile * (H_GROUPS * B * MLEN * LANE_COUNT * D_INNER) + + h_grp * (B * MLEN * LANE_COUNT * D_INNER) + + b * (MLEN * LANE_COUNT * D_INNER) + + s_inner * (LANE_COUNT * D_INNER) + + lane * D_INNER + + d_inner + + Total physical element count equals logical numel — the layout is + just a permutation; AddressAllocationPass uses ``numel`` regardless. + """ + # Logical 4D shape (B, S, H, D). + logical_b: int + logical_s: int + logical_h: int + logical_d: int + # Tile grid sizes. + d_tiles: int + s_tiles: int + h_groups: int + # Inner tile dims. + mlen: int + lane_count: int + d_inner: int + + @property + def tile_elems(self) -> int: + """Element count of one inner tile = MLEN * LANE_COUNT * D_INNER.""" + return self.mlen * self.lane_count * self.d_inner + + @property + def num_tiles(self) -> int: + """Total number of inner tiles in the buffer.""" + return self.d_tiles * self.s_tiles * self.h_groups * self.logical_b + + +# Layout name -> (batch_idx, row_idx, channel_idx, col_idx) into a 4D shape. +# The TileLayout / 7D physical layout / multi-tile DMA logic is all written in +# canonical BSHD terms (batch / row-tiled / channel-grouped / col-tiled). +# A buffer declared in another layout (NCHW etc.) just has its axes permuted +# at the boundary — every downstream pass keeps thinking BSHD. +LAYOUT_AXES = { + "BSHD": (0, 1, 2, 3), # B=axes[0], S=axes[1], H=axes[2], D=axes[3] + "NCHW": (0, 2, 1, 3), # N=axes[0], H=axes[2], C=axes[1], W=axes[3] +} + +DEFAULT_LAYOUT = "BSHD" + + +def _select_axes(values, layout: str): + """Pick (batch, row, channel, col) values from a 4D ``values`` sequence + according to ``layout``. ``values`` is either a shape tuple or a starts + PrimExpr tuple — same indexing rule applies to both.""" + if layout not in LAYOUT_AXES: + raise ValueError( + f"unknown layout {layout!r}; known: {sorted(LAYOUT_AXES)}" + ) + bi, ri, ci, di = LAYOUT_AXES[layout] + return values[bi], values[ri], values[ci], values[di] + + +def logical_2d_extents(shape_or_extents, layout: str = DEFAULT_LAYOUT): + """Project a 4D shape (or slice extents) to ``(rows, cols)``. + + ``rows`` = batch * row-dim (the dims that get s-tiled / batched); + ``cols`` = channel * col-dim (the dims that fit inside one MLEN-wide + chunk per logical row). + + For BSHD the row-dim is axes[1] and channel-dim is axes[2], so the + legacy "merge last two as cols, fold the rest into rows" heuristic + happens to match. For NCHW the row-dim is axes[2] and the channel + is axes[1], so we have to look up axes per ``LAYOUT_AXES`` instead + of going by position. + + Lower-rank shapes (3D / 2D / 1D) keep the old per-rank heuristic + since they're not multi-tile-eligible anyway. + """ + n = len(shape_or_extents) + if n == 0: + return (1, 1) + if n == 1: + return (1, int(shape_or_extents[0])) + if n == 2: + return (int(shape_or_extents[0]), int(shape_or_extents[1])) + if n == 3: + # Keep the legacy "fold leading dims into rows, cols = D" rule. + return (int(shape_or_extents[0]), int(shape_or_extents[1]) * int(shape_or_extents[2])) + if n != 4: + raise ValueError( + f"logical_2d_extents only handles up to 4D; got {n}-D" + ) + bi, ri, ci, di = LAYOUT_AXES[layout] + rows = int(shape_or_extents[bi]) * int(shape_or_extents[ri]) + cols = int(shape_or_extents[ci]) * int(shape_or_extents[di]) + return (rows, cols) + + +def hbm_strides_for_layout(shape, layout: str = DEFAULT_LAYOUT): + """Return ``(b_stride, s_stride, h_stride, d_stride)`` in *canonical* + (B, S, H, D) order for a 4D shape laid out row-major in HBM under + ``layout``. + + Each stride is in element units (HBM is row-major in source-layout + order). For BSHD the strides come out in trivial order; for NCHW + they get permuted because the row-dim (H) and channel-dim (C) swap + relative to canonical positions. + + Used by ``_emit_dma_h2v_slice_multi_tile`` and friends to compute + the per-tile HBM offset. + """ + if len(shape) != 4: + raise ValueError(f"hbm_strides_for_layout needs 4D shape; got {tuple(shape)}") + src_strides = [1, 1, 1, 1] + for i in range(2, -1, -1): + src_strides[i] = src_strides[i + 1] * int(shape[i + 1]) + bi, ri, ci, di = LAYOUT_AXES[layout] + return (src_strides[bi], src_strides[ri], src_strides[ci], src_strides[di]) + + +def make_tile_layout( + *, shape=None, layout: str = DEFAULT_LAYOUT, mlen: int, hlen: int, + cluster_dim: Optional[int] = None, + # Legacy keyword form (b/s/h/d) kept for back-compat with any older + # caller. New callers should pass ``shape=...`` plus ``layout=...``. + b: Optional[int] = None, s: Optional[int] = None, + h: Optional[int] = None, d: Optional[int] = None, +) -> Optional["TileLayout"]: + """Build a TileLayout for a 4D buffer if (and only if) it needs + multi-tile storage. + + Two calling forms (kept compatible during the layout migration): + make_tile_layout(shape=(...), layout="NCHW", mlen=..., hlen=...) + make_tile_layout(b=..., s=..., h=..., d=..., mlen=..., hlen=...) + + The ``shape``/``layout`` form picks (b, s, h, d) per ``LAYOUT_AXES``; + everything downstream still works in canonical BSHD terms. Returns + None for buffers that fit a single inner tile (caller treats them + as plain row-major just like before). + + Tiling is required when any of: + * S > MLEN (need S-direction tiling) + * D > MLEN (need D-direction tiling, the new "outer D_TILES" dim) + * H > MLEN // HLEN (need H-direction lane grouping past one group) + + For now we only handle ``b == 1`` cleanly; multi-batch tiling can be + added later by a caller wanting it. + """ + if shape is not None: + if any(v is not None for v in (b, s, h, d)): + raise ValueError( + "make_tile_layout: pass either ``shape``/``layout`` OR the " + "legacy ``b``/``s``/``h``/``d`` kwargs, not both" + ) + if len(shape) != 4: + raise ValueError( + f"make_tile_layout: shape must be 4D; got {tuple(shape)}" + ) + b, s, h, d = (int(x) for x in _select_axes(shape, layout)) + else: + if any(v is None for v in (b, s, h, d)): + raise ValueError( + "make_tile_layout: legacy form requires all four of " + "b/s/h/d" + ) + if mlen <= 0 or hlen <= 0 or hlen > mlen or mlen % hlen != 0: + raise ValueError( + f"invalid mlen={mlen}, hlen={hlen}: require 0 < hlen <= mlen and " + f"mlen % hlen == 0" + ) + + # Inner-tile shape derivation. D == mlen → LANE_COUNT = 1; + # D < mlen (typically D == hlen) → LANE_COUNT = mlen // D. + if d >= mlen: + d_inner = mlen + lane_count = 1 + else: + if mlen % d != 0: + raise ValueError( + f"narrow-D tile requires mlen % d == 0; got mlen={mlen}, d={d}" + ) + d_inner = d + lane_count = mlen // d + + d_tiles = (d + mlen - 1) // mlen if d > mlen else 1 + s_tiles = (s + mlen - 1) // mlen + if h % lane_count != 0: + raise ValueError( + f"H ({h}) must be a multiple of LANE_COUNT ({lane_count})" + ) + h_groups = h // lane_count + + # Every 4D BSHD/NCHW buffer gets a TileLayout — even trivially + # 1×1×1×1-tile-grid ones. Downstream code (isa_emit) walks the + # same 7D physical-offset formula uniformly; the "single-tile + # fast path" the old code returned None for is now expressed as + # the degenerate ``d_tiles=s_tiles=h_groups=b=1`` case of the + # same formula. This removes a category of "is the buffer a + # single tile or multi tile?" branches from every pass that + # consumes TileLayout. + return TileLayout( + logical_b=b, logical_s=s, logical_h=h, logical_d=d, + d_tiles=d_tiles, s_tiles=s_tiles, h_groups=h_groups, + mlen=mlen, lane_count=lane_count, d_inner=d_inner, + ) + + +@dataclass +class Buffer: + """One tile/tensor with shape, scope, dtype, and (after Pass 2) address.""" + + name: str + scope: str # one of scope.{HBM,VRAM,MRAM,FPRAM} + shape: Tuple[int, ...] + dtype: str + + # Filled by AddressAllocationPass. None until then. + address: Optional[int] = None # base address in the buffer's scope + hbm_offset: int = 0 # for HBM tiles that are sub-regions + hbm_stride: Optional[int] = None # row stride in HBM (defaults to mlen) + hbm_scale_size: Optional[int] = None # tile elem count (defaults to tile_elems) + + # Multi-tile physical layout descriptor for VRAM/MRAM buffers whose + # logical shape overflows one inner tile. None means "single tile, + # plain row-major" — the existing case all kernels written before + # this change relied on. See ``TileLayout`` docstring for the + # logical→physical mapping. + tile_layout: Optional[TileLayout] = None + + # Author-pinned ``global.vram`` / ``global.mram`` tensor caches. The + # testbench (or a pre-kernel stub) loads these row-major-contiguous, + # NOT in the 7D mlen-tile-padded layout the compiler uses for its + # own VRAM allocations. AddressAllocationPass therefore skips + # ``make_tile_layout`` for them, and the offset-walking iterators + # branch on this flag to compute addresses as flat row-major. + # + # Lane-shared FPRAM scalars also ride this path: declaring a scalar + # with ``scope="global.fpram"`` bypasses cluster-fusion lane + # expansion, so a ``(1,)`` buffer occupies exactly 1 FPRAM slot + # (every lane reads the same address) instead of lane_count slots. + is_pinned_global: bool = False + + # Compile-time-known value for a constant scalar buffer synthesised + # by the ``hoist_float_constants`` TIR pre-pass (frontend/passes/ + # hoist_float_constants.py). The pre-pass scans the PrimFunc for + # ``T.float16(c)`` (FloatImm) uses, allocates one ``global.fpram`` + # buffer per unique (dtype, value) pair, and rewrites the uses to + # BufferLoads of that buffer. The compiler proper sees a plain + # ``alloc_shared(global.fpram)`` and lowers it normally. This field + # only exists so ``--dump-buffer-addrs`` can emit the value the + # test harness needs to preload — nothing else in the compiler + # reads it. + constant_value: Optional[float] = None + + # 4D-buffer layout hint, used to resolve which axis is the row / + # channel / col dim. ``BSHD`` (the default) means axes are already + # in canonical batch-row-channel-col order; ``NCHW`` means + # axes[1] is the channel and axes[2] is the row, so callers must + # permute before computing tile offsets / lane groups. Ignored for + # non-4D shapes. See ``LAYOUT_AXES`` for the mapping. + layout: str = "BSHD" + + # Index into ``shape`` of the cluster (lane) axis. Set by the + # expand step in mid_ir when a buffer is grown for cluster-fusion; + # tracked through view / burn_view so any axis permutation moves + # the marker along with the dim. ``None`` for buffers that aren't + # cluster-aware (HBM globals, user-declared ``global.*`` caches, + # pure scalar fpram slots). Downstream addressing reads this + # directly instead of guessing from shape. + cluster_dim: Optional[int] = None + + # Pass-attached scratch (e.g. logical_rows, logical_cols, row_blocks, + # col_blocks for HBM buffers). Free-form so passes can stash hints + # without growing this dataclass. + annotations: Dict[str, Any] = field(default_factory=dict) + + @property + def num_elements(self) -> int: + n = 1 + for s in self.shape: + n *= int(s) + return n + + @property + def byte_size(self) -> int: + bits = {"float16": 16, "bfloat16": 16, "float32": 32, "int8": 8, "int32": 32}.get( + self.dtype, 32 + ) + return self.num_elements * bits // 8 + + +@dataclass +class BufferSlice: + """A logical sub-region of a parent HBM buffer. + + Conventions (BSHD-aware, mirroring the layout rules in the package + docstring): + * `parent`: name of the parent Buffer in the same HLIRModule. + Must be an HBM buffer (slicing VRAM/MRAM is not a + thing PLENA exposes natively). + * `starts`: per-dim start indices in the parent's logical shape. + Each entry is either: + - a Python int (compile-time-known) + - a tir.PrimExpr (loop-derived; lowered by + ExprMaterializer at ISA emit time) + * `extents`: per-dim element counts of the slice in each parent + dim. Currently restricted to compile-time ints. + + Address resolution conventions: + * The slice inherits the parent's `hbm_addr`, `hbm_stride`, + `hbm_scale_size`. Pass 3 computes the additional `hbm_offset` + from `starts` and adds it to the parent's `hbm_offset`. + * Pass 3 iterates tiles within the slice using `extents` to + decide row_blocks / col_blocks (BSHD-aware H*D merge same as + for whole buffers). + """ + parent: str # name of parent Buffer + starts: Tuple[Any, ...] # int | tir.PrimExpr per dim + extents: Tuple[int, ...] # int per dim + + +@dataclass +class VramRegion: + """A logical sub-region of a VRAM on-chip buffer. + + Mirrors :class:`BufferSlice` but for on-chip buffers — used by ops + whose lowering needs to know the multi-dim shape of the region (to + split it into multiple HW-MLEN-wide transfers and to compute + per-chunk physical addresses against the parent's 7D tile layout). + + starts / extents are per-dim and refer to the parent's logical + shape. Each ``starts`` entry is either a Python int or a + ``tir.PrimExpr`` (loop-derived; materialised at ISA emit time). + Each ``extents`` entry is a Python int. + """ + parent: str + starts: Tuple[Any, ...] + extents: Tuple[int, ...] + + +@dataclass +class MramRegion: + """A logical sub-region of an MRAM on-chip buffer. + + Same shape contract as :class:`VramRegion` — ``starts`` / + ``extents`` are per-dim against the parent buffer's logical 4D + BSHD shape, and the same 7D tile-layout addressing applies. Kept + as a distinct type so emitters can statically tell which on-chip + backing store an operand lives in (matmul's LHS is VRAM, RHS is + MRAM; mixing them up would target the wrong HW unit). + """ + parent: str + starts: Tuple[Any, ...] + extents: Tuple[int, ...] + + +@dataclass(frozen=True) +class BufferElement: + """One scalar element reference within a buffer. + + Used for FPRAM-backed `_at` operands where the frontend keeps + tilelang-style indexing (`buf[row]`, later expanded to `buf[by,row]`) + but the ISA ultimately expects a flat scalar address. + """ + + buffer: str + indices: Tuple[Any, ...] + + +@dataclass +class Op: + """One HLIR op. + + Two flavours: + * Leaf op (most ops): `kind` matches a `plena.` intrinsic. + `buffer_args` are buffer names; `scalar_args` holds Python ints + or `tir.PrimExpr` (compound expressions involving loop vars). + * Structured op (only `for` for now): `kind == "for"`, `body` is + a non-empty list of nested ops, and `annotations` holds the + loop metadata (`loop_var`, `extent`, `init`). Pass 3 recurses + on `body` while binding `loop_var` to a GP register. + + After Pass 2 every buffer arg has a resolved address on its Buffer. + Pass 3 reads those addresses (and any PrimExpr scalar args via + ExprMaterializer) and emits ISA. + """ + + kind: str + # Each entry is either a buffer name (whole-buffer reference) or a + # BufferSlice (a sub-region of a parent buffer). + buffer_args: List[Any] # List[str | BufferSlice] + scalar_args: List[Any] = field(default_factory=list) # int | float | str | tir.PrimExpr + annotations: Dict[str, Any] = field(default_factory=dict) # debug/passes + # Only set for structured ops (currently just "for"). Leaves leave it None. + body: Optional[List["Op"]] = None + # Per-buffer-arg axes. Parallel to ``buffer_args``: the i-th entry + # is the per-dim ``(role_name, extent)`` tuple for ``buffer_args[i]``. + # ``role_name`` is a string identifying the dim's algebra role + # (``"batch"`` / ``"simd"`` / ``"cluster"`` / ``"reduce"`` / + # ``"broadcast"`` / ``"gemm_m"`` / ``"gemm_n"`` / ``"gemm_k"``); + # the values mirror mid_ir's ``AxisRole`` enum-value strings so a + # mid_ir → HLIR translator can pass them through verbatim. + # + # Filled by ``mid_ir.to_plena`` for ops whose ISA emit needs to + # locate dims by role (e.g. ``row_*_at`` family — which dim is the + # rows axis, which is the cluster lane). Leaf ops that don't need + # this leave the slot ``None``. Slot count must equal + # ``len(buffer_args)``; default empty for back-compat with + # constructors that haven't migrated. + buffer_axes: List[Optional[Tuple[Tuple[str, int], ...]]] = field( + default_factory=list, + ) + + +def make_for_op( + loop_var, + extent, + body: List[Op], + init: int = 0, +) -> Op: + """Helper: build a structured For op.""" + return Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={"loop_var": loop_var, "extent": extent, "init": init}, + body=body, + ) + + +@dataclass +class HLIRModule: + """One PrimFunc lowered to HLIR. Buffers + linear op stream.""" + + name: str + buffers: Dict[str, Buffer] # name -> Buffer + ops: List[Op] + # PrimFunc parameter names in their declaration order. Useful so + # downstream passes know which buffers come in as inputs/outputs vs + # internally-allocated scratch. + param_names: List[str] = field(default_factory=list) + + def get_buffer(self, name: str) -> Buffer: + if name not in self.buffers: + raise KeyError( + f"buffer {name!r} not found in HLIR. " + f"Known: {sorted(self.buffers.keys())}" + ) + return self.buffers[name] + + def buffers_in_scope(self, scope: str) -> List[Buffer]: + return [b for b in self.buffers.values() if b.scope == scope] + + def __repr__(self) -> str: + bs = ", ".join(f"{b.name}<{b.scope}>" for b in self.buffers.values()) + return f"HLIRModule({self.name!r}, buffers=[{bs}], ops={len(self.ops)})" + + +def format_hlir(mod: HLIRModule) -> str: + """Pretty-print HLIR. Used for `--dump-hlir`.""" + lines = [f"HLIRModule(name={mod.name!r})", ""] + lines.append(f"Params (in declaration order):") + for p in mod.param_names: + lines.append(f" - {p}") + lines.append("") + + lines.append("Buffers:") + name_w = max((len(b.name) for b in mod.buffers.values()), default=4) + for b in mod.buffers.values(): + addr = "" if b.address is None else str(b.address) + shape_s = "x".join(str(s) for s in b.shape) + extras = "" + if b.scope == "hbm": + extras = ( + f" stride={b.hbm_stride}" + f" scale={b.hbm_scale_size}" + f" hbm_offset={b.hbm_offset}" + ) + lines.append( + f" {b.name:<{name_w}} scope={b.scope:<5} addr={addr:<8} " + f"shape={shape_s} dtype={b.dtype}{extras}" + ) + lines.append("") + + lines.append("Ops:") + _format_ops(mod.ops, lines, indent=2, prefix=[0]) + return "\n".join(lines) + "\n" + + +def _format_ops(ops: List[Op], lines: List[str], indent: int, prefix: List[int]) -> None: + """Recursive op pretty-printer; handles structured ops (for) with nesting.""" + for op in ops: + idx = prefix[0] + prefix[0] += 1 + ind = " " * indent + if op.kind == "for": + lv = op.annotations.get("loop_var") + ext = op.annotations.get("extent") + init = op.annotations.get("init", 0) + lines.append( + f"{ind}[{idx:2d}] for {getattr(lv, 'name', lv)} " + f"in [{init}, {ext}):" + ) + _format_ops(op.body or [], lines, indent + 4, prefix) + else: + bs = ", ".join(_fmt_buf_arg(a) for a in op.buffer_args) if op.buffer_args else "-" + ss = ", ".join(_fmt_scalar(a) for a in op.scalar_args) if op.scalar_args else "-" + lines.append(f"{ind}[{idx:2d}] {op.kind:<14} bufs=({bs}) scalars=({ss})") + + +def _fmt_idx_item(item) -> str: + """Render one index entry. Tries hard to expose ``tir.Var`` names + (e.g. ``by_phase``, ``row``) so HLIR dumps are readable instead of + showing opaque ```` placeholders.""" + if isinstance(item, (int, float, str)): + return str(item) + # tir.Var / similar IR nodes carry a ``.name`` field. + name = getattr(item, "name", None) + if isinstance(name, str) and name: + return name + name_hint = getattr(item, "name_hint", None) + if isinstance(name_hint, str) and name_hint: + return name_hint + return str(item) + + +def _fmt_buf_arg(a) -> str: + """Render a buffer ref or slice for HLIR dump.""" + if isinstance(a, str): + return a + if isinstance(a, BufferSlice): + starts = ",".join(_fmt_idx_item(s) for s in a.starts) + extents = ",".join(str(e) for e in a.extents) + return f"{a.parent}[starts=({starts}), ext=({extents})]" + if isinstance(a, VramRegion): + starts = ",".join(_fmt_idx_item(s) for s in a.starts) + extents = ",".join(str(e) for e in a.extents) + return f"{a.parent}[starts=({starts}), ext=({extents})]" + if isinstance(a, MramRegion): + starts = ",".join(_fmt_idx_item(s) for s in a.starts) + extents = ",".join(str(e) for e in a.extents) + return f"{a.parent}[starts=({starts}), ext=({extents})]" + return str(a) + + +def _fmt_scalar(x) -> str: + """Compact display for ints / strs / PrimExprs.""" + if isinstance(x, BufferElement): + idx = ", ".join(_fmt_idx_item(i) for i in x.indices) + return f"{x.buffer}[{idx}]" + if isinstance(x, (int, float, str)): + return str(x) + name = getattr(x, "name", None) + if isinstance(name, str) and name: + return name + return str(x) + + +def _count_ops(ops: List[Op]) -> int: + """Total op count including nested ``for`` bodies — matches the + flat index space ``_format_ops`` walks.""" + n = 0 + for op in ops: + n += 1 + if op.kind == "for": + n += _count_ops(op.body or []) + return n + + +def format_lowir(mod: HLIRModule, lowir_log: List[tuple]) -> str: + """Render the low-level "last variable-form" report. + + This is the layer between HLIR and ISA: each HLIR op, after the + isa_pass has lowered its buffer refs into physical address + expressions but BEFORE those expressions are bound to ``gp`` + registers. ``lowir_log`` is the recording the ``ExprMaterializer`` + captures at that exact chokepoint — a list of + ``(top_level_op_idx, expr_str)`` where ``expr_str`` still contains + the live ``tir.Var`` loop indices (``head_phase``, ``row``, ...). + + Because the recording is taken from real codegen (not a re-derived + copy), this report can never drift from what the ISA actually + emits. The op indices line up with ``.hlir.txt``. + + Use it to verify, e.g., that ``row_mul_fp`` on a packed-head + buffer produces ``mask = 1 << head_phase`` (the loop var survives) + and not ``mask = 1 << 2`` (the var got constant-folded away). + """ + # Group recorded expressions by depth-first op index — the SAME + # [NN] space _format_ops (and isa_pass's _lowir_idx counter) walks, + # so every op, nested or not, owns its own bucket. + by_op: Dict[int, List[str]] = {} + for op_idx, expr_str in lowir_log: + by_op.setdefault(op_idx, []).append(expr_str) + + lines = [ + f"LowIR report -- kernel: {mod.name!r}", + "", + "Per-op physical address expressions, captured at the " + "var->gp boundary.", + "Each expr below is what the ISA materializes into a gp " + "register next;", + "live tir.Var loop indices (head_phase, row, ...) are still " + "symbolic here.", + "Op indices match .hlir.txt.", + "", + "=" * 64, + "", + ] + + def _collapse(exprs: List[str]) -> List[tuple]: + """Run-length collapse consecutive identical exprs (a loop body + re-emits the same symbolic form every unrolled iteration).""" + out: List[tuple] = [] + prev, run = None, 0 + for e in exprs: + if e == prev: + run += 1 + else: + if prev is not None: + out.append((prev, run)) + prev, run = e, 1 + if prev is not None: + out.append((prev, run)) + return out + + def _emit(ops: List[Op], indent: int, idx: List[int]) -> None: + for op in ops: + cur = idx[0] + idx[0] += 1 + ind = " " * indent + if op.kind == "for": + lv = op.annotations.get("loop_var") + ext = op.annotations.get("extent") + init = op.annotations.get("init", 0) + lines.append( + f"{ind}[{cur:2d}] for " + f"{getattr(lv, 'name', lv)} in [{init}, {ext}):" + ) + _emit(op.body or [], indent + 4, idx) + continue + bs = ", ".join(_fmt_buf_arg(a) for a in op.buffer_args) \ + if op.buffer_args else "-" + ss = ", ".join(_fmt_scalar(a) for a in op.scalar_args) \ + if op.scalar_args else "-" + lines.append( + f"{ind}[{cur:2d}] {op.kind} bufs=({bs}) scalars=({ss})" + ) + exprs = by_op.get(cur, []) + if not exprs: + lines.append(f"{ind} (no materialized address exprs)") + else: + for e, n in _collapse(exprs): + suffix = f" (x{n})" if n > 1 else "" + lines.append(f"{ind} -> {e}{suffix}") + + _emit(mod.ops, indent=2, idx=[0]) + return "\n".join(lines) + "\n" + + +# Sanity helper used by passes to assert progress. +def assert_addresses_resolved(mod: HLIRModule) -> None: + missing = [b.name for b in mod.buffers.values() if b.address is None] + if missing: + raise RuntimeError( + f"address allocation incomplete; buffers without address: {missing}" + ) + + +__all__ = [ + "Buffer", "BufferSlice", + "VramRegion", "MramRegion", + "BufferElement", "Op", "HLIRModule", + "make_for_op", + "assert_addresses_resolved", "format_hlir", "format_lowir", +] diff --git a/tilelang_tvm_compiler/hw_consts.py b/tilelang_tvm_compiler/hw_consts.py new file mode 100644 index 0000000..a7fe39e --- /dev/null +++ b/tilelang_tvm_compiler/hw_consts.py @@ -0,0 +1,102 @@ +"""Hardware-shape constants as symbolic ``tir.Var``s. + +PreIsaPass producers write address expressions in terms of these +symbolic vars rather than baking the current hardware values +(``shim.mlen``, ``shim.blen``, etc.) into ``tir.IntImm``s. BackendEmit +binds each symbolic var to the shim's current numeric value via +``symbol_table`` at run start; the materialiser's +``_peephole_const_fold`` then substitutes the IntImms back in and +folds the address algebra at emit time. + +Why symbolic +------------ +The hardware shape parameters are NOT compile-time constants: the +chip variant being targeted changes them, parameterised tests +sweep them, and the ``plena_settings.toml`` active mode dictates +their current values. PreIsaIR keeps them symbolic so optimisations +that operate on PreIsaIR (LICM, CSE, stride-detection) see the +algebra structure — e.g. ``Mul(oc, BLEN_VAR)`` — rather than a +post-substitution ``Mul(oc, IntImm(4))`` where the ``4`` is +indistinguishable from any other unrelated literal. + +The vars are MODULE-LEVEL SINGLETONS so every producer + every +BackendEmit consumer references the same Python object (``id()`` +match). This lets BackendEmit's group-cache + lookup machinery do +its job — ``id(BLEN_VAR)`` is the symbol_table key on both sides. + +Usage in producers +------------------ +Import the relevant vars (or grab the dict via ``hw_const_vars()``) +and use them in PrimExpr operands: + + from .hw_consts import BLEN_VAR + mat_col_expr = tir.Add( + tir.IntImm("int32", int(rhs.address)), + tir.Mul(oc_var, BLEN_VAR), + ) + +For values derived from the shape constants +(``output_row_stride = blen * mlen``) write the derivation +explicitly as a PrimExpr — let the materialiser simplify: + + output_row_stride_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + Mul(orow_var, output_row_stride_expr) + +Usage in BackendEmit +-------------------- +``BackendEmit.__init__`` binds every hw_const var into +``symbol_table`` using the values from its shim: + + for var, attr in HW_CONST_ATTRS.items(): + self.symbol_table[var] = tir.IntImm("int32", int(getattr(shim, attr))) +""" + +from __future__ import annotations + +from typing import Dict + +from tvm import tir + + +# Singletons. Same Python objects in every module that imports. +MLEN_VAR = tir.Var("mlen", "int32") +BLEN_VAR = tir.Var("blen", "int32") +BTMM_HLEN_VAR = tir.Var("btmm_hlen", "int32") +BTMM_LANE_COUNT_VAR = tir.Var("btmm_lane_count", "int32") +# Rows transferred per H_PREFETCH_V / H_STORE_V instruction — +# the emulator's PREFETCH_V_AMOUNT / STORE_V_AMOUNT. The DMA +# helpers use these as the per-instruction VLEN-row count. +# Different chip variants have different values. +V_PREFETCH_AMOUNT_VAR = tir.Var("v_prefetch_amount", "int32") +V_WRITEBACK_AMOUNT_VAR = tir.Var("v_writeback_amount", "int32") + + +# Map of hw-const tir.Var -> ProgramShim attribute name. BackendEmit +# iterates this at startup to populate symbol_table. +HW_CONST_ATTRS: Dict[tir.Var, str] = { + MLEN_VAR: "mlen", + BLEN_VAR: "blen", + BTMM_HLEN_VAR: "btmm_hlen", + BTMM_LANE_COUNT_VAR: "btmm_lane_count", + V_PREFETCH_AMOUNT_VAR: "v_prefetch_amount", + V_WRITEBACK_AMOUNT_VAR: "v_writeback_amount", +} + + +def hw_const_vars() -> Dict[str, tir.Var]: + """Convenience dict by name (for producers that prefer string keys).""" + return { + "mlen": MLEN_VAR, + "blen": BLEN_VAR, + "btmm_hlen": BTMM_HLEN_VAR, + "btmm_lane_count": BTMM_LANE_COUNT_VAR, + "v_prefetch_amount": V_PREFETCH_AMOUNT_VAR, + "v_writeback_amount": V_WRITEBACK_AMOUNT_VAR, + } + + +__all__ = [ + "MLEN_VAR", "BLEN_VAR", "BTMM_HLEN_VAR", "BTMM_LANE_COUNT_VAR", + "V_PREFETCH_AMOUNT_VAR", "V_WRITEBACK_AMOUNT_VAR", + "HW_CONST_ATTRS", "hw_const_vars", +] diff --git a/tilelang_tvm_compiler/intrinsics.py b/tilelang_tvm_compiler/intrinsics.py new file mode 100644 index 0000000..80acdac --- /dev/null +++ b/tilelang_tvm_compiler/intrinsics.py @@ -0,0 +1,352 @@ +"""PLENA intrinsic descriptors. + +Each intrinsic in PLENA's ISA gets a Python descriptor here: + - canonical name used in T.call_extern("handle", "plena.", ...) + - operand scope constraints (which RAM each operand must live in) + - simple printer used by the codegen pass + +This is the place to add new ops as the compiler grows. The codegen +walks the TIR, finds plena.* extern calls, looks them up here, verifies +scopes, and emits ISA text. + +FPRAM operand convention: + FPRAM is treated as a flat scalar register file, not a buffered + region. Every FP operand position is a SCALAR address (PrimExpr or + int), counted in element units from address 0. The kernel is + responsible for adding any per-slot base offset before passing the + value in. There are no FPRAM buffer handles in TIR anymore. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Dict, Sequence + +from . import scope as _scope + + +@dataclass(frozen=True) +class IntrinsicSpec: + name: str + # Required scope per buffer-typed operand position. + # `None` means "scalar / immediate / FP address, no scope check". + operand_scopes: Sequence[str | None] + # Friendly printer: takes a list of resolved operand strings and + # any trailing scalar args, returns one ISA line. + emit: Callable[[list[str]], str] + + +_REGISTRY: Dict[str, IntrinsicSpec] = {} + + +def register(spec: IntrinsicSpec) -> None: + if spec.name in _REGISTRY: + raise ValueError(f"duplicate intrinsic: {spec.name}") + _REGISTRY[spec.name] = spec + + +def lookup(name: str) -> IntrinsicSpec: + if name not in _REGISTRY: + raise KeyError( + f"unknown PLENA intrinsic: {name!r}. " + f"Known: {sorted(_REGISTRY.keys())}" + ) + return _REGISTRY[name] + + +def all_names() -> list[str]: + return sorted(_REGISTRY.keys()) + + +# --------------------------------------------------------------------------- +# DMA / matmul / vector ops +# --------------------------------------------------------------------------- + +register(IntrinsicSpec( + name="plena.dma_h2v", + operand_scopes=(_scope.HBM, _scope.VRAM, None), # src, dst, size + emit=lambda a: f"DMA_H2V src={a[0]} dst={a[1]} size={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.dma_h2m", + operand_scopes=(_scope.HBM, _scope.MRAM, None), + emit=lambda a: f"DMA_H2M src={a[0]} dst={a[1]} size={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.dma_v2h", + operand_scopes=(_scope.VRAM, _scope.HBM, None), + emit=lambda a: f"DMA_V2H src={a[0]} dst={a[1]} size={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.btmm", + operand_scopes=(_scope.VRAM, _scope.MRAM, _scope.VRAM, None), + emit=lambda a: f"BTMM A={a[0]} B={a[1]} C={a[2]} group_heads={a[3]}", +)) + +register(IntrinsicSpec( + # Lane-fused matrix-vector. LHS is a 1-D vector (lane-packed across + # heads, MLEN-wide); RHS is a (mlen, lane_count, hlen) MRAM matrix + # (same layout as BTMM's RHS). DST is a row-stacked 1-D vector that + # M_BMV_WO writes out as `lane_count` MLEN-wide rows. + # Maps to one M_BTMV + M_BMV_WO pair, parallel to plena.btmm's + # M_BTMM + M_BMM_WO. + name="plena.btmv", + operand_scopes=(_scope.VRAM, _scope.MRAM, _scope.VRAM, None), + emit=lambda a: f"BTMV A={a[0]} B={a[1]} C={a[2]} group_heads={a[3]}", +)) + +register(IntrinsicSpec( + name="plena.tile_add", + operand_scopes=(_scope.VRAM, _scope.VRAM, _scope.VRAM), + emit=lambda a: f"TILE_ADD lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.tile_sub", + operand_scopes=(_scope.VRAM, _scope.VRAM, _scope.VRAM), + emit=lambda a: f"TILE_SUB lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.tile_mul", + operand_scopes=(_scope.VRAM, _scope.VRAM, _scope.VRAM), + emit=lambda a: f"TILE_MUL lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.mm", + operand_scopes=(_scope.VRAM, _scope.MRAM, _scope.VRAM), + emit=lambda a: f"MM A={a[0]} B={a[1]} C={a[2]}", +)) + +register(IntrinsicSpec( + # Unified `(M, K) @ (K, N) -> (M, N)`. Replaces plena.mm + plena.mm_slot. + # K reduction is folded into the matmul op itself (M_MM accumulate + + # M_MM_WO drain), so no caller-side scratch + v_add is needed for K. + # N may exceed mlen and walks across mlen-wide B-tile blocks internally. + # + # Trailing scalar args (offsets) let the same op address sub-regions + # of larger buffers without buffer slicing in HLIR: + # lhs_offset : element offset added to A_v's base (int or PrimExpr) + # rhs_offset : element offset added to B_m's base (int) + # dst_offset : element offset added to C_v's base (int) + # dst_row_stride : C row stride in elements (int) -- defaults to N + # when callers pass 0 here. + name="plena.matmul", + operand_scopes=( + _scope.VRAM, _scope.MRAM, _scope.VRAM, + None, None, None, # M_tiles, K_tiles, N + None, None, None, None, # lhs_offset, rhs_offset, dst_offset, dst_row_stride + ), + emit=lambda a: ( + f"MATMUL A={a[0]} B={a[1]} C={a[2]} " + f"M_tiles={a[3]} K_tiles={a[4]} N={a[5]} " + f"lhs_off={a[6]} rhs_off={a[7]} dst_off={a[8]} dst_row_stride={a[9]}" + ), +)) + +register(IntrinsicSpec( + # Per-head matrix-vector, single-lane M_MV + M_MV_WO. + # Used for the P @ V step of decode where the LHS is one row of a + # row-stacked S_loc fragment (one head's score vector). Each call + # handles ONE head; the kernel author wraps it in a per-lane loop + # (T.serial(lane_count) or T.unroll), exactly mirroring how + # flash_attention_min uses plena.matmul (M_MM) per head. + # + # Trailing offsets are element offsets added to each buffer's base. + # All three may be int OR PrimExpr (materialized to gp registers). + name="plena.mv", + operand_scopes=( + _scope.VRAM, _scope.MRAM, _scope.VRAM, + None, None, None, # lhs_offset, rhs_offset, dst_offset + ), + emit=lambda a: ( + f"MV A={a[0]} B={a[1]} C={a[2]} " + f"lhs_off={a[3]} rhs_off={a[4]} dst_off={a[5]}" + ), +)) + +register(IntrinsicSpec( + name="plena.mm_slot", + operand_scopes=(_scope.VRAM, _scope.MRAM, _scope.VRAM, None, None, None, None), + emit=lambda a: ( + f"MM_SLOT A={a[0]} B={a[1]} C={a[2]} " + f"lhs_row_offset={a[3]} rhs_col_offset={a[4]} " + f"dst_col_offset={a[5]} col_count={a[6]}" + ), +)) + +register(IntrinsicSpec( + name="plena.tile_zero", + operand_scopes=(_scope.VRAM,), + emit=lambda a: f"TILE_ZERO dst={a[0]}", +)) + + +# --------------------------------------------------------------------------- +# FP scalar ops (`_at` only). FP operands are SCALAR addresses, counted +# in element units. Kernel computes `slot_base + h*rows + row` and passes +# the result; codegen does no per-slot offset math of its own. +# --------------------------------------------------------------------------- + +register(IntrinsicSpec( + name="plena.fp_copy_at", + operand_scopes=(None, None), # src_addr, dst_addr + emit=lambda a: f"FP_COPY_AT src={a[0]} dst={a[1]}", +)) + +register(IntrinsicSpec( + name="plena.fp_zero_at", + operand_scopes=(None,), # dst_addr (single FPRAM scalar) + emit=lambda a: f"FP_ZERO_AT dst={a[0]}", +)) + +register(IntrinsicSpec( + name="plena.fp_add_at", + operand_scopes=(None, None, None), # lhs_addr, rhs_addr, dst_addr + emit=lambda a: f"FP_ADD_AT lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.fp_sub_at", + operand_scopes=(None, None, None), + emit=lambda a: f"FP_SUB_AT lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.fp_mul_at", + operand_scopes=(None, None, None), + emit=lambda a: f"FP_MUL_AT lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.fp_max_at", + operand_scopes=(None, None, None), + emit=lambda a: f"FP_MAX_AT lhs={a[0]} rhs={a[1]} dst={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.fp_exp_at", + operand_scopes=(None, None), + emit=lambda a: f"FP_EXP_AT src={a[0]} dst={a[1]}", +)) + +register(IntrinsicSpec( + name="plena.fp_reci_at", + operand_scopes=(None, None), + emit=lambda a: f"FP_RECI_AT src={a[0]} dst={a[1]}", +)) + +register(IntrinsicSpec( + name="plena.fp_sqrt_at", + operand_scopes=(None, None), + emit=lambda a: f"FP_SQRT_AT src={a[0]} dst={a[1]}", +)) + + +# --------------------------------------------------------------------------- +# Row ops (`_at` only). VRAM-side trailing scalars (row_idx, head_idx) are +# the buffer's *logical* (S, H) BSHD coordinates of the row to operate on: +# +# * row_idx — index along the logical S axis (which mlen-row). +# * head_idx — index along the logical H axis (which head/lane). +# +# These two are layout-agnostic at the graph-IR layer (i.e. the same +# whether the buffer is COL_PACK, ROW_STACK, or single-tile). isa_pass's +# ``_resolve_row_at_coords`` is the single point that consults +# ``buf.layout`` + ``buf.tile_layout`` to turn (row, head) into the +# physical (B, S, H, D) 7D coordinates and finally a VRAM row address +# plus optional lane V_MASK. FP-side operand is a SCALAR address, +# identical to the FP `_at` family above. +# --------------------------------------------------------------------------- + +register(IntrinsicSpec( + name="plena.row_reduce_max_at", + # vram_src, fp_dst_addr, row, head + operand_scopes=(_scope.VRAM, None, None, None), + emit=lambda a: f"ROW_REDUCE_MAX_AT src={a[0]} dst={a[1]} row={a[2]} head={a[3]}", +)) + +register(IntrinsicSpec( + name="plena.row_reduce_sum_at", + operand_scopes=(_scope.VRAM, None, None, None), + emit=lambda a: f"ROW_REDUCE_SUM_AT src={a[0]} dst={a[1]} row={a[2]} head={a[3]}", +)) + +register(IntrinsicSpec( + name="plena.row_exp_at", + # vram_src, vram_dst, row, head (no FP operand) + operand_scopes=(_scope.VRAM, _scope.VRAM, None, None), + emit=lambda a: f"ROW_EXP_AT src={a[0]} dst={a[1]} row={a[2]} head={a[3]}", +)) + +register(IntrinsicSpec( + name="plena.row_sub_fp_at", + # vram_src, fp_addr, vram_dst, row, head + operand_scopes=(_scope.VRAM, None, _scope.VRAM, None, None), + emit=lambda a: f"ROW_SUB_FP_AT src={a[0]} rhs={a[1]} dst={a[2]} row={a[3]} head={a[4]}", +)) + +register(IntrinsicSpec( + name="plena.row_mul_fp_at", + operand_scopes=(_scope.VRAM, None, _scope.VRAM, None, None), + emit=lambda a: f"ROW_MUL_FP_AT src={a[0]} rhs={a[1]} dst={a[2]} row={a[3]} head={a[4]}", +)) + +register(IntrinsicSpec( + name="plena.row_add_fp_at", + operand_scopes=(_scope.VRAM, None, _scope.VRAM, None, None), + emit=lambda a: f"ROW_ADD_FP_AT src={a[0]} rhs={a[1]} dst={a[2]} row={a[3]} head={a[4]}", +)) + + +# --------------------------------------------------------------------------- +# VRAM <-> VRAM and slice-form VRAM <-> FPRAM transfers. +# --------------------------------------------------------------------------- + +register(IntrinsicSpec( + # Single MLEN-wide row copy in VRAM, lane-fused. Lowers to + # ``V_ADD_VF dst, src, f0, 0`` which (with f0 reserved == 0) computes + # ``dst[i] = src[i] + 0 = src[i]`` for the full HW vector. Used by + # the "tensor cache" path: a small VRAM region pre-populated by a + # testbench-side FPRAM->VRAM stub feeds the kernel via this op, + # avoiding HBM DMA for vector-shape (seq=1) tensors. + name="plena.copy_v_to_v", + # src_buf, src_offset, dst_buf, dst_offset + operand_scopes=(_scope.VRAM, None, _scope.VRAM, None), + emit=lambda a: f"COPY_V_TO_V src={a[0]}+{a[1]} dst={a[2]}+{a[3]}", +)) + +# NOTE: ``plena.row_load_v_to_fp`` / ``plena.row_store_fp_to_v`` are +# retired. The single-row contract they enforced was too narrow: a +# logical ``T.copy(vram_slice, fpram_buf)`` may span multiple MLEN +# rows, span lane-grouped (narrow-D) layouts, or land on a buffer +# whose VRAM placement is the 7D physical layout. The +# ``v_fp_transfer_slice_{v_to_fp,fp_to_v}`` ops carry the whole +# logical region; isa_emit splits it into per-MLEN issues internally. + + +# --------------------------------------------------------------------------- +# Slice DMA variants (variadic args; only the first two operands are +# scope-checked). +# --------------------------------------------------------------------------- +register(IntrinsicSpec( + name="plena.dma_h2v_slice", + operand_scopes=(_scope.HBM, _scope.VRAM, None), + emit=lambda a: f"DMA_H2V_SLICE src={a[0]} dst={a[1]} ndim={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.dma_h2m_slice", + operand_scopes=(_scope.HBM, _scope.MRAM, None), + emit=lambda a: f"DMA_H2M_SLICE src={a[0]} dst={a[1]} ndim={a[2]}", +)) + +register(IntrinsicSpec( + name="plena.dma_v2h_slice", + operand_scopes=(_scope.VRAM, _scope.HBM, None), + emit=lambda a: f"DMA_V2H_SLICE src={a[0]} dst={a[1]} ndim={a[2]}", +)) diff --git a/tilelang_tvm_compiler/isa_emitter.py b/tilelang_tvm_compiler/isa_emitter.py new file mode 100644 index 0000000..9a19e1d --- /dev/null +++ b/tilelang_tvm_compiler/isa_emitter.py @@ -0,0 +1,1365 @@ +"""ISAEmitter: turns prepared tile/FP operations into ISA strings. + +Owns all `emit_*` methods (HBM/VRAM transfer, BTMM, matmul, FP kernels, +row operations, etc.). Managers and TileTensorProgram hold a reference +to an ISAEmitter and call its methods directly rather than going through +the program object. +""" + +from __future__ import annotations + +import math +import sys +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +# This file is at .../compiler/tilelang_tvm_compiler/isa_emitter.py. +# Walking three parents lands at the project root so that +# `compiler.asm_templates` resolves regardless of CWD/PYTHONPATH. +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from compiler.asm_templates import preload_addr_reg_asm, reset_reg_asm + +# NOTE on stripped imports: +# The original runtime version did `from ._types import *` and +# `from ._helpers import *` to pull in tile/tensor types and small +# helpers used by the higher-order methods (emit_matmul / emit_fp_kernel +# / emit_row_operation, etc.). +# +# For the TVM port we ONLY use the simple methods that take physical +# addresses and produce ISA: emit_load_tile_from_hbm, +# emit_store_tile_to_hbm, emit_hbm_tile_to_mram, emit_btmm, +# emit_btmm_wo, emit_zero_vram_tile, emit_map_v_fp_tile, +# emit_map_fp_v_tile. None of those reference _types/_helpers symbols, +# so we drop the deep imports and only pull `Sequence` from typing. +# +# Calling the heavier methods (emit_matmul / emit_fp_kernel / row ops) +# will raise NameError until those types are ported. That's intentional: +# we want the failure to be loud when we try to use a method whose +# contract we haven't validated yet. + + +class ISAEmitter: + """Emit ISA strings for already-prepared tensor/FP operations.""" + + def __init__(self, program: "TileTensorProgram") -> None: + self.program = program + + def _emit_preload_tile_isa( + self, + *, + vlen: int, + preload_len: int, + batch: int, + hidden_size: int, + act_vram_offset: int, + alive_registers: List[int], + activation_offset_reg: int, + stride_size: Optional[int] = None, + scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + # PLENA TVM extension: when supplied, the offset is COPIED from + # `hbm_start_offset_reg` instead of loaded as a literal. Used by + # the slice-aware DMA dispatcher when the slice has a runtime- + # computed start (e.g. derived from a loop var). + hbm_start_offset_reg: Optional[int] = None, + ) -> str: + generated_code = "; Preload Activation Generation \n" + a_actual_register = alive_registers[0] + set_stride_register = alive_registers[1] + result_register = alive_registers[2] + outer_loop_register = alive_registers[3] + inner_loop_register = alive_registers[4] + + stride_len = vlen if stride_size is None else int(stride_size) + scale_len = hidden_size * batch if scale_size is None else int(scale_size) + load_amount_per_hidden = math.ceil(hidden_size / vlen) + + generated_code += f"S_ADDI_INT gp{a_actual_register}, gp0, {scale_len} \n" + generated_code += f"C_SET_SCALE_REG gp{a_actual_register} \n" + if hbm_start_offset_reg is not None: + # Dynamic base + static residual: ``a = dyn_reg + static``. + # The static piece carries per-tile constant offsets (one + # invocation per inner tile in the multi-tile DMA grid). + generated_code += ( + f"S_ADDI_INT gp{a_actual_register}, gp{hbm_start_offset_reg}, " + f"{int(hbm_start_offset)} \n" + ) + else: + generated_code += f"S_ADDI_INT gp{a_actual_register}, gp0, {int(hbm_start_offset)} \n" + generated_code += f"S_ADDI_INT gp{result_register}, gp0, {act_vram_offset} \n" + + if batch == 1: + elements_per_prefetch = vlen * preload_len + for _ in range(math.ceil(hidden_size / elements_per_prefetch)): + generated_code += ( + f"H_PREFETCH_V gp{result_register}, gp{a_actual_register}, " + f"a{activation_offset_reg}, 0, 0, 0 \n" + ) + generated_code += ( + f"S_ADDI_INT gp{result_register}, gp{result_register}, {elements_per_prefetch} \n" + ) + generated_code += ( + f"S_ADDI_INT gp{a_actual_register}, gp{a_actual_register}, {elements_per_prefetch} \n" + ) + return generated_code + + generated_code += f"S_ADDI_INT gp{set_stride_register}, gp0, {stride_len} \n" + generated_code += f"C_SET_STRIDE_REG gp{set_stride_register} \n" + a_offset_register = set_stride_register + # Compile-time unrolled twin C_LOOP. ``result`` and ``a_offset`` + # were running registers advanced by S_ADDI_INT each iter; we + # now bake every (outer, inner) address in as a literal. No + # C_LOOP, no per-iter advance. + inner_count = math.ceil(batch / preload_len) if batch > preload_len else 1 + for outer in range(load_amount_per_hidden): + for inner in range(inner_count): + result_addr = act_vram_offset + ( + outer * inner_count + inner) * vlen * preload_len + a_off = outer * vlen + ( + inner * stride_len * preload_len if batch > preload_len else 0) + generated_code += f"S_ADDI_INT gp{result_register}, gp0, {result_addr} \n" + generated_code += f"S_ADDI_INT gp{a_offset_register}, gp{a_actual_register}, {a_off} \n" + generated_code += ( + f"H_PREFETCH_V gp{result_register}, gp{a_offset_register}, " + f"a{activation_offset_reg}, 1, 0 \n" + ) + return generated_code + + def _emit_store_tile_isa( + self, + *, + vlen: int, + batch: int, + hidden_size: int, + alive_registers: List[int], + act_vram_offset: int, + hbm_addr_reg: int, + stride_size: Optional[int] = None, + scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + store_amount: int = 4, + # PLENA TVM extension (see emit_preload_tile_isa for rationale). + hbm_start_offset_reg: Optional[int] = None, + ) -> str: + generated_code = "; Store Activation Generation\n" + + hbm_offset_reg = alive_registers[0] + set_stride_register = alive_registers[1] + vram_reg = alive_registers[2] + outer_loop_register = alive_registers[3] + inner_loop_register = alive_registers[4] + + stride_len = hidden_size if stride_size is None else int(stride_size) + scale_len = hidden_size * batch if scale_size is None else int(scale_size) + store_amount_per_hidden = math.ceil(hidden_size / vlen) + + generated_code += f"S_ADDI_INT gp{vram_reg}, gp0, {act_vram_offset}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp0, {scale_len}\n" + generated_code += f"C_SET_SCALE_REG gp{hbm_offset_reg}\n" + if hbm_start_offset_reg is not None: + # Dynamic base + static residual (see _emit_preload_tile_isa). + generated_code += ( + f"S_ADDI_INT gp{hbm_offset_reg}, gp{hbm_start_offset_reg}, " + f"{int(hbm_start_offset)}\n" + ) + else: + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp0, {int(hbm_start_offset)}\n" + + if batch == 1: + elements_per_store = vlen * store_amount + for _ in range(math.ceil(hidden_size / elements_per_store)): + generated_code += f"H_STORE_V gp{vram_reg}, gp{hbm_offset_reg}, a{hbm_addr_reg}, 0, 0\n" + generated_code += f"S_ADDI_INT gp{vram_reg}, gp{vram_reg}, {elements_per_store}\n" + generated_code += f"S_ADDI_INT gp{hbm_offset_reg}, gp{hbm_offset_reg}, {elements_per_store}\n" + return generated_code + + generated_code += f"S_ADDI_INT gp{set_stride_register}, gp0, {stride_len}\n" + generated_code += f"C_SET_STRIDE_REG gp{set_stride_register}\n" + hbm_base_reg = set_stride_register + # Compile-time unrolled twin C_LOOP. ``vram_reg`` ran across + # both loops; ``hbm_base`` was reset to ``hbm_offset + outer*vlen`` + # each outer iter and advanced inner. Bake all addresses in as + # literals — no C_LOOP, no per-iter advance. + inner_count = math.ceil(batch / store_amount) if batch > store_amount else 1 + for outer in range(store_amount_per_hidden): + for inner in range(inner_count): + vram_off = (outer * inner_count + inner) * vlen * store_amount + hbm_off = outer * vlen + ( + inner * stride_len * store_amount if batch > store_amount else 0) + generated_code += f"S_ADDI_INT gp{vram_reg}, gp0, {act_vram_offset + vram_off}\n" + generated_code += f"S_ADDI_INT gp{hbm_base_reg}, gp{hbm_offset_reg}, {hbm_off}\n" + generated_code += ( + f"H_STORE_V gp{vram_reg}, gp{hbm_base_reg}, a{hbm_addr_reg}, 1, 0\n" + ) + return generated_code + + def emit_hbm_tile_to_mram( + self, + *, + hbm_addr: int, + mram_addr: int, + hbm_offset: int = 0, + hbm_scale: Optional[int] = None, + hbm_stride: Optional[int] = None, + # PLENA TVM extension: when set, the offset is sourced from + # this GP register (caller owns it). `hbm_offset` is ignored. + hbm_offset_reg: Optional[int] = None, + ) -> None: + addr_reg = self.program.compiler.register_allocator.allocate_addr(1)[0] + gp_addr = self.program.compiler.register_allocator.allocate_gp(2) + gp_exec = self.program.compiler.register_allocator.allocate_gp(3) + gp_scale, gp_stride, gp_mram = gp_exec + scale_val = self.program.tile_elems if hbm_scale is None else int(hbm_scale) + stride_val = self.program.mlen if hbm_stride is None else int(hbm_stride) + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[hbm_addr], + ) + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {scale_val}\n" + isa += f"C_SET_SCALE_REG gp{gp_scale}\n" + isa += f"S_ADDI_INT gp{gp_stride}, gp0, {stride_val}\n" + isa += f"C_SET_STRIDE_REG gp{gp_stride}\n" + isa += f"S_ADDI_INT gp{gp_mram}, gp0, {mram_addr}\n" + if hbm_offset_reg is not None: + # Dynamic base + static residual. + isa += f"S_ADDI_INT gp{gp_scale}, gp{hbm_offset_reg}, {hbm_offset}\n" + else: + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {hbm_offset}\n" + isa += f"H_PREFETCH_M gp{gp_mram}, gp{gp_scale}, a{addr_reg}, 1, 0\n" + isa += f"S_ADDI_INT gp{gp_scale}, gp0, {self.program.tile_elems}\n" + isa += f"C_SET_SCALE_REG gp{gp_scale}\n" + isa += f"S_ADDI_INT gp{gp_stride}, gp0, {self.program.mlen}\n" + isa += f"C_SET_STRIDE_REG gp{gp_stride}\n" + self.program.compiler.generated_code += isa + + self.program.compiler.register_allocator.free_gp(gp_addr) + self.program.compiler.register_allocator.free_gp(gp_exec) + self.program.compiler.register_allocator.free_addr([addr_reg]) + + def emit_load_tile_from_hbm( + self, + *, + hbm_addr: int, + vram_addr: int, + hbm_stride: Optional[int] = None, + hbm_scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + # PLENA TVM extension: when set, the runtime-computed offset + # comes from this GP register (caller owns it; emitter just + # reads). `hbm_start_offset` is ignored in that case. + hbm_start_offset_reg: Optional[int] = None, + ) -> None: + ra = self.program.compiler.register_allocator + addr_reg = ra.allocate_addr(1)[0] + # Need 1 (addr-init scratch) + 5 (preload scratch). Use + # spill_borrow so the allocator can move long-lived outer GPs + # (loop counters / indices) to IntRAM temporarily. The + # caller-supplied ``hbm_start_offset_reg`` is read inside the + # emit body, so protect it from being spilled. + protect = [hbm_start_offset_reg] if hbm_start_offset_reg is not None else [] + borrowed, token = ra.spill_borrow( + 6, compiler=self.program.compiler, protect=protect, + ) + gp_addr = [borrowed[0]] + gp_preload = borrowed[1:6] + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[int(hbm_addr)], + ) + isa += reset_reg_asm(alive_registers=gp_preload) + isa += self._emit_preload_tile_isa( + vlen=self.program.mlen, + # Rows per H_PREFETCH_V instruction = the emulator's + # PREFETCH_V_AMOUNT. Using blen here emitted AMOUNT/blen× + # too many prefetches with wrong strides (data landed in + # the wrong VRAM rows -> downstream all-zero). + preload_len=self.program.v_prefetch_amount, + batch=self.program.mlen, + hidden_size=self.program.mlen, + act_vram_offset=vram_addr, + alive_registers=gp_preload, + activation_offset_reg=addr_reg, + stride_size=self.program.mlen if hbm_stride is None else int(hbm_stride), + scale_size=self.program.tile_elems if hbm_scale_size is None else int(hbm_scale_size), + hbm_start_offset=int(hbm_start_offset), + hbm_start_offset_reg=hbm_start_offset_reg, + ) + self.program.compiler.generated_code += isa + + ra.spill_return(token, compiler=self.program.compiler) + ra.free_addr([addr_reg]) + + def emit_store_tile_to_hbm( + self, + *, + vram_addr: int, + hbm_addr: int, + hbm_stride: Optional[int] = None, + hbm_scale_size: Optional[int] = None, + hbm_start_offset: int = 0, + # PLENA TVM extension; see emit_load_tile_from_hbm. + hbm_start_offset_reg: Optional[int] = None, + ) -> None: + ra = self.program.compiler.register_allocator + addr_reg = ra.allocate_addr(1)[0] + protect = [hbm_start_offset_reg] if hbm_start_offset_reg is not None else [] + borrowed, token = ra.spill_borrow( + 6, compiler=self.program.compiler, protect=protect, + ) + gp_addr = [borrowed[0]] + gp_store = borrowed[1:6] + + isa = "" + isa += preload_addr_reg_asm( + addr_reg_to_set=[addr_reg], + available_registers=gp_addr, + addr_reg_val=[int(hbm_addr)], + ) + isa += self._emit_store_tile_isa( + vlen=self.program.mlen, + batch=self.program.mlen, + hidden_size=self.program.mlen, + alive_registers=gp_store, + act_vram_offset=vram_addr, + hbm_addr_reg=addr_reg, + stride_size=self.program.mlen if hbm_stride is None else int(hbm_stride), + scale_size=self.program.tile_elems if hbm_scale_size is None else int(hbm_scale_size), + hbm_start_offset=int(hbm_start_offset), + # Rows per H_STORE_V instruction = the emulator's + # STORE_V_AMOUNT (see preload_len note above). + store_amount=self.program.v_writeback_amount, + hbm_start_offset_reg=hbm_start_offset_reg, + ) + self.program.compiler.generated_code += isa + + ra.spill_return(token, compiler=self.program.compiler) + ra.free_addr([addr_reg]) + + def emit_zero_vram_tile(self, vram_addr: int, num_rows: Optional[int] = None) -> None: + # `num_rows` is how many MLEN-wide rows to zero. Defaults to MLEN + # for legacy callers that always zero a full MLEN*MLEN tile. + # Buffers smaller than that (e.g. a (1, MLEN) accumulator) MUST + # pass the actual row count or the loop will write past the + # buffer's end into adjacent VRAM (silent corruption of whatever + # follows in the address map). + loop_count = self.program.mlen if num_rows is None else int(num_rows) + if loop_count < 1: + raise ValueError(f"num_rows must be >= 1, got {loop_count}") + gp_regs = self.program.compiler.register_allocator.allocate_gp(1) + (gp,) = gp_regs + # Compile-time unrolled: emit one V_MUL_VF per row with the row + # address baked in as a literal. No C_LOOP, no per-iter + # S_ADDI_INT address advance — kills the gp loop-maintenance + # overhead the hardware loop incurs. + lines = [f"; zero tile vram[{vram_addr}] rows={loop_count} (unrolled)"] + for i in range(loop_count): + row_addr = vram_addr + i * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp}, gp0, {row_addr}") + lines.append(f"V_MUL_VF gp{gp}, gp{gp}, f0, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_map_v_fp_tile( + self, + *, + vram_addr: int, + fpram_addr: int, + row_count: int, + row_width: int, + task_id: str = "map_v_fp_tile", + ) -> None: + if row_count <= 0 or row_width <= 0: + raise ValueError(f"emit_map_v_fp_tile expects positive row_count/row_width, got {row_count}/{row_width}") + if row_width != self.program.mlen: + raise ValueError( + f"emit_map_v_fp_tile currently requires row_width == mlen == {self.program.mlen}, got {row_width}" + ) + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_dst, gp_src = gp_regs + # Compile-time unrolled: row addresses baked in as literals, one + # S_MAP_V_FP per row, no C_LOOP / per-iter address advance. + lines = [f"; map fp tile task {task_id} fpram[{fpram_addr}] -> " + f"vram[{vram_addr}] (unrolled, {row_count} rows)"] + for i in range(row_count): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {vram_addr + i * row_width}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {fpram_addr + i * row_width}") + lines.append(f"S_MAP_V_FP gp{gp_dst}, gp{gp_src}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_map_fp_v_tile( + self, + *, + fpram_addr: int, + vram_addr: int, + row_count: int, + row_width: int, + task_id: str = "map_fp_v_tile", + ) -> None: + if row_count <= 0 or row_width <= 0: + raise ValueError(f"emit_map_fp_v_tile expects positive row_count/row_width, got {row_count}/{row_width}") + if row_width != self.program.mlen: + raise ValueError( + f"emit_map_fp_v_tile currently requires row_width == mlen == {self.program.mlen}, got {row_width}" + ) + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_dst, gp_src = gp_regs + # Compile-time unrolled: row addresses baked in as literals, one + # S_MAP_FP_V per row, no C_LOOP / per-iter address advance. + lines = [f"; map fp tile task {task_id} vram[{vram_addr}] -> " + f"fpram[{fpram_addr}] (unrolled, {row_count} rows)"] + for i in range(row_count): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {fpram_addr + i * row_width}") + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {vram_addr + i * row_width}") + lines.append(f"S_MAP_FP_V gp{gp_dst}, gp{gp_src}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_btmm( + self, + *, + lhs_packed_vram_addr: int, + rhs_mram_addr: int, + task_id: str = "btmm", + ) -> None: + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_mram_base, gp_lhs_base = gp_regs + lines = [ + ( + f"; btmm task {task_id} lhs_packed=vram[{lhs_packed_vram_addr}] " + f"rhs_mram={rhs_mram_addr} lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_mram_base}, gp0, {rhs_mram_addr}", + f"S_ADDI_INT gp{gp_lhs_base}, gp0, {lhs_packed_vram_addr}", + f"M_BTMM gp0, gp{gp_mram_base}, gp{gp_lhs_base}", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp(gp_regs) + + def emit_btmm_wo( + self, + *, + base_addr: int, + tile_count: int, + task_id: str = "btmm_wo", + ) -> None: + gp_out = self.program.compiler.register_allocator.allocate_gp(1)[0] + lines = [ + ( + f"; btmm write-only task {task_id} out=vram[{base_addr}] " + f"tiles={tile_count} lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_out}, gp0, {base_addr}", + f"M_BMM_WO gp{gp_out}, 0", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp([gp_out]) + + def emit_mv( + self, + *, + lhs_vram_addr, + rhs_mram_addr, + dst_vram_addr, + task_id: str = "mv", + lhs_offset_reg=None, + rhs_offset_reg=None, + dst_offset_reg=None, + n: int | None = None, + ) -> None: + """Per-head M_MV + M_MV_WO (single-lane matrix-vector). + + Each (M_MV, M_MV_WO) pair processes BLEN-wide column blocks: M_MV + accumulates ``vec[mlen] @ mat[mlen, blen]`` into the systolic array + first row, M_MV_WO drains those blen elements to VRAM. To cover + ``n`` columns total (defaults to ``btmm_hlen`` -- one full head), + we loop ``n / blen`` times, advancing both the matrix column + offset and the destination offset by blen each iteration. + + Mirrors emit_matmul's blen-loop (used by plena.matmul / M_MM) but + with the LHS being a single row and the writeback being M_MV_WO. + """ + if n is None: + n = int(self.program.btmm_hlen) + blen = int(self.program.blen) + if n % blen != 0: + raise ValueError( + f"emit_mv: column extent n={n} must be a multiple of blen={blen}" + ) + tiles = n // blen + + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_v, gp_m, gp_o = gp_regs + lines = [ + ( + f"; mv task {task_id} v=vram[{lhs_vram_addr}]" + f" m=mram[{rhs_mram_addr}] dst=vram[{dst_vram_addr}]" + f" tiles={tiles} blen={blen}" + ), + # Set up vector base (lhs). + f"S_ADDI_INT gp{gp_v}, gp0, {lhs_vram_addr}", + ] + if lhs_offset_reg is not None: + lines.append(f"S_ADD_INT gp{gp_v}, gp{gp_v}, gp{lhs_offset_reg}") + + # Each iteration walks the matrix and dst by blen elements. Set + # up the per-iteration starting m_addr / dst_addr (head base) once, + # then bump them by blen inside the loop body. + lines.append(f"S_ADDI_INT gp{gp_m}, gp0, {rhs_mram_addr}") + if rhs_offset_reg is not None: + lines.append(f"S_ADD_INT gp{gp_m}, gp{gp_m}, gp{rhs_offset_reg}") + lines.append(f"S_ADDI_INT gp{gp_o}, gp0, {dst_vram_addr}") + if dst_offset_reg is not None: + lines.append(f"S_ADD_INT gp{gp_o}, gp{gp_o}, gp{dst_offset_reg}") + + for t in range(tiles): + lines.append(f"M_MV gp0, gp{gp_m}, gp{gp_v}") + lines.append(f"M_MV_WO gp{gp_o}, 0") + if t < tiles - 1: + lines.append(f"S_ADDI_INT gp{gp_m}, gp{gp_m}, {blen}") + lines.append(f"S_ADDI_INT gp{gp_o}, gp{gp_o}, {blen}") + + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp(gp_regs) + + def emit_btmv( + self, + *, + lhs_packed_vram_addr: int, + rhs_mram_addr: int, + task_id: str = "btmv", + ) -> None: + """Lane-fused vector × matrix^T (M_BTMV). + + Mirrors emit_btmm — same MRAM/VRAM register setup, same operand + order, just the M_BTMM opcode swapped for M_BTMV. The hardware + consumes a 1-row vector LHS instead of an mlen-row matrix LHS. + """ + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_mram_base, gp_lhs_base = gp_regs + lines = [ + ( + f"; btmv task {task_id} lhs_packed=vram[{lhs_packed_vram_addr}] " + f"rhs_mram={rhs_mram_addr} lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_mram_base}, gp0, {rhs_mram_addr}", + f"S_ADDI_INT gp{gp_lhs_base}, gp0, {lhs_packed_vram_addr}", + f"M_BTMV gp0, gp{gp_mram_base}, gp{gp_lhs_base}", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp(gp_regs) + + def emit_bmv_wo( + self, + *, + base_addr: int, + task_id: str = "bmv_wo", + ) -> None: + """Drain accumulator from systolic-array first row to VRAM + (M_BMV_WO). Writes lane_count MLEN-wide rows starting at base_addr. + """ + gp_out = self.program.compiler.register_allocator.allocate_gp(1)[0] + lines = [ + ( + f"; bmv write-only task {task_id} out=vram[{base_addr}] " + f"lanes={self.program.btmm_lane_count} head_width={self.program.btmm_hlen}" + ), + f"S_ADDI_INT gp{gp_out}, gp0, {base_addr}", + f"M_BMV_WO gp{gp_out}, 0", + ] + self.program.compiler.generated_code += "\n".join(lines) + "\n" + self.program.compiler.register_allocator.free_gp([gp_out]) + + def emit_matmul( + self, + *, + lhs_vram_addrs: Sequence[int], + rhs_mram_addrs: Sequence[int], + dst_vram_addr: int, + task_id: str = "matmul", + zero_dst: bool = False, + ) -> None: + if len(lhs_vram_addrs) != len(rhs_mram_addrs): + raise ValueError("lhs_vram_addrs and rhs_mram_addrs must have equal lengths") + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + lines = [f"; matmul task {task_id}"] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + lhs_prog = self.program._arith_progression([int(addr) for addr in lhs_vram_addrs]) + rhs_prog = self.program._arith_progression([int(addr) for addr in rhs_mram_addrs]) + + for oc in range(tiles_per_mlen): + for orow in range(tiles_per_mlen): + if lhs_prog is not None and rhs_prog is not None: + lhs_start, pair_count, lhs_step = lhs_prog + rhs_start, _, rhs_step = rhs_prog + act_addr = lhs_start + orow * self.program.blen * self.program.mlen + mat_addr = rhs_start + oc * self.program.blen + # Compile-time unrolled: one M_MM per (act,mat) pair + # with both addresses baked in as literals. No + # C_LOOP, no per-iter S_ADDI_INT advance. + for p in range(pair_count): + lines.append( + f"S_ADDI_INT gp{gp_act}, gp0, {act_addr + p * lhs_step}") + lines.append( + f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr + p * rhs_step}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + else: + for lhs_addr, rhs_addr in zip(lhs_vram_addrs, rhs_mram_addrs): + act_addr = lhs_addr + orow * self.program.blen * self.program.mlen + mat_addr = rhs_addr + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + out_addr = dst_vram_addr + orow * self.program.blen * self.program.mlen + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_out}, gp0, {out_addr}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_matmul_single_tile_hwloop( + self, + *, + lhs_vram_addr: int, + rhs_mram_addr: int, + dst_vram_addr: int, + task_id: str = "matmul_single_hwloop", + ) -> None: + """Single-tile (mlen*mlen) MM emitted with hardware loops over the + blen-tiled output (oc, orow), instead of the Python-unrolled form + used by `emit_matmul`. + + Generates O((mlen/blen)^2) M_MM/M_MM_WO pairs *dynamically* but + only ~15 lines of *static* ISA — vs. ~7*256 ≈ 1800 lines for the + unrolled `emit_matmul` on a 64/4 (mlen/blen) configuration. + Identical dynamic instruction count, so loop-instruction caps + in the emulator behave the same. + + Loop structure (mirrors sub_matrix_manager.vram_sub_projection_asm + with num_hidden_blocks == 1, so the innermost accumulation loop + collapses to a single M_MM): + + for oc in tiles_per_mlen: # output blen-cols + for orow in tiles_per_mlen: # output blen-rows + M_MM 0, gp_mat, gp_act + M_MM_WO gp_result, gp0, 0 + """ + ra = self.program.compiler.register_allocator + # Single allocate_gp call -> single free_gp at the end. The loop + # counters (gp_loop_outer, gp_loop_middle) stay marked in-use for + # the entire emit, so nested ISAEmitter calls (none today, but + # future-proof against a body sub-emit) cannot collide with them. + gp_regs = ra.allocate_gp(6) + (gp_act_row_base, gp_mat_col_base, gp_result_col_base, gp_result, + gp_loop_outer, gp_loop_middle) = gp_regs + + tiles_per_mlen = self.program.mlen // self.program.blen + output_row_stride = self.program.blen * self.program.mlen + blen = self.program.blen + + # Single accumulation pair -> drop the inner accum C_LOOP and the + # gp_act/gp_mat copies that the multi-pair runtime version needs. + # M_MM reads its operand regs and does not mutate them, so we + # pass gp_act_row_base / gp_mat_col_base directly. gp_act_row_base + # is advanced inside the orow loop (output_row_stride per iter) + # and re-loaded with lhs_vram_addr at the top of each oc iter. + # Compile-time unrolled twin C_LOOP over (oc, orow). Original: + # oc loop: mat_col_base += blen, result_col_base += blen + # orow loop: act_row_base += output_row_stride, + # result += output_row_stride + # Every address is now a literal — no C_LOOP, no per-iter advance. + lines = [ + f"; matmul (single-tile, unrolled) task {task_id} " + f"lhs=vram[{lhs_vram_addr}] rhs=mram[{rhs_mram_addr}] " + f"dst=vram[{dst_vram_addr}] " + f"regs: act=gp{gp_act_row_base} mat=gp{gp_mat_col_base} " + f"result=gp{gp_result}", + ] + for oc in range(tiles_per_mlen): + mat_col = rhs_mram_addr + oc * blen + result_col = dst_vram_addr + oc * blen + for orow in range(tiles_per_mlen): + act_row = lhs_vram_addr + orow * output_row_stride + result = result_col + orow * output_row_stride + lines.append(f"S_ADDI_INT gp{gp_mat_col_base}, gp0, {mat_col}") + lines.append(f"S_ADDI_INT gp{gp_act_row_base}, gp0, {act_row}") + lines.append(f"S_ADDI_INT gp{gp_result}, gp0, {result}") + lines.append(f"M_MM 0, gp{gp_mat_col_base}, gp{gp_act_row_base}") + lines.append(f"M_MM_WO gp{gp_result}, gp0, 0") + self.program.compiler.generated_code += "\n".join(lines) + "\n" + ra.free_gp(gp_regs) + + def emit_slot_matmul( + self, + *, + lhs_vram_addr: int, + lhs_vram_addr_reg: Optional[int] = None, + rhs_mram_addr: int, + rhs_col_offset: int = 0, + rhs_col_offset_reg: Optional[int] = None, + dst_vram_addr: int, + dst_col_offset: int = 0, + dst_col_offset_reg: Optional[int] = None, + col_count: int, + task_id: str = "slot_matmul", + zero_dst: bool = False, + ) -> None: + if col_count <= 0: + raise ValueError("emit_slot_matmul requires one positive col_count") + if col_count % self.program.blen != 0: + raise ValueError( + f"emit_slot_matmul requires col_count divisible by blen={self.program.blen}, got {col_count}" + ) + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + tiles_per_slot = col_count // self.program.blen + lines = [ + f"; slot matmul task {task_id}" + f" rhs_col_offset=" + f"{'gp' + str(rhs_col_offset_reg) if rhs_col_offset_reg is not None else rhs_col_offset}" + f" dst_col_offset=" + f"{'gp' + str(dst_col_offset_reg) if dst_col_offset_reg is not None else dst_col_offset}" + ] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + for oc in range(tiles_per_slot): + if lhs_vram_addr_reg is not None: + # base = register-held lhs address (already includes any + # dynamic lhs_row_offset). Reset gp_act to base each oc tile. + lines.append(f"S_ADDI_INT gp{gp_act}, gp{lhs_vram_addr_reg}, 0") + else: + act_addr = lhs_vram_addr + lines.append(f"S_ADDI_INT gp{gp_act}, gp0, {act_addr}") + if rhs_col_offset_reg is not None: + lines.append(f"S_ADDI_INT gp{gp_mat}, gp{rhs_col_offset_reg}, {rhs_mram_addr + oc * self.program.blen}") + else: + mat_addr = rhs_mram_addr + rhs_col_offset + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + if dst_col_offset_reg is not None: + lines.append(f"S_ADDI_INT gp{gp_out}, gp{dst_col_offset_reg}, {dst_vram_addr + oc * self.program.blen}") + else: + out_addr = dst_vram_addr + dst_col_offset + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_out}, gp0, {out_addr}") + # Compile-time unrolled inner C_LOOP. gp_act / gp_out were + # set above (possibly off a register base); for the first + # iter reuse them as-is, for the rest re-derive off the same + # base + literal offset. gp_act may sit on lhs_vram_addr_reg, + # so advance relative to its current value with S_ADDI_INT. + row_stride = self.program.blen * self.program.mlen + for t in range(tiles_per_mlen): + if t > 0: + lines.append( + f"S_ADDI_INT gp{gp_act}, gp{gp_act}, {row_stride}") + lines.append( + f"S_ADDI_INT gp{gp_out}, gp{gp_out}, {row_stride}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_matmul_narrow_tile_hwloop( + self, + *, + lhs_vram_addr: int, + rhs_mram_addr: int, + dst_vram_addr: int, + hlen: int, + rhs_col_offset: int = 0, + dst_col_offset: int = 0, + dst_row_stride: Optional[int] = None, + task_id: str = "matmul_narrow_hwloop", + zero_dst: bool = False, + ) -> None: + """Emit `mlen x mlen @ mlen x hlen` via the regular M_MM path.""" + if hlen <= 0: + raise ValueError("emit_matmul_narrow_tile_hwloop requires positive hlen") + if hlen > self.program.mlen: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires hlen <= mlen={self.program.mlen}, got {hlen}" + ) + if hlen % self.program.blen != 0: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires hlen divisible by blen={self.program.blen}, got {hlen}" + ) + if dst_row_stride is None: + dst_row_stride = int(hlen) + if dst_row_stride < hlen: + raise ValueError( + f"emit_matmul_narrow_tile_hwloop requires dst_row_stride >= hlen ({hlen}), got {dst_row_stride}" + ) + if zero_dst: + self.emit_zero_vram_tile(dst_vram_addr) + + ra = self.program.compiler.register_allocator + gp_regs = ra.allocate_gp(5) + gp_act, gp_mat, gp_out, gp_stride, gp_loop = gp_regs + tiles_per_mlen = self.program.mlen // self.program.blen + tiles_per_slot = hlen // self.program.blen + output_row_stride = self.program.blen * int(dst_row_stride) + lines = [ + f"; narrow matmul task {task_id} lhs=vram[{lhs_vram_addr}] " + f"rhs=mram[{rhs_mram_addr}] rhs_col_offset={rhs_col_offset} " + f"dst=vram[{dst_vram_addr}] dst_col_offset={dst_col_offset} " + f"hlen={hlen} dst_row_stride={dst_row_stride}" + ] + lines.append(f"S_ADDI_INT gp{gp_stride}, gp0, 1") + + act_row_stride = self.program.blen * self.program.mlen + for oc in range(tiles_per_slot): + act_addr = lhs_vram_addr + mat_addr = rhs_mram_addr + rhs_col_offset + oc * self.program.blen + out_addr = dst_vram_addr + dst_col_offset + oc * self.program.blen + lines.append(f"S_ADDI_INT gp{gp_mat}, gp0, {mat_addr}") + # Compile-time unrolled inner C_LOOP — act/out addresses baked + # in as literals (act steps by blen*mlen, out by the dst row + # stride). No C_LOOP, no per-iter advance. + for t in range(tiles_per_mlen): + lines.append( + f"S_ADDI_INT gp{gp_act}, gp0, {act_addr + t * act_row_stride}") + lines.append( + f"S_ADDI_INT gp{gp_out}, gp0, {out_addr + t * output_row_stride}") + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"M_MM_WO gp{gp_out}, gp0, 0") + + self.program.compiler.generated_code += "\n".join(lines) + "\n" + ra.free_gp(gp_regs) + + def emit_matmul_general( + self, + *, + M_tiles: int, + K_tiles: int, + N: int, + lhs_vram_base: int, + lhs_offset: int = 0, + lhs_offset_reg: Optional[int] = None, + lhs_m_tile_stride: Optional[int] = None, + lhs_k_tile_stride: Optional[int] = None, + rhs_mram_base: int, + rhs_offset: int = 0, + rhs_offset_reg: Optional[int] = None, + rhs_k_tile_stride: Optional[int] = None, + rhs_n_mlen_tile_stride: Optional[int] = None, + dst_vram_base: int, + dst_offset: int = 0, + dst_offset_reg: Optional[int] = None, + dst_m_tile_stride: Optional[int] = None, + dst_row_stride: Optional[int] = None, + task_id: str = "matmul", + scratch_regs: Optional[List[int]] = None, + transpose_b: bool = False, + unroll_loops: bool = False, + ) -> None: + """Unified `(M, K) @ (K, N) -> (M, N)` matmul. + + K is folded into the systolic-array accumulator: each output + BLEN×BLEN sub-tile is produced by K_tiles `M_MM` issuances followed + by one `M_MM_WO`. No software scratch / v_add is needed for K + accumulation. + + When ``transpose_b=True``, B is expected in MRAM as ``(N, K)`` + row-major (matches the nn.Linear weight convention). The inner + op switches from ``M_MM`` to ``M_TMM`` which transposes the + (mlen, mlen) MRAM tile on the fly, and the per-N column + sub-tile step changes from ``blen`` (columns of (K, N)) to + ``blen * mlen`` (rows of (N, K)) — sim enforces the latter via + ``mat_offset.assert_multiple_of(mlen)`` inside ``M_TMM``. + + Shape constraints: + M : multiple of mlen, M_tiles = M / mlen + K : multiple of mlen, K_tiles = K / mlen + N : multiple of hlen (hlen = btmm_hlen on the shim) + + N may exceed mlen — the emitter walks B in (K_tiles × N_mlen_tiles) + mlen-wide blocks, where N_mlen_tiles = ceil(N / mlen). The trailing + N-mlen block is allowed to carry < mlen valid columns (down to the + nearest hlen boundary). + + Layout assumptions (defaults match a packed tile-grid layout): + A in VRAM : (M_tiles × K_tiles) grid of (mlen, mlen) tiles, + packed: A_tile(m, k) = base + m*K_tiles*mlen² + k*mlen². + B in MRAM : (K_tiles × N_mlen_tiles) grid of (mlen, mlen) tiles, + packed K-major: + B_tile(k, nm) = base + k*N_mlen_tiles*mlen² + nm*mlen². + C in VRAM : row-major (M, N) with `dst_row_stride` elements + between consecutive output rows (defaults to N). + M-tile spacing defaults to `mlen * dst_row_stride`. + + Offsets: + `lhs_offset` / `rhs_offset` / `dst_offset` are static element + offsets added to the corresponding base addresses. Useful when + A/B/C are sub-regions of larger packed buffers (mm_slot pattern). + For each side, pass the ``*_offset_reg`` form instead to use a + dynamic (PrimExpr-derived) offset already materialised to a gp + register; when ``*_offset_reg`` is set, the matching static + ``*_offset`` is ignored and the per-iteration pointer is formed + via ``S_ADDI_INT gp_dst, gp{*_offset_reg}, ``. + """ + mlen = self.program.mlen + blen = self.program.blen + hlen = int(self.program.btmm_hlen) + if M_tiles <= 0 or K_tiles <= 0 or N <= 0: + raise ValueError(f"M_tiles, K_tiles, N must be positive; got {M_tiles}, {K_tiles}, {N}") + if N % hlen != 0: + raise ValueError(f"N must be divisible by hlen={hlen}; got N={N}") + + N_mlen_tiles = (N + mlen - 1) // mlen + + if lhs_k_tile_stride is None: + lhs_k_tile_stride = mlen * mlen + if lhs_m_tile_stride is None: + lhs_m_tile_stride = K_tiles * mlen * mlen + # When ``transpose_b`` is set, B is laid out as ``(N, K)`` — + # tiles are packed N-major (one full K-row of mlen-tiles per + # N-mlen step). When unset, B is ``(K, N)`` and tiles are + # packed K-major. The inner-tile layout stays row-major in both + # cases; M_TMM transposes the (mlen, mlen) tile on the fly. + if rhs_n_mlen_tile_stride is None: + if transpose_b: + rhs_n_mlen_tile_stride = K_tiles * mlen * mlen + else: + rhs_n_mlen_tile_stride = mlen * mlen + if rhs_k_tile_stride is None: + if transpose_b: + rhs_k_tile_stride = mlen * mlen + else: + rhs_k_tile_stride = N_mlen_tiles * mlen * mlen + if dst_row_stride is None: + dst_row_stride = N + if dst_m_tile_stride is None: + dst_m_tile_stride = mlen * int(dst_row_stride) + + tiles_per_mlen = mlen // blen + a_orow_step = blen * mlen + # M_MM_WO writes 4 rows (i=0..blen-1) at vram[vec_base + i*mlen] + # — physical row stride is always mlen, regardless of how dense + # the dst's logical N maps inside each mlen-row (e.g. N=16 with + # 4 lanes packed: 4 cols per lane stored together inside one + # mlen-row). The outer orow advance must therefore step by + # ``blen * mlen`` to jump 4 physical mlen-rows; the previous + # ``blen * dst_row_stride`` formula collapsed to a 1-mlen-row + # step for narrow N (≤ mlen) and made the kernel re-write the + # same mlen-rows repeatedly. + c_orow_step = blen * mlen + + ra = self.program.compiler.register_allocator + # Caller can pre-allocate the 7 scratch GPs (and pin them) so they're + # disjoint from any offset registers it materialised. When caller + # passes them in we don't free here either — caller owns the lifetime. + if scratch_regs is not None: + if len(scratch_regs) != 7: + raise ValueError( + f"emit_matmul_general expects 7 scratch_regs, got {len(scratch_regs)}" + ) + gp_regs = list(scratch_regs) + caller_owns_scratch = True + else: + gp_regs = ra.allocate_gp(7) + caller_owns_scratch = False + (gp_act_orow, gp_out_orow, gp_act, gp_mat, gp_out, + gp_loop_orow, gp_loop_k) = gp_regs + + lines = [ + f"; matmul (general) task {task_id} " + f"M={M_tiles * mlen} K={K_tiles * mlen} N={N} " + f"(M_tiles={M_tiles} K_tiles={K_tiles} N_mlen_tiles={N_mlen_tiles})" + ] + + for m in range(M_tiles): + # Static-residual addresses (everything that doesn't depend on + # a dynamic offset register). When the matching `*_offset_reg` + # is set we issue `S_ADDI_INT gp_X, gp{reg}, ` to fold + # the runtime offset in; otherwise we just load the absolute + # static value. + lhs_static_full = int(lhs_vram_base) + int(lhs_offset) + m * int(lhs_m_tile_stride) + lhs_static_dyn = int(lhs_vram_base) + m * int(lhs_m_tile_stride) + dst_m_base_static_full = int(dst_vram_base) + int(dst_offset) + m * int(dst_m_tile_stride) + dst_m_base_static_dyn = int(dst_vram_base) + m * int(dst_m_tile_stride) + for n_mlen in range(N_mlen_tiles): + rhs_n_mlen_static_full = ( + int(rhs_mram_base) + int(rhs_offset) + + n_mlen * int(rhs_n_mlen_tile_stride) + ) + rhs_n_mlen_static_dyn = ( + int(rhs_mram_base) + n_mlen * int(rhs_n_mlen_tile_stride) + ) + cols_here = min(mlen, N - n_mlen * mlen) + tiles_per_n_mlen = cols_here // blen + # Per-oc B offset within the current (mlen, mlen) tile: + # M_MM picks (mlen, blen) columns at byte stride blen. + # M_TMM picks blen ROWS (transposed -> the same blen + # columns of B^T) at byte stride mlen*blen — sim asserts + # ``mat_offset / mlen ∈ [0, mlen)`` after the inner + # ``assert_multiple_of(mlen)``, so the per-row scale is + # exactly mlen. + oc_b_step = blen * mlen if transpose_b else blen + # Matmul opcode: M_TMM transposes the (mlen, mlen) MRAM + # tile on the fly; its (rs1, rs2) order is also swapped + # vs M_MM (rs1 = vram_lhs, rs2 = mram_rhs). + mm_opcode = "M_TMM" if transpose_b else "M_MM" + for oc in range(tiles_per_n_mlen): + dst_col = n_mlen * mlen + oc * blen + if lhs_offset_reg is not None: + lines.append( + f"S_ADDI_INT gp{gp_act_orow}, gp{lhs_offset_reg}, " + f"{lhs_static_dyn}" + ) + else: + lines.append(f"S_ADDI_INT gp{gp_act_orow}, gp0, {lhs_static_full}") + if dst_offset_reg is not None: + lines.append( + f"S_ADDI_INT gp{gp_out_orow}, gp{dst_offset_reg}, " + f"{dst_m_base_static_dyn + dst_col}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_out_orow}, gp0, " + f"{dst_m_base_static_full + dst_col}" + ) + + if unroll_loops: + # Fully unrolled body: emit ``tiles_per_mlen`` + # copies of the (K_tiles M_MM + M_MM_WO) cell, + # each with its own static lhs_act / mat base. + # No C_LOOP nesting — diagnostic mode for the + # debugger to read straight through. + for orow in range(tiles_per_mlen): + act_static = lhs_static_full + orow * a_orow_step + dst_static = dst_m_base_static_full + dst_col + orow * c_orow_step + if lhs_offset_reg is not None: + act_dyn = lhs_static_dyn + orow * a_orow_step + lines.append( + f"S_ADDI_INT gp{gp_act}, gp{lhs_offset_reg}, " + f"{act_dyn}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_act}, gp0, {act_static}" + ) + for k in range(K_tiles): + act_k = act_static + k * int(lhs_k_tile_stride) - (orow * a_orow_step + lhs_static_full - lhs_static_full) + # Recompute act/mat per k explicitly so + # there is no incremental S_ADDI between + # M_MMs (matches unroll-only style). + if k > 0: + if lhs_offset_reg is not None: + lines.append( + f"S_ADDI_INT gp{gp_act}, " + f"gp{lhs_offset_reg}, " + f"{lhs_static_dyn + orow * a_orow_step + k * int(lhs_k_tile_stride)}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_act}, gp0, " + f"{act_static + k * int(lhs_k_tile_stride)}" + ) + mat_static = ( + rhs_n_mlen_static_full + + oc * oc_b_step + + k * int(rhs_k_tile_stride) + ) + if rhs_offset_reg is not None: + mat_dyn = ( + rhs_n_mlen_static_dyn + + oc * oc_b_step + + k * int(rhs_k_tile_stride) + ) + lines.append( + f"S_ADDI_INT gp{gp_mat}, " + f"gp{rhs_offset_reg}, {mat_dyn}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_mat}, gp0, {mat_static}" + ) + if transpose_b: + lines.append( + f"M_TMM 0, gp{gp_act}, gp{gp_mat}" + ) + else: + lines.append( + f"M_MM 0, gp{gp_mat}, gp{gp_act}" + ) + if dst_offset_reg is not None: + lines.append( + f"S_ADDI_INT gp{gp_out_orow}, " + f"gp{dst_offset_reg}, " + f"{dst_m_base_static_dyn + dst_col + orow * c_orow_step}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_out_orow}, gp0, " + f"{dst_static}" + ) + lines.append( + f"M_MM_WO gp{gp_out_orow}, gp0, 0" + ) + continue + + lines.append(f"C_LOOP_START gp{gp_loop_orow}, {tiles_per_mlen}") + lines.append(f"S_ADDI_INT gp{gp_act}, gp{gp_act_orow}, 0") + if rhs_offset_reg is not None: + lines.append( + f"S_ADDI_INT gp{gp_mat}, gp{rhs_offset_reg}, " + f"{rhs_n_mlen_static_dyn + oc * oc_b_step}" + ) + else: + lines.append( + f"S_ADDI_INT gp{gp_mat}, gp0, " + f"{rhs_n_mlen_static_full + oc * oc_b_step}" + ) + lines.append(f"C_LOOP_START gp{gp_loop_k}, {K_tiles}") + if transpose_b: + lines.append(f"M_TMM 0, gp{gp_act}, gp{gp_mat}") + else: + lines.append(f"M_MM 0, gp{gp_mat}, gp{gp_act}") + lines.append(f"S_ADDI_INT gp{gp_act}, gp{gp_act}, {int(lhs_k_tile_stride)}") + lines.append(f"S_ADDI_INT gp{gp_mat}, gp{gp_mat}, {int(rhs_k_tile_stride)}") + lines.append(f"C_LOOP_END gp{gp_loop_k}") + lines.append(f"M_MM_WO gp{gp_out_orow}, gp0, 0") + lines.append(f"S_ADDI_INT gp{gp_act_orow}, gp{gp_act_orow}, {a_orow_step}") + lines.append(f"S_ADDI_INT gp{gp_out_orow}, gp{gp_out_orow}, {c_orow_step}") + lines.append(f"C_LOOP_END gp{gp_loop_orow}") + + if not caller_owns_scratch: + ra.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_tile_binary( + self, + *, + lhs_vram_addr: int, + rhs_vram_addr: int, + dst_vram_addr: int, + op: str = "add", + task_id: str = "tile_binary", + num_rows: Optional[int] = None, + ) -> None: + """One ``V_*_VV`` per MLEN-wide row, looped ``num_rows`` times. + + ``num_rows`` defaults to MLEN (legacy behavior — assumes a full + MLEN×MLEN tile per operand, which is what flash_attention / + BTMM-style kernels with lane-fused (rows, hlen) post-expansion + buffers want). Callers with smaller operands (e.g. one + MLEN-wide row, or any (1, …, MLEN) BSHD buffer where the + flattened element count is below MLEN²) must pass the actual + row count or the loop will over-iterate past the operand's end + and corrupt whatever VRAM follows it. + """ + op_to_insn = { + "add": "V_ADD_VV", + "sub": "V_SUB_VV", + "mul": "V_MUL_VV", + } + if op not in op_to_insn: + raise ValueError(f"Unsupported tile binary op={op!r}") + loop_count = self.program.mlen if num_rows is None else int(num_rows) + if loop_count < 1: + raise ValueError(f"num_rows must be >= 1, got {loop_count}") + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_dst, gp_lhs, gp_rhs = gp_regs + # Compile-time unrolled: one V_*_VV per row, all three operand + # row addresses baked in as literals. No C_LOOP, no per-iter + # S_ADDI_INT advance. + lines = [ + f"; tile binary task {task_id} op={op} rows={loop_count} (unrolled)", + ] + mlen = self.program.mlen + for i in range(loop_count): + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_vram_addr + i * mlen}") + lines.append(f"S_ADDI_INT gp{gp_lhs}, gp0, {lhs_vram_addr + i * mlen}") + lines.append(f"S_ADDI_INT gp{gp_rhs}, gp0, {rhs_vram_addr + i * mlen}") + lines.append(f"{op_to_insn[op]} gp{gp_dst}, gp{gp_lhs}, gp{gp_rhs}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + + def emit_tile_add( + self, + *, + lhs_vram_addr: int, + rhs_vram_addr: int, + dst_vram_addr: int, + task_id: str = "tile_add", + ) -> None: + self.emit_tile_binary( + lhs_vram_addr=lhs_vram_addr, + rhs_vram_addr=rhs_vram_addr, + dst_vram_addr=dst_vram_addr, + op="add", + task_id=task_id, + ) + + def emit_fp_kernel( + self, + *, + src1_addrs: Sequence[int], + dst_addrs: Sequence[int], + src2_addrs: Optional[Sequence[int]] = None, + op: str, + task_id: str = "fp_kernel", + ) -> None: + unary_copy = {"copy", "fill"} + unary_math = {"exp": "S_EXP_FP", "reci": "S_RECI_FP", "sqrt": "S_SQRT_FP"} + binary_math = {"add": "S_ADD_FP", "sub": "S_SUB_FP", "mul": "S_MUL_FP", "max": "S_MAX_FP"} + if len(src1_addrs) != len(dst_addrs): + raise ValueError("emit_fp_kernel expects matched src1/dst lengths") + if src2_addrs is not None and len(src2_addrs) != len(dst_addrs): + raise ValueError("emit_fp_kernel expects matched src2/dst lengths") + if op in unary_copy: + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_src, gp_dst = gp_regs + # Compile-time unrolled — one S_LD_FP/S_ST_FP pair per slot + # with literal addresses. No C_LOOP / arith-progression + # address-advance loop. + lines = [f"; fp kernel task {task_id} op={op} (unrolled)"] + for src_addr, dst_addr in zip(src1_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + if op in unary_math: + gp_regs = self.program.compiler.register_allocator.allocate_gp(2) + gp_src, gp_dst = gp_regs + # Compile-time unrolled — one load / math / store per slot, + # literal addresses, no C_LOOP. + lines = [f"; fp kernel task {task_id} op={op} (unrolled)"] + for src_addr, dst_addr in zip(src1_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + if op in {"exp", "reci"}: + lines.append(f"{unary_math[op]} f1, f1, 0") + else: + lines.append(f"{unary_math[op]} f1, f1") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + if op in binary_math: + if src2_addrs is None: + raise ValueError(f"emit_fp_kernel op={op!r} requires src2_addrs") + gp_regs = self.program.compiler.register_allocator.allocate_gp(3) + gp_a, gp_b, gp_dst = gp_regs + # Compile-time unrolled — one load/load/math/store per slot, + # literal addresses, no C_LOOP. + lines = [f"; fp kernel task {task_id} op={op} (unrolled)"] + for src1_addr, src2_addr, dst_addr in zip(src1_addrs, src2_addrs, dst_addrs): + lines.append(f"S_ADDI_INT gp{gp_a}, gp0, {int(src1_addr)}") + lines.append(f"S_ADDI_INT gp{gp_b}, gp0, {int(src2_addr)}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_a}, 0") + lines.append(f"S_LD_FP f2, gp{gp_b}, 0") + lines.append(f"{binary_math[op]} f1, f1, f2") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" + return + raise ValueError(f"Unsupported emit_fp_kernel op={op!r}") + + def emit_row_operation( + self, + *, + src_vram_addr: int, + dst_vram_addr: Optional[int] = None, + op: str, + row_count: int, + dst_addrs: Optional[Sequence[int]] = None, + rhs_addrs: Optional[Sequence[int]] = None, + mask_val: Optional[int] = None, + task_id: str = "row_operations", + ) -> None: + if row_count <= 0: + return + unary_ops = {"exp", "reci"} + reduce_ops = {"reduce_max": "V_RED_MAX", "reduce_sum": "V_RED_SUM"} + binary_ops = {"mul": "V_MUL_VF", "add": "V_ADD_VF", "sub": "V_SUB_VF"} + if op not in unary_ops | set(reduce_ops) | set(binary_ops): + raise ValueError(f"Unsupported emit_row_operation op={op!r}") + + gp_regs = self.program.compiler.register_allocator.allocate_gp(5) + gp_src, gp_fp, gp_dst, gp_loop, gp_mask = gp_regs + lines = [f"; row operation task {task_id} op={op} rows={row_count}"] + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {int(src_vram_addr)}") + dst_vram_addr = int(src_vram_addr if dst_vram_addr is None else dst_vram_addr) + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_vram_addr}") + use_mask = mask_val is not None + if use_mask: + lines.append(f"; row operation mask {int(mask_val)}") + lines.append(f"S_ADDI_INT gp{gp_mask}, gp0, {int(mask_val)}") + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + + if op in unary_ops: + for row_index in range(int(row_count)): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + dst_row_addr = dst_vram_addr + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_row_addr}") + if op == "exp": + lines.append(f"V_EXP_V gp{gp_dst}, gp{gp_src}, {1 if use_mask else 0}") + else: + lines.append(f"V_RECI_V gp{gp_dst}, gp{gp_src}, {1 if use_mask else 0}") + elif op in reduce_ops: + if dst_addrs is None or len(dst_addrs) != row_count: + raise ValueError(f"emit_row_operation op={op!r} expects one dst fp addr per row") + for row_index, dst_addr in enumerate(dst_addrs): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {int(dst_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_dst}, 0") + lines.append(f"{reduce_ops[op]} f1, gp{gp_src}, {1 if use_mask else 0}") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + else: + if rhs_addrs is None or len(rhs_addrs) not in (1, row_count): + raise ValueError(f"emit_row_operation op={op!r} expects one rhs fp addr or one per row") + if len(rhs_addrs) == 1: + rhs_addr = int(rhs_addrs[0]) + for row_index in range(int(row_count)): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + dst_row_addr = dst_vram_addr + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_row_addr}") + lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {rhs_addr}") + lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") + if op == "sub": + lines.append(f"V_SUB_VF gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}, 0") + else: + lines.append(f"{binary_ops[op]} gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}") + else: + for row_index, rhs_addr in enumerate(rhs_addrs): + row_addr = int(src_vram_addr) + row_index * self.program.mlen + dst_row_addr = dst_vram_addr + row_index * self.program.mlen + lines.append(f"S_ADDI_INT gp{gp_src}, gp0, {row_addr}") + lines.append(f"S_ADDI_INT gp{gp_dst}, gp0, {dst_row_addr}") + lines.append(f"S_ADDI_INT gp{gp_fp}, gp0, {int(rhs_addr)}") + lines.append(f"S_LD_FP f1, gp{gp_fp}, 0") + if op == "sub": + lines.append(f"V_SUB_VF gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}, 0") + else: + lines.append(f"{binary_ops[op]} gp{gp_dst}, gp{gp_src}, f1, {1 if use_mask else 0}") + + if use_mask: + lines.append("S_ADDI_INT gp{0}, gp0, 0".format(gp_mask)) + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + + self.program.compiler.register_allocator.free_gp(gp_regs) + self.program.compiler.generated_code += "\n".join(lines) + "\n" diff --git a/tilelang_tvm_compiler/isa_pass.py b/tilelang_tvm_compiler/isa_pass.py new file mode 100644 index 0000000..05b9423 --- /dev/null +++ b/tilelang_tvm_compiler/isa_pass.py @@ -0,0 +1,3329 @@ +"""Pass 3: turn HLIR (with addresses) into real PLENA ISA text. + +Each HLIR op kind has a small dispatcher that pulls the right addresses +off the buffers and forwards them to ISAEmitter. This is intentionally +mechanical -- if you add a new op kind to `intrinsics.py`, you add one +case here too. + +For BTMM specifically, the runtime convention is: + - the actual `M_BTMM` instruction takes packed lhs (vram) + rhs (mram) + - it does NOT itself write the result; the result is committed to + VRAM by a paired `M_BMM_WO` instruction +The HLIR view collapses this to one `btmm` op that names lhs/rhs/dst; +this pass expands it into the `emit_btmm` + `emit_btmm_wo` pair so the +emitter contract is honoured. +""" + +from __future__ import annotations + +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple + +import tvm +from tvm import tir + +from . import hlir as _hlir +from . import scope as _scope +from .expr_materializer import ExprMaterializer, MaterializedExpr +from .frontend.mid_ir.cluster_guard import MLEN as _HW_MLEN +from .isa_emitter import ISAEmitter +from .program_shim import ProgramShim + + +class IsaEmissionError(RuntimeError): + pass + + +# Maximum unsigned literal that fits in the S_ADDI_INT three-operand +# immediate slot. opcode(6) + 2*operand(4) = 14 bits taken by other +# fields, leaving 32 - 14 = 18 bits for imm. Mirrors _S_ADDI_MAX in +# expr_materializer.py and _normalize_large_addi_immediates in +# tilelang_runtime_compier/tile_tensor_program/_program.py. +_S_ADDI_IMM_MAX = (1 << 18) - 1 # 262143 + + +def _normalize_large_addi_immediates(asm_code: str) -> str: + """Rewrite `S_ADDI_INT rd, rs1, imm` when imm overflows the 18-bit slot. + + Strategy: + - rs1 == gp0: LUI rd, hi ; ADDI rd, rd, lo + - rs1 != gp0, rd != rs1: LUI rd, hi ; ADDI rd, rd, lo ; S_ADD_INT rd, rd, rs1 + - rs1 != gp0, rd == rs1: cannot expand text-only without a scratch reg. + Emit a warning and leave the line untouched (the binary-stage + instruction mask will then truncate it, producing wrong code that + the user can spot via the warning). + """ + lines: List[str] = [] + for raw_line in asm_code.splitlines(): + line = raw_line.rstrip("\n") + stripped = line.strip() + if not stripped or stripped.startswith(";"): + lines.append(line) + continue + + parts = stripped.split(None, 1) + if len(parts) != 2 or parts[0] != "S_ADDI_INT": + lines.append(line) + continue + + operands = [item.strip() for item in parts[1].split(",")] + if len(operands) != 3: + lines.append(line) + continue + + rd, rs1, imm_text = operands + try: + imm_value = int(imm_text) + except ValueError: + lines.append(line) + continue + + if 0 <= imm_value <= _S_ADDI_IMM_MAX: + lines.append(line) + continue + + if imm_value < 0: + print( + f"[isa_pass] WARN: negative imm in {stripped!r}; " + f"normalize pass only handles unsigned overflow." + ) + lines.append(line) + continue + + upper = imm_value >> 12 + lower = imm_value & 0xFFF + + if rs1 == "gp0": + lines.append(f"S_LUI_INT {rd}, {upper}") + lines.append(f"S_ADDI_INT {rd}, {rd}, {lower}") + continue + + if rd != rs1: + lines.append(f"S_LUI_INT {rd}, {upper}") + lines.append(f"S_ADDI_INT {rd}, {rd}, {lower}") + lines.append(f"S_ADD_INT {rd}, {rd}, {rs1}") + continue + + print( + f"[isa_pass] WARN: cannot expand large-imm S_ADDI_INT in-place " + f"(rd==rs1=={rd}, imm={imm_value}): {stripped!r}. " + f"Need a scratch register — fix the emitter to use a separate rd." + ) + lines.append(line) + + normalized = "\n".join(lines) + if asm_code.endswith("\n"): + normalized += "\n" + return normalized + + +class IsaEmitterPass: + def __init__(self, shim: ProgramShim) -> None: + self.shim = shim + self.emitter = ISAEmitter(shim) + # Symbol table: tir.Var -> currently-bound GP register id. Loop + # bodies push entries on enter and pop on exit. ExprMaterializer + # consults this table to resolve Var references in scalar args. + self.symbol_table: Dict[tir.Var, int] = {} + self.materializer = ExprMaterializer(shim, self.symbol_table) + # Global op counter for the lowir report. Advanced once per op + # (including nested ``for`` bodies) so the recorded op index + # matches the depth-first ``[NN]`` numbering ``format_hlir``'s + # ``_format_ops`` produces — the lowir report and hlir.txt then + # line up entry-for-entry. + self._lowir_idx: int = -1 + self._dispatch: Dict[str, Callable[[_hlir.HLIRModule, _hlir.Op], None]] = { + "dma_h2v": self._emit_dma_h2v, + "dma_h2m": self._emit_dma_h2m, + "dma_v2h": self._emit_dma_v2h, + "dma_h2v_slice": self._emit_dma_h2v_slice, + "dma_h2m_slice": self._emit_dma_h2m_slice, + "dma_v2h_slice": self._emit_dma_v2h_slice, + "btmm": self._emit_btmm, + "btmv": self._emit_btmv, + "mm": self._emit_mm, + "mm_slot": self._emit_mm_slot, + "matmul": self._emit_matmul, + "mv": self._emit_mv, + # 1-D vector VRAM ops — each op handles ONE logical 1D + # vector of length ``n_elem``. The emitter unrolls it into + # ``ceil(n_elem / mlen)`` HW issues; multi-row tiles are + # expressed by the lowering wrapping ``v_*`` inside a + # ``for row`` so a 2-D tile becomes M explicit per-row + # vector ops in the HLIR. + "v_zero": self._emit_v_zero, + "v_add": self._emit_v_add, + "v_sub": self._emit_v_sub, + "v_mul": self._emit_v_mul, + "v_exp": self._emit_v_exp, + "v_reci": self._emit_v_reci, + "v_sqrt": self._emit_v_sqrt, + "fp_copy_at": self._emit_fp_copy_at, + "fp_zero_at": self._emit_fp_zero_at, + "fp_add_at": self._emit_fp_add_at, + "fp_sub_at": self._emit_fp_sub_at, + "fp_mul_at": self._emit_fp_mul_at, + "fp_max_at": self._emit_fp_max_at, + "fp_exp_at": self._emit_fp_exp_at, + "fp_reci_at": self._emit_fp_reci_at, + "fp_sqrt_at": self._emit_fp_sqrt_at, + "v_fp_transfer_slice_v_to_fp": self._emit_v_fp_transfer_slice_v_to_fp, + "v_fp_transfer_slice_fp_to_v": self._emit_v_fp_transfer_slice_fp_to_v, + "copy_v_to_v": self._emit_copy_v_to_v, + # Row-level VRAM/FP ops. Contract: one HLIR op = one HW + # instruction over a SINGLE row. Multi-row callers must wrap + # in an outer HLIR ``for row``. + "row_reduce_max_at": self._emit_row_reduce_max_at, + "row_reduce_sum_at": self._emit_row_reduce_sum_at, + "row_exp": self._emit_row_exp, + "row_sub_fp": self._emit_row_sub_fp, + "row_mul_fp": self._emit_row_mul_fp, + "row_add_fp": self._emit_row_add_fp, + "for": self._emit_for, + } + + def run(self, mod: _hlir.HLIRModule) -> str: + _hlir.assert_addresses_resolved(mod) + # Emit a header so generated dumps are easy to identify in build/. + self.shim.compiler.generated_code = ( + f"; PLENA ISA -- kernel: {mod.name}\n" + f"; generated by tilelang_tvm_compiler (real-ISA path)\n" + f"; ============================================================\n" + f"; buffer layout:\n" + ) + for buf in mod.buffers.values(): + self.shim.compiler.generated_code += ( + f"; {buf.name:<10s} scope={buf.scope:<5s} addr={buf.address} " + f"shape={'x'.join(str(s) for s in buf.shape)}\n" + ) + self.shim.compiler.generated_code += ( + "; ============================================================\n\n" + ) + + ra = self.shim.compiler.register_allocator + self._lowir_idx = -1 + for op in mod.ops: + handler = self._dispatch.get(op.kind) + if handler is None: + raise IsaEmissionError( + f"no ISA dispatcher for HLIR op kind {op.kind!r}. " + f"Either add it to isa_pass dispatch table, or guard " + f"the op out of HLIR earlier." + ) + # Advance the depth-first op counter and tag the materializer + # so every address expr this op lowers is recorded under the + # same [NN] index format_hlir prints. _emit_for advances the + # counter again for each nested body op. + self._lowir_idx += 1 + i = self._lowir_idx + ra.push_site(f"op[{i}] {op.kind}") + self.materializer.set_lowir_op_idx(i) + self.materializer.begin_op() + try: + handler(mod, op) + finally: + self.materializer.end_op() + ra.pop_site() + self.shim.compiler.generated_code = _normalize_large_addi_immediates( + self.shim.compiler.generated_code + ) + return self.shim.compiler.generated_code + + @staticmethod + def _logical_2d( + shape: Tuple[int, ...], layout: str = "BSHD", + ) -> Tuple[int, int]: + return _hlir.logical_2d_extents(shape, layout) + + def _vram_row_shape(self, buf: _hlir.Buffer, op_kind: str, role: str) -> Tuple[int, int]: + _check_scope(buf, _scope.VRAM, op_kind, role) + rows, cols = self._logical_2d(buf.shape, buf.layout) + if cols != self.shim.mlen: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r} must have logical row width " + f"mlen={self.shim.mlen}; got logical 2D ({rows}, {cols})" + ) + return rows, cols + + def _resolve_row_at_coords( + self, + buf: _hlir.Buffer, + op_kind: str, + role: str, + row_expr, + head_expr, + op_axes: Optional[Tuple[Tuple[str, int], ...]] = None, + ) -> Tuple[tir.PrimExpr, tir.PrimExpr | None]: + """Translate logical ``(head=H-idx, row=S-idx)`` coords on a + VRAM buffer into a physical vram-row index + optional V_MASK. + + Driven by the per-op ``op_axes`` table that mid_ir → HLIR + lowering stamps on ``hlir.Op.buffer_axes``. Each entry pairs + a role string with the dim's extent: + + ``"simd"`` — innermost D / vector axis + ``"cluster"`` — lane / head axis + ``"batch"`` — row-fanout axis (rows OR degenerate + leading B placeholder; pick the rows + one by largest extent) + + Computation: + + flat_row = row * row_stride + head * head_stride + + where ``row_stride`` / ``head_stride`` is the product of every + physical dim's extent strictly inside that role's position + (i.e. between the role and the innermost dim). + + Then: + + * ``D >= MLEN`` and ``D % MLEN == 0``: each flat_row is one + full mlen vector. ``vram_row = flat_row``; no mask. + * ``D < MLEN`` (``lane_count = MLEN/D``): ``lane_count`` + consecutive flat_rows pack into one mlen-row. + ``vram_row = flat_row // lane_count``, + ``mask = 1 << (flat_row % lane_count)``. + """ + _check_scope(buf, _scope.VRAM, op_kind, role) + if not buf.shape: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: empty shape" + ) + mlen = int(self.shim.mlen) + rank = len(buf.shape) + + if op_axes is None: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: op_axes required " + f"(callers must thread per-op axes from hlir.Op.buffer_axes " + f"so this resolver can locate the rows / cluster / D dim " + f"by role tag instead of guessing from shape)." + ) + if len(op_axes) != rank: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: op_axes rank " + f"{len(op_axes)} doesn't match buffer rank {rank} " + f"(shape={list(buf.shape)} op_axes={list(op_axes)})" + ) + + # Locate dims by role. ``simd`` must exist and is the innermost + # vector axis. ``cluster`` is optional (rank-2 fragments don't + # have one). ``batch`` may appear multiple times — pick the one + # with the largest extent as the rows axis; extent-1 batch + # entries are pad-to-4D / cluster-expand placeholders. + d_axis: Optional[int] = None + cluster_dim: Optional[int] = None + rows_axis: Optional[int] = None + rows_extent = -1 + for i, (role_name, _extent) in enumerate(op_axes): + if role_name == "simd": + d_axis = i + elif role_name == "cluster": + cluster_dim = i + elif role_name == "batch": + if int(_extent) > rows_extent: + rows_extent = int(_extent) + rows_axis = i + if d_axis is None: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: no ``simd`` axis " + f"in op_axes {list(op_axes)}" + ) + d_dim = int(buf.shape[d_axis]) + + if rank >= 3 and rows_axis is not None: + head_stride = 1 + if cluster_dim is not None: + lo = min(cluster_dim, d_axis) + 1 + hi = max(cluster_dim, d_axis) + for axis in range(lo, hi): + head_stride *= int(buf.shape[axis]) + row_stride = 1 + lo = min(rows_axis, d_axis) + 1 + hi = max(rows_axis, d_axis) + for axis in range(lo, hi): + row_stride *= int(buf.shape[axis]) + terms = [] + if head_stride == 1: + terms.append(head_expr) + elif head_stride > 1: + terms.append( + tir.Mul(head_expr, tir.IntImm("int32", head_stride)) + ) + if row_stride == 1: + terms.append(row_expr) + else: + terms.append( + tir.Mul(row_expr, tir.IntImm("int32", row_stride)) + ) + flat_row = terms[0] + for t in terms[1:]: + flat_row = tir.Add(flat_row, t) + else: + flat_row = row_expr + + # Case 1: D-wide row is at least a full mlen vector. Each + # flat_row IS one full mlen-row; no mask. + if d_dim >= mlen and d_dim % mlen == 0: + return flat_row, None + + # Case 2: D < MLEN — ``lane_count`` flat_rows pack into one + # mlen-row. vram_row = flat_row // lane_count; + # mask = 1 << (flat_row % lane_count). + if mlen % d_dim == 0: + lane_count = mlen // d_dim + log2_lc = (lane_count - 1).bit_length() + if (1 << log2_lc) != lane_count: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: lane_count " + f"{lane_count} (=MLEN/{d_dim}) is not a power of two; " + f"shift / mask shortcut for the narrow-D path requires it." + ) + vram_row_expr = tir.shift_right( + flat_row, tir.IntImm("int32", log2_lc), + ) + # PLENA has no bitwise-AND; ``flat_row % lane_count`` is + # ``flat_row - ((flat_row >> k) << k)``. + quotient_shifted_back = tir.shift_left( + vram_row_expr, tir.IntImm("int32", log2_lc), + ) + col_in_row = tir.Sub(flat_row, quotient_shifted_back) + mask_expr = tir.shift_left(tir.IntImm("int32", 1), col_in_row) + return vram_row_expr, mask_expr + + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r}: innermost dim {d_dim} " + f"is neither a multiple of MLEN ({mlen}) nor a divisor — no " + f"unified row_*_at addressing path for it." + ) + + def _resolve_fp_scalar_addr_arg( + self, + mod: _hlir.HLIRModule, + arg, + op_kind: str, + role: str, + ): + if isinstance(arg, _hlir.BufferElement): + buf = mod.get_buffer(arg.buffer) + _check_scope(buf, _scope.FPRAM, op_kind, role) + if len(arg.indices) != len(buf.shape): + raise IsaEmissionError( + f"{op_kind} {role} buffer element {buf.name!r} has index rank {len(arg.indices)} " + f"but buffer shape rank {len(buf.shape)}" + ) + offset = tir.IntImm("int32", 0) + stride = 1 + for dim, idx in zip(reversed(buf.shape), reversed(arg.indices)): + idx_expr = tir.IntImm("int32", int(idx)) if isinstance(idx, int) else idx + term = idx_expr if stride == 1 else tir.Mul( + idx_expr, tir.IntImm("int32", int(stride)), + ) + offset = term if stride == 1 and isinstance(offset, tir.IntImm) and int(offset.value) == 0 else tir.Add(term, offset) + stride *= int(dim) + return tir.Add(tir.IntImm("int32", int(buf.address)), offset) + if isinstance(arg, (int, tir.PrimExpr)): + return arg + raise IsaEmissionError( + f"{op_kind} {role} expects an FPRAM address or buffer element ref; " + f"got {type(arg).__name__}: {arg!r}" + ) + + def _emit_fp_scalar_op_at( + self, + mod: _hlir.HLIRModule, + op: _hlir.Op, + *, + kernel_op: str, + ) -> None: + # FP `_at` operands are scalar fpram addresses (PrimExpr or int), + # already including any per-slot base offset. We materialize each + # address into its own GP register and emit S_LD_FP / S_ST_FP. + if kernel_op in {"copy", "exp", "reci", "sqrt"}: + expected = 2 + else: + expected = 3 + if len(op.scalar_args) != expected: + raise IsaEmissionError( + f"{op.kind} expects {expected} scalar address args, got {len(op.scalar_args)}" + ) + + addr_exprs = [ + self._resolve_fp_scalar_addr_arg(mod, a, op.kind, f"arg{i}") + for i, a in enumerate(op.scalar_args) + ] + # Materialize one address expression at a time, commit its ISA to + # ``generated_code`` immediately, and ``pin_gp`` the result reg so + # the next materialize() call cannot auto-spill it. + # + # Why pinning is required: + # ExprMaterializer eagerly frees operand registers after writing a + # binop's ISA text. ``allocate_gp`` then auto-spills the + # most-recently-allocated in-use reg when pressure is high — and + # that "most-recently-allocated" reg can be the previous mats[i]'s + # final reg itself. The spill stores the value to IntRAM but the + # MaterializedExpr's ``register`` field still names the same reg, + # which the next materialize() then overwrites. Net effect: by + # the time we emit S_LD_FP/S_MUL_FP, mats[0]/mats[1] both point + # at a reg holding mats[2]'s value. Pinning blocks that path. + ra = self.shim.compiler.register_allocator + mats: List[MaterializedExpr] = [] + for a in addr_exprs: + m = self.materializer.materialize(a) + self.shim.compiler.generated_code += m.isa + ra.pin_gp(m.register) + mats.append(m) + + try: + lines = [f"; fp scalar task {op.annotations.get('intrinsic', op.kind)} op={kernel_op}"] + if kernel_op in {"copy", "exp", "reci", "sqrt"}: + gp_src, gp_dst = mats[0].register, mats[1].register + lines.append(f"S_LD_FP f1, gp{gp_src}, 0") + if kernel_op == "exp": + lines.append("S_EXP_FP f1, f1, 0") + elif kernel_op == "reci": + lines.append("S_RECI_FP f1, f1") + elif kernel_op == "sqrt": + lines.append("S_SQRT_FP f1, f1") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + else: + gp_lhs, gp_rhs, gp_dst = mats[0].register, mats[1].register, mats[2].register + opcode = { + "add": "S_ADD_FP", + "sub": "S_SUB_FP", + "mul": "S_MUL_FP", + "max": "S_MAX_FP", + }[kernel_op] + lines.append(f"S_LD_FP f1, gp{gp_lhs}, 0") + lines.append(f"S_LD_FP f2, gp{gp_rhs}, 0") + lines.append(f"{opcode} f1, f1, f2") + lines.append(f"S_ST_FP f1, gp{gp_dst}, 0") + self.shim.compiler.generated_code += "\n".join(lines) + "\n" + finally: + for m in reversed(mats): + ra.unpin_gp(m.register) + m.release() + + def _tile_layout_strides(self, buf: _hlir.Buffer): + """Element-strides for the 7D physical layout of a VRAM/MRAM buffer. + + Mirrors the stride math ``_slice_tile_grid`` uses for the + ``tile_layout is not None`` branch but factored out so non-DMA + emitters can read it. Returns a dict:: + + { + "d_tiles": outer d-tile count, + "s_tiles": outer s-tile count, + "h_groups": outer h-group count, + "logical_b": batch count (almost always 1), + "mlen": inner s-tile height, + "lane_count": inner lane count (1 when D >= mlen), + "d_inner": inner d width (mlen when D >= mlen), + "s_inner_stride": elements between consecutive s_inner rows + in the same tile (= lane_count * d_inner), + "h_grp_stride": elements between consecutive h_groups, + "s_tile_stride": elements between consecutive s_tiles, + "d_tile_stride": elements between consecutive d_tiles, + } + + Buffers without a ``tile_layout`` (rank < 4, or 4D with + ``d_tiles == s_tiles == h_groups == logical_b == 1`` flattened + away upstream) get a ``None`` return — callers fall back to + their pre-7D row-major-flat path. + """ + tl = getattr(buf, "tile_layout", None) + if tl is None: + return None + inner_d = int(tl.d_inner) + inner_lane = int(tl.lane_count) * inner_d + inner_s = int(tl.mlen) * inner_lane + b_stride = inner_s + inner_b = int(tl.logical_b) * inner_s + h_grp_stride = inner_b + s_tile_stride = int(tl.h_groups) * inner_b + d_tile_stride = int(tl.s_tiles) * s_tile_stride + return { + "d_tiles": int(tl.d_tiles), + "s_tiles": int(tl.s_tiles), + "h_groups": int(tl.h_groups), + "logical_b": int(tl.logical_b), + "mlen": int(tl.mlen), + "lane_count": int(tl.lane_count), + "d_inner": int(tl.d_inner), + "s_inner_stride": inner_lane, + "h_grp_stride": h_grp_stride, + "s_tile_stride": s_tile_stride, + "d_tile_stride": d_tile_stride, + "b_stride": b_stride, + } + + def _buffer_tile_grid_iter(self, buf: _hlir.Buffer): + """Yield every mlen-row of ``buf`` in physical address order. + + For each ``(d_tile, s_tile, h_grp, b, s_inner)`` cell of the + buffer's 7D physical layout, yields + ``(d_tile, s_tile, h_grp, b, s_inner, phys_offset)`` where + ``phys_offset`` is in *elements* relative to ``buf.address``. + One yielded entry == one HW mlen-wide vector op + (``V_*_VV`` / ``V_*_VF`` / ``V_RED_*`` / ``V_EXP_V`` / etc.). + + Walks the outer tile grid in physical-address order + (d_tile slowest, b fastest at the tile level), then steps + through s_inner inside each tile. ``s_inner`` covers + ``tl.mlen`` rows because each s_inner row is one HW vector + (its width is ``lane_count * d_inner == mlen`` by tile-layout + invariant). Use this in any emitter that wants to issue one + HW vector op per mlen-row covering the whole buffer (v_zero, + tile_add, tile_mul, …) — it hides the 7D layout behind one + loop and stays in lock-step with what ``_slice_tile_grid`` + walks for DMAs. + + Buffers that aren't 4D (no ``tile_layout``) raise — callers + with 1D / 2D scratch must use the legacy flat-offset + emitters; everything pad-to-4D'd by to_plena hits this path. + """ + info = self._tile_layout_strides(buf) + if info is None: + raise IsaEmissionError( + f"_buffer_tile_grid_iter: buffer {buf.name!r} has no " + f"tile_layout (rank={len(buf.shape)}, shape={tuple(buf.shape)}) " + f"— this helper only handles 4D buffers; 1D/2D callers must " + f"stay on the flat-offset path." + ) + s_inner_stride = info["s_inner_stride"] + for d_tile in range(info["d_tiles"]): + for s_tile in range(info["s_tiles"]): + for h_grp in range(info["h_groups"]): + for b in range(info["logical_b"]): + tile_base = ( + d_tile * info["d_tile_stride"] + + s_tile * info["s_tile_stride"] + + h_grp * info["h_grp_stride"] + + b * info["b_stride"] + ) + for s_inner in range(info["mlen"]): + phys = tile_base + s_inner * s_inner_stride + yield (d_tile, s_tile, h_grp, b, + s_inner, phys) + + def _logical_to_phys_row_offset( + self, + buf: _hlir.Buffer, + region: _hlir.VramRegion, + ): + """Translate a single-row ``VramRegion`` into the physical + 7D *mlen-row base* offset (in elements, relative to + ``buf.address``) plus an optional packed-head mask expression. + + Returns ``(phys_offset_expr, mask_expr_or_None, info)``. + + ``region`` must describe exactly one logical row: + ``extents = (..., 1, ..., 1, D_full)`` with non-D extents + equal to 1. ``starts`` is 4 entries in physical-axis order + (matching ``buf.shape``); the helper routes each entry to + the right stride **by the buffer's role tags**, so it works + for any lane-fusion mode: + + * col_pack: shape=(B=1, S, H=lane, D_narrow), cluster_dim=2 + → starts[2] is the lane index, gets split into + (h_grp, lane); mask = 1 << lane. + * row_stack: shape=(B=lane, S, H=1, MLEN), cluster_dim=0 + → starts[0] is the lane index, same split logic. + * bshd_lift / no cluster: cluster_dim is None, every + non-D, non-rows axis is either a B placeholder or H + placeholder; their starts contribute via the matching + stride (b_stride / h_grp_stride). lane_count == 1 here, + so mask_expr is None. + """ + info = self._tile_layout_strides(buf) + if info is None: + raise IsaEmissionError( + f"_logical_to_phys_row_offset: buffer {buf.name!r} has no " + f"tile_layout (rank={len(buf.shape)}, shape={tuple(buf.shape)})" + ) + mlen = info["mlen"] + lane_count = info["lane_count"] + if len(region.starts) != 4 or len(region.extents) != 4: + raise IsaEmissionError( + f"_logical_to_phys_row_offset: region {region.parent!r} " + f"must be 4D; got starts={tuple(region.starts)} " + f"extents={tuple(region.extents)}" + ) + + cluster_dim = getattr(buf, "cluster_dim", None) + rank = len(buf.shape) + d_axis = rank - 1 + + # Per-axis role assignment (matches the layout + # ``_hlir_axes_for_buffer`` produces in mid_ir): the innermost + # axis is "d", the cluster_dim (if set) is "lane", everything + # else is "batch". Among batch axes the one with the largest + # extent acts as "rows"; the rest are pad placeholders that + # carry a B=1 or H=1 stride term. + roles: List[str] = [] + for i in range(rank): + if i == d_axis: + roles.append("d") + elif cluster_dim is not None and i == cluster_dim: + roles.append("lane") + else: + roles.append("batch") + + def _to_expr(x): + if isinstance(x, int): + return tir.IntImm("int32", int(x)) + return x + + # Per-physical-axis "step one unit along this axis" stride + # table (in elements). Indexed by axis position the same way + # ``roles`` is, so any axis-handling branch (rows / lane / + # batch placeholder) can look up its stride without caring + # about which role label happens to live there. The S axis + # gets its s_inner_stride here; the s_tile/s_inner split + # for the multi-s_tile case is handled in the i==1 branch. + axis_stride = [ + info["b_stride"], + info["s_inner_stride"], + info["h_grp_stride"], + 1, + ] + + # Per-axis stride contribution. We iterate physical axes + # in order so any 7D nuance (s_tile / s_inner split when + # s_tiles > 1) lives next to its axis's stride choice. + terms: List = [] + mask_expr = None + + for i, role in enumerate(roles): + start = _to_expr(region.starts[i]) + if role == "d": + # D start is folded into d_tile bump by the caller's + # outer loop; the per-issue base is always at the + # logical row's d-tile=0 chunk. Non-zero d-starts are + # not supported here. + if not (isinstance(region.starts[i], (int, tir.IntImm)) + and (int(region.starts[i]) + if isinstance(region.starts[i], int) + else int(region.starts[i].value)) == 0): + raise IsaEmissionError( + f"row_*_at on {buf.name!r}: d-axis start must be " + f"0; got {region.starts[i]!r}" + ) + continue + if role == "lane": + # col_pack puts cluster_dim at axis 2 (H) with + # lane_count>1 (within-mlen packing); axis_stride[2] = + # h_grp_stride is correct there. + # row_stack puts cluster_dim at axis 0 (B) with + # lane_count==1 (no packed-head — d already fills mlen). + # Per the M_BMM_WO / M_BMV_WO hardware writeback in + # ``transactional_emulator/src/main.rs``, lane j's data + # lands at ``vec_base + j * (per_lane_elems)`` where + # per_lane_elems = product(shape[1:]) — i.e. the flat + # row-major stride for axis 0. For S=mlen this equals + # ``mlen*inner_lane = b_stride`` (default works); for + # S 1: + h_grp = tir.floordiv( + start, tir.IntImm("int32", lane_count), + ) + lane = tir.floormod( + start, tir.IntImm("int32", lane_count), + ) + if lane_stride: + terms.append( + tir.Mul(h_grp, + tir.IntImm("int32", lane_stride)) + ) + mask_expr = tir.shift_left( + tir.IntImm("int32", 1), lane, + ) + else: + if lane_stride: + terms.append( + tir.Mul(start, + tir.IntImm("int32", lane_stride)) + ) + continue + # role == "batch": could be the rows axis or a degenerate + # placeholder. The stride is determined by where this axis + # sits in the physical layout. + # + # Two physical positions matter: + # * S row (s_inner inside an s_tile, multiplied by + # ``s_inner_stride``). When s_tiles > 1 the value is + # also split into (s_tile, s_inner). + # * B placeholder (a leading batch dim). It rides on + # ``b_stride``. + # * H placeholder (when cluster_dim is at a different + # position than this batch axis, e.g. row_stack puts + # H=1 at axis 2 with role "batch"). It rides on + # ``h_grp_stride``. + # + # We disambiguate by axis index: the axis index immediately + # before d_axis (or cluster_dim, whichever is later) is + # treated as H if it differs from the rows position. The + # leading axis is B. The S axis is the only batch axis + # whose extent in ``buf.shape`` matches a "real" rows + # dimension (not 1). + # + # In practice we route by axis index: + # i == 0 and cluster_dim != 0 → B placeholder (b_stride) + # i == 1 → S (s_inner_stride / s_tile_stride split) + # i == 2 and cluster_dim != 2 → H placeholder (h_grp_stride) + if i == 0: + if info["b_stride"]: + terms.append( + tir.Mul(start, tir.IntImm("int32", info["b_stride"])) + ) + elif i == 1: + if info["s_tiles"] > 1: + s_tile = tir.floordiv(start, tir.IntImm("int32", mlen)) + s_inner = tir.floormod(start, tir.IntImm("int32", mlen)) + terms.append( + tir.Mul(s_tile, + tir.IntImm("int32", info["s_tile_stride"])) + ) + terms.append( + tir.Mul(s_inner, + tir.IntImm("int32", info["s_inner_stride"])) + ) + else: + if info["s_inner_stride"] == 1: + terms.append(start) + else: + terms.append( + tir.Mul(start, + tir.IntImm("int32", info["s_inner_stride"])) + ) + elif i == 2: + if info["h_grp_stride"]: + terms.append( + tir.Mul(start, + tir.IntImm("int32", info["h_grp_stride"])) + ) + else: + raise IsaEmissionError( + f"row_*_at on {buf.name!r}: unexpected batch axis at " + f"physical index {i}" + ) + + if not terms: + return tir.IntImm("int32", 0), mask_expr, info + expr = terms[0] + for t in terms[1:]: + expr = tir.Add(expr, t) + return expr, mask_expr, info + + def _region_origin_offset(self, buf: _hlir.Buffer, + region) -> "tir.PrimExpr | int": + """Translate a Region's ``starts`` into a physical element + offset against ``buf.address``. + + Handles the cluster axis specially: when the cluster axis is + packed-head (``lane_count > 1``), an index value < lane_count + is a within-mlen lane segment whose stride is ``d_inner`` + (not ``h_grp_stride``). Larger indices split into + ``h_grp = idx // lane_count`` (walks ``h_grp_stride``) and + ``lane = idx % lane_count`` (walks ``d_inner``). For non- + packed buffers (``lane_count == 1``) the cluster axis stride + is ``h_grp_stride`` directly. + """ + tl_info = self._tile_layout_strides(buf) + if tl_info is None: + for i, s in enumerate(region.starts): + if isinstance(s, int) and s == 0: + continue + if isinstance(s, tir.IntImm) and int(s.value) == 0: + continue + raise IsaEmissionError( + f"_region_origin_offset: {region.parent!r} non-zero " + f"start at axis {i} but parent has no tile_layout" + ) + return 0 + cluster_dim = getattr(buf, "cluster_dim", None) + lane_count = int(tl_info["lane_count"]) + d_inner = int(tl_info["d_inner"]) + h_grp_stride = int(tl_info["h_grp_stride"]) + s_inner_stride = int(tl_info["s_inner_stride"]) + b_stride = int(tl_info["b_stride"]) + + def _is_zero(x) -> bool: + if isinstance(x, int) and x == 0: + return True + if isinstance(x, tir.IntImm) and int(x.value) == 0: + return True + return False + + def _mul(expr, k: int): + if k == 0: + return tir.IntImm("int32", 0) + if isinstance(expr, int): + return tir.IntImm("int32", expr * k) + if isinstance(expr, tir.IntImm): + return tir.IntImm("int32", int(expr.value) * k) + if k == 1: + return expr + return tir.Mul(expr, tir.IntImm("int32", k)) + + terms = [] + for i, s in enumerate(region.starts): + if _is_zero(s): + continue + if (cluster_dim is not None and i == cluster_dim + and lane_count == 1 and i == 0): + # row_stack lane axis: B carries lane stacking, no + # packed-head (lane_count == 1). Each lane occupies + # ``shape[1] * shape[2] * shape[3]`` elements + # (logical_s * h_groups_etc * inner_lane in the 7D + # picture), NOT the per-axis-table ``b_stride = + # mlen * inner_lane`` (which assumes B is outer batch + # over a full-mlen s_tile). For S_loc with S=mlen the + # two values coincide; for S 1: + # Packed-head lane axis. For a static int we split + # into (h_grp, lane) the usual way. For a PrimExpr + # (typically the cluster phase var, which mid_ir + # guarantees is < lane_count) we skip the floordiv / + # floormod split and emit ``s * d_inner`` directly — + # the floordiv/floormod expansion materialises into + # an SRLI+SLLI+SUB chain that easily exhausts the + # 16-GP budget when nested inside outer loops, and + # the result simplifies to the same expression. + if isinstance(s, int): + h_grp = s // lane_count + lane = s % lane_count + if h_grp: + terms.append( + tir.IntImm("int32", h_grp * h_grp_stride) + ) + if lane: + terms.append( + tir.IntImm("int32", lane * d_inner) + ) + elif isinstance(s, tir.IntImm): + val = int(s.value) + h_grp = val // lane_count + lane = val % lane_count + if h_grp: + terms.append( + tir.IntImm("int32", h_grp * h_grp_stride) + ) + if lane: + terms.append( + tir.IntImm("int32", lane * d_inner) + ) + else: + terms.append(_mul(s, d_inner)) + continue + # Regular axis: pick stride from per-axis table. + if i == 0: + stride = b_stride + elif i == 1: + stride = s_inner_stride + elif i == 2: + stride = h_grp_stride + else: + stride = 1 + terms.append(_mul(s, stride)) + + terms = [t for t in terms + if not (isinstance(t, tir.IntImm) and int(t.value) == 0)] + if not terms: + return 0 + acc = terms[0] + for t in terms[1:]: + acc = tir.Add(acc, t) + return acc + + def _d_tile_info(self, buf: _hlir.Buffer) -> Tuple[int, int]: + """Return ``(n_d_tiles, d_tile_stride_elems)`` for a VRAM buffer. + + Thin convenience wrapper over ``_tile_layout_strides`` — + ``row_*_at`` emitters only ever walk the d-tile axis (the row / + head coords pick a specific s_inner + h_grp + b within one + d-tile plane), so they just need the outer d-tile count and + the bump amount. + """ + info = self._tile_layout_strides(buf) + if info is None or info["d_tiles"] <= 1: + return 1, 0 + return info["d_tiles"], info["d_tile_stride"] + + def _emit_row_scalar_op_at( + self, + mod: _hlir.HLIRModule, + op: _hlir.Op, + *, + row_op: str, + reduce: bool = False, + masked: bool = False, + has_fp: bool = False, + ) -> None: + """Row-scalar HW vector op on a single logical row of VRAM. + + Region schema (every variant): + * reduce_max / reduce_sum: + buffer_args = [src_region] + scalar_args = [fp_addr (BufferElement)] + * exp / reci (no FP operand): + buffer_args = [src_region, dst_region] + scalar_args = [] + * add / sub / mul fp: + buffer_args = [src_region, dst_region] + scalar_args = [fp_addr (BufferElement)] + + ``src_region`` / ``dst_region`` are ``VramRegion`` with 4D + BSHD starts/extents picking exactly one (b, s, h) logical + row. ``extents`` must be ``(1, 1, 1, D_full)`` — the emitter + walks d_tiles itself. ``starts[2]`` (the H index) is *not* + clipped to a single lane in packed-head buffers: it carries + the actual head idx (0..head_count-1), and the emitter splits + it into (h_grp, lane) — h_grp picks the mlen-row, lane drives + the ``V_MASK`` bit so the V_*_VF only updates the target + head's data. + """ + has_fp = has_fp or reduce + if reduce: + if len(op.buffer_args) != 1: + raise IsaEmissionError( + f"{op.kind} expects 1 buffer_arg (src region); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 1 + elif has_fp: + if len(op.buffer_args) != 2: + raise IsaEmissionError( + f"{op.kind} expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 1 + else: + if len(op.buffer_args) != 2: + raise IsaEmissionError( + f"{op.kind} expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 0 + if len(op.scalar_args) != expected_scalar: + raise IsaEmissionError( + f"{op.kind} expects {expected_scalar} scalar_args; " + f"got {len(op.scalar_args)}" + ) + for slot, name in enumerate( + ("src",) if reduce else ("src", "dst") + ): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise IsaEmissionError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + + src_region: _hlir.VramRegion = op.buffer_args[0] + src = mod.get_buffer(src_region.parent) + _check_scope(src, _scope.VRAM, op.kind, "src") + if len(src_region.extents) != 4: + raise IsaEmissionError( + f"{op.kind} src: region must be 4D; got " + f"extents={tuple(src_region.extents)}" + ) + # All non-D extents must be 1 (one logical row per op). + if any(int(e) != 1 for e in src_region.extents[:3]): + raise IsaEmissionError( + f"{op.kind} src: row_*_at processes one logical row, " + f"non-D extents must be 1; got " + f"{tuple(src_region.extents[:3])}" + ) + + fp_addr_expr = None + if has_fp: + fp_addr_expr = self._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + + src_base_off, src_mask_expr, src_info = self._logical_to_phys_row_offset( + src, src_region, + ) + emit_v_mask = masked and src_mask_expr is not None + use_mask_flag = 1 if emit_v_mask else 0 + + mats = [] + m_src = self.materializer.materialize( + tir.Add(tir.IntImm("int32", int(src.address)), src_base_off) + ) + self.shim.compiler.generated_code += m_src.isa + mats.append(m_src) + gp_src = m_src.register + + gp_mask = None + try: + lines = [ + f"; row scalar task {op.annotations.get('intrinsic', op.kind)} " + f"op={row_op} " + f"src.parent={src_region.parent} " + f"starts={list(src_region.starts)!r}" + ] + if emit_v_mask: + m_mask = self.materializer.materialize(src_mask_expr) + self.shim.compiler.generated_code += m_mask.isa + mats.append(m_mask) + gp_mask = m_mask.register + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + + n_d_tiles = src_info["d_tiles"] + d_tile_stride_s = src_info["d_tile_stride"] + + if reduce: + # buffer_args=[src_region]; FP destination is scalar_args[0]. + m_dst = self.materializer.materialize(fp_addr_expr) + self.shim.compiler.generated_code += m_dst.isa + mats.append(m_dst) + opcode = {"reduce_max": "V_RED_MAX", + "reduce_sum": "V_RED_SUM"}[row_op] + # V_RED_* accumulate into f1; load the FPRAM slot + # first so kernels that pre-seeded it see the seed. + # Across d_tiles, accumulate into the same f1. + lines.append(f"S_LD_FP f1, gp{m_dst.register}, 0") + for t in range(n_d_tiles): + lines.append(f"{opcode} f1, gp{gp_src}, {use_mask_flag}") + if t < n_d_tiles - 1: + lines.append( + f"S_ADDI_INT gp{gp_src}, gp{gp_src}, " + f"{d_tile_stride_s}" + ) + lines.append(f"S_ST_FP f1, gp{m_dst.register}, 0") + else: + dst_region: _hlir.VramRegion = op.buffer_args[1] + dst = mod.get_buffer(dst_region.parent) + _check_scope(dst, _scope.VRAM, op.kind, "dst") + if len(dst_region.extents) != 4: + raise IsaEmissionError( + f"{op.kind} dst: region must be 4D; got " + f"extents={tuple(dst_region.extents)}" + ) + if any(int(e) != 1 for e in dst_region.extents[:3]): + raise IsaEmissionError( + f"{op.kind} dst: non-D extents must be 1; " + f"got {tuple(dst_region.extents[:3])}" + ) + dst_base_off, dst_mask_expr, dst_info = ( + self._logical_to_phys_row_offset(dst, dst_region) + ) + if emit_v_mask and dst_mask_expr is None: + raise IsaEmissionError( + f"{op.kind} src requires packed-head mask but dst " + f"{dst.name!r} does not" + ) + if emit_v_mask and dst_region.parent != src_region.parent: + warnings.warn( + f"{op.kind}: masked V_*_V with dst " + f"{dst_region.parent!r} != src " + f"{src_region.parent!r} — unmasked heads will " + f"overwrite dst with src; previous cross-lane " + f"writes to dst will be lost. Use in-place " + f"(dst == src) or insert an explicit copy_v_to_v.", + RuntimeWarning, + stacklevel=2, + ) + if dst_info["d_tiles"] != n_d_tiles: + raise IsaEmissionError( + f"{op.kind}: src/dst d_tiles mismatch " + f"({n_d_tiles} vs {dst_info['d_tiles']})" + ) + d_tile_stride_d = dst_info["d_tile_stride"] + m_dst = self.materializer.materialize( + tir.Add(tir.IntImm("int32", int(dst.address)), dst_base_off) + ) + self.shim.compiler.generated_code += m_dst.isa + mats.append(m_dst) + + if fp_addr_expr is None: + # exp / reci + opcode = {"exp": "V_EXP_V", "reci": "V_RECI_V"}[row_op] + for t in range(n_d_tiles): + lines.append( + f"{opcode} gp{m_dst.register}, gp{gp_src}, " + f"{use_mask_flag}" + ) + if t < n_d_tiles - 1: + lines.append( + f"S_ADDI_INT gp{gp_src}, gp{gp_src}, " + f"{d_tile_stride_s}" + ) + lines.append( + f"S_ADDI_INT gp{m_dst.register}, " + f"gp{m_dst.register}, {d_tile_stride_d}" + ) + else: + # add / sub / mul with FP scalar + m_rhs = self.materializer.materialize(fp_addr_expr) + self.shim.compiler.generated_code += m_rhs.isa + mats.append(m_rhs) + lines.append(f"S_LD_FP f1, gp{m_rhs.register}, 0") + for t in range(n_d_tiles): + if row_op == "sub": + lines.append( + f"V_SUB_VF gp{m_dst.register}, gp{gp_src}, " + f"f1, {use_mask_flag}, 0" + ) + else: + opcode = {"add": "V_ADD_VF", + "mul": "V_MUL_VF"}[row_op] + lines.append( + f"{opcode} gp{m_dst.register}, gp{gp_src}, " + f"f1, {use_mask_flag}" + ) + if t < n_d_tiles - 1: + lines.append( + f"S_ADDI_INT gp{gp_src}, gp{gp_src}, " + f"{d_tile_stride_s}" + ) + lines.append( + f"S_ADDI_INT gp{m_dst.register}, " + f"gp{m_dst.register}, {d_tile_stride_d}" + ) + + if emit_v_mask: + lines.append(f"S_ADDI_INT gp{gp_mask}, gp0, 0") + lines.append(f"C_SET_V_MASK_REG gp{gp_mask}") + self.shim.compiler.generated_code += "\n".join(lines) + "\n" + finally: + for m in reversed(mats): + m.release() + + # ------------------------------------------------------------------ + # Per-op dispatchers. Each one is a thin glue between HLIR buffer + # references and ISAEmitter's positional/keyword API. + # ------------------------------------------------------------------ + # ---- DMA decomposition -------------------------------------------- + # Each emit_load/store_tile_from_hbm transfers exactly ONE mlen x mlen + # tile. For HBM buffers whose logical 2D shape spans multiple tiles + # we issue one ISA emit call per tile, walking the buffer in + # col-block-major order (matches --stage-output / runtime helper). + def _iter_tile_offsets(self, hbm_buf: _hlir.Buffer): + """Yield (vram_offset_elems, hbm_offset_elems) for each tile of buf.""" + mlen = self.shim.mlen + ann = hbm_buf.annotations + rows = ann.get("logical_rows", mlen) + cols = ann.get("logical_cols", mlen) + row_blocks = ann.get("row_blocks", 1) + col_blocks = ann.get("col_blocks", 1) + tile_elems = mlen * mlen + idx = 0 + for j in range(col_blocks): + for i in range(row_blocks): + # HBM offset (in elements) of the (i,j) tile within this + # logical 2D buffer. Logical layout is row-major so column + # j contributes j*mlen, row i contributes i*mlen*cols. + hbm_off = i * mlen * cols + j * mlen + vram_off = idx * tile_elems + yield vram_off, hbm_off + idx += 1 + + def _emit_dma_h2v(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + _check_scope(src, _scope.HBM, op.kind, "src") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + for vram_off, hbm_off in self._iter_tile_offsets(src): + self.shim.compiler.generated_code += ( + f"; dma_h2v tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[vram+{vram_off}]\n" + ) + self.emitter.emit_load_tile_from_hbm( + hbm_addr=src.address, + vram_addr=dst.address + vram_off, + hbm_stride=src.hbm_stride, + hbm_scale_size=src.hbm_scale_size, + hbm_start_offset=src.hbm_offset + hbm_off, + ) + + def _emit_dma_h2m(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + _check_scope(src, _scope.HBM, op.kind, "src") + _check_scope(dst, _scope.MRAM, op.kind, "dst") + for vram_off, hbm_off in self._iter_tile_offsets(src): + self.shim.compiler.generated_code += ( + f"; dma_h2m tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[mram+{vram_off}]\n" + ) + self.emitter.emit_hbm_tile_to_mram( + hbm_addr=src.address, + mram_addr=dst.address + vram_off, + hbm_offset=src.hbm_offset + hbm_off, + hbm_scale=src.hbm_scale_size, + hbm_stride=src.hbm_stride, + ) + + def _emit_dma_v2h(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + # Convention: HBM buffers are BSHD; VRAM buffers may be BHSD if + # they came out of BTMM/BMM_WO (head-major physical layout). This + # iteration walks the HBM (dst) in col-block-major order, which + # happens to land vram_off = idx * tile_elems on each head's tile + # boundary -- exactly matching BMM_WO's BHSD VRAM packing. The + # store thus reorders BHSD -> BSHD as a side-effect of the walk. + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + _check_scope(src, _scope.VRAM, op.kind, "src") + _check_scope(dst, _scope.HBM, op.kind, "dst") + if src.num_elements != dst.num_elements: + raise IsaEmissionError( + f"dma_v2h: src ({src.name}, {src.num_elements} elems) and dst " + f"({dst.name}, {dst.num_elements} elems) must have the same total " + f"size, but their layouts may differ (VRAM=BHSD, HBM=BSHD)." + ) + for vram_off, hbm_off in self._iter_tile_offsets(dst): + self.shim.compiler.generated_code += ( + f"; dma_v2h tile {src.name}[vram+{vram_off}] -> " + f"{dst.name}[hbm+{hbm_off}]\n" + ) + self.emitter.emit_store_tile_to_hbm( + vram_addr=src.address + vram_off, + hbm_addr=dst.address, + hbm_stride=dst.hbm_stride, + hbm_scale_size=dst.hbm_scale_size, + hbm_start_offset=dst.hbm_offset + hbm_off, + ) + + # ------------------------------------------------------------------ + # Sliced DMA dispatchers. The slice is one of buffer_args (src or + # dst depending on direction). For now we restrict to STATIC starts + # (all Python ints / IntImm); dynamic starts (PrimExpr) raise a + # clear error pointing at the next phase. + # ------------------------------------------------------------------ + def _slice_offset_static( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ) -> int: + """Same math as `_build_slice_offset_expr`, restricted to all-int + starts. Used in the static fast-path (avoids the extra + S_ADDI_INT...mov that the dynamic path inserts).""" + offset = 0 + shape = parent.shape + for i, s in enumerate(sl.starts): + stride_below = 1 + for d in shape[i + 1:]: + stride_below *= int(d) + offset += int(s) * stride_below + return offset + + @staticmethod + def _slice_has_dynamic_start(sl: _hlir.BufferSlice) -> bool: + return any(not isinstance(s, int) for s in sl.starts) + + def _build_slice_offset_expr( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ): + """Build a PrimExpr for the slice's element offset in `parent`'s + HBM region. Mixes static (Python int / IntImm) and dynamic + (PrimExpr) starts uniformly. ExprMaterializer's constant + folding will collapse static sub-trees automatically. + """ + offset = tir.IntImm("int32", 0) + shape = parent.shape + for i, s in enumerate(sl.starts): + stride_below = 1 + for d in shape[i + 1:]: + stride_below *= int(d) + if isinstance(s, int): + term = tir.IntImm("int32", s * stride_below) + else: + # `s` is a PrimExpr (loop var or compound); multiply by + # stride at the IR level so the materialiser can apply + # strength reduction (e.g. SLLI when stride is 2^k). + term = s * tir.IntImm("int32", stride_below) + offset = offset + term + return offset + + def _check_slice_single_tile( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ) -> None: + """For input slices (h2v / h2m): must fit in exactly one mlen*mlen + tile after H*D logical-2D collapse, since the destination + VRAM/MRAM buffer is a single-tile staging area. + """ + mlen = self.shim.mlen + ext = sl.extents + if len(ext) != len(parent.shape): + raise IsaEmissionError( + f"slice on {parent.name!r}: extents length {len(ext)} != " + f"parent ndim {len(parent.shape)}" + ) + rows, cols = _hlir.logical_2d_extents(ext, parent.layout) + if rows != mlen or cols != mlen: + raise IsaEmissionError( + f"slice on {parent.name!r} extents={ext} (layout={parent.layout!r}) " + f"maps to logical 2D ({rows}, {cols}); h2v/h2m input slices " + f"must fit a single mlen*mlen tile." + ) + + def _iter_slice_tiles_per_head(self, parent: _hlir.Buffer, sl: _hlir.BufferSlice): + """Per-head multi-tile iterator for v2h-slice writeback. + + Used when the slice covers `eh` heads but a single mlen-aligned + block in seq and dim dims. Each tile is one head's contribution. + Yields tuples `(h_idx, vram_off_in_src, tile_const_in_parent_elems)`: + * `h_idx` -- which head within the slice (0..eh-1) + * `vram_off_in_src` -- offset in elements from the VRAM source + buffer's base. Assumes BHSD physical + layout (BMM_WO output convention), + where head h's tile sits at + `h * tile_elems`. + * `tile_const_in_parent_elems` -- additive offset to combine + with the slice's BASE offset to land + at this tile's element 0 in the parent. + + Constraints (matches the v2h_slice dispatcher's expectations): + * parent is 4D BSHD + * eb == 1 (single batch) + * es == mlen (single seq tile per head) + * ed == mlen (single dim tile per head) + * eh >= 1 (any number of heads; eh > 1 is the multi-tile case) + """ + mlen = self.shim.mlen + if len(parent.shape) != 4: + raise IsaEmissionError( + f"per-head slice tiling requires 4D parent; got " + f"shape {parent.shape}" + ) + # Permute parent shape and slice extents into canonical (B, S, H, D) + # order per parent.layout. Downstream math works in BSHD. + B, S, H, D = _hlir._select_axes(parent.shape, parent.layout) + eb, es, eh, ed = _hlir._select_axes(sl.extents, parent.layout) + if eb != 1: + raise IsaEmissionError( + f"per-head slice tiling does not support batch slicing " + f"(eb={eb})" + ) + if es != mlen: + raise IsaEmissionError( + f"per-head slice tiling requires es == mlen ({mlen}); " + f"got es={es}" + ) + if ed != mlen: + raise IsaEmissionError( + f"per-head slice tiling requires ed == mlen ({mlen}); " + f"got ed={ed}" + ) + if eh < 1: + raise IsaEmissionError(f"slice has no heads to iterate (eh={eh})") + # Per-head HBM offset is h_idx * h_stride, where h_stride is + # the canonical channel-axis stride in HBM-row-major order. + # For BSHD (head outer of D): h_stride = D — what the legacy + # ``h_idx * D`` formula assumed. For NCHW (channel outer of + # H*W): h_stride = H*W — different by a row-count factor. + # ``hbm_strides_for_layout`` returns the right number for any + # layout we register. + _hb, _hs, h_stride, _hd = _hlir.hbm_strides_for_layout( + parent.shape, parent.layout, + ) + tile_elems = mlen * mlen + for h_idx in range(eh): + yield h_idx, h_idx * tile_elems, h_idx * int(h_stride) + + def _slice_is_single_logical_tile( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ) -> bool: + ext = sl.extents + if len(ext) != len(parent.shape): + return False + rows, cols = _hlir.logical_2d_extents(ext, parent.layout) + return rows == self.shim.mlen and cols == self.shim.mlen + + def _materialise_slice_offset( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ): + """Returns either: + (None, int_offset) -- all starts static; caller uses + existing int-offset emit path. + (MaterializedExpr, None) -- dynamic; caller uses *_reg emit + path and must release the result. + """ + if not self._slice_has_dynamic_start(sl): + return None, parent.hbm_offset + self._slice_offset_static(parent, sl) + # Dynamic path: build expr (includes parent.hbm_offset for safety) + # and lower via ExprMaterializer. + expr = self._build_slice_offset_expr(parent, sl) + if parent.hbm_offset: + expr = expr + tir.IntImm("int32", parent.hbm_offset) + m = self.materializer.materialize(expr) + self.shim.compiler.generated_code += m.isa + return m, None + + @staticmethod + def _format_starts(sl: _hlir.BufferSlice) -> str: + return ",".join( + str(s) if isinstance(s, int) else f"<{type(s).__name__}>" + for s in sl.starts + ) + + def _slice_tile_grid( + self, + parent: _hlir.Buffer, + sl: _hlir.BufferSlice, + on_chip: _hlir.Buffer, + ): + """Compute the inner-tile grid for one HBM↔on-chip slice copy. + + Symmetric between h2v load and v2h store: the iteration grid + only depends on the on-chip buffer's physical layout and the + slice's footprint, not on direction. The returned strides feed + either ``emit_load_tile_from_hbm`` or ``emit_store_tile_to_hbm``. + + Returns a tuple ``(d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, hbm_strides, vram_strides)`` where + ``hbm_strides`` is ``(b, s, h)`` and ``vram_strides`` is + ``(d_tile, s_tile, h_grp, b)`` — all in elements. + + Single tile (one mlen×mlen tile, or smaller) is the degenerate + case where every tile count is 1 — the caller runs one loop + iteration and emits a single H_LOAD_V / H_STORE_V. + """ + dst = on_chip + mlen = self.shim.mlen + + if dst.tile_layout is not None: + # 4D BSHD path. Parent must be 4D — we read its layout- + # aware per-axis HBM strides (b/s/h/d). + if len(parent.shape) != 4: + raise IsaEmissionError( + f"dma_h2v_slice: HBM parent {parent.name!r} must " + f"be 4D for tile_layout-driven dst; got shape " + f"{tuple(parent.shape)}" + ) + hbm_stride_b, hbm_stride_s, hbm_stride_h, _hbm_stride_d = ( + _hlir.hbm_strides_for_layout(parent.shape, parent.layout) + ) + tl = dst.tile_layout + inner_d = tl.d_inner + inner_lane = tl.lane_count * inner_d + inner_s = tl.mlen * inner_lane + b_stride = inner_s + inner_b = tl.logical_b * inner_s + h_grp_stride = inner_b + s_tile_stride = tl.h_groups * inner_b + d_tile_stride = tl.s_tiles * s_tile_stride + return ( + tl.d_tiles, tl.s_tiles, tl.h_groups, tl.logical_b, + tl.mlen, tl.lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride), + ) + + # No tile_layout: row-major flat view. dst dictates the tile + # grid; HBM strides come from parent's row-major flat layout + # treating axes[-2] as row and axes[-1] as col regardless of + # layout-name semantics. This is the right model for any + # plain ``T.alloc_shared((rows, cols))`` staging buffer. + if len(dst.shape) == 2: + dst_rows = int(dst.shape[0]) + dst_cols = int(dst.shape[1]) + elif len(dst.shape) == 1: + dst_rows = 1 + dst_cols = int(dst.shape[0]) + else: + raise IsaEmissionError( + f"dma_h2v_slice on {parent.name!r}: dst {dst.name!r} " + f"has rank {len(dst.shape)} with no tile_layout; " + f"only 1D / 2D dst supported on the row-major-flat path" + ) + if dst_rows % mlen != 0 or dst_cols % mlen != 0: + raise IsaEmissionError( + f"dma_h2v_slice on {parent.name!r}: dst {dst.name!r} " + f"shape ({dst_rows}, {dst_cols}) is not (MLEN={mlen})-" + f"aligned on both axes; partial-tile loads not supported" + ) + s_tiles = dst_rows // mlen + d_tiles = dst_cols // mlen + d_tile_stride = mlen + s_tile_stride = mlen * dst_cols + + # Parent HBM row stride = product of axes after the row axis. + # For 4D ``(N, C, H, W)`` and a per-channel slice, row = + # axes[-2] (H), col = axes[-1] (W); per-row HBM stride = W. + # We map this to the existing (b/s/h) stride triple by routing + # the row stride through ``hbm_stride_s`` (s_tile multiplier in + # the emitter); h-stride / b-stride aren't iterated in this + # path (h_groups == logical_b == 1). + pshape = [int(x) for x in parent.shape] + if len(pshape) < 2: + # 1D HBM parent — degenerate row stride = 1. + hbm_row_stride = 1 + else: + hbm_row_stride = 1 + for d in pshape[-1:]: + hbm_row_stride = int(d) + return ( + d_tiles, s_tiles, 1, 1, + mlen, 1, + (0, hbm_row_stride, 0), + (d_tile_stride, s_tile_stride, 0, 0), + ) + + def _emit_dma_h2v_slice(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Emit H_LOAD_V instructions for one HBM→VRAM slice copy. + + Single path for all dst shapes: compute an inner-tile grid via + :meth:`_h2v_tile_grid` and walk it. The grid is ``1×1×1×1`` for + a slice that fits one mlen×mlen tile; larger dst (2D row-major + or 4D BSHD with ``tile_layout``) expand into multiple issues. + """ + sl = op.buffer_args[0] + _arg1 = op.buffer_args[1] + if isinstance(_arg1, _hlir.BufferSlice): + raise IsaEmissionError( + f"dma_h2v_slice: dst (buffer_args[1]) must be a whole-buffer " + f"name; got BufferSlice(parent={_arg1.parent!r}, " + f"starts={list(_arg1.starts)}, extents={list(_arg1.extents)})" + ) + dst = mod.get_buffer(_arg1) + if not isinstance(sl, _hlir.BufferSlice): + raise IsaEmissionError( + f"dma_h2v_slice: buffer_args[0] must be BufferSlice, got " + f"{type(sl).__name__}" + ) + parent = mod.get_buffer(sl.parent) + _check_scope(parent, _scope.HBM, op.kind, "src.parent") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._slice_tile_grid(parent, sl, dst) + ) + + m_off, slice_static = self._materialise_slice_offset(parent, sl) + base_static = parent.hbm_offset + ( + slice_static if slice_static is not None else 0 + ) + + starts_s = self._format_starts(sl) + self.shim.compiler.generated_code += ( + f"; dma_h2v_slice {parent.name}[{starts_s}]+{list(sl.extents)} " + f"-> {dst.name} " + f"(grid d_tiles={d_tiles}, s_tiles={s_tiles}, " + f"h_groups={h_groups}, b={logical_b}" + f"{', dyn base gp' + str(m_off.register) if m_off is not None else ''})\n" + ) + for d_tile in range(d_tiles): + for s_tile in range(s_tiles): + for h_grp in range(h_groups): + for b in range(logical_b): + hbm_off = ( + base_static + + b * hbm_stride_b + + s_tile * inner_mlen * hbm_stride_s + + h_grp * lane_count * hbm_stride_h + + d_tile * inner_mlen + ) + vram_off = ( + d_tile * d_tile_stride + + s_tile * s_tile_stride + + h_grp * h_grp_stride + + b * b_stride + ) + self.shim.compiler.generated_code += ( + f"; tile (d={d_tile}, s={s_tile}, h={h_grp}, " + f"b={b}): hbm_off={hbm_off} vram_off={vram_off}\n" + ) + if m_off is not None: + self.emitter.emit_load_tile_from_hbm( + hbm_addr=parent.address, + vram_addr=dst.address + vram_off, + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset=hbm_off, + hbm_start_offset_reg=m_off.register, + ) + else: + self.emitter.emit_load_tile_from_hbm( + hbm_addr=parent.address, + vram_addr=dst.address + vram_off, + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset=hbm_off, + ) + if m_off is not None: + m_off.release() + + def _emit_dma_h2m_slice(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + sl = op.buffer_args[0] + dst = mod.get_buffer(op.buffer_args[1]) + if not isinstance(sl, _hlir.BufferSlice): + raise IsaEmissionError( + f"dma_h2m_slice: buffer_args[0] must be BufferSlice" + ) + parent = mod.get_buffer(sl.parent) + _check_scope(parent, _scope.HBM, op.kind, "src.parent") + _check_scope(dst, _scope.MRAM, op.kind, "dst") + self._check_slice_single_tile(parent, sl) + + m_off, static_off = self._materialise_slice_offset(parent, sl) + starts_s = self._format_starts(sl) + if m_off is None: + self.shim.compiler.generated_code += ( + f"; dma_h2m_slice {parent.name}[{starts_s}]+{list(sl.extents)} " + f"-> {dst.name} (parent_off={static_off} elems)\n" + ) + self.emitter.emit_hbm_tile_to_mram( + hbm_addr=parent.address, mram_addr=dst.address, + hbm_offset=static_off, + hbm_scale=parent.hbm_scale_size, hbm_stride=parent.hbm_stride, + ) + else: + self.shim.compiler.generated_code += ( + f"; dma_h2m_slice {parent.name}[{starts_s}]+{list(sl.extents)} " + f"-> {dst.name} (parent_off=gp{m_off.register} dyn)\n" + ) + self.emitter.emit_hbm_tile_to_mram( + hbm_addr=parent.address, mram_addr=dst.address, + hbm_offset_reg=m_off.register, + hbm_scale=parent.hbm_scale_size, hbm_stride=parent.hbm_stride, + ) + m_off.release() + + def _emit_dma_v2h_slice(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Emit H_STORE_V instructions for one VRAM→HBM slice copy. + + Mirror of :meth:`_emit_dma_h2v_slice`: same tile-grid model + from :meth:`_slice_tile_grid`, same per-tile offset math, just + store-direction emit. Single tile = ``1×1×1×1`` grid → one + H_STORE_V. Multi-tile (4D with tile_layout, or 2D row-major + larger than one mlen tile) expands naturally. + """ + src = mod.get_buffer(op.buffer_args[0]) + sl = op.buffer_args[1] + if not isinstance(sl, _hlir.BufferSlice): + raise IsaEmissionError( + f"dma_v2h_slice: buffer_args[1] must be BufferSlice" + ) + parent = mod.get_buffer(sl.parent) + _check_scope(src, _scope.VRAM, op.kind, "src") + _check_scope(parent, _scope.HBM, op.kind, "dst.parent") + + ra = self.shim.compiler.register_allocator + + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._slice_tile_grid(parent, sl, src) + ) + + m_base, static_base = self._materialise_slice_offset(parent, sl) + is_dyn = m_base is not None + base_static = static_base if static_base is not None else 0 + + starts_s = self._format_starts(sl) + self.shim.compiler.generated_code += ( + f"; dma_v2h_slice {src.name} -> " + f"{parent.name}[{starts_s}]+{list(sl.extents)} " + f"(grid d_tiles={d_tiles}, s_tiles={s_tiles}, " + f"h_groups={h_groups}, b={logical_b}" + f"{', dyn base gp' + str(m_base.register) if is_dyn else ''})\n" + ) + for d_tile in range(d_tiles): + for s_tile in range(s_tiles): + for h_grp in range(h_groups): + for b in range(logical_b): + tile_const = ( + b * hbm_stride_b + + s_tile * inner_mlen * hbm_stride_s + + h_grp * lane_count * hbm_stride_h + + d_tile * inner_mlen + ) + vram_off = ( + d_tile * d_tile_stride + + s_tile * s_tile_stride + + h_grp * h_grp_stride + + b * b_stride + ) + tile_vram = src.address + vram_off + self.shim.compiler.generated_code += ( + f"; tile (d={d_tile}, s={s_tile}, h={h_grp}, " + f"b={b}): vram[+{vram_off}] -> " + f"hbm[base+{tile_const}]\n" + ) + if is_dyn: + if tile_const == 0: + tile_off_reg = m_base.register + tile_off_owned = False + else: + tile_off_reg = ra.allocate_gp(1)[0] + tile_off_owned = True + self.shim.compiler.generated_code += ( + f"S_ADDI_INT gp{tile_off_reg}, " + f"gp{m_base.register}, {tile_const}\n" + ) + self.emitter.emit_store_tile_to_hbm( + vram_addr=tile_vram, hbm_addr=parent.address, + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset_reg=tile_off_reg, + ) + if tile_off_owned: + ra.free_gp([tile_off_reg]) + else: + self.emitter.emit_store_tile_to_hbm( + vram_addr=tile_vram, hbm_addr=parent.address, + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset=base_static + tile_const, + ) + if is_dyn: + m_base.release() + + if is_dyn: + m_base.release() + + def _emit_btmm(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Lane-fused (packed-head) Q @ K^T. + + Region schema: + buffer_args = [a_region (VramRegion), b_region (MramRegion), + c_region (VramRegion)] + scalar_args = [a_dim_roles, b_dim_roles, c_dim_roles] + + BTMM HW takes the whole packed-head tile in one issue; the + regions describe the full operands and the emitter doesn't + need to walk M/K/N internally — just hand off the three + physical base addresses. Per-lane offsets are zero here + (multi-lane HW instruction spans every lane natively). + """ + if len(op.buffer_args) != 3: + raise IsaEmissionError( + f"plena.btmm expects 3 buffer_args (regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.btmm a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise IsaEmissionError( + f"plena.btmm b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.btmm c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + + # Result-tile-count for the writeback. For our minimal_btmm: + # lhs (mlen, gh, hlen) x rhs (mlen, gh, hlen) -> dst (mlen, gh, mlen) + # so dst has gh tiles per group, total tile_count = + # (mlen*gh*mlen)/tile_elems. + tile_count = max(1, dst.num_elements // self.shim.tile_elems) + + self.emitter.emit_btmm( + lhs_packed_vram_addr=lhs.address, + rhs_mram_addr=rhs.address, + task_id=op.annotations.get("intrinsic", "btmm"), + ) + self.emitter.emit_btmm_wo( + base_addr=dst.address, + tile_count=tile_count, + task_id=op.annotations.get("intrinsic", "btmm") + ".wo", + ) + + def _emit_mv(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Per-head matrix-vector: M_MV + M_MV_WO. + + Region schema (same shape as matmul): + buffer_args = [a_region, b_region, c_region] + a_region: VramRegion (rank-1 LHS row, M extent == 1) + b_region: MramRegion + c_region: VramRegion + scalar_args = [a_dim_roles, b_dim_roles, c_dim_roles] + 4-tuples of "M"/"K"/"N"/"_". + + Origin offsets come from the regions' starts (lane_var on the + cluster axis when wrapped in CLUSTER), translated to physical + offsets via ``_tile_layout_strides``. + """ + if len(op.buffer_args) != 3: + raise IsaEmissionError( + f"plena.mv expects 3 buffer_args (a/b/c regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.mv a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise IsaEmissionError( + f"plena.mv b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.mv c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + + lhs_raw_off = self._region_origin_offset(lhs, a_reg) + rhs_raw_off = self._region_origin_offset(rhs, b_reg) + dst_raw_off = self._region_origin_offset(dst, c_reg) + + def _resolve(expr, name): + """Returns (static_int_or_None, gp_register_or_None, handle_or_None).""" + if isinstance(expr, tir.IntImm): + return int(expr.value), None, None + if isinstance(expr, int): + return int(expr), None, None + m = self.materializer.materialize(expr) + self.shim.compiler.generated_code += m.isa + return None, m.register, m + + lhs_static, lhs_reg, lhs_h = _resolve(lhs_raw_off, "lhs_offset") + rhs_static, rhs_reg, rhs_h = _resolve(rhs_raw_off, "rhs_offset") + dst_static, dst_reg, dst_h = _resolve(dst_raw_off, "dst_offset") + + try: + self.emitter.emit_mv( + lhs_vram_addr=lhs.address + (lhs_static or 0), + rhs_mram_addr=rhs.address + (rhs_static or 0), + dst_vram_addr=dst.address + (dst_static or 0), + lhs_offset_reg=lhs_reg, + rhs_offset_reg=rhs_reg, + dst_offset_reg=dst_reg, + task_id=op.annotations.get("intrinsic", "mv"), + ) + finally: + for h in (lhs_h, rhs_h, dst_h): + if h is not None: + h.release() + + def _emit_btmv(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Lane-fused matrix-vector (decode-style btmm with M=1). + + Region schema same as _emit_btmm; differs only in op kind. + """ + if len(op.buffer_args) != 3: + raise IsaEmissionError( + f"plena.btmv expects 3 buffer_args (regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.btmv a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise IsaEmissionError( + f"plena.btmv b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.btmv c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + + self.emitter.emit_btmv( + lhs_packed_vram_addr=lhs.address, + rhs_mram_addr=rhs.address, + task_id=op.annotations.get("intrinsic", "btmv"), + ) + self.emitter.emit_bmv_wo( + base_addr=dst.address, + task_id=op.annotations.get("intrinsic", "btmv") + ".wo", + ) + + def _emit_mm(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Single-tile, single-head matrix multiply. + + Maps `plena.mm(lhs_vram, rhs_mram, dst_vram)` to one M_MM / + M_MM_WO sequence (via ISAEmitter.emit_matmul with a single + lhs/rhs pair). The dst tile is fully overwritten — no implicit + accumulation across calls. Streaming-style accumulation is the + kernel author's job (tile_zero + mm + tile_add into a separate + accumulator tile, see kernels/tiled_mm.py). + """ + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + lhs_rows, lhs_cols = self._logical_2d(lhs.shape) + rhs_rows, rhs_cols = self._logical_2d(rhs.shape) + dst_rows, dst_cols = self._logical_2d(dst.shape) + if lhs_rows != self.shim.mlen or lhs_cols != self.shim.mlen: + raise IsaEmissionError( + f"plena.mm lhs must be one full mlen*mlen tile; got logical 2D " + f"({lhs_rows}, {lhs_cols}) for buffer {lhs.name}" + ) + if rhs_rows != self.shim.mlen: + raise IsaEmissionError( + f"plena.mm rhs must have mlen rows; got logical 2D " + f"({rhs_rows}, {rhs_cols}) for buffer {rhs.name}" + ) + if dst_rows != self.shim.mlen: + raise IsaEmissionError( + f"plena.mm dst must have mlen rows; got logical 2D " + f"({dst_rows}, {dst_cols}) for buffer {dst.name}" + ) + if rhs_cols != dst_cols: + raise IsaEmissionError( + f"plena.mm rhs/dst logical widths must match; got rhs={rhs_cols} dst={dst_cols}" + ) + # Use the hw-loop emitter (tens of static lines) instead of the + # Python-unrolled emit_matmul (~2k lines per call). Dynamic + # instruction count is identical; hw loops just shrink the ISA + # text. Important for kernels that invoke plena.mm under several + # unrolled outer levels (q*h*d*kv) where ASM size scales with + # the product. + if rhs_cols == self.shim.mlen and dst_cols == self.shim.mlen: + self.emitter.emit_matmul_single_tile_hwloop( + lhs_vram_addr=lhs.address, + rhs_mram_addr=rhs.address, + dst_vram_addr=dst.address, + task_id=op.annotations.get("intrinsic", "mm"), + ) + return + self.emitter.emit_matmul_narrow_tile_hwloop( + lhs_vram_addr=lhs.address, + rhs_mram_addr=rhs.address, + dst_vram_addr=dst.address, + hlen=rhs_cols, + dst_row_stride=dst_cols, + task_id=op.annotations.get("intrinsic", "mm"), + ) + + def _emit_matmul(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Unified `(M, K) @ (K, N) -> (M, N)` matmul. + + Region schema (Einstein-style): + buffer_args = [a_region, b_region, c_region] + Vram/MramRegion (4D start+extent on the parent buffer's + physical shape). a_region is VRAM, b_region is MRAM, + c_region is VRAM. starts encode per-lane / per-batch + origin; extents are the per-axis logical span. + scalar_args = [a_dim_roles, b_dim_roles, c_dim_roles] + each a 4-tuple of "M"/"K"/"N"/"_" labels aligned with + the matching region. K appears in a and b but not in + c (contracted); M in a and c; N in b and c. The + ordering of K vs N inside b's roles tells the emitter + whether B is K-inner (standard, M_MM) or K-outer + (transpose_B, M_TMM). + """ + if len(op.buffer_args) != 3: + raise IsaEmissionError( + f"plena.matmul expects 3 buffer_args (a/b/c regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.matmul a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise IsaEmissionError( + f"plena.matmul b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise IsaEmissionError( + f"plena.matmul c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + if len(op.scalar_args) != 3: + raise IsaEmissionError( + f"plena.matmul expects 3 scalar_args (a/b/c dim_roles); " + f"got {len(op.scalar_args)}" + ) + a_roles, b_roles, c_roles = op.scalar_args + if len(a_roles) != 4 or len(b_roles) != 4 or len(c_roles) != 4: + raise IsaEmissionError( + f"plena.matmul dim_roles must each be 4-tuples; got " + f"a={a_roles!r} b={b_roles!r} c={c_roles!r}" + ) + + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + + mlen = int(self.shim.mlen) + + def _find_role_axis(roles: Tuple[str, ...], role: str, + operand: str) -> Optional[int]: + hits = [i for i, r in enumerate(roles) if r == role] + if not hits: + return None + if len(hits) > 1: + raise IsaEmissionError( + f"plena.matmul {operand}: role {role!r} appears at " + f"multiple axes {hits} in roles {roles!r}" + ) + return hits[0] + + c_M_axis = _find_role_axis(c_roles, "M", "c") + c_N_axis = _find_role_axis(c_roles, "N", "c") + a_M_axis = _find_role_axis(a_roles, "M", "a") + a_K_axis = _find_role_axis(a_roles, "K", "a") + b_K_axis = _find_role_axis(b_roles, "K", "b") + b_N_axis = _find_role_axis(b_roles, "N", "b") + for axis, name in ( + (c_M_axis, "c.M"), (c_N_axis, "c.N"), + (a_M_axis, "a.M"), (a_K_axis, "a.K"), + (b_K_axis, "b.K"), (b_N_axis, "b.N"), + ): + if axis is None: + raise IsaEmissionError( + f"plena.matmul: missing {name} axis in dim_roles; " + f"a={a_roles!r} b={b_roles!r} c={c_roles!r}" + ) + + M = int(a_reg.extents[a_M_axis]) + K = int(a_reg.extents[a_K_axis]) + N = int(b_reg.extents[b_N_axis]) + if int(b_reg.extents[b_K_axis]) != K: + raise IsaEmissionError( + f"plena.matmul: a.K extent {K} != b.K extent " + f"{int(b_reg.extents[b_K_axis])}" + ) + if int(c_reg.extents[c_M_axis]) != M: + raise IsaEmissionError( + f"plena.matmul: c.M extent {int(c_reg.extents[c_M_axis])} " + f"!= a.M extent {M}" + ) + if int(c_reg.extents[c_N_axis]) != N: + raise IsaEmissionError( + f"plena.matmul: c.N extent {int(c_reg.extents[c_N_axis])} " + f"!= b.N extent {N}" + ) + + if M % mlen != 0 or K % mlen != 0: + raise IsaEmissionError( + f"plena.matmul: M ({M}) and K ({K}) must be multiples of " + f"MLEN ({mlen})" + ) + M_tiles = M // mlen + K_tiles = K // mlen + # transpose_b: standard layout has B = (K, N) row-major, i.e. + # K is the outer (slower-varying) dim and N is inner — in + # physical axis indices, ``K_axis < N_axis``. When the kernel + # author intends ``B = (N, K)`` (nn.Linear weight convention) + # the order flips: ``N_axis < K_axis``, and the emitter must + # swap M_MM for M_TMM so the systolic array sees the right + # operand orientation. + transpose_b = b_N_axis < b_K_axis + + # dst_row_stride = elements between consecutive rows of C in + # its physical VRAM layout (emit_matmul_general walks C as a + # row-major (M, N) block stepping by this much per row). + # + # The right value depends on where C's cluster axis sits: + # + # * packed-head dst (cluster_dim on the H axis, lane_count>1, + # e.g. PV_loc 1x1024x8x128): one physical mlen row holds + # LANE_COUNT heads side by side. The M (row) axis is S = + # axis 1, and consecutive S rows are S_INNER_STRIDE + # (= LANE_COUNT*D_INNER) apart. Deriving the stride from + # c_reg.extents (the logical region, H-extent 1) yields + # only N and makes every row-(r+1) write land in row-r's + # head-1 slot. + # + # * cluster axis NOT on H (e.g. S_loc 8x1024x1x1024 has the + # cluster on B/axis 0, lane_count==1), or no tile_layout: + # the per-row spacing is just the extents product after the + # M axis — the legacy path is correct there. + # + # So only switch to the physical s_inner stride when the dst is + # genuinely packed-head AND M is the S axis the s_inner stride + # describes. + dst_cluster_dim = getattr(dst, "cluster_dim", None) + tl_info = self._tile_layout_strides(dst) + packed_head_dst = ( + tl_info is not None + and dst_cluster_dim == 2 # cluster on the H axis + and int(tl_info["lane_count"]) > 1 + and c_M_axis == 1 # M is the S axis + ) + if packed_head_dst: + dst_row_stride = int(tl_info["s_inner_stride"]) + else: + dst_row_stride = 1 + for ax in range(c_M_axis + 1, len(c_reg.extents)): + dst_row_stride *= int(c_reg.extents[ax]) + if dst_row_stride <= 0: + dst_row_stride = None + + lhs_raw_off = self._region_origin_offset(lhs, a_reg) + rhs_raw_off = self._region_origin_offset(rhs, b_reg) + dst_raw_off = self._region_origin_offset(dst, c_reg) + + # Each of lhs / rhs / dst offsets supports either a compile-time + # int (folded into the emitter's static residual) or an arbitrary + # PrimExpr (materialised to a gp register here, passed in via the + # matching `*_offset_reg`). Two offsets that are structurally the + # same PrimExpr (e.g. ``rhs = by*hlen``, ``dst = by*hlen``) share + # one materialised register so we don't run into the 16-GP cap. + # Materialised registers are released after the emit returns. + materialised_handles: List = [] + cached: List = [] # list of (raw_expr, register) for CSE lookup + + def _resolve_offset(raw, name: str): + if isinstance(raw, tir.IntImm): + return int(raw.value), None + if isinstance(raw, int): + return int(raw), None + if isinstance(raw, tir.PrimExpr): + for prev_raw, prev_reg in cached: + if tvm.ir.structural_equal(prev_raw, raw): + return 0, prev_reg + m = self.materializer.materialize(raw) + self.shim.compiler.generated_code += m.isa + cached.append((raw, m.register)) + # Pin so emit_matmul_general's ``allocate_gp(7)`` can't + # pick this register as a spill candidate and silently + # corrupt the offset. + # + # ONLY for caller-owned registers. A register that the + # materializer handed out with ``owns_register=False`` + # belongs to the per-op idx cache: it is already pinned + # by that cache and will be unpinned/freed by + # ``end_op``. Pinning / unpinning / releasing it here + # would race that ownership (double-free, premature + # unpin). So we record it for cleanup ONLY when we own + # it, and leave cache-owned registers untouched. + if m.owns_register: + materialised_handles.append(m) + self.shim.compiler.register_allocator.pin_gp(m.register) + return 0, m.register + raise IsaEmissionError( + f"plena.matmul {name} must be int or PrimExpr; got {raw!r}" + ) + + # Pre-allocate the 7 scratch GPs the emitter needs and pin them + # BEFORE materialising the dynamic offsets. Order matters: if we + # materialised first, _auto_spill triggered by allocate_gp(7) could + # spill the offset regs (despite pin_gp) and then hand the same + # physical registers back as scratch — silently aliasing + # `lhs_offset_reg`/`rhs_offset_reg`/`dst_offset_reg` with + # `gp_act_orow`/`gp_out_orow`/`gp_mat`. By taking scratch first we + # guarantee offset regs are disjoint from scratch regs. + ra = self.shim.compiler.register_allocator + scratch_regs = ra.allocate_gp(7) + for r in scratch_regs: + ra.pin_gp(r) + + try: + lhs_off_static, lhs_off_reg = _resolve_offset(lhs_raw_off, "lhs_offset") + rhs_off_static, rhs_off_reg = _resolve_offset(rhs_raw_off, "rhs_offset") + dst_off_static, dst_off_reg = _resolve_offset(dst_raw_off, "dst_offset") + + self.emitter.emit_matmul_general( + M_tiles=M_tiles, + K_tiles=K_tiles, + N=N, + lhs_vram_base=int(lhs.address), + lhs_offset=lhs_off_static, + lhs_offset_reg=lhs_off_reg, + rhs_mram_base=int(rhs.address), + rhs_offset=rhs_off_static, + rhs_offset_reg=rhs_off_reg, + dst_vram_base=int(dst.address), + dst_offset=dst_off_static, + dst_offset_reg=dst_off_reg, + dst_row_stride=dst_row_stride, + task_id=op.annotations.get("intrinsic", "matmul"), + scratch_regs=scratch_regs, + transpose_b=transpose_b, + # Fully unroll — emit static M_MM/M_MM_WO instead of + # nested C_LOOP, eliminating the gp loop-counter and + # per-iter address-advance overhead. + unroll_loops=True, + ) + finally: + for m in materialised_handles: + ra.unpin_gp(m.register) + m.release() + for r in scratch_regs: + ra.unpin_gp(r) + ra.free_gp(scratch_regs) + + def _emit_mm_slot(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.MRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + if len(op.scalar_args) != 4: + raise IsaEmissionError( + f"plena.mm_slot expects exactly 4 scalar args " + f"(lhs_row_offset, rhs_col_offset, dst_col_offset, col_count); " + f"got {len(op.scalar_args)}" + ) + lhs_row_offset_raw = op.scalar_args[0] + rhs_col_offset_raw = op.scalar_args[1] + dst_col_offset_raw = op.scalar_args[2] + col_count_raw = op.scalar_args[3] + # lhs_row_offset can be either a compile-time int (literal / IntImm) + # or a dynamic PrimExpr (e.g. `h * mlen * mlen` from a TIR loop). + # Static case: fold into the lhs_vram_addr literal. + # Dynamic case: materialize `lhs.address + offset` to a register and + # pass that as lhs_vram_addr_reg. + lhs_addr_m = None + if isinstance(lhs_row_offset_raw, tir.IntImm): + lhs_row_offset = int(lhs_row_offset_raw.value) + elif isinstance(lhs_row_offset_raw, int): + lhs_row_offset = int(lhs_row_offset_raw) + elif isinstance(lhs_row_offset_raw, tir.PrimExpr): + lhs_row_offset = None + full_addr_expr = tir.Add( + tir.IntImm("int32", int(lhs.address)), + lhs_row_offset_raw, + ) + lhs_addr_m = self.materializer.materialize(full_addr_expr) + self.shim.compiler.generated_code += lhs_addr_m.isa + else: + raise IsaEmissionError( + f"plena.mm_slot lhs_row_offset must be int or PrimExpr; " + f"got {type(lhs_row_offset_raw).__name__}: {lhs_row_offset_raw!r}" + ) + if lhs_row_offset is not None and lhs_row_offset < 0: + raise IsaEmissionError( + f"plena.mm_slot lhs_row_offset must be >= 0; got {lhs_row_offset}" + ) + if isinstance(rhs_col_offset_raw, tir.PrimExpr) and not isinstance(rhs_col_offset_raw, tir.IntImm): + rhs_col_offset = None + rhs_off_m = self.materializer.materialize(rhs_col_offset_raw) + self.shim.compiler.generated_code += rhs_off_m.isa + else: + rhs_col_offset = int(rhs_col_offset_raw) + rhs_off_m = None + if isinstance(dst_col_offset_raw, tir.PrimExpr) and not isinstance(dst_col_offset_raw, tir.IntImm): + dst_col_offset = None + dst_off_m = self.materializer.materialize(dst_col_offset_raw) + self.shim.compiler.generated_code += dst_off_m.isa + else: + dst_col_offset = int(dst_col_offset_raw) + dst_off_m = None + try: + col_count = int(col_count_raw) + except TypeError as exc: + raise IsaEmissionError( + f"plena.mm_slot col_count must be a compile-time integer; got " + f"{type(col_count_raw).__name__}: {col_count_raw!r}" + ) from exc + lhs_rows, lhs_cols = self._logical_2d(lhs.shape) + rhs_rows, rhs_cols = self._logical_2d(rhs.shape) + dst_rows, dst_cols = self._logical_2d(dst.shape) + # LHS must contain at least one mlen*mlen tile. For static offsets + # we can range-check at compile time; for dynamic offsets the kernel + # author is responsible for keeping the offset in range. + tile_elems = self.shim.mlen * self.shim.mlen + if lhs_row_offset is not None and lhs_row_offset + tile_elems > lhs.num_elements: + raise IsaEmissionError( + f"plena.mm_slot lhs tile out of range; " + f"lhs_row_offset={lhs_row_offset} + mlen*mlen={tile_elems} " + f"exceeds buffer {lhs.name} num_elements={lhs.num_elements}" + ) + if rhs_rows != self.shim.mlen or dst_rows != self.shim.mlen: + raise IsaEmissionError( + f"plena.mm_slot rhs/dst must have mlen rows; got rhs=({rhs_rows}, {rhs_cols}) " + f"dst=({dst_rows}, {dst_cols})" + ) + rhs_col_offset_check = 0 if rhs_col_offset is None else rhs_col_offset + dst_col_offset_check = 0 if dst_col_offset is None else dst_col_offset + if rhs_col_offset_check < 0 or dst_col_offset_check < 0 or col_count <= 0: + raise IsaEmissionError( + f"plena.mm_slot requires non-negative offsets and positive col_count; got " + f"rhs_col_offset={rhs_col_offset_raw} dst_col_offset={dst_col_offset_raw} col_count={col_count}" + ) + if rhs_col_offset is not None and rhs_col_offset + col_count > rhs_cols: + raise IsaEmissionError( + f"plena.mm_slot rhs slot exceeds rhs width; rhs_width={rhs_cols} " + f"rhs_col_offset={rhs_col_offset} col_count={col_count}" + ) + if dst_col_offset is not None and dst_col_offset + col_count > dst_cols: + raise IsaEmissionError( + f"plena.mm_slot dst slot exceeds dst width; dst_width={dst_cols} " + f"dst_col_offset={dst_col_offset} col_count={col_count}" + ) + try: + if lhs_addr_m is not None: + lhs_vram_addr_arg = 0 # ignored when reg form is used + lhs_vram_addr_reg = lhs_addr_m.register + else: + lhs_vram_addr_arg = lhs.address + lhs_row_offset + lhs_vram_addr_reg = None + self.emitter.emit_slot_matmul( + lhs_vram_addr=lhs_vram_addr_arg, + lhs_vram_addr_reg=lhs_vram_addr_reg, + rhs_mram_addr=rhs.address, + rhs_col_offset=0 if rhs_col_offset is None else rhs_col_offset, + rhs_col_offset_reg=None if rhs_off_m is None else rhs_off_m.register, + dst_vram_addr=dst.address, + dst_col_offset=0 if dst_col_offset is None else dst_col_offset, + dst_col_offset_reg=None if dst_off_m is None else dst_off_m.register, + col_count=col_count, + task_id=op.annotations.get("intrinsic", "mm_slot"), + ) + finally: + if rhs_off_m is not None: + rhs_off_m.release() + if dst_off_m is not None: + dst_off_m.release() + if lhs_addr_m is not None: + lhs_addr_m.release() + + def _emit_v_zero(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Region-based zero-fill on VRAM: ``dst[region] = 0``. + + Schema (region layer): + buffer_args = [dst_region] (VramRegion with 4D BSHD) + scalar_args = [] + + Lowers to ``V_MUL_VF dst, dst, f0, 0`` per mlen-wide chunk + (f0 == 0 by convention). + """ + if len(op.buffer_args) != 1: + raise IsaEmissionError( + f"v_zero expects 1 buffer_arg (dst region); " + f"got {len(op.buffer_args)}" + ) + if not isinstance(op.buffer_args[0], _hlir.VramRegion): + raise IsaEmissionError( + f"v_zero dst: expected VramRegion, got " + f"{type(op.buffer_args[0]).__name__}" + ) + if op.scalar_args: + raise IsaEmissionError( + f"v_zero expects 0 scalar_args; got {len(op.scalar_args)}" + ) + dst_region: _hlir.VramRegion = op.buffer_args[0] + dst = mod.get_buffer(dst_region.parent) + _check_scope(dst, _scope.VRAM, op.kind, "dst") + self.shim.compiler.generated_code += ( + f"; v_zero dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}\n" + ) + for d_off, _ in self._vram_region_iter_chunks(dst, dst_region): + dst_addr = tir.Add( + tir.IntImm("int32", int(dst.address)), d_off, + ) + m_dst = self.materializer.materialize(dst_addr) + self.shim.compiler.generated_code += m_dst.isa + try: + self.shim.compiler.generated_code += ( + f"V_MUL_VF gp{m_dst.register}, gp{m_dst.register}, " + f"f0, 0\n" + ) + finally: + m_dst.release() + + def _emit_v_binary(self, mod: _hlir.HLIRModule, op: _hlir.Op, + *, binary_op: str) -> None: + """Region-based vector binary op: + ``dst[region] = lhs[region] rhs[region]`` elementwise. + + Schema (region layer): + buffer_args = [lhs_region, rhs_region, dst_region] + each is a ``VramRegion(parent, starts, extents)`` with + ``starts`` / ``extents`` length 4 in canonical BSHD + order. Every (b, s, h, d) cell within ``extents`` of + the three regions is paired up for the elementwise + op. The three regions must agree on ``extents``. + scalar_args = [] (region carries everything) + + Emission walks each region with ``_vram_region_iter_chunks``, + which folds the cluster (packed-head) axis and unrolls the + d_tile axis automatically. One HLIR op may therefore emit + N V_*_VV instructions (N = product of non-cluster outer + extents × d_chunks). + """ + op_to_insn = { + "add": "V_ADD_VV", + "sub": "V_SUB_VV", + "mul": "V_MUL_VV", + } + opcode = op_to_insn[binary_op] + if len(op.buffer_args) != 3: + raise IsaEmissionError( + f"{op.kind} expects 3 buffer_args (lhs/rhs/dst regions); " + f"got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("lhs", "rhs", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise IsaEmissionError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise IsaEmissionError( + f"{op.kind} expects 0 scalar_args (region carries shape); " + f"got {len(op.scalar_args)}" + ) + lhs_region: _hlir.VramRegion = op.buffer_args[0] + rhs_region: _hlir.VramRegion = op.buffer_args[1] + dst_region: _hlir.VramRegion = op.buffer_args[2] + lhs = mod.get_buffer(lhs_region.parent) + rhs = mod.get_buffer(rhs_region.parent) + dst = mod.get_buffer(dst_region.parent) + _check_scope(lhs, _scope.VRAM, op.kind, "lhs") + _check_scope(rhs, _scope.VRAM, op.kind, "rhs") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + if (tuple(lhs_region.extents) != tuple(dst_region.extents) + or tuple(rhs_region.extents) != tuple(dst_region.extents)): + raise IsaEmissionError( + f"{op.kind}: lhs/rhs/dst region extents must match; " + f"lhs={tuple(lhs_region.extents)} " + f"rhs={tuple(rhs_region.extents)} " + f"dst={tuple(dst_region.extents)}" + ) + + self.shim.compiler.generated_code += ( + f"; v binary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}\n" + ) + + # Walk all three regions in lock-step. Each yield gives the + # ``(vram_offset_expr, fp_step_elems)`` for one mlen-wide + # chunk; we discard fp_step (no FPRAM here) and materialise + # the three per-operand absolute addresses. + lhs_iter = self._vram_region_iter_chunks(lhs, lhs_region) + rhs_iter = self._vram_region_iter_chunks(rhs, rhs_region) + dst_iter = self._vram_region_iter_chunks(dst, dst_region) + for (l_off, _), (r_off, _), (d_off, _) in zip( + lhs_iter, rhs_iter, dst_iter + ): + lhs_addr = tir.Add( + tir.IntImm("int32", int(lhs.address)), l_off, + ) + rhs_addr = tir.Add( + tir.IntImm("int32", int(rhs.address)), r_off, + ) + dst_addr = tir.Add( + tir.IntImm("int32", int(dst.address)), d_off, + ) + m_lhs = self.materializer.materialize(lhs_addr) + self.shim.compiler.generated_code += m_lhs.isa + m_rhs = self.materializer.materialize(rhs_addr) + self.shim.compiler.generated_code += m_rhs.isa + m_dst = self.materializer.materialize(dst_addr) + self.shim.compiler.generated_code += m_dst.isa + try: + self.shim.compiler.generated_code += ( + f"{opcode} gp{m_dst.register}, gp{m_lhs.register}, " + f"gp{m_rhs.register}, 0\n" + ) + finally: + m_dst.release() + m_rhs.release() + m_lhs.release() + return + + def _emit_v_add(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_binary(mod, op, binary_op="add") + + def _emit_v_sub(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_binary(mod, op, binary_op="sub") + + def _emit_v_mul(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_binary(mod, op, binary_op="mul") + + def _emit_v_unary(self, mod: _hlir.HLIRModule, op: _hlir.Op, + *, opcode: str) -> None: + """Region-based vector unary op: ``dst[region] = op(src[region])``. + + Schema (region layer): + buffer_args = [src_region, dst_region] + each is a ``VramRegion`` with 4D BSHD (starts, extents). + The two regions must agree on ``extents``. + scalar_args = [] + """ + if len(op.buffer_args) != 2: + raise IsaEmissionError( + f"{op.kind} expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise IsaEmissionError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise IsaEmissionError( + f"{op.kind} expects 0 scalar_args (region carries shape); " + f"got {len(op.scalar_args)}" + ) + src_region: _hlir.VramRegion = op.buffer_args[0] + dst_region: _hlir.VramRegion = op.buffer_args[1] + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + _check_scope(src, _scope.VRAM, op.kind, "src") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + if tuple(src_region.extents) != tuple(dst_region.extents): + raise IsaEmissionError( + f"{op.kind}: src/dst region extents must match; " + f"src={tuple(src_region.extents)} dst={tuple(dst_region.extents)}" + ) + + self.shim.compiler.generated_code += ( + f"; v unary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}\n" + ) + src_iter = self._vram_region_iter_chunks(src, src_region) + dst_iter = self._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = tir.Add( + tir.IntImm("int32", int(src.address)), s_off, + ) + dst_addr = tir.Add( + tir.IntImm("int32", int(dst.address)), d_off, + ) + m_src = self.materializer.materialize(src_addr) + self.shim.compiler.generated_code += m_src.isa + m_dst = self.materializer.materialize(dst_addr) + self.shim.compiler.generated_code += m_dst.isa + try: + self.shim.compiler.generated_code += ( + f"{opcode} gp{m_dst.register}, gp{m_src.register}, 0\n" + ) + finally: + m_dst.release() + m_src.release() + + def _emit_v_exp(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_unary(mod, op, opcode="V_EXP_V") + + def _emit_v_reci(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_unary(mod, op, opcode="V_RECI_V") + + def _emit_v_sqrt(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_v_unary(mod, op, opcode="V_SQRT_V") + + def _emit_fp_copy_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="copy") + + def _emit_fp_zero_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Store FP zero to one FPRAM slot via ``S_ST_FP f0, gp{dst}, 0``. + + Relies on the same ``f0 == 0`` convention plena.tile_zero and + plena.copy_v_to_v already depend on. Single scalar arg = the + FPRAM destination address (allowed to be a PrimExpr — the + materialiser folds in the fragment's allocated FPRAM base).""" + if len(op.scalar_args) != 1: + raise IsaEmissionError( + f"{op.kind} expects 1 scalar address arg, got {len(op.scalar_args)}" + ) + dst_addr_expr = self._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "dst", + ) + m_dst = self.materializer.materialize(dst_addr_expr) + self.shim.compiler.generated_code += m_dst.isa + try: + lines = [ + f"; fp scalar task {op.annotations.get('intrinsic', op.kind)} op=zero", + f"S_ST_FP f0, gp{m_dst.register}, 0", + ] + self.shim.compiler.generated_code += "\n".join(lines) + "\n" + finally: + m_dst.release() + + def _emit_fp_add_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="add") + + def _emit_fp_sub_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="sub") + + def _emit_fp_mul_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="mul") + + def _emit_fp_max_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="max") + + def _emit_fp_exp_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="exp") + + def _emit_fp_reci_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="reci") + + def _emit_fp_sqrt_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_fp_scalar_op_at(mod, op, kernel_op="sqrt") + + # `_at` row ops: scalars are (FP scalar address, dim2, dim3) for the + # variants that touch fpram, or just (dim2, dim3) for exp. The emitter + # maps (dim2, dim3) to a physical VRAM row and synthesizes a V_MASK + # for narrow packed D tiles. + def _emit_row_reduce_max_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at( + mod, op, row_op="reduce_max", reduce=True, masked=True, + ) + + def _emit_row_reduce_sum_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at( + mod, op, row_op="reduce_sum", reduce=True, masked=True, + ) + + # Single-row VRAM × FPRAM-scalar ops. One HLIR op = one HW + # instruction. Multi-row callers wrap in outer ``for row``. + def _emit_row_exp(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at(mod, op, row_op="exp", masked=True) + + def _emit_row_sub_fp(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at(mod, op, row_op="sub", masked=True, has_fp=True) + + def _emit_row_mul_fp(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at(mod, op, row_op="mul", masked=True, has_fp=True) + + def _emit_row_add_fp(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + self._emit_row_scalar_op_at(mod, op, row_op="add", masked=True, has_fp=True) + + # ------------------------------------------------------------------ + # Slice-level VRAM <-> FPRAM transfer. HLIR carries the whole logical + # region (VramRegion: starts + extents on the parent buffer); this + # emitter splits it into HW-MLEN-wide chunks, computes each chunk's + # physical VRAM offset via the parent's 7D tile layout, and emits + # one S_MAP_*_FP/V per chunk. + # + # The parent is always a 4D ``(B, S, H, D)`` BSHD buffer (the + # pad-to-4D step in to_plena guarantees rank==4 on every VRAM/MRAM + # buffer). Its physical placement in VRAM is the 7D tile layout + # described in ``hlir.TileLayout``: + # + # (D_TILES, S_TILES, H_GROUPS, B, MLEN, LANE_COUNT, D_INNER) + # + # A logical position ``(b, s, h, d)`` decomposes as: + # d_tile = d // MLEN d_inner_off = d % MLEN + # s_tile = s // MLEN s_inner_off = s % MLEN + # h_grp = h // LANE_COUNT lane = h % LANE_COUNT + # ------------------------------------------------------------------ + def _vram_region_iter_chunks( + self, + parent: _hlir.Buffer, + region: _hlir.VramRegion, + ): + """Yield ``(vram_offset_expr, fp_step_elems)`` for each HW-MLEN + chunk inside ``region``. ``fp_step_elems`` is the cumulative + element count consumed by all chunks so far — callers add it + to the base fp address. + + Region semantics (post pad-to-4D): every parent is rank 4 + BSHD; ``starts`` / ``extents`` are 4-tuples. The region's + last-axis extent (``ed``) drives the chunking — one S_MAP per + ``D_INNER`` slots along D. The (b, s, h, d_tile) outer + coordinates are walked once each; the chunk's physical VRAM + offset is the 7D inner-tile address. + """ + starts = region.starts + extents = region.extents + if len(parent.shape) != 4: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}) expects 4D BSHD " + f"parent; got shape {tuple(parent.shape)}. pad-to-4D " + f"in to_plena should have normalised this." + ) + if len(starts) != 4 or len(extents) != 4: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}) rank mismatch: " + f"starts={tuple(starts)} extents={tuple(extents)}; " + f"both must be 4-tuples" + ) + # Row-major-flat path: parent has no tile_layout AND no + # cluster_dim. This covers (a) author-pinned global.vram / + # global.mram tensor caches (testbench-loaded contiguous, + # not 7D-tile-padded) and (b) any small buffer that fits a + # single tile (logical extent ≤ mlen on every dim). Each + # mlen-wide region chunk maps directly to a flat row-major + # slice. Buffers with a non-None tile_layout keep the 7D path + # below — their physical layout walks mlen-row tiles. + cluster_dim_pre = getattr(parent, "cluster_dim", None) + if cluster_dim_pre is None and parent.tile_layout is None: + mlen = self.shim.mlen + shape = [int(d) for d in parent.shape] + # Row-major strides on the BSHD shape (rank-4). + row_strides = [1] * 4 + for i in range(2, -1, -1): + row_strides[i] = row_strides[i + 1] * shape[i + 1] + eb, es, eh, ed = (int(x) for x in extents) + total_elems = eb * es * eh * ed + if total_elems % mlen != 0: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}, cluster-less): " + f"total region elems={total_elems} not a multiple of " + f"MLEN={mlen}" + ) + chunks = total_elems // mlen + + def _start_plus_simple(axis: int): + s = starts[axis] + if isinstance(s, int): + return tir.IntImm("int32", int(s)) + return s + + def _mul_s(expr, k: int): + if k == 0: + return tir.IntImm("int32", 0) + if k == 1: + return expr + if isinstance(expr, tir.IntImm): + return tir.IntImm("int32", int(expr.value) * k) + return tir.Mul(expr, tir.IntImm("int32", int(k))) + + def _sum_s(terms): + nz = [t for t in terms if not (isinstance(t, tir.IntImm) and int(t.value) == 0)] + if not nz: + return tir.IntImm("int32", 0) + acc = nz[0] + for t in nz[1:]: + acc = tir.Add(acc, t) + return acc + + base_off = _sum_s([ + _mul_s(_start_plus_simple(0), row_strides[0]), + _mul_s(_start_plus_simple(1), row_strides[1]), + _mul_s(_start_plus_simple(2), row_strides[2]), + _start_plus_simple(3), + ]) + fp_elems = 0 + for c in range(chunks): + if c == 0: + yield base_off, fp_elems + else: + yield tir.Add(base_off, tir.IntImm("int32", c * mlen)), fp_elems + fp_elems += mlen + return + + tl = parent.tile_layout + if tl is None: + # ``make_tile_layout`` returns None for buffers that fit a + # single inner tile (s ≤ mlen ∧ d ≤ mlen on BSHD). Synthesise + # the trivial 1×1×1×1 layout so the offset math below works + # uniformly without a separate code path. + b_sz, s_sz, h_sz, d_sz = (int(x) for x in parent.shape) + tl = _hlir.TileLayout( + logical_b=b_sz, logical_s=s_sz, logical_h=h_sz, logical_d=d_sz, + d_tiles=1, s_tiles=1, h_groups=1, + mlen=self.shim.mlen, lane_count=1, + d_inner=d_sz if d_sz > 0 else self.shim.mlen, + ) + + eb, es, eh, ed = (int(x) for x in extents) + if ed % tl.d_inner != 0: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}): innermost extent " + f"ed={ed} not a multiple of D_INNER={tl.d_inner}" + ) + d_chunks = ed // tl.d_inner + + # Lane axis is sync-wrap-folded by mid_ir: a single + # ``S_MAP_*_FP/V`` instruction covers every lane in one issue, + # so the emitter must NOT iterate the lane axis (doing so would + # re-issue the same multi-lane instruction lane_count times at + # offsets that no longer align to mlen). Assert the region + # covers the full lane span on whatever axis the parent's + # ``cluster_dim`` marks, then fold that axis out of the walk. + cluster_dim = getattr(parent, "cluster_dim", None) + outer_iter = {"b": eb, "s": es, "h": eh} + if cluster_dim is not None: + # BSHD positions: 0=B, 1=S, 2=H, 3=D. Lane never lands at D. + lane_key = {0: "b", 1: "s", 2: "h"}.get(cluster_dim) + if lane_key is None: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}): cluster_dim " + f"={cluster_dim} is not a recognised BSHD lane " + f"position (0=B / 1=S / 2=H)" + ) + lane_span = int(parent.shape[cluster_dim]) + lane_ext = outer_iter[lane_key] + if lane_ext != lane_span: + raise IsaEmissionError( + f"VramRegion(parent={region.parent!r}): lane axis " + f"({lane_key.upper()}, cluster_dim={cluster_dim}) " + f"must cover the full lane span " + f"({lane_span}) under sync wrap; got extent " + f"{lane_ext}" + ) + # Fold lane axis out — one S_MAP per (b, s, h_grp, d_chunk) + # except along the lane direction itself. + outer_iter[lane_key] = 1 + + # 7D physical strides (in elements). One inner tile holds + # MLEN * LANE_COUNT * D_INNER contiguous values; outer tiles + # walk (B, H_GROUPS, S_TILES, D_TILES) with the standard 7D + # stride pattern. + tile_elems = tl.mlen * tl.lane_count * tl.d_inner + b_stride = tile_elems + h_grp_stride = tl.logical_b * tile_elems + s_tile_stride = tl.h_groups * h_grp_stride + d_tile_stride = tl.s_tiles * s_tile_stride + + # Helper: render ``starts[axis]`` + extra as a PrimExpr. + def _start_plus(axis: int, extra: int): + s = starts[axis] + if isinstance(s, int): + v = s + extra + return tir.IntImm("int32", int(v)) + if extra == 0: + return s + return tir.Add(s, tir.IntImm("int32", int(extra))) + + def _floordiv(expr, divisor: int): + if divisor == 1: + return expr + if isinstance(expr, tir.IntImm): + return tir.IntImm("int32", int(expr.value) // divisor) + return tir.FloorDiv(expr, tir.IntImm("int32", int(divisor))) + + def _floormod(expr, divisor: int): + if divisor == 1: + return tir.IntImm("int32", 0) + if isinstance(expr, tir.IntImm): + return tir.IntImm("int32", int(expr.value) % divisor) + return tir.FloorMod(expr, tir.IntImm("int32", int(divisor))) + + def _mul(expr, k: int): + if k == 0: + return tir.IntImm("int32", 0) + if k == 1: + return expr + if isinstance(expr, tir.IntImm): + return tir.IntImm("int32", int(expr.value) * k) + return tir.Mul(expr, tir.IntImm("int32", int(k))) + + def _sum(terms): + non_zero = [ + t for t in terms + if not (isinstance(t, tir.IntImm) and int(t.value) == 0) + ] + if not non_zero: + return tir.IntImm("int32", 0) + acc = non_zero[0] + for t in non_zero[1:]: + acc = tir.Add(acc, t) + return acc + + fp_elems_so_far = 0 + # Cartesian walk over (b_off, h_off, s_off, d_chunk). The lane + # axis among (B, S, H) was folded to extent 1 above so a + # single S_MAP per chunk covers every lane. + for d_chunk in range(d_chunks): + for s_off in range(outer_iter["s"]): + for h_off in range(outer_iter["h"]): + for b_off in range(outer_iter["b"]): + b_expr = _start_plus(0, b_off) + s_expr = _start_plus(1, s_off) + h_expr = _start_plus(2, h_off) + # d_start covers the current D_INNER-wide slot + # within the region; it indexes into parent's D. + d_expr = _start_plus(3, d_chunk * tl.d_inner) + + d_tile = _floordiv(d_expr, tl.mlen) + d_inner = _floormod(d_expr, tl.mlen) + s_tile = _floordiv(s_expr, tl.mlen) + s_inner = _floormod(s_expr, tl.mlen) + h_grp = _floordiv(h_expr, tl.lane_count) + lane = _floormod(h_expr, tl.lane_count) + + # 7D physical flat offset. + terms = [ + _mul(d_tile, d_tile_stride), + _mul(s_tile, s_tile_stride), + _mul(h_grp, h_grp_stride), + _mul(b_expr, b_stride), + _mul(s_inner, tl.lane_count * tl.d_inner), + _mul(lane, tl.d_inner), + d_inner, + ] + vram_off = _sum(terms) + yield vram_off, fp_elems_so_far + # One S_MAP transfers ``lane_count * d_inner`` + # (= MLEN) contiguous FPRAM slots — the whole + # lane group's slice in one issue. + fp_elems_so_far += tl.lane_count * tl.d_inner + + def _emit_v_fp_transfer_slice( + self, + mod: _hlir.HLIRModule, + op: _hlir.Op, + *, + direction: str, # "v_to_fp" or "fp_to_v" + ) -> None: + if len(op.buffer_args) != 1 or not isinstance(op.buffer_args[0], _hlir.VramRegion): + raise IsaEmissionError( + f"{op.kind}: buffer_args[0] must be VramRegion; " + f"got {op.buffer_args!r}" + ) + if len(op.scalar_args) != 1: + raise IsaEmissionError( + f"{op.kind}: expected 1 scalar arg (fp_addr); " + f"got {len(op.scalar_args)}" + ) + region: _hlir.VramRegion = op.buffer_args[0] + vram = mod.get_buffer(region.parent) + _check_scope(vram, _scope.VRAM, op.kind, "vram") + + fp_addr_base = self._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + opcode = "S_MAP_FP_V" if direction == "v_to_fp" else "S_MAP_V_FP" + + self.shim.compiler.generated_code += ( + f"; v↔fp transfer slice {op.kind} parent={region.parent} " + f"starts={list(region.starts)!r} extents={list(region.extents)!r}\n" + ) + + for vram_off_expr, fp_step in self._vram_region_iter_chunks(vram, region): + vram_addr_expr = tir.Add( + tir.IntImm("int32", int(vram.address)), + vram_off_expr, + ) + fp_chunk_addr = ( + fp_addr_base if fp_step == 0 + else tir.Add(fp_addr_base, tir.IntImm("int32", int(fp_step))) + ) + m_vram = self.materializer.materialize(vram_addr_expr) + self.shim.compiler.generated_code += m_vram.isa + m_fp = self.materializer.materialize(fp_chunk_addr) + self.shim.compiler.generated_code += m_fp.isa + try: + if direction == "v_to_fp": + self.shim.compiler.generated_code += ( + f"{opcode} gp{m_fp.register}, gp{m_vram.register}, 0\n" + ) + else: + self.shim.compiler.generated_code += ( + f"{opcode} gp{m_vram.register}, gp{m_fp.register}, 0\n" + ) + finally: + m_fp.release() + m_vram.release() + + def _emit_v_fp_transfer_slice_v_to_fp( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + self._emit_v_fp_transfer_slice(mod, op, direction="v_to_fp") + + def _emit_v_fp_transfer_slice_fp_to_v( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + self._emit_v_fp_transfer_slice(mod, op, direction="fp_to_v") + + def _emit_copy_v_to_v(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Region-based VRAM→VRAM copy: ``dst[region] = src[region]``. + + Schema (region layer): + buffer_args = [src_region, dst_region] (VramRegion 4D BSHD) + scalar_args = [] + + Each mlen-wide chunk emits one ``V_ADD_VF dst, src, f0, 0`` — + f0 == 0 by convention so ``src + 0`` is just src. + """ + if len(op.buffer_args) != 2: + raise IsaEmissionError( + f"copy_v_to_v expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise IsaEmissionError( + f"copy_v_to_v {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise IsaEmissionError( + f"copy_v_to_v expects 0 scalar_args; " + f"got {len(op.scalar_args)}" + ) + src_region: _hlir.VramRegion = op.buffer_args[0] + dst_region: _hlir.VramRegion = op.buffer_args[1] + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + _check_scope(src, _scope.VRAM, op.kind, "src") + _check_scope(dst, _scope.VRAM, op.kind, "dst") + if tuple(src_region.extents) != tuple(dst_region.extents): + raise IsaEmissionError( + f"copy_v_to_v: src/dst region extents must match; " + f"src={tuple(src_region.extents)} " + f"dst={tuple(dst_region.extents)}" + ) + self.shim.compiler.generated_code += ( + f"; copy_v_to_v src.parent={src_region.parent} -> " + f"dst.parent={dst_region.parent} " + f"extents={list(dst_region.extents)!r}\n" + ) + src_iter = self._vram_region_iter_chunks(src, src_region) + dst_iter = self._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = tir.Add( + tir.IntImm("int32", int(src.address)), s_off, + ) + dst_addr = tir.Add( + tir.IntImm("int32", int(dst.address)), d_off, + ) + m_src = self.materializer.materialize(src_addr) + self.shim.compiler.generated_code += m_src.isa + m_dst = self.materializer.materialize(dst_addr) + self.shim.compiler.generated_code += m_dst.isa + try: + self.shim.compiler.generated_code += ( + f"V_ADD_VF gp{m_dst.register}, gp{m_src.register}, " + f"f0, 0\n" + ) + finally: + m_dst.release() + m_src.release() + + # ------------------------------------------------------------------ + # Structured ops: For + # ------------------------------------------------------------------ + def _emit_for(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Emit `C_LOOP_START / body / inc / C_LOOP_END` for a structured + For op. + + PLENA's hardware loop is: + C_LOOP_START gp_loop, IMM_count + ... + C_LOOP_END gp_loop + where IMM_count is a literal iteration count and gp_loop is an + internal counter the hardware decrements -- it is NOT the + iteration index. So we need TWO registers: + * `gp_loop` -- hardware counter, opaque to body + * `gp_idx` -- body-visible iteration variable, bound to + the TIR loop_var via symbol_table; manually + initialised before C_LOOP_START and + incremented by 1 at the end of every + iteration. + + Constraints: + * extent must be a Python int / IntImm (immediate field of + C_LOOP_START is a literal). PrimExpr extents are not + supported -- they would need a compile-time evaluation + pass or a different lowering (no native loop-with-runtime- + count instruction in PLENA's ISA). + * init must be int (typically 0). PrimExpr inits are + unsupported for the same reason: would force runtime + loop-bound recomputation. + """ + loop_var = op.annotations.get("loop_var") + extent = op.annotations.get("extent") + init = op.annotations.get("init", 0) + if loop_var is None or extent is None: + raise IsaEmissionError( + f"for-op missing loop_var or extent annotation: {op!r}" + ) + if not isinstance(extent, (int, tir.IntImm)): + raise IsaEmissionError( + f"for-op extent must be a compile-time integer (PLENA's " + f"C_LOOP_START takes an immediate). Got {type(extent).__name__}: " + f"{extent!r}. Restructure the kernel so the loop bound is known " + f"at TIR-construction time." + ) + if not isinstance(init, (int, tir.IntImm)): + raise IsaEmissionError( + f"for-op init must be a compile-time integer. Got " + f"{type(init).__name__}: {init!r}." + ) + extent_imm = int(extent.value) if isinstance(extent, tir.IntImm) else int(extent) + init_imm = int(init.value) if isinstance(init, tir.IntImm) else int(init) + if loop_var in self.symbol_table: + raise IsaEmissionError( + f"loop_var {loop_var.name!r} (id={id(loop_var)}) already " + f"bound; nested loops reusing the same Var aren't supported. " + f"Active bindings: " + f"{[(v.name, id(v)) for v in self.symbol_table]!r}" + ) + + ra = self.shim.compiler.register_allocator + loop_kind = op.annotations.get("loop_kind", "serial") + + # Compile-time unroll: emit the body N times back-to-back with + # loop_var rebound to a literal each iteration. Use this to break + # out of MAX_LOOP_INSTRUCTIONS-per-iter when one outer iteration's + # body would otherwise dispatch too many dynamic instructions + # (e.g. an inner kv_block accumulation containing a 16x16 unrolled + # emit_matmul). Costs one S_ADDI_INT per iter to re-init gp_idx; + # the hardware loop overhead disappears entirely. + if loop_kind in ("unroll", "unrolled"): + # Unrolled: the loop variable takes a known constant value in + # each iteration, so bind it to a tir.IntImm rather than a + # GP. The materializer constant-folds every use — no GP is + # pinned, no per-iter ``S_ADDI_INT`` is emitted. This is what + # keeps a deep unrolled nest from exhausting the GP file. + self.shim.compiler.generated_code += ( + f"; unroll for {loop_var.name} in " + f"[{init_imm}, {init_imm + extent_imm}) -- idx is a literal\n" + ) + # All N unrolled iterations share the SAME body [NN] indices + # (format_hlir prints the body once). Snapshot the counter + # before the first iter and rewind to it at the start of + # every iter; let the last iter leave it at the body end so + # sibling ops after the for keep numbering correctly. + body_start_idx = self._lowir_idx + try: + for i in range(extent_imm): + iter_val = init_imm + i + self.symbol_table[loop_var] = tir.IntImm("int32", iter_val) + self.shim.compiler.generated_code += ( + f"; ... unroll iter {i} -> {loop_var.name}={iter_val}\n" + ) + self._lowir_idx = body_start_idx + for j, sub_op in enumerate(op.body or []): + handler = self._dispatch.get(sub_op.kind) + if handler is None: + raise IsaEmissionError( + f"no ISA dispatcher for nested op kind " + f"{sub_op.kind!r} inside unrolled for-loop" + ) + self._lowir_idx += 1 + self.materializer.set_lowir_op_idx(self._lowir_idx) + ra.push_site(f"unroll[{i}].body[{j}] {sub_op.kind}") + self.materializer.begin_op() + try: + handler(mod, sub_op) + finally: + self.materializer.end_op() + ra.pop_site() + finally: + self.symbol_table.pop(loop_var, None) + return + + # gp_loop is the PLENA hw counter — C_LOOP_END decrements it. + # The loop-register allocation pass picked the GP by HLIR + # liveness and stamped it on the op; that GP is in the emit + # allocator's ``gp_reserved`` set, so it is physically disjoint + # from any temporary. We pin it here only as a defensive marker. + gp_loop = op.annotations.get("loop_gp") + if gp_loop is None: + raise IsaEmissionError( + f"serial for-op {loop_var.name!r} has no 'loop_gp' " + f"annotation; loop_register_alloc must run before isa_pass" + ) + ra.pin_gp(gp_loop) + + # idx lives in IntRAM, not a GP. Deep nests (flash_attention with + # an inner matmul, conv2d's 6-level grid) would exhaust the GP + # file if every loop pinned two GPs. Storing the idx in IntRAM + # keeps it to 1 GP per loop — the materializer re-loads the idx + # on every use via S_LD_INT. (Re-load cost to be optimised later + # with a per-op materialisation cache.) + idx_addr = ra.claim_idx_slot() + if init_imm == 0: + # gp0 is constant zero — store it straight to the idx slot. + self.shim.compiler.generated_code += ( + f"; for {loop_var.name} in [{init_imm}, {init_imm + extent_imm}) " + f"-- hw counter gp{gp_loop}, idx ram[{idx_addr}]\n" + f"S_ST_INT gp0, gp0, {idx_addr}\n" + f"C_LOOP_START gp{gp_loop}, {extent_imm}\n" + ) + else: + init_gp = ra.allocate_gp(1)[0] + self.shim.compiler.generated_code += ( + f"; for {loop_var.name} in [{init_imm}, {init_imm + extent_imm}) " + f"-- hw counter gp{gp_loop}, idx ram[{idx_addr}]\n" + f"S_ADDI_INT gp{init_gp}, gp0, {init_imm}\n" + f"S_ST_INT gp{init_gp}, gp0, {idx_addr}\n" + f"C_LOOP_START gp{gp_loop}, {extent_imm}\n" + ) + ra.free_gp([init_gp]) + + self.symbol_table[loop_var] = ("ram", idx_addr) + try: + for j, sub_op in enumerate(op.body or []): + handler = self._dispatch.get(sub_op.kind) + if handler is None: + raise IsaEmissionError( + f"no ISA dispatcher for nested op kind {sub_op.kind!r} " + f"inside for-loop" + ) + # Depth-first: each body op gets the next [NN] index, + # right after the enclosing for. Nested fors recurse and + # advance the counter further, matching _format_ops. + self._lowir_idx += 1 + self.materializer.set_lowir_op_idx(self._lowir_idx) + ra.push_site(f"for[{loop_var.name}].body[{j}] {sub_op.kind}") + self.materializer.begin_op() + try: + handler(mod, sub_op) + finally: + self.materializer.end_op() + ra.pop_site() + + # idx += 1: load -> addi -> store. Borrow one GP for the + # round-trip. Inside the try so a body failure still hits the + # finally's GP / idx-slot release. + inc_gp = ra.allocate_gp(1)[0] + self.shim.compiler.generated_code += ( + f"; idx {loop_var.name} += 1 (ram[{idx_addr}])\n" + f"S_LD_INT gp{inc_gp}, gp0, {idx_addr}\n" + f"S_ADDI_INT gp{inc_gp}, gp{inc_gp}, 1\n" + f"S_ST_INT gp{inc_gp}, gp0, {idx_addr}\n" + f"C_LOOP_END gp{gp_loop}\n" + ) + ra.free_gp([inc_gp]) + finally: + # Release loop-owned state on EVERY exit path — including a + # body exception. ``gp_loop`` is a reserved GP (assigned by + # loop_register_alloc, never taken from the free pool), so we + # only unpin it — NOT free_gp, which would corrupt the pool. + self.symbol_table.pop(loop_var, None) + ra.unpin_gp(gp_loop) + ra.release_idx_slot(idx_addr) + + +def _check_scope(buf: _hlir.Buffer, expected: str, op_kind: str, role: str) -> None: + # `global.` is treated as `` for ISA-level scope checks — + # the user-declared global flag changes lane-fusion behaviour but the + # buffer's physical residency (and therefore the legal operand-scope + # rules for each instruction) is identical. Keep `buf.scope` as the + # original string so JSON dumps / debug output retain the global flag. + if _scope.physical_scope(buf.scope) != expected: + raise IsaEmissionError( + f"{op_kind} {role} buffer {buf.name!r} must be in scope {expected!r}, " + f"got {buf.scope!r}" + ) + + +__all__ = ["IsaEmitterPass", "IsaEmissionError"] diff --git a/tilelang_tvm_compiler/kernels/__init__.py b/tilelang_tvm_compiler/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilelang_tvm_compiler/kernels/_head_layout.py b/tilelang_tvm_compiler/kernels/_head_layout.py new file mode 100644 index 0000000..c567ef0 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/_head_layout.py @@ -0,0 +1,62 @@ +"""HBM head-layout view helpers — BSHD <-> B,S,1,H*D. + +Both layouts share the *same* row-major fp16 bytes in HBM. The only +difference is how the next kernel's ``T.Tensor((...), ...)`` signature +declares the logical shape. Use these helpers when a producer kernel +emits BSHD but the consumer wants B,S,1,H*D (or vice versa). + +Pure ``torch.Tensor.view`` — zero copy, contiguous-preserving. +""" + +from __future__ import annotations + +import torch + + +def pack_heads(x_bshd: torch.Tensor) -> torch.Tensor: + """[B, S, H, D] -> [B, S, 1, H*D]. + + Same memory, just a different logical view. The producing kernel + wrote H*D contiguous fp16 elements per (B,S) row; the consuming + kernel declares them as a single "head" of width H*D. + """ + if x_bshd.dim() != 4: + raise ValueError(f"pack_heads expects 4D BSHD; got shape {tuple(x_bshd.shape)}") + if not x_bshd.is_contiguous(): + raise ValueError( + "pack_heads requires a contiguous tensor (view-only op). " + "Call .contiguous() upstream if the producer's output was permuted." + ) + B, S, H, D = x_bshd.shape + return x_bshd.view(B, S, 1, H * D) + + +def unpack_heads(x_packed: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor: + """[B, S, 1, H*D] -> [B, S, H, D]. + + Inverse of ``pack_heads``. ``num_heads * head_dim`` must equal the + last-dim of the packed tensor. + """ + if x_packed.dim() != 4: + raise ValueError( + f"unpack_heads expects 4D B,S,1,H*D; got shape {tuple(x_packed.shape)}" + ) + if x_packed.shape[2] != 1: + raise ValueError( + f"unpack_heads expects head-axis == 1 (packed); got shape " + f"{tuple(x_packed.shape)}" + ) + if not x_packed.is_contiguous(): + raise ValueError( + "unpack_heads requires a contiguous tensor (view-only op)." + ) + B, S, _, HD = x_packed.shape + if HD != num_heads * head_dim: + raise ValueError( + f"unpack_heads: packed last-dim {HD} != num_heads*head_dim " + f"{num_heads}*{head_dim} = {num_heads * head_dim}" + ) + return x_packed.view(B, S, num_heads, head_dim) + + +__all__ = ["pack_heads", "unpack_heads"] diff --git a/tilelang_tvm_compiler/kernels/concat_min.py b/tilelang_tvm_compiler/kernels/concat_min.py new file mode 100644 index 0000000..1e01457 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/concat_min.py @@ -0,0 +1,128 @@ +"""Concat-min kernel — feature-axis concatenation of two head-packed +tensors. + +Both inputs are taken in the HEAD-PACKED view ``[B, S, 1, dim]`` (see +_head_layout.py: BSHD ``[B,S,H,D]`` and packed ``[B,S,1,H*D]`` share +the same row-major fp16 bytes — heads folded into the feature axis). +Concatenating along that feature axis: + + Y[:, :, 0, 0:Adim] = A + Y[:, :, 0, Adim:Adim+Bdim] = B + +The copy is done in MLEN-wide blocks (the hardware tile granularity), +NOT per-head: each (B,S) row is ``dim`` contiguous fp16 elements = +``dim / MLEN`` blocks, and the kernel walks those blocks. A's blocks +land in Y's first ``Adim`` columns, B's in the next ``Bdim``. + +Plain VRAM->VRAM copy — no FPRAM, no compute. + +Why a dedicated kernel: the single-stream-block chain needs +``concat([attn_out, mlp_out])`` as linear2's input. Letting each +producer (flash_attention, gelu) write its OWN compact tensor and +joining them here keeps attention/gelu on their plain compact-output +path (no o_head_offset writeback) and keeps every step independently +verifiable. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_concat_min( + *, + rows: int | None = None, + a_dim: int = 128, + b_dim: int = 128, + num_s_blocks: int = 2, + batch: int = 1, +): + """Feature-axis concat of two head-packed tensors. + + A_hbm : [batch, seq, 1, a_dim] + B_hbm : [batch, seq, 1, b_dim] + Y_hbm : [batch, seq, 1, a_dim + b_dim] + Y[..., 0, 0:a_dim] = A + Y[..., 0, a_dim:a_dim+b_dim] = B + + ``a_dim`` and ``b_dim`` must each be a multiple of MLEN (the copy + walks MLEN-wide blocks). Inputs are the head-packed [B,S,1,dim] + view; a BSHD producer output aliases this byte-for-byte. + """ + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"concat_min requires rows == MLEN ({MLEN}), got {rows}") + if a_dim % MLEN != 0: + raise ValueError(f"a_dim must be a multiple of MLEN ({MLEN}); got {a_dim}") + if b_dim % MLEN != 0: + raise ValueError(f"b_dim must be a multiple of MLEN ({MLEN}); got {b_dim}") + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + total_dim = a_dim + b_dim + a_blocks = a_dim // MLEN + b_blocks = b_dim // MLEN + + @T.prim_func + def concat_min( + A_hbm: T.Tensor((batch, seq_len, 1, a_dim), "float16"), + B_hbm: T.Tensor((batch, seq_len, 1, b_dim), "float16"), + Y_hbm: T.Tensor((batch, seq_len, 1, total_dim), "float16"), + ): + # Grid iterates seq blocks. Each program copies one (rows x dim) + # tile per source, MLEN-wide block at a time, in a single + # uniform body. A's blocks fill Y[..., 0:a_dim]; B's blocks fill + # Y[..., a_dim:total_dim]. All block offsets are compile-time + # constants — no grid-variable branch. + with T.Kernel(num_s_blocks, threads=128) as s_block: + A_sh = T.alloc_shared((rows, MLEN), "float16") + B_sh = T.alloc_shared((rows, MLEN), "float16") + + # A -> first a_dim columns of Y. + for blk in T.serial(a_blocks): + T.copy( + A_hbm[0, s_block * rows : (s_block + 1) * rows, + 0, blk * MLEN : (blk + 1) * MLEN], + A_sh, + ) + T.copy( + A_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + 0, blk * MLEN : (blk + 1) * MLEN], + ) + + # B -> next b_dim columns of Y (shifted by a_dim). + for blk in T.serial(b_blocks): + T.copy( + B_hbm[0, s_block * rows : (s_block + 1) * rows, + 0, blk * MLEN : (blk + 1) * MLEN], + B_sh, + ) + T.copy( + B_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + 0, a_dim + blk * MLEN : a_dim + (blk + 1) * MLEN], + ) + + lowered = concat_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "A_DIM": a_dim, + "B_DIM": b_dim, + "TOTAL_DIM": total_dim, + "A_BLOCKS": a_blocks, + "B_BLOCKS": b_blocks, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + } + return lowered, constants + + +__all__ = ["make_concat_min"] diff --git a/tilelang_tvm_compiler/kernels/conv2d_min.py b/tilelang_tvm_compiler/kernels/conv2d_min.py new file mode 100644 index 0000000..dc10b9a --- /dev/null +++ b/tilelang_tvm_compiler/kernels/conv2d_min.py @@ -0,0 +1,227 @@ +"""Conv2D-min — 2D convolution with optional multi-channel support. + +Shapes (NCHW): + + Input : (1, C_IN, H_PAD, W_PAD) pre-padded right/bottom for the kernel + Weight : (C_OUT, C_IN, KH, KW) flattened to FPRAM in (oc, ic, kh*KW+kw) order + Output : (1, C_OUT, H, W) same spatial as logical input + +For the simplest case (C_IN=C_OUT=1), this reduces to the original +single-channel kernel. + +Why no GEMM here: + + The natural per-output-row work is a (MLEN_w, C_IN*KH*KW) gather × + (C_IN*KH*KW,) weight = (MLEN_w,) row — a GEMV (matrix × vector). + PLENA has no GEMV instruction; the smallest matmul tile is + (MLEN, MLEN) × (MLEN, MLEN), which would be very sparse for typical + conv shapes (especially when C_IN*KH*KW < MLEN). Instead we lower + the whole thing to vector-scalar FMAs: + + for c_in: for kh: for kw: # C_IN*KH*KW unrolled + for m in T.Parallel(MLEN): # one HW vector op + C_loc[m] += in_shifted[m] * weight[oc, c_in, kh, kw] + + Each iter is one ``plena.v_mul + plena.v_add`` (or fused FMA) on a + 64-wide vector. + +Construction (per (oc, oh) output row): + + 1. **kw-shift via FPRAM padded fragment**: per (ic, kh), copy one + MLEN-wide input row into ``in_FP_padded`` (size MLEN+KW-1, last + KW-1 slots zero-init). For each kw, read shifted slice + ``in_FP_padded[m + kw_idx]`` into shift_FP, then back to VRAM + A_sh. + + 2. **Vector-scalar FMA**: ``A_sh *= B_FP[oc, ic, kh, kw]``, then + ``A_sh_acc += A_sh``. + + 3. **Writeback**: one ``T.copy`` writes A_sh_acc into + ``C_loc[0, oc, oh, :]``. After all (oc, oh) iterations, one + ``T.copy(C_loc, Output[0, 0, 0, 0])`` dumps everything to HBM. + +Constraints: + * KH * KW == HLEN (= 16) + * W == MLEN (single output W tile per row) + * C_OUT * C_IN * MLEN must fit in FPRAM (B_FP holds the entire + weight tensor MLEN-padded). For C_OUT=4, C_IN=4: 1024 elements. + * Output staging in __main__._emit_output_staging currently only + handles C_OUT == 1 cleanly (it walks logical 2D as rows × cols + with constant stride; for C_OUT > 1 NCHW the cross-channel stride + differs from the cross-row stride). Multi-C_OUT staging is + queued for a follow-up. + +Layout contract for ``B_cache`` (testbench-side preload): + B_cache[oc * C_IN + ic, k_tap] = Weight[oc, ic, kh, kw] + where k_tap = kh*KW + kw (same row-major weight ordering the + original single-channel kernel used; just extended by the + (oc, ic) outer pair). +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_conv2d_min( + *, + h_in: int = 64, + w_in: int = 64, + kh: int = 4, + kw: int = 4, + c_in: int = 1, + c_out: int = 1, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + HLEN = _hw.hlen + + if kh * kw != HLEN: + raise ValueError( + f"first-cut conv2d_min requires kh*kw == HLEN ({HLEN}); " + f"got kh={kh}, kw={kw}, kh*kw={kh*kw}" + ) + if w_in != MLEN: + raise ValueError( + f"first-cut conv2d_min requires w_in == MLEN ({MLEN}); got w_in={w_in}" + ) + if h_in <= 0: + raise ValueError(f"h_in must be positive; got h_in={h_in}") + if c_in <= 0 or c_out <= 0: + raise ValueError(f"c_in and c_out must be positive; got c_in={c_in}, c_out={c_out}") + + H = h_in + W = w_in + KH = kh + KW = kw + K_FLAT = KH * KW # = HLEN, the unrolled-1D tap count + C_IN = c_in + C_OUT = c_out + OC_IC = C_OUT * C_IN # number of (oc, ic) weight rows in B_cache + + def _round_up_to_mlen(x: int) -> int: + return (x + MLEN - 1) // MLEN * MLEN + H_PAD = _round_up_to_mlen(H + KH - 1) + W_PAD = _round_up_to_mlen(W + KW - 1) + + @T.prim_func + def conv2d_min( + Input: T.Tensor((1, C_IN, H_PAD, W_PAD), "float16"), + Output: T.Tensor((1, C_OUT, H, W), "float16"), + ): + T.func_attr({"plena.layout": "NCHW"}) + if False: + _ = (H_PAD, W_PAD, H, W, C_IN, C_OUT, OC_IC) + + with T.Kernel(1, threads=1) as _bx: + # Single-channel padded input tile, re-staged per ic. + in_stage = T.alloc_shared((H_PAD, W_PAD), "float16") + + # ``A_sh`` and ``A_sh_acc`` live in VRAM so the per-tap + # multiply + accumulate lower to vector instructions + # (``V_MUL_VF`` / ``V_ADD_VV``) instead of the per-element + # FPRAM scalar loop. The kw-shift chain stays in FPRAM — + # only the multiply and the accumulate move to vram. + # + # Shape ``(1, MLEN)`` (not ``(MLEN,)``) on purpose: fold's + # broadcast detection only kicks in when ``len(src.indices) + # < len(dst.indices)``, so a 1D shared dst with a scalar + # fp src ``w_aux[0]`` fails to fold (same rank, no + # broadcast path). 2D dst + 1D scalar src matches the + # path flash_attention already uses. + A_sh = T.alloc_shared((1, MLEN), "float16") + A_sh_acc = T.alloc_shared((1, MLEN), "float16") + + # ``B_FP`` holds the full weight tensor after MLEN-padding: + # OC_IC rows of MLEN slots each. The testbench's + # ``fp_preload`` writes weights into FPRAM at this buffer's + # allocated address before the kernel runs. + B_FP = T.alloc_fragment((OC_IC * MLEN,), "float16", + scope="global.fpram") + # Per-tap weight scalar — 1D so the FPRAM-scalar fold path + # accepts the multiply (`A_sh[m] = A_sh[m] * w_aux[0]`). + w_aux = T.alloc_fragment((1,), "float16") + in_FP_aux = T.alloc_fragment((MLEN,), "float16") + in_FP_padded = T.alloc_fragment((MLEN + KW - 1,), "float16") + shift_FP = T.alloc_fragment((MLEN,), "float16") + + # Single-channel output tile, drained to HBM per oc. + C_loc = T.alloc_shared((MLEN, MLEN), "float16") + + for k in T.serial(KW - 1): + in_FP_padded[MLEN + k] = T.float16(0) + + for oc in T.serial(C_OUT): + for oh in T.serial(H): + for m in T.Parallel(MLEN): + A_sh_acc[0, m] = T.float16(0) + + for ic in T.serial(C_IN): + # (Re-)stage this input channel's full padded + # tile into VRAM. Inner kh_idx reads from it + # row-wise. + T.copy( + Input[0, ic, 0:H_PAD, 0:W_PAD], + in_stage[0:H_PAD, 0:W_PAD], + ) + for kh_idx in T.unroll(KH): + T.copy( + in_stage[oh + kh_idx, 0:MLEN], + in_FP_aux[0:MLEN], + ) + + for i in T.serial(MLEN): + in_FP_padded[i] = in_FP_aux[i] + + for kw_idx in T.unroll(KW): + k_tap = kh_idx * KW + kw_idx + + for m in T.serial(MLEN): + shift_FP[m] = in_FP_padded[m + kw_idx] + + # fpram → vram row copy: lowers to a + # single ``S_MAP_V_FP`` (whole MLEN row + # in one issue) instead of MLEN scalar + # loads + stores. + T.copy(shift_FP[0:MLEN], A_sh[0, 0:MLEN]) + + w_aux[0] = B_FP[(oc * C_IN + ic) * MLEN + k_tap] + + # vram × fp_scalar broadcast → one + # ``V_MUL_VF`` per row; with A_sh being + # 1×MLEN that's a single instruction. + for m in T.Parallel(MLEN): + A_sh[0, m] = A_sh[0, m] * w_aux[0] + # vram + vram → one ``V_ADD_VV``. + for m in T.Parallel(MLEN): + A_sh_acc[0, m] = A_sh_acc[0, m] + A_sh[0, m] + + T.copy( + A_sh_acc[0, 0:MLEN], + C_loc[oh, 0:MLEN], + ) + + # Drain this oc's full (MLEN, MLEN) tile to HBM. + T.copy( + C_loc[0:MLEN, 0:MLEN], + Output[0, oc, 0:MLEN, 0:MLEN], + ) + + # Return the raw PrimFunc. ``compile_kernel`` runs stmt prep + the + # mid_ir pipeline itself, so factories no longer need to call into + # the legacy compile_func. + lowered = conv2d_min + + constants = { + "H": H, "W": W, + "H_PAD": H_PAD, "W_PAD": W_PAD, + "KH": KH, "KW": KW, + "K_FLAT": K_FLAT, + "MLEN": MLEN, "HLEN": HLEN, + "C_IN": C_IN, "C_OUT": C_OUT, + } + return lowered, constants + + +__all__ = ["make_conv2d_min"] diff --git a/tilelang_tvm_compiler/kernels/copy_offset_min.py b/tilelang_tvm_compiler/kernels/copy_offset_min.py new file mode 100644 index 0000000..b17b291 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/copy_offset_min.py @@ -0,0 +1,254 @@ +"""Copy-offset-min kernel — staged probe for the o_head_offset path. + +Reads a compact ``[B, S, H, D]`` input and writes it into a head-slice +``[o_head_offset : o_head_offset + H]`` of a wider ``[B, S, o_head_count, +D]`` output. The DMA structure matches gelu_min's offset variant; the +``compute`` knob dials how much per-element FPRAM math runs in between, +so a binary search over compute stages can pin down which operator's +interaction with the offset writeback is broken. + +``compute`` stages (each adds one GELU-relevant operator): + "copy" VRAM->VRAM verbatim, no FPRAM at all. + "id" per-row VRAM->FPRAM->VRAM, identity (Y_FP = X_FP). + "mul" Y_FP[i] = X_FP[i] * X_FP[i] (plain fp mul) + "const_mul" Y_FP[i] = 0.5 * X_FP[i] (hoisted-const mul) + "exp" Y_FP[i] = exp(X_FP[i]) (fp exp) + "reci" Y_FP[i] = 1.0 / X_FP[i] (fp reciprocal) +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +_COMPUTE_STAGES = ("copy", "id", "mul", "const_mul", "exp", "reci") + + +def make_copy_offset_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, + o_head_count: int | None = None, + o_head_offset: int = 0, + compute: str = "copy", + # Back-compat: ``fp_roundtrip=True`` is the old name for compute="id". + fp_roundtrip: bool | None = None, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"copy_offset_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + if o_head_count is None: + o_head_count = head_count + if o_head_count < head_count: + raise ValueError( + f"o_head_count ({o_head_count}) must be >= head_count ({head_count})" + ) + if not (0 <= o_head_offset <= o_head_count - head_count): + raise ValueError( + f"o_head_offset ({o_head_offset}) + head_count ({head_count}) " + f"must fit within o_head_count ({o_head_count})" + ) + if fp_roundtrip is not None: + compute = "id" if fp_roundtrip else "copy" + if compute not in _COMPUTE_STAGES: + raise ValueError( + f"compute must be one of {_COMPUTE_STAGES}; got {compute!r}" + ) + + seq_len = num_s_blocks * rows + + # ----- compute == "copy": plain VRAM->VRAM, no FPRAM ----- + @T.prim_func + def copy_offset_copy( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + T.copy(X_sh, Y_sh) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + # ----- compute == "id": per-row FPRAM round-trip, identity ----- + @T.prim_func + def copy_offset_id( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + for i in T.unroll(hlen): + Y_FP[i] = X_FP[i] + T.copy(Y_FP, Y_sh[row, 0]) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + # ----- compute == "mul": Y = X * X ----- + @T.prim_func + def copy_offset_mul( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + for i in T.unroll(hlen): + Y_FP[i] = X_FP[i] * X_FP[i] + T.copy(Y_FP, Y_sh[row, 0]) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + # ----- compute == "const_mul": Y = 0.5 * X (hoisted-const mul) ----- + @T.prim_func + def copy_offset_const_mul( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + for i in T.unroll(hlen): + Y_FP[i] = T.float16(0.5) * X_FP[i] + T.copy(Y_FP, Y_sh[row, 0]) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + # ----- compute == "exp": Y = exp(X) ----- + @T.prim_func + def copy_offset_exp( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + for i in T.unroll(hlen): + Y_FP[i] = T.exp(X_FP[i]) + T.copy(Y_FP, Y_sh[row, 0]) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + # ----- compute == "reci": Y = 1.0 / X ----- + @T.prim_func + def copy_offset_reci( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + for i in T.unroll(hlen): + Y_FP[i] = T.float16(1.0) / X_FP[i] + T.copy(Y_FP, Y_sh[row, 0]) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + _STAGE_FUNCS = { + "copy": copy_offset_copy, + "id": copy_offset_id, + "mul": copy_offset_mul, + "const_mul": copy_offset_const_mul, + "exp": copy_offset_exp, + "reci": copy_offset_reci, + } + lowered = _STAGE_FUNCS[compute] + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "O_HEAD_COUNT": o_head_count, + "O_HEAD_OFFSET": o_head_offset, + "COMPUTE": compute, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_copy_offset_min"] diff --git a/tilelang_tvm_compiler/kernels/flash_attention_gemm_only.py b/tilelang_tvm_compiler/kernels/flash_attention_gemm_only.py new file mode 100644 index 0000000..f02b1c4 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/flash_attention_gemm_only.py @@ -0,0 +1,109 @@ +"""flash_attention gemm-only debug kernel — with a step dial. + +Base: BTMM(Q@K^T) + matmul(S@V), no softmax. ``fd_steps`` (0..6) adds +the softmax of flash_attention_min.py back ONE BLOCK AT A TIME, where +each block is copied VERBATIM from flash_attention_min (so every level +is a strict subset of that kernel and is guaranteed to compile): + + 0 O = (Q@K^T) @ V (gemm-only base) + 1 + block A : S *= scale ; M_CURR = M_OLD + 2 + reduce_max -> M_CURR + 3 + block B : M_RES = exp(M_OLD-M_CURR) ; S = exp(S-M_CURR) ; P_SUM=0 + 4 + reduce_sum -> P_SUM + 5 + block C : L_NEW = L_OLD*M_RES+P_SUM ; O *= M_RES ; advance state + 6 + block D : L_INV = 1/L_NEW ; O *= L_INV (== full flash_attention) + +These are flash_attention_min.py's own natural code blocks (lines +190-229), not an arbitrary split — so each level always compiles. +The testbench golden mirrors the same fd_steps. + +NOTE the O at intermediate levels: + * 1/2 : O == (scale * Q@K^T) @ V (M_CURR computed, unused for O) + * 3/4 : O == exp(scale*S - M_CURR) @ V + * 5 : O == M_RES * (exp(...) @ V) + * 6 : O == (M_RES * (exp(...) @ V)) / L_NEW +""" + +import math + +import tilelang.language as T + +from ..frontend.gemm_macros import KIND +from ..plena_settings import load_sizes as _load_sizes + + +def make_flash_attention_gemm_only( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int | None = None, + num_kv_blocks: int = 1, + num_q_blocks: int = 1, + fd_steps: int = 0, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"flash_attention_gemm_only requires rows == MLEN ({MLEN}), got {rows}" + ) + if MLEN % hlen != 0: + raise ValueError( + f"hlen must divide MLEN ({MLEN}); got hlen={hlen}" + ) + if not (0 <= fd_steps <= 6): + raise ValueError(f"fd_steps must be in [0, 6], got {fd_steps}") + hardware_lane_count = MLEN // hlen + if head_count is None: + head_count = hardware_lane_count + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + + kv_seq = num_kv_blocks * rows + q_seq = num_q_blocks * rows + scale_val = 1.0 / math.sqrt(hlen) + + # DMA-IN-ONLY probe: HBM -> VRAM, nothing else. No writeback, no + # gemm, no FPRAM. Isolates a single dma_h2v_slice (H_PREFETCH_V + # chain). The result lives in Q_sh (VRAM); compare it directly, + # no O_hbm, no compare-staging. ``fd_steps`` and K/V/O are kept + # for signature compatibility but unused. + @T.prim_func + def flash_attention_gemm_only( + Q_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + K_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + O_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + ): + with T.Kernel(num_q_blocks, head_count, threads=128) as (q_block, by): + Q_sh = T.alloc_shared((rows, hlen), "float16") + + # HBM -> VRAM. That's it. + T.copy( + Q_hbm[0, q_block * rows : (q_block + 1) * rows, by, 0:hlen], + Q_sh, + ) + + lowered = flash_attention_gemm_only + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "HARDWARE_LANE_COUNT": hardware_lane_count, + "NUM_KV_BLOCKS": num_kv_blocks, + "NUM_Q_BLOCKS": num_q_blocks, + "FD_STEPS": fd_steps, + } + return lowered, constants + + +__all__ = ["make_flash_attention_gemm_only"] diff --git a/tilelang_tvm_compiler/kernels/flash_attention_gqa_min.py b/tilelang_tvm_compiler/kernels/flash_attention_gqa_min.py new file mode 100644 index 0000000..bd4fe42 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/flash_attention_gqa_min.py @@ -0,0 +1,269 @@ +"""Flash-attention-GQA-min kernel — grouped-query attention. + +Variant of ``flash_attention_min`` where the KV head count is smaller +than the Q head count: ``kv_head_count = head_count // group_size``. +Each Q head ``by`` shares a KV head with ``group_size - 1`` siblings. + +KV-head index mapping +--------------------- +Q head ``by`` reads KV head ``by % kv_head_count``. + +We use ``%`` (modulo), NOT ``//`` (floordiv), on purpose: + + * The PLENA ISA implements a modulo op but has **no integer-divide + op**, so ``by // group_size`` cannot be lowered to hardware. + * ``%`` gives an *interleaved* group layout: Q heads + ``h, h+kv_head_count, h+2*kv_head_count, ...`` all map to KV head + ``h``. (A ``//``-based layout would be contiguous groups + ``0..G-1 -> 0`` — that layout is not expressible on this HW.) + +So callers must lay Q heads out interleaved-by-KV-head in HBM. + +KV tensors are declared with ``kv_head_count`` heads — this is the real +GQA memory saving, K/V genuinely store fewer heads. The grid still +iterates ``head_count`` Q heads; the ``by % kv_head_count`` index picks +the shared KV head. + +Everything else (online softmax, BTMM Q@K^T, per-head P@V, lane fusion) +is identical to ``flash_attention_min`` — see that file's docstring. +The frontend's ``_subst_lane_var`` recurses through arbitrary index +``op`` nodes, so ``by % kv_head_count`` survives the lane-axis split +(``by`` -> ``phase + number*cluster_count``) with no pass changes. +""" + +import math + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + +from ..address_alloc import FPRAM_USER_BASE +from ..frontend.gemm_macros import KIND + + +def make_flash_attention_gqa_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int | None = None, + group_size: int = 2, + lane_count: int | None = None, + active_lane: int = 0, + num_kv_blocks: int = 1, + num_q_blocks: int = 2, + o_head_count: int | None = None, + o_head_offset: int = 0, +): + """Grouped-query flash attention with online softmax. + + ``group_size`` Q heads share one KV head; ``kv_head_count = + head_count // group_size``. ``group_size == 1`` degenerates to plain + MHA (identical to ``flash_attention_min``). + + ``o_head_count`` / ``o_head_offset`` behave exactly as in + ``flash_attention_min`` — they let the kernel drop its output into a + head-slice of a wider output tensor. + """ + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"flash_attention_gqa_min requires rows == MLEN ({MLEN}), got {rows}" + ) + if MLEN % hlen != 0: + raise ValueError( + f"hlen must divide MLEN ({MLEN}); got hlen={hlen}" + ) + hardware_lane_count = MLEN // hlen + if head_count is None: + head_count = lane_count if lane_count is not None else hardware_lane_count + elif lane_count is not None and lane_count != head_count: + raise ValueError( + f"head_count and legacy lane_count disagree: {head_count} vs {lane_count}" + ) + if head_count < 1: + raise ValueError(f"head_count must be >= 1, got {head_count}") + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of hardware lane width " + f"MLEN/hlen={hardware_lane_count}; got {head_count}" + ) + if group_size < 1: + raise ValueError(f"group_size must be >= 1, got {group_size}") + if head_count % group_size != 0: + raise ValueError( + f"head_count ({head_count}) must be a multiple of group_size " + f"({group_size})" + ) + kv_head_count = head_count // group_size + if not (0 <= active_lane < hardware_lane_count): + raise ValueError( + f"active_lane out of hardware lane range [0, {hardware_lane_count}): " + f"{active_lane}" + ) + if num_kv_blocks < 1: + raise ValueError(f"num_kv_blocks must be >= 1, got {num_kv_blocks}") + if num_q_blocks < 1: + raise ValueError(f"num_q_blocks must be >= 1, got {num_q_blocks}") + + if o_head_count is None: + o_head_count = head_count + if o_head_count < head_count: + raise ValueError( + f"o_head_count ({o_head_count}) must be >= head_count " + f"({head_count})" + ) + if not (0 <= o_head_offset <= o_head_count - head_count): + raise ValueError( + f"o_head_offset ({o_head_offset}) + head_count ({head_count}) " + f"must fit within o_head_count ({o_head_count})" + ) + + grouped = hlen < MLEN + kv_seq = num_kv_blocks * rows + q_seq = num_q_blocks * rows + + fp_state_elems = hardware_lane_count * rows + scale_val = 1.0 / math.sqrt(hlen) + + @T.prim_func + def flash_attention_gqa_min( + Q_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + K_hbm: T.Tensor((1, kv_seq, kv_head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, kv_head_count, hlen), "float16"), + O_hbm: T.Tensor((1, q_seq, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_q_blocks, head_count, threads=128) as (q_block, by): + # KV head shared across the group. ``%`` (modulo) — the ISA + # has no integer divide, so an interleaved group layout is + # the only HW-expressible GQA mapping. + kv_by = by % kv_head_count + + # Per-lane (rows, hlen) — col-pack expanded to 4D BSHD-packed. + Q_sh = T.alloc_shared((rows, hlen), "float16") + K_sh = T.alloc_shared((rows, hlen), "float16") # gemm RHS → mram + V_sh = T.alloc_shared((rows, hlen), "float16") # matmul RHS → mram (via DMA + gemm) + PV_loc = T.alloc_fragment((rows, hlen), "float16") + O_loc = T.alloc_fragment((rows, hlen), "float16") + # BTMM output: per-lane (rows, MLEN), row-stack expanded to 4D BHSD. + S_loc = T.alloc_fragment((rows, MLEN), "float16") + # Per-lane FP softmax state — expanded to (lane_count, rows). + M_OLD = T.alloc_fragment((rows,), "float16") + M_CURR = T.alloc_fragment((rows,), "float16") + M_RES = T.alloc_fragment((rows,), "float16") + L_OLD = T.alloc_fragment((rows,), "float16") + L_NEW = T.alloc_fragment((rows,), "float16") + P_SUM = T.alloc_fragment((rows,), "float16") + L_INV = T.alloc_fragment((rows,), "float16") + + # Q DMA — sync, fires once per q_block (multi-lane). Q is + # indexed by the full Q head ``by``. + T.copy( + Q_hbm[0, q_block * rows : (q_block + 1) * rows, by, 0:hlen], + Q_sh, + ) + + # Zero running output. + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = T.float16(0) + + # Reset per-lane FP softmax state for this q tile. + for row in T.serial(rows): + M_OLD[row] = T.float16(-1.0e4) + L_OLD[row] = T.float16(0) + + for kv_block in T.serial(num_kv_blocks): + # K, V DMAs — sync, multi-lane. Indexed by the SHARED + # KV head ``kv_by`` (= by % kv_head_count). + T.copy( + K_hbm[0, kv_block * rows : (kv_block + 1) * rows, kv_by, 0:hlen], + K_sh, + ) + T.copy( + V_hbm[0, kv_block * rows : (kv_block + 1) * rows, kv_by, 0:hlen], + V_sh, + ) + + # BTMM Q @ K^T → S_loc. + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + # Scale S_loc by 1/sqrt(d_k) per row. + for row in T.serial(rows): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] * T.float16(scale_val) + M_CURR[row] = M_OLD[row] + + # M_CURR = max(M_OLD, rowmax(S_loc)). + T.reduce_max(S_loc, M_CURR, dim=1, clear=False) + + for row in T.serial(rows): + M_RES[row] = M_OLD[row] - M_CURR[row] + M_RES[row] = T.exp(M_RES[row]) + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] - M_CURR[row] + for col in T.Parallel(MLEN): + S_loc[row, col] = T.exp(S_loc[row, col]) + P_SUM[row] = T.float16(0) + + # P_SUM = rowsum(exp(S - M_CURR)). + T.reduce_sum(S_loc, P_SUM, dim=1, clear=False) + + for row in T.serial(rows): + L_NEW[row] = L_OLD[row] * M_RES[row] + L_NEW[row] = L_NEW[row] + P_SUM[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * M_RES[row] + M_OLD[row] = M_CURR[row] + L_OLD[row] = L_NEW[row] + + # Per-head P @ V → PV_loc, then O += PV_loc. + T.gemm(S_loc, V_sh, PV_loc) + + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] + PV_loc[row, col] + + # Final O = O / L_new for this q_block. + for row in T.serial(rows): + L_INV[row] = 1.0 / L_NEW[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * L_INV[row] + + # Write O back to HBM at this q_block slot. O is indexed by + # the full Q head ``by`` (+ o_head_offset) — every Q head + # produces its own output, only K/V are shared. + T.copy( + O_loc, + O_hbm[0, q_block * rows : (q_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + lowered = flash_attention_gqa_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "KV_HEAD_COUNT": kv_head_count, + "GROUP_SIZE": group_size, + "LANE_COUNT": hardware_lane_count, + "HARDWARE_LANE_COUNT": hardware_lane_count, + "ACTIVE_LANE": active_lane, + "GROUPED": grouped, + "FPRAM_USER_BASE": FPRAM_USER_BASE, + "FP_STATE_ELEMS": fp_state_elems, + "NUM_KV_BLOCKS": num_kv_blocks, + "NUM_Q_BLOCKS": num_q_blocks, + } + return lowered, constants + + +__all__ = ["make_flash_attention_gqa_min"] diff --git a/tilelang_tvm_compiler/kernels/flash_attention_min.py b/tilelang_tvm_compiler/kernels/flash_attention_min.py new file mode 100644 index 0000000..fdb3268 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/flash_attention_min.py @@ -0,0 +1,262 @@ +"""Flash-attention-min kernel — written in tilelang style. + +Multi-q-block × multi-kv-block flash attention with online softmax, +head-fused via ``T.Kernel(num_q_blocks, head_count)``. + +Tilelang-DSL parts: + * ``T.Kernel(num_q_blocks, head_count) as (q_block, by)`` — grid axes; + ``by`` is the logical head axis. The frontend splits it into hardware + sync domains of width ``MLEN / hlen`` when DMAs / BTMM need fusion. + * ``T.copy`` for HBM↔VRAM/MRAM transfers. + * ``T.gemm(..., transpose_B=True)`` under ``T.attr(0, KIND, "btmm")`` + for Q@K^T with head fusion. + * Per-lane buffers declared as 2D shapes get auto-expanded by the + ``allocate_group_memory`` pass into 4D BSHD-packed (column-direction + head packing) or BHSD-stacked (row-direction head stacking). + +Direct ``T.call_extern("plena.*")`` parts (no tilelang DSL equivalent +yet): + * Per-head ``plena.matmul`` for ``P @ V``. + * ``plena.v_add`` / ``plena.zero_v`` for output accumulation. + +The frontend pipeline handles lane-fusion segmentation automatically: +each sync point (DMA / BTMM / vector op) fires once as a multi-lane HW +op outside the per-lane for-by loop; per-lane FP / matmul / row ops run +inside their own for-by loop. + +Single-buffered: each program handles exactly one head (``by``). The +head double-buffering variant is preserved in +``flash_attention_min.py.doublebuf.bak``. + +FP slot layout (1 flat FPRAM region starting at FPRAM_USER_BASE; each +slot ``hardware_lane_count*rows`` wide). Users declare each slot as a +1D per-lane fragment ``(rows,)`` and the compiler expands it to +``(hardware_lane_count, rows)`` inside the lane group. The testbench +preloads the read-only slots: + Scale[h, :] = 1 / sqrt(d_k) + M_init[h, :] = -inf surrogate + L_init[h, :] = 0 +""" + +import math + +import tilelang.language as T + +from ..address_alloc import FPRAM_USER_BASE +from ..frontend.gemm_macros import KIND +from ..plena_settings import load_sizes as _load_sizes + + +def make_flash_attention_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int | None = None, + lane_count: int | None = None, + active_lane: int = 0, + num_kv_blocks: int = 1, + num_q_blocks: int = 2, + o_head_count: int | None = None, + o_head_offset: int = 0, +): + """Flash attention with online softmax. + + ``o_head_count`` / ``o_head_offset`` let the kernel write its output + into a head-slice of a WIDER output tensor — used by the + single-stream-block chain to drop attention's result straight into + the left half of ``concat([attn, mlp])`` with no separate concat + kernel. ``O_hbm`` is declared with ``o_head_count`` heads (default: + same as ``head_count``, the standalone-kernel behaviour), and each + program writes head ``by + o_head_offset``. The grid still iterates + ``head_count`` heads; only the destination head index shifts. + """ + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"flash_attention_min requires rows == MLEN ({MLEN}), got {rows}" + ) + if MLEN % hlen != 0: + raise ValueError( + f"hlen must divide MLEN ({MLEN}); got hlen={hlen}" + ) + hardware_lane_count = MLEN // hlen + # Backward compatibility for older scripts: `lane_count` used to mean + # logical head count. New callers should pass `head_count`. + if head_count is None: + head_count = lane_count if lane_count is not None else hardware_lane_count + elif lane_count is not None and lane_count != head_count: + raise ValueError( + f"head_count and legacy lane_count disagree: {head_count} vs {lane_count}" + ) + if head_count < 1: + raise ValueError(f"head_count must be >= 1, got {head_count}") + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of hardware lane width " + f"MLEN/hlen={hardware_lane_count}; got {head_count}" + ) + if not (0 <= active_lane < hardware_lane_count): + raise ValueError( + f"active_lane out of hardware lane range [0, {hardware_lane_count}): " + f"{active_lane}" + ) + if num_kv_blocks < 1: + raise ValueError(f"num_kv_blocks must be >= 1, got {num_kv_blocks}") + if num_q_blocks < 1: + raise ValueError(f"num_q_blocks must be >= 1, got {num_q_blocks}") + + if o_head_count is None: + o_head_count = head_count + if o_head_count < head_count: + raise ValueError( + f"o_head_count ({o_head_count}) must be >= head_count " + f"({head_count})" + ) + if not (0 <= o_head_offset <= o_head_count - head_count): + raise ValueError( + f"o_head_offset ({o_head_offset}) + head_count ({head_count}) " + f"must fit within o_head_count ({o_head_count})" + ) + + grouped = hlen < MLEN + kv_seq = num_kv_blocks * rows + q_seq = num_q_blocks * rows + + fp_state_elems = hardware_lane_count * rows + # Softmax scale 1/sqrt(d_k). Embedded directly as a FloatImm via + # ``T.float16(...)``; ``hoist_float_constants`` turns it into a + # 1-slot ``global.fpram`` buffer at compile time. + scale_val = 1.0 / math.sqrt(hlen) + + @T.prim_func + def flash_attention_min( + Q_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + K_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + O_hbm: T.Tensor((1, q_seq, o_head_count, hlen), "float16"), + ): + # Single-buffered: the grid head axis is the FULL head_count. + # Each program handles exactly one head (``by``) with its own + # set of buffers. + with T.Kernel(num_q_blocks, head_count, threads=128) as (q_block, by): + head = by + + # ---- per-head buffers ----------------------------------- + Q_sh = T.alloc_shared((rows, hlen), "float16") + K_sh = T.alloc_shared((rows, hlen), "float16") # gemm RHS → mram + V_sh = T.alloc_shared((rows, hlen), "float16") # matmul RHS → mram + PV_loc = T.alloc_fragment((rows, hlen), "float16") + O_loc = T.alloc_fragment((rows, hlen), "float16") + S_loc = T.alloc_fragment((rows, MLEN), "float16") + M_OLD = T.alloc_fragment((rows,), "float16") + M_CURR = T.alloc_fragment((rows,), "float16") + M_RES = T.alloc_fragment((rows,), "float16") + L_OLD = T.alloc_fragment((rows,), "float16") + L_NEW = T.alloc_fragment((rows,), "float16") + P_SUM = T.alloc_fragment((rows,), "float16") + L_INV = T.alloc_fragment((rows,), "float16") + + # Q DMA. + T.copy( + Q_hbm[0, q_block * rows : (q_block + 1) * rows, head, 0:hlen], + Q_sh, + ) + + # Zero running output. + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = T.float16(0) + + # Reset per-lane FP softmax state. + for row in T.serial(rows): + M_OLD[row] = T.float16(-1.0e4) + L_OLD[row] = T.float16(0) + + for kv_block in T.serial(num_kv_blocks): + # --- copy copy -> dma (BTMM) --- + T.copy( + K_hbm[0, kv_block * rows : (kv_block + 1) * rows, head, 0:hlen], + K_sh, + ) + T.copy( + V_hbm[0, kv_block * rows : (kv_block + 1) * rows, head, 0:hlen], + V_sh, + ) + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + # --- online softmax + SV, one group --- + for row in T.serial(rows): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] * T.float16(scale_val) + M_CURR[row] = M_OLD[row] + T.reduce_max(S_loc, M_CURR, dim=1, clear=False) + for row in T.serial(rows): + M_RES[row] = M_OLD[row] - M_CURR[row] + M_RES[row] = T.exp(M_RES[row]) + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] - M_CURR[row] + for col in T.Parallel(MLEN): + S_loc[row, col] = T.exp(S_loc[row, col]) + P_SUM[row] = T.float16(0) + T.reduce_sum(S_loc, P_SUM, dim=1, clear=False) + for row in T.serial(rows): + L_NEW[row] = L_OLD[row] * M_RES[row] + L_NEW[row] = L_NEW[row] + P_SUM[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * M_RES[row] + M_OLD[row] = M_CURR[row] + L_OLD[row] = L_NEW[row] + T.gemm(S_loc, V_sh, PV_loc) + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] + PV_loc[row, col] + + # Final O = O / L_new. + for row in T.serial(rows): + L_INV[row] = 1.0 / L_NEW[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * L_INV[row] + + # Write O back to HBM — head ``by`` to head + offset. + T.copy( + O_loc, + O_hbm[0, q_block * rows : (q_block + 1) * rows, + head + o_head_offset, 0:hlen], + ) + + # Return the raw PrimFunc. ``compile_kernel`` runs stmt prep + the + # mid_ir pipeline itself, so factories no longer need to call into + # the legacy compile_func. + lowered = flash_attention_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "LANE_COUNT": hardware_lane_count, + "HARDWARE_LANE_COUNT": hardware_lane_count, + "ACTIVE_LANE": active_lane, + "GROUPED": grouped, + "FPRAM_USER_BASE": FPRAM_USER_BASE, + "FP_STATE_ELEMS": fp_state_elems, + # FPRAM scalar-slot addresses are exposed via the compiler's + # --dump-buffer-addrs JSON (single source of truth — see + # PIPELINE_ARCHITECTURE.md § 5.6). Don't add ``*_ADDR`` keys + # back here; they were a hand-rolled mirror of + # AddressAllocationPass and were the root cause of the + # flash_decode_min FPRAM bug when they drifted. + "NUM_KV_BLOCKS": num_kv_blocks, + "NUM_Q_BLOCKS": num_q_blocks, + } + return lowered, constants + + +__all__ = ["make_flash_attention_min"] diff --git a/tilelang_tvm_compiler/kernels/flash_decode_min.py b/tilelang_tvm_compiler/kernels/flash_decode_min.py new file mode 100644 index 0000000..2fdd63c --- /dev/null +++ b/tilelang_tvm_compiler/kernels/flash_decode_min.py @@ -0,0 +1,255 @@ +"""Flash-attention decode kernel — single Q token, lane-fused multi-head. + +Mirrors `flash_attention_min`'s structure with three key differences: + + 1. Q is a **single token** (rows=1). The kernel doesn't loop over + q_blocks; only over kv_blocks for the online-softmax accumulation. + 2. Q does NOT come from HBM. It lives in a **VRAM "tensor cache"** + region (``Q_cache``) sized ``(head_count, hlen)``. The testbench + preloads Q values to FPRAM and a pre-kernel ASM stub copies them + to the VRAM cache via ``S_MAP_V_FP`` before the kernel proper + starts. From the kernel's perspective this is just a normal shared + buffer it reads from. Per-by_o iteration the kernel pulls one + MLEN-wide chunk via ``T.copy(Q_cache[by, 0], Q_sh[0, 0])``, which + lowers to a single ``V_ADD_VF`` with f0=0 (vram→vram row copy). + 3. Q @ K^T uses **BTMV** (rows=1 LHS triggers the dispatch in + `_lower_gemm`). P @ V uses **plena.mv** (per-head M_MV), since + the row-stacked S_loc layout is incompatible with M_BMV's + lane-packed input — exactly the same reason flash_attention_min + uses per-head M_MM for its P @ V. + +Supports any ``head_count`` that is a multiple of +``hardware_lane_count`` (single by_o iter when equal; multi-by_o +otherwise). Q_cache holds all heads' Q values laid out head-major +(head 0's hlen elements, then head 1's, etc.); the by-indexed read +naturally selects the right MLEN-wide chunk per by_o. + +FP slot layout (1 flat FPRAM region starting at FPRAM_USER_BASE): + + Q_FP_STAGE (head_count, hlen) — staging area, preloaded by testbench; + pre-kernel stub copies to Q_cache + M_OLD (lane_count, 1) + M_CURR (lane_count, 1) + M_RES (lane_count, 1) + L_OLD (lane_count, 1) + L_NEW (lane_count, 1) + P_SUM (lane_count, 1) + SCALE (lane_count, 1) — preloaded: 1 / sqrt(d_k) + L_INV (lane_count, 1) + M_INIT (lane_count, 1) — preloaded: -inf surrogate + L_INIT (lane_count, 1) — preloaded: 0 + O_FP (lane_count, hlen) — final per-head output, drained from VRAM + +The kernel does NOT write back to HBM. The output ends up in FPRAM at +``O_FP``; the testbench reads FPRAM directly to compare against golden +(``compare_fpsram_output=True`` in comparison_params). +""" + +import math + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + +from ..address_alloc import FPRAM_USER_BASE +from ..frontend.gemm_macros import KIND + + +def make_flash_decode_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int | None = None, + num_kv_blocks: int = 2, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"flash_decode_min requires rows == MLEN ({MLEN}), got {rows}" + ) + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count is None: + head_count = hardware_lane_count + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of hardware_lane_count " + f"({hardware_lane_count}); got head_count={head_count}" + ) + if num_kv_blocks < 1: + raise ValueError(f"num_kv_blocks must be >= 1, got {num_kv_blocks}") + + kv_seq = num_kv_blocks * rows + # Softmax scale 1/sqrt(d_k). Embedded directly as a FloatImm via + # ``T.float16(...)`` in the kernel body — the ``hoist_float_constants`` + # pre-pass turns it into a 1-slot global.fpram buffer at compile + # time, no SCALE alloc / SCALE preload required. + scale_val = 1.0 / math.sqrt(hlen) + + @T.prim_func + def flash_decode_min( + K_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + ): + with T.Kernel(1, head_count, threads=128) as (_, by): + # Q lives in a VRAM "tensor cache" region — a global tensor + # populated by the testbench pre-kernel stub via S_MAP_V_FP + # from FPRAM. Layout is head-major (head_count rows, hlen + # cols), so by-indexed reads naturally select per-by_o slices + # of MLEN = lane_count * hlen. Marked global.vram so + # allocate_group_memory does not try to re-expand its head + # axis (the head dim is already explicit in the shape). + Q_cache = T.alloc_shared((head_count, hlen), "float16", + scope="global.vram") + # Symmetric output cache. Kernel writes O_loc -> O_cache[by, 0] + # via vram→vram T.copy (V_ADD_VF f0=0). Testbench compares the + # VRAM region directly — no FPRAM round-trip needed. Same + # global.vram rationale as Q_cache. + O_cache = T.alloc_shared((head_count, hlen), "float16", + scope="global.vram") + # VRAM staging so BTMV can read Q from VRAM. + # 2D rows=1 → col-packed to (1, 1, lane_count, hlen). + Q_sh = T.alloc_shared((1, hlen), "float16") + # MRAM tiles for K and V (gemm RHS). + K_sh = T.alloc_shared((rows, hlen), "float16") + V_sh = T.alloc_shared((rows, hlen), "float16") + # BTMV output: 2D rows=1 → row-stacked to (1, lane_count, 1, MLEN). + S_loc = T.alloc_fragment((1, MLEN), "float16") + # P @ V partial output and running accumulator: 2D rows=1. + PV_loc = T.alloc_fragment((1, hlen), "float16") + O_loc = T.alloc_fragment((1, hlen), "float16") + # Online softmax state: rank-1 → lane-stacked (lane_count, 1). + M_OLD = T.alloc_fragment((1,), "float16") + M_CURR = T.alloc_fragment((1,), "float16") + M_RES = T.alloc_fragment((1,), "float16") + L_OLD = T.alloc_fragment((1,), "float16") + L_NEW = T.alloc_fragment((1,), "float16") + P_SUM = T.alloc_fragment((1,), "float16") + L_INV = T.alloc_fragment((1,), "float16") + # SCALE / M_INIT / L_INIT are no longer declared buffers — + # the kernel body embeds the literals directly as + # ``T.float16(...)`` and the ``hoist_float_constants`` + # pre-pass synthesises an equivalent ``global.fpram`` + # 1-slot buffer per unique constant at compile time. + # ``test_helper`` auto-preloads the values from the + # buffer-addrs dump. + + # VRAM cache → VRAM staging: pull this by_o's MLEN-wide chunk + # of Q into Q_sh. Lowers to one V_ADD_VF (f0=0) row copy. + # ``by`` after split_lane_groups is by_o*lane_count + by_i; sync + # wrap substitutes by_i -> 0, so the source offset becomes + # by_o*lane_count*hlen = by_o*MLEN — exactly the per-by_o chunk. + # NOTE: dst is the whole Q_sh buffer (NOT Q_sh[0, 0]) so tilelang's + # copy_op doesn't degenerate to a scalar BufferStore. + T.copy(Q_cache[by, 0], Q_sh) + + # Zero output accumulator. T.Parallel + constant fill is + # picked up by fuse_elementwise → plena.zero_v (multi-lane, + # sync) — kernel never sees the plena op directly. + for col in T.Parallel(hlen): + O_loc[0, col] = T.float16(0) + + # Init online softmax state from -inf / 0 literals; the + # pre-pass hoists -1e4 into a shared global.fpram slot. + for row in T.serial(1): + M_OLD[row] = T.float16(-1.0e4) + L_OLD[row] = T.float16(0) + + for kv_block in T.unroll(num_kv_blocks): + # K, V DMAs — sync, multi-lane. Explicit slice form so + # mid_ir's ranged_slice inference produces clean + # (extent=rows, extent=hlen) tile shapes. + T.copy( + K_hbm[0, kv_block * rows : (kv_block + 1) * rows, by, 0:hlen], + K_sh, + ) + T.copy( + V_hbm[0, kv_block * rows : (kv_block + 1) * rows, by, 0:hlen], + V_sh, + ) + + # Q @ K^T → BTMV (rows=1 LHS auto-routes to plena.btmv). + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + # Scale + grab current max baseline. + for row in T.serial(1): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] * T.float16(scale_val) + M_CURR[row] = M_OLD[row] + + T.reduce_max(S_loc, M_CURR, dim=1, clear=False) + + for row in T.serial(1): + M_RES[row] = M_OLD[row] - M_CURR[row] + M_RES[row] = T.exp(M_RES[row]) + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] - M_CURR[row] + for col in T.Parallel(MLEN): + S_loc[row, col] = T.exp(S_loc[row, col]) + P_SUM[row] = T.float16(0) + + T.reduce_sum(S_loc, P_SUM, dim=1, clear=False) + + for row in T.serial(1): + L_NEW[row] = L_OLD[row] * M_RES[row] + L_NEW[row] = L_NEW[row] + P_SUM[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * M_RES[row] + M_OLD[row] = M_CURR[row] + L_OLD[row] = L_NEW[row] + + # P @ V — default kind. Compiler picks plena.mv (M_MV) + # because S_loc has rows=1; per-head lane offset + # (S_loc row-stacked at by*MLEN, V_sh / PV_loc + # col-packed at by*hlen) is auto-injected from each + # buffer's lane-axis stride. + T.gemm(S_loc, V_sh, PV_loc) + + # O += PV. T.Parallel + add is picked up by + # fuse_elementwise → plena.v_add (multi-lane, sync). + for col in T.Parallel(hlen): + O_loc[0, col] = O_loc[0, col] + PV_loc[0, col] + + # Final O = O / L_new. + for row in T.serial(1): + L_INV[row] = 1.0 / L_NEW[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * L_INV[row] + + # Write this by_o's MLEN-wide chunk of O into O_cache[by, 0]. + # vram→vram copy (V_ADD_VF f0=0); after lane fusion sync wrap, + # by_i drops to 0 so the dst offset becomes by_o*MLEN — exactly + # the per-by_o slice in the head-major O_cache layout. + # NOTE: src is the whole O_loc buffer (NOT O_loc[0, 0]) so + # tilelang's copy_op doesn't degenerate to a scalar BufferStore. + T.copy(O_loc, O_cache[by, 0]) + + # Return the raw PrimFunc — ``compile_kernel`` runs stmt prep + the + # mid_ir pipeline itself, so factories don't pre-lower anymore. + lowered = flash_decode_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "HARDWARE_LANE_COUNT": hardware_lane_count, + "FPRAM_USER_BASE": FPRAM_USER_BASE, + "NUM_KV_BLOCKS": num_kv_blocks, + "CACHE_NUM_MLEN_ROWS": (head_count * hlen) // MLEN, + # Buffer addresses are exposed via the compiler's + # --dump-buffer-addrs JSON (single source of truth — see + # PIPELINE_ARCHITECTURE.md § 5.6). The previous ``*_ADDR`` + # entries here were a hand-rolled mirror of + # AddressAllocationPass / `_slot_addresses` and were the root + # cause of the flash_decode_min FPRAM bug when they drifted. + } + return lowered, constants diff --git a/tilelang_tvm_compiler/kernels/flash_decode_min_gemm_only.py b/tilelang_tvm_compiler/kernels/flash_decode_min_gemm_only.py new file mode 100644 index 0000000..acd9622 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/flash_decode_min_gemm_only.py @@ -0,0 +1,201 @@ +"""flash_decode gemm-only debug kernel — with a step dial. + +Base: BTMV(Q@K^T) + MV(S@V), no softmax. ``fd_steps`` (0..8) adds the +stripped softmax ops back ONE AT A TIME so the testbench can bisect +which step introduces the sim/golden mismatch. + +Each level produces a MATHEMATICALLY WELL-DEFINED output (the testbench +golden mirrors it exactly), so there are no half-computed states to +guess at: + + 0 O = (Q@K^T) @ V (gemm-only base) + 1 O = (scale * Q@K^T) @ V (+ S *= scale) + 2 same O as 1, but reduce_max -> M_CURR is issued (result unused) + 3 same O as 1, but M_RES = exp(M_OLD-M_CURR) is issued (unused) + 4 O = exp(scale*S - M_CURR) @ V (+ S exp/sub) + 5 same O as 4, but reduce_sum -> P_SUM is issued (unused) + 6 same O as 4, but L_NEW = L_OLD*M_RES+P_SUM is issued (unused) + 7 O = M_RES * (exp(...) @ V) (+ O *= M_RES) + 8 O = (M_RES * (exp(...) @ V)) / L_NEW (== full flash_decode) + +Levels 2/3/5/6 add an op whose result does NOT change O — so if the +correctness drops at one of those, that op's hardware (reduce_max / +scalar exp / reduce_sum / scalar L chain) is itself the error source. +Levels 1/4/7/8 genuinely change O and the golden tracks them. + +NOTE: this is the single-q-block decode shape (rows of S used = 1). +""" + +import math + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + +from ..frontend.gemm_macros import KIND + + +def make_flash_decode_min_gemm_only( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int | None = None, + num_kv_blocks: int = 2, + fd_steps: int = 0, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + if not (0 <= fd_steps <= 8): + raise ValueError(f"fd_steps must be in [0, 8], got {fd_steps}") + hardware_lane_count = MLEN // hlen + if head_count is None: + head_count = hardware_lane_count + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of hardware_lane_count " + f"({hardware_lane_count}); got head_count={head_count}" + ) + if num_kv_blocks < 1: + raise ValueError(f"num_kv_blocks must be >= 1, got {num_kv_blocks}") + + kv_seq = num_kv_blocks * rows + scale_val = 1.0 / math.sqrt(hlen) + + @T.prim_func + def flash_decode_min_gemm_only( + K_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + ): + with T.Kernel(1, head_count, threads=128) as (_, by): + Q_cache = T.alloc_shared((head_count, hlen), "float16", + scope="global.vram") + O_cache = T.alloc_shared((head_count, hlen), "float16", + scope="global.vram") + Q_sh = T.alloc_shared((1, hlen), "float16") + K_sh = T.alloc_shared((rows, hlen), "float16") + V_sh = T.alloc_shared((rows, hlen), "float16") + S_loc = T.alloc_fragment((1, MLEN), "float16") + PV_loc = T.alloc_fragment((1, hlen), "float16") + O_loc = T.alloc_fragment((1, hlen), "float16") + # Online-softmax FPRAM scalars. Allocated for fd_steps >= 2; + # cheap to always declare, only written when the step needs it. + M_OLD = T.alloc_fragment((1,), "float16") + M_CURR = T.alloc_fragment((1,), "float16") + M_RES = T.alloc_fragment((1,), "float16") + L_OLD = T.alloc_fragment((1,), "float16") + L_NEW = T.alloc_fragment((1,), "float16") + P_SUM = T.alloc_fragment((1,), "float16") + L_INV = T.alloc_fragment((1,), "float16") + + T.copy(Q_cache[by, 0], Q_sh) + + for col in T.Parallel(hlen): + O_loc[0, col] = T.float16(0) + + # Init softmax state (needed from step 2 onward). + if fd_steps >= 2: + for row in T.serial(1): + M_OLD[row] = T.float16(-1.0e4) + L_OLD[row] = T.float16(0) + + for kv_block in T.unroll(num_kv_blocks): + T.copy( + K_hbm[0, kv_block * rows : (kv_block + 1) * rows, by, 0:hlen], + K_sh, + ) + T.copy( + V_hbm[0, kv_block * rows : (kv_block + 1) * rows, by, 0:hlen], + V_sh, + ) + + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + # STEP 1: S *= scale + if fd_steps >= 1: + for row in T.serial(1): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] * T.float16(scale_val) + + # STEP 2: reduce_max -> M_CURR (result unused for O at <4). + if fd_steps >= 2: + for row in T.serial(1): + M_CURR[row] = M_OLD[row] + T.reduce_max(S_loc, M_CURR, dim=1, clear=False) + + # STEP 3: M_RES = exp(M_OLD - M_CURR) (unused for O at <7). + if fd_steps >= 3: + for row in T.serial(1): + M_RES[row] = M_OLD[row] - M_CURR[row] + M_RES[row] = T.exp(M_RES[row]) + + # STEP 4: S = exp(S - M_CURR) + if fd_steps >= 4: + for row in T.serial(1): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] - M_CURR[row] + for col in T.Parallel(MLEN): + S_loc[row, col] = T.exp(S_loc[row, col]) + + # STEP 5: reduce_sum -> P_SUM (unused for O at <6/8). + if fd_steps >= 5: + for row in T.serial(1): + P_SUM[row] = T.float16(0) + T.reduce_sum(S_loc, P_SUM, dim=1, clear=False) + + # STEP 6: L_NEW = L_OLD*M_RES + P_SUM (unused for O at <8). + if fd_steps >= 6: + for row in T.serial(1): + L_NEW[row] = L_OLD[row] * M_RES[row] + L_NEW[row] = L_NEW[row] + P_SUM[row] + + # STEP 7: O_loc *= M_RES (rescale running output). + if fd_steps >= 7: + for row in T.serial(1): + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * M_RES[row] + + # Advance online state (only meaningful once it is used). + if fd_steps >= 6: + for row in T.serial(1): + M_OLD[row] = M_CURR[row] + L_OLD[row] = L_NEW[row] + + T.gemm(S_loc, V_sh, PV_loc) + + for col in T.Parallel(hlen): + O_loc[0, col] = O_loc[0, col] + PV_loc[0, col] + + # STEP 8: O = O / L_NEW + if fd_steps >= 8: + for row in T.serial(1): + L_INV[row] = 1.0 / L_NEW[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * L_INV[row] + + T.copy(O_loc, O_cache[by, 0]) + + lowered = flash_decode_min_gemm_only + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "HARDWARE_LANE_COUNT": hardware_lane_count, + "NUM_KV_BLOCKS": num_kv_blocks, + "CACHE_NUM_MLEN_ROWS": (head_count * hlen) // MLEN, + "FD_STEPS": fd_steps, + } + return lowered, constants + + +__all__ = ["make_flash_decode_min_gemm_only"] diff --git a/tilelang_tvm_compiler/kernels/gelu_min.py b/tilelang_tvm_compiler/kernels/gelu_min.py new file mode 100644 index 0000000..b066022 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/gelu_min.py @@ -0,0 +1,174 @@ +"""GELU-min kernel — exercises FP scalar compound-store decomposition. + +GELU (tanh approximation): + + GELU(x) = 0.5 * x * (1 + tanh(u)) + u = sqrt(2/pi) * (x + 0.044715 * x^3) + +PLENA has no native tanh. We expand it inline using only exp / reci / +add / sub / mul (all of which are PLENA FP scalar primitives): + + tanh(u) = 1 - 2 / (exp(2u) + 1) + +The five scalar constants (0.5, 1.0, 2.0, sqrt(2/pi), 0.044715) are +embedded directly as ``T.float16(...)`` literals; the +``hoist_float_constants`` pre-pass synthesises one 1-slot +``global.fpram`` buffer per unique value, and ``test_helper`` auto- +preloads the values from the buffer-addrs dump. +""" + +import math + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_gelu_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, + o_head_count: int | None = None, + o_head_offset: int = 0, +): + """GELU (tanh approximation). + + ``o_head_count`` / ``o_head_offset`` let GELU write into a + head-slice of a WIDER output tensor — the single-stream-block chain + uses this to drop GELU(mlp) into the right half of + ``concat([attn, mlp])``. + """ + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"gelu_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + if o_head_count is None: + o_head_count = head_count + if o_head_count < head_count: + raise ValueError( + f"o_head_count ({o_head_count}) must be >= head_count ({head_count})" + ) + if not (0 <= o_head_offset <= o_head_count - head_count): + raise ValueError( + f"o_head_offset ({o_head_offset}) + head_count ({head_count}) " + f"must fit within o_head_count ({o_head_count})" + ) + + seq_len = num_s_blocks * rows + # GELU tanh-approximation constants. Embedded directly as + # ``T.float16(...)`` in the kernel body; the + # ``hoist_float_constants`` pre-pass turns each unique value into a + # 1-slot global.fpram buffer at compile time. + sqrt_2_over_pi_val = math.sqrt(2.0 / math.pi) + + @T.prim_func + def gelu_min( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, o_head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + + # Per-row FPRAM scratch for the input and output of GELU. + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + + # The five GELU scalar constants (0.5, 1.0, 2.0, sqrt(2/pi), + # 0.044715) are inlined as ``T.float16(...)`` below. The + # ``hoist_float_constants`` pre-pass auto-allocates a 1-slot + # global.fpram buffer per unique value. + + # Intermediate FPRAM scratch fragments. Allocating them + # explicitly (instead of letting lower_compound_fp_stores + # auto-allocate ``__tmp_fp_*``) keeps the chain readable and + # gives every subop a foldable ``dst[i] = a[i] op b[i]`` + # shape — no FloatImm leaves anywhere in the RHS. + x3 = T.alloc_fragment((hlen,), "float16") # x*x*x + cx3 = T.alloc_fragment((hlen,), "float16") # 0.044715 * x^3 + inner_raw = T.alloc_fragment((hlen,), "float16") # x + cx3 + u = T.alloc_fragment((hlen,), "float16") # sqrt(2/pi) * inner_raw + two_u = T.alloc_fragment((hlen,), "float16") # 2 * u + e2u = T.alloc_fragment((hlen,), "float16") # exp(2u) + denom = T.alloc_fragment((hlen,), "float16") # exp(2u) + 1 + reci_d = T.alloc_fragment((hlen,), "float16") # 1 / denom + two_recid = T.alloc_fragment((hlen,), "float16") # 2 * reci_d + tanh_u = T.alloc_fragment((hlen,), "float16") # 1 - two_recid + one_p = T.alloc_fragment((hlen,), "float16") # 1 + tanh_u + hx = T.alloc_fragment((hlen,), "float16") # 0.5 * x + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + + for i in T.unroll(hlen): + # u = sqrt(2/pi) * (x + 0.044715 * x^3) + x3[i] = X_FP[i] * X_FP[i] * X_FP[i] + cx3[i] = T.float16(0.044715) * x3[i] + inner_raw[i] = X_FP[i] + cx3[i] + u[i] = T.float16(sqrt_2_over_pi_val) * inner_raw[i] + + # tanh(u) = 1 - 2 * (1 / (exp(2u) + 1)) + two_u[i] = T.float16(2.0) * u[i] + e2u[i] = T.exp(two_u[i]) + denom[i] = e2u[i] + T.float16(1.0) + # ``1.0 / x`` is the only div form fold recognises + # (it picks the FloatImm-1 literal numerator and + # lowers to ``fp_reci_at``). A ``BufferLoad / BufferLoad`` + # would fall through to fold's binop arm and fail. + reci_d[i] = T.float16(1.0) / denom[i] + two_recid[i] = T.float16(2.0) * reci_d[i] + tanh_u[i] = T.float16(1.0) - two_recid[i] + + # GELU(x) = 0.5 * x * (1 + tanh(u)) + one_p[i] = T.float16(1.0) + tanh_u[i] + hx[i] = T.float16(0.5) * X_FP[i] + Y_FP[i] = hx[i] * one_p[i] + + T.copy(Y_FP, Y_sh[row, 0]) + + # Destination head shifted by o_head_offset so GELU's output + # can land in a head-slice of a wider tensor (concat). + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, + by + o_head_offset, 0:hlen], + ) + + lowered = gelu_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_gelu_min"] diff --git a/tilelang_tvm_compiler/kernels/layernorm_min.py b/tilelang_tvm_compiler/kernels/layernorm_min.py new file mode 100644 index 0000000..f66bf89 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/layernorm_min.py @@ -0,0 +1,192 @@ +"""LayerNorm-min kernel — ``out = (x - mean(x)) * rsqrt(var(x) + eps) * scale + bias``. + +Differs from ``rmsnorm_min`` in two key ways: + + 1. **Reduction span**: LayerNorm normalises over the entire + ``hidden_size = H*D`` dimension of a ``(B, S, H*D)`` tensor (no + per-head split). HBM layout is ``(1, S, 1, H*D)`` and D > MLEN + is the common case (``hidden_size`` is typically 128, 256, …). + The reduce / row_*_fp / tile_* emitters handle the multi-d_tile + unroll themselves via the 7D ``tile_layout``. + + 2. **Extra subtraction + bias**: LayerNorm subtracts the mean before + squaring (RMSNorm doesn't), and adds a learnable bias at the end + (RMSNorm's affine has scale only). + +Decomposed into PLENA single-op stores: + + mean_sum[i] = sum_j(X[i,j]) row_reduce_sum_at + mu[i] = mean_sum[i] * INV_N[i] fp_mul + XC[i,j] = X[i,j] - mu[i] row_sub_fp_at + SQ[i,j] = XC[i,j] * XC[i,j] tile_mul + var_sum[i] = sum_j(SQ[i,j]) row_reduce_sum_at + var[i] = var_sum[i] * INV_N[i] fp_mul + var_eps[i] = var[i] + EPS[i] fp_add + norm[i] = sqrt(var_eps[i]) fp_sqrt + inv[i] = 1 / norm[i] fp_reci + Y[i,j] = XC[i,j] * SCALE[i,j] tile_mul (host broadcasts scale) + Y[i,j] *= inv[i] row_mul_fp_at (in-place) + Y[i,j] += BIAS[i,j] tile_add (host broadcasts bias) + +Like ``rmsnorm_min``: + * ``INV_N = 1/hidden_size`` and ``EPS`` are FPRAM-preloaded scalars. + * The accumulating ``V_RED_SUM`` is seeded from a preloaded zero + fragment (``SS_INIT``) before each reduce — same pattern flash + attention uses for ``L_INIT``. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_layernorm_min( + *, + rows: int | None = None, + hidden_size: int = 128, + num_s_blocks: int = 2, + batch: int = 1, + eps: float = 1e-6, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"layernorm_min requires rows == MLEN ({MLEN}), got {rows}" + ) + if hidden_size % MLEN != 0: + raise ValueError( + f"hidden_size must be a multiple of MLEN ({MLEN}); " + f"got hidden_size={hidden_size}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + H = hidden_size + # INV_N (= 1/hidden_size) and eps are inlined as T.float16(...) + # literals; auto-hoisted into 1-slot global.fpram buffers. + inv_n_val = 1.0 / hidden_size + + @T.prim_func + def layernorm_min( + X_hbm: T.Tensor((batch, seq_len, 1, H), "float16"), + SCALE_hbm: T.Tensor((batch, seq_len, 1, H), "float16"), + BIAS_hbm: T.Tensor((batch, seq_len, 1, H), "float16"), + Y_hbm: T.Tensor((batch, seq_len, 1, H), "float16"), + ): + with T.Kernel(num_s_blocks, threads=128) as s_block: + # HBM <-> on-chip staging (shared). No head packing (H=1 in + # the BSHD layout); ``hidden_size > MLEN`` triggers the 7D + # tile_layout with d_tiles = hidden_size/MLEN. + X_sh = T.alloc_shared((rows, H), "float16") + SCALE_sh = T.alloc_shared((rows, H), "float16") + BIAS_sh = T.alloc_shared((rows, H), "float16") + Y_sh = T.alloc_shared((rows, H), "float16") + + X_loc = T.alloc_fragment((rows, H), "float16") + SC_loc = T.alloc_fragment((rows, H), "float16") + BI_loc = T.alloc_fragment((rows, H), "float16") + SQ_loc = T.alloc_fragment((rows, H), "float16") + Y_loc = T.alloc_fragment((rows, H), "float16") + + # Per-row FP scratch (rank-1 -> FPRAM scalar slots). + MEAN_SUM = T.alloc_fragment((rows,), "float16") + MU = T.alloc_fragment((rows,), "float16") + VAR_SUM = T.alloc_fragment((rows,), "float16") + VAR = T.alloc_fragment((rows,), "float16") + VAR_EPS = T.alloc_fragment((rows,), "float16") + NORM = T.alloc_fragment((rows,), "float16") + INV = T.alloc_fragment((rows,), "float16") + + # INV_N (1/hidden_size) and eps are inlined as + # T.float16(...) literals below; the zero seed is + # T.float16(0) which takes fold's zero-fill path. + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, 0, 0:H], + X_sh, + ) + T.copy( + SCALE_hbm[0, s_block * rows : (s_block + 1) * rows, 0, 0:H], + SCALE_sh, + ) + T.copy( + BIAS_hbm[0, s_block * rows : (s_block + 1) * rows, 0, 0:H], + BIAS_sh, + ) + T.copy(X_sh, X_loc) + T.copy(SCALE_sh, SC_loc) + T.copy(BIAS_sh, BI_loc) + + # Seed mean accumulator from zero before reduce — + # V_RED_SUM accumulates into its FPRAM slot. + for row in T.serial(rows): + MEAN_SUM[row] = T.float16(0) + + # mean_sum[i] = sum_j(X[i, j]) + T.reduce_sum(X_loc, MEAN_SUM, dim=1) + + # mu[i] = mean_sum[i] * INV_N[i] + for row in T.serial(rows): + MU[row] = MEAN_SUM[row] * T.float16(inv_n_val) + + # XC = X - mu (in-place on X_loc to avoid extra fragment). + for row in T.serial(rows): + for col in T.Parallel(H): + X_loc[row, col] = X_loc[row, col] - MU[row] + + # SQ = XC * XC + for row in T.serial(rows): + for col in T.Parallel(H): + SQ_loc[row, col] = X_loc[row, col] * X_loc[row, col] + VAR_SUM[row] = T.float16(0) + + T.reduce_sum(SQ_loc, VAR_SUM, dim=1) + + for row in T.serial(rows): + VAR[row] = VAR_SUM[row] * T.float16(inv_n_val) + VAR_EPS[row] = VAR[row] + T.float16(eps) + NORM[row] = T.sqrt(VAR_EPS[row]) + INV[row] = T.float16(1.0) / NORM[row] + + # Y = XC * scale (host-broadcast SCALE into (rows, H)). + # Write directly into Y_loc so the in-place row_mul_fp_at + # below has dst == src (no cross-lane pollution; not strictly + # needed here because there is no packed-head mask in this + # path, but it keeps the kernel parallel to rmsnorm_min). + for row in T.serial(rows): + for col in T.Parallel(H): + Y_loc[row, col] = X_loc[row, col] * SC_loc[row, col] + + # Y *= inv (row_mul_fp_at; D > MLEN unrolls inside the emitter). + for row in T.serial(rows): + for col in T.Parallel(H): + Y_loc[row, col] = Y_loc[row, col] * INV[row] + + # Y += bias (host-broadcast BIAS into (rows, H)). + for row in T.serial(rows): + for col in T.Parallel(H): + Y_loc[row, col] = Y_loc[row, col] + BI_loc[row, col] + + T.copy(Y_loc, Y_sh) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, 0, 0:H], + ) + + lowered = layernorm_min + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HIDDEN_SIZE": hidden_size, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + } + return lowered, constants + + +__all__ = ["make_layernorm_min"] diff --git a/tilelang_tvm_compiler/kernels/linear_min.py b/tilelang_tvm_compiler/kernels/linear_min.py new file mode 100644 index 0000000..6609298 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/linear_min.py @@ -0,0 +1,254 @@ +"""Linear-min kernel — multi-tile GEMM (+ optional bias). + +PLENA-flavored counterpart to GPU-tilelang's ``_build_gemm_kernel`` / +``_build_gemm_bias_kernel``. Shape constraints: M, N, K must each be a +multiple of MLEN (= 64) — PLENA's matmul tile granularity. The natural +mapping mirrors the GPU version's (block_M, block_N, block_K) where each +block equals MLEN: + + block_M = block_N = block_K = MLEN = 64 + m_blocks = M / MLEN + n_blocks = N / MLEN + k_blocks = K / MLEN + +What it computes (output per (m_block, n_block) tile): + + C[m_block, n_block] = sum_{k_block} A[m_block, k_block] + @ B[n_block, k_block]^T + + bias[m_block, n_block] (optional) + +B is K-inner ``(N, K)`` to match ``nn.Linear.weight``; the gemm uses +``transpose_B=True``. + +Lowering path (kind="overwrite" default, NOT btmm): + PLENA's BTMM is the head-fused Q@K^T path with packed-head layout + (btmm_hlen < MLEN). Plain MLEN×MLEN×MLEN matmul without head packing + lowers through ``plena.matmul`` (``M_MM_WO`` drain) — the same path + flash_attention's second gemm (P @ V) takes. So no ``T.attr(KIND, + "btmm")`` wrap here. + + ``transpose_B=True`` is honoured by the lowering: emit_matmul_general + swaps ``M_MM`` for ``M_TMM`` (which transposes the (mlen, mlen) MRAM + tile inside the systolic array), so the kernel takes ``B`` in + ``(N, K)`` row-major layout — the standard nn.Linear weight format, + no host-side transpose required. + +K accumulation: + ``kind="add"`` is reserved but not implemented yet (see + gemm_macros.py docstring). For now do the documented workaround + manually: + + T.gemm(A_blk, B_blk, SCR_loc) # overwrite into scratch + for r, c: C_loc[r, c] += SCR_loc[r, c] # fuse_elementwise → v_add + + Before the K-loop starts, ``C_loc`` is zeroed by an inline parallel + loop so the first k_block's add behaves like a clear+write. + +Bias (optional): + ``bias`` broadcasts along N (cols) in nn.Linear, which is the wrong + axis for ``row_*_fp_at``'s "one FP scalar per row" semantics. So the + testbench host-broadcasts ``bias[N]`` → ``(M, N)`` and the kernel + consumes it as a full-shape VRAM tile (one tile_add per (m, n) block). + Same trick rmsnorm_min uses for its ``(hlen,)`` scale weight. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_linear_min( + *, + m_blocks: int = 1, + n_blocks: int = 1, + k_blocks: int = 1, + with_bias: bool = False, + c_wide_n: int | None = None, + c_col_offset: int = 0, +): + """``c_wide_n`` / ``c_col_offset`` let the kernel write its result + into a column-slice of a WIDER output tensor — used by the + single-stream-block chain to drop the MLP-in projection straight + into the left part of ``concat([mlp, attn])`` with no separate + concat kernel (mirrors flash_attention_min's ``o_head_offset``). + + ``C_hbm`` is declared with ``c_wide_n`` columns (default: ``N``, the + standalone-kernel behaviour); each tile writes columns + ``bx*MLEN + c_col_offset``. ``c_col_offset`` must be a multiple of + MLEN so the 64-wide tile lands on an aligned column block. + """ + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if m_blocks < 1 or n_blocks < 1 or k_blocks < 1: + raise ValueError( + f"m_blocks/n_blocks/k_blocks must be >= 1; " + f"got m={m_blocks}, n={n_blocks}, k={k_blocks}" + ) + M = m_blocks * MLEN + N = n_blocks * MLEN + K = k_blocks * MLEN + + if c_wide_n is None: + c_wide_n = N + C_WIDE_N = c_wide_n + if C_WIDE_N < N: + raise ValueError( + f"c_wide_n ({C_WIDE_N}) must be >= N ({N})" + ) + if c_col_offset % MLEN != 0: + raise ValueError( + f"c_col_offset ({c_col_offset}) must be a multiple of MLEN " + f"({MLEN})" + ) + if c_col_offset + N > C_WIDE_N: + raise ValueError( + f"c_col_offset ({c_col_offset}) + N ({N}) must fit within " + f"c_wide_n ({C_WIDE_N})" + ) + + # PLENA's DMA-slice lowering expects HBM tensors to carry the full + # 4D BSHD shape (batch, seq, head, hlen). Linear has no real head + # axis, so we degenerate head=1 and lay (M, K) / (N, K) / (M, N) + # along the (seq, hlen) pair: A_hbm[1, M, 1, K], B_hbm[1, N, 1, K], + # C_hbm[1, M, 1, N], BIAS_hbm[1, M, 1, N]. + if with_bias: + @T.prim_func + def linear_min( + A_hbm: T.Tensor((1, M, 1, K), "float16"), + B_hbm: T.Tensor((1, N, 1, K), "float16"), + BIAS_hbm: T.Tensor((1, M, 1, N), "float16"), + C_hbm: T.Tensor((1, M, 1, C_WIDE_N), "float16"), + ): + # Grid: one program per (n_block, m_block) tile — same axis + # order tilelang_kernels/linear.py uses (bx along N, by along + # M) so cache-line / coalescing intuition carries over. + with T.Kernel(n_blocks, m_blocks, threads=128) as (bx, by): + A_sh = T.alloc_shared((MLEN, MLEN), "float16") + B_sh = T.alloc_shared((MLEN, MLEN), "float16") + BIAS_sh = T.alloc_shared((MLEN, MLEN), "float16") + C_sh = T.alloc_shared((MLEN, MLEN), "float16") + + C_loc = T.alloc_fragment((MLEN, MLEN), "float16") + SCR_loc = T.alloc_fragment((MLEN, MLEN), "float16") + + # Zero C_loc so the first K iteration's add behaves as + # clear+write. + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = T.float16(0) + + for k_block in T.serial(k_blocks): + T.copy( + A_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + A_sh, + ) + # B is (N, K) row-major — same convention as + # nn.Linear.weight. The lowering issues M_TMM when + # transpose_B is set, which transposes the (mlen, + # mlen) MRAM tile on the fly inside the systolic + # array. The slice walks N along the seq axis and K + # along the hlen axis. + T.copy( + B_hbm[0, + bx * MLEN : (bx + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + B_sh, + ) + + T.gemm(A_sh, B_sh, SCR_loc, transpose_B=True) + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + SCR_loc[row, col] + + T.copy( + BIAS_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + BIAS_sh, + ) + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + BIAS_sh[row, col] + + T.copy(C_loc, C_sh) + # Write into a column-slice of the (possibly wider) + # C_hbm: shift the col block by c_col_offset. + T.copy( + C_sh, + C_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN + c_col_offset + : bx * MLEN + c_col_offset + MLEN], + ) + else: + @T.prim_func + def linear_min( + A_hbm: T.Tensor((1, M, 1, K), "float16"), + B_hbm: T.Tensor((1, N, 1, K), "float16"), + C_hbm: T.Tensor((1, M, 1, C_WIDE_N), "float16"), + ): + with T.Kernel(n_blocks, m_blocks, threads=128) as (bx, by): + A_sh = T.alloc_shared((MLEN, MLEN), "float16") + B_sh = T.alloc_shared((MLEN, MLEN), "float16") + C_sh = T.alloc_shared((MLEN, MLEN), "float16") + + C_loc = T.alloc_fragment((MLEN, MLEN), "float16") + SCR_loc = T.alloc_fragment((MLEN, MLEN), "float16") + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = T.float16(0) + + for k_block in T.serial(k_blocks): + T.copy( + A_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + A_sh, + ) + T.copy( + B_hbm[0, + bx * MLEN : (bx + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + B_sh, + ) + + T.gemm(A_sh, B_sh, SCR_loc, transpose_B=True) + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + SCR_loc[row, col] + + T.copy(C_loc, C_sh) + # Write into a column-slice of the (possibly wider) + # C_hbm: shift the col block by c_col_offset. + T.copy( + C_sh, + C_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN + c_col_offset + : bx * MLEN + c_col_offset + MLEN], + ) + + lowered = linear_min + constants = { + "M": M, "N": N, "K": K, "MLEN": MLEN, + "M_BLOCKS": m_blocks, "N_BLOCKS": n_blocks, "K_BLOCKS": k_blocks, + "WITH_BIAS": with_bias, + "C_WIDE_N": C_WIDE_N, "C_COL_OFFSET": c_col_offset, + } + return lowered, constants + + +__all__ = ["make_linear_min"] diff --git a/tilelang_tvm_compiler/kernels/linear_min_no_transpose.py b/tilelang_tvm_compiler/kernels/linear_min_no_transpose.py new file mode 100644 index 0000000..ae0adcf --- /dev/null +++ b/tilelang_tvm_compiler/kernels/linear_min_no_transpose.py @@ -0,0 +1,175 @@ +"""Linear-min kernel — multi-tile GEMM (no-transpose-B variant). + +Same shape semantics as ``linear_min`` but ``B`` is laid out as +``(K, N)`` row-major rather than ``(N, K)`` — i.e. the host has already +transposed the weight. The lowering uses plain ``M_MM`` (no ``M_TMM``), +exercising the matmul path that does NOT need the transpose flag. + +What it computes (output per (m_block, n_block) tile): + + C[m_block, n_block] = sum_{k_block} A[m_block, k_block] + @ B[k_block, n_block] + + bias[m_block, n_block] (optional) + +Differences vs ``linear_min`` (``transpose_B=True``): + * ``B_hbm`` shape: ``(1, K, 1, N)`` (was ``(1, N, 1, K)``) + * Slice over B: ``[k_block * MLEN : ., bx * MLEN : .]`` + (was ``[bx * MLEN : ., k_block * MLEN : .]``) + * ``T.gemm`` call: no ``transpose_B`` argument (default False) + * Inner ISA: ``M_MM`` instead of ``M_TMM``; per-oc B step + is ``blen`` (cols of (K, N)) instead of + ``blen * mlen`` (rows of (N, K)). + +Everything else (K-acc via SCR_loc + tile_add, grid layout, bias as +host-broadcast (M, N) tile) is identical to ``linear_min``. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_linear_min_no_transpose( + *, + m_blocks: int = 1, + n_blocks: int = 1, + k_blocks: int = 1, + with_bias: bool = False, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if m_blocks < 1 or n_blocks < 1 or k_blocks < 1: + raise ValueError( + f"m_blocks/n_blocks/k_blocks must be >= 1; " + f"got m={m_blocks}, n={n_blocks}, k={k_blocks}" + ) + M = m_blocks * MLEN + N = n_blocks * MLEN + K = k_blocks * MLEN + + if with_bias: + @T.prim_func + def linear_min_no_transpose( + A_hbm: T.Tensor((1, M, 1, K), "float16"), + B_hbm: T.Tensor((1, K, 1, N), "float16"), + BIAS_hbm: T.Tensor((1, M, 1, N), "float16"), + C_hbm: T.Tensor((1, M, 1, N), "float16"), + ): + with T.Kernel(n_blocks, m_blocks, threads=128) as (bx, by): + A_sh = T.alloc_shared((MLEN, MLEN), "float16") + B_sh = T.alloc_shared((MLEN, MLEN), "float16") + BIAS_sh = T.alloc_shared((MLEN, MLEN), "float16") + C_sh = T.alloc_shared((MLEN, MLEN), "float16") + + C_loc = T.alloc_fragment((MLEN, MLEN), "float16") + SCR_loc = T.alloc_fragment((MLEN, MLEN), "float16") + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = T.float16(0) + + for k_block in T.serial(k_blocks): + T.copy( + A_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + A_sh, + ) + # B is (K, N) row-major: walk K along the seq axis and + # N along the hlen axis. No transpose needed in the + # matmul itself. + T.copy( + B_hbm[0, + k_block * MLEN : (k_block + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + B_sh, + ) + + T.gemm(A_sh, B_sh, SCR_loc) + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + SCR_loc[row, col] + + T.copy( + BIAS_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + BIAS_sh, + ) + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + BIAS_sh[row, col] + + T.copy(C_loc, C_sh) + T.copy( + C_sh, + C_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + ) + else: + @T.prim_func + def linear_min_no_transpose( + A_hbm: T.Tensor((1, M, 1, K), "float16"), + B_hbm: T.Tensor((1, K, 1, N), "float16"), + C_hbm: T.Tensor((1, M, 1, N), "float16"), + ): + with T.Kernel(n_blocks, m_blocks, threads=128) as (bx, by): + A_sh = T.alloc_shared((MLEN, MLEN), "float16") + B_sh = T.alloc_shared((MLEN, MLEN), "float16") + C_sh = T.alloc_shared((MLEN, MLEN), "float16") + + C_loc = T.alloc_fragment((MLEN, MLEN), "float16") + SCR_loc = T.alloc_fragment((MLEN, MLEN), "float16") + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = T.float16(0) + + for k_block in T.serial(k_blocks): + T.copy( + A_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + k_block * MLEN : (k_block + 1) * MLEN], + A_sh, + ) + T.copy( + B_hbm[0, + k_block * MLEN : (k_block + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + B_sh, + ) + + T.gemm(A_sh, B_sh, SCR_loc) + + for row in T.serial(MLEN): + for col in T.Parallel(MLEN): + C_loc[row, col] = C_loc[row, col] + SCR_loc[row, col] + + T.copy(C_loc, C_sh) + T.copy( + C_sh, + C_hbm[0, + by * MLEN : (by + 1) * MLEN, + 0, + bx * MLEN : (bx + 1) * MLEN], + ) + + lowered = linear_min_no_transpose + constants = { + "M": M, "N": N, "K": K, "MLEN": MLEN, + "M_BLOCKS": m_blocks, "N_BLOCKS": n_blocks, "K_BLOCKS": k_blocks, + "WITH_BIAS": with_bias, + } + return lowered, constants + + +__all__ = ["make_linear_min_no_transpose"] diff --git a/tilelang_tvm_compiler/kernels/modulate_min.py b/tilelang_tvm_compiler/kernels/modulate_min.py new file mode 100644 index 0000000..0d8ec67 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/modulate_min.py @@ -0,0 +1,108 @@ +"""Modulate-min kernel — adaLN ``out = (1 + scale) * x + shift``. + +Designed for the tile_mul / tile_add VRAM elementwise path (no FPRAM +involvement). The literal ``1 + scale`` term is hoisted out of the +kernel: the testbench passes ``scale_plus_one = scale + 1`` directly, +so the kernel reduces to one tile_mul + one tile_add: + + tmp = scale_plus_one * x + out = tmp + shift + +Both stores are single-op binops over same-shape VRAM tiles — exactly +what fold + plena.tile_* expects. + +Layout: HBM -> VRAM tiles -> tile_mul -> tile_add -> VRAM -> HBM. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_modulate_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"modulate_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + + @T.prim_func + def modulate_min( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + SCALE1P_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + SHIFT_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + SCALE1P_sh = T.alloc_shared((rows, hlen), "float16") + SHIFT_sh = T.alloc_shared((rows, hlen), "float16") + TMP_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + T.copy( + SCALE1P_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + SCALE1P_sh, + ) + T.copy( + SHIFT_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + SHIFT_sh, + ) + + # tmp = (1 + scale) * x + for row in T.serial(rows): + for col in T.Parallel(hlen): + TMP_sh[row, col] = SCALE1P_sh[row, col] * X_sh[row, col] + + # out = tmp + shift + for row in T.serial(rows): + for col in T.Parallel(hlen): + Y_sh[row, col] = TMP_sh[row, col] + SHIFT_sh[row, col] + + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + ) + + lowered = modulate_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_modulate_min"] diff --git a/tilelang_tvm_compiler/kernels/online_softmax_min.py b/tilelang_tvm_compiler/kernels/online_softmax_min.py new file mode 100644 index 0000000..d7eef53 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/online_softmax_min.py @@ -0,0 +1,166 @@ +"""Minimal online-softmax kernel over one Score tile (HBM round-trip). + + m_curr = max(score_row) + m_new = max(m_old, m_curr) + m_res = exp(m_old - m_new) + score = exp(score - m_new) + l_new = l_old * m_res + sum(score) + m_old/l_old updated in-place + +FPRAM layout: one flat region starting at FPRAM_USER_BASE; each slot +takes lane_count*rows elements; addresses passed directly to the FP / +row `_at` intrinsics. +""" + +import tvm +from tvm.script import tir as T + +from ..address_alloc import FPRAM_USER_BASE +from ..plena_settings import load_sizes as _load_sizes + + +_SLOTS = ("M_OLD", "M_CURR", "M_RES", "L_OLD", "L_NEW", "P_SUM") + + +def _slot_bases(fp_state_elems: int) -> dict[str, int]: + return {name: FPRAM_USER_BASE + i * fp_state_elems for i, name in enumerate(_SLOTS)} + + +def make_online_softmax_hbm( + *, + rows: int | None = None, + hlen: int | None = None, + lane_count: int = 4, + active_lane: int = 0, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"online_softmax_hbm currently requires rows == MLEN ({MLEN}), got {rows}") + if hlen <= 0 or hlen > MLEN or MLEN % hlen != 0: + raise ValueError(f"hlen must be a positive divisor of MLEN={MLEN}, got {hlen}") + if lane_count * hlen != MLEN: + raise ValueError( + f"lane_count * hlen must equal MLEN ({lane_count} * {hlen} == {MLEN})" + ) + if not (0 <= active_lane < lane_count): + raise ValueError(f"active_lane must be in [0, {lane_count}), got {active_lane}") + + grouped = hlen < MLEN + mask_val = 1 << active_lane + SCORE_SHAPE = (1, rows, lane_count, hlen) + + fp_state_elems = lane_count * rows + bases = _slot_bases(fp_state_elems) + M_OLD = bases["M_OLD"] + M_CURR = bases["M_CURR"] + M_RES = bases["M_RES"] + L_OLD = bases["L_OLD"] + L_NEW = bases["L_NEW"] + P_SUM = bases["P_SUM"] + + @T.prim_func + def online_softmax_hbm( + Score_hbm: T.Buffer(SCORE_SHAPE, "float16"), + Score_out_hbm: T.Buffer(SCORE_SHAPE, "float16"), + ): + Score_v = T.alloc_buffer(SCORE_SHAPE, "float16", scope="vram") + + T.evaluate(T.call_extern( + "handle", "plena.dma_h2v_slice", + Score_hbm.data, Score_v.data, + 4, + 0, 0, 0, 0, + 1, rows, lane_count, hlen, + )) + for lane in T.serial(lane_count): + for row in T.serial(rows): + T.evaluate(T.call_extern( + "handle", "plena.fp_copy_at", + M_OLD + lane * rows + row, M_CURR + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.row_reduce_max_at", + Score_v.data, M_CURR + lane * rows + row, row, lane, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_sub_at", + M_OLD + lane * rows + row, M_CURR + lane * rows + row, M_RES + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_exp_at", + M_RES + lane * rows + row, M_RES + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.row_sub_fp_at", + Score_v.data, M_CURR + lane * rows + row, Score_v.data, row, lane, + )) + T.evaluate(T.call_extern( + "handle", "plena.row_exp_at", + Score_v.data, Score_v.data, row, lane, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_sub_at", + P_SUM + lane * rows + row, P_SUM + lane * rows + row, P_SUM + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.row_reduce_sum_at", + Score_v.data, P_SUM + lane * rows + row, row, lane, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_mul_at", + L_OLD + lane * rows + row, M_RES + lane * rows + row, L_NEW + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_add_at", + L_NEW + lane * rows + row, P_SUM + lane * rows + row, L_NEW + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_copy_at", + M_CURR + lane * rows + row, M_OLD + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.fp_copy_at", + L_NEW + lane * rows + row, L_OLD + lane * rows + row, + )) + T.evaluate(T.call_extern( + "handle", "plena.dma_v2h_slice", + Score_v.data, Score_out_hbm.data, + 4, + 0, 0, 0, 0, + 1, rows, lane_count, hlen, + )) + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "LANE_COUNT": lane_count, + "ACTIVE_LANE": active_lane, + "MASK_VAL": mask_val, + "GROUPED": grouped, + "FPRAM_USER_BASE": FPRAM_USER_BASE, + "FP_STATE_ELEMS": fp_state_elems, + "M_OLD_ADDR": M_OLD, + "M_CURR_ADDR": M_CURR, + "M_RES_ADDR": M_RES, + "L_OLD_ADDR": L_OLD, + "L_NEW_ADDR": L_NEW, + "P_SUM_ADDR": P_SUM, + } + return online_softmax_hbm, constants + + +def build_hbm_module( + *, rows: int | None = None, hlen: int | None = None, + lane_count: int = 4, active_lane: int = 0, +) -> tvm.IRModule: + func, _ = make_online_softmax_hbm( + rows=rows, hlen=hlen, lane_count=lane_count, active_lane=active_lane, + ) + return tvm.IRModule({"online_softmax_hbm": func}) diff --git a/tilelang_tvm_compiler/kernels/residual_gate_min.py b/tilelang_tvm_compiler/kernels/residual_gate_min.py new file mode 100644 index 0000000..a803012 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/residual_gate_min.py @@ -0,0 +1,101 @@ +"""Residual-gate-min kernel — ``out = x + gate * y`` on VRAM tiles. + +Two single-op stores: one tile_mul (``tmp = gate * y``) then one +tile_add (``out = x + tmp``). All operands are same-shape VRAM +tiles — exactly the pattern fold + plena.tile_* expects. + +Layout: HBM -> VRAM tiles -> tile_mul -> tile_add -> VRAM -> HBM. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_residual_gate_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"residual_gate_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + + @T.prim_func + def residual_gate_min( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + GATE_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + OUT_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + GATE_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + TMP_sh = T.alloc_shared((rows, hlen), "float16") + OUT_sh = T.alloc_shared((rows, hlen), "float16") + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + T.copy( + GATE_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + GATE_sh, + ) + T.copy( + Y_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + Y_sh, + ) + + # tmp = gate * y + for row in T.serial(rows): + for col in T.Parallel(hlen): + TMP_sh[row, col] = GATE_sh[row, col] * Y_sh[row, col] + + # out = x + tmp + for row in T.serial(rows): + for col in T.Parallel(hlen): + OUT_sh[row, col] = X_sh[row, col] + TMP_sh[row, col] + + T.copy( + OUT_sh, + OUT_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + ) + + lowered = residual_gate_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_residual_gate_min"] diff --git a/tilelang_tvm_compiler/kernels/rmsnorm_min.py b/tilelang_tvm_compiler/kernels/rmsnorm_min.py new file mode 100644 index 0000000..70c953c --- /dev/null +++ b/tilelang_tvm_compiler/kernels/rmsnorm_min.py @@ -0,0 +1,159 @@ +"""RMSNorm-min kernel — ``out = x * scale * rsqrt(mean(x^2) + eps)``. + +Decomposed into PLENA's single-op stores: + + sq[i,j] = x[i,j] * x[i,j] tile_mul + ss[i] = sum_j sq[i,j] row_reduce_sum_at + ss_n[i] = ss[i] * INV_N[i] fp_mul (INV_N preloaded = 1/N) + ss_eps[i] = ss_n[i] + EPS[i] fp_add (EPS preloaded) + norm[i] = sqrt(ss_eps[i]) fp_sqrt + inv[i] = 1 / norm[i] fp_reci (literal-1 numerator) + xs[i,j] = x[i,j] * SCALE[i,j] tile_mul (host broadcasts scale) + out[i,j] = xs[i,j] * inv[i] row_mul_fp_at + +The ``1/N`` and ``eps`` scalars come from preloaded FPRAM fragments +(mirroring gelu_min). The learnable ``scale`` weight ``(hlen,)`` is +pre-broadcast on the host into a ``(rows, hlen)`` tile so we can use +``tile_mul`` instead of a row-broadcast op (none exists for VRAM x VRAM +broadcast yet). +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_rmsnorm_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, + eps: float = 1e-6, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"rmsnorm_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + # Constants inlined as T.float16(...) literals below; the + # hoist_float_constants pre-pass synthesises 1-slot global.fpram + # buffers for each unique value. + inv_n_val = 1.0 / hlen + + @T.prim_func + def rmsnorm_min( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + SCALE_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + # HBM ↔ on-chip staging (shared). + X_sh = T.alloc_shared((rows, hlen), "float16") + SCALE_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + + # Rank-2 VRAM work fragments — row_*_fp_at / tile_mul take + # their dst/src from these. (rank-1 fragments would land on + # FPRAM scalars, rank-2 stays on VRAM.) + X_loc = T.alloc_fragment((rows, hlen), "float16") + SC_loc = T.alloc_fragment((rows, hlen), "float16") + SQ_loc = T.alloc_fragment((rows, hlen), "float16") + Y_loc = T.alloc_fragment((rows, hlen), "float16") + + # Per-row FP scratch (rank-1 → FPRAM scalar slots). + SS = T.alloc_fragment((rows,), "float16") + SS_N = T.alloc_fragment((rows,), "float16") + SS_EPS = T.alloc_fragment((rows,), "float16") + NORM = T.alloc_fragment((rows,), "float16") + INV = T.alloc_fragment((rows,), "float16") + + # 1/hlen and eps are inlined as T.float16(...) literals + # below; auto-hoisted into 1-slot global.fpram buffers. + # The zero seed for SS (``SS = SS_INIT[row]``) is also + # inlined; T.float16(0) takes fold's zero-fill path and + # doesn't go through FPRAM at all. + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + T.copy( + SCALE_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + SCALE_sh, + ) + T.copy(X_sh, X_loc) + T.copy(SCALE_sh, SC_loc) + + # sq = x * x, and seed SS from preloaded zero in the same + # row loop. V_RED_SUM accumulates into the FPRAM slot, so + # SS must be pre-zeroed before the reduce. Mirrors + # flash_attention_min's pattern of folding + # ``P_SUM[row] = L_INIT[row]`` into the row loop that + # precedes ``T.reduce_sum(..., clear=False)``. + for row in T.serial(rows): + for col in T.Parallel(hlen): + SQ_loc[row, col] = X_loc[row, col] * X_loc[row, col] + SS[row] = T.float16(0) + + # ss[i] = sum_j sq[i,j] + T.reduce_sum(SQ_loc, SS, dim=1) + + for row in T.serial(rows): + SS_N[row] = SS[row] * T.float16(inv_n_val) + SS_EPS[row] = SS_N[row] + T.float16(eps) + NORM[row] = T.sqrt(SS_EPS[row]) + # literal-1 numerator so fold picks the reci pattern + INV[row] = T.float16(1.0) / NORM[row] + + # y = x * scale (host has broadcast SCALE into (rows, hlen)). + # Write directly into Y_loc — packed-head row_mul_fp_at below + # requires dst == src (its unmasked heads otherwise overwrite + # dst with src verbatim, destroying cross-by_phase writes). + for row in T.serial(rows): + for col in T.Parallel(hlen): + Y_loc[row, col] = X_loc[row, col] * SC_loc[row, col] + + # out[i,j] = y[i,j] * inv[i] (in-place on Y_loc) + for row in T.serial(rows): + for col in T.Parallel(hlen): + Y_loc[row, col] = Y_loc[row, col] * INV[row] + + T.copy(Y_loc, Y_sh) + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + ) + + lowered = rmsnorm_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_rmsnorm_min"] diff --git a/tilelang_tvm_compiler/kernels/rope_min.py b/tilelang_tvm_compiler/kernels/rope_min.py new file mode 100644 index 0000000..038b833 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/rope_min.py @@ -0,0 +1,152 @@ +"""RoPE-min kernel — written in tilelang style. + +Multi-S × multi-head RoPE with branchless pair-swap and FPRAM-scalar +lowering. The pair-swap (output element d depends on input elements +d and d^1) is expressed as loop fission over half_dim — no +``T.if_then_else``, no per-element predicate. Each iteration writes both +the even (``2*i``) and odd (``2*i+1``) output slots in straight-line +code. + +Lowering path: + * ``T.copy`` for HBM↔VRAM tile transfer (existing). + * Per-row inner: ``T.copy(shared_2d[row, 0], frag_1d)`` lowers to a + contiguous DMA from VRAM into an FPRAM-resident 1D fragment (v2f). + Same fragments are reused across rows — one FPRAM region per slot, + bound once by ``allocate_group_memory``. + * FPRAM-only scalar FMA over half_dim pairs. ``X_FP[e]`` and + ``X_FP[o]`` are different elements of the same fragment addressed + as scalars; no cross-element hardware shuffle needed. + * ``T.copy(frag_1d, shared_2d[row, 0])`` lowers to f2v symmetrically. + +What this kernel needs from the compiler that isn't already in place: + * ``T.copy(shared, fragment)`` and ``T.copy(fragment, shared)`` lowered + to ``plena.dma_v2f`` / ``plena.dma_f2v`` (length=hlen, contiguous, + dynamic row index). + * Scalar FPRAM lvalue stores (``OUT_FP[e] = ...``) where ``e`` is an + affine expression of an enclosing loop var. flash_attention_min + already lvalue-stores into a 1D fragment with a loop-var index + (``M_RES[row] = ...``); this just reuses that with ``e = 2*i``. + +K-side RoPE has identical structure (XK→K, SIN↔NEG_SIN role swap). +Kept out of this minimal kernel. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_rope_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + half_dim: int = 8, + num_s_blocks: int = 2, + batch: int = 1, +): + full_dim = half_dim * 2 + if full_dim != hlen: + raise ValueError( + f"full_dim (= 2*half_dim = {full_dim}) must equal hlen ({hlen})" + ) + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError( + f"rope_min requires rows == MLEN ({MLEN}), got {rows}" + ) + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + + @T.prim_func + def rope_min( + XQ_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + COS_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + SIN_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + NEG_SIN_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Q_OUT_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + XQ_sh = T.alloc_shared((rows, hlen), "float16") + COS_sh = T.alloc_shared((rows, hlen), "float16") + SIN_sh = T.alloc_shared((rows, hlen), "float16") + NEG_SIN_sh = T.alloc_shared((rows, hlen), "float16") + Q_OUT_sh = T.alloc_shared((rows, hlen), "float16") + + # FPRAM scratch — one (hlen,) fragment per source. + # Allocated at kernel scope; same FPRAM offsets are reused + # across every row of every (s_block, head) tile. + X_FP = T.alloc_fragment((hlen,), "float16") + C_FP = T.alloc_fragment((hlen,), "float16") + S_FP = T.alloc_fragment((hlen,), "float16") + NS_FP = T.alloc_fragment((hlen,), "float16") + OUT_FP = T.alloc_fragment((hlen,), "float16") + + T.copy( + XQ_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + XQ_sh, + ) + T.copy( + COS_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + COS_sh, + ) + T.copy( + SIN_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + SIN_sh, + ) + T.copy( + NEG_SIN_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + NEG_SIN_sh, + ) + + for row in T.serial(rows): + T.copy(XQ_sh [row, 0], X_FP) + T.copy(COS_sh [row, 0], C_FP) + T.copy(SIN_sh [row, 0], S_FP) + T.copy(NEG_SIN_sh[row, 0], NS_FP) + + for i in T.unroll(half_dim): + e = 2 * i + o = 2 * i + 1 + OUT_FP[e] = X_FP[e] * C_FP[e] + X_FP[o] * NS_FP[e] + OUT_FP[o] = X_FP[o] * C_FP[o] + X_FP[e] * S_FP[o] + + T.copy(OUT_FP, Q_OUT_sh[row, 0]) + + T.copy( + Q_OUT_sh, + Q_OUT_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + ) + + # Return the raw PrimFunc — ``compile_kernel`` runs stmt prep + the + # mid_ir pipeline itself, so factories don't pre-lower anymore. + lowered = rope_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "HALF_DIM": half_dim, + "FULL_DIM": full_dim, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants diff --git a/tilelang_tvm_compiler/kernels/silu_min.py b/tilelang_tvm_compiler/kernels/silu_min.py new file mode 100644 index 0000000..0a84a01 --- /dev/null +++ b/tilelang_tvm_compiler/kernels/silu_min.py @@ -0,0 +1,111 @@ +"""SiLU-min kernel — sigmoid linear unit on FP scalar pipeline. + + SiLU(x) = x * sigmoid(x) + sigmoid(x) = 1 / (1 + exp(-x)) + +PLENA has no sigmoid ISA; the kernel composes it from +``exp / add / reci / mul``. Only two scalar constants are needed — +``1.0`` (for the reci numerator and the denominator add) and ``-1.0`` +(to express ``exp(-x)`` as ``exp(NEG_ONE * x)``, since there is no +unary negate in the FP scalar set). Both come from preloaded rank-1 +``local.fragment`` slots, mirroring gelu_min and flash_attention_min. + +Layout: HBM -> VRAM (shared) -> per-row FPRAM scratch -> VRAM -> HBM. +""" + +import tilelang.language as T + +from ..plena_settings import load_sizes as _load_sizes + + +def make_silu_min( + *, + rows: int | None = None, + hlen: int | None = None, + head_count: int = 8, + num_s_blocks: int = 2, + batch: int = 1, +): + # Hardware sizes default to plena_settings.toml's active mode. + _hw = _load_sizes() + MLEN = _hw.mlen + if hlen is None: + hlen = _hw.hlen + if rows is None: + rows = MLEN + if rows != MLEN: + raise ValueError(f"silu_min requires rows == MLEN ({MLEN}), got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN ({MLEN}); got hlen={hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be a multiple of MLEN/hlen={hardware_lane_count}; " + f"got {head_count}" + ) + if num_s_blocks < 1: + raise ValueError(f"num_s_blocks must be >= 1, got {num_s_blocks}") + + seq_len = num_s_blocks * rows + + @T.prim_func + def silu_min( + X_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + Y_hbm: T.Tensor((batch, seq_len, head_count, hlen), "float16"), + ): + with T.Kernel(num_s_blocks, head_count, threads=128) as (s_block, by): + X_sh = T.alloc_shared((rows, hlen), "float16") + Y_sh = T.alloc_shared((rows, hlen), "float16") + + X_FP = T.alloc_fragment((hlen,), "float16") + Y_FP = T.alloc_fragment((hlen,), "float16") + + # 1.0 / -1.0 are inlined as T.float16(...) literals below. + # The hoist_float_constants pre-pass synthesises one 1-slot + # global.fpram buffer per unique value. + + neg_x = T.alloc_fragment((hlen,), "float16") # -x + e_negx = T.alloc_fragment((hlen,), "float16") # exp(-x) + denom = T.alloc_fragment((hlen,), "float16") # 1 + exp(-x) + sig = T.alloc_fragment((hlen,), "float16") # sigmoid(x) + + T.copy( + X_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + X_sh, + ) + + for row in T.serial(rows): + T.copy(X_sh[row, 0], X_FP) + + for i in T.unroll(hlen): + neg_x[i] = T.float16(-1.0) * X_FP[i] + e_negx[i] = T.exp(neg_x[i]) + denom[i] = T.float16(1.0) + e_negx[i] + # ``1.0 / x`` literal numerator — fold lowers this + # to fp_reci_at. A BufferLoad/BufferLoad div would + # not match the reci pattern. + sig[i] = T.float16(1.0) / denom[i] + Y_FP[i] = X_FP[i] * sig[i] + + T.copy(Y_FP, Y_sh[row, 0]) + + T.copy( + Y_sh, + Y_hbm[0, s_block * rows : (s_block + 1) * rows, by, 0:hlen], + ) + + lowered = silu_min + + constants = { + "ROWS": rows, + "MLEN": MLEN, + "HLEN": hlen, + "HEAD_COUNT": head_count, + "BATCH": batch, + "NUM_S_BLOCKS": num_s_blocks, + "HARDWARE_LANE_COUNT": hardware_lane_count, + } + return lowered, constants + + +__all__ = ["make_silu_min"] diff --git a/tilelang_tvm_compiler/loop_interchange.py b/tilelang_tvm_compiler/loop_interchange.py new file mode 100644 index 0000000..e147be9 --- /dev/null +++ b/tilelang_tvm_compiler/loop_interchange.py @@ -0,0 +1,112 @@ +"""Loop interchange pass (HLIR post-processing). + +Why this pass exists +-------------------- + +``to_plena`` lowers per-lane work as ``for X { for C { ... } }`` where +``C`` is the cluster (per-lane) axis and ``X`` is some enclosing loop +(typically ``for row``). A sibling cluster loop lowered separately — +e.g. ``for C { matmul }`` — sits next to ``for X``, NOT next to the +inner ``for C``, so :mod:`fuse_adjacent_loops` cannot reach the two. + +Interchanging the nest:: + + for X { for C { body } } -> for C { for X { body } } + +lifts the cluster loop to the same level as its sibling, after which +the fusion pass merges them. Run the two passes alternately to a fixed +point (see :func:`pipeline`) so a chain collapses fully. + +Legality +-------- + +This pass interchanges **only** when the inner loop is a cluster axis +(tagged ``is_cluster_axis`` by ``to_plena``) and the outer loop is not. +A cluster axis is per-lane: every lane owns its own buffer slots, so +there is no cross-iteration dependency between a cluster axis and any +non-cluster axis. That single structural condition *is* the legality +proof — no dependency analysis is needed. + +Scope +----- + +Conservative on purpose: only the clean case is handled — an outer +``for`` whose body is *exactly one* statement, that statement being a +cluster ``for``. A mixed outer body (cluster loop interleaved with +other ops) would need those other ops dragged inside the cluster loop, +which is not always legal; that case is left untouched. +""" + +from __future__ import annotations + +from typing import List, Tuple + +from . import hlir as _hlir + + +def _is_for(op: _hlir.Op) -> bool: + return op.kind == "for" + + +def _is_cluster_for(op: _hlir.Op) -> bool: + return _is_for(op) and bool(op.annotations.get("is_cluster_axis")) + + +def _clone_for(op: _hlir.Op, body: List[_hlir.Op]) -> _hlir.Op: + """A copy of a ``for`` op with a different body, annotations and all + loop metadata preserved.""" + return _hlir.Op( + kind="for", + buffer_args=list(op.buffer_args), + scalar_args=list(op.scalar_args), + annotations=dict(op.annotations), + body=body, + buffer_axes=list(op.buffer_axes), + ) + + +def _interchange_body(ops: List[_hlir.Op]) -> Tuple[List[_hlir.Op], bool]: + """Interchange eligible nests in one body list. Recurses first so a + deep nest is handled bottom-up. Returns ``(new_ops, changed)``.""" + changed = False + + # 1) recurse into every for body first. + recursed: List[_hlir.Op] = [] + for op in ops: + if _is_for(op) and op.body is not None: + new_body, sub_changed = _interchange_body(op.body) + changed = changed or sub_changed + recursed.append(_clone_for(op, new_body)) + else: + recursed.append(op) + + # 2) at this level, interchange ``for X { for C {...} }``. + out: List[_hlir.Op] = [] + for op in recursed: + if (_is_for(op) + and not _is_cluster_for(op) + and op.body is not None + and len(op.body) == 1 + and _is_cluster_for(op.body[0])): + outer = op # for X + inner = op.body[0] # for C (cluster) + # for X { for C { body } } -> for C { for X { body } } + new_inner = _clone_for(outer, list(inner.body)) # for X { body } + new_outer = _clone_for(inner, [new_inner]) # for C { for X } + out.append(new_outer) + changed = True + else: + out.append(op) + return out, changed + + +def run(mod: _hlir.HLIRModule) -> Tuple[_hlir.HLIRModule, bool]: + """Interchange cluster-inner loops outward throughout the module. + Returns ``(mod, changed)``; ``changed`` is False once no nest is + eligible (the fixed-point signal). Mutates ``mod.ops`` in place.""" + new_ops, changed = _interchange_body(mod.ops) + mod.ops = new_ops + return mod, changed + + +__all__ = ["run"] diff --git a/tilelang_tvm_compiler/loop_register_alloc.py b/tilelang_tvm_compiler/loop_register_alloc.py new file mode 100644 index 0000000..628c796 --- /dev/null +++ b/tilelang_tvm_compiler/loop_register_alloc.py @@ -0,0 +1,150 @@ +"""Loop-register allocation pass (HLIR liveness). + +Why this pass exists +-------------------- + +GP registers used to be allocated entirely during ISA emission, from a +single 16-register pool shared by three unrelated consumers: + + * ``_emit_for`` — one ``gp_loop`` (C_LOOP hardware counter) per + serial loop, live for the whole loop body. + * ``expr_materializer``— short-lived temporaries for address algebra. + * ``emit_*`` — per-instruction scratch. + +Mixing long-lived loop registers with short-lived temporaries in one +bare pool meant emit-stage code could not tell which registers it was +allowed to pin / free, and deep loop nests exhausted the pool. + +This pass separates the two classes. It walks the HLIR, computes — by +liveness — which GP each serial ``for`` loop's ``gp_loop`` should use, +stamps that onto the ``for`` op, and returns the set of GPs it claimed. +The caller hands that set to the emit-stage ``RegisterAllocator`` as +``gp_reserved``, so emit-stage temporary allocation physically cannot +touch a loop register. See ``doc/LOOP_REGISTER_ALLOC.md``. + +Why liveness here is trivial +---------------------------- + +A loop variable is live exactly over its loop's lexical body. Loop +bodies are *strictly nested* — any two are either nested or disjoint, +never partially overlapping. So "how many loop registers are live at +once" is just the loop-nesting depth, and assigning registers is a +linear walk with a stack: depth-0 loop takes the first reserved GP, +depth-1 the next, and so on. No interference graph, no colouring. + +What this pass does NOT touch +----------------------------- + + * ``unroll`` loops — their index is a compile-time constant + (``_emit_for`` binds it to a ``tir.IntImm``), so they need no + ``gp_loop`` and no reservation. + * Loop index storage — the index itself stays in IntRAM + (``claim_idx_slot``); only the C_LOOP hardware counter is a GP. + * Emit-stage temporaries — still allocated during emission, just from + the now-smaller un-reserved pool. +""" + +from __future__ import annotations + +from typing import Dict, List, Set + +from . import hlir as _hlir + + +class LoopRegisterAllocError(RuntimeError): + pass + + +# GP file is 16 registers; gp0 is the constant-zero register. The emit +# stage still needs a workable pool for op temporaries after loop +# registers are reserved — if a nest is so deep that too few GPs are +# left, fail here with a clear message rather than crashing mid-emit. +_GP_TOTAL = 16 +_GP0_RESERVED = 1 # gp0 +_MIN_EMIT_POOL = 8 # heaviest emit_* needs ~7 scratch + + +def _is_for(op: _hlir.Op) -> bool: + return op.kind == "for" + + +def _loop_kind(op: _hlir.Op) -> str: + return op.annotations.get("loop_kind", "serial") + + +def _is_serial_for(op: _hlir.Op) -> bool: + """A serial ``for`` lowers to a hardware C_LOOP and therefore needs + a ``gp_loop`` register. An ``unroll`` ``for`` does not.""" + return _is_for(op) and _loop_kind(op) not in ("unroll", "unrolled") + + +def _max_serial_depth(ops: List[_hlir.Op]) -> int: + """Deepest chain of *serial* ``for`` nesting in a body list.""" + best = 0 + for op in ops: + if op.body is None: + continue + inner = _max_serial_depth(op.body) + if _is_serial_for(op): + inner += 1 + best = max(best, inner) + return best + + +def _assign(ops: List[_hlir.Op], depth: int, reserved: List[int]) -> None: + """Walk ``ops``; stamp each serial ``for`` with the ``gp_loop`` GP + for its nesting depth. ``reserved`` is the depth-indexed list of GP + numbers (reserved[d] == the GP used by a serial loop at depth d).""" + for op in ops: + if _is_serial_for(op): + # ``depth`` can never exceed what _max_serial_depth measured + # — both count "+1 only on a serial for" identically — so + # this index is always valid. Assert rather than risk a bare + # IndexError if the two ever drift apart. + assert depth < len(reserved), ( + f"loop depth {depth} exceeds reserved register count " + f"{len(reserved)} — _max_serial_depth / _assign drifted" + ) + gp_loop = reserved[depth] + op.annotations["loop_gp"] = gp_loop + if op.body is not None: + _assign(op.body, depth + 1, reserved) + else: + # Non-serial-for (unroll for, leaf op): depth unchanged. + if op.body is not None: + _assign(op.body, depth, reserved) + + +def run(mod: _hlir.HLIRModule) -> Set[int]: + """Assign a ``gp_loop`` GP to every serial ``for`` in ``mod`` and + stamp it onto the op's ``annotations['loop_gp']``. + + Returns the set of GP numbers reserved for loop counters — the + caller passes this to the emit-stage ``RegisterAllocator`` as + ``gp_reserved`` so temporary allocation cannot collide with a loop + register. Mutates ``mod`` in place. + """ + depth = _max_serial_depth(mod.ops) + if depth == 0: + return set() + + # Reserve from the TOP of the GP file (gp15, gp14, …) so the + # low-numbered registers — which emit-stage code and dumps tend to + # use first — stay with the temporary pool. + reserved_list = [(_GP_TOTAL - 1) - d for d in range(depth)] + reserved = set(reserved_list) + + free_for_emit = _GP_TOTAL - _GP0_RESERVED - len(reserved) + if free_for_emit < _MIN_EMIT_POOL: + raise LoopRegisterAllocError( + f"serial loop nesting depth {depth} reserves {len(reserved)} " + f"GP(s); only {free_for_emit} left for emit-stage temporaries " + f"(need >= {_MIN_EMIT_POOL}). Convert an outer loop to " + f"T.unroll(...) — an unrolled loop needs no gp_loop." + ) + + _assign(mod.ops, 0, reserved_list) + return reserved + + +__all__ = ["run", "LoopRegisterAllocError"] diff --git a/tilelang_tvm_compiler/mir.py b/tilelang_tvm_compiler/mir.py new file mode 100644 index 0000000..30082a5 --- /dev/null +++ b/tilelang_tvm_compiler/mir.py @@ -0,0 +1,1152 @@ +"""MIR — machine IR for the PLENA backend. + +This sits between :mod:`pre_isa_ir` (PrimExpr-operand form, machine- +neutral) and the final ISA text. It is the "explicit conversion" layer +the user wanted: where the abstract address algebra of PreIsaIR becomes +named SSA values, and where loop structure / def-use chains are +explicit enough to support standard machine-IR optimisations +(LICM, CSE, DCE, register allocation, spill-to-IntRAM). + +Design summary +-------------- +* **SSA**. Every value computed by an instruction is a unique + :class:`MirValue` with a string name (``%0``, ``%1``, ...). + Operands of subsequent instructions reference values by identity + (the Python ``MirValue`` object), NOT by name; the name is a debug + string only. This mirrors LLVM's ``llvm::Value*`` model. + +* **def/use chain**. Every ``MirValue`` knows the single instruction + that produced it (``defined_by``) and every instruction that + currently uses it (``used_by`` — kept up-to-date by operand mutators + on ``MirInstr``). This is what makes LICM cheap (free-vars of an + expression are the ``defined_by`` set of its operand values), and + what makes register allocation a graph problem on the live-range + intervals. + +* **Block + terminators**. Instructions live in :class:`MirBlock`s. + Each block ends in exactly ONE terminator (a branch / loop-back / + loop-end / return). Non-terminator instructions in the middle are + straight-line. + +* **Loop is a region**. PLENA has a hardware loop primitive + (``C_LOOP_START`` / ``C_LOOP_END``), and unrolled loops are a + separate codegen choice. We model both as :class:`MirLoop` + regions tagged with ``loop_kind`` ∈ ``{"serial", "unroll"}`` — the + loop IS the region; backend chooses how to lower it. + +* **Types**. Three concrete types for now: + - ``"i32"`` — an integer (occupies one GP register at lowering) + - ``"addr_reg"`` — a PLENA address-register value (``aN``) + - ``"fp_reg"`` — a PLENA FPU-register value (``fN``); these are + a hardware-fixed file (f0=0, f1, f2, ...) and + almost never live more than one instruction, + so we encode them as pinned named tokens + rather than real SSA values + ``"void"`` is used as a result type for instructions that don't + produce a value (``M_BTMM`` for example writes to the systolic + array, not a GP). + +The verifier (``verify``) catches: + - operand type / count mismatches per opcode + - def-before-use (every use must come after its def, modulo loop + back-edges that the loop-region structure resolves) + - dangling uses (used_by entries pointing to deleted instructions) + - terminator-in-the-middle / non-terminator-at-end + +Lowering passes (PreIsaPass → MIR, MIR optimise, MIR → ISA) live in +separate modules; this file is the data model + dump + verifier only. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from tvm import tir + + +# ---------------------------------------------------------------------- +# Opcode table +# ---------------------------------------------------------------------- + +# Each opcode declares: +# - result_type: "i32" / "addr_reg" / "fp_reg" / "void" +# - operand_kinds: tuple of expected operand kinds, where each kind is +# "i32" — an i32 MirValue (or an int / IntImm immediate +# that will fold; the verifier accepts both) +# "addr_reg" — an addr_reg MirValue +# "fp_reg" — a verbatim FP register token (``"f0"`` / ``"f1"``) +# held as a Python str on the operand list +# "literal_int" — a Python int / IntImm bound at construction +# (loop bounds, mask flags, immediate fields) +# "verbatim_str" — a Python str dropped into the ISA template +# The same opcode may appear in two forms (different operand +# counts) — those are encoded as separate _OpcodeSpec entries +# with disambiguating internal names ('_'-prefixed for variants). +@dataclass(frozen=True) +class _OpcodeSpec: + result_type: str + operand_kinds: Tuple[str, ...] + isa_mnemonic: str # what the backend writes into the ASM text + # Per-operand position in the emitted ISA, given as a list of + # operand-list indices. Default = identity (0, 1, 2, ...). Used by + # opcodes whose emit order differs from the construction order. + isa_operand_order: Optional[Tuple[int, ...]] = None + + +# Note: this table is the SOURCE OF TRUTH for what a valid MIR +# instruction looks like. Anything not here will fail verify(). +OPCODES: Dict[str, _OpcodeSpec] = { + # ---- integer scalar ---- + "S_ADDI_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="S_ADDI_INT", + ), + "S_ADD_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "i32"), + isa_mnemonic="S_ADD_INT", + ), + "S_SUB_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "i32"), + isa_mnemonic="S_SUB_INT", + ), + "S_MUL_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "i32"), + isa_mnemonic="S_MUL_INT", + ), + "S_LUI_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("literal_int",), + isa_mnemonic="S_LUI_INT", + ), + "S_LD_INT": _OpcodeSpec( + # gp_dst = intram[gp_base + imm] + result_type="i32", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="S_LD_INT", + ), + "S_ST_INT": _OpcodeSpec( + # intram[gp_base + imm] = gp_value (no result) + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="S_ST_INT", + ), + "S_SLLI_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="S_SLLI_INT", + ), + "S_SRLI_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="S_SRLI_INT", + ), + # Reg-amount shifts — the shift count comes from another GP rather + # than an immediate. Used by packed-head mask expressions like + # ``1 << (lane_var % lane_count)``. + "S_SLL_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "i32"), + isa_mnemonic="S_SLL_INT", + ), + "S_SRL_INT": _OpcodeSpec( + result_type="i32", + operand_kinds=("i32", "i32"), + isa_mnemonic="S_SRL_INT", + ), + # ---- FP scalar ---- + "S_LD_FP": _OpcodeSpec( + # f_dst = fpram[gp_addr + 0]; no MIR-level SSA result for fpregs + # (they're a fixed file). The producer carries the fp register + # as a verbatim_str on operand 0. + result_type="void", + operand_kinds=("fp_reg", "i32", "literal_int"), + isa_mnemonic="S_LD_FP", + ), + "S_ST_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "i32", "literal_int"), + isa_mnemonic="S_ST_FP", + ), + "S_ADD_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg", "fp_reg"), + isa_mnemonic="S_ADD_FP", + ), + "S_SUB_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg", "fp_reg"), + isa_mnemonic="S_SUB_FP", + ), + "S_MUL_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg", "fp_reg"), + isa_mnemonic="S_MUL_FP", + ), + "S_MAX_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg", "fp_reg"), + isa_mnemonic="S_MAX_FP", + ), + "S_EXP_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg", "literal_int"), + isa_mnemonic="S_EXP_FP", + ), + "S_RECI_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg"), + isa_mnemonic="S_RECI_FP", + ), + "S_SQRT_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "fp_reg"), + isa_mnemonic="S_SQRT_FP", + ), + "S_MAP_FP_V": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="S_MAP_FP_V", + ), + "S_MAP_V_FP": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="S_MAP_V_FP", + ), + # ---- vector ---- + "V_ADD_VV": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "i32", "literal_int"), + isa_mnemonic="V_ADD_VV", + ), + "V_SUB_VV": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "i32", "literal_int"), + isa_mnemonic="V_SUB_VV", + ), + "V_MUL_VV": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "i32", "literal_int"), + isa_mnemonic="V_MUL_VV", + ), + "V_ADD_VF": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "fp_reg", "literal_int"), + isa_mnemonic="V_ADD_VF", + ), + "V_SUB_VF": _OpcodeSpec( + # PLENA V_SUB_VF takes 5 operands: dst, src, fp_scalar, + # mask_flag, reverse_flag (legacy quirk — always 0). The + # extra trailing flag distinguishes it from V_ADD_VF / + # V_MUL_VF which are 4-operand. + result_type="void", + operand_kinds=( + "i32", "i32", "fp_reg", "literal_int", "literal_int", + ), + isa_mnemonic="V_SUB_VF", + ), + "V_MUL_VF": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "fp_reg", "literal_int"), + isa_mnemonic="V_MUL_VF", + ), + "V_EXP_V": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="V_EXP_V", + ), + "V_RECI_V": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="V_RECI_V", + ), + "V_SQRT_V": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "i32", "literal_int"), + isa_mnemonic="V_SQRT_V", + ), + "V_RED_MAX": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "i32", "literal_int"), + isa_mnemonic="V_RED_MAX", + ), + "V_RED_SUM": _OpcodeSpec( + result_type="void", + operand_kinds=("fp_reg", "i32", "literal_int"), + isa_mnemonic="V_RED_SUM", + ), + # ---- matrix ---- + "M_BTMM": _OpcodeSpec( + result_type="void", + operand_kinds=("verbatim_str", "i32", "i32"), + isa_mnemonic="M_BTMM", + ), + "M_BMM_WO": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="M_BMM_WO", + ), + "M_BTMV": _OpcodeSpec( + result_type="void", + operand_kinds=("verbatim_str", "i32", "i32"), + isa_mnemonic="M_BTMV", + ), + "M_BMV_WO": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="M_BMV_WO", + ), + "M_MV": _OpcodeSpec( + result_type="void", + operand_kinds=("verbatim_str", "i32", "i32"), + isa_mnemonic="M_MV", + ), + "M_MV_WO": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "literal_int"), + isa_mnemonic="M_MV_WO", + ), + "M_MM": _OpcodeSpec( + result_type="void", + operand_kinds=("literal_int", "i32", "i32"), + isa_mnemonic="M_MM", + ), + "M_MM_WO": _OpcodeSpec( + result_type="void", + operand_kinds=("i32", "verbatim_str", "literal_int"), + isa_mnemonic="M_MM_WO", + ), + "M_TMM": _OpcodeSpec( + result_type="void", + operand_kinds=("literal_int", "i32", "i32"), + isa_mnemonic="M_TMM", + ), + # ---- control / setup ---- + "C_SET_SCALE_REG": _OpcodeSpec( + result_type="void", + operand_kinds=("i32",), + isa_mnemonic="C_SET_SCALE_REG", + ), + "C_SET_STRIDE_REG": _OpcodeSpec( + result_type="void", + operand_kinds=("i32",), + isa_mnemonic="C_SET_STRIDE_REG", + ), + "C_SET_ADDR_REG": _OpcodeSpec( + # Bind a PLENA addr-reg slot. HW builds the 64-bit address as + # ``(gp[rs1] << 32) | gp[rs2]`` (see main.rs C_SET_ADDR_REG), so + # the ISA needs TWO gp sources: high word then low word. Our + # addresses are 32-bit, so the high word is the hardwired-zero + # ``gp0`` (matches the legacy backend_emit form + # ``C_SET_ADDR_REG aN, gp0, gp{addr}``). + # operand 0: verbatim "gp0" — the constant-zero high word. + # operand 1: the source i32 SSA value (the low/address word). + # result : an addr_reg SSA value (the bound aN). + result_type="addr_reg", + operand_kinds=("verbatim_str", "i32"), + isa_mnemonic="C_SET_ADDR_REG", + ), + "C_SET_V_MASK_REG": _OpcodeSpec( + result_type="void", + operand_kinds=("i32",), + isa_mnemonic="C_SET_V_MASK_REG", + ), + # ---- HBM ---- + "H_PREFETCH_V": _OpcodeSpec( + result_type="void", + operand_kinds=( + "i32", "i32", "addr_reg", + "literal_int", "literal_int", + ), + isa_mnemonic="H_PREFETCH_V", + ), + "H_PREFETCH_M": _OpcodeSpec( + result_type="void", + operand_kinds=( + "i32", "i32", "addr_reg", + "literal_int", "literal_int", + ), + isa_mnemonic="H_PREFETCH_M", + ), + "H_STORE_V": _OpcodeSpec( + result_type="void", + operand_kinds=( + "i32", "i32", "addr_reg", + "literal_int", "literal_int", + ), + isa_mnemonic="H_STORE_V", + ), + "H_LOAD_V": _OpcodeSpec( + result_type="void", + operand_kinds=( + "i32", "i32", "addr_reg", + "literal_int", "literal_int", + ), + isa_mnemonic="H_LOAD_V", + ), + # ---- pseudo / meta ---- + # ``_COMMENT`` does not emit a real instruction; backend prints it + # as a ``; ...`` line. operands = (one verbatim_str). + "_COMMENT": _OpcodeSpec( + result_type="void", + operand_kinds=("verbatim_str",), + isa_mnemonic=";", + ), +} + + +# ---------------------------------------------------------------------- +# Core data structures +# ---------------------------------------------------------------------- + +class MirValue: + """An SSA value. + + Mirrors MLIR-style SSA: each value has exactly one definition site. + There are three legal definition sites: + + 1. ``defined_by`` set to a :class:`MirInstr` — the standard case: + a producer instruction computes this value. + 2. ``defined_by is None`` AND ``block_arg_of`` set to a + :class:`MirBlock` — a block argument, supplied by the + enclosing region (loop header) on each entry to that block. + loop_var values are of this kind. + 3. ``defined_by is None`` AND ``is_function_const`` is True — a + function-level constant value (today: ``gp0``, the + hardware-fixed zero). Treated by passes as a "value that just + exists" — no def site to schedule. + + ``dtype`` ∈ ``{"i32", "addr_reg", "fp_reg"}``. void instructions + don't produce a MirValue at all (their ``MirInstr.result`` is + None). + """ + + __slots__ = ( + "name", "dtype", "defined_by", "used_by", + "block_arg_of", "is_function_const", + ) + + def __init__(self, name: str, dtype: str) -> None: + self.name: str = name + self.dtype: str = dtype + self.defined_by: Optional["MirInstr"] = None + self.used_by: List["MirInstr"] = [] + # Set when this value is the block argument of a particular + # MirBlock. ``defined_by`` stays None in that case. + self.block_arg_of: Optional["MirBlock"] = None + # True for function-level constants (e.g. gp0). ``defined_by`` + # and ``block_arg_of`` both stay None. + self.is_function_const: bool = False + + def __repr__(self) -> str: + return f"%{self.name}:{self.dtype}" + + +# An operand can be: +# * MirValue — reference to an SSA value produced earlier +# * int — compile-time literal int (for literal_int kinds) +# * tir.IntImm — TVM integer immediate; treated as int by passes +# * str — verbatim token (e.g. "f0", "f1", "gp0" for the +# constant-zero source on instructions that hard- +# code it) +MirOperand = Union[MirValue, int, "tir.IntImm", str] + + +class MirInstr: + """One MIR instruction. + + ``opcode`` is a key in ``OPCODES`` above. ``operands`` is a list of + ``MirOperand``s whose kinds match ``OPCODES[opcode].operand_kinds``. + ``result`` is a ``MirValue`` for non-void opcodes, else None. + + Use ``set_operand(i, val)`` to mutate operands so the def-use chain + stays consistent (the old operand's ``used_by`` loses this instr, + the new operand's gains it). + """ + + __slots__ = ( + "opcode", "operands", "result", "parent", + "annotations", + ) + + def __init__( + self, + opcode: str, + operands: List[MirOperand], + result: Optional[MirValue] = None, + ) -> None: + if opcode not in OPCODES: + raise ValueError( + f"MirInstr: unknown opcode {opcode!r}. Add an entry to " + f"mir.OPCODES." + ) + self.opcode = opcode + self.operands: List[MirOperand] = list(operands) + self.result = result + self.parent: Optional["MirBlock"] = None + # Free-form per-pass scratch (debug source-PreIsaOp index, + # optimisation hints, etc.). + self.annotations: Dict[str, Any] = {} + # Wire result.defined_by + each MirValue operand's used_by. + if result is not None: + if result.defined_by is not None: + raise ValueError( + f"MirInstr: result {result!r} is already defined " + f"by another instruction {result.defined_by!r}; " + f"each SSA value must have exactly one def." + ) + result.defined_by = self + for op in self.operands: + if isinstance(op, MirValue): + op.used_by.append(self) + + def set_operand(self, i: int, new: MirOperand) -> None: + """Replace the i-th operand, updating def-use chains.""" + old = self.operands[i] + if isinstance(old, MirValue): + try: + old.used_by.remove(self) + except ValueError: + pass + self.operands[i] = new + if isinstance(new, MirValue): + new.used_by.append(self) + + def replace_all_uses_of( + self, old: MirValue, new: MirOperand, + ) -> None: + """Replace every occurrence of ``old`` in this instr's operands + with ``new``. Updates def-use chains. Common subroutine for + the same-named LLVM API used by CSE / value-replacement + rewrites.""" + for i, op in enumerate(self.operands): + if op is old: + self.set_operand(i, new) + + def __repr__(self) -> str: + op_strs = [_fmt_operand(o) for o in self.operands] + if self.result is not None: + return ( + f"%{self.result.name} = {self.opcode} " + f"{', '.join(op_strs)}" + ) + return f"{self.opcode} {', '.join(op_strs)}" + + +@dataclass +class MirBlock: + """A basic block — a straight-line interleaved sequence of + :class:`MirInstr`s and nested :class:`MirLoop` regions, prefixed + by an optional list of block arguments (MLIR-style). + + Block arguments are SSA values that the enclosing region + (function entry, or a loop header) supplies on each entry. A + MirLoop's body block has exactly one argument — the loop_var + SSA value — and that argument's ``block_arg_of`` points back at + this block. From the body's perspective the argument is a + normal SSA value with no in-block def: it just "is there". + + ``items`` preserves source order so loops can appear interleaved + with instructions; the dump walks them in sequence. + + PLENA kernels are loop-nested DAGs only — we don't model + arbitrary branching (no MirIf / phi); a block has no terminator + instruction. Control flow is implicit in the loop-region + nesting. + """ + + name: str + items: List[Union["MirInstr", "MirLoop"]] = field(default_factory=list) + # Block arguments — SSA values supplied by the enclosing region + # at entry. For loop body blocks, this is ``[loop_var]``. + arguments: List[MirValue] = field(default_factory=list) + # Parent loop region (None = function-level top scope). Set when + # this block is added to a MirLoop.body. + parent_loop: Optional["MirLoop"] = None + + def append(self, item: Union["MirInstr", "MirLoop"]) -> Union["MirInstr", "MirLoop"]: + if isinstance(item, MirInstr): + item.parent = self + elif isinstance(item, MirLoop): + item.parent_block = self + else: + raise TypeError( + f"MirBlock.append: expected MirInstr or MirLoop, " + f"got {type(item).__name__}" + ) + self.items.append(item) + return item + + def add_argument(self, v: MirValue) -> MirValue: + """Register ``v`` as a block argument of this block. The + value's ``block_arg_of`` is set; its ``defined_by`` stays + None (block arguments have no in-block def).""" + if v.defined_by is not None: + raise ValueError( + f"add_argument: {v!r} is already defined by " + f"{v.defined_by!r}; cannot also be a block argument" + ) + if v.block_arg_of is not None and v.block_arg_of is not self: + raise ValueError( + f"add_argument: {v!r} is already an argument of " + f"block {v.block_arg_of.name!r}" + ) + v.block_arg_of = self + if v not in self.arguments: + self.arguments.append(v) + return v + + # Back-compat helper for code that iterates only MirInstrs. + @property + def instrs(self) -> List["MirInstr"]: + return [it for it in self.items if isinstance(it, MirInstr)] + + +@dataclass +class MirLoop: + """A loop region. + + A loop's body is a list of MirBlocks (currently always exactly + one in the lower passes, but the structure is here so a future + "if-then" inside a loop body can land cleanly). + + ``loop_var`` is the SSA value supplied as the BLOCK ARGUMENT of + the first body block — exactly like MLIR's ``scf.for`` induction + variable. From the body's perspective ``loop_var`` is a normal + SSA value with no in-block def; the region (this MirLoop) is the + notional "definer", injecting a fresh value each iteration. + + Whether the lowering physically unrolls the body (binding + loop_var to a literal int per iter) or emits a hardware loop + (binding loop_var to an IntRAM-backed counter) is the + ``loop_kind`` switch — a backend decision, not an IR-level + distinction. Optimisation passes may REWRITE ``loop_kind`` to + flip the strategy; the IR structure is unchanged. + + ``init`` / ``extent`` are compile-time ints (matching PLENA's + ``C_LOOP_START`` immediate-only iteration count). Runtime-bounded + loops would need a different lowering and are not modelled here. + + ``loop_kind`` ∈ ``{"serial", "unroll"}``. + """ + + name: str + loop_var: MirValue + init: int + extent: int + body: List[MirBlock] = field(default_factory=list) + loop_kind: str = "serial" + # Set by ``MirBlock.append`` when this loop is added to a block's + # items list. None at top-of-function (the function's top block + # is the parent in that case). + parent_block: Optional[MirBlock] = None + # Free-form scratch (loop_gp choice if hand-pinned, source + # PreIsaOp idx, etc.). + annotations: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + # Wire reverse parent pointer on each body block. Producers + # build a MirLoop with ``body=[blk]``; we make sure + # ``blk.parent_loop`` points back, so scope-walks (verify + # dominance, last-use loop-stack, etc.) can climb out of a + # body block to find its enclosing loop region. The same + # invariant applies to any blocks appended later — see + # ``add_body_block``. + for blk in self.body: + blk.parent_loop = self + + def add_body_block(self, blk: MirBlock) -> MirBlock: + """Append a block to this loop's body, setting its + ``parent_loop`` back-pointer.""" + blk.parent_loop = self + self.body.append(blk) + return blk + + +@dataclass +class MirFunction: + """Top-level MIR container — one kernel. + + The function has a sequence of top-level :class:`MirBlock`s. + Loops live inside blocks (via ``MirBlock.append(MirLoop(...))``); + a loop's ``body`` is a list of nested blocks. Walking + ``walk_loops()`` yields all loops in pre-order for passes that + care. + """ + + name: str + blocks: List[MirBlock] = field(default_factory=list) + # The SSA name counter. Auto-incremented by ``mint_value``. + _next_id: int = 0 + # Free-form metadata forwarded from HLIR / PreIsaIR (buffer table + # for the dump, etc.). + metadata: Dict[str, Any] = field(default_factory=dict) + # Function-level constant: hardware-fixed gp0 zero. Available to + # passes as a "value that just exists" — its + # ``is_function_const`` flag is True, ``defined_by`` is None. + # Created by the converter via ``make_gp0_const``. + gp0_value: Optional[MirValue] = None + + def make_gp0_const(self) -> MirValue: + """Mint (or return existing) the function-level gp0 constant.""" + if self.gp0_value is not None: + return self.gp0_value + v = self.mint_value("i32", hint="gp0") + v.is_function_const = True + self.gp0_value = v + return v + + def mint_value(self, dtype: str, hint: str = "") -> MirValue: + """Create a fresh, undefined MirValue. Caller must wire it as + the ``result`` of exactly one MirInstr (the constructor does + the binding). + + ``hint`` is appended to the auto-generated name for readability + (e.g. ``mint_value("i32", "lhs_addr")`` → ``%5_lhs_addr``). + """ + nm = f"{self._next_id}" + if hint: + nm = f"{nm}_{hint}" + self._next_id += 1 + return MirValue(nm, dtype) + + def walk_loops(self): + """Yield every :class:`MirLoop` in the function in pre-order + (outer loops before their nested children).""" + def _recurse(blocks: List[MirBlock]): + for blk in blocks: + for item in blk.items: + if isinstance(item, MirLoop): + yield item + yield from _recurse(item.body) + yield from _recurse(self.blocks) + + def walk_instrs(self): + """Yield every :class:`MirInstr` in the function in source + order, recursing into nested loop bodies.""" + def _recurse(blocks: List[MirBlock]): + for blk in blocks: + for item in blk.items: + if isinstance(item, MirInstr): + yield item + elif isinstance(item, MirLoop): + yield from _recurse(item.body) + yield from _recurse(self.blocks) + + +# ---------------------------------------------------------------------- +# Dump +# ---------------------------------------------------------------------- + +def _fmt_operand(op: MirOperand) -> str: + if isinstance(op, MirValue): + return f"%{op.name}" + if isinstance(op, tir.IntImm): + return str(int(op.value)) + if isinstance(op, int): + return str(op) + if isinstance(op, str): + return op + return repr(op) + + +def format_mir(fn: MirFunction) -> str: + """Pretty-print one MirFunction. Used for ``.mir.txt``.""" + lines = [f"MirFunction({fn.name!r}):"] + if fn.metadata.get("buffers"): + lines.append(" Buffers:") + bufs = fn.metadata["buffers"] + name_w = max((len(n) for n in bufs), default=4) + for nm, b in bufs.items(): + scope = getattr(b, "scope", "?") + shape = getattr(b, "shape", ()) + addr = getattr(b, "address", None) + shape_s = "x".join(str(s) for s in shape) if shape else "()" + addr_s = "?" if addr is None else str(addr) + lines.append( + f" {nm:<{name_w}} scope={scope:<5} addr={addr_s} " + f"shape={shape_s}" + ) + if fn.gp0_value is not None: + lines.append( + f" Function constants: %{fn.gp0_value.name}:i32 = " + f"" + ) + lines.append(" Body:") + for blk in fn.blocks: + _format_block(blk, lines, indent=4) + return "\n".join(lines) + "\n" + + +def _format_block_header(blk: MirBlock) -> str: + """``^body(%2: i32, %3: i32):`` MLIR-style block header.""" + if blk.arguments: + args = ", ".join( + f"%{a.name}: {a.dtype}" for a in blk.arguments + ) + return f"^{blk.name}({args}):" + return f"^{blk.name}:" + + +def _format_block(blk: MirBlock, lines: List[str], indent: int) -> None: + ind = " " * indent + lines.append(f"{ind}{_format_block_header(blk)}") + body_ind = " " * (indent + 2) + for item in blk.items: + if isinstance(item, MirInstr): + lines.append(f"{body_ind}{_format_instr(item)}") + else: + _format_loop(item, lines, indent + 2) + + +def _format_loop(lp: MirLoop, lines: List[str], indent: int) -> None: + ind = " " * indent + lines.append( + f"{ind}loop {lp.name} in [{lp.init}, {lp.init + lp.extent}) " + f"[kind={lp.loop_kind}]" + ) + for body_blk in lp.body: + _format_block(body_blk, lines, indent + 2) + + +def _format_instr(instr: MirInstr) -> str: + op_strs = [_fmt_operand(o) for o in instr.operands] + body = f"{instr.opcode} {', '.join(op_strs)}".rstrip() + if instr.result is not None: + prefix = f"%{instr.result.name}:{instr.result.dtype} = " + return prefix + body + return body + + +# ---------------------------------------------------------------------- +# Verifier +# ---------------------------------------------------------------------- + +class MirVerifyError(RuntimeError): + pass + + +def verify(fn: MirFunction) -> None: + """Sanity-check the MIR. Raises :class:`MirVerifyError` on the + first problem found. + + Checks: + * every instruction's opcode is in ``OPCODES`` + * operand counts + kinds match the opcode spec + * non-void instructions have a non-None result of the right type + * void instructions have ``result is None`` + * every MirValue has exactly one defining instruction + * every operand-as-MirValue reference is reflected in the + defining value's ``used_by`` + * loop_var SSA values are defined by ``_LOOP_VAR_DEF`` instrs + sitting in the loop's body + """ + # Collect every MirValue that is the result of some MirInstr, + # and every MirValue that appears as an operand. The set of + # operand-referenced MirValues that AREN'T defined anywhere in + # the function is a hard error. + defined: Set[int] = set() # id() of MirValues we've seen as results + declared_vals: List[MirValue] = [] + + def _walk_instrs(instrs: List[MirInstr]) -> None: + for instr in instrs: + spec = OPCODES.get(instr.opcode) + if spec is None: + raise MirVerifyError( + f"unknown opcode {instr.opcode!r} on instr {instr!r}" + ) + if len(instr.operands) != len(spec.operand_kinds): + raise MirVerifyError( + f"{instr.opcode}: operand count " + f"{len(instr.operands)} != spec arity " + f"{len(spec.operand_kinds)} ({spec.operand_kinds})" + ) + for i, (op, kind) in enumerate( + zip(instr.operands, spec.operand_kinds), + ): + _check_operand_kind(instr, i, op, kind) + if spec.result_type == "void": + if instr.result is not None: + raise MirVerifyError( + f"{instr.opcode}: void opcode but result " + f"{instr.result!r} is not None" + ) + else: + if instr.result is None: + raise MirVerifyError( + f"{instr.opcode}: non-void opcode but result " + f"is None" + ) + if instr.result.dtype != spec.result_type: + raise MirVerifyError( + f"{instr.opcode}: result dtype " + f"{instr.result.dtype!r} != spec result " + f"{spec.result_type!r}" + ) + if instr.result.defined_by is not instr: + raise MirVerifyError( + f"{instr.opcode}: result {instr.result!r} " + f"defined_by mismatch (claims " + f"{instr.result.defined_by!r}, real {instr!r})" + ) + if id(instr.result) in defined: + raise MirVerifyError( + f"{instr.opcode}: result {instr.result!r} is " + f"already defined by a previous instr " + f"(double-def)" + ) + defined.add(id(instr.result)) + declared_vals.append(instr.result) + + def _walk_block(blk: MirBlock) -> None: + # Block arguments count as "defined" — they have no in-block + # def site but the enclosing region supplies them. + for arg in blk.arguments: + if arg.defined_by is not None: + raise MirVerifyError( + f"block {blk.name!r}: argument {arg!r} also has " + f"defined_by {arg.defined_by!r} — block arguments " + f"must have defined_by=None" + ) + if arg.block_arg_of is not blk: + raise MirVerifyError( + f"block {blk.name!r}: argument {arg!r} has " + f"block_arg_of={arg.block_arg_of!r}, not this block" + ) + if id(arg) in defined: + raise MirVerifyError( + f"block argument {arg!r} double-defined" + ) + defined.add(id(arg)) + declared_vals.append(arg) + for item in blk.items: + if isinstance(item, MirInstr): + _walk_instrs([item]) + elif isinstance(item, MirLoop): + # loop_var is the body's first block argument; verify + # the binding before recursing into the body. + if not item.body: + raise MirVerifyError( + f"loop {item.name}: empty body" + ) + first_body = item.body[0] + if not first_body.arguments or \ + first_body.arguments[0] is not item.loop_var: + raise MirVerifyError( + f"loop {item.name}: loop_var must be the first " + f"argument of body block {first_body.name!r}; " + f"got arguments={first_body.arguments!r}" + ) + for body_blk in item.body: + _walk_block(body_blk) + else: + raise MirVerifyError( + f"block {blk.name}: unexpected item type " + f"{type(item).__name__}" + ) + + # Function-level constants count as defined too. + if fn.gp0_value is not None: + if not fn.gp0_value.is_function_const: + raise MirVerifyError( + f"function gp0 {fn.gp0_value!r} missing " + f"is_function_const flag" + ) + defined.add(id(fn.gp0_value)) + declared_vals.append(fn.gp0_value) + + for blk in fn.blocks: + _walk_block(blk) + + # Cross-check: every MirValue used as an operand is defined, + # used_by chains consistent, AND the def site is an ancestor + # scope of the use site (SCF-style dominance — a value defined + # inside a loop body cannot be referenced from outside the + # body, since the loop has no yield/iter_args mechanism in + # our MIR today). + def _block_chain(blk: MirBlock) -> List[MirBlock]: + """Ancestor chain from ``blk`` outward, crossing one MirLoop + boundary per step.""" + chain = [] + cur = blk + while cur is not None: + chain.append(cur) + if cur.parent_loop is None: + break + cur = cur.parent_loop.parent_block + return chain + + def _def_block(v: MirValue) -> Optional[MirBlock]: + if v.is_function_const: + return None # always in scope, no specific block + if v.block_arg_of is not None: + return v.block_arg_of + if v.defined_by is not None: + return v.defined_by.parent + return None + + def _check_uses_block(blk: MirBlock) -> None: + ancestors = _block_chain(blk) + ancestor_set = {id(b) for b in ancestors} + for item in blk.items: + if isinstance(item, MirInstr): + for i, op in enumerate(item.operands): + if isinstance(op, MirValue): + if id(op) not in defined: + raise MirVerifyError( + f"{item.opcode} operand[{i}] uses " + f"undefined SSA value {op!r}" + ) + if item not in op.used_by: + raise MirVerifyError( + f"{item.opcode} operand[{i}]: SSA " + f"value {op!r}'s used_by chain doesn't " + f"include this instr" + ) + if not op.is_function_const: + db = _def_block(op) + if db is not None and id(db) not in ancestor_set: + raise MirVerifyError( + f"{item.opcode} operand[{i}]: SSA " + f"dominance violation. Value " + f"{op!r} is defined in block " + f"{db.name!r}, which is NOT an " + f"ancestor of the use site block " + f"{blk.name!r}. Cross-scope use " + f"is illegal: a value defined " + f"inside a loop body can only be " + f"referenced within that body." + ) + elif isinstance(item, MirLoop): + for body_blk in item.body: + _check_uses_block(body_blk) + + for blk in fn.blocks: + _check_uses_block(blk) + + # SCF-style scope check: every operand's defining-site scope must + # be an ANCESTOR of the using instruction's scope. A scope is the + # block (or its enclosing loop-region chain) containing a value's + # def. Concretely: + # * a function-level constant (gp0) is in scope everywhere + # * a value defined by an instr in block B is in scope inside + # B and inside any block nested under B (via MirLoop body) + # * a block argument of block B is in scope inside B and below + # + # Equivalently: walking outward from the use site through + # parent_loop / parent_block pointers must eventually reach the + # block where the value is defined / is an argument of. + def _block_chain(blk: MirBlock) -> List[MirBlock]: + """List of blocks from ``blk`` outward to the function root. + Each step crosses one MirLoop boundary (the block's + parent_loop's parent_block is the next outer block).""" + chain = [] + cur = blk + while cur is not None: + chain.append(cur) + if cur.parent_loop is None: + break + cur = cur.parent_loop.parent_block + return chain + + def _def_block(v: MirValue) -> Optional[MirBlock]: + if v.is_function_const: + return None # always in scope + if v.block_arg_of is not None: + return v.block_arg_of + if v.defined_by is not None: + return v.defined_by.parent + return None + + def _check_scope_block(blk: MirBlock) -> None: + for item in blk.items: + if isinstance(item, MirInstr): + ancestors = _block_chain(blk) + ancestor_set = {id(b) for b in ancestors} + for i, op in enumerate(item.operands): + if not isinstance(op, MirValue): + continue + if op.is_function_const: + continue + db = _def_block(op) + if db is None: + continue # already handled + if id(db) not in ancestor_set: + raise MirVerifyError( + f"{item.opcode} operand[{i}]: SSA value " + f"{op!r} is defined in block " + f"{db.name!r} which is NOT an ancestor " + f"of the use site block {blk.name!r}. " + f"Cross-scope use violates SSA/SCF " + f"dominance. Producer should yield this " + f"value via a (TODO) loop-result mechanism " + f"or hoist the def to a common ancestor." + ) + elif isinstance(item, MirLoop): + for body_blk in item.body: + _check_scope_block(body_blk) + + for blk in fn.blocks: + _check_scope_block(blk) + + +def _check_operand_kind( + instr: MirInstr, i: int, op: MirOperand, kind: str, +) -> None: + if kind == "i32": + if isinstance(op, MirValue): + if op.dtype != "i32": + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected i32 SSA " + f"value, got {op!r} (dtype={op.dtype!r})" + ) + return + if isinstance(op, (int, tir.IntImm)): + return + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected i32 SSA value or " + f"int literal; got {type(op).__name__} {op!r}" + ) + if kind == "addr_reg": + if not isinstance(op, MirValue) or op.dtype != "addr_reg": + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected addr_reg SSA " + f"value; got {op!r}" + ) + return + if kind == "fp_reg": + if not isinstance(op, str): + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected fp_reg verbatim " + f"string (e.g. 'f1'); got {type(op).__name__} {op!r}" + ) + return + if kind == "literal_int": + if not isinstance(op, (int, tir.IntImm)): + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected literal_int; " + f"got {type(op).__name__} {op!r}" + ) + return + if kind == "verbatim_str": + if not isinstance(op, str): + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: expected verbatim_str; " + f"got {type(op).__name__} {op!r}" + ) + return + raise MirVerifyError( + f"{instr.opcode} operand[{i}]: unknown operand kind {kind!r}" + ) + + +__all__ = [ + "MirValue", "MirInstr", "MirBlock", "MirLoop", "MirFunction", + "MirOperand", "MirVerifyError", + "OPCODES", "_OpcodeSpec", + "format_mir", "verify", +] diff --git a/tilelang_tvm_compiler/mir_passes.py b/tilelang_tvm_compiler/mir_passes.py new file mode 100644 index 0000000..2990397 --- /dev/null +++ b/tilelang_tvm_compiler/mir_passes.py @@ -0,0 +1,1344 @@ +"""MIR optimisation passes. + +Each pass takes a :class:`mir.MirFunction` and mutates it in place; +returns ``True`` if it changed anything (so a driver can iterate to +a fixed point). Passes are independent; ``run_default_pipeline`` ties +them together in a sensible order. + +Implemented passes +------------------ + +* :func:`dead_loop_elim` — eliminate ``extent <= 1`` loops. ``extent + == 0`` deletes the whole loop; ``extent == 1`` peels the body up + one level, replacing ``loop_var`` uses with ``IntImm(init)``. + +* :func:`const_fold` — fold instructions whose i32 inputs are all + compile-time constants (``IntImm``). Substitutes the resulting + constant value into every use, removes the now-dead instruction. + Handles the same set the runtime ALU ops cover (S_ADDI/SLLI/SRLI/ + ADD/SUB/MUL). + +* :func:`dce` — dead-code elimination. Drops any non-side-effecting + MirInstr whose result has no users. Side-effecting opcodes + (memory writes, control-register sets, HW ops) are kept regardless. + +* :func:`cse` — common subexpression elimination within a block. + Two MirInstrs with the same opcode and operand identities collapse + to one; the duplicate's result has its uses redirected to the first. + +* :func:`reassociate` — flatten chains of ``S_ADD_INT`` / + ``S_ADDI_INT`` into multi-term sums, canonicalise the term lists, + fold IntImm terms, then rebuild as left-associative chains that + share the longest common prefix with any already-existing chain. + This is what lets two address PrimExprs like + ``mat = head*hlen + oc*blen + base`` and + ``result = orow*S + head*hlen + oc*blen + base`` collapse so that + ``result`` literally becomes ``S_ADD_INT %mat, %orow_term`` + instead of recomputing four operands from scratch. + +* :func:`run_default_pipeline` — runs DLE → const_fold → DCE → + reassociate → CSE → LICM(optional) to a fixed point (or until + ``max_iters`` is reached). +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +from tvm import tir + +from . import mir + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + +def _replace_all_uses(old: mir.MirValue, new: "mir.MirOperand") -> None: + """Replace every use of ``old`` with ``new`` across the function. + + ``new`` can be any MirOperand (MirValue, int, IntImm, str). We + walk ``old.used_by`` (a snapshot to avoid mutation-during-iter) + and ask each using instr to swap. Each ``set_operand`` call + updates ``new.used_by`` (if new is a MirValue) and removes the + using instr from ``old.used_by`` automatically. + """ + for user in list(old.used_by): + for i, op in enumerate(list(user.operands)): + if op is old: + user.set_operand(i, new) + + +def _is_int_const(op: "mir.MirOperand") -> bool: + """True iff ``op`` represents a compile-time integer constant. + + Three forms count: + * Python ``int`` + * ``tir.IntImm`` + * function-level constants (``gp0`` = 0) + * a MirValue produced by ``S_ADDI_INT gp0, K`` (the canonical + "materialise constant" idiom). This makes const-prop + transitive: the result of a previous fold becomes a const + input for the next. + """ + if isinstance(op, int): + return True + if isinstance(op, tir.IntImm): + return True + if isinstance(op, mir.MirValue): + if op.is_function_const: + return True + d = op.defined_by + if (d is not None and d.opcode == "S_ADDI_INT" + and isinstance(d.operands[0], mir.MirValue) + and d.operands[0].is_function_const + and isinstance(d.operands[1], int)): + return True + return False + + +def _int_value(op: "mir.MirOperand") -> int: + if isinstance(op, int): + return int(op) + if isinstance(op, tir.IntImm): + return int(op.value) + if isinstance(op, mir.MirValue): + if op.is_function_const: + return 0 + d = op.defined_by + if (d is not None and d.opcode == "S_ADDI_INT" + and isinstance(d.operands[0], mir.MirValue) + and d.operands[0].is_function_const + and isinstance(d.operands[1], int)): + return int(d.operands[1]) + raise TypeError(f"_int_value: not a constant: {op!r}") + + +def _is_side_effecting(instr: mir.MirInstr) -> bool: + """True iff this instr has effects beyond its SSA result. + + Anything writing to memory (S_ST_*, M_*_WO, H_STORE_V), binding a + control register (C_SET_*), launching a HW kernel (M_MM, M_BTMM, + H_PREFETCH_*, V_RED_*, M_TMM …) — none of those are safe to drop + by DCE no matter how unused their (often void) result is. We + err on the side of caution: keep anything NOT in the known-pure + set. + + Pure: scalar ALU ops that compute one i32 from inputs (no I/O, + no register-file bind). Their result is the *only* observable + effect. + """ + pure = { + "S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT", + "S_SLL_INT", "S_SRL_INT", + "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", + } + return instr.opcode not in pure + + +def _walk_blocks(fn: mir.MirFunction): + """Yield every MirBlock in the function, recursing into loop + bodies. Order: outer block first, then nested bodies in source + order. Sufficient for passes that just need "every block".""" + def _recurse(blk: mir.MirBlock): + yield blk + for item in blk.items: + if isinstance(item, mir.MirLoop): + for b in item.body: + yield from _recurse(b) + for blk in fn.blocks: + yield from _recurse(blk) + + +def _walk_blocks_with_parent_item_lists(fn: mir.MirFunction): + """Yield every (block, items_list_owning_the_block, block_index) + triple. Used by DLE to splice a loop's body items back into the + parent items list. For top-level blocks, ``items_list_owning_the_block`` + is ``fn.blocks`` and ``block_index`` is its index there.""" + # Top-level blocks live in fn.blocks. + for i, blk in enumerate(fn.blocks): + yield blk, fn.blocks, i + # Recurse into loops inside this block. + yield from _recurse_loops_in_block(blk) + + +def _recurse_loops_in_block(blk: mir.MirBlock): + for item in blk.items: + if isinstance(item, mir.MirLoop): + for i, body_blk in enumerate(item.body): + yield body_blk, item.body, i + yield from _recurse_loops_in_block(body_blk) + + +# --------------------------------------------------------------------- +# Pass 1: dead loop elimination +# --------------------------------------------------------------------- + +def dead_loop_elim(fn: mir.MirFunction) -> bool: + """Eliminate ``extent <= 1`` loops. + + Strategy: walk every block; for each loop child whose extent is + 0 (delete) or 1 (peel), rewrite in place. Peeling moves the + loop body's items into the parent block at the loop's old + position, substituting ``IntImm(init)`` for every use of + ``loop_var``. + + Caveat: the body block's own ``arguments`` list (which contained + ``loop_var``) is discarded as the items splice up — the parent + block is the new owner, and the loop_var has been RAUW'd away + so nobody references it any more. + + Returns True if any loop was removed/peeled. + """ + changed = False + + def _process_block(parent_blk: mir.MirBlock) -> None: + """Rewrite ``parent_blk.items`` in place. We pass the OWNING + block (not just the items list) so we can re-wire spliced + items' ``parent``/``parent_block`` to the new home.""" + nonlocal changed + items = parent_blk.items + i = 0 + while i < len(items): + it = items[i] + if not isinstance(it, mir.MirLoop): + i += 1 + continue + lp = it + # Recurse inside-out — peeling an inner extent-1 loop + # exposes the items in its parent (= this lp's body), + # which may itself become peelable in a later iteration. + for body_blk in lp.body: + _process_block(body_blk) + + if lp.extent <= 0: + for use in list(lp.loop_var.used_by): + raise mir.MirVerifyError( + f"DLE: cannot delete loop {lp.name!r} " + f"(extent=0): loop_var still has user " + f"{use!r} outside the body" + ) + del items[i] + changed = True + continue + + if lp.extent == 1: + if len(lp.body) != 1: + i += 1 + continue + body = lp.body[0] + init_const = tir.IntImm("int32", int(lp.init)) + _replace_all_uses(lp.loop_var, init_const) + lp.loop_var.block_arg_of = None + # Splice body items into parent_blk at position + # ``i`` (replacing the loop). Set each spliced + # item's parent pointer to ``parent_blk`` — they + # now live there, not in the discarded body + # block. Without this rewire, a child loop's + # ``parent_block`` would point at the dead body + # block and any later scope walk would see a + # broken chain (the dominance check would fail + # to recognise valid uses). + spliced = list(body.items) + for sub in spliced: + if isinstance(sub, mir.MirInstr): + sub.parent = parent_blk + elif isinstance(sub, mir.MirLoop): + sub.parent_block = parent_blk + items[i:i + 1] = spliced + changed = True + continue + + i += 1 + + for blk in fn.blocks: + _process_block(blk) + return changed + + +# --------------------------------------------------------------------- +# Pass 2: constant folding +# --------------------------------------------------------------------- + +def const_fold(fn: mir.MirFunction) -> bool: + """Fold pure ALU instructions whose i32 inputs are all integer + constants. The instruction's operands are rewritten in place: + when both inputs are const, the instr is collapsed into the + canonical "materialise constant" form ``S_ADDI_INT gp0, K`` + (or rewritten to be a ``S_LUI_INT`` + ``S_ADDI_INT`` chain when + K exceeds the 18-bit immediate range — but for now we just emit + a single S_ADDI_INT and rely on a later pass to handle wide + immediates). + + Why rewrite the SAME instr instead of RAUW'ing an IntImm? The + MIR i32 operand kind requires a MirValue at every use site + (the emit layer turns it into ``gp{N}``). A bare IntImm in an + i32 slot has no GP and cannot be emitted. By rewriting the + folding instr to ``S_ADDI_INT gp0, K``, the value's identity + is preserved (same MirValue, same uses), but its definition + becomes a cheap "load constant into GP" — and any number of + folded ADDIs producing the same K collapse to one via CSE. + + Recognised opcodes (the pure set): + + * ``S_ADDI_INT (x, k)`` → x + k (k is literal_int) + * ``S_SLLI_INT (x, k)`` → x << k + * ``S_SRLI_INT (x, k)`` → x >> k + * ``S_LUI_INT (k)`` → k << 12 + * ``S_SLL_INT (x, y)`` → x << y + * ``S_SRL_INT (x, y)`` → x >> y + * ``S_ADD_INT (x, y)`` → x + y + * ``S_SUB_INT (x, y)`` → x - y + * ``S_MUL_INT (x, y)`` → x * y + + Returns True if any instruction was rewritten. + """ + changed = False + + foldable = {"S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT", + "S_SLL_INT", "S_SRL_INT", + "S_ADD_INT", "S_SUB_INT", "S_MUL_INT"} + + def _try_fold(instr: mir.MirInstr) -> Optional[int]: + op = instr.opcode + if op not in foldable: + return None + ops = instr.operands + if op == "S_LUI_INT": + if not _is_int_const(ops[0]): + return None + return _int_value(ops[0]) << 12 + # All others: 2 operands. The first is always i32, the + # second varies (literal_int for immediates, i32 for reg). + if not (_is_int_const(ops[0]) and _is_int_const(ops[1])): + return None + a, b = _int_value(ops[0]), _int_value(ops[1]) + if op == "S_ADDI_INT" or op == "S_ADD_INT": + return a + b + if op == "S_SUB_INT": + return a - b + if op == "S_MUL_INT": + return a * b + if op == "S_SLLI_INT" or op == "S_SLL_INT": + return a << b + if op == "S_SRLI_INT" or op == "S_SRL_INT": + return a >> b + return None + + # Iterate to a local fixed point. Each round we try BOTH + # constant folding and algebraic identity peepholes: + # + # * ``S_ADD x, gp0`` / ``S_ADD gp0, x`` → x + # * ``S_SUB x, gp0`` → x + # * ``S_MUL gp0, _`` / ``S_MUL _, gp0`` → 0 → gp0 + # * ``S_SLLI gp0, _`` / ``S_SRLI gp0, _`` → gp0 + # * folded result == 0 → RAUW to gp0 (don't materialise zero) + # + # These identity hits are common after the per-PreIsaOp address + # PrimExpr lowering: many sub-expressions reduce to additions + # against gp0 (e.g. ``0 * stride + offset``). + def _identity_rewrite(instr: mir.MirInstr) -> bool: + """Return True if we RAUW'd the result to a simpler operand + (caller marks ``changed`` and continues — the instr's now + dead; DCE sweeps).""" + op = instr.opcode + ops = instr.operands + gp0 = fn.gp0_value + + def _is_gp0(o): + return isinstance(o, mir.MirValue) and o.is_function_const + + if op == "S_ADD_INT": + if _is_gp0(ops[0]): + _replace_all_uses(instr.result, ops[1]) + return True + if _is_gp0(ops[1]): + _replace_all_uses(instr.result, ops[0]) + return True + elif op == "S_SUB_INT": + if _is_gp0(ops[1]): + _replace_all_uses(instr.result, ops[0]) + return True + elif op == "S_MUL_INT": + if _is_gp0(ops[0]) or _is_gp0(ops[1]): + _replace_all_uses(instr.result, gp0) + return True + elif op in ("S_SLLI_INT", "S_SRLI_INT", "S_SLL_INT", "S_SRL_INT"): + if _is_gp0(ops[0]): + _replace_all_uses(instr.result, gp0) + return True + elif op == "S_ADDI_INT": + # ``ADDI gp0, 0`` is gp0 itself. + if _is_gp0(ops[0]) and isinstance(ops[1], int) and ops[1] == 0: + _replace_all_uses(instr.result, gp0) + return True + # ``ADDI x, 0`` is x. + if isinstance(ops[1], int) and ops[1] == 0: + _replace_all_uses(instr.result, ops[0]) + return True + return False + + for _iteration in range(1000): + any_this_round = False + for instr in fn.walk_instrs(): + if instr.result is None: + continue + # Identity peephole first — cheaper, and tends to + # expose more folding opportunities (a gp0 propagating + # through ADD chains). + if _identity_rewrite(instr): + any_this_round = True + changed = True + continue + folded = _try_fold(instr) + if folded is None: + continue + # If folded value is 0, prefer RAUW to gp0 over + # materialising an ``ADDI gp0, 0``. + if folded == 0: + _replace_all_uses(instr.result, fn.gp0_value) + # detach operands + for op in list(instr.operands): + if isinstance(op, mir.MirValue): + try: + op.used_by.remove(instr) + except ValueError: + pass + any_this_round = True + changed = True + continue + # Wide immediates (>= 2^17 abs) need S_LUI+ADDI; leave + # them as-is for now (the original PrimExpr lowering + # already handled wide imms; if we hit one here it + # came from a fold and we don't want to over-eagerly + # rewrite). + if not (-(1 << 17) <= folded < (1 << 17)): + continue + # Detach old operands. + for op in list(instr.operands): + if isinstance(op, mir.MirValue): + try: + op.used_by.remove(instr) + except ValueError: + pass + # Rewrite this instr to ``S_ADDI_INT gp0, folded``. + instr.opcode = "S_ADDI_INT" + instr.operands = [fn.gp0_value, int(folded)] + fn.gp0_value.used_by.append(instr) + any_this_round = True + changed = True + if not any_this_round: + break + return changed + + +# --------------------------------------------------------------------- +# Pass 3: dead code elimination +# --------------------------------------------------------------------- + +def dce(fn: mir.MirFunction) -> bool: + """Drop pure MirInstrs whose result has no users. + + A "pure" instr is one whose only effect is producing its SSA + result (no memory write, no control-register bind, no HW kernel + issue). See ``_is_side_effecting`` for the keep-set. + + Iterates to fixed point — removing one instr may make its + operands' producers dead too. + + Returns True if any instruction was removed. + """ + changed = False + + for _iteration in range(1000): + any_this_round = False + for blk in _walk_blocks(fn): + i = 0 + while i < len(blk.items): + it = blk.items[i] + if not isinstance(it, mir.MirInstr): + i += 1 + continue + if it.result is None: + i += 1 + continue + if _is_side_effecting(it): + i += 1 + continue + if it.result.used_by: + i += 1 + continue + # Drop. First sever this instr's uses on its + # operands (so their used_by lists don't hold + # dangling pointers). + for op in list(it.operands): + if isinstance(op, mir.MirValue): + try: + op.used_by.remove(it) + except ValueError: + pass + # Drop the result's defined_by binding too — + # the MirValue becomes orphaned and unreachable. + it.result.defined_by = None + del blk.items[i] + any_this_round = True + changed = True + # don't advance i — items shifted left + if not any_this_round: + break + return changed + + +# --------------------------------------------------------------------- +# Pass 4: common subexpression elimination (intra-block) +# --------------------------------------------------------------------- + +def cse(fn: mir.MirFunction) -> bool: + """Within each block, collapse duplicate pure expressions. + + Two MirInstrs are duplicates when they have: + * the same opcode, + * the same number of operands, and + * each operand pair is either identical Python int / IntImm + (by value), the same str token, or the same MirValue (by + identity). + + The FIRST occurrence wins; later occurrences have their result's + uses redirected to the first occurrence's result, then the + duplicate is dropped. + + Only pure instructions are eligible (CSE'ing a memory write or + control-register set would change program behaviour). + + Block-local only: a value defined in an outer block can be CSE'd + against another defined in an inner block, but we don't yet + move definitions across block boundaries (that's LICM). The + common case — repeated address-base computation inside a single + loop body — is fully handled here. + + Returns True if any instruction was eliminated. + """ + changed = False + + def _operand_key(op): + if isinstance(op, mir.MirValue): + return ("v", id(op)) + if isinstance(op, tir.IntImm): + return ("i", int(op.value)) + if isinstance(op, int): + return ("i", int(op)) + if isinstance(op, str): + return ("s", op) + return ("o", repr(op)) + + def _process_block(blk: mir.MirBlock): + nonlocal changed + # opcode + operand-keys-tuple → first MirInstr that + # produced it. Reset per block (no cross-block CSE + # without dominator info). + seen = {} + i = 0 + while i < len(blk.items): + it = blk.items[i] + if not isinstance(it, mir.MirInstr): + i += 1 + # Loops are entered separately below. + continue + if it.result is None or _is_side_effecting(it): + i += 1 + continue + key = ( + it.opcode, + tuple(_operand_key(o) for o in it.operands), + ) + if key in seen: + first = seen[key] + # Redirect uses of this duplicate to the first + # producer's result. + _replace_all_uses(it.result, first.result) + # Drop this instr. + for op in list(it.operands): + if isinstance(op, mir.MirValue): + try: + op.used_by.remove(it) + except ValueError: + pass + it.result.defined_by = None + del blk.items[i] + changed = True + continue + seen[key] = it + i += 1 + # Recurse into nested loops. + for it in blk.items: + if isinstance(it, mir.MirLoop): + for body_blk in it.body: + _process_block(body_blk) + + for blk in fn.blocks: + _process_block(blk) + return changed + + +# --------------------------------------------------------------------- +# Pass 5: loop-invariant code motion (LICM) +# --------------------------------------------------------------------- + +def licm(fn: mir.MirFunction) -> bool: + """Hoist loop-invariant pure instructions out of their enclosing + loop. + + An instruction is invariant in loop L if every MirValue operand + is defined OUTSIDE L's body (i.e. in an ancestor block) or is a + block argument of an ancestor block / function-level constant. + Only pure instructions (see ``_is_side_effecting``) are eligible; + hoisting a HW kernel issue or a control-register set would + change program semantics. + + Hoisting strategy: move the invariant instr from its body + position to the parent block, immediately before the loop itself. + The MIR scope rules guarantee this is safe — the parent block is + an ancestor of every body block, so all SSA uses still see a + valid def. (PLENA MIR has no MirIf, so there are no branch + paths to confuse this.) + + Iterates to fixed point — hoisting an inner-loop instr can + expose an outer-loop invariant. + """ + changed = False + + def _is_invariant(instr: mir.MirInstr, + forbidden_defs: set) -> bool: + """``instr`` is invariant w.r.t. some loop iff none of its + operand MirValues were defined inside that loop (i.e. + ``forbidden_defs`` contains every MirValue defined inside + the loop's body, including the loop_var).""" + if _is_side_effecting(instr): + return False + for op in instr.operands: + if not isinstance(op, mir.MirValue): + continue # literals are always invariant + if id(op) in forbidden_defs: + return False + return True + + def _defs_inside(blk: mir.MirBlock, out: set) -> None: + """Collect ``id()`` of every MirValue defined inside this + block subtree (block args + instr results + nested loop + body args + nested instr results).""" + for arg in blk.arguments: + out.add(id(arg)) + for item in blk.items: + if isinstance(item, mir.MirInstr): + if item.result is not None: + out.add(id(item.result)) + elif isinstance(item, mir.MirLoop): + for b in item.body: + _defs_inside(b, out) + + def _process_loop_in(parent_items: List, loop_idx: int) -> None: + """Hoist any invariant instrs from ``parent_items[loop_idx]`` + (a MirLoop) up into ``parent_items`` just before it.""" + nonlocal changed + lp = parent_items[loop_idx] + # Collect defs inside the loop. We re-collect after each + # hoist (a hoisted instr's result no longer counts as + # inside). + while True: + forbidden = set() + for b in lp.body: + _defs_inside(b, forbidden) + # Try to find ONE invariant instr in this iteration; + # if found, hoist it and recompute forbidden. + hoisted = False + for b in lp.body: + for i, it in enumerate(b.items): + if not isinstance(it, mir.MirInstr): + continue + if _is_invariant(it, forbidden): + # Remove from body. + del b.items[i] + # Insert in parent_items just before lp. + # parent_blk is the block that owns ``lp``; + # the hoisted instr now lives there, so wire + # its ``parent`` accordingly. Without this, + # verify's dominance check (which looks at + # ``defined_by.parent``) silently sees None + # and skips the check — a false-negative + # rather than a real OK. + parent_items.insert(loop_idx, it) + it.parent = lp.parent_block + loop_idx += 1 # lp shifted right by 1 + hoisted = True + changed = True + break + if hoisted: + break + if not hoisted: + break + # Recurse into nested loops AFTER the outer is stable — + # an inner-loop hoist that puts a def into THIS loop's + # body doesn't make it invariant for THIS loop (the inner + # is still inside), so order doesn't matter much; we + # process inside-out by recursing here. + for b in lp.body: + i = 0 + while i < len(b.items): + it = b.items[i] + if isinstance(it, mir.MirLoop): + _process_loop_in(b.items, i) + i += 1 + + def _process_block(blk: mir.MirBlock) -> None: + i = 0 + while i < len(blk.items): + it = blk.items[i] + if isinstance(it, mir.MirLoop): + _process_loop_in(blk.items, i) + i += 1 + + for blk in fn.blocks: + _process_block(blk) + return changed + + +# --------------------------------------------------------------------- +# Pass 6: reassociate ADD chains for sharing +# --------------------------------------------------------------------- + +def reassociate(fn: mir.MirFunction) -> bool: + """Canonicalise ``S_ADD_INT`` / ``S_ADDI_INT`` chains so that + chains with overlapping term sets share their common prefix. + + A "chain" is the transitive closure of ``S_ADD_INT`` / + ``S_ADDI_INT`` instrs whose operands are themselves chain + results. A *leaf* is any operand that is NOT a chain result: a + block argument, an instruction with a non-additive opcode, an + IntImm, gp0, etc. By recursively unfolding chain operands we + rewrite each chain instr's "view" as a multiset of leaf terms + + one folded IntImm constant. + + Sharing rule: any two chains with the same canonical leaf list + collapse to the same MirValue (RAUW on duplicate result). Two + chains where one's leaf list is a **prefix** of the other's + share the partial-sum value: the longer chain's result becomes + ``S_ADD_INT shorter_result, extra_leaf``. + + Caveats: + * An interior chain instr with multiple users is treated as + an opaque leaf — splitting it would force its other users + to also restructure (and would lose CSE if they were a + chain too). It can still ANCHOR other chains as a leaf, + which is what enables the "result = mat + orow_term" case + without breaking ``mat`` itself. + * IntImm folding here covers the same arithmetic + ``const_fold`` does on a per-instr basis, but with the + full multiset visible — so ``(((x+0)+3)+5)+x`` reduces to + ``2*x+8`` (well, we don't synthesise multiplication, but + we DO collapse the two ``x`` and the ``0+3+5`` into a + canonical form that CSE then dedups across chains). + + This is one outer iteration: rebuild chains greedily, + largest-first (so longer chains get a chance to anchor the + "shared prefix" slot before shorter chains decide what to + point at). Outer driver re-runs the rest of the pipeline + + this pass to a fixed point. + + Returns True if any chain was rewritten. + """ + changed = False + + def _is_add_instr(instr) -> bool: + return (instr is not None + and instr.opcode in ( + "S_ADD_INT", "S_ADDI_INT", "S_SUB_INT", + )) + + def _is_chain_result(op) -> bool: + """``op`` is a MirValue defined by an ADD/ADDI/SUB instr + with EXACTLY ONE user. Multi-user chain instrs are + barriers — absorbing them would orphan the other users.""" + if not isinstance(op, mir.MirValue): + return False + d = op.defined_by + if d is None: + return False + if not _is_add_instr(d): + return False + if len(op.used_by) != 1: + return False + return True + + # ---- Phase 1: flatten ---- + # + # Each chain instr collapses into ``(signed_leaves, const)`` where + # ``signed_leaves`` is a list of ``(sign, MirValue)`` and + # ``const`` is the (signed) integer sum. ``S_SUB_INT a, b`` is + # treated as ``+a + (-b)`` so its leaves get split into a + # positive group (from a) and a negative group (from b). This + # turns SUB into just-another-ADD with negated terms, letting + # the rest of the algorithm (canonical sort + prefix cache) + # work unchanged. Reconstruction (Phase 2) re-emits SUBs only + # if any leaf has sign = -1. + # + # When absorbing a sub-chain whose head sign is ``s``, we visit + # its operands with sign ``s`` for the first operand and ``s`` + # for the second (ADD/ADDI) or ``-s`` (SUB second operand). + + def _flatten(instr) -> Tuple[List[Tuple[int, "mir.MirValue"]], int]: + """Return ``(signed_leaves, signed_const)``. ``signed_leaves`` + is a list of ``(sign ∈ {+1,-1}, MirValue)`` sorted by + (sign, name) for a stable canonical form.""" + leaves: List[Tuple[int, "mir.MirValue"]] = [] + const = 0 + # Worklist of (sign, operand, opcode-of-parent, operand-index) + # — opcode + index lets us flip the sign on SUB's second + # operand. We seed with ``instr``'s operands directly. + work: List[Tuple[int, "mir.MirOperand"]] = [] + # For the top-level instr: each operand inherits sign +1, + # except the second operand of S_SUB_INT flips to -1. + for i, op in enumerate(instr.operands): + sign = 1 + if instr.opcode == "S_SUB_INT" and i == 1: + sign = -1 + work.append((sign, op)) + + while work: + sign, op = work.pop() + if isinstance(op, int): + const += sign * op + continue + if isinstance(op, tir.IntImm): + const += sign * int(op.value) + continue + if isinstance(op, mir.MirValue): + if op.is_function_const: + continue + if _is_chain_result(op): + d = op.defined_by + # Absorb: push d's operands with the right + # signs given the outer ``sign``. + for j, sub_op in enumerate(d.operands): + sub_sign = sign + if d.opcode == "S_SUB_INT" and j == 1: + sub_sign = -sign + work.append((sub_sign, sub_op)) + continue + leaves.append((sign, op)) + continue + # Unknown — drop silently. + # Cancel matching +/- pairs of the SAME MirValue. + # ``%x + %x - %x`` reduces to ``%x``. + canon: Dict[int, int] = {} # id(v) -> net sign sum + canon_val: Dict[int, "mir.MirValue"] = {} + for s, v in leaves: + canon[id(v)] = canon.get(id(v), 0) + s + canon_val[id(v)] = v + out: List[Tuple[int, "mir.MirValue"]] = [] + for vid, net in canon.items(): + v = canon_val[vid] + # net of +2 / -2 / etc. could be expressed as + # multiplication; we don't synth muls. Keep |net| as + # repeated additions/subtractions if the magnitude is + # small (1 is by far the common case). For >|1| we + # just expand into |net| copies — rare in address code. + if net == 0: + continue + mag = abs(net) + sub_sign = 1 if net > 0 else -1 + for _ in range(mag): + out.append((sub_sign, v)) + out.sort(key=lambda sv: (sv[0], sv[1].name)) + return out, const + + # ---- Phase 2: per-block rebuild ---- + def _process_block(blk: mir.MirBlock) -> None: + nonlocal changed + # Cache: (tuple of (sign, leaf-id), const) → MirValue. + prefix_cache: Dict[Tuple, "mir.MirValue"] = {} + + chain_entries = [] + for idx, it in enumerate(blk.items): + if isinstance(it, mir.MirInstr) and _is_add_instr(it): + leaves, const = _flatten(it) + chain_entries.append((idx, it, leaves, const)) + + if not chain_entries: + for it in blk.items: + if isinstance(it, mir.MirLoop): + for b in it.body: + _process_block(b) + return + + def _key(signed_leaves, const): + return (tuple((s, id(v)) for s, v in signed_leaves), const) + + for _, instr, leaves, const in chain_entries: + if instr not in blk.items: + continue + + k = _key(leaves, const) + if k in prefix_cache: + cached = prefix_cache[k] + if cached is instr.result: + continue + _replace_all_uses(instr.result, cached) + changed = True + continue + + # Reconstruction strategy. + # + # signed_leaves is sorted (sign, name). Sign +1 first, + # then -1. We: + # 1. Build the positive sub-sum first using the same + # prefix-cache trick (longest matching pos prefix + # reuses an existing partial sum). + # 2. For each negative leaf, emit S_SUB_INT acc, leaf. + # 3. Finally fold ``const`` into the accumulator via + # S_ADDI_INT acc, const (PLENA ADDI takes a + # signed 18-bit immediate, so negative const works + # too). + pos_leaves = [v for s, v in leaves if s > 0] + neg_leaves = [v for s, v in leaves if s < 0] + + # ---- find prefix cache hit for pos_leaves ---- + best_prefix = None + best_prefix_value = None + best_prefix_const_in = False + for n in range(len(pos_leaves), 0, -1): + # Pos prefix means signed leaves of the prefix are + # all (+1, v). The cache key for a positive-only + # prefix uses sign tag +1. + pref_signed = [(1, v) for v in pos_leaves[:n]] + sub_key_full = _key(pref_signed, const) + if sub_key_full in prefix_cache: + best_prefix = pos_leaves[:n] + best_prefix_value = prefix_cache[sub_key_full] + best_prefix_const_in = True + break + sub_key_no_c = _key(pref_signed, 0) + if sub_key_no_c in prefix_cache: + best_prefix = pos_leaves[:n] + best_prefix_value = prefix_cache[sub_key_no_c] + best_prefix_const_in = False + break + + if best_prefix is not None: + start_value = best_prefix_value + remaining_pos = pos_leaves[len(best_prefix):] + pending_const = 0 if best_prefix_const_in else const + else: + # No pos-prefix match; start from scratch. + if not pos_leaves: + if not neg_leaves: + # All-const sum — const_fold should have + # handled it; skip. + continue + # Pure negative sum: ``0 - n1 - n2 ...``. Start + # with the negative of the first neg_leaf? PLENA + # has no NEG; emit as ``S_SUB_INT gp0, neg[0]``, + # then SUB the rest, then ADDI const. + start_value = fn.gp0_value + remaining_pos = [] + pending_const = const + else: + start_value = pos_leaves[0] + remaining_pos = pos_leaves[1:] + pending_const = const + + # ---- build tail ---- + # Each tail element is (opcode, second_operand) where: + # ("S_ADD_INT", MirValue) — pos leaf + # ("S_SUB_INT", MirValue) — neg leaf + # ("S_ADDI_INT", int) — pending const fold + tail = [] + for v in remaining_pos: + tail.append(("S_ADD_INT", v)) + for v in neg_leaves: + tail.append(("S_SUB_INT", v)) + if pending_const != 0: + # Check range; out of range → don't fold here, drop. + if -(1 << 17) <= pending_const < (1 << 17): + tail.append(("S_ADDI_INT", int(pending_const))) + else: + # Wide const — leave the original instr alone. + # (Rare; const_fold handles ordinary 18-bit + # range. Pathological case skip.) + continue + + insert_idx = blk.items.index(instr) + accum = start_value + + def _emit_partial(opcode, src, second): + nonlocal insert_idx + dst = fn.mint_value("i32") + new_instr = mir.MirInstr( + opcode=opcode, operands=[src, second], result=dst, + ) + new_instr.parent = blk + blk.items.insert(insert_idx, new_instr) + insert_idx += 1 + return dst + + if not tail: + _replace_all_uses(instr.result, start_value) + prefix_cache[k] = start_value + changed = True + continue + + # Track running canonical key for caching partial sums. + # We mirror the same canonical form _flatten produces: + # sorted by (sign, name). So the partial-sum at each + # tail step is the input prefix + leaves consumed by + # tail so far. Easiest: just rebuild the canonical + # signed-leaf list for the current accum. + cur_signed = [(1, v) for v in (best_prefix + if best_prefix is not None + else ( + [pos_leaves[0]] + if pos_leaves else [] + ))] + cur_const = const if best_prefix_const_in else 0 + + for step_i, (step_op, step_arg) in enumerate(tail): + last = (step_i == len(tail) - 1) + if last: + # Reuse instr. + for op in list(instr.operands): + if isinstance(op, mir.MirValue): + try: + op.used_by.remove(instr) + except ValueError: + pass + instr.opcode = step_op + instr.operands = [accum, step_arg] + if isinstance(accum, mir.MirValue): + accum.used_by.append(instr) + if isinstance(step_arg, mir.MirValue): + step_arg.used_by.append(instr) + new_val = instr.result + else: + new_val = _emit_partial(step_op, accum, step_arg) + accum = new_val + + # Update the running canonical state + cache the + # partial sum so later chains find it. + if step_op == "S_ADD_INT": + cur_signed = sorted( + cur_signed + [(1, step_arg)], + key=lambda sv: (sv[0], sv[1].name), + ) + elif step_op == "S_SUB_INT": + cur_signed = sorted( + cur_signed + [(-1, step_arg)], + key=lambda sv: (sv[0], sv[1].name), + ) + elif step_op == "S_ADDI_INT": + cur_const += int(step_arg) + # Cache this partial sum's key for downstream reuse. + prefix_cache[(tuple((s, id(v)) for s, v in cur_signed), + cur_const)] = accum + + # Final partial-sum cache entry already wrote the full + # canonical key on the last loop iteration. Belt-and- + # braces: re-cache the original (k) form too. By the + # construction above (last tail step reuses ``instr``), + # ``accum is instr.result``. + assert accum is instr.result, ( + f"reassoc: tail's last step should reuse instr; " + f"got accum={accum!r} vs instr.result={instr.result!r}" + ) + prefix_cache[k] = accum + changed = True + + # Recurse into nested loops. + for it in blk.items: + if isinstance(it, mir.MirLoop): + for b in it.body: + _process_block(b) + + for blk in fn.blocks: + _process_block(blk) + return changed + + +# --------------------------------------------------------------------- +# Pass: unroll +# --------------------------------------------------------------------- + +def _clone_operand( + op: "mir.MirOperand", + value_map: Dict["mir.MirValue", "mir.MirOperand"], +) -> "mir.MirOperand": + """Translate one operand under a SSA-value mapping. + + Plain ints / IntImms / verbatim strings are pass-through. MirValue + operands look up ``value_map``; values not in the map (defined + outside the cloned region — outer loop_vars, gp0, function args, + buffer-base SSA values, …) reuse the original. + """ + if isinstance(op, mir.MirValue): + return value_map.get(op, op) + return op + + +def _clone_instr( + instr: mir.MirInstr, + fn: mir.MirFunction, + value_map: Dict["mir.MirValue", "mir.MirOperand"], +) -> mir.MirInstr: + """Clone a MirInstr. The clone gets fresh result MirValues; the + ``value_map`` is updated so later instrs in the same cloned region + pick up the new defs.""" + new_operands = [_clone_operand(o, value_map) for o in instr.operands] + new_result: Optional[mir.MirValue] = None + if instr.result is not None: + old = instr.result + # Mint a fresh SSA value of the same dtype. Preserve a hint + # so the dump stays human-readable across the unroll. + hint = "" + if "_" in old.name: + hint = old.name.split("_", 1)[1] + new_result = fn.mint_value(old.dtype, hint=hint) + value_map[old] = new_result + new_instr = mir.MirInstr(instr.opcode, new_operands, result=new_result) + # Carry over any optimisation-hint annotations (cheap shallow copy). + if instr.annotations: + new_instr.annotations = dict(instr.annotations) + return new_instr + + +def _clone_block( + blk: mir.MirBlock, + fn: mir.MirFunction, + value_map: Dict["mir.MirValue", "mir.MirOperand"], + name_suffix: str, +) -> mir.MirBlock: + """Deep-clone a MirBlock, including any nested MirLoop regions. + Block arguments get fresh MirValues mapped in ``value_map``.""" + new_blk = mir.MirBlock(name=f"{blk.name}{name_suffix}") + for arg in blk.arguments: + # If the caller (e.g. _clone_loop) already minted a fresh + # value for this argument and put it in the map, reuse it — + # the loop_var's MirValue identity must be the same one that + # the enclosing MirLoop holds in ``loop_var``. + if arg in value_map: + existing = value_map[arg] + if not isinstance(existing, mir.MirValue): + raise mir.MirVerifyError( + f"_clone_block: block-arg {arg!r} was pre-mapped " + f"to a non-MirValue {existing!r}" + ) + new_blk.add_argument(existing) + continue + hint = arg.name.split("_", 1)[1] if "_" in arg.name else "" + new_arg = fn.mint_value(arg.dtype, hint=hint) + value_map[arg] = new_arg + new_blk.add_argument(new_arg) + for item in blk.items: + if isinstance(item, mir.MirInstr): + new_blk.append(_clone_instr(item, fn, value_map)) + elif isinstance(item, mir.MirLoop): + new_blk.append(_clone_loop(item, fn, value_map, name_suffix)) + else: + raise TypeError( + f"_clone_block: unexpected item {type(item).__name__}" + ) + return new_blk + + +def _clone_loop( + lp: mir.MirLoop, + fn: mir.MirFunction, + value_map: Dict["mir.MirValue", "mir.MirOperand"], + name_suffix: str, +) -> mir.MirLoop: + """Deep-clone a MirLoop with a fresh loop_var. Used when we clone + an outer unroll body whose body contains a nested loop (typically + a serial inner loop — nested unroll loops are already flattened + by the innermost-first walk in :func:`unroll_loops`).""" + hint = lp.loop_var.name.split("_", 1)[1] if "_" in lp.loop_var.name else "" + new_lvar = fn.mint_value("i32", hint=hint) + value_map[lp.loop_var] = new_lvar + new_lp = mir.MirLoop( + name=f"{lp.name}{name_suffix}", + loop_var=new_lvar, + init=lp.init, + extent=lp.extent, + body=[], + loop_kind=lp.loop_kind, + annotations=dict(lp.annotations), + ) + for body_blk in lp.body: + cloned = _clone_block(body_blk, fn, value_map, name_suffix) + new_lp.add_body_block(cloned) + return new_lp + + +def unroll_loops(fn: mir.MirFunction) -> bool: + """Physically unroll every ``loop_kind == "unroll"`` MirLoop. + + For each unroll loop, the body is cloned ``extent`` times and the + clones are spliced into the parent items list at the loop's old + position. In each clone, every reference to the loop_var is + replaced by the integer iteration value, and every body-local SSA + def is replaced by a fresh MirValue (preserving the SSA single-def + invariant). The loop region itself is then removed. + + Walks innermost-first so that when an outer unroll is processed, + its body is already flat (any inner unroll has been expanded and + constant-folded by the surrounding pipeline run). Nested serial + loops inside an unroll body are deep-cloned per iteration. + + Why integers, not IntImms, for the iter value: a plain ``int`` + travels safely through every i32 operand slot — :func:`const_fold` + recognises it, :func:`mir.verify` accepts it, and the + ``mir_to_isa`` emit layer formats it directly when it lands in an + operand slot. IntImm would force the emit layer to special-case. + + Returns True iff any loop was unrolled. + """ + changed = False + + def _process_block(parent_blk: mir.MirBlock) -> None: + """Innermost-first rewrite of ``parent_blk.items``. Recurses + into every nested loop body before considering whether to + unroll the current loop itself.""" + nonlocal changed + items = parent_blk.items + i = 0 + while i < len(items): + it = items[i] + if not isinstance(it, mir.MirLoop): + i += 1 + continue + lp = it + # Recurse first so any inner unroll loops are flattened + # before we clone the current body. + for body_blk in lp.body: + _process_block(body_blk) + + if lp.loop_kind != "unroll": + i += 1 + continue + + if lp.extent < 0: + raise mir.MirVerifyError( + f"unroll_loops: loop {lp.name!r} has negative " + f"extent {lp.extent}" + ) + if lp.extent == 0: + # Drop the whole loop. Its loop_var must be + # unreferenced outside the body (block argument). + del items[i] + changed = True + continue + if len(lp.body) != 1: + raise mir.MirVerifyError( + f"unroll_loops: loop {lp.name!r} has " + f"{len(lp.body)} body blocks; expected exactly 1" + ) + body = lp.body[0] + spliced: List[mir.MirInstr | mir.MirLoop] = [] + for k in range(lp.extent): + iter_val = int(lp.init) + k + # Fresh per-iter value map. The loop_var maps to a + # plain int; every body-local def will be remapped to + # a fresh MirValue as instrs are cloned. + vmap: Dict[mir.MirValue, mir.MirOperand] = { + lp.loop_var: iter_val, + } + suffix = f"__u{k}" + for sub in body.items: + if isinstance(sub, mir.MirInstr): + new_instr = _clone_instr(sub, fn, vmap) + new_instr.parent = parent_blk + spliced.append(new_instr) + elif isinstance(sub, mir.MirLoop): + new_lp = _clone_loop(sub, fn, vmap, suffix) + new_lp.parent_block = parent_blk + spliced.append(new_lp) + else: + raise TypeError( + f"unroll_loops: unexpected body item " + f"{type(sub).__name__}" + ) + # Splice in place of the loop region. + items[i:i + 1] = spliced + # The old loop_var is no longer referenced (every clone + # used a plain int). Detach the block_arg pointer so + # downstream verify doesn't see a stale arg. + lp.loop_var.block_arg_of = None + changed = True + # Skip over the freshly-spliced items — they're already + # flat and don't contain any further unroll loops to + # process (inner unrolls were handled by the recursion + # at the top of this iteration). + i += len(spliced) + + for blk in fn.blocks: + _process_block(blk) + return changed + + +# --------------------------------------------------------------------- +# Default pipeline +# --------------------------------------------------------------------- + +def run_default_pipeline( + fn: mir.MirFunction, *, max_iters: int = 16, + enable_licm: bool = False, + dump_dir=None, +) -> List[Tuple[str, int]]: + """Run DLE → const_fold → DCE → CSE (→ LICM) to a fixed point. + + LICM is disabled by default. It reduces instruction count but + can increase register pressure beyond what the current + linear-scan allocator can serve (no IntRAM spill yet). Enable + explicitly once spill support lands. + + ``dump_dir`` (when set): after every pass that reports a change, + write the full ``format_mir`` snapshot to + ``/NN_iterM_passname.mir`` (NN is a global step counter). + The pre-opt and final states are dumped too. Lets you watch each + address PrimExpr fold step by step. + + Returns a list of ``(pass_name, run_count)`` tuples for + diagnostics — the number of outer iterations in which each + pass returned True. + """ + step = [0] + + def _dump(tag: str) -> None: + if dump_dir is None: + return + from pathlib import Path + d = Path(dump_dir) + d.mkdir(parents=True, exist_ok=True) + path = d / f"{step[0]:02d}_{tag}.mir" + path.write_text(mir.format_mir(fn)) + step[0] += 1 + + _dump("input") + passes = [ + ("dead_loop_elim", dead_loop_elim), + ("const_fold", const_fold), + ("dce", dce), + ("reassociate", reassociate), + ("cse", cse), + ] + if enable_licm: + passes.append(("licm", licm)) + + counts = {name: 0 for name, _ in passes} + for it in range(max_iters): + any_change = False + for name, fn_pass in passes: + if fn_pass(fn): + counts[name] += 1 + any_change = True + _dump(f"iter{it}_{name}") + if not any_change: + break + _dump("final") + return list(counts.items()) diff --git a/tilelang_tvm_compiler/mir_to_isa.py b/tilelang_tvm_compiler/mir_to_isa.py new file mode 100644 index 0000000..fc3a686 --- /dev/null +++ b/tilelang_tvm_compiler/mir_to_isa.py @@ -0,0 +1,1053 @@ +"""MIR → ISA text emitter. + +A mechanical lowering pass: walk the MirFunction, allocate physical +GP / addr registers to each SSA MirValue, then emit one line of +PLENA ISA text per non-meta MirInstr. + +This is the FIRST version — it uses a TRIVIAL register allocator: +every i32 MirValue gets a fresh GP from the free pool; addr_reg +values get a fresh ``aN`` slot. Values are never released → if a +kernel uses more than 16 GP-resident values the emit fails. That's +fine for the POC; the real register allocator (with live-range +analysis + IntRAM spill) is a follow-up pass that decorates each +MirValue with a "physical home" annotation before this emit runs. + +The emit dispatches on ``mir.OPCODES[opcode].isa_mnemonic`` and +formats the operands per opcode-specific rules. +""" + +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Union + +from . import mir + + +class MirToIsaError(RuntimeError): + pass + + +# --------------------------------------------------------------------- +# Physical register state +# --------------------------------------------------------------------- + +# GP file: gp0 is the constant-zero source (hardware-fixed); gp15 is +# reserved for the serial loop counter (matching the legacy +# convention); gp1..gp15 are user-allocatable. Serial loop counters +# borrow from this same pool — the emit layer mints a fresh MirValue +# per serial loop and pins its GP for the loop's lifetime, then +# releases it on exit. Nested serial loops therefore consume one +# extra GP per nesting level (the loop counter itself is dead inside +# the body — only the IntRAM-backed lvar gets read/written there). +GP_TOTAL = 16 +GP_USER_FIRST = 1 +GP_USER_LAST = 15 + + +def compute_emit_order(fn: "mir.MirFunction") -> Dict[int, int]: + """Assign every ``MirInstr`` a stable emit index by a DFS walk that + visits items in the SAME order ``MirToIsa`` emits them. + + Returns ``{id(MirInstr): emit_idx}``. The emit walk + (:meth:`MirToIsa._emit_instr`) reads each instr's pre-assigned index + from this map instead of a running counter, so the interval table + (keyed off these indices) and the release driver can never drift — + that drift was the root of the register-corruption bugs. + """ + order: Dict[int, int] = {} + counter = [0] + + def _walk(items): + for item in items: + if isinstance(item, mir.MirInstr): + order[id(item)] = counter[0] + counter[0] += 1 + elif isinstance(item, mir.MirLoop): + for blk in item.body: + _walk(blk.items) + + for blk in fn.blocks: + _walk(blk.items) + return order + + +def _check_no_iter_args(fn: "mir.MirFunction") -> None: + """Guard: this allocator handles only the induction variable as a + loop-carried SSA value. A loop body block with MORE THAN ONE block + argument means real ``iter_args`` (loop-carried accumulators kept in + SSA/GP form), which need phi-congruence register allocation (the + block-arg and the corresponding yield value must share a register, + pinned across the loop). That is NOT implemented — and today's MIR + can't even express it (MirLoop has no yield). Fail loud if it ever + appears, so the unsupported case can never silently miscompile. + """ + def _walk(items): + for item in items: + if isinstance(item, mir.MirLoop): + for blk in item.body: + if len(blk.arguments) > 1: + raise MirToIsaError( + f"loop {item.name!r} body block " + f"{blk.name!r} has {len(blk.arguments)} block " + f"arguments {[a.name for a in blk.arguments]}; " + f"only the induction variable is supported. " + f"Extra args are iter_args (loop-carried SSA " + f"values) requiring phi-congruence register " + f"allocation, which is not implemented." + ) + _walk(blk.items) + for blk in fn.blocks: + _walk(blk.items) + + +def _compute_live_intervals( + fn: "mir.MirFunction", emit_order: Dict[int, int], + loop_last_idx_out: Optional[Dict[int, int]] = None, + carried_by_loop_out: Optional[Dict[int, set]] = None, +) -> Dict[int, int]: + """Structured (SCF-aware) live-interval ends for every i32 MirValue. + + A value's interval is ``[def_point, end]`` in ``emit_order`` index + space. ``end`` is the point past which the value is provably dead, + so the allocator may recycle its GP once ``cur_idx > end``. + + SCF rule — the whole point of this pass: + + ``end[v] = max over each use u of v: + emit_idx[u], lifted OUTWARD through every loop that + encloses u but NOT v's definition, up to that loop's + LAST body emit index.`` + + Because a serial loop's single emitted body runs N times at runtime, + any value defined outside a loop but read inside it must stay live + across the whole loop region — hence the outward lift. A value + defined and used entirely within one scope gets the plain textual + last use. The loop_var / loop counter need no special-casing: their + def sits at the loop head and their uses are inside, so the rule + extends them to the loop's end automatically. + + ``gp0`` (function const) is omitted — the allocator pins it to + register 0 and never recycles it. + + ``carried_by_loop_out`` (when given): filled with + ``{id(MirLoop): {id(MirValue) live-in to that loop}}`` — values + defined OUTSIDE a loop but used INSIDE it. These are LOOP-CARRIED: + they must hold the same GP across every runtime re-entry, so the + allocator must never spill them while that loop is open (a spill + emitted in the single-pass body is not replayed on the back-edge, + so the next iteration's head would read a clobbered GP). This is + the structural fact that the region-recursive model expresses as + "live-in GPs are not lent into the child region". + + Returns ``{id(MirValue): end_idx}``. + """ + # Pass 1: record each instr's enclosing-loop id stack, each loop's + # last body emit index, and each value's defining-loop stack. + instr_loops: Dict[int, Tuple[int, ...]] = {} + def_loops: Dict[int, Tuple[int, ...]] = {} + loop_last_idx: Dict[int, int] = {} + + def _walk(items, loop_stack: Tuple[int, ...]): + for item in items: + if isinstance(item, mir.MirInstr): + idx = emit_order[id(item)] + instr_loops[id(item)] = loop_stack + for lid in loop_stack: + cur = loop_last_idx.get(lid, idx) + loop_last_idx[lid] = max(cur, idx) + if item.result is not None: + def_loops[id(item.result)] = loop_stack + elif isinstance(item, mir.MirLoop): + lid = id(item) + inner = loop_stack + (lid,) + for blk in item.body: + # Block arguments (loop_var) are defined inside. + for arg in blk.arguments: + def_loops[id(arg)] = inner + _walk(blk.items, inner) + + for blk in fn.blocks: + for arg in blk.arguments: + def_loops[id(arg)] = () + _walk(blk.items, ()) + + # Pass 2: for every use, extend the value's end through enclosing + # loops the def is not in, and record those loops as carriers. + end: Dict[int, int] = {} + carried: Dict[int, set] = {} + + def _visit_instr(instr: "mir.MirInstr"): + use_idx = emit_order[id(instr)] + use_loops = instr_loops[id(instr)] + for op in instr.operands: + if not isinstance(op, mir.MirValue) or op.is_function_const: + continue + def_set = def_loops.get(id(op), ()) + candidate = use_idx + for lid in use_loops: + if lid in def_set: + continue + # ``op`` is defined outside ``lid`` but used inside it → + # loop-carried into ``lid``. + candidate = max(candidate, loop_last_idx.get(lid, use_idx)) + carried.setdefault(lid, set()).add(id(op)) + if candidate > end.get(id(op), -1): + end[id(op)] = candidate + + def _walk2(items): + for item in items: + if isinstance(item, mir.MirInstr): + _visit_instr(item) + elif isinstance(item, mir.MirLoop): + for blk in item.body: + _walk2(blk.items) + + for blk in fn.blocks: + _walk2(blk.items) + if loop_last_idx_out is not None: + loop_last_idx_out.update(loop_last_idx) + if carried_by_loop_out is not None: + carried_by_loop_out.update(carried) + return end + + +class _LinearScanAllocator: + """Linear-scan register allocator over STRUCTURED live intervals, + with IntRAM spill-on-pressure. + + Two pre-passes (``compute_emit_order`` + ``_compute_live_intervals``) + assign every instr a stable emit index and every value an interval + END in that index space. The end is SCF-aware: a value defined + outside a loop but used inside it has its end lifted to the loop's + last body index, so it stays live across every runtime re-entry of + the serially-emitted body. The emit walker reads each instr's + pre-assigned index (NOT a running counter — that drifted whenever + emit appended ISA the pre-pass never walked, which was the root of + the register-corruption bugs) and calls ``release_dead_at(idx)``; + values whose end ≤ idx have their GP returned to the free pool. + + No data-value pinning: loop-carried values survive purely via their + extended interval. The ONLY pins are the serial-loop counter and + lvar, whose GPs are named raw in the C_LOOP prologue/epilogue and so + must not move while the loop is open (``pin`` / ``unpin``). + + When ``assign_i32`` is called with the free pool empty, the + allocator spills the live value with the farthest interval end to + a fresh IntRAM slot. The spilled value's GP is freed and handed + to the new request; emit-layer ``_fmt_operand`` calls into the + allocator to reload the spilled value back into a GP on demand. + + The allocator emits ``S_ST_INT`` (spill) / ``S_LD_INT`` (reload) + instrs by appending to a caller-owned list passed at + construction time (``isa_lines``), so the surrounding emit + layer doesn't need to know spill happened — except that we + warn via ``warnings.warn`` once per function so the user knows + register pressure was high. + + IntRAM slots: we use slot 0..MAX_INTRAM_SLOTS-1, allocated + monotonically as values get spilled. The first + :class:`MAX_LOOP_INTRAM_SLOTS` are reserved by emit for + serial-loop idx slots — see ``_emit_loop_serial``. + """ + + def __init__(self, fn: "mir.MirFunction") -> None: + self.fn = fn + self.emit_order = compute_emit_order(fn) + # Structured live-interval ENDS: ``{id(MirValue): end_idx}``. + # A value is dead (its GP recyclable) once ``cur_idx > end``. + # SCF-aware: loop-carried values already extend to their loop's + # end, so NO manual pinning is needed. + # ``loop_last_idx`` maps id(MirLoop) → its last body emit index; + # emit uses it to give serial-loop counters a correct interval. + # Reject loop-carried SSA values (iter_args) we can't handle. + _check_no_iter_args(fn) + self.loop_last_idx: Dict[int, int] = {} + # ``carried_by_loop[id(MirLoop)]`` = set of value-ids live-in to + # that loop (defined outside, used inside). While that loop is + # open during emit, these must not be spilled — see + # ``_pick_spill_victim`` / ``enter_loop`` / ``exit_loop``. + self.carried_by_loop: Dict[int, set] = {} + self.end = _compute_live_intervals( + fn, self.emit_order, self.loop_last_idx, + self.carried_by_loop, + ) + # id(MirValue) -> the MirValue, for human-readable diagnostics + # (the gp maps are keyed by id only). + self._vid_to_value: Dict[int, "mir.MirValue"] = {} + # Free GP pool — LIFO so freshly-freed GPs come back first + # (best for live-locality + dump readability). + self.free_gp: List[int] = list( + range(GP_USER_FIRST, GP_USER_LAST + 1), + ) + self.value_to_gp: Dict[int, int] = {} + self.value_to_areg: Dict[int, int] = {} + self._next_areg: int = 0 + # Manual interval ends for values minted DURING emit (the serial + # loop counter), which the pre-pass never walked. emit calls + # ``set_end`` for these. Merged into ``end`` lookups. + self._manual_end: Dict[int, int] = {} + # Values whose physical GP is referenced by RAW NUMBER in emitted + # ISA outside the allocator's operand path (serial-loop counter + + # lvar, used directly in the C_LOOP_START/END + idx prologue/ + # epilogue). These must never be spilled or recycled while the + # loop is open, or the raw GP number would point at stale data. + # This is the ONLY pin concept left; loop-carried *data* values + # need no pin — their structured interval keeps them live. + self._pinned: set = set() + # GPs holding the CURRENT instruction's input operands. While set, + # ``_pick_spill_victim`` will not steal them: their tokens have + # already been emitted into this instruction's ISA text, so + # spilling them (e.g. when assign_i32 allocates the result under + # register pressure) would silently invalidate the operand the + # instruction is about to read. Set by emit before assigning the + # result GP, cleared after. Holds id(MirValue)s. + self._operand_lock: set = set() + # Stack of id(MirLoop) for serial loops currently OPEN during + # emit (pushed in enter_loop, popped in exit_loop). The UNION of + # their ``carried_by_loop`` sets is the set of values that must + # not be spilled right now: each is live across a loop body whose + # single-pass emission won't replay a spill on the back-edge, so + # spilling it corrupts the next iteration. Maintained incrementally + # in ``_carried_now`` for O(1) victim checks. + self._open_loops: List[int] = [] + self._carried_now: set = set() + # IntRAM spill state. ``value_to_slot`` maps id(MirValue) → + # IntRAM slot index for values whose GP was reclaimed via + # spill. They're not in ``value_to_gp`` while spilled; a + # subsequent operand reference reloads them via + # ``ensure_in_gp``. ``intram_next_slot`` is the next free + # IntRAM slot — emit reserves slots 0..N-1 for serial-loop + # counters (see ``reserve_intram_slot`` / serial loop + # epilogue logic), so this counter starts above the + # reservation. + # IntRAM spill state — the SINGLE spill mechanism. ``value_to_slot`` + # maps id(MirValue) → its IntRAM slot. A spilled value lives in + # IntRAM; on each use it is reloaded into a throwaway GP and then + # promptly returned to IntRAM (its slot entry is RETAINED across + # the reload). This "reload-per-use, never resident" discipline is + # what makes spilling correct across a loop back-edge — every use + # carries its own S_LD_INT, which re-executes each iteration — and + # lets many spilled values time-share a few GPs. The store to a + # value's slot happens once at spill; the value is loop-invariant + # OR dies within the body, so the slot stays valid for re-reads. + self.value_to_slot: Dict[int, int] = {} + self._intram_next_slot: int = 0 + # Hook for emit to receive spill/reload ISA lines + to + # capture cur_idx for reload sequencing. emit binds these + # at construction time. + self._isa_lines: Optional[List[str]] = None + # Track whether we've already warned for this function + # (warn at most once per kernel; per-spill noise is + # excessive). + self._spill_warned: bool = False + # Total spill / reload counts for diagnostics. + self.spill_count: int = 0 + self.reload_count: int = 0 + + # ---- emit hookups (set by MirToIsa.__init__) ---- + def bind_emit(self, isa_lines: List[str], get_cur_idx) -> None: + """Hand the allocator the live ISA-line list + a callback + to fetch the current instr idx (used when picking spill + victims — we want the one whose interval end is farthest from + ``cur_idx``).""" + self._isa_lines = isa_lines + self._get_cur_idx = get_cur_idx + + def reserve_intram_slot(self) -> int: + """Reserve and return the next IntRAM slot. Used by emit + for serial-loop idx slots (see ``_emit_loop_serial``).""" + slot = self._intram_next_slot + self._intram_next_slot += 1 + return slot + + # ---- free-pool invariants ---- + # Two invariants the allocator must never violate: + # (1) a GP appears in ``free_gp`` AT MOST ONCE — otherwise two + # distinct ``pop`` calls hand the same physical register to two + # live values, and one silently clobbers the other; + # (2) a GP in ``free_gp`` is owned by NO value in ``value_to_gp``. + # Every return-to-pool routes through ``_free_gp`` and every take + # through ``_take_gp`` so these hold by construction. A violation is a + # real allocator bug; we fail loud rather than emit corrupt ISA. + def _free_gp(self, gp: int) -> None: + if gp == 0: + return # gp0 is never pooled + if gp in self.free_gp: + raise MirToIsaError( + f"alloc invariant: gp{gp} freed twice (already in pool " + f"{self.free_gp}). Double-free corrupts allocation." + ) + owner = next((v for v, g in self.value_to_gp.items() if g == gp), + None) + if owner is not None: + raise MirToIsaError( + f"alloc invariant: gp{gp} freed while still owned by " + f"value@{owner}." + ) + self.free_gp.insert(0, gp) + + def _take_gp(self) -> int: + gp = self.free_gp.pop(0) + owner = next((v for v, g in self.value_to_gp.items() if g == gp), + None) + if owner is not None: + raise MirToIsaError( + f"alloc invariant: gp{gp} taken from pool but still owned " + f"by value@{owner} — pool/owner state diverged." + ) + return gp + + def set_end(self, v: "mir.MirValue", end_idx: int) -> None: + """Register an interval end for a value minted during emit (the + serial loop counter). The pre-pass never saw it, so emit must + declare how long it lives — typically the loop's last body + emit index.""" + self._manual_end[id(v)] = end_idx + + def _make_room(self) -> bool: + """Free exactly one GP so a subsequent ``_take_gp`` succeeds, by + spilling one value to IntRAM. Returns True if a GP was freed, + False if nothing is spillable (genuine, unrecoverable pressure). + + There is ONE spill mechanism now: store to IntRAM, reload per use. + Loop-carried values are spillable too — correctness across the + loop back-edge comes from reloading at every use (each S_LD_INT + re-executes each iteration), not from keeping the value resident.""" + victim = self._pick_spill_victim() + if victim is None: + return False + self._spill_value(victim) + return True + + def spill_carried_at_entry(self, lp: "mir.MirLoop") -> None: + """SCOPE-ENTRY spill (design §3). Called BEFORE ``C_LOOP_START`` + — i.e. outside the loop body — to demote loop-carried values to + IntRAM whose ``S_ST_INT`` must therefore land OUTSIDE the body. + + We spill every carried value of ``lp`` that is currently resident + in a GP. The store is emitted here (outside the body, once). The + body then reads each via reload-from-slot (loads only, which + re-execute correctly every iteration because the slot was written + outside the body and the value is loop-invariant). This is the + structural fix for the "spill-store landed in the body and got + overwritten across iterations" bug: stores never sit in a body. + + After this, ``_pick_spill_victim`` forbids spilling carried + values mid-body, so no new carried store can land in the body. + Spilling ALL carried values here is conservative (may add reload + traffic) but always correct; keeping some in GP is a later + optimisation.""" + carried = self.carried_by_loop.get(id(lp), set()) + for vid in sorted(carried): + if vid in self.value_to_gp and self.value_to_gp[vid] != 0 \ + and vid not in self._pinned: + self._spill_value(vid) + + def enter_loop(self, lp: "mir.MirLoop") -> None: + """Mark a serial loop as open during emit: its loop-carried + values become unspillable mid-body until ``exit_loop`` (their + stores were already done at scope entry — see + ``spill_carried_at_entry``). Must be called right before emitting + the loop body, paired with exit_loop.""" + lid = id(lp) + self._open_loops.append(lid) + self._carried_now |= self.carried_by_loop.get(lid, set()) + + def exit_loop(self, lp: "mir.MirLoop") -> None: + """End the open-loop scope started by ``enter_loop``. Recompute + the carried set from the loops still open (a value carried by an + outer loop must stay protected even after the inner one closes).""" + lid = id(lp) + assert self._open_loops and self._open_loops[-1] == lid, ( + "enter_loop/exit_loop mismatch" + ) + self._open_loops.pop() + self._carried_now = set() + for open_lid in self._open_loops: + self._carried_now |= self.carried_by_loop.get(open_lid, set()) + + def _end_of(self, vid: int) -> Optional[int]: + """Interval end for ``vid`` (manual override wins). None means + the value has no recorded use at all.""" + if vid in self._manual_end: + return self._manual_end[vid] + return self.end.get(vid) + + def pin(self, v: "mir.MirValue") -> None: + """Pin ``v``'s GP against spill AND recycle for the duration of + an open serial loop. Only for emit-managed values whose GP is + named raw in the prologue/epilogue (counter, lvar).""" + self._pinned.add(id(v)) + + def unpin(self, v: "mir.MirValue") -> None: + self._pinned.discard(id(v)) + # Release now if dead: at loop exit the counter/lvar are done. + gp = self.value_to_gp.get(id(v)) + if gp is not None and gp != 0: + del self.value_to_gp[id(v)] + self._free_gp(gp) + + def lock_operands(self, vids) -> None: + """Protect these values' GPs from being chosen as spill victims. + Emit calls this with the current instruction's input-operand + ids before allocating the result GP, so a result allocation + under register pressure can't spill an operand whose token was + already emitted for this instruction.""" + self._operand_lock = set(vids) + + def clear_operand_lock(self) -> None: + self._operand_lock = set() + + # ---- spill / reload ---- + def _pick_spill_victim(self) -> Optional[int]: + """Return id(MirValue) of the value whose GP we'll steal, or None + if no candidate. Among SPILLABLE values, picks the one with the + FARTHEST interval end (least likely to be referenced soon). + + A value is NOT spillable MID-BODY if it is: + * gp0 (the const-zero fixture), + * pinned (serial-loop counter/lvar, named raw in ISA), + * an operand of the current instruction (token already emitted), + * loop-carried by any currently-open loop (``_carried_now``): + its store must stay OUTSIDE the body (see + ``spill_carried_at_entry``); spilling it here would emit an + ``S_ST_INT`` inside the body that re-executes every iteration + with a clobbered GP. Carried values are demoted at scope + ENTRY instead, where the store is outside the body. + So mid-body spill only ever targets ``local`` temporaries whose + store+reload sit in the same straight-line iteration — safe.""" + cur = self._get_cur_idx() if self._get_cur_idx else -1 + best_vid = None + best_end = -1 + for vid, gp in self.value_to_gp.items(): + if gp == 0: + continue + if vid in self._pinned: + continue # raw-GP-named (counter/lvar) — never move it + if vid in self._operand_lock: + continue # current instr's live operand — would corrupt it + if vid in self._carried_now: + continue # carried — demoted at scope entry, not mid-body + e = self._end_of(vid) + e = cur if e is None else e + if e > best_end: + best_end = e + best_vid = vid + return best_vid + + def _spill_value(self, vid: int) -> int: + """Spill the value with id ``vid``: free its GP back to the pool, + leaving the value in IntRAM (``value_to_slot``). Returns the freed + GP number. + + If the value ALREADY has a slot (it was spilled before and is + only transiently reloaded), we DON'T store again — the slot's + contents are still valid (the value is loop-invariant or unchanged + since the original store). We just free the GP. This is what makes + reload-per-use cheap: one store ever, many reloads.""" + gp = self.value_to_gp.pop(vid) + existing = self.value_to_slot.get(vid) + if existing is not None: + self._isa_lines.append( + f"; RE-EVICT value@{vid} (slot {existing} still valid) " + f"freeing gp{gp}" + ) + self._free_gp(gp) + return gp + slot = self.reserve_intram_slot() + self.value_to_slot[vid] = slot + self._isa_lines.append( + f"; SPILL value@{vid} -> intram[{slot}] (freeing gp{gp})" + ) + self._isa_lines.append(f"S_ST_INT gp{gp}, gp0, {slot}") + # Return the GP to the free pool so the caller (assign_i32 / + # ensure_in_gp) can grab it. + self._free_gp(gp) + self.spill_count += 1 + if not self._spill_warned: + self._spill_warned = True + warnings.warn( + f"mir_to_isa: IntRAM spill triggered in function " + f"{self.fn.name!r} — register pressure exceeded " + f"{GP_USER_LAST - GP_USER_FIRST + 1} GPs. " + f"Each spill incurs an S_ST_INT + S_LD_INT per use. " + f"Consider reducing LICM aggressiveness, hoisting " + f"fewer invariants, or accepting the IntRAM round-trip.", + RuntimeWarning, + stacklevel=3, + ) + return gp + + def ensure_in_gp(self, v: "mir.MirValue") -> int: + """Like ``assign_i32`` but for an OPERAND read. If ``v`` is + currently spilled, reload it from its IntRAM slot into a + GP (possibly triggering another spill to free that GP). + Returns the GP holding ``v``.""" + if v.is_function_const: + return 0 + if id(v) in self.value_to_gp: + return self.value_to_gp[id(v)] + if id(v) in self.value_to_slot: + # Reload from IntRAM into a GP. The slot entry is RETAINED + # (not popped): the value stays spill-resident, so after this + # use ``release_dead_at`` frees the GP again and the next use + # reloads afresh. Reload-per-use keeps it correct across the + # loop back-edge and lets spilled values time-share GPs. + slot = self.value_to_slot[id(v)] + if not self.free_gp and not self._make_room(): + raise MirToIsaError( + f"reload: no free GP and nothing to spill " + f"for value {v!r}" + ) + gp = self._take_gp() + self.value_to_gp[id(v)] = gp + self._isa_lines.append( + f"; RELOAD value@{id(v)} from intram[{slot}] -> gp{gp}" + ) + self._isa_lines.append(f"S_LD_INT gp{gp}, gp0, {slot}") + self.reload_count += 1 + return gp + # Reading an operand that is in NEITHER a GP nor a spill slot + # means its value was dropped (release_dead_at deleted it from + # value_to_gp) before this use — i.e. its computed last_use is + # earlier than this actual use. Allocating a fresh GP here would + # silently hand back garbage (no reload), corrupting the operand. + # This is always a live-range bug, never legitimate. Fail loud. + cur = self._get_cur_idx() if self._get_cur_idx else -1 + raise MirToIsaError( + f"ensure_in_gp: operand value@{id(v)} ({v!r}) read at " + f"cur_idx={cur} is in neither a GP nor a spill slot — it was " + f"released before this use. Its interval end=" + f"{self._end_of(id(v))!r}. This is a live-interval bug: the " + f"value's recorded end is earlier than this actual use." + ) + + def assign_i32(self, v: "mir.MirValue") -> int: + """Return the gp number for ``v``, allocating a new one if + first sight. The function-level gp0 constant is fixed at + register 0. On free-pool exhaustion, spills a live value.""" + self._vid_to_value[id(v)] = v + if v.is_function_const: + self.value_to_gp[id(v)] = 0 + return 0 + if id(v) in self.value_to_gp: + return self.value_to_gp[id(v)] + if id(v) in self.value_to_slot: + # Defining a value that we've spilled? Shouldn't + # happen — spill is on USES, but the def itself sets + # the initial value. Reload via ensure_in_gp instead. + return self.ensure_in_gp(v) + if not self.free_gp: + # Free one GP by spilling a value to IntRAM. Only fires when + # the pool is empty, so the common path is untouched. + if not self._make_room(): + # Nothing spillable. Break down WHICH value holds each GP + # and why, so we can see what LICM hoisted / what's + # carried across the open loops. + def _nm(vid): + val = self._vid_to_value.get(vid) + return val.name if val is not None else f"id{vid}" + lines = [] + for vid, g in sorted(self.value_to_gp.items(), + key=lambda kv: kv[1]): + why = [] + if vid in self._pinned: + why.append("pinned(ctr/lvar)") + if vid in self._carried_now: + why.append("carried") + if vid in self._operand_lock: + why.append("operand") + lines.append( + f" gp{g} = %{_nm(vid)} end={self._end_of(vid)}" + f" [{','.join(why) or 'free?'}]" + ) + raise MirToIsaError( + f"GP file exhausted ({len(self.value_to_gp)} of " + f"{GP_USER_LAST - GP_USER_FIRST + 1} GPs held, none " + f"spillable) while defining %{v.name}.\n" + f" open loops: {len(self._open_loops)}\n" + + "\n".join(lines) + ) + gp = self._take_gp() + self.value_to_gp[id(v)] = gp + return gp + + def assign_addr_reg(self, v: "mir.MirValue") -> int: + if id(v) in self.value_to_areg: + return self.value_to_areg[id(v)] + areg = self._next_areg + self._next_areg += 1 + self.value_to_areg[id(v)] = areg + return areg + + def release_dead_at(self, cur_idx: int) -> None: + """Return GPs to the free pool. Two cases: + + * A SPILLED value (has a ``value_to_slot`` entry) that is + currently reloaded into a GP is freed PROMPTLY — its slot keeps + the value, the next use reloads again. This "never resident" + discipline is the time-sharing + back-edge correctness of the + single spill mechanism. + * A non-spilled value is freed once its structured interval end + is ``<= cur_idx`` (provably dead, including across serial-loop + re-entries — the interval extends carried values to the loop + end). Values with no recorded end are dead on definition. + + Convention: ``_emit_instr`` calls with ``cur-1`` before operand + formatting and ``cur`` after emitting (so ``end <= cur_idx``).""" + for vid in list(self.value_to_gp.keys()): + gp = self.value_to_gp[vid] + if gp == 0: + continue # gp0 is a permanent hardware fixture + if vid in self._operand_lock: + continue # token in the current line — don't reclaim yet + if vid in self._pinned: + continue # emit owns the counter/lvar lifecycle + if vid in self.value_to_slot: + # Spilled value, transiently reloaded: free its GP now; + # the slot keeps the value for the next use's reload. + del self.value_to_gp[vid] + self._free_gp(gp) + continue + if vid in self._carried_now: + continue # live across an open loop — keep its GP put + e = self._end_of(vid) + if e is None or e <= cur_idx: + # Drop ownership first, then return to pool (the pool + # invariant check requires the GP be unowned). + del self.value_to_gp[vid] + self._free_gp(gp) + + +# --------------------------------------------------------------------- +# Emit +# --------------------------------------------------------------------- + +class MirToIsa: + """Walk a MirFunction; produce ISA text.""" + + def __init__(self, fn: mir.MirFunction, shim) -> None: + self.fn = fn + self.shim = shim + self.alloc = _LinearScanAllocator(fn) + self.lines: List[str] = [] + # For serial loops we need to claim an IntRAM idx slot and a + # GP-loop counter; pin/release like legacy ``_emit_for``. + # Track open serial loops to emit the matching epilogue. + self._serial_loop_stack: List[Dict] = [] + # Current emit index, set per instr from the pre-assigned + # ``compute_emit_order`` map (NOT a running counter). Drives the + # allocator's release_dead_at / spill-victim choice. + self._cur_idx: int = -1 + # Wire allocator so spill/reload can emit S_ST_INT/S_LD_INT + # into our ``lines`` list and ask us for the current idx + # when picking a spill victim. + self.alloc.bind_emit(self.lines, lambda: self._cur_idx) + + def run(self) -> str: + # Header. + self.lines.append(f"; PLENA ISA -- kernel: {self.fn.name}") + self.lines.append( + "; generated by tilelang_tvm_compiler (PreIsaIR v2 → MIR path)" + ) + self.lines.append("; " + "=" * 60) + for blk in self.fn.blocks: + self._emit_block(blk) + return "\n".join(self.lines) + "\n" + + def _emit_block(self, blk: mir.MirBlock) -> None: + for item in blk.items: + if isinstance(item, mir.MirInstr): + self._emit_instr(item) + elif isinstance(item, mir.MirLoop): + self._emit_loop(item) + else: + raise MirToIsaError( + f"unexpected block item: {type(item).__name__}" + ) + + def _emit_loop(self, lp: mir.MirLoop) -> None: + if lp.loop_kind == "serial": + self._emit_loop_serial(lp) + return + if lp.loop_kind == "unroll": + # Emit-time unrolling was removed: it cloned the body into a + # scratch block (minting MirValues the precomputed last_use + # table never saw) and corrupted register allocation. All + # loops now lower to hardware C_LOOPs — pre_isa_ir_v2's + # FORCE_SERIAL_LOOPS downgrades unroll→serial at construction, + # so this branch should be unreachable. If it fires, a loop + # was built after that switch was bypassed. + raise MirToIsaError( + f"loop {lp.name!r} reached emit with loop_kind='unroll'; " + f"emit-time unrolling is removed (see " + f"pre_isa_ir_v2.FORCE_SERIAL_LOOPS). All loops must be " + f"serial by the time MIR is emitted." + ) + raise MirToIsaError( + f"unknown loop_kind {lp.loop_kind!r} on loop {lp.name}" + ) + + def _emit_loop_serial(self, lp: mir.MirLoop) -> None: + """Emit a hardware-backed serial loop. + + The loop_var is materialised the way the PLENA ISA spec + mandates: a SEPARATE software-maintained index in an IntRAM + slot, NOT the hardware loop counter register. Per the spec + (``doc/plena_isa_spec.md`` C_LOOP_START): + + "The loop counter register ``rd`` does NOT contain the + current iteration index. You must maintain your own + index variable and increment it manually inside the + loop." + + So: + * a **counter GP** holds ``C_LOOP_START``'s remaining-iter + count (hardware-managed; we never read it as data), + * a **lvar GP** + **IntRAM idx slot** hold the real index; + the body reads it at entry, the epilogue increments and + stores it back. + + A prior experiment derived the loop_var as ``counter - 1`` to + save the idx slot + 3 instrs/iter. It passed in the simulator + (whose ``gp[rd]`` happens to expose ``extent..1``) but the + hardware spec explicitly forbids reading ``rd`` as an index, + so it was undefined behaviour on real silicon and has been + removed. ``order_independent`` annotations still flow through + the IR (kernel → HLIR → MIR) for any future, spec-safe + order-independence optimisation, but the backend treats every + serial loop identically here. + + Resources are emit-time-only physical concerns (no upper + layer should know about them): + + * **counter GP** — ``C_LOOP_START gpN, K`` operand. Minted + fresh per loop. + * **lvar GP** — holds the current index in the body. + * **IntRAM idx slot** — software index backing store. + + Both the counter and lvar GPs are named RAW (by number) in the + prologue/epilogue ISA, outside the allocator's operand path, so + they are ``pin``-ned: never spilled or recycled while the loop is + open. They also get an interval end at the loop's last body index + so the structured allocator sees them as live across the body + (the body is emitted once but runs N times). Data values carried + from outside the loop need NO pin — their structured interval + already extends across the loop. + + Non-zero ``lp.init`` is handled in the prelude by storing the + init value into the idx slot instead of zero. + """ + loop_end = self.alloc.loop_last_idx.get(id(lp)) + + # 0) SCOPE ENTRY: demote this loop's carried values to IntRAM + # NOW — before C_LOOP_START — so their S_ST_INT lands OUTSIDE the + # body (design §3). Inside the body they are reloaded on use + # (loads only). This is the structural fix for spill-stores + # landing in a loop body and getting clobbered across iterations. + # It also frees the GPs the counter/lvar need below. + self.alloc.spill_carried_at_entry(lp) + + # 1) Counter GP. Pinned: C_LOOP_START/END name its GP raw. + counter_val = self.fn.mint_value("i32", hint=f"loop_ctr_{lp.name}") + counter_gp = self.alloc.assign_i32(counter_val) + self.alloc.pin(counter_val) + if loop_end is not None: + self.alloc.set_end(counter_val, loop_end) + + # 2) lvar GP — body uses it; pinned because the idx prologue/ + # epilogue (S_LD/S_ADDI/S_ST) name its GP raw too. + lvar_gp = self.alloc.assign_i32(lp.loop_var) + self.alloc.pin(lp.loop_var) + if loop_end is not None: + self.alloc.set_end(lp.loop_var, loop_end) + + # 3) IntRAM idx slot + LD entry + LD/ADDI/ST exit. + idx_addr = self.alloc.reserve_intram_slot() + self.lines.append( + f"; for {lp.loop_var.name} in [{lp.init}, " + f"{lp.init + lp.extent}) -- hw counter gp{counter_gp}, " + f"idx ram[{idx_addr}]" + ) + if lp.init == 0: + self.lines.append(f"S_ST_INT gp0, gp0, {idx_addr}") + else: + init_imm = int(lp.init) + if 0 <= init_imm <= 262143: + self.lines.append( + f"S_ADDI_INT gp{lvar_gp}, gp0, {init_imm}" + ) + else: + raise MirToIsaError( + f"serial loop init {init_imm} exceeds S_ADDI_INT " + f"immediate range; S_LUI fallback not yet wired" + ) + self.lines.append( + f"S_ST_INT gp{lvar_gp}, gp0, {idx_addr}" + ) + self.lines.append( + f"C_LOOP_START gp{counter_gp}, {lp.extent}" + ) + self.lines.append( + f"S_LD_INT gp{lvar_gp}, gp0, {idx_addr}" + ) + self._serial_loop_stack.append({ + "counter_gp": counter_gp, + "counter_val": counter_val, + "idx_addr": idx_addr, + "lvar_gp": lvar_gp, + "lvar_name": lp.loop_var.name, + "loop_var": lp.loop_var, + }) + # Open the loop scope: its loop-carried values become unspillable + # for the whole body (a spill here wouldn't replay on the + # back-edge — see _pick_spill_victim). + self.alloc.enter_loop(lp) + try: + for body_blk in lp.body: + self._emit_block(body_blk) + finally: + self.alloc.exit_loop(lp) + # Emit owns counter/lvar lifecycle: unpin AFTER the body so + # their raw GPs survive the whole region, then the epilogue + # below still references them by the same number. + self.alloc.unpin(lp.loop_var) + self.alloc.unpin(counter_val) + st = self._serial_loop_stack.pop() + self.lines.append( + f"; idx {st['lvar_name']} += 1 (ram[{st['idx_addr']}])" + ) + self.lines.append( + f"S_LD_INT gp{st['lvar_gp']}, gp0, {st['idx_addr']}" + ) + self.lines.append( + f"S_ADDI_INT gp{st['lvar_gp']}, gp{st['lvar_gp']}, 1" + ) + self.lines.append( + f"S_ST_INT gp{st['lvar_gp']}, gp0, {st['idx_addr']}" + ) + self.lines.append(f"C_LOOP_END gp{st['counter_gp']}") + + def _emit_instr(self, instr: mir.MirInstr) -> None: + # Use the PRE-ASSIGNED emit index (compute_emit_order), not a + # running counter — this is the same index space the interval + # table is keyed on, so cur_idx and interval ends can never + # drift (the old running-counter approach drifted whenever emit + # appended ISA lines the interval pre-pass never walked). + cur = self.alloc.emit_order[id(instr)] + self._cur_idx = cur + op = instr.opcode + if op == "_COMMENT": + text = instr.operands[0] if instr.operands else "" + self.lines.append(f"; {text}") + self.alloc.release_dead_at(cur) + return + + spec = mir.OPCODES.get(op) + if spec is None: + raise MirToIsaError(f"unknown opcode {op!r}") + + # FIRST: release any GPs whose interval END was before this + # instruction (end <= cur-1). This lets ``assign_i32(instr.result)`` + # (below) see a free pool that already excludes values that died + # at end == cur-1. Without it, assign_i32 grabs a fresh GP for the + # result while the previous instr's now-dead operand GPs are still + # parked — peak GP usage inflates by one slot at every link in an + # address-arithmetic chain. + self.alloc.release_dead_at(cur - 1) + + # Format operands per opcode spec. Lock each i32 operand's GP + # INCREMENTALLY — as soon as its token is emitted, it must not be + # spilled/remat-evicted to make room for a LATER operand or the + # result (its gpN is already in this line's text). Locking only + # after the whole loop would leave earlier operands stealable + # while formatting later ones. + tokens: List[str] = [] + operand_value_ids = [] + self.alloc.lock_operands(operand_value_ids) + try: + for i, (val, kind) in enumerate( + zip(instr.operands, spec.operand_kinds), + ): + tokens.append(self._fmt_operand(val, kind)) + if kind == "i32" and isinstance(val, mir.MirValue) \ + and not val.is_function_const: + operand_value_ids.append(id(val)) + self.alloc.lock_operands(operand_value_ids) + + # Format result prefix (if non-void). The dst GP goes FIRST + # in the ISA arg list. Operand locks (above) keep the inputs + # from being evicted to make room for the result. + result_tok: Optional[str] = None + if instr.result is not None: + if spec.result_type == "i32": + gp = self.alloc.assign_i32(instr.result) + result_tok = f"gp{gp}" + elif spec.result_type == "addr_reg": + areg = self.alloc.assign_addr_reg(instr.result) + result_tok = f"a{areg}" + else: + raise MirToIsaError( + f"{op}: don't know how to assign physical reg for " + f"result_type {spec.result_type!r}" + ) + + # PLENA ISA convention: the destination GP goes FIRST in the + # operand list (just like RISC-V). The MIR operand list + # carries SOURCES only; we prepend dst here. + if result_tok is not None: + arg_list = ", ".join([result_tok] + tokens) + else: + arg_list = ", ".join(tokens) + self.lines.append(f"{spec.isa_mnemonic} {arg_list}") + finally: + self.alloc.clear_operand_lock() + # Post-emit: release any value whose interval END == cur. These + # are operands whose final use was this very instruction. They + # can't be released before assign_i32 above (their GP was needed + # to format the operand text), but they're dead now and the next + # instr can reuse them. + self.alloc.release_dead_at(cur) + + def _fmt_operand(self, val, kind: str) -> str: + if kind == "i32": + if isinstance(val, mir.MirValue): + # ensure_in_gp reloads from IntRAM if val was + # spilled; transparent to the caller. + return f"gp{self.alloc.ensure_in_gp(val)}" + if isinstance(val, int): + return str(val) + raise MirToIsaError( + f"i32 operand expects MirValue or int; got {val!r}" + ) + if kind == "literal_int": + return str(int(val)) + if kind == "fp_reg": + if not isinstance(val, str): + raise MirToIsaError( + f"fp_reg operand expects str; got {val!r}" + ) + return val + if kind == "verbatim_str": + if not isinstance(val, str): + raise MirToIsaError( + f"verbatim_str expects str; got {val!r}" + ) + return val + if kind == "addr_reg": + if isinstance(val, mir.MirValue): + return f"a{self.alloc.assign_addr_reg(val)}" + raise MirToIsaError( + f"addr_reg operand expects MirValue; got {val!r}" + ) + raise MirToIsaError(f"unknown operand kind {kind!r}") + + +def emit(fn: mir.MirFunction, shim) -> str: + """Convert a MirFunction to ISA text.""" + return MirToIsa(fn, shim).run() + + +__all__ = ["emit", "MirToIsa", "MirToIsaError"] diff --git a/tilelang_tvm_compiler/pipeline.py b/tilelang_tvm_compiler/pipeline.py new file mode 100644 index 0000000..f48ab34 --- /dev/null +++ b/tilelang_tvm_compiler/pipeline.py @@ -0,0 +1,283 @@ +"""End-to-end driver: raw TIR PrimFunc -> real PLENA ISA text. + +Orchestrates: + 0. inline_let_stmts + lower_compound_fp_stores (stmt prep) + 1. mid_ir pipeline (10 passes, see frontend/mid_ir/passes/) + 2. AddressAllocationPass (HLIR + addresses) + 3. IsaEmitterPass (HLIR -> ISA text) + +The legacy ``frontend/`` graph-IR pipeline + ``codegen.PlenaCodegen`` +are no longer in the call path. They're still on disk for reference +but aren't imported here. + +Hardware constants for the program shim are passed in via PlenaTarget, +which we keep deliberately small for now -- mlen/blen/btmm shape are +fixed per chip variant. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import tvm +from tvm import tir + +from .address_alloc import AddressAllocationPass, AddressAllocConfig +from . import dead_buffer_elim as _dead_buffer_elim +from . import fuse_adjacent_loops as _fuse_adjacent_loops +from . import loop_interchange as _loop_interchange +from . import loop_register_alloc as _loop_register_alloc +from . import plena_settings as _plena_settings +# Direct submodule imports to avoid the legacy frontend package's +# __init__ (which imports compile_func → frontend/pipeline.py → +# ..pipeline.PlenaTarget, a circular import once we land here). +from .frontend.passes import inline_let_stmts as _stmt_inline_let +from .frontend.passes import lower_compound_fp_stores as _stmt_lower_compound +from .frontend.passes import hoist_float_constants as _stmt_hoist_consts +from .frontend.mid_ir.passes import infer_lane_axis as _mid_infer_lane_axis +from .frontend.mid_ir.passes import fold as _mid_fold +from .frontend.mid_ir.passes import mark as _mid_mark +from .frontend.mid_ir.passes import split as _mid_split +from .frontend.mid_ir.passes import distribute_cluster as _mid_distribute +from .frontend.mid_ir.passes import async_wrap as _mid_async +from .frontend.mid_ir.passes import view as _mid_view +from .frontend.mid_ir.passes import fuse as _mid_fuse +from .frontend.mid_ir.passes import burn_view as _mid_burn +from .frontend.mid_ir.passes import to_plena as _mid_to_plena +from .hlir import HLIRModule +from .isa_pass import IsaEmitterPass +from .program_shim import make_shim +from .register_alloc import RegisterAllocator + + +@dataclass +class PlenaTarget: + """Hardware-shape constants. Equivalent to TileTensorProgram() ctor. + + Defaults are read from ``plena_settings.toml`` (the active mode's + MLEN / HLEN / BLEN) so the compiler and the simulator never drift. + Pass explicit values to override for a non-default target / test. + """ + + mlen: int = field(default_factory=_plena_settings.mlen) + blen: int = field(default_factory=_plena_settings.blen) + # group_heads — how many narrow heads pack into one MLEN vector. + btmm_lane_count: int = field( + default_factory=lambda: _plena_settings.load_sizes().hardware_lane_count + ) + btmm_hlen: int = field(default_factory=_plena_settings.hlen) + + +@dataclass +class CompiledKernel: + name: str + hlir: HLIRModule + isa_text: str + # GP allocator trace captured during ISA emit. ``None`` if the + # compile path didn't expose it. Each entry is a dict with keys + # ``asm_line``/``site``/``event``/``free``/``in_use``/``pinned`` + # plus event-specific fields (regs, slot, addr, n, ...). + gp_trace: list = None + # lowir recording: ``(op_idx, expr_str)`` pairs captured at the + # var->gp materialization chokepoint during ISA emit. Feeds the + # ``.lowir.txt`` report — the symbolic "last variable-form" + # of every address expression the ISA actually consumes. Empty list + # if not recorded. + lowir_log: list = None + + def __repr__(self) -> str: + return ( + f"CompiledKernel(name={self.name!r}, " + f"buffers={len(self.hlir.buffers)}, " + f"ops={len(self.hlir.ops)}, " + f"isa_lines={self.isa_text.count(chr(10))})" + ) + + +def compile_kernel( + prim_func: tir.PrimFunc, + *, + target: PlenaTarget, + name: str = "kernel", + midir_dump_dir: Optional[Path] = None, + addr_config_override: Optional[AddressAllocConfig] = None, + use_v2: bool = False, +) -> CompiledKernel: + """Lower a raw TIR PrimFunc through the mid_ir pipeline + downstream + address-alloc + ISA-emit passes. + + ``midir_dump_dir`` (when set): pass_6_to_plena will write a + human-readable ``.midir.txt`` snapshot there for debugging. + + ``addr_config_override`` (when set): use this AddressAllocConfig + verbatim for the address-alloc pass instead of building a default + one from ``target``. Used by multi-kernel drivers that stitch + several kernels into one continuous ASM run and need to control + the FPRAM / HBM bases per kernel. + + ``use_v2`` (default False): when True, route the post-address HLIR + through the PreIsaPassV2 → MIR → ISA pipeline instead of the legacy + single-pass ``IsaEmitterPass``. The v2 path is fully op-coverage- + complete (all 38 HLIR op kinds) and produces structurally + identical HW-op streams (same M_MM/M_BTMM/H_PREFETCH_V/etc. + count and order); GP numbers can differ. Set this when you want + the v2 path's tighter register allocation (~9 GPs on full matmul + vs 14+ in legacy) and the MIR-level dump for debugging. + + Unroll loops in v2: the MIR is kept compact (no physical + expansion); ``mir_to_isa._emit_loop_unroll`` clones each iter's + body into a throwaway scratch block, substitutes the loop_var + to a plain int, runs scratch-local constant folding, then emits + the folded result. This collapses per-iter address arithmetic + without inflating the MIR, and is consistently shorter ISA than + the previous "physically unroll then re-fold" strategy because + LICM's hoisted invariants survive (they aren't duplicated per + unroll iter). + """ + # ---------- 0. stmt prep ---------- + func = _stmt_inline_let.run(prim_func) + func = _stmt_lower_compound.run(func) + # Hoist FP literals (T.float16(c) etc.) into auto-synthesised + # ``global.fpram`` 1-slot buffers so the kernel author doesn't have + # to declare a SCALE / NEG_INF / etc. fragment + a testbench + # preload by hand. See hoist_float_constants.py for the contract. + func = _stmt_hoist_consts.run(func) + + # ---------- 1. mid_ir pipeline ---------- + func = _mid_infer_lane_axis.run(func) + midfn = _mid_fold.run(func, name=name) + midfn = _mid_mark.run(midfn) + midfn = _mid_split.run(midfn) + midfn = _mid_distribute.run(midfn) + midfn = _mid_async.run(midfn) + midfn = _mid_view.run(midfn) + midfn = _mid_fuse.run(midfn) + midfn = _mid_burn.run(midfn) + mod = _mid_to_plena.run(midfn, build_dir=midir_dump_dir, mlen=target.mlen) + + # DEBUG: dump HLIR immediately after to_plena so we can inspect it + # even when later passes fail. + if midir_dump_dir is not None: + from .hlir import format_hlir as _fmt + (midir_dump_dir / "post_to_plena.hlir.txt").write_text(_fmt(mod)) + + # ---------- 1.25. loop interchange + fusion to a fixed point ---------- + # to_plena lowers each per-lane op into its own for-loop. Two + # structural passes alternate until the IR stops changing: + # * loop_interchange — lifts a cluster ``for`` out of an enclosing + # loop so it becomes a sibling of other cluster loops; + # * fuse_adjacent_loops — merges adjacent same-shape loops. + # Alternating both to a fixed point lets interchange expose a fusion + # opportunity, fusion expose a further interchange, and so on. + # Structural-only — runs before address allocation. The iteration + # cap is a safety net; convergence is monotone (each step strictly + # reduces loop count or nesting) so it terminates well before it. + for _ in range(64): + mod, _ic_changed = _loop_interchange.run(mod) + mod, _fu_changed = _fuse_adjacent_loops.run(mod) + if not (_ic_changed or _fu_changed): + break + + # ---------- 1.5. drop unreachable buffers ---------- + # Buffers declared in the kernel but not referenced by any HLIR op + # (e.g. softmax-state fragments in a stub kernel that bypasses + # softmax) would otherwise waste FPRAM/VRAM and can also crash + # downstream shape checks if their post-expansion layout doesn't + # match the lane mode that was never inferred. + _dead_buffer_elim.run(mod) + + # ---------- 2. address alloc ---------- + if addr_config_override is not None: + addr_cfg = addr_config_override + else: + addr_cfg = AddressAllocConfig( + mlen=target.mlen, + blen=target.blen, + hlen=target.btmm_hlen, + ) + addr_pass = AddressAllocationPass(addr_cfg) + addr_pass.run(mod) + + # ---------- 2.5. loop-register allocation ---------- + # Assign each serial ``for`` loop's C_LOOP counter (gp_loop) a GP by + # HLIR liveness, stamping it on the op. The returned set is reserved + # away from the emit-stage allocator so per-op temporaries can never + # collide with a loop counter. See doc/LOOP_REGISTER_ALLOC.md. + loop_reserved_gp = _loop_register_alloc.run(mod) + + # ---------- 3. ISA emit ---------- + allocator = RegisterAllocator( + gp_reserved=(0, *sorted(loop_reserved_gp)), + ) + shim = make_shim( + mlen=target.mlen, + blen=target.blen, + btmm_lane_count=target.btmm_lane_count, + btmm_hlen=target.btmm_hlen, + v_prefetch_amount=_plena_settings.v_prefetch_amount(), + v_writeback_amount=_plena_settings.v_writeback_amount(), + register_allocator=allocator, + ) + if use_v2: + # v2 path: HLIR → PreIsaIR v2 → MIR → opt pipeline → ISA. + # PreIsaPassV2 still delegates layout/offset helpers to a + # legacy IsaEmitterPass instance, but the *visible* ISA + # output here comes from the v2 MIR emit. + from .pre_isa_pass_v2 import PreIsaPassV2 + from . import pre_isa_to_mir as _p2m + from . import mir as _mir + from . import mir_to_isa as _m2i + from . import mir_passes as _mp + pre = PreIsaPassV2(shim).run(mod) + mir_fn = _p2m.convert(pre, shim) + _mir.verify(mir_fn) + # Per-pass MIR dump (debugging): one .mir file per pass under + # ``/mir_passes/`` so each address-PrimExpr fold + # is visible step by step. + _mir_dump_dir = ( + (midir_dump_dir / "mir_passes") if midir_dump_dir else None + ) + # DLE + const-fold + DCE + CSE to a fixed point. Reduces + # GP pressure (peels extent-1 loops, folds static + # address arithmetic, deduplicates repeated bases). + _mp.run_default_pipeline( + mir_fn, enable_licm=True, dump_dir=_mir_dump_dir, + ) + _mir.verify(mir_fn) + isa_text = _m2i.emit(mir_fn, shim) + return CompiledKernel( + name=name, hlir=mod, isa_text=isa_text, + gp_trace=allocator.trace_rows(), + lowir_log=[], + ) + + isa_pass = IsaEmitterPass(shim) + # Record symbolic address expressions for the lowir report. Enabled + # before run() so the recorder captures the real emit pass — no + # second codegen pass, no drift from the actual ISA. + isa_pass.materializer.enable_lowir_log() + isa_text = isa_pass.run(mod) + + return CompiledKernel( + name=name, hlir=mod, isa_text=isa_text, + gp_trace=allocator.trace_rows(), + lowir_log=list(isa_pass.materializer.lowir_log()), + ) + + +def compile_module( + mod: tvm.IRModule, + *, + target: PlenaTarget, +) -> dict: + out = {} + for gv, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc): + continue + out[gv.name_hint] = compile_kernel(func, target=target, name=gv.name_hint) + return out + + +__all__ = ["PlenaTarget", "CompiledKernel", "compile_kernel", "compile_module"] diff --git a/tilelang_tvm_compiler/plena_settings.py b/tilelang_tvm_compiler/plena_settings.py new file mode 100644 index 0000000..307cef6 --- /dev/null +++ b/tilelang_tvm_compiler/plena_settings.py @@ -0,0 +1,155 @@ +"""Single source of truth for hardware sizes — reads ``plena_settings.toml``. + +The simulator config (``PLENA_Simulator/plena_settings.toml``) already +holds the hardware geometry: MLEN / HLEN / BLEN / VLEN, separately for +the ``analytic`` and ``behavior`` modes, with ``[MODE].active`` picking +one. Previously the compiler hard-coded its own copies (PlenaTarget +defaults, CLI defaults, ``cluster_guard.MLEN``, ``split._DEFAULT_LANE``, +per-kernel ``MLEN`` / ``hlen`` args) — so changing target geometry meant +editing several files by hand and keeping them in sync. + +This module reads the toml once and exposes the active mode's sizes, so +every compiler component can derive geometry from one place. + +Override the toml path with the ``PLENA_SETTINGS`` environment variable +(useful for tests / non-default targets). +""" + +from __future__ import annotations + +import os +import tomllib +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + + +class PlenaSettingsError(RuntimeError): + pass + + +def _default_settings_path() -> Path: + """Locate ``plena_settings.toml``. + + Honours ``$PLENA_SETTINGS`` first; otherwise walks up from this file + to ``PLENA_Simulator/`` (this module lives at + ``PLENA_Simulator/compiler/tilelang_tvm_compiler/plena_settings.py``, + so the toml is three parents up).""" + env = os.environ.get("PLENA_SETTINGS") + if env: + return Path(env) + return Path(__file__).resolve().parents[2] / "plena_settings.toml" + + +@dataclass(frozen=True) +class HardwareSizes: + """The active mode's hardware geometry, read from plena_settings.toml. + + ``mlen`` — matrix/vector lane width (full HW vector). + ``hlen`` — narrow head dim per BTMM lane. + ``blen`` — block tile width. + ``vlen`` — vector SRAM row width. + ``v_prefetch_amount`` — number of VLEN-wide rows one H_PREFETCH_V + instruction transfers (HBM_V_Prefetch_Amount). + ``v_writeback_amount`` — number of VLEN-wide rows one H_STORE_V + instruction transfers (HBM_V_Writeback_Amount). + """ + mode: str + mlen: int + hlen: int + blen: int + vlen: int + v_prefetch_amount: int + v_writeback_amount: int + + @property + def hardware_lane_count(self) -> int: + """Number of BTMM head lanes packed into one MLEN vector.""" + return self.mlen // self.hlen + + +@lru_cache(maxsize=None) +def load_sizes(path: str | None = None) -> HardwareSizes: + """Parse ``plena_settings.toml`` and return the active mode's sizes. + + ``path`` defaults to :func:`_default_settings_path`. Cached — the + toml is read once per process. + """ + toml_path = Path(path) if path is not None else _default_settings_path() + if not toml_path.is_file(): + raise PlenaSettingsError( + f"plena_settings.toml not found at {toml_path}. Set " + f"$PLENA_SETTINGS to point at it." + ) + with open(toml_path, "rb") as f: + cfg = tomllib.load(f) + + # The transactional (behavior) emulator hard-codes reading the + # [BEHAVIOR] section (see transactional_emulator load_config.rs — + # ``#[serde(rename = "BEHAVIOR")]``). The compiler must read the + # SAME section, or its geometry drifts from the emulator's: an + # analytic-mode MLEN=512 against a behavior-mode MLEN=64 emulator + # produces addresses too large for the 32-bit instruction immediate + # field. So we ignore [MODE].active and always use [BEHAVIOR]. + active = "behavior" + section = "BEHAVIOR" + if section not in cfg: + raise PlenaSettingsError( + f"{toml_path}: no [{section}] section exists" + ) + config = cfg[section].get("CONFIG", {}) + + def _val(key: str) -> int: + try: + return int(config[key]["value"]) + except KeyError as e: + raise PlenaSettingsError( + f"{toml_path}: [{section}.CONFIG.{key}] missing or has no " + f"'value' field" + ) from e + + return HardwareSizes( + mode=active, + mlen=_val("MLEN"), + hlen=_val("HLEN"), + blen=_val("BLEN"), + vlen=_val("VLEN"), + v_prefetch_amount=_val("HBM_V_Prefetch_Amount"), + v_writeback_amount=_val("HBM_V_Writeback_Amount"), + ) + + +# Convenience accessors — every call routes through the cached loader. +def mlen() -> int: + return load_sizes().mlen + + +def hlen() -> int: + return load_sizes().hlen + + +def blen() -> int: + return load_sizes().blen + + +def vlen() -> int: + return load_sizes().vlen + + +def v_prefetch_amount() -> int: + return load_sizes().v_prefetch_amount + + +def v_writeback_amount() -> int: + return load_sizes().v_writeback_amount + + +__all__ = [ + "HardwareSizes", + "PlenaSettingsError", + "load_sizes", + "mlen", + "hlen", + "blen", + "vlen", +] diff --git a/tilelang_tvm_compiler/pre_isa_ir.py b/tilelang_tvm_compiler/pre_isa_ir.py new file mode 100644 index 0000000..0d0fda6 --- /dev/null +++ b/tilelang_tvm_compiler/pre_isa_ir.py @@ -0,0 +1,274 @@ +"""PreIsaIR — the var-ref IR layer between IsaPass and the ISA emitter. + +Pipeline position: + + HLIR + | + v IsaPass — every handler appends PreIsaOps. Operand math + | (the PrimExpr addresses) is identical to the + | pre-PreIsaIR code path; what changes is the SINK: + | instead of materialising the expr to a GP register + | and emitting an ISA line, the handler records a + | PreIsaOp(opcode, operands=[PrimExpr, ...]). + v + PreIsaIR -- one PreIsaOp == one PLENA ISA instruction. + | operands are tir.PrimExpr / int / str (loop vars are + | still ``tir.Var``; nothing is bound to a GP yet). + | + v pre_isa_optimize — arith.simplify on every operand; CSE + | across PreIsaOps within a loop region; + | LICM hoisting subexprs that don't depend + | on the enclosing C_LOOP_START's binds var. + v + PreIsaIR (optimised) + | + v BackendEmit — walks the PreIsaOp stream linearly: + | for each op, materialises each PrimExpr operand + | to a GP register via the existing + | ExprMaterializer, then emits one ISA text line + | per the opcode's template. + v + ISA text + +Iron rule: every PreIsaOp emits exactly one PLENA HW instruction. +C_LOOP_START / C_LOOP_END are themselves PreIsaOps — a "for-loop" in +PreIsaIR is two HW-loop PreIsaOps with body PreIsaOps between them, +sitting in one flat stream. There is no PreIsaFor container — structure +is by C_LOOP_START / C_LOOP_END matching the way the HW executes. + +Operand semantics: + * tir.PrimExpr — symbolic address / value. Loop vars are unresolved + tir.Var; BackendEmit calls ExprMaterializer at lowering time. + * int — compile-time-known immediate (loop count, mask bit, flag). + * str — hard-coded token (e.g. "f0", "a3") dropped verbatim into + the ISA template by BackendEmit. + +This layer carries NO register decisions. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from tvm import tir + + +# All PLENA hardware mnemonics that may appear as a PreIsaOp.opcode. +# Sourced from the inventory of every single-ISA-line mnemonic the +# pre-PreIsaIR isa_emitter.py / isa_pass.py write into generated_code. +KNOWN_OPCODES = frozenset({ + # control — UNIFIED loop pair. The kind of loop is carried by + # ``annotations["loop_kind"]`` on the LOOP_START PreIsaOp: + # * "serial" (default) → BackendEmit emits a hardware loop: + # S_ST_INT (idx init) + C_LOOP_START gp_loop, extent + # + body once + idx-inc + C_LOOP_END gp_loop. + # * "unroll" → BackendEmit expands the body N times inline, + # binding loop_var to tir.IntImm(init+i) per iteration (no + # hardware C_LOOP, no idx slot, no loop_gp use). + # Optimisation passes treat both kinds uniformly (they're the + # same data shape); a "switch" pass can change ``loop_kind`` on + # any LOOP_START to flip the codegen strategy without touching + # the surrounding IR. + # + # IMPORTANT: PreIsaIR opcodes ``LOOP_START`` / ``LOOP_END`` are + # PreIsaIR-level control markers. The PLENA ISA *mnemonics* + # ``C_LOOP_START`` and ``C_LOOP_END`` that the hardware actually + # executes are emitted as plain text by BackendEmit; they don't + # appear as PreIsaIR opcodes (and need no entry here — they're + # just strings inside emit templates). + "LOOP_START", "LOOP_END", "C_BREAK", + "C_SET_V_MASK_REG", "C_SET_ADDR_REG", + "C_SET_SCALE_REG", "C_SET_STRIDE_REG", + "C_SET_FP_REG", "C_RUN_FP_KERNEL", + # scalar int + "S_ADDI_INT", "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", + "S_LUI_INT", "S_LD_INT", "S_ST_INT", + "S_SLLI_INT", "S_SRLI_INT", "S_SLL_INT", "S_SRL_INT", + "S_MV_INT", + # scalar fp + "S_LD_FP", "S_ST_FP", + "S_ADD_FP", "S_SUB_FP", "S_MUL_FP", "S_MAX_FP", + "S_EXP_FP", "S_RECI_FP", "S_SQRT_FP", + "S_MAP_FP_V", "S_MAP_V_FP", + "S_MV_FP", + # vector + "V_ADD_VV", "V_SUB_VV", "V_MUL_VV", + "V_ADD_VF", "V_SUB_VF", "V_MUL_VF", + "V_EXP_V", "V_RECI_V", "V_SQRT_V", + "V_RED_MAX", "V_RED_SUM", + "V_AND_VV", "V_OR_VV", "V_XOR_VV", "V_NOT_V", + "V_MAX_VV", "V_MIN_VV", "V_MAX_VF", "V_MIN_VF", + "V_SHIFT_V", "V_SHFTL_V", + # matrix + "M_BTMM", "M_BTMM_WO", "M_BMM_WO", + "M_BTMV", "M_BMV_WO", + "M_MV", "M_MV_WO", + "M_MM", "M_MM_WO", + "M_TMM", + # HBM + "H_LOAD_V", "H_STORE_V", "H_PREFETCH_V", "H_PREFETCH_M", + # meta — translated by BackendEmit into a "; ..." comment line. + "_COMMENT", + # meta — forces BackendEmit to materialise the operand PrimExpr + # into a GP register NOW (and cache it in the current group scope), + # without emitting its own ISA line. The S_ADDI_INT / S_LUI_INT + # that the materialiser writes is the actual on-disk evidence; + # this meta-op exists so the PreIsaPass producer can dictate the + # ORDER of address materialisations relative to the HW ops that + # use them, matching the legacy "materialise all addresses first, + # then emit all HW ops" pattern (see _emit_fp_scalar_op_at). + "_PRELOAD_ADDR", + # meta — emit ``S_ADDI_INT gp{N}, gp{N}, stride`` where ``gp{N}`` is + # the register the operand PrimExpr currently lives in (must already + # be in the BackendEmit group cache; produced earlier via a + # _PRELOAD_ADDR or implicit materialise). After the bump the cached + # GP holds ``expr + stride``, NOT the original value of ``expr``. + # This mirrors legacy's destructive in-place stride bump pattern in + # _emit_row_scalar_op_at (and emit_matmul / similar) where a single + # GP is walked across d_tile iterations via repeated S_ADDI_INTs. + # Operands: [cached_addr_expr (PrimExpr), stride (int)]. + "_BUMP_CACHED_GP", + # meta — allocate a PLENA addr register, load its value from the + # operand PrimExpr, and cache the binding under the operand's + # ``id()`` so subsequent PreIsaOps referencing the same Python + # object get the same ``aN`` token. Side-effect emit: + # * ``_load_large_int`` style S_ADDI_INT / S_LUI_INT sequence to + # materialise the value into a scratch GP + # * ``C_SET_ADDR_REG aN, gp0, gp{scratch}`` + # Operand: [addr_value_expr] (tir.PrimExpr). + # The producer must use the SAME Python PrimExpr object across + # all DMA PreIsaOps that reference the same addr register. + "_PRELOAD_ADDR_REG", + # Variant opcodes — same emitted ISA mnemonic as the corresponding + # non-prefixed entry but with cached-GP operand semantics for + # patterns that read the SAME GP across multiple PreIsaOps + # (row_*_at's d_tile unroll). The leading underscore avoids + # colliding with the canonical form's slot signature. + "_V_SUB_VF_ROW", "_V_ADD_VF_ROW", "_V_MUL_VF_ROW", + "_V_EXP_V_ROW", "_V_RECI_V_ROW", + "_S_ADDI_INT_RESET_MASK", + "_S_LD_FP_CACHED", "_S_ST_FP_CACHED", + # H_PREFETCH_V 6-operand variant (batch=1 path in legacy + # _emit_preload_tile_isa); see backend_emit._TEMPLATES. + "_H_PREFETCH_V_6OP", +}) + + +@dataclass +class PreIsaOp: + opcode: str + operands: List[Any] = field(default_factory=list) + binds: Optional[tir.Var] = None + annotations: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.opcode not in KNOWN_OPCODES: + raise ValueError( + f"PreIsaOp opcode {self.opcode!r} is not a known PLENA " + f"mnemonic. If this is a real HW instruction add it to " + f"KNOWN_OPCODES (pre_isa_ir.py) after confirming it is " + f"a single-ISA-line atom (one PreIsaOp == one ISA line)." + ) + + +@dataclass +class PreIsaModule: + name: str + ops: List[PreIsaOp] = field(default_factory=list) + buffers: Dict[str, Any] = field(default_factory=dict) + + def append(self, op: PreIsaOp) -> None: + self.ops.append(op) + + def comment(self, text: str) -> None: + self.ops.append(PreIsaOp(opcode="_COMMENT", operands=[text])) + + +def _fmt_operand(x: Any) -> str: + if isinstance(x, (int, float)): + return str(x) + if isinstance(x, str): + return x + if isinstance(x, tir.Var): + return x.name + return str(x) + + +def format_pre_isa(mod: PreIsaModule) -> str: + lines = [f"PreIsaModule(name={mod.name!r})", ""] + if mod.buffers: + lines.append("Buffers:") + name_w = max((len(n) for n in mod.buffers), default=4) + for name, b in mod.buffers.items(): + addr = getattr(b, "address", None) + scope = getattr(b, "scope", "?") + shape = getattr(b, "shape", ()) + shape_s = "x".join(str(s) for s in shape) if shape else "()" + addr_s = "" if addr is None else str(addr) + lines.append( + f" {name:<{name_w}} scope={scope:<5} addr={addr_s:<8} " + f"shape={shape_s}" + ) + lines.append("") + lines.append("Ops:") + indent = 2 + for idx, op in enumerate(mod.ops): + if op.opcode in _LOOP_END_OPCODES: + indent = max(2, indent - 4) + ind = " " * indent + if op.opcode == "_COMMENT": + text = op.operands[0] if op.operands else "" + lines.append(f"{ind}[{idx:4d}] ; {text}") + else: + ops_s = ", ".join(_fmt_operand(o) for o in op.operands) + binds_s = ( + f" binds={op.binds.name}" + if op.binds is not None + else "" + ) + note = op.annotations.get("comment", "") + note_s = f" ; {note}" if note else "" + lines.append( + f"{ind}[{idx:4d}] {op.opcode:<18} {ops_s}{binds_s}{note_s}" + ) + if op.opcode in _LOOP_START_OPCODES: + indent += 4 + return "\n".join(lines) + "\n" + + +_LOOP_START_OPCODES = ("LOOP_START",) +_LOOP_END_OPCODES = ("LOOP_END",) + + +def loop_regions( + ops: List[PreIsaOp], +) -> List[Tuple[int, int, Optional[tir.Var]]]: + """Yield ``(start_idx, end_idx, loop_var)`` for every + ``LOOP_START`` / ``LOOP_END`` pair. Both ``loop_kind="serial"`` + and ``loop_kind="unroll"`` loops use the same opcode pair, so + LICM / CSE / other passes can iterate this uniformly.""" + out: List[Tuple[int, int, Optional[tir.Var]]] = [] + stack: List[Tuple[int, Optional[tir.Var]]] = [] + for i, op in enumerate(ops): + if op.opcode == "LOOP_START": + stack.append((i, op.binds)) + elif op.opcode == "LOOP_END": + if not stack: + raise ValueError( + f"LOOP_END at [{i}] with no matching loop-start" + ) + start_idx, var = stack.pop() + out.append((start_idx, i, var)) + if stack: + raise ValueError( + f"unclosed loop-start(s) at indices " + f"{[s for s, _ in stack]} — missing end marker" + ) + return out + + +__all__ = [ + "PreIsaOp", "PreIsaModule", + "KNOWN_OPCODES", "format_pre_isa", "loop_regions", +] diff --git a/tilelang_tvm_compiler/pre_isa_ir_v2.py b/tilelang_tvm_compiler/pre_isa_ir_v2.py new file mode 100644 index 0000000..e919ee5 --- /dev/null +++ b/tilelang_tvm_compiler/pre_isa_ir_v2.py @@ -0,0 +1,328 @@ +"""PreIsaIR v2 — the clean rewrite. + +The original :mod:`pre_isa_ir` accumulated PreIsaOps for *register +materialisation control* (``_PRELOAD_ADDR``, ``_BUMP_CACHED_GP``, +``_PRELOAD_ADDR_REG``, ``_slot_expr_cached`` family, ``group_id``, +``unroll_scope``, ``scope_floor``, etc.). Those concerns were forced +into this layer because the original "BackendEmit" did +register-allocation and ISA emission in one mixed step. + +The new architecture splits responsibilities: + + PreIsaIR (this file) — what HW instruction, with what symbolic + operand. PrimExpr operands stay in their + most abstract form (full address algebra). + NO register materialisation concept. + MIR (mir.py) — explicit conversion of PrimExpr → SSA value + chains, def/use, loops with loop_kind, ready + for LICM / CSE / DCE / register allocation. + Backend (mir_to_isa) — mechanical SSA → physical-register dispatch. + +What's IN this PreIsaIR: + * PreIsaOp(opcode, operands) where: + - opcode is a literal PLENA ISA mnemonic ("M_MM", "H_PREFETCH_V", + "S_ADDI_INT", "S_LD_FP", ...). NO ``_*`` prefixed variants. + - operands is a list of: + ``tir.PrimExpr`` — any address algebra; PreIsaIR doesn't + fold or evaluate; the next pass does. + ``int`` — compile-time literal immediate. + ``str`` — verbatim token like "f0" / "f1" / "f2" + (FPU register names) or "gp0" (the + hardware-fixed constant-zero source on + instructions where it's encoded as part + of the ISA, e.g. ``S_ADDI_INT _, gp0, _`` + — and only those cases). + * LoopRegion(start, end, loop_var, init_imm, extent_imm, loop_kind) + where ``loop_kind`` is ``"serial"`` (HW C_LOOP) or ``"unroll"`` + (compile-time unrolled). The loop_var is a ``tir.Var`` that body + PreIsaOps reference in their operand PrimExprs. + * Comment lines (``_COMMENT`` opcode) for the human-readable dump. + +What's NOT in this PreIsaIR: + * NO ``_PRELOAD_ADDR`` / ``_PRELOAD_ADDR_REG`` / ``_BUMP_CACHED_GP`` + * NO ``group_id`` / ``close_order`` / ``unroll_scope`` annotations + * NO ``_slot_*_cached`` / ``_*_ROW`` opcode variants + * NO ``addr_reg`` / ``gp_reg`` numbers — those don't exist yet + +Producer contract: + For an HLIR op, the PreIsaPass producer emits a sequence of + PreIsaOps + LoopRegions that, taken together, semantically realise + the legacy ISA emission BUT with addresses left as PrimExprs. + Whether one M_MM ends up with a fresh ``S_ADDI_INT`` ahead of it, + or shares a hoisted register with another M_MM — none of that is + the producer's concern. PreIsaIR → MIR conversion + MIR LICM + decide. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from tvm import tir + + +# --------------------------------------------------------------------- +# Loop-strategy switch +# --------------------------------------------------------------------- +# When True, every LoopRegion built with ``loop_kind="unroll"`` is +# downgraded to ``"serial"`` (a hardware C_LOOP) at construction time. +# Emit-time unrolling was unsound: it cloned each iter's body into a +# scratch block (minting MirValues the precomputed last_use table never +# saw) and the linear-scan allocator then read released / never-tracked +# operands. Until a real MIR-level unroll pass exists, run everything as +# hardware loops. Flip to False to restore the (broken) emit-time +# unroll path for A/B debugging. +FORCE_SERIAL_LOOPS = True + + +# --------------------------------------------------------------------- +# Opcode set +# --------------------------------------------------------------------- + +# PreIsaIR opcodes are PLENA ISA mnemonics in their *atomic* form — +# one PreIsaOp per single PLENA instruction. No internal variants. +# The PreIsaIR → MIR pass turns each into a MIR instruction with +# operands lowered to SSA values; mir.OPCODES is the source of truth +# for what arguments each PLENA instruction takes. +# +# This set must stay a SUBSET of mir.OPCODES (every PreIsaIR opcode +# must have a corresponding MIR opcode). We don't import mir.OPCODES +# here to avoid a circular import; the conversion pass cross-checks. +KNOWN_OPCODES = frozenset({ + # control + "C_SET_V_MASK_REG", "C_SET_ADDR_REG", + "C_SET_SCALE_REG", "C_SET_STRIDE_REG", + # scalar int + "S_ADDI_INT", "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", + "S_LUI_INT", "S_LD_INT", "S_ST_INT", + "S_SLLI_INT", "S_SRLI_INT", + # scalar fp + "S_LD_FP", "S_ST_FP", + "S_ADD_FP", "S_SUB_FP", "S_MUL_FP", "S_MAX_FP", + "S_EXP_FP", "S_RECI_FP", "S_SQRT_FP", + "S_MAP_FP_V", "S_MAP_V_FP", + # vector + "V_ADD_VV", "V_SUB_VV", "V_MUL_VV", + "V_ADD_VF", "V_SUB_VF", "V_MUL_VF", + "V_EXP_V", "V_RECI_V", "V_SQRT_V", + "V_RED_MAX", "V_RED_SUM", + # matrix + "M_BTMM", "M_BMM_WO", + "M_BTMV", "M_BMV_WO", + "M_MV", "M_MV_WO", + "M_MM", "M_MM_WO", "M_TMM", + # HBM + "H_LOAD_V", "H_STORE_V", "H_PREFETCH_V", "H_PREFETCH_M", + # meta — emitted as ``; ...`` comment, not a real instruction. + "_COMMENT", +}) + + +# --------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------- + +# Operand of a PreIsaOp. +# * tir.PrimExpr — address algebra / value expression; loop_vars +# stay as tir.Var, hw constants as MLEN_VAR / BLEN_VAR +# (defined in hw_consts), IntImm for literals already known. +# * int — compile-time integer immediate (flag bits, loop +# counts on legacy callsites we haven't migrated yet). +# * str — verbatim token. The ONLY allowed verbatim +# strings are PLENA FPU register names ("f0", "f1", "f2") and +# the hardware-encoded constant-zero source "gp0" — these are +# part of the ISA encoding, not register-allocation decisions. +# * PreIsaOp — reference to the result of a previously emitted +# op. Used for ``addr_reg`` operands (the only currently-supported +# result-producing opcode is ``C_SET_ADDR_REG``). The MIR +# converter substitutes the producer's MirValue at lowering +# time. Producer must emit the referenced op BEFORE the +# consumer in the same module body. +PreIsaOperand = Union[tir.PrimExpr, int, str, "PreIsaOp"] + + +@dataclass +class PreIsaOp: + """One PLENA ISA instruction with symbolic operands. + + ``opcode`` is a PLENA mnemonic from :data:`KNOWN_OPCODES`. NO + ``_*``-prefixed variants — those were a leak of register- + materialisation concerns from PreIsaIR v1; in v2 the conversion + pass handles those concerns instead. + """ + + opcode: str + operands: List[PreIsaOperand] = field(default_factory=list) + # Free-form debug annotations (source HLIR op index, intrinsic + # name, etc.). No semantic meaning to any pass — feel free to + # add fields without coordinating. + annotations: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.opcode not in KNOWN_OPCODES: + raise ValueError( + f"PreIsaOp opcode {self.opcode!r} is not a known PLENA " + f"mnemonic in pre_isa_ir_v2. If this is a real HW " + f"instruction, add it to KNOWN_OPCODES (and ensure " + f"mir.OPCODES has a matching entry)." + ) + + +@dataclass +class LoopRegion: + """A loop block in the PreIsaIR program structure. + + ``loop_var`` is a ``tir.Var``. Body PreIsaOps reference it via + PrimExpr operands; the PreIsaIR → MIR conversion will lower it + to a MirValue defined by a ``_LOOP_VAR_DEF`` MirInstr at the top + of the loop body block. + + ``loop_kind`` is ``"serial"`` (HW C_LOOP_START / C_LOOP_END) or + ``"unroll"`` (compile-time body replay with loop_var bound to + IntImm per iteration). Optimisation passes may rewrite this + attribute — it's a strategy hint, not a structural property. + + ``body`` is a flat sequence of PreIsaOps + LoopRegions (nested + loops). Order is source order; the conversion pass walks it + verbatim. + + Producer contract for ``loop_var`` choice: every LoopRegion's + loop_var must be a FRESH ``tir.Var`` instance (don't reuse a Var + across LoopRegions; the conversion pass relies on identity). + """ + + loop_var: tir.Var + init_imm: int + extent_imm: int + body: List[Union["PreIsaOp", "LoopRegion"]] = field(default_factory=list) + loop_kind: str = "serial" # "serial" or "unroll" + annotations: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.loop_kind not in ("serial", "unroll"): + raise ValueError( + f"LoopRegion.loop_kind must be 'serial' or 'unroll'; " + f"got {self.loop_kind!r}" + ) + # Global force-serial: emit-time unrolling (clone body into a + # scratch block, fold loop_var→const) created emit-time MirValues + # the single-walk last_use table never saw, corrupting register + # allocation. Until a proper MIR-level unroll pass exists, every + # LoopRegion lowers to a hardware C_LOOP. The address arithmetic + # (base + loop_var*stride) that unroll used to const-fold is + # computed at runtime instead — mathematically equivalent. + # Python-side static unrolls (plain ``for`` in the generator, + # e.g. matmul's partial-tile n_mlen sweep) are NOT LoopRegions + # and are unaffected. + if FORCE_SERIAL_LOOPS and self.loop_kind == "unroll": + self.loop_kind = "serial" + + +@dataclass +class PreIsaModule: + """One kernel's PreIsaIR — a flat sequence of PreIsaOps and + LoopRegions at the top level. Loops can nest arbitrarily. + """ + + name: str + body: List[Union[PreIsaOp, LoopRegion]] = field(default_factory=list) + # Buffer table forwarded from HLIR. Used by the dump + the + # MIR conversion pass for buffer-address resolution. + buffers: Dict[str, Any] = field(default_factory=dict) + + def append(self, item: Union[PreIsaOp, LoopRegion]) -> None: + self.body.append(item) + + def comment(self, text: str) -> None: + self.body.append(PreIsaOp(opcode="_COMMENT", operands=[text])) + + +# --------------------------------------------------------------------- +# Dump +# --------------------------------------------------------------------- + +def _fmt_operand(op: PreIsaOperand) -> str: + if isinstance(op, str): + return op + if isinstance(op, int): + return str(op) + if isinstance(op, tir.IntImm): + return str(int(op.value)) + if isinstance(op, tir.Var): + return op.name + # Generic PrimExpr — str() preserves loop var names. + return str(op) + + +def format_pre_isa_v2(mod: PreIsaModule) -> str: + """Pretty-print PreIsaIR v2 — used for ``.pre_isa.txt``.""" + lines = [f"PreIsaModule(name={mod.name!r}):"] + if mod.buffers: + lines.append(" Buffers:") + name_w = max((len(n) for n in mod.buffers), default=4) + for nm, b in mod.buffers.items(): + scope = getattr(b, "scope", "?") + shape = getattr(b, "shape", ()) + addr = getattr(b, "address", None) + shape_s = "x".join(str(s) for s in shape) if shape else "()" + addr_s = "?" if addr is None else str(addr) + lines.append( + f" {nm:<{name_w}} scope={scope:<5} addr={addr_s} " + f"shape={shape_s}" + ) + lines.append(" Body:") + for item in mod.body: + _fmt_item(item, lines, indent=4) + return "\n".join(lines) + "\n" + + +def _fmt_item(item, lines, indent): + ind = " " * indent + if isinstance(item, PreIsaOp): + if item.opcode == "_COMMENT": + text = item.operands[0] if item.operands else "" + lines.append(f"{ind}; {text}") + return + ops_s = ", ".join(_fmt_operand(o) for o in item.operands) + lines.append(f"{ind}{item.opcode:<18} {ops_s}") + return + if isinstance(item, LoopRegion): + lines.append( + f"{ind}loop {item.loop_var.name} in " + f"[{item.init_imm}, {item.init_imm + item.extent_imm}) " + f"[kind={item.loop_kind}]" + ) + for inner in item.body: + _fmt_item(inner, lines, indent + 2) + return + raise TypeError( + f"_fmt_item: expected PreIsaOp or LoopRegion, got " + f"{type(item).__name__}" + ) + + +# --------------------------------------------------------------------- +# Sanity / structural helpers +# --------------------------------------------------------------------- + +def loop_regions( + body: List[Union[PreIsaOp, LoopRegion]], +) -> List[Tuple[LoopRegion, int]]: + """Return ``(loop, depth)`` for every LoopRegion in pre-order.""" + out: List[Tuple[LoopRegion, int]] = [] + + def _walk(items, depth): + for it in items: + if isinstance(it, LoopRegion): + out.append((it, depth)) + _walk(it.body, depth + 1) + _walk(body, 0) + return out + + +__all__ = [ + "PreIsaOp", "LoopRegion", "PreIsaModule", + "KNOWN_OPCODES", "PreIsaOperand", + "format_pre_isa_v2", "loop_regions", +] diff --git a/tilelang_tvm_compiler/pre_isa_pass.py b/tilelang_tvm_compiler/pre_isa_pass.py new file mode 100644 index 0000000..c7412e0 --- /dev/null +++ b/tilelang_tvm_compiler/pre_isa_pass.py @@ -0,0 +1,3257 @@ +"""PreIsaPass — lower HLIR to PreIsaIR. + +Replaces the emit half of the legacy ``IsaEmitterPass``. For each HLIR +op: + * Same address-math code as before (reuses helpers on ``IsaEmitterPass`` + by composition; this pass holds an instance to delegate to for now, + while the migration is in progress). + * Instead of calling ``materializer.materialize`` + ``ISAEmitter.emit_*`` + + ``generated_code +=``, the handler appends one or more PreIsaOps + to ``self.pre_isa`` (a ``PreIsaModule``). Operands are kept as + ``tir.PrimExpr`` (var-ref form); register allocation happens later + in :class:`backend_emit.BackendEmit`. + +Iron rule: one PreIsaOp == one HW ISA instruction. ``C_LOOP_START`` / +``C_LOOP_END`` are themselves PreIsaOps in the same flat stream. + +Migration status: + Migrated handlers (produce PreIsaIR): + * fp_zero_at + All other op kinds delegate to the legacy ``IsaEmitterPass`` and run + through the byte-equal old path. As each handler migrates, the + delegate fallback is removed. +""" + +from __future__ import annotations + +import warnings +from typing import Dict, Callable + +from tvm import tir + +from . import hlir as _hlir +from . import scope as _scope +from .hw_consts import ( + BLEN_VAR, BTMM_HLEN_VAR, MLEN_VAR, + V_PREFETCH_AMOUNT_VAR, V_WRITEBACK_AMOUNT_VAR, +) +from .pre_isa_ir import PreIsaModule, PreIsaOp + + +class PreIsaPassError(RuntimeError): + pass + + +class PreIsaPass: + """Lower an HLIRModule into a PreIsaModule. + + Construction takes the same shim as the legacy IsaEmitterPass — + we reuse its address-resolution helpers (``_resolve_fp_scalar_addr_arg`` + in particular) by composing the legacy pass instance. + """ + + def __init__(self, shim) -> None: + # Lazy import to avoid the dependency cycle isa_pass ↔ pre_isa_pass + # at module-import time. + from .isa_pass import IsaEmitterPass + self.shim = shim + self._legacy = IsaEmitterPass(shim) + self.pre_isa = PreIsaModule(name="") + + # Counter for ``annotations["group_id"]`` — BackendEmit groups + # consecutive PreIsaOps sharing this id into one materialisation + # scope (so e.g. an FP-binary's 3 addresses get materialised + # once across its 5 ISA lines, mirroring the legacy + # ra.pin_gp pattern in IsaEmitterPass). + self._next_group_id: int = 0 + + self._dispatch: Dict[str, Callable[[_hlir.HLIRModule, _hlir.Op], None]] = { + "fp_zero_at": self._emit_fp_zero_at, + "fp_copy_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="copy"), + "fp_exp_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="exp"), + "fp_reci_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="reci"), + "fp_sqrt_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="sqrt"), + "fp_add_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="add"), + "fp_sub_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="sub"), + "fp_mul_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="mul"), + "fp_max_at": lambda m, o: self._emit_fp_scalar_op_at(m, o, kernel_op="max"), + "v_zero": self._emit_v_zero, + "v_add": lambda m, o: self._emit_v_binary(m, o, binary_op="add"), + "v_sub": lambda m, o: self._emit_v_binary(m, o, binary_op="sub"), + "v_mul": lambda m, o: self._emit_v_binary(m, o, binary_op="mul"), + "v_exp": lambda m, o: self._emit_v_unary(m, o, opcode="V_EXP_V"), + "v_reci": lambda m, o: self._emit_v_unary(m, o, opcode="V_RECI_V"), + "v_sqrt": lambda m, o: self._emit_v_unary(m, o, opcode="V_SQRT_V"), + "copy_v_to_v": self._emit_copy_v_to_v, + "v_fp_transfer_slice_v_to_fp": + lambda m, o: self._emit_v_fp_transfer_slice(m, o, direction="v_to_fp"), + "v_fp_transfer_slice_fp_to_v": + lambda m, o: self._emit_v_fp_transfer_slice(m, o, direction="fp_to_v"), + "for": self._emit_for, + "row_reduce_max_at": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="reduce_max", reduce=True, masked=True, + ), + "row_reduce_sum_at": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="reduce_sum", reduce=True, masked=True, + ), + "row_exp": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="exp", masked=True, + ), + "row_sub_fp": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="sub", masked=True, has_fp=True, + ), + "row_mul_fp": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="mul", masked=True, has_fp=True, + ), + "row_add_fp": + lambda m, o: self._emit_row_scalar_op_at( + m, o, row_op="add", masked=True, has_fp=True, + ), + "btmm": self._emit_btmm, + "btmv": self._emit_btmv, + "mv": self._emit_mv, + "mm": self._emit_mm, + "mm_slot": self._emit_mm_slot, + "matmul": self._emit_matmul, + "dma_h2v": self._emit_dma_h2v, + "dma_h2m": self._emit_dma_h2m, + "dma_v2h": self._emit_dma_v2h, + "dma_h2v_slice": self._emit_dma_h2v_slice, + "dma_h2m_slice": self._emit_dma_h2m_slice, + "dma_v2h_slice": self._emit_dma_v2h_slice, + } + + def _new_group(self) -> int: + gid = self._next_group_id + self._next_group_id += 1 + return gid + + def run(self, mod: _hlir.HLIRModule) -> PreIsaModule: + """Produce a PreIsaModule for ``mod``. Buffers / addresses are + forwarded verbatim onto the PreIsaModule for BackendEmit / + the dump.""" + _hlir.assert_addresses_resolved(mod) + self.pre_isa = PreIsaModule(name=mod.name, buffers=dict(mod.buffers)) + self.pre_isa.comment(f"PLENA ISA -- kernel: {mod.name}") + self.pre_isa.comment("generated by tilelang_tvm_compiler (PreIsaIR path)") + self.pre_isa.comment("=" * 60) + self.pre_isa.comment("buffer layout:") + for buf in mod.buffers.values(): + shape_s = "x".join(str(s) for s in buf.shape) + self.pre_isa.comment( + f" {buf.name:<10s} scope={buf.scope:<5s} addr={buf.address} " + f"shape={shape_s}" + ) + self.pre_isa.comment("=" * 60) + self.pre_isa.comment("") + for op in mod.ops: + handler = self._dispatch.get(op.kind) + if handler is None: + raise PreIsaPassError( + f"PreIsaPass: no handler migrated yet for HLIR op " + f"kind {op.kind!r}. While migration is in progress " + f"callers must dispatch to the legacy IsaEmitterPass " + f"for non-migrated kinds." + ) + handler(mod, op) + return self.pre_isa + + # ================================================================== + # migrated handlers + # ================================================================== + def _emit_fp_zero_at(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Store FP zero to one FPRAM slot. + + Legacy emission (isa_pass._emit_fp_zero_at): + ; fp scalar task op=zero + S_ST_FP f0, gp{dst}, 0 + + PreIsaIR emission (var-ref; one PreIsaOp per ISA line): + _COMMENT "fp scalar task op=zero" + S_ST_FP ["f0", , 0] + """ + if len(op.scalar_args) != 1: + raise PreIsaPassError( + f"{op.kind} expects 1 scalar address arg, got {len(op.scalar_args)}" + ) + dst_addr_expr = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "dst", + ) + intrinsic = op.annotations.get("intrinsic", op.kind) + gid = self._new_group() + # Order matches legacy isa_pass._emit_fp_zero_at: there is no + # upfront materialisation loop in legacy fp_zero_at (only one + # address) — the S_ADDI_INT comes from materialise() right + # before the S_ST_FP. So we DON'T emit a _PRELOAD_ADDR here. + self.pre_isa.append( + PreIsaOp( + opcode="_COMMENT", + operands=[f"fp scalar task {intrinsic} op=zero"], + annotations={"group_id": gid}, + ) + ) + self.pre_isa.append( + PreIsaOp( + opcode="S_ST_FP", + operands=["f0", dst_addr_expr, 0], + annotations={"group_id": gid}, + ) + ) + + def _emit_fp_scalar_op_at( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, kernel_op: str, + ) -> None: + """Mirror of legacy isa_pass._emit_fp_scalar_op_at — FP scalar + load/op/store sequence on FPRAM operands. + + For copy/exp/reci/sqrt (2 scalar_args = src, dst): + ; fp scalar task op= + S_LD_FP f1, gp{src}, 0 + [S_EXP_FP f1, f1, 0 | S_RECI_FP f1, f1 | S_SQRT_FP f1, f1] (only for exp/reci/sqrt) + S_ST_FP f1, gp{dst}, 0 + + For add/sub/mul/max (3 scalar_args = lhs, rhs, dst): + ; fp scalar task op= + S_LD_FP f1, gp{lhs}, 0 + S_LD_FP f2, gp{rhs}, 0 + f1, f1, f2 + S_ST_FP f1, gp{dst}, 0 + + All addresses go through ``_resolve_fp_scalar_addr_arg`` so a + ``BufferElement`` resolves to ``buf.address + flat-offset`` — + identical to legacy. The address PrimExpr objects are CREATED + ONCE and reused across the multiple PreIsaOps in this op so + BackendEmit's group cache materialises each address one time. + """ + if kernel_op in {"copy", "exp", "reci", "sqrt"}: + expected = 2 + else: + expected = 3 + if len(op.scalar_args) != expected: + raise PreIsaPassError( + f"{op.kind} expects {expected} scalar address args, got {len(op.scalar_args)}" + ) + + addr_exprs = [ + self._legacy._resolve_fp_scalar_addr_arg( + mod, a, op.kind, f"arg{i}", + ) + for i, a in enumerate(op.scalar_args) + ] + intrinsic = op.annotations.get("intrinsic", op.kind) + gid = self._new_group() + + def _stamp(op_): + op_.annotations["group_id"] = gid + return op_ + + # Order matches legacy isa_pass._emit_fp_scalar_op_at: + # 1. for each address, materialise (S_ADDI_INT / S_LUI_INT) + # 2. emit the ``; fp scalar task ...`` comment + # 3. emit the S_LD_FP / OP / S_ST_FP burst + # PreIsaIR encodes step 1 as _PRELOAD_ADDR meta-ops so the + # group cache pre-populates before the FP ops materialise + # operands lazily. + for addr in addr_exprs: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_COMMENT", + operands=[f"fp scalar task {intrinsic} op={kernel_op}"], + ))) + if kernel_op in {"copy", "exp", "reci", "sqrt"}: + src, dst = addr_exprs + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_LD_FP", operands=["f1", src, 0], + ))) + if kernel_op == "exp": + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_EXP_FP", operands=["f1", "f1", 0], + ))) + elif kernel_op == "reci": + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_RECI_FP", operands=["f1", "f1"], + ))) + elif kernel_op == "sqrt": + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_SQRT_FP", operands=["f1", "f1"], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_ST_FP", operands=["f1", dst, 0], + ))) + else: + lhs, rhs, dst = addr_exprs + opcode_map = { + "add": "S_ADD_FP", + "sub": "S_SUB_FP", + "mul": "S_MUL_FP", + "max": "S_MAX_FP", + } + opcode = opcode_map[kernel_op] + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_LD_FP", operands=["f1", lhs, 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_LD_FP", operands=["f2", rhs, 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode=opcode, operands=["f1", "f1", "f2"], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_ST_FP", operands=["f1", dst, 0], + ))) + + # ------------------------------------------------------------------ + # vector ops — VRAM region-based, walks chunks via the legacy + # ``_vram_region_iter_chunks`` helper. One HLIR op may emit N + # PreIsaOps (N = total chunk count across non-cluster outer dims). + # ------------------------------------------------------------------ + def _emit_v_zero(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_v_zero``: + ; v_zero dst.parent=... starts=... extents=... + for each chunk: + S_ADDI_INT gp{r}, gp0, dst.address + d_off + V_MUL_VF gp{r}, gp{r}, f0, 0 + """ + if len(op.buffer_args) != 1: + raise PreIsaPassError( + f"v_zero expects 1 buffer_arg; got {len(op.buffer_args)}" + ) + if not isinstance(op.buffer_args[0], _hlir.VramRegion): + raise PreIsaPassError( + f"v_zero dst: expected VramRegion, got " + f"{type(op.buffer_args[0]).__name__}" + ) + if op.scalar_args: + raise PreIsaPassError( + f"v_zero expects 0 scalar_args; got {len(op.scalar_args)}" + ) + dst_region: _hlir.VramRegion = op.buffer_args[0] + dst = mod.get_buffer(dst_region.parent) + self.pre_isa.comment( + f"v_zero dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + for d_off, _ in self._legacy._vram_region_iter_chunks(dst, dst_region): + dst_addr = tir.Add( + tir.IntImm("int32", int(dst.address)), d_off, + ) + gid = self._new_group() + # V_MUL_VF dst, dst, f0, 0 — both vector operand slots + # point at the SAME ``dst_addr`` PrimExpr object so + # BackendEmit's id()-keyed group cache materialises it + # once and reuses the GP for both gpN positions. + self.pre_isa.append(PreIsaOp( + opcode="V_MUL_VF", + operands=[dst_addr, dst_addr, "f0", 0], + annotations={"group_id": gid}, + )) + + def _emit_v_binary(self, mod: _hlir.HLIRModule, op: _hlir.Op, + *, binary_op: str) -> None: + """Mirror of legacy ``isa_pass._emit_v_binary``. + ; v binary dst.parent=... starts=... extents=... + for each (l_off, r_off, d_off): + S_ADDI_INT gp{lhs}, ... + S_ADDI_INT gp{rhs}, ... + S_ADDI_INT gp{dst}, ... + gp{dst}, gp{lhs}, gp{rhs}, 0 + """ + op_to_insn = {"add": "V_ADD_VV", "sub": "V_SUB_VV", "mul": "V_MUL_VV"} + opcode = op_to_insn[binary_op] + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"{op.kind} expects 3 buffer_args; got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("lhs", "rhs", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise PreIsaPassError( + f"{op.kind} expects 0 scalar_args; got {len(op.scalar_args)}" + ) + lhs_region: _hlir.VramRegion = op.buffer_args[0] + rhs_region: _hlir.VramRegion = op.buffer_args[1] + dst_region: _hlir.VramRegion = op.buffer_args[2] + lhs = mod.get_buffer(lhs_region.parent) + rhs = mod.get_buffer(rhs_region.parent) + dst = mod.get_buffer(dst_region.parent) + self.pre_isa.comment( + f"v binary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + lhs_iter = self._legacy._vram_region_iter_chunks(lhs, lhs_region) + rhs_iter = self._legacy._vram_region_iter_chunks(rhs, rhs_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (l_off, _), (r_off, _), (d_off, _) in zip( + lhs_iter, rhs_iter, dst_iter, + ): + lhs_addr = tir.Add(tir.IntImm("int32", int(lhs.address)), l_off) + rhs_addr = tir.Add(tir.IntImm("int32", int(rhs.address)), r_off) + dst_addr = tir.Add(tir.IntImm("int32", int(dst.address)), d_off) + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + # Preload order = legacy materialise order = lhs, rhs, dst. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[lhs_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[rhs_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr], + ))) + # ISA op operand order in template = dst, lhs, rhs, 0 + # (matches legacy + # ``f"{opcode} gp{m_dst}, gp{m_lhs}, gp{m_rhs}, 0\n"``). + self.pre_isa.append(_stamp(PreIsaOp( + opcode=opcode, + operands=[dst_addr, lhs_addr, rhs_addr, 0], + ))) + + def _emit_v_unary(self, mod: _hlir.HLIRModule, op: _hlir.Op, + *, opcode: str) -> None: + """Mirror of legacy ``isa_pass._emit_v_unary``: + ; v unary dst.parent=... starts=... extents=... + for each (s_off, d_off): + S_ADDI_INT gp{src}, ... + S_ADDI_INT gp{dst}, ... + gp{dst}, gp{src}, 0 + """ + if len(op.buffer_args) != 2: + raise PreIsaPassError( + f"{op.kind} expects 2 buffer_args; got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise PreIsaPassError( + f"{op.kind} expects 0 scalar_args; got {len(op.scalar_args)}" + ) + src_region: _hlir.VramRegion = op.buffer_args[0] + dst_region: _hlir.VramRegion = op.buffer_args[1] + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + self.pre_isa.comment( + f"v unary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + src_iter = self._legacy._vram_region_iter_chunks(src, src_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = tir.Add(tir.IntImm("int32", int(src.address)), s_off) + dst_addr = tir.Add(tir.IntImm("int32", int(dst.address)), d_off) + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[src_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode=opcode, + operands=[dst_addr, src_addr, 0], + ))) + + # ------------------------------------------------------------------ + # VRAM <-> VRAM copy and VRAM <-> FPRAM transfer + # ------------------------------------------------------------------ + def _emit_copy_v_to_v(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_copy_v_to_v``: + ; copy_v_to_v ... + for each chunk: + S_ADDI_INT gp{src}, ... + S_ADDI_INT gp{dst}, ... + V_ADD_VF gp{dst}, gp{src}, f0, 0 + """ + if len(op.buffer_args) != 2: + raise PreIsaPassError( + f"copy_v_to_v expects 2 buffer_args; got {len(op.buffer_args)}" + ) + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassError( + f"copy_v_to_v {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + if op.scalar_args: + raise PreIsaPassError( + f"copy_v_to_v expects 0 scalar_args; got {len(op.scalar_args)}" + ) + src_region: _hlir.VramRegion = op.buffer_args[0] + dst_region: _hlir.VramRegion = op.buffer_args[1] + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + self.pre_isa.comment( + f"copy_v_to_v src.parent={src_region.parent} -> " + f"dst.parent={dst_region.parent} " + f"extents={list(dst_region.extents)!r}" + ) + src_iter = self._legacy._vram_region_iter_chunks(src, src_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = tir.Add(tir.IntImm("int32", int(src.address)), s_off) + dst_addr = tir.Add(tir.IntImm("int32", int(dst.address)), d_off) + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + # Legacy materialise order = src, dst. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[src_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr], + ))) + # V_ADD_VF gp{dst}, gp{src}, f0, 0 + self.pre_isa.append(_stamp(PreIsaOp( + opcode="V_ADD_VF", + operands=[dst_addr, src_addr, "f0", 0], + ))) + + def _emit_v_fp_transfer_slice( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, direction: str, + ) -> None: + """Mirror of legacy ``isa_pass._emit_v_fp_transfer_slice``. + + direction == "v_to_fp": + ; v↔fp transfer slice ... + for each chunk: + S_ADDI_INT gp{vram}, ... + S_ADDI_INT gp{fp}, ... + S_MAP_FP_V gp{fp}, gp{vram}, 0 + direction == "fp_to_v": + same prelude, then S_MAP_V_FP gp{vram}, gp{fp}, 0 + """ + if len(op.buffer_args) != 1 or not isinstance(op.buffer_args[0], _hlir.VramRegion): + raise PreIsaPassError( + f"{op.kind}: buffer_args[0] must be VramRegion" + ) + if len(op.scalar_args) != 1: + raise PreIsaPassError( + f"{op.kind}: expected 1 scalar arg (fp_addr); got {len(op.scalar_args)}" + ) + region: _hlir.VramRegion = op.buffer_args[0] + vram = mod.get_buffer(region.parent) + fp_addr_base = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + opcode = "S_MAP_FP_V" if direction == "v_to_fp" else "S_MAP_V_FP" + self.pre_isa.comment( + f"v↔fp transfer slice {op.kind} parent={region.parent} " + f"starts={list(region.starts)!r} extents={list(region.extents)!r}" + ) + for vram_off_expr, fp_step in self._legacy._vram_region_iter_chunks(vram, region): + vram_addr = tir.Add( + tir.IntImm("int32", int(vram.address)), vram_off_expr, + ) + fp_chunk_addr = ( + fp_addr_base if fp_step == 0 + else tir.Add(fp_addr_base, tir.IntImm("int32", int(fp_step))) + ) + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + # Legacy materialise order = vram, fp. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[vram_addr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[fp_chunk_addr], + ))) + if direction == "v_to_fp": + # S_MAP_FP_V gp{fp}, gp{vram}, 0 + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_MAP_FP_V", + operands=[fp_chunk_addr, vram_addr, 0], + ))) + else: + # S_MAP_V_FP gp{vram}, gp{fp}, 0 + self.pre_isa.append(_stamp(PreIsaOp( + opcode="S_MAP_V_FP", + operands=[vram_addr, fp_chunk_addr, 0], + ))) + + + # ------------------------------------------------------------------ + # for-loop: hardware C_LOOP_START / C_LOOP_END + # ------------------------------------------------------------------ + def _emit_for(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_for``. + + Serial form: + ; for {var} in [init, init+extent) -- hw counter gp{loop_gp}, idx ram[{idx}] + [S_ADDI_INT gp{init}, gp0, init_imm] (only when init != 0) + S_ST_INT gp{init or 0}, gp0, {idx_addr} + C_LOOP_START gp{loop_gp}, extent + ... body PreIsaOps ... + ; idx {var} += 1 (ram[{idx_addr}]) + S_LD_INT gp{inc}, gp0, {idx_addr} + S_ADDI_INT gp{inc}, gp{inc}, 1 + S_ST_INT gp{inc}, gp0, {idx_addr} + C_LOOP_END gp{loop_gp} + + Unrolled form: bind loop_var to an IntImm per iteration; emit + body N times back-to-back (no C_LOOP_*, no idx slot). + + BackendEmit's _emit_c_loop_start / _emit_c_loop_end handle + the prelude/epilogue ISA emission so the iron-rule + (one PreIsaOp == one HW instruction) is preserved at the HW + level — the multi-instruction prelude is "address materialise" + side-effect of the C_LOOP_START PreIsaOp (mirroring how + S_ADDI_INT setup for any address operand is side-effect of + the using PreIsaOp). + """ + loop_var = op.annotations.get("loop_var") + extent = op.annotations.get("extent") + init = op.annotations.get("init", 0) + if loop_var is None or extent is None: + raise PreIsaPassError( + f"for-op missing loop_var or extent annotation: {op!r}" + ) + if not isinstance(extent, (int, tir.IntImm)): + raise PreIsaPassError( + f"for-op extent must be a compile-time integer; got " + f"{type(extent).__name__}: {extent!r}" + ) + if not isinstance(init, (int, tir.IntImm)): + raise PreIsaPassError( + f"for-op init must be a compile-time integer; got " + f"{type(init).__name__}: {init!r}" + ) + extent_imm = int(extent.value) if isinstance(extent, tir.IntImm) else int(extent) + init_imm = int(init.value) if isinstance(init, tir.IntImm) else int(init) + loop_kind = op.annotations.get("loop_kind", "serial") + + # ----- compile-time unrolled loop ----- + if loop_kind in ("unroll", "unrolled"): + # Produce a single LOOP_START / LOOP_END pair with + # loop_kind="unroll". BackendEmit's run() walker detects + # this and re-emits the body N times, binding loop_var to + # IntImm(init+i) per iteration. Mirrors legacy emit_for's + # unrolled branch (no idx slot, no loop_gp use, no + # hardware C_LOOP_* lines). + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[init_imm, extent_imm], + binds=loop_var, + annotations={"loop_kind": "unroll"}, + )) + for sub_op in op.body or []: + handler = self._dispatch.get(sub_op.kind) + if handler is None: + raise PreIsaPassError( + f"PreIsaPass: no handler migrated for body op " + f"kind {sub_op.kind!r} inside unrolled for-loop" + ) + handler(mod, sub_op) + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + return + + # ----- serial hardware loop ----- + loop_gp = op.annotations.get("loop_gp") + if loop_gp is None: + raise PreIsaPassError( + f"serial for-op {loop_var.name!r} has no 'loop_gp' " + f"annotation; loop_register_alloc must run before pre_isa_pass" + ) + + # LOOP_START / LOOP_END with loop_kind="serial" → BackendEmit + # emits the hardware C_LOOP_START / C_LOOP_END ISA + idx + # init / increment. ``loop_gp`` comes from the HLIR op's + # loop_register_alloc stamp. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[init_imm, extent_imm], + binds=loop_var, + annotations={ + "loop_kind": "serial", + "loop_gp": int(loop_gp), + }, + )) + + for sub_op in op.body or []: + handler = self._dispatch.get(sub_op.kind) + if handler is None: + raise PreIsaPassError( + f"PreIsaPass: no handler migrated for body op kind " + f"{sub_op.kind!r} inside for-loop" + ) + handler(mod, sub_op) + + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "serial"}, + )) + + # ------------------------------------------------------------------ + # matmul family — decompose the multi-line legacy emit_* helpers + # into PreIsaOp streams so address PrimExprs are exposed to the + # optimiser (LICM / CSE / arith.simplify can see the loop-var + # dependencies that were previously buried inside the helper). + # ------------------------------------------------------------------ + def _emit_btmm(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_btmm`` which calls + ``ISAEmitter.emit_btmm`` + ``emit_btmm_wo``. + + Decomposed PreIsaOp stream: + _COMMENT "btmm task ..." + _PRELOAD_ADDR rhs_mram_addr_expr -> S_ADDI_INT + _PRELOAD_ADDR lhs_packed_vram_addr_expr -> S_ADDI_INT + M_BTMM ["gp0", rhs_expr, lhs_expr] + _COMMENT "btmm write-only task ..." + _PRELOAD_ADDR dst_addr_expr -> S_ADDI_INT + M_BMM_WO [dst_expr, 0] + + Each line is one PreIsaOp; operand PrimExprs preserve the + address algebra so an optimiser can hoist / CSE them. + + Two ``group_id``s — one per legacy emit_* call — match + legacy's two ``allocate_gp`` cycles (emit_btmm allocates 2, + emit_btmm_wo allocates 1, no GP reuse across them). + """ + # ---------- validation: mirror legacy ---------- + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"plena.btmm expects 3 buffer_args (regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.btmm a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassError( + f"plena.btmm b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.btmm c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + + # Result-tile-count for the writeback (mirrors legacy). + tile_count = max(1, dst.num_elements // self.shim.tile_elems) + task_id = op.annotations.get("intrinsic", "btmm") + lane_count = self.shim.btmm_lane_count + head_width = self.shim.btmm_hlen + + # ---------- emit_btmm equivalent ---------- + # Build PrimExpr objects for the three base addresses. These + # are simple IntImm now (legacy passes ``lhs.address`` — + # already a literal int) but kept as PrimExpr so the + # optimiser can compose them with loop-var-dependent offsets + # later when btmm is inside a for-loop. + rhs_addr_expr = tir.IntImm("int32", int(rhs.address)) + lhs_addr_expr = tir.IntImm("int32", int(lhs.address)) + + gid_btmm = self._new_group() + first_btmm_op = True + + def _stamp_btmm(o): + nonlocal first_btmm_op + o.annotations["group_id"] = gid_btmm + if first_btmm_op: + # Legacy emit_btmm ends with ``ra.free_gp(gp_regs)`` — + # a batched free that, given the LIFO free pool, + # leaves the LAST-allocated reg on top. We mirror this + # by closing the group in INSERTION order. + o.annotations["close_order"] = "insertion" + first_btmm_op = False + return o + + # NOTE: a _COMMENT does NOT trigger _enter_group_for (it's + # skipped), so close_order annotations on a _COMMENT would + # never be read. Put the close_order on the first real op + # (the first _PRELOAD_ADDR below) instead. Comments still get + # the group_id annotation just for the dump's readability. + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"btmm task {task_id} " + f"lhs_packed=vram[{int(lhs.address)}] " + f"rhs_mram={int(rhs.address)} " + f"lanes={lane_count} head_width={head_width}" + ], + annotations={"group_id": gid_btmm}, + )) + # Legacy ``emit_btmm`` calls ``allocate_gp(2)`` and assigns + # [gp_mram_base, gp_lhs_base] in that order. Materialising rhs + # FIRST mirrors that allocation order — the first preload + # claims the first GP, the second preload claims the second. + self.pre_isa.append(_stamp_btmm(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[rhs_addr_expr], + ))) + self.pre_isa.append(_stamp_btmm(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[lhs_addr_expr], + ))) + # M_BTMM gp0, gp{rhs}, gp{lhs} + self.pre_isa.append(_stamp_btmm(PreIsaOp( + opcode="M_BTMM", + operands=["gp0", rhs_addr_expr, lhs_addr_expr], + ))) + + # ---------- emit_btmm_wo equivalent ---------- + dst_addr_expr = tir.IntImm("int32", int(dst.address)) + gid_wo = self._new_group() + first_wo_op = True + + def _stamp_wo(o): + nonlocal first_wo_op + o.annotations["group_id"] = gid_wo + if first_wo_op: + o.annotations["close_order"] = "insertion" + first_wo_op = False + return o + + # Comment first (group_id only, no close_order — see note above). + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"btmm write-only task {task_id}.wo " + f"out=vram[{int(dst.address)}] " + f"tiles={tile_count} " + f"lanes={lane_count} head_width={head_width}" + ], + annotations={"group_id": gid_wo}, + )) + self.pre_isa.append(_stamp_wo(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr_expr], + ))) + # M_BMM_WO gp{out}, 0 + self.pre_isa.append(_stamp_wo(PreIsaOp( + opcode="M_BMM_WO", operands=[dst_addr_expr, 0], + ))) + + def _emit_btmv(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_btmv``: M_BTMV + M_BMV_WO. + + Identical structure to ``_emit_btmm`` modulo two opcode + substitutions (M_BTMM -> M_BTMV, M_BMM_WO -> M_BMV_WO) and a + different task_id default. The decomposition produces 7 + PreIsaOps in the same two-group pattern. + """ + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"plena.btmv expects 3 buffer_args (regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.btmv a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassError( + f"plena.btmv b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.btmv c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + task_id = op.annotations.get("intrinsic", "btmv") + lane_count = self.shim.btmm_lane_count + head_width = self.shim.btmm_hlen + + # ---------- emit_btmv equivalent ---------- + rhs_addr_expr = tir.IntImm("int32", int(rhs.address)) + lhs_addr_expr = tir.IntImm("int32", int(lhs.address)) + gid_btmv = self._new_group() + first_btmv_op = True + + def _stamp_btmv(o): + nonlocal first_btmv_op + o.annotations["group_id"] = gid_btmv + if first_btmv_op: + o.annotations["close_order"] = "insertion" + first_btmv_op = False + return o + + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"btmv task {task_id} " + f"lhs_packed=vram[{int(lhs.address)}] " + f"rhs_mram={int(rhs.address)} " + f"lanes={lane_count} head_width={head_width}" + ], + annotations={"group_id": gid_btmv}, + )) + self.pre_isa.append(_stamp_btmv(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[rhs_addr_expr], + ))) + self.pre_isa.append(_stamp_btmv(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[lhs_addr_expr], + ))) + self.pre_isa.append(_stamp_btmv(PreIsaOp( + opcode="M_BTMV", + operands=["gp0", rhs_addr_expr, lhs_addr_expr], + ))) + + # ---------- emit_bmv_wo equivalent ---------- + dst_addr_expr = tir.IntImm("int32", int(dst.address)) + gid_wo = self._new_group() + first_wo_op = True + + def _stamp_wo(o): + nonlocal first_wo_op + o.annotations["group_id"] = gid_wo + if first_wo_op: + o.annotations["close_order"] = "insertion" + first_wo_op = False + return o + + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"bmv write-only task {task_id}.wo " + f"out=vram[{int(dst.address)}] " + f"lanes={lane_count} head_width={head_width}" + ], + annotations={"group_id": gid_wo}, + )) + self.pre_isa.append(_stamp_wo(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr_expr], + ))) + self.pre_isa.append(_stamp_wo(PreIsaOp( + opcode="M_BMV_WO", operands=[dst_addr_expr, 0], + ))) + + def _emit_mv(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_mv`` + ``emit_mv``. + + Decomposes emit_mv's setup + tiles-loop + bumps into a flat + stream of PreIsaOps: + + _COMMENT "mv task ..." + _PRELOAD_ADDR lhs_addr_expr ← gp_v + _PRELOAD_ADDR rhs_addr_expr ← gp_m + _PRELOAD_ADDR dst_addr_expr ← gp_o + for t in range(tiles): + M_MV gp0, gp{rhs}, gp{lhs} + M_MV_WO gp{dst}, 0 + [if t 1: + # Prefix loop: (tiles - 1) iters, each = M_MV + M_MV_WO + 2 bumps. + t_var = tir.Var("mv_t", "int32") + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles - 1], + binds=t_var, + annotations={ + "loop_kind": "unroll", + # Shared scope across iters: the body's cached GPs + # for rhs_addr_expr / dst_addr_expr persist, and + # the _BUMP_CACHED_GPs inside the body mutate + # them in place — matching legacy emit_mv's + # tiles-loop behaviour. + "unroll_scope": "shared", + "group_id": gid, + "close_order": "insertion", + }, + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="M_MV", + operands=["gp0", rhs_addr_expr, lhs_addr_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="M_MV_WO", + operands=[dst_addr_expr, 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[rhs_addr_expr, BLEN_VAR], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[dst_addr_expr, BLEN_VAR], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={ + "loop_kind": "unroll", + "group_id": gid, + }, + ))) + # Final iter (no trailing bumps). + self.pre_isa.append(_stamp(PreIsaOp( + opcode="M_MV", + operands=["gp0", rhs_addr_expr, lhs_addr_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="M_MV_WO", + operands=[dst_addr_expr, 0], + ))) + + def _emit_mm(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_mm`` → + ``ISAEmitter.emit_matmul_single_tile_hwloop``. + + Legacy walks a nested (oc, orow) Python loop, both with + extent ``tiles_per_mlen = mlen // blen``. Per inner-loop iter + legacy emits: + + S_ADDI_INT gp{mat}, gp0, rhs.address + oc*blen + S_ADDI_INT gp{act}, gp0, lhs.address + orow*output_row_stride + S_ADDI_INT gp{result}, gp0, + dst.address + oc*blen + orow*output_row_stride + M_MM 0, gp{mat}, gp{act} + M_MM_WO gp{result}, gp0, 0 + + where ``output_row_stride = blen * mlen``. + + PreIsaIR migration: the (oc, orow) loops become two nested + ``LOOP_START(loop_kind="unroll", unroll_scope="per_iter")`` + pairs; the three addresses are PrimExprs in terms of the loop + vars (so the optimiser SEES the address algebra: + ``rhs.address + oc * blen``). Each inner-loop iter is its + own materialisation scope (``per_iter``) — legacy uses fresh + ``S_ADDI_INT``s per iter and discards the previous values. + + Narrow-dst path (rhs_cols != mlen) is handled by a separate + ``_emit_matmul_narrow_tile_hwloop``-equivalent migration — + TODO; for now we raise on that case. + """ + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"plena.mm expects 3 buffer_args; got {len(op.buffer_args)}" + ) + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + + lhs_rows, lhs_cols = self._legacy._logical_2d(lhs.shape) + rhs_rows, rhs_cols = self._legacy._logical_2d(rhs.shape) + dst_rows, dst_cols = self._legacy._logical_2d(dst.shape) + if lhs_rows != mlen or lhs_cols != mlen: + raise PreIsaPassError( + f"plena.mm lhs must be mlen*mlen; got ({lhs_rows},{lhs_cols})" + ) + if rhs_rows != mlen: + raise PreIsaPassError( + f"plena.mm rhs must have mlen rows; got rows={rhs_rows}" + ) + if dst_rows != mlen: + raise PreIsaPassError( + f"plena.mm dst must have mlen rows; got rows={dst_rows}" + ) + if rhs_cols != dst_cols: + raise PreIsaPassError( + f"plena.mm rhs/dst logical widths mismatch: " + f"rhs={rhs_cols} dst={dst_cols}" + ) + + # Narrow path: rhs_cols < mlen (and dst_cols == rhs_cols). + # Delegate to a separate decomposer that mirrors legacy + # ``emit_matmul_narrow_tile_hwloop`` byte-for-byte (within the + # GP-rename equivalence that semantic_isa_equal expects). + if not (rhs_cols == mlen and dst_cols == mlen): + self._emit_mm_narrow( + mod=mod, op=op, lhs=lhs, rhs=rhs, dst=dst, + hlen=int(rhs_cols), dst_row_stride=int(dst_cols), + ) + return + + tiles_per_mlen = mlen // blen + output_row_stride = blen * mlen + task_id = op.annotations.get("intrinsic", "mm") + + # Outer-most group_id: governs the entire op (matmul body lives + # under it; nested per_iter unroll scopes are independent inner + # scopes, BackendEmit handles the nesting). + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"matmul (single-tile, symbolic unroll) task {task_id} " + f"lhs=vram[{int(lhs.address)}] " + f"rhs=mram[{int(rhs.address)}] " + f"dst=vram[{int(dst.address)}]" + ], + annotations={"group_id": gid}, + )) + + # Loop variables — fresh tir.Var per emit_mm call so nested + # mm callers don't alias. + oc_var = tir.Var(f"mm_oc_{id(op) & 0xffff:x}", "int32") + orow_var = tir.Var(f"mm_orow_{id(op) & 0xffff:x}", "int32") + + # Address PrimExprs — written EXACTLY in legacy form so the + # optimiser can fold / hoist: + # mat_col = rhs.address + oc * blen + # act_row = lhs.address + orow * (blen * mlen) + # result = dst.address + oc * blen + orow * (blen * mlen) + # + # ``blen`` and ``mlen`` are hardware-shape parameters that + # change with the chip variant — written as symbolic Vars + # (``BLEN_VAR`` / ``MLEN_VAR``) so PreIsaIR preserves the + # algebra (``oc * blen``, not ``oc * 4``). BackendEmit's + # symbol_table binds them to the shim's current IntImm + # values; the materialiser substitutes + folds at emit time. + output_row_stride_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + mat_col_expr = tir.Add( + tir.IntImm("int32", int(rhs.address)), + tir.Mul(oc_var, BLEN_VAR), + ) + act_row_expr = tir.Add( + tir.IntImm("int32", int(lhs.address)), + tir.Mul(orow_var, output_row_stride_expr), + ) + result_expr = tir.Add( + tir.Add( + tir.IntImm("int32", int(dst.address)), + tir.Mul(oc_var, BLEN_VAR), + ), + tir.Mul(orow_var, output_row_stride_expr), + ) + + # Outer (oc) symbolic-unroll loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_mlen], + binds=oc_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + # Inner (orow) symbolic-unroll loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_mlen], + binds=orow_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + # Inner body group_id — fresh, so each iter is its own + # materialisation scope (BackendEmit's per_iter unroll opens + # one scope per iter; the group_id only matters across + # consecutive PreIsaOps in a single iter). + inner_gid = self._new_group() + + def _stamp_inner(o): + o.annotations["group_id"] = inner_gid + # Legacy emit_matmul_single_tile_hwloop ends with + # ``ra.free_gp(gp_regs)`` — batched free, insertion-order + # close matches. + if "close_order" not in o.annotations: + o.annotations["close_order"] = "insertion" + return o + + # Three address preloads in legacy emit order: mat, act, result. + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mat_col_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[act_row_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[result_expr], + ))) + # M_MM 0, gp{mat}, gp{act} + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="M_MM", + operands=[0, mat_col_expr, act_row_expr], + ))) + # M_MM_WO gp{result}, gp0, 0 + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="M_MM_WO", + operands=[result_expr, "gp0", 0], + ))) + # Close inner / outer loops. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + def _emit_mm_narrow( + self, + *, + mod: _hlir.HLIRModule, + op: _hlir.Op, + lhs: _hlir.Buffer, + rhs: _hlir.Buffer, + dst: _hlir.Buffer, + hlen: int, + dst_row_stride: int, + ) -> None: + """Mirror of legacy ``ISAEmitter.emit_matmul_narrow_tile_hwloop``. + + Legacy emission for ``mlen*mlen @ mlen*hlen -> mlen*hlen`` (with + possibly wider dst_row_stride): + + S_ADDI_INT gp{stride}, gp0, 1 ← preamble (dead, but kept) + for oc in tiles_per_slot: ← outer + mat_addr = rhs.address + oc * blen + S_ADDI_INT gp{mat}, gp0, mat_addr + for t in tiles_per_mlen: ← inner + act_addr = lhs.address + t * blen * mlen + out_addr = dst.address + oc * blen + t * blen * dst_row_stride + S_ADDI_INT gp{act}, gp0, act_addr + S_ADDI_INT gp{out}, gp0, out_addr + M_MM 0, gp{mat}, gp{act} + M_MM_WO gp{out}, gp0, 0 + + where ``tiles_per_slot = hlen / blen`` and + ``tiles_per_mlen = mlen / blen``. + + PreIsaIR migration: both loops become symbolic + ``LOOP_START(loop_kind="unroll", unroll_scope="per_iter")``; + addresses are PrimExprs in the loop vars so an LICM pass can + hoist ``mat_addr`` out of the inner loop (legacy already does + this manually by computing mat_addr in the outer loop, but + the materialised S_ADDI for it sits OUTSIDE the inner t loop — + we preserve that structure). + """ + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + tiles_per_slot = hlen // blen + tiles_per_mlen = mlen // blen + # legacy: act_row_stride = blen * mlen + act_row_stride = blen * mlen + # legacy: output_row_stride = blen * dst_row_stride + output_row_stride = blen * dst_row_stride + task_id = op.annotations.get("intrinsic", "mm") + # rhs_col_offset / dst_col_offset default to 0 (legacy _emit_mm + # always passes the default — no per-call override yet). + rhs_col_offset = 0 + dst_col_offset = 0 + + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + # Header comment. + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"narrow matmul task {task_id} " + f"lhs=vram[{int(lhs.address)}] " + f"rhs=mram[{int(rhs.address)}] " + f"rhs_col_offset={rhs_col_offset} " + f"dst=vram[{int(dst.address)}] " + f"dst_col_offset={dst_col_offset} " + f"hlen={hlen} dst_row_stride={dst_row_stride}" + ], + annotations={"group_id": gid}, + )) + # Preamble: legacy emits ``S_ADDI_INT gp{stride}, gp0, 1`` (dead + # but byte-equally preserved). Model via _PRELOAD_ADDR of the + # literal 1 — materialiser emits the same single instruction. + stride_const = tir.IntImm("int32", 1) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[stride_const], + ))) + + # Loop variables — fresh per emit_mm_narrow call. + oc_var = tir.Var(f"mm_n_oc_{id(op) & 0xffff:x}", "int32") + t_var = tir.Var(f"mm_n_t_{id(op) & 0xffff:x}", "int32") + + # mat_addr = rhs.address + rhs_col_offset + oc * blen + # (rhs_col_offset is 0 here but kept symbolic for future use). + # ``blen`` is a hardware-shape const — referenced via + # ``BLEN_VAR`` so PreIsaIR preserves the algebra; the + # materialiser substitutes + folds at emit time. + mat_addr_expr = tir.Add( + tir.Add( + tir.IntImm("int32", int(rhs.address)), + tir.IntImm("int32", rhs_col_offset), + ), + tir.Mul(oc_var, BLEN_VAR), + ) + + # Outer (oc) symbolic unroll. NOTE on mat_addr placement: + # legacy ``emit_matmul_narrow_tile_hwloop`` materialises + # mat_addr ONCE per oc iter, BEFORE the inner t loop — saving + # tiles_per_mlen-1 redundant S_ADDIs per oc. With the + # PreIsaIR per_iter scope model, an outer-scope _PRELOAD_ADDR + # is closed when the inner unroll's body ops open their own + # scope; preserving outer scopes across inner unroll iters + # would need a scope-ownership mechanism beyond what we have. + # We accept the simpler form: mat_addr is preloaded inside + # each (oc, t) inner-iter scope alongside act/out. This emits + # tiles_per_mlen-1 extra S_ADDIs per oc relative to legacy, + # so semantic_isa_equal's instruction-count check will reject + # strict equality — see the matching test for the relaxed + # check. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_slot], + binds=oc_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # Inner (t) symbolic unroll. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_mlen], + binds=t_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # act_addr = lhs.address + t * (blen * mlen) + # out_addr = dst.address + dst_col_offset + oc * blen + # + t * (blen * dst_row_stride) + # blen / mlen referenced symbolically; dst_row_stride is a + # per-op compile-time parameter (not a hw-shape const) so + # stays as an IntImm here. + act_row_stride_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + output_row_stride_expr = tir.Mul( + BLEN_VAR, tir.IntImm("int32", dst_row_stride), + ) + act_addr_expr = tir.Add( + tir.IntImm("int32", int(lhs.address)), + tir.Mul(t_var, act_row_stride_expr), + ) + out_addr_expr = tir.Add( + tir.Add( + tir.Add( + tir.IntImm("int32", int(dst.address)), + tir.IntImm("int32", dst_col_offset), + ), + tir.Mul(oc_var, BLEN_VAR), + ), + tir.Mul(t_var, output_row_stride_expr), + ) + + inner_gid = self._new_group() + + def _stamp_inner(o): + o.annotations["group_id"] = inner_gid + if "close_order" not in o.annotations: + o.annotations["close_order"] = "insertion" + return o + + # Per-iter preloads: mat (re-loaded per inner iter — see note + # above on the scope model), then act, then out. Order keeps + # ``mat`` first so the GP backing it (and the bijection used + # by ``M_MM 0, gp{mat}, gp{act}``) lines up with legacy's + # ``S_ADDI gp{mat}, ...; ...; M_MM 0, gp{mat}, gp{act}`` + # sequence. + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mat_addr_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[act_addr_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[out_addr_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="M_MM", + operands=[0, mat_addr_expr, act_addr_expr], + ))) + self.pre_isa.append(_stamp_inner(PreIsaOp( + opcode="M_MM_WO", + operands=[out_addr_expr, "gp0", 0], + ))) + + # Close inner loop, then outer. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + def _emit_matmul(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_matmul`` → + ``ISAEmitter.emit_matmul_general(unroll_loops=True)``. + + Schema: + buffer_args = [a_region: VramRegion, + b_region: MramRegion, + c_region: VramRegion] + scalar_args = [a_dim_roles, b_dim_roles, c_dim_roles] + each a 4-tuple of "M"/"K"/"N"/"_" labels. + + Decomposition into 5 nested unroll loops (m, n_mlen, oc, orow, k): + for m in M_tiles: (unroll, per_iter) + for n_mlen in N_mlen_tiles: (unroll, per_iter) + for oc in tiles_per_n_mlen: (unroll, per_iter) + _PRELOAD act_orow, out_orow + for orow in tiles_per_mlen: (unroll, per_iter) + _PRELOAD gp_act (base) + for k in K_tiles: (unroll, per_iter) + if k > 0: _PRELOAD gp_act (with k stride) + _PRELOAD gp_mat + M_MM 0, gp_mat, gp_act (or M_TMM) + _PRELOAD gp_out_orow (with orow stride) + M_MM_WO gp_out_orow, gp0, 0 + + Static-offset path only (no PrimExpr offsets, no packed-head + dst, no dim_roles other than canonical layouts). Hardware + consts ``mlen`` / ``blen`` are referenced symbolically via + ``MLEN_VAR`` / ``BLEN_VAR``. + """ + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"plena.matmul expects 3 buffer_args; got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.matmul a: expected VramRegion, got " + f"{type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassError( + f"plena.matmul b: expected MramRegion, got " + f"{type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassError( + f"plena.matmul c: expected VramRegion, got " + f"{type(c_reg).__name__}" + ) + if len(op.scalar_args) != 3: + raise PreIsaPassError( + f"plena.matmul expects 3 scalar_args (a/b/c dim_roles); " + f"got {len(op.scalar_args)}" + ) + a_roles, b_roles, c_roles = op.scalar_args + if len(a_roles) != 4 or len(b_roles) != 4 or len(c_roles) != 4: + raise PreIsaPassError( + f"plena.matmul dim_roles must each be 4-tuples" + ) + + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + mlen_v = int(self.shim.mlen) + blen_v = int(self.shim.blen) + hlen_v = int(self.shim.btmm_hlen) + + def _find_role_axis(roles, role, operand): + hits = [i for i, r in enumerate(roles) if r == role] + if not hits: + raise PreIsaPassError( + f"plena.matmul {operand}: missing role {role!r}" + ) + if len(hits) > 1: + raise PreIsaPassError( + f"plena.matmul {operand}: role {role!r} appears at " + f"multiple axes {hits}" + ) + return hits[0] + + c_M_axis = _find_role_axis(c_roles, "M", "c") + c_N_axis = _find_role_axis(c_roles, "N", "c") + a_M_axis = _find_role_axis(a_roles, "M", "a") + a_K_axis = _find_role_axis(a_roles, "K", "a") + b_K_axis = _find_role_axis(b_roles, "K", "b") + b_N_axis = _find_role_axis(b_roles, "N", "b") + + M = int(a_reg.extents[a_M_axis]) + K = int(a_reg.extents[a_K_axis]) + N = int(b_reg.extents[b_N_axis]) + if M % mlen_v != 0 or K % mlen_v != 0: + raise PreIsaPassError( + f"plena.matmul: M ({M}) and K ({K}) must be multiples of " + f"mlen ({mlen_v})" + ) + M_tiles = M // mlen_v + K_tiles = K // mlen_v + N_mlen_tiles = (N + mlen_v - 1) // mlen_v + transpose_b = b_N_axis < b_K_axis + + # dst_row_stride — same heuristic as legacy (packed-head case + # handled later in a follow-up; non-packed = product of + # extents past the M axis). + dst_cluster_dim = getattr(dst, "cluster_dim", None) + tl_info = self._legacy._tile_layout_strides(dst) + packed_head_dst = ( + tl_info is not None + and dst_cluster_dim == 2 + and int(tl_info["lane_count"]) > 1 + and c_M_axis == 1 + ) + if packed_head_dst: + dst_row_stride = int(tl_info["s_inner_stride"]) + else: + dst_row_stride = 1 + for ax in range(c_M_axis + 1, len(c_reg.extents)): + dst_row_stride *= int(c_reg.extents[ax]) + + # Region origin offsets — static-only path for now. + lhs_off = self._legacy._region_origin_offset(lhs, a_reg) + rhs_off = self._legacy._region_origin_offset(rhs, b_reg) + dst_off = self._legacy._region_origin_offset(dst, c_reg) + + def _static(x): + if isinstance(x, int): + return int(x) + if isinstance(x, tir.IntImm): + return int(x.value) + return None + + lhs_off_s = _static(lhs_off) + rhs_off_s = _static(rhs_off) + dst_off_s = _static(dst_off) + if lhs_off_s is None or rhs_off_s is None or dst_off_s is None: + raise PreIsaPassError( + f"plena.matmul: dynamic region offsets not yet " + f"supported by PreIsaPass; got lhs={lhs_off!r} " + f"rhs={rhs_off!r} dst={dst_off!r}" + ) + + task_id = op.annotations.get("intrinsic", "matmul") + + # Symbolic hw-shape exprs. + # lhs_k_tile_stride = mlen * mlen + # lhs_m_tile_stride = K_tiles * mlen * mlen + # rhs_n_mlen_tile_stride / rhs_k_tile_stride depend on transpose_b + # a_orow_step = blen * mlen + # c_orow_step = blen * mlen + # oc_b_step = blen (or blen*mlen for transpose_b) + # dst_m_tile_stride = mlen * dst_row_stride + # mlen_sq = MLEN * MLEN + mlen_sq_expr = tir.Mul(MLEN_VAR, MLEN_VAR) + a_orow_step_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + c_orow_step_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + lhs_k_tile_stride_expr = mlen_sq_expr + lhs_m_tile_stride_expr = tir.Mul( + tir.IntImm("int32", K_tiles), mlen_sq_expr, + ) + if transpose_b: + rhs_n_mlen_tile_stride_expr = tir.Mul( + tir.IntImm("int32", K_tiles), mlen_sq_expr, + ) + rhs_k_tile_stride_expr = mlen_sq_expr + oc_b_step_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + else: + rhs_n_mlen_tile_stride_expr = mlen_sq_expr + rhs_k_tile_stride_expr = tir.Mul( + tir.IntImm("int32", N_mlen_tiles), mlen_sq_expr, + ) + oc_b_step_expr = BLEN_VAR + dst_m_tile_stride_expr = tir.Mul( + MLEN_VAR, tir.IntImm("int32", dst_row_stride), + ) + + gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"matmul (general, symbolic unroll) task {task_id} " + f"M={M} K={K} N={N} " + f"(M_tiles={M_tiles} K_tiles={K_tiles} " + f"N_mlen_tiles={N_mlen_tiles} transpose_b={transpose_b})" + ], + annotations={"group_id": gid}, + )) + + # Loop vars. All fresh per emit_matmul so nested matmul callers + # don't alias. + suffix = f"{id(op) & 0xffff:x}" + m_var = tir.Var(f"mm_m_{suffix}", "int32") + n_mlen_var = tir.Var(f"mm_nmlen_{suffix}", "int32") + # oc / orow / k are unrolled inside the n_mlen iter so we + # don't need them as tir.Vars (could, but legacy unroll + # uses Python literals). For now produce LOOP_START for the + # outer two loops only; oc/orow/k stay Python-unrolled. + + mm_opcode = "M_TMM" if transpose_b else "M_MM" + + # Static int residues that don't fold to symbolic exprs. + lhs_residual_static = int(lhs.address) + int(lhs_off_s) + rhs_residual_static = int(rhs.address) + int(rhs_off_s) + dst_residual_static = int(dst.address) + int(dst_off_s) + + # m loop (outer-most). + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, M_tiles], + binds=m_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + # n_mlen loop (inside m). + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, N_mlen_tiles], + binds=n_mlen_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # Per-(m, n_mlen) base addresses: + # lhs_base = lhs.address + lhs_off + m * lhs_m_tile_stride + # rhs_n_mlen_base = rhs.address + rhs_off + n_mlen * rhs_n_mlen_tile_stride + # dst_m_base = dst.address + dst_off + m * dst_m_tile_stride + lhs_base_expr = tir.Add( + tir.IntImm("int32", lhs_residual_static), + tir.Mul(m_var, lhs_m_tile_stride_expr), + ) + rhs_n_mlen_base_expr = tir.Add( + tir.IntImm("int32", rhs_residual_static), + tir.Mul(n_mlen_var, rhs_n_mlen_tile_stride_expr), + ) + dst_m_base_expr = tir.Add( + tir.IntImm("int32", dst_residual_static), + tir.Mul(m_var, dst_m_tile_stride_expr), + ) + + # For this initial migration we accept some structural + # differences from legacy (mat_addr re-preloads per inner + # iter — same accepted compromise as mm_narrow). The + # innermost oc/orow/k stay Python-unrolled in the + # PreIsaPass producer (less code) — they become flat + # PreIsaOps. The "unroll" axes that DO appear in PreIsaIR + # are m and n_mlen — the outer two. + for oc in range(0): # placeholder — we don't iterate oc/orow/k here + pass + + # cols_here / tiles_per_n_mlen depend on n_mlen (runtime in the + # symbolic loop). We need them as compile-time bounds for the + # inner Python unrolls; since n_mlen_var binds to IntImm + # per-iter via BackendEmit, we can't read it here. Compromise: + # if N is exact multiple of mlen, all n_mlen iters have + # ``cols_here = mlen``; otherwise we'd need a different + # encoding. Enforce the common case. + if N % mlen_v != 0: + raise PreIsaPassError( + f"plena.matmul: PreIsaPass currently requires N ({N}) to be " + f"a multiple of mlen ({mlen_v}) — partial-last-block N " + f"not yet handled (legacy ``cols_here = min(mlen, ...)`` " + f"path)" + ) + tiles_per_n_mlen = mlen_v // blen_v + tiles_per_mlen = mlen_v // blen_v + + # All five nesting levels (m, n_mlen, oc, orow, k) become + # PreIsaIR LOOP_START loops with ``loop_kind="unroll"``. + # Optimisation passes (LICM, CSE) operating on PreIsaIR see + # the full algebra and can hoist ``m * lhs_m_tile_stride`` + # out of inner loops, share ``k * rhs_k_tile_stride`` across + # oc iterations, etc — all of which the legacy emit_matmul + # bakes as literal residuals inside a fully-flat unrolled + # emit (no hoisting possible at that point). + oc_var = tir.Var(f"mm_oc_{suffix}", "int32") + orow_var = tir.Var(f"mm_orow_{suffix}", "int32") + k_var = tir.Var(f"mm_k_{suffix}", "int32") + + # oc loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_n_mlen], + binds=oc_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + # orow loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_mlen], + binds=orow_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # dst_col = n_mlen * mlen + oc * blen + dst_col_expr = tir.Add( + tir.Mul(n_mlen_var, MLEN_VAR), + tir.Mul(oc_var, BLEN_VAR), + ) + # act_orow = lhs_base + orow * a_orow_step + act_orow_expr = tir.Add( + lhs_base_expr, + tir.Mul(orow_var, a_orow_step_expr), + ) + out_expr = tir.Add( + tir.Add(dst_m_base_expr, dst_col_expr), + tir.Mul(orow_var, c_orow_step_expr), + ) + + # K accumulation loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, K_tiles], + binds=k_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # act_k = act_orow + k * lhs_k_tile_stride + # mat = rhs_n_mlen_base + oc * oc_b_step + k * rhs_k_tile_stride + act_k_expr = tir.Add( + act_orow_expr, + tir.Mul(k_var, lhs_k_tile_stride_expr), + ) + mat_expr = tir.Add( + tir.Add( + rhs_n_mlen_base_expr, + tir.Mul(oc_var, oc_b_step_expr), + ), + tir.Mul(k_var, rhs_k_tile_stride_expr), + ) + + # Each k iter is its own group (per_iter scope). + k_iter_gid = self._new_group() + + def _stamp_k(o): + o.annotations["group_id"] = k_iter_gid + if "close_order" not in o.annotations: + o.annotations["close_order"] = "insertion" + return o + + self.pre_isa.append(_stamp_k(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[act_k_expr], + ))) + self.pre_isa.append(_stamp_k(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mat_expr], + ))) + if transpose_b: + self.pre_isa.append(_stamp_k(PreIsaOp( + opcode="M_TMM", + operands=[0, act_k_expr, mat_expr], + ))) + else: + self.pre_isa.append(_stamp_k(PreIsaOp( + opcode="M_MM", + operands=[0, mat_expr, act_k_expr], + ))) + + # Close k loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + # M_MM_WO after K accumulation — once per (oc, orow). + wo_gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_PRELOAD_ADDR", + operands=[out_expr], + annotations={"group_id": wo_gid, "close_order": "insertion"}, + )) + self.pre_isa.append(PreIsaOp( + opcode="M_MM_WO", + operands=[out_expr, "gp0", 0], + annotations={"group_id": wo_gid}, + )) + + # Close orow, oc loops. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + # Close n_mlen loop, then m loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + # ------------------------------------------------------------------ + # DMA — HBM ↔ VRAM / MRAM tile transfers + # ------------------------------------------------------------------ + def _iter_tile_offsets(self, hbm_buf: _hlir.Buffer): + """Mirror of legacy ``isa_pass._iter_tile_offsets``.""" + mlen = int(self.shim.mlen) + ann = hbm_buf.annotations + rows = ann.get("logical_rows", mlen) + cols = ann.get("logical_cols", mlen) + row_blocks = ann.get("row_blocks", 1) + col_blocks = ann.get("col_blocks", 1) + tile_elems = mlen * mlen + idx = 0 + for j in range(col_blocks): + for i in range(row_blocks): + hbm_off = i * mlen * cols + j * mlen + vram_off = idx * tile_elems + yield vram_off, hbm_off + idx += 1 + + def _emit_dma_h2v(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_dma_h2v`` → + ``ISAEmitter.emit_load_tile_from_hbm``. + + Per tile: + 1. Bind an addr-reg ``aN`` from hbm_addr (literal IntImm). + 2. Reset 5 preload scratch GPs (S_ADDI_INT gp{r}, gp0, 0). + 3. Emit ``_emit_preload_tile_isa`` equivalent: + - S_ADDI_INT gp{a_actual}, gp0, scale_len + - C_SET_SCALE_REG gp{a_actual} + - S_ADDI_INT gp{a_actual}, gp0, hbm_start_offset + - S_ADDI_INT gp{result}, gp0, vram_addr + - (batch == mlen, batch > preload_len): set stride, twin-unroll + - For each (outer, inner): set act/mat addrs, H_PREFETCH_V. + + Hardware-coupled parameters (all hw_consts symbolic): + * ``mlen`` (VLEN + batch + hidden_size default) + * ``v_prefetch_amount`` (rows per H_PREFETCH_V) + * ``tile_elems = mlen * mlen`` (scale_size default) + Per-call options (per_op static): + * ``hbm_stride`` (defaults to mlen) + * ``hbm_scale_size`` (defaults to tile_elems) + """ + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + mlen_v = int(self.shim.mlen) + # hbm_stride / hbm_scale_size defaults — applied to PrimExpr. + hbm_stride_v = ( + mlen_v if src.hbm_stride is None else int(src.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if src.hbm_scale_size is None + else int(src.hbm_scale_size) + ) + + gid = self._new_group() + + for vram_off, hbm_off in self._iter_tile_offsets(src): + tile_hbm_start_offset = src.hbm_offset + hbm_off + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_h2v tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[vram+{vram_off}]" + ], + annotations={"group_id": gid}, + )) + self._emit_load_tile_from_hbm_seq( + hbm_addr=int(src.address), + vram_addr=int(dst.address) + vram_off, + hbm_stride=hbm_stride_v, + hbm_scale_size=hbm_scale_v, + hbm_start_offset=tile_hbm_start_offset, + group_id=gid, + ) + + def _emit_dma_h2m(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_dma_h2m`` → + ``ISAEmitter.emit_hbm_tile_to_mram``.""" + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + mlen_v = int(self.shim.mlen) + hbm_stride_v = ( + mlen_v if src.hbm_stride is None else int(src.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if src.hbm_scale_size is None + else int(src.hbm_scale_size) + ) + + gid = self._new_group() + for vram_off, hbm_off in self._iter_tile_offsets(src): + tile_hbm_offset = src.hbm_offset + hbm_off + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_h2m tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[mram+{vram_off}]" + ], + annotations={"group_id": gid}, + )) + self._emit_hbm_tile_to_mram_seq( + hbm_addr=int(src.address), + mram_addr=int(dst.address) + vram_off, + hbm_stride=hbm_stride_v, + hbm_scale=hbm_scale_v, + hbm_offset=tile_hbm_offset, + group_id=gid, + ) + + def _emit_dma_v2h(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_dma_v2h`` → + ``ISAEmitter.emit_store_tile_to_hbm``.""" + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + if src.num_elements != dst.num_elements: + raise PreIsaPassError( + f"dma_v2h: src/dst element-count mismatch" + ) + mlen_v = int(self.shim.mlen) + hbm_stride_v = ( + mlen_v if dst.hbm_stride is None else int(dst.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if dst.hbm_scale_size is None + else int(dst.hbm_scale_size) + ) + + gid = self._new_group() + for vram_off, hbm_off in self._iter_tile_offsets(dst): + tile_hbm_start_offset = dst.hbm_offset + hbm_off + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_v2h tile {src.name}[vram+{vram_off}] -> " + f"{dst.name}[hbm+{hbm_off}]" + ], + annotations={"group_id": gid}, + )) + self._emit_store_tile_to_hbm_seq( + vram_addr=int(src.address) + vram_off, + hbm_addr=int(dst.address), + hbm_stride=hbm_stride_v, + hbm_scale_size=hbm_scale_v, + hbm_start_offset=tile_hbm_start_offset, + group_id=gid, + ) + + # ------------------------------------------------------------------ + # DMA slice variants (BufferSlice src/dst + multi-tile grid). + # + # Static-offset path only: ``BufferSlice.starts`` must all be ints + # (no PrimExpr starts derived from loop vars). Dynamic-offset + # support — where ``_materialise_slice_offset`` returns a + # ``MaterializedExpr`` and the emit uses ``hbm_start_offset_reg`` + # — is a TODO matching the legacy isa_pass's two branches. + # ------------------------------------------------------------------ + def _emit_dma_h2v_slice( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + """Mirror of legacy ``isa_pass._emit_dma_h2v_slice``. + + For each tile in the slice's d/s/h/b grid: + 1. Comment "; tile (d,s,h,b): hbm_off=... vram_off=..." + 2. ``_emit_load_tile_from_hbm_seq`` with the per-tile + ``hbm_start_offset = base_static + tile_const`` and + ``vram_addr = dst.address + vram_off``. + """ + sl = op.buffer_args[0] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassError( + f"dma_h2v_slice: buffer_args[0] must be BufferSlice; " + f"got {type(sl).__name__}" + ) + dst_name = op.buffer_args[1] + if isinstance(dst_name, _hlir.BufferSlice): + raise PreIsaPassError( + f"dma_h2v_slice: dst must be a whole-buffer name" + ) + dst = mod.get_buffer(dst_name) + parent = mod.get_buffer(sl.parent) + if self._legacy._slice_has_dynamic_start(sl): + raise PreIsaPassError( + f"dma_h2v_slice: dynamic-start slice not yet supported " + f"by PreIsaPass (legacy uses _materialise_slice_offset's " + f"reg form)" + ) + # Tile grid + per-tile strides. + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._legacy._slice_tile_grid(parent, sl, dst) + ) + base_static = ( + parent.hbm_offset + + self._legacy._slice_offset_static(parent, sl) + ) + mlen_v = int(self.shim.mlen) + hbm_stride_v = ( + mlen_v if parent.hbm_stride is None + else int(parent.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if parent.hbm_scale_size is None + else int(parent.hbm_scale_size) + ) + + starts_s = self._legacy._format_starts(sl) + gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_h2v_slice {parent.name}[{starts_s}]" + f"+{list(sl.extents)} -> {dst.name} " + f"(grid d_tiles={d_tiles}, s_tiles={s_tiles}, " + f"h_groups={h_groups}, b={logical_b})" + ], + annotations={"group_id": gid}, + )) + for d_tile in range(d_tiles): + for s_tile in range(s_tiles): + for h_grp in range(h_groups): + for b in range(logical_b): + hbm_off = ( + base_static + + b * hbm_stride_b + + s_tile * inner_mlen * hbm_stride_s + + h_grp * lane_count * hbm_stride_h + + d_tile * inner_mlen + ) + vram_off = ( + d_tile * d_tile_stride + + s_tile * s_tile_stride + + h_grp * h_grp_stride + + b * b_stride + ) + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f" tile (d={d_tile}, s={s_tile}, " + f"h={h_grp}, b={b}): hbm_off={hbm_off} " + f"vram_off={vram_off}" + ], + annotations={"group_id": gid}, + )) + tile_gid = self._new_group() + self._emit_load_tile_from_hbm_seq( + hbm_addr=int(parent.address), + vram_addr=int(dst.address) + vram_off, + hbm_stride=hbm_stride_v, + hbm_scale_size=hbm_scale_v, + hbm_start_offset=hbm_off, + group_id=tile_gid, + ) + + def _emit_dma_h2m_slice( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + """Mirror of legacy ``isa_pass._emit_dma_h2m_slice``. + + Single-tile by contract (legacy ``_check_slice_single_tile``): + a slice into MRAM is always exactly one mlen*mlen tile. We + emit one ``emit_hbm_tile_to_mram`` equivalent at the resolved + per-slice hbm offset. + """ + sl = op.buffer_args[0] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassError( + f"dma_h2m_slice: buffer_args[0] must be BufferSlice" + ) + dst = mod.get_buffer(op.buffer_args[1]) + parent = mod.get_buffer(sl.parent) + if self._legacy._slice_has_dynamic_start(sl): + raise PreIsaPassError( + f"dma_h2m_slice: dynamic-start slice not yet supported" + ) + # Validate single-tile invariant (matches legacy). + self._legacy._check_slice_single_tile(parent, sl) + static_off = ( + parent.hbm_offset + + self._legacy._slice_offset_static(parent, sl) + ) + mlen_v = int(self.shim.mlen) + hbm_stride_v = ( + mlen_v if parent.hbm_stride is None + else int(parent.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if parent.hbm_scale_size is None + else int(parent.hbm_scale_size) + ) + + starts_s = self._legacy._format_starts(sl) + gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_h2m_slice {parent.name}[{starts_s}]" + f"+{list(sl.extents)} -> {dst.name} " + f"(parent_off={static_off} elems)" + ], + annotations={"group_id": gid}, + )) + self._emit_hbm_tile_to_mram_seq( + hbm_addr=int(parent.address), + mram_addr=int(dst.address), + hbm_stride=hbm_stride_v, + hbm_scale=hbm_scale_v, + hbm_offset=static_off, + group_id=gid, + ) + + def _emit_dma_v2h_slice( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + """Mirror of legacy ``isa_pass._emit_dma_v2h_slice``. + + Per tile in the slice grid (same shape as ``_emit_dma_h2v_slice``): + emit ``_emit_store_tile_to_hbm_seq`` with the per-tile + ``hbm_start_offset = base_static + tile_const``. + """ + src = mod.get_buffer(op.buffer_args[0]) + sl = op.buffer_args[1] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassError( + f"dma_v2h_slice: buffer_args[1] must be BufferSlice" + ) + parent = mod.get_buffer(sl.parent) + if self._legacy._slice_has_dynamic_start(sl): + raise PreIsaPassError( + f"dma_v2h_slice: dynamic-start slice not yet supported" + ) + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._legacy._slice_tile_grid(parent, sl, src) + ) + base_static = ( + parent.hbm_offset + + self._legacy._slice_offset_static(parent, sl) + ) + mlen_v = int(self.shim.mlen) + hbm_stride_v = ( + mlen_v if parent.hbm_stride is None + else int(parent.hbm_stride) + ) + hbm_scale_v = ( + mlen_v * mlen_v if parent.hbm_scale_size is None + else int(parent.hbm_scale_size) + ) + + starts_s = self._legacy._format_starts(sl) + gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"dma_v2h_slice {src.name} -> " + f"{parent.name}[{starts_s}]+{list(sl.extents)} " + f"(grid d_tiles={d_tiles}, s_tiles={s_tiles}, " + f"h_groups={h_groups}, b={logical_b})" + ], + annotations={"group_id": gid}, + )) + for d_tile in range(d_tiles): + for s_tile in range(s_tiles): + for h_grp in range(h_groups): + for b in range(logical_b): + tile_const = ( + b * hbm_stride_b + + s_tile * inner_mlen * hbm_stride_s + + h_grp * lane_count * hbm_stride_h + + d_tile * inner_mlen + ) + vram_off = ( + d_tile * d_tile_stride + + s_tile * s_tile_stride + + h_grp * h_grp_stride + + b * b_stride + ) + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f" tile (d={d_tile}, s={s_tile}, " + f"h={h_grp}, b={b}): vram[+{vram_off}] -> " + f"hbm[base+{tile_const}]" + ], + annotations={"group_id": gid}, + )) + tile_gid = self._new_group() + self._emit_store_tile_to_hbm_seq( + vram_addr=int(src.address) + vram_off, + hbm_addr=int(parent.address), + hbm_stride=hbm_stride_v, + hbm_scale_size=hbm_scale_v, + hbm_start_offset=base_static + tile_const, + group_id=tile_gid, + ) + + # ------------------------------------------------------------------ + # DMA emit helpers (decompose legacy emit_load/store/_emit_*_tile_isa + # into PreIsaOp sequences). + # ------------------------------------------------------------------ + def _emit_hbm_tile_to_mram_seq( + self, *, hbm_addr: int, mram_addr: int, + hbm_stride: int, hbm_scale: int, hbm_offset: int, + group_id: int, + ) -> None: + """PreIsaOp decomposition of legacy ``emit_hbm_tile_to_mram``. + + Emits (per legacy): + _PRELOAD_ADDR_REG hbm_addr ; C_SET_ADDR_REG aN ... + S_ADDI_INT + C_SET_SCALE_REG (hbm_scale) + S_ADDI_INT + C_SET_STRIDE_REG (hbm_stride) + S_ADDI_INT (mram base addr) + S_ADDI_INT (hbm_offset) ; same gp reused as scale + H_PREFETCH_M gp{mram}, gp{offset}, aN, 1, 0 + S_ADDI_INT + C_SET_SCALE_REG (tile_elems) ← restore default + S_ADDI_INT + C_SET_STRIDE_REG (mlen) ← restore default + """ + # Distinct Python PrimExpr objects per slot — id()-keyed cache + # routes the right value to the right HW operand. + addr_expr = tir.IntImm("int32", int(hbm_addr)) + scale_expr = tir.IntImm("int32", int(hbm_scale)) + stride_expr = tir.IntImm("int32", int(hbm_stride)) + mram_expr = tir.IntImm("int32", int(mram_addr)) + offset_expr = tir.IntImm("int32", int(hbm_offset)) + # Restore defaults at the end (hw-shape symbolic): + default_scale_expr = tir.Mul(MLEN_VAR, MLEN_VAR) # tile_elems + default_stride_expr = MLEN_VAR + + def _stamp(o): + o.annotations["group_id"] = group_id + return o + + # Bind the HBM addr-reg from hbm_addr. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR_REG", operands=[addr_expr], + ))) + # Set scale. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[scale_expr], + ))) + # Set stride. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[stride_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[stride_expr], + ))) + # MRAM base. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mram_expr], + ))) + # HBM start offset. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[offset_expr], + ))) + # H_PREFETCH_M gp{mram}, gp{offset}, aN, 1, 0. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="H_PREFETCH_M", + operands=[mram_expr, offset_expr, addr_expr, 1, 0], + ))) + # Restore defaults (legacy emits scale = tile_elems, stride = mlen). + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[default_scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[default_scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[default_stride_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[default_stride_expr], + ))) + + def _emit_load_tile_from_hbm_seq( + self, *, hbm_addr: int, vram_addr: int, + hbm_stride: int, hbm_scale_size: int, hbm_start_offset: int, + group_id: int, + ) -> None: + """PreIsaOp decomposition of legacy ``emit_load_tile_from_hbm`` + + ``_emit_preload_tile_isa`` (batch>1 path). + + Mirrors the static-unroll body the legacy emits when + ``batch == mlen`` and ``batch > v_prefetch_amount``: a + per-(outer, inner) sequence of two S_ADDIs + one + H_PREFETCH_V. Outer/inner are statically unrolled here (could + become LOOP_START in a future optimisation pass). + """ + mlen_v = int(self.shim.mlen) + vpref_v = int(self.shim.v_prefetch_amount) + # All hardware-coupled values referenced via PrimExpr so the + # optimiser sees the algebra; folded by materialiser at emit. + addr_expr = tir.IntImm("int32", int(hbm_addr)) + scale_expr = tir.IntImm("int32", int(hbm_scale_size)) + stride_expr = tir.IntImm("int32", int(hbm_stride)) + offset_expr = tir.IntImm("int32", int(hbm_start_offset)) + vram_base_expr = tir.IntImm("int32", int(vram_addr)) + + def _stamp(o): + o.annotations["group_id"] = group_id + return o + + # Bind addr-reg. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR_REG", operands=[addr_expr], + ))) + # NOTE: legacy emits 5 ``reset_reg_asm`` ``S_ADDI_INT gp{r}, + # gp0, 0`` lines as a sanity-reset of scratch GPs that the body + # then immediately overwrites — semantically a no-op. PreIsaIR + # skips them to keep the cache from holding 5 pinned GPs for + # no reason; this means the new path emits 5 fewer ISA lines + # (semantically equivalent — those zeros are dead writes). + # Preload tile body (mirrors _emit_preload_tile_isa, batch>1). + # S_ADDI_INT gp{a_actual}, gp0, scale_len ; + # C_SET_SCALE_REG gp{a_actual} ; + # S_ADDI_INT gp{a_actual}, gp0, hbm_start_offset (re-uses GP) + # S_ADDI_INT gp{result}, gp0, vram_addr + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[offset_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[vram_base_expr], + ))) + # S_ADDI_INT gp{stride}, gp0, stride_len ; + # C_SET_STRIDE_REG gp{stride} + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[stride_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[stride_expr], + ))) + + # Static-unroll dimensions: batch=mlen, hidden_size=mlen, + # vlen=mlen. load_amount_per_hidden = 1 always under these + # assumptions; inner_count = mlen/v_prefetch_amount when + # batch>preload_len. + load_amount_per_hidden = 1 + # Symbolic inner_count = MLEN/V_PREFETCH_AMOUNT — but the + # *bound* of the LOOP_START must be a compile-time int, so + # we compute it from the shim here. The strides INSIDE the + # body stay symbolic (referencing MLEN_VAR, etc). + if mlen_v > vpref_v: + inner_count_n = (mlen_v + vpref_v - 1) // vpref_v + else: + inner_count_n = 1 + + # Wrap the (outer × inner) loop in a single per_iter + # LOOP_START. The body uses the iter var ``it`` = + # outer * inner_count + inner. Since + # load_amount_per_hidden == 1, outer == 0 always and the + # full count is just inner_count_n. + step_elems_expr = tir.Mul(MLEN_VAR, V_PREFETCH_AMOUNT_VAR) + # ``it`` ∈ [0, inner_count_n). + it_var = tir.Var( + f"dma_h2v_it_{id(self) & 0xffff:x}_" + f"{hbm_start_offset & 0xffff:x}", "int32", + ) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_START", + operands=[0, load_amount_per_hidden * inner_count_n], + binds=it_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + ))) + result_addr_expr = tir.Add( + vram_base_expr, + tir.Mul(it_var, step_elems_expr), + ) + if mlen_v > vpref_v: + # a_off (within HBM block) = it * stride_len * preload_len + # (outer term is 0 because load_amount_per_hidden=1). + a_off_inner_expr = tir.Add( + offset_expr, + tir.Mul( + it_var, + tir.Mul( + tir.IntImm("int32", int(hbm_stride)), + V_PREFETCH_AMOUNT_VAR, + ), + ), + ) + else: + # batch <= preload_len: legacy a_off = outer*vlen = 0. + a_off_inner_expr = offset_expr + + iter_gid = self._new_group() + + def _stamp_iter(o): + o.annotations["group_id"] = iter_gid + return o + + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[result_addr_expr], + ))) + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[a_off_inner_expr], + ))) + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="H_PREFETCH_V", + operands=[result_addr_expr, a_off_inner_expr, + addr_expr, 1, 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + ))) + + def _emit_store_tile_to_hbm_seq( + self, *, vram_addr: int, hbm_addr: int, + hbm_stride: int, hbm_scale_size: int, hbm_start_offset: int, + group_id: int, + ) -> None: + """PreIsaOp decomposition of legacy ``emit_store_tile_to_hbm`` + + ``_emit_store_tile_isa`` (batch>1 path). + """ + mlen_v = int(self.shim.mlen) + vwb_v = int(self.shim.v_writeback_amount) + addr_expr = tir.IntImm("int32", int(hbm_addr)) + scale_expr = tir.IntImm("int32", int(hbm_scale_size)) + stride_expr = tir.IntImm("int32", int(hbm_stride)) + offset_expr = tir.IntImm("int32", int(hbm_start_offset)) + vram_base_expr = tir.IntImm("int32", int(vram_addr)) + + def _stamp(o): + o.annotations["group_id"] = group_id + return o + + # Bind addr-reg. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR_REG", operands=[addr_expr], + ))) + # Setup: S_ADDI vram, S_ADDI scale; C_SET_SCALE; S_ADDI offset. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[vram_base_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[scale_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[offset_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[stride_expr], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[stride_expr], + ))) + + store_amount_per_hidden = 1 + if mlen_v > vwb_v: + inner_count_n = (mlen_v + vwb_v - 1) // vwb_v + else: + inner_count_n = 1 + step_elems_expr = tir.Mul(MLEN_VAR, V_WRITEBACK_AMOUNT_VAR) + + it_var = tir.Var( + f"dma_v2h_it_{id(self) & 0xffff:x}_" + f"{hbm_start_offset & 0xffff:x}", "int32", + ) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_START", + operands=[0, store_amount_per_hidden * inner_count_n], + binds=it_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + ))) + vram_off_expr = tir.Add( + vram_base_expr, + tir.Mul(it_var, step_elems_expr), + ) + if mlen_v > vwb_v: + hbm_off_inner_expr = tir.Add( + offset_expr, + tir.Mul( + it_var, + tir.Mul( + tir.IntImm("int32", int(hbm_stride)), + V_WRITEBACK_AMOUNT_VAR, + ), + ), + ) + else: + hbm_off_inner_expr = offset_expr + + iter_gid = self._new_group() + + def _stamp_iter(o): + o.annotations["group_id"] = iter_gid + return o + + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[vram_off_expr], + ))) + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[hbm_off_inner_expr], + ))) + self.pre_isa.append(_stamp_iter(PreIsaOp( + opcode="H_STORE_V", + operands=[vram_off_expr, hbm_off_inner_expr, + addr_expr, 1, 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + ))) + + def _emit_mm_slot(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Mirror of legacy ``isa_pass._emit_mm_slot`` → + ``ISAEmitter.emit_slot_matmul``. + + Schema: + buffer_args = [lhs_name, rhs_name, dst_name] + scalar_args = [lhs_row_offset, rhs_col_offset, + dst_col_offset, col_count] + (offsets may be ints OR tir.PrimExprs; col_count must be int) + + Legacy emission for each outer oc iter: + S_ADDI gp{act}, ..., act_addr(oc) + S_ADDI gp{mat}, ..., mat_addr(oc) + S_ADDI gp{out}, ..., out_addr(oc) + for t in range(tiles_per_mlen): + if t > 0: + S_ADDI gp{act}, gp{act}, row_stride (bump) + S_ADDI gp{out}, gp{out}, row_stride (bump) + M_MM 0, gp{mat}, gp{act} + M_MM_WO gp{out}, gp0, 0 + + PreIsaIR decomposition: + * outer oc loop → LOOP_START(unroll, per_iter) + * per oc iter: 3 _PRELOAD_ADDRs (act, mat, out) + * inner t loop → LOOP_START(unroll, shared) so the per-iter + BUMPs on cached act/out GPs persist across iters. + Each inner iter (except first) bumps act + out before + M_MM + M_MM_WO. + + This initial migration handles the STATIC-offset path + (lhs_row_offset / rhs_col_offset / dst_col_offset are all + compile-time ints). Dynamic-offset support is a TODO. + """ + if len(op.buffer_args) != 3: + raise PreIsaPassError( + f"plena.mm_slot expects 3 buffer_args; got {len(op.buffer_args)}" + ) + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + if len(op.scalar_args) != 4: + raise PreIsaPassError( + f"plena.mm_slot expects 4 scalar_args; got {len(op.scalar_args)}" + ) + lhs_row_offset_raw = op.scalar_args[0] + rhs_col_offset_raw = op.scalar_args[1] + dst_col_offset_raw = op.scalar_args[2] + col_count_raw = op.scalar_args[3] + + def _static(x): + if isinstance(x, int): + return int(x) + if isinstance(x, tir.IntImm): + return int(x.value) + return None + + lhs_row_offset = _static(lhs_row_offset_raw) + rhs_col_offset = _static(rhs_col_offset_raw) + dst_col_offset = _static(dst_col_offset_raw) + if (lhs_row_offset is None or rhs_col_offset is None + or dst_col_offset is None): + raise PreIsaPassError( + f"plena.mm_slot: dynamic offsets not yet supported by " + f"PreIsaPass; got lhs_row_offset={lhs_row_offset_raw!r}, " + f"rhs_col_offset={rhs_col_offset_raw!r}, " + f"dst_col_offset={dst_col_offset_raw!r}" + ) + col_count = _static(col_count_raw) + if col_count is None or col_count <= 0: + raise PreIsaPassError( + f"plena.mm_slot col_count must be a positive compile-time int; " + f"got {col_count_raw!r}" + ) + + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + if col_count % blen != 0: + raise PreIsaPassError( + f"plena.mm_slot: col_count={col_count} must be divisible by blen={blen}" + ) + tiles_per_slot = col_count // blen + tiles_per_mlen = mlen // blen + # row_stride = blen * mlen — symbolic so PreIsaIR preserves + # the hw-shape algebra. Materialiser folds to literal at emit. + row_stride_expr = tir.Mul(BLEN_VAR, MLEN_VAR) + task_id = op.annotations.get("intrinsic", "mm_slot") + + gid = self._new_group() + self.pre_isa.append(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"slot matmul task {task_id} " + f"rhs_col_offset={rhs_col_offset} " + f"dst_col_offset={dst_col_offset}" + ], + annotations={"group_id": gid}, + )) + # Preamble: legacy emits ``S_ADDI gp{stride}, gp0, 1`` (dead). + stride_const = tir.IntImm("int32", 1) + self.pre_isa.append(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[stride_const], + annotations={"group_id": gid}, + )) + + oc_var = tir.Var(f"mm_slot_oc_{id(op) & 0xffff:x}", "int32") + t_var = tir.Var(f"mm_slot_t_{id(op) & 0xffff:x}", "int32") + + # Outer (oc) per_iter unroll. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_slot], + binds=oc_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "per_iter", + }, + )) + + # Per-oc address PrimExprs. + # act_addr = lhs.address + lhs_row_offset + # (lhs_row_offset is the static offset into a multi-tile lhs). + # mat_addr = rhs.address + rhs_col_offset + oc * blen + # out_addr = dst.address + dst_col_offset + oc * blen + # ``blen`` referenced symbolically via BLEN_VAR. + act_base_expr = tir.Add( + tir.IntImm("int32", int(lhs.address)), + tir.IntImm("int32", lhs_row_offset), + ) + mat_addr_expr = tir.Add( + tir.Add( + tir.IntImm("int32", int(rhs.address)), + tir.IntImm("int32", rhs_col_offset), + ), + tir.Mul(oc_var, BLEN_VAR), + ) + out_addr_expr = tir.Add( + tir.Add( + tir.IntImm("int32", int(dst.address)), + tir.IntImm("int32", dst_col_offset), + ), + tir.Mul(oc_var, BLEN_VAR), + ) + + oc_iter_gid = self._new_group() + + def _stamp_oc(o): + o.annotations["group_id"] = oc_iter_gid + if "close_order" not in o.annotations: + o.annotations["close_order"] = "insertion" + return o + + # Per-oc preloads (legacy emit order: act, mat, out). + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[act_base_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mat_addr_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[out_addr_expr], + ))) + + # Inner (t) shared-scope unroll — the per-iter bumps on act/out + # cached GPs persist across iters. + # For byte-equal we follow legacy's "if t > 0 do bumps; then + # M_MM/M_MM_WO" pattern with the prefix-loop + final-iter + # split mv already uses. + if tiles_per_mlen > 1: + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="LOOP_START", + operands=[0, tiles_per_mlen - 1], + binds=t_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "shared", + "group_id": oc_iter_gid, + "close_order": "insertion", + }, + ))) + # First M_MM (legacy: skip bump on iter 0). + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="M_MM", + operands=[0, mat_addr_expr, act_base_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="M_MM_WO", + operands=[out_addr_expr, "gp0", 0], + ))) + # Bump act + out for the NEXT iter. ``row_stride_expr`` is + # the symbolic ``BLEN_VAR * MLEN_VAR``; BackendEmit folds + # it to a literal stride at emit time. + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[act_base_expr, row_stride_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[out_addr_expr, row_stride_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + ))) + # Final iter — no trailing bumps. + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="M_MM", + operands=[0, mat_addr_expr, act_base_expr], + ))) + self.pre_isa.append(_stamp_oc(PreIsaOp( + opcode="M_MM_WO", + operands=[out_addr_expr, "gp0", 0], + ))) + + # Close outer oc loop. + self.pre_isa.append(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={"loop_kind": "unroll"}, + )) + + # ------------------------------------------------------------------ + # row_*_at — row-scalar VRAM ops with d_tile unroll + # ------------------------------------------------------------------ + def _emit_row_scalar_op_at( + self, + mod: _hlir.HLIRModule, + op: _hlir.Op, + *, + row_op: str, + reduce: bool = False, + masked: bool = False, + has_fp: bool = False, + ) -> None: + """Mirror of legacy ``isa_pass._emit_row_scalar_op_at``. + + Legacy emit order: + 1. Eager: ``materialise(src_addr)`` → S_ADDI_INT gp{src}, ... + 2. Eager: if masked, ``materialise(mask_expr)`` → S_ADDI_INT gp{mask} + 3. Eager: ``materialise(dst_addr or fp_addr)`` → S_ADDI_INT gp{dst} + 4. Eager: if binary-fp, ``materialise(fp_rhs_addr)`` → S_ADDI_INT gp{rhs} + 5. Flush `lines` (header comment + body ISA + mask reset) + + PreIsaIR encodes this as: + _PRELOAD_ADDR src_addr + [if masked] _PRELOAD_ADDR mask_expr + _PRELOAD_ADDR dst_addr_or_fp_addr + [if binary-fp] _PRELOAD_ADDR fp_rhs_addr + _COMMENT "row scalar task ..." + [if masked] C_SET_V_MASK_REG (cached mask GP) + [if reduce] S_LD_FP f1, gp{dst}, 0 -- _S_LD_FP_CACHED + per d_tile: + HW op (V_RED_* / V_EXP_V / V_*_VF) using cached gp{src}, gp{dst} + [if not last d_tile] _BUMP_CACHED_GP src_addr, d_tile_stride_s + [if dst exists] _BUMP_CACHED_GP dst_addr, d_tile_stride_d + [if reduce] S_ST_FP f1, gp{dst}, 0 -- _S_ST_FP_CACHED + [if masked] _S_ADDI_INT_RESET_MASK + C_SET_V_MASK_REG + """ + has_fp = has_fp or reduce + if reduce: + if len(op.buffer_args) != 1: + raise PreIsaPassError( + f"{op.kind} expects 1 buffer_arg (src region); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 1 + elif has_fp: + if len(op.buffer_args) != 2: + raise PreIsaPassError( + f"{op.kind} expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 1 + else: + if len(op.buffer_args) != 2: + raise PreIsaPassError( + f"{op.kind} expects 2 buffer_args (src/dst regions); " + f"got {len(op.buffer_args)}" + ) + expected_scalar = 0 + if len(op.scalar_args) != expected_scalar: + raise PreIsaPassError( + f"{op.kind} expects {expected_scalar} scalar_args; " + f"got {len(op.scalar_args)}" + ) + for slot, name in enumerate(("src",) if reduce else ("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassError( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + + src_region: _hlir.VramRegion = op.buffer_args[0] + src = mod.get_buffer(src_region.parent) + # All non-D extents must be 1 (one logical row per op). + if any(int(e) != 1 for e in src_region.extents[:3]): + raise PreIsaPassError( + f"{op.kind} src: row_*_at processes one logical row, " + f"non-D extents must be 1; got " + f"{tuple(src_region.extents[:3])}" + ) + + fp_addr_expr = None + if has_fp: + fp_addr_expr = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + + src_base_off, src_mask_expr, src_info = ( + self._legacy._logical_to_phys_row_offset(src, src_region) + ) + emit_v_mask = masked and src_mask_expr is not None + use_mask_flag = 1 if emit_v_mask else 0 + + # Build the address PrimExpr objects ONCE (id()-keyed cache). + src_addr = tir.Add(tir.IntImm("int32", int(src.address)), src_base_off) + mask_expr = src_mask_expr if emit_v_mask else None + + # Reduce: scalar_args[0] is the FP destination + # Binary-fp: same — fp_addr_expr is the FP rhs operand. The dst + # buffer is buffer_args[1] (VramRegion). + dst_addr = None + d_tile_stride_d = 0 + n_d_tiles = src_info["d_tiles"] + d_tile_stride_s = src_info["d_tile_stride"] + if not reduce: + dst_region: _hlir.VramRegion = op.buffer_args[1] + dst = mod.get_buffer(dst_region.parent) + if len(dst_region.extents) != 4: + raise PreIsaPassError( + f"{op.kind} dst: region must be 4D; got " + f"extents={tuple(dst_region.extents)}" + ) + if any(int(e) != 1 for e in dst_region.extents[:3]): + raise PreIsaPassError( + f"{op.kind} dst: non-D extents must be 1; " + f"got {tuple(dst_region.extents[:3])}" + ) + dst_base_off, dst_mask_expr, dst_info = ( + self._legacy._logical_to_phys_row_offset(dst, dst_region) + ) + if emit_v_mask and dst_mask_expr is None: + raise PreIsaPassError( + f"{op.kind} src requires packed-head mask but dst " + f"{dst.name!r} does not" + ) + if emit_v_mask and dst_region.parent != src_region.parent: + warnings.warn( + f"{op.kind}: masked V_*_V with dst " + f"{dst_region.parent!r} != src " + f"{src_region.parent!r} — unmasked heads will " + f"overwrite dst with src", + RuntimeWarning, + stacklevel=2, + ) + if dst_info["d_tiles"] != n_d_tiles: + raise PreIsaPassError( + f"{op.kind}: src/dst d_tiles mismatch" + ) + d_tile_stride_d = dst_info["d_tile_stride"] + dst_addr = tir.Add(tir.IntImm("int32", int(dst.address)), dst_base_off) + + # ----- one group_id for the whole row_*_at op ----- + gid = self._new_group() + + def _stamp(o): + o.annotations["group_id"] = gid + return o + + # 1) Preload src first (matches legacy m_src materialise). + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[src_addr], + ))) + # 2) Preload mask if masked. + if emit_v_mask: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[mask_expr], + ))) + # 3) Preload dst / fp_dst. + if reduce: + # The FP destination address (scalar_args[0]). + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[fp_addr_expr], + ))) + else: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[dst_addr], + ))) + # 4) For binary-fp variants, ALSO preload the FP RHS. + if fp_addr_expr is not None: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_PRELOAD_ADDR", operands=[fp_addr_expr], + ))) + + # Header comment. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_COMMENT", + operands=[ + f"row scalar task {op.annotations.get('intrinsic', op.kind)} " + f"op={row_op} " + f"src.parent={src_region.parent} " + f"starts={list(src_region.starts)!r}" + ], + ))) + + # C_SET_V_MASK_REG (if masked). + if emit_v_mask: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_V_MASK_REG", + operands=[mask_expr], + ))) + + # ----- the body ----- + # The d_tile unroll for row_*_at is encoded as a prefix + # LOOP_START(unroll, shared) of length (n_d_tiles - 1) — each + # iter emits the HW op AND a trailing _BUMP_CACHED_GP for + # src (and dst when applicable) — followed by one more + # explicit body emission for the final iter WITHOUT trailing + # bumps. This mirrors legacy's ``if t < n_d_tiles - 1`` guard + # while keeping the loop symbolic in PreIsaIR; the + # ``unroll_scope="shared"`` annotation tells BackendEmit to + # carry the cached GPs (and their bumped values) across + # iterations. + def _emit_one_iter(emit_bumps: bool) -> None: + """Emit one row-op body (single d_tile). When emit_bumps is + True, append the bumps that ready the cached GPs for the + next iter.""" + if reduce: + op_str = {"reduce_max": "V_RED_MAX", + "reduce_sum": "V_RED_SUM"}[row_op] + self.pre_isa.append(_stamp(PreIsaOp( + opcode=op_str, + operands=["f1", src_addr, use_mask_flag], + ))) + elif fp_addr_expr is None: + op_str = {"exp": "_V_EXP_V_ROW", + "reci": "_V_RECI_V_ROW"}[row_op] + self.pre_isa.append(_stamp(PreIsaOp( + opcode=op_str, + operands=[dst_addr, src_addr, use_mask_flag], + ))) + else: + if row_op == "sub": + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_V_SUB_VF_ROW", + operands=[dst_addr, src_addr, "f1", + use_mask_flag, 0], + ))) + else: + op_str = {"add": "_V_ADD_VF_ROW", + "mul": "_V_MUL_VF_ROW"}[row_op] + self.pre_isa.append(_stamp(PreIsaOp( + opcode=op_str, + operands=[dst_addr, src_addr, "f1", + use_mask_flag], + ))) + if emit_bumps: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[src_addr, d_tile_stride_s], + ))) + if not reduce: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_BUMP_CACHED_GP", + operands=[dst_addr, d_tile_stride_d], + ))) + + # Reduce + binary-fp variants seed/finalise an f1 accumulator + # around the d_tile sweep. + if reduce: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_S_LD_FP_CACHED", + operands=["f1", fp_addr_expr, 0], + ))) + elif fp_addr_expr is not None: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_S_LD_FP_CACHED", + operands=["f1", fp_addr_expr, 0], + ))) + + # Prefix unroll loop: (n_d_tiles - 1) iters, each with + # trailing bumps. Final iter is emitted explicitly after, + # without bumps. When n_d_tiles == 1 the loop is empty. + if n_d_tiles > 1: + t_var = tir.Var(f"row_d_tile_{id(op) & 0xffff:x}", "int32") + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_START", + operands=[0, n_d_tiles - 1], + binds=t_var, + annotations={ + "loop_kind": "unroll", + "unroll_scope": "shared", + "group_id": gid, + }, + ))) + _emit_one_iter(emit_bumps=True) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="LOOP_END", + operands=[], + annotations={ + "loop_kind": "unroll", + "group_id": gid, + }, + ))) + # Final iter (always present). + _emit_one_iter(emit_bumps=False) + + if reduce: + # S_ST_FP f1, gp{fp_dst}, 0 — flush accumulator. + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_S_ST_FP_CACHED", + operands=["f1", fp_addr_expr, 0], + ))) + + # Mask reset. + if emit_v_mask: + self.pre_isa.append(_stamp(PreIsaOp( + opcode="_S_ADDI_INT_RESET_MASK", + operands=[mask_expr, "gp0", 0], + ))) + self.pre_isa.append(_stamp(PreIsaOp( + opcode="C_SET_V_MASK_REG", + operands=[mask_expr], + ))) + + +__all__ = ["PreIsaPass", "PreIsaPassError"] diff --git a/tilelang_tvm_compiler/pre_isa_pass_v2.py b/tilelang_tvm_compiler/pre_isa_pass_v2.py new file mode 100644 index 0000000..6b2f0d0 --- /dev/null +++ b/tilelang_tvm_compiler/pre_isa_pass_v2.py @@ -0,0 +1,2145 @@ +"""PreIsaPass v2 — the clean producer. + +Each HLIR op handler emits a sequence of PreIsaOp + LoopRegion items +into a PreIsaModule. The producer does NOT think about registers, +materialisation, GP cache, addr-reg cache, or scope nesting. Its +single job is "what PLENA ISA instructions does this HLIR op +correspond to, with what symbolic operand expressions". + +PrimExpr expansion (turning ``a + b * c`` into S_ADD_INT / S_MUL_INT +chains), simplification (folding hw consts), SSA naming, register +allocation, and ISA text emission all live in later passes +(:mod:`pre_isa_to_mir`, MIR optimise passes, :mod:`mir_to_isa`). + +Currently migrated handlers (everything else raises): + * fp_zero_at + +The pre_isa_pass_v2 dispatch table only contains handlers we've +already migrated. Callers fall back to the legacy ``isa_pass`` for +non-migrated ops. +""" + +from __future__ import annotations + +from typing import Callable, Dict, List + +from tvm import tir + +from . import hlir as _hlir +from . import pre_isa_ir_v2 as pi +from . import scope as _scope +from .program_shim import ProgramShim + + +# Helper to combine a base address + an offset PrimExpr into a single +# operand expression. Returns a PrimExpr; the PreIsa→MIR conversion +# pass simplifies/folds. +def _addr(base: int, offset) -> tir.PrimExpr: + base_imm = tir.IntImm("int32", int(base)) + if isinstance(offset, int) and offset == 0: + return base_imm + if isinstance(offset, tir.IntImm) and int(offset.value) == 0: + return base_imm + return tir.Add(base_imm, offset) + + +def _as_expr(v) -> tir.PrimExpr: + """Coerce ``v`` (int / IntImm / PrimExpr) to a PrimExpr.""" + if isinstance(v, int): + return tir.IntImm("int32", v) + if isinstance(v, tir.IntImm): + return v + if isinstance(v, tir.PrimExpr): + return v + raise TypeError( + f"_as_expr: expected int/IntImm/PrimExpr; got " + f"{type(v).__name__}: {v!r}" + ) + + +class PreIsaPassV2Error(RuntimeError): + pass + + +class PreIsaPassV2: + """Lower an HLIRModule to a clean PreIsaModule. + + Construction takes a ProgramShim — handlers read hardware-shape + constants (mlen, blen, btmm_hlen) via the shim and may also + delegate to legacy helpers (notably ``_resolve_fp_scalar_addr_arg`` + on the legacy IsaEmitterPass) for buffer-address resolution. + """ + + def __init__(self, shim: ProgramShim) -> None: + from .isa_pass import IsaEmitterPass + self.shim = shim + self._legacy = IsaEmitterPass(shim) + self.pre_isa = pi.PreIsaModule(name="") + # Stack of "current append targets" — handlers write into the + # top of this stack. The bottom is the module's top-level body; + # entering a LoopRegion pushes its body list on top so nested + # ops land inside the loop region naturally. Mirrors a scope + # stack — loops ARE scopes. + self._cursor_stack: List[List] = [] + self._dispatch: Dict[ + str, Callable[[_hlir.HLIRModule, _hlir.Op], None], + ] = { + "fp_zero_at": self._emit_fp_zero_at, + "fp_copy_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="copy"), + "fp_exp_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="exp"), + "fp_reci_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="reci"), + "fp_sqrt_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="sqrt"), + "fp_add_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="add"), + "fp_sub_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="sub"), + "fp_mul_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="mul"), + "fp_max_at": lambda m, o: self._emit_fp_scalar_op(m, o, kernel_op="max"), + "v_zero": self._emit_v_zero, + "v_add": lambda m, o: self._emit_v_binary(m, o, opcode="V_ADD_VV"), + "v_sub": lambda m, o: self._emit_v_binary(m, o, opcode="V_SUB_VV"), + "v_mul": lambda m, o: self._emit_v_binary(m, o, opcode="V_MUL_VV"), + "v_exp": lambda m, o: self._emit_v_unary(m, o, opcode="V_EXP_V"), + "v_reci": lambda m, o: self._emit_v_unary(m, o, opcode="V_RECI_V"), + "v_sqrt": lambda m, o: self._emit_v_unary(m, o, opcode="V_SQRT_V"), + "copy_v_to_v": self._emit_copy_v_to_v, + "v_fp_transfer_slice_v_to_fp": lambda m, o: + self._emit_v_fp_transfer_slice(m, o, direction="v_to_fp"), + "v_fp_transfer_slice_fp_to_v": lambda m, o: + self._emit_v_fp_transfer_slice(m, o, direction="fp_to_v"), + "for": self._emit_for, + "row_reduce_max_at": lambda m, o: self._emit_row_scalar( + m, o, row_op="reduce_max", reduce=True, masked=True, + ), + "row_reduce_sum_at": lambda m, o: self._emit_row_scalar( + m, o, row_op="reduce_sum", reduce=True, masked=True, + ), + "row_exp": lambda m, o: self._emit_row_scalar( + m, o, row_op="exp", masked=True, + ), + "row_sub_fp": lambda m, o: self._emit_row_scalar( + m, o, row_op="sub", masked=True, has_fp=True, + ), + "row_mul_fp": lambda m, o: self._emit_row_scalar( + m, o, row_op="mul", masked=True, has_fp=True, + ), + "row_add_fp": lambda m, o: self._emit_row_scalar( + m, o, row_op="add", masked=True, has_fp=True, + ), + "btmm": self._emit_btmm, + "btmv": self._emit_btmv, + "mv": self._emit_mv, + "mm": self._emit_mm, + "mm_slot": self._emit_mm_slot, + "matmul": self._emit_matmul, + "dma_h2v": self._emit_dma_h2v, + "dma_h2m": self._emit_dma_h2m, + "dma_v2h": self._emit_dma_v2h, + "dma_h2v_slice": self._emit_dma_h2v_slice, + "dma_h2m_slice": self._emit_dma_h2m_slice, + "dma_v2h_slice": self._emit_dma_v2h_slice, + } + + # ---- "cursor" scope helpers — handlers append via _append ---- + def _append(self, item) -> None: + """Add a PreIsaOp / LoopRegion to the current scope (top of + ``_cursor_stack``, or the module's top-level body if stack is + empty).""" + if self._cursor_stack: + self._cursor_stack[-1].append(item) + else: + self.pre_isa.append(item) + + def _comment(self, text: str) -> None: + self._append(pi.PreIsaOp(opcode="_COMMENT", operands=[text])) + + def _push_scope(self, body_list: List) -> None: + """Enter a nested scope (LoopRegion body). Subsequent appends + land in ``body_list``.""" + self._cursor_stack.append(body_list) + + def _pop_scope(self) -> None: + self._cursor_stack.pop() + + def run(self, mod: _hlir.HLIRModule) -> pi.PreIsaModule: + _hlir.assert_addresses_resolved(mod) + self.pre_isa = pi.PreIsaModule( + name=mod.name, buffers=dict(mod.buffers), + ) + self._cursor_stack = [] + for hlir_op in mod.ops: + self._dispatch_op(mod, hlir_op) + return self.pre_isa + + def _dispatch_op( + self, mod: _hlir.HLIRModule, hlir_op: _hlir.Op, + ) -> None: + """Find + run the v2 handler for one HLIR op. Used by the + top-level run loop AND recursively by the ``for`` handler + for body sub-ops.""" + handler = self._dispatch.get(hlir_op.kind) + if handler is None: + raise PreIsaPassV2Error( + f"PreIsaPassV2: no handler migrated for HLIR op kind " + f"{hlir_op.kind!r}. Migrate it or dispatch to legacy " + f"isa_pass for this op." + ) + handler(mod, hlir_op) + + # ================================================================== + # migrated handlers + # ================================================================== + def _emit_fp_zero_at( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + """HLIR ``fp_zero_at`` — store FP zero to one FPRAM slot. + + Legacy ISA: + ; fp scalar task op=zero + S_ADDI_INT gp{r}, gp0, + S_ST_FP f0, gp{r}, 0 + + PreIsaIR v2 form — no GPs, no materialisation: + _COMMENT "fp scalar task ..." + S_ST_FP ["f0", , 0] + The conversion pass turns dst_addr_expr into a SSA value chain + producing the right physical GP at MIR-emit time. + """ + if len(op.scalar_args) != 1: + raise PreIsaPassV2Error( + f"fp_zero_at expects 1 scalar address arg; got " + f"{len(op.scalar_args)}" + ) + # Reuse legacy resolver — returns a tir.PrimExpr already + # carrying buf.address + buffer-element offset. + dst_addr_expr = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "dst", + ) + intrinsic = op.annotations.get("intrinsic", op.kind) + self._comment(f"fp scalar task {intrinsic} op=zero") + self._append(pi.PreIsaOp( + opcode="S_ST_FP", + operands=["f0", dst_addr_expr, 0], + )) + + def _emit_fp_scalar_op( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, kernel_op: str, + ) -> None: + """``fp__at`` (2 scalar args: src, dst) or + ``fp__at`` (3 scalar args: lhs, rhs, dst). + + Legacy ISA(unary, e.g. exp): + S_ADDI_INT gp{src}, gp0, + S_ADDI_INT gp{dst}, gp0, + ; fp scalar task ... op=exp + S_LD_FP f1, gp{src}, 0 + S_EXP_FP f1, f1, 0 + S_ST_FP f1, gp{dst}, 0 + + Legacy ISA(binary, e.g. mul): + S_ADDI gp{lhs}, ...; S_ADDI gp{rhs}, ...; S_ADDI gp{dst}, ... + ; fp scalar task ... op=mul + S_LD_FP f1, gp{lhs}, 0 + S_LD_FP f2, gp{rhs}, 0 + S_MUL_FP f1, f1, f2 + S_ST_FP f1, gp{dst}, 0 + + PreIsaIR v2 form — no GPs, no materialisation: + _COMMENT "fp scalar task ..." + S_LD_FP ["f1", , 0] + [S_LD_FP ["f2", , 0]] ; binary only + [...fpreg arguments...] + S_ST_FP ["f1", , 0] + """ + if kernel_op in ("copy", "exp", "reci", "sqrt"): + expected = 2 + else: + expected = 3 + if len(op.scalar_args) != expected: + raise PreIsaPassV2Error( + f"{op.kind} expects {expected} scalar address args; " + f"got {len(op.scalar_args)}" + ) + addr_exprs = [ + self._legacy._resolve_fp_scalar_addr_arg( + mod, a, op.kind, f"arg{i}", + ) + for i, a in enumerate(op.scalar_args) + ] + intrinsic = op.annotations.get("intrinsic", op.kind) + self._comment( + f"fp scalar task {intrinsic} op={kernel_op}" + ) + if kernel_op in ("copy", "exp", "reci", "sqrt"): + src_expr, dst_expr = addr_exprs + # Load src into f1. + self._append(pi.PreIsaOp( + opcode="S_LD_FP", operands=["f1", src_expr, 0], + )) + # Per-op compute (copy is the no-op variant). + if kernel_op == "exp": + self._append(pi.PreIsaOp( + opcode="S_EXP_FP", operands=["f1", "f1", 0], + )) + elif kernel_op == "reci": + self._append(pi.PreIsaOp( + opcode="S_RECI_FP", operands=["f1", "f1"], + )) + elif kernel_op == "sqrt": + self._append(pi.PreIsaOp( + opcode="S_SQRT_FP", operands=["f1", "f1"], + )) + # ``copy`` has no compute — f1 is already src; just store. + # Store f1 into dst. + self._append(pi.PreIsaOp( + opcode="S_ST_FP", operands=["f1", dst_expr, 0], + )) + else: + lhs_expr, rhs_expr, dst_expr = addr_exprs + self._append(pi.PreIsaOp( + opcode="S_LD_FP", operands=["f1", lhs_expr, 0], + )) + self._append(pi.PreIsaOp( + opcode="S_LD_FP", operands=["f2", rhs_expr, 0], + )) + opcode_map = { + "add": "S_ADD_FP", + "sub": "S_SUB_FP", + "mul": "S_MUL_FP", + "max": "S_MAX_FP", + } + self._append(pi.PreIsaOp( + opcode=opcode_map[kernel_op], + operands=["f1", "f1", "f2"], + )) + self._append(pi.PreIsaOp( + opcode="S_ST_FP", operands=["f1", dst_expr, 0], + )) + + # ------------------------------------------------------------------ + # vector ops — VRAM region elementwise + # ------------------------------------------------------------------ + def _emit_v_zero(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """``v_zero`` — dst[region] = 0, lowered to one + ``V_MUL_VF dst, dst, f0, 0`` per MLEN-wide chunk. + + Iterates the legacy ``_vram_region_iter_chunks`` to walk the + region; each chunk's VRAM offset becomes a PrimExpr operand. + """ + if len(op.buffer_args) != 1 or not isinstance( + op.buffer_args[0], _hlir.VramRegion, + ): + raise PreIsaPassV2Error( + f"v_zero expects 1 VramRegion buffer_arg; got " + f"{op.buffer_args!r}" + ) + dst_region: _hlir.VramRegion = op.buffer_args[0] + dst = mod.get_buffer(dst_region.parent) + self._comment( + f"v_zero dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + for d_off, _fp_step in self._legacy._vram_region_iter_chunks( + dst, dst_region, + ): + dst_addr = _addr(int(dst.address), d_off) + # V_MUL_VF dst, dst, f0, 0 (the conversion pass will + # CSE the two ``dst_addr`` operands into a single SSA + # value since they're the same Python object.) + self._append(pi.PreIsaOp( + opcode="V_MUL_VF", + operands=[dst_addr, dst_addr, "f0", 0], + )) + + def _emit_v_binary( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, opcode: str, + ) -> None: + """``v_add`` / ``v_sub`` / ``v_mul`` — elementwise VV. + Per chunk: dst_addr, lhs_addr, rhs_addr, 0 + """ + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"{op.kind} expects 3 buffer_args (lhs, rhs, dst regions); " + f"got {len(op.buffer_args)}" + ) + lhs_region, rhs_region, dst_region = op.buffer_args + for slot, name in enumerate(("lhs", "rhs", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassV2Error( + f"{op.kind} {name}: expected VramRegion, got " + f"{type(op.buffer_args[slot]).__name__}" + ) + lhs = mod.get_buffer(lhs_region.parent) + rhs = mod.get_buffer(rhs_region.parent) + dst = mod.get_buffer(dst_region.parent) + self._comment( + f"v binary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + lhs_iter = self._legacy._vram_region_iter_chunks(lhs, lhs_region) + rhs_iter = self._legacy._vram_region_iter_chunks(rhs, rhs_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (l_off, _), (r_off, _), (d_off, _) in zip( + lhs_iter, rhs_iter, dst_iter, + ): + lhs_addr = _addr(int(lhs.address), l_off) + rhs_addr = _addr(int(rhs.address), r_off) + dst_addr = _addr(int(dst.address), d_off) + self._append(pi.PreIsaOp( + opcode=opcode, + operands=[dst_addr, lhs_addr, rhs_addr, 0], + )) + + def _emit_v_unary( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, opcode: str, + ) -> None: + """``v_exp`` / ``v_reci`` / ``v_sqrt`` — elementwise unary. + Per chunk: dst_addr, src_addr, 0 + """ + if len(op.buffer_args) != 2: + raise PreIsaPassV2Error( + f"{op.kind} expects 2 buffer_args (src, dst regions); " + f"got {len(op.buffer_args)}" + ) + src_region, dst_region = op.buffer_args + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassV2Error( + f"{op.kind} {name}: expected VramRegion" + ) + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + self._comment( + f"v unary {op.kind} {opcode} " + f"dst.parent={dst_region.parent} " + f"starts={list(dst_region.starts)!r} " + f"extents={list(dst_region.extents)!r}" + ) + src_iter = self._legacy._vram_region_iter_chunks(src, src_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = _addr(int(src.address), s_off) + dst_addr = _addr(int(dst.address), d_off) + self._append(pi.PreIsaOp( + opcode=opcode, + operands=[dst_addr, src_addr, 0], + )) + + # ------------------------------------------------------------------ + # VRAM-to-VRAM copy and VRAM ↔ FPRAM slice transfers + # ------------------------------------------------------------------ + def _emit_copy_v_to_v( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + ) -> None: + """``copy_v_to_v`` — dst[region] = src[region]. Each chunk + emits ``V_ADD_VF dst, src, f0, 0`` (f0 == 0 so dst = src + 0).""" + if len(op.buffer_args) != 2: + raise PreIsaPassV2Error( + f"copy_v_to_v expects 2 buffer_args (src, dst); " + f"got {len(op.buffer_args)}" + ) + src_region, dst_region = op.buffer_args + for slot, name in enumerate(("src", "dst")): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassV2Error( + f"copy_v_to_v {name}: expected VramRegion" + ) + src = mod.get_buffer(src_region.parent) + dst = mod.get_buffer(dst_region.parent) + self._comment( + f"copy_v_to_v src.parent={src_region.parent} -> " + f"dst.parent={dst_region.parent} " + f"extents={list(dst_region.extents)!r}" + ) + src_iter = self._legacy._vram_region_iter_chunks(src, src_region) + dst_iter = self._legacy._vram_region_iter_chunks(dst, dst_region) + for (s_off, _), (d_off, _) in zip(src_iter, dst_iter): + src_addr = _addr(int(src.address), s_off) + dst_addr = _addr(int(dst.address), d_off) + # V_ADD_VF dst, src, f0, 0 — adds the zero FP register + # to ``src`` and writes the result to ``dst`` (= copy). + self._append(pi.PreIsaOp( + opcode="V_ADD_VF", + operands=[dst_addr, src_addr, "f0", 0], + )) + + def _emit_v_fp_transfer_slice( + self, mod: _hlir.HLIRModule, op: _hlir.Op, *, direction: str, + ) -> None: + """``v_fp_transfer_slice_``. Each chunk + emits one ``S_MAP_FP_V`` (vram→fpram) or ``S_MAP_V_FP`` + (fpram→vram). The FPRAM address steps by the cumulative + ``fp_step_elems`` returned by ``_vram_region_iter_chunks``. + """ + if len(op.buffer_args) != 1 or not isinstance( + op.buffer_args[0], _hlir.VramRegion, + ): + raise PreIsaPassV2Error( + f"{op.kind}: buffer_args[0] must be VramRegion" + ) + if len(op.scalar_args) != 1: + raise PreIsaPassV2Error( + f"{op.kind} expects 1 scalar arg (fp address); got " + f"{len(op.scalar_args)}" + ) + region = op.buffer_args[0] + vram = mod.get_buffer(region.parent) + fp_base_expr = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + opcode = "S_MAP_FP_V" if direction == "v_to_fp" else "S_MAP_V_FP" + self._comment( + f"v↔fp transfer slice {op.kind} parent={region.parent} " + f"starts={list(region.starts)!r} " + f"extents={list(region.extents)!r}" + ) + for vram_off, fp_step in self._legacy._vram_region_iter_chunks( + vram, region, + ): + vram_addr = _addr(int(vram.address), vram_off) + if fp_step == 0: + fp_addr = fp_base_expr + else: + fp_addr = tir.Add( + fp_base_expr, tir.IntImm("int32", int(fp_step)), + ) + if direction == "v_to_fp": + # S_MAP_FP_V fp_dst, vram_src, 0 + self._append(pi.PreIsaOp( + opcode="S_MAP_FP_V", + operands=[fp_addr, vram_addr, 0], + )) + else: + # S_MAP_V_FP vram_dst, fp_src, 0 + self._append(pi.PreIsaOp( + opcode="S_MAP_V_FP", + operands=[vram_addr, fp_addr, 0], + )) + + # ------------------------------------------------------------------ + # HLIR for-loop — emits a single LoopRegion holding the body + # sub-ops. THIS IS THE WHOLE HANDLER: no GP pinning, no + # symbol_table push/pop, no idx_addr management. The MIR loop's + # body IS a scope (MirBlock), and the pre_isa_to_mir conversion + # binds the loop_var to a fresh ``_LOOP_VAR_DEF`` MirValue at + # the top of that block — body ops referencing the loop_var via + # PrimExpr operands resolve through the converter's symbol table + # automatically. + # ------------------------------------------------------------------ + def _emit_for(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + loop_var = op.annotations.get("loop_var") + extent = op.annotations.get("extent") + init = op.annotations.get("init", 0) + loop_kind = op.annotations.get("loop_kind", "serial") + if loop_var is None or extent is None: + raise PreIsaPassV2Error( + f"for-op missing loop_var or extent annotation: {op!r}" + ) + if not isinstance(extent, (int, tir.IntImm)): + raise PreIsaPassV2Error( + f"for-op extent must be a compile-time integer; got " + f"{extent!r}" + ) + if not isinstance(init, (int, tir.IntImm)): + raise PreIsaPassV2Error( + f"for-op init must be a compile-time integer; got " + f"{init!r}" + ) + if loop_kind == "unrolled": + loop_kind = "unroll" + if loop_kind not in ("serial", "unroll"): + raise PreIsaPassV2Error( + f"for-op unknown loop_kind {loop_kind!r}" + ) + ext_imm = int(extent.value) if isinstance(extent, tir.IntImm) else int(extent) + init_imm = int(init.value) if isinstance(init, tir.IntImm) else int(init) + + # Build the empty LoopRegion, push its body as the current + # scope, dispatch sub-ops, pop. The append-into-current-scope + # discipline means every sub-op handler's _append lands in + # this loop's body without any further coordination. + loop = pi.LoopRegion( + loop_var=loop_var, + init_imm=init_imm, + extent_imm=ext_imm, + loop_kind=loop_kind, + body=[], + ) + # Forward the ``order_independent`` hint if the kernel + # author marked the source for-op as such. Backend uses + # this to drop the IntRAM idx slot + per-iter LD/ADDI/ST + # by running the hw counter directly as the loop_var, + # iterating N..1 instead of 0..N-1. Safe iff body is + # genuinely independent of iteration order. + if "order_independent" in op.annotations: + loop.annotations["order_independent"] = bool( + op.annotations["order_independent"] + ) + self._append(loop) + self._push_scope(loop.body) + try: + for sub_op in op.body or []: + self._dispatch_op(mod, sub_op) + finally: + self._pop_scope() + + # ------------------------------------------------------------------ + # row_*_at — one logical VRAM row across n_d_tiles d-tiles. + # ------------------------------------------------------------------ + def _emit_row_scalar( + self, + mod: _hlir.HLIRModule, + op: _hlir.Op, + *, + row_op: str, + reduce: bool = False, + masked: bool = False, + has_fp: bool = False, + ) -> None: + """Migrated form of legacy ``_emit_row_scalar_op_at``. + + Three flavours: + * reduce_max / reduce_sum (buffer_args=[src], scalar=[fp_addr]) + ``V_RED_MAX/SUM f1, src, mask`` accumulated across d_tiles, + with ``S_LD_FP f1, fp_dst, 0`` before and ``S_ST_FP f1, + fp_dst, 0`` after. + * exp / reci (buffer_args=[src, dst], scalar=[]) + ``V_EXP_V / V_RECI_V dst, src, mask`` per d_tile. + * add / sub / mul (buffer_args=[src, dst], scalar=[fp_rhs]) + ``S_LD_FP f1, fp_rhs, 0`` once, then ``V_*_VF dst, src, + f1, mask`` per d_tile. + + In PreIsaIR v2 form, d-tile iteration is a ``LoopRegion(unroll)`` + with the body's src/dst PrimExpr operands referencing the + d_tile loop var. MIR conversion expands them into SSA chains; + MIR→ISA unrolls. + + Masking: when the source buffer is packed-head, we emit + ``C_SET_V_MASK_REG `` before the loop and + ``C_SET_V_MASK_REG 0`` after (mask reset). + """ + has_fp = has_fp or reduce + if reduce: + if len(op.buffer_args) != 1: + raise PreIsaPassV2Error( + f"{op.kind} expects 1 buffer_arg (src region)" + ) + expected_scalar = 1 + elif has_fp: + if len(op.buffer_args) != 2: + raise PreIsaPassV2Error( + f"{op.kind} expects 2 buffer_args (src, dst regions)" + ) + expected_scalar = 1 + else: + if len(op.buffer_args) != 2: + raise PreIsaPassV2Error( + f"{op.kind} expects 2 buffer_args (src, dst regions)" + ) + expected_scalar = 0 + if len(op.scalar_args) != expected_scalar: + raise PreIsaPassV2Error( + f"{op.kind} expects {expected_scalar} scalar args" + ) + for slot, name in enumerate( + ("src",) if reduce else ("src", "dst"), + ): + if not isinstance(op.buffer_args[slot], _hlir.VramRegion): + raise PreIsaPassV2Error( + f"{op.kind} {name}: expected VramRegion" + ) + + src_region = op.buffer_args[0] + src = mod.get_buffer(src_region.parent) + + fp_addr_expr = None + if has_fp: + fp_addr_expr = self._legacy._resolve_fp_scalar_addr_arg( + mod, op.scalar_args[0], op.kind, "fp", + ) + + src_base_off, src_mask_expr, src_info = ( + self._legacy._logical_to_phys_row_offset(src, src_region) + ) + emit_v_mask = masked and src_mask_expr is not None + mask_flag = 1 if emit_v_mask else 0 + n_d_tiles = int(src_info["d_tiles"]) + d_tile_stride_s = int(src_info["d_tile_stride"]) + + # dst region only for non-reduce. + dst = None + dst_base_off = None + d_tile_stride_d = 0 + if not reduce: + dst_region = op.buffer_args[1] + dst = mod.get_buffer(dst_region.parent) + dst_base_off, _dm, dst_info = ( + self._legacy._logical_to_phys_row_offset(dst, dst_region) + ) + d_tile_stride_d = int(dst_info["d_tile_stride"]) + + intrinsic = op.annotations.get("intrinsic", op.kind) + self._comment( + f"row scalar task {intrinsic} op={row_op} " + f"src.parent={src_region.parent} " + f"starts={list(src_region.starts)!r}" + ) + + # Mask arm — legacy emits ``C_SET_V_MASK_REG gp{mask_expr}`` + # ONCE before the body loop, and ``C_SET_V_MASK_REG gp0`` to + # reset afterwards. + if emit_v_mask: + self._append(pi.PreIsaOp( + opcode="C_SET_V_MASK_REG", + operands=[src_mask_expr], + )) + + # F1 seed for reduce/binary-fp. + if reduce: + self._append(pi.PreIsaOp( + opcode="S_LD_FP", + operands=["f1", fp_addr_expr, 0], + )) + elif fp_addr_expr is not None: + self._append(pi.PreIsaOp( + opcode="S_LD_FP", + operands=["f1", fp_addr_expr, 0], + )) + + # d_tile sweep — wrap in an unroll LoopRegion when n_d_tiles + # > 1; for n_d_tiles == 1, emit the body straight into the + # current scope. Otherwise an extent=1 loop would emit a + # vestigial ``loop_var * stride`` SSA chain that the current + # arith.simplify doesn't fold (loop_var stays symbolic in + # PreIsaIR → MIR conversion). + if n_d_tiles > 1: + t_var = tir.Var(f"d_tile_{id(op) & 0xffff:x}", "int32") + loop = pi.LoopRegion( + loop_var=t_var, init_imm=0, extent_imm=n_d_tiles, + loop_kind="unroll", body=[], + ) + self._append(loop) + self._push_scope(loop.body) + t_var_term_s = ( + tir.Mul(t_var, tir.IntImm("int32", d_tile_stride_s)) + if d_tile_stride_s != 0 else None + ) + t_var_term_d = ( + tir.Mul(t_var, tir.IntImm("int32", d_tile_stride_d)) + if d_tile_stride_d != 0 else None + ) + else: + t_var_term_s = None + t_var_term_d = None + + try: + # Per d_tile addresses. + src_off_expr = src_base_off + if t_var_term_s is not None: + src_off_expr = tir.Add(src_base_off, t_var_term_s) + src_addr = _addr(int(src.address), src_off_expr) + if dst is not None: + dst_off_expr = dst_base_off + if t_var_term_d is not None: + dst_off_expr = tir.Add(dst_base_off, t_var_term_d) + dst_addr = _addr(int(dst.address), dst_off_expr) + + if reduce: + opcode = { + "reduce_max": "V_RED_MAX", + "reduce_sum": "V_RED_SUM", + }[row_op] + # V_RED_* f1, gp_src, mask_flag — accumulates into f1. + self._append(pi.PreIsaOp( + opcode=opcode, + operands=["f1", src_addr, mask_flag], + )) + elif fp_addr_expr is None: + # exp / reci. + opcode = { + "exp": "V_EXP_V", + "reci": "V_RECI_V", + }[row_op] + self._append(pi.PreIsaOp( + opcode=opcode, + operands=[dst_addr, src_addr, mask_flag], + )) + else: + # add / sub / mul with FP scalar f1. PLENA quirk: + # V_SUB_VF takes 5 operands (extra trailing 0 flag); + # V_ADD_VF / V_MUL_VF take 4. + opcode = { + "add": "V_ADD_VF", + "sub": "V_SUB_VF", + "mul": "V_MUL_VF", + }[row_op] + if row_op == "sub": + self._append(pi.PreIsaOp( + opcode=opcode, + operands=[dst_addr, src_addr, "f1", mask_flag, 0], + )) + else: + self._append(pi.PreIsaOp( + opcode=opcode, + operands=[dst_addr, src_addr, "f1", mask_flag], + )) + finally: + if n_d_tiles > 1: + self._pop_scope() + + # F1 flush (reduce only). + if reduce: + self._append(pi.PreIsaOp( + opcode="S_ST_FP", + operands=["f1", fp_addr_expr, 0], + )) + + # Mask reset. + if emit_v_mask: + zero = tir.IntImm("int32", 0) + self._append(pi.PreIsaOp( + opcode="C_SET_V_MASK_REG", + operands=[zero], + )) + + # ------------------------------------------------------------------ + # btmm / btmv — lane-fused matrix × matrix / matrix × vector. + # Each op emits one ``M_BTMM`` / ``M_BTMV`` and a paired + # write-only ``M_BMM_WO`` / ``M_BMV_WO``. No tile loop. + # ------------------------------------------------------------------ + def _emit_btmm_like( + self, mod: _hlir.HLIRModule, op: _hlir.Op, + *, op_mnemonic: str, wo_mnemonic: str, task_default: str, + ) -> None: + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"{op.kind} expects 3 buffer_args (a/b/c regions); " + f"got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassV2Error( + f"{op.kind} a: expected VramRegion" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassV2Error( + f"{op.kind} b: expected MramRegion" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassV2Error( + f"{op.kind} c: expected VramRegion" + ) + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + task_id = op.annotations.get("intrinsic", task_default) + # rhs / lhs addresses are just the buffer.address ints + # (legacy uses ``rhs.address`` literally — region.starts are + # always zero on this op path). + rhs_addr = tir.IntImm("int32", int(rhs.address)) + lhs_addr = tir.IntImm("int32", int(lhs.address)) + dst_addr = tir.IntImm("int32", int(dst.address)) + # Matching tile_count for the write-back. + tile_count = max(1, dst.num_elements // self.shim.tile_elems) + # Header + main op. + self._comment( + f"{task_default} task {task_id} " + f"lhs_packed=vram[{int(lhs.address)}] " + f"rhs_mram={int(rhs.address)} " + f"lanes={self.shim.btmm_lane_count} " + f"head_width={self.shim.btmm_hlen}" + ) + self._append(pi.PreIsaOp( + opcode=op_mnemonic, + operands=["gp0", rhs_addr, lhs_addr], + )) + # Write-back header + op. + self._comment( + f"{task_default} write-only task {task_id}.wo " + f"out=vram[{int(dst.address)}] " + f"tiles={tile_count} " + f"lanes={self.shim.btmm_lane_count} " + f"head_width={self.shim.btmm_hlen}" + ) + self._append(pi.PreIsaOp( + opcode=wo_mnemonic, + operands=[dst_addr, 0], + )) + + def _emit_btmm(self, mod, op): + self._emit_btmm_like( + mod, op, + op_mnemonic="M_BTMM", wo_mnemonic="M_BMM_WO", + task_default="btmm", + ) + + def _emit_btmv(self, mod, op): + self._emit_btmm_like( + mod, op, + op_mnemonic="M_BTMV", wo_mnemonic="M_BMV_WO", + task_default="btmv", + ) + + # ------------------------------------------------------------------ + # mm — single-tile (mlen*mlen) matmul. Walks (oc, orow) grid; + # each iter emits one M_MM + M_MM_WO pair. K stays at 1 here + # (single-tile). Narrow path (rhs_cols < mlen) is a TODO. + # ------------------------------------------------------------------ + def _emit_mm(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"plena.mm expects 3 buffer_args; got {len(op.buffer_args)}" + ) + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + lhs_rows, lhs_cols = self._legacy._logical_2d(lhs.shape) + rhs_rows, rhs_cols = self._legacy._logical_2d(rhs.shape) + dst_rows, dst_cols = self._legacy._logical_2d(dst.shape) + if lhs_rows != mlen or lhs_cols != mlen: + raise PreIsaPassV2Error( + f"plena.mm lhs must be mlen*mlen; got ({lhs_rows},{lhs_cols})" + ) + if rhs_rows != mlen or dst_rows != mlen: + raise PreIsaPassV2Error( + f"plena.mm rhs/dst must have mlen rows" + ) + if not (rhs_cols == mlen and dst_cols == mlen): + raise PreIsaPassV2Error( + f"plena.mm narrow-tile path not yet migrated to v2" + ) + + tiles_per_mlen = mlen // blen + output_row_stride = blen * mlen + task_id = op.annotations.get("intrinsic", "mm") + self._comment( + f"matmul (single-tile, symbolic unroll) task {task_id} " + f"lhs=vram[{int(lhs.address)}] rhs=mram[{int(rhs.address)}] " + f"dst=vram[{int(dst.address)}]" + ) + + oc_var = tir.Var(f"mm_oc_{id(op) & 0xffff:x}", "int32") + orow_var = tir.Var(f"mm_orow_{id(op) & 0xffff:x}", "int32") + + # Outer oc loop. + oc_loop = pi.LoopRegion( + loop_var=oc_var, init_imm=0, extent_imm=tiles_per_mlen, + loop_kind="unroll", body=[], + ) + self._append(oc_loop) + self._push_scope(oc_loop.body) + try: + orow_loop = pi.LoopRegion( + loop_var=orow_var, init_imm=0, extent_imm=tiles_per_mlen, + loop_kind="unroll", body=[], + ) + self._append(orow_loop) + self._push_scope(orow_loop.body) + try: + # Address PrimExprs. + mat_col = tir.Add( + tir.IntImm("int32", int(rhs.address)), + tir.Mul(oc_var, tir.IntImm("int32", blen)), + ) + act_row = tir.Add( + tir.IntImm("int32", int(lhs.address)), + tir.Mul(orow_var, + tir.IntImm("int32", output_row_stride)), + ) + result_addr = tir.Add( + tir.Add( + tir.IntImm("int32", int(dst.address)), + tir.Mul(oc_var, tir.IntImm("int32", blen)), + ), + tir.Mul(orow_var, + tir.IntImm("int32", output_row_stride)), + ) + # M_MM 0, gp{mat_col}, gp{act_row} + self._append(pi.PreIsaOp( + opcode="M_MM", + operands=[0, mat_col, act_row], + )) + # M_MM_WO gp{result_addr}, gp0, 0 + self._append(pi.PreIsaOp( + opcode="M_MM_WO", + operands=[result_addr, "gp0", 0], + )) + finally: + self._pop_scope() + finally: + self._pop_scope() + + # ------------------------------------------------------------------ + # mm_slot — apply M_MM/M_MM_WO over a col-slot of rhs/dst, with + # optional dynamic LHS row, RHS col, and DST col offsets carried as + # PrimExprs through scalar_args. Mirrors legacy + # ``ISAEmitter.emit_slot_matmul``: nested (oc, t) loops, oc walking + # by ``blen`` columns, t walking by ``blen*mlen`` rows. + # ------------------------------------------------------------------ + def _emit_mm_slot(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"plena.mm_slot expects 3 buffer_args; got {len(op.buffer_args)}" + ) + if len(op.scalar_args) != 4: + raise PreIsaPassV2Error( + f"plena.mm_slot expects 4 scalar args " + f"(lhs_row_off, rhs_col_off, dst_col_off, col_count)" + ) + lhs = mod.get_buffer(op.buffer_args[0]) + rhs = mod.get_buffer(op.buffer_args[1]) + dst = mod.get_buffer(op.buffer_args[2]) + lhs_row_raw, rhs_col_raw, dst_col_raw, col_count_raw = op.scalar_args + try: + col_count = int(col_count_raw) + except TypeError as exc: + raise PreIsaPassV2Error( + f"plena.mm_slot col_count must be compile-time int; " + f"got {col_count_raw!r}" + ) from exc + + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + if col_count <= 0 or col_count % blen != 0: + raise PreIsaPassV2Error( + f"plena.mm_slot col_count must be positive multiple of " + f"blen={blen}; got {col_count}" + ) + tiles_per_slot = col_count // blen + tiles_per_mlen = mlen // blen + row_stride = blen * mlen + + # Coerce raw scalars to PrimExpr-or-int. Both kinds end up + # combined into a tir.Add chain; arith.simplify in + # pre_isa_to_mir folds the static parts. + def _expr_or_int(x) -> tir.PrimExpr: + if isinstance(x, int): + return tir.IntImm("int32", x) + if isinstance(x, tir.IntImm): + return x + if isinstance(x, tir.PrimExpr): + return x + raise PreIsaPassV2Error( + f"plena.mm_slot scalar arg must be int or PrimExpr; " + f"got {type(x).__name__}: {x!r}" + ) + + lhs_off_e = _expr_or_int(lhs_row_raw) + rhs_off_e = _expr_or_int(rhs_col_raw) + dst_off_e = _expr_or_int(dst_col_raw) + + task_id = op.annotations.get("intrinsic", "mm_slot") + self._comment( + f"slot matmul task {task_id} " + f"lhs=vram[{int(lhs.address)}] rhs=mram[{int(rhs.address)}] " + f"dst=vram[{int(dst.address)}] col_count={col_count}" + ) + + oc_var = tir.Var(f"slot_oc_{id(op) & 0xffff:x}", "int32") + t_var = tir.Var(f"slot_t_{id(op) & 0xffff:x}", "int32") + + oc_loop = pi.LoopRegion( + loop_var=oc_var, init_imm=0, extent_imm=tiles_per_slot, + loop_kind="unroll", body=[], + ) + self._append(oc_loop) + self._push_scope(oc_loop.body) + try: + t_loop = pi.LoopRegion( + loop_var=t_var, init_imm=0, extent_imm=tiles_per_mlen, + loop_kind="unroll", body=[], + ) + self._append(t_loop) + self._push_scope(t_loop.body) + try: + # act = lhs.base + lhs_row_off + t * row_stride + act_addr = tir.Add( + tir.Add(tir.IntImm("int32", int(lhs.address)), + lhs_off_e), + tir.Mul(t_var, tir.IntImm("int32", row_stride)), + ) + # mat = rhs.base + rhs_col_off + oc * blen + mat_addr = tir.Add( + tir.Add(tir.IntImm("int32", int(rhs.address)), + rhs_off_e), + tir.Mul(oc_var, tir.IntImm("int32", blen)), + ) + # out = dst.base + dst_col_off + oc * blen + t * row_stride + out_addr = tir.Add( + tir.Add( + tir.Add(tir.IntImm("int32", int(dst.address)), + dst_off_e), + tir.Mul(oc_var, tir.IntImm("int32", blen)), + ), + tir.Mul(t_var, tir.IntImm("int32", row_stride)), + ) + self._append(pi.PreIsaOp( + opcode="M_MM", + operands=[0, mat_addr, act_addr], + )) + self._append(pi.PreIsaOp( + opcode="M_MM_WO", + operands=[out_addr, "gp0", 0], + )) + finally: + self._pop_scope() + finally: + self._pop_scope() + + # ------------------------------------------------------------------ + # matmul — unified (M, K) @ (K, N) -> (M, N). + # Region-aware: each operand is a Vram/MramRegion with dim_roles + # ("M"/"K"/"N"/"_") scalar arg. transpose_b inferred from B-region + # axis order. K is folded into the systolic-array accumulator: + # K_tiles M_MM/M_TMM issuances feed one M_MM_WO. 5-level unroll + # over (m, n_mlen, oc, orow, k). + # ------------------------------------------------------------------ + def _emit_matmul(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"plena.matmul expects 3 buffer_args; got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassV2Error( + f"plena.matmul a: expected VramRegion, got {type(a_reg).__name__}" + ) + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassV2Error( + f"plena.matmul b: expected MramRegion, got {type(b_reg).__name__}" + ) + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassV2Error( + f"plena.matmul c: expected VramRegion, got {type(c_reg).__name__}" + ) + if len(op.scalar_args) != 3: + raise PreIsaPassV2Error( + f"plena.matmul expects 3 scalar_args (a/b/c dim_roles)" + ) + a_roles, b_roles, c_roles = op.scalar_args + if len(a_roles) != 4 or len(b_roles) != 4 or len(c_roles) != 4: + raise PreIsaPassV2Error( + f"plena.matmul dim_roles must be 4-tuples" + ) + + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + + def _find_role(roles, role, operand): + hits = [i for i, r in enumerate(roles) if r == role] + if not hits: + raise PreIsaPassV2Error( + f"plena.matmul {operand}: missing role {role!r}" + ) + if len(hits) > 1: + raise PreIsaPassV2Error( + f"plena.matmul {operand}: role {role!r} at multiple axes" + ) + return hits[0] + + c_M_axis = _find_role(c_roles, "M", "c") + a_M_axis = _find_role(a_roles, "M", "a") + a_K_axis = _find_role(a_roles, "K", "a") + b_K_axis = _find_role(b_roles, "K", "b") + b_N_axis = _find_role(b_roles, "N", "b") + + mlen = int(self.shim.mlen) + blen = int(self.shim.blen) + hlen = int(self.shim.btmm_hlen) + + M = int(a_reg.extents[a_M_axis]) + K = int(a_reg.extents[a_K_axis]) + N = int(b_reg.extents[b_N_axis]) + if M % mlen != 0 or K % mlen != 0: + raise PreIsaPassV2Error( + f"plena.matmul: M ({M}) and K ({K}) must be multiples of MLEN ({mlen})" + ) + if N % hlen != 0: + raise PreIsaPassV2Error( + f"plena.matmul: N ({N}) must be multiple of hlen ({hlen})" + ) + M_tiles = M // mlen + K_tiles = K // mlen + N_mlen_tiles = (N + mlen - 1) // mlen + transpose_b = b_N_axis < b_K_axis + + # dst_row_stride — packed-head (cluster_dim==2, lane_count>1, M + # on S axis) uses physical s_inner_stride; otherwise extents + # product after M. + dst_cluster_dim = getattr(dst, "cluster_dim", None) + tl_info = self._legacy._tile_layout_strides(dst) + packed_head_dst = ( + tl_info is not None + and dst_cluster_dim == 2 + and int(tl_info["lane_count"]) > 1 + and c_M_axis == 1 + ) + if packed_head_dst: + dst_row_stride = int(tl_info["s_inner_stride"]) + else: + dst_row_stride = 1 + for ax in range(c_M_axis + 1, len(c_reg.extents)): + dst_row_stride *= int(c_reg.extents[ax]) + if dst_row_stride <= 0: + dst_row_stride = 1 + + # Per-side raw origin offsets — int or PrimExpr. arith.simplify + # in pre_isa_to_mir folds the static parts. + lhs_raw_off = self._legacy._region_origin_offset(lhs, a_reg) + rhs_raw_off = self._legacy._region_origin_offset(rhs, b_reg) + dst_raw_off = self._legacy._region_origin_offset(dst, c_reg) + + def _as_expr(x) -> tir.PrimExpr: + if isinstance(x, int): + return tir.IntImm("int32", x) + return x + + lhs_off_e = _as_expr(lhs_raw_off) + rhs_off_e = _as_expr(rhs_raw_off) + dst_off_e = _as_expr(dst_raw_off) + + # Strides matching emit_matmul_general defaults. + lhs_k_tile_stride = mlen * mlen + lhs_m_tile_stride = K_tiles * mlen * mlen + if transpose_b: + rhs_n_mlen_tile_stride = K_tiles * mlen * mlen + rhs_k_tile_stride = mlen * mlen + oc_b_step = blen * mlen + mm_opcode = "M_TMM" + else: + rhs_n_mlen_tile_stride = mlen * mlen + rhs_k_tile_stride = N_mlen_tiles * mlen * mlen + oc_b_step = blen + mm_opcode = "M_MM" + dst_m_tile_stride = mlen * int(dst_row_stride) + + tiles_per_mlen = mlen // blen + a_orow_step = blen * mlen + c_orow_step = blen * mlen + + task_id = op.annotations.get("intrinsic", "matmul") + self._comment( + f"matmul (general) task {task_id} M={M_tiles*mlen} K={K_tiles*mlen} N={N} " + f"(M_tiles={M_tiles} K_tiles={K_tiles} N_mlen_tiles={N_mlen_tiles}" + f"{', transpose_b' if transpose_b else ''})" + ) + + # Loop vars for the 5-level nest. + op_id = f"{id(op) & 0xffff:x}" + m_var = tir.Var(f"mm_m_{op_id}", "int32") + nm_var = tir.Var(f"mm_nm_{op_id}", "int32") + oc_var = tir.Var(f"mm_oc_{op_id}", "int32") + orow_var = tir.Var(f"mm_orow_{op_id}", "int32") + k_var = tir.Var(f"mm_k_{op_id}", "int32") + + # Build nested LoopRegions. We enter scopes from outer to inner; + # innermost emits two PreIsaOps: a K-loop with M_MM/M_TMM in its + # body, then one M_MM_WO after the K loop in the orow scope. + m_loop = pi.LoopRegion( + loop_var=m_var, init_imm=0, extent_imm=M_tiles, + loop_kind="unroll", body=[], + ) + self._append(m_loop) + self._push_scope(m_loop.body) + try: + for n_mlen_static in range(N_mlen_tiles): + # N_mlen_tiles may have a partial trailing block (cols < mlen); + # tiles_per_n_mlen varies per iter, so we materialise the + # n_mlen loop as a Python-side static unroll over n_mlen. + # The remaining 3 inner levels stay as LoopRegions. + cols_here = min(mlen, N - n_mlen_static * mlen) + tiles_per_n_mlen = cols_here // blen + if tiles_per_n_mlen <= 0: + continue + oc_loop = pi.LoopRegion( + loop_var=oc_var, init_imm=0, + extent_imm=tiles_per_n_mlen, + loop_kind="unroll", body=[], + ) + self._append(oc_loop) + self._push_scope(oc_loop.body) + try: + orow_loop = pi.LoopRegion( + loop_var=orow_var, init_imm=0, + extent_imm=tiles_per_mlen, + loop_kind="unroll", body=[], + ) + self._append(orow_loop) + self._push_scope(orow_loop.body) + try: + # Inner K loop — emits one M_MM/M_TMM per iter, + # accumulating into the systolic-array + # accumulator. M_MM_WO sits AFTER the K loop, + # at the orow-scope level. + k_loop = pi.LoopRegion( + loop_var=k_var, init_imm=0, extent_imm=K_tiles, + loop_kind="unroll", body=[], + ) + self._append(k_loop) + self._push_scope(k_loop.body) + try: + # act = lhs.base + lhs_off + m*lhs_m_tile_stride + # + orow*a_orow_step + k*lhs_k_tile_stride + act_addr = tir.Add( + tir.Add( + tir.Add( + tir.Add( + tir.IntImm("int32", int(lhs.address)), + lhs_off_e, + ), + tir.Mul(m_var, tir.IntImm("int32", lhs_m_tile_stride)), + ), + tir.Mul(orow_var, tir.IntImm("int32", a_orow_step)), + ), + tir.Mul(k_var, tir.IntImm("int32", lhs_k_tile_stride)), + ) + # mat = rhs.base + rhs_off + n_mlen_static*rhs_n_mlen_tile_stride + # + oc*oc_b_step + k*rhs_k_tile_stride + mat_addr = tir.Add( + tir.Add( + tir.Add( + tir.Add( + tir.IntImm("int32", int(rhs.address)), + rhs_off_e, + ), + tir.IntImm("int32", + n_mlen_static * rhs_n_mlen_tile_stride), + ), + tir.Mul(oc_var, tir.IntImm("int32", oc_b_step)), + ), + tir.Mul(k_var, tir.IntImm("int32", rhs_k_tile_stride)), + ) + if transpose_b: + # M_TMM 0, act, mat + self._append(pi.PreIsaOp( + opcode="M_TMM", + operands=[0, act_addr, mat_addr], + )) + else: + # M_MM 0, mat, act + self._append(pi.PreIsaOp( + opcode="M_MM", + operands=[0, mat_addr, act_addr], + )) + finally: + self._pop_scope() + # M_MM_WO: at orow scope. dst_col within the + # output tile = n_mlen*mlen + oc*blen. + dst_col_static = n_mlen_static * mlen + out_addr = tir.Add( + tir.Add( + tir.Add( + tir.Add( + tir.IntImm("int32", int(dst.address)), + dst_off_e, + ), + tir.Mul(m_var, tir.IntImm("int32", dst_m_tile_stride)), + ), + tir.Mul(orow_var, tir.IntImm("int32", c_orow_step)), + ), + tir.Add( + tir.IntImm("int32", dst_col_static), + tir.Mul(oc_var, tir.IntImm("int32", blen)), + ), + ) + self._append(pi.PreIsaOp( + opcode="M_MM_WO", + operands=[out_addr, "gp0", 0], + )) + finally: + self._pop_scope() + finally: + self._pop_scope() + finally: + self._pop_scope() + + # ------------------------------------------------------------------ + # DMA — H↔V / H↔M / V↔H tile-wise transfers. Each HBM buffer carries + # a (row_blocks × col_blocks) tile grid annotation; we walk it via + # legacy's ``_iter_tile_offsets`` and emit one preload/store body + # per tile. Per-tile body is the canonical + # ``C_SET_ADDR_REG`` + ``C_SET_SCALE_REG`` + ``C_SET_STRIDE_REG`` + + # (vlen/preload-stripe of) ``H_PREFETCH_V`` / ``H_PREFETCH_M`` / + # ``H_STORE_V`` sequence. + # ------------------------------------------------------------------ + def _emit_dma_h2v(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + if src.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_h2v src must be HBM; got {src.scope}" + ) + if dst.scope != _scope.VRAM: + raise PreIsaPassV2Error( + f"dma_h2v dst must be VRAM; got {dst.scope}" + ) + for vram_off, hbm_off in self._legacy._iter_tile_offsets(src): + self._comment( + f"dma_h2v tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[vram+{vram_off}]" + ) + self._emit_h2v_tile_body( + hbm_addr=int(src.address), + vram_addr=int(dst.address) + int(vram_off), + hbm_stride=src.hbm_stride, + hbm_scale_size=src.hbm_scale_size, + hbm_start_offset=int(src.hbm_offset) + int(hbm_off), + ) + + def _emit_dma_h2m(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + if src.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_h2m src must be HBM; got {src.scope}" + ) + if dst.scope != _scope.MRAM: + raise PreIsaPassV2Error( + f"dma_h2m dst must be MRAM; got {dst.scope}" + ) + for mram_off, hbm_off in self._legacy._iter_tile_offsets(src): + self._comment( + f"dma_h2m tile {src.name}[hbm+{hbm_off}] -> " + f"{dst.name}[mram+{mram_off}]" + ) + self._emit_h2m_tile_body( + hbm_addr=int(src.address), + mram_addr=int(dst.address) + int(mram_off), + hbm_offset=int(src.hbm_offset) + int(hbm_off), + hbm_scale=src.hbm_scale_size, + hbm_stride=src.hbm_stride, + ) + + def _emit_dma_v2h(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + dst = mod.get_buffer(op.buffer_args[1]) + if src.scope != _scope.VRAM: + raise PreIsaPassV2Error( + f"dma_v2h src must be VRAM; got {src.scope}" + ) + if dst.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_v2h dst must be HBM; got {dst.scope}" + ) + if src.num_elements != dst.num_elements: + raise PreIsaPassV2Error( + f"dma_v2h: src ({src.name}, {src.num_elements} elems) " + f"and dst ({dst.name}, {dst.num_elements} elems) must " + f"have the same total size" + ) + # Walk the HBM (dst) grid; vram_off = idx * tile_elems lines + # up with BMM_WO's BHSD VRAM packing — same convention as + # legacy ``_emit_dma_v2h``. + for vram_off, hbm_off in self._legacy._iter_tile_offsets(dst): + self._comment( + f"dma_v2h tile {src.name}[vram+{vram_off}] -> " + f"{dst.name}[hbm+{hbm_off}]" + ) + self._emit_v2h_tile_body( + vram_addr=int(src.address) + int(vram_off), + hbm_addr=int(dst.address), + hbm_stride=dst.hbm_stride, + hbm_scale_size=dst.hbm_scale_size, + hbm_start_offset=int(dst.hbm_offset) + int(hbm_off), + ) + + # ------------------------------------------------------------------ + # DMA slice variants — same as dma_h2v / _h2m / _v2h but with a + # BufferSlice describing a sub-region of the HBM buffer. Walks a + # 4-level (d_tile, s_tile, h_grp, b) tile grid as nested + # LoopRegions; per-tile body is the same preload/store helper as + # the whole-buffer DMA path, with hbm_off / vram_off addresses + # expressed as PrimExprs referencing the four loop_vars. Slice + # starts that are themselves PrimExprs (e.g. derived from an + # outer kernel-loop var) flow through arith.simplify and only + # crystallise into S_ADDI_INT/S_MUL_INT chains at MIR lowering. + # ------------------------------------------------------------------ + def _slice_offset_expr( + self, parent: _hlir.Buffer, sl: _hlir.BufferSlice, + ) -> tir.PrimExpr: + """Build a unified PrimExpr for ``slice.starts``'s element + offset in ``parent``. Mixes static and dynamic starts; + arith.simplify folds the static sub-trees.""" + offset = tir.IntImm("int32", 0) + shape = parent.shape + for i, s in enumerate(sl.starts): + stride_below = 1 + for d in shape[i + 1:]: + stride_below *= int(d) + if isinstance(s, int): + term = tir.IntImm("int32", s * stride_below) + elif isinstance(s, tir.IntImm): + term = tir.IntImm("int32", int(s.value) * stride_below) + else: + term = tir.Mul(s, tir.IntImm("int32", stride_below)) + offset = tir.Add(offset, term) + if int(parent.hbm_offset): + offset = tir.Add( + offset, tir.IntImm("int32", int(parent.hbm_offset)) + ) + return offset + + def _emit_dma_h2v_slice(self, mod, op) -> None: + sl = op.buffer_args[0] + arg1 = op.buffer_args[1] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassV2Error( + f"dma_h2v_slice: buffer_args[0] must be BufferSlice" + ) + if isinstance(arg1, _hlir.BufferSlice): + raise PreIsaPassV2Error( + f"dma_h2v_slice: dst must be a whole-buffer name" + ) + dst = mod.get_buffer(arg1) + parent = mod.get_buffer(sl.parent) + if parent.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_h2v_slice: src.parent must be HBM" + ) + if dst.scope != _scope.VRAM: + raise PreIsaPassV2Error( + f"dma_h2v_slice: dst must be VRAM" + ) + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._legacy._slice_tile_grid(parent, sl, dst) + ) + base_off_expr = self._slice_offset_expr(parent, sl) + + self._comment( + f"dma_h2v_slice {parent.name} -> {dst.name} " + f"(grid d={d_tiles} s={s_tiles} h={h_groups} b={logical_b})" + ) + self._emit_slice_grid_h2v( + parent=parent, dst=dst, base_off_expr=base_off_expr, + d_tiles=d_tiles, s_tiles=s_tiles, h_groups=h_groups, + logical_b=logical_b, inner_mlen=inner_mlen, + lane_count=lane_count, + hbm_stride_b=hbm_stride_b, hbm_stride_s=hbm_stride_s, + hbm_stride_h=hbm_stride_h, + d_tile_stride=d_tile_stride, s_tile_stride=s_tile_stride, + h_grp_stride=h_grp_stride, b_stride=b_stride, + ) + + def _emit_dma_h2m_slice(self, mod, op) -> None: + sl = op.buffer_args[0] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassV2Error( + f"dma_h2m_slice: buffer_args[0] must be BufferSlice" + ) + dst = mod.get_buffer(op.buffer_args[1]) + parent = mod.get_buffer(sl.parent) + if parent.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_h2m_slice: src.parent must be HBM" + ) + if dst.scope != _scope.MRAM: + raise PreIsaPassV2Error( + f"dma_h2m_slice: dst must be MRAM" + ) + # h2m_slice is single-tile (legacy enforces via + # ``_check_slice_single_tile``); we just emit one + # H_PREFETCH_M with the slice offset folded in. + self._legacy._check_slice_single_tile(parent, sl) + base_off_expr = self._slice_offset_expr(parent, sl) + + self._comment( + f"dma_h2m_slice {parent.name} -> {dst.name}" + ) + self._emit_h2m_tile_body( + hbm_addr=int(parent.address), + mram_addr=int(dst.address), + hbm_offset=base_off_expr, + hbm_scale=parent.hbm_scale_size, + hbm_stride=parent.hbm_stride, + ) + + def _emit_dma_v2h_slice(self, mod, op) -> None: + src = mod.get_buffer(op.buffer_args[0]) + sl = op.buffer_args[1] + if not isinstance(sl, _hlir.BufferSlice): + raise PreIsaPassV2Error( + f"dma_v2h_slice: buffer_args[1] must be BufferSlice" + ) + parent = mod.get_buffer(sl.parent) + if src.scope != _scope.VRAM: + raise PreIsaPassV2Error( + f"dma_v2h_slice: src must be VRAM" + ) + if parent.scope != _scope.HBM: + raise PreIsaPassV2Error( + f"dma_v2h_slice: dst.parent must be HBM" + ) + (d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + (hbm_stride_b, hbm_stride_s, hbm_stride_h), + (d_tile_stride, s_tile_stride, h_grp_stride, b_stride)) = ( + self._legacy._slice_tile_grid(parent, sl, src) + ) + base_off_expr = self._slice_offset_expr(parent, sl) + + self._comment( + f"dma_v2h_slice {src.name} -> {parent.name} " + f"(grid d={d_tiles} s={s_tiles} h={h_groups} b={logical_b})" + ) + self._emit_slice_grid_v2h( + src=src, parent=parent, base_off_expr=base_off_expr, + d_tiles=d_tiles, s_tiles=s_tiles, h_groups=h_groups, + logical_b=logical_b, inner_mlen=inner_mlen, + lane_count=lane_count, + hbm_stride_b=hbm_stride_b, hbm_stride_s=hbm_stride_s, + hbm_stride_h=hbm_stride_h, + d_tile_stride=d_tile_stride, s_tile_stride=s_tile_stride, + h_grp_stride=h_grp_stride, b_stride=b_stride, + ) + + # ----- Slice tile-grid walkers (shared between h2v / v2h slice) ----- + + def _slice_per_tile_addresses( + self, *, base_off_expr, inner_mlen, lane_count, + hbm_stride_b, hbm_stride_s, hbm_stride_h, + d_tile_stride, s_tile_stride, h_grp_stride, b_stride, + d_var, s_var, h_var, b_var, + ): + """Per-tile (hbm_off, vram_off) PrimExprs given the four + loop_vars and the layout strides. Mirrors legacy's per-tile + offset math in ``_emit_dma_h2v_slice`` / + ``_emit_dma_v2h_slice``.""" + hbm_off = tir.Add( + tir.Add( + tir.Add( + tir.Add( + base_off_expr, + tir.Mul(b_var, tir.IntImm("int32", hbm_stride_b)), + ), + tir.Mul(s_var, + tir.IntImm("int32", inner_mlen * hbm_stride_s)), + ), + tir.Mul(h_var, + tir.IntImm("int32", lane_count * hbm_stride_h)), + ), + tir.Mul(d_var, tir.IntImm("int32", inner_mlen)), + ) + vram_off = tir.Add( + tir.Add( + tir.Add( + tir.Mul(d_var, tir.IntImm("int32", d_tile_stride)), + tir.Mul(s_var, tir.IntImm("int32", s_tile_stride)), + ), + tir.Mul(h_var, tir.IntImm("int32", h_grp_stride)), + ), + tir.Mul(b_var, tir.IntImm("int32", b_stride)), + ) + return hbm_off, vram_off + + def _emit_slice_grid_h2v( + self, *, parent, dst, base_off_expr, + d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + hbm_stride_b, hbm_stride_s, hbm_stride_h, + d_tile_stride, s_tile_stride, h_grp_stride, b_stride, + ) -> None: + op_id = f"{id(parent) & 0xffff:x}_{id(dst) & 0xffff:x}" + d_var = tir.Var(f"sl_d_{op_id}", "int32") + s_var = tir.Var(f"sl_s_{op_id}", "int32") + h_var = tir.Var(f"sl_h_{op_id}", "int32") + b_var = tir.Var(f"sl_b_{op_id}", "int32") + + d_loop = pi.LoopRegion( + loop_var=d_var, init_imm=0, extent_imm=d_tiles, + loop_kind="unroll", body=[], + ) + self._append(d_loop) + self._push_scope(d_loop.body) + try: + s_loop = pi.LoopRegion( + loop_var=s_var, init_imm=0, extent_imm=s_tiles, + loop_kind="unroll", body=[], + ) + self._append(s_loop) + self._push_scope(s_loop.body) + try: + h_loop = pi.LoopRegion( + loop_var=h_var, init_imm=0, extent_imm=h_groups, + loop_kind="unroll", body=[], + ) + self._append(h_loop) + self._push_scope(h_loop.body) + try: + b_loop = pi.LoopRegion( + loop_var=b_var, init_imm=0, extent_imm=logical_b, + loop_kind="unroll", body=[], + ) + self._append(b_loop) + self._push_scope(b_loop.body) + try: + hbm_off, vram_off = self._slice_per_tile_addresses( + base_off_expr=base_off_expr, + inner_mlen=inner_mlen, lane_count=lane_count, + hbm_stride_b=hbm_stride_b, + hbm_stride_s=hbm_stride_s, + hbm_stride_h=hbm_stride_h, + d_tile_stride=d_tile_stride, + s_tile_stride=s_tile_stride, + h_grp_stride=h_grp_stride, + b_stride=b_stride, + d_var=d_var, s_var=s_var, + h_var=h_var, b_var=b_var, + ) + vram_addr_expr = tir.Add( + tir.IntImm("int32", int(dst.address)), + vram_off, + ) + self._emit_h2v_tile_body( + hbm_addr=int(parent.address), + vram_addr=vram_addr_expr, + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset=hbm_off, + ) + finally: + self._pop_scope() + finally: + self._pop_scope() + finally: + self._pop_scope() + finally: + self._pop_scope() + + def _emit_slice_grid_v2h( + self, *, src, parent, base_off_expr, + d_tiles, s_tiles, h_groups, logical_b, + inner_mlen, lane_count, + hbm_stride_b, hbm_stride_s, hbm_stride_h, + d_tile_stride, s_tile_stride, h_grp_stride, b_stride, + ) -> None: + op_id = f"{id(parent) & 0xffff:x}_{id(src) & 0xffff:x}" + d_var = tir.Var(f"sl_d_{op_id}", "int32") + s_var = tir.Var(f"sl_s_{op_id}", "int32") + h_var = tir.Var(f"sl_h_{op_id}", "int32") + b_var = tir.Var(f"sl_b_{op_id}", "int32") + + d_loop = pi.LoopRegion( + loop_var=d_var, init_imm=0, extent_imm=d_tiles, + loop_kind="unroll", body=[], + ) + self._append(d_loop) + self._push_scope(d_loop.body) + try: + s_loop = pi.LoopRegion( + loop_var=s_var, init_imm=0, extent_imm=s_tiles, + loop_kind="unroll", body=[], + ) + self._append(s_loop) + self._push_scope(s_loop.body) + try: + h_loop = pi.LoopRegion( + loop_var=h_var, init_imm=0, extent_imm=h_groups, + loop_kind="unroll", body=[], + ) + self._append(h_loop) + self._push_scope(h_loop.body) + try: + b_loop = pi.LoopRegion( + loop_var=b_var, init_imm=0, extent_imm=logical_b, + loop_kind="unroll", body=[], + ) + self._append(b_loop) + self._push_scope(b_loop.body) + try: + hbm_off, vram_off = self._slice_per_tile_addresses( + base_off_expr=base_off_expr, + inner_mlen=inner_mlen, lane_count=lane_count, + hbm_stride_b=hbm_stride_b, + hbm_stride_s=hbm_stride_s, + hbm_stride_h=hbm_stride_h, + d_tile_stride=d_tile_stride, + s_tile_stride=s_tile_stride, + h_grp_stride=h_grp_stride, + b_stride=b_stride, + d_var=d_var, s_var=s_var, + h_var=h_var, b_var=b_var, + ) + vram_addr_expr = tir.Add( + tir.IntImm("int32", int(src.address)), + vram_off, + ) + self._emit_v2h_tile_body( + vram_addr=vram_addr_expr, + hbm_addr=int(parent.address), + hbm_stride=parent.hbm_stride, + hbm_scale_size=parent.hbm_scale_size, + hbm_start_offset=hbm_off, + ) + finally: + self._pop_scope() + finally: + self._pop_scope() + finally: + self._pop_scope() + finally: + self._pop_scope() + + # ----- DMA tile bodies (one HBM tile worth of preload/store) ----- + # + # These mirror ``ISAEmitter._emit_preload_tile_isa`` / + # ``_emit_store_tile_isa`` for the ``batch = mlen`` (multi-row) + # case, which is the only one v2 supports today. The ``batch = 1`` + # narrow-row path exists in legacy but is unreachable from our + # shim's DMA dispatcher (it always sets ``batch = mlen``). + + def _emit_h2v_tile_body( + self, *, hbm_addr: int, vram_addr, + hbm_stride, hbm_scale_size, hbm_start_offset, + ) -> None: + """One HBM-tile's worth of preload, as PreIsaIR. + + ``vram_addr`` and ``hbm_start_offset`` may be int OR + ``tir.PrimExpr`` (slice-aware DMA passes in PrimExprs that + reference outer LoopRegion loop_vars). + + Structure: ``C_SET_ADDR_REG`` + scale/stride + a nested + (outer, inner) LoopRegion whose body issues one + ``H_PREFETCH_V`` with PrimExpr addresses referencing both + loop vars. + """ + mlen = int(self.shim.mlen) + v_prefetch = int(self.shim.v_prefetch_amount) + tile_elems = mlen * mlen + stride_len = mlen if hbm_stride is None else int(hbm_stride) + scale_len = tile_elems if hbm_scale_size is None else int(hbm_scale_size) + vram_addr_e = _as_expr(vram_addr) + hbm_start_e = _as_expr(hbm_start_offset) + + # batch = hidden = mlen for our shim's DMA path, so + # load_amount_per_hidden = ceil(mlen / mlen) = 1 and + # inner_count = ceil(mlen / v_prefetch). When batch <= + # preload, the inner stride term collapses to 0 (no per-inner + # advance) — matches legacy's ``if batch > preload_len`` + # branch. + load_amount_per_hidden = (mlen + mlen - 1) // mlen + if mlen > v_prefetch: + inner_count = (mlen + v_prefetch - 1) // v_prefetch + inner_stride_per_iter = stride_len * v_prefetch + else: + inner_count = 1 + inner_stride_per_iter = 0 + + # 1) Bind a fresh addr_reg to ``hbm_addr``. + addr = pi.PreIsaOp( + opcode="C_SET_ADDR_REG", + # high word = constant-zero gp0; low word = address. + operands=["gp0", int(hbm_addr)], + ) + self._append(addr) + # 2) Scale + stride for this DMA. + self._append(pi.PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[int(scale_len)], + )) + self._append(pi.PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[int(stride_len)], + )) + + # 3) Nested LoopRegions over (outer, inner). Addresses become + # PrimExprs in loop_var; arith.simplify in pre_isa_to_mir + # folds the static constants. + op_id = f"{id(addr) & 0xffff:x}" + outer_var = tir.Var(f"h2v_outer_{op_id}", "int32") + inner_var = tir.Var(f"h2v_inner_{op_id}", "int32") + + outer_loop = pi.LoopRegion( + loop_var=outer_var, init_imm=0, + extent_imm=load_amount_per_hidden, + loop_kind="unroll", body=[], + ) + self._append(outer_loop) + self._push_scope(outer_loop.body) + try: + inner_loop = pi.LoopRegion( + loop_var=inner_var, init_imm=0, extent_imm=inner_count, + loop_kind="unroll", body=[], + ) + self._append(inner_loop) + self._push_scope(inner_loop.body) + try: + # result_addr = vram_addr + # + (outer * inner_count + inner) + # * (mlen * v_prefetch) + row_idx = tir.Add( + tir.Mul(outer_var, + tir.IntImm("int32", inner_count)), + inner_var, + ) + result_addr = tir.Add( + vram_addr_e, + tir.Mul(row_idx, + tir.IntImm("int32", mlen * v_prefetch)), + ) + # hbm_off = hbm_start_offset + # + outer * mlen + # + inner * inner_stride_per_iter + hbm_off = tir.Add( + tir.Add( + hbm_start_e, + tir.Mul(outer_var, + tir.IntImm("int32", mlen)), + ), + tir.Mul(inner_var, + tir.IntImm("int32", + int(inner_stride_per_iter))), + ) + self._append(pi.PreIsaOp( + opcode="H_PREFETCH_V", + operands=[result_addr, hbm_off, addr, 1, 0], + )) + finally: + self._pop_scope() + finally: + self._pop_scope() + + def _emit_h2m_tile_body( + self, *, hbm_addr: int, mram_addr, + hbm_offset, hbm_scale, hbm_stride, + ) -> None: + """One HBM tile → MRAM. ``mram_addr`` and ``hbm_offset`` may + be int OR PrimExpr (slice path with dynamic starts).""" + mlen = int(self.shim.mlen) + tile_elems = mlen * mlen + scale_val = tile_elems if hbm_scale is None else int(hbm_scale) + stride_val = mlen if hbm_stride is None else int(hbm_stride) + + addr = pi.PreIsaOp( + opcode="C_SET_ADDR_REG", + # high word = constant-zero gp0; low word = address. + operands=["gp0", int(hbm_addr)], + + ) + self._append(addr) + self._append(pi.PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[int(scale_val)], + )) + self._append(pi.PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[int(stride_val)], + )) + # One H_PREFETCH_M per tile (mram_addr, hbm_offset both as + # PrimExpr — converter folds static cases via arith.simplify). + self._append(pi.PreIsaOp( + opcode="H_PREFETCH_M", + operands=[ + _as_expr(mram_addr), + _as_expr(hbm_offset), + addr, + 1, 0, + ], + )) + # Reset scale/stride to canonical (tile_elems / mlen) so subsequent + # non-DMA ops see the defaults the rest of the kernel expects — + # matches legacy reset after H_PREFETCH_M. + self._append(pi.PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[int(tile_elems)], + )) + self._append(pi.PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[int(mlen)], + )) + + def _emit_v2h_tile_body( + self, *, vram_addr, hbm_addr: int, + hbm_stride, hbm_scale_size, hbm_start_offset, + ) -> None: + """One HBM tile's worth of writeback. Mirror of + ``_emit_h2v_tile_body``: nested (outer, inner) LoopRegions + with PrimExpr addresses; emits ``H_STORE_V`` instead of + ``H_PREFETCH_V``. ``vram_addr`` / ``hbm_start_offset`` may + be int or PrimExpr. + """ + mlen = int(self.shim.mlen) + v_writeback = int(self.shim.v_writeback_amount) + tile_elems = mlen * mlen + stride_len = mlen if hbm_stride is None else int(hbm_stride) + scale_len = tile_elems if hbm_scale_size is None else int(hbm_scale_size) + vram_addr_e = _as_expr(vram_addr) + hbm_start_e = _as_expr(hbm_start_offset) + + store_amount_per_hidden = (mlen + mlen - 1) // mlen + if mlen > v_writeback: + inner_count = (mlen + v_writeback - 1) // v_writeback + inner_stride_per_iter = stride_len * v_writeback + else: + inner_count = 1 + inner_stride_per_iter = 0 + + addr = pi.PreIsaOp( + opcode="C_SET_ADDR_REG", + # high word = constant-zero gp0; low word = address. + operands=["gp0", int(hbm_addr)], + ) + self._append(addr) + self._append(pi.PreIsaOp( + opcode="C_SET_SCALE_REG", operands=[int(scale_len)], + )) + self._append(pi.PreIsaOp( + opcode="C_SET_STRIDE_REG", operands=[int(stride_len)], + )) + + op_id = f"{id(addr) & 0xffff:x}" + outer_var = tir.Var(f"v2h_outer_{op_id}", "int32") + inner_var = tir.Var(f"v2h_inner_{op_id}", "int32") + + outer_loop = pi.LoopRegion( + loop_var=outer_var, init_imm=0, + extent_imm=store_amount_per_hidden, + loop_kind="unroll", body=[], + ) + self._append(outer_loop) + self._push_scope(outer_loop.body) + try: + inner_loop = pi.LoopRegion( + loop_var=inner_var, init_imm=0, extent_imm=inner_count, + loop_kind="unroll", body=[], + ) + self._append(inner_loop) + self._push_scope(inner_loop.body) + try: + row_idx = tir.Add( + tir.Mul(outer_var, + tir.IntImm("int32", inner_count)), + inner_var, + ) + vram_chunk = tir.Add( + vram_addr_e, + tir.Mul(row_idx, + tir.IntImm("int32", mlen * v_writeback)), + ) + hbm_off = tir.Add( + tir.Add( + hbm_start_e, + tir.Mul(outer_var, + tir.IntImm("int32", mlen)), + ), + tir.Mul(inner_var, + tir.IntImm("int32", + int(inner_stride_per_iter))), + ) + self._append(pi.PreIsaOp( + opcode="H_STORE_V", + operands=[vram_chunk, hbm_off, addr, 1, 0], + )) + finally: + self._pop_scope() + finally: + self._pop_scope() + + # ------------------------------------------------------------------ + # mv — per-head matrix × vector with tile loop over n/blen. + # ------------------------------------------------------------------ + def _emit_mv(self, mod: _hlir.HLIRModule, op: _hlir.Op) -> None: + """Legacy emits ``tiles = n // blen`` (M_MV, M_MV_WO) pairs, + each iter advancing the mat/dst pointers by ``blen``. In v2 + we model the tile loop as an unroll LoopRegion with the + per-tile addresses as ``base + t * blen``. + + Static-offset path only — dynamic region offsets are a TODO. + """ + if len(op.buffer_args) != 3: + raise PreIsaPassV2Error( + f"plena.mv expects 3 buffer_args; got {len(op.buffer_args)}" + ) + a_reg, b_reg, c_reg = op.buffer_args + if not isinstance(a_reg, _hlir.VramRegion): + raise PreIsaPassV2Error(f"plena.mv a: expected VramRegion") + if not isinstance(b_reg, _hlir.MramRegion): + raise PreIsaPassV2Error(f"plena.mv b: expected MramRegion") + if not isinstance(c_reg, _hlir.VramRegion): + raise PreIsaPassV2Error(f"plena.mv c: expected VramRegion") + lhs = mod.get_buffer(a_reg.parent) + rhs = mod.get_buffer(b_reg.parent) + dst = mod.get_buffer(c_reg.parent) + # Region origin offsets — static-only. + lhs_off = self._legacy._region_origin_offset(lhs, a_reg) + rhs_off = self._legacy._region_origin_offset(rhs, b_reg) + dst_off = self._legacy._region_origin_offset(dst, c_reg) + + def _static(x): + if isinstance(x, int): + return int(x) + if isinstance(x, tir.IntImm): + return int(x.value) + return None + + lhs_off_s = _static(lhs_off) + rhs_off_s = _static(rhs_off) + dst_off_s = _static(dst_off) + if (lhs_off_s is None or rhs_off_s is None or dst_off_s is None): + raise PreIsaPassV2Error( + f"plena.mv: dynamic region offsets not yet supported on v2" + ) + task_id = op.annotations.get("intrinsic", "mv") + n = int(self.shim.btmm_hlen) + blen = int(self.shim.blen) + if n % blen != 0: + raise PreIsaPassV2Error( + f"plena.mv: n={n} must be a multiple of blen={blen}" + ) + tiles = n // blen + lhs_vram_addr = int(lhs.address) + lhs_off_s + rhs_mram_addr = int(rhs.address) + rhs_off_s + dst_vram_addr = int(dst.address) + dst_off_s + self._comment( + f"mv task {task_id} " + f"v=vram[{lhs_vram_addr}] " + f"m=mram[{rhs_mram_addr}] " + f"dst=vram[{dst_vram_addr}] " + f"tiles={tiles} blen={blen}" + ) + # Per-tile address PrimExprs. lhs is fixed across tiles (the + # vector base); rhs and dst advance by ``blen`` per tile. + lhs_addr = tir.IntImm("int32", lhs_vram_addr) + if tiles > 1: + t_var = tir.Var(f"mv_t_{id(op) & 0xffff:x}", "int32") + loop = pi.LoopRegion( + loop_var=t_var, init_imm=0, extent_imm=tiles, + loop_kind="unroll", body=[], + ) + self._append(loop) + self._push_scope(loop.body) + try: + rhs_addr = tir.Add( + tir.IntImm("int32", rhs_mram_addr), + tir.Mul(t_var, tir.IntImm("int32", blen)), + ) + dst_addr = tir.Add( + tir.IntImm("int32", dst_vram_addr), + tir.Mul(t_var, tir.IntImm("int32", blen)), + ) + # M_MV gp0, gp{rhs}, gp{lhs} + self._append(pi.PreIsaOp( + opcode="M_MV", + operands=["gp0", rhs_addr, lhs_addr], + )) + # M_MV_WO gp{dst}, 0 + self._append(pi.PreIsaOp( + opcode="M_MV_WO", + operands=[dst_addr, 0], + )) + finally: + self._pop_scope() + else: + rhs_addr = tir.IntImm("int32", rhs_mram_addr) + dst_addr = tir.IntImm("int32", dst_vram_addr) + self._append(pi.PreIsaOp( + opcode="M_MV", + operands=["gp0", rhs_addr, lhs_addr], + )) + self._append(pi.PreIsaOp( + opcode="M_MV_WO", + operands=[dst_addr, 0], + )) + + +__all__ = ["PreIsaPassV2", "PreIsaPassV2Error"] diff --git a/tilelang_tvm_compiler/pre_isa_to_mir.py b/tilelang_tvm_compiler/pre_isa_to_mir.py new file mode 100644 index 0000000..fe4c437 --- /dev/null +++ b/tilelang_tvm_compiler/pre_isa_to_mir.py @@ -0,0 +1,657 @@ +"""PreIsaIR (v2) → MIR conversion pass. + +This is the "explicit conversion" layer the user asked for: every +``tir.PrimExpr`` operand of a PreIsaOp is expanded into a chain of +MIR instructions producing SSA values, then the PreIsaOp itself +becomes a MirInstr referencing those SSA values. + +Pipeline contract: + + PreIsaIR (v2) — opcode + PrimExpr/int/str operands, no registers + yet, address algebra fully symbolic. + │ + │ this pass + ▼ + MIR — SSA values, def/use chains, loops with loop_kind. + Ready for LICM / CSE / DCE / regalloc. + +Three things this pass does: + + 1. **PrimExpr lowering**. Each PrimExpr operand is recursively + turned into a sequence of MirInstrs producing one MirValue per + subexpression. ``tir.Add`` → ``S_ADD_INT`` / ``S_ADDI_INT`` + (folded if the constant fits), ``tir.Mul`` → ``S_MUL_INT`` / + ``S_SLLI_INT`` (power-of-2 strength reduction), ``tir.Var`` + looked up in the symbol table. + + 2. **Symbol table threading**. The hw_consts (MLEN_VAR / BLEN_VAR + / ...) are bound to constant-int MirValues at pass start. + Loop variables (``tir.Var`` introduced by LoopRegion) are bound + to body-block-argument MirValues at loop entry, unbound on exit. + + 3. **arith.Analyzer simplification**. Before lowering each + PrimExpr we run ``arith.Analyzer().simplify`` to fold constants, + normalise ``Add`` order, etc. Hw consts are pre-substituted + to their IntImm bindings so e.g. ``BLEN_VAR * MLEN_VAR`` folds + to a literal IntImm when the shim says blen=4, mlen=64. Loop + vars are NOT substituted — they stay symbolic so MIR sees the + loop-dependent algebra structurally. + +Caching: + Within a single PreIsaOp lowering we cache subexpression results + by ``tvm.ir.structural_equal`` — two operands that are the same + expression structurally share one MirValue chain. This is a free + cheap CSE; the bigger CSE pass later does it across PreIsaOps. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Union # noqa: F401 + +from tvm import arith, tir +from tvm.tir import stmt_functor + +from . import mir +from .pre_isa_ir_v2 import ( + LoopRegion, PreIsaOp, PreIsaModule, KNOWN_OPCODES, +) + + +class PreIsaToMirError(RuntimeError): + pass + + +# --------------------------------------------------------------------- +# Operand classifier — what MIR operand kind a PreIsaOp operand maps to. +# --------------------------------------------------------------------- + +# For each PreIsaIR opcode, we'd ideally cross-check against +# ``mir.OPCODES``'s operand kinds. We do that on the fly in +# ``_lower_one_preop`` — see ``_kind_at``. + +def _kind_at(opcode: str, i: int) -> str: + """Look up the MIR operand kind expected at slot ``i`` of + ``opcode``.""" + spec = mir.OPCODES.get(opcode) + if spec is None: + raise PreIsaToMirError( + f"PreIsaIR opcode {opcode!r} not in mir.OPCODES — add a " + f"matching MIR entry." + ) + if i >= len(spec.operand_kinds): + raise PreIsaToMirError( + f"{opcode}: PreIsaOp gave operand[{i}] but mir.OPCODES " + f"arity is {len(spec.operand_kinds)}" + ) + return spec.operand_kinds[i] + + +# --------------------------------------------------------------------- +# Conversion class +# --------------------------------------------------------------------- + +class PreIsaToMir: + """State for one PreIsaModule → MirFunction conversion.""" + + def __init__(self, mod: PreIsaModule, shim) -> None: + self.preisa = mod + self.shim = shim + self.fn = mir.MirFunction(name=mod.name) + self.fn.metadata["buffers"] = dict(mod.buffers) + # Current block where new MirInstrs / MirLoops are appended. + self.cur_block: Optional[mir.MirBlock] = None + # Symbol tables. + # tir.Var (loop vars + hw_consts) → MirValue (or IntImm for + # hw_consts that get pre-substituted). + self.var_to_value: Dict[tir.Var, mir.MirValue] = {} + # hw_const tir.Var → constant int (the shim's current value). + # Used by arith.Analyzer to pre-substitute when simplifying. + self.hw_const_values: Dict[tir.Var, int] = {} + self._init_hw_consts() + # ``gp0`` is a hardware-fixed constant-zero register used as a + # source on instructions like ``S_ADDI_INT %dst, gp0, imm``. + # We model it as a function-level constant MirValue (its + # ``is_function_const`` flag is True; no instr produces it, + # no block argument). + self.gp0_value: Optional[mir.MirValue] = None + # Per-PreIsaOp expression cache (structural_equal → MirValue). + # Reset by ``_begin_preop``. + self._expr_cache: List = [] + # PreIsaOp identity → its produced MirValue (for ops that + # define an addr_reg or other typed result that downstream + # PreIsaOps reference via ``PreIsaOp`` as an operand). + self._preop_result: Dict[int, mir.MirValue] = {} + # Analyzer reused across all simplifies. + self._analyzer = arith.Analyzer() + + def _init_hw_consts(self) -> None: + # Lazy import to avoid cycle. + from .hw_consts import HW_CONST_ATTRS + for var, attr in HW_CONST_ATTRS.items(): + self.hw_const_values[var] = int(getattr(self.shim, attr)) + + def run(self) -> mir.MirFunction: + # Set up the top block. + top = mir.MirBlock(name="entry") + self.fn.blocks.append(top) + self.cur_block = top + # gp0 is a function-level constant (MLIR-style: a value that + # just "exists"; no instr produces it; no block argument). It + # represents the hardware-fixed zero register. + self.gp0_value = self.fn.make_gp0_const() + + # Walk top-level body. + self._lower_items(self.preisa.body) + return self.fn + + # ----------------------------------------------------------------- + # Body walk + # ----------------------------------------------------------------- + def _lower_items(self, items) -> None: + for it in items: + if isinstance(it, PreIsaOp): + self._lower_one_preop(it) + elif isinstance(it, LoopRegion): + self._lower_loop(it) + else: + raise PreIsaToMirError( + f"unexpected body item type: {type(it).__name__}" + ) + + def _lower_loop(self, lp: LoopRegion) -> None: + # Create body block with loop_var as its block argument + # (MLIR-style: the region supplies the induction var on each + # entry; from the body's perspective it's just an SSA value). + body_blk = mir.MirBlock(name=f"loop.body.{lp.loop_var.name}") + lvar = self.fn.mint_value("i32", hint=lp.loop_var.name) + body_blk.add_argument(lvar) + # Build MirLoop + attach to current block. + loop_obj = mir.MirLoop( + name=f"L_{lp.loop_var.name}", + loop_var=lvar, + init=lp.init_imm, + extent=lp.extent_imm, + body=[body_blk], + loop_kind=lp.loop_kind, + # Forward any PreIsaIR LoopRegion annotations the + # producer set (e.g. ``order_independent`` for the + # reverse-iter optimisation in mir_to_isa). + annotations=dict(lp.annotations), + ) + self.cur_block.append(loop_obj) + # Push symbol table. + if lp.loop_var in self.var_to_value: + raise PreIsaToMirError( + f"loop_var {lp.loop_var.name!r} already bound — nested " + f"loops using the same tir.Var aren't supported (producer " + f"must mint a fresh tir.Var per LoopRegion)" + ) + self.var_to_value[lp.loop_var] = lvar + prev_block = self.cur_block + self.cur_block = body_blk + try: + self._lower_items(lp.body) + finally: + self.cur_block = prev_block + self.var_to_value.pop(lp.loop_var, None) + + # ----------------------------------------------------------------- + # PreIsaOp lowering + # ----------------------------------------------------------------- + def _lower_one_preop(self, op: PreIsaOp) -> None: + # Comment passes through as a meta MirInstr. + if op.opcode == "_COMMENT": + text = op.operands[0] if op.operands else "" + self.cur_block.append(mir.MirInstr( + opcode="_COMMENT", + operands=[text], + result=None, + )) + return + + # Per-op expr cache. + self._expr_cache = [] + spec = mir.OPCODES.get(op.opcode) + if spec is None: + raise PreIsaToMirError( + f"PreIsaOp opcode {op.opcode!r} not in mir.OPCODES" + ) + if len(op.operands) != len(spec.operand_kinds): + raise PreIsaToMirError( + f"{op.opcode}: PreIsaOp has {len(op.operands)} operands " + f"but mir.OPCODES expects {len(spec.operand_kinds)}" + ) + + mir_operands: List = [] + for i, (val, kind) in enumerate( + zip(op.operands, spec.operand_kinds), + ): + mir_operands.append(self._lower_operand(val, kind)) + + # Allocate result MirValue if non-void. + result: Optional[mir.MirValue] = None + if spec.result_type != "void": + result = self.fn.mint_value(spec.result_type) + self.cur_block.append(mir.MirInstr( + opcode=op.opcode, + operands=mir_operands, + result=result, + )) + # Record the result so downstream PreIsaOps referencing this + # op via ``PreIsaOp`` operand can resolve it. + if result is not None: + self._preop_result[id(op)] = result + + # ----------------------------------------------------------------- + # Operand kind dispatch + # ----------------------------------------------------------------- + def _lower_operand(self, val, kind: str): + """Return the MIR-form operand for this PreIsaOp operand.""" + if kind == "i32": + return self._lower_i32_operand(val) + if kind == "literal_int": + return self._lower_literal_int(val) + if kind == "fp_reg": + return self._lower_verbatim_str(val, kind) + if kind == "verbatim_str": + return self._lower_verbatim_str(val, kind) + if kind == "addr_reg": + return self._lower_addr_reg_operand(val) + raise PreIsaToMirError( + f"_lower_operand: unknown operand kind {kind!r}" + ) + + def _lower_i32_operand(self, val) -> mir.MirValue: + """Turn a PrimExpr / int into an i32 MirValue.""" + if isinstance(val, int): + return self._lower_primexpr(tir.IntImm("int32", int(val))) + if isinstance(val, tir.PrimExpr): + return self._lower_primexpr(val) + raise PreIsaToMirError( + f"i32 operand expects PrimExpr / int; got " + f"{type(val).__name__} {val!r}" + ) + + def _lower_literal_int(self, val): + """Pass-through compile-time int literal.""" + if isinstance(val, int): + return int(val) + if isinstance(val, tir.IntImm): + return int(val.value) + raise PreIsaToMirError( + f"literal_int operand expects int / IntImm; got " + f"{type(val).__name__} {val!r}" + ) + + def _lower_verbatim_str(self, val, kind: str) -> str: + if isinstance(val, str): + return val + raise PreIsaToMirError( + f"{kind} operand expects str; got {type(val).__name__} {val!r}" + ) + + def _lower_addr_reg_operand(self, val): + # ``PreIsaOp`` operand → the MirValue produced by that op's + # earlier lowering. The producer must have emitted the + # referenced op before this consumer; we look it up in + # ``_preop_result``. + if isinstance(val, PreIsaOp): + mv = self._preop_result.get(id(val)) + if mv is None: + raise PreIsaToMirError( + f"addr_reg operand: PreIsaOp {val.opcode!r} was not " + f"lowered before this consumer (or it does not " + f"produce an addr_reg result)" + ) + if mv.dtype != "addr_reg": + raise PreIsaToMirError( + f"addr_reg operand: referenced op {val.opcode!r} " + f"produces dtype {mv.dtype!r}, not addr_reg" + ) + return mv + raise PreIsaToMirError( + f"addr_reg operand expects a PreIsaOp reference to a " + f"producer (typically C_SET_ADDR_REG); got " + f"{type(val).__name__} {val!r}" + ) + + # ----------------------------------------------------------------- + # PrimExpr → SSA chain (the heart of the conversion) + # ----------------------------------------------------------------- + def _lower_primexpr(self, expr) -> mir.MirValue: + """Lower a PrimExpr into a chain of MirInstrs ending in an + i32 MirValue.""" + # Simplify first (substitute hw_consts to IntImm + arith.Analyzer + # simplify). Loop vars stay symbolic. + expr = self._simplify(expr) + + # Structural-equal cache lookup. + for cached_expr, cached_val in self._expr_cache: + try: + from tvm import ir as _ir + if _ir.structural_equal(expr, cached_expr): + return cached_val + except Exception: + pass + + val = self._emit_primexpr(expr) + self._expr_cache.append((expr, val)) + return val + + def _simplify(self, expr): + """Substitute hw_const Vars to their IntImm values, then run + arith.Analyzer().simplify. Loop var Vars are intentionally + NOT substituted — they must stay symbolic in MIR so loop- + invariant analysis can spot them.""" + if not isinstance(expr, tir.PrimExpr): + return expr + # Substitute only hw consts. + var_map = { + v: tir.IntImm("int32", n) + for v, n in self.hw_const_values.items() + } + if var_map: + try: + expr = stmt_functor.substitute(expr, var_map) + except Exception: + pass + try: + expr = self._analyzer.simplify(expr) + except Exception: + pass + return expr + + def _emit_primexpr(self, expr) -> mir.MirValue: + """Recursive emit; assumes ``expr`` has been simplified.""" + if isinstance(expr, tir.IntImm): + return self._emit_intimm(int(expr.value)) + if isinstance(expr, tir.Var): + # Loop var lookup. + mv = self.var_to_value.get(expr) + if mv is None: + # Hw const that escaped substitution? Shouldn't happen + # after _simplify. + if expr in self.hw_const_values: + return self._emit_intimm(self.hw_const_values[expr]) + raise PreIsaToMirError( + f"unbound tir.Var {expr.name!r} in PrimExpr; not " + f"a loop var and not in hw_consts" + ) + return mv + if isinstance(expr, tir.Add): + return self._emit_add(expr.a, expr.b) + if isinstance(expr, tir.Sub): + return self._emit_sub(expr.a, expr.b) + if isinstance(expr, tir.Mul): + return self._emit_mul(expr.a, expr.b) + if isinstance(expr, tir.FloorDiv): + return self._emit_floordiv(expr.a, expr.b) + if isinstance(expr, tir.FloorMod): + return self._emit_floormod(expr.a, expr.b) + if isinstance(expr, tir.Call): + return self._emit_call(expr) + # tir.Cast / Min / Max etc. — extend as we hit them. + raise PreIsaToMirError( + f"unsupported PrimExpr node: {type(expr).__name__} ({expr!r})" + ) + + def _emit_call(self, expr: "tir.Call") -> mir.MirValue: + """Lower a TIR Call to a MIR instruction. Currently supports: + * ``tir.shift_left(x, k)`` → ``S_SLLI_INT %x, k`` (k literal) + or ``S_SLL_INT %x, %k`` (k reg) + * ``tir.shift_right(x, k)`` → ``S_SRLI_INT %x, k`` / ``S_SRL_INT %x, %k`` + """ + op_name = expr.op.name if hasattr(expr.op, "name") else str(expr.op) + if op_name in ("tir.shift_left", "tir.shift_right"): + if len(expr.args) != 2: + raise PreIsaToMirError( + f"{op_name}: expected 2 args; got {len(expr.args)}" + ) + x, k = expr.args + is_left = (op_name == "tir.shift_left") + if _is_intlike(k): + # Immediate shift amount. + k_int = _intval(k) + if k_int == 0: + return self._lower_primexpr(x) + lhs = self._lower_primexpr(x) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SLLI_INT" if is_left else "S_SRLI_INT", + operands=[lhs, k_int], + result=dst, + )) + return dst + # Reg-amount shift. + lhs = self._lower_primexpr(x) + rhs = self._lower_primexpr(k) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SLL_INT" if is_left else "S_SRL_INT", + operands=[lhs, rhs], + result=dst, + )) + return dst + raise PreIsaToMirError( + f"unsupported PrimExpr Call: {op_name} ({expr!r})" + ) + + # ---- leaves and small helpers ---- + def _emit_intimm(self, n: int) -> mir.MirValue: + if n == 0: + return self.gp0_value + dst = self.fn.mint_value("i32") + # Fits in S_ADDI_INT immediate? legacy bound 262143. + if 0 <= n <= 262143: + self.cur_block.append(mir.MirInstr( + opcode="S_ADDI_INT", + operands=[self.gp0_value, n], + result=dst, + )) + return dst + # Two-instr form: S_LUI_INT upper; S_ADDI_INT lower. + upper = n >> 12 + lower = n & 0xFFF + hi = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_LUI_INT", + operands=[upper], + result=hi, + )) + self.cur_block.append(mir.MirInstr( + opcode="S_ADDI_INT", + operands=[hi, lower], + result=dst, + )) + return dst + + def _emit_add(self, a, b) -> mir.MirValue: + # x + 0 → x / Both intlike → fold. + if _is_intlike(a) and _is_intlike(b): + return self._emit_intimm(_intval(a) + _intval(b)) + if _is_intlike(b) and _intval(b) == 0: + return self._lower_primexpr(a) + if _is_intlike(a) and _intval(a) == 0: + return self._lower_primexpr(b) + # Imm form: S_ADDI_INT %a, IMM (fits in immediate). + if _is_intlike(b) and 0 <= _intval(b) <= 262143: + lhs = self._lower_primexpr(a) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_ADDI_INT", + operands=[lhs, _intval(b)], + result=dst, + )) + return dst + if _is_intlike(a) and 0 <= _intval(a) <= 262143: + rhs = self._lower_primexpr(b) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_ADDI_INT", + operands=[rhs, _intval(a)], + result=dst, + )) + return dst + # General form. + lhs = self._lower_primexpr(a) + rhs = self._lower_primexpr(b) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_ADD_INT", + operands=[lhs, rhs], + result=dst, + )) + return dst + + def _emit_sub(self, a, b) -> mir.MirValue: + if _is_intlike(a) and _is_intlike(b): + return self._emit_intimm(_intval(a) - _intval(b)) + if _is_intlike(b) and _intval(b) == 0: + return self._lower_primexpr(a) + lhs = self._lower_primexpr(a) + rhs = self._lower_primexpr(b) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SUB_INT", + operands=[lhs, rhs], + result=dst, + )) + return dst + + def _emit_mul(self, a, b) -> mir.MirValue: + if _is_intlike(a) and _is_intlike(b): + return self._emit_intimm(_intval(a) * _intval(b)) + # x * 0 / 0 * x → 0. + if _is_intlike(b) and _intval(b) == 0: + return self.gp0_value + if _is_intlike(a) and _intval(a) == 0: + return self.gp0_value + # x * 1 / 1 * x → x. + if _is_intlike(b) and _intval(b) == 1: + return self._lower_primexpr(a) + if _is_intlike(a) and _intval(a) == 1: + return self._lower_primexpr(b) + # Strength reduce x * 2^k → SLLI_INT %x, k. + if _is_intlike(b): + k = _try_pow2_shift(_intval(b)) + if k is not None: + lhs = self._lower_primexpr(a) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SLLI_INT", + operands=[lhs, k], + result=dst, + )) + return dst + if _is_intlike(a): + k = _try_pow2_shift(_intval(a)) + if k is not None: + rhs = self._lower_primexpr(b) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SLLI_INT", + operands=[rhs, k], + result=dst, + )) + return dst + # General S_MUL_INT. + lhs = self._lower_primexpr(a) + rhs = self._lower_primexpr(b) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_MUL_INT", + operands=[lhs, rhs], + result=dst, + )) + return dst + + def _emit_floordiv(self, a, b) -> mir.MirValue: + if _is_intlike(a) and _is_intlike(b): + return self._emit_intimm(_intval(a) // _intval(b)) + # x // 2^k → SRLI_INT. + if _is_intlike(b): + k = _try_pow2_shift(_intval(b)) + if k is not None: + lhs = self._lower_primexpr(a) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SRLI_INT", + operands=[lhs, k], + result=dst, + )) + return dst + raise PreIsaToMirError( + f"FloorDiv with non-power-of-2 divisor and non-literal LHS " + f"is not lowerable on PLENA (no integer divide); " + f"got {a!r} // {b!r}" + ) + + def _emit_floormod(self, a, b) -> mir.MirValue: + if _is_intlike(a) and _is_intlike(b): + return self._emit_intimm(_intval(a) % _intval(b)) + # x % 2^k = x - (x >> k) << k. Emit directly in MIR + # without re-running arith.simplify (which would fold the + # SRLI+SLLI+Sub chain straight back into FloorMod and + # loop forever). + if _is_intlike(b): + k = _try_pow2_shift(_intval(b)) + if k is not None: + lhs = self._lower_primexpr(a) + shifted = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SRLI_INT", + operands=[lhs, k], + result=shifted, + )) + scaled = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SLLI_INT", + operands=[shifted, k], + result=scaled, + )) + dst = self.fn.mint_value("i32") + self.cur_block.append(mir.MirInstr( + opcode="S_SUB_INT", + operands=[lhs, scaled], + result=dst, + )) + return dst + raise PreIsaToMirError( + f"FloorMod by non-power-of-2 not lowerable; got {a!r} % {b!r}" + ) + + +# --------------------------------------------------------------------- +# Small helpers +# --------------------------------------------------------------------- + +def _is_intlike(x) -> bool: + return isinstance(x, (int, tir.IntImm)) + + +def _intval(x) -> int: + if isinstance(x, tir.IntImm): + return int(x.value) + return int(x) + + +def _try_pow2_shift(n: int) -> Optional[int]: + if n <= 1 or (n & (n - 1)) != 0: + return None + k = n.bit_length() - 1 + if k > 31: + return None + return k + + +# --------------------------------------------------------------------- +# Public entry +# --------------------------------------------------------------------- + +def convert(mod: PreIsaModule, shim) -> mir.MirFunction: + """Convert one PreIsaIR v2 module to a MirFunction. ``shim`` is + used for hw_const substitution (mlen / blen / etc.).""" + return PreIsaToMir(mod, shim).run() + + +__all__ = ["convert", "PreIsaToMir", "PreIsaToMirError"] diff --git a/tilelang_tvm_compiler/program_shim.py b/tilelang_tvm_compiler/program_shim.py new file mode 100644 index 0000000..a1d9932 --- /dev/null +++ b/tilelang_tvm_compiler/program_shim.py @@ -0,0 +1,101 @@ +"""Minimal stand-in for the runtime TileTensorProgram + Compiler objects +that ISAEmitter pokes into. + +ISAEmitter (the file we copied wholesale) reads: + self.program.mlen + self.program.blen + self.program.tile_elems + self.program.btmm_lane_count + self.program.btmm_hlen + self.program.compiler.register_allocator + self.program.compiler.generated_code (string, accumulated by += ) + +For methods we don't use yet (emit_matmul, emit_fp_kernel, ...) it also +touches `self.program._arith_progression` and various tile/_helpers/_types +symbols. Those will fail at call time if invoked. We document the +contract here and add fields lazily as we enable more emitter methods. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from .register_alloc import RegisterAllocator + + +@dataclass +class CompilerShim: + """Holds the pieces ISAEmitter expects under `program.compiler`.""" + + register_allocator: RegisterAllocator = field(default_factory=RegisterAllocator) + generated_code: str = "" + + +@dataclass +class ProgramShim: + """Holds the hardware constants ISAEmitter expects under `program`.""" + + mlen: int + blen: int + btmm_lane_count: int + btmm_hlen: int + # Rows transferred per H_PREFETCH_V / H_STORE_V instruction — the + # emulator's PREFETCH_V_AMOUNT / STORE_V_AMOUNT. The DMA emit must + # use these (not blen) as the per-instruction VLEN-row count, or it + # emits AMOUNT/blen times too many instructions with wrong strides. + v_prefetch_amount: int = 1 + v_writeback_amount: int = 1 + compiler: CompilerShim = field(default_factory=CompilerShim) + + @property + def tile_elems(self) -> int: + return self.mlen * self.mlen + + @staticmethod + def _arith_progression(values): + """Detect an arithmetic progression in a list of ints. + + Returns (start, count, step) when the input is a non-empty AP, + else None. Single-element inputs are treated as a degenerate AP + with step=0 so emit_matmul can use its hardware-loop fast path + with pair_count=1 instead of falling through to explicit + unrolling. + """ + if not values: + return None + if len(values) == 1: + return (int(values[0]), 1, 0) + step = int(values[1]) - int(values[0]) + for i in range(2, len(values)): + if int(values[i]) - int(values[i - 1]) != step: + return None + return (int(values[0]), len(values), step) + + +def make_shim( + *, + mlen: int, + blen: int, + btmm_lane_count: int, + btmm_hlen: int, + v_prefetch_amount: int = 1, + v_writeback_amount: int = 1, + register_allocator: Optional[RegisterAllocator] = None, +) -> ProgramShim: + compiler = CompilerShim(register_allocator=register_allocator or RegisterAllocator()) + # Wire the allocator back to the compiler so auto-spill can emit + # S_ST_INT / S_LD_INT into generated_code. + compiler.register_allocator.compiler = compiler + return ProgramShim( + mlen=mlen, + blen=blen, + btmm_lane_count=btmm_lane_count, + btmm_hlen=btmm_hlen, + v_prefetch_amount=v_prefetch_amount, + v_writeback_amount=v_writeback_amount, + compiler=compiler, + ) + + +__all__ = ["ProgramShim", "CompilerShim", "make_shim"] diff --git a/tilelang_tvm_compiler/register_alloc.py b/tilelang_tvm_compiler/register_alloc.py new file mode 100644 index 0000000..576ff23 --- /dev/null +++ b/tilelang_tvm_compiler/register_alloc.py @@ -0,0 +1,443 @@ +"""Tiny free-list register allocator with optional GP spill to IntRAM. + +ISAEmitter calls into us mid-emit to get scratch registers and returns +them when the instruction sequence is finished: + + gp_regs = compiler.register_allocator.allocate_gp(5) # list[int] + ... emit ISA using gp{gp_regs[0]}, gp{gp_regs[1]}, ... + compiler.register_allocator.free_gp(gp_regs) + +When the free GP pool can't satisfy a request, the runtime can fall +back to a *borrow scope*: temporarily move the contents of some +currently-allocated GPs to IntRAM (``S_ST_INT``), reuse those slots +for the request, then restore (``S_LD_INT``) them when the borrow +ends. Use this around a single leaf emit that doesn't read any of the +spilled GPs in its body: + + borrowed, token = ra.spill_borrow( + 6, compiler=program.compiler, protect=[gp_addr_offset_reg] + ) + # ... emit ISA using borrowed GPs ... + ra.spill_return(token, compiler=program.compiler) + +Pool sizes match the PLENA spec (16 GP, 8 addr) minus a few reserved: + - gp0 reserved as the constant-zero register + - addr0..7 all available; runtime convention reserves none + +IntRAM spill region: + - The IntRAM (1024 u32 words) is shared with the user. We reserve + slots at ``[SPILL_BASE, SPILL_BASE + SPILL_SLOTS)`` for GP + saves. ``SPILL_BASE`` defaults to 256 to leave headroom for any + user preload at the start of IntRAM. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Optional, Tuple + + +class RegisterExhausted(RuntimeError): + pass + + +# IntRAM regions (units = u32 words; emulator's intsram is sized 1024). +# [0, 256) user / preload scratch +# [SPILL_BASE, ...) GP auto-spill backing store +# [IDX_BASE, ...) loop idx backing store +# Keeping the regions disjoint means a loop's idx in IntRAM can never +# be clobbered by a GP spill done by the loop body. +SPILL_BASE = 256 +SPILL_SLOTS = 256 +IDX_BASE = 512 +IDX_SLOTS = 256 + + +@dataclass +class _SpillRecord: + orig_reg: int + slot: int + + +@dataclass +class BorrowToken: + """Opaque handle returned from ``spill_borrow``. Pass back to + ``spill_return`` to release the borrow and reload spilled GPs.""" + borrowed: List[int] + spilled: List[_SpillRecord] = field(default_factory=list) + + +class RegisterAllocator: + def __init__( + self, + *, + gp_total: int = 16, + addr_total: int = 8, + gp_reserved: Iterable[int] = (0,), # gp0 = constant zero + addr_reserved: Iterable[int] = (), + ) -> None: + gp_reserved_set = set(gp_reserved) + addr_reserved_set = set(addr_reserved) + self._gp_total = gp_total + self._gp_free: List[int] = [i for i in range(gp_total) if i not in gp_reserved_set] + # In-use GPs in allocation order (LIFO spill candidates pulled + # from the end). Always disjoint from ``_gp_free``. + self._gp_in_use: List[int] = [] + self._addr_free: List[int] = [i for i in range(addr_total) if i not in addr_reserved_set] + # Spill slot bitmap. + self._spill_slots_in_use: List[bool] = [False] * SPILL_SLOTS + # Idx slot bitmap (loop idx values stored in IntRAM instead of + # GPs so deep nests don't pin the whole GP file). + self._idx_slots_in_use: List[bool] = [False] * IDX_SLOTS + # Late-bound by the shim/compiler so allocate_gp can auto-spill + # on demand. Stays None for tests that don't wire it up. + self.compiler = None + # Each successful auto-spill records (orig_reg, slot) so the + # matching ``free_gp`` reload-restores from IntRAM. + self._auto_spills_by_borrow: Dict[int, _SpillRecord] = {} + # Set of GPs that hold long-lived bindings (loop indices etc.) + # — never picked as spill candidates. The ISA pass populates + # this via ``pin_gp`` / ``unpin_gp`` whenever it binds / + # unbinds a loop var, keeping the symbol_table contents safe + # from being trashed by auto-spill. + self._pinned_gp: set = set() + # GP register event trace, one row per state-changing call. + # Filled by ``_record(...)`` from every public mutator below. + # Dumped at end of compile to ``/.gp_trace.tsv`` + # so the ASM can be cross-referenced line-by-line with the + # allocation events that produced it. + self._trace: List[Dict[str, str]] = [] + # Source-site stack pushed by ISA pass / materializer when they + # want events grouped under a logical operation (e.g. ``op[12] + # row_reduce_sum_at``, ``materialize Add``). Whatever sits on + # top here annotates every event until popped. + self._site_stack: List[str] = [] + + # ------------------------------------------------------------------ + # Event trace + # ------------------------------------------------------------------ + def push_site(self, label: str) -> None: + """Annotate every subsequent event until the matching ``pop_site`` + with ``label``. Used by ISA pass / materializer to tag events + with the logical op or expression they came from.""" + self._site_stack.append(label) + + def pop_site(self) -> None: + if self._site_stack: + self._site_stack.pop() + + def _site(self) -> str: + return " > ".join(self._site_stack) + + def _asm_line(self) -> int: + """Current line count of ``generated_code`` — used as a coarse + cursor so trace rows can be aligned with the ASM dump.""" + if self.compiler is None: + return 0 + return self.compiler.generated_code.count("\n") + 1 + + def _record(self, event: str, **fields: object) -> None: + """Append one row to ``self._trace``. Captures the post-mutation + state (free pool / in-use list / pinned set) so a reader can + replay the timeline without reconstructing it from deltas.""" + row: Dict[str, object] = { + "asm_line": self._asm_line(), + "site": self._site(), + "event": event, + } + row.update(fields) + # Snapshot pool state AFTER the event, so a reader can compare + # consecutive rows to spot leaks / double-frees / pinning bugs. + row["free"] = ",".join(str(r) for r in self._gp_free) + row["in_use"] = ",".join(str(r) for r in self._gp_in_use) + row["pinned"] = ",".join(str(r) for r in sorted(self._pinned_gp)) + self._trace.append(row) + + def trace_rows(self) -> List[Dict[str, object]]: + return list(self._trace) + + # ------------------------------------------------------------------ + # GP register pool + # ------------------------------------------------------------------ + def allocate_gp(self, n: int) -> List[int]: + # Auto-spill is allowed only against UNPINNED in-use regs. Loop + # hw-counters and idx regs are always pinned by `_emit_for`, so + # they can never be picked as spill victims -- which was the + # earlier bug (spilled hw counter reloaded too late, after + # C_LOOP_END had already read garbage). If every in-use reg is + # pinned, `_auto_spill` itself raises RegisterExhausted -- that + # is the kernel-author's signal to convert one of the outer + # `for` loops to `T.unroll(...)` (which doesn't pin + # gp_loop+gp_idx) so non-loop work has room to spill. + if n > len(self._gp_free): + self._auto_spill(n - len(self._gp_free)) + out = self._gp_free[:n] + self._gp_free = self._gp_free[n:] + self._gp_in_use.extend(out) + self._record("allocate_gp", n=n, regs=",".join(str(r) for r in out)) + return out + + def pin_gp(self, reg: int) -> None: + """Mark ``reg`` as carrying a long-lived value (loop index, + symbol-table binding etc.) so ``_auto_spill`` never picks it + as a spill candidate. Spilling such a register would silently + corrupt the binding because the materializer reads it via + the symbol table without going through ``free_gp``.""" + self._pinned_gp.add(reg) + self._record("pin_gp", regs=str(reg)) + + def unpin_gp(self, reg: int) -> None: + self._pinned_gp.discard(reg) + self._record("unpin_gp", regs=str(reg)) + + def _auto_spill(self, need: int) -> None: + """Free up ``need`` more GPs by spilling the most-recently + allocated in-use ones to IntRAM. Each spilled GP is recorded + keyed by its (new) register number so the matching ``free_gp`` + triggers a reload. + + Pinned GPs (loop indices etc.) are never spilled because the + symbol table refers to them by register number without going + through ``free_gp`` — spilling them would silently corrupt + the value the materializer reads back.""" + if self.compiler is None: + raise RegisterExhausted( + f"auto-spill required ({need} more GP) but no compiler " + f"bound on the allocator; call shim.compiler.bind_allocator() " + f"or pass `compiler=` to the RegisterAllocator constructor" + ) + candidates: List[int] = [] + for r in reversed(self._gp_in_use): + if r in self._pinned_gp: + continue + candidates.append(r) + if len(candidates) == need: + break + if len(candidates) < need: + raise RegisterExhausted( + f"auto-spill: need {need} GPs but only {len(candidates)} " + f"unpinned in-use to spill; in_use={self._gp_in_use!r} " + f"pinned={sorted(self._pinned_gp)!r}. " + f"Hint: every in-use GP is pinned (typically by nested " + f"`for` loops reserving gp_loop+gp_idx each). Convert one " + f"of the outer loops to `T.unroll(...)` so it doesn't " + f"pin two regs, leaving room for non-loop work to spill." + ) + for r in candidates: + slot = self._claim_spill_slot() + addr = SPILL_BASE + slot + self.compiler.generated_code += ( + f"; auto-spill gp{r} -> intram[{addr}]\n" + f"S_ST_INT gp{r}, gp0, {addr}\n" + ) + self._gp_in_use.remove(r) + self._gp_free.insert(0, r) + # Record the spill keyed by the register number — when the + # caller frees this register later we use the same key to + # reload its contents into the same physical GP. + self._auto_spills_by_borrow[r] = _SpillRecord(orig_reg=r, slot=slot) + self._record( + "auto_spill", regs=str(r), slot=slot, addr=addr, + ) + + def free_gp(self, regs: Iterable[int]) -> None: + # Push back at the front to maximise locality (next alloc reuses + # the same register, keeping the live range short and dump + # human-readable). If this reg was auto-spilled earlier, emit a + # reload first so the original outer-scope content is restored + # before anyone re-allocates the register. + for r in regs: + if r in self._gp_free: + # Idempotent free: some callers (e.g. ExprMaterializer + # tracking both ``register`` and ``intermediates``) may + # release the same register twice when constant-folding + # collapsed an intermediate onto the final result reg. + # Tolerate it instead of crashing. + self._record("free_gp_noop", regs=str(r)) + continue + if r in self._gp_in_use: + self._gp_in_use.remove(r) + rec = self._auto_spills_by_borrow.pop(r, None) + if rec is not None: + addr = SPILL_BASE + rec.slot + if self.compiler is not None: + self.compiler.generated_code += ( + f"; auto-reload gp{r} <- intram[{addr}]\n" + f"S_LD_INT gp{r}, gp0, {addr}\n" + ) + self._release_spill_slot(rec.slot) + # Reg goes back to in-use (its outer-scope owner still + # holds it). Don't push to free pool. + self._gp_in_use.append(r) + self._record( + "auto_reload", regs=str(r), slot=rec.slot, addr=addr, + ) + continue + self._gp_free.insert(0, r) + self._record("free_gp", regs=str(r)) + + # ------------------------------------------------------------------ + # Spill-borrow API + # ------------------------------------------------------------------ + def spill_borrow( + self, + n: int, + *, + compiler, + protect: Optional[Iterable[int]] = None, + ) -> Tuple[List[int], BorrowToken]: + """Borrow ``n`` GP registers, spilling currently-allocated ones + to IntRAM if necessary. Emits ``S_ST_INT`` lines into + ``compiler.generated_code`` for every spilled GP. Returns + ``(borrowed, token)`` — pass ``token`` back to ``spill_return`` + to restore the spilled state. + + ``protect`` is a set of currently-in-use GPs the caller still + needs to read inside the borrow scope — they are excluded from + spill candidates and will trigger ``RegisterExhausted`` if + spilling them is the only way to satisfy the request. + """ + protect_set = set(protect or ()) + protect_set.discard(0) # gp0 reserved-zero is never spillable anyway + + need = n - len(self._gp_free) + spilled: List[_SpillRecord] = [] + if need > 0: + candidates: List[int] = [] + for r in reversed(self._gp_in_use): + if r in protect_set: + continue + # Pinned GPs (loop hw counters, long-lived symbol-table + # bindings) are referenced by register number without + # going through free_gp, so spilling them would silently + # corrupt the value. Matches the _auto_spill filter. + if r in self._pinned_gp: + continue + candidates.append(r) + if len(candidates) == need: + break + if len(candidates) < need: + raise RegisterExhausted( + f"spill_borrow: need to spill {need} GP(s) but only " + f"{len(candidates)} are unprotected; in_use=" + f"{self._gp_in_use!r} protect={sorted(protect_set)!r} " + f"pinned={sorted(self._pinned_gp)!r}" + ) + for r in candidates: + slot = self._claim_spill_slot() + addr = SPILL_BASE + slot + compiler.generated_code += ( + f"; spill gp{r} -> intram[{addr}]\n" + f"S_ST_INT gp{r}, gp0, {addr}\n" + ) + spilled.append(_SpillRecord(orig_reg=r, slot=slot)) + self._gp_in_use.remove(r) + self._gp_free.insert(0, r) + self._record( + "borrow_spill", regs=str(r), slot=slot, addr=addr, + ) + + borrowed = self.allocate_gp(n) + self._record("spill_borrow", n=n, + regs=",".join(str(r) for r in borrowed), + spilled=",".join(str(s.orig_reg) for s in spilled)) + return borrowed, BorrowToken(borrowed=borrowed, spilled=spilled) + + def spill_return(self, token: BorrowToken, *, compiler) -> None: + """End a borrow scope: free the borrowed GPs, re-allocate the + previously spilled GPs at their original register numbers, and + emit ``S_LD_INT`` to restore their contents from IntRAM.""" + self.free_gp(token.borrowed) + for rec in token.spilled: + if rec.orig_reg in self._gp_free: + self._gp_free.remove(rec.orig_reg) + else: + raise RuntimeError( + f"spill_return: expected gp{rec.orig_reg} to be free " + f"(no one else may take it during a borrow scope), but " + f"free pool is {self._gp_free!r}" + ) + self._gp_in_use.append(rec.orig_reg) + addr = SPILL_BASE + rec.slot + compiler.generated_code += ( + f"; reload gp{rec.orig_reg} <- intram[{addr}]\n" + f"S_LD_INT gp{rec.orig_reg}, gp0, {addr}\n" + ) + self._release_spill_slot(rec.slot) + self._record( + "borrow_reload", regs=str(rec.orig_reg), + slot=rec.slot, addr=addr, + ) + + def _claim_spill_slot(self) -> int: + for i, used in enumerate(self._spill_slots_in_use): + if not used: + self._spill_slots_in_use[i] = True + return i + raise RegisterExhausted( + f"spill slots exhausted ({SPILL_SLOTS} used). Bump SPILL_SLOTS " + f"or reduce simultaneous register pressure." + ) + + def _release_spill_slot(self, slot: int) -> None: + if not self._spill_slots_in_use[slot]: + raise RuntimeError(f"double-release of spill slot {slot}") + self._spill_slots_in_use[slot] = False + + # ------------------------------------------------------------------ + # Loop idx slot pool (IntRAM-backed loop indices). Disjoint from + # spill slots so a body's GP spill can't clobber an outer loop's idx. + # ------------------------------------------------------------------ + def claim_idx_slot(self) -> int: + """Allocate an IntRAM word for a loop's idx. Returns the + absolute IntRAM address (suitable for `S_LD_INT gp, gp0, addr`). + """ + for i, used in enumerate(self._idx_slots_in_use): + if not used: + self._idx_slots_in_use[i] = True + self._record("claim_idx_slot", slot=i, addr=IDX_BASE + i) + return IDX_BASE + i + raise RegisterExhausted( + f"idx slots exhausted ({IDX_SLOTS} used). Bump IDX_SLOTS or " + f"reduce simultaneous loop nesting depth." + ) + + def release_idx_slot(self, addr: int) -> None: + slot = addr - IDX_BASE + if not (0 <= slot < IDX_SLOTS): + raise RuntimeError(f"release_idx_slot: addr {addr} out of range") + if not self._idx_slots_in_use[slot]: + raise RuntimeError(f"double-release of idx slot {slot}") + self._idx_slots_in_use[slot] = False + self._record("release_idx_slot", slot=slot, addr=addr) + + # ------------------------------------------------------------------ + # Address register pool + # ------------------------------------------------------------------ + def allocate_addr(self, n: int) -> List[int]: + if n > len(self._addr_free): + raise RegisterExhausted( + f"requested {n} addr registers but only {len(self._addr_free)} free" + ) + out = self._addr_free[:n] + self._addr_free = self._addr_free[n:] + self._record( + "allocate_addr", n=n, regs=",".join(f"a{r}" for r in out), + ) + return out + + def free_addr(self, regs: Iterable[int]) -> None: + for r in regs: + if r in self._addr_free: + raise RuntimeError(f"double-free of a{r}") + self._addr_free.insert(0, r) + self._record("free_addr", regs=f"a{r}") + + +__all__ = [ + "RegisterAllocator", + "RegisterExhausted", + "BorrowToken", + "SPILL_BASE", + "SPILL_SLOTS", +] diff --git a/tilelang_tvm_compiler/scope.py b/tilelang_tvm_compiler/scope.py new file mode 100644 index 0000000..ec337db --- /dev/null +++ b/tilelang_tvm_compiler/scope.py @@ -0,0 +1,67 @@ +"""PLENA storage scopes. + +In TVM TIR, a buffer's "storage scope" is just a string attached to the +buffer. We pick a fixed vocabulary here so different parts of the compiler +agree on which physical memory each buffer lives in. + +Physical scope vocabulary (mirrors PLENA hardware): + HBM -- main DRAM, source/sink for DMA + VRAM -- vector SRAM, LHS operand of BTMM/MM and target of vector ops + MRAM -- matrix SRAM, RHS operand of BTMM/MM + FPRAM -- on-chip FP buffer (small, used for staging) + +Block-private declared scopes (tilelang DSL surface — these participate in +the compiler's lane-fusion expansion via `allocate_group_memory`): + shared.dyn -- T.alloc_shared default; resolved to vram / mram by + scope_inference based on usage + local.fragment -- T.alloc_fragment default; resolved to vram / fpram + +Global declared scopes (user-declared, authoritative — these do NOT +participate in lane-fusion expansion; their shape IS their physical layout): + global.vram, global.fpram, global.mram + +`global.*` is for SRAM tensors that outlive the kernel — typically a buffer +that the testbench preloads before the kernel runs (e.g. weights into FPRAM) +or reads directly after the kernel finishes (e.g. an output cache in VRAM). +The user writes the physical shape they want; the compiler keeps it as-is. + +PrimFunc parameters (function arguments) are treated as HBM by default. +""" + +HBM = "hbm" +VRAM = "vram" +MRAM = "mram" +FPRAM = "fpram" + +PHYSICAL_SCOPES = (HBM, VRAM, MRAM, FPRAM) + +GLOBAL_PREFIX = "global." +GLOBAL_VRAM = GLOBAL_PREFIX + VRAM +GLOBAL_FPRAM = GLOBAL_PREFIX + FPRAM +GLOBAL_MRAM = GLOBAL_PREFIX + MRAM + +GLOBAL_SCOPES = (GLOBAL_VRAM, GLOBAL_FPRAM, GLOBAL_MRAM) + +# All scope strings the compiler treats as final answers (no inference). +ALL_SCOPES = PHYSICAL_SCOPES + GLOBAL_SCOPES + + +def is_known(scope: str) -> bool: + return scope in ALL_SCOPES + + +def is_global_scope(scope: str) -> bool: + """True for user-declared `global.` scopes that bypass lane-fusion + expansion. Buffers with these scopes carry their physical layout in their + declared shape and must not be rewritten by `allocate_group_memory`.""" + return scope.startswith(GLOBAL_PREFIX) and scope in GLOBAL_SCOPES + + +def physical_scope(scope: str) -> str: + """Strip the `global.` prefix if present. Downstream passes that only + care about *where* a buffer lives in hardware (address allocation, + codegen, ISA emit) can call this to collapse `global.vram` and `vram` + to the same answer (`vram`).""" + if scope.startswith(GLOBAL_PREFIX): + return scope[len(GLOBAL_PREFIX):] + return scope diff --git a/tilelang_tvm_compiler/scripts/__init__.py b/tilelang_tvm_compiler/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilelang_tvm_compiler/scripts/run_flash_attention_midir.py b/tilelang_tvm_compiler/scripts/run_flash_attention_midir.py new file mode 100644 index 0000000..6a504ff --- /dev/null +++ b/tilelang_tvm_compiler/scripts/run_flash_attention_midir.py @@ -0,0 +1,201 @@ +"""Run flash_attention_min through the new mid_ir pipeline end-to-end +and print the resulting HLIR. + +Usage: + nix develop --command bash -c ' + PYTHONPATH=compiler .venv/bin/python -m \\ + tilelang_tvm_compiler.scripts.run_flash_attention_midir + ' + +Or with the .venv-tvm Python — but that one doesn't have tilelang +installed today (env-split issue documented elsewhere). Pick whichever +venv has tilelang + tvm both available. + +Output: + * /.midir.txt (mid_ir snapshot before lowering) + * stdout: formatted HLIR module +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +# Local helper: build a raw PrimFunc straight from the kernel source, +# bypassing the legacy compile_func wrapper (which would run the old +# frontend pipeline). We monkey-import the kernel module and grab its +# T.prim_func before make_*_min calls compile_func on it. + +import tilelang.language as T + +# Pull KIND constant + recreate the kernel literally so we get the raw +# PrimFunc (instead of the post-frontend lowered one make_flash_attention_min +# returns). +from tilelang_tvm_compiler.frontend.gemm_macros import KIND + + +def build_raw_flash_attention_min(*, + rows: int = 64, + hlen: int = 16, + head_count: int = 4, + num_kv_blocks: int = 2, + num_q_blocks: int = 2): + """Mirror of make_flash_attention_min in kernels/flash_attention_min.py + but stops *before* compile_func — returns the raw tir.PrimFunc.""" + + MLEN = 64 + if rows != MLEN: + raise ValueError(f"rows must == MLEN={MLEN}, got {rows}") + if MLEN % hlen != 0: + raise ValueError(f"hlen must divide MLEN={MLEN}, got {hlen}") + hardware_lane_count = MLEN // hlen + if head_count % hardware_lane_count != 0: + raise ValueError( + f"head_count must be multiple of MLEN/hlen={hardware_lane_count}" + ) + + kv_seq = num_kv_blocks * rows + q_seq = num_q_blocks * rows + + @T.prim_func + def flash_attention_min( + Q_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + K_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + V_hbm: T.Tensor((1, kv_seq, head_count, hlen), "float16"), + O_hbm: T.Tensor((1, q_seq, head_count, hlen), "float16"), + ): + with T.Kernel(num_q_blocks, head_count, threads=128) as (q_block, by): + Q_sh = T.alloc_shared((rows, hlen), "float16") + K_sh = T.alloc_shared((rows, hlen), "float16") + V_sh = T.alloc_shared((rows, hlen), "float16") + PV_loc = T.alloc_fragment((rows, hlen), "float16") + O_loc = T.alloc_fragment((rows, hlen), "float16") + S_loc = T.alloc_fragment((rows, MLEN), "float16") + M_OLD = T.alloc_fragment((rows,), "float16") + M_CURR = T.alloc_fragment((rows,), "float16") + M_RES = T.alloc_fragment((rows,), "float16") + L_OLD = T.alloc_fragment((rows,), "float16") + L_NEW = T.alloc_fragment((rows,), "float16") + P_SUM = T.alloc_fragment((rows,), "float16") + SCALE = T.alloc_fragment((rows,), "float16") + L_INV = T.alloc_fragment((rows,), "float16") + M_INIT = T.alloc_fragment((rows,), "float16") + L_INIT = T.alloc_fragment((rows,), "float16") + + T.copy(Q_hbm[0, q_block * rows, by, 0], Q_sh) + + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = T.float16(0) + + for row in T.serial(rows): + M_OLD[row] = M_INIT[row] + L_OLD[row] = L_INIT[row] + + for kv_block in T.unroll(num_kv_blocks): + T.copy(K_hbm[0, kv_block * rows, by, 0], K_sh) + T.copy(V_hbm[0, kv_block * rows, by, 0], V_sh) + + with T.attr(0, KIND, "btmm"): + T.gemm(Q_sh, K_sh, S_loc, transpose_B=True) + + for row in T.serial(rows): + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] * SCALE[row] + M_CURR[row] = M_OLD[row] + + T.reduce_max(S_loc, M_CURR, dim=1, clear=False) + + for row in T.serial(rows): + M_RES[row] = M_OLD[row] - M_CURR[row] + M_RES[row] = T.exp(M_RES[row]) + for col in T.Parallel(MLEN): + S_loc[row, col] = S_loc[row, col] - M_CURR[row] + for col in T.Parallel(MLEN): + S_loc[row, col] = T.exp(S_loc[row, col]) + P_SUM[row] = L_INIT[row] + + T.reduce_sum(S_loc, P_SUM, dim=1, clear=False) + + for row in T.serial(rows): + L_NEW[row] = L_OLD[row] * M_RES[row] + L_NEW[row] = L_NEW[row] + P_SUM[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * M_RES[row] + M_OLD[row] = M_CURR[row] + L_OLD[row] = L_NEW[row] + + T.gemm(S_loc, V_sh, PV_loc) + + for row in T.serial(rows): + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] + PV_loc[row, col] + + for row in T.serial(rows): + L_INV[row] = 1.0 / L_NEW[row] + for col in T.Parallel(hlen): + O_loc[row, col] = O_loc[row, col] * L_INV[row] + + T.copy(O_loc, O_hbm[0, q_block * rows, by, 0]) + + return flash_attention_min + + +def main(argv) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--build-dir", type=Path, default=None, + help="Where to dump .midir.txt (default: skip)") + parser.add_argument("--num-q-blocks", type=int, default=2) + parser.add_argument("--num-kv-blocks", type=int, default=2) + parser.add_argument("--head-count", type=int, default=4) + args = parser.parse_args(argv) + + raw = build_raw_flash_attention_min( + num_q_blocks=args.num_q_blocks, + num_kv_blocks=args.num_kv_blocks, + head_count=args.head_count, + ) + + # Run the legacy stmt-prep (inline_let_stmts + lower_compound_fp_stores) + # — these ARE pre-fold steps, not part of the new mid_ir pipeline. + from tilelang_tvm_compiler.frontend.passes import inline_let_stmts + from tilelang_tvm_compiler.frontend.passes import lower_compound_fp_stores + raw = inline_let_stmts.run(raw) + raw = lower_compound_fp_stores.run(raw) + + # Mid_ir pipeline. + from tilelang_tvm_compiler.frontend.mid_ir.passes.infer_lane_axis import run as infer_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.fold import run as fold_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.mark import run as mark_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.split import run as split_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.distribute_cluster import run as dist_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.async_wrap import run as async_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.view import run as view_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.fuse import run as fuse_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.burn_view import run as burn_run + from tilelang_tvm_compiler.frontend.mid_ir.passes.to_plena import run as to_plena_run + + raw = infer_run(raw) + print(f"[infer_lane_axis] picked: " + f"{raw.attrs['plena.lane_axis'] if raw.attrs and 'plena.lane_axis' in raw.attrs else None}", + file=sys.stderr) + + midfn = fold_run(raw, name="flash_attention_min") + midfn = mark_run(midfn) + midfn = split_run(midfn) + midfn = dist_run(midfn) + midfn = async_run(midfn) + midfn = view_run(midfn) + midfn = fuse_run(midfn) + midfn = burn_run(midfn) + + hlir = to_plena_run(midfn, build_dir=args.build_dir) + + from tilelang_tvm_compiler.hlir import format_hlir + print(format_hlir(hlir)) + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/tilelang_tvm_compiler/test_helper.py b/tilelang_tvm_compiler/test_helper.py new file mode 100644 index 0000000..29dd29b --- /dev/null +++ b/tilelang_tvm_compiler/test_helper.py @@ -0,0 +1,618 @@ +"""TVM testbench harness — single entry point for tvm_*_test.py drivers. + +Each ``transactional_emulator/testbench/tvm__test.py`` had grown +into 200-370 lines of mostly-identical boilerplate (subprocess into the +TVM venv, write .pt files, call create_sim_env / create_mem_for_sim, +write comparison_params.json). This helper collapses the shared flow +into a single ``run(spec)`` call. Per-kernel code shrinks to: + + * a few constants (MLEN, HLEN, ...) + * a ``build_inputs_and_golden(seed)`` function + * an optional ``build_fp_preload`` and/or ``build_pre_kernel_stub`` + * a ``build_comparison_params`` function (or static dict) + * a single ``TvmTestbenchSpec(...)`` and ``run(spec)`` call + +The helper itself does not import torch eagerly because this file lives +in the compiler tree, which is loaded under multiple Python venvs. +Torch is imported inside ``run()`` where the testbench's own venv has +already set ``sys.path`` for it. + +Pipeline (in order): + + 1. Subprocess into the TVM venv to compile TIR -> PLENA ISA text, + optionally dumping HLIR and the buffer-address JSON. + 2. If ``parse_buffer_addrs`` was given, parse the JSON into the + address dict the per-kernel hooks expect. + 3. If ``build_pre_kernel_stub`` was given, prepend its output to the + kernel ISA. Used by conv2d_min / flash_decode_min for the FPRAM + staging -> VRAM cache copy that has to happen before the kernel + proper. + 4. ``build_inputs_and_golden(seed)`` produces: + - ``hbm_inputs``: dict[name -> torch.Tensor] for HBM staging + - ``golden_flat``: 2D flat golden the comparator diffs against + - any extras (e.g. ``q_token`` for flash_decode) consumed by + ``build_fp_preload`` + 5. If ``build_fp_preload`` was given, it returns the FP-preload + tensor positioned at the right FPRAM addresses. + 6. ``create_sim_env`` writes .pt / .asm / fp_sram.bin / int_sram.bin + and the golden file. ``create_mem_for_sim`` assembles + packs HBM. + 7. Write ``comparison_params.json`` so view_mem.py knows where to + read the staged output. + +The helper does not call cargo / view_mem itself — those are still +driven by `just build-emulator-debug `. This file only produces +the artefact set in ``build/``. +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Mapping, Optional + + +# --------------------------------------------------------------------------- +# Repo discovery. The helper lives at: +# /compiler/tilelang_tvm_compiler/test_helper.py +# so the repo root is two ``parent`` hops up. +# --------------------------------------------------------------------------- +_THIS_FILE = Path(__file__).resolve() +REPO_ROOT = _THIS_FILE.parent.parent.parent +TESTBENCH_DIR = REPO_ROOT / "transactional_emulator" / "testbench" + +# Default LD_LIBRARY_PATH for the ``.venv`` (Python 3.12) flow — torch +# loaded from that venv requires this Nix-provided libstdc++. The older +# ``.venv-tvm`` (Python 3.11, TVM-only) flow needs LD_LIBRARY_PATH="". +DEFAULT_LD_LIBRARY_PATH = "/nix/store/si4q3zks5mn5jhzzyri9hhd3cv789vlm-gcc-15.2.0-lib/lib" + + +# --------------------------------------------------------------------------- +# Spec +# --------------------------------------------------------------------------- + +# Input-build hook contract: +# def build_inputs_and_golden(seed: int) -> dict +# Required keys in the returned dict: +# - "hbm_inputs": dict[str, torch.Tensor] # buffers to stage into HBM +# - "golden_flat": torch.Tensor # 2D flat golden (rows, MLEN-aligned cols) +# Any other keys are kernel-specific extras (e.g. ``q_token`` for +# flash_decode_min) and are forwarded to ``build_fp_preload`` unchanged. +InputsBuilder = Callable[[int], dict] + +# Buffer-addresses parser: +# def parse_buffer_addrs(raw_json_dict: dict) -> dict +# Receives the raw output of ``--dump-buffer-addrs`` (each entry has +# ``scope``, ``address``, ``shape``, ``dtype``). Returns whatever shape +# the per-kernel hooks find convenient. +BufferAddrsParser = Callable[[dict], dict] + +# Pre-kernel ASM stub (concatenated BEFORE the compiled kernel ISA): +# def build_pre_kernel_stub(addrs: dict) -> str +PreKernelStubBuilder = Callable[[dict], str] + +# FP preload builder: +# def build_fp_preload(io: dict, addrs: dict) -> torch.Tensor +# ``io`` is the dict returned by ``build_inputs_and_golden``. +FpPreloadBuilder = Callable[[dict, dict], Any] # Any to avoid eager torch import + +# Comparison-params builder: +# def build_comparison_params(io: dict, addrs: dict) -> dict +# Receives the same ``io`` (so it can read shapes off ``hbm_inputs`` / +# ``golden_flat``) and the parsed addrs (some kernels need an O_CACHE +# address for ``start_row_idx``). +ComparisonParamsBuilder = Callable[[dict, dict], dict] + + +@dataclass +class TvmTestbenchSpec: + """Everything one ``tvm__test.py`` needs to declare.""" + + # ---- identity ---- + asm_name: str + """Used for the .asm filename, log messages, and (after the helper + runs) ``{asm_name}_generated_asm_code.asm`` in build/.""" + + kernel: str + """Kernel spec passed to ``tilelang_tvm_compiler compile --kernel``, + e.g. ``"tilelang_tvm_compiler.kernels.conv2d_min:make_conv2d_min"``.""" + + build_inputs_and_golden: InputsBuilder + """See ``InputsBuilder`` above.""" + + build_comparison_params: ComparisonParamsBuilder + """See ``ComparisonParamsBuilder`` above.""" + + # ---- compile-time tuneables ---- + kernel_kwargs: Mapping[str, Any] = field(default_factory=dict) + """k=v pairs forwarded as ``--kernel-kwargs k1=v1,k2=v2,...``.""" + + mlen: int = 64 + btmm_hlen: Optional[int] = None + btmm_lane_count: Optional[int] = None + + stage_output: Optional[str] = None + """Buffer name to re-stage from HBM -> VRAM at the end of the + kernel for view_mem comparison (passed via ``--stage-output``).""" + + use_v2: bool = False + """Route compilation through the PreIsaPassV2 → MIR → ISA path + (passed via ``--use-v2``) instead of the legacy single-pass + IsaEmitterPass. Same HW op stream; v2 also runs the MIR opt + passes (LICM / reassoc / spill). Enable to validate the v2 + backend end-to-end against the simulator golden.""" + + artifact_prefix: Optional[str] = None + """Prefix for ancillary build artefacts. Defaults to ``asm_name``.""" + + # ---- venv / subprocess env ---- + venv_name: str = ".venv" + """Subdir of the repo root containing the Python venv used to invoke + the compiler. ``.venv`` (Python 3.12, the new default) or the legacy + ``.venv-tvm`` (Python 3.11, TVM-wheel-only).""" + + ld_library_path: Optional[str] = DEFAULT_LD_LIBRARY_PATH + """Forwarded as the subprocess's ``LD_LIBRARY_PATH``. Pass ``""`` to + explicitly clear it (the ``.venv-tvm`` convention) or ``None`` to + inherit from the parent process unchanged.""" + + # ---- buffer-addrs JSON ---- + parse_buffer_addrs: Optional[BufferAddrsParser] = None + """If given, the helper passes ``--dump-buffer-addrs`` to the + compiler, then calls this function with the parsed JSON. The result + is forwarded to ``build_pre_kernel_stub`` / + ``build_fp_preload`` / ``build_comparison_params``. If omitted, the + helper still passes an empty dict to the hooks, so kernels that + don't need address introspection don't pay for it.""" + + # ---- optional kernel hooks ---- + build_pre_kernel_stub: Optional[PreKernelStubBuilder] = None + build_fp_preload: Optional[FpPreloadBuilder] = None + patch_isa: Optional[Any] = None + """Optional last-mile rewrite hook over the assembled ASM text. + + Signature: ``(isa_text: str) -> str``. Called after the compile + subprocess emits the kernel ASM and (if any) the pre-kernel stub + is prepended, but before ``create_sim_env`` writes it to disk. + Used by debug step kernels to patch a single instruction's + operand (e.g. flip ``V_RED_SUM ..., 1`` to ``..., 0`` to see + what the unmasked path produces) without touching the compiler. + """ + int_preload: Any = None + """Static int-preload tensor (sim_env_utils takes it through). None + for everything we have today.""" + + # ---- misc ---- + seed: int = 0 + """Forwarded to ``build_inputs_and_golden``.""" + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _format_kwargs(kwargs: Mapping[str, Any]) -> str: + return ",".join(f"{k}={v}" for k, v in kwargs.items()) + + +def _compile_via_subprocess( + spec: TvmTestbenchSpec, + *, + hlir_path: Path, + addrs_path: Optional[Path], +) -> str: + """Subprocess into the TVM venv to compile the kernel. + + Returns the kernel's ISA text (stdout). Raises ``RuntimeError`` + with the captured stderr on failure. + """ + venv_python = REPO_ROOT / spec.venv_name / "bin" / "python" + if not venv_python.exists(): + raise RuntimeError( + f"venv python not found: {venv_python}. Set " + f"TvmTestbenchSpec.venv_name to a venv that exists." + ) + cmd = [ + str(venv_python), "-m", "tilelang_tvm_compiler", "compile", + "--kernel", spec.kernel, + "--asm-name", spec.asm_name, + "--mlen", str(spec.mlen), + ] + if spec.kernel_kwargs: + cmd += ["--kernel-kwargs", _format_kwargs(spec.kernel_kwargs)] + if spec.btmm_lane_count is not None: + cmd += ["--btmm-lane-count", str(spec.btmm_lane_count)] + if spec.btmm_hlen is not None: + cmd += ["--btmm-hlen", str(spec.btmm_hlen)] + if spec.stage_output is not None: + cmd += ["--stage-output", spec.stage_output] + if spec.use_v2: + cmd += ["--use-v2"] + cmd += ["--dump-hlir", str(hlir_path)] + if addrs_path is not None: + cmd += ["--dump-buffer-addrs", str(addrs_path)] + + env = os.environ.copy() + if spec.ld_library_path is not None: + env["LD_LIBRARY_PATH"] = spec.ld_library_path + env["PYTHONPATH"] = str(REPO_ROOT / "compiler") + + res = subprocess.run(cmd, env=env, capture_output=True, text=True) + if res.returncode != 0: + sys.stderr.write(res.stderr) + raise RuntimeError( + f"TVM compile subprocess failed (returncode={res.returncode}). " + f"See stderr above. Command: {' '.join(cmd)}" + ) + return res.stdout + + +def _validate_io(io: dict) -> None: + if not isinstance(io, dict): + raise TypeError( + f"build_inputs_and_golden must return a dict; got " + f"{type(io).__name__}" + ) + missing = {"hbm_inputs", "golden_flat"} - set(io) + if missing: + raise KeyError( + f"build_inputs_and_golden return dict is missing required keys: " + f"{sorted(missing)} (must include 'hbm_inputs' and 'golden_flat')" + ) + if not isinstance(io["hbm_inputs"], dict): + raise TypeError( + f"build_inputs_and_golden['hbm_inputs'] must be a dict, got " + f"{type(io['hbm_inputs']).__name__}" + ) + + +# --------------------------------------------------------------------------- +# Canonical output layout — the SINGLE source of truth for how a kernel's +# logical (B, S, H, D) output maps to its physical VRAM staging layout, +# and therefore how golden must be flattened and how the comparator must +# reorder the staged VRAM dump before diffing. +# +# Why this exists: golden-flatten order, --stage-output VRAM layout, and +# check_mem's reorder were each hand-coded per kernel, each carrying its +# own (often wrong) layout assumption. Every disagreement showed up as a +# "compare fails" mystery. Route ALL THREE through this one function so +# they cannot drift. +# +# The layout facts (must match the compiler's TileLayout / --stage-output +# and check_mem.reorder_stride_mode): +# +# * One mlen-wide VRAM tile packs LANE_COUNT = mlen // hlen heads +# side by side (each head occupies hlen columns). +# * A logical output with H heads therefore spans +# H_GROUPS = ceil(H / LANE_COUNT) head-groups. +# * --stage-output loads O_hbm back into VRAM head-group-major: +# head-group 0 for ALL S rows first, then head-group 1, ... +# i.e. physical order [h_group][s][lane][d]. +# * golden_flat is batch-major [s][h][d] (one row's H heads +# contiguous). These two are the same data in different +# permutations. +# * check_mem.reorder_stride_mode converts head-group-major chunks +# back to batch-major exactly when chunks_per_batch > 1 (i.e. when +# H_GROUPS > 1). For H_GROUPS == 1 there is no interleaving and the +# reorder must stay off. +# --------------------------------------------------------------------------- + +# Every kernel output, whatever its logical rank, ultimately stages +# into VRAM as a 2D grid: num_batches logical rows, each +# elements_per_batch wide. A row wider than MLEN occupies +# CHUNKS_PER_BATCH = ceil(elements_per_batch / mlen) mlen-wide physical +# chunks, and --stage-output writes those chunk-group-major (chunk 0 of +# every row first, then chunk 1 of every row, ...). golden_flat is +# batch-major (each row's full width contiguous). When CHUNKS_PER_BATCH +# > 1 the two orders differ and the comparator must reorder; that is +# precisely what check_mem.reorder_stride_mode does, so use_stride_mode +# is just CHUNKS_PER_BATCH > 1. +# +# BSHD (B,S,H,D) outputs are the common special case: rows = B*S, +# cols = H*D, and chunk groups coincide with head-groups +# (LANE_COUNT = mlen//hlen heads per mlen tile). conv NCHW outputs and +# plain (M,N) matmul outputs are the same 2D grid with a different +# logical-shape story — :func:`resolve_output_layout` accepts any of +# them and they all funnel through the identical 2D math here. + +@dataclass(frozen=True) +class OutputLayout: + """Resolved 2D staging layout for one kernel output. + + ``num_batches`` logical rows, ``elements_per_batch`` values each, + physically chunked into ``mlen``-wide pieces.""" + num_batches: int + elements_per_batch: int + mlen: int + + @property + def chunks_per_batch(self) -> int: + """mlen-wide physical chunks one logical row spans.""" + return (self.elements_per_batch + self.mlen - 1) // self.mlen + + @property + def use_stride_mode(self) -> bool: + """True iff the staged VRAM is chunk-group-major and must be + reordered back to batch-major before diffing against golden.""" + return self.chunks_per_batch > 1 + + def comparison_params(self) -> dict: + """The geometry block check_mem / view_mem need. Merge this into + whatever kernel-specific keys (check_hbm, start_row_idx, ...) + the SPEC still wants to set.""" + return { + "num_rows": self.num_batches * self.chunks_per_batch, + "num_batches": self.num_batches, + "elements_per_batch": self.elements_per_batch, + "row_dim": self.mlen, + "use_stride_mode": self.use_stride_mode, + } + + def flatten_golden(self, out): + """Flatten a golden tensor of any rank into the canonical 2D + ``golden_flat``: ``(num_batches, elements_per_batch)``, + batch-major. The comparator reorders the VRAM side to match + this — never the reverse. Raises if the tensor's element count + doesn't match ``num_batches * elements_per_batch``.""" + want = self.num_batches * self.elements_per_batch + got = 1 + for dim in out.shape: + got *= int(dim) + if got != want: + raise ValueError( + f"flatten_golden: golden has {got} elements but layout " + f"expects num_batches*elements_per_batch = " + f"{self.num_batches}*{self.elements_per_batch} = {want} " + f"(golden shape {tuple(out.shape)})" + ) + return out.reshape(self.num_batches, self.elements_per_batch) + + +def resolve_output_layout( + *, + mlen: int, + # --- form 1: BSHD output --- + b: Optional[int] = None, + s: Optional[int] = None, + h: Optional[int] = None, + d: Optional[int] = None, + hlen: Optional[int] = None, + # --- form 2: explicit 2D output --- + num_batches: Optional[int] = None, + elements_per_batch: Optional[int] = None, +) -> OutputLayout: + """Build the canonical :class:`OutputLayout`. + + Two calling forms — pick the one matching the kernel's output: + + * BSHD: ``resolve_output_layout(b=, s=, h=, d=, mlen=, hlen=)`` + rows = b*s, cols = h*d. ``hlen`` is checked against ``mlen`` + for the lane-packing invariant (``h`` must be a multiple of + ``mlen//hlen``) so a mis-shaped head count fails loudly here. + + * explicit 2D: ``resolve_output_layout(num_batches=, + elements_per_batch=, mlen=)`` — for matmul ``(M,N)``, conv + ``(C_OUT*H, W)`` and anything else that isn't naturally BSHD. + + Whichever form, the result drives BOTH ``flatten_golden`` and + ``comparison_params`` so the golden order and the comparator's + reorder agree by construction. + """ + if mlen <= 0: + raise ValueError(f"resolve_output_layout: mlen must be > 0; got {mlen}") + + bshd_given = any(v is not None for v in (b, s, h, d, hlen)) + twod_given = any(v is not None for v in (num_batches, elements_per_batch)) + if bshd_given and twod_given: + raise ValueError( + "resolve_output_layout: pass EITHER the BSHD form " + "(b/s/h/d/hlen) OR the explicit 2D form " + "(num_batches/elements_per_batch), not both" + ) + if twod_given: + if num_batches is None or elements_per_batch is None: + raise ValueError( + "resolve_output_layout 2D form needs both num_batches " + "and elements_per_batch" + ) + return OutputLayout( + num_batches=int(num_batches), + elements_per_batch=int(elements_per_batch), + mlen=int(mlen), + ) + + # BSHD form. + if any(v is None for v in (b, s, h, d, hlen)): + raise ValueError( + "resolve_output_layout BSHD form needs all of b/s/h/d/hlen" + ) + if hlen <= 0 or mlen % hlen != 0: + raise ValueError( + f"resolve_output_layout: need mlen % hlen == 0; " + f"got mlen={mlen}, hlen={hlen}" + ) + lane_count = mlen // hlen + if h % lane_count != 0: + raise ValueError( + f"resolve_output_layout: H ({h}) must be a multiple of " + f"LANE_COUNT ({lane_count} = mlen//hlen)" + ) + return OutputLayout( + num_batches=int(b) * int(s), + elements_per_batch=int(h) * int(d), + mlen=int(mlen), + ) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def run(spec: TvmTestbenchSpec) -> int: + """Drive the full TVM testbench pipeline for ``spec``. + + Writes everything ``just build-emulator-debug`` expects under + ``transactional_emulator/testbench/build/``. Returns 0 on success. + """ + # Lazy imports — these need the testbench's venv site-packages to be + # on sys.path, which the testbench itself sets up before importing us. + from compiler.sim_env_utils import create_mem_for_sim + from transactional_emulator.tools.create_sim_env import create_sim_env + + artifact_prefix = spec.artifact_prefix or spec.asm_name + + build_dir = TESTBENCH_DIR / "build" + build_dir.mkdir(parents=True, exist_ok=True) + + # ---------- 1. compile ---------- + print(f"[1/4] Compiling TVM {spec.asm_name} kernel ...") + hlir_path = build_dir / f"{spec.asm_name}.hlir.txt" + addrs_path: Optional[Path] = ( + build_dir / f"{spec.asm_name}.buffer_addrs.json" + if spec.parse_buffer_addrs is not None else None + ) + kernel_isa = _compile_via_subprocess( + spec, hlir_path=hlir_path, addrs_path=addrs_path, + ) + + addrs: dict = {} + if addrs_path is not None: + raw_addrs = json.loads(addrs_path.read_text()) + addrs = spec.parse_buffer_addrs(raw_addrs) # type: ignore[misc] + + # Optional pre-kernel stub (FPRAM staging -> VRAM cache, etc.). + stub_isa = "" + if spec.build_pre_kernel_stub is not None: + stub_isa = spec.build_pre_kernel_stub(addrs) + isa_text = stub_isa + kernel_isa + if spec.patch_isa is not None: + patched = spec.patch_isa(isa_text) + if patched != isa_text: + print( + f" NB patch_isa hook rewrote ASM " + f"({isa_text.count(chr(10))} -> " + f"{patched.count(chr(10))} lines)" + ) + isa_text = patched + + # Large-immediate normalisation, on the FULL assembled text. + # ``isa_pass`` runs this on the kernel body, but the pre-kernel + # stub (``build_pre_kernel_stub``) and any ``patch_isa`` hook are + # assembled outside that path — their ``S_ADDI_INT`` immediates + # never got normalised. With MLEN-512 addresses those overflow the + # 18-bit immediate slot. Re-run the pass over the concatenation so + # every section is covered. + from .isa_pass import _normalize_large_addi_immediates + _before = isa_text.count(chr(10)) + isa_text = _normalize_large_addi_immediates(isa_text) + _after = isa_text.count(chr(10)) + if _after != _before: + print( + f" NB large-imm normalise expanded ASM " + f"({_before} -> {_after} lines)" + ) + print( + f" OK ({kernel_isa.count(chr(10))} kernel lines" + + (f" + {stub_isa.count(chr(10))} stub lines" if stub_isa else "") + + f", HLIR -> {hlir_path.name})" + ) + + # ---------- 2. inputs + golden + (optional) FP preload ---------- + print(f"[2/4] Generating inputs + golden{' + FP preload' if spec.build_fp_preload else ''} ...") + io = spec.build_inputs_and_golden(spec.seed) + _validate_io(io) + hbm_inputs: dict = io["hbm_inputs"] + golden_flat = io["golden_flat"] + + fp_preload = None + if spec.build_fp_preload is not None: + fp_preload = spec.build_fp_preload(io, addrs) + + # Auto-preload hoisted FP constants. The compiler's + # ``hoist_float_constants`` pre-pass turns every ``T.float16(c)`` + # use into a 1-slot ``global.fpram`` buffer; the buffer-addrs dump + # carries each one's slot address and value. Write those slots + # here so per-kernel testbenches don't have to enumerate them. + if addrs_path is not None and raw_addrs: + import torch # local — testbench venv has torch on sys.path here + const_entries = [ + (int(entry["address"]), float(entry["value"])) + for entry in raw_addrs.values() + if isinstance(entry, dict) and "value" in entry + ] + if const_entries: + max_const_addr = max(addr for addr, _ in const_entries) + needed = max_const_addr + 1 + if fp_preload is None: + fp_preload = torch.zeros(needed, dtype=torch.float16) + elif fp_preload.numel() < needed: + grown = torch.zeros(needed, dtype=fp_preload.dtype) + grown[: fp_preload.numel()] = fp_preload + fp_preload = grown + for addr, value in const_entries: + fp_preload[addr] = value + print(f" auto-preloaded {len(const_entries)} FP constant(s)") + + input_feed = { + name: t.contiguous().reshape(1, -1) for name, t in hbm_inputs.items() + } + input_order = list(input_feed) + summary = ", ".join( + f"{n}={tuple(t.shape)}" for n, t in hbm_inputs.items() + ) + print(f" OK hbm_inputs: {summary}") + print(f" golden flat: {tuple(golden_flat.shape)}") + if fp_preload is not None: + print(f" fp_preload: {tuple(fp_preload.shape)}") + + # ---------- 3. create_sim_env (.pt + .asm + fp/int sram bins) ---------- + print(f"[3/4] create_sim_env -> .pt + .asm + fp/int sram bins ...") + create_sim_env( + input_tensor=input_feed, + generated_code=isa_text, + golden_result={"original_output": golden_flat}, + fp_preload=fp_preload, + int_preload=spec.int_preload, + build_dir=str(build_dir), + ) + print(f" OK -> {build_dir}") + + # ---------- 4. create_mem_for_sim (assemble + pack HBM) ---------- + print(f"[4/4] create_mem_for_sim -> assemble .asm + pack HBM bin ...") + create_mem_for_sim( + data_size=256, + mode="behave_sim", + asm=spec.asm_name, + data=None, + specified_data_order=input_order, + build_path=build_dir, + ) + print(f" OK -> generated_machine_code.mem + hbm_for_behave_sim.bin") + + # ---------- comparison_params + asm snapshot ---------- + comparison_params = spec.build_comparison_params(io, addrs) + cmp_path = build_dir / "comparison_params.json" + cmp_path.write_text(json.dumps(comparison_params, indent=2)) + print(f" wrote comparison_params.json -> {cmp_path}") + + (build_dir / f"{artifact_prefix}_generated_asm_code.asm").write_text(isa_text) + + print() + print("=" * 60) + print(f"build/ ready for: just build-emulator-debug {artifact_prefix}") + print("=" * 60) + return 0 + + +__all__ = [ + "TvmTestbenchSpec", + "run", + "OutputLayout", + "resolve_output_layout", + "REPO_ROOT", + "TESTBENCH_DIR", + "DEFAULT_LD_LIBRARY_PATH", +] diff --git a/tilelang_tvm_compiler/tests/__init__.py b/tilelang_tvm_compiler/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tilelang_tvm_compiler/tests/_isa_diff.py b/tilelang_tvm_compiler/tests/_isa_diff.py new file mode 100644 index 0000000..b429ee5 --- /dev/null +++ b/tilelang_tvm_compiler/tests/_isa_diff.py @@ -0,0 +1,159 @@ +"""Semantic ISA diff utility for the matmul / DMA family migration. + +Used by byte-equal tests where strict register-number equality is +intractable — legacy ``emit_matmul_*`` helpers pre-allocate +``allocate_gp(6)`` blocks and use only a subset, producing scrambled +GP-number choices (``gp2``/``gp1``/``gp4``) that the PreIsaIR per-iter +materialiser can't reproduce without abandoning the var-ref operand +model entirely. + +``semantic_isa_equal(legacy, new)`` returns True iff: + * Both decode to the same sequence of ``(mnemonic, [operand,...])`` + tuples (comments stripped, blank lines stripped). + * Each operand token is either: + - identical literal (e.g. ``"0"``, ``"f0"``, ``"f1"``, ``"a3"``) + - or a ``gp`` reference, consistently renamed: + the FIRST time ``gpN`` appears on the legacy side at position + (instr_i, operand_j), it gets bijected to whatever ``gpM`` + the new side has at that same position; subsequent appearances + of legacy ``gpN`` MUST map to the same ``gpM``, and the + bijection must be one-to-one (no two legacy GPs map to the + same new GP). + +This catches: + * any opcode mismatch + * any literal-immediate mismatch (e.g. wrong stride / offset) + * any GP-aliasing bug (e.g. using gp_dst where gp_src expected) + +while tolerating: + * GP renumbering (gp2 ↔ gp1 etc.) from different allocation orders +""" + +from __future__ import annotations + +import re +from typing import Dict, List, Optional, Tuple + + +_GP_RE = re.compile(r"^gp(\d+)$") + + +def _instr_decode(text: str) -> List[Tuple[str, List[str]]]: + """``[(mnemonic, [tok, tok, ...]), ...]`` — comments/blanks dropped.""" + out: List[Tuple[str, List[str]]] = [] + for raw in text.split("\n"): + ln = raw.strip() + if not ln or ln.startswith(";"): + continue + head, _, tail = ln.partition(" ") + mnem = head.strip().rstrip(",") + operands = [t.strip() for t in tail.split(",") if t.strip()] + out.append((mnem, operands)) + return out + + +def _is_gp(tok: str) -> bool: + return _GP_RE.match(tok) is not None + + +def semantic_isa_equal( + legacy_text: str, new_text: str, +) -> Tuple[bool, Optional[str]]: + """Return (equal, error_message). + + ``equal`` is True iff the two ISA streams differ only by GP + renumbering. The bijection is reset on each ``M_*_WO`` (matmul + write-back) — those are natural iteration boundaries in unrolled + matmul / mv code where legacy keeps the same GPs across iters but + the PreIsaIR per-iter scope allocates fresh ones each time. We + require GP consistency WITHIN a "block" (between WO boundaries) + but allow re-mapping ACROSS blocks. + + For DMA-style streams with no WO marker, the bijection is whole- + stream (no boundary triggers a reset). + + ``error_message`` is None on success or a human-readable + explanation of the first mismatch. + """ + legacy = _instr_decode(legacy_text) + new = _instr_decode(new_text) + if len(legacy) != len(new): + return False, ( + f"instruction count differs: legacy={len(legacy)} new={len(new)}\n" + f"legacy:\n " + "\n ".join( + f"{m} {', '.join(ops)}" for m, ops in legacy + ) + "\nnew:\n " + "\n ".join( + f"{m} {', '.join(ops)}" for m, ops in new + ) + ) + gp_map_l2n: Dict[str, str] = {} + gp_seen_new: Dict[str, str] = {} + for i, ((lm, lops), (nm, nops)) in enumerate(zip(legacy, new)): + if lm != nm: + return False, ( + f"instr [{i}] mnemonic mismatch: legacy={lm!r} new={nm!r}\n" + f"legacy line: {lm} {', '.join(lops)}\n" + f"new line: {nm} {', '.join(nops)}" + ) + if len(lops) != len(nops): + return False, ( + f"instr [{i}] operand-count mismatch: " + f"legacy={len(lops)} new={len(nops)}\n" + f"legacy: {lm} {', '.join(lops)}\n" + f"new: {nm} {', '.join(nops)}" + ) + for j, (lt, nt) in enumerate(zip(lops, nops)): + lt_is_gp = _is_gp(lt) + nt_is_gp = _is_gp(nt) + if lt_is_gp != nt_is_gp: + return False, ( + f"instr [{i}] operand[{j}]: one is GP, other isn't\n" + f"legacy={lt!r} new={nt!r}" + ) + if not lt_is_gp: + if lt != nt: + return False, ( + f"instr [{i}] operand[{j}] literal mismatch: " + f"legacy={lt!r} new={nt!r}\n" + f"legacy: {lm} {', '.join(lops)}\n" + f"new: {nm} {', '.join(nops)}" + ) + continue + prev_mapped = gp_map_l2n.get(lt) + if prev_mapped is None: + if nt in gp_seen_new and gp_seen_new[nt] != lt: + return False, ( + f"instr [{i}] operand[{j}]: GP bijection " + f"broken — new {nt!r} already mapped to " + f"legacy {gp_seen_new[nt]!r}, now needs to " + f"also represent legacy {lt!r}" + ) + gp_map_l2n[lt] = nt + gp_seen_new[nt] = lt + else: + if prev_mapped != nt: + return False, ( + f"instr [{i}] operand[{j}]: legacy {lt!r} " + f"previously mapped to new {prev_mapped!r}, " + f"now appears as {nt!r}\n" + f"legacy: {lm} {', '.join(lops)}\n" + f"new: {nm} {', '.join(nops)}" + ) + # WO instructions are natural iteration boundaries — legacy + # reuses GPs across iters while PreIsaIR per_iter scope + # allocates fresh ones. Reset the bijection here. + if lm in ("M_MM_WO", "M_MV_WO", "M_BMV_WO", "M_BMM_WO"): + gp_map_l2n = {} + gp_seen_new = {} + return True, None + + +def assert_semantic_isa_equal(legacy_text: str, new_text: str) -> None: + """Convenience assertion for pytest tests.""" + ok, err = semantic_isa_equal(legacy_text, new_text) + if not ok: + raise AssertionError( + f"semantic ISA equality failed:\n{err}\n\n" + f"=== legacy ===\n{legacy_text}\n\n" + f"=== new ===\n{new_text}" + ) diff --git a/tilelang_tvm_compiler/tests/test_mir_passes.py b/tilelang_tvm_compiler/tests/test_mir_passes.py new file mode 100644 index 0000000..0700883 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_mir_passes.py @@ -0,0 +1,546 @@ +"""Tests for MIR optimisation passes. + +Each pass is verified two ways: + 1. **Unit**: build a tiny MIR by hand, run the pass, check + structure + counts. + 2. **Integration**: run on a real PreIsaIR-v2-produced MIR (via + v2 e2e helpers), check that the pass doesn't break MIR + verifier and that resulting ISA matches pre-pass ISA + structurally (HW op set unchanged). +""" + +import re + +import pytest +from tvm import tir + +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import mir_passes as P + + +# --------------------------------------------------------------------- +# Tiny MIR builders for unit tests +# --------------------------------------------------------------------- + +def _mk_fn(name="t"): + fn = mir.MirFunction(name=name) + entry = mir.MirBlock(name="entry") + fn.blocks.append(entry) + fn.make_gp0_const() + return fn, entry + + +def _addi(blk, src, imm, hint=""): + fn = _enclosing_fn(blk) + dst = fn.mint_value("i32", hint=hint) + blk.append(mir.MirInstr("S_ADDI_INT", [src, imm], result=dst)) + return dst + + +def _slli(blk, src, k, hint=""): + fn = _enclosing_fn(blk) + dst = fn.mint_value("i32", hint=hint) + blk.append(mir.MirInstr("S_SLLI_INT", [src, k], result=dst)) + return dst + + +def _add(blk, a, b, hint=""): + fn = _enclosing_fn(blk) + dst = fn.mint_value("i32", hint=hint) + blk.append(mir.MirInstr("S_ADD_INT", [a, b], result=dst)) + return dst + + +def _enclosing_fn(blk): + # Climb out via parent_loop pointers. + # For unit tests the block is always one of fn.blocks or a + # loop body block — both reachable from fn.blocks. We stash + # the fn on the block via a side attr the helpers know. + return blk._test_fn + + +def _attach_fn(blk, fn): + blk._test_fn = fn + + +def _mk_loop(fn, parent_blk, init, extent, kind="unroll", lvar_hint="i"): + lvar = fn.mint_value("i32", hint=lvar_hint) + body = mir.MirBlock(name=f"body.{lvar_hint}") + body.add_argument(lvar) + _attach_fn(body, fn) + lp = mir.MirLoop( + name=f"L_{lvar_hint}", loop_var=lvar, + init=init, extent=extent, body=[body], + loop_kind=kind, + ) + parent_blk.append(lp) + return lp, body, lvar + + +# --------------------------------------------------------------------- +# dead_loop_elim +# --------------------------------------------------------------------- + +def test_dle_extent_one_peels_body(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=5, extent=1, lvar_hint="i") + # Body: %x = ADDI lvar, 3 + x = _addi(body, lvar, 3, hint="x") + # Outer block uses nothing — pass should peel and replace + # lvar uses with IntImm(5). + + changed = P.dead_loop_elim(fn) + assert changed + mir.verify(fn) + + # After peeling: entry has no loop, has the ADDI instr. + assert all(not isinstance(it, mir.MirLoop) for it in entry.items) + adds = [it for it in entry.items if isinstance(it, mir.MirInstr) + and it.opcode == "S_ADDI_INT"] + assert len(adds) == 1 + # ADDI's first operand is now IntImm(5), not the lvar. + op0 = adds[0].operands[0] + assert isinstance(op0, tir.IntImm) and int(op0.value) == 5 + + +def test_dle_extent_zero_deletes_loop(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=0, extent=0, lvar_hint="i") + # Body empty — extent=0 means body never runs. + + changed = P.dead_loop_elim(fn) + assert changed + mir.verify(fn) + assert all(not isinstance(it, mir.MirLoop) for it in entry.items) + + +def test_dle_nested_extent_one_collapses_both(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + outer, outer_body, outer_lv = _mk_loop( + fn, entry, init=0, extent=1, lvar_hint="o", + ) + inner, inner_body, inner_lv = _mk_loop( + fn, outer_body, init=0, extent=1, lvar_hint="i", + ) + _addi(inner_body, outer_lv, 0, hint="x") + + P.dead_loop_elim(fn) + mir.verify(fn) + # Both loops gone, one ADDI remains in entry. + assert all(not isinstance(it, mir.MirLoop) for it in entry.items) + + +def test_dle_extent_two_left_alone(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=0, extent=2, lvar_hint="i") + _addi(body, lvar, 1) + changed = P.dead_loop_elim(fn) + assert not changed + assert sum(1 for it in entry.items if isinstance(it, mir.MirLoop)) == 1 + + +# --------------------------------------------------------------------- +# const_fold +# --------------------------------------------------------------------- + +def test_const_fold_addi_chain(): + """Folded instr is rewritten to ``S_ADDI_INT gp0, K``. The + consumer's MirValue operand is unchanged; what changed is + whose defining instr it points to. Two chained ADDIs both + fold to ``ADDI gp0, K`` where K is the running constant.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + # gp0 + 5 → 5 → rewrites to ADDI gp0, 5 + # %a + 10 → 15 → rewrites to ADDI gp0, 15 + a = _addi(entry, fn.gp0_value, 5, hint="a") + b = _addi(entry, a, 10, hint="b") + # Side-effecting consumer to keep b alive. + entry.append(mir.MirInstr( + "C_SET_SCALE_REG", [b], result=None, + )) + changed = P.const_fold(fn) + assert changed + # ``b`` is now produced by ``ADDI gp0, 15``. + assert b.defined_by.opcode == "S_ADDI_INT" + assert b.defined_by.operands[0] is fn.gp0_value + assert b.defined_by.operands[1] == 15 + + +def test_const_fold_slli_with_gp0(): + """``gp0 << K`` is gp0 (identity peephole). The setreg + consumer's operand is now ``fn.gp0_value`` directly; the + SLLI instr has no users and gets swept by DCE.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + s = _slli(entry, fn.gp0_value, 5, hint="s") + entry.append(mir.MirInstr( + "C_SET_STRIDE_REG", [s], result=None, + )) + P.const_fold(fn) + setreg = [it for it in entry.items + if isinstance(it, mir.MirInstr) + and it.opcode == "C_SET_STRIDE_REG"][0] + assert setreg.operands[0] is fn.gp0_value + + +# --------------------------------------------------------------------- +# dce +# --------------------------------------------------------------------- + +def test_dce_drops_unused_addi(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + _addi(entry, fn.gp0_value, 5, hint="dead") + # No consumer — DCE should drop the ADDI. + changed = P.dce(fn) + assert changed + assert not any( + isinstance(it, mir.MirInstr) and it.opcode == "S_ADDI_INT" + for it in entry.items + ) + + +def test_dce_keeps_used_addi(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + x = _addi(entry, fn.gp0_value, 5, hint="live") + entry.append(mir.MirInstr("C_SET_SCALE_REG", [x], result=None)) + changed = P.dce(fn) + assert not changed + assert sum( + 1 for it in entry.items + if isinstance(it, mir.MirInstr) and it.opcode == "S_ADDI_INT" + ) == 1 + + +def test_dce_keeps_side_effecting(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + # C_SET_SCALE_REG has no result; even so DCE must keep it. + entry.append(mir.MirInstr( + "C_SET_SCALE_REG", [fn.gp0_value], result=None, + )) + changed = P.dce(fn) + assert not changed + + +# --------------------------------------------------------------------- +# cse +# --------------------------------------------------------------------- + +def test_cse_collapses_identical_addi(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + a = _addi(entry, fn.gp0_value, 100, hint="a") + b = _addi(entry, fn.gp0_value, 100, hint="b") + # Two consumers — one for each result. After CSE both should + # point at ``a``. + entry.append(mir.MirInstr("C_SET_SCALE_REG", [a], result=None)) + entry.append(mir.MirInstr("C_SET_STRIDE_REG", [b], result=None)) + + changed = P.cse(fn) + assert changed + addis = [it for it in entry.items + if isinstance(it, mir.MirInstr) and it.opcode == "S_ADDI_INT"] + assert len(addis) == 1 + setregs = [it for it in entry.items + if isinstance(it, mir.MirInstr) + and it.opcode in ("C_SET_SCALE_REG", "C_SET_STRIDE_REG")] + # Both setregs reference the same surviving ADDI's result. + assert setregs[0].operands[0] is setregs[1].operands[0] + assert setregs[0].operands[0] is addis[0].result + + +def test_cse_respects_operand_differences(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + a = _addi(entry, fn.gp0_value, 100, hint="a") + b = _addi(entry, fn.gp0_value, 200, hint="b") # different imm + entry.append(mir.MirInstr("C_SET_SCALE_REG", [a], result=None)) + entry.append(mir.MirInstr("C_SET_STRIDE_REG", [b], result=None)) + + changed = P.cse(fn) + assert not changed + addis = [it for it in entry.items + if isinstance(it, mir.MirInstr) and it.opcode == "S_ADDI_INT"] + assert len(addis) == 2 + + +# --------------------------------------------------------------------- +# default pipeline +# --------------------------------------------------------------------- + +# --------------------------------------------------------------------- +# licm +# --------------------------------------------------------------------- + +def test_licm_hoists_invariant_addi(): + """``for i in [0, 4): %x = ADDI gp0, 100; ADDI lvar, x`` + — ``x`` is invariant; LICM hoists it out of the loop.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=0, extent=4, lvar_hint="i") + x = _addi(body, fn.gp0_value, 100, hint="x") # invariant + y = _add(body, lvar, x, hint="y") # depends on lvar — NOT invariant + body.append(mir.MirInstr("C_SET_SCALE_REG", [y], result=None)) + + changed = P.licm(fn) + assert changed + mir.verify(fn) + # ``x``'s defining ADDI should now sit in ``entry`` BEFORE the loop. + entry_addis = [it for it in entry.items + if isinstance(it, mir.MirInstr) + and it.opcode == "S_ADDI_INT"] + assert len(entry_addis) == 1 + assert entry_addis[0].result is x + # ``y`` (lvar-dependent) stays inside. + body_adds = [it for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode == "S_ADD_INT"] + assert len(body_adds) == 1 + + +def test_licm_doesnt_hoist_side_effecting(): + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=0, extent=4, lvar_hint="i") + # C_SET_SCALE_REG with gp0 operand is "invariant" in the + # operand sense but side-effecting — must NOT be hoisted. + body.append(mir.MirInstr( + "C_SET_SCALE_REG", [fn.gp0_value], result=None, + )) + changed = P.licm(fn) + assert not changed + # setreg still in body. + assert any( + isinstance(it, mir.MirInstr) and it.opcode == "C_SET_SCALE_REG" + for it in body.items + ) + + +def test_licm_nested_loops(): + """Inner-loop invariant w.r.t. inner but NOT outer should be + hoisted only to inner's parent (= outer's body), not all the + way out.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + outer, outer_body, outer_lv = _mk_loop( + fn, entry, init=0, extent=4, lvar_hint="o", + ) + inner, inner_body, inner_lv = _mk_loop( + fn, outer_body, init=0, extent=4, lvar_hint="i", + ) + # x depends on outer_lv — invariant in inner, but not in outer. + x = _addi(inner_body, outer_lv, 5, hint="x") + inner_body.append(mir.MirInstr("C_SET_SCALE_REG", [x], result=None)) + + P.licm(fn) + mir.verify(fn) + # x's ADDI now sits in outer_body, before the inner loop. + outer_body_addis = [it for it in outer_body.items + if isinstance(it, mir.MirInstr) + and it.opcode == "S_ADDI_INT"] + assert len(outer_body_addis) == 1 + assert outer_body_addis[0].result is x + # The setreg consumer is still inside inner_body. + assert any( + isinstance(it, mir.MirInstr) and it.opcode == "C_SET_SCALE_REG" + for it in inner_body.items + ) + + +def test_pipeline_dle_then_fold(): + """``for i in [0, 1): %x = ADDI lvar, 5; setreg %x`` → after + pipeline: + * DLE peels the loop, RAUW'ing ``lvar`` to ``IntImm(3)``; + * const_fold rewrites ``ADDI IntImm(3), 5`` to ``ADDI gp0, 8``; + * DCE has nothing to do (the ADDI is still used); + * setreg's operand still points at the same MirValue, now + produced by ``ADDI gp0, 8``. + """ + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, lvar = _mk_loop(fn, entry, init=3, extent=1, lvar_hint="i") + x = _addi(body, lvar, 5, hint="x") + body.append(mir.MirInstr("C_SET_SCALE_REG", [x], result=None)) + + P.run_default_pipeline(fn) + mir.verify(fn) + # Loop gone. + assert all(not isinstance(it, mir.MirLoop) for it in entry.items) + # Exactly one ADDI (the folded gp0+8) and one setreg referencing it. + addis = [it for it in entry.items + if isinstance(it, mir.MirInstr) and it.opcode == "S_ADDI_INT"] + setregs = [it for it in entry.items + if isinstance(it, mir.MirInstr) + and it.opcode == "C_SET_SCALE_REG"] + assert len(addis) == 1 + assert addis[0].operands[0] is fn.gp0_value + assert addis[0].operands[1] == 8 + assert len(setregs) == 1 + assert setregs[0].operands[0] is addis[0].result + + +# --------------------------------------------------------------------- +# reassociate +# --------------------------------------------------------------------- + +def _setreg(blk, src): + blk.append(mir.MirInstr("C_SET_SCALE_REG", [src], result=None)) + + +def _mk_loop_arg(fn, parent_blk, lvar_hint): + """Helper: mk a serial loop with a non-trivial extent so its + loop_var stays alive (used to give us "named" block-arg leaves + in reassoc tests).""" + lvar = fn.mint_value("i32", hint=lvar_hint) + body = mir.MirBlock(name=f"body.{lvar_hint}") + body.add_argument(lvar) + _attach_fn(body, fn) + lp = mir.MirLoop( + name=f"L_{lvar_hint}", loop_var=lvar, + init=0, extent=4, body=[body], + loop_kind="serial", + ) + parent_blk.append(lp) + return lp, body, lvar + + +def test_reassociate_duplicate_full_chain_collapses(): + """Two chains over the same leaves+const collapse to one.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, _ = _mk_loop_arg(fn, entry, "i") + a = _addi(body, fn.gp0_value, 5, hint="a") + b = _addi(body, fn.gp0_value, 7, hint="b") + # chain1 = a + b + 3 + s1a = _add(body, a, b, hint="s1a") + s1 = _addi(body, s1a, 3, hint="s1") + _setreg(body, s1) + # chain2 = same — different ADD order + s2a = _addi(body, b, 3, hint="s2a") + s2 = _add(body, s2a, a, hint="s2") + _setreg(body, s2) + + P.reassociate(fn) + mir.verify(fn) + # Both setregs reference the SAME MirValue now. + setregs = [it for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode == "C_SET_SCALE_REG"] + assert len(setregs) == 2 + assert setregs[0].operands[0] is setregs[1].operands[0] + + +def _sub(blk, a, b, hint=""): + fn = _enclosing_fn(blk) + dst = fn.mint_value("i32", hint=hint) + blk.append(mir.MirInstr("S_SUB_INT", [a, b], result=dst)) + return dst + + +def test_reassociate_sub_treated_as_negated_add(): + """``a - b`` and ``-b + a`` and ``a + (b - 2b)`` should all + canonicalise to the same {+a, -b} multiset. Two chains + expressing the same canonical form collapse to one.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, _ = _mk_loop_arg(fn, entry, "i") + a = _addi(body, fn.gp0_value, 100, hint="a") + b = _addi(body, fn.gp0_value, 200, hint="b") + # chain1: a - b + s1 = _sub(body, a, b, hint="s1") + _setreg(body, s1) + # chain2: also a - b, but written as (a + 0) - b via a ADD + # detour; reassoc should still see {+a, -b}, 0. + z = _add(body, a, fn.gp0_value, hint="z") # = a + s2 = _sub(body, z, b, hint="s2") + _setreg(body, s2) + + P.run_default_pipeline(fn) + mir.verify(fn) + # Two setregs should point at the same MirValue. + setregs = [it for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode == "C_SET_SCALE_REG"] + assert len(setregs) == 2 + assert setregs[0].operands[0] is setregs[1].operands[0] + + +def test_reassociate_sub_cancellation(): + """``a + b - a`` should simplify so the only remaining work in + the body is producing ``b`` (a constant or a copy of ``b``). + + Concretely after pipeline: pure ADDs/SUBs in the body that + survive should produce ``200`` (the value of ``b`` in this + test). We don't insist on a specific MirValue identity — DCE + may have collapsed the original ``b`` definition into the + final sum's instr. + """ + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, _ = _mk_loop_arg(fn, entry, "i") + a = _addi(body, fn.gp0_value, 100, hint="a") + b = _addi(body, fn.gp0_value, 200, hint="b") + s_ab = _add(body, a, b, hint="ab") + s = _sub(body, s_ab, a, hint="s") + _setreg(body, s) + + P.run_default_pipeline(fn) + mir.verify(fn) + # No SUB / no ADD instr should remain — the whole thing + # collapses to ``ADDI gp0, 200`` (after const-fold sees both + # ``a`` and ``a`` cancel via reassoc). + arith_ops = [it for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode in ("S_ADD_INT", "S_SUB_INT")] + assert len(arith_ops) == 0, ( + f"expected no surviving ADD/SUB; got " + f"{[it.opcode for it in arith_ops]}" + ) + # The setreg's operand value should be 200 (we check via the + # producer instr). + setregs = [it for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode == "C_SET_SCALE_REG"] + assert len(setregs) == 1 + src = setregs[0].operands[0] + # src is produced by some ADDI gp0, K. + d = src.defined_by + assert d is not None and d.opcode == "S_ADDI_INT" + assert d.operands[0] is fn.gp0_value + assert d.operands[1] == 200 + + +def test_reassociate_prefix_share(): + """Two chains where one's leaves are a prefix of the other's + share the partial sum.""" + fn, entry = _mk_fn() + _attach_fn(entry, fn) + lp, body, _ = _mk_loop_arg(fn, entry, "i") + a = _addi(body, fn.gp0_value, 5, hint="a") # leaf 1 + b = _addi(body, fn.gp0_value, 7, hint="b") # leaf 2 + c = _addi(body, fn.gp0_value, 11, hint="c") # leaf 3 + # chain1 = a + b + s1 = _add(body, a, b, hint="s1") + _setreg(body, s1) + # chain2 = a + b + c (should reuse s1) + s2a = _add(body, a, b, hint="s2a") + s2 = _add(body, s2a, c, hint="s2") + _setreg(body, s2) + + before = sum(1 for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode in ("S_ADD_INT", "S_ADDI_INT")) + # reassoc alone only RAUWs; DCE in the full pipeline sweeps + # the now-dead instr — so test the pipeline, not the pass. + P.run_default_pipeline(fn) + mir.verify(fn) + after = sum(1 for it in body.items + if isinstance(it, mir.MirInstr) + and it.opcode in ("S_ADD_INT", "S_ADDI_INT")) + assert after < before, f"before={before} after={after}" diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_ir.py b/tilelang_tvm_compiler/tests/test_pre_isa_ir.py new file mode 100644 index 0000000..6b67240 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_ir.py @@ -0,0 +1,136 @@ +"""Tests for the PreIsaIR data structures (pre_isa_ir.py). + +Covers: + * PreIsaOp construction + opcode validation + * PreIsaModule.append / .comment + * format_pre_isa indentation around C_LOOP_START / C_LOOP_END + * loop_regions pairing logic (nested, unmatched) + +These tests use ONLY the pre_isa_ir module — no TVM kernel, no +shim. Cheap and fast. +""" + +import pytest +from tvm import tir + +from tilelang_tvm_compiler.pre_isa_ir import ( + PreIsaOp, PreIsaModule, format_pre_isa, loop_regions, KNOWN_OPCODES, +) + + +def test_known_opcodes_include_core_set(): + """Sanity: every PreIsaIR opcode passes use must appear in + KNOWN_OPCODES. Note ``C_LOOP_START`` / ``C_LOOP_END`` are PLENA + ISA mnemonics (emitted as text by BackendEmit's serial-loop + handler), NOT PreIsaIR opcodes — the PreIsaIR control marker is + the unified ``LOOP_START`` / ``LOOP_END`` (with kind in + ``annotations["loop_kind"]``).""" + must_have = { + "S_ADDI_INT", "S_LD_FP", "S_ST_FP", "S_MUL_FP", + "V_ADD_VV", "V_MUL_VF", "V_EXP_V", + "M_BTMM", "M_MM", + "H_LOAD_V", "H_STORE_V", "H_PREFETCH_V", + "LOOP_START", "LOOP_END", + "C_SET_SCALE_REG", "C_SET_STRIDE_REG", + } + missing = must_have - KNOWN_OPCODES + assert not missing, f"opcodes missing from KNOWN_OPCODES: {missing}" + + +def test_pre_isa_op_rejects_unknown_opcode(): + with pytest.raises(ValueError, match="not a known PLENA mnemonic"): + PreIsaOp(opcode="NOT_A_REAL_INSTR") + + +def test_pre_isa_op_accepts_known_opcode(): + op = PreIsaOp(opcode="S_ADDI_INT", operands=["gp1", "gp0", "256"]) + assert op.opcode == "S_ADDI_INT" + assert op.operands == ["gp1", "gp0", "256"] + assert op.binds is None + assert op.annotations == {} + + +def test_module_append_and_comment(): + mod = PreIsaModule(name="k") + mod.append(PreIsaOp(opcode="S_ADDI_INT", operands=["gp1", "gp0", "1"])) + mod.comment("hello world") + assert len(mod.ops) == 2 + assert mod.ops[0].opcode == "S_ADDI_INT" + assert mod.ops[1].opcode == "_COMMENT" + assert mod.ops[1].operands == ["hello world"] + + +def test_format_pre_isa_indents_inside_loop(): + mod = PreIsaModule(name="k") + var = tir.Var("i", "int32") + mod.append(PreIsaOp(opcode="LOOP_START", operands=["gp2", "8"], binds=var)) + mod.append(PreIsaOp(opcode="V_ADD_VV", operands=["gp3", "gp4", "gp5", "0"])) + mod.append(PreIsaOp(opcode="LOOP_END", operands=["gp2"])) + text = format_pre_isa(mod) + lines = [ln for ln in text.split("\n") if "V_ADD_VV" in ln or "LOOP_" in ln] + # The body line must be indented further than the LOOP_START line. + start_indent = len(lines[0]) - len(lines[0].lstrip()) + body_indent = len(lines[1]) - len(lines[1].lstrip()) + end_indent = len(lines[2]) - len(lines[2].lstrip()) + assert body_indent > start_indent, f"body should indent past START:\n{text}" + assert end_indent == start_indent, f"END should align with START:\n{text}" + + +def test_format_pre_isa_handles_empty(): + mod = PreIsaModule(name="empty") + text = format_pre_isa(mod) + assert "PreIsaModule(name='empty')" in text + assert "Ops:" in text + + +def test_loop_regions_flat_pair(): + var = tir.Var("i", "int32") + ops = [ + PreIsaOp(opcode="LOOP_START", operands=["gp2", "4"], binds=var), + PreIsaOp(opcode="V_ADD_VV", operands=["gp3", "gp4", "gp5", "0"]), + PreIsaOp(opcode="LOOP_END", operands=["gp2"]), + ] + regions = loop_regions(ops) + assert regions == [(0, 2, var)] + + +def test_loop_regions_nested(): + v_outer = tir.Var("outer", "int32") + v_inner = tir.Var("inner", "int32") + ops = [ + PreIsaOp(opcode="LOOP_START", operands=["gp1", "4"], binds=v_outer), + PreIsaOp(opcode="LOOP_START", operands=["gp2", "8"], binds=v_inner), + PreIsaOp(opcode="V_ADD_VV", operands=["gp3", "gp4", "gp5", "0"]), + PreIsaOp(opcode="LOOP_END", operands=["gp2"]), + PreIsaOp(opcode="LOOP_END", operands=["gp1"]), + ] + regions = loop_regions(ops) + # Inner pair closes first; outer second. Both reported. + assert (1, 3, v_inner) in regions + assert (0, 4, v_outer) in regions + + +def test_loop_regions_unmatched_end_raises(): + ops = [PreIsaOp(opcode="LOOP_END", operands=["gp1"])] + with pytest.raises(ValueError, match="no matching loop-start"): + loop_regions(ops) + + +def test_loop_regions_unclosed_start_raises(): + var = tir.Var("i", "int32") + ops = [PreIsaOp(opcode="LOOP_START", operands=["gp1", "4"], binds=var)] + with pytest.raises(ValueError, match="unclosed loop-start"): + loop_regions(ops) + + +def test_loop_regions_tolerates_missing_binds_from_text_capture(): + """When PreIsaIR is built by text capture (CapturingCode), the + iteration var is already lowered away — binds is None. loop_regions + must NOT require it; LICM is the only consumer that needs binds and + it operates on the var-ref construction path only.""" + ops = [ + PreIsaOp(opcode="LOOP_START", operands=["gp1", "4"]), # binds=None + PreIsaOp(opcode="LOOP_END", operands=["gp1"]), + ] + regions = loop_regions(ops) + assert regions == [(0, 1, None)] diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmm.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmm.py new file mode 100644 index 0000000..3c9527b --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmm.py @@ -0,0 +1,88 @@ +"""Byte-equal verification for the migrated ``btmm`` (Q @ K^T packed- +head matmul). + +Legacy path: isa_pass._emit_btmm -> ISAEmitter.emit_btmm + .emit_btmm_wo. +New path: PreIsaPass._emit_btmm decomposes those two emit_* helpers +into a stream of PreIsaOps (preloads + M_BTMM + M_BMM_WO) so the +operand PrimExprs are visible to the optimiser. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram_buf(name: str, addr: int, shape) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram_buf(name: str, addr: int, shape) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def test_btmm_byte_equal(): + """Minimal btmm: lhs(M, gh, hlen) @ rhs(M, gh, hlen) -> dst(M, gh, M). + With M=mlen=64, gh=lane_count=4, hlen=16: + tile_elems = mlen*mlen = 4096 + dst.num_elements = 64*4*64 = 16384 -> tile_count = 16384 / 4096 = 4. + """ + LANE = 4 # btmm_lane_count + lhs = _vram_buf("lhs", 0, (MLEN, LANE, HLEN)) + rhs = _mram_buf("rhs", 4096, (MLEN, LANE, HLEN)) + dst = _vram_buf("dst", 8192, (MLEN, LANE, MLEN)) + + op = _hlir.Op( + kind="btmm", + buffer_args=[ + _hlir.VramRegion( + parent="lhs", starts=(0, 0, 0), extents=(MLEN, LANE, HLEN), + ), + _hlir.MramRegion( + parent="rhs", starts=(0, 0, 0), extents=(MLEN, LANE, HLEN), + ), + _hlir.VramRegion( + parent="dst", starts=(0, 0, 0), extents=(MLEN, LANE, MLEN), + ), + ], + scalar_args=[], + annotations={"intrinsic": "btmm_test"}, + ) + hlir = _hlir.HLIRModule( + name="btmm_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], + param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmv.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmv.py new file mode 100644 index 0000000..9de975e --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_btmv.py @@ -0,0 +1,71 @@ +"""Byte-equal verification for the migrated ``btmv``.""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 + + +def _instr_lines(text: str): + return [ + ln.strip() for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram_buf(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram_buf(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def test_btmv_byte_equal(): + LANE = 4 + lhs = _vram_buf("lhs", 0, (1, LANE, HLEN)) + rhs = _mram_buf("rhs", 4096, (MLEN, LANE, HLEN)) + dst = _vram_buf("dst", 8192, (1, LANE, MLEN)) + + op = _hlir.Op( + kind="btmv", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0, 0), + extents=(1, LANE, HLEN)), + _hlir.MramRegion(parent="rhs", starts=(0, 0, 0), + extents=(MLEN, LANE, HLEN)), + _hlir.VramRegion(parent="dst", starts=(0, 0, 0), + extents=(1, LANE, MLEN)), + ], + scalar_args=[], + annotations={"intrinsic": "btmv_test"}, + ) + hlir = _hlir.HLIRModule( + name="btmv_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma.py new file mode 100644 index 0000000..89e9f71 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma.py @@ -0,0 +1,156 @@ +"""Structural verification for the migrated DMA family +(``dma_h2v``, ``dma_h2m``, ``dma_v2h``). + +Legacy emit goes through ISAEmitter.emit_load_tile_from_hbm / +emit_hbm_tile_to_mram / emit_store_tile_to_hbm, each emitting a +multi-line setup sequence (S_ADDI_INT × N + C_SET_*_REG + the actual +H_PREFETCH_V / H_PREFETCH_M / H_STORE_V). + +PreIsaIR migration: addresses + scale/stride literals become PrimExprs +referencing ``MLEN_VAR``, ``V_PREFETCH_AMOUNT_VAR``, +``V_WRITEBACK_AMOUNT_VAR``. Optimiser sees ``vram_base + idx * +(mlen * v_prefetch_amount)`` etc. Backend lowers to the same HW +mnemonics; GP renumbering is expected. + +Coarse structural check: equal counts of H_PREFETCH_V / H_PREFETCH_M +/ H_STORE_V on both sides. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 + + +def _counts(isa: str): + p_v = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_V")) + p_m = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_M")) + s_v = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_STORE_V")) + return p_v, p_m, s_v + + +def _hbm(name, addr, num_elements): + """Minimal HBM buffer.""" + return _hlir.Buffer( + name=name, scope=_scope.HBM, + shape=(num_elements,), dtype="float16", + address=addr, + hbm_offset=0, hbm_stride=MLEN, + hbm_scale_size=MLEN * MLEN, + ) + + +def _vram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(MLEN, MLEN), dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, + shape=(MLEN, MLEN), dtype="float16", address=addr, + ) + + +def test_dma_h2m_structural_equal(): + """Single-tile HBM -> MRAM. One H_PREFETCH_M emitted.""" + src = _hbm("src_hbm", 0, MLEN * MLEN) + dst = _mram("dst_mram", 4096) + op = _hlir.Op( + kind="dma_h2m", + buffer_args=["src_hbm", "dst_mram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2m_smoke", + buffers={"src_hbm": src, "dst_mram": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_pv, legacy_pm, legacy_sv = _counts(legacy_isa) + new_pv, new_pm, new_sv = _counts(new_isa) + assert legacy_pm == new_pm == 1, ( + f"H_PREFETCH_M count differs: legacy={legacy_pm} new={new_pm}" + ) + assert new_pv == legacy_pv == 0 + assert new_sv == legacy_sv == 0 + + +def test_dma_h2v_structural_equal(): + """Single-tile HBM -> VRAM. Expected H_PREFETCH_V count: + inner_count * load_amount_per_hidden = ceil(mlen/v_prefetch_amount) * 1 + For mlen=64, v_prefetch_amount=1: 64 H_PREFETCH_Vs.""" + src = _hbm("src_hbm", 0, MLEN * MLEN) + dst = _vram("dst_vram", 4096) + op = _hlir.Op( + kind="dma_h2v", + buffer_args=["src_hbm", "dst_vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_smoke", + buffers={"src_hbm": src, "dst_vram": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_pv, _, _ = _counts(legacy_isa) + new_pv, _, _ = _counts(new_isa) + assert legacy_pv == new_pv, ( + f"H_PREFETCH_V count differs: legacy={legacy_pv} new={new_pv}" + ) + assert legacy_pv > 0 + + +def test_dma_v2h_structural_equal(): + """Single-tile VRAM -> HBM. Expected H_STORE_V count matches + H_PREFETCH_V count under symmetric mlen/v_writeback_amount setup.""" + src = _vram("src_vram", 0) + dst = _hbm("dst_hbm", 4096, MLEN * MLEN) + op = _hlir.Op( + kind="dma_v2h", + buffer_args=["src_vram", "dst_hbm"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_v2h_smoke", + buffers={"src_vram": src, "dst_hbm": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + _, _, legacy_sv = _counts(legacy_isa) + _, _, new_sv = _counts(new_isa) + assert legacy_sv == new_sv, ( + f"H_STORE_V count differs: legacy={legacy_sv} new={new_sv}" + ) + assert legacy_sv > 0 diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma_slice.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma_slice.py new file mode 100644 index 0000000..f792ddd --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_dma_slice.py @@ -0,0 +1,159 @@ +"""Structural verification for the migrated DMA slice family +(``dma_h2v_slice``, ``dma_h2m_slice``, ``dma_v2h_slice``). + +Static-start slices only (dynamic-start path is a known TODO in the +PreIsaPass migration — legacy uses ``hbm_start_offset_reg``). + +Coarse structural check: equal counts of H_PREFETCH_V / H_PREFETCH_M +/ H_STORE_V on both sides. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 + + +def _counts(isa: str): + p_v = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_V")) + p_m = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_M")) + s_v = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_STORE_V")) + return p_v, p_m, s_v + + +def _hbm_2d(name, addr, rows, cols): + """Minimal 2D HBM buffer (e.g. an activations tensor).""" + return _hlir.Buffer( + name=name, scope=_scope.HBM, + shape=(rows, cols), dtype="float16", + address=addr, + hbm_offset=0, hbm_stride=cols, + hbm_scale_size=MLEN * MLEN, + ) + + +def _vram_2d(name, addr, rows=MLEN, cols=MLEN): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(rows, cols), dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram_2d(name, addr, rows=MLEN, cols=MLEN): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, + shape=(rows, cols), dtype="float16", address=addr, + ) + + +def test_dma_h2v_slice_structural_equal(): + """Single mlen*mlen tile slice from a wider 2*mlen × 2*mlen HBM + parent. Grid = 1×1×1×1 (one tile).""" + parent = _hbm_2d("hbm", 0, 2 * MLEN, 2 * MLEN) + dst = _vram_2d("vram", 4096) + sl = _hlir.BufferSlice( + parent="hbm", + starts=(MLEN, MLEN), + extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_h2v_slice", + buffer_args=[sl, "vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_slice_smoke", + buffers={"hbm": parent, "vram": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_pv, _, _ = _counts(legacy_isa) + new_pv, _, _ = _counts(new_isa) + assert legacy_pv == new_pv, ( + f"H_PREFETCH_V count differs: legacy={legacy_pv} new={new_pv}" + ) + assert legacy_pv > 0 + + +def test_dma_h2m_slice_structural_equal(): + """Single-tile slice into MRAM.""" + parent = _hbm_2d("hbm", 0, 2 * MLEN, 2 * MLEN) + dst = _mram_2d("mram", 4096) + sl = _hlir.BufferSlice( + parent="hbm", + starts=(0, MLEN), + extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_h2m_slice", + buffer_args=[sl, "mram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2m_slice_smoke", + buffers={"hbm": parent, "mram": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + _, legacy_pm, _ = _counts(legacy_isa) + _, new_pm, _ = _counts(new_isa) + assert legacy_pm == new_pm == 1, ( + f"H_PREFETCH_M count differs: legacy={legacy_pm} new={new_pm}" + ) + + +def test_dma_v2h_slice_structural_equal(): + """VRAM source → slice into wider HBM dst.""" + src = _vram_2d("vram", 0) + parent = _hbm_2d("hbm", 4096, 2 * MLEN, 2 * MLEN) + sl = _hlir.BufferSlice( + parent="hbm", + starts=(MLEN, 0), + extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_v2h_slice", + buffer_args=["vram", sl], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_v2h_slice_smoke", + buffers={"vram": src, "hbm": parent}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + _, _, legacy_sv = _counts(legacy_isa) + _, _, new_sv = _counts(new_isa) + assert legacy_sv == new_sv, ( + f"H_STORE_V count differs: legacy={legacy_sv} new={new_sv}" + ) + assert legacy_sv > 0 diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for.py new file mode 100644 index 0000000..a2db6a6 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for.py @@ -0,0 +1,87 @@ +"""Byte-equal verification for ``for`` (HLIR structured op). + +A serial for-loop wraps one ``fp_zero_at`` body op. Legacy emits a +3-line prelude (init=0 case: S_ST_INT + C_LOOP_START with the +``;`` header), the body, and a 4-line epilogue (S_LD_INT / +S_ADDI_INT / S_ST_INT / C_LOOP_END). The PreIsaIR path must emit +byte-equal. + +The body PrimExpr references the loop variable so this also exercises +the symbol_table binding cycle: PreIsaPass leaves ``loop_var`` symbolic +in the operand; BackendEmit's C_LOOP_START handler binds it via +``symbol_table[loop_var] = ("ram", idx_addr)`` and the body op's +materialise resolves via S_LD_INT. +""" + +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def test_for_serial_byte_equal_body_uses_loop_var(): + """``for i in [0, 4): fp_zero_at dst_fp[i]`` — the body references + ``i`` in its scalar_args, so the inner S_LD_INT must come from + BackendEmit's symbol_table binding ``i -> ("ram", idx_addr)`` set + up by C_LOOP_START.""" + i = tir.Var("i", "int32") + # The dst FPRAM has 4 slots; body op writes f0 to slot index ``i``. + buf = _hlir.Buffer( + name="dst_fp", scope=_scope.FPRAM, + shape=(4,), dtype="float16", address=128, + ) + body = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst_fp", indices=(i,))], + ) + # loop_register_alloc stamps loop_gp; we mimic it here. + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, + "extent": 4, + "init": 0, + "loop_kind": "serial", + "loop_gp": 15, # legacy loop_register_alloc usually reserves high GPs + }, + body=[body], + ) + hlir = _hlir.HLIRModule( + name="for_smoke", buffers={"dst_fp": buf}, + ops=[for_op], param_names=[], + ) + + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + # legacy pass needs the loop_gp reserved in its allocator too. Reserve + # by allocating it upfront (mirrors what loop_register_alloc does + # via the RegisterAllocator constructor's gp_reserved). The simplest + # cross-test approach is to pre-pin gp15 — that ensures both paths + # see the same allocator state when they enter the for-op. + shim_legacy.compiler.register_allocator.pin_gp(15) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + shim_legacy.compiler.register_allocator.unpin_gp(15) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + shim_new.compiler.register_allocator.pin_gp(15) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + shim_new.compiler.register_allocator.unpin_gp(15) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for_unroll.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for_unroll.py new file mode 100644 index 0000000..722d9f1 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_for_unroll.py @@ -0,0 +1,131 @@ +"""Byte-equal verification for the HLIR for-loop with +``loop_kind="unroll"``. + +Legacy ``_emit_for`` unrolled branch (isa_pass.py:3189-3231): + * binds loop_var to ``tir.IntImm(init+i)`` for each iter; + * emits one ``; ... unroll iter`` header per iter; + * replays the body sub-ops with the IntImm-bound loop_var so + materialise() constant-folds it. + +PreIsaPass migrated unroll branch: + * emits a single ``LOOP_START`` PreIsaOp with + ``annotations["loop_kind"] = "unroll"``; + * walks the HLIR body sub-ops once (producing one set of body + PreIsaOps inside the LOOP_START/LOOP_END pair); + * BackendEmit's run() detects unroll-kind and replays the body + PreIsaOps N times, binding loop_var to IntImm per iter. + +This test exercises that whole pipeline against legacy on a small +HLIR: ``for i in [0, 3) unroll: fp_zero_at dst[i]`` — i.e. one +unrolled for-loop with a single fp_zero_at body that references the +loop var in its scalar_args. +""" + +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +def _instr_lines(text): + return [ + ln.strip() for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def test_for_unroll_byte_equal(): + """Unrolled for-loop body uses ``i`` in its scalar_args. Each iter's + materialise sees ``i`` bound to IntImm(0/1/2) and constant-folds + the address.""" + i = tir.Var("i", "int32") + buf = _hlir.Buffer( + name="dst_fp", scope=_scope.FPRAM, + shape=(4,), dtype="float16", address=128, + ) + body = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst_fp", indices=(i,))], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, + "extent": 3, + "init": 0, + "loop_kind": "unroll", + }, + body=[body], + ) + hlir = _hlir.HLIRModule( + name="for_unroll_smoke", + buffers={"dst_fp": buf}, + ops=[for_op], + param_names=[], + ) + + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + + +def test_for_unroll_body_with_loop_var_in_addr(): + """Same as above but with a 4-iter range to make sure the + sequence of materialised addresses (128, 129, 130, 131) shows + up in the expected order on both paths.""" + i = tir.Var("i", "int32") + buf = _hlir.Buffer( + name="dst_fp", scope=_scope.FPRAM, + shape=(8,), dtype="float16", address=256, + ) + body = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst_fp", indices=(i,))], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, + "extent": 4, + "init": 0, + "loop_kind": "unroll", + }, + body=[body], + ) + hlir = _hlir.HLIRModule( + name="for_unroll_4", + buffers={"dst_fp": buf}, + ops=[for_op], + param_names=[], + ) + + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + # Sanity: 4 iterations -> 4 S_ST_FP instructions. + instrs = _instr_lines(new_isa) + assert sum(1 for s in instrs if s.startswith("S_ST_FP")) == 4, instrs diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_scalar_at.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_scalar_at.py new file mode 100644 index 0000000..18d96b6 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_scalar_at.py @@ -0,0 +1,99 @@ +"""Byte-equal verification for the migrated ``fp_*_at`` family +(``fp_copy_at`` / ``fp_exp_at`` / ``fp_reci_at`` / ``fp_sqrt_at`` / +``fp_add_at`` / ``fp_sub_at`` / ``fp_mul_at`` / ``fp_max_at``). + +Each variant builds the same minimal HLIRModule, runs it through both +the legacy ``IsaEmitterPass`` and the new ``PreIsaPass`` + ``BackendEmit`` +path, and asserts the emitted HW-instruction lines are byte-equal. + +These tests are the proof that PreIsaIR's "group_id" materialisation +scoping correctly reuses the legacy path's pin/release pattern across +the multi-line burst each ``_at`` op emits. +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _mk_buf(name: str, addr: int) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, shape=(1,), dtype="float16", + address=addr, + ) + + +def _hlir_unary(kind: str) -> _hlir.HLIRModule: + buffers = { + "src": _mk_buf("src", 64), + "dst": _mk_buf("dst", 128), + } + op = _hlir.Op( + kind=kind, + buffer_args=[], + scalar_args=[ + _hlir.BufferElement(buffer="src", indices=(0,)), + _hlir.BufferElement(buffer="dst", indices=(0,)), + ], + ) + return _hlir.HLIRModule( + name=kind, buffers=buffers, ops=[op], param_names=[], + ) + + +def _hlir_binary(kind: str) -> _hlir.HLIRModule: + buffers = { + "lhs": _mk_buf("lhs", 32), + "rhs": _mk_buf("rhs", 64), + "dst": _mk_buf("dst", 128), + } + op = _hlir.Op( + kind=kind, + buffer_args=[], + scalar_args=[ + _hlir.BufferElement(buffer="lhs", indices=(0,)), + _hlir.BufferElement(buffer="rhs", indices=(0,)), + _hlir.BufferElement(buffer="dst", indices=(0,)), + ], + ) + return _hlir.HLIRModule( + name=kind, buffers=buffers, ops=[op], param_names=[], + ) + + +def _byte_equal(hlir: _hlir.HLIRModule) -> None: + """Run both paths against ``hlir`` and assert HW-instruction + byte-equality.""" + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", ["fp_copy_at", "fp_exp_at", "fp_reci_at", "fp_sqrt_at"]) +def test_fp_unary_at_byte_equal(kind): + _byte_equal(_hlir_unary(kind)) + + +@pytest.mark.parametrize("kind", ["fp_add_at", "fp_sub_at", "fp_mul_at", "fp_max_at"]) +def test_fp_binary_at_byte_equal(kind): + _byte_equal(_hlir_binary(kind)) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_zero_at.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_zero_at.py new file mode 100644 index 0000000..edb9c47 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_fp_zero_at.py @@ -0,0 +1,112 @@ +"""End-to-end PreIsaIR pipeline verification for ``fp_zero_at`` (the +first migrated handler). + +Two paths produce ISA text for the same minimal HLIRModule: + + legacy: HLIR -> IsaEmitterPass.run -> ISA text + new: HLIR -> PreIsaPass.run -> PreIsaModule + -> BackendEmit.run -> ISA text + +With no optimisation enabled, the new path must produce +**byte-equal** ISA to the legacy path (modulo the header comment block +emitted directly by IsaEmitterPass.run vs the same comments produced +as _COMMENT PreIsaOps; we compare only the non-comment lines, which +are the actual HW instructions). + +This is the proof-of-concept that the migration architecture works: +PreIsaIR + BackendEmit round-trips a real ISA emission with no loss +or drift. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +def _build_minimal_hlir(): + """One ``fp_zero_at`` op against an FPRAM buffer at addr=128. + ``fp_zero_at`` is the simplest leaf — one scalar arg, two ISA + lines (comment + S_ST_FP).""" + buf = _hlir.Buffer( + name="dst_fp", + scope=_scope.FPRAM, + shape=(1,), + dtype="float16", + address=128, + ) + op = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst_fp", indices=(0,))], + ) + return _hlir.HLIRModule( + name="fpz", + buffers={"dst_fp": buf}, + ops=[op], + param_names=[], + ) + + +def _instr_lines(text: str): + """Return the non-comment, non-blank instruction lines from ``text``. + These are the actual HW instructions whose count + form must match + byte-for-byte between the legacy and PreIsaIR paths.""" + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def test_legacy_path_emits_s_st_fp(): + """Legacy path: the materialiser writes one S_ADDI_INT to load the + FPRAM address into a GP, then the handler emits the S_ST_FP.""" + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + isa = IsaEmitterPass(shim).run(_build_minimal_hlir()) + instrs = _instr_lines(isa) + assert instrs == ["S_ADDI_INT gp1, gp0, 128", "S_ST_FP f0, gp1, 0"], ( + f"legacy expected S_ADDI_INT + S_ST_FP; got {instrs!r}" + ) + + +def test_pre_isa_path_produces_pre_isa_module(): + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim).run(_build_minimal_hlir()) + real_ops = [op for op in pre.ops if op.opcode != "_COMMENT"] + assert len(real_ops) == 1 + assert real_ops[0].opcode == "S_ST_FP" + # Operands: [str "f0", PrimExpr dst_addr, int 0] + operands = real_ops[0].operands + assert operands[0] == "f0" + assert isinstance(operands[2], int) and operands[2] == 0 + # The middle operand is a PrimExpr in var-ref form (the BufferElement + # resolved to ``128 + 0 == 128``, an IntImm — but the materialiser + # will be the one to lower it later). + import tvm.tir as _tir + assert isinstance(operands[1], (_tir.PrimExpr, int)) + + +def test_backend_emit_produces_byte_equal_isa(): + """The whole point of this proof-of-concept: drive PreIsaPass + + BackendEmit and confirm the ISA instruction lines are byte-equal + to the legacy IsaEmitterPass output.""" + hlir = _build_minimal_hlir() + + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_instrs = _instr_lines(legacy_isa) + new_instrs = _instr_lines(new_isa) + assert legacy_instrs == new_instrs, ( + "PreIsaIR path must produce byte-equal HW instructions to " + "the legacy path when no optimisation is enabled.\n" + f"legacy: {legacy_instrs!r}\n" + f"new: {new_instrs!r}" + ) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_matmul.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_matmul.py new file mode 100644 index 0000000..741a6b6 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_matmul.py @@ -0,0 +1,112 @@ +"""Structural verification for the migrated ``matmul`` (the unified +plena.matmul HLIR op that lowers to emit_matmul_general). + +This is the most complex matmul handler in the migration: 5-deep +nested LOOP_START loops (m, n_mlen, oc, orow, k) with PrimExpr +addresses referencing the loop vars + hw-shape consts +(MLEN_VAR / BLEN_VAR). The optimiser sees the full address algebra. + +We use a coarse structural check (M_MM / M_MM_WO counts equal +legacy) rather than strict byte-equal: the PreIsaIR's per_iter +nested unrolls allocate GPs differently than legacy's pre-pinned +7-reg block, and the legacy emit_matmul_general(unroll_loops=True) +bakes literal addresses into S_ADDI_INTs while PreIsaIR keeps them +symbolic — so the S_ADDI count + form will be different. M_MM / +M_MM_WO counts are the right invariant: same algorithm, same +sub-tile structure. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _counts(isa: str): + mm = sum(1 for ln in isa.split("\n") if ln.strip().startswith("M_MM ")) + tmm = sum(1 for ln in isa.split("\n") if ln.strip().startswith("M_TMM ")) + wo = sum(1 for ln in isa.split("\n") if ln.strip().startswith("M_MM_WO ")) + return mm, tmm, wo + + +def test_matmul_general_structural_equal(): + """Single-tile (M=K=N=mlen) non-transpose matmul. Legacy + emit_matmul_general(unroll_loops=True) produces 16 (M_MM, M_MM_WO) + pairs for M_tiles=K_tiles=1, N_mlen_tiles=1, tiles_per_n_mlen=16, + tiles_per_mlen=16. Per orow: K_tiles M_MMs then 1 M_MM_WO. + Total: 16 orow * 16 oc * 1 K_tile = 256 M_MMs, 256 M_MM_WOs. + """ + # 4D BSHD shape: (1, M, 1, K) for A, (1, K, 1, N) for B in row-major + # (we use rank-4 with the M / K / N roles tagged below). + M, K, N = MLEN, MLEN, MLEN + lhs = _vram("lhs", 0, (1, M, 1, K)) + rhs = _mram("rhs", 4096, (1, K, 1, N)) + dst = _vram("dst", 8192, (1, M, 1, N)) + + op = _hlir.Op( + kind="matmul", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0, 0, 0), + extents=(1, M, 1, K)), + _hlir.MramRegion(parent="rhs", starts=(0, 0, 0, 0), + extents=(1, K, 1, N)), + _hlir.VramRegion(parent="dst", starts=(0, 0, 0, 0), + extents=(1, M, 1, N)), + ], + scalar_args=[ + ("_", "M", "_", "K"), + ("_", "K", "_", "N"), + ("_", "M", "_", "N"), + ], + annotations={"intrinsic": "matmul_test"}, + ) + hlir = _hlir.HLIRModule( + name="matmul_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_mm, legacy_tmm, legacy_wo = _counts(legacy_isa) + new_mm, new_tmm, new_wo = _counts(new_isa) + # Expected pair count: tiles_per_n_mlen * tiles_per_mlen * K_tiles + # = 16 * 16 * 1 = 256 M_MMs; 16 * 16 = 256 M_MM_WOs. + expected_mm = 16 * 16 * 1 + expected_wo = 16 * 16 + assert legacy_mm == expected_mm, ( + f"legacy M_MM count: {legacy_mm} (expected {expected_mm})" + ) + assert new_mm == expected_mm, ( + f"new M_MM count: {new_mm} (expected {expected_mm})" + ) + assert legacy_wo == expected_wo + assert new_wo == expected_wo + assert legacy_tmm == 0 and new_tmm == 0 diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm.py new file mode 100644 index 0000000..2904758 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm.py @@ -0,0 +1,130 @@ +"""Semantic byte-equal verification for the migrated ``mm`` (M_MM + +M_MM_WO, single-tile mlen*mlen). + +Legacy emit_matmul_single_tile_hwloop pre-allocates ``allocate_gp(6)`` +and uses only 3 of those GPs (with a specific non-sequential +assignment); the per-iter PreIsaIR materialiser can't reproduce that +allocation scheme without abandoning the var-ref operand model. We +use ``semantic_isa_equal`` instead: mnemonics + literals must match, +GP renumbering is allowed. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + +from ._isa_diff import assert_semantic_isa_equal + + +MLEN = 64 +BLEN = 4 + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def test_mm_narrow_tile_structural_equal(): + """``mlen*mlen @ mlen*hlen -> mlen*hlen`` matmul. + + Strict semantic equality fails on narrow MM because legacy + materialises ``mat_addr`` ONCE per oc-iter (before the inner t + loop), while PreIsaIR's per_iter inner unroll closes the + surrounding scope and forces a re-preload per (oc, t). That + emits ``tiles_per_mlen - 1 = 15`` extra S_ADDIs per oc — a known + instruction-count divergence. The structural check verifies the + PreIsaIR path produces the right MNEMONIC stream around each + M_MM / M_MM_WO pair: a sequence of ``S_ADDI_INT``s setting up + the operand GPs, then ``M_MM``, then ``M_MM_WO``. Counts are + checked at the M_MM / M_MM_WO level, not per S_ADDI. + + The address-algebra optimisation opportunity exists either way + (mat_addr expr ``rhs.base + oc * blen`` is preserved as a + PrimExpr) — only the ISA emit form differs. + """ + MLEN, BLEN_L, HLEN = 64, 4, 16 + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, HLEN)) + dst = _vram("dst", 8192, (MLEN, HLEN)) + + op = _hlir.Op( + kind="mm", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[], + annotations={"intrinsic": "mm_narrow_test"}, + ) + hlir = _hlir.HLIRModule( + name="mm_narrow_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN_L, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN_L, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + def _mm_pairs(isa: str): + """Return the count of (M_MM, M_MM_WO) pairs in stream order.""" + mm = sum(1 for ln in isa.split("\n") + if ln.strip().startswith("M_MM ")) + wo = sum(1 for ln in isa.split("\n") + if ln.strip().startswith("M_MM_WO ")) + return mm, wo + + legacy_mm, legacy_wo = _mm_pairs(legacy_isa) + new_mm, new_wo = _mm_pairs(new_isa) + assert legacy_mm == new_mm, ( + f"M_MM count differs: legacy={legacy_mm} new={new_mm}" + ) + assert legacy_wo == new_wo, ( + f"M_MM_WO count differs: legacy={legacy_wo} new={new_wo}" + ) + # tiles_per_slot * tiles_per_mlen = 4 * 16 = 64 M_MM pairs total. + assert new_mm == 64 + assert new_wo == 64 + + +def test_mm_single_tile_semantic_equal(): + """Single mlen*mlen MM. Legacy emits tiles_per_mlen² = 16² = 256 + (M_MM, M_MM_WO) pairs after 3 S_ADDI_INTs each — 1280 instr total.""" + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + + op = _hlir.Op( + kind="mm", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[], + annotations={"intrinsic": "mm_test"}, + ) + hlir = _hlir.HLIRModule( + name="mm_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert_semantic_isa_equal(legacy_isa, new_isa) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm_slot.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm_slot.py new file mode 100644 index 0000000..7dc15c3 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mm_slot.py @@ -0,0 +1,88 @@ +"""Structural verification for the migrated ``mm_slot`` (slot matmul). + +Legacy ``emit_slot_matmul`` runs ``tiles_per_slot * tiles_per_mlen`` +(M_MM, M_MM_WO) pairs. PreIsaIR migration: outer oc loop is +per_iter, inner t loop is shared-scope so the destructive +act/out S_ADDI bumps carry across t-iters. + +Static-offset path only — dynamic offset migration is a TODO. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN_L = 4 +HLEN = 16 + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _mm_pairs(isa: str): + mm = sum(1 for ln in isa.split("\n") if ln.strip().startswith("M_MM ")) + wo = sum(1 for ln in isa.split("\n") if ln.strip().startswith("M_MM_WO ")) + return mm, wo + + +def test_mm_slot_structural_equal(): + """LHS mlen*mlen tile sliced from a 2*mlen*mlen buffer at offset 0; + RHS / DST are 64*256 wide tensors with rhs_col_offset=64, + dst_col_offset=128, col_count=16. + Pairs: tiles_per_slot * tiles_per_mlen = 4 * 16 = 64.""" + lhs = _vram("lhs", 0, (MLEN * 2, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, 256)) + dst = _vram("dst", 8192, (MLEN, 256)) + + op = _hlir.Op( + kind="mm_slot", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[ + 0, # lhs_row_offset + 64, # rhs_col_offset + 128, # dst_col_offset + HLEN, # col_count = 16, divisible by blen=4 + ], + annotations={"intrinsic": "mm_slot_test"}, + ) + hlir = _hlir.HLIRModule( + name="mm_slot_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN_L, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN_L, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + legacy_mm, legacy_wo = _mm_pairs(legacy_isa) + new_mm, new_wo = _mm_pairs(new_isa) + assert legacy_mm == new_mm, ( + f"M_MM count differs: legacy={legacy_mm} new={new_mm}" + ) + assert legacy_wo == new_wo, ( + f"M_MM_WO count differs: legacy={legacy_wo} new={new_wo}" + ) + # tiles_per_slot * tiles_per_mlen = 4 * 16 = 64. + assert new_mm == 64 + assert new_wo == 64 diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mv.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mv.py new file mode 100644 index 0000000..6189561 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_mv.py @@ -0,0 +1,77 @@ +"""Byte-equal verification for the migrated ``mv`` (M_MV + M_MV_WO). + +Exercises emit_mv's tiles-loop unroll and the destructive in-place +``S_ADDI_INT gp{m}, gp{m}, blen`` stride bump between iterations. +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 +BLEN = 4 + + +def _instr_lines(text): + return [ + ln.strip() for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def test_mv_byte_equal(): + """Single-head MV: lhs (1, hlen) vector @ rhs (mlen, hlen) matrix + -> dst (1, hlen). emit_mv runs tiles = hlen/blen = 4 iterations.""" + lhs = _vram("lhs", 0, (1, HLEN)) + rhs = _mram("rhs", 4096, (MLEN, HLEN)) + dst = _vram("dst", 8192, (1, HLEN)) + + op = _hlir.Op( + kind="mv", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0), + extents=(1, HLEN)), + _hlir.MramRegion(parent="rhs", starts=(0, 0), + extents=(MLEN, HLEN)), + _hlir.VramRegion(parent="dst", starts=(0, 0), + extents=(1, HLEN)), + ], + scalar_args=[], + annotations={"intrinsic": "mv_test"}, + ) + hlir = _hlir.HLIRModule( + name="mv_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + shim_legacy = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_row_ops.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_row_ops.py new file mode 100644 index 0000000..9ecfb89 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_row_ops.py @@ -0,0 +1,202 @@ +"""Byte-equal verification for the migrated ``row_*_at`` family: +``row_reduce_max_at`` / ``row_reduce_sum_at`` / +``row_exp`` / ``row_sub_fp`` / ``row_mul_fp`` / ``row_add_fp``. + +These ops: + * walk d_tiles in an unrolled loop (n_d_tiles HW ops per call); + * use a destructive in-place stride-bump pattern on the cached + src/dst GPs; + * optionally arm/reset the V_MASK register for packed-head buffers. + +The PreIsaIR migration exercises _PRELOAD_ADDR, _BUMP_CACHED_GP, +group_id scoping, and the cached-GP slot kind all at once. +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram_buf(name: str, addr: int) -> _hlir.Buffer: + """A simple cluster-less 1×MLEN×1×MLEN VRAM buffer with + tile_layout=None (single-tile path inside + _logical_to_phys_row_offset / _tile_layout_strides).""" + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, 1, MLEN), dtype="float16", + address=addr, + # Single-inner-tile case: d_tiles=1, no packed-head mask. + tile_layout=_hlir.TileLayout( + logical_b=1, logical_s=MLEN, logical_h=1, logical_d=MLEN, + d_tiles=1, s_tiles=1, h_groups=1, + mlen=MLEN, lane_count=1, d_inner=MLEN, + ), + cluster_dim=None, + ) + + +def _fpram_slot(name: str, addr: int) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, + shape=(1,), dtype="float16", address=addr, + ) + + +def _row_region(name: str, row: int) -> _hlir.VramRegion: + """One logical row of VRAM: extents (1,1,1,MLEN), starts pick the + row index in S.""" + return _hlir.VramRegion( + parent=name, starts=(0, row, 0, 0), extents=(1, 1, 1, MLEN), + ) + + +def _byte_equal(hlir: _hlir.HLIRModule) -> None: + shim_legacy = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", ["row_reduce_max_at", "row_reduce_sum_at"]) +def test_row_reduce_byte_equal(kind): + """Reduce: src VRAM row → FPRAM scalar accumulator.""" + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram_buf("src", 0), + "fp": _fpram_slot("fp", 1024), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row_region("src", row=5)], + scalar_args=[_hlir.BufferElement(buffer="fp", indices=(0,))], + )], + param_names=[], + ) + _byte_equal(hlir) + + +def test_row_exp_byte_equal(): + hlir = _hlir.HLIRModule( + name="row_exp", + buffers={ + "src": _vram_buf("src", 0), + "dst": _vram_buf("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind="row_exp", + buffer_args=[_row_region("src", row=2), _row_region("dst", row=2)], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) + + +@pytest.mark.parametrize("kind", ["row_add_fp", "row_sub_fp", "row_mul_fp"]) +def test_row_binary_fp_byte_equal(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram_buf("src", 0), + "dst": _vram_buf("dst", MLEN * MLEN), + "fp": _fpram_slot("fp", 2048), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row_region("src", row=3), _row_region("dst", row=3)], + scalar_args=[_hlir.BufferElement(buffer="fp", indices=(0,))], + )], + param_names=[], + ) + _byte_equal(hlir) + + +# ---------- multi-d_tile coverage (exercises _BUMP_CACHED_GP loop) ---------- + +def _vram_buf_wide(name: str, addr: int) -> _hlir.Buffer: + """A 1×MLEN×1×(2*MLEN) VRAM buffer with d_tiles=2 — exercises the + stride-bump unroll loop in row_*_at.""" + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, 1, 2 * MLEN), dtype="float16", + address=addr, + tile_layout=_hlir.TileLayout( + logical_b=1, logical_s=MLEN, logical_h=1, logical_d=2 * MLEN, + d_tiles=2, s_tiles=1, h_groups=1, + mlen=MLEN, lane_count=1, d_inner=MLEN, + ), + cluster_dim=None, + ) + + +def _row_region_wide(name: str, row: int) -> _hlir.VramRegion: + return _hlir.VramRegion( + parent=name, starts=(0, row, 0, 0), extents=(1, 1, 1, 2 * MLEN), + ) + + +def test_row_exp_multi_d_tile_byte_equal(): + """d_tiles=2 means the d_tile unroll loop fires _BUMP_CACHED_GP + once between the two V_EXP_V ops.""" + hlir = _hlir.HLIRModule( + name="row_exp_wide", + buffers={ + "src": _vram_buf_wide("src", 0), + "dst": _vram_buf_wide("dst", 4096), + }, + ops=[_hlir.Op( + kind="row_exp", + buffer_args=[ + _row_region_wide("src", row=1), + _row_region_wide("dst", row=1), + ], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) + + +def test_row_mul_fp_multi_d_tile_byte_equal(): + hlir = _hlir.HLIRModule( + name="row_mul_wide", + buffers={ + "src": _vram_buf_wide("src", 0), + "dst": _vram_buf_wide("dst", 4096), + "fp": _fpram_slot("fp", 8192), + }, + ops=[_hlir.Op( + kind="row_mul_fp", + buffer_args=[ + _row_region_wide("src", row=2), + _row_region_wide("dst", row=2), + ], + scalar_args=[_hlir.BufferElement(buffer="fp", indices=(0,))], + )], + param_names=[], + ) + _byte_equal(hlir) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_transfer.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_transfer.py new file mode 100644 index 0000000..280a9e2 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_transfer.py @@ -0,0 +1,101 @@ +"""Byte-equal verification for the migrated transfer ops: +``copy_v_to_v`` / ``v_fp_transfer_slice_v_to_fp`` +/ ``v_fp_transfer_slice_fp_to_v``. +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram_buf(name: str, addr: int) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, 1, MLEN), dtype="float16", + address=addr, cluster_dim=None, tile_layout=None, + ) + + +def _fpram_buf(name: str, addr: int, shape=(MLEN,)) -> _hlir.Buffer: + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, + shape=shape, dtype="float16", address=addr, + ) + + +def _whole(name: str) -> _hlir.VramRegion: + return _hlir.VramRegion( + parent=name, starts=(0, 0, 0, 0), extents=(1, MLEN, 1, MLEN), + ) + + +def _byte_equal(hlir: _hlir.HLIRModule) -> None: + shim_legacy = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + + +def test_copy_v_to_v_byte_equal(): + hlir = _hlir.HLIRModule( + name="copy_v_to_v", + buffers={ + "src": _vram_buf("src", 0), + "dst": _vram_buf("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind="copy_v_to_v", + buffer_args=[_whole("src"), _whole("dst")], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) + + +@pytest.mark.parametrize("kind", [ + "v_fp_transfer_slice_v_to_fp", "v_fp_transfer_slice_fp_to_v", +]) +def test_v_fp_transfer_slice_byte_equal(kind): + """1-row VRAM region + 1-element FPRAM slot. ``_vram_region_iter_chunks`` + yields one chunk → exactly one S_MAP_*.""" + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "v": _vram_buf("v", 0), + "fp": _fpram_buf("fp", 8192, shape=(MLEN,)), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_hlir.VramRegion( + parent="v", starts=(0, 0, 0, 0), + extents=(1, 1, 1, MLEN), + )], + scalar_args=[_hlir.BufferElement(buffer="fp", indices=(0,))], + )], + param_names=[], + ) + _byte_equal(hlir) diff --git a/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_v_ops.py b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_v_ops.py new file mode 100644 index 0000000..7961ca8 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_pre_isa_pipeline_v_ops.py @@ -0,0 +1,122 @@ +"""Byte-equal verification for the migrated vector ops: +``v_zero`` / ``v_add`` / ``v_sub`` / ``v_mul`` / ``v_exp`` / ``v_reci`` +/ ``v_sqrt``. + +These ops all walk VRAM regions chunk-by-chunk (via the legacy +``_vram_region_iter_chunks``) and emit one HW vector instruction per +chunk. The PreIsaIR migration must produce byte-equal HW instruction +sequences to the legacy ``IsaEmitterPass``. +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler.backend_emit import BackendEmit +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass import PreIsaPass +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 + + +def _instr_lines(text: str): + return [ + ln.strip() + for ln in text.split("\n") + if ln.strip() and not ln.strip().startswith(";") + ] + + +def _vram_buf(name: str, addr: int) -> _hlir.Buffer: + """A simple cluster-less, tile-layout-less 4D VRAM buffer + (1, MLEN, 1, MLEN). _vram_region_iter_chunks takes the + row-major-flat code path for these (cluster_dim=None, + tile_layout=None) — keeps the test setup minimal.""" + return _hlir.Buffer( + name=name, + scope=_scope.VRAM, + shape=(1, MLEN, 1, MLEN), + dtype="float16", + address=addr, + cluster_dim=None, + tile_layout=None, + ) + + +def _whole_region(name: str) -> _hlir.VramRegion: + return _hlir.VramRegion( + parent=name, + starts=(0, 0, 0, 0), + extents=(1, MLEN, 1, MLEN), + ) + + +def _byte_equal(hlir: _hlir.HLIRModule) -> None: + shim_legacy = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPass(shim_new).run(hlir) + new_isa = BackendEmit(shim_new).run(pre) + + assert _instr_lines(legacy_isa) == _instr_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nnew:\n{new_isa}" + ) + + +def test_v_zero_byte_equal(): + hlir = _hlir.HLIRModule( + name="v_zero_smoke", + buffers={"dst": _vram_buf("dst", 0)}, + ops=[_hlir.Op( + kind="v_zero", + buffer_args=[_whole_region("dst")], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) + + +@pytest.mark.parametrize("kind", ["v_add", "v_sub", "v_mul"]) +def test_v_binary_byte_equal(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "lhs": _vram_buf("lhs", 0), + "rhs": _vram_buf("rhs", MLEN * MLEN), + "dst": _vram_buf("dst", 2 * MLEN * MLEN), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[ + _whole_region("lhs"), + _whole_region("rhs"), + _whole_region("dst"), + ], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) + + +@pytest.mark.parametrize("kind", ["v_exp", "v_reci", "v_sqrt"]) +def test_v_unary_byte_equal(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram_buf("src", 0), + "dst": _vram_buf("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_whole_region("src"), _whole_region("dst")], + scalar_args=[], + )], + param_names=[], + ) + _byte_equal(hlir) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_btmm_mv.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_btmm_mv.py new file mode 100644 index 0000000..c26fe5a --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_btmm_mv.py @@ -0,0 +1,149 @@ +"""End-to-end v2 tests for btmm / btmv / mv. + +Structural comparison: same set of HW mnemonics (M_BTMM / M_BMM_WO / +M_BTMV / M_BMV_WO / M_MV / M_MV_WO) in same order. Address-setup +S_ADDI / S_SLLI / etc. skipped since v2 builds addresses from SSA +chains while legacy uses literals. +""" + +import re + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +HLEN = 16 +LANE = 4 +_GP_RE = re.compile(r"\bgp\d+\b") +_ADDR_SETUP = frozenset({"S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT"}) + + +def _strip(isa: str): + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _ADDR_SETUP: + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=LANE, btmm_hlen=HLEN) + return IsaEmitterPass(shim).run(hlir) + + +def test_btmm_v2_matches_legacy(): + lhs = _vram("lhs", 0, (MLEN, LANE, HLEN)) + rhs = _mram("rhs", 4096, (MLEN, LANE, HLEN)) + dst = _vram("dst", 8192, (MLEN, LANE, MLEN)) + op = _hlir.Op( + kind="btmm", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0, 0), + extents=(MLEN, LANE, HLEN)), + _hlir.MramRegion(parent="rhs", starts=(0, 0, 0), + extents=(MLEN, LANE, HLEN)), + _hlir.VramRegion(parent="dst", starts=(0, 0, 0), + extents=(MLEN, LANE, MLEN)), + ], + scalar_args=[], + annotations={"intrinsic": "btmm_test"}, + ) + hlir = _hlir.HLIRModule( + name="btmm", buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + + +def test_btmv_v2_matches_legacy(): + lhs = _vram("lhs", 0, (1, LANE, HLEN)) + rhs = _mram("rhs", 4096, (MLEN, LANE, HLEN)) + dst = _vram("dst", 8192, (1, LANE, MLEN)) + op = _hlir.Op( + kind="btmv", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0, 0), + extents=(1, LANE, HLEN)), + _hlir.MramRegion(parent="rhs", starts=(0, 0, 0), + extents=(MLEN, LANE, HLEN)), + _hlir.VramRegion(parent="dst", starts=(0, 0, 0), + extents=(1, LANE, MLEN)), + ], + scalar_args=[], + annotations={"intrinsic": "btmv_test"}, + ) + hlir = _hlir.HLIRModule( + name="btmv", buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + + +def test_mv_v2_matches_legacy(): + lhs = _vram("lhs", 0, (1, HLEN)) + rhs = _mram("rhs", 4096, (MLEN, HLEN)) + dst = _vram("dst", 8192, (1, HLEN)) + op = _hlir.Op( + kind="mv", + buffer_args=[ + _hlir.VramRegion(parent="lhs", starts=(0, 0), extents=(1, HLEN)), + _hlir.MramRegion(parent="rhs", starts=(0, 0), extents=(MLEN, HLEN)), + _hlir.VramRegion(parent="dst", starts=(0, 0), extents=(1, HLEN)), + ], + scalar_args=[], + annotations={"intrinsic": "mv_test"}, + ) + hlir = _hlir.HLIRModule( + name="mv", buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma.py new file mode 100644 index 0000000..fcfbb8d --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma.py @@ -0,0 +1,208 @@ +"""End-to-end v2 tests for DMA family — ``dma_h2v``, ``dma_h2m``, +``dma_v2h``. + +Each one walks an HBM buffer's tile grid via the same +``_iter_tile_offsets`` legacy uses. Per-tile body is a fixed sequence +of ``C_SET_ADDR_REG`` + ``C_SET_SCALE_REG`` + ``C_SET_STRIDE_REG`` + +one or more ``H_PREFETCH_V`` / ``H_PREFETCH_M`` / ``H_STORE_V``. + +Structural verification: equal counts of HW HBM ops on both sides +(legacy → v2). Address-setup S_ADDI_INT and similar scalar arithmetic +are skipped (different in-place vs SSA-rebuild styles). +""" + +import re + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 + + +def _counts(isa: str): + """Return (H_PREFETCH_V, H_PREFETCH_M, H_STORE_V) line counts.""" + pv = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_V")) + pm = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_M")) + sv = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_STORE_V")) + return pv, pm, sv + + +def _set_counts(isa: str): + sa = sum(1 for ln in isa.split("\n") if ln.strip().startswith("C_SET_ADDR_REG")) + sc = sum(1 for ln in isa.split("\n") if ln.strip().startswith("C_SET_SCALE_REG")) + st = sum(1 for ln in isa.split("\n") if ln.strip().startswith("C_SET_STRIDE_REG")) + return sa, sc, st + + +def _hbm(name, addr, num_elements, + hbm_stride=MLEN, hbm_scale_size=MLEN * MLEN, hbm_offset=0): + return _hlir.Buffer( + name=name, scope=_scope.HBM, + shape=(num_elements,), dtype="float16", + address=addr, + hbm_offset=hbm_offset, + hbm_stride=hbm_stride, + hbm_scale_size=hbm_scale_size, + ) + + +def _vram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(MLEN, MLEN), dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, + shape=(MLEN, MLEN), dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + return IsaEmitterPass(shim).run(hlir) + + +def test_dma_h2m_single_tile_v2(): + """Single-tile HBM → MRAM. One H_PREFETCH_M, one addr_reg setup.""" + src = _hbm("src_hbm", 0, MLEN * MLEN) + dst = _mram("dst_mram", 4096) + op = _hlir.Op( + kind="dma_h2m", + buffer_args=["src_hbm", "dst_mram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2m_smoke", + buffers={"src_hbm": src, "dst_mram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + + l_pv, l_pm, l_sv = _counts(legacy) + n_pv, n_pm, n_sv = _counts(new) + assert l_pm == n_pm == 1, ( + f"H_PREFETCH_M legacy={l_pm} v2={n_pm}\nv2:\n{new}" + ) + assert n_pv == l_pv == 0 + assert n_sv == l_sv == 0 + # v2 also binds the addr_reg once and sets scale + stride once + # (then resets at end). Sanity: at least one set of each. + n_sa, n_sc, n_st = _set_counts(new) + assert n_sa == 1 + assert n_sc >= 1 + assert n_st >= 1 + + +def test_dma_h2v_single_tile_v2(): + """Single-tile HBM → VRAM. H_PREFETCH_V count = inner_count * + load_amount_per_hidden. For mlen=64, v_prefetch=1 → 64 prefetches.""" + src = _hbm("src_hbm", 0, MLEN * MLEN) + dst = _vram("dst_vram", 4096) + op = _hlir.Op( + kind="dma_h2v", + buffer_args=["src_hbm", "dst_vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_smoke", + buffers={"src_hbm": src, "dst_vram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + l_pv, l_pm, l_sv = _counts(legacy) + n_pv, n_pm, n_sv = _counts(new) + assert l_pv == n_pv, ( + f"H_PREFETCH_V legacy={l_pv} v2={n_pv}\nv2:\n{new[:2000]}" + ) + assert n_pm == l_pm == 0 + assert n_sv == l_sv == 0 + + +def test_dma_v2h_single_tile_v2(): + """Single-tile VRAM → HBM. H_STORE_V count = inner_count * + store_amount_per_hidden. For mlen=64, v_writeback=4 → 16 stores.""" + src = _vram("src_vram", 0) + dst = _hbm("dst_hbm", 4096, MLEN * MLEN) + op = _hlir.Op( + kind="dma_v2h", + buffer_args=["src_vram", "dst_hbm"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_v2h_smoke", + buffers={"src_vram": src, "dst_hbm": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + l_pv, l_pm, l_sv = _counts(legacy) + n_pv, n_pm, n_sv = _counts(new) + assert l_sv == n_sv, ( + f"H_STORE_V legacy={l_sv} v2={n_sv}\nv2:\n{new[:2000]}" + ) + assert n_pv == l_pv == 0 + assert n_pm == l_pm == 0 + + +def test_dma_h2v_multi_tile_v2(): + """4-tile HBM (2 rows × 2 cols of mlen×mlen) → VRAM. Expect 4× + the single-tile H_PREFETCH_V count.""" + rows = 2 + cols = 2 + src = _hbm( + "src_hbm", 0, MLEN * MLEN * rows * cols, + hbm_stride=MLEN * cols, + hbm_scale_size=MLEN * MLEN, + ) + src.annotations["logical_rows"] = MLEN * rows + src.annotations["logical_cols"] = MLEN * cols + src.annotations["row_blocks"] = rows + src.annotations["col_blocks"] = cols + # dst VRAM big enough for all tiles. + dst = _hlir.Buffer( + name="dst_vram", scope=_scope.VRAM, + shape=(rows * cols * MLEN, MLEN), dtype="float16", + address=4096, cluster_dim=None, tile_layout=None, + ) + op = _hlir.Op( + kind="dma_h2v", + buffer_args=["src_hbm", "dst_vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_multi", + buffers={"src_hbm": src, "dst_vram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + l_pv, l_pm, l_sv = _counts(legacy) + n_pv, n_pm, n_sv = _counts(new) + assert l_pv == n_pv, ( + f"H_PREFETCH_V legacy={l_pv} v2={n_pv}\nv2:\n{new[:2000]}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma_slice.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma_slice.py new file mode 100644 index 0000000..16d0025 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_dma_slice.py @@ -0,0 +1,223 @@ +"""End-to-end v2 tests for DMA slice family — ``dma_h2v_slice``, +``dma_h2m_slice``, ``dma_v2h_slice``. + +Slice is a ``BufferSlice(parent, starts, extents)`` of an HBM +buffer. v2 walks a 4-level (d_tile, s_tile, h_grp, b) tile grid via +nested LoopRegions; the per-tile body is the same preload/store +helper as the whole-buffer DMA path. Slice ``starts`` may be static +(int / IntImm) or dynamic (PrimExpr referencing an outer +LoopRegion's loop_var) — arith.simplify in pre_isa_to_mir folds +static cases away. + +Structural verification: H_PREFETCH_V / H_PREFETCH_M / H_STORE_V +counts match legacy. +""" + +import pytest +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 + + +def _counts(isa: str): + pv = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_V")) + pm = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_PREFETCH_M")) + sv = sum(1 for ln in isa.split("\n") if ln.strip().startswith("H_STORE_V")) + return pv, pm, sv + + +def _hbm_2d(name, addr, rows, cols): + return _hlir.Buffer( + name=name, scope=_scope.HBM, + shape=(rows, cols), dtype="float16", + address=addr, + hbm_offset=0, hbm_stride=cols, + hbm_scale_size=MLEN * MLEN, + ) + + +def _vram_2d(name, addr, rows=MLEN, cols=MLEN): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(rows, cols), dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram_2d(name, addr, rows=MLEN, cols=MLEN): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, + shape=(rows, cols), dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + return IsaEmitterPass(shim).run(hlir) + + +def test_dma_h2v_slice_static_v2(): + """Single-tile static slice from a 2*mlen × 2*mlen parent.""" + parent = _hbm_2d("hbm", 0, 2 * MLEN, 2 * MLEN) + dst = _vram_2d("vram", 4096) + sl = _hlir.BufferSlice( + parent="hbm", starts=(MLEN, MLEN), extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_h2v_slice", + buffer_args=[sl, "vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_slice_smoke", + buffers={"hbm": parent, "vram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + l_pv, _, _ = _counts(legacy) + n_pv, _, _ = _counts(new) + assert l_pv == n_pv > 0, ( + f"H_PREFETCH_V legacy={l_pv} v2={n_pv}\nv2:\n{new[:2000]}" + ) + + +def test_dma_h2m_slice_static_v2(): + """Single-tile slice into MRAM.""" + parent = _hbm_2d("hbm", 0, 2 * MLEN, 2 * MLEN) + dst = _mram_2d("mram", 4096) + sl = _hlir.BufferSlice( + parent="hbm", starts=(0, MLEN), extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_h2m_slice", + buffer_args=[sl, "mram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2m_slice_smoke", + buffers={"hbm": parent, "mram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + _, l_pm, _ = _counts(legacy) + _, n_pm, _ = _counts(new) + assert l_pm == n_pm == 1 + + +def test_dma_v2h_slice_static_v2(): + """Static slice into HBM (single tile).""" + src = _vram_2d("vram", 0) + parent = _hbm_2d("hbm", 4096, 2 * MLEN, 2 * MLEN) + sl = _hlir.BufferSlice( + parent="hbm", starts=(MLEN, 0), extents=(MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_v2h_slice", + buffer_args=["vram", sl], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_v2h_slice_smoke", + buffers={"vram": src, "hbm": parent}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + _, _, l_sv = _counts(legacy) + _, _, n_sv = _counts(new) + assert l_sv == n_sv > 0 + + +def test_dma_h2v_slice_dynamic_start_v2(): + """Slice with a dynamic start (PrimExpr derived from an outer + for-loop var). v2 keeps the offset symbolic; converter folds the + static parts via arith.simplify. We can't byte-compare to legacy + here — legacy materialises an extra GP and counts differently. + Just check the v2 path produces a well-formed MIR with the + expected H_PREFETCH_V count. + """ + parent = _hbm_2d("hbm", 0, 4 * MLEN, MLEN) + dst = _vram_2d("vram", 4096) + i = tir.Var("i", "int32") + sl = _hlir.BufferSlice( + parent="hbm", + starts=(tir.Mul(i, tir.IntImm("int32", MLEN)), 0), + extents=(MLEN, MLEN), + ) + body_op = _hlir.Op( + kind="dma_h2v_slice", + buffer_args=[sl, "vram"], + scalar_args=[], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], scalar_args=[], + annotations={ + "loop_var": i, "extent": 4, "init": 0, + "loop_kind": "unroll", + }, + body=[body_op], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_slice_dyn", + buffers={"hbm": parent, "vram": dst}, + ops=[for_op], param_names=[], + ) + new = _v2_emit(hlir) + n_pv, _, _ = _counts(new) + # 4 outer iters × per-tile H_PREFETCH_V count. + # For mlen=64, v_prefetch=1, single-tile grid: 64 per iter. + assert n_pv > 0 + # Sanity: outer 4 iters compound. + assert n_pv % 4 == 0 + + +def test_dma_h2v_slice_multi_tile_v2(): + """Multi-tile slice — dst is large enough to require a 2-tile + row dimension. Expect 2x single-tile H_PREFETCH_V count.""" + parent = _hbm_2d("hbm", 0, 2 * MLEN, 2 * MLEN) + dst = _vram_2d("vram", 4096, rows=2 * MLEN, cols=MLEN) + sl = _hlir.BufferSlice( + parent="hbm", + starts=(0, 0), + extents=(2 * MLEN, MLEN), + ) + op = _hlir.Op( + kind="dma_h2v_slice", + buffer_args=[sl, "vram"], + scalar_args=[], + ) + hlir = _hlir.HLIRModule( + name="dma_h2v_slice_multi", + buffers={"hbm": parent, "vram": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + l_pv, _, _ = _counts(legacy) + n_pv, _, _ = _counts(new) + assert l_pv == n_pv, ( + f"H_PREFETCH_V legacy={l_pv} v2={n_pv}\nv2:\n{new[:2000]}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_for.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_for.py new file mode 100644 index 0000000..a70ef1d --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_for.py @@ -0,0 +1,177 @@ +"""End-to-end v2 tests for the HLIR ``for`` op — serial + unroll. + +Loops in PreIsaIR v2 are LoopRegions; the MIR conversion turns them +into MirLoops whose body is a MirBlock — the SSA scope of the loop +body. No GP pinning, no symbol-table state machine, no scope_floor. +Just nested blocks, the way LLVM IR (and TVM) models loops. +""" + +import re + +import pytest +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +_GP_RE = re.compile(r"\bgp\d+\b") + + +def _strip(isa: str): + """Comparable lines: non-S_ADDI, gpN canonicalised.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head == "S_ADDI_INT": + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _v2_emit(hlir): + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +def _fpram(name, addr, shape=(4,)): + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, + shape=shape, dtype="float16", address=addr, + ) + + +def test_for_unroll_with_fp_zero_at(): + """``for i in [0, 3) unroll: fp_zero_at dst[i]`` — body op uses + the loop var in its address. Unroll expands 3 body copies.""" + i = tir.Var("i", "int32") + buf = _fpram("dst", 128) + body_op = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst", indices=(i,))], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, "extent": 3, "init": 0, + "loop_kind": "unroll", + }, + body=[body_op], + ) + hlir = _hlir.HLIRModule( + name="for_unroll_fp_zero", + buffers={"dst": buf}, + ops=[for_op], + param_names=[], + ) + + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + # Both paths should emit 3 S_ST_FP lines (one per unrolled iter). + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + # Sanity — 3 stores. + assert sum(1 for l in new_isa.split("\n") if "S_ST_FP" in l) == 3 + + +@pytest.mark.parametrize("ext", [3, 5]) +def test_for_unroll_extent_param(ext): + """Vary the unroll extent — verify v2 matches legacy at multiple + sizes.""" + i = tir.Var(f"i_{ext}", "int32") + buf = _fpram("dst", 256, shape=(ext + 1,)) + body_op = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst", indices=(i,))], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, "extent": ext, "init": 0, + "loop_kind": "unroll", + }, + body=[body_op], + ) + hlir = _hlir.HLIRModule( + name="for_unroll_param", + buffers={"dst": buf}, + ops=[for_op], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +def test_for_serial_with_fp_zero_at(): + """``for i in [0, 4) serial: fp_zero_at dst[i]`` — serial loop + emits HW C_LOOP_START/C_LOOP_END + idx slot. Body's S_LD_INT + on each iter reads loop_var from IntRAM.""" + i = tir.Var("i_serial", "int32") + buf = _fpram("dst", 512, shape=(8,)) + body_op = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst", indices=(i,))], + ) + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": i, "extent": 4, "init": 0, + "loop_kind": "serial", + "loop_gp": 15, # legacy needs this annotation + }, + body=[body_op], + ) + hlir = _hlir.HLIRModule( + name="for_serial_fp_zero", + buffers={"dst": buf}, + ops=[for_op], + param_names=[], + ) + + # Legacy run: pre-pin gp15 so its allocator state matches. + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + shim_legacy.compiler.register_allocator.pin_gp(15) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + shim_legacy.compiler.register_allocator.unpin_gp(15) + + new_isa = _v2_emit(hlir) + + # Structural: both should have exactly 1 C_LOOP_START and 1 + # C_LOOP_END, and the same number of S_ST_FP (body executes once + # per source-code body, the HW loop runs it 4 times — so 1 + # S_ST_FP in the static text). + def _count(isa, mnem): + return sum(1 for l in isa.split("\n") if l.strip().startswith(mnem)) + assert _count(legacy_isa, "C_LOOP_START") == _count(new_isa, "C_LOOP_START") == 1 + assert _count(legacy_isa, "C_LOOP_END") == _count(new_isa, "C_LOOP_END") == 1 + assert _count(legacy_isa, "S_ST_FP") == _count(new_isa, "S_ST_FP") == 1 diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_scalar.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_scalar.py new file mode 100644 index 0000000..512d431 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_scalar.py @@ -0,0 +1,130 @@ +"""End-to-end v2 tests for the fp_*_at family +(copy/exp/reci/sqrt/add/sub/mul/max). + +Path: HLIR → PreIsaPassV2 → PreIsaIR v2 + → pre_isa_to_mir → MIR + → mir_to_isa → ISA text + +Structural compare against legacy: same non-S_ADDI mnemonics in +same order. S_ADDI_INT count may differ (legacy mixes addresses +with the body; v2's trivial allocator emits all setup S_ADDIs +up front via the SSA chain). +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +import re + +_GP_RE = re.compile(r"\bgp\d+\b") + + +def _non_addi_mnemonic_skeleton(isa: str): + """Return the list of non-S_ADDI ISA lines with GP numbers + canonicalised away — every ``gpN`` becomes ``gpX``. Allows v2's + aggressive GP reuse to compare equal to legacy's separate GPs + when the algorithm is the same.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head == "S_ADDI_INT": + continue + # Canonicalise gp numbers — both paths get the same skeleton. + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _mk_fpram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, shape=(1,), dtype="float16", + address=addr, + ) + + +def _hlir_unary(kind: str) -> _hlir.HLIRModule: + return _hlir.HLIRModule( + name=kind, + buffers={ + "src": _mk_fpram("src", 64), + "dst": _mk_fpram("dst", 128), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[], + scalar_args=[ + _hlir.BufferElement(buffer="src", indices=(0,)), + _hlir.BufferElement(buffer="dst", indices=(0,)), + ], + )], + param_names=[], + ) + + +def _hlir_binary(kind: str) -> _hlir.HLIRModule: + return _hlir.HLIRModule( + name=kind, + buffers={ + "lhs": _mk_fpram("lhs", 32), + "rhs": _mk_fpram("rhs", 64), + "dst": _mk_fpram("dst", 128), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[], + scalar_args=[ + _hlir.BufferElement(buffer="lhs", indices=(0,)), + _hlir.BufferElement(buffer="rhs", indices=(0,)), + _hlir.BufferElement(buffer="dst", indices=(0,)), + ], + )], + param_names=[], + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +@pytest.mark.parametrize("kind", [ + "fp_copy_at", "fp_exp_at", "fp_reci_at", "fp_sqrt_at", +]) +def test_fp_unary_v2_matches_legacy(kind): + hlir = _hlir_unary(kind) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_mnemonic_skeleton(legacy_isa) == _non_addi_mnemonic_skeleton(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", [ + "fp_add_at", "fp_sub_at", "fp_mul_at", "fp_max_at", +]) +def test_fp_binary_v2_matches_legacy(kind): + hlir = _hlir_binary(kind) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_mnemonic_skeleton(legacy_isa) == _non_addi_mnemonic_skeleton(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_zero_at.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_zero_at.py new file mode 100644 index 0000000..01ba72c --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_fp_zero_at.py @@ -0,0 +1,127 @@ +"""End-to-end test for the v2 pipeline (the clean rewrite). + +Path under test: + HLIR + ↓ PreIsaPassV2 + PreIsaIR v2 (clean — PrimExpr operands only) + ↓ pre_isa_to_mir + MIR (SSA values, def/use, loops) + ↓ mir_to_isa (trivial regalloc POC) + ISA text + +Compared against legacy: + HLIR → IsaEmitterPass → ISA text + +The check is *structural*, not byte-equal: legacy and v2 must emit +the same set of PLENA mnemonics in the same order, modulo +GP-renumbering and (for v2) a single extra S_ADDI_INT zero-load +that the legacy avoids (because legacy materialises addresses +lazily; v2 always synthesises an explicit ``%addr = S_ADDI_INT +gp0, IMM`` SSA def before the use). + +We use a tolerant equality: the set of HW mnemonics that the two +paths emit must be byte-equal, ignoring lines whose mnemonic is +``S_ADDI_INT`` (those are address-setup boilerplate that +allocator/peephole work changes the count of). +""" + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +def _build_fp_zero_at_hlir(): + """One ``fp_zero_at`` op against an FPRAM scalar buffer at addr=128.""" + buf = _hlir.Buffer( + name="dst_fp", + scope=_scope.FPRAM, + shape=(1,), + dtype="float16", + address=128, + ) + op = _hlir.Op( + kind="fp_zero_at", + buffer_args=[], + scalar_args=[_hlir.BufferElement(buffer="dst_fp", indices=(0,))], + ) + return _hlir.HLIRModule( + name="fpz", + buffers={"dst_fp": buf}, + ops=[op], + param_names=[], + ) + + +import re + +_GP_RE = re.compile(r"\bgp\d+\b") + + +def _non_addi_lines(isa: str): + """Non-S_ADDI ISA lines with gpN → gpX canonicalisation.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head == "S_ADDI_INT": + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def test_fp_zero_at_v2_pipeline_runs(): + """Smoke: the v2 pipeline (HLIR → PreIsaIR v2 → MIR → ISA) runs + end-to-end and produces non-empty ISA text.""" + hlir = _build_fp_zero_at_hlir() + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + isa = m2i.emit(fn, shim) + assert "S_ST_FP" in isa + + +def test_fp_zero_at_v2_matches_legacy_hw_mnemonics(): + """The set of non-S_ADDI HW mnemonics in v2 == legacy.""" + hlir = _build_fp_zero_at_hlir() + + shim_legacy = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + legacy_isa = IsaEmitterPass(shim_legacy).run(hlir) + + shim_new = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim_new).run(hlir) + fn = p2m.convert(pre, shim_new) + mir.verify(fn) + new_isa = m2i.emit(fn, shim_new) + + legacy_lines = _non_addi_lines(legacy_isa) + new_lines = _non_addi_lines(new_isa) + assert legacy_lines == new_lines, ( + f"\nlegacy non-S_ADDI:\n " + "\n ".join(legacy_lines) + + f"\nv2 non-S_ADDI:\n " + "\n ".join(new_lines) + ) + + +def test_fp_zero_at_v2_mir_dump(): + """Sanity that the MIR dump has the expected structure: + Function constants: %0_gp0:i32 = + ^entry: + _COMMENT + %1 = S_ADDI_INT %0_gp0, 128 + S_ST_FP f0, %1, 0 + """ + hlir = _build_fp_zero_at_hlir() + shim = make_shim(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + text = mir.format_mir(fn) + # The address 128 should appear as the immediate of an S_ADDI_INT. + assert "S_ADDI_INT" in text and "128" in text, text + assert "S_ST_FP" in text and "f0" in text, text diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_matmul.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_matmul.py new file mode 100644 index 0000000..7dd50db --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_matmul.py @@ -0,0 +1,254 @@ +"""End-to-end v2 tests for the general ``matmul`` op. + +5-level nested unroll: (m, n_mlen, oc, orow, k). K folds into the +systolic-array accumulator → K_tiles M_MM/M_TMM followed by one +M_MM_WO per output BLEN×BLEN tile. transpose_b is inferred from the +B-region axis order (b_N_axis < b_K_axis). + +Structural comparison: same M_MM/M_TMM and M_MM_WO count + order. +Scalar address-arithmetic stripped (legacy reuses 7 GPs with per-iter +S_ADDI bumps; v2 builds fresh PrimExpr chains). +""" + +import re + +import pytest +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +HLEN = 16 +_GP_RE = re.compile(r"\bgp\d+\b") +_ADDR_SETUP = frozenset({ + "S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT", + "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", +}) + + +def _strip(isa: str): + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _ADDR_SETUP: + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=HLEN) + return IsaEmitterPass(shim).run(hlir) + + +def _count(isa, mnem): + return sum( + 1 for ln in isa.split("\n") if ln.strip().startswith(mnem + " ") + ) + + +def test_matmul_mlen_square_v2_matches_legacy(): + """(MLEN, MLEN) @ (MLEN, MLEN) → (MLEN, MLEN); B as (K, N) row-major. + + M_tiles=K_tiles=N_mlen_tiles=1; tiles_per_n_mlen=16; tiles_per_mlen=16. + Total M_MM pairs = 1*1*16*16*1 = 256. + """ + a = _vram("a", 0, (1, MLEN, 1, MLEN)) + b = _mram("b", 4096, (1, MLEN, 1, MLEN)) + c = _vram("c", 8192, (1, MLEN, 1, MLEN)) + op = _hlir.Op( + kind="matmul", + buffer_args=[ + _hlir.VramRegion(parent="a", starts=(0, 0, 0, 0), + extents=(1, MLEN, 1, MLEN)), + _hlir.MramRegion(parent="b", starts=(0, 0, 0, 0), + extents=(1, MLEN, 1, MLEN)), + _hlir.VramRegion(parent="c", starts=(0, 0, 0, 0), + extents=(1, MLEN, 1, MLEN)), + ], + # a axes (B, M, _, K); b axes (B, K, _, N); c axes (B, M, _, N). + scalar_args=[ + ("_", "M", "_", "K"), + ("_", "K", "_", "N"), + ("_", "M", "_", "N"), + ], + annotations={"intrinsic": "matmul_square"}, + ) + hlir = _hlir.HLIRModule( + name="matmul_square", + buffers={"a": a, "b": b, "c": c}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + # tiles_per_mlen^2 = 256 (M_tiles=K_tiles=N_mlen_tiles=1) + assert _count(legacy, "M_MM") == _count(new, "M_MM") == 256 + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == 256 + + +def test_matmul_2mlen_x_2mlen_v2_matches_legacy(): + """(2*MLEN, MLEN) @ (MLEN, 2*MLEN) → (2*MLEN, 2*MLEN). + + M_tiles=2, K_tiles=1, N_mlen_tiles=2 → 4 mlen-tile output blocks + × 16 oc × 16 orow × 1 k = 1024 M_MMs. + """ + M = 2 * MLEN + N = 2 * MLEN + K = MLEN + a = _vram("a", 0, (1, M, 1, K)) + b = _mram("b", 4096, (1, K, 1, N)) + c = _vram("c", 8192, (1, M, 1, N)) + op = _hlir.Op( + kind="matmul", + buffer_args=[ + _hlir.VramRegion(parent="a", starts=(0, 0, 0, 0), + extents=(1, M, 1, K)), + _hlir.MramRegion(parent="b", starts=(0, 0, 0, 0), + extents=(1, K, 1, N)), + _hlir.VramRegion(parent="c", starts=(0, 0, 0, 0), + extents=(1, M, 1, N)), + ], + scalar_args=[ + ("_", "M", "_", "K"), + ("_", "K", "_", "N"), + ("_", "M", "_", "N"), + ], + annotations={"intrinsic": "matmul_2m_x_2n"}, + ) + hlir = _hlir.HLIRModule( + name="matmul_2m_x_2n", + buffers={"a": a, "b": b, "c": c}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + expected = 2 * 2 * 16 * 16 * 1 + assert _count(legacy, "M_MM") == _count(new, "M_MM") == expected + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected + + +def test_matmul_k_accumulation_v2_matches_legacy(): + """K_tiles=2: each output BLEN×BLEN tile gets 2 M_MM issuances + feeding one M_MM_WO (K folds into systolic accumulator).""" + M = MLEN + K = 2 * MLEN + N = MLEN + a = _vram("a", 0, (1, M, 1, K)) + b = _mram("b", 4096, (1, K, 1, N)) + c = _vram("c", 8192, (1, M, 1, N)) + op = _hlir.Op( + kind="matmul", + buffer_args=[ + _hlir.VramRegion(parent="a", starts=(0, 0, 0, 0), + extents=(1, M, 1, K)), + _hlir.MramRegion(parent="b", starts=(0, 0, 0, 0), + extents=(1, K, 1, N)), + _hlir.VramRegion(parent="c", starts=(0, 0, 0, 0), + extents=(1, M, 1, N)), + ], + scalar_args=[ + ("_", "M", "_", "K"), + ("_", "K", "_", "N"), + ("_", "M", "_", "N"), + ], + annotations={"intrinsic": "matmul_K2"}, + ) + hlir = _hlir.HLIRModule( + name="matmul_K2", + buffers={"a": a, "b": b, "c": c}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + expected_mm = 1 * 1 * 16 * 16 * 2 # K_tiles=2 + expected_wo = 1 * 1 * 16 * 16 * 1 # one per (oc, orow) + assert _count(legacy, "M_MM") == _count(new, "M_MM") == expected_mm + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected_wo + + +def test_matmul_transpose_b_v2_matches_legacy(): + """B as (N, K) — transpose_b inferred from b_N_axis < b_K_axis. + Should emit M_TMM, not M_MM.""" + M = MLEN + K = MLEN + N = MLEN + a = _vram("a", 0, (1, M, 1, K)) + # B physical shape with N before K (axis 1 = N, axis 3 = K) + b = _mram("b", 4096, (1, N, 1, K)) + c = _vram("c", 8192, (1, M, 1, N)) + op = _hlir.Op( + kind="matmul", + buffer_args=[ + _hlir.VramRegion(parent="a", starts=(0, 0, 0, 0), + extents=(1, M, 1, K)), + _hlir.MramRegion(parent="b", starts=(0, 0, 0, 0), + extents=(1, N, 1, K)), + _hlir.VramRegion(parent="c", starts=(0, 0, 0, 0), + extents=(1, M, 1, N)), + ], + scalar_args=[ + ("_", "M", "_", "K"), + ("_", "N", "_", "K"), + ("_", "M", "_", "N"), + ], + annotations={"intrinsic": "matmul_transpose_b"}, + ) + hlir = _hlir.HLIRModule( + name="matmul_transpose_b", + buffers={"a": a, "b": b, "c": c}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + expected = 1 * 1 * 16 * 16 * 1 + assert _count(legacy, "M_TMM") == _count(new, "M_TMM") == expected + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected + # No M_MM, only M_TMM. + assert _count(new, "M_MM") == 0 diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm.py new file mode 100644 index 0000000..e12f8a8 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm.py @@ -0,0 +1,149 @@ +"""End-to-end v2 tests for ``mm`` (single-tile mlen*mlen path). + +Structural comparison: same HW mnemonic stream (M_MM / M_MM_WO in +same order, same count). Address-setup S_ADDI_INT etc. skipped since +v2 builds addresses from SSA chains while legacy uses literals. + +The narrow-tile (mlen*hlen) and rectangular path is not covered here +— v2's ``_emit_mm`` currently rejects ``cols != mlen`` with a +"narrow-tile path not yet migrated" error. That path moves into the +general ``matmul`` family handler. +""" + +import re + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +_GP_RE = re.compile(r"\bgp\d+\b") +# Both paths emit very different S_* sequences for the same address +# expression (legacy preserves S_ADDI/SLLI/SRLI/LUI chains; v2 rebuilds +# fresh PrimExpr → SSA chains). We compare only HW core ops. +_ADDR_SETUP = frozenset({ + "S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT", + # v2 builds the result_addr = dst.base + oc*blen + orow*row_stride + # as a multi-step PrimExpr chain. The final SSA "two non-const + # operands" reduction emits S_ADD_INT (legacy folds it into the + # materialiser preamble — no S_ADD_INT in legacy mm). Filter the + # whole scalar address-arithmetic family out of the comparison. + "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", +}) + + +def _strip(isa: str): + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _ADDR_SETUP: + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +def _count(isa, mnem): + return sum( + 1 for ln in isa.split("\n") if ln.strip().startswith(mnem + " ") + ) + + +def test_mm_single_tile_v2_matches_legacy(): + """mlen*mlen @ mlen*mlen → mlen*mlen. + + tiles_per_mlen = mlen/blen = 16; outer oc loop x inner orow loop + each runs 16 iters → 256 M_MM pairs. + """ + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + + op = _hlir.Op( + kind="mm", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[], + annotations={"intrinsic": "mm_single_tile_test"}, + ) + hlir = _hlir.HLIRModule( + name="mm_single_tile", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + # Sanity — both emit 256 pairs. + tiles = MLEN // BLEN + expected = tiles * tiles + assert _count(legacy, "M_MM") == _count(new, "M_MM") == expected + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected + + +def test_mm_narrow_tile_not_yet_supported(): + """v2 _emit_mm explicitly rejects mlen*hlen narrow-tile path. + + Once the general ``matmul`` handler is migrated, narrow MM will + route there; for now we just assert v2 declines cleanly. + """ + from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2Error + HLEN = 16 + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, HLEN)) + dst = _vram("dst", 8192, (MLEN, HLEN)) + op = _hlir.Op( + kind="mm", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[], + annotations={"intrinsic": "mm_narrow_test"}, + ) + hlir = _hlir.HLIRModule( + name="mm_narrow_smoke", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + with pytest.raises(PreIsaPassV2Error, match="narrow-tile"): + _v2_emit(hlir) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm_slot.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm_slot.py new file mode 100644 index 0000000..28e9835 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_mm_slot.py @@ -0,0 +1,222 @@ +"""End-to-end v2 tests for ``mm_slot``. + +mm_slot applies M_MM/M_MM_WO over a col-slot of rhs/dst with optional +dynamic LHS row, RHS col, DST col offsets. Static-offset case is +covered first; dynamic (PrimExpr) offsets are covered by a second +test that uses a tir.Var offset. + +Comparison is structural — same M_MM / M_MM_WO count + order. Scalar +address-arithmetic mnemonics (S_ADDI_INT / S_SLLI_INT / S_ADD_INT / +S_MUL_INT) are stripped: v2 builds each pair's addresses from fresh +SSA chains while legacy reuses a 5-GP allocation with per-iter +S_ADDI bumps. Same dynamic semantics, different static form. +""" + +import re + +import pytest +from tvm import tir + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +BLEN = 4 +_GP_RE = re.compile(r"\bgp\d+\b") +_ADDR_SETUP = frozenset({ + "S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT", + "S_ADD_INT", "S_SUB_INT", "S_MUL_INT", +}) + + +def _strip(isa: str): + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _ADDR_SETUP: + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, shape=shape, + dtype="float16", address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _mram(name, addr, shape): + return _hlir.Buffer( + name=name, scope=_scope.MRAM, shape=shape, + dtype="float16", address=addr, + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=BLEN, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +def _count(isa, mnem): + return sum( + 1 for ln in isa.split("\n") if ln.strip().startswith(mnem + " ") + ) + + +def test_mm_slot_static_offsets_v2_matches_legacy(): + """All 4 scalar args literal: lhs_row=0, rhs_col=0, dst_col=0, + col_count=blen (= 1 oc tile). + + tiles_per_slot=1, tiles_per_mlen=16 → 16 M_MM pairs. + """ + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + + op = _hlir.Op( + kind="mm_slot", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[0, 0, 0, BLEN], + annotations={"intrinsic": "mm_slot_static"}, + ) + hlir = _hlir.HLIRModule( + name="mm_slot_static", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + expected = (BLEN // BLEN) * (MLEN // BLEN) # 1 * 16 + assert _count(legacy, "M_MM") == _count(new, "M_MM") == expected + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected + + +def test_mm_slot_wide_col_count_v2_matches_legacy(): + """col_count = 4*blen (= 4 oc tiles) → 4*16 = 64 M_MM pairs.""" + cc = 4 * BLEN + lhs = _vram("lhs", 0, (MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + op = _hlir.Op( + kind="mm_slot", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[0, 0, 0, cc], + annotations={"intrinsic": "mm_slot_wide"}, + ) + hlir = _hlir.HLIRModule( + name="mm_slot_wide", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + expected = (cc // BLEN) * (MLEN // BLEN) + assert _count(legacy, "M_MM") == _count(new, "M_MM") == expected + assert _count(legacy, "M_MM_WO") == _count(new, "M_MM_WO") == expected + + +def test_mm_slot_static_nonzero_offsets_v2_matches_legacy(): + """Static lhs_row, rhs_col, dst_col offsets all non-zero. + + Sanity: legacy folds the literals into S_ADDI immediates while v2 + builds (base + offset) PrimExprs that arith.simplify reduces to + literals — final HW ops still match structurally. + """ + cc = 2 * BLEN + # Pick offsets that keep us in bounds. + lhs_row_off = MLEN * MLEN # second mlen*mlen tile of a 2-tile lhs + rhs_col_off = 2 * BLEN + dst_col_off = 3 * BLEN + # Resize lhs so lhs_row_off is in range. + lhs = _vram("lhs", 0, (2 * MLEN, MLEN)) + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + op = _hlir.Op( + kind="mm_slot", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[lhs_row_off, rhs_col_off, dst_col_off, cc], + annotations={"intrinsic": "mm_slot_static_off"}, + ) + hlir = _hlir.HLIRModule( + name="mm_slot_static_off", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[op], param_names=[], + ) + legacy = _legacy_emit(hlir) + new = _v2_emit(hlir) + assert _strip(legacy) == _strip(new), ( + f"\nlegacy:\n{legacy}\nv2:\n{new}" + ) + + +def test_mm_slot_dynamic_lhs_row_offset_v2(): + """lhs_row_offset is a PrimExpr (= h * mlen * mlen). + + Legacy materialises it once to a GP and bases each oc tile off + that GP. v2 keeps it as a symbolic PrimExpr per pair — same M_MM + pair count, different static form (different scalar S_* prefix). + + We don't compare to legacy here (mnemonic counts differ at the + S_* level since legacy hoists, v2 inlines per iter). Instead we + verify the M_MM pair count is correct and the v2 path passes MIR + verify(). + """ + cc = 2 * BLEN + h = tir.Var("h", "int32") + lhs_off_expr = tir.Mul(h, tir.IntImm("int32", MLEN * MLEN)) + lhs = _vram("lhs", 0, (4 * MLEN, MLEN)) # roomy enough + rhs = _mram("rhs", 4096, (MLEN, MLEN)) + dst = _vram("dst", 8192, (MLEN, MLEN)) + op = _hlir.Op( + kind="mm_slot", + buffer_args=["lhs", "rhs", "dst"], + scalar_args=[lhs_off_expr, 0, 0, cc], + annotations={"intrinsic": "mm_slot_dynamic_lhs"}, + ) + # Wrap in a for-h loop so h is in scope as a loop_var. + for_op = _hlir.Op( + kind="for", + buffer_args=[], + scalar_args=[], + annotations={ + "loop_var": h, "extent": 2, "init": 0, + "loop_kind": "unroll", + }, + body=[op], + ) + hlir = _hlir.HLIRModule( + name="mm_slot_dyn", + buffers={"lhs": lhs, "rhs": rhs, "dst": dst}, + ops=[for_op], param_names=[], + ) + new = _v2_emit(hlir) + # 2 (h iters) * (cc/blen) * tiles_per_mlen pairs. + expected = 2 * (cc // BLEN) * (MLEN // BLEN) + assert _count(new, "M_MM") == expected + assert _count(new, "M_MM_WO") == expected diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_row.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_row.py new file mode 100644 index 0000000..89411ca --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_row.py @@ -0,0 +1,317 @@ +"""End-to-end v2 tests for the row_*_at family. + +Six ops: + * row_reduce_max_at / row_reduce_sum_at — reduce VRAM row → FPRAM + * row_exp — unary on VRAM row + * row_add_fp / row_sub_fp / row_mul_fp — binary with FP scalar + +Single-tile-layout setup (d_tiles=1, no packed-head mask) keeps the +test focused on the basic row-op structure. Multi-d_tile variants +exercise the unroll loop in the v2 path. +""" + +import re + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +_GP_RE = re.compile(r"\bgp\d+\b") + +# Address-setup mnemonics that legacy and v2 emit in different +# quantities (legacy uses destructive in-place stride bump; v2 builds +# each iter's address from scratch as a fresh SSA chain). Skipped +# during comparison. +_ADDR_SETUP = frozenset({"S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "S_LUI_INT"}) + + +def _strip(isa: str): + """Strip lines + canonicalise gp numbers. Skip address-setup + instructions so the comparison focuses on HW ops + their + invocation count/order.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _ADDR_SETUP: + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, d_extent=MLEN): + """1×MLEN×1×d_extent VRAM buffer with a tile_layout that matches + legacy's single-tile case (d_tiles = ceil(d_extent / MLEN)).""" + d_tiles = (d_extent + MLEN - 1) // MLEN + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, 1, d_extent), dtype="float16", + address=addr, + tile_layout=_hlir.TileLayout( + logical_b=1, logical_s=MLEN, logical_h=1, logical_d=d_extent, + d_tiles=d_tiles, s_tiles=1, h_groups=1, + mlen=MLEN, lane_count=1, d_inner=MLEN, + ), + cluster_dim=None, + ) + + +def _fpram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, shape=(1,), + dtype="float16", address=addr, + ) + + +def _row(name, row=0, d_extent=MLEN): + return _hlir.VramRegion( + parent=name, starts=(0, row, 0, 0), extents=(1, 1, 1, d_extent), + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +@pytest.mark.parametrize("kind", ["row_reduce_max_at", "row_reduce_sum_at"]) +def test_row_reduce_v2_matches_legacy(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram("src", 0), + "fp": _fpram("fp", 1024), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row("src", row=2)], + scalar_args=[ + _hlir.BufferElement(buffer="fp", indices=(0,)), + ], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +def test_row_exp_v2_matches_legacy(): + hlir = _hlir.HLIRModule( + name="row_exp", + buffers={ + "src": _vram("src", 0), + "dst": _vram("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind="row_exp", + buffer_args=[_row("src", row=1), _row("dst", row=1)], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", ["row_add_fp", "row_sub_fp", "row_mul_fp"]) +def test_row_binary_fp_v2_matches_legacy(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram("src", 0), + "dst": _vram("dst", MLEN * MLEN), + "fp": _fpram("fp", 2048), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row("src", row=3), _row("dst", row=3)], + scalar_args=[ + _hlir.BufferElement(buffer="fp", indices=(0,)), + ], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +def test_row_exp_multi_d_tile_v2(): + """d_tiles=2 — exercises the unroll loop.""" + hlir = _hlir.HLIRModule( + name="row_exp_wide", + buffers={ + "src": _vram("src", 0, d_extent=2 * MLEN), + "dst": _vram("dst", 4096, d_extent=2 * MLEN), + }, + ops=[_hlir.Op( + kind="row_exp", + buffer_args=[ + _row("src", row=1, d_extent=2 * MLEN), + _row("dst", row=1, d_extent=2 * MLEN), + ], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +# ----------------------------------------------------------------- +# Packed-head (col_pack) masked variants +# ----------------------------------------------------------------- +# Layout: shape=(1, MLEN, LANE_COUNT, D_INNER), cluster_dim=2, with +# LANE_COUNT > 1, D_INNER = MLEN // LANE_COUNT. Picking a non-zero lane +# index (= region.starts[2]) forces _logical_to_phys_row_offset to +# emit a real mask_expr (1 << (lane % lane_count)), so the row scalar +# handler must bracket the body with C_SET_V_MASK_REG / 0. + +LANE_COUNT_PACKED = 4 +D_INNER_PACKED = MLEN // LANE_COUNT_PACKED # 16 when MLEN=64 + + +def _vram_packed(name, addr): + """col_pack buffer: 1×MLEN×LANE_COUNT×D_INNER, cluster on H.""" + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, LANE_COUNT_PACKED, D_INNER_PACKED), + dtype="float16", address=addr, + tile_layout=_hlir.TileLayout( + logical_b=1, logical_s=MLEN, + logical_h=LANE_COUNT_PACKED, logical_d=D_INNER_PACKED, + d_tiles=1, s_tiles=1, h_groups=1, + mlen=MLEN, lane_count=LANE_COUNT_PACKED, + d_inner=D_INNER_PACKED, + ), + cluster_dim=2, + ) + + +def _row_packed(name, row, lane): + """One logical row at (b=0, s=row, h=lane, d=0..D_INNER).""" + return _hlir.VramRegion( + parent=name, + starts=(0, row, lane, 0), + extents=(1, 1, 1, D_INNER_PACKED), + ) + + +def _count_mask_set(isa): + return sum( + 1 for l in isa.split("\n") + if l.strip().startswith("C_SET_V_MASK_REG") + ) + + +@pytest.mark.parametrize("kind", ["row_reduce_max_at", "row_reduce_sum_at"]) +def test_row_reduce_masked_v2_matches_legacy(kind): + """Reduce on a col_pack source with lane=1 → packed-head mask.""" + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram_packed("src", 0), + "fp": _fpram("fp", 1024), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row_packed("src", row=2, lane=1)], + scalar_args=[ + _hlir.BufferElement(buffer="fp", indices=(0,)), + ], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + # Sanity: both bracket the body with set/reset. + assert _count_mask_set(legacy_isa) == _count_mask_set(new_isa) == 2 + + +def test_row_exp_masked_v2_matches_legacy(): + """row_exp with packed-head src+dst, lane=3.""" + hlir = _hlir.HLIRModule( + name="row_exp_masked", + buffers={ + "src": _vram_packed("src", 0), + "dst": _vram_packed("dst", 4096), + }, + ops=[_hlir.Op( + kind="row_exp", + buffer_args=[ + _row_packed("src", row=1, lane=3), + _row_packed("dst", row=1, lane=3), + ], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + assert _count_mask_set(legacy_isa) == _count_mask_set(new_isa) == 2 + + +@pytest.mark.parametrize("kind", ["row_add_fp", "row_sub_fp", "row_mul_fp"]) +def test_row_binary_fp_masked_v2_matches_legacy(kind): + """V_*_VF with col_pack + lane mask.""" + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram_packed("src", 0), + "dst": _vram_packed("dst", 4096), + "fp": _fpram("fp", 2048), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[ + _row_packed("src", row=3, lane=2), + _row_packed("dst", row=3, lane=2), + ], + scalar_args=[ + _hlir.BufferElement(buffer="fp", indices=(0,)), + ], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _strip(legacy_isa) == _strip(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + assert _count_mask_set(legacy_isa) == _count_mask_set(new_isa) == 2 diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_transfer.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_transfer.py new file mode 100644 index 0000000..8d98f39 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_transfer.py @@ -0,0 +1,123 @@ +"""End-to-end v2 tests for transfer ops: +``copy_v_to_v`` / ``v_fp_transfer_slice_v_to_fp`` / +``v_fp_transfer_slice_fp_to_v``. +""" + +import re + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 +_GP_RE = re.compile(r"\bgp\d+\b") + + +def _non_addi_lines(isa: str): + """Non-S_ADDI ISA lines with gpN → gpX canonicalisation.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head == "S_ADDI_INT": + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr, shape=None): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=shape or (1, MLEN, 1, MLEN), dtype="float16", + address=addr, cluster_dim=None, tile_layout=None, + ) + + +def _fpram(name, addr, shape=(MLEN,)): + return _hlir.Buffer( + name=name, scope=_scope.FPRAM, + shape=shape, dtype="float16", address=addr, + ) + + +def _whole_vram(name): + return _hlir.VramRegion( + parent=name, starts=(0, 0, 0, 0), extents=(1, MLEN, 1, MLEN), + ) + + +def _row_vram(name, row): + return _hlir.VramRegion( + parent=name, starts=(0, row, 0, 0), extents=(1, 1, 1, MLEN), + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +def test_copy_v_to_v_v2_matches_legacy(): + hlir = _hlir.HLIRModule( + name="copy_v_to_v", + buffers={ + "src": _vram("src", 0), + "dst": _vram("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind="copy_v_to_v", + buffer_args=[_whole_vram("src"), _whole_vram("dst")], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_lines(legacy_isa) == _non_addi_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", [ + "v_fp_transfer_slice_v_to_fp", + "v_fp_transfer_slice_fp_to_v", +]) +def test_v_fp_transfer_slice_v2_matches_legacy(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "v": _vram("v", 0), + "fp": _fpram("fp", 8192), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_row_vram("v", row=0)], + scalar_args=[ + _hlir.BufferElement(buffer="fp", indices=(0,)), + ], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_lines(legacy_isa) == _non_addi_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_end_to_end_vector.py b/tilelang_tvm_compiler/tests/test_v2_end_to_end_vector.py new file mode 100644 index 0000000..8c4eaa4 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_end_to_end_vector.py @@ -0,0 +1,129 @@ +"""End-to-end v2 tests for the vector ops +(v_zero / v_add / v_sub / v_mul / v_exp / v_reci / v_sqrt). +""" + +import pytest + +from tilelang_tvm_compiler import hlir as _hlir +from tilelang_tvm_compiler import scope as _scope +from tilelang_tvm_compiler import mir +from tilelang_tvm_compiler import pre_isa_to_mir as p2m +from tilelang_tvm_compiler import mir_to_isa as m2i +from tilelang_tvm_compiler.isa_pass import IsaEmitterPass +from tilelang_tvm_compiler.pre_isa_pass_v2 import PreIsaPassV2 +from tilelang_tvm_compiler.program_shim import make_shim + + +MLEN = 64 + + +import re + +_GP_RE = re.compile(r"\bgp\d+\b") + + +def _non_addi_lines(isa: str): + """Non-S_ADDI ISA lines with gpN → gpX canonicalisation, so + v2's aggressive GP reuse compares equal to legacy's separate + GP assignments.""" + out = [] + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head == "S_ADDI_INT": + continue + out.append(_GP_RE.sub("gpX", s)) + return out + + +def _vram(name, addr): + return _hlir.Buffer( + name=name, scope=_scope.VRAM, + shape=(1, MLEN, 1, MLEN), dtype="float16", + address=addr, + cluster_dim=None, tile_layout=None, + ) + + +def _whole(name): + return _hlir.VramRegion( + parent=name, starts=(0, 0, 0, 0), extents=(1, MLEN, 1, MLEN), + ) + + +def _v2_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + pre = PreIsaPassV2(shim).run(hlir) + fn = p2m.convert(pre, shim) + mir.verify(fn) + return m2i.emit(fn, shim) + + +def _legacy_emit(hlir): + shim = make_shim(mlen=MLEN, blen=4, btmm_lane_count=4, btmm_hlen=16) + return IsaEmitterPass(shim).run(hlir) + + +def test_v_zero_v2_matches_legacy(): + hlir = _hlir.HLIRModule( + name="v_zero", + buffers={"dst": _vram("dst", 0)}, + ops=[_hlir.Op( + kind="v_zero", + buffer_args=[_whole("dst")], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_lines(legacy_isa) == _non_addi_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", ["v_add", "v_sub", "v_mul"]) +def test_v_binary_v2_matches_legacy(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "lhs": _vram("lhs", 0), + "rhs": _vram("rhs", MLEN * MLEN), + "dst": _vram("dst", 2 * MLEN * MLEN), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_whole("lhs"), _whole("rhs"), _whole("dst")], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_lines(legacy_isa) == _non_addi_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) + + +@pytest.mark.parametrize("kind", ["v_exp", "v_reci", "v_sqrt"]) +def test_v_unary_v2_matches_legacy(kind): + hlir = _hlir.HLIRModule( + name=kind, + buffers={ + "src": _vram("src", 0), + "dst": _vram("dst", MLEN * MLEN), + }, + ops=[_hlir.Op( + kind=kind, + buffer_args=[_whole("src"), _whole("dst")], + scalar_args=[], + )], + param_names=[], + ) + legacy_isa = _legacy_emit(hlir) + new_isa = _v2_emit(hlir) + assert _non_addi_lines(legacy_isa) == _non_addi_lines(new_isa), ( + f"\nlegacy:\n{legacy_isa}\nv2:\n{new_isa}" + ) diff --git a/tilelang_tvm_compiler/tests/test_v2_flash_attention_min.py b/tilelang_tvm_compiler/tests/test_v2_flash_attention_min.py new file mode 100644 index 0000000..b562935 --- /dev/null +++ b/tilelang_tvm_compiler/tests/test_v2_flash_attention_min.py @@ -0,0 +1,114 @@ +"""End-to-end smoke: compile ``flash_attention_min`` through both +legacy and v2 paths, compare HW op stream structurally. + +This is the integration test for the full v2 pipeline (HLIR → +PreIsaIR v2 → MIR → ISA) on a real, non-toy kernel: multi-block +flash attention with online softmax, head fusion, DMA, BTMM, per- +head matmul, row-vector softmax math. If the v2 path produces the +same set + count of HW mnemonics as legacy, the migration is +materially complete for this kernel. +""" + +import re + +import pytest + +import tilelang.language as T +from tvm import tir + +from tilelang_tvm_compiler.kernels.flash_attention_min import ( + make_flash_attention_min, +) +from tilelang_tvm_compiler.pipeline import compile_kernel +from tilelang_tvm_compiler.target import PlenaTarget + + +_HW_OPCODES = frozenset({ + # matmul family + "M_MM", "M_MM_WO", "M_TMM", + "M_BTMM", "M_BMM_WO", "M_BTMV", "M_BMV_WO", + "M_MV", "M_MV_WO", + # vector + "V_ADD_VV", "V_SUB_VV", "V_MUL_VV", + "V_ADD_VF", "V_SUB_VF", "V_MUL_VF", + "V_EXP_V", "V_RECI_V", "V_SQRT_V", + "V_RED_MAX", "V_RED_SUM", + # FP scalar + "S_LD_FP", "S_ST_FP", + "S_ADD_FP", "S_SUB_FP", "S_MUL_FP", "S_MAX_FP", + "S_EXP_FP", "S_RECI_FP", "S_SQRT_FP", + # HBM + "H_PREFETCH_V", "H_PREFETCH_M", "H_STORE_V", "H_LOAD_V", + # control + "C_LOOP_START", "C_LOOP_END", + "C_SET_V_MASK_REG", "C_SET_ADDR_REG", + "C_SET_SCALE_REG", "C_SET_STRIDE_REG", +}) + + +def _hw_op_counts(isa: str): + """Return {mnemonic: count} for every HW opcode appearing in + ``isa``. Ignores S_ADDI/S_SLLI/S_SRLI/S_LUI/S_ADD/S_MUL_INT + (scalar address-arithmetic — legacy and v2 build addresses + differently).""" + counts = {} + for ln in isa.split("\n"): + s = ln.strip() + if not s or s.startswith(";"): + continue + head = s.split(None, 1)[0] + if head in _HW_OPCODES: + counts[head] = counts.get(head, 0) + 1 + return counts + + +def _build_kernel(): + return make_flash_attention_min( + rows=64, hlen=16, head_count=4, lane_count=4, + num_q_blocks=2, num_kv_blocks=1, + ) + + +def _target(): + return PlenaTarget(mlen=64, blen=4, btmm_lane_count=4, btmm_hlen=16) + + +@pytest.mark.skip(reason="enabled per-run when investigating v2 coverage") +def test_flash_attention_min_v2_structural_equal(): + """Compile flash_attention_min via legacy + v2; compare HW op + histograms. Skipped by default — flip @pytest.mark.skip off to + run; not in the regression set yet because the kernel pulls in + the full mid_ir pipeline + needs the tilelang frontend, both + heavy.""" + prim = _build_kernel() + target = _target() + + legacy = compile_kernel(prim, target=target, name="fa_min") + v2 = compile_kernel(prim, target=target, name="fa_min", use_v2=True) + + l_counts = _hw_op_counts(legacy.isa_text) + v_counts = _hw_op_counts(v2.isa_text) + assert l_counts == v_counts, ( + f"HW op histograms differ.\n" + f"only-in-legacy: {set(l_counts) - set(v_counts)}\n" + f"only-in-v2: {set(v_counts) - set(l_counts)}\n" + f"counts diff: " + + ", ".join( + f"{op}: legacy={l_counts.get(op,0)} v2={v_counts.get(op,0)}" + for op in sorted(set(l_counts) | set(v_counts)) + if l_counts.get(op, 0) != v_counts.get(op, 0) + ) + ) + + +def test_flash_attention_min_v2_compiles(): + """At minimum the v2 path must run to completion and produce a + non-empty ISA text. HW op histogram comparison gated separately + above; this is the keep-it-green sanity check.""" + prim = _build_kernel() + target = _target() + v2 = compile_kernel(prim, target=target, name="fa_min", use_v2=True) + assert v2.isa_text + # Must contain at least one M_BTMM (the Q@K^T head-fused matmul) + # and at least one M_MM (the per-head P@V matmul). + assert "M_BTMM" in v2.isa_text, v2.isa_text[:500]