diff --git a/.gitignore b/.gitignore index cf8bc07..c7ced4b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Data +*.ply + # Generated version file fvdb/version.py diff --git a/instance_segmentation/garfvdb/garfvdb/model.py b/instance_segmentation/garfvdb/garfvdb/model.py index cf542e6..50844fc 100644 --- a/instance_segmentation/garfvdb/garfvdb/model.py +++ b/instance_segmentation/garfvdb/garfvdb/model.py @@ -576,7 +576,12 @@ def get_encoded_features(self, input: GARfVDBInput) -> torch.Tensor: if not self.model_config.enc_feats_one_idx_per_ray: # Weighted sum of the enc_feats and transmittance weights enc_feats.jdata = enc_feats.jdata * weights.jdata.unsqueeze(-1) - enc_feats = enc_feats.jsum(dim=0, keepdim=True) + # When every pixel has exactly 1 contributor, fvdb returns a + # 1-level JaggedTensor [C,[R]] instead of 2-level [C,[R,[K]]]. + # jsum(dim=0) on a 1-level tensor would collapse pixels within + # each camera rather than depth samples within each pixel. + if isinstance(enc_feats.lshape[0], list): + enc_feats = enc_feats.jsum(dim=0, keepdim=True) epsilon = 1e-6 enc_feats.jdata = enc_feats.jdata / (torch.linalg.norm(enc_feats.jdata, dim=-1, keepdim=True) + epsilon) diff --git a/instance_segmentation/garfvdb/garfvdb/training/segmentation.py b/instance_segmentation/garfvdb/garfvdb/training/segmentation.py index 0c095fc..11902ea 100644 --- a/instance_segmentation/garfvdb/garfvdb/training/segmentation.py +++ b/instance_segmentation/garfvdb/garfvdb/training/segmentation.py @@ -292,6 +292,17 @@ def device(self) -> torch.device: return torch.device(dev) return dev + @property + def viz_callback(self) -> Callable[["GaussianSplatScaleConditionedSegmentation", int], None] | None: + """Get or set the visualization callback invoked at epoch boundaries.""" + return self._viz_callback + + @viz_callback.setter + def viz_callback( + self, callback: Callable[["GaussianSplatScaleConditionedSegmentation", int], None] | None + ) -> None: + self._viz_callback = callback + @staticmethod def _init_model( model_config: GARfVDBModelConfig, diff --git a/instance_segmentation/garfvdb/garfvdb_environment.yml b/instance_segmentation/garfvdb/garfvdb_environment.yml index fa792e0..436926b 100644 --- a/instance_segmentation/garfvdb/garfvdb_environment.yml +++ b/instance_segmentation/garfvdb/garfvdb_environment.yml @@ -3,10 +3,12 @@ channels: - conda-forge - nodefaults dependencies: + - fvdb-core>=0.4.2,<0.5.0 + - cxx-compiler - blas=*=mkl - python=3.12 - - pytorch-gpu=2.8.0 - - cuda-version=12.9 + - pytorch-gpu=2.10.0 + - cuda-version>=12.9 - pip - git - gitpython @@ -16,6 +18,8 @@ dependencies: - tyro - scikit-learn - py-opencv + - open-clip-torch + - openusd - imageio - sam2 - pyproj @@ -32,3 +36,7 @@ dependencies: - rapidsai::cuml - cupy - matplotlib + - pip: + - dlnr_lite + - fvdb-reality-capture>=0.4.0<0.5.0 + - nanovdb-editor>=0.0.24 diff --git a/instance_segmentation/garfvdb/pyproject.toml b/instance_segmentation/garfvdb/pyproject.toml index 7ca3edb..e6a4de0 100644 --- a/instance_segmentation/garfvdb/pyproject.toml +++ b/instance_segmentation/garfvdb/pyproject.toml @@ -17,8 +17,9 @@ dependencies = [ "tqdm", "tyro", "tensorboard", - "fvdb-core", - "fvdb-reality-capture", + "fvdb-core>=0.4.2,<0.5.0", + "nanovdb-editor>=0.0.24,<0.2.0", + "fvdb-reality-capture>=0.4.0<0.5.0", "pycocotools", "matplotlib", "pillow", diff --git a/instance_segmentation/garfvdb/train_segmentation.py b/instance_segmentation/garfvdb/train_segmentation.py index 61d990e..344d731 100644 --- a/instance_segmentation/garfvdb/train_segmentation.py +++ b/instance_segmentation/garfvdb/train_segmentation.py @@ -3,16 +3,13 @@ # import logging import pathlib -import threading import time -from typing import Callable, Literal +from typing import Literal import cv2 -import fvdb.viz as fviz import numpy as np import torch import tyro - from garfvdb.config import ( GaussianSplatSegmentationTrainingConfig, SfmSceneSegmentationTransformConfig, @@ -41,9 +38,8 @@ def main( visualize_every: int = -1, viewer_port: int = 8080, viewer_ip_address: str = "127.0.0.1", - camera_check_interval: float = 0.5, - overlay_width: int = 1920, - overlay_height: int = 1080, + overlay_width: int = 1440, + overlay_height: int = 720, overlay_downsample: int = 2, mask_scale: float = 0.1, mask_blend: float = 0.5, @@ -55,61 +51,69 @@ def main( Args: sfm_dataset_path (pathlib.Path): Path to the dataset. For "colmap" datasets, this should be the - directory containing the `images` and `sparse` subdirectories. For "simple_directory" datasets, - this should be the directory containing the images and a `cameras.txt` file. For "e57" datasets, - this should be the path to the E57 file or directory. + directory containing the ``images`` and ``sparse`` subdirectories. For "simple_directory" + datasets, this should be the directory containing the images and a ``cameras.txt`` file. + For "e57" datasets, this should be the path to the E57 file or directory. reconstruction_path (pathlib.Path): Path to the precomputed Gaussian splat reconstruction to load. - config (GaussianSplatSegmentationTrainingConfig): Configuration for the Gaussian splat segmentation training. - tx (SfmSceneSegmentationTransformConfig): Configuration for the transforms to apply to the SfmScene before training. - io (GaussianSplatSegmentationWriterConfig): Configuration for saving segmentation metrics and checkpoints. + config (GaussianSplatSegmentationTrainingConfig): Configuration for the Gaussian splat + segmentation training. + tx (SfmSceneSegmentationTransformConfig): Configuration for the transforms to apply to the + SfmScene before training. + io (GaussianSplatSegmentationWriterConfig): Configuration for saving segmentation metrics + and checkpoints. dataset_type (Literal["colmap", "simple_directory", "e57"]): Type of dataset to load. - run_name (str | None): Name of the run. If None, a name will be generated based on the current date and time. - log_path (pathlib.Path | None): Path to log metrics and checkpoints. If None, no metrics or checkpoints will be saved. - use_every_n_as_val (int): Use every n-th image as a validation image. If -1, do not use a separate validation split. + run_name (str | None): Name of the run. If None, a name will be generated based on the + current date and time. + log_path (pathlib.Path | None): Path to log metrics and checkpoints. If None, no metrics + or checkpoints will be saved. + use_every_n_as_val (int): Use every n-th image as a validation image. If -1, do not use + a separate validation split. device (str | torch.device): Device to use for training. log_every (int): Log training metrics every n steps. visualize_every (int): Update the viewer every n epochs. If -1, do not visualize. - viewer_port (int): The port to expose the viewer server on if visualize_every > 0. - viewer_ip_address (str): The IP address to expose the viewer server on if visualize_every > 0. - camera_check_interval (float): How often to check for camera changes (seconds). The segmentation - overlay will update when the viewer camera moves. - overlay_width (int): Width of the segmentation overlay in the viewer. - overlay_height (int): Height of the segmentation overlay in the viewer. - overlay_downsample (int): Downsample factor for rendering. Renders at overlay_size / overlay_downsample - and then scales up for better performance. - mask_scale (float): Fraction of scene max scale to use for rendering segmentation masks (0.1 = 10%). - mask_blend (float): Blend factor for the segmentation overlay (0.0 = transparent, 1.0 = opaque). + viewer_port (int): The port to expose the viewer server on if ``visualize_every > 0``. + viewer_ip_address (str): The IP address to expose the viewer server on if + ``visualize_every > 0``. + overlay_width (int): Width of the segmentation overlay in the viewer. Must match + the nanovdb-editor viewport width for correct alignment (default 1440). + overlay_height (int): Height of the segmentation overlay in the viewer. Must match + the nanovdb-editor viewport height for correct alignment (default 720). + overlay_downsample (int): Downsample factor for rendering. Renders at + ``overlay_size / overlay_downsample`` and then scales up for better performance. + mask_scale (float): Fraction of scene max scale to use for rendering segmentation masks + (0.1 = 10%). + mask_blend (float): Blend factor for the segmentation overlay (0.0 = transparent, + 1.0 = opaque). verbose (bool): Whether to log debug messages. cache_dataset (bool): If True, cache images and masks in memory to speed up data loading. Set to False to reduce memory usage for large datasets. Default is True. """ log_level = logging.DEBUG if verbose else logging.WARNING - logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") + logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s", force=True) logger = logging.getLogger(__name__) - # Validate camera_check_interval when visualization is enabled - if visualize_every > 0 and camera_check_interval <= 0: - raise ValueError( - f"camera_check_interval must be positive when visualize_every > 0, got {camera_check_interval}" - ) + viewer_enabled = visualize_every > 0 + # ---- Initialize viewer BEFORE any CUDA operations ---- + # The Vulkan device created by fviz.init() must be set up before the CUDA + # context is first created (by load_splats_from_file). This matches the + # initialization order used in frgs reconstruction. + if viewer_enabled: + import fvdb.viz as fviz + fviz.init(ip_address=viewer_ip_address, port=viewer_port, verbose=verbose) - # Load the SfmScene - sfm_scene = load_sfm_scene(sfm_dataset_path, dataset_type) - # Load the GaussianSplat3D model + # ---- Load data ---- + sfm_scene = load_sfm_scene(sfm_dataset_path, dataset_type) gs_model, metadata = load_splats_from_file(reconstruction_path, device) normalization_transform = metadata.get("normalization_transform", None) - # Set up visualization if enabled + # ---- Start the viewer ---- viz_scene = None - viz_callback: Callable[[GaussianSplatScaleConditionedSegmentation, int], None] | None = None + image_view = None + if viewer_enabled: + import fvdb.viz as fviz - if visualize_every > 0: - logger.info(f"Starting viewer server on {viewer_ip_address}:{viewer_port}") - fviz.init(ip_address=viewer_ip_address, port=viewer_port, verbose=verbose) viz_scene = fviz.get_scene("Gaussian Splat Segmentation Training") - - # Add the Gaussian splats to the scene viz_scene.add_gaussian_splat_3d("Gaussian Splats", gs_model) # Set up initial camera position based on scene geometry @@ -129,170 +133,188 @@ def main( up=[0, 0, 1], ) - # Get projection matrix from metadata for camera intrinsics - projection_matrices = metadata.get("projection_matrices", None) - image_sizes = metadata.get("image_sizes", None) - - # Compute render dimensions (smaller for performance) - render_w = overlay_width // overlay_downsample - render_h = overlay_height // overlay_downsample - - # Try to set up the segmentation overlay image at full resolution - image_view = None + # Add the segmentation overlay image + initial_image = np.zeros((overlay_height, overlay_width, 4), dtype=np.uint8) + initial_image[..., 3] = 128 try: - initial_image = np.zeros((overlay_height, overlay_width, 4), dtype=np.uint8) - initial_image[..., 3] = 128 # Semi-transparent image_view = viz_scene.add_image( name="Segmentation Overlay", width=overlay_width, height=overlay_height, rgba_image=initial_image.flatten(), ) - logger.info(f"Segmentation overlay: {overlay_width}x{overlay_height} (render: {render_w}x{render_h})") - except Exception as e: - logger.warning(f"Failed to set up segmentation overlay: {e}") - logger.info("Running without segmentation overlay") - - # Get reference projection for intrinsics from metadata and scale for render resolution - reference_projection = None - if projection_matrices is not None and image_sizes is not None: - orig_projection = projection_matrices[0].float() - orig_w = float(image_sizes[0, 0].item()) - orig_h = float(image_sizes[0, 1].item()) - # Scale the projection matrix to the render resolution - # fx, fy scale with resolution, cx, cy scale with resolution - scale_x = render_w / orig_w - scale_y = render_h / orig_h - scaled_projection = orig_projection.clone() - scaled_projection[0, 0] *= scale_x # fx - scaled_projection[1, 1] *= scale_y # fy - scaled_projection[0, 2] *= scale_x # cx - scaled_projection[1, 2] *= scale_y # cy - reference_projection = scaled_projection.to(device) - - # OpenGL to OpenCV conversion matrix - opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) - - def get_camera_from_viewer() -> torch.Tensor | None: - """Compute camera-to-world matrix from the viewer's orbit camera state. - - Constructs a 4x4 transformation matrix from the viewer's current - orbit parameters (center, direction, radius, up vector) and converts - it from OpenGL to OpenCV convention for use with the segmentation model. - - Returns: - torch.Tensor | None: A 4x4 camera-to-world matrix in OpenCV - convention, or ``None`` if the camera state cannot be retrieved. - """ - try: - # Get orbit camera state from viewer - center = viz_scene.camera_orbit_center.cpu().numpy().copy() - eye_direction = viz_scene.camera_orbit_direction.cpu().numpy().copy() - radius = viz_scene.camera_orbit_radius - up_world = viz_scene.camera_up_direction.cpu().numpy().copy() - - # Camera position: center - eye_direction * radius - position = center - eye_direction * radius - - # Forward = eye_direction (already the look direction) - forward = eye_direction / np.linalg.norm(eye_direction) - - # Right vector = forward × up_world - right = np.cross(forward, up_world) - right = right / np.linalg.norm(right) - - # Up vector = right × forward - up = np.cross(right, forward) - up = up / np.linalg.norm(up) - - # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) - c2w_opengl = np.eye(4, dtype=np.float32) - c2w_opengl[:3, 0] = right - c2w_opengl[:3, 1] = up - c2w_opengl[:3, 2] = -forward # OpenGL: camera looks along -Z - c2w_opengl[:3, 3] = position - - # Convert to OpenCV convention for the segmentation model - c2w_opencv = c2w_opengl @ opengl_to_opencv - return torch.from_numpy(c2w_opencv).float().to(device) - except Exception as e: - logger.debug(f"Failed to get camera from viewer: {e}") - return None + except Exception as exc: + logger.warning(f"Failed to create segmentation overlay image: {exc}") + logger.warning("Overlay visualization will be disabled for this session.") - def render_segmentation_overlay(model: GARfVDBModel, log_prefix: str = "") -> None: - """Render a segmentation mask overlay in the interactive viewer. - - Captures the current viewer camera pose, renders the model's mask - features at a reduced resolution, applies PCA to produce an RGB - visualization, and updates the viewer's overlay image. - - Args: - model: The segmentation model to render. - log_prefix: Optional message to log on successful render. - """ - if image_view is None or reference_projection is None: - return + fviz.show() + logger.info("=" * 60) + logger.info(f"Viewer running at http://{viewer_ip_address}:{viewer_port}") + logger.info(f"Visualization updates every {visualize_every} epoch(s)") + logger.info("=" * 60) - camera_to_world = get_camera_from_viewer() - if camera_to_world is None: - return + # ---- Build the overlay helpers used by the viz_callback ---- + # Compute render dimensions (smaller for performance) + render_w = overlay_width // overlay_downsample + render_h = overlay_height // overlay_downsample + + # Build intrinsic matrix from the viewer's FOV to match its perspective + # camera. Recomputed whenever the FOV changes. + cx = render_w / 2.0 + cy = render_h / 2.0 + cached_fov: float | None = None + reference_projection: torch.Tensor | None = None + + def _update_projection(fov_y_rad: float) -> torch.Tensor: + nonlocal cached_fov, reference_projection + fy = render_h / (2.0 * np.tan(fov_y_rad / 2.0)) + reference_projection = torch.tensor( + [[fy, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], + dtype=torch.float32, + ).to(device) + cached_fov = fov_y_rad + return reference_projection + + # OpenGL to OpenCV conversion matrix + opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) + + def _camera_tuple_to_c2w( + center: np.ndarray, + eye_direction: np.ndarray, + radius: float, + up_world: np.ndarray, + ) -> torch.Tensor | None: + """Convert orbit camera parameters to a 4x4 camera-to-world matrix. + + Constructs a camera-to-world transformation matrix from the viewer's + orbit parameters (center, direction, radius, up vector) and converts + it from OpenGL to OpenCV convention for use with the segmentation model. + + Returns: + A 4x4 camera-to-world matrix in OpenCV convention on ``device``, + or ``None`` if the camera state is degenerate (zero-length or + near-parallel vectors). + """ + # Camera position: center - eye_direction * radius + position = center - eye_direction * radius + + # Forward = eye_direction (already the look direction) + eye_norm = np.linalg.norm(eye_direction) + if eye_norm < 1e-8: + return None + forward = eye_direction / eye_norm + + # Right vector = forward x up_world + right = np.cross(forward, up_world) + right_norm = np.linalg.norm(right) + if right_norm < 1e-8: + return None + right /= right_norm + + # Up vector = right x forward + up = np.cross(right, forward) + up_norm = np.linalg.norm(up) + if up_norm < 1e-8: + return None + up /= up_norm + + # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) + c2w_gl = np.eye(4, dtype=np.float32) + c2w_gl[:3, 0] = right + c2w_gl[:3, 1] = up + c2w_gl[:3, 2] = -forward # OpenGL: camera looks along -Z + c2w_gl[:3, 3] = position + + # Convert to OpenCV convention for the segmentation model + c2w_cv = c2w_gl @ opengl_to_opencv + return torch.from_numpy(c2w_cv).float().to(device) + + def render_overlay(model: GARfVDBModel, camera_to_world: torch.Tensor) -> np.ndarray | None: + """Render the segmentation overlay for a given camera pose. + + Returns an ``(H, W, 4)`` uint8 RGBA image, or ``None`` on failure. + """ + if reference_projection is None: + return None + try: + from garfvdb.training.dataset import GARfVDBInput + # Render at lower resolution for performance try: - from garfvdb.training.dataset import GARfVDBInput - - # Render at lower resolution for performance - model_input = GARfVDBInput( - { - "projection": reference_projection.unsqueeze(0), - "camera_to_world": camera_to_world.unsqueeze(0), - "image_w": [render_w], - "image_h": [render_h], - } - ) - - # Render at configured fraction of max scale - max_scale_tensor = model.max_grouping_scale - desired_scale = float(max_scale_tensor.item()) * mask_scale - - with torch.no_grad(): - mask_features_output, mask_alpha = model.get_mask_output(model_input, desired_scale) - - # Check for invalid values - if torch.isnan(mask_features_output).any() or torch.isinf(mask_features_output).any(): - logger.warning("Invalid values in mask features, skipping visualization update") - return - - # Apply PCA projection to get RGB visualization - mask_pca = pca_projection_fast(mask_features_output, 3, mask=mask_alpha.squeeze(-1) > 0)[0] - - # Blend mask with alpha - blended_alpha = mask_alpha[0] * mask_blend - - rgba = np.concatenate( - [mask_pca.detach().cpu().numpy(), blended_alpha.detach().cpu().numpy()], axis=-1 - ) - rgba_uint8 = (rgba.clip(0.0, 1.0) * 255).astype(np.uint8) - - # Scale up to overlay resolution - if overlay_downsample > 1: - rgba_uint8 = cv2.resize( - rgba_uint8, (overlay_width, overlay_height), interpolation=cv2.INTER_LINEAR - ) - - image_view.update(rgba_uint8.flatten()) - if log_prefix: - logger.debug(f"{log_prefix}") - - except Exception as e: - logger.warning(f"Error updating visualization: {e}") - - # Create visualization callback that uses the viewer camera - def update_visualization(runner: GaussianSplatScaleConditionedSegmentation, epoch: int) -> None: - """Update the visualization overlay during training from current viewer camera.""" - render_segmentation_overlay(runner.model, f"Updated segmentation visualization at epoch {epoch}") - - viz_callback = update_visualization + world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + except torch.linalg.LinAlgError: + return None + model_input = GARfVDBInput( + { + "projection": reference_projection.unsqueeze(0), + "camera_to_world": camera_to_world.unsqueeze(0).contiguous(), + "world_to_camera": world_to_camera.unsqueeze(0), + "image_w": [render_w], + "image_h": [render_h], + } + ) + # Render at configured fraction of max scale + max_scale_tensor = model.max_grouping_scale + desired_scale = float(max_scale_tensor.item()) * mask_scale + + with torch.no_grad(): + mask_features_output, mask_alpha = model.get_mask_output(model_input, desired_scale) + + # Check for invalid values + if torch.isnan(mask_features_output).any() or torch.isinf(mask_features_output).any(): + return None + + # Apply PCA projection to get RGB visualization + mask_pca = pca_projection_fast(mask_features_output, 3, mask=mask_alpha.squeeze(-1) > 0)[0] + + # Blend mask with alpha + blended_alpha = mask_alpha[0] * mask_blend + rgba = np.concatenate([mask_pca.detach().cpu().numpy(), blended_alpha.detach().cpu().numpy()], axis=-1) + rgba_uint8 = (rgba.clip(0.0, 1.0) * 255).astype(np.uint8) + + # Scale up to overlay resolution + if overlay_downsample > 1: + rgba_uint8 = cv2.resize(rgba_uint8, (overlay_width, overlay_height), interpolation=cv2.INTER_LINEAR) + return rgba_uint8 + except Exception as exc: + logger.warning(f"Error rendering overlay: {exc}") + return None + + def get_viewer_camera() -> tuple[np.ndarray, np.ndarray, float, np.ndarray] | None: + """Read the current orbit camera state directly from the viewer.""" + if viz_scene is None: + return None + try: + return ( + viz_scene.camera_orbit_center.cpu().numpy(), + viz_scene.camera_orbit_direction.cpu().numpy(), + float(viz_scene.camera_orbit_radius), + viz_scene.camera_up_direction.cpu().numpy(), + ) + except Exception: + return None + + def update_visualization(runner_arg: GaussianSplatScaleConditionedSegmentation, epoch: int) -> None: + """Viz callback invoked at epoch boundaries to update the segmentation overlay.""" + if image_view is None or viz_scene is None: + return + cam = get_viewer_camera() + if cam is None: + return + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) + c2w = _camera_tuple_to_c2w(*cam) + if c2w is None: + return + frame = render_overlay(runner_arg.model, c2w) + if frame is not None: + image_view.update(frame.flatten()) + logger.debug(f"Updated segmentation overlay at epoch {epoch}") + +# ---- Create the runner (SAM2 masks + model init) ---- + # ---- Create the runner (SAM2 masks + model init) ---- writer = GaussianSplatSegmentationWriter(run_name=run_name, save_path=log_path, config=io, exist_ok=False) runner = GaussianSplatScaleConditionedSegmentation.new( sfm_scene=tx.build_scene_transforms(gs_model, normalization_transform)(sfm_scene), @@ -304,125 +326,32 @@ def update_visualization(runner: GaussianSplatScaleConditionedSegmentation, epoc use_every_n_as_val=use_every_n_as_val, viewer_update_interval_epochs=visualize_every, log_interval_steps=log_every, - viz_callback=viz_callback, + viz_callback=update_visualization if image_view is not None else None, cache_dataset=cache_dataset, ) - # Set up camera tracking thread for interactive visualization - camera_thread = None - stop_camera_thread = threading.Event() - - if viz_scene is not None and image_view is not None: - # Camera state for change detection - prev_center = None - prev_direction = None - prev_radius = None - prev_up = None - - def camera_changed() -> bool: - """Determine whether the interactive viewer camera state has changed. - - Compares the current camera orbit parameters (center, direction, - radius, and up vector) from ``viz_scene`` against the previously - observed values stored in the nonlocal variables ``prev_center``, - ``prev_direction``, ``prev_radius``, and ``prev_up``. - - Notes: - This function updates the ``prev_*`` variables when called for - the first time or when a change is detected. - - Returns: - bool: ``True`` if this is the first call (to force an initial - overlay update) or if any of the tracked camera parameters - differ from the previously stored values; ``False`` otherwise - or if an error occurs while querying the camera state. - """ - nonlocal prev_center, prev_direction, prev_radius, prev_up - try: - center = viz_scene.camera_orbit_center - direction = viz_scene.camera_orbit_direction - radius = viz_scene.camera_orbit_radius - up = viz_scene.camera_up_direction - - # First time - always update - if prev_center is None: - prev_center = center - prev_direction = direction - prev_radius = radius - prev_up = up - return True - - changed = ( - not torch.allclose(center, prev_center) - or not torch.allclose(direction, prev_direction) # type: ignore - or radius != prev_radius - or not torch.allclose(up, prev_up) # type: ignore - ) - - if changed: - prev_center = center - prev_direction = direction - prev_radius = radius - prev_up = up - - return changed - except Exception as exc: - logger.debug( - "Failed to retrieve camera state in camera_changed: %s", - exc, - exc_info=True, - ) - return False - - def camera_monitor_loop() -> None: - """Monitor camera changes in a background thread and trigger overlay updates. - - This function is intended to be run in a dedicated daemon thread started - from :func:`main`. It periodically checks the interactive viewer camera - state and, when a change is detected, re-renders the segmentation overlay - for the current model. - - The loop runs until the ``stop_camera_thread`` :class:`threading.Event` - is set (e.g. in the main thread's shutdown/KeyboardInterrupt handler), - at which point the background thread exits cleanly. - - While :meth:`runner.train` executes in the main thread, this monitor - runs concurrently to keep the visualization in sync with user-driven - camera movements, and it continues to run during the post-training - viewing phase until shutdown. - """ - while not stop_camera_thread.is_set(): - time.sleep(camera_check_interval) - if camera_changed(): - render_segmentation_overlay(runner.model, "Updated segmentation overlay from camera") - - # Start the camera monitoring thread - camera_thread = threading.Thread(target=camera_monitor_loop, daemon=True) - camera_thread.start() - logger.info(f"Camera tracking enabled (checking every {camera_check_interval}s)") - - if viz_scene is not None: - logger.info("=" * 60) - logger.info(f"Viewer running at http://{viewer_ip_address}:{viewer_port}") - logger.info(f"Visualization updates every {visualize_every} epoch(s)") - if camera_thread is not None: - logger.info(f"Camera tracking updates every {camera_check_interval}s") - logger.info("=" * 60) - fviz.show() - + # ---- Train ---- runner.train() + # ---- Post-training interactive viewing ---- if viz_scene is not None: logger.info("Training complete. Viewer running... Ctrl+C to exit.") try: - # Camera thread continues running during post-training viewing while True: - time.sleep(1) + cam = get_viewer_camera() + if cam is not None: + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) + c2w = _camera_tuple_to_c2w(*cam) + if c2w is None: + continue + frame = render_overlay(runner.model, c2w) + if frame is not None and image_view is not None: + image_view.update(frame.flatten()) + time.sleep(0.5) except KeyboardInterrupt: - if camera_thread is not None: - stop_camera_thread.set() - camera_thread.join(timeout=5.0) - logger.info("Shutting down...") + pass if __name__ == "__main__": diff --git a/instance_segmentation/garfvdb/view_checkpoint.py b/instance_segmentation/garfvdb/view_checkpoint.py index ac1d806..1f223fa 100644 --- a/instance_segmentation/garfvdb/view_checkpoint.py +++ b/instance_segmentation/garfvdb/view_checkpoint.py @@ -233,10 +233,10 @@ class ViewCheckpoint: device: str | torch.device = "cuda" """Device for computation (e.g., "cuda" or "cpu").""" - initial_scale: float = 0.1 + mask_scale: float = 0.1 """Initial segmentation scale as a fraction of max scale.""" - initial_blend: float = 0.5 + mask_blend: float = 0.5 """Initial mask blend factor (0=beauty only, 1=mask only).""" camera_check_interval: float = 0.5 @@ -245,11 +245,13 @@ class ViewCheckpoint: no_overlay: bool = False """Disable the segmentation overlay (show Gaussian splats only).""" - overlay_width: int = 1920 - """Width of the segmentation overlay in pixels.""" + overlay_width: int = 1440 + """Width of the segmentation overlay in pixels. Must match the + nanovdb-editor viewport width for correct alignment (default 1440).""" - overlay_height: int = 1080 - """Height of the segmentation overlay in pixels.""" + overlay_height: int = 720 + """Height of the segmentation overlay in pixels. Must match the + nanovdb-editor viewport height for correct alignment (default 720).""" overlay_downsample: int = 2 """Downsample factor for rendering (renders at overlay_size/downsample).""" @@ -260,6 +262,11 @@ def execute(self) -> None: logging.basicConfig(level=log_level, format="%(levelname)s : %(message)s") logger = logging.getLogger(__name__) + # Initialize fvdb.viz + logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}") + fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose) + viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer") + device = torch.device(self.device) # Validate segmentation checkpoint path @@ -299,13 +306,8 @@ def execute(self) -> None: segmentation_model=segmentation_model, device=device, ) - renderer.scale = self.initial_scale * float(segmentation_model.max_grouping_scale.item()) - renderer.mask_blend = self.initial_blend - - # Initialize fvdb.viz - logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}") - fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose) - viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer") + renderer.scale = self.mask_scale * float(segmentation_model.max_grouping_scale.item()) + renderer.mask_blend = self.mask_blend # Set initial camera position scene_centroid = gs_model.means.mean(dim=0).cpu().numpy() @@ -338,11 +340,6 @@ def execute(self) -> None: image_sizes=image_sizes, ) - # Get original training image size for projection scaling - assert image_sizes is not None - orig_img_w = int(image_sizes[0, 0].item()) - orig_img_h = int(image_sizes[0, 1].item()) - # Compute render dimensions (smaller for performance) render_w = self.overlay_width // self.overlay_downsample render_h = self.overlay_height // self.overlay_downsample @@ -351,29 +348,10 @@ def execute(self) -> None: # Add the Gaussian splat model to the scene viz_scene.add_gaussian_splat_3d("Gaussian Splats", gs_model) - # Set up the segmentation overlay if enabled + # Overlay is created lazily on first render to avoid the C++ viewer + # render thread touching an image grid before we have real content. image_view = None - if not self.no_overlay: - try: - # Create an initial blank image at full overlay resolution - initial_image = np.zeros((self.overlay_height, self.overlay_width, 4), dtype=np.uint8) - initial_image[..., 3] = 128 # Semi-transparent - - # Add the image overlay using the add_image API - image_view = viz_scene.add_image( # type: ignore[call-arg] - name="Segmentation Overlay", - width=self.overlay_width, - height=self.overlay_height, - rgba_image=initial_image.flatten(), - ) - logger.info("Segmentation overlay enabled") - - except AttributeError as e: - logger.warning(f"add_image API not available: {e}") - logger.info("Running without segmentation overlay") - except Exception as e: - logger.warning(f"Failed to set up segmentation overlay: {e}") - logger.info("Running without segmentation overlay") + overlay_enabled = not self.no_overlay logger.info("=" * 60) logger.info("Viewer running... Ctrl+C to exit.") @@ -382,7 +360,7 @@ def execute(self) -> None: logger.info("Segmentation settings:") logger.info(f" - Scale: {renderer.scale:.4f} (max: {segmentation_model.max_grouping_scale:.4f})") logger.info(f" - Mask blend: {renderer.mask_blend:.2f}") - if image_view is not None: + if overlay_enabled: logger.info(f" - Overlay: {self.overlay_width}x{self.overlay_height} (render: {render_w}x{render_h})") logger.info(f" - Update interval: {self.camera_check_interval}s") else: @@ -396,15 +374,17 @@ def execute(self) -> None: prev_direction = None prev_radius = None prev_up = None + prev_fov = None def camera_changed() -> bool: """Check if camera state changed using documented fvdb.viz.Scene properties.""" - nonlocal prev_center, prev_direction, prev_radius, prev_up + nonlocal prev_center, prev_direction, prev_radius, prev_up, prev_fov try: - center = viz_scene.camera_orbit_center - direction = viz_scene.camera_orbit_direction + fov = viz_scene.camera_fov + center = viz_scene.camera_orbit_center.clone().cpu().numpy() + direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius - up = viz_scene.camera_up_direction + up = viz_scene.camera_up_direction.clone().cpu().numpy() # First time - always update if prev_center is None: @@ -412,6 +392,7 @@ def camera_changed() -> bool: prev_direction = direction prev_radius = radius prev_up = up + prev_fov = fov return True assert prev_center is not None and prev_direction is not None and prev_up is not None @@ -420,6 +401,7 @@ def camera_changed() -> bool: or not np.allclose(direction, prev_direction) or radius != prev_radius or not np.allclose(up, prev_up) + or fov != prev_fov ) if changed: @@ -427,23 +409,28 @@ def camera_changed() -> bool: prev_direction = direction prev_radius = radius prev_up = up + prev_fov = fov return changed except Exception: return False - # Get reference projection from SfmScene and scale for render resolution - sfm_projection = torch.from_numpy(sfm_scene.projection_matrices).float() - orig_projection = sfm_projection[0] - # Scale the projection matrix from original training image size to render resolution - scale_x = render_w / orig_img_w - scale_y = render_h / orig_img_h - scaled_projection = orig_projection.clone() - scaled_projection[0, 0] *= scale_x # fx - scaled_projection[1, 1] *= scale_y # fy - scaled_projection[0, 2] *= scale_x # cx - scaled_projection[1, 2] *= scale_y # cy - reference_projection = scaled_projection.to(device) + # Build intrinsic matrix from the viewer's FOV to match its perspective + # camera. Recomputed whenever the FOV changes. + cx = render_w / 2.0 + cy = render_h / 2.0 + cached_fov: float | None = None + reference_projection: torch.Tensor | None = None + + def _update_projection(fov_y_rad: float) -> torch.Tensor: + nonlocal cached_fov, reference_projection + fy = render_h / (2.0 * np.tan(fov_y_rad / 2.0)) + reference_projection = torch.tensor( + [[fy, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], + dtype=torch.float32, + ).to(device) + cached_fov = fov_y_rad + return reference_projection # OpenGL to OpenCV conversion matrix (applied to camera axes) # OpenGL: X-right, Y-up, Z-backward @@ -451,33 +438,45 @@ def camera_changed() -> bool: opengl_to_opencv = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) def update_overlay() -> None: - """Render and update the segmentation overlay.""" - if image_view is None: + """Render segmentation and lazily create/update the image overlay.""" + nonlocal image_view, overlay_enabled + if not overlay_enabled: return try: - # Get orbit camera state from viewer + fov_y_rad = viz_scene.camera_fov + if reference_projection is None or fov_y_rad != cached_fov: + _update_projection(fov_y_rad) + # NOTE: Despite the Python API name "camera_orbit_direction", the C++ implementation # returns eye_direction which is the direction the camera is LOOKING (toward scene). # Camera position = center - eye_direction * distance (see Camera.h line 387-390) - center = np.array(viz_scene.camera_orbit_center.cpu().numpy()) - eye_direction = np.array(viz_scene.camera_orbit_direction.cpu().numpy()) + center = viz_scene.camera_orbit_center.clone().cpu().numpy() + eye_direction = viz_scene.camera_orbit_direction.clone().cpu().numpy() radius = viz_scene.camera_orbit_radius - up_world = np.array(viz_scene.camera_up_direction.cpu().numpy()) - - # Camera position: center - eye_direction * radius - # (eye_direction points FROM camera TOWARD center) + up_world = viz_scene.camera_up_direction.clone().cpu().numpy() position = center - eye_direction * radius - # Forward = eye_direction (already the look direction) - forward = eye_direction / np.linalg.norm(eye_direction) + # Guard against degenerate camera states (zero-length vectors + # produce NaN and ultimately a singular matrix). + eye_norm = np.linalg.norm(eye_direction) + if eye_norm < 1e-8: + return - # Right vector = forward × up_world + forward = eye_direction / eye_norm + + # Right vector = forward x up_world right = np.cross(forward, up_world) - right = right / np.linalg.norm(right) + right_norm = np.linalg.norm(right) + if right_norm < 1e-8: + return + right = right / right_norm - # Up vector = right × forward + # Up vector = right x forward up = np.cross(right, forward) - up = up / np.linalg.norm(up) + up_norm = np.linalg.norm(up) + if up_norm < 1e-8: + return + up = up / up_norm # Build OpenGL-style camera-to-world (X-right, Y-up, Z-backward) # In OpenGL, camera looks along -Z, so Z column = -forward @@ -491,8 +490,13 @@ def update_overlay() -> None: c2w_opencv = c2w_opengl @ opengl_to_opencv camera_to_world = torch.from_numpy(c2w_opencv).float().to(renderer.device) - world_to_camera = torch.linalg.inv(camera_to_world) + try: + world_to_camera = torch.linalg.inv(camera_to_world).contiguous() + except torch._C._LinAlgError: + return + # Render at lower resolution for performance + assert reference_projection is not None rgba_image = renderer.render_segmentation_image( camera_to_world, world_to_camera, reference_projection, render_w, render_h ) @@ -501,7 +505,25 @@ def update_overlay() -> None: rgba_image = cv2.resize( rgba_image, (self.overlay_width, self.overlay_height), interpolation=cv2.INTER_LINEAR ) - image_view.update(rgba_image.flatten()) # type: ignore[attr-defined] + + flat_rgba = rgba_image.flatten() + + if image_view is None: + # Lazily create the image overlay on first render + try: + image_view = viz_scene.add_image( # type: ignore[call-arg] + name="Segmentation Overlay", + width=self.overlay_width, + height=self.overlay_height, + rgba_image=flat_rgba, + ) + except Exception as e: + logger.warning(f"add_image API not available or failed: {e}") + overlay_enabled = False + return + else: + image_view.update(flat_rgba) # type: ignore[attr-defined] + logger.debug( f"Updated segmentation overlay ({render_w}x{render_h} -> {self.overlay_width}x{self.overlay_height})" ) @@ -515,7 +537,7 @@ def update_overlay() -> None: while True: time.sleep(self.camera_check_interval) - if image_view is not None and camera_changed(): + if overlay_enabled and camera_changed(): logger.debug("Camera changed, updating overlay...") update_overlay()