[gfx1250][gemm] Add PTPC FP8/A8W4 support#649
Draft
aoli26 wants to merge 4 commits into
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add per-token per-channel (PTPC) scaling to the gfx1250 GEMM kernel, where scales are per-token
sa[M]and per-channelsb[N](constant along K) fp32 data and thus applied once in the epilogue rather than per K-block.Technical Details
PTPC FP8 runs the unscaled WMMA in the K-loop while A8W4 uses the scaled f8f6f4 op with an identity scale, and
sa*sbis applied in fp32 in the epilogue (split-K supported via per-chunk scale + atomic add). All changes are compile-time gated to PTPC so the mxscale path is untouched; PTPC additionally skips scale TDM/LDS (only 2 loader waves needed) and prefetches the epiloguesa/sbloads behind the last WMMAs.Test Plan
pytest tests/kernels/test_gemm_fp8fp4_gfx1250.py -k ptpc, plus ISA inspection of the PTPC kernels.Test Result
All 14 PTPC tests pass (FP8 + A8W4 + split-K); ISA confirms scale TDM removal and epilogue prefetch with lower VGPR count and 0 spill.
Submission Checklist