From 164f612405e6e300c2334ab891958ed423d52e3b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 12 May 2026 16:07:55 +0200 Subject: [PATCH 01/16] [PyTorch] Add fake implementations for Linear forward/backward Add torch custom-op fake implementations for `_linear_forward_impl` and `_linear_backward` so torch.compile can perform shape inference without running the real GEMM / communication / quantization paths. * `fake_cast_if_needed` in `utils.py`: returns an empty tensor of the target dtype when a cast would happen, otherwise returns the input. * `fake_quantize_weight` in `module/base.py`: mirrors the cache-hit / cache-miss control flow of `quantize_weight` but fills cache misses with `quantizer.make_empty`. * `_linear_forward_fake_impl` in `module/linear.py`: mirrors the real forward's control flow and `set_usage` / `update_usage` calls, but replaces computation with empty tensors and skips side effects irrelevant for shape inference (CPU offload, calibration, NCCL/UB collectives, `clear_tensor_data`, FSDP scatter). Communication is simulated by producing tensors with the post-comm shape. Manual TE FSDP (`fsdp_group is not None`) is unsupported. * `_linear_backward_fake_impl` in `module/linear.py`: backward output shapes/dtypes are deterministic, so just allocates empty tensors of the right shape. `grad_input_quantizer.set_usage` is preserved because it influences `dgrad`'s `make_empty`. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 55 ++++ transformer_engine/pytorch/module/linear.py | 321 ++++++++++++++++++++ transformer_engine/pytorch/utils.py | 15 + 3 files changed, 391 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..6031c809ca 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -763,6 +763,61 @@ def quantize_weight( return out, None +def fake_quantize_weight( + *, + tensor: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + workspace: Optional[QuantizedTensorStorage] = None, + fsdp_group: Optional["dist_group_type"] = None, + workspace_dtype: Optional[torch.dtype] = None, + cache: bool = False, +) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]: + """Fake counterpart of :func:`quantize_weight` for shape inference. + + Mirrors the cache-hit / cache-miss control flow of :func:`quantize_weight` + but never performs an actual quantization. Cache misses are filled with + ``quantizer.make_empty``. Used by torch custom-op fake registrations. + """ + + # Already-quantized weight (primary FP8 parameters) + if isinstance(tensor, QuantizedTensor): + update_rowwise = True if quantizer.rowwise_usage else None + update_columnwise = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise, + columnwise_usage=update_columnwise, + ) + return tensor, None + + # Validate workspace + if workspace is not None and quantizer is not None: + if not _is_weight_workspace_valid(workspace, quantizer): + workspace = None + + if workspace is not None and fsdp_group is not None: + raise NotImplementedError( + "fake_quantize_weight does not support FSDP weight workspaces" + ) + + # Cache hit + if workspace is not None: + return workspace, None + + # Cache miss — create new (fake) workspace + if tensor is None or quantizer is None: + raise ValueError( + "tensor and quantizer kwargs must be provided to construct FP8 workspace" + ) + out = quantizer.make_empty( + tensor.shape, + dtype=workspace_dtype, + device=tensor.device, + ) + if cache: + return out, out + return out, None + + class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..ff1a55e1f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -18,6 +18,7 @@ from transformer_engine.pytorch.torch_version import torch_version from .base import ( + fake_quantize_weight, fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, @@ -33,6 +34,7 @@ cast_if_needed, clear_tensor_data, divide, + fake_cast_if_needed, init_method_constant, needs_quantized_gemm, assert_dim_for_fp8_exec, @@ -601,6 +603,275 @@ def _linear_forward_impl( return out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs +def _linear_forward_fake_impl( + args: LinearFwdArgs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], None, Optional[Dict]]: + """Fake :func:`_linear_forward_impl` for torch custom-op shape inference. + + Mirrors the real control flow and ``set_usage`` / ``update_usage`` + calls, but replaces computation with empty tensors and skips side + effects irrelevant for shape inference (CPU offload, calibration, + NCCL/UB collectives, ``clear_tensor_data``, FSDP scatter). + """ + + # The few locals below are mutated later in the function; everything + # else is read directly off ``args``. This shape-inference helper is + # not on a hot path, so we don't bother caching attribute lookups. + save_original_input = args.save_original_input + if args.backward_override == "high_precision": + save_original_input = True + weight_quantizer = args.weight_quantizer + + out_features, in_features = args.weight.shape + assert args.inp.shape[-1] == in_features, "GEMM not possible" + + tp_world_size = get_distributed_world_size(args.tp_group) + backward_needs_input = args.is_grad_enabled and args.weight_requires_grad + with_input_all_gather_nccl = ( + args.parallel_mode == "column" + and args.sequence_parallel + and not args.ub_overlap_ag_fprop + ) + + # ------------------------------------------------------ + # Prepare input tensor + # ------------------------------------------------------ + # ``inputmat`` may become a ``QuantizedTensorStorage`` (which does not + # always expose ``.shape``), so track the logical shape separately. + inputmat = args.inp + inputmat_shape = list(args.inp.shape) + inputmat_total = None + inputmat_total_shape: List[int] = inputmat_shape + own_quantized_input = False + if args.fp8: + assert_dim_for_fp8_exec(inputmat, args.weight) + if save_original_input: + assert not isinstance( + args.input_quantizer, Float8Quantizer + ), "DelayedScaling recipe is not supported with save_original_input" + + if with_input_all_gather_nccl or args.ub_overlap_ag_fprop: + + if args.fp8 or args.debug: + if args.input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if not isinstance(inputmat, QuantizedTensorStorage) and not args.custom: + own_quantized_input = True + args.input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and args.backward_override is None, + ) + if isinstance( + args.input_quantizer, (Float8CurrentScalingQuantizer, Float8Quantizer) + ): + args.input_quantizer.set_usage(columnwise=False) + if save_original_input: + args.input_quantizer.set_usage(columnwise=False) + own_quantized_input = False + inputmat = args.input_quantizer.make_empty( + inputmat.shape, + dtype=args.activation_dtype, + device=inputmat.device, + ) + else: + inputmat = fake_cast_if_needed(args.inp, args.activation_dtype) + + # Initialize gathered input tensor (interleaved set_usage stays). + quantizer = None + if args.fp8 or args.debug: + quantizer = args.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=False) + + gathered_shape = list(inputmat_shape) + gathered_shape[0] *= tp_world_size + inputmat_total_shape = gathered_shape + if quantizer is not None: + inputmat_total = quantizer.make_empty( + gathered_shape, + dtype=args.activation_dtype, + device=args.inp.device, + ) + else: + inputmat_total = torch.empty( + gathered_shape, dtype=args.activation_dtype, device=args.inp.device + ) + + else: + if args.fp8 or args.debug: + if isinstance(inputmat, QuantizedTensorStorage): + inputmat.update_usage(rowwise_usage=True) + else: + if args.input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + args.input_quantizer.set_usage( + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and args.backward_override is None + ), + ) + inputmat = args.input_quantizer.make_empty( + inputmat.shape, + dtype=args.activation_dtype, + device=inputmat.device, + ) + own_quantized_input = True + else: + inputmat = fake_cast_if_needed(args.inp, args.activation_dtype) + inputmat_total = inputmat + inputmat_total_shape = inputmat_shape + + # ------------------------------------------------------ + # Prepare weight tensor + # ------------------------------------------------------ + new_weight_workspace = None + weightmat = args.weight + if args.fp8 or args.debug: + if weight_quantizer is not None and ( + not isinstance(args.weight, QuantizedTensor) or args.debug + ): + columnwise_usage = ( + args.is_grad_enabled and args.input_requires_grad and not args.is_fsdp2 + ) + if args.backward_override is not None: + columnwise_usage = False + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(args.weight, QuantizedTensor): + weight_quantizer = args.weight._quantizer + weightmat, new_weight_workspace = fake_quantize_weight( + tensor=args.weight, + quantizer=weight_quantizer, + workspace=args.weight_workspace, + fsdp_group=args.fsdp_group, + workspace_dtype=args.activation_dtype, + cache=args.cache_weight, + ) + weightmat.update_usage(rowwise_usage=True) + else: + weightmat = fake_cast_if_needed(weightmat, args.activation_dtype) + + # Cast bias to expected dtype + bias_dtype = args.activation_dtype + if needs_quantized_gemm(inputmat_total) and args.activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = fake_cast_if_needed(args.bias, bias_dtype) if args.bias is not None else args.bias + + # Configure output quantizer + if args.output_quantizer is not None: + args.output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Output buffer for Userbuffers reduce-scatter (allocated with the + # post-RS shape so downstream consumers see consistent dimensions). + reduce_scatter_out = None + if args.ub_overlap_rs_fprop: + out_shape = list(args.inp.shape) + out_shape[0] //= tp_world_size + out_shape[-1] = out_features + reduce_scatter_out = torch.empty( + out_shape, dtype=args.activation_dtype, device=args.inp.device + ) + + # ------------------------------------------------------ + # Forward GEMM (fake) + # ------------------------------------------------------ + gemm_out_shape = list(inputmat_total_shape[:-1]) + [out_features] + if args.output_quantizer is not None: + gemm_out = args.output_quantizer.make_empty( + gemm_out_shape, dtype=args.activation_dtype, device=args.inp.device + ) + else: + gemm_out = torch.empty( + gemm_out_shape, dtype=args.activation_dtype, device=args.inp.device + ) + + if with_input_all_gather_nccl: + inputmat_total = None + + # ------------------------------------------------------ + # Prepare output tensor (mirror the real comm path with shape-only ops) + # ------------------------------------------------------ + if args.ub_overlap_rs_fprop: + out = reduce_scatter_out + elif args.parallel_mode == "row" and args.tp_size > 1: + out = gemm_out + if args.sequence_parallel: + new_shape = list(out.shape) + new_shape[0] //= tp_world_size + if args.output_quantizer is not None: + out = args.output_quantizer.make_empty( + new_shape, dtype=out.dtype, device=out.device + ) + else: + out = torch.empty(new_shape, dtype=out.dtype, device=out.device) + # allreduce / symmetric_all_reduce do not change shape. + else: + out = gemm_out + + # Prepare backward state + tensors_to_save_from_forward = None + ctx_attrs = None + + if args.is_grad_enabled: + if save_original_input: + inputmat = args.inp + + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorStorage) + ): + if args.backward_override is not None: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + elif ( + args.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() + ): + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + saved_inputmat = None + if backward_needs_input: + saved_inputmat = inputmat + + if args.fsdp_group is not None: + raise NotImplementedError( + "Fake Linear forward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + fsdp_shapes = [] + + wt_save = weightmat + if args.is_fsdp2 and weightmat is not args.weight: + wt_save = None + + saved_tensor_aliases = ( + "inp" if saved_inputmat is args.inp else None, + "weight" if wt_save is args.weight else None, + "weight", + "bias" if bias is not None else None, + ) + tensors_to_save_from_forward = ( + None if saved_tensor_aliases[0] is not None else saved_inputmat, + None if saved_tensor_aliases[1] is not None else wt_save, + None, + None if saved_tensor_aliases[3] is not None else bias, + ) + + ctx_attrs = { + "fsdp_shapes": fsdp_shapes, + "saved_tensor_aliases": saved_tensor_aliases, + } + + return out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs + + def _linear_setup_ctx( bwd_args: LinearBwdArgs, fwd_args: LinearFwdArgs, @@ -1249,6 +1520,56 @@ def wgrad_gemm( ) +def _linear_backward_fake_impl( + args: LinearBwdArgs, +) -> Tuple[Union[torch.Tensor, None], ...]: + """Fake :func:`_linear_backward` for torch custom-op shape inference. + + Backward output shapes/dtypes are deterministic, so we just allocate + empty tensors of the right shape. ``grad_input_quantizer.set_usage`` + is preserved because it influences ``dgrad``'s ``make_empty``. + Manual TE FSDP is unsupported; FSDP2 / MCore FSDP go through the + standard path. + """ + + if args.fsdp_group is not None: + raise NotImplementedError( + "Fake Linear backward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + + assert args.saved_weight is not None and args.grad_output is not None + out_features, in_features = args.saved_weight.shape + + if args.grad_input_quantizer is not None: + args.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + def _empty(shape, quantizer): + if quantizer is not None: + return quantizer.make_empty( + shape, dtype=args.activation_dtype, device=args.grad_output.device + ) + return torch.empty( + shape, dtype=args.activation_dtype, device=args.grad_output.device + ) + + wgrad = None + if args.requires_wgrad and not args.fuse_wgrad_accumulation: + wgrad = _empty([out_features, in_features], args.grad_weight_quantizer) + + dgrad = None + if args.requires_dgrad: + dgrad = _empty(list(args.inp_shape), args.grad_input_quantizer) + + grad_bias = None + if args.use_bias and args.requires_wgrad: + grad_bias = torch.empty( + [out_features], dtype=args.activation_dtype, device=args.grad_output.device + ) + + return wgrad, dgrad, grad_bias + + class _Linear(torch.autograd.Function): """Linear semi-top level module Calls custom cuda extensions. diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 250daec67f..76d204deb0 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -502,6 +502,21 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return tensor.to(dtype=dtype) +def fake_cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Fake counterpart of :func:`cast_if_needed` for shape inference. + + Returns the same tensor if no cast would happen, otherwise an empty + tensor of the requested dtype with matching shape and device. Used by + torch custom-op fake registrations so the FX graph can reason about + output shapes without actually performing the cast. + """ + if tensor is None: + return None + if tensor.dtype == dtype: + return tensor + return torch.empty_like(tensor, dtype=dtype) + + def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: """Check if tensor dimensions are supported for FP8 TN GEMM""" return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 From 1a9176cd38e8bd8f9e1a84b0fd8bf124b2317441 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 12 May 2026 16:14:03 +0200 Subject: [PATCH 02/16] [PyTorch] Add torch.compile dynamo helper module Drop in the `dynamo.py` helper from `gh_linear_torch_compile_support` @dbe4ca79 ("Add torch.compile support for Linear"), before the "Cache torch.compile unpack output" (8ab2425b) and "Hoist constant Linear setup out of opaque custom-op body" (9a98ff5b) optimizations. Exposes `ArgObject`, `OpaqueSimpleMetadata`, and `_te_register_custom_op`, which will be used to route the TE Linear forward/backward through a torch custom op so torch.compile can trace through it without entering the eager autograd.Function machinery. The module is self-contained: it imports only `torch`, `dataclasses`, `enum`, and `typing`, and does not yet have any callers in this branch. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 1329 ++++++++++++++++++++++++++ 1 file changed, 1329 insertions(+) create mode 100644 transformer_engine/pytorch/dynamo.py diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py new file mode 100644 index 0000000000..e05b865da3 --- /dev/null +++ b/transformer_engine/pytorch/dynamo.py @@ -0,0 +1,1329 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""torch.compile (Dynamo) integration for TransformerEngine modules.""" +from __future__ import annotations + +import dataclasses +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + get_args, + get_origin, + get_type_hints, +) + +import torch + + +__all__ = [ + "ArgObject", + "OpaqueSimpleMetadata", + "_te_register_custom_op", +] + + +# Sentinel for ``None`` entries inside the op's flat ``Tensor[]`` return. +# Used by :func:`_te_register_custom_op` to support ``None`` outputs (e.g. +# an FP8 weight workspace returned only on the cache-miss path) on a +# non-nullable schema -- ``Tensor?[]`` returns are not picked up by +# ``torch.library.register_autograd``, so the registered backward never +# attaches a ``grad_fn`` to the op's outputs. +_NONE_SENTINEL_DTYPE = torch.uint8 + + +def _encode_none(t: Optional[torch.Tensor]) -> torch.Tensor: + """Replace ``None`` with a 0-element uint8 sentinel tensor.""" + if t is None: + return torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) + return t + + +def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Inverse of :func:`_encode_none`.""" + if t is None: + return None + if t.numel() == 0 and t.dtype == _NONE_SENTINEL_DTYPE: + return None + return t + + +# --------------------------------------------------------------------------- # +# OpaqueSimpleMetadata +# --------------------------------------------------------------------------- # + +class OpaqueSimpleMetadata: + """Opaque value-type bundle of simple Python values. + + Wraps a ``{name: value}`` dict so that many small non-Tensor arguments + of a TE custom op can be passed as a single op input. Registered as a + torch.compile *value* opaque type, meaning Dynamo specializes the + traced graph on the bundle's contents: ``__eq__`` installs a guard, + and any change to a wrapped value triggers a recompile. + + Allowed value types: primitives in :attr:`PRIMITIVE_TYPES`, + :class:`enum.Enum`, :class:`torch.Size`, plus arbitrarily nested + tuples/lists thereof. + """ + + # Primitive Python types we are willing to bundle into a single op + # input. The bundle is registered as a torch.compile *value* opaque + # type, so its contents must be hashable, comparable for equality, + # and round-trippable through ``__fx_repr__``. + PRIMITIVE_TYPES: Tuple[type, ...] = ( + type(None), + bool, + int, + float, + str, + torch.dtype, + torch.device, + ) + + @classmethod + def _is_opaque_value(cls, value: Any) -> bool: + """Whether ``value``'s class is registered as a value-opaque type.""" + try: + from torch._library.opaque_object import is_opaque_value_type + except Exception: # pragma: no cover - older torch + return False + return is_opaque_value_type(type(value)) + + @classmethod + def is_simple_value(cls, value: Any) -> bool: + """Whether ``value`` is allowed inside an instance. + + Accepts simple primitives (see :attr:`PRIMITIVE_TYPES`), + :class:`enum.Enum`, :class:`torch.Size`, instances of any class + registered as a torch.compile *value*-opaque type (the latter + already supplies ``__eq__`` / ``__hash__`` / ``__fx_repr__`` as + a registration prerequisite), and arbitrarily nested + tuples / lists thereof. + """ + if isinstance(value, cls.PRIMITIVE_TYPES): + return True + if isinstance(value, Enum): + return True + if isinstance(value, torch.Size): + return True + if cls._is_opaque_value(value): + return True + if isinstance(value, (list, tuple)): + return all(cls.is_simple_value(v) for v in value) + return False + + @classmethod + def _to_hashable(cls, value: Any) -> Any: + """Convert a simple value into something hashable (lists -> tuples).""" + if isinstance(value, (list, tuple, torch.Size)): + return tuple(cls._to_hashable(v) for v in value) + # Opaque-value instances already supply ``__hash__`` (required + # by registration) so they can stay as-is. + return value + + @classmethod + def _fmt_simple(cls, value: Any) -> str: + """Repr for a simple value, evaluable in a context with ``torch`` globals.""" + if isinstance(value, torch.dtype): + return f"__import__('torch').{str(value).split('.')[-1]}" + if isinstance(value, torch.device): + return f"__import__('torch').device({str(value)!r})" + if isinstance(value, torch.Size): + return f"__import__('torch').Size({list(value)!r})" + if isinstance(value, Enum): + return f"{type(value).__name__}.{value.name}" + if isinstance(value, list): + return "[" + ", ".join(cls._fmt_simple(v) for v in value) + "]" + if isinstance(value, tuple): + body = ", ".join(cls._fmt_simple(v) for v in value) + return f"({body},)" if len(value) == 1 else f"({body})" + if cls._is_opaque_value(value): + # Opaque-value types declare their FX reconstruction via + # ``__fx_repr__``; just splice their expression in here. + return value.__fx_repr__()[0] + return repr(value) + + def __init__( + self, + data: Optional[Dict[str, Any]] = None, + /, + **kwargs: Any, + ) -> None: + merged: Dict[str, Any] = dict(data) if data else {} + merged.update(kwargs) + cls = type(self) + for k, v in merged.items(): + if not cls.is_simple_value(v): + raise TypeError( + f"OpaqueSimpleMetadata field '{k}' has unsupported " + f"type {type(v).__name__}; only simple primitives " + f"({', '.join(t.__name__ for t in cls.PRIMITIVE_TYPES)}, " + f"Enum, torch.Size, registered torch.compile value-" + f"opaque types) and tuples/lists thereof are allowed." + ) + self._data: Dict[str, Any] = merged + self._frozen: Tuple[Tuple[str, Any], ...] = tuple( + (k, cls._to_hashable(v)) for k, v in sorted(merged.items()) + ) + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __getattr__(self, name: str) -> Any: + # Only called when normal attribute lookup fails, so ``_data`` / + # ``_frozen`` won't recurse here once set in ``__init__``. + try: + return self._data[name] + except KeyError as e: + raise AttributeError(name) from e + + def __contains__(self, key: str) -> bool: + return key in self._data + + def keys(self) -> List[str]: + return list(self._data.keys()) + + def values(self) -> List[Any]: + return list(self._data.values()) + + def items(self) -> List[Tuple[str, Any]]: + return list(self._data.items()) + + def get(self, key: str, default: Any = None) -> Any: + return self._data.get(key, default) + + def as_dict(self) -> Dict[str, Any]: + return dict(self._data) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OpaqueSimpleMetadata): + return NotImplemented + return self._frozen == other._frozen + + def __hash__(self) -> int: + return hash(self._frozen) + + def __fx_repr__(self) -> Tuple[str, Dict[str, Any]]: + cls = type(self) + items = ", ".join( + f"{k!r}: {cls._fmt_simple(v)}" for k, v in self._data.items() + ) + # Collect every type referenced by a nested opaque-value's + # ``__fx_repr__`` so the FX codegen can resolve those names. + globals_: Dict[str, Any] = { + "OpaqueSimpleMetadata": OpaqueSimpleMetadata, + } + + def _collect(value: Any) -> None: + if isinstance(value, (list, tuple)): + for v in value: + _collect(v) + return + # Skip plain Python / torch primitives up-front: they're + # rendered as literals by ``_fmt_simple`` and need no + # globals entry. + if isinstance(value, cls.PRIMITIVE_TYPES): + return + if isinstance(value, torch.Size): + return + if isinstance(value, Enum): + # ``_fmt_simple`` emits ``EnumName.MEMBER``; the Enum + # class must be in scope when the source string is + # later ``exec``d (e.g. by ``GraphModule.print_readable`` + # or by Inductor's runtime wrapper). + t = type(value) + globals_[t.__name__] = t + return + if cls._is_opaque_value(value): + _, extra = value.__fx_repr__() + globals_.update(extra) + + for v in self._data.values(): + _collect(v) + return (f"OpaqueSimpleMetadata({{{items}}})", globals_) + + def __repr__(self) -> str: + # ``__repr__`` is on hot diagnostic paths (Inductor error + # formatters, FX node printers, ...) and must never raise: + # treating any embedded value's ``repr`` failure as a soft + # placeholder keeps those error reporters from masking the + # actual root-cause exception with a crash inside our repr. + parts: List[str] = [] + for k, v in self._data.items(): + try: + v_repr = repr(v) + except Exception as e: # pylint: disable=broad-except + v_repr = f"<{type(v).__name__}: repr failed: {e!s}>" + parts.append(f"{k!r}: {v_repr}") + return f"OpaqueSimpleMetadata({{{', '.join(parts)}}})" + + +# Register OpaqueSimpleMetadata as a torch.compile value-opaque type, and +# resolve the schema name of ``torch.distributed.ProcessGroup`` (registered +# upstream as a *reference* opaque type via +# ``torch.distributed.device_mesh._register_distributed_opaque_types``). +# Both are done at module import so that any TE op declared via +# ``_te_register_custom_op`` can immediately reference them in its schema. +# Older PyTorch versions without these APIs are tolerated: the eager path +# keeps working, only torch.compile tracing of TE custom ops is unavailable. +try: + from torch._library.opaque_object import ( + get_opaque_type_name, + register_opaque_type, + ) + + register_opaque_type(OpaqueSimpleMetadata, typ="value") + _OPAQUE_SIMPLE_META_TYPE_NAME: Optional[str] = get_opaque_type_name( + OpaqueSimpleMetadata + ) + + _PROCESS_GROUP_TYPE_NAME: Optional[str] = None + try: + from torch.distributed import ProcessGroup + from torch.distributed.device_mesh import ( + _register_distributed_opaque_types, + ) + + _register_distributed_opaque_types() + _PROCESS_GROUP_TYPE_NAME = get_opaque_type_name(ProcessGroup) + except Exception: # pragma: no cover - distributed not built / disabled + _PROCESS_GROUP_TYPE_NAME = None +except Exception: # pragma: no cover - older torch without opaque_object + _OPAQUE_SIMPLE_META_TYPE_NAME = None + _PROCESS_GROUP_TYPE_NAME = None + + +# --------------------------------------------------------------------------- # +# Field buckets +# --------------------------------------------------------------------------- # + +# Each dataclass field of an :class:`ArgObject` is mapped to exactly one +# bucket. A bucket owns the full per-field "vocabulary" -- which schema +# slots it emits, how its packed value(s) are produced from the dataclass +# instance, and how the unpacked value is re-injected into the +# reconstructed instance. ``ArgObject`` then becomes three trivial loops +# over a list of buckets, instead of three parallel branch ladders. +# +# Five bucket kinds are used: +# +# * :class:`_TensorBucket` -- :class:`torch.Tensor` / +# :class:`Optional[torch.Tensor] ` -> one ``Tensor`` / +# ``Tensor?`` slot. +# * :class:`_TensorListBucket` -- ``List[torch.Tensor]`` / +# ``Tuple[torch.Tensor, ...]`` -> one ``Tensor[]`` slot. Used for +# variable-length tensor sequences such as ``ctx.saved_tensors``. +# * :class:`_ProcessGroupBucket` -- :class:`torch.distributed.ProcessGroup` +# (already registered upstream as a value-opaque type) -> one direct +# slot. +# * :class:`_FlattenableBucket` -- a field whose type implements the +# ``_flatten`` / ``_unflatten`` protocol (today: any +# :class:`Quantizer` or :class:`Recipe` subclass) -> three slots +# ``__fmeta`` / ``__fpg`` / ``__ftensors``. Bases +# are discovered via :func:`_flattenable_bases`, lazily imported to +# avoid an import cycle. +# * :class:`_SimpleBundleBucket` -- aggregator over **all** simple-typed +# fields of the dataclass; emits a single ``_simple_meta`` slot +# carrying an :class:`OpaqueSimpleMetadata` bundle. +# * :class:`_UnknownBucket` -- a field whose annotation matches none of +# the above. Emits no schema slot; pack raises if the field holds a +# non-``None`` value, unpack restores it as ``None``. + + +def _strip_optional(annot: Any) -> Tuple[Any, bool]: + """If ``annot`` is ``Optional[X]`` return ``(X, True)``; else ``(annot, False)``. + + Shared by all bucket matchers below. + """ + if get_origin(annot) is Union: + args = get_args(annot) + if type(None) in args: + non_none = [a for a in args if a is not type(None)] + if len(non_none) == 1: + return non_none[0], True + return annot, False + + +class _Bucket: + """Per-field handler for translating between a dataclass field and the + flat ``{slot_name: slot_value}`` view consumed by ``torch.library``. + + Each concrete bucket owns: + + * a :meth:`try_build` classmethod that decides whether a given + ``(name, annotation)`` pair belongs to this bucket (returning an + instance, or ``None`` to defer to the next bucket); + * the runtime :meth:`schema_slots` / :meth:`pack` / :meth:`unpack` + logic for that field. + + :class:`_SimpleBundleBucket` is the exception: it aggregates many + simple-typed fields into a single op input, so it does not implement + ``try_build``. It exposes :meth:`matches_field` for the per-field + membership test, and is constructed once at the end of dispatch + with the collected names. + """ + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_Bucket"]: + """Return an instance handling ``(name, annot)``, or ``None``.""" + raise NotImplementedError + + def schema_slots(self) -> List[Tuple[str, str]]: + """Return ``[(slot_name, schema_type_str), ...]`` for this field.""" + raise NotImplementedError + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + """Return ``[(slot_name, value), ...]`` extracted from ``owner``.""" + raise NotImplementedError + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + """Read this field's slots from ``args`` and write the + reconstructed dataclass attribute(s) into ``kwargs``.""" + raise NotImplementedError + + +class _TensorOrStorageBucket(_Bucket): + """``Tensor | QuantizedTensorStorage`` -> meta / pg / Tensor[] slots. + + Plain tensors are carried as a single-element ``Tensor[]``. Quantized + tensor wrappers and storage shells are carried through their + ``_torch_compile_flatten`` protocol so the backward op receives the same + structured object type that eager restoration produced. + """ + + SUFFIX_META = "__tsmeta" + SUFFIX_PG = "__tspg" + SUFFIX_TENSORS = "__tstensors" + + KIND_KEY = "_te_tensor_storage_kind" + KIND_NONE = "none" + KIND_TENSOR = "tensor" + + def __init__(self, name: str) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"Tensor/storage field {name!r} requires both " + "OpaqueSimpleMetadata and torch.distributed.ProcessGroup " + "to be registered as torch._library opaque types; one or " + "both are unavailable in this PyTorch build." + ) + self.name = name + + @staticmethod + def _is_tensor_storage_union(annot: Any) -> bool: + origin = get_origin(annot) + if origin is not Union: + return False + members = [a for a in get_args(annot) if a is not type(None)] + if torch.Tensor not in members: + return False + try: + from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage + except Exception: # pragma: no cover - partial init + return False + return any( + isinstance(member, type) and issubclass(member, QuantizedTensorStorage) + for member in members + ) + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_TensorOrStorageBucket"]: + if cls._is_tensor_storage_union(annot): + return cls(name) + return None + + def _slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def _slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def _slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self._slot_tensors(), "Tensor[]"), + ] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) + if value is None: + meta = OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_NONE}) + pg: Any = None + tensors: List[torch.Tensor] = [] + else: + from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage + + if isinstance(value, QuantizedTensorStorage): + meta, pg, tensors = value._torch_compile_flatten() + elif isinstance(value, torch.Tensor): + meta = OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_TENSOR}) + pg = None + tensors = [value] + else: + raise TypeError( + f"{type(owner).__name__} field {self.name!r} expected " + "None, torch.Tensor, or QuantizedTensorStorage, got " + f"{type(value).__name__}" + ) + return [ + (self._slot_meta(), meta), + (self._slot_pg(), pg), + (self._slot_tensors(), list(tensors)), + ] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + meta = args[self._slot_meta()] + kind = meta.get(self.KIND_KEY) + if kind == self.KIND_NONE: + kwargs[self.name] = None + return + tensors = args[self._slot_tensors()] + if kind == self.KIND_TENSOR: + kwargs[self.name] = tensors[0] + return + + from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage + + kwargs[self.name] = QuantizedTensorStorage._torch_compile_unflatten( + meta, + args[self._slot_pg()], + tensors, + ) + + +class _TensorBucket(_Bucket): + """``Tensor`` / ``Optional[Tensor]`` -> single ``Tensor`` / ``Tensor?`` slot.""" + + def __init__(self, name: str, is_optional: bool) -> None: + self.name = name + self.type_str = "Tensor?" if is_optional else "Tensor" + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_TensorBucket"]: + stripped, is_optional = _strip_optional(annot) + if stripped is torch.Tensor: + return cls(name, is_optional) + return None + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.name, self.type_str)] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + return [(self.name, getattr(owner, self.name))] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = args[self.name] + + +class _TensorListBucket(_Bucket): + """``List[Tensor]`` / ``Tuple[Tensor, ...]`` -> single ``Tensor[]`` slot. + + Used for fields like ``LinearBwdArgs.saved_tensors`` that carry an + arbitrary-length sequence of tensors (typically the + ``ctx.saved_tensors`` payload restored before invoking the backward + op). The slot itself is non-nullable, but individual ``None`` + elements are smuggled through using :func:`_encode_none` / + :func:`_decode_none` sentinels (matching what the forward op return + list already does). An empty sequence is valid. + """ + + def __init__(self, name: str, container: type) -> None: + self.name = name + # Remember the original container type so unpack returns the + # exact same Python type the dataclass annotation declared. + self.container = container + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_TensorListBucket"]: + stripped, _ = _strip_optional(annot) + origin = get_origin(stripped) + if origin is None: + return None + args = get_args(stripped) + if not args: + return None + # ``Tuple[Tensor, ...]`` -> args = (Tensor, Ellipsis); other forms + # like ``Tuple[Tensor, Tensor]`` or ``List[Tensor]`` only have + # type entries. + if origin is tuple: + if len(args) == 2 and args[1] is Ellipsis: + elem = args[0] + else: + elem = args[0] if all(a is args[0] for a in args) else None + elif origin is list: + elem = args[0] + else: + return None + if elem is not torch.Tensor: + return None + return cls(name, list if origin is list else tuple) + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.name, "Tensor[]")] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) or () + return [(self.name, [_encode_none(t) for t in value])] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = self.container(_decode_none(t) for t in args[self.name]) + + +class _ProcessGroupBucket(_Bucket): + """``ProcessGroup`` / ``Optional[ProcessGroup]`` -> one direct opaque-ref slot. + + PG is registered upstream (in ``torch.distributed.device_mesh``) as + a value-opaque type, so torch.library carries it without help. + """ + + def __init__(self, name: str, is_optional: bool) -> None: + if _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"ProcessGroup field {name!r} requires torch.distributed " + "and the opaque-type registration in " + "torch.distributed.device_mesh; neither is available in " + "this PyTorch build." + ) + self.name = name + self.type_str = _PROCESS_GROUP_TYPE_NAME + ("?" if is_optional else "") + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_ProcessGroupBucket"]: + stripped, is_optional = _strip_optional(annot) + if not isinstance(stripped, type): + return None + try: + from torch.distributed import ProcessGroup + except Exception: # pragma: no cover - distributed not built + return None + if not issubclass(stripped, ProcessGroup): + return None + return cls(name, is_optional) + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.name, self.type_str)] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + return [(self.name, getattr(owner, self.name))] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = args[self.name] + + +def _flattenable_bases() -> Tuple[type, ...]: + """Return the list of base classes whose subclasses are routed + through :class:`_FlattenableBucket`. + + A "flattenable" type implements the duck-typed pair + + * instance method ``_flatten() -> (OpaqueSimpleMetadata, ref, list[Tensor])`` + * classmethod ``_unflatten(meta, ref, tensors)`` (dispatches by an + identifier stamped into ``meta``) + + Lazy import keeps ``dynamo`` importable before the modules that + define these bases (avoid import cycles). + """ + bases: List[type] = [] + try: + from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer + + bases.append(Quantizer) + bases.append(QuantizedTensorStorage) + except Exception: # pragma: no cover - partial init + pass + try: + from transformer_engine.common.recipe import Recipe + + bases.append(Recipe) + except Exception: # pragma: no cover - partial init + pass + return tuple(bases) + + +class _FlattenableBucket(_Bucket): + """Three-slot expansion (``meta`` / ``ref`` / ``tensors``) for any + field whose type implements the ``_flatten`` / ``_unflatten`` + protocol (see :func:`_flattenable_bases`). Used today for + :class:`~transformer_engine.pytorch.quantized_tensor.Quantizer` and + :class:`~transformer_engine.common.recipe.Recipe`. + """ + + SUFFIX_META = "__fmeta" + SUFFIX_PG = "__fpg" + SUFFIX_TENSORS = "__ftensors" + + # Stored under ``_qcls`` in the metadata bundle to encode ``None`` + # without making any of the three slots nullable. + NONE_MARKER_KEY = "_qcls" + NONE_MARKER_VAL = "" + + def __init__(self, name: str, base_cls: type) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"Flattenable field {name!r} requires both " + "OpaqueSimpleMetadata and torch.distributed.ProcessGroup " + "to be registered as torch._library opaque types; one or " + "both are unavailable in this PyTorch build." + ) + self.name = name + self.base_cls = base_cls + + @classmethod + def try_build(cls, name: str, annot: Any) -> Optional["_FlattenableBucket"]: + stripped, _ = _strip_optional(annot) + if not isinstance(stripped, type): + return None + for base in _flattenable_bases(): + if issubclass(stripped, base): + return cls(name, base) + return None + + def _slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def _slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def _slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self._slot_tensors(), "Tensor[]"), + ] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) + if value is None: + meta = OpaqueSimpleMetadata({self.NONE_MARKER_KEY: self.NONE_MARKER_VAL}) + pg: Any = None + tensors: List[torch.Tensor] = [] + else: + if hasattr(value, "_flatten"): + meta, pg, tensors = value._flatten() + else: + meta, pg, tensors = value._torch_compile_flatten() + return [ + (self._slot_meta(), meta), + (self._slot_pg(), pg), + (self._slot_tensors(), list(tensors)), + ] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + meta = args[self._slot_meta()] + if meta.get(self.NONE_MARKER_KEY) == self.NONE_MARKER_VAL: + kwargs[self.name] = None + return + if hasattr(self.base_cls, "_unflatten"): + kwargs[self.name] = self.base_cls._unflatten( + meta, args[self._slot_pg()], args[self._slot_tensors()] + ) + else: + kwargs[self.name] = self.base_cls._torch_compile_unflatten( + meta, args[self._slot_pg()], args[self._slot_tensors()] + ) + + +class _SimpleBundleBucket(_Bucket): + """Aggregator: bundles every simple-typed field of the dataclass + into a single :class:`OpaqueSimpleMetadata` slot. + + Does not implement :meth:`try_build` because membership is decided + per-field via :meth:`matches_field`, with construction deferred + until all simple field names are collected. + """ + + SLOT = "_simple_meta" + + def __init__(self, names: List[str]) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None: + raise RuntimeError( + "OpaqueSimpleMetadata could not be registered with " + "torch._library.opaque_object; cannot bundle simple fields " + f"{names}. Upgrade PyTorch or drop the simple fields." + ) + self.names = list(names) + + @classmethod + def matches_field(cls, annot: Any) -> bool: + """Whether ``annot`` (Optional-aware, recursive on tuple/list) is + bundled-simple, i.e. eligible for this aggregator. + + Accepts simple primitives, :class:`enum.Enum`, :class:`torch.Size`, + any class registered as a torch.compile *value*-opaque type, and + nested tuples / lists thereof. + """ + annot, _ = _strip_optional(annot) + if annot in OpaqueSimpleMetadata.PRIMITIVE_TYPES: + return True + if isinstance(annot, type) and issubclass(annot, Enum): + return True + if annot is torch.Size: + return True + # Any registered value-opaque class is hashable / FX-reproducible + # and therefore safe to embed in the OpaqueSimpleMetadata bundle. + if isinstance(annot, type): + try: + from torch._library.opaque_object import is_opaque_value_type + except Exception: # pragma: no cover - older torch + is_opaque_value_type = None + if is_opaque_value_type is not None and is_opaque_value_type(annot): + return True + origin = get_origin(annot) + if origin in (tuple, list): + # Inner args may contain Ellipsis (e.g. ``Tuple[int, ...]``); + # only require the *concrete* inner annotations to be simple. + inner = [a for a in get_args(annot) if a is not Ellipsis] + return bool(inner) and all(cls.matches_field(a) for a in inner) + return False + + def schema_slots(self) -> List[Tuple[str, str]]: + return [(self.SLOT, _OPAQUE_SIMPLE_META_TYPE_NAME)] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + bundle = OpaqueSimpleMetadata({n: getattr(owner, n) for n in self.names}) + return [(self.SLOT, bundle)] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + if self.SLOT not in args: + return + meta = args[self.SLOT] + for n in self.names: + kwargs[n] = meta[n] + + +class _UnknownBucket(_Bucket): + """Fallback for fields whose annotation no other bucket claims. + Emits no schema slot; pack rejects non-trivial values to avoid silent + data loss; unpack restores the field as ``None``. + + A "trivial" value is one that carries no observable information -- + ``None`` itself or a sequence of all-``None`` entries (e.g. a + ``tensor_objects`` payload from :func:`prepare_for_saving` over a + bag of plain bf16 tensors). Such values are dropped on the way into + the op and reconstructed from companion fields (``saved_tensors``, + quantizer metadata, ...) on the way out. + + Constructed directly by :meth:`ArgObject._buckets` (it has no + annotation-based ``try_build`` -- it's the explicit "no match" case). + """ + + def __init__(self, name: str, owner_cls_name: str) -> None: + self.name = name + self.owner_cls_name = owner_cls_name + + @staticmethod + def _is_trivial(value: Any) -> bool: + if value is None: + return True + if isinstance(value, (list, tuple)): + return all(v is None for v in value) + return False + + def schema_slots(self) -> List[Tuple[str, str]]: + return [] + + def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + value = getattr(owner, self.name, None) + if not self._is_trivial(value): + raise TypeError( + f"{self.owner_cls_name} field {self.name!r} has a type not " + "supported by torch.compile (not Tensor, simple, " + "ProcessGroup, or Quantizer) and carries " + "a non-trivial value; override " + f"{self.owner_cls_name}.torch_compile_pack to handle it." + ) + return [] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = None + + +# Buckets, in priority order, that own ``try_build`` for a single field. +_FIELD_BUCKETS: Tuple[type, ...] = ( + _TensorOrStorageBucket, + _TensorBucket, + _TensorListBucket, + _ProcessGroupBucket, + _FlattenableBucket, +) + + +# --------------------------------------------------------------------------- # +# ArgObject +# --------------------------------------------------------------------------- # + + +class ArgObject: + """Base class for structured argument containers passed to TE custom ops. + + Subclassed by per-module forward / backward dataclasses + (e.g. ``LinearFwdArgs``, ``LinearBwdArgs``). Provides the pack / + unpack / schema hooks consumed by :func:`_te_register_custom_op` + when wiring the dataclass into a ``torch.library`` schema. + + The default pack / unpack / schema implementations dispatch on + dataclass field annotations. Each field is mapped to exactly one + :class:`_Bucket` (see module-level docstring); the three methods + then become trivial iterations over the bucket list. + """ + + @classmethod + def _resolved_field_annotations(cls) -> List[Tuple[str, Any]]: + if not dataclasses.is_dataclass(cls): + raise TypeError( + f"{cls.__name__} must be a @dataclass to use the default " + f"ArgObject torch_compile_* implementations." + ) + # ``get_type_hints`` resolves forward references and PEP 563 + # ``from __future__ import annotations`` strings. + try: + hints = get_type_hints(cls) + except Exception: + hints = {} + return [ + (f.name, hints.get(f.name, f.type)) for f in dataclasses.fields(cls) + ] + + @classmethod + def _buckets(cls) -> List[_Bucket]: + """Build the bucket list for this dataclass from field annotations. + + Dispatch order per field: try each bucket in :data:`_FIELD_BUCKETS` + (Tensor, ProcessGroup, Quantizer); if none claims the field, route + it to :class:`_SimpleBundleBucket` if its annotation is bundle-able, + else to :class:`_UnknownBucket`. + + Intentionally **not** cached. Caching on ``cls`` (e.g. by writing + ``cls.__te_buckets__``) tickles Dynamo: subsequent reads of + ``cls.__dict__`` from a compiled function trigger + "mappingproxy affected by dictionary mutation" graph breaks. + Hot paths must instead capture the bucket list once at op + registration time and pass it explicitly to :meth:`torch_compile_pack` + / :meth:`torch_compile_unpack`. + """ + buckets: List[_Bucket] = [] + simple_names: List[str] = [] + for name, annot in cls._resolved_field_annotations(): + built: Optional[_Bucket] = None + for bucket_cls in _FIELD_BUCKETS: + built = bucket_cls.try_build(name, annot) + if built is not None: + break + if built is not None: + buckets.append(built) + elif _SimpleBundleBucket.matches_field(annot): + simple_names.append(name) + else: + buckets.append(_UnknownBucket(name, cls.__name__)) + if simple_names: + buckets.append(_SimpleBundleBucket(simple_names)) + return buckets + + @classmethod + def torch_compile_get_schema(cls) -> List[Tuple[str, str]]: + """Default: derive the schema from dataclass annotations. + + See :class:`_Bucket` subclasses for the per-field-kind layout + (Tensor, ProcessGroup, Quantizer, and the + aggregated ``_simple_meta`` bundle of simple fields). + """ + return [slot for b in cls._buckets() for slot in b.schema_slots()] + + def torch_compile_pack( + self, buckets: Optional[List[_Bucket]] = None + ) -> Dict[str, Any]: + """Default: ask each bucket to extract its slot(s) from ``self``. + + ``buckets`` is the precomputed bucket list (from + :meth:`_buckets`). Hot paths -- e.g. the closures created by + :func:`_te_register_custom_op` -- must pass it to avoid recomputing + and, critically, to keep Dynamo away from ``cls.__dict__`` while + tracing. When ``None``, this method recomputes the buckets + (eager-only fallback intended for ad-hoc / test usage). + """ + if buckets is None: + buckets = type(self)._buckets() + out: Dict[str, Any] = {} + for bucket in buckets: + for name, value in bucket.pack(self): + out[name] = value + return out + + @classmethod + def torch_compile_unpack( + cls, + args: Dict[str, Any], + buckets: Optional[List[_Bucket]] = None, + ) -> "ArgObject": + """Default: ask each bucket to inject its field(s) into a fresh + instance built via ``__new__`` (we bypass the dataclass + ``__init__`` so unknown-typed fields can stay as ``None`` even + when they have no default). + + ``buckets`` semantics match :meth:`torch_compile_pack`: hot paths + pass the precomputed list, eager-only callers may omit it. + """ + if buckets is None: + buckets = cls._buckets() + kwargs: Dict[str, Any] = {} + for bucket in buckets: + bucket.unpack(args, kwargs) + obj = cls.__new__(cls) + for k, v in kwargs.items(): + object.__setattr__(obj, k, v) + return obj + + @classmethod + def torch_compile_get_input_tensors_for_grad(cls) -> List[str]: + """Names of forward inputs (from :meth:`torch_compile_get_schema`) + for which the corresponding ``backward_impl`` produces gradients, + in the exact order ``backward_impl`` returns them. + + Only meaningful on the forward arg type. Default is ``[]`` (no + gradients, e.g. for inference-only ops). The wrapper uses this + to pad the autograd return tuple with ``None`` for every input + not listed here, so torch sees one slot per forward input as + required by ``register_autograd``. + """ + return [] + + +def _te_register_custom_op( + *, + linear_impl: Callable[[Any], Any], + linear_arg_type: type, + setup_context: Callable[..., None], + backward_impl: Callable[[Any], Any], + backward_obj: type, + backward_arg_type: type, + num_outputs: int, + linear_fake_impl: Optional[Callable[[Any], Any]] = None, + backward_fake_impl: Optional[Callable[[Any], Any]] = None, + op_namespace: str = "transformer_engine", + op_name: str = "linear", +) -> Callable[..., Any]: + """Register a TE module's forward + backward as a single torch custom op. + + Parameters + ---------- + linear_impl + Eager forward implementation. Receives a single argument of type + ``linear_arg_type`` and must return a tuple of the form + ``(*output_tensors, tensors_to_save, tensor_objects, ctx_attrs)`` + where: + + * ``output_tensors`` -- one or more :class:`torch.Tensor` outputs + returned to the caller. + * ``tensors_to_save`` -- flat list of :class:`torch.Tensor` to be + stashed via ``ctx.save_for_backward``. + * ``tensor_objects`` -- the metadata object produced by + :func:`prepare_for_saving`, paired with ``tensors_to_save`` to + let the backward reconstruct quantized / structured tensors. + * ``ctx_attrs`` -- non-tensor state to attach to the autograd + context, restricted to values that cannot be derived from the + forward args inside ``setup_context``. + linear_arg_type + Dataclass type aggregating all forward inputs (e.g. + :class:`LinearFwdArgs`). Used to (re)build the structured argument + from the flat tensor / non-tensor inputs accepted by the custom op. + setup_context + Eager autograd ``setup_context`` analogue. Receives a freshly + constructed ``backward_obj`` instance, the forward args, the + forward output, and ``ctx_attrs`` produced by ``linear_impl``; + is responsible for populating the backward-state object so that + ``backward_impl`` can later consume it. + backward_impl + Eager backward implementation. Receives a single argument of type + ``backward_arg_type`` and returns the gradient tuple. + backward_obj + Dataclass / class used to instantiate a fresh backward-state + container at the end of the forward pass (typically the same as + ``backward_arg_type``). + backward_arg_type + Type accepted by ``backward_impl``. May differ from ``backward_obj`` + if the backward op needs a wrapped / opaque view of the state. + num_outputs + Number of user-facing tensor outputs returned by ``linear_impl``. + The op concatenates ``[*output_tensors, *tensors_to_save]`` into + a single ``Tensor[]`` return; the wrapper uses ``num_outputs`` to + split the two halves on the way back out. + + The list of forward inputs that receive gradients is declared on + the forward arg type itself, via + :meth:`ArgObject.torch_compile_get_input_tensors_for_grad`. + ``backward_impl`` must return its gradients in that exact order. + linear_fake_impl + Optional fake (shape inference) counterpart of ``linear_impl``, + registered via ``torch.library.register_fake``. Returns the same + tuple shape as ``linear_impl`` -- ``(*output_tensors, + tensors_to_save, tensor_objects, ctx_attrs)`` -- but every + ``torch.Tensor`` is a fake tensor (allocated via + ``quantizer.make_empty`` or ``torch.empty``) carrying only the + correct shape / dtype / device, with no real storage or + computation. ``tensor_objects`` and ``ctx_attrs`` must be + structurally identical to those produced by ``linear_impl`` so + that ``setup_context`` and ``backward_impl`` see the same + non-tensor state in eager and traced modes. + backward_fake_impl + Optional fake counterpart of ``backward_impl``. Returns the same + gradient tuple as ``backward_impl``, with fake tensors in place + of the real gradients. + op_namespace, op_name + Library namespace / op name used when registering with + ``torch.library``. + + Returns + ------- + Callable + A function ``forward_fn(linear_arg_type_instance)`` that dispatches + through the registered custom op, returning the user-facing + outputs (single tensor if ``num_outputs == 1``, otherwise a + tuple). Use under ``torch.compiler.is_compiling()`` as a drop-in + for ``Function.apply``. + """ + + fwd_qualname = f"{op_namespace}::{op_name}" + bwd_op_name = f"{op_name}_backward" + bwd_qualname = f"{op_namespace}::{bwd_op_name}" + + # Precompute the bucket list for both arg types and capture them in + # the closures below. Critical for the compiled path: re-deriving + # buckets at call time would force ``ArgObject._buckets`` to read + # ``cls.__dict__`` from inside a Dynamo-traced function, which + # triggers a "mappingproxy affected by dictionary mutation" graph + # break under ``fullgraph=True``. + fwd_buckets: List[_Bucket] = linear_arg_type._buckets() + bwd_buckets: List[_Bucket] = backward_arg_type._buckets() + + def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: + spec = [slot for b in buckets for slot in b.schema_slots()] + names = [name for name, _ in spec] + schema_str = "(" + ", ".join(f"{type_str} {name}" for name, type_str in spec) + ")" + return schema_str, names + + fwd_schema_args, fwd_arg_names = _build_schema(fwd_buckets) + bwd_schema_args, bwd_arg_names = _build_schema(bwd_buckets) + + # ``torch.library.register_autograd`` requires the backward to return + # one grad slot per forward input, with the same Python tree + # structure as the input itself: a ``Tensor[]`` slot must get back a + # ``list``, not a bare ``None``. Precompute the per-slot "no-grad" + # value so the autograd return matches. + fwd_slot_defaults: List[Any] = [] + for bucket in fwd_buckets: + for _, type_str in bucket.schema_slots(): + fwd_slot_defaults.append([] if type_str.endswith("[]") else None) + + # Validate ``input_tensors_for_grad`` references real forward inputs + # and precompute the positions where backward grads land in the + # autograd return tuple. Some logical fields (e.g. Tensor-or-storage + # fields) expand to a ``Tensor[]`` slot; their gradient must be returned + # as a list matching that input tree. + input_tensors_for_grad = linear_arg_type.torch_compile_get_input_tensors_for_grad() + fwd_grad_targets: Dict[str, Tuple[int, bool]] = {} + slot_offset = 0 + for bucket in fwd_buckets: + slots = bucket.schema_slots() + if isinstance(bucket, _TensorBucket): + fwd_grad_targets[bucket.name] = (slot_offset, False) + elif isinstance(bucket, _TensorListBucket): + fwd_grad_targets[bucket.name] = (slot_offset, True) + elif isinstance(bucket, _TensorOrStorageBucket): + for i, (slot_name, _) in enumerate(slots): + if slot_name == bucket._slot_tensors(): + fwd_grad_targets[bucket.name] = (slot_offset + i, True) + break + slot_offset += len(slots) + unknown_grad_names = [n for n in input_tensors_for_grad if n not in fwd_grad_targets] + if unknown_grad_names: + raise ValueError( + f"{linear_arg_type.__name__}.torch_compile_get_input_tensors_for_grad() " + f"contains names not present in " + f"{linear_arg_type.__name__}.torch_compile_get_schema(): " + f"{unknown_grad_names}" + ) + grad_targets = [fwd_grad_targets[n] for n in input_tensors_for_grad] + num_grad_inputs = len(input_tensors_for_grad) + + lib = torch.library.Library(op_namespace, "FRAGMENT") + # Forward op concatenates user outputs and tensors_to_save into a + # single ``Tensor[]`` return so that autograd's ``setup_context`` can + # stash the saved-for-backward tensors without re-running the eager + # impl. The schema is non-nullable (``Tensor[]``, not ``Tensor?[]``) + # because ``torch.library.register_autograd`` does not propagate + # ``grad_fn`` to a nullable list output. ``None`` entries on either + # side are smuggled through via :func:`_encode_none` / + # :func:`_decode_none` sentinels (see below). + lib.define(f"{op_name}{fwd_schema_args} -> Tensor[]") + lib.define(f"{bwd_op_name}{bwd_schema_args} -> Tensor[]") + + def _outputs_for_setup(outputs: List[torch.Tensor]) -> Any: + return outputs[0] if num_outputs == 1 else tuple(outputs) + + def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: + from transformer_engine.pytorch.quantized_tensor import prepare_for_saving + + return prepare_for_saving(*(tensors or ())) + + def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: + from transformer_engine.pytorch.quantized_tensor import restore_from_saved + + return restore_from_saved(tensor_objects, saved_tensors) + + def _fwd_impl(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(fwd_arg_names, flat)) + obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) + result = linear_impl(obj) + outputs = list(result[:num_outputs]) + tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) + return [_encode_none(t) for t in outputs + tensors_to_save] + + lib.impl(op_name, _fwd_impl, "CompositeExplicitAutograd") + + if linear_fake_impl is not None: + + def _fwd_fake(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(fwd_arg_names, flat)) + obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) + result = linear_fake_impl(obj) + outputs = list(result[:num_outputs]) + tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) + return [_encode_none(t) for t in outputs + tensors_to_save] + + torch.library.register_fake(fwd_qualname, _fwd_fake, lib=lib) + + def _check_bwd_len(grads): + if len(grads) != num_grad_inputs: + raise RuntimeError( + f"{op_namespace}::{bwd_op_name} expected backward_impl to " + f"return {num_grad_inputs} grads (one per " + f"input_tensors_for_grad entry), got {len(grads)}" + ) + + def _bwd_impl(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(bwd_arg_names, flat)) + obj = backward_arg_type.torch_compile_unpack(kwargs, bwd_buckets) + grads = list(backward_impl(obj)) + _check_bwd_len(grads) + return [_encode_none(g) for g in grads] + + lib.impl(bwd_op_name, _bwd_impl, "CompositeExplicitAutograd") + + if backward_fake_impl is not None: + + def _bwd_fake(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(bwd_arg_names, flat)) + obj = backward_arg_type.torch_compile_unpack(kwargs, bwd_buckets) + grads = list(backward_fake_impl(obj)) + _check_bwd_len(grads) + return [_encode_none(g) for g in grads] + + torch.library.register_fake(bwd_qualname, _bwd_fake, lib=lib) + + # Re-run fake (or real) impl in setup_context to recover + # tensor_objects / ctx_attrs, which are not part of the op's return. + fake_for_setup = linear_fake_impl if linear_fake_impl is not None else linear_impl + + def _setup_context(ctx, inputs, output): + ctx._te_fwd_tensor_list_lengths = { + i: len(value) for i, value in enumerate(inputs) if isinstance(value, list) + } + kwargs = dict(zip(fwd_arg_names, inputs)) + fwd_obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) + fake_result = fake_for_setup(fwd_obj) + _, tensor_objects = _prepare_for_saving(fake_result[num_outputs]) + ctx_attrs = fake_result[num_outputs + 2] + + # Split op output: first num_outputs are user-facing tensors, + # the rest are tensors_to_save. ``output`` is a flat ``Tensor[]`` + # with our None-sentinels in place; decode here so downstream + # eager code sees the original ``None``\ s. + user_outputs = [_decode_none(t) for t in output[:num_outputs]] + op_saved_tensors = [_decode_none(t) for t in output[num_outputs:]] + tensors_to_save_from_forward = _restore_from_saved( + tensor_objects, + op_saved_tensors, + ) + + bwd_obj = backward_obj() + tensors_to_save_from_setup = setup_context( + bwd_obj, + fwd_obj, + _outputs_for_setup(user_outputs), + ctx_attrs, + tensors_to_save_from_forward, + ) + tensors_to_save, tensor_objects = _prepare_for_saving(tensors_to_save_from_setup) + ctx.tensor_objects = tensor_objects + ctx.save_for_backward(*tensors_to_save) + ctx.bwd_obj = bwd_obj + + def _autograd_backward(ctx, *grad_outputs): + bwd_obj = ctx.bwd_obj + if hasattr(bwd_obj, "setup_saved_tensors"): + bwd_obj.setup_saved_tensors(ctx.saved_tensors, ctx.tensor_objects) + ctx.tensor_objects = None + # The forward op returns a single ``Tensor[]`` (concatenation of + # user outputs and saved tensors), so ``grad_outputs`` is a + # 1-tuple containing the per-element grad list. Only the first + # ``num_outputs`` of those correspond to user-facing outputs; + # ``grad_output`` for the backward is the grad of the primary + # output. + per_output_grads = grad_outputs[0] + bwd_obj.grad_output = _decode_none(per_output_grads[0]) + kwargs = backward_arg_type.torch_compile_pack(bwd_obj, bwd_buckets) + bwd_args_flat = [kwargs[name] for name in bwd_arg_names] + bwd_op = getattr(getattr(torch.ops, op_namespace), bwd_op_name) + grads = [_decode_none(g) for g in bwd_op(*bwd_args_flat)] + # ``register_autograd`` requires one grad slot per forward input + # with the same tree structure as the input (a ``Tensor[]`` slot + # must get back a list, never a bare ``None``). Start from the + # precomputed per-slot defaults and overlay the produced grads + # at the positions declared by ``input_tensors_for_grad``. + out: List[Any] = list(fwd_slot_defaults) + tensor_list_lengths = getattr(ctx, "_te_fwd_tensor_list_lengths", {}) + for (pos, as_list), g in zip(grad_targets, grads): + if as_list: + length = tensor_list_lengths.get(pos, 1) + out[pos] = [g] + [None] * (length - 1) + else: + out[pos] = g + return tuple(out) + + torch.library.register_autograd( + fwd_qualname, + _autograd_backward, + setup_context=_setup_context, + lib=lib, + ) + + fwd_op = getattr(getattr(torch.ops, op_namespace), op_name) + + def forward_fn(fwd_args): + # Bind ``lib`` here so its registrations (impl / register_fake / + # register_autograd) outlive ``_te_register_custom_op`` even if + # all other references to it are dropped: ``torch.library`` uses + # the ``Library`` instance lifetime for all attached registrations. + _ = lib # noqa: F841 -- closure-captured for lifetime only + kwargs = linear_arg_type.torch_compile_pack(fwd_args, fwd_buckets) + flat = [kwargs[name] for name in fwd_arg_names] + result = fwd_op(*flat) + outputs = [_decode_none(t) for t in result[:num_outputs]] + if num_outputs == 1: + return outputs[0] + return tuple(outputs) + + return forward_fn From 8bc7a1a111b024c9a319a89cc9bcfea6b960d302 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 13 May 2026 20:27:24 +0200 Subject: [PATCH 03/16] [PyTorch] Iterate on torch.compile support for Linear Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 82 +- transformer_engine/common/recipe/__init__.py | 94 ++ transformer_engine/pytorch/constants.py | 28 + transformer_engine/pytorch/dynamo.py | 1419 ++++++++++------- transformer_engine/pytorch/module/base.py | 32 +- transformer_engine/pytorch/module/linear.py | 113 +- transformer_engine/pytorch/quantization.py | 58 +- .../pytorch/quantized_tensor.py | 218 ++- .../pytorch/tensor/float8_blockwise_tensor.py | 56 +- .../pytorch/tensor/float8_tensor.py | 187 ++- .../pytorch/tensor/mxfp8_tensor.py | 45 +- .../pytorch/tensor/nvfp4_tensor.py | 67 +- .../float8_blockwise_tensor_storage.py | 89 +- .../tensor/storage/float8_tensor_storage.py | 102 +- .../tensor/storage/mxfp8_tensor_storage.py | 89 +- .../tensor/storage/nvfp4_tensor_storage.py | 97 +- 16 files changed, 2139 insertions(+), 637 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..e14aa39bbf 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -97,8 +97,8 @@ def __fx_repr__(self): def _make_qfactory(tag: str): """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" - def qfactory(role: str): - return ToyQuantizer(tag=f"{tag}:{role}") + def qfactory(role): + return ToyQuantizer(tag=f"{tag}:{role.tensor_type}") return qfactory @@ -324,3 +324,81 @@ def fn(inp): out = compiled(inp) out.sum().backward() + + +@pytest.mark.parametrize( + "fp8_recipe", + [None, *_all_recipes], + ids=lambda r: "bf16" if r is None else type(r).__name__, +) +def test_te_linear_compiles(fp8_recipe): + """torch.compile(fullgraph=True) of ``te.Linear`` under every built-in + recipe (and the bf16-only baseline with no autocast). + + Exercises the custom-op path in + :mod:`transformer_engine.pytorch.dynamo`: forward goes through + ``_linear_compiled_op``, backward through the registered + ``transformer_engine::linear_backward`` op, and the dataclass + arg-objects are packed/unpacked via the bucket dispatch in + :mod:`transformer_engine.pytorch.dynamo`. + """ + if fp8_recipe is not None and not fp8_available: + pytest.skip(reason_for_no_fp8) + + dtype = torch.bfloat16 + device = "cuda" + + # FP8 GEMMs require leading dimensions divisible by 16; pick + # in/out features and batch comfortably above that minimum. + model = te.Linear(64, 32, params_dtype=dtype, device=device) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + if fp8_recipe is None: + return model(inp) + with te.autocast(recipe=fp8_recipe): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + assert out.shape == (32, 32) + assert inp.grad is not None + assert model.weight.grad is not None, "weight.grad missing" + assert model.weight.grad.shape == model.weight.shape, ( + f"weight.grad shape {tuple(model.weight.grad.shape)} != " + f"weight shape {tuple(model.weight.shape)}" + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_te_linear_compile_with_quantized_fp8_weight(): + """torch.compile should handle Linear weights initialized as FP8 tensors.""" + dtype = torch.bfloat16 + device = "cuda" + fp8_recipe = recipe.Float8CurrentScaling() + + with te.quantized_model_init(enabled=True, recipe=fp8_recipe): + model = te.Linear(64, 32, params_dtype=dtype, device=device) + + assert isinstance(model.weight, te.Float8Tensor) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe): + return model(inp) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + out.sum().backward() + assert out.shape == (32, 32) + assert inp.grad is not None + assert model.weight.grad is not None, "Float8Tensor weight.grad missing" + assert model.weight.grad.shape == model.weight.shape, ( + f"Float8Tensor weight.grad shape {tuple(model.weight.grad.shape)} != " + f"weight shape {tuple(model.weight.shape)}" + ) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b773a81d1b..57d7e3965a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -166,6 +166,88 @@ def custom(cls): """Whether the given recipe is custom.""" return issubclass(cls, CustomRecipe) + # ------------------------------------------------------------------ # + # torch.compile flatten / unflatten protocol + # ------------------------------------------------------------------ # + # The flattenable bucket in + # :mod:`transformer_engine.pytorch.dynamo` ships ``Recipe`` instances + # through TE custom ops by calling :meth:`_flatten` (instance method + # on each concrete subclass) and :meth:`_unflatten` (classmethod on + # this base, which dispatches by ``_rcls`` stamped into the + # metadata bundle). The default implementation reads + # :func:`dataclasses.fields` and flattens nested ``@dataclass`` + # fields with ``"."`` keys; reconstruction + # instantiates the target class with default args and writes the + # flattened values back. Subclasses can override either method when + # their structure is too irregular for the generic round-trip. + + def _flatten(self): # noqa: D401 -- short name preferred + """Return ``(OpaqueSimpleMetadata, None, [])``.""" + # Lazy imports keep ``common`` independent of pytorch. + from dataclasses import fields, is_dataclass + from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata + + payload: dict = {"_rcls": type(self).__qualname__} + for f in fields(self): + v = getattr(self, f.name) + if is_dataclass(v) and not isinstance(v, type): + for sf in fields(v): + payload[f"{f.name}.{sf.name}"] = getattr(v, sf.name) + else: + payload[f.name] = v + return OpaqueSimpleMetadata(payload), None, [] + + @classmethod + def _unflatten(cls, meta, _ref, _tensors): + """Dispatch to the concrete subclass identified by + ``meta['_rcls']`` and rehydrate fields (including nested + ``@dataclass`` fields written under ``"."`` + keys by :meth:`_flatten`).""" + from dataclasses import fields, is_dataclass + + target_name = meta["_rcls"] + target_cls = _RECIPE_REGISTRY.get(target_name) + if target_cls is None: + raise KeyError( + f"Unknown recipe class {target_name!r} during unflatten; " + "is the subclass imported in transformer_engine.common.recipe?" + ) + + out = target_cls() + nested: dict = {} + for k, v in meta.items(): + if k == "_rcls": + continue + if "." in k: + parent, child = k.split(".", 1) + nested.setdefault(parent, {})[child] = v + else: + setattr(out, k, v) + for parent, children in nested.items(): + target = getattr(out, parent, None) + if target is None or not is_dataclass(target): + continue + # Nested dataclasses (e.g. ``MMParams``) may be frozen, so + # rebuild the instance with merged kwargs and reassign. + cur_kwargs = {f.name: getattr(target, f.name) for f in fields(target)} + cur_kwargs.update(children) + setattr(out, parent, type(target)(**cur_kwargs)) + return out + + +# Lazily populated by :meth:`Recipe.__init_subclass__` so that +# :meth:`Recipe._unflatten` can dispatch by ``__qualname__``. +_RECIPE_REGISTRY: dict = {} + + +def _register_recipe_subclass(cls) -> None: + _RECIPE_REGISTRY[cls.__qualname__] = cls + + +# Recipe uses pydantic.dataclasses which can interfere with hooking +# ``__init_subclass__``; register subclasses explicitly at the bottom of +# this module instead. + @dataclass(repr=False) class DelayedScaling(Recipe): @@ -654,3 +736,15 @@ def _make_repr(self) -> str: f"qfactory={self.qfactory}, " f"backward_override={self.backward_override}" ) + + +# Populate the dispatch registry consumed by :meth:`Recipe._unflatten`. +for _R in ( + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, + Float8BlockScaling, + NVFP4BlockScaling, + CustomRecipe, +): + _register_recipe_subclass(_R) diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 2aff4fd8e8..694de2d94e 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -41,6 +41,34 @@ tex.DType.kBFloat16: torch.bfloat16, } +# Map: TE DType *id* (Python int) -> TE DType enum. Used by +# :func:`canonicalize_te_dtype` to recover the pybind enum from its +# integer id without going through ``tex.DType(int)``, which Dynamo +# cannot trace (pybind11 enum constructor is opaque). +TE_DType_ID_To_TE = { + int(tex.DType.kByte): tex.DType.kByte, + int(tex.DType.kFloat8E4M3): tex.DType.kFloat8E4M3, + int(tex.DType.kFloat8E5M2): tex.DType.kFloat8E5M2, + int(tex.DType.kFloat4E2M1): tex.DType.kFloat4E2M1, + int(tex.DType.kInt32): tex.DType.kInt32, + int(tex.DType.kFloat32): tex.DType.kFloat32, + int(tex.DType.kFloat16): tex.DType.kFloat16, + int(tex.DType.kBFloat16): tex.DType.kBFloat16, +} + + +def canonicalize_te_dtype(dtype): + """Accept either a TE ``DType`` enum or its Python ``int`` id. + + Recipe state keeps dtype ids as Python ``int`` values for cheap, + trace-friendly comparisons. Quantizer objects, however, are passed to + TE's C++ bindings, which expect the pybind ``tex.DType`` enum. + """ + if isinstance(dtype, int): + return TE_DType_ID_To_TE[dtype] + return dtype + + # Cache enum -> int conversions to avoid repeated PyObject lookups. FP8FwdTensorIdx = SimpleNamespace( GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT), diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index e05b865da3..5b62757ee3 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -13,6 +13,7 @@ Dict, List, Optional, + Sequence, Tuple, Union, get_args, @@ -24,12 +25,15 @@ __all__ = [ - "ArgObject", "OpaqueSimpleMetadata", "_te_register_custom_op", ] +_TE_OP_NAMESPACE = "transformer_engine_compile" +_TE_LIB = torch.library.Library(_TE_OP_NAMESPACE, "FRAGMENT") + + # Sentinel for ``None`` entries inside the op's flat ``Tensor[]`` return. # Used by :func:`_te_register_custom_op` to support ``None`` outputs (e.g. # an FP8 weight workspace returned only on the cache-miss path) on a @@ -89,12 +93,9 @@ class OpaqueSimpleMetadata: @classmethod def _is_opaque_value(cls, value: Any) -> bool: - """Whether ``value``'s class is registered as a value-opaque type.""" - try: - from torch._library.opaque_object import is_opaque_value_type - except Exception: # pragma: no cover - older torch - return False - return is_opaque_value_type(type(value)) + """Whether ``value``'s class is registered as a value-opaque type. + """ + return _is_opaque_value_type(type(value)) @classmethod def is_simple_value(cls, value: Any) -> bool: @@ -150,16 +151,10 @@ def _fmt_simple(cls, value: Any) -> str: return value.__fx_repr__()[0] return repr(value) - def __init__( - self, - data: Optional[Dict[str, Any]] = None, - /, - **kwargs: Any, - ) -> None: - merged: Dict[str, Any] = dict(data) if data else {} - merged.update(kwargs) + def __init__(self, data: Optional[Dict[str, Any]] = None) -> None: + data = dict(data) if data else {} cls = type(self) - for k, v in merged.items(): + for k, v in data.items(): if not cls.is_simple_value(v): raise TypeError( f"OpaqueSimpleMetadata field '{k}' has unsupported " @@ -168,9 +163,9 @@ def __init__( f"Enum, torch.Size, registered torch.compile value-" f"opaque types) and tuples/lists thereof are allowed." ) - self._data: Dict[str, Any] = merged + self._data: Dict[str, Any] = data self._frozen: Tuple[Tuple[str, Any], ...] = tuple( - (k, cls._to_hashable(v)) for k, v in sorted(merged.items()) + (k, cls._to_hashable(v)) for k, v in sorted(data.items()) ) def __getitem__(self, key: str) -> Any: @@ -276,6 +271,7 @@ def __repr__(self) -> str: try: from torch._library.opaque_object import ( get_opaque_type_name, + is_opaque_value_type as _is_opaque_value_type, register_opaque_type, ) @@ -296,6 +292,7 @@ def __repr__(self) -> str: except Exception: # pragma: no cover - distributed not built / disabled _PROCESS_GROUP_TYPE_NAME = None except Exception: # pragma: no cover - older torch without opaque_object + _is_opaque_value_type = None _OPAQUE_SIMPLE_META_TYPE_NAME = None _PROCESS_GROUP_TYPE_NAME = None @@ -304,21 +301,20 @@ def __repr__(self) -> str: # Field buckets # --------------------------------------------------------------------------- # -# Each dataclass field of an :class:`ArgObject` is mapped to exactly one +# Each dataclass field of an argument container is mapped to exactly one # bucket. A bucket owns the full per-field "vocabulary" -- which schema # slots it emits, how its packed value(s) are produced from the dataclass # instance, and how the unpacked value is re-injected into the -# reconstructed instance. ``ArgObject`` then becomes three trivial loops -# over a list of buckets, instead of three parallel branch ladders. +# reconstructed instance. The module-level :func:`_get_buckets` / +# :func:`_get_schema` / :func:`_pack` / :func:`_unpack` helpers then +# become trivial loops over a list of buckets, instead of three parallel +# branch ladders. # # Five bucket kinds are used: # # * :class:`_TensorBucket` -- :class:`torch.Tensor` / # :class:`Optional[torch.Tensor] ` -> one ``Tensor`` / # ``Tensor?`` slot. -# * :class:`_TensorListBucket` -- ``List[torch.Tensor]`` / -# ``Tuple[torch.Tensor, ...]`` -> one ``Tensor[]`` slot. Used for -# variable-length tensor sequences such as ``ctx.saved_tensors``. # * :class:`_ProcessGroupBucket` -- :class:`torch.distributed.ProcessGroup` # (already registered upstream as a value-opaque type) -> one direct # slot. @@ -378,7 +374,7 @@ def schema_slots(self) -> List[Tuple[str, str]]: """Return ``[(slot_name, schema_type_str), ...]`` for this field.""" raise NotImplementedError - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: """Return ``[(slot_name, value), ...]`` extracted from ``owner``.""" raise NotImplementedError @@ -388,33 +384,155 @@ def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: raise NotImplementedError -class _TensorOrStorageBucket(_Bucket): - """``Tensor | QuantizedTensorStorage`` -> meta / pg / Tensor[] slots. +class _MetaPGTensorsBucket(_Bucket): + """Shared three-slot bucket emitting ``__meta`` / + ``__pg`` / ``__tensors``. - Plain tensors are carried as a single-element ``Tensor[]``. Quantized - tensor wrappers and storage shells are carried through their - ``_torch_compile_flatten`` protocol so the backward op receives the same - structured object type that eager restoration produced. + Used by every field whose value must be carried as the triple + ``(OpaqueSimpleMetadata, ProcessGroup?, Tensor[])`` -- today this + covers ``Tensor | QuantizedTensorStorage`` unions (see + :class:`_UniversalTensorBucket`) and ``Quantizer`` / ``Recipe`` + instances (see :class:`_FlattenableBucket`). Concrete subclasses + implement :meth:`_pack_value` / :meth:`_unpack_value` for their + flatten/unflatten protocol; the rest of the bucket contract is + identical and lives here. """ - SUFFIX_META = "__tsmeta" - SUFFIX_PG = "__tspg" - SUFFIX_TENSORS = "__tstensors" + SUFFIX_META = "__meta" + SUFFIX_PG = "__pg" + SUFFIX_TENSORS = "__tensors" + + def __init__(self, name: str) -> None: + if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: + raise RuntimeError( + f"Field {name!r} requires both OpaqueSimpleMetadata and " + "torch.distributed.ProcessGroup to be registered as " + "torch._library opaque types; one or both are " + "unavailable in this PyTorch build." + ) + self.name = name + + def _slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def _slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def _slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self._slot_tensors(), "Tensor[]"), + ] - KIND_KEY = "_te_tensor_storage_kind" + def pack(self, owner: Any) -> List[Tuple[str, Any]]: + value = getattr(owner, self.name) + meta, pg, tensors = self._pack_value(value) + return [ + (self._slot_meta(), meta), + (self._slot_pg(), pg), + (self._slot_tensors(), list(tensors)), + ] + + def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: + kwargs[self.name] = self._unpack_value( + args[self._slot_meta()], + args[self._slot_pg()], + args[self._slot_tensors()], + ) + + def _pack_value( + self, value: Any + ) -> Tuple[Any, Any, List[torch.Tensor]]: + """Flatten one field value into ``(meta, pg, tensors)``.""" + raise NotImplementedError + + def _unpack_value( + self, meta: Any, pg: Any, tensors: List[torch.Tensor] + ) -> Any: + """Inverse of :meth:`_pack_value`.""" + raise NotImplementedError + + +class _UniversalTensorBucket(_Bucket): + """``Tensor | QuantizedTensorStorage`` (also subclass-tensor) field. + + Emits four schema slots per field, regardless of the runtime value: + + * ```` (``Tensor?``) -- plain tensor / subclass tensor + (e.g. :class:`Float8Tensor`) + passes through here untouched. + ``None`` for the storage path. + * ``__tensors`` (``Tensor[]``) -- flat inner tensors when the + value was carried through a + flatten protocol (storage at + pack-time, or a subclass that + was dispatched into flat form + by ``register_torch_dispatch`` + on the outer op). + * ``__pg`` (``ProcessGroup?``) -- distributed handle attached + to the flatten metadata, if + any. + * ``__meta`` (``OpaqueSimpleMetadata``) -- everything else: + the storage / subclass meta + dict, plus a ``__kind__`` + marker telling the unpacker + which slot to look at: + ``"none"``, ``"tensor"``, or + ``"storage"`` (the latter + covers both storage and any + already-flattened subclass). + + Storage values are flattened at ``_pack`` time (callsite). Plain + tensors -- including subclass instances -- are passed unchanged + through ````; under ``torch.compile`` an outer-op + ``register_torch_dispatch`` rule turns each registered subclass + into the storage layout *between* outer and inner op so the + autograd graph stays attached to the user-facing wrapper. + """ + + SUFFIX_TENSORS = "__tensors" + SUFFIX_PG = "__pg" + SUFFIX_META = "__meta" + + KIND_KEY = "__kind__" KIND_NONE = "none" KIND_TENSOR = "tensor" + KIND_STORAGE = "storage" def __init__(self, name: str) -> None: if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: raise RuntimeError( - f"Tensor/storage field {name!r} requires both " - "OpaqueSimpleMetadata and torch.distributed.ProcessGroup " - "to be registered as torch._library opaque types; one or " - "both are unavailable in this PyTorch build." + f"Field {name!r} requires both OpaqueSimpleMetadata and " + "torch.distributed.ProcessGroup to be registered as " + "torch._library opaque types; one or both are " + "unavailable in this PyTorch build." ) self.name = name + def slot_name(self) -> str: + return self.name + + def slot_tensors(self) -> str: + return self.name + self.SUFFIX_TENSORS + + def slot_pg(self) -> str: + return self.name + self.SUFFIX_PG + + def slot_meta(self) -> str: + return self.name + self.SUFFIX_META + + def schema_slots(self) -> List[Tuple[str, str]]: + return [ + (self.slot_name(), "Tensor?"), + (self.slot_tensors(), "Tensor[]"), + (self.slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), + (self.slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), + ] + @staticmethod def _is_tensor_storage_union(annot: Any) -> bool: origin = get_origin(annot) @@ -423,81 +541,71 @@ def _is_tensor_storage_union(annot: Any) -> bool: members = [a for a in get_args(annot) if a is not type(None)] if torch.Tensor not in members: return False - try: - from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage - except Exception: # pragma: no cover - partial init + qts = _quantized_tensor_storage_cls() + if qts is None: return False return any( - isinstance(member, type) and issubclass(member, QuantizedTensorStorage) + isinstance(member, type) and issubclass(member, qts) for member in members ) @classmethod - def try_build(cls, name: str, annot: Any) -> Optional["_TensorOrStorageBucket"]: + def try_build(cls, name: str, annot: Any) -> Optional["_UniversalTensorBucket"]: if cls._is_tensor_storage_union(annot): return cls(name) return None - def _slot_meta(self) -> str: - return self.name + self.SUFFIX_META - - def _slot_pg(self) -> str: - return self.name + self.SUFFIX_PG - - def _slot_tensors(self) -> str: - return self.name + self.SUFFIX_TENSORS - - def schema_slots(self) -> List[Tuple[str, str]]: - return [ - (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), - (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), - (self._slot_tensors(), "Tensor[]"), - ] - - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: value = getattr(owner, self.name) if value is None: - meta = OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_NONE}) - pg: Any = None - tensors: List[torch.Tensor] = [] - else: - from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage - - if isinstance(value, QuantizedTensorStorage): + return [ + (self.slot_name(), None), + (self.slot_tensors(), []), + (self.slot_pg(), None), + (self.slot_meta(), OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_NONE})), + ] + # Plain ``torch.Tensor`` *and* any subclass (e.g. ``Float8Tensor``) + # hit this branch first -- the wrapper is forwarded untouched + # through the ``Tensor?`` slot so the autograd graph stays + # attached to the user-facing tensor object. Subclass-specific + # flattening (if any) happens later inside the outer op's + # ``register_torch_dispatch`` rule. + if isinstance(value, torch.Tensor): + return [ + (self.slot_name(), value), + (self.slot_tensors(), []), + (self.slot_pg(), None), + (self.slot_meta(), OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_TENSOR})), + ] + qts = _quantized_tensor_storage_cls() + if qts is not None and isinstance(value, qts): meta, pg, tensors = value._torch_compile_flatten() - elif isinstance(value, torch.Tensor): - meta = OpaqueSimpleMetadata({self.KIND_KEY: self.KIND_TENSOR}) - pg = None - tensors = [value] - else: - raise TypeError( - f"{type(owner).__name__} field {self.name!r} expected " - "None, torch.Tensor, or QuantizedTensorStorage, got " - f"{type(value).__name__}" - ) + # Stamp the storage-flatten meta with our kind marker so the + # unpacker can route by ``__kind__`` alone. + meta._data[self.KIND_KEY] = self.KIND_STORAGE return [ - (self._slot_meta(), meta), - (self._slot_pg(), pg), - (self._slot_tensors(), list(tensors)), - ] + (self.slot_name(), None), + (self.slot_tensors(), list(tensors)), + (self.slot_pg(), pg), + (self.slot_meta(), meta), + ] + raise TypeError( + f"field {self.name!r} expected None, torch.Tensor, or " + f"QuantizedTensorStorage, got {type(value).__name__}" + ) def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: - meta = args[self._slot_meta()] + meta = args[self.slot_meta()] kind = meta.get(self.KIND_KEY) if kind == self.KIND_NONE: kwargs[self.name] = None return - tensors = args[self._slot_tensors()] if kind == self.KIND_TENSOR: - kwargs[self.name] = tensors[0] + kwargs[self.name] = args[self.slot_name()] return - - from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage - - kwargs[self.name] = QuantizedTensorStorage._torch_compile_unflatten( - meta, - args[self._slot_pg()], - tensors, + qts = _quantized_tensor_storage_cls() + kwargs[self.name] = qts._torch_compile_unflatten( + meta, args[self.slot_pg()], args[self.slot_tensors()] ) @@ -518,67 +626,13 @@ def try_build(cls, name: str, annot: Any) -> Optional["_TensorBucket"]: def schema_slots(self) -> List[Tuple[str, str]]: return [(self.name, self.type_str)] - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: return [(self.name, getattr(owner, self.name))] def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: kwargs[self.name] = args[self.name] -class _TensorListBucket(_Bucket): - """``List[Tensor]`` / ``Tuple[Tensor, ...]`` -> single ``Tensor[]`` slot. - - Used for fields like ``LinearBwdArgs.saved_tensors`` that carry an - arbitrary-length sequence of tensors (typically the - ``ctx.saved_tensors`` payload restored before invoking the backward - op). The slot itself is non-nullable, but individual ``None`` - elements are smuggled through using :func:`_encode_none` / - :func:`_decode_none` sentinels (matching what the forward op return - list already does). An empty sequence is valid. - """ - - def __init__(self, name: str, container: type) -> None: - self.name = name - # Remember the original container type so unpack returns the - # exact same Python type the dataclass annotation declared. - self.container = container - - @classmethod - def try_build(cls, name: str, annot: Any) -> Optional["_TensorListBucket"]: - stripped, _ = _strip_optional(annot) - origin = get_origin(stripped) - if origin is None: - return None - args = get_args(stripped) - if not args: - return None - # ``Tuple[Tensor, ...]`` -> args = (Tensor, Ellipsis); other forms - # like ``Tuple[Tensor, Tensor]`` or ``List[Tensor]`` only have - # type entries. - if origin is tuple: - if len(args) == 2 and args[1] is Ellipsis: - elem = args[0] - else: - elem = args[0] if all(a is args[0] for a in args) else None - elif origin is list: - elem = args[0] - else: - return None - if elem is not torch.Tensor: - return None - return cls(name, list if origin is list else tuple) - - def schema_slots(self) -> List[Tuple[str, str]]: - return [(self.name, "Tensor[]")] - - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: - value = getattr(owner, self.name) or () - return [(self.name, [_encode_none(t) for t in value])] - - def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: - kwargs[self.name] = self.container(_decode_none(t) for t in args[self.name]) - - class _ProcessGroupBucket(_Bucket): """``ProcessGroup`` / ``Optional[ProcessGroup]`` -> one direct opaque-ref slot. @@ -613,69 +667,96 @@ def try_build(cls, name: str, annot: Any) -> Optional["_ProcessGroupBucket"]: def schema_slots(self) -> List[Tuple[str, str]]: return [(self.name, self.type_str)] - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: return [(self.name, getattr(owner, self.name))] def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: kwargs[self.name] = args[self.name] -def _flattenable_bases() -> Tuple[type, ...]: - """Return the list of base classes whose subclasses are routed - through :class:`_FlattenableBucket`. +# Cached resolutions of TE types that ``dynamo`` references lazily to +# avoid import cycles (they live in modules that themselves import this +# one). Each ``_*_cls`` getter resolves its target once and reuses the +# result on every subsequent call; the values are kept module-level +# rather than baked into bucket instances so the cache survives across +# different dataclass registrations. +_QTS_REF: Optional[type] = None +_QUANTIZER_REF: Optional[type] = None +_RECIPE_REF: Optional[type] = None - A "flattenable" type implements the duck-typed pair - * instance method ``_flatten() -> (OpaqueSimpleMetadata, ref, list[Tensor])`` - * classmethod ``_unflatten(meta, ref, tensors)`` (dispatches by an - identifier stamped into ``meta``) +def _quantized_tensor_storage_cls() -> Optional[type]: + """Lazy-resolve :class:`QuantizedTensorStorage`; ``None`` if unavailable.""" + global _QTS_REF + if _QTS_REF is None: + try: + from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensorStorage, + ) + + _QTS_REF = QuantizedTensorStorage + except Exception: # pragma: no cover - partial init + return None + return _QTS_REF - Lazy import keeps ``dynamo`` importable before the modules that - define these bases (avoid import cycles). - """ - bases: List[type] = [] - try: - from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer - bases.append(Quantizer) - bases.append(QuantizedTensorStorage) +def _quantizer_cls() -> Optional[type]: + """Lazy-resolve :class:`Quantizer`; ``None`` if unavailable.""" + global _QUANTIZER_REF + if _QUANTIZER_REF is None: + try: + from transformer_engine.pytorch.quantized_tensor import Quantizer + + _QUANTIZER_REF = Quantizer except Exception: # pragma: no cover - partial init - pass + return None + return _QUANTIZER_REF + + +def _recipe_cls() -> Optional[type]: + """Lazy-resolve :class:`Recipe`; ``None`` if unavailable.""" + global _RECIPE_REF + if _RECIPE_REF is None: try: from transformer_engine.common.recipe import Recipe - bases.append(Recipe) + _RECIPE_REF = Recipe except Exception: # pragma: no cover - partial init - pass - return tuple(bases) + return None + return _RECIPE_REF -class _FlattenableBucket(_Bucket): - """Three-slot expansion (``meta`` / ``ref`` / ``tensors``) for any - field whose type implements the ``_flatten`` / ``_unflatten`` +def _flattenable_bases() -> Tuple[type, ...]: + """Return the list of base classes whose subclasses are routed + through :class:`_FlattenableBucket`. + + A "flattenable" type implements the duck-typed pair + + * instance method ``_flatten() -> (OpaqueSimpleMetadata, ref, list[Tensor])`` + * classmethod ``_unflatten(meta, ref, tensors)`` (dispatches by an + identifier stamped into ``meta``). + """ + return tuple( + cls + for cls in (_quantizer_cls(), _quantized_tensor_storage_cls(), _recipe_cls()) + if cls is not None + ) + + +class _FlattenableBucket(_MetaPGTensorsBucket): + """Field whose type implements the ``_flatten`` / ``_unflatten`` protocol (see :func:`_flattenable_bases`). Used today for :class:`~transformer_engine.pytorch.quantized_tensor.Quantizer` and :class:`~transformer_engine.common.recipe.Recipe`. """ - SUFFIX_META = "__fmeta" - SUFFIX_PG = "__fpg" - SUFFIX_TENSORS = "__ftensors" - # Stored under ``_qcls`` in the metadata bundle to encode ``None`` # without making any of the three slots nullable. NONE_MARKER_KEY = "_qcls" NONE_MARKER_VAL = "" def __init__(self, name: str, base_cls: type) -> None: - if _OPAQUE_SIMPLE_META_TYPE_NAME is None or _PROCESS_GROUP_TYPE_NAME is None: - raise RuntimeError( - f"Flattenable field {name!r} requires both " - "OpaqueSimpleMetadata and torch.distributed.ProcessGroup " - "to be registered as torch._library opaque types; one or " - "both are unavailable in this PyTorch build." - ) - self.name = name + super().__init__(name) self.base_cls = base_cls @classmethod @@ -688,52 +769,25 @@ def try_build(cls, name: str, annot: Any) -> Optional["_FlattenableBucket"]: return cls(name, base) return None - def _slot_meta(self) -> str: - return self.name + self.SUFFIX_META - - def _slot_pg(self) -> str: - return self.name + self.SUFFIX_PG - - def _slot_tensors(self) -> str: - return self.name + self.SUFFIX_TENSORS - - def schema_slots(self) -> List[Tuple[str, str]]: - return [ - (self._slot_meta(), _OPAQUE_SIMPLE_META_TYPE_NAME), - (self._slot_pg(), _PROCESS_GROUP_TYPE_NAME + "?"), - (self._slot_tensors(), "Tensor[]"), - ] - - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: - value = getattr(owner, self.name) + def _pack_value(self, value: Any) -> Tuple[Any, Any, List[torch.Tensor]]: if value is None: - meta = OpaqueSimpleMetadata({self.NONE_MARKER_KEY: self.NONE_MARKER_VAL}) - pg: Any = None - tensors: List[torch.Tensor] = [] - else: + return ( + OpaqueSimpleMetadata({self.NONE_MARKER_KEY: self.NONE_MARKER_VAL}), + None, + [], + ) if hasattr(value, "_flatten"): - meta, pg, tensors = value._flatten() - else: - meta, pg, tensors = value._torch_compile_flatten() - return [ - (self._slot_meta(), meta), - (self._slot_pg(), pg), - (self._slot_tensors(), list(tensors)), - ] + return value._flatten() + return value._torch_compile_flatten() - def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: - meta = args[self._slot_meta()] + def _unpack_value( + self, meta: Any, pg: Any, tensors: List[torch.Tensor] + ) -> Any: if meta.get(self.NONE_MARKER_KEY) == self.NONE_MARKER_VAL: - kwargs[self.name] = None - return + return None if hasattr(self.base_cls, "_unflatten"): - kwargs[self.name] = self.base_cls._unflatten( - meta, args[self._slot_pg()], args[self._slot_tensors()] - ) - else: - kwargs[self.name] = self.base_cls._torch_compile_unflatten( - meta, args[self._slot_pg()], args[self._slot_tensors()] - ) + return self.base_cls._unflatten(meta, pg, tensors) + return self.base_cls._torch_compile_unflatten(meta, pg, tensors) class _SimpleBundleBucket(_Bucket): @@ -774,12 +828,7 @@ def matches_field(cls, annot: Any) -> bool: return True # Any registered value-opaque class is hashable / FX-reproducible # and therefore safe to embed in the OpaqueSimpleMetadata bundle. - if isinstance(annot, type): - try: - from torch._library.opaque_object import is_opaque_value_type - except Exception: # pragma: no cover - older torch - is_opaque_value_type = None - if is_opaque_value_type is not None and is_opaque_value_type(annot): + if isinstance(annot, type) and _is_opaque_value_type(annot): return True origin = get_origin(annot) if origin in (tuple, list): @@ -792,7 +841,7 @@ def matches_field(cls, annot: Any) -> bool: def schema_slots(self) -> List[Tuple[str, str]]: return [(self.SLOT, _OPAQUE_SIMPLE_META_TYPE_NAME)] - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: bundle = OpaqueSimpleMetadata({n: getattr(owner, n) for n in self.names}) return [(self.SLOT, bundle)] @@ -816,7 +865,7 @@ class _UnknownBucket(_Bucket): the op and reconstructed from companion fields (``saved_tensors``, quantizer metadata, ...) on the way out. - Constructed directly by :meth:`ArgObject._buckets` (it has no + Constructed directly by :func:`_get_buckets` (it has no annotation-based ``try_build`` -- it's the explicit "no match" case). """ @@ -835,15 +884,14 @@ def _is_trivial(value: Any) -> bool: def schema_slots(self) -> List[Tuple[str, str]]: return [] - def pack(self, owner: "ArgObject") -> List[Tuple[str, Any]]: + def pack(self, owner: Any) -> List[Tuple[str, Any]]: value = getattr(owner, self.name, None) if not self._is_trivial(value): raise TypeError( f"{self.owner_cls_name} field {self.name!r} has a type not " "supported by torch.compile (not Tensor, simple, " - "ProcessGroup, or Quantizer) and carries " - "a non-trivial value; override " - f"{self.owner_cls_name}.torch_compile_pack to handle it." + "ProcessGroup, or Quantizer) and carries a non-trivial " + "value; add a matching bucket in dynamo.py to handle it." ) return [] @@ -853,39 +901,34 @@ def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: # Buckets, in priority order, that own ``try_build`` for a single field. _FIELD_BUCKETS: Tuple[type, ...] = ( - _TensorOrStorageBucket, + _UniversalTensorBucket, _TensorBucket, - _TensorListBucket, _ProcessGroupBucket, _FlattenableBucket, ) # --------------------------------------------------------------------------- # -# ArgObject +# Dataclass <-> torch.library plumbing # --------------------------------------------------------------------------- # - - -class ArgObject: - """Base class for structured argument containers passed to TE custom ops. - - Subclassed by per-module forward / backward dataclasses - (e.g. ``LinearFwdArgs``, ``LinearBwdArgs``). Provides the pack / - unpack / schema hooks consumed by :func:`_te_register_custom_op` - when wiring the dataclass into a ``torch.library`` schema. - - The default pack / unpack / schema implementations dispatch on - dataclass field annotations. Each field is mapped to exactly one - :class:`_Bucket` (see module-level docstring); the three methods - then become trivial iterations over the bucket list. - """ - - @classmethod - def _resolved_field_annotations(cls) -> List[Tuple[str, Any]]: +# +# The argument containers consumed by :func:`_te_register_custom_op` +# (e.g. ``LinearFwdArgs`` / ``LinearBwdArgs``) are intentionally just +# plain ``@dataclass`` types -- no base class, no decorators, no special +# methods. All translation between the dataclass and the flat +# ``{slot_name: slot_value}`` view that ``torch.library`` works with is +# provided by the module-level helpers below, which dispatch on dataclass +# field annotations: each field is mapped to exactly one :class:`_Bucket` +# and the three operations (schema / pack / unpack) reduce to a loop +# over the bucket list. + + +def _resolved_field_annotations(cls: type) -> List[Tuple[str, Any]]: + """Return ``[(field_name, resolved_type), ...]`` for a dataclass.""" if not dataclasses.is_dataclass(cls): raise TypeError( - f"{cls.__name__} must be a @dataclass to use the default " - f"ArgObject torch_compile_* implementations." + f"{cls.__name__} must be a @dataclass to be used as a TE " + f"custom-op argument container." ) # ``get_type_hints`` resolves forward references and PEP 563 # ``from __future__ import annotations`` strings. @@ -893,30 +936,27 @@ def _resolved_field_annotations(cls) -> List[Tuple[str, Any]]: hints = get_type_hints(cls) except Exception: hints = {} - return [ - (f.name, hints.get(f.name, f.type)) for f in dataclasses.fields(cls) - ] + return [(f.name, hints.get(f.name, f.type)) for f in dataclasses.fields(cls)] - @classmethod - def _buckets(cls) -> List[_Bucket]: - """Build the bucket list for this dataclass from field annotations. + +def _get_buckets(cls: type) -> List[_Bucket]: + """Build the bucket list for a dataclass from its field annotations. Dispatch order per field: try each bucket in :data:`_FIELD_BUCKETS` (Tensor, ProcessGroup, Quantizer); if none claims the field, route it to :class:`_SimpleBundleBucket` if its annotation is bundle-able, else to :class:`_UnknownBucket`. - Intentionally **not** cached. Caching on ``cls`` (e.g. by writing - ``cls.__te_buckets__``) tickles Dynamo: subsequent reads of + Intentionally **not** cached on ``cls``. Caching there (e.g. by + writing ``cls.__te_buckets__``) tickles Dynamo: subsequent reads of ``cls.__dict__`` from a compiled function trigger - "mappingproxy affected by dictionary mutation" graph breaks. - Hot paths must instead capture the bucket list once at op - registration time and pass it explicitly to :meth:`torch_compile_pack` - / :meth:`torch_compile_unpack`. + "mappingproxy affected by dictionary mutation" graph breaks. Hot + paths must instead capture the bucket list once at op registration + time and pass it explicitly to :func:`_pack` / :func:`_unpack`. """ buckets: List[_Bucket] = [] simple_names: List[str] = [] - for name, annot in cls._resolved_field_annotations(): + for name, annot in _resolved_field_annotations(cls): built: Optional[_Bucket] = None for bucket_cls in _FIELD_BUCKETS: built = bucket_cls.try_build(name, annot) @@ -932,52 +972,46 @@ def _buckets(cls) -> List[_Bucket]: buckets.append(_SimpleBundleBucket(simple_names)) return buckets - @classmethod - def torch_compile_get_schema(cls) -> List[Tuple[str, str]]: - """Default: derive the schema from dataclass annotations. - See :class:`_Bucket` subclasses for the per-field-kind layout - (Tensor, ProcessGroup, Quantizer, and the - aggregated ``_simple_meta`` bundle of simple fields). - """ - return [slot for b in cls._buckets() for slot in b.schema_slots()] - - def torch_compile_pack( - self, buckets: Optional[List[_Bucket]] = None - ) -> Dict[str, Any]: - """Default: ask each bucket to extract its slot(s) from ``self``. - - ``buckets`` is the precomputed bucket list (from - :meth:`_buckets`). Hot paths -- e.g. the closures created by - :func:`_te_register_custom_op` -- must pass it to avoid recomputing - and, critically, to keep Dynamo away from ``cls.__dict__`` while - tracing. When ``None``, this method recomputes the buckets - (eager-only fallback intended for ad-hoc / test usage). - """ - if buckets is None: - buckets = type(self)._buckets() +def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: + """Return ``(schema_str, slot_names)`` for a precomputed bucket list. + + ``schema_str`` is the parenthesised argument list (e.g. + ``"(Tensor x, Tensor? y)"``) that ``torch.library.Library.define`` + appends to the op name; ``slot_names`` is the ordered list of slot + keys produced by :func:`_pack`, used to flatten/unflatten the + keyword dict into the positional call. + """ + spec = [slot for b in buckets for slot in b.schema_slots()] + names = [name for name, _ in spec] + schema_str = "(" + ", ".join(f"{type_str} {name}" for name, type_str in spec) + ")" + return schema_str, names + + +def _pack(obj: Any, buckets: List[_Bucket]) -> Dict[str, Any]: + """Ask each bucket to extract its slot(s) from ``obj``. + + ``buckets`` is the precomputed bucket list (from :func:`_get_buckets`). + Hot paths -- e.g. the closures created by + :func:`_te_register_custom_op` -- must pass the precomputed list to + avoid recomputing and, critically, to keep Dynamo away from + ``cls.__dict__`` while tracing. + """ out: Dict[str, Any] = {} for bucket in buckets: - for name, value in bucket.pack(self): + for name, value in bucket.pack(obj): out[name] = value return out - @classmethod - def torch_compile_unpack( - cls, - args: Dict[str, Any], - buckets: Optional[List[_Bucket]] = None, - ) -> "ArgObject": - """Default: ask each bucket to inject its field(s) into a fresh - instance built via ``__new__`` (we bypass the dataclass - ``__init__`` so unknown-typed fields can stay as ``None`` even - when they have no default). - - ``buckets`` semantics match :meth:`torch_compile_pack`: hot paths - pass the precomputed list, eager-only callers may omit it. - """ - if buckets is None: - buckets = cls._buckets() + +def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: + """Ask each bucket to inject its field(s) into a fresh instance. + + The instance is built via ``cls.__new__(cls)`` (we bypass any + dataclass ``__init__`` so unknown-typed fields can stay as ``None`` + even when they have no default). ``buckets`` semantics match + :func:`_pack`. + """ kwargs: Dict[str, Any] = {} for bucket in buckets: bucket.unpack(args, kwargs) @@ -986,271 +1020,246 @@ def torch_compile_unpack( object.__setattr__(obj, k, v) return obj - @classmethod - def torch_compile_get_input_tensors_for_grad(cls) -> List[str]: - """Names of forward inputs (from :meth:`torch_compile_get_schema`) - for which the corresponding ``backward_impl`` produces gradients, - in the exact order ``backward_impl`` returns them. - - Only meaningful on the forward arg type. Default is ``[]`` (no - gradients, e.g. for inference-only ops). The wrapper uses this - to pad the autograd return tuple with ``None`` for every input - not listed here, so torch sees one slot per forward input as - required by ``register_autograd``. - """ - return [] +# --------------------------------------------------------------------------- # +# Op registration helpers +# --------------------------------------------------------------------------- # +# +# The bottom half of the module turns one or more user-supplied eager +# kernels (forward / backward / their fake counterparts) plus the +# dataclass argument types into a fully registered ``torch.library`` +# custom op. :func:`_te_register_custom_op` is the orchestrator; the +# helpers below are the per-step building blocks (validation, kernel +# wrapping, dispatcher creation). -def _te_register_custom_op( - *, - linear_impl: Callable[[Any], Any], - linear_arg_type: type, - setup_context: Callable[..., None], - backward_impl: Callable[[Any], Any], - backward_obj: type, - backward_arg_type: type, - num_outputs: int, - linear_fake_impl: Optional[Callable[[Any], Any]] = None, - backward_fake_impl: Optional[Callable[[Any], Any]] = None, - op_namespace: str = "transformer_engine", - op_name: str = "linear", -) -> Callable[..., Any]: - """Register a TE module's forward + backward as a single torch custom op. - Parameters - ---------- - linear_impl - Eager forward implementation. Receives a single argument of type - ``linear_arg_type`` and must return a tuple of the form - ``(*output_tensors, tensors_to_save, tensor_objects, ctx_attrs)`` - where: +def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: + """Lazy wrapper around :func:`quantized_tensor.prepare_for_saving`. - * ``output_tensors`` -- one or more :class:`torch.Tensor` outputs - returned to the caller. - * ``tensors_to_save`` -- flat list of :class:`torch.Tensor` to be - stashed via ``ctx.save_for_backward``. - * ``tensor_objects`` -- the metadata object produced by - :func:`prepare_for_saving`, paired with ``tensors_to_save`` to - let the backward reconstruct quantized / structured tensors. - * ``ctx_attrs`` -- non-tensor state to attach to the autograd - context, restricted to values that cannot be derived from the - forward args inside ``setup_context``. - linear_arg_type - Dataclass type aggregating all forward inputs (e.g. - :class:`LinearFwdArgs`). Used to (re)build the structured argument - from the flat tensor / non-tensor inputs accepted by the custom op. - setup_context - Eager autograd ``setup_context`` analogue. Receives a freshly - constructed ``backward_obj`` instance, the forward args, the - forward output, and ``ctx_attrs`` produced by ``linear_impl``; - is responsible for populating the backward-state object so that - ``backward_impl`` can later consume it. - backward_impl - Eager backward implementation. Receives a single argument of type - ``backward_arg_type`` and returns the gradient tuple. - backward_obj - Dataclass / class used to instantiate a fresh backward-state - container at the end of the forward pass (typically the same as - ``backward_arg_type``). - backward_arg_type - Type accepted by ``backward_impl``. May differ from ``backward_obj`` - if the backward op needs a wrapped / opaque view of the state. - num_outputs - Number of user-facing tensor outputs returned by ``linear_impl``. - The op concatenates ``[*output_tensors, *tensors_to_save]`` into - a single ``Tensor[]`` return; the wrapper uses ``num_outputs`` to - split the two halves on the way back out. + Lazy-imports to avoid the dynamo<->quantized_tensor circular import + that ``transformer_engine.pytorch`` would otherwise trigger at + module import time. + """ + from transformer_engine.pytorch.quantized_tensor import prepare_for_saving - The list of forward inputs that receive gradients is declared on - the forward arg type itself, via - :meth:`ArgObject.torch_compile_get_input_tensors_for_grad`. - ``backward_impl`` must return its gradients in that exact order. - linear_fake_impl - Optional fake (shape inference) counterpart of ``linear_impl``, - registered via ``torch.library.register_fake``. Returns the same - tuple shape as ``linear_impl`` -- ``(*output_tensors, - tensors_to_save, tensor_objects, ctx_attrs)`` -- but every - ``torch.Tensor`` is a fake tensor (allocated via - ``quantizer.make_empty`` or ``torch.empty``) carrying only the - correct shape / dtype / device, with no real storage or - computation. ``tensor_objects`` and ``ctx_attrs`` must be - structurally identical to those produced by ``linear_impl`` so - that ``setup_context`` and ``backward_impl`` see the same - non-tensor state in eager and traced modes. - backward_fake_impl - Optional fake counterpart of ``backward_impl``. Returns the same - gradient tuple as ``backward_impl``, with fake tensors in place - of the real gradients. - op_namespace, op_name - Library namespace / op name used when registering with - ``torch.library``. + return prepare_for_saving(*(tensors or ())) - Returns - ------- - Callable - A function ``forward_fn(linear_arg_type_instance)`` that dispatches - through the registered custom op, returning the user-facing - outputs (single tensor if ``num_outputs == 1``, otherwise a - tuple). Use under ``torch.compiler.is_compiling()`` as a drop-in - for ``Function.apply``. - """ - fwd_qualname = f"{op_namespace}::{op_name}" - bwd_op_name = f"{op_name}_backward" - bwd_qualname = f"{op_namespace}::{bwd_op_name}" +def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: + """Lazy wrapper around :func:`quantized_tensor.restore_from_saved`.""" + from transformer_engine.pytorch.quantized_tensor import restore_from_saved - # Precompute the bucket list for both arg types and capture them in - # the closures below. Critical for the compiled path: re-deriving - # buckets at call time would force ``ArgObject._buckets`` to read - # ``cls.__dict__`` from inside a Dynamo-traced function, which - # triggers a "mappingproxy affected by dictionary mutation" graph - # break under ``fullgraph=True``. - fwd_buckets: List[_Bucket] = linear_arg_type._buckets() - bwd_buckets: List[_Bucket] = backward_arg_type._buckets() + return restore_from_saved(tensor_objects, saved_tensors) - def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: - spec = [slot for b in buckets for slot in b.schema_slots()] - names = [name for name, _ in spec] - schema_str = "(" + ", ".join(f"{type_str} {name}" for name, type_str in spec) + ")" - return schema_str, names - fwd_schema_args, fwd_arg_names = _build_schema(fwd_buckets) - bwd_schema_args, bwd_arg_names = _build_schema(bwd_buckets) +def _format_fwd_result(result: Any, num_outputs: int) -> List[torch.Tensor]: + """Pack a fwd-impl return tuple into the op's ``Tensor[]`` payload. - # ``torch.library.register_autograd`` requires the backward to return - # one grad slot per forward input, with the same Python tree - # structure as the input itself: a ``Tensor[]`` slot must get back a - # ``list``, not a bare ``None``. Precompute the per-slot "no-grad" - # value so the autograd return matches. + The op concatenates ``[*output_tensors, *tensors_to_save]`` into a + single non-nullable list; ``None`` entries are smuggled through the + :func:`_encode_none` sentinel so ``register_autograd`` still + attaches a ``grad_fn`` to the result. + """ + outputs = list(result[:num_outputs]) + tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) + return [_encode_none(t) for t in outputs + tensors_to_save] + + +def _format_bwd_result( + grads: Any, num_grad_inputs: int, op_qualname: str +) -> List[torch.Tensor]: + """Pack a backward-impl return tuple into the op's ``Tensor[]`` payload. + + Validates that the user kernel returned exactly one grad per + ``input_tensors_for_grad`` entry; raises with the op's qualified + name on mismatch. + """ + grads = list(grads) + if len(grads) != num_grad_inputs: + raise RuntimeError( + f"{op_qualname} expected backward_impl to return " + f"{num_grad_inputs} grads (one per input_tensors_for_grad " + f"entry), got {len(grads)}" + ) + return [_encode_none(g) for g in grads] + + +def _resolve_grad_targets( + fwd_buckets: List[_Bucket], + fwd_arg_type: type, + input_tensors_for_grad: List[str], +) -> Tuple[List[Any], List[Tuple[int, bool]]]: + """Validate ``input_tensors_for_grad`` and resolve grad-output layout. + + Returns ``(fwd_slot_defaults, grad_targets)`` where: + + * ``fwd_slot_defaults`` is the per-slot "no-grad" template the + autograd return tuple starts from -- ``[]`` for ``Tensor[]`` + slots, ``None`` otherwise. ``register_autograd`` requires one + grad slot per forward input with matching tree structure (a + ``Tensor[]`` slot must get back a list, not bare ``None``). + * ``grad_targets`` is the ``[(slot_index, as_list), ...]`` mapping + for each name in ``input_tensors_for_grad``, in the same order; + ``as_list`` is ``True`` for ``Tensor[]``-shaped slots so the + caller wraps the single grad into a length-matched list. + """ fwd_slot_defaults: List[Any] = [] for bucket in fwd_buckets: for _, type_str in bucket.schema_slots(): fwd_slot_defaults.append([] if type_str.endswith("[]") else None) - # Validate ``input_tensors_for_grad`` references real forward inputs - # and precompute the positions where backward grads land in the - # autograd return tuple. Some logical fields (e.g. Tensor-or-storage - # fields) expand to a ``Tensor[]`` slot; their gradient must be returned - # as a list matching that input tree. - input_tensors_for_grad = linear_arg_type.torch_compile_get_input_tensors_for_grad() fwd_grad_targets: Dict[str, Tuple[int, bool]] = {} slot_offset = 0 for bucket in fwd_buckets: slots = bucket.schema_slots() if isinstance(bucket, _TensorBucket): fwd_grad_targets[bucket.name] = (slot_offset, False) - elif isinstance(bucket, _TensorListBucket): - fwd_grad_targets[bucket.name] = (slot_offset, True) - elif isinstance(bucket, _TensorOrStorageBucket): + elif isinstance(bucket, _UniversalTensorBucket): + # Grad routes to the ``Tensor?`` slot -- the wrapper / + # plain-tensor passthrough -- so the gradient flows back + # to the user-facing object (e.g. an ``nn.Parameter`` + # wrapped as ``Float8Tensor``). In the storage path the + # ``Tensor?`` slot is ``None`` and the kernel does not + # request a grad for it. for i, (slot_name, _) in enumerate(slots): - if slot_name == bucket._slot_tensors(): - fwd_grad_targets[bucket.name] = (slot_offset + i, True) + if slot_name == bucket.slot_name(): + fwd_grad_targets[bucket.name] = (slot_offset + i, False) break slot_offset += len(slots) - unknown_grad_names = [n for n in input_tensors_for_grad if n not in fwd_grad_targets] - if unknown_grad_names: + + unknown = [n for n in input_tensors_for_grad if n not in fwd_grad_targets] + if unknown: raise ValueError( - f"{linear_arg_type.__name__}.torch_compile_get_input_tensors_for_grad() " - f"contains names not present in " - f"{linear_arg_type.__name__}.torch_compile_get_schema(): " - f"{unknown_grad_names}" + f"input_tensors_for_grad contains names not present in " + f"{fwd_arg_type.__name__} schema: {unknown}" ) grad_targets = [fwd_grad_targets[n] for n in input_tensors_for_grad] - num_grad_inputs = len(input_tensors_for_grad) + return fwd_slot_defaults, grad_targets - lib = torch.library.Library(op_namespace, "FRAGMENT") - # Forward op concatenates user outputs and tensors_to_save into a - # single ``Tensor[]`` return so that autograd's ``setup_context`` can - # stash the saved-for-backward tensors without re-running the eager - # impl. The schema is non-nullable (``Tensor[]``, not ``Tensor?[]``) - # because ``torch.library.register_autograd`` does not propagate - # ``grad_fn`` to a nullable list output. ``None`` entries on either - # side are smuggled through via :func:`_encode_none` / - # :func:`_decode_none` sentinels (see below). - lib.define(f"{op_name}{fwd_schema_args} -> Tensor[]") - lib.define(f"{bwd_op_name}{bwd_schema_args} -> Tensor[]") - - def _outputs_for_setup(outputs: List[torch.Tensor]) -> Any: - return outputs[0] if num_outputs == 1 else tuple(outputs) - - def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: - from transformer_engine.pytorch.quantized_tensor import prepare_for_saving - - return prepare_for_saving(*(tensors or ())) - - def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: - from transformer_engine.pytorch.quantized_tensor import restore_from_saved - return restore_from_saved(tensor_objects, saved_tensors) - - def _fwd_impl(*flat: Any) -> List[torch.Tensor]: - kwargs = dict(zip(fwd_arg_names, flat)) - obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) - result = linear_impl(obj) - outputs = list(result[:num_outputs]) - tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) - return [_encode_none(t) for t in outputs + tensors_to_save] - - lib.impl(op_name, _fwd_impl, "CompositeExplicitAutograd") - - if linear_fake_impl is not None: +def _register_kernel( + *, + op_name: str, + op_qualname: str, + arg_type: type, + arg_names: List[str], + buckets: List[_Bucket], + impl: Callable[[Any], Any], + fake_impl: Optional[Callable[[Any], Any]], + format_result: Callable[[Any], List[torch.Tensor]], +) -> None: + """Wire ``impl`` (and optionally ``fake_impl``) into :data:`_TE_LIB` + under ``op_name``. + + The wrapper unpacks the flat positional args using + ``arg_names`` / ``buckets``, calls the user kernel with the rebuilt + dataclass instance, and packs the result through ``format_result`` + (which encodes ``None``s into the op's ``Tensor[]`` return slot). + """ - def _fwd_fake(*flat: Any) -> List[torch.Tensor]: - kwargs = dict(zip(fwd_arg_names, flat)) - obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) - result = linear_fake_impl(obj) - outputs = list(result[:num_outputs]) - tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) - return [_encode_none(t) for t in outputs + tensors_to_save] + def _eager(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(arg_names, flat)) + obj = _unpack(arg_type, kwargs, buckets) + return format_result(impl(obj)) - torch.library.register_fake(fwd_qualname, _fwd_fake, lib=lib) + _TE_LIB.impl(op_name, _eager, "CompositeExplicitAutograd") - def _check_bwd_len(grads): - if len(grads) != num_grad_inputs: - raise RuntimeError( - f"{op_namespace}::{bwd_op_name} expected backward_impl to " - f"return {num_grad_inputs} grads (one per " - f"input_tensors_for_grad entry), got {len(grads)}" - ) + if fake_impl is not None: - def _bwd_impl(*flat: Any) -> List[torch.Tensor]: - kwargs = dict(zip(bwd_arg_names, flat)) - obj = backward_arg_type.torch_compile_unpack(kwargs, bwd_buckets) - grads = list(backward_impl(obj)) - _check_bwd_len(grads) - return [_encode_none(g) for g in grads] + def _fake(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(arg_names, flat)) + obj = _unpack(arg_type, kwargs, buckets) + return format_result(fake_impl(obj)) - lib.impl(bwd_op_name, _bwd_impl, "CompositeExplicitAutograd") + torch.library.register_fake(op_qualname, _fake, lib=_TE_LIB) - if backward_fake_impl is not None: - def _bwd_fake(*flat: Any) -> List[torch.Tensor]: - kwargs = dict(zip(bwd_arg_names, flat)) - obj = backward_arg_type.torch_compile_unpack(kwargs, bwd_buckets) - grads = list(backward_fake_impl(obj)) - _check_bwd_len(grads) - return [_encode_none(g) for g in grads] +def _collect_universal_slot_offsets(buckets: List[_Bucket]) -> List[int]: + """Return the start index of each :class:`_UniversalTensorBucket` + group inside the flat positional arg list of a registered op. - torch.library.register_fake(bwd_qualname, _bwd_fake, lib=lib) + The four schema slots emitted by a universal bucket are always + contiguous (``name``, ``__tensors``, ``__pg``, ``__meta``); knowing + the offset of the first slot lets a subclass dispatch rule rewrite + all four slots in place at trace / eager time without re-deriving + the bucket list. + """ + offsets: List[int] = [] + pos = 0 + for bucket in buckets: + if isinstance(bucket, _UniversalTensorBucket): + offsets.append(pos) + pos += len(bucket.schema_slots()) + return offsets + + +def _flatten_subclass_into_slots( + new_args: List[Any], slot_offsets: List[int], subclass: type +) -> None: + """Rewrite each ``_UniversalTensorBucket`` group whose ``Tensor?`` + slot holds an instance of ``subclass`` into the storage layout. + + Used as the body of a ``register_torch_dispatch`` rule on the outer + fwd / bwd op: a subclass passed through the user-facing op is + flattened in place (via ``_torch_compile_flatten``) so that the + inner op only ever sees plain tensors plus the storage-flatten + metadata. The wrapper's autograd identity remains attached to the + inner tensors via the wrapper-subclass machinery, so gradients + still flow back to the user-facing tensor. + """ + for offset in slot_offsets: + val = new_args[offset] + if val is None or not isinstance(val, subclass): + continue + meta, pg, tensors = val._torch_compile_flatten() + meta._data[_UniversalTensorBucket.KIND_KEY] = _UniversalTensorBucket.KIND_STORAGE + new_args[offset] = None + new_args[offset + 1] = list(tensors) + new_args[offset + 2] = pg + new_args[offset + 3] = meta + + +def _register_autograd_for_op( + *, + fwd_op_name: str, + bwd_op_name: str, + fwd_arg_type: type, + fwd_arg_names: List[str], + fwd_buckets: List[_Bucket], + bwd_arg_names: List[str], + bwd_buckets: List[_Bucket], + num_outputs: int, + fwd_slot_defaults: List[Any], + grad_targets: List[Tuple[int, bool]], + fwd_fake_impl: Optional[Callable[[Any], Any]], + fwd_impl: Callable[[Any], Any], + setup_context_user: Callable[..., None], + backward_obj_type: type, +) -> None: + """Wire ``register_autograd`` on a forward op so its backward calls + ``bwd_op_name``. + + Both the inner and outer tiers of a two-tier op share an identical + autograd bridge (the wrapper logic only cares about op *names*), so + this helper is called once per tier; the actual kernel + registration is handled separately (by :func:`_register_kernel` + for the inner tier and :func:`_register_outer_forwarder` for the + outer tier). + """ + fwd_qualname = f"{_TE_OP_NAMESPACE}::{fwd_op_name}" - # Re-run fake (or real) impl in setup_context to recover - # tensor_objects / ctx_attrs, which are not part of the op's return. - fake_for_setup = linear_fake_impl if linear_fake_impl is not None else linear_impl + fake_for_setup = fwd_fake_impl if fwd_fake_impl is not None else fwd_impl def _setup_context(ctx, inputs, output): ctx._te_fwd_tensor_list_lengths = { i: len(value) for i, value in enumerate(inputs) if isinstance(value, list) } kwargs = dict(zip(fwd_arg_names, inputs)) - fwd_obj = linear_arg_type.torch_compile_unpack(kwargs, fwd_buckets) + fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) fake_result = fake_for_setup(fwd_obj) _, tensor_objects = _prepare_for_saving(fake_result[num_outputs]) ctx_attrs = fake_result[num_outputs + 2] - # Split op output: first num_outputs are user-facing tensors, - # the rest are tensors_to_save. ``output`` is a flat ``Tensor[]`` - # with our None-sentinels in place; decode here so downstream - # eager code sees the original ``None``\ s. user_outputs = [_decode_none(t) for t in output[:num_outputs]] op_saved_tensors = [_decode_none(t) for t in output[num_outputs:]] tensors_to_save_from_forward = _restore_from_saved( @@ -1258,11 +1267,11 @@ def _setup_context(ctx, inputs, output): op_saved_tensors, ) - bwd_obj = backward_obj() - tensors_to_save_from_setup = setup_context( + bwd_obj = backward_obj_type() + tensors_to_save_from_setup = setup_context_user( bwd_obj, fwd_obj, - _outputs_for_setup(user_outputs), + user_outputs[0] if num_outputs == 1 else tuple(user_outputs), ctx_attrs, tensors_to_save_from_forward, ) @@ -1274,25 +1283,14 @@ def _setup_context(ctx, inputs, output): def _autograd_backward(ctx, *grad_outputs): bwd_obj = ctx.bwd_obj if hasattr(bwd_obj, "setup_saved_tensors"): - bwd_obj.setup_saved_tensors(ctx.saved_tensors, ctx.tensor_objects) + bwd_obj.setup_saved_tensors(ctx) ctx.tensor_objects = None - # The forward op returns a single ``Tensor[]`` (concatenation of - # user outputs and saved tensors), so ``grad_outputs`` is a - # 1-tuple containing the per-element grad list. Only the first - # ``num_outputs`` of those correspond to user-facing outputs; - # ``grad_output`` for the backward is the grad of the primary - # output. per_output_grads = grad_outputs[0] bwd_obj.grad_output = _decode_none(per_output_grads[0]) - kwargs = backward_arg_type.torch_compile_pack(bwd_obj, bwd_buckets) + kwargs = _pack(bwd_obj, bwd_buckets) bwd_args_flat = [kwargs[name] for name in bwd_arg_names] - bwd_op = getattr(getattr(torch.ops, op_namespace), bwd_op_name) + bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), bwd_op_name) grads = [_decode_none(g) for g in bwd_op(*bwd_args_flat)] - # ``register_autograd`` requires one grad slot per forward input - # with the same tree structure as the input (a ``Tensor[]`` slot - # must get back a list, never a bare ``None``). Start from the - # precomputed per-slot defaults and overlay the produced grads - # at the positions declared by ``input_tensors_for_grad``. out: List[Any] = list(fwd_slot_defaults) tensor_list_lengths = getattr(ctx, "_te_fwd_tensor_list_lengths", {}) for (pos, as_list), g in zip(grad_targets, grads): @@ -1307,18 +1305,335 @@ def _autograd_backward(ctx, *grad_outputs): fwd_qualname, _autograd_backward, setup_context=_setup_context, - lib=lib, + lib=_TE_LIB, + ) + + +def _register_outer_forwarder( + *, + outer_op_name: str, + inner_op_name: str, + arg_names: List[str], +) -> None: + """Register the outer op's default kernel + fake as a thin + forwarder into the inner op. + + The outer op must remain opaque to compilation (so + ``register_torch_dispatch`` rules installed on it actually fire); + we register the kernel against ``CompositeExplicitAutograd`` and + additionally register a fake impl that simply re-invokes the + inner op. For the subclass path the dispatch rule rewrites the + call into an inner call *before* this kernel/fake ever runs; the + forwarder is only consulted when no rule matches (i.e. the inputs + are plain tensors and / or plain ``QuantizedTensorStorage`` flat + slots that already match the inner schema directly). + """ + inner_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_op_name) + + def _outer_kernel(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + _TE_LIB.impl(outer_op_name, _outer_kernel, "CompositeExplicitAutograd") + + def _outer_fake(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + torch.library.register_fake( + f"{_TE_OP_NAMESPACE}::{outer_op_name}", _outer_fake, lib=_TE_LIB + ) + + +def _te_register_custom_op( + *, + op_name: str, + num_outputs: int, + input_tensors_for_grad: List[str], + fwd_arg_type: type, + fwd_impl: Callable[[Any], Any], + fwd_fake_impl: Optional[Callable[[Any], Any]] = None, + setup_context: Callable[..., None], + backward_arg_type: type, + backward_obj: type, + backward_impl: Callable[[Any], Any], + backward_fake_impl: Optional[Callable[[Any], Any]] = None, + subclasses: Optional[Sequence[type]] = None, +) -> Callable[..., Any]: + """Register a TE module's forward + backward as a single torch custom op. + + Parameters + ---------- + op_name + Op name used when registering with ``torch.library``. The + namespace is fixed at module level (:data:`_TE_OP_NAMESPACE`). + num_outputs + Number of user-facing tensor outputs returned by ``fwd_impl``. + The op concatenates ``[*output_tensors, *tensors_to_save]`` into + a single ``Tensor[]`` return; the wrapper uses ``num_outputs`` to + split the two halves on the way back out. + input_tensors_for_grad + Names of forward-arg-type fields for which ``backward_impl`` + returns gradients, in the same order. The wrapper uses this to + pad the autograd return tuple with ``None`` for every input not + listed here, so torch sees one grad slot per forward input as + required by ``register_autograd``. + fwd_arg_type + Dataclass type aggregating all forward inputs (e.g. + ``LinearFwdArgs``). Used to (re)build the structured argument + from the flat tensor / non-tensor inputs accepted by the custom op. + fwd_impl + Eager forward implementation. Receives a single argument of type + ``fwd_arg_type`` and must return a tuple of the form + ``(*output_tensors, tensors_to_save, tensor_objects, ctx_attrs)`` + where: + + * ``output_tensors`` -- one or more :class:`torch.Tensor` outputs + returned to the caller. + * ``tensors_to_save`` -- flat list of :class:`torch.Tensor` to be + stashed via ``ctx.save_for_backward``. + * ``tensor_objects`` -- the metadata object produced by + :func:`prepare_for_saving`, paired with ``tensors_to_save`` to + let the backward reconstruct quantized / structured tensors. + * ``ctx_attrs`` -- non-tensor state to attach to the autograd + context, restricted to values that cannot be derived from the + forward args inside ``setup_context``. + fwd_fake_impl + Optional fake (shape inference) counterpart of ``fwd_impl``, + registered via ``torch.library.register_fake``. Returns the same + tuple shape as ``fwd_impl`` -- ``(*output_tensors, + tensors_to_save, tensor_objects, ctx_attrs)`` -- but every + ``torch.Tensor`` is a fake tensor (allocated via + ``quantizer.make_empty`` or ``torch.empty``) carrying only the + correct shape / dtype / device, with no real storage or + computation. ``tensor_objects`` and ``ctx_attrs`` must be + structurally identical to those produced by ``fwd_impl`` so + that ``setup_context`` and ``backward_impl`` see the same + non-tensor state in eager and traced modes. + setup_context + Eager autograd ``setup_context`` analogue. Receives a freshly + constructed ``backward_obj`` instance, the forward args, the + forward output, and ``ctx_attrs`` produced by ``fwd_impl``; + is responsible for populating the backward-state object so that + ``backward_impl`` can later consume it. + backward_arg_type + Type accepted by ``backward_impl``. May differ from ``backward_obj`` + if the backward op needs a wrapped / opaque view of the state. + backward_obj + Dataclass / class used to instantiate a fresh backward-state + container at the end of the forward pass (typically the same as + ``backward_arg_type``). + backward_impl + Eager backward implementation. Receives a single argument of type + ``backward_arg_type`` and returns the gradient tuple. + backward_fake_impl + Optional fake counterpart of ``backward_impl``. Returns the same + gradient tuple as ``backward_impl``, with fake tensors in place + of the real gradients. + + Returns + ------- + Callable + A function ``forward_fn(fwd_arg_type_instance)`` that dispatches + through the registered custom op, returning the user-facing + outputs (single tensor if ``num_outputs == 1``, otherwise a + tuple). Use under ``torch.compiler.is_compiling()`` as a drop-in + for ``Function.apply``. + """ + + outer_fwd_name = op_name + outer_bwd_name = f"{op_name}_backward" + subclass_list = list(subclasses or ()) + + # Precompute the bucket list once per arg type and capture it in + # the registered closures. Re-deriving the bucket list inside a + # compiled call would force :func:`_get_buckets` to read + # ``cls.__dict__`` from inside a Dynamo-traced function, which + # triggers a "mappingproxy affected by dictionary mutation" graph + # break under ``fullgraph=True``. + fwd_buckets: List[_Bucket] = _get_buckets(fwd_arg_type) + bwd_buckets: List[_Bucket] = _get_buckets(backward_arg_type) + + fwd_schema_args, fwd_arg_names = _build_schema(fwd_buckets) + bwd_schema_args, bwd_arg_names = _build_schema(bwd_buckets) + + num_grad_inputs = len(input_tensors_for_grad) + fwd_slot_defaults, grad_targets = _resolve_grad_targets( + fwd_buckets, fwd_arg_type, input_tensors_for_grad + ) + + # Two-tier layout when subclass dispatch rules are requested: + # inner = ``{op_name}_base`` -- real impl, sees only plain tensors + # and the storage-flatten metadata. + # outer = ``{op_name}`` -- user-facing op that either falls through + # to the inner op (plain-tensor path) or is rewritten by a + # ``register_torch_dispatch`` rule (subclass path) into a + # call to the inner op with subclass tensors flattened in + # place. Both tiers carry their own ``register_autograd`` + # bridge. + # Single-tier when no subclasses are given: only the outer pair is + # defined and it owns the real impl (today's behaviour). + inner_fwd_name = f"{op_name}_base" if subclass_list else outer_fwd_name + inner_bwd_name = f"{outer_bwd_name}_base" if subclass_list else outer_bwd_name + + # Forward op concatenates user outputs and tensors_to_save into a + # single ``Tensor[]`` return so that autograd's ``setup_context`` can + # stash the saved-for-backward tensors without re-running the eager + # impl. The schema is non-nullable (``Tensor[]``, not ``Tensor?[]``) + # because ``torch.library.register_autograd`` does not propagate + # ``grad_fn`` to a nullable list output. ``None`` entries on either + # side are smuggled through via :func:`_encode_none` / + # :func:`_decode_none` sentinels. + _TE_LIB.define(f"{inner_fwd_name}{fwd_schema_args} -> Tensor[]") + _TE_LIB.define(f"{inner_bwd_name}{bwd_schema_args} -> Tensor[]") + if subclass_list: + # Outer fwd / outer bwd are user-facing entry points. The + # outer fwd is the target of ``register_torch_dispatch`` for + # the forward subclass path; outer bwd is the target for the + # backward subclass path. Both forward to the corresponding + # inner op when no rule matches (plain-tensor / pure-storage + # path). + _TE_LIB.define(f"{outer_fwd_name}{fwd_schema_args} -> Tensor[]") + _TE_LIB.define(f"{outer_bwd_name}{bwd_schema_args} -> Tensor[]") + + # Inner pair owns the real implementation. The fwd & bwd kernels + # are registered directly against the user-supplied impls; the + # autograd bridge below wires the inner fwd op's backward to call + # the inner bwd op. + inner_fwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_fwd_name}" + inner_bwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_bwd_name}" + _register_kernel( + op_name=inner_fwd_name, + op_qualname=inner_fwd_qualname, + arg_type=fwd_arg_type, + arg_names=fwd_arg_names, + buckets=fwd_buckets, + impl=fwd_impl, + fake_impl=fwd_fake_impl, + format_result=lambda r: _format_fwd_result(r, num_outputs), + ) + _register_kernel( + op_name=inner_bwd_name, + op_qualname=inner_bwd_qualname, + arg_type=backward_arg_type, + arg_names=bwd_arg_names, + buckets=bwd_buckets, + impl=backward_impl, + fake_impl=backward_fake_impl, + format_result=lambda g: _format_bwd_result(g, num_grad_inputs, inner_bwd_qualname), + ) + _register_autograd_for_op( + fwd_op_name=inner_fwd_name, + bwd_op_name=inner_bwd_name, + fwd_arg_type=fwd_arg_type, + fwd_arg_names=fwd_arg_names, + fwd_buckets=fwd_buckets, + bwd_arg_names=bwd_arg_names, + bwd_buckets=bwd_buckets, + num_outputs=num_outputs, + fwd_slot_defaults=fwd_slot_defaults, + grad_targets=grad_targets, + fwd_fake_impl=fwd_fake_impl, + fwd_impl=fwd_impl, + setup_context_user=setup_context, + backward_obj_type=backward_obj, ) - fwd_op = getattr(getattr(torch.ops, op_namespace), op_name) + if subclass_list: + # Two-tier setup, mirroring the ex.py pattern: + # + # * Inner pair (already registered above) carries the real + # kernels + fakes and a full ``register_autograd`` bridge. + # It only ever sees plain tensors / plain + # ``QuantizedTensorStorage`` flat slots; the subclass + # wrapper never reaches it. + # * Outer pair is a thin opaque shell. Its kernels forward + # to the inner op and its ``register_torch_dispatch`` rules + # flatten registered subclasses inline before forwarding. + # It carries its own autograd bridge so that the user-facing + # tensor (e.g. a ``Float8Tensor`` weight parameter) ends + # up on the autograd graph and receives a ``.grad``. With + # ``__tensor_unflatten__`` rebuilding a real quantizer from + # the subclass meta snapshot, outer's setup_context can run + # the user fake impl on the raw forward inputs even when + # they include reconstructed subclass instances. + _register_outer_forwarder( + outer_op_name=outer_fwd_name, + inner_op_name=inner_fwd_name, + arg_names=fwd_arg_names, + ) + _register_outer_forwarder( + outer_op_name=outer_bwd_name, + inner_op_name=inner_bwd_name, + arg_names=bwd_arg_names, + ) + _register_autograd_for_op( + fwd_op_name=outer_fwd_name, + bwd_op_name=outer_bwd_name, + fwd_arg_type=fwd_arg_type, + fwd_arg_names=fwd_arg_names, + fwd_buckets=fwd_buckets, + bwd_arg_names=bwd_arg_names, + bwd_buckets=bwd_buckets, + num_outputs=num_outputs, + fwd_slot_defaults=fwd_slot_defaults, + grad_targets=grad_targets, + fwd_fake_impl=fwd_fake_impl, + fwd_impl=fwd_impl, + setup_context_user=setup_context, + backward_obj_type=backward_obj, + ) + + # Register ``torch_dispatch`` rules per subclass on both the + # outer fwd and the outer bwd op. The rule replaces the outer + # call entirely: it flattens every ``_UniversalTensorBucket`` + # slot whose ``name`` value is an instance of the registered + # subclass into ``(None, [inner tensors], process_group, + # opaque_meta)`` and invokes the inner op on the rewritten + # args. After the rewrite no subclass tensor remains in the + # call's arg list, and the autograd entry that ends up on the + # output graph is the inner op's (not the outer's), so the + # backward path goes through the inner pair only. + fwd_slot_offsets = _collect_universal_slot_offsets(fwd_buckets) + bwd_slot_offsets = _collect_universal_slot_offsets(bwd_buckets) + inner_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_fwd_name) + inner_bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_bwd_name) + outer_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) + outer_bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_bwd_name) + outer_fwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_fwd_name}" + outer_bwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_bwd_name}" + for subclass in subclass_list: + def _fwd_rule(mode, func, types, args, kwargs, subclass=subclass): + new_args = list(args) + _flatten_subclass_into_slots(new_args, fwd_slot_offsets, subclass) + return inner_fwd_op(*new_args) + + def _bwd_rule(mode, func, types, args, kwargs, subclass=subclass): + new_args = list(args) + _flatten_subclass_into_slots(new_args, bwd_slot_offsets, subclass) + return inner_bwd_op(*new_args) + + torch.library.register_torch_dispatch( + outer_fwd_qualname, subclass, _fwd_rule, lib=_TE_LIB + ) + torch.library.register_torch_dispatch( + outer_bwd_qualname, subclass, _bwd_rule, lib=_TE_LIB + ) + + # ``QuantizedTensor.__torch_dispatch__`` falls back to + # dequantizing all subclass args for any op it does not + # recognise, which would defeat our + # ``register_torch_dispatch`` rules. Marking both outer ops + # as passthroughs makes QuantizedTensor delegate straight to + # ``super().__torch_dispatch__`` for them, where the + # registered dispatch rules are honoured. + from transformer_engine.pytorch.quantized_tensor import ( + _quantized_tensor_passthrough_ops, + ) + _quantized_tensor_passthrough_ops.add(outer_fwd_op.default) + + fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) def forward_fn(fwd_args): - # Bind ``lib`` here so its registrations (impl / register_fake / - # register_autograd) outlive ``_te_register_custom_op`` even if - # all other references to it are dropped: ``torch.library`` uses - # the ``Library`` instance lifetime for all attached registrations. - _ = lib # noqa: F841 -- closure-captured for lifetime only - kwargs = linear_arg_type.torch_compile_pack(fwd_args, fwd_buckets) + kwargs = _pack(fwd_args, fwd_buckets) flat = [kwargs[name] for name in fwd_arg_names] result = fwd_op(*flat) outputs = [_decode_none(t) for t in result[:num_outputs]] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6031c809ca..58f42781e0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -779,14 +779,18 @@ def fake_quantize_weight( ``quantizer.make_empty``. Used by torch custom-op fake registrations. """ - # Already-quantized weight (primary FP8 parameters) - if isinstance(tensor, QuantizedTensor): - update_rowwise = True if quantizer.rowwise_usage else None - update_columnwise = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise, - columnwise_usage=update_columnwise, - ) + # Already-quantized weight (primary FP8 parameters, both the + # ``Float8Tensor``-style subclass wrappers and the bare + # ``Float8TensorStorage``-style flat carriers produced by the + # outer-op torch_dispatch rule on the way into the inner op). + if isinstance(tensor, QuantizedTensorStorage): + if quantizer is not None: + update_rowwise = True if quantizer.rowwise_usage else None + update_columnwise = True if quantizer.columnwise_usage else None + tensor.update_usage( + rowwise_usage=update_rowwise, + columnwise_usage=update_columnwise, + ) return tensor, None # Validate workspace @@ -1027,7 +1031,11 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return - if recipe.custom() and isinstance(recipe_state, CustomRecipeState): + if ( + recipe.custom() + and isinstance(recipe_state, CustomRecipeState) + and recipe_state.recipe is recipe + ): return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and @@ -1914,6 +1922,12 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: return if not hasattr(self, "weight_names") or not self.weight_names: return + # Skip under ``torch.compile`` -- the check is a one-off + # runtime guard that calls ``tensor._get_quantizer()`` (returns + # a ``Quantizer``, not a Tensor) and Dynamo cannot trace + # quantizer objects flowing through ``call_method``. + if torch.compiler.is_compiling(): + return recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ff1a55e1f1..5fe7602899 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -58,16 +58,21 @@ general_gemm, ) from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type -from ..jit import no_torch_dynamo +from ..dynamo import _te_register_custom_op from ..graph import is_graph_capturing from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, Quantizer, + TensorOrQuantized, prepare_for_saving, - restore_from_func_ctx, + restore_from_saved, +) +from ..tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, ) -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import clear_columnwise_cache, is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up @@ -82,20 +87,17 @@ __all__ = ["Linear"] -TensorOrQuantized = Union[torch.Tensor, QuantizedTensorStorage] - - @dataclass(slots=True) class LinearFwdArgs: """Single-argument bag for the forward path of :class:`_Linear`.""" # --- Differentiable tensors (also passed positionally to autograd) --- weight: TensorOrQuantized - inp: torch.Tensor + inp: TensorOrQuantized bias: Optional[torch.Tensor] # --- Non-differentiable cached tensors --- - weight_workspace: Optional[torch.Tensor] + weight_workspace: Optional[QuantizedTensorStorage] # --- requires_grad flags (cached so backward does not re-query) --- input_requires_grad: bool @@ -227,16 +229,23 @@ class LinearBwdArgs: # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None - def setup_saved_tensors(self, ctx: torch.autograd.function.FunctionCtx) -> None: - """Pull saved tensors from ``ctx`` into the fields backward consumes.""" + def setup_saved_tensors(self, ctx) -> None: + """Restore saved tensors into the fields consumed by backward. + + Accepts both a ``torch.autograd.Function`` ctx (eager path) and a + ``torch.library.register_autograd`` ctx (compile path); both expose + ``saved_tensors`` and the ``tensor_objects`` attribute we attach + during forward. + """ ( self.inputmat, self.weight_fp8, self.saved_weight, self.bias, - ) = restore_from_func_ctx( - ctx - ) # pylint: disable=unbalanced-tuple-unpacking + ) = restore_from_saved( # pylint: disable=unbalanced-tuple-unpacking + ctx.tensor_objects, + list(ctx.saved_tensors), + ) def _check_fp8_reduce_and_update(): @@ -297,7 +306,19 @@ def _linear_forward_impl( # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad + # NOTE: prefer the explicit ``args.weight_requires_grad`` flag over + # ``weight.requires_grad`` so we stay consistent with the fake impl + # under ``torch.compile``: when the outer op flattens a + # ``Float8Tensor`` wrapper into a ``Float8TensorStorage`` for the + # inner op, the wrapper's ``requires_grad`` is observed inside the + # autograd Function and reads as ``False`` (autograd detaches its + # forward inputs), so the requires-grad bit baked into the + # storage metadata snapshot ends up ``False`` too. The fake impl + # uses ``args.weight_requires_grad`` (populated at outer-call site + # from the live ``nn.Parameter``) so the real impl must too, + # otherwise their ``backward_needs_input`` flags diverge and + # ``tensors_to_save_from_forward`` ends up with different lengths. + backward_needs_input = is_grad_enabled and args.weight_requires_grad with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) @@ -1638,6 +1659,7 @@ def backward( bwd_args: LinearBwdArgs = ctx.backward_objects bwd_args.grad_output = grad_output bwd_args.setup_saved_tensors(ctx) + ctx.tensor_objects = None nvtx_label = "transformer_engine._Linear.backward" if bwd_args.ub_name is not None: nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" @@ -1654,6 +1676,33 @@ def backward( return result +# Register the linear forward + backward as a single torch custom op so that +# ``torch.compile`` can trace through it without entering the eager +# ``torch.autograd.Function`` machinery. Used by :meth:`Linear.forward` +# under ``torch.compiler.is_compiling()``. +_linear_compiled_op = _te_register_custom_op( + op_name="linear", + num_outputs=2, + input_tensors_for_grad=["weight", "inp", "bias"], + fwd_arg_type=LinearFwdArgs, + fwd_impl=_linear_forward_impl, + fwd_fake_impl=_linear_forward_fake_impl, + setup_context=_linear_setup_ctx, + backward_arg_type=LinearBwdArgs, + backward_obj=LinearBwdArgs, + backward_impl=_linear_backward, + backward_fake_impl=_linear_backward_fake_impl, + # Two-tier custom op: the outer ``linear`` op accepts tensor + # subclasses (e.g. ``Float8Tensor`` as a weight), and an + # ``register_torch_dispatch`` rule flattens each subclass into + # plain tensors plus storage metadata before calling the inner + # ``linear_base`` op. The wrapper's autograd identity stays + # attached to the inner tensors so gradients flow back to the + # user-facing tensor (``Linear.weight.grad`` is populated). + subclasses=[Float8Tensor], +) + + class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` @@ -2037,7 +2086,6 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() def forward( self, inp: torch.Tensor, @@ -2116,12 +2164,18 @@ def forward( grad_output_quantizer, ) = quantizers - if is_grad_enabled: - linear_fn = _Linear.apply - autograd_ctx = [] - else: - linear_fn = _Linear.forward - autograd_ctx = [None] + # Under torch.compile we always dispatch through the registered + # custom op (it only takes ``fwd_args``); torch.library handles the + # no-grad case automatically. Otherwise fall back to the eager + # torch.autograd.Function (or its bare forward when grad is off). + use_compiled_op = torch.compiler.is_compiling() + if not use_compiled_op: + if is_grad_enabled: + linear_fn = _Linear.apply + autograd_ctx = [] + else: + linear_fn = _Linear.forward + autograd_ctx = [None] cache_name = None if (is_first_microbatch is None or self.is_fsdp2) else "weight" weight_workspace = ( @@ -2216,13 +2270,16 @@ def forward( cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, ) - out, new_weight_workspace = linear_fn( - *autograd_ctx, - weight_tensor, - inp, - linear_bias_tensor, - fwd_args, - ) + if use_compiled_op: + out, new_weight_workspace = _linear_compiled_op(fwd_args) + else: + out, new_weight_workspace = linear_fn( + *autograd_ctx, + weight_tensor, + inp, + linear_bias_tensor, + fwd_args, + ) if new_weight_workspace is not None and cache_name is not None: if isinstance(new_weight_workspace, torch.Tensor): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..da999c3a5a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -32,6 +32,14 @@ from .jit import jit_fuser +# Trace-friendly TE DType ids (Python ints). Materialized once at +# import time so that hot paths (RecipeState init, get_fp8_te_dtype_id) +# never touch the pybind11 enum, which Dynamo cannot trace. +_TE_DTYPE_ID_FLOAT8_E4M3 = int(tex.DType.kFloat8E4M3) +_TE_DTYPE_ID_FLOAT8_E5M2 = int(tex.DType.kFloat8E5M2) +_TE_DTYPE_ID_FLOAT4_E2M1 = int(tex.DType.kFloat4E2M1) + + __all__ = [ "autocast", "quantized_model_init", @@ -286,6 +294,17 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp8_te_dtype_id(fp8_recipe: Recipe, fprop_tensor: bool = True) -> int: + """Trace-friendly variant of :func:`get_fp8_te_dtype` returning the + integer id of the TE ``DType`` enum. Use this on any code path that + may be traced by ``torch.compile``.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return _TE_DTYPE_ID_FLOAT8_E4M3 + return _TE_DTYPE_ID_FLOAT8_E5M2 + + def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: """Get fp4 data type according to recipe and tensor""" if fp4_recipe.fp4_format == Format.E2M1: @@ -293,6 +312,14 @@ def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") +def get_fp4_te_dtype_id(fp4_recipe: Recipe) -> int: + """Trace-friendly variant of :func:`get_fp4_te_dtype` returning the + integer id of the TE ``DType`` enum.""" + if fp4_recipe.fp4_format == Format.E2M1: + return _TE_DTYPE_ID_FLOAT4_E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -1404,7 +1431,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1453,7 +1480,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1496,7 +1523,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + self.dtype = get_fp8_te_dtype_id(recipe, mode == "forward") # Allocate buffers if device is None: @@ -1536,9 +1563,9 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + self.qx_dtype = get_fp8_te_dtype_id(recipe, True) + self.qw_dtype = get_fp8_te_dtype_id(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype_id(recipe, False) # Allocate buffers if device is None: @@ -1621,7 +1648,7 @@ def __init__( self.mode = mode self.num_quantizers = num_quantizers self.roles = roles - self.dtype = get_fp4_te_dtype(recipe) + self.dtype = get_fp4_te_dtype_id(recipe) # Allocate buffers if device is None: @@ -1837,12 +1864,17 @@ def make_quantizers(self) -> list: roles = self.roles if roles is None: - warnings.warn( - "CustomRecipeState: no QuantizerRole list provided by the module/op. " - "Falling back to bare QuantizerRole() defaults. " - "Override get_quantizer_roles() to provide meaningful roles.", - stacklevel=2, - ) + # Dynamo cannot trace the Python builtin ``_warnings.warn``, + # which graph-breaks any ``fullgraph=True`` compile that + # eventually calls ``make_quantizers``. The warning is + # informational only and is safe to skip under compile. + if not torch.compiler.is_compiling(): + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) roles = [QuantizerRole() for _ in range(self.num_quantizers)] # qfactory must return a Quantizer or QuantizerRequest for every slot. diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a7722f777e..21e5aca58e 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -5,7 +5,7 @@ """Pure Python base classes for quantization.""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Any, Dict, Union +from typing import Optional, Tuple, Iterable, Any, Dict, List, Union import abc import warnings import math @@ -21,6 +21,80 @@ ) +# Maps a Quantizer subclass's ``__qualname__`` to the class object. Populated +# lazily via :meth:`Quantizer.__init_subclass__` and consumed by +# :meth:`Quantizer._unflatten` to dispatch reconstruction to the right +# subclass when a TE custom op is unpacked under ``torch.compile``. +_QUANTIZER_REGISTRY: Dict[str, type] = {} + + +def _quantizer_subclass_snapshot( + quantizer: Optional["Quantizer"], +) -> Optional[Tuple[Tuple[str, Any], ...]]: + """Return a Dynamo-guard-stable snapshot of a quantizer, or ``None``. + + Used by tensor subclasses (e.g. :class:`Float8Tensor`) to embed a + tensor-free, comparable representation of their live + :class:`Quantizer` in the ``meta`` dict returned from + ``__tensor_flatten__``. PyTorch's tensor-subclass metadata guard + diff-checks that dict via ``dict.__eq__`` on every entry into the + compiled region, so values that resolve to elementwise tensor + comparison or identity-only equality (live ``torch.Tensor`` + objects, ``ProcessGroup``, the live quantizer instance itself) + cannot appear there. + + The snapshot is a sorted tuple of ``(key, value)`` pairs derived + from ``quantizer._flatten()`` whenever the quantizer's state is + fully expressible without tensors (an empty trailing tensor list + in the ``_flatten`` triplet). Quantizers carrying tensors in their + state (e.g. :class:`Float8Quantizer`'s ``scale`` / ``amax``) and + quantizers that don't implement ``_flatten`` produce ``None``; + in that case the subclass's ``__tensor_unflatten__`` will + rebuild the wrapper with ``quantizer=None`` and any code that + needs the live quantizer must source it from the bucket-level + opaque metadata flowing through the inner custom op. + """ + if quantizer is None: + return None + try: + meta, _pg, tensors = quantizer._flatten() + except NotImplementedError: + return None + if tensors: + return None + if hasattr(meta, "_data"): + meta_dict = meta._data + elif isinstance(meta, dict): + meta_dict = meta + else: + return None + return tuple(sorted(meta_dict.items(), key=lambda kv: kv[0])) + + +def _quantizer_from_subclass_snapshot( + snapshot: Optional[Tuple[Tuple[str, Any], ...]], +) -> Optional["Quantizer"]: + """Inverse of :func:`_quantizer_subclass_snapshot`. + + Rebuilds the quantizer from the qualname stored in the snapshot's + ``"_qcls"`` entry, dispatching via :func:`Quantizer._unflatten` + (and so via the right subclass's ``_do_unflatten``). The + reconstructed quantizer's process-group reference is always + ``None`` -- live ``ProcessGroup`` objects cannot survive the + snapshot round trip; callers that need a real process group + obtain it via the bucket-level opaque metadata instead. + """ + if snapshot is None: + return None + meta_dict = dict(snapshot) + return Quantizer._unflatten(meta_dict, None, []) + +# Same idea for lightweight QuantizedTensorStorage shells. Populated via +# :meth:`QuantizedTensorStorage.__init_subclass__` and consumed by +# :meth:`QuantizedTensorStorage._torch_compile_unflatten`. +_STORAGE_REGISTRY: Dict[str, type] = {} + + # Custom ops that should pass through __torch_dispatch__ without unwrapping # QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that # handle quantized tensors internally. @@ -130,6 +204,65 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: f"{self.__class__.__name__} class does not implement copy_from_storage function" ) + # ------------------------------------------------------------------ # + # torch.compile flatten / unflatten protocol + # ------------------------------------------------------------------ # + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + _STORAGE_REGISTRY[cls.__qualname__] = cls + + def __eq__(self, other: object) -> bool: + return self is other + + def __hash__(self) -> int: + return id(self) + + def _torch_compile_flatten( + self, + ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: + """Pack this storage's metadata and live tensor state for torch.compile.""" + raise NotImplementedError( + f"{type(self).__name__} class does not implement " + "_torch_compile_flatten; required for torch.compile support " + "of QuantizedTensorStorage objects." + ) + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "QuantizedTensorStorage": + """Reconstruct an instance of ``cls`` from storage flatten data.""" + raise NotImplementedError( + f"{cls.__name__} class does not implement " + "_torch_compile_do_unflatten; required for torch.compile " + "support of QuantizedTensorStorage objects." + ) + + @classmethod + def _torch_compile_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "QuantizedTensorStorage": + """Dispatch to the right storage subclass based on metadata.""" + storage_cls = meta["_qstorage_cls"] + target = _STORAGE_REGISTRY.get(storage_cls) + if target is None: + raise ValueError( + f"No QuantizedTensorStorage subclass registered under " + f"qualname {storage_cls!r}; known: {sorted(_STORAGE_REGISTRY)}" + ) + return target._torch_compile_do_unflatten(meta, process_group, tensors) + + + +TensorOrQuantized = Union[torch.Tensor, QuantizedTensorStorage] + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorStorage], @@ -378,6 +511,75 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self.columnwise_usage, } + # ------------------------------------------------------------------ # + # torch.compile flatten / unflatten protocol + # ------------------------------------------------------------------ # + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Auto-register every Quantizer subclass so ``_unflatten`` can + # dispatch back to it by ``__qualname__``. + _QUANTIZER_REGISTRY[cls.__qualname__] = cls + + def _flatten( + self, + ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: + """Pack this quantizer's state into the + ``(meta, process_group, tensors)`` triplet expected by the + flattenable bucket in :mod:`transformer_engine.pytorch.dynamo`. + + * ``meta`` -- :class:`OpaqueSimpleMetadata` of all simple state. + Subclasses **must** include their own ``cls.__qualname__`` under + the ``"_qcls"`` key so :meth:`_unflatten` can dispatch back to + ``_do_unflatten`` on the correct subclass. Common base state + (``rowwise_usage``, ``columnwise_usage``, ``internal``, + ``optimize_for_gemm``) is the subclass's responsibility too. + * ``process_group`` -- the (single) :class:`torch.distributed.ProcessGroup` + this quantizer participates in, or ``None``. Quantizers without a + process group return ``None``. + * ``tensors`` -- the live tensor state the op needs to receive + (e.g. ``scale``, ``amax``, RHT matrix). Order is + quantizer-defined and matches what ``_do_unflatten`` expects. + """ + raise NotImplementedError( + f"{type(self).__name__} class does not implement _flatten; " + "required for torch.compile support of TE custom ops." + ) + + @classmethod + def _do_unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "Quantizer": + """Reconstruct an instance of ``cls`` from the triplet returned by a + previous :meth:`_flatten` on the same subclass. Subclasses override. + """ + raise NotImplementedError( + f"{cls.__name__} class does not implement _do_unflatten; " + "required for torch.compile support of TE custom ops." + ) + + @classmethod + def _unflatten( + cls, + meta: Any, + process_group: Optional["torch.distributed.ProcessGroup"], + tensors: List[torch.Tensor], + ) -> "Quantizer": + """Dispatch to the right subclass's :meth:`_do_unflatten` based on + the ``"_qcls"`` qualname stored in ``meta``. + """ + qcls = meta["_qcls"] + target = _QUANTIZER_REGISTRY.get(qcls) + if target is None: + raise ValueError( + f"No Quantizer subclass registered under qualname {qcls!r}; " + f"known: {sorted(_QUANTIZER_REGISTRY)}" + ) + return target._do_unflatten(meta, process_group, tensors) + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data @@ -686,13 +888,13 @@ def maybe_update_inplace(arg, new_arg, schema_arg): out = super().__torch_dispatch__(func, types, args, kwargs) return out - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - # Do not force the QuantizedTensor type on the returned tensor - return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + # Set as a class-level attribute rather than a ``@classmethod`` so that + # Dynamo recognises the canonical "torch_function disabled" idiom + # and can trace through custom-op calls that receive a + # QuantizedTensor subclass as an argument. As a method override, + # Dynamo bails with "cannot trace builtin + # torch._C._disabled_torch_function_impl". + __torch_function__ = torch._C._disabled_torch_function_impl def contiguous( self, memory_format: torch.memory_format = torch.contiguous_format diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 914397b9b6..e32081e055 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..constants import canonicalize_te_dtype from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple @@ -47,7 +48,7 @@ def __init__( block_scaling_dim: int = 2, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) self.block_len = 128 self.force_pow_2_scales = force_pow_2_scales self.amax_epsilon = amax_epsilon @@ -244,7 +245,21 @@ def make_empty( **tensor_kwargs, ) - # Construct FP8 tensor + is_2d_scaled = self.block_scaling_dim == 2 + + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return Float8BlockwiseQTensorStorage( + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + self.dtype, + self, + is_2d_scaled, + fake_dtype=dtype, + ) + return Float8BlockwiseQTensor( shape=shape, dtype=dtype, @@ -254,7 +269,7 @@ def make_empty( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, quantizer=self, - is_2D_scaled=self.block_scaling_dim == 2, + is_2D_scaled=is_2d_scaled, requires_grad=requires_grad, ) @@ -266,6 +281,41 @@ def calibrate(self, tensor: torch.Tensor) -> None: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling + def _flatten(self): + from ..dynamo import OpaqueSimpleMetadata + + meta = OpaqueSimpleMetadata( + { + "_qcls": type(self).__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + "block_len": self.block_len, + "amax_epsilon": self.amax_epsilon, + "force_pow_2_scales": self.force_pow_2_scales, + "block_scaling_dim": self.block_scaling_dim, + } + ) + return meta, None, [] + + @classmethod + def _do_unflatten(cls, meta, process_group, tensors): + del process_group, tensors + q = cls( + fp8_dtype=meta["dtype"], + rowwise=meta["rowwise_usage"], + columnwise=meta["columnwise_usage"], + amax_epsilon=meta["amax_epsilon"], + force_pow_2_scales=meta["force_pow_2_scales"], + block_scaling_dim=meta["block_scaling_dim"], + ) + q.block_len = meta["block_len"] + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + return q + class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed6091c85b..01bc480bb2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,9 +18,14 @@ ) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func -from ..quantized_tensor import QuantizedTensor, Quantizer +from ..quantized_tensor import ( + QuantizedTensor, + Quantizer, + _quantizer_from_subclass_snapshot, + _quantizer_subclass_snapshot, +) from ._quantization_helpers import _IdentityFunc -from ..constants import dist_group_type +from ..constants import canonicalize_te_dtype, dist_group_type aten = torch.ops.aten @@ -68,7 +73,7 @@ def __init__( super().__init__(rowwise=rowwise, columnwise=columnwise) self.scale = scale self.amax = amax - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) def copy(self) -> Float8Quantizer: """Create shallow copy""" @@ -142,12 +147,29 @@ def make_empty( pin_memory=pin_memory, ) - # Construct FP8 tensor + scale_inv = torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + + # Honor ``internal``: tex.quantize() returns a bare + # Float8TensorStorage when the quantizer is marked internal + # (lower CPU overhead, no autograd-aware subclass) and so should + # make_empty in order to stay shape/type-equivalent on every + # path that touches it (eager fast-path, fake-impl under + # torch.compile, etc.). + if self.internal: + return Float8TensorStorage( + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + fake_dtype=dtype, + data_transpose=data_transpose, + quantizer=self, + ) + return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), + fp8_scale_inv=scale_inv, fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -223,6 +245,36 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _flatten(self): + from ..dynamo import OpaqueSimpleMetadata + + meta = OpaqueSimpleMetadata( + { + "_qcls": type(self).__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + } + ) + return meta, None, [self.scale, self.amax] + + @classmethod + def _do_unflatten(cls, meta, process_group, tensors): + del process_group + scale, amax = tensors + q = cls( + scale=scale, + amax=amax, + fp8_dtype=meta["dtype"], + rowwise=meta["rowwise_usage"], + columnwise=meta["columnwise_usage"], + ) + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + return q + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -279,7 +331,7 @@ def __init__( stacklevel=2, ) del device, use_existing_amax, scale, amax # Kept for backward compatibility - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales @@ -366,12 +418,24 @@ def make_empty( device=device, pin_memory=pin_memory, ) - # Construct FP8 tensor + scale_inv = torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return Float8TensorStorage( + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + fake_dtype=dtype, + data_transpose=data_transpose, + quantizer=self, + ) + return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), + fp8_scale_inv=scale_inv, fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -461,6 +525,41 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _flatten(self): + from ..dynamo import OpaqueSimpleMetadata + + meta = OpaqueSimpleMetadata( + { + "_qcls": type(self).__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + "with_amax_reduction": self.with_amax_reduction, + "force_pow_2_scales": self.force_pow_2_scales, + "amax_epsilon": self.amax_epsilon, + } + ) + return meta, self.amax_reduction_group, [] + + @classmethod + def _do_unflatten(cls, meta, process_group, tensors): + del tensors + q = cls( + fp8_dtype=meta["dtype"], + device=torch.device("cuda"), + rowwise=meta["rowwise_usage"], + columnwise=meta["columnwise_usage"], + with_amax_reduction=meta["with_amax_reduction"], + amax_reduction_group=process_group, + force_pow_2_scales=meta["force_pow_2_scales"], + amax_epsilon=meta["amax_epsilon"], + ) + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + return q + class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data @@ -494,14 +593,82 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ def __repr__(self, *, tensor_contents=None): + # ``__repr__`` is on hot diagnostic paths (Dynamo's + # ``Dynamo failed to run FX node`` formatter, autograd + # anomaly mode, FX node printers, ...) and must never raise. + # In particular, dequantising a fake/functional tensor here + # would access ``data_ptr()`` and replace the real failure + # with a misleading data-pointer error. + try: + shape = tuple(self.shape) + except BaseException: # pylint: disable=broad-except + shape = "" return ( "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" + f"shape={shape}" ")" ) + def __tensor_flatten__(self) -> Tuple[list, dict]: + """torch.compile / tensor-subclass flatten protocol. + + Returns ``(inner_tensor_names, meta)`` so that PyTorch's + wrapper-subclass machinery and :func:`register_torch_dispatch` + rules on custom ops can decompose a ``Float8Tensor`` into + plain tensors plus a static metadata dict at trace time. + + The metadata dict must contain only values supporting stable + ``==`` comparison (Dynamo's tensor-subclass metadata guard + re-evaluates it via dict equality on every entry into the + compiled region). Mutable / runtime-only state such as the + ``_transpose_invalid`` flag deliberately does *not* end up + here; it would flip between calls and trip the "Guard failed + on the same frame" assertion. + + ``_quantizer_snapshot`` carries a tensor-free snapshot of + the live ``Quantizer`` so :meth:`__tensor_unflatten__` can + rebuild a structurally-equivalent quantizer on the unflatten + side. Quantizers that carry tensors in their state (e.g. + :class:`Float8Quantizer` keeps ``scale`` / ``amax``) cannot + be snapshotted into a guard-stable dict and produce a + ``None`` snapshot; in that case the reconstructed + ``Float8Tensor`` will have ``_quantizer = None`` and any + downstream code that needs the quantizer must source it from + elsewhere (typically the bucket-level opaque metadata on the + inner op call). + """ + inner: list = [] + if self._data is not None: + inner.append("_data") + if self._scale_inv is not None: + inner.append("_scale_inv") + if self._transpose is not None: + inner.append("_transpose") + meta = { + "_fp8_dtype": self._fp8_dtype, + "_fake_dtype": self._dtype, + "_quantizer_snapshot": _quantizer_subclass_snapshot(self._quantizer), + "_requires_grad": self.requires_grad, + } + return inner, meta + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict, meta: dict, outer_size, outer_stride + ) -> "Float8Tensor": + quantizer = _quantizer_from_subclass_snapshot(meta.get("_quantizer_snapshot")) + return Float8Tensor( + shape=outer_size, + dtype=meta["_fake_dtype"], + data=inner_tensors.get("_data"), + fp8_scale_inv=inner_tensors.get("_scale_inv"), + fp8_dtype=meta["_fp8_dtype"], + data_transpose=inner_tensors.get("_transpose"), + quantizer=quantizer, + requires_grad=meta.get("_requires_grad", False), + ) + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8Tensor diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..7616de2247 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -15,7 +15,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe -from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..constants import MXFP8_BLOCK_SCALING_SIZE, canonicalize_te_dtype from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -43,7 +43,7 @@ def __init__( columnwise: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp8_dtype + self.dtype = canonicalize_te_dtype(fp8_dtype) def copy(self) -> MXFP8Quantizer: """Create shallow copy""" @@ -146,7 +146,19 @@ def make_empty( pin_memory=pin_memory, ) - # Construct FP8 tensor + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return MXFP8TensorStorage( + data, + scale_inv, + columnwise_data, + columnwise_scale_inv, + self.dtype, + self, + self.optimize_for_gemm, + fake_dtype=dtype, + ) + return MXFP8Tensor( shape=shape, dtype=dtype, @@ -243,6 +255,33 @@ def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> tor def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling + def _flatten(self): + from ..dynamo import OpaqueSimpleMetadata + + meta = OpaqueSimpleMetadata( + { + "_qcls": type(self).__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + } + ) + return meta, None, [] + + @classmethod + def _do_unflatten(cls, meta, process_group, tensors): + del process_group, tensors + q = cls( + fp8_dtype=meta["dtype"], + rowwise=meta["rowwise_usage"], + columnwise=meta["columnwise_usage"], + ) + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + return q + class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 285a7f030a..42ccb611c4 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -15,7 +15,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe -from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..constants import NVFP4_BLOCK_SCALING_SIZE, canonicalize_te_dtype, dist_group_type from ..utils import ( canonicalize_process_group, devices_match, @@ -150,7 +150,7 @@ def __init__( with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.dtype = fp4_dtype + self.dtype = canonicalize_te_dtype(fp4_dtype) self.with_rht = with_rht self.with_post_rht_amax = with_post_rht_amax self.with_amax_reduction = with_amax_reduction @@ -373,7 +373,22 @@ def make_empty( 1, dtype=torch.float32, device=device, pin_memory=pin_memory ) - # Construct FP8 tensor + # See ``Float8Quantizer.make_empty`` for the rationale. + if self.internal: + return NVFP4TensorStorage( + data, + scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + self.dtype, + self, + False, + fake_dtype=dtype, + row_scaled_nvfp4=self.row_scaled_nvfp4, + ) + return NVFP4Tensor( shape=shape, dtype=dtype, @@ -400,6 +415,52 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling + def _flatten(self): + from ..dynamo import OpaqueSimpleMetadata + + meta = OpaqueSimpleMetadata( + { + "_qcls": type(self).__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + "with_rht": self.with_rht, + "with_post_rht_amax": self.with_post_rht_amax, + "with_amax_reduction": self.with_amax_reduction, + "with_2d_quantization": self.with_2d_quantization, + "stochastic_rounding": self.stochastic_rounding, + "row_scaled_nvfp4": self.row_scaled_nvfp4, + "rht_matrix_random_sign_mask_t": self.rht_matrix_random_sign_mask_t, + } + ) + return meta, self.amax_reduction_group, [self.rht_matrix] + + @classmethod + def _do_unflatten(cls, meta, process_group, tensors): + (rht_matrix,) = tensors + # Construct with default RHT mask, then overwrite the computed + # ``rht_matrix_random_sign_mask_t`` / ``rht_matrix`` with the + # restored values so we don't depend on cuda helpers / device state. + q = cls( + fp4_dtype=meta["dtype"], + rowwise=meta["rowwise_usage"], + columnwise=meta["columnwise_usage"], + with_amax_reduction=meta["with_amax_reduction"], + amax_reduction_group=process_group, + with_rht=meta["with_rht"], + with_post_rht_amax=meta["with_post_rht_amax"], + with_2d_quantization=meta["with_2d_quantization"], + stochastic_rounding=meta["stochastic_rounding"], + row_scaled_nvfp4=meta["row_scaled_nvfp4"], + ) + q.rht_matrix_random_sign_mask_t = meta["rht_matrix_random_sign_mask_t"] + q.rht_matrix = rht_matrix + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + return q + class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index ca3913762f..641192dfb2 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -6,7 +6,7 @@ from __future__ import annotations import math -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, List, Tuple import torch import transformer_engine_torch as tex @@ -18,6 +18,16 @@ from ...utils import _empty_tensor +try: + from torch._library.opaque_object import is_opaque_value_type, register_opaque_type + + if not hasattr(TE_DType, "__fx_repr__"): + TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) + if not is_opaque_value_type(TE_DType): + register_opaque_type(TE_DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass + class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8BlockwiseQTensor. @@ -134,6 +144,83 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] + def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: + from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata + + tensors: List[torch.Tensor] = [] + + def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: + if tensor is None: + return False + tensors.append(tensor) + return True + + quantizer_meta = None + process_group = None + quantizer_tensors: List[torch.Tensor] = [] + if self._quantizer is not None: + quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": type(self).__qualname__, + "is_tensor": isinstance(self, torch.Tensor), + "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, + "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, + "device": self.device if isinstance(self, torch.Tensor) else None, + "fp8_dtype": self._fp8_dtype, + "fake_dtype": self._dtype, + "is_2D_scaled": self._is_2D_scaled, + "has_rowwise_data": _append_if_present(self._rowwise_data), + "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), + "has_columnwise_data": _append_if_present(self._columnwise_data), + "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), + "quantizer_meta": quantizer_meta, + } + ) + tensors.extend(quantizer_tensors) + return meta, process_group, tensors + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Any, + tensors: List[torch.Tensor], + ) -> "Float8BlockwiseQTensorStorage": + tensor_iter = iter(tensors) + rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None + rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None + columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None + columnwise_scale_inv = ( + next(tensor_iter) if meta["has_columnwise_scale_inv"] else None + ) + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + kwargs = { + "rowwise_data": rowwise_data, + "rowwise_scale_inv": rowwise_scale_inv, + "columnwise_data": columnwise_data, + "columnwise_scale_inv": columnwise_scale_inv, + "fp8_dtype": meta["fp8_dtype"], + "quantizer": quantizer, + "is_2D_scaled": meta["is_2D_scaled"], + "fake_dtype": meta["fake_dtype"], + } + if meta["is_tensor"]: + kwargs.update( + { + "shape": meta["shape"], + "dtype": meta["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + return cls(**kwargs) + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" if rowwise_data and columnwise_data: diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index de7f8f58e2..ecaf1d919f 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -6,7 +6,7 @@ from __future__ import annotations import math -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import transformer_engine_torch as tex @@ -18,6 +18,16 @@ from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor +try: + from torch._library.opaque_object import is_opaque_value_type, register_opaque_type + + if not hasattr(TE_DType, "__fx_repr__"): + TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) + if not is_opaque_value_type(TE_DType): + register_opaque_type(TE_DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass + class _FromFloat8Func(torch.autograd.Function): """Cast from FP8 to other dtype""" @@ -215,14 +225,100 @@ def view(self, shape: torch.Size): ) def __repr__(self): + # Must never raise: this runs from Inductor error formatters, + # FX node dumps, Dynamo guards, etc. Crucially we must also + # avoid any tensor->scalar materialization (``.item()``, + # ``.tolist()``, ``dequantize()``): under fake-tensor mode they + # allocate fresh unbacked symbols which then leak out of the + # current op as "unreturned outputs" and crash the compile. + # Stick to shape/dtype summaries. + scale_shape = list(getattr(self._scale_inv, "shape", ())) + if self._data is None: + data_repr = "" + else: + data_shape = list(getattr(self._data, "shape", ())) + data_repr = f"" return ( "Float8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " - f"scale_inv={self._scale_inv.item()}, " - f"data={self.dequantize()}" + f"scale_inv=, " + f"data={data_repr}" ")" ) + def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: + from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata + + tensors: List[torch.Tensor] = [] + + def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: + if tensor is None: + return False + tensors.append(tensor) + return True + + quantizer_meta = None + process_group = None + quantizer_tensors: List[torch.Tensor] = [] + if self._quantizer is not None: + quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": type(self).__qualname__, + "is_tensor": isinstance(self, torch.Tensor), + "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, + "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, + "device": self.device if isinstance(self, torch.Tensor) else None, + "fp8_dtype": self._fp8_dtype, + "fake_dtype": self._dtype, + "transpose_invalid": self._transpose_invalid, + "has_data": _append_if_present(self._data), + "has_transpose": _append_if_present(self._transpose), + "has_scale_inv": _append_if_present(self._scale_inv), + "quantizer_meta": quantizer_meta, + } + ) + tensors.extend(quantizer_tensors) + return meta, process_group, tensors + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Any, + tensors: List[torch.Tensor], + ) -> "Float8TensorStorage": + tensor_iter = iter(tensors) + data = next(tensor_iter) if meta["has_data"] else None + transpose = next(tensor_iter) if meta["has_transpose"] else None + scale_inv = next(tensor_iter) if meta["has_scale_inv"] else None + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + kwargs = { + "data": data, + "fp8_scale_inv": scale_inv, + "fp8_dtype": meta["fp8_dtype"], + "data_transpose": transpose, + "quantizer": quantizer, + "fake_dtype": meta["fake_dtype"], + } + if meta["is_tensor"]: + kwargs.update( + { + "shape": meta["shape"], + "dtype": meta["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + out = cls(**kwargs) + out._transpose_invalid = meta["transpose_invalid"] + return out + def _create_transpose(self): """Update FP8 transpose cache""" data = self._data diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 842f42838b..edc1dd8ac1 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -5,7 +5,7 @@ """Mixin class holding data specific for MXFP8Tensor""" from __future__ import annotations -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, List, Tuple from collections.abc import Iterable import math import torch @@ -19,6 +19,16 @@ from ...utils import _empty_tensor +try: + from torch._library.opaque_object import is_opaque_value_type, register_opaque_type + + if not hasattr(TE_DType, "__fx_repr__"): + TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) + if not is_opaque_value_type(TE_DType): + register_opaque_type(TE_DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass + class _FromMXFP8Func(torch.autograd.Function): """Cast from MXFP8 to other dtype""" @@ -173,6 +183,83 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] + def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: + from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata + + tensors: List[torch.Tensor] = [] + + def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: + if tensor is None: + return False + tensors.append(tensor) + return True + + quantizer_meta = None + process_group = None + quantizer_tensors: List[torch.Tensor] = [] + if self._quantizer is not None: + quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": type(self).__qualname__, + "is_tensor": isinstance(self, torch.Tensor), + "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, + "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, + "device": self.device if isinstance(self, torch.Tensor) else None, + "fp8_dtype": self._fp8_dtype, + "fake_dtype": self._dtype, + "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "has_rowwise_data": _append_if_present(self._rowwise_data), + "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), + "has_columnwise_data": _append_if_present(self._columnwise_data), + "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), + "quantizer_meta": quantizer_meta, + } + ) + tensors.extend(quantizer_tensors) + return meta, process_group, tensors + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Any, + tensors: List[torch.Tensor], + ) -> "MXFP8TensorStorage": + tensor_iter = iter(tensors) + rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None + rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None + columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None + columnwise_scale_inv = ( + next(tensor_iter) if meta["has_columnwise_scale_inv"] else None + ) + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + kwargs = { + "rowwise_data": rowwise_data, + "rowwise_scale_inv": rowwise_scale_inv, + "columnwise_data": columnwise_data, + "columnwise_scale_inv": columnwise_scale_inv, + "fp8_dtype": meta["fp8_dtype"], + "quantizer": quantizer, + "with_gemm_swizzled_scales": meta["with_gemm_swizzled_scales"], + "fake_dtype": meta["fake_dtype"], + } + if meta["is_tensor"]: + kwargs.update( + { + "shape": meta["shape"], + "dtype": meta["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + return cls(**kwargs) + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" if rowwise_data and columnwise_data: diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index e51acb71e5..0e4810a4ea 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -8,7 +8,7 @@ from collections.abc import Iterable import functools import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import warnings import torch @@ -21,6 +21,16 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...utils import _empty_tensor +try: + from torch._library.opaque_object import is_opaque_value_type, register_opaque_type + + if not hasattr(TE_DType, "__fx_repr__"): + TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) + if not is_opaque_value_type(TE_DType): + register_opaque_type(TE_DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass + @functools.lru_cache(maxsize=None) def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: @@ -216,6 +226,91 @@ def restore_from_saved( self._amax_columnwise = tensors[5] return tensors[6:] + def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: + from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata + + tensors: List[torch.Tensor] = [] + + def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: + if tensor is None: + return False + tensors.append(tensor) + return True + + quantizer_meta = None + process_group = None + quantizer_tensors: List[torch.Tensor] = [] + if self._quantizer is not None: + quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": type(self).__qualname__, + "is_tensor": isinstance(self, torch.Tensor), + "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, + "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, + "device": self.device if isinstance(self, torch.Tensor) else None, + "fp4_dtype": self._fp4_dtype, + "fake_dtype": self._dtype, + "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "row_scaled_nvfp4": self._row_scaled_nvfp4, + "has_rowwise_data": _append_if_present(self._rowwise_data), + "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), + "has_columnwise_data": _append_if_present(self._columnwise_data), + "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), + "has_amax_rowwise": _append_if_present(self._amax_rowwise), + "has_amax_columnwise": _append_if_present(self._amax_columnwise), + "quantizer_meta": quantizer_meta, + } + ) + tensors.extend(quantizer_tensors) + return meta, process_group, tensors + + @classmethod + def _torch_compile_do_unflatten( + cls, + meta: Any, + process_group: Any, + tensors: List[torch.Tensor], + ) -> "NVFP4TensorStorage": + tensor_iter = iter(tensors) + rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None + rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None + columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None + columnwise_scale_inv = ( + next(tensor_iter) if meta["has_columnwise_scale_inv"] else None + ) + amax_rowwise = next(tensor_iter) if meta["has_amax_rowwise"] else None + amax_columnwise = next(tensor_iter) if meta["has_amax_columnwise"] else None + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + kwargs = { + "rowwise_data": rowwise_data, + "rowwise_scale_inv": rowwise_scale_inv, + "columnwise_data": columnwise_data, + "columnwise_scale_inv": columnwise_scale_inv, + "amax_rowwise": amax_rowwise, + "amax_columnwise": amax_columnwise, + "fp4_dtype": meta["fp4_dtype"], + "quantizer": quantizer, + "with_gemm_swizzled_scales": meta["with_gemm_swizzled_scales"], + "fake_dtype": meta["fake_dtype"], + "row_scaled_nvfp4": meta["row_scaled_nvfp4"], + } + if meta["is_tensor"]: + kwargs.update( + { + "shape": meta["shape"], + "dtype": meta["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + return cls(**kwargs) + def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data From f9e45bad897feeb54a762722ccedb9ffc088aed6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 14 May 2026 11:26:42 +0200 Subject: [PATCH 04/16] [PyTorch] Fix indentation errors in dynamo.py Six blocks in dynamo.py shipped with broken indentation in the previous commit and prevented the module from being imported (`IndentationError` / `SyntaxError`): - `_UniversalTensorBucket.pack`: storage branch had its body / return split across two indentation levels and ended up outside the matching `if isinstance(value, qts):` block. - `_quantizer_cls` / `_recipe_cls`: `try`/`except` pair misaligned. - `_FlattenableBucket._pack_value`: `if hasattr(value, "_flatten"):` pulled inside the wrong outer branch. - `_SimpleBundleBucket.matches_field`: stray extra indent on a `return True`. - `_resolved_field_annotations`, `_get_buckets`, `_pack`, `_unpack`: function bodies indented under their own header by 8 spaces instead of 4. No semantic changes intended -- this commit restores the structure that the docstrings and surrounding code already document. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 98 ++++++++++++++-------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 5b62757ee3..f3c8af1e0b 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -579,11 +579,11 @@ def pack(self, owner: Any) -> List[Tuple[str, Any]]: ] qts = _quantized_tensor_storage_cls() if qts is not None and isinstance(value, qts): - meta, pg, tensors = value._torch_compile_flatten() + meta, pg, tensors = value._torch_compile_flatten() # Stamp the storage-flatten meta with our kind marker so the # unpacker can route by ``__kind__`` alone. meta._data[self.KIND_KEY] = self.KIND_STORAGE - return [ + return [ (self.slot_name(), None), (self.slot_tensors(), list(tensors)), (self.slot_pg(), pg), @@ -708,7 +708,7 @@ def _quantizer_cls() -> Optional[type]: from transformer_engine.pytorch.quantized_tensor import Quantizer _QUANTIZER_REF = Quantizer - except Exception: # pragma: no cover - partial init + except Exception: # pragma: no cover - partial init return None return _QUANTIZER_REF @@ -717,11 +717,11 @@ def _recipe_cls() -> Optional[type]: """Lazy-resolve :class:`Recipe`; ``None`` if unavailable.""" global _RECIPE_REF if _RECIPE_REF is None: - try: - from transformer_engine.common.recipe import Recipe + try: + from transformer_engine.common.recipe import Recipe _RECIPE_REF = Recipe - except Exception: # pragma: no cover - partial init + except Exception: # pragma: no cover - partial init return None return _RECIPE_REF @@ -776,7 +776,7 @@ def _pack_value(self, value: Any) -> Tuple[Any, Any, List[torch.Tensor]]: None, [], ) - if hasattr(value, "_flatten"): + if hasattr(value, "_flatten"): return value._flatten() return value._torch_compile_flatten() @@ -829,7 +829,7 @@ def matches_field(cls, annot: Any) -> bool: # Any registered value-opaque class is hashable / FX-reproducible # and therefore safe to embed in the OpaqueSimpleMetadata bundle. if isinstance(annot, type) and _is_opaque_value_type(annot): - return True + return True origin = get_origin(annot) if origin in (tuple, list): # Inner args may contain Ellipsis (e.g. ``Tuple[int, ...]``); @@ -925,52 +925,52 @@ def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: def _resolved_field_annotations(cls: type) -> List[Tuple[str, Any]]: """Return ``[(field_name, resolved_type), ...]`` for a dataclass.""" - if not dataclasses.is_dataclass(cls): - raise TypeError( + if not dataclasses.is_dataclass(cls): + raise TypeError( f"{cls.__name__} must be a @dataclass to be used as a TE " f"custom-op argument container." - ) - # ``get_type_hints`` resolves forward references and PEP 563 - # ``from __future__ import annotations`` strings. - try: - hints = get_type_hints(cls) - except Exception: - hints = {} + ) + # ``get_type_hints`` resolves forward references and PEP 563 + # ``from __future__ import annotations`` strings. + try: + hints = get_type_hints(cls) + except Exception: + hints = {} return [(f.name, hints.get(f.name, f.type)) for f in dataclasses.fields(cls)] def _get_buckets(cls: type) -> List[_Bucket]: """Build the bucket list for a dataclass from its field annotations. - Dispatch order per field: try each bucket in :data:`_FIELD_BUCKETS` - (Tensor, ProcessGroup, Quantizer); if none claims the field, route - it to :class:`_SimpleBundleBucket` if its annotation is bundle-able, - else to :class:`_UnknownBucket`. + Dispatch order per field: try each bucket in :data:`_FIELD_BUCKETS` + (Tensor, ProcessGroup, Quantizer); if none claims the field, route + it to :class:`_SimpleBundleBucket` if its annotation is bundle-able, + else to :class:`_UnknownBucket`. Intentionally **not** cached on ``cls``. Caching there (e.g. by writing ``cls.__te_buckets__``) tickles Dynamo: subsequent reads of - ``cls.__dict__`` from a compiled function trigger + ``cls.__dict__`` from a compiled function trigger "mappingproxy affected by dictionary mutation" graph breaks. Hot paths must instead capture the bucket list once at op registration time and pass it explicitly to :func:`_pack` / :func:`_unpack`. - """ - buckets: List[_Bucket] = [] - simple_names: List[str] = [] + """ + buckets: List[_Bucket] = [] + simple_names: List[str] = [] for name, annot in _resolved_field_annotations(cls): - built: Optional[_Bucket] = None - for bucket_cls in _FIELD_BUCKETS: - built = bucket_cls.try_build(name, annot) - if built is not None: - break + built: Optional[_Bucket] = None + for bucket_cls in _FIELD_BUCKETS: + built = bucket_cls.try_build(name, annot) if built is not None: - buckets.append(built) - elif _SimpleBundleBucket.matches_field(annot): - simple_names.append(name) - else: - buckets.append(_UnknownBucket(name, cls.__name__)) - if simple_names: - buckets.append(_SimpleBundleBucket(simple_names)) - return buckets + break + if built is not None: + buckets.append(built) + elif _SimpleBundleBucket.matches_field(annot): + simple_names.append(name) + else: + buckets.append(_UnknownBucket(name, cls.__name__)) + if simple_names: + buckets.append(_SimpleBundleBucket(simple_names)) + return buckets def _build_schema(buckets: List[_Bucket]) -> Tuple[str, List[str]]: @@ -997,11 +997,11 @@ def _pack(obj: Any, buckets: List[_Bucket]) -> Dict[str, Any]: avoid recomputing and, critically, to keep Dynamo away from ``cls.__dict__`` while tracing. """ - out: Dict[str, Any] = {} - for bucket in buckets: + out: Dict[str, Any] = {} + for bucket in buckets: for name, value in bucket.pack(obj): - out[name] = value - return out + out[name] = value + return out def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: @@ -1012,13 +1012,13 @@ def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: even when they have no default). ``buckets`` semantics match :func:`_pack`. """ - kwargs: Dict[str, Any] = {} - for bucket in buckets: - bucket.unpack(args, kwargs) - obj = cls.__new__(cls) - for k, v in kwargs.items(): - object.__setattr__(obj, k, v) - return obj + kwargs: Dict[str, Any] = {} + for bucket in buckets: + bucket.unpack(args, kwargs) + obj = cls.__new__(cls) + for k, v in kwargs.items(): + object.__setattr__(obj, k, v) + return obj # --------------------------------------------------------------------------- # From be5c4ad592916b8f8223c4bcdcbd12acb5247342 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 27 May 2026 20:58:18 +0200 Subject: [PATCH 05/16] [PyTorch] Replace Linear fake-impls with output-info descriptors Linear.forward / backward under torch.compile previously relied on hand-written ``_linear_forward_fake_impl`` / ``_linear_backward_fake_impl`` which duplicated every per-precision / per-mode branch from the eager impl just to materialise fake tensors during AOTAutograd tracing. This commit replaces both fake impls with pure-Python alloc descriptors and auto-synthesizes the fake-impls in dynamo.py: * ``_te_register_custom_op`` now accepts ``bwd_output_info_fn`` next to the existing ``output_info_fn`` and synthesizes ``fwd_fake_impl`` / ``backward_fake_impl`` from their alloc tuples via ``_make_fake_impl_from_output_info`` / ``_make_fake_impl_from_bwd_output_info``. * ``output_info_fn`` now returns a 4-tuple ``(user_specs, tensor_objects, ctx_attrs, fake_specs)``; the new ``fake_specs`` dict carries one alloc spec per user output and per saved slot (``("plain", shape, dtype, device)`` or ``("quantized", quantizer, shape, dtype, device)``). * ``_linear_forward_fake_impl`` and ``_linear_backward_fake_impl`` are removed; ``_linear_forward_output_info`` is extended to produce ``fake_specs``, and a new ``_linear_backward_output_info`` (~55 LoC) describes gradient shapes/dtypes/devices. * The matching quantizer plumbing -- ``create_metadata`` / ``create_storage_metadata`` / ``create_save_shell`` -- is filled in for ``MXFP8Quantizer``, ``Float8BlockQuantizer`` and ``NVFP4Quantizer`` so every storage family takes the same path. * ``Float8TensorStorage._torch_compile_do_unflatten`` no longer overwrites ``_transpose_invalid`` with stale metadata captured at flatten time; the post-restore semantic (``_transpose_invalid = (data_transpose is None)``) is now preserved by ``__new__`` / ``restore_from_saved`` instead. Net result: a single Python function per op now owns layout, restore-shape and fake-alloc bookkeeping; the auto-synthesized fake-impl runs only under ``FakeTensorMode`` (so live quantizers and ``tex.DType`` pybind enums are fine) while the layout descriptor stays purely in plain Python and remains traceable by Dynamo under ``fullgraph=True``. All ``tests/pytorch/test_torch_compile.py`` cases pass; ``test_numerics`` linear coverage (882 cases) is green. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 47 + transformer_engine/pytorch/dynamo.py | 877 ++++++++++++++++-- transformer_engine/pytorch/fp8_dtype.py | 67 ++ transformer_engine/pytorch/module/base.py | 16 + transformer_engine/pytorch/module/linear.py | 812 ++++++++++------ .../pytorch/tensor/float8_blockwise_tensor.py | 74 ++ .../pytorch/tensor/float8_tensor.py | 279 +++++- .../pytorch/tensor/mxfp8_tensor.py | 89 ++ .../pytorch/tensor/nvfp4_tensor.py | 85 ++ .../tensor/storage/float8_tensor_storage.py | 34 +- 10 files changed, 2003 insertions(+), 377 deletions(-) create mode 100644 transformer_engine/pytorch/fp8_dtype.py diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index e14aa39bbf..04cd7cc843 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -402,3 +402,50 @@ def fn(inp): f"Float8Tensor weight.grad shape {tuple(model.weight.grad.shape)} != " f"weight shape {tuple(model.weight.shape)}" ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_te_linear_compile_with_fp8_output(): + """torch.compile of ``te.Linear(..., fp8_output=True)``: forward returns + a :class:`Float8Tensor`. + + Exercises the output-rewrap path in + :mod:`transformer_engine.pytorch.dynamo`: the first user output is + declared ``Union[torch.Tensor, Float8Tensor]`` in ``output_annotations``, + and when an output quantizer is active the eager + fake paths must + rewrap the inner data tensors back into a ``Float8Tensor`` for the + user-facing slot. + + Backward through a subclass return value is a known PyTorch + ``torch.compile`` limitation (Dynamo / AOT autograd drop the + ``grad_fn`` on wrapper-subclass outputs of custom ops, so + ``out.sum().backward()`` errors with "element 0 of tensors does + not require grad and does not have a grad_fn"). The forward shape + + type assertions below are sufficient to exercise the rewrap; + grad-routing on FP8 outputs under compile is left as future work. + """ + dtype = torch.bfloat16 + device = "cuda" + fp8_recipe = recipe.Float8CurrentScaling() + + model = te.Linear(64, 32, params_dtype=dtype, device=device) + inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True) + + def fn(inp): + with te.autocast(recipe=fp8_recipe): + return model(inp, fp8_output=True) + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + out = compiled(inp) + assert isinstance(out, te.Float8Tensor), ( + f"expected Float8Tensor output, got {type(out).__name__}" + ) + assert out.shape == (32, 32) + # Dequantising outside the compiled region exercises the + # ``Float8Tensor`` machinery (scale + data + dtype all wired up + # by the rewrap) on the value returned from the compiled fn. + deq = out.dequantize() + assert deq.shape == (32, 32) + assert deq.dtype == dtype diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index f3c8af1e0b..a3163029d9 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -26,6 +26,7 @@ __all__ = [ "OpaqueSimpleMetadata", + "_DispatchTrigger", "_te_register_custom_op", ] @@ -59,6 +60,354 @@ def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return t +# --------------------------------------------------------------------------- # +# Output layout helpers +# --------------------------------------------------------------------------- # +# +# A user output of a TE custom op can be one of: +# * ``None`` -> 1 sentinel slot. +# * plain :class:`torch.Tensor` -> 1 slot. +# * wrapper-subclass tensor with +# ``__tensor_flatten__`` (e.g. +# :class:`Float8Tensor`) -> ``len(inner_names)`` slots. +# * pure-Python class with +# ``_torch_compile_flatten`` (e.g. +# :class:`Float8TensorStorage`) -> ``len(tensors)`` slots. +# +# At op-execution time, :func:`_format_fwd_result` splits each output via +# its flatten protocol and concatenates the inner plain tensors into the +# op's ``Tensor[]`` return. +# +# At call-site time (in :func:`forward_fn`), the layout for each user +# output is learned from a fake run of the user fwd impl (driven by +# :func:`_run_fake_for_proto` -- ``@torch._dynamo.disable``'d so the +# fake call doesn't pollute the surrounding FX graph). The layout +# carries the static (class, inner_names, metadata, shape, stride) +# tuple needed to reassemble the user-facing object from its real +# inner tensors emitted by the op. + + +def _extract_layout(proto_value: Any) -> Tuple[Any, ...]: + """Extract layout info from a fake proto output value. + + Returned tuple starts with a ``kind`` string: ``"none"``, ``"plain"``, + ``"subclass"``, or ``"storage"``; followed by kind-specific fields + consumed by :func:`forward_fn` and the autograd ``setup_context``. + + Used only on the legacy ``fwd_fake_impl``-driven path (see + :func:`_run_fake_for_proto`). The recommended path supplies an + explicit ``output_info_fn`` to :func:`_te_register_custom_op`, + which returns the same shape tuples directly without ever + materialising a fake prototype tensor. + """ + if proto_value is None: + return ("none",) + if isinstance(proto_value, torch.Tensor): + if type(proto_value) is not torch.Tensor and hasattr( + proto_value, "__tensor_flatten__" + ): + inner_names, meta = proto_value.__tensor_flatten__() + return ( + "subclass", + type(proto_value), + tuple(inner_names), + meta, + tuple(proto_value.shape), + tuple(proto_value.stride()), + ) + return ("plain",) + if hasattr(proto_value, "_torch_compile_flatten"): + meta, pg, tensors = proto_value._torch_compile_flatten() + return ("storage", type(proto_value), meta, pg, len(tensors)) + raise TypeError( + f"unsupported output type {type(proto_value).__name__}; expected " + "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " + "class with _torch_compile_flatten." + ) + + +def _spec_slot_count(spec: Tuple[Any, ...]) -> int: + """Number of flat ``Tensor[]`` slots this output spec consumes. + + Accepts both legacy "layout" tuples (from :func:`_extract_layout`) + and the new ``output_info_fn`` spec tuples; the kind-indexed + structure is identical on the slot-count fields. + """ + kind = spec[0] + if kind == "subclass": + return len(spec[2]) # inner_names tuple + if kind == "storage": + return spec[4] # tensor_count + # "none" / "plain": 1 slot + return 1 + + +# Kept as an alias for the small set of internal helpers that still +# spell the old name (e.g. legacy ``_run_fake_for_proto`` paths). +_layout_slot_count = _spec_slot_count + + +def _reassemble_from_spec(spec: Tuple[Any, ...], chunk: List[Any]) -> Any: + """Reconstruct one user-facing output / saved object from its + flat-tensor chunk. + + ``chunk`` is the post-:func:`_decode_none` view of the op's + contribution to this output. Direct ``__tensor_unflatten__`` / + ``_torch_compile_do_unflatten`` is used here (rather than going + through :class:`_ToSubclassFn`); callers that need to interpose an + ``autograd.Function`` between the op output and the user-side + forward fn use :class:`_ToSubclassFn` explicitly. + """ + kind = spec[0] + if kind == "none": + return None + if kind == "plain": + return chunk[0] + if kind == "subclass": + _, cls, inner_names, meta, shape, stride = spec + inner_dict = dict(zip(inner_names, chunk)) + return cls.__tensor_unflatten__(inner_dict, meta, shape, stride) + # kind == "storage" + _, cls, meta, pg, _ = spec + real_tensors = [t for t in chunk if t is not None] + return cls._torch_compile_do_unflatten(meta, pg, real_tensors) + + +# --------------------------------------------------------------------------- # +# Fake-impl synthesis from ``output_info_fn`` allocation specs. +# --------------------------------------------------------------------------- # +# +# The recommended path for TE custom ops is to expose ``output_info_fn`` / +# ``bwd_output_info_fn`` -- pure-Python descriptors of the op's output layout. +# When such a descriptor returns alloc specs alongside the layout / saved +# bookkeeping, ``_te_register_custom_op`` auto-synthesizes a fake-impl from +# them: callers no longer need to maintain a separate +# ``fwd_fake_impl`` / ``backward_fake_impl`` that duplicates the same +# branching logic. The alloc-spec format is intentionally minimal: +# +# * ``None`` -> ``None`` (no slot allocated). +# * ``("plain", shape, dtype, device)`` -> ``torch.empty(...)``. +# * ``("quantized", quantizer, shape, dtype, device)`` +# -> ``quantizer.make_empty(...)``; +# returns either a tensor +# subclass or a +# ``QuantizedTensorStorage`` +# depending on the quantizer. +def _alloc_from_fake_spec(spec: Optional[Tuple[Any, ...]]) -> Any: + """Allocate one fake value from an alloc spec. + + See module-level commentary for the supported spec kinds. ``None`` + /-``("none",)`` is a sentinel meaning "no allocation"; the returned + value is ``None`` and the caller should skip the slot. + """ + if spec is None or spec[0] == "none": + return None + kind = spec[0] + if kind == "plain": + _, shape, dtype, device = spec + return torch.empty(tuple(shape), dtype=dtype, device=device) + if kind == "quantized": + _, quantizer, shape, dtype, device = spec + return quantizer.make_empty(tuple(shape), dtype=dtype, device=device) + raise ValueError(f"unsupported alloc-spec kind: {kind!r}") + + +def _make_fake_impl_from_output_info( + output_info_fn: Callable[[Any], Any], + num_outputs: int, +) -> Callable[[Any], Tuple[Any, ...]]: + """Build a forward fake-impl from an ``output_info_fn``. + + The synthesized fake-impl returns + ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)`` -- + the same shape :func:`_setup_context` expects from a hand-written + ``fwd_fake_impl``. ``user_outputs`` come from + ``fake_specs["user_outputs"]`` (one alloc spec per output), the + saved tuple from ``fake_specs["saved_tensors"]`` (``None`` if the + op did not save anything, e.g. ``is_grad_enabled=False``), and + ``tensor_objects`` / ``ctx_attrs`` are propagated verbatim from + the descriptor. + + The descriptor must return a 4-tuple + ``(user_specs, tensor_objects, ctx_attrs, fake_specs)``. + ``user_specs`` is unused here -- the synthesized fake-impl + delegates layout introspection to :func:`_setup_context`, which + re-invokes ``output_info_fn`` -- but having a single function + return both reassembly specs and alloc specs avoids duplicating + the branching logic. + """ + del num_outputs # informational only; layout comes from fake_specs. + + def _fake(args: Any) -> Tuple[Any, ...]: + _user_specs, tensor_objects, ctx_attrs, fake_specs = output_info_fn(args) + user_outputs = [ + _alloc_from_fake_spec(s) for s in fake_specs["user_outputs"] + ] + saved_specs = fake_specs.get("saved_tensors") + if saved_specs is None: + tensors_to_save: Any = None + else: + tensors_to_save = tuple(_alloc_from_fake_spec(s) for s in saved_specs) + return (*user_outputs, tensors_to_save, tensor_objects, ctx_attrs) + + return _fake + + +def _make_fake_impl_from_bwd_output_info( + bwd_output_info_fn: Callable[[Any], List[Optional[Tuple[Any, ...]]]], +) -> Callable[[Any], Tuple[Any, ...]]: + """Build a backward fake-impl from a ``bwd_output_info_fn``. + + The descriptor returns a flat list of alloc specs (or ``None``) + per gradient output, in the same order as ``backward_impl``'s + return tuple. The synthesized fake-impl just allocates one fake + per slot. + """ + + def _fake(bwd_args: Any) -> Tuple[Any, ...]: + specs = bwd_output_info_fn(bwd_args) + return tuple(_alloc_from_fake_spec(s) for s in specs) + + return _fake + + +class _ToSubclassFn(torch.autograd.Function): + """Construct a wrapper-subclass tensor from its inner plain tensors, + preserving autograd flow through ``__tensor_unflatten__``. + + Non-tensor args (``cls``, ``inner_names``, ``meta``, ``outer_shape``, + ``outer_stride``) are static constants. Dynamo / AOT capture them as + constants on the autograd.Function node; the variadic ``inner_tensors`` + are real / fake graph tensors emitted by the underlying custom op. + """ + + @staticmethod + def forward(ctx, cls, inner_names, meta, outer_shape, outer_stride, *inner_tensors): + """Reassemble ``cls`` from ``inner_tensors`` via ``__tensor_unflatten__``.""" + ctx.inner_names = inner_names + ctx.num_inner = len(inner_tensors) + inner_dict = dict(zip(inner_names, inner_tensors)) + return cls.__tensor_unflatten__(inner_dict, meta, outer_shape, outer_stride) + + @staticmethod + def backward(ctx, grad_output): + """Route ``grad_output`` back to its per-inner-name slots.""" + # Under AOTAutograd, ``grad_output`` typically arrives flattened + # via the subclass machinery; under eager it may be the subclass + # itself. Both paths support ``__tensor_flatten__``-driven routing. + if grad_output is not None and hasattr(grad_output, "__tensor_flatten__"): + names_in_grad, _ = grad_output.__tensor_flatten__() + grad_by_name = {n: getattr(grad_output, n) for n in names_in_grad} + grads = tuple(grad_by_name.get(n) for n in ctx.inner_names) + else: + # Fallback: route the lone grad to the first inner slot; the + # remaining slots (typically derived quantities like scale) + # get ``None``. + grads = (grad_output,) + (None,) * (ctx.num_inner - 1) + # First five returns correspond to the five leading non-tensor args + # to ``forward`` (``cls``, ``inner_names``, ``meta``, ``shape``, + # ``stride``); none of them carries a gradient. + return (None, None, None, None, None) + grads + + +# --------------------------------------------------------------------------- # +# Dispatch trigger +# --------------------------------------------------------------------------- # +# +# ``register_torch_dispatch(op, subclass, rule)`` only fires when at least +# one argument of the call is an instance of ``subclass``. To get the rule +# to fire *unconditionally* (so the user-facing wrapping logic -- output +# rewrapping into ``Float8Tensor`` etc. -- always runs in the same place +# regardless of whether the caller passed any "real" subclass instances), +# we add an internal ``_DispatchTrigger`` tensor as the last positional +# argument of every subclass-aware custom op. The trigger is a 0-element +# wrapper subclass; the schema slot is plain ``Tensor``, so the call is +# transparent to torch autograd / opcheck and the trigger never appears +# in user code. + +class _DispatchTrigger(torch.Tensor): + """Empty wrapper-subclass tensor used solely to force a + ``register_torch_dispatch`` rule to fire on every call to a + subclass-aware custom op. + + Designed to be installed as an ``nn.Module`` buffer (typically on + :class:`TransformerEngineBaseModule`) and threaded through the + custom op's argument dataclass as a regular ``torch.Tensor`` + field. Dynamo lifts ``nn.Module`` buffers as graph inputs, so the + trigger reaches the FX graph as a regular FakeTensor instead of a + Python-side constant -- this is what made every other "always-on + trigger" approach (module-level globals, fresh-per-call + constructors, ...) trip ``FakeTensorMode`` under + ``torch.compile``. + + ``__torch_dispatch__`` is a transparent passthrough: any op + accidentally invoked on a trigger falls back to the underlying op + with the trigger replaced by a plain empty tensor. The + ``register_torch_dispatch(outer_op, _DispatchTrigger, ...)`` + bindings installed by :func:`_te_register_custom_op` shadow this + for the specific ops we care about. + """ + + @staticmethod + def __new__(cls, _inner: Optional[torch.Tensor] = None) -> "_DispatchTrigger": + instance = torch.Tensor._make_wrapper_subclass( # pylint: disable=no-member + cls, (0,), dtype=_NONE_SENTINEL_DTYPE, device="cpu", + ) + # Attach a regular inner tensor so the subclass has something + # for Dynamo / FakeTensorMode to fake out via the standard + # subclass-flattening protocol. Without an inner tensor, + # Dynamo can't reproduce the subclass instance in the fake + # graph and the call to a ``torch.compile``'d module trips + # ``InternalTorchDynamoError: Wrapped Tensor must be this + # graph's fake``. + instance._inner = ( + _inner if _inner is not None + else torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) + ) + return instance + + def __init__(self, _inner: Optional[torch.Tensor] = None) -> None: + # All work is done in ``__new__``; the optional ``_inner`` + # parameter is consumed there. The signature is mirrored here + # so direct ``__init__`` calls (e.g. via ``__tensor_unflatten__`` + # paths inside Dynamo) don't trip ``TypeError`` on the extra + # positional. + del _inner + + def __tensor_flatten__(self) -> Tuple[List[str], Dict[str, Any]]: + return ["_inner"], {} + + @staticmethod + def __tensor_unflatten__( + inner_tensors: Dict[str, torch.Tensor], + meta: Dict[str, Any], + outer_size, + outer_stride, + ) -> "_DispatchTrigger": + del meta, outer_size, outer_stride + return _DispatchTrigger(_inner=inner_tensors["_inner"]) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + def _strip(value: Any) -> Any: + if isinstance(value, _DispatchTrigger): + return torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) + return value + + new_args = [_strip(a) for a in args] + new_kwargs = {k: _strip(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + def _stable_hash_for_caching(self) -> str: + # Required by AOT autograd's subclass cache. The trigger + # carries no semantically-relevant state, so a constant string + # is sufficient and ensures different trigger instances cache + # to the same compiled artifact. + return "te.dynamo._DispatchTrigger" + + # --------------------------------------------------------------------------- # # OpaqueSimpleMetadata # --------------------------------------------------------------------------- # @@ -1052,17 +1401,118 @@ def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: return restore_from_saved(tensor_objects, saved_tensors) -def _format_fwd_result(result: Any, num_outputs: int) -> List[torch.Tensor]: +# --------------------------------------------------------------------------- # +# Forward-result packing +# --------------------------------------------------------------------------- # +# +# The custom-op schema is fixed at ``-> Tensor[]``: a single flat list of +# plain tensors. To return values that are *not* plain tensors (a +# :class:`Float8Tensor` wrapper subclass, a ``QuantizedTensorStorage`` +# workspace, ``None``...), :func:`_format_fwd_result` runs each user +# output through the relevant flatten protocol and concatenates the +# resulting inner tensors -- one variable-length chunk per output -- +# into the op's flat return. Saved-for-backward tensors follow in +# declaration order. +# +# At call-site time (:func:`forward_fn` and :func:`_setup_context`), +# the per-call output structure is learned from a fake run of the user +# fwd impl driven by :func:`_run_fake_for_proto` (see +# :func:`_extract_layout` and :func:`_layout_slot_count` near the top +# of this file). The static (class, inner-names, metadata, shape, +# stride) captured by each layout is enough to reassemble the +# user-facing object from its real inner tensors emitted by the op; +# subclass reconstruction goes through :class:`_ToSubclassFn` so the +# wrap is recorded on the autograd graph. + + +def _format_fwd_result( + result: Any, + num_outputs: int, +) -> List[torch.Tensor]: """Pack a fwd-impl return tuple into the op's ``Tensor[]`` payload. - The op concatenates ``[*output_tensors, *tensors_to_save]`` into a - single non-nullable list; ``None`` entries are smuggled through the - :func:`_encode_none` sentinel so ``register_autograd`` still - attaches a ``grad_fn`` to the result. + Each user output is decomposed into a deterministic number of inner + plain tensors (see :func:`_extract_layout`): + + * ``None`` -> 1 sentinel slot. + * plain Tensor -> 1 slot. + * subclass with + ``__tensor_flatten__`` -> ``len(inner_names)`` slots, in the order + declared by the class. + * storage with + ``_torch_compile_flatten`` -> ``len(tensors)`` slots. + + Saved-for-backward tensors follow in declaration order. ``None`` + entries on either side are smuggled through :func:`_encode_none` + so the schema stays non-nullable and ``register_autograd`` still + attaches a ``grad_fn`` to the op's outputs. + + The slot layout produced here must match exactly what + :func:`_extract_layout` predicts from a proto fake run, since the + call-site reassembly in :func:`forward_fn` uses the proto-derived + layout to slice this flat list back into user-facing objects. """ outputs = list(result[:num_outputs]) + flat: List[torch.Tensor] = [] + # Flatten user outputs *before* ``_prepare_for_saving`` -- the + # latter mutates storage instances in place (clears ``_data`` / + # ``_transpose`` / ``_scale_inv``), and the same object can be + # both a user output and a saved-for-backward entry. Doing the + # flatten first observes the original tensor state. + for value in outputs: + if value is None: + flat.append(_encode_none(None)) + elif isinstance(value, torch.Tensor): + if type(value) is not torch.Tensor and hasattr(value, "__tensor_flatten__"): + inner_names, _ = value.__tensor_flatten__() + flat.extend(_encode_none(getattr(value, n)) for n in inner_names) + else: + flat.append(_encode_none(value)) + elif hasattr(value, "_torch_compile_flatten"): + _, _, tensors = value._torch_compile_flatten() + flat.extend(_encode_none(t) for t in tensors) + else: + raise TypeError( + f"unsupported output type {type(value).__name__}; expected " + "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " + "class with _torch_compile_flatten." + ) tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) - return [_encode_none(t) for t in outputs + tensors_to_save] + flat.extend(_encode_none(t) for t in tensors_to_save) + return flat + + +@torch._dynamo.allow_in_graph +def _run_fake_for_proto( + fwd_fake_impl: Callable[[Any], Any], + fwd_obj: Any, + num_outputs: int, +) -> List[Any]: + """Execute ``fwd_fake_impl(fwd_obj)`` in isolation and return its + user-facing outputs to be used as prototypes for output layout + extraction. + + Isolated from any ambient ``FakeTensorMode`` (Dynamo / AOT's own + mode included) by stacking ``_disable_current_modes`` plus a + fresh ``FakeTensorMode``. None of the fake allocations performed + inside ``fwd_fake_impl`` pollute the surrounding FX graph; the + proto outputs leave the function as Python objects whose + metadata (class, ``__tensor_flatten__`` names, shape, ...) is + extracted into static layout tuples on the call site. + + Decorated with :func:`torch._dynamo.allow_in_graph` so that + Dynamo encodes the entire call as a single opaque FX node + instead of trying to trace the fake-allocation body. Unlike + ``@torch._dynamo.disable`` this does not graph-break under + ``fullgraph=True``. + """ + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.utils._python_dispatch import _disable_current_modes + + with _disable_current_modes(): + with FakeTensorMode(allow_non_fake_inputs=True): + result = fwd_fake_impl(fwd_obj) + return list(result[:num_outputs]) def _format_bwd_result( @@ -1235,6 +1685,7 @@ def _register_autograd_for_op( fwd_impl: Callable[[Any], Any], setup_context_user: Callable[..., None], backward_obj_type: type, + output_info_fn: Optional[Callable[[Any], Tuple[List[Tuple[Any, ...]], List[Tuple[Any, ...]], Any]]] = None, ) -> None: """Wire ``register_autograd`` on a forward op so its backward calls ``bwd_op_name``. @@ -1245,6 +1696,19 @@ def _register_autograd_for_op( registration is handled separately (by :func:`_register_kernel` for the inner tier and :func:`_register_outer_forwarder` for the outer tier). + + The op's ``Tensor[]`` return holds the flat layout produced by + :func:`_format_fwd_result` -- one chunk per user output / saved + tensor, sliced via: + + * ``output_info_fn(fwd_obj)`` -- the recommended path: a pure + Python function that returns the static + ``(user_specs, saved_specs, ctx_attrs)`` tuple. Traceable by + Dynamo / AOT, no fake tensor allocation involved. + * legacy ``fwd_fake_impl(fwd_obj)`` -- runs the user fake impl + and extracts layouts via :func:`_extract_layout`. Kept for + backwards compatibility with callers that haven't migrated to + ``output_info_fn`` yet. """ fwd_qualname = f"{_TE_OP_NAMESPACE}::{fwd_op_name}" @@ -1256,16 +1720,53 @@ def _setup_context(ctx, inputs, output): } kwargs = dict(zip(fwd_arg_names, inputs)) fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) - fake_result = fake_for_setup(fwd_obj) - _, tensor_objects = _prepare_for_saving(fake_result[num_outputs]) - ctx_attrs = fake_result[num_outputs + 2] - - user_outputs = [_decode_none(t) for t in output[:num_outputs]] - op_saved_tensors = [_decode_none(t) for t in output[num_outputs:]] - tensors_to_save_from_forward = _restore_from_saved( - tensor_objects, - op_saved_tensors, - ) + + if output_info_fn is not None: + user_specs, tensor_objects, ctx_attrs, _fake_specs = output_info_fn(fwd_obj) + cursor = 0 + user_outputs: List[Any] = [] + for spec in user_specs: + n = _spec_slot_count(spec) + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + user_outputs.append(_reassemble_from_spec(spec, chunk)) + + # ``tensor_objects`` is the same shape :func:`prepare_for_saving` + # would produce: a list with one entry per element of + # ``tensors_to_save_from_forward`` -- ``None`` for plain + # tensors, a storage shell (with ``_data`` / ``_scale_inv`` + # / ... set to ``None``) for quantized storages. The shells + # are constructed inside ``output_info_fn`` via simple + # ``object.__new__`` + attribute writes so Dynamo can carry + # them as ``UserDefinedObjectVariable``s across the trace + # boundary. :func:`_restore_from_saved` reads each shell's + # ``restore_from_saved`` to consume the right number of + # slots from ``op_saved_tensors``. + op_saved_tensors = [_decode_none(t) for t in output[cursor:]] + tensors_to_save_from_forward = _restore_from_saved( + tensor_objects, + op_saved_tensors, + ) + else: + fake_result = fake_for_setup(fwd_obj) + # Learn output layouts from the fake result. + layouts = [_extract_layout(p) for p in fake_result[:num_outputs]] + + cursor = 0 + user_outputs = [] + for layout in layouts: + n = _spec_slot_count(layout) + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + user_outputs.append(_reassemble_from_spec(layout, chunk)) + + op_saved_tensors = [_decode_none(t) for t in output[cursor:]] + _, tensor_objects = _prepare_for_saving(fake_result[num_outputs]) + ctx_attrs = fake_result[num_outputs + 2] + tensors_to_save_from_forward = _restore_from_saved( + tensor_objects, + op_saved_tensors, + ) bwd_obj = backward_obj_type() tensors_to_save_from_setup = setup_context_user( @@ -1293,6 +1794,15 @@ def _autograd_backward(ctx, *grad_outputs): grads = [_decode_none(g) for g in bwd_op(*bwd_args_flat)] out: List[Any] = list(fwd_slot_defaults) tensor_list_lengths = getattr(ctx, "_te_fwd_tensor_list_lengths", {}) + # Pad every ``Tensor[]`` slot with ``None`` entries matching the + # corresponding forward input length. AOT's pytree check on the + # backward return rejects an empty list where the forward input + # was a non-empty list -- the list structure must match + # element-for-element. Grad-target slots below overwrite the + # first entry with the actual gradient. + for pos, length in tensor_list_lengths.items(): + if isinstance(out[pos], list): + out[pos] = [None] * length for (pos, as_list), g in zip(grad_targets, grads): if as_list: length = tensor_list_lengths.get(pos, 1) @@ -1313,31 +1823,46 @@ def _register_outer_forwarder( *, outer_op_name: str, inner_op_name: str, - arg_names: List[str], + buckets: Optional[List[_Bucket]] = None, + subclass_list: Optional[List[type]] = None, ) -> None: - """Register the outer op's default kernel + fake as a thin - forwarder into the inner op. - - The outer op must remain opaque to compilation (so - ``register_torch_dispatch`` rules installed on it actually fire); - we register the kernel against ``CompositeExplicitAutograd`` and - additionally register a fake impl that simply re-invokes the - inner op. For the subclass path the dispatch rule rewrites the - call into an inner call *before* this kernel/fake ever runs; the - forwarder is only consulted when no rule matches (i.e. the inputs - are plain tensors and / or plain ``QuantizedTensorStorage`` flat - slots that already match the inner schema directly). + """Register the outer op's default kernel + fake. + + Both kernel and fake forward to the inner op, optionally with an + in-place input flatten step for any registered subclass arg (so the + inner op's plain-tensor schema is satisfied). Outputs travel + untouched in their flat ``Tensor[]`` shape -- the user-facing + wrapping back into subclasses / storage happens in + :func:`forward_fn` via :class:`_ToSubclassFn`. """ inner_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_op_name) - def _outer_kernel(*flat: Any) -> List[torch.Tensor]: - return inner_op(*flat) + input_flatten_enabled = bool(subclass_list) and buckets is not None - _TE_LIB.impl(outer_op_name, _outer_kernel, "CompositeExplicitAutograd") + if input_flatten_enabled: + slot_offsets = _collect_universal_slot_offsets(buckets) - def _outer_fake(*flat: Any) -> List[torch.Tensor]: - return inner_op(*flat) + def _flatten_all(new_args: List[Any]) -> None: + for sub in subclass_list: + _flatten_subclass_into_slots(new_args, slot_offsets, sub) + def _outer_kernel(*flat: Any) -> List[torch.Tensor]: + new_args = list(flat) + _flatten_all(new_args) + return inner_op(*new_args) + + def _outer_fake(*flat: Any) -> List[torch.Tensor]: + new_args = list(flat) + _flatten_all(new_args) + return inner_op(*new_args) + else: + def _outer_kernel(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + def _outer_fake(*flat: Any) -> List[torch.Tensor]: + return inner_op(*flat) + + _TE_LIB.impl(outer_op_name, _outer_kernel, "CompositeExplicitAutograd") torch.library.register_fake( f"{_TE_OP_NAMESPACE}::{outer_op_name}", _outer_fake, lib=_TE_LIB ) @@ -1346,7 +1871,8 @@ def _outer_fake(*flat: Any) -> List[torch.Tensor]: def _te_register_custom_op( *, op_name: str, - num_outputs: int, + num_outputs: Optional[int] = None, + output_annotations: Optional[Sequence[Any]] = None, input_tensors_for_grad: List[str], fwd_arg_type: type, fwd_impl: Callable[[Any], Any], @@ -1357,6 +1883,15 @@ def _te_register_custom_op( backward_impl: Callable[[Any], Any], backward_fake_impl: Optional[Callable[[Any], Any]] = None, subclasses: Optional[Sequence[type]] = None, + output_info_fn: Optional[ + Callable[ + [Any], + Tuple[List[Tuple[Any, ...]], List[Any], Any, Dict[str, Any]], + ] + ] = None, + bwd_output_info_fn: Optional[ + Callable[[Any], List[Optional[Tuple[Any, ...]]]] + ] = None, ) -> Callable[..., Any]: """Register a TE module's forward + backward as a single torch custom op. @@ -1366,10 +1901,19 @@ def _te_register_custom_op( Op name used when registering with ``torch.library``. The namespace is fixed at module level (:data:`_TE_OP_NAMESPACE`). num_outputs - Number of user-facing tensor outputs returned by ``fwd_impl``. - The op concatenates ``[*output_tensors, *tensors_to_save]`` into - a single ``Tensor[]`` return; the wrapper uses ``num_outputs`` to - split the two halves on the way back out. + Number of user-facing outputs returned by ``fwd_impl``. May be + inferred from ``output_annotations`` if the latter is provided. + output_annotations + Optional per-output type annotation, e.g. + ``[Union[torch.Tensor, Float8Tensor], + Optional[Union[torch.Tensor, Float8TensorStorage]]]``. Kept + for documentation / backward compatibility. The runtime layout + of each output (plain / subclass / storage / ``None``) is + learned dynamically from a fake run of ``fwd_fake_impl`` + executed under ``_disable_current_modes`` and a fresh + ``FakeTensorMode``, so the annotation does not constrain the + flat ``Tensor[]`` payload anymore. If both are passed, + ``num_outputs`` must equal ``len(output_annotations)``. input_tensors_for_grad Names of forward-arg-type fields for which ``backward_impl`` returns gradients, in the same order. The wrapper uses this to @@ -1428,6 +1972,70 @@ def _te_register_custom_op( Optional fake counterpart of ``backward_impl``. Returns the same gradient tuple as ``backward_impl``, with fake tensors in place of the real gradients. + output_info_fn + Optional pure-Python layout descriptor for the op's outputs: + ``fn(fwd_obj) -> (user_specs, tensor_objects, ctx_attrs, fake_specs)``. + + * ``user_specs`` is a list, one entry per user output, where + each entry is: + + - ``("plain",)`` -- plain :class:`torch.Tensor` (or ``None`` + smuggled via :func:`_encode_none`). + - ``("none",)`` -- explicit ``None`` (single sentinel slot). + - ``("subclass", cls, inner_names, meta, shape, stride)`` -- + tensor subclass, reassembled via + ``cls.__tensor_unflatten__``. + - ``("storage", cls, meta, pg, tensor_count)`` -- non-tensor + storage, reassembled via ``cls._torch_compile_do_unflatten``. + + * ``tensor_objects`` is the structured descriptor that + :func:`prepare_for_saving` would produce on the user's + ``tensors_to_save_from_forward`` tuple: a Python list with + one entry per saved object, ``None`` for plain tensors and a + storage *shell* (typically built via + :meth:`Quantizer.create_save_shell` -- ``object.__new__`` + + attribute writes, no constructor logic) for quantized + storages. :func:`_restore_from_saved` uses these shells to + reconstruct the saved tuple from the flat ``op_saved_tensors`` + payload. + + * ``ctx_attrs`` is the non-tensor state attached to the + autograd context (passed through to ``setup_context``). + + * ``fake_specs`` is a dict with the alloc info needed to + synthesize a fake-impl when ``fwd_fake_impl`` is not + supplied (see :func:`_alloc_from_fake_spec`). Keys: + + - ``"user_outputs"`` -- list of alloc specs (one per user + output) used to materialise the fake tensors / subclasses + / storages returned by the synthesized fake-impl. + - ``"saved_tensors"`` -- ``None`` (no saved tensors, e.g. + ``is_grad_enabled=False``) or a list of alloc specs (one + per saved slot) used to build the synthesized + ``tensors_to_save`` tuple. + + When supplied, :func:`forward_fn` and the autograd + ``setup_context`` use this function instead of running + ``fwd_fake_impl`` to learn output layouts -- which is the only + way to keep the layout-extraction step traceable by Dynamo + under ``fullgraph=True`` (fake-impl execution typically tries + to construct subclasses with UDF arguments such as live + quantizers / pybind enums, graph-breaking the trace). + + Required to support tensor-subclass outputs (e.g. + :class:`Float8Tensor`) under ``torch.compile``. Optional for + plain-tensor ops, where the fake-impl path is still cheap. + bwd_output_info_fn + Optional pure-Python alloc descriptor for the backward op: + ``fn(bwd_obj) -> [alloc_spec_per_grad_output]``. Each entry is + either ``None`` (slot is ``None``), ``("plain", shape, dtype, + device)``, or ``("quantized", quantizer, shape, dtype, + device)``. When supplied (and ``backward_fake_impl`` is not), + :func:`_te_register_custom_op` synthesizes the backward + fake-impl by allocating one fake per spec via + :func:`_alloc_from_fake_spec`. Useful so the backward fake no + longer has to duplicate the gradient-shape derivation that + lives in the eager impl / its layout descriptor. Returns ------- @@ -1443,6 +2051,21 @@ def _te_register_custom_op( outer_bwd_name = f"{op_name}_backward" subclass_list = list(subclasses or ()) + if output_annotations is not None: + annotated_count = len(list(output_annotations)) + if num_outputs is not None and num_outputs != annotated_count: + raise ValueError( + "_te_register_custom_op: num_outputs=" + f"{num_outputs} does not match len(output_annotations)=" + f"{annotated_count}" + ) + num_outputs = annotated_count + if num_outputs is None: + raise ValueError( + "_te_register_custom_op requires either ``num_outputs`` or " + "``output_annotations``" + ) + # Precompute the bucket list once per arg type and capture it in # the registered closures. Re-deriving the bucket list inside a # compiled call would force :func:`_get_buckets` to read @@ -1500,6 +2123,25 @@ def _te_register_custom_op( # the inner bwd op. inner_fwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_fwd_name}" inner_bwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_bwd_name}" + + # Auto-synthesize the forward / backward fake impls from the + # alloc-spec descriptors when the caller did not hand-write them. + # The synthesized impls share branching with their layout + # counterparts (``output_info_fn`` / ``bwd_output_info_fn``) so + # there's exactly one place where every per-precision / per-mode + # condition lives. Hand-written fake impls still take precedence + # when supplied, so callers can stage the migration op-by-op. + effective_fwd_fake_impl = fwd_fake_impl + if effective_fwd_fake_impl is None and output_info_fn is not None: + effective_fwd_fake_impl = _make_fake_impl_from_output_info( + output_info_fn, num_outputs + ) + effective_bwd_fake_impl = backward_fake_impl + if effective_bwd_fake_impl is None and bwd_output_info_fn is not None: + effective_bwd_fake_impl = _make_fake_impl_from_bwd_output_info( + bwd_output_info_fn + ) + _register_kernel( op_name=inner_fwd_name, op_qualname=inner_fwd_qualname, @@ -1507,7 +2149,7 @@ def _te_register_custom_op( arg_names=fwd_arg_names, buckets=fwd_buckets, impl=fwd_impl, - fake_impl=fwd_fake_impl, + fake_impl=effective_fwd_fake_impl, format_result=lambda r: _format_fwd_result(r, num_outputs), ) _register_kernel( @@ -1517,7 +2159,7 @@ def _te_register_custom_op( arg_names=bwd_arg_names, buckets=bwd_buckets, impl=backward_impl, - fake_impl=backward_fake_impl, + fake_impl=effective_bwd_fake_impl, format_result=lambda g: _format_bwd_result(g, num_grad_inputs, inner_bwd_qualname), ) _register_autograd_for_op( @@ -1531,10 +2173,11 @@ def _te_register_custom_op( num_outputs=num_outputs, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, - fwd_fake_impl=fwd_fake_impl, + fwd_fake_impl=effective_fwd_fake_impl, fwd_impl=fwd_impl, setup_context_user=setup_context, backward_obj_type=backward_obj, + output_info_fn=output_info_fn, ) if subclass_list: @@ -1558,12 +2201,12 @@ def _te_register_custom_op( _register_outer_forwarder( outer_op_name=outer_fwd_name, inner_op_name=inner_fwd_name, - arg_names=fwd_arg_names, + buckets=fwd_buckets, + subclass_list=list(subclass_list), ) _register_outer_forwarder( outer_op_name=outer_bwd_name, inner_op_name=inner_bwd_name, - arg_names=bwd_arg_names, ) _register_autograd_for_op( fwd_op_name=outer_fwd_name, @@ -1576,22 +2219,21 @@ def _te_register_custom_op( num_outputs=num_outputs, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, - fwd_fake_impl=fwd_fake_impl, + fwd_fake_impl=effective_fwd_fake_impl, fwd_impl=fwd_impl, setup_context_user=setup_context, backward_obj_type=backward_obj, + output_info_fn=output_info_fn, ) - # Register ``torch_dispatch`` rules per subclass on both the - # outer fwd and the outer bwd op. The rule replaces the outer - # call entirely: it flattens every ``_UniversalTensorBucket`` - # slot whose ``name`` value is an instance of the registered - # subclass into ``(None, [inner tensors], process_group, - # opaque_meta)`` and invokes the inner op on the rewritten - # args. After the rewrite no subclass tensor remains in the - # call's arg list, and the autograd entry that ends up on the - # output graph is the inner op's (not the outer's), so the - # backward path goes through the inner pair only. + # Register per-subclass ``torch_dispatch`` rules. Each rule + # flattens every registered subclass arg into the + # ``_UniversalTensorBucket`` storage layout (so the inner op + # only ever sees plain tensors + opaque metadata) and forwards + # to the inner op. The flat ``Tensor[]`` output travels back + # untouched -- user-facing wrapping into subclasses / storage + # happens in :func:`forward_fn` via :class:`_ToSubclassFn`, + # outside the dispatcher. fwd_slot_offsets = _collect_universal_slot_offsets(fwd_buckets) bwd_slot_offsets = _collect_universal_slot_offsets(bwd_buckets) inner_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_fwd_name) @@ -1600,43 +2242,130 @@ def _te_register_custom_op( outer_bwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_bwd_name) outer_fwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_fwd_name}" outer_bwd_qualname = f"{_TE_OP_NAMESPACE}::{outer_bwd_name}" - for subclass in subclass_list: - def _fwd_rule(mode, func, types, args, kwargs, subclass=subclass): - new_args = list(args) - _flatten_subclass_into_slots(new_args, fwd_slot_offsets, subclass) - return inner_fwd_op(*new_args) - - def _bwd_rule(mode, func, types, args, kwargs, subclass=subclass): - new_args = list(args) - _flatten_subclass_into_slots(new_args, bwd_slot_offsets, subclass) - return inner_bwd_op(*new_args) + def _flatten_all_subclasses(new_args: List[Any], slot_offsets: List[int]) -> None: + for sub in subclass_list: + _flatten_subclass_into_slots(new_args, slot_offsets, sub) + + def _fwd_rule(mode, func, types, args, kwargs): + del mode, func, types, kwargs + new_args = list(args) + _flatten_all_subclasses(new_args, fwd_slot_offsets) + return inner_fwd_op(*new_args) + + def _bwd_rule(mode, func, types, args, kwargs): + del mode, func, types, kwargs + new_args = list(args) + _flatten_all_subclasses(new_args, bwd_slot_offsets) + return inner_bwd_op(*new_args) + + # EXPERIMENT: temporarily disable trigger-based dispatch rules. + # torch.library.register_torch_dispatch( + # outer_fwd_qualname, _DispatchTrigger, _fwd_rule, lib=_TE_LIB + # ) + # torch.library.register_torch_dispatch( + # outer_bwd_qualname, _DispatchTrigger, _bwd_rule, lib=_TE_LIB + # ) + + # Also register per-subclass dispatch rules. The trigger + # rule above only fires when the dispatcher actually + # consults ``register_torch_dispatch`` (e.g. eager-mode calls + # where the trigger is the only subclass), which doesn't + # cover the case where Dynamo lifts a real wrapper-subclass + # parameter (such as a ``Float8Tensor`` weight) into the FX + # graph: in that case Dynamo invokes the registered fake + # impl instead, so we additionally bind the same rule body + # for every concrete subclass class so the eager dispatcher + # still picks it up alongside the fake impl handling the + # tracing path. + for sub in subclass_list: torch.library.register_torch_dispatch( - outer_fwd_qualname, subclass, _fwd_rule, lib=_TE_LIB + outer_fwd_qualname, sub, _fwd_rule, lib=_TE_LIB ) torch.library.register_torch_dispatch( - outer_bwd_qualname, subclass, _bwd_rule, lib=_TE_LIB + outer_bwd_qualname, sub, _bwd_rule, lib=_TE_LIB ) # ``QuantizedTensor.__torch_dispatch__`` falls back to # dequantizing all subclass args for any op it does not # recognise, which would defeat our - # ``register_torch_dispatch`` rules. Marking both outer ops - # as passthroughs makes QuantizedTensor delegate straight to - # ``super().__torch_dispatch__`` for them, where the - # registered dispatch rules are honoured. + # ``register_torch_dispatch`` rules and would also crash on + # FakeTensors (``tex.dequantize`` needs ``data_ptr``). Mark + # every op we register through this helper -- both tiers and + # both directions -- as passthroughs so QuantizedTensor + # delegates straight to ``super().__torch_dispatch__``. from transformer_engine.pytorch.quantized_tensor import ( _quantized_tensor_passthrough_ops, ) _quantized_tensor_passthrough_ops.add(outer_fwd_op.default) + _quantized_tensor_passthrough_ops.add(outer_bwd_op.default) + _quantized_tensor_passthrough_ops.add(inner_fwd_op.default) + _quantized_tensor_passthrough_ops.add(inner_bwd_op.default) fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) + # Use the auto-synthesized fake-impl when available so the proto + # path stays in sync with the kernel registration above. Falls back + # to ``fwd_impl`` when there is no fake-impl at all (legacy + # plain-tensor ops). + proto_fn = ( + effective_fwd_fake_impl if effective_fwd_fake_impl is not None else fwd_impl + ) def forward_fn(fwd_args): + # 1) Learn user-output layouts. + # ``output_info_fn`` is the recommended path: a pure Python + # function that returns the static spec tuples without ever + # materialising a fake prototype tensor. Traceable by Dynamo + # under ``fullgraph=True``. Fallback: legacy fake-impl run + # via ``_run_fake_for_proto`` (``@torch._dynamo.allow_in_graph`` + # so it stays opaque to Dynamo). + if output_info_fn is not None: + ( + user_specs, + _tensor_objects, + _ctx_attrs, + _fake_specs, + ) = output_info_fn(fwd_args) + else: + proto_outputs = _run_fake_for_proto(proto_fn, fwd_args, num_outputs) + user_specs = [_extract_layout(p) for p in proto_outputs] + + # 2) Invoke the op (graph node). Result is the flat ``Tensor[]`` + # payload produced by :func:`_format_fwd_result`. kwargs = _pack(fwd_args, fwd_buckets) - flat = [kwargs[name] for name in fwd_arg_names] - result = fwd_op(*flat) - outputs = [_decode_none(t) for t in result[:num_outputs]] + flat_in = [kwargs[name] for name in fwd_arg_names] + result = fwd_op(*flat_in) + + # 3) Slice the flat result by spec and reassemble each user + # output. Tensor subclasses go through :class:`_ToSubclassFn` + # so the construction is recorded on the autograd graph and + # Dynamo lifts it as an ``autograd.Function`` call; + # ``QuantizedTensorStorage``-style objects (no autograd of + # their own) are reconstructed directly. + cursor = 0 + outputs: List[Any] = [] + for spec in user_specs: + n = _spec_slot_count(spec) + chunk_raw = result[cursor:cursor + n] + cursor += n + chunk = [_decode_none(t) for t in chunk_raw] + kind = spec[0] + if kind == "none": + outputs.append(None) + elif kind == "plain": + outputs.append(chunk[0]) + elif kind == "subclass": + _, cls, inner_names, meta, shape, stride = spec + outputs.append( + _ToSubclassFn.apply( + cls, inner_names, meta, shape, stride, *chunk + ) + ) + else: # "storage" + _, cls, meta, pg, _slot_count = spec + real_tensors = [t for t in chunk if t is not None] + outputs.append(cls._torch_compile_do_unflatten(meta, pg, real_tensors)) + if num_outputs == 1: return outputs[0] return tuple(outputs) diff --git a/transformer_engine/pytorch/fp8_dtype.py b/transformer_engine/pytorch/fp8_dtype.py new file mode 100644 index 0000000000..2a0fc0cfda --- /dev/null +++ b/transformer_engine/pytorch/fp8_dtype.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Dynamo-friendly mirror of ``transformer_engine_torch.DType``. + +The C++-binded ``transformer_engine_torch.DType`` enum is opaque to +TorchDynamo (see ``UserDefinedObjectVariable(DType)`` graph-break under +``fullgraph=True``): Dynamo cannot proxy a pybind11 enum value as a +constant in the FX graph it builds for tensor-subclass constructors +(e.g. :class:`Float8Tensor`). + +:class:`FP8DType` is a Python :class:`enum.IntEnum` that mirrors +``tex.DType`` 1:1 by integer value. Because :class:`IntEnum` derives +from :class:`int`, Dynamo recognises it as a ``ConstantVariable`` and +captures it as a static constant on subclass-constructor calls inside +a compiled region. Conversion to/from the C++ enum is one +``int(...)`` call. +""" +from __future__ import annotations +from enum import IntEnum + +import transformer_engine_torch as tex + + +class FP8DType(IntEnum): + """Python mirror of :class:`transformer_engine_torch.DType` (int values). + + Values match :class:`tex.DType` 1:1 so that ``int(FP8DType.x) == + int(tex.DType.x)`` for every member. Use :func:`to_tex` to bridge + back to the C++ enum at pybind boundaries. + """ + + kByte = int(tex.DType.kByte) + kInt32 = int(tex.DType.kInt32) + kFloat32 = int(tex.DType.kFloat32) + kFloat16 = int(tex.DType.kFloat16) + kBFloat16 = int(tex.DType.kBFloat16) + kFloat8E4M3 = int(tex.DType.kFloat8E4M3) + kFloat8E5M2 = int(tex.DType.kFloat8E5M2) + kFloat4E2M1 = int(tex.DType.kFloat4E2M1) + + +# Precomputed at module load so Dynamo doesn't have to trace +# ``IntEnum.__new__`` / ``tex.DType.__int__`` inside compiled regions +# (both recurse through Python's internal inspect machinery and exhaust +# Dynamo's frame stack). +_TEX_TO_FP8DTYPE = {member.value: member for member in FP8DType} +_TEX_TO_FP8DTYPE_BY_TEX = {tex.DType(v): m for v, m in _TEX_TO_FP8DTYPE.items()} + + +def to_tex(d) -> tex.DType: + """Coerce ``d`` (``FP8DType`` / ``tex.DType`` / int) to ``tex.DType``.""" + if isinstance(d, tex.DType): + return d + return tex.DType(int(d)) + + +def from_tex(d: tex.DType) -> FP8DType: + """Coerce a ``tex.DType`` (or int matching one of its enum values) to + :class:`FP8DType` via a precomputed lookup table. + """ + if isinstance(d, FP8DType): + return d + if isinstance(d, tex.DType): + return _TEX_TO_FP8DTYPE_BY_TEX[d] + return _TEX_TO_FP8DTYPE[int(d)] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 58f42781e0..fb70b5f4e7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -854,6 +854,22 @@ def __init__(self, name: Optional[str] = None) -> None: self._output_quantizer_role: Optional[QuantizerRole] = None self._grad_input_quantizer_role: Optional[QuantizerRole] = None + # Empty wrapper-subclass tensor threaded through every TE + # custom op as a regular ``Tensor`` argument. Its sole purpose + # is to make ``register_torch_dispatch`` rules + # (registered in :func:`transformer_engine.pytorch.dynamo._te_register_custom_op` + # against ``_DispatchTrigger``) fire on every call to a + # subclass-aware op, even when no other argument is a + # registered subclass. Routed via ``register_buffer`` so that + # Dynamo lifts it as a regular graph input under + # ``torch.compile`` instead of internalising it as a + # Python-side constant (which would then trip + # ``FakeTensorMode``). + from transformer_engine.pytorch.dynamo import _DispatchTrigger + self.register_buffer( + "_te_dispatch_trigger", _DispatchTrigger(), persistent=False + ) + if not TEDebugState.debug_enabled: TEDebugState.initialize() self._validate_name() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5fe7602899..0468ef6ad5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -18,7 +18,6 @@ from transformer_engine.pytorch.torch_version import torch_version from .base import ( - fake_quantize_weight, fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, @@ -27,6 +26,7 @@ _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, + _is_weight_workspace_valid, ) from ._common import noop_cat, WeightGradStore from ..quantization import FP8GlobalStateManager, QuantizerRole @@ -34,7 +34,6 @@ cast_if_needed, clear_tensor_data, divide, - fake_cast_if_needed, init_method_constant, needs_quantized_gemm, assert_dim_for_fp8_exec, @@ -73,6 +72,7 @@ Float8Quantizer, Float8Tensor, ) +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import clear_columnwise_cache, is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up @@ -157,6 +157,13 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool + # Always set to ``self._te_dispatch_trigger`` of the calling + # module: a tiny ``_DispatchTrigger`` wrapper-subclass tensor that + # exists only to make ``register_torch_dispatch`` rules fire on + # every call to the outer custom op, so output rewrapping can run + # in a single place. See :class:`transformer_engine.pytorch.dynamo._DispatchTrigger`. + _te_dispatch_trigger: Optional[torch.Tensor] = None + @dataclass(slots=True) class LinearBwdArgs: @@ -226,6 +233,12 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False + # See :class:`LinearFwdArgs._te_dispatch_trigger`. Set in + # ``_linear_setup_ctx`` from the corresponding forward-args field + # so the backward op carries the same trigger and its always-on + # ``register_torch_dispatch`` rule fires too. + _te_dispatch_trigger: Optional[torch.Tensor] = None + # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None @@ -624,275 +637,6 @@ def _linear_forward_impl( return out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs -def _linear_forward_fake_impl( - args: LinearFwdArgs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], None, Optional[Dict]]: - """Fake :func:`_linear_forward_impl` for torch custom-op shape inference. - - Mirrors the real control flow and ``set_usage`` / ``update_usage`` - calls, but replaces computation with empty tensors and skips side - effects irrelevant for shape inference (CPU offload, calibration, - NCCL/UB collectives, ``clear_tensor_data``, FSDP scatter). - """ - - # The few locals below are mutated later in the function; everything - # else is read directly off ``args``. This shape-inference helper is - # not on a hot path, so we don't bother caching attribute lookups. - save_original_input = args.save_original_input - if args.backward_override == "high_precision": - save_original_input = True - weight_quantizer = args.weight_quantizer - - out_features, in_features = args.weight.shape - assert args.inp.shape[-1] == in_features, "GEMM not possible" - - tp_world_size = get_distributed_world_size(args.tp_group) - backward_needs_input = args.is_grad_enabled and args.weight_requires_grad - with_input_all_gather_nccl = ( - args.parallel_mode == "column" - and args.sequence_parallel - and not args.ub_overlap_ag_fprop - ) - - # ------------------------------------------------------ - # Prepare input tensor - # ------------------------------------------------------ - # ``inputmat`` may become a ``QuantizedTensorStorage`` (which does not - # always expose ``.shape``), so track the logical shape separately. - inputmat = args.inp - inputmat_shape = list(args.inp.shape) - inputmat_total = None - inputmat_total_shape: List[int] = inputmat_shape - own_quantized_input = False - if args.fp8: - assert_dim_for_fp8_exec(inputmat, args.weight) - if save_original_input: - assert not isinstance( - args.input_quantizer, Float8Quantizer - ), "DelayedScaling recipe is not supported with save_original_input" - - if with_input_all_gather_nccl or args.ub_overlap_ag_fprop: - - if args.fp8 or args.debug: - if args.input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorStorage) and not args.custom: - own_quantized_input = True - args.input_quantizer.set_usage( - rowwise=True, - columnwise=backward_needs_input and args.backward_override is None, - ) - if isinstance( - args.input_quantizer, (Float8CurrentScalingQuantizer, Float8Quantizer) - ): - args.input_quantizer.set_usage(columnwise=False) - if save_original_input: - args.input_quantizer.set_usage(columnwise=False) - own_quantized_input = False - inputmat = args.input_quantizer.make_empty( - inputmat.shape, - dtype=args.activation_dtype, - device=inputmat.device, - ) - else: - inputmat = fake_cast_if_needed(args.inp, args.activation_dtype) - - # Initialize gathered input tensor (interleaved set_usage stays). - quantizer = None - if args.fp8 or args.debug: - quantizer = args.input_quantizer - quantizer.set_usage(rowwise=True, columnwise=False) - - gathered_shape = list(inputmat_shape) - gathered_shape[0] *= tp_world_size - inputmat_total_shape = gathered_shape - if quantizer is not None: - inputmat_total = quantizer.make_empty( - gathered_shape, - dtype=args.activation_dtype, - device=args.inp.device, - ) - else: - inputmat_total = torch.empty( - gathered_shape, dtype=args.activation_dtype, device=args.inp.device - ) - - else: - if args.fp8 or args.debug: - if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=True) - else: - if args.input_quantizer is None: - raise ValueError("Missing quantizer for input tensor") - args.input_quantizer.set_usage( - rowwise=True, - columnwise=( - backward_needs_input - and not save_original_input - and args.backward_override is None - ), - ) - inputmat = args.input_quantizer.make_empty( - inputmat.shape, - dtype=args.activation_dtype, - device=inputmat.device, - ) - own_quantized_input = True - else: - inputmat = fake_cast_if_needed(args.inp, args.activation_dtype) - inputmat_total = inputmat - inputmat_total_shape = inputmat_shape - - # ------------------------------------------------------ - # Prepare weight tensor - # ------------------------------------------------------ - new_weight_workspace = None - weightmat = args.weight - if args.fp8 or args.debug: - if weight_quantizer is not None and ( - not isinstance(args.weight, QuantizedTensor) or args.debug - ): - columnwise_usage = ( - args.is_grad_enabled and args.input_requires_grad and not args.is_fsdp2 - ) - if args.backward_override is not None: - columnwise_usage = False - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - elif isinstance(args.weight, QuantizedTensor): - weight_quantizer = args.weight._quantizer - weightmat, new_weight_workspace = fake_quantize_weight( - tensor=args.weight, - quantizer=weight_quantizer, - workspace=args.weight_workspace, - fsdp_group=args.fsdp_group, - workspace_dtype=args.activation_dtype, - cache=args.cache_weight, - ) - weightmat.update_usage(rowwise_usage=True) - else: - weightmat = fake_cast_if_needed(weightmat, args.activation_dtype) - - # Cast bias to expected dtype - bias_dtype = args.activation_dtype - if needs_quantized_gemm(inputmat_total) and args.activation_dtype == torch.float32: - bias_dtype = torch.bfloat16 - bias = fake_cast_if_needed(args.bias, bias_dtype) if args.bias is not None else args.bias - - # Configure output quantizer - if args.output_quantizer is not None: - args.output_quantizer.set_usage(rowwise=True, columnwise=False) - - # Output buffer for Userbuffers reduce-scatter (allocated with the - # post-RS shape so downstream consumers see consistent dimensions). - reduce_scatter_out = None - if args.ub_overlap_rs_fprop: - out_shape = list(args.inp.shape) - out_shape[0] //= tp_world_size - out_shape[-1] = out_features - reduce_scatter_out = torch.empty( - out_shape, dtype=args.activation_dtype, device=args.inp.device - ) - - # ------------------------------------------------------ - # Forward GEMM (fake) - # ------------------------------------------------------ - gemm_out_shape = list(inputmat_total_shape[:-1]) + [out_features] - if args.output_quantizer is not None: - gemm_out = args.output_quantizer.make_empty( - gemm_out_shape, dtype=args.activation_dtype, device=args.inp.device - ) - else: - gemm_out = torch.empty( - gemm_out_shape, dtype=args.activation_dtype, device=args.inp.device - ) - - if with_input_all_gather_nccl: - inputmat_total = None - - # ------------------------------------------------------ - # Prepare output tensor (mirror the real comm path with shape-only ops) - # ------------------------------------------------------ - if args.ub_overlap_rs_fprop: - out = reduce_scatter_out - elif args.parallel_mode == "row" and args.tp_size > 1: - out = gemm_out - if args.sequence_parallel: - new_shape = list(out.shape) - new_shape[0] //= tp_world_size - if args.output_quantizer is not None: - out = args.output_quantizer.make_empty( - new_shape, dtype=out.dtype, device=out.device - ) - else: - out = torch.empty(new_shape, dtype=out.dtype, device=out.device) - # allreduce / symmetric_all_reduce do not change shape. - else: - out = gemm_out - - # Prepare backward state - tensors_to_save_from_forward = None - ctx_attrs = None - - if args.is_grad_enabled: - if save_original_input: - inputmat = args.inp - - if ( - backward_needs_input - and own_quantized_input - and isinstance(inputmat, QuantizedTensorStorage) - ): - if args.backward_override is not None: - inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) - elif ( - args.backward_input_needs_gather - and weight_quantizer.supports_only_rowwise_all_gather() - ): - inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) - else: - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) - - saved_inputmat = None - if backward_needs_input: - saved_inputmat = inputmat - - if args.fsdp_group is not None: - raise NotImplementedError( - "Fake Linear forward does not support manual TE FSDP " - "(fsdp_group is not None); use FSDP2 or MCore FSDP." - ) - fsdp_shapes = [] - - wt_save = weightmat - if args.is_fsdp2 and weightmat is not args.weight: - wt_save = None - - saved_tensor_aliases = ( - "inp" if saved_inputmat is args.inp else None, - "weight" if wt_save is args.weight else None, - "weight", - "bias" if bias is not None else None, - ) - tensors_to_save_from_forward = ( - None if saved_tensor_aliases[0] is not None else saved_inputmat, - None if saved_tensor_aliases[1] is not None else wt_save, - None, - None if saved_tensor_aliases[3] is not None else bias, - ) - - ctx_attrs = { - "fsdp_shapes": fsdp_shapes, - "saved_tensor_aliases": saved_tensor_aliases, - } - - return out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs - - def _linear_setup_ctx( bwd_args: LinearBwdArgs, fwd_args: LinearFwdArgs, @@ -977,6 +721,7 @@ def _linear_setup_ctx( # Misc bwd_args.cpu_offloading = fwd_args.cpu_offloading + bwd_args._te_dispatch_trigger = fwd_args._te_dispatch_trigger if backward_override is not None: bwd_args.fp8 = False @@ -1541,16 +1286,22 @@ def wgrad_gemm( ) -def _linear_backward_fake_impl( +def _linear_backward_output_info( args: LinearBwdArgs, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fake :func:`_linear_backward` for torch custom-op shape inference. - - Backward output shapes/dtypes are deterministic, so we just allocate - empty tensors of the right shape. ``grad_input_quantizer.set_usage`` - is preserved because it influences ``dgrad``'s ``make_empty``. - Manual TE FSDP is unsupported; FSDP2 / MCore FSDP go through the - standard path. +) -> List[Optional[Tuple[Any, ...]]]: + """Pure-Python alloc-spec descriptor for :func:`_linear_backward`. + + Returns a list of three alloc specs -- one per gradient output + ``(wgrad, dgrad, grad_bias)`` -- consumed by the auto-synthesized + backward fake-impl in :func:`_make_fake_impl_from_bwd_output_info`. + Replaces the previously hand-written + ``_linear_backward_fake_impl``: gradient shapes/dtypes are + deterministic, so the descriptor just encodes them as alloc + tuples (``("plain", ...)`` /``("quantized", ...)``) instead of + allocating fake tensors. ``set_usage`` on + ``grad_input_quantizer`` is preserved because it influences + ``dgrad``'s downstream ``make_empty``. Manual TE FSDP is + unsupported; FSDP2 / MCore FSDP go through the standard path. """ if args.fsdp_group is not None: @@ -1565,30 +1316,29 @@ def _linear_backward_fake_impl( if args.grad_input_quantizer is not None: args.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - def _empty(shape, quantizer): + activation_dtype = args.activation_dtype + device = args.grad_output.device + + def _alloc( + shape: Tuple[int, ...], quantizer: Any + ) -> Tuple[Any, ...]: if quantizer is not None: - return quantizer.make_empty( - shape, dtype=args.activation_dtype, device=args.grad_output.device - ) - return torch.empty( - shape, dtype=args.activation_dtype, device=args.grad_output.device - ) + return ("quantized", quantizer, tuple(shape), activation_dtype, device) + return ("plain", tuple(shape), activation_dtype, device) - wgrad = None + wgrad_alloc: Optional[Tuple[Any, ...]] = None if args.requires_wgrad and not args.fuse_wgrad_accumulation: - wgrad = _empty([out_features, in_features], args.grad_weight_quantizer) + wgrad_alloc = _alloc((out_features, in_features), args.grad_weight_quantizer) - dgrad = None + dgrad_alloc: Optional[Tuple[Any, ...]] = None if args.requires_dgrad: - dgrad = _empty(list(args.inp_shape), args.grad_input_quantizer) + dgrad_alloc = _alloc(tuple(args.inp_shape), args.grad_input_quantizer) - grad_bias = None + grad_bias_alloc: Optional[Tuple[Any, ...]] = None if args.use_bias and args.requires_wgrad: - grad_bias = torch.empty( - [out_features], dtype=args.activation_dtype, device=args.grad_output.device - ) + grad_bias_alloc = ("plain", (out_features,), activation_dtype, device) - return wgrad, dgrad, grad_bias + return [wgrad_alloc, dgrad_alloc, grad_bias_alloc] class _Linear(torch.autograd.Function): @@ -1680,18 +1430,480 @@ def backward( # ``torch.compile`` can trace through it without entering the eager # ``torch.autograd.Function`` machinery. Used by :meth:`Linear.forward` # under ``torch.compiler.is_compiling()``. +def _linear_forward_output_info( + args: LinearFwdArgs, +) -> Tuple[List[Tuple[Any, ...]], List[Any], Dict[str, Any], Dict[str, Any]]: + """Pure-Python output-layout descriptor for the linear forward. + + Returns ``(user_specs, tensor_objects, ctx_attrs, fake_specs)`` -- + the static, Dynamo-traceable single source of truth for the + forward op's output layout, saved-tensor bookkeeping, and + fake-impl allocation hints. Replaces the previously hand-written + ``_linear_forward_fake_impl``: :func:`_te_register_custom_op` now + auto-synthesizes the fake-impl from ``fake_specs`` via + :func:`_make_fake_impl_from_output_info`, so every per-precision / + per-mode condition lives in exactly one place. + + Why a separate descriptor (vs. a hand-written fake-impl): + constructing real :class:`Float8Tensor` / + :class:`MXFP8TensorStorage` / ... instances inside a fake-impl + relies on the live quantizers, which under ``fullgraph=True`` + Dynamo refuses to trace through (live quantizers are + :class:`UserDefinedObjectVariable`, ``tex.DType`` is a pybind + enum, ...). The descriptor instead emits: + + * ``user_specs`` / ``tensor_objects`` -- pure-Python tuples and + ``object.__new__``-built shells (via + :meth:`Quantizer.create_metadata` / + :meth:`Quantizer.create_save_shell`); consumed by + :func:`forward_fn` and :func:`_setup_context` to reassemble + subclasses and restore saved storages. + * ``fake_specs`` -- alloc tuples + ``("plain", shape, dtype, device)`` / + ``("quantized", quantizer, shape, dtype, device)``; consumed by + the auto-synthesized fake-impl (which only runs under + ``FakeTensorMode``, never under Dynamo's trace -- so live + quantizers / pybind enums are fine here). + + The four return values keep the same branching: every + ``set_usage`` / ``update_usage`` side effect on the live + quantizers happens once in this function and stays consistent + across the layout / fake / forward paths; downstream code + (especially backward) reads the post-forward usage flags off + the same quantizer instance. + """ + fp8 = args.fp8 + debug = args.debug + fp8_or_debug = fp8 or debug + activation_dtype = args.activation_dtype + output_quantizer = args.output_quantizer + input_quantizer = args.input_quantizer + weight_quantizer = args.weight_quantizer + weight = args.weight + inp = args.inp + bias = args.bias + + save_original_input = args.save_original_input + if args.backward_override == "high_precision": + save_original_input = True + + out_features, in_features = weight.shape + assert inp.shape[-1] == in_features, "GEMM not possible" + + tp_world_size = get_distributed_world_size(args.tp_group) + backward_needs_input = args.is_grad_enabled and args.weight_requires_grad + with_input_all_gather_nccl = ( + args.parallel_mode == "column" + and args.sequence_parallel + and not args.ub_overlap_ag_fprop + ) + + # ------------------------------------------------------------------ + # Input pipeline -- mirror :func:`_linear_forward_impl`'s + # ``set_usage`` calls and track which of three end-states the + # ``saved_inputmat`` slot will land in: + # + # * ``inputmat_aliases_inp`` -- saved value IS ``args.inp`` + # (impl-side ``saved_tensor_aliases[0] = "inp"``, slot stored + # as ``None`` and resolved back to ``args.inp`` in + # ``_linear_setup_ctx``). + # * ``inputmat_is_storage`` -- saved value is a fresh + # ``QuantizedTensorStorage`` (created here only as a tensor-free + # shell; the impl produces the real one; the auto-synthesized + # fake-impl allocates a fake one from the slot's alloc spec). + # * neither -- saved value is a plain ``Tensor`` (the cast result). + # + # The branches below match :func:`_linear_forward_impl` + # line-for-line; comments cross-reference the mirrored block when + # not obvious. + # ------------------------------------------------------------------ + inputmat_is_storage = False + inputmat_aliases_inp = False + own_quantized_input = False + inputmat_total_shape: List[int] = list(inp.shape) + + if with_input_all_gather_nccl or args.ub_overlap_ag_fprop: + if fp8_or_debug: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if not isinstance(inp, QuantizedTensorStorage) and not args.custom: + own_quantized_input = True + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and args.backward_override is None, + ) + if isinstance( + input_quantizer, (Float8CurrentScalingQuantizer, Float8Quantizer) + ): + input_quantizer.set_usage(columnwise=False) + if save_original_input: + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False + inputmat_is_storage = True + else: + inputmat_aliases_inp = inp.dtype == activation_dtype + # ``inputmat_total`` only matters for the GEMM output shape; the + # all-gather inflates the leading dim by ``tp_world_size``. + inputmat_total_shape = list(inp.shape) + inputmat_total_shape[0] *= tp_world_size + else: + if fp8_or_debug: + if isinstance(inp, QuantizedTensorStorage): + # In-place ``update_usage`` on the original storage; + # ``inputmat is args.inp`` stays true downstream. + inp.update_usage(rowwise_usage=True) + inputmat_is_storage = True + inputmat_aliases_inp = True + else: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage( + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and args.backward_override is None + ), + ) + inputmat_is_storage = True + own_quantized_input = True + else: + inputmat_aliases_inp = inp.dtype == activation_dtype + + # ``save_original_input`` (and ``backward_override == "high_precision"`` + # in particular) flips ``inputmat`` back to ``args.inp`` at the + # tail of the impl, overriding whatever the input pipeline + # produced above. We mirror that here by forcing the alias bit + # so the saved slot tracks the impl's final ``saved_inputmat is + # args.inp`` check. + if save_original_input: + inputmat_aliases_inp = True + inputmat_is_storage = False + + # ------------------------------------------------------------------ + # Weight pipeline -- mirror of :func:`_linear_forward_impl`'s + # ``quantize_weight`` / ``cast_if_needed`` branches. Tracks the + # same three end-states for ``wt_save``: + # + # * ``weightmat_aliases_weight`` -- ``saved_tensor_aliases[1]`` is + # ``"weight"`` and the slot ends up resolving back to + # ``args.weight`` inside ``_linear_setup_ctx``. + # * ``weightmat_is_storage`` (and not aliased) -- a freshly built + # :class:`QuantizedTensorStorage` (real one in the impl, a + # tensor-free shell here for saved-slot bookkeeping). + # * neither -- a plain cast ``Tensor``. + # + # ``new_weight_workspace_spec`` is the user-output [1] slot: + # non-``("none",)`` only on the cache-miss + ``cache_weight`` + # combination, mirroring the weight-workspace caching branch in + # :func:`_linear_forward_impl`. + # ------------------------------------------------------------------ + new_weight_workspace_spec: Tuple[Any, ...] = ("none",) + weightmat_is_storage = False + weightmat_aliases_weight = False + if fp8_or_debug: + if weight_quantizer is not None and ( + not isinstance(weight, QuantizedTensor) or debug + ): + columnwise_usage = ( + args.is_grad_enabled and args.input_requires_grad and not args.is_fsdp2 + ) + if args.backward_override is not None: + columnwise_usage = False + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(weight, QuantizedTensor): + weight_quantizer = weight._quantizer + + if isinstance(weight, QuantizedTensorStorage): + # ``_linear_forward_impl`` short-circuits the weight pipeline + # on a primary-quantized weight: ``weightmat = weight``. + weightmat_is_storage = True + weightmat_aliases_weight = True + else: + weightmat_is_storage = True + # ``new_weight_workspace`` is non-``None`` only when we miss + # the workspace cache *and* the caller asked us to publish + # the freshly-built workspace back. + workspace = args.weight_workspace + if workspace is not None and not _is_weight_workspace_valid( + workspace, weight_quantizer + ): + workspace = None + if workspace is None and args.cache_weight: + cls, meta, pg, count = weight_quantizer.create_storage_metadata( + shape=weight.shape, + fake_dtype=activation_dtype, + device=weight.device, + requires_grad=False, + as_tensor=False, + ) + new_weight_workspace_spec = ("storage", cls, meta, pg, count) + # ``weightmat.update_usage(rowwise_usage=True)`` runs in the + # impl after this point; that's a no-op on the layout flags + # we track here (we already requested ``rowwise=True`` above). + else: + weightmat_aliases_weight = weight.dtype == activation_dtype + + # ------------------------------------------------------------------ + # Output configuration + # ------------------------------------------------------------------ + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Compute the GEMM-output shape and the post-comm shape that + # leaves the op. + gemm_out_shape: List[int] = list(inputmat_total_shape[:-1]) + [out_features] + if args.ub_overlap_rs_fprop: + out_shape: List[int] = list(inp.shape) + out_shape[0] //= tp_world_size + out_shape[-1] = out_features + elif args.parallel_mode == "row" and args.tp_size > 1 and args.sequence_parallel: + out_shape = list(gemm_out_shape) + out_shape[0] //= tp_world_size + else: + out_shape = list(gemm_out_shape) + + # ------------------------------------------------------------------ + # Build user-output spec [0] -- the GEMM result. + # ------------------------------------------------------------------ + if output_quantizer is None: + out_spec: Tuple[Any, ...] = ("plain",) + else: + # The only subclass we declare in ``output_annotations`` is + # :class:`Float8Tensor`; other quantizer families flow their + # workspace through ``new_weight_workspace`` instead. + inner_names, meta = output_quantizer.create_metadata( + fake_dtype=activation_dtype, + requires_grad=False, + ) + stride = _contiguous_stride(out_shape) + out_spec = ( + "subclass", + Float8Tensor, + inner_names, + meta, + tuple(out_shape), + stride, + ) + + user_specs: List[Tuple[Any, ...]] = [out_spec, new_weight_workspace_spec] + + # ------------------------------------------------------------------ + # Saved-for-backward tensor_objects + saved_tensor_aliases + # ------------------------------------------------------------------ + tensor_objects: List[Any] = [None, None, None, None] + saved_inputmat_alias: Optional[str] = None + wt_save_alias: Optional[str] = None + bias_alias: Optional[str] = None + + if args.is_grad_enabled: + # Post-forward ``update_usage`` on the cached input. The + # in-place ``set_usage`` flips ``input_quantizer`` 's row/col + # bits so backward sees the same storage layout the impl + # ended up with. (Mirrors the matching block in + # :func:`_linear_forward_impl`; we only need to track the + # side effect on the quantizer, no shell rebuild.) + if ( + backward_needs_input + and own_quantized_input + and inputmat_is_storage + and not save_original_input + ): + if args.backward_override is not None: + input_quantizer.set_usage(rowwise=True, columnwise=False) + elif ( + args.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() + ): + input_quantizer.set_usage(rowwise=True, columnwise=False) + else: + input_quantizer.set_usage(rowwise=False, columnwise=True) + + if backward_needs_input: + if inputmat_aliases_inp: + saved_inputmat_alias = "inp" + elif inputmat_is_storage: + # Fresh storage produced by ``input_quantizer``; emit a + # shell so ``_restore_from_saved`` consumes the right + # number of slots from the saved-tensor payload. + tensor_objects[0] = input_quantizer.create_save_shell( + fake_dtype=activation_dtype, + ) + # else: plain Tensor saved -> tensor_objects[0] stays None. + # else: ``saved_inputmat = None`` -> tensor_objects[0] stays None. + + if weightmat_aliases_weight: + wt_save_alias = "weight" + elif args.is_fsdp2: + # ``wt_save = None`` in :func:`_linear_forward_impl` when + # ``weightmat is not args.weight``; FSDP2 re-quantizes + # from the all-gathered weight on backward. + pass + elif weightmat_is_storage: + tensor_objects[1] = weight_quantizer.create_save_shell( + fake_dtype=activation_dtype, + ) + # else: plain cast Tensor saved -> tensor_objects[1] stays None. + + if bias is not None: + bias_alias = "bias" + + saved_tensor_aliases = ( + saved_inputmat_alias, + wt_save_alias, + "weight", + bias_alias, + ) + + # Manual TE FSDP unsupported under compile. + if args.fsdp_group is not None and args.is_grad_enabled: + raise NotImplementedError( + "Compile-time Linear forward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + fsdp_shapes: List[Any] = [] + + ctx_attrs: Dict[str, Any] = { + "fsdp_shapes": fsdp_shapes, + "saved_tensor_aliases": saved_tensor_aliases, + } + + # ------------------------------------------------------------------ + # Fake-impl allocation specs -- consumed by the auto-synthesized + # fake-impl in :func:`_make_fake_impl_from_output_info`. One alloc + # spec per user output and per saved-tensor slot. Pure data so it + # can be carried across Dynamo's trace boundary as constants / + # ``UserDefinedObjectVariable``s. + # ------------------------------------------------------------------ + if output_quantizer is None: + out_alloc: Tuple[Any, ...] = ( + "plain", tuple(out_shape), activation_dtype, inp.device, + ) + else: + out_alloc = ( + "quantized", + output_quantizer, + tuple(out_shape), + activation_dtype, + inp.device, + ) + + if new_weight_workspace_spec[0] == "none": + new_weight_workspace_alloc: Optional[Tuple[Any, ...]] = None + else: + new_weight_workspace_alloc = ( + "quantized", + weight_quantizer, + tuple(weight.shape), + activation_dtype, + weight.device, + ) + + user_output_allocs: List[Optional[Tuple[Any, ...]]] = [ + out_alloc, + new_weight_workspace_alloc, + ] + + saved_tensor_allocs: Optional[List[Optional[Tuple[Any, ...]]]] + if not args.is_grad_enabled: + saved_tensor_allocs = None + else: + # Slot 0 -- ``saved_inputmat``. ``None`` when nothing is saved + # (alias to ``inp``, or backward doesn't need the input); + # ``("quantized", ...)`` when the saved value is a quantized + # storage (matches ``tensor_objects[0] != None``); ``("plain", + # ...)`` otherwise (a fresh cast). + if not backward_needs_input or saved_inputmat_alias is not None: + slot0_alloc: Optional[Tuple[Any, ...]] = None + elif tensor_objects[0] is not None: + slot0_alloc = ( + "quantized", + input_quantizer, + tuple(inp.shape), + activation_dtype, + inp.device, + ) + else: + slot0_alloc = ( + "plain", tuple(inp.shape), activation_dtype, inp.device, + ) + + # Slot 1 -- ``wt_save``. ``None`` when aliased to ``weight`` or + # under FSDP2 (the latter rebuilds the workspace on backward). + # ``args.weight_quantizer`` may differ from the local + # ``weight_quantizer`` (which is reassigned to + # ``weight._quantizer`` when the weight is already a + # :class:`QuantizedTensor`); the saved storage's quantizer must + # match the one the impl uses for re-quantization. + weight_quantizer_for_save = ( + weight._quantizer + if isinstance(weight, QuantizedTensor) + else args.weight_quantizer + ) + if wt_save_alias is not None or args.is_fsdp2: + slot1_alloc: Optional[Tuple[Any, ...]] = None + elif tensor_objects[1] is not None: + slot1_alloc = ( + "quantized", + weight_quantizer_for_save, + tuple(weight.shape), + activation_dtype, + weight.device, + ) + else: + slot1_alloc = ( + "plain", tuple(weight.shape), activation_dtype, weight.device, + ) + + # Slot 2 -- ``saved_weight`` always aliased back to ``weight`` + # by :func:`_linear_setup_ctx``; Slot 3 -- ``saved_bias`` is + # either aliased ("bias") or ``None`` when there is no bias. + # Both stored slots are therefore always ``None``. + saved_tensor_allocs = [slot0_alloc, slot1_alloc, None, None] + + fake_specs: Dict[str, Any] = { + "user_outputs": user_output_allocs, + "saved_tensors": saved_tensor_allocs, + } + + return user_specs, tensor_objects, ctx_attrs, fake_specs + + +def _contiguous_stride(shape: List[int]) -> Tuple[int, ...]: + """Row-major contiguous stride for ``shape``.""" + stride: List[int] = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + stride[i] = stride[i + 1] * int(shape[i + 1]) + return tuple(stride) + + _linear_compiled_op = _te_register_custom_op( op_name="linear", - num_outputs=2, + # ``out`` may be a plain Tensor (default path) or a ``Float8Tensor`` + # (when an output quantizer is configured, e.g. ``fp8_output=True`` + # on a downstream module wired through ``output_quantizer``). + # ``new_weight_workspace`` is the optional FP8 weight cache: a + # ``Float8TensorStorage`` on cache miss with ``is_first_microbatch`` + # / ``cache_weight``; ``None`` otherwise (the bookkeeping flows + # through the storage flatten path even when ``None``). + output_annotations=[ + Union[torch.Tensor, Float8Tensor], + Optional[Union[torch.Tensor, Float8TensorStorage]], + ], input_tensors_for_grad=["weight", "inp", "bias"], fwd_arg_type=LinearFwdArgs, fwd_impl=_linear_forward_impl, - fwd_fake_impl=_linear_forward_fake_impl, + output_info_fn=_linear_forward_output_info, setup_context=_linear_setup_ctx, backward_arg_type=LinearBwdArgs, backward_obj=LinearBwdArgs, backward_impl=_linear_backward, - backward_fake_impl=_linear_backward_fake_impl, + bwd_output_info_fn=_linear_backward_output_info, # Two-tier custom op: the outer ``linear`` op accepts tensor # subclasses (e.g. ``Float8Tensor`` as a weight), and an # ``register_torch_dispatch`` rule flattens each subclass into @@ -2269,6 +2481,8 @@ def forward( # misc cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, + # always-on torch_dispatch trigger + _te_dispatch_trigger=self._te_dispatch_trigger, ) if use_compiled_op: out, new_weight_workspace = _linear_compiled_op(fwd_args) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index e32081e055..93a8843faa 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -281,6 +281,80 @@ def calibrate(self, tensor: torch.Tensor) -> None: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, + ): + """Return ``(cls, meta, process_group, tensor_count)`` + suitable as the ``("storage", ...)`` payload of a Dynamo + output spec; the dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`Float8BlockwiseQTensorStorage._torch_compile_do_unflatten` + for reconstruction. + + Same contract as :meth:`Float8Quantizer.create_storage_metadata` + / :meth:`MXFP8Quantizer.create_storage_metadata` -- see those + docstrings for the broader rationale; this variant carries the + extra ``is_2D_scaled`` flag that the blockwise storage needs + on reconstruction. + """ + if device is None: + device = torch.device("cuda") + shape = torch.Size(shape) + has_rowwise = bool(self.rowwise_usage) + has_columnwise = bool(self.columnwise_usage) + tensor_count = int(has_rowwise) * 2 + int(has_columnwise) * 2 + from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": "Float8BlockwiseQTensorStorage", + "is_tensor": as_tensor, + "shape": shape if as_tensor else None, + "requires_grad": requires_grad if as_tensor else False, + "device": device if as_tensor else None, + "fp8_dtype": self.dtype, + "fake_dtype": fake_dtype, + "is_2D_scaled": self.block_scaling_dim == 2, + "has_rowwise_data": has_rowwise, + "has_rowwise_scale_inv": has_rowwise, + "has_columnwise_data": has_columnwise, + "has_columnwise_scale_inv": has_columnwise, + "quantizer_meta": None, + } + ) + return Float8BlockwiseQTensorStorage, meta, None, tensor_count + + def create_save_shell( + self, + *, + fake_dtype: torch.dtype, + ) -> Float8BlockwiseQTensorStorage: + """Return a tensor-free :class:`Float8BlockwiseQTensorStorage` + shell suitable for use as a ``tensor_objects`` entry in + :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. + + Built via ``object.__new__`` + direct attribute writes for + Dynamo traceability. Mirrors + :meth:`Float8Quantizer.create_save_shell` -- see its docstring + for rationale. + """ + shell = object.__new__(Float8BlockwiseQTensorStorage) + shell._dtype = fake_dtype + shell._rowwise_data = None + shell._columnwise_data = None + shell._rowwise_scale_inv = None + shell._columnwise_scale_inv = None + shell._fp8_dtype = self.dtype + shell._quantizer = None + shell._is_2D_scaled = self.block_scaling_dim == 2 + return shell + def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 01bc480bb2..11fde18a2f 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,7 +4,7 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Optional, Tuple, Iterable, Union +from typing import Any, List, Optional, Tuple, Iterable, Union import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState @@ -26,6 +26,7 @@ ) from ._quantization_helpers import _IdentityFunc from ..constants import canonicalize_te_dtype, dist_group_type +from ..fp8_dtype import FP8DType, from_tex, to_tex aten = torch.ops.aten @@ -43,6 +44,182 @@ } +# --------------------------------------------------------------------------- # +# torch.compile output-layout metadata helpers +# --------------------------------------------------------------------------- # +# +# These helpers produce the static (inner-names + meta-dict) and +# storage-meta layouts that the dynamo integration layer needs to +# reassemble a :class:`Float8Tensor` / :class:`Float8TensorStorage` +# from the flat ``Tensor[]`` return of a TE custom op, without +# allocating a fake prototype tensor inside a traced region. +# +# Shared between :class:`Float8Quantizer` and +# :class:`Float8CurrentScalingQuantizer` because both produce identical +# ``Float8Tensor`` / ``Float8TensorStorage`` layouts (rowwise / columnwise / +# scale-inv inner tensors); the per-quantizer ``create_metadata`` / +# ``create_storage_metadata`` methods delegate here. + + +def _float8_create_subclass_metadata( + quantizer: "Quantizer", + *, + fake_dtype: torch.dtype, + requires_grad: bool = False, +) -> Tuple[Tuple[str, ...], dict]: + """Return ``(inner_names, meta)`` for :meth:`Float8Tensor.__tensor_unflatten__`. + + ``inner_names`` reflects the rowwise / columnwise usage flags of the + quantizer (``_data`` and/or ``_transpose``, plus always ``_scale_inv``). + ``meta`` carries the static, Dynamo-friendly attributes + :class:`Float8Tensor`'s constructor needs: + + * ``_fp8_dtype`` -- :class:`FP8DType` (an :class:`IntEnum`, + proxies as a constant for Dynamo; bridges back to ``tex.DType`` + via :func:`to_tex` on the kernel side). + * ``_fake_dtype`` -- caller-supplied torch dtype. + * ``_quantizer_snapshot`` -- always ``None`` on this path. Re-using + the snapshot reconstruction (which builds a fresh quantizer + inside :meth:`Float8Tensor.__tensor_unflatten__`) would force + Dynamo to trace a quantizer constructor call, which routinely + trips ``UserDefinedObjectVariable(Float8...Quantizer)``. + ``quantizer=None`` keeps the wrapper construction within Dynamo's + proxyable surface; user code that needs the live quantizer + sources it from outside the compiled region. + * ``_requires_grad`` -- caller-supplied flag. + """ + inner_names: List[str] = [] + if quantizer.rowwise_usage: + inner_names.append("_data") + inner_names.append("_scale_inv") + if quantizer.columnwise_usage: + inner_names.append("_transpose") + meta = { + "_fp8_dtype": from_tex(quantizer.dtype), + "_fake_dtype": fake_dtype, + "_quantizer_snapshot": None, + "_requires_grad": requires_grad, + } + return tuple(inner_names), meta + + +def _float8_create_storage_metadata( + quantizer: "Quantizer", + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, +): + """Return ``(cls, meta, process_group, tensor_count)`` suitable + for use as the ``("storage", ...)`` payload of a Dynamo output + spec; the dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`Float8TensorStorage._torch_compile_do_unflatten` for + reconstruction. + + Companion of :func:`_float8_create_subclass_metadata` for the + pure-storage layout (used today for the FP8 weight workspace + returned alongside ``Linear`` 's primary output). ``meta`` is an + :class:`OpaqueSimpleMetadata` carrying: + + * the storage layout flags (``has_data``, ``has_transpose``, + ``has_scale_inv``) derived from the quantizer's rowwise / + columnwise usage, + * ``fp8_dtype`` (raw ``tex.DType`` -- the storage path does not + cross a Dynamo subclass-constructor boundary, so we can keep + the native enum here), + * ``fake_dtype`` / ``shape`` / ``device`` / ``requires_grad`` + describing the higher-precision view of the storage, + * ``quantizer_meta`` -- ``None`` for the same reason as in + :func:`_float8_create_subclass_metadata`. + + ``tensor_count`` is the number of flat inner tensors the storage + will consume from the op's ``Tensor[]`` return (``_data``, + ``_transpose``, ``_scale_inv``, in that order, only those whose + ``has_*`` flag is ``True``). The dynamo layer uses it to slice the + flat return; the storage's :meth:`_torch_compile_do_unflatten` + reassembles them via the same ``has_*`` flags. + """ + if device is None: + device = torch.device("cuda") + shape = torch.Size(shape) + has_data = bool(quantizer.rowwise_usage) + has_transpose = bool(quantizer.columnwise_usage) + has_scale_inv = True + tensor_count = int(has_data) + int(has_transpose) + int(has_scale_inv) + from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": "Float8TensorStorage", + "is_tensor": as_tensor, + "shape": shape if as_tensor else None, + "requires_grad": requires_grad if as_tensor else False, + "device": device if as_tensor else None, + "fp8_dtype": quantizer.dtype, + "fake_dtype": fake_dtype, + # ``Float8TensorStorage._torch_compile_do_unflatten`` skips + # the transpose-validity check when reconstructing; we + # publish ``False`` (valid transpose) here since a + # freshly-quantized storage with the configured usage + # always has up-to-date inner buffers. + "transpose_invalid": not has_transpose, + "has_data": has_data, + "has_transpose": has_transpose, + "has_scale_inv": has_scale_inv, + "quantizer_meta": None, + } + ) + return Float8TensorStorage, meta, None, tensor_count + + +def _float8_create_save_shell( + quantizer: "Quantizer", + *, + fake_dtype: torch.dtype, +) -> "Float8TensorStorage": + """Return a tensor-free :class:`Float8TensorStorage` shell suitable + for use as a ``tensor_objects`` entry in + :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. + + The shell is built via ``object.__new__`` + direct attribute writes + rather than the regular constructor: that avoids tripping Dynamo + on the UDF args (live quantizer instance, ``tex.DType``) that + :meth:`Float8TensorStorage.__new__` would otherwise see when this + function is called from a Dynamo-traced region (e.g. from + ``_linear_forward_output_info``). + + The shell holds no inner tensors -- ``restore_from_saved`` fills + them in from the flat saved-tensor list emitted by the op return, + matching the fixed three-slot layout (``_data``, ``_transpose``, + ``_scale_inv``) of :meth:`Float8TensorStorage.prepare_for_saving`. + The ``_quantizer`` slot is intentionally left ``None``; user code + inside the compiled region must source the live quantizer from + outside. + """ + shell = object.__new__(Float8TensorStorage) + shell._dtype = fake_dtype + shell._data = None + shell._transpose = None + shell._scale_inv = None + shell._fp8_dtype = quantizer.dtype + shell._quantizer = None + # ``_transpose_invalid`` flags a transpose buffer that exists but + # whose contents are stale. Saved-for-backward storages always + # come from the forward after the quantizer has filled in the + # transpose (when it was requested), so the saved transpose -- if + # present at all -- is valid. Initialising to ``False`` keeps + # ``has_data_transpose`` true whenever ``_transpose`` ends up + # non-``None`` after :meth:`restore_from_saved` (which itself only + # writes ``_transpose`` and leaves this flag alone). The + # transpose-None case is unaffected since ``has_data_transpose`` + # ANDs in the ``_transpose is not None`` check. + shell._transpose_invalid = False + return shell + + class Float8Quantizer(Quantizer): """Builder class for FP8 tensors with per-tensor delayed scaling @@ -245,6 +422,46 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def create_metadata( + self, + *, + fake_dtype: torch.dtype, + requires_grad: bool = False, + ) -> Tuple[Tuple[str, ...], dict]: + # pylint: disable=missing-function-docstring + return _float8_create_subclass_metadata( + self, + fake_dtype=fake_dtype, + requires_grad=requires_grad, + ) + + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, + ): + # pylint: disable=missing-function-docstring + return _float8_create_storage_metadata( + self, + shape=shape, + fake_dtype=fake_dtype, + device=device, + requires_grad=requires_grad, + as_tensor=as_tensor, + ) + + def create_save_shell( + self, + *, + fake_dtype: torch.dtype, + ) -> Float8TensorStorage: + # pylint: disable=missing-function-docstring + return _float8_create_save_shell(self, fake_dtype=fake_dtype) + def _flatten(self): from ..dynamo import OpaqueSimpleMetadata @@ -420,7 +637,6 @@ def make_empty( ) scale_inv = torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory) - # See ``Float8Quantizer.make_empty`` for the rationale. if self.internal: return Float8TensorStorage( data=data, @@ -525,6 +741,46 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def create_metadata( + self, + *, + fake_dtype: torch.dtype, + requires_grad: bool = False, + ) -> Tuple[Tuple[str, ...], dict]: + # pylint: disable=missing-function-docstring + return _float8_create_subclass_metadata( + self, + fake_dtype=fake_dtype, + requires_grad=requires_grad, + ) + + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, + ): + # pylint: disable=missing-function-docstring + return _float8_create_storage_metadata( + self, + shape=shape, + fake_dtype=fake_dtype, + device=device, + requires_grad=requires_grad, + as_tensor=as_tensor, + ) + + def create_save_shell( + self, + *, + fake_dtype: torch.dtype, + ) -> Float8TensorStorage: + # pylint: disable=missing-function-docstring + return _float8_create_save_shell(self, fake_dtype=fake_dtype) + def _flatten(self): from ..dynamo import OpaqueSimpleMetadata @@ -592,6 +848,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ + # Upper bound on the number of inner tensors produced by + # :meth:`__tensor_flatten__`. Used by the wide-output layout in + # :mod:`transformer_engine.pytorch.dynamo` to reserve enough slots in + # the custom-op ``Tensor[]`` return for any subclass-shaped output: + # data, scale_inv, transpose. + _TORCH_COMPILE_MAX_INNER_TENSORS = 3 + def __repr__(self, *, tensor_contents=None): # ``__repr__`` is on hot diagnostic paths (Dynamo's # ``Dynamo failed to run FX node`` formatter, autograd @@ -658,12 +921,22 @@ def __tensor_unflatten__( inner_tensors: dict, meta: dict, outer_size, outer_stride ) -> "Float8Tensor": quantizer = _quantizer_from_subclass_snapshot(meta.get("_quantizer_snapshot")) + fp8_dtype = meta["_fp8_dtype"] + if isinstance(fp8_dtype, FP8DType): + # ``meta`` produced by :func:`_float8_create_subclass_metadata` + # carries the Dynamo-friendly :class:`FP8DType` enum (an + # ``IntEnum`` so it proxies as a constant during tracing). + # Pybind-bound TE kernels (e.g. ``tex.dequantize``) accept only + # the native ``transformer_engine_torch.DType``, so bridge back + # here. The eager ``__tensor_flatten__`` path stores the native + # enum directly and skips this conversion. + fp8_dtype = to_tex(fp8_dtype) return Float8Tensor( shape=outer_size, dtype=meta["_fake_dtype"], data=inner_tensors.get("_data"), fp8_scale_inv=inner_tensors.get("_scale_inv"), - fp8_dtype=meta["_fp8_dtype"], + fp8_dtype=fp8_dtype, data_transpose=inner_tensors.get("_transpose"), quantizer=quantizer, requires_grad=meta.get("_requires_grad", False), diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 7616de2247..b4bd410d58 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -255,6 +255,95 @@ def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> tor def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, + ): + """Return ``(cls, meta, process_group, tensor_count)`` + suitable as the ``("storage", ...)`` payload of a Dynamo + output spec; the dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`MXFP8TensorStorage._torch_compile_do_unflatten` for + reconstruction. + + Mirrors what + :meth:`MXFP8TensorStorage._torch_compile_flatten` would emit + for a freshly-quantized storage configured with this + quantizer's rowwise / columnwise usage. ``tensor_count`` is + the variable-length count of present inner tensors + (rowwise_data, rowwise_scale_inv, columnwise_data, + columnwise_scale_inv, only those whose ``has_*`` flag is + true). The dynamo layer uses it to slice the op's flat + ``Tensor[]`` return; the storage's + :meth:`_torch_compile_do_unflatten` reassembles them via the + same ``has_*`` flags. + + ``quantizer_meta`` is set to ``None`` so the reconstructed + storage has ``_quantizer=None`` -- keeping the constructor + traceable by Dynamo, mirroring the behaviour of + :class:`Float8Quantizer.create_storage_metadata`. + """ + if device is None: + device = torch.device("cuda") + shape = torch.Size(shape) + has_rowwise = bool(self.rowwise_usage) + has_columnwise = bool(self.columnwise_usage) + tensor_count = ( + int(has_rowwise) * 2 # rowwise_data + rowwise_scale_inv + + int(has_columnwise) * 2 # columnwise_data + columnwise_scale_inv + ) + from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": "MXFP8TensorStorage", + "is_tensor": as_tensor, + "shape": shape if as_tensor else None, + "requires_grad": requires_grad if as_tensor else False, + "device": device if as_tensor else None, + "fp8_dtype": self.dtype, + "fake_dtype": fake_dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + "has_rowwise_data": has_rowwise, + "has_rowwise_scale_inv": has_rowwise, + "has_columnwise_data": has_columnwise, + "has_columnwise_scale_inv": has_columnwise, + "quantizer_meta": None, + } + ) + return MXFP8TensorStorage, meta, None, tensor_count + + def create_save_shell( + self, + *, + fake_dtype: torch.dtype, + ) -> MXFP8TensorStorage: + """Return a tensor-free :class:`MXFP8TensorStorage` shell for + use as a ``tensor_objects`` entry in + :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. + + Built via ``object.__new__`` + direct attribute writes for + Dynamo traceability. Mirrors + :meth:`Float8Quantizer.create_save_shell` -- see its docstring + for rationale. Restores from the fixed four-slot layout + emitted by :meth:`MXFP8TensorStorage.prepare_for_saving`. + """ + shell = object.__new__(MXFP8TensorStorage) + shell._dtype = fake_dtype + shell._rowwise_data = None + shell._columnwise_data = None + shell._rowwise_scale_inv = None + shell._columnwise_scale_inv = None + shell._fp8_dtype = self.dtype + shell._quantizer = None + shell._with_gemm_swizzled_scales = self.optimize_for_gemm + return shell + def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 42ccb611c4..bd1938ee2f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -415,6 +415,91 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + requires_grad: bool = False, + as_tensor: bool = False, + ): + """Return ``(cls, meta, process_group, tensor_count)`` + suitable as the ``("storage", ...)`` payload of a Dynamo + output spec; the dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`NVFP4TensorStorage._torch_compile_do_unflatten` for + reconstruction. + + See :meth:`Float8Quantizer.create_storage_metadata` for the + general contract. This variant adds the FP4-specific + ``with_gemm_swizzled_scales`` / ``row_scaled_nvfp4`` flags, + and the two amax-row/columnwise inner tensors that come with + the NVFP4 storage layout. + """ + if device is None: + device = torch.device("cuda") + shape = torch.Size(shape) + has_rowwise = bool(self.rowwise_usage) + has_columnwise = bool(self.columnwise_usage) + # Counts: rowwise contributes data + scale_inv + amax; same for + # columnwise. Each pair toggles on its respective usage flag. + tensor_count = ( + int(has_rowwise) * 3 + int(has_columnwise) * 3 + ) + from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": "NVFP4TensorStorage", + "is_tensor": as_tensor, + "shape": shape if as_tensor else None, + "requires_grad": requires_grad if as_tensor else False, + "device": device if as_tensor else None, + "fp4_dtype": self.dtype, + "fake_dtype": fake_dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + "row_scaled_nvfp4": self.row_scaled_nvfp4, + "has_rowwise_data": has_rowwise, + "has_rowwise_scale_inv": has_rowwise, + "has_columnwise_data": has_columnwise, + "has_columnwise_scale_inv": has_columnwise, + "has_amax_rowwise": has_rowwise, + "has_amax_columnwise": has_columnwise, + "quantizer_meta": None, + } + ) + return NVFP4TensorStorage, meta, None, tensor_count + + def create_save_shell( + self, + *, + fake_dtype: torch.dtype, + ) -> NVFP4TensorStorage: + """Return a tensor-free :class:`NVFP4TensorStorage` shell for + use as a ``tensor_objects`` entry in + :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. + + Built via ``object.__new__`` + direct attribute writes for + Dynamo traceability. Restores from the fixed six-slot layout + emitted by :meth:`NVFP4TensorStorage.prepare_for_saving` + (rowwise_data, columnwise_data, rowwise_scale_inv, + columnwise_scale_inv, amax_rowwise, amax_columnwise). + """ + shell = object.__new__(NVFP4TensorStorage) + shell._dtype = fake_dtype + shell._rowwise_data = None + shell._columnwise_data = None + shell._rowwise_scale_inv = None + shell._columnwise_scale_inv = None + shell._amax_rowwise = None + shell._amax_columnwise = None + shell._fp4_dtype = self.dtype + shell._quantizer = None + shell._with_gemm_swizzled_scales = self.optimize_for_gemm + shell._row_scaled_nvfp4 = self.row_scaled_nvfp4 + return shell + def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index ecaf1d919f..51b28c766d 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -86,6 +86,14 @@ class Float8TensorStorage(QuantizedTensorStorage): _transpose: Optional[torch.Tensor] _transpose_invalid: bool + # Upper bound on the number of inner tensors produced by + # :meth:`_torch_compile_flatten`. Used by the wide-output layout in + # :mod:`transformer_engine.pytorch.dynamo` to reserve enough slots in + # the custom-op ``Tensor[]`` return for any storage-shaped output: + # 3 data tensors (data / transpose / scale_inv) + up to 2 quantizer + # tensors (Float8Quantizer carries scale / amax). + _TORCH_COMPILE_MAX_INNER_TENSORS = 5 + def __new__( cls, *args, @@ -172,6 +180,19 @@ def restore_from_saved( self._data = tensors[0] self._transpose = tensors[1] self._scale_inv = tensors[2] + # Re-derive ``_transpose_invalid`` from the restored buffer: + # the saved transpose, if present, was valid at save time + # (``prepare_for_saving`` never resets this flag, and forward + # producers don't save stale transposes). Tying the flag to + # ``self._transpose`` here makes restoration independent of + # whichever shell carried the storage across the trace + # boundary -- in particular ``torch.compile``'s save/restore + # round-trip, which builds a fresh wrapper shell for backward + # whose pre-restore ``_transpose_invalid`` would otherwise + # come from :meth:`Float8TensorStorage.__new__` (``True`` + # whenever it sees ``data_transpose=None``) and trip + # :meth:`update_usage` downstream. + self._transpose_invalid = self._transpose is None return tensors[3:] def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): @@ -316,7 +337,18 @@ def _torch_compile_do_unflatten( } ) out = cls(**kwargs) - out._transpose_invalid = meta["transpose_invalid"] + # ``__new__`` already sets ``_transpose_invalid = (data_transpose + # is None)``, which is exactly the post-restoration semantic we + # want under :mod:`torch.compile`: a transpose buffer that the + # producer chose to ship through the trace was valid at flatten + # time (forward never emits stale transposes onto saved + # tensors), so the unflattened storage must treat it as valid. + # Trusting ``meta["transpose_invalid"]`` instead would re-pin the + # stale ``True`` that Dynamo embeds into the metadata constant + # because it cannot follow the in-place + # :meth:`restore_from_saved` write through ``ctx.tensor_objects`` + # and would then fail the :meth:`update_usage` + # ``not has_data_transpose`` guard in backward. return out def _create_transpose(self): From 1019b19247a507b93a7cd81276b57d76b0d7f1b5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 13:43:12 +0200 Subject: [PATCH 06/16] [PyTorch] Unify TensorSpec hierarchy and infer num_outputs dynamically Consolidates the spec machinery used by ``_te_register_custom_op``: * Merge ``QuantizedAllocSpec`` into ``SubclassTensorSpec``: a single class now covers both *full* (forward outputs that round-trip through the op's flat ``Tensor[]``) and *alloc-only* (backward grad outputs that go straight to autograd) modes, picked by whether ``wrapper_cls`` is supplied to ``SubclassTensorSpec.from_quantizer``. * Introduce ``tensor_spec(...)`` as the single user-facing factory for declaring every op output / saved slot / grad spec. Picks the right ``TensorSpec`` subclass based on the kwargs given (``alias``, ``shape``, ``quantizer``, ``wrapper_cls``, ``storage``); ``output_info_fn`` authors no longer need to know the underlying class hierarchy. * Drop the explicit ``num_outputs`` (and ``output_annotations``) parameter from ``_te_register_custom_op`` and friends. The user output count is inferred dynamically at call time from the impl return shape -- the trailing three slots are always ``(tensors_to_save, tensor_objects, ctx_attrs)``, so ``num_outputs = len(result) - _FWD_TRAILING_SLOTS`` / ``len(user_specs)`` everywhere they're needed. The op schema is ``Tensor[]`` (variable-length list) so this matches the actual graph contract. * Move ``_contiguous_stride`` from ``linear.py`` to ``dynamo.py`` (private helper for ``from_quantizer`` classmethods). * Sync ``create_storage_metadata`` in every quantized tensor module to include the quantizer's flattened tensor count in ``StorageSpec.tensor_count`` so slot counts match the actual ``_torch_compile_flatten`` payload. Net result for ``linear.py``: the public surface from ``dynamo.py`` shrinks to ``TensorSpec`` (type hint), ``_te_register_custom_op``, and ``tensor_spec``; all five concrete spec classes become implementation details of ``tensor_spec``. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 1113 +++++++++++------ transformer_engine/pytorch/module/linear.py | 645 ++++------ .../pytorch/tensor/float8_blockwise_tensor.py | 33 +- .../pytorch/tensor/float8_tensor.py | 71 +- .../pytorch/tensor/mxfp8_tensor.py | 34 +- .../pytorch/tensor/nvfp4_tensor.py | 37 +- 6 files changed, 991 insertions(+), 942 deletions(-) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index a3163029d9..196d4e223a 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -26,6 +26,13 @@ __all__ = [ "OpaqueSimpleMetadata", + "TensorSpec", + "NoneSpec", + "AliasedSpec", + "PlainTensorSpec", + "SubclassTensorSpec", + "StorageSpec", + "tensor_spec", "_DispatchTrigger", "_te_register_custom_op", ] @@ -87,186 +94,563 @@ def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: # inner tensors emitted by the op. -def _extract_layout(proto_value: Any) -> Tuple[Any, ...]: - """Extract layout info from a fake proto output value. +def _contiguous_stride(shape: Sequence[int]) -> Tuple[int, ...]: + """Row-major contiguous stride for ``shape``. - Returned tuple starts with a ``kind`` string: ``"none"``, ``"plain"``, - ``"subclass"``, or ``"storage"``; followed by kind-specific fields - consumed by :func:`forward_fn` and the autograd ``setup_context``. + Used by :meth:`SubclassTensorSpec.from_quantizer` to fill in the + ``stride`` field expected by ``__tensor_unflatten__``; user code + that builds :class:`SubclassTensorSpec` directly typically does + not need to touch this. + """ + stride: List[int] = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + stride[i] = stride[i + 1] * int(shape[i + 1]) + return tuple(stride) + + +# --------------------------------------------------------------------------- # +# TensorSpec -- unified per-slot descriptor +# --------------------------------------------------------------------------- # +# +# ``TensorSpec`` is the single source of truth for one user output / one +# backward grad / one fake saved-slot value. Each instance encodes: +# +# * ``slot_count()`` -- how many entries of the op's flat ``Tensor[]`` +# payload this output consumes; +# * ``reassemble(chunk)`` -- how to turn those entries back into the +# user-facing object (plain tensor, tensor +# subclass, ``QuantizedTensorStorage``, ...); +# * ``reassemble_with_autograd(chunk)`` +# -- variant used by :func:`forward_fn` that +# interposes :class:`_ToSubclassFn` for +# subclass paths so the construction stays +# on the autograd graph; +# * ``alloc()`` -- (optional) build an empty fake version of +# the value for shape inference under +# :class:`torch._subclasses.FakeTensorMode`. +# Required only when the op has no +# hand-written ``fwd_fake_impl`` / +# ``backward_fake_impl`` and relies on +# :func:`_make_fake_impl_from_output_info` / +# :func:`_make_fake_impl_from_bwd_output_info` +# to auto-synthesize one. +# +# Replaces the earlier pair of parallel tuple lists (``user_specs`` for +# reassembly, ``fake_specs["user_outputs"]`` for allocation) that every +# ``output_info_fn`` had to keep in lock-step. + + +class TensorSpec: + """Per-output / per-saved-slot layout + (optional) allocation descriptor. - Used only on the legacy ``fwd_fake_impl``-driven path (see - :func:`_run_fake_for_proto`). The recommended path supplies an - explicit ``output_info_fn`` to :func:`_te_register_custom_op`, - which returns the same shape tuples directly without ever - materialising a fake prototype tensor. + Concrete subclasses (:class:`NoneSpec`, :class:`PlainTensorSpec`, + :class:`SubclassTensorSpec`, :class:`StorageSpec`) implement the + methods listed below. See module-level commentary for the role + each method plays in the forward / fake / setup-context pipelines. """ - if proto_value is None: - return ("none",) - if isinstance(proto_value, torch.Tensor): - if type(proto_value) is not torch.Tensor and hasattr( - proto_value, "__tensor_flatten__" - ): - inner_names, meta = proto_value.__tensor_flatten__() - return ( - "subclass", - type(proto_value), - tuple(inner_names), - meta, - tuple(proto_value.shape), - tuple(proto_value.stride()), + + KIND: str = "" + + def slot_count(self) -> int: + raise NotImplementedError( + f"{type(self).__name__}.slot_count() not implemented" + ) + + def reassemble(self, chunk: List[Any]) -> Any: + raise NotImplementedError( + f"{type(self).__name__}.reassemble() not implemented" + ) + + def reassemble_with_autograd(self, chunk: List[Any]) -> Any: + """Reassemble while keeping the autograd graph intact. + + Default to :meth:`reassemble`; only :class:`SubclassTensorSpec` + overrides this to route subclass construction through + :class:`_ToSubclassFn` (so AOTAutograd records the wrap). + """ + return self.reassemble(chunk) + + def alloc(self) -> Any: + raise NotImplementedError( + f"{type(self).__name__}.alloc() not implemented; the spec was " + f"built without allocation info (legacy fake-impl path)." + ) + + @staticmethod + def from_proto(proto_value: Any) -> "TensorSpec": + """Build a reassembly-only spec from a fake-impl proto value. + + Used only by the legacy path where the user provides + ``fwd_fake_impl`` instead of ``output_info_fn``: a fake + prototype tensor is constructed by the user fake-impl, and the + layout (kind, cls, inner_names, meta, shape, stride) is + extracted from it. The returned spec is reassembly-capable but + not alloc-capable -- callers on this path don't need alloc. + """ + if proto_value is None: + return NoneSpec() + if isinstance(proto_value, torch.Tensor): + if type(proto_value) is not torch.Tensor and hasattr( + proto_value, "__tensor_flatten__" + ): + inner_names, meta = proto_value.__tensor_flatten__() + return SubclassTensorSpec( + cls=type(proto_value), + inner_names=tuple(inner_names), + meta=meta, + shape=tuple(proto_value.shape), + stride=tuple(proto_value.stride()), + ) + return PlainTensorSpec( + shape=tuple(proto_value.shape), + dtype=proto_value.dtype, + device=proto_value.device, ) - return ("plain",) - if hasattr(proto_value, "_torch_compile_flatten"): - meta, pg, tensors = proto_value._torch_compile_flatten() - return ("storage", type(proto_value), meta, pg, len(tensors)) - raise TypeError( - f"unsupported output type {type(proto_value).__name__}; expected " - "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " - "class with _torch_compile_flatten." - ) + if hasattr(proto_value, "_torch_compile_flatten"): + meta, pg, tensors = proto_value._torch_compile_flatten() + return StorageSpec( + cls=type(proto_value), + meta=meta, + pg=pg, + tensor_count=len(tensors), + ) + raise TypeError( + f"unsupported output type {type(proto_value).__name__}; expected " + "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " + "class with _torch_compile_flatten." + ) -def _spec_slot_count(spec: Tuple[Any, ...]) -> int: - """Number of flat ``Tensor[]`` slots this output spec consumes. +class NoneSpec(TensorSpec): + """Output / save slot whose value is ``None``. - Accepts both legacy "layout" tuples (from :func:`_extract_layout`) - and the new ``output_info_fn`` spec tuples; the kind-indexed - structure is identical on the slot-count fields. + Consumes one ``Tensor[]`` slot via the :func:`_encode_none` / + :func:`_decode_none` sentinel pair so that the op's schema (which + is non-nullable ``Tensor[]``) can still carry a ``None`` value + end-to-end. """ - kind = spec[0] - if kind == "subclass": - return len(spec[2]) # inner_names tuple - if kind == "storage": - return spec[4] # tensor_count - # "none" / "plain": 1 slot - return 1 - - -# Kept as an alias for the small set of internal helpers that still -# spell the old name (e.g. legacy ``_run_fake_for_proto`` paths). -_layout_slot_count = _spec_slot_count - - -def _reassemble_from_spec(spec: Tuple[Any, ...], chunk: List[Any]) -> Any: - """Reconstruct one user-facing output / saved object from its - flat-tensor chunk. - - ``chunk`` is the post-:func:`_decode_none` view of the op's - contribution to this output. Direct ``__tensor_unflatten__`` / - ``_torch_compile_do_unflatten`` is used here (rather than going - through :class:`_ToSubclassFn`); callers that need to interpose an - ``autograd.Function`` between the op output and the user-side - forward fn use :class:`_ToSubclassFn` explicitly. + + KIND = "none" + + def slot_count(self) -> int: + return 1 + + def reassemble(self, chunk: List[Any]) -> Any: + return None + + def alloc(self) -> Any: + return None + + +class AliasedSpec(TensorSpec): + """Saved-tensor slot whose value is identical to a forward arg. + + The forward impl writes ``None`` into the slot (so no extra storage + moves through the op return) and tags the slot's ``alias`` name in + ``ctx_attrs["saved_tensor_aliases"]``; the user's ``setup_context`` + resolves the alias back to the actual forward arg. + + Behaves like :class:`NoneSpec` on the schema side (1 sentinel slot, + ``reassemble -> None``, ``alloc -> None``); the only difference is + that :func:`_inject_saved_aliases` reads ``self.alias`` to populate + ``ctx_attrs["saved_tensor_aliases"]``. """ - kind = spec[0] - if kind == "none": + + KIND = "aliased" + + def __init__(self, alias: str) -> None: + self.alias = alias + + def slot_count(self) -> int: + return 1 + + def reassemble(self, chunk: List[Any]) -> Any: + return None + + def alloc(self) -> Any: return None - if kind == "plain": + + +class PlainTensorSpec(TensorSpec): + """Plain :class:`torch.Tensor` output / save slot. + + Carries ``shape`` / ``dtype`` / ``device`` for allocation; reassembly + is just the lone slot value. + """ + + KIND = "plain" + + def __init__( + self, + shape: Optional[Sequence[int]] = None, + dtype: Optional["torch.dtype"] = None, + device: Optional["torch.device"] = None, + ) -> None: + self.shape = tuple(shape) if shape is not None else None + self.dtype = dtype + self.device = device + + def slot_count(self) -> int: + return 1 + + def reassemble(self, chunk: List[Any]) -> Any: return chunk[0] - if kind == "subclass": - _, cls, inner_names, meta, shape, stride = spec - inner_dict = dict(zip(inner_names, chunk)) - return cls.__tensor_unflatten__(inner_dict, meta, shape, stride) - # kind == "storage" - _, cls, meta, pg, _ = spec - real_tensors = [t for t in chunk if t is not None] - return cls._torch_compile_do_unflatten(meta, pg, real_tensors) + + def alloc(self) -> Any: + if self.shape is None or self.dtype is None or self.device is None: + return TensorSpec.alloc(self) + return torch.empty(self.shape, dtype=self.dtype, device=self.device) + + +class SubclassTensorSpec(TensorSpec): + """Tensor-subclass output / save slot (e.g. :class:`Float8Tensor`). + + Two modes, picked at construction time via :meth:`from_quantizer`: + + * **Full mode** (``wrapper_cls`` supplied): the spec knows the + subclass identity, ``inner_names`` and ``meta`` for + ``__tensor_unflatten__``, so it can both :meth:`alloc` (under + :class:`FakeTensorMode`) and :meth:`reassemble` slot chunks from + the op's flat ``Tensor[]`` payload back into a user-facing + subclass instance. Used for forward outputs that flow through + the custom op and need to be re-wrapped on the other side. + * **Alloc-only mode** (no ``wrapper_cls``): the spec only carries + enough info to :meth:`alloc` an empty instance via + ``quantizer.make_empty(shape, dtype, device)``. Used for + backward gradient outputs, which never round-trip through the + flat ``Tensor[]`` -- ``_format_bwd_result`` hands them straight + to autograd -- so the layout-aware methods are intentionally + undefined. + """ + + KIND = "subclass" + + def __init__( + self, + *, + shape: Sequence[int], + alloc_quantizer: Any, + alloc_dtype: "torch.dtype", + alloc_device: "torch.device", + cls: Optional[type] = None, + inner_names: Optional[Sequence[str]] = None, + meta: Any = None, + stride: Optional[Sequence[int]] = None, + ) -> None: + self.cls = cls + self.inner_names = tuple(inner_names) if inner_names is not None else None + self.meta = meta + self.shape = tuple(shape) + self.stride = tuple(stride) if stride is not None else None + self.alloc_quantizer = alloc_quantizer + self.alloc_dtype = alloc_dtype + self.alloc_device = alloc_device + + def _require_full_mode(self, method_name: str) -> None: + if self.cls is None: + raise RuntimeError( + f"SubclassTensorSpec.{method_name} is only available in " + "full mode (built with ``wrapper_cls=``). Alloc-only specs " + "(used for backward grad outputs) don't participate in the " + "flat ``Tensor[]`` payload, so they have no slot layout." + ) + + def slot_count(self) -> int: + self._require_full_mode("slot_count") + return len(self.inner_names) + + def reassemble(self, chunk: List[Any]) -> Any: + self._require_full_mode("reassemble") + inner_dict = dict(zip(self.inner_names, chunk)) + return self.cls.__tensor_unflatten__( + inner_dict, self.meta, self.shape, self.stride + ) + + def reassemble_with_autograd(self, chunk: List[Any]) -> Any: + self._require_full_mode("reassemble_with_autograd") + return _ToSubclassFn.apply( + self.cls, self.inner_names, self.meta, self.shape, self.stride, *chunk + ) + + def alloc(self) -> Any: + return self.alloc_quantizer.make_empty( + self.shape, dtype=self.alloc_dtype, device=self.alloc_device + ) + + @classmethod + def from_quantizer( + cls, + quantizer: Any, + *, + shape: Sequence[int], + dtype: "torch.dtype", + device: "torch.device", + wrapper_cls: Optional[type] = None, + requires_grad: bool = False, + ) -> "SubclassTensorSpec": + """Build a :class:`SubclassTensorSpec` from a live quantizer. + + Hides the ``create_metadata`` / inner-name / stride bookkeeping + behind a single call: callers in ``output_info_fn`` / + ``bwd_output_info_fn`` only specify the user-facing identity + (shape, dtype, device, quantizer) -- and, for forward outputs + that need flat-slot reassembly, the ``wrapper_cls`` they + unflatten into. + + Omitting ``wrapper_cls`` yields an alloc-only spec suitable + for backward grad outputs: the quantizer-specific fake + allocation still works (``quantizer.make_empty(...)``), but + :meth:`slot_count` / :meth:`reassemble` are intentionally + disabled because gradients never round-trip through the op's + flat ``Tensor[]`` payload. + """ + if wrapper_cls is None: + return cls( + shape=tuple(shape), + alloc_quantizer=quantizer, + alloc_dtype=dtype, + alloc_device=device, + ) + inner_names, meta = quantizer.create_metadata( + fake_dtype=dtype, + requires_grad=requires_grad, + ) + return cls( + cls=wrapper_cls, + inner_names=inner_names, + meta=meta, + shape=tuple(shape), + stride=_contiguous_stride(shape), + alloc_quantizer=quantizer, + alloc_dtype=dtype, + alloc_device=device, + ) + + +class StorageSpec(TensorSpec): + """Non-tensor storage output / save slot (e.g. :class:`Float8TensorStorage`). + + Reassembled via ``cls._torch_compile_do_unflatten``; allocated via + ``alloc_quantizer.make_empty(shape, ...)``. + """ + + KIND = "storage" + + def __init__( + self, + cls: type, + meta: Any, + pg: Any, + tensor_count: int, + *, + alloc_quantizer: Any = None, + alloc_shape: Optional[Sequence[int]] = None, + alloc_dtype: Optional["torch.dtype"] = None, + alloc_device: Optional["torch.device"] = None, + ) -> None: + self.cls = cls + self.meta = meta + self.pg = pg + self.tensor_count = tensor_count + self.alloc_quantizer = alloc_quantizer + self.alloc_shape = ( + tuple(alloc_shape) if alloc_shape is not None else None + ) + self.alloc_dtype = alloc_dtype + self.alloc_device = alloc_device + + def slot_count(self) -> int: + return self.tensor_count + + def reassemble(self, chunk: List[Any]) -> Any: + real_tensors = [t for t in chunk if t is not None] + return self.cls._torch_compile_do_unflatten(self.meta, self.pg, real_tensors) + + def alloc(self) -> Any: + if self.alloc_quantizer is None or self.alloc_shape is None: + return TensorSpec.alloc(self) + return self.alloc_quantizer.make_empty( + self.alloc_shape, dtype=self.alloc_dtype, device=self.alloc_device + ) + + @classmethod + def from_quantizer( + cls, + quantizer: Any, + *, + shape: Sequence[int], + dtype: "torch.dtype", + device: "torch.device", + requires_grad: bool = False, + as_tensor: bool = False, + ) -> "StorageSpec": + """Build a :class:`StorageSpec` from a live quantizer. + + Hides the ``create_storage_metadata`` four-tuple + ``(cls, meta, process_group, tensor_count)`` behind a single + call: callers in ``output_info_fn`` only need to specify the + quantizer that drives the layout plus the higher-precision + view (shape / dtype / device) the storage represents. + """ + storage_cls, meta, pg, count = quantizer.create_storage_metadata( + shape=shape, + fake_dtype=dtype, + device=device, + requires_grad=requires_grad, + as_tensor=as_tensor, + ) + return cls( + cls=storage_cls, + meta=meta, + pg=pg, + tensor_count=count, + alloc_quantizer=quantizer, + alloc_shape=tuple(shape), + alloc_dtype=dtype, + alloc_device=device, + ) + + +def tensor_spec( + *, + shape: Optional[Sequence[int]] = None, + dtype: Optional["torch.dtype"] = None, + device: Optional["torch.device"] = None, + quantizer: Optional[Any] = None, + wrapper_cls: Optional[type] = None, + storage: bool = False, + alias: Optional[str] = None, +) -> TensorSpec: + """One-stop factory for declaring an op output / saved slot / grad spec. + + Single entry point that authors of ``output_info_fn`` / + ``bwd_output_info_fn`` use to describe every slot the op + produces, regardless of whether the slot is a plain tensor, a + quantized wrapper, a non-tensor storage, an aliased save, an + absent output, or a grad-only alloc target. Internally dispatches + to the appropriate :class:`TensorSpec` subclass based on which + keyword arguments are supplied (first match wins): + + * ``alias`` set -> :class:`AliasedSpec` (saved slot that + reuses a forward arg; no payload moves + through the op). + * ``shape is None`` -> :class:`NoneSpec` (absent output / save). + * ``quantizer is None`` -> :class:`PlainTensorSpec`. + * ``storage=True`` -> :class:`StorageSpec` via + :meth:`StorageSpec.from_quantizer` (used + for quantized saved storages). + * otherwise -> :class:`SubclassTensorSpec` via + :meth:`SubclassTensorSpec.from_quantizer`. + ``wrapper_cls`` picks between *full mode* + (forward outputs that re-wrap from the + flat ``Tensor[]`` payload) and + *alloc-only mode* (backward grad outputs + that never round-trip through the op). + + All quantized paths use ``dtype`` / ``device`` for fake allocation + (``quantizer.make_empty(shape, dtype, device)``); the plain path + requires both as well, since it falls back to ``torch.empty``. + """ + if alias is not None: + return AliasedSpec(alias) + if shape is None: + return NoneSpec() + if quantizer is None: + return PlainTensorSpec(shape=shape, dtype=dtype, device=device) + if storage: + return StorageSpec.from_quantizer( + quantizer, shape=shape, dtype=dtype, device=device + ) + return SubclassTensorSpec.from_quantizer( + quantizer, + shape=shape, + dtype=dtype, + device=device, + wrapper_cls=wrapper_cls, + ) # --------------------------------------------------------------------------- # -# Fake-impl synthesis from ``output_info_fn`` allocation specs. +# Fake-impl synthesis from ``output_info_fn`` / ``bwd_output_info_fn``. # --------------------------------------------------------------------------- # -# -# The recommended path for TE custom ops is to expose ``output_info_fn`` / -# ``bwd_output_info_fn`` -- pure-Python descriptors of the op's output layout. -# When such a descriptor returns alloc specs alongside the layout / saved -# bookkeeping, ``_te_register_custom_op`` auto-synthesizes a fake-impl from -# them: callers no longer need to maintain a separate -# ``fwd_fake_impl`` / ``backward_fake_impl`` that duplicates the same -# branching logic. The alloc-spec format is intentionally minimal: -# -# * ``None`` -> ``None`` (no slot allocated). -# * ``("plain", shape, dtype, device)`` -> ``torch.empty(...)``. -# * ``("quantized", quantizer, shape, dtype, device)`` -# -> ``quantizer.make_empty(...)``; -# returns either a tensor -# subclass or a -# ``QuantizedTensorStorage`` -# depending on the quantizer. -def _alloc_from_fake_spec(spec: Optional[Tuple[Any, ...]]) -> Any: - """Allocate one fake value from an alloc spec. - - See module-level commentary for the supported spec kinds. ``None`` - /-``("none",)`` is a sentinel meaning "no allocation"; the returned - value is ``None`` and the caller should skip the slot. + + +def _inject_saved_aliases( + ctx_attrs: Dict[str, Any], saved_slots: Sequence[TensorSpec] +) -> Dict[str, Any]: + """Inject ``saved_tensor_aliases`` derived from ``saved_slots``. + + The user's ``setup_context`` callback reads aliases off + ``ctx_attrs["saved_tensor_aliases"]`` to resolve aliased saved + slots back to their forward arg. Only :class:`AliasedSpec` + contributes a non-``None`` alias entry; every other spec maps to + ``None`` (no alias, the real value is carried through the op + payload). We expose the tuple on every code path (real op output, + output-info path, auto-synthesized fake) so the callback's + contract stays identical. """ - if spec is None or spec[0] == "none": - return None - kind = spec[0] - if kind == "plain": - _, shape, dtype, device = spec - return torch.empty(tuple(shape), dtype=dtype, device=device) - if kind == "quantized": - _, quantizer, shape, dtype, device = spec - return quantizer.make_empty(tuple(shape), dtype=dtype, device=device) - raise ValueError(f"unsupported alloc-spec kind: {kind!r}") + out = dict(ctx_attrs) if ctx_attrs else {} + out["saved_tensor_aliases"] = tuple( + s.alias if isinstance(s, AliasedSpec) else None for s in saved_slots + ) + return out def _make_fake_impl_from_output_info( output_info_fn: Callable[[Any], Any], - num_outputs: int, ) -> Callable[[Any], Tuple[Any, ...]]: """Build a forward fake-impl from an ``output_info_fn``. The synthesized fake-impl returns ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)`` -- the same shape :func:`_setup_context` expects from a hand-written - ``fwd_fake_impl``. ``user_outputs`` come from - ``fake_specs["user_outputs"]`` (one alloc spec per output), the - saved tuple from ``fake_specs["saved_tensors"]`` (``None`` if the - op did not save anything, e.g. ``is_grad_enabled=False``), and - ``tensor_objects`` / ``ctx_attrs`` are propagated verbatim from - the descriptor. - - The descriptor must return a 4-tuple - ``(user_specs, tensor_objects, ctx_attrs, fake_specs)``. - ``user_specs`` is unused here -- the synthesized fake-impl - delegates layout introspection to :func:`_setup_context`, which - re-invokes ``output_info_fn`` -- but having a single function - return both reassembly specs and alloc specs avoids duplicating - the branching logic. + ``fwd_fake_impl``: + + * ``user_outputs`` comes from ``[s.alloc() for s in user_specs]``. + * ``tensors_to_save`` comes from ``tuple(s.alloc() for s in saved_slots)``, + or ``None`` if ``saved_slots`` is empty + (e.g. ``is_grad_enabled=False``). + * ``tensor_objects`` is a vestigial slot kept for tuple-shape + symmetry with hand-written fake impls; the + compile path no longer consumes it. + * ``ctx_attrs`` is augmented with ``saved_tensor_aliases`` + derived from ``saved_slots`` so the user's + ``setup_context`` sees the same contract. + + ``output_info_fn`` must return a 3-tuple + ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], + ctx_attrs: Dict[str, Any])``. """ - del num_outputs # informational only; layout comes from fake_specs. def _fake(args: Any) -> Tuple[Any, ...]: - _user_specs, tensor_objects, ctx_attrs, fake_specs = output_info_fn(args) - user_outputs = [ - _alloc_from_fake_spec(s) for s in fake_specs["user_outputs"] - ] - saved_specs = fake_specs.get("saved_tensors") - if saved_specs is None: + user_specs, saved_slots, ctx_attrs = output_info_fn(args) + user_outputs = [s.alloc() for s in user_specs] + if not saved_slots: tensors_to_save: Any = None else: - tensors_to_save = tuple(_alloc_from_fake_spec(s) for s in saved_specs) - return (*user_outputs, tensors_to_save, tensor_objects, ctx_attrs) + tensors_to_save = tuple(s.alloc() for s in saved_slots) + ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) + return (*user_outputs, tensors_to_save, None, ctx_attrs) return _fake def _make_fake_impl_from_bwd_output_info( - bwd_output_info_fn: Callable[[Any], List[Optional[Tuple[Any, ...]]]], + bwd_output_info_fn: Callable[[Any], List[TensorSpec]], ) -> Callable[[Any], Tuple[Any, ...]]: """Build a backward fake-impl from a ``bwd_output_info_fn``. - The descriptor returns a flat list of alloc specs (or ``None``) - per gradient output, in the same order as ``backward_impl``'s - return tuple. The synthesized fake-impl just allocates one fake - per slot. + The descriptor returns a flat list of :class:`TensorSpec` + (typically :class:`NoneSpec` / :class:`PlainTensorSpec` / + alloc-only :class:`SubclassTensorSpec` for quantized grads), one + per gradient output in the same order as ``backward_impl``'s + return tuple. The synthesized fake-impl just calls + :meth:`TensorSpec.alloc` on each. """ def _fake(bwd_args: Any) -> Tuple[Any, ...]: specs = bwd_output_info_fn(bwd_args) - return tuple(_alloc_from_fake_spec(s) for s in specs) + return tuple(s.alloc() for s in specs) return _fake @@ -1385,22 +1769,19 @@ def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: """Lazy wrapper around :func:`quantized_tensor.prepare_for_saving`. - Lazy-imports to avoid the dynamo<->quantized_tensor circular import - that ``transformer_engine.pytorch`` would otherwise trigger at - module import time. + Used only to flatten the user's setup-context return into a + ``(flat_tensors, tensor_objects)`` pair stashed on ``ctx`` for the + backward; the forward output and saved-tensor restoration on the + compile-path now go through :class:`TensorSpec` instead. Lazy-imports + avoid the dynamo<->quantized_tensor circular import that + ``transformer_engine.pytorch`` would otherwise trigger at module + import time. """ from transformer_engine.pytorch.quantized_tensor import prepare_for_saving return prepare_for_saving(*(tensors or ())) -def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: - """Lazy wrapper around :func:`quantized_tensor.restore_from_saved`.""" - from transformer_engine.pytorch.quantized_tensor import restore_from_saved - - return restore_from_saved(tensor_objects, saved_tensors) - - # --------------------------------------------------------------------------- # # Forward-result packing # --------------------------------------------------------------------------- # @@ -1415,70 +1796,85 @@ def _restore_from_saved(tensor_objects: Any, saved_tensors: List[Any]) -> Any: # declaration order. # # At call-site time (:func:`forward_fn` and :func:`_setup_context`), -# the per-call output structure is learned from a fake run of the user -# fwd impl driven by :func:`_run_fake_for_proto` (see -# :func:`_extract_layout` and :func:`_layout_slot_count` near the top -# of this file). The static (class, inner-names, metadata, shape, -# stride) captured by each layout is enough to reassemble the +# the per-call output structure is described by a list of +# :class:`TensorSpec` (preferred path, via ``output_info_fn``) or +# extracted from a fake run of the user fwd impl driven by +# :func:`_run_fake_for_proto` and :meth:`TensorSpec.from_proto` +# (legacy path). Either way, each spec carries enough info +# (class, inner-names, metadata, shape, stride) to reassemble the # user-facing object from its real inner tensors emitted by the op; # subclass reconstruction goes through :class:`_ToSubclassFn` so the # wrap is recorded on the autograd graph. -def _format_fwd_result( - result: Any, - num_outputs: int, -) -> List[torch.Tensor]: - """Pack a fwd-impl return tuple into the op's ``Tensor[]`` payload. +def _flatten_value_into(flat: List[torch.Tensor], value: Any) -> None: + """Append the ``Tensor[]`` slots produced by ``value`` to ``flat``. - Each user output is decomposed into a deterministic number of inner - plain tensors (see :func:`_extract_layout`): + The dispatch matches the four spec kinds in :class:`TensorSpec`: - * ``None`` -> 1 sentinel slot. + * ``None`` -> 1 sentinel slot (via :func:`_encode_none`). * plain Tensor -> 1 slot. - * subclass with - ``__tensor_flatten__`` -> ``len(inner_names)`` slots, in the order - declared by the class. - * storage with - ``_torch_compile_flatten`` -> ``len(tensors)`` slots. - - Saved-for-backward tensors follow in declaration order. ``None`` - entries on either side are smuggled through :func:`_encode_none` - so the schema stays non-nullable and ``register_autograd`` still - attaches a ``grad_fn`` to the op's outputs. - - The slot layout produced here must match exactly what - :func:`_extract_layout` predicts from a proto fake run, since the - call-site reassembly in :func:`forward_fn` uses the proto-derived - layout to slice this flat list back into user-facing objects. + * tensor subclass with ``__tensor_flatten__`` -> ``len(inner_names)`` + slots, in the order declared by the class. + * storage with ``_torch_compile_flatten`` -> ``len(tensors)`` slots. """ - outputs = list(result[:num_outputs]) - flat: List[torch.Tensor] = [] - # Flatten user outputs *before* ``_prepare_for_saving`` -- the - # latter mutates storage instances in place (clears ``_data`` / - # ``_transpose`` / ``_scale_inv``), and the same object can be - # both a user output and a saved-for-backward entry. Doing the - # flatten first observes the original tensor state. - for value in outputs: - if value is None: - flat.append(_encode_none(None)) - elif isinstance(value, torch.Tensor): - if type(value) is not torch.Tensor and hasattr(value, "__tensor_flatten__"): - inner_names, _ = value.__tensor_flatten__() - flat.extend(_encode_none(getattr(value, n)) for n in inner_names) - else: - flat.append(_encode_none(value)) - elif hasattr(value, "_torch_compile_flatten"): - _, _, tensors = value._torch_compile_flatten() - flat.extend(_encode_none(t) for t in tensors) + if value is None: + flat.append(_encode_none(None)) + return + if isinstance(value, torch.Tensor): + if type(value) is not torch.Tensor and hasattr(value, "__tensor_flatten__"): + inner_names, _ = value.__tensor_flatten__() + flat.extend(_encode_none(getattr(value, n)) for n in inner_names) else: - raise TypeError( - f"unsupported output type {type(value).__name__}; expected " - "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " - "class with _torch_compile_flatten." - ) - tensors_to_save, _ = _prepare_for_saving(result[num_outputs]) - flat.extend(_encode_none(t) for t in tensors_to_save) + flat.append(_encode_none(value)) + return + if hasattr(value, "_torch_compile_flatten"): + _, _, tensors = value._torch_compile_flatten() + flat.extend(_encode_none(t) for t in tensors) + return + raise TypeError( + f"unsupported value type {type(value).__name__}; expected " + "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " + "class with _torch_compile_flatten." + ) + + +# Number of trailing slots in every ``fwd_impl`` return tuple: +# ``tensors_to_save, tensor_objects, ctx_attrs``. Everything before +# those is a user output, so ``num_outputs = len(result) - +# _FWD_TRAILING_SLOTS`` -- the same convention every fake-impl (hand +# written or auto-synthesized) follows. +_FWD_TRAILING_SLOTS = 3 + + +def _format_fwd_result(result: Any) -> List[torch.Tensor]: + """Pack a fwd-impl return tuple into the op's ``Tensor[]`` payload. + + User outputs come first, then the saved-for-backward tensors in + declaration order. Both groups go through the same per-value + :func:`_flatten_value_into` dispatch -- the slot layout produced + here must match exactly what :meth:`TensorSpec.slot_count` reports + for the corresponding spec, since the call-site reassembly in + :func:`forward_fn` / :func:`_setup_context` slices this flat list + back into user-facing objects using those per-spec counts. + + ``None`` entries on either side are smuggled through + :func:`_encode_none` so the schema stays non-nullable and + ``register_autograd`` still attaches a ``grad_fn`` to the op's + outputs. + + The split point between user outputs and saved tensors is + inferred from the impl's return shape: + ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)`` + -- the last three slots are the standard ``fwd_impl`` tail. + """ + num_outputs = len(result) - _FWD_TRAILING_SLOTS + flat: List[torch.Tensor] = [] + for value in result[:num_outputs]: + _flatten_value_into(flat, value) + saved = result[num_outputs] or () + for value in saved: + _flatten_value_into(flat, value) return flat @@ -1486,7 +1882,6 @@ def _format_fwd_result( def _run_fake_for_proto( fwd_fake_impl: Callable[[Any], Any], fwd_obj: Any, - num_outputs: int, ) -> List[Any]: """Execute ``fwd_fake_impl(fwd_obj)`` in isolation and return its user-facing outputs to be used as prototypes for output layout @@ -1512,6 +1907,7 @@ def _run_fake_for_proto( with _disable_current_modes(): with FakeTensorMode(allow_non_fake_inputs=True): result = fwd_fake_impl(fwd_obj) + num_outputs = len(result) - _FWD_TRAILING_SLOTS return list(result[:num_outputs]) @@ -1678,7 +2074,6 @@ def _register_autograd_for_op( fwd_buckets: List[_Bucket], bwd_arg_names: List[str], bwd_buckets: List[_Bucket], - num_outputs: int, fwd_slot_defaults: List[Any], grad_targets: List[Tuple[int, bool]], fwd_fake_impl: Optional[Callable[[Any], Any]], @@ -1702,11 +2097,15 @@ def _register_autograd_for_op( tensor, sliced via: * ``output_info_fn(fwd_obj)`` -- the recommended path: a pure - Python function that returns the static - ``(user_specs, saved_specs, ctx_attrs)`` tuple. Traceable by - Dynamo / AOT, no fake tensor allocation involved. + Python function that returns + ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], + ctx_attrs)``. Traceable by Dynamo / AOT, no fake tensor + allocation involved. :class:`AliasedSpec` entries on the saved + side carry the forward-arg name the slot aliases, surfaced to + the user's ``setup_context`` via + ``ctx_attrs["saved_tensor_aliases"]``. * legacy ``fwd_fake_impl(fwd_obj)`` -- runs the user fake impl - and extracts layouts via :func:`_extract_layout`. Kept for + and extracts layouts via :meth:`TensorSpec.from_proto`. Kept for backwards compatibility with callers that haven't migrated to ``output_info_fn`` yet. """ @@ -1722,57 +2121,46 @@ def _setup_context(ctx, inputs, output): fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) if output_info_fn is not None: - user_specs, tensor_objects, ctx_attrs, _fake_specs = output_info_fn(fwd_obj) - cursor = 0 - user_outputs: List[Any] = [] - for spec in user_specs: - n = _spec_slot_count(spec) - chunk = [_decode_none(t) for t in output[cursor:cursor + n]] - cursor += n - user_outputs.append(_reassemble_from_spec(spec, chunk)) - - # ``tensor_objects`` is the same shape :func:`prepare_for_saving` - # would produce: a list with one entry per element of - # ``tensors_to_save_from_forward`` -- ``None`` for plain - # tensors, a storage shell (with ``_data`` / ``_scale_inv`` - # / ... set to ``None``) for quantized storages. The shells - # are constructed inside ``output_info_fn`` via simple - # ``object.__new__`` + attribute writes so Dynamo can carry - # them as ``UserDefinedObjectVariable``s across the trace - # boundary. :func:`_restore_from_saved` reads each shell's - # ``restore_from_saved`` to consume the right number of - # slots from ``op_saved_tensors``. - op_saved_tensors = [_decode_none(t) for t in output[cursor:]] - tensors_to_save_from_forward = _restore_from_saved( - tensor_objects, - op_saved_tensors, - ) + user_specs, saved_slots, ctx_attrs = output_info_fn(fwd_obj) + ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) else: + # Legacy path: learn output and saved-tensor layouts from a + # fake run of the user fwd impl, then reassemble both via + # the same :class:`TensorSpec` machinery. The fake return + # follows the same ``(*user_outputs, tensors_to_save, + # tensor_objects, ctx_attrs)`` shape as the real impl, so + # the user-output count is just ``len(result) - + # _FWD_TRAILING_SLOTS``. fake_result = fake_for_setup(fwd_obj) - # Learn output layouts from the fake result. - layouts = [_extract_layout(p) for p in fake_result[:num_outputs]] - - cursor = 0 - user_outputs = [] - for layout in layouts: - n = _spec_slot_count(layout) - chunk = [_decode_none(t) for t in output[cursor:cursor + n]] - cursor += n - user_outputs.append(_reassemble_from_spec(layout, chunk)) - - op_saved_tensors = [_decode_none(t) for t in output[cursor:]] - _, tensor_objects = _prepare_for_saving(fake_result[num_outputs]) + num_outputs = len(fake_result) - _FWD_TRAILING_SLOTS + user_specs = [ + TensorSpec.from_proto(p) for p in fake_result[:num_outputs] + ] + saved_protos = fake_result[num_outputs] or () + saved_slots = [TensorSpec.from_proto(p) for p in saved_protos] ctx_attrs = fake_result[num_outputs + 2] - tensors_to_save_from_forward = _restore_from_saved( - tensor_objects, - op_saved_tensors, - ) + + cursor = 0 + user_outputs: List[Any] = [] + for spec in user_specs: + n = spec.slot_count() + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + user_outputs.append(spec.reassemble(chunk)) + + tensors_to_save_from_forward_list: List[Any] = [] + for spec in saved_slots: + n = spec.slot_count() + chunk = [_decode_none(t) for t in output[cursor:cursor + n]] + cursor += n + tensors_to_save_from_forward_list.append(spec.reassemble(chunk)) + tensors_to_save_from_forward = tuple(tensors_to_save_from_forward_list) bwd_obj = backward_obj_type() tensors_to_save_from_setup = setup_context_user( bwd_obj, fwd_obj, - user_outputs[0] if num_outputs == 1 else tuple(user_outputs), + user_outputs[0] if len(user_specs) == 1 else tuple(user_outputs), ctx_attrs, tensors_to_save_from_forward, ) @@ -1868,11 +2256,36 @@ def _outer_fake(*flat: Any) -> List[torch.Tensor]: ) +def _all_quantized_tensor_subclasses() -> List[type]: + """Return every imported ``QuantizedTensor`` wrapper subclass. + + Imports the ``transformer_engine.pytorch.tensor`` package as a side + effect so that all concrete wrapper subclasses (``Float8Tensor``, + ``MXFP8Tensor``, ``Float8BlockwiseQTensor``, ``NVFP4Tensor``) get + registered with Python's subclass tracker before we walk + ``QuantizedTensor.__subclasses__()`` recursively. The lazy import + keeps ``dynamo.py`` itself free of top-level ``tensor`` imports + (which would form a cycle through the in-function ``dynamo`` + imports inside the tensor modules), while still giving every + custom op the full subclass set at registration time. + """ + import transformer_engine.pytorch.tensor # noqa: F401 -- side-effect: registers subclasses + from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + seen: List[type] = [] + stack: List[type] = list(QuantizedTensor.__subclasses__()) + while stack: + cls = stack.pop() + if cls in seen: + continue + seen.append(cls) + stack.extend(cls.__subclasses__()) + return seen + + def _te_register_custom_op( *, op_name: str, - num_outputs: Optional[int] = None, - output_annotations: Optional[Sequence[Any]] = None, input_tensors_for_grad: List[str], fwd_arg_type: type, fwd_impl: Callable[[Any], Any], @@ -1882,38 +2295,27 @@ def _te_register_custom_op( backward_obj: type, backward_impl: Callable[[Any], Any], backward_fake_impl: Optional[Callable[[Any], Any]] = None, - subclasses: Optional[Sequence[type]] = None, output_info_fn: Optional[ Callable[ [Any], - Tuple[List[Tuple[Any, ...]], List[Any], Any, Dict[str, Any]], + Tuple[List["TensorSpec"], List["TensorSpec"], Dict[str, Any]], ] ] = None, - bwd_output_info_fn: Optional[ - Callable[[Any], List[Optional[Tuple[Any, ...]]]] - ] = None, + bwd_output_info_fn: Optional[Callable[[Any], List["TensorSpec"]]] = None, ) -> Callable[..., Any]: """Register a TE module's forward + backward as a single torch custom op. + The user-output count is derived dynamically at call time from + the impl return shape: ``num_outputs = len(result) - + _FWD_TRAILING_SLOTS`` (the impl tail is always + ``tensors_to_save, tensor_objects, ctx_attrs``). No explicit + ``num_outputs`` argument is required. + Parameters ---------- op_name Op name used when registering with ``torch.library``. The namespace is fixed at module level (:data:`_TE_OP_NAMESPACE`). - num_outputs - Number of user-facing outputs returned by ``fwd_impl``. May be - inferred from ``output_annotations`` if the latter is provided. - output_annotations - Optional per-output type annotation, e.g. - ``[Union[torch.Tensor, Float8Tensor], - Optional[Union[torch.Tensor, Float8TensorStorage]]]``. Kept - for documentation / backward compatibility. The runtime layout - of each output (plain / subclass / storage / ``None``) is - learned dynamically from a fake run of ``fwd_fake_impl`` - executed under ``_disable_current_modes`` and a fresh - ``FakeTensorMode``, so the annotation does not constrain the - flat ``Tensor[]`` payload anymore. If both are passed, - ``num_outputs`` must equal ``len(output_annotations)``. input_tensors_for_grad Names of forward-arg-type fields for which ``backward_impl`` returns gradients, in the same order. The wrapper uses this to @@ -1974,45 +2376,35 @@ def _te_register_custom_op( of the real gradients. output_info_fn Optional pure-Python layout descriptor for the op's outputs: - ``fn(fwd_obj) -> (user_specs, tensor_objects, ctx_attrs, fake_specs)``. - - * ``user_specs`` is a list, one entry per user output, where - each entry is: - - - ``("plain",)`` -- plain :class:`torch.Tensor` (or ``None`` - smuggled via :func:`_encode_none`). - - ``("none",)`` -- explicit ``None`` (single sentinel slot). - - ``("subclass", cls, inner_names, meta, shape, stride)`` -- - tensor subclass, reassembled via - ``cls.__tensor_unflatten__``. - - ``("storage", cls, meta, pg, tensor_count)`` -- non-tensor - storage, reassembled via ``cls._torch_compile_do_unflatten``. - - * ``tensor_objects`` is the structured descriptor that - :func:`prepare_for_saving` would produce on the user's - ``tensors_to_save_from_forward`` tuple: a Python list with - one entry per saved object, ``None`` for plain tensors and a - storage *shell* (typically built via - :meth:`Quantizer.create_save_shell` -- ``object.__new__`` + - attribute writes, no constructor logic) for quantized - storages. :func:`_restore_from_saved` uses these shells to - reconstruct the saved tuple from the flat ``op_saved_tensors`` - payload. + ``fn(fwd_obj) -> (user_specs, saved_slots, ctx_attrs)``. + + * ``user_specs`` is a list, one :class:`TensorSpec` per user + output. Each spec encodes everything dynamo needs about + that slot: ``slot_count()`` for flat-``Tensor[]`` slicing, + ``reassemble(chunk)`` / ``reassemble_with_autograd(chunk)`` + for rebuilding the user-facing object from the op's flat + output, and ``alloc()`` for the auto-synthesized fake-impl + (see below). The four concrete subclasses -- + :class:`NoneSpec`, :class:`PlainTensorSpec`, + :class:`SubclassTensorSpec`, :class:`StorageSpec` -- cover + every output shape TE currently produces. + + * ``saved_slots`` is a list of :class:`TensorSpec`, one per + saved-for-backward slot, mirroring ``user_specs`` but for + the saved-tensor section of the op payload. Use + :class:`AliasedSpec(name)` for slots that the forward impl + leaves as ``None`` because the value is identical to a + forward arg (the alias name is surfaced to + ``setup_context`` via + ``ctx_attrs["saved_tensor_aliases"]``, injected by dynamo). + Use :class:`NoneSpec` / :class:`PlainTensorSpec` / + :class:`StorageSpec` / :class:`SubclassTensorSpec` for the + rest, exactly as for user outputs. * ``ctx_attrs`` is the non-tensor state attached to the autograd context (passed through to ``setup_context``). - - * ``fake_specs`` is a dict with the alloc info needed to - synthesize a fake-impl when ``fwd_fake_impl`` is not - supplied (see :func:`_alloc_from_fake_spec`). Keys: - - - ``"user_outputs"`` -- list of alloc specs (one per user - output) used to materialise the fake tensors / subclasses - / storages returned by the synthesized fake-impl. - - ``"saved_tensors"`` -- ``None`` (no saved tensors, e.g. - ``is_grad_enabled=False``) or a list of alloc specs (one - per saved slot) used to build the synthesized - ``tensors_to_save`` tuple. + Dynamo augments it with ``"saved_tensor_aliases"`` before + the callback runs. When supplied, :func:`forward_fn` and the autograd ``setup_context`` use this function instead of running @@ -2027,44 +2419,37 @@ def _te_register_custom_op( plain-tensor ops, where the fake-impl path is still cheap. bwd_output_info_fn Optional pure-Python alloc descriptor for the backward op: - ``fn(bwd_obj) -> [alloc_spec_per_grad_output]``. Each entry is - either ``None`` (slot is ``None``), ``("plain", shape, dtype, - device)``, or ``("quantized", quantizer, shape, dtype, - device)``. When supplied (and ``backward_fake_impl`` is not), + ``fn(bwd_obj) -> List[TensorSpec]``, one entry per gradient + output in the same order as ``backward_impl``'s return tuple. + Typically :class:`NoneSpec` for missing grads, + :class:`PlainTensorSpec` for plain tensors, and an alloc-only + :class:`SubclassTensorSpec` (built via + :meth:`SubclassTensorSpec.from_quantizer` without a + ``wrapper_cls``) for quantized ones. When supplied (and + ``backward_fake_impl`` is not), :func:`_te_register_custom_op` synthesizes the backward - fake-impl by allocating one fake per spec via - :func:`_alloc_from_fake_spec`. Useful so the backward fake no - longer has to duplicate the gradient-shape derivation that - lives in the eager impl / its layout descriptor. + fake-impl by calling :meth:`TensorSpec.alloc` on each spec -- + the gradient-shape derivation lives entirely in the + descriptor. Returns ------- Callable A function ``forward_fn(fwd_arg_type_instance)`` that dispatches through the registered custom op, returning the user-facing - outputs (single tensor if ``num_outputs == 1``, otherwise a - tuple). Use under ``torch.compiler.is_compiling()`` as a drop-in - for ``Function.apply``. + outputs (single tensor if the impl produced exactly one + user-facing output, otherwise a tuple). Use under + ``torch.compiler.is_compiling()`` as a drop-in for + ``Function.apply``. """ outer_fwd_name = op_name outer_bwd_name = f"{op_name}_backward" - subclass_list = list(subclasses or ()) - - if output_annotations is not None: - annotated_count = len(list(output_annotations)) - if num_outputs is not None and num_outputs != annotated_count: - raise ValueError( - "_te_register_custom_op: num_outputs=" - f"{num_outputs} does not match len(output_annotations)=" - f"{annotated_count}" - ) - num_outputs = annotated_count - if num_outputs is None: - raise ValueError( - "_te_register_custom_op requires either ``num_outputs`` or " - "``output_annotations``" - ) + # Auto-discover every imported ``QuantizedTensor`` wrapper subclass + # so callers never have to enumerate them. Each subclass gets a + # ``register_torch_dispatch`` rule on the outer op (see below) and + # is flattened into plain tensors before the inner op runs. + subclass_list = _all_quantized_tensor_subclasses() # Precompute the bucket list once per arg type and capture it in # the registered closures. Re-deriving the bucket list inside a @@ -2083,7 +2468,9 @@ def _te_register_custom_op( fwd_buckets, fwd_arg_type, input_tensors_for_grad ) - # Two-tier layout when subclass dispatch rules are requested: + # Two-tier layout when at least one ``QuantizedTensor`` subclass is + # imported (the common case -- ``_all_quantized_tensor_subclasses`` + # discovers them automatically): # inner = ``{op_name}_base`` -- real impl, sees only plain tensors # and the storage-flatten metadata. # outer = ``{op_name}`` -- user-facing op that either falls through @@ -2092,8 +2479,9 @@ def _te_register_custom_op( # call to the inner op with subclass tensors flattened in # place. Both tiers carry their own ``register_autograd`` # bridge. - # Single-tier when no subclasses are given: only the outer pair is - # defined and it owns the real impl (today's behaviour). + # Single-tier fallback: if no ``QuantizedTensor`` subclasses have + # been imported (e.g. minimal embedded build) only the outer pair + # is defined and it owns the real impl directly. inner_fwd_name = f"{op_name}_base" if subclass_list else outer_fwd_name inner_bwd_name = f"{outer_bwd_name}_base" if subclass_list else outer_bwd_name @@ -2133,9 +2521,7 @@ def _te_register_custom_op( # when supplied, so callers can stage the migration op-by-op. effective_fwd_fake_impl = fwd_fake_impl if effective_fwd_fake_impl is None and output_info_fn is not None: - effective_fwd_fake_impl = _make_fake_impl_from_output_info( - output_info_fn, num_outputs - ) + effective_fwd_fake_impl = _make_fake_impl_from_output_info(output_info_fn) effective_bwd_fake_impl = backward_fake_impl if effective_bwd_fake_impl is None and bwd_output_info_fn is not None: effective_bwd_fake_impl = _make_fake_impl_from_bwd_output_info( @@ -2150,7 +2536,7 @@ def _te_register_custom_op( buckets=fwd_buckets, impl=fwd_impl, fake_impl=effective_fwd_fake_impl, - format_result=lambda r: _format_fwd_result(r, num_outputs), + format_result=_format_fwd_result, ) _register_kernel( op_name=inner_bwd_name, @@ -2170,7 +2556,6 @@ def _te_register_custom_op( fwd_buckets=fwd_buckets, bwd_arg_names=bwd_arg_names, bwd_buckets=bwd_buckets, - num_outputs=num_outputs, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, fwd_fake_impl=effective_fwd_fake_impl, @@ -2216,7 +2601,6 @@ def _te_register_custom_op( fwd_buckets=fwd_buckets, bwd_arg_names=bwd_arg_names, bwd_buckets=bwd_buckets, - num_outputs=num_outputs, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, fwd_fake_impl=effective_fwd_fake_impl, @@ -2320,15 +2704,10 @@ def forward_fn(fwd_args): # via ``_run_fake_for_proto`` (``@torch._dynamo.allow_in_graph`` # so it stays opaque to Dynamo). if output_info_fn is not None: - ( - user_specs, - _tensor_objects, - _ctx_attrs, - _fake_specs, - ) = output_info_fn(fwd_args) + user_specs, _saved_slots, _ctx_attrs = output_info_fn(fwd_args) else: - proto_outputs = _run_fake_for_proto(proto_fn, fwd_args, num_outputs) - user_specs = [_extract_layout(p) for p in proto_outputs] + proto_outputs = _run_fake_for_proto(proto_fn, fwd_args) + user_specs = [TensorSpec.from_proto(p) for p in proto_outputs] # 2) Invoke the op (graph node). Result is the flat ``Tensor[]`` # payload produced by :func:`_format_fwd_result`. @@ -2337,36 +2716,22 @@ def forward_fn(fwd_args): result = fwd_op(*flat_in) # 3) Slice the flat result by spec and reassemble each user - # output. Tensor subclasses go through :class:`_ToSubclassFn` - # so the construction is recorded on the autograd graph and - # Dynamo lifts it as an ``autograd.Function`` call; - # ``QuantizedTensorStorage``-style objects (no autograd of - # their own) are reconstructed directly. + # output. :meth:`TensorSpec.reassemble_with_autograd` routes + # subclass paths through :class:`_ToSubclassFn` so the + # construction is recorded on the autograd graph and Dynamo + # lifts it as an ``autograd.Function`` call; plain tensors and + # storage classes (which have no autograd identity of their + # own) are reconstructed directly. cursor = 0 outputs: List[Any] = [] for spec in user_specs: - n = _spec_slot_count(spec) + n = spec.slot_count() chunk_raw = result[cursor:cursor + n] cursor += n chunk = [_decode_none(t) for t in chunk_raw] - kind = spec[0] - if kind == "none": - outputs.append(None) - elif kind == "plain": - outputs.append(chunk[0]) - elif kind == "subclass": - _, cls, inner_names, meta, shape, stride = spec - outputs.append( - _ToSubclassFn.apply( - cls, inner_names, meta, shape, stride, *chunk - ) - ) - else: # "storage" - _, cls, meta, pg, _slot_count = spec - real_tensors = [t for t in chunk if t is not None] - outputs.append(cls._torch_compile_do_unflatten(meta, pg, real_tensors)) + outputs.append(spec.reassemble_with_autograd(chunk)) - if num_outputs == 1: + if len(outputs) == 1: return outputs[0] return tuple(outputs) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0468ef6ad5..9ed57e51b8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -57,7 +57,11 @@ general_gemm, ) from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type -from ..dynamo import _te_register_custom_op +from ..dynamo import ( + TensorSpec, + _te_register_custom_op, + tensor_spec, +) from ..graph import is_graph_capturing from ..quantized_tensor import ( QuantizedTensor, @@ -1286,22 +1290,31 @@ def wgrad_gemm( ) +# ---------------------------------------------------------------------------- +# Compile-tier wrappers: ``output_info_fn`` descriptors + ``_te_register_custom_op`` +# registration. The custom op lets ``torch.compile`` trace through linear +# forward + backward as a single graph node without entering the eager +# ``_Linear`` autograd.Function machinery. Selected by :meth:`Linear.forward` +# when ``torch.compiler.is_compiling()`` is true. +# ---------------------------------------------------------------------------- def _linear_backward_output_info( args: LinearBwdArgs, -) -> List[Optional[Tuple[Any, ...]]]: +) -> List[TensorSpec]: """Pure-Python alloc-spec descriptor for :func:`_linear_backward`. - Returns a list of three alloc specs -- one per gradient output - ``(wgrad, dgrad, grad_bias)`` -- consumed by the auto-synthesized - backward fake-impl in :func:`_make_fake_impl_from_bwd_output_info`. - Replaces the previously hand-written - ``_linear_backward_fake_impl``: gradient shapes/dtypes are - deterministic, so the descriptor just encodes them as alloc - tuples (``("plain", ...)`` /``("quantized", ...)``) instead of - allocating fake tensors. ``set_usage`` on - ``grad_input_quantizer`` is preserved because it influences - ``dgrad``'s downstream ``make_empty``. Manual TE FSDP is - unsupported; FSDP2 / MCore FSDP go through the standard path. + Returns a list of three :class:`TensorSpec` -- one per gradient + output ``(wgrad, dgrad, grad_bias)`` -- consumed by the + auto-synthesized backward fake-impl in + :func:`_make_fake_impl_from_bwd_output_info`. Replaces the + previously hand-written ``_linear_backward_fake_impl``: gradient + shapes / dtypes are deterministic, so the descriptor just encodes + each slot through :func:`tensor_spec` (passing ``shape=None`` for + absent grads and a ``quantizer`` for quantized ones -- backward + grads use alloc-only ``SubclassTensorSpec`` because they go + straight to autograd, never through the op's flat ``Tensor[]``). + ``set_usage`` on ``grad_input_quantizer`` is preserved because it + influences ``dgrad``'s downstream ``make_empty``. Manual TE FSDP + is unsupported; FSDP2 / MCore FSDP go through the standard path. """ if args.fsdp_group is not None: @@ -1319,158 +1332,47 @@ def _linear_backward_output_info( activation_dtype = args.activation_dtype device = args.grad_output.device - def _alloc( - shape: Tuple[int, ...], quantizer: Any - ) -> Tuple[Any, ...]: - if quantizer is not None: - return ("quantized", quantizer, tuple(shape), activation_dtype, device) - return ("plain", tuple(shape), activation_dtype, device) - - wgrad_alloc: Optional[Tuple[Any, ...]] = None - if args.requires_wgrad and not args.fuse_wgrad_accumulation: - wgrad_alloc = _alloc((out_features, in_features), args.grad_weight_quantizer) - - dgrad_alloc: Optional[Tuple[Any, ...]] = None - if args.requires_dgrad: - dgrad_alloc = _alloc(tuple(args.inp_shape), args.grad_input_quantizer) - - grad_bias_alloc: Optional[Tuple[Any, ...]] = None - if args.use_bias and args.requires_wgrad: - grad_bias_alloc = ("plain", (out_features,), activation_dtype, device) - - return [wgrad_alloc, dgrad_alloc, grad_bias_alloc] - - -class _Linear(torch.autograd.Function): - """Linear semi-top level module - Calls custom cuda extensions. - """ - - @staticmethod - def forward( - ctx, - weight: torch.Tensor, - inp: torch.Tensor, - bias: Optional[torch.Tensor], - fwd_args: LinearFwdArgs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward pass: compute linear output and set up autograd context. - - ``weight``, ``inp`` and ``bias`` are positional Tensor arguments so - autograd tracks them; they are immediately re-attached to ``fwd_args`` - so every downstream helper can be invoked with a single argument. - - ``weight_workspace`` is intentionally NOT a positional input: it is a - non-differentiable cached tensor passed in via - ``fwd_args.weight_workspace`` and the freshly produced workspace is - returned as a separate output so the module can refresh its cache. - """ - fwd_args.weight = weight - fwd_args.inp = inp - fwd_args.bias = bias - ( - out, - new_weight_workspace, - tensors_to_save_from_forward, - _, - ctx_attrs, - ) = _linear_forward_impl(fwd_args) - if ctx is not None: - bwd_args = LinearBwdArgs() - tensors_to_save_from_setup = _linear_setup_ctx( - bwd_args, - fwd_args, - out, - ctx_attrs, - tensors_to_save_from_forward, - ) - tensors_to_save, tensor_objects = prepare_for_saving(*tensors_to_save_from_setup) - ctx.save_for_backward(*tensors_to_save) - ctx.tensor_objects = tensor_objects - ctx.backward_objects = bwd_args - if fwd_args.fp8 and ( - fwd_args.input_requires_grad - or fwd_args.weight_requires_grad - or fwd_args.bias_requires_grad - ): - bwd_args.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update() - if fwd_args.backward_override is not None: - bwd_args.reduce_and_update_bwd_fp8_tensors = False - - return out, new_weight_workspace + wgrad_shape = ( + (out_features, in_features) + if args.requires_wgrad and not args.fuse_wgrad_accumulation + else None + ) + dgrad_shape = tuple(args.inp_shape) if args.requires_dgrad else None + grad_bias_shape = (out_features,) if args.use_bias and args.requires_wgrad else None - @staticmethod - def backward( - ctx, - grad_output: torch.Tensor, - _grad_weight_workspace, - ) -> Tuple[Union[torch.Tensor, None], ...]: - """Backward pass: compute gradients and reduce FP8 scaling factors.""" - bwd_args: LinearBwdArgs = ctx.backward_objects - bwd_args.grad_output = grad_output - bwd_args.setup_saved_tensors(ctx) - ctx.tensor_objects = None - nvtx_label = "transformer_engine._Linear.backward" - if bwd_args.ub_name is not None: - nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" - result = _linear_backward(bwd_args) + (None,) # fwd_args grad slot - reduce_and_update_bwd_fp8_tensors = bwd_args.reduce_and_update_bwd_fp8_tensors - # Drop all references held by bwd_args (saved tensors, quantizers, weakrefs, - # main_grad closure) so they don't outlive backward via ctx under retain_graph. - ctx.backward_objects = None - del bwd_args - if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): - nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") - return result + return [ + tensor_spec( + shape=wgrad_shape, + dtype=activation_dtype, + device=device, + quantizer=args.grad_weight_quantizer, + ), + tensor_spec( + shape=dgrad_shape, + dtype=activation_dtype, + device=device, + quantizer=args.grad_input_quantizer, + ), + tensor_spec( + shape=grad_bias_shape, + dtype=activation_dtype, + device=device, + ), + ] -# Register the linear forward + backward as a single torch custom op so that -# ``torch.compile`` can trace through it without entering the eager -# ``torch.autograd.Function`` machinery. Used by :meth:`Linear.forward` -# under ``torch.compiler.is_compiling()``. def _linear_forward_output_info( args: LinearFwdArgs, -) -> Tuple[List[Tuple[Any, ...]], List[Any], Dict[str, Any], Dict[str, Any]]: - """Pure-Python output-layout descriptor for the linear forward. - - Returns ``(user_specs, tensor_objects, ctx_attrs, fake_specs)`` -- - the static, Dynamo-traceable single source of truth for the - forward op's output layout, saved-tensor bookkeeping, and - fake-impl allocation hints. Replaces the previously hand-written - ``_linear_forward_fake_impl``: :func:`_te_register_custom_op` now - auto-synthesizes the fake-impl from ``fake_specs`` via - :func:`_make_fake_impl_from_output_info`, so every per-precision / - per-mode condition lives in exactly one place. - - Why a separate descriptor (vs. a hand-written fake-impl): - constructing real :class:`Float8Tensor` / - :class:`MXFP8TensorStorage` / ... instances inside a fake-impl - relies on the live quantizers, which under ``fullgraph=True`` - Dynamo refuses to trace through (live quantizers are - :class:`UserDefinedObjectVariable`, ``tex.DType`` is a pybind - enum, ...). The descriptor instead emits: - - * ``user_specs`` / ``tensor_objects`` -- pure-Python tuples and - ``object.__new__``-built shells (via - :meth:`Quantizer.create_metadata` / - :meth:`Quantizer.create_save_shell`); consumed by - :func:`forward_fn` and :func:`_setup_context` to reassemble - subclasses and restore saved storages. - * ``fake_specs`` -- alloc tuples - ``("plain", shape, dtype, device)`` / - ``("quantized", quantizer, shape, dtype, device)``; consumed by - the auto-synthesized fake-impl (which only runs under - ``FakeTensorMode``, never under Dynamo's trace -- so live - quantizers / pybind enums are fine here). - - The four return values keep the same branching: every - ``set_usage`` / ``update_usage`` side effect on the live - quantizers happens once in this function and stays consistent - across the layout / fake / forward paths; downstream code - (especially backward) reads the post-forward usage flags off - the same quantizer instance. +) -> Tuple[List[TensorSpec], List[TensorSpec], Dict[str, Any]]: + """Output-layout descriptor for the linear forward. + + Returns ``(user_specs, saved_slots, ctx_attrs)`` -- Dynamo-traceable + layout + alloc info for the op's outputs and saved tensors. Replaces + a hand-written fake-impl: :func:`_te_register_custom_op` synthesizes + one by calling :meth:`TensorSpec.alloc` on each entry. + + All ``set_usage`` side effects on the live quantizers happen here + and are observed by both the real fwd impl and backward. """ fp8 = args.fp8 debug = args.debug @@ -1498,25 +1400,9 @@ def _linear_forward_output_info( and not args.ub_overlap_ag_fprop ) - # ------------------------------------------------------------------ - # Input pipeline -- mirror :func:`_linear_forward_impl`'s - # ``set_usage`` calls and track which of three end-states the - # ``saved_inputmat`` slot will land in: - # - # * ``inputmat_aliases_inp`` -- saved value IS ``args.inp`` - # (impl-side ``saved_tensor_aliases[0] = "inp"``, slot stored - # as ``None`` and resolved back to ``args.inp`` in - # ``_linear_setup_ctx``). - # * ``inputmat_is_storage`` -- saved value is a fresh - # ``QuantizedTensorStorage`` (created here only as a tensor-free - # shell; the impl produces the real one; the auto-synthesized - # fake-impl allocates a fake one from the slot's alloc spec). - # * neither -- saved value is a plain ``Tensor`` (the cast result). - # - # The branches below match :func:`_linear_forward_impl` - # line-for-line; comments cross-reference the mirrored block when - # not obvious. - # ------------------------------------------------------------------ + # Input pipeline -- mirror ``_linear_forward_impl``'s ``set_usage`` + # calls and classify the ``saved_inputmat`` slot end-state: + # aliased to ``args.inp``, fresh quantized storage, or plain cast. inputmat_is_storage = False inputmat_aliases_inp = False own_quantized_input = False @@ -1542,15 +1428,12 @@ def _linear_forward_output_info( inputmat_is_storage = True else: inputmat_aliases_inp = inp.dtype == activation_dtype - # ``inputmat_total`` only matters for the GEMM output shape; the - # all-gather inflates the leading dim by ``tp_world_size``. + # All-gather inflates the leading dim of the GEMM-input shape. inputmat_total_shape = list(inp.shape) inputmat_total_shape[0] *= tp_world_size else: if fp8_or_debug: if isinstance(inp, QuantizedTensorStorage): - # In-place ``update_usage`` on the original storage; - # ``inputmat is args.inp`` stays true downstream. inp.update_usage(rowwise_usage=True) inputmat_is_storage = True inputmat_aliases_inp = True @@ -1570,35 +1453,17 @@ def _linear_forward_output_info( else: inputmat_aliases_inp = inp.dtype == activation_dtype - # ``save_original_input`` (and ``backward_override == "high_precision"`` - # in particular) flips ``inputmat`` back to ``args.inp`` at the - # tail of the impl, overriding whatever the input pipeline - # produced above. We mirror that here by forcing the alias bit - # so the saved slot tracks the impl's final ``saved_inputmat is - # args.inp`` check. + # ``save_original_input`` / ``backward_override="high_precision"`` + # flip ``inputmat`` back to ``args.inp`` at the tail of the impl; + # mirror that here so the saved slot ends up aliased. if save_original_input: inputmat_aliases_inp = True inputmat_is_storage = False - # ------------------------------------------------------------------ - # Weight pipeline -- mirror of :func:`_linear_forward_impl`'s - # ``quantize_weight`` / ``cast_if_needed`` branches. Tracks the - # same three end-states for ``wt_save``: - # - # * ``weightmat_aliases_weight`` -- ``saved_tensor_aliases[1]`` is - # ``"weight"`` and the slot ends up resolving back to - # ``args.weight`` inside ``_linear_setup_ctx``. - # * ``weightmat_is_storage`` (and not aliased) -- a freshly built - # :class:`QuantizedTensorStorage` (real one in the impl, a - # tensor-free shell here for saved-slot bookkeeping). - # * neither -- a plain cast ``Tensor``. - # - # ``new_weight_workspace_spec`` is the user-output [1] slot: - # non-``("none",)`` only on the cache-miss + ``cache_weight`` - # combination, mirroring the weight-workspace caching branch in - # :func:`_linear_forward_impl`. - # ------------------------------------------------------------------ - new_weight_workspace_spec: Tuple[Any, ...] = ("none",) + # Weight pipeline -- mirror ``quantize_weight`` / ``cast_if_needed``. + # ``new_weight_workspace_spec`` is non-``NoneSpec`` only on the + # cache-miss + ``cache_weight`` combination. + new_weight_workspace_spec: TensorSpec = tensor_spec() weightmat_is_storage = False weightmat_aliases_weight = False if fp8_or_debug: @@ -1620,43 +1485,31 @@ def _linear_forward_output_info( weight_quantizer = weight._quantizer if isinstance(weight, QuantizedTensorStorage): - # ``_linear_forward_impl`` short-circuits the weight pipeline - # on a primary-quantized weight: ``weightmat = weight``. + # Primary-quantized weight: the impl reuses it as ``weightmat``. weightmat_is_storage = True weightmat_aliases_weight = True else: weightmat_is_storage = True - # ``new_weight_workspace`` is non-``None`` only when we miss - # the workspace cache *and* the caller asked us to publish - # the freshly-built workspace back. workspace = args.weight_workspace if workspace is not None and not _is_weight_workspace_valid( workspace, weight_quantizer ): workspace = None if workspace is None and args.cache_weight: - cls, meta, pg, count = weight_quantizer.create_storage_metadata( + new_weight_workspace_spec = tensor_spec( shape=weight.shape, - fake_dtype=activation_dtype, + dtype=activation_dtype, device=weight.device, - requires_grad=False, - as_tensor=False, + quantizer=weight_quantizer, + storage=True, ) - new_weight_workspace_spec = ("storage", cls, meta, pg, count) - # ``weightmat.update_usage(rowwise_usage=True)`` runs in the - # impl after this point; that's a no-op on the layout flags - # we track here (we already requested ``rowwise=True`` above). else: weightmat_aliases_weight = weight.dtype == activation_dtype - # ------------------------------------------------------------------ - # Output configuration - # ------------------------------------------------------------------ if output_quantizer is not None: output_quantizer.set_usage(rowwise=True, columnwise=False) - # Compute the GEMM-output shape and the post-comm shape that - # leaves the op. + # Post-comm output shape (the value that leaves the op). gemm_out_shape: List[int] = list(inputmat_total_shape[:-1]) + [out_features] if args.ub_overlap_rs_fprop: out_shape: List[int] = list(inp.shape) @@ -1668,46 +1521,26 @@ def _linear_forward_output_info( else: out_shape = list(gemm_out_shape) - # ------------------------------------------------------------------ - # Build user-output spec [0] -- the GEMM result. - # ------------------------------------------------------------------ - if output_quantizer is None: - out_spec: Tuple[Any, ...] = ("plain",) - else: - # The only subclass we declare in ``output_annotations`` is - # :class:`Float8Tensor`; other quantizer families flow their - # workspace through ``new_weight_workspace`` instead. - inner_names, meta = output_quantizer.create_metadata( - fake_dtype=activation_dtype, - requires_grad=False, - ) - stride = _contiguous_stride(out_shape) - out_spec = ( - "subclass", - Float8Tensor, - inner_names, - meta, - tuple(out_shape), - stride, - ) + # User-output [0] -- the GEMM result. ``Float8Tensor`` is the only + # quantized wrapper this op produces directly; other quantizer + # families flow their workspace through ``new_weight_workspace`` + # instead. + out_spec = tensor_spec( + shape=tuple(out_shape), + dtype=activation_dtype, + device=inp.device, + quantizer=output_quantizer, + wrapper_cls=Float8Tensor if output_quantizer is not None else None, + ) - user_specs: List[Tuple[Any, ...]] = [out_spec, new_weight_workspace_spec] + user_specs: List[TensorSpec] = [out_spec, new_weight_workspace_spec] - # ------------------------------------------------------------------ - # Saved-for-backward tensor_objects + saved_tensor_aliases - # ------------------------------------------------------------------ - tensor_objects: List[Any] = [None, None, None, None] - saved_inputmat_alias: Optional[str] = None - wt_save_alias: Optional[str] = None - bias_alias: Optional[str] = None + saved_slots: List[TensorSpec] = [] if args.is_grad_enabled: - # Post-forward ``update_usage`` on the cached input. The - # in-place ``set_usage`` flips ``input_quantizer`` 's row/col - # bits so backward sees the same storage layout the impl - # ended up with. (Mirrors the matching block in - # :func:`_linear_forward_impl`; we only need to track the - # side effect on the quantizer, no shell rebuild.) + # Post-forward ``set_usage`` -- mirrors ``_linear_forward_impl`` + # so backward observes the same row/col layout on the input + # quantizer the impl ended up with. if ( backward_needs_input and own_quantized_input @@ -1724,177 +1557,66 @@ def _linear_forward_output_info( else: input_quantizer.set_usage(rowwise=False, columnwise=True) - if backward_needs_input: - if inputmat_aliases_inp: - saved_inputmat_alias = "inp" - elif inputmat_is_storage: - # Fresh storage produced by ``input_quantizer``; emit a - # shell so ``_restore_from_saved`` consumes the right - # number of slots from the saved-tensor payload. - tensor_objects[0] = input_quantizer.create_save_shell( - fake_dtype=activation_dtype, - ) - # else: plain Tensor saved -> tensor_objects[0] stays None. - # else: ``saved_inputmat = None`` -> tensor_objects[0] stays None. - - if weightmat_aliases_weight: - wt_save_alias = "weight" - elif args.is_fsdp2: - # ``wt_save = None`` in :func:`_linear_forward_impl` when - # ``weightmat is not args.weight``; FSDP2 re-quantizes - # from the all-gathered weight on backward. - pass - elif weightmat_is_storage: - tensor_objects[1] = weight_quantizer.create_save_shell( - fake_dtype=activation_dtype, - ) - # else: plain cast Tensor saved -> tensor_objects[1] stays None. - - if bias is not None: - bias_alias = "bias" - - saved_tensor_aliases = ( - saved_inputmat_alias, - wt_save_alias, - "weight", - bias_alias, - ) - - # Manual TE FSDP unsupported under compile. - if args.fsdp_group is not None and args.is_grad_enabled: - raise NotImplementedError( - "Compile-time Linear forward does not support manual TE FSDP " - "(fsdp_group is not None); use FSDP2 or MCore FSDP." - ) - fsdp_shapes: List[Any] = [] - - ctx_attrs: Dict[str, Any] = { - "fsdp_shapes": fsdp_shapes, - "saved_tensor_aliases": saved_tensor_aliases, - } - - # ------------------------------------------------------------------ - # Fake-impl allocation specs -- consumed by the auto-synthesized - # fake-impl in :func:`_make_fake_impl_from_output_info`. One alloc - # spec per user output and per saved-tensor slot. Pure data so it - # can be carried across Dynamo's trace boundary as constants / - # ``UserDefinedObjectVariable``s. - # ------------------------------------------------------------------ - if output_quantizer is None: - out_alloc: Tuple[Any, ...] = ( - "plain", tuple(out_shape), activation_dtype, inp.device, - ) - else: - out_alloc = ( - "quantized", - output_quantizer, - tuple(out_shape), - activation_dtype, - inp.device, - ) - - if new_weight_workspace_spec[0] == "none": - new_weight_workspace_alloc: Optional[Tuple[Any, ...]] = None - else: - new_weight_workspace_alloc = ( - "quantized", - weight_quantizer, - tuple(weight.shape), - activation_dtype, - weight.device, - ) - - user_output_allocs: List[Optional[Tuple[Any, ...]]] = [ - out_alloc, - new_weight_workspace_alloc, - ] - - saved_tensor_allocs: Optional[List[Optional[Tuple[Any, ...]]]] - if not args.is_grad_enabled: - saved_tensor_allocs = None - else: - # Slot 0 -- ``saved_inputmat``. ``None`` when nothing is saved - # (alias to ``inp``, or backward doesn't need the input); - # ``("quantized", ...)`` when the saved value is a quantized - # storage (matches ``tensor_objects[0] != None``); ``("plain", - # ...)`` otherwise (a fresh cast). - if not backward_needs_input or saved_inputmat_alias is not None: - slot0_alloc: Optional[Tuple[Any, ...]] = None - elif tensor_objects[0] is not None: - slot0_alloc = ( - "quantized", - input_quantizer, - tuple(inp.shape), - activation_dtype, - inp.device, - ) + # Slot 0 -- ``saved_inputmat``: absent / aliased to ``inp`` / + # fresh quantized storage / plain cast (mutually exclusive). + if not backward_needs_input: + saved_slots.append(tensor_spec()) + elif inputmat_aliases_inp: + saved_slots.append(tensor_spec(alias="inp")) else: - slot0_alloc = ( - "plain", tuple(inp.shape), activation_dtype, inp.device, + saved_slots.append( + tensor_spec( + shape=tuple(inp.shape), + dtype=activation_dtype, + device=inp.device, + quantizer=input_quantizer if inputmat_is_storage else None, + storage=inputmat_is_storage, + ) ) - # Slot 1 -- ``wt_save``. ``None`` when aliased to ``weight`` or - # under FSDP2 (the latter rebuilds the workspace on backward). - # ``args.weight_quantizer`` may differ from the local - # ``weight_quantizer`` (which is reassigned to - # ``weight._quantizer`` when the weight is already a - # :class:`QuantizedTensor`); the saved storage's quantizer must - # match the one the impl uses for re-quantization. + # Slot 1 -- ``wt_save``. The saved storage's quantizer must + # match the one the impl uses for re-quantization, which is + # ``weight._quantizer`` for already-quantized weights. FSDP2 + # re-quantizes from the all-gathered weight on backward, so + # the slot is absent in that case. weight_quantizer_for_save = ( weight._quantizer if isinstance(weight, QuantizedTensor) else args.weight_quantizer ) - if wt_save_alias is not None or args.is_fsdp2: - slot1_alloc: Optional[Tuple[Any, ...]] = None - elif tensor_objects[1] is not None: - slot1_alloc = ( - "quantized", - weight_quantizer_for_save, - tuple(weight.shape), - activation_dtype, - weight.device, - ) + if weightmat_aliases_weight: + saved_slots.append(tensor_spec(alias="weight")) + elif args.is_fsdp2: + saved_slots.append(tensor_spec()) else: - slot1_alloc = ( - "plain", tuple(weight.shape), activation_dtype, weight.device, + saved_slots.append( + tensor_spec( + shape=tuple(weight.shape), + dtype=activation_dtype, + device=weight.device, + quantizer=weight_quantizer_for_save if weightmat_is_storage else None, + storage=weightmat_is_storage, + ) ) - # Slot 2 -- ``saved_weight`` always aliased back to ``weight`` - # by :func:`_linear_setup_ctx``; Slot 3 -- ``saved_bias`` is - # either aliased ("bias") or ``None`` when there is no bias. - # Both stored slots are therefore always ``None``. - saved_tensor_allocs = [slot0_alloc, slot1_alloc, None, None] + # Slot 2 -- ``saved_weight`` (always aliased). Slot 3 -- + # ``saved_bias`` (aliased or absent). + saved_slots.append(tensor_spec(alias="weight")) + saved_slots.append(tensor_spec(alias="bias") if bias is not None else tensor_spec()) - fake_specs: Dict[str, Any] = { - "user_outputs": user_output_allocs, - "saved_tensors": saved_tensor_allocs, - } - - return user_specs, tensor_objects, ctx_attrs, fake_specs + if args.fsdp_group is not None and args.is_grad_enabled: + raise NotImplementedError( + "Compile-time Linear forward does not support manual TE FSDP " + "(fsdp_group is not None); use FSDP2 or MCore FSDP." + ) + ctx_attrs: Dict[str, Any] = {"fsdp_shapes": []} -def _contiguous_stride(shape: List[int]) -> Tuple[int, ...]: - """Row-major contiguous stride for ``shape``.""" - stride: List[int] = [1] * len(shape) - for i in range(len(shape) - 2, -1, -1): - stride[i] = stride[i + 1] * int(shape[i + 1]) - return tuple(stride) + return user_specs, saved_slots, ctx_attrs _linear_compiled_op = _te_register_custom_op( op_name="linear", - # ``out`` may be a plain Tensor (default path) or a ``Float8Tensor`` - # (when an output quantizer is configured, e.g. ``fp8_output=True`` - # on a downstream module wired through ``output_quantizer``). - # ``new_weight_workspace`` is the optional FP8 weight cache: a - # ``Float8TensorStorage`` on cache miss with ``is_first_microbatch`` - # / ``cache_weight``; ``None`` otherwise (the bookkeeping flows - # through the storage flatten path even when ``None``). - output_annotations=[ - Union[torch.Tensor, Float8Tensor], - Optional[Union[torch.Tensor, Float8TensorStorage]], - ], input_tensors_for_grad=["weight", "inp", "bias"], fwd_arg_type=LinearFwdArgs, fwd_impl=_linear_forward_impl, @@ -1904,17 +1626,94 @@ def _contiguous_stride(shape: List[int]) -> Tuple[int, ...]: backward_obj=LinearBwdArgs, backward_impl=_linear_backward, bwd_output_info_fn=_linear_backward_output_info, - # Two-tier custom op: the outer ``linear`` op accepts tensor - # subclasses (e.g. ``Float8Tensor`` as a weight), and an - # ``register_torch_dispatch`` rule flattens each subclass into - # plain tensors plus storage metadata before calling the inner - # ``linear_base`` op. The wrapper's autograd identity stays - # attached to the inner tensors so gradients flow back to the - # user-facing tensor (``Linear.weight.grad`` is populated). - subclasses=[Float8Tensor], ) +class _Linear(torch.autograd.Function): + """Linear semi-top level module + Calls custom cuda extensions. + """ + + @staticmethod + def forward( + ctx, + weight: torch.Tensor, + inp: torch.Tensor, + bias: Optional[torch.Tensor], + fwd_args: LinearFwdArgs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass: compute linear output and set up autograd context. + + ``weight``, ``inp`` and ``bias`` are positional Tensor arguments so + autograd tracks them; they are immediately re-attached to ``fwd_args`` + so every downstream helper can be invoked with a single argument. + + ``weight_workspace`` is intentionally NOT a positional input: it is a + non-differentiable cached tensor passed in via + ``fwd_args.weight_workspace`` and the freshly produced workspace is + returned as a separate output so the module can refresh its cache. + """ + fwd_args.weight = weight + fwd_args.inp = inp + fwd_args.bias = bias + ( + out, + new_weight_workspace, + tensors_to_save_from_forward, + _, + ctx_attrs, + ) = _linear_forward_impl(fwd_args) + if ctx is not None: + bwd_args = LinearBwdArgs() + tensors_to_save_from_setup = _linear_setup_ctx( + bwd_args, + fwd_args, + out, + ctx_attrs, + tensors_to_save_from_forward, + ) + tensors_to_save, tensor_objects = prepare_for_saving(*tensors_to_save_from_setup) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.backward_objects = bwd_args + if fwd_args.fp8 and ( + fwd_args.input_requires_grad + or fwd_args.weight_requires_grad + or fwd_args.bias_requires_grad + ): + bwd_args.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update() + if fwd_args.backward_override is not None: + bwd_args.reduce_and_update_bwd_fp8_tensors = False + + return out, new_weight_workspace + + @staticmethod + def backward( + ctx, + grad_output: torch.Tensor, + _grad_weight_workspace, + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Backward pass: compute gradients and reduce FP8 scaling factors.""" + bwd_args: LinearBwdArgs = ctx.backward_objects + bwd_args.grad_output = grad_output + bwd_args.setup_saved_tensors(ctx) + ctx.tensor_objects = None + nvtx_label = "transformer_engine._Linear.backward" + if bwd_args.ub_name is not None: + nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" + result = _linear_backward(bwd_args) + (None,) # fwd_args grad slot + reduce_and_update_bwd_fp8_tensors = bwd_args.reduce_and_update_bwd_fp8_tensors + # Drop all references held by bwd_args (saved tensors, quantizers, weakrefs, + # main_grad closure) so they don't outlive backward via ctx under retain_graph. + ctx.backward_objects = None + del bwd_args + if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") + return result + + class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 93a8843faa..ec60893d06 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -309,6 +309,12 @@ def create_storage_metadata( has_rowwise = bool(self.rowwise_usage) has_columnwise = bool(self.columnwise_usage) tensor_count = int(has_rowwise) * 2 + int(has_columnwise) * 2 + # Storage's :meth:`_torch_compile_flatten` also emits the live + # quantizer's flatten tensors (see + # :meth:`Float8Quantizer.create_storage_metadata` for + # rationale); keep the count + meta in sync. + quantizer_meta, _, quantizer_tensors = self._flatten() + tensor_count += len(quantizer_tensors) from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel meta = OpaqueSimpleMetadata( @@ -325,36 +331,11 @@ def create_storage_metadata( "has_rowwise_scale_inv": has_rowwise, "has_columnwise_data": has_columnwise, "has_columnwise_scale_inv": has_columnwise, - "quantizer_meta": None, + "quantizer_meta": quantizer_meta, } ) return Float8BlockwiseQTensorStorage, meta, None, tensor_count - def create_save_shell( - self, - *, - fake_dtype: torch.dtype, - ) -> Float8BlockwiseQTensorStorage: - """Return a tensor-free :class:`Float8BlockwiseQTensorStorage` - shell suitable for use as a ``tensor_objects`` entry in - :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. - - Built via ``object.__new__`` + direct attribute writes for - Dynamo traceability. Mirrors - :meth:`Float8Quantizer.create_save_shell` -- see its docstring - for rationale. - """ - shell = object.__new__(Float8BlockwiseQTensorStorage) - shell._dtype = fake_dtype - shell._rowwise_data = None - shell._columnwise_data = None - shell._rowwise_scale_inv = None - shell._columnwise_scale_inv = None - shell._fp8_dtype = self.dtype - shell._quantizer = None - shell._is_2D_scaled = self.block_scaling_dim == 2 - return shell - def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 11fde18a2f..ed26d50773 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -149,6 +149,14 @@ def _float8_create_storage_metadata( has_transpose = bool(quantizer.columnwise_usage) has_scale_inv = True tensor_count = int(has_data) + int(has_transpose) + int(has_scale_inv) + # Storage's :meth:`_torch_compile_flatten` also emits the live + # quantizer's flatten tensors when ``self._quantizer is not None`` + # (the impl-produced storage always carries one). Pull + # ``quantizer._flatten()`` to learn the count + meta so the + # metadata we publish here stays in lock-step with the slot count + # produced at flatten time. + quantizer_meta, _, quantizer_tensors = quantizer._flatten() + tensor_count += len(quantizer_tensors) from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel meta = OpaqueSimpleMetadata( @@ -169,57 +177,12 @@ def _float8_create_storage_metadata( "has_data": has_data, "has_transpose": has_transpose, "has_scale_inv": has_scale_inv, - "quantizer_meta": None, + "quantizer_meta": quantizer_meta, } ) return Float8TensorStorage, meta, None, tensor_count -def _float8_create_save_shell( - quantizer: "Quantizer", - *, - fake_dtype: torch.dtype, -) -> "Float8TensorStorage": - """Return a tensor-free :class:`Float8TensorStorage` shell suitable - for use as a ``tensor_objects`` entry in - :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. - - The shell is built via ``object.__new__`` + direct attribute writes - rather than the regular constructor: that avoids tripping Dynamo - on the UDF args (live quantizer instance, ``tex.DType``) that - :meth:`Float8TensorStorage.__new__` would otherwise see when this - function is called from a Dynamo-traced region (e.g. from - ``_linear_forward_output_info``). - - The shell holds no inner tensors -- ``restore_from_saved`` fills - them in from the flat saved-tensor list emitted by the op return, - matching the fixed three-slot layout (``_data``, ``_transpose``, - ``_scale_inv``) of :meth:`Float8TensorStorage.prepare_for_saving`. - The ``_quantizer`` slot is intentionally left ``None``; user code - inside the compiled region must source the live quantizer from - outside. - """ - shell = object.__new__(Float8TensorStorage) - shell._dtype = fake_dtype - shell._data = None - shell._transpose = None - shell._scale_inv = None - shell._fp8_dtype = quantizer.dtype - shell._quantizer = None - # ``_transpose_invalid`` flags a transpose buffer that exists but - # whose contents are stale. Saved-for-backward storages always - # come from the forward after the quantizer has filled in the - # transpose (when it was requested), so the saved transpose -- if - # present at all -- is valid. Initialising to ``False`` keeps - # ``has_data_transpose`` true whenever ``_transpose`` ends up - # non-``None`` after :meth:`restore_from_saved` (which itself only - # writes ``_transpose`` and leaves this flag alone). The - # transpose-None case is unaffected since ``has_data_transpose`` - # ANDs in the ``_transpose is not None`` check. - shell._transpose_invalid = False - return shell - - class Float8Quantizer(Quantizer): """Builder class for FP8 tensors with per-tensor delayed scaling @@ -454,14 +417,6 @@ def create_storage_metadata( as_tensor=as_tensor, ) - def create_save_shell( - self, - *, - fake_dtype: torch.dtype, - ) -> Float8TensorStorage: - # pylint: disable=missing-function-docstring - return _float8_create_save_shell(self, fake_dtype=fake_dtype) - def _flatten(self): from ..dynamo import OpaqueSimpleMetadata @@ -773,14 +728,6 @@ def create_storage_metadata( as_tensor=as_tensor, ) - def create_save_shell( - self, - *, - fake_dtype: torch.dtype, - ) -> Float8TensorStorage: - # pylint: disable=missing-function-docstring - return _float8_create_save_shell(self, fake_dtype=fake_dtype) - def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index b4bd410d58..5767f254c5 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -297,6 +297,12 @@ def create_storage_metadata( int(has_rowwise) * 2 # rowwise_data + rowwise_scale_inv + int(has_columnwise) * 2 # columnwise_data + columnwise_scale_inv ) + # Storage's :meth:`_torch_compile_flatten` also emits the live + # quantizer's flatten tensors (see + # :meth:`Float8Quantizer.create_storage_metadata` for + # rationale); keep the count + meta in sync. + quantizer_meta, _, quantizer_tensors = self._flatten() + tensor_count += len(quantizer_tensors) from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel meta = OpaqueSimpleMetadata( @@ -313,37 +319,11 @@ def create_storage_metadata( "has_rowwise_scale_inv": has_rowwise, "has_columnwise_data": has_columnwise, "has_columnwise_scale_inv": has_columnwise, - "quantizer_meta": None, + "quantizer_meta": quantizer_meta, } ) return MXFP8TensorStorage, meta, None, tensor_count - def create_save_shell( - self, - *, - fake_dtype: torch.dtype, - ) -> MXFP8TensorStorage: - """Return a tensor-free :class:`MXFP8TensorStorage` shell for - use as a ``tensor_objects`` entry in - :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. - - Built via ``object.__new__`` + direct attribute writes for - Dynamo traceability. Mirrors - :meth:`Float8Quantizer.create_save_shell` -- see its docstring - for rationale. Restores from the fixed four-slot layout - emitted by :meth:`MXFP8TensorStorage.prepare_for_saving`. - """ - shell = object.__new__(MXFP8TensorStorage) - shell._dtype = fake_dtype - shell._rowwise_data = None - shell._columnwise_data = None - shell._rowwise_scale_inv = None - shell._columnwise_scale_inv = None - shell._fp8_dtype = self.dtype - shell._quantizer = None - shell._with_gemm_swizzled_scales = self.optimize_for_gemm - return shell - def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index bd1938ee2f..5ba8ed1833 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -447,6 +447,12 @@ def create_storage_metadata( tensor_count = ( int(has_rowwise) * 3 + int(has_columnwise) * 3 ) + # Storage's :meth:`_torch_compile_flatten` also emits the live + # quantizer's flatten tensors (see + # :meth:`Float8Quantizer.create_storage_metadata` for + # rationale); keep the count + meta in sync. + quantizer_meta, _, quantizer_tensors = self._flatten() + tensor_count += len(quantizer_tensors) from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel meta = OpaqueSimpleMetadata( @@ -466,40 +472,11 @@ def create_storage_metadata( "has_columnwise_scale_inv": has_columnwise, "has_amax_rowwise": has_rowwise, "has_amax_columnwise": has_columnwise, - "quantizer_meta": None, + "quantizer_meta": quantizer_meta, } ) return NVFP4TensorStorage, meta, None, tensor_count - def create_save_shell( - self, - *, - fake_dtype: torch.dtype, - ) -> NVFP4TensorStorage: - """Return a tensor-free :class:`NVFP4TensorStorage` shell for - use as a ``tensor_objects`` entry in - :func:`transformer_engine.pytorch.quantized_tensor.restore_from_saved`. - - Built via ``object.__new__`` + direct attribute writes for - Dynamo traceability. Restores from the fixed six-slot layout - emitted by :meth:`NVFP4TensorStorage.prepare_for_saving` - (rowwise_data, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv, amax_rowwise, amax_columnwise). - """ - shell = object.__new__(NVFP4TensorStorage) - shell._dtype = fake_dtype - shell._rowwise_data = None - shell._columnwise_data = None - shell._rowwise_scale_inv = None - shell._columnwise_scale_inv = None - shell._amax_rowwise = None - shell._amax_columnwise = None - shell._fp4_dtype = self.dtype - shell._quantizer = None - shell._with_gemm_swizzled_scales = self.optimize_for_gemm - shell._row_scaled_nvfp4 = self.row_scaled_nvfp4 - return shell - def _flatten(self): from ..dynamo import OpaqueSimpleMetadata From 23399f3431d5b9957a2c9b722c1a98db508792be Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 16:02:59 +0200 Subject: [PATCH 07/16] [PyTorch] Require output-info descriptors for custom ops Remove the legacy fake-impl prototype path so compiled TE ops derive output layouts and fake kernels from explicit TensorSpec descriptors. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 539 +++----------------- transformer_engine/pytorch/module/base.py | 16 - transformer_engine/pytorch/module/linear.py | 16 - 3 files changed, 85 insertions(+), 486 deletions(-) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 196d4e223a..13c3375e52 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -27,13 +27,7 @@ __all__ = [ "OpaqueSimpleMetadata", "TensorSpec", - "NoneSpec", - "AliasedSpec", - "PlainTensorSpec", - "SubclassTensorSpec", - "StorageSpec", "tensor_spec", - "_DispatchTrigger", "_te_register_custom_op", ] @@ -86,10 +80,9 @@ def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: # op's ``Tensor[]`` return. # # At call-site time (in :func:`forward_fn`), the layout for each user -# output is learned from a fake run of the user fwd impl (driven by -# :func:`_run_fake_for_proto` -- ``@torch._dynamo.disable``'d so the -# fake call doesn't pollute the surrounding FX graph). The layout -# carries the static (class, inner_names, metadata, shape, stride) +# output is described by the user-supplied ``output_info_fn``: a pure +# Python function that returns a list of :class:`TensorSpec`, each +# carrying the static (class, inner_names, metadata, shape, stride) # tuple needed to reassemble the user-facing object from its real # inner tensors emitted by the op. @@ -125,19 +118,9 @@ def _contiguous_stride(shape: Sequence[int]) -> Tuple[int, ...]: # interposes :class:`_ToSubclassFn` for # subclass paths so the construction stays # on the autograd graph; -# * ``alloc()`` -- (optional) build an empty fake version of -# the value for shape inference under +# * ``alloc()`` -- build an empty fake version of the value +# for shape inference under # :class:`torch._subclasses.FakeTensorMode`. -# Required only when the op has no -# hand-written ``fwd_fake_impl`` / -# ``backward_fake_impl`` and relies on -# :func:`_make_fake_impl_from_output_info` / -# :func:`_make_fake_impl_from_bwd_output_info` -# to auto-synthesize one. -# -# Replaces the earlier pair of parallel tuple lists (``user_specs`` for -# reassembly, ``fake_specs["user_outputs"]`` for allocation) that every -# ``output_info_fn`` had to keep in lock-step. class TensorSpec: @@ -172,52 +155,7 @@ def reassemble_with_autograd(self, chunk: List[Any]) -> Any: def alloc(self) -> Any: raise NotImplementedError( - f"{type(self).__name__}.alloc() not implemented; the spec was " - f"built without allocation info (legacy fake-impl path)." - ) - - @staticmethod - def from_proto(proto_value: Any) -> "TensorSpec": - """Build a reassembly-only spec from a fake-impl proto value. - - Used only by the legacy path where the user provides - ``fwd_fake_impl`` instead of ``output_info_fn``: a fake - prototype tensor is constructed by the user fake-impl, and the - layout (kind, cls, inner_names, meta, shape, stride) is - extracted from it. The returned spec is reassembly-capable but - not alloc-capable -- callers on this path don't need alloc. - """ - if proto_value is None: - return NoneSpec() - if isinstance(proto_value, torch.Tensor): - if type(proto_value) is not torch.Tensor and hasattr( - proto_value, "__tensor_flatten__" - ): - inner_names, meta = proto_value.__tensor_flatten__() - return SubclassTensorSpec( - cls=type(proto_value), - inner_names=tuple(inner_names), - meta=meta, - shape=tuple(proto_value.shape), - stride=tuple(proto_value.stride()), - ) - return PlainTensorSpec( - shape=tuple(proto_value.shape), - dtype=proto_value.dtype, - device=proto_value.device, - ) - if hasattr(proto_value, "_torch_compile_flatten"): - meta, pg, tensors = proto_value._torch_compile_flatten() - return StorageSpec( - cls=type(proto_value), - meta=meta, - pg=pg, - tensor_count=len(tensors), - ) - raise TypeError( - f"unsupported output type {type(proto_value).__name__}; expected " - "None / torch.Tensor / tensor subclass with __tensor_flatten__ / " - "class with _torch_compile_flatten." + f"{type(self).__name__}.alloc() not implemented" ) @@ -602,17 +540,15 @@ def _make_fake_impl_from_output_info( """Build a forward fake-impl from an ``output_info_fn``. The synthesized fake-impl returns - ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)`` -- - the same shape :func:`_setup_context` expects from a hand-written - ``fwd_fake_impl``: + ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)``: * ``user_outputs`` comes from ``[s.alloc() for s in user_specs]``. * ``tensors_to_save`` comes from ``tuple(s.alloc() for s in saved_slots)``, or ``None`` if ``saved_slots`` is empty (e.g. ``is_grad_enabled=False``). * ``tensor_objects`` is a vestigial slot kept for tuple-shape - symmetry with hand-written fake impls; the - compile path no longer consumes it. + symmetry; the compile path does not consume + it. * ``ctx_attrs`` is augmented with ``saved_tensor_aliases`` derived from ``saved_slots`` so the user's ``setup_context`` sees the same contract. @@ -694,104 +630,6 @@ def backward(ctx, grad_output): return (None, None, None, None, None) + grads -# --------------------------------------------------------------------------- # -# Dispatch trigger -# --------------------------------------------------------------------------- # -# -# ``register_torch_dispatch(op, subclass, rule)`` only fires when at least -# one argument of the call is an instance of ``subclass``. To get the rule -# to fire *unconditionally* (so the user-facing wrapping logic -- output -# rewrapping into ``Float8Tensor`` etc. -- always runs in the same place -# regardless of whether the caller passed any "real" subclass instances), -# we add an internal ``_DispatchTrigger`` tensor as the last positional -# argument of every subclass-aware custom op. The trigger is a 0-element -# wrapper subclass; the schema slot is plain ``Tensor``, so the call is -# transparent to torch autograd / opcheck and the trigger never appears -# in user code. - -class _DispatchTrigger(torch.Tensor): - """Empty wrapper-subclass tensor used solely to force a - ``register_torch_dispatch`` rule to fire on every call to a - subclass-aware custom op. - - Designed to be installed as an ``nn.Module`` buffer (typically on - :class:`TransformerEngineBaseModule`) and threaded through the - custom op's argument dataclass as a regular ``torch.Tensor`` - field. Dynamo lifts ``nn.Module`` buffers as graph inputs, so the - trigger reaches the FX graph as a regular FakeTensor instead of a - Python-side constant -- this is what made every other "always-on - trigger" approach (module-level globals, fresh-per-call - constructors, ...) trip ``FakeTensorMode`` under - ``torch.compile``. - - ``__torch_dispatch__`` is a transparent passthrough: any op - accidentally invoked on a trigger falls back to the underlying op - with the trigger replaced by a plain empty tensor. The - ``register_torch_dispatch(outer_op, _DispatchTrigger, ...)`` - bindings installed by :func:`_te_register_custom_op` shadow this - for the specific ops we care about. - """ - - @staticmethod - def __new__(cls, _inner: Optional[torch.Tensor] = None) -> "_DispatchTrigger": - instance = torch.Tensor._make_wrapper_subclass( # pylint: disable=no-member - cls, (0,), dtype=_NONE_SENTINEL_DTYPE, device="cpu", - ) - # Attach a regular inner tensor so the subclass has something - # for Dynamo / FakeTensorMode to fake out via the standard - # subclass-flattening protocol. Without an inner tensor, - # Dynamo can't reproduce the subclass instance in the fake - # graph and the call to a ``torch.compile``'d module trips - # ``InternalTorchDynamoError: Wrapped Tensor must be this - # graph's fake``. - instance._inner = ( - _inner if _inner is not None - else torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) - ) - return instance - - def __init__(self, _inner: Optional[torch.Tensor] = None) -> None: - # All work is done in ``__new__``; the optional ``_inner`` - # parameter is consumed there. The signature is mirrored here - # so direct ``__init__`` calls (e.g. via ``__tensor_unflatten__`` - # paths inside Dynamo) don't trip ``TypeError`` on the extra - # positional. - del _inner - - def __tensor_flatten__(self) -> Tuple[List[str], Dict[str, Any]]: - return ["_inner"], {} - - @staticmethod - def __tensor_unflatten__( - inner_tensors: Dict[str, torch.Tensor], - meta: Dict[str, Any], - outer_size, - outer_stride, - ) -> "_DispatchTrigger": - del meta, outer_size, outer_stride - return _DispatchTrigger(_inner=inner_tensors["_inner"]) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - def _strip(value: Any) -> Any: - if isinstance(value, _DispatchTrigger): - return torch.empty(0, dtype=_NONE_SENTINEL_DTYPE) - return value - - new_args = [_strip(a) for a in args] - new_kwargs = {k: _strip(v) for k, v in kwargs.items()} - return func(*new_args, **new_kwargs) - - def _stable_hash_for_caching(self) -> str: - # Required by AOT autograd's subclass cache. The trigger - # carries no semantically-relevant state, so a constant string - # is sufficient and ensures different trigger instances cache - # to the same compiled artifact. - return "te.dynamo._DispatchTrigger" - - # --------------------------------------------------------------------------- # # OpaqueSimpleMetadata # --------------------------------------------------------------------------- # @@ -1033,36 +871,11 @@ def __repr__(self) -> str: # --------------------------------------------------------------------------- # # Field buckets # --------------------------------------------------------------------------- # - -# Each dataclass field of an argument container is mapped to exactly one -# bucket. A bucket owns the full per-field "vocabulary" -- which schema -# slots it emits, how its packed value(s) are produced from the dataclass -# instance, and how the unpacked value is re-injected into the -# reconstructed instance. The module-level :func:`_get_buckets` / -# :func:`_get_schema` / :func:`_pack` / :func:`_unpack` helpers then -# become trivial loops over a list of buckets, instead of three parallel -# branch ladders. -# -# Five bucket kinds are used: # -# * :class:`_TensorBucket` -- :class:`torch.Tensor` / -# :class:`Optional[torch.Tensor] ` -> one ``Tensor`` / -# ``Tensor?`` slot. -# * :class:`_ProcessGroupBucket` -- :class:`torch.distributed.ProcessGroup` -# (already registered upstream as a value-opaque type) -> one direct -# slot. -# * :class:`_FlattenableBucket` -- a field whose type implements the -# ``_flatten`` / ``_unflatten`` protocol (today: any -# :class:`Quantizer` or :class:`Recipe` subclass) -> three slots -# ``__fmeta`` / ``__fpg`` / ``__ftensors``. Bases -# are discovered via :func:`_flattenable_bases`, lazily imported to -# avoid an import cycle. -# * :class:`_SimpleBundleBucket` -- aggregator over **all** simple-typed -# fields of the dataclass; emits a single ``_simple_meta`` slot -# carrying an :class:`OpaqueSimpleMetadata` bundle. -# * :class:`_UnknownBucket` -- a field whose annotation matches none of -# the above. Emits no schema slot; pack raises if the field holds a -# non-``None`` value, unpack restores it as ``None``. +# Each dataclass field is mapped to exactly one bucket that owns its +# schema slots and the pack/unpack logic between the dataclass attribute +# and the flat ``torch.library`` view. Concrete bucket types are defined +# below; the per-class docstrings describe what each one matches. def _strip_optional(annot: Any) -> Tuple[Any, bool]: @@ -1645,15 +1458,10 @@ def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: # Dataclass <-> torch.library plumbing # --------------------------------------------------------------------------- # # -# The argument containers consumed by :func:`_te_register_custom_op` -# (e.g. ``LinearFwdArgs`` / ``LinearBwdArgs``) are intentionally just -# plain ``@dataclass`` types -- no base class, no decorators, no special -# methods. All translation between the dataclass and the flat -# ``{slot_name: slot_value}`` view that ``torch.library`` works with is -# provided by the module-level helpers below, which dispatch on dataclass -# field annotations: each field is mapped to exactly one :class:`_Bucket` -# and the three operations (schema / pack / unpack) reduce to a loop -# over the bucket list. +# The helpers below translate a plain ``@dataclass`` argument container +# into the flat ``{slot_name: slot_value}`` view ``torch.library`` works +# with. Each dataclass field is dispatched (by annotation) to one +# :class:`_Bucket`; schema / pack / unpack are then loops over that list. def _resolved_field_annotations(cls: type) -> List[Tuple[str, Any]]: @@ -1758,12 +1566,9 @@ def _unpack(cls: type, args: Dict[str, Any], buckets: List[_Bucket]) -> Any: # Op registration helpers # --------------------------------------------------------------------------- # # -# The bottom half of the module turns one or more user-supplied eager -# kernels (forward / backward / their fake counterparts) plus the -# dataclass argument types into a fully registered ``torch.library`` -# custom op. :func:`_te_register_custom_op` is the orchestrator; the -# helpers below are the per-step building blocks (validation, kernel -# wrapping, dispatcher creation). +# Per-step building blocks (schema, kernel wrapping, autograd bridge, +# dispatcher) used by :func:`_te_register_custom_op` to turn user-supplied +# eager kernels + dataclass arg types into a ``torch.library`` custom op. def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any]: @@ -1786,25 +1591,11 @@ def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any # Forward-result packing # --------------------------------------------------------------------------- # # -# The custom-op schema is fixed at ``-> Tensor[]``: a single flat list of -# plain tensors. To return values that are *not* plain tensors (a -# :class:`Float8Tensor` wrapper subclass, a ``QuantizedTensorStorage`` -# workspace, ``None``...), :func:`_format_fwd_result` runs each user -# output through the relevant flatten protocol and concatenates the -# resulting inner tensors -- one variable-length chunk per output -- -# into the op's flat return. Saved-for-backward tensors follow in -# declaration order. -# -# At call-site time (:func:`forward_fn` and :func:`_setup_context`), -# the per-call output structure is described by a list of -# :class:`TensorSpec` (preferred path, via ``output_info_fn``) or -# extracted from a fake run of the user fwd impl driven by -# :func:`_run_fake_for_proto` and :meth:`TensorSpec.from_proto` -# (legacy path). Either way, each spec carries enough info -# (class, inner-names, metadata, shape, stride) to reassemble the -# user-facing object from its real inner tensors emitted by the op; -# subclass reconstruction goes through :class:`_ToSubclassFn` so the -# wrap is recorded on the autograd graph. +# The op schema is fixed at ``-> Tensor[]``. To return non-tensor +# values (subclass wrappers, ``QuantizedTensorStorage``, ``None``...), +# :func:`_format_fwd_result` runs each user output through its +# flatten protocol and concatenates the inner tensors into the flat +# return; saved-for-backward tensors follow in declaration order. def _flatten_value_into(flat: List[torch.Tensor], value: Any) -> None: @@ -1839,11 +1630,9 @@ def _flatten_value_into(flat: List[torch.Tensor], value: Any) -> None: ) -# Number of trailing slots in every ``fwd_impl`` return tuple: -# ``tensors_to_save, tensor_objects, ctx_attrs``. Everything before -# those is a user output, so ``num_outputs = len(result) - -# _FWD_TRAILING_SLOTS`` -- the same convention every fake-impl (hand -# written or auto-synthesized) follows. +# Trailing slots in every ``fwd_impl`` return tuple: +# ``tensors_to_save, tensor_objects, ctx_attrs``. User-output count +# is ``len(result) - _FWD_TRAILING_SLOTS``. _FWD_TRAILING_SLOTS = 3 @@ -1878,39 +1667,6 @@ def _format_fwd_result(result: Any) -> List[torch.Tensor]: return flat -@torch._dynamo.allow_in_graph -def _run_fake_for_proto( - fwd_fake_impl: Callable[[Any], Any], - fwd_obj: Any, -) -> List[Any]: - """Execute ``fwd_fake_impl(fwd_obj)`` in isolation and return its - user-facing outputs to be used as prototypes for output layout - extraction. - - Isolated from any ambient ``FakeTensorMode`` (Dynamo / AOT's own - mode included) by stacking ``_disable_current_modes`` plus a - fresh ``FakeTensorMode``. None of the fake allocations performed - inside ``fwd_fake_impl`` pollute the surrounding FX graph; the - proto outputs leave the function as Python objects whose - metadata (class, ``__tensor_flatten__`` names, shape, ...) is - extracted into static layout tuples on the call site. - - Decorated with :func:`torch._dynamo.allow_in_graph` so that - Dynamo encodes the entire call as a single opaque FX node - instead of trying to trace the fake-allocation body. Unlike - ``@torch._dynamo.disable`` this does not graph-break under - ``fullgraph=True``. - """ - from torch._subclasses.fake_tensor import FakeTensorMode - from torch.utils._python_dispatch import _disable_current_modes - - with _disable_current_modes(): - with FakeTensorMode(allow_non_fake_inputs=True): - result = fwd_fake_impl(fwd_obj) - num_outputs = len(result) - _FWD_TRAILING_SLOTS - return list(result[:num_outputs]) - - def _format_bwd_result( grads: Any, num_grad_inputs: int, op_qualname: str ) -> List[torch.Tensor]: @@ -2076,11 +1832,9 @@ def _register_autograd_for_op( bwd_buckets: List[_Bucket], fwd_slot_defaults: List[Any], grad_targets: List[Tuple[int, bool]], - fwd_fake_impl: Optional[Callable[[Any], Any]], - fwd_impl: Callable[[Any], Any], setup_context_user: Callable[..., None], backward_obj_type: type, - output_info_fn: Optional[Callable[[Any], Tuple[List[Tuple[Any, ...]], List[Tuple[Any, ...]], Any]]] = None, + output_info_fn: Callable[[Any], Tuple[List["TensorSpec"], List["TensorSpec"], Any]], ) -> None: """Wire ``register_autograd`` on a forward op so its backward calls ``bwd_op_name``. @@ -2094,25 +1848,16 @@ def _register_autograd_for_op( The op's ``Tensor[]`` return holds the flat layout produced by :func:`_format_fwd_result` -- one chunk per user output / saved - tensor, sliced via: - - * ``output_info_fn(fwd_obj)`` -- the recommended path: a pure - Python function that returns - ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], - ctx_attrs)``. Traceable by Dynamo / AOT, no fake tensor - allocation involved. :class:`AliasedSpec` entries on the saved - side carry the forward-arg name the slot aliases, surfaced to - the user's ``setup_context`` via - ``ctx_attrs["saved_tensor_aliases"]``. - * legacy ``fwd_fake_impl(fwd_obj)`` -- runs the user fake impl - and extracts layouts via :meth:`TensorSpec.from_proto`. Kept for - backwards compatibility with callers that haven't migrated to - ``output_info_fn`` yet. + tensor, sliced according to the user-supplied ``output_info_fn``: + a pure Python function returning + ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], + ctx_attrs)``. Traceable by Dynamo / AOT, no fake tensor allocation + involved. :class:`AliasedSpec` entries on the saved side carry the + forward-arg name the slot aliases, surfaced to the user's + ``setup_context`` via ``ctx_attrs["saved_tensor_aliases"]``. """ fwd_qualname = f"{_TE_OP_NAMESPACE}::{fwd_op_name}" - fake_for_setup = fwd_fake_impl if fwd_fake_impl is not None else fwd_impl - def _setup_context(ctx, inputs, output): ctx._te_fwd_tensor_list_lengths = { i: len(value) for i, value in enumerate(inputs) if isinstance(value, list) @@ -2120,25 +1865,8 @@ def _setup_context(ctx, inputs, output): kwargs = dict(zip(fwd_arg_names, inputs)) fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) - if output_info_fn is not None: - user_specs, saved_slots, ctx_attrs = output_info_fn(fwd_obj) - ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) - else: - # Legacy path: learn output and saved-tensor layouts from a - # fake run of the user fwd impl, then reassemble both via - # the same :class:`TensorSpec` machinery. The fake return - # follows the same ``(*user_outputs, tensors_to_save, - # tensor_objects, ctx_attrs)`` shape as the real impl, so - # the user-output count is just ``len(result) - - # _FWD_TRAILING_SLOTS``. - fake_result = fake_for_setup(fwd_obj) - num_outputs = len(fake_result) - _FWD_TRAILING_SLOTS - user_specs = [ - TensorSpec.from_proto(p) for p in fake_result[:num_outputs] - ] - saved_protos = fake_result[num_outputs] or () - saved_slots = [TensorSpec.from_proto(p) for p in saved_protos] - ctx_attrs = fake_result[num_outputs + 2] + user_specs, saved_slots, ctx_attrs = output_info_fn(fwd_obj) + ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) cursor = 0 user_outputs: List[Any] = [] @@ -2289,19 +2017,15 @@ def _te_register_custom_op( input_tensors_for_grad: List[str], fwd_arg_type: type, fwd_impl: Callable[[Any], Any], - fwd_fake_impl: Optional[Callable[[Any], Any]] = None, setup_context: Callable[..., None], backward_arg_type: type, backward_obj: type, backward_impl: Callable[[Any], Any], - backward_fake_impl: Optional[Callable[[Any], Any]] = None, - output_info_fn: Optional[ - Callable[ - [Any], - Tuple[List["TensorSpec"], List["TensorSpec"], Dict[str, Any]], - ] - ] = None, - bwd_output_info_fn: Optional[Callable[[Any], List["TensorSpec"]]] = None, + output_info_fn: Callable[ + [Any], + Tuple[List["TensorSpec"], List["TensorSpec"], Dict[str, Any]], + ], + bwd_output_info_fn: Callable[[Any], List["TensorSpec"]], ) -> Callable[..., Any]: """Register a TE module's forward + backward as a single torch custom op. @@ -2342,18 +2066,6 @@ def _te_register_custom_op( * ``ctx_attrs`` -- non-tensor state to attach to the autograd context, restricted to values that cannot be derived from the forward args inside ``setup_context``. - fwd_fake_impl - Optional fake (shape inference) counterpart of ``fwd_impl``, - registered via ``torch.library.register_fake``. Returns the same - tuple shape as ``fwd_impl`` -- ``(*output_tensors, - tensors_to_save, tensor_objects, ctx_attrs)`` -- but every - ``torch.Tensor`` is a fake tensor (allocated via - ``quantizer.make_empty`` or ``torch.empty``) carrying only the - correct shape / dtype / device, with no real storage or - computation. ``tensor_objects`` and ``ctx_attrs`` must be - structurally identical to those produced by ``fwd_impl`` so - that ``setup_context`` and ``backward_impl`` see the same - non-tensor state in eager and traced modes. setup_context Eager autograd ``setup_context`` analogue. Receives a freshly constructed ``backward_obj`` instance, the forward args, the @@ -2370,12 +2082,8 @@ def _te_register_custom_op( backward_impl Eager backward implementation. Receives a single argument of type ``backward_arg_type`` and returns the gradient tuple. - backward_fake_impl - Optional fake counterpart of ``backward_impl``. Returns the same - gradient tuple as ``backward_impl``, with fake tensors in place - of the real gradients. output_info_fn - Optional pure-Python layout descriptor for the op's outputs: + Pure-Python layout descriptor for the op's outputs: ``fn(fwd_obj) -> (user_specs, saved_slots, ctx_attrs)``. * ``user_specs`` is a list, one :class:`TensorSpec` per user @@ -2383,11 +2091,11 @@ def _te_register_custom_op( that slot: ``slot_count()`` for flat-``Tensor[]`` slicing, ``reassemble(chunk)`` / ``reassemble_with_autograd(chunk)`` for rebuilding the user-facing object from the op's flat - output, and ``alloc()`` for the auto-synthesized fake-impl - (see below). The four concrete subclasses -- - :class:`NoneSpec`, :class:`PlainTensorSpec`, - :class:`SubclassTensorSpec`, :class:`StorageSpec` -- cover - every output shape TE currently produces. + output, and ``alloc()`` for the auto-synthesized fake-impl. + The four concrete subclasses -- :class:`NoneSpec`, + :class:`PlainTensorSpec`, :class:`SubclassTensorSpec`, + :class:`StorageSpec` -- cover every output shape TE + currently produces. * ``saved_slots`` is a list of :class:`TensorSpec`, one per saved-for-backward slot, mirroring ``user_specs`` but for @@ -2406,31 +2114,25 @@ def _te_register_custom_op( Dynamo augments it with ``"saved_tensor_aliases"`` before the callback runs. - When supplied, :func:`forward_fn` and the autograd - ``setup_context`` use this function instead of running - ``fwd_fake_impl`` to learn output layouts -- which is the only - way to keep the layout-extraction step traceable by Dynamo - under ``fullgraph=True`` (fake-impl execution typically tries - to construct subclasses with UDF arguments such as live - quantizers / pybind enums, graph-breaking the trace). - - Required to support tensor-subclass outputs (e.g. - :class:`Float8Tensor`) under ``torch.compile``. Optional for - plain-tensor ops, where the fake-impl path is still cheap. + :func:`forward_fn` and the autograd ``setup_context`` use + this descriptor to learn output layouts without ever + materialising a fake prototype tensor -- the only way to + keep layout extraction traceable by Dynamo under + ``fullgraph=True``. The forward fake-impl + (:func:`torch.library.register_fake`) is auto-synthesized + from the same specs via :func:`_make_fake_impl_from_output_info`. bwd_output_info_fn - Optional pure-Python alloc descriptor for the backward op: + Pure-Python alloc descriptor for the backward op: ``fn(bwd_obj) -> List[TensorSpec]``, one entry per gradient output in the same order as ``backward_impl``'s return tuple. Typically :class:`NoneSpec` for missing grads, :class:`PlainTensorSpec` for plain tensors, and an alloc-only :class:`SubclassTensorSpec` (built via :meth:`SubclassTensorSpec.from_quantizer` without a - ``wrapper_cls``) for quantized ones. When supplied (and - ``backward_fake_impl`` is not), - :func:`_te_register_custom_op` synthesizes the backward - fake-impl by calling :meth:`TensorSpec.alloc` on each spec -- - the gradient-shape derivation lives entirely in the - descriptor. + ``wrapper_cls``) for quantized ones. The backward fake-impl + is synthesized from these specs via + :func:`_make_fake_impl_from_bwd_output_info`, so the + gradient-shape derivation lives entirely in the descriptor. Returns ------- @@ -2513,20 +2215,12 @@ def _te_register_custom_op( inner_bwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_bwd_name}" # Auto-synthesize the forward / backward fake impls from the - # alloc-spec descriptors when the caller did not hand-write them. - # The synthesized impls share branching with their layout - # counterparts (``output_info_fn`` / ``bwd_output_info_fn``) so - # there's exactly one place where every per-precision / per-mode - # condition lives. Hand-written fake impls still take precedence - # when supplied, so callers can stage the migration op-by-op. - effective_fwd_fake_impl = fwd_fake_impl - if effective_fwd_fake_impl is None and output_info_fn is not None: - effective_fwd_fake_impl = _make_fake_impl_from_output_info(output_info_fn) - effective_bwd_fake_impl = backward_fake_impl - if effective_bwd_fake_impl is None and bwd_output_info_fn is not None: - effective_bwd_fake_impl = _make_fake_impl_from_bwd_output_info( - bwd_output_info_fn - ) + # alloc-spec descriptors. The synthesized impls share branching + # with their layout counterparts (``output_info_fn`` / + # ``bwd_output_info_fn``) so there's exactly one place where every + # per-precision / per-mode condition lives. + fwd_fake_impl = _make_fake_impl_from_output_info(output_info_fn) + bwd_fake_impl = _make_fake_impl_from_bwd_output_info(bwd_output_info_fn) _register_kernel( op_name=inner_fwd_name, @@ -2535,7 +2229,7 @@ def _te_register_custom_op( arg_names=fwd_arg_names, buckets=fwd_buckets, impl=fwd_impl, - fake_impl=effective_fwd_fake_impl, + fake_impl=fwd_fake_impl, format_result=_format_fwd_result, ) _register_kernel( @@ -2545,7 +2239,7 @@ def _te_register_custom_op( arg_names=bwd_arg_names, buckets=bwd_buckets, impl=backward_impl, - fake_impl=effective_bwd_fake_impl, + fake_impl=bwd_fake_impl, format_result=lambda g: _format_bwd_result(g, num_grad_inputs, inner_bwd_qualname), ) _register_autograd_for_op( @@ -2558,31 +2252,18 @@ def _te_register_custom_op( bwd_buckets=bwd_buckets, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, - fwd_fake_impl=effective_fwd_fake_impl, - fwd_impl=fwd_impl, setup_context_user=setup_context, backward_obj_type=backward_obj, output_info_fn=output_info_fn, ) if subclass_list: - # Two-tier setup, mirroring the ex.py pattern: - # - # * Inner pair (already registered above) carries the real - # kernels + fakes and a full ``register_autograd`` bridge. - # It only ever sees plain tensors / plain - # ``QuantizedTensorStorage`` flat slots; the subclass - # wrapper never reaches it. - # * Outer pair is a thin opaque shell. Its kernels forward - # to the inner op and its ``register_torch_dispatch`` rules - # flatten registered subclasses inline before forwarding. - # It carries its own autograd bridge so that the user-facing - # tensor (e.g. a ``Float8Tensor`` weight parameter) ends - # up on the autograd graph and receives a ``.grad``. With - # ``__tensor_unflatten__`` rebuilding a real quantizer from - # the subclass meta snapshot, outer's setup_context can run - # the user fake impl on the raw forward inputs even when - # they include reconstructed subclass instances. + # Outer tier (thin shell): default kernels forward to inner + # plus a ``register_torch_dispatch`` rule per subclass that + # flattens the wrapper in place before forwarding. Carries + # its own autograd bridge so the user-facing subclass tensor + # (e.g. a ``Float8Tensor`` parameter) stays on the autograd + # graph and receives a ``.grad``. _register_outer_forwarder( outer_op_name=outer_fwd_name, inner_op_name=inner_fwd_name, @@ -2603,21 +2284,11 @@ def _te_register_custom_op( bwd_buckets=bwd_buckets, fwd_slot_defaults=fwd_slot_defaults, grad_targets=grad_targets, - fwd_fake_impl=effective_fwd_fake_impl, - fwd_impl=fwd_impl, setup_context_user=setup_context, backward_obj_type=backward_obj, output_info_fn=output_info_fn, ) - # Register per-subclass ``torch_dispatch`` rules. Each rule - # flattens every registered subclass arg into the - # ``_UniversalTensorBucket`` storage layout (so the inner op - # only ever sees plain tensors + opaque metadata) and forwards - # to the inner op. The flat ``Tensor[]`` output travels back - # untouched -- user-facing wrapping into subclasses / storage - # happens in :func:`forward_fn` via :class:`_ToSubclassFn`, - # outside the dispatcher. fwd_slot_offsets = _collect_universal_slot_offsets(fwd_buckets) bwd_slot_offsets = _collect_universal_slot_offsets(bwd_buckets) inner_fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), inner_fwd_name) @@ -2643,25 +2314,11 @@ def _bwd_rule(mode, func, types, args, kwargs): _flatten_all_subclasses(new_args, bwd_slot_offsets) return inner_bwd_op(*new_args) - # EXPERIMENT: temporarily disable trigger-based dispatch rules. - # torch.library.register_torch_dispatch( - # outer_fwd_qualname, _DispatchTrigger, _fwd_rule, lib=_TE_LIB - # ) - # torch.library.register_torch_dispatch( - # outer_bwd_qualname, _DispatchTrigger, _bwd_rule, lib=_TE_LIB - # ) - - # Also register per-subclass dispatch rules. The trigger - # rule above only fires when the dispatcher actually - # consults ``register_torch_dispatch`` (e.g. eager-mode calls - # where the trigger is the only subclass), which doesn't - # cover the case where Dynamo lifts a real wrapper-subclass - # parameter (such as a ``Float8Tensor`` weight) into the FX - # graph: in that case Dynamo invokes the registered fake - # impl instead, so we additionally bind the same rule body - # for every concrete subclass class so the eager dispatcher - # still picks it up alongside the fake impl handling the - # tracing path. + # Per-subclass dispatch rule: any registered subclass arg + # passed to the outer op (e.g. Dynamo lifting a + # ``Float8Tensor`` weight into the FX graph) is flattened + # into its storage layout before forwarding to the inner op, + # which only ever sees plain tensors. for sub in subclass_list: torch.library.register_torch_dispatch( outer_fwd_qualname, sub, _fwd_rule, lib=_TE_LIB @@ -2687,48 +2344,22 @@ def _bwd_rule(mode, func, types, args, kwargs): _quantized_tensor_passthrough_ops.add(inner_bwd_op.default) fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) - # Use the auto-synthesized fake-impl when available so the proto - # path stays in sync with the kernel registration above. Falls back - # to ``fwd_impl`` when there is no fake-impl at all (legacy - # plain-tensor ops). - proto_fn = ( - effective_fwd_fake_impl if effective_fwd_fake_impl is not None else fwd_impl - ) def forward_fn(fwd_args): - # 1) Learn user-output layouts. - # ``output_info_fn`` is the recommended path: a pure Python - # function that returns the static spec tuples without ever - # materialising a fake prototype tensor. Traceable by Dynamo - # under ``fullgraph=True``. Fallback: legacy fake-impl run - # via ``_run_fake_for_proto`` (``@torch._dynamo.allow_in_graph`` - # so it stays opaque to Dynamo). - if output_info_fn is not None: - user_specs, _saved_slots, _ctx_attrs = output_info_fn(fwd_args) - else: - proto_outputs = _run_fake_for_proto(proto_fn, fwd_args) - user_specs = [TensorSpec.from_proto(p) for p in proto_outputs] - - # 2) Invoke the op (graph node). Result is the flat ``Tensor[]`` - # payload produced by :func:`_format_fwd_result`. + user_specs, _saved_slots, _ctx_attrs = output_info_fn(fwd_args) kwargs = _pack(fwd_args, fwd_buckets) flat_in = [kwargs[name] for name in fwd_arg_names] result = fwd_op(*flat_in) - # 3) Slice the flat result by spec and reassemble each user - # output. :meth:`TensorSpec.reassemble_with_autograd` routes - # subclass paths through :class:`_ToSubclassFn` so the - # construction is recorded on the autograd graph and Dynamo - # lifts it as an ``autograd.Function`` call; plain tensors and - # storage classes (which have no autograd identity of their - # own) are reconstructed directly. + # Slice the flat result by spec. Subclass specs route through + # :class:`_ToSubclassFn` to keep the wrap on the autograd graph; + # plain tensors / storage classes are reconstructed directly. cursor = 0 outputs: List[Any] = [] for spec in user_specs: n = spec.slot_count() - chunk_raw = result[cursor:cursor + n] + chunk = [_decode_none(t) for t in result[cursor:cursor + n]] cursor += n - chunk = [_decode_none(t) for t in chunk_raw] outputs.append(spec.reassemble_with_autograd(chunk)) if len(outputs) == 1: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index fb70b5f4e7..58f42781e0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -854,22 +854,6 @@ def __init__(self, name: Optional[str] = None) -> None: self._output_quantizer_role: Optional[QuantizerRole] = None self._grad_input_quantizer_role: Optional[QuantizerRole] = None - # Empty wrapper-subclass tensor threaded through every TE - # custom op as a regular ``Tensor`` argument. Its sole purpose - # is to make ``register_torch_dispatch`` rules - # (registered in :func:`transformer_engine.pytorch.dynamo._te_register_custom_op` - # against ``_DispatchTrigger``) fire on every call to a - # subclass-aware op, even when no other argument is a - # registered subclass. Routed via ``register_buffer`` so that - # Dynamo lifts it as a regular graph input under - # ``torch.compile`` instead of internalising it as a - # Python-side constant (which would then trip - # ``FakeTensorMode``). - from transformer_engine.pytorch.dynamo import _DispatchTrigger - self.register_buffer( - "_te_dispatch_trigger", _DispatchTrigger(), persistent=False - ) - if not TEDebugState.debug_enabled: TEDebugState.initialize() self._validate_name() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9ed57e51b8..b00bc4bb03 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -161,13 +161,6 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool - # Always set to ``self._te_dispatch_trigger`` of the calling - # module: a tiny ``_DispatchTrigger`` wrapper-subclass tensor that - # exists only to make ``register_torch_dispatch`` rules fire on - # every call to the outer custom op, so output rewrapping can run - # in a single place. See :class:`transformer_engine.pytorch.dynamo._DispatchTrigger`. - _te_dispatch_trigger: Optional[torch.Tensor] = None - @dataclass(slots=True) class LinearBwdArgs: @@ -237,12 +230,6 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False - # See :class:`LinearFwdArgs._te_dispatch_trigger`. Set in - # ``_linear_setup_ctx`` from the corresponding forward-args field - # so the backward op carries the same trigger and its always-on - # ``register_torch_dispatch`` rule fires too. - _te_dispatch_trigger: Optional[torch.Tensor] = None - # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None @@ -725,7 +712,6 @@ def _linear_setup_ctx( # Misc bwd_args.cpu_offloading = fwd_args.cpu_offloading - bwd_args._te_dispatch_trigger = fwd_args._te_dispatch_trigger if backward_override is not None: bwd_args.fp8 = False @@ -2280,8 +2266,6 @@ def forward( # misc cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, - # always-on torch_dispatch trigger - _te_dispatch_trigger=self._te_dispatch_trigger, ) if use_compiled_op: out, new_weight_workspace = _linear_compiled_op(fwd_args) From 89a80aab53c25419ad87ca986e0eb1e6f104f201 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 17:17:11 +0200 Subject: [PATCH 08/16] [PyTorch] Drop dead parameters from torch.compile custom op registration Remove ``requires_grad``/``as_tensor`` from ``SubclassTensorSpec.from_quantizer`` and ``StorageSpec.from_quantizer`` (always defaulted, never threaded through), drop the vestigial ``ctx_attrs`` slot from ``_make_fake_impl_from_output_info`` (``_format_fwd_result`` never reads it), and make ``fake_impl`` required in ``_register_kernel`` (always supplied since output_info_fn became mandatory). Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 65 ++++++++++------------------ 1 file changed, 24 insertions(+), 41 deletions(-) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 13c3375e52..2b2d440a40 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -324,7 +324,6 @@ def from_quantizer( dtype: "torch.dtype", device: "torch.device", wrapper_cls: Optional[type] = None, - requires_grad: bool = False, ) -> "SubclassTensorSpec": """Build a :class:`SubclassTensorSpec` from a live quantizer. @@ -349,10 +348,7 @@ def from_quantizer( alloc_dtype=dtype, alloc_device=device, ) - inner_names, meta = quantizer.create_metadata( - fake_dtype=dtype, - requires_grad=requires_grad, - ) + inner_names, meta = quantizer.create_metadata(fake_dtype=dtype) return cls( cls=wrapper_cls, inner_names=inner_names, @@ -419,8 +415,6 @@ def from_quantizer( shape: Sequence[int], dtype: "torch.dtype", device: "torch.device", - requires_grad: bool = False, - as_tensor: bool = False, ) -> "StorageSpec": """Build a :class:`StorageSpec` from a live quantizer. @@ -434,8 +428,6 @@ def from_quantizer( shape=shape, fake_dtype=dtype, device=device, - requires_grad=requires_grad, - as_tensor=as_tensor, ) return cls( cls=storage_cls, @@ -540,18 +532,17 @@ def _make_fake_impl_from_output_info( """Build a forward fake-impl from an ``output_info_fn``. The synthesized fake-impl returns - ``(*user_outputs, tensors_to_save, tensor_objects, ctx_attrs)``: + ``(*user_outputs, tensors_to_save, None, None)``: * ``user_outputs`` comes from ``[s.alloc() for s in user_specs]``. * ``tensors_to_save`` comes from ``tuple(s.alloc() for s in saved_slots)``, or ``None`` if ``saved_slots`` is empty (e.g. ``is_grad_enabled=False``). - * ``tensor_objects`` is a vestigial slot kept for tuple-shape - symmetry; the compile path does not consume - it. - * ``ctx_attrs`` is augmented with ``saved_tensor_aliases`` - derived from ``saved_slots`` so the user's - ``setup_context`` sees the same contract. + * The trailing ``tensor_objects`` / ``ctx_attrs`` slots are + ``None`` placeholders -- the eager fwd_impl contract requires + them in the tuple (via ``_FWD_TRAILING_SLOTS``) but + :func:`_format_fwd_result` only reads user outputs + saved + tensors off a fake-impl return. ``output_info_fn`` must return a 3-tuple ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], @@ -559,14 +550,16 @@ def _make_fake_impl_from_output_info( """ def _fake(args: Any) -> Tuple[Any, ...]: - user_specs, saved_slots, ctx_attrs = output_info_fn(args) + user_specs, saved_slots, _ = output_info_fn(args) user_outputs = [s.alloc() for s in user_specs] - if not saved_slots: - tensors_to_save: Any = None - else: - tensors_to_save = tuple(s.alloc() for s in saved_slots) - ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) - return (*user_outputs, tensors_to_save, None, ctx_attrs) + tensors_to_save = ( + None if not saved_slots else tuple(s.alloc() for s in saved_slots) + ) + # Trailing ``tensor_objects`` / ``ctx_attrs`` slots are required + # by the eager fwd_impl contract (``_FWD_TRAILING_SLOTS``) but + # are never read off a fake-impl return -- ``_format_fwd_result`` + # only slices user outputs + tensors_to_save out of the tuple. + return (*user_outputs, tensors_to_save, None, None) return _fake @@ -1747,11 +1740,10 @@ def _register_kernel( arg_names: List[str], buckets: List[_Bucket], impl: Callable[[Any], Any], - fake_impl: Optional[Callable[[Any], Any]], + fake_impl: Callable[[Any], Any], format_result: Callable[[Any], List[torch.Tensor]], ) -> None: - """Wire ``impl`` (and optionally ``fake_impl``) into :data:`_TE_LIB` - under ``op_name``. + """Wire ``impl`` + ``fake_impl`` into :data:`_TE_LIB` under ``op_name``. The wrapper unpacks the flat positional args using ``arg_names`` / ``buckets``, calls the user kernel with the rebuilt @@ -1764,16 +1756,13 @@ def _eager(*flat: Any) -> List[torch.Tensor]: obj = _unpack(arg_type, kwargs, buckets) return format_result(impl(obj)) - _TE_LIB.impl(op_name, _eager, "CompositeExplicitAutograd") - - if fake_impl is not None: - - def _fake(*flat: Any) -> List[torch.Tensor]: - kwargs = dict(zip(arg_names, flat)) - obj = _unpack(arg_type, kwargs, buckets) - return format_result(fake_impl(obj)) + def _fake(*flat: Any) -> List[torch.Tensor]: + kwargs = dict(zip(arg_names, flat)) + obj = _unpack(arg_type, kwargs, buckets) + return format_result(fake_impl(obj)) - torch.library.register_fake(op_qualname, _fake, lib=_TE_LIB) + _TE_LIB.impl(op_name, _eager, "CompositeExplicitAutograd") + torch.library.register_fake(op_qualname, _fake, lib=_TE_LIB) def _collect_universal_slot_offsets(buckets: List[_Bucket]) -> List[int]: @@ -2029,12 +2018,6 @@ def _te_register_custom_op( ) -> Callable[..., Any]: """Register a TE module's forward + backward as a single torch custom op. - The user-output count is derived dynamically at call time from - the impl return shape: ``num_outputs = len(result) - - _FWD_TRAILING_SLOTS`` (the impl tail is always - ``tensors_to_save, tensor_objects, ctx_attrs``). No explicit - ``num_outputs`` argument is required. - Parameters ---------- op_name From a3f63539dbad6381f5ad4c79d929545c17e55b2c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 17:17:24 +0200 Subject: [PATCH 09/16] [PyTorch] Unify quantized tensor flatten via declarative schema Move ``__tensor_flatten__`` / ``__tensor_unflatten__`` and ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` onto the ``QuantizedTensor`` / ``QuantizedTensorStorage`` base classes, driven by three per-subclass class attributes: ``_FLATTEN_TENSOR_ATTRS``, ``_FLATTEN_META_ATTRS``, and ``_FLATTEN_CTOR_KWARG``, plus a ``_flatten_meta_overrides`` hook (used by ``Float8Tensor`` to bridge ``FP8DType`` <-> ``tex.DType``). Both the PyTorch wrapper-subclass protocol and the storage triplet protocol now share a single source of truth for the field schema, and the four storage shells + ``Float8Tensor`` drop their per-class implementations in favor of ~10 lines of declarations. As a side effect, ``MXFP8Tensor`` / ``NVFP4Tensor`` / ``Float8BlockwiseQTensor`` gain ``__tensor_flatten__`` / ``__tensor_unflatten__`` for free, enabling future use as ``torch.compile`` input subclasses. Signed-off-by: Pawel Gadzinski --- .../pytorch/quantized_tensor.py | 186 ++++++++++++++++-- .../pytorch/tensor/float8_tensor.py | 107 +++------- .../float8_blockwise_tensor_storage.py | 101 +++------- .../tensor/storage/float8_tensor_storage.py | 113 +++-------- .../tensor/storage/mxfp8_tensor_storage.py | 101 +++------- .../tensor/storage/nvfp4_tensor_storage.py | 119 ++++------- 6 files changed, 301 insertions(+), 426 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 21e5aca58e..03a7dbe4d1 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -120,6 +120,38 @@ class QuantizedTensorStorage: _dtype: torch.dtype _quantizer: Optional[Quantizer] + # ------------------------------------------------------------------ # + # Declarative schema for the unified flatten / unflatten machinery # + # (consumed by both the storage ``_torch_compile_flatten`` protocol # + # and ``QuantizedTensor``'s PyTorch ``__tensor_flatten__`` helper). # + # ------------------------------------------------------------------ # + + # Names of optional tensor attributes on the instance, in canonical + # order. Each name must be an attribute on ``self`` and must be + # accepted as a kwarg by ``cls(**kwargs)`` (potentially after + # remapping through :attr:`_FLATTEN_CTOR_KWARG`). + _FLATTEN_TENSOR_ATTRS: Tuple[str, ...] = () + + # Names of value-stable scalar / enum attributes needed to round-trip + # the instance. Same naming / kwarg conventions as + # :attr:`_FLATTEN_TENSOR_ATTRS`. + _FLATTEN_META_ATTRS: Tuple[str, ...] = () + + # Map from attribute name to constructor kwarg name, used when they + # differ (e.g. ``_data`` -> ``data``). Identity by default. + _FLATTEN_CTOR_KWARG: Dict[str, str] = {} + + @classmethod + def _flatten_meta_overrides(cls, meta: Dict[str, Any]) -> Dict[str, Any]: + """Hook for last-mile meta value massaging before unflatten dispatches + to ``cls(**kwargs)``. Default: no-op. + + Used today by :class:`Float8Tensor` to bridge :class:`FP8DType` + (carried by the subclass output spec) back to the native + ``tex.DType`` accepted by pybind-bound kernels. + """ + return meta + def update_usage( self, rowwise_usage: Optional[bool] = None, @@ -218,16 +250,66 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return id(self) + @classmethod + def _flatten_ctor_kw(cls, attr_name: str) -> str: + """Return the constructor kwarg name corresponding to ``attr_name``. + + Identity unless overridden via :attr:`_FLATTEN_CTOR_KWARG`. + """ + return cls._FLATTEN_CTOR_KWARG.get(attr_name, attr_name) + + @staticmethod + def _flatten_presence_key(attr_name: str) -> str: + """Return the ``has_*`` meta key indicating whether ``attr_name`` is + present in the flattened payload. Derived from the attribute name + (with the leading underscore stripped) so the static metadata + constructors in ``float8_tensor.py`` etc. don't need to know about + :attr:`_FLATTEN_CTOR_KWARG` remapping. + """ + return f"has_{attr_name.lstrip('_')}" + def _torch_compile_flatten( self, ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: - """Pack this storage's metadata and live tensor state for torch.compile.""" - raise NotImplementedError( - f"{type(self).__name__} class does not implement " - "_torch_compile_flatten; required for torch.compile support " - "of QuantizedTensorStorage objects." + """Pack storage state into the ``(meta, pg, tensors)`` triplet + consumed by :mod:`transformer_engine.pytorch.dynamo`. + + Generic implementation driven by :attr:`_FLATTEN_TENSOR_ATTRS`, + :attr:`_FLATTEN_META_ATTRS`, and :attr:`_FLATTEN_CTOR_KWARG`. + Quantizer-with-tensors (e.g. :class:`Float8Quantizer`'s + ``scale`` / ``amax``) is round-tripped via + :meth:`Quantizer._flatten`; quantizer tensors trail the + storage's own tensors in the flat list. + """ + from transformer_engine.pytorch.dynamo import ( # pylint: disable=import-outside-toplevel + OpaqueSimpleMetadata, ) + tensors: List[torch.Tensor] = [] + is_tensor = isinstance(self, torch.Tensor) + meta_dict: Dict[str, Any] = { + "_qstorage_cls": type(self).__qualname__, + "is_tensor": is_tensor, + "shape": torch.Size(self.shape) if is_tensor else None, + "requires_grad": self.requires_grad if is_tensor else False, + "device": self.device if is_tensor else None, + } + for attr in self._FLATTEN_META_ATTRS: + meta_dict[self._flatten_ctor_kw(attr)] = getattr(self, attr) + for attr in self._FLATTEN_TENSOR_ATTRS: + tensor = getattr(self, attr) + present = tensor is not None + meta_dict[self._flatten_presence_key(attr)] = present + if present: + tensors.append(tensor) + quantizer_meta = None + process_group = None + if self._quantizer is not None: + quantizer_meta, process_group, q_tensors = self._quantizer._flatten() + tensors.extend(q_tensors) + meta_dict["quantizer_meta"] = quantizer_meta + return OpaqueSimpleMetadata(meta_dict), process_group, tensors + @classmethod def _torch_compile_do_unflatten( cls, @@ -235,12 +317,35 @@ def _torch_compile_do_unflatten( process_group: Optional["torch.distributed.ProcessGroup"], tensors: List[torch.Tensor], ) -> "QuantizedTensorStorage": - """Reconstruct an instance of ``cls`` from storage flatten data.""" - raise NotImplementedError( - f"{cls.__name__} class does not implement " - "_torch_compile_do_unflatten; required for torch.compile " - "support of QuantizedTensorStorage objects." - ) + """Reconstruct ``cls`` from a triplet produced by + :meth:`_torch_compile_flatten`. Generic; driven by the same + ``_FLATTEN_*`` declarations. + """ + meta = cls._flatten_meta_overrides(meta) + tensor_iter = iter(tensors) + kwargs: Dict[str, Any] = {} + for attr in cls._FLATTEN_TENSOR_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = next(tensor_iter) if meta[cls._flatten_presence_key(attr)] else None + quantizer = None + if meta["quantizer_meta"] is not None: + quantizer = Quantizer._unflatten( + meta["quantizer_meta"], process_group, list(tensor_iter) + ) + for attr in cls._FLATTEN_META_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = meta[kw] + kwargs["quantizer"] = quantizer + if meta["is_tensor"]: + kwargs.update( + { + "shape": meta["shape"], + "dtype": kwargs["fake_dtype"], + "requires_grad": meta["requires_grad"], + "device": meta["device"], + } + ) + return cls(**kwargs) @classmethod def _torch_compile_unflatten( @@ -915,6 +1020,65 @@ def get_metadata(self) -> Dict[str, Any]: f"{self.__class__.__name__} class does not implement get_metadata function" ) + # ------------------------------------------------------------------ # + # PyTorch wrapper-subclass flatten / unflatten # + # ------------------------------------------------------------------ # + # + # Driven by the same ``_FLATTEN_*_ATTRS`` / ``_FLATTEN_CTOR_KWARG`` + # declarations as :meth:`QuantizedTensorStorage._torch_compile_flatten`, + # plus the :meth:`_flatten_meta_overrides` hook (Float8Tensor uses it + # to bridge :class:`FP8DType` <-> ``tex.DType``). + # + # Per-subclass differences vs the storage path: PyTorch's protocol + # carries only attributes living on ``self`` (no quantizer tensors, + # no process group). Quantizers whose state contains tensors (e.g. + # :class:`Float8Quantizer`'s ``scale`` / ``amax``, + # :class:`NVFP4Quantizer`'s ``rht_matrix``) therefore round-trip via + # :func:`_quantizer_subclass_snapshot`, which bails to ``None``; the + # reconstructed tensor's ``_quantizer`` is ``None`` and downstream + # code that needs the live quantizer sources it from the bucket-level + # opaque metadata flowing alongside the inner op. + + def __tensor_flatten__(self) -> Tuple[list, dict]: + if not type(self)._FLATTEN_TENSOR_ATTRS: + raise NotImplementedError( + f"{type(self).__name__} did not declare _FLATTEN_TENSOR_ATTRS" + ) + inner: list = [ + attr for attr in self._FLATTEN_TENSOR_ATTRS if getattr(self, attr) is not None + ] + meta: Dict[str, Any] = { + "quantizer_snapshot": _quantizer_subclass_snapshot(self._quantizer), + "requires_grad": self.requires_grad, + } + for attr in self._FLATTEN_META_ATTRS: + meta[self._flatten_ctor_kw(attr)] = getattr(self, attr) + return inner, meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors: dict, + meta: dict, + outer_size, + outer_stride, + ) -> "QuantizedTensor": + meta = cls._flatten_meta_overrides(meta) + quantizer = _quantizer_from_subclass_snapshot(meta.get("quantizer_snapshot")) + kwargs: Dict[str, Any] = { + "shape": outer_size, + "dtype": meta["fake_dtype"], + "requires_grad": meta.get("requires_grad", False), + "quantizer": quantizer, + } + for attr in cls._FLATTEN_TENSOR_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = inner_tensors.get(attr) + for attr in cls._FLATTEN_META_ATTRS: + kw = cls._flatten_ctor_kw(attr) + kwargs[kw] = meta[kw] + return cls(**kwargs) + @classmethod def make_like( cls, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index ed26d50773..6f7b6790b9 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -21,8 +21,6 @@ from ..quantized_tensor import ( QuantizedTensor, Quantizer, - _quantizer_from_subclass_snapshot, - _quantizer_subclass_snapshot, ) from ._quantization_helpers import _IdentityFunc from ..constants import canonicalize_te_dtype, dist_group_type @@ -72,13 +70,15 @@ def _float8_create_subclass_metadata( ``inner_names`` reflects the rowwise / columnwise usage flags of the quantizer (``_data`` and/or ``_transpose``, plus always ``_scale_inv``). ``meta`` carries the static, Dynamo-friendly attributes - :class:`Float8Tensor`'s constructor needs: + :class:`Float8Tensor`'s constructor needs (matching the schema produced + by :meth:`Float8Tensor._generic_tensor_flatten`): - * ``_fp8_dtype`` -- :class:`FP8DType` (an :class:`IntEnum`, + * ``fp8_dtype`` -- :class:`FP8DType` (an :class:`IntEnum`, proxies as a constant for Dynamo; bridges back to ``tex.DType`` - via :func:`to_tex` on the kernel side). - * ``_fake_dtype`` -- caller-supplied torch dtype. - * ``_quantizer_snapshot`` -- always ``None`` on this path. Re-using + via :meth:`Float8Tensor._flatten_meta_overrides` inside + ``__tensor_unflatten__``). + * ``fake_dtype`` -- caller-supplied torch dtype. + * ``quantizer_snapshot`` -- always ``None`` on this path. Re-using the snapshot reconstruction (which builds a fresh quantizer inside :meth:`Float8Tensor.__tensor_unflatten__`) would force Dynamo to trace a quantizer constructor call, which routinely @@ -86,7 +86,7 @@ def _float8_create_subclass_metadata( ``quantizer=None`` keeps the wrapper construction within Dynamo's proxyable surface; user code that needs the live quantizer sources it from outside the compiled region. - * ``_requires_grad`` -- caller-supplied flag. + * ``requires_grad`` -- caller-supplied flag. """ inner_names: List[str] = [] if quantizer.rowwise_usage: @@ -95,10 +95,10 @@ def _float8_create_subclass_metadata( if quantizer.columnwise_usage: inner_names.append("_transpose") meta = { - "_fp8_dtype": from_tex(quantizer.dtype), - "_fake_dtype": fake_dtype, - "_quantizer_snapshot": None, - "_requires_grad": requires_grad, + "fp8_dtype": from_tex(quantizer.dtype), + "fake_dtype": fake_dtype, + "quantizer_snapshot": None, + "requires_grad": requires_grad, } return tuple(inner_names), meta @@ -795,13 +795,6 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): """ - # Upper bound on the number of inner tensors produced by - # :meth:`__tensor_flatten__`. Used by the wide-output layout in - # :mod:`transformer_engine.pytorch.dynamo` to reserve enough slots in - # the custom-op ``Tensor[]`` return for any subclass-shaped output: - # data, scale_inv, transpose. - _TORCH_COMPILE_MAX_INNER_TENSORS = 3 - def __repr__(self, *, tensor_contents=None): # ``__repr__`` is on hot diagnostic paths (Dynamo's # ``Dynamo failed to run FX node`` formatter, autograd @@ -820,74 +813,18 @@ def __repr__(self, *, tensor_contents=None): ")" ) - def __tensor_flatten__(self) -> Tuple[list, dict]: - """torch.compile / tensor-subclass flatten protocol. - - Returns ``(inner_tensor_names, meta)`` so that PyTorch's - wrapper-subclass machinery and :func:`register_torch_dispatch` - rules on custom ops can decompose a ``Float8Tensor`` into - plain tensors plus a static metadata dict at trace time. - - The metadata dict must contain only values supporting stable - ``==`` comparison (Dynamo's tensor-subclass metadata guard - re-evaluates it via dict equality on every entry into the - compiled region). Mutable / runtime-only state such as the - ``_transpose_invalid`` flag deliberately does *not* end up - here; it would flip between calls and trip the "Guard failed - on the same frame" assertion. - - ``_quantizer_snapshot`` carries a tensor-free snapshot of - the live ``Quantizer`` so :meth:`__tensor_unflatten__` can - rebuild a structurally-equivalent quantizer on the unflatten - side. Quantizers that carry tensors in their state (e.g. - :class:`Float8Quantizer` keeps ``scale`` / ``amax``) cannot - be snapshotted into a guard-stable dict and produce a - ``None`` snapshot; in that case the reconstructed - ``Float8Tensor`` will have ``_quantizer = None`` and any - downstream code that needs the quantizer must source it from - elsewhere (typically the bucket-level opaque metadata on the - inner op call). + @classmethod + def _flatten_meta_overrides(cls, meta: dict) -> dict: + """Bridge :class:`FP8DType` (carried by the subclass output spec + via :func:`_float8_create_subclass_metadata`) back to the native + ``tex.DType`` accepted by pybind-bound TE kernels. The eager + :meth:`__tensor_flatten__` path stores ``tex.DType`` directly and + is a no-op here. """ - inner: list = [] - if self._data is not None: - inner.append("_data") - if self._scale_inv is not None: - inner.append("_scale_inv") - if self._transpose is not None: - inner.append("_transpose") - meta = { - "_fp8_dtype": self._fp8_dtype, - "_fake_dtype": self._dtype, - "_quantizer_snapshot": _quantizer_subclass_snapshot(self._quantizer), - "_requires_grad": self.requires_grad, - } - return inner, meta - - @staticmethod - def __tensor_unflatten__( - inner_tensors: dict, meta: dict, outer_size, outer_stride - ) -> "Float8Tensor": - quantizer = _quantizer_from_subclass_snapshot(meta.get("_quantizer_snapshot")) - fp8_dtype = meta["_fp8_dtype"] + fp8_dtype = meta.get("fp8_dtype") if isinstance(fp8_dtype, FP8DType): - # ``meta`` produced by :func:`_float8_create_subclass_metadata` - # carries the Dynamo-friendly :class:`FP8DType` enum (an - # ``IntEnum`` so it proxies as a constant during tracing). - # Pybind-bound TE kernels (e.g. ``tex.dequantize``) accept only - # the native ``transformer_engine_torch.DType``, so bridge back - # here. The eager ``__tensor_flatten__`` path stores the native - # enum directly and skips this conversion. - fp8_dtype = to_tex(fp8_dtype) - return Float8Tensor( - shape=outer_size, - dtype=meta["_fake_dtype"], - data=inner_tensors.get("_data"), - fp8_scale_inv=inner_tensors.get("_scale_inv"), - fp8_dtype=fp8_dtype, - data_transpose=inner_tensors.get("_transpose"), - quantizer=quantizer, - requires_grad=meta.get("_requires_grad", False), - ) + meta = {**meta, "fp8_dtype": to_tex(fp8_dtype)} + return meta def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 641192dfb2..b93c426401 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -6,7 +6,7 @@ from __future__ import annotations import math -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Any, Tuple import torch import transformer_engine_torch as tex @@ -46,6 +46,26 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): _columnwise_scale_inv: Optional[torch.Tensor] _is_2D_scaled: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + ) + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_is_2D_scaled") + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + "_is_2D_scaled": "is_2D_scaled", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -144,82 +164,9 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: - from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata - - tensors: List[torch.Tensor] = [] - - def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: - if tensor is None: - return False - tensors.append(tensor) - return True - - quantizer_meta = None - process_group = None - quantizer_tensors: List[torch.Tensor] = [] - if self._quantizer is not None: - quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": type(self).__qualname__, - "is_tensor": isinstance(self, torch.Tensor), - "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, - "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, - "device": self.device if isinstance(self, torch.Tensor) else None, - "fp8_dtype": self._fp8_dtype, - "fake_dtype": self._dtype, - "is_2D_scaled": self._is_2D_scaled, - "has_rowwise_data": _append_if_present(self._rowwise_data), - "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), - "has_columnwise_data": _append_if_present(self._columnwise_data), - "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), - "quantizer_meta": quantizer_meta, - } - ) - tensors.extend(quantizer_tensors) - return meta, process_group, tensors - - @classmethod - def _torch_compile_do_unflatten( - cls, - meta: Any, - process_group: Any, - tensors: List[torch.Tensor], - ) -> "Float8BlockwiseQTensorStorage": - tensor_iter = iter(tensors) - rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None - rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None - columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None - columnwise_scale_inv = ( - next(tensor_iter) if meta["has_columnwise_scale_inv"] else None - ) - quantizer = None - if meta["quantizer_meta"] is not None: - quantizer = Quantizer._unflatten( - meta["quantizer_meta"], process_group, list(tensor_iter) - ) - kwargs = { - "rowwise_data": rowwise_data, - "rowwise_scale_inv": rowwise_scale_inv, - "columnwise_data": columnwise_data, - "columnwise_scale_inv": columnwise_scale_inv, - "fp8_dtype": meta["fp8_dtype"], - "quantizer": quantizer, - "is_2D_scaled": meta["is_2D_scaled"], - "fake_dtype": meta["fake_dtype"], - } - if meta["is_tensor"]: - kwargs.update( - { - "shape": meta["shape"], - "dtype": meta["fake_dtype"], - "requires_grad": meta["requires_grad"], - "device": meta["device"], - } - ) - return cls(**kwargs) + # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are + # the generic implementations on :class:`QuantizedTensorStorage`, + # driven by the ``_FLATTEN_*`` declarations above. def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 51b28c766d..3e0625fe2a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -6,7 +6,7 @@ from __future__ import annotations import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import transformer_engine_torch as tex @@ -86,13 +86,18 @@ class Float8TensorStorage(QuantizedTensorStorage): _transpose: Optional[torch.Tensor] _transpose_invalid: bool - # Upper bound on the number of inner tensors produced by - # :meth:`_torch_compile_flatten`. Used by the wide-output layout in - # :mod:`transformer_engine.pytorch.dynamo` to reserve enough slots in - # the custom-op ``Tensor[]`` return for any storage-shaped output: - # 3 data tensors (data / transpose / scale_inv) + up to 2 quantizer - # tensors (Float8Quantizer carries scale / amax). - _TORCH_COMPILE_MAX_INNER_TENSORS = 5 + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ("_data", "_transpose", "_scale_inv") + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype") + _FLATTEN_CTOR_KWARG = { + "_data": "data", + "_transpose": "data_transpose", + "_scale_inv": "fp8_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + } def __new__( cls, @@ -267,89 +272,15 @@ def __repr__(self): ")" ) - def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: - from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata - - tensors: List[torch.Tensor] = [] - - def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: - if tensor is None: - return False - tensors.append(tensor) - return True - - quantizer_meta = None - process_group = None - quantizer_tensors: List[torch.Tensor] = [] - if self._quantizer is not None: - quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": type(self).__qualname__, - "is_tensor": isinstance(self, torch.Tensor), - "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, - "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, - "device": self.device if isinstance(self, torch.Tensor) else None, - "fp8_dtype": self._fp8_dtype, - "fake_dtype": self._dtype, - "transpose_invalid": self._transpose_invalid, - "has_data": _append_if_present(self._data), - "has_transpose": _append_if_present(self._transpose), - "has_scale_inv": _append_if_present(self._scale_inv), - "quantizer_meta": quantizer_meta, - } - ) - tensors.extend(quantizer_tensors) - return meta, process_group, tensors - - @classmethod - def _torch_compile_do_unflatten( - cls, - meta: Any, - process_group: Any, - tensors: List[torch.Tensor], - ) -> "Float8TensorStorage": - tensor_iter = iter(tensors) - data = next(tensor_iter) if meta["has_data"] else None - transpose = next(tensor_iter) if meta["has_transpose"] else None - scale_inv = next(tensor_iter) if meta["has_scale_inv"] else None - quantizer = None - if meta["quantizer_meta"] is not None: - quantizer = Quantizer._unflatten( - meta["quantizer_meta"], process_group, list(tensor_iter) - ) - kwargs = { - "data": data, - "fp8_scale_inv": scale_inv, - "fp8_dtype": meta["fp8_dtype"], - "data_transpose": transpose, - "quantizer": quantizer, - "fake_dtype": meta["fake_dtype"], - } - if meta["is_tensor"]: - kwargs.update( - { - "shape": meta["shape"], - "dtype": meta["fake_dtype"], - "requires_grad": meta["requires_grad"], - "device": meta["device"], - } - ) - out = cls(**kwargs) - # ``__new__`` already sets ``_transpose_invalid = (data_transpose - # is None)``, which is exactly the post-restoration semantic we - # want under :mod:`torch.compile`: a transpose buffer that the - # producer chose to ship through the trace was valid at flatten - # time (forward never emits stale transposes onto saved - # tensors), so the unflattened storage must treat it as valid. - # Trusting ``meta["transpose_invalid"]`` instead would re-pin the - # stale ``True`` that Dynamo embeds into the metadata constant - # because it cannot follow the in-place - # :meth:`restore_from_saved` write through ``ctx.tensor_objects`` - # and would then fail the :meth:`update_usage` - # ``not has_data_transpose`` guard in backward. - return out + # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are + # the generic implementations on :class:`QuantizedTensorStorage`, + # driven by the ``_FLATTEN_*`` declarations above. ``__new__`` + # re-derives ``_transpose_invalid`` from the restored ``_transpose`` + # buffer, so we deliberately do not round-trip the flag through + # ``_FLATTEN_META_ATTRS``: a producer that ships a transpose through + # the trace had it valid, and trusting a stale ``True`` from a + # Dynamo-embedded meta constant would trip + # :meth:`update_usage`'s ``not has_data_transpose`` guard in backward. def _create_transpose(self): """Update FP8 transpose cache""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index edc1dd8ac1..6c96937428 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -5,7 +5,7 @@ """Mixin class holding data specific for MXFP8Tensor""" from __future__ import annotations -from typing import Optional, Dict, Any, List, Tuple +from typing import Optional, Dict, Any, Tuple from collections.abc import Iterable import math import torch @@ -89,6 +89,26 @@ class MXFP8TensorStorage(QuantizedTensorStorage): # GEMM _with_gemm_swizzled_scales: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + ) + _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_with_gemm_swizzled_scales") + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_fp8_dtype": "fp8_dtype", + "_dtype": "fake_dtype", + "_with_gemm_swizzled_scales": "with_gemm_swizzled_scales", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -183,82 +203,9 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: - from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata - - tensors: List[torch.Tensor] = [] - - def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: - if tensor is None: - return False - tensors.append(tensor) - return True - - quantizer_meta = None - process_group = None - quantizer_tensors: List[torch.Tensor] = [] - if self._quantizer is not None: - quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": type(self).__qualname__, - "is_tensor": isinstance(self, torch.Tensor), - "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, - "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, - "device": self.device if isinstance(self, torch.Tensor) else None, - "fp8_dtype": self._fp8_dtype, - "fake_dtype": self._dtype, - "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, - "has_rowwise_data": _append_if_present(self._rowwise_data), - "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), - "has_columnwise_data": _append_if_present(self._columnwise_data), - "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), - "quantizer_meta": quantizer_meta, - } - ) - tensors.extend(quantizer_tensors) - return meta, process_group, tensors - - @classmethod - def _torch_compile_do_unflatten( - cls, - meta: Any, - process_group: Any, - tensors: List[torch.Tensor], - ) -> "MXFP8TensorStorage": - tensor_iter = iter(tensors) - rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None - rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None - columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None - columnwise_scale_inv = ( - next(tensor_iter) if meta["has_columnwise_scale_inv"] else None - ) - quantizer = None - if meta["quantizer_meta"] is not None: - quantizer = Quantizer._unflatten( - meta["quantizer_meta"], process_group, list(tensor_iter) - ) - kwargs = { - "rowwise_data": rowwise_data, - "rowwise_scale_inv": rowwise_scale_inv, - "columnwise_data": columnwise_data, - "columnwise_scale_inv": columnwise_scale_inv, - "fp8_dtype": meta["fp8_dtype"], - "quantizer": quantizer, - "with_gemm_swizzled_scales": meta["with_gemm_swizzled_scales"], - "fake_dtype": meta["fake_dtype"], - } - if meta["is_tensor"]: - kwargs.update( - { - "shape": meta["shape"], - "dtype": meta["fake_dtype"], - "requires_grad": meta["requires_grad"], - "device": meta["device"], - } - ) - return cls(**kwargs) + # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are + # the generic implementations on :class:`QuantizedTensorStorage`, + # driven by the ``_FLATTEN_*`` declarations above. def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 0e4810a4ea..ad164ca118 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -8,7 +8,7 @@ from collections.abc import Iterable import functools import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import warnings import torch @@ -110,6 +110,36 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool + # Declarative schema consumed by the generic + # :meth:`QuantizedTensorStorage._torch_compile_flatten` / + # :meth:`_torch_compile_do_unflatten` implementations in the base. + _FLATTEN_TENSOR_ATTRS = ( + "_rowwise_data", + "_rowwise_scale_inv", + "_columnwise_data", + "_columnwise_scale_inv", + "_amax_rowwise", + "_amax_columnwise", + ) + _FLATTEN_META_ATTRS = ( + "_fp4_dtype", + "_dtype", + "_with_gemm_swizzled_scales", + "_row_scaled_nvfp4", + ) + _FLATTEN_CTOR_KWARG = { + "_rowwise_data": "rowwise_data", + "_rowwise_scale_inv": "rowwise_scale_inv", + "_columnwise_data": "columnwise_data", + "_columnwise_scale_inv": "columnwise_scale_inv", + "_amax_rowwise": "amax_rowwise", + "_amax_columnwise": "amax_columnwise", + "_fp4_dtype": "fp4_dtype", + "_dtype": "fake_dtype", + "_with_gemm_swizzled_scales": "with_gemm_swizzled_scales", + "_row_scaled_nvfp4": "row_scaled_nvfp4", + } + def __new__( cls, rowwise_data: Optional[torch.Tensor], @@ -226,90 +256,9 @@ def restore_from_saved( self._amax_columnwise = tensors[5] return tensors[6:] - def _torch_compile_flatten(self) -> Tuple[Any, Any, List[torch.Tensor]]: - from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata - - tensors: List[torch.Tensor] = [] - - def _append_if_present(tensor: Optional[torch.Tensor]) -> bool: - if tensor is None: - return False - tensors.append(tensor) - return True - - quantizer_meta = None - process_group = None - quantizer_tensors: List[torch.Tensor] = [] - if self._quantizer is not None: - quantizer_meta, process_group, quantizer_tensors = self._quantizer._flatten() - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": type(self).__qualname__, - "is_tensor": isinstance(self, torch.Tensor), - "shape": torch.Size(self.shape) if isinstance(self, torch.Tensor) else None, - "requires_grad": self.requires_grad if isinstance(self, torch.Tensor) else False, - "device": self.device if isinstance(self, torch.Tensor) else None, - "fp4_dtype": self._fp4_dtype, - "fake_dtype": self._dtype, - "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, - "row_scaled_nvfp4": self._row_scaled_nvfp4, - "has_rowwise_data": _append_if_present(self._rowwise_data), - "has_rowwise_scale_inv": _append_if_present(self._rowwise_scale_inv), - "has_columnwise_data": _append_if_present(self._columnwise_data), - "has_columnwise_scale_inv": _append_if_present(self._columnwise_scale_inv), - "has_amax_rowwise": _append_if_present(self._amax_rowwise), - "has_amax_columnwise": _append_if_present(self._amax_columnwise), - "quantizer_meta": quantizer_meta, - } - ) - tensors.extend(quantizer_tensors) - return meta, process_group, tensors - - @classmethod - def _torch_compile_do_unflatten( - cls, - meta: Any, - process_group: Any, - tensors: List[torch.Tensor], - ) -> "NVFP4TensorStorage": - tensor_iter = iter(tensors) - rowwise_data = next(tensor_iter) if meta["has_rowwise_data"] else None - rowwise_scale_inv = next(tensor_iter) if meta["has_rowwise_scale_inv"] else None - columnwise_data = next(tensor_iter) if meta["has_columnwise_data"] else None - columnwise_scale_inv = ( - next(tensor_iter) if meta["has_columnwise_scale_inv"] else None - ) - amax_rowwise = next(tensor_iter) if meta["has_amax_rowwise"] else None - amax_columnwise = next(tensor_iter) if meta["has_amax_columnwise"] else None - quantizer = None - if meta["quantizer_meta"] is not None: - quantizer = Quantizer._unflatten( - meta["quantizer_meta"], process_group, list(tensor_iter) - ) - kwargs = { - "rowwise_data": rowwise_data, - "rowwise_scale_inv": rowwise_scale_inv, - "columnwise_data": columnwise_data, - "columnwise_scale_inv": columnwise_scale_inv, - "amax_rowwise": amax_rowwise, - "amax_columnwise": amax_columnwise, - "fp4_dtype": meta["fp4_dtype"], - "quantizer": quantizer, - "with_gemm_swizzled_scales": meta["with_gemm_swizzled_scales"], - "fake_dtype": meta["fake_dtype"], - "row_scaled_nvfp4": meta["row_scaled_nvfp4"], - } - if meta["is_tensor"]: - kwargs.update( - { - "shape": meta["shape"], - "dtype": meta["fake_dtype"], - "requires_grad": meta["requires_grad"], - "device": meta["device"], - } - ) - return cls(**kwargs) + # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are + # the generic implementations on :class:`QuantizedTensorStorage`, + # driven by the ``_FLATTEN_*`` declarations above. def get_data_tensors(self): """Get this Tensor's data.""" From e4dfb9aff989faaaba454ebbf6a4f14ce3a388ac Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 17:32:24 +0200 Subject: [PATCH 10/16] [PyTorch] Generic Quantizer.create_storage_metadata via declarative schema Move ``create_storage_metadata`` onto the ``Quantizer`` base class, driven by the storage's ``_FLATTEN_*`` declarations plus two new hooks: a per-quantizer ``_storage_cls`` class attribute pointing at the storage subclass it produces and a ``_storage_scalars()`` method returning the quantizer-specific scalar fields (``fp8_dtype``, ``with_gemm_swizzled_scales``, ...). Per-attribute presence flags are derived from the storage's new ``_FLATTEN_TENSOR_USAGE`` map plus the quantizer's ``rowwise_usage`` / ``columnwise_usage``. The four bespoke ``*Quantizer.create_storage_metadata`` methods and the ``_float8_create_storage_metadata`` helper collapse to ~5 lines of declarations per quantizer. Also slim the storage flatten path: ``is_tensor`` / ``shape`` / ``requires_grad`` / ``device`` are only emitted when the storage is flattened as a ``torch.Tensor`` (the ``_rewrite_subclass_to_storage`` path); the storage-only path keeps a smaller meta payload, and ``_torch_compile_do_unflatten`` falls back via ``meta.get`` instead of requiring those keys. Signed-off-by: Pawel Gadzinski --- .../pytorch/quantized_tensor.py | 111 +++++++++++++-- .../pytorch/tensor/float8_blockwise_tensor.py | 61 +-------- .../pytorch/tensor/float8_tensor.py | 126 ++---------------- .../pytorch/tensor/mxfp8_tensor.py | 77 ++--------- .../pytorch/tensor/nvfp4_tensor.py | 71 ++-------- .../float8_blockwise_tensor_storage.py | 6 + .../tensor/storage/float8_tensor_storage.py | 5 + .../tensor/storage/mxfp8_tensor_storage.py | 6 + .../tensor/storage/nvfp4_tensor_storage.py | 8 ++ 9 files changed, 160 insertions(+), 311 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 03a7dbe4d1..7506a2aac1 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -132,6 +132,14 @@ class QuantizedTensorStorage: # remapping through :attr:`_FLATTEN_CTOR_KWARG`). _FLATTEN_TENSOR_ATTRS: Tuple[str, ...] = () + # Maps each entry in :attr:`_FLATTEN_TENSOR_ATTRS` to one of + # ``"rowwise"`` / ``"columnwise"`` / ``"always"``. Consumed by + # :meth:`Quantizer.create_storage_metadata` to translate a live + # quantizer's ``rowwise_usage`` / ``columnwise_usage`` flags into + # per-attribute presence (``has_*``) flags at output-spec time. + # Unmapped attributes default to ``"always"``. + _FLATTEN_TENSOR_USAGE: Dict[str, str] = {} + # Names of value-stable scalar / enum attributes needed to round-trip # the instance. Same naming / kwarg conventions as # :attr:`_FLATTEN_TENSOR_ATTRS`. @@ -286,14 +294,20 @@ def _torch_compile_flatten( ) tensors: List[torch.Tensor] = [] - is_tensor = isinstance(self, torch.Tensor) - meta_dict: Dict[str, Any] = { - "_qstorage_cls": type(self).__qualname__, - "is_tensor": is_tensor, - "shape": torch.Size(self.shape) if is_tensor else None, - "requires_grad": self.requires_grad if is_tensor else False, - "device": self.device if is_tensor else None, - } + meta_dict: Dict[str, Any] = {"_qstorage_cls": type(self).__qualname__} + # Tensor-wrapper fields are only relevant when ``self`` is a live + # ``torch.Tensor`` (e.g. ``Float8Tensor`` rewritten in-place to a + # storage payload by ``_rewrite_subclass_to_storage``); a bare + # storage shell has no outer shape / requires_grad / device. + if isinstance(self, torch.Tensor): + meta_dict.update( + { + "is_tensor": True, + "shape": torch.Size(self.shape), + "requires_grad": self.requires_grad, + "device": self.device, + } + ) for attr in self._FLATTEN_META_ATTRS: meta_dict[self._flatten_ctor_kw(attr)] = getattr(self, attr) for attr in self._FLATTEN_TENSOR_ATTRS: @@ -336,7 +350,7 @@ def _torch_compile_do_unflatten( kw = cls._flatten_ctor_kw(attr) kwargs[kw] = meta[kw] kwargs["quantizer"] = quantizer - if meta["is_tensor"]: + if meta.get("is_tensor", False): kwargs.update( { "shape": meta["shape"], @@ -456,6 +470,13 @@ class Quantizer(abc.ABC): """ rowwise_usage: bool + # The :class:`QuantizedTensorStorage` subclass produced by this + # quantizer's quantize / make_empty path. Consumed by + # :meth:`create_storage_metadata` to declare a ``("storage", ...)`` + # output payload that round-trips through the generic + # :meth:`QuantizedTensorStorage._torch_compile_do_unflatten`. + _storage_cls: type["QuantizedTensorStorage"] + """Whether to construct quantized tensors with "column-wise usage" Hand-wave explanation: Consider the matrix multiplication C = A^T @@ -685,6 +706,78 @@ def _unflatten( ) return target._do_unflatten(meta, process_group, tensors) + def _storage_scalars(self) -> Dict[str, Any]: + """Per-quantizer scalar fields for the storage's ``_FLATTEN_META_ATTRS``. + + Keys are constructor kwarg names (matching the values of + :attr:`QuantizedTensorStorage._FLATTEN_CTOR_KWARG`). ``fake_dtype`` + is supplied separately by :meth:`create_storage_metadata`; subclasses + only need to return their quantizer-specific scalars (e.g. + ``fp8_dtype``, ``with_gemm_swizzled_scales``). + """ + raise NotImplementedError( + f"{type(self).__name__} class does not implement _storage_scalars; " + "required for torch.compile output specs that emit a " + "QuantizedTensorStorage." + ) + + def create_storage_metadata( + self, + *, + shape: Iterable[int], + fake_dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> Tuple[type["QuantizedTensorStorage"], Any, Optional[Any], int]: + """Return ``(cls, meta, process_group, tensor_count)`` describing + the ``("storage", ...)`` payload of a Dynamo output spec. + + The Dynamo layer hands the trailing + ``(meta, process_group, tensors[: tensor_count])`` triple to + :meth:`QuantizedTensorStorage._torch_compile_do_unflatten` to + reconstruct the freshly-quantized storage on the consumer side. + + Driven entirely by the storage's ``_FLATTEN_*`` schema plus a + per-quantizer :meth:`_storage_scalars` hook; ``has_*`` flags are + derived from ``rowwise_usage`` / ``columnwise_usage`` and the + storage's :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_USAGE` + map. Quantizers with tensor state (e.g. :class:`Float8Quantizer`'s + ``scale`` / ``amax``) append those tensors after the storage's own + slots; :meth:`Quantizer._flatten` provides both the count and the + ``quantizer_meta`` payload needed to rebuild the quantizer. + """ + from .dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + if device is None: + device = torch.device("cuda") + del device, shape # storage-only path: no outer tensor view + storage_cls = type(self)._storage_cls + usage_flag = { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + "always": True, + } + has_flags: Dict[str, bool] = {} + tensor_count = 0 + for attr in storage_cls._FLATTEN_TENSOR_ATTRS: + usage = storage_cls._FLATTEN_TENSOR_USAGE.get(attr, "always") + flag = usage_flag[usage] + has_flags[storage_cls._flatten_presence_key(attr)] = flag + if flag: + tensor_count += 1 + quantizer_meta, _, quantizer_tensors = self._flatten() + tensor_count += len(quantizer_tensors) + scalars = self._storage_scalars() + scalars["fake_dtype"] = fake_dtype + meta = OpaqueSimpleMetadata( + { + "_qstorage_cls": storage_cls.__qualname__, + **scalars, + **has_flags, + "quantizer_meta": quantizer_meta, + } + ) + return storage_cls, meta, None, tensor_count + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ec60893d06..d171062c8e 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -37,6 +37,8 @@ class Float8BlockQuantizer(Quantizer): force_pow_2_scales: bool block_scaling_dim: int + _storage_cls = Float8BlockwiseQTensorStorage + def __init__( self, fp8_dtype: TE_DType, @@ -281,60 +283,11 @@ def calibrate(self, tensor: torch.Tensor) -> None: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling - def create_storage_metadata( - self, - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, - ): - """Return ``(cls, meta, process_group, tensor_count)`` - suitable as the ``("storage", ...)`` payload of a Dynamo - output spec; the dynamo layer hands the trailing - ``(meta, process_group, tensors[: tensor_count])`` triple to - :meth:`Float8BlockwiseQTensorStorage._torch_compile_do_unflatten` - for reconstruction. - - Same contract as :meth:`Float8Quantizer.create_storage_metadata` - / :meth:`MXFP8Quantizer.create_storage_metadata` -- see those - docstrings for the broader rationale; this variant carries the - extra ``is_2D_scaled`` flag that the blockwise storage needs - on reconstruction. - """ - if device is None: - device = torch.device("cuda") - shape = torch.Size(shape) - has_rowwise = bool(self.rowwise_usage) - has_columnwise = bool(self.columnwise_usage) - tensor_count = int(has_rowwise) * 2 + int(has_columnwise) * 2 - # Storage's :meth:`_torch_compile_flatten` also emits the live - # quantizer's flatten tensors (see - # :meth:`Float8Quantizer.create_storage_metadata` for - # rationale); keep the count + meta in sync. - quantizer_meta, _, quantizer_tensors = self._flatten() - tensor_count += len(quantizer_tensors) - from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": "Float8BlockwiseQTensorStorage", - "is_tensor": as_tensor, - "shape": shape if as_tensor else None, - "requires_grad": requires_grad if as_tensor else False, - "device": device if as_tensor else None, - "fp8_dtype": self.dtype, - "fake_dtype": fake_dtype, - "is_2D_scaled": self.block_scaling_dim == 2, - "has_rowwise_data": has_rowwise, - "has_rowwise_scale_inv": has_rowwise, - "has_columnwise_data": has_columnwise, - "has_columnwise_scale_inv": has_columnwise, - "quantizer_meta": quantizer_meta, - } - ) - return Float8BlockwiseQTensorStorage, meta, None, tensor_count + def _storage_scalars(self) -> dict: + return { + "fp8_dtype": self.dtype, + "is_2D_scaled": self.block_scaling_dim == 2, + } def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 6f7b6790b9..92d20fbb22 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,7 +4,7 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, List, Optional, Tuple, Iterable, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState @@ -103,86 +103,6 @@ def _float8_create_subclass_metadata( return tuple(inner_names), meta -def _float8_create_storage_metadata( - quantizer: "Quantizer", - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, -): - """Return ``(cls, meta, process_group, tensor_count)`` suitable - for use as the ``("storage", ...)`` payload of a Dynamo output - spec; the dynamo layer hands the trailing - ``(meta, process_group, tensors[: tensor_count])`` triple to - :meth:`Float8TensorStorage._torch_compile_do_unflatten` for - reconstruction. - - Companion of :func:`_float8_create_subclass_metadata` for the - pure-storage layout (used today for the FP8 weight workspace - returned alongside ``Linear`` 's primary output). ``meta`` is an - :class:`OpaqueSimpleMetadata` carrying: - - * the storage layout flags (``has_data``, ``has_transpose``, - ``has_scale_inv``) derived from the quantizer's rowwise / - columnwise usage, - * ``fp8_dtype`` (raw ``tex.DType`` -- the storage path does not - cross a Dynamo subclass-constructor boundary, so we can keep - the native enum here), - * ``fake_dtype`` / ``shape`` / ``device`` / ``requires_grad`` - describing the higher-precision view of the storage, - * ``quantizer_meta`` -- ``None`` for the same reason as in - :func:`_float8_create_subclass_metadata`. - - ``tensor_count`` is the number of flat inner tensors the storage - will consume from the op's ``Tensor[]`` return (``_data``, - ``_transpose``, ``_scale_inv``, in that order, only those whose - ``has_*`` flag is ``True``). The dynamo layer uses it to slice the - flat return; the storage's :meth:`_torch_compile_do_unflatten` - reassembles them via the same ``has_*`` flags. - """ - if device is None: - device = torch.device("cuda") - shape = torch.Size(shape) - has_data = bool(quantizer.rowwise_usage) - has_transpose = bool(quantizer.columnwise_usage) - has_scale_inv = True - tensor_count = int(has_data) + int(has_transpose) + int(has_scale_inv) - # Storage's :meth:`_torch_compile_flatten` also emits the live - # quantizer's flatten tensors when ``self._quantizer is not None`` - # (the impl-produced storage always carries one). Pull - # ``quantizer._flatten()`` to learn the count + meta so the - # metadata we publish here stays in lock-step with the slot count - # produced at flatten time. - quantizer_meta, _, quantizer_tensors = quantizer._flatten() - tensor_count += len(quantizer_tensors) - from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": "Float8TensorStorage", - "is_tensor": as_tensor, - "shape": shape if as_tensor else None, - "requires_grad": requires_grad if as_tensor else False, - "device": device if as_tensor else None, - "fp8_dtype": quantizer.dtype, - "fake_dtype": fake_dtype, - # ``Float8TensorStorage._torch_compile_do_unflatten`` skips - # the transpose-validity check when reconstructing; we - # publish ``False`` (valid transpose) here since a - # freshly-quantized storage with the configured usage - # always has up-to-date inner buffers. - "transpose_invalid": not has_transpose, - "has_data": has_data, - "has_transpose": has_transpose, - "has_scale_inv": has_scale_inv, - "quantizer_meta": quantizer_meta, - } - ) - return Float8TensorStorage, meta, None, tensor_count - - class Float8Quantizer(Quantizer): """Builder class for FP8 tensors with per-tensor delayed scaling @@ -201,6 +121,8 @@ class Float8Quantizer(Quantizer): """FP8 datatype""" dtype: TE_DType + _storage_cls = Float8TensorStorage + def __init__( self, scale: torch.Tensor, @@ -398,24 +320,8 @@ def create_metadata( requires_grad=requires_grad, ) - def create_storage_metadata( - self, - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, - ): - # pylint: disable=missing-function-docstring - return _float8_create_storage_metadata( - self, - shape=shape, - fake_dtype=fake_dtype, - device=device, - requires_grad=requires_grad, - as_tensor=as_tensor, - ) + def _storage_scalars(self) -> Dict[str, Any]: + return {"fp8_dtype": self.dtype} def _flatten(self): from ..dynamo import OpaqueSimpleMetadata @@ -479,6 +385,8 @@ class Float8CurrentScalingQuantizer(Quantizer): force_pow_2_scales: bool amax_epsilon: float + _storage_cls = Float8TensorStorage + def __init__( self, fp8_dtype: TE_DType, @@ -709,24 +617,8 @@ def create_metadata( requires_grad=requires_grad, ) - def create_storage_metadata( - self, - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, - ): - # pylint: disable=missing-function-docstring - return _float8_create_storage_metadata( - self, - shape=shape, - fake_dtype=fake_dtype, - device=device, - requires_grad=requires_grad, - as_tensor=as_tensor, - ) + def _storage_scalars(self) -> Dict[str, Any]: + return {"fp8_dtype": self.dtype} def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5767f254c5..83c800dfdf 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple, Union, Any +from typing import Any, Dict, Optional, Tuple, Union import warnings import torch @@ -35,6 +35,8 @@ class MXFP8Quantizer(Quantizer): dtype: TE_DType + _storage_cls = MXFP8TensorStorage + def __init__( self, fp8_dtype: TE_DType, @@ -255,74 +257,11 @@ def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> tor def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling - def create_storage_metadata( - self, - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, - ): - """Return ``(cls, meta, process_group, tensor_count)`` - suitable as the ``("storage", ...)`` payload of a Dynamo - output spec; the dynamo layer hands the trailing - ``(meta, process_group, tensors[: tensor_count])`` triple to - :meth:`MXFP8TensorStorage._torch_compile_do_unflatten` for - reconstruction. - - Mirrors what - :meth:`MXFP8TensorStorage._torch_compile_flatten` would emit - for a freshly-quantized storage configured with this - quantizer's rowwise / columnwise usage. ``tensor_count`` is - the variable-length count of present inner tensors - (rowwise_data, rowwise_scale_inv, columnwise_data, - columnwise_scale_inv, only those whose ``has_*`` flag is - true). The dynamo layer uses it to slice the op's flat - ``Tensor[]`` return; the storage's - :meth:`_torch_compile_do_unflatten` reassembles them via the - same ``has_*`` flags. - - ``quantizer_meta`` is set to ``None`` so the reconstructed - storage has ``_quantizer=None`` -- keeping the constructor - traceable by Dynamo, mirroring the behaviour of - :class:`Float8Quantizer.create_storage_metadata`. - """ - if device is None: - device = torch.device("cuda") - shape = torch.Size(shape) - has_rowwise = bool(self.rowwise_usage) - has_columnwise = bool(self.columnwise_usage) - tensor_count = ( - int(has_rowwise) * 2 # rowwise_data + rowwise_scale_inv - + int(has_columnwise) * 2 # columnwise_data + columnwise_scale_inv - ) - # Storage's :meth:`_torch_compile_flatten` also emits the live - # quantizer's flatten tensors (see - # :meth:`Float8Quantizer.create_storage_metadata` for - # rationale); keep the count + meta in sync. - quantizer_meta, _, quantizer_tensors = self._flatten() - tensor_count += len(quantizer_tensors) - from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": "MXFP8TensorStorage", - "is_tensor": as_tensor, - "shape": shape if as_tensor else None, - "requires_grad": requires_grad if as_tensor else False, - "device": device if as_tensor else None, - "fp8_dtype": self.dtype, - "fake_dtype": fake_dtype, - "with_gemm_swizzled_scales": self.optimize_for_gemm, - "has_rowwise_data": has_rowwise, - "has_rowwise_scale_inv": has_rowwise, - "has_columnwise_data": has_columnwise, - "has_columnwise_scale_inv": has_columnwise, - "quantizer_meta": quantizer_meta, - } - ) - return MXFP8TensorStorage, meta, None, tensor_count + def _storage_scalars(self) -> Dict[str, Any]: + return { + "fp8_dtype": self.dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + } def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5ba8ed1833..10d45d4561 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -7,7 +7,7 @@ from collections.abc import Iterable import math import warnings -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import functools import torch @@ -135,6 +135,8 @@ class NVFP4Quantizer(Quantizer): rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor + _storage_cls = NVFP4TensorStorage + def __init__( self, fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, @@ -415,67 +417,12 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling - def create_storage_metadata( - self, - *, - shape: Iterable[int], - fake_dtype: torch.dtype, - device: Optional[torch.device] = None, - requires_grad: bool = False, - as_tensor: bool = False, - ): - """Return ``(cls, meta, process_group, tensor_count)`` - suitable as the ``("storage", ...)`` payload of a Dynamo - output spec; the dynamo layer hands the trailing - ``(meta, process_group, tensors[: tensor_count])`` triple to - :meth:`NVFP4TensorStorage._torch_compile_do_unflatten` for - reconstruction. - - See :meth:`Float8Quantizer.create_storage_metadata` for the - general contract. This variant adds the FP4-specific - ``with_gemm_swizzled_scales`` / ``row_scaled_nvfp4`` flags, - and the two amax-row/columnwise inner tensors that come with - the NVFP4 storage layout. - """ - if device is None: - device = torch.device("cuda") - shape = torch.Size(shape) - has_rowwise = bool(self.rowwise_usage) - has_columnwise = bool(self.columnwise_usage) - # Counts: rowwise contributes data + scale_inv + amax; same for - # columnwise. Each pair toggles on its respective usage flag. - tensor_count = ( - int(has_rowwise) * 3 + int(has_columnwise) * 3 - ) - # Storage's :meth:`_torch_compile_flatten` also emits the live - # quantizer's flatten tensors (see - # :meth:`Float8Quantizer.create_storage_metadata` for - # rationale); keep the count + meta in sync. - quantizer_meta, _, quantizer_tensors = self._flatten() - tensor_count += len(quantizer_tensors) - from ..dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel - - meta = OpaqueSimpleMetadata( - { - "_qstorage_cls": "NVFP4TensorStorage", - "is_tensor": as_tensor, - "shape": shape if as_tensor else None, - "requires_grad": requires_grad if as_tensor else False, - "device": device if as_tensor else None, - "fp4_dtype": self.dtype, - "fake_dtype": fake_dtype, - "with_gemm_swizzled_scales": self.optimize_for_gemm, - "row_scaled_nvfp4": self.row_scaled_nvfp4, - "has_rowwise_data": has_rowwise, - "has_rowwise_scale_inv": has_rowwise, - "has_columnwise_data": has_columnwise, - "has_columnwise_scale_inv": has_columnwise, - "has_amax_rowwise": has_rowwise, - "has_amax_columnwise": has_columnwise, - "quantizer_meta": quantizer_meta, - } - ) - return NVFP4TensorStorage, meta, None, tensor_count + def _storage_scalars(self) -> Dict[str, Any]: + return { + "fp4_dtype": self.dtype, + "with_gemm_swizzled_scales": self.optimize_for_gemm, + "row_scaled_nvfp4": self.row_scaled_nvfp4, + } def _flatten(self): from ..dynamo import OpaqueSimpleMetadata diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index b93c426401..0bd1a3d555 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -55,6 +55,12 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): "_columnwise_data", "_columnwise_scale_inv", ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + } _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_is_2D_scaled") _FLATTEN_CTOR_KWARG = { "_rowwise_data": "rowwise_data", diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 3e0625fe2a..966bb43936 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -90,6 +90,11 @@ class Float8TensorStorage(QuantizedTensorStorage): # :meth:`QuantizedTensorStorage._torch_compile_flatten` / # :meth:`_torch_compile_do_unflatten` implementations in the base. _FLATTEN_TENSOR_ATTRS = ("_data", "_transpose", "_scale_inv") + _FLATTEN_TENSOR_USAGE = { + "_data": "rowwise", + "_transpose": "columnwise", + "_scale_inv": "always", + } _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype") _FLATTEN_CTOR_KWARG = { "_data": "data", diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 6c96937428..f827695294 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -98,6 +98,12 @@ class MXFP8TensorStorage(QuantizedTensorStorage): "_columnwise_data", "_columnwise_scale_inv", ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + } _FLATTEN_META_ATTRS = ("_fp8_dtype", "_dtype", "_with_gemm_swizzled_scales") _FLATTEN_CTOR_KWARG = { "_rowwise_data": "rowwise_data", diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index ad164ca118..53b5649956 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -121,6 +121,14 @@ class NVFP4TensorStorage(QuantizedTensorStorage): "_amax_rowwise", "_amax_columnwise", ) + _FLATTEN_TENSOR_USAGE = { + "_rowwise_data": "rowwise", + "_rowwise_scale_inv": "rowwise", + "_columnwise_data": "columnwise", + "_columnwise_scale_inv": "columnwise", + "_amax_rowwise": "rowwise", + "_amax_columnwise": "columnwise", + } _FLATTEN_META_ATTRS = ( "_fp4_dtype", "_dtype", From dc875f0f0f8c0e458345b222ab2ccfe2ede81259 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 17:41:53 +0200 Subject: [PATCH 11/16] [PyTorch] Generic Quantizer._flatten / _do_unflatten via declarative schema Move the TE-internal Quantizer flattening protocol (used by ``_FlattenableBucket`` in ``dynamo.py`` to round-trip quantizers through compiled regions and to embed ``quantizer_meta`` inside storage flatten payloads) onto the ``Quantizer`` base class. Subclasses now declare per-class attrs instead of writing 30-line ``_flatten`` / ``_do_unflatten`` pairs: * ``_DTYPE_INIT_KWARG`` -- ``"fp8_dtype"`` (default) or ``"fp4_dtype"``. * ``_INIT_META_ATTRS`` / ``_POST_INIT_META_ATTRS`` -- scalar attrs threaded through ``__init__`` vs. set after construction. * ``_INIT_TENSOR_ATTRS`` / ``_POST_INIT_TENSOR_ATTRS`` -- tensor attrs in flatten order, split the same way. * ``_PG_ATTR`` / ``_PG_INIT_KWARG`` -- name of the optional ``ProcessGroup`` attribute and (optionally) the ``__init__`` kwarg to thread it through. * ``_FIXED_INIT_KWARGS`` -- hardcoded ``__init__`` kwargs not derived from meta (e.g. ``device=torch.device("cuda")`` for ``Float8CurrentScalingQuantizer``). The base ``dtype`` / ``rowwise_usage`` / ``columnwise_usage`` round-trip through ``__init__`` and ``internal`` / ``optimize_for_gemm`` through post-init are handled by the base class. ``MXFP8Quantizer`` drops both methods entirely (no extras to declare). ``Float8Quantizer`` collapses to a single ``_INIT_TENSOR_ATTRS`` declaration. ``NVFP4Quantizer`` drops ~45 lines. Signed-off-by: Pawel Gadzinski --- .../pytorch/quantized_tensor.py | 104 ++++++++++++++---- .../pytorch/tensor/float8_blockwise_tensor.py | 37 +------ .../pytorch/tensor/float8_tensor.py | 70 +----------- .../pytorch/tensor/mxfp8_tensor.py | 27 ----- .../pytorch/tensor/nvfp4_tensor.py | 59 +++------- 5 files changed, 103 insertions(+), 194 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 7506a2aac1..1e6cc9ce8c 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -647,6 +647,40 @@ def __init_subclass__(cls, **kwargs: Any) -> None: # dispatch back to it by ``__qualname__``. _QUANTIZER_REGISTRY[cls.__qualname__] = cls + # ---- Declarative schema for the generic :meth:`_flatten` / ---- # + # ---- :meth:`_do_unflatten` implementations below. ---- # + + # ``__init__`` kwarg name for ``self.dtype`` (e.g. ``"fp8_dtype"``, + # ``"fp4_dtype"``). + _DTYPE_INIT_KWARG: str = "fp8_dtype" + + # Scalar attribute names (besides ``dtype`` / ``rowwise_usage`` / + # ``columnwise_usage``) threaded through ``__init__``. The kwarg name + # is assumed to match the attribute name. + _INIT_META_ATTRS: Tuple[str, ...] = () + + # Scalar attribute names (besides ``internal`` / ``optimize_for_gemm``) + # set on the instance after ``__init__``. + _POST_INIT_META_ATTRS: Tuple[str, ...] = () + + # Tensor attribute names threaded through ``__init__``, in flatten + # order. + _INIT_TENSOR_ATTRS: Tuple[str, ...] = () + + # Tensor attribute names set on the instance after ``__init__``. + _POST_INIT_TENSOR_ATTRS: Tuple[str, ...] = () + + # Attribute name on ``self`` holding the (optional) ``ProcessGroup``, + # or ``None`` if the quantizer has no PG. + _PG_ATTR: Optional[str] = None + # ``__init__`` kwarg name to thread the PG through. ``None`` means + # set ``_PG_ATTR`` directly after ``__init__``. + _PG_INIT_KWARG: Optional[str] = None + + # Hardcoded ``__init__`` kwargs not derived from meta (e.g. + # ``device=torch.device("cuda")`` for ``Float8CurrentScalingQuantizer``). + _FIXED_INIT_KWARGS: Dict[str, Any] = {} + def _flatten( self, ) -> Tuple[Any, Optional["torch.distributed.ProcessGroup"], List[torch.Tensor]]: @@ -654,23 +688,31 @@ def _flatten( ``(meta, process_group, tensors)`` triplet expected by the flattenable bucket in :mod:`transformer_engine.pytorch.dynamo`. - * ``meta`` -- :class:`OpaqueSimpleMetadata` of all simple state. - Subclasses **must** include their own ``cls.__qualname__`` under - the ``"_qcls"`` key so :meth:`_unflatten` can dispatch back to - ``_do_unflatten`` on the correct subclass. Common base state - (``rowwise_usage``, ``columnwise_usage``, ``internal``, - ``optimize_for_gemm``) is the subclass's responsibility too. - * ``process_group`` -- the (single) :class:`torch.distributed.ProcessGroup` - this quantizer participates in, or ``None``. Quantizers without a - process group return ``None``. - * ``tensors`` -- the live tensor state the op needs to receive - (e.g. ``scale``, ``amax``, RHT matrix). Order is - quantizer-defined and matches what ``_do_unflatten`` expects. + Generic implementation driven by the declarative schema attrs above. + Subclasses only declare which scalars / tensors go through + ``__init__`` vs. are set post-init; the base class round-trips + ``dtype`` / ``rowwise_usage`` / ``columnwise_usage`` and + ``internal`` / ``optimize_for_gemm`` on every quantizer. """ - raise NotImplementedError( - f"{type(self).__name__} class does not implement _flatten; " - "required for torch.compile support of TE custom ops." - ) + from .dynamo import OpaqueSimpleMetadata # pylint: disable=import-outside-toplevel + + cls = type(self) + meta_dict: Dict[str, Any] = { + "_qcls": cls.__qualname__, + "dtype": self.dtype, + "rowwise_usage": self.rowwise_usage, + "columnwise_usage": self.columnwise_usage, + "internal": self.internal, + "optimize_for_gemm": self.optimize_for_gemm, + } + for attr in (*cls._INIT_META_ATTRS, *cls._POST_INIT_META_ATTRS): + meta_dict[attr] = getattr(self, attr) + tensors = [ + getattr(self, attr) + for attr in (*cls._INIT_TENSOR_ATTRS, *cls._POST_INIT_TENSOR_ATTRS) + ] + pg = getattr(self, cls._PG_ATTR) if cls._PG_ATTR else None + return OpaqueSimpleMetadata(meta_dict), pg, tensors @classmethod def _do_unflatten( @@ -680,12 +722,32 @@ def _do_unflatten( tensors: List[torch.Tensor], ) -> "Quantizer": """Reconstruct an instance of ``cls`` from the triplet returned by a - previous :meth:`_flatten` on the same subclass. Subclasses override. + previous :meth:`_flatten` on the same subclass. Generic; driven + by the declarative schema attrs. """ - raise NotImplementedError( - f"{cls.__name__} class does not implement _do_unflatten; " - "required for torch.compile support of TE custom ops." - ) + init_kwargs: Dict[str, Any] = { + cls._DTYPE_INIT_KWARG: meta["dtype"], + "rowwise": meta["rowwise_usage"], + "columnwise": meta["columnwise_usage"], + } + for attr in cls._INIT_META_ATTRS: + init_kwargs[attr] = meta[attr] + if cls._PG_INIT_KWARG is not None: + init_kwargs[cls._PG_INIT_KWARG] = process_group + init_kwargs.update(cls._FIXED_INIT_KWARGS) + tensor_iter = iter(tensors) + for attr in cls._INIT_TENSOR_ATTRS: + init_kwargs[attr] = next(tensor_iter) + q = cls(**init_kwargs) + q.internal = meta["internal"] + q.optimize_for_gemm = meta["optimize_for_gemm"] + for attr in cls._POST_INIT_META_ATTRS: + setattr(q, attr, meta[attr]) + for attr in cls._POST_INIT_TENSOR_ATTRS: + setattr(q, attr, next(tensor_iter)) + if cls._PG_ATTR is not None and cls._PG_INIT_KWARG is None: + setattr(q, cls._PG_ATTR, process_group) + return q @classmethod def _unflatten( diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index d171062c8e..2e18465322 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -38,6 +38,8 @@ class Float8BlockQuantizer(Quantizer): block_scaling_dim: int _storage_cls = Float8BlockwiseQTensorStorage + _INIT_META_ATTRS = ("amax_epsilon", "force_pow_2_scales", "block_scaling_dim") + _POST_INIT_META_ATTRS = ("block_len",) def __init__( self, @@ -289,41 +291,6 @@ def _storage_scalars(self) -> dict: "is_2D_scaled": self.block_scaling_dim == 2, } - def _flatten(self): - from ..dynamo import OpaqueSimpleMetadata - - meta = OpaqueSimpleMetadata( - { - "_qcls": type(self).__qualname__, - "dtype": self.dtype, - "rowwise_usage": self.rowwise_usage, - "columnwise_usage": self.columnwise_usage, - "internal": self.internal, - "optimize_for_gemm": self.optimize_for_gemm, - "block_len": self.block_len, - "amax_epsilon": self.amax_epsilon, - "force_pow_2_scales": self.force_pow_2_scales, - "block_scaling_dim": self.block_scaling_dim, - } - ) - return meta, None, [] - - @classmethod - def _do_unflatten(cls, meta, process_group, tensors): - del process_group, tensors - q = cls( - fp8_dtype=meta["dtype"], - rowwise=meta["rowwise_usage"], - columnwise=meta["columnwise_usage"], - amax_epsilon=meta["amax_epsilon"], - force_pow_2_scales=meta["force_pow_2_scales"], - block_scaling_dim=meta["block_scaling_dim"], - ) - q.block_len = meta["block_len"] - q.internal = meta["internal"] - q.optimize_for_gemm = meta["optimize_for_gemm"] - return q - class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 92d20fbb22..0735aa7b3c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -122,6 +122,7 @@ class Float8Quantizer(Quantizer): dtype: TE_DType _storage_cls = Float8TensorStorage + _INIT_TENSOR_ATTRS = ("scale", "amax") def __init__( self, @@ -323,36 +324,6 @@ def create_metadata( def _storage_scalars(self) -> Dict[str, Any]: return {"fp8_dtype": self.dtype} - def _flatten(self): - from ..dynamo import OpaqueSimpleMetadata - - meta = OpaqueSimpleMetadata( - { - "_qcls": type(self).__qualname__, - "dtype": self.dtype, - "rowwise_usage": self.rowwise_usage, - "columnwise_usage": self.columnwise_usage, - "internal": self.internal, - "optimize_for_gemm": self.optimize_for_gemm, - } - ) - return meta, None, [self.scale, self.amax] - - @classmethod - def _do_unflatten(cls, meta, process_group, tensors): - del process_group - scale, amax = tensors - q = cls( - scale=scale, - amax=amax, - fp8_dtype=meta["dtype"], - rowwise=meta["rowwise_usage"], - columnwise=meta["columnwise_usage"], - ) - q.internal = meta["internal"] - q.optimize_for_gemm = meta["optimize_for_gemm"] - return q - class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -386,6 +357,10 @@ class Float8CurrentScalingQuantizer(Quantizer): amax_epsilon: float _storage_cls = Float8TensorStorage + _INIT_META_ATTRS = ("with_amax_reduction", "force_pow_2_scales", "amax_epsilon") + _PG_ATTR = "amax_reduction_group" + _PG_INIT_KWARG = "amax_reduction_group" + _FIXED_INIT_KWARGS = {"device": torch.device("cuda")} def __init__( self, @@ -620,41 +595,6 @@ def create_metadata( def _storage_scalars(self) -> Dict[str, Any]: return {"fp8_dtype": self.dtype} - def _flatten(self): - from ..dynamo import OpaqueSimpleMetadata - - meta = OpaqueSimpleMetadata( - { - "_qcls": type(self).__qualname__, - "dtype": self.dtype, - "rowwise_usage": self.rowwise_usage, - "columnwise_usage": self.columnwise_usage, - "internal": self.internal, - "optimize_for_gemm": self.optimize_for_gemm, - "with_amax_reduction": self.with_amax_reduction, - "force_pow_2_scales": self.force_pow_2_scales, - "amax_epsilon": self.amax_epsilon, - } - ) - return meta, self.amax_reduction_group, [] - - @classmethod - def _do_unflatten(cls, meta, process_group, tensors): - del tensors - q = cls( - fp8_dtype=meta["dtype"], - device=torch.device("cuda"), - rowwise=meta["rowwise_usage"], - columnwise=meta["columnwise_usage"], - with_amax_reduction=meta["with_amax_reduction"], - amax_reduction_group=process_group, - force_pow_2_scales=meta["force_pow_2_scales"], - amax_epsilon=meta["amax_epsilon"], - ) - q.internal = meta["internal"] - q.optimize_for_gemm = meta["optimize_for_gemm"] - return q - class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 83c800dfdf..5a17e69a73 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -263,33 +263,6 @@ def _storage_scalars(self) -> Dict[str, Any]: "with_gemm_swizzled_scales": self.optimize_for_gemm, } - def _flatten(self): - from ..dynamo import OpaqueSimpleMetadata - - meta = OpaqueSimpleMetadata( - { - "_qcls": type(self).__qualname__, - "dtype": self.dtype, - "rowwise_usage": self.rowwise_usage, - "columnwise_usage": self.columnwise_usage, - "internal": self.internal, - "optimize_for_gemm": self.optimize_for_gemm, - } - ) - return meta, None, [] - - @classmethod - def _do_unflatten(cls, meta, process_group, tensors): - del process_group, tensors - q = cls( - fp8_dtype=meta["dtype"], - rowwise=meta["rowwise_usage"], - columnwise=meta["columnwise_usage"], - ) - q.internal = meta["internal"] - q.optimize_for_gemm = meta["optimize_for_gemm"] - return q - class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 10d45d4561..e8469b06d9 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -136,6 +136,19 @@ class NVFP4Quantizer(Quantizer): rht_matrix: torch.Tensor _storage_cls = NVFP4TensorStorage + _DTYPE_INIT_KWARG = "fp4_dtype" + _INIT_META_ATTRS = ( + "with_amax_reduction", + "with_rht", + "with_post_rht_amax", + "with_2d_quantization", + "stochastic_rounding", + "row_scaled_nvfp4", + ) + _POST_INIT_META_ATTRS = ("rht_matrix_random_sign_mask_t",) + _POST_INIT_TENSOR_ATTRS = ("rht_matrix",) + _PG_ATTR = "amax_reduction_group" + _PG_INIT_KWARG = "amax_reduction_group" def __init__( self, @@ -424,52 +437,6 @@ def _storage_scalars(self) -> Dict[str, Any]: "row_scaled_nvfp4": self.row_scaled_nvfp4, } - def _flatten(self): - from ..dynamo import OpaqueSimpleMetadata - - meta = OpaqueSimpleMetadata( - { - "_qcls": type(self).__qualname__, - "dtype": self.dtype, - "rowwise_usage": self.rowwise_usage, - "columnwise_usage": self.columnwise_usage, - "internal": self.internal, - "optimize_for_gemm": self.optimize_for_gemm, - "with_rht": self.with_rht, - "with_post_rht_amax": self.with_post_rht_amax, - "with_amax_reduction": self.with_amax_reduction, - "with_2d_quantization": self.with_2d_quantization, - "stochastic_rounding": self.stochastic_rounding, - "row_scaled_nvfp4": self.row_scaled_nvfp4, - "rht_matrix_random_sign_mask_t": self.rht_matrix_random_sign_mask_t, - } - ) - return meta, self.amax_reduction_group, [self.rht_matrix] - - @classmethod - def _do_unflatten(cls, meta, process_group, tensors): - (rht_matrix,) = tensors - # Construct with default RHT mask, then overwrite the computed - # ``rht_matrix_random_sign_mask_t`` / ``rht_matrix`` with the - # restored values so we don't depend on cuda helpers / device state. - q = cls( - fp4_dtype=meta["dtype"], - rowwise=meta["rowwise_usage"], - columnwise=meta["columnwise_usage"], - with_amax_reduction=meta["with_amax_reduction"], - amax_reduction_group=process_group, - with_rht=meta["with_rht"], - with_post_rht_amax=meta["with_post_rht_amax"], - with_2d_quantization=meta["with_2d_quantization"], - stochastic_rounding=meta["stochastic_rounding"], - row_scaled_nvfp4=meta["row_scaled_nvfp4"], - ) - q.rht_matrix_random_sign_mask_t = meta["rht_matrix_random_sign_mask_t"] - q.rht_matrix = rht_matrix - q.internal = meta["internal"] - q.optimize_for_gemm = meta["optimize_for_gemm"] - return q - class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data From e25913c5c775c1143855d5f76ba93340efe03c12 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 17:58:47 +0200 Subject: [PATCH 12/16] [PyTorch] Drop dead code and stale comments from torch.compile branch - Remove `fake_cast_if_needed` (utils.py) and `fake_quantize_weight` (module/base.py): leftover from the hand-written fake-impl design that output_info_fn / TensorSpec.alloc replaced; no remaining call sites. - Remove unused ``KIND`` class attributes from TensorSpec subclasses in dynamo.py; never read after the spec-dispatch refactor. - Fix stale comment references to nonexistent symbols (`_rewrite_subclass_to_storage`, `_generic_tensor_flatten`). - Trim verbose narration about "replaced hand-written fake-impls" in linear.py docstrings; the current behavior is described, history belongs in git. - Drop redundant "generic implementations on base" comments from MXFP8 and Float8Block storage; the ``_FLATTEN_*`` declarations speak for themselves. Keep the genuine `_transpose_invalid` rationale in float8_tensor_storage.py. - Fix copy-paste comment in nvfp4_tensor_storage.py (MXFP8 -> NVFP4). Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo.py | 12 ---- transformer_engine/pytorch/module/base.py | 59 ------------------- transformer_engine/pytorch/module/linear.py | 24 ++++---- .../pytorch/quantized_tensor.py | 4 +- .../pytorch/tensor/float8_tensor.py | 20 +++---- .../float8_blockwise_tensor_storage.py | 4 -- .../tensor/storage/float8_tensor_storage.py | 15 ++--- .../tensor/storage/mxfp8_tensor_storage.py | 4 -- .../tensor/storage/nvfp4_tensor_storage.py | 2 +- transformer_engine/pytorch/utils.py | 15 ----- 10 files changed, 29 insertions(+), 130 deletions(-) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 2b2d440a40..87b8224efc 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -132,8 +132,6 @@ class TensorSpec: each method plays in the forward / fake / setup-context pipelines. """ - KIND: str = "" - def slot_count(self) -> int: raise NotImplementedError( f"{type(self).__name__}.slot_count() not implemented" @@ -168,8 +166,6 @@ class NoneSpec(TensorSpec): end-to-end. """ - KIND = "none" - def slot_count(self) -> int: return 1 @@ -194,8 +190,6 @@ class AliasedSpec(TensorSpec): ``ctx_attrs["saved_tensor_aliases"]``. """ - KIND = "aliased" - def __init__(self, alias: str) -> None: self.alias = alias @@ -216,8 +210,6 @@ class PlainTensorSpec(TensorSpec): is just the lone slot value. """ - KIND = "plain" - def __init__( self, shape: Optional[Sequence[int]] = None, @@ -261,8 +253,6 @@ class SubclassTensorSpec(TensorSpec): undefined. """ - KIND = "subclass" - def __init__( self, *, @@ -368,8 +358,6 @@ class StorageSpec(TensorSpec): ``alloc_quantizer.make_empty(shape, ...)``. """ - KIND = "storage" - def __init__( self, cls: type, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 58f42781e0..c4691aa645 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -763,65 +763,6 @@ def quantize_weight( return out, None -def fake_quantize_weight( - *, - tensor: Optional[torch.Tensor] = None, - quantizer: Optional[Quantizer] = None, - workspace: Optional[QuantizedTensorStorage] = None, - fsdp_group: Optional["dist_group_type"] = None, - workspace_dtype: Optional[torch.dtype] = None, - cache: bool = False, -) -> Tuple[QuantizedTensorStorage, Optional[QuantizedTensorStorage]]: - """Fake counterpart of :func:`quantize_weight` for shape inference. - - Mirrors the cache-hit / cache-miss control flow of :func:`quantize_weight` - but never performs an actual quantization. Cache misses are filled with - ``quantizer.make_empty``. Used by torch custom-op fake registrations. - """ - - # Already-quantized weight (primary FP8 parameters, both the - # ``Float8Tensor``-style subclass wrappers and the bare - # ``Float8TensorStorage``-style flat carriers produced by the - # outer-op torch_dispatch rule on the way into the inner op). - if isinstance(tensor, QuantizedTensorStorage): - if quantizer is not None: - update_rowwise = True if quantizer.rowwise_usage else None - update_columnwise = True if quantizer.columnwise_usage else None - tensor.update_usage( - rowwise_usage=update_rowwise, - columnwise_usage=update_columnwise, - ) - return tensor, None - - # Validate workspace - if workspace is not None and quantizer is not None: - if not _is_weight_workspace_valid(workspace, quantizer): - workspace = None - - if workspace is not None and fsdp_group is not None: - raise NotImplementedError( - "fake_quantize_weight does not support FSDP weight workspaces" - ) - - # Cache hit - if workspace is not None: - return workspace, None - - # Cache miss — create new (fake) workspace - if tensor is None or quantizer is None: - raise ValueError( - "tensor and quantizer kwargs must be provided to construct FP8 workspace" - ) - out = quantizer.make_empty( - tensor.shape, - dtype=workspace_dtype, - device=tensor.device, - ) - if cache: - return out, out - return out, None - - class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b00bc4bb03..7ca8ba00d2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1291,16 +1291,14 @@ def _linear_backward_output_info( Returns a list of three :class:`TensorSpec` -- one per gradient output ``(wgrad, dgrad, grad_bias)`` -- consumed by the auto-synthesized backward fake-impl in - :func:`_make_fake_impl_from_bwd_output_info`. Replaces the - previously hand-written ``_linear_backward_fake_impl``: gradient - shapes / dtypes are deterministic, so the descriptor just encodes - each slot through :func:`tensor_spec` (passing ``shape=None`` for - absent grads and a ``quantizer`` for quantized ones -- backward - grads use alloc-only ``SubclassTensorSpec`` because they go - straight to autograd, never through the op's flat ``Tensor[]``). - ``set_usage`` on ``grad_input_quantizer`` is preserved because it - influences ``dgrad``'s downstream ``make_empty``. Manual TE FSDP - is unsupported; FSDP2 / MCore FSDP go through the standard path. + :func:`_make_fake_impl_from_bwd_output_info`. Each slot is encoded + through :func:`tensor_spec` (``shape=None`` for absent grads, + ``quantizer`` for quantized ones -- backward grads use alloc-only + ``SubclassTensorSpec`` because they go straight to autograd and + never through the op's flat ``Tensor[]``). ``set_usage`` on + ``grad_input_quantizer`` is preserved because it influences + ``dgrad``'s downstream ``make_empty``. Manual TE FSDP is + unsupported; FSDP2 / MCore FSDP go through the standard path. """ if args.fsdp_group is not None: @@ -1353,9 +1351,9 @@ def _linear_forward_output_info( """Output-layout descriptor for the linear forward. Returns ``(user_specs, saved_slots, ctx_attrs)`` -- Dynamo-traceable - layout + alloc info for the op's outputs and saved tensors. Replaces - a hand-written fake-impl: :func:`_te_register_custom_op` synthesizes - one by calling :meth:`TensorSpec.alloc` on each entry. + layout + alloc info for the op's outputs and saved tensors. + :func:`_te_register_custom_op` synthesizes the fake-impl by calling + :meth:`TensorSpec.alloc` on each entry. All ``set_usage`` side effects on the live quantizers happen here and are observed by both the real fwd impl and backward. diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 1e6cc9ce8c..afb65f6256 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -296,8 +296,8 @@ def _torch_compile_flatten( tensors: List[torch.Tensor] = [] meta_dict: Dict[str, Any] = {"_qstorage_cls": type(self).__qualname__} # Tensor-wrapper fields are only relevant when ``self`` is a live - # ``torch.Tensor`` (e.g. ``Float8Tensor`` rewritten in-place to a - # storage payload by ``_rewrite_subclass_to_storage``); a bare + # ``torch.Tensor`` (e.g. ``Float8Tensor`` flattened directly into a + # storage payload by ``_flatten_subclass_into_slots``); a bare # storage shell has no outer shape / requires_grad / device. if isinstance(self, torch.Tensor): meta_dict.update( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 0735aa7b3c..5ef6a23955 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -43,20 +43,19 @@ # --------------------------------------------------------------------------- # -# torch.compile output-layout metadata helpers +# torch.compile output-layout metadata helper # --------------------------------------------------------------------------- # # -# These helpers produce the static (inner-names + meta-dict) and -# storage-meta layouts that the dynamo integration layer needs to -# reassemble a :class:`Float8Tensor` / :class:`Float8TensorStorage` -# from the flat ``Tensor[]`` return of a TE custom op, without -# allocating a fake prototype tensor inside a traced region. +# Produces the static (inner-names + meta-dict) layout that the dynamo +# integration layer needs to reassemble a :class:`Float8Tensor` from +# the flat ``Tensor[]`` return of a TE custom op, without allocating a +# fake prototype tensor inside a traced region. # # Shared between :class:`Float8Quantizer` and # :class:`Float8CurrentScalingQuantizer` because both produce identical -# ``Float8Tensor`` / ``Float8TensorStorage`` layouts (rowwise / columnwise / -# scale-inv inner tensors); the per-quantizer ``create_metadata`` / -# ``create_storage_metadata`` methods delegate here. +# ``Float8Tensor`` layouts (rowwise / columnwise / scale-inv inner +# tensors); the per-quantizer ``create_metadata`` methods delegate +# here. def _float8_create_subclass_metadata( @@ -70,8 +69,7 @@ def _float8_create_subclass_metadata( ``inner_names`` reflects the rowwise / columnwise usage flags of the quantizer (``_data`` and/or ``_transpose``, plus always ``_scale_inv``). ``meta`` carries the static, Dynamo-friendly attributes - :class:`Float8Tensor`'s constructor needs (matching the schema produced - by :meth:`Float8Tensor._generic_tensor_flatten`): + :class:`Float8Tensor`'s constructor needs: * ``fp8_dtype`` -- :class:`FP8DType` (an :class:`IntEnum`, proxies as a constant for Dynamo; bridges back to ``tex.DType`` diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 0bd1a3d555..3733a9d32e 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -170,10 +170,6 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are - # the generic implementations on :class:`QuantizedTensorStorage`, - # driven by the ``_FLATTEN_*`` declarations above. - def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" if rowwise_data and columnwise_data: diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 966bb43936..1d712abea1 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -277,15 +277,12 @@ def __repr__(self): ")" ) - # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are - # the generic implementations on :class:`QuantizedTensorStorage`, - # driven by the ``_FLATTEN_*`` declarations above. ``__new__`` - # re-derives ``_transpose_invalid`` from the restored ``_transpose`` - # buffer, so we deliberately do not round-trip the flag through - # ``_FLATTEN_META_ATTRS``: a producer that ships a transpose through - # the trace had it valid, and trusting a stale ``True`` from a - # Dynamo-embedded meta constant would trip - # :meth:`update_usage`'s ``not has_data_transpose`` guard in backward. + # ``__new__`` re-derives ``_transpose_invalid`` from the restored + # ``_transpose`` buffer, so the flag is deliberately not round-tripped + # through ``_FLATTEN_META_ATTRS``: a producer that ships a transpose + # through the trace had it valid, and trusting a stale ``True`` from + # a Dynamo-embedded meta constant would trip :meth:`update_usage`'s + # ``not has_data_transpose`` guard in backward. def _create_transpose(self): """Update FP8 transpose cache""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index f827695294..d4076b63d9 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -209,10 +209,6 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - # ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` are - # the generic implementations on :class:`QuantizedTensorStorage`, - # driven by the ``_FLATTEN_*`` declarations above. - def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" if rowwise_data and columnwise_data: diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 53b5649956..55cb0c0991 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -100,7 +100,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # column-scaled FP4 data) _amax_columnwise: torch.Tensor - # Builder class for casting to MXFP8 + # Builder class for casting to NVFP4 _quantizer: Optional[Quantizer] # FP4 data type _fp4_dtype: TE_DType diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 76d204deb0..250daec67f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -502,21 +502,6 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return tensor.to(dtype=dtype) -def fake_cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - """Fake counterpart of :func:`cast_if_needed` for shape inference. - - Returns the same tensor if no cast would happen, otherwise an empty - tensor of the requested dtype with matching shape and device. Used by - torch custom-op fake registrations so the FX graph can reason about - output shapes without actually performing the cast. - """ - if tensor is None: - return None - if tensor.dtype == dtype: - return tensor - return torch.empty_like(tensor, dtype=dtype) - - def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: """Check if tensor dimensions are supported for FP8 TN GEMM""" return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 From 844f5c2e1f56184094d876c20cb1c2b430db9bb1 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 18:09:26 +0200 Subject: [PATCH 13/16] [PyTorch] Generic Quantizer.create_metadata via declarative schema Move the (inner_names, meta) builder for SubclassTensorSpec onto the base ``Quantizer``, driven by the storage class's ``_FLATTEN_TENSOR_ATTRS`` / ``_FLATTEN_TENSOR_USAGE`` plus the quantizer's existing ``_storage_scalars()`` hook. ``inner_names`` now follows declaration order so it stays in sync with ``__tensor_flatten__`` (used by ``_flatten_value_into`` when an FP8 output is serialized back into the op's flat Tensor[] payload). The pybind ``tex.DType`` -> Python ``FP8DType`` conversion (needed to keep the meta dict Dynamo-traceable as a constant on ``_ToSubclassFn``) is centralised behind a single class attribute ``_SUBCLASS_META_TEX_KEYS`` that defaults to ``("fp8_dtype",)``. Drops the standalone ``_float8_create_subclass_metadata`` helper and the two duplicate ``Float8*Quantizer.create_metadata`` overrides; they reduce to the generic base behavior plus the existing ``_storage_scalars`` declaration. Signed-off-by: Pawel Gadzinski --- .../pytorch/quantized_tensor.py | 55 ++++++++++++ .../pytorch/tensor/float8_tensor.py | 89 +------------------ 2 files changed, 57 insertions(+), 87 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index afb65f6256..65faa62dcc 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -14,6 +14,7 @@ from torch.utils._pytree import tree_map from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.fp8_dtype import from_tex from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, _IdentityFunc, @@ -783,6 +784,60 @@ def _storage_scalars(self) -> Dict[str, Any]: "QuantizedTensorStorage." ) + # Scalar keys in :meth:`_storage_scalars` whose values are pybind + # enums (currently ``transformer_engine_torch.DType``) and must be + # converted to a Dynamo-traceable Python proxy + # (:class:`FP8DType`) before being embedded in the subclass-spec + # ``meta`` dict. The reverse conversion happens in the tensor + # subclass's :meth:`_flatten_meta_overrides`. + _SUBCLASS_META_TEX_KEYS: Tuple[str, ...] = ("fp8_dtype",) + + def create_metadata( + self, + *, + fake_dtype: torch.dtype, + requires_grad: bool = False, + ) -> Tuple[Tuple[str, ...], Dict[str, Any]]: + """Return ``(inner_names, meta)`` for :meth:`QuantizedTensor.__tensor_unflatten__`. + + Generic implementation driven by the storage class's + :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_ATTRS` / + :attr:`QuantizedTensorStorage._FLATTEN_TENSOR_USAGE` plus this + quantizer's :meth:`_storage_scalars`. ``inner_names`` follows the + declaration order of ``_FLATTEN_TENSOR_ATTRS`` so it matches the + slot order produced by :meth:`QuantizedTensor.__tensor_flatten__` + in :func:`dynamo._flatten_value_into`. + + ``quantizer_snapshot`` is forced to ``None`` on this path: + rebuilding a live :class:`Quantizer` inside + ``__tensor_unflatten__`` would force Dynamo to trace the + constructor, which routinely trips + ``UserDefinedObjectVariable(...Quantizer)``. Code that needs + the live quantizer sources it from outside the compiled region. + """ + storage_cls = type(self)._storage_cls + usage_flag = { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + "always": True, + } + inner_names = tuple( + attr + for attr in storage_cls._FLATTEN_TENSOR_ATTRS + if usage_flag[storage_cls._FLATTEN_TENSOR_USAGE.get(attr, "always")] + ) + scalars = self._storage_scalars() + for key in self._SUBCLASS_META_TEX_KEYS: + if key in scalars: + scalars[key] = from_tex(scalars[key]) + meta: Dict[str, Any] = { + **scalars, + "fake_dtype": fake_dtype, + "quantizer_snapshot": None, + "requires_grad": requires_grad, + } + return inner_names, meta + def create_storage_metadata( self, *, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5ef6a23955..de1040e7c2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -24,7 +24,7 @@ ) from ._quantization_helpers import _IdentityFunc from ..constants import canonicalize_te_dtype, dist_group_type -from ..fp8_dtype import FP8DType, from_tex, to_tex +from ..fp8_dtype import FP8DType, to_tex aten = torch.ops.aten @@ -42,65 +42,6 @@ } -# --------------------------------------------------------------------------- # -# torch.compile output-layout metadata helper -# --------------------------------------------------------------------------- # -# -# Produces the static (inner-names + meta-dict) layout that the dynamo -# integration layer needs to reassemble a :class:`Float8Tensor` from -# the flat ``Tensor[]`` return of a TE custom op, without allocating a -# fake prototype tensor inside a traced region. -# -# Shared between :class:`Float8Quantizer` and -# :class:`Float8CurrentScalingQuantizer` because both produce identical -# ``Float8Tensor`` layouts (rowwise / columnwise / scale-inv inner -# tensors); the per-quantizer ``create_metadata`` methods delegate -# here. - - -def _float8_create_subclass_metadata( - quantizer: "Quantizer", - *, - fake_dtype: torch.dtype, - requires_grad: bool = False, -) -> Tuple[Tuple[str, ...], dict]: - """Return ``(inner_names, meta)`` for :meth:`Float8Tensor.__tensor_unflatten__`. - - ``inner_names`` reflects the rowwise / columnwise usage flags of the - quantizer (``_data`` and/or ``_transpose``, plus always ``_scale_inv``). - ``meta`` carries the static, Dynamo-friendly attributes - :class:`Float8Tensor`'s constructor needs: - - * ``fp8_dtype`` -- :class:`FP8DType` (an :class:`IntEnum`, - proxies as a constant for Dynamo; bridges back to ``tex.DType`` - via :meth:`Float8Tensor._flatten_meta_overrides` inside - ``__tensor_unflatten__``). - * ``fake_dtype`` -- caller-supplied torch dtype. - * ``quantizer_snapshot`` -- always ``None`` on this path. Re-using - the snapshot reconstruction (which builds a fresh quantizer - inside :meth:`Float8Tensor.__tensor_unflatten__`) would force - Dynamo to trace a quantizer constructor call, which routinely - trips ``UserDefinedObjectVariable(Float8...Quantizer)``. - ``quantizer=None`` keeps the wrapper construction within Dynamo's - proxyable surface; user code that needs the live quantizer - sources it from outside the compiled region. - * ``requires_grad`` -- caller-supplied flag. - """ - inner_names: List[str] = [] - if quantizer.rowwise_usage: - inner_names.append("_data") - inner_names.append("_scale_inv") - if quantizer.columnwise_usage: - inner_names.append("_transpose") - meta = { - "fp8_dtype": from_tex(quantizer.dtype), - "fake_dtype": fake_dtype, - "quantizer_snapshot": None, - "requires_grad": requires_grad, - } - return tuple(inner_names), meta - - class Float8Quantizer(Quantizer): """Builder class for FP8 tensors with per-tensor delayed scaling @@ -306,19 +247,6 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True - def create_metadata( - self, - *, - fake_dtype: torch.dtype, - requires_grad: bool = False, - ) -> Tuple[Tuple[str, ...], dict]: - # pylint: disable=missing-function-docstring - return _float8_create_subclass_metadata( - self, - fake_dtype=fake_dtype, - requires_grad=requires_grad, - ) - def _storage_scalars(self) -> Dict[str, Any]: return {"fp8_dtype": self.dtype} @@ -577,19 +505,6 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True - def create_metadata( - self, - *, - fake_dtype: torch.dtype, - requires_grad: bool = False, - ) -> Tuple[Tuple[str, ...], dict]: - # pylint: disable=missing-function-docstring - return _float8_create_subclass_metadata( - self, - fake_dtype=fake_dtype, - requires_grad=requires_grad, - ) - def _storage_scalars(self) -> Dict[str, Any]: return {"fp8_dtype": self.dtype} @@ -646,7 +561,7 @@ def __repr__(self, *, tensor_contents=None): @classmethod def _flatten_meta_overrides(cls, meta: dict) -> dict: """Bridge :class:`FP8DType` (carried by the subclass output spec - via :func:`_float8_create_subclass_metadata`) back to the native + via :meth:`Quantizer.create_metadata`) back to the native ``tex.DType`` accepted by pybind-bound TE kernels. The eager :meth:`__tensor_flatten__` path stores ``tex.DType`` directly and is a no-op here. From 14381afe581d301835306fa53fe7864336df5c2e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 18:15:53 +0200 Subject: [PATCH 14/16] [PyTorch] Drop Recipe._flatten/_unflatten torch.compile protocol ``LinearBwdArgs.fp8_recipe: Optional[Recipe]`` was shipped through the compiled custom op only to read two booleans on the backward side -- ``recipe.fp8_gemm_dgrad.use_split_accumulator`` and ``recipe.fp8_gemm_wgrad.use_split_accumulator``. Replace the Recipe-typed field with those two booleans, populated in setup-ctx the same way the forward already computes ``use_split_accumulator`` from the live recipe. With nothing left to flatten through the bucket scanner: - Drop ``Recipe._flatten`` / ``Recipe._unflatten``, the ``_RECIPE_REGISTRY`` dispatch dict and the explicit ``_register_recipe_subclass`` calls at the bottom of ``common/recipe/__init__.py`` (~90 lines). - Drop ``_recipe_cls()`` / ``_RECIPE_REF`` and remove ``Recipe`` from ``_flattenable_bases`` in ``pytorch/dynamo.py``. - Tighten the now-quantizer-only docstrings on ``_FlattenableBucket`` / ``_MetaPGTensorsBucket``. Signed-off-by: Pawel Gadzinski --- transformer_engine/common/recipe/__init__.py | 92 -------------------- transformer_engine/pytorch/dynamo.py | 22 +---- transformer_engine/pytorch/module/linear.py | 24 +++-- 3 files changed, 14 insertions(+), 124 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 57d7e3965a..97bea190ea 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -166,88 +166,6 @@ def custom(cls): """Whether the given recipe is custom.""" return issubclass(cls, CustomRecipe) - # ------------------------------------------------------------------ # - # torch.compile flatten / unflatten protocol - # ------------------------------------------------------------------ # - # The flattenable bucket in - # :mod:`transformer_engine.pytorch.dynamo` ships ``Recipe`` instances - # through TE custom ops by calling :meth:`_flatten` (instance method - # on each concrete subclass) and :meth:`_unflatten` (classmethod on - # this base, which dispatches by ``_rcls`` stamped into the - # metadata bundle). The default implementation reads - # :func:`dataclasses.fields` and flattens nested ``@dataclass`` - # fields with ``"."`` keys; reconstruction - # instantiates the target class with default args and writes the - # flattened values back. Subclasses can override either method when - # their structure is too irregular for the generic round-trip. - - def _flatten(self): # noqa: D401 -- short name preferred - """Return ``(OpaqueSimpleMetadata, None, [])``.""" - # Lazy imports keep ``common`` independent of pytorch. - from dataclasses import fields, is_dataclass - from transformer_engine.pytorch.dynamo import OpaqueSimpleMetadata - - payload: dict = {"_rcls": type(self).__qualname__} - for f in fields(self): - v = getattr(self, f.name) - if is_dataclass(v) and not isinstance(v, type): - for sf in fields(v): - payload[f"{f.name}.{sf.name}"] = getattr(v, sf.name) - else: - payload[f.name] = v - return OpaqueSimpleMetadata(payload), None, [] - - @classmethod - def _unflatten(cls, meta, _ref, _tensors): - """Dispatch to the concrete subclass identified by - ``meta['_rcls']`` and rehydrate fields (including nested - ``@dataclass`` fields written under ``"."`` - keys by :meth:`_flatten`).""" - from dataclasses import fields, is_dataclass - - target_name = meta["_rcls"] - target_cls = _RECIPE_REGISTRY.get(target_name) - if target_cls is None: - raise KeyError( - f"Unknown recipe class {target_name!r} during unflatten; " - "is the subclass imported in transformer_engine.common.recipe?" - ) - - out = target_cls() - nested: dict = {} - for k, v in meta.items(): - if k == "_rcls": - continue - if "." in k: - parent, child = k.split(".", 1) - nested.setdefault(parent, {})[child] = v - else: - setattr(out, k, v) - for parent, children in nested.items(): - target = getattr(out, parent, None) - if target is None or not is_dataclass(target): - continue - # Nested dataclasses (e.g. ``MMParams``) may be frozen, so - # rebuild the instance with merged kwargs and reassign. - cur_kwargs = {f.name: getattr(target, f.name) for f in fields(target)} - cur_kwargs.update(children) - setattr(out, parent, type(target)(**cur_kwargs)) - return out - - -# Lazily populated by :meth:`Recipe.__init_subclass__` so that -# :meth:`Recipe._unflatten` can dispatch by ``__qualname__``. -_RECIPE_REGISTRY: dict = {} - - -def _register_recipe_subclass(cls) -> None: - _RECIPE_REGISTRY[cls.__qualname__] = cls - - -# Recipe uses pydantic.dataclasses which can interfere with hooking -# ``__init_subclass__``; register subclasses explicitly at the bottom of -# this module instead. - @dataclass(repr=False) class DelayedScaling(Recipe): @@ -738,13 +656,3 @@ def _make_repr(self) -> str: ) -# Populate the dispatch registry consumed by :meth:`Recipe._unflatten`. -for _R in ( - DelayedScaling, - Float8CurrentScaling, - MXFP8BlockScaling, - Float8BlockScaling, - NVFP4BlockScaling, - CustomRecipe, -): - _register_recipe_subclass(_R) diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 87b8224efc..558db92f68 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -918,8 +918,8 @@ class _MetaPGTensorsBucket(_Bucket): Used by every field whose value must be carried as the triple ``(OpaqueSimpleMetadata, ProcessGroup?, Tensor[])`` -- today this covers ``Tensor | QuantizedTensorStorage`` unions (see - :class:`_UniversalTensorBucket`) and ``Quantizer`` / ``Recipe`` - instances (see :class:`_FlattenableBucket`). Concrete subclasses + :class:`_UniversalTensorBucket`) and ``Quantizer`` instances + (see :class:`_FlattenableBucket`). Concrete subclasses implement :meth:`_pack_value` / :meth:`_unpack_value` for their flatten/unflatten protocol; the rest of the bucket contract is identical and lives here. @@ -1209,7 +1209,6 @@ def unpack(self, args: Dict[str, Any], kwargs: Dict[str, Any]) -> None: # different dataclass registrations. _QTS_REF: Optional[type] = None _QUANTIZER_REF: Optional[type] = None -_RECIPE_REF: Optional[type] = None def _quantized_tensor_storage_cls() -> Optional[type]: @@ -1240,19 +1239,6 @@ def _quantizer_cls() -> Optional[type]: return _QUANTIZER_REF -def _recipe_cls() -> Optional[type]: - """Lazy-resolve :class:`Recipe`; ``None`` if unavailable.""" - global _RECIPE_REF - if _RECIPE_REF is None: - try: - from transformer_engine.common.recipe import Recipe - - _RECIPE_REF = Recipe - except Exception: # pragma: no cover - partial init - return None - return _RECIPE_REF - - def _flattenable_bases() -> Tuple[type, ...]: """Return the list of base classes whose subclasses are routed through :class:`_FlattenableBucket`. @@ -1265,7 +1251,7 @@ def _flattenable_bases() -> Tuple[type, ...]: """ return tuple( cls - for cls in (_quantizer_cls(), _quantized_tensor_storage_cls(), _recipe_cls()) + for cls in (_quantizer_cls(), _quantized_tensor_storage_cls()) if cls is not None ) @@ -1274,7 +1260,7 @@ class _FlattenableBucket(_MetaPGTensorsBucket): """Field whose type implements the ``_flatten`` / ``_unflatten`` protocol (see :func:`_flattenable_bases`). Used today for :class:`~transformer_engine.pytorch.quantized_tensor.Quantizer` and - :class:`~transformer_engine.common.recipe.Recipe`. + :class:`~transformer_engine.pytorch.quantized_tensor.QuantizedTensorStorage`. """ # Stored under ``_qcls`` in the metadata bundle to encode ``None`` diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7ca8ba00d2..c72a21c8a9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -189,7 +189,8 @@ class LinearBwdArgs: # --- Numerical / dtype config --- activation_dtype: Optional[torch.dtype] = None fp8: bool = False - fp8_recipe: Optional[Recipe] = None + use_split_accumulator_dgrad: bool = _2X_ACC_DGRAD + use_split_accumulator_wgrad: bool = _2X_ACC_WGRAD backward_override: Optional[str] = None is_weight_param_quantized: bool = False custom: bool = False @@ -672,7 +673,12 @@ def _linear_setup_ctx( # Numerical / dtype config bwd_args.activation_dtype = fwd_args.activation_dtype bwd_args.fp8 = fp8 - bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + if fp8: + _bwd_recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(_bwd_recipe, "fp8_gemm_dgrad"): + bwd_args.use_split_accumulator_dgrad = _bwd_recipe.fp8_gemm_dgrad.use_split_accumulator + if hasattr(_bwd_recipe, "fp8_gemm_wgrad"): + bwd_args.use_split_accumulator_wgrad = _bwd_recipe.fp8_gemm_wgrad.use_split_accumulator bwd_args.backward_override = backward_override bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) bwd_args.custom = fwd_args.custom @@ -973,12 +979,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. ): weight_fp8.update_usage(columnwise_usage=True) - # Choose whether to use GEMM kernel with split accumulator - use_split_accumulator = _2X_ACC_DGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_dgrad"): - use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + use_split_accumulator = bwd_args.use_split_accumulator_dgrad # Update grad input quantizer if grad_input_quantizer is not None: @@ -1121,12 +1122,7 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output = grad_output_quantizer(grad_output) - # Figure out whether to use split accumulator - use_split_accumulator = _2X_ACC_WGRAD - if bwd_args.fp8: - recipe = bwd_args.fp8_recipe - if hasattr(recipe, "fp8_gemm_wgrad"): - use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator + use_split_accumulator = bwd_args.use_split_accumulator_wgrad # Figure out whether to output wgrad GEMM directly into main grad if bwd_args.is_first_microbatch is not None: From edf11aa9b3dd4cf61ff38d995c1d7f3f538da74f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 May 2026 18:30:13 +0200 Subject: [PATCH 15/16] [PyTorch] Consolidate tex.DType opaque-type registration in fp8_dtype.py The four storage modules each carried an identical try/except block that injected ``__fx_repr__`` on ``transformer_engine_torch.DType`` and registered it as a torch.compile value-opaque type. Move the registration to its natural home in ``fp8_dtype.py`` -- right next to ``to_tex`` / ``from_tex`` -- and load it once from ``quantized_tensor.py`` (which every storage / quantizer subclass already pulls in). Same behavior, one place instead of four. Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/fp8_dtype.py | 26 +++++++++++++++++++ .../pytorch/quantized_tensor.py | 19 ++++++++------ .../float8_blockwise_tensor_storage.py | 10 ------- .../tensor/storage/float8_tensor_storage.py | 10 ------- .../tensor/storage/mxfp8_tensor_storage.py | 10 ------- .../tensor/storage/nvfp4_tensor_storage.py | 10 ------- 6 files changed, 37 insertions(+), 48 deletions(-) diff --git a/transformer_engine/pytorch/fp8_dtype.py b/transformer_engine/pytorch/fp8_dtype.py index 2a0fc0cfda..88b2b5c0c1 100644 --- a/transformer_engine/pytorch/fp8_dtype.py +++ b/transformer_engine/pytorch/fp8_dtype.py @@ -65,3 +65,29 @@ def from_tex(d: tex.DType) -> FP8DType: if isinstance(d, tex.DType): return _TEX_TO_FP8DTYPE_BY_TEX[d] return _TEX_TO_FP8DTYPE[int(d)] + + +# Register ``tex.DType`` as a torch.compile value-opaque type so it +# can flow through Dynamo as a constant inside ``__tensor_flatten__`` +# meta dicts and other traced metadata payloads. Without this, +# Dynamo trips on ``UserDefinedObjectVariable(DType)`` because the +# pybind11 enum carries a custom ``__hash__``. ``__fx_repr__`` is +# injected once here so the FX codegen can serialize literal values +# as ``TE_DType()``. Gated by a try/except so importing this +# module remains safe on older PyTorch versions that lack the +# private ``opaque_object`` API. +try: + from torch._library.opaque_object import ( + is_opaque_value_type as _is_opaque_value_type, + register_opaque_type as _register_opaque_type, + ) + + if not hasattr(tex.DType, "__fx_repr__"): + tex.DType.__fx_repr__ = lambda self: ( + f"TE_DType({int(self)})", + {"TE_DType": tex.DType}, + ) + if not _is_opaque_value_type(tex.DType): + _register_opaque_type(tex.DType, typ="value", members={}) +except Exception: # pragma: no cover - older torch / partial init + pass diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 65faa62dcc..8d718b3b12 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -784,12 +784,16 @@ def _storage_scalars(self) -> Dict[str, Any]: "QuantizedTensorStorage." ) - # Scalar keys in :meth:`_storage_scalars` whose values are pybind - # enums (currently ``transformer_engine_torch.DType``) and must be - # converted to a Dynamo-traceable Python proxy - # (:class:`FP8DType`) before being embedded in the subclass-spec - # ``meta`` dict. The reverse conversion happens in the tensor - # subclass's :meth:`_flatten_meta_overrides`. + # Keys in :meth:`_storage_scalars` whose values are pybind enums + # (``transformer_engine_torch.DType``) and must be converted to the + # Python ``FP8DType`` proxy for :meth:`create_metadata`. The opaque + # registration in :mod:`fp8_dtype` is enough to flow ``tex.DType`` + # through Dynamo as an FX constant, but + # :meth:`autograd.Function.apply` -- used by + # :func:`_ToSubclassFn.reassemble_with_autograd` -- still rejects + # opaque values via its proxy-conversion check. The reverse + # conversion lives in the tensor subclass's + # :meth:`_flatten_meta_overrides`. _SUBCLASS_META_TEX_KEYS: Tuple[str, ...] = ("fp8_dtype",) def create_metadata( @@ -812,8 +816,7 @@ def create_metadata( rebuilding a live :class:`Quantizer` inside ``__tensor_unflatten__`` would force Dynamo to trace the constructor, which routinely trips - ``UserDefinedObjectVariable(...Quantizer)``. Code that needs - the live quantizer sources it from outside the compiled region. + ``UserDefinedObjectVariable(...Quantizer)``. """ storage_cls = type(self)._storage_cls usage_flag = { diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 3733a9d32e..84d00c7930 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -18,16 +18,6 @@ from ...utils import _empty_tensor -try: - from torch._library.opaque_object import is_opaque_value_type, register_opaque_type - - if not hasattr(TE_DType, "__fx_repr__"): - TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) - if not is_opaque_value_type(TE_DType): - register_opaque_type(TE_DType, typ="value", members={}) -except Exception: # pragma: no cover - older torch / partial init - pass - class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8BlockwiseQTensor. diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 1d712abea1..4cd7162e39 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -18,16 +18,6 @@ from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor -try: - from torch._library.opaque_object import is_opaque_value_type, register_opaque_type - - if not hasattr(TE_DType, "__fx_repr__"): - TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) - if not is_opaque_value_type(TE_DType): - register_opaque_type(TE_DType, typ="value", members={}) -except Exception: # pragma: no cover - older torch / partial init - pass - class _FromFloat8Func(torch.autograd.Function): """Cast from FP8 to other dtype""" diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index d4076b63d9..d3f19a3b1c 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -19,16 +19,6 @@ from ...utils import _empty_tensor -try: - from torch._library.opaque_object import is_opaque_value_type, register_opaque_type - - if not hasattr(TE_DType, "__fx_repr__"): - TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) - if not is_opaque_value_type(TE_DType): - register_opaque_type(TE_DType, typ="value", members={}) -except Exception: # pragma: no cover - older torch / partial init - pass - class _FromMXFP8Func(torch.autograd.Function): """Cast from MXFP8 to other dtype""" diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 55cb0c0991..f8d79ccf5e 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -21,16 +21,6 @@ from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...utils import _empty_tensor -try: - from torch._library.opaque_object import is_opaque_value_type, register_opaque_type - - if not hasattr(TE_DType, "__fx_repr__"): - TE_DType.__fx_repr__ = lambda self: (f"TE_DType({int(self)})", {"TE_DType": TE_DType}) - if not is_opaque_value_type(TE_DType): - register_opaque_type(TE_DType, typ="value", members={}) -except Exception: # pragma: no cover - older torch / partial init - pass - @functools.lru_cache(maxsize=None) def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: From c1b0842d3ccff94ae50ae06cf04f44bff7a91460 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 29 May 2026 16:51:57 +0200 Subject: [PATCH 16/16] [PyTorch] Drop TensorSpec; reassemble torch.compile outputs from fake templates Replace the TensorSpec descriptor hierarchy in dynamo.py with two free helpers (_template_slot_count / _template_reassemble) that read the slot layout and rebuild the user-facing objects straight off the forward fake_impl's fake values. forward_fn / setup_context now consume fwd_fake_impl directly instead of a derived output_info_fn. Also preserve the FP8 output's quantizer (and its amax-reduction group) across the compile-path reassembly: make_fake_empty stashes the live quantizer on the fake template and _template_reassemble restores it, since __tensor_unflatten__ rebuilds with quantizer=None. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 11 + transformer_engine/pytorch/dynamo.py | 807 ++++++------------ transformer_engine/pytorch/module/linear.py | 181 ++-- .../pytorch/tensor/float8_tensor.py | 79 ++ 4 files changed, 437 insertions(+), 641 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 04cd7cc843..439bae84d7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -443,6 +443,17 @@ def fn(inp): f"expected Float8Tensor output, got {type(out).__name__}" ) assert out.shape == (32, 32) + # The compile-path reassembly rebuilds the wrapper via + # ``__tensor_unflatten__``, whose snapshot-free ``meta`` forces + # ``quantizer=None`` (a live ``ProcessGroup`` / amax-reduction group + # can't survive Dynamo guards). ``make_fake_empty`` stashes the live + # quantizer on the fake template and the reassembly helper restores it, + # so the output must keep a (non-``None``) quantizer rather than losing + # its amax-reduction group. + assert out._quantizer is not None, ( + "FP8 output lost its quantizer (and thus its amax-reduction group) " + "on the torch.compile path" + ) # Dequantising outside the compiled region exercises the # ``Float8Tensor`` machinery (scale + data + dtype all wired up # by the rewrap) on the value returned from the compiled fn. diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 558db92f68..0c8ec0f378 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -26,8 +26,6 @@ __all__ = [ "OpaqueSimpleMetadata", - "TensorSpec", - "tensor_spec", "_te_register_custom_op", ] @@ -79,21 +77,18 @@ def _decode_none(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: # its flatten protocol and concatenates the inner plain tensors into the # op's ``Tensor[]`` return. # -# At call-site time (in :func:`forward_fn`), the layout for each user -# output is described by the user-supplied ``output_info_fn``: a pure -# Python function that returns a list of :class:`TensorSpec`, each -# carrying the static (class, inner_names, metadata, shape, stride) -# tuple needed to reassemble the user-facing object from its real -# inner tensors emitted by the op. +# At call-site time (in :func:`forward_fn` / ``setup_context``), the layout +# for each output is read straight off the forward ``fake_impl``'s fake +# values, which double as reassembly templates (:func:`_template_slot_count` +# / :func:`_template_reassemble`). def _contiguous_stride(shape: Sequence[int]) -> Tuple[int, ...]: """Row-major contiguous stride for ``shape``. - Used by :meth:`SubclassTensorSpec.from_quantizer` to fill in the - ``stride`` field expected by ``__tensor_unflatten__``; user code - that builds :class:`SubclassTensorSpec` directly typically does - not need to touch this. + Used to fill in the ``stride`` field expected by + ``__tensor_unflatten__`` when rebuilding a wrapper subclass from a + fake template (:func:`_template_reassemble`). """ stride: List[int] = [1] * len(shape) for i in range(len(shape) - 2, -1, -1): @@ -102,474 +97,205 @@ def _contiguous_stride(shape: Sequence[int]) -> Tuple[int, ...]: # --------------------------------------------------------------------------- # -# TensorSpec -- unified per-slot descriptor +# Reassembly: rebuild user-facing objects from the op's flat ``Tensor[]``. # --------------------------------------------------------------------------- # # -# ``TensorSpec`` is the single source of truth for one user output / one -# backward grad / one fake saved-slot value. Each instance encodes: -# -# * ``slot_count()`` -- how many entries of the op's flat ``Tensor[]`` -# payload this output consumes; -# * ``reassemble(chunk)`` -- how to turn those entries back into the -# user-facing object (plain tensor, tensor -# subclass, ``QuantizedTensorStorage``, ...); -# * ``reassemble_with_autograd(chunk)`` -# -- variant used by :func:`forward_fn` that -# interposes :class:`_ToSubclassFn` for -# subclass paths so the construction stays -# on the autograd graph; -# * ``alloc()`` -- build an empty fake version of the value -# for shape inference under -# :class:`torch._subclasses.FakeTensorMode`. - - -class TensorSpec: - """Per-output / per-saved-slot layout + (optional) allocation descriptor. - - Concrete subclasses (:class:`NoneSpec`, :class:`PlainTensorSpec`, - :class:`SubclassTensorSpec`, :class:`StorageSpec`) implement the - methods listed below. See module-level commentary for the role - each method plays in the forward / fake / setup-context pipelines. - """ - - def slot_count(self) -> int: - raise NotImplementedError( - f"{type(self).__name__}.slot_count() not implemented" - ) - - def reassemble(self, chunk: List[Any]) -> Any: - raise NotImplementedError( - f"{type(self).__name__}.reassemble() not implemented" - ) - - def reassemble_with_autograd(self, chunk: List[Any]) -> Any: - """Reassemble while keeping the autograd graph intact. - - Default to :meth:`reassemble`; only :class:`SubclassTensorSpec` - overrides this to route subclass construction through - :class:`_ToSubclassFn` (so AOTAutograd records the wrap). - """ - return self.reassemble(chunk) - - def alloc(self) -> Any: - raise NotImplementedError( - f"{type(self).__name__}.alloc() not implemented" - ) - - -class NoneSpec(TensorSpec): - """Output / save slot whose value is ``None``. - - Consumes one ``Tensor[]`` slot via the :func:`_encode_none` / - :func:`_decode_none` sentinel pair so that the op's schema (which - is non-nullable ``Tensor[]``) can still carry a ``None`` value - end-to-end. +# The forward ``fake_impl`` returns the op's outputs / saved tensors as fake +# values (``make_fake_empty`` wrappers / ``make_empty`` storages / +# ``torch.empty`` plains / aliased forward args / ``None``). Each fake value is +# itself a complete reassembly *template*: it says how many flat slots the real +# value occupies and how to rebuild it. :func:`_flatten_value_into` packs a +# value into slots; the two helpers below are its inverse (slot count + +# rebuild), reading straight off the fake template -- no separate descriptor +# object is materialised. + + +def _template_slot_count(template: Any, *, aliased: bool = False) -> int: + """Flat ``Tensor[]`` slots the real value for ``template`` occupies. + + ``aliased`` arg / ``None`` -> 1 (an :func:`_encode_none` sentinel); a + plain tensor -> 1; a ``make_fake_empty`` subclass -> ``len(inner_names)`` + (from its stamped plan); a storage -> ``len(_torch_compile_flatten())``. """ - - def slot_count(self) -> int: + if aliased or template is None: return 1 - - def reassemble(self, chunk: List[Any]) -> Any: - return None - - def alloc(self) -> Any: - return None - - -class AliasedSpec(TensorSpec): - """Saved-tensor slot whose value is identical to a forward arg. - - The forward impl writes ``None`` into the slot (so no extra storage - moves through the op return) and tags the slot's ``alias`` name in - ``ctx_attrs["saved_tensor_aliases"]``; the user's ``setup_context`` - resolves the alias back to the actual forward arg. - - Behaves like :class:`NoneSpec` on the schema side (1 sentinel slot, - ``reassemble -> None``, ``alloc -> None``); the only difference is - that :func:`_inject_saved_aliases` reads ``self.alias`` to populate - ``ctx_attrs["saved_tensor_aliases"]``. - """ - - def __init__(self, alias: str) -> None: - self.alias = alias - - def slot_count(self) -> int: + if isinstance(template, torch.Tensor): + plan = getattr(template, _TE_COMPILE_UNFLATTEN_PLAN, None) + if plan is not None: + inner_names, _ = plan + return len(inner_names) return 1 + flatten = getattr(template, "_torch_compile_flatten", None) + if flatten is not None: + _, _, tensors = flatten() + return len(tensors) + raise TypeError( + f"fake_impl produced an unsupported value of type {type(template).__name__}; " + "expected None / torch.Tensor (plain or make_fake_empty subclass) / " + "a storage exposing _torch_compile_flatten()." + ) - def reassemble(self, chunk: List[Any]) -> Any: - return None - - def alloc(self) -> Any: - return None - - -class PlainTensorSpec(TensorSpec): - """Plain :class:`torch.Tensor` output / save slot. - Carries ``shape`` / ``dtype`` / ``device`` for allocation; reassembly - is just the lone slot value. +def _template_reassemble( + template: Any, + chunk: List[Any], + *, + with_autograd: bool = False, + aliased: bool = False, +) -> Any: + """Rebuild the user-facing value for ``template`` from real slots ``chunk``. + + Inverse of :func:`_flatten_value_into`, driven by the fake template: an + ``aliased`` arg / ``None`` -> ``None`` (aliases are resolved by the + caller's ``setup_context`` from the alias name); a plain tensor -> + ``chunk[0]``; a ``make_fake_empty`` subclass -> ``__tensor_unflatten__`` + (routed through :class:`_ToSubclassFn` when ``with_autograd`` so the wrap + stays on the autograd graph); a storage -> ``_torch_compile_do_unflatten``. """ - - def __init__( - self, - shape: Optional[Sequence[int]] = None, - dtype: Optional["torch.dtype"] = None, - device: Optional["torch.device"] = None, - ) -> None: - self.shape = tuple(shape) if shape is not None else None - self.dtype = dtype - self.device = device - - def slot_count(self) -> int: - return 1 - - def reassemble(self, chunk: List[Any]) -> Any: + if aliased or template is None: + return None + if isinstance(template, torch.Tensor): + plan = getattr(template, _TE_COMPILE_UNFLATTEN_PLAN, None) + if plan is not None: + inner_names, meta = plan + shape = tuple(template.shape) + stride = _contiguous_stride(shape) + if with_autograd: + result = _ToSubclassFn.apply( + type(template), inner_names, meta, shape, stride, *chunk + ) + else: + inner_dict = dict(zip(inner_names, chunk)) + result = type(template).__tensor_unflatten__( + inner_dict, meta, shape, stride + ) + # ``__tensor_unflatten__`` rebuilds with ``quantizer=None`` (the + # snapshot can't carry a live ``ProcessGroup``); restore the live + # quantizer the fake template stashed so the output keeps its + # amax-reduction group. + quantizer = getattr(template, "_te_compile_quantizer", None) + if quantizer is not None: + result._quantizer = quantizer + return result return chunk[0] - - def alloc(self) -> Any: - if self.shape is None or self.dtype is None or self.device is None: - return TensorSpec.alloc(self) - return torch.empty(self.shape, dtype=self.dtype, device=self.device) - - -class SubclassTensorSpec(TensorSpec): - """Tensor-subclass output / save slot (e.g. :class:`Float8Tensor`). - - Two modes, picked at construction time via :meth:`from_quantizer`: - - * **Full mode** (``wrapper_cls`` supplied): the spec knows the - subclass identity, ``inner_names`` and ``meta`` for - ``__tensor_unflatten__``, so it can both :meth:`alloc` (under - :class:`FakeTensorMode`) and :meth:`reassemble` slot chunks from - the op's flat ``Tensor[]`` payload back into a user-facing - subclass instance. Used for forward outputs that flow through - the custom op and need to be re-wrapped on the other side. - * **Alloc-only mode** (no ``wrapper_cls``): the spec only carries - enough info to :meth:`alloc` an empty instance via - ``quantizer.make_empty(shape, dtype, device)``. Used for - backward gradient outputs, which never round-trip through the - flat ``Tensor[]`` -- ``_format_bwd_result`` hands them straight - to autograd -- so the layout-aware methods are intentionally - undefined. - """ - - def __init__( - self, - *, - shape: Sequence[int], - alloc_quantizer: Any, - alloc_dtype: "torch.dtype", - alloc_device: "torch.device", - cls: Optional[type] = None, - inner_names: Optional[Sequence[str]] = None, - meta: Any = None, - stride: Optional[Sequence[int]] = None, - ) -> None: - self.cls = cls - self.inner_names = tuple(inner_names) if inner_names is not None else None - self.meta = meta - self.shape = tuple(shape) - self.stride = tuple(stride) if stride is not None else None - self.alloc_quantizer = alloc_quantizer - self.alloc_dtype = alloc_dtype - self.alloc_device = alloc_device - - def _require_full_mode(self, method_name: str) -> None: - if self.cls is None: - raise RuntimeError( - f"SubclassTensorSpec.{method_name} is only available in " - "full mode (built with ``wrapper_cls=``). Alloc-only specs " - "(used for backward grad outputs) don't participate in the " - "flat ``Tensor[]`` payload, so they have no slot layout." - ) - - def slot_count(self) -> int: - self._require_full_mode("slot_count") - return len(self.inner_names) - - def reassemble(self, chunk: List[Any]) -> Any: - self._require_full_mode("reassemble") - inner_dict = dict(zip(self.inner_names, chunk)) - return self.cls.__tensor_unflatten__( - inner_dict, self.meta, self.shape, self.stride - ) - - def reassemble_with_autograd(self, chunk: List[Any]) -> Any: - self._require_full_mode("reassemble_with_autograd") - return _ToSubclassFn.apply( - self.cls, self.inner_names, self.meta, self.shape, self.stride, *chunk - ) - - def alloc(self) -> Any: - return self.alloc_quantizer.make_empty( - self.shape, dtype=self.alloc_dtype, device=self.alloc_device - ) - - @classmethod - def from_quantizer( - cls, - quantizer: Any, - *, - shape: Sequence[int], - dtype: "torch.dtype", - device: "torch.device", - wrapper_cls: Optional[type] = None, - ) -> "SubclassTensorSpec": - """Build a :class:`SubclassTensorSpec` from a live quantizer. - - Hides the ``create_metadata`` / inner-name / stride bookkeeping - behind a single call: callers in ``output_info_fn`` / - ``bwd_output_info_fn`` only specify the user-facing identity - (shape, dtype, device, quantizer) -- and, for forward outputs - that need flat-slot reassembly, the ``wrapper_cls`` they - unflatten into. - - Omitting ``wrapper_cls`` yields an alloc-only spec suitable - for backward grad outputs: the quantizer-specific fake - allocation still works (``quantizer.make_empty(...)``), but - :meth:`slot_count` / :meth:`reassemble` are intentionally - disabled because gradients never round-trip through the op's - flat ``Tensor[]`` payload. - """ - if wrapper_cls is None: - return cls( - shape=tuple(shape), - alloc_quantizer=quantizer, - alloc_dtype=dtype, - alloc_device=device, - ) - inner_names, meta = quantizer.create_metadata(fake_dtype=dtype) - return cls( - cls=wrapper_cls, - inner_names=inner_names, - meta=meta, - shape=tuple(shape), - stride=_contiguous_stride(shape), - alloc_quantizer=quantizer, - alloc_dtype=dtype, - alloc_device=device, - ) - - -class StorageSpec(TensorSpec): - """Non-tensor storage output / save slot (e.g. :class:`Float8TensorStorage`). - - Reassembled via ``cls._torch_compile_do_unflatten``; allocated via - ``alloc_quantizer.make_empty(shape, ...)``. - """ - - def __init__( - self, - cls: type, - meta: Any, - pg: Any, - tensor_count: int, - *, - alloc_quantizer: Any = None, - alloc_shape: Optional[Sequence[int]] = None, - alloc_dtype: Optional["torch.dtype"] = None, - alloc_device: Optional["torch.device"] = None, - ) -> None: - self.cls = cls - self.meta = meta - self.pg = pg - self.tensor_count = tensor_count - self.alloc_quantizer = alloc_quantizer - self.alloc_shape = ( - tuple(alloc_shape) if alloc_shape is not None else None - ) - self.alloc_dtype = alloc_dtype - self.alloc_device = alloc_device - - def slot_count(self) -> int: - return self.tensor_count - - def reassemble(self, chunk: List[Any]) -> Any: + flatten = getattr(template, "_torch_compile_flatten", None) + if flatten is not None: + meta, pg, _ = flatten() real_tensors = [t for t in chunk if t is not None] - return self.cls._torch_compile_do_unflatten(self.meta, self.pg, real_tensors) + return type(template)._torch_compile_do_unflatten(meta, pg, real_tensors) + raise TypeError( + f"fake_impl produced an unsupported value of type {type(template).__name__}; " + "expected None / torch.Tensor (plain or make_fake_empty subclass) / " + "a storage exposing _torch_compile_flatten()." + ) - def alloc(self) -> Any: - if self.alloc_quantizer is None or self.alloc_shape is None: - return TensorSpec.alloc(self) - return self.alloc_quantizer.make_empty( - self.alloc_shape, dtype=self.alloc_dtype, device=self.alloc_device - ) - - @classmethod - def from_quantizer( - cls, - quantizer: Any, - *, - shape: Sequence[int], - dtype: "torch.dtype", - device: "torch.device", - ) -> "StorageSpec": - """Build a :class:`StorageSpec` from a live quantizer. - - Hides the ``create_storage_metadata`` four-tuple - ``(cls, meta, process_group, tensor_count)`` behind a single - call: callers in ``output_info_fn`` only need to specify the - quantizer that drives the layout plus the higher-precision - view (shape / dtype / device) the storage represents. - """ - storage_cls, meta, pg, count = quantizer.create_storage_metadata( - shape=shape, - fake_dtype=dtype, - device=device, - ) - return cls( - cls=storage_cls, - meta=meta, - pg=pg, - tensor_count=count, - alloc_quantizer=quantizer, - alloc_shape=tuple(shape), - alloc_dtype=dtype, - alloc_device=device, - ) +def _split_fwd_fake_result( + result: Tuple[Any, ...], +) -> Tuple[List[Any], List[Any], Dict[str, Any]]: + """Slice a forward ``fake_impl`` return into ``(user_fakes, saved_fakes, ctx_attrs)``. -def tensor_spec( - *, - shape: Optional[Sequence[int]] = None, - dtype: Optional["torch.dtype"] = None, - device: Optional["torch.device"] = None, - quantizer: Optional[Any] = None, - wrapper_cls: Optional[type] = None, - storage: bool = False, - alias: Optional[str] = None, -) -> TensorSpec: - """One-stop factory for declaring an op output / saved slot / grad spec. - - Single entry point that authors of ``output_info_fn`` / - ``bwd_output_info_fn`` use to describe every slot the op - produces, regardless of whether the slot is a plain tensor, a - quantized wrapper, a non-tensor storage, an aliased save, an - absent output, or a grad-only alloc target. Internally dispatches - to the appropriate :class:`TensorSpec` subclass based on which - keyword arguments are supplied (first match wins): - - * ``alias`` set -> :class:`AliasedSpec` (saved slot that - reuses a forward arg; no payload moves - through the op). - * ``shape is None`` -> :class:`NoneSpec` (absent output / save). - * ``quantizer is None`` -> :class:`PlainTensorSpec`. - * ``storage=True`` -> :class:`StorageSpec` via - :meth:`StorageSpec.from_quantizer` (used - for quantized saved storages). - * otherwise -> :class:`SubclassTensorSpec` via - :meth:`SubclassTensorSpec.from_quantizer`. - ``wrapper_cls`` picks between *full mode* - (forward outputs that re-wrap from the - flat ``Tensor[]`` payload) and - *alloc-only mode* (backward grad outputs - that never round-trip through the op). - - All quantized paths use ``dtype`` / ``device`` for fake allocation - (``quantizer.make_empty(shape, dtype, device)``); the plain path - requires both as well, since it falls back to ``torch.empty``. + ``result`` has the eager-impl tuple shape ``(*user_outputs, + tensors_to_save, tensor_objects, ctx_attrs)``; the fake values double as + reassembly templates for :func:`_template_slot_count` / + :func:`_template_reassemble`. """ - if alias is not None: - return AliasedSpec(alias) - if shape is None: - return NoneSpec() - if quantizer is None: - return PlainTensorSpec(shape=shape, dtype=dtype, device=device) - if storage: - return StorageSpec.from_quantizer( - quantizer, shape=shape, dtype=dtype, device=device - ) - return SubclassTensorSpec.from_quantizer( - quantizer, - shape=shape, - dtype=dtype, - device=device, - wrapper_cls=wrapper_cls, - ) + num_outputs = len(result) - _FWD_TRAILING_SLOTS + saved = result[num_outputs] + ctx_attrs = result[num_outputs + 2] + user_fakes = list(result[:num_outputs]) + saved_fakes = list(saved) if saved is not None else [] + ctx_attrs = dict(ctx_attrs) if ctx_attrs else {} + return user_fakes, saved_fakes, ctx_attrs # --------------------------------------------------------------------------- # -# Fake-impl synthesis from ``output_info_fn`` / ``bwd_output_info_fn``. +# ``fake_impl`` consumers. +# +# A module describes its forward op outputs directly as a ``fwd_fake_impl`` +# that returns the same ``(*user_outputs, tensors_to_save, tensor_objects, +# ctx_attrs)`` tuple as the eager ``fwd_impl``, but built out of *fake* +# values: +# * ``quantizer.make_fake_empty(...)`` -- Dynamo-safe quantized wrapper. +# * ``quantizer.make_empty(...)`` -- quantized storage. +# * ``torch.empty(...)`` -- plain tensor. +# * the actual forward-arg tensor -- an aliased saved slot. +# * ``None`` -- absent output / saved slot. +# These fake values are the single source of truth for the op's layout: +# * ``forward_fn`` / ``setup_context`` reassemble the real flat ``Tensor[]`` +# using the fakes as templates (:func:`_template_slot_count` / +# :func:`_template_reassemble`), resolving aliased saved slots via +# :func:`_alias_name_for`. +# * :func:`_fwd_register_fake_from_fake_impl` wires the same callable as the +# op's ``register_fake`` (aliased saved slots nulled so the fake flat +# ``Tensor[]`` layout matches the eager impl, which writes ``None`` for +# aliases). +# The backward ``bwd_fake_impl`` is used directly as the backward +# ``register_fake`` -- backward grads never round-trip through the op +# payload, so no reassembly is needed. # --------------------------------------------------------------------------- # +# Attribute stamped on ``make_fake_empty`` outputs carrying the +# ``(inner_names, meta)`` plan needed to rebuild the subclass via +# ``__tensor_unflatten__``. The adapter reads it back (as a Dynamo +# constant) instead of calling ``value.__tensor_flatten__()`` in-trace: +# a tensor method returning non-tensors graph-breaks under fullgraph, +# whereas a plain attribute read is inlined. +_TE_COMPILE_UNFLATTEN_PLAN = "_te_compile_unflatten_plan" -def _inject_saved_aliases( - ctx_attrs: Dict[str, Any], saved_slots: Sequence[TensorSpec] -) -> Dict[str, Any]: - """Inject ``saved_tensor_aliases`` derived from ``saved_slots``. - - The user's ``setup_context`` callback reads aliases off - ``ctx_attrs["saved_tensor_aliases"]`` to resolve aliased saved - slots back to their forward arg. Only :class:`AliasedSpec` - contributes a non-``None`` alias entry; every other spec maps to - ``None`` (no alias, the real value is carried through the op - payload). We expose the tuple on every code path (real op output, - output-info path, auto-synthesized fake) so the callback's - contract stays identical. - """ - out = dict(ctx_attrs) if ctx_attrs else {} - out["saved_tensor_aliases"] = tuple( - s.alias if isinstance(s, AliasedSpec) else None for s in saved_slots - ) - return out +def _fwd_arg_alias_pairs(fwd_obj: Any, field_names: Sequence[str]) -> List[Tuple[torch.Tensor, str]]: + """Collect ``(tensor field value, field name)`` for a fwd-arg object. -def _make_fake_impl_from_output_info( - output_info_fn: Callable[[Any], Any], -) -> Callable[[Any], Tuple[Any, ...]]: - """Build a forward fake-impl from an ``output_info_fn``. - - The synthesized fake-impl returns - ``(*user_outputs, tensors_to_save, None, None)``: - - * ``user_outputs`` comes from ``[s.alloc() for s in user_specs]``. - * ``tensors_to_save`` comes from ``tuple(s.alloc() for s in saved_slots)``, - or ``None`` if ``saved_slots`` is empty - (e.g. ``is_grad_enabled=False``). - * The trailing ``tensor_objects`` / ``ctx_attrs`` slots are - ``None`` placeholders -- the eager fwd_impl contract requires - them in the tuple (via ``_FWD_TRAILING_SLOTS``) but - :func:`_format_fwd_result` only reads user outputs + saved - tensors off a fake-impl return. - - ``output_info_fn`` must return a 3-tuple - ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], - ctx_attrs: Dict[str, Any])``. + ``field_names`` is precomputed outside the trace (reading + ``dataclasses.fields`` in-trace would graph-break on the class + ``mappingproxy``); attribute access by name is inlined. Used to + detect saved slots that alias a forward arg by identity (``is``). """ + pairs: List[Tuple[torch.Tensor, str]] = [] + for name in field_names: + value = getattr(fwd_obj, name, None) + if isinstance(value, torch.Tensor): + pairs.append((value, name)) + return pairs - def _fake(args: Any) -> Tuple[Any, ...]: - user_specs, saved_slots, _ = output_info_fn(args) - user_outputs = [s.alloc() for s in user_specs] - tensors_to_save = ( - None if not saved_slots else tuple(s.alloc() for s in saved_slots) - ) - # Trailing ``tensor_objects`` / ``ctx_attrs`` slots are required - # by the eager fwd_impl contract (``_FWD_TRAILING_SLOTS``) but - # are never read off a fake-impl return -- ``_format_fwd_result`` - # only slices user outputs + tensors_to_save out of the tuple. - return (*user_outputs, tensors_to_save, None, None) - return _fake +def _alias_name_for(value: Any, pairs: List[Tuple[torch.Tensor, str]]) -> Optional[str]: + """Return the forward-arg name ``value`` aliases (by ``is``), else ``None``.""" + for tensor, name in pairs: + if value is tensor: + return name + return None -def _make_fake_impl_from_bwd_output_info( - bwd_output_info_fn: Callable[[Any], List[TensorSpec]], +def _fwd_register_fake_from_fake_impl( + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + field_names: Sequence[str], ) -> Callable[[Any], Tuple[Any, ...]]: - """Build a backward fake-impl from a ``bwd_output_info_fn``. - - The descriptor returns a flat list of :class:`TensorSpec` - (typically :class:`NoneSpec` / :class:`PlainTensorSpec` / - alloc-only :class:`SubclassTensorSpec` for quantized grads), one - per gradient output in the same order as ``backward_impl``'s - return tuple. The synthesized fake-impl just calls - :meth:`TensorSpec.alloc` on each. + """Adapt a forward ``fake_impl`` into a ``register_fake`` kernel. + + The user's ``fake_impl`` returns the *actual* forward-arg tensor for + aliased saved slots; the eager impl instead writes ``None`` there + (the value rides along as a ctx alias, not through the op payload). + Aliased saved slots are nulled here so the fake flat ``Tensor[]`` + layout stays identical to the eager impl. """ - def _fake(bwd_args: Any) -> Tuple[Any, ...]: - specs = bwd_output_info_fn(bwd_args) - return tuple(s.alloc() for s in specs) + def fwd_fake(fwd_obj: Any) -> Tuple[Any, ...]: + result = fwd_fake_impl(fwd_obj) + num_outputs = len(result) - _FWD_TRAILING_SLOTS + user_outputs = result[:num_outputs] + saved = result[num_outputs] + if saved is None: + tensors_to_save: Any = None + else: + pairs = _fwd_arg_alias_pairs(fwd_obj, field_names) + tensors_to_save = tuple( + None if _alias_name_for(v, pairs) is not None else v for v in saved + ) + return (*user_outputs, tensors_to_save, None, None) - return _fake + return fwd_fake class _ToSubclassFn(torch.autograd.Function): @@ -1544,7 +1270,7 @@ def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any Used only to flatten the user's setup-context return into a ``(flat_tensors, tensor_objects)`` pair stashed on ``ctx`` for the backward; the forward output and saved-tensor restoration on the - compile-path now go through :class:`TensorSpec` instead. Lazy-imports + compile-path go through :func:`_template_reassemble` instead. Lazy-imports avoid the dynamo<->quantized_tensor circular import that ``transformer_engine.pytorch`` would otherwise trigger at module import time. @@ -1568,7 +1294,8 @@ def _prepare_for_saving(tensors: Any) -> Tuple[List[Optional[torch.Tensor]], Any def _flatten_value_into(flat: List[torch.Tensor], value: Any) -> None: """Append the ``Tensor[]`` slots produced by ``value`` to ``flat``. - The dispatch matches the four spec kinds in :class:`TensorSpec`: + The inverse of :func:`_template_reassemble`; the slot counts match + :func:`_template_slot_count`: * ``None`` -> 1 sentinel slot (via :func:`_encode_none`). * plain Tensor -> 1 slot. @@ -1609,10 +1336,10 @@ def _format_fwd_result(result: Any) -> List[torch.Tensor]: User outputs come first, then the saved-for-backward tensors in declaration order. Both groups go through the same per-value :func:`_flatten_value_into` dispatch -- the slot layout produced - here must match exactly what :meth:`TensorSpec.slot_count` reports - for the corresponding spec, since the call-site reassembly in + here must match exactly what :func:`_template_slot_count` reports + for the corresponding fake template, since the call-site reassembly in :func:`forward_fn` / :func:`_setup_context` slices this flat list - back into user-facing objects using those per-spec counts. + back into user-facing objects using those per-template counts. ``None`` entries on either side are smuggled through :func:`_encode_none` so the schema stays non-nullable and @@ -1797,7 +1524,8 @@ def _register_autograd_for_op( grad_targets: List[Tuple[int, bool]], setup_context_user: Callable[..., None], backward_obj_type: type, - output_info_fn: Callable[[Any], Tuple[List["TensorSpec"], List["TensorSpec"], Any]], + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + fwd_field_names: Sequence[str], ) -> None: """Wire ``register_autograd`` on a forward op so its backward calls ``bwd_op_name``. @@ -1811,13 +1539,12 @@ def _register_autograd_for_op( The op's ``Tensor[]`` return holds the flat layout produced by :func:`_format_fwd_result` -- one chunk per user output / saved - tensor, sliced according to the user-supplied ``output_info_fn``: - a pure Python function returning - ``(user_specs: List[TensorSpec], saved_slots: List[TensorSpec], - ctx_attrs)``. Traceable by Dynamo / AOT, no fake tensor allocation - involved. :class:`AliasedSpec` entries on the saved side carry the - forward-arg name the slot aliases, surfaced to the user's - ``setup_context`` via ``ctx_attrs["saved_tensor_aliases"]``. + tensor. ``setup_context`` re-runs ``fwd_fake_impl`` to recover the + fake output / saved templates, then reassembles each chunk via + :func:`_template_reassemble`. Saved slots that alias a forward arg + (the fake returns the actual arg) are detected by identity and + surfaced to the user's ``setup_context`` via + ``ctx_attrs["saved_tensor_aliases"]``. """ fwd_qualname = f"{_TE_OP_NAMESPACE}::{fwd_op_name}" @@ -1828,30 +1555,36 @@ def _setup_context(ctx, inputs, output): kwargs = dict(zip(fwd_arg_names, inputs)) fwd_obj = _unpack(fwd_arg_type, kwargs, fwd_buckets) - user_specs, saved_slots, ctx_attrs = output_info_fn(fwd_obj) - ctx_attrs = _inject_saved_aliases(ctx_attrs, saved_slots) + user_fakes, saved_fakes, ctx_attrs = _split_fwd_fake_result(fwd_fake_impl(fwd_obj)) + pairs = _fwd_arg_alias_pairs(fwd_obj, fwd_field_names) + saved_aliases = tuple(_alias_name_for(t, pairs) for t in saved_fakes) + ctx_attrs = dict(ctx_attrs) + ctx_attrs["saved_tensor_aliases"] = saved_aliases cursor = 0 user_outputs: List[Any] = [] - for spec in user_specs: - n = spec.slot_count() + for template in user_fakes: + n = _template_slot_count(template) chunk = [_decode_none(t) for t in output[cursor:cursor + n]] cursor += n - user_outputs.append(spec.reassemble(chunk)) + user_outputs.append(_template_reassemble(template, chunk)) tensors_to_save_from_forward_list: List[Any] = [] - for spec in saved_slots: - n = spec.slot_count() + for template, alias in zip(saved_fakes, saved_aliases): + aliased = alias is not None + n = _template_slot_count(template, aliased=aliased) chunk = [_decode_none(t) for t in output[cursor:cursor + n]] cursor += n - tensors_to_save_from_forward_list.append(spec.reassemble(chunk)) + tensors_to_save_from_forward_list.append( + _template_reassemble(template, chunk, aliased=aliased) + ) tensors_to_save_from_forward = tuple(tensors_to_save_from_forward_list) bwd_obj = backward_obj_type() tensors_to_save_from_setup = setup_context_user( bwd_obj, fwd_obj, - user_outputs[0] if len(user_specs) == 1 else tuple(user_outputs), + user_outputs[0] if len(user_fakes) == 1 else tuple(user_outputs), ctx_attrs, tensors_to_save_from_forward, ) @@ -1984,11 +1717,8 @@ def _te_register_custom_op( backward_arg_type: type, backward_obj: type, backward_impl: Callable[[Any], Any], - output_info_fn: Callable[ - [Any], - Tuple[List["TensorSpec"], List["TensorSpec"], Dict[str, Any]], - ], - bwd_output_info_fn: Callable[[Any], List["TensorSpec"]], + fwd_fake_impl: Callable[[Any], Tuple[Any, ...]], + bwd_fake_impl: Callable[[Any], Tuple[Any, ...]], ) -> Callable[..., Any]: """Register a TE module's forward + backward as a single torch custom op. @@ -2039,57 +1769,38 @@ def _te_register_custom_op( backward_impl Eager backward implementation. Receives a single argument of type ``backward_arg_type`` and returns the gradient tuple. - output_info_fn - Pure-Python layout descriptor for the op's outputs: - ``fn(fwd_obj) -> (user_specs, saved_slots, ctx_attrs)``. - - * ``user_specs`` is a list, one :class:`TensorSpec` per user - output. Each spec encodes everything dynamo needs about - that slot: ``slot_count()`` for flat-``Tensor[]`` slicing, - ``reassemble(chunk)`` / ``reassemble_with_autograd(chunk)`` - for rebuilding the user-facing object from the op's flat - output, and ``alloc()`` for the auto-synthesized fake-impl. - The four concrete subclasses -- :class:`NoneSpec`, - :class:`PlainTensorSpec`, :class:`SubclassTensorSpec`, - :class:`StorageSpec` -- cover every output shape TE - currently produces. - - * ``saved_slots`` is a list of :class:`TensorSpec`, one per - saved-for-backward slot, mirroring ``user_specs`` but for - the saved-tensor section of the op payload. Use - :class:`AliasedSpec(name)` for slots that the forward impl - leaves as ``None`` because the value is identical to a - forward arg (the alias name is surfaced to - ``setup_context`` via - ``ctx_attrs["saved_tensor_aliases"]``, injected by dynamo). - Use :class:`NoneSpec` / :class:`PlainTensorSpec` / - :class:`StorageSpec` / :class:`SubclassTensorSpec` for the - rest, exactly as for user outputs. - - * ``ctx_attrs`` is the non-tensor state attached to the - autograd context (passed through to ``setup_context``). - Dynamo augments it with ``"saved_tensor_aliases"`` before - the callback runs. - - :func:`forward_fn` and the autograd ``setup_context`` use - this descriptor to learn output layouts without ever - materialising a fake prototype tensor -- the only way to - keep layout extraction traceable by Dynamo under - ``fullgraph=True``. The forward fake-impl - (:func:`torch.library.register_fake`) is auto-synthesized - from the same specs via :func:`_make_fake_impl_from_output_info`. - bwd_output_info_fn - Pure-Python alloc descriptor for the backward op: - ``fn(bwd_obj) -> List[TensorSpec]``, one entry per gradient - output in the same order as ``backward_impl``'s return tuple. - Typically :class:`NoneSpec` for missing grads, - :class:`PlainTensorSpec` for plain tensors, and an alloc-only - :class:`SubclassTensorSpec` (built via - :meth:`SubclassTensorSpec.from_quantizer` without a - ``wrapper_cls``) for quantized ones. The backward fake-impl - is synthesized from these specs via - :func:`_make_fake_impl_from_bwd_output_info`, so the - gradient-shape derivation lives entirely in the descriptor. + fwd_fake_impl + Forward fake implementation: ``fn(fwd_obj) -> (*user_outputs, + tensors_to_save, tensor_objects, ctx_attrs)`` -- the same tuple + shape as ``fwd_impl``, but built from *fake* values instead of + running the real kernel. Each slot is one of: + + * ``quantizer.make_fake_empty(shape, dtype, device)`` -- a + Dynamo-safe quantized wrapper (assembled via + ``__tensor_unflatten__`` with a snapshot-free meta). + * ``quantizer.make_empty(shape, dtype, device)`` -- a quantized + storage (e.g. an FP8 weight workspace). + * ``torch.empty(shape, dtype, device)`` -- a plain tensor. + * the actual forward-arg tensor -- for a saved slot that aliases + a forward input (detected by identity). + * ``None`` -- an absent output / saved slot. + + This single callable drives both consumers: ``forward_fn`` / + ``setup_context`` use its fake values directly as reassembly + templates (:func:`_template_slot_count` / + :func:`_template_reassemble`), and + :func:`_fwd_register_fake_from_fake_impl` wires it (with aliased + saved slots nulled) as the op's + :func:`torch.library.register_fake`. The whole callable must be + Dynamo-traceable under ``fullgraph=True``. + bwd_fake_impl + Backward fake implementation: ``fn(bwd_obj) -> grad_tuple``, one + fake grad per gradient output in the same order as + ``backward_impl``'s return tuple (``None`` for missing grads, + ``torch.empty`` for plain, ``quantizer.make_empty`` for + quantized). Wired directly as the backward op's + ``register_fake`` -- backward grads never round-trip through the + op payload, so no layout adapter is needed. Returns ------- @@ -2171,13 +1882,16 @@ def _te_register_custom_op( inner_fwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_fwd_name}" inner_bwd_qualname = f"{_TE_OP_NAMESPACE}::{inner_bwd_name}" - # Auto-synthesize the forward / backward fake impls from the - # alloc-spec descriptors. The synthesized impls share branching - # with their layout counterparts (``output_info_fn`` / - # ``bwd_output_info_fn``) so there's exactly one place where every - # per-precision / per-mode condition lives. - fwd_fake_impl = _make_fake_impl_from_output_info(output_info_fn) - bwd_fake_impl = _make_fake_impl_from_bwd_output_info(bwd_output_info_fn) + # The module supplies its output layout as a forward ``fake_impl`` + # (fake values in the eager-impl tuple shape). ``forward_fn`` / + # ``setup_context`` consume it directly as reassembly templates; the + # forward ``register_fake`` kernel wraps it to null aliased saved slots + # (so the fake flat ``Tensor[]`` matches the eager impl). The backward + # ``fake_impl`` is the backward ``register_fake`` directly. ``field + # names`` are precomputed here (reading ``dataclasses.fields`` in-trace + # would graph-break) for the alias-by-identity detection. + fwd_field_names = [f.name for f in dataclasses.fields(fwd_arg_type)] + fwd_register_fake = _fwd_register_fake_from_fake_impl(fwd_fake_impl, fwd_field_names) _register_kernel( op_name=inner_fwd_name, @@ -2186,7 +1900,7 @@ def _te_register_custom_op( arg_names=fwd_arg_names, buckets=fwd_buckets, impl=fwd_impl, - fake_impl=fwd_fake_impl, + fake_impl=fwd_register_fake, format_result=_format_fwd_result, ) _register_kernel( @@ -2211,7 +1925,8 @@ def _te_register_custom_op( grad_targets=grad_targets, setup_context_user=setup_context, backward_obj_type=backward_obj, - output_info_fn=output_info_fn, + fwd_fake_impl=fwd_fake_impl, + fwd_field_names=fwd_field_names, ) if subclass_list: @@ -2243,7 +1958,8 @@ def _te_register_custom_op( grad_targets=grad_targets, setup_context_user=setup_context, backward_obj_type=backward_obj, - output_info_fn=output_info_fn, + fwd_fake_impl=fwd_fake_impl, + fwd_field_names=fwd_field_names, ) fwd_slot_offsets = _collect_universal_slot_offsets(fwd_buckets) @@ -2303,21 +2019,24 @@ def _bwd_rule(mode, func, types, args, kwargs): fwd_op = getattr(getattr(torch.ops, _TE_OP_NAMESPACE), outer_fwd_name) def forward_fn(fwd_args): - user_specs, _saved_slots, _ctx_attrs = output_info_fn(fwd_args) + user_fakes, _saved_fakes, _ctx_attrs = _split_fwd_fake_result( + fwd_fake_impl(fwd_args) + ) kwargs = _pack(fwd_args, fwd_buckets) flat_in = [kwargs[name] for name in fwd_arg_names] result = fwd_op(*flat_in) - # Slice the flat result by spec. Subclass specs route through - # :class:`_ToSubclassFn` to keep the wrap on the autograd graph; - # plain tensors / storage classes are reconstructed directly. + # Slice the flat result using the fake outputs as templates. Subclass + # templates route through :class:`_ToSubclassFn` to keep the wrap on + # the autograd graph; plain tensors / storage classes are + # reconstructed directly. User outputs never alias a forward arg. cursor = 0 outputs: List[Any] = [] - for spec in user_specs: - n = spec.slot_count() + for template in user_fakes: + n = _template_slot_count(template) chunk = [_decode_none(t) for t in result[cursor:cursor + n]] cursor += n - outputs.append(spec.reassemble_with_autograd(chunk)) + outputs.append(_template_reassemble(template, chunk, with_autograd=True)) if len(outputs) == 1: return outputs[0] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c72a21c8a9..dfa0bb6b51 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -58,9 +58,7 @@ ) from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx, GemmParallelModes, dist_group_type from ..dynamo import ( - TensorSpec, _te_register_custom_op, - tensor_spec, ) from ..graph import is_graph_capturing from ..quantized_tensor import ( @@ -74,9 +72,7 @@ from ..tensor.float8_tensor import ( Float8CurrentScalingQuantizer, Float8Quantizer, - Float8Tensor, ) -from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import clear_columnwise_cache, is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up @@ -1273,28 +1269,26 @@ def wgrad_gemm( # ---------------------------------------------------------------------------- -# Compile-tier wrappers: ``output_info_fn`` descriptors + ``_te_register_custom_op`` +# Compile-tier wrappers: forward / backward ``fake_impl`` + ``_te_register_custom_op`` # registration. The custom op lets ``torch.compile`` trace through linear # forward + backward as a single graph node without entering the eager # ``_Linear`` autograd.Function machinery. Selected by :meth:`Linear.forward` # when ``torch.compiler.is_compiling()`` is true. # ---------------------------------------------------------------------------- -def _linear_backward_output_info( +def _linear_backward_fake_impl( args: LinearBwdArgs, -) -> List[TensorSpec]: - """Pure-Python alloc-spec descriptor for :func:`_linear_backward`. - - Returns a list of three :class:`TensorSpec` -- one per gradient - output ``(wgrad, dgrad, grad_bias)`` -- consumed by the - auto-synthesized backward fake-impl in - :func:`_make_fake_impl_from_bwd_output_info`. Each slot is encoded - through :func:`tensor_spec` (``shape=None`` for absent grads, - ``quantizer`` for quantized ones -- backward grads use alloc-only - ``SubclassTensorSpec`` because they go straight to autograd and - never through the op's flat ``Tensor[]``). ``set_usage`` on +) -> Tuple[Any, Any, Any]: + """Backward fake-impl for :func:`_linear_backward`. + + Returns the ``(wgrad, dgrad, grad_bias)`` gradient triple built from + *fake* values (``None`` for absent grads, ``torch.empty`` for plain, + ``quantizer.make_empty`` for quantized ones), in the same order as + :func:`_linear_backward`'s return tuple. Wired directly as the + backward op's ``register_fake`` -- it runs under fake-prop (not the + Dynamo trace), so ``make_empty`` is fine here. ``set_usage`` on ``grad_input_quantizer`` is preserved because it influences - ``dgrad``'s downstream ``make_empty``. Manual TE FSDP is - unsupported; FSDP2 / MCore FSDP go through the standard path. + ``dgrad``'s allocation. Manual TE FSDP is unsupported; FSDP2 / MCore + FSDP go through the standard path. """ if args.fsdp_group is not None: @@ -1312,44 +1306,37 @@ def _linear_backward_output_info( activation_dtype = args.activation_dtype device = args.grad_output.device - wgrad_shape = ( - (out_features, in_features) + def grad(shape, quantizer): + if shape is None: + return None + if quantizer is not None: + return quantizer.make_empty(list(shape), dtype=activation_dtype, device=device) + return torch.empty(tuple(shape), dtype=activation_dtype, device=device) + + wgrad = ( + grad((out_features, in_features), args.grad_weight_quantizer) if args.requires_wgrad and not args.fuse_wgrad_accumulation else None ) - dgrad_shape = tuple(args.inp_shape) if args.requires_dgrad else None - grad_bias_shape = (out_features,) if args.use_bias and args.requires_wgrad else None + dgrad = grad(args.inp_shape, args.grad_input_quantizer) if args.requires_dgrad else None + grad_bias = grad((out_features,), None) if args.use_bias and args.requires_wgrad else None - return [ - tensor_spec( - shape=wgrad_shape, - dtype=activation_dtype, - device=device, - quantizer=args.grad_weight_quantizer, - ), - tensor_spec( - shape=dgrad_shape, - dtype=activation_dtype, - device=device, - quantizer=args.grad_input_quantizer, - ), - tensor_spec( - shape=grad_bias_shape, - dtype=activation_dtype, - device=device, - ), - ] + return wgrad, dgrad, grad_bias -def _linear_forward_output_info( +def _linear_forward_fake_impl( args: LinearFwdArgs, -) -> Tuple[List[TensorSpec], List[TensorSpec], Dict[str, Any]]: - """Output-layout descriptor for the linear forward. - - Returns ``(user_specs, saved_slots, ctx_attrs)`` -- Dynamo-traceable - layout + alloc info for the op's outputs and saved tensors. - :func:`_te_register_custom_op` synthesizes the fake-impl by calling - :meth:`TensorSpec.alloc` on each entry. +) -> Tuple[Any, Any, Any, Any, Dict[str, Any]]: + """Forward fake-impl for :func:`_linear_forward_impl`. + + Returns ``(out, new_weight_workspace, tensors_to_save, None, + ctx_attrs)`` -- the same tuple shape as the eager impl, but built + from *fake* values (``make_fake_empty`` wrappers / ``make_empty`` + storages / ``torch.empty`` plains / aliased forward args / ``None``). + The ``fake_impl`` -> layout adapter in + :mod:`transformer_engine.pytorch.dynamo` reads the slot layout off + these fake values (and nulls aliased saved slots for the + ``register_fake`` kernel). All ``set_usage`` side effects on the live quantizers happen here and are observed by both the real fwd impl and backward. @@ -1441,9 +1428,9 @@ def _linear_forward_output_info( inputmat_is_storage = False # Weight pipeline -- mirror ``quantize_weight`` / ``cast_if_needed``. - # ``new_weight_workspace_spec`` is non-``NoneSpec`` only on the - # cache-miss + ``cache_weight`` combination. - new_weight_workspace_spec: TensorSpec = tensor_spec() + # ``new_weight_workspace`` is a fresh fake storage only on the + # cache-miss + ``cache_weight`` combination, else ``None``. + new_weight_workspace: Any = None weightmat_is_storage = False weightmat_aliases_weight = False if fp8_or_debug: @@ -1476,12 +1463,10 @@ def _linear_forward_output_info( ): workspace = None if workspace is None and args.cache_weight: - new_weight_workspace_spec = tensor_spec( - shape=weight.shape, - dtype=activation_dtype, - device=weight.device, - quantizer=weight_quantizer, - storage=True, + # Fresh FP8 weight workspace -- a ``*TensorStorage`` + # (``weight_quantizer`` is ``internal``). + new_weight_workspace = weight_quantizer.make_empty( + list(weight.shape), dtype=activation_dtype, device=weight.device ) else: weightmat_aliases_weight = weight.dtype == activation_dtype @@ -1504,18 +1489,17 @@ def _linear_forward_output_info( # User-output [0] -- the GEMM result. ``Float8Tensor`` is the only # quantized wrapper this op produces directly; other quantizer # families flow their workspace through ``new_weight_workspace`` - # instead. - out_spec = tensor_spec( - shape=tuple(out_shape), - dtype=activation_dtype, - device=inp.device, - quantizer=output_quantizer, - wrapper_cls=Float8Tensor if output_quantizer is not None else None, - ) - - user_specs: List[TensorSpec] = [out_spec, new_weight_workspace_spec] + # instead. The quantized output uses ``make_fake_empty`` -- the + # Dynamo-safe wrapper allocator (``make_empty`` cannot build a + # wrapper in-trace because it proxies the live quantizer). + if output_quantizer is not None: + out = output_quantizer.make_fake_empty( + tuple(out_shape), dtype=activation_dtype, device=inp.device + ) + else: + out = torch.empty(tuple(out_shape), dtype=activation_dtype, device=inp.device) - saved_slots: List[TensorSpec] = [] + saved_values: List[Any] = [] if args.is_grad_enabled: # Post-forward ``set_usage`` -- mirrors ``_linear_forward_impl`` @@ -1539,20 +1523,22 @@ def _linear_forward_output_info( # Slot 0 -- ``saved_inputmat``: absent / aliased to ``inp`` / # fresh quantized storage / plain cast (mutually exclusive). + # An aliased slot returns the actual forward arg ``inp``; the + # adapter detects the identity and nulls it in the payload. if not backward_needs_input: - saved_slots.append(tensor_spec()) + saved_values.append(None) elif inputmat_aliases_inp: - saved_slots.append(tensor_spec(alias="inp")) - else: - saved_slots.append( - tensor_spec( - shape=tuple(inp.shape), - dtype=activation_dtype, - device=inp.device, - quantizer=input_quantizer if inputmat_is_storage else None, - storage=inputmat_is_storage, + saved_values.append(inp) + elif inputmat_is_storage: + saved_values.append( + input_quantizer.make_empty( + list(inp.shape), dtype=activation_dtype, device=inp.device ) ) + else: + saved_values.append( + torch.empty(tuple(inp.shape), dtype=activation_dtype, device=inp.device) + ) # Slot 1 -- ``wt_save``. The saved storage's quantizer must # match the one the impl uses for re-quantization, which is @@ -1565,24 +1551,24 @@ def _linear_forward_output_info( else args.weight_quantizer ) if weightmat_aliases_weight: - saved_slots.append(tensor_spec(alias="weight")) + saved_values.append(weight) elif args.is_fsdp2: - saved_slots.append(tensor_spec()) - else: - saved_slots.append( - tensor_spec( - shape=tuple(weight.shape), - dtype=activation_dtype, - device=weight.device, - quantizer=weight_quantizer_for_save if weightmat_is_storage else None, - storage=weightmat_is_storage, + saved_values.append(None) + elif weightmat_is_storage: + saved_values.append( + weight_quantizer_for_save.make_empty( + list(weight.shape), dtype=activation_dtype, device=weight.device ) ) + else: + saved_values.append( + torch.empty(tuple(weight.shape), dtype=activation_dtype, device=weight.device) + ) - # Slot 2 -- ``saved_weight`` (always aliased). Slot 3 -- - # ``saved_bias`` (aliased or absent). - saved_slots.append(tensor_spec(alias="weight")) - saved_slots.append(tensor_spec(alias="bias") if bias is not None else tensor_spec()) + # Slot 2 -- ``saved_weight`` (always aliased to ``weight``). + # Slot 3 -- ``saved_bias`` (aliased to ``bias`` or absent). + saved_values.append(weight) + saved_values.append(bias if bias is not None else None) if args.fsdp_group is not None and args.is_grad_enabled: raise NotImplementedError( @@ -1592,7 +1578,8 @@ def _linear_forward_output_info( ctx_attrs: Dict[str, Any] = {"fsdp_shapes": []} - return user_specs, saved_slots, ctx_attrs + tensors_to_save = tuple(saved_values) if args.is_grad_enabled else None + return out, new_weight_workspace, tensors_to_save, None, ctx_attrs _linear_compiled_op = _te_register_custom_op( @@ -1600,12 +1587,12 @@ def _linear_forward_output_info( input_tensors_for_grad=["weight", "inp", "bias"], fwd_arg_type=LinearFwdArgs, fwd_impl=_linear_forward_impl, - output_info_fn=_linear_forward_output_info, + fwd_fake_impl=_linear_forward_fake_impl, setup_context=_linear_setup_ctx, backward_arg_type=LinearBwdArgs, backward_obj=LinearBwdArgs, backward_impl=_linear_backward, - bwd_output_info_fn=_linear_backward_output_info, + bwd_fake_impl=_linear_backward_fake_impl, ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index de1040e7c2..0758d2ae9d 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -28,6 +28,56 @@ aten = torch.ops.aten + +def _float8_make_fake_empty( + quantizer: "Quantizer", + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> "Float8Tensor": + """Dynamo-safe ``Float8Tensor`` allocation shared by the FP8 quantizers. + + Mirrors the inner-tensor layout of ``make_empty`` (rowwise ``_data`` / + columnwise ``_transpose`` / ``_scale_inv``) but assembles the wrapper + through :meth:`QuantizedTensor.__tensor_unflatten__` -- which takes a + snapshot-free ``meta`` dict (``quantizer_snapshot=None``, ``FP8DType``) + rather than the live ``Quantizer`` / ``tex.DType`` constructor args that + Dynamo cannot proxy inside a traced frame. + """ + from ..dynamo import _contiguous_stride # pylint: disable=import-outside-toplevel + + if device is None: + device = torch.device("cuda") + shape = list(shape) + + alloc: Dict[str, torch.Tensor] = {} + if quantizer.rowwise_usage: + alloc["_data"] = torch.empty(shape, dtype=torch.uint8, device=device) + if quantizer.columnwise_usage: + transpose_shape = [shape[-1]] + list(shape[:-1]) + alloc["_transpose"] = torch.empty(transpose_shape, dtype=torch.uint8, device=device) + alloc["_scale_inv"] = torch.empty(1, dtype=torch.float32, device=device) + + inner_names, meta = quantizer.create_metadata(fake_dtype=dtype) + inner_dict = {name: alloc[name] for name in inner_names} + out = Float8Tensor.__tensor_unflatten__( + inner_dict, meta, tuple(shape), _contiguous_stride(shape) + ) + # Stamp the reassembly plan so the dynamo reassembly helper + # (:func:`transformer_engine.pytorch.dynamo._template_reassemble`) + # can rebuild this subclass from the op's flat ``Tensor[]`` payload + # by reading an attribute (Dynamo-safe) rather than calling + # ``__tensor_flatten__`` in-trace. + out._te_compile_unflatten_plan = (tuple(inner_names), meta) + # Stash the live quantizer on the template so the reassembly helper can + # restore it on the real output (``__tensor_unflatten__`` rebuilds with + # ``quantizer=None`` because the snapshot can't carry a live + # ``ProcessGroup`` / scale-amax tensors through Dynamo guards). + out._te_compile_quantizer = quantizer + return out + + _ops_to_preserve_subclass_in_fsdp2 = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -179,6 +229,24 @@ def make_empty( device=device, ) + def make_fake_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> Float8Tensor: + """Dynamo-safe analogue of :meth:`make_empty`. + + Builds the :class:`Float8Tensor` via + :meth:`QuantizedTensor.__tensor_unflatten__` (snapshot-free meta, + :class:`FP8DType`) instead of the live-quantizer constructor, so it + traces under ``torch.compile(fullgraph=True)`` -- where + :meth:`make_empty` trips on ``UserDefinedObjectVariable(Quantizer)`` + / ``UserDefinedObjectVariable(DType)``. + """ + return _float8_make_fake_empty(self, shape, dtype=dtype, device=device) + def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) @@ -423,6 +491,17 @@ def make_empty( device=device, ) + def make_fake_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> Float8Tensor: + """Dynamo-safe analogue of :meth:`make_empty` (see + :func:`_float8_make_fake_empty`).""" + return _float8_make_fake_empty(self, shape, dtype=dtype, device=device) + def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate return