Skip to content

[gfx1250][gemm] Add PTPC FP8/A8W4 support#649

Draft
aoli26 wants to merge 4 commits into
gfx1250/gemm_fp8_optfrom
gfx1250/gemm_ptpc
Draft

[gfx1250][gemm] Add PTPC FP8/A8W4 support#649
aoli26 wants to merge 4 commits into
gfx1250/gemm_fp8_optfrom
gfx1250/gemm_ptpc

Conversation

@aoli26
Copy link
Copy Markdown
Contributor

@aoli26 aoli26 commented Jun 3, 2026

Motivation

Add per-token per-channel (PTPC) scaling to the gfx1250 GEMM kernel, where scales are per-token sa[M] and per-channel sb[N] (constant along K) fp32 data and thus applied once in the epilogue rather than per K-block.

Technical Details

PTPC FP8 runs the unscaled WMMA in the K-loop while A8W4 uses the scaled f8f6f4 op with an identity scale, and sa*sb is applied in fp32 in the epilogue (split-K supported via per-chunk scale + atomic add). All changes are compile-time gated to PTPC so the mxscale path is untouched; PTPC additionally skips scale TDM/LDS (only 2 loader waves needed) and prefetches the epilogue sa/sb loads behind the last WMMAs.

Test Plan

pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py -k ptpc, plus ISA inspection of the PTPC kernels.

Test Result

All 14 PTPC tests pass (FP8 + A8W4 + split-K); ISA confirms scale TDM removal and epilogue prefetch with lower VGPR count and 0 spill.

Submission Checklist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant