diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ced32e1a5..df607bd6dc 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -44,7 +44,6 @@ is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor -from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor import transformer_engine_torch as tex # Import utility functions @@ -80,6 +79,9 @@ if nvfp4_available: _quantization_list.append("nvfp4") _quantization_list.append("nvfp4_4over6") +_grouped_mlp_quantization_list = list(_quantization_list) +if nvfp4_available: + _grouped_mlp_quantization_list.append("nvfp4_rht") @pytest.fixture(autouse=True, scope="function") @@ -109,7 +111,10 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") + and not nvfp4_available + ): pytest.skip(reason_for_no_nvfp4) # Check dims @@ -122,14 +127,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: if ( - quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and dtype != torch.bfloat16 ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -187,10 +192,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization in ("nvfp4", "nvfp4_row_scaled"): + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_rht"): + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + with_rht = quantization == "nvfp4_rht" and tensor_type != "weight" test = NVFP4Quantizer( - with_rht=False, - with_post_rht_amax=False, + with_rht=with_rht, + with_post_rht_amax=with_rht, with_2d_quantization=False, stochastic_rounding=False, with_random_sign_mask=False, @@ -3685,7 +3694,7 @@ def test_layernorm_mlp( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantization", _grouped_mlp_quantization_list) @pytest.mark.parametrize("single_grouped_weight", (False, True)) @pytest.mark.parametrize("single_grouped_bias", (False, True)) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @@ -3753,16 +3762,19 @@ def test_grouped_mlp( pytest.skip("Unary activations do not use GLU interleaving") if quantization == "nvfp4_4over6": pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if quantization == "nvfp4_rht" and ( + activation != "scaled_swiglu" or bias or glu_interleave_size != 32 + ): + pytest.skip("NVFP4 RHT grouped MLP coverage is limited to fused no-bias SwiGLU") if ( with_quantization - and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht") and activation.startswith("scaled_clamped_qgeglu") and bias ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size - # Activation parameters for clamped QGeGLU variants if activation == "scaled_clamped_qgeglu_custom": geglu_limit = 5.0 @@ -3845,13 +3857,7 @@ def test_grouped_mlp( fc2_ws_test.append(fc2_w_test) fc2_bs_test.append(fc2_b_test) - # Reference implementation - xs = torch.split(x_ref, split_sizes.tolist()) - probs = torch.split(probs_ref, split_sizes.tolist()) - ys = [] - for group_idx in range(group_size): - x = xs[group_idx] - x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + def _apply_activation(x: torch.Tensor) -> torch.Tensor: if activation_is_glu and glu_interleave_size is not None: x = x.reshape( -1, @@ -3863,66 +3869,85 @@ def test_grouped_mlp( x = x.reshape(-1, 2 * hidden_size) if activation == "scaled_swiglu": x1, x2 = x.chunk(2, dim=-1) - x = torch.nn.functional.silu(x1) * x2 - elif activation.startswith("scaled_clamped_qgeglu"): + return torch.nn.functional.silu(x1) * x2 + if activation.startswith("scaled_clamped_qgeglu"): x1, x2 = x.chunk(2, dim=-1) lim = torch.tensor(geglu_limit, device=x1.device, dtype=x1.dtype) x1c = torch.minimum(x1, lim) x2c = torch.clamp(x2, -lim, lim) - x = (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) - elif activation == "scaled_srelu": - x = torch.nn.functional.relu(x).square() - else: - raise ValueError(f"Unexpected grouped MLP activation ({activation})") - x = x * probs[group_idx].unsqueeze(-1) - x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx]) + return (x2c + geglu_offset) * (x1c * torch.sigmoid(geglu_alpha * x1c)) + if activation == "scaled_srelu": + return torch.nn.functional.relu(x).square() + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + fc1_out = torch.nn.functional.linear( + x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] + ) + fc2_in = _apply_activation(fc1_out) + fc2_in = fc2_in * probs[group_idx].unsqueeze(-1) + y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: - x = x + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) - ys.append(x) + y = y + fc2_bs_ref[group_idx] * probs[group_idx].unsqueeze(-1) + ys.append(y) y_ref = torch.cat(ys) y_ref.backward(dy_ref) # Construct operations recipe = make_recipe(quantization) - if activation == "scaled_clamped_qgeglu_custom": - scaled_act = te_ops.ScaledClampedQGeGLU( - glu_interleave_size=glu_interleave_size, - limit=geglu_limit, - alpha=geglu_alpha, - glu_linear_offset=geglu_offset, - ) - with te.quantized_model_init(enabled=with_quantization, recipe=recipe): - fc1 = te_ops.GroupedLinear( - group_size, - hidden_size, - fc1_out_features, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - ) - fc2 = te_ops.GroupedLinear( - group_size, - hidden_size, - hidden_size, - bias=bias, - device=device, - dtype=dtype, - single_grouped_weight=single_grouped_weight, - single_grouped_bias=single_grouped_bias, - accumulate_into_main_grad=accumulate_into_main_grad, - delay_wgrad_compute=delay_wgrad_compute, - scale_bias=bias, - ) - module = te_ops.Sequential( - fc1, - scaled_act, - fc2, - ) + def _make_scaled_act(): + if activation == "scaled_swiglu": + return te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_clamped_qgeglu_custom": + return te_ops.ScaledClampedQGeGLU( + glu_interleave_size=glu_interleave_size, + limit=geglu_limit, + alpha=geglu_alpha, + glu_linear_offset=geglu_offset, + ) + if activation.startswith("scaled_clamped_qgeglu"): + return te_ops.ScaledClampedQGeGLU(glu_interleave_size=glu_interleave_size) + if activation == "scaled_srelu": + return te_ops.ScaledSReLU() + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + + def _make_module(): + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1_op = te_ops.GroupedLinear( + group_size, + hidden_size, + fc1_out_features, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + ) + + fc2_op = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + accumulate_into_main_grad=accumulate_into_main_grad, + delay_wgrad_compute=delay_wgrad_compute, + scale_bias=bias, + ) + return te_ops.Sequential(fc1_op, _make_scaled_act(), fc2_op), fc1_op, fc2_op + + module, fc1, fc2 = _make_module() # Copy weights with torch.no_grad(): @@ -3976,7 +4001,7 @@ def test_grouped_mlp( fc2.backward_dw() # Check for expected fusions - if ( + expected_grouped_mlp_fusion = ( quantization == "mxfp8" and dtype in (torch.bfloat16, torch.float16) and ( @@ -3984,13 +4009,14 @@ def test_grouped_mlp( or (activation_is_glu and glu_interleave_size == 32) ) and _cudnn_frontend_version_supported() - ): + ) + if expected_grouped_mlp_fusion: if activation_is_glu: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8 - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8 + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU else: - forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary_MXFP8 - backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8 + forward_cls = te_ops.fused.ForwardGroupedMLP_CuTeGEMMUnary + backward_cls = te_ops.fused.BackwardGroupedMLP_CuTeGEMMDUnary if forward_cls.is_supported(): forward_ops = module._module_groups[0]._forward_ops assert len(forward_ops) == 1 @@ -4008,7 +4034,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): tols = {"rtol": 0.25, "atol": 0.5} # Check values @@ -4088,9 +4114,9 @@ def test_grouped_mlp_single_weight_numerics( ) -> None: """single_grouped_weight=True/False should match exactly for fused MXFP8 grouped MLP.""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") split_sizes = [split_alignment * (i + 1) for i in range(group_size)] @@ -4192,12 +4218,12 @@ def _run_case(single_grouped_weight: bool) -> tuple[torch.Tensor, ...]: assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) if single_grouped_weight: @@ -4310,9 +4336,9 @@ def test_grouped_mlp_overwrite_main_grad( that read ``.grad`` don't see stale bytes from the cached dummy). """ - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP forward is not supported on this system") - if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): + if not te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP backward is not supported on this system") recipe = make_recipe("mxfp8") @@ -4444,7 +4470,7 @@ def test_grouped_mlp_cuda_graph_safe_mxfp8( ) -> None: """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" - if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): pytest.skip("MXFP8 fused grouped MLP is not supported on this system") if dtype not in (torch.bfloat16, torch.float16): pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") @@ -4586,12 +4612,12 @@ def train_step( assert len(forward_ops) == 1 assert isinstance( forward_ops[0][0], - te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + te_ops.fused.ForwardGroupedMLP_CuTeGEMMGLU, ) assert len(backward_ops) == 1 assert isinstance( backward_ops[0][0], - te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDGLU, ) fresh_x = torch.randn_like(static_x) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 19cc118a90..84489f30c1 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,10 +145,10 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6", "nvfp4_rht"): use_4over6 = name == "nvfp4_4over6" kwargs = { - "disable_rht": True, + "disable_rht": name != "nvfp4_rht", "disable_stochastic_rounding": True, "disable_2d_quantization": not use_4over6, "row_scaled_activation": name == "nvfp4_row_scaled", @@ -163,12 +163,16 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" - if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 != "none": - return "NVFP4RowScaled4Over6BlockScaling" - if recipe.nvfp4() and recipe.nvfp4_4over6 != "none": - return "NVFP44Over6BlockScaling" - if recipe.nvfp4() and recipe.row_scaled_activation: - return "NVFP4RowScaledBlockScaling" + if recipe.nvfp4(): + nvfp4_features = [] + if recipe.row_scaled_activation: + nvfp4_features.append("RowScaled") + if recipe.nvfp4_4over6 != "none": + nvfp4_features.append("4Over6") + if not recipe.disable_rht: + nvfp4_features.append("RHT") + if nvfp4_features: + return f"NVFP4{''.join(nvfp4_features)}BlockScaling" return type(recipe).__name__ diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 717d872010..87911d76f4 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -5,6 +5,7 @@ """Helper functions used in fusible operations.""" from __future__ import annotations +from collections.abc import Iterable import functools import math from importlib.metadata import PackageNotFoundError, version as get_pkg_version @@ -13,10 +14,13 @@ import torch from packaging.version import Version as PkgVersion +import transformer_engine_torch as tex from transformer_engine_torch import FP8TensorMeta from ..torch_version import torch_version from ..quantization import FP8GlobalStateManager +from ..tensor import NVFP4Quantizer, NVFP4Tensor, NVFP4TensorStorage, Quantizer from ..tensor.float8_tensor import Float8Tensor +from ..tensor.grouped_tensor import GroupedTensor from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype @@ -57,6 +61,146 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool: return _cudnn_frontend_version_supported() +def _group_quantize_for_grouped_mlp( + tensor: torch.Tensor, + quantizer: Quantizer, + num_groups: int, + split_sizes: Optional[torch.Tensor], + *, + tensor_offsets: Optional[torch.Tensor] = None, +) -> GroupedTensor: + """Quantize into grouped storage.""" + + # Typical case: group-quantize + if num_groups != 1 or not isinstance(quantizer, NVFP4Quantizer): + return tex.group_quantize(tensor, quantizer, num_groups, split_sizes) + + # -------------------------------------------------- + # Special case: single-tensor NVFP4 quantize + # -------------------------------------------------- + + quantized = tex.quantize(tensor, quantizer) + with_gemm_swizzled_scales = quantized._with_gemm_swizzled_scales + if quantizer.optimize_for_gemm: + tex.swizzle_scales_for_gemm_(quantized) + with_gemm_swizzled_scales = True + + rowwise_data = quantized._rowwise_data + rowwise_scale = quantized._rowwise_scale_inv + columnwise_data = quantized._columnwise_data + columnwise_scale = quantized._columnwise_scale_inv + amax = quantized._amax_rowwise + columnwise_amax = quantized._amax_columnwise + + if split_sizes is None: + split_sizes = torch.full((1,), tensor.shape[0], dtype=torch.int64, device=tensor.device) + else: + split_sizes = split_sizes.to(dtype=torch.int64, device=tensor.device) + + m_dim = tensor.shape[0] + if rowwise_data is not None: + k_dim = rowwise_data.shape[-1] * 2 + elif columnwise_data is not None: + k_dim = columnwise_data.shape[0] + else: + k_dim = tensor.shape[-1] + + if tensor_offsets is None: + tensor_offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int64, device=tensor.device), + torch.cumsum(split_sizes * k_dim, dim=0), + ], + ) + + return GroupedTensor( + shape=(m_dim, k_dim), + dtype=tensor.dtype, + quantizer=quantizer, + num_tensors=1, + data=rowwise_data.reshape(-1) if rowwise_data is not None else None, + columnwise_data=columnwise_data.reshape(-1) if columnwise_data is not None else None, + scale_inv=rowwise_scale.reshape(-1) if rowwise_scale is not None else None, + columnwise_scale_inv=columnwise_scale.reshape(-1) if columnwise_scale is not None else None, + amax=amax, + columnwise_amax=columnwise_amax, + first_dims=split_sizes, + tensor_offsets=tensor_offsets, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + +def _nvfp4_amax( + tensors: GroupedTensor | Iterable[NVFP4TensorStorage], + *, + columnwise: bool, +) -> torch.Tensor: + """Get one NVFP4 amax value per group.""" + grouped_attr = "columnwise_amax" if columnwise else "amax" + tensor_attr = "_amax_columnwise" if columnwise else "_amax_rowwise" + + if hasattr(tensors, grouped_attr): + amax = getattr(tensors, grouped_attr) + if amax is None: + raise RuntimeError(f"NVFP4 GroupedTensor is missing {grouped_attr}.") + return amax.view(-1) + + amaxes = [getattr(tensor, tensor_attr) for tensor in tensors] + if any(amax is None for amax in amaxes): + raise RuntimeError(f"NVFP4 tensor list is missing {tensor_attr}.") + return torch.cat([amax.view(-1) for amax in amaxes], dim=0) + + +def _nvfp4_single_tensor_from_grouped( + grouped: GroupedTensor, + quantizer: Optional[NVFP4Quantizer] = None, + *, + fp4_dtype: Optional[torch.dtype] = None, +) -> NVFP4Tensor: + """Build a single NVFP4Tensor view over a one-member grouped storage.""" + if quantizer is None: + quantizer = grouped.quantizer + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError("Expected an NVFP4 GroupedTensor.") + + shape = tuple(grouped.logical_shape) + rowwise_data = None + if grouped.rowwise_data is not None: + rowwise_data = grouped.rowwise_data.view(quantizer.convert_shape_for_fp4(shape)) + + rowwise_scale_inv = None + if grouped.scale_inv is not None: + rowwise_scale_inv = grouped.scale_inv.view(quantizer.get_scale_shape(shape, False)) + + columnwise_data = None + if grouped.columnwise_data is not None: + columnwise_shape = quantizer.get_columnwise_shape(shape) + columnwise_data = grouped.columnwise_data.view( + quantizer.convert_shape_for_fp4(columnwise_shape) + ) + + columnwise_scale_inv = None + if grouped.columnwise_scale_inv is not None: + columnwise_scale_inv = grouped.columnwise_scale_inv.view( + quantizer.get_scale_shape(shape, True) + ) + + return NVFP4Tensor( + shape=shape, + dtype=grouped.get_dtype(), + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=grouped.amax, + amax_columnwise=grouped.columnwise_amax, + fp4_dtype=fp4_dtype or quantizer.dtype, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=grouped._with_gemm_swizzled_scales, + ) + + def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" return isinstance(tensor, QuantizedTensorStorage) @@ -285,7 +429,10 @@ def fuse_grouped_mlp_ops( if not fused_op_cls.is_supported(): return ops - if recipe is None or not recipe.mxfp8(): + if recipe is None or not (recipe.mxfp8() or recipe.nvfp4()): + return ops + # NVFP4 fused grouped MLP uses graph-safe grouped quantize, which currently requires RHT. + if recipe.nvfp4() and recipe.disable_rht: return ops if activation_op_types is None: activation_op_types = (ScaledSwiGLU, ScaledClampedQGeGLU) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index dc15bc63b8..e9787f96b2 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -22,7 +22,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer from ...utils import ( @@ -291,6 +291,25 @@ def num_quantizers(self, mode: str) -> int: return self.num_groups return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" + if mode == "forward": + roles = [] + for _ in range(self.num_groups): + roles.extend( + [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] + ) + return roles + if mode == "backward": + return [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name) + for _ in range(self.num_groups) + ] + return None + @property def has_bias(self) -> bool: """Whether an additive bias is being applied""" diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index b29e35814d..78f9d880ba 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -32,10 +32,10 @@ # Import experimental fusions # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position - ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, - ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, + ForwardGroupedMLP_CuTeGEMMGLU, + ForwardGroupedMLP_CuTeGEMMUnary, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position - BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, - BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, + BackwardGroupedMLP_CuTeGEMMDGLU, + BackwardGroupedMLP_CuTeGEMMDUnary, ) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 25ccad1377..792b6d7811 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -14,10 +14,17 @@ import transformer_engine_torch as tex from ...quantization import Recipe +from ...tensor import NVFP4Quantizer, NVFP4Tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...utils import clear_tensor_data, get_cached_ones_tensor, get_device_compute_capability -from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ...utils import ( + ceil_div, + clear_tensor_data, + get_cached_ones_tensor, + get_device_compute_capability, + round_up_to_nearest_multiple, +) +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -25,6 +32,9 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, + _group_quantize_for_grouped_mlp, + _nvfp4_amax, + _nvfp4_single_tensor_from_grouped, fuse_grouped_mlp_ops, get_accumulate_flag_in_param, get_dummy_wgrads_for_params, @@ -34,11 +44,41 @@ view_main_grad_as_grouped_buffer, validate_grouped_mlp_dims, ) -from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor +from ...cpp_extensions import ( + general_gemm, + general_grouped_gemm_for_grouped_tensor, +) from ...module.base import _2X_ACC_WGRAD from ...triton.grouped_dbias_dscales import compute_grouped_dbias_dscales +def _nvfp4_single_group_wgrad_gemm( + grouped_x: GroupedTensor, + grouped_dy: GroupedTensor, + wgrad_output, + *, + weight_shape: tuple[int, int], + accumulate: bool, +) -> None: + """Run one-group NVFP4 wgrad with regular GEMM instead of grouped GEMM.""" + x_single = _nvfp4_single_tensor_from_grouped(grouped_x) + dy_single = _nvfp4_single_tensor_from_grouped(grouped_dy) + if isinstance(wgrad_output, GroupedTensor): + out = wgrad_output.rowwise_data.view(1, *weight_shape)[0] + else: + out = wgrad_output[0] + + general_gemm( + x_single, + dy_single, + out_dtype=out.dtype, + out=out, + layout="NT", + accumulate=accumulate, + use_split_accumulator=_2X_ACC_WGRAD, + ) + + def _cudnn_compute_wgrad( grouped_x: GroupedTensor, grouped_dy: GroupedTensor, @@ -62,8 +102,8 @@ def _cudnn_compute_wgrad( fp8_dtype = torch.float8_e4m3fn - sfa_leading_dim = ((out_features + 127) // 128) * 128 - sfb_leading_dim = ((in_features + 127) // 128) * 128 + sfa_leading_dim = round_up_to_nearest_multiple(out_features, 128) + sfb_leading_dim = round_up_to_nearest_multiple(in_features, 128) if total_tokens == 0: # A workaround for the case with zero-token experts. @@ -220,6 +260,18 @@ def _compute_grad_params( single_grouped_weight=fc_op.single_grouped_weight, current_stream=torch.cuda.current_stream().cuda_stream, ) + elif ( + num_groups == 1 + and isinstance(grouped_x, GroupedTensor) + and isinstance(grouped_dy, GroupedTensor) + and isinstance(grouped_x.quantizer, NVFP4Quantizer) + and isinstance(grouped_dy.quantizer, NVFP4Quantizer) + ): + gemm_fn = functools.partial( + _nvfp4_single_group_wgrad_gemm, + weight_shape=weight_shape, + accumulate=accumulate_into_main_grad, + ) else: gemm_fn = functools.partial( general_grouped_gemm_for_grouped_tensor, @@ -252,8 +304,8 @@ def _compute_grad_params( return w_list + bias_list -class _BackwardGroupedMLP_CuTeGEMMDBase_MXFP8(FusedOperation): - """Base fused backward op for MXFP8 GroupedLinear + activation + GroupedLinear. +class _BackwardGroupedMLP_CuTeGEMMDBase(FusedOperation): + """Base fused backward op for block-scaled GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -360,7 +412,9 @@ def fuser_backward( grad_output = grad_output.reshape(-1, fc2_weight_shape[0]) out_shape = list(grad_output.size()) num_groups = fc1_op.num_groups - device = fc1_op._get_weight_tensors()[0].device + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + device = fc1_weight_param.device dtype = fc1_ctx.dtype # Saved tensors from FC1 forward. @@ -419,10 +473,18 @@ def fuser_backward( output_fc2_dbias = fc2_op.has_bias fc2_dbias_packed = None fc2_dy = None + grad_output_quantizer = getattr(grad_output, "quantizer", None) + fc2_grad_output_quantizer_matches = ( + isinstance(fc2_grad_output_quantizer, MXFP8Quantizer) + and isinstance(grad_output_quantizer, MXFP8Quantizer) + ) or ( + isinstance(fc2_grad_output_quantizer, NVFP4Quantizer) + and isinstance(grad_output_quantizer, NVFP4Quantizer) + ) if ( not output_fc2_dbias and isinstance(grad_output, GroupedTensor) - and isinstance(getattr(grad_output, "quantizer", None), MXFP8Quantizer) + and fc2_grad_output_quantizer_matches ): grouped_fc2_dy = grad_output else: @@ -435,13 +497,26 @@ def fuser_backward( split_sizes, ) else: - grouped_fc2_dy = tex.group_quantize( + grouped_fc2_dy = _group_quantize_for_grouped_mlp( fc2_dy, fc2_grad_output_quantizer, num_groups, split_sizes, + tensor_offsets=base_split_offsets * fc2_weight_shape[0], ) + use_nvfp4 = ( + isinstance(fc2_grad_output_quantizer, NVFP4Quantizer) + or isinstance(fc1_weight_param, NVFP4Tensor) + or isinstance(fc2_weight_param, NVFP4Tensor) + ) + data_dtype = torch.float4_e2m1fn_x2 if use_nvfp4 else torch.float8_e4m3fn + scale_view_dtype = torch.float8_e4m3fn if use_nvfp4 else torch.float8_e8m0fnu + sf_vec_size = NVFP4_BLOCK_SCALING_SIZE if use_nvfp4 else MXFP8_BLOCK_SCALING_SIZE + data_k = out_shape[1] // 2 if use_nvfp4 else out_shape[1] + fc2_weight_k = fc2_weight_shape[1] // 2 if use_nvfp4 else fc2_weight_shape[1] + k_sf_divisor = 2 * sf_vec_size if use_nvfp4 else 4 * sf_vec_size + # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous # logical dims. @@ -451,20 +526,42 @@ def fuser_backward( # Data logical shape: (sum(m), k, 1) # Scale logical shape: (32 (block row), 4 (block row), # sum(m)/128, 4 (block col), k/128, 1) - fc2_dy_data = grouped_fc2_dy.rowwise_data.view(out_shape[0], out_shape[1]) - fc2_dy_data = fc2_dy_data.view(dtype=torch.float8_e4m3fn) + fc2_dy_data = grouped_fc2_dy.rowwise_data.view(dtype=data_dtype) + fc2_dy_data = fc2_dy_data.view(out_shape[0], data_k) fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0) fc2_dy_scales = grouped_fc2_dy.scale_inv - fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu) - fc2_dy_scales = fc2_dy_scales.view( - 1, - (out_shape[0] + 127) // 128, - (out_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, - ) - fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + fc2_dy_scales = fc2_dy_scales.view(dtype=scale_view_dtype) + with_gemm_swizzled_scales = grouped_fc2_dy._with_gemm_swizzled_scales + if use_nvfp4 and with_gemm_swizzled_scales: + fc2_dy_scales = fc2_dy_scales.view( + 1, + ceil_div(out_shape[0], 128), + ceil_div(data_k, k_sf_divisor), + 32, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + elif use_nvfp4: + fc2_dy_scales = fc2_dy_scales.view( + 1, + ceil_div(out_shape[0], 128), + 4, + 32, + ceil_div(data_k, k_sf_divisor), + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 2, 1, 5, 4, 0) + else: + fc2_dy_scales = fc2_dy_scales.view( + 1, + ceil_div(out_shape[0], 128), + ceil_div(out_shape[1], k_sf_divisor), + 32, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) # Kernel scaling factors alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) @@ -475,25 +572,43 @@ def fuser_backward( scales_tensor = scales_f32.reshape(-1, 1, 1) dscales_tensor = torch.zeros_like(scales_tensor) + fc2_d_dtype = torch.bfloat16 if use_nvfp4 else torch.float8_e4m3fn + if use_nvfp4: + nvfp4_fp4_max = 6.0 + nvfp4_fp8_max = 448.0 + fc2_alpha_tensor = ( + torch.sqrt( + _nvfp4_amax(grouped_fc2_dy, columnwise=False) + * _nvfp4_amax(grouped_fc2_weight, columnwise=True) + ) + / (nvfp4_fp8_max * nvfp4_fp4_max) + ).expand(num_groups) + fc2_beta_tensor = get_cached_ones_tensor(num_groups, torch.float32, device) + fc2_norm_const_tensor = None + else: + fc2_alpha_tensor = alpha_tensor + fc2_beta_tensor = alpha_tensor + fc2_norm_const_tensor = norm_const_tensor + fc2_dactivation_kwargs = { "a_tensor": fc2_dy_data, "c_tensor": activation_in.unsqueeze(0).permute(1, 2, 0), "sfa_tensor": fc2_dy_scales, "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, + "alpha_tensor": fc2_alpha_tensor, "prob_tensor": scales_tensor, "dprob_tensor": dscales_tensor, "generate_dbias": fc1_op.has_bias, - "norm_const_tensor": norm_const_tensor, - "d_dtype": torch.float8_e4m3fn, + "norm_const_tensor": fc2_norm_const_tensor, + "d_dtype": fc2_d_dtype, "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "sf_vec_size": sf_vec_size, "current_stream": current_stream, - "discrete_col_sfd": True, + "discrete_col_sfd": not use_nvfp4, "use_dynamic_sched": True, } if self._cudnn_dact_func is not None: - fc2_dactivation_kwargs["beta_tensor"] = alpha_tensor + fc2_dactivation_kwargs["beta_tensor"] = fc2_beta_tensor fc2_dactivation_kwargs["act_func"] = self._cudnn_dact_func else: fc2_dactivation_kwargs["use_dsrelu_reuse"] = recompute_fc2_x_from_dsrelu @@ -513,19 +628,23 @@ def fuser_backward( # Data actual shape: (num_groups, k, n) # Data logical shape: (n, k, num_groups) fc2_w_data = fc2_weight_for_gemm.columnwise_data - fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) - fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) - fc2_w_data = fc2_w_data.permute(2, 1, 0) - fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_data = fc2_w_data.view(dtype=data_dtype) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_k) + fc2_w_data = fc2_w_data.permute(1, 2, 0) if use_nvfp4 else fc2_w_data.permute(2, 1, 0) + fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=scale_view_dtype) fc2_w_scales = fc2_w_scales.view( num_groups, - (fc2_weight_shape[1] + 127) // 128, - (fc2_weight_shape[0] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, + ceil_div(fc2_weight_shape[1], k_sf_divisor), + ceil_div(fc2_weight_shape[0], 128), + 32, 4, 4, ) - fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_w_scales = ( + fc2_w_scales.permute(3, 4, 2, 5, 1, 0) + if use_nvfp4 + else fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + ) fc2_dactivation_kwargs["b_tensor"] = fc2_w_data fc2_dactivation_kwargs["sfb_tensor"] = fc2_w_scales @@ -534,27 +653,43 @@ def fuser_backward( [w._columnwise_data for w in grouped_fc2_weight], device, ) + swizzle_type = ( + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" + ) fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - "uniform_mxfp8_columnwise_swizzle", + swizzle_type, [w._columnwise_scale_inv for w in grouped_fc2_weight], device, ) fc2_dactivation_kwargs["b_ptrs"] = fc2_b_ptrs fc2_dactivation_kwargs["sfb_ptrs"] = fc2_sfb_ptrs fc2_dactivation_kwargs["n"] = fc2_weight_shape[1] - fc2_dactivation_kwargs["b_dtype"] = torch.float8_e4m3fn - fc2_dactivation_kwargs["b_major"] = "n" + fc2_dactivation_kwargs["b_dtype"] = data_dtype + fc2_dactivation_kwargs["b_major"] = "k" if use_nvfp4 else "n" fc2_dgrad_kernel_out = self.grouped_gemm_dactivation_kernel()(**fc2_dactivation_kwargs) - fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] - fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) - # View scale in their actual swizzled shape - fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) - fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] - fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]) - # View scale in their actual swizzled shape - fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + if use_nvfp4: + fc1_dy_bf16 = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_bf16 = fc1_dy_bf16.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_row_data = None + fc1_dy_row_scale = None + fc1_dy_col_data = None + fc1_dy_col_scale = None + else: + fc1_dy_bf16 = None + fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) + # View scale in their actual swizzled shape + fc1_dy_row_scale = ( + fc2_dgrad_kernel_out["sfd_row_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + ) + fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] + fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]) + # View scale in their actual swizzled shape + fc1_dy_col_scale = ( + fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) + ) grad_scales = fc2_dgrad_kernel_out["dprob_tensor"].view(-1) if recompute_fc2_x_from_dsrelu: @@ -628,21 +763,37 @@ def fuser_backward( # FC1 grad output for dgrad and wgrad GEMMs fc1_dy_tensor_offsets = base_split_offsets * fc1_weight_shape[0] - grouped_fc1_dy = GroupedTensor( - shape=(out_shape[0], fc1_weight_shape[0]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc1_ctx.grad_output_quantizers[0], - data=fc1_dy_row_data, - columnwise_data=fc1_dy_col_data, - scale_inv=fc1_dy_row_scale, - columnwise_scale_inv=fc1_dy_col_scale, - first_dims=split_sizes, - tensor_offsets=fc1_dy_tensor_offsets, - with_gemm_swizzled_scales=True, - ) + fc1_grad_output_quantizer = fc1_ctx.grad_output_quantizers[0] + if use_nvfp4: + fc1_grad_output_quantizer.set_usage( + rowwise=True, + columnwise=fc1_ctx.weight_requires_grad, + ) + fc1_grad_output_quantizer.optimize_for_gemm = True + grouped_fc1_dy = _group_quantize_for_grouped_mlp( + fc1_dy_bf16, + fc1_grad_output_quantizer, + num_groups, + split_sizes, + tensor_offsets=fc1_dy_tensor_offsets, + ) + else: + grouped_fc1_dy = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[0]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc1_grad_output_quantizer, + data=fc1_dy_row_data, + columnwise_data=fc1_dy_col_data, + scale_inv=fc1_dy_row_scale, + columnwise_scale_inv=fc1_dy_col_scale, + first_dims=split_sizes, + tensor_offsets=fc1_dy_tensor_offsets, + with_gemm_swizzled_scales=True, + ) # FC2 wgrad GEMM + wgrad_kernel_fn = None if use_nvfp4 else self.grouped_gemm_wgrad_kernel() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -655,7 +806,7 @@ def fuser_backward( bias_grads=fc2_bias_grads, bias_grad_packed=fc2_bias_grad_packed, label="FC2", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), + cudnn_wgrad_kernel_fn=wgrad_kernel_fn, offsets=split_points, ) @@ -677,67 +828,110 @@ def fuser_backward( if fc1_ctx.input_requires_grad: in_shape = out_shape[:-1] + [fc1_weight_shape[1]] - fc1_dgrad_a_data = fc2_dgrad_kernel_out["d_row_tensor"] - fc1_dgrad_a_scales = fc2_dgrad_kernel_out["sfd_row_tensor"] - - fc1_dgrad_kwargs = { - "a_tensor": fc1_dgrad_a_data, - "sfa_tensor": fc1_dgrad_a_scales, - "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, - "norm_const_tensor": None, - "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), - "acc_dtype": torch.float32, - "d_dtype": dtype, - "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, - "current_stream": current_stream, - "discrete_col_sfd": True, - "use_dynamic_sched": True, - } - - if fc1_op.single_grouped_weight: - # Clone and swizzle scales for GEMM - fc1_weight_for_gemm = grouped_fc1_weight.copy() - tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=False, columnwise=True) - - fc1_w_data = fc1_weight_for_gemm.columnwise_data - fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) - fc1_w_data = fc1_w_data.permute(2, 1, 0) - fc1_w_scales = fc1_weight_for_gemm.columnwise_scale_inv.view( - dtype=torch.float8_e8m0fnu - ) - fc1_w_scales = fc1_w_scales.view( - num_groups, - (fc1_weight_shape[1] + 127) // 128, - (fc1_weight_shape[0] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, - ) - fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) - - fc1_dgrad_kwargs["b_tensor"] = fc1_w_data - fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales + if use_nvfp4: + grad_input = torch.empty(in_shape, dtype=dtype, device=device) + if num_groups == 1: + if fc1_op.single_grouped_weight: + fc1_w_single = grouped_fc1_weight.split_into_quantized_tensors()[0] + else: + fc1_w_single = grouped_fc1_weight[0] + fc1_dy_single = _nvfp4_single_tensor_from_grouped(grouped_fc1_dy) + general_gemm( + fc1_w_single, + fc1_dy_single, + out_dtype=dtype, + out=grad_input, + layout="NN", + ) + else: + fc1_x_tensor_offsets = base_split_offsets * fc1_weight_shape[1] + grouped_grad_input = GroupedTensor( + shape=(out_shape[0], fc1_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=grad_input.view(-1), + first_dims=split_sizes, + tensor_offsets=fc1_x_tensor_offsets, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_fc1_weight, + grouped_fc1_dy, + grouped_grad_input, + layout="NN", + ) else: - fc1_b_ptrs = tex.copy_data_ptrs_to_device( - [w._columnwise_data for w in grouped_fc1_weight], - device, - ) - fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - "uniform_mxfp8_columnwise_swizzle", - [w._columnwise_scale_inv for w in grouped_fc1_weight], - device, - ) - fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs - fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs - fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] - fc1_dgrad_kwargs["b_dtype"] = torch.float8_e4m3fn - fc1_dgrad_kwargs["b_major"] = "n" - - fc1_dgrad_kernel_out = self.grouped_gemm_quant_kernel()(**fc1_dgrad_kwargs) - grad_input = fc1_dgrad_kernel_out["d_tensor"].view(in_shape) + fc1_dgrad_a_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dgrad_a_scales = fc2_dgrad_kernel_out["sfd_row_tensor"] + + fc1_dgrad_kwargs = { + "a_tensor": fc1_dgrad_a_data, + "sfa_tensor": fc1_dgrad_a_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "norm_const_tensor": None, + "prob_tensor": torch.ones( + (out_shape[0], 1, 1), dtype=torch.float32, device=device + ), + "acc_dtype": torch.float32, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": True, + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + # Clone and swizzle scales for GEMM + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm( + fc1_weight_for_gemm, rowwise=False, columnwise=True + ) + + fc1_w_data = fc1_weight_for_gemm.columnwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view( + num_groups, fc1_weight_shape[0], fc1_weight_shape[1] + ) + fc1_w_data = fc1_w_data.permute(2, 1, 0) + fc1_w_scales = fc1_weight_for_gemm.columnwise_scale_inv.view( + dtype=torch.float8_e8m0fnu + ) + fc1_w_scales = fc1_w_scales.view( + num_groups, + ceil_div(fc1_weight_shape[1], 128), + ceil_div(fc1_weight_shape[0], 128), + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_dgrad_kwargs["b_tensor"] = fc1_w_data + fc1_dgrad_kwargs["sfb_tensor"] = fc1_w_scales + else: + fc1_b_ptrs = tex.copy_data_ptrs_to_device( + [w._columnwise_data for w in grouped_fc1_weight], + device, + ) + swizzle_type = ( + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_columnwise_swizzle" + ) + fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._columnwise_scale_inv for w in grouped_fc1_weight], + device, + ) + + fc1_dgrad_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_dgrad_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_dgrad_kwargs["n"] = fc1_weight_shape[1] + fc1_dgrad_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_dgrad_kwargs["b_major"] = "n" + + fc1_dgrad_kernel_out = self.grouped_gemm_quant_kernel()(**fc1_dgrad_kwargs) + grad_input = fc1_dgrad_kernel_out["d_tensor"].view(in_shape) # FC1 wgrad GEMM fc1_grad_params = _compute_grad_params( @@ -752,7 +946,7 @@ def fuser_backward( bias_grads=fc1_bias_grads, bias_grad_packed=fc1_bias_grad_packed, label="FC1", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), + cudnn_wgrad_kernel_fn=wgrad_kernel_fn, offsets=split_points, ) @@ -778,8 +972,8 @@ def fuser_backward( ) -class BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): - """Fused backward op for GroupedLinear + scaled GLU + GroupedLinear.""" +class BackwardGroupedMLP_CuTeGEMMDGLU(_BackwardGroupedMLP_CuTeGEMMDBase): + """Fused backward op for block-scaled GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -790,8 +984,8 @@ def grouped_gemm_dactivation_kernel(cls) -> Callable: return grouped_gemm_dglu_wrapper_sm100 -class BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8(_BackwardGroupedMLP_CuTeGEMMDBase_MXFP8): - """Fused backward op for GroupedLinear + scaled unary activation + GroupedLinear.""" +class BackwardGroupedMLP_CuTeGEMMDUnary(_BackwardGroupedMLP_CuTeGEMMDBase): + """Fused backward op for block-scaled GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -833,7 +1027,7 @@ def fuse_backward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDGLU, ) @@ -845,16 +1039,18 @@ def fuse_backward_srelu_ops( ) -> list[FusibleOperation]: """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for backward pass.""" + if recipe is None or not recipe.mxfp8(): + return ops return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8, + fused_op_cls=BackwardGroupedMLP_CuTeGEMMDUnary, activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if BackwardGroupedMLP_CuTeGEMMDGLU_MXFP8.is_supported(): +if BackwardGroupedMLP_CuTeGEMMDGLU.is_supported(): register_backward_fusion(fuse_backward_ops, prepend=True) -if BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8.is_supported(): +if BackwardGroupedMLP_CuTeGEMMDUnary.is_supported(): register_backward_fusion(fuse_backward_srelu_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index a0c5f766c5..f4f2108578 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,12 +13,18 @@ import torch import transformer_engine_torch as tex +from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe -from ...tensor import Quantizer -from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor +from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer +from ...utils import ( + ceil_div, + get_cached_ones_tensor, + get_device_compute_capability, + mark_grouped_tensor, +) from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext @@ -26,7 +32,10 @@ _cudnn_frontend_geglu_runtime_params, _cudnn_frontend_version_supported, _cudnn_frontend_supports_grouped_gemm_srelu, + _group_quantize_for_grouped_mlp, _nvidia_cudnn_frontend_supports_wgrad, + _nvfp4_amax, + _nvfp4_single_tensor_from_grouped, fuse_grouped_mlp_ops, is_glu_activation, is_quantized_tensor, @@ -67,8 +76,8 @@ def _grouped_gemm_dsrelu_backward_supported() -> bool: return grouped_gemm_dsrelu_wrapper_sm100 is not None -class _ForwardGroupedMLP_CuTeGEMMBase_MXFP8(FusedOperation): - """Base fused op for MXFP8 GroupedLinear + activation + GroupedLinear. +class _ForwardGroupedMLP_CuTeGEMMBase(FusedOperation): + """Base fused op for block-scaled GroupedLinear + activation + GroupedLinear. Uses experimental CuTe DSL kernel from cuDNN front-end. @@ -202,6 +211,7 @@ def fuser_forward( split_sizes = split_sizes.to(dtype=torch.int64, device=device) base_split_offsets = tex.splits_to_offsets(split_sizes, 1) split_points = base_split_offsets[1:].to(dtype=torch.int) + fc1_x_tensor_offsets = base_split_offsets * fc1_weight_shape[1] fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] # Extract per-row activation probabilities from the middle op. @@ -224,7 +234,7 @@ def fuser_forward( if fc1_op.weight.rowwise_data is None: raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - grouped_fc1_weight = tex.group_quantize( + grouped_fc1_weight = _group_quantize_for_grouped_mlp( fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), fc1_weight_quantizer, num_groups, @@ -256,7 +266,7 @@ def fuser_forward( if fc2_op.weight.rowwise_data is None: raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - grouped_fc2_weight = tex.group_quantize( + grouped_fc2_weight = _group_quantize_for_grouped_mlp( fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), fc2_weight_quantizer, num_groups, @@ -276,25 +286,45 @@ def fuser_forward( grouped_fc2_weight = quantized_fc2_weights # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. - if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( - grouped_fc1_weight, GroupedTensor + if isinstance(grouped_fc1_weight, GroupedTensor) and not hasattr( + grouped_fc1_weight, "_with_gemm_swizzled_scales" ): grouped_fc1_weight._with_gemm_swizzled_scales = False - if getattr(grouped_fc2_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( - grouped_fc2_weight, GroupedTensor + if isinstance(grouped_fc2_weight, GroupedTensor) and not hasattr( + grouped_fc2_weight, "_with_gemm_swizzled_scales" ): grouped_fc2_weight._with_gemm_swizzled_scales = False # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True - if isinstance(input_, GroupedTensor) and isinstance( - getattr(input_, "quantizer", None), MXFP8Quantizer + input_quantizer = getattr(input_, "quantizer", None) + if isinstance(input_, GroupedTensor) and ( + isinstance(fc1_input_quantizer, MXFP8Quantizer) + and isinstance(input_quantizer, MXFP8Quantizer) + or isinstance(fc1_input_quantizer, NVFP4Quantizer) + and isinstance(input_quantizer, NVFP4Quantizer) ): grouped_fc1_x = input_ else: fc1_x = maybe_dequantize(input_, dtype) - grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) + grouped_fc1_x = _group_quantize_for_grouped_mlp( + fc1_x, + fc1_input_quantizer, + num_groups, + split_sizes, + tensor_offsets=fc1_x_tensor_offsets, + ) + + use_nvfp4 = isinstance(fc1_input_quantizer, NVFP4Quantizer) or isinstance( + fc1_weight_param, NVFP4Tensor + ) + data_dtype = torch.float4_e2m1fn_x2 if use_nvfp4 else torch.float8_e4m3fn + scale_view_dtype = torch.float8_e4m3fn if use_nvfp4 else torch.float8_e8m0fnu + sf_vec_size = NVFP4_BLOCK_SCALING_SIZE if use_nvfp4 else MXFP8_BLOCK_SCALING_SIZE + data_in_k = in_shape[1] // 2 if use_nvfp4 else in_shape[1] + fc1_weight_k = fc1_weight_shape[1] // 2 if use_nvfp4 else fc1_weight_shape[1] + k_sf_divisor = 2 * sf_vec_size if use_nvfp4 else 4 * sf_vec_size # Pack data tensors # Note: Fused kernel expects tensor with non-contiguous @@ -305,20 +335,42 @@ def fuser_forward( # Data logical shape: (sum(m), k, 1) # Scale logical shape: (32 (block row), 4 (block row), # sum(m)/128, 4 (block col), k/128, 1) - fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) - fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = grouped_fc1_x.rowwise_data.view(dtype=data_dtype) + fc1_x_data = fc1_x_data.view(in_shape[0], data_in_k) fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) fc1_x_scales = grouped_fc1_x.scale_inv - fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) - fc1_x_scales = fc1_x_scales.view( - 1, - (in_shape[0] + 127) // 128, - (in_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, - ) - fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + fc1_x_scales = fc1_x_scales.view(dtype=scale_view_dtype) + with_gemm_swizzled_scales = grouped_fc1_x._with_gemm_swizzled_scales + if use_nvfp4 and with_gemm_swizzled_scales: + fc1_x_scales = fc1_x_scales.view( + 1, + ceil_div(in_shape[0], 128), + ceil_div(data_in_k, k_sf_divisor), + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + elif use_nvfp4: + fc1_x_scales = fc1_x_scales.view( + 1, + ceil_div(in_shape[0], 128), + 4, + 32, + ceil_div(data_in_k, k_sf_divisor), + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 2, 1, 5, 4, 0) + else: + fc1_x_scales = fc1_x_scales.view( + 1, + ceil_div(in_shape[0], 128), + ceil_div(in_shape[1], k_sf_divisor), + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) norm_const_tensor = get_cached_ones_tensor(1, torch.float32, device) @@ -327,21 +379,37 @@ def fuser_forward( fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) + fc1_d_dtype = torch.bfloat16 if use_nvfp4 else torch.float8_e4m3fn + fc1_prob_tensor = ( + scales.detach().to(dtype=torch.float32 if use_nvfp4 else dtype).reshape(-1, 1, 1) + ) + fc1_norm_const_tensor = None if use_nvfp4 else norm_const_tensor + if use_nvfp4: + nvfp4_fp4_max = 6.0 + nvfp4_fp8_max = 448.0 + fc1_alpha_tensor = ( + _nvfp4_amax(grouped_fc1_x, columnwise=False) + * _nvfp4_amax(grouped_fc1_weight, columnwise=False) + / (nvfp4_fp4_max**2 * nvfp4_fp8_max**2) + ).to(torch.float32) + else: + fc1_alpha_tensor = alpha_tensor + fc1_activation_kwargs = { "a_tensor": fc1_x_data, "sfa_tensor": fc1_x_scales, "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, + "alpha_tensor": fc1_alpha_tensor, "bias_tensor": fc1_bias_packed, - "norm_const_tensor": norm_const_tensor, - "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "norm_const_tensor": fc1_norm_const_tensor, + "prob_tensor": fc1_prob_tensor, "acc_dtype": torch.float32, "c_dtype": torch.bfloat16, - "d_dtype": torch.float8_e4m3fn, + "d_dtype": fc1_d_dtype, "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "sf_vec_size": sf_vec_size, "current_stream": current_stream, - "discrete_col_sfd": True, + "discrete_col_sfd": not use_nvfp4, "use_dynamic_sched": True, } if self._cudnn_act_func is not None: @@ -363,15 +431,15 @@ def fuser_forward( # Data actual shape: (num_groups, n, k) # Data logical shape: (n, k, num_groups) fc1_w_data = fc1_weight_for_gemm.rowwise_data - fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) - fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.view(dtype=data_dtype) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_k) fc1_w_data = fc1_w_data.permute(1, 2, 0) - fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=scale_view_dtype) fc1_w_scales = fc1_w_scales.view( num_groups, - (fc1_weight_shape[0] + 127) // 128, - (fc1_weight_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, + ceil_div(fc1_weight_shape[0], 128), + ceil_div(fc1_weight_shape[1], k_sf_divisor), + 32, 4, 4, ) @@ -385,15 +453,16 @@ def fuser_forward( [w._rowwise_data for w in grouped_fc1_weight], device, ) + swizzle_type = "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" fc1_sfb_ptrs, _fc1_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - "uniform_mxfp8_rowwise_swizzle", + swizzle_type, [w._rowwise_scale_inv for w in grouped_fc1_weight], device, ) fc1_activation_kwargs["b_ptrs"] = fc1_b_ptrs fc1_activation_kwargs["sfb_ptrs"] = fc1_sfb_ptrs fc1_activation_kwargs["n"] = fc1_weight_shape[0] - fc1_activation_kwargs["b_dtype"] = torch.float8_e4m3fn + fc1_activation_kwargs["b_dtype"] = data_dtype fc1_activation_kwargs["b_major"] = "k" fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) @@ -409,96 +478,173 @@ def fuser_forward( # k/128, 4 (block row), sum(m_splits)/128, 1) activation_in = fc1_kernel_out["c_tensor"] activation_in = activation_in.view(in_shape[0], fc1_weight_shape[0]) - fc2_in_row_data = fc1_kernel_out["d_tensor"] - fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] - fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) - - fc2_in_col_data = fc1_kernel_out["d_col_tensor"] - fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) - fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] - fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - # Repack columnwise scales on GPU to preserve group ordering. - - # FC2 inputs scales are already swizzled/optimized for GEMM - grouped_fc2_x = GroupedTensor( - shape=(in_shape[0], fc2_weight_shape[1]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc2_input_quantizer, - data=fc2_in_row_data.reshape(-1), - columnwise_data=fc2_in_col_data.reshape(-1), - scale_inv=fc2_in_row_scale.reshape(-1), - columnwise_scale_inv=fc2_in_col_scale.reshape(-1), - first_dims=split_sizes, - tensor_offsets=fc2_x_tensor_offsets, - with_gemm_swizzled_scales=True, - ) # FC2 GEMM fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] fc2_scales = basic_op_extra_inputs[2][1] if fc2_op._scale_bias else None - fc2_scales_tensor = ( - fc2_scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1) - if fc2_scales is not None - else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) - ) - fc2_quant_kwargs = { - "a_tensor": fc1_kernel_out["d_tensor"], - "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], - "padded_offsets": split_points, - "alpha_tensor": alpha_tensor, - "bias_tensor": fc2_bias_packed, - "norm_const_tensor": None, - "prob_tensor": fc2_scales_tensor, - "acc_dtype": torch.float32, - "d_dtype": dtype, - "cd_major": "n", - "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, - "current_stream": current_stream, - "use_dynamic_sched": True, - } - if fc2_op.single_grouped_weight: - # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) - fc2_weight_for_gemm = grouped_fc2_weight.copy() - tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) - - fc2_w_data = fc2_weight_for_gemm.rowwise_data - fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) - fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) - fc2_w_data = fc2_w_data.permute(1, 2, 0) - - fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) - fc2_w_scales = fc2_w_scales.view( + if use_nvfp4: + fc2_bias_for_gemm = None + fc2_bias_scale = None + if fc2_bias_packed is not None: + fc2_bias_for_gemm = fc2_op._get_grouped_bias_for_gemm(dtype) + if fc2_scales is not None: + fc2_bias_scale = fc2_scales.reshape(-1) + if fc2_bias_scale.dtype != torch.float32: + fc2_bias_scale = fc2_bias_scale.to(dtype=torch.float32) + + fc2_in = fc1_kernel_out["d_tensor"] + fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc2_input_quantizer.optimize_for_gemm = True + grouped_fc2_x = _group_quantize_for_grouped_mlp( + fc2_in, + fc2_input_quantizer, num_groups, - (fc2_weight_shape[0] + 127) // 128, - (fc2_weight_shape[1] + 127) // 128, - MXFP8_BLOCK_SCALING_SIZE, - 4, - 4, + split_sizes, + tensor_offsets=fc2_x_tensor_offsets, ) - fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) - fc2_quant_kwargs["b_tensor"] = fc2_w_data - fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + + fc2_out_buf = torch.empty(fc2_out_shape, dtype=dtype, device=device) + if ( + num_groups == 1 + and grouped_fc2_x.columnwise_data is not None + and grouped_fc2_x.columnwise_scale_inv is not None + ): + if fc2_op.single_grouped_weight: + fc2_w_single = grouped_fc2_weight.split_into_quantized_tensors()[0] + else: + fc2_w_single = grouped_fc2_weight[0] + fc2_x_single = _nvfp4_single_tensor_from_grouped( + grouped_fc2_x, + fc2_input_quantizer, + fp4_dtype=fc2_w_single._fp4_dtype, + ) + general_gemm( + fc2_w_single, + fc2_x_single, + out_dtype=dtype, + out=fc2_out_buf, + layout="TN", + use_split_accumulator=False, + ) + if fc2_bias_packed is not None: + token_bias = ( + fc2_bias_packed.transpose(0, 1).contiguous().expand(in_shape[0], -1) + ) + if fc2_scales is not None: + fc2_out_buf = fc2_out_buf + token_bias * fc2_scales.view(-1, 1) + else: + fc2_out_buf = fc2_out_buf + token_bias + else: + fc2_out_offsets = base_split_offsets * fc2_weight_shape[0] + fc2_out_grouped = GroupedTensor( + shape=(in_shape[0], fc2_weight_shape[0]), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=fc2_out_buf.view(-1), + first_dims=split_sizes, + tensor_offsets=fc2_out_offsets, + ) + general_grouped_gemm_for_grouped_tensor( + grouped_fc2_weight, + grouped_fc2_x, + fc2_out_grouped, + layout="TN", + bias=fc2_bias_for_gemm, + bias_scale=fc2_bias_scale, + ) + fc2_out = fc2_out_buf else: - fc2_b_ptrs = tex.copy_data_ptrs_to_device( - [w._rowwise_data for w in grouped_fc2_weight], - device, + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]) + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + + grouped_fc2_x = GroupedTensor( + shape=(in_shape[0], fc2_weight_shape[1]), + dtype=dtype, + num_tensors=num_groups, + quantizer=fc2_input_quantizer, + data=fc2_in_row_data.reshape(-1), + columnwise_data=fc2_in_col_data.reshape(-1), + scale_inv=fc2_in_row_scale.reshape(-1), + columnwise_scale_inv=fc2_in_col_scale.reshape(-1), + first_dims=split_sizes, + tensor_offsets=fc2_x_tensor_offsets, + with_gemm_swizzled_scales=True, ) - fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( - "uniform_mxfp8_rowwise_swizzle", - [w._rowwise_scale_inv for w in grouped_fc2_weight], - device, + + fc2_scales_tensor = ( + fc2_scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1) + if fc2_scales is not None + else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) ) - fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs - fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs - fc2_quant_kwargs["n"] = fc2_weight_shape[0] - fc2_quant_kwargs["b_dtype"] = torch.float8_e4m3fn - fc2_quant_kwargs["b_major"] = "k" + fc2_quant_kwargs = { + "a_tensor": fc1_kernel_out["d_tensor"], + "sfa_tensor": fc1_kernel_out["sfd_row_tensor"], + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "bias_tensor": fc2_bias_packed, + "norm_const_tensor": None, + "prob_tensor": fc2_scales_tensor, + "acc_dtype": torch.float32, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "use_dynamic_sched": True, + } + + if fc2_op.single_grouped_weight: + # Clone and swizzle scales for GEMM (original stays unmodified for save_for_backward) + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) + + fc2_w_data = fc2_weight_for_gemm.rowwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) + fc2_w_data = fc2_w_data.permute(1, 2, 0) + + fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu) + fc2_w_scales = fc2_w_scales.view( + num_groups, + ceil_div(fc2_weight_shape[0], 128), + ceil_div(fc2_weight_shape[1], 128), + MXFP8_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_quant_kwargs["b_tensor"] = fc2_w_data + fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs = tex.copy_data_ptrs_to_device( + [w._rowwise_data for w in grouped_fc2_weight], + device, + ) + swizzle_type = ( + "uniform_nvfp4_swizzle" if use_nvfp4 else "uniform_mxfp8_rowwise_swizzle" + ) + fc2_sfb_ptrs, _fc2_sfb_buffer = tex.transform_and_copy_data_ptrs_to_device( + swizzle_type, + [w._rowwise_scale_inv for w in grouped_fc2_weight], + device, + ) + fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_quant_kwargs["n"] = fc2_weight_shape[0] + fc2_quant_kwargs["b_dtype"] = torch.float8_e4m3fn + fc2_quant_kwargs["b_major"] = "k" - fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) - fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() + fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) + fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() # Save state for backward pass if requires_grad: @@ -517,11 +663,13 @@ def fuser_forward( ) saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x - # Save the input ``GroupedTensor``s themselves for the activations. - for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): - if grouped_fc_x is not None: - grouped_fc_x.rowwise_data = None - grouped_fc_x.scale_inv = None + # MXFP8 wgrad only needs columnwise tiles. NVFP4 generic GEMM fallbacks + # need the full grouped tensor state, including rowwise data and amax. + if not use_nvfp4: + for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): + if grouped_fc_x is not None: + grouped_fc_x.rowwise_data = None + grouped_fc_x.scale_inv = None # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, @@ -586,8 +734,8 @@ def fuser_forward( return fc2_out, [(), (), ()] -class ForwardGroupedMLP_CuTeGEMMGLU_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): - """Fused op for MXFP8 GroupedLinear + scaled GLU + GroupedLinear.""" +class ForwardGroupedMLP_CuTeGEMMGLU(_ForwardGroupedMLP_CuTeGEMMBase): + """Fused op for block-scaled GroupedLinear + scaled GLU + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -598,8 +746,8 @@ def grouped_gemm_activation_kernel(cls) -> Callable: return grouped_gemm_glu_wrapper_sm100 -class ForwardGroupedMLP_CuTeGEMMUnary_MXFP8(_ForwardGroupedMLP_CuTeGEMMBase_MXFP8): - """Fused op for MXFP8 GroupedLinear + scaled unary activation + GroupedLinear.""" +class ForwardGroupedMLP_CuTeGEMMUnary(_ForwardGroupedMLP_CuTeGEMMBase): + """Fused op for block-scaled GroupedLinear + scaled unary activation + GroupedLinear.""" @classmethod @functools.lru_cache(maxsize=None) @@ -641,7 +789,7 @@ def fuse_forward_ops( return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=ForwardGroupedMLP_CuTeGEMMGLU_MXFP8, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMGLU, ) @@ -653,16 +801,18 @@ def fuse_forward_srelu_ops( ) -> list[FusibleOperation]: """Apply GroupedLinear + ScaledSReLU + GroupedLinear fusion for forward pass.""" + if recipe is None or not recipe.mxfp8(): + return ops return fuse_grouped_mlp_ops( ops, recipe=recipe, - fused_op_cls=ForwardGroupedMLP_CuTeGEMMUnary_MXFP8, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMUnary, activation_op_types=(ScaledSReLU,), ) # Register fusion if available -if ForwardGroupedMLP_CuTeGEMMGLU_MXFP8.is_supported(): +if ForwardGroupedMLP_CuTeGEMMGLU.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) -if ForwardGroupedMLP_CuTeGEMMUnary_MXFP8.is_supported(): +if ForwardGroupedMLP_CuTeGEMMUnary.is_supported(): register_forward_fusion(fuse_forward_srelu_ops, prepend=True) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 250daec67f..fd8f817b33 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -626,8 +626,15 @@ def get_sm_count() -> int: return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count +def ceil_div(numerator, denominator): + """Integer ceiling division: ``ceil(numerator / denominator)``.""" + if denominator == 0: + raise ValueError("denominator cannot be zero.") + return (numerator + denominator - 1) // denominator + + def round_up_to_nearest_multiple(value, multiple): - """Round up `value` to the next mutiple of `multiple`""" + """Round up `value` to the next multiple of `multiple`""" if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple