diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py index badab1d199..46bf342a84 100644 --- a/benchmarks/benchmark_rht_cast.py +++ b/benchmarks/benchmark_rht_cast.py @@ -8,16 +8,16 @@ import torch.utils.benchmark as benchmark import transformer_engine.pytorch as te -import transformer_engine_torch as tex import transformer_engine.pytorch.cpp_extensions as ext +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer scale_padding_to = 1 permute_scale = False TORCH_TO_TE_FLOAT_MAP = { - torch.bfloat16: tex.DType.kBFloat16, + torch.bfloat16: TE_DType.kBFloat16, } @@ -31,7 +31,7 @@ def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16): # Quantize nvfp4_quantizer = NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, rowwise=True, columnwise=True, with_amax_reduction=False, diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index 9b21807255..df975b9e3a 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -204,9 +204,9 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer - import transformer_engine_torch as tex + from transformer_engine.pytorch.constants import TE_DType - fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2 + fp8_type = TE_DType.kFloat8E4M3 if fp8_format == "e4m3" else TE_DType.kFloat8E5M2 scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 9f6b4944e6..5370fab7bb 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -14,6 +14,7 @@ ) from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch import ( autocast, @@ -323,34 +324,34 @@ def run_dpa_with_cp( ).cuda() if scaling_mode == "delayed": qkv_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, scale=torch.tensor([1], dtype=torch.float32).cuda(), amax=torch.tensor([0], dtype=torch.float32).cuda(), ) dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, + fp8_dtype=TE_DType.kFloat8E5M2, scale=torch.tensor([1], dtype=torch.float32).cuda(), amax=torch.tensor([0], dtype=torch.float32).cuda(), ) if scaling_mode == "current": qkv_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device="cuda", ) dout_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, + fp8_dtype=TE_DType.kFloat8E5M2, device="cuda", ) if scaling_mode == "mxfp8": qkv_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=True, ) qkv_quantizer.optimize_for_gemm = True qkv_quantizer.internal = False dout_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, + fp8_dtype=TE_DType.kFloat8E5M2, rowwise=True, columnwise=True, ) diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 285ec7ba0c..b6e6b64fe5 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist import transformer_engine -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType import nvdlfw_inspect.api as debug_api from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.pytorch import is_fp8_available @@ -683,7 +683,7 @@ def _run_test_with_combinations( ) # test_fake_quant_fp8 - dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None] + dtype_options = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, None] _run_test_with_combinations( test_fake_quant_fp8, dtype_options, diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index 5387634cb3..fc95171351 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -4,12 +4,12 @@ import torch from transformer_engine.pytorch import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch.constants import TE_DType import nvdlfw_inspect.api as debug_api try: import transformer_engine - import transformer_engine_torch as tex except (ImportError, ModuleNotFoundError): print("Could not find TransformerEngine package.") exit(1) @@ -128,12 +128,12 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): default_quantizer1 = Float8Quantizer( scale=torch.tensor([1]).cuda(), amax=torch.tensor([0]).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) default_quantizer2 = Float8Quantizer( scale=torch.tensor([1]).cuda(), amax=torch.tensor([0]).cuda(), - fp8_dtype=tex.DType.kFloat8E5M2, + fp8_dtype=TE_DType.kFloat8E5M2, ) output1 = debug_api.transformer_engine.modify_tensor( @@ -145,7 +145,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): tensor=tensor, ) assert type(output1) == Float8Tensor - assert output1._fp8_dtype == tex.DType.kFloat8E4M3 + assert output1._fp8_dtype == TE_DType.kFloat8E4M3 output2 = debug_api.transformer_engine.modify_tensor( "decoder.1.mlp.fc1", @@ -156,7 +156,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): iteration=0, ) assert type(output2) == Float8Tensor - assert output2._fp8_dtype == tex.DType.kFloat8E5M2 + assert output2._fp8_dtype == TE_DType.kFloat8E5M2 assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", @@ -234,7 +234,7 @@ def test_statistics_collection(configs_dir, feature_dirs): quantizer = Float8Quantizer( scale=torch.full([1], 1.0).cuda(), amax=torch.full([1], 1.0).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) tensor_fp8 = quantizer(tensor) @@ -372,7 +372,7 @@ def log_stats(): quantizer = Float8Quantizer( scale=torch.full([1], 1.0).cuda(), amax=torch.full([1], 1.0).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) def fp8_tensor(t): diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py index ab9a2d054a..26cfaa1d44 100644 --- a/tests/pytorch/debug/test_numerics.py +++ b/tests/pytorch/debug/test_numerics.py @@ -15,7 +15,7 @@ import nvdlfw_inspect.api as debug_api import transformer_engine.debug import transformer_engine.pytorch as tepytorch -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.pytorch.quantization import _default_sf_compute from transformer_engine.pytorch import ( @@ -57,7 +57,7 @@ def _cast_to_fp8(tensor, scale, dtype): def _get_current_scale(tensor, fp8_dtype): - if fp8_dtype == tex.DType.kFloat8E4M3: + if fp8_dtype == TE_DType.kFloat8E4M3: fp8_max = Format.E4M3.value.max_fwd else: fp8_max = Format.E5M2.value.max_fwd @@ -93,19 +93,19 @@ def _emulate_linear( input: torch.Tensor, weight: torch.Tensor, fprop_fp8: bool = False, - fprop_input_fake_quant: tex.DType = None, + fprop_input_fake_quant: TE_DType = None, fprop_input_scale: torch.Tensor = None, - fprop_weight_fake_quant: tex.DType = None, + fprop_weight_fake_quant: TE_DType = None, fprop_weight_scale: torch.Tensor = None, dgrad_fp8: bool = False, - dgrad_gradient_fake_quant: tex.DType = None, + dgrad_gradient_fake_quant: TE_DType = None, dgrad_gradient_scale: torch.Tensor = None, - dgrad_weight_fake_quant: tex.DType = None, + dgrad_weight_fake_quant: TE_DType = None, dgrad_weight_scale: torch.Tensor = None, wgrad_fp8: bool = False, - wgrad_gradient_fake_quant: tex.DType = None, + wgrad_gradient_fake_quant: TE_DType = None, wgrad_gradient_scale: torch.Tensor = None, - wgrad_input_fake_quant: tex.DType = None, + wgrad_input_fake_quant: TE_DType = None, wgrad_input_scale: torch.Tensor = None, loss_multiplier: float = 1.0, activation_sync=None, @@ -116,10 +116,10 @@ def _emulate_linear( activation = _fp8_gemm_kernel( input, _scalar(fprop_input_scale or 1.0), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, weight, _scalar(fprop_weight_scale or 1.0), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, _2X_ACC_FPROP, ) activation = activation.clone().detach().contiguous().requires_grad_(True) @@ -152,10 +152,10 @@ def _emulate_linear( dgrad = _fp8_gemm_kernel( weight.T, _scalar(dgrad_weight_scale or 1.0), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, gradient, _scalar(dgrad_gradient_scale or 1.0), - tex.DType.kFloat8E5M2, + TE_DType.kFloat8E5M2, _2X_ACC_DGRAD, ).T else: @@ -176,10 +176,10 @@ def _emulate_linear( wgrad = _fp8_gemm_kernel( input.T, _scalar(wgrad_input_scale or 1.0), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, gradient.T, _scalar(wgrad_gradient_scale or 1.0), - tex.DType.kFloat8E5M2, + TE_DType.kFloat8E5M2, _2X_ACC_WGRAD, ).T else: @@ -470,17 +470,17 @@ def set_scaling_factors(model, input_kwargs, fp8_kwargs): def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs): # Compute per tensor scaling factor if respective flag in input_kwargs is set. if input_kwargs["fprop_inp"]: - fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3 + fp8_kwargs["fprop_input_scale"] = TE_DType.kFloat8E4M3 if input_kwargs["fprop_weight"]: - fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3 + fp8_kwargs["fprop_weight_scale"] = TE_DType.kFloat8E4M3 if input_kwargs["dgrad_grad"]: - fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + fp8_kwargs["dgrad_gradient_scale"] = TE_DType.kFloat8E5M2 if input_kwargs["dgrad_weight"]: - fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3 + fp8_kwargs["dgrad_weight_scale"] = TE_DType.kFloat8E4M3 if input_kwargs["wgrad_grad"]: - fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2 + fp8_kwargs["wgrad_gradient_scale"] = TE_DType.kFloat8E5M2 if input_kwargs["wgrad_input"]: - fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3 + fp8_kwargs["wgrad_input_scale"] = TE_DType.kFloat8E4M3 @create_config_file @@ -651,7 +651,7 @@ def init_and_warmup(): all_combinations = list( - itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6) + itertools.product([TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, None], repeat=6) ) subset_combinations = random.sample(all_combinations, 10) @@ -687,7 +687,7 @@ def test_fake_quant_fp8( def fake_quant_fp8_create_config( fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file ): - format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"} + format_to_str = {TE_DType.kFloat8E4M3: "FP8E4M3", TE_DType.kFloat8E5M2: "FP8E5M2"} gemms = "" def _add_tensor(quant_format, tensor): diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 96a7e43231..87b6b1309a 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -25,6 +25,7 @@ ) import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -473,7 +474,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if opts.quantization == "fp8": # Structure to maintain amax and scale/scale_inv information for the kernel and input num_gemms = 6 if ub_obj2 is not None else 3 - fp8_dtype = tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E4M3 fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda") fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda") @@ -516,7 +517,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype ) elif opts.quantization == "mxfp8": - fp8_dtype = tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E4M3 inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False) ker_quantizer = MXFP8Quantizer(fp8_dtype) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 8e24e636e8..d8b223ae5c 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -15,7 +15,6 @@ import torch from torch import nn import torch.distributed as dist -import transformer_engine_torch as tex from transformer_engine.common.recipe import ( MXFP8BlockScaling, DelayedScaling, @@ -27,7 +26,7 @@ QParams, ) from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer -from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE, TE_DType from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors @@ -399,7 +398,7 @@ def _test_quantizer(input_dtype, fp8_dtype): Args: input_dtype (torch.dtype): The data type of the input. - fp8_dtype (tex.DType): The data type of the fp8. + fp8_dtype (TE_DType): The data type of the fp8. """ M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE @@ -443,7 +442,7 @@ def test_quantizer(): return input_dtypes = [torch.float32, torch.bfloat16] - fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + fp8_dtypes = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2] for input_dtype in input_dtypes: for fp8_dtype in fp8_dtypes: @@ -514,7 +513,7 @@ def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls): Args: input_dtype (torch.dtype): The data type of the input. - low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8. + low_precision_dtype (TE_DType): The data type of the low precision, can be fp4 or fp8. """ M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2 @@ -623,8 +622,8 @@ def test_quantized_all_gather(): return input_dtypes = [torch.bfloat16] - fp4_dtype = [tex.DType.kFloat4E2M1] - fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + fp4_dtype = [TE_DType.kFloat4E2M1] + fp8_dtype = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2] quantizer_cls_nvfp4 = [NVFP4Quantizer] # add FP8 quantizers if needed quantizer_cls_fp8 = [] diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index c484038938..8903db46b4 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -29,7 +29,7 @@ is_bf16_available, ) import transformer_engine.pytorch.ops as te_ops -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -107,17 +107,17 @@ def make_reference_and_test_tensors( quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + test = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3)(test) elif quantization == "nvfp4": test = NVFP4Quantizer( with_rht=False, diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 3dcefd46fd..09ef23ae05 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -22,6 +22,7 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.ops.fused import ( UserbuffersBackwardLinear, UserbuffersForwardLinear, @@ -156,17 +157,17 @@ def make_reference_and_test_tensors( quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + test = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3)(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -372,7 +373,7 @@ def _test_linear( tols = dtype_tols( model[0].weight._fp8_dtype if isinstance(model[0].weight, Float8Tensor) - else tex.DType.kFloat8E4M3 + else TE_DType.kFloat8E4M3 ) # Check results diff --git a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py index c2f8e8de12..bb6b42ee09 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py +++ b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @@ -6,6 +6,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.pytorch.constants import TE_DType import pytest import torch @@ -139,7 +140,7 @@ def check_grouped_tensor_mxfp8_versus_reference( optimize_for_gemm: bool = False, ) -> None: - te_dtype = tex.DType.kFloat8E4M3 + te_dtype = TE_DType.kFloat8E4M3 split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") @@ -236,7 +237,7 @@ def check_grouped_tensor_mxfp8_with_paged_stashing( optimize_for_gemm: bool = False, ) -> None: - te_dtype = tex.DType.kFloat8E4M3 + te_dtype = TE_DType.kFloat8E4M3 assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" diff --git a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py index 6f0700809b..1f745bb5af 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py +++ b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py @@ -4,8 +4,8 @@ import transformer_engine.pytorch as te -import transformer_engine_torch as tex from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage import pytest @@ -43,7 +43,7 @@ def check_mxfp8_quantize_swizzle_fusion( return_transpose: bool, ) -> None: - te_dtype = tex.DType.kFloat8E4M3 + te_dtype = TE_DType.kFloat8E4M3 # Setup device and random seed device = "cuda" diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a7ea4f089f..7b670819f3 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -29,7 +29,7 @@ def check_nvfp4_gemm_versus_reference( w_columnwise: bool = False, row_scaled_nvfp4: bool = False, ): - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 # Setup device and random seed device = "cuda" @@ -233,7 +233,7 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, ): - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 device = "cuda" torch.manual_seed(23) torch.cuda.manual_seed(23) @@ -322,7 +322,7 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( K: int, N: int, ): - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 device = "cuda" torch.manual_seed(37) torch.cuda.manual_seed(37) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 20a91bf6fe..a4cd61f116 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -45,7 +45,7 @@ def check_group_quantization_nvfp4_versus_reference( with_random_sign_mask: bool = True, ) -> None: - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 # Setup device and random seed device = "cuda" diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index d46a874695..ecdc532cb0 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -55,7 +55,7 @@ def check_grouped_tensor_nvfp4_versus_reference( optimize_for_gemm: bool = False, ) -> None: - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") @@ -172,7 +172,7 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( optimize_for_gemm: bool = False, ) -> None: - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..9f59f063c3 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -50,7 +50,7 @@ def check_quantization_nvfp4_versus_reference( row_scaled_nvfp4, return_transpose, with_2d_quantization ) - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 # Setup device and random seed device = "cuda" @@ -226,7 +226,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ): maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 device = "cuda" seed = 0 @@ -337,7 +337,7 @@ def test_nvfp4_quantization_boundary_values( """ maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 device = "cuda" seed = 123 @@ -452,7 +452,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ): maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 device = "cuda" seed = 17 diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 2d159dbf6a..3d0861868b 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -46,7 +46,7 @@ def check_quantization_nvfp4_versus_reference( ) -> None: assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled." - te_dtype = tex.DType.kFloat4E2M1 + te_dtype = TE_DType.kFloat4E2M1 # Setup device and random seed device = "cuda" diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py index c4a6d73d70..c453b93cd4 100644 --- a/tests/pytorch/references/ref_per_tensor_cs.py +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -3,9 +3,8 @@ # See LICENSE for license information. import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.constants import TE_DType_To_Torch +from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch from references.quantize_scale_calc import scale_from_amax_tensor @@ -40,7 +39,7 @@ def _multi_dim_transpose(tensor): # current scaling reference quantization def ref_per_tensor_cs_cast( tensor: torch.Tensor, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, return_transpose: bool = False, force_pow_2_scales: bool = False, amax_epsilon: float = 0.0, diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 50196782f2..1c0d782e7e 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -20,7 +20,7 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from utils import ModelConfig, skip_unsupported_backward_override -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType # Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() @@ -157,23 +157,23 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) return tensor elif recipe.delayed(): quantizer = te.tensor.float8_tensor.Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, scale=torch.tensor([1.0], device="cuda"), amax=torch.tensor([1.0], device="cuda"), ) return quantizer(tensor) elif recipe.float8_current_scaling(): quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, device="cuda" + fp8_dtype=TE_DType.kFloat8E4M3, device="cuda" ) return quantizer(tensor) elif recipe.float8_block_scaling(): quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=True ) return quantizer(tensor) elif recipe.mxfp8(): - quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) return quantizer(tensor) elif recipe.nvfp4(): quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 62a6291797..b91839b522 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -8,7 +8,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.common import recipe -from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx +from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx, TE_DType from transformer_engine.pytorch import ( autocast, Linear, @@ -100,10 +100,10 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): if role is None: - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) @@ -137,10 +137,10 @@ def test_custom_recipe_grouped_linear_sanity(): def quantizer_factory(role): if role is None: - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) @@ -183,11 +183,11 @@ def test_custom_recipe_matches_current_scaling(): ref_fwd_out = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] ref_bwd_go = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] ref_bwd_gi = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] - assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 - assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 - assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 - assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2 + assert ref_fwd_in.dtype == TE_DType.kFloat8E4M3 + assert ref_fwd_w.dtype == TE_DType.kFloat8E4M3 + assert ref_fwd_out.dtype == TE_DType.kFloat8E4M3 + assert ref_bwd_go.dtype == TE_DType.kFloat8E5M2 + assert ref_bwd_gi.dtype == TE_DType.kFloat8E5M2 # Stress dynamic range in grad_output scale = torch.ones(out_features, device="cuda", dtype=torch.float32) @@ -199,10 +199,10 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): if role is None: - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) @@ -214,11 +214,11 @@ def quantizer_factory(role): cus_fwd_out = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT] cus_bwd_go = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1] cus_bwd_gi = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] - assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 - assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 - assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 - assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert cus_bwd_gi.dtype == tex.DType.kFloat8E4M3 # role=None fallback + assert cus_fwd_in.dtype == TE_DType.kFloat8E4M3 + assert cus_fwd_w.dtype == TE_DType.kFloat8E4M3 + assert cus_fwd_out.dtype == TE_DType.kFloat8E4M3 + assert cus_bwd_go.dtype == TE_DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == TE_DType.kFloat8E4M3 # role=None fallback loss_custom = (out_custom.float() * scale.view(1, -1)).sum() loss_custom.backward() @@ -256,10 +256,10 @@ def test_custom_recipe_ops_linear_2_1_layout(): def quantizer_factory(role): if role is None: - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom = recipe.CustomRecipe(qfactory=quantizer_factory) @@ -300,14 +300,14 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): def quantizer_factory(role): if role is None: counts[None] += 1 - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device=torch.device("cuda")) assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" assert role.module_type == "linear" if role.tensor_type in counts: counts[role.tensor_type] += 1 if role.tensor_type == "grad_output": - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E5M2, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device=torch.device("cuda")) custom = recipe.CustomRecipe(qfactory=quantizer_factory) @@ -336,7 +336,7 @@ def test_factories_return_distinct_instances_and_buffers(): def factory(): scale = torch.ones(1, dtype=torch.float32, device="cuda") amax = torch.zeros(1, dtype=torch.float32, device="cuda") - return Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + return Float8Quantizer(scale=scale, amax=amax, fp8_dtype=TE_DType.kFloat8E4M3) q1 = factory() q2 = factory() @@ -723,7 +723,7 @@ def test_grouped_linear_module_type_dispatch(): def recording_factory(role): recorded_roles.append(role) - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom_recipe = recipe.CustomRecipe(qfactory=recording_factory) @@ -825,7 +825,7 @@ def mixed_factory(role): # Only weight gets delayed scaling, rest get current scaling if role is not None and role.tensor_type == "weight": return DelayedScalingRequest(fp8_format=Format.HYBRID) - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + return Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, device="cuda") custom_recipe = recipe.CustomRecipe(qfactory=mixed_factory) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 3b964a5af9..169ecfc67f 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -761,7 +761,7 @@ class TestFP8CurrentScalingNativeVsRef: def _make_quantizers(rowwise=True, columnwise=True): # TE native FP8 current scaling quantizer te_quant = te.Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device=torch.device("cuda"), rowwise=rowwise, columnwise=columnwise, diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 7add4ee5ab..fe2075e45f 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -17,16 +17,17 @@ get_device_compute_capability, ) import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] # TE FP8 dtypes -_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] +_fp8_dtypes: List[TE_DType] = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2] # Numerical tolerances with FP8 types -_tols: Dict[tex.DType, Dict[str, float]] = { - tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08), - tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), +_tols: Dict[TE_DType, Dict[str, float]] = { + TE_DType.kFloat8E4M3: dict(rtol=0.125, atol=0.08), + TE_DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), } @@ -59,7 +60,7 @@ def setup_class(cls) -> None: def test_constructor( self, dims: DimsType = 1, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, dtype: torch.dtype = torch.float32, is_2D_scaled: bool = True, ) -> None: @@ -140,7 +141,7 @@ def _test_quantize_dequantize( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_quantize_dequantize_dtypes( - self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, block_scaling_dim: int ) -> None: atol = _tols[fp8_dtype]["atol"] rtol = _tols[fp8_dtype]["rtol"] @@ -156,7 +157,7 @@ def test_quantize_dequantize_dtypes( @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("block_scaling_dim", [1]) def test_quantize_dequantize_columnwise_only( - self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, block_scaling_dim: int ) -> None: atol = _tols[fp8_dtype]["atol"] rtol = _tols[fp8_dtype]["rtol"] @@ -181,10 +182,10 @@ def test_quantize_dequantize_dims( block_scaling_dim: int, dq_columnwise: bool, ) -> None: - atol = _tols[tex.DType.kFloat8E4M3]["atol"] - rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + atol = _tols[TE_DType.kFloat8E4M3]["atol"] + rtol = _tols[TE_DType.kFloat8E4M3]["rtol"] quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, @@ -206,10 +207,10 @@ def test_quantize_dequantize_dims( def test_quantize_dequantize_compact_format( self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool ) -> None: - atol = _tols[tex.DType.kFloat8E4M3]["atol"] - rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + atol = _tols[TE_DType.kFloat8E4M3]["atol"] + rtol = _tols[TE_DType.kFloat8E4M3]["rtol"] quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, @@ -229,7 +230,7 @@ def test_quantize_dequantize_compact_format( @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) @pytest.mark.parametrize("dq_columnwise", [True, False]) def test_quantize_dequantize_dims_cpp_allocate_output( - self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: TE_DType, dq_columnwise: bool ) -> None: atol = _tols[fp8_dtype]["atol"] rtol = _tols[fp8_dtype]["rtol"] @@ -257,7 +258,7 @@ def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None: x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) y_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) - fp8_dtype = tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E4M3 quantizer = Float8BlockQuantizer( fp8_dtype=fp8_dtype, rowwise=True, @@ -283,7 +284,7 @@ def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: dtype = torch.bfloat16 x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, + fp8_dtype=TE_DType.kFloat8E5M2, rowwise=True, columnwise=True, block_scaling_dim=block_scaling_dim, @@ -316,12 +317,12 @@ def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: x_fp8_loaded_dequant = x_fp8_loaded.dequantize() torch.testing.assert_close(x_fp8_loaded_dequant, x_fp8_dequant) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_inplace_ops( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int ) -> None: """Test in-place operations""" device = "cuda" @@ -353,12 +354,12 @@ def test_inplace_ops( x_fp8.mul_(y_fp8) torch.testing.assert_close(x_fp8.dequantize(), x_hp * y_hp, **_tols[fp8_dtype]) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_out_of_place_ops( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int ) -> None: """Test out-of-place operations""" device = "cuda" @@ -389,12 +390,12 @@ def test_out_of_place_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_hp - y_hp, **_tols[fp8_dtype]) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_view_same_shape( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int ) -> None: """Test view operations that preserve tensor shape""" device = "cuda" @@ -419,13 +420,13 @@ def test_view_same_shape( with pytest.raises(AssertionError): torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype]) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize( "dims", [[16, 16, 512], [16, 16, 512, 16], [12, 7, 11], [13, 14, 16], [2, 3, 5]] ) def test_view_and_reshape_1D( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int] + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: List[int] ) -> None: """Test view operations that preserve tensor shape""" device = "cuda" @@ -472,11 +473,11 @@ def is_bitwise_equal(a, b): assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data) assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str) @pytest.mark.parametrize("dims", [[16, 16, 512, 16], [2, 512, 512, 128], [3, 13, 14, 16]]) def test_view_and_reshape_2D( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int] + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: List[int] ) -> None: """Test view operations that preserve tensor shape""" device = "cuda" @@ -523,12 +524,12 @@ def is_bitwise_equal(a, b): assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data) assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_reshape_same_shape( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int ) -> None: """Test reshape operations that preserve tensor shape""" device = "cuda" @@ -559,12 +560,12 @@ def test_reshape_same_shape( with pytest.raises(AssertionError): torch.testing.assert_close(x_reshape.dequantize(), -x_hp, **_tols[fp8_dtype]) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("block_scaling_dim", [1, 2]) def test_clone_detach( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, block_scaling_dim: int ) -> None: """Test clone and detach operations""" device = "cuda" diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3a3aa8be91..0ac6ee44d3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -44,6 +44,7 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor +from transformer_engine.pytorch.constants import TE_DType import transformer_engine_torch as tex # Import utility functions @@ -170,17 +171,17 @@ def make_reference_and_test_tensors( quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device=test_device, ) test = quantizer(test) elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + test = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3)(test) elif quantization == "nvfp4": test = NVFP4Quantizer( with_rht=False, @@ -1885,9 +1886,9 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute and quantization == "nvfp4": - tols = dtype_tols(tex.DType.kFloat4E2M1) + tols = dtype_tols(TE_DType.kFloat4E2M1) elif quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = dtype_tols(TE_DType.kFloat8E4M3) # Check results assert_close(y_test, y_ref, **tols) @@ -5232,7 +5233,7 @@ def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: total_m = num_groups * m split_sizes = torch.full((num_groups,), m, device=device, dtype=torch.int64) - q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=False) + q = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=False) q.optimize_for_gemm = False torch.manual_seed(0) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c54c9758ff..47a3dbbe1d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -17,7 +17,7 @@ MXFP8Quantizer, NVFP4Quantizer, ) -from transformer_engine.pytorch.constants import TE_DType_To_Torch +from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch import transformer_engine_torch as tex # Check available recipes @@ -61,17 +61,17 @@ def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, i quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device="cuda"), amax=torch.zeros(1, dtype=torch.float32, device="cuda"), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device="cuda", ) quantizer.set_usage(rowwise=True, columnwise=False) elif quantization == "fp8_blockwise": quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=False, force_pow_2_scales=True, @@ -79,7 +79,7 @@ def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, i block_scaling_dim=1, ) elif quantization == "mxfp8": - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) elif quantization == "nvfp4": quantizer = NVFP4Quantizer( with_rht=False, @@ -369,7 +369,7 @@ def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]], output_dbias grouped_input = torch.cat(input_tensors, dim=0) # Create MXFP8 output grouped tensor (rowwise only for easier validation) - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) first_dims = torch.tensor( [shape[0][0] for _ in range(num_tensors)], @@ -417,7 +417,7 @@ def test_bgrad_group_quantize_zero_size_tensor(self) -> None: last_dim = 1024 grouped_input = torch.empty(0, last_dim, dtype=torch.bfloat16, device="cuda") - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) first_dims = torch.zeros(num_tensors, dtype=torch.int64, device="cuda") @@ -440,7 +440,7 @@ def test_group_quantize_cudagraph_capturable(self, output_dbias: bool) -> None: input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] grouped_input = torch.cat(input_tensors, dim=0) - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) first_dims = torch.tensor( [shape[0][0] for _ in range(num_tensors)], @@ -513,7 +513,7 @@ def test_group_dequantize(self, shape: List[Tuple[int, int]]) -> None: input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] grouped_input = torch.cat(input_tensors, dim=0) - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device="cuda") @@ -521,7 +521,7 @@ def test_group_dequantize(self, shape: List[Tuple[int, int]]) -> None: quantized = tex.group_quantize(grouped_input, quantizer, num_tensors, first_dims) # Dequantize. - dequantized = tex.group_dequantize(quantized, tex.DType.kBFloat16) + dequantized = tex.group_dequantize(quantized, TE_DType.kBFloat16) # Verify output metadata. assert dequantized.num_tensors == num_tensors @@ -543,7 +543,7 @@ def test_group_dequantize_cudagraph_capturable(self) -> None: input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] grouped_input = torch.cat(input_tensors, dim=0) - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) first_dims = torch.tensor( [shape[0][0] for _ in range(num_tensors)], @@ -556,12 +556,12 @@ def test_group_dequantize_cudagraph_capturable(self) -> None: # Warmup dequantize. torch.cuda.synchronize() - _ = tex.group_dequantize(quantized, tex.DType.kBFloat16) + _ = tex.group_dequantize(quantized, TE_DType.kBFloat16) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - static_output = tex.group_dequantize(quantized, tex.DType.kBFloat16) + static_output = tex.group_dequantize(quantized, TE_DType.kBFloat16) # Replay with different input data. fresh_input = torch.cat( @@ -575,7 +575,7 @@ def test_group_dequantize_cudagraph_capturable(self) -> None: graph.replay() torch.cuda.synchronize() - expected = tex.group_dequantize(quantized, tex.DType.kBFloat16) + expected = tex.group_dequantize(quantized, TE_DType.kBFloat16) expected_tensors = expected.split_into_quantized_tensors() static_tensors = static_output.split_into_quantized_tensors() for exp, got in zip(expected_tensors, static_tensors): diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..141353fbf0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -53,6 +53,7 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe +from transformer_engine.pytorch.constants import TE_DType import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -3063,7 +3064,7 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): else: rowwise, columnwise = not transb, transb quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise, ) @@ -3159,7 +3160,7 @@ def _make_grouped_tensor_quantized_mxfp8( if not tensors: raise ValueError("Expected non-empty tensor list for grouped quantization.") quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise, ) @@ -3182,7 +3183,7 @@ def _per_tensor_quantize_mxfp8( Used to build reference discrete inputs for grouped GEMM. """ quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise, ) @@ -3299,17 +3300,17 @@ def test_grouped_gemm_grouped_tensor_mxfp8( @pytest.mark.parametrize( "input_quantizer", [ - Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8CurrentScalingQuantizer(fp8_dtype=TE_DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3), ], ) @pytest.mark.parametrize( "out_quantizer", [ - Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8CurrentScalingQuantizer(fp8_dtype=TE_DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3), Float8Quantizer( - torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3 + torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), TE_DType.kFloat8E4M3 ), ], ) @@ -3389,7 +3390,7 @@ def test_fp8_grouped_gemm(shape, accumulate): Float8Quantizer( scale.clone(), amax.clone(), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, ) for _ in range(z) ] @@ -3397,7 +3398,7 @@ def test_fp8_grouped_gemm(shape, accumulate): Float8Quantizer( scale.clone(), amax.clone(), - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, ) for _ in range(z) ] diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 9aea3bc274..e48b837409 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -32,7 +32,7 @@ from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op import transformer_engine.pytorch as te from transformer_engine.common import recipe -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method @@ -88,7 +88,7 @@ def trt_fp8_quantize(t, scale_inv): q = te.tensor.float8_tensor.Float8Quantizer( scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) return q(x)._data.cpu().numpy() @@ -108,7 +108,7 @@ def trt_fp8_dequantize(t, scale_inv): q = te.tensor.float8_tensor.Float8Quantizer( scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) quantizer_tensor = q.create_tensor_from_data(x, fake_dtype=torch.float32) return quantizer_tensor.dequantize().cpu().numpy() @@ -125,7 +125,7 @@ def trt_fp8_dequantize(t, scale_inv): def trt_mxfp8_quantize(t): """MXFP8 quantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() - q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3) + q = te.tensor.mxfp8_tensor.MXFP8Quantizer(TE_DType.kFloat8E4M3) return q(x)._rowwise_data.cpu().numpy(), q(x)._rowwise_scale_inv.cpu().numpy() @@ -142,7 +142,7 @@ def trt_mxfp8_dequantize(t, scale_inv): """MXFP8 dequantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() scale_inv_tensor = torch.from_numpy(scale_inv).cuda() - q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3) + q = te.tensor.mxfp8_tensor.MXFP8Quantizer(TE_DType.kFloat8E4M3) quantizer_tensor = q.create_tensor_from_data(x, scale_inv_tensor, fake_dtype=torch.float32) return quantizer_tensor.dequantize().cpu().numpy() @@ -382,9 +382,9 @@ def dtype2str(dtype: torch.dtype, fake_bf16_io=False): def as_te_type(dtype: torch.dtype): return { - torch.float32: tex.DType.kFloat32, - torch.float16: tex.DType.kFloat16, - torch.bfloat16: tex.DType.kBFloat16, + torch.float32: TE_DType.kFloat32, + torch.float16: TE_DType.kFloat16, + torch.bfloat16: TE_DType.kBFloat16, }[dtype] diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 66c685e139..dcc668c545 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -25,7 +25,7 @@ Float8BlockQuantizer, MXFP8Quantizer, ) -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding import copy @@ -191,19 +191,19 @@ def pytorch_sort_chunks_by_index( return output -def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: +def dtype_tols(te_dtype: TE_DType) -> Dict[str, float]: """Estimated tolerances for a datatype Based on tolerances for torch.testing.assert_close. """ - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: return dict(rtol=1.0e-6, atol=1.0e-6) - if te_dtype == tex.DType.kFloat16: + if te_dtype == TE_DType.kFloat16: return dict(rtol=3.0e-3, atol=1.0e-5) - if te_dtype == tex.DType.kBFloat16: + if te_dtype == TE_DType.kBFloat16: return dict(rtol=2.0e-2, atol=1.0e-5) - if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: + if te_dtype == TE_DType.kFloat8E5M2 or te_dtype == TE_DType.kFloat8E4M3: return dict(rtol=2.0e-1, atol=1.0e-1) raise ValueError(f"Unsuppored dtype ({te_dtype})") @@ -255,11 +255,11 @@ def _test_permutation_index_map( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -476,11 +476,11 @@ def _test_permutation_mask_map( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -704,11 +704,11 @@ def _test_permutation_and_padding_mask_map( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -1000,11 +1000,11 @@ def _test_permutation_and_padding_with_merging_probs( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -1327,11 +1327,11 @@ def _test_moe_chunk_sort( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -1462,11 +1462,11 @@ def _test_permutation_mask_map_alongside_probs( ) # Convert TE dtypes to PyTorch dtypes - if te_dtype == tex.DType.kFloat32: + if te_dtype == TE_DType.kFloat32: dtype = torch.float32 - elif te_dtype == tex.DType.kFloat16: + elif te_dtype == TE_DType.kFloat16: dtype = torch.float16 - elif te_dtype == tex.DType.kBFloat16: + elif te_dtype == TE_DType.kBFloat16: dtype = torch.bfloat16 else: pytest.skip("Invalid dtype.") @@ -1667,9 +1667,9 @@ def perf_test_cuda_kernel(cuda_kernel_fn): # TE tensor dtypes -_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +_te_dtypes: List[TE_DType] = [TE_DType.kFloat32, TE_DType.kFloat16] if te.is_bf16_available(): - _te_dtypes.append(tex.DType.kBFloat16) + _te_dtypes.append(TE_DType.kBFloat16) @pytest.mark.parametrize("te_dtype", _te_dtypes) @@ -1899,7 +1899,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype, use_torch_co @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("te_dtype", [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2]) @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) @@ -2048,9 +2048,9 @@ def test_chunk_permutation_empty_input(te_dtype, use_torch_compile): def test_permutation_single_case(): print("GPU:", torch.cuda.get_device_name(0)) - # te_dtype = tex.DType.kFloat32 - # te_dtype = tex.DType.kFloat16 - te_dtype = tex.DType.kBFloat16 + # te_dtype = TE_DType.kFloat32 + # te_dtype = TE_DType.kFloat16 + te_dtype = TE_DType.kBFloat16 num_tokens = 12 num_expert = 4 @@ -2216,9 +2216,9 @@ def test_benchmark_multiple_cases(): """Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark""" print("GPU:", torch.cuda.get_device_name(0)) - # te_dtype = tex.DType.kFloat32 - # te_dtype = tex.DType.kFloat16 - te_dtype = tex.DType.kBFloat16 + # te_dtype = TE_DType.kFloat32 + # te_dtype = TE_DType.kFloat16 + te_dtype = TE_DType.kBFloat16 ep_size = 64 tp_size = 2 diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 119914fbc3..e44c456b47 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -25,7 +25,7 @@ ) from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from references.ref_per_tensor_cs import ref_per_tensor_cs_cast from utils import assert_close, quantization_tols @@ -33,12 +33,12 @@ # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] # TE FP8 dtypes -_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] +_fp8_dtypes: List[TE_DType] = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2] # Numerical tolerances with FP8 types -_tols: Dict[tex.DType, Dict[str, float]] = { - tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 - tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +_tols: Dict[TE_DType, Dict[str, float]] = { + TE_DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + TE_DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 } @@ -74,7 +74,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: # delayed scaling def to_float8( tensor: torch.Tensor, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 1.0, ) -> Float8Tensor: """Cast tensor to FP8""" @@ -89,7 +89,7 @@ def to_float8( # current scaling def to_float8_CS( tensor: torch.Tensor, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, return_transpose: bool = False, force_pow_2_scales: bool = False, amax_epsilon: float = 0.0, @@ -142,18 +142,18 @@ def make_reference_and_test_tensors( quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=test_device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) test = quantizer(test) elif quantization == "fp8_current_scaling": quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device=test_device, ) test = quantizer(test) elif quantization == "fp8_blockwise": quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=True, force_pow_2_scales=True, @@ -162,7 +162,7 @@ def make_reference_and_test_tensors( ) test = quantizer(test) elif quantization == "mxfp8": - test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + test = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3)(test) elif quantization == "nvfp4": test = NVFP4Quantizer( with_rht=False, @@ -195,7 +195,7 @@ def setup_class(cls) -> None: def test_constructor( self, dims: DimsType = 1, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale_inv: float = 0.375, dtype: torch.dtype = torch.float32, ) -> None: @@ -214,7 +214,7 @@ def test_constructor( def _test_quantize_dequantize( self, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 3.5, dtype: torch.dtype = torch.float32, dims: DimsType = 23, @@ -239,7 +239,7 @@ def _test_quantize_dequantize( @pytest.mark.parametrize("dtype", _dtypes) def test_quantize_dequantize_dtypes( self, - fp8_dtype: tex.DType, + fp8_dtype: TE_DType, dtype: torch.dtype, ) -> None: self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype) @@ -256,7 +256,7 @@ def test_quantize_dequantize_dims(self, dims: DimsType) -> None: @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("noop", [True, False]) def test_quantize_dequantize_noop( - self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool + self, fp8_dtype: TE_DType, dtype: torch.dtype, noop: bool ) -> None: noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda") if noop: @@ -281,7 +281,7 @@ def test_quantize_dequantize_noop( def test_basic_ops( self, dims: DimsType = 23, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 3.5, dtype: torch.dtype = torch.float32, ) -> None: @@ -317,7 +317,7 @@ def test_basic_ops( def test_chunk_op( self, dims: DimsType, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 3.5, dtype: torch.dtype = torch.float32, ) -> None: @@ -346,7 +346,7 @@ def test_chunk_op( def test_inplace_ops( self, dims: DimsType = 23, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 3.5, dtype: torch.dtype = torch.float32, ) -> None: @@ -384,7 +384,7 @@ def test_inplace_ops( def test_serialization( self, dims: DimsType = [2, 3, 5], - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, scale: float = 0.5, dtype: torch.dtype = torch.float32, ): @@ -484,7 +484,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str) def test_quantize( self, - fp8_dtype: tex.DType, + fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, return_transpose: bool, @@ -526,11 +526,11 @@ def test_quantize( x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0 ) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("fp8_dtype", [TE_DType.kFloat8E4M3], ids=str) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]]) def test_quantize_dequantize( - self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType + self, fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType ) -> None: """Check numerical error when casting to FP8 and back""" @@ -772,11 +772,11 @@ def test_update_nd_tensor( quantizer = Float8Quantizer( scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), amax=torch.zeros(1, dtype=torch.float32, device=device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) elif quantization == "fp8_blockwise": quantizer = Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=True, force_pow_2_scales=True, @@ -784,7 +784,7 @@ def test_update_nd_tensor( block_scaling_dim=1, ) elif quantization == "mxfp8": - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3) elif quantization in ("nvfp4", "nvfp4_2d"): quantizer = NVFP4Quantizer( rowwise=True, @@ -826,7 +826,7 @@ def setup_class(cls) -> None: @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]]) def test_mxfp8_dequantize_columnwise_only( self, - fp8_dtype: tex.DType, + fp8_dtype: TE_DType, dtype: torch.dtype, dims: DimsType, ) -> None: @@ -867,7 +867,7 @@ def test_mxfp8_dequantize_columnwise_only( @pytest.mark.parametrize("dims", [[128, 128], [256, 256]]) def test_mxfp8_dequantize_columnwise_only_quantized_separately( self, - fp8_dtype: tex.DType, + fp8_dtype: TE_DType, dims: DimsType, ) -> None: """Check dequantization of MXFP8 tensor quantized with columnwise only""" diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5f5221af76..d8e71adf8a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -23,6 +23,7 @@ ) import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, NVFP4BlockScalingRecipeState, @@ -296,14 +297,14 @@ def check_metas( @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) @pytest.mark.parametrize( - "fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"] + "fp8_dtype", [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2], ids=["E4M3", "E5M2"] ) def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype): - if fp8_dtype == tex.DType.kFloat8E4M3: + if fp8_dtype == TE_DType.kFloat8E4M3: fp8_format = transformer_engine.common.recipe.Format.E4M3 fp8_max = fp8_format.value.max_fwd - elif fp8_dtype == tex.DType.kFloat8E5M2: + elif fp8_dtype == TE_DType.kFloat8E5M2: fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_max = fp8_format.value.max_bwd else: diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..093e32a527 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -35,7 +35,7 @@ is_bf16_available, ) from transformer_engine.common import recipe -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data from utils import ModelConfig, recipe_id, skip_unsupported_backward_override @@ -1031,7 +1031,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): scales = torch.ones(1).cuda().squeeze() amaxes = torch.ones(1).cuda().squeeze() - dtype = tex.DType.kFloat8E4M3 + dtype = TE_DType.kFloat8E4M3 fp8_quantizer = Float8Quantizer(scales, amaxes, dtype) outp_type = datatype @@ -1055,7 +1055,7 @@ def test_replace_raw_data_for_float8tensor(): torch.manual_seed(12345) torch.cuda.manual_seed(12345) - fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda") + fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=TE_DType.kFloat8E4M3, device="cuda") fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda") random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..7caf53575c 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -22,7 +22,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.common import recipe -from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx +from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx, TE_DType from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer @@ -66,7 +66,7 @@ class ToyQuantizer(Float8CurrentScalingQuantizer, metaclass=_ToyQuantizerMeta): opaque value type so torch.compile can treat it as a baked-in constant.""" def __init__(self, tag: str): - super().__init__(fp8_dtype=tex.DType.kFloat8E4M3, device=torch.device("cuda")) + super().__init__(fp8_dtype=TE_DType.kFloat8E4M3, device=torch.device("cuda")) self.tag = tag def __eq__(self, other): diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 2ee18aaf57..4f17b8ccce 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -17,9 +17,9 @@ import torch import transformer_engine -import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import InferenceParams, QuantizedTensor +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, @@ -70,7 +70,7 @@ def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: return dtype -def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: +def dtype_tols(dtype: torch.dtype | TE_DType) -> dict[str, float]: """Estimated numerical error for a datatype Based on tolerances for torch.testing.assert_close. @@ -78,17 +78,17 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """ # Transformer Engine dtypes - if isinstance(dtype, tex.DType): - if dtype == tex.DType.kFloat4E2M1: + if isinstance(dtype, TE_DType): + if dtype == TE_DType.kFloat4E2M1: return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { - tex.DType.kByte: torch.uint8, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - tex.DType.kFloat8E4M3: torch.float8_e4m3fn, - tex.DType.kFloat8E5M2: torch.float8_e5m2, + TE_DType.kByte: torch.uint8, + TE_DType.kInt32: torch.int32, + TE_DType.kFloat32: torch.float32, + TE_DType.kFloat16: torch.half, + TE_DType.kBFloat16: torch.bfloat16, + TE_DType.kFloat8E4M3: torch.float8_e4m3fn, + TE_DType.kFloat8E5M2: torch.float8_e5m2, }[dtype] # PyTorch dtypes @@ -117,9 +117,9 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8", "mxfp8_block_scaling", ): - return dtype_tols(tex.DType.kFloat8E4M3) + return dtype_tols(TE_DType.kFloat8E4M3) if name in ("nvfp4", "nvfp4_row_scaled"): - return dtype_tols(tex.DType.kFloat4E2M1) + return dtype_tols(TE_DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index ed48fe4d61..061a55a96c 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -23,16 +23,10 @@ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ - .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \ - .def("__reduce_ex__", \ - [](transformer_engine::DType self, pybind11::object /*protocol*/) { \ - return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ - pybind11::make_tuple(static_cast(self))); \ - }) \ - .def("__reduce__", [](transformer_engine::DType self) { \ - return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \ - pybind11::make_tuple(static_cast(self))); \ - }); \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ + /* Allow Python int (and IntEnum subclasses like transformer_engine.pytorch.TE_DType) to be */ \ + /* passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. */ \ + pybind11::implicitly_convertible(); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index f48b49b725..0e7c978df1 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -13,16 +13,16 @@ from nvdlfw_inspect.utils import append_parent_docstring -import transformer_engine_torch as tex from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.common.recipe import Format +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.quantization import _default_sf_compute -def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): +def fake_quantize(tensor: torch.Tensor, fp8_format: str, out=None): """Input tensor is quantized to fp8 and then dequantized.""" assert tensor.dtype in ( @@ -43,10 +43,10 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): if fp8_format in ["FP8E4M3", "FP8E5M2"]: if fp8_format == "FP8E4M3": fp8_max = Format.E4M3.value.max_fwd - fp8_dtype = tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E4M3 else: fp8_max = Format.E5M2.value.max_fwd - fp8_dtype = tex.DType.kFloat8E5M2 + fp8_dtype = TE_DType.kFloat8E5M2 amax = tensor.abs().max().float() one = torch.ones(1, device=tensor.device) scale = _default_sf_compute(amax, one, fp8_max, 0) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d26f9ef7f6..12fd5ba086 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -12,7 +12,7 @@ import nvdlfw_inspect.api as debug_api from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter @@ -251,7 +251,7 @@ def update_aux_dict( Yields the aux_dict. Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ - fp8_dtype = tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E4M3 if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: assert isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index a4bab4eaf5..ad111e8afa 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -11,7 +11,7 @@ import nvdlfw_inspect.api as debug_api from nvdlfw_inspect.registry import Registry, api_method -import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Tensor, @@ -22,7 +22,7 @@ def per_tensor_cast( - tensor: torch.Tensor, fp8_dtype: tex.DType, out: Float8Tensor = None + tensor: torch.Tensor, fp8_dtype: TE_DType, out: Float8Tensor = None ) -> Float8Tensor: """ This function computes the scaling factors based on the tensor amax and then casts it to the fp8 @@ -35,8 +35,8 @@ def per_tensor_cast( ), "[NVTORCH INSPECT ERROR] Unsupported tensor type for per tensor current scaling" assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor." assert fp8_dtype in { - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, }, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE." tensor = tensor.contiguous() diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index b0002ffee6..0ae1108d01 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -11,8 +11,8 @@ import torch import torch.nn.functional as F -import transformer_engine_torch as tex from transformer_engine.common.recipe import Format +from transformer_engine.pytorch.constants import TE_DType class BlockwiseDynamicRangeStat( @@ -142,8 +142,8 @@ def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor): # Map each supported FP8 dtype to its corresponding max forward value. dtype_to_max = { - tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd, - tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd, + TE_DType.kFloat8E4M3: Format.E4M3.value.max_fwd, + TE_DType.kFloat8E5M2: Format.E5M2.value.max_fwd, } if dtype not in dtype_to_max: diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7653d5992e..cc1ad1a5dd 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -16,6 +16,7 @@ assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." load_framework_extension("torch") +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP @@ -108,14 +109,9 @@ pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 # To allow for safe unpickling of QuantizedTensors when using DCP -# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its -# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see -# ``transformer_engine/common/util/pybind_helper.h``) so its pickle -# stream encodes as ``(tex.DType, (int,))`` and only the class itself -# needs to be allow-listed below. +# checkpointing with FSDP2. try: from torch.serialization import add_safe_globals - import transformer_engine_torch as tex add_safe_globals( [ @@ -132,8 +128,8 @@ MXFP8Quantizer, NVFP4Quantizer, Float8BlockQuantizer, - # pybind11 enum used as Quantizer.dtype - tex.DType, + # Python IntEnum used as Quantizer.dtype + TE_DType, # __reduce_ex__ reconstructors (module-level functions). _make_float8_tensor_in_reduce_ex, _make_mxfp8_tensor_in_reduce_ex, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b38b66c3e6..18c22a0a6b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -38,6 +38,7 @@ from transformer_engine.pytorch.constants import ( AttnMaskTypes, AttnTypes, + TE_DType, dist_group_type, ) from transformer_engine.pytorch.distributed import ( @@ -1232,11 +1233,11 @@ def forward( forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False) assert forward_dtype in [ - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, ] and backward_dtype in [ - tex.DType.kFloat8E4M3, - tex.DType.kFloat8E5M2, + TE_DType.kFloat8E4M3, + TE_DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" else: fp8_output = False diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f1637cecd..18aaf41e04 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -46,7 +46,10 @@ from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.pytorch.quantization import get_fp8_te_dtype -from transformer_engine.pytorch.constants import TE_DType, MXFP8_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.constants import ( + TE_DType, + MXFP8_BLOCK_SCALING_SIZE, +) from transformer_engine.pytorch.utils import ( @@ -1220,6 +1223,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # Filter: cuDNN support fused_attention_backend = None if use_fused_attention: + # ``TE_DType`` is implicitly convertible to ``transformer_engine::DType`` + # on the C++ side, so pass it straight to the pybind function. q_type = TE_DType[qkv_dtype] kv_type = q_type if fp8 and fp8_meta["recipe"].fp8_dpa: diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 2aff4fd8e8..13e791af53 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -3,42 +3,101 @@ # See LICENSE for license information. """Enums for e2e transformer""" +import enum from types import SimpleNamespace import torch import torch.distributed import transformer_engine_torch as tex -""" -This is a map: torch.dtype -> int -Used for passing dtypes into cuda -extension. Has one to one mapping -with enum in transformer_engine.h -""" -TE_DType = { - torch.uint8: tex.DType.kByte, - torch.float8_e4m3fn: tex.DType.kFloat8E4M3, - torch.float8_e5m2: tex.DType.kFloat8E5M2, - torch.int32: tex.DType.kInt32, - torch.float32: tex.DType.kFloat32, - torch.half: tex.DType.kFloat16, - torch.bfloat16: tex.DType.kBFloat16, +class _TE_DTypeMeta(enum.EnumMeta): + """Metaclass that extends ``cls[key]`` / ``key in cls`` on ``TE_DType``. + + - ``TE_DType[torch.dtype]`` returns the matching ``TE_DType`` member + (replaces the old ``TORCH_DTYPE_TO_TE_DTYPE[dtype]`` pattern). + - ``torch.dtype in TE_DType`` reports whether a mapping exists. + - Anything that is not a ``torch.dtype`` (most notably member-name + strings) falls through to the standard ``EnumMeta`` behavior, so + ``TE_DType["kFloat32"]`` and ``TE_DType.kFloat32 in TE_DType`` + keep working exactly as before. + """ + + def __getitem__(cls, key): + if isinstance(key, torch.dtype): + return _TORCH_DTYPE_TO_TE_DTYPE[key] + return super().__getitem__(key) + + def __contains__(cls, key): + if isinstance(key, torch.dtype): + return key in _TORCH_DTYPE_TO_TE_DTYPE + return super().__contains__(key) + + +class TE_DType(enum.IntEnum, metaclass=_TE_DTypeMeta): + """Python mirror of ``transformer_engine_torch.DType`` (pybind11 enum). + + Members are constructed manually from the underlying pybind enum so + that this class is the single source of truth for dtype tags used + across ``transformer_engine.pytorch``. Using a Python ``IntEnum`` + avoids the per-access cost of looking up attributes on the pybind11 + enum class (which traverses C++ ``tp_getattro``) and reduces + comparisons to plain ``int.__eq__``. + + The custom metaclass adds dict-like lookup by ``torch.dtype``: + ``TE_DType[torch.float32] is TE_DType.kFloat32``. + """ + + kByte = int(tex.DType.kByte) + kInt32 = int(tex.DType.kInt32) + kFloat32 = int(tex.DType.kFloat32) + kFloat16 = int(tex.DType.kFloat16) + kBFloat16 = int(tex.DType.kBFloat16) + kFloat8E4M3 = int(tex.DType.kFloat8E4M3) + kFloat8E5M2 = int(tex.DType.kFloat8E5M2) + kFloat4E2M1 = int(tex.DType.kFloat4E2M1) + + +# Fail fast at import time if a new enumerator is added +# on the C++ side without being mirrored above. +assert {m.name for m in TE_DType} == set(tex.DType.__members__), ( + "TE_DType is out of sync with transformer_engine_torch.DType; " + "add the new pybind enumerator to TE_DType in constants.py." +) + + +# Private one-to-one mapping ``torch.dtype -> TE_DType`` (mirrors the +# enum order in ``transformer_engine.h``). The metaclass above forwards +# ``TE_DType[torch_dtype]`` to this dict, so callers should use the +# bracket syntax on ``TE_DType`` rather than importing this directly. +_TORCH_DTYPE_TO_TE_DTYPE = { + torch.uint8: TE_DType.kByte, + torch.float8_e4m3fn: TE_DType.kFloat8E4M3, + torch.float8_e5m2: TE_DType.kFloat8E5M2, + torch.int32: TE_DType.kInt32, + torch.float32: TE_DType.kFloat32, + torch.half: TE_DType.kFloat16, + torch.bfloat16: TE_DType.kBFloat16, } -""" -This is a map: int -> torch.dtype -Used for resolving cuda extension types to torch. -Has one to one mapping with enum in -transformer_engine.h -""" + +# Map ``TE_DType -> torch.dtype`` for resolving cuda extension types to +# torch. One-to-one with the enum in ``transformer_engine.h``. +# +# C++ sites that stamp dtype tags onto Python tensors (e.g. ``_fp8_dtype``, +# ``_fp4_dtype``) route through the ``MakeTEDType`` helper in +# ``transformer_engine/pytorch/csrc/common.{h,cpp}``, so every key we +# look up here is guaranteed to be a ``TE_DType`` member. Keep this dict +# keyed by ``TE_DType`` (not ``int``) so accidental mixing with the +# pybind11 ``tex.DType`` enum surfaces as a ``KeyError`` instead of +# silently succeeding. TE_DType_To_Torch = { - tex.DType.kByte: torch.uint8, - tex.DType.kFloat8E4M3: torch.float8_e4m3fn, - tex.DType.kFloat8E5M2: torch.float8_e5m2, - tex.DType.kInt32: torch.int32, - tex.DType.kFloat32: torch.float32, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, + TE_DType.kByte: torch.uint8, + TE_DType.kFloat8E4M3: torch.float8_e4m3fn, + TE_DType.kFloat8E5M2: torch.float8_e5m2, + TE_DType.kInt32: torch.int32, + TE_DType.kFloat32: torch.float32, + TE_DType.kFloat16: torch.half, + TE_DType.kBFloat16: torch.bfloat16, } # Cache enum -> int conversions to avoid repeated PyObject lookups. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 2ce939430d..b582c27bd5 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -16,7 +16,7 @@ NVTE_Fused_Attn_Backend, ) from ..quantized_tensor import Quantizer -from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx +from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, TE_DType __all__ = [ @@ -26,12 +26,12 @@ TORCH_DType = { - tex.DType.kFloat8E4M3: torch.uint8, - tex.DType.kFloat8E5M2: torch.uint8, - tex.DType.kFloat16: torch.half, - tex.DType.kBFloat16: torch.bfloat16, - tex.DType.kFloat32: torch.float32, - tex.DType.kInt32: torch.int32, + TE_DType.kFloat8E4M3: torch.uint8, + TE_DType.kFloat8E5M2: torch.uint8, + TE_DType.kFloat16: torch.half, + TE_DType.kBFloat16: torch.bfloat16, + TE_DType.kFloat32: torch.float32, + TE_DType.kInt32: torch.int32, } QKVFormat = { @@ -173,7 +173,7 @@ def fused_attn_fwd( input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) v : torch.Tensor input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) - fake_dtype : tex.DType + fake_dtype : TE_DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype fused_attention_backend : tex.NVTE_Fused_Attn_Backend @@ -455,7 +455,7 @@ def fused_attn_bwd( d_o : torch.Tensor input tensor dO (gradient of O); same data type as Q, K and V; same shape as Q - fake_dtype : tex.DType + fake_dtype : TE_DType data type of Q, K and V - in case of high precision, fake dtype in case of FP8; in torch.dtype aux_ctx_tensors : List[torch.Tensor] diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index edf2c1e1c2..2bf884bd81 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -288,7 +288,7 @@ def general_grouped_gemm( bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, + D_dtype: Optional[TE_DType] = None, single_output=False, ) -> Tuple[List[torch.Tensor], ...]: """ diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 66bb2dc40e..4f85920671 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -86,6 +86,16 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, return transformer_engine::DType::kFloat8E5M2; } +pybind11::object MakeTEDType(transformer_engine::DType dtype) { + // Cache the Python ``TE_DType`` class object on first call so subsequent + // invocations avoid re-importing the module. ``static`` initialization is + // thread-safe under C++11. We are always inside a pybind11-invoked function + // when this runs, so the GIL is held and Python imports are legal. + static pybind11::object te_dtype_cls = + pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType"); + return te_dtype_cls(static_cast(dtype)); +} + TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); std::unique_ptr my_quantizer = convert_quantizer(quantizer); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..8d0093e6eb 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -388,6 +388,24 @@ std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); +/*! @brief Wrap a C++ ``transformer_engine::DType`` as the Python + * ``transformer_engine.pytorch.TE_DType`` ``IntEnum`` member. + * + * pybind11's default ``py::cast`` of a C++ enum produces an instance + * of the pybind11-bound enum class (``tex.DType``), which does not + * compare equal to a ``TE_DType`` member of the same int value (cross- + * type equality is not implemented). To keep tensor attributes like + * ``_fp8_dtype`` typed consistently as ``TE_DType`` regardless of + * whether the tensor was constructed from Python or from C++, all C++ + * sites that set such attributes should go through this helper. + * + * The Python class object is imported once and cached in a function- + * local ``static`` (initialization is thread-safe under C++11), so the + * runtime cost per call is one ``IntEnum.__call__`` (a dict lookup in + * ``_value2member_map_``). + */ +pybind11::object MakeTEDType(transformer_engine::DType dtype); + inline size_t typeToNumBits(transformer_engine::DType t) { switch (t) { case transformer_engine::DType::kInt64: diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..3fbc80ab2e 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -573,10 +573,12 @@ std::tuple, std::vector> bulk_allocate_fp py::object columnwise_scale = (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); - // Construct Python tensor + // Construct Python tensor (wrap C++ DType so ``_fp8_dtype`` on the + // Python tensor is the Python ``TE_DType`` IntEnum, matching the + // contract documented in ``common.h``::MakeTEDType). tensor_py_list.emplace_back( Float8BlockwiseQTensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, - fp8_dtype, quantizer_py_list[i], is_2D_scaled)); + MakeTEDType(fp8_dtype), quantizer_py_list[i], is_2D_scaled)); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( @@ -679,10 +681,12 @@ std::tuple, std::vector> bulk_allocate_mx py::object columnwise_scale = (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); - // Construct Python tensor + // Construct Python tensor (wrap C++ DType so ``_fp8_dtype`` on the + // Python tensor is the Python ``TE_DType`` IntEnum, matching the + // contract documented in ``common.h``::MakeTEDType). tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, fp8_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales)); + columnwise_scale, MakeTEDType(fp8_dtype), + quantizer_py_list[i], with_gemm_swizzled_scales)); // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( @@ -865,10 +869,12 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none(); py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); - // Construct Python tensor + // Construct Python tensor (wrap C++ DType so ``_fp4_dtype`` on the + // Python tensor is the Python ``TE_DType`` IntEnum, matching the + // contract documented in ``common.h``::MakeTEDType). tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, amax_columnwise, - fp4_dtype, quantizer_py_list[i], + MakeTEDType(fp4_dtype), quantizer_py_list[i], with_gemm_swizzled_scales, row_scaled_nvfp4)); // Construct C++ tensor diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..213195fe84 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -319,7 +319,7 @@ std::pair Float8Quantizer::create_tensor( py::tuple args(0); kwargs["data"] = data_py; kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; kwargs["fake_dtype"] = GetATenDType(dtype); @@ -343,7 +343,7 @@ std::pair Float8Quantizer::create_tensor( kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["data"] = data_py; kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; kwargs["device"] = py::cast(device); @@ -527,7 +527,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( tensor.attr("_transpose_invalid") = !need_transpose; // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_fp8_dtype") = MakeTEDType(dtype); // Construct C++ FP8 tensor TensorWrapper out_cpp; @@ -628,7 +628,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::dict kwargs; kwargs["data"] = data_py; kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; kwargs["fake_dtype"] = GetATenDType(dtype); @@ -651,7 +651,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["data"] = data_py; kwargs["fp8_scale_inv"] = scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["data_transpose"] = transpose_py; kwargs["quantizer"] = this->quantizer; kwargs["device"] = py::cast(device); @@ -854,7 +854,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ tensor.attr("_transpose_invalid") = !need_transpose; // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_fp8_dtype") = MakeTEDType(dtype); // Construct C++ FP8 tensor TensorWrapper out_cpp; @@ -1019,7 +1019,7 @@ std::pair Float8BlockQuantizer::create_tensor( kwargs["columnwise_data"] = py::cast(data_colwise); kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); kwargs["fake_dtype"] = GetATenDType(dtype); @@ -1045,7 +1045,7 @@ std::pair Float8BlockQuantizer::create_tensor( kwargs["columnwise_data"] = py::cast(data_colwise); kwargs["rowwise_scale_inv"] = py::cast(scale_inv_rowwise); kwargs["columnwise_scale_inv"] = py::cast(scale_inv_colwise); - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["is_2D_scaled"] = py::cast(block_scaling_dim == 2); kwargs["device"] = py::cast(device); @@ -1423,7 +1423,7 @@ std::pair MXFP8Quantizer::create_tensor( kwargs["columnwise_data"] = columnwise_data_py; kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["fake_dtype"] = GetATenDType(dtype); @@ -1447,7 +1447,7 @@ std::pair MXFP8Quantizer::create_tensor( kwargs["columnwise_data"] = columnwise_data_py; kwargs["rowwise_scale_inv"] = rowwise_scale_inv_py; kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; - kwargs["fp8_dtype"] = py::cast(this->dtype); + kwargs["fp8_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); @@ -1657,7 +1657,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( } // Coerce other attrs - tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_fp8_dtype") = MakeTEDType(dtype); tensor.attr("_with_gemm_swizzled_scales") = with_gemm_swizzled_scales; // Construct C++ MXFP8 tensor @@ -1841,7 +1841,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; kwargs["amax_rowwise"] = amax_rowwise_py; kwargs["amax_columnwise"] = amax_columnwise_py; - kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["fp4_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); @@ -1870,7 +1870,7 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["columnwise_scale_inv"] = columnwise_scale_inv_py; kwargs["amax_rowwise"] = amax_rowwise_py; kwargs["amax_columnwise"] = amax_columnwise_py; - kwargs["fp4_dtype"] = py::cast(this->dtype); + kwargs["fp4_dtype"] = MakeTEDType(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py index d660e5a53b..0147d55f78 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -38,8 +38,7 @@ from typing import Optional -import transformer_engine_torch as tex - +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.quantization import QuantizerRole @@ -70,7 +69,7 @@ def _make_mxfp8_quantizer(): from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer return MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) @@ -87,7 +86,7 @@ def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): if is_weight: return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, with_rht=False, with_post_rht_amax=False, with_2d_quantization=True, @@ -97,7 +96,7 @@ def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): if is_grad: return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, rowwise=True, columnwise=True, with_rht=True, @@ -108,7 +107,7 @@ def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): ) return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, rowwise=True, columnwise=True, with_rht=True, @@ -180,7 +179,7 @@ def nvfp4_linear_fp8_dpa_factory( if is_dpa: return Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device="cuda", ) @@ -193,7 +192,7 @@ def nvfp4_linear_fp8_dpa_factory( ) if is_dpa_boundary: return Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, device="cuda", ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py index 22eafaa665..24f46151f6 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -28,8 +28,7 @@ from typing import Optional import torch -import transformer_engine_torch as tex - +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.quantization import QuantizerRole @@ -65,7 +64,7 @@ def current_scaling_quantizer_factory( ) is_backward = role is not None and role.tensor_type == "grad_output" - fp8_dtype = tex.DType.kFloat8E5M2 if is_backward else tex.DType.kFloat8E4M3 + fp8_dtype = TE_DType.kFloat8E5M2 if is_backward else TE_DType.kFloat8E4M3 return Float8CurrentScalingQuantizer( fp8_dtype=fp8_dtype, @@ -86,7 +85,7 @@ def mxfp8_quantizer_factory( from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer return MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) @@ -111,7 +110,7 @@ def float8_block_scaling_quantizer_factory( block_scaling_dim = 2 if is_weight else 1 return Float8BlockQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise=True, columnwise=True, amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero @@ -146,7 +145,7 @@ def nvfp4_quantizer_factory( if is_weight: return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, with_rht=False, with_post_rht_amax=False, with_2d_quantization=True, @@ -156,7 +155,7 @@ def nvfp4_quantizer_factory( if is_grad: return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, rowwise=True, columnwise=True, with_rht=True, @@ -168,7 +167,7 @@ def nvfp4_quantizer_factory( # For input and unknown roles return NVFP4Quantizer( - fp4_dtype=tex.DType.kFloat4E2M1, + fp4_dtype=TE_DType.kFloat4E2M1, rowwise=True, columnwise=True, with_rht=True, diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index 4d3b90bf63..6f164168fb 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -30,7 +30,7 @@ from .tensor.float8_tensor import Float8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer -from .constants import MXFP8_BLOCK_SCALING_SIZE +from .constants import MXFP8_BLOCK_SCALING_SIZE, TE_DType from .utils import round_up_to_nearest_multiple from .export import is_in_onnx_export_mode @@ -85,7 +85,7 @@ def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: """Quantize to Float8Tensor used for inference.""" scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) amax_tensor = torch.tensor([1], dtype=torch.float32, device=tensor.device) - quantizer = Float8Quantizer(scale_tensor, amax_tensor, tex.DType.kFloat8E4M3) + quantizer = Float8Quantizer(scale_tensor, amax_tensor, TE_DType.kFloat8E4M3) return quantizer.quantize(tensor)._data @@ -131,7 +131,7 @@ def onnx_quantize_fp8_symbolic( def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: """Dequantize from Float8Tensor used for inference.""" quantizer = Float8Quantizer( - 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 + 1 / scale_inv, torch.zeros(1).to(tensor.device), TE_DType.kFloat8E4M3 ) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) return quantizer_tensor.dequantize() @@ -212,7 +212,7 @@ def onnx_quantize_fp8_cs_symbolic( @torch.library.custom_op("tex::mxfp8_quantize", mutates_args=[]) def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize to MXFP8Tensor used for inference.""" - quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(TE_DType.kFloat8E4M3) quantized_tensor = quantizer(tensor) return quantized_tensor._rowwise_data, quantized_tensor._rowwise_scale_inv @@ -264,7 +264,7 @@ def onnx_quantize_mxfp8_symbolic( @torch.library.custom_op("tex::mxfp8_dequantize", mutates_args=[]) def onnx_dequantize_mxfp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: """Dequantize from MXFP8Tensor used for inference.""" - quantizer = MXFP8Quantizer(tex.DType.kFloat8E4M3) + quantizer = MXFP8Quantizer(TE_DType.kFloat8E4M3) quantizer_tensor = quantizer.create_tensor_from_data( tensor, scale_inv, fake_dtype=torch.float32 ) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index eacc36b36c..3efacb20fa 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -12,6 +12,7 @@ import torch import transformer_engine_torch as tex +from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data @@ -107,7 +108,7 @@ def op_forward( # Quantize input to FP8 before caching if needed if self.cache_quantized_input: - input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer = Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, x.device) input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 1f00d92284..32adc6c5ff 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -14,6 +14,7 @@ import torch import transformer_engine_torch as tex +from ...constants import TE_DType from ...cpp_extensions import general_grouped_gemm, general_grouped_gemm_for_grouped_tensor from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore @@ -518,7 +519,7 @@ def _quantize_weights_mxfp8( weight = MXFP8Tensor( shape=unpacked_shape, dtype=dtype, - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, rowwise_data=rowwise_data[group_idx], rowwise_scale_inv=rowwise_scales[group_idx], columnwise_data=columnwise_data[group_idx], diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 9267d9bbbb..39fbbcb689 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -11,6 +11,7 @@ import torch import transformer_engine_torch as tex +from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data @@ -118,7 +119,7 @@ def op_forward( # Quantize input to FP8 before caching if needed if self.cache_quantized_input: input_quantizer = Float8CurrentScalingQuantizer( - tex.DType.kFloat8E4M3, + TE_DType.kFloat8E4M3, input_.device, ) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -297,7 +298,7 @@ def op_forward( # Quantize input to FP8 before caching if needed if self.cache_quantized_input: - input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer = Float8CurrentScalingQuantizer(TE_DType.kFloat8E4M3, x.device) input_quantizer.set_usage(rowwise=True, columnwise=False) x = input_quantizer(x) diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 828c34f539..e6b504f4cf 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -13,6 +13,7 @@ import torch from torch.distributed._tensor import DTensor import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from .multi_tensor_apply import multi_tensor_applier @@ -422,7 +423,7 @@ def _initialize_state( quantizer = Float8Quantizer( scale=torch.ones([1], dtype=torch.float32, device=param.device), amax=torch.zeros([1], dtype=torch.float32, device=param.device), - fp8_dtype=tex.DType.kFloat8E4M3, + fp8_dtype=TE_DType.kFloat8E4M3, ) self.state[param][state_name] = quantizer.make_empty(data.shape) self.state[param][state_name].quantize_(data.float()) @@ -599,7 +600,7 @@ def step(self, closure=None, grad_scaler=None): state_scales = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. - out_dtype = tex.DType.kFloat32 + out_dtype = TE_DType.kFloat32 has_fp16 = False has_bf16 = False diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..ef0f9dc84c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -26,7 +26,7 @@ NVFP4BlockScaling, CustomRecipe, ) -from .constants import dist_group_type +from .constants import dist_group_type, TE_DType from .utils import get_device_compute_capability from .jit import jit_fuser @@ -277,23 +277,23 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. return torch.float8_e5m2 -def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> TE_DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 + return TE_DType.kFloat8E4M3 + return TE_DType.kFloat8E5M2 -def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: +def get_fp4_te_dtype(fp4_recipe: Recipe) -> TE_DType: """Get fp4 data type according to recipe and tensor""" if fp4_recipe.fp4_format == Format.E2M1: - return tex.DType.kFloat4E2M1 + return TE_DType.kFloat4E2M1 raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> TE_DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -1382,7 +1382,7 @@ class DelayedScalingRecipeState(RecipeState): recipe: DelayedScaling mode: str - dtype: tex.DType + dtype: TE_DType scale: torch.Tensor amax_history: torch.Tensor @@ -1436,7 +1436,7 @@ class Float8CurrentScalingRecipeState(RecipeState): recipe: Float8CurrentScaling mode: str - dtype: tex.DType + dtype: TE_DType device: torch.device def __init__( @@ -1480,7 +1480,7 @@ class MXFP8BlockScalingRecipeState(RecipeState): recipe: MXFP8BlockScaling mode: str - dtype: tex.DType + dtype: TE_DType def __init__( self, @@ -1518,9 +1518,9 @@ class Float8BlockScalingRecipeState(RecipeState): recipe: Float8BlockScaling mode: str - qx_dtype: tex.DType - qw_dtype: tex.DType - qgrad_dtype: tex.DType + qx_dtype: TE_DType + qw_dtype: TE_DType + qgrad_dtype: TE_DType def __init__( self, @@ -1605,7 +1605,7 @@ class NVFP4BlockScalingRecipeState(RecipeState): recipe: NVFP4BlockScaling mode: str - dtype: tex.DType + dtype: TE_DType def __init__( self, diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index e091e27e59..086b538798 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -11,11 +11,11 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc +from ..constants import TE_DType from ..utils import devices_match, round_up_to_nearest_multiple aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7842ccc127..22a064498b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -9,7 +9,6 @@ import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import ( DelayedScaling, @@ -20,7 +19,7 @@ from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc -from ..constants import dist_group_type +from ..constants import dist_group_type, TE_DType aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 2815aaa96e..264efe54cf 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -12,10 +12,9 @@ import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe -from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..constants import MXFP8_BLOCK_SCALING_SIZE, TE_DType from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -148,7 +147,7 @@ def create_tensor_from_data( data: torch.Tensor, scale_inv: torch.Tensor, fake_dtype: torch.dtype, - fp8_dtype: TE_DType = tex.DType.kFloat8E4M3, + fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, ) -> MXFP8Tensor: """Create a new MXFP8Tensor from data and scale_inv.""" return MXFP8Tensor( diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 2ebefefaaa..8d9222ff81 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -12,10 +12,9 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe -from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..constants import NVFP4_BLOCK_SCALING_SIZE, TE_DType, dist_group_type from ..utils import ( canonicalize_process_group, devices_match, @@ -137,7 +136,7 @@ class NVFP4Quantizer(Quantizer): def __init__( self, - fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + fp4_dtype: TE_DType = TE_DType.kFloat4E2M1, rowwise: bool = True, columnwise: bool = True, with_amax_reduction: bool = False, diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index ca3913762f..f57e1d2196 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -10,11 +10,10 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType_To_Torch +from ...constants import TE_DType, TE_DType_To_Torch from ...utils import _empty_tensor diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 3a72ec5d1a..38d7b30b8c 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -10,11 +10,13 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch +from ...constants import ( + TE_DType, + TE_DType_To_Torch, +) from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor @@ -29,7 +31,7 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - te_dtype = torch_to_transformer_engine_dtype[dtype] + te_dtype = TE_DType[dtype] # Make sure FP8 data is in expected format if tensor._data is not None: @@ -42,7 +44,9 @@ def forward( tensor._data.view(fp8_torch_dtype).float() * tensor._scale_inv.to(tensor._data.device) ).to(dtype) - # Cast from FP8 + # Cast from FP8. ``TE_DType`` is implicitly convertible to + # ``transformer_engine::DType`` on the C++ side, so pass it + # directly to ``tex.dequantize``. return tex.dequantize(tensor, te_dtype) raise NotImplementedError("Casting back from the transpose not implemented yet!") diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 874555f465..1c6c05f08f 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -11,11 +11,10 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType from ...utils import _empty_tensor @@ -37,8 +36,10 @@ def forward( if tensor._rowwise_data is None and tensor._columnwise_data is None: raise ValueError("Cannot dequantize MXFP8 tensor with no data") - te_dtype = torch_to_transformer_engine_dtype[dtype] - # ``tex.dequantize`` requires CUDA-resident buffers. + # ``tex.dequantize`` requires CUDA-resident buffers. ``TE_DType`` + # is implicitly convertible to ``transformer_engine::DType`` on + # the C++ side (see ``pybind_helper.h``), so pass it directly. + te_dtype = TE_DType[dtype] src_device = tensor.device if src_device.type != "cuda": cuda_tensor = tensor.to(device=torch.device("cuda")) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 490184e5f8..31cb9a8786 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -14,11 +14,10 @@ import torch import transformer_engine_torch as tex -from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType from ...utils import _empty_tensor @@ -56,9 +55,9 @@ def forward( src_device = tensor.device if src_device.type != "cuda": cuda_tensor = tensor.to(device=torch.device("cuda")) - result = tex.dequantize(cuda_tensor, torch_to_transformer_engine_dtype[dtype]) + result = tex.dequantize(cuda_tensor, TE_DType[dtype]) return result.to(device=src_device) - return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) + return tex.dequantize(tensor, TE_DType[dtype]) @staticmethod def backward( diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 8b22097f7e..98e3f68a47 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -21,7 +21,7 @@ from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier from ..utils import is_non_tn_fp8_gemm_supported -from ..constants import NVFP4_BLOCK_SCALING_SIZE +from ..constants import NVFP4_BLOCK_SCALING_SIZE, TE_DType def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -367,9 +367,9 @@ def _cast_master_weights_to_fp8_current_scaling( # --------------------------------------------------------------------------------------------- # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- - if fp8_dtype == tex.DType.kFloat8E4M3: + if fp8_dtype == TE_DType.kFloat8E4M3: max_fp8 = 448.0 - elif fp8_dtype == tex.DType.kFloat8E5M2: + elif fp8_dtype == TE_DType.kFloat8E5M2: max_fp8 = 57344.0 else: raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") @@ -530,9 +530,9 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # --------------------------------------------------------------------------------------------- # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- - if fp8_dtype == tex.DType.kFloat8E4M3: + if fp8_dtype == TE_DType.kFloat8E4M3: max_fp8 = 448.0 - elif fp8_dtype == tex.DType.kFloat8E5M2: + elif fp8_dtype == TE_DType.kFloat8E5M2: max_fp8 = 57344.0 else: raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")