generate.py: --dit-dtype {bfloat16,float16,float32} flag for ~30%% faster sampling#12
Open
NoahBPeterson wants to merge 1 commit into
Open
generate.py: --dit-dtype {bfloat16,float16,float32} flag for ~30%% faster sampling#12NoahBPeterson wants to merge 1 commit into
NoahBPeterson wants to merge 1 commit into
Conversation
Lets users opt into fp16 DiT compute on Apple Silicon for ~25-32% sampling
wall-clock savings, matching the proven win from upstream-ish ports of
TRELLIS.2 to MLX. Apple's Metal SDPA + matmul kernels run noticeably
faster on fp16 than bf16; visual quality stays essentially identical
(sub-pixel mesh deviation), but single-seed numerical parity with
upstream is sacrificed.
Implementation:
- Use the model's own `convert_to(dtype)` method when available rather
than a plain `.to(dtype)`. `convert_to` casts the transformer blocks'
parameters AND updates `self.dtype`, which the forward pass uses to
drive `manual_cast(x, self.dtype)` on intermediates. A bare
`.to(dtype)` would cast the parameters but leave the manual-cast
targets at bf16 — silently undoing the speedup.
- Falls back to `.to(dtype)` + setting `m.dtype = dtype` for any model
that doesn't expose `convert_to`.
- Iterates over all five flow-model keys (SS + shape@512 + shape@1024 +
tex@512 + tex@1024) so the flag is consistent across pipeline_type.
- VAE decoders are intentionally NOT recast — they already ship as fp16
and accurate intermediates matter more there.
Default remains bfloat16 (matches upstream training/inference dtype).
Inspect:
Pre: {'torch.float32': 17.5M, 'torch.bfloat16': 1.27B}
Post: {'torch.float32': 17.5M, 'torch.float16': 1.27B}
That is, only the transformer torso is recast; input/output layers stay
at fp32. The 17.5M fp32 params are normalization / time-embed / projection
weights where we want maximum precision regardless.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a
--dit-dtypeCLI flag togenerate.py, defaulting tobfloat16(current behavior, matches upstream training/inference dtype). Setting
--dit-dtype float16recasts the three flow DiTs' transformer torso tofp16, taking advantage of Apple Silicon's faster fp16 SDPA + matmul
kernels.
Why
Apple's Metal SDPA and matmul are roughly 1.3× faster on fp16 than bf16.
On the SLat sampling phase — which is the dominant wall-clock cost on
trellis-mac (Shape SLat alone is ~35 s/step on M1 Pro) — that translates
to a meaningful end-to-end win. Mature MLX-side ports of TRELLIS.2 ship
the same flag and report 25-32% total wall-clock savings with sub-pixel
mesh deviation; visual quality is essentially identical.
Single-seed numerical parity with upstream is sacrificed (fp16 grid step
differs from bf16), so the flag is opt-in and bf16 remains the default.
Implementation
convert_to(dtype)method when it exposes onerather than a plain
.to(dtype).convert_tocasts the transformerblocks' parameters AND updates
self.dtype, which the forward passuses to drive
manual_cast(x, self.dtype)on intermediates. A bare.to(dtype)would cast the parameters but leave the manual-casttargets at bf16 — silently undoing the speedup.
.to(dtype)+m.dtype = dtypefor any flow modelvariant that doesn't expose
convert_to(defensive).across pipeline_type ∈
{512, 1024, 1024_cascade}.and accurate intermediates matter more there.
Validation
Smoke test on M1 Pro 16 GB:
Only the transformer torso is recast; input/output layers stay at fp32.
End-to-end run with
--dit-dtype float16not yet measured here (M1 Pro16 GB is heavily swap-bound across multiple runs); the upstream MLX port
on equivalent shapes reports 32% total wall-clock savings, and the
underlying kernel-perf delta is well-known. I'll post a measured number
when I can do a clean cool-machine run.
API impact
No breakage: new flag with the existing default behavior preserved.
Existing scripts continue to work unchanged.
Related
--dit-dtype/dit_compute_dtype="float16". Validated there as a clear win.