[PyTorch] [torch.compile] torch.compile support for Linear#3053
Draft
pggPL wants to merge 16 commits into
Draft
Conversation
Add torch custom-op fake implementations for `_linear_forward_impl` and `_linear_backward` so torch.compile can perform shape inference without running the real GEMM / communication / quantization paths. * `fake_cast_if_needed` in `utils.py`: returns an empty tensor of the target dtype when a cast would happen, otherwise returns the input. * `fake_quantize_weight` in `module/base.py`: mirrors the cache-hit / cache-miss control flow of `quantize_weight` but fills cache misses with `quantizer.make_empty`. * `_linear_forward_fake_impl` in `module/linear.py`: mirrors the real forward's control flow and `set_usage` / `update_usage` calls, but replaces computation with empty tensors and skips side effects irrelevant for shape inference (CPU offload, calibration, NCCL/UB collectives, `clear_tensor_data`, FSDP scatter). Communication is simulated by producing tensors with the post-comm shape. Manual TE FSDP (`fsdp_group is not None`) is unsupported. * `_linear_backward_fake_impl` in `module/linear.py`: backward output shapes/dtypes are deterministic, so just allocates empty tensors of the right shape. `grad_input_quantizer.set_usage` is preserved because it influences `dgrad`'s `make_empty`. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop in the `dynamo.py` helper from `gh_linear_torch_compile_support`
@dbe4ca79 ("Add torch.compile support for Linear"), before the
"Cache torch.compile unpack output" (8ab2425) and "Hoist constant
Linear setup out of opaque custom-op body" (9a98ff5) optimizations.
Exposes `ArgObject`, `OpaqueSimpleMetadata`, and
`_te_register_custom_op`, which will be used to route the TE Linear
forward/backward through a torch custom op so torch.compile can trace
through it without entering the eager autograd.Function machinery.
The module is self-contained: it imports only `torch`, `dataclasses`,
`enum`, and `typing`, and does not yet have any callers in this branch.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Six blocks in dynamo.py shipped with broken indentation in the previous commit and prevented the module from being imported (`IndentationError` / `SyntaxError`): - `_UniversalTensorBucket.pack`: storage branch had its body / return split across two indentation levels and ended up outside the matching `if isinstance(value, qts):` block. - `_quantizer_cls` / `_recipe_cls`: `try`/`except` pair misaligned. - `_FlattenableBucket._pack_value`: `if hasattr(value, "_flatten"):` pulled inside the wrong outer branch. - `_SimpleBundleBucket.matches_field`: stray extra indent on a `return True`. - `_resolved_field_annotations`, `_get_buckets`, `_pack`, `_unpack`: function bodies indented under their own header by 8 spaces instead of 4. No semantic changes intended -- this commit restores the structure that the docstrings and surrounding code already document. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Linear.forward / backward under torch.compile previously relied on
hand-written ``_linear_forward_fake_impl`` / ``_linear_backward_fake_impl``
which duplicated every per-precision / per-mode branch from the eager
impl just to materialise fake tensors during AOTAutograd tracing.
This commit replaces both fake impls with pure-Python alloc descriptors
and auto-synthesizes the fake-impls in dynamo.py:
* ``_te_register_custom_op`` now accepts ``bwd_output_info_fn`` next to
the existing ``output_info_fn`` and synthesizes ``fwd_fake_impl`` /
``backward_fake_impl`` from their alloc tuples via
``_make_fake_impl_from_output_info`` /
``_make_fake_impl_from_bwd_output_info``.
* ``output_info_fn`` now returns a 4-tuple
``(user_specs, tensor_objects, ctx_attrs, fake_specs)``; the new
``fake_specs`` dict carries one alloc spec per user output and per
saved slot (``("plain", shape, dtype, device)`` or
``("quantized", quantizer, shape, dtype, device)``).
* ``_linear_forward_fake_impl`` and ``_linear_backward_fake_impl`` are
removed; ``_linear_forward_output_info`` is extended to produce
``fake_specs``, and a new ``_linear_backward_output_info`` (~55 LoC)
describes gradient shapes/dtypes/devices.
* The matching quantizer plumbing -- ``create_metadata`` /
``create_storage_metadata`` / ``create_save_shell`` -- is filled in
for ``MXFP8Quantizer``, ``Float8BlockQuantizer`` and
``NVFP4Quantizer`` so every storage family takes the same path.
* ``Float8TensorStorage._torch_compile_do_unflatten`` no longer
overwrites ``_transpose_invalid`` with stale metadata captured at
flatten time; the post-restore semantic (``_transpose_invalid =
(data_transpose is None)``) is now preserved by ``__new__`` /
``restore_from_saved`` instead.
Net result: a single Python function per op now owns layout,
restore-shape and fake-alloc bookkeeping; the auto-synthesized
fake-impl runs only under ``FakeTensorMode`` (so live quantizers and
``tex.DType`` pybind enums are fine) while the layout descriptor stays
purely in plain Python and remains traceable by Dynamo under
``fullgraph=True``.
All ``tests/pytorch/test_torch_compile.py`` cases pass; ``test_numerics``
linear coverage (882 cases) is green.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Consolidates the spec machinery used by ``_te_register_custom_op``: * Merge ``QuantizedAllocSpec`` into ``SubclassTensorSpec``: a single class now covers both *full* (forward outputs that round-trip through the op's flat ``Tensor[]``) and *alloc-only* (backward grad outputs that go straight to autograd) modes, picked by whether ``wrapper_cls`` is supplied to ``SubclassTensorSpec.from_quantizer``. * Introduce ``tensor_spec(...)`` as the single user-facing factory for declaring every op output / saved slot / grad spec. Picks the right ``TensorSpec`` subclass based on the kwargs given (``alias``, ``shape``, ``quantizer``, ``wrapper_cls``, ``storage``); ``output_info_fn`` authors no longer need to know the underlying class hierarchy. * Drop the explicit ``num_outputs`` (and ``output_annotations``) parameter from ``_te_register_custom_op`` and friends. The user output count is inferred dynamically at call time from the impl return shape -- the trailing three slots are always ``(tensors_to_save, tensor_objects, ctx_attrs)``, so ``num_outputs = len(result) - _FWD_TRAILING_SLOTS`` / ``len(user_specs)`` everywhere they're needed. The op schema is ``Tensor[]`` (variable-length list) so this matches the actual graph contract. * Move ``_contiguous_stride`` from ``linear.py`` to ``dynamo.py`` (private helper for ``from_quantizer`` classmethods). * Sync ``create_storage_metadata`` in every quantized tensor module to include the quantizer's flattened tensor count in ``StorageSpec.tensor_count`` so slot counts match the actual ``_torch_compile_flatten`` payload. Net result for ``linear.py``: the public surface from ``dynamo.py`` shrinks to ``TensorSpec`` (type hint), ``_te_register_custom_op``, and ``tensor_spec``; all five concrete spec classes become implementation details of ``tensor_spec``. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove the legacy fake-impl prototype path so compiled TE ops derive output layouts and fake kernels from explicit TensorSpec descriptors. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove ``requires_grad``/``as_tensor`` from ``SubclassTensorSpec.from_quantizer`` and ``StorageSpec.from_quantizer`` (always defaulted, never threaded through), drop the vestigial ``ctx_attrs`` slot from ``_make_fake_impl_from_output_info`` (``_format_fwd_result`` never reads it), and make ``fake_impl`` required in ``_register_kernel`` (always supplied since output_info_fn became mandatory). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Move ``__tensor_flatten__`` / ``__tensor_unflatten__`` and ``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` onto the ``QuantizedTensor`` / ``QuantizedTensorStorage`` base classes, driven by three per-subclass class attributes: ``_FLATTEN_TENSOR_ATTRS``, ``_FLATTEN_META_ATTRS``, and ``_FLATTEN_CTOR_KWARG``, plus a ``_flatten_meta_overrides`` hook (used by ``Float8Tensor`` to bridge ``FP8DType`` <-> ``tex.DType``). Both the PyTorch wrapper-subclass protocol and the storage triplet protocol now share a single source of truth for the field schema, and the four storage shells + ``Float8Tensor`` drop their per-class implementations in favor of ~10 lines of declarations. As a side effect, ``MXFP8Tensor`` / ``NVFP4Tensor`` / ``Float8BlockwiseQTensor`` gain ``__tensor_flatten__`` / ``__tensor_unflatten__`` for free, enabling future use as ``torch.compile`` input subclasses. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…chema Move ``create_storage_metadata`` onto the ``Quantizer`` base class, driven by the storage's ``_FLATTEN_*`` declarations plus two new hooks: a per-quantizer ``_storage_cls`` class attribute pointing at the storage subclass it produces and a ``_storage_scalars()`` method returning the quantizer-specific scalar fields (``fp8_dtype``, ``with_gemm_swizzled_scales``, ...). Per-attribute presence flags are derived from the storage's new ``_FLATTEN_TENSOR_USAGE`` map plus the quantizer's ``rowwise_usage`` / ``columnwise_usage``. The four bespoke ``*Quantizer.create_storage_metadata`` methods and the ``_float8_create_storage_metadata`` helper collapse to ~5 lines of declarations per quantizer. Also slim the storage flatten path: ``is_tensor`` / ``shape`` / ``requires_grad`` / ``device`` are only emitted when the storage is flattened as a ``torch.Tensor`` (the ``_rewrite_subclass_to_storage`` path); the storage-only path keeps a smaller meta payload, and ``_torch_compile_do_unflatten`` falls back via ``meta.get`` instead of requiring those keys. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…schema
Move the TE-internal Quantizer flattening protocol (used by
``_FlattenableBucket`` in ``dynamo.py`` to round-trip quantizers
through compiled regions and to embed ``quantizer_meta`` inside
storage flatten payloads) onto the ``Quantizer`` base class.
Subclasses now declare per-class attrs instead of writing 30-line
``_flatten`` / ``_do_unflatten`` pairs:
* ``_DTYPE_INIT_KWARG`` -- ``"fp8_dtype"`` (default) or ``"fp4_dtype"``.
* ``_INIT_META_ATTRS`` / ``_POST_INIT_META_ATTRS`` -- scalar attrs
threaded through ``__init__`` vs. set after construction.
* ``_INIT_TENSOR_ATTRS`` / ``_POST_INIT_TENSOR_ATTRS`` -- tensor attrs
in flatten order, split the same way.
* ``_PG_ATTR`` / ``_PG_INIT_KWARG`` -- name of the optional
``ProcessGroup`` attribute and (optionally) the ``__init__`` kwarg
to thread it through.
* ``_FIXED_INIT_KWARGS`` -- hardcoded ``__init__`` kwargs not derived
from meta (e.g. ``device=torch.device("cuda")`` for
``Float8CurrentScalingQuantizer``).
The base ``dtype`` / ``rowwise_usage`` / ``columnwise_usage`` round-trip
through ``__init__`` and ``internal`` / ``optimize_for_gemm`` through
post-init are handled by the base class.
``MXFP8Quantizer`` drops both methods entirely (no extras to declare).
``Float8Quantizer`` collapses to a single ``_INIT_TENSOR_ATTRS``
declaration. ``NVFP4Quantizer`` drops ~45 lines.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Remove `fake_cast_if_needed` (utils.py) and `fake_quantize_weight` (module/base.py): leftover from the hand-written fake-impl design that output_info_fn / TensorSpec.alloc replaced; no remaining call sites. - Remove unused ``KIND`` class attributes from TensorSpec subclasses in dynamo.py; never read after the spec-dispatch refactor. - Fix stale comment references to nonexistent symbols (`_rewrite_subclass_to_storage`, `_generic_tensor_flatten`). - Trim verbose narration about "replaced hand-written fake-impls" in linear.py docstrings; the current behavior is described, history belongs in git. - Drop redundant "generic implementations on base" comments from MXFP8 and Float8Block storage; the ``_FLATTEN_*`` declarations speak for themselves. Keep the genuine `_transpose_invalid` rationale in float8_tensor_storage.py. - Fix copy-paste comment in nvfp4_tensor_storage.py (MXFP8 -> NVFP4). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Move the (inner_names, meta) builder for SubclassTensorSpec onto the
base ``Quantizer``, driven by the storage class's
``_FLATTEN_TENSOR_ATTRS`` / ``_FLATTEN_TENSOR_USAGE`` plus the
quantizer's existing ``_storage_scalars()`` hook. ``inner_names``
now follows declaration order so it stays in sync with
``__tensor_flatten__`` (used by ``_flatten_value_into`` when an FP8
output is serialized back into the op's flat Tensor[] payload).
The pybind ``tex.DType`` -> Python ``FP8DType`` conversion (needed
to keep the meta dict Dynamo-traceable as a constant on
``_ToSubclassFn``) is centralised behind a single class attribute
``_SUBCLASS_META_TEX_KEYS`` that defaults to ``("fp8_dtype",)``.
Drops the standalone ``_float8_create_subclass_metadata`` helper and
the two duplicate ``Float8*Quantizer.create_metadata`` overrides;
they reduce to the generic base behavior plus the existing
``_storage_scalars`` declaration.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
``LinearBwdArgs.fp8_recipe: Optional[Recipe]`` was shipped through the compiled custom op only to read two booleans on the backward side -- ``recipe.fp8_gemm_dgrad.use_split_accumulator`` and ``recipe.fp8_gemm_wgrad.use_split_accumulator``. Replace the Recipe-typed field with those two booleans, populated in setup-ctx the same way the forward already computes ``use_split_accumulator`` from the live recipe. With nothing left to flatten through the bucket scanner: - Drop ``Recipe._flatten`` / ``Recipe._unflatten``, the ``_RECIPE_REGISTRY`` dispatch dict and the explicit ``_register_recipe_subclass`` calls at the bottom of ``common/recipe/__init__.py`` (~90 lines). - Drop ``_recipe_cls()`` / ``_RECIPE_REF`` and remove ``Recipe`` from ``_flattenable_bases`` in ``pytorch/dynamo.py``. - Tighten the now-quantizer-only docstrings on ``_FlattenableBucket`` / ``_MetaPGTensorsBucket``. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The four storage modules each carried an identical try/except block that injected ``__fx_repr__`` on ``transformer_engine_torch.DType`` and registered it as a torch.compile value-opaque type. Move the registration to its natural home in ``fp8_dtype.py`` -- right next to ``to_tex`` / ``from_tex`` -- and load it once from ``quantized_tensor.py`` (which every storage / quantizer subclass already pulls in). Same behavior, one place instead of four. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
… templates Replace the TensorSpec descriptor hierarchy in dynamo.py with two free helpers (_template_slot_count / _template_reassemble) that read the slot layout and rebuild the user-facing objects straight off the forward fake_impl's fake values. forward_fn / setup_context now consume fwd_fake_impl directly instead of a derived output_info_fn. Also preserve the FP8 output's quantizer (and its amax-reduction group) across the compile-path reassembly: make_fake_empty stashes the live quantizer on the fake template and _template_reassemble restores it, since __tensor_unflatten__ rebuilds with quantizer=None. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: