Summary
The frontend's tile-selection SPAD budget and the actual SPAD footprint of the emitted
kernel can diverge, in two distinct ways. Both lead to a kernel whose real .spad
section is larger than what tile selection assumed:
- Fused epilogue: tile-selection budget is naive about the epilogue output buffer.
- Broadcast operands: the SPAD buffer is over-allocated to the full output-tile shape.
This was surfaced when the SPAD-overflow guard budget was tightened to spad/2
(test_transformer_fusion now raises SpadOverflowError), but both issues are
pre-existing and independent of that change.
1. Fused epilogue buffer estimate != actual SPAD
What happens
MLIRGemmTemplate.select_tile feeds the epilogue count to the SPAD budget as
max(n_extra_read - 2, 0) (PyTorchSimFrontend/mlir/mlir_gemm_template.py:332):
tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), ...)
n_extra_read counts the extra input tensors an epilogue reads (the GEMM output is
removed). For a pure elementwise epilogue such as ReLU, n_extra_read == 0, so the
4th arg (n_extra_node) becomes 0. The -2 encodes the assumption that the input
(X) and weight (W) DMA buffers are already dead by the epilogue, so their SPAD can be
reused -> they are subtracted out of the budget.
But the codegen still emits the epilogue output as a separate, non-overlapping
memref.global that never reuses the freed X/W region. So the budget and the real
binary disagree.
Concrete numbers (test_transformer_fusion, ffn1 = addmm + ReLU)
Kernel: origins={'relu','addmm_3'}, M=512 N=3072 K=768, chosen tile
TILE_M=256 TILE_N=3072 TILE_K=384. Per-lane .spad globals in the emitted kernel:
| buffer |
per-lane floats |
bytes |
| X_spad |
768 |
3,072 |
| W_spad |
9,216 |
36,864 |
| Y_spad |
6,144 |
24,576 |
| buf3_spad (ReLU epilogue out) |
6,144 |
24,576 |
| total |
22,272 |
89,088 |
- Tile-selection budget (with
n_extra_node = 0, i.e. bare GEMM):
X 768 + W 9,216 + Y 6,144 = 16,128 floats = 64,512 B/lane -> fits spad/2 (65,536)
- Actual emitted
.spad (incl. buf3): 89,088 B/lane -> exceeds spad/2, fits full spad (131,072)
Under the old full-spad guard this was harmless (89,088 < 131,072); the spad/2 guard
exposes it.
Note: the real peak-live footprint does fit spad/2
The -2 liveness assumption is correct in principle. Buffer liveness:
matmul (k-loop): X(768) + W(9216) + Y(6144) = 16,128 <- peak
epilogue: X,W dead -> Y(6144) + buf3(6144) = 12,288
Peak simultaneously-live = 16,128 floats = 64,512 B < spad/2. buf3 (6,144) fits in
the dead X+W region (9,984). So the overflow is an allocation artifact: the four
buffers are laid out as disjoint static globals that never reuse freed space. The
liveness is static (X/W die at a fixed compile-time point, buf3 is born at the
epilogue) -- this is a compile-time SPAD offset-assignment problem, not a runtime
dynamic-free problem.
Possible directions
- Make SPAD allocation liveness-aware (offset reuse / arena) so
buf3 overlaps the dead
X/W region -> static .spad == peak-live, budget and guard agree, no spad wasted,
large tile (TILE_N=3072) kept.
- Cheap partial fix for elementwise epilogues: write the epilogue in-place into
Y_buffer and drop buf3 entirely.
- Fallback: make
select_tile budget honestly include the epilogue output buffer
(so tile selection picks a smaller tile) -- correct but costs reuse/perf (smaller
tiles -> weight reload, more DMA).
2. Broadcast operands over-allocate SPAD
When a fused operand is broadcast, the load-side tile is expanded to the full output
tile shape rather than the operand's logical (pre-broadcast) shape
(PyTorchSimFrontend/mlir/mlir_codegen_backend.py:1181-1182):
if broadcast and (total_dims != local_dims or ...):
local_dims = total_dims # Broadcast tile shape
So a [N] / [1, N] operand (e.g. a bias or per-channel scale used in an elementwise
epilogue) gets a [TILE_M, TILE_N] SPAD buffer instead of [TILE_N], materializing the
broadcast in SPAD and wasting space proportional to the broadcast factor (here TILE_M).
Example: in the kernel above the bias [3072] is MVIN'd with dram_stride=[0,1],
subtile_size=[256, 3072] -- a 256x materialization of a 3072-element vector. The
operand could instead be stored once ([1, TILE_N]) and broadcast at compute time.
Possible direction
Keep broadcast operands at their logical reduced shape in SPAD and broadcast at the
vector-compute step, instead of expanding local_dims to the full output tile.
Repro
python tests/ops/fusion/test_transformer_fusion.py on the spad/2 guard
(commit that runs the spad-overflow check in timing-only mode at spad/2).
Fails with PyTorchSimFrontend.extension_codecache.SpadOverflowError on mlir_kernel_6.
Summary
The frontend's tile-selection SPAD budget and the actual SPAD footprint of the emitted
kernel can diverge, in two distinct ways. Both lead to a kernel whose real
.spadsection is larger than what tile selection assumed:
This was surfaced when the SPAD-overflow guard budget was tightened to
spad/2(
test_transformer_fusionnow raisesSpadOverflowError), but both issues arepre-existing and independent of that change.
1. Fused epilogue buffer estimate != actual SPAD
What happens
MLIRGemmTemplate.select_tilefeeds the epilogue count to the SPAD budget asmax(n_extra_read - 2, 0)(PyTorchSimFrontend/mlir/mlir_gemm_template.py:332):n_extra_readcounts the extra input tensors an epilogue reads (the GEMM output isremoved). For a pure elementwise epilogue such as ReLU,
n_extra_read == 0, so the4th arg (
n_extra_node) becomes0. The-2encodes the assumption that the input(X) and weight (W) DMA buffers are already dead by the epilogue, so their SPAD can be
reused -> they are subtracted out of the budget.
But the codegen still emits the epilogue output as a separate, non-overlapping
memref.globalthat never reuses the freed X/W region. So the budget and the realbinary disagree.
Concrete numbers (test_transformer_fusion, ffn1 = addmm + ReLU)
Kernel:
origins={'relu','addmm_3'}, M=512 N=3072 K=768, chosen tileTILE_M=256 TILE_N=3072 TILE_K=384. Per-lane
.spadglobals in the emitted kernel:n_extra_node = 0, i.e. bare GEMM):X 768 + W 9,216 + Y 6,144 = 16,128 floats = 64,512 B/lane -> fits
spad/2(65,536).spad(incl.buf3): 89,088 B/lane -> exceedsspad/2, fits full spad (131,072)Under the old full-spad guard this was harmless (89,088 < 131,072); the
spad/2guardexposes it.
Note: the real peak-live footprint does fit spad/2
The
-2liveness assumption is correct in principle. Buffer liveness:Peak simultaneously-live = 16,128 floats = 64,512 B < spad/2.
buf3(6,144) fits inthe dead X+W region (9,984). So the overflow is an allocation artifact: the four
buffers are laid out as disjoint static globals that never reuse freed space. The
liveness is static (X/W die at a fixed compile-time point, buf3 is born at the
epilogue) -- this is a compile-time SPAD offset-assignment problem, not a runtime
dynamic-free problem.
Possible directions
buf3overlaps the deadX/W region -> static
.spad== peak-live, budget and guard agree, no spad wasted,large tile (TILE_N=3072) kept.
Y_bufferand dropbuf3entirely.select_tilebudget honestly include the epilogue output buffer(so tile selection picks a smaller tile) -- correct but costs reuse/perf (smaller
tiles -> weight reload, more DMA).
2. Broadcast operands over-allocate SPAD
When a fused operand is broadcast, the load-side tile is expanded to the full output
tile shape rather than the operand's logical (pre-broadcast) shape
(
PyTorchSimFrontend/mlir/mlir_codegen_backend.py:1181-1182):So a
[N]/[1, N]operand (e.g. a bias or per-channel scale used in an elementwiseepilogue) gets a
[TILE_M, TILE_N]SPAD buffer instead of[TILE_N], materializing thebroadcast in SPAD and wasting space proportional to the broadcast factor (here
TILE_M).Example: in the kernel above the bias
[3072]is MVIN'd withdram_stride=[0,1],subtile_size=[256, 3072]-- a 256x materialization of a 3072-element vector. Theoperand could instead be stored once (
[1, TILE_N]) and broadcast at compute time.Possible direction
Keep broadcast operands at their logical reduced shape in SPAD and broadcast at the
vector-compute step, instead of expanding
local_dimsto the full output tile.Repro
python tests/ops/fusion/test_transformer_fusion.pyon thespad/2guard(commit that runs the spad-overflow check in timing-only mode at
spad/2).Fails with
PyTorchSimFrontend.extension_codecache.SpadOverflowErroronmlir_kernel_6.