[gfx1250][gemm] Optimize FP8 GEMM deep pipeline and add scale loading experiment#608
Open
aoli26 wants to merge 19 commits into
Open
[gfx1250][gemm] Optimize FP8 GEMM deep pipeline and add scale loading experiment#608aoli26 wants to merge 19 commits into
aoli26 wants to merge 19 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR targets gfx1250 MXFP8 GEMM performance/robustness by adding a dedicated FP8 deep-pipeline schedule, introducing a segmented LDS layout for the tuned 256×256×128 configuration, and adding an experimental scale loading path that can bypass TDM/LDS (buffer_load→VGPR). It also updates benchmark/verification utilities and tweaks TDM descriptor behavior.
Changes:
- Add an FP8 deep-pipeline compute schedule (auto-selected when eligible) and a reference segmented LDS layout for the optimized gfx1250 FP8 256×256×128 kernel.
- Introduce experimental scale load modes (
vgpr,vgpr_ab_split) that load E8M0 scales directly to VGPRs (plus test/CLI support and new preshuffle layout). - Improve benchmark timing utilities (single-event fast path, hipGraph capture/replay sanity checks) and enable TDM multicast early-timeout for loads.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
kernels/gemm_fp8fp4_gfx1250.py |
Adds FP8 deep-pipeline schedule, segmented LDS layout, VGPR scale-load experiment, and related fencing/scheduling changes. |
python/flydsl/expr/rocdl/tdm_ops.py |
Enables the TDM descriptor early_timeout bit for load descriptors (kept off for stores). |
python/flydsl/expr/rocdl/inline_asm.py |
Modifies gfx1250 instruction prefetch wrapper implementation/defaults. |
tests/kernels/test_gemm_fp8fp4_gfx1250.py |
Adds coalesced scale preshuffle for VGPR scale path, updates tests/bench CLI, adds hipGraph verification and new fill modes. |
tests/kernels/benchmark_common.py |
Adds a fast timing path using a single event pair when no flush/prep is requested. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
bd2990c to
a38b444
Compare
coderfeli
reviewed
Jun 2, 2026
coderfeli
reviewed
Jun 2, 2026
Replace the buffer_lds_stage scale paths with vgpr/vgpr_ab_split, loading scale global->VGPR via buffer_load (off the LDS/TDM/barrier path) with a coalesced lane-major host layout. Add opt-in overlay-chunks and ab-half-fence schedules, and trim verbose comments.
…ed_group_barrier)
Opt-in hot-loop scheduling for the FP8 deep-pipeline (default unchanged):
- hot_loop_sched_mode = default | iglp | manual
- iglp: clean region (internal sched_barrier(0) suppressed) + iglp_opt(0)
(LLVM MFMASmallGemmOpt) -> SP3-style DS/MFMA interleave.
- manual: same clean region + hand-emitted sched_group_barrier template,
granularity hot_loop_manual_ds:hot_loop_manual_mfma (default 8:8).
default mode byte-identical to baseline. All modes cosine=1.0; fp8 wave_spec
tests pass. Net win over default unconfirmed (AM underestimates ds latency) -
to be ranked on silicon ATT.
Skip only the scale TDM vmem->LDS load path: drop the scale loader waves (2,3) from the active TDM the same way the vgpr path does, and pre-fill the scale LDS stages once with a constant E8M0=1.0 (0x7F) byte. The downstream scale LDS read path and the scaled-WMMA op are left byte-for-byte unchanged, so the variant isolates the cost of scale TDM delivery alone (unlike FLYDSL_SCALE_DISABLED, which also removes the LDS reads). Gated on scale_load_path='tdm' + wave_specialized_tdm; mutually exclusive with FLYDSL_SCALE_DISABLED. Also adds test_mxfp8_hot_loop_sched_modes.
Split-K accumulates partial K-sums across workgroups via atomic add into C.
Two fixes:
- Correctness: route the atomic accumulation through llvm.atomicrmw fadd on a
global (addrspace 1) pointer with syncscope("agent") instead of buffer atomics.
- Precision: tests run split-K with f32 accumulation and convert to the requested
bf16/f16 on the host, avoiding compounded rounding from half-precision atomics.
Adds test_mxfp8_gemm_splitk over split_k in {2,4,6,8}.
8b24891 to
1a90fd8
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Improve gfx1250 MXFP8 GEMM performance and robustness by adding a dedicated FP8 deep-pipeline schedule, tuning LDS/TDM scheduling behavior, and add extra scale-load buff-load path on the main optimized FP8(experiment, default off). It also fixes a split-K accuracy bug.
Technical Details
This PR adds an FP8 deep-pipeline compute schedule for the 256x256x128 gfx1250 MXFP8 configuration, including tuned wait placement, panel scheduling, instruction prefetch behavior, TDM handling, mcast early timeout and cluster fence/signal overlap. It also introduces a segmented LDS layout for the optimized FP8 path(avoid LDS segment conflicts), adds direct buffer-load-to-VGPR scale loading modes, removes older experimental scale staging paths, enables the TDM multicast early-timeout descriptor bit for loads, and improves benchmark timing/graph verification coverage for gfx1250 GEMM.
It also fixes split-K accuracy by switching the cross-workgroup atomic accumulation to device-scoped
llvm.atomicrmw fadd.Test Plan
Adds
test_mxfp8_gemm_splitkcoveringsplit_k∈ {2, 4, 6, 8} for f32 and bf16 outputs.Test Result
All tests passed.
Submission Checklist