Skip to content

perf(gemm2): 1:1 HIP a4w4 gemm2 port + small-M latency optimization to HIP parity#660

Open
fsx950223 wants to merge 5 commits into
mainfrom
worktree-gemm2-port
Open

perf(gemm2): 1:1 HIP a4w4 gemm2 port + small-M latency optimization to HIP parity#660
fsx950223 wants to merge 5 commits into
mainfrom
worktree-gemm2-port

Conversation

@fsx950223
Copy link
Copy Markdown
Contributor

@fsx950223 fsx950223 commented Jun 5, 2026

Summary

Adds a 1:1 FlyDSL port of aiter PR #3470's gemm2_a4w4 MXFP4 MoE down-proj
kernel (gfx950, BM32 atomic instance), then optimizes its small-M latency so it
matches or beats the HIP kernel across the full size range.

  • Port (kernels/gemm2_a4w4_port.py): mirrors the HIP kernel
    instruction-for-instruction — 4 make.buffer.rsrc, A→LDS via
    raw.ptr.buffer.load.lds, B/scales via buffer.load.v4i32/.i32,
    mfma_scale_f32_16x16x128_f8f6f4 (.v4i32.v4i32), s_waitcnt vmcnt +
    s_barrier fences, atomic bf16 epilog with topk-weight multiply. Output is
    bit-exact vs aiter's HIP gemm2.
  • Small-M latency optimization (two bit-exact, latency-hiding changes;
    the kernels are latency-bound at small M, occupancy LDS-bound at 5 waves/SIMD
    identical to HIP):
    1. Issue A→LDS before the cumsum-gated early-return branch. The
      raw.ptr.buffer.load.lds is side-effecting so the compiler cannot sink it
      back, letting its HBM latency overlap the cumsum load + bound check.
    2. Prefetch sorted_token_ids/sorted_weights (invariant) at epilog
      entry
      , before the cshuffle stores and both LDS barriers, so their global
      latency overlaps the store+barriers instead of being exposed in the
      dependent atomic loop (the epilog is ~48% of small-M stalls — the main lever).

Performance (unit test, CUDA-graph GPU-event timing)

Run via test_performance over test.py's KIMI-K2.5 config
(NE=385, H=7168, INTER=512, TOPK=9; M-list {4…256}) plus a large context
(M=16384). srt = roundup(M*TOPK, BM) is the expanded+padded token count gemm2
processes.

M (tokens) srt HIP µs port µs port/HIP
4 64 9.6 8.8 0.92×
8 96 9.7 8.9 0.92×
16 160 10.1 9.3 0.92×
32 288 11.0 10.5 0.95×
64 576 15.3 15.2 1.00×
128 1152 28.8 26.5 0.92×
256 2304 47.7 41.3 0.87×
16384 147456 2544.8 2228.5 0.88×

Before the optimization the small-M cases were 1.27–1.32× slower than HIP
(latency-bound, exposed prologue/epilog fixed cost); they are now at parity or
faster. Larger sizes were already faster and remain so.

Measurement note: per-process rocprofv3 at 4–6 µs kernels is dominated by GPU
clock state (±10%); the reliable signal is the port's absolute median and the
in-run (cuda-graph or interleaved) ratio. PORT_DESIGN.md documents the method.

Correctness

Bit-exact vs aiter HIP mxfp4_moe_gemm2_a4w4 (BM32 atomic) on identical input
bytes (cosine=1.0, max_abs_diff=0).

Testing

tests/kernels/test_gemm2_a4w4_port.py (gfx950, l2_device + rocm_lower):

  • test_smoke — compile+run, finite/non-zero output (no aiter dep)
  • test_accuracy_vs_hip[256,1024] — bit-exact vs HIP
  • test_performance[...] — CUDA-graph timing over the M-list + large context

All pass on MI355 (gfx950).

e2e gemm2 performance (in fused_moe, BM16+NT & BM32+NT)

The port is wired into aiter's fused_moe mxfp4 path (dispatch hook gated by
FLYDSL_GEMM2_PORT=1; aiter-side, separate PR). Under CUDA graph
(production-realistic), the gemm2 kernel time (rocprofv3, median over graph
replays) is at parity with HIP across the whole M range — fused_moe selects
BM16_ATOMIC_NT for M≤128 and BM32_ATOMIC_NT for M=256, both routed to the port:

M gemm2 instance HIP µs port µs port/HIP
4 BM16_ATOMIC_NT 10.36 10.60 1.023×
8 BM16_ATOMIC_NT 18.72 18.92 1.011×
16 BM16_ATOMIC_NT 30.40 30.84 1.014×
32 BM16_ATOMIC_NT 45.30 46.36 1.023×
64 BM16_ATOMIC_NT 64.78 64.84 1.001×
128 BM16_ATOMIC_NT 79.72 80.60 1.011×
256 BM32_ATOMIC_NT 95.16 95.76 1.006×

Output is cos=1.0 vs the HIP MoE end-to-end (bit-exact standalone). The kernel
supports BM∈{16,32} × {non-NT, NT}; tests cover bm32 / bm16nt / bm32nt (33 cases).

🤖 Generated with Claude Code

fsx950223 and others added 2 commits June 5, 2026 03:43
Port aiter PR #3470 mxfp4 a4w4 gemm2 (BM32 atomic) to FlyDSL with
matching LLVM intrinsics and ISA compute instructions; bit-exact vs HIP.
Adds kernels/gemm2_a4w4_port.py, tests/kernels/test_gemm2_a4w4_port.py,
and reference IR/ISA + design notes under kernels/gemm2_port_ref/.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Two bit-exact latency-hiding changes close the small-M (M<=32) gap vs HIP:

- Issue A->LDS before the cumsum-gated early-return branch. The
  raw.ptr.buffer.load.lds is side-effecting so it is not sunk back, letting
  its HBM latency overlap the cumsum/bound check.
- Prefetch sorted_token_ids/sorted_weights (invariant) at epilog entry,
  before the cshuffle stores and both LDS barriers, so their latency overlaps
  the store+barriers instead of being exposed in the dependent atomic loop
  (the epilog is ~48% of small-M stalls).

Result (cuda-graph unit-test timing, port/HIP): M=4 1.32x->0.92x,
M=16 1.30x->0.92x, M=32 1.27x->0.95x; >=M=64 already 0.87-1.00x. Output
stays bit-exact (4/4 accuracy/perf tests pass).

test_performance now uses CUDA-graph GPU-event timing and is parametrized
over test.py's KIMI M-list plus a large context (srt=147456). PORT_DESIGN.md
documents the optimization and the interleaved measurement method.

Signed-off-by: fsx950223 <fsx950223@outlook.com>
Copilot AI review requested due to automatic review settings June 5, 2026 05:04
Remove the port's development reference directory (HIP/FlyDSL IR dumps +
PORT_DESIGN.md) from version control; nothing imports or tests it. Also drop
the two dangling docstring references to it. Files are kept locally (untracked).

Signed-off-by: fsx950223 <fsx950223@outlook.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a high-fidelity FlyDSL implementation of aiter’s gemm2_a4w4 MXFP4 MoE down-projection kernel for gfx950 (BM32 atomic variant), along with targeted latency-hiding tweaks to bring small‑M runtime to HIP parity while remaining bit-exact.

Changes:

  • Introduces kernels/gemm2_a4w4_port.py, a 1:1 instruction-structure port (buffer resources, A→LDS, MFMA scale ops, waitcnt+barrier fencing, and atomic bf16 epilog).
  • Adds a dedicated GPU test suite validating smoke/correctness vs HIP and performance via CUDA-graph event timing.
  • Checks in reference artifacts and design notes (PORT_DESIGN.md + reference .ll files) documenting the equivalence methodology.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
kernels/gemm2_a4w4_port.py New FlyDSL kernel implementation and small‑M latency optimizations while preserving bit-exactness.
tests/kernels/test_gemm2_a4w4_port.py New l2_device test coverage: smoke, bit-exact HIP comparison, and performance guardrails.
kernels/gemm2_port_ref/PORT_DESIGN.md Design/verification notes and reproduction guidance for the HIP↔FlyDSL parity work.
kernels/gemm2_port_ref/hip_gemm2_bm32.ll Checked-in HIP reference LLVM IR snapshot for parity comparisons.
kernels/gemm2_port_ref/flydsl_port_v5.ll Checked-in FlyDSL LLVM IR snapshot (v5) for parity comparisons.
kernels/gemm2_port_ref/flydsl_port_v6.ll Checked-in FlyDSL LLVM IR snapshot (v6) for parity comparisons.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +31 to +40
import flydsl.compiler as flyc

pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]

_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
for _p in (os.path.join(_REPO_ROOT, "build", "python_packages"), _REPO_ROOT):
if os.path.isdir(_p) and _p not in sys.path:
sys.path.insert(0, _p)

from flydsl.runtime.device import get_rocm_arch # noqa: E402
fsx950223 added 2 commits June 5, 2026 06:30
Parametrize compile_gemm2_a4w4_port by (BM, use_nt) and add the BM16 +
kUseNT path, mirroring aiter's production instance
...TOPK9_BM16_ATOMIC_NT:

- BM16 tiling: kMChunks=1 (i0 only), A->LDS gated to wave<BM/8 (rows
  wave*8), a_scale chunk_base=m_row/BM, epilog M_REPS=2/kMChunksEpi=1,
  16KB LDS (higher occupancy).
- NT: B_q buffer_load cache_modifier=2 (kBQ_AUX, non-temporal hint).
- BM32 path unchanged. Both latency-hiding opts carry over.

test_gemm2_a4w4_port.py is parametrized over both variants (bm32,
bm16nt): smoke + bit-exact accuracy vs HIP + cuda-graph perf. BM16+NT is
bit-exact vs HIP ...BM16_ATOMIC_NT and 0.86-1.02x its kernel time;
22/22 tests pass on gfx950.

Signed-off-by: fsx950223 <fsx950223@outlook.com>
Add the bm32nt specialization (compile_gemm2_a4w4_port(BM=32,
use_nt=True)) to the parametrized variants, validating bit-exact accuracy
and cuda-graph perf vs HIP ...BM32_ATOMIC_NT (the instance fused_moe
selects at larger M). Smoke + accuracy + performance now cover bm32,
bm16nt, and bm32nt (33 cases, all pass on gfx950).

Signed-off-by: fsx950223 <fsx950223@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants