Skip to content

[PyTorch] [torch.compile] torch.compile support for Linear#3053

Draft
pggPL wants to merge 16 commits into
NVIDIA:mainfrom
pggPL:linear_torch_compile_final_attempt
Draft

[PyTorch] [torch.compile] torch.compile support for Linear#3053
pggPL wants to merge 16 commits into
NVIDIA:mainfrom
pggPL:linear_torch_compile_final_attempt

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 28, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 16 commits May 12, 2026 16:07
Add torch custom-op fake implementations for `_linear_forward_impl` and
`_linear_backward` so torch.compile can perform shape inference without
running the real GEMM / communication / quantization paths.

* `fake_cast_if_needed` in `utils.py`: returns an empty tensor of the
  target dtype when a cast would happen, otherwise returns the input.
* `fake_quantize_weight` in `module/base.py`: mirrors the cache-hit /
  cache-miss control flow of `quantize_weight` but fills cache misses
  with `quantizer.make_empty`.
* `_linear_forward_fake_impl` in `module/linear.py`: mirrors the real
  forward's control flow and `set_usage` / `update_usage` calls, but
  replaces computation with empty tensors and skips side effects
  irrelevant for shape inference (CPU offload, calibration, NCCL/UB
  collectives, `clear_tensor_data`, FSDP scatter). Communication is
  simulated by producing tensors with the post-comm shape. Manual TE
  FSDP (`fsdp_group is not None`) is unsupported.
* `_linear_backward_fake_impl` in `module/linear.py`: backward output
  shapes/dtypes are deterministic, so just allocates empty tensors of
  the right shape. `grad_input_quantizer.set_usage` is preserved
  because it influences `dgrad`'s `make_empty`.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop in the `dynamo.py` helper from `gh_linear_torch_compile_support`
@dbe4ca79 ("Add torch.compile support for Linear"), before the
"Cache torch.compile unpack output" (8ab2425) and "Hoist constant
Linear setup out of opaque custom-op body" (9a98ff5) optimizations.

Exposes `ArgObject`, `OpaqueSimpleMetadata`, and
`_te_register_custom_op`, which will be used to route the TE Linear
forward/backward through a torch custom op so torch.compile can trace
through it without entering the eager autograd.Function machinery.

The module is self-contained: it imports only `torch`, `dataclasses`,
`enum`, and `typing`, and does not yet have any callers in this branch.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Six blocks in dynamo.py shipped with broken indentation in the
previous commit and prevented the module from being imported
(`IndentationError` / `SyntaxError`):

- `_UniversalTensorBucket.pack`: storage branch had its body / return
  split across two indentation levels and ended up outside the
  matching `if isinstance(value, qts):` block.
- `_quantizer_cls` / `_recipe_cls`: `try`/`except` pair misaligned.
- `_FlattenableBucket._pack_value`: `if hasattr(value, "_flatten"):`
  pulled inside the wrong outer branch.
- `_SimpleBundleBucket.matches_field`: stray extra indent on a
  `return True`.
- `_resolved_field_annotations`, `_get_buckets`, `_pack`, `_unpack`:
  function bodies indented under their own header by 8 spaces
  instead of 4.

No semantic changes intended -- this commit restores the structure
that the docstrings and surrounding code already document.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Linear.forward / backward under torch.compile previously relied on
hand-written ``_linear_forward_fake_impl`` / ``_linear_backward_fake_impl``
which duplicated every per-precision / per-mode branch from the eager
impl just to materialise fake tensors during AOTAutograd tracing.

This commit replaces both fake impls with pure-Python alloc descriptors
and auto-synthesizes the fake-impls in dynamo.py:

* ``_te_register_custom_op`` now accepts ``bwd_output_info_fn`` next to
  the existing ``output_info_fn`` and synthesizes ``fwd_fake_impl`` /
  ``backward_fake_impl`` from their alloc tuples via
  ``_make_fake_impl_from_output_info`` /
  ``_make_fake_impl_from_bwd_output_info``.
* ``output_info_fn`` now returns a 4-tuple
  ``(user_specs, tensor_objects, ctx_attrs, fake_specs)``; the new
  ``fake_specs`` dict carries one alloc spec per user output and per
  saved slot (``("plain", shape, dtype, device)`` or
  ``("quantized", quantizer, shape, dtype, device)``).
* ``_linear_forward_fake_impl`` and ``_linear_backward_fake_impl`` are
  removed; ``_linear_forward_output_info`` is extended to produce
  ``fake_specs``, and a new ``_linear_backward_output_info`` (~55 LoC)
  describes gradient shapes/dtypes/devices.
* The matching quantizer plumbing -- ``create_metadata`` /
  ``create_storage_metadata`` / ``create_save_shell`` -- is filled in
  for ``MXFP8Quantizer``, ``Float8BlockQuantizer`` and
  ``NVFP4Quantizer`` so every storage family takes the same path.
* ``Float8TensorStorage._torch_compile_do_unflatten`` no longer
  overwrites ``_transpose_invalid`` with stale metadata captured at
  flatten time; the post-restore semantic (``_transpose_invalid =
  (data_transpose is None)``) is now preserved by ``__new__`` /
  ``restore_from_saved`` instead.

Net result: a single Python function per op now owns layout,
restore-shape and fake-alloc bookkeeping; the auto-synthesized
fake-impl runs only under ``FakeTensorMode`` (so live quantizers and
``tex.DType`` pybind enums are fine) while the layout descriptor stays
purely in plain Python and remains traceable by Dynamo under
``fullgraph=True``.

All ``tests/pytorch/test_torch_compile.py`` cases pass; ``test_numerics``
linear coverage (882 cases) is green.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Consolidates the spec machinery used by ``_te_register_custom_op``:

* Merge ``QuantizedAllocSpec`` into ``SubclassTensorSpec``: a single
  class now covers both *full* (forward outputs that round-trip
  through the op's flat ``Tensor[]``) and *alloc-only* (backward grad
  outputs that go straight to autograd) modes, picked by whether
  ``wrapper_cls`` is supplied to ``SubclassTensorSpec.from_quantizer``.
* Introduce ``tensor_spec(...)`` as the single user-facing factory
  for declaring every op output / saved slot / grad spec. Picks the
  right ``TensorSpec`` subclass based on the kwargs given
  (``alias``, ``shape``, ``quantizer``, ``wrapper_cls``,
  ``storage``); ``output_info_fn`` authors no longer need to know
  the underlying class hierarchy.
* Drop the explicit ``num_outputs`` (and ``output_annotations``)
  parameter from ``_te_register_custom_op`` and friends. The user
  output count is inferred dynamically at call time from the impl
  return shape -- the trailing three slots are always
  ``(tensors_to_save, tensor_objects, ctx_attrs)``, so
  ``num_outputs = len(result) - _FWD_TRAILING_SLOTS`` /
  ``len(user_specs)`` everywhere they're needed. The op schema is
  ``Tensor[]`` (variable-length list) so this matches the actual
  graph contract.
* Move ``_contiguous_stride`` from ``linear.py`` to ``dynamo.py``
  (private helper for ``from_quantizer`` classmethods).
* Sync ``create_storage_metadata`` in every quantized tensor module
  to include the quantizer's flattened tensor count in
  ``StorageSpec.tensor_count`` so slot counts match the actual
  ``_torch_compile_flatten`` payload.

Net result for ``linear.py``: the public surface from ``dynamo.py``
shrinks to ``TensorSpec`` (type hint), ``_te_register_custom_op``,
and ``tensor_spec``; all five concrete spec classes become
implementation details of ``tensor_spec``.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove the legacy fake-impl prototype path so compiled TE ops derive output layouts and fake kernels from explicit TensorSpec descriptors.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Remove ``requires_grad``/``as_tensor`` from ``SubclassTensorSpec.from_quantizer``
and ``StorageSpec.from_quantizer`` (always defaulted, never threaded through),
drop the vestigial ``ctx_attrs`` slot from ``_make_fake_impl_from_output_info``
(``_format_fwd_result`` never reads it), and make ``fake_impl`` required in
``_register_kernel`` (always supplied since output_info_fn became mandatory).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Move ``__tensor_flatten__`` / ``__tensor_unflatten__`` and
``_torch_compile_flatten`` / ``_torch_compile_do_unflatten`` onto the
``QuantizedTensor`` / ``QuantizedTensorStorage`` base classes, driven by
three per-subclass class attributes: ``_FLATTEN_TENSOR_ATTRS``,
``_FLATTEN_META_ATTRS``, and ``_FLATTEN_CTOR_KWARG``, plus a
``_flatten_meta_overrides`` hook (used by ``Float8Tensor`` to bridge
``FP8DType`` <-> ``tex.DType``).

Both the PyTorch wrapper-subclass protocol and the storage triplet
protocol now share a single source of truth for the field schema, and
the four storage shells + ``Float8Tensor`` drop their per-class
implementations in favor of ~10 lines of declarations.

As a side effect, ``MXFP8Tensor`` / ``NVFP4Tensor`` /
``Float8BlockwiseQTensor`` gain ``__tensor_flatten__`` /
``__tensor_unflatten__`` for free, enabling future use as
``torch.compile`` input subclasses.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…chema

Move ``create_storage_metadata`` onto the ``Quantizer`` base class,
driven by the storage's ``_FLATTEN_*`` declarations plus two new
hooks: a per-quantizer ``_storage_cls`` class attribute pointing at
the storage subclass it produces and a ``_storage_scalars()`` method
returning the quantizer-specific scalar fields (``fp8_dtype``,
``with_gemm_swizzled_scales``, ...). Per-attribute presence flags are
derived from the storage's new ``_FLATTEN_TENSOR_USAGE`` map plus the
quantizer's ``rowwise_usage`` / ``columnwise_usage``.

The four bespoke ``*Quantizer.create_storage_metadata`` methods and
the ``_float8_create_storage_metadata`` helper collapse to ~5 lines
of declarations per quantizer.

Also slim the storage flatten path: ``is_tensor`` / ``shape`` /
``requires_grad`` / ``device`` are only emitted when the storage is
flattened as a ``torch.Tensor`` (the ``_rewrite_subclass_to_storage``
path); the storage-only path keeps a smaller meta payload, and
``_torch_compile_do_unflatten`` falls back via ``meta.get`` instead
of requiring those keys.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…schema

Move the TE-internal Quantizer flattening protocol (used by
``_FlattenableBucket`` in ``dynamo.py`` to round-trip quantizers
through compiled regions and to embed ``quantizer_meta`` inside
storage flatten payloads) onto the ``Quantizer`` base class.

Subclasses now declare per-class attrs instead of writing 30-line
``_flatten`` / ``_do_unflatten`` pairs:

* ``_DTYPE_INIT_KWARG`` -- ``"fp8_dtype"`` (default) or ``"fp4_dtype"``.
* ``_INIT_META_ATTRS`` / ``_POST_INIT_META_ATTRS`` -- scalar attrs
  threaded through ``__init__`` vs. set after construction.
* ``_INIT_TENSOR_ATTRS`` / ``_POST_INIT_TENSOR_ATTRS`` -- tensor attrs
  in flatten order, split the same way.
* ``_PG_ATTR`` / ``_PG_INIT_KWARG`` -- name of the optional
  ``ProcessGroup`` attribute and (optionally) the ``__init__`` kwarg
  to thread it through.
* ``_FIXED_INIT_KWARGS`` -- hardcoded ``__init__`` kwargs not derived
  from meta (e.g. ``device=torch.device("cuda")`` for
  ``Float8CurrentScalingQuantizer``).

The base ``dtype`` / ``rowwise_usage`` / ``columnwise_usage`` round-trip
through ``__init__`` and ``internal`` / ``optimize_for_gemm`` through
post-init are handled by the base class.

``MXFP8Quantizer`` drops both methods entirely (no extras to declare).
``Float8Quantizer`` collapses to a single ``_INIT_TENSOR_ATTRS``
declaration. ``NVFP4Quantizer`` drops ~45 lines.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Remove `fake_cast_if_needed` (utils.py) and `fake_quantize_weight`
  (module/base.py): leftover from the hand-written fake-impl design that
  output_info_fn / TensorSpec.alloc replaced; no remaining call sites.
