Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/paddle_all_test_cases.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ python -m pytest -rs "tests/comm/test_trtllm_allreduce_fusion.py::test_trtllm_al
# python -m pytest -rs tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing
# python -m pytest -rs tests/moe/test_trtllm_gen_fused_moe.py::test_nvfp4_moe_gemm_bias
python -m pytest -rs tests/norm/test_fused_rmsnorm_silu.py
python -m pytest -rs tests/norm/test_fused_dit_layernorm.py
# test_rmsnorm_fp4_quant_cute_dsl.py: SKIP - torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+, NVFP4 packed dtype)
# test_add_rmsnorm_fp4_quant_cute_dsl.py: SKIP - same reason as above
9 changes: 9 additions & 0 deletions tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import pytest as _pytest_fp4
import torch as _torch_fp4
if not hasattr(_torch_fp4, "float4_e2m1fn_x2"):
_pytest_fp4.skip("torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+)", allow_module_level=True)
del _pytest_fp4, _torch_fp4

# Copyright (c) 2025 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -22,6 +28,9 @@
from tests.test_helpers.utils_fp4 import cast_from_fp4





def get_cc():
"""Get CUDA compute capability."""
major, minor = torch.cuda.get_device_capability()
Expand Down
52 changes: 28 additions & 24 deletions tests/norm/test_fused_dit_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,21 @@ def _make_strided_gate(batch_size, seq_len, hidden_dim, device):
temb = torch.randn(
batch_size, seq_len, 6, hidden_dim, dtype=torch.bfloat16, device=device
)
return temb.chunk(6, dim=2)[0].squeeze(2)
return _chunk_strided(temb, 0)




def _chunk_strided(temb, chunk_idx):
batch_size, seq_len, _, hidden_dim = temb.shape
batch_stride, row_stride, _, col_stride = temb.stride()
return torch.as_strided(
temb,
size=(batch_size, seq_len, hidden_dim),
stride=(batch_stride, row_stride, col_stride),
storage_offset=chunk_idx * hidden_dim * temb.element_size(),
)

def _make_wan_temb_inputs(batch_size, seq_len, hidden_dim, device):
"""Create gate/scale/shift tensors matching WAN's temb.chunk(6, dim=2) pattern.

Expand Down Expand Up @@ -158,8 +170,7 @@ def test_gate_residual_gamma_beta_bf16(batch_size, seq_len):
)

# Fused kernel β€” pass strided gate from temb.chunk(6, dim=2) directly
temb_chunks = temb_data["temb"].chunk(6, dim=2)
gate_strided = temb_chunks[2].squeeze(2) # gate_msa position
gate_strided = _chunk_strided(temb_data["temb"], 2) # gate_msa position
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)
gate_bias_from_table = table_chunks[2].squeeze(1)

Expand Down Expand Up @@ -211,15 +222,14 @@ def test_gate_residual_scale_shift_bf16(batch_size, seq_len):
)

# Fused kernel
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_fused, norm_fused = fused_dit_gate_residual_layernorm_scale_shift(
input_tensor,
residual,
temb_chunks[5].squeeze(2), # c_gate_msa
temb_chunks[1].squeeze(2), # scale_msa
temb_chunks[0].squeeze(2), # shift_msa
_chunk_strided(temb_data["temb"], 5), # c_gate_msa
_chunk_strided(temb_data["temb"], 1), # scale_msa
_chunk_strided(temb_data["temb"], 0), # shift_msa
gate_bias=table_chunks[5].squeeze(1),
scale_bias=table_chunks[1].squeeze(1),
shift_bias=table_chunks[0].squeeze(1),
Expand Down Expand Up @@ -264,13 +274,12 @@ def test_residual_scale_shift_bf16(batch_size, seq_len):
)

# Fused kernel
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_fused, norm_fused = fused_dit_residual_layernorm_scale_shift(
input_tensor,
temb_chunks[4].squeeze(2), # c_scale_msa
temb_chunks[3].squeeze(2), # c_shift_msa
_chunk_strided(temb_data["temb"], 4), # c_scale_msa
_chunk_strided(temb_data["temb"], 3), # c_shift_msa
residual=residual,
scale_bias=table_chunks[4].squeeze(1),
shift_bias=table_chunks[3].squeeze(1),
Expand Down Expand Up @@ -303,7 +312,6 @@ def test_destination_passing():
beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)

temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device)
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_out = torch.empty_like(input_tensor)
Expand All @@ -312,7 +320,7 @@ def test_destination_passing():
r_ret, n_ret = fused_dit_gate_residual_layernorm_gamma_beta(
input_tensor,
residual,
temb_chunks[2].squeeze(2),
_chunk_strided(temb_data["temb"], 2),
gamma,
beta,
gate_bias=table_chunks[2].squeeze(1),
Expand Down Expand Up @@ -340,16 +348,15 @@ def test_destination_passing_scale_shift():
)
residual = torch.randn_like(input_tensor)
temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device)
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_out = torch.empty_like(input_tensor)
norm_out = torch.empty_like(input_tensor)

r_ret, n_ret = fused_dit_residual_layernorm_scale_shift(
input_tensor,
temb_chunks[4].squeeze(2),
temb_chunks[3].squeeze(2),
_chunk_strided(temb_data["temb"], 4),
_chunk_strided(temb_data["temb"], 3),
residual=residual,
scale_bias=table_chunks[4].squeeze(1),
shift_bias=table_chunks[3].squeeze(1),
Expand All @@ -376,13 +383,12 @@ def test_residual_scale_shift_no_residual():
batch_size, seq_len, HIDDEN_DIM, dtype=torch.bfloat16, device=device
)
temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device)
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_fused, norm_fused = fused_dit_residual_layernorm_scale_shift(
input_tensor,
temb_chunks[4].squeeze(2),
temb_chunks[3].squeeze(2),
_chunk_strided(temb_data["temb"], 4),
_chunk_strided(temb_data["temb"], 3),
residual=None,
scale_bias=table_chunks[4].squeeze(1),
shift_bias=table_chunks[3].squeeze(1),
Expand Down Expand Up @@ -424,13 +430,12 @@ def test_odd_num_rows():
beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)

temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device)
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

residual_fused, norm_fused = fused_dit_gate_residual_layernorm_gamma_beta(
input_tensor,
residual,
temb_chunks[2].squeeze(2),
_chunk_strided(temb_data["temb"], 2),
gamma,
beta,
gate_bias=table_chunks[2].squeeze(1),
Expand Down Expand Up @@ -483,7 +488,6 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768):
gamma = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)
beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)
temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device)
temb_chunks = temb_data["temb"].chunk(6, dim=2)
table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1)

# Compute BF16 reference for residual and norm
Expand Down Expand Up @@ -525,7 +529,7 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768):
residual_out, norm_out = fused_dit_gate_residual_layernorm_gamma_beta(
input_tensor,
residual,
temb_chunks[2].squeeze(2),
_chunk_strided(temb_data["temb"], 2),
gamma,
beta,
gate_bias=table_chunks[2].squeeze(1),
Expand All @@ -537,8 +541,8 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768):
else:
residual_out, norm_out = fused_dit_residual_layernorm_scale_shift(
input_tensor,
temb_chunks[4].squeeze(2),
temb_chunks[3].squeeze(2),
_chunk_strided(temb_data["temb"], 4),
_chunk_strided(temb_data["temb"], 3),
residual=residual,
scale_bias=table_chunks[4].squeeze(1),
shift_bias=table_chunks[3].squeeze(1),
Expand Down
9 changes: 9 additions & 0 deletions tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import pytest as _pytest_fp4
import torch as _torch_fp4
if not hasattr(_torch_fp4, "float4_e2m1fn_x2"):
_pytest_fp4.skip("torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+)", allow_module_level=True)
del _pytest_fp4, _torch_fp4

# Copyright (c) 2025 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -22,6 +28,9 @@
from tests.test_helpers.utils_fp4 import cast_from_fp4





def get_cc():
"""Get CUDA compute capability."""
major, minor = torch.cuda.get_device_capability()
Expand Down
Loading