Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tests/kernels/test_moe_blockscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def block_quant_expert(w_fp32, blk_n, blk_k):
w1_shuf = w1_bq_shuf.view(-1)
w2_shuf = w2_bq_shuf.view(-1)

# Input block quantize (from x_q fp8 -> re-block)
x_f32_for_blk = x_q.float()
# Input block quantize: dequantize x_q first (raw fp8 codes overflow the f16 pipeline; #642).
x_f32_for_blk = x_q.float() * x_scale
a1_bq, a1_bscale = pertoken_quant(
x_f32_for_blk.view(-1, model_dim // scale_blk_k, scale_blk_k), quant_dtype=DTYPE_FP8
)
Expand Down Expand Up @@ -357,7 +357,10 @@ def block_quant_expert(w_fp32, blk_n, blk_k):
if HAS_AITER:
# quant kernel writes transposed scale layout directly
a1_bq, a1_scale_fly = per_group_quant_hip(
x_q.to(torch.bfloat16), quant_dtype=DTYPE_FP8, group_size=scale_blk_k, transpose_scale=True
(x_q.float() * x_scale).to(torch.bfloat16),
quant_dtype=DTYPE_FP8,
group_size=scale_blk_k,
transpose_scale=True,
)
a1_scale_fly = a1_scale_fly.view(-1) # [nblk_k_w1 * token] flat
else:
Expand Down Expand Up @@ -525,6 +528,9 @@ def launch_stage2():
# Verify flydsl full pipeline vs blockscale torch reference
err_fly = checkAllclose(out_ref.to(out2.dtype), out2, rtol=0.1, atol=0.1, msg="flydsl vs ref", printLog=False)
Comment thread
coderfeli marked this conversation as resolved.
print(f" flydsl: pipeline err_ratio vs ref = {err_fly:.4f}")
# Assert instead of returning timings, so a wrong pipeline can't pass silently (#642).
assert torch.isfinite(out2).all(), "FlyDSL block-scale pipeline produced non-finite output"
Comment thread
coderfeli marked this conversation as resolved.
assert err_fly <= 0.05, f"FlyDSL block-scale pipeline diverges from reference: err_ratio={err_fly:.4f}"

# ---- aiter comparisons ----
us_aiter_fused = 0
Expand Down
Loading