- Remove unused ``KIND`` class attributes from TensorSpec subclasses in
  dynamo.py; never read after the spec-dispatch refactor.
- Fix stale comment references to nonexistent symbols
  (`_rewrite_subclass_to_storage`, `_generic_tensor_flatten`).
- Trim verbose narration about "replaced hand-written fake-impls" in
  linear.py docstrings; the current behavior is described, history
  belongs in git.
- Drop redundant "generic implementations on base" comments from MXFP8
  and Float8Block storage; the ``_FLATTEN_*`` declarations speak for
  themselves. Keep the genuine `_transpose_invalid` rationale in
  float8_tensor_storage.py.
- Fix copy-paste comment in nvfp4_tensor_storage.py (MXFP8 -> NVFP4).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Move the (inner_names, meta) builder for SubclassTensorSpec onto the
base ``Quantizer``, driven by the storage class's
``_FLATTEN_TENSOR_ATTRS`` / ``_FLATTEN_TENSOR_USAGE`` plus the
quantizer's existing ``_storage_scalars()`` hook. ``inner_names``
now follows declaration order so it stays in sync with
``__tensor_flatten__`` (used by ``_flatten_value_into`` when an FP8
output is serialized back into the op's flat Tensor[] payload).

The pybind ``tex.DType`` -> Python ``FP8DType`` conversion (needed
to keep the meta dict Dynamo-traceable as a constant on
``_ToSubclassFn``) is centralised behind a single class attribute
``_SUBCLASS_META_TEX_KEYS`` that defaults to ``("fp8_dtype",)``.

Drops the standalone ``_float8_create_subclass_metadata`` helper and
the two duplicate ``Float8*Quantizer.create_metadata`` overrides;
they reduce to the generic base behavior plus the existing
``_storage_scalars`` declaration.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
``LinearBwdArgs.fp8_recipe: Optional[Recipe]`` was shipped through
the compiled custom op only to read two booleans on the backward
side -- ``recipe.fp8_gemm_dgrad.use_split_accumulator`` and
``recipe.fp8_gemm_wgrad.use_split_accumulator``. Replace the
Recipe-typed field with those two booleans, populated in
setup-ctx the same way the forward already computes
``use_split_accumulator`` from the live recipe.

With nothing left to flatten through the bucket scanner:

- Drop ``Recipe._flatten`` / ``Recipe._unflatten``, the
  ``_RECIPE_REGISTRY`` dispatch dict and the explicit
  ``_register_recipe_subclass`` calls at the bottom of
  ``common/recipe/__init__.py`` (~90 lines).
- Drop ``_recipe_cls()`` / ``_RECIPE_REF`` and remove ``Recipe``
  from ``_flattenable_bases`` in ``pytorch/dynamo.py``.
- Tighten the now-quantizer-only docstrings on
  ``_FlattenableBucket`` / ``_MetaPGTensorsBucket``.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The four storage modules each carried an identical try/except block
that injected ``__fx_repr__`` on ``transformer_engine_torch.DType``
and registered it as a torch.compile value-opaque type. Move the
registration to its natural home in ``fp8_dtype.py`` -- right next
to ``to_tex`` / ``from_tex`` -- and load it once from
``quantized_tensor.py`` (which every storage / quantizer subclass
already pulls in). Same behavior, one place instead of four.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
… templates

Replace the TensorSpec descriptor hierarchy in dynamo.py with two free
helpers (_template_slot_count / _template_reassemble) that read the slot
layout and rebuild the user-facing objects straight off the forward
fake_impl's fake values. forward_fn / setup_context now consume
fwd_fake_impl directly instead of a derived output_info_fn.

Also preserve the FP8 output's quantizer (and its amax-reduction group)
across the compile-path reassembly: make_fake_empty stashes the live
quantizer on the fake template and _template_reassemble restores it,
since __tensor_unflatten__ rebuilds with quantizer=None.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant