LangSplatV2: GPU memory and runtime improvements#60
Merged
Conversation
Two independent, results-preserving optimizations to mask_utils.py, aimed at high-resolution (e.g. 4K) inputs where mask post-processing dominated peak GPU RAM and wall-clock time. mask_nms: bound peak GPU memory (~4-5x lower at high mask counts) - The pairwise IoU / containment matrices are now computed on masks downsampled so their longest side is <= nms_max_dim (default 1024), instead of at full resolution. These matrices only drive thresholding decisions and the returned indices point back at the full-resolution masks, so the kept set is essentially unchanged (borderline-only flips). - Downsampling is chunked (interpolate needs float input) so the transient full-res float buffer never materialises; this was the dominant ~33MB/mask allocation (tens of GB for a few thousand masks). - Clamp area/union denominators so masks that downsample to zero area can't produce NaNs; no-op for non-zero areas. nms_max_dim=None restores the exact original behaviour. Threaded through masks_update. postprocess_small_regions: ~4.4x faster per frame - _clean_single_mask now runs cv2.connectedComponentsWithStats on each mask's bounding-box crop (padded 1px) rather than the full frame, avoiding scanning the mostly-empty image. The 1px pad keeps the exterior background a single border-touching component, so hole/island labelling is identical; the original "islands" full-frame behaviour (keeps background label 0) is reproduced exactly. Verified byte-for-byte equal to the previous output. Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
…r config Training previously built dense [B, H, W, 512] feature maps and computed the loss over the full image, even though only the masked pixels carry features (~22% on safety_park). Profiling one step showed the loss stage alone spiking peak allocated memory from ~11 GB to ~33 GB. Sparse loss path (numerically identical to the dense loss): - loss.py: add calculate_langsplatv2_loss_sparse(), operating on the flat valid-pixel set and normalising by mask.numel() so values and gradients match the dense loss exactly. The cosine_loss_valid diagnostic is gated behind compute_valid_metric (it was previously gathered every step but only logged occasionally). - dataset.py: add build_sparse_gt_features(), gathering GT features only at valid pixels (ordered to match weight_maps[mask]), with the same out-of-bounds handling as build_feature_map. - trainer.py: the training step now renders the full weight maps but decodes only the valid pixels (weight_maps[mask] @ codebook) and uses the sparse loss. Dense maps are reconstructed only on logging steps when log_test_images is set; eval() is unchanged. Verified on a real frame: total_loss/cosine_loss bit-identical, gradients match to float precision (logits ~7e-12, codebooks ~1e-7); peak allocated 38.7 GB -> 9.8 GB. Allocator fragmentation fix: - train_langsplatv2.py: default PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (via setdefault, before CUDA init). The variable per-image valid-pixel count makes the sparse tensors different sizes each step, fragmenting the default caching allocator (reserved memory grew ~17 GB -> ~44 GB despite ~13 GB peak allocated). Expandable segments keep reserved tracking actual usage. Net: real training peak ~38-43 GB -> ~18 GB, plus a per-step speedup (decode + loss now run over ~4.5x fewer elements). No change to training results. Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
The optimizer step was ~5% of each training iteration, dominated by updating the large per-Gaussian logits tensor. Switch to torch.optim.Adam(fused=True) when training on CUDA -- it is numerically equivalent but roughly halves the step time (5.5 ms -> 2.6 ms, ~3% faster iterations overall). Guarded by fused=(device.type == "cuda") in both LangSplatV2Trainer.new and from_state_dict so CPU/eval paths and checkpoint restore are unaffected. Verified the loss trajectory is unchanged. Signed-off-by: Jonathan Swartz <jonathan@jswartz.info>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Performance improvements to the LangSplatV2 example, reducing both GPU memory and runtime in mask preprocessing and feature training at high resolution. No changes to results — verified numerically equivalent throughout.
SAM mask post-processing (
scene_transforms/mask_utils.py)mask_nms: compute the pairwise IoU/containment matrices on masks downsampled to a capped longest side (default 1024, chunked so the full-resolution float tensor is never materialized). The dominantmasks_flat @ masks_flat.Tallocation no longer scales with full resolution — ~4–5× lower peak GPU memory at high mask counts.nms_max_dim=Nonerestores exact original behavior.postprocess_small_regions: run connected-components on each mask's bounding-box crop instead of the full frame — ~4.4× faster on 4K frames, byte-for-byte identical output (verified).Feature training memory (
loss.py,training/dataset.py,training/trainer.py)[B, H, W, 512]maps (only ~22% of pixels carry features). Newcalculate_langsplatv2_loss_sparse+build_sparse_gt_features; the training step decodes only valid pixels. Dense maps are reconstructed only for optional image logging. Peak allocated memory 38.7 GB → 9.8 GB; loss and gradients match the dense path to float precision.Training memory and speed (
scripts/train_langsplatv2.py,training/trainer.py)PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True(viasetdefault, before CUDA init): the per-image variable valid-pixel count fragmented the caching allocator (reserved memory grew ~17 GB → ~44 GB despite ~13 GB allocated). Brings real training peak ~38–43 GB → ~18 GB.Combined: SAM mask generation uses far less GPU RAM and runs faster; training peak memory is ~2.5× lower and iterations are modestly faster. Remaining training time is ~84% Gaussian rasterizer (render + backward), which is the floor short of a resolution/quality tradeoff.