Skip to content

Add opt-in FP8 DiT TensorRT engine for SA3-medium#47

Draft
ryanontheinside wants to merge 1 commit into
Stability-AI:mainfrom
ryanontheinside:feat/dit-fp8
Draft

Add opt-in FP8 DiT TensorRT engine for SA3-medium#47
ryanontheinside wants to merge 1 commit into
Stability-AI:mainfrom
ryanontheinside:feat/dit-fp8

Conversation

@ryanontheinside

Copy link
Copy Markdown

DRAFT: FP8 DiT TensorRT engine (opt-in, ~1.8x per step, ~2.5x batched throughput)

Draft. The dit_fp8.onnx artifact is not yet published to HuggingFace, so
this PR lands the build recipe and the --precision fp8 wiring for review; the
consumer paths 404 until a producer run uploads the ONNX (see Status below).

Adds an optional FP8 GEMM-trunk TensorRT engine for the SA3-medium DiT, on top
of the existing FP16-mixed recipe. The DiT step is the inner loop of the
pingpong sampler, so this is the highest-leverage place to cut latency. The
engine keeps FP32 inputs/outputs, so it is a drop-in swap for the FP16-mixed
engine at inference (sa3_trt --precision fp8), pairing an FP8 DiT with the
existing FP16-mixed decoder.

Why FP8 here (and why batching is the larger win)

The FP16-mixed DiT engine is compute-saturated: a single batched forward barely
amortizes (<=1.09x at B=4), so a ring-buffer pipeline at depth > 1 hits a flat
throughput ceiling. FP8 cuts per-row GEMM compute ~1.8x, which frees the SM
throughput the FP16 engine saturated, so batching amortizes.

The recipe (why it is more than mtq.quantize)

build_dit_fp8.py takes the published dit_fp16mixed.onnx plus a calibration
.npz and produces dit_fp8.onnx:

  1. Kahn-sort + opset-19 convert (the FP16-mixed graph is left un-toposorted by
    the island surgery, and ModelOpt/ORT reject that).
  2. ModelOpt FP8 PTQ on MatMul/Gemm only, disable_mha_qdq (attention BMMs stay
    on the FP16/FP32 path), max calibration.
  3. Restore the handful of initializers ModelOpt corrupts during preprocessing,
    and recalibrate activation scales on a Q/DQ-bypassed copy with real
    conditioning.
  4. Re-apply the FP32 islands the FP16-mixed recipe established (RMSNorm /
    Softmax / RoPE) plus the conditioning front-end, which must stay FP32 or the
    t>=0.984 timestep features flush.
  5. Per-channel weight scales along the GEMM N axis (constant-folded at build,
    free at runtime). Activations stay per-tensor (TRT requirement for FP8
    activation quant); max calibration, because the activation outliers are
    signal and percentile clipping regresses parity.

Calibration data is captured from the model's own generate() by
make_calib.py, which records the six DiT engine inputs across the pingpong
schedule. Prompts come from this repo's interface/reprompt.py Music examples,
the deployment-matched reprompt format. The .npz is a reproducible producer
artifact (gitignored, never committed).

What is in this PR

Producer (model maintainers):

  • build/make_calib.py (new): calibration capture from the checkpoint.
  • build/build_dit_fp8.py (new): the FP8 build recipe above.

Consumer:

  • build/build_from_onnx.py, build/build.py: add sa3-m-fp8 as an opt-in
    target. It is excluded from all / all-both / "build all missing" and is
    built only by explicit name, gated on the published ONNX existing.
  • scripts/sa3_trt.py, scripts/sa3_trt_core.py: --precision fp8 selection
    (FP8 DiT + FP16-mixed decoder).
  • build/README.md: producer and consumer documentation.
  • .gitignore: ignore *.calib.npz.

Testing and validation

All results below were produced on a single RTX 5090 (sm_120), TensorRT
10.16, on SA3-medium at L=646 (the latent length of a ~54 s generation). Step
times are hardware-dependent; the speedup ratios are the portable claim.

1. Clean-room reproduction of the producer chain

The full producer path was rebuilt from a clean checkout, using only the inputs
a model maintainer has: the SA3-medium checkpoint and the published
dit_fp16mixed.onnx (both pulled from HuggingFace), nothing carried over from
development. make_calib.py captured a fresh calibration set, then
build_dit_fp8.py produced the engine end to end.

Calibration capture: 376 samples (47 reprompt Music prompts x 8 sigmas) at
L=646, schedule [1.0, 0.9944, 0.9845, 0.9579, 0.8909, 0.7455, 0.5125, 0.2739], t5_hidden range [-52.33, 36.10], x range [-5.87, 5.38],
local_add_cond all zero (text-to-music), matching the expected reference
profile.

Build stages, all clean:

  • toposort + opset 17 -> 19.
  • FP8 PTQ: 619 quantizable nodes, MHA Q/DQ disabled, max calibration over the
    376 samples.
  • repair: the two known ModelOpt-corrupted initializers restored
    (to_timestep_embed.2.bias 6060 -> 0.12, layers.22.to_local_embed
    3286 -> 1.30), 417 Q/DQ pairs bypassed, 834 activation scales recalibrated,
    0 mask-path pairs.
  • FP32 islands (hybrid): 2021 island nodes (96 Softmax) plus the conditioning
    front-end.
  • per-channel weights: 220 weight pairs upgraded, 440 per-channel scale vectors
    verified.
  • TensorRT STRONGLY_TYPED compile: 1494 MB engine.

The resulting engine deserializes and exposes the expected six FP32 inputs
(x, t, t5_hidden, t5_mask, seconds_total, local_add_cond) and the
velocity FP32 output, with dynamic latent length.

2. Numerical parity vs the FP16-mixed engine

Parity was measured by feeding the captured DiT inputs through both the FP8 and
FP16-mixed engines at batch 1 and comparing outputs.

Single-step latent agreement (x + dt * v, the quantity that actually advances
the sampler), over all 376 samples, by sigma:

sigma min cos mean cos
1.0000 1.00000 1.00000
0.9944 0.99999 1.00000
0.9845 0.99979 0.99998
0.9579 0.99985 0.99992
0.8909 0.99955 0.99973
0.7455 0.99893 0.99945
0.5125 0.99921 0.99948
0.2739 0.99824 0.99912

Worst single-step latent cosine 0.99824, mean 0.99971.

Compounded agreement was measured with an 8-step deterministic euler rollout
from a fixed noise seed, one prompt's conditioning held constant across the
schedule, both engines warmed once before measuring (the FP8 engine's first
in-process rollout carries a transient that settles afterward). The final-latent
cosine of the FP8 rollout vs the FP16-mixed rollout was the metric. Using the
same procedure and calibration set the recipe was originally tuned on, this
reproduces 0.97600 compounded and 0.97807 worst single-step velocity cosine.
The clean-room rebuild scores 0.963 compounded and 0.978 single-step on that
prompt. The compounded figure is a single-prompt, single-rollout statistic and
is sensitive to which prompt drives the rollout (the ranking between two FP8
engines flips across prompts), so it is reported as a range, ~0.96 to 0.976,
rather than a single number.

For reference, the FP16-mixed engine itself scores only ~0.998 compounded vs PT
eager, so cosine is a guide rather than a gate; final judgement is by ear. BF16
was tried and rejected earlier: it compounds error over the 8 steps
(final-latent cos ~0.81) and is audibly degraded. Under the stochastic pingpong
sampler the FP8 engine yields a different but comparable sample.

3. Per-step latency (B=1, L=646)

Median of 200 timed steps on a real calibration sample:

engine step time speedup
FP16-mixed 19.9 ms 1.0x
FP8 11.2 ms 1.77x

4. Batched throughput, depths 1..8

For each engine the input batch dimension was made dynamic, a STRONGLY_TYPED
B=1..8 engine was compiled, each batched row was validated against the serial
B=1 engine, then the median batched step time was benched against serial
dispatch. gens/s is the steady-state generation rate of an 8-step pingpong
ring buffer running at that batch (depth).

FP16-mixed:

B batched ms per-slot ms speedup gens/s
1 28.4 28.4 0.74x 4.4
2 42.7 21.4 0.98x 5.9
3 60.5 20.2 1.04x 6.2
4 76.9 19.2 1.09x 6.5
5 96.4 19.3 1.09x 6.5
6 125.5 20.9 1.00x 6.0
7 144.1 20.6 1.02x 6.1
8 164.1 20.5 1.02x 6.1

(serial B=1 reference engine: 20.9 ms)

FP8:

B batched ms per-slot ms speedup gens/s
1 11.0 11.0 0.96x 11.3
2 17.1 8.6 1.24x 14.6
3 24.0 8.0 1.33x 15.6
4 30.4 7.6 1.40x 16.5
5 37.6 7.5 1.41x 16.6
6 46.0 7.7 1.38x 16.3
7 53.2 7.6 1.40x 16.5
8 61.9 7.7 1.37x 16.2

(serial B=1 reference engine: 10.6 ms)

Reading these together:

  • FP16-mixed is compute-saturated: per-slot stays ~19 to 21 ms regardless of
    batch, batching buys <=1.09x, and the ring-buffer ceiling is ~6.5 gens/s.
  • FP8 exceeds the ceiling: per-slot drops to ~7.6 ms and batching amortizes up to
    1.41x at B=4..5, lifting the ceiling to ~16.5 gens/s.
  • The per-slot FP8 advantage holds ~2.5x across the batch (2.57x at B=1, 2.53x
    at B=4, 2.65x at B=8), and the end-to-end pipeline throughput a depth > 1
    ring buffer actually hits rises from ~6.5 to ~16.5 gens/s, about 2.5x. The
    headline 1.8x is the B=1 step; under the batching the pipeline uses, FP8
    compounds it.

5. Batching correctness

Each batched row was validated against the serial B=1 engine on identical
inputs. The FP16-mixed batched engine matched its serial engine at 0.99994
(worst row). The FP8 batched engine matched at a uniform 0.992 to 0.995 across
all rows: this is FP8 kernel-tactic variance between two separate engine builds
(the per-tensor activation scales are the same), not a batch-size
specialization bug, which the uniformity across rows confirms. Timing is
unaffected.


Status: ONNX not yet on HF

dit_fp8.onnx + dit_fp8.onnx.data are not in the model repo yet, so
build_from_onnx.py sa3-m-fp8 and sa3_trt --precision fp8 will 404 until a
producer run uploads them under exactly those filenames. The consumer wiring
and --precision fp8 plumbing land here so they can be reviewed; the artifact
upload is the follow-up. This is why sa3-m-fp8 is opt-in and kept out of the
default all build paths.

Usage

Producer:

# 1. capture calibration data from the checkpoint
python build/make_calib.py \
  --model-config <MODELS_ROOT>/SA3-M-hf/model_config.json \
  --checkpoint   <MODELS_ROOT>/SA3-M-hf/model.safetensors \
  --out          sa3-m.calib.npz

# 2. build the FP8 ONNX + engine
python build/build_dit_fp8.py \
  --input  <HF_REPO>/onnx/sa3-m/dit_fp16mixed.onnx \
  --calib  sa3-m.calib.npz \
  --onnx   <HF_REPO>/onnx/sa3-m/dit_fp8.onnx \
  --engine ../models/<arch>/sa3-m/dit_fp8.trt

Requires nvidia-modelopt + onnxruntime-gpu on top of the consumer deps.

Consumer (once the ONNX is published):

python build/build_from_onnx.py sa3-m-fp8   # STRONGLY_TYPED compile, no ModelOpt
# then: sa3_trt --precision fp8

Producer recipe (build_dit_fp8.py) builds a ModelOpt FP8 GEMM-trunk DiT on
top of the FP16-mixed graph: FP8 PTQ on MatMul/Gemm, initializer repair plus
activation-scale recalibration, re-applied FP32 islands (RMSNorm/Softmax/RoPE
plus the conditioning front-end), and per-channel weight scales. make_calib.py
captures calibration inputs from the model's own pingpong generate(), pulling
prompts from interface/reprompt.py. ~1.8x faster steps than FP16-mixed at B=1,
amortizing further under batched dispatch.

Consumer wiring adds sa3-m-fp8 as an opt-in target (excluded from all/all-both
and 'build all missing'; built only by explicit name, gated on the published
ONNX) and a --precision fp8 selection that pairs the FP8 DiT with the
FP16-mixed decoder, guarded so non-medium DiTs cannot request it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant