diff --git a/tests/kernels/test_moe_blockscale.py b/tests/kernels/test_moe_blockscale.py index 0d7f2b577..219e98968 100644 --- a/tests/kernels/test_moe_blockscale.py +++ b/tests/kernels/test_moe_blockscale.py @@ -303,8 +303,8 @@ def block_quant_expert(w_fp32, blk_n, blk_k): w1_shuf = w1_bq_shuf.view(-1) w2_shuf = w2_bq_shuf.view(-1) - # Input block quantize (from x_q fp8 -> re-block) - x_f32_for_blk = x_q.float() + # Input block quantize: dequantize x_q first (raw fp8 codes overflow the f16 pipeline; #642). + x_f32_for_blk = x_q.float() * x_scale a1_bq, a1_bscale = pertoken_quant( x_f32_for_blk.view(-1, model_dim // scale_blk_k, scale_blk_k), quant_dtype=DTYPE_FP8 ) @@ -357,7 +357,10 @@ def block_quant_expert(w_fp32, blk_n, blk_k): if HAS_AITER: # quant kernel writes transposed scale layout directly a1_bq, a1_scale_fly = per_group_quant_hip( - x_q.to(torch.bfloat16), quant_dtype=DTYPE_FP8, group_size=scale_blk_k, transpose_scale=True + (x_q.float() * x_scale).to(torch.bfloat16), + quant_dtype=DTYPE_FP8, + group_size=scale_blk_k, + transpose_scale=True, ) a1_scale_fly = a1_scale_fly.view(-1) # [nblk_k_w1 * token] flat else: @@ -525,6 +528,9 @@ def launch_stage2(): # Verify flydsl full pipeline vs blockscale torch reference err_fly = checkAllclose(out_ref.to(out2.dtype), out2, rtol=0.1, atol=0.1, msg="flydsl vs ref", printLog=False) print(f" flydsl: pipeline err_ratio vs ref = {err_fly:.4f}") + # Assert instead of returning timings, so a wrong pipeline can't pass silently (#642). + assert torch.isfinite(out2).all(), "FlyDSL block-scale pipeline produced non-finite output" + assert err_fly <= 0.05, f"FlyDSL block-scale pipeline diverges from reference: err_ratio={err_fly:.4f}" # ---- aiter comparisons ---- us_aiter_fused = 0