diff --git a/tests/pytorch/test_grouped_linear.py b/tests/pytorch/test_grouped_linear.py index 0dc253c18c..d51796046c 100644 --- a/tests/pytorch/test_grouped_linear.py +++ b/tests/pytorch/test_grouped_linear.py @@ -866,6 +866,53 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass, monkeypatch torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +def test_grouped_gemm_cutlass_empty_groups(layout, monkeypatch): + dtype = torch.bfloat16 + z, k, n = 1, 2048, 1536 + m_splits = [0] * z + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + out = [torch.empty(0, n, dtype=dtype, device="cuda")] # output + grad = False + single_output = True + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.empty(0, k, dtype=dtype, device="cuda")] # dgrad + grad = True + single_output = True + else: # layout == "NT" + A = [torch.empty(0, k, dtype=dtype, device="cuda") for _ in range(z)] # input + B = [torch.empty(0, n, dtype=dtype, device="cuda") for _ in range(z)] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + single_output = False + + monkeypatch.setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1") + general_grouped_gemm( + A, + B, + out, + [None] * z, + dtype, + m_splits=m_splits, + grad=grad, + layout=layout, + single_output=single_output, + ) + torch.cuda.synchronize() + + for tensor in out: + torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) + + def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: data = grouped_tensor.rowwise_data if data is None: diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index e59e9c00c9..a0529c80c0 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1056,6 +1056,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_gemm); + if (num_gemms <= 0) { + return; + } + const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 53ea76d83b..47ece517b4 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -538,6 +538,10 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + if (te_A_wrappers.empty()) { + return bias; + } + // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list;