diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 034d404439..b5b237a398 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,10 +13,12 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload from ...quantization import Recipe from ...tensor import Quantizer from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU @@ -288,10 +290,32 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True + fc1_input_quantizer.internal = True if isinstance(input_, GroupedTensor) and isinstance( getattr(input_, "quantizer", None), MXFP8Quantizer ): - grouped_fc1_x = input_ + grouped_fc1_x = GroupedTensorStorage( + shape=input_.logical_shape, + dtype=input_.fake_dtype, + num_tensors=input_.num_tensors, + shapes=input_.tensor_shapes, + quantizer=input_.quantizer, + data=input_.rowwise_data, + columnwise_data=input_.columnwise_data, + scale_inv=input_.scale_inv, + columnwise_scale_inv=input_.columnwise_scale_inv, + amax=input_.amax, + columnwise_amax=input_.columnwise_amax, + scale=input_.scale, + first_dims=input_.first_dims, + last_dims=input_.last_dims, + tensor_offsets=input_.tensor_offsets, + offsets=input_.offsets, + scale_inv_offsets=input_.scale_inv_offsets, + columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales, + row_scaled_nvfp4=input_.row_scaled_nvfp4, + ) else: fc1_x = maybe_dequantize(input_, dtype) grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) @@ -419,7 +443,7 @@ def fuser_forward( # Repack columnwise scales on GPU to preserve group ordering. # FC2 inputs scales are already swizzled/optimized for GEMM - grouped_fc2_x = GroupedTensor( + grouped_fc2_x = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[1]), dtype=dtype, num_tensors=num_groups, @@ -500,6 +524,11 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] + cpu_offloading = is_cpu_offload_enabled() + offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) + offload_activation_input = bool( + getattr(activation_op, "fine_grained_activation_offloading", False) + ) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -513,7 +542,7 @@ def fuser_forward( ) saved_grouped_fc2_x = None if recompute_srelu_fc2_x else grouped_fc2_x - # Save the input ``GroupedTensor``s themselves for the activations. + # Save the input grouped tensor storages themselves for the activations. for grouped_fc_x in (grouped_fc1_x, saved_grouped_fc2_x): if grouped_fc_x is not None: grouped_fc_x.rowwise_data = None @@ -525,6 +554,10 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) + if cpu_offloading: + if not offload_fc1_input: + mark_not_offload(grouped_fc1_x) + mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -543,6 +576,8 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation + if cpu_offloading and not offload_activation_input: + mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True @@ -558,7 +593,11 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - fc2_saved: list[Optional[torch.Tensor]] = [ + if cpu_offloading: + if saved_grouped_fc2_x is not None: + mark_not_offload(saved_grouped_fc2_x) + mark_not_offload(*fc2_weight_tensors) + fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..c112634024 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -387,6 +387,21 @@ def restore_from_saved( self.tensor_offsets = tensors[9] return tensors[10:] + def get_data_tensors(self): + """Get tensor fields that may be saved or offloaded.""" + return ( + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ) + def clear(self) -> None: """ Reset tensor data and clear all buffers.