From 9d041dfe12cdeee5751b4b730677ddb5b1377876 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 16 Mar 2026 13:04:27 +1300 Subject: [PATCH 01/15] Fix for mask overlay visualization Signed-off-by: Jonathan Swartz --- .../garfvdb/garfvdb/model.py | 7 +- .../garfvdb/train_segmentation.py | 521 +++++++----------- 2 files changed, 216 insertions(+), 312 deletions(-) diff --git a/instance_segmentation/garfvdb/garfvdb/model.py b/instance_segmentation/garfvdb/garfvdb/model.py index cf542e6..50844fc 100644 --- a/instance_segmentation/garfvdb/garfvdb/model.py +++ b/instance_segmentation/garfvdb/garfvdb/model.py @@ -576,7 +576,12 @@ def get_encoded_features(self, input: GARfVDBInput) -> torch.Tensor: if not self.model_config.enc_feats_one_idx_per_ray: # Weighted sum of the enc_feats and transmittance weights enc_feats.jdata = enc_feats.jdata * weights.jdata.unsqueeze(-1) - enc_feats = enc_feats.jsum(dim=0, keepdim=True) + # When every pixel has exactly 1 contributor, fvdb returns a + # 1-level JaggedTensor [C,[R]] instead of 2-level [C,[R,[K]]]. + # jsum(dim=0) on a 1-level tensor would collapse pixels within + # each camera rather than depth samples within each pixel. + if isinstance(enc_feats.lshape[0], list): + enc_feats = enc_feats.jsum(dim=0, keepdim=True) epsilon = 1e-6 enc_feats.jdata = enc_feats.jdata / (torch.linalg.norm(enc_feats.jdata, dim=-1, keepdim=True) + epsilon) diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index 61d990e..8aca263 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -3,16 +3,13 @@ # import logging import pathlib -import threading import time -from typing import Callable, Literal +from typing import Literal import cv2 -import fvdb.viz as fviz import numpy as np import torch import tyro - from garfvdb.config import ( GaussianSplatSegmentationTrainingConfig, SfmSceneSegmentationTransformConfig, @@ -41,7 +38,6 @@ def main( visualize_every: int = -1, viewer_port: int = 8080, viewer_ip_address: str = "127.0.0.1", - camera_check_interval: float = 0.5, overlay_width: int = 1920, overlay_height: int = 1080, overlay_downsample: int = 2, @@ -55,61 +51,76 @@ def main( Args: sfm_dataset_path (pathlib.Path): Path to the dataset. For "colmap" datasets, this should be the - directory containing the `images` and `sparse` subdirectories. For "simple_directory" datasets, - this should be the directory containing the images and a `cameras.txt` file. For "e57" datasets, - this should be the path to the E57 file or directory. + directory containing the ``images`` and ``sparse`` subdirectories. For "simple_directory" + datasets, this should be the directory containing the images and a ``cameras.txt`` file. + For "e57" datasets, this should be the path to the E57 file or directory. reconstruction_path (pathlib.Path): Path to the precomputed Gaussian splat reconstruction to load. - config (GaussianSplatSegmentationTrainingConfig): Configuration for the Gaussian splat segmentation training. - tx (SfmSceneSegmentationTransformConfig): Configuration for the transforms to apply to the SfmScene before training. - io (GaussianSplatSegmentationWriterConfig): Configuration for saving segmentation metrics and checkpoints. + config (GaussianSplatSegmentationTrainingConfig): Configuration for the Gaussian splat + segmentation training. + tx (SfmSceneSegmentationTransformConfig): Configuration for the transforms to apply to the + SfmScene before training. + io (GaussianSplatSegmentationWriterConfig): Configuration for saving segmentation metrics + and checkpoints. dataset_type (Literal["colmap", "simple_directory", "e57"]): Type of dataset to load. - run_name (str | None): Name of the run. If None, a name will be generated based on the current date and time. - log_path (pathlib.Path | None): Path to log metrics and checkpoints. If None, no metrics or checkpoints will be saved. - use_every_n_as_val (int): Use every n-th image as a validation image. If -1, do not use a separate validation split. + run_name (str | None): Name of the run. If None, a name will be generated based on the + current date and time. + log_path (pathlib.Path | None): Path to log metrics and checkpoints. If None, no metrics + or checkpoints will be saved. + use_every_n_as_val (int): Use every n-th image as a validation image. If -1, do not use + a separate validation split. device (str | torch.device): Device to use for training. log_every (int): Log training metrics every n steps. visualize_every (int): Update the viewer every n epochs. If -1, do not visualize. - viewer_port (int): The port to expose the viewer server on if visualize_every > 0. - viewer_ip_address (str): The IP address to expose the viewer server on if visualize_every > 0. - camera_check_interval (float): How often to check for camera changes (seconds). The segmentation - overlay will update when the viewer camera moves. + viewer_port (int): The port to expose the viewer server on if ``visualize_every > 0``. + viewer_ip_address (str): The IP address to expose the viewer server on if + ``visualize_every > 0``. overlay_width (int): Width of the segmentation overlay in the viewer. overlay_height (int): Height of the segmentation overlay in the viewer. - overlay_downsample (int): Downsample factor for rendering. Renders at overlay_size / overlay_downsample - and then scales up for better performance. - mask_scale (float): Fraction of scene max scale to use for rendering segmentation masks (0.1 = 10%). - mask_blend (float): Blend factor for the segmentation overlay (0.0 = transparent, 1.0 = opaque). + overlay_downsample (int): Downsample factor for rendering. Renders at + ``overlay_size / overlay_downsample`` and then scales up for better performance. + mask_scale (float): Fraction of scene max scale to use for rendering segmentation masks + (0.1 = 10%). + mask_blend (float): Blend factor for the segmentation overlay (0.0 = transparent, + 1.0 = opaque). verbose (bool): Whether to log debug messages. cache_dataset (bool): If True, cache images and masks in memory to speed up data loading. Set to False to reduce memory usage for large datasets. Default is True. """ log_level = logging.DEBUG if verbose else logging.WARNING - logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s", force=True) logger = logging.getLogger(__name__) - # Validate camera_check_interval when visualization is enabled - if visualize_every > 0 and camera_check_interval <= 0: - raise ValueError( - f"camera_check_interval must be positive when visualize_every > 0, got {camera_check_interval}" - ) + viewer_enabled = visualize_every > 0 - # Load the SfmScene + # ---- Load data ---- sfm_scene = load_sfm_scene(sfm_dataset_path, dataset_type) - - # Load the GaussianSplat3D model gs_model, metadata = load_splats_from_file(reconstruction_path, device) normalization_transform = metadata.get("normalization_transform", None) - # Set up visualization if enabled + # ---- Create the runner (SAM2 masks + model init) ---- + writer = GaussianSplatSegmentationWriter(run_name=run_name, save_path=log_path, config=io, exist_ok=False) + runner = GaussianSplatScaleConditionedSegmentation.new( + sfm_scene=tx.build_scene_transforms(gs_model, normalization_transform)(sfm_scene), + gs_model=gs_model, + gs_model_path=reconstruction_path, + writer=writer, + config=config, + device=device, + use_every_n_as_val=use_every_n_as_val, + viewer_update_interval_epochs=visualize_every, + log_interval_steps=log_every, + viz_callback=None, + cache_dataset=cache_dataset, + ) + + # ---- Start the viewer ---- viz_scene = None - viz_callback: Callable[[GaussianSplatScaleConditionedSegmentation, int], None] | None = None + image_view = None + if viewer_enabled: + import fvdb.viz as fviz - if visualize_every > 0: - logger.info(f"Starting viewer server on {viewer_ip_address}:{viewer_port}") fviz.init(ip_address=viewer_ip_address, port=viewer_port, verbose=verbose) viz_scene = fviz.get_scene("Gaussian Splat Segmentation Training") - - # Add the Gaussian splats to the scene viz_scene.add_gaussian_splat_3d("Gaussian Splats", gs_model) # Set up initial camera position based on scene geometry @@ -129,300 +140,188 @@ def main( up=[0, 0, 1], ) - # Get projection matrix from metadata for camera intrinsics - projection_matrices = metadata.get("projection_matrices", None) - image_sizes = metadata.get("image_sizes", None) - - # Compute render dimensions (smaller for performance) - render_w = overlay_width // overlay_downsample - render_h = overlay_height // overlay_downsample - - # Try to set up the segmentation overlay image at full resolution - image_view = None + # Add the segmentation overlay image + initial_image = np.zeros((overlay_height, overlay_width, 4), dtype=np.uint8) + initial_image[..., 3] = 128 try: - initial_image = np.zeros((overlay_height, overlay_width, 4), dtype=np.uint8) - initial_image[..., 3] = 128 # Semi-transparent image_view = viz_scene.add_image( name="Segmentation Overlay", width=overlay_width, height=overlay_height, rgba_image=initial_image.flatten(), ) - logger.info(f"Segmentation overlay: {overlay_width}x{overlay_height} (render: {render_w}x{render_h})") - except Exception as e: - logger.warning(f"Failed to set up segmentation overlay: {e}") - logger.info("Running without segmentation overlay") - - # Get reference projection for intrinsics from metadata and scale for render resolution - reference_projection = None - if projection_matrices is not None and image_sizes is not None: - orig_projection = projection_matrices[0].float() - orig_w = float(image_sizes[0, 0].item()) - orig_h = float(image_sizes[0, 1].item()) - # Scale the projection matrix to the render resolution - # fx, fy scale with resolution, cx, cy scale with resolution - scale_x = render_w / orig_w - scale_y = render_h / orig_h - scaled_projection = orig_projection.clone() - scaled_projection[0, 0] *= scale_x # fx - scaled_projection[1, 1] *= scale_y # fy - scaled_projection[0, 2] *= scale_x # cx - scaled_projection[1, 2] *= scale_y # cy - reference_projection = scaled_projection.to(device) - - # OpenGL to OpenCV conversion matrix - opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) - - def get_camera_from_viewer() -> torch.Tensor | None: - """Compute camera-to-world matrix from the viewer's orbit camera state. - - Constructs a 4x4 transformation matrix from the viewer's current - orbit parameters (center, direction, radius, up vector) and converts - it from OpenGL to OpenCV convention for use with the segmentation model. - - Returns: - torch.Tensor | None: A 4x4 camera-to-world matrix in OpenCV - convention, or ``None`` if the camera state cannot be retrieved. - """ - try: - # Get orbit camera state from viewer - center = viz_scene.camera_orbit_center.cpu().numpy().copy() - eye_direction = viz_scene.camera_orbit_direction.cpu().numpy().copy() - radius = viz_scene.camera_orbit_radius - up_world = viz_scene.camera_up_direction.cpu().numpy().copy() - - # Camera position: center - eye_direction * radius - position = center - eye_direction * radius - - # Forward = eye_direction (already the look direction) - forward = eye_direction / np.linalg.norm(eye_direction) - - # Right vector = forward × up_world - right = np.cross(forward, up_world) - right = right / np.linalg.norm(right) - - # Up vector = right × forward - up = np.cross(right, forward) - up = up / np.linalg.norm(up) - - # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) - c2w_opengl = np.eye(4, dtype=np.float32) - c2w_opengl[:3, 0] = right - c2w_opengl[:3, 1] = up - c2w_opengl[:3, 2] = -forward # OpenGL: camera looks along -Z - c2w_opengl[:3, 3] = position - - # Convert to OpenCV convention for the segmentation model - c2w_opencv = c2w_opengl @ opengl_to_opencv - return torch.from_numpy(c2w_opencv).float().to(device) - except Exception as e: - logger.debug(f"Failed to get camera from viewer: {e}") - return None - - def render_segmentation_overlay(model: GARfVDBModel, log_prefix: str = "") -> None: - """Render a segmentation mask overlay in the interactive viewer. - - Captures the current viewer camera pose, renders the model's mask - features at a reduced resolution, applies PCA to produce an RGB - visualization, and updates the viewer's overlay image. - - Args: - model: The segmentation model to render. - log_prefix: Optional message to log on successful render. - """ - if image_view is None or reference_projection is None: - return - - camera_to_world = get_camera_from_viewer() - if camera_to_world is None: - return - - try: - from garfvdb.training.dataset import GARfVDBInput - - # Render at lower resolution for performance - model_input = GARfVDBInput( - { - "projection": reference_projection.unsqueeze(0), - "camera_to_world": camera_to_world.unsqueeze(0), - "image_w": [render_w], - "image_h": [render_h], - } - ) - - # Render at configured fraction of max scale - max_scale_tensor = model.max_grouping_scale - desired_scale = float(max_scale_tensor.item()) * mask_scale - - with torch.no_grad(): - mask_features_output, mask_alpha = model.get_mask_output(model_input, desired_scale) - - # Check for invalid values - if torch.isnan(mask_features_output).any() or torch.isinf(mask_features_output).any(): - logger.warning("Invalid values in mask features, skipping visualization update") - return - - # Apply PCA projection to get RGB visualization - mask_pca = pca_projection_fast(mask_features_output, 3, mask=mask_alpha.squeeze(-1) > 0)[0] - - # Blend mask with alpha - blended_alpha = mask_alpha[0] * mask_blend - - rgba = np.concatenate( - [mask_pca.detach().cpu().numpy(), blended_alpha.detach().cpu().numpy()], axis=-1 - ) - rgba_uint8 = (rgba.clip(0.0, 1.0) * 255).astype(np.uint8) - - # Scale up to overlay resolution - if overlay_downsample > 1: - rgba_uint8 = cv2.resize( - rgba_uint8, (overlay_width, overlay_height), interpolation=cv2.INTER_LINEAR - ) - - image_view.update(rgba_uint8.flatten()) - if log_prefix: - logger.debug(f"{log_prefix}") - - except Exception as e: - logger.warning(f"Error updating visualization: {e}") - - # Create visualization callback that uses the viewer camera - def update_visualization(runner: GaussianSplatScaleConditionedSegmentation, epoch: int) -> None: - """Update the visualization overlay during training from current viewer camera.""" - render_segmentation_overlay(runner.model, f"Updated segmentation visualization at epoch {epoch}") - - viz_callback = update_visualization + except Exception: + pass - writer = GaussianSplatSegmentationWriter(run_name=run_name, save_path=log_path, config=io, exist_ok=False) - runner = GaussianSplatScaleConditionedSegmentation.new( - sfm_scene=tx.build_scene_transforms(gs_model, normalization_transform)(sfm_scene), - gs_model=gs_model, - gs_model_path=reconstruction_path, - writer=writer, - config=config, - device=device, - use_every_n_as_val=use_every_n_as_val, - viewer_update_interval_epochs=visualize_every, - log_interval_steps=log_every, - viz_callback=viz_callback, - cache_dataset=cache_dataset, - ) - - # Set up camera tracking thread for interactive visualization - camera_thread = None - stop_camera_thread = threading.Event() - - if viz_scene is not None and image_view is not None: - # Camera state for change detection - prev_center = None - prev_direction = None - prev_radius = None - prev_up = None - - def camera_changed() -> bool: - """Determine whether the interactive viewer camera state has changed. - - Compares the current camera orbit parameters (center, direction, - radius, and up vector) from ``viz_scene`` against the previously - observed values stored in the nonlocal variables ``prev_center``, - ``prev_direction``, ``prev_radius``, and ``prev_up``. - - Notes: - This function updates the ``prev_*`` variables when called for - the first time or when a change is detected. - - Returns: - bool: ``True`` if this is the first call (to force an initial - overlay update) or if any of the tracked camera parameters - differ from the previously stored values; ``False`` otherwise - or if an error occurs while querying the camera state. - """ - nonlocal prev_center, prev_direction, prev_radius, prev_up - try: - center = viz_scene.camera_orbit_center - direction = viz_scene.camera_orbit_direction - radius = viz_scene.camera_orbit_radius - up = viz_scene.camera_up_direction - - # First time - always update - if prev_center is None: - prev_center = center - prev_direction = direction - prev_radius = radius - prev_up = up - return True - - changed = ( - not torch.allclose(center, prev_center) - or not torch.allclose(direction, prev_direction) # type: ignore - or radius != prev_radius - or not torch.allclose(up, prev_up) # type: ignore - ) - - if changed: - prev_center = center - prev_direction = direction - prev_radius = radius - prev_up = up - - return changed - except Exception as exc: - logger.debug( - "Failed to retrieve camera state in camera_changed: %s", - exc, - exc_info=True, - ) - return False - - def camera_monitor_loop() -> None: - """Monitor camera changes in a background thread and trigger overlay updates. - - This function is intended to be run in a dedicated daemon thread started - from :func:`main`. It periodically checks the interactive viewer camera - state and, when a change is detected, re-renders the segmentation overlay - for the current model. - - The loop runs until the ``stop_camera_thread`` :class:`threading.Event` - is set (e.g. in the main thread's shutdown/KeyboardInterrupt handler), - at which point the background thread exits cleanly. - - While :meth:`runner.train` executes in the main thread, this monitor - runs concurrently to keep the visualization in sync with user-driven - camera movements, and it continues to run during the post-training - viewing phase until shutdown. - """ - while not stop_camera_thread.is_set(): - time.sleep(camera_check_interval) - if camera_changed(): - render_segmentation_overlay(runner.model, "Updated segmentation overlay from camera") - - # Start the camera monitoring thread - camera_thread = threading.Thread(target=camera_monitor_loop, daemon=True) - camera_thread.start() - logger.info(f"Camera tracking enabled (checking every {camera_check_interval}s)") - - if viz_scene is not None: + fviz.show() logger.info("=" * 60) logger.info(f"Viewer running at http://{viewer_ip_address}:{viewer_port}") logger.info(f"Visualization updates every {visualize_every} epoch(s)") - if camera_thread is not None: - logger.info(f"Camera tracking updates every {camera_check_interval}s") logger.info("=" * 60) - fviz.show() + # ---- Build the overlay helpers used by the viz_callback ---- + # Compute render dimensions (smaller for performance) + render_w = overlay_width // overlay_downsample + render_h = overlay_height // overlay_downsample + + # Get reference projection for intrinsics from metadata and scale for render resolution + projection_matrices = metadata.get("projection_matrices", None) + image_sizes = metadata.get("image_sizes", None) + reference_projection = None + if projection_matrices is not None and image_sizes is not None: + orig_projection = projection_matrices[0].float() + orig_w = float(image_sizes[0, 0].item()) + orig_h = float(image_sizes[0, 1].item()) + # Scale the projection matrix to the render resolution + # fx, fy scale with resolution, cx, cy scale with resolution + scale_x = render_w / orig_w + scale_y = render_h / orig_h + scaled_projection = orig_projection.clone() + scaled_projection[0, 0] *= scale_x # fx + scaled_projection[1, 1] *= scale_y # fy + scaled_projection[0, 2] *= scale_x # cx + scaled_projection[1, 2] *= scale_y # cy + reference_projection = scaled_projection.to(device) + + # OpenGL to OpenCV conversion matrix + opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) + + def _camera_tuple_to_c2w( + center: np.ndarray, + eye_direction: np.ndarray, + radius: float, + up_world: np.ndarray, + ) -> torch.Tensor: + """Convert orbit camera parameters to a 4x4 camera-to-world matrix. + + Constructs a camera-to-world transformation matrix from the viewer's + orbit parameters (center, direction, radius, up vector) and converts + it from OpenGL to OpenCV convention for use with the segmentation model. + + Returns: + A 4x4 camera-to-world matrix in OpenCV convention on ``device``. + """ + # Camera position: center - eye_direction * radius + position = center - eye_direction * radius + + # Forward = eye_direction (already the look direction) + forward = eye_direction / np.linalg.norm(eye_direction) + + # Right vector = forward x up_world + right = np.cross(forward, up_world) + right = right / np.linalg.norm(right) + + # Up vector = right x forward + up = np.cross(right, forward) + up = up / np.linalg.norm(up) + + # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) + c2w_gl = np.eye(4, dtype=np.float32) + c2w_gl[:3, 0] = right + c2w_gl[:3, 1] = up + c2w_gl[:3, 2] = -forward # OpenGL: camera looks along -Z + c2w_gl[:3, 3] = position + + # Convert to OpenCV convention for the segmentation model + c2w_cv = c2w_gl @ opengl_to_opencv + return torch.from_numpy(c2w_cv).float().to(device) + + def render_overlay(model: GARfVDBModel, camera_to_world: torch.Tensor) -> np.ndarray | None: + """Render the segmentation overlay for a given camera pose. + + Returns an ``(H, W, 4)`` uint8 RGBA image, or ``None`` on failure. + """ + if reference_projection is None: + return None + try: + from garfvdb.training.dataset import GARfVDBInput + + # Render at lower resolution for performance + world_to_camera = torch.inverse(camera_to_world).contiguous() + model_input = GARfVDBInput( + { + "projection": reference_projection.unsqueeze(0), + "camera_to_world": camera_to_world.unsqueeze(0).contiguous(), + "world_to_camera": world_to_camera.unsqueeze(0), + "image_w": [render_w], + "image_h": [render_h], + } + ) + + # Render at configured fraction of max scale + max_scale_tensor = model.max_grouping_scale + desired_scale = float(max_scale_tensor.item()) * mask_scale + + with torch.no_grad(): + mask_features_output, mask_alpha = model.get_mask_output(model_input, desired_scale) + + # Check for invalid values + if torch.isnan(mask_features_output).any() or torch.isinf(mask_features_output).any(): + return None + + # Apply PCA projection to get RGB visualization + mask_pca = pca_projection_fast(mask_features_output, 3, mask=mask_alpha.squeeze(-1) > 0)[0] + + # Blend mask with alpha + blended_alpha = mask_alpha[0] * mask_blend + rgba = np.concatenate([mask_pca.detach().cpu().numpy(), blended_alpha.detach().cpu().numpy()], axis=-1) + rgba_uint8 = (rgba.clip(0.0, 1.0) * 255).astype(np.uint8) + + # Scale up to overlay resolution + if overlay_downsample > 1: + rgba_uint8 = cv2.resize(rgba_uint8, (overlay_width, overlay_height), interpolation=cv2.INTER_LINEAR) + return rgba_uint8 + except Exception as exc: + logger.warning(f"Error rendering overlay: {exc}") + return None + + def get_viewer_camera() -> tuple[np.ndarray, np.ndarray, float, np.ndarray] | None: + """Read the current orbit camera state directly from the viewer.""" + if viz_scene is None: + return None + try: + return ( + viz_scene.camera_orbit_center.cpu().numpy(), + viz_scene.camera_orbit_direction.cpu().numpy(), + float(viz_scene.camera_orbit_radius), + viz_scene.camera_up_direction.cpu().numpy(), + ) + except Exception: + return None + + def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, epoch: int) -> None: + """Viz callback invoked at epoch boundaries to update the segmentation overlay.""" + if image_view is None: + return + cam = get_viewer_camera() + if cam is None: + return + c2w = _camera_tuple_to_c2w(*cam) + frame = render_overlay(runner_arg.model, c2w) + if frame is not None: + image_view.update(frame.flatten()) + logger.debug(f"Updated segmentation overlay at epoch {epoch}") + + if image_view is not None: + runner._viz_callback = update_visualization + + # ---- Train ---- runner.train() + # ---- Post-training interactive viewing ---- if viz_scene is not None: logger.info("Training complete. Viewer running... Ctrl+C to exit.") try: - # Camera thread continues running during post-training viewing while True: - time.sleep(1) + cam = get_viewer_camera() + if cam is not None: + c2w = _camera_tuple_to_c2w(*cam) + frame = render_overlay(runner.model, c2w) + if frame is not None and image_view is not None: + image_view.update(frame.flatten()) + time.sleep(0.5) except KeyboardInterrupt: - if camera_thread is not None: - stop_camera_thread.set() - camera_thread.join(timeout=5.0) - logger.info("Shutting down...") + pass if __name__ == "__main__": From 075fcef8d648197225ede5bd575dafbb06d82f02 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Tue, 17 Mar 2026 13:29:02 +1300 Subject: [PATCH 02/15] Change view_checkpoint to initialize fvdbviz first and to lazily create the image_view on first add_image Signed-off-by: Jonathan Swartz --- .../garfvdb/view_checkpoint.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index ac1d806..087217b 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -260,6 +260,11 @@ def execute(self) -> None: logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") logger = logging.getLogger(__name__) + # Initialize fvdb.viz + logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}") + fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose) + viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer") + device = torch.device(self.device) # Validate segmentation checkpoint path @@ -302,11 +307,6 @@ def execute(self) -> None: renderer.scale = self.initial_scale * float(segmentation_model.max_grouping_scale.item()) renderer.mask_blend = self.initial_blend - # Initialize fvdb.viz - logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}") - fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose) - viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer") - # Set initial camera position scene_centroid = gs_model.means.mean(dim=0).cpu().numpy() cam_to_world_matrices = metadata.get("camera_to_world_matrices", None) @@ -351,29 +351,10 @@ def execute(self) -> None: # Add the Gaussian splat model to the scene viz_scene.add_gaussian_splat_3d("Gaussian Splats", gs_model) - # Set up the segmentation overlay if enabled + # Overlay is created lazily on first render to avoid the C++ viewer + # render thread touching an image grid before we have real content. image_view = None - if not self.no_overlay: - try: - # Create an initial blank image at full overlay resolution - initial_image = np.zeros((self.overlay_height, self.overlay_width, 4), dtype=np.uint8) - initial_image[..., 3] = 128 # Semi-transparent - - # Add the image overlay using the add_image API - image_view = viz_scene.add_image( # type: ignore[call-arg] - name="Segmentation Overlay", - width=self.overlay_width, - height=self.overlay_height, - rgba_image=initial_image.flatten(), - ) - logger.info("Segmentation overlay enabled") - - except AttributeError as e: - logger.warning(f"add_image API not available: {e}") - logger.info("Running without segmentation overlay") - except Exception as e: - logger.warning(f"Failed to set up segmentation overlay: {e}") - logger.info("Running without segmentation overlay") + overlay_enabled = not self.no_overlay logger.info("=" * 60) logger.info("Viewer running... Ctrl+C to exit.") @@ -382,7 +363,7 @@ def execute(self) -> None: logger.info("Segmentation settings:") logger.info(f" - Scale: {renderer.scale:.4f} (max: {segmentation_model.max_grouping_scale:.4f})") logger.info(f" - Mask blend: {renderer.mask_blend:.2f}") - if image_view is not None: + if overlay_enabled: logger.info(f" - Overlay: {self.overlay_width}x{self.overlay_height} (render: {render_w}x{render_h})") logger.info(f" - Update interval: {self.camera_check_interval}s") else: @@ -401,10 +382,10 @@ def camera_changed() -> bool: """Check if camera state changed using documented fvdb.viz.Scene properties.""" nonlocal prev_center, prev_direction, prev_radius, prev_up try: - center = viz_scene.camera_orbit_center - direction = viz_scene.camera_orbit_direction + center = viz_scene.camera_orbit_center.clone().cpu().numpy() + direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius - up = viz_scene.camera_up_direction + up = viz_scene.camera_up_direction.clone().cpu().numpy() # First time - always update if prev_center is None: @@ -451,18 +432,19 @@ def camera_changed() -> bool: opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) def update_overlay() -> None: - """Render and update the segmentation overlay.""" - if image_view is None: + """Render segmentation and lazily create/update the image overlay.""" + nonlocal image_view + if not overlay_enabled: return try: # Get orbit camera state from viewer # NOTE: Despite the Python API name "camera_orbit_direction", the C++ implementation # returns eye_direction which is the direction the camera is LOOKING (toward scene). # Camera position = center - eye_direction * distance (see Camera.h line 387-390) - center = np.array(viz_scene.camera_orbit_center.cpu().numpy()) - eye_direction = np.array(viz_scene.camera_orbit_direction.cpu().numpy()) + center = viz_scene.camera_orbit_center.clone().cpu().numpy() + eye_direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius - up_world = np.array(viz_scene.camera_up_direction.cpu().numpy()) + up_world = viz_scene.camera_up_direction.clone().cpu().numpy() # Camera position: center - eye_direction * radius # (eye_direction points FROM camera TOWARD center) @@ -471,11 +453,11 @@ def update_overlay() -> None: # Forward = eye_direction (already the look direction) forward = eye_direction / np.linalg.norm(eye_direction) - # Right vector = forward × up_world + # Right vector = forward x up_world right = np.cross(forward, up_world) right = right / np.linalg.norm(right) - # Up vector = right × forward + # Up vector = right x forward up = np.cross(right, forward) up = up / np.linalg.norm(up) @@ -491,7 +473,8 @@ def update_overlay() -> None: c2w_opencv = c2w_opengl @ opengl_to_opencv camera_to_world = torch.from_numpy(c2w_opencv).float().to(renderer.device) - world_to_camera = torch.linalg.inv(camera_to_world) + world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + # Render at lower resolution for performance rgba_image = renderer.render_segmentation_image( camera_to_world, world_to_camera, reference_projection, render_w, render_h @@ -501,7 +484,24 @@ def update_overlay() -> None: rgba_image = cv2.resize( rgba_image, (self.overlay_width, self.overlay_height), interpolation=cv2.INTER_LINEAR ) - image_view.update(rgba_image.flatten()) # type: ignore[attr-defined] + + flat_rgba = rgba_image.flatten() + + if image_view is None: + # Lazily create the image overlay on first render + try: + image_view = viz_scene.add_image( # type: ignore[call-arg] + name="Segmentation Overlay", + width=self.overlay_width, + height=self.overlay_height, + rgba_image=flat_rgba, + ) + except Exception as e: + logger.warning(f"add_image API not available or failed: {e}") + return + else: + image_view.update(flat_rgba) # type: ignore[attr-defined] + logger.debug( f"Updated segmentation overlay ({render_w}x{render_h} -> {self.overlay_width}x{self.overlay_height})" ) @@ -515,7 +515,7 @@ def update_overlay() -> None: while True: time.sleep(self.camera_check_interval) - if image_view is not None and camera_changed(): + if overlay_enabled and camera_changed(): logger.debug("Camera changed, updating overlay...") update_overlay() From f392033af7081753f8cbe7d2ad149bd75d65871f Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Tue, 17 Mar 2026 13:48:12 +1300 Subject: [PATCH 03/15] env touchup Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/garfvdb_environment.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/instance_segmentation/garfvdb/garfvdb_environment.yml b/instance_segmentation/garfvdb/garfvdb_environment.yml index fa792e0..bf5a588 100644 --- a/instance_segmentation/garfvdb/garfvdb_environment.yml +++ b/instance_segmentation/garfvdb/garfvdb_environment.yml @@ -3,9 +3,10 @@ channels: - conda-forge - nodefaults dependencies: + - cxx-compiler - blas=*=mkl - python=3.12 - - pytorch-gpu=2.8.0 + - pytorch-gpu=2.10.0 - cuda-version=12.9 - pip - git @@ -16,6 +17,8 @@ dependencies: - tyro - scikit-learn - py-opencv + - open-clip-torch + - openusd - imageio - sam2 - pyproj @@ -32,3 +35,5 @@ dependencies: - rapidsai::cuml - cupy - matplotlib + - pip: + - dlnr_lite From 123865f80e4045838fc878d6f02bd49a4fd17256 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Fri, 20 Mar 2026 13:29:26 +1300 Subject: [PATCH 04/15] guard against degenerate camera states Signed-off-by: Jonathan Swartz --- .../garfvdb/view_checkpoint.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index 087217b..cb96361 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -450,16 +450,27 @@ def update_overlay() -> None: # (eye_direction points FROM camera TOWARD center) position = center - eye_direction * radius - # Forward = eye_direction (already the look direction) - forward = eye_direction / np.linalg.norm(eye_direction) + # Guard against degenerate camera states (zero-length vectors + # produce NaN and ultimately a singular matrix). + eye_norm = np.linalg.norm(eye_direction) + if eye_norm < 1e-8: + return + + forward = eye_direction / eye_norm # Right vector = forward x up_world right = np.cross(forward, up_world) - right = right / np.linalg.norm(right) + right_norm = np.linalg.norm(right) + if right_norm < 1e-8: + return + right = right / right_norm # Up vector = right x forward up = np.cross(right, forward) - up = up / np.linalg.norm(up) + up_norm = np.linalg.norm(up) + if up_norm < 1e-8: + return + up = up / up_norm # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) # In OpenGL, camera looks along -Z, so Z column = -forward @@ -473,7 +484,10 @@ def update_overlay() -> None: c2w_opencv = c2w_opengl @ opengl_to_opencv camera_to_world = torch.from_numpy(c2w_opencv).float().to(renderer.device) - world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + try: + world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + except torch._C._LinAlgError: + return # Render at lower resolution for performance rgba_image = renderer.render_segmentation_image( From 8f84e26e9ee691c933597244eff5f09e9cf50794 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 23 Mar 2026 16:31:22 +1300 Subject: [PATCH 05/15] Update overlay dimensions and improve camera state handling in ViewCheckpoint - Adjusted overlay width to 1440 and height to 720 to match the nanovdb-editor viewport for correct alignment. - Removed unused original training image size assertions and added handling for field of view (FOV) changes in camera state checks. Signed-off-by: Jonathan Swartz --- .../garfvdb/view_checkpoint.py | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index cb96361..1eedcba 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -245,11 +245,13 @@ class ViewCheckpoint: no_overlay: bool = False """Disable the segmentation overlay (show Gaussian splats only).""" - overlay_width: int = 1920 - """Width of the segmentation overlay in pixels.""" + overlay_width: int = 1440 + """Width of the segmentation overlay in pixels. Must match the + nanovdb-editor viewport width for correct alignment (default 1440).""" - overlay_height: int = 1080 - """Height of the segmentation overlay in pixels.""" + overlay_height: int = 720 + """Height of the segmentation overlay in pixels. Must match the + nanovdb-editor viewport height for correct alignment (default 720).""" overlay_downsample: int = 2 """Downsample factor for rendering (renders at overlay_size/downsample).""" @@ -338,11 +340,6 @@ def execute(self) -> None: image_sizes=image_sizes, ) - # Get original training image size for projection scaling - assert image_sizes is not None - orig_img_w = int(image_sizes[0, 0].item()) - orig_img_h = int(image_sizes[0, 1].item()) - # Compute render dimensions (smaller for performance) render_w = self.overlay_width // self.overlay_downsample render_h = self.overlay_height // self.overlay_downsample @@ -377,11 +374,13 @@ def execute(self) -> None: prev_direction = None prev_radius = None prev_up = None + prev_fov = None def camera_changed() -> bool: """Check if camera state changed using documented fvdb.viz.Scene properties.""" - nonlocal prev_center, prev_direction, prev_radius, prev_up + nonlocal prev_center, prev_direction, prev_radius, prev_up, prev_fov try: + fov = viz_scene.camera_fov center = viz_scene.camera_orbit_center.clone().cpu().numpy() direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius @@ -393,6 +392,7 @@ def camera_changed() -> bool: prev_direction = direction prev_radius = radius prev_up = up + prev_fov = fov return True assert prev_center is not None and prev_direction is not None and prev_up is not None @@ -401,6 +401,7 @@ def camera_changed() -> bool: or not np.allclose(direction, prev_direction) or radius != prev_radius or not np.allclose(up, prev_up) + or fov != prev_fov ) if changed: @@ -408,23 +409,28 @@ def camera_changed() -> bool: prev_direction = direction prev_radius = radius prev_up = up + prev_fov = fov return changed except Exception: return False - # Get reference projection from SfmScene and scale for render resolution - sfm_projection = torch.from_numpy(sfm_scene.projection_matrices).float() - orig_projection = sfm_projection[0] - # Scale the projection matrix from original training image size to render resolution - scale_x = render_w / orig_img_w - scale_y = render_h / orig_img_h - scaled_projection = orig_projection.clone() - scaled_projection[0, 0] *= scale_x # fx - scaled_projection[1, 1] *= scale_y # fy - scaled_projection[0, 2] *= scale_x # cx - scaled_projection[1, 2] *= scale_y # cy - reference_projection = scaled_projection.to(device) + # Build intrinsic matrix from the viewer's FOV to match its perspective + # camera. Recomputed whenever the FOV changes. + cx = render_w / 2.0 + cy = render_h / 2.0 + cached_fov: float | None = None + reference_projection: torch.Tensor | None = None + + def _update_projection(fov_y_rad: float) -> torch.Tensor: + nonlocal cached_fov, reference_projection + fy = render_h / (2.0 * np.tan(fov_y_rad / 2.0)) + reference_projection = torch.tensor( + [[fy, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], + dtype=torch.float32, + ).to(device) + cached_fov = fov_y_rad + return reference_projection # OpenGL to OpenCV conversion matrix (applied to camera axes) # OpenGL: X-right, Y-up, Z-backward @@ -437,7 +443,10 @@ def update_overlay() -> None: if not overlay_enabled: return try: - # Get orbit camera state from viewer + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) + # NOTE: Despite the Python API name "camera_orbit_direction", the C++ implementation # returns eye_direction which is the direction the camera is LOOKING (toward scene). # Camera position = center - eye_direction * distance (see Camera.h line 387-390) @@ -445,9 +454,6 @@ def update_overlay() -> None: eye_direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius up_world = viz_scene.camera_up_direction.clone().cpu().numpy() - - # Camera position: center - eye_direction * radius - # (eye_direction points FROM camera TOWARD center) position = center - eye_direction * radius # Guard against degenerate camera states (zero-length vectors @@ -490,6 +496,7 @@ def update_overlay() -> None: return # Render at lower resolution for performance + assert reference_projection is not None rgba_image = renderer.render_segmentation_image( camera_to_world, world_to_camera, reference_projection, render_w, render_h ) From 0529cb3a7dd33dd3641505ec03b8e208b404bc23 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 30 Mar 2026 16:11:23 +1300 Subject: [PATCH 06/15] camera fixes in train_segmentation Signed-off-by: Jonathan Swartz --- .../garfvdb/train_segmentation.py | 61 +++++++++++-------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index 8aca263..6994d0a 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -38,8 +38,8 @@ def main( visualize_every: int = -1, viewer_port: int = 8080, viewer_ip_address: str = "127.0.0.1", - overlay_width: int = 1920, - overlay_height: int = 1080, + overlay_width: int = 1440, + overlay_height: int = 720, overlay_downsample: int = 2, mask_scale: float = 0.1, mask_blend: float = 0.5, @@ -74,8 +74,10 @@ def main( viewer_port (int): The port to expose the viewer server on if ``visualize_every > 0``. viewer_ip_address (str): The IP address to expose the viewer server on if ``visualize_every > 0``. - overlay_width (int): Width of the segmentation overlay in the viewer. - overlay_height (int): Height of the segmentation overlay in the viewer. + overlay_width (int): Width of the segmentation overlay in the viewer. Must match + the nanovdb-editor viewport width for correct alignment (default 1440). + overlay_height (int): Height of the segmentation overlay in the viewer. Must match + the nanovdb-editor viewport height for correct alignment (default 720). overlay_downsample (int): Downsample factor for rendering. Renders at ``overlay_size / overlay_downsample`` and then scales up for better performance. mask_scale (float): Fraction of scene max scale to use for rendering segmentation masks @@ -91,6 +93,14 @@ def main( logger = logging.getLogger(__name__) viewer_enabled = visualize_every > 0 + # ---- Initialize viewer BEFORE any CUDA operations ---- + # The Vulkan device created by fviz.init() must be set up before the CUDA + # context is first created (by load_splats_from_file). This matches the + # initialization order used in frgs reconstruction. + if viewer_enabled: + import fvdb.viz as fviz + fviz.init(ip_address=viewer_ip_address, port=viewer_port, verbose=verbose) + # ---- Load data ---- sfm_scene = load_sfm_scene(sfm_dataset_path, dataset_type) @@ -119,7 +129,6 @@ def main( if viewer_enabled: import fvdb.viz as fviz - fviz.init(ip_address=viewer_ip_address, port=viewer_port, verbose=verbose) viz_scene = fviz.get_scene("Gaussian Splat Segmentation Training") viz_scene.add_gaussian_splat_3d("Gaussian Splats", gs_model) @@ -164,24 +173,22 @@ def main( render_w = overlay_width // overlay_downsample render_h = overlay_height // overlay_downsample - # Get reference projection for intrinsics from metadata and scale for render resolution - projection_matrices = metadata.get("projection_matrices", None) - image_sizes = metadata.get("image_sizes", None) - reference_projection = None - if projection_matrices is not None and image_sizes is not None: - orig_projection = projection_matrices[0].float() - orig_w = float(image_sizes[0, 0].item()) - orig_h = float(image_sizes[0, 1].item()) - # Scale the projection matrix to the render resolution - # fx, fy scale with resolution, cx, cy scale with resolution - scale_x = render_w / orig_w - scale_y = render_h / orig_h - scaled_projection = orig_projection.clone() - scaled_projection[0, 0] *= scale_x # fx - scaled_projection[1, 1] *= scale_y # fy - scaled_projection[0, 2] *= scale_x # cx - scaled_projection[1, 2] *= scale_y # cy - reference_projection = scaled_projection.to(device) + # Build intrinsic matrix from the viewer's FOV to match its perspective + # camera. Recomputed whenever the FOV changes. + cx = render_w / 2.0 + cy = render_h / 2.0 + cached_fov: float | None = None + reference_projection: torch.Tensor | None = None + + def _update_projection(fov_y_rad: float) -> torch.Tensor: + nonlocal cached_fov, reference_projection + fy = render_h / (2.0 * np.tan(fov_y_rad / 2.0)) + reference_projection = torch.tensor( + [[fy, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], + dtype=torch.float32, + ).to(device) + cached_fov = fov_y_rad + return reference_projection # OpenGL to OpenCV conversion matrix opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) @@ -291,11 +298,14 @@ def get_viewer_camera() -> tuple[np.ndarray, np.ndarray, float, np.ndarray] | No def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, epoch: int) -> None: """Viz callback invoked at epoch boundaries to update the segmentation overlay.""" - if image_view is None: + if image_view is None or viz_scene is None: return cam = get_viewer_camera() if cam is None: return + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) c2w = _camera_tuple_to_c2w(*cam) frame = render_overlay(runner_arg.model, c2w) if frame is not None: @@ -315,6 +325,9 @@ def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, while True: cam = get_viewer_camera() if cam is not None: + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) c2w = _camera_tuple_to_c2w(*cam) frame = render_overlay(runner.model, c2w) if frame is not None and image_view is not None: From 1630a287a91b073a91821bff921da1467d20df97 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 30 Mar 2026 16:11:33 +1300 Subject: [PATCH 07/15] cuda 13 in env Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/garfvdb_environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instance_segmentation/garfvdb/garfvdb_environment.yml b/instance_segmentation/garfvdb/garfvdb_environment.yml index bf5a588..aa400e6 100644 --- a/instance_segmentation/garfvdb/garfvdb_environment.yml +++ b/instance_segmentation/garfvdb/garfvdb_environment.yml @@ -7,7 +7,7 @@ dependencies: - blas=*=mkl - python=3.12 - pytorch-gpu=2.10.0 - - cuda-version=12.9 + - cuda-version=13.0 - pip - git - gitpython From 68c2b49c6a4cd50c2885c410e807991832eacd2c Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 30 Mar 2026 16:31:12 +1300 Subject: [PATCH 08/15] rename of mask parameters in view_checkpoint Signed-off-by: Jonathan Swartz --- .gitignore | 3 +++ instance_segmentation/garfvdb/view_checkpoint.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index cf8bc07..c7ced4b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Data +*.ply + # Generated version file fvdb/version.py diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index 1eedcba..025ce45 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -233,10 +233,10 @@ class ViewCheckpoint: device: str | torch.device = "cuda" """Device for computation (e.g., "cuda" or "cpu").""" - initial_scale: float = 0.1 + mask_scale: float = 0.1 """Initial segmentation scale as a fraction of max scale.""" - initial_blend: float = 0.5 + mask_blend: float = 0.5 """Initial mask blend factor (0=beauty only, 1=mask only).""" camera_check_interval: float = 0.5 @@ -306,8 +306,8 @@ def execute(self) -> None: segmentation_model=segmentation_model, device=device, ) - renderer.scale = self.initial_scale * float(segmentation_model.max_grouping_scale.item()) - renderer.mask_blend = self.initial_blend + renderer.scale = self.mask_scale * float(segmentation_model.max_grouping_scale.item()) + renderer.mask_blend = self.mask_blend # Set initial camera position scene_centroid = gs_model.means.mean(dim=0).cpu().numpy() From 06f92d4112e806f3a594f644cd06ad55c10510fe Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 11:57:49 +1300 Subject: [PATCH 09/15] Handle potential singularity in camera transformation by adding error handling for inverse calculation in train_segmentation.py Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/train_segmentation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index 6994d0a..c1076fd 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -244,7 +244,10 @@ def render_overlay(model: GARfVDBModel, camera_to_world: torch.Tensor) -> np.nda from garfvdb.training.dataset import GARfVDBInput # Render at lower resolution for performance - world_to_camera = torch.inverse(camera_to_world).contiguous() + try: + world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + except torch.linalg.LinAlgError: + return None model_input = GARfVDBInput( { "projection": reference_projection.unsqueeze(0), From d81026013a77ab72a22acb47ddb3595223eae77b Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 12:02:52 +1300 Subject: [PATCH 10/15] Refactor visualization callback handling in segmentation training - Moved the visualization callback assignment to the runner creation in train_segmentation.py. - Added getter and setter for the visualization callback in GaussianSplatScaleConditionedSegmentation class to improve encapsulation. Signed-off-by: Jonathan Swartz --- .../garfvdb/garfvdb/training/segmentation.py | 11 ++++++ .../garfvdb/train_segmentation.py | 34 +++++++++---------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/instance_segmentation/garfvdb/garfvdb/training/segmentation.py b/instance_segmentation/garfvdb/garfvdb/training/segmentation.py index 0c095fc..11902ea 100644 --- a/instance_segmentation/garfvdb/garfvdb/training/segmentation.py +++ b/instance_segmentation/garfvdb/garfvdb/training/segmentation.py @@ -292,6 +292,17 @@ def device(self) -> torch.device: return torch.device(dev) return dev + @property + def viz_callback(self) -> Callable[["GaussianSplatScaleConditionedSegmentation", int], None] | None: + """Get or set the visualization callback invoked at epoch boundaries.""" + return self._viz_callback + + @viz_callback.setter + def viz_callback( + self, callback: Callable[["GaussianSplatScaleConditionedSegmentation", int], None] | None + ) -> None: + self._viz_callback = callback + @staticmethod def _init_model( model_config: GARfVDBModelConfig, diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index c1076fd..e6e158d 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -107,22 +107,6 @@ def main( gs_model, metadata = load_splats_from_file(reconstruction_path, device) normalization_transform = metadata.get("normalization_transform", None) - # ---- Create the runner (SAM2 masks + model init) ---- - writer = GaussianSplatSegmentationWriter(run_name=run_name, save_path=log_path, config=io, exist_ok=False) - runner = GaussianSplatScaleConditionedSegmentation.new( - sfm_scene=tx.build_scene_transforms(gs_model, normalization_transform)(sfm_scene), - gs_model=gs_model, - gs_model_path=reconstruction_path, - writer=writer, - config=config, - device=device, - use_every_n_as_val=use_every_n_as_val, - viewer_update_interval_epochs=visualize_every, - log_interval_steps=log_every, - viz_callback=None, - cache_dataset=cache_dataset, - ) - # ---- Start the viewer ---- viz_scene = None image_view = None @@ -315,8 +299,22 @@ def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, image_view.update(frame.flatten()) logger.debug(f"Updated segmentation overlay at epoch {epoch}") - if image_view is not None: - runner._viz_callback = update_visualization +# ---- Create the runner (SAM2 masks + model init) ---- + # ---- Create the runner (SAM2 masks + model init) ---- + writer = GaussianSplatSegmentationWriter(run_name=run_name, save_path=log_path, config=io, exist_ok=False) + runner = GaussianSplatScaleConditionedSegmentation.new( + sfm_scene=tx.build_scene_transforms(gs_model, normalization_transform)(sfm_scene), + gs_model=gs_model, + gs_model_path=reconstruction_path, + writer=writer, + config=config, + device=device, + use_every_n_as_val=use_every_n_as_val, + viewer_update_interval_epochs=visualize_every, + log_interval_steps=log_every, + viz_callback=update_visualization if image_view is not None else None, + cache_dataset=cache_dataset, + ) # ---- Train ---- runner.train() From 41a00a40f06b3485c173f1f58ce97e7d335dc001 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 12:07:13 +1300 Subject: [PATCH 11/15] - Set `overlay_enabled` to `False` upon failure of the `add_image` API to prevent further attempts. Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/view_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index 025ce45..1f223fa 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -439,7 +439,7 @@ def _update_projection(fov_y_rad: float) -> torch.Tensor: def update_overlay() -> None: """Render segmentation and lazily create/update the image overlay.""" - nonlocal image_view + nonlocal image_view, overlay_enabled if not overlay_enabled: return try: @@ -519,6 +519,7 @@ def update_overlay() -> None: ) except Exception as e: logger.warning(f"add_image API not available or failed: {e}") + overlay_enabled = False return else: image_view.update(flat_rgba) # type: ignore[attr-defined] From c16373f0022c3dce091a7dff737bc0b57acd6fff Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 12:11:37 +1300 Subject: [PATCH 12/15] - Updated the camera-to-world matrix function to return None for degenerate camera states, preventing potential errors during rendering. - Added checks for zero-length vectors in eye direction, right, and up vectors to ensure valid transformations. Signed-off-by: Jonathan Swartz --- .../garfvdb/train_segmentation.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index e6e158d..9cbb804 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -182,7 +182,7 @@ def _camera_tuple_to_c2w( eye_direction: np.ndarray, radius: float, up_world: np.ndarray, - ) -> torch.Tensor: + ) -> torch.Tensor | None: """Convert orbit camera parameters to a 4x4 camera-to-world matrix. Constructs a camera-to-world transformation matrix from the viewer's @@ -190,21 +190,32 @@ def _camera_tuple_to_c2w( it from OpenGL to OpenCV convention for use with the segmentation model. Returns: - A 4x4 camera-to-world matrix in OpenCV convention on ``device``. + A 4x4 camera-to-world matrix in OpenCV convention on ``device``, + or ``None`` if the camera state is degenerate (zero-length or + near-parallel vectors). """ # Camera position: center - eye_direction * radius position = center - eye_direction * radius # Forward = eye_direction (already the look direction) - forward = eye_direction / np.linalg.norm(eye_direction) + eye_norm = np.linalg.norm(eye_direction) + if eye_norm < 1e-8: + return None + forward = eye_direction / eye_norm # Right vector = forward x up_world right = np.cross(forward, up_world) - right = right / np.linalg.norm(right) + right_norm = np.linalg.norm(right) + if right_norm < 1e-8: + return None + right /= right_norm # Up vector = right x forward up = np.cross(right, forward) - up = up / np.linalg.norm(up) + up_norm = np.linalg.norm(up) + if up_norm < 1e-8: + return None + up /= up_norm # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) c2w_gl = np.eye(4, dtype=np.float32) @@ -294,6 +305,8 @@ def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, if reference_projection is None or fov_y_rad != cached_fov: _update_projection(fov_y_rad) c2w = _camera_tuple_to_c2w(*cam) + if c2w is None: + return frame = render_overlay(runner_arg.model, c2w) if frame is not None: image_view.update(frame.flatten()) @@ -330,6 +343,8 @@ def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, if reference_projection is None or fov_y_rad != cached_fov: _update_projection(fov_y_rad) c2w = _camera_tuple_to_c2w(*cam) + if c2w is None: + continue frame = render_overlay(runner.model, c2w) if frame is not None and image_view is not None: image_view.update(frame.flatten()) From 81e63d1c4abbbc9068465bcc18b89653bcd0e6f5 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 12:24:55 +1300 Subject: [PATCH 13/15] exception handling in train_segmentation.py to log warnings when creating the segmentation overlay image fails, providing clearer feedback on the issue and disabling overlay visualization for the session. Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/train_segmentation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index 9cbb804..344d731 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -143,8 +143,9 @@ def main( height=overlay_height, rgba_image=initial_image.flatten(), ) - except Exception: - pass + except Exception as exc: + logger.warning(f"Failed to create segmentation overlay image: {exc}") + logger.warning("Overlay visualization will be disabled for this session.") fviz.show() logger.info("=" * 60) From 7251650ec7bcb1d510a350ab7dc53f74a18bc3e4 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 13:00:30 +1300 Subject: [PATCH 14/15] environment/dependency updates Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/garfvdb_environment.yml | 4 +++- instance_segmentation/garfvdb/pyproject.toml | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/instance_segmentation/garfvdb/garfvdb_environment.yml b/instance_segmentation/garfvdb/garfvdb_environment.yml index aa400e6..a97789c 100644 --- a/instance_segmentation/garfvdb/garfvdb_environment.yml +++ b/instance_segmentation/garfvdb/garfvdb_environment.yml @@ -3,11 +3,12 @@ channels: - conda-forge - nodefaults dependencies: + - fvdb-core>=0.4.2,<0.5.0 - cxx-compiler - blas=*=mkl - python=3.12 - pytorch-gpu=2.10.0 - - cuda-version=13.0 + - cuda-version>=12.9 - pip - git - gitpython @@ -37,3 +38,4 @@ dependencies: - matplotlib - pip: - dlnr_lite + - nanovdb-editor>=0.0.24 diff --git a/instance_segmentation/garfvdb/pyproject.toml b/instance_segmentation/garfvdb/pyproject.toml index 7ca3edb..8a41b85 100644 --- a/instance_segmentation/garfvdb/pyproject.toml +++ b/instance_segmentation/garfvdb/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "tqdm", "tyro", "tensorboard", - "fvdb-core", + "fvdb-core>=0.4.2,<0.5.0", + "nanovdb-editor>=0.0.24,<0.2.0", "fvdb-reality-capture", "pycocotools", "matplotlib", From bde5a1f49dcecb7e5a0edaa0c9519201fc0e2d7e Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 1 Apr 2026 13:16:51 +1300 Subject: [PATCH 15/15] env/requirements updates Signed-off-by: Jonathan Swartz --- instance_segmentation/garfvdb/garfvdb_environment.yml | 1 + instance_segmentation/garfvdb/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/instance_segmentation/garfvdb/garfvdb_environment.yml b/instance_segmentation/garfvdb/garfvdb_environment.yml index a97789c..436926b 100644 --- a/instance_segmentation/garfvdb/garfvdb_environment.yml +++ b/instance_segmentation/garfvdb/garfvdb_environment.yml @@ -38,4 +38,5 @@ dependencies: - matplotlib - pip: - dlnr_lite + - fvdb-reality-capture>=0.4.0<0.5.0 - nanovdb-editor>=0.0.24 diff --git a/instance_segmentation/garfvdb/pyproject.toml b/instance_segmentation/garfvdb/pyproject.toml index 8a41b85..e6a4de0 100644 --- a/instance_segmentation/garfvdb/pyproject.toml +++ b/instance_segmentation/garfvdb/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "tensorboard", "fvdb-core>=0.4.2,<0.5.0", "nanovdb-editor>=0.0.24,<0.2.0", - "fvdb-reality-capture", + "fvdb-reality-capture>=0.4.0<0.5.0", "pycocotools", "matplotlib", "pillow",