Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,69 @@ def calculate_langsplatv2_loss(

loss_dict["total_loss"] = total_loss
return loss_dict


def calculate_langsplatv2_loss_sparse(
pred_valid: torch.Tensor,
gt_valid: torch.Tensor,
num_pixels: int,
use_cosine_loss: bool = True,
use_l1_loss: bool = False,
normalize_features: bool = False,
compute_valid_metric: bool = False,
) -> dict[str, torch.Tensor]:
"""Sparse equivalent of :func:`calculate_langsplatv2_loss`.

Operates only on the valid (masked) pixels, avoiding the dense
``[B, H, W, C]`` feature maps entirely. This is numerically identical to
the dense version because unmapped pixels contribute exactly zero to the
dense loss (their GT features are zero and the per-pixel loss is multiplied
by the mask), and the dense normalisation divides by ``mask.numel()`` -- so
``sum_over_valid / num_pixels`` reproduces the same value and gradients.

Args:
pred_valid: Predicted features at valid pixels, shape ``[N_valid, C]``.
gt_valid: Ground-truth features at the same pixels, shape ``[N_valid, C]``.
num_pixels: Total pixel count of the dense map (``mask.numel()``), used
as the normalisation denominator to match the dense loss exactly.
use_cosine_loss: Whether to include cosine similarity loss.
use_l1_loss: Whether to include L1 loss.
normalize_features: Whether to L2-normalize predicted features.
compute_valid_metric: If True, also compute the ``cosine_loss_valid``
diagnostic (mean over valid pixels). Off by default since it is only
used for occasional logging.

Returns:
Dictionary with the same keys as :func:`calculate_langsplatv2_loss`.
"""
assert use_cosine_loss or use_l1_loss, "At least one loss type must be enabled"

if normalize_features:
pred_valid = pred_valid / (pred_valid.norm(dim=-1, keepdim=True) + 1e-10)

loss_dict: dict[str, torch.Tensor] = {}
total_loss = torch.tensor(0.0, device=pred_valid.device)
has_valid = pred_valid.shape[0] > 0

if use_cosine_loss:
if has_valid:
per_valid = 1.0 - F.cosine_similarity(pred_valid, gt_valid, dim=-1) # [N_valid]
cos_loss_all = per_valid.sum() / num_pixels
else:
cos_loss_all = total_loss
loss_dict["cosine_loss"] = cos_loss_all
total_loss = total_loss + cos_loss_all

if compute_valid_metric:
loss_dict["cosine_loss_valid"] = per_valid.mean() if has_valid else cos_loss_all

if use_l1_loss:
if has_valid:
l1 = torch.abs(pred_valid - gt_valid).sum() / num_pixels
else:
l1 = total_loss
loss_dict["l1_loss"] = l1
total_loss = total_loss + l1

loss_dict["total_loss"] = total_loss
return loss_dict
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,57 @@ def _clean_single_mask(
) -> Tuple[np.ndarray, float, List[float]]:
"""Clean one mask and compute its bounding box. Thread-safe (CPU-only).

The connected-component work runs on the mask's bounding-box crop (padded
by 1px) rather than the full frame. This is exactly equivalent to cleaning
the full-resolution mask: the 1px pad guarantees the exterior background is
a single border-touching component, so the hole/island labelling is
identical, but it avoids scanning the (mostly empty) full image, which
dominates runtime at high resolution.

Returns:
(cleaned_seg, score, box_xyxy) where score is 1.0 if the mask was
unchanged and 0.0 if it was modified.
"""
seg = seg_raw.astype(bool)
seg, changed_holes = remove_small_regions(seg, min_area, mode="holes")
unchanged = not changed_holes
seg, changed_islands = remove_small_regions(seg, min_area, mode="islands")
unchanged = unchanged and not changed_islands

ys, xs = np.where(seg)
if len(xs) == 0:
box = [0.0, 0.0, 0.0, 0.0]
seg_full = seg_raw.astype(bool)
h, w = seg_full.shape

rows = np.any(seg_full, axis=1)
if not rows.any():
return seg_full, 1.0, [0.0, 0.0, 0.0, 0.0]
cols = np.any(seg_full, axis=0)
row_idx = np.where(rows)[0]
col_idx = np.where(cols)[0]
y0 = max(0, int(row_idx[0]) - 1)
y1 = min(h, int(row_idx[-1]) + 2)
x0 = max(0, int(col_idx[0]) - 1)
x1 = min(w, int(col_idx[-1]) + 2)

sub = seg_full[y0:y1, x0:x1]
sub, changed_holes = remove_small_regions(sub, min_area, mode="holes")
sub, changed_islands = remove_small_regions(sub, min_area, mode="islands")
unchanged = (not changed_holes) and (not changed_islands)

# The "islands" pass keeps background label 0 (matching the original
# LangSplatV2 logic), so when it fires the result is the whole frame True
# except the small islands. Reproduce that: everything outside the crop is
# True iff islands changed the mask, otherwise False.
seg = np.ones((h, w), dtype=bool) if changed_islands else np.zeros((h, w), dtype=bool)
seg[y0:y1, x0:x1] = sub

if changed_islands:
box = [0.0, 0.0, float(w - 1), float(h - 1)]
else:
box = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
sub_rows = np.where(np.any(sub, axis=1))[0]
if len(sub_rows) == 0:
box = [0.0, 0.0, 0.0, 0.0]
else:
sub_cols = np.where(np.any(sub, axis=0))[0]
box = [
float(int(sub_cols[0]) + x0),
float(int(sub_rows[0]) + y0),
float(int(sub_cols[-1]) + x0),
float(int(sub_rows[-1]) + y0),
]

return seg, float(unchanged), box

Expand Down Expand Up @@ -145,6 +181,7 @@ def mask_nms(
iou_thr: float = 0.7,
score_thr: float = 0.1,
inner_thr: float = 0.2,
nms_max_dim: int | None = 1024,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -159,6 +196,14 @@ def mask_nms(
iou_thr: IoU threshold for NMS.
score_thr: Minimum score threshold.
inner_thr: Inner overlap threshold for removing contained masks.
nms_max_dim: If set, masks whose longest side exceeds this value are
downsampled (longest side capped at ``nms_max_dim``) *only* for
computing the pairwise IoU / containment matrices. The dominant
memory cost here is the full-resolution ``masks_flat @ masks_flat.T``
(e.g. a 4K mask is ~33 MB as float32, so a few thousand masks is
tens of GB). Capping the working resolution bounds that cost while
leaving the returned masks at full resolution. ``None`` disables
downsampling (exact original behavior).

Returns:
Indices of selected masks after NMS.
Expand All @@ -174,15 +219,51 @@ def mask_nms(
scores, idx = scores.sort(0, descending=True)
num_masks = idx.shape[0]

masks_ord = masks[idx.view(-1), :]
masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float)
sorted_idx = idx.view(-1)

# The pairwise IoU / containment matrices below only drive thresholding
# decisions, and the indices returned by this function point back at the
# original full-resolution masks. So we can compute those matrices on a
# downsampled copy to bound peak memory (the float `masks_flat @
# masks_flat.T` is by far the largest allocation at high resolution) with
# negligible effect on which masks are kept.
#
# The downsampling is done in chunks: `interpolate` needs a float input, so
# converting all masks to full-resolution float at once would allocate the
# very tensor we are trying to avoid. Casting/resizing a chunk at a time
# keeps the transient float buffer to `chunk_size` masks.
h, w = masks.shape[-2:]
longest = max(h, w)
if nms_max_dim is not None and longest > nms_max_dim:
scale = nms_max_dim / longest
new_h = max(1, int(round(h * scale)))
new_w = max(1, int(round(w * scale)))
chunk_size = 64
chunks = []
for start in range(0, num_masks, chunk_size):
block = masks[sorted_idx[start : start + chunk_size]].unsqueeze(1).float()
block = (
torch.nn.functional.interpolate(block, size=(new_h, new_w), mode="area").squeeze(1) > 0.5
)
chunks.append(block)
masks_work = torch.cat(chunks, dim=0)
del chunks
else:
masks_work = masks[sorted_idx, :]

masks_flat = masks_ord.reshape(num_masks, -1).float()
masks_area = torch.sum(masks_work, dim=(1, 2), dtype=torch.float)
# A small mask can downsample to zero area; clamp the denominators so the
# IoU / containment ratios stay finite (a vanished mask just scores ~0
# overlap against everything, i.e. it is kept). For non-zero areas this
# clamp is a no-op, matching the original numerics.
masks_area_safe = masks_area.clamp_min(1.0)

masks_flat = masks_work.reshape(num_masks, -1).float()
intersection = masks_flat @ masks_flat.T
union = masks_area[:, None] + masks_area[None, :] - intersection
union = (masks_area_safe[:, None] + masks_area_safe[None, :] - intersection).clamp_min(1.0)
iou_matrix = intersection / union

R = intersection / masks_area[:, None]
R = intersection / masks_area_safe[:, None]
inner_val = 1 - R * R.T
cond = (R < 0.5) & (R.T >= 0.85)
inner_iou_matrix = torch.where(cond, inner_val, torch.zeros_like(inner_val))
Expand Down Expand Up @@ -249,6 +330,7 @@ def masks_update(
score_thr: float = 0.7,
inner_thr: float = 0.5,
max_area_frac: float = 0.95,
nms_max_dim: int | None = 1024,
) -> tuple:
"""
Apply mask NMS to multiple lists of masks.
Expand All @@ -261,6 +343,9 @@ def masks_update(
max_area_frac: Discard masks covering more than this fraction of the
image. Near-full-image masks poison the inner-containment check
by appearing to contain every other mask.
nms_max_dim: Longest-side cap for the (memory-heavy) IoU / containment
computation inside :func:`mask_nms`. See that function for details.
``None`` disables downsampling.

Returns:
Tuple of filtered mask lists.
Expand Down Expand Up @@ -310,7 +395,14 @@ def masks_update(
stability = stability.cuda()

scores = stability * iou_pred
keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr)
keep_mask_nms = mask_nms(
seg_pred,
scores,
iou_thr=iou_thr,
score_thr=score_thr,
inner_thr=inner_thr,
nms_max_dim=nms_max_dim,
)

keep_set = set(keep_mask_nms.int().cpu().numpy().tolist())
filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,64 @@ def build_feature_map(
return gt_features, feature_mask


def build_sparse_gt_features(
features: JaggedTensor | torch.Tensor,
seg_map: torch.Tensor,
clip_n_dims: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather only the valid GT features, without building a dense ``[H,W,C]`` map.

This is the sparse counterpart to :func:`build_feature_map`. It returns the
GT features for valid pixels in the row-major order of a boolean index over
``[B, H, W]`` (i.e. matching ``weight_maps[mask]``), so the predicted and GT
features stay aligned. The returned mask reflects any out-of-bounds dropping,
exactly as :func:`build_feature_map` does.

Args:
features: CLIP features as a :class:`JaggedTensor` (batched) or a plain
``torch.Tensor`` ``[N_masks, clip_n_dims]`` (unbatched).
seg_map: Segmentation map ``[B, H, W]`` or ``[H, W]`` (-1 = no feature).
clip_n_dims: Feature dimensionality (for the empty-output fallback).

Returns:
Tuple of (gt_valid, mask):
- gt_valid: ``[N_valid, clip_n_dims]`` GT features at valid pixels.
- mask: ``[B, H, W]`` or ``[H, W]`` bool, post out-of-bounds filtering.
"""
if isinstance(features, torch.Tensor):
mask = seg_map >= 0 # [H, W]
idx = seg_map[mask].long()
in_bounds = idx < features.shape[0]
if not bool(in_bounds.all()):
positions = mask.nonzero(as_tuple=False)[in_bounds]
mask = torch.zeros_like(mask)
mask[positions[:, 0], positions[:, 1]] = True
idx = idx[in_bounds]
return features[idx], mask

# Batched JaggedTensor path
B = seg_map.shape[0] if seg_map.dim() == 3 else 1
device = features.jdata.device
dtype = features.jdata.dtype
mask = seg_map >= 0 # [B, H, W]
parts: list[torch.Tensor] = []
for b in range(B):
mask_b = mask[b]
feat_b = features[b].jdata
idx = seg_map[b][mask_b].long()
in_bounds = idx < feat_b.shape[0]
if not bool(in_bounds.all()):
positions = mask_b.nonzero(as_tuple=False)[in_bounds]
new_mask_b = torch.zeros_like(mask_b)
new_mask_b[positions[:, 0], positions[:, 1]] = True
mask[b] = new_mask_b
idx = idx[in_bounds]
parts.append(feat_b[idx])

gt_valid = torch.cat(parts, dim=0) if parts else torch.zeros(0, clip_n_dims, device=device, dtype=dtype)
return gt_valid, mask


class LangSplatV2Input(dict):
"""Batched input dictionary for the LangSplatV2 model.

Expand Down
Loading
Loading