From 5fd7eab62b91efdaf2f18f70526f42741c20efa0 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 4 Jun 2026 14:26:29 +1200 Subject: [PATCH 1/3] LangSplatV2: cut GPU memory and CPU time in SAM mask post-processing Two independent, results-preserving optimizations to mask_utils.py, aimed at high-resolution (e.g. 4K) inputs where mask post-processing dominated peak GPU RAM and wall-clock time. mask_nms: bound peak GPU memory (~4-5x lower at high mask counts) - The pairwise IoU / containment matrices are now computed on masks downsampled so their longest side is <= nms_max_dim (default 1024), instead of at full resolution. These matrices only drive thresholding decisions and the returned indices point back at the full-resolution masks, so the kept set is essentially unchanged (borderline-only flips). - Downsampling is chunked (interpolate needs float input) so the transient full-res float buffer never materialises; this was the dominant ~33MB/mask allocation (tens of GB for a few thousand masks). - Clamp area/union denominators so masks that downsample to zero area can't produce NaNs; no-op for non-zero areas. nms_max_dim=None restores the exact original behaviour. Threaded through masks_update. postprocess_small_regions: ~4.4x faster per frame - _clean_single_mask now runs cv2.connectedComponentsWithStats on each mask's bounding-box crop (padded 1px) rather than the full frame, avoiding scanning the mostly-empty image. The 1px pad keeps the exterior background a single border-touching component, so hole/island labelling is identical; the original "islands" full-frame behaviour (keeps background label 0) is reproduced exactly. Verified byte-for-byte equal to the previous output. Signed-off-by: Jonathan Swartz --- .../scene_transforms/mask_utils.py | 124 +++++++++++++++--- 1 file changed, 108 insertions(+), 16 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py index fc1145b..053d0ef 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py @@ -56,21 +56,57 @@ def _clean_single_mask( ) -> Tuple[np.ndarray, float, List[float]]: """Clean one mask and compute its bounding box. Thread-safe (CPU-only). + The connected-component work runs on the mask's bounding-box crop (padded + by 1px) rather than the full frame. This is exactly equivalent to cleaning + the full-resolution mask: the 1px pad guarantees the exterior background is + a single border-touching component, so the hole/island labelling is + identical, but it avoids scanning the (mostly empty) full image, which + dominates runtime at high resolution. + Returns: (cleaned_seg, score, box_xyxy) where score is 1.0 if the mask was unchanged and 0.0 if it was modified. """ - seg = seg_raw.astype(bool) - seg, changed_holes = remove_small_regions(seg, min_area, mode="holes") - unchanged = not changed_holes - seg, changed_islands = remove_small_regions(seg, min_area, mode="islands") - unchanged = unchanged and not changed_islands - - ys, xs = np.where(seg) - if len(xs) == 0: - box = [0.0, 0.0, 0.0, 0.0] + seg_full = seg_raw.astype(bool) + h, w = seg_full.shape + + rows = np.any(seg_full, axis=1) + if not rows.any(): + return seg_full, 1.0, [0.0, 0.0, 0.0, 0.0] + cols = np.any(seg_full, axis=0) + row_idx = np.where(rows)[0] + col_idx = np.where(cols)[0] + y0 = max(0, int(row_idx[0]) - 1) + y1 = min(h, int(row_idx[-1]) + 2) + x0 = max(0, int(col_idx[0]) - 1) + x1 = min(w, int(col_idx[-1]) + 2) + + sub = seg_full[y0:y1, x0:x1] + sub, changed_holes = remove_small_regions(sub, min_area, mode="holes") + sub, changed_islands = remove_small_regions(sub, min_area, mode="islands") + unchanged = (not changed_holes) and (not changed_islands) + + # The "islands" pass keeps background label 0 (matching the original + # LangSplatV2 logic), so when it fires the result is the whole frame True + # except the small islands. Reproduce that: everything outside the crop is + # True iff islands changed the mask, otherwise False. + seg = np.ones((h, w), dtype=bool) if changed_islands else np.zeros((h, w), dtype=bool) + seg[y0:y1, x0:x1] = sub + + if changed_islands: + box = [0.0, 0.0, float(w - 1), float(h - 1)] else: - box = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())] + sub_rows = np.where(np.any(sub, axis=1))[0] + if len(sub_rows) == 0: + box = [0.0, 0.0, 0.0, 0.0] + else: + sub_cols = np.where(np.any(sub, axis=0))[0] + box = [ + float(int(sub_cols[0]) + x0), + float(int(sub_rows[0]) + y0), + float(int(sub_cols[-1]) + x0), + float(int(sub_rows[-1]) + y0), + ] return seg, float(unchanged), box @@ -145,6 +181,7 @@ def mask_nms( iou_thr: float = 0.7, score_thr: float = 0.1, inner_thr: float = 0.2, + nms_max_dim: int | None = 1024, **kwargs, ) -> torch.Tensor: """ @@ -159,6 +196,14 @@ def mask_nms( iou_thr: IoU threshold for NMS. score_thr: Minimum score threshold. inner_thr: Inner overlap threshold for removing contained masks. + nms_max_dim: If set, masks whose longest side exceeds this value are + downsampled (longest side capped at ``nms_max_dim``) *only* for + computing the pairwise IoU / containment matrices. The dominant + memory cost here is the full-resolution ``masks_flat @ masks_flat.T`` + (e.g. a 4K mask is ~33 MB as float32, so a few thousand masks is + tens of GB). Capping the working resolution bounds that cost while + leaving the returned masks at full resolution. ``None`` disables + downsampling (exact original behavior). Returns: Indices of selected masks after NMS. @@ -174,15 +219,51 @@ def mask_nms( scores, idx = scores.sort(0, descending=True) num_masks = idx.shape[0] - masks_ord = masks[idx.view(-1), :] - masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float) + sorted_idx = idx.view(-1) + + # The pairwise IoU / containment matrices below only drive thresholding + # decisions, and the indices returned by this function point back at the + # original full-resolution masks. So we can compute those matrices on a + # downsampled copy to bound peak memory (the float `masks_flat @ + # masks_flat.T` is by far the largest allocation at high resolution) with + # negligible effect on which masks are kept. + # + # The downsampling is done in chunks: `interpolate` needs a float input, so + # converting all masks to full-resolution float at once would allocate the + # very tensor we are trying to avoid. Casting/resizing a chunk at a time + # keeps the transient float buffer to `chunk_size` masks. + h, w = masks.shape[-2:] + longest = max(h, w) + if nms_max_dim is not None and longest > nms_max_dim: + scale = nms_max_dim / longest + new_h = max(1, int(round(h * scale))) + new_w = max(1, int(round(w * scale))) + chunk_size = 64 + chunks = [] + for start in range(0, num_masks, chunk_size): + block = masks[sorted_idx[start : start + chunk_size]].unsqueeze(1).float() + block = ( + torch.nn.functional.interpolate(block, size=(new_h, new_w), mode="area").squeeze(1) > 0.5 + ) + chunks.append(block) + masks_work = torch.cat(chunks, dim=0) + del chunks + else: + masks_work = masks[sorted_idx, :] - masks_flat = masks_ord.reshape(num_masks, -1).float() + masks_area = torch.sum(masks_work, dim=(1, 2), dtype=torch.float) + # A small mask can downsample to zero area; clamp the denominators so the + # IoU / containment ratios stay finite (a vanished mask just scores ~0 + # overlap against everything, i.e. it is kept). For non-zero areas this + # clamp is a no-op, matching the original numerics. + masks_area_safe = masks_area.clamp_min(1.0) + + masks_flat = masks_work.reshape(num_masks, -1).float() intersection = masks_flat @ masks_flat.T - union = masks_area[:, None] + masks_area[None, :] - intersection + union = (masks_area_safe[:, None] + masks_area_safe[None, :] - intersection).clamp_min(1.0) iou_matrix = intersection / union - R = intersection / masks_area[:, None] + R = intersection / masks_area_safe[:, None] inner_val = 1 - R * R.T cond = (R < 0.5) & (R.T >= 0.85) inner_iou_matrix = torch.where(cond, inner_val, torch.zeros_like(inner_val)) @@ -249,6 +330,7 @@ def masks_update( score_thr: float = 0.7, inner_thr: float = 0.5, max_area_frac: float = 0.95, + nms_max_dim: int | None = 1024, ) -> tuple: """ Apply mask NMS to multiple lists of masks. @@ -261,6 +343,9 @@ def masks_update( max_area_frac: Discard masks covering more than this fraction of the image. Near-full-image masks poison the inner-containment check by appearing to contain every other mask. + nms_max_dim: Longest-side cap for the (memory-heavy) IoU / containment + computation inside :func:`mask_nms`. See that function for details. + ``None`` disables downsampling. Returns: Tuple of filtered mask lists. @@ -310,7 +395,14 @@ def masks_update( stability = stability.cuda() scores = stability * iou_pred - keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr) + keep_mask_nms = mask_nms( + seg_pred, + scores, + iou_thr=iou_thr, + score_thr=score_thr, + inner_thr=inner_thr, + nms_max_dim=nms_max_dim, + ) keep_set = set(keep_mask_nms.int().cpu().numpy().tolist()) filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set] From 4a483d9dd2b39fb8345f6f4cde2011122ed495ec Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 4 Jun 2026 14:54:51 +1200 Subject: [PATCH 2/3] LangSplatV2: cut training GPU memory ~2.5x via sparse loss + allocator config Training previously built dense [B, H, W, 512] feature maps and computed the loss over the full image, even though only the masked pixels carry features (~22% on safety_park). Profiling one step showed the loss stage alone spiking peak allocated memory from ~11 GB to ~33 GB. Sparse loss path (numerically identical to the dense loss): - loss.py: add calculate_langsplatv2_loss_sparse(), operating on the flat valid-pixel set and normalising by mask.numel() so values and gradients match the dense loss exactly. The cosine_loss_valid diagnostic is gated behind compute_valid_metric (it was previously gathered every step but only logged occasionally). - dataset.py: add build_sparse_gt_features(), gathering GT features only at valid pixels (ordered to match weight_maps[mask]), with the same out-of-bounds handling as build_feature_map. - trainer.py: the training step now renders the full weight maps but decodes only the valid pixels (weight_maps[mask] @ codebook) and uses the sparse loss. Dense maps are reconstructed only on logging steps when log_test_images is set; eval() is unchanged. Verified on a real frame: total_loss/cosine_loss bit-identical, gradients match to float precision (logits ~7e-12, codebooks ~1e-7); peak allocated 38.7 GB -> 9.8 GB. Allocator fragmentation fix: - train_langsplatv2.py: default PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (via setdefault, before CUDA init). The variable per-image valid-pixel count makes the sparse tensors different sizes each step, fragmenting the default caching allocator (reserved memory grew ~17 GB -> ~44 GB despite ~13 GB peak allocated). Expandable segments keep reserved tracking actual usage. Net: real training peak ~38-43 GB -> ~18 GB, plus a per-step speedup (decode + loss now run over ~4.5x fewer elements). No change to training results. Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/loss.py | 66 +++++++++++++++++++ .../langsplatv2/training/dataset.py | 58 ++++++++++++++++ .../langsplatv2/training/trainer.py | 54 +++++++++------ .../langsplatv2/scripts/train_langsplatv2.py | 11 ++++ 4 files changed, 170 insertions(+), 19 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py index e008fcc..23da96c 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py @@ -105,3 +105,69 @@ def calculate_langsplatv2_loss( loss_dict["total_loss"] = total_loss return loss_dict + + +def calculate_langsplatv2_loss_sparse( + pred_valid: torch.Tensor, + gt_valid: torch.Tensor, + num_pixels: int, + use_cosine_loss: bool = True, + use_l1_loss: bool = False, + normalize_features: bool = False, + compute_valid_metric: bool = False, +) -> dict[str, torch.Tensor]: + """Sparse equivalent of :func:`calculate_langsplatv2_loss`. + + Operates only on the valid (masked) pixels, avoiding the dense + ``[B, H, W, C]`` feature maps entirely. This is numerically identical to + the dense version because unmapped pixels contribute exactly zero to the + dense loss (their GT features are zero and the per-pixel loss is multiplied + by the mask), and the dense normalisation divides by ``mask.numel()`` -- so + ``sum_over_valid / num_pixels`` reproduces the same value and gradients. + + Args: + pred_valid: Predicted features at valid pixels, shape ``[N_valid, C]``. + gt_valid: Ground-truth features at the same pixels, shape ``[N_valid, C]``. + num_pixels: Total pixel count of the dense map (``mask.numel()``), used + as the normalisation denominator to match the dense loss exactly. + use_cosine_loss: Whether to include cosine similarity loss. + use_l1_loss: Whether to include L1 loss. + normalize_features: Whether to L2-normalize predicted features. + compute_valid_metric: If True, also compute the ``cosine_loss_valid`` + diagnostic (mean over valid pixels). Off by default since it is only + used for occasional logging. + + Returns: + Dictionary with the same keys as :func:`calculate_langsplatv2_loss`. + """ + assert use_cosine_loss or use_l1_loss, "At least one loss type must be enabled" + + if normalize_features: + pred_valid = pred_valid / (pred_valid.norm(dim=-1, keepdim=True) + 1e-10) + + loss_dict: dict[str, torch.Tensor] = {} + total_loss = torch.tensor(0.0, device=pred_valid.device) + has_valid = pred_valid.shape[0] > 0 + + if use_cosine_loss: + if has_valid: + per_valid = 1.0 - F.cosine_similarity(pred_valid, gt_valid, dim=-1) # [N_valid] + cos_loss_all = per_valid.sum() / num_pixels + else: + cos_loss_all = total_loss + loss_dict["cosine_loss"] = cos_loss_all + total_loss = total_loss + cos_loss_all + + if compute_valid_metric: + loss_dict["cosine_loss_valid"] = per_valid.mean() if has_valid else cos_loss_all + + if use_l1_loss: + if has_valid: + l1 = torch.abs(pred_valid - gt_valid).sum() / num_pixels + else: + l1 = total_loss + loss_dict["l1_loss"] = l1 + total_loss = total_loss + l1 + + loss_dict["total_loss"] = total_loss + return loss_dict diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py index 86eedf7..b3dfdcb 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py @@ -291,6 +291,64 @@ def build_feature_map( return gt_features, feature_mask +def build_sparse_gt_features( + features: JaggedTensor | torch.Tensor, + seg_map: torch.Tensor, + clip_n_dims: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Gather only the valid GT features, without building a dense ``[H,W,C]`` map. + + This is the sparse counterpart to :func:`build_feature_map`. It returns the + GT features for valid pixels in the row-major order of a boolean index over + ``[B, H, W]`` (i.e. matching ``weight_maps[mask]``), so the predicted and GT + features stay aligned. The returned mask reflects any out-of-bounds dropping, + exactly as :func:`build_feature_map` does. + + Args: + features: CLIP features as a :class:`JaggedTensor` (batched) or a plain + ``torch.Tensor`` ``[N_masks, clip_n_dims]`` (unbatched). + seg_map: Segmentation map ``[B, H, W]`` or ``[H, W]`` (-1 = no feature). + clip_n_dims: Feature dimensionality (for the empty-output fallback). + + Returns: + Tuple of (gt_valid, mask): + - gt_valid: ``[N_valid, clip_n_dims]`` GT features at valid pixels. + - mask: ``[B, H, W]`` or ``[H, W]`` bool, post out-of-bounds filtering. + """ + if isinstance(features, torch.Tensor): + mask = seg_map >= 0 # [H, W] + idx = seg_map[mask].long() + in_bounds = idx < features.shape[0] + if not bool(in_bounds.all()): + positions = mask.nonzero(as_tuple=False)[in_bounds] + mask = torch.zeros_like(mask) + mask[positions[:, 0], positions[:, 1]] = True + idx = idx[in_bounds] + return features[idx], mask + + # Batched JaggedTensor path + B = seg_map.shape[0] if seg_map.dim() == 3 else 1 + device = features.jdata.device + dtype = features.jdata.dtype + mask = seg_map >= 0 # [B, H, W] + parts: list[torch.Tensor] = [] + for b in range(B): + mask_b = mask[b] + feat_b = features[b].jdata + idx = seg_map[b][mask_b].long() + in_bounds = idx < feat_b.shape[0] + if not bool(in_bounds.all()): + positions = mask_b.nonzero(as_tuple=False)[in_bounds] + new_mask_b = torch.zeros_like(mask_b) + new_mask_b[positions[:, 0], positions[:, 1]] = True + mask[b] = new_mask_b + idx = idx[in_bounds] + parts.append(feat_b[idx]) + + gt_valid = torch.cat(parts, dim=0) if parts else torch.zeros(0, clip_n_dims, device=device, dtype=dtype) + return gt_valid, mask + + class LangSplatV2Input(dict): """Batched input dictionary for the LangSplatV2 model. diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py index 5859278..866fd6c 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py @@ -20,7 +20,7 @@ from fvdb_reality_capture.sfm_scene import SfmScene from ..config import LangSplatV2ModelConfig, LangSplatV2TrainingConfig -from ..loss import calculate_langsplatv2_loss +from ..loss import calculate_langsplatv2_loss, calculate_langsplatv2_loss_sparse from ..model import LangSplatV2Model from ..util import calculate_pca_projection, cosine_error_map, pca_projection_fast from ..vq_utils import ( @@ -33,6 +33,7 @@ LangSplatV2CollateFn, LangSplatV2Dataset, build_feature_map, + build_sparse_gt_features, ) from .langsplatv2_writer import LangSplatV2Writer @@ -502,9 +503,11 @@ def train(self, show_progress: bool = True) -> None: with nvtx.range("data_to_device"): minibatch = minibatch.to(self.device) - # Build dense feature map on GPU from compact data - with nvtx.range("build_feature_map"): - gt_features, feature_mask = build_feature_map( + # Gather GT features only at valid pixels (sparse). Avoids the dense + # [B,H,W,512] GT map; the returned mask drives the matching gather of + # predicted features below. + with nvtx.range("build_sparse_gt"): + gt_valid, feature_mask = build_sparse_gt_features( features=minibatch["features"], seg_map=minibatch["seg_map"], clip_n_dims=self._cfg.model.clip_n_dims, @@ -519,25 +522,33 @@ def train(self, show_progress: bool = True) -> None: layer_num - 1, ) - # Forward pass + # Whether this step logs metrics/images (drives optional dense work) + is_log_step = (self._global_step + 1) % self._log_interval_steps == 0 + want_dense = is_log_step and self._cfg.log_test_images + + # Forward pass: render full weight maps, then decode only the valid + # pixels into CLIP features (sparse), keeping memory proportional to + # the number of mapped pixels rather than the full image. with nvtx.range("forward_pass"): - predicted_features, alpha = self._model( + weight_maps, alpha = self._model.render_weight_maps( world_to_camera=minibatch["world_to_camera"], projection=minibatch["projection"], image_width=img_w, image_height=img_h, - layer_idx=layer_idx, ) + pred_valid = self._model.decode_weight_maps(weight_maps[feature_mask], layer_idx) - # Compute loss + # Compute loss on the sparse valid set (numerically identical to the + # dense loss; see calculate_langsplatv2_loss_sparse). with nvtx.range("loss_computation"): - loss_dict = calculate_langsplatv2_loss( - predicted_features=predicted_features, - gt_features=gt_features, - mask=feature_mask, + loss_dict = calculate_langsplatv2_loss_sparse( + pred_valid=pred_valid, + gt_valid=gt_valid, + num_pixels=feature_mask.numel(), use_cosine_loss=self._cfg.use_cosine_loss, use_l1_loss=self._cfg.use_l1_loss, normalize_features=self._cfg.normalize_features, + compute_valid_metric=is_log_step, ) loss = loss_dict["total_loss"] @@ -580,13 +591,18 @@ def train(self, show_progress: bool = True) -> None: if key != "total_loss": self._writer.log_metric(self._global_step, f"train/{key}", val.item()) - # Log training images when enabled - if self._cfg.log_test_images: - self._log_training_images( - predicted_features.detach(), - gt_features.detach(), - feature_mask, - ) + # Log training images when enabled. The training step runs in + # the sparse regime, so reconstruct the dense maps here (only on + # logging steps) for the PCA / error visualizations. + if want_dense: + with torch.no_grad(): + dense_pred = self._model.decode_weight_maps(weight_maps, layer_idx) + dense_gt, dense_mask = build_feature_map( + features=minibatch["features"], + seg_map=minibatch["seg_map"], + clip_n_dims=self._cfg.model.clip_n_dims, + ) + self._log_training_images(dense_pred, dense_gt, dense_mask) pbar.update(1) diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py index d39ade6..947377e 100644 --- a/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py +++ b/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py @@ -23,9 +23,20 @@ done """ import logging +import os import pathlib from typing import Literal +# Use expandable CUDA memory segments to avoid caching-allocator fragmentation. +# Training renders a fixed-size weight map but decodes/loss-computes only the +# valid (masked) pixels, whose count varies per image. Those variable-size +# allocations fragment the default allocator (reserved memory grows step over +# step, e.g. ~17 GB -> ~44 GB), even though peak *allocated* memory stays ~13 GB. +# Expandable segments keep reserved memory tracking actual usage. Must be set +# before CUDA initializes (i.e. before the first CUDA allocation), so we set it +# here at import time. A user-provided value is respected. +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + import torch import tyro From 93ac68a597901810766306386174b6603c599098 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 4 Jun 2026 15:01:31 +1200 Subject: [PATCH 3/3] LangSplatV2: use fused CUDA Adam for the training optimizer The optimizer step was ~5% of each training iteration, dominated by updating the large per-Gaussian logits tensor. Switch to torch.optim.Adam(fused=True) when training on CUDA -- it is numerically equivalent but roughly halves the step time (5.5 ms -> 2.6 ms, ~3% faster iterations overall). Guarded by fused=(device.type == "cuda") in both LangSplatV2Trainer.new and from_state_dict so CPU/eval paths and checkpoint restore are unaffected. Verified the loss trajectory is unchanged. Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/training/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py index 866fd6c..b5d41c8 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py @@ -279,11 +279,15 @@ def new( logger.info(f"Model initialized with {gs_model.num_gaussians:,} Gaussians") - # Create optimizer (only optimize language feature parameters) + # Create optimizer (only optimize language feature parameters). Use the + # fused CUDA implementation when available -- it is numerically + # equivalent but noticeably faster for the large per-Gaussian logits + # tensor (the optimizer step is otherwise ~5% of each training step). optimizer = torch.optim.Adam( params=[model.logits, model.codebooks], lr=config.learning_rate, eps=1e-15, + fused=(torch.device(device).type == "cuda"), ) # No scheduler needed for constant LR (matching original LangSplatV2) @@ -359,11 +363,12 @@ def from_state_dict( for param in model.parameters(): param.requires_grad_(False) - # Restore optimizer + # Restore optimizer (fused CUDA Adam when available; see `new`) optimizer = torch.optim.Adam( params=[model.logits, model.codebooks], lr=config.learning_rate, eps=1e-15, + fused=(torch.device(device).type == "cuda"), ) if not eval_only: optimizer.load_state_dict(state_dict["optimizer"])