Skip to content

[gfx1250][gemm] Optimize FP8 GEMM deep pipeline and add scale loading experiment#608

Open
aoli26 wants to merge 19 commits into
mainfrom
gfx1250/gemm_fp8_opt
Open

[gfx1250][gemm] Optimize FP8 GEMM deep pipeline and add scale loading experiment#608
aoli26 wants to merge 19 commits into
mainfrom
gfx1250/gemm_fp8_opt

Conversation

@aoli26
Copy link
Copy Markdown
Contributor

@aoli26 aoli26 commented Jun 2, 2026

Motivation

Improve gfx1250 MXFP8 GEMM performance and robustness by adding a dedicated FP8 deep-pipeline schedule, tuning LDS/TDM scheduling behavior, and add extra scale-load buff-load path on the main optimized FP8(experiment, default off). It also fixes a split-K accuracy bug.

Technical Details

This PR adds an FP8 deep-pipeline compute schedule for the 256x256x128 gfx1250 MXFP8 configuration, including tuned wait placement, panel scheduling, instruction prefetch behavior, TDM handling, mcast early timeout and cluster fence/signal overlap. It also introduces a segmented LDS layout for the optimized FP8 path(avoid LDS segment conflicts), adds direct buffer-load-to-VGPR scale loading modes, removes older experimental scale staging paths, enables the TDM multicast early-timeout descriptor bit for loads, and improves benchmark timing/graph verification coverage for gfx1250 GEMM.
It also fixes split-K accuracy by switching the cross-workgroup atomic accumulation to device-scoped llvm.atomicrmw fadd.

Test Plan

pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py

Adds test_mxfp8_gemm_splitk covering split_k ∈ {2, 4, 6, 8} for f32 and bf16 outputs.

Test Result

All tests passed.

Submission Checklist

Copilot AI review requested due to automatic review settings June 2, 2026 08:05
@aoli26 aoli26 added the enhancement New feature or request label Jun 2, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets gfx1250 MXFP8 GEMM performance/robustness by adding a dedicated FP8 deep-pipeline schedule, introducing a segmented LDS layout for the tuned 256×256×128 configuration, and adding an experimental scale loading path that can bypass TDM/LDS (buffer_load→VGPR). It also updates benchmark/verification utilities and tweaks TDM descriptor behavior.

Changes:

  • Add an FP8 deep-pipeline compute schedule (auto-selected when eligible) and a reference segmented LDS layout for the optimized gfx1250 FP8 256×256×128 kernel.
  • Introduce experimental scale load modes (vgpr, vgpr_ab_split) that load E8M0 scales directly to VGPRs (plus test/CLI support and new preshuffle layout).
  • Improve benchmark timing utilities (single-event fast path, hipGraph capture/replay sanity checks) and enable TDM multicast early-timeout for loads.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
kernels/gemm_fp8fp4_gfx1250.py Adds FP8 deep-pipeline schedule, segmented LDS layout, VGPR scale-load experiment, and related fencing/scheduling changes.
python/flydsl/expr/rocdl/tdm_ops.py Enables the TDM descriptor early_timeout bit for load descriptors (kept off for stores).
python/flydsl/expr/rocdl/inline_asm.py Modifies gfx1250 instruction prefetch wrapper implementation/defaults.
tests/kernels/test_gemm_fp8fp4_gfx1250.py Adds coalesced scale preshuffle for VGPR scale path, updates tests/bench CLI, adds hipGraph verification and new fill modes.
tests/kernels/benchmark_common.py Adds a fast timing path using a single event pair when no flush/prep is requested.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/flydsl/expr/rocdl/inline_asm.py Outdated
Comment thread kernels/gemm_fp8fp4_gfx1250.py Outdated
Comment thread tests/kernels/benchmark_common.py Outdated
Comment thread tests/kernels/test_gemm_fp8fp4_gfx1250.py Outdated
@aoli26 aoli26 force-pushed the gfx1250/gemm_fp8_opt branch 2 times, most recently from bd2990c to a38b444 Compare June 2, 2026 08:14
Comment thread python/flydsl/expr/rocdl/inline_asm.py
Comment thread python/flydsl/expr/rocdl/tdm_ops.py
aoli26 added 19 commits June 3, 2026 00:23
Replace the buffer_lds_stage scale paths with vgpr/vgpr_ab_split, loading
scale global->VGPR via buffer_load (off the LDS/TDM/barrier path) with a
coalesced lane-major host layout. Add opt-in overlay-chunks and ab-half-fence
schedules, and trim verbose comments.
…ed_group_barrier)

Opt-in hot-loop scheduling for the FP8 deep-pipeline (default unchanged):
- hot_loop_sched_mode = default | iglp | manual
- iglp:   clean region (internal sched_barrier(0) suppressed) + iglp_opt(0)
          (LLVM MFMASmallGemmOpt) -> SP3-style DS/MFMA interleave.
- manual: same clean region + hand-emitted sched_group_barrier template,
          granularity hot_loop_manual_ds:hot_loop_manual_mfma (default 8:8).
default mode byte-identical to baseline. All modes cosine=1.0; fp8 wave_spec
tests pass. Net win over default unconfirmed (AM underestimates ds latency) -
to be ranked on silicon ATT.
Skip only the scale TDM vmem->LDS load path: drop the scale loader waves
(2,3) from the active TDM the same way the vgpr path does, and pre-fill the
scale LDS stages once with a constant E8M0=1.0 (0x7F) byte. The downstream
scale LDS read path and the scaled-WMMA op are left byte-for-byte unchanged,
so the variant isolates the cost of scale TDM delivery alone (unlike
FLYDSL_SCALE_DISABLED, which also removes the LDS reads).

Gated on scale_load_path='tdm' + wave_specialized_tdm; mutually exclusive
with FLYDSL_SCALE_DISABLED. Also adds test_mxfp8_hot_loop_sched_modes.
Split-K accumulates partial K-sums across workgroups via atomic add into C.
Two fixes:

- Correctness: route the atomic accumulation through llvm.atomicrmw fadd on a
  global (addrspace 1) pointer with syncscope("agent") instead of buffer atomics.

- Precision: tests run split-K with f32 accumulation and convert to the requested
  bf16/f16 on the host, avoiding compounded rounding from half-precision atomics.
  Adds test_mxfp8_gemm_splitk over split_k in {2,4,6,8}.
@aoli26 aoli26 force-pushed the gfx1250/gemm_fp8_opt branch from 8b24891 to 1a90fd8 Compare June 2, 2026 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants