Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,20 @@ def xqa(
v_sf_cache = v_sf_cache.transpose(-3, -2)
if (
k_cache.dtype == torch.float8_e4m3fn
and get_compute_capability(torch.device(device="cuda"))[0] == 9
and get_compute_capability(torch.device("cuda"))[0] == 9
):
run_sm90_fp8_mha = True
else:
run_sm90_fp8_mha = False

if k_cache.dtype == torch.uint8:
assert get_compute_capability(torch.device(device="cuda"))[0] in [12], (
assert get_compute_capability(torch.device("cuda"))[0] in [12], (
"XQA NVFP4 KV is only supported on SM120 GPUs"
)
assert k_sf_cache is not None, "K SF cache is required when NVFP4 KV is used"
assert v_sf_cache is not None, "V SF cache is required when NVFP4 KV is used"

if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
if get_compute_capability(torch.device("cuda"))[0] not in [9, 10, 12]:
raise RuntimeError("XQA is only supported on SM90, SM100, SM120/SM121 GPUs")

xqa_module = get_xqa_module(
Expand Down Expand Up @@ -534,7 +534,7 @@ def xqa_mla(

assert k_cache.dtype == v_cache.dtype, "K and V cache must have the same dtype"

if get_compute_capability(torch.device(device="cuda"))[0] not in [12]:
if get_compute_capability(torch.device("cuda"))[0] not in [12]:
raise RuntimeError("XQA MLA is only supported on SM120/SM121 GPUs")

xqa_module = get_xqa_module_mla(
Expand Down
7 changes: 7 additions & 0 deletions scripts/paddle_all_test_cases.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ python -m pytest -rs tests/norm/test_fused_rmsnorm_silu.py
python -m pytest -rs tests/norm/test_fused_dit_layernorm.py
# test_rmsnorm_fp4_quant_cute_dsl.py: SKIP - torch.float4_e2m1fn_x2 not available (requires PyTorch 2.6+, NVFP4 packed dtype)
# test_add_rmsnorm_fp4_quant_cute_dsl.py: SKIP - same reason as above
# gemm tests (2026-05-15)
# test_mm_bf16: cudnn/auto backend FAIL due to §47 env issue (Multiple libcudart), cutlass/tgv/cublaslt/tinygemm pass
python -m pytest -rs 'tests/gemm/test_mm_bf16.py::test_mm_bf16[False-cutlass-False-False-res_dtype0-1024-1024-1]'
# test_bmm_bf16: auto-float32 FAIL due to §47, cutlass/cudnn pass
python -m pytest -rs 'tests/gemm/test_bmm_bf16.py::test_bmm_bf16[cutlass-res_dtype0-64-80-48-1]'
# test_group_gemm: sm80 PASS, sm90 SKIP (device is SM100 Blackwell, sm90 not supported)
python -m pytest -rs 'tests/gemm/test_group_gemm.py::test_segment_gemm[sm80-cuda:0-dtype0-False-False-128-128-3-1]'
4 changes: 2 additions & 2 deletions tests/attention/test_attention_sink_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_blackwell_trtllm_gen_decode_attention_sink(
num_kv_heads,
head_dim,
):
# compute_capability = get_compute_capability(torch.device(device="cuda"))
# compute_capability = get_compute_capability(torch.device("cuda"))
# if compute_capability[0] in [11, 12]:
# pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
# seed = 0
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_blackwell_trtllm_gen_context_attention_sink(
num_kv_heads,
head_dim,
):
# compute_capability = get_compute_capability(torch.device(device="cuda"))
# compute_capability = get_compute_capability(torch.device("cuda"))
# if compute_capability[0] in [11, 12]:
# pytest.skip("trtllm-gen does not support SM110/SM120/SM121 GPUs.")
seed = 0
Expand Down
6 changes: 3 additions & 3 deletions tests/attention/test_batch_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _run_attention(

# ------------------------- PyTest test case ----------------------------- #
@pytest.mark.xfail(
get_compute_capability(torch.device(device="cuda"))[0] == 12,
get_compute_capability(torch.device("cuda"))[0] == 12,
reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
)
def test_batch_attention_with_noncontiguous_q():
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_batch_attention_with_noncontiguous_q():


@pytest.mark.xfail(
get_compute_capability(torch.device(device="cuda"))[0] == 12,
get_compute_capability(torch.device("cuda"))[0] == 12,
reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
)
@pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs())
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_batch_attention_correctness(


@pytest.mark.xfail(
get_compute_capability(torch.device(device="cuda"))[0] == 12,
get_compute_capability(torch.device("cuda"))[0] == 12,
reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
)
@pytest.mark.parametrize("batch_size", [1, 4])
Expand Down
8 changes: 4 additions & 4 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def _test_trtllm_batch_prefill(
skips_softmax: bool = False,
uses_shared_paged_kv_idx: bool = True,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
if not causal and window_left >= 0:
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def _test_trtllm_batch_decode(

Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...()
"""
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))

# Check GPU architecture requirements for different backends
if backend == "trtllm-gen" and compute_capability[0] != 10:
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def test_trtllm_gen_prefill(
skips_softmax: bool,
enable_sink: bool,
) -> None:
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
if s_qo > s_kv:
Expand Down Expand Up @@ -2046,7 +2046,7 @@ def test_trtllm_gen_prefill_fp8(
skips_softmax: bool,
) -> None:
"""Test cute-dsl prefill with FP8 (e4m3) input, bf16 output."""
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")

Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def trtllm_batch_decode_mla(
skips_softmax: bool,
uses_shared_paged_kv_idx: bool = True,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if backend == "xqa":
if compute_capability[0] != 12:
pytest.skip("XQA MLA only supports SM120/SM121 GPUs")
Expand Down Expand Up @@ -554,7 +554,7 @@ def trtllm_batch_decode_mla_sparse(
qk_nope_head_dim: int,
num_attn_heads: int,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if backend == "trtllm-gen":
if compute_capability[0] != 10:
pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs")
Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def ref_attention(


@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
get_compute_capability(torch.device("cuda"))[0] not in [9, 10, 12],
reason="XQA is only supported on SM90, SM100, SM120/SM121 GPUs",
)
@pytest.mark.parametrize("enable_pdl", [True, False])
Expand Down Expand Up @@ -466,7 +466,7 @@ def test_xqa(


@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
get_compute_capability(torch.device("cuda"))[0] not in [12],
reason="XQA mla is only supported on SM120/SM121 GPUs",
)
@pytest.mark.parametrize("kv_scale", [1.0, 0.5])
Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_xqa_batch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def generate_causal_mask(


@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12],
get_compute_capability(torch.device("cuda"))[0] not in [9, 10, 12],
reason="XQA is only supported on SM90, SM100, SM120/SM121 GPUs",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_xqa_batch_decode(


@pytest.mark.skipif(
get_compute_capability(torch.device(device="cuda"))[0] not in [12],
get_compute_capability(torch.device("cuda"))[0] not in [12],
reason="XQA with NVFP4 KV is only supported on SM120 GPUs",
)
@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/attention/test_xqa_mla_batch_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_xqa_mla_batch_decode(
page_size: int,
enable_pdl: bool,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 12:
pytest.skip("These tests are only guaranteed to work on SM120/SM121 GPUs.")

Expand Down
2 changes: 1 addition & 1 deletion tests/autotuner/test_autotuner_bmm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
],
)
def test_autotuner_gemm(pre_tune, tune_mode, expected_cache_hit, m, n, k):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not bmm_fp8.is_compute_capability_supported(compute_capability_number):
pytest.skip(
Expand Down
2 changes: 1 addition & 1 deletion tests/comm/test_mixed_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _run_worker(
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("local_size", [2, 4, 8])
def test_mixed_comm(local_size, num_nodes, node_id, dtype, dist_init_method):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not run_mixed_comm.is_compute_capability_supported(compute_capability_number):
pytest.skip(
Expand Down
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,46 @@
import paddle

paddle.enable_compat()

# Monkey-patch Stream to add cuda_stream property for paddle compat
# After paddle.enable_compat(), torch.cuda.Stream is paddle.device.Stream
# which uses __cuda_stream__() instead of cuda_stream property
try:
import torch

if not hasattr(torch.cuda.Stream, "cuda_stream"):
torch.cuda.Stream.cuda_stream = property(
lambda self: self.__cuda_stream__()[1]
if hasattr(self, "__cuda_stream__")
else None
)
except Exception:
pass


# Monkey-patch torch.cuda.current_blas_handle for paddle compat
try:
import ctypes

if not hasattr(torch.cuda, "current_blas_handle"):
_cublas_lib = ctypes.CDLL("libcublas.so.12")
_cublas_handles = {}

def _current_blas_handle():
device_id = torch.cuda.current_device()
if device_id not in _cublas_handles:
handle = ctypes.c_void_p()
_cublas_lib.cublasCreate_v2(ctypes.byref(handle))
_cublas_handles[device_id] = handle
handle = _cublas_handles[device_id]
stream_ptr = torch.cuda.current_stream().cuda_stream
_cublas_lib.cublasSetStream_v2(handle, ctypes.c_void_p(stream_ptr))
return handle.value

paddle.cuda.current_blas_handle = _current_blas_handle
except Exception:
pass

import pytest
import torch
# from torch.torch_version import TorchVersion
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_bmm_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "auto"])
def test_bmm_bf16(b, m, n, k, res_dtype, backend):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not bmm_bf16.is_compute_capability_supported(compute_capability_number):
pytest.skip(
Expand Down
12 changes: 6 additions & 6 deletions tests/gemm/test_groupwise_scaled_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_fp8_blockscale_gemm(
scale_major_mode,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] not in [10, 11, 12]:
pytest.skip(
"gemm_fp8_nt_blockscaled is only supported on SM100/103, SM110, and SM120/121 GPUs."
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_fp8_groupwise_gemm(
scale_major_mode,
backend,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if backend == "trtllm":
if compute_capability[0] != 10:
pytest.skip(
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_fp8_groupwise_gemm(
@pytest.mark.parametrize("k", [256])
@pytest.mark.parametrize("scale_major_mode", ["MN", "K"])
def test_fp8_groupwise_gemm_small_batch_size(m, n, k, scale_major_mode):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip(
"Small-batch gemm_fp8_nt_groupwise dispatch is only relevant on SM100/103."
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_fp8_groupwise_group_gemm(
scale_major_mode,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if group_size > 1 and compute_capability[0] in [
12,
]:
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_fp8_groupwise_group_deepgemm(
group_size,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip(
"group_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103 in trtllm backend."
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_fp8_groupwise_batch_deepgemm_masked(
group_size,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 10:
pytest.skip(
"batch_deepgemm_fp8_nt_groupwise is only supported on SM100, SM103."
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_mxfp8_mxfp4_groupwise_group_gemm(
fp8_dtype,
out_dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] not in [10, 12]:
pytest.skip(
"gemm_mxfp4_nt_groupwise is only supported on SM100, SM103, and SM120/121 GPUs."
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_mm_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_mm_bf16(
backend: str,
auto_tuning: bool,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not mm_bf16.is_compute_capability_supported(compute_capability_number):
pytest.skip(
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_mm_fp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _test_mm_fp4(
):
use_nvfp4 = fp4_type == "nvfp4"

compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
if not mm_fp4.is_backend_supported(backend, compute_capability_number):
pytest.skip(
Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_mm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_mm_fp8(
mat2_dtype: torch.dtype,
res_dtype: torch.dtype,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] not in [10]:
pytest.skip("mm_fp8 is only supported on Blackwell GPUs.")

Expand Down
2 changes: 1 addition & 1 deletion tests/gemm/test_sm_constraint_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def torch_addmm(a, b, c, alpha=1.0, beta=0.0):
"EPILOGUE_SUBTILE", [True, False]
) # only for descriptor persistent
def test_sm_constraint_gemm(M, N, K, alpha, beta, num_sms, dtype, EPILOGUE_SUBTILE):
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
# TODO(P1): Most of these tests pass on Blackwell. We need triage these at some point.
if compute_capability[0] != 9:
pytest.skip("These tests are only guaranteed to work on Hopper GPUs.")
Expand Down
4 changes: 2 additions & 2 deletions tests/moe/test_trtllm_gen_moe_autotune_tactics.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_trtllm_fp4_routed_moe_all_tactics_correctness(
determinism, and approximate match to the heuristic-default tactic's
output.
"""
if get_compute_capability(torch.device(device="cuda"))[0] not in [10]:
if get_compute_capability(torch.device("cuda"))[0] not in [10]:
pytest.skip("Only work on SM100 / SM103.")

AutoTuner.get()._logged_file_hits.discard(_TEST_LOG_KEY_FP4)
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_trtllm_fp8_routed_moe_all_tactics_correctness(
quant_mode: Fp8QuantMode,
):
"""Per-tactic correctness sweep of `trtllm_fp8_block_scale_routed_moe`."""
if get_compute_capability(torch.device(device="cuda"))[0] not in [10]:
if get_compute_capability(torch.device("cuda"))[0] not in [10]:
pytest.skip("Only work on SM100 / SM103.")

AutoTuner.get()._logged_file_hits.discard(_TEST_LOG_KEY_FP8)
Expand Down
2 changes: 1 addition & 1 deletion tests/moe/test_trtllm_gen_per_token_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_routed_fused_moe(
top_k: int,
):
device = torch.device("cuda:0")
compute_capability = get_compute_capability(torch.device(device="cuda"))
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] not in [10]:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
enable_pdl = device_support_pdl(device)
Expand Down
Loading
Loading