From 107935155713d4bb9f85a0df630e76d8f7cc8042 Mon Sep 17 00:00:00 2001 From: Jin Pan Date: Wed, 3 Jun 2026 09:50:59 +0000 Subject: [PATCH] test(moe_blockscale): fix the broken e2e test (inflated activations + f16 overflow) test_moe_blockscale_e2e reported the FlyDSL 2-stage block-scale pipeline as 100% wrong (stage2 err_ratio = 1.0, all-NaN), but the kernel is correct. The 1.0 was a test-harness artifact from three compounding bugs in the test: 1. The block-scale activation was built from the raw fp8 codes (x_q.float(), ~448-scale) instead of the dequantized activation (x_q * x_scale, ~0.2), inflating activations ~2000x; the stage1 intermediate reached ~2.4e7. 2. Those magnitudes overflow f16 (max 65504): the f16 stage1/stage2 kernels produced NaN, and the reference's own out1_torch_ref.to(torch.float16) cast overflowed and poisoned the 2-stage torch reference (pure-torch check: err(2-stage ref vs fused ref) = 0.9999, cos = 0.04). aiter's own CK stage2 kernel failed the same comparison identically (err_vs_ref = 1.0). 3. The test return-ed timings instead of asserting, hiding all of it. Fix: dequantize x before re-block-quantizing (for both the torch reference and the FlyDSL kernel input), and assert finiteness + err_fly <= 0.05. With realistic magnitudes the whole pipeline is finite and correct on gfx950 / MI350X across small-E8, medium-E8, and DS-V3 (E=256/topk=8): stage1/stage2 err = 0.0000, pipeline err <= 0.0016, all passed. The FlyDSL 2-stage block-scale stage2 kernel and torch_stage2_blockscale_ref are both correct; no kernel change is needed. Closes #642 Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/kernels/test_moe_blockscale.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_moe_blockscale.py b/tests/kernels/test_moe_blockscale.py index 0d7f2b577..219e98968 100644 --- a/tests/kernels/test_moe_blockscale.py +++ b/tests/kernels/test_moe_blockscale.py @@ -303,8 +303,8 @@ def block_quant_expert(w_fp32, blk_n, blk_k): w1_shuf = w1_bq_shuf.view(-1) w2_shuf = w2_bq_shuf.view(-1) - # Input block quantize (from x_q fp8 -> re-block) - x_f32_for_blk = x_q.float() + # Input block quantize: dequantize x_q first (raw fp8 codes overflow the f16 pipeline; #642). + x_f32_for_blk = x_q.float() * x_scale a1_bq, a1_bscale = pertoken_quant( x_f32_for_blk.view(-1, model_dim // scale_blk_k, scale_blk_k), quant_dtype=DTYPE_FP8 ) @@ -357,7 +357,10 @@ def block_quant_expert(w_fp32, blk_n, blk_k): if HAS_AITER: # quant kernel writes transposed scale layout directly a1_bq, a1_scale_fly = per_group_quant_hip( - x_q.to(torch.bfloat16), quant_dtype=DTYPE_FP8, group_size=scale_blk_k, transpose_scale=True + (x_q.float() * x_scale).to(torch.bfloat16), + quant_dtype=DTYPE_FP8, + group_size=scale_blk_k, + transpose_scale=True, ) a1_scale_fly = a1_scale_fly.view(-1) # [nblk_k_w1 * token] flat else: @@ -525,6 +528,9 @@ def launch_stage2(): # Verify flydsl full pipeline vs blockscale torch reference err_fly = checkAllclose(out_ref.to(out2.dtype), out2, rtol=0.1, atol=0.1, msg="flydsl vs ref", printLog=False) print(f" flydsl: pipeline err_ratio vs ref = {err_fly:.4f}") + # Assert instead of returning timings, so a wrong pipeline can't pass silently (#642). + assert torch.isfinite(out2).all(), "FlyDSL block-scale pipeline produced non-finite output" + assert err_fly <= 0.05, f"FlyDSL block-scale pipeline diverges from reference: err_ratio={err_fly:.4f}" # ---- aiter comparisons ---- us_aiter_fused = 0