diff --git a/open_vocabulary_segmentation/langsplatv2/.gitignore b/open_vocabulary_segmentation/langsplatv2/.gitignore index 4bf5408..417a8a9 100644 --- a/open_vocabulary_segmentation/langsplatv2/.gitignore +++ b/open_vocabulary_segmentation/langsplatv2/.gitignore @@ -1,2 +1,4 @@ *.nsys-rep langsplatv2_logs/ +frgs_logs/ +*_results*/ diff --git a/open_vocabulary_segmentation/langsplatv2/README.md b/open_vocabulary_segmentation/langsplatv2/README.md index 4975336..c27147d 100644 --- a/open_vocabulary_segmentation/langsplatv2/README.md +++ b/open_vocabulary_segmentation/langsplatv2/README.md @@ -1,10 +1,10 @@ # LangSplatV2 (fVDB) -LangSplatV2-style open-vocabulary 3D segmentation using [fVDB](https://github.com/openvdb/fvdb-core) and pre-trained Gaussian splat reconstructions. This implementation trains per-Gaussian sparse coefficient fields and shared CLIP-aligned codebooks on an existing reconstruction; it does not train the underlying Gaussians or colors. +LangSplatV2-style open-vocabulary 3D segmentation using [fVDB](https://github.com/openvdb/fvdb-core) and pre-trained Gaussian splat reconstructions. This implementation trains per-Gaussian sparse coefficient fields and shared CLIP-aligned codebooks on an existing reconstruction. ## What this implements -- **Preprocessing**: Multi-scale SAM2 masks and OpenCLIP feature encoding for each image (cached on disk). +- **Preprocessing**: Multi-scale SAM masks (SAM1 or SAM2, configurable) and OpenCLIP feature encoding for each image (cached on disk). - **Training**: Residual VQ codebooks and per-splat sparse logits so that rendered language features match the CLIP embeddings from SAM masks. One feature level (scale) per run; train multiple levels separately and combine at inference. - **Compatibility**: Same feature pipeline and training setup (loss, LR, layer schedule) as the original LangSplatV2; uses fVDB for the 3D representation and rendering. @@ -22,7 +22,7 @@ conda activate fvdb pip install -e . ``` -Dependencies (see `pyproject.toml`) include `torch`, `open-clip-torch`, `fvdb-reality-capture`, `tyro`, and optional TensorBoard for logging. +Dependencies (see `pyproject.toml`) include `fvdb-core`, `fvdb-reality-capture`, `torch`, `open-clip-torch`, `sam2`, `tyro`, and `matplotlib`. ## How to run @@ -31,7 +31,7 @@ Training loads the SfM scene, applies preprocessing (SAM2 + CLIP) with caching, **Minimal (COLMAP scene + PLY reconstruction):** ```bash -python train_langsplatv2.py \ +python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/colmap/scene \ --reconstruction-path /path/to/point_cloud.ply ``` @@ -39,7 +39,7 @@ python train_langsplatv2.py \ **With explicit feature level and log directory:** ```bash -python train_langsplatv2.py \ +python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/colmap/scene \ --reconstruction-path /path/to/point_cloud.ply \ --config.feature-level 1 \ @@ -50,7 +50,7 @@ python train_langsplatv2.py \ ```bash for level in 1 2 3; do - python train_langsplatv2.py \ + python scripts/train_langsplatv2.py \ --sfm-dataset-path /path/to/scene \ --reconstruction-path /path/to/gaussians.ply \ --config.feature-level $level \ @@ -63,8 +63,10 @@ done - `--config.feature-level` — 0=default, 1=small, 2=medium, 3=large (default: 1). - `--config.max-steps` — Training steps (default from max_epochs if not set). -- `--preprocess.image-downsample-factor` — Downsample images before SAM2/CLIP (e.g. 2 for speed). +- `--preprocess.image-downsample-factor` — Downsample images before SAM/CLIP (e.g. 2 for speed). +- `--preprocess.sam-model` — `sam1` or `sam2` (default: `sam2`). - `--preprocess.sam2.checkpoint` — SAM2 size: `large`, `small`, `tiny`, `base_plus`. +- `--preprocess.sam1.checkpoint` — SAM1 variant: `vit_h`, `vit_l`, `vit_b`. - `--log-path` — Directory for run subdirs (checkpoints, metrics). Use `None` to disable saving. - `--io.use-tensorboard` — Log scalars (and optionally images) to TensorBoard. - `--use-every-n-as-val` — Hold out every N-th image for validation (e.g. 5); -1 = no validation. @@ -74,7 +76,8 @@ done With `--log-path` set (e.g. `langsplatv2_logs`), each run writes: - `log_path/run_/` (or `log_path//` if `--run-name` is set) - - `checkpoints//langsplatv2_ckpt.pt` — Model state and config (when `io.save_checkpoints` is True). + - `final_checkpoint.pt` — Final model checkpoint saved at the run's top level for easy access. + - `checkpoints//langsplatv2_ckpt.pt` — Per-step model state and config (when `io.save_checkpoints` is True). - `metrics_log.csv` — Step, loss, and optional validation metrics. - `tensorboard/` — If `io.use_tensorboard` is True. - `images/` — If `io.save_images` is True (e.g. feature visualizations at save steps). @@ -83,7 +86,7 @@ Preprocessing caches (SAM2 masks, CLIP features) are stored under the scene’s ## Preprocessing pipeline and cache format -The pipeline (see `LangSplatV2PreprocessConfig` in `config.py`) runs in order: optional scene normalization, point filtering, image downsampling, filter images by visible points, **ComputeMultiScaleSAM2Masks**, **ComputeCLIPFeatures**, and optional cropping. +The pipeline (see `LangSplatV2PreprocessConfig` in `config.py`) runs in order: optional scene normalization, point filtering, image downsampling, filter images by visible points, **ComputeMultiScaleSAM1Masks** or **ComputeMultiScaleSAM2Masks** (controlled by `--preprocess.sam-model`), **ComputeCLIPFeatures**, and optional cropping. ### SAM2 masks (per image) @@ -104,11 +107,10 @@ Training uses a single `feature_level` (0–3) to choose which scale’s seg map ## Training details and comparison with original LangSplatV2 - **Feature generation**: Same as original — crop mask region → pad to square → resize to 224 → OpenCLIP encode → L2-normalize. Scale order and seg-map indexing (default → s → m → l, cumulative) match. -- **Optimization**: Same language-feature LR (0.0025), layer schedule (every 10k steps), and cosine loss over valid pixels with gradient scaling via mask fraction. The scalar `train/loss` is the (mask-fraction-scaled) total loss used for backprop. For a smoother, more interpretable curve when mask coverage varies across images, use `train/cosine_loss_valid`, which is the mean cosine loss over valid pixels only (no mask-fraction scaling), we use this for logging. -- **Data sampling**: One random permutation of all training views per “epoch” (InfiniteSampler with shuffle), one view per step when `batch_size=1`, matching the original’s viewpoint-stack behavior. ## References -- [LangSplatV2: Vision-Language Gaussian Splatting](https://arxiv.org/abs/2312.16084) +- [LangSplat: 3D Language Gaussian Splatting](https://arxiv.org/abs/2312.16084) +- [LangSplatV2: High-dimensional 3D Language Gaussian Splatting with 450+ FPS](https://arxiv.org/abs/2507.07136) - [Segment Anything 2 (SAM2)](https://github.com/facebookresearch/segment-anything-2) - [OpenCLIP](https://github.com/mlfoundations/open_clip) diff --git a/open_vocabulary_segmentation/langsplatv2/environment.yml b/open_vocabulary_segmentation/langsplatv2/environment.yml new file mode 100644 index 0000000..d7a20df --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/environment.yml @@ -0,0 +1,19 @@ +name: fvdb_langsplatv2 +channels: + - conda-forge +dependencies: + - python >=3.10 + - numpy + - pytorch + - torchvision + - opencv + - tqdm + - scikit-learn + - matplotlib + - gdown + - open-clip-torch + - tyro + - sam2 + - segment-anything + - fvdb-core + - fvdb-reality-capture diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md new file mode 100644 index 0000000..bb42623 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/README.md @@ -0,0 +1,131 @@ +# LERF-OVS Evaluation + +Open-vocabulary segmentation evaluation on the [LERF-OVS](https://www.lerf.io/) dataset, comparing against ground-truth labelme annotations. Computes segmentation mIoU and localization accuracy across four scenes: `ramen`, `figurines`, `teatime`, `waldo_kitchen`. + +All commands below should be run from this directory (`evaluation/lerf_ovs/`). + +## Prerequisites + +- The `fvdb` conda environment with the `fvdb-langsplatv2` package installed (see the [parent README](../../README.md)). +- `fvdb-reality-capture` installed. + +## Step 1: Download the LERF-OVS dataset + +```bash +python download_data.py --dataset-root data +``` + +This downloads and extracts the LERF-OVS data from Google Drive into `data/lerf_ovs/`. The resulting layout is: + +``` +data/lerf_ovs/ + label//frame_XXXXX.json, frame_XXXXX.jpg + /images/, sparse/ +``` + +## Step 2: Reconstruct Gaussian splats + +```bash +bash batch_reconstruct_eval_scenes.sh +``` + +This runs `frgs reconstruct` on each scene with default settings. Outputs are written to: + +``` +reconstructions/.ply +``` + +To reconstruct a single scene manually: + +```bash +frgs reconstruct \ + --run-name teatime \ + --tx.image-downsample-factor 1 \ + data/lerf_ovs/teatime/ \ + -uv 10 \ + -o reconstructions/teatime.ply \ + --cfg.batch-size 1 \ + --cfg.pose_opt_start_epoch 20 +``` + +## Step 3: Train LangSplatV2 features + +```bash +bash batch_train_eval_langsplat.sh +``` + +For each scene, this trains three models (one per SAM scale level: 1=small, 2=medium, 3=large) for 10k steps using SAM1. The final checkpoints are collected into: + +``` +langsplatv2_results/_level_1.pt +langsplatv2_results/_level_2.pt +langsplatv2_results/_level_3.pt +``` + +To train a single scene and level manually: + +```bash +python ../../scripts/train_langsplatv2.py \ + --sfm-dataset-path data/lerf_ovs/teatime \ + --reconstruction-path reconstructions/teatime.ply \ + --config.feature-level 1 \ + --run-name teatime_level_1 \ + --log-path langsplatv2_logs \ + --config.max-steps 10000 \ + --preprocess.sam-model sam1 +``` + +Then copy the final checkpoint: + +```bash +cp langsplatv2_logs/teatime_level_1/final_checkpoint.pt \ + langsplatv2_results/teatime_level_1.pt +``` + +## Step 4: Evaluate + +**All scenes (auto-discovered from checkpoints):** + +```bash +python eval_lerf.py \ + --lerf-root data/lerf_ovs \ + --results-root langsplatv2_results \ + --reconstructions-root reconstructions +``` + +**Single scene:** + +```bash +python eval_lerf.py \ + --lerf-root data/lerf_ovs \ + --results-root langsplatv2_results \ + --reconstructions-root reconstructions \ + --scenes teatime +``` + +The evaluation: +1. Loads all three level checkpoints per scene +2. Renders CLIP features from each level for each annotated frame +3. Computes OpenCLIP relevancy maps for each ground-truth text prompt +4. Selects the best level per prompt (highest max relevancy score) +5. Reports **segmentation mIoU** (thresholded relevancy vs GT masks) and **localization accuracy** (relevancy peak inside GT bounding box) + +### Evaluation flags + +- `--mask-thresh` — Relevancy threshold for binary segmentation mask (default: 0.4). +- `--eval-topk` — Number of codebook entries to combine at eval (default: 4). +- `--output-dir` — Where to write results and visualizations (default: `lerf_eval_results`). +- `--no-visualizations` — Skip saving per-frame visualization images. +- `--verbose` — Enable debug logging. + +### Output + +Results are saved to `lerf_eval_results/` (or the path given by `--output-dir`): + +``` +lerf_eval_results/ + lerf_results.json # Summary across all scenes (mIoU, localization accuracy) + / + results.json # Per-frame breakdown for this scene + frame_XXXXX.jpg # Per-frame visualizations (if enabled) +``` diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh new file mode 100644 index 0000000..b5cb0c8 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_reconstruct_eval_scenes.sh @@ -0,0 +1,13 @@ +#! /bin/bash +export PYTHONUNBUFFERED=1 + +for scene in ramen figurines teatime waldo_kitchen; do + frgs reconstruct \ + --run-name ${scene} \ + --tx.image-downsample-factor 1 \ + data/lerf_ovs/${scene}/ \ + -uv 10 \ + -o reconstructions/${scene}.ply \ + --cfg.batch-size 1 \ + --cfg.pose_opt_start_epoch 20 +done diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh new file mode 100644 index 0000000..eb6af15 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/batch_train_eval_langsplat.sh @@ -0,0 +1,21 @@ +#! /bin/bash +set -ex +for scene in ramen figurines teatime waldo_kitchen; do + for level in 1 2 3; do + python ../../scripts/train_langsplatv2.py \ + --sfm-dataset-path data/lerf_ovs/${scene} \ + --reconstruction-path reconstructions/${scene}.ply \ + --config.feature-level $level \ + --run-name ${scene}_level_${level} \ + --log-path langsplatv2_logs \ + --config.max-steps 10000 \ + --preprocess.sam-model sam1 + done + + # Collect checkpoints (final_checkpoint.pt is saved at the run's top level) + mkdir -p langsplatv2_results + for level in 1 2 3; do + cp langsplatv2_logs/${scene}_level_${level}/final_checkpoint.pt \ + langsplatv2_results/${scene}_level_${level}.pt + done +done diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/download_data.py b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/download_data.py new file mode 100644 index 0000000..fce3f72 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/download_data.py @@ -0,0 +1,26 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""CLI script to download the LERF-OVS evaluation dataset.""" +from dataclasses import dataclass +from pathlib import Path + +import tyro +from langsplatv2.evaluation.datasets import set_dataset_root +from langsplatv2.evaluation.datasets.lerf import download_lerf_data + + +@dataclass +class DownloadLERFData: + """Download the LERF-OVS dataset for open-vocabulary segmentation evaluation.""" + + dataset_root: Path = Path("data") + """Root directory to store downloaded datasets.""" + + def main(self): + set_dataset_root(self.dataset_root) + download_lerf_data() + + +if __name__ == "__main__": + tyro.cli(DownloadLERFData).main() diff --git a/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/eval_lerf.py b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/eval_lerf.py new file mode 100644 index 0000000..7615e88 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/evaluation/lerf_ovs/eval_lerf.py @@ -0,0 +1,998 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Evaluation script for LERF dataset open-vocabulary segmentation. + +This script evaluates trained LangSplatV2 models on the LERF-OVS dataset by: +1. Loading 3 trained checkpoints (one per SAM scale level) +2. Loading LERF ground-truth annotations (labelme-format JSONs) +3. Rendering CLIP features from each level and computing relevancy maps +4. Computing segmentation mIoU and localization accuracy + +The evaluation matches the original LangSplatV2 eval_lerf.py methodology: +- Relevancy is computed using OpenCLIP (ViT-B-16) with pos/neg softmax scoring +- Segmentation uses AvgPool smoothing + normalization + thresholding +- The best level is chosen per-prompt based on max relevancy score +- Localization checks if the peak of the relevancy map is inside any GT bbox + +LERF-OVS dataset structure: + lerf_ovs/ + label//frame_XXXXX.json, frame_XXXXX.jpg + / + images/ + sparse/ + output//point_cloud/iteration_30000/point_cloud.ply + +Usage: + # Evaluate a single scene + python eval_lerf.py \\ + --lerf-root /path/to/lerf_ovs \\ + --results-root ./langsplatv2_results \\ + --gs-model-path /path/to/point_cloud.ply \\ + --scenes teatime + + # Evaluate all scenes + python eval_lerf.py \\ + --lerf-root /path/to/lerf_ovs \\ + --results-root ./langsplatv2_results +""" +import argparse +import glob +import json +import logging +import os +import pathlib +import sys +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Tuple + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +import tqdm +from fvdb_reality_capture.sfm_scene import SfmScene +from langsplatv2.evaluation.openclip_relevancy import OpenCLIPRelevancy +from langsplatv2.model import LangSplatV2Model +from langsplatv2.training.trainer import LangSplatV2Trainer +from langsplatv2.util import load_splats_from_file + +matplotlib.use("Agg") # Use non-interactive backend + + +# --------------------------------------------------------------------------- +# Ground truth parsing (matching original eval_gt_lerfdata) +# --------------------------------------------------------------------------- + + +def polygon_to_mask(img_shape: Tuple[int, int], points_list: list) -> np.ndarray: + """Convert a polygon to a binary mask. + + Args: + img_shape: (height, width) of the target image. + points_list: List of [x, y] polygon vertices. + + Returns: + Binary mask of shape ``(height, width)`` with dtype uint8. + """ + points = np.asarray(points_list, dtype=np.int32) + mask = np.zeros(img_shape, dtype=np.uint8) + cv2.fillPoly(mask, [points], 1) + return mask + + +def stack_mask(mask_base: np.ndarray, mask_add: np.ndarray) -> np.ndarray: + """Merge two binary masks (logical OR).""" + mask = mask_base.copy() + mask[mask_add != 0] = 1 + return mask + + +def load_lerf_ground_truth( + json_folder: pathlib.Path, + logger: logging.Logger, +) -> Tuple[Dict, Tuple[int, int], list]: + """Parse LERF-OVS ground truth annotations from labelme-format JSONs. + + Matches the original ``eval_gt_lerfdata`` function exactly. + + Args: + json_folder: Path to the label folder for a specific scene + (e.g. ``lerf_ovs/label/teatime``). + logger: Logger instance. + + Returns: + Tuple of: + - gt_ann: ``{str(frame_idx): {label: {"bboxes": ndarray, "mask": ndarray}}}`` + - image_shape: ``(height, width)`` of the ground truth images + - img_paths: sorted list of GT JPEG image paths + """ + gt_json_paths = sorted(glob.glob(os.path.join(str(json_folder), "frame_*.json"))) + img_paths = sorted(glob.glob(os.path.join(str(json_folder), "frame_*.jpg"))) + + if not gt_json_paths: + raise FileNotFoundError(f"No frame_*.json files found in {json_folder}") + + logger.info(f"Found {len(gt_json_paths)} GT annotations, {len(img_paths)} GT images in {json_folder}") + + gt_ann = {} + h, w = 0, 0 + for js_path in gt_json_paths: + img_ann: Dict[str, dict] = defaultdict(dict) + with open(js_path, "r") as f: + gt_data = json.load(f) + + h, w = gt_data["info"]["height"], gt_data["info"]["width"] + # Frame index: frame_00001 -> idx=0 (1-indexed filename to 0-indexed) + idx = int(gt_data["info"]["name"].split("_")[-1].split(".jpg")[0]) - 1 + + for prompt_data in gt_data["objects"]: + label = prompt_data["category"] + box = np.asarray(prompt_data["bbox"]).reshape(-1) # x1y1x2y2 + mask = polygon_to_mask((h, w), prompt_data["segmentation"]) + + if img_ann[label].get("mask", None) is not None: + # Merge multiple objects with the same label + mask = stack_mask(img_ann[label]["mask"], mask) + img_ann[label]["bboxes"] = np.concatenate( + [img_ann[label]["bboxes"].reshape(-1, 4), box.reshape(-1, 4)], axis=0 + ) + else: + img_ann[label]["bboxes"] = box + img_ann[label]["mask"] = mask + + gt_ann[str(idx)] = dict(img_ann) + + logger.info(f"Parsed GT: {len(gt_ann)} frames, image size {w}x{h}") + for idx_str, ann in gt_ann.items(): + labels = list(ann.keys()) + logger.debug(f" Frame {idx_str}: {len(labels)} labels: {labels}") + + return gt_ann, (h, w), img_paths + + +# --------------------------------------------------------------------------- +# Segmentation and localization (matching original eval_lerf.py) +# --------------------------------------------------------------------------- + + +def smooth_mask(mask_pred: torch.Tensor) -> torch.Tensor: + """Smooth a binary mask using average pooling (matching original smooth_cuda). + + Args: + mask_pred: Binary mask tensor ``[H, W]`` of type uint8. + + Returns: + Smoothed binary mask ``[H, W]`` of type uint8. + """ + scale = 7 + avg_pool = torch.nn.AvgPool2d(kernel_size=scale, stride=1, padding=3, count_include_pad=False).to(mask_pred.device) + avg_filtered = avg_pool(mask_pred.float().unsqueeze(0).unsqueeze(0)) + mask = (avg_filtered > 0.5).to(torch.uint8).squeeze(0).squeeze(0) + return mask + + +def segmentation_process( + relevancy_map: torch.Tensor, + thresh: float, + img_ann: dict, + prompts: list, + device: torch.device, +) -> Tuple[list, list, dict]: + """Compute segmentation IoU for each prompt across all levels. + + Replicates the original ``segmentation_process_cuda`` exactly: + 1. AvgPool2d(29) smoothing blended 50/50 with raw relevancy + 2. Min-max normalize to [-1, 1] then clip to [0, 1] + 3. Threshold and smooth with AvgPool2d(7) + 4. IoU against GT mask + 5. Choose level with highest max relevancy score + + Args: + relevancy_map: ``[n_levels, n_prompts, H, W]`` relevancy values. + thresh: Mask threshold (default 0.4). + img_ann: GT annotations for this frame ``{label: {"mask": ndarray, "bboxes": ndarray}}``. + prompts: List of prompt labels (same order as relevancy_map's prompt dim). + device: Torch device. + + Returns: + Tuple of: + - chosen_iou_list: per-prompt IoU at the chosen level + - chosen_lvl_list: chosen level index per prompt + - iou_all: ``{prompt: [iou_level_0, iou_level_1, iou_level_2]}`` + """ + n_head, n_prompt, h, w = relevancy_map.shape + valid_map = relevancy_map.clone() + + chosen_iou_list = [] + chosen_lvl_list = [] + iou_all = {} + + for k in range(n_prompt): + iou_lvl = torch.zeros(n_head, device=device) + for i in range(n_head): + # AvgPool smoothing (kernel=29) + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(valid_map[i][k].unsqueeze(0).unsqueeze(0)) + valid_map[i][k] = 0.5 * (avg_filtered.squeeze(0).squeeze(0) + valid_map[i][k]) + + # Normalize to [-1, 1] then clip to [0, 1] + output = valid_map[i][k] + output = output - torch.min(output) + output = output / (torch.max(output) + 1e-9) + output = output * 2.0 - 1.0 # scale to [-1, 1] + output = torch.clip(output, 0, 1) + + # Threshold and smooth + mask_pred = (output > thresh).to(torch.uint8) + mask_pred = smooth_mask(mask_pred) + + # GT mask + mask_gt = torch.from_numpy(img_ann[prompts[k]]["mask"].astype(np.uint8)).to(device) + + # IoU + intersection = torch.sum(torch.logical_and(mask_gt, mask_pred)) + union = torch.sum(torch.logical_or(mask_gt, mask_pred)) + iou = intersection.float() / union.float() + iou_lvl[i] = iou + + iou_all[prompts[k]] = iou_lvl.tolist() + + # Choose level with highest max relevancy score + score_lvl = torch.zeros(n_head, device=valid_map.device) + for i in range(n_head): + score_lvl[i] = valid_map[i, k].max() + chosen_lvl = torch.argmax(score_lvl) + + chosen_iou_list.append(iou_lvl[chosen_lvl].cpu().item()) + chosen_lvl_list.append(chosen_lvl.cpu().item()) + + return chosen_iou_list, chosen_lvl_list, iou_all + + +def localization_process( + relevancy_map: torch.Tensor, + img_ann: dict, + device: torch.device, +) -> int: + """Compute localization accuracy (peak-in-bbox check). + + Replicates the original ``localization_process_cuda`` exactly: + 1. AvgPool2d(29) smoothing of relevancy + 2. For each prompt, find the peak location at the best level + 3. Check if peak falls inside any GT bbox + + Args: + relevancy_map: ``[n_levels, n_prompts, H, W]`` relevancy values. + img_ann: GT annotations ``{label: {"mask": ndarray, "bboxes": ndarray}}``. + device: Torch device. + + Returns: + Number of correctly localized bboxes. + """ + n_head, n_prompt, h, w = relevancy_map.shape + + positives = list(img_ann.keys()) + acc_num = 0 + + for k in range(n_prompt): + select_output = relevancy_map[:, k] # [n_head, H, W] + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(select_output.unsqueeze(1)).squeeze(1) # [n_head, H, W] + + score_lvl = torch.zeros(n_head) + coord_lvl = [] + for i in range(n_head): + score = avg_filtered[i].max() + coord = torch.nonzero((avg_filtered[i] == score).to(torch.uint8)) + score_lvl[i] = score + coord_lvl.append(coord) + + selec_head = torch.argmax(score_lvl) + coord_final = coord_lvl[selec_head] + + for box in img_ann[positives[k]]["bboxes"].reshape(-1, 4): + x1, y1, x2, y2 = box + x_min, x_max = min(x1, x2), max(x1, x2) + y_min, y_max = min(y1, y2), max(y1, y2) + flag = 0 + for cord_list in coord_final: + # coord is (row, col) = (y, x) + if cord_list[1] >= x_min and cord_list[1] <= x_max and cord_list[0] >= y_min and cord_list[0] <= y_max: + acc_num += 1 + flag = 1 + break + if flag != 0: + break + + return acc_num + + +# --------------------------------------------------------------------------- +# Model loading helpers +# --------------------------------------------------------------------------- + + +def load_langsplatv2_model( + checkpoint_path: pathlib.Path, + gs_model_path: pathlib.Path, + device: torch.device, + logger: logging.Logger, + eval_topk: int | None = None, +) -> Tuple[LangSplatV2Model, SfmScene]: + """Load a trained LangSplatV2 model from a checkpoint. + + Args: + checkpoint_path: Path to the ``.pt`` checkpoint file. + gs_model_path: Path to the Gaussian splat PLY file. + device: Device to load the model on. + logger: Logger instance. + eval_topk: If set, override the checkpoint's topk value. The + original LangSplatV2 trains with topk=4. + + Returns: + Tuple of (LangSplatV2Model, SfmScene) from the checkpoint. + """ + # Load the base Gaussian splat + gs_model, _ = load_splats_from_file(gs_model_path, device) + logger.info(f"Loaded Gaussian splat with {gs_model.num_gaussians} gaussians from {gs_model_path}") + + # Load checkpoint and create trainer (eval-only mode) + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False) + trainer = LangSplatV2Trainer.from_state_dict( + state_dict=state_dict, + gs_model=gs_model, + gs_model_path=gs_model_path, + device=device, + eval_only=True, + ) + + model = trainer._model + sfm_scene = trainer._sfm_scene + feature_level = trainer._cfg.feature_level + + if eval_topk is not None and model.topk != eval_topk: + logger.info(f"Overriding topk: {model.topk} (checkpoint) -> {eval_topk} (eval)") + model.topk = eval_topk + + logger.info( + f"Loaded LangSplatV2 model (feature_level={feature_level}, topk={model.topk}) " f"from {checkpoint_path}" + ) + + return model, sfm_scene + + +def render_clip_features( + model: LangSplatV2Model, + world_to_camera: torch.Tensor, + projection: torch.Tensor, + image_width: int, + image_height: int, +) -> torch.Tensor: + """Render CLIP feature maps from a LangSplatV2 model. + + Args: + model: Trained LangSplatV2Model. + world_to_camera: ``[1, 4, 4]`` world-to-camera matrix. + projection: ``[1, 3, 3]`` camera intrinsics. + image_width: Render width. + image_height: Render height. + + Returns: + Normalized CLIP features ``[H, W, 512]``. + """ + with torch.no_grad(): + feature_maps, _ = model( + world_to_camera=world_to_camera, + projection=projection, + image_width=image_width, + image_height=image_height, + ) + # feature_maps: [1, H, W, 512], normalize + feat = feature_maps[0] # [H, W, 512] + feat = feat / (feat.norm(dim=-1, keepdim=True) + 1e-10) + return feat + + +# --------------------------------------------------------------------------- +# Visualization +# --------------------------------------------------------------------------- + + +def save_frame_visualization( + output_path: pathlib.Path, + frame_idx: int, + gt_img: np.ndarray | None, + relevancy_map: torch.Tensor, + prompts: list, + img_ann: dict, + chosen_iou_list: list, + chosen_lvl_list: list, + thresh: float, + device: torch.device, +): + """Save a per-frame visualization showing relevancy and segmentation. + + Creates a grid with one row per prompt showing: + - GT image with GT mask overlay + - Relevancy heatmap at chosen level + - Predicted mask at chosen level + - Overlay (predicted=red, GT=blue) + + Args: + output_path: Directory for output images. + frame_idx: Frame index. + gt_img: Optional GT image ``[H, W, 3]`` (RGB uint8). + relevancy_map: ``[n_levels, n_prompts, H, W]``. + prompts: List of prompt labels. + img_ann: GT annotations for this frame. + chosen_iou_list: IoU values per prompt. + chosen_lvl_list: Chosen level per prompt. + thresh: Mask threshold. + device: Torch device. + """ + n_prompts = len(prompts) + n_levels, _, h, w = relevancy_map.shape + + fig, axes = plt.subplots(n_prompts, 4, figsize=(24, 6 * n_prompts)) + if n_prompts == 1: + axes = axes.reshape(1, -1) + + for k, prompt in enumerate(prompts): + lvl = chosen_lvl_list[k] + iou = chosen_iou_list[k] + + # Get relevancy at chosen level + relev = relevancy_map[lvl, k].cpu().numpy() + + # Recompute predicted mask at chosen level (same as segmentation_process) + relev_t = relevancy_map[lvl, k].clone() + avg_pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False).to(device) + avg_filtered = avg_pool(relev_t.unsqueeze(0).unsqueeze(0)) + blended = 0.5 * (avg_filtered.squeeze(0).squeeze(0) + relev_t) + output = blended - torch.min(blended) + output = output / (torch.max(output) + 1e-9) + output = output * 2.0 - 1.0 + output = torch.clip(output, 0, 1) + mask_pred = (output > thresh).to(torch.uint8) + mask_pred = smooth_mask(mask_pred) + mask_pred_np = mask_pred.cpu().numpy() + + mask_gt = img_ann[prompt]["mask"] + + # Col 0: GT image with GT mask overlay + if gt_img is not None: + overlay_gt = gt_img.copy().astype(np.float32) / 255.0 + else: + overlay_gt = np.zeros((h, w, 3), dtype=np.float32) + overlay_gt[:, :, 2] = np.clip(overlay_gt[:, :, 2] + mask_gt * 0.4, 0, 1) + axes[k, 0].imshow(overlay_gt) + axes[k, 0].set_title(f'GT: "{prompt}"') + axes[k, 0].axis("off") + + # Col 1: Relevancy heatmap + im = axes[k, 1].imshow(relev, cmap="jet", vmin=0, vmax=1) + axes[k, 1].set_title(f"Relevancy (level {lvl + 1})") + axes[k, 1].axis("off") + plt.colorbar(im, ax=axes[k, 1], fraction=0.046, pad=0.04) + + # Col 2: Predicted mask + axes[k, 2].imshow(mask_pred_np, cmap="gray") + axes[k, 2].set_title(f"Pred mask (thresh={thresh})") + axes[k, 2].axis("off") + + # Col 3: Overlay (red=pred, blue=GT) + if gt_img is not None: + overlay = gt_img.copy().astype(np.float32) / 255.0 + else: + overlay = np.zeros((h, w, 3), dtype=np.float32) + mask_overlay = np.zeros_like(overlay) + mask_overlay[:, :, 0] = mask_pred_np # Red for prediction + mask_overlay[:, :, 2] = mask_gt # Blue for GT + overlay = overlay * 0.5 + mask_overlay * 0.5 + axes[k, 3].imshow(np.clip(overlay, 0, 1)) + axes[k, 3].contour(mask_gt, colors="blue", linewidths=1, linestyles="solid") + axes[k, 3].contour(mask_pred_np, colors="red", linewidths=1, linestyles="dashed") + axes[k, 3].set_title(f"Overlay (IoU={iou * 100:.1f}%)") + axes[k, 3].axis("off") + + fig.suptitle(f"Frame {frame_idx}") + fig.tight_layout() + fig.savefig(output_path / f"frame_{frame_idx:05d}.jpg", dpi=150, pil_kwargs={"quality": 90}) + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Main evaluation +# --------------------------------------------------------------------------- + + +@dataclass +class EvaluationConfig: + """Configuration for LERF evaluation.""" + + lerf_root: pathlib.Path = pathlib.Path("data/lerf_ovs") + """Root directory of the LERF-OVS dataset.""" + + results_root: pathlib.Path = pathlib.Path("langsplatv2_results") + """Directory containing trained model checkpoints. + + Expected layout:: + + results_root/ + _level_1.pt + _level_2.pt + _level_3.pt + """ + + reconstructions_root: pathlib.Path | None = pathlib.Path("reconstructions") + """Directory containing per-scene Gaussian splat reconstructions. + + Expected layout:: + + reconstructions_root/ + .ply + + If None, falls back to the LERF dataset structure at + ``lerf_root//output//point_cloud/iteration_30000/point_cloud.ply``. + """ + + output_dir: pathlib.Path = pathlib.Path("lerf_eval_results") + """Directory to save evaluation results and visualizations.""" + + device: str = "cuda" + """Device for computation.""" + + mask_thresh: float = 0.4 + """Threshold for converting relevancy map to binary mask (matching original).""" + + save_visualizations: bool = True + """Whether to save per-frame visualization images.""" + + eval_topk: int = 4 + """Number of codebook entries to combine at evaluation time. + + The original LangSplatV2 trains with topk=4""" + + +def get_camera_for_frame( + sfm_scene, + frame_idx: int, + device: torch.device, + logger: logging.Logger, +) -> Tuple[torch.Tensor, torch.Tensor, int, int] | None: + """Get camera parameters for a specific frame index. + + Images in the SfmScene are sorted by name to match the COLMAP ordering + used by the original LangSplatV2 dataset. + + Args: + sfm_scene: The SfmScene from the checkpoint. + frame_idx: 0-based frame index. + device: Torch device. + logger: Logger instance. + + Returns: + Tuple of (world_to_camera [1,4,4], projection [1,3,3], height, width) + or None if index is out of range. + """ + # Sort images by name to match COLMAP ordering + sorted_images = sorted(sfm_scene.images, key=lambda img: img.image_path) + + if frame_idx >= len(sorted_images): + logger.error(f"Frame index {frame_idx} out of range ({len(sorted_images)} images)") + return None + + img_meta = sorted_images[frame_idx] + c2w = torch.from_numpy(img_meta.camera_to_world_matrix).float() + K = torch.from_numpy(img_meta.camera_metadata.projection_matrix).float() + h = img_meta.camera_metadata.height + w = img_meta.camera_metadata.width + + w2c = torch.linalg.inv(c2w).contiguous() + + return ( + w2c.unsqueeze(0).to(device), + K.unsqueeze(0).to(device), + h, + w, + ) + + +def run_lerf_evaluation( + scene_name: str, + config: EvaluationConfig, + logger: logging.Logger, +) -> dict | None: + """Run evaluation on a single LERF scene. + + Args: + scene_name: Name of the scene (e.g. "teatime", "figurines"). + config: Evaluation configuration. + logger: Logger instance. + + Returns: + Dictionary with evaluation results, or None if evaluation failed. + """ + device = torch.device(config.device) + + # --- Locate checkpoint files --- + level_checkpoints = [] + for level in [1, 2, 3]: + ckpt_path = config.results_root / f"{scene_name}_level_{level}.pt" + if not ckpt_path.exists(): + logger.error(f"Checkpoint not found: {ckpt_path}") + return None + level_checkpoints.append(ckpt_path) + + # --- Locate GS model for this scene --- + gs_path = None + # Try reconstructions_root/.ply first + if config.reconstructions_root is not None: + candidate = config.reconstructions_root / f"{scene_name}.ply" + if candidate.exists(): + gs_path = candidate + else: + # Also try .pt/.pth extensions + for ext in (".pt", ".pth"): + candidate = config.reconstructions_root / f"{scene_name}{ext}" + if candidate.exists(): + gs_path = candidate + break + + # Fall back to LERF dataset structure + if gs_path is None: + gs_path = ( + config.lerf_root + / scene_name + / "output" + / scene_name + / "point_cloud" + / "iteration_30000" + / "point_cloud.ply" + ) + + if not gs_path.exists(): + logger.error( + f"GS model not found for scene '{scene_name}'. " + f"Searched: reconstructions_root={config.reconstructions_root}, " + f"lerf_root={config.lerf_root}//output/..." + ) + return None + + logger.info(f"Using GS model: {gs_path}") + + # --- Load ground truth --- + label_dir = config.lerf_root / "label" / scene_name + if not label_dir.exists(): + logger.error(f"Label directory not found: {label_dir}") + return None + + gt_ann, image_shape, gt_img_paths = load_lerf_ground_truth(label_dir, logger) + eval_frame_indices = [int(idx) for idx in gt_ann.keys()] + gt_h, gt_w = image_shape + + # --- Load models (3 levels) --- + models: list[LangSplatV2Model] = [] + sfm_scene = None + for i, ckpt_path in enumerate(level_checkpoints): + model, scene = load_langsplatv2_model(ckpt_path, gs_path, device, logger, eval_topk=config.eval_topk) + models.append(model) + if sfm_scene is None: + sfm_scene = scene + + # --- Load OpenCLIP for relevancy --- + clip_relevancy = OpenCLIPRelevancy(device=config.device) + + # --- Evaluate each annotated frame --- + chosen_iou_all = [] + chosen_lvl_list_all = [] + acc_num_total = 0 + total_bboxes = 0 + per_frame_results = [] + + scene_output_dir = config.output_dir / scene_name + scene_output_dir.mkdir(parents=True, exist_ok=True) + + for frame_enum_idx, frame_idx in enumerate( + tqdm.tqdm(eval_frame_indices, desc=f"Evaluating {scene_name}", leave=False) + ): + # Load GT image for visualization (optional) + gt_img = None + if frame_enum_idx < len(gt_img_paths) and config.save_visualizations: + gt_img_bgr = cv2.imread(gt_img_paths[frame_enum_idx]) + if gt_img_bgr is not None: + gt_img = cv2.cvtColor(gt_img_bgr, cv2.COLOR_BGR2RGB) + + # Get camera parameters for this frame + cam_info = get_camera_for_frame(sfm_scene, frame_idx, device, logger) + if cam_info is None: + logger.warning(f"Skipping frame {frame_idx}: camera not found") + continue + w2c, K, img_h, img_w = cam_info + + # Render CLIP features from all 3 levels + sem_feats = [] + for model in models: + feat = render_clip_features(model, w2c, K, img_w, img_h) + sem_feats.append(feat) + + # Stack: [3, H, W, 512] + sem_map = torch.stack(sem_feats, dim=0) + + # Get GT annotations for this frame + img_ann = gt_ann[str(frame_idx)] + prompts = list(img_ann.keys()) + + # Set positive prompts in CLIP model + clip_relevancy.set_positives(prompts) + + # Compute relevancy: [3, n_prompts, H, W] + relevancy_map = clip_relevancy.get_relevancy_map(sem_map) + + # Resize relevancy to GT resolution if needed + _, _, relev_h, relev_w = relevancy_map.shape + if relev_h != gt_h or relev_w != gt_w: + relevancy_map = torch.nn.functional.interpolate( + relevancy_map.reshape(-1, 1, relev_h, relev_w), + size=(gt_h, gt_w), + mode="bilinear", + align_corners=False, + ).reshape(3, len(prompts), gt_h, gt_w) + + # Segmentation IoU + # Clone relevancy_map because segmentation_process modifies it in-place + c_iou_list, c_lvl_list, iou_all = segmentation_process( + relevancy_map.clone(), config.mask_thresh, img_ann, prompts, device + ) + chosen_iou_all.extend(c_iou_list) + chosen_lvl_list_all.extend(c_lvl_list) + + # Localization accuracy + acc_num_img = localization_process(relevancy_map.clone(), img_ann, device) + acc_num_total += acc_num_img + total_bboxes += len(prompts) + + # Per-frame record + frame_result = { + "frame_idx": frame_idx, + "prompts": prompts, + "ious": c_iou_list, + "chosen_levels": c_lvl_list, + "iou_all_levels": iou_all, + "localization_correct": acc_num_img, + "num_bboxes": len(prompts), + } + per_frame_results.append(frame_result) + + logger.info( + f" Frame {frame_idx}: " + + ", ".join( + f'"{p}" IoU={iou * 100:.1f}% (lvl {lvl + 1})' for p, iou, lvl in zip(prompts, c_iou_list, c_lvl_list) + ) + + f" | loc={acc_num_img}/{len(prompts)}" + ) + + # Save visualization + if config.save_visualizations: + save_frame_visualization( + scene_output_dir, + frame_idx, + gt_img, + relevancy_map, + prompts, + img_ann, + c_iou_list, + c_lvl_list, + config.mask_thresh, + device, + ) + + # --- Scene summary --- + if not chosen_iou_all: + logger.warning(f"No successful evaluations for scene {scene_name}") + return None + + mean_iou = float(np.mean(chosen_iou_all)) + loc_accuracy = acc_num_total / total_bboxes if total_bboxes > 0 else 0.0 + + logger.info( + f"Scene {scene_name}: mean IoU = {mean_iou * 100:.1f}%, localization accuracy = {loc_accuracy * 100:.1f}%" + ) + logger.info(f" Chosen levels: {chosen_lvl_list_all}") + + return { + "scene": scene_name, + "mean_iou": mean_iou, + "localization_accuracy": loc_accuracy, + "localization_correct": acc_num_total, + "localization_total": total_bboxes, + "per_prompt_ious": chosen_iou_all, + "chosen_levels": chosen_lvl_list_all, + "mask_thresh": config.mask_thresh, + "per_frame_results": per_frame_results, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate LangSplatV2 models on the LERF-OVS dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--lerf-root", + type=pathlib.Path, + default=pathlib.Path("data/lerf_ovs"), + help="Root directory of the LERF-OVS dataset", + ) + parser.add_argument( + "--results-root", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_results"), + help="Directory containing trained model checkpoints " + "(_level_1.pt, _level_2.pt, _level_3.pt)", + ) + parser.add_argument( + "--reconstructions-root", + type=pathlib.Path, + default=pathlib.Path("reconstructions"), + help="Directory containing per-scene Gaussian splat reconstructions " + "(.ply). If not found, falls back to LERF dataset structure.", + ) + parser.add_argument( + "--output-dir", + type=pathlib.Path, + default=pathlib.Path("lerf_eval_results"), + help="Directory to save evaluation results and visualizations", + ) + parser.add_argument( + "--scenes", + nargs="+", + type=str, + default=None, + help="Specific scenes to evaluate. If not specified, discovers available scenes.", + ) + parser.add_argument( + "--mask-thresh", + type=float, + default=0.4, + help="Threshold for converting relevancy map to binary mask", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for computation (cuda or cuda:N)", + ) + parser.add_argument( + "--eval-topk", + type=int, + default=4, + help="Number of codebook entries to combine at evaluation time. " + "The original LangSplatV2 trains with topk=4.", + ) + parser.add_argument( + "--no-visualizations", + action="store_true", + help="Disable saving per-frame visualization images", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Set up logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger("lerf_eval") + + # Build config + config = EvaluationConfig( + lerf_root=args.lerf_root, + results_root=args.results_root, + reconstructions_root=args.reconstructions_root, + output_dir=args.output_dir, + device=args.device, + mask_thresh=args.mask_thresh, + save_visualizations=not args.no_visualizations, + eval_topk=args.eval_topk, + ) + + # Determine scenes to evaluate + if args.scenes: + scenes_to_eval = args.scenes + else: + # Auto-discover scenes from checkpoint files + ckpt_files = list(config.results_root.glob("*_level_1.pt")) + if ckpt_files: + scenes_to_eval = [f.stem.replace("_level_1", "") for f in ckpt_files] + else: + # Fall back to label directories + label_root = config.lerf_root / "label" + if label_root.exists(): + scenes_to_eval = [d.name for d in label_root.iterdir() if d.is_dir()] + else: + logger.error("No scenes found. Specify --scenes or check paths.") + return + + logger.info(f"Evaluating scenes: {scenes_to_eval}") + logger.info(f"Mask threshold: {config.mask_thresh}") + + # Run evaluation + results = [] + for scene_name in scenes_to_eval: + logger.info("=" * 60) + logger.info(f"Evaluating scene: {scene_name}") + logger.info("=" * 60) + + result = run_lerf_evaluation(scene_name, config, logger) + if result is not None: + results.append(result) + + # Print summary + logger.info("=" * 60) + logger.info("EVALUATION COMPLETE - Summary:") + logger.info("=" * 60) + + if results: + ious = [r["mean_iou"] for r in results] + accs = [r["localization_accuracy"] for r in results] + overall_mean_iou = float(np.mean(ious)) + overall_loc_acc = float(np.mean(accs)) + + logger.info(f"{'Scene':<20} {'mean IoU':>10} {'Loc Acc':>10}") + logger.info("-" * 42) + for r in results: + logger.info(f"{r['scene']:<20} {r['mean_iou'] * 100:>9.1f}% {r['localization_accuracy'] * 100:>9.1f}%") + logger.info("-" * 42) + logger.info(f"{'Overall':<20} {overall_mean_iou * 100:>9.1f}% {overall_loc_acc * 100:>9.1f}%") + + # Save results JSON + results_file = config.output_dir / "lerf_results.json" + results_file.parent.mkdir(parents=True, exist_ok=True) + + # Strip per-frame results for the summary (keep in per-scene files) + summary_results = [] + for r in results: + # Save per-scene detailed results + scene_results_file = config.output_dir / r["scene"] / "results.json" + scene_results_file.parent.mkdir(parents=True, exist_ok=True) + with open(scene_results_file, "w") as f: + json.dump(r, f, indent=2) + logger.info(f"Per-scene results saved to {scene_results_file}") + + # Summary (without per-frame details) + summary = {k: v for k, v in r.items() if k != "per_frame_results"} + summary_results.append(summary) + + with open(results_file, "w") as f: + json.dump( + { + "results": summary_results, + "overall_mean_iou": overall_mean_iou, + "overall_localization_accuracy": overall_loc_acc, + "config": { + "mask_thresh": config.mask_thresh, + "device": config.device, + }, + }, + f, + indent=2, + ) + logger.info(f"Summary results saved to {results_file}") + else: + logger.warning("No successful evaluations") + + +if __name__ == "__main__": + main() diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py index eaa05cb..56ff78d 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/config.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # from dataclasses import dataclass, field +from pathlib import Path from typing import Literal import numpy as np @@ -17,7 +18,12 @@ TransformScene, ) -from .scene_transforms import ComputeCLIPFeatures, ComputeMultiScaleSAM2Masks +from .scene_transforms import ( + ComputeCLIPFeatures, + ComputeMultiScaleSAM1Masks, + ComputeMultiScaleSAM2Masks, + ImportOriginalLangSplatV2Features, +) @dataclass @@ -25,13 +31,77 @@ class SAM2Config: """Configuration for SAM2 multi-scale mask generation.""" checkpoint: Literal["large", "small", "tiny", "base_plus"] = "large" - """SAM2 checkpoint size to use.""" + """SAM2 checkpoint size to use. Larger models produce higher-quality masks + but run a heavier image encoder. "base_plus" or "small" can noticeably + reduce per-image encoding time at some cost to mask quality.""" + + points_per_side: int = 32 + """Grid density for point prompts. Total points = points_per_side**2 + (e.g. 32 -> 1024, 16 -> 256). Reducing this value is the single + largest lever for speeding up mask generation because each point + requires a decoder forward pass.""" + + points_per_batch: int = 256 + """Points processed simultaneously by SAM2. Larger values reduce the + number of decoder forward passes (fewer kernel launches) at the cost + of higher peak GPU memory. 256 is safe on 24 GB+ GPUs.""" + + pred_iou_thresh: float = 0.5 + """Predicted IoU threshold for mask filtering. SAM2 predicts + substantially lower IoU scores than SAM1 (e.g. small-mask means + of 0.3-0.6 vs 0.8-0.9), so the threshold is set lower to achieve + comparable mask survival rates.""" + + stability_score_thresh: float = 0.85 + """Stability score threshold for mask filtering.""" + + crop_n_layers: int = 1 + """Number of crop layers. With ``crop_n_layers=1`` (the LangSplatV2 + default), 5 crops are generated per image (1 full + 4 overlapping + sub-crops), each running a full encoder + decoder pass. Setting + this to 0 reduces to a single full-image crop (~5x fewer passes) + at the cost of losing detail from sub-crop masks.""" + + crop_n_points_downscale_factor: int = 2 + """Point grid downscale factor per crop layer. Sub-crop layers use + (points_per_side / 2)**2 points instead of points_per_side**2, + reducing decoder cost on sub-crops by ~4x while keeping the + full-image crop at full density. The original LangSplatV2 uses 1 + with SAM ViT-H; 2 is a reasonable default with SAM2.""" + + min_mask_region_area: int = 100 + """Minimum mask region area for post-processing (matching the original + LangSplatV2 which uses ``min_mask_region_area=100``).""" + + box_nms_thresh: float = 0.7 + """Box NMS IoU threshold within each crop.""" + + nms_iou_thr: float = 0.8 + """IoU threshold for mask NMS post-processing.""" + + nms_score_thr: float = 0.7 + """Score threshold for mask NMS.""" + + nms_inner_thr: float = 0.5 + """Inner overlap threshold for mask NMS.""" + + +@dataclass +class SAM1Config: + """Configuration for SAM1 multi-scale mask generation. + + Default values match the original LangSplatV2 ``preprocess.py`` exactly + (SAM ViT-H with ``crop_n_layers=1``, ``crop_n_points_downscale_factor=1``). + """ + + checkpoint: Literal["vit_h", "vit_l", "vit_b"] = "vit_h" + """SAM1 model variant. The original LangSplatV2 uses ViT-H.""" points_per_side: int = 32 """Grid density for point prompts.""" - points_per_batch: int = 64 - """Points processed simultaneously by SAM2.""" + points_per_batch: int = 256 + """Points processed simultaneously by SAM1.""" pred_iou_thresh: float = 0.7 """Predicted IoU threshold for mask filtering.""" @@ -40,15 +110,14 @@ class SAM2Config: """Stability score threshold for mask filtering.""" crop_n_layers: int = 1 - """Number of crop layers. 1 = also run SAM on image crops (matching - the original LangSplatV2 which uses ``crop_n_layers=1``).""" + """Number of crop layers (1 = also run on crops, matching original).""" crop_n_points_downscale_factor: int = 1 - """Point grid downscale factor per crop layer.""" + """Point grid downscale factor per crop layer. The original LangSplatV2 + uses 1 (no downscaling on sub-crops).""" min_mask_region_area: int = 100 - """Minimum mask region area for post-processing (matching the original - LangSplatV2 which uses ``min_mask_region_area=100``).""" + """Minimum mask region area for post-processing.""" box_nms_thresh: float = 0.7 """Box NMS IoU threshold within each crop.""" @@ -124,12 +193,20 @@ class LangSplatV2PreprocessConfig: crop_to_points: bool = False """If True, crop scene bounds to the point cloud extent.""" - # SAM2 configuration + # SAM configuration + sam_model: Literal["sam1", "sam2"] = "sam2" + """Which SAM model to use for mask generation. ``"sam1"`` uses the + original Segment Anything Model (ViT-H by default) matching the original + LangSplatV2 pipeline. ``"sam2"`` uses SAM2 (Hiera-Large by default).""" + + sam1: SAM1Config = field(default_factory=SAM1Config) + """Configuration for SAM1 mask generation (used when ``sam_model="sam1"``).""" + sam2: SAM2Config = field(default_factory=SAM2Config) - """Configuration for SAM2 mask generation.""" + """Configuration for SAM2 mask generation (used when ``sam_model="sam2"``).""" compute_sam_masks: bool = True - """Whether to compute SAM2 segmentation masks.""" + """Whether to compute SAM segmentation masks.""" # CLIP configuration clip: OpenCLIPConfig = field(default_factory=OpenCLIPConfig) @@ -138,6 +215,15 @@ class LangSplatV2PreprocessConfig: compute_clip_features: bool = True """Whether to compute CLIP features for masked regions.""" + # Import original features (bypasses SAM2 + CLIP) + original_features_dir: Path | None = None + """Path to the original LangSplatV2 ``language_features/`` directory. + + When set, imports pre-computed ``_f.npy`` / ``_s.npy`` files from the + original LangSplatV2 ``preprocess.py`` instead of running SAM2 mask + generation and CLIP feature encoding. Useful for A/B testing against + the original pipeline.""" + # Device device: torch.device | str = "cuda" """Device for model inference.""" @@ -155,9 +241,10 @@ def build_scene_transforms( 2. Point cloud percentile filtering 3. Image downsampling 4. Image filtering by visible points - 5. Multi-scale SAM2 mask generation - 6. CLIP feature encoding - 7. Scene cropping (optional) + 5a. Multi-scale SAM1/SAM2 mask generation + CLIP feature encoding, OR + 5b. Import original LangSplatV2 features (when ``original_features_dir`` + is set) + 6. Scene cropping (optional) Args: normalization_transform: Optional 4x4 transformation matrix @@ -184,28 +271,26 @@ def build_scene_transforms( rescaled_jpeg_quality=self.rescale_jpeg_quality, ), FilterImagesWithLowPoints(min_num_points=self.min_points_per_image), - # SAM2 mask generation - ( - ComputeMultiScaleSAM2Masks( - checkpoint=self.sam2.checkpoint, - points_per_side=self.sam2.points_per_side, - points_per_batch=self.sam2.points_per_batch, - pred_iou_thresh=self.sam2.pred_iou_thresh, - stability_score_thresh=self.sam2.stability_score_thresh, - crop_n_layers=self.sam2.crop_n_layers, - crop_n_points_downscale_factor=self.sam2.crop_n_points_downscale_factor, - min_mask_region_area=self.sam2.min_mask_region_area, - box_nms_thresh=self.sam2.box_nms_thresh, - nms_iou_thr=self.sam2.nms_iou_thr, - nms_score_thr=self.sam2.nms_score_thr, - nms_inner_thr=self.sam2.nms_inner_thr, - device=self.device, + ] + + if self.original_features_dir is not None: + # Import pre-computed features from original LangSplatV2 + transforms.append( + ImportOriginalLangSplatV2Features( + original_features_dir=self.original_features_dir, + clip_n_dims=self.clip.clip_n_dims, ) - if self.compute_sam_masks - else Identity() - ), - # CLIP feature encoding - ( + ) + else: + # Standard pipeline: SAM masks then CLIP features + if self.compute_sam_masks: + if self.sam_model == "sam1": + transforms.append(self.build_sam1_transform()) + else: + transforms.append(self.build_sam2_transform()) + else: + transforms.append(Identity()) + transforms.append( ComputeCLIPFeatures( clip_model_type=self.clip.clip_model_type, clip_model_pretrained=self.clip.clip_model_pretrained, @@ -214,8 +299,7 @@ def build_scene_transforms( ) if self.compute_clip_features else Identity() - ), - ] + ) # Optional scene cropping if self.crop_bbox is not None: @@ -225,6 +309,29 @@ def build_scene_transforms( return Compose(*transforms) + def build_sam1_transform(self): + """ + Build only the SAM1 mask generation transform. + + Returns: + ComputeMultiScaleSAM1Masks transform. + """ + return ComputeMultiScaleSAM1Masks( + checkpoint=self.sam1.checkpoint, + points_per_side=self.sam1.points_per_side, + points_per_batch=self.sam1.points_per_batch, + pred_iou_thresh=self.sam1.pred_iou_thresh, + stability_score_thresh=self.sam1.stability_score_thresh, + crop_n_layers=self.sam1.crop_n_layers, + crop_n_points_downscale_factor=self.sam1.crop_n_points_downscale_factor, + min_mask_region_area=self.sam1.min_mask_region_area, + box_nms_thresh=self.sam1.box_nms_thresh, + nms_iou_thr=self.sam1.nms_iou_thr, + nms_score_thr=self.sam1.nms_score_thr, + nms_inner_thr=self.sam1.nms_inner_thr, + device=self.device, + ) + def build_sam2_transform(self): """ Build only the SAM2 mask generation transform. @@ -277,7 +384,9 @@ class LangSplatV2ModelConfig: """Dimensionality of CLIP embeddings.""" topk: int = 4 - """Number of non-zero sparse coefficients per VQ layer.""" + """Number of non-zero sparse coefficients per VQ layer. + + The original LangSplatV2 uses topk=4 for both training and evaluation.""" @dataclass @@ -318,6 +427,11 @@ class LangSplatV2TrainingConfig: normalize_features: bool = False """Whether to L2-normalize predicted features before computing loss.""" + init_codebooks_all_levels: bool = True + """If True, initialize codebooks using features from ALL scale levels + (matching original LangSplatV2). If False, use only the target + ``feature_level``.""" + model: LangSplatV2ModelConfig = field(default_factory=LangSplatV2ModelConfig) """Model architecture configuration.""" diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py new file mode 100644 index 0000000..6e14057 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/__init__.py @@ -0,0 +1,3 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py new file mode 100644 index 0000000..56d9ad9 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/__init__.py @@ -0,0 +1,23 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path +from typing import Optional + +_dataset_root = None + + +def set_dataset_root(dataset_root: Path): + global _dataset_root + _dataset_root = dataset_root + + +def get_dataset_root() -> Optional[Path]: + global _dataset_root + if _dataset_root is None: + return None + _dataset_root.mkdir(parents=True, exist_ok=True) + return _dataset_root + + +__all__ = ["set_dataset_root", "get_dataset_root"] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py new file mode 100644 index 0000000..4041ab5 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/lerf.py @@ -0,0 +1,78 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path + +from . import get_dataset_root +from .util import download_google_drive_file + +# Google Drive file ID extracted from the download link in the LangSplatV2 README: +# https://drive.google.com/file/d/1QF1Po5p5DwTjFHu6tnTeYs_G0egMVmHt/view?usp=sharing +LERF_GDRIVE_FILE_ID = "1QF1Po5p5DwTjFHu6tnTeYs_G0egMVmHt" + +# Scenes available in the LERF-OVS dataset +LERF_SCENES = [ + "figurines", + "ramen", + "teatime", + "waldo_kitchen", +] + + +def get_lerf_data_path() -> Path: + """Get the path to the LERF-OVS dataset root. + + Returns: + Path to the ``lerf_ovs`` directory inside the dataset root. + + Raises: + ValueError: If the dataset root has not been set via ``set_dataset_root()``. + """ + dataset_root = get_dataset_root() + if dataset_root is None: + raise ValueError("Dataset root is not set. Call set_dataset_root() first.") + result = dataset_root / "lerf_ovs" + if not result.exists(): + result.mkdir(parents=True, exist_ok=True) + return result + + +def download_lerf_data(): + """ + Download the LERF-OVS dataset from Google Drive. + + The dataset contains COLMAP scenes and labelme-format ground truth labels + for open-vocabulary segmentation evaluation. + + Expected layout after download:: + + lerf_ovs/ + label/ + / + frame_XXXXX.json + frame_XXXXX.jpg + / (COLMAP scene) + images/ + sparse/ + output/ + + The data will be saved to the LERF data path (dataset_root/lerf_ovs/). + """ + output_path = get_lerf_data_path() + + print(f"Downloading LERF-OVS dataset to: {output_path}") + + # Check if data already exists + label_dir = output_path / "label" + if label_dir.exists() and any(label_dir.iterdir()): + print(f"LERF-OVS data already exists at: {output_path}") + print("Delete the directory to re-download.") + return + + download_google_drive_file( + file_id=LERF_GDRIVE_FILE_ID, + output_path=output_path.parent, + filename="lerf_ovs.zip", + ) + + print(f"\nLERF-OVS dataset download complete: {output_path}") diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py new file mode 100644 index 0000000..535612a --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/datasets/util.py @@ -0,0 +1,43 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +from pathlib import Path +from zipfile import ZipFile + + +def download_google_drive_file(file_id: str, output_path: Path, filename: str = "download.zip") -> Path: + """ + Download a file from Google Drive using gdown. + + Args: + file_id: Google Drive file ID (from the sharing URL). + output_path: Directory to save and extract to. + filename: Name for the downloaded file. + + Returns: + Path to the output directory after extraction. + """ + try: + import gdown + except ImportError: + raise ImportError( + "gdown is required to download files from Google Drive. " + "Install it with: pip install gdown" + ) + + output_path.mkdir(parents=True, exist_ok=True) + zip_path = output_path / filename + + url = f"https://drive.google.com/uc?id={file_id}" + print(f"Downloading from Google Drive (file_id={file_id})...") + gdown.download(url, str(zip_path), quiet=False) + + if zip_path.suffix == ".zip": + print("Extracting archive...") + with ZipFile(zip_path, "r") as zf: + zf.extractall(output_path) + # Clean up zip file + zip_path.unlink() + print(f"Extracted to: {output_path}") + + return output_path diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py new file mode 100644 index 0000000..5ae9b25 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/evaluation/openclip_relevancy.py @@ -0,0 +1,152 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""OpenCLIP relevancy computation for LERF evaluation. + +Implements the relevancy scoring from the original LangSplatV2 evaluation +(``OpenCLIPNetwork.get_max_across_quick``). For each pixel, the score is the +minimum across all negative prompts of ``softmax(10 * [pos_sim, neg_sim])[0]``. + +""" +import logging +from typing import Sequence + +import open_clip +import torch +import torchvision + +logger = logging.getLogger(__name__) + +# Default negative prompts (matching the original LangSplatV2 evaluation) +DEFAULT_NEGATIVES = ("object", "things", "stuff", "texture") + + +class OpenCLIPRelevancy: + """Compute CLIP-based relevancy maps for open-vocabulary segmentation evaluation. + + This class encapsulates: + + * Loading an OpenCLIP model (ViT-B-16 by default) + * Encoding positive (query) and negative (distractor) text prompts + * Computing per-pixel relevancy scores from rendered CLIP feature maps + + The relevancy computation matches the original LangSplatV2 + ``get_max_across_quick`` exactly: + + 1. Dot-product similarity between per-pixel features and all prompt embeddings + 2. For each positive prompt and each negative prompt, form a 2-way softmax + with temperature 10 + 3. Take the *minimum* positive probability across all negatives as the + final relevancy score + + Args: + clip_model_type: OpenCLIP model architecture (default ``"ViT-B-16"``). + clip_model_pretrained: Pretrained weights identifier + (default ``"laion2b_s34b_b88k"``). + negatives: Tuple of negative distractor text prompts. + device: Device for model and embeddings. + """ + + def __init__( + self, + clip_model_type: str = "ViT-B-16", + clip_model_pretrained: str = "laion2b_s34b_b88k", + negatives: Sequence[str] = DEFAULT_NEGATIVES, + device: str | torch.device = "cuda", + ): + self.device = torch.device(device) + self.clip_model_type = clip_model_type + self.clip_model_pretrained = clip_model_pretrained + + # Select precision based on device: use fp16 only on CUDA, fp32 otherwise + precision = "fp16" if self.device.type == "cuda" else "fp32" + + # Load OpenCLIP model + model, _, _ = open_clip.create_model_and_transforms( + clip_model_type, + pretrained=clip_model_pretrained, + precision=precision, + ) + model.eval() + self.model = model.to(self.device) + self.tokenizer = open_clip.get_tokenizer(clip_model_type) + + # Encode negative prompts + self.negatives = tuple(negatives) + with torch.no_grad(): + tok = torch.cat([self.tokenizer(phrase) for phrase in self.negatives]).to(self.device) + self.neg_embeds = self.model.encode_text(tok) + self.neg_embeds = self.neg_embeds / self.neg_embeds.norm(dim=-1, keepdim=True) + + # Positive embeddings (set per evaluation frame) + self.positives: tuple[str, ...] = () + self.pos_embeds: torch.Tensor | None = None + + logger.info( + f"OpenCLIPRelevancy initialized: model={clip_model_type}, " + f"pretrained={clip_model_pretrained}, negatives={self.negatives}" + ) + + def set_positives(self, text_list: Sequence[str]) -> None: + """Encode and cache positive (query) text prompts. + + Args: + text_list: List of positive text prompts (e.g. category labels + from the ground-truth annotations). + """ + self.positives = tuple(text_list) + with torch.no_grad(): + tok = torch.cat([self.tokenizer(phrase) for phrase in self.positives]).to(self.device) + self.pos_embeds = self.model.encode_text(tok) + self.pos_embeds = self.pos_embeds / self.pos_embeds.norm(dim=-1, keepdim=True) + + @torch.no_grad() + def get_relevancy_map(self, sem_map: torch.Tensor) -> torch.Tensor: + """Compute per-pixel relevancy for all positive prompts across all levels. + + This exactly replicates ``OpenCLIPNetwork.get_max_across_quick`` from the + original LangSplatV2 evaluation. + + Args: + sem_map: CLIP feature maps of shape ``[n_levels, H, W, 512]``. + Each level corresponds to a different SAM scale model. + + Returns: + Relevancy tensor of shape ``[n_levels, n_prompts, H, W]`` where each + value is in ``[0, 1]``. + """ + if self.pos_embeds is None: + raise RuntimeError("Call set_positives() before computing relevancy maps.") + + n_levels, h, w, c = sem_map.shape + n_phrases = len(self.positives) + n_negatives = len(self.negatives) + + # Flatten spatial dims: [n_levels, H*W, 512] + sem_map_flat = sem_map.reshape(n_levels, h * w, c).contiguous() + + # All prompt embeddings: [P+N, 512] + phrase_embeds = torch.cat([self.pos_embeds, self.neg_embeds], dim=0) + phrase_embeds = phrase_embeds.to(sem_map.dtype).to(sem_map.device) + + # Dot-product similarity: [n_levels, H*W, P+N] + sim = torch.einsum("nqc,pc->nqp", sem_map_flat, phrase_embeds) + + # Split into positive and negative similarities + pos_vals = sim[:, :, :n_phrases] # [n_levels, H*W, P] + neg_vals = sim[:, :, n_phrases:] # [n_levels, H*W, N] + + # For each (positive, negative) pair, compute 2-way softmax with temperature 10 + repeated_pos = pos_vals.unsqueeze(-1).expand(-1, -1, -1, n_negatives) # [n_levels, H*W, P, N] + neg_vals_exp = neg_vals.unsqueeze(2).expand(-1, -1, n_phrases, -1) # [n_levels, H*W, P, N] + + sims = torch.stack([repeated_pos, neg_vals_exp], dim=-1) # [n_levels, H*W, P, N, 2] + softmax = torch.softmax(10 * sims, dim=-1) # [n_levels, H*W, P, N, 2] + + # Take minimum positive probability across all negatives + min_pos_prob, _ = softmax[..., 0].min(dim=-1) # [n_levels, H*W, P] + + # Reshape back to spatial dimensions + relev_map = min_pos_prob.permute(0, 2, 1).reshape(n_levels, n_phrases, h, w) + + return relev_map diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py index ea4d30f..e008fcc 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/loss.py @@ -43,13 +43,22 @@ def calculate_langsplatv2_loss( ) -> dict[str, torch.Tensor]: """Compute the LangSplatV2 language feature loss. - Compares predicted CLIP features with ground truth features, masked to only - include valid pixels (those covered by a SAM mask). + Both predicted and ground-truth features are zeroed for unmapped pixels + (mask == False), matching the original's + ``language_feature * language_feature_mask`` approach. Per-pixel loss + is then zeroed for unmapped pixels and averaged over **all** pixels. + This means: + + - Unmapped pixels contribute 0 to the reported loss. + - The mean over all pixels implicitly down-scales gradients from valid + pixels by ``N_valid / N_total``, matching the original's gradient + magnitudes exactly. Args: - predicted_features: Predicted feature map of shape ``[B, H, W, clip_n_dims]``. - gt_features: Ground truth feature map of shape ``[B, H, W, clip_n_dims]``. - mask: Boolean mask of shape ``[B, H, W]`` indicating valid pixels. + predicted_features: Predicted feature map, shape ``[B, H, W, C]`` + or ``[H, W, C]``. + gt_features: Ground truth feature map, same shape. + mask: Boolean mask, shape ``[B, H, W]`` or ``[H, W]``. use_cosine_loss: Whether to include cosine similarity loss. use_l1_loss: Whether to include L1 loss. normalize_features: Whether to L2-normalize predicted features @@ -59,52 +68,38 @@ def calculate_langsplatv2_loss( Dictionary with loss components: - ``"total_loss"``: Combined loss value. - ``"cosine_loss"``: Cosine loss component (if enabled). + - ``"cosine_loss_valid"``: Cosine loss on valid pixels only (for logging). - ``"l1_loss"``: L1 loss component (if enabled). """ assert use_cosine_loss or use_l1_loss, "At least one loss type must be enabled" - # Optionally normalize predicted features if normalize_features: predicted_features = predicted_features / (predicted_features.norm(dim=-1, keepdim=True) + 1e-10) loss_dict: dict[str, torch.Tensor] = {} total_loss = torch.tensor(0.0, device=predicted_features.device) - if not mask.any(): - # No valid pixels - return zero loss - loss_dict["total_loss"] = total_loss - if use_cosine_loss: - loss_dict["cosine_loss"] = total_loss - if use_l1_loss: - loss_dict["l1_loss"] = total_loss - return loss_dict - - # Gather only valid pixels (clean signal, no NaN risk from torch.empty). - valid_pred = predicted_features[mask] # [N_valid, clip_n_dims] - valid_gt = gt_features[mask] # [N_valid, clip_n_dims] - - # The original LangSplatV2 computes .mean() over ALL H*W pixels, where - # masked-out pixels are zero-vectors that contribute ~0 to the sum but - # inflate the denominator. This implicitly scales gradients down by - # (N_valid / N_total). We replicate this by computing the loss on valid - # pixels only (clean, interpretable values) and multiplying by the mask - # coverage fraction so that gradient magnitudes match the original exactly: - # - # grad_original = (1/N_total) * sum_valid(dL_i) - # grad_ours = (ratio/N_valid) * sum_valid(dL_i) - # = (1/N_total) * sum_valid(dL_i) [identical] - mask_fraction = mask.sum().float() / mask.numel() + mask_expanded = mask.unsqueeze(-1) # [..., 1] + masked_pred = predicted_features * mask_expanded + masked_gt = gt_features * mask_expanded if use_cosine_loss: - cos_loss_raw = cosine_loss(valid_pred, valid_gt) - cos_loss = cos_loss_raw * mask_fraction - loss_dict["cosine_loss"] = cos_loss - # Mean over valid pixels only (no mask_fraction); stable for logging when coverage varies - loss_dict["cosine_loss_valid"] = cos_loss_raw - total_loss = total_loss + cos_loss + per_pixel_cos = F.cosine_similarity(masked_pred, masked_gt, dim=-1) + per_pixel_loss = (1.0 - per_pixel_cos) * mask.float() + cos_loss_all = per_pixel_loss.sum() / mask.numel() + loss_dict["cosine_loss"] = cos_loss_all + total_loss = total_loss + cos_loss_all + + if mask.any(): + valid_pred = predicted_features[mask] + valid_gt = gt_features[mask] + loss_dict["cosine_loss_valid"] = cosine_loss(valid_pred, valid_gt) + else: + loss_dict["cosine_loss_valid"] = cos_loss_all if use_l1_loss: - l1 = l1_loss(valid_pred, valid_gt) * mask_fraction + per_pixel_l1 = torch.abs(masked_pred - masked_gt) * mask_expanded + l1 = per_pixel_l1.sum() / mask.numel() loss_dict["l1_loss"] = l1 total_loss = total_loss + l1 diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py index de05bce..b7d182f 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/__init__.py @@ -2,9 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 # from .clip_feature_encoding import ComputeCLIPFeatures +from .import_original_features import ImportOriginalLangSplatV2Features +from .multi_scale_sam1_masks import ComputeMultiScaleSAM1Masks from .multi_scale_sam_masks import ComputeMultiScaleSAM2Masks __all__ = [ "ComputeCLIPFeatures", + "ComputeMultiScaleSAM1Masks", "ComputeMultiScaleSAM2Masks", + "ImportOriginalLangSplatV2Features", ] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py index 2ce3372..3c10b51 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/clip_feature_encoding.py @@ -5,6 +5,9 @@ This transform computes CLIP features for masked image regions generated by the multi-scale SAM transform, following the LangSplatV2 approach. + +Crop extraction, masking, padding, and resize are performed on the GPU in +float32 to avoid uint8 quantisation artefacts that occur with ``cv2.resize`` on small masks. """ import logging from typing import Any @@ -12,57 +15,12 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F import tqdm from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene from fvdb_reality_capture.transforms import BaseTransform, transform -def get_seg_img(mask: np.ndarray, image: np.ndarray, bbox: np.ndarray) -> np.ndarray: - """ - Extract and crop the segmented region from an image. - - Args: - mask: Binary segmentation mask, shape [H, W]. - image: Source image, shape [H, W, 3]. - bbox: Bounding box in XYWH format. - - Returns: - Cropped and masked image region. - """ - x, y, w, h = map(int, bbox) - - # Crop first (view, no copy), then copy only the small region - cropped_img = image[y : y + h, x : x + w].copy() - cropped_mask = mask[y : y + h, x : x + w] - - # Apply mask only to the cropped region - cropped_img[cropped_mask == 0] = 0 - - return cropped_img - - -def pad_img(img: np.ndarray) -> np.ndarray: - """ - Pad an image to make it square. - - Args: - img: Input image, shape [H, W, 3]. - - Returns: - Square padded image. - """ - h, w = img.shape[:2] - l = max(w, h) - pad = np.zeros((l, l, 3), dtype=np.uint8) - - if h > w: - pad[:, (h - w) // 2 : (h - w) // 2 + w, :] = img - else: - pad[(w - h) // 2 : (w - h) // 2 + h, :, :] = img - - return pad - - @transform class ComputeCLIPFeatures(BaseTransform): """ @@ -76,7 +34,7 @@ class ComputeCLIPFeatures(BaseTransform): This transform must be run after ComputeMultiScaleSAM2Masks. """ - version = "1.0.0" + version = "1.1.0" def __init__( self, @@ -118,16 +76,21 @@ def _get_clip_model(self): def _encode_masked_regions( self, - image: np.ndarray, + image_gpu: torch.Tensor, masks: np.ndarray, bboxes: np.ndarray, ) -> torch.Tensor: """ Encode masked image regions using CLIP. + Performs crop-mask-pad-resize entirely on GPU in float32 (matching + the original LangSplatV2 ``mask2segmap`` + ``_embed_clip_sam_tiles``), + then normalises and encodes through CLIP in one batch. + Args: - image: Source image in RGB format, shape [H, W, 3]. - masks: Binary masks, shape [N, H, W]. + image_gpu: Source image on CUDA as float32, shape [H, W, 3], + values in [0, 255]. + masks: Binary masks, shape [N, H, W] (uint8 0/1). bboxes: Bounding boxes in XYWH format, shape [N, 4]. Returns: @@ -138,24 +101,40 @@ def _encode_masked_regions( clip_model = self._get_clip_model() image_size = clip_model.image_size - - # Extract each masked region, pad to square, and resize to model's expected size - seg_imgs = [] - for mask, bbox in zip(masks, bboxes): - seg_img = get_seg_img(mask, image, bbox) - pad_seg_img = pad_img(seg_img) - resized_img = cv2.resize(pad_seg_img, (image_size, image_size)) - seg_imgs.append(resized_img) - - # Stack and convert to tensor with values in [0, 1] - seg_imgs_np = np.stack(seg_imgs, axis=0) # [N, image_size, image_size, 3] - seg_imgs_tensor = ( - torch.from_numpy(seg_imgs_np.astype("float32")).permute(0, 3, 1, 2) / 255.0 - ) # [N, 3, image_size, image_size] - - # Apply model's tensor preprocessing (handles normalization) and encode + n = len(masks) + device = image_gpu.device + + all_segs = torch.from_numpy(masks).to(device) + seg_imgs = torch.empty(n, 3, image_size, image_size, device=device) + + for i in range(n): + x, y, w, h = int(bboxes[i][0]), int(bboxes[i][1]), int(bboxes[i][2]), int(bboxes[i][3]) + w = max(w, 1) + h = max(h, 1) + + crop = image_gpu[y : y + h, x : x + w].clone() + crop[~all_segs[i, y : y + h, x : x + w].bool()] = 0 + crop = crop.permute(2, 0, 1) # HWC -> CHW + + side = max(h, w) + padded = torch.zeros(3, side, side, device=device) + if h > w: + offset = (h - w) // 2 + padded[:, :, offset : offset + w] = crop + else: + offset = (w - h) // 2 + padded[:, offset : offset + h, :] = crop + + seg_imgs[i] = F.interpolate( + padded.unsqueeze(0), size=(image_size, image_size), + mode="bilinear", align_corners=False, + ).squeeze(0) + + seg_imgs /= 255.0 + + # Normalise and encode with torch.no_grad(): - seg_imgs_preprocessed = clip_model.preprocess_tensor(seg_imgs_tensor) + seg_imgs_preprocessed = clip_model.preprocess_tensor(seg_imgs) clip_embeds = clip_model.encode_image(seg_imgs_preprocessed) clip_embeds = clip_embeds / clip_embeds.norm(dim=-1, keepdim=True) @@ -226,7 +205,8 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: # Create cache folder model_type_safe = self._clip_model_type.replace("-", "_") pretrained_safe = self._clip_model_pretrained.replace("-", "_") - cache_prefix = f"clip_features_{model_type_safe}_{pretrained_safe}_{self._clip_n_dims}" + version_safe = self.version.replace(".", "_") + cache_prefix = f"clip_features_{model_type_safe}_{pretrained_safe}_{self._clip_n_dims}_v{version_safe}" output_cache = input_cache.make_folder( cache_prefix, description=f"CLIP features using {self._clip_model_type}", @@ -286,16 +266,17 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) h, w = img.shape[:2] - # Load SAM2 masks from parent cache + img_gpu = torch.from_numpy(img_rgb).to(self._device, dtype=torch.float32) + mask_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" if not input_cache.has_file(mask_filename): raise RuntimeError( - f"Mask file {mask_filename} not found in cache. " "Run ComputeMultiScaleSAM2Masks first." + f"Mask file {mask_filename} not found in cache. " + "Run ComputeMultiScaleSAM2Masks or ComputeMultiScaleSAM1Masks first." ) _, mask_data = input_cache.read_file(mask_filename) - # Encode all masked regions across scales all_features = [] scale_names = ["default", "s", "m", "l"] @@ -304,7 +285,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: bboxes = mask_data.get(f"{scale_name}_bboxes", np.zeros((0, 4))) if len(masks) > 0: - scale_features = self._encode_masked_regions(img_rgb, masks, bboxes) + scale_features = self._encode_masked_regions(img_gpu, masks, bboxes) all_features.append(scale_features) # Concatenate features from all scales @@ -320,6 +301,20 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: total_masks = sum(lengths) assert features.shape[0] == total_masks, f"Feature count mismatch: {features.shape[0]} vs {total_masks}" + if features.shape[0] > 0 and torch.isnan(features).any(): + self._logger.warning( + f"Image {image_meta.image_id}: CLIP features contain NaN " + f"({torch.isnan(features).sum().item()} / {features.numel()} values)" + ) + + # Report per-scale mask coverage for the first few images + if image_meta.image_id < 3: + coverage = {sn: int((seg_maps[i] >= 0).sum()) for i, sn in enumerate(scale_names)} + self._logger.debug( + f"Image {image_meta.image_id}: {total_masks} masks, " + f"lengths={lengths}, pixel coverage={coverage}" + ) + # Save cache_filename = f"features_{image_meta.image_id:0{num_zeropad}}" output_cache.write_file( diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py new file mode 100644 index 0000000..059a13d --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/import_original_features.py @@ -0,0 +1,252 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Import pre-computed features from the original LangSplatV2 repository. + +The original LangSplatV2 ``preprocess.py`` saves per-image features as two +numpy files: + +- ``{image_name}_f.npy``: CLIP feature vectors ``[N_total, 512]`` (float16) +- ``{image_name}_s.npy``: segmentation maps ``[4, H, W]`` (int32) + +This transform reads those files and writes them into our SfmCache format, +acting as a drop-in replacement for the ``ComputeMultiScaleSAM2Masks`` + +``ComputeCLIPFeatures`` pipeline to test our pipeline's outputs. +""" +import logging +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import tqdm +from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene +from fvdb_reality_capture.transforms import BaseTransform, transform + + +def _compute_lengths_from_seg_maps( + seg_maps: np.ndarray, + total_features: int, +) -> list[int]: + """Derive per-level feature counts from the original's segmentation maps. + + The original concatenates features in order (default, s, m, l) and assigns + globally-offset indices in each seg_map channel. It guarantees that the + **max** index in each non-empty channel equals + ``cumulative_length[level] - 1`` (asserted in ``preprocess.py``). We + exploit this by recovering cumulative lengths from the max index per + channel: ``cum[j] = max(channel_j) + 1``. + + This is robust to masks whose indices never appear in the seg_map (e.g. + a mask at the start of a level with zero pixels assigned) because only + the max -- which the original guarantees correct -- is used. + + Args: + seg_maps: Array of shape ``[4, H, W]`` with int32 indices (-1 = none). + total_features: Total number of features in the corresponding + ``_f.npy`` file (used as a cross-check). + + Returns: + List of 4 integers giving the number of features at each scale level. + """ + cum: list[int] = [] + for level in range(4): + channel = seg_maps[level] + valid = channel[channel >= 0] + if len(valid) == 0: + cum.append(cum[-1] if cum else 0) + else: + cum.append(int(valid.max()) + 1) + + lengths = [cum[0]] + [cum[j] - cum[j - 1] for j in range(1, 4)] + return lengths + + +@transform +class ImportOriginalLangSplatV2Features(BaseTransform): + """Import pre-computed language features from the original LangSplatV2. + + Reads ``_f.npy`` / ``_s.npy`` file pairs produced by the original + ``preprocess.py`` and writes them into the SfmCache in the same dict + format that ``ComputeCLIPFeatures`` produces: + + .. code-block:: python + + {"features": Tensor, "seg_maps": Tensor, "lengths": Tensor} + + This allows the training pipeline to consume original features without + any changes to the dataset or trainer code. + """ + + version = "1.0.2" + + def __init__( + self, + original_features_dir: Path | str, + clip_n_dims: int = 512, + ): + """ + Args: + original_features_dir: Path to the original ``language_features/`` + directory containing ``*_f.npy`` and ``*_s.npy`` files. + clip_n_dims: Expected CLIP embedding dimensionality (used for + cache naming and validation). + """ + self._features_dir = Path(original_features_dir) + self._clip_n_dims = clip_n_dims + self._logger = logging.getLogger( + f"{self.__class__.__module__}.{self.__class__.__name__}" + ) + + def __call__(self, input_scene: SfmScene) -> SfmScene: + if len(input_scene.images) == 0: + self._logger.warning("No images in SfmScene. Returning unchanged.") + return input_scene + + input_cache: SfmCache = input_scene.cache + + version_safe = self.version.replace(".", "_") + cache_prefix = ( + f"imported_original_langsplatv2_features" + f"_{self._clip_n_dims}_v{version_safe}" + ) + output_cache = input_cache.make_folder( + cache_prefix, + description="Imported features from original LangSplatV2", + ) + + num_zeropad = len(str(input_scene.num_images)) + 2 + regenerate_cache = False + + if output_cache.num_files != input_scene.num_images: + if output_cache.num_files != 0: + self._logger.info( + f"Cache has {output_cache.num_files} files but expected " + f"{input_scene.num_images}. Regenerating cache." + ) + output_cache.clear_current_folder() + regenerate_cache = True + + if not regenerate_cache: + for image_id in range(input_scene.num_images): + cache_filename = f"features_{image_id:0{num_zeropad}}" + if not output_cache.has_file(cache_filename): + self._logger.info( + f"{cache_filename} missing from cache. Regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + if regenerate_cache: + self._logger.info( + f"Importing original features from {self._features_dir}" + ) + pbar = tqdm.tqdm( + input_scene.images, + unit="imgs", + desc="Importing original features", + ) + + for image_meta in pbar: + image_name = Path(image_meta.image_path).stem + f_path = self._features_dir / f"{image_name}_f.npy" + s_path = self._features_dir / f"{image_name}_s.npy" + + if not f_path.exists(): + raise FileNotFoundError( + f"Feature file not found: {f_path}. " + f"Run the original LangSplatV2 preprocess.py first." + ) + if not s_path.exists(): + raise FileNotFoundError( + f"Segmentation map not found: {s_path}. " + f"Run the original LangSplatV2 preprocess.py first." + ) + + features_np = np.load(str(f_path)) # [N_total, 512], float16 + seg_maps_np = np.load(str(s_path)) # [4, H, W], int32 + + if features_np.shape[1] != self._clip_n_dims: + raise ValueError( + f"Feature dimension mismatch for {image_name}: " + f"expected {self._clip_n_dims}, got {features_np.shape[1]}" + ) + + lengths = _compute_lengths_from_seg_maps( + seg_maps_np, features_np.shape[0] + ) + + expected_total = sum(lengths) + if features_np.shape[0] != expected_total: + self._logger.warning( + f"{image_name}: feature count ({features_np.shape[0]}) " + f"!= sum(lengths) ({expected_total}). " + f"Possible corrupt feature/seg_map pair." + ) + + features = torch.from_numpy(features_np).half() + seg_maps = torch.from_numpy(seg_maps_np).int() + lengths_t = torch.tensor(lengths, dtype=torch.int32) + + cache_filename = f"features_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data={ + "features": features, + "seg_maps": seg_maps, + "lengths": lengths_t, + }, + data_type="pt", + metadata={ + "source": "original_langsplatv2", + "clip_n_dims": self._clip_n_dims, + "original_features_dir": str(self._features_dir), + }, + ) + + pbar.close() + self._logger.info( + f"Imported features for {input_scene.num_images} images." + ) + else: + self._logger.info("Original features already cached.") + + output_scene = SfmScene( + cameras=input_scene.cameras, + images=input_scene.images, + points=input_scene.points, + points_err=input_scene.points_err, + points_rgb=input_scene.points_rgb, + scene_bbox=input_scene.scene_bbox, + transformation_matrix=input_scene.transformation_matrix, + cache=output_cache, + ) + + return output_scene + + @staticmethod + def name() -> str: + return "ImportOriginalLangSplatV2Features" + + def state_dict(self) -> dict[str, Any]: + return { + "name": self.name(), + "version": self.version, + "original_features_dir": str(self._features_dir), + "clip_n_dims": self._clip_n_dims, + } + + @staticmethod + def from_state_dict( + state_dict: dict[str, Any], + ) -> "ImportOriginalLangSplatV2Features": + if state_dict["name"] != "ImportOriginalLangSplatV2Features": + raise ValueError( + f"Expected 'ImportOriginalLangSplatV2Features', " + f"got {state_dict['name']}" + ) + return ImportOriginalLangSplatV2Features( + original_features_dir=state_dict["original_features_dir"], + clip_n_dims=state_dict["clip_n_dims"], + ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py new file mode 100644 index 0000000..fc1145b --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/mask_utils.py @@ -0,0 +1,354 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Shared mask post-processing utilities for multi-scale SAM transforms. + +These functions are used by both :class:`ComputeMultiScaleSAM1Masks` and +:class:`ComputeMultiScaleSAM2Masks`. They have no dependency on either the +``segment_anything`` or ``sam2`` packages. +""" +import logging +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Tuple + +import cv2 +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms + + +def remove_small_regions( + mask: np.ndarray, + area_thresh: float, + mode: str, +) -> Tuple[np.ndarray, bool]: + """Remove small disconnected regions or holes from a binary mask. + + Pure OpenCV implementation matching ``segment_anything.utils.amg.remove_small_regions``. + + Args: + mask: Boolean mask, shape ``[H, W]``. + area_thresh: Minimum area; components smaller than this are removed. + mode: ``"holes"`` to fill small holes, ``"islands"`` to remove small islands. + + Returns: + ``(cleaned_mask, changed)`` where *changed* is True if the mask was modified. + """ + assert mode in ("holes", "islands") + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1] + small_regions = [i for i, s in enumerate(sizes) if s < area_thresh and i != 0] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels or i == 0] + mask = np.isin(regions, fill_labels) + return mask, True + + +def _clean_single_mask( + seg_raw: np.ndarray, + min_area: int, +) -> Tuple[np.ndarray, float, List[float]]: + """Clean one mask and compute its bounding box. Thread-safe (CPU-only). + + Returns: + (cleaned_seg, score, box_xyxy) where score is 1.0 if the mask was + unchanged and 0.0 if it was modified. + """ + seg = seg_raw.astype(bool) + seg, changed_holes = remove_small_regions(seg, min_area, mode="holes") + unchanged = not changed_holes + seg, changed_islands = remove_small_regions(seg, min_area, mode="islands") + unchanged = unchanged and not changed_islands + + ys, xs = np.where(seg) + if len(xs) == 0: + box = [0.0, 0.0, 0.0, 0.0] + else: + box = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())] + + return seg, float(unchanged), box + + +def postprocess_small_regions( + masks: List[Dict[str, Any]], + min_area: int, + nms_thresh: float, +) -> List[Dict[str, Any]]: + """Remove small disconnected regions and holes from masks, then re-run box NMS. + + Mirrors the ``postprocess_small_regions`` step in the original SAM + ``SamAutomaticMaskGenerator.generate_curr_anns`` which cleans up each + mask's binary segmentation before returning annotations. + + The per-mask cleaning (``cv2.connectedComponentsWithStats``) is + parallelized across CPU cores via :class:`ThreadPoolExecutor`. + + Args: + masks: List of mask annotation dicts with at least ``segmentation`` + (np.ndarray bool/uint8 HxW) and ``bbox`` ([x, y, w, h]). + min_area: Minimum area threshold for ``remove_small_regions``. + nms_thresh: Box NMS IoU threshold for deduplication after cleaning. + + Returns: + Filtered list of mask annotation dicts with cleaned segmentations. + """ + if len(masks) == 0: + return masks + + max_workers = min(len(masks), os.cpu_count() or 4) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit(_clean_single_mask, m["segmentation"], min_area) + for m in masks + ] + results = [f.result() for f in futures] + + new_segmentations = [r[0] for r in results] + scores = [r[1] for r in results] + boxes_list = [r[2] for r in results] + + boxes_xyxy = torch.tensor(boxes_list, dtype=torch.float32) + scores_t = torch.tensor(scores, dtype=torch.float32) + + keep = batched_nms( + boxes_xyxy, + scores_t, + torch.zeros(len(masks), dtype=torch.long), + iou_threshold=nms_thresh, + ) + + result = [] + for idx in keep.tolist(): + m = masks[idx].copy() + m["segmentation"] = new_segmentations[idx].astype(np.uint8) + x_min, y_min, x_max, y_max = boxes_list[idx] + m["bbox"] = [x_min, y_min, x_max - x_min, y_max - y_min] + m["area"] = int(new_segmentations[idx].sum()) + result.append(m) + + return result + + +_mask_nms_logger = logging.getLogger(__name__ + ".mask_nms") +_mask_nms_diag_count = 0 + + +def mask_nms( + masks: torch.Tensor, + scores: torch.Tensor, + iou_thr: float = 0.7, + score_thr: float = 0.1, + inner_thr: float = 0.2, + **kwargs, +) -> torch.Tensor: + """ + Perform mask non-maximum suppression (NMS) on a set of masks. + + Faithful reimplementation of the ``mask_nms`` from the original + LangSplatV2 ``preprocess.py``. + + Args: + masks: Binary masks, shape ``[num_masks, H, W]``. + scores: Mask scores, shape ``[num_masks]``. + iou_thr: IoU threshold for NMS. + score_thr: Minimum score threshold. + inner_thr: Inner overlap threshold for removing contained masks. + + Returns: + Indices of selected masks after NMS. + """ + global _mask_nms_diag_count + log_diag = _mask_nms_diag_count < 8 + if log_diag: + _mask_nms_diag_count += 1 + + if len(masks) == 0: + return torch.tensor([], dtype=torch.long) + + scores, idx = scores.sort(0, descending=True) + num_masks = idx.shape[0] + + masks_ord = masks[idx.view(-1), :] + masks_area = torch.sum(masks_ord, dim=(1, 2), dtype=torch.float) + + masks_flat = masks_ord.reshape(num_masks, -1).float() + intersection = masks_flat @ masks_flat.T + union = masks_area[:, None] + masks_area[None, :] - intersection + iou_matrix = intersection / union + + R = intersection / masks_area[:, None] + inner_val = 1 - R * R.T + cond = (R < 0.5) & (R.T >= 0.85) + inner_iou_matrix = torch.where(cond, inner_val, torch.zeros_like(inner_val)) + + iou_matrix.triu_(diagonal=1) + iou_max, _ = iou_matrix.max(dim=0) + inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) + inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) + # NOTE: this includes the diagonal and the first super-diagonal, which + # doesn't match the intended “lower triangle excluding diagonal” + # logic used for the containment check. + # this should use the lower triangle below the diagonal (-1) + # but this won't really effect results and we want to keep the original logic + # to match the original LangSplatV2 implementation + inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=1) + inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) + + keep = iou_max <= iou_thr + keep_conf = scores > score_thr + keep_inner_u = inner_iou_max_u <= 1 - inner_thr + keep_inner_l = inner_iou_max_l <= 1 - inner_thr + + if log_diag: + _mask_nms_logger.info( + "[mask_nms diag] input=%d masks scores: min=%.4f max=%.4f " + "areas: min=%.0f max=%.0f iou_max: min=%.4f max=%.4f mean=%.4f " + "pass_iou=%d pass_conf=%d pass_inner_u=%d pass_inner_l=%d", + num_masks, + scores.min().item(), scores.max().item(), + masks_area.min().item(), masks_area.max().item(), + iou_max.min().item(), iou_max.max().item(), iou_max.mean().item(), + keep.sum().item(), keep_conf.sum().item(), + keep_inner_u.sum().item(), keep_inner_l.sum().item(), + ) + + if keep_conf.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_conf[index] = True + if keep_inner_u.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_inner_u[index] = True + if keep_inner_l.sum() == 0: + index = scores.topk(min(3, len(scores))).indices + keep_inner_l[index] = True + keep *= keep_conf + keep *= keep_inner_u + keep *= keep_inner_l + + if log_diag: + _mask_nms_logger.info( + "[mask_nms diag] final keep=%d / %d", keep.sum().item(), num_masks, + ) + + selected_idx = idx[keep] + return selected_idx + + +_masks_update_logger = logging.getLogger(__name__ + ".masks_update") + + +def masks_update( + *mask_lists, + iou_thr: float = 0.8, + score_thr: float = 0.7, + inner_thr: float = 0.5, + max_area_frac: float = 0.95, +) -> tuple: + """ + Apply mask NMS to multiple lists of masks. + + Args: + *mask_lists: Variable number of mask lists to filter. + iou_thr: IoU threshold for NMS. + score_thr: Score threshold. + inner_thr: Inner overlap threshold. + max_area_frac: Discard masks covering more than this fraction of the + image. Near-full-image masks poison the inner-containment check + by appearing to contain every other mask. + + Returns: + Tuple of filtered mask lists. + """ + masks_new = [] + + for masks_lvl in mask_lists: + if len(masks_lvl) == 0: + masks_new.append([]) + continue + + if max_area_frac < 1.0: + h, w = masks_lvl[0]["segmentation"].shape[:2] + total_pixels = h * w + area_limit = total_pixels * max_area_frac + before = len(masks_lvl) + masks_lvl = [m for m in masks_lvl if m["segmentation"].sum() <= area_limit] + n_dropped = before - len(masks_lvl) + if n_dropped > 0: + _masks_update_logger.info( + "[masks_update] dropped %d masks covering >%.0f%% of image (%d remain)", + n_dropped, max_area_frac * 100, len(masks_lvl), + ) + if len(masks_lvl) == 0: + masks_new.append([]) + continue + + before_empty = len(masks_lvl) + masks_lvl = [m for m in masks_lvl if m["segmentation"].sum() > 0] + n_empty = before_empty - len(masks_lvl) + if n_empty > 0: + _masks_update_logger.info( + "[masks_update] dropped %d zero-area masks (%d remain)", + n_empty, len(masks_lvl), + ) + if len(masks_lvl) == 0: + masks_new.append([]) + continue + + seg_pred = torch.from_numpy(np.stack([m["segmentation"] for m in masks_lvl], axis=0)) + iou_pred = torch.from_numpy(np.stack([m["predicted_iou"] for m in masks_lvl], axis=0)) + stability = torch.from_numpy(np.stack([m["stability_score"] for m in masks_lvl], axis=0)) + + if torch.cuda.is_available(): + seg_pred = seg_pred.cuda() + iou_pred = iou_pred.cuda() + stability = stability.cuda() + + scores = stability * iou_pred + keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr) + + keep_set = set(keep_mask_nms.int().cpu().numpy().tolist()) + filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set] + masks_new.append(filtered_masks) + + return tuple(masks_new) + + +def cross_crop_nms( + mask_list: List[Dict[str, Any]], + iou_threshold: float = 0.7, +) -> List[Dict[str, Any]]: + """Run NMS across masks from multiple crops; prefer masks from smaller crops. + + Args: + mask_list: List of mask records with "bbox" (xywh) and "crop_box" (xywh). + iou_threshold: Box IoU threshold for NMS. + + Returns: + Filtered list of mask records. + """ + if len(mask_list) <= 1: + return mask_list + boxes_xywh = np.array([m["bbox"] for m in mask_list], dtype=np.float32) + x1 = boxes_xywh[:, 0] + y1 = boxes_xywh[:, 1] + x2 = boxes_xywh[:, 0] + boxes_xywh[:, 2] + y2 = boxes_xywh[:, 1] + boxes_xywh[:, 3] + boxes_xyxy = torch.from_numpy(np.stack([x1, y1, x2, y2], axis=1)) + crop_areas = np.array( + [m["crop_box"][2] * m["crop_box"][3] for m in mask_list], + dtype=np.float32, + ) + scores = torch.from_numpy(1.0 / (crop_areas + 1e-6)) + keep = batched_nms( + boxes_xyxy.float(), + scores, + torch.zeros(len(mask_list), dtype=torch.long), + iou_threshold=iou_threshold, + ) + return [mask_list[i] for i in keep.tolist()] diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py new file mode 100644 index 0000000..9b6a664 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam1_masks.py @@ -0,0 +1,422 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Multi-scale SAM1 segmentation transform. + +Uses :class:`fvdb_reality_capture.foundation_models.SAM1Model` with +``output_mode="multi_scale"`` to match the original LangSplatV2's SAM ViT-H +mask generation pipeline. Crop management, point grid generation, cross-crop +NMS, and mask NMS mirror the SAM2 transform. +""" +import logging +from typing import Any, Dict, List, Literal + +import cv2 +import numpy as np +import torch +import tqdm + +from fvdb_reality_capture.foundation_models import SAM1Model +from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene +from fvdb_reality_capture.transforms import BaseTransform, transform + +from .mask_utils import cross_crop_nms, masks_update, postprocess_small_regions + + +def _sam1_amg(): + """Lazy accessor for ``segment_anything.utils.amg``.""" + try: + import segment_anything.utils.amg as _amg + except ImportError: + raise ImportError( + "SAM1 transform requires the segment-anything package. Install with:\n" + " conda install -c conda-forge segment-anything\n" + "Or update your environment:\n" + " conda env update -f open_vocabulary_segmentation/langsplatv2/environment.yml" + ) from None + return _amg + + +@transform +class ComputeMultiScaleSAM1Masks(BaseTransform): + """Generate multi-scale segmentation masks using SAM1 (ViT-H/L/B). + + Uses :class:`fvdb_reality_capture.foundation_models.SAM1Model` with + ``output_mode="multi_scale"`` to split the 3 multimask outputs per point + by index (small/medium/large). Default parameters match the original + LangSplatV2 ``preprocess.py`` exactly. + """ + + version = "1.5.0" + + def __init__( + self, + checkpoint: Literal["vit_h", "vit_l", "vit_b"] = "vit_h", + points_per_side: int = 32, + points_per_batch: int = 256, + pred_iou_thresh: float = 0.7, + stability_score_thresh: float = 0.85, + crop_n_layers: int = 1, + crop_n_points_downscale_factor: int = 1, + min_mask_region_area: int = 100, + box_nms_thresh: float = 0.7, + nms_iou_thr: float = 0.8, + nms_score_thr: float = 0.7, + nms_inner_thr: float = 0.5, + device: torch.device | str = "cuda", + ): + """ + Create a multi-scale SAM1 mask generation transform. + + Args: + checkpoint: SAM1 model variant (vit_h, vit_l, vit_b). + points_per_side: Grid density for point prompts. + points_per_batch: Points processed simultaneously. + pred_iou_thresh: Predicted IoU threshold. + stability_score_thresh: Stability score threshold. + crop_n_layers: Number of crop layers (1 = also run on crops, + matching the original LangSplatV2). + crop_n_points_downscale_factor: Point grid downscale per crop layer. + min_mask_region_area: Minimum mask region area for post-processing. + box_nms_thresh: Box NMS IoU threshold within each crop. + nms_iou_thr: IoU threshold for mask NMS post-processing. + nms_score_thr: Score threshold for mask NMS. + nms_inner_thr: Inner overlap threshold for mask NMS. + device: Device to run SAM1 on. + """ + self._checkpoint = checkpoint + self._points_per_side = points_per_side + self._points_per_batch = points_per_batch + self._pred_iou_thresh = pred_iou_thresh + self._stability_score_thresh = stability_score_thresh + self._crop_n_layers = crop_n_layers + self._crop_n_points_downscale_factor = crop_n_points_downscale_factor + self._min_mask_region_area = min_mask_region_area + self._box_nms_thresh = box_nms_thresh + self._nms_iou_thr = nms_iou_thr + self._nms_score_thr = nms_score_thr + self._nms_inner_thr = nms_inner_thr + self._device = device + + self._sam1_model: SAM1Model | None = None + self._logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") + + def _get_sam1_model(self) -> SAM1Model: + if self._sam1_model is None: + self._sam1_model = SAM1Model( + checkpoint=self._checkpoint, + points_per_side=self._points_per_side, + points_per_batch=self._points_per_batch, + pred_iou_thresh=self._pred_iou_thresh, + stability_score_thresh=self._stability_score_thresh, + min_mask_region_area=self._min_mask_region_area, + box_nms_thresh=self._box_nms_thresh, + output_mode="multi_scale", + device=self._device, + ) + return self._sam1_model + + def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: + """Generate masks at multiple scales. + + Uses multi-crop generation and cross-crop NMS, then mask NMS per scale. + + Args: + image: Input image in BGR format (OpenCV default), shape ``[H, W, 3]``. + + Returns: + Dictionary with mask lists keyed by ``"default"``, ``"s"``, + ``"m"``, ``"l"``. + """ + amg = _sam1_amg() + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + orig_size = image_rgb.shape[:2] + sam1 = self._get_sam1_model() + + crop_boxes, layer_idxs = amg.generate_crop_boxes( + orig_size, self._crop_n_layers, 512 / 1500 + ) + point_grids = amg.build_all_layer_point_grids( + self._points_per_side, + self._crop_n_layers, + self._crop_n_points_downscale_factor, + ) + + all_default: List[Dict[str, Any]] = [] + all_s: List[Dict[str, Any]] = [] + all_m: List[Dict[str, Any]] = [] + all_l: List[Dict[str, Any]] = [] + + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + x0, y0, x1, y1 = crop_box + cropped_im = image_rgb[y0:y1, x0:x1, :] + cropped_h, cropped_w = cropped_im.shape[:2] + points_scale = np.array([cropped_w, cropped_h], dtype=np.float64) + point_coords = point_grids[layer_idx] * points_scale + + md, ms, mm, ml = sam1.predict_masks_multi_scale( + cropped_im, + point_coords=point_coords, + crop_box=crop_box, + orig_size=orig_size, + ) + all_default.extend(md) + all_s.extend(ms) + all_m.extend(mm) + all_l.extend(ml) + + log_diag = image_rgb.shape is not None # always log first few images + if log_diag and hasattr(self, '_diag_count'): + log_diag = self._diag_count < 2 + if log_diag: + if not hasattr(self, '_diag_count'): + self._diag_count = 0 + self._diag_count += 1 + self._logger.debug( + "[diag] after predict: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + if len(crop_boxes) > 1: + all_default = cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) + all_s = cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) + all_m = cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) + all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + + if log_diag: + self._logger.debug( + "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + if self._min_mask_region_area > 0: + nms_thresh = self._box_nms_thresh + all_default = postprocess_small_regions(all_default, self._min_mask_region_area, nms_thresh) + all_s = postprocess_small_regions(all_s, self._min_mask_region_area, nms_thresh) + all_m = postprocess_small_regions(all_m, self._min_mask_region_area, nms_thresh) + all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) + + if log_diag: + self._logger.debug( + "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + masks_default, masks_s, masks_m, masks_l = masks_update( + all_default, + all_s, + all_m, + all_l, + iou_thr=self._nms_iou_thr, + score_thr=self._nms_score_thr, + inner_thr=self._nms_inner_thr, + ) + + if log_diag: + self._logger.debug( + "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", + len(masks_default), len(masks_s), len(masks_m), len(masks_l), + ) + + return { + "default": masks_default, + "s": masks_s, + "m": masks_m, + "l": masks_l, + } + + def __call__(self, input_scene: SfmScene) -> SfmScene: + """Generate multi-scale SAM1 masks for all images in the scene. + + Args: + input_scene: Input scene containing images. + + Returns: + Scene with cache containing multi-scale mask data. + """ + if len(input_scene.images) == 0: + self._logger.warning("No images found in the SfmScene. Returning unchanged.") + return input_scene + + input_cache: SfmCache = input_scene.cache + + version_safe = self.version.replace(".", "_") + cache_prefix = ( + f"sam1_multi_scale_masks_{self._checkpoint}_" + f"p{self._points_per_side}_" + f"iou{int(self._pred_iou_thresh * 100)}_" + f"stab{int(self._stability_score_thresh * 100)}_" + f"crop{self._crop_n_layers}_" + f"nmsiou{int(self._nms_iou_thr * 100)}_" + f"nmsscore{int(self._nms_score_thr * 100)}_" + f"nmsinner{int(self._nms_inner_thr * 100)}_" + f"v{version_safe}" + ) + output_cache = input_cache.make_folder( + cache_prefix, + description=f"Multi-scale SAM1 masks with {self._checkpoint} checkpoint", + ) + + num_zeropad = len(str(len(input_scene.images))) + 2 + regenerate_cache = False + + if output_cache.num_files != input_scene.num_images: + if output_cache.num_files != 0: + self._logger.info( + f"Cache has {output_cache.num_files} files but expected " + f"{input_scene.num_images}. Regenerating cache." + ) + output_cache.clear_current_folder() + regenerate_cache = True + + for image_id in range(input_scene.num_images): + if regenerate_cache: + break + cache_filename = f"masks_{image_id:0{num_zeropad}}" + if not output_cache.has_file(cache_filename): + self._logger.info( + f"Masks {cache_filename} not found in the cache. " f"Clearing cache and regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + cache_meta = output_cache.get_file_metadata(cache_filename) + value_meta = cache_meta.get("metadata", {}) + if ( + value_meta.get("checkpoint") != self._checkpoint + or value_meta.get("points_per_side") != self._points_per_side + or value_meta.get("pred_iou_thresh") != self._pred_iou_thresh + or value_meta.get("stability_score_thresh") != self._stability_score_thresh + or value_meta.get("crop_n_layers") != self._crop_n_layers + or value_meta.get("min_mask_region_area") != self._min_mask_region_area + or value_meta.get("nms_iou_thr") != self._nms_iou_thr + or value_meta.get("nms_score_thr") != self._nms_score_thr + or value_meta.get("nms_inner_thr") != self._nms_inner_thr + ): + self._logger.info( + f"Cache metadata does not match expected parameters. " f"Clearing cache and regenerating." + ) + output_cache.clear_current_folder() + regenerate_cache = True + break + + if regenerate_cache: + self._logger.info("Generating multi-scale SAM1 masks for all images.") + + pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM1 masks") + + for image_meta in pbar: + image_path = image_meta.image_path + img = cv2.imread(image_path) + assert img is not None, f"Failed to load image {image_path}" + + img = image_meta.camera_metadata.undistort_image(img) + + masks_dict = self._generate_multi_scale_masks(img) + + mask_data = {} + for scale_name, masks in masks_dict.items(): + if len(masks) > 0: + mask_data[f"{scale_name}_segmentations"] = np.stack( + [m["segmentation"].astype(np.uint8) for m in masks], axis=0 + ) + mask_data[f"{scale_name}_bboxes"] = np.array([m["bbox"] for m in masks], dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.array([m["area"] for m in masks], dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.array( + [m["predicted_iou"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_stability_scores"] = np.array( + [m["stability_score"] for m in masks], dtype=np.float32 + ) + else: + mask_data[f"{scale_name}_segmentations"] = np.zeros( + (0, img.shape[0], img.shape[1]), dtype=np.uint8 + ) + mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) + mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) + + cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data=mask_data, + data_type="pt", + metadata={ + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "min_mask_region_area": self._min_mask_region_area, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + }, + ) + + pbar.close() + self._logger.info(f"Generated masks for {input_scene.num_images} images.") + else: + self._logger.info("Loading masks from cache.") + + output_scene = SfmScene( + cameras=input_scene.cameras, + images=input_scene.images, + points=input_scene.points, + points_err=input_scene.points_err, + points_rgb=input_scene.points_rgb, + scene_bbox=input_scene.scene_bbox, + transformation_matrix=input_scene.transformation_matrix, + cache=output_cache, + ) + + return output_scene + + @staticmethod + def name() -> str: + return "ComputeMultiScaleSAM1Masks" + + def state_dict(self) -> dict[str, Any]: + return { + "name": self.name(), + "version": self.version, + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "points_per_batch": self._points_per_batch, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "crop_n_points_downscale_factor": self._crop_n_points_downscale_factor, + "min_mask_region_area": self._min_mask_region_area, + "box_nms_thresh": self._box_nms_thresh, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + "device": str(self._device), + } + + @staticmethod + def from_state_dict(state_dict: dict[str, Any]) -> "ComputeMultiScaleSAM1Masks": + if state_dict["name"] != "ComputeMultiScaleSAM1Masks": + raise ValueError( + f"Expected state_dict with name 'ComputeMultiScaleSAM1Masks', " + f"got {state_dict['name']} instead." + ) + + return ComputeMultiScaleSAM1Masks( + checkpoint=state_dict["checkpoint"], + points_per_side=state_dict["points_per_side"], + points_per_batch=state_dict.get("points_per_batch", 256), + pred_iou_thresh=state_dict["pred_iou_thresh"], + stability_score_thresh=state_dict["stability_score_thresh"], + crop_n_layers=state_dict.get("crop_n_layers", 1), + crop_n_points_downscale_factor=state_dict.get("crop_n_points_downscale_factor", 1), + min_mask_region_area=state_dict.get("min_mask_region_area", 100), + box_nms_thresh=state_dict.get("box_nms_thresh", 0.7), + nms_iou_thr=state_dict["nms_iou_thr"], + nms_score_thr=state_dict["nms_score_thr"], + nms_inner_thr=state_dict["nms_inner_thr"], + device=state_dict["device"], + ) diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py index 6f4229f..7bbef40 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/scene_transforms/multi_scale_sam_masks.py @@ -15,190 +15,13 @@ import numpy as np import torch import tqdm -from torchvision.ops.boxes import batched_nms import sam2.utils.amg as _sam2_amg from fvdb_reality_capture.foundation_models import SAM2Model from fvdb_reality_capture.sfm_scene import SfmCache, SfmScene from fvdb_reality_capture.transforms import BaseTransform, transform -def mask_nms( - masks: torch.Tensor, - scores: torch.Tensor, - iou_thr: float = 0.7, - score_thr: float = 0.1, - inner_thr: float = 0.2, -) -> torch.Tensor: - """ - Perform mask non-maximum suppression (NMS) on a set of masks. - - Removes redundant masks based on IoU overlap and inner containment. - This implementation is fully vectorized for efficient GPU computation. - - Args: - masks: Binary masks, shape [num_masks, H, W]. - scores: Mask scores, shape [num_masks]. - iou_thr: IoU threshold for NMS. - score_thr: Minimum score threshold. - inner_thr: Inner overlap threshold for removing contained masks. - - Returns: - Indices of selected masks after NMS. - """ - if len(masks) == 0: - return torch.tensor([], dtype=torch.long) - - scores, idx = scores.sort(0, descending=True) - num_masks = idx.shape[0] - - masks_ord = masks[idx.view(-1), :] - - # Flatten masks for vectorized computation: [N, H*W] - # Use reshape instead of view since indexing may produce non-contiguous tensor - masks_flat = masks_ord.reshape(num_masks, -1).float() - - # Compute all pairwise intersections in one matrix multiply - # For binary masks: intersection[i,j] = sum(mask_i & mask_j) = mask_i @ mask_j.T - intersection = masks_flat @ masks_flat.T # [N, N] - - # Compute areas - masks_area = masks_flat.sum(dim=1) # [N] - - # Compute unions: union[i,j] = area_i + area_j - intersection[i,j] - union = masks_area[:, None] + masks_area[None, :] - intersection - - # Compute IoU matrix (avoid division by zero) - iou_matrix = torch.where(union > 0, intersection / union, torch.zeros_like(union)) - - # Compute containment ratios for inner IoU - # ratio_i[i,j] = intersection[i,j] / area_i - # ratio_j[i,j] = intersection[i,j] / area_j - area_i = masks_area[:, None].expand_as(intersection) - area_j = masks_area[None, :].expand_as(intersection) - - # Avoid division by zero - ratio_i = torch.where(area_i > 0, intersection / area_i, torch.zeros_like(intersection)) - ratio_j = torch.where(area_j > 0, intersection / area_j, torch.zeros_like(intersection)) - - # Inner IoU for containment detection - # Case 1: j is mostly contained in i (ratio_i < 0.5 and ratio_j >= 0.85) - # Case 2: i is mostly contained in j (ratio_i >= 0.85 and ratio_j < 0.5) - inner_iou_values = 1 - ratio_i * ratio_j - - # Build inner_iou_matrix based on containment conditions - inner_iou_matrix = torch.zeros_like(iou_matrix) - - # Upper triangle: j contained in i - condition_upper = (ratio_i < 0.5) & (ratio_j >= 0.85) - inner_iou_matrix = torch.where(condition_upper, inner_iou_values, inner_iou_matrix) - - # Lower triangle: i contained in j (transpose the condition) - condition_lower = (ratio_i >= 0.85) & (ratio_j < 0.5) - inner_iou_matrix = torch.where(condition_lower, inner_iou_values.T, inner_iou_matrix) - - # Apply triangular masks and compute max values - iou_matrix = torch.triu(iou_matrix, diagonal=1) - iou_max, _ = iou_matrix.max(dim=0) - - inner_iou_matrix_u = torch.triu(inner_iou_matrix, diagonal=1) - inner_iou_max_u, _ = inner_iou_matrix_u.max(dim=0) - inner_iou_matrix_l = torch.tril(inner_iou_matrix, diagonal=-1) - inner_iou_max_l, _ = inner_iou_matrix_l.max(dim=0) - - keep = iou_max <= iou_thr - keep_conf = scores > score_thr - keep_inner_u = inner_iou_max_u <= 1 - inner_thr - keep_inner_l = inner_iou_max_l <= 1 - inner_thr - - # If no masks pass thresholds, keep top 3 - if keep_conf.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_conf[index] = True - if keep_inner_u.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_inner_u[index] = True - if keep_inner_l.sum() == 0: - index = scores.topk(min(3, len(scores))).indices - keep_inner_l[index] = True - - keep = keep & keep_conf & keep_inner_u & keep_inner_l - selected_idx = idx[keep] - - return selected_idx - - -def masks_update( - *mask_lists, - iou_thr: float = 0.8, - score_thr: float = 0.7, - inner_thr: float = 0.5, -) -> tuple: - """ - Apply mask NMS to multiple lists of masks. - - Args: - *mask_lists: Variable number of mask lists to filter. - iou_thr: IoU threshold for NMS. - score_thr: Score threshold. - inner_thr: Inner overlap threshold. - - Returns: - Tuple of filtered mask lists. - """ - masks_new = [] - - for masks_lvl in mask_lists: - if len(masks_lvl) == 0: - masks_new.append([]) - continue - - seg_pred = torch.from_numpy(np.stack([m["segmentation"] for m in masks_lvl], axis=0)) - iou_pred = torch.from_numpy(np.stack([m["predicted_iou"] for m in masks_lvl], axis=0)) - stability = torch.from_numpy(np.stack([m["stability_score"] for m in masks_lvl], axis=0)) - - scores = stability * iou_pred - keep_mask_nms = mask_nms(seg_pred, scores, iou_thr=iou_thr, score_thr=score_thr, inner_thr=inner_thr) - - keep_set = set(keep_mask_nms.int().cpu().numpy().tolist()) - filtered_masks = [m for i, m in enumerate(masks_lvl) if i in keep_set] - masks_new.append(filtered_masks) - - return tuple(masks_new) - -def _cross_crop_nms( - mask_list: List[Dict[str, Any]], - iou_threshold: float = 0.7, -) -> List[Dict[str, Any]]: - """Run NMS across masks from multiple crops; prefer masks from smaller crops. - - Args: - mask_list: List of mask records with "bbox" (xywh) and "crop_box" (xywh). - iou_threshold: Box IoU threshold for NMS. - - Returns: - Filtered list of mask records. - """ - if len(mask_list) <= 1: - return mask_list - # bbox and crop_box are [x, y, w, h]; convert bbox to xyxy for batched_nms - boxes_xywh = np.array([m["bbox"] for m in mask_list], dtype=np.float32) - x1 = boxes_xywh[:, 0] - y1 = boxes_xywh[:, 1] - x2 = boxes_xywh[:, 0] + boxes_xywh[:, 2] - y2 = boxes_xywh[:, 1] + boxes_xywh[:, 3] - boxes_xyxy = torch.from_numpy(np.stack([x1, y1, x2, y2], axis=1)) - crop_areas = np.array( - [m["crop_box"][2] * m["crop_box"][3] for m in mask_list], - dtype=np.float32, - ) - scores = torch.from_numpy(1.0 / (crop_areas + 1e-6)) - keep = batched_nms( - boxes_xyxy.float(), - scores, - torch.zeros(len(mask_list), dtype=torch.long), - iou_threshold=iou_threshold, - ) - return [mask_list[i] for i in keep.tolist()] +from .mask_utils import cross_crop_nms, masks_update, postprocess_small_regions @transform @@ -211,14 +34,14 @@ class ComputeMultiScaleSAM2Masks(BaseTransform): to each scale level independently. """ - version = "1.0.0" + version = "1.2.0" def __init__( self, checkpoint: Literal["large", "small", "tiny", "base_plus"] = "large", points_per_side: int = 32, points_per_batch: int = 64, - pred_iou_thresh: float = 0.7, + pred_iou_thresh: float = 0.5, stability_score_thresh: float = 0.85, crop_n_layers: int = 1, crop_n_points_downscale_factor: int = 1, @@ -329,12 +152,43 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: all_m.extend(mm) all_l.extend(ml) + log_diag = not hasattr(self, '_diag_count') or self._diag_count < 2 + if log_diag: + if not hasattr(self, '_diag_count'): + self._diag_count = 0 + self._diag_count += 1 + self._logger.debug( + "[diag] after predict: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + # Cross-crop NMS (prefer smaller crops) if len(crop_boxes) > 1: - all_default = _cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) - all_s = _cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) - all_m = _cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) - all_l = _cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + all_default = cross_crop_nms(all_default, iou_threshold=self._box_nms_thresh) + all_s = cross_crop_nms(all_s, iou_threshold=self._box_nms_thresh) + all_m = cross_crop_nms(all_m, iou_threshold=self._box_nms_thresh) + all_l = cross_crop_nms(all_l, iou_threshold=self._box_nms_thresh) + + if log_diag: + self._logger.debug( + "[diag] after cross-crop NMS: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) + + # Remove small disconnected regions and holes (matches original SAM's + # postprocess_small_regions step in generate_curr_anns) + if self._min_mask_region_area > 0: + nms_thresh = self._box_nms_thresh + all_default = postprocess_small_regions(all_default, self._min_mask_region_area, nms_thresh) + all_s = postprocess_small_regions(all_s, self._min_mask_region_area, nms_thresh) + all_m = postprocess_small_regions(all_m, self._min_mask_region_area, nms_thresh) + all_l = postprocess_small_regions(all_l, self._min_mask_region_area, nms_thresh) + + if log_diag: + self._logger.debug( + "[diag] after postprocess_small_regions: default=%d, s=%d, m=%d, l=%d", + len(all_default), len(all_s), len(all_m), len(all_l), + ) # Mask NMS per scale masks_default, masks_s, masks_m, masks_l = masks_update( @@ -347,6 +201,12 @@ def _generate_multi_scale_masks(self, image: np.ndarray) -> dict: inner_thr=self._nms_inner_thr, ) + if log_diag: + self._logger.debug( + "[diag] after masks_update: default=%d, s=%d, m=%d, l=%d", + len(masks_default), len(masks_s), len(masks_m), len(masks_l), + ) + return { "default": masks_default, "s": masks_s, @@ -370,6 +230,7 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: input_cache: SfmCache = input_scene.cache # Create cache folder + version_safe = self.version.replace(".", "_") cache_prefix = ( f"sam2_multi_scale_masks_{self._checkpoint}_" f"p{self._points_per_side}_" @@ -378,7 +239,8 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: f"crop{self._crop_n_layers}_" f"nmsiou{int(self._nms_iou_thr * 100)}_" f"nmsscore{int(self._nms_score_thr * 100)}_" - f"nmsinner{int(self._nms_inner_thr * 100)}" + f"nmsinner{int(self._nms_inner_thr * 100)}_" + f"v{version_safe}" ) output_cache = input_cache.make_folder( cache_prefix, @@ -433,84 +295,87 @@ def __call__(self, input_scene: SfmScene) -> SfmScene: if regenerate_cache: self._logger.info("Generating multi-scale SAM2 masks for all images.") # Suppress SAM2's per-image INFO (e.g. "Computing image embeddings...") + # SAM2 logs through root, so we must suppress root. Setting our own + # logger level explicitly ensures propagated messages still reach + # root's handlers (propagation skips the parent's level check). _root = logging.getLogger() - _sam2 = logging.getLogger("sam2") _prev_root = _root.level - _prev_sam2 = _sam2.level - try: - _root.setLevel(logging.WARNING) - _sam2.setLevel(logging.WARNING) - except Exception: - # Silently ignore errors setting logging levels - pass - - pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM2 masks") - - for image_meta in pbar: - image_path = image_meta.image_path - img = cv2.imread(image_path) - assert img is not None, f"Failed to load image {image_path}" - - # Undistort the image if the camera has distortion parameters - img = image_meta.camera_metadata.undistort_image(img) - - # Generate multi-scale masks - masks_dict = self._generate_multi_scale_masks(img) - - # Convert masks to storable format - mask_data = {} - for scale_name, masks in masks_dict.items(): - if len(masks) > 0: - # Store segmentation masks and metadata - mask_data[f"{scale_name}_segmentations"] = np.stack( - [m["segmentation"].astype(np.uint8) for m in masks], axis=0 - ) - mask_data[f"{scale_name}_bboxes"] = np.array([m["bbox"] for m in masks], dtype=np.float32) - mask_data[f"{scale_name}_areas"] = np.array([m["area"] for m in masks], dtype=np.int32) - mask_data[f"{scale_name}_predicted_ious"] = np.array( - [m["predicted_iou"] for m in masks], dtype=np.float32 - ) - mask_data[f"{scale_name}_stability_scores"] = np.array( - [m["stability_score"] for m in masks], dtype=np.float32 - ) - else: - # Empty arrays for scales with no masks - mask_data[f"{scale_name}_segmentations"] = np.zeros( - (0, img.shape[0], img.shape[1]), dtype=np.uint8 - ) - mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) - mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) - mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) - mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) - - # Save to cache - cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" - output_cache.write_file( - name=cache_filename, - data=mask_data, - data_type="pt", - metadata={ - "checkpoint": self._checkpoint, - "points_per_side": self._points_per_side, - "pred_iou_thresh": self._pred_iou_thresh, - "stability_score_thresh": self._stability_score_thresh, - "crop_n_layers": self._crop_n_layers, - "min_mask_region_area": self._min_mask_region_area, - "nms_iou_thr": self._nms_iou_thr, - "nms_score_thr": self._nms_score_thr, - "nms_inner_thr": self._nms_inner_thr, - }, - ) + _prev_self = self._logger.level + _sam2_model_logger = logging.getLogger("fvdb_reality_capture.foundation_models.sam2.SAM2Model") + _prev_sam2_model = _sam2_model_logger.level + self._logger.setLevel(logging.INFO) + _sam2_model_logger.setLevel(logging.INFO) + _root.setLevel(logging.WARNING) - pbar.close() - # Restore logging levels try: + pbar = tqdm.tqdm(input_scene.images, unit="imgs", desc="Generating SAM2 masks") + + for image_meta in pbar: + image_path = image_meta.image_path + img = cv2.imread(image_path) + assert img is not None, f"Failed to load image {image_path}" + + # Undistort the image if the camera has distortion parameters + img = image_meta.camera_metadata.undistort_image(img) + + # Generate multi-scale masks + masks_dict = self._generate_multi_scale_masks(img) + + # Convert masks to storable format + mask_data = {} + for scale_name, masks in masks_dict.items(): + if len(masks) > 0: + # Store segmentation masks and metadata + mask_data[f"{scale_name}_segmentations"] = np.stack( + [m["segmentation"].astype(np.uint8) for m in masks], axis=0 + ) + mask_data[f"{scale_name}_bboxes"] = np.array( + [m["bbox"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_areas"] = np.array( + [m["area"] for m in masks], dtype=np.int32 + ) + mask_data[f"{scale_name}_predicted_ious"] = np.array( + [m["predicted_iou"] for m in masks], dtype=np.float32 + ) + mask_data[f"{scale_name}_stability_scores"] = np.array( + [m["stability_score"] for m in masks], dtype=np.float32 + ) + else: + # Empty arrays for scales with no masks + mask_data[f"{scale_name}_segmentations"] = np.zeros( + (0, img.shape[0], img.shape[1]), dtype=np.uint8 + ) + mask_data[f"{scale_name}_bboxes"] = np.zeros((0, 4), dtype=np.float32) + mask_data[f"{scale_name}_areas"] = np.zeros(0, dtype=np.int32) + mask_data[f"{scale_name}_predicted_ious"] = np.zeros(0, dtype=np.float32) + mask_data[f"{scale_name}_stability_scores"] = np.zeros(0, dtype=np.float32) + + # Save to cache + cache_filename = f"masks_{image_meta.image_id:0{num_zeropad}}" + output_cache.write_file( + name=cache_filename, + data=mask_data, + data_type="pt", + metadata={ + "checkpoint": self._checkpoint, + "points_per_side": self._points_per_side, + "pred_iou_thresh": self._pred_iou_thresh, + "stability_score_thresh": self._stability_score_thresh, + "crop_n_layers": self._crop_n_layers, + "min_mask_region_area": self._min_mask_region_area, + "nms_iou_thr": self._nms_iou_thr, + "nms_score_thr": self._nms_score_thr, + "nms_inner_thr": self._nms_inner_thr, + }, + ) + + pbar.close() + self._logger.info(f"Generated masks for {input_scene.num_images} images.") + finally: _root.setLevel(_prev_root) - _sam2.setLevel(_prev_sam2) - except Exception: - # Silently ignore errors restoring logging levels - pass - self._logger.info(f"Generated masks for {input_scene.num_images} images.") + self._logger.setLevel(_prev_self) + _sam2_model_logger.setLevel(_prev_sam2_model) else: self._logger.info("Loading masks from cache.") diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py index 2be8378..86eedf7 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/dataset.py @@ -98,11 +98,24 @@ def warmup_cache(self) -> None: """ import tqdm + n_empty = 0 if self._cache_features: for idx in tqdm.tqdm(range(len(self)), desc="Warming up feature cache"): index = self._indices[idx] - if index not in self._features_cache: - self.get_feature_data(index) + feature_data = self._features_cache.get(index) + if feature_data is None: + feature_data = self.get_feature_data(index) + _, seg_map, _ = feature_data + if not (seg_map >= 0).any(): + n_empty += 1 + + if n_empty > 0: + logger.warning( + "%d / %d training images (%.1f%%) have NO valid masks at level %d " + "-- these contribute zero gradient during training", + n_empty, len(self), 100 * n_empty / len(self), + self.feature_level, + ) if self._cache_images: for idx in tqdm.tqdm(range(len(self)), desc="Warming up image cache"): @@ -217,9 +230,7 @@ def build_feature_map( # Unbatched path: plain tensor [N_masks, clip_n_dims] H, W = seg_map.shape feature_mask = seg_map >= 0 - # Use empty -- invalid pixels are never read (the loss gathers only - # valid pixels via the mask). - gt_features = torch.empty( + gt_features = torch.zeros( H, W, clip_n_dims, dtype=features.dtype, device=features.device ) if feature_mask.any(): @@ -250,9 +261,7 @@ def build_feature_map( dtype = features.jdata.dtype feature_mask = seg_map >= 0 # [B, H, W] - # Use empty -- invalid pixels are never read (the loss gathers only - # valid pixels via the mask). - gt_features = torch.empty( + gt_features = torch.zeros( B, H, W, clip_n_dims, dtype=dtype, device=device ) for b in range(B): diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py index 706b835..25c614d 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/langsplatv2_writer.py @@ -312,6 +312,26 @@ def save_checkpoint( torch.save(checkpoint, ckpt_path) self._logger.info(f"Saved checkpoint to {ckpt_path}") + @torch.no_grad() + def save_final_checkpoint(self, checkpoint: dict[str, Any]) -> pathlib.Path | None: + """Save a ``final_checkpoint.pt`` at the run's top-level directory. + + This makes it easy for evaluation scripts to find the finished + checkpoint without globbing through step-numbered subdirectories. + + Args: + checkpoint: Checkpoint dictionary (from ``Trainer.state_dict()``). + + Returns: + Path to the saved file, or *None* if saving is disabled. + """ + if not self._config.save_checkpoints or self._save_path is None: + return None + final_path = self._save_path / "final_checkpoint.pt" + torch.save(checkpoint, final_path) + self._logger.info(f"Saved final checkpoint to {final_path}") + return final_path + # ------------------------------------------------------------------ # Helpers (matching GARfVDB writer helpers) # ------------------------------------------------------------------ diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py index 9885278..5859278 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/training/trainer.py @@ -25,6 +25,7 @@ from ..util import calculate_pca_projection, cosine_error_map, pca_projection_fast from ..vq_utils import ( ResidualVectorQuantizationWithClustering, + load_all_clip_features, load_clip_features_for_level, ) from .dataset import ( @@ -226,7 +227,6 @@ def new( ) # Initialize codebooks via K-means on CLIP features - logger.info("Loading CLIP features for codebook initialization...") all_dataset = LangSplatV2Dataset( sfm_scene=sfm_scene, feature_level=config.feature_level, @@ -234,10 +234,15 @@ def new( cache_features=False, cache_images=False, ) - clip_features = load_clip_features_for_level( - full_dataset=all_dataset, - feature_level=config.feature_level, - ) + if config.init_codebooks_all_levels: + logger.info("Loading CLIP features from ALL scale levels for codebook initialization...") + clip_features = load_all_clip_features(full_dataset=all_dataset) + else: + logger.info(f"Loading CLIP features for level {config.feature_level} for codebook initialization...") + clip_features = load_clip_features_for_level( + full_dataset=all_dataset, + feature_level=config.feature_level, + ) logger.info(f"Loaded {clip_features.shape[0]:,} CLIP features of dimension {clip_features.shape[1]}") # Run K-means to get initial codebooks diff --git a/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py b/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py index 17357f1..a352189 100644 --- a/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py +++ b/open_vocabulary_segmentation/langsplatv2/langsplatv2/vq_utils.py @@ -166,3 +166,32 @@ def load_clip_features_for_level( raise RuntimeError(f"No features found for level {feature_level}") return torch.cat(all_features, dim=0).float() + + +def load_all_clip_features( + full_dataset: LangSplatV2Dataset, +) -> torch.Tensor: + """Load CLIP features from ALL scale levels across all images. + + This matches the original LangSplatV2 codebook initialization which + concatenates every ``*_f.npy`` file (each containing features from all + 4 scales) into a single tensor for k-means. + + Args: + full_dataset: The full dataset containing all features. + + Returns: + Feature tensor of shape ``[N_all, clip_n_dims]`` containing + features from all scale levels and all images. + """ + all_features = [] + + for image_id in range(len(full_dataset)): + features, _, _ = full_dataset.get_feature_data(image_id) + if features.shape[0] > 0: + all_features.append(features) + + if len(all_features) == 0: + raise RuntimeError("No features found across any images") + + return torch.cat(all_features, dim=0).float() diff --git a/open_vocabulary_segmentation/langsplatv2/pyproject.toml b/open_vocabulary_segmentation/langsplatv2/pyproject.toml index 47437f2..ee81bc6 100644 --- a/open_vocabulary_segmentation/langsplatv2/pyproject.toml +++ b/open_vocabulary_segmentation/langsplatv2/pyproject.toml @@ -7,17 +7,21 @@ name = "fvdb-langsplatv2" version = "0.0.1" description = "fVDB Implementation of LangSplatV2" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10" dependencies = [ + "fvdb-reality-capture", + "gdown", + "matplotlib", "numpy", - "torch", - "torchvision", "opencv-python", - "tqdm", "open-clip-torch", - "fvdb-reality-capture", + "segment-anything", + "sam2", "scikit-learn", + "torch", + "torchvision", + "tqdm", "tyro", ] diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py b/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py new file mode 100644 index 0000000..7369eb7 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/scripts/query_prompt.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Visualise the relevancy mask for a text prompt on a trained LangSplatV2 model. + +Loads a single-level checkpoint, renders CLIP features for a chosen camera +view, computes OpenCLIP relevancy against the user's prompt, and writes an +image showing the source photograph, the relevancy heatmap, and the +thresholded binary mask side-by-side. + +Usage: + python scripts/query_prompt.py \\ + --checkpoint langsplatv2_results/scene_level_2.pt \\ + --reconstruction-path reconstructions/scene.ply \\ + --prompt "coffee cup" + + # Specify a different view and output path: + python scripts/query_prompt.py \\ + --checkpoint langsplatv2_results/scene_level_1.pt \\ + --reconstruction-path reconstructions/scene.ply \\ + --prompt "wooden table" \\ + --image-index 42 \\ + --output table_query.jpg +""" +import argparse +import logging +import pathlib + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from langsplatv2.evaluation.openclip_relevancy import OpenCLIPRelevancy +from langsplatv2.training.trainer import LangSplatV2Trainer +from langsplatv2.util import load_splats_from_file + +matplotlib.use("Agg") + +logger = logging.getLogger(__name__) + + +def _load_model(checkpoint_path, gs_model_path, device, eval_topk=None): + """Load a trained LangSplatV2 model and its embedded SfmScene.""" + gs_model, _ = load_splats_from_file(gs_model_path, device) + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False) + trainer = LangSplatV2Trainer.from_state_dict( + state_dict=state_dict, + gs_model=gs_model, + gs_model_path=gs_model_path, + device=device, + eval_only=True, + ) + model = trainer._model + sfm_scene = trainer._sfm_scene + feature_level = trainer._cfg.feature_level + + if eval_topk is not None and model.topk != eval_topk: + logger.info(f"Overriding topk: {model.topk} -> {eval_topk}") + model.topk = eval_topk + + logger.info( + f"Loaded model (feature_level={feature_level}, topk={model.topk}) " f"with {gs_model.num_gaussians:,} Gaussians" + ) + return model, sfm_scene, feature_level + + +def _get_camera(sfm_scene, image_index, device): + """Return (w2c, K, h, w, image_path) for the given image index.""" + sorted_images = sorted(sfm_scene.images, key=lambda img: img.image_path) + if image_index >= len(sorted_images): + raise IndexError(f"--image-index {image_index} out of range " f"(scene has {len(sorted_images)} images)") + img = sorted_images[image_index] + c2w = torch.from_numpy(img.camera_to_world_matrix).float() + K = torch.from_numpy(img.camera_metadata.projection_matrix).float() + w2c = torch.linalg.inv(c2w).contiguous() + return ( + w2c.unsqueeze(0).to(device), + K.unsqueeze(0).to(device), + img.camera_metadata.height, + img.camera_metadata.width, + img.image_path, + ) + + +@torch.no_grad() +def _render_clip(model, w2c, K, width, height): + """Render normalised CLIP features [H, W, 512].""" + feat_maps, _ = model( + world_to_camera=w2c, + projection=K, + image_width=width, + image_height=height, + ) + feat = feat_maps[0] # [H, W, 512] + return feat / (feat.norm(dim=-1, keepdim=True) + 1e-10) + + +def _smooth_mask(mask: torch.Tensor) -> torch.Tensor: + pool = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=3, count_include_pad=False) + pool = pool.to(mask.device) + smoothed = pool(mask.float().unsqueeze(0).unsqueeze(0)) + return (smoothed > 0.5).to(torch.uint8).squeeze(0).squeeze(0) + + +def _relevancy_to_mask(relevancy_2d: torch.Tensor, thresh: float) -> torch.Tensor: + """Convert a raw relevancy map [H, W] to a binary mask. + + Applies the same AvgPool(29) blending, normalisation, and thresholding + that the original LangSplatV2 evaluation uses. + """ + pool = torch.nn.AvgPool2d(kernel_size=29, stride=1, padding=14, count_include_pad=False) + pool = pool.to(relevancy_2d.device) + blended = 0.5 * (pool(relevancy_2d.unsqueeze(0).unsqueeze(0)).squeeze() + relevancy_2d) + blended = blended - blended.min() + blended = blended / (blended.max() + 1e-9) + blended = blended * 2.0 - 1.0 + blended = torch.clip(blended, 0, 1) + return _smooth_mask((blended > thresh).to(torch.uint8)) + + +def _save_visualization( + output_path: pathlib.Path, + prompt: str, + feature_level: int, + relevancy: np.ndarray, + mask: np.ndarray, + source_img: np.ndarray | None, + image_index: int, +): + """Write a side-by-side panel image to *output_path*.""" + has_source = source_img is not None + n_cols = 3 if has_source else 2 + fig, axes = plt.subplots(1, n_cols, figsize=(8 * n_cols, 8)) + + col = 0 + + if has_source: + overlay = source_img.astype(np.float32) / 255.0 + mask_rgb = np.zeros_like(overlay) + mask_rgb[:, :, 0] = mask.astype(np.float32) + blended = np.clip(overlay * 0.6 + mask_rgb * 0.4, 0, 1) + axes[col].imshow(blended) + axes[col].set_title(f"Source (image {image_index}) + mask overlay", fontsize=12) + axes[col].axis("off") + col += 1 + + im = axes[col].imshow(relevancy, cmap="turbo", vmin=0, vmax=1) + axes[col].set_title(f"Relevancy (level {feature_level})", fontsize=12) + axes[col].axis("off") + plt.colorbar(im, ax=axes[col], fraction=0.046, pad=0.04) + col += 1 + + axes[col].imshow(mask, cmap="gray", vmin=0, vmax=1) + axes[col].set_title("Binary mask", fontsize=12) + axes[col].axis("off") + + fig.suptitle(f'"{prompt}"', fontsize=16, fontweight="bold") + fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(output_path), dpi=150, bbox_inches="tight") + plt.close(fig) + logger.info(f"Saved visualisation to {output_path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Visualise relevancy for a text prompt on a trained LangSplatV2 model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--checkpoint", + type=pathlib.Path, + required=True, + help="Path to a trained LangSplatV2 checkpoint (.pt) for one feature level.", + ) + parser.add_argument( + "--reconstruction-path", + type=pathlib.Path, + required=True, + help="Path to the Gaussian splat reconstruction (.ply or .pt).", + ) + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Text query to compute relevancy for (e.g. 'coffee cup').", + ) + parser.add_argument( + "--image-index", + type=int, + default=0, + help="Which dataset image (camera view) to render.", + ) + parser.add_argument( + "--output", + type=pathlib.Path, + default=None, + help="Output image path. Defaults to '_level_img.jpg'.", + ) + parser.add_argument( + "--mask-thresh", + type=float, + default=0.4, + help="Threshold for converting relevancy to a binary mask.", + ) + parser.add_argument( + "--eval-topk", + type=int, + default=None, + help="Override the checkpoint's topk value for codebook decoding.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Torch device.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable debug logging.", + ) + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + + device = torch.device(args.device) + + model, sfm_scene, feature_level = _load_model( + args.checkpoint, + args.reconstruction_path, + device, + args.eval_topk, + ) + + w2c, K, img_h, img_w, image_path = _get_camera(sfm_scene, args.image_index, device) + logger.info(f"Rendering view {args.image_index}: {image_path} ({img_w}x{img_h})") + + source_img = None + if image_path and pathlib.Path(image_path).is_file(): + bgr = cv2.imread(str(image_path)) + if bgr is not None: + source_img = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + + clip_feat = _render_clip(model, w2c, K, img_w, img_h) + + clip_relevancy = OpenCLIPRelevancy(device=args.device) + clip_relevancy.set_positives([args.prompt]) + + # get_relevancy_map expects [n_levels, H, W, 512] + relevancy = clip_relevancy.get_relevancy_map(clip_feat.unsqueeze(0)) # [1, 1, H, W] + relevancy_2d = relevancy[0, 0] # [H, W] + + mask = _relevancy_to_mask(relevancy_2d, args.mask_thresh) + + if args.output is None: + safe_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:40] + args.output = pathlib.Path(f"{safe_prompt}_level{feature_level}_img{args.image_index}.jpg") + + _save_visualization( + args.output, + args.prompt, + feature_level, + relevancy_2d.cpu().numpy(), + mask.cpu().numpy(), + source_img, + args.image_index, + ) + + print(f"\nWrote {args.output}") + + +if __name__ == "__main__": + main() diff --git a/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py new file mode 100644 index 0000000..dcc11c4 --- /dev/null +++ b/open_vocabulary_segmentation/langsplatv2/scripts/train_all_levels.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Train all 3 LangSplatV2 feature levels for a single scene. + +Wraps ``train_langsplatv2.py``, running it once per feature level and +collecting the final checkpoints into a results directory ready for +evaluation. + +Any arguments not recognised by this wrapper are forwarded verbatim to +``train_langsplatv2.py`` (e.g. ``--config.max-steps``, +``--preprocess.sam-model``, ``--dataset-type``). + +Usage: + # Minimal -- trains levels 1, 2, 3 with default settings: + python scripts/train_all_levels.py \\ + --sfm-dataset-path /path/to/colmap/scene \\ + --reconstruction-path /path/to/scene.ply + + # With extra training options: + python scripts/train_all_levels.py \\ + --sfm-dataset-path /data/my_scene \\ + --reconstruction-path /data/my_scene.ply \\ + --name my_scene \\ + --results-dir my_results \\ + --config.max-steps 10000 \\ + --preprocess.sam-model sam1 +""" +import argparse +import pathlib +import shutil +import subprocess +import sys + +SCRIPT_DIR = pathlib.Path(__file__).resolve().parent +TRAIN_SCRIPT = SCRIPT_DIR / "train_langsplatv2.py" + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Train LangSplatV2 at all 3 feature levels and collect checkpoints.", + epilog=( + "Unrecognised arguments are forwarded to train_langsplatv2.py. " + "For example: --config.max-steps 10000 --preprocess.sam-model sam1" + ), + ) + parser.add_argument( + "--sfm-dataset-path", + type=pathlib.Path, + required=True, + help="Path to the SfM dataset (COLMAP, simple_directory, or E57).", + ) + parser.add_argument( + "--reconstruction-path", + type=pathlib.Path, + required=True, + help="Path to the pre-trained Gaussian splat reconstruction (.ply or .pt).", + ) + parser.add_argument( + "--results-dir", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_results"), + help="Directory to collect final checkpoints into (default: langsplatv2_results).", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Scene name used for run directories and checkpoint files. " + "Defaults to the reconstruction file stem.", + ) + parser.add_argument( + "--log-path", + type=pathlib.Path, + default=pathlib.Path("langsplatv2_logs"), + help="Directory for per-run training logs and checkpoints (default: langsplatv2_logs).", + ) + parser.add_argument( + "--levels", + type=int, + nargs="+", + default=[1, 2, 3], + help="Feature levels to train (default: 1 2 3).", + ) + + args, extra = parser.parse_known_args() + name = args.name or args.reconstruction_path.stem + + failed_levels: list[int] = [] + + total_levels = len(args.levels) + for idx, level in enumerate(args.levels, 1): + run_name = f"{name}_level_{level}" + cmd = [ + sys.executable, + str(TRAIN_SCRIPT), + "--sfm-dataset-path", str(args.sfm_dataset_path), + "--reconstruction-path", str(args.reconstruction_path), + "--config.feature-level", str(level), + "--run-name", run_name, + "--log-path", str(args.log_path), + *extra, + ] + + print(f"\n{'=' * 60}") + print(f" Training level {level} ({idx}/{total_levels}): {run_name}") + print(f"{'=' * 60}\n") + + result = subprocess.run(cmd) + if result.returncode != 0: + print(f"\nERROR: Training failed for level {level} (exit code {result.returncode})") + failed_levels.append(level) + + # -- Collect checkpoints ------------------------------------------------ + args.results_dir.mkdir(parents=True, exist_ok=True) + collected = 0 + + for level in args.levels: + if level in failed_levels: + continue + run_name = f"{name}_level_{level}" + src = args.log_path / run_name / "final_checkpoint.pt" + dst = args.results_dir / f"{run_name}.pt" + if src.exists(): + shutil.copy2(src, dst) + print(f" Collected: {src} -> {dst}") + collected += 1 + else: + print(f" WARNING: {src} not found, skipping") + + # -- Summary ------------------------------------------------------------ + print(f"\n{'=' * 60}") + print(f" Done -- {collected}/{len(args.levels)} checkpoints in {args.results_dir}") + if failed_levels: + print(f" Failed levels: {failed_levels}") + print(f"{'=' * 60}") + + if failed_levels: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py b/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py similarity index 95% rename from open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py rename to open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py index 49f8145..d39ade6 100644 --- a/open_vocabulary_segmentation/langsplatv2/train_langsplatv2.py +++ b/open_vocabulary_segmentation/langsplatv2/scripts/train_langsplatv2.py @@ -136,10 +136,16 @@ def main( runner.train() + # Save a final_checkpoint.pt at the top level of the run directory + # so evaluation scripts can find it without globbing step subdirectories + final_path = writer.save_final_checkpoint(runner.state_dict()) + logger.info("=" * 60) logger.info("Training complete!") if writer.log_path is not None: logger.info(f"Results saved to {writer.log_path}") + if final_path is not None: + logger.info(f"Final checkpoint: {final_path}") logger.info("=" * 60) writer.close()