Skip to content

[Frontend] SPAD budget vs actual footprint diverge: naive fused-epilogue buffer estimate and broadcast over-allocation #275

Description

@YWHyuk

Summary

The frontend's tile-selection SPAD budget and the actual SPAD footprint of the emitted
kernel can diverge, in two distinct ways. Both lead to a kernel whose real .spad
section is larger than what tile selection assumed:

  1. Fused epilogue: tile-selection budget is naive about the epilogue output buffer.
  2. Broadcast operands: the SPAD buffer is over-allocated to the full output-tile shape.

This was surfaced when the SPAD-overflow guard budget was tightened to spad/2
(test_transformer_fusion now raises SpadOverflowError), but both issues are
pre-existing and independent of that change.


1. Fused epilogue buffer estimate != actual SPAD

What happens

MLIRGemmTemplate.select_tile feeds the epilogue count to the SPAD budget as
max(n_extra_read - 2, 0) (PyTorchSimFrontend/mlir/mlir_gemm_template.py:332):

tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), ...)

n_extra_read counts the extra input tensors an epilogue reads (the GEMM output is
removed). For a pure elementwise epilogue such as ReLU, n_extra_read == 0, so the
4th arg (n_extra_node) becomes 0. The -2 encodes the assumption that the input
(X) and weight (W) DMA buffers are already dead by the epilogue, so their SPAD can be
reused -> they are subtracted out of the budget.

But the codegen still emits the epilogue output as a separate, non-overlapping
memref.global
that never reuses the freed X/W region. So the budget and the real
binary disagree.

Concrete numbers (test_transformer_fusion, ffn1 = addmm + ReLU)

Kernel: origins={'relu','addmm_3'}, M=512 N=3072 K=768, chosen tile
TILE_M=256 TILE_N=3072 TILE_K=384. Per-lane .spad globals in the emitted kernel:

buffer per-lane floats bytes
X_spad 768 3,072
W_spad 9,216 36,864
Y_spad 6,144 24,576
buf3_spad (ReLU epilogue out) 6,144 24,576
total 22,272 89,088
  • Tile-selection budget (with n_extra_node = 0, i.e. bare GEMM):
    X 768 + W 9,216 + Y 6,144 = 16,128 floats = 64,512 B/lane -> fits spad/2 (65,536)
  • Actual emitted .spad (incl. buf3): 89,088 B/lane -> exceeds spad/2, fits full spad (131,072)

Under the old full-spad guard this was harmless (89,088 < 131,072); the spad/2 guard
exposes it.

Note: the real peak-live footprint does fit spad/2

The -2 liveness assumption is correct in principle. Buffer liveness:

matmul (k-loop):  X(768) + W(9216) + Y(6144) = 16,128  <- peak
epilogue:         X,W dead -> Y(6144) + buf3(6144) = 12,288

Peak simultaneously-live = 16,128 floats = 64,512 B < spad/2. buf3 (6,144) fits in
the dead X+W region (9,984). So the overflow is an allocation artifact: the four
buffers are laid out as disjoint static globals that never reuse freed space. The
liveness is static (X/W die at a fixed compile-time point, buf3 is born at the
epilogue) -- this is a compile-time SPAD offset-assignment problem, not a runtime
dynamic-free problem.

Possible directions

  • Make SPAD allocation liveness-aware (offset reuse / arena) so buf3 overlaps the dead
    X/W region -> static .spad == peak-live, budget and guard agree, no spad wasted,
    large tile (TILE_N=3072) kept.
  • Cheap partial fix for elementwise epilogues: write the epilogue in-place into
    Y_buffer and drop buf3 entirely.
  • Fallback: make select_tile budget honestly include the epilogue output buffer
    (so tile selection picks a smaller tile) -- correct but costs reuse/perf (smaller
    tiles -> weight reload, more DMA).

2. Broadcast operands over-allocate SPAD

When a fused operand is broadcast, the load-side tile is expanded to the full output
tile shape
rather than the operand's logical (pre-broadcast) shape
(PyTorchSimFrontend/mlir/mlir_codegen_backend.py:1181-1182):

if broadcast and (total_dims != local_dims or ...):
    local_dims = total_dims  # Broadcast tile shape

So a [N] / [1, N] operand (e.g. a bias or per-channel scale used in an elementwise
epilogue) gets a [TILE_M, TILE_N] SPAD buffer instead of [TILE_N], materializing the
broadcast in SPAD and wasting space proportional to the broadcast factor (here TILE_M).
Example: in the kernel above the bias [3072] is MVIN'd with dram_stride=[0,1],
subtile_size=[256, 3072] -- a 256x materialization of a 3072-element vector. The
operand could instead be stored once ([1, TILE_N]) and broadcast at compute time.

Possible direction

Keep broadcast operands at their logical reduced shape in SPAD and broadcast at the
vector-compute step, instead of expanding local_dims to the full output tile.


Repro

python tests/ops/fusion/test_transformer_fusion.py on the spad/2 guard
(commit that runs the spad-overflow check in timing-only mode at spad/2).
Fails with PyTorchSimFrontend.extension_codecache.SpadOverflowError on mlir_kernel_6.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions