Fix both MoE kernels' correctness: aiter shuffle, f16-overflow data-prep, fp8 MoE tolerance (refs #4 #6, ROCm/FlyDSL#642)#5
Conversation
aiter.fused_moe inferred the output dtype from the fp8 inputs and asserted
("unsupported out dtype torch.float8_e4m3fn"), so the optimized baseline was
`failed` on every fp8 row, leaving only slow PyTorch eager and a meaningless
~157x vs-best. Pass dtype=torch.bfloat16 so it runs.
This is step 1 of #4: aiter now runs but is still `incorrect` vs the 2-stage
reference (end-to-end routing/normalization differs) -- the baseline alignment
is tracked there.
Refs #4
Reviewer's guide (collapsed on small PRs)Reviewer's GuideThis PR ensures the aiter fp8 fused MoE baseline runs by explicitly requesting a bfloat16 output dtype from fused_moe and updating the provider detail string and in-code commentary accordingly. File-Level Changes
Possibly linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Code Review
This pull request updates the fused_moe benchmark provider to explicitly request torch.bfloat16 output when using FP8 inputs, preventing execution failures caused by unsupported FP8 output types. The reviewer suggested dynamically determining the output data type (either bfloat16 or float16) based on the benchmark configuration instead of hardcoding torch.bfloat16.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self.provider_detail = "aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; bf16 out; routing+sorting timed)" | ||
| # Without an explicit `dtype`, fused_moe infers the output dtype from the | ||
| # fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"), | ||
| # which is why aiter was `failed` on every fp8 row. Request bf16 out so | ||
| # the baseline at least RUNS. (It is still recorded `incorrect` vs our | ||
| # 2-stage reference -- aiter's end-to-end routing/normalization differs; | ||
| # see flydsl-kernel-profiling baseline-alignment issue.) | ||
| return fused_moe( | ||
| x_q, w1_q, w2_q, topk_weights, topk_ids, | ||
| activation=ActivationType.Silu, quant_type=QuantType.per_Token, | ||
| doweight_stage1=False, | ||
| doweight_stage1=False, dtype=torch.bfloat16, | ||
| w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale) |
There was a problem hiding this comment.
Instead of hardcoding torch.bfloat16 and "bf16 out", we should dynamically determine the output dtype based on the dtype variable (similar to how it is done on line 343). This ensures that if the benchmark is run with float16, the correct output dtype is requested and reported.
| self.provider_detail = "aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; bf16 out; routing+sorting timed)" | |
| # Without an explicit `dtype`, fused_moe infers the output dtype from the | |
| # fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"), | |
| # which is why aiter was `failed` on every fp8 row. Request bf16 out so | |
| # the baseline at least RUNS. (It is still recorded `incorrect` vs our | |
| # 2-stage reference -- aiter's end-to-end routing/normalization differs; | |
| # see flydsl-kernel-profiling baseline-alignment issue.) | |
| return fused_moe( | |
| x_q, w1_q, w2_q, topk_weights, topk_ids, | |
| activation=ActivationType.Silu, quant_type=QuantType.per_Token, | |
| doweight_stage1=False, | |
| doweight_stage1=False, dtype=torch.bfloat16, | |
| w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale) | |
| out_dtype = torch.bfloat16 if dtype in ("bf16", "bfloat16") else torch.float16 | |
| out_dtype_str = "bf16" if out_dtype == torch.bfloat16 else "fp16" | |
| self.provider_detail = f"aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; {out_dtype_str} out; routing+sorting timed)" | |
| # Without an explicit `dtype`, fused_moe infers the output dtype from the | |
| # fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"), | |
| # which is why aiter was `failed` on every fp8 row. Request the appropriate out dtype so | |
| # the baseline at least RUNS. (It is still recorded `incorrect` vs our | |
| # 2-stage reference -- aiter's end-to-end routing/normalization differs; | |
| # see flydsl-kernel-profiling baseline-alignment issue.) | |
| return fused_moe( | |
| x_q, w1_q, w2_q, topk_weights, topk_ids, | |
| activation=ActivationType.Silu, quant_type=QuantType.per_Token, | |
| doweight_stage1=False, dtype=out_dtype, | |
| w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale) |
…baseline The previous commit made aiter run; this makes it CORRECT. aiter.fused_moe (like aiter's own op_tests) expects the expert weights in the shuffled layout. We were passing the raw per-token-quant layout, so the kernel read it as orthogonal garbage (cos ~0.003 vs the reference, max_abs err ~4.4) -- which is why aiter was recorded `incorrect` even where FlyDSL is correct. Verified: aiter's own fp8-modeled torch_moe matches our reference (err 0.19), so the reference was right; shuffle_weight(w1_q)/shuffle_weight(w2_q) flips aiter to cos 0.998 / err 0.196 (fp8 noise). The shuffle is cached per shape (one-time setup, out of the timed region) to match the FlyDSL path's preshuffle. Result: aiter 0 -> 3 correct (same shapes FlyDSL passes); moe_gemm now has a real verdict (geomean 2.93x vs aiter) instead of baseline_blocked. The remaining 6 fp8 rows are double-quant noise > tol(0.15) for BOTH flydsl and aiter (separate tolerance question), and 20 rows are other-dtype (out of scope). Closes #4
…locked -> promote 2.93x) After the aiter shuffle_weight fix, moe_gemm has a real fast baseline: it moves from baseline_blocked (meaningless 157x vs eager) to promote (2.93x vs aiter on the 3 fp8-correct shapes). Re-exported docs/data + rebuilt the single-file index.html. Refs #4.
|
Update — the alignment is done; this now fully un-blocks Root cause of the divergence: Verified on a FlyDSL-correct shape (E256/M32/Dim1=7168):
The shuffle is cached per shape (one-time setup, out of the timed region), matching the FlyDSL path's preshuffle. Result: aiter 0 → 3 correct (the same shapes FlyDSL passes), so Closes #4. |
…ernels un-✕ Two MoE corrections from the ROCm/FlyDSL#642 investigation (the moe_blockscale "all-NaN" was a HARNESS f16-overflow, not a kernel bug): - fp8 MoE tolerance: MoeGemmOp + MoeBlockscaleOp now use (0.4,0.4) for fp8 (bf16 unchanged). fp8 MoE double/triple-quant noise is ~0.2-0.4 at cos ~0.998; a real bug (NaN/orthogonal) is cos~0, far outside. This lands moe_gemm 3->9 correct (all fp8 shapes) -> promote 5.19x vs aiter. - moe_blockscale data-prep: the FlyDSL AND aiter providers block-quantized the raw fp8 codes (x_q, ~448) instead of the dequantized activation (x_q * scale, ~0.2), inflating ~2000x -> f16 overflow -> all-NaN (ROCm/FlyDSL#642). Use x_q*scale. moe_blockscale 0 -> 4 correct, with a real aiter baseline -> promote 7.09x. Remaining (tracked in #6): the 6 inter_dim=2048 moe_blockscale shapes are still flagged incorrect, but FlyDSL and aiter's CK kernel AGREE (cos 0.997) and only the fp32 reference diverges (cos 0.56) -> a reference bug at inter_dim=2048, not a kernel bug. The 2 large-E shapes OOM in make_inputs. Dashboard regenerated: moe_gemm and moe_blockscale both move from ✕ to ▲ promote.
|
Re-scoped: this now fixes both MoE ✕ on the dashboard. Three corrections from the ROCm/FlyDSL#642 investigation (the moe_blockscale "all-NaN" was a harness f16-overflow, not a kernel bug):
Result:
Verified throughout with kernel-vs-kernel agreement (the #642 method). The remaining 6 Closes #4. Refs #6, ROCm/FlyDSL#642 (the matching upstream test fix is ROCm/FlyDSL#643). |
First step toward un-blocking
moe_gemm(#4).aiter.fused_moewas called without an explicit outputdtype, so it inferred the dtype from the fp8 inputs and raisedAssertionError: Fused_moe unsupported out dtype: torch.float8_e4m3fn— the aiter baseline wasfailedon every fp8 row, leaving only slow PyTorch eager (the meaningless ~157× vs-best). Passingdtype=torch.bfloat16makes it run.Scope/honesty: this does not by itself remove the dashboard ✕. aiter now
runsbut is still recordedincorrectvs our 2-stage reference (max_abs ≈ 1–7, even where FlyDSL is correct) because aiter's end-to-endfused_moerouting/normalization differs from the composed 2-stage reference. That alignment — the actual fix to givemoe_gemma real verdict — is tracked as the open checkbox in #4.Closes nothing on its own; refs #4.
Summary by Sourcery
Ensure the aiter fp8 fused_moe baseline runs by explicitly requesting a bfloat16 output dtype.
Bug Fixes:
Enhancements: