Skip to content

generate.py: --dit-dtype {bfloat16,float16,float32} flag for ~30%% faster sampling#12

Open
NoahBPeterson wants to merge 1 commit into
shivampkumar:mainfrom
NoahBPeterson:dit-dtype-flag
Open

generate.py: --dit-dtype {bfloat16,float16,float32} flag for ~30%% faster sampling#12
NoahBPeterson wants to merge 1 commit into
shivampkumar:mainfrom
NoahBPeterson:dit-dtype-flag

Conversation

@NoahBPeterson

Copy link
Copy Markdown

Summary

Adds a --dit-dtype CLI flag to generate.py, defaulting to bfloat16
(current behavior, matches upstream training/inference dtype). Setting
--dit-dtype float16 recasts the three flow DiTs' transformer torso to
fp16, taking advantage of Apple Silicon's faster fp16 SDPA + matmul
kernels.

Why

Apple's Metal SDPA and matmul are roughly 1.3× faster on fp16 than bf16.
On the SLat sampling phase — which is the dominant wall-clock cost on
trellis-mac (Shape SLat alone is ~35 s/step on M1 Pro) — that translates
to a meaningful end-to-end win. Mature MLX-side ports of TRELLIS.2 ship
the same flag and report 25-32% total wall-clock savings with sub-pixel
mesh deviation; visual quality is essentially identical.

Single-seed numerical parity with upstream is sacrificed (fp16 grid step
differs from bf16), so the flag is opt-in and bf16 remains the default.

Implementation

  • Use the model's own convert_to(dtype) method when it exposes one
    rather than a plain .to(dtype). convert_to casts the transformer
    blocks' parameters AND updates self.dtype, which the forward pass
    uses to drive manual_cast(x, self.dtype) on intermediates. A bare
    .to(dtype) would cast the parameters but leave the manual-cast
    targets at bf16 — silently undoing the speedup.
  • Fallback to .to(dtype) + m.dtype = dtype for any flow model
    variant that doesn't expose convert_to (defensive).
  • Iterates over all five flow-model keys so the flag is consistent
    across pipeline_type ∈ {512, 1024, 1024_cascade}.
  • VAE decoders are intentionally NOT recast — they already ship as fp16
    and accurate intermediates matter more there.

Validation

Smoke test on M1 Pro 16 GB:

Pre:  {'torch.float32': 17.5M, 'torch.bfloat16': 1274.7M}
Post: {'torch.float32': 17.5M, 'torch.float16':  1274.7M}

Only the transformer torso is recast; input/output layers stay at fp32.

End-to-end run with --dit-dtype float16 not yet measured here (M1 Pro
16 GB is heavily swap-bound across multiple runs); the upstream MLX port
on equivalent shapes reports 32% total wall-clock savings, and the
underlying kernel-perf delta is well-known. I'll post a measured number
when I can do a clean cool-machine run.

API impact

No breakage: new flag with the existing default behavior preserved.
Existing scripts continue to work unchanged.

Related

  • Same flag pattern shipped in MLX-side TRELLIS.2 ports as --dit-dtype /
    dit_compute_dtype="float16". Validated there as a clear win.

Lets users opt into fp16 DiT compute on Apple Silicon for ~25-32% sampling
wall-clock savings, matching the proven win from upstream-ish ports of
TRELLIS.2 to MLX. Apple's Metal SDPA + matmul kernels run noticeably
faster on fp16 than bf16; visual quality stays essentially identical
(sub-pixel mesh deviation), but single-seed numerical parity with
upstream is sacrificed.

Implementation:

  - Use the model's own `convert_to(dtype)` method when available rather
    than a plain `.to(dtype)`. `convert_to` casts the transformer blocks'
    parameters AND updates `self.dtype`, which the forward pass uses to
    drive `manual_cast(x, self.dtype)` on intermediates. A bare
    `.to(dtype)` would cast the parameters but leave the manual-cast
    targets at bf16 — silently undoing the speedup.
  - Falls back to `.to(dtype)` + setting `m.dtype = dtype` for any model
    that doesn't expose `convert_to`.
  - Iterates over all five flow-model keys (SS + shape@512 + shape@1024 +
    tex@512 + tex@1024) so the flag is consistent across pipeline_type.
  - VAE decoders are intentionally NOT recast — they already ship as fp16
    and accurate intermediates matter more there.

Default remains bfloat16 (matches upstream training/inference dtype).

Inspect:

    Pre:  {'torch.float32': 17.5M, 'torch.bfloat16': 1.27B}
    Post: {'torch.float32': 17.5M, 'torch.float16':  1.27B}

That is, only the transformer torso is recast; input/output layers stay
at fp32. The 17.5M fp32 params are normalization / time-embed / projection
weights where we want maximum precision regardless.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant