perf(gemm2): 1:1 HIP a4w4 gemm2 port + small-M latency optimization to HIP parity#660
Open
fsx950223 wants to merge 5 commits into
Open
perf(gemm2): 1:1 HIP a4w4 gemm2 port + small-M latency optimization to HIP parity#660fsx950223 wants to merge 5 commits into
fsx950223 wants to merge 5 commits into
Conversation
Port aiter PR #3470 mxfp4 a4w4 gemm2 (BM32 atomic) to FlyDSL with matching LLVM intrinsics and ISA compute instructions; bit-exact vs HIP. Adds kernels/gemm2_a4w4_port.py, tests/kernels/test_gemm2_a4w4_port.py, and reference IR/ISA + design notes under kernels/gemm2_port_ref/. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: fsx950223 <fsx950223@outlook.com>
Two bit-exact latency-hiding changes close the small-M (M<=32) gap vs HIP: - Issue A->LDS before the cumsum-gated early-return branch. The raw.ptr.buffer.load.lds is side-effecting so it is not sunk back, letting its HBM latency overlap the cumsum/bound check. - Prefetch sorted_token_ids/sorted_weights (invariant) at epilog entry, before the cshuffle stores and both LDS barriers, so their latency overlaps the store+barriers instead of being exposed in the dependent atomic loop (the epilog is ~48% of small-M stalls). Result (cuda-graph unit-test timing, port/HIP): M=4 1.32x->0.92x, M=16 1.30x->0.92x, M=32 1.27x->0.95x; >=M=64 already 0.87-1.00x. Output stays bit-exact (4/4 accuracy/perf tests pass). test_performance now uses CUDA-graph GPU-event timing and is parametrized over test.py's KIMI M-list plus a large context (srt=147456). PORT_DESIGN.md documents the optimization and the interleaved measurement method. Signed-off-by: fsx950223 <fsx950223@outlook.com>
Remove the port's development reference directory (HIP/FlyDSL IR dumps + PORT_DESIGN.md) from version control; nothing imports or tests it. Also drop the two dangling docstring references to it. Files are kept locally (untracked). Signed-off-by: fsx950223 <fsx950223@outlook.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds a high-fidelity FlyDSL implementation of aiter’s gemm2_a4w4 MXFP4 MoE down-projection kernel for gfx950 (BM32 atomic variant), along with targeted latency-hiding tweaks to bring small‑M runtime to HIP parity while remaining bit-exact.
Changes:
- Introduces
kernels/gemm2_a4w4_port.py, a 1:1 instruction-structure port (buffer resources, A→LDS, MFMA scale ops, waitcnt+barrier fencing, and atomic bf16 epilog). - Adds a dedicated GPU test suite validating smoke/correctness vs HIP and performance via CUDA-graph event timing.
- Checks in reference artifacts and design notes (
PORT_DESIGN.md+ reference.llfiles) documenting the equivalence methodology.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
kernels/gemm2_a4w4_port.py |
New FlyDSL kernel implementation and small‑M latency optimizations while preserving bit-exactness. |
tests/kernels/test_gemm2_a4w4_port.py |
New l2_device test coverage: smoke, bit-exact HIP comparison, and performance guardrails. |
kernels/gemm2_port_ref/PORT_DESIGN.md |
Design/verification notes and reproduction guidance for the HIP↔FlyDSL parity work. |
kernels/gemm2_port_ref/hip_gemm2_bm32.ll |
Checked-in HIP reference LLVM IR snapshot for parity comparisons. |
kernels/gemm2_port_ref/flydsl_port_v5.ll |
Checked-in FlyDSL LLVM IR snapshot (v5) for parity comparisons. |
kernels/gemm2_port_ref/flydsl_port_v6.ll |
Checked-in FlyDSL LLVM IR snapshot (v6) for parity comparisons. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+31
to
+40
| import flydsl.compiler as flyc | ||
|
|
||
| pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] | ||
|
|
||
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) | ||
| for _p in (os.path.join(_REPO_ROOT, "build", "python_packages"), _REPO_ROOT): | ||
| if os.path.isdir(_p) and _p not in sys.path: | ||
| sys.path.insert(0, _p) | ||
|
|
||
| from flydsl.runtime.device import get_rocm_arch # noqa: E402 |
Parametrize compile_gemm2_a4w4_port by (BM, use_nt) and add the BM16 + kUseNT path, mirroring aiter's production instance ...TOPK9_BM16_ATOMIC_NT: - BM16 tiling: kMChunks=1 (i0 only), A->LDS gated to wave<BM/8 (rows wave*8), a_scale chunk_base=m_row/BM, epilog M_REPS=2/kMChunksEpi=1, 16KB LDS (higher occupancy). - NT: B_q buffer_load cache_modifier=2 (kBQ_AUX, non-temporal hint). - BM32 path unchanged. Both latency-hiding opts carry over. test_gemm2_a4w4_port.py is parametrized over both variants (bm32, bm16nt): smoke + bit-exact accuracy vs HIP + cuda-graph perf. BM16+NT is bit-exact vs HIP ...BM16_ATOMIC_NT and 0.86-1.02x its kernel time; 22/22 tests pass on gfx950. Signed-off-by: fsx950223 <fsx950223@outlook.com>
Add the bm32nt specialization (compile_gemm2_a4w4_port(BM=32, use_nt=True)) to the parametrized variants, validating bit-exact accuracy and cuda-graph perf vs HIP ...BM32_ATOMIC_NT (the instance fused_moe selects at larger M). Smoke + accuracy + performance now cover bm32, bm16nt, and bm32nt (33 cases, all pass on gfx950). Signed-off-by: fsx950223 <fsx950223@outlook.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a 1:1 FlyDSL port of aiter PR #3470's
gemm2_a4w4MXFP4 MoE down-projkernel (gfx950, BM32 atomic instance), then optimizes its small-M latency so it
matches or beats the HIP kernel across the full size range.
kernels/gemm2_a4w4_port.py): mirrors the HIP kernelinstruction-for-instruction — 4
make.buffer.rsrc, A→LDS viaraw.ptr.buffer.load.lds, B/scales viabuffer.load.v4i32/.i32,mfma_scale_f32_16x16x128_f8f6f4(.v4i32.v4i32),s_waitcnt vmcnt+s_barrierfences, atomic bf16 epilog with topk-weight multiply. Output isbit-exact vs aiter's HIP gemm2.
the kernels are latency-bound at small M, occupancy LDS-bound at 5 waves/SIMD
identical to HIP):
raw.ptr.buffer.load.ldsis side-effecting so the compiler cannot sink itback, letting its HBM latency overlap the cumsum load + bound check.
sorted_token_ids/sorted_weights(invariant) at epilogentry, before the cshuffle stores and both LDS barriers, so their global
latency overlaps the store+barriers instead of being exposed in the
dependent atomic loop (the epilog is ~48% of small-M stalls — the main lever).
Performance (unit test, CUDA-graph GPU-event timing)
Run via
test_performanceovertest.py's KIMI-K2.5 config(NE=385, H=7168, INTER=512, TOPK=9; M-list {4…256}) plus a large context
(M=16384).
srt = roundup(M*TOPK, BM)is the expanded+padded token count gemm2processes.
Before the optimization the small-M cases were 1.27–1.32× slower than HIP
(latency-bound, exposed prologue/epilog fixed cost); they are now at parity or
faster. Larger sizes were already faster and remain so.
Correctness
Bit-exact vs aiter HIP
mxfp4_moe_gemm2_a4w4(BM32 atomic) on identical inputbytes (cosine=1.0, max_abs_diff=0).
Testing
tests/kernels/test_gemm2_a4w4_port.py(gfx950,l2_device+rocm_lower):test_smoke— compile+run, finite/non-zero output (no aiter dep)test_accuracy_vs_hip[256,1024]— bit-exact vs HIPtest_performance[...]— CUDA-graph timing over the M-list + large contextAll pass on MI355 (gfx950).
e2e gemm2 performance (in
fused_moe, BM16+NT & BM32+NT)The port is wired into aiter's
fused_moemxfp4 path (dispatch hook gated byFLYDSL_GEMM2_PORT=1; aiter-side, separate PR). Under CUDA graph(production-realistic), the gemm2 kernel time (rocprofv3, median over graph
replays) is at parity with HIP across the whole M range —
fused_moeselectsBM16_ATOMIC_NTfor M≤128 andBM32_ATOMIC_NTfor M=256, both routed to the port:Output is cos=1.0 vs the HIP MoE end-to-end (bit-exact standalone). The kernel
supports BM∈{16,32} × {non-NT, NT}; tests cover bm32 / bm16nt / bm32nt (33 cases).
🤖 Generated with Claude Code