Add opt-in FP8 DiT TensorRT engine for SA3-medium#47
Draft
ryanontheinside wants to merge 1 commit into
Draft
Conversation
Producer recipe (build_dit_fp8.py) builds a ModelOpt FP8 GEMM-trunk DiT on top of the FP16-mixed graph: FP8 PTQ on MatMul/Gemm, initializer repair plus activation-scale recalibration, re-applied FP32 islands (RMSNorm/Softmax/RoPE plus the conditioning front-end), and per-channel weight scales. make_calib.py captures calibration inputs from the model's own pingpong generate(), pulling prompts from interface/reprompt.py. ~1.8x faster steps than FP16-mixed at B=1, amortizing further under batched dispatch. Consumer wiring adds sa3-m-fp8 as an opt-in target (excluded from all/all-both and 'build all missing'; built only by explicit name, gated on the published ONNX) and a --precision fp8 selection that pairs the FP8 DiT with the FP16-mixed decoder, guarded so non-medium DiTs cannot request it.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
DRAFT: FP8 DiT TensorRT engine (opt-in, ~1.8x per step, ~2.5x batched throughput)
Adds an optional FP8 GEMM-trunk TensorRT engine for the SA3-medium DiT, on top
of the existing FP16-mixed recipe. The DiT step is the inner loop of the
pingpong sampler, so this is the highest-leverage place to cut latency. The
engine keeps FP32 inputs/outputs, so it is a drop-in swap for the FP16-mixed
engine at inference (
sa3_trt --precision fp8), pairing an FP8 DiT with theexisting FP16-mixed decoder.
Why FP8 here (and why batching is the larger win)
The FP16-mixed DiT engine is compute-saturated: a single batched forward barely
amortizes (<=1.09x at B=4), so a ring-buffer pipeline at depth > 1 hits a flat
throughput ceiling. FP8 cuts per-row GEMM compute ~1.8x, which frees the SM
throughput the FP16 engine saturated, so batching amortizes.
The recipe (why it is more than
mtq.quantize)build_dit_fp8.pytakes the publisheddit_fp16mixed.onnxplus a calibration.npzand producesdit_fp8.onnx:the island surgery, and ModelOpt/ORT reject that).
disable_mha_qdq(attention BMMs stayon the FP16/FP32 path),
maxcalibration.and recalibrate activation scales on a Q/DQ-bypassed copy with real
conditioning.
Softmax / RoPE) plus the conditioning front-end, which must stay FP32 or the
t>=0.984 timestep features flush.
free at runtime). Activations stay per-tensor (TRT requirement for FP8
activation quant);
maxcalibration, because the activation outliers aresignal and percentile clipping regresses parity.
Calibration data is captured from the model's own
generate()bymake_calib.py, which records the six DiT engine inputs across the pingpongschedule. Prompts come from this repo's
interface/reprompt.pyMusic examples,the deployment-matched reprompt format. The
.npzis a reproducible producerartifact (gitignored, never committed).
What is in this PR
Producer (model maintainers):
build/make_calib.py(new): calibration capture from the checkpoint.build/build_dit_fp8.py(new): the FP8 build recipe above.Consumer:
build/build_from_onnx.py,build/build.py: addsa3-m-fp8as an opt-intarget. It is excluded from
all/all-both/ "build all missing" and isbuilt only by explicit name, gated on the published ONNX existing.
scripts/sa3_trt.py,scripts/sa3_trt_core.py:--precision fp8selection(FP8 DiT + FP16-mixed decoder).
build/README.md: producer and consumer documentation..gitignore: ignore*.calib.npz.Testing and validation
All results below were produced on a single RTX 5090 (sm_120), TensorRT
10.16, on SA3-medium at L=646 (the latent length of a ~54 s generation). Step
times are hardware-dependent; the speedup ratios are the portable claim.
1. Clean-room reproduction of the producer chain
The full producer path was rebuilt from a clean checkout, using only the inputs
a model maintainer has: the SA3-medium checkpoint and the published
dit_fp16mixed.onnx(both pulled from HuggingFace), nothing carried over fromdevelopment.
make_calib.pycaptured a fresh calibration set, thenbuild_dit_fp8.pyproduced the engine end to end.Calibration capture: 376 samples (47 reprompt Music prompts x 8 sigmas) at
L=646, schedule
[1.0, 0.9944, 0.9845, 0.9579, 0.8909, 0.7455, 0.5125, 0.2739],t5_hiddenrange[-52.33, 36.10],xrange[-5.87, 5.38],local_add_condall zero (text-to-music), matching the expected referenceprofile.
Build stages, all clean:
maxcalibration over the376 samples.
(
to_timestep_embed.2.bias6060 -> 0.12,layers.22.to_local_embed3286 -> 1.30), 417 Q/DQ pairs bypassed, 834 activation scales recalibrated,
0 mask-path pairs.
front-end.
verified.
The resulting engine deserializes and exposes the expected six FP32 inputs
(
x,t,t5_hidden,t5_mask,seconds_total,local_add_cond) and thevelocityFP32 output, with dynamic latent length.2. Numerical parity vs the FP16-mixed engine
Parity was measured by feeding the captured DiT inputs through both the FP8 and
FP16-mixed engines at batch 1 and comparing outputs.
Single-step latent agreement (
x + dt * v, the quantity that actually advancesthe sampler), over all 376 samples, by sigma:
Worst single-step latent cosine 0.99824, mean 0.99971.
Compounded agreement was measured with an 8-step deterministic euler rollout
from a fixed noise seed, one prompt's conditioning held constant across the
schedule, both engines warmed once before measuring (the FP8 engine's first
in-process rollout carries a transient that settles afterward). The final-latent
cosine of the FP8 rollout vs the FP16-mixed rollout was the metric. Using the
same procedure and calibration set the recipe was originally tuned on, this
reproduces 0.97600 compounded and 0.97807 worst single-step velocity cosine.
The clean-room rebuild scores 0.963 compounded and 0.978 single-step on that
prompt. The compounded figure is a single-prompt, single-rollout statistic and
is sensitive to which prompt drives the rollout (the ranking between two FP8
engines flips across prompts), so it is reported as a range, ~0.96 to 0.976,
rather than a single number.
For reference, the FP16-mixed engine itself scores only ~0.998 compounded vs PT
eager, so cosine is a guide rather than a gate; final judgement is by ear. BF16
was tried and rejected earlier: it compounds error over the 8 steps
(final-latent cos ~0.81) and is audibly degraded. Under the stochastic pingpong
sampler the FP8 engine yields a different but comparable sample.
3. Per-step latency (B=1, L=646)
Median of 200 timed steps on a real calibration sample:
4. Batched throughput, depths 1..8
For each engine the input batch dimension was made dynamic, a STRONGLY_TYPED
B=1..8 engine was compiled, each batched row was validated against the serial
B=1 engine, then the median batched step time was benched against serial
dispatch.
gens/sis the steady-state generation rate of an 8-step pingpongring buffer running at that batch (depth).
FP16-mixed:
(serial B=1 reference engine: 20.9 ms)
FP8:
(serial B=1 reference engine: 10.6 ms)
Reading these together:
batch, batching buys <=1.09x, and the ring-buffer ceiling is ~6.5 gens/s.
1.41x at B=4..5, lifting the ceiling to ~16.5 gens/s.
at B=4, 2.65x at B=8), and the end-to-end pipeline throughput a depth > 1
ring buffer actually hits rises from ~6.5 to ~16.5 gens/s, about 2.5x. The
headline 1.8x is the B=1 step; under the batching the pipeline uses, FP8
compounds it.
5. Batching correctness
Each batched row was validated against the serial B=1 engine on identical
inputs. The FP16-mixed batched engine matched its serial engine at 0.99994
(worst row). The FP8 batched engine matched at a uniform 0.992 to 0.995 across
all rows: this is FP8 kernel-tactic variance between two separate engine builds
(the per-tensor activation scales are the same), not a batch-size
specialization bug, which the uniformity across rows confirms. Timing is
unaffected.
Status: ONNX not yet on HF
dit_fp8.onnx+dit_fp8.onnx.dataare not in the model repo yet, sobuild_from_onnx.py sa3-m-fp8andsa3_trt --precision fp8will 404 until aproducer run uploads them under exactly those filenames. The consumer wiring
and
--precision fp8plumbing land here so they can be reviewed; the artifactupload is the follow-up. This is why
sa3-m-fp8is opt-in and kept out of thedefault
allbuild paths.Usage
Producer:
Requires
nvidia-modelopt+onnxruntime-gpuon top of the consumer deps.Consumer (once the ONNX is published):