test(moe_blockscale): fix the broken e2e test (inflated activations + f16 overflow); kernel is correct#643
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Tightens test_moe_blockscale validation so the FlyDSL block-scale 2-stage pipeline can’t silently pass when it diverges from the torch reference.
Changes:
- Add assertions that FlyDSL output is finite.
- Add a hard failure if the computed
err_flyexceeds a threshold, with a pointer to #642.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
4bd3b24 to
ddf6e2c
Compare
… f16 overflow) test_moe_blockscale_e2e reported the FlyDSL 2-stage block-scale pipeline as 100% wrong (stage2 err_ratio = 1.0, all-NaN), but the kernel is correct. The 1.0 was a test-harness artifact from three compounding bugs in the test: 1. The block-scale activation was built from the raw fp8 codes (x_q.float(), ~448-scale) instead of the dequantized activation (x_q * x_scale, ~0.2), inflating activations ~2000x; the stage1 intermediate reached ~2.4e7. 2. Those magnitudes overflow f16 (max 65504): the f16 stage1/stage2 kernels produced NaN, and the reference's own out1_torch_ref.to(torch.float16) cast overflowed and poisoned the 2-stage torch reference (pure-torch check: err(2-stage ref vs fused ref) = 0.9999, cos = 0.04). aiter's own CK stage2 kernel failed the same comparison identically (err_vs_ref = 1.0). 3. The test return-ed timings instead of asserting, hiding all of it. Fix: dequantize x before re-block-quantizing (for both the torch reference and the FlyDSL kernel input), and assert finiteness + err_fly <= 0.05. With realistic magnitudes the whole pipeline is finite and correct on gfx950 / MI350X across small-E8, medium-E8, and DS-V3 (E=256/topk=8): stage1/stage2 err = 0.0000, pipeline err <= 0.0016, all passed. The FlyDSL 2-stage block-scale stage2 kernel and torch_stage2_blockscale_ref are both correct; no kernel change is needed. Closes ROCm#642 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ddf6e2c to
1079351
Compare
coderfeli
approved these changes
Jun 4, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes the broken
test_moe_blockscale_e2efrom #642.What #642 actually was
The test reported the FlyDSL 2-stage block-scale pipeline as 100% wrong
(
stage2 err_ratio = 1.0, all-NaN output), but the kernel is correct. The1.0was a test-harness artifact from three compounding bugs in the test:built from the raw fp8 codes (
x_q.float(), ~448-scale) instead of thedequantized activation (
x_q * x_scale, ~0.2). The stage1 intermediate thenreached
absmax ≈ 2.4e7.stage1/stage2 kernels overflowed to NaN, and the torch reference's own
out1_torch_ref.to(torch.float16)cast overflowed too, poisoning the 2-stagereference. Pure-torch check, no kernel involved:
err(2-stage ref vs fused ref) = 0.9999, cosine = 0.04, ‖2stage‖/‖fused‖ = 0.0025.aiter's own CK stage2 kernel fails the same comparison identically
(
err_vs_ref = 1.0) — the tell that the harness, not the kernel, was wrong.returninstead ofasserthid all of it.This PR
xbefore re-block-quantizing the activation (for both the torchreference and the FlyDSL kernel input) so the f16 pipeline stays in range.
err_fly <= 0.05asserts so a real regression can't passsilently.
With the fix, on gfx950 / MI350X (small-E8, medium-E8, and DS-V3 E=256/topk=8):
Notes
needed. In isolation, fed clean in-range inputs and the same
a2_bq, itmatches
torch_stage2_blockscale_refaterr = 0.0000, cos = 1.0000.torch_stage2_blockscale_refis also correct; it reproduces the fused referenceexactly when fed an in-range activation. It was only poisoned by the overflowed
input.
bf16output is not the fix — stage1's cshuffle epilog is f16-only; thedata-prep fix keeps the existing f16 path valid. (A separate hardening option is
to give stage1/stage2 a bf16 output path for genuinely large-magnitude
workloads.)
return (us_fly_total, us_aiter_fused)is kept for the__main__benchmark path and still triggers pytest's return-not-assert warning; splitting
the benchmark out would remove it.
Closes #642.