Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 138 additions & 2 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __fx_repr__(self):
def _make_qfactory(tag: str):
"""Return a qfactory that produces ToyQuantizer instances tagged with *tag*."""

def qfactory(role: str):
return ToyQuantizer(tag=f"{tag}:{role}")
def qfactory(role):
return ToyQuantizer(tag=f"{tag}:{role.tensor_type}")

return qfactory

Expand Down Expand Up @@ -324,3 +324,139 @@ def fn(inp):

out = compiled(inp)
out.sum().backward()


@pytest.mark.parametrize(
"fp8_recipe",
[None, *_all_recipes],
ids=lambda r: "bf16" if r is None else type(r).__name__,
)
def test_te_linear_compiles(fp8_recipe):
"""torch.compile(fullgraph=True) of ``te.Linear`` under every built-in
recipe (and the bf16-only baseline with no autocast).

Exercises the custom-op path in
:mod:`transformer_engine.pytorch.dynamo`: forward goes through
``_linear_compiled_op``, backward through the registered
``transformer_engine::linear_backward`` op, and the dataclass
arg-objects are packed/unpacked via the bucket dispatch in
:mod:`transformer_engine.pytorch.dynamo`.
"""
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)

dtype = torch.bfloat16
device = "cuda"

# FP8 GEMMs require leading dimensions divisible by 16; pick
# in/out features and batch comfortably above that minimum.
model = te.Linear(64, 32, params_dtype=dtype, device=device)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
if fp8_recipe is None:
return model(inp)
with te.autocast(recipe=fp8_recipe):
return model(inp)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
out.sum().backward()
assert out.shape == (32, 32)
assert inp.grad is not None
assert model.weight.grad is not None, "weight.grad missing"
assert model.weight.grad.shape == model.weight.shape, (
f"weight.grad shape {tuple(model.weight.grad.shape)} != "
f"weight shape {tuple(model.weight.shape)}"
)


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_te_linear_compile_with_quantized_fp8_weight():
"""torch.compile should handle Linear weights initialized as FP8 tensors."""
dtype = torch.bfloat16
device = "cuda"
fp8_recipe = recipe.Float8CurrentScaling()

with te.quantized_model_init(enabled=True, recipe=fp8_recipe):
model = te.Linear(64, 32, params_dtype=dtype, device=device)

assert isinstance(model.weight, te.Float8Tensor)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
with te.autocast(recipe=fp8_recipe):
return model(inp)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
out.sum().backward()
assert out.shape == (32, 32)
assert inp.grad is not None
assert model.weight.grad is not None, "Float8Tensor weight.grad missing"
assert model.weight.grad.shape == model.weight.shape, (
f"Float8Tensor weight.grad shape {tuple(model.weight.grad.shape)} != "
f"weight shape {tuple(model.weight.shape)}"
)


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_te_linear_compile_with_fp8_output():
"""torch.compile of ``te.Linear(..., fp8_output=True)``: forward returns
a :class:`Float8Tensor`.

Exercises the output-rewrap path in
:mod:`transformer_engine.pytorch.dynamo`: the first user output is
declared ``Union[torch.Tensor, Float8Tensor]`` in ``output_annotations``,
and when an output quantizer is active the eager + fake paths must
rewrap the inner data tensors back into a ``Float8Tensor`` for the
user-facing slot.

Backward through a subclass return value is a known PyTorch
``torch.compile`` limitation (Dynamo / AOT autograd drop the
``grad_fn`` on wrapper-subclass outputs of custom ops, so
``out.sum().backward()`` errors with "element 0 of tensors does
not require grad and does not have a grad_fn"). The forward shape
+ type assertions below are sufficient to exercise the rewrap;
grad-routing on FP8 outputs under compile is left as future work.
"""
dtype = torch.bfloat16
device = "cuda"
fp8_recipe = recipe.Float8CurrentScaling()

model = te.Linear(64, 32, params_dtype=dtype, device=device)
inp = torch.randn(32, 64, dtype=dtype, device=device, requires_grad=True)

def fn(inp):
with te.autocast(recipe=fp8_recipe):
return model(inp, fp8_output=True)

torch._dynamo.reset()
compiled = torch.compile(fn, fullgraph=True)

out = compiled(inp)
assert isinstance(out, te.Float8Tensor), (
f"expected Float8Tensor output, got {type(out).__name__}"
)
assert out.shape == (32, 32)
# The compile-path reassembly rebuilds the wrapper via
# ``__tensor_unflatten__``, whose snapshot-free ``meta`` forces
# ``quantizer=None`` (a live ``ProcessGroup`` / amax-reduction group
# can't survive Dynamo guards). ``make_fake_empty`` stashes the live
# quantizer on the fake template and the reassembly helper restores it,
# so the output must keep a (non-``None``) quantizer rather than losing
# its amax-reduction group.
assert out._quantizer is not None, (
"FP8 output lost its quantizer (and thus its amax-reduction group) "
"on the torch.compile path"
)
# Dequantising outside the compiled region exercises the
# ``Float8Tensor`` machinery (scale + data + dtype all wired up
# by the rewrap) on the value returned from the compiled fn.
deq = out.dequantize()
assert deq.shape == (32, 32)
assert deq.dtype == dtype
2 changes: 2 additions & 0 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,5 @@ def _make_repr(self) -> str:
f"qfactory={self.qfactory}, "
f"backward_override={self.backward_override}"
)


28 changes: 28 additions & 0 deletions transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@
tex.DType.kBFloat16: torch.bfloat16,
}

# Map: TE DType *id* (Python int) -> TE DType enum. Used by
# :func:`canonicalize_te_dtype` to recover the pybind enum from its
# integer id without going through ``tex.DType(int)``, which Dynamo
# cannot trace (pybind11 enum constructor is opaque).
TE_DType_ID_To_TE = {
int(tex.DType.kByte): tex.DType.kByte,
int(tex.DType.kFloat8E4M3): tex.DType.kFloat8E4M3,
int(tex.DType.kFloat8E5M2): tex.DType.kFloat8E5M2,
int(tex.DType.kFloat4E2M1): tex.DType.kFloat4E2M1,
int(tex.DType.kInt32): tex.DType.kInt32,
int(tex.DType.kFloat32): tex.DType.kFloat32,
int(tex.DType.kFloat16): tex.DType.kFloat16,
int(tex.DType.kBFloat16): tex.DType.kBFloat16,
}


def canonicalize_te_dtype(dtype):
"""Accept either a TE ``DType`` enum or its Python ``int`` id.

Recipe state keeps dtype ids as Python ``int`` values for cheap,
trace-friendly comparisons. Quantizer objects, however, are passed to
TE's C++ bindings, which expect the pybind ``tex.DType`` enum.
"""
if isinstance(dtype, int):
return TE_DType_ID_To_TE[dtype]
return dtype


# Cache enum -> int conversions to avoid repeated PyObject lookups.
FP8FwdTensorIdx = SimpleNamespace(
GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT),
Expand Down
Loading