From f25a1f583c71b26db4a6dcf4e65aed5a1e203033 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Tue, 2 Jun 2026 03:58:34 -0700 Subject: [PATCH] Support selective offload for fused grouped MLP Signed-off-by: hongbinl --- .../pytorch/ops/fused/forward_grouped_mlp.py | 47 +++++++++++++++++-- .../tensor/storage/grouped_tensor_storage.py | 15 ++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f4f2108578..654ca50b6c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,6 +13,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -23,6 +24,7 @@ mark_grouped_tensor, ) from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU @@ -298,6 +300,7 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True + fc1_input_quantizer.internal = True input_quantizer = getattr(input_, "quantizer", None) if isinstance(input_, GroupedTensor) and ( isinstance(fc1_input_quantizer, MXFP8Quantizer) @@ -305,7 +308,30 @@ def fuser_forward( or isinstance(fc1_input_quantizer, NVFP4Quantizer) and isinstance(input_quantizer, NVFP4Quantizer) ): - grouped_fc1_x = input_ + grouped_fc1_x = GroupedTensorStorage( + shape=input_.logical_shape, + dtype=input_.fake_dtype, + num_tensors=input_.num_tensors, + shapes=input_.tensor_shapes, + quantizer=input_.quantizer, + data=input_.rowwise_data, + columnwise_data=input_.columnwise_data, + scale_inv=input_.scale_inv, + columnwise_scale_inv=input_.columnwise_scale_inv, + amax=input_.amax, + columnwise_amax=input_.columnwise_amax, + scale=input_.scale, + first_dims=input_.first_dims, + last_dims=input_.last_dims, + tensor_offsets=input_.tensor_offsets, + offsets=input_.offsets, + scale_inv_offsets=input_.scale_inv_offsets, + columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales, + row_scaled_nvfp4=input_.row_scaled_nvfp4, + nvfp4_use_4over6=input_.nvfp4_use_4over6, + nvfp4_e4m3_max=input_.nvfp4_e4m3_max, + ) else: fc1_x = maybe_dequantize(input_, dtype) grouped_fc1_x = _group_quantize_for_grouped_mlp( @@ -567,7 +593,7 @@ def fuser_forward( fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - grouped_fc2_x = GroupedTensor( + grouped_fc2_x = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[1]), dtype=dtype, num_tensors=num_groups, @@ -650,6 +676,9 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] + cpu_offloading = is_cpu_offload_enabled() + no_offload_fc1_activation = bool(getattr(fc1_op, "no_offload_activation", False)) + no_offload_moe_activation = bool(getattr(activation_op, "no_offload_activation", False)) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -677,6 +706,10 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) + if cpu_offloading: + if no_offload_fc1_activation: + mark_not_offload(grouped_fc1_x) + mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -695,6 +728,8 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation + if cpu_offloading and no_offload_moe_activation: + mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True @@ -710,7 +745,13 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - fc2_saved: list[Optional[torch.Tensor]] = [ + if cpu_offloading: + if saved_grouped_fc2_x is not None: + # FC2 input is saved for FC2 wgrad, but it is not the Megatron moe_act + # activation target controlled above. Keep this extra saved tensor resident. + mark_not_offload(saved_grouped_fc2_x) + mark_not_offload(*fc2_weight_tensors) + fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..c112634024 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -387,6 +387,21 @@ def restore_from_saved( self.tensor_offsets = tensors[9] return tensors[10:] + def get_data_tensors(self): + """Get tensor fields that may be saved or offloaded.""" + return ( + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ) + def clear(self) -> None: """ Reset tensor data and clear all buffers.