From 728bdc51e4b22563cbb840acbbc36726529fdace Mon Sep 17 00:00:00 2001 From: Zoe Date: Wed, 10 Jun 2026 16:37:18 -0400 Subject: [PATCH] code to rip segements from splat using garfvdb Signed-off-by: Zoe --- .../garfvdb/extract_segments.py | 466 ++++++++++++++++++ .../extract_mesh_for_segments.py | 322 ++++++++++++ 2 files changed, 788 insertions(+) create mode 100644 instance_segmentation/garfvdb/extract_segments.py create mode 100644 segmenation_extraction/extract_mesh_for_segments.py diff --git a/instance_segmentation/garfvdb/extract_segments.py b/instance_segmentation/garfvdb/extract_segments.py new file mode 100644 index 0000000..99b5868 --- /dev/null +++ b/instance_segmentation/garfvdb/extract_segments.py @@ -0,0 +1,466 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Export instance segments from a trained GARfVDB segmentation checkpoint. + +Loads a Gaussian splat reconstruction and segmentation checkpoint, clusters +per-Gaussian affinity features at a chosen scale, filters clusters, then writes +``n`` segment ``.ply`` files (one per selected cluster). + +Example:: + + python extract_segments.py \\ + -s garfvdb_logs/run/checkpoints/00036600/train_ckpt.pt \\ + -r frgs_logs/safety_park_1/checkpoints/00024800/reconstruct_ckpt.pt \\ + -o segments/ \\ + --n 5 \\ + --scale 0.1 +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path +from typing import Literal + +import numpy as np +import torch +from fvdb import GaussianSplat3d +from scipy.spatial import cKDTree + +# Import directly to avoid fvdb_reality_capture.tools.__init__ pulling optional deps (e.g. DLNR). +from fvdb_reality_capture.tools._filter_splats import ( + filter_splats_above_scale, + filter_splats_by_mean_percentile, + filter_splats_by_opacity_percentile, +) +from garfvdb.training.segmentation import GaussianSplatScaleConditionedSegmentation +from garfvdb.util import load_splats_from_file + +logger = logging.getLogger(__name__) + + +def load_segmentation_runner_from_checkpoint( + checkpoint_path: Path, + gs_model: GaussianSplat3d, + gs_model_path: Path, + device: str | torch.device = "cuda", +) -> GaussianSplatScaleConditionedSegmentation: + """Restore a segmentation runner from a training checkpoint. + Args: + checkpoint_path: Path to the segmentation checkpoint. + gs_model: The Gaussian splat model to use for the segmentation. + gs_model_path: Path to the Gaussian splat reconstruction. + device: The device to use for the segmentation. + """ + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + runner = GaussianSplatScaleConditionedSegmentation.from_state_dict( + state_dict=checkpoint, + gs_model=gs_model, + gs_model_path=gs_model_path, + device=device, + eval_only=True, + ) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return runner + + +def _is_gpu_oom_error(exc: BaseException) -> bool: + """Return whether an exception indicates GPU out-of-memory. + Args: + exc: The exception raised during clustering. + """ + return "out_of_memory" in str(exc).lower() or isinstance(exc, MemoryError) + + +def _drop_clusters( + cluster_splats: dict[int, GaussianSplat3d], + cluster_coherence: dict[int, float], + keys: list[int], +) -> None: + """Remove clusters from the splat and coherence maps in place. + Args: + cluster_splats: Cluster label to Gaussian splat mapping. + cluster_coherence: Cluster label to coherence score mapping. + keys: Cluster labels to remove. + """ + for key in keys: + cluster_splats.pop(key, None) + cluster_coherence.pop(key, None) + + +def _subsample_gaussians( + gs_model: GaussianSplat3d, + max_gaussians: int, + seed: int, + device: torch.device, +) -> tuple[GaussianSplat3d, torch.Tensor]: + """Randomly subsample Gaussians for memory-bounded clustering. + Args: + gs_model: Full-scene Gaussian splat model. + max_gaussians: Maximum number of Gaussians to keep. + seed: Random seed for reproducible subsampling. + device: Torch device for the returned mask. + """ + rng = np.random.default_rng(seed) + indices = rng.choice(gs_model.num_gaussians, size=max_gaussians, replace=False) + mask = torch.zeros(gs_model.num_gaussians, dtype=torch.bool, device=device) + mask[torch.from_numpy(indices).to(device)] = True + return gs_model[mask], mask + + +def _map_cluster_labels_to_full_scene( + cluster_labels_sub: torch.Tensor, + cluster_probs_sub: torch.Tensor, + gs_model: GaussianSplat3d, + clustering_gs_model: GaussianSplat3d, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Map cluster labels from a subsampled set back to the full scene. + Args: + cluster_labels_sub: Cluster labels on the subsampled Gaussians. + cluster_probs_sub: Cluster probabilities on the subsampled Gaussians. + gs_model: Full-scene Gaussian splat model. + clustering_gs_model: Subsampled model used for clustering. + device: Torch device for returned tensors. + """ + tree = cKDTree(clustering_gs_model.means.cpu().numpy()) + _, nearest_indices = tree.query(gs_model.means.cpu().numpy(), k=1, workers=-1) + index_tensor = torch.from_numpy(nearest_indices).to(device) + return cluster_labels_sub[index_tensor], cluster_probs_sub[index_tensor] + + +def _filter_high_variance_clusters( + cluster_splats: dict[int, GaussianSplat3d], + cluster_coherence: dict[int, float], + variance_threshold: float, +) -> list[int]: + """Find spatially incoherent clusters by normalized variance. + Args: + cluster_splats: Cluster label to Gaussian splat mapping. + cluster_coherence: Cluster label to coherence score mapping. + variance_threshold: Normalized variance cutoff (variance / extent^2). + """ + removed: list[int] = [] + for label, splat in list(cluster_splats.items()): + means = splat.means + extent = (means.max(dim=0).values - means.min(dim=0).values).max().item() + if extent > 1e-6: + norm_variance = means.var(dim=0).mean().item() / (extent**2) + else: + norm_variance = 0.0 + if norm_variance > variance_threshold: + removed.append(label) + return removed + + +@torch.inference_mode() +def rip_segments( + *, + segmentation_path: Path, + reconstruction_path: Path, + out_dir: Path, + n: int, + scale: float, + scale_is_fraction_of_max: bool, + seed: int, + device: str, + min_splat_scale: float, + opacity_percentile: float, + mean_percentile: tuple[float, ...], + min_cluster_gaussians: int, + filter_high_variance: bool, + variance_threshold: float, + sample_by: Literal["random", "coherence"], + max_gaussians_for_clustering: int, + verbose: bool, +) -> None: + """Cluster Gaussians and export ``n`` segment PLY files.""" + log_level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + + # Clustering depends on GPU libs (cuml/cupy); import lazily so --help works without them. + from garfvdb.evaluation.clustering import ( # noqa: PLC0415 + compute_cluster_labels, + split_gaussians_into_clusters, + ) + + device_t = torch.device(device) + + if not segmentation_path.exists(): + raise FileNotFoundError(f"Segmentation checkpoint not found: {segmentation_path}") + if not reconstruction_path.exists(): + raise FileNotFoundError(f"Reconstruction checkpoint not found: {reconstruction_path}") + if n <= 0: + raise ValueError(f"--n must be > 0, got {n}") + + out_dir.mkdir(parents=True, exist_ok=True) + + logger.info("Loading Gaussian splat model from %s", reconstruction_path) + gs_model, original_metadata = load_splats_from_file(reconstruction_path, device_t) + logger.info("Loaded %s Gaussians (pre-filter)", f"{gs_model.num_gaussians:,}") + + if min_splat_scale > 0: + gs_model = filter_splats_above_scale(gs_model, min_splat_scale) + if opacity_percentile > 0: + gs_model = filter_splats_by_opacity_percentile(gs_model, percentile=opacity_percentile) + if mean_percentile: + gs_model = filter_splats_by_mean_percentile(gs_model, percentile=list(mean_percentile)) + logger.info("Remaining %s Gaussians (post-filter)", f"{gs_model.num_gaussians:,}") + + runner = load_segmentation_runner_from_checkpoint( + checkpoint_path=segmentation_path, + gs_model=gs_model, + gs_model_path=reconstruction_path, + device=device_t, + ) + gs_model = runner.gs_model + segmentation_model = runner.model + + max_scale = float(segmentation_model.max_grouping_scale.item()) + scale_abs = float(scale) * max_scale if scale_is_fraction_of_max else float(scale) + logger.info("Segmentation model max scale: %.6f", max_scale) + logger.info("Clustering at scale: %.6f", scale_abs) + + clustering_gs_model = gs_model + subsample_mask: torch.Tensor | None = None + if max_gaussians_for_clustering > 0 and gs_model.num_gaussians > max_gaussians_for_clustering: + logger.warning( + "Scene has %s gaussians (> %s); subsampling for clustering, then mapping labels back.", + f"{gs_model.num_gaussians:,}", + f"{max_gaussians_for_clustering:,}", + ) + clustering_gs_model, subsample_mask = _subsample_gaussians( + gs_model, max_gaussians_for_clustering, seed, device_t + ) + logger.info( + "Clustering on %s subsampled gaussians", + f"{clustering_gs_model.num_gaussians:,}", + ) + + mask_features = segmentation_model.get_gaussian_affinity_output(scale_abs) + if subsample_mask is not None: + mask_features = mask_features[subsample_mask] + + try: + cluster_labels_sub, cluster_probs_sub = compute_cluster_labels(mask_features, device=device_t) + except Exception as exc: + if not _is_gpu_oom_error(exc): + raise + logger.error( + "GPU OOM while clustering %s gaussians. Try tighter pre-filters or a lower " + "--max-gaussians-for-clustering (current: %s). Target: <500k for clustering.", + f"{clustering_gs_model.num_gaussians:,}", + f"{max_gaussians_for_clustering:,}", + ) + raise RuntimeError( + f"Clustering failed due to GPU memory ({clustering_gs_model.num_gaussians:,} gaussians)." + ) from exc + + if subsample_mask is not None: + logger.info("Mapping cluster labels back to all gaussians via nearest neighbor...") + cluster_labels, cluster_probs = _map_cluster_labels_to_full_scene( + cluster_labels_sub, + cluster_probs_sub, + gs_model, + clustering_gs_model, + device_t, + ) + else: + cluster_labels, cluster_probs = cluster_labels_sub, cluster_probs_sub + + cluster_splats, cluster_coherence, _ = split_gaussians_into_clusters(cluster_labels, cluster_probs, gs_model) + + if min_cluster_gaussians > 0: + removed_small = [ + label for label, splat in cluster_splats.items() if splat.num_gaussians < min_cluster_gaussians + ] + _drop_clusters(cluster_splats, cluster_coherence, removed_small) + if removed_small: + logger.info( + "Removed %d clusters with < %d gaussians", + len(removed_small), + min_cluster_gaussians, + ) + + if filter_high_variance: + removed_variance = _filter_high_variance_clusters(cluster_splats, cluster_coherence, variance_threshold) + _drop_clusters(cluster_splats, cluster_coherence, removed_variance) + if removed_variance: + logger.info( + "Removed %d spatially incoherent clusters (variance_threshold=%.4f)", + len(removed_variance), + variance_threshold, + ) + + if not cluster_splats: + raise RuntimeError("No clusters remaining after filtering; relax filters or try a different --scale.") + + labels = sorted(cluster_splats.keys()) + n_to_export = min(n, len(labels)) + if n_to_export < n: + logger.warning( + "Requested n=%d but only %d clusters available; exporting %d.", + n, + len(labels), + n_to_export, + ) + + if sample_by == "coherence": + chosen_labels = [ + label + for label, _ in sorted(cluster_coherence.items(), key=lambda item: item[1], reverse=True)[:n_to_export] + ] + else: + rng = np.random.default_rng(seed) + chosen_labels = rng.choice(np.array(labels, dtype=np.int64), size=n_to_export, replace=False).tolist() + + for i, label in enumerate(chosen_labels): + splat = cluster_splats[int(label)] + coherence = float(cluster_coherence[int(label)]) + ply_path = out_dir / (f"segment_{i:04d}_cluster{int(label)}_coh{coherence:.3f}_n{splat.num_gaussians}.ply") + if ply_path.exists(): + logger.warning("Overwriting existing file: %s", ply_path) + + metadata: dict = {} + if original_metadata: + metadata.update(original_metadata) + metadata.update( + { + "cluster_id": int(label), + "coherence": coherence, + "num_gaussians": int(splat.num_gaussians), + "scale_abs": scale_abs, + "segmentation_ckpt": str(segmentation_path), + "reconstruction": str(reconstruction_path), + } + ) + + splat.save_ply(str(ply_path), metadata=metadata) + logger.info("Wrote %s (%s gaussians)", ply_path, f"{splat.num_gaussians:,}") + + +def main() -> None: + """Parse CLI arguments and run segment extraction.""" + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "-s", + "--segmentation-path", + type=Path, + required=True, + help="GARfVDB segmentation checkpoint (.pt / .pth)", + ) + parser.add_argument( + "-r", + "--reconstruction-path", + type=Path, + required=True, + help="Gaussian splat reconstruction (.pt / .ply)", + ) + parser.add_argument( + "-o", + "--out-dir", + type=Path, + required=True, + help="Output directory for segment PLY files", + ) + + parser.add_argument("--n", type=int, default=10, help="Number of segments to export") + parser.add_argument( + "--scale", + type=float, + default=0.1, + help="Clustering scale (absolute or fraction of max; see --scale-is-fraction-of-max)", + ) + parser.add_argument( + "--scale-is-fraction-of-max", + action=argparse.BooleanOptionalAction, + default=True, + help="Interpret --scale as a fraction of model.max_grouping_scale", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling/subsampling") + parser.add_argument("--device", type=str, default="cuda", help="Torch device") + parser.add_argument("--verbose", action="store_true", help="Enable debug logging") + + parser.add_argument( + "--min-splat-scale", + type=float, + default=0.1, + help="Drop Gaussians with scale below this value (0 disables)", + ) + parser.add_argument( + "--opacity-percentile", + type=float, + default=0.85, + help="Keep Gaussians above this opacity percentile (0 disables)", + ) + parser.add_argument( + "--mean-percentile", + type=float, + nargs="*", + default=[0.96, 0.96, 0.96, 0.96, 0.98, 0.99], + help="Per-channel mean percentiles for splat pre-filtering", + ) + + parser.add_argument( + "--min-cluster-gaussians", + type=int, + default=200, + help="Drop clusters smaller than this (0 disables)", + ) + parser.add_argument( + "--filter-high-variance", + action=argparse.BooleanOptionalAction, + default=True, + help="Remove spatially incoherent clusters", + ) + parser.add_argument( + "--variance-threshold", + type=float, + default=0.1, + help="Normalized variance cutoff when --filter-high-variance is enabled", + ) + parser.add_argument( + "--sample-by", + choices=["random", "coherence"], + default="random", + help="How to pick which clusters to export", + ) + parser.add_argument( + "--max-gaussians-for-clustering", + type=int, + default=500_000, + help="Subsample before clustering when scene is larger (0 disables)", + ) + + args = parser.parse_args() + + rip_segments( + segmentation_path=args.segmentation_path, + reconstruction_path=args.reconstruction_path, + out_dir=args.out_dir, + n=args.n, + scale=args.scale, + scale_is_fraction_of_max=args.scale_is_fraction_of_max, + seed=args.seed, + device=args.device, + min_splat_scale=args.min_splat_scale, + opacity_percentile=args.opacity_percentile, + mean_percentile=tuple(float(x) for x in args.mean_percentile), + min_cluster_gaussians=args.min_cluster_gaussians, + filter_high_variance=args.filter_high_variance, + variance_threshold=args.variance_threshold, + sample_by=args.sample_by, + max_gaussians_for_clustering=args.max_gaussians_for_clustering, + verbose=args.verbose, + ) + + +if __name__ == "__main__": + main() diff --git a/segmenation_extraction/extract_mesh_for_segments.py b/segmenation_extraction/extract_mesh_for_segments.py new file mode 100644 index 0000000..566face --- /dev/null +++ b/segmenation_extraction/extract_mesh_for_segments.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Pulls mesh segment matching GS segment from a large mesh +Requires frgs mesh-dlnr or similar to be run on full scene first +Optionally closes holes via harmonic Laplacian fill then make_mesh_watertight +Output can be used as input for fvdb-reality-capture/scripts/create_isaac_ready_files.py to make a USDZ +""" +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +import igl +import numpy as np +import point_cloud_utils as pcu +from fvdb import GaussianSplat3d +from scipy.spatial import cKDTree + + +def _reproject_vertex_colors( + target_vertices: np.ndarray, + source_vertices: np.ndarray, + source_colors: np.ndarray, +) -> np.ndarray: + """Copy colors onto new mesh vertices via nearest-neighbor lookup. + Args: + target_vertices: The vertices of the new mesh. + source_vertices: The vertices of the source mesh. + source_colors: The colors of the source mesh. + Returns: + The colors of the new mesh. + """ + finite_mask = np.isfinite(source_vertices).all(axis=1) + if not finite_mask.all(): + source_vertices = source_vertices[finite_mask] + source_colors = source_colors[finite_mask] + if source_vertices.shape[0] == 0: + raise ValueError("No finite source vertices available for color reprojection") + + tree = cKDTree(source_vertices) + query_vertices = np.asarray(target_vertices, dtype=np.float64) + bad = ~np.isfinite(query_vertices).all(axis=1) + if bad.any(): + query_vertices = query_vertices.copy() + query_vertices[bad] = source_vertices.mean(axis=0) + _, indices = tree.query(query_vertices, k=1, workers=-1) + return source_colors[indices] + + +def _has_vertex_colors(vertex_colors: np.ndarray | None) -> bool: + """Return True if the mesh has a non-empty per-vertex color array.""" + return vertex_colors is not None and vertex_colors.size > 0 and vertex_colors.shape[0] > 0 + + +def _normalize_vertex_colors(colors: np.ndarray) -> np.ndarray: + """Convert vertex colors to float RGB(A) in [0, 1] for point_cloud_utils I/O.""" + colors = np.asarray(colors) + if colors.dtype == np.uint8: + colors = colors.astype(np.float64) / 255.0 + else: + colors = colors.astype(np.float64) + if colors.size > 0 and colors.max() > 1.0: + colors = colors / 255.0 + return np.clip(colors, 0.0, 1.0).astype(np.float32) + + +def extract_segment_mesh( + *, + full_mesh_path: Path, + segment_ply_path: Path, + output_path: Path, + distance_threshold: float, + no_gap_fill: bool, + resolution: int, + device: str, + verbose: bool, +) -> None: + """Extract a mesh region corresponding to a Gaussian segment. + + Args: + full_mesh_path: Path to the full scene mesh (.ply). + segment_ply_path: Path to the segment Gaussian splats (.ply). + output_path: Path to save the extracted segment mesh. + distance_threshold: Maximum distance from segment Gaussians to include mesh vertices (in world units). + no_gap_fill: Skip watertight + harmonic gap fill on the extracted patch + resolution: Manifold octree resolution for pcu.make_mesh_watertight (capped to segment size) + device: Device for loading Gaussian splats. + verbose: Enable verbose logging. + """ + log_level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + logger = logging.getLogger(__name__) + + # Load the full mesh + logger.info(f"Loading full scene mesh from {full_mesh_path}") + vertices, faces, vertex_colors = pcu.load_mesh_vfc(str(full_mesh_path)) + logger.info(f"Loaded mesh with {len(vertices):,} vertices and {len(faces):,} faces") + + # Load the segment Gaussians + logger.info(f"Loading segment Gaussians from {segment_ply_path}") + segment_splat, _ = GaussianSplat3d.from_ply(segment_ply_path, device=device) + segment_means = segment_splat.means.cpu().numpy() # [N, 3] + logger.info(f"Loaded {len(segment_means):,} Gaussians in segment") + + # Build KD-tree for fast nearest neighbor queries + logger.info("Building KD-tree for segment Gaussians...") + tree = cKDTree(segment_means) + + # Find mesh vertices within distance threshold of any segment Gaussian + logger.info(f"Finding mesh vertices within {distance_threshold:.3f} units of segment...") + distances, _ = tree.query(vertices, k=1, distance_upper_bound=distance_threshold) + close_vertex_mask = distances < distance_threshold + num_close_vertices = close_vertex_mask.sum() + + logger.info( + f"Found {num_close_vertices:,} vertices ({100 * num_close_vertices / len(vertices):.1f}%) " f"within threshold" + ) + + if num_close_vertices == 0: + raise ValueError( + f"No mesh vertices found within {distance_threshold} units of segment Gaussians. " + f"Try increasing --distance-threshold." + ) + + # Create a mapping from old vertex indices to new vertex indices + old_to_new_vertex_idx = np.full(len(vertices), -1, dtype=np.int64) + old_to_new_vertex_idx[close_vertex_mask] = np.arange(num_close_vertices) + + # Extract only faces where all 3 vertices are close to the segment + face_mask = np.all(close_vertex_mask[faces], axis=1) + extracted_faces = faces[face_mask] + num_extracted_faces = len(extracted_faces) + + logger.info( + f"Extracted {num_extracted_faces:,} faces ({100 * num_extracted_faces / len(faces):.1f}%) " + f"where all vertices are close to segment" + ) + + if num_extracted_faces == 0: + raise ValueError( + "No complete faces found within the distance threshold. " "Try increasing --distance-threshold." + ) + + # Reindex faces to use new vertex indices + extracted_faces_reindexed = old_to_new_vertex_idx[extracted_faces] + + # Extract the corresponding vertices and colors + extracted_vertices = vertices[close_vertex_mask] + extracted_colors = ( + _normalize_vertex_colors(vertex_colors[close_vertex_mask]) if _has_vertex_colors(vertex_colors) else None + ) + + if not no_gap_fill: + # Use the raw extracted patch for color lookup after watertight remeshes vertices. + color_source_vertices = extracted_vertices.copy() + color_source_colors = extracted_colors + + logger.info("Harmonic Laplacian gap fill on extracted patch...") + num_faces_before = len(extracted_faces_reindexed) + num_vertices_before = len(extracted_vertices) + extracted_vertices, extracted_faces_reindexed = fill_mesh_gaps(extracted_vertices, extracted_faces_reindexed) + if len(extracted_faces_reindexed) != num_faces_before: + logger.info( + "Harmonic fill added %d faces and %d vertices (%d -> %d faces)", + len(extracted_faces_reindexed) - num_faces_before, + len(extracted_vertices) - num_vertices_before, + num_faces_before, + len(extracted_faces_reindexed), + ) + effective_resolution = int(min(resolution, max(2_000, len(extracted_faces_reindexed) // 2))) + if effective_resolution != resolution: + logger.info( + "Capped watertight resolution %d -> %d for segment size", + resolution, + effective_resolution, + ) + logger.info("Making mesh watertight (resolution=%d)...", effective_resolution) + num_faces_before = len(extracted_faces_reindexed) + num_vertices_before = len(extracted_vertices) + extracted_vertices, extracted_faces_reindexed = pcu.make_mesh_watertight( + extracted_vertices, + extracted_faces_reindexed, + resolution=effective_resolution, + ) + logger.info( + "Watertight mesh: %d vertices, %d faces (was %d vertices, %d faces)", + len(extracted_vertices), + len(extracted_faces_reindexed), + num_vertices_before, + num_faces_before, + ) + # Reproject colors onto the new mesh vertices if colors are available + if color_source_colors is not None: + extracted_colors = _reproject_vertex_colors(extracted_vertices, color_source_vertices, color_source_colors) + + # Save the extracted mesh + output_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving extracted mesh to {output_path}") + + if extracted_colors is not None: + pcu.save_mesh_vfc( + str(output_path), + extracted_vertices, + extracted_faces_reindexed, + _normalize_vertex_colors(extracted_colors), + ) + else: + pcu.save_mesh_vf(str(output_path), extracted_vertices, extracted_faces_reindexed) + + logger.info( + f"Successfully saved mesh with {len(extracted_vertices):,} vertices " + f"and {len(extracted_faces_reindexed):,} faces" + ) + + +def fill_mesh_gaps( + vertices: np.ndarray, + faces: np.ndarray, + *, + k: int = 1, +) -> tuple[np.ndarray, np.ndarray]: + """Close boundary holes with fan caps and harmonic Laplacian fairing. + + Each open boundary loop is triangulated with a Steiner vertex at the loop + centroid. Cap vertex positions are then relaxed by solving a k-harmonic + PDE (k=1: Laplacian, k=2: biharmonic) with the original mesh vertices fixed. + + Args: + vertices: Mesh vertex positions, shape (#V, 3). + faces: Triangle indices, shape (#F, 3). + k: Order of the harmonic operator (1 = harmonic/Laplacian, 2 = biharmonic). + + Returns: + Updated (vertices, faces) with holes capped. + """ + vertices = np.asarray(vertices, dtype=np.float64) + faces = np.asarray(faces, dtype=np.int32) + if faces.ndim != 2 or faces.shape[1] != 3: + raise ValueError("faces must be an (#F, 3) triangle index array") + + orig_vertex_count = vertices.shape[0] + + while True: + loop = igl.boundary_loop(faces) + if loop.size < 3: + break + + centroid = vertices[loop].mean(axis=0, keepdims=True) + cap_idx = vertices.shape[0] + vertices = np.vstack([vertices, centroid]) + + cap_faces = np.array( + [[cap_idx, int(loop[i]), int(loop[(i + 1) % loop.size])] for i in range(loop.size)], + dtype=np.int32, + ) + faces = np.vstack([faces, cap_faces]) + + if vertices.shape[0] == orig_vertex_count: + return vertices.astype(np.float32), faces + + b = np.arange(orig_vertex_count, dtype=np.int32) + bc = vertices[:orig_vertex_count] + smoothed = igl.harmonic(vertices, faces, b, bc, k) + if not np.isfinite(smoothed).all(): + bad = ~np.isfinite(smoothed).all(axis=1) + smoothed = np.asarray(smoothed, dtype=np.float64) + smoothed[bad] = vertices[bad] + vertices = smoothed + + return vertices.astype(np.float32), faces + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + parser = argparse.ArgumentParser( + description="Segment full scene mesh based on a PLY Gaussian Splat segment", + ) + parser.add_argument("--input-splat", type=Path, help="Input splat segment file (PLY format)") + parser.add_argument("--input-mesh", type=Path, help="Input full scene mesh file (PLY/OBJ format)") + parser.add_argument("--output-path", type=Path, required=True, help="Output path") + parser.add_argument( + "--no-gap-fill", + action="store_true", + help="Skip watertight + harmonic gap fill; may cause collision issues in Isaac Sim", + ) + parser.add_argument( + "--resolution", + type=int, + default=20_000, + help="Manifold octree resolution for make_mesh_watertight (default: 20000)", + ) + parser.add_argument("--device", type=str, default="cuda", help="Device to use") + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument( + "--distance-threshold", + type=float, + default=0.5, + help="Maximum distance from segment Gaussians to include mesh vertices (default: 0.5)", + ) + args = parser.parse_args() + if args.input_splat is None or args.input_mesh is None: + parser.error("Both --input-splat and --input-mesh are required") + + extract_segment_mesh( + full_mesh_path=args.input_mesh, + segment_ply_path=args.input_splat, + output_path=args.output_path, + distance_threshold=args.distance_threshold, + device=args.device, + verbose=args.verbose, + no_gap_fill=args.no_gap_fill, + resolution=args.resolution, + ) + + +if __name__ == "__main__": + main()