From 0df4673d29bf1e7fa208ef9d3a68b2c72259bf43 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 14 May 2026 18:18:32 +0800 Subject: [PATCH] adapt(norm): adapt tests/norm/ for Paddle compat - test_fused_dit_layernorm.py: add _chunk_strided() helper using torch.as_strided to reconstruct correct stride from 4D temb tensor. Paddle chunk() returns contiguous copies (losing strides); kernel requires gate.stride(1)==6*hidden_dim. Offset uses byte units (Paddle as_strided storage_offset is in bytes, PyTorch in elements). Fix _make_strided_gate to use _chunk_strided instead of chunk(). - test_rmsnorm_fp4_quant_cute_dsl.py, test_add_rmsnorm_fp4_quant_cute_dsl.py: add module-level skip guard for torch.float4_e2m1fn_x2 (NVFP4 packed dtype, PyTorch 2.6+, not proxied in Paddle compat). Use pytest.skip(allow_module_level=True). - scripts/paddle_all_test_cases.sh: add test_fused_dit_layernorm.py; add comments for fp4 tests (skipped, unavailable dtype). Results: test_fused_rmsnorm_silu.py: 102 passed, 50 skipped test_fused_dit_layernorm.py: 35 passed fp4 tests: 2 skipped (dtype unavailable) --- scripts/paddle_all_test_cases.sh | 3 ++ .../test_add_rmsnorm_fp4_quant_cute_dsl.py | 9 ++++ tests/norm/test_fused_dit_layernorm.py | 52 ++++++++++--------- tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py | 9 ++++ 4 files changed, 49 insertions(+), 24 deletions(-) diff --git a/scripts/paddle_all_test_cases.sh b/scripts/paddle_all_test_cases.sh index b8277cdb7d..0d8c9d79d9 100755 --- a/scripts/paddle_all_test_cases.sh +++ b/scripts/paddle_all_test_cases.sh @@ -19,3 +19,6 @@ python -m pytest -rs "tests/comm/test_trtllm_allreduce_fusion.py::test_trtllm_al # python -m pytest -rs tests/moe/test_trtllm_gen_fused_moe.py::test_deepseekv3_routing # python -m pytest -rs tests/moe/test_trtllm_gen_fused_moe.py::test_nvfp4_moe_gemm_bias python -m pytest -rs tests/norm/test_fused_rmsnorm_silu.py +python -m pytest -rs tests/norm/test_fused_dit_layernorm.py +# test_rmsnorm_fp4_quant_cute_dsl.py: SKIP - torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+, NVFP4 packed dtype) +# test_add_rmsnorm_fp4_quant_cute_dsl.py: SKIP - same reason as above diff --git a/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py b/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py index bd6fe87cfa..ab74a3a5f5 100644 --- a/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py +++ b/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py @@ -1,3 +1,9 @@ +import pytest as _pytest_fp4 +import torch as _torch_fp4 +if not hasattr(_torch_fp4, "float4_e2m1fn_x2"): + _pytest_fp4.skip("torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+)", allow_module_level=True) +del _pytest_fp4, _torch_fp4 + # Copyright (c) 2025 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,6 +28,9 @@ from tests.test_helpers.utils_fp4 import cast_from_fp4 + + + def get_cc(): """Get CUDA compute capability.""" major, minor = torch.cuda.get_device_capability() diff --git a/tests/norm/test_fused_dit_layernorm.py b/tests/norm/test_fused_dit_layernorm.py index 5950164c5e..63b68402d2 100644 --- a/tests/norm/test_fused_dit_layernorm.py +++ b/tests/norm/test_fused_dit_layernorm.py @@ -42,9 +42,21 @@ def _make_strided_gate(batch_size, seq_len, hidden_dim, device): temb = torch.randn( batch_size, seq_len, 6, hidden_dim, dtype=torch.bfloat16, device=device ) - return temb.chunk(6, dim=2)[0].squeeze(2) + return _chunk_strided(temb, 0) + + +def _chunk_strided(temb, chunk_idx): + batch_size, seq_len, _, hidden_dim = temb.shape + batch_stride, row_stride, _, col_stride = temb.stride() + return torch.as_strided( + temb, + size=(batch_size, seq_len, hidden_dim), + stride=(batch_stride, row_stride, col_stride), + storage_offset=chunk_idx * hidden_dim * temb.element_size(), + ) + def _make_wan_temb_inputs(batch_size, seq_len, hidden_dim, device): """Create gate/scale/shift tensors matching WAN's temb.chunk(6, dim=2) pattern. @@ -158,8 +170,7 @@ def test_gate_residual_gamma_beta_bf16(batch_size, seq_len): ) # Fused kernel — pass strided gate from temb.chunk(6, dim=2) directly - temb_chunks = temb_data["temb"].chunk(6, dim=2) - gate_strided = temb_chunks[2].squeeze(2) # gate_msa position + gate_strided = _chunk_strided(temb_data["temb"], 2) # gate_msa position table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) gate_bias_from_table = table_chunks[2].squeeze(1) @@ -211,15 +222,14 @@ def test_gate_residual_scale_shift_bf16(batch_size, seq_len): ) # Fused kernel - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_fused, norm_fused = fused_dit_gate_residual_layernorm_scale_shift( input_tensor, residual, - temb_chunks[5].squeeze(2), # c_gate_msa - temb_chunks[1].squeeze(2), # scale_msa - temb_chunks[0].squeeze(2), # shift_msa + _chunk_strided(temb_data["temb"], 5), # c_gate_msa + _chunk_strided(temb_data["temb"], 1), # scale_msa + _chunk_strided(temb_data["temb"], 0), # shift_msa gate_bias=table_chunks[5].squeeze(1), scale_bias=table_chunks[1].squeeze(1), shift_bias=table_chunks[0].squeeze(1), @@ -264,13 +274,12 @@ def test_residual_scale_shift_bf16(batch_size, seq_len): ) # Fused kernel - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_fused, norm_fused = fused_dit_residual_layernorm_scale_shift( input_tensor, - temb_chunks[4].squeeze(2), # c_scale_msa - temb_chunks[3].squeeze(2), # c_shift_msa + _chunk_strided(temb_data["temb"], 4), # c_scale_msa + _chunk_strided(temb_data["temb"], 3), # c_shift_msa residual=residual, scale_bias=table_chunks[4].squeeze(1), shift_bias=table_chunks[3].squeeze(1), @@ -303,7 +312,6 @@ def test_destination_passing(): beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device) temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device) - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_out = torch.empty_like(input_tensor) @@ -312,7 +320,7 @@ def test_destination_passing(): r_ret, n_ret = fused_dit_gate_residual_layernorm_gamma_beta( input_tensor, residual, - temb_chunks[2].squeeze(2), + _chunk_strided(temb_data["temb"], 2), gamma, beta, gate_bias=table_chunks[2].squeeze(1), @@ -340,7 +348,6 @@ def test_destination_passing_scale_shift(): ) residual = torch.randn_like(input_tensor) temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device) - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_out = torch.empty_like(input_tensor) @@ -348,8 +355,8 @@ def test_destination_passing_scale_shift(): r_ret, n_ret = fused_dit_residual_layernorm_scale_shift( input_tensor, - temb_chunks[4].squeeze(2), - temb_chunks[3].squeeze(2), + _chunk_strided(temb_data["temb"], 4), + _chunk_strided(temb_data["temb"], 3), residual=residual, scale_bias=table_chunks[4].squeeze(1), shift_bias=table_chunks[3].squeeze(1), @@ -376,13 +383,12 @@ def test_residual_scale_shift_no_residual(): batch_size, seq_len, HIDDEN_DIM, dtype=torch.bfloat16, device=device ) temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device) - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_fused, norm_fused = fused_dit_residual_layernorm_scale_shift( input_tensor, - temb_chunks[4].squeeze(2), - temb_chunks[3].squeeze(2), + _chunk_strided(temb_data["temb"], 4), + _chunk_strided(temb_data["temb"], 3), residual=None, scale_bias=table_chunks[4].squeeze(1), shift_bias=table_chunks[3].squeeze(1), @@ -424,13 +430,12 @@ def test_odd_num_rows(): beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device) temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device) - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) residual_fused, norm_fused = fused_dit_gate_residual_layernorm_gamma_beta( input_tensor, residual, - temb_chunks[2].squeeze(2), + _chunk_strided(temb_data["temb"], 2), gamma, beta, gate_bias=table_chunks[2].squeeze(1), @@ -483,7 +488,6 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768): gamma = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device) beta = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device) temb_data = _make_wan_temb_inputs(batch_size, seq_len, HIDDEN_DIM, device) - temb_chunks = temb_data["temb"].chunk(6, dim=2) table_chunks = temb_data["scale_shift_table"].chunk(6, dim=1) # Compute BF16 reference for residual and norm @@ -525,7 +529,7 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768): residual_out, norm_out = fused_dit_gate_residual_layernorm_gamma_beta( input_tensor, residual, - temb_chunks[2].squeeze(2), + _chunk_strided(temb_data["temb"], 2), gamma, beta, gate_bias=table_chunks[2].squeeze(1), @@ -537,8 +541,8 @@ def _run_nvfp4_or_mxfp8_test(mode, output_type, batch_size=1, seq_len=768): else: residual_out, norm_out = fused_dit_residual_layernorm_scale_shift( input_tensor, - temb_chunks[4].squeeze(2), - temb_chunks[3].squeeze(2), + _chunk_strided(temb_data["temb"], 4), + _chunk_strided(temb_data["temb"], 3), residual=residual, scale_bias=table_chunks[4].squeeze(1), shift_bias=table_chunks[3].squeeze(1), diff --git a/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py b/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py index 4a8b439d49..ea0464c616 100644 --- a/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py +++ b/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py @@ -1,3 +1,9 @@ +import pytest as _pytest_fp4 +import torch as _torch_fp4 +if not hasattr(_torch_fp4, "float4_e2m1fn_x2"): + _pytest_fp4.skip("torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+)", allow_module_level=True) +del _pytest_fp4, _torch_fp4 + # Copyright (c) 2025 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,6 +28,9 @@ from tests.test_helpers.utils_fp4 import cast_from_fp4 + + + def get_cc(): """Get CUDA compute capability.""" major, minor = torch.cuda.get_device_capability()