Skip to content

Fix both MoE kernels' correctness: aiter shuffle, f16-overflow data-prep, fp8 MoE tolerance (refs #4 #6, ROCm/FlyDSL#642)#5

Open
jhinpan wants to merge 4 commits into
mainfrom
fix/moe-gemm-aiter-fp8-baseline
Open

Fix both MoE kernels' correctness: aiter shuffle, f16-overflow data-prep, fp8 MoE tolerance (refs #4 #6, ROCm/FlyDSL#642)#5
jhinpan wants to merge 4 commits into
mainfrom
fix/moe-gemm-aiter-fp8-baseline

Conversation

@jhinpan
Copy link
Copy Markdown
Owner

@jhinpan jhinpan commented Jun 3, 2026

First step toward un-blocking moe_gemm (#4).

aiter.fused_moe was called without an explicit output dtype, so it inferred the dtype from the fp8 inputs and raised AssertionError: Fused_moe unsupported out dtype: torch.float8_e4m3fn — the aiter baseline was failed on every fp8 row, leaving only slow PyTorch eager (the meaningless ~157× vs-best). Passing dtype=torch.bfloat16 makes it run.

Scope/honesty: this does not by itself remove the dashboard ✕. aiter now runs but is still recorded incorrect vs our 2-stage reference (max_abs ≈ 1–7, even where FlyDSL is correct) because aiter's end-to-end fused_moe routing/normalization differs from the composed 2-stage reference. That alignment — the actual fix to give moe_gemm a real verdict — is tracked as the open checkbox in #4.

Closes nothing on its own; refs #4.

Summary by Sourcery

Ensure the aiter fp8 fused_moe baseline runs by explicitly requesting a bfloat16 output dtype.

Bug Fixes:

  • Fix fused_moe failing on fp8 configurations by specifying a supported bfloat16 output dtype instead of inferring from fp8 inputs.

Enhancements:

  • Clarify provider detail and inline comments to document the fused_moe fp8 path and its bfloat16 output behavior.

aiter.fused_moe inferred the output dtype from the fp8 inputs and asserted
("unsupported out dtype torch.float8_e4m3fn"), so the optimized baseline was
`failed` on every fp8 row, leaving only slow PyTorch eager and a meaningless
~157x vs-best. Pass dtype=torch.bfloat16 so it runs.

This is step 1 of #4: aiter now runs but is still `incorrect` vs the 2-stage
reference (end-to-end routing/normalization differs) -- the baseline alignment
is tracked there.

Refs #4
@sourcery-ai
Copy link
Copy Markdown

sourcery-ai Bot commented Jun 3, 2026

Reviewer's guide (collapsed on small PRs)

Reviewer's Guide

This PR ensures the aiter fp8 fused MoE baseline runs by explicitly requesting a bfloat16 output dtype from fused_moe and updating the provider detail string and in-code commentary accordingly.

File-Level Changes

Change Details Files
Ensure aiter fp8 fused_moe baseline runs by forcing bf16 output instead of inferring from fp8 inputs.
  • Update provider_detail description to note bf16 output for the fp8 per-token QuantType path.
  • Add explanatory comments describing the previous assertion failure when fused_moe inferred an unsupported fp8 output dtype and why bf16 is requested.
  • Pass dtype=torch.bfloat16 into fused_moe for the fp8 quantized path while leaving other call sites unchanged.
benchmarks/providers/moe_gemm.py

Possibly linked issues


Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the fused_moe benchmark provider to explicitly request torch.bfloat16 output when using FP8 inputs, preventing execution failures caused by unsupported FP8 output types. The reviewer suggested dynamically determining the output data type (either bfloat16 or float16) based on the benchmark configuration instead of hardcoding torch.bfloat16.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread benchmarks/providers/moe_gemm.py Outdated
Comment on lines 331 to 342
self.provider_detail = "aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; bf16 out; routing+sorting timed)"
# Without an explicit `dtype`, fused_moe infers the output dtype from the
# fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"),
# which is why aiter was `failed` on every fp8 row. Request bf16 out so
# the baseline at least RUNS. (It is still recorded `incorrect` vs our
# 2-stage reference -- aiter's end-to-end routing/normalization differs;
# see flydsl-kernel-profiling baseline-alignment issue.)
return fused_moe(
x_q, w1_q, w2_q, topk_weights, topk_ids,
activation=ActivationType.Silu, quant_type=QuantType.per_Token,
doweight_stage1=False,
doweight_stage1=False, dtype=torch.bfloat16,
w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of hardcoding torch.bfloat16 and "bf16 out", we should dynamically determine the output dtype based on the dtype variable (similar to how it is done on line 343). This ensures that if the benchmark is run with float16, the correct output dtype is requested and reported.

Suggested change
self.provider_detail = "aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; bf16 out; routing+sorting timed)"
# Without an explicit `dtype`, fused_moe infers the output dtype from the
# fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"),
# which is why aiter was `failed` on every fp8 row. Request bf16 out so
# the baseline at least RUNS. (It is still recorded `incorrect` vs our
# 2-stage reference -- aiter's end-to-end routing/normalization differs;
# see flydsl-kernel-profiling baseline-alignment issue.)
return fused_moe(
x_q, w1_q, w2_q, topk_weights, topk_ids,
activation=ActivationType.Silu, quant_type=QuantType.per_Token,
doweight_stage1=False,
doweight_stage1=False, dtype=torch.bfloat16,
w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale)
out_dtype = torch.bfloat16 if dtype in ("bf16", "bfloat16") else torch.float16
out_dtype_str = "bf16" if out_dtype == torch.bfloat16 else "fp16"
self.provider_detail = f"aiter.fused_moe (end-to-end fused; QuantType.per_Token fp8; {out_dtype_str} out; routing+sorting timed)"
# Without an explicit `dtype`, fused_moe infers the output dtype from the
# fp8 inputs and asserts ("unsupported out dtype torch.float8_e4m3fn"),
# which is why aiter was `failed` on every fp8 row. Request the appropriate out dtype so
# the baseline at least RUNS. (It is still recorded `incorrect` vs our
# 2-stage reference -- aiter's end-to-end routing/normalization differs;
# see flydsl-kernel-profiling baseline-alignment issue.)
return fused_moe(
x_q, w1_q, w2_q, topk_weights, topk_ids,
activation=ActivationType.Silu, quant_type=QuantType.per_Token,
doweight_stage1=False, dtype=out_dtype,
w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale)

Copy link
Copy Markdown

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've reviewed your changes and they look great!


Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Copy link
Copy Markdown

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 1 additional finding.

Open in Devin Review

jhinpan added 2 commits June 3, 2026 09:44
…baseline

The previous commit made aiter run; this makes it CORRECT. aiter.fused_moe (like
aiter's own op_tests) expects the expert weights in the shuffled layout. We were
passing the raw per-token-quant layout, so the kernel read it as orthogonal
garbage (cos ~0.003 vs the reference, max_abs err ~4.4) -- which is why aiter was
recorded `incorrect` even where FlyDSL is correct.

Verified: aiter's own fp8-modeled torch_moe matches our reference (err 0.19), so
the reference was right; shuffle_weight(w1_q)/shuffle_weight(w2_q) flips aiter to
cos 0.998 / err 0.196 (fp8 noise). The shuffle is cached per shape (one-time
setup, out of the timed region) to match the FlyDSL path's preshuffle.

Result: aiter 0 -> 3 correct (same shapes FlyDSL passes); moe_gemm now has a real
verdict (geomean 2.93x vs aiter) instead of baseline_blocked. The remaining 6 fp8
rows are double-quant noise > tol(0.15) for BOTH flydsl and aiter (separate
tolerance question), and 20 rows are other-dtype (out of scope).

Closes #4
…locked -> promote 2.93x)

After the aiter shuffle_weight fix, moe_gemm has a real fast baseline: it moves
from baseline_blocked (meaningless 157x vs eager) to promote (2.93x vs aiter on
the 3 fp8-correct shapes). Re-exported docs/data + rebuilt the single-file
index.html. Refs #4.
@jhinpan
Copy link
Copy Markdown
Owner Author

jhinpan commented Jun 3, 2026

Update — the alignment is done; this now fully un-blocks moe_gemm.

Root cause of the divergence: aiter.fused_moe (like aiter's own op_tests) expects the expert weights in the shuffled layout. We passed the raw per-token-quant layout, so the kernel read it as orthogonal garbage.

Verified on a FlyDSL-correct shape (E256/M32/Dim1=7168):

  • aiter's own fp8-modeled torch_moe matches our reference → err 0.186 (so the reference was right all along).
  • UNshuffled (before): cos 0.003, err 4.36 ← garbage.
  • shuffle_weight(w1_q) / shuffle_weight(w2_q) (this PR): cos 0.998, err 0.196 (fp8 noise).

The shuffle is cached per shape (one-time setup, out of the timed region), matching the FlyDSL path's preshuffle.

Result: aiter 0 → 3 correct (the same shapes FlyDSL passes), so moe_gemm moves from baseline_blocked (meaningless 157× vs eager) to promote, geomean 2.93× vs aiter. Dashboard regenerated. The remaining 6 fp8 rows are double-quant noise > tol(0.15) for both flydsl and aiter (a separate tolerance question, not a baseline bug); 20 rows are other-dtype/out-of-scope.

Closes #4.

…ernels un-✕

Two MoE corrections from the ROCm/FlyDSL#642 investigation (the moe_blockscale
"all-NaN" was a HARNESS f16-overflow, not a kernel bug):

- fp8 MoE tolerance: MoeGemmOp + MoeBlockscaleOp now use (0.4,0.4) for fp8 (bf16
  unchanged). fp8 MoE double/triple-quant noise is ~0.2-0.4 at cos ~0.998; a real
  bug (NaN/orthogonal) is cos~0, far outside. This lands moe_gemm 3->9 correct
  (all fp8 shapes) -> promote 5.19x vs aiter.

- moe_blockscale data-prep: the FlyDSL AND aiter providers block-quantized the raw
  fp8 codes (x_q, ~448) instead of the dequantized activation (x_q * scale, ~0.2),
  inflating ~2000x -> f16 overflow -> all-NaN (ROCm/FlyDSL#642). Use x_q*scale.
  moe_blockscale 0 -> 4 correct, with a real aiter baseline -> promote 7.09x.

Remaining (tracked in #6): the 6 inter_dim=2048 moe_blockscale shapes are still
flagged incorrect, but FlyDSL and aiter's CK kernel AGREE (cos 0.997) and only the
fp32 reference diverges (cos 0.56) -> a reference bug at inter_dim=2048, not a
kernel bug. The 2 large-E shapes OOM in make_inputs.

Dashboard regenerated: moe_gemm and moe_blockscale both move from ✕ to ▲ promote.
@jhinpan jhinpan changed the title moe_gemm: pass explicit bf16 out so the aiter fp8 baseline runs (refs #4) Fix both MoE kernels' correctness: aiter shuffle, f16-overflow data-prep, fp8 MoE tolerance (refs #4 #6, ROCm/FlyDSL#642) Jun 3, 2026
@jhinpan
Copy link
Copy Markdown
Owner Author

jhinpan commented Jun 3, 2026

Re-scoped: this now fixes both MoE ✕ on the dashboard. Three corrections from the ROCm/FlyDSL#642 investigation (the moe_blockscale "all-NaN" was a harness f16-overflow, not a kernel bug):

fix effect
aiter shuffle_weight (moe_gemm) aiter 0→correct baseline (was orthogonal garbage, cos 0.003→0.998)
f16-overflow data-prep (moe_blockscale, both FlyDSL + aiter providers) block-quantize x_q * scale not the raw fp8 codes → NaN→finite (cos 0.998)
fp8 MoE tolerance (0.4,0.4) (both ops) admits genuine fp8 double/triple-quant noise (0.2–0.4 @ cos 0.998); real bugs are cos0

Result:

  • moe_gemm: 3→9 correct▲ promote 5.19× vs aiter (was baseline_blocked ✕).
  • moe_blockscale: 0→4 correct + real aiter baseline → ▲ promote 7.09× (was 0/12 NaN ✕).

Verified throughout with kernel-vs-kernel agreement (the #642 method). The remaining 6 moe_blockscale inter_dim=2048 shapes are flagged in #6: FlyDSL and aiter's CK kernel agree (cos 0.997), only the fp32 reference diverges (cos 0.56) → a reference bug, not a kernel bug.

Closes #4. Refs #6, ROCm/FlyDSL#642 (the matching upstream test fix is ROCm/FlyDSL#643).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant