From a33b51f86607b0d7faa8093c2273b5ac40c88375 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 4 Mar 2026 18:33:04 +1300 Subject: [PATCH 01/13] Enhance LangSplatV2 with new evaluation and training scripts - Added `train_eval.sh` and `train_scenes.sh` for streamlined training and evaluation processes. - Introduced `download_data.py` for downloading the LERF-OVS dataset, along with `eval_lerf.py` for model evaluation on this dataset. - Created `environment.yml` to define the Conda environment for LangSplatV2, including necessary dependencies. - Updated `.gitignore` to exclude new log and result directories. - Enhanced `train_langsplatv2.py` to save final checkpoints and log paths for better tracking of training progress. - Various alterations to the training, data generation and evaluation to be more inline with the original implementation results. These changes improve the usability and functionality of the LangSplatV2 framework for open-vocabulary segmentation tasks. Signed-off-by: Jonathan Swartz --- .../langsplatv2/.gitignore | 2 + .../langsplatv2/environment.yml | 17 + .../langsplatv2/evaluation/download_data.py | 26 + .../langsplatv2/evaluation/eval_lerf.py | 998 ++++++++++++++++++ .../langsplatv2/langsplatv2/config.py | 192 +++- .../langsplatv2/evaluation/__init__.py | 3 + .../evaluation/datasets/__init__.py | 23 + .../langsplatv2/evaluation/datasets/lerf.py | 78 ++ .../langsplatv2/evaluation/datasets/util.py | 43 + .../evaluation/openclip_relevancy.py | 149 +++ .../langsplatv2/langsplatv2/loss.py | 69 +- .../langsplatv2/scene_transforms/__init__.py | 4 + .../scene_transforms/clip_feature_encoding.py | 139 ++- .../import_original_features.py | 252 +++++ .../scene_transforms/mask_utils.py | 336 ++++++ .../multi_scale_sam1_masks.py | 422 ++++++++ .../scene_transforms/multi_scale_sam_masks.py | 259 ++--- .../langsplatv2/training/dataset.py | 20 +- .../training/langsplatv2_writer.py | 20 + .../langsplatv2/training/trainer.py | 15 +- .../langsplatv2/langsplatv2/vq_utils.py | 29 + .../langsplatv2/pyproject.toml | 1 + .../langsplatv2/train_eval.sh | 21 + .../langsplatv2/train_langsplatv2.py | 6 + .../langsplatv2/train_scenes.sh | 13 + 25 files changed, 2778 insertions(+), 359 deletions(-) create mode 100644 open_vocabulary_segmentation/langsplatv2/environment.yml create mode 100644 open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py create mode 100644 open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py create mode 100644 open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py create mode 100644 open_vocabulary_segmentation/langsplatv2/train_eval.sh create mode 100644 open_vocabulary_segmentation/langsplatv2/train_scenes.sh diff --git a/open_vocabulary_segmentation/langsplatv2/.gitignore b/open_vocabulary_segmentation/langsplatv2/.gitignore index 4bf5408..417a8a9 100644 --- a/open_vocabulary_segmentation/langsplatv2/.gitignore +++ b/open_vocabulary_segmentation/langsplatv2/.gitignore @@ -1,2 +1,4 @@ *.nsys-rep langsplatv2_logs/ +frgs_logs/ +*_results*/ diff --git a/open_vocabulary_segmentation/langsplatv2/environment.yml b/open_vocabulary_segmentation/langsplatv2/environment.yml new file mode 100644 index 0000000..625fcb4 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/environment.yml @@ -0,0 +1,17 @@ +name: fvdb_langsplatv2 +channels: + - conda-forge +dependencies: + - python >=3.11 + - numpy + - pytorch + - torchvision + - opencv + - tqdm + - scikit-learn + - pip + - gdown + - open-clip-torch + - tyro + - segment-anything + - fvdb-reality-capture diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py b/open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py new file mode 100644 index 0000000..fce3f72 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py @@ -0,0 +1,26 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""CLI script to download the LERF-OVS evaluation dataset.""" +from dataclasses import dataclass +from pathlib import Path + +import tyro +from langsplatv2.evaluation.datasets import set_dataset_root +from langsplatv2.evaluation.datasets.lerf import download_lerf_data + + +@dataclass +class DownloadLERFData: + """Download the LERF-OVS dataset for open-vocabulary segmentation evaluation.""" + + dataset_root: Path = Path("data") + """Root directory to store downloaded datasets.""" + + def main(self): + set_dataset_root(self.dataset_root) + download_lerf_data() + + +if __name__ == "__main__": + tyro.cli(DownloadLERFData).main() diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py b/open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py new file mode 100644 index 0000000..7615e88 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py @@ -0,0 +1,998 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Evaluation script for LERF dataset open-vocabulary segmentation. + +This script evaluates trained LangSplatV2 models on the LERF-OVS dataset by: +1. Loading 3 trained checkpoints (one per SAM scale level) +2. Loading LERF ground-truth annotations (labelme-format JSONs) +3. Rendering CLIP features from each level and computing relevancy maps +4. Computing segmentation mIoU and localization accuracy + +The evaluation matches the original LangSplatV2 eval_lerf.py methodology: +- Relevancy is computed using OpenCLIP (ViT-B-16) with pos/neg softmax scoring +- Segmentation uses AvgPool smoothing + normalization + thresholding +- The best level is chosen per-prompt based on max relevancy score +- Localization checks if the peak of the relevancy map is inside any GT bbox + +LERF-OVS dataset structure: + lerf_ovs/ + label//frame_XXXXX.json, frame_XXXXX.jpg + / + images/ + sparse/ + output//point_cloud/iteration_30000/point_cloud.ply + +Usage: + # Evaluate a single scene + python eval_lerf.py \\ + --lerf-root /path/to/lerf_ovs \\ + --results-root ./langsplatv2_results \\ + --gs-model-path /path/to/point_cloud.ply \\ + --scenes teatime + + # Evaluate all scenes + python eval_lerf.py \\ + --lerf-root /path/to/lerf_ovs \\ + --results-root ./langsplatv2_results +""" +import argparse +import glob +import json +import logging +import os +import pathlib +import sys +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Tuple + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +import tqdm +from fvdb_reality_capture.sfm_scene import SfmScene +from langsplatv2.evaluation.openclip_relevancy import OpenCLIPRelevancy +from langsplatv2.model import LangSplatV2Model +from langsplatv2.training.trainer import LangSplatV2Trainer +from langsplatv2.util import load_splats_from_file + +matplotlib.use("Agg") # Use non-interactive backend + + +# --------------------------------------------------------------------------- +# Ground truth parsing (matching original eval_gt_lerfdata) +# --------------------------------------------------------------------------- + + +def polygon_to_mask(img_shape: Tuple[int, int], points_list: list) -> np.ndarray: + """Convert a polygon to a binary mask. + + Args: + img_shape: (height, width) of the target image. + points_list: List of [x, y] polygon vertices. + + Returns: + Binary mask of shape ``(height, width)`` with dtype uint8. + """ + points = np.asarray(points_list, dtype=np.int32) + mask = np.zeros(img_shape, dtype=np.uint8) + cv2.fillPoly(mask, [points], 1) + return mask + + +def stack_mask(mask_base: np.ndarray, mask_add: np.ndarray) -> np.ndarray: + """Merge two binary masks (logical OR).""" + mask = mask_base.copy() + mask[mask_add != 0] = 1 + return mask + + +def load_lerf_ground_truth( + json_folder: pathlib.Path, + logger: logging.Logger, +) -> Tuple[Dict, Tuple[int, int], list]: + """Parse LERF-OVS ground truth annotations from labelme-format JSONs. + + Matches the original ``eval_gt_lerfdata`` function exactly. + + Args: + json_folder: Path to the label folder for a specific scene + (e.g. ``lerf_ovs/label/teatime``). + logger: Logger instance. + + Returns: + Tuple of: + - gt_ann: ``{str(frame_idx): {label: {"bboxes": ndarray, "mask": ndarray}}}`` + - image_shape: ``(height, width)`` of the ground truth images + - img_paths: sorted list of GT JPEG image paths + """ + gt_json_paths = sorted(glob.glob(os.path.join(str(json_folder), "frame_*.json"))) + img_paths = sorted(glob.glob(os.path.join(str(json_folder), "frame_*.jpg"))) + + if not gt_json_paths: + raise FileNotFoundError(f"No frame_*.json files found in {json_folder}") + + logger.info(f"Found {len(gt_json_paths)} GT annotations, {len(img_paths)} GT images in {json_folder}") + + gt_ann = {} + h, w = 0, 0 + for js_path in gt_json_paths: + img_ann: Dict[str, dict] = defaultdict(dict) + with open(js_path, "r") as f: + gt_data = json.load(f) + + h, w = gt_data["info"]["height"], gt_data["info"]["width"] + # Frame index: frame_00001 -> idx=0 (1-indexed filename to 0-indexed) + idx = int(gt_data["info"]["name"].split("_")[-1].split(".jpg")[0]) - 1 + + for prompt_data in gt_data["objects"]: + label = prompt_data["category"] + box = np.asarray(prompt_data["bbox"]).reshape(-1) # x1y1x2y2 + mask = polygon_to_mask((h, w), prompt_data["segmentation"]) + + if img_ann[label].get("mask", None) is not None: + # Merge multiple objects with the same label + mask = stack_mask(img_ann[label]["mask"], mask) + img_ann[label]["bboxes"] = np.concatenate( + [img_ann[label]["bboxes"].reshape(-1, 4), box.reshape(-1, 4)], axis=0 + ) + else: + img_ann[label]["bboxes"] = box + img_ann[label]["mask"] = mask + + gt_ann[str(idx)] = dict(img_ann) + + logger.info(f"Parsed GT: {len(gt_ann)} frames, image size {w}x{h}") + for idx_str, ann in gt_ann.items(): + labels = list(ann.keys()) + logger.debug(f" Frame {idx_str}: {len(labels)} labels: {labels}") + + return gt_ann, (h, w), img_paths + + +# --------------------------------------------------------------------------- +# Segmentation and localization (matching original eval_lerf.py) +# --------------------------------------------------------------------------- + + +def smooth_mask(mask_pred: torch.Tensor) -> torch.Tensor: + """Smooth a binary mask using average pooling (matching original smooth_cuda). + + Args: + mask_pred: Binary mask tensor ``[H, W]`` of type uint8. + + Returns: + Smoothed binary mask ``[H, W]`` of type uint8. + """ + scale = 7 + avg_pool = torch.nn.AvgPool2d(kernel_size=scale, stride=1, padding=3, count_include_pad=False).to(mask_pred.device) + avg_filtered = avg_pool(mask_pred.float().unsqueeze(0).unsqueeze(0)) + mask = (avg_filtered > 0.5).to(torch.uint8).squeeze(0).squeeze(0) + return mask + + +def segmentation_process( + relevancy_map: torch.Tensor, + thresh: float, + img_ann: dict, + prompts: list, + device: torch.device, +) -> Tuple[list, list, dict]: + """Compute segmentation IoU for each prompt across all levels. + + Replicates the original ``segmentation_process_cuda`` exactly: + 1. AvgPool2d(29) smoothing blended 50/50 with raw relevancy + 2. Min-max normalize to [-1, 1] then clip to [0, 1] + 3. Threshold and smooth with AvgPool2d(7) + 4. IoU against GT mask + 5. Choose level with highest max relevancy score + + Args: + relevancy_map: ``[n_levels, n_prompts, H, W]`` relevancy values. + thresh: Mask threshold (default 0.4). + img_ann: GT annotations for this frame ``{label: {"mask": ndarray, "bboxes": ndarray}}``. + prompts: List of prompt labels (same order as relevancy_map's prompt dim). + device: Torch device. + + Returns: + Tuple of: + - chosen_iou_list: per-prompt IoU at the chosen level + - chosen_lvl_list: chosen level index per prompt + - iou_all: ``{prompt: [iou_level_0, iou_level_1, iou_level_2]}`` + """ + n_head, n_prompt, h, w = relevancy_map.shape + valid_map = relevancy_map.clone() + + chosen_iou_list = [] + chosen_lvl_list = [] + iou_all = {} + + for k in range(n_prompt): + iou_lvl = torch.zeros(n_head, device=device) + for i in range(n_head): + # AvgPool smoothing (kernel=29) + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(valid_map[i][k].unsqueeze(0).unsqueeze(0)) + valid_map[i][k] = 0.5 * (avg_filtered.squeeze(0).squeeze(0) + valid_map[i][k]) + + # Normalize to [-1, 1] then clip to [0, 1] + output = valid_map[i][k] + output = output - torch.min(output) + output = output / (torch.max(output) + 1e-9) + output = output * 2.0 - 1.0 # scale to [-1, 1] + output = torch.clip(output, 0, 1) + + # Threshold and smooth + mask_pred = (output > thresh).to(torch.uint8) + mask_pred = smooth_mask(mask_pred) + + # GT mask + mask_gt = torch.from_numpy(img_ann[prompts[k]]["mask"].astype(np.uint8)).to(device) + + # IoU + intersection = torch.sum(torch.logical_and(mask_gt, mask_pred)) + union = torch.sum(torch.logical_or(mask_gt, mask_pred)) + iou = intersection.float() / union.float() + iou_lvl[i] = iou + + iou_all[prompts[k]] = iou_lvl.tolist() + + # Choose level with highest max relevancy score + score_lvl = torch.zeros(n_head, device=valid_map.device) + for i in range(n_head): + score_lvl[i] = valid_map[i, k].max() + chosen_lvl = torch.argmax(score_lvl) + + chosen_iou_list.append(iou_lvl[chosen_lvl].cpu().item()) + chosen_lvl_list.append(chosen_lvl.cpu().item()) + + return chosen_iou_list, chosen_lvl_list, iou_all + + +def localization_process( + relevancy_map: torch.Tensor, + img_ann: dict, + device: torch.device, +) -> int: + """Compute localization accuracy (peak-in-bbox check). + + Replicates the original ``localization_process_cuda`` exactly: + 1. AvgPool2d(29) smoothing of relevancy + 2. For each prompt, find the peak location at the best level + 3. Check if peak falls inside any GT bbox + + Args: + relevancy_map: ``[n_levels, n_prompts, H, W]`` relevancy values. + img_ann: GT annotations ``{label: {"mask": ndarray, "bboxes": ndarray}}``. + device: Torch device. + + Returns: + Number of correctly localized bboxes. + """ + n_head, n_prompt, h, w = relevancy_map.shape + + positives = list(img_ann.keys()) + acc_num = 0 + + for k in range(n_prompt): + select_output = relevancy_map[:, k] # [n_head, H, W] + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(select_output.unsqueeze(1)).squeeze(1) # [n_head, H, W] + + score_lvl = torch.zeros(n_head) + coord_lvl = [] + for i in range(n_head): + score = avg_filtered[i].max() + coord = torch.nonzero((avg_filtered[i] == score).to(torch.uint8)) + score_lvl[i] = score + coord_lvl.append(coord) + + selec_head = torch.argmax(score_lvl) + coord_final = coord_lvl[selec_head] + + for box in img_ann[positives[k]]["bboxes"].reshape(-1, 4): + x1, y1, x2, y2 = box + x_min, x_max = min(x1, x2), max(x1, x2) + y_min, y_max = min(y1, y2), max(y1, y2) + flag = 0 + for cord_list in coord_final: + # coord is (row, col) = (y, x) + if cord_list[1] >= x_min and cord_list[1] <= x_max and cord_list[0] >= y_min and cord_list[0] <= y_max: + acc_num += 1 + flag = 1 + break + if flag != 0: + break + + return acc_num + + +# --------------------------------------------------------------------------- +# Model loading helpers +# --------------------------------------------------------------------------- + + +def load_langsplatv2_model( + checkpoint_path: pathlib.Path, + gs_model_path: pathlib.Path, + device: torch.device, + logger: logging.Logger, + eval_topk: int | None = None, +) -> Tuple[LangSplatV2Model, SfmScene]: + """Load a trained LangSplatV2 model from a checkpoint. + + Args: + checkpoint_path: Path to the ``.pt`` checkpoint file. + gs_model_path: Path to the Gaussian splat PLY file. + device: Device to load the model on. + logger: Logger instance. + eval_topk: If set, override the checkpoint's topk value. The + original LangSplatV2 trains with topk=4. + + Returns: + Tuple of (LangSplatV2Model, SfmScene) from the checkpoint. + """ + # Load the base Gaussian splat + gs_model, _ = load_splats_from_file(gs_model_path, device) + logger.info(f"Loaded Gaussian splat with {gs_model.num_gaussians} gaussians from {gs_model_path}") + + # Load checkpoint and create trainer (eval-only mode) + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False) + trainer = LangSplatV2Trainer.from_state_dict( + state_dict=state_dict, + gs_model=gs_model, + gs_model_path=gs_model_path, + device=device, + eval_only=True, + ) + + model = trainer._model + sfm_scene = trainer._sfm_scene + feature_level = trainer._cfg.feature_level + + if eval_topk is not None and model.topk != eval_topk: + logger.info(f"Overriding topk: {model.topk} (checkpoint) -> {eval_topk} (eval)") + model.topk = eval_topk + + logger.info( + f"Loaded LangSplatV2 model (feature_level={feature_level}, topk={model.topk}) " f"from {checkpoint_path}" + ) + + return model, sfm_scene + + +def render_clip_features( + model: LangSplatV2Model, + world_to_camera: torch.Tensor, + projection: torch.Tensor, + image_width: int, + image_height: int, +) -> torch.Tensor: + """Render CLIP feature maps from a LangSplatV2 model. + + Args: + model: Trained LangSplatV2Model. + world_to_camera: ``[1, 4, 4]`` world-to-camera matrix. + projection: ``[1, 3, 3]`` camera intrinsics. + image_width: Render width. + image_height: Render height. + + Returns: + Normalized CLIP features ``[H, W, 512]``. + """ + with torch.no_grad(): + feature_maps, _ = model( + world_to_camera=world_to_camera, + projection=projection, + image_width=image_width, + image_height=image_height, + ) + # feature_maps: [1, H, W, 512], normalize + feat = feature_maps[0] # [H, W, 512] + feat = feat / (feat.norm(dim=-1, keepdim=True) + 1e-10) + return feat + + +# --------------------------------------------------------------------------- +# Visualization +# --------------------------------------------------------------------------- + + +def save_frame_visualization( + output_path: pathlib.Path, + frame_idx: int, + gt_img: np.ndarray | None, + relevancy_map: torch.Tensor, + prompts: list, + img_ann: dict, + chosen_iou_list: list, + chosen_lvl_list: list, + thresh: float, + device: torch.device, +): + """Save a per-frame visualization showing relevancy and segmentation. + + Creates a grid with one row per prompt showing: + - GT image with GT mask overlay + - Relevancy heatmap at chosen level + - Predicted mask at chosen level + - Overlay (predicted=red, GT=blue) + + Args: + output_path: Directory for output images. + frame_idx: Frame index. + gt_img: Optional GT image ``[H, W, 3]`` (RGB uint8). + relevancy_map: ``[n_levels, n_prompts, H, W]``. + prompts: List of prompt labels. + img_ann: GT annotations for this frame. + chosen_iou_list: IoU values per prompt. + chosen_lvl_list: Chosen level per prompt. + thresh: Mask threshold. + device: Torch device. + """ + n_prompts = len(prompts) + n_levels, _, h, w = relevancy_map.shape + + fig, axes = plt.subplots(n_prompts, 4, figsize=(24, 6 * n_prompts)) + if n_prompts == 1: + axes = axes.reshape(1, -1) + + for k, prompt in enumerate(prompts): + lvl = chosen_lvl_list[k] + iou = chosen_iou_list[k] + + # Get relevancy at chosen level + relev = relevancy_map[lvl, k].cpu().numpy() + + # Recompute predicted mask at chosen level (same as segmentation_process) + relev_t = relevancy_map[lvl, k].clone() + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(relev_t.unsqueeze(0).unsqueeze(0)) + blended = 0.5 * (avg_filtered.squeeze(0).squeeze(0) + relev_t) + output = blended - torch.min(blended) + output = output / (torch.max(output) + 1e-9) + output = output * 2.0 - 1.0 + output = torch.clip(output, 0, 1) + mask_pred = (output > thresh).to(torch.uint8) + mask_pred = smooth_mask(mask_pred) + mask_pred_np = mask_pred.cpu().numpy() + + mask_gt = img_ann[prompt]["mask"] + + # Col 0: GT image with GT mask overlay + if gt_img is not None: + overlay_gt = gt_img.copy().astype(np.float32) / 255.0 + else: + overlay_gt = np.zeros((h, w, 3), dtype=np.float32) + overlay_gt[:, :, 2] = np.clip(overlay_gt[:, :, 2] + mask_gt * 0.4, 0, 1) + axes[k, 0].imshow(overlay_gt) + axes[k, 0].set_title(f'GT: "{prompt}"') + axes[k, 0].axis("off") + + # Col 1: Relevancy heatmap + im = axes[k, 1].imshow(relev, cmap="jet", vmin=0, vmax=1) + axes[k, 1].set_title(f"Relevancy (level {lvl + 1})") + axes[k, 1].axis("off") + plt.colorbar(im, ax=axes[k, 1], fraction=0.046, pad=0.04) + + # Col 2: Predicted mask + axes[k, 2].imshow(mask_pred_np, cmap="gray") + axes[k, 2].set_title(f"Pred mask (thresh={thresh})") + axes[k, 2].axis("off") + + # Col 3: Overlay (red=pred, blue=GT) + if gt_img is not None: + overlay = gt_img.copy().astype(np.float32) / 255.0 + else: + overlay = np.zeros((h, w, 3), dtype=np.float32) + mask_overlay = np.zeros_like(overlay) + mask_overlay[:, :, 0] = mask_pred_np # Red for prediction + mask_overlay[:, :, 2] = mask_gt # Blue for GT + overlay = overlay * 0.5 + mask_overlay * 0.5 + axes[k, 3].imshow(np.clip(overlay, 0, 1)) + axes[k, 3].contour(mask_gt, colors="blue", linewidths=1, linestyles="solid") + axes[k, 3].contour(mask_pred_np, colors="red", linewidths=1, linestyles="dashed") + axes[k, 3].set_title(f"Overlay (IoU={iou * 100:.1f}%)") + axes[k, 3].axis("off") + + fig.suptitle(f"Frame {frame_idx}") + fig.tight_layout() + fig.savefig(output_path / f"frame_{frame_idx:05d}.jpg", dpi=150, pil_kwargs={"quality": 90}) + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Main evaluation +# --------------------------------------------------------------------------- + + +@dataclass +class EvaluationConfig: + """Configuration for LERF evaluation.""" + + lerf_root: pathlib.Path = pathlib.Path("data/lerf_ovs") + """Root directory of the LERF-OVS dataset.""" + + results_root: pathlib.Path = pathlib.Path("langsplatv2_results") + """Directory containing trained model checkpoints. + + Expected layout:: + + results_root/ + _level_1.pt + _level_2.pt + _level_3.pt + """ + + reconstructions_root: pathlib.Path | None = pathlib.Path("reconstructions") + """Directory containing per-scene Gaussian splat reconstructions. + + Expected layout:: + + reconstructions_root/ + .ply + + If None, falls back to the LERF dataset structure at + ``lerf_root//output//point_cloud/iteration_30000/point_cloud.ply``. + """ + + output_dir: pathlib.Path = pathlib.Path("lerf_eval_results") + """Directory to save evaluation results and visualizations.""" + + device: str = "cuda" + """Device for computation.""" + + mask_thresh: float = 0.4 + """Threshold for converting relevancy map to binary mask (matching original).""" + + save_visualizations: bool = True + """Whether to save per-frame visualization images.""" + + eval_topk: int = 4 + """Number of codebook entries to combine at evaluation time. + + The original LangSplatV2 trains with topk=4""" + + +def get_camera_for_frame( + sfm_scene, + frame_idx: int, + device: torch.device, + logger: logging.Logger, +) -> Tuple[torch.Tensor, torch.Tensor, int, int] | None: + """Get camera parameters for a specific frame index. + + Images in the SfmScene are sorted by name to match the COLMAP ordering + used by the original LangSplatV2 dataset. + + Args: + sfm_scene: The SfmScene from the checkpoint. + frame_idx: 0-based frame index. + device: Torch device. + logger: Logger instance. + + Returns: + Tuple of (world_to_camera [1,4,4], projection [1,3,3], height, width) + or None if index is out of range. + """ + # Sort images by name to match COLMAP ordering + sorted_images = sorted(sfm_scene.images, key=lambda img: img.image_path) + + if frame_idx >= len(sorted_images): + logger.error(f"Frame index {frame_idx} out of range ({len(sorted_images)} images)") + return None + + img_meta = sorted_images[frame_idx] + c2w = torch.from_numpy(img_meta.camera_to_world_matrix).float() + K = torch.from_numpy(img_meta.camera_metadata.projection_matrix).float() + h = img_meta.camera_metadata.height + w = img_meta.camera_metadata.width + + w2c = torch.linalg.inv(c2w).contiguous() + + return ( + w2c.unsqueeze(0).to(device), + K.unsqueeze(0).to(device), + h, + w, + ) + + +def run_lerf_evaluation( + scene_name: str, + config: EvaluationConfig, + logger: logging.Logger, +) -> dict | None: + """Run evaluation on a single LERF scene. + + Args: + scene_name: Name of the scene (e.g. "teatime", "figurines"). + config: Evaluation configuration. + logger: Logger instance. + + Returns: + Dictionary with evaluation results, or None if evaluation failed. + """ + device = torch.device(config.device) + + # --- Locate checkpoint files --- + level_checkpoints = [] + for level in [1, 2, 3]: + ckpt_path = config.results_root / f"{scene_name}_level_{level}.pt" + if not ckpt_path.exists(): + logger.error(f"Checkpoint not found: {ckpt_path}") + return None + level_checkpoints.append(ckpt_path) + + # --- Locate GS model for this scene --- + gs_path = None + # Try reconstructions_root/.ply first + if config.reconstructions_root is not None: + candidate = config.reconstructions_root / f"{scene_name}.ply" + if candidate.exists(): + gs_path = candidate + else: + # Also try .pt/.pth extensions + for ext in (".pt", ".pth"): + candidate = config.reconstructions_root / f"{scene_name}{ext}" + if candidate.exists(): + gs_path = candidate + break + + # Fall back to LERF dataset structure + if gs_path is None: + gs_path = ( + config.lerf_root + / scene_name + / "output" + / scene_name + / "point_cloud" + / "iteration_30000" + / "point_cloud.ply" + ) + + if not gs_path.exists(): + logger.error( + f"GS model not found for scene '{scene_name}'. " + f"Searched: reconstructions_root={config.reconstructions_root}, " + f"lerf_root={config.lerf_root}//output/..." + ) + return None + + logger.info(f"Using GS model: {gs_path}") + + # --- Load ground truth --- + label_dir = config.lerf_root / "label" / scene_name + if not label_dir.exists(): + logger.error(f"Label directory not found: {label_dir}") + return None + + gt_ann, image_shape, gt_img_paths = load_lerf_ground_truth(label_dir, logger) + eval_frame_indices = [int(idx) for idx in gt_ann.keys()] + gt_h, gt_w = image_shape + + # --- Load models (3 levels) --- + models: list[LangSplatV2Model] = [] + sfm_scene = None + for i, ckpt_path in enumerate(level_checkpoints): + model, scene = load_langsplatv2_model(ckpt_path, gs_path, device, logger, eval_topk=config.eval_topk) + models.append(model) + if sfm_scene is None: + sfm_scene = scene + + # --- Load OpenCLIP for relevancy --- + clip_relevancy = OpenCLIPRelevancy(device=config.device) + + # --- Evaluate each annotated frame --- + chosen_iou_all = [] + chosen_lvl_list_all = [] + acc_num_total = 0 + total_bboxes = 0 + per_frame_results = [] + + scene_output_dir = config.output_dir / scene_name + scene_output_dir.mkdir(parents=True, exist_ok=True) + + for frame_enum_idx, frame_idx in enumerate( + tqdm.tqdm(eval_frame_indices, desc=f"Evaluating {scene_name}", leave=False) + ): + # Load GT image for visualization (optional) + gt_img = None + if frame_enum_idx < len(gt_img_paths) and config.save_visualizations: + gt_img_bgr = cv2.imread(gt_img_paths[frame_enum_idx]) + if gt_img_bgr is not None: + gt_img = cv2.cvtColor(gt_img_bgr, cv2.COLOR_BGR2RGB) + + # Get camera parameters for this frame + cam_info = get_camera_for_frame(sfm_scene, frame_idx, device, logger) + if cam_info is None: + logger.warning(f"Skipping frame {frame_idx}: camera not found") + continue + w2c, K, img_h, img_w = cam_info + + # Render CLIP features from all 3 levels + sem_feats = [] + for model in models: + feat = render_clip_features(model, w2c, K, img_w, img_h) + sem_feats.append(feat) + + # Stack: [3, H, W, 512] + sem_map = torch.stack(sem_feats, dim=0) + + # Get GT annotations for this frame + img_ann = gt_ann[str(frame_idx)] + prompts = list(img_ann.keys()) + + # Set positive prompts in CLIP model + clip_relevancy.set_positives(prompts) + + # Compute relevancy: [3, n_prompts, H, W] + relevancy_map = clip_relevancy.get_relevancy_map(sem_map) + + # Resize relevancy to GT resolution if needed + _, _, relev_h, relev_w = relevancy_map.shape + if relev_h != gt_h or relev_w != gt_w: + relevancy_map = torch.nn.functional.interpolate( + relevancy_map.reshape(-1, 1, relev_h, relev_w), + size=(gt_h, gt_w), + mode="bilinear", + align_corners=False, + ).reshape(3, len(prompts), gt_h, gt_w) + + # Segmentation IoU + # Clone relevancy_map because segmentation_process modifies it in-place + c_iou_list, c_lvl_list, iou_all = segmentation_process( + relevancy_map.clone(), config.mask_thresh, img_ann, prompts, device + ) + chosen_iou_all.extend(c_iou_list) + chosen_lvl_list_all.extend(c_lvl_list) + + # Localization accuracy + acc_num_img = localization_process(relevancy_map.clone(), img_ann, device) + acc_num_total += acc_num_img + total_bboxes += len(prompts) + + # Per-frame record + frame_result = { + "frame_idx": frame_idx, + "prompts": prompts, + "ious": c_iou_list, + "chosen_levels": c_lvl_list, + "iou_all_levels": iou_all, + "localization_correct": acc_num_img, + "num_bboxes": len(prompts), + } + per_frame_results.append(frame_result) + + logger.info( + f" Frame {frame_idx}: " + + ", ".join( + f'"{p}" IoU={iou * 100:.1f}% (lvl {lvl + 1})' for p, iou, lvl in zip(prompts, c_iou_list, c_lvl_list) + ) + + f" | loc={acc_num_img}/{len(prompts)}" + ) + + # Save visualization + if config.save_visualizations: + save_frame_visualization( + scene_output_dir, + frame_idx, + gt_img, + relevancy_map, + prompts, + img_ann, + c_iou_list, + c_lvl_list, + config.mask_thresh, + device, + ) + + # --- Scene summary --- + if not chosen_iou_all: + logger.warning(f"No successful evaluations for scene {scene_name}") + return None + + mean_iou = float(np.mean(chosen_iou_all)) + loc_accuracy = acc_num_total / total_bboxes if total_bboxes > 0 else 0.0 + + logger.info( + f"Scene {scene_name}: mean IoU = {mean_iou * 100:.1f}%, localization accuracy = {loc_accuracy * 100:.1f}%" + ) + logger.info(f" Chosen levels: {chosen_lvl_list_all}") + + return { + "scene": scene_name, + "mean_iou": mean_iou, + "localization_accuracy": loc_accuracy, + "localization_correct": acc_num_total, + "localization_total": total_bboxes, + "per_prompt_ious": chosen_iou_all, + "chosen_levels": chosen_lvl_list_all, + "mask_thresh": config.mask_thresh, + "per_frame_results": per_frame_results, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate LangSplatV2 models on the LERF-OVS dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--lerf-root", + type=pathlib.Path, + default=pathlib.Path("data/lerf_ovs"), + help="Root directory of the LERF-OVS dataset", + ) + parser.add_argument( + "--results-root", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_results"), + help="Directory containing trained model checkpoints " + "(_level_1.pt, _level_2.pt, _level_3.pt)", + ) + parser.add_argument( + "--reconstructions-root", + type=pathlib.Path, + default=pathlib.Path("reconstructions"), + help="Directory containing per-scene Gaussian splat reconstructions " + "(.ply). If not found, falls back to LERF dataset structure.", + ) + parser.add_argument( + "--output-dir", + type=pathlib.Path, + default=pathlib.Path("lerf_eval_results"), + help="Directory to save evaluation results and visualizations", + ) + parser.add_argument( + "--scenes", + nargs="+", + type=str, + default=None, + help="Specific scenes to evaluate. If not specified, discovers available scenes.", + ) + parser.add_argument( + "--mask-thresh", + type=float, + default=0.4, + help="Threshold for converting relevancy map to binary mask", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for computation (cuda or cuda:N)", + ) + parser.add_argument( + "--eval-topk", + type=int, + default=4, + help="Number of codebook entries to combine at evaluation time. " + "The original LangSplatV2 trains with topk=4.", + ) + parser.add_argument( + "--no-visualizations", + action="store_true", + help="Disable saving per-frame visualization images", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Set up logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger("lerf_eval") + + # Build config + config = EvaluationConfig( + lerf_root=args.lerf_root, + results_root=args.results_root, + reconstructions_root=args.reconstructions_root, + output_dir=args.output_dir, + device=args.device, + mask_thresh=args.mask_thresh, + save_visualizations=not args.no_visualizations, + eval_topk=args.eval_topk, + ) + + # Determine scenes to evaluate + if args.scenes: + scenes_to_eval = args.scenes + else: + # Auto-discover scenes from checkpoint files + ckpt_files = list(config.results_root.glob("*_level_1.pt")) + if ckpt_files: + scenes_to_eval = [f.stem.replace("_level_1", "") for f in ckpt_files] + else: + # Fall back to label directories + label_root = config.lerf_root / "label" + if label_root.exists(): + scenes_to_eval = [d.name for d in label_root.iterdir() if d.is_dir()] + else: + logger.error("No scenes found. Specify --scenes or check paths.") + return + + logger.info(f"Evaluating scenes: {scenes_to_eval}") + logger.info(f"Mask threshold: {config.mask_thresh}") + + # Run evaluation + results = [] + for scene_name in scenes_to_eval: + logger.info("=" * 60) + logger.info(f"Evaluating scene: {scene_name}") + logger.info("=" * 60) + + result = run_lerf_evaluation(scene_name, config, logger) + if result is not None: + results.append(result) + + # Print summary + logger.info("=" * 60) + logger.info("EVALUATION COMPLETE - Summary:") + logger.info("=" * 60) + + if results: + ious = [r["mean_iou"] for r in results] + accs = [r["localization_accuracy"] for r in results] + overall_mean_iou = float(np.mean(ious)) + overall_loc_acc = float(np.mean(accs)) + + logger.info(f"{'Scene':<20} {'mean IoU':>10} {'Loc Acc':>10}") + logger.info("-" * 42) + for r in results: + logger.info(f"{r['scene']:<20} {r['mean_iou'] * 100:>9.1f}% {r['localization_accuracy'] * 100:>9.1f}%") + logger.info("-" * 42) + logger.info(f"{'Overall':<20} {overall_mean_iou * 100:>9.1f}% {overall_loc_acc * 100:>9.1f}%") + + # Save results JSON + results_file = config.output_dir / "lerf_results.json" + results_file.parent.mkdir(parents=True, exist_ok=True) + + # Strip per-frame results for the summary (keep in per-scene files) + summary_results = [] + for r in results: + # Save per-scene detailed results + scene_results_file = config.output_dir / r["scene"] / "results.json" + scene_results_file.parent.mkdir(parents=True, exist_ok=True) + with open(scene_results_file, "w") as f: + json.dump(r, f, indent=2) + logger.info(f"Per-scene results saved to {scene_results_file}") + + # Summary (without per-frame details) + summary = {k: v for k, v in r.items() if k != "per_frame_results"} + summary_results.append(summary) + + with open(results_file, "w") as f: + json.dump( + { + "results": summary_results, + "overall_mean_iou": overall_mean_iou, + "overall_localization_accuracy": overall_loc_acc, + "config": { + "mask_thresh": config.mask_thresh, + "device": config.device, + }, + }, + f, + indent=2, + ) + logger.info(f"Summary results saved to {results_file}") + else: + logger.warning("No successful evaluations") + + +if __name__ == "__main__": + main() diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py index eaa05cb..56ff78d 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from dataclasses import dataclass, field +from pathlib import Path from typing import Literal import numpy as np @@ -17,7 +18,12 @@ TransformScene, ) -from .scene_transforms import ComputeCLIPFeatures, ComputeMultiScaleSAM2Masks +from .scene_transforms import ( + ComputeCLIPFeatures, + ComputeMultiScaleSAM1Masks, + ComputeMultiScaleSAM2Masks, + ImportOriginalLangSplatV2Features, +) @dataclass @@ -25,13 +31,77 @@ class SAM2Config: """Configuration for SAM2 multi-scale mask generation.""" checkpoint: Literal["large", "small", "tiny", "base_plus"] = "large" - """SAM2 checkpoint size to use.""" + """SAM2 checkpoint size to use. Larger models produce higher-quality masks + but run a heavier image encoder. "base_plus" or "small" can noticeably + reduce per-image encoding time at some cost to mask quality.""" + + points_per_side: int = 32 + """Grid density for point prompts. Total points = points_per_side**2 + (e.g. 32 -> 1024, 16 -> 256). Reducing this value is the single + largest lever for speeding up mask generation because each point + requires a decoder forward pass.""" + + points_per_batch: int = 256 + """Points processed simultaneously by SAM2. Larger values reduce the + number of decoder forward passes (fewer kernel launches) at the cost + of higher peak GPU memory. 256 is safe on 24 GB+ GPUs.""" + + pred_iou_thresh: float = 0.5 + """Predicted IoU threshold for mask filtering. SAM2 predicts + substantially lower IoU scores than SAM1 (e.g. small-mask means + of 0.3-0.6 vs 0.8-0.9), so the threshold is set lower to achieve + comparable mask survival rates.""" + + stability_score_thresh: float = 0.85 + """Stability score threshold for mask filtering.""" + + crop_n_layers: int = 1 + """Number of crop layers. With ``crop_n_layers=1`` (the LangSplatV2 + default), 5 crops are generated per image (1 full + 4 overlapping + sub-crops), each running a full encoder + decoder pass. Setting + this to 0 reduces to a single full-image crop (~5x fewer passes) + at the cost of losing detail from sub-crop masks.""" + + crop_n_points_downscale_factor: int = 2 + """Point grid downscale factor per crop layer. Sub-crop layers use + (points_per_side / 2)**2 points instead of points_per_side**2, + reducing decoder cost on sub-crops by ~4x while keeping the + full-image crop at full density. The original LangSplatV2 uses 1 + with SAM ViT-H; 2 is a reasonable default with SAM2.""" + + min_mask_region_area: int = 100 + """Minimum mask region area for post-processing (matching the original + LangSplatV2 which uses ``min_mask_region_area=100``).""" + + box_nms_thresh: float = 0.7 + """Box NMS IoU threshold within each crop.""" + + nms_iou_thr: float = 0.8 + """IoU threshold for mask NMS post-processing.""" + + nms_score_thr: float = 0.7 + """Score threshold for mask NMS.""" + + nms_inner_thr: float = 0.5 + """Inner overlap threshold for mask NMS.""" + + +@dataclass +class SAM1Config: + """Configuration for SAM1 multi-scale mask generation. + + Default values match the original LangSplatV2 ``preprocess.py`` exactly + (SAM ViT-H with ``crop_n_layers=1``, ``crop_n_points_downscale_factor=1``). + """ + + checkpoint: Literal["vit_h", "vit_l", "vit_b"] = "vit_h" + """SAM1 model variant. The original LangSplatV2 uses ViT-H.""" points_per_side: int = 32 """Grid density for point prompts.""" - points_per_batch: int = 64 - """Points processed simultaneously by SAM2.""" + points_per_batch: int = 256 + """Points processed simultaneously by SAM1.""" pred_iou_thresh: float = 0.7 """Predicted IoU threshold for mask filtering.""" @@ -40,15 +110,14 @@ class SAM2Config: """Stability score threshold for mask filtering.""" crop_n_layers: int = 1 - """Number of crop layers. 1 = also run SAM on image crops (matching - the original LangSplatV2 which uses ``crop_n_layers=1``).""" + """Number of crop layers (1 = also run on crops, matching original).""" crop_n_points_downscale_factor: int = 1 - """Point grid downscale factor per crop layer.""" + """Point grid downscale factor per crop layer. The original LangSplatV2 + uses 1 (no downscaling on sub-crops).""" min_mask_region_area: int = 100 - """Minimum mask region area for post-processing (matching the original - LangSplatV2 which uses ``min_mask_region_area=100``).""" + """Minimum mask region area for post-processing.""" box_nms_thresh: float = 0.7 """Box NMS IoU threshold within each crop.""" @@ -124,12 +193,20 @@ class LangSplatV2PreprocessConfig: crop_to_points: bool = False """If True, crop scene bounds to the point cloud extent.""" - # SAM2 configuration + # SAM configuration + sam_model: Literal["sam1", "sam2"] = "sam2" + """Which SAM model to use for mask generation. ``"sam1"`` uses the + original Segment Anything Model (ViT-H by default) matching the original + LangSplatV2 pipeline. ``"sam2"`` uses SAM2 (Hiera-Large by default).""" + + sam1: SAM1Config = field(default_factory=SAM1Config) + """Configuration for SAM1 mask generation (used when ``sam_model="sam1"``).""" + sam2: SAM2Config = field(default_factory=SAM2Config) - """Configuration for SAM2 mask generation.""" + """Configuration for SAM2 mask generation (used when ``sam_model="sam2"``).""" compute_sam_masks: bool = True - """Whether to compute SAM2 segmentation masks.""" + """Whether to compute SAM segmentation masks.""" # CLIP configuration clip: OpenCLIPConfig = field(default_factory=OpenCLIPConfig) @@ -138,6 +215,15 @@ class LangSplatV2PreprocessConfig: compute_clip_features: bool = True """Whether to compute CLIP features for masked regions.""" + # Import original features (bypasses SAM2 + CLIP) + original_features_dir: Path | None = None + """Path to the original LangSplatV2 ``language_features/`` directory. + + When set, imports pre-computed ``_f.npy`` / ``_s.npy`` files from the + original LangSplatV2 ``preprocess.py`` instead of running SAM2 mask + generation and CLIP feature encoding. Useful for A/B testing against + the original pipeline.""" + # Device device: torch.device | str = "cuda" """Device for model inference.""" @@ -155,9 +241,10 @@ def build_scene_transforms( 2. Point cloud percentile filtering 3. Image downsampling 4. Image filtering by visible points - 5. Multi-scale SAM2 mask generation - 6. CLIP feature encoding - 7. Scene cropping (optional) + 5a. Multi-scale SAM1/SAM2 mask generation + CLIP feature encoding, OR + 5b. Import original LangSplatV2 features (when ``original_features_dir`` + is set) + 6. Scene cropping (optional) Args: normalization_transform: Optional 4x4 transformation matrix @@ -184,28 +271,26 @@ def build_scene_transforms( rescaled_jpeg_quality=self.rescale_jpeg_quality, ), FilterImagesWithLowPoints(min_num_points=self.min_points_per_image), - # SAM2 mask generation - ( - ComputeMultiScaleSAM2Masks( - checkpoint=self.sam2.checkpoint, - points_per_side=self.sam2.points_per_side, - points_per_batch=self.sam2.points_per_batch, - pred_iou_thresh=self.sam2.pred_iou_thresh, - stability_score_thresh=self.sam2.stability_score_thresh, - crop_n_layers=self.sam2.crop_n_layers, - crop_n_points_downscale_factor=self.sam2.crop_n_points_downscale_factor, - min_mask_region_area=self.sam2.min_mask_region_area, - box_nms_thresh=self.sam2.box_nms_thresh, - nms_iou_thr=self.sam2.nms_iou_thr, - nms_score_thr=self.sam2.nms_score_thr, - nms_inner_thr=self.sam2.nms_inner_thr, - device=self.device, + ] + + if self.original_features_dir is not None: + # Import pre-computed features from original LangSplatV2 + transforms.append( + ImportOriginalLangSplatV2Features( + original_features_dir=self.original_features_dir, + clip_n_dims=self.clip.clip_n_dims, ) - if self.compute_sam_masks - else Identity() - ), - # CLIP feature encoding - ( + ) + else: + # Standard pipeline: SAM masks then CLIP features + if self.compute_sam_masks: + if self.sam_model == "sam1": + transforms.append(self.build_sam1_transform()) + else: + transforms.append(self.build_sam2_transform()) + else: + transforms.append(Identity()) + transforms.append( ComputeCLIPFeatures( clip_model_type=self.clip.clip_model_type, clip_model_pretrained=self.clip.clip_model_pretrained, @@ -214,8 +299,7 @@ def build_scene_transforms( ) if self.compute_clip_features else Identity() - ), - ] + ) # Optional scene cropping if self.crop_bbox is not None: @@ -225,6 +309,29 @@ def build_scene_transforms( return Compose(*transforms) + def build_sam1_transform(self): + """ + Build only the SAM1 mask generation transform. + + Returns: + ComputeMultiScaleSAM1Masks transform. + """ + return ComputeMultiScaleSAM1Masks( + checkpoint=self.sam1.checkpoint, + points_per_side=self.sam1.points_per_side, + points_per_batch=self.sam1.points_per_batch, + pred_iou_thresh=self.sam1.pred_iou_thresh, + stability_score_thresh=self.sam1.stability_score_thresh, + crop_n_layers=self.sam1.crop_n_layers, + crop_n_points_downscale_factor=self.sam1.crop_n_points_downscale_factor, + min_mask_region_area=self.sam1.min_mask_region_area, + box_nms_thresh=self.sam1.box_nms_thresh, + nms_iou_thr=self.sam1.nms_iou_thr, + nms_score_thr=self.sam1.nms_score_thr, + nms_inner_thr=self.sam1.nms_inner_thr, + device=self.device, + ) + def build_sam2_transform(self): """ Build only the SAM2 mask generation transform. @@ -277,7 +384,9 @@ class LangSplatV2ModelConfig: """Dimensionality of CLIP embeddings.""" topk: int = 4 - """Number of non-zero sparse coefficients per VQ layer.""" + """Number of non-zero sparse coefficients per VQ layer. + + The original LangSplatV2 uses topk=4 for both training and evaluation.""" @dataclass @@ -318,6 +427,11 @@ class LangSplatV2TrainingConfig: normalize_features: bool = False """Whether to L2-normalize predicted features before computing loss.""" + init_codebooks_all_levels: bool = True + """If True, initialize codebooks using features from ALL scale levels + (matching original LangSplatV2). If False, use only the target + ``feature_level``.""" + model: LangSplatV2ModelConfig = field(default_factory=LangSplatV2ModelConfig) """Model architecture configuration.""" diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py new file mode 100644 index 0000000..6e14057 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py new file mode 100644 index 0000000..56d9ad9 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py @@ -0,0 +1,23 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path +from typing import Optional + +_dataset_root = None + + +def set_dataset_root(dataset_root: Path): + global _dataset_root + _dataset_root = dataset_root + + +def get_dataset_root() -> Optional[Path]: + global _dataset_root + if _dataset_root is None: + return None + _dataset_root.mkdir(parents=True, exist_ok=True) + return _dataset_root + + +__all__ = ["set_dataset_root", "get_dataset_root"] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py new file mode 100644 index 0000000..4041ab5 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py @@ -0,0 +1,78 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path + +from . import get_dataset_root +from .util import download_google_drive_file + +# Google Drive file ID extracted from the download link in the LangSplatV2 README: +# https://drive.google.com/file/d/1QF1Po5p5DwTjFHu6tnTeYs_G0egMVmHt/view?usp=sharing +LERF_GDRIVE_FILE_ID = "1QF1Po5p5DwTjFHu6tnTeYs_G0egMVmHt" + +# Scenes available in the LERF-OVS dataset +LERF_SCENES = [ + "figurines", + "ramen", + "teatime", + "waldo_kitchen", +] + + +def get_lerf_data_path() -> Path: + """Get the path to the LERF-OVS dataset root. + + Returns: + Path to the ``lerf_ovs`` directory inside the dataset root. + + Raises: + ValueError: If the dataset root has not been set via ``set_dataset_root()``. + """ + dataset_root = get_dataset_root() + if dataset_root is None: + raise ValueError("Dataset root is not set. Call set_dataset_root() first.") + result = dataset_root / "lerf_ovs" + if not result.exists(): + result.mkdir(parents=True, exist_ok=True) + return result + + +def download_lerf_data(): + """ + Download the LERF-OVS dataset from Google Drive. + + The dataset contains COLMAP scenes and labelme-format ground truth labels + for open-vocabulary segmentation evaluation. + + Expected layout after download:: + + lerf_ovs/ + label/ + / + frame_XXXXX.json + frame_XXXXX.jpg + / (COLMAP scene) + images/ + sparse/ + output/ + + The data will be saved to the LERF data path (dataset_root/lerf_ovs/). + """ + output_path = get_lerf_data_path() + + print(f"Downloading LERF-OVS dataset to: {output_path}") + + # Check if data already exists + label_dir = output_path / "label" + if label_dir.exists() and any(label_dir.iterdir()): + print(f"LERF-OVS data already exists at: {output_path}") + print("Delete the directory to re-download.") + return + + download_google_drive_file( + file_id=LERF_GDRIVE_FILE_ID, + output_path=output_path.parent, + filename="lerf_ovs.zip", + ) + + print(f"\nLERF-OVS dataset download complete: {output_path}") diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py new file mode 100644 index 0000000..535612a --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py @@ -0,0 +1,43 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path +from zipfile import ZipFile + + +def download_google_drive_file(file_id: str, output_path: Path, filename: str = "download.zip") -> Path: + """ + Download a file from Google Drive using gdown. + + Args: + file_id: Google Drive file ID (from the sharing URL). + output_path: Directory to save and extract to. + filename: Name for the downloaded file. + + Returns: + Path to the output directory after extraction. + """ + try: + import gdown + except ImportError: + raise ImportError( + "gdown is required to download files from Google Drive. " + "Install it with: pip install gdown" + ) + + output_path.mkdir(parents=True, exist_ok=True) + zip_path = output_path / filename + + url = f"https://drive.google.com/uc?id={file_id}" + print(f"Downloading from Google Drive (file_id={file_id})...") + gdown.download(url, str(zip_path), quiet=False) + + if zip_path.suffix == ".zip": + print("Extracting archive...") + with ZipFile(zip_path, "r") as zf: + zf.extractall(output_path) + # Clean up zip file + zip_path.unlink() + print(f"Extracted to: {output_path}") + + return output_path diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py new file mode 100644 index 0000000..21ba846 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py @@ -0,0 +1,149 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""OpenCLIP relevancy computation for LERF evaluation. + +Implements the relevancy scoring from the original LangSplatV2 evaluation +(``OpenCLIPNetwork.get_max_across_quick``). For each pixel, the score is the +minimum across all negative prompts of ``softmax(10 * [pos_sim, neg_sim])[0]``. + +""" +import logging +from typing import Sequence + +import open_clip +import torch +import torchvision + +logger = logging.getLogger(__name__) + +# Default negative prompts (matching the original LangSplatV2 evaluation) +DEFAULT_NEGATIVES = ("object", "things", "stuff", "texture") + + +class OpenCLIPRelevancy: + """Compute CLIP-based relevancy maps for open-vocabulary segmentation evaluation. + + This class encapsulates: + + * Loading an OpenCLIP model (ViT-B-16 by default) + * Encoding positive (query) and negative (distractor) text prompts + * Computing per-pixel relevancy scores from rendered CLIP feature maps + + The relevancy computation matches the original LangSplatV2 + ``get_max_across_quick`` exactly: + + 1. Dot-product similarity between per-pixel features and all prompt embeddings + 2. For each positive prompt and each negative prompt, form a 2-way softmax + with temperature 10 + 3. Take the *minimum* positive probability across all negatives as the + final relevancy score + + Args: + clip_model_type: OpenCLIP model architecture (default ``"ViT-B-16"``). + clip_model_pretrained: Pretrained weights identifier + (default ``"laion2b_s34b_b88k"``). + negatives: Tuple of negative distractor text prompts. + device: Device for model and embeddings. + """ + + def __init__( + self, + clip_model_type: str = "ViT-B-16", + clip_model_pretrained: str = "laion2b_s34b_b88k", + negatives: Sequence[str] = DEFAULT_NEGATIVES, + device: str | torch.device = "cuda", + ): + self.device = torch.device(device) + self.clip_model_type = clip_model_type + self.clip_model_pretrained = clip_model_pretrained + + # Load OpenCLIP model + model, _, _ = open_clip.create_model_and_transforms( + clip_model_type, + pretrained=clip_model_pretrained, + precision="fp16", + ) + model.eval() + self.model = model.to(self.device) + self.tokenizer = open_clip.get_tokenizer(clip_model_type) + + # Encode negative prompts + self.negatives = tuple(negatives) + with torch.no_grad(): + tok = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(self.device) + self.neg_embeds = self.model.encode_text(tok) + self.neg_embeds = self.neg_embeds / self.neg_embeds.norm(dim=-1, keepdim=True) + + # Positive embeddings (set per evaluation frame) + self.positives: tuple[str, ...] = () + self.pos_embeds: torch.Tensor | None = None + + logger.info( + f"OpenCLIPRelevancy initialized: model={clip_model_type}, " + f"pretrained={clip_model_pretrained}, negatives={self.negatives}" + ) + + def set_positives(self, text_list: Sequence[str]) -> None: + """Encode and cache positive (query) text prompts. + + Args: + text_list: List of positive text prompts (e.g. category labels + from the ground-truth annotations). + """ + self.positives = tuple(text_list) + with torch.no_grad(): + tok = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.device) + self.pos_embeds = self.model.encode_text(tok) + self.pos_embeds = self.pos_embeds / self.pos_embeds.norm(dim=-1, keepdim=True) + + @torch.no_grad() + def get_relevancy_map(self, sem_map: torch.Tensor) -> torch.Tensor: + """Compute per-pixel relevancy for all positive prompts across all levels. + + This exactly replicates ``OpenCLIPNetwork.get_max_across_quick`` from the + original LangSplatV2 evaluation. + + Args: + sem_map: CLIP feature maps of shape ``[n_levels, H, W, 512]``. + Each level corresponds to a different SAM scale model. + + Returns: + Relevancy tensor of shape ``[n_levels, n_prompts, H, W]`` where each + value is in ``[0, 1]``. + """ + if self.pos_embeds is None: + raise RuntimeError("Call set_positives() before computing relevancy maps.") + + n_levels, h, w, c = sem_map.shape + n_phrases = len(self.positives) + n_negatives = len(self.negatives) + + # Flatten spatial dims: [n_levels, H*W, 512] + sem_map_flat = sem_map.reshape(n_levels, h * w, c).contiguous() + + # All prompt embeddings: [P+N, 512] + phrase_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0) + phrase_embeds = phrase_embeds.to(sem_map.dtype).to(sem_map.device) + + # Dot-product similarity: [n_levels, H*W, P+N] + sim = torch.einsum("nqc,pc->nqp", sem_map_flat, phrase_embeds) + + # Split into positive and negative similarities + pos_vals = sim[:, :, :n_phrases] # [n_levels, H*W, P] + neg_vals = sim[:, :, n_phrases:] # [n_levels, H*W, N] + + # For each (positive, negative) pair, compute 2-way softmax with temperature 10 + repeated_pos = pos_vals.unsqueeze(-1).expand(-1, -1, -1, n_negatives) # [n_levels, H*W, P, N] + neg_vals_exp = neg_vals.unsqueeze(2).expand(-1, -1, n_phrases, -1) # [n_levels, H*W, P, N] + + sims = torch.stack([repeated_pos, neg_vals_exp], dim=-1) # [n_levels, H*W, P, N, 2] + softmax = torch.softmax(10 * sims, dim=-1) # [n_levels, H*W, P, N, 2] + + # Take minimum positive probability across all negatives + min_pos_prob, _ = softmax[..., 0].min(dim=-1) # [n_levels, H*W, P] + + # Reshape back to spatial dimensions + relev_map = min_pos_prob.permute(0, 2, 1).reshape(n_levels, n_phrases, h, w) + + return relev_map diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py index ea4d30f..e008fcc 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py @@ -43,13 +43,22 @@ def calculate_langsplatv2_loss( ) -> dict[str, torch.Tensor]: """Compute the LangSplatV2 language feature loss. - Compares predicted CLIP features with ground truth features, masked to only - include valid pixels (those covered by a SAM mask). + Both predicted and ground-truth features are zeroed for unmapped pixels + (mask == False), matching the original's + ``language_feature * language_feature_mask`` approach. Per-pixel loss + is then zeroed for unmapped pixels and averaged over **all** pixels. + This means: + + - Unmapped pixels contribute 0 to the reported loss. + - The mean over all pixels implicitly down-scales gradients from valid + pixels by ``N_valid / N_total``, matching the original's gradient + magnitudes exactly. Args: - predicted_features: Predicted feature map of shape ``[B, H, W, clip_n_dims]``. - gt_features: Ground truth feature map of shape ``[B, H, W, clip_n_dims]``. - mask: Boolean mask of shape ``[B, H, W]`` indicating valid pixels. + predicted_features: Predicted feature map, shape ``[B, H, W, C]`` + or ``[H, W, C]``. + gt_features: Ground truth feature map, same shape. + mask: Boolean mask, shape ``[B, H, W]`` or ``[H, W]``. use_cosine_loss: Whether to include cosine similarity loss. use_l1_loss: Whether to include L1 loss. normalize_features: Whether to L2-normalize predicted features @@ -59,52 +68,38 @@ def calculate_langsplatv2_loss( Dictionary with loss components: - ``"total_loss"``: Combined loss value. - ``"cosine_loss"``: Cosine loss component (if enabled). + - ``"cosine_loss_valid"``: Cosine loss on valid pixels only (for logging). - ``"l1_loss"``: L1 loss component (if enabled). """ assert use_cosine_loss or use_l1_loss, "At least one loss type must be enabled" - # Optionally normalize predicted features if normalize_features: predicted_features = predicted_features / (predicted_features.norm(dim=-1, keepdim=True) + 1e-10) loss_dict: dict[str, torch.Tensor] = {} total_loss = torch.tensor(0.0, device=predicted_features.device) - if not mask.any(): - # No valid pixels - return zero loss - loss_dict["total_loss"] = total_loss - if use_cosine_loss: - loss_dict["cosine_loss"] = total_loss - if use_l1_loss: - loss_dict["l1_loss"] = total_loss - return loss_dict - - # Gather only valid pixels (clean signal, no NaN risk from torch.empty). - valid_pred = predicted_features[mask] # [N_valid, clip_n_dims] - valid_gt = gt_features[mask] # [N_valid, clip_n_dims] - - # The original LangSplatV2 computes .mean() over ALL H*W pixels, where - # masked-out pixels are zero-vectors that contribute ~0 to the sum but - # inflate the denominator. This implicitly scales gradients down by - # (N_valid / N_total). We replicate this by computing the loss on valid - # pixels only (clean, interpretable values) and multiplying by the mask - # coverage fraction so that gradient magnitudes match the original exactly: - # - # grad_original = (1/N_total) * sum_valid(dL_i) - # grad_ours = (ratio/N_valid) * sum_valid(dL_i) - # = (1/N_total) * sum_valid(dL_i) [identical] - mask_fraction = mask.sum().float() / mask.numel() + mask_expanded = mask.unsqueeze(-1) # [..., 1] + masked_pred = predicted_features * mask_expanded + masked_gt = gt_features * mask_expanded if use_cosine_loss: - cos_loss_raw = cosine_loss(valid_pred, valid_gt) - cos_loss = cos_loss_raw * mask_fraction - loss_dict["cosine_loss"] = cos_loss - # Mean over valid pixels only (no mask_fraction); stable for logging when coverage varies - loss_dict["cosine_loss_valid"] = cos_loss_raw - total_loss = total_loss + cos_loss + per_pixel_cos = F.cosine_similarity(masked_pred, masked_gt, dim=-1) + per_pixel_loss = (1.0 - per_pixel_cos) * mask.float() + cos_loss_all = per_pixel_loss.sum() / mask.numel() + loss_dict["cosine_loss"] = cos_loss_all + total_loss = total_loss + cos_loss_all + + if mask.any(): + valid_pred = predicted_features[mask] + valid_gt = gt_features[mask] + loss_dict["cosine_loss_valid"] = cosine_loss(valid_pred, valid_gt) + else: + loss_dict["cosine_loss_valid"] = cos_loss_all if use_l1_loss: - l1 = l1_loss(valid_pred, valid_gt) * mask_fraction + per_pixel_l1 = torch.abs(masked_pred - masked_gt) * mask_expanded + l1 = per_pixel_l1.sum() / mask.numel() loss_dict["l1_loss"] = l1 total_loss = total_loss + l1 diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py index de05bce..b7d182f 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py @@ -2,9 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 # from .clip_feature_encoding import ComputeCLIPFeatures +from .import_original_features import ImportOriginalLangSplatV2Features +from .multi_scale_sam1_masks import ComputeMultiScaleSAM1Masks from .multi_scale_sam_masks import ComputeMultiScaleSAM2Masks __all__ = [ "ComputeCLIPFeatures", + "ComputeMultiScaleSAM1Masks", "ComputeMultiScaleSAM2Masks", + "ImportOriginalLangSplatV2Features", ] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py index 2ce3372..5ef25d9 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py @@ -5,6 +5,9 @@ This transform computes CLIP features for masked image regions generated by the multi-scale SAM transform, following the LangSplatV2 approach. + +Crop extraction, masking, padding, and resize are performed on the GPU in +float32 to avoid uint8 quantisation artefacts that occur with ``cv2.resize`` on small masks. """ import logging from typing import Any @@ -12,57 +15,12 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F import tqdm from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene from fvdb_reality_capture.transforms import BaseTransform, transform -def get_seg_img(mask: np.ndarray, image: np.ndarray, bbox: np.ndarray) -> np.ndarray: - """ - Extract and crop the segmented region from an image. - - Args: - mask: Binary segmentation mask, shape [H, W]. - image: Source image, shape [H, W, 3]. - bbox: Bounding box in XYWH format. - - Returns: - Cropped and masked image region. - """ - x, y, w, h = map(int, bbox) - - # Crop first (view, no copy), then copy only the small region - cropped_img = image[y : y + h, x : x + w].copy() - cropped_mask = mask[y : y + h, x : x + w] - - # Apply mask only to the cropped region - cropped_img[cropped_mask == 0] = 0 - - return cropped_img - - -def pad_img(img: np.ndarray) -> np.ndarray: - """ - Pad an image to make it square. - - Args: - img: Input image, shape [H, W, 3]. - - Returns: - Square padded image. - """ - h, w = img.shape[:2] - l = max(w, h) - pad = np.zeros((l, l, 3), dtype=np.uint8) - - if h > w: - pad[:, (h - w) // 2 : (h - w) // 2 + w, :] = img - else: - pad[(w - h) // 2 : (w - h) // 2 + h, :, :] = img - - return pad - - @transform class ComputeCLIPFeatures(BaseTransform): """ @@ -76,7 +34,7 @@ class ComputeCLIPFeatures(BaseTransform): This transform must be run after ComputeMultiScaleSAM2Masks. """ - version = "1.0.0" + version = "1.1.0" def __init__( self, @@ -118,16 +76,21 @@ def _get_clip_model(self): def _encode_masked_regions( self, - image: np.ndarray, + image_gpu: torch.Tensor, masks: np.ndarray, bboxes: np.ndarray, ) -> torch.Tensor: """ Encode masked image regions using CLIP. + Performs crop-mask-pad-resize entirely on GPU in float32 (matching + the original LangSplatV2 ``mask2segmap`` + ``_embed_clip_sam_tiles``), + then normalises and encodes through CLIP in one batch. + Args: - image: Source image in RGB format, shape [H, W, 3]. - masks: Binary masks, shape [N, H, W]. + image_gpu: Source image on CUDA as float32, shape [H, W, 3], + values in [0, 255]. + masks: Binary masks, shape [N, H, W] (uint8 0/1). bboxes: Bounding boxes in XYWH format, shape [N, 4]. Returns: @@ -138,24 +101,40 @@ def _encode_masked_regions( clip_model = self._get_clip_model() image_size = clip_model.image_size - - # Extract each masked region, pad to square, and resize to model's expected size - seg_imgs = [] - for mask, bbox in zip(masks, bboxes): - seg_img = get_seg_img(mask, image, bbox) - pad_seg_img = pad_img(seg_img) - resized_img = cv2.resize(pad_seg_img, (image_size, image_size)) - seg_imgs.append(resized_img) - - # Stack and convert to tensor with values in [0, 1] - seg_imgs_np = np.stack(seg_imgs, axis=0) # [N, image_size, image_size, 3] - seg_imgs_tensor = ( - torch.from_numpy(seg_imgs_np.astype("float32")).permute(0, 3, 1, 2) / 255.0 - ) # [N, 3, image_size, image_size] - - # Apply model's tensor preprocessing (handles normalization) and encode + n = len(masks) + device = image_gpu.device + + all_segs = torch.from_numpy(masks).to(device) + seg_imgs = torch.empty(n, 3, image_size, image_size, device=device) + + for i in range(n): + x, y, w, h = int(bboxes[i][0]), int(bboxes[i][1]), int(bboxes[i][2]), int(bboxes[i][3]) + w = max(w, 1) + h = max(h, 1) + + crop = image_gpu[y : y + h, x : x + w].clone() + crop[~all_segs[i, y : y + h, x : x + w].bool()] = 0 + crop = crop.permute(2, 0, 1) # HWC -> CHW + + side = max(h, w) + padded = torch.zeros(3, side, side, device=device) + if h > w: + offset = (h - w) // 2 + padded[:, :, offset : offset + w] = crop + else: + offset = (w - h) // 2 + padded[:, offset : offset + h, :] = crop + + seg_imgs[i] = F.interpolate( + padded.unsqueeze(0), size=(image_size, image_size), + mode="bilinear", align_corners=False, + ).squeeze(0) + + seg_imgs /= 255.0 + + # Normalise and encode with torch.no_grad(): - seg_imgs_preprocessed = clip_model.preprocess_tensor(seg_imgs_tensor) + seg_imgs_preprocessed = clip_model.preprocess_tensor(seg_imgs) clip_embeds = clip_model.encode_image(seg_imgs_preprocessed) clip_embeds = clip_embeds / clip_embeds.norm(dim=-1, keepdim=True) @@ -226,7 +205,8 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: # Create cache folder model_type_safe = self._clip_model_type.replace("-", "_") pretrained_safe = self._clip_model_pretrained.replace("-", "_") - cache_prefix = f"clip_features_{model_type_safe}_{pretrained_safe}_{self._clip_n_dims}" + version_safe = self.version.replace(".", "_") + cache_prefix = f"clip_features_{model_type_safe}_{pretrained_safe}_{self._clip_n_dims}_v{version_safe}" output_cache = input_cache.make_folder( cache_prefix, description=f"CLIP features using {self._clip_model_type}", @@ -286,16 +266,17 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] - # Load SAM2 masks from parent cache + img_gpu = torch.from_numpy(img_rgb).to("cuda", dtype=torch.float32) + mask_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" if not input_cache.has_file(mask_filename): raise RuntimeError( - f"Mask file {mask_filename} not found in cache. " "Run ComputeMultiScaleSAM2Masks first." + f"Mask file {mask_filename} not found in cache. " + "Run ComputeMultiScaleSAM2Masks or ComputeMultiScaleSAM1Masks first." ) _, mask_data = input_cache.read_file(mask_filename) - # Encode all masked regions across scales all_features = [] scale_names = ["default", "s", "m", "l"] @@ -304,7 +285,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: bboxes = mask_data.get(f"{scale_name}_bboxes", np.zeros((0, 4))) if len(masks) > 0: - scale_features = self._encode_masked_regions(img_rgb, masks, bboxes) + scale_features = self._encode_masked_regions(img_gpu, masks, bboxes) all_features.append(scale_features) # Concatenate features from all scales @@ -320,6 +301,20 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: total_masks = sum(lengths) assert features.shape[0] == total_masks, f"Feature count mismatch: {features.shape[0]} vs {total_masks}" + if features.shape[0] > 0 and torch.isnan(features).any(): + self._logger.warning( + f"Image {image_meta.image_id}: CLIP features contain NaN " + f"({torch.isnan(features).sum().item()} / {features.numel()} values)" + ) + + # Report per-scale mask coverage for the first few images + if image_meta.image_id < 3: + coverage = {sn: int((seg_maps[i] >= 0).sum()) for i, sn in enumerate(scale_names)} + self._logger.info( + f"Image {image_meta.image_id}: {total_masks} masks, " + f"lengths={lengths}, pixel coverage={coverage}" + ) + # Save cache_filename = f"features_{image_meta.image_id:0{num_zeropad}}" output_cache.write_file( diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py new file mode 100644 index 0000000..059a13d --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py @@ -0,0 +1,252 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Import pre-computed features from the original LangSplatV2 repository. + +The original LangSplatV2 ``preprocess.py`` saves per-image features as two +numpy files: + +- ``{image_name}_f.npy``: CLIP feature vectors ``[N_total, 512]`` (float16) +- ``{image_name}_s.npy``: segmentation maps ``[4, H, W]`` (int32) + +This transform reads those files and writes them into our SfmCache format, +acting as a drop-in replacement for the ``ComputeMultiScaleSAM2Masks`` + +``ComputeCLIPFeatures`` pipeline to test our pipeline's outputs. +""" +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import tqdm +from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene +from fvdb_reality_capture.transforms import BaseTransform, transform + + +def _compute_lengths_from_seg_maps( + seg_maps: np.ndarray, + total_features: int, +) -> list[int]: + """Derive per-level feature counts from the original's segmentation maps. + + The original concatenates features in order (default, s, m, l) and assigns + globally-offset indices in each seg_map channel. It guarantees that the + **max** index in each non-empty channel equals + ``cumulative_length[level] - 1`` (asserted in ``preprocess.py``). We + exploit this by recovering cumulative lengths from the max index per + channel: ``cum[j] = max(channel_j) + 1``. + + This is robust to masks whose indices never appear in the seg_map (e.g. + a mask at the start of a level with zero pixels assigned) because only + the max -- which the original guarantees correct -- is used. + + Args: + seg_maps: Array of shape ``[4, H, W]`` with int32 indices (-1 = none). + total_features: Total number of features in the corresponding + ``_f.npy`` file (used as a cross-check). + + Returns: + List of 4 integers giving the number of features at each scale level. + """ + cum: list[int] = [] + for level in range(4): + channel = seg_maps[level] + valid = channel[channel >= 0] + if len(valid) == 0: + cum.append(cum[-1] if cum else 0) + else: + cum.append(int(valid.max()) + 1) + + lengths = [cum[0]] + [cum[j] - cum[j - 1] for j in range(1, 4)] + return lengths + + +@transform +class ImportOriginalLangSplatV2Features(BaseTransform): + """Import pre-computed language features from the original LangSplatV2. + + Reads ``_f.npy`` / ``_s.npy`` file pairs produced by the original + ``preprocess.py`` and writes them into the SfmCache in the same dict + format that ``ComputeCLIPFeatures`` produces: + + .. code-block:: python + + {"features": Tensor, "seg_maps": Tensor, "lengths": Tensor} + + This allows the training pipeline to consume original features without + any changes to the dataset or trainer code. + """ + + version = "1.0.2" + + def __init__( + self, + original_features_dir: Path | str, + clip_n_dims: int = 512, + ): + """ + Args: + original_features_dir: Path to the original ``language_features/`` + directory containing ``*_f.npy`` and ``*_s.npy`` files. + clip_n_dims: Expected CLIP embedding dimensionality (used for + cache naming and validation). + """ + self._features_dir = Path(original_features_dir) + self._clip_n_dims = clip_n_dims + self._logger = logging.getLogger( + f"{self.__class__.__module__}.{self.__class__.__name__}" + ) + + def __call__(self, input_scene: SfmScene) -> SfmScene: + if len(input_scene.images) == 0: + self._logger.warning("No images in SfmScene. Returning unchanged.") + return input_scene + + input_cache: SfmCache = input_scene.cache + + version_safe = self.version.replace(".", "_") + cache_prefix = ( + f"imported_original_langsplatv2_features" + f"_{self._clip_n_dims}_v{version_safe}" + ) + output_cache = input_cache.make_folder( + cache_prefix, + description="Imported features from original LangSplatV2", + ) + + num_zeropad = len(str(input_scene.num_images)) + 2 + regenerate_cache = False + + if output_cache.num_files != input_scene.num_images: + if output_cache.num_files != 0: + self._logger.info( + f"Cache has {output_cache.num_files} files but expected " + f"{input_scene.num_images}. Regenerating cache." + ) + output_cache.clear_current_folder() + regenerate_cache = True + + if not regenerate_cache: + for image_id in range(input_scene.num_images): + cache_filename = f"features_{image_id:0{num_zeropad}}" + if not output_cache.has_file(cache_filename): + self._logger.info( + f"{cache_filename} missing from cache. Regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + if regenerate_cache: + self._logger.info( + f"Importing original features from {self._features_dir}" + ) + pbar = tqdm.tqdm( + input_scene.images, + unit="imgs", + desc="Importing original features", + ) + + for image_meta in pbar: + image_name = Path(image_meta.image_path).stem + f_path = self._features_dir / f"{image_name}_f.npy" + s_path = self._features_dir / f"{image_name}_s.npy" + + if not f_path.exists(): + raise FileNotFoundError( + f"Feature file not found: {f_path}. " + f"Run the original LangSplatV2 preprocess.py first." + ) + if not s_path.exists(): + raise FileNotFoundError( + f"Segmentation map not found: {s_path}. " + f"Run the original LangSplatV2 preprocess.py first." + ) + + features_np = np.load(str(f_path)) # [N_total, 512], float16 + seg_maps_np = np.load(str(s_path)) # [4, H, W], int32 + + if features_np.shape[1] != self._clip_n_dims: + raise ValueError( + f"Feature dimension mismatch for {image_name}: " + f"expected {self._clip_n_dims}, got {features_np.shape[1]}" + ) + + lengths = _compute_lengths_from_seg_maps( + seg_maps_np, features_np.shape[0] + ) + + expected_total = sum(lengths) + if features_np.shape[0] != expected_total: + self._logger.warning( + f"{image_name}: feature count ({features_np.shape[0]}) " + f"!= sum(lengths) ({expected_total}). " + f"Possible corrupt feature/seg_map pair." + ) + + features = torch.from_numpy(features_np).half() + seg_maps = torch.from_numpy(seg_maps_np).int() + lengths_t = torch.tensor(lengths, dtype=torch.int32) + + cache_filename = f"features_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data={ + "features": features, + "seg_maps": seg_maps, + "lengths": lengths_t, + }, + data_type="pt", + metadata={ + "source": "original_langsplatv2", + "clip_n_dims": self._clip_n_dims, + "original_features_dir": str(self._features_dir), + }, + ) + + pbar.close() + self._logger.info( + f"Imported features for {input_scene.num_images} images." + ) + else: + self._logger.info("Original features already cached.") + + output_scene = SfmScene( + cameras=input_scene.cameras, + images=input_scene.images, + points=input_scene.points, + points_err=input_scene.points_err, + points_rgb=input_scene.points_rgb, + scene_bbox=input_scene.scene_bbox, + transformation_matrix=input_scene.transformation_matrix, + cache=output_cache, + ) + + return output_scene + + @staticmethod + def name() -> str: + return "ImportOriginalLangSplatV2Features" + + def state_dict(self) -> dict[str, Any]: + return { + "name": self.name(), + "version": self.version, + "original_features_dir": str(self._features_dir), + "clip_n_dims": self._clip_n_dims, + } + + @staticmethod + def from_state_dict( + state_dict: dict[str, Any], + ) -> "ImportOriginalLangSplatV2Features": + if state_dict["name"] != "ImportOriginalLangSplatV2Features": + raise ValueError( + f"Expected 'ImportOriginalLangSplatV2Features', " + f"got {state_dict['name']}" + ) + return ImportOriginalLangSplatV2Features( + original_features_dir=state_dict["original_features_dir"], + clip_n_dims=state_dict["clip_n_dims"], + ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py new file mode 100644 index 0000000..257e378 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py @@ -0,0 +1,336 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Shared mask post-processing utilities for multi-scale SAM transforms. + +These functions are used by both :class:`ComputeMultiScaleSAM1Masks` and +:class:`ComputeMultiScaleSAM2Masks`. They have no dependency on either the +``segment_anything`` or ``sam2`` packages. +""" +import logging +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Tuple + +import cv2 +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms + + +def remove_small_regions( + mask: np.ndarray, + area_thresh: float, + mode: str, +) -> Tuple[np.ndarray, bool]: + """Remove small disconnected regions or holes from a binary mask. + + Pure OpenCV implementation matching ``segment_anything.utils.amg.remove_small_regions``. + + Args: + mask: Boolean mask, shape ``[H, W]``. + area_thresh: Minimum area; components smaller than this are removed. + mode: ``"holes"`` to fill small holes, ``"islands"`` to remove small islands. + + Returns: + ``(cleaned_mask, changed)`` where *changed* is True if the mask was modified. + """ + assert mode in ("holes", "islands") + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1] + small_regions = [i for i, s in enumerate(sizes) if s < area_thresh and i != 0] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels or i == 0] + mask = np.isin(regions, fill_labels) + return mask, True + + +def _clean_single_mask( + seg_raw: np.ndarray, + min_area: int, +) -> Tuple[np.ndarray, float, List[float]]: + """Clean one mask and compute its bounding box. Thread-safe (CPU-only). + + Returns: + (cleaned_seg, score, box_xyxy) where score is 1.0 if the mask was + unchanged and 0.0 if it was modified. + """ + seg = seg_raw.astype(bool) + seg, changed_holes = remove_small_regions(seg, min_area, mode="holes") + unchanged = not changed_holes + seg, changed_islands = remove_small_regions(seg, min_area, mode="islands") + unchanged = unchanged and not changed_islands + + ys, xs = np.where(seg) + if len(xs) == 0: + box = [0.0, 0.0, 0.0, 0.0] + else: + box = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())] + + return seg, float(unchanged), box + + +def postprocess_small_regions( + masks: List[Dict[str, Any]], + min_area: int, + nms_thresh: float, +) -> List[Dict[str, Any]]: + """Remove small disconnected regions and holes from masks, then re-run box NMS. + + Mirrors the ``postprocess_small_regions`` step in the original SAM + ``SamAutomaticMaskGenerator.generate_curr_anns`` which cleans up each + mask's binary segmentation before returning annotations. + + The per-mask cleaning (``cv2.connectedComponentsWithStats``) is + parallelized across CPU cores via :class:`ThreadPoolExecutor`. + + Args: + masks: List of mask annotation dicts with at least ``segmentation`` + (np.ndarray bool/uint8 HxW) and ``bbox`` ([x, y, w, h]). + min_area: Minimum area threshold for ``remove_small_regions``. + nms_thresh: Box NMS IoU threshold for deduplication after cleaning. + + Returns: + Filtered list of mask annotation dicts with cleaned segmentations. + """ + if len(masks) == 0: + return masks + + max_workers = min(len(masks), os.cpu_count() or 4) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit(_clean_single_mask, m["segmentation"], min_area) + for m in masks + ] + results = [f.result() for f in futures] + + new_segmentations = [r[0] for r in results] + scores = [r[1] for r in results] + boxes_list = [r[2] for r in results] + + boxes_xyxy = torch.tensor(boxes_list, dtype=torch.float32) + scores_t = torch.tensor(scores, dtype=torch.float32) + + keep = batched_nms( + boxes_xyxy, + scores_t, + torch.zeros(len(masks), dtype=torch.long), + iou_threshold=nms_thresh, + ) + + result = [] + for idx in keep.tolist(): + m = masks[idx].copy() + m["segmentation"] = new_segmentations[idx].astype(np.uint8) + x_min, y_min, x_max, y_max = boxes_list[idx] + m["bbox"] = [x_min, y_min, x_max - x_min, y_max - y_min] + m["area"] = int(new_segmentations[idx].sum()) + result.append(m) + + return result + + +_mask_nms_logger = logging.getLogger(__name__ + ".mask_nms") +_mask_nms_diag_count = 0 + + +def mask_nms( + masks: torch.Tensor, + scores: torch.Tensor, + iou_thr: float = 0.7, + score_thr: float = 0.1, + inner_thr: float = 0.2, + **kwargs, +) -> torch.Tensor: + """ + Perform mask non-maximum suppression (NMS) on a set of masks. + + Faithful reimplementation of the ``mask_nms`` from the original + LangSplatV2 ``preprocess.py``. + + Args: + masks: Binary masks, shape ``[num_masks, H, W]``. + scores: Mask scores, shape ``[num_masks]``. + iou_thr: IoU threshold for NMS. + score_thr: Minimum score threshold. + inner_thr: Inner overlap threshold for removing contained masks. + + Returns: + Indices of selected masks after NMS. + """ + global _mask_nms_diag_count + log_diag = _mask_nms_diag_count < 8 + if log_diag: + _mask_nms_diag_count += 1 + + if len(masks) == 0: + return torch.tensor([], dtype=torch.long) + + scores, idx = scores.sort(0, descending=True) + num_masks = idx.shape[0] + + masks_ord = masks[idx.view(-1), :] + masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float) + + masks_flat = masks_ord.reshape(num_masks, -1).float() + intersection = masks_flat @ masks_flat.T + union = masks_area[:, None] + masks_area[None, :] - intersection + iou_matrix = intersection / union + + R = intersection / masks_area[:, None] + inner_val = 1 - R * R.T + cond = (R < 0.5) & (R.T >= 0.85) + inner_iou_matrix = torch.where(cond, inner_val, torch.zeros_like(inner_val)) + + iou_matrix.triu_(diagonal=1) + iou_max, _ = iou_matrix.max(dim=0) + inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) + inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) + inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=1) + inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) + + keep = iou_max <= iou_thr + keep_conf = scores > score_thr + keep_inner_u = inner_iou_max_u <= 1 - inner_thr + keep_inner_l = inner_iou_max_l <= 1 - inner_thr + + if log_diag: + _mask_nms_logger.info( + "[mask_nms diag] input=%d masks scores: min=%.4f max=%.4f " + "areas: min=%.0f max=%.0f iou_max: min=%.4f max=%.4f mean=%.4f " + "pass_iou=%d pass_conf=%d pass_inner_u=%d pass_inner_l=%d", + num_masks, + scores.min().item(), scores.max().item(), + masks_area.min().item(), masks_area.max().item(), + iou_max.min().item(), iou_max.max().item(), iou_max.mean().item(), + keep.sum().item(), keep_conf.sum().item(), + keep_inner_u.sum().item(), keep_inner_l.sum().item(), + ) + + if keep_conf.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_conf[index, 0] = True + if keep_inner_u.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_inner_u[index, 0] = True + if keep_inner_l.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_inner_l[index, 0] = True + keep *= keep_conf + keep *= keep_inner_u + keep *= keep_inner_l + + if log_diag: + _mask_nms_logger.info( + "[mask_nms diag] final keep=%d / %d", keep.sum().item(), num_masks, + ) + + selected_idx = idx[keep] + return selected_idx + + +_masks_update_logger = logging.getLogger(__name__ + ".masks_update") + + +def masks_update( + *mask_lists, + iou_thr: float = 0.8, + score_thr: float = 0.7, + inner_thr: float = 0.5, + max_area_frac: float = 0.95, +) -> tuple: + """ + Apply mask NMS to multiple lists of masks. + + Args: + *mask_lists: Variable number of mask lists to filter. + iou_thr: IoU threshold for NMS. + score_thr: Score threshold. + inner_thr: Inner overlap threshold. + max_area_frac: Discard masks covering more than this fraction of the + image. Near-full-image masks poison the inner-containment check + by appearing to contain every other mask. + + Returns: + Tuple of filtered mask lists. + """ + masks_new = [] + + for masks_lvl in mask_lists: + if len(masks_lvl) == 0: + masks_new.append([]) + continue + + if max_area_frac < 1.0: + h, w = masks_lvl[0]["segmentation"].shape[:2] + total_pixels = h * w + area_limit = total_pixels * max_area_frac + before = len(masks_lvl) + masks_lvl = [m for m in masks_lvl if m["segmentation"].sum() <= area_limit] + n_dropped = before - len(masks_lvl) + if n_dropped > 0: + _masks_update_logger.info( + "[masks_update] dropped %d masks covering >%.0f%% of image (%d remain)", + n_dropped, max_area_frac * 100, len(masks_lvl), + ) + if len(masks_lvl) == 0: + masks_new.append([]) + continue + + seg_pred = torch.from_numpy(np.stack([m["segmentation"] for m in masks_lvl], axis=0)) + iou_pred = torch.from_numpy(np.stack([m["predicted_iou"] for m in masks_lvl], axis=0)) + stability = torch.from_numpy(np.stack([m["stability_score"] for m in masks_lvl], axis=0)) + + if torch.cuda.is_available(): + seg_pred = seg_pred.cuda() + iou_pred = iou_pred.cuda() + stability = stability.cuda() + + scores = stability * iou_pred + keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr) + + keep_set = set(keep_mask_nms.int().cpu().numpy().tolist()) + filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set] + masks_new.append(filtered_masks) + + return tuple(masks_new) + + +def cross_crop_nms( + mask_list: List[Dict[str, Any]], + iou_threshold: float = 0.7, +) -> List[Dict[str, Any]]: + """Run NMS across masks from multiple crops; prefer masks from smaller crops. + + Args: + mask_list: List of mask records with "bbox" (xywh) and "crop_box" (xywh). + iou_threshold: Box IoU threshold for NMS. + + Returns: + Filtered list of mask records. + """ + if len(mask_list) <= 1: + return mask_list + boxes_xywh = np.array([m["bbox"] for m in mask_list], dtype=np.float32) + x1 = boxes_xywh[:, 0] + y1 = boxes_xywh[:, 1] + x2 = boxes_xywh[:, 0] + boxes_xywh[:, 2] + y2 = boxes_xywh[:, 1] + boxes_xywh[:, 3] + boxes_xyxy = torch.from_numpy(np.stack([x1, y1, x2, y2], axis=1)) + crop_areas = np.array( + [m["crop_box"][2] * m["crop_box"][3] for m in mask_list], + dtype=np.float32, + ) + scores = torch.from_numpy(1.0 / (crop_areas + 1e-6)) + keep = batched_nms( + boxes_xyxy.float(), + scores, + torch.zeros(len(mask_list), dtype=torch.long), + iou_threshold=iou_threshold, + ) + return [mask_list[i] for i in keep.tolist()] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py new file mode 100644 index 0000000..d86603c --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py @@ -0,0 +1,422 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Multi-scale SAM1 segmentation transform. + +Uses :class:`fvdb_reality_capture.foundation_models.SAM1Model` with +``output_mode="multi_scale"`` to match the original LangSplatV2's SAM ViT-H +mask generation pipeline. Crop management, point grid generation, cross-crop +NMS, and mask NMS mirror the SAM2 transform. +""" +import logging +from typing import Any, Dict, List, Literal + +import cv2 +import numpy as np +import torch +import tqdm + +from fvdb_reality_capture.foundation_models import SAM1Model +from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene +from fvdb_reality_capture.transforms import BaseTransform, transform + +from .mask_utils import cross_crop_nms, masks_update, postprocess_small_regions + + +def _sam1_amg(): + """Lazy accessor for ``segment_anything.utils.amg``.""" + try: + import segment_anything.utils.amg as _amg + except ImportError: + raise ImportError( + "SAM1 transform requires the segment-anything package. Install with:\n" + " conda install -c conda-forge segment-anything\n" + "Or update your environment:\n" + " conda env update -f open_vocabulary_segmentation/langsplatv2/environment.yml" + ) from None + return _amg + + +@transform +class ComputeMultiScaleSAM1Masks(BaseTransform): + """Generate multi-scale segmentation masks using SAM1 (ViT-H/L/B). + + Uses :class:`fvdb_reality_capture.foundation_models.SAM1Model` with + ``output_mode="multi_scale"`` to split the 3 multimask outputs per point + by index (small/medium/large). Default parameters match the original + LangSplatV2 ``preprocess.py`` exactly. + """ + + version = "1.5.0" + + def __init__( + self, + checkpoint: Literal["vit_h", "vit_l", "vit_b"] = "vit_h", + points_per_side: int = 32, + points_per_batch: int = 256, + pred_iou_thresh: float = 0.7, + stability_score_thresh: float = 0.85, + crop_n_layers: int = 1, + crop_n_points_downscale_factor: int = 1, + min_mask_region_area: int = 100, + box_nms_thresh: float = 0.7, + nms_iou_thr: float = 0.8, + nms_score_thr: float = 0.7, + nms_inner_thr: float = 0.5, + device: torch.device | str = "cuda", + ): + """ + Create a multi-scale SAM1 mask generation transform. + + Args: + checkpoint: SAM1 model variant (vit_h, vit_l, vit_b). + points_per_side: Grid density for point prompts. + points_per_batch: Points processed simultaneously. + pred_iou_thresh: Predicted IoU threshold. + stability_score_thresh: Stability score threshold. + crop_n_layers: Number of crop layers (1 = also run on crops, + matching the original LangSplatV2). + crop_n_points_downscale_factor: Point grid downscale per crop layer. + min_mask_region_area: Minimum mask region area for post-processing. + box_nms_thresh: Box NMS IoU threshold within each crop. + nms_iou_thr: IoU threshold for mask NMS post-processing. + nms_score_thr: Score threshold for mask NMS. + nms_inner_thr: Inner overlap threshold for mask NMS. + device: Device to run SAM1 on. + """ + self._checkpoint = checkpoint + self._points_per_side = points_per_side + self._points_per_batch = points_per_batch + self._pred_iou_thresh = pred_iou_thresh + self._stability_score_thresh = stability_score_thresh + self._crop_n_layers = crop_n_layers + self._crop_n_points_downscale_factor = crop_n_points_downscale_factor + self._min_mask_region_area = min_mask_region_area + self._box_nms_thresh = box_nms_thresh + self._nms_iou_thr = nms_iou_thr + self._nms_score_thr = nms_score_thr + self._nms_inner_thr = nms_inner_thr + self._device = device + + self._sam1_model: SAM1Model | None = None + self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") + + def _get_sam1_model(self) -> SAM1Model: + if self._sam1_model is None: + self._sam1_model = SAM1Model( + checkpoint=self._checkpoint, + points_per_side=self._points_per_side, + points_per_batch=self._points_per_batch, + pred_iou_thresh=self._pred_iou_thresh, + stability_score_thresh=self._stability_score_thresh, + min_mask_region_area=self._min_mask_region_area, + box_nms_thresh=self._box_nms_thresh, + output_mode="multi_scale", + device=self._device, + ) + return self._sam1_model + + def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: + """Generate masks at multiple scales. + + Uses multi-crop generation and cross-crop NMS, then mask NMS per scale. + + Args: + image: Input image in BGR format (OpenCV default), shape ``[H, W, 3]``. + + Returns: + Dictionary with mask lists keyed by ``"default"``, ``"s"``, + ``"m"``, ``"l"``. + """ + amg = _sam1_amg() + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + orig_size = image_rgb.shape[:2] + sam1 = self._get_sam1_model() + + crop_boxes, layer_idxs = amg.generate_crop_boxes( + orig_size, self._crop_n_layers, 512 / 1500 + ) + point_grids = amg.build_all_layer_point_grids( + self._points_per_side, + self._crop_n_layers, + self._crop_n_points_downscale_factor, + ) + + all_default: List[Dict[str, Any]] = [] + all_s: List[Dict[str, Any]] = [] + all_m: List[Dict[str, Any]] = [] + all_l: List[Dict[str, Any]] = [] + + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + x0, y0, x1, y1 = crop_box + cropped_im = image_rgb[y0:y1, x0:x1, :] + cropped_h, cropped_w = cropped_im.shape[:2] + points_scale = np.array([cropped_w, cropped_h], dtype=np.float64) + point_coords = point_grids[layer_idx] * points_scale + + md, ms, mm, ml = sam1.predict_masks_multi_scale( + cropped_im, + point_coords=point_coords, + crop_box=crop_box, + orig_size=orig_size, + ) + all_default.extend(md) + all_s.extend(ms) + all_m.extend(mm) + all_l.extend(ml) + + log_diag = image_rgb.shape is not None # always log first few images + if log_diag and hasattr(self, '_diag_count'): + log_diag = self._diag_count < 2 + if log_diag: + if not hasattr(self, '_diag_count'): + self._diag_count = 0 + self._diag_count += 1 + self._logger.info( + "[diag] after predict: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + if len(crop_boxes) > 1: + all_default = cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) + all_s = cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) + all_m = cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) + all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + + if log_diag: + self._logger.info( + "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + if self._min_mask_region_area > 0: + nms_thresh = self._box_nms_thresh + all_default = postprocess_small_regions(all_default, self._min_mask_region_area, nms_thresh) + all_s = postprocess_small_regions(all_s, self._min_mask_region_area, nms_thresh) + all_m = postprocess_small_regions(all_m, self._min_mask_region_area, nms_thresh) + all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) + + if log_diag: + self._logger.info( + "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + masks_default, masks_s, masks_m, masks_l = masks_update( + all_default, + all_s, + all_m, + all_l, + iou_thr=self._nms_iou_thr, + score_thr=self._nms_score_thr, + inner_thr=self._nms_inner_thr, + ) + + if log_diag: + self._logger.info( + "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", + len(masks_default), len(masks_s), len(masks_m), len(masks_l), + ) + + return { + "default": masks_default, + "s": masks_s, + "m": masks_m, + "l": masks_l, + } + + def __call__(self, input_scene: SfmScene) -> SfmScene: + """Generate multi-scale SAM1 masks for all images in the scene. + + Args: + input_scene: Input scene containing images. + + Returns: + Scene with cache containing multi-scale mask data. + """ + if len(input_scene.images) == 0: + self._logger.warning("No images found in the SfmScene. Returning unchanged.") + return input_scene + + input_cache: SfmCache = input_scene.cache + + version_safe = self.version.replace(".", "_") + cache_prefix = ( + f"sam1_multi_scale_masks_{self._checkpoint}_" + f"p{self._points_per_side}_" + f"iou{int(self._pred_iou_thresh * 100)}_" + f"stab{int(self._stability_score_thresh * 100)}_" + f"crop{self._crop_n_layers}_" + f"nmsiou{int(self._nms_iou_thr * 100)}_" + f"nmsscore{int(self._nms_score_thr * 100)}_" + f"nmsinner{int(self._nms_inner_thr * 100)}_" + f"v{version_safe}" + ) + output_cache = input_cache.make_folder( + cache_prefix, + description=f"Multi-scale SAM1 masks with {self._checkpoint} checkpoint", + ) + + num_zeropad = len(str(len(input_scene.images))) + 2 + regenerate_cache = False + + if output_cache.num_files != input_scene.num_images: + if output_cache.num_files != 0: + self._logger.info( + f"Cache has {output_cache.num_files} files but expected " + f"{input_scene.num_images}. Regenerating cache." + ) + output_cache.clear_current_folder() + regenerate_cache = True + + for image_id in range(input_scene.num_images): + if regenerate_cache: + break + cache_filename = f"masks_{image_id:0{num_zeropad}}" + if not output_cache.has_file(cache_filename): + self._logger.info( + f"Masks {cache_filename} not found in the cache. " f"Clearing cache and regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + cache_meta = output_cache.get_file_metadata(cache_filename) + value_meta = cache_meta.get("metadata", {}) + if ( + value_meta.get("checkpoint") != self._checkpoint + or value_meta.get("points_per_side") != self._points_per_side + or value_meta.get("pred_iou_thresh") != self._pred_iou_thresh + or value_meta.get("stability_score_thresh") != self._stability_score_thresh + or value_meta.get("crop_n_layers") != self._crop_n_layers + or value_meta.get("min_mask_region_area") != self._min_mask_region_area + or value_meta.get("nms_iou_thr") != self._nms_iou_thr + or value_meta.get("nms_score_thr") != self._nms_score_thr + or value_meta.get("nms_inner_thr") != self._nms_inner_thr + ): + self._logger.info( + f"Cache metadata does not match expected parameters. " f"Clearing cache and regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + if regenerate_cache: + self._logger.info("Generating multi-scale SAM1 masks for all images.") + + pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM1 masks") + + for image_meta in pbar: + image_path = image_meta.image_path + img = cv2.imread(image_path) + assert img is not None, f"Failed to load image {image_path}" + + img = image_meta.camera_metadata.undistort_image(img) + + masks_dict = self._generate_multi_scale_masks(img) + + mask_data = {} + for scale_name, masks in masks_dict.items(): + if len(masks) > 0: + mask_data[f"{scale_name}_segmentations"] = np.stack( + [m["segmentation"].astype(np.uint8) for m in masks], axis=0 + ) + mask_data[f"{scale_name}_bboxes"] = np.array([m["bbox"] for m in masks], dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.array([m["area"] for m in masks], dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.array( + [m["predicted_iou"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_stability_scores"] = np.array( + [m["stability_score"] for m in masks], dtype=np.float32 + ) + else: + mask_data[f"{scale_name}_segmentations"] = np.zeros( + (0, img.shape[0], img.shape[1]), dtype=np.uint8 + ) + mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) + mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) + + cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data=mask_data, + data_type="pt", + metadata={ + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "min_mask_region_area": self._min_mask_region_area, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + }, + ) + + pbar.close() + self._logger.info(f"Generated masks for {input_scene.num_images} images.") + else: + self._logger.info("Loading masks from cache.") + + output_scene = SfmScene( + cameras=input_scene.cameras, + images=input_scene.images, + points=input_scene.points, + points_err=input_scene.points_err, + points_rgb=input_scene.points_rgb, + scene_bbox=input_scene.scene_bbox, + transformation_matrix=input_scene.transformation_matrix, + cache=output_cache, + ) + + return output_scene + + @staticmethod + def name() -> str: + return "ComputeMultiScaleSAM1Masks" + + def state_dict(self) -> dict[str, Any]: + return { + "name": self.name(), + "version": self.version, + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "points_per_batch": self._points_per_batch, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "crop_n_points_downscale_factor": self._crop_n_points_downscale_factor, + "min_mask_region_area": self._min_mask_region_area, + "box_nms_thresh": self._box_nms_thresh, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + "device": str(self._device), + } + + @staticmethod + def from_state_dict(state_dict: dict[str, Any]) -> "ComputeMultiScaleSAM1Masks": + if state_dict["name"] != "ComputeMultiScaleSAM1Masks": + raise ValueError( + f"Expected state_dict with name 'ComputeMultiScaleSAM1Masks', " + f"got {state_dict['name']} instead." + ) + + return ComputeMultiScaleSAM1Masks( + checkpoint=state_dict["checkpoint"], + points_per_side=state_dict["points_per_side"], + points_per_batch=state_dict.get("points_per_batch", 256), + pred_iou_thresh=state_dict["pred_iou_thresh"], + stability_score_thresh=state_dict["stability_score_thresh"], + crop_n_layers=state_dict.get("crop_n_layers", 1), + crop_n_points_downscale_factor=state_dict.get("crop_n_points_downscale_factor", 1), + min_mask_region_area=state_dict.get("min_mask_region_area", 100), + box_nms_thresh=state_dict.get("box_nms_thresh", 0.7), + nms_iou_thr=state_dict["nms_iou_thr"], + nms_score_thr=state_dict["nms_score_thr"], + nms_inner_thr=state_dict["nms_inner_thr"], + device=state_dict["device"], + ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py index 6f4229f..5e45b01 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py @@ -15,190 +15,13 @@ import numpy as np import torch import tqdm -from torchvision.ops.boxes import batched_nms import sam2.utils.amg as _sam2_amg from fvdb_reality_capture.foundation_models import SAM2Model from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene from fvdb_reality_capture.transforms import BaseTransform, transform -def mask_nms( - masks: torch.Tensor, - scores: torch.Tensor, - iou_thr: float = 0.7, - score_thr: float = 0.1, - inner_thr: float = 0.2, -) -> torch.Tensor: - """ - Perform mask non-maximum suppression (NMS) on a set of masks. - - Removes redundant masks based on IoU overlap and inner containment. - This implementation is fully vectorized for efficient GPU computation. - - Args: - masks: Binary masks, shape [num_masks, H, W]. - scores: Mask scores, shape [num_masks]. - iou_thr: IoU threshold for NMS. - score_thr: Minimum score threshold. - inner_thr: Inner overlap threshold for removing contained masks. - - Returns: - Indices of selected masks after NMS. - """ - if len(masks) == 0: - return torch.tensor([], dtype=torch.long) - - scores, idx = scores.sort(0, descending=True) - num_masks = idx.shape[0] - - masks_ord = masks[idx.view(-1), :] - - # Flatten masks for vectorized computation: [N, H*W] - # Use reshape instead of view since indexing may produce non-contiguous tensor - masks_flat = masks_ord.reshape(num_masks, -1).float() - - # Compute all pairwise intersections in one matrix multiply - # For binary masks: intersection[i,j] = sum(mask_i & mask_j) = mask_i @ mask_j.T - intersection = masks_flat @ masks_flat.T # [N, N] - - # Compute areas - masks_area = masks_flat.sum(dim=1) # [N] - - # Compute unions: union[i,j] = area_i + area_j - intersection[i,j] - union = masks_area[:, None] + masks_area[None, :] - intersection - - # Compute IoU matrix (avoid division by zero) - iou_matrix = torch.where(union > 0, intersection / union, torch.zeros_like(union)) - - # Compute containment ratios for inner IoU - # ratio_i[i,j] = intersection[i,j] / area_i - # ratio_j[i,j] = intersection[i,j] / area_j - area_i = masks_area[:, None].expand_as(intersection) - area_j = masks_area[None, :].expand_as(intersection) - - # Avoid division by zero - ratio_i = torch.where(area_i > 0, intersection / area_i, torch.zeros_like(intersection)) - ratio_j = torch.where(area_j > 0, intersection / area_j, torch.zeros_like(intersection)) - - # Inner IoU for containment detection - # Case 1: j is mostly contained in i (ratio_i < 0.5 and ratio_j >= 0.85) - # Case 2: i is mostly contained in j (ratio_i >= 0.85 and ratio_j < 0.5) - inner_iou_values = 1 - ratio_i * ratio_j - - # Build inner_iou_matrix based on containment conditions - inner_iou_matrix = torch.zeros_like(iou_matrix) - - # Upper triangle: j contained in i - condition_upper = (ratio_i < 0.5) & (ratio_j >= 0.85) - inner_iou_matrix = torch.where(condition_upper, inner_iou_values, inner_iou_matrix) - - # Lower triangle: i contained in j (transpose the condition) - condition_lower = (ratio_i >= 0.85) & (ratio_j < 0.5) - inner_iou_matrix = torch.where(condition_lower, inner_iou_values.T, inner_iou_matrix) - - # Apply triangular masks and compute max values - iou_matrix = torch.triu(iou_matrix, diagonal=1) - iou_max, _ = iou_matrix.max(dim=0) - - inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) - inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) - inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=-1) - inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) - - keep = iou_max <= iou_thr - keep_conf = scores > score_thr - keep_inner_u = inner_iou_max_u <= 1 - inner_thr - keep_inner_l = inner_iou_max_l <= 1 - inner_thr - - # If no masks pass thresholds, keep top 3 - if keep_conf.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_conf[index] = True - if keep_inner_u.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_inner_u[index] = True - if keep_inner_l.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_inner_l[index] = True - - keep = keep & keep_conf & keep_inner_u & keep_inner_l - selected_idx = idx[keep] - - return selected_idx - - -def masks_update( - *mask_lists, - iou_thr: float = 0.8, - score_thr: float = 0.7, - inner_thr: float = 0.5, -) -> tuple: - """ - Apply mask NMS to multiple lists of masks. - - Args: - *mask_lists: Variable number of mask lists to filter. - iou_thr: IoU threshold for NMS. - score_thr: Score threshold. - inner_thr: Inner overlap threshold. - - Returns: - Tuple of filtered mask lists. - """ - masks_new = [] - - for masks_lvl in mask_lists: - if len(masks_lvl) == 0: - masks_new.append([]) - continue - - seg_pred = torch.from_numpy(np.stack([m["segmentation"] for m in masks_lvl], axis=0)) - iou_pred = torch.from_numpy(np.stack([m["predicted_iou"] for m in masks_lvl], axis=0)) - stability = torch.from_numpy(np.stack([m["stability_score"] for m in masks_lvl], axis=0)) - - scores = stability * iou_pred - keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr) - - keep_set = set(keep_mask_nms.int().cpu().numpy().tolist()) - filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set] - masks_new.append(filtered_masks) - return tuple(masks_new) - - -def _cross_crop_nms( - mask_list: List[Dict[str, Any]], - iou_threshold: float = 0.7, -) -> List[Dict[str, Any]]: - """Run NMS across masks from multiple crops; prefer masks from smaller crops. - - Args: - mask_list: List of mask records with "bbox" (xywh) and "crop_box" (xywh). - iou_threshold: Box IoU threshold for NMS. - - Returns: - Filtered list of mask records. - """ - if len(mask_list) <= 1: - return mask_list - # bbox and crop_box are [x, y, w, h]; convert bbox to xyxy for batched_nms - boxes_xywh = np.array([m["bbox"] for m in mask_list], dtype=np.float32) - x1 = boxes_xywh[:, 0] - y1 = boxes_xywh[:, 1] - x2 = boxes_xywh[:, 0] + boxes_xywh[:, 2] - y2 = boxes_xywh[:, 1] + boxes_xywh[:, 3] - boxes_xyxy = torch.from_numpy(np.stack([x1, y1, x2, y2], axis=1)) - crop_areas = np.array( - [m["crop_box"][2] * m["crop_box"][3] for m in mask_list], - dtype=np.float32, - ) - scores = torch.from_numpy(1.0 / (crop_areas + 1e-6)) - keep = batched_nms( - boxes_xyxy.float(), - scores, - torch.zeros(len(mask_list), dtype=torch.long), - iou_threshold=iou_threshold, - ) - return [mask_list[i] for i in keep.tolist()] +from .mask_utils import cross_crop_nms, masks_update, postprocess_small_regions @transform @@ -211,14 +34,14 @@ class ComputeMultiScaleSAM2Masks(BaseTransform): to each scale level independently. """ - version = "1.0.0" + version = "1.2.0" def __init__( self, checkpoint: Literal["large", "small", "tiny", "base_plus"] = "large", points_per_side: int = 32, points_per_batch: int = 64, - pred_iou_thresh: float = 0.7, + pred_iou_thresh: float = 0.5, stability_score_thresh: float = 0.85, crop_n_layers: int = 1, crop_n_points_downscale_factor: int = 1, @@ -329,12 +152,43 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_m.extend(mm) all_l.extend(ml) + log_diag = not hasattr(self, '_diag_count') or self._diag_count < 2 + if log_diag: + if not hasattr(self, '_diag_count'): + self._diag_count = 0 + self._diag_count += 1 + self._logger.info( + "[diag] after predict: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + # Cross-crop NMS (prefer smaller crops) if len(crop_boxes) > 1: - all_default = _cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) - all_s = _cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) - all_m = _cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) - all_l = _cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + all_default = cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) + all_s = cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) + all_m = cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) + all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + + if log_diag: + self._logger.info( + "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + # Remove small disconnected regions and holes (matches original SAM's + # postprocess_small_regions step in generate_curr_anns) + if self._min_mask_region_area > 0: + nms_thresh = self._box_nms_thresh + all_default = postprocess_small_regions(all_default, self._min_mask_region_area, nms_thresh) + all_s = postprocess_small_regions(all_s, self._min_mask_region_area, nms_thresh) + all_m = postprocess_small_regions(all_m, self._min_mask_region_area, nms_thresh) + all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) + + if log_diag: + self._logger.info( + "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) # Mask NMS per scale masks_default, masks_s, masks_m, masks_l = masks_update( @@ -347,6 +201,12 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: inner_thr=self._nms_inner_thr, ) + if log_diag: + self._logger.info( + "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", + len(masks_default), len(masks_s), len(masks_m), len(masks_l), + ) + return { "default": masks_default, "s": masks_s, @@ -370,6 +230,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: input_cache: SfmCache = input_scene.cache # Create cache folder + version_safe = self.version.replace(".", "_") cache_prefix = ( f"sam2_multi_scale_masks_{self._checkpoint}_" f"p{self._points_per_side}_" @@ -378,7 +239,8 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: f"crop{self._crop_n_layers}_" f"nmsiou{int(self._nms_iou_thr * 100)}_" f"nmsscore{int(self._nms_score_thr * 100)}_" - f"nmsinner{int(self._nms_inner_thr * 100)}" + f"nmsinner{int(self._nms_inner_thr * 100)}_" + f"v{version_safe}" ) output_cache = input_cache.make_folder( cache_prefix, @@ -433,16 +295,17 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: if regenerate_cache: self._logger.info("Generating multi-scale SAM2 masks for all images.") # Suppress SAM2's per-image INFO (e.g. "Computing image embeddings...") + # SAM2 logs through root, so we must suppress root. Setting our own + # logger level explicitly ensures propagated messages still reach + # root's handlers (propagation skips the parent's level check). _root = logging.getLogger() - _sam2 = logging.getLogger("sam2") _prev_root = _root.level - _prev_sam2 = _sam2.level - try: - _root.setLevel(logging.WARNING) - _sam2.setLevel(logging.WARNING) - except Exception: - # Silently ignore errors setting logging levels - pass + _prev_self = self._logger.level + _sam2_model_logger = logging.getLogger("fvdb_reality_capture.foundation_models.sam2.SAM2Model") + _prev_sam2_model = _sam2_model_logger.level + self._logger.setLevel(logging.DEBUG) + _sam2_model_logger.setLevel(logging.DEBUG) + _root.setLevel(logging.WARNING) pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM2 masks") @@ -503,13 +366,9 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: ) pbar.close() - # Restore logging levels - try: - _root.setLevel(_prev_root) - _sam2.setLevel(_prev_sam2) - except Exception: - # Silently ignore errors restoring logging levels - pass + _root.setLevel(_prev_root) + self._logger.setLevel(_prev_self) + _sam2_model_logger.setLevel(_prev_sam2_model) self._logger.info(f"Generated masks for {input_scene.num_images} images.") else: self._logger.info("Loading masks from cache.") diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py index 2be8378..938b66a 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py @@ -98,11 +98,23 @@ def warmup_cache(self) -> None: """ import tqdm + n_empty = 0 if self._cache_features: for idx in tqdm.tqdm(range(len(self)), desc="Warming up feature cache"): index = self._indices[idx] if index not in self._features_cache: self.get_feature_data(index) + _, seg_map, _ = self._features_cache.get(index, self.get_feature_data(index)) + if not (seg_map >= 0).any(): + n_empty += 1 + + if n_empty > 0: + logger.warning( + "%d / %d training images (%.1f%%) have NO valid masks at level %d " + "-- these contribute zero gradient during training", + n_empty, len(self), 100 * n_empty / len(self), + self.feature_level, + ) if self._cache_images: for idx in tqdm.tqdm(range(len(self)), desc="Warming up image cache"): @@ -217,9 +229,7 @@ def build_feature_map( # Unbatched path: plain tensor [N_masks, clip_n_dims] H, W = seg_map.shape feature_mask = seg_map >= 0 - # Use empty -- invalid pixels are never read (the loss gathers only - # valid pixels via the mask). - gt_features = torch.empty( + gt_features = torch.zeros( H, W, clip_n_dims, dtype=features.dtype, device=features.device ) if feature_mask.any(): @@ -250,9 +260,7 @@ def build_feature_map( dtype = features.jdata.dtype feature_mask = seg_map >= 0 # [B, H, W] - # Use empty -- invalid pixels are never read (the loss gathers only - # valid pixels via the mask). - gt_features = torch.empty( + gt_features = torch.zeros( B, H, W, clip_n_dims, dtype=dtype, device=device ) for b in range(B): diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py index 706b835..ce47739 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py @@ -312,6 +312,26 @@ def save_checkpoint( torch.save(checkpoint, ckpt_path) self._logger.info(f"Saved checkpoint to {ckpt_path}") + @torch.no_grad() + def save_final_checkpoint(self, checkpoint: dict[str, Any]) -> pathlib.Path | None: + """Save a ``final_checkpoint.pt`` at the run's top-level directory. + + This makes it easy for evaluation scripts to find the finished + checkpoint without globbing through step-numbered subdirectories. + + Args: + checkpoint: Checkpoint dictionary (from ``Trainer.state_dict()``). + + Returns: + Path to the saved file, or *None* if saving is disabled. + """ + if self._save_path is None: + return None + final_path = self._save_path / "final_checkpoint.pt" + torch.save(checkpoint, final_path) + self._logger.info(f"Saved final checkpoint to {final_path}") + return final_path + # ------------------------------------------------------------------ # Helpers (matching GARfVDB writer helpers) # ------------------------------------------------------------------ diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py index 9885278..5859278 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py @@ -25,6 +25,7 @@ from ..util import calculate_pca_projection, cosine_error_map, pca_projection_fast from ..vq_utils import ( ResidualVectorQuantizationWithClustering, + load_all_clip_features, load_clip_features_for_level, ) from .dataset import ( @@ -226,7 +227,6 @@ def new( ) # Initialize codebooks via K-means on CLIP features - logger.info("Loading CLIP features for codebook initialization...") all_dataset = LangSplatV2Dataset( sfm_scene=sfm_scene, feature_level=config.feature_level, @@ -234,10 +234,15 @@ def new( cache_features=False, cache_images=False, ) - clip_features = load_clip_features_for_level( - full_dataset=all_dataset, - feature_level=config.feature_level, - ) + if config.init_codebooks_all_levels: + logger.info("Loading CLIP features from ALL scale levels for codebook initialization...") + clip_features = load_all_clip_features(full_dataset=all_dataset) + else: + logger.info(f"Loading CLIP features for level {config.feature_level} for codebook initialization...") + clip_features = load_clip_features_for_level( + full_dataset=all_dataset, + feature_level=config.feature_level, + ) logger.info(f"Loaded {clip_features.shape[0]:,} CLIP features of dimension {clip_features.shape[1]}") # Run K-means to get initial codebooks diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py index 17357f1..a352189 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py @@ -166,3 +166,32 @@ def load_clip_features_for_level( raise RuntimeError(f"No features found for level {feature_level}") return torch.cat(all_features, dim=0).float() + + +def load_all_clip_features( + full_dataset: LangSplatV2Dataset, +) -> torch.Tensor: + """Load CLIP features from ALL scale levels across all images. + + This matches the original LangSplatV2 codebook initialization which + concatenates every ``*_f.npy`` file (each containing features from all + 4 scales) into a single tensor for k-means. + + Args: + full_dataset: The full dataset containing all features. + + Returns: + Feature tensor of shape ``[N_all, clip_n_dims]`` containing + features from all scale levels and all images. + """ + all_features = [] + + for image_id in range(len(full_dataset)): + features, _, _ = full_dataset.get_feature_data(image_id) + if features.shape[0] > 0: + all_features.append(features) + + if len(all_features) == 0: + raise RuntimeError("No features found across any images") + + return torch.cat(all_features, dim=0).float() diff --git a/open_vocabulary_segmentation/langsplatv2/pyproject.toml b/open_vocabulary_segmentation/langsplatv2/pyproject.toml index 47437f2..d0fb6d0 100644 --- a/open_vocabulary_segmentation/langsplatv2/pyproject.toml +++ b/open_vocabulary_segmentation/langsplatv2/pyproject.toml @@ -10,6 +10,7 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ + "gdown", "numpy", "torch", "torchvision", diff --git a/open_vocabulary_segmentation/langsplatv2/train_eval.sh b/open_vocabulary_segmentation/langsplatv2/train_eval.sh new file mode 100644 index 0000000..2a764b2 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/train_eval.sh @@ -0,0 +1,21 @@ +#! /bin/bash +set -ex +for scene in ramen figurines teatime waldo_kitchen; do + for level in 1 2 3; do + python train_langsplatv2.py \ + --sfm-dataset-path data/lerf_ovs/${scene} \ + --reconstruction-path reconstructions/${scene}.ply \ + --config.feature-level $level \ + --run-name ${scene}_level_${level} \ + --log-path langsplatv2_logs \ + --config.max-steps 10000 \ + --preprocess.sam-model sam2 + done + + # Collect checkpoints (final_checkpoint.pt is saved at the run's top level) + mkdir -p langsplatv2_results + for level in 1 2 3; do + cp langsplatv2_logs/${scene}_level_${level}/final_checkpoint.pt \ + langsplatv2_results/${scene}_level_${level}.pt + done +done diff --git a/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py b/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py index 49f8145..d39ade6 100644 --- a/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py +++ b/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py @@ -136,10 +136,16 @@ def main( runner.train() + # Save a final_checkpoint.pt at the top level of the run directory + # so evaluation scripts can find it without globbing step subdirectories + final_path = writer.save_final_checkpoint(runner.state_dict()) + logger.info("=" * 60) logger.info("Training complete!") if writer.log_path is not None: logger.info(f"Results saved to {writer.log_path}") + if final_path is not None: + logger.info(f"Final checkpoint: {final_path}") logger.info("=" * 60) writer.close() diff --git a/open_vocabulary_segmentation/langsplatv2/train_scenes.sh b/open_vocabulary_segmentation/langsplatv2/train_scenes.sh new file mode 100644 index 0000000..109c1b8 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/train_scenes.sh @@ -0,0 +1,13 @@ +#! /bin/bash +export PYTHONUNBUFFERED=1 + +for scene in teatime waldo_kitchen; do + frgs reconstruct \ + --run-name ${scene} \ + --tx.image-downsample-factor 1 \ + data/lerf_ovs/${scene}/ \ + -uv 10 \ + -o reconstructions/${scene}.ply \ + --cfg.batch-size 1 \ + --cfg.pose_opt_start_epoch 20 +done From df6bd0167c48b12dbe5703ed0dcff274b0a86f05 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 4 Mar 2026 18:40:40 +1300 Subject: [PATCH 02/13] rename and organization Signed-off-by: Jonathan Swartz --- .../lerf_ovs/batch_reconstruct_eval_scenes.sh} | 2 +- .../lerf_ovs/batch_train_eval_langsplat.sh} | 2 +- .../langsplatv2/evaluation/{ => lerf_ovs}/download_data.py | 0 .../langsplatv2/evaluation/{ => lerf_ovs}/eval_lerf.py | 0 open_vocabulary_segmentation/langsplatv2/pyproject.toml | 2 +- .../langsplatv2/{ => scripts}/train_langsplatv2.py | 0 6 files changed, 3 insertions(+), 3 deletions(-) rename open_vocabulary_segmentation/langsplatv2/{train_scenes.sh => evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh} (83%) rename open_vocabulary_segmentation/langsplatv2/{train_eval.sh => evaluation/lerf_ovs/batch_train_eval_langsplat.sh} (93%) rename open_vocabulary_segmentation/langsplatv2/evaluation/{ => lerf_ovs}/download_data.py (100%) rename open_vocabulary_segmentation/langsplatv2/evaluation/{ => lerf_ovs}/eval_lerf.py (100%) rename open_vocabulary_segmentation/langsplatv2/{ => scripts}/train_langsplatv2.py (100%) diff --git a/open_vocabulary_segmentation/langsplatv2/train_scenes.sh b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh similarity index 83% rename from open_vocabulary_segmentation/langsplatv2/train_scenes.sh rename to open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh index 109c1b8..b5cb0c8 100644 --- a/open_vocabulary_segmentation/langsplatv2/train_scenes.sh +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh @@ -1,7 +1,7 @@ #! /bin/bash export PYTHONUNBUFFERED=1 -for scene in teatime waldo_kitchen; do +for scene in ramen figurines teatime waldo_kitchen; do frgs reconstruct \ --run-name ${scene} \ --tx.image-downsample-factor 1 \ diff --git a/open_vocabulary_segmentation/langsplatv2/train_eval.sh b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh similarity index 93% rename from open_vocabulary_segmentation/langsplatv2/train_eval.sh rename to open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh index 2a764b2..471073f 100644 --- a/open_vocabulary_segmentation/langsplatv2/train_eval.sh +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh @@ -2,7 +2,7 @@ set -ex for scene in ramen figurines teatime waldo_kitchen; do for level in 1 2 3; do - python train_langsplatv2.py \ + python ../../scripts/train_langsplatv2.py \ --sfm-dataset-path data/lerf_ovs/${scene} \ --reconstruction-path reconstructions/${scene}.ply \ --config.feature-level $level \ diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/download_data.py similarity index 100% rename from open_vocabulary_segmentation/langsplatv2/evaluation/download_data.py rename to open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/download_data.py diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/eval_lerf.py similarity index 100% rename from open_vocabulary_segmentation/langsplatv2/evaluation/eval_lerf.py rename to open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/eval_lerf.py diff --git a/open_vocabulary_segmentation/langsplatv2/pyproject.toml b/open_vocabulary_segmentation/langsplatv2/pyproject.toml index d0fb6d0..aef3981 100644 --- a/open_vocabulary_segmentation/langsplatv2/pyproject.toml +++ b/open_vocabulary_segmentation/langsplatv2/pyproject.toml @@ -7,7 +7,7 @@ name = "fvdb-langsplatv2" version = "0.0.1" description = "fVDB Implementation of LangSplatV2" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10" dependencies = [ "gdown", diff --git a/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py similarity index 100% rename from open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py rename to open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py From 4941f71757246a3d6d690bf895efa344f9105bf8 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 4 Mar 2026 19:02:58 +1300 Subject: [PATCH 03/13] readme updates Signed-off-by: Jonathan Swartz --- .../langsplatv2/README.md | 26 ++-- .../langsplatv2/environment.yml | 6 +- .../langsplatv2/evaluation/lerf_ovs/README.md | 131 ++++++++++++++++++ .../langsplatv2/pyproject.toml | 10 +- 4 files changed, 155 insertions(+), 18 deletions(-) create mode 100644 open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md diff --git a/open_vocabulary_segmentation/langsplatv2/README.md b/open_vocabulary_segmentation/langsplatv2/README.md index 4975336..c27147d 100644 --- a/open_vocabulary_segmentation/langsplatv2/README.md +++ b/open_vocabulary_segmentation/langsplatv2/README.md @@ -1,10 +1,10 @@ # LangSplatV2 (fVDB) -LangSplatV2-style open-vocabulary 3D segmentation using [fVDB](https://github.com/openvdb/fvdb-core) and pre-trained Gaussian splat reconstructions. This implementation trains per-Gaussian sparse coefficient fields and shared CLIP-aligned codebooks on an existing reconstruction; it does not train the underlying Gaussians or colors. +LangSplatV2-style open-vocabulary 3D segmentation using [fVDB](https://github.com/openvdb/fvdb-core) and pre-trained Gaussian splat reconstructions. This implementation trains per-Gaussian sparse coefficient fields and shared CLIP-aligned codebooks on an existing reconstruction. ## What this implements -- **Preprocessing**: Multi-scale SAM2 masks and OpenCLIP feature encoding for each image (cached on disk). +- **Preprocessing**: Multi-scale SAM masks (SAM1 or SAM2, configurable) and OpenCLIP feature encoding for each image (cached on disk). - **Training**: Residual VQ codebooks and per-splat sparse logits so that rendered language features match the CLIP embeddings from SAM masks. One feature level (scale) per run; train multiple levels separately and combine at inference. - **Compatibility**: Same feature pipeline and training setup (loss, LR, layer schedule) as the original LangSplatV2; uses fVDB for the 3D representation and rendering. @@ -22,7 +22,7 @@ conda activate fvdb pip install -e . ``` -Dependencies (see `pyproject.toml`) include `torch`, `open-clip-torch`, `fvdb-reality-capture`, `tyro`, and optional TensorBoard for logging. +Dependencies (see `pyproject.toml`) include `fvdb-core`, `fvdb-reality-capture`, `torch`, `open-clip-torch`, `sam2`, `tyro`, and `matplotlib`. ## How to run @@ -31,7 +31,7 @@ Training loads the SfM scene, applies preprocessing (SAM2 + CLIP) with caching, **Minimal (COLMAP scene + PLY reconstruction):** ```bash -python train_langsplatv2.py \ +python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/colmap/scene \ --reconstruction-path /path/to/point_cloud.ply ``` @@ -39,7 +39,7 @@ python train_langsplatv2.py \ **With explicit feature level and log directory:** ```bash -python train_langsplatv2.py \ +python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/colmap/scene \ --reconstruction-path /path/to/point_cloud.ply \ --config.feature-level 1 \ @@ -50,7 +50,7 @@ python train_langsplatv2.py \ ```bash for level in 1 2 3; do - python train_langsplatv2.py \ + python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/scene \ --reconstruction-path /path/to/gaussians.ply \ --config.feature-level $level \ @@ -63,8 +63,10 @@ done - `--config.feature-level` — 0=default, 1=small, 2=medium, 3=large (default: 1). - `--config.max-steps` — Training steps (default from max_epochs if not set). -- `--preprocess.image-downsample-factor` — Downsample images before SAM2/CLIP (e.g. 2 for speed). +- `--preprocess.image-downsample-factor` — Downsample images before SAM/CLIP (e.g. 2 for speed). +- `--preprocess.sam-model` — `sam1` or `sam2` (default: `sam2`). - `--preprocess.sam2.checkpoint` — SAM2 size: `large`, `small`, `tiny`, `base_plus`. +- `--preprocess.sam1.checkpoint` — SAM1 variant: `vit_h`, `vit_l`, `vit_b`. - `--log-path` — Directory for run subdirs (checkpoints, metrics). Use `None` to disable saving. - `--io.use-tensorboard` — Log scalars (and optionally images) to TensorBoard. - `--use-every-n-as-val` — Hold out every N-th image for validation (e.g. 5); -1 = no validation. @@ -74,7 +76,8 @@ done With `--log-path` set (e.g. `langsplatv2_logs`), each run writes: - `log_path/run_/` (or `log_path//` if `--run-name` is set) - - `checkpoints//langsplatv2_ckpt.pt` — Model state and config (when `io.save_checkpoints` is True). + - `final_checkpoint.pt` — Final model checkpoint saved at the run's top level for easy access. + - `checkpoints//langsplatv2_ckpt.pt` — Per-step model state and config (when `io.save_checkpoints` is True). - `metrics_log.csv` — Step, loss, and optional validation metrics. - `tensorboard/` — If `io.use_tensorboard` is True. - `images/` — If `io.save_images` is True (e.g. feature visualizations at save steps). @@ -83,7 +86,7 @@ Preprocessing caches (SAM2 masks, CLIP features) are stored under the scene’s ## Preprocessing pipeline and cache format -The pipeline (see `LangSplatV2PreprocessConfig` in `config.py`) runs in order: optional scene normalization, point filtering, image downsampling, filter images by visible points, **ComputeMultiScaleSAM2Masks**, **ComputeCLIPFeatures**, and optional cropping. +The pipeline (see `LangSplatV2PreprocessConfig` in `config.py`) runs in order: optional scene normalization, point filtering, image downsampling, filter images by visible points, **ComputeMultiScaleSAM1Masks** or **ComputeMultiScaleSAM2Masks** (controlled by `--preprocess.sam-model`), **ComputeCLIPFeatures**, and optional cropping. ### SAM2 masks (per image) @@ -104,11 +107,10 @@ Training uses a single `feature_level` (0–3) to choose which scale’s seg map ## Training details and comparison with original LangSplatV2 - **Feature generation**: Same as original — crop mask region → pad to square → resize to 224 → OpenCLIP encode → L2-normalize. Scale order and seg-map indexing (default → s → m → l, cumulative) match. -- **Optimization**: Same language-feature LR (0.0025), layer schedule (every 10k steps), and cosine loss over valid pixels with gradient scaling via mask fraction. The scalar `train/loss` is the (mask-fraction-scaled) total loss used for backprop. For a smoother, more interpretable curve when mask coverage varies across images, use `train/cosine_loss_valid`, which is the mean cosine loss over valid pixels only (no mask-fraction scaling), we use this for logging. -- **Data sampling**: One random permutation of all training views per “epoch” (InfiniteSampler with shuffle), one view per step when `batch_size=1`, matching the original’s viewpoint-stack behavior. ## References -- [LangSplatV2: Vision-Language Gaussian Splatting](https://arxiv.org/abs/2312.16084) +- [LangSplat: 3D Language Gaussian Splatting](https://arxiv.org/abs/2312.16084) +- [LangSplatV2: High-dimensional 3D Language Gaussian Splatting with 450+ FPS](https://arxiv.org/abs/2507.07136) - [Segment Anything 2 (SAM2)](https://github.com/facebookresearch/segment-anything-2) - [OpenCLIP](https://github.com/mlfoundations/open_clip) diff --git a/open_vocabulary_segmentation/langsplatv2/environment.yml b/open_vocabulary_segmentation/langsplatv2/environment.yml index 625fcb4..d7a20df 100644 --- a/open_vocabulary_segmentation/langsplatv2/environment.yml +++ b/open_vocabulary_segmentation/langsplatv2/environment.yml @@ -2,16 +2,18 @@ name: fvdb_langsplatv2 channels: - conda-forge dependencies: - - python >=3.11 + - python >=3.10 - numpy - pytorch - torchvision - opencv - tqdm - scikit-learn - - pip + - matplotlib - gdown - open-clip-torch - tyro + - sam2 - segment-anything + - fvdb-core - fvdb-reality-capture diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md new file mode 100644 index 0000000..37a9015 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md @@ -0,0 +1,131 @@ +# LERF-OVS Evaluation + +Open-vocabulary segmentation evaluation on the [LERF-OVS](https://www.lerf.io/) dataset, comparing against ground-truth labelme annotations. Computes segmentation mIoU and localization accuracy across four scenes: `ramen`, `figurines`, `teatime`, `waldo_kitchen`. + +All commands below should be run from this directory (`evaluation/lerf_ovs/`). + +## Prerequisites + +- The `fvdb` conda environment with the `fvdb-langsplatv2` package installed (see the [parent README](../../README.md)). +- `fvdb-reality-capture` installed. + +## Step 1: Download the LERF-OVS dataset + +```bash +python download_data.py --dataset-root data +``` + +This downloads and extracts the LERF-OVS data from Google Drive into `data/lerf_ovs/`. The resulting layout is: + +``` +data/lerf_ovs/ + label//frame_XXXXX.json, frame_XXXXX.jpg + /images/, sparse/ +``` + +## Step 2: Reconstruct Gaussian splats + +```bash +bash batch_reconstruct_eval_scenes.sh +``` + +This runs `frgs reconstruct` on each scene with default settings. Outputs are written to: + +``` +reconstructions/.ply +``` + +To reconstruct a single scene manually: + +```bash +frgs reconstruct \ + --run-name teatime \ + --tx.image-downsample-factor 1 \ + data/lerf_ovs/teatime/ \ + -uv 10 \ + -o reconstructions/teatime.ply \ + --cfg.batch-size 1 \ + --cfg.pose_opt_start_epoch 20 +``` + +## Step 3: Train LangSplatV2 features + +```bash +bash batch_train_eval_langsplat.sh +``` + +For each scene, this trains three models (one per SAM scale level: 1=small, 2=medium, 3=large) for 10k steps using SAM2. The final checkpoints are collected into: + +``` +langsplatv2_results/_level_1.pt +langsplatv2_results/_level_2.pt +langsplatv2_results/_level_3.pt +``` + +To train a single scene and level manually: + +```bash +python ../../scripts/train_langsplatv2.py \ + --sfm-dataset-path data/lerf_ovs/teatime \ + --reconstruction-path reconstructions/teatime.ply \ + --config.feature-level 1 \ + --run-name teatime_level_1 \ + --log-path langsplatv2_logs \ + --config.max-steps 10000 \ + --preprocess.sam-model sam2 +``` + +Then copy the final checkpoint: + +```bash +cp langsplatv2_logs/teatime_level_1/final_checkpoint.pt \ + langsplatv2_results/teatime_level_1.pt +``` + +## Step 4: Evaluate + +**All scenes (auto-discovered from checkpoints):** + +```bash +python eval_lerf.py \ + --lerf-root data/lerf_ovs \ + --results-root langsplatv2_results \ + --reconstructions-root reconstructions +``` + +**Single scene:** + +```bash +python eval_lerf.py \ + --lerf-root data/lerf_ovs \ + --results-root langsplatv2_results \ + --reconstructions-root reconstructions \ + --scenes teatime +``` + +The evaluation: +1. Loads all three level checkpoints per scene +2. Renders CLIP features from each level for each annotated frame +3. Computes OpenCLIP relevancy maps for each ground-truth text prompt +4. Selects the best level per prompt (highest max relevancy score) +5. Reports **segmentation mIoU** (thresholded relevancy vs GT masks) and **localization accuracy** (relevancy peak inside GT bounding box) + +### Evaluation flags + +- `--mask-thresh` — Relevancy threshold for binary segmentation mask (default: 0.4). +- `--eval-topk` — Number of codebook entries to combine at eval (default: 4). +- `--output-dir` — Where to write results and visualizations (default: `lerf_eval_results`). +- `--no-visualizations` — Skip saving per-frame visualization images. +- `--verbose` — Enable debug logging. + +### Output + +Results are saved to `lerf_eval_results/` (or the path given by `--output-dir`): + +``` +lerf_eval_results/ + lerf_results.json # Summary across all scenes (mIoU, localization accuracy) + / + results.json # Per-frame breakdown for this scene + frame_XXXXX.jpg # Per-frame visualizations (if enabled) +``` diff --git a/open_vocabulary_segmentation/langsplatv2/pyproject.toml b/open_vocabulary_segmentation/langsplatv2/pyproject.toml index aef3981..a8d420b 100644 --- a/open_vocabulary_segmentation/langsplatv2/pyproject.toml +++ b/open_vocabulary_segmentation/langsplatv2/pyproject.toml @@ -10,15 +10,17 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ + "fvdb-reality-capture", "gdown", + "matplotlib", "numpy", - "torch", - "torchvision", "opencv-python", - "tqdm", "open-clip-torch", - "fvdb-reality-capture", + "sam2", "scikit-learn", + "torch", + "torchvision", + "tqdm", "tyro", ] From 6e71c9215003a75a209775b902dbc9f1ea833f90 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 4 Mar 2026 22:43:47 +1300 Subject: [PATCH 04/13] train and query scripts Signed-off-by: Jonathan Swartz --- .../langsplatv2/scripts/query_prompt.py | 279 ++++++++++++++++++ .../langsplatv2/scripts/train_all_levels.py | 143 +++++++++ 2 files changed, 422 insertions(+) create mode 100644 open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py create mode 100644 open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py b/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py new file mode 100644 index 0000000..7369eb7 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Visualise the relevancy mask for a text prompt on a trained LangSplatV2 model. + +Loads a single-level checkpoint, renders CLIP features for a chosen camera +view, computes OpenCLIP relevancy against the user's prompt, and writes an +image showing the source photograph, the relevancy heatmap, and the +thresholded binary mask side-by-side. + +Usage: + python scripts/query_prompt.py \\ + --checkpoint langsplatv2_results/scene_level_2.pt \\ + --reconstruction-path reconstructions/scene.ply \\ + --prompt "coffee cup" + + # Specify a different view and output path: + python scripts/query_prompt.py \\ + --checkpoint langsplatv2_results/scene_level_1.pt \\ + --reconstruction-path reconstructions/scene.ply \\ + --prompt "wooden table" \\ + --image-index 42 \\ + --output table_query.jpg +""" +import argparse +import logging +import pathlib + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from langsplatv2.evaluation.openclip_relevancy import OpenCLIPRelevancy +from langsplatv2.training.trainer import LangSplatV2Trainer +from langsplatv2.util import load_splats_from_file + +matplotlib.use("Agg") + +logger = logging.getLogger(__name__) + + +def _load_model(checkpoint_path, gs_model_path, device, eval_topk=None): + """Load a trained LangSplatV2 model and its embedded SfmScene.""" + gs_model, _ = load_splats_from_file(gs_model_path, device) + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False) + trainer = LangSplatV2Trainer.from_state_dict( + state_dict=state_dict, + gs_model=gs_model, + gs_model_path=gs_model_path, + device=device, + eval_only=True, + ) + model = trainer._model + sfm_scene = trainer._sfm_scene + feature_level = trainer._cfg.feature_level + + if eval_topk is not None and model.topk != eval_topk: + logger.info(f"Overriding topk: {model.topk} -> {eval_topk}") + model.topk = eval_topk + + logger.info( + f"Loaded model (feature_level={feature_level}, topk={model.topk}) " f"with {gs_model.num_gaussians:,} Gaussians" + ) + return model, sfm_scene, feature_level + + +def _get_camera(sfm_scene, image_index, device): + """Return (w2c, K, h, w, image_path) for the given image index.""" + sorted_images = sorted(sfm_scene.images, key=lambda img: img.image_path) + if image_index >= len(sorted_images): + raise IndexError(f"--image-index {image_index} out of range " f"(scene has {len(sorted_images)} images)") + img = sorted_images[image_index] + c2w = torch.from_numpy(img.camera_to_world_matrix).float() + K = torch.from_numpy(img.camera_metadata.projection_matrix).float() + w2c = torch.linalg.inv(c2w).contiguous() + return ( + w2c.unsqueeze(0).to(device), + K.unsqueeze(0).to(device), + img.camera_metadata.height, + img.camera_metadata.width, + img.image_path, + ) + + +@torch.no_grad() +def _render_clip(model, w2c, K, width, height): + """Render normalised CLIP features [H, W, 512].""" + feat_maps, _ = model( + world_to_camera=w2c, + projection=K, + image_width=width, + image_height=height, + ) + feat = feat_maps[0] # [H, W, 512] + return feat / (feat.norm(dim=-1, keepdim=True) + 1e-10) + + +def _smooth_mask(mask: torch.Tensor) -> torch.Tensor: + pool = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=3, count_include_pad=False) + pool = pool.to(mask.device) + smoothed = pool(mask.float().unsqueeze(0).unsqueeze(0)) + return (smoothed > 0.5).to(torch.uint8).squeeze(0).squeeze(0) + + +def _relevancy_to_mask(relevancy_2d: torch.Tensor, thresh: float) -> torch.Tensor: + """Convert a raw relevancy map [H, W] to a binary mask. + + Applies the same AvgPool(29) blending, normalisation, and thresholding + that the original LangSplatV2 evaluation uses. + """ + pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False) + pool = pool.to(relevancy_2d.device) + blended = 0.5 * (pool(relevancy_2d.unsqueeze(0).unsqueeze(0)).squeeze() + relevancy_2d) + blended = blended - blended.min() + blended = blended / (blended.max() + 1e-9) + blended = blended * 2.0 - 1.0 + blended = torch.clip(blended, 0, 1) + return _smooth_mask((blended > thresh).to(torch.uint8)) + + +def _save_visualization( + output_path: pathlib.Path, + prompt: str, + feature_level: int, + relevancy: np.ndarray, + mask: np.ndarray, + source_img: np.ndarray | None, + image_index: int, +): + """Write a side-by-side panel image to *output_path*.""" + has_source = source_img is not None + n_cols = 3 if has_source else 2 + fig, axes = plt.subplots(1, n_cols, figsize=(8 * n_cols, 8)) + + col = 0 + + if has_source: + overlay = source_img.astype(np.float32) / 255.0 + mask_rgb = np.zeros_like(overlay) + mask_rgb[:, :, 0] = mask.astype(np.float32) + blended = np.clip(overlay * 0.6 + mask_rgb * 0.4, 0, 1) + axes[col].imshow(blended) + axes[col].set_title(f"Source (image {image_index}) + mask overlay", fontsize=12) + axes[col].axis("off") + col += 1 + + im = axes[col].imshow(relevancy, cmap="turbo", vmin=0, vmax=1) + axes[col].set_title(f"Relevancy (level {feature_level})", fontsize=12) + axes[col].axis("off") + plt.colorbar(im, ax=axes[col], fraction=0.046, pad=0.04) + col += 1 + + axes[col].imshow(mask, cmap="gray", vmin=0, vmax=1) + axes[col].set_title("Binary mask", fontsize=12) + axes[col].axis("off") + + fig.suptitle(f'"{prompt}"', fontsize=16, fontweight="bold") + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved visualisation to {output_path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Visualise relevancy for a text prompt on a trained LangSplatV2 model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--checkpoint", + type=pathlib.Path, + required=True, + help="Path to a trained LangSplatV2 checkpoint (.pt) for one feature level.", + ) + parser.add_argument( + "--reconstruction-path", + type=pathlib.Path, + required=True, + help="Path to the Gaussian splat reconstruction (.ply or .pt).", + ) + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Text query to compute relevancy for (e.g. 'coffee cup').", + ) + parser.add_argument( + "--image-index", + type=int, + default=0, + help="Which dataset image (camera view) to render.", + ) + parser.add_argument( + "--output", + type=pathlib.Path, + default=None, + help="Output image path. Defaults to '_level_img.jpg'.", + ) + parser.add_argument( + "--mask-thresh", + type=float, + default=0.4, + help="Threshold for converting relevancy to a binary mask.", + ) + parser.add_argument( + "--eval-topk", + type=int, + default=None, + help="Override the checkpoint's topk value for codebook decoding.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Torch device.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable debug logging.", + ) + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + + device = torch.device(args.device) + + model, sfm_scene, feature_level = _load_model( + args.checkpoint, + args.reconstruction_path, + device, + args.eval_topk, + ) + + w2c, K, img_h, img_w, image_path = _get_camera(sfm_scene, args.image_index, device) + logger.info(f"Rendering view {args.image_index}: {image_path} ({img_w}x{img_h})") + + source_img = None + if image_path and pathlib.Path(image_path).is_file(): + bgr = cv2.imread(str(image_path)) + if bgr is not None: + source_img = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + clip_feat = _render_clip(model, w2c, K, img_w, img_h) + + clip_relevancy = OpenCLIPRelevancy(device=args.device) + clip_relevancy.set_positives([args.prompt]) + + # get_relevancy_map expects [n_levels, H, W, 512] + relevancy = clip_relevancy.get_relevancy_map(clip_feat.unsqueeze(0)) # [1, 1, H, W] + relevancy_2d = relevancy[0, 0] # [H, W] + + mask = _relevancy_to_mask(relevancy_2d, args.mask_thresh) + + if args.output is None: + safe_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:40] + args.output = pathlib.Path(f"{safe_prompt}_level{feature_level}_img{args.image_index}.jpg") + + _save_visualization( + args.output, + args.prompt, + feature_level, + relevancy_2d.cpu().numpy(), + mask.cpu().numpy(), + source_img, + args.image_index, + ) + + print(f"\nWrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py new file mode 100644 index 0000000..2a01d84 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Train all 3 LangSplatV2 feature levels for a single scene. + +Wraps ``train_langsplatv2.py``, running it once per feature level and +collecting the final checkpoints into a results directory ready for +evaluation. + +Any arguments not recognised by this wrapper are forwarded verbatim to +``train_langsplatv2.py`` (e.g. ``--config.max-steps``, +``--preprocess.sam-model``, ``--dataset-type``). + +Usage: + # Minimal -- trains levels 1, 2, 3 with default settings: + python scripts/train_all_levels.py \\ + --sfm-dataset-path /path/to/colmap/scene \\ + --reconstruction-path /path/to/scene.ply + + # With extra training options: + python scripts/train_all_levels.py \\ + --sfm-dataset-path /data/my_scene \\ + --reconstruction-path /data/my_scene.ply \\ + --name my_scene \\ + --results-dir my_results \\ + --config.max-steps 10000 \\ + --preprocess.sam-model sam1 +""" +import argparse +import pathlib +import shutil +import subprocess +import sys + +SCRIPT_DIR = pathlib.Path(__file__).resolve().parent +TRAIN_SCRIPT = SCRIPT_DIR / "train_langsplatv2.py" + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Train LangSplatV2 at all 3 feature levels and collect checkpoints.", + epilog=( + "Unrecognised arguments are forwarded to train_langsplatv2.py. " + "For example: --config.max-steps 10000 --preprocess.sam-model sam1" + ), + ) + parser.add_argument( + "--sfm-dataset-path", + type=pathlib.Path, + required=True, + help="Path to the SfM dataset (COLMAP, simple_directory, or E57).", + ) + parser.add_argument( + "--reconstruction-path", + type=pathlib.Path, + required=True, + help="Path to the pre-trained Gaussian splat reconstruction (.ply or .pt).", + ) + parser.add_argument( + "--results-dir", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_results"), + help="Directory to collect final checkpoints into (default: langsplatv2_results).", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Scene name used for run directories and checkpoint files. " + "Defaults to the reconstruction file stem.", + ) + parser.add_argument( + "--log-path", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_logs"), + help="Directory for per-run training logs and checkpoints (default: langsplatv2_logs).", + ) + parser.add_argument( + "--levels", + type=int, + nargs="+", + default=[1, 2, 3], + help="Feature levels to train (default: 1 2 3).", + ) + + args, extra = parser.parse_known_args() + name = args.name or args.reconstruction_path.stem + + failed_levels: list[int] = [] + + for level in args.levels: + run_name = f"{name}_level_{level}" + cmd = [ + sys.executable, + str(TRAIN_SCRIPT), + "--sfm-dataset-path", str(args.sfm_dataset_path), + "--reconstruction-path", str(args.reconstruction_path), + "--config.feature-level", str(level), + "--run-name", run_name, + "--log-path", str(args.log_path), + *extra, + ] + + print(f"\n{'=' * 60}") + print(f" Training level {level}/{ args.levels[-1]}: {run_name}") + print(f"{'=' * 60}\n") + + result = subprocess.run(cmd) + if result.returncode != 0: + print(f"\nERROR: Training failed for level {level} (exit code {result.returncode})") + failed_levels.append(level) + + # -- Collect checkpoints ------------------------------------------------ + args.results_dir.mkdir(parents=True, exist_ok=True) + collected = 0 + + for level in args.levels: + if level in failed_levels: + continue + run_name = f"{name}_level_{level}" + src = args.log_path / run_name / "final_checkpoint.pt" + dst = args.results_dir / f"{run_name}.pt" + if src.exists(): + shutil.copy2(src, dst) + print(f" Collected: {src} -> {dst}") + collected += 1 + else: + print(f" WARNING: {src} not found, skipping") + + # -- Summary ------------------------------------------------------------ + print(f"\n{'=' * 60}") + print(f" Done -- {collected}/{len(args.levels)} checkpoints in {args.results_dir}") + if failed_levels: + print(f" Failed levels: {failed_levels}") + print(f"{'=' * 60}") + + if failed_levels: + sys.exit(1) + + +if __name__ == "__main__": + main() From 51413cb5e2a4244827808a88224e2832daf05b3d Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 10:38:35 +1300 Subject: [PATCH 05/13] Update open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/scene_transforms/mask_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py index 257e378..ca4b9fb 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py @@ -214,13 +214,13 @@ def mask_nms( if keep_conf.sum() == 0: index = scores.topk(min(3, len(scores))).indices - keep_conf[index, 0] = True + keep_conf[index] = True if keep_inner_u.sum() == 0: index = scores.topk(min(3, len(scores))).indices - keep_inner_u[index, 0] = True + keep_inner_u[index] = True if keep_inner_l.sum() == 0: index = scores.topk(min(3, len(scores))).indices - keep_inner_l[index, 0] = True + keep_inner_l[index] = True keep *= keep_conf keep *= keep_inner_u keep *= keep_inner_l From bd6c57ba1f853d6542f2ff795014b408b9f57751 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 10:39:28 +1300 Subject: [PATCH 06/13] Update open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jonathan Swartz --- .../langsplatv2/scene_transforms/clip_feature_encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py index 5ef25d9..d90de37 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py @@ -266,7 +266,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] - img_gpu = torch.from_numpy(img_rgb).to("cuda", dtype=torch.float32) + img_gpu = torch.from_numpy(img_rgb).to(self._device, dtype=torch.float32) mask_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" if not input_cache.has_file(mask_filename): From efbe246a2e038f685da59de200db780e867c1447 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 10:50:12 +1300 Subject: [PATCH 07/13] add prefiltering of zero-area masks add a comment for a small logic issue found in a PR review Signed-off-by: Jonathan Swartz --- .../langsplatv2/scene_transforms/mask_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py index ca4b9fb..fc1145b 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py @@ -191,6 +191,12 @@ def mask_nms( iou_max, _ = iou_matrix.max(dim=0) inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) + # NOTE: this includes the diagonal and the first super-diagonal, which + # doesn't match the intended “lower triangle excluding diagonal” + # logic used for the containment check. + # this should use the lower triangle below the diagonal (-1) + # but this won't really effect results and we want to keep the original logic + # to match the original LangSplatV2 implementation inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=1) inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) @@ -282,6 +288,18 @@ def masks_update( masks_new.append([]) continue + before_empty = len(masks_lvl) + masks_lvl = [m for m in masks_lvl if m["segmentation"].sum() > 0] + n_empty = before_empty - len(masks_lvl) + if n_empty > 0: + _masks_update_logger.info( + "[masks_update] dropped %d zero-area masks (%d remain)", + n_empty, len(masks_lvl), + ) + if len(masks_lvl) == 0: + masks_new.append([]) + continue + seg_pred = torch.from_numpy(np.stack([m["segmentation"] for m in masks_lvl], axis=0)) iou_pred = torch.from_numpy(np.stack([m["predicted_iou"] for m in masks_lvl], axis=0)) stability = torch.from_numpy(np.stack([m["stability_score"] for m in masks_lvl], axis=0)) From 5fdde041f7bb4d2913d11f29eebeb36dd742eaa5 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 12:25:13 +1300 Subject: [PATCH 08/13] Update open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/training/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py index 938b66a..86eedf7 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py @@ -102,9 +102,10 @@ def warmup_cache(self) -> None: if self._cache_features: for idx in tqdm.tqdm(range(len(self)), desc="Warming up feature cache"): index = self._indices[idx] - if index not in self._features_cache: - self.get_feature_data(index) - _, seg_map, _ = self._features_cache.get(index, self.get_feature_data(index)) + feature_data = self._features_cache.get(index) + if feature_data is None: + feature_data = self.get_feature_data(index) + _, seg_map, _ = feature_data if not (seg_map >= 0).any(): n_empty += 1 From 90801e370111fdc60a21d70404fd0167859944d2 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 12:30:58 +1300 Subject: [PATCH 09/13] wrap the SAM2 verbose logging supression in a try/finally Signed-off-by: Jonathan Swartz --- .../lerf_ovs/batch_train_eval_langsplat.sh | 2 +- .../scene_transforms/multi_scale_sam_masks.py | 132 +++++++++--------- 2 files changed, 70 insertions(+), 64 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh index 471073f..eb6af15 100644 --- a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh @@ -9,7 +9,7 @@ for scene in ramen figurines teatime waldo_kitchen; do --run-name ${scene}_level_${level} \ --log-path langsplatv2_logs \ --config.max-steps 10000 \ - --preprocess.sam-model sam2 + --preprocess.sam-model sam1 done # Collect checkpoints (final_checkpoint.pt is saved at the run's top level) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py index 5e45b01..43958cc 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py @@ -307,69 +307,75 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: _sam2_model_logger.setLevel(logging.DEBUG) _root.setLevel(logging.WARNING) - pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM2 masks") - - for image_meta in pbar: - image_path = image_meta.image_path - img = cv2.imread(image_path) - assert img is not None, f"Failed to load image {image_path}" - - # Undistort the image if the camera has distortion parameters - img = image_meta.camera_metadata.undistort_image(img) - - # Generate multi-scale masks - masks_dict = self._generate_multi_scale_masks(img) - - # Convert masks to storable format - mask_data = {} - for scale_name, masks in masks_dict.items(): - if len(masks) > 0: - # Store segmentation masks and metadata - mask_data[f"{scale_name}_segmentations"] = np.stack( - [m["segmentation"].astype(np.uint8) for m in masks], axis=0 - ) - mask_data[f"{scale_name}_bboxes"] = np.array([m["bbox"] for m in masks], dtype=np.float32) - mask_data[f"{scale_name}_areas"] = np.array([m["area"] for m in masks], dtype=np.int32) - mask_data[f"{scale_name}_predicted_ious"] = np.array( - [m["predicted_iou"] for m in masks], dtype=np.float32 - ) - mask_data[f"{scale_name}_stability_scores"] = np.array( - [m["stability_score"] for m in masks], dtype=np.float32 - ) - else: - # Empty arrays for scales with no masks - mask_data[f"{scale_name}_segmentations"] = np.zeros( - (0, img.shape[0], img.shape[1]), dtype=np.uint8 - ) - mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) - mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) - mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) - mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) - - # Save to cache - cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" - output_cache.write_file( - name=cache_filename, - data=mask_data, - data_type="pt", - metadata={ - "checkpoint": self._checkpoint, - "points_per_side": self._points_per_side, - "pred_iou_thresh": self._pred_iou_thresh, - "stability_score_thresh": self._stability_score_thresh, - "crop_n_layers": self._crop_n_layers, - "min_mask_region_area": self._min_mask_region_area, - "nms_iou_thr": self._nms_iou_thr, - "nms_score_thr": self._nms_score_thr, - "nms_inner_thr": self._nms_inner_thr, - }, - ) - - pbar.close() - _root.setLevel(_prev_root) - self._logger.setLevel(_prev_self) - _sam2_model_logger.setLevel(_prev_sam2_model) - self._logger.info(f"Generated masks for {input_scene.num_images} images.") + try: + pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM2 masks") + + for image_meta in pbar: + image_path = image_meta.image_path + img = cv2.imread(image_path) + assert img is not None, f"Failed to load image {image_path}" + + # Undistort the image if the camera has distortion parameters + img = image_meta.camera_metadata.undistort_image(img) + + # Generate multi-scale masks + masks_dict = self._generate_multi_scale_masks(img) + + # Convert masks to storable format + mask_data = {} + for scale_name, masks in masks_dict.items(): + if len(masks) > 0: + # Store segmentation masks and metadata + mask_data[f"{scale_name}_segmentations"] = np.stack( + [m["segmentation"].astype(np.uint8) for m in masks], axis=0 + ) + mask_data[f"{scale_name}_bboxes"] = np.array( + [m["bbox"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_areas"] = np.array( + [m["area"] for m in masks], dtype=np.int32 + ) + mask_data[f"{scale_name}_predicted_ious"] = np.array( + [m["predicted_iou"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_stability_scores"] = np.array( + [m["stability_score"] for m in masks], dtype=np.float32 + ) + else: + # Empty arrays for scales with no masks + mask_data[f"{scale_name}_segmentations"] = np.zeros( + (0, img.shape[0], img.shape[1]), dtype=np.uint8 + ) + mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) + mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) + + # Save to cache + cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data=mask_data, + data_type="pt", + metadata={ + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "min_mask_region_area": self._min_mask_region_area, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + }, + ) + + pbar.close() + self._logger.info(f"Generated masks for {input_scene.num_images} images.") + finally: + _root.setLevel(_prev_root) + self._logger.setLevel(_prev_self) + _sam2_model_logger.setLevel(_prev_sam2_model) else: self._logger.info("Loading masks from cache.") From a5dd9606d739d0a982c20815c839ae640073b05c Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 12:46:06 +1300 Subject: [PATCH 10/13] notes Signed-off-by: Jonathan Swartz --- .../langsplatv2/evaluation/lerf_ovs/README.md | 4 ++-- open_vocabulary_segmentation/langsplatv2/pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md index 37a9015..bb42623 100644 --- a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md @@ -54,7 +54,7 @@ frgs reconstruct \ bash batch_train_eval_langsplat.sh ``` -For each scene, this trains three models (one per SAM scale level: 1=small, 2=medium, 3=large) for 10k steps using SAM2. The final checkpoints are collected into: +For each scene, this trains three models (one per SAM scale level: 1=small, 2=medium, 3=large) for 10k steps using SAM1. The final checkpoints are collected into: ``` langsplatv2_results/_level_1.pt @@ -72,7 +72,7 @@ python ../../scripts/train_langsplatv2.py \ --run-name teatime_level_1 \ --log-path langsplatv2_logs \ --config.max-steps 10000 \ - --preprocess.sam-model sam2 + --preprocess.sam-model sam1 ``` Then copy the final checkpoint: diff --git a/open_vocabulary_segmentation/langsplatv2/pyproject.toml b/open_vocabulary_segmentation/langsplatv2/pyproject.toml index a8d420b..ee81bc6 100644 --- a/open_vocabulary_segmentation/langsplatv2/pyproject.toml +++ b/open_vocabulary_segmentation/langsplatv2/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "numpy", "opencv-python", "open-clip-torch", + "segment-anything", "sam2", "scikit-learn", "torch", From 6e5a43aba1d20abee1d9c03ea3be0a5f643f538f Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 12:47:56 +1300 Subject: [PATCH 11/13] Update open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/evaluation/openclip_relevancy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py index 21ba846..5ae9b25 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py @@ -58,11 +58,14 @@ def __init__( self.clip_model_type = clip_model_type self.clip_model_pretrained = clip_model_pretrained + # Select precision based on device: use fp16 only on CUDA, fp32 otherwise + precision = "fp16" if self.device.type == "cuda" else "fp32" + # Load OpenCLIP model model, _, _ = open_clip.create_model_and_transforms( clip_model_type, pretrained=clip_model_pretrained, - precision="fp16", + precision=precision, ) model.eval() self.model = model.to(self.device) From 91f0afbece0a9c60e63e06cd4b3cf442869467e0 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 12:50:52 +1300 Subject: [PATCH 12/13] notes fixes Signed-off-by: Jonathan Swartz --- .../langsplatv2/langsplatv2/training/langsplatv2_writer.py | 2 +- .../langsplatv2/scripts/train_all_levels.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py index ce47739..25c614d 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py @@ -325,7 +325,7 @@ def save_final_checkpoint(self, checkpoint: dict[str, Any]) -> pathlib.Path | No Returns: Path to the saved file, or *None* if saving is disabled. """ - if self._save_path is None: + if not self._config.save_checkpoints or self._save_path is None: return None final_path = self._save_path / "final_checkpoint.pt" torch.save(checkpoint, final_path) diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py index 2a01d84..dcc11c4 100644 --- a/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py +++ b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py @@ -89,7 +89,8 @@ def main() -> None: failed_levels: list[int] = [] - for level in args.levels: + total_levels = len(args.levels) + for idx, level in enumerate(args.levels, 1): run_name = f"{name}_level_{level}" cmd = [ sys.executable, @@ -103,7 +104,7 @@ def main() -> None: ] print(f"\n{'=' * 60}") - print(f" Training level {level}/{ args.levels[-1]}: {run_name}") + print(f" Training level {level} ({idx}/{total_levels}): {run_name}") print(f"{'=' * 60}\n") result = subprocess.run(cmd) From e7bfeb7b18970fd905650d61216c24e7540ec236 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Thu, 5 Mar 2026 13:22:24 +1300 Subject: [PATCH 13/13] setting debug prints correctly Signed-off-by: Jonathan Swartz --- .../scene_transforms/clip_feature_encoding.py | 2 +- .../scene_transforms/multi_scale_sam1_masks.py | 8 ++++---- .../scene_transforms/multi_scale_sam_masks.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py index d90de37..3c10b51 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py @@ -310,7 +310,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: # Report per-scale mask coverage for the first few images if image_meta.image_id < 3: coverage = {sn: int((seg_maps[i] >= 0).sum()) for i, sn in enumerate(scale_names)} - self._logger.info( + self._logger.debug( f"Image {image_meta.image_id}: {total_masks} masks, " f"lengths={lengths}, pixel coverage={coverage}" ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py index d86603c..9b6a664 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py @@ -172,7 +172,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: if not hasattr(self, '_diag_count'): self._diag_count = 0 self._diag_count += 1 - self._logger.info( + self._logger.debug( "[diag] after predict: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -184,7 +184,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -197,7 +197,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -213,7 +213,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: ) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", len(masks_default), len(masks_s), len(masks_m), len(masks_l), ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py index 43958cc..7bbef40 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py @@ -157,7 +157,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: if not hasattr(self, '_diag_count'): self._diag_count = 0 self._diag_count += 1 - self._logger.info( + self._logger.debug( "[diag] after predict: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -170,7 +170,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -185,7 +185,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", len(all_default), len(all_s), len(all_m), len(all_l), ) @@ -202,7 +202,7 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: ) if log_diag: - self._logger.info( + self._logger.debug( "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", len(masks_default), len(masks_s), len(masks_m), len(masks_l), ) @@ -303,8 +303,8 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: _prev_self = self._logger.level _sam2_model_logger = logging.getLogger("fvdb_reality_capture.foundation_models.sam2.SAM2Model") _prev_sam2_model = _sam2_model_logger.level - self._logger.setLevel(logging.DEBUG) - _sam2_model_logger.setLevel(logging.DEBUG) + self._logger.setLevel(logging.INFO) + _sam2_model_logger.setLevel(logging.INFO) _root.setLevel(logging.WARNING) try: