diff --git a/.gitignore b/.gitignore index f3567a1f0..f14b16aeb 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,6 @@ scratch/ *_benchmark.json *_results.json comparison_*.png + +# A staging file for PR descriptions +_PR_STAGING.md diff --git a/CMakeLists.txt b/CMakeLists.txt index 02ec5bb58..96be51c02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/) set(FVDB_BINDINGS_CPP_FILES src/python/Bindings.cpp src/python/FusedSSIMBinding.cpp - src/python/GaussianSplatBinding.cpp + src/python/GaussianSplatOps.cpp src/python/GridBatchDataBinding.cpp src/python/GridBatchOps.cpp src/python/JaggedTensorBinding.cpp diff --git a/docs/api/functional_splat.rst b/docs/api/functional_splat.rst new file mode 100644 index 000000000..ee79775c4 --- /dev/null +++ b/docs/api/functional_splat.rst @@ -0,0 +1,136 @@ +Functional Gaussian Splatting API +================================= + +.. module:: fvdb.functional + +The Gaussian splatting functions in :mod:`fvdb.functional` provide a +pure-function interface for Gaussian splatting rendering. Every operation is a +standalone function that takes raw tensors as input, following the same design +philosophy as the rest of :mod:`fvdb.functional` for sparse-grid operations. + +The API is organized as a **4-stage composable pipeline**: + +1. ``project_gaussians`` -- geometric projection (Stage 1) +2. ``evaluate_gaussian_sh`` -- SH / feature evaluation (Stage 2) +3. ``intersect_gaussian_tiles`` / ``intersect_gaussian_tiles_sparse`` -- tile intersection (Stage 3) +4. ``rasterize_screen_space_gaussians`` / ``rasterize_world_space_gaussians`` + / ``rasterize_screen_space_gaussians_sparse`` -- rasterization (Stage 4) + +All stages are fully differentiable via Python autograd (except tile intersection). + +.. tip:: + + For standard rendering, the methods on :class:`~fvdb.GaussianSplat3d` + are the simplest entry points. + + The decomposed stages are for users who need fine-grained control over the + rendering pipeline -- for example, to insert custom logic between projection + and rasterization, or to build training loops without the + :class:`~fvdb.GaussianSplat3d` wrapper. + + +**Example: building a custom render pipeline** + +.. code-block:: python + + import torch + import fvdb.functional as F + from fvdb.enums import CameraModel, GaussianRenderMode + + # Raw tensors (no GaussianSplat3d needed) + means = ... # [N, 3] + quats = ... # [N, 4] + log_scales = ... # [N, 3] + logit_opacities = ... # [N] + sh0 = ... # [N, 1, 3] + shN = ... # [N, K-1, 3] + world_to_cam = ... # [C, 4, 4] + K = ... # [C, 3, 3] + + # Stage 1: Geometric projection + projected = F.project_gaussians( + means, quats, log_scales, world_to_cam, K, + image_width=640, image_height=480, + ) + + # Stage 2: View-dependent features (SH evaluation) + features = F.evaluate_gaussian_sh( + means, sh0, shN, world_to_cam, projected, + sh_degree_to_use=3, + render_mode=GaussianRenderMode.FEATURES, + ) + + # Stage 3: Tile intersection + tiles = F.intersect_gaussian_tiles(projected) + + # Stage 4: Rasterize + images, alphas = F.rasterize_screen_space_gaussians( + projected, features, logit_opacities, tiles, + ) + + # Compute loss and backpropagate -- gradients flow through all stages + loss = torch.nn.functional.l1_loss(images, target_images) + loss.backward() + + +Types +------ + +.. autoclass:: ProjectedGaussians + :members: + +.. autoclass:: GaussianTileIntersection + :members: + +.. autoclass:: SparseGaussianTileIntersection + :members: + + +Stage 1: Projection +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: project_gaussians + + +Stage 2: SH / Feature Evaluation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: evaluate_gaussian_sh + + +Stage 3: Tile Intersection +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: intersect_gaussian_tiles + +.. autofunction:: intersect_gaussian_tiles_sparse + + +Stage 4: Rasterization +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: rasterize_screen_space_gaussians + +.. autofunction:: rasterize_world_space_gaussians + +.. autofunction:: rasterize_screen_space_gaussians_sparse + + +Analysis +--------- + +.. autofunction:: count_contributing_gaussians + +.. autofunction:: identify_contributing_gaussians + +.. autofunction:: count_contributing_gaussians_sparse + +.. autofunction:: identify_contributing_gaussians_sparse + + +Metrics +-------- + +.. autofunction:: psnr + +.. autofunction:: ssim diff --git a/docs/api/gaussian_splatting.rst b/docs/api/gaussian_splatting.rst index 91959cf8b..234b9ac0e 100644 --- a/docs/api/gaussian_splatting.rst +++ b/docs/api/gaussian_splatting.rst @@ -1,6 +1,12 @@ Gaussian Splatting ========================== +The :class:`~fvdb.GaussianSplat3d` class provides an object-oriented interface +for Gaussian splatting rendering. It manages Gaussian parameters (means, +quaternions, scales, opacities, SH coefficients) and provides methods for +projection, rasterization, and training-related operations like gradient +accumulation and MCMC densification. + .. autoclass:: fvdb.ProjectedGaussianSplats :members: :special-members: __getitem__, __setitem__ @@ -8,3 +14,15 @@ Gaussian Splatting .. autoclass:: fvdb.GaussianSplat3d :members: :special-members: __getitem__, __setitem__ + +.. seealso:: + + :mod:`fvdb.functional` provides a pure-function alternative for + building custom rendering pipelines without the + :class:`~fvdb.GaussianSplat3d` wrapper. The functional API decomposes + the rendering pipeline into individually composable stages + (:func:`~fvdb.functional.project_gaussians`, + :func:`~fvdb.functional.evaluate_gaussian_sh`, + :func:`~fvdb.functional.intersect_gaussian_tiles`, + :func:`~fvdb.functional.rasterize_screen_space_gaussians`), enabling + custom training loops and pipeline composition without mutable state. diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 885b18d40..48106dc13 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -7,5 +7,7 @@ Utilities ``fvdb.utils.metrics`` -------------------------- -.. automodule:: fvdb.utils.metrics - :members: +.. deprecated:: + The ``fvdb.utils.metrics`` module re-exports :func:`~fvdb.functional.psnr` + and :func:`~fvdb.functional.ssim` for backward compatibility. Use + ``fvdb.functional.psnr`` and ``fvdb.functional.ssim`` directly. diff --git a/docs/index.rst b/docs/index.rst index 8770f8038..8259ee41a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -100,6 +100,7 @@ algorithms for 3D reconstruction from sensor data. api/convolution_plan api/sparse_grids api/functional + api/functional_splat api/gaussian_splatting api/viz api/enums diff --git a/fvdb/__init__.py b/fvdb/__init__.py index 1d32482ff..c87c35347 100644 --- a/fvdb/__init__.py +++ b/fvdb/__init__.py @@ -55,8 +55,6 @@ def _parse_device_string(device_or_device_string: str | torch.device) -> torch.d pass # isort: off -from ._fvdb_cpp import gaussian_render_jagged as _gaussian_render_jagged_cpp -from ._fvdb_cpp import evaluate_spherical_harmonics from ._fvdb_cpp import ( config, volume_render, @@ -95,34 +93,173 @@ def gaussian_render_jagged( backgrounds: torch.Tensor | None = None, masks: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: - return _gaussian_render_jagged_cpp( - means=means._impl, - quats=quats._impl, - scales=scales._impl, - opacities=opacities._impl, - sh_coeffs=sh_coeffs._impl, - viewmats=viewmats._impl, - Ks=Ks._impl, - image_width=image_width, - image_height=image_height, - near_plane=near_plane, - far_plane=far_plane, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - radius_clip=radius_clip, - eps2d=eps2d, - antialias=antialias, - render_depth_channel=render_depth_channel, - return_debug_info=return_debug_info, - ortho=ortho, - backgrounds=backgrounds, - masks=masks, + import math + + from . import _fvdb_cpp as _C + from .functional._gaussian_projection import _ProjectGaussiansJaggedFn + from .functional._gaussian_rasterization import _RasterizeScreenSpaceGaussiansFn + from .functional._gaussian_spherical_harmonics import _EvaluateGaussianSHFn + + ccz = viewmats.jdata.shape[0] # total number of cameras + device = means.jdata.device + dtype = means.jdata.dtype + + # ---- Step 1: Compute batch indices (camera_ids, gaussian_ids) ---- + # g_sizes is [N1, N2, ...], c_sizes is [C1, C2, ...] + g_sizes = means.joffsets[1:] - means.joffsets[:-1] + c_sizes = Ks.joffsets[1:] - Ks.joffsets[:-1] + + # camera_ids: for each element in the expanded (C_i * N_i) layout, which camera it belongs to + tt = g_sizes.repeat_interleave(c_sizes) + camera_ids = torch.arange(ccz, device=device, dtype=torch.int32).repeat_interleave(tt, 0) + + # gaussian_ids: for each element, which gaussian (in the flat jdata) it belongs to + dd0 = means.joffsets[:-1].repeat_interleave(c_sizes, 0) + dd1 = means.joffsets[1:].repeat_interleave(c_sizes, 0) + shifts = dd0[1:] - dd1[:-1] + shifts = torch.cat([torch.tensor([0], device=device), shifts]) + shifts_cumsum = shifts.cumsum(0) + gaussian_ids = torch.arange(camera_ids.shape[0], device=device, dtype=torch.int32) + gaussian_ids = gaussian_ids + shifts_cumsum.repeat_interleave(tt, 0) + + # ---- Step 2: Jagged projection (differentiable via Python autograd) ---- + radii, means2d, depths, conics = _ProjectGaussiansJaggedFn.apply( + g_sizes, + means.jdata, + quats.jdata, + scales.jdata, + c_sizes, + viewmats.jdata, + Ks.jdata, + image_width, + image_height, + eps2d, + near_plane, + far_plane, + radius_clip, + ortho, + ) + + # ---- Step 3: Gather opacities for the expanded layout ---- + opacities_batched = opacities.jdata[gaussian_ids] # [M] + + # ---- Debug info (populated before SH eval so we capture projection outputs) ---- + debug_info: dict[str, torch.Tensor] = {} + if return_debug_info: + debug_info["camera_ids"] = camera_ids + debug_info["gaussian_ids"] = gaussian_ids + debug_info["radii"] = radii + debug_info["means2d"] = means2d + debug_info["depths"] = depths + debug_info["conics"] = conics + debug_info["opacities"] = opacities_batched + + # ---- Step 4: Compute render features (SH eval or depth) ---- + nnz = camera_ids.shape[0] + D = sh_coeffs.jdata.shape[-1] # feature dimension (typically 3 for RGB) + + # sh_coeffs.jdata is [total_N, K, D]; permute to [K, total_N, D] then gather + sh_coeffs_batched = sh_coeffs.jdata.permute(1, 0, 2)[:, gaussian_ids, :] # [K, nnz, D] + K = sh_coeffs_batched.shape[0] + actual_sh_degree = sh_degree_to_use if sh_degree_to_use >= 0 else int(math.sqrt(K)) - 1 + + if actual_sh_degree == 0: + sh0 = sh_coeffs_batched[0].unsqueeze(0) # [1, nnz, D] + features = _EvaluateGaussianSHFn.apply( + actual_sh_degree, + 1, + torch.zeros(1, nnz, 3, device=device, dtype=dtype), + sh0.permute(1, 0, 2), # [nnz, 1, D] + torch.empty(nnz, 0, D, device=device, dtype=dtype), + radii.unsqueeze(0), # [1, nnz] + ) + else: + sh0 = sh_coeffs_batched[0].unsqueeze(0) # [1, nnz, D] + shN = sh_coeffs_batched[1:] # [K-1, nnz, D] + cam_to_world = torch.linalg.inv(viewmats.jdata) # [ccz, 4, 4] + dirs = means.jdata[gaussian_ids] - cam_to_world[camera_ids, :3, 3] # [nnz, 3] + features = _EvaluateGaussianSHFn.apply( + actual_sh_degree, + 1, + dirs.unsqueeze(0), # [1, nnz, 3] + sh0.permute(1, 0, 2), # [nnz, 1, D] + shN.permute(1, 0, 2), # [nnz, K-1, D] + radii.unsqueeze(0), # [1, nnz] + ) + features = features.squeeze(0) # [C=1, nnz, D] -> [nnz, D] + + if render_depth_channel: + features = torch.cat([features, depths[gaussian_ids].unsqueeze(-1)], -1) + + # ---- Step 5: Tile intersection (non-differentiable) ---- + num_tiles_w = math.ceil(image_width / tile_size) + num_tiles_h = math.ceil(image_height / tile_size) + tile_offsets, tile_gaussian_ids = _C.intersect_gaussian_tiles( + means2d, + radii, + depths, + ccz, + tile_size, + num_tiles_h, + num_tiles_w, + camera_ids=camera_ids, + ) + + if return_debug_info: + debug_info["tile_offsets"] = tile_offsets + debug_info["tile_gaussian_ids"] = tile_gaussian_ids + + # ---- Step 6: Rasterize (differentiable via Python autograd) ---- + images, alphas = _RasterizeScreenSpaceGaussiansFn.apply( + means2d, + conics, + features, + opacities_batched.contiguous(), + image_width, + image_height, + 0, + 0, + tile_size, + tile_offsets, + tile_gaussian_ids, + False, + backgrounds, + masks, + ) + + return images, alphas, debug_info + + +def evaluate_spherical_harmonics( + sh_degree: int, + num_cameras: int, + sh0: torch.Tensor, + radii: torch.Tensor, + shN: torch.Tensor | None = None, + view_directions: torch.Tensor | None = None, +) -> torch.Tensor: + """Evaluate spherical harmonics (differentiable via Python autograd).""" + from .functional._gaussian_spherical_harmonics import _EvaluateGaussianSHFn + + if sh_degree > 0: + if view_directions is None: + raise ValueError("view_directions must be provided when sh_degree > 0") + if shN is None: + raise ValueError("shN must be provided when sh_degree > 0") + view_dirs = ( + view_directions + if view_directions is not None + else torch.zeros(num_cameras, sh0.size(0), 3, device=sh0.device, dtype=sh0.dtype) ) + shN_val = shN if shN is not None else torch.empty(sh0.size(0), 0, sh0.size(2), device=sh0.device, dtype=sh0.dtype) + return _EvaluateGaussianSHFn.apply(sh_degree, num_cameras, view_dirs, sh0, shN_val, radii) from .convolution_plan import ConvolutionPlan -from .gaussian_splatting import GaussianSplat3d, ProjectedGaussianSplats -from .enums import CameraModel, ProjectionMethod, RollingShutterType, ShOrderingMode +from .gaussian_splatting import GaussianSplat3d +from .functional._gaussian_projection import ProjectedGaussians +from .functional._gaussian_tile_intersection import GaussianTileIntersection, SparseGaussianTileIntersection +from .enums import CameraModel, GaussianRenderMode, ProjectionMethod, RollingShutterType, ShOrderingMode # Import torch-compatible functions that work with both Tensor and JaggedTensor from .torch_jagged import ( @@ -185,8 +322,11 @@ def gaussian_render_jagged( "GridBatch", "JaggedTensor", "GaussianSplat3d", - "ProjectedGaussianSplats", + "ProjectedGaussians", + "GaussianTileIntersection", + "SparseGaussianTileIntersection", "CameraModel", + "GaussianRenderMode", "ProjectionMethod", "RollingShutterType", "ShOrderingMode", diff --git a/fvdb/__init__.pyi b/fvdb/__init__.pyi deleted file mode 100644 index c0a204f82..000000000 --- a/fvdb/__init__.pyi +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright Contributors to the OpenVDB Project -# SPDX-License-Identifier: Apache-2.0 -# -from __future__ import annotations - -from collections.abc import Sequence -from typing import overload - -import torch - -if torch.cuda.is_available(): - torch.cuda.init() - -def _parse_device_string(device_string: str | torch.device) -> torch.device: ... - -# Make these available without an explicit submodule import -# The following import needs to come after the GridBatch and JaggedTensor imports -# immediately above in order to avoid a circular dependency error. -from . import nn, utils, viz -from ._fvdb_cpp import config, hilbert, morton, volume_render -from .attention import scaled_dot_product_attention -from .convolution_plan import ConvolutionPlan -from .enums import CameraModel, ProjectionMethod, RollingShutterType, ShOrderingMode -from .gaussian_splatting import GaussianSplat3d, ProjectedGaussianSplats -from .grid import Grid -from .grid_batch import GridBatch, gcat -from .jagged_tensor import JaggedTensor, jcat -from .torch_jagged import ( - add, - all, - amax, - amin, - any, - argmax, - argmin, - ceil, - clamp, - eq, - exp, - floor, - floor_divide, - ge, - gt, - le, - log, - lt, - maximum, - mean, - minimum, - mul, - nan_to_num, - ne, - norm, - pow, - relu, - relu_, - remainder, - round, - sigmoid, - sqrt, - std, - sub, - sum, - tanh, - true_divide, - var, - where, -) - -def gaussian_render_jagged( - means: JaggedTensor, - quats: JaggedTensor, - scales: JaggedTensor, - opacities: JaggedTensor, - sh_coeffs: JaggedTensor, - viewmats: JaggedTensor, - Ks: JaggedTensor, - image_width: int, - image_height: int, - near_plane: float = 0.01, - far_plane: float = 1e10, - sh_degree_to_use: int = -1, - tile_size: int = 16, - radius_clip: float = 0.0, - eps2d: float = 0.3, - antialias: bool = False, - render_depth_channel: bool = False, - return_debug_info: bool = False, - render_depth_only: bool = False, - ortho: bool = False, - backgrounds: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: ... -def evaluate_spherical_harmonics( - sh_degree: int, - num_cameras: int, - sh0: torch.Tensor, - radii: torch.Tensor, - shN: torch.Tensor | None = None, - view_directions: torch.Tensor | None = None, -) -> torch.Tensor: ... - -__all__ = [ - # Core classes - "GridBatch", - "Grid", - "JaggedTensor", - "GaussianSplat3d", - "ProjectedGaussianSplats", - "ConvolutionPlan", - "CameraModel", - "ProjectionMethod", - "RollingShutterType", - "ShOrderingMode", - "Grid", - # JaggedTensor operations - # Concatenation of jagged tensors or grid/grid batches - "jcat", - "gcat", - # Morton/Hilbert operations - "morton", - "hilbert", - # Specialized operations - "scaled_dot_product_attention", - "volume_render", - "gaussian_render_jagged", - "evaluate_spherical_harmonics", - # Torch-compatible functions (work with both Tensor and JaggedTensor) - "relu", - "relu_", - "sigmoid", - "tanh", - "exp", - "log", - "sqrt", - "floor", - "ceil", - "round", - "nan_to_num", - "clamp", - "add", - "sub", - "mul", - "true_divide", - "floor_divide", - "remainder", - "pow", - "maximum", - "minimum", - "eq", - "ne", - "lt", - "le", - "gt", - "ge", - "where", - "sum", - "mean", - "amax", - "amin", - "argmax", - "argmin", - "all", - "any", - "norm", - "var", - "std", - # Config - "config", - # Submodules - "viz", - "nn", - "utils", -] diff --git a/fvdb/_fvdb_cpp.pyi b/fvdb/_fvdb_cpp.pyi index 3a6ad2a87..3b906bcd5 100644 --- a/fvdb/_fvdb_cpp.pyi +++ b/fvdb/_fvdb_cpp.pyi @@ -96,399 +96,6 @@ def pred_gather_igemm_conv( stride: int, ) -> torch.Tensor: ... -class GaussianSplat3d: - log_scales: torch.Tensor - logit_opacities: torch.Tensor - means: torch.Tensor - quats: torch.Tensor - requires_grad: bool - sh0: torch.Tensor - shN: torch.Tensor - def __init__( - self, - means: torch.Tensor, - quats: torch.Tensor, - log_scales: torch.Tensor, - logit_opacities: torch.Tensor, - sh0: torch.Tensor, - shN: torch.Tensor, - accumulate_mean_2d_gradients: bool = ..., - accumulate_max_2d_radii: bool = ..., - detach: bool = ..., - ) -> None: ... - @property - def device(self) -> torch.device: ... - @property - def dtype(self) -> torch.dtype: ... - @staticmethod - def cat( - splats: "list[GaussianSplat3d]", - accumulate_mean_2d_gradients: bool = False, - accumulate_max_2d_radii: bool = False, - detach: bool = False, - ) -> "GaussianSplat3d": ... - def to(self, device: torch.device, dtype: torch.dtype) -> "GaussianSplat3d": ... - def detach(self) -> "GaussianSplat3d": ... - def detach_in_place(self) -> None: ... - def index_select(self, indices: torch.Tensor) -> "GaussianSplat3d": ... - def mask_select(self, mask: torch.Tensor) -> "GaussianSplat3d": ... - def slice_select(self, begin: int, end: int, step: int) -> "GaussianSplat3d": ... - def index_set(self, indices: torch.Tensor, value: "GaussianSplat3d") -> None: ... - def mask_set(self, mask: torch.Tensor, value: "GaussianSplat3d") -> None: ... - def slice_set(self, begin: int, end: int, step: int, value: "GaussianSplat3d") -> None: ... - @property - def sh_degree(self) -> int: ... - @property - def accumulate_mean_2d_gradients(self) -> bool: ... - @accumulate_mean_2d_gradients.setter - def accumulate_mean_2d_gradients(self, value: bool) -> None: ... - @property - def accumulate_max_2d_radii(self) -> bool: ... - @accumulate_max_2d_radii.setter - def accumulate_max_2d_radii(self, value: bool) -> None: ... - @staticmethod - def from_state_dict(state_dict: dict[str, torch.Tensor]) -> GaussianSplat3d: ... - def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: ... - def project_gaussians_for_depths( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - ) -> ProjectedGaussianSplats: ... - def project_gaussians_for_images( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - ) -> ProjectedGaussianSplats: ... - def project_gaussians_for_images_and_depths( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - ) -> ProjectedGaussianSplats: ... - def render_depths( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def render_depths_from_world( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def sparse_render_depths( - self, - pixels_to_render: JaggedTensor, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def render_from_projected_gaussians( - self, - projected_gaussians: ProjectedGaussianSplats, - crop_width: int = ..., - crop_height: int = ..., - crop_origin_w: int = ..., - crop_origin_h: int = ..., - tile_size: int = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def render_images( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def render_images_from_world( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def sparse_render_images( - self, - pixels_to_render: JaggedTensor, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def render_images_and_depths( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def render_images_and_depths_from_world( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def sparse_render_images_and_depths( - self, - pixels_to_render: JaggedTensor, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., - masks: Optional[torch.Tensor] = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def render_num_contributing_gaussians( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def sparse_render_num_contributing_gaussians( - self, - pixels_to_render: JaggedTensor, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def render_contributing_gaussian_ids( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - top_k_contributors: int = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def sparse_render_contributing_gaussian_ids( - self, - pixels_to_render: JaggedTensor, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: "CameraModel" = ..., - projection_method: "ProjectionMethod" = ..., - distortion_coeffs: Optional[torch.Tensor] = ..., - tile_size: int = ..., - min_radius_2d: float = ..., - eps_2d: float = ..., - antialias: bool = ..., - top_k_contributors: int = ..., - ) -> tuple[JaggedTensor, JaggedTensor]: ... - def relocate_gaussians( - self, - log_scales: torch.Tensor, - logit_opacities: torch.Tensor, - ratios: torch.Tensor, - binomial_coeffs: torch.Tensor, - n_max: int, - min_opacity: float, - ) -> tuple[torch.Tensor, torch.Tensor]: ... - def add_noise_to_means(self, noise_scale: float, t: float = ..., k: float = ...) -> None: ... - def reset_accumulated_gradient_state(self) -> None: ... - def save_ply(self, filename: str, metadata: dict[str, str | int | float | torch.Tensor] | None) -> None: ... - @staticmethod - def from_ply( - filename: str, device: torch.device = ... - ) -> tuple[GaussianSplat3d, dict[str, str | int | float | torch.Tensor]]: ... - def set_state( - self, - means: torch.Tensor, - quats: torch.Tensor, - log_scales: torch.Tensor, - logit_opacities: torch.Tensor, - sh0: torch.Tensor, - shN: torch.Tensor, - ) -> None: ... - def state_dict(self) -> dict[str, torch.Tensor]: ... - @property - def accumulated_gradient_step_counts(self) -> torch.Tensor: ... - @property - def accumulated_max_2d_radii(self) -> torch.Tensor: ... - @property - def accumulated_mean_2d_gradient_norms(self) -> torch.Tensor: ... - @property - def num_channels(self) -> int: ... - @property - def num_gaussians(self) -> int: ... - @property - def num_sh_bases(self) -> int: ... - @property - def opacities(self) -> torch.Tensor: ... - @property - def scales(self) -> torch.Tensor: ... - class GridBatchData: MAX_GRIDS_PER_BATCH: ClassVar[int] = ... @@ -974,7 +581,10 @@ class JaggedTensor: def from_data_and_offsets(arg0: torch.Tensor, arg1: torch.Tensor) -> JaggedTensor: ... @staticmethod def from_data_indices_and_list_ids( - data: torch.Tensor, indices: torch.Tensor, list_ids: torch.Tensor, num_tensors: int + data: torch.Tensor, + indices: torch.Tensor, + list_ids: torch.Tensor, + num_tensors: int, ) -> JaggedTensor: ... @staticmethod def from_data_offsets_and_list_ids( @@ -1206,44 +816,10 @@ class JaggedTensor: def rshape(self) -> tuple[int, ...]: ... def __iter__(self) -> typing.Iterator[JaggedTensor]: ... -class ProjectedGaussianSplats: - def __init__(self, *args, **kwargs) -> None: ... - @property - def antialias(self) -> bool: ... - @property - def conics(self) -> torch.Tensor: ... - @property - def depths(self) -> torch.Tensor: ... - @property - def eps_2d(self) -> float: ... - @property - def far_plane(self) -> float: ... - @property - def image_height(self) -> int: ... - @property - def image_width(self) -> int: ... - @property - def means2d(self) -> torch.Tensor: ... - @property - def min_radius_2d(self) -> float: ... - @property - def near_plane(self) -> float: ... - @property - def opacities(self) -> torch.Tensor: ... - @property - def camera_model(self) -> CameraModel: ... - @property - def projection_method(self) -> ProjectionMethod: ... - @property - def radii(self) -> torch.Tensor: ... - @property - def render_quantities(self) -> torch.Tensor: ... - @property - def sh_degree_to_use(self) -> int: ... - @property - def tile_gaussian_ids(self) -> torch.Tensor: ... - @property - def tile_offsets(self) -> torch.Tensor: ... +# ProjectedGaussianSplats and SparseProjectedGaussianSplats are no longer +# exposed as C++ pybind11 types. All projection functions now return plain +# tuples of tensors; see fvdb.ProjectedGaussians, fvdb.GaussianTileIntersection, +# and fvdb.SparseGaussianTileIntersection for the Python types. class GaussianSplat3dView: @property @@ -1310,7 +886,15 @@ class Viewer: def remove_scene(self, scene_name: str) -> None: ... def remove_view(self, scene_name: str, name: str) -> None: ... def add_gaussian_splat_3d_view( - self, scene_name: str, name: str, gaussian_splat_3d: GaussianSplat3d + self, + scene_name: str, + name: str, + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, ) -> GaussianSplat3dView: ... def has_gaussian_splat_3d_view(self, name: str) -> bool: ... def get_gaussian_splat_3d_view(self, name: str) -> GaussianSplat3dView: ... @@ -1361,37 +945,6 @@ class config: pedantic_error_checking: ClassVar[bool] = ... def __init__(self, *args, **kwargs) -> None: ... -def gaussian_render_jagged( - means: JaggedTensor, - quats: JaggedTensor, - scales: JaggedTensor, - opacities: JaggedTensor, - sh_coeffs: JaggedTensor, - viewmats: JaggedTensor, - Ks: JaggedTensor, - image_width: int, - image_height: int, - near_plane: float = ..., - far_plane: float = ..., - sh_degree_to_use: int = ..., - tile_size: int = ..., - radius_clip: float = ..., - eps2d: float = ..., - antialias: bool = ..., - render_depth_channel: bool = ..., - return_debug_info: bool = ..., - render_depth_only: bool = ..., - ortho: bool = ..., - backgrounds: Optional[torch.Tensor] = ..., -) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: ... -def evaluate_spherical_harmonics( - sh_degree: int, - num_cameras: int, - sh0: torch.Tensor, - radii: torch.Tensor, - shN: Optional[torch.Tensor] = ..., - view_directions: Optional[torch.Tensor] = ..., -) -> torch.Tensor: ... @overload def jcat(grid_batches: list[GridBatchData]) -> GridBatchData: ... @overload @@ -1544,3 +1097,377 @@ class ProjectionMethod(Enum): AUTO = ... ANALYTIC = ... UNSCENTED = ... + +# ---- Gaussian Splat Operations ---- + +def count_contributing_gaussians( + means2d: torch.Tensor, + conics: torch.Tensor, + opacities: torch.Tensor, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: ... +def count_contributing_gaussians_sparse( + means2d: torch.Tensor, + conics: torch.Tensor, + opacities: torch.Tensor, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + pixels_to_render: JaggedTensor, + active_tiles: torch.Tensor, + tile_pixel_mask: torch.Tensor, + tile_pixel_cumsum: torch.Tensor, + pixel_map: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, +) -> tuple[JaggedTensor, JaggedTensor]: ... +def identify_contributing_gaussians( + means2d: torch.Tensor, + conics: torch.Tensor, + opacities: torch.Tensor, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + num_depth_samples: int, + num_contributing_gaussians: Optional[torch.Tensor] = None, +) -> tuple[JaggedTensor, JaggedTensor]: ... +def identify_contributing_gaussians_sparse( + means2d: torch.Tensor, + conics: torch.Tensor, + opacities: torch.Tensor, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + pixels_to_render: JaggedTensor, + active_tiles: torch.Tensor, + tile_pixel_mask: torch.Tensor, + tile_pixel_cumsum: torch.Tensor, + pixel_map: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + num_depth_samples: int, + num_contributing_gaussians: Optional[JaggedTensor] = None, +) -> tuple[JaggedTensor, JaggedTensor]: ... +def relocate_gaussians( + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + ratios: torch.Tensor, + binomial_coeffs: torch.Tensor, + n_max: int, + min_opacity: float, +) -> tuple[torch.Tensor, torch.Tensor]: ... +def add_noise_to_gaussian_means( + means: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + quats: torch.Tensor, + noise_scale: float, + t: float, + k: float, +) -> None: ... +def save_gaussians_ply( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + filename: str, + metadata: Optional[dict[str, str | int | float | torch.Tensor]], +) -> None: ... +def load_gaussians_ply( + filename: str, + device: torch.device = ..., +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, str | int | float | torch.Tensor], +]: ... + +# ---- Raw forward/backward dispatch (for Python autograd) ---- + +def project_gaussians_analytic_fwd( + means: torch.Tensor, + quats: torch.Tensor, + scales: torch.Tensor, + world_to_cam_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + eps2d: float, + near: float, + far: float, + min_radius_2d: float, + calc_compensations: bool, + ortho: bool, +) -> tuple[torch.Tensor, ...]: ... +def project_gaussians_analytic_bwd( + means: torch.Tensor, + quats: torch.Tensor, + scales: torch.Tensor, + world_to_cam_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + compensations: Optional[torch.Tensor], + image_width: int, + image_height: int, + eps2d: float, + radii: torch.Tensor, + conics: torch.Tensor, + d_loss_d_means2d: torch.Tensor, + d_loss_d_depths: torch.Tensor, + d_loss_d_conics: torch.Tensor, + d_loss_d_compensations: Optional[torch.Tensor], + world_to_cam_matrices_requires_grad: bool, + ortho: bool, + out_normalized_d_loss_d_means2d_norm_accum: Optional[torch.Tensor] = ..., + out_normalized_max_radii_accum: Optional[torch.Tensor] = ..., + out_gradient_step_counts: Optional[torch.Tensor] = ..., +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def project_gaussians_ut_fwd( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + world_to_cam_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + distortion_coeffs: torch.Tensor, + camera_model: CameraModel, + image_width: int, + image_height: int, + eps2d: float, + near: float, + far: float, + min_radius_2d: float, + calc_compensations: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def eval_gaussian_sh_fwd( + sh_degree_to_use: int, + num_cameras: int, + view_dirs: torch.Tensor, + sh0_coeffs: torch.Tensor, + sh_n_coeffs: torch.Tensor, + radii: torch.Tensor, +) -> torch.Tensor: ... +def eval_gaussian_sh_bwd( + sh_degree_to_use: int, + num_cameras: int, + num_gaussians: int, + view_dirs: torch.Tensor, + sh_n_coeffs: torch.Tensor, + d_loss_d_colors: torch.Tensor, + radii: torch.Tensor, + compute_d_loss_d_view_dirs: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def rasterize_screen_space_gaussians_fwd( + means2d: torch.Tensor, + conics: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + backgrounds: Optional[torch.Tensor], + masks: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def rasterize_screen_space_gaussians_bwd( + means2d: torch.Tensor, + conics: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + rendered_alphas: torch.Tensor, + last_ids: torch.Tensor, + d_loss_d_rendered_features: torch.Tensor, + d_loss_d_rendered_alphas: torch.Tensor, + abs_grad: bool, + num_shared_channels_override: int = ..., + backgrounds: Optional[torch.Tensor] = ..., + masks: Optional[torch.Tensor] = ..., +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def rasterize_screen_space_gaussians_sparse_fwd( + pixels_to_render: JaggedTensor, + means2d: torch.Tensor, + conics: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + active_tiles: torch.Tensor, + tile_pixel_mask: torch.Tensor, + tile_pixel_cumsum: torch.Tensor, + pixel_map: torch.Tensor, + backgrounds: Optional[torch.Tensor], + masks: Optional[torch.Tensor], +) -> tuple[JaggedTensor, JaggedTensor, JaggedTensor]: ... +def rasterize_screen_space_gaussians_sparse_bwd( + pixels_to_render: JaggedTensor, + means2d: torch.Tensor, + conics: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + rendered_alphas: JaggedTensor, + last_ids: JaggedTensor, + d_loss_d_rendered_features: JaggedTensor, + d_loss_d_rendered_alphas: JaggedTensor, + active_tiles: torch.Tensor, + tile_pixel_mask: torch.Tensor, + tile_pixel_cumsum: torch.Tensor, + pixel_map: torch.Tensor, + abs_grad: bool, + num_shared_channels_override: int = ..., + backgrounds: Optional[torch.Tensor] = ..., + masks: Optional[torch.Tensor] = ..., +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def rasterize_world_space_gaussians_fwd( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + world_to_cam_matrices_start: torch.Tensor, + world_to_cam_matrices_end: torch.Tensor, + projection_matrices: torch.Tensor, + distortion_coeffs: torch.Tensor, + rolling_shutter_type: RollingShutterType, + camera_model: CameraModel, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + backgrounds: Optional[torch.Tensor], + masks: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ... +def rasterize_world_space_gaussians_bwd( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + features: torch.Tensor, + opacities: torch.Tensor, + world_to_cam_matrices_start: torch.Tensor, + world_to_cam_matrices_end: torch.Tensor, + projection_matrices: torch.Tensor, + distortion_coeffs: torch.Tensor, + rolling_shutter_type: RollingShutterType, + camera_model: CameraModel, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + rendered_alphas: torch.Tensor, + last_ids: torch.Tensor, + d_loss_d_rendered_features: torch.Tensor, + d_loss_d_rendered_alphas: torch.Tensor, + backgrounds: Optional[torch.Tensor], + masks: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def project_gaussians_analytic_jagged_fwd( + g_sizes: torch.Tensor, + means: torch.Tensor, + quats: torch.Tensor, + scales: torch.Tensor, + c_sizes: torch.Tensor, + world_to_cam_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + eps2d: float, + near: float, + far: float, + min_radius_2d: float, + ortho: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... +def project_gaussians_analytic_jagged_bwd( + g_sizes: torch.Tensor, + means: torch.Tensor, + quats: torch.Tensor, + scales: torch.Tensor, + c_sizes: torch.Tensor, + world_to_cam_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + eps2d: float, + radii: torch.Tensor, + conics: torch.Tensor, + d_loss_d_means2d: torch.Tensor, + d_loss_d_depths: torch.Tensor, + d_loss_d_conics: torch.Tensor, + world_to_cam_matrices_requires_grad: bool, + ortho: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + +# ---- Tile intersection (non-differentiable) ---- + +def intersect_gaussian_tiles( + means2d: torch.Tensor, + radii: torch.Tensor, + depths: torch.Tensor, + num_cameras: int, + tile_size: int, + num_tiles_h: int, + num_tiles_w: int, + camera_ids: Optional[torch.Tensor] = ..., +) -> tuple[torch.Tensor, torch.Tensor]: ... +def intersect_gaussian_tiles_sparse( + means2d: torch.Tensor, + radii: torch.Tensor, + depths: torch.Tensor, + tile_mask: torch.Tensor, + active_tiles: torch.Tensor, + num_cameras: int, + tile_size: int, + num_tiles_h: int, + num_tiles_w: int, + camera_ids: Optional[torch.Tensor] = ..., +) -> tuple[torch.Tensor, torch.Tensor]: ... +def build_sparse_gaussian_tile_layout( + tile_side_length: int, + num_tiles_w: int, + num_tiles_h: int, + pixels_to_render: JaggedTensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... diff --git a/fvdb/enums.py b/fvdb/enums.py index 10c43850b..685818e60 100644 --- a/fvdb/enums.py +++ b/fvdb/enums.py @@ -122,3 +122,30 @@ class ProjectionMethod(IntEnum): """ Use the unscented projection path. """ + + +class GaussianRenderMode(IntEnum): + """ + Rendering mode for Gaussian splatting rasterization. + + Controls which quantities are produced by the rasterization stage: + + - ``FEATURES``: Rasterize only the per-Gaussian features (e.g. SH-evaluated colours). + - ``DEPTH``: Rasterize only depth. + - ``FEATURES_AND_DEPTH``: Rasterize features concatenated with depth. + """ + + FEATURES = 0 + """ + Rasterize only per-Gaussian features (e.g. evaluated spherical-harmonic colours). + """ + + DEPTH = 1 + """ + Rasterize only per-Gaussian depth. + """ + + FEATURES_AND_DEPTH = 2 + """ + Rasterize features concatenated with a depth channel. + """ diff --git a/fvdb/functional/__init__.py b/fvdb/functional/__init__.py index f9a93811c..89bfa99b8 100644 --- a/fvdb/functional/__init__.py +++ b/fvdb/functional/__init__.py @@ -171,8 +171,56 @@ # I/O from ._io import load_nanovdb, load_nanovdb_single, save_nanovdb, save_nanovdb_single +from ._io import load_gaussian_ply, save_gaussian_ply + +# Gaussian splatting types +from ._gaussian_projection import ProjectedGaussians +from ._gaussian_tile_intersection import GaussianTileIntersection, SparseGaussianTileIntersection + +# Gaussian splatting pipeline +from ._gaussian_projection import project_gaussians +from ._gaussian_spherical_harmonics import evaluate_gaussian_sh +from ._gaussian_tile_intersection import intersect_gaussian_tiles, intersect_gaussian_tiles_sparse +from ._gaussian_rasterization import ( + compute_gaussian_opacities, + rasterize_screen_space_gaussians, + rasterize_world_space_gaussians, +) +from ._gaussian_rasterization_sparse import rasterize_screen_space_gaussians_sparse + +# Gaussian splatting analysis +from ._gaussian_analysis import ( + count_contributing_gaussians, + count_contributing_gaussians_sparse, + identify_contributing_gaussians, + identify_contributing_gaussians_sparse, +) + +# Gaussian MCMC +from ._gaussian_mcmc import relocate_gaussians, add_noise_to_gaussian_means + +# Metrics +from ._metrics import psnr, ssim __all__ = [ + # Gaussian splatting types + "ProjectedGaussians", + "GaussianTileIntersection", + "SparseGaussianTileIntersection", + # Gaussian splatting pipeline + "project_gaussians", + "evaluate_gaussian_sh", + "intersect_gaussian_tiles", + "intersect_gaussian_tiles_sparse", + "rasterize_screen_space_gaussians", + "rasterize_world_space_gaussians", + "rasterize_screen_space_gaussians_sparse", + "compute_gaussian_opacities", + # Gaussian splatting analysis + "count_contributing_gaussians", + "identify_contributing_gaussians", + "count_contributing_gaussians_sparse", + "identify_contributing_gaussians_sparse", # Interpolation (batch) "sample_trilinear_batch", "sample_trilinear_with_grad_batch", @@ -307,4 +355,12 @@ "load_nanovdb_single", "save_nanovdb", "save_nanovdb_single", + "load_gaussian_ply", + "save_gaussian_ply", + # Gaussian MCMC + "relocate_gaussians", + "add_noise_to_gaussian_means", + # Metrics + "psnr", + "ssim", ] diff --git a/fvdb/functional/_gaussian_analysis.py b/fvdb/functional/_gaussian_analysis.py new file mode 100644 index 000000000..26e2d1991 --- /dev/null +++ b/fvdb/functional/_gaussian_analysis.py @@ -0,0 +1,330 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for analysing contributing Gaussians per pixel.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .. import _fvdb_cpp as _C +from .._fvdb_cpp import JaggedTensor as JaggedTensorCpp +from ..jagged_tensor import JaggedTensor + +if TYPE_CHECKING: + from ._gaussian_projection import ProjectedGaussians + from ._gaussian_tile_intersection import ( + GaussianTileIntersection, + SparseGaussianTileIntersection, + ) + + +def _crop_jagged_pixel_pair( + ids: JaggedTensor, + weights: JaggedTensor, + image_width: int, + image_height: int, + ox: int, + oy: int, + cw: int, + ch: int, +) -> tuple[JaggedTensor, JaggedTensor]: + """Crop a pair of pixel-indexed ``JaggedTensor`` s (ldim=2) to a rectangular sub-region. + + The inputs must have structure: ``C`` cameras, each with ``H * W`` pixel + rows in row-major order, each pixel with a variable number of samples. + """ + C = len(ids) + device = ids.jdata.device + pixels_per_camera = cw * ch + + # Flat indices of crop pixels within one camera: [ch * cw] + ys = torch.arange(oy, oy + ch, device=device) + xs = torch.arange(ox, ox + cw, device=device) + cam_pixel_offsets = (ys[:, None] * image_width + xs[None, :]).reshape(-1) + + # Extend to all cameras: [C * ch * cw] + cam_starts = torch.arange(C, device=device) * (image_height * image_width) + all_pixel_indices = (cam_starts[:, None] + cam_pixel_offsets[None, :]).reshape(-1) + + # Both JaggedTensors share the same jagged structure so we reuse offsets. + offsets = ids.joffsets # [C*H*W + 1] + starts = offsets[all_pixel_indices] + ends = offsets[all_pixel_indices + 1] + lengths = ends - starts + + new_offsets = torch.zeros(len(all_pixel_indices) + 1, dtype=torch.int64, device=device) + torch.cumsum(lengths, dim=0, out=new_offsets[1:]) + total = int(new_offsets[-1].item()) + + if total > 0: + pixel_for_sample = torch.repeat_interleave( + torch.arange(len(all_pixel_indices), device=device, dtype=torch.long), + lengths, + ) + within_pixel = torch.arange(total, device=device, dtype=torch.long) - new_offsets[pixel_for_sample] + gather_idx = (starts[pixel_for_sample] + within_pixel).long() + new_ids_data = ids.jdata[gather_idx] + new_weights_data = weights.jdata[gather_idx] + else: + new_ids_data = ids.jdata[:0] + new_weights_data = weights.jdata[:0] + + cam_ids = torch.arange(C, device=device, dtype=torch.int32).repeat_interleave(pixels_per_camera) + inner_ids = torch.arange(pixels_per_camera, device=device, dtype=torch.int32).repeat(C) + list_ids = torch.stack([cam_ids, inner_ids], dim=1) + + return ( + JaggedTensor.from_data_offsets_and_list_ids(new_ids_data, new_offsets, list_ids), + JaggedTensor.from_data_offsets_and_list_ids(new_weights_data, new_offsets.clone(), list_ids.clone()), + ) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def count_contributing_gaussians( + projected: ProjectedGaussians, + logit_opacities: torch.Tensor, + tiles: GaussianTileIntersection, + crop: tuple[int, int, int, int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Count the number of contributing Gaussians per pixel (dense). + + Non-differentiable analysis function. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + tiles: :class:`GaussianTileIntersection` from Stage 3. + crop: Optional ``(origin_x, origin_y, width, height)`` tuple defining a + sub-region to analyse. When ``None`` (the default), the full image + is used. The crop region is clamped to the projected image bounds. + + Returns: + Tuple of (num_contributing ``[C, H, W]``, weights ``[C, H, W]``) where + H and W are the crop dimensions (or the full image when ``crop`` is + ``None``). + """ + from ._gaussian_rasterization import _compute_opacities, _validate_crop + + opacities = _compute_opacities(logit_opacities, projected) + num, weights = _C.count_contributing_gaussians( + projected.means2d, + projected.conics, + opacities, + tiles.tile_offsets, + tiles.tile_gaussian_ids, + tiles.image_width, + tiles.image_height, + 0, + 0, + tiles.tile_size, + ) + + if crop is not None: + ox, oy, w, h = _validate_crop(crop, tiles.image_width, tiles.image_height) + num = num[:, oy : oy + h, ox : ox + w] + weights = weights[:, oy : oy + h, ox : ox + w] + + return num, weights + + +def identify_contributing_gaussians( + projected: ProjectedGaussians, + logit_opacities: torch.Tensor, + tiles: GaussianTileIntersection, + num_contributing: torch.Tensor | None = None, + top_k_contributors: int = 0, + crop: tuple[int, int, int, int] | None = None, +) -> tuple[JaggedTensor, JaggedTensor]: + """Get the IDs of contributing Gaussians per pixel (dense). + + Non-differentiable analysis function. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + tiles: :class:`GaussianTileIntersection` from Stage 3. + num_contributing: Optional pre-computed count tensor (from + :func:`count_contributing_gaussians`). + top_k_contributors: If > 0, return only the top-k most opaque + contributors per pixel. + crop: Optional ``(origin_x, origin_y, width, height)`` tuple defining a + sub-region to analyse. When ``None`` (the default), the full image + is used. The crop region is clamped to the projected image bounds. + + Returns: + Tuple of (gaussian_ids, weights) as :class:`~fvdb.JaggedTensor`. + Each JaggedTensor has structure ``C`` cameras, each with ``H * W`` + pixel rows (crop dimensions when ``crop`` is provided), each pixel + with a variable number of contributing Gaussian samples. + """ + from ._gaussian_rasterization import _compute_opacities, _validate_crop + + opacities = _compute_opacities(logit_opacities, projected) + + if top_k_contributors <= 0 and num_contributing is None: + num_contributing, _ = count_contributing_gaussians(projected, logit_opacities, tiles) + + ids_impl, weights_impl = _C.identify_contributing_gaussians( + projected.means2d, + projected.conics, + opacities, + tiles.tile_offsets, + tiles.tile_gaussian_ids, + tiles.image_width, + tiles.image_height, + 0, + 0, + tiles.tile_size, + top_k_contributors, + num_contributing, + ) + ids, weights = JaggedTensor(impl=ids_impl), JaggedTensor(impl=weights_impl) + + if crop is not None: + ox, oy, w, h = _validate_crop(crop, tiles.image_width, tiles.image_height) + ids, weights = _crop_jagged_pixel_pair(ids, weights, tiles.image_width, tiles.image_height, ox, oy, w, h) + + return ids, weights + + +def count_contributing_gaussians_sparse( + projected: ProjectedGaussians, + logit_opacities: torch.Tensor, + sparse_tiles: SparseGaussianTileIntersection, +) -> tuple[JaggedTensor, JaggedTensor]: + """Count the number of contributing Gaussians per pixel (sparse). + + Non-differentiable analysis function. Gets ``pixels_to_render`` from + ``sparse_tiles.pixels_to_render``. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + sparse_tiles: :class:`SparseGaussianTileIntersection` from Stage 3. + + Returns: + Tuple of (num_contributing, weights) as :class:`~fvdb.JaggedTensor`. + """ + from ._gaussian_rasterization import _compute_opacities + + opacities = _compute_opacities(logit_opacities, projected) + + render_pixels = sparse_tiles.unique_pixels if sparse_tiles.has_duplicates else sparse_tiles.pixels_to_render + jt0_impl, jt1_impl = _C.count_contributing_gaussians_sparse( + projected.means2d, + projected.conics, + opacities, + sparse_tiles.tile_offsets, + sparse_tiles.tile_gaussian_ids, + render_pixels._impl, + sparse_tiles.active_tiles, + sparse_tiles.tile_pixel_mask, + sparse_tiles.tile_pixel_cumsum, + sparse_tiles.pixel_map, + sparse_tiles.image_width, + sparse_tiles.image_height, + 0, + 0, + sparse_tiles.tile_size, + ) + + if sparse_tiles.has_duplicates: + inv = sparse_tiles.inverse_indices + pixels_impl = sparse_tiles.pixels_to_render._impl + jt0_impl = pixels_impl.jagged_like(jt0_impl.jdata.index_select(0, inv)) + jt1_impl = pixels_impl.jagged_like(jt1_impl.jdata.index_select(0, inv)) + return JaggedTensor(impl=jt0_impl), JaggedTensor(impl=jt1_impl) + + +def identify_contributing_gaussians_sparse( + projected: ProjectedGaussians, + logit_opacities: torch.Tensor, + sparse_tiles: SparseGaussianTileIntersection, + num_contributing: JaggedTensor | None = None, +) -> tuple[JaggedTensor, JaggedTensor]: + """Get the IDs of contributing Gaussians per pixel (sparse). + + Non-differentiable analysis function. Gets ``pixels_to_render`` from + ``sparse_tiles.pixels_to_render``. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + sparse_tiles: :class:`SparseGaussianTileIntersection` from Stage 3. + num_contributing: Optional pre-computed count :class:`~fvdb.JaggedTensor`. + + Returns: + Tuple of (gaussian_ids, weights) as :class:`~fvdb.JaggedTensor`. + """ + from ._gaussian_rasterization import _compute_opacities + + opacities = _compute_opacities(logit_opacities, projected) + + render_pixels = sparse_tiles.unique_pixels if sparse_tiles.has_duplicates else sparse_tiles.pixels_to_render + + # Resolve num_contributing into unique-pixel space for the kernel + jt_arg: JaggedTensorCpp + if num_contributing is None: + jt_arg, _ = _C.count_contributing_gaussians_sparse( + projected.means2d, + projected.conics, + opacities, + sparse_tiles.tile_offsets, + sparse_tiles.tile_gaussian_ids, + render_pixels._impl, + sparse_tiles.active_tiles, + sparse_tiles.tile_pixel_mask, + sparse_tiles.tile_pixel_cumsum, + sparse_tiles.pixel_map, + sparse_tiles.image_width, + sparse_tiles.image_height, + 0, + 0, + sparse_tiles.tile_size, + ) + else: + nc_impl: JaggedTensorCpp = num_contributing._impl if isinstance(num_contributing, JaggedTensor) else num_contributing # type: ignore[assignment] + if sparse_tiles.has_duplicates: + inv = sparse_tiles.inverse_indices + rep_idx = torch.empty(render_pixels._impl.rsize(0), dtype=torch.long, device=inv.device) + rep_idx.scatter_( + 0, inv, torch.arange(sparse_tiles.pixels_to_render._impl.rsize(0), dtype=torch.long, device=inv.device) + ) + unique_data = nc_impl.jdata.index_select(0, rep_idx) + jt_arg = render_pixels._impl.jagged_like(unique_data) + else: + jt_arg = nc_impl + + ids_impl, weights_impl = _C.identify_contributing_gaussians_sparse( + projected.means2d, + projected.conics, + opacities, + sparse_tiles.tile_offsets, + sparse_tiles.tile_gaussian_ids, + render_pixels._impl, + sparse_tiles.active_tiles, + sparse_tiles.tile_pixel_mask, + sparse_tiles.tile_pixel_cumsum, + sparse_tiles.pixel_map, + sparse_tiles.image_width, + sparse_tiles.image_height, + 0, + 0, + sparse_tiles.tile_size, + -1, + jt_arg, + ) + + if sparse_tiles.has_duplicates: + inv = sparse_tiles.inverse_indices + pixels_impl = sparse_tiles.pixels_to_render._impl + ids_impl = pixels_impl.jagged_like(ids_impl.jdata.index_select(0, inv)) + weights_impl = pixels_impl.jagged_like(weights_impl.jdata.index_select(0, inv)) + return JaggedTensor(impl=ids_impl), JaggedTensor(impl=weights_impl) diff --git a/fvdb/functional/_gaussian_mcmc.py b/fvdb/functional/_gaussian_mcmc.py new file mode 100644 index 000000000..98b0c7c6f --- /dev/null +++ b/fvdb/functional/_gaussian_mcmc.py @@ -0,0 +1,76 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for MCMC-based Gaussian operations. + +``relocate_gaussians`` and ``add_noise_to_gaussian_means`` dispatch directly +to C++ free functions via pybind. +""" + +from __future__ import annotations + +import torch + +from .. import _fvdb_cpp as _C + + +def relocate_gaussians( + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + ratios: torch.Tensor, + binomial_coeffs: torch.Tensor, + n_max: int, + min_opacity: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Relocate dead Gaussians to high-gradient regions (MCMC strategy). + + Args: + log_scales: ``[N, 3]`` log-scale parameters. + logit_opacities: ``[N]`` pre-sigmoid opacity logits. + ratios: ``[N]`` relocation ratios. + binomial_coeffs: ``[N]`` binomial sampling coefficients. + n_max: Maximum number of relocation candidates. + min_opacity: Minimum opacity threshold for liveness. + + Returns: + Tuple of (logit_opacities_new ``[N]``, log_scales_new ``[N, 3]``). + """ + return _C.relocate_gaussians( + log_scales, + logit_opacities, + ratios, + binomial_coeffs, + n_max, + min_opacity, + ) + + +def add_noise_to_gaussian_means( + means: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + quats: torch.Tensor, + noise_scale: float, + t: float = 0.005, + k: float = 100.0, +) -> None: + """Add scale-dependent noise to Gaussian positions in-place. + + Args: + means: ``[N, 3]`` Gaussian centres (mutated in-place). + log_scales: ``[N, 3]`` log-scale parameters. + logit_opacities: ``[N]`` pre-sigmoid opacity logits. + quats: ``[N, 4]`` quaternion rotations. + noise_scale: Noise scale factor. + t: Noise scaling parameter. Defaults to ``0.005``. + k: Noise scaling parameter. Defaults to ``100.0``. + """ + _C.add_noise_to_gaussian_means( + means, + log_scales, + logit_opacities, + quats, + noise_scale, + t, + k, + ) diff --git a/fvdb/functional/_gaussian_projection.py b/fvdb/functional/_gaussian_projection.py new file mode 100644 index 000000000..65ff8dcfa --- /dev/null +++ b/fvdb/functional/_gaussian_projection.py @@ -0,0 +1,490 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for Gaussian projection (analytic and camera-dispatched).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import torch + +from .. import _fvdb_cpp as _C +from ..enums import CameraModel, ProjectionMethod + +# --------------------------------------------------------------------------- +# Autograd functions (raw dispatch wrappers) +# --------------------------------------------------------------------------- + + +class _ProjectGaussiansFn(torch.autograd.Function): + """Python autograd wrapper for the Gaussian projection forward/backward dispatch.""" + + @staticmethod + def forward( + ctx, + means: torch.Tensor, # [N, 3] + quats: torch.Tensor, # [N, 4] + log_scales: torch.Tensor, # [N, 3] + world_to_cam: torch.Tensor, # [C, 4, 4] + projection_matrices: torch.Tensor, # [C, 3, 3] + image_width: int, + image_height: int, + eps2d: float, + near: float, + far: float, + min_radius_2d: float, + calc_compensations: bool, + ortho: bool, + accum_grad_norms: torch.Tensor | None = None, + accum_step_counts: torch.Tensor | None = None, + accum_max_radii: torch.Tensor | None = None, + ): + result = _C.project_gaussians_analytic_fwd( + means, + quats, + log_scales, + world_to_cam, + projection_matrices, + image_width, + image_height, + eps2d, + near, + far, + min_radius_2d, + calc_compensations, + ortho, + ) + radii: torch.Tensor = result[0] + means2d: torch.Tensor = result[1] + depths: torch.Tensor = result[2] + conics: torch.Tensor = result[3] + compensations: torch.Tensor | None = result[4] if calc_compensations else None + + to_save = [ + means, + quats, + log_scales, + world_to_cam, + projection_matrices, + radii, + conics, + ] + if compensations is not None: + to_save.append(compensations) + ctx.save_for_backward(*to_save) + + ctx.image_width = image_width + ctx.image_height = image_height + ctx.eps2d = eps2d + ctx.calc_compensations = calc_compensations + ctx.ortho = ortho + ctx.accum_grad_norms = accum_grad_norms + ctx.accum_step_counts = accum_step_counts + ctx.accum_max_radii = accum_max_radii + + if compensations is not None: + return radii, means2d, depths, conics, compensations + return radii, means2d, depths, conics + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + grad_radii = grad_outputs[0] + grad_means2d = grad_outputs[1] + grad_depths = grad_outputs[2] + grad_conics = grad_outputs[3] + maybe_grad_comp = grad_outputs[4:] + # Make gradients contiguous (required by the CUDA backward kernel) + if grad_radii is not None: + grad_radii = grad_radii.contiguous() + if grad_means2d is not None: + grad_means2d = grad_means2d.contiguous() + if grad_depths is not None: + grad_depths = grad_depths.contiguous() + if grad_conics is not None: + grad_conics = grad_conics.contiguous() + + grad_compensations: torch.Tensor | None = None + if ctx.calc_compensations and maybe_grad_comp: + gc = maybe_grad_comp[0] + if gc is not None: + grad_compensations = gc.contiguous() + + saved = ctx.saved_tensors + means = saved[0] + quats = saved[1] + log_scales = saved[2] + world_to_cam = saved[3] + projection_matrices = saved[4] + radii = saved[5] + conics = saved[6] + compensations = saved[7] if ctx.calc_compensations else None + + # --- Direct in-place accumulation --- + # The kernel uses gpuAtomicAdd/atomicMax which are inherently accumulative, + # so we pass the persistent accumulators directly — no temporary tensors needed. + assert grad_means2d is not None + assert grad_depths is not None + assert grad_conics is not None + d_means, _, d_quats, d_scales, d_w2c = _C.project_gaussians_analytic_bwd( + means, + quats, + log_scales, + world_to_cam, + projection_matrices, + compensations, + ctx.image_width, + ctx.image_height, + ctx.eps2d, + radii, + conics, + grad_means2d, + grad_depths, + grad_conics, + grad_compensations, + ctx.needs_input_grad[3], # world_to_cam requires_grad + ctx.ortho, + ctx.accum_grad_norms, + ctx.accum_max_radii, + ctx.accum_step_counts, + ) + + # Return None for all non-differentiable inputs + # Order: means, quats, log_scales, world_to_cam, projection_matrices, + # image_width, image_height, eps2d, near_plane, far_plane, min_radius_2d, + # calc_compensations, ortho, accum_grad_norms, accum_step_counts, accum_max_radii + return ( + d_means, + d_quats, + d_scales, + d_w2c, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class _ProjectGaussiansJaggedFn(torch.autograd.Function): + """Python autograd wrapper for the jagged Gaussian projection dispatch.""" + + @staticmethod + def forward( + ctx, + g_sizes: torch.Tensor, + means: torch.Tensor, + quats: torch.Tensor, + scales: torch.Tensor, + c_sizes: torch.Tensor, + world_to_cam: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + eps2d: float, + near: float, + far: float, + min_radius_2d: float, + ortho: bool, + ): + result = _C.project_gaussians_analytic_jagged_fwd( + g_sizes, + means, + quats, + scales, + c_sizes, + world_to_cam, + projection_matrices, + image_width, + image_height, + eps2d, + near, + far, + min_radius_2d, + ortho, + ) + radii, means2d, depths, conics = result[0], result[1], result[2], result[3] + + ctx.save_for_backward( + g_sizes, + means, + quats, + scales, + c_sizes, + world_to_cam, + projection_matrices, + radii, + conics, + ) + ctx.image_width = image_width + ctx.image_height = image_height + ctx.eps2d = eps2d + ctx.ortho = ortho + + return radii, means2d, depths, conics + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + grad_radii = grad_outputs[0] + grad_means2d = grad_outputs[1] + grad_depths = grad_outputs[2] + grad_conics = grad_outputs[3] + if grad_radii is not None: + grad_radii = grad_radii.contiguous() + if grad_means2d is not None: + grad_means2d = grad_means2d.contiguous() + if grad_depths is not None: + grad_depths = grad_depths.contiguous() + if grad_conics is not None: + grad_conics = grad_conics.contiguous() + + g_sizes, means, quats, scales, c_sizes = ctx.saved_tensors[:5] + world_to_cam, projection_matrices, radii, conics = ctx.saved_tensors[5:] + + assert grad_means2d is not None + assert grad_depths is not None + assert grad_conics is not None + d_means, _, d_quats, d_scales, d_w2c = _C.project_gaussians_analytic_jagged_bwd( + g_sizes, + means, + quats, + scales, + c_sizes, + world_to_cam, + projection_matrices, + ctx.image_width, + ctx.image_height, + ctx.eps2d, + radii, + conics, + grad_means2d, + grad_depths, + grad_conics, + ctx.needs_input_grad[5], # world_to_cam requires_grad + ctx.ortho, + ) + + # g_sizes, means, quats, scales, c_sizes, world_to_cam, projection_matrices, + # image_width, image_height, eps2d, near_plane, far_plane, min_radius_2d, ortho + return ( + None, + d_means, + d_quats, + d_scales, + None, + d_w2c, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def _resolve_projection_method( + camera_model: CameraModel, + projection_method: ProjectionMethod, +) -> ProjectionMethod: + """Resolve AUTO -> ANALYTIC or UNSCENTED based on camera model. + + Accepts both the Python ``CameraModel`` enum and the C++ pybind enum + (``_fvdb_cpp.CameraModel``), normalising via ``int()`` to avoid + cross-type comparison pitfalls. + """ + pm_int = int(projection_method) + if pm_int != int(ProjectionMethod.AUTO): + return ProjectionMethod(pm_int) + cm_int = int(camera_model) + if cm_int in (int(CameraModel.PINHOLE), int(CameraModel.ORTHOGRAPHIC)): + return ProjectionMethod.ANALYTIC + return ProjectionMethod.UNSCENTED + + +@dataclass(frozen=True) +class ProjectedGaussians: + """Result of geometric projection of 3D Gaussians onto 2D image planes. + + This is the output of Stage 1 (``project_gaussians``) and serves as input + to all subsequent pipeline stages. Contains only the raw geometric + projection outputs -- no opacity computation, SH evaluation, or tile + intersection data. + + Attributes: + radii: ``[C, N]`` int32 projected radii (<=0 means culled). + means2d: ``[C, N, 2]`` Projected 2D means. + depths: ``[C, N]`` Depths along camera z-axis. + conics: ``[C, N, 3]`` Upper-triangle of inverse 2D covariance. + compensations: ``[C, N]`` Antialiasing compensation factors, or ``None`` + when antialiasing is disabled. + image_width: Image width in pixels (carried forward to tile intersection). + image_height: Image height in pixels (carried forward to tile intersection). + camera_model: The camera model used for projection. + projection_method: The resolved projection method (never ``AUTO``). + """ + + radii: torch.Tensor + means2d: torch.Tensor + depths: torch.Tensor + conics: torch.Tensor + compensations: torch.Tensor | None + image_width: int + image_height: int + camera_model: CameraModel + projection_method: ProjectionMethod + + +def project_gaussians( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + eps_2d: float = 0.3, + near: float = 0.01, + far: float = 1e10, + radius_clip: float = 0.0, + antialias: bool = False, + camera_model: CameraModel = CameraModel.PINHOLE, + projection_method: ProjectionMethod = ProjectionMethod.AUTO, + distortion_coeffs: torch.Tensor | None = None, + accum_grad_norms: torch.Tensor | None = None, + accum_step_counts: torch.Tensor | None = None, + accum_max_radii: torch.Tensor | None = None, +) -> ProjectedGaussians: + """Geometric projection of 3D Gaussians onto 2D image planes (Stage 1). + + Dispatches between analytic projection (differentiable) and unscented + transform projection (forward-only) based on ``projection_method``. + + Accumulator tensors (``accum_grad_norms``, ``accum_step_counts``, + ``accum_max_radii``) are only used with analytic projection; they are + ignored for UT projection. + + Args: + means: ``[N, 3]`` Gaussian means. + quats: ``[N, 4]`` Quaternion rotations. + log_scales: ``[N, 3]`` Log scale factors. + world_to_camera_matrices: ``[C, 4, 4]`` World-to-camera transforms. + projection_matrices: ``[C, 3, 3]`` Projection/intrinsic matrices. + image_width: Output image width in pixels. + image_height: Output image height in pixels. + eps_2d: Epsilon for 2D projection numerical stability. + near: Near clipping plane distance. + far: Far clipping plane distance. + radius_clip: Minimum projected radius for culling. + antialias: Whether to compute antialiasing compensations. + camera_model: Camera distortion model. + projection_method: ``AUTO`` resolves to ``ANALYTIC`` for pinhole/ortho, + ``UNSCENTED`` for distortion-based camera models. + distortion_coeffs: ``[C, 12]`` Optional OpenCV distortion coefficients + (required for UT projection with distortion-based cameras). + accum_grad_norms: ``[N]`` Persistent accumulator for gradient norms + (analytic only; mutated in-place during backward). + accum_step_counts: ``[N]`` Persistent accumulator for gradient step + counts (analytic only; mutated in-place during backward). + accum_max_radii: ``[N]`` Persistent accumulator for max projected radii + (analytic only; mutated in-place during backward). + + Returns: + A :class:`ProjectedGaussians` containing radii, means2d, depths, conics, + compensations, image_width, and image_height. + """ + if not projection_matrices.is_contiguous(): + raise RuntimeError("projectionMatrices must be contiguous") + + _OPENCV_MODELS = { + int(CameraModel.OPENCV_RADTAN_5), + int(CameraModel.OPENCV_RATIONAL_8), + int(CameraModel.OPENCV_RADTAN_THIN_PRISM_9), + int(CameraModel.OPENCV_THIN_PRISM_12), + } + cm_int = int(camera_model) + if cm_int in _OPENCV_MODELS: + if distortion_coeffs is None: + raise RuntimeError("distortionCoeffs must be provided for OPENCV_* camera models") + C = world_to_camera_matrices.size(0) + if distortion_coeffs.dim() != 2 or distortion_coeffs.size(0) != C or distortion_coeffs.size(1) != 12: + raise RuntimeError(f"distortionCoeffs must have shape [{C}, 12], got {list(distortion_coeffs.shape)}") + + resolved = _resolve_projection_method(camera_model, projection_method) + + if cm_int in _OPENCV_MODELS and resolved == ProjectionMethod.ANALYTIC: + raise RuntimeError("OPENCV_* camera models require ProjectionMethod::UNSCENTED or AUTO") + + if resolved == ProjectionMethod.ANALYTIC: + ortho = int(camera_model) == int(CameraModel.ORTHOGRAPHIC) + proj_result = cast( + tuple[torch.Tensor, ...], + _ProjectGaussiansFn.apply( + means, + quats, + log_scales, + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + eps_2d, + near, + far, + radius_clip, + antialias, + ortho, + accum_grad_norms, + accum_step_counts, + accum_max_radii, + ), + ) + radii = proj_result[0] + means2d = proj_result[1] + depths = proj_result[2] + conics = proj_result[3] + compensations: torch.Tensor | None = proj_result[4] if antialias else None + else: + C = world_to_camera_matrices.size(0) + dc = ( + distortion_coeffs + if distortion_coeffs is not None + else torch.empty(C, 0, device=means.device, dtype=means.dtype) + ) + radii, means2d, depths, conics, compensations_raw = _C.project_gaussians_ut_fwd( + means, + quats, + log_scales, + world_to_camera_matrices, + projection_matrices, + dc, + _C.CameraModel(camera_model), + image_width, + image_height, + eps_2d, + near, + far, + radius_clip, + antialias, + ) + compensations = compensations_raw if antialias else None + + return ProjectedGaussians( + radii=radii, + means2d=means2d, + depths=depths, + conics=conics, + compensations=compensations, + image_width=image_width, + image_height=image_height, + camera_model=CameraModel(int(camera_model)), + projection_method=resolved, + ) diff --git a/fvdb/functional/_gaussian_rasterization.py b/fvdb/functional/_gaussian_rasterization.py new file mode 100644 index 000000000..e9169b6cf --- /dev/null +++ b/fvdb/functional/_gaussian_rasterization.py @@ -0,0 +1,579 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for dense Gaussian rasterization (Stage 4). + +Provides both screen-space and world-space rasterization paths: + +- **Screen-space** (``rasterize_screen_space_gaussians``): operates on + pre-projected 2D Gaussians; used with the analytic projection pipeline. +- **World-space** (``rasterize_world_space_gaussians``): reprojects from 3D + geometry during backpropagation so that gradients flow through the Gaussian + means, quats, and log_scales; used with the UT projection pipeline. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, cast + +import torch + +from .. import _fvdb_cpp as _C +from ..enums import CameraModel, RollingShutterType + +if TYPE_CHECKING: + from ._gaussian_projection import ProjectedGaussians + from ._gaussian_tile_intersection import GaussianTileIntersection + + +# --------------------------------------------------------------------------- +# Internal opacity computation helper +# --------------------------------------------------------------------------- + + +def _compute_opacities(logit_opacities: torch.Tensor, projected: ProjectedGaussians) -> torch.Tensor: + """``[N]`` logit_opacities -> ``[C, N]`` opacities with optional compensation.""" + return compute_gaussian_opacities(logit_opacities, projected) + + +def compute_gaussian_opacities(logit_opacities: torch.Tensor, projected: ProjectedGaussians) -> torch.Tensor: + """Convert logit opacities to per-camera opacities with optional compensation. + + Args: + logit_opacities: ``[N]`` pre-sigmoid opacity logits. + projected: :class:`ProjectedGaussians` from Stage 1. + + Returns: + ``[C, N]`` opacities (sigmoid-activated, optionally compensated). + """ + C = projected.means2d.shape[0] + opacities = torch.sigmoid(logit_opacities).unsqueeze(0).expand(C, -1) + if projected.compensations is not None: + opacities = opacities * projected.compensations + return opacities + + +# --------------------------------------------------------------------------- +# Crop validation +# --------------------------------------------------------------------------- + + +def _validate_crop( + crop: tuple[int, int, int, int], + image_width: int, + image_height: int, +) -> tuple[int, int, int, int]: + """Validate and clamp a ``(origin_x, origin_y, width, height)`` crop rect. + + Returns the clamped ``(origin_x, origin_y, width, height)`` that lies + within ``[0, image_width) x [0, image_height)``. + + Raises: + ValueError: If any component is negative or if the clamped region + has zero area. + """ + ox, oy, w, h = crop + if ox < 0 or oy < 0: + raise ValueError(f"Crop origin must be non-negative, got ({ox}, {oy})") + if w <= 0 or h <= 0: + raise ValueError(f"Crop size must be positive, got ({w}, {h})") + # Clamp so the crop doesn't extend beyond the projected image. + w = min(w, image_width - ox) + h = min(h, image_height - oy) + if w <= 0 or h <= 0: + raise ValueError( + f"Crop region (origin=({ox}, {oy}), size=({crop[2]}, {crop[3]})) " + f"has no overlap with the {image_width}x{image_height} image" + ) + return ox, oy, w, h + + +# =========================================================================== +# Screen-space rasterization +# =========================================================================== + + +class _RasterizeScreenSpaceGaussiansFn(torch.autograd.Function): + """Python autograd wrapper for the dense Gaussian rasterization forward/backward dispatch.""" + + @staticmethod + def forward( + ctx, + means2d: torch.Tensor, # [C, N, 2] + conics: torch.Tensor, # [C, N, 3] + colors: torch.Tensor, # [C, N, D] + opacities: torch.Tensor, # [N] + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, # [C, tile_height, tile_width] + tile_gaussian_ids: torch.Tensor, # [n_isects] + absgrad: bool, + backgrounds: torch.Tensor | None, # [C, D] or None + masks: torch.Tensor | None, # [C, tileH, tileW] or None + ): + result = _C.rasterize_screen_space_gaussians_fwd( + means2d, + conics, + colors, + opacities, + image_width, + image_height, + image_origin_w, + image_origin_h, + tile_size, + tile_offsets, + tile_gaussian_ids, + backgrounds, + masks, + ) + rendered_colors = result[0] + rendered_alphas = result[1] + last_ids = result[2] + + to_save = [ + means2d, + conics, + colors, + opacities, + tile_offsets, + tile_gaussian_ids, + rendered_alphas, + last_ids, + ] + if backgrounds is not None: + to_save.append(backgrounds) + ctx.has_backgrounds = True + else: + ctx.has_backgrounds = False + if masks is not None: + to_save.append(masks) + ctx.has_masks = True + else: + ctx.has_masks = False + ctx.save_for_backward(*to_save) + + ctx.image_width = image_width + ctx.image_height = image_height + ctx.image_origin_w = image_origin_w + ctx.image_origin_h = image_origin_h + ctx.tile_size = tile_size + ctx.absgrad = absgrad + + return rendered_colors, rendered_alphas + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + d_loss_d_rendered_colors = grad_outputs[0] + d_loss_d_rendered_alphas = grad_outputs[1] + if d_loss_d_rendered_colors is not None: + d_loss_d_rendered_colors = d_loss_d_rendered_colors.contiguous() + if d_loss_d_rendered_alphas is not None: + d_loss_d_rendered_alphas = d_loss_d_rendered_alphas.contiguous() + + saved = ctx.saved_tensors + means2d = saved[0] + conics = saved[1] + colors = saved[2] + opacities = saved[3] + tile_offsets = saved[4] + tile_gaussian_ids = saved[5] + rendered_alphas = saved[6] + last_ids = saved[7] + + backgrounds: torch.Tensor | None = None + masks: torch.Tensor | None = None + opt_idx = 8 + if ctx.has_backgrounds: + backgrounds = saved[opt_idx] + opt_idx += 1 + if ctx.has_masks: + masks = saved[opt_idx] + opt_idx += 1 + + assert d_loss_d_rendered_colors is not None + assert d_loss_d_rendered_alphas is not None + result = _C.rasterize_screen_space_gaussians_bwd( + means2d, + conics, + colors, + opacities, + ctx.image_width, + ctx.image_height, + ctx.image_origin_w, + ctx.image_origin_h, + ctx.tile_size, + tile_offsets, + tile_gaussian_ids, + rendered_alphas, + last_ids, + d_loss_d_rendered_colors, + d_loss_d_rendered_alphas, + ctx.absgrad, + -1, + backgrounds, + masks, + ) + # result: (dMean2dAbs, dMeans2d, dConics, dColors, dOpacities) + d_means2d = result[1] + d_conics = result[2] + d_colors = result[3] + d_opacities = result[4] + + return ( + d_means2d, + d_conics, + d_colors, + d_opacities, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def rasterize_screen_space_gaussians( + projected: ProjectedGaussians, + features: torch.Tensor, + logit_opacities: torch.Tensor, + tiles: GaussianTileIntersection, + backgrounds: torch.Tensor | None = None, + masks: torch.Tensor | None = None, + crop: tuple[int, int, int, int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Rasterize screen-space Gaussians to produce images and alpha maps (Stage 4). + + Computes opacities internally from ``logit_opacities`` (sigmoid + optional + antialiasing compensation from ``projected``). + + Differentiable via Python autograd. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + features: ``[C, N, D]`` Render features from Stage 2. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + tiles: :class:`GaussianTileIntersection` from Stage 3. + backgrounds: ``[C, D]`` Optional per-camera background colours. + masks: ``[C, tileH, tileW]`` Optional per-tile masks. + crop: Optional ``(origin_x, origin_y, width, height)`` tuple defining a + sub-region of the projected image to rasterize. When ``None`` + (the default), the full image is rendered. The crop region is + clamped to the projected image bounds so it is safe to specify a + region that partially or fully extends beyond the image. + + Returns: + Tuple of (rendered_images ``[C, H, W, D]``, alphas ``[C, H, W, 1]``) + where H and W are the crop dimensions (or the full image dimensions + when ``crop`` is ``None``). + """ + opacities = _compute_opacities(logit_opacities, projected) + + result, alphas = cast( + tuple[torch.Tensor, torch.Tensor], + _RasterizeScreenSpaceGaussiansFn.apply( + projected.means2d, + projected.conics, + features, + opacities, + tiles.image_width, + tiles.image_height, + 0, + 0, + tiles.tile_size, + tiles.tile_offsets, + tiles.tile_gaussian_ids, + False, # absgrad + backgrounds, + masks, + ), + ) + + if crop is not None: + ox, oy, w, h = _validate_crop(crop, tiles.image_width, tiles.image_height) + result = result[:, oy : oy + h, ox : ox + w, :] + alphas = alphas[:, oy : oy + h, ox : ox + w, :] + + return result, alphas + + +# =========================================================================== +# World-space rasterization +# =========================================================================== + + +class _RasterizeWorldSpaceGaussiansFn(torch.autograd.Function): + """Python autograd wrapper for world-space Gaussian rasterization forward/backward dispatch.""" + + @staticmethod + def forward( + ctx, + means: torch.Tensor, # [N, 3] + quats: torch.Tensor, # [N, 4] + log_scales: torch.Tensor, # [N, 3] + features: torch.Tensor, # [C, N, D] + opacities: torch.Tensor, # [C, N] + world_to_cam_start: torch.Tensor, # [C, 4, 4] + world_to_cam_end: torch.Tensor, # [C, 4, 4] + projection_matrices: torch.Tensor, # [C, 3, 3] + distortion_coeffs: torch.Tensor, # [C, K] + rolling_shutter_type: RollingShutterType, + camera_model: CameraModel, + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + backgrounds: torch.Tensor | None, # [C, D] or None + masks: torch.Tensor | None, # [C, tileH, tileW] or None + ): + result = _C.rasterize_world_space_gaussians_fwd( + means, + quats, + log_scales, + features, + opacities, + world_to_cam_start, + world_to_cam_end, + projection_matrices, + distortion_coeffs, + _C.RollingShutterType(rolling_shutter_type), + _C.CameraModel(camera_model), + image_width, + image_height, + image_origin_w, + image_origin_h, + tile_size, + tile_offsets, + tile_gaussian_ids, + backgrounds, + masks, + ) + rendered_features = result[0] + rendered_alphas = result[1] + last_ids = result[2] + + to_save = [ + means, + quats, + log_scales, + features, + opacities, + world_to_cam_start, + world_to_cam_end, + projection_matrices, + distortion_coeffs, + tile_offsets, + tile_gaussian_ids, + rendered_alphas, + last_ids, + ] + if backgrounds is not None: + to_save.append(backgrounds) + ctx.has_backgrounds = True + else: + ctx.has_backgrounds = False + if masks is not None: + to_save.append(masks) + ctx.has_masks = True + else: + ctx.has_masks = False + ctx.save_for_backward(*to_save) + + ctx.image_width = image_width + ctx.image_height = image_height + ctx.image_origin_w = image_origin_w + ctx.image_origin_h = image_origin_h + ctx.tile_size = tile_size + ctx.rolling_shutter_type = rolling_shutter_type + ctx.camera_model = camera_model + + return rendered_features, rendered_alphas + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + d_loss_d_rendered_features = grad_outputs[0] + d_loss_d_rendered_alphas = grad_outputs[1] + if d_loss_d_rendered_features is not None: + d_loss_d_rendered_features = d_loss_d_rendered_features.contiguous() + if d_loss_d_rendered_alphas is not None: + d_loss_d_rendered_alphas = d_loss_d_rendered_alphas.contiguous() + + saved = ctx.saved_tensors + means = saved[0] + quats = saved[1] + log_scales = saved[2] + features = saved[3] + opacities = saved[4] + world_to_cam_start = saved[5] + world_to_cam_end = saved[6] + projection_matrices = saved[7] + distortion_coeffs = saved[8] + tile_offsets = saved[9] + tile_gaussian_ids = saved[10] + rendered_alphas = saved[11] + last_ids = saved[12] + + backgrounds: torch.Tensor | None = None + masks: torch.Tensor | None = None + opt_idx = 13 + if ctx.has_backgrounds: + backgrounds = saved[opt_idx] + opt_idx += 1 + if ctx.has_masks: + masks = saved[opt_idx] + opt_idx += 1 + + assert d_loss_d_rendered_features is not None + assert d_loss_d_rendered_alphas is not None + result = _C.rasterize_world_space_gaussians_bwd( + means, + quats, + log_scales, + features, + opacities, + world_to_cam_start, + world_to_cam_end, + projection_matrices, + distortion_coeffs, + _C.RollingShutterType(ctx.rolling_shutter_type), + _C.CameraModel(ctx.camera_model), + ctx.image_width, + ctx.image_height, + ctx.image_origin_w, + ctx.image_origin_h, + ctx.tile_size, + tile_offsets, + tile_gaussian_ids, + rendered_alphas, + last_ids, + d_loss_d_rendered_features, + d_loss_d_rendered_alphas, + backgrounds, + masks, + ) + d_means = result[0] + d_quats = result[1] + d_log_scales = result[2] + d_features = result[3] + d_opacities = result[4] + + return ( + d_means, + d_quats, + d_log_scales, + d_features, + d_opacities, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def rasterize_world_space_gaussians( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + projected: ProjectedGaussians, + features: torch.Tensor, + logit_opacities: torch.Tensor, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + distortion_coeffs: torch.Tensor, + camera_model: CameraModel, + tiles: GaussianTileIntersection, + backgrounds: torch.Tensor | None = None, + masks: torch.Tensor | None = None, + crop: tuple[int, int, int, int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Rasterize Gaussians from world-space with geometry gradients (Stage 4). + + Unlike :func:`rasterize_screen_space_gaussians`, this path computes gradients + with respect to the 3D Gaussian geometry (means, quats, log_scales) during + backpropagation, enabling world-space optimization for the UT projection + pipeline. + + Computes opacities internally from ``logit_opacities`` (sigmoid + optional + antialiasing compensation from ``projected``). + + Differentiable via Python autograd. + + Args: + means: ``[N, 3]`` Gaussian means. + quats: ``[N, 4]`` Quaternion rotations. + log_scales: ``[N, 3]`` Log scale factors. + projected: :class:`ProjectedGaussians` from Stage 1. + features: ``[C, N, D]`` Render features from Stage 2. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + world_to_camera_matrices: ``[C, 4, 4]`` World-to-camera transforms. + projection_matrices: ``[C, 3, 3]`` Projection matrices. + distortion_coeffs: ``[C, K]`` Distortion coefficients (empty ``[C, 0]`` for none). + camera_model: Camera distortion model. + tiles: :class:`GaussianTileIntersection` from Stage 3. + backgrounds: ``[C, D]`` Optional per-camera backgrounds. + masks: ``[C, tileH, tileW]`` Optional per-tile masks. + crop: Optional ``(origin_x, origin_y, width, height)`` tuple defining a + sub-region of the projected image to rasterize. When ``None`` + (the default), the full image is rendered. The crop region is + clamped to the projected image bounds. + + Returns: + Tuple of (rendered_images ``[C, H, W, D]``, alphas ``[C, H, W, 1]``) + where H and W are the crop dimensions (or the full image dimensions + when ``crop`` is ``None``). + """ + opacities = _compute_opacities(logit_opacities, projected) + + result, alphas = cast( + tuple[torch.Tensor, torch.Tensor], + _RasterizeWorldSpaceGaussiansFn.apply( + means, + quats, + log_scales, + features, + opacities, + world_to_camera_matrices, + world_to_camera_matrices, + projection_matrices, + distortion_coeffs, + RollingShutterType.NONE, + camera_model, + tiles.image_width, + tiles.image_height, + 0, + 0, + tiles.tile_size, + tiles.tile_offsets, + tiles.tile_gaussian_ids, + backgrounds, + masks, + ), + ) + + if crop is not None: + ox, oy, w, h = _validate_crop(crop, tiles.image_width, tiles.image_height) + result = result[:, oy : oy + h, ox : ox + w, :] + alphas = alphas[:, oy : oy + h, ox : ox + w, :] + + return result, alphas diff --git a/fvdb/functional/_gaussian_rasterization_sparse.py b/fvdb/functional/_gaussian_rasterization_sparse.py new file mode 100644 index 000000000..885162607 --- /dev/null +++ b/fvdb/functional/_gaussian_rasterization_sparse.py @@ -0,0 +1,300 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for sparse Gaussian rasterization (Stage 4).""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any, cast + +import torch + +from .. import _fvdb_cpp as _C +from ..enums import CameraModel, ProjectionMethod +from ..jagged_tensor import JaggedTensor + +if TYPE_CHECKING: + from ._gaussian_projection import ProjectedGaussians + from ._gaussian_tile_intersection import SparseGaussianTileIntersection + + +# --------------------------------------------------------------------------- +# Autograd function (raw dispatch wrapper) +# --------------------------------------------------------------------------- + + +class _RasterizeScreenSpaceGaussiansSparseFn(torch.autograd.Function): + """Python autograd wrapper for the sparse Gaussian rasterization forward/backward dispatch. + + The complexity here is that the rasterize kernels operate on JaggedTensors, + but torch.autograd.Function only tracks plain tensors for gradient computation. + We decompose JaggedTensors into their component tensors (jdata, joffsets, + jidx, jlidx) for saving, then reconstruct in backward. + """ + + @staticmethod + def forward( + ctx, + # Differentiable inputs (plain tensors) + means2d: torch.Tensor, # [C, N, 2] + conics: torch.Tensor, # [C, N, 3] + features: torch.Tensor, # [C, N, D] + opacities: torch.Tensor, # [N] + # Non-differentiable inputs + pixels_to_render: JaggedTensor, # JaggedTensor [C, num_pixels, 2] + image_width: int, + image_height: int, + image_origin_w: int, + image_origin_h: int, + tile_size: int, + tile_offsets: torch.Tensor, + tile_gaussian_ids: torch.Tensor, + active_tiles: torch.Tensor, + tile_pixel_mask: torch.Tensor, + tile_pixel_cumsum: torch.Tensor, + pixel_map: torch.Tensor, + absgrad: bool, + backgrounds: torch.Tensor | None, + masks: torch.Tensor | None, + ): + result = _C.rasterize_screen_space_gaussians_sparse_fwd( + pixels_to_render._impl, + means2d, + conics, + features, + opacities, + image_width, + image_height, + image_origin_w, + image_origin_h, + tile_size, + tile_offsets, + tile_gaussian_ids, + active_tiles, + tile_pixel_mask, + tile_pixel_cumsum, + pixel_map, + backgrounds, + masks, + ) + rendered_colors_jt = JaggedTensor(impl=result[0]) + rendered_alphas_jt = JaggedTensor(impl=result[1]) + last_ids_jt = JaggedTensor(impl=result[2]) + + joffsets = pixels_to_render.joffsets + jidx = pixels_to_render.jidx + jlidx = pixels_to_render.jlidx + + to_save = [ + means2d, # 0 + conics, # 1 + features, # 2 + opacities, # 3 + tile_offsets, # 4 + tile_gaussian_ids, # 5 + pixels_to_render.jdata, # 6 + rendered_colors_jt.jdata, # 7 + rendered_alphas_jt.jdata, # 8 + last_ids_jt.jdata, # 9 + joffsets, # 10 + jidx, # 11 + jlidx, # 12 + active_tiles, # 13 + tile_pixel_mask, # 14 + tile_pixel_cumsum, # 15 + pixel_map, # 16 + ] + if backgrounds is not None: + to_save.append(backgrounds) + ctx.has_backgrounds = True + else: + ctx.has_backgrounds = False + if masks is not None: + to_save.append(masks) + ctx.has_masks = True + else: + ctx.has_masks = False + ctx.save_for_backward(*to_save) + + ctx.image_width = image_width + ctx.image_height = image_height + ctx.image_origin_w = image_origin_w + ctx.image_origin_h = image_origin_h + ctx.tile_size = tile_size + ctx.absgrad = absgrad + ctx.num_outer_lists = len(pixels_to_render) + + return rendered_colors_jt.jdata, rendered_alphas_jt.jdata + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + d_loss_d_rendered_features_jdata = grad_outputs[0] + d_loss_d_rendered_alphas_jdata = grad_outputs[1] + if d_loss_d_rendered_features_jdata is not None: + d_loss_d_rendered_features_jdata = d_loss_d_rendered_features_jdata.contiguous() + if d_loss_d_rendered_alphas_jdata is not None: + d_loss_d_rendered_alphas_jdata = d_loss_d_rendered_alphas_jdata.contiguous() + + saved = ctx.saved_tensors + means2d = saved[0] + conics = saved[1] + features = saved[2] + opacities = saved[3] + tile_offsets = saved[4] + tile_gaussian_ids = saved[5] + pixels_jdata = saved[6] + rendered_colors_jdata = saved[7] + rendered_alphas_jdata = saved[8] + last_ids_jdata = saved[9] + joffsets = saved[10] + jidx = saved[11] + jlidx = saved[12] + active_tiles = saved[13] + tile_pixel_mask = saved[14] + tile_pixel_cumsum = saved[15] + pixel_map = saved[16] + + backgrounds: torch.Tensor | None = None + masks: torch.Tensor | None = None + opt_idx = 17 + if ctx.has_backgrounds: + backgrounds = saved[opt_idx] + opt_idx += 1 + if ctx.has_masks: + masks = saved[opt_idx] + opt_idx += 1 + + pixels_jt = JaggedTensor(impl=_C.JaggedTensor.from_data_offsets_and_list_ids(pixels_jdata, joffsets, jlidx)) + rendered_alphas_jt = pixels_jt.jagged_like(rendered_alphas_jdata) + last_ids_jt = pixels_jt.jagged_like(last_ids_jdata) + assert d_loss_d_rendered_features_jdata is not None + assert d_loss_d_rendered_alphas_jdata is not None + d_loss_d_rendered_features_jt = pixels_jt.jagged_like(d_loss_d_rendered_features_jdata) + d_loss_d_rendered_alphas_jt = pixels_jt.jagged_like(d_loss_d_rendered_alphas_jdata) + + result = _C.rasterize_screen_space_gaussians_sparse_bwd( + pixels_jt._impl, + means2d, + conics, + features, + opacities, + ctx.image_width, + ctx.image_height, + ctx.image_origin_w, + ctx.image_origin_h, + ctx.tile_size, + tile_offsets, + tile_gaussian_ids, + rendered_alphas_jt._impl, + last_ids_jt._impl, + d_loss_d_rendered_features_jt._impl, + d_loss_d_rendered_alphas_jt._impl, + active_tiles, + tile_pixel_mask, + tile_pixel_cumsum, + pixel_map, + ctx.absgrad, + -1, + backgrounds, + masks, + ) + d_means2d = result[1] + d_conics = result[2] + d_colors = result[3] + d_opacities = result[4] + + # Order: means2d, conics, features, opacities, + # pixels_to_render, image_width, image_height, image_origin_w, image_origin_h, + # tile_size, tile_offsets, tile_gaussian_ids, active_tiles, + # tile_pixel_mask, tile_pixel_cumsum, pixel_map, absgrad, + # backgrounds, masks + return ( + d_means2d, + d_conics, + d_colors, + d_opacities, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def rasterize_screen_space_gaussians_sparse( + projected: ProjectedGaussians, + features: torch.Tensor, + logit_opacities: torch.Tensor, + sparse_tiles: SparseGaussianTileIntersection, + backgrounds: torch.Tensor | None = None, + masks: torch.Tensor | None = None, +) -> tuple[JaggedTensor, JaggedTensor]: + """Rasterize screen-space Gaussians at sparse pixel locations (Stage 4). + + Computes opacities internally from ``logit_opacities``. Returns results + for the deduplicated pixel set only -- the caller is responsible for + scatter-back using ``sparse_tiles.inverse_indices`` when duplicates exist. + + Differentiable via Python autograd. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + features: ``[C, N, D]`` Render features from Stage 2. + logit_opacities: ``[N]`` Pre-sigmoid opacities. + sparse_tiles: :class:`SparseGaussianTileIntersection` from Stage 3. + backgrounds: ``[C, D]`` Optional per-camera backgrounds. + masks: ``[C, tileH, tileW]`` Optional per-tile masks. + + Returns: + Tuple of (rendered_features, rendered_alphas) as :class:`JaggedTensor` + for the **unique** pixel set. + """ + from ._gaussian_rasterization import _compute_opacities + + opacities = _compute_opacities(logit_opacities, projected) + + rendered_jdata, rendered_alphas_jdata = cast( + tuple[torch.Tensor, torch.Tensor], + _RasterizeScreenSpaceGaussiansSparseFn.apply( + projected.means2d, + projected.conics, + features, + opacities, + sparse_tiles.unique_pixels, + sparse_tiles.image_width, + sparse_tiles.image_height, + 0, + 0, + sparse_tiles.tile_size, + sparse_tiles.tile_offsets, + sparse_tiles.tile_gaussian_ids, + sparse_tiles.active_tiles, + sparse_tiles.tile_pixel_mask, + sparse_tiles.tile_pixel_cumsum, + sparse_tiles.pixel_map, + False, # absgrad + backgrounds, + masks, + ), + ) + + rendered_jt = sparse_tiles.unique_pixels.jagged_like(rendered_jdata) + rendered_alphas_jt = sparse_tiles.unique_pixels.jagged_like(rendered_alphas_jdata) + + return rendered_jt, rendered_alphas_jt diff --git a/fvdb/functional/_gaussian_spherical_harmonics.py b/fvdb/functional/_gaussian_spherical_harmonics.py new file mode 100644 index 000000000..bd200b8a2 --- /dev/null +++ b/fvdb/functional/_gaussian_spherical_harmonics.py @@ -0,0 +1,144 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Functional API for spherical harmonics evaluation (Stage 2).""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any, cast + +import torch + +from .. import _fvdb_cpp as _C +from ..enums import GaussianRenderMode + +if TYPE_CHECKING: + from ._gaussian_projection import ProjectedGaussians + + +# --------------------------------------------------------------------------- +# Autograd function (raw dispatch wrapper) +# --------------------------------------------------------------------------- + + +class _EvaluateGaussianSHFn(torch.autograd.Function): + """Python autograd wrapper for the SH evaluation forward/backward dispatch.""" + + @staticmethod + def forward( + ctx, + sh_degree_to_use: int, + num_cameras: int, + view_dirs: torch.Tensor, # [C, N, 3] or empty + sh0_coeffs: torch.Tensor, # [N, 1, D] + shN_coeffs: torch.Tensor, # [N, K-1, D] or empty + radii: torch.Tensor, # [C, N] + ) -> torch.Tensor: + render_quantities = _C.eval_gaussian_sh_fwd( + sh_degree_to_use, + num_cameras, + view_dirs, + sh0_coeffs, + shN_coeffs, + radii, + ) + + ctx.save_for_backward(view_dirs, shN_coeffs, radii) + ctx.sh_degree_to_use = sh_degree_to_use + ctx.num_cameras = num_cameras + ctx.num_gaussians = sh0_coeffs.size(0) + + return render_quantities + + @staticmethod + def backward(ctx: Any, *grad_outputs: torch.Tensor | None) -> tuple[torch.Tensor | None, ...]: + d_loss_d_colors = grad_outputs[0] + if d_loss_d_colors is None: + return (None, None, None, None, None, None) + d_loss_d_colors = d_loss_d_colors.contiguous() + + view_dirs, shN_coeffs, radii = ctx.saved_tensors + + compute_d_loss_d_view_dirs = view_dirs.numel() > 0 and view_dirs.requires_grad + + d_sh0, d_shN, d_view_dirs = _C.eval_gaussian_sh_bwd( + ctx.sh_degree_to_use, + ctx.num_cameras, + ctx.num_gaussians, + view_dirs, + shN_coeffs, + d_loss_d_colors, + radii, + compute_d_loss_d_view_dirs, + ) + + # Order: sh_degree_to_use, num_cameras, view_dirs, sh0_coeffs, shN_coeffs, radii + return (None, None, d_view_dirs, d_sh0, d_shN, None) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def evaluate_gaussian_sh( + means: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + world_to_camera_matrices: torch.Tensor, + projected: ProjectedGaussians, + sh_degree_to_use: int = -1, + render_mode: GaussianRenderMode = GaussianRenderMode.FEATURES, +) -> torch.Tensor: + """Evaluate per-Gaussian render features from SH coefficients (Stage 2). + + Computes view-dependent features based on ``render_mode``: + + - ``FEATURES``: evaluates spherical harmonics to produce view-dependent + colours (or any per-Gaussian feature encoded as SH coefficients). + - ``DEPTH``: returns depths as a single-channel feature (no SH evaluation). + - ``FEATURES_AND_DEPTH``: concatenates SH-evaluated features with depths. + + Differentiable via Python autograd. + + Args: + means: ``[N, 3]`` Gaussian means (used to compute view directions when + the SH degree is > 0). + sh0: ``[N, 1, D]`` Degree-0 SH coefficients. + shN: ``[N, K-1, D]`` Higher-degree SH coefficients. + world_to_camera_matrices: ``[C, 4, 4]`` World-to-camera transforms. + projected: :class:`ProjectedGaussians` from Stage 1 (provides + ``radii`` and ``depths``). + sh_degree_to_use: SH degree to use (-1 for all available). + render_mode: Which quantities to produce. + + Returns: + ``[C, N, D]`` Render features (``D`` depends on render mode). + """ + radii = projected.radii + depths = projected.depths + + if render_mode == GaussianRenderMode.DEPTH: + return depths.unsqueeze(-1) + + K = shN.size(1) + 1 + C = world_to_camera_matrices.size(0) + actual_sh_degree = sh_degree_to_use if sh_degree_to_use >= 0 else int(math.sqrt(K)) - 1 + + if actual_sh_degree == 0: + view_dirs = means.new_empty(0) + else: + cam_to_world = torch.linalg.inv(world_to_camera_matrices) + camera_pos = cam_to_world[:, :3, 3] + view_dirs = means[None, :, :] - camera_pos[:, None, :] + + features = cast( + torch.Tensor, + _EvaluateGaussianSHFn.apply(actual_sh_degree, C, view_dirs, sh0, shN, radii), + ) + + if render_mode == GaussianRenderMode.FEATURES_AND_DEPTH: + features = torch.cat([features, depths.unsqueeze(-1)], -1) + + return features diff --git a/fvdb/functional/_gaussian_tile_intersection.py b/fvdb/functional/_gaussian_tile_intersection.py new file mode 100644 index 000000000..a1eabc351 --- /dev/null +++ b/fvdb/functional/_gaussian_tile_intersection.py @@ -0,0 +1,266 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Helpers for tile intersection (Stage 3).""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from .. import _fvdb_cpp as _C +from ..jagged_tensor import JaggedTensor + +if TYPE_CHECKING: + from ._gaussian_projection import ProjectedGaussians + + +@dataclass(frozen=True) +class GaussianTileIntersection: + """Result of tile-based Gaussian culling for dense rasterization. + + Identifies which Gaussians overlap which screen-space tiles, producing + sorted lists for the tiled rasterizer. + + Attributes: + tile_offsets: Per-tile start offsets into ``tile_gaussian_ids``. + tile_gaussian_ids: Sorted Gaussian indices per tile. + tile_size: Tile side length in pixels. + image_width: Image width in pixels (copied from ``ProjectedGaussians``). + image_height: Image height in pixels (copied from ``ProjectedGaussians``). + """ + + tile_offsets: torch.Tensor + tile_gaussian_ids: torch.Tensor + tile_size: int + image_width: int + image_height: int + + +@dataclass(frozen=True) +class SparseGaussianTileIntersection: + """Result of sparse tile-based Gaussian culling for sparse rasterization. + + Bundles the sparse tile layout, deduplicated pixel information, and tile + intersection data needed by ``rasterize_screen_space_gaussians_sparse``. + + Attributes: + tile_offsets: Sparse tile offsets. + tile_gaussian_ids: Sparse tile Gaussian IDs. + unique_pixels: Deduplicated pixel coordinates as a :class:`JaggedTensor`. + inverse_indices: Mapping back to original (possibly duplicated) pixels. + has_duplicates: Whether the pixel set contained duplicates (and + ``inverse_indices`` is valid). + pixels_to_render: Original pixel coordinates passed by the caller. + active_tiles: ``[num_active_tiles]`` Indices of tiles with at least one pixel. + active_tile_mask: ``[C, TH, TW]`` Boolean mask of active tiles. + tile_pixel_mask: ``[num_active_tiles, words_per_tile]`` Per-tile pixel bitmask. + tile_pixel_cumsum: ``[num_active_tiles]`` Cumulative pixel counts per tile. + pixel_map: ``[num_active_pixels]`` Pixel-to-tile mapping. + tile_size: Tile side length in pixels. + image_width: Image width in pixels (copied from ``ProjectedGaussians``). + image_height: Image height in pixels (copied from ``ProjectedGaussians``). + """ + + tile_offsets: torch.Tensor + tile_gaussian_ids: torch.Tensor + unique_pixels: JaggedTensor + inverse_indices: torch.Tensor + has_duplicates: bool + pixels_to_render: JaggedTensor + active_tiles: torch.Tensor + active_tile_mask: torch.Tensor + tile_pixel_mask: torch.Tensor + tile_pixel_cumsum: torch.Tensor + pixel_map: torch.Tensor + tile_size: int + image_width: int + image_height: int + + +def _find_unique_pixels( + pixels_to_render: JaggedTensor, + image_width: int | None = None, + image_height: int | None = None, +) -> tuple[JaggedTensor, torch.Tensor, bool]: + """Deduplicate pixel coordinates within a JaggedTensor. + + Given a JaggedTensor of ``[row, col]`` pixel coordinates (possibly with + duplicates across or within batches), returns a new JaggedTensor containing + only unique pixels, an inverse-index tensor mapping original positions to + unique positions, and a boolean indicating whether any duplicates were found. + + Args: + pixels_to_render: JaggedTensor with ``rshape = (..., 2)`` holding + ``[row, col]`` coordinates. + image_width: Image width (used for key encoding). Inferred if ``None``. + image_height: Image height (used for key encoding). Inferred if ``None``. + + Returns: + Tuple of ``(unique_pixels, inverse_indices, has_duplicates)``. + """ + jdata = pixels_to_render.jdata + total_pixels = jdata.size(0) + + if total_pixels == 0: + empty_inv = torch.empty(0, dtype=torch.long, device=jdata.device) + return pixels_to_render, empty_inv, False + + device = jdata.device + jidx = pixels_to_render.jidx + + if image_width is None: + image_width = int(jdata[:, 1].max().item()) + 1 + if image_height is None: + image_height = int(jdata[:, 0].max().item()) + 1 + num_pixels_per_image = image_height * image_width + + rows = jdata[:, 0].long() + cols = jdata[:, 1].long() + + single_list = jidx.numel() == 0 + if single_list: + keys = rows * image_width + cols + else: + keys = jidx.long() * num_pixels_per_image + rows * image_width + cols + + sorted_keys, sort_perm = keys.sort() + + is_group_start = torch.ones(total_pixels, dtype=torch.bool, device=device) + if total_pixels > 1: + is_group_start[1:] = sorted_keys[1:] != sorted_keys[:-1] + + first_in_sorted = is_group_start.nonzero(as_tuple=False).squeeze(1) + + group_ids = is_group_start.long().cumsum(0) - 1 + num_unique = group_ids[-1].item() + 1 + + if num_unique == total_pixels: + return pixels_to_render, torch.arange(total_pixels, dtype=torch.long, device=device), False + + inverse_indices = torch.empty(total_pixels, dtype=torch.long, device=device) + inverse_indices[sort_perm] = group_ids + + unique_orig_indices = sort_perm[first_in_sorted] + unique_jdata = jdata[unique_orig_indices] + + if single_list: + unique_batch_idx = torch.zeros(num_unique, dtype=torch.long, device=device) + else: + unique_batch_idx = jidx.long()[unique_orig_indices] + + num_lists = len(pixels_to_render) + counts_per_list = torch.bincount(unique_batch_idx, minlength=num_lists) + new_offsets = torch.zeros(num_lists + 1, dtype=torch.long, device=device) + new_offsets[1:] = counts_per_list.cumsum(0) + + empty_lidx = torch.empty((0, 1), dtype=torch.int32, device=device) + unique_impl = _C.JaggedTensor.from_data_offsets_and_list_ids(unique_jdata, new_offsets, empty_lidx) + unique_pixels = JaggedTensor(impl=unique_impl) + + return unique_pixels, inverse_indices, True + + +def intersect_gaussian_tiles( + projected: ProjectedGaussians, + tile_size: int = 16, +) -> GaussianTileIntersection: + """Compute tile-Gaussian intersections for tiled rasterization (Stage 3). + + Non-differentiable. ``num_cameras``, ``image_width``, and ``image_height`` + are derived from ``projected`` -- not passed explicitly. + + Args: + projected: :class:`ProjectedGaussians` from Stage 1. + tile_size: Tile side length in pixels (default 16). + + Returns: + A :class:`GaussianTileIntersection` with tile offsets, sorted Gaussian + IDs, and the tile/image dimensions. + """ + image_width = projected.image_width + image_height = projected.image_height + num_cameras = projected.means2d.shape[0] + num_tiles_w = math.ceil(image_width / tile_size) + num_tiles_h = math.ceil(image_height / tile_size) + tile_offsets, tile_gaussian_ids = _C.intersect_gaussian_tiles( + projected.means2d, + projected.radii, + projected.depths, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w, + ) + return GaussianTileIntersection( + tile_offsets=tile_offsets, + tile_gaussian_ids=tile_gaussian_ids, + tile_size=tile_size, + image_width=image_width, + image_height=image_height, + ) + + +def intersect_gaussian_tiles_sparse( + pixels_to_render: JaggedTensor, + projected: ProjectedGaussians, + tile_size: int = 16, +) -> SparseGaussianTileIntersection: + """Compute sparse tile-Gaussian intersections for sparse rasterization (Stage 3). + + Fuses pixel deduplication, sparse tile layout computation, and sparse tile + intersection into a single call. Non-differentiable. + + Args: + pixels_to_render: :class:`JaggedTensor` of ``[C, num_pixels, 2]`` pixel + coordinates (may contain duplicates). + projected: :class:`ProjectedGaussians` from Stage 1. + tile_size: Tile side length in pixels (default 16). + + Returns: + A :class:`SparseGaussianTileIntersection` bundling all sparse tile + layout and intersection data. + """ + image_width = projected.image_width + image_height = projected.image_height + num_cameras = projected.means2d.shape[0] + num_tiles_w = math.ceil(image_width / tile_size) + num_tiles_h = math.ceil(image_height / tile_size) + + unique_pixels, inverse_indices, has_duplicates = _find_unique_pixels(pixels_to_render) + + active_tiles, active_tile_mask, tile_pixel_mask, tile_pixel_cumsum, pixel_map = ( + _C.build_sparse_gaussian_tile_layout(tile_size, num_tiles_w, num_tiles_h, unique_pixels._impl) + ) + + sparse_tile_offsets, sparse_tile_gaussian_ids = _C.intersect_gaussian_tiles_sparse( + projected.means2d, + projected.radii, + projected.depths, + active_tile_mask, + active_tiles, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w, + ) + + return SparseGaussianTileIntersection( + tile_offsets=sparse_tile_offsets, + tile_gaussian_ids=sparse_tile_gaussian_ids, + unique_pixels=unique_pixels, + inverse_indices=inverse_indices, + has_duplicates=has_duplicates, + pixels_to_render=pixels_to_render, + active_tiles=active_tiles, + active_tile_mask=active_tile_mask, + tile_pixel_mask=tile_pixel_mask, + tile_pixel_cumsum=tile_pixel_cumsum, + pixel_map=pixel_map, + tile_size=tile_size, + image_width=image_width, + image_height=image_height, + ) diff --git a/fvdb/functional/_io.py b/fvdb/functional/_io.py index 95dd2891c..62461f38d 100644 --- a/fvdb/functional/_io.py +++ b/fvdb/functional/_io.py @@ -1,14 +1,15 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 # -"""Functional API for loading and saving grid batches in NanoVDB format.""" +"""Functional API for loading and saving grid batches in NanoVDB format, and Gaussian PLY I/O.""" + from __future__ import annotations from typing import TYPE_CHECKING, overload import torch -from .. import _fvdb_cpp +from .. import _fvdb_cpp as _C from ..jagged_tensor import JaggedTensor from ..types import DeviceIdentifier, resolve_device @@ -235,6 +236,7 @@ def save_nanovdb_single( .. seealso:: :func:`save_nanovdb` """ import torch + from .._fvdb_cpp import save as _save grid_data = grid.data @@ -246,3 +248,67 @@ def save_nanovdb_single( _save(path, grid_data, data_impl, name, compressed, verbose) else: _save(path, grid_data, data_impl, [], compressed, verbose) + + +# --------------------------------------------------------------------------- +# Gaussian PLY I/O +# --------------------------------------------------------------------------- + + +def load_gaussian_ply( + filename: str, + device: DeviceIdentifier = "cpu", +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, str | int | float | torch.Tensor], +]: + """Load a Gaussian splat model from a PLY file. + + Args: + filename: Path to the ``.ply`` file. + device: Device to load tensors onto. Defaults to ``"cpu"``. + + Returns: + Tuple of (means, quats, log_scales, logit_opacities, sh0, shN, metadata). + """ + device = resolve_device(device) + return _C.load_gaussians_ply(filename, device) + + +def save_gaussian_ply( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + filename: str, + metadata: dict[str, str | int | float | torch.Tensor] | None = None, +) -> None: + """Save a Gaussian splat model to a PLY file. + + Args: + means: ``[N, 3]`` Gaussian centres. + quats: ``[N, 4]`` quaternion rotations. + log_scales: ``[N, 3]`` log-scale parameters. + logit_opacities: ``[N]`` pre-sigmoid opacity logits. + sh0: ``[N, 1, 3]`` zero-order SH coefficients. + shN: ``[N, K, 3]`` higher-order SH coefficients. + filename: Output file path. + metadata: Optional dict of scalar/tensor metadata to embed in the PLY. + """ + _C.save_gaussians_ply( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + filename, + dict(metadata) if metadata is not None else None, + ) diff --git a/fvdb/utils/metrics/ssim.py b/fvdb/functional/_metrics.py similarity index 50% rename from fvdb/utils/metrics/ssim.py rename to fvdb/functional/_metrics.py index 8e9115111..fb8d1b622 100644 --- a/fvdb/utils/metrics/ssim.py +++ b/fvdb/functional/_metrics.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 + # This file contains source code from the fused-ssim library obtained from # https://github.com/rahul-goel/fused-ssim. The fused-ssim library is licensed under the MIT # License. Refer to ORSB 5512107 for more. Original license text follows. @@ -24,21 +27,29 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 -# +"""Image-quality metrics: PSNR and SSIM.""" -from typing import NamedTuple +from __future__ import annotations + +import math +from typing import Literal import torch -import fvdb +from .. import _fvdb_cpp as _fvdb_cpp # noqa: F401 -- loads the custom C++ ops used by torch.ops.fvdb below + +_ALLOWED_PADDING = ("same", "valid") + -allowed_padding = ["same", "valid"] +# --------------------------------------------------------------------------- +# SSIM (fused CUDA implementation) +# --------------------------------------------------------------------------- -class FusedSSIMMap(torch.autograd.Function): +class _FusedSSIMFn(torch.autograd.Function): @staticmethod - def forward(ctx, C1, C2, img1, img2, padding="same", train=True): + def forward(ctx, C1, C2, img1, img2, padding="same", train=True): # type: ignore[override] ( ssim_map, dm_dmu1, @@ -57,7 +68,7 @@ def forward(ctx, C1, C2, img1, img2, padding="same", train=True): return ssim_map @staticmethod - def backward(ctx, opt_grad): + def backward(ctx, opt_grad): # type: ignore[override] img1, img2, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 = ctx.saved_tensors C1, C2, padding = ctx.C1, ctx.C2, ctx.padding dL_dmap = opt_grad @@ -70,18 +81,16 @@ def backward(ctx, opt_grad): return None, None, grad, None, None, None -def fused_ssim(img1, img2, padding="same", train=True): +def _fused_ssim(img1: torch.Tensor, img2: torch.Tensor, padding: str = "same", train: bool = True) -> torch.Tensor: C1 = 0.01**2 C2 = 0.03**2 - assert padding in allowed_padding + if padding not in _ALLOWED_PADDING: + raise ValueError(f"padding must be one of {_ALLOWED_PADDING}, got {padding!r}") img1 = img1.contiguous() - map = FusedSSIMMap.apply(C1, C2, img1, img2, padding, train) - return map.mean() # type: ignore - - -from typing import Literal + ssim_map = _FusedSSIMFn.apply(C1, C2, img1, img2, padding, train) + return ssim_map.mean() # type: ignore[union-attr] def ssim( @@ -102,4 +111,53 @@ def ssim( Returns: ssim (torch.Tensor): The average SSIM between each image over the batch. """ - return fused_ssim(img1, img2, padding, train) + return _fused_ssim(img1, img2, padding, train) + + +# --------------------------------------------------------------------------- +# PSNR +# --------------------------------------------------------------------------- + + +def psnr( + noisy_images: torch.Tensor, + ground_truth_images: torch.Tensor, + max_value: float = 1.0, + reduction: Literal["none", "mean", "sum"] = "mean", +) -> torch.Tensor: + """ + Compute the Peak-Signal-to-Noise-Ratio (PSNR) ratio between two batches of images. + + Args: + noisy_images (torch.Tensor): A batch of noisy images of shape ``(B, C, H, W)`` + ground_truth_images (torch.Tensor): A batch of ground truth images of shape ``(B, C, H, W)`` + max_value (float): The maximum possible value images computed with this loss can have. + Default is 1.0. + reduction (Literal["none", "mean", "sum"]): How to reduce over the batch dimension. ``"sum"`` + and ``"mean"`` will add-up and average the losses across the batch respectively. ``"none"`` will + return each loss as a separate entry in the tensor. Default is ``"mean"``. + + Returns: + psnr (torch.Tensor): The PSNR between the two images. If reduction is not "none", the result + will be reduced over the batch dimension (*i.e.* will be a single scalar), otherwise it will + be a tensor of shape ``(B,)``. + """ + if max_value <= 0: + raise ValueError("max_value must be a positive number") + + if reduction not in ("none", "mean", "sum"): + raise ValueError("reduction must be one of ('none', 'mean', 'sum')") + + if (noisy_images.shape != ground_truth_images.shape) or (noisy_images.dim() != 4): + raise ValueError("Input images must have the same shape and be 4-dimensional with shape (B, C, H, W)") + + mse = torch.mean((noisy_images - ground_truth_images) ** 2, dim=(1, 2, 3)) # [B] + + psnr_val = 10.0 * (2.0 * math.log10(max_value) - torch.log10(mse)) + if reduction == "none": + return psnr_val + elif reduction == "mean": + return torch.mean(psnr_val) + elif reduction == "sum": + return torch.sum(psnr_val) + raise ValueError("reduction must be one of ('none', 'mean', 'sum')") diff --git a/fvdb/gaussian_splatting.py b/fvdb/gaussian_splatting.py index 9fd422112..bfc7f9fe9 100644 --- a/fvdb/gaussian_splatting.py +++ b/fvdb/gaussian_splatting.py @@ -1,18 +1,43 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 # +""" +Object-oriented Gaussian splatting interface. + +This module provides :class:`GaussianSplat3d`, the primary user-facing class +for Gaussian splatting in fVDB. It manages parameter tensors, gradient +accumulators, and rendering state, and delegates the actual computation to +the pure-functional stages in :mod:`fvdb.functional`. + +**Relationship to the functional API.** Internally, +:meth:`GaussianSplat3d.render_images` and the other render methods compose +the same pipeline stages that :mod:`fvdb.functional` exposes publicly +(``project_gaussians`` -> ``evaluate_gaussian_sh`` -> +``intersect_gaussian_tiles`` -> ``rasterize_screen_space_gaussians``). +The class adds three things the functional API intentionally omits: + +1. *Mutable accumulator state* -- gradient norm tracking, max-radii tracking, + and step counts used by densification heuristics during training. +2. *Lazy initialization* -- accumulators are allocated on first use. +3. *Projection method routing* -- automatic dispatch between the differentiable + analytic path and the non-differentiable unscented-transform fallback for + OpenCV camera models. + +Users who need custom pipeline composition or want to avoid mutable state +should use :mod:`fvdb.functional` directly. +""" + +import math import pathlib -from typing import Any, Mapping, Sequence, TypeVar, overload +from typing import Any, Mapping, Sequence, TypeVar, cast, overload import torch import torch.nn.functional as F +from fvdb.enums import CameraModel, GaussianRenderMode, ProjectionMethod -from fvdb.enums import CameraModel, ProjectionMethod - -from . import _fvdb_cpp as _C -from ._fvdb_cpp import GaussianSplat3d as GaussianSplat3dCpp +from . import functional as GF from ._fvdb_cpp import JaggedTensor as JaggedTensorCpp -from ._fvdb_cpp import ProjectedGaussianSplats as ProjectedGaussianSplatsCpp +from .functional import ProjectedGaussians from .grid import Grid from .grid_batch import GridBatch from .jagged_tensor import JaggedTensor @@ -21,13 +46,13 @@ JaggedTensorOrTensorT = TypeVar("JaggedTensorOrTensorT", JaggedTensor, torch.Tensor) -def _pixel_mask_to_tile_mask(pixel_mask: torch.Tensor, tile_size: int) -> torch.Tensor: - """Convert a per-pixel boolean mask ``[C, H, W]`` to a per-tile boolean mask ``[C, tileH, tileW]``. +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- - A tile is ``True`` (render) if **any** pixel in that tile is ``True``. - Uses ``max_pool2d`` with ``ceil_mode=True`` so that partial edge tiles are - handled correctly when ``H`` or ``W`` is not divisible by ``tile_size``. - """ + +def _pixel_mask_to_tile_mask(pixel_mask: torch.Tensor, tile_size: int) -> torch.Tensor: + """``[C, H, W]`` per-pixel boolean mask -> ``[C, tileH, tileW]`` per-tile mask.""" return ( F.max_pool2d( pixel_mask.unsqueeze(1).float(), @@ -46,15 +71,10 @@ def _apply_pixel_mask( pixel_mask: torch.Tensor, backgrounds: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Apply a per-pixel boolean mask ``[C, H, W]`` to rendered features and alphas. - - Masked-out pixels (``False``) are filled with the background colour (or zero) - and their alpha is set to zero. The operation is differentiable: gradients - flow through unmasked pixels and are zero for masked pixels. - """ - mask_float = pixel_mask.unsqueeze(-1).float() # [C, H, W, 1] + """Apply a per-pixel boolean mask to rendered features/alphas (differentiable).""" + mask_float = pixel_mask.unsqueeze(-1).float() if backgrounds is not None: - bg = backgrounds[:, None, None, :] # [C, 1, 1, D] + bg = backgrounds[:, None, None, :] else: bg = torch.zeros(1, 1, 1, features.shape[-1], device=features.device, dtype=features.dtype) features = features * mask_float + bg * (1.0 - mask_float) @@ -62,241 +82,6 @@ def _apply_pixel_mask( return features, alphas -class ProjectedGaussianSplats: - """ - A class representing a set of Gaussian splats projected onto a batch of 2D image planes. - - A :class:`ProjectedGaussianSplats` instance contains the 2D projections of 3D Gaussian splats, which can be used to render - images onto the image planes. Instances of this class are created by calling the :meth:`GaussianSplat3d.project_gaussians_for_images`, - :meth:`GaussianSplat3d.project_gaussians_for_images_and_depths`, etc. methods. - - .. note:: - - The reason to have a separate class for projected Gaussian splats is to be able to run projection once, and then render - the splats multiple times (e.g. rendering crops) without re-projecting them each time. This can save significant computation time. - """ - - __PRIVATE__ = object() - - def __init__(self, impl: ProjectedGaussianSplatsCpp, _private: Any = None) -> None: - """ - Private constructor. Use :meth:`GaussianSplat3d.project_gaussians_for_images` or similar methods to create instances. - - Args: - impl (ProjectedGaussianSplatsCpp): The underlying C++ implementation. - _private (Any): A private object to prevent direct construction. Must be :attr:`ProjectedGaussianSplats.__PRIVATE__`. - """ - if _private is not self.__PRIVATE__: - raise ValueError( - "ProjectedGaussianSplats constructor is private. Use GaussianSplat3d.project_gaussians_for_images or similar methods instead." - ) - self._impl = impl - - @property - def antialias(self) -> bool: - """ - Return whether antialiasing was enabled during the projection of the Gaussian splats. - - Returns: - antialias (bool): ``True`` if antialiasing was enabled during projection, ``False`` otherwise. - """ - return self._impl.antialias - - @property - def inv_covar_2d(self) -> torch.Tensor: - """ - The inverse of the 2D covariance matrices of the Gaussians projected into each image plane. These define the - spatial extent of ellipses for each splatted Gaussian. Note that - since covariance matrices are symmetric, we pack them into a tensor of shape ``(num_projected_gaussians, 3)`` - where each covariance matrix is represented as ``(Cxx, Cxy, Cyy)``. - - Returns: - inv_covar_2d (torch.Tensor): A tensor of shape ``(C, N, D)`` representing the packed inverse 2D covariance matrices, - where ``C`` is the number of image planes, ``N`` is the number of projected Gaussians, and ``D`` is number of feature channels for each - Gaussian (see :attr:`GaussianSplat3d.num_channels`). - """ - return self._impl.conics - - @property - def depths(self) -> torch.Tensor: - """ - Return the depth of each projected Gaussian in each image plane. The depth is defined as the - distance from the camera to the mean of the Gaussian along the camera's viewing direction. - - Returns: - depths (torch.Tensor): A tensor of shape ``(C, N)`` representing the depth of each projected Gaussian, where - ``C`` is the number of image planes, and ``N`` is the number of projected Gaussians. - """ - return self._impl.depths - - @property - def eps_2d(self) -> float: - """ - Return the epsilon value used during the projection of the Gaussian splats to avoid - numerical issues. This value is used to clamp very small radii during projection. - - Returns: - eps_2d (float): The epsilon value used during projection. - """ - return self._impl.eps_2d - - @property - def far_plane(self) -> float: - """ - Return the far plane distance used during the projection of the Gaussian splats. - - Returns: - far_plane (float): The far plane distance. - """ - return self._impl.far_plane - - @property - def image_height(self) -> int: - """ - Return the height of the image planes used during the projection of the Gaussian splats. - - Returns: - image_height (int): The height of the image planes. - """ - return self._impl.image_height - - @property - def image_width(self) -> int: - """ - Return the width of the image planes used during the projection of the Gaussian splats. - - Returns: - image_width (int): The width of the image planes. - """ - return self._impl.image_width - - @property - def means2d(self) -> torch.Tensor: - """ - Return the 2D projected means (in pixel units) of the Gaussians in each image plane. - - Returns: - means2d (torch.Tensor): A tensor of shape ``(C, N, 2)`` representing the 2D projected means, - where ``C`` is the number of image planes, ``N`` is the number of projected Gaussians, - and the last dimension contains the (x, y) coordinates of the means in pixel space. - """ - return self._impl.means2d - - @property - def min_radius_2d(self) -> float: - """ - Return the minimum radius (in pixels) used to clip Gaussians during projection. Gaussians - whose radius projected to less than this value are ignored to avoid numerical issues. - - Returns: - min_radius_2d (float): The minimum radius used during projection. - """ - return self._impl.min_radius_2d - - @property - def near_plane(self) -> float: - """ - Return the near plane distance used during the projection of the Gaussian splats. - - Returns: - near_plane (float): The near plane distance. - """ - return self._impl.near_plane - - @property - def opacities(self) -> torch.Tensor: - """ - Return the opacities of each projected Gaussian in each image plane. - - Returns: - opacities (torch.Tensor): A tensor of shape ``(C, N)`` representing the opacity of each projected Gaussian, where - ``C`` is the number of image planes, and ``N`` is the number of projected Gaussians. - """ - return self._impl.opacities - - @property - def camera_model(self) -> CameraModel: - """ - Return the camera model used during projection. - - Returns: - camera_model (CameraModel): The camera model used during projection. - """ - return GaussianSplat3d._camera_model_from_cpp(self._impl.camera_model) - - @property - def projection_method(self) -> ProjectionMethod: - """ - Return the resolved projection method used during projection. - - Returns: - projection_method (ProjectionMethod): The resolved projection method. - """ - return GaussianSplat3d._projection_method_from_cpp(self._impl.projection_method) - - @property - def radii(self) -> torch.Tensor: - """ - Return the 2D radii (in pixels) of each projected Gaussian in each image plane. The radius of a Gaussian is the maximum extent - of the Gaussian along any direction in the image plane. - - Returns: - radii (torch.Tensor): A tensor of shape ``(C, N)`` representing the 2D radius of each projected Gaussian, where - ``C`` is the number of image planes, and ``N`` is the number of projected Gaussians. - """ - return self._impl.radii - - @property - def render_quantities(self) -> torch.Tensor: - """ - Return the render quantities of each projected Gaussian in each image plane. The render quantities - are used for shading and lighting calculations during rendering. - - Returns: - render_quantities (torch.Tensor): A tensor of shape ``(C, N, D)`` representing the render quantities of each projected Gaussian, - where ``C`` is the number of image planes, ``N`` is the number of projected Gaussians, and ``D`` is the number of feature - channels for each Gaussian (see :attr:`GaussianSplat3d.num_channels`). - """ - return self._impl.render_quantities - - @property - def sh_degree_to_use(self) -> int: - """ - Return the spherical harmonic degree used during the projection of the Gaussian splats. - - .. note:: - - This indicates up to which degree the spherical harmonics coefficients were projected - for each Gaussian. For example, if this value is ``0``, only the diffuse (degree 0) coefficients - were projected. If this value is ``2``, coefficients up to degree 2 were projected. - - Returns: - sh_degree_to_use (int): The spherical harmonic degree used during projection. - """ - return self._impl.sh_degree_to_use - - @property - def tile_gaussian_ids(self) -> torch.Tensor: - """ - Return a tensor containing the ID of each tile/gaussian intersection. - - Returns: - tile_gaussian_ids (torch.Tensor): A tensor of shape ``(M,)`` containing the IDs of the Gaussians. - """ - return self._impl.tile_gaussian_ids - - @property - def tile_offsets(self) -> torch.Tensor: - """ - Return the starting offset of the set of intersections for each tile into :attr:`tile_gaussian_ids`. - - Returns: - tile_offsets (torch.Tensor): A tensor of shape ``(C, TH, TW,)`` where ``C`` is the number of image planes, - ``TH`` is the number of tiles in the height dimension, and ``TW`` is the number of tiles in the width dimension. - """ - return self._impl.tile_offsets - - class GaussianSplat3d: """ An efficient data structure representing a Gaussian splat radiance field in 3D space. @@ -374,15 +159,55 @@ class GaussianSplat3d: __PRIVATE__ = object() + @staticmethod + def _check_gaussian_state( + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + ) -> None: + """Validate tensor shapes, devices, and dtypes for Gaussian splat state.""" + N = means.size(0) + if list(means.shape) != [N, 3]: + raise ValueError("means must have shape (N, 3)") + if list(quats.shape) != [N, 4]: + raise ValueError("quats must have shape (N, 4)") + if list(log_scales.shape) != [N, 3]: + raise ValueError("scales must have shape (N, 3)") + if list(logit_opacities.shape) != [N]: + raise ValueError("opacities must have shape (N)") + if sh0.size(0) != N or sh0.size(1) != 1 or sh0.dim() != 3: + raise ValueError("sh0 must have shape (N, 1, D)") + if shN.size(0) != N or shN.dim() != 3: + raise ValueError("shN must have shape (N, K-1, D)") + + device = means.device + if not all(t.device == device for t in (quats, log_scales, logit_opacities, sh0, shN)): + raise ValueError("All tensors must be on the same device") + if not means.is_floating_point(): + raise ValueError("All tensors must be of floating point type") + dtype = means.dtype + if not all(t.dtype == dtype for t in (quats, log_scales, logit_opacities, sh0, shN)): + raise ValueError("All tensors must be of the same type") + def __init__( self, - impl: GaussianSplat3dCpp, + means: torch.Tensor, + quats: torch.Tensor, + log_scales: torch.Tensor, + logit_opacities: torch.Tensor, + sh0: torch.Tensor, + shN: torch.Tensor, + accumulate_mean_2d_gradients: bool = False, + accumulate_max_2d_radii: bool = False, + detach: bool = False, _private: Any = None, ) -> None: """ - Initializes the :class:`GaussianSplat3d` with an existing C++ implementation. - This constructor is used to wrap an existing instance of :class:`GaussianSplat3dCpp`. - It is only called internally within this class and should not be used directly. + Initializes the :class:`GaussianSplat3d` with raw tensors. + This constructor is used internally. You should not call it directly. .. note:: @@ -391,11 +216,43 @@ def __init__( :class:`GaussianSplat3d`. Args: - impl (GaussianSplat3dCpp): An instance of the C++ implementation. + means (torch.Tensor): Tensor of shape ``(N, 3)`` representing the means. + quats (torch.Tensor): Tensor of shape ``(N, 4)`` representing the quaternions. + log_scales (torch.Tensor): Tensor of shape ``(N, 3)`` representing the log scales. + logit_opacities (torch.Tensor): Tensor of shape ``(N,)`` representing the logit opacities. + sh0 (torch.Tensor): Tensor of shape ``(N, 1, D)`` representing the diffuse SH coefficients. + shN (torch.Tensor): Tensor of shape ``(N, K-1, D)`` representing the higher-degree SH coefficients. + accumulate_mean_2d_gradients (bool): If ``True``, track gradient norms. + accumulate_max_2d_radii (bool): If ``True``, track max 2D radii. + detach (bool): If ``True``, detach tensors from the computation graph. + _private (Any): A private object to prevent direct construction. """ if _private is not self.__PRIVATE__: raise ValueError("GaussianSplat3d constructor is private. Use from_tensors or from_ply instead.") - self._impl = impl + + self._check_gaussian_state(means, quats, log_scales, logit_opacities, sh0, shN) + + if detach: + means = means.detach() + quats = quats.detach() + log_scales = log_scales.detach() + logit_opacities = logit_opacities.detach() + sh0 = sh0.detach() + shN = shN.detach() + + self._means = means + self._quats = quats + self._log_scales = log_scales + self._logit_opacities = logit_opacities + self._sh0 = sh0 + self._shN = shN + self._accumulate_mean_2d_gradients = accumulate_mean_2d_gradients + self._accumulate_max_2d_radii = accumulate_max_2d_radii + + # Accumulator tensors -- lazily initialized by projection with_accum + self._accum_grad_norms: torch.Tensor | None = None + self._accum_step_counts: torch.Tensor | None = None + self._accum_max_2d_radii: torch.Tensor | None = None @classmethod def from_tensors( @@ -448,19 +305,16 @@ def from_tensors( detach (bool, optional): If ``True``, creates copies of the input tensors and detaches them from the computation graph. Defaults to ``False``. """ - return GaussianSplat3d( - GaussianSplat3dCpp( - means=means, - quats=quats, - log_scales=log_scales, - logit_opacities=logit_opacities, - sh0=sh0, - shN=shN, - accumulate_mean_2d_gradients=accumulate_mean_2d_gradients, - accumulate_max_2d_radii=accumulate_max_2d_radii, - detach=detach, - ), + means=means, + quats=quats, + log_scales=log_scales, + logit_opacities=logit_opacities, + sh0=sh0, + shN=shN, + accumulate_mean_2d_gradients=accumulate_mean_2d_gradients, + accumulate_max_2d_radii=accumulate_max_2d_radii, + detach=detach, _private=cls.__PRIVATE__, ) @@ -484,9 +338,22 @@ def from_ply( if isinstance(filename, pathlib.Path): filename = str(filename) - gs_impl, metadata = GaussianSplat3dCpp.from_ply(filename=filename, device=device) + means, quats, log_scales, logit_opacities, sh0, shN, metadata = GF.load_gaussian_ply(filename, device) + + result = cls.__new__(cls) + result._means = means + result._quats = quats + result._log_scales = log_scales + result._logit_opacities = logit_opacities + result._sh0 = sh0 + result._shN = shN + result._accumulate_mean_2d_gradients = False + result._accumulate_max_2d_radii = False + result._accum_grad_norms = None + result._accum_step_counts = None + result._accum_max_2d_radii = None - return cls(impl=gs_impl, _private=cls.__PRIVATE__), metadata + return result, metadata @overload def __getitem__(self, index: slice) -> "GaussianSplat3d": ... @@ -528,14 +395,7 @@ def __getitem__(self, index: slice | torch.Tensor) -> "GaussianSplat3d": """ if isinstance(index, slice): - return GaussianSplat3d( - impl=self._impl.slice_select( - index.start if index.start is not None else 0, - index.stop if index.stop is not None else self.num_gaussians, - index.step if index.step is not None else 1, - ), - _private=self.__PRIVATE__, - ) + return self._index_with_slice(index) elif isinstance(index, torch.Tensor): if index.dim() != 1: raise ValueError("Expected 'index' to be a 1D tensor.") @@ -546,14 +406,46 @@ def __getitem__(self, index: slice | torch.Tensor) -> "GaussianSplat3d": f"Expected 'index_or_mask' to have the same length as the number of Gaussians ({self.num_gaussians}), " f"but got {len(index)}." ) - return GaussianSplat3d(impl=self._impl.mask_select(index), _private=self.__PRIVATE__) + return self._index_with_tensor(index) elif index.dtype == torch.int64 or index.dtype == torch.int32: - return GaussianSplat3d(impl=self._impl.index_select(index), _private=self.__PRIVATE__) + return self._index_with_tensor(index) else: raise ValueError("Expected 'index' to be a boolean or integer (int32 or int64) tensor.") else: raise TypeError("Expected 'index' to be a slice or a torch.Tensor.") + def _index_with_tensor(self, index: torch.Tensor) -> "GaussianSplat3d": + """Internal helper: select Gaussians by boolean mask or integer index tensor.""" + result = GaussianSplat3d.__new__(GaussianSplat3d) + result._means = self._means[index] + result._quats = self._quats[index] + result._log_scales = self._log_scales[index] + result._logit_opacities = self._logit_opacities[index] + result._sh0 = self._sh0[index] + result._shN = self._shN[index] + result._accumulate_mean_2d_gradients = self._accumulate_mean_2d_gradients + result._accumulate_max_2d_radii = self._accumulate_max_2d_radii + result._accum_grad_norms = self._accum_grad_norms[index] if self._accum_grad_norms is not None else None + result._accum_step_counts = self._accum_step_counts[index] if self._accum_step_counts is not None else None + result._accum_max_2d_radii = self._accum_max_2d_radii[index] if self._accum_max_2d_radii is not None else None + return result + + def _index_with_slice(self, s: slice) -> "GaussianSplat3d": + """Internal helper: select Gaussians by slice.""" + result = GaussianSplat3d.__new__(GaussianSplat3d) + result._means = self._means[s] + result._quats = self._quats[s] + result._log_scales = self._log_scales[s] + result._logit_opacities = self._logit_opacities[s] + result._sh0 = self._sh0[s] + result._shN = self._shN[s] + result._accumulate_mean_2d_gradients = self._accumulate_mean_2d_gradients + result._accumulate_max_2d_radii = self._accumulate_max_2d_radii + result._accum_grad_norms = self._accum_grad_norms[s] if self._accum_grad_norms is not None else None + result._accum_step_counts = self._accum_step_counts[s] if self._accum_step_counts is not None else None + result._accum_max_2d_radii = self._accum_max_2d_radii[s] if self._accum_max_2d_radii is not None else None + return result + @overload def __setitem__(self, index: slice, value: "GaussianSplat3d") -> None: ... @@ -601,12 +493,7 @@ def __setitem__(self, index: torch.Tensor | slice, value: "GaussianSplat3d") -> Must have the same number of Gaussians as the selected indices or mask. """ if isinstance(index, slice): - self._impl.slice_set( - index.start if index.start is not None else 0, - index.stop if index.stop is not None else self.num_gaussians, - index.step if index.step is not None else 1, - value._impl, - ) + self._set_with_slice(index, value) return elif isinstance(index, torch.Tensor): @@ -619,14 +506,70 @@ def __setitem__(self, index: torch.Tensor | slice, value: "GaussianSplat3d") -> f"Expected 'index' to have the same length as the number of Gaussians ({self.num_gaussians}), " f"but got {len(index)}." ) - self._impl.mask_set(index, value._impl) + self._set_with_tensor(index, value) elif index.dtype == torch.int64 or index.dtype == torch.int32: - self._impl.index_set(index, value._impl) + self._set_with_tensor(index, value) else: raise ValueError("Expected 'index' to be a boolean or integer (int32 or int64) tensor.") else: raise TypeError("Expected 'index' to be a slice or a torch.Tensor") + def _set_with_tensor(self, index: torch.Tensor, value: "GaussianSplat3d") -> None: + """Internal helper: set Gaussians by boolean mask or integer index tensor.""" + # Use index_put (out-of-place) to avoid in-place errors on leaf tensors with requires_grad + self._means = self._means.index_put((index,), value._means) + self._quats = self._quats.index_put((index,), value._quats) + self._log_scales = self._log_scales.index_put((index,), value._log_scales) + self._logit_opacities = self._logit_opacities.index_put((index,), value._logit_opacities) + self._sh0 = self._sh0.index_put((index,), value._sh0) + self._shN = self._shN.index_put((index,), value._shN) + if self._accum_grad_norms is not None: + if value._accum_grad_norms is not None: + self._accum_grad_norms[index] = value._accum_grad_norms + else: + self._accum_grad_norms[index] = 0.0 + if self._accum_step_counts is not None: + if value._accum_step_counts is not None: + self._accum_step_counts[index] = value._accum_step_counts + else: + self._accum_step_counts[index] = 0.0 + if self._accum_max_2d_radii is not None: + if value._accum_max_2d_radii is not None: + self._accum_max_2d_radii[index] = value._accum_max_2d_radii + else: + self._accum_max_2d_radii[index] = 0.0 + + def _set_with_slice(self, s: slice, value: "GaussianSplat3d") -> None: + """Internal helper: set Gaussians by slice.""" + # Use clone + slice assign to avoid in-place errors on leaf tensors with requires_grad + self._means = self._means.clone() + self._means[s] = value._means + self._quats = self._quats.clone() + self._quats[s] = value._quats + self._log_scales = self._log_scales.clone() + self._log_scales[s] = value._log_scales + self._logit_opacities = self._logit_opacities.clone() + self._logit_opacities[s] = value._logit_opacities + self._sh0 = self._sh0.clone() + self._sh0[s] = value._sh0 + self._shN = self._shN.clone() + self._shN[s] = value._shN + if self._accum_grad_norms is not None: + if value._accum_grad_norms is not None: + self._accum_grad_norms[s] = value._accum_grad_norms + else: + self._accum_grad_norms[s] = 0.0 + if self._accum_step_counts is not None: + if value._accum_step_counts is not None: + self._accum_step_counts[s] = value._accum_step_counts + else: + self._accum_step_counts[s] = 0.0 + if self._accum_max_2d_radii is not None: + if value._accum_max_2d_radii is not None: + self._accum_max_2d_radii[s] = value._accum_max_2d_radii + else: + self._accum_max_2d_radii[s] = 0.0 + def detach(self) -> "GaussianSplat3d": """ Return a new :class:`GaussianSplat3d` instance whose tensors are detached from the computation graph. @@ -636,7 +579,19 @@ def detach(self) -> "GaussianSplat3d": gaussian_splat (GaussianSplat3d): A new :class:`GaussianSplat3d` instance whose tensors are detached. """ - return GaussianSplat3d(impl=self._impl.detach(), _private=self.__PRIVATE__) + result = GaussianSplat3d.__new__(GaussianSplat3d) + result._means = self._means.detach() + result._quats = self._quats.detach() + result._log_scales = self._log_scales.detach() + result._logit_opacities = self._logit_opacities.detach() + result._sh0 = self._sh0.detach() + result._shN = self._shN.detach() + result._accumulate_mean_2d_gradients = self._accumulate_mean_2d_gradients + result._accumulate_max_2d_radii = self._accumulate_max_2d_radii + result._accum_grad_norms = self._accum_grad_norms.detach() if self._accum_grad_norms is not None else None + result._accum_step_counts = self._accum_step_counts.detach() if self._accum_step_counts is not None else None + result._accum_max_2d_radii = self._accum_max_2d_radii.detach() if self._accum_max_2d_radii is not None else None + return result def detach_(self) -> None: """ @@ -648,7 +603,18 @@ def detach_(self) -> None: This method modifies the current instance and does not return a new instance. """ - self._impl.detach_in_place() + self._means = self._means.detach() + self._quats = self._quats.detach() + self._log_scales = self._log_scales.detach() + self._logit_opacities = self._logit_opacities.detach() + self._sh0 = self._sh0.detach() + self._shN = self._shN.detach() + if self._accum_grad_norms is not None: + self._accum_grad_norms = self._accum_grad_norms.detach() + if self._accum_step_counts is not None: + self._accum_step_counts = self._accum_step_counts.detach() + if self._accum_max_2d_radii is not None: + self._accum_max_2d_radii = self._accum_max_2d_radii.detach() @staticmethod def cat( @@ -697,11 +663,75 @@ def cat( Returns: GaussianSplat3d: A new instance of GaussianSplat3d containing the concatenated Gaussians. """ - splat_list = [splat._impl for splat in splats] - return GaussianSplat3d( - impl=GaussianSplat3dCpp.cat(splat_list, accumulate_mean_2d_gradients, accumulate_max_2d_radii, detach), - _private=GaussianSplat3d.__PRIVATE__, - ) + result = GaussianSplat3d.__new__(GaussianSplat3d) + + means_list = [s._means for s in splats] + quats_list = [s._quats for s in splats] + log_scales_list = [s._log_scales for s in splats] + logit_opacities_list = [s._logit_opacities for s in splats] + sh0_list = [s._sh0 for s in splats] + shN_list = [s._shN for s in splats] + + result._means = torch.cat(means_list, dim=0) + result._quats = torch.cat(quats_list, dim=0) + result._log_scales = torch.cat(log_scales_list, dim=0) + result._logit_opacities = torch.cat(logit_opacities_list, dim=0) + result._sh0 = torch.cat(sh0_list, dim=0) + result._shN = torch.cat(shN_list, dim=0) + + if detach: + result._means = result._means.detach() + result._quats = result._quats.detach() + result._log_scales = result._log_scales.detach() + result._logit_opacities = result._logit_opacities.detach() + result._sh0 = result._sh0.detach() + result._shN = result._shN.detach() + + result._accumulate_mean_2d_gradients = accumulate_mean_2d_gradients + result._accumulate_max_2d_radii = accumulate_max_2d_radii + + # Handle accumulator concatenation + if accumulate_mean_2d_gradients: + grad_norms_list = [] + step_counts_list = [] + for s in splats: + n = s.num_gaussians + dev = s._means.device + dtype = s._means.dtype + if s._accum_grad_norms is not None: + grad_norms_list.append(s._accum_grad_norms) + else: + grad_norms_list.append(torch.zeros(n, device=dev, dtype=dtype)) + if s._accum_step_counts is not None: + step_counts_list.append(s._accum_step_counts) + else: + step_counts_list.append(torch.zeros(n, device=dev, dtype=torch.int32)) + result._accum_grad_norms = torch.cat(grad_norms_list, dim=0) + result._accum_step_counts = torch.cat(step_counts_list, dim=0) + if detach: + result._accum_grad_norms = result._accum_grad_norms.detach() + result._accum_step_counts = result._accum_step_counts.detach() + else: + result._accum_grad_norms = None + result._accum_step_counts = None + + if accumulate_max_2d_radii: + max_radii_list = [] + for s in splats: + n = s.num_gaussians + dev = s._means.device + dtype = s._means.dtype + if s._accum_max_2d_radii is not None: + max_radii_list.append(s._accum_max_2d_radii) + else: + max_radii_list.append(torch.zeros(n, device=dev, dtype=torch.int32)) + result._accum_max_2d_radii = torch.cat(max_radii_list, dim=0) + if detach: + result._accum_max_2d_radii = result._accum_max_2d_radii.detach() + else: + result._accum_max_2d_radii = None + + return result @classmethod def from_state_dict(cls, state_dict: dict[str, torch.Tensor]) -> "GaussianSplat3d": @@ -741,7 +771,21 @@ def from_state_dict(cls, state_dict: dict[str, torch.Tensor]) -> "GaussianSplat3 Returns: gaussian_splat (GaussianSplat3d): An instance of :class:`GaussianSplat3d` initialized with the provided state dictionary. """ - return cls(impl=GaussianSplat3dCpp.from_state_dict(state_dict), _private=cls.__PRIVATE__) + result = cls.__new__(cls) + result._means = state_dict["means"] + result._quats = state_dict["quats"] + result._log_scales = state_dict["log_scales"] + result._logit_opacities = state_dict["logit_opacities"] + result._sh0 = state_dict["sh0"] + result._shN = state_dict["shN"] + result._accumulate_mean_2d_gradients = bool(state_dict["accumulate_mean_2d_gradients"].item()) + result._accumulate_max_2d_radii = bool(state_dict["accumulate_max_2d_radii"].item()) + + result._accum_grad_norms = state_dict.get("accumulated_mean_2d_gradient_norms", None) + result._accum_step_counts = state_dict.get("accumulated_gradient_step_counts", None) + result._accum_max_2d_radii = state_dict.get("accumulated_max_2d_radii", None) + + return result @property def device(self) -> torch.device: @@ -751,7 +795,7 @@ def device(self) -> torch.device: Returns: device (torch.device): The device of this :class:`GaussianSplat3d` instance. """ - return self._impl.device + return self._means.device @property def dtype(self) -> torch.dtype: @@ -762,7 +806,7 @@ def dtype(self) -> torch.dtype: Returns: torch.dtype: The data type of the tensors managed by this :class:`GaussianSplat3d` instance. """ - return self._impl.dtype + return self._means.dtype @property def sh_degree(self) -> int: @@ -779,7 +823,8 @@ def sh_degree(self) -> int: Returns: sh_degree (int): The degree of the spherical harmonics. """ - return self._impl.sh_degree + num_bases = self._shN.size(1) + 1 + return int(math.isqrt(num_bases)) - 1 @property def num_channels(self) -> int: @@ -790,7 +835,7 @@ def num_channels(self) -> int: Returns: num_channels (int): The number of channels. """ - return self._impl.num_channels + return self._shN.size(2) @property def num_gaussians(self) -> int: @@ -801,7 +846,7 @@ def num_gaussians(self) -> int: Returns: num_gaussians (int): The number of Gaussians. """ - return self._impl.num_gaussians + return self._means.size(0) @property def num_sh_bases(self) -> int: @@ -816,7 +861,7 @@ def num_sh_bases(self) -> int: Returns: num_sh_bases (int): The number of spherical harmonics bases. """ - return self._impl.num_sh_bases + return self._shN.size(1) + 1 @property def log_scales(self) -> torch.Tensor: @@ -842,7 +887,7 @@ def log_scales(self) -> torch.Tensor: log_scales (torch.Tensor): A tensor of shape ``(N, 3)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the log of the scale of a Gaussian in 3D space. """ - return self._impl.log_scales + return self._log_scales @log_scales.setter def log_scales(self, value: torch.Tensor) -> None: @@ -870,7 +915,7 @@ def log_scales(self, value: torch.Tensor) -> None: scale of a Gaussian in 3D space. """ - self._impl.log_scales = cast_check(value, torch.Tensor, "log_scales") + self._log_scales = cast_check(value, torch.Tensor, "log_scales") @property def logit_opacities(self) -> torch.Tensor: @@ -887,7 +932,7 @@ def logit_opacities(self) -> torch.Tensor: logit_opacities (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the logit of the opacity of a Gaussian in 3D space. """ - return self._impl.logit_opacities + return self._logit_opacities @logit_opacities.setter def logit_opacities(self, value: torch.Tensor) -> None: @@ -904,7 +949,7 @@ def logit_opacities(self, value: torch.Tensor) -> None: value (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the logit of the opacity of a Gaussian in 3D space. """ - self._impl.logit_opacities = cast_check(value, torch.Tensor, "logit_opacities") + self._logit_opacities = cast_check(value, torch.Tensor, "logit_opacities") @property def means(self) -> torch.Tensor: @@ -926,7 +971,7 @@ def means(self) -> torch.Tensor: torch.Tensor: A tensor of shape (N, 3) where N is the number of Gaussians (see `num_gaussians`). Each row represents the mean of a Gaussian in 3D space. """ - return self._impl.means + return self._means @means.setter def means(self, value: torch.Tensor) -> None: @@ -947,7 +992,7 @@ def means(self, value: torch.Tensor) -> None: value (torch.Tensor): A tensor of shape ``(N, 3)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the mean of a Gaussian in 3D space. """ - self._impl.means = cast_check(value, torch.Tensor, "means") + self._means = cast_check(value, torch.Tensor, "means") @property def quats(self) -> torch.Tensor: @@ -967,7 +1012,7 @@ def quats(self) -> torch.Tensor: quats (torch.Tensor): A tensor of shape ``(N, 4)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the unit quaternion of a Gaussian in 3D space. """ - return self._impl.quats + return self._quats @quats.setter def quats(self, value: torch.Tensor) -> None: @@ -987,7 +1032,7 @@ def quats(self, value: torch.Tensor) -> None: value (torch.Tensor): A tensor of shape ``(N, 4)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each row represents the unit quaternion of a Gaussian in 3D space. """ - self._impl.quats = cast_check(value, torch.Tensor, "quats") + self._quats = cast_check(value, torch.Tensor, "quats") @property def requires_grad(self) -> bool: @@ -1012,7 +1057,14 @@ def requires_grad(self) -> bool: Returns: requires_grad (bool): ``True`` if gradients are required, ``False`` otherwise. """ - return self._impl.requires_grad + return ( + self._means.requires_grad + and self._quats.requires_grad + and self._log_scales.requires_grad + and self._logit_opacities.requires_grad + and self._sh0.requires_grad + and self._shN.requires_grad + ) @requires_grad.setter def requires_grad(self, value: bool) -> None: @@ -1037,7 +1089,13 @@ def requires_grad(self, value: bool) -> None: Returns: requires_grad (bool): ``True`` if gradients are required, ``False`` otherwise. """ - self._impl.requires_grad = cast_check(value, bool, "requires_grad") + value = cast_check(value, bool, "requires_grad") + self._means.requires_grad_(value) + self._quats.requires_grad_(value) + self._log_scales.requires_grad_(value) + self._logit_opacities.requires_grad_(value) + self._sh0.requires_grad_(value) + self._shN.requires_grad_(value) @property def sh0(self) -> torch.Tensor: @@ -1050,7 +1108,7 @@ def sh0(self) -> torch.Tensor: of Gaussians (see :attr:`num_gaussians`), and ``D`` is the number of channels (see :attr:`num_channels`). Each row represents the diffuse SH coefficients for a Gaussian. """ - return self._impl.sh0 + return self._sh0 @sh0.setter def sh0(self, value: torch.Tensor) -> None: @@ -1063,7 +1121,7 @@ def sh0(self, value: torch.Tensor) -> None: of Gaussians (see :attr:`num_gaussians`), and ``D`` is the number of channels (see :attr:`num_channels`). Each row represents the diffuse SH coefficients for a Gaussian. """ - self._impl.sh0 = cast_check(value, torch.Tensor, "sh0") + self._sh0 = cast_check(value, torch.Tensor, "sh0") @property def shN(self) -> torch.Tensor: @@ -1077,7 +1135,7 @@ def shN(self) -> torch.Tensor: and K is the number of spherical harmonic bases (see `num_sh_bases`). Each row represents the directionally varying SH coefficients for a Gaussian. """ - return self._impl.shN + return self._shN @shN.setter def shN(self, value: torch.Tensor) -> None: @@ -1091,7 +1149,7 @@ def shN(self, value: torch.Tensor) -> None: and ``K`` is the number of spherical harmonic bases (see :attr:`num_sh_bases`). Each row represents the directionally varying SH coefficients for a Gaussian. """ - self._impl.shN = cast_check(value, torch.Tensor, "shN") + self._shN = cast_check(value, torch.Tensor, "shN") @property def opacities(self) -> torch.Tensor: @@ -1108,7 +1166,7 @@ def opacities(self) -> torch.Tensor: opacities (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each element represents the opacity of a Gaussian. """ - return self._impl.opacities + return torch.sigmoid(self._logit_opacities) @property def scales(self) -> torch.Tensor: @@ -1134,10 +1192,10 @@ def scales(self) -> torch.Tensor: scales (torch.Tensor): A tensor of shape ``(N, 3)`` where ``N`` is the number of Gaussians. Each row represents the scale of a Gaussian in 3D space. """ - return self._impl.scales + return torch.exp(self._log_scales) @property - def accumulated_gradient_step_counts(self) -> torch.Tensor: + def accumulated_gradient_step_counts(self) -> torch.Tensor | None: """ Returns the accumulated gradient step counts for each Gaussian. @@ -1155,10 +1213,10 @@ def accumulated_gradient_step_counts(self) -> torch.Tensor: step_counts (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). Each element represents the accumulated gradient step count for a Gaussian. """ - return self._impl.accumulated_gradient_step_counts + return self._accum_step_counts @property - def accumulated_max_2d_radii(self) -> torch.Tensor: + def accumulated_max_2d_radii(self) -> torch.Tensor | None: """ Returns the maximum 2D projected radius (in pixels) for each Gaussian across all calls to `render_*` functions. This is used by certain optimization techniques to ensure that the Gaussians do not become too large or too small during the optimization process. @@ -1177,7 +1235,7 @@ def accumulated_max_2d_radii(self) -> torch.Tensor: Each element represents the maximum 2D radius for a Gaussian across all optimization iterations. """ - return self._impl.accumulated_max_2d_radii + return self._accum_max_2d_radii @property def accumulate_max_2d_radii(self) -> bool: @@ -1193,7 +1251,7 @@ def accumulate_max_2d_radii(self) -> bool: Returns: accumulate_max_radii (bool): ``True`` if the maximum 2D radii are being tracked across rendering calls, ``False`` otherwise. """ - return self._impl.accumulate_max_2d_radii + return self._accumulate_max_2d_radii @accumulate_max_2d_radii.setter def accumulate_max_2d_radii(self, value) -> None: @@ -1208,7 +1266,7 @@ def accumulate_max_2d_radii(self, value) -> None: Args: value (bool): ``True`` if the maximum 2D radii are being tracked across rendering calls, ``False`` otherwise. """ - self._impl.accumulate_max_2d_radii = cast_check(value, bool, "accumulate_max_2d_radii") + self._accumulate_max_2d_radii = cast_check(value, bool, "accumulate_max_2d_radii") @property def accumulate_mean_2d_gradients(self) -> bool: @@ -1231,7 +1289,7 @@ def accumulate_mean_2d_gradients(self) -> bool: Returns: accumulate_mean_2d_grads (bool): ``True`` if the average norm of the gradient of projected means is being tracked, ``False`` otherwise. """ - return self._impl.accumulate_mean_2d_gradients + return self._accumulate_mean_2d_gradients @accumulate_mean_2d_gradients.setter def accumulate_mean_2d_gradients(self, value: bool) -> None: @@ -1254,10 +1312,10 @@ def accumulate_mean_2d_gradients(self, value: bool) -> None: Args: value (bool): ``True`` if the average norm of the gradient of projected means is being tracked, ``False`` otherwise. """ - self._impl.accumulate_mean_2d_gradients = cast_check(value, bool, "accumulate_mean_2d_gradients") + self._accumulate_mean_2d_gradients = cast_check(value, bool, "accumulate_mean_2d_gradients") @property - def accumulated_mean_2d_gradient_norms(self) -> torch.Tensor: + def accumulated_mean_2d_gradient_norms(self) -> torch.Tensor | None: """ Returns the average norm of the gradient of projected (2D) means for each Gaussian across every backward pass. This is used by certain optimization techniques to split/prune/duplicate Gaussians. @@ -1265,239 +1323,35 @@ def accumulated_mean_2d_gradient_norms(self) -> torch.Tensor: .. math:: - \\sum_{t=1}^{T} \\| \\partial_{L_t} \\mu_i^{2D} \\|_2 - - where :math:`\\mu_i^{2D}` is the projection of the mean of Gaussian :math:`g_i` onto the image plane, - and :math:`L_t` is the loss at iteration :math:`t`. - - .. note:: - - To reset the accumulated norms, call the :meth:`reset_accumulated_gradient_state` method. - - Returns: - accumulated_grad_2d_norms (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). - Each element represents the average norm of the gradient of projected means for a Gaussian across all optimization iterations. - The norm is computed in 2D space, i.e., the projected means. - """ - return self._impl.accumulated_mean_2d_gradient_norms - - def project_gaussians_for_depths( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: CameraModel = CameraModel.PINHOLE, - projection_method: ProjectionMethod = ProjectionMethod.AUTO, - distortion_coeffs: torch.Tensor | None = None, - min_radius_2d: float = 0.0, - eps_2d: float = 0.3, - antialias: bool = False, - ) -> ProjectedGaussianSplats: - """ - Projects this :class:`GaussianSplat3d` onto one or more image planes for rendering depth images in those planes. - You can render depth images from the projected Gaussians by calling :meth:`render_projected_gaussians`. - - .. note:: - - The reason to have a separate projection and rendering step is to enable rendering crops of an image without - having to project the Gaussians again. - - - .. note:: - - All images being rendered must have the same width and height. - - - .. seealso:: - - :class:`fvdb.ProjectedGaussianSplats` for the projected Gaussians representation. - - .. code-block:: python - - # Assume gaussian_splat_3d is an instance of GaussianSplat3d - # Project the Gaussians for rendering depth images onto C image planes - projected_gaussians = gaussian_splat_3d.project_gaussians_for_depths( - world_to_camera_matrices, # tensor of shape [C, 4, 4] - projection_matrices, # tensor of shape [C, 3, 3] - image_width, # width of the C images - image_height, # height of the C images - near, # near clipping plane - far) # far clipping plane - - # Now render a crop of size 100x100 starting at (10, 10) from the projected Gaussians - # in each image plane. - # Returns a tensor of shape [C, 100, 100, 1] containing the depth images, - # and a tensor of shape [C, 100, 100, 1] containing the final alpha (opacity) values - # of each pixel. - cropped_depth_images_1, cropped_alphas = gaussian_splat_3d.render_from_projected_gaussians( - projected_gaussians, - crop_width=100, - crop_height=100, - crop_origin_w=10, - crop_origin_h=10) - - # To get the depth images, divide the last channel by the alpha values - true_depths_1 = cropped_images_1[..., -1:] / cropped_alphas - - Args: - world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the world-to-camera transformation matrices for ``C`` cameras. - Each matrix transforms points from world coordinates to camera coordinates. - projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. - Each matrix projects points in camera space into homogeneous pixel coordinates. - image_width (int): The width of the images to be rendered. Note that all images must have the same width. - image_height (int): The height of the images to be rendered. Note that all images must have the same height. - near (float): The near clipping plane distance for the projection. - far (float): The far clipping plane distance for the projection. - camera_model (CameraModel): Semantic camera model for projection. Default is - :attr:`fvdb.CameraModel.PINHOLE`. - projection_method (ProjectionMethod): Projection implementation selector. Default is - :attr:`fvdb.ProjectionMethod.AUTO`. - distortion_coeffs (torch.Tensor | None): Distortion coefficients with shape ``(C, 12)``. - Required for :class:`CameraModel.OPENCV_*` camera models. For - :class:`CameraModel.PINHOLE` and :class:`CameraModel.ORTHOGRAPHIC`, pass - ``None`` or a ``(C, 12)`` tensor, which is ignored. To represent no - distortion with an OpenCV camera model, pass a zero-filled tensor. - min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. - eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and - numerical issues. - antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. - - Returns: - projected_gaussians (ProjectedGaussianSplats): An instance of ProjectedGaussianSplats containing the projected Gaussians. - This object contains the projected 2D representations of the Gaussians, which can be used for rendering depth images or further processing. - - """ - return ProjectedGaussianSplats( - self._impl.project_gaussians_for_depths( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - ), - _private=ProjectedGaussianSplats.__PRIVATE__, - ) - - def project_gaussians_for_images( - self, - world_to_camera_matrices: torch.Tensor, - projection_matrices: torch.Tensor, - image_width: int, - image_height: int, - near: float, - far: float, - camera_model: CameraModel = CameraModel.PINHOLE, - projection_method: ProjectionMethod = ProjectionMethod.AUTO, - distortion_coeffs: torch.Tensor | None = None, - sh_degree_to_use: int = -1, - min_radius_2d: float = 0.0, - eps_2d: float = 0.3, - antialias: bool = False, - ) -> ProjectedGaussianSplats: - """ - Projects this :class:`GaussianSplat3d` onto one or more image planes for rendering multi-channel (see :attr:`num_channels`) images in those planes. - You can render images from the projected Gaussians by calling :meth:`render_projected_gaussians`. - - .. note:: - - The reason to have a separate projection and rendering step is to enable rendering crops of an image without - having to project the Gaussians again. - - - .. note:: - - All images being rendered must have the same width and height. - - - .. seealso:: - - :class:`fvdb.ProjectedGaussianSplats` for the projected Gaussians representation. - - .. code-block:: python - - # Assume gaussian_splat_3d is an instance of GaussianSplat3d - # Project the Gaussians for rendering images onto C image planes - projected_gaussians = gaussian_splat_3d.project_gaussians_for_images( - world_to_camera_matrices, # tensor of shape [C, 4, 4] - projection_matrices, # tensor of shape [C, 3, 3] - image_width, # width of the C images - image_height, # height of the C images - near, # near clipping plane - far) # far clipping plane - - # Now render a crop of size 100x100 starting at (10, 10) from the projected Gaussians - # in each image plane. - # Returns a tensor of shape [C, 100, 100, D] containing the images (where D is num_channels), - # and a tensor of shape [C, 100, 100, 1] containing the final alpha (opacity) values - # of each pixel. - cropped_images_1, cropped_alphas = gaussian_splat_3d.render_from_projected_gaussians( - projected_gaussians, - crop_width=100, - crop_height=100, - crop_origin_w=10, - crop_origin_h=10) - - Args: - world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the world-to-camera transformation matrices for ``C`` cameras. - Each matrix transforms points from world coordinates to camera coordinates. - projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. - Each matrix projects points in camera space into homogeneous pixel coordinates. - image_width (int): The width of the images to be rendered. Note that all images must have the same width. - image_height (int): The height of the images to be rendered. Note that all images must have the same height. - near (float): The near clipping plane distance for the projection. - far (float): The far clipping plane distance for the projection. - camera_model (CameraModel): Semantic camera model for projection. Default is - :attr:`fvdb.CameraModel.PINHOLE`. - projection_method (ProjectionMethod): Projection implementation selector. Default is - :attr:`fvdb.ProjectionMethod.AUTO`. - distortion_coeffs (torch.Tensor | None): Distortion coefficients with shape ``(C, 12)``. - Required for :class:`CameraModel.OPENCV_*` camera models. For - :class:`CameraModel.PINHOLE` and :class:`CameraModel.ORTHOGRAPHIC`, pass - ``None`` or a ``(C, 12)`` tensor, which is ignored. To represent no - distortion with an OpenCV camera model, pass a zero-filled tensor. - sh_degree_to_use (int): The degree of spherical harmonics to use for rendering. -1 means use all available SH bases. - 0 means use only the first SH base (constant color). Note that you can't use more SH bases than available in the GaussianSplat3d instance. - Default is -1. - min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. - eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and - numerical issues. - antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. + \\sum_{t=1}^{T} \\| \\partial_{L_t} \\mu_i^{2D} \\|_2 - Returns: - projected_gaussians (ProjectedGaussianSplats): An instance of ProjectedGaussianSplats containing the projected Gaussians. - This object contains the projected 2D representations of the Gaussians, which can be used for rendering images or further processing. - - """ - return ProjectedGaussianSplats( - self._impl.project_gaussians_for_images( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - ), - _private=ProjectedGaussianSplats.__PRIVATE__, - ) + where :math:`\\mu_i^{2D}` is the projection of the mean of Gaussian :math:`g_i` onto the image plane, + and :math:`L_t` is the loss at iteration :math:`t`. + + .. note:: + + To reset the accumulated norms, call the :meth:`reset_accumulated_gradient_state` method. - def project_gaussians_for_images_and_depths( + Returns: + accumulated_grad_2d_norms (torch.Tensor): A tensor of shape ``(N,)`` where ``N`` is the number of Gaussians (see :attr:`num_gaussians`). + Each element represents the average norm of the gradient of projected means for a Gaussian across all optimization iterations. + The norm is computed in 2D space, i.e., the projected means. + """ + return self._accum_grad_norms + + def _init_accumulators(self) -> None: + """Lazily initialize gradient/radii accumulators if enabled.""" + N = self._means.size(0) + if self._accumulate_mean_2d_gradients: + if self._accum_grad_norms is None: + self._accum_grad_norms = torch.zeros(N, device=self._means.device, dtype=self._means.dtype) + if self._accum_step_counts is None: + self._accum_step_counts = torch.zeros(N, device=self._means.device, dtype=torch.int32) + if self._accumulate_max_2d_radii: + if self._accum_max_2d_radii is None: + self._accum_max_2d_radii = torch.zeros(N, device=self._means.device, dtype=torch.int32) + + def _project_and_eval_sh( self, world_to_camera_matrices: torch.Tensor, projection_matrices: torch.Tensor, @@ -1505,223 +1359,219 @@ def project_gaussians_for_images_and_depths( image_height: int, near: float, far: float, + render_mode: GaussianRenderMode, camera_model: CameraModel = CameraModel.PINHOLE, projection_method: ProjectionMethod = ProjectionMethod.AUTO, distortion_coeffs: torch.Tensor | None = None, sh_degree_to_use: int = -1, - min_radius_2d: float = 0.0, eps_2d: float = 0.3, + radius_clip: float = 0.0, antialias: bool = False, - ) -> ProjectedGaussianSplats: + ) -> tuple[ProjectedGaussians, torch.Tensor]: + """Project gaussians and evaluate SH features (Stages 1-2). + + Returns ``(projected, features)``. """ - Projects this :class:`GaussianSplat3d` onto one or more image planes for rendering multi-channel (see :attr:`num_channels`) images with depths - in the last channel. - You can render images+depths from the projected Gaussians by calling :meth:`render_projected_gaussians`. + self._init_accumulators() - .. note:: + projected = GF.project_gaussians( + self._means, + self._quats, + self._log_scales, + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + eps_2d=eps_2d, + near=near, + far=far, + radius_clip=radius_clip, + antialias=antialias, + camera_model=camera_model, + projection_method=projection_method, + distortion_coeffs=distortion_coeffs, + accum_grad_norms=self._accum_grad_norms, + accum_step_counts=self._accum_step_counts, + accum_max_radii=self._accum_max_2d_radii, + ) + features = GF.evaluate_gaussian_sh( + self._means, + self._sh0, + self._shN, + world_to_camera_matrices, + projected, + sh_degree_to_use=sh_degree_to_use, + render_mode=render_mode, + ) + return projected, features - The reason to have a separate projection and rendering step is to enable rendering crops of an image without - having to project the Gaussians again. + def project_gaussians( + self, + world_to_camera_matrices: torch.Tensor, + projection_matrices: torch.Tensor, + image_width: int, + image_height: int, + near: float, + far: float, + camera_model: CameraModel = CameraModel.PINHOLE, + projection_method: ProjectionMethod = ProjectionMethod.AUTO, + distortion_coeffs: torch.Tensor | None = None, + min_radius_2d: float = 0.0, + eps_2d: float = 0.3, + antialias: bool = False, + ) -> ProjectedGaussians: + """ + Projects this :class:`GaussianSplat3d` onto one or more image planes. + Returns a :class:`ProjectedGaussians` containing the 2D screen-space + projections. You can then render from this projection using + :meth:`render_from_projected_gaussians`, which is useful for rendering + multiple crops without re-projecting. .. note:: All images being rendered must have the same width and height. - - .. seealso:: - - :class:`fvdb.ProjectedGaussianSplats` for the projected Gaussians representation. + Example: .. code-block:: python - # Assume gaussian_splat_3d is an instance of GaussianSplat3d - # Project the Gaussians for rendering images onto C image planes - projected_gaussians = gaussian_splat_3d.project_gaussians_for_images_and_depths( - world_to_camera_matrices, # tensor of shape [C, 4, 4] - projection_matrices, # tensor of shape [C, 3, 3] - image_width, # width of the C images - image_height, # height of the C images - near, # near clipping plane - far) # far clipping plane - - # Now render a crop of size 100x100 starting at (10, 10) from the projected Gaussians - # in each image plane. - # Returns a tensor of shape [C, 100, 100, D] containing the images (where D is num_channels + 1 for depth), - # and a tensor of shape [C, 100, 100, 1] containing the final alpha (opacity) values - # of each pixel. - cropped_images_1, cropped_alphas = gaussian_splat_3d.render_from_projected_gaussians( - projected_gaussians, - crop_width=100, - crop_height=100, - crop_origin_w=10, - crop_origin_h=10) - - cropped_images = cropped_images_1[..., :-1] # Extract image channels + projected = splat.project_gaussians( + world_to_camera_matrices, # [C, 4, 4] + projection_matrices, # [C, 3, 3] + image_width, image_height, + near, far) - # Divide by alpha to get the final true depth values - cropped_depths = cropped_images_1[..., -1:] / cropped_alphas # Extract depth channel + crop1, alpha1 = splat.render_from_projected_gaussians( + projected, world_to_camera_matrices, + render_mode=GaussianRenderMode.FEATURES, + crop_width=100, crop_height=100, + crop_origin_w=0, crop_origin_h=0) Args: - world_to_camera_matrices (torch.Tensor): Tensor of shape ``(C, 4, 4)`` representing the world-to-camera transformation matrices for ``C`` cameras. - Each matrix transforms points from world coordinates to camera coordinates. - projection_matrices (torch.Tensor): Tensor of shape ``(C, 3, 3)`` representing the projection matrices for ``C`` cameras. - Each matrix projects points in camera space into homogeneous pixel coordinates. - image_width (int): The width of the images to be rendered. Note that all images must have the same width. - image_height (int): The height of the images to be rendered. Note that all images must have the same height. - near (float): The near clipping plane distance for the projection. - far (float): The far clipping plane distance for the projection. - camera_model (CameraModel): Semantic camera model for projection. Default is - :attr:`fvdb.CameraModel.PINHOLE`. - projection_method (ProjectionMethod): Projection implementation selector. Default is - :attr:`fvdb.ProjectionMethod.AUTO`. - distortion_coeffs (torch.Tensor | None): Distortion coefficients with shape ``(C, 12)``. - Required for :class:`CameraModel.OPENCV_*` camera models. For - :class:`CameraModel.PINHOLE` and :class:`CameraModel.ORTHOGRAPHIC`, pass - ``None`` or a ``(C, 12)`` tensor, which is ignored. To represent no - distortion with an OpenCV camera model, pass a zero-filled tensor. - sh_degree_to_use (int): The degree of spherical harmonics to use for rendering. -1 means use all available SH bases. - 0 means use only the first SH base (constant color). Note that you can't use more SH bases than available in the GaussianSplat3d instance. - Default is -1. - min_radius_2d (float): The minimum radius (in pixels) below which Gaussians are ignored during rendering. - eps_2d (float): A value used to pad Gaussians when projecting them onto the image plane, to avoid very projected Gaussians which create artifacts and - numerical issues. - antialias (bool): If ``True``, applies opacity correction to the projected Gaussians when using ``eps_2d > 0.0``. + world_to_camera_matrices (torch.Tensor): ``(C, 4, 4)`` world-to-camera transforms. + projection_matrices (torch.Tensor): ``(C, 3, 3)`` projection matrices. + image_width (int): Width of the target images. + image_height (int): Height of the target images. + near (float): Near clipping plane. + far (float): Far clipping plane. + camera_model (CameraModel): Camera model. Default :attr:`CameraModel.PINHOLE`. + projection_method (ProjectionMethod): Projection selector. Default :attr:`ProjectionMethod.AUTO`. + distortion_coeffs (torch.Tensor | None): ``(C, 12)`` distortion coefficients. + Required for ``OPENCV_*`` camera models. + min_radius_2d (float): Minimum 2D radius in pixels (smaller Gaussians are culled). + eps_2d (float): Padding applied to projected Gaussians. + antialias (bool): Apply opacity correction when ``eps_2d > 0``. Returns: - projected_gaussians (ProjectedGaussianSplats): An instance of ProjectedGaussianSplats containing the projected Gaussians. - This object contains the projected 2D representations of the Gaussians, which can be used for rendering images or further processing. - - """ - return ProjectedGaussianSplats( - self._impl.project_gaussians_for_images_and_depths( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - ), - _private=ProjectedGaussianSplats.__PRIVATE__, + :class:`ProjectedGaussians` with the screen-space projection. + """ + self._init_accumulators() + return GF.project_gaussians( + self._means, + self._quats, + self._log_scales, + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + eps_2d=eps_2d, + near=near, + far=far, + radius_clip=min_radius_2d, + antialias=antialias, + camera_model=camera_model, + projection_method=projection_method, + distortion_coeffs=distortion_coeffs, + accum_grad_norms=self._accum_grad_norms, + accum_step_counts=self._accum_step_counts, + accum_max_radii=self._accum_max_2d_radii, ) def render_from_projected_gaussians( self, - projected_gaussians: ProjectedGaussianSplats, + projected_gaussians: ProjectedGaussians, + world_to_camera_matrices: torch.Tensor, + render_mode: GaussianRenderMode = GaussianRenderMode.FEATURES, + sh_degree_to_use: int = -1, + tile_size: int = 16, crop_width: int = -1, crop_height: int = -1, crop_origin_w: int = -1, crop_origin_h: int = -1, - tile_size: int = 16, backgrounds: torch.Tensor | None = None, masks: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Render a set of images from Gaussian splats that have already been projected onto image planes - (See for example :meth:`project_gaussians_for_images`). - This method is useful when you want to render images from pre-computed projected Gaussians, - for example, when rendering crops of images without having to re-project the Gaussians. - - .. note:: - - If you want to render the full image, pass negative values for ``crop_width``, ``crop_height``, - ``crop_origin_w``, and ``crop_origin_h`` (default behavior). To render full images, - all these values must be negative or this method will raise an error. - - .. note:: - - If your crop goes beyond the image boundaries, the resulting image will be clipped to - be within the image boundaries. + Render images from pre-projected Gaussians. + This avoids the expensive projection step when rendering multiple crops + of the same scene. SH evaluation, tile intersection, and rasterization + are performed inside this call. Example: .. code-block:: python - # Assume gaussian_splat_3d is an instance of GaussianSplat3d - # Project the Gaussians for rendering images onto C image planes - projected_gaussians = gaussian_splat_3d.project_gaussians_for_images_and_depths( - world_to_camera_matrices, # tensor of shape [C, 4, 4] - projection_matrices, # tensor of shape [C, 3, 3] - image_width, # width of the C images - image_height, # height of the C images - near, # near clipping plane - far) # far clipping plane - - # Now render a crop of size 100x100 starting at (10, 10) from the projected Gaussians - # in each image plane. - # Returns a tensor of shape [C, 100, 100, D] containing the images (where D is num_channels + 1 for depth), - # and a tensor of shape [C, 100, 100, 1] containing the final alpha (opacity) values - # of each pixel. - cropped_images_1, cropped_alphas = gaussian_splat_3d.render_from_projected_gaussians( - projected_gaussians, - crop_width=100, - crop_height=100, - crop_origin_w=10, - crop_origin_h=10) - - cropped_images = cropped_images_1[..., :-1] # Extract image channels - - # Divide by alpha to get the final true depth values - cropped_depths = cropped_images_1[..., -1:] / cropped_alphas # Extract depth channel + projected = splat.project_gaussians(w2c, K, W, H, near, far) + crop, alpha = splat.render_from_projected_gaussians( + projected, w2c, + render_mode=GaussianRenderMode.FEATURES, + crop_width=100, crop_height=100, + crop_origin_w=10, crop_origin_h=10) Args: - projected_gaussians (ProjectedGaussianSplats): An instance of :class:`fvdb.ProjectedGaussianSplats` - containing the projected Gaussians after spherical harmonic evaluation. This object should have been created by calling - :meth:`project_gaussians_for_images`, :meth:`project_gaussians_for_depths`, - :meth:`project_gaussians_for_images_and_depths`, etc. - crop_width (int): The width of the crop to render. If -1, the full image width is used. - Default is -1. - crop_height (int): The height of the crop to render. If -1, the full image height is used. - Default is -1. - crop_origin_w (int): The x-coordinate of the top-left corner of the crop. If -1, the crop starts at (0, 0). - Default is -1. - crop_origin_h (int): The y-coordinate of the top-left corner of the crop. If -1, the crop starts at (0, 0). - Default is -1. - tile_size (int): The size of the tiles to use for rendering. Default is 16. - This parameter controls the size of the tiles used for rendering the images. - You shouldn't set this parameter unless you really know what you are doing. - backgrounds (torch.Tensor | None): Optional background colors of shape ``(C, D)``. - If ``None``, background is treated as 0. - masks (torch.Tensor | None): Optional per-pixel boolean mask of shape ``(C, cropH, cropW)`` - (in crop coordinate space, matching the output dimensions). - ``True`` means render, ``False`` means skip (filled with background). - + projected_gaussians (ProjectedGaussians): Projection from :meth:`project_gaussians`. + world_to_camera_matrices (torch.Tensor): ``(C, 4, 4)`` world-to-camera transforms + (needed for view-dependent SH evaluation). + render_mode (GaussianRenderMode): What to render. Default ``FEATURES``. + sh_degree_to_use (int): SH degree to use (``-1`` = all available). + tile_size (int): Tile size for rasterization. + crop_width (int): Crop width (``-1`` for full image). + crop_height (int): Crop height (``-1`` for full image). + crop_origin_w (int): Crop x-origin (``-1`` for 0). + crop_origin_h (int): Crop y-origin (``-1`` for 0). + backgrounds (torch.Tensor | None): ``(C, D)`` background colors. + masks (torch.Tensor | None): ``(C, cropH, cropW)`` per-pixel boolean mask. Returns: - rendered_images (torch.Tensor): A tensor of shape ``(C, H, W, D)`` where ``C`` is the number of image planes, - ``H`` is the height of the rendered images, ``W`` is the width of the rendered images, and ``D`` is the - number of channels (e.g., RGB, RGBD, etc.). - alpha_images (torch.Tensor): A tensor of shape ``(C, H, W, 1)`` where ``C`` is the number of cameras, - ``H`` is the height of the images, and ``W`` is the width of the images. - Each element represents the alpha value (opacity) at a pixel such that 0 <= alpha < 1, - and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. - """ + Tuple of ``(rendered_images, alpha_images)``. + """ + features = GF.evaluate_gaussian_sh( + self._means, + self._sh0, + self._shN, + world_to_camera_matrices, + projected_gaussians, + sh_degree_to_use=sh_degree_to_use, + render_mode=render_mode, + ) + tiles = GF.intersect_gaussian_tiles(projected_gaussians, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - features, alphas = self._impl.render_from_projected_gaussians( - projected_gaussians=projected_gaussians._impl, - crop_width=crop_width, - crop_height=crop_height, - crop_origin_w=crop_origin_w, - crop_origin_h=crop_origin_h, - tile_size=tile_size, - backgrounds=backgrounds, - masks=tile_masks, + is_crop = crop_width > 0 or crop_height > 0 or crop_origin_w >= 0 or crop_origin_h >= 0 + crop_rect: tuple[int, int, int, int] | None = None + if is_crop: + crop_rect = ( + crop_origin_w if crop_origin_w >= 0 else 0, + crop_origin_h if crop_origin_h >= 0 else 0, + crop_width if crop_width > 0 else projected_gaussians.image_width, + crop_height if crop_height > 0 else projected_gaussians.image_height, + ) + result, alphas = GF.rasterize_screen_space_gaussians( + projected_gaussians, + features, + self._logit_opacities, + tiles, + backgrounds, + tile_masks, + crop=crop_rect, ) if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def render_depths( self, @@ -1805,30 +1655,35 @@ def render_depths( Each element represents the alpha value (opacity) at a pixel such that ``0 <= alpha < 1``, and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_depths( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + result, alphas = GF.rasterize_screen_space_gaussians( + projected, + render_features, + self._logit_opacities, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def sparse_render_depths( self, @@ -1921,29 +1776,44 @@ def sparse_render_depths( else: raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") - ret_depths, ret_alphas = self._impl.sparse_render_depths( - pixels_to_render=pixels_to_render_impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=masks, + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + pixels_jt = JaggedTensor(impl=pixels_to_render_impl) + sparse_tiles = GF.intersect_gaussian_tiles_sparse(pixels_jt, projected, tile_size=tile_size) + rendered_jt, rendered_alphas_jt = GF.rasterize_screen_space_gaussians_sparse( + projected, + render_features, + self._logit_opacities, + sparse_tiles, + backgrounds, + masks, ) + rendered_colors = rendered_jt.jdata + rendered_alphas = rendered_alphas_jt.jdata + if sparse_tiles.has_duplicates: + rendered_colors = rendered_colors.index_select(0, sparse_tiles.inverse_indices) + rendered_alphas = rendered_alphas.index_select(0, sparse_tiles.inverse_indices) + ret_features = pixels_jt.jagged_like(rendered_colors) + ret_alphas = pixels_jt.jagged_like(rendered_alphas) if isinstance(pixels_to_render, torch.Tensor): - return ret_depths.jdata, ret_alphas.jdata + return ret_features._impl.jdata, ret_alphas._impl.jdata else: - return JaggedTensor(impl=ret_depths), JaggedTensor(impl=ret_alphas) + return ret_features, ret_alphas def render_images( self, @@ -2029,31 +1899,35 @@ def render_images( Each element represents the alpha value (opacity) at a pixel such that ``0 <= alpha < 1``, and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_images( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + result, alphas = GF.rasterize_screen_space_gaussians( + projected, + render_features, + self._logit_opacities, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def render_images_from_world( self, @@ -2149,31 +2023,48 @@ def render_images_from_world( images (torch.Tensor): Rendered images of shape ``(C, H, W, D)``. alpha_images (torch.Tensor): Alpha images of shape ``(C, H, W, 1)``. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_images_from_world( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + C = world_to_camera_matrices.size(0) + dc = ( + distortion_coeffs + if distortion_coeffs is not None + else torch.empty(C, 0, device=self._means.device, dtype=self._means.dtype) + ) + result, alphas = GF.rasterize_world_space_gaussians( + self._means, + self._quats, + self._log_scales, + projected, + render_features, + self._logit_opacities, + world_to_camera_matrices, + projection_matrices, + dc, + camera_model, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def render_depths_from_world( self, @@ -2199,30 +2090,48 @@ def render_depths_from_world( This mirrors :meth:`render_images_from_world`, but renders depth-only outputs with the same camera-model and projection-method dispatch. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_depths_from_world( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + C = world_to_camera_matrices.size(0) + dc = ( + distortion_coeffs + if distortion_coeffs is not None + else torch.empty(C, 0, device=self._means.device, dtype=self._means.dtype) + ) + result, alphas = GF.rasterize_world_space_gaussians( + self._means, + self._quats, + self._log_scales, + projected, + render_features, + self._logit_opacities, + world_to_camera_matrices, + projection_matrices, + dc, + camera_model, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def sparse_render_images( self, @@ -2319,30 +2228,44 @@ def sparse_render_images( else: raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") - ret_features, ret_alphas = self._impl.sparse_render_images( - pixels_to_render=pixels_to_render_impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=masks, + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + pixels_jt = JaggedTensor(impl=pixels_to_render_impl) + sparse_tiles = GF.intersect_gaussian_tiles_sparse(pixels_jt, projected, tile_size=tile_size) + rendered_jt, rendered_alphas_jt = GF.rasterize_screen_space_gaussians_sparse( + projected, + render_features, + self._logit_opacities, + sparse_tiles, + backgrounds, + masks, ) + rendered_colors = rendered_jt.jdata + rendered_alphas = rendered_alphas_jt.jdata + if sparse_tiles.has_duplicates: + rendered_colors = rendered_colors.index_select(0, sparse_tiles.inverse_indices) + rendered_alphas = rendered_alphas.index_select(0, sparse_tiles.inverse_indices) + ret_features = pixels_jt.jagged_like(rendered_colors) + ret_alphas = pixels_jt.jagged_like(rendered_alphas) if isinstance(pixels_to_render, torch.Tensor): - return ret_features.jdata, ret_alphas.jdata + return ret_features._impl.jdata, ret_alphas._impl.jdata else: - return JaggedTensor(impl=ret_features), JaggedTensor(impl=ret_alphas) + return ret_features, ret_alphas def sparse_render_images_and_depths( self, @@ -2440,30 +2363,44 @@ def sparse_render_images_and_depths( else: raise TypeError("pixels_to_render must be either a torch.Tensor or a fvdb.JaggedTensor") - ret_features, ret_alphas = self._impl.sparse_render_images_and_depths( - pixels_to_render=pixels_to_render_impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=masks, + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES_AND_DEPTH, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + pixels_jt = JaggedTensor(impl=pixels_to_render_impl) + sparse_tiles = GF.intersect_gaussian_tiles_sparse(pixels_jt, projected, tile_size=tile_size) + rendered_jt, rendered_alphas_jt = GF.rasterize_screen_space_gaussians_sparse( + projected, + render_features, + self._logit_opacities, + sparse_tiles, + backgrounds, + masks, ) + rendered_colors = rendered_jt.jdata + rendered_alphas = rendered_alphas_jt.jdata + if sparse_tiles.has_duplicates: + rendered_colors = rendered_colors.index_select(0, sparse_tiles.inverse_indices) + rendered_alphas = rendered_alphas.index_select(0, sparse_tiles.inverse_indices) + ret_features = pixels_jt.jagged_like(rendered_colors) + ret_alphas = pixels_jt.jagged_like(rendered_alphas) if isinstance(pixels_to_render, torch.Tensor): - return ret_features.jdata, ret_alphas.jdata + return ret_features._impl.jdata, ret_alphas._impl.jdata else: - return JaggedTensor(impl=ret_features), JaggedTensor(impl=ret_alphas) + return ret_features, ret_alphas def render_images_and_depths( self, @@ -2552,31 +2489,35 @@ def render_images_and_depths( Each element represents the alpha value (opacity) at a pixel such that ``0 <= alpha < 1``, and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES_AND_DEPTH, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_images_and_depths( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + result, alphas = GF.rasterize_screen_space_gaussians( + projected, + render_features, + self._logit_opacities, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def render_images_and_depths_from_world( self, @@ -2603,31 +2544,48 @@ def render_images_and_depths_from_world( This mirrors :meth:`render_images_from_world`, but returns image channels with depth in the final channel while using the same camera-model and projection-method dispatch. """ + projected, render_features = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.FEATURES_AND_DEPTH, + camera_model, + projection_method, + distortion_coeffs, + sh_degree_to_use, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) tile_masks = _pixel_mask_to_tile_mask(masks, tile_size) if masks is not None else None - - features, alphas = self._impl.render_images_and_depths_from_world( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - sh_degree_to_use=sh_degree_to_use, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - backgrounds=backgrounds, - masks=tile_masks, + C = world_to_camera_matrices.size(0) + dc = ( + distortion_coeffs + if distortion_coeffs is not None + else torch.empty(C, 0, device=self._means.device, dtype=self._means.dtype) + ) + result, alphas = GF.rasterize_world_space_gaussians( + self._means, + self._quats, + self._log_scales, + projected, + render_features, + self._logit_opacities, + world_to_camera_matrices, + projection_matrices, + dc, + camera_model, + tiles, + backgrounds, + tile_masks, ) - if masks is not None: - features, alphas = _apply_pixel_mask(features, alphas, masks, backgrounds) - - return features, alphas + result, alphas = _apply_pixel_mask(result, alphas, masks, backgrounds) + return result, alphas def render_num_contributing_gaussians( self, @@ -2705,21 +2663,24 @@ def render_num_contributing_gaussians( Each element represents the alpha value (opacity) at a pixel such that ``0 <= alpha < 1``, and 0 means the pixel is fully transparent, and 1 means the pixel is fully opaque. """ - return self._impl.render_num_contributing_gaussians( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, + projected, _ = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) + return GF.count_contributing_gaussians(projected, self._logit_opacities, tiles) @overload def sparse_render_num_contributing_gaussians( @@ -2827,51 +2788,42 @@ def sparse_render_num_contributing_gaussians( C, R, _ = pixels_to_render.shape tensors = [pixels_to_render[i] for i in range(C)] pixels_to_render_jagged = JaggedTensor(tensors) + else: + pixels_to_render_jagged = pixels_to_render + + projected, _ = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + sparse_tiles = GF.intersect_gaussian_tiles_sparse( + pixels_to_render_jagged, + projected, + tile_size=tile_size, + ) + result_num, result_alphas = GF.count_contributing_gaussians_sparse( + projected, + self._logit_opacities, + sparse_tiles, + ) - result_num_contributing_gaussians, result_alphas = self._impl.sparse_render_num_contributing_gaussians( - pixels_to_render=pixels_to_render_jagged._impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - ) - - num_contributing_gaussians_list = result_num_contributing_gaussians.unbind() - alphas_list = result_alphas.unbind() - dense_num_contributing_gaussians = torch.stack(num_contributing_gaussians_list, dim=0) # type: ignore # Shape: (C, R) - dense_alphas = torch.stack(alphas_list, dim=0) # type: ignore # Shape: (C, R) - - return dense_num_contributing_gaussians, dense_alphas + if isinstance(pixels_to_render, torch.Tensor): + num_list = cast(list[torch.Tensor], result_num.unbind()) + alphas_list = cast(list[torch.Tensor], result_alphas.unbind()) + return torch.stack(num_list, dim=0), torch.stack(alphas_list, dim=0) else: - # Already a JaggedTensor, call C++ implementation directly - result_num_contributing_gaussians_impl, result_alphas_impl = ( - self._impl.sparse_render_num_contributing_gaussians( - pixels_to_render=pixels_to_render._impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - ) - ) - return JaggedTensor(impl=result_num_contributing_gaussians_impl), JaggedTensor(impl=result_alphas_impl) + return result_num, result_alphas def render_contributing_gaussian_ids( self, @@ -2927,23 +2879,29 @@ def render_contributing_gaussian_ids( jagged tensor containing the weights of the contributing Gaussians of each rendered pixel for each camera. The weights are in row-major order and sum to 1 for each pixel if that pixel is opaque (alpha=1). """ - ids, weights = self._impl.render_contributing_gaussian_ids( - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, + projected, _ = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + tiles = GF.intersect_gaussian_tiles(projected, tile_size=tile_size) + return GF.identify_contributing_gaussians( + projected, + self._logit_opacities, + tiles, top_k_contributors=top_k_contributors, ) - return JaggedTensor(impl=ids), JaggedTensor(impl=weights) @overload def sparse_render_contributing_gaussian_ids( @@ -3049,47 +3007,36 @@ def sparse_render_contributing_gaussian_ids( if isinstance(pixels_to_render, torch.Tensor): C, R, _ = pixels_to_render.shape tensors = [pixels_to_render[i] for i in range(C)] - pixels_to_render_jagged = JaggedTensor(tensors) - - result_ids, result_weights = self._impl.sparse_render_contributing_gaussian_ids( - pixels_to_render=pixels_to_render_jagged._impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - top_k_contributors=top_k_contributors, - ) - - return JaggedTensor(impl=result_ids), JaggedTensor(impl=result_weights) + pixels_jt = JaggedTensor(tensors) else: - # Already a JaggedTensor, call C++ implementation directly - result_ids_impl, result_weights_impl = self._impl.sparse_render_contributing_gaussian_ids( - pixels_to_render=pixels_to_render._impl, - world_to_camera_matrices=world_to_camera_matrices, - projection_matrices=projection_matrices, - image_width=image_width, - image_height=image_height, - near=near, - far=far, - camera_model=self._camera_model_to_cpp(camera_model), - projection_method=self._projection_method_to_cpp(projection_method), - distortion_coeffs=distortion_coeffs, - tile_size=tile_size, - min_radius_2d=min_radius_2d, - eps_2d=eps_2d, - antialias=antialias, - top_k_contributors=top_k_contributors, - ) - return JaggedTensor(impl=result_ids_impl), JaggedTensor(impl=result_weights_impl) + pixels_jt = pixels_to_render + + projected, _ = self._project_and_eval_sh( + world_to_camera_matrices, + projection_matrices, + image_width, + image_height, + near, + far, + GaussianRenderMode.DEPTH, + camera_model, + projection_method, + distortion_coeffs, + -1, + eps_2d, + min_radius_2d, + antialias, + ) + sparse_tiles = GF.intersect_gaussian_tiles_sparse( + pixels_jt, + projected, + tile_size=tile_size, + ) + return GF.identify_contributing_gaussians_sparse( + projected, + self._logit_opacities, + sparse_tiles, + ) def relocate_gaussians( self, @@ -3113,7 +3060,7 @@ def relocate_gaussians( Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of (logit_opacities_new [N], log_scales_new [N, 3]). """ - return self._impl.relocate_gaussians( + return GF.relocate_gaussians( log_scales, logit_opacities, ratios, @@ -3131,7 +3078,15 @@ def add_noise_to_means(self, noise_scale: float, t: float = 0.005, k: float = 10 t (float): Parameter t for noise scaling. Defaults to 0.005. k (float): Parameter k for noise scaling. Defaults to 100.0. """ - self._impl.add_noise_to_means(noise_scale, t, k) + GF.add_noise_to_gaussian_means( + self._means, + self._log_scales, + self._logit_opacities, + self._quats, + noise_scale, + t, + k, + ) def reset_accumulated_gradient_state(self) -> None: """ @@ -3150,10 +3105,17 @@ def reset_accumulated_gradient_state(self) -> None: for the actual accumulated state being reset. """ - self._impl.reset_accumulated_gradient_state() + if self._accum_grad_norms is not None: + self._accum_grad_norms.zero_() + if self._accum_step_counts is not None: + self._accum_step_counts.zero_() + if self._accum_max_2d_radii is not None: + self._accum_max_2d_radii.zero_() def save_ply( - self, filename: pathlib.Path | str, metadata: Mapping[str, str | int | float | torch.Tensor] | None = None + self, + filename: pathlib.Path | str, + metadata: Mapping[str, str | int | float | torch.Tensor] | None = None, ) -> None: """ Save this :class:`GaussianSplat3d` to a PLY file. and include any metadata provided. @@ -3165,7 +3127,16 @@ def save_ply( """ if isinstance(filename, pathlib.Path): filename = str(filename) - self._impl.save_ply(filename, metadata) # type: ignore -- mapping to dict is fine here + GF.save_gaussian_ply( + self._means, + self._quats, + self._log_scales, + self._logit_opacities, + self._sh0, + self._shN, + filename, + dict(metadata) if metadata is not None else metadata, + ) @overload def to(self, dtype: torch.dtype | None = None) -> "GaussianSplat3d": ... @@ -3228,6 +3199,10 @@ def to( gaussian_splat_3d (GaussianSplat3d): A new instance of :class:`GaussianSplat3d` with the specified device and/or data type. """ + # Initialize to satisfy pyright; all valid paths below assign both. + device: torch.device | None = None + dtype: torch.dtype | None = None + # All values passed by keyword arguments if len(args) == 0: if len(kwargs) == 1: @@ -3285,13 +3260,26 @@ def to( device = resolve_device(device, inherit_from=self) dtype = self.dtype if dtype is None else cast_check(dtype, torch.dtype, "dtype") - return GaussianSplat3d( - impl=self._impl.to( - device=device, - dtype=dtype, - ), - _private=GaussianSplat3d.__PRIVATE__, + result = GaussianSplat3d.__new__(GaussianSplat3d) + result._means = self._means.to(device=device, dtype=dtype) + result._quats = self._quats.to(device=device, dtype=dtype) + result._log_scales = self._log_scales.to(device=device, dtype=dtype) + result._logit_opacities = self._logit_opacities.to(device=device, dtype=dtype) + result._sh0 = self._sh0.to(device=device, dtype=dtype) + result._shN = self._shN.to(device=device, dtype=dtype) + result._accumulate_mean_2d_gradients = self._accumulate_mean_2d_gradients + result._accumulate_max_2d_radii = self._accumulate_max_2d_radii + # Grad norms follow the main dtype; step counts and max radii are int32 (device-only move) + result._accum_grad_norms = ( + self._accum_grad_norms.to(device=device, dtype=dtype) if self._accum_grad_norms is not None else None + ) + result._accum_step_counts = ( + self._accum_step_counts.to(device=device) if self._accum_step_counts is not None else None ) + result._accum_max_2d_radii = ( + self._accum_max_2d_radii.to(device=device) if self._accum_max_2d_radii is not None else None + ) + return result def set_state( self, @@ -3324,14 +3312,15 @@ def set_state( ``D`` is the number of channels (see :attr:`num_channels`), and ``K`` is the number of spherical harmonic bases (see :attr:`num_sh_bases`). """ - self._impl.set_state( - means=means, - quats=quats, - log_scales=log_scales, - logit_opacities=logit_opacities, - sh0=sh0, - shN=shN, - ) + self._check_gaussian_state(means, quats, log_scales, logit_opacities, sh0, shN) + self._means = means + self._quats = quats + self._log_scales = log_scales + self._logit_opacities = logit_opacities + self._sh0 = sh0 + self._shN = shN + # Reset accumulators when state changes (matches C++ behavior) + self.reset_accumulated_gradient_state() def state_dict(self) -> dict[str, torch.Tensor]: """ @@ -3370,30 +3359,20 @@ def state_dict(self) -> dict[str, torch.Tensor]: state_dict (dict[str, torch.Tensor]): A dictionary containing the state of the :class:`GaussianSplat3d` instance. """ - return self._impl.state_dict() - - @staticmethod - def _camera_model_from_cpp(camera_model: _C.CameraModel) -> CameraModel: - try: - return CameraModel[camera_model.name] - except KeyError as exc: - raise ValueError(f"Invalid camera model: {camera_model}") from exc - - @staticmethod - def _camera_model_to_cpp(camera_model: CameraModel) -> _C.CameraModel: - if isinstance(camera_model, CameraModel): - return getattr(_C.CameraModel, camera_model.name) - return camera_model - - @staticmethod - def _projection_method_from_cpp(projection_method: _C.ProjectionMethod) -> ProjectionMethod: - try: - return ProjectionMethod[projection_method.name] - except KeyError as exc: - raise ValueError(f"Invalid projection method: {projection_method}") from exc - - @staticmethod - def _projection_method_to_cpp(projection_method: ProjectionMethod) -> _C.ProjectionMethod: - if isinstance(projection_method, ProjectionMethod): - return getattr(_C.ProjectionMethod, projection_method.name) - return projection_method + sd: dict[str, torch.Tensor] = { + "means": self._means, + "quats": self._quats, + "log_scales": self._log_scales, + "logit_opacities": self._logit_opacities, + "sh0": self._sh0, + "shN": self._shN, + "accumulate_mean_2d_gradients": torch.tensor(self._accumulate_mean_2d_gradients), + "accumulate_max_2d_radii": torch.tensor(self._accumulate_max_2d_radii), + } + if self._accum_grad_norms is not None: + sd["accumulated_mean_2d_gradient_norms"] = self._accum_grad_norms + if self._accum_step_counts is not None: + sd["accumulated_gradient_step_counts"] = self._accum_step_counts + if self._accum_max_2d_radii is not None: + sd["accumulated_max_2d_radii"] = self._accum_max_2d_radii + return sd diff --git a/fvdb/types.py b/fvdb/types.py index 9b2d8a4e6..0590a0ba5 100644 --- a/fvdb/types.py +++ b/fvdb/types.py @@ -72,7 +72,12 @@ def is_Vec3i(x: Any) -> bool: if isinstance(x, torch.Size): return len(x) == 3 if isinstance(x, (torch.Tensor, numpy.ndarray)): - return x.shape == (3,) and x.dtype in (torch.int32, torch.int64, numpy.int32, numpy.int64) + return x.shape == (3,) and x.dtype in ( + torch.int32, + torch.int64, + numpy.int32, + numpy.int64, + ) if isinstance(x, list): return len(x) == 3 and all(isinstance(i, int) for i in x) if isinstance(x, tuple): @@ -108,7 +113,12 @@ def is_Vec3iOrScalar(x: Any) -> bool: def is_Vec4i(x: Any) -> bool: if isinstance(x, (torch.Tensor, numpy.ndarray)): - return x.shape == (4,) and x.dtype in (torch.int32, torch.int64, numpy.int32, numpy.int64) + return x.shape == (4,) and x.dtype in ( + torch.int32, + torch.int64, + numpy.int32, + numpy.int64, + ) if isinstance(x, list): return len(x) == 4 and all(isinstance(i, int) for i in x) if isinstance(x, tuple): @@ -144,7 +154,14 @@ def is_Vec3dBatch(x: Any) -> bool: return ( len(x.shape) >= 1 and x.shape[-1] == 3 - and x.dtype in (torch.float16, torch.float32, torch.float64, numpy.float32, numpy.float64) + and x.dtype + in ( + torch.float16, + torch.float32, + torch.float64, + numpy.float32, + numpy.float64, + ) ) if isinstance(x, list): if len(x) == 0: @@ -262,11 +279,15 @@ def is_NumericScalar(x: Any) -> TypeGuard[NumericScalar]: return is_NumericScalarNative(x) or (isinstance(x, (torch.Tensor, numpy.ndarray)) and x.ndim == 0) -def is_SequenceOfNumericScalarNative(x: Any) -> TypeGuard[Sequence[NumericScalarNative]]: +def is_SequenceOfNumericScalarNative( + x: Any, +) -> TypeGuard[Sequence[NumericScalarNative]]: return isinstance(x, Sequence) and all(is_NumericScalarNative(item) for item in x) -def is_SequenceOfSequenceOfNumericScalarNative(x: Any) -> TypeGuard[Sequence[Sequence[NumericScalarNative]]]: +def is_SequenceOfSequenceOfNumericScalarNative( + x: Any, +) -> TypeGuard[Sequence[Sequence[NumericScalarNative]]]: return isinstance(x, Sequence) and all(is_SequenceOfNumericScalarNative(item) for item in x) @@ -472,7 +493,12 @@ def to_GenericTensorBroadcastableRank1( if is_NumericScalar(x): result = to_GenericScalar( - x, dtype, allowed_torch_dtypes, allowed_numpy_dtypes, dtype_category, value_constraint=ValueConstraint.NONE + x, + dtype, + allowed_torch_dtypes, + allowed_numpy_dtypes, + dtype_category, + value_constraint=ValueConstraint.NONE, ) try: result_shape = torch.broadcast_shapes(result.shape, test_shape) @@ -762,7 +788,9 @@ def to_GenericTensorBroadcastableRank3( def to_IntegerScalar( - x: NumericScalar, dtype: torch.dtype = torch.int64, value_constraint: ValueConstraint = ValueConstraint.NONE + x: NumericScalar, + dtype: torch.dtype = torch.int64, + value_constraint: ValueConstraint = ValueConstraint.NONE, ) -> torch.Tensor: """ Converts a NumericScalar to an integer scalar tensor. @@ -789,7 +817,9 @@ def to_IntegerScalar( def to_FloatingScalar( - x: NumericScalar, dtype: torch.dtype = torch.float32, value_constraint: ValueConstraint = ValueConstraint.NONE + x: NumericScalar, + dtype: torch.dtype = torch.float32, + value_constraint: ValueConstraint = ValueConstraint.NONE, ) -> torch.Tensor: """ Converts a NumericScalar to a floating scalar tensor. @@ -808,8 +838,22 @@ def to_FloatingScalar( return to_GenericScalar( x, dtype, - allowed_torch_dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.float64), - allowed_numpy_dtypes=(np.int32, np.int64, np.uint32, np.uint64, np.float16, np.float32, np.float64), + allowed_torch_dtypes=( + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + ), + allowed_numpy_dtypes=( + np.int32, + np.int64, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ), dtype_category="int or float", value_constraint=value_constraint, ) @@ -880,8 +924,22 @@ def to_FloatingTensorBroadcastableRank1( x, test_shape, dtype, - allowed_torch_dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.float64), - allowed_numpy_dtypes=(np.int32, np.int64, np.uint32, np.uint64, np.float16, np.float32, np.float64), + allowed_torch_dtypes=( + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + ), + allowed_numpy_dtypes=( + np.int32, + np.int64, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ), dtype_category="int or float", value_constraint=value_constraint, do_broadcast_to=do_broadcast_to, @@ -988,8 +1046,22 @@ def to_FloatingTensorBroadcastableRank2( x, test_shape, dtype, - allowed_torch_dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.float64), - allowed_numpy_dtypes=(np.int32, np.int64, np.uint32, np.uint64, np.float16, np.float32, np.float64), + allowed_torch_dtypes=( + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + ), + allowed_numpy_dtypes=( + np.int32, + np.int64, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ), dtype_category="int or float", value_constraint=value_constraint, do_broadcast_to=do_broadcast_to, @@ -1024,8 +1096,22 @@ def to_FloatingTensorBroadcastableRank3( x, test_shape, dtype, - allowed_torch_dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.float64), - allowed_numpy_dtypes=(np.int32, np.int64, np.uint32, np.uint64, np.float16, np.float32, np.float64), + allowed_torch_dtypes=( + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + ), + allowed_numpy_dtypes=( + np.int32, + np.int64, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + ), dtype_category="int or float", value_constraint=value_constraint, do_broadcast_to=do_broadcast_to, diff --git a/fvdb/utils/metrics/__init__.py b/fvdb/utils/metrics/__init__.py index 00bf3ba88..5102335fa 100644 --- a/fvdb/utils/metrics/__init__.py +++ b/fvdb/utils/metrics/__init__.py @@ -1,9 +1,7 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 -# Package exports for fvdb.utils.metrics - -from fvdb.utils.metrics.psnr import psnr -from fvdb.utils.metrics.ssim import ssim +# Backward-compatible re-exports; canonical location is fvdb.functional. +from fvdb.functional._metrics import psnr, ssim __all__ = ["psnr", "ssim"] diff --git a/fvdb/utils/metrics/psnr.py b/fvdb/utils/metrics/psnr.py deleted file mode 100644 index 7767944bb..000000000 --- a/fvdb/utils/metrics/psnr.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright Contributors to the OpenVDB Project -# SPDX-License-Identifier: Apache-2.0 - -import math -from typing import Literal - -import torch - - -def psnr( - noisy_images: torch.Tensor, - ground_truth_images: torch.Tensor, - max_value: float = 1.0, - reduction: Literal["none", "mean", "sum"] = "mean", -) -> torch.Tensor: - """ - Compute the Peak-Signal-to-Noise-Ratio (PSNR) ratio between two batches of images. - - Args: - noisy_images (torch.Tensor): A batch of noisy images of shape ``(B, C, H, W)`` - ground_truth_images (torch.Tensor): A batch of ground truth images of shape ``(B, C, H, W)`` - max_value (float): The maximum possible value images computed with this loss can have. - Default is 1.0. - reduction (Literal["none", "mean", "sum"]): How to reduce over the batch dimension. ``"sum"`` - and ``"mean"`` will add-up and average the losses across the batch respectively. ``"none"`` will - return each loss as a separate entry in the tensor. Default is ``"mean"``. - - Returns: - psnr (torch.Tensor): The PSNR between the two images. If reduction is not "none", the result - will be reduced over the batch dimension (*i.e.* will be a single scalar), otherwise it will - be a tensor of shape ``(B,)``. - """ - if max_value <= 0: - raise ValueError("max_value must be a positive number") - - if reduction not in ("none", "mean", "sum"): - raise ValueError("reduction must be one of ('none', 'mean', 'sum')") - - if (noisy_images.shape != ground_truth_images.shape) or (noisy_images.dim() != 4): - raise ValueError("Input images must have the same shape and be 4-dimensional with shape (B, C, H, W)") - - mse = torch.mean((noisy_images - ground_truth_images) ** 2, dim=(1, 2, 3)) # [B] - - # Expand log of ratio to difference of logs for better stability - psnr = 10.0 * (2.0 * math.log10(max_value) - torch.log10(mse)) - if reduction == "none": - return psnr - elif reduction == "mean": - return torch.mean(psnr) - elif reduction == "sum": - return torch.sum(psnr) diff --git a/fvdb/viz/_gaussian_splat_3d_view.py b/fvdb/viz/_gaussian_splat_3d_view.py index 3d23035cf..34d1cde87 100644 --- a/fvdb/viz/_gaussian_splat_3d_view.py +++ b/fvdb/viz/_gaussian_splat_3d_view.py @@ -4,8 +4,8 @@ from enum import Enum from typing import Any -from .._fvdb_cpp import GaussianSplat3d as GaussianSplat3dCpp from .._fvdb_cpp import GaussianSplat3dView as GaussianSplat3dViewCpp +import torch from ._viewer_server import _get_viewer_server_cpp @@ -34,7 +34,7 @@ def __init__( self, scene_name: str, name: str, - gaussian_splat_3d: GaussianSplat3dCpp, + gaussian_splat_3d, # GaussianSplat3d (Python class) tile_size: int = 16, min_radius_2d: float = 0.0, eps_2d: float = 0.3, @@ -69,9 +69,23 @@ def __init__( self._scene_name = scene_name self._name = name server = _get_viewer_server_cpp() - view = server.add_gaussian_splat_3d_view(scene_name=scene_name, name=name, gaussian_splat_3d=gaussian_splat_3d) - - if sh_ordering_mode not in (ShOrderingMode.RGB_RGB_RGB, ShOrderingMode.RRR_GGG_BBB): + # Pass raw tensors to the C++ viewer (no longer depends on C++ GaussianSplat3d class) + gs = gaussian_splat_3d + view = server.add_gaussian_splat_3d_view( + scene_name=scene_name, + name=name, + means=gs.means, + quats=gs.quats, + log_scales=gs.log_scales, + logit_opacities=gs.logit_opacities, + sh0=gs.sh0, + shN=gs.shN, + ) + + if sh_ordering_mode not in ( + ShOrderingMode.RGB_RGB_RGB, + ShOrderingMode.RRR_GGG_BBB, + ): raise ValueError(f"Invalid ShOrderingMode: {sh_ordering_mode}") view.tile_size = tile_size diff --git a/fvdb/viz/_point_cloud_view.py b/fvdb/viz/_point_cloud_view.py index 71b84cc42..6b182c9fe 100644 --- a/fvdb/viz/_point_cloud_view.py +++ b/fvdb/viz/_point_cloud_view.py @@ -5,8 +5,6 @@ import torch -from fvdb import GaussianSplat3d - from .._fvdb_cpp import GaussianSplat3dView as GaussianSplat3dViewCpp from ._viewer_server import _get_viewer_server_cpp @@ -68,19 +66,18 @@ def _rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor: quats[:, 0] = 1.0 # identity rotation logit_opacities = torch.full((positions.shape[0],), 10.0, dtype=torch.float32) log_scales = torch.full((positions.shape[0], 3), -20.0, dtype=torch.float32) # since scales are exp(log_scale) - sh0 = _rgb_to_sh(colors) + sh0 = _rgb_to_sh(colors).unsqueeze(1) shN = torch.zeros((positions.shape[0], 0, 3), dtype=torch.float32) - gs_impl = GaussianSplat3d.from_tensors( + view: GaussianSplat3dViewCpp = server.add_gaussian_splat_3d_view( + scene_name=scene_name, + name=name, means=means, quats=quats, log_scales=log_scales, logit_opacities=logit_opacities, sh0=sh0, shN=shN, - )._impl - view: GaussianSplat3dViewCpp = server.add_gaussian_splat_3d_view( - scene_name=scene_name, name=name, gaussian_splat_3d=gs_impl ) view.tile_size = 16 view.min_radius_2d = 0.0 diff --git a/fvdb/viz/_scene.py b/fvdb/viz/_scene.py index bc6e71ba1..669b91b4e 100644 --- a/fvdb/viz/_scene.py +++ b/fvdb/viz/_scene.py @@ -146,7 +146,7 @@ def add_gaussian_splat_3d( return GaussianSplat3dView( scene_name=self._name, name=name, - gaussian_splat_3d=gaussian_splat_3d._impl, + gaussian_splat_3d=gaussian_splat_3d, tile_size=tile_size, min_radius_2d=min_radius_2d, eps_2d=eps_2d, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0876dc052..2417eafaf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,22 +44,15 @@ execute_process( # Source files set(FVDB_CPP_FILES fvdb/Config.cpp - fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp - fvdb/detail/autograd/GaussianProjection.cpp - fvdb/detail/autograd/GaussianRasterize.cpp - fvdb/detail/autograd/GaussianRasterizeFromWorld.cpp - fvdb/detail/autograd/GaussianRasterizeSparse.cpp fvdb/detail/autograd/VolumeRender.cpp fvdb/detail/io/GaussianPlyIO.cpp fvdb/detail/io/LoadNanovdb.cpp fvdb/detail/io/SaveNanoVDB.cpp - fvdb/detail/ops/gsplat/GaussianUtils.cpp fvdb/detail/ops/jagged/JaggedReductions.cpp fvdb/detail/TorchDeviceBuffer.cpp fvdb/detail/viewer/GaussianSplat3dView.cpp fvdb/detail/viewer/Viewer.cpp fvdb/FVDB.cpp - fvdb/GaussianSplat3d.cpp fvdb/JaggedTensor.cpp ) @@ -95,26 +88,26 @@ set(FVDB_CU_FILES fvdb/detail/ops/AvgPool.cu fvdb/detail/ops/MaxPool.cu fvdb/detail/ops/GridEdgeNetwork.cu - fvdb/detail/ops/gsplat/FusedSSIM.cu - fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu - fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu - fvdb/detail/ops/gsplat/GaussianMCMCRelocation.cu - fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu - fvdb/detail/ops/gsplat/GaussianProjectionForward.cu - fvdb/detail/ops/gsplat/GaussianProjectionUT.cu - fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu - fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.cu - fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu - fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu - fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.cu - fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.cu - fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.cu - fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.cu - fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.cu - fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.cu - fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.cu - fvdb/detail/ops/gsplat/GaussianSplatSparse.cu - fvdb/detail/ops/gsplat/GaussianTileIntersection.cu + fvdb/detail/ops/FusedSSIM.cu + fvdb/detail/ops/ComputeGaussianNanInfMask.cu + fvdb/detail/ops/AddNoiseToGaussianMeans.cu + fvdb/detail/ops/RelocateGaussians.cu + fvdb/detail/ops/ProjectGaussiansAnalyticBackward.cu + fvdb/detail/ops/ProjectGaussiansAnalyticForward.cu + fvdb/detail/ops/ProjectGaussiansUtForward.cu + fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.cu + fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.cu + fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.cu + fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.cu + fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.cu + fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.cu + fvdb/detail/ops/CountContributingGaussians.cu + fvdb/detail/ops/IdentifyTopContributingGaussians.cu + fvdb/detail/ops/IdentifyContributingGaussians.cu + fvdb/detail/ops/EvalGaussianShBackward.cu + fvdb/detail/ops/EvalGaussianShForward.cu + fvdb/detail/ops/BuildSparseGaussianTileLayout.cu + fvdb/detail/ops/IntersectGaussianTiles.cu fvdb/detail/ops/IjkForMesh.cu fvdb/detail/ops/IjkToIndex.cu fvdb/detail/ops/IjkToInvIndex.cu diff --git a/src/fvdb/GaussianSplat3d.cpp b/src/fvdb/GaussianSplat3d.cpp deleted file mode 100644 index 2e03d862c..000000000 --- a/src/fvdb/GaussianSplat3d.cpp +++ /dev/null @@ -1,2402 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include -#include - -// Autograd headers -#include -#include -#include -#include - -// Ops headers -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace fvdb { - -using RenderMode = fvdb::detail::ops::RenderSettings::RenderMode; -using RenderSettings = fvdb::detail::ops::RenderSettings; - -namespace { - -using CameraModel = fvdb::GaussianSplat3d::CameraModel; -using ProjectionMethod = fvdb::GaussianSplat3d::ProjectionMethod; - -bool -usesOpenCVDistortion(const CameraModel cameraModel) { - return cameraModel == CameraModel::OPENCV_RADTAN_5 || - cameraModel == CameraModel::OPENCV_RATIONAL_8 || - cameraModel == CameraModel::OPENCV_RADTAN_THIN_PRISM_9 || - cameraModel == CameraModel::OPENCV_THIN_PRISM_12; -} - -ProjectionMethod -resolveProjectionMethod(const CameraModel cameraModel, const ProjectionMethod projectionMethod) { - if (projectionMethod == ProjectionMethod::AUTO) { - return usesOpenCVDistortion(cameraModel) ? ProjectionMethod::UNSCENTED - : ProjectionMethod::ANALYTIC; - } - return projectionMethod; -} - -void -validateCameraProjectionArgs(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const CameraModel cameraModel, - const ProjectionMethod requestedProjectionMethod, - const std::optional &distortionCoeffs) { - const int64_t C = worldToCameraMatrices.size(0); - TORCH_CHECK(C > 0, "At least one camera must be provided (got 0)"); - TORCH_CHECK(worldToCameraMatrices.sizes() == torch::IntArrayRef({C, 4, 4}), - "worldToCameraMatrices must have shape (C, 4, 4)"); - TORCH_CHECK(projectionMatrices.sizes() == torch::IntArrayRef({C, 3, 3}), - "projectionMatrices must have shape (C, 3, 3)"); - TORCH_CHECK(worldToCameraMatrices.is_contiguous(), "worldToCameraMatrices must be contiguous"); - TORCH_CHECK(projectionMatrices.is_contiguous(), "projectionMatrices must be contiguous"); - - const ProjectionMethod resolvedProjectionMethod = - resolveProjectionMethod(cameraModel, requestedProjectionMethod); - - if (distortionCoeffs.has_value()) { - TORCH_CHECK(distortionCoeffs->sizes() == torch::IntArrayRef({C, 12}), - "distortionCoeffs must have shape (C, 12)"); - TORCH_CHECK(distortionCoeffs->is_contiguous(), "distortionCoeffs must be contiguous"); - } - - if (usesOpenCVDistortion(cameraModel)) { - TORCH_CHECK(distortionCoeffs.has_value(), - "distortionCoeffs must be provided for OpenCV camera models"); - TORCH_CHECK(resolvedProjectionMethod == ProjectionMethod::UNSCENTED, - "OpenCV camera models require ProjectionMethod::UNSCENTED or AUTO"); - } -} - -} // namespace - -torch::Tensor -GaussianSplat3d::evalSphericalHarmonicsImpl(const int64_t shDegreeToUse, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &perGaussianProjectedRadii) const { - FVDB_FUNC_RANGE(); - const auto K = mShN.size(1) + 1; // number of SH bases - const auto C = worldToCameraMatrices.size(0); // number of cameras - const auto actualShDegree = shDegreeToUse < 0 ? (std::sqrt(K) - 1) : shDegreeToUse; - if (actualShDegree == 0) { - return detail::autograd::EvaluateSphericalHarmonics::apply( - actualShDegree, C, torch::nullopt, mSh0, torch::nullopt, perGaussianProjectedRadii)[0]; - } else { - // FIXME (Francis): Do this in the kernel instead of materializing a large - // tensor here. It's a bit annoying because we'll have to update - // the current backward pass - auto [camToWorldMatrices, info] = torch::linalg_inv_ex(worldToCameraMatrices); - // Equivalent to viewDirs = means[None, :, :] - camToWorldMatrices[:, None, :3, 3] - // NOTE: viewDirs are not normalized here, they get normalized in the spherical - // harmonics evaluation kernel - const torch::Tensor viewDirs = - mMeans.index( - {torch::indexing::None, torch::indexing::Slice(), torch::indexing::Slice()}) - - camToWorldMatrices.index({torch::indexing::Slice(), - torch::indexing::None, - torch::indexing::Slice(0, 3), - 3}); // [1, N, 3] - [C, 1, 3] - return detail::autograd::EvaluateSphericalHarmonics::apply( - actualShDegree, C, viewDirs, mSh0, mShN, perGaussianProjectedRadii)[0]; - } -} - -void -GaussianSplat3d::checkState(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN) { - const int64_t N = means.size(0); // number of gaussians - - TORCH_CHECK_VALUE(means.sizes() == torch::IntArrayRef({N, 3}), "means must have shape (N, 3)"); - TORCH_CHECK_VALUE(quats.sizes() == torch::IntArrayRef({N, 4}), "quats must have shape (N, 4)"); - TORCH_CHECK_VALUE(logScales.sizes() == torch::IntArrayRef({N, 3}), - "scales must have shape (N, 3)"); - TORCH_CHECK_VALUE(logitOpacities.sizes() == torch::IntArrayRef({N}), - "opacities must have shape (N)"); - TORCH_CHECK_VALUE(sh0.size(0) == N, "sh0 must have shape (N, 1, D)"); - TORCH_CHECK_VALUE(sh0.size(1) == 1, "sh0 must have shape (N, 1, D)"); - TORCH_CHECK_VALUE(sh0.dim() == 3, "sh0 must have shape (N, 1, D)"); - TORCH_CHECK_VALUE(shN.size(0) == N, "shN must have shape (N, K-1, D)"); - TORCH_CHECK_VALUE(shN.dim() == 3, "shN must have shape (N, K-1, D)"); - - TORCH_CHECK_VALUE(means.device() == quats.device(), "All tensors must be on the same device"); - TORCH_CHECK_VALUE(means.device() == logScales.device(), - "All tensors must be on the same device"); - TORCH_CHECK_VALUE(means.device() == logitOpacities.device(), - "All tensors must be on the same device"); - TORCH_CHECK_VALUE(means.device() == sh0.device(), "All tensors must be on the same device"); - TORCH_CHECK_VALUE(means.device() == shN.device(), "All tensors must be on the same device"); - - TORCH_CHECK_VALUE(torch::isFloatingType(means.scalar_type()), - "All tensors must be of floating point type"); - TORCH_CHECK_VALUE(means.scalar_type() == quats.scalar_type(), - "All tensors must be of the same type"); - TORCH_CHECK_VALUE(means.scalar_type() == logScales.scalar_type(), - "All tensors must be of the same type"); - TORCH_CHECK_VALUE(means.scalar_type() == logitOpacities.scalar_type(), - "All tensors must be of the same type"); - TORCH_CHECK_VALUE(means.scalar_type() == sh0.scalar_type(), - "All tensors must be of the same type"); - TORCH_CHECK_VALUE(means.scalar_type() == shN.scalar_type(), - "All tensors must be of the same type"); -} - -GaussianSplat3d::GaussianSplat3d(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN, - const bool accumulateMeans2dGradients, - const bool accumulateMax2dRadii, - const bool copyAndDetach) - : mMeans(means), mQuats(quats), mLogScales(logScales), mLogitOpacities(logitOpacities), - mSh0(sh0), mShN(shN), mAccumulateMean2dGradients(accumulateMeans2dGradients), - mAccumulateMax2dRadii(accumulateMax2dRadii) { - const int64_t N = means.size(0); // number of gaussians - if (mSh0.dim() == 2) { - TORCH_CHECK(mSh0.size(0) == N, "sh0 must have shape (N, 1, D) or (N, D)"); - mSh0 = mSh0.unsqueeze(1); - } - if (copyAndDetach) { - mMeans = means.detach(); - mQuats = quats.detach(); - mLogScales = logScales.detach(); - mLogitOpacities = logitOpacities.detach(); - mSh0 = sh0.detach(); - mShN = shN.detach(); - } - checkState(mMeans, mQuats, mLogScales, mLogitOpacities, mSh0, mShN); -} - -void -GaussianSplat3d::setState(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN) { - checkState(means, quats, logScales, logitOpacities, sh0, shN); - resetAccumulatedGradientState(); - - mMeans = means; - mQuats = quats; - mLogScales = logScales; - mLogitOpacities = logitOpacities; - mSh0 = sh0; - mShN = shN; -} - -std::unordered_map -GaussianSplat3d::stateDict() const { - auto ret = std::unordered_map{{"means", mMeans}, - {"quats", mQuats}, - {"log_scales", mLogScales}, - {"logit_opacities", mLogitOpacities}, - {"sh0", mSh0}, - {"shN", mShN}}; - - const auto boolOpts = torch::TensorOptions().dtype(torch::kBool); - ret["accumulate_means_2d_gradients"] = - mAccumulateMean2dGradients ? torch::ones({}, boolOpts) : torch::zeros({}, boolOpts); - ret["accumulate_max_2d_radii"] = - mAccumulateMax2dRadii ? torch::ones({}, boolOpts) : torch::zeros({}, boolOpts); - - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() != 0) { - ret["accumulated_mean_2d_gradient_norms_for_grad"] = - mAccumulatedNormalized2dMeansGradientNormsForGrad; - } - if (mAccumulated2dRadiiForGrad.numel() != 0) { - ret["accumulated_max_2d_radii_for_grad"] = mAccumulated2dRadiiForGrad; - } - if (mGradientStepCountForGrad.numel() != 0) { - ret["accumulated_gradient_step_counts_for_grad"] = mGradientStepCountForGrad; - } - return ret; -} - -void -GaussianSplat3d::loadStateDict(const std::unordered_map &stateDict) { - TORCH_CHECK_VALUE(stateDict.count("means") == 1, "Missing key 'means' in state dict"); - TORCH_CHECK_VALUE(stateDict.count("quats") == 1, "Missing key 'quats' in state dict"); - TORCH_CHECK_VALUE(stateDict.count("log_scales") == 1, "Missing key 'log_scales' in state dict"); - TORCH_CHECK_VALUE(stateDict.count("logit_opacities") == 1, - "Missing key 'logit_opacities' in state dict"); - TORCH_CHECK_VALUE(stateDict.count("sh0") == 1, "Missing key 'sh0' in state dict"); - TORCH_CHECK_VALUE(stateDict.count("shN") == 1, "Missing key 'shN' in state dict"); - - TORCH_CHECK_VALUE(stateDict.count("accumulate_means_2d_gradients") == 1, - "Missing key 'accumulate_means_2d_gradients' in state dict"); - - TORCH_CHECK_VALUE(stateDict.count("accumulate_max_2d_radii") == 1, - "Missing key 'accumulate_max_2d_radii' in state dict"); - - const torch::Tensor means = stateDict.at("means"); - const torch::Tensor quats = stateDict.at("quats"); - const torch::Tensor logScales = stateDict.at("log_scales"); - const torch::Tensor logitOpacities = stateDict.at("logit_opacities"); - const torch::Tensor sh0 = stateDict.at("sh0"); - const torch::Tensor shN = stateDict.at("shN"); - - const int64_t N = means.size(0); // number of gaussians - - checkState(means, quats, logScales, logitOpacities, sh0, shN); - - const bool accumulateMeans2dGrad = - stateDict.at("accumulate_means_2d_gradients").item().toBool(); - const bool accumulateMax2dRadii = stateDict.at("accumulate_max_2d_radii").item().toBool(); - torch::Tensor accumulatedNormalized2dMeansGradientNormsForGrad; - torch::Tensor accumulated2dRadiiForGrad; - torch::Tensor gradientStepCountForGrad; - - if (stateDict.count("accumulated_mean_2d_gradient_norms_for_grad") > 0) { - accumulatedNormalized2dMeansGradientNormsForGrad = - stateDict.at("accumulated_mean_2d_gradient_norms_for_grad"); - TORCH_CHECK_VALUE(accumulatedNormalized2dMeansGradientNormsForGrad.numel() == N, - "accumulated_mean_2d_gradient_norms_for_grad must have shape (N)"); - TORCH_CHECK_VALUE( - accumulatedNormalized2dMeansGradientNormsForGrad.device() == means.device(), - "accumulated_mean_2d_gradient_norms_for_grad must be on the same device as " - "means"); - TORCH_CHECK_VALUE(accumulatedNormalized2dMeansGradientNormsForGrad.dim() == 1, - "accumulated_mean_2d_gradient_norms_for_grad must have one dimension"); - TORCH_CHECK_VALUE(accumulatedNormalized2dMeansGradientNormsForGrad.scalar_type() == - means.scalar_type(), - "accumulated_mean_2d_gradient_norms_for_grad must have the same type as " - "means"); - TORCH_CHECK_VALUE(stateDict.count("accumulated_gradient_step_counts_for_grad") != 0, - "gradient_step_counts_for_grad " - "must be non-empty if " - "accumulated_mean_2d_gradient_norms_for_grad " - "is non-empty"); - gradientStepCountForGrad = stateDict.at("accumulated_gradient_step_counts_for_grad"); - TORCH_CHECK_VALUE(gradientStepCountForGrad.numel() != 0, - "gradient_step_counts_for_grad " - "must be non-empty if " - "accumulated_mean_2d_gradient_norms_for_grad " - "is non-empty"); - TORCH_CHECK_VALUE(gradientStepCountForGrad.numel() == N, - "accumulated_gradient_step_counts_for_grad must have shape (N)"); - TORCH_CHECK_VALUE(gradientStepCountForGrad.device() == means.device(), - "accumulated_gradient_step_counts_for_grad must be on the same device as " - "means"); - TORCH_CHECK_VALUE(gradientStepCountForGrad.dim() == 1, - "accumulated_gradient_step_counts_for_grad must have one dimension"); - TORCH_CHECK_VALUE(gradientStepCountForGrad.scalar_type() == torch::kInt32, - "accumulated_gradient_step_counts_for_grad must be of type int32"); - } - - if (stateDict.count("accumulated_max_2d_radii_for_grad") > 0) { - accumulated2dRadiiForGrad = stateDict.at("accumulated_max_2d_radii_for_grad"); - TORCH_CHECK_VALUE(accumulated2dRadiiForGrad.numel() == N, - "accumulated_max_2d_radii_for_grad must have shape (N)"); - TORCH_CHECK_VALUE(accumulated2dRadiiForGrad.device() == means.device(), - "accumulated_max_2d_radii_for_grad must be on the same device as " - "means"); - TORCH_CHECK_VALUE(accumulated2dRadiiForGrad.dim() == 1, - "accumulated_max_2d_radii_for_grad must have one dimension"); - TORCH_CHECK_VALUE(accumulated2dRadiiForGrad.scalar_type() == torch::kInt32, - "accumulated_max_2d_radii_for_grad must be of type int32"); - } - - mMeans = means; - mQuats = quats; - mLogScales = logScales; - mLogitOpacities = logitOpacities; - mSh0 = sh0; - mShN = shN; - - mAccumulateMean2dGradients = accumulateMeans2dGrad; - mAccumulateMax2dRadii = accumulateMax2dRadii; - mAccumulatedNormalized2dMeansGradientNormsForGrad = - accumulatedNormalized2dMeansGradientNormsForGrad; - mAccumulated2dRadiiForGrad = accumulated2dRadiiForGrad; - mGradientStepCountForGrad = gradientStepCountForGrad; -} - -GaussianSplat3d::ProjectedGaussianSplats -GaussianSplat3d::projectGaussiansImpl(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const RenderSettings &settings, - const CameraModel cameraModel) { - FVDB_FUNC_RANGE(); - const bool ortho = cameraModel == CameraModel::ORTHOGRAPHIC; - const int C = worldToCameraMatrices.size(0); // number of cameras - const int N = mMeans.size(0); // number of gaussians - - ProjectedGaussianSplats ret; - ret.mRenderSettings = settings; - ret.mCameraModel = cameraModel; - ret.mProjectionMethod = ProjectionMethod::ANALYTIC; - - // Track gradients for the 2D means in the backward pass if you're optimizing - std::optional maybeNormalizedMeans2dGradientNorms = std::nullopt; - std::optional maybePerGaussianRadiiForGrad = std::nullopt; - std::optional maybeGradientStepCount = std::nullopt; - if (mAccumulateMean2dGradients) { - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() != N) { - mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::zeros({N}, mMeans.options()); - } - if (mGradientStepCountForGrad.numel() != N) { - mGradientStepCountForGrad = torch::zeros( - {N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device())); - } - maybeNormalizedMeans2dGradientNorms = mAccumulatedNormalized2dMeansGradientNormsForGrad; - maybeGradientStepCount = mGradientStepCountForGrad; - } - if (mAccumulateMax2dRadii) { - if (mAccumulated2dRadiiForGrad.numel() != N && mAccumulateMax2dRadii) { - mAccumulated2dRadiiForGrad = torch::zeros( - {N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device())); - } - maybePerGaussianRadiiForGrad = mAccumulated2dRadiiForGrad; - } - - // Project to image plane - const auto projectionResults = - detail::autograd::ProjectGaussians::apply(mMeans, - mQuats, - mLogScales, - worldToCameraMatrices, - projectionMatrices, - settings.imageWidth, - settings.imageHeight, - settings.eps2d, - settings.nearPlane, - settings.farPlane, - settings.radiusClip, - settings.antialias, - ortho, - maybeNormalizedMeans2dGradientNorms, - maybePerGaussianRadiiForGrad, - maybeGradientStepCount); - ret.perGaussianRadius = projectionResults[0]; - ret.perGaussian2dMean = projectionResults[1]; - ret.perGaussianDepth = projectionResults[2]; - ret.perGaussianConic = projectionResults[3]; - // FIXME: Use accessors in the kernel and use exapand - ret.perGaussianOpacity = opacities().repeat({C, 1}); - if (settings.antialias) { - ret.perGaussianOpacity *= projectionResults[4]; - // FIXME (Francis): The contiguity requirement is dumb and should be - // removed by using accessors in the kernel - ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous(); - } - - ret.perGaussianRenderQuantity = [&]() { - torch::Tensor renderQuantity; - if (settings.renderMode == RenderMode::DEPTH) { - renderQuantity = ret.perGaussianDepth.unsqueeze(-1); // [C, N, 1] - } else if (settings.renderMode == RenderMode::RGB || - settings.renderMode == RenderMode::RGBD) { - renderQuantity = evalSphericalHarmonicsImpl( - settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius); - - if (settings.renderMode == RenderMode::RGBD) { - renderQuantity = torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)}, - -1); // [C, N, D + 1] - } - } else { - TORCH_CHECK_VALUE(false, "Invalid render mode"); - } - return renderQuantity; - }(); - - // Intersect projected Gaussians with image tiles [non-differentiable] - const int numTilesW = std::ceil(settings.imageWidth / static_cast(settings.tileSize)); - const int numTilesH = std::ceil(settings.imageHeight / static_cast(settings.tileSize)); - const auto [tileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return detail::ops::dispatchGaussianTileIntersection(ret.perGaussian2dMean, - ret.perGaussianRadius, - ret.perGaussianDepth, - at::nullopt, - C, - settings.tileSize, - numTilesH, - numTilesW); - }); - ret.tileOffsets = tileOffsets; // [C, TH, TW] - ret.tileGaussianIds = tileGaussianIds; // [TOT_INTERSECTIONS] - - return ret; -} - -GaussianSplat3d::ProjectedGaussianSplats -GaussianSplat3d::projectGaussiansForCameraImpl( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs) { - FVDB_FUNC_RANGE(); - validateCameraProjectionArgs( - worldToCameraMatrices, projectionMatrices, cameraModel, projectionMethod, distortionCoeffs); - - const ProjectionMethod resolvedProjectionMethod = - resolveProjectionMethod(cameraModel, projectionMethod); - - RenderSettings settingsForProjection = settings; - - if (resolvedProjectionMethod == ProjectionMethod::ANALYTIC) { - return projectGaussiansImpl( - worldToCameraMatrices, projectionMatrices, settingsForProjection, cameraModel); - } - - const int C = worldToCameraMatrices.size(0); - ProjectedGaussianSplats ret; - ret.mRenderSettings = settingsForProjection; - ret.mCameraModel = cameraModel; - ret.mProjectionMethod = resolvedProjectionMethod; - - const torch::Tensor distortionCoeffsTensor = distortionCoeffs.has_value() - ? distortionCoeffs.value() - : torch::empty({C, 0}, mMeans.options()); - fvdb::detail::ops::UTParams utParams = fvdb::detail::ops::UTParams{}; - const auto projectionResults = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianProjectionForwardUT( - mMeans, - mQuats, - mLogScales, - worldToCameraMatrices, - worldToCameraMatrices, - projectionMatrices, - fvdb::detail::ops::RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffsTensor, - static_cast(settings.imageWidth), - static_cast(settings.imageHeight), - settings.eps2d, - settings.nearPlane, - settings.farPlane, - settings.radiusClip, - settings.antialias); - }); - - ret.perGaussianRadius = std::get<0>(projectionResults); - ret.perGaussian2dMean = std::get<1>(projectionResults); - ret.perGaussianDepth = std::get<2>(projectionResults); - ret.perGaussianConic = std::get<3>(projectionResults); - - ret.perGaussianOpacity = opacities().repeat({C, 1}); - if (settings.antialias) { - const torch::Tensor compensations = std::get<4>(projectionResults); - TORCH_CHECK(compensations.defined(), - "UT projection returned an undefined compensation tensor in antialias mode"); - ret.perGaussianOpacity *= compensations; - ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous(); - } - - ret.perGaussianRenderQuantity = [&]() { - torch::Tensor renderQuantity; - if (settings.renderMode == RenderMode::DEPTH) { - renderQuantity = ret.perGaussianDepth.unsqueeze(-1); - } else if (settings.renderMode == RenderMode::RGB || - settings.renderMode == RenderMode::RGBD) { - renderQuantity = evalSphericalHarmonicsImpl( - settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius); - if (settings.renderMode == RenderMode::RGBD) { - renderQuantity = - torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)}, -1); - } - } else { - TORCH_CHECK_VALUE(false, "Invalid render mode"); - } - return renderQuantity; - }(); - - const int numTilesW = std::ceil(settings.imageWidth / static_cast(settings.tileSize)); - const int numTilesH = std::ceil(settings.imageHeight / static_cast(settings.tileSize)); - std::tie(ret.tileOffsets, ret.tileGaussianIds) = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return detail::ops::dispatchGaussianTileIntersection(ret.perGaussian2dMean, - ret.perGaussianRadius, - ret.perGaussianDepth, - at::nullopt, - C, - settings.tileSize, - numTilesH, - numTilesW); - }); - - return ret; -} - -/// @brief Deduplicate pixel coordinates in a JaggedTensor. -/// -/// Encodes each pixel as a single int64 key incorporating its batch index and 2D coordinate, -/// sorts keys to find unique groups and builds an inverse mapping. Returns the deduplicated -/// pixels as a new JaggedTensor, the inverse index tensor, and a flag indicating whether any -/// duplicates were found. -/// -/// @param pixelsToRender The input JaggedTensor of pixel coordinates [total_pixels, 2] -/// @param imageWidth Width of each image in pixels -/// @param imageHeight Height of each image in pixels -/// @return Tuple of (uniquePixels JaggedTensor, inverseIndices tensor, hasDuplicates bool) -std::tuple -deduplicatePixels(const JaggedTensor &pixelsToRender, int64_t imageWidth, int64_t imageHeight) { - const auto totalPixels = pixelsToRender.rsize(0); - if (totalPixels == 0) { - auto emptyInverse = torch::empty({0}, pixelsToRender.jdata().options().dtype(torch::kLong)); - return {pixelsToRender, emptyInverse, false}; - } - - const auto device = pixelsToRender.device(); - const auto jdata = pixelsToRender.jdata(); - const auto jidx = pixelsToRender.jidx(); - const int64_t numPixelsPerImage = imageHeight * imageWidth; - const auto longOpts = torch::TensorOptions().device(device).dtype(torch::kLong); - const auto boolOpts = torch::TensorOptions().device(device).dtype(torch::kBool); - - // Encode (batchIdx, row, col) into a single int64 key: - // key = batchIdx * (H * W) + row * W + col - // For single-list JaggedTensors, jidx is empty so we skip the batch term entirely. - const bool singleList = (jidx.size(0) == 0); - torch::Tensor rows, cols; - if (jdata.scalar_type() == torch::kInt32) { - rows = jdata.select(1, 0).to(torch::kLong); - cols = jdata.select(1, 1).to(torch::kLong); - } else { - rows = jdata.select(1, 0); - cols = jdata.select(1, 1); - } - torch::Tensor keys; - if (singleList) { - keys = rows * imageWidth + cols; - } else { - auto jidxLong = jidx.to(torch::kLong); - keys = jidxLong * numPixelsPerImage + rows * imageWidth + cols; - } - - // Sort keys and find group boundaries - auto [sortedKeys, sortPerm] = keys.sort(); - - auto isGroupStart = torch::ones({totalPixels}, boolOpts); - if (totalPixels > 1) { - isGroupStart.slice(0, 1).copy_(sortedKeys.slice(0, 1) != sortedKeys.slice(0, 0, -1)); - } - - // Extract first-of-group positions before mutating isGroupStart - auto firstInSorted = isGroupStart.nonzero().squeeze(1); - - // Assign a group ID (0-based) to each sorted position via in-place cumsum - auto groupIds = isGroupStart.to(torch::kLong); - groupIds.cumsum_(0).sub_(1); - const auto numUnique = groupIds[-1].item() + 1; - - if (numUnique == totalPixels) { - return {pixelsToRender, torch::arange(totalPixels, longOpts), false}; - } - - // inverseIndices: map each original position to its group ID (= index in unique output) - auto inverseIndices = torch::empty({totalPixels}, longOpts); - inverseIndices.index_put_({sortPerm}, groupIds); - - // Pick the first occurrence of each group (in sorted order) and map to original indices - auto uniqueOrigIndices = sortPerm.index_select(0, firstInSorted); - auto uniqueJData = jdata.index_select(0, uniqueOrigIndices); - - // Build new JaggedTensor offsets for the unique pixels - auto uniqueBatchIdx = singleList ? torch::zeros({numUnique}, longOpts) - : jidx.to(torch::kLong).index_select(0, uniqueOrigIndices); - auto numLists = pixelsToRender.num_outer_lists(); - auto countsPerList = torch::bincount(uniqueBatchIdx, {}, numLists); - auto newOffsets = torch::zeros({numLists + 1}, longOpts); - newOffsets.slice(0, 1).copy_(countsPerList.cumsum(0)); - - auto newJidx = uniqueBatchIdx.to(fvdb::JIdxScalarType); - - auto uniquePixels = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - uniqueJData, newOffsets, newJidx, pixelsToRender.jlidx(), numLists); - - return {uniquePixels, inverseIndices, true}; -} - -GaussianSplat3d::SparseProjectedGaussianSplats -GaussianSplat3d::sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const RenderSettings &settings, - const CameraModel cameraModel) { - FVDB_FUNC_RANGE(); - const bool ortho = cameraModel == CameraModel::ORTHOGRAPHIC; - const int C = worldToCameraMatrices.size(0); // number of cameras - const int N = mMeans.size(0); // number of gaussians - TORCH_CHECK(static_cast(pixelsToRender.num_outer_lists()) == C, - "pixelsToRender must have the same number of outer lists as the number of cameras. " - "Got ", - pixelsToRender.num_outer_lists(), - " outer lists but ", - C, - " cameras. "); - - SparseProjectedGaussianSplats ret; - ret.mRenderSettings = settings; - ret.mCameraModel = cameraModel; - ret.mProjectionMethod = ProjectionMethod::ANALYTIC; - - // Deduplicate pixel coordinates. computeSparseInfo requires unique pixels because its - // tile bitmask has one bit per pixel position. We scatter results back after rendering. - auto [uniquePixels, inverseIndices, hasDuplicates] = - deduplicatePixels(pixelsToRender, settings.imageWidth, settings.imageHeight); - ret.inverseIndices = inverseIndices; - ret.uniquePixelsToRender = uniquePixels; - ret.hasDuplicates = hasDuplicates; - - // Compute sparse tile info using deduplicated pixels - const int numTilesW = std::ceil(settings.imageWidth / static_cast(settings.tileSize)); - const int numTilesH = std::ceil(settings.imageHeight / static_cast(settings.tileSize)); - - const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo(settings.tileSize, numTilesW, numTilesH, uniquePixels); - - ret.activeTiles = activeTiles; - ret.activeTileMask = activeTileMask; - ret.tilePixelMask = tilePixelMask; - ret.tilePixelCumsum = tilePixelCumsum; - ret.pixelMap = pixelMap; - - // Track gradients for the 2D means in the backward pass if you're optimizing - std::optional maybeNormalizedMeans2dGradientNorms = std::nullopt; - std::optional maybePerGaussianRadiiForGrad = std::nullopt; - std::optional maybeGradientStepCount = std::nullopt; - if (mAccumulateMean2dGradients) { - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() != N) { - mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::zeros({N}, mMeans.options()); - } - if (mGradientStepCountForGrad.numel() != N) { - mGradientStepCountForGrad = torch::zeros( - {N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device())); - } - maybeNormalizedMeans2dGradientNorms = mAccumulatedNormalized2dMeansGradientNormsForGrad; - maybeGradientStepCount = mGradientStepCountForGrad; - } - if (mAccumulateMax2dRadii) { - if (mAccumulated2dRadiiForGrad.numel() != N && mAccumulateMax2dRadii) { - mAccumulated2dRadiiForGrad = torch::zeros( - {N}, torch::TensorOptions().dtype(torch::kInt32).device(mMeans.device())); - } - maybePerGaussianRadiiForGrad = mAccumulated2dRadiiForGrad; - } - - // Project to image plane - const auto projectionResults = - detail::autograd::ProjectGaussians::apply(mMeans, - mQuats, - mLogScales, - worldToCameraMatrices, - projectionMatrices, - settings.imageWidth, - settings.imageHeight, - settings.eps2d, - settings.nearPlane, - settings.farPlane, - settings.radiusClip, - settings.antialias, - ortho, - maybeNormalizedMeans2dGradientNorms, - maybePerGaussianRadiiForGrad, - maybeGradientStepCount); - ret.perGaussianRadius = projectionResults[0]; - ret.perGaussian2dMean = projectionResults[1]; - ret.perGaussianDepth = projectionResults[2]; - ret.perGaussianConic = projectionResults[3]; - // FIXME: Use accessors in the kernel and use expand - ret.perGaussianOpacity = opacities().repeat({C, 1}); - if (settings.antialias) { - ret.perGaussianOpacity *= projectionResults[4]; - // FIXME (Francis): The contiguity requirement is dumb and should be - // removed by using accessors in the kernel - ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous(); - } - - ret.perGaussianRenderQuantity = [&]() { - torch::Tensor renderQuantity; - if (settings.renderMode == RenderMode::DEPTH) { - renderQuantity = ret.perGaussianDepth.unsqueeze(-1); // [C, N, 1] - } else if (settings.renderMode == RenderMode::RGB || - settings.renderMode == RenderMode::RGBD) { - renderQuantity = evalSphericalHarmonicsImpl( - settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius); - - if (settings.renderMode == RenderMode::RGBD) { - renderQuantity = torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)}, - -1); // [C, N, D + 1] - } - } else { - TORCH_CHECK_VALUE(false, "Invalid render mode"); - } - return renderQuantity; - }(); - - // Intersect projected Gaussians with image tiles [non-differentiable] - // Use sparse tile intersection which only computes intersections for active tiles - const auto [sparseTileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return detail::ops::dispatchGaussianSparseTileIntersection(ret.perGaussian2dMean, - ret.perGaussianRadius, - ret.perGaussianDepth, - ret.activeTileMask, - ret.activeTiles, - at::nullopt, - C, - settings.tileSize, - numTilesH, - numTilesW); - }); - // Use sparse 1D tile offsets - RasterizeCommonArgs detects the format from dimensions - ret.tileOffsets = sparseTileOffsets; // [num_active_tiles + 1] - ret.tileGaussianIds = tileGaussianIds; // [TOT_INTERSECTIONS] - - return ret; -} - -GaussianSplat3d::SparseProjectedGaussianSplats -GaussianSplat3d::sparseProjectGaussiansForCameraImpl( - const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs) { - FVDB_FUNC_RANGE(); - validateCameraProjectionArgs( - worldToCameraMatrices, projectionMatrices, cameraModel, projectionMethod, distortionCoeffs); - - const ProjectionMethod resolvedProjectionMethod = - resolveProjectionMethod(cameraModel, projectionMethod); - - RenderSettings settingsForProjection = settings; - - if (resolvedProjectionMethod == ProjectionMethod::ANALYTIC) { - return sparseProjectGaussiansImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settingsForProjection, - cameraModel); - } - - const int C = worldToCameraMatrices.size(0); - TORCH_CHECK(static_cast(pixelsToRender.num_outer_lists()) == C, - "pixelsToRender must have the same number of outer lists as the number of cameras. " - "Got ", - pixelsToRender.num_outer_lists(), - " outer lists but ", - C, - " cameras. "); - - SparseProjectedGaussianSplats ret; - ret.mRenderSettings = settingsForProjection; - ret.mCameraModel = cameraModel; - ret.mProjectionMethod = resolvedProjectionMethod; - - auto [uniquePixels, inverseIndices, hasDuplicates] = - deduplicatePixels(pixelsToRender, settings.imageWidth, settings.imageHeight); - ret.inverseIndices = inverseIndices; - ret.uniquePixelsToRender = uniquePixels; - ret.hasDuplicates = hasDuplicates; - - const int numTilesW = std::ceil(settings.imageWidth / static_cast(settings.tileSize)); - const int numTilesH = std::ceil(settings.imageHeight / static_cast(settings.tileSize)); - const auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo(settings.tileSize, numTilesW, numTilesH, uniquePixels); - ret.activeTiles = activeTiles; - ret.activeTileMask = activeTileMask; - ret.tilePixelMask = tilePixelMask; - ret.tilePixelCumsum = tilePixelCumsum; - ret.pixelMap = pixelMap; - - const torch::Tensor distortionCoeffsTensor = distortionCoeffs.has_value() - ? distortionCoeffs.value() - : torch::empty({C, 0}, mMeans.options()); - fvdb::detail::ops::UTParams utParams = fvdb::detail::ops::UTParams{}; - const auto projectionResults = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianProjectionForwardUT( - mMeans, - mQuats, - mLogScales, - worldToCameraMatrices, - worldToCameraMatrices, - projectionMatrices, - fvdb::detail::ops::RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffsTensor, - static_cast(settings.imageWidth), - static_cast(settings.imageHeight), - settings.eps2d, - settings.nearPlane, - settings.farPlane, - settings.radiusClip, - settings.antialias); - }); - ret.perGaussianRadius = std::get<0>(projectionResults); - ret.perGaussian2dMean = std::get<1>(projectionResults); - ret.perGaussianDepth = std::get<2>(projectionResults); - ret.perGaussianConic = std::get<3>(projectionResults); - - ret.perGaussianOpacity = opacities().repeat({C, 1}); - if (settings.antialias) { - const torch::Tensor compensations = std::get<4>(projectionResults); - TORCH_CHECK(compensations.defined(), - "UT projection returned an undefined compensation tensor in antialias mode"); - ret.perGaussianOpacity *= compensations; - ret.perGaussianOpacity = ret.perGaussianOpacity.contiguous(); - } - - ret.perGaussianRenderQuantity = [&]() { - torch::Tensor renderQuantity; - if (settings.renderMode == RenderMode::DEPTH) { - renderQuantity = ret.perGaussianDepth.unsqueeze(-1); - } else if (settings.renderMode == RenderMode::RGB || - settings.renderMode == RenderMode::RGBD) { - renderQuantity = evalSphericalHarmonicsImpl( - settings.shDegreeToUse, worldToCameraMatrices, ret.perGaussianRadius); - if (settings.renderMode == RenderMode::RGBD) { - renderQuantity = - torch::cat({renderQuantity, ret.perGaussianDepth.unsqueeze(-1)}, -1); - } - } else { - TORCH_CHECK_VALUE(false, "Invalid render mode"); - } - return renderQuantity; - }(); - - const auto [sparseTileOffsets, tileGaussianIds] = FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return detail::ops::dispatchGaussianSparseTileIntersection(ret.perGaussian2dMean, - ret.perGaussianRadius, - ret.perGaussianDepth, - ret.activeTileMask, - ret.activeTiles, - at::nullopt, - C, - settings.tileSize, - numTilesH, - numTilesW); - }); - ret.tileOffsets = sparseTileOffsets; - ret.tileGaussianIds = tileGaussianIds; - - return ret; -} - -std::tuple -GaussianSplat3d::renderCropFromProjectedGaussiansImpl( - const ProjectedGaussianSplats &projectedGaussians, - const size_t tileSize, - const ssize_t cropWidth, - const ssize_t cropHeight, - const ssize_t cropOriginW, - const ssize_t cropOriginH, - const std::optional &backgrounds, - const std::optional &masks) { - FVDB_FUNC_RANGE(); - // Negative values mean use the whole image, but all values must be negative - if (cropWidth <= 0 || cropHeight <= 0 || cropOriginW < 0 || cropOriginH < 0) { - TORCH_CHECK_VALUE(cropWidth <= 0 && cropHeight <= 0 && cropOriginW <= 0 && cropOriginH <= 0, - "Invalid crop dimensions"); - } else { - TORCH_CHECK_VALUE(cropWidth > 0 && cropHeight > 0 && cropOriginW >= 0 && cropOriginH >= 0, - "Invalid crop dimensions"); - } - - const size_t cropWidth_ = cropWidth <= 0 ? projectedGaussians.imageWidth() : cropWidth; - const size_t cropHeight_ = cropHeight <= 0 ? projectedGaussians.imageHeight() : cropHeight; - const size_t cropOriginW_ = cropOriginW < 0 ? 0 : cropOriginW; - const size_t cropOriginH_ = cropOriginH < 0 ? 0 : cropOriginH; - - // Rasterize projected Gaussians to pixels (differentiable) - // NOTE: projectGaussians* performs input checking, we need to apply some further - // checking before GaussianRasterizeToPixels - auto outputs = detail::autograd::RasterizeGaussiansToPixels::apply( - projectedGaussians.perGaussian2dMean, - projectedGaussians.perGaussianConic, - projectedGaussians.perGaussianRenderQuantity, - projectedGaussians.perGaussianOpacity, - cropWidth_, - cropHeight_, - cropOriginW_, - cropOriginH_, - tileSize, - projectedGaussians.tileOffsets, - projectedGaussians.tileGaussianIds, - false, - backgrounds, - masks); - torch::Tensor renderedImage = outputs[0]; - torch::Tensor renderedAlphas = outputs[1]; - - return {renderedImage, renderedAlphas}; -} - -std::tuple -GaussianSplat3d::sparseRenderImpl(const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &backgrounds, - const std::optional &masks) { - FVDB_FUNC_RANGE(); - - const SparseProjectedGaussianSplats &state = - sparseProjectGaussiansForCameraImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - - // Render using unique (deduplicated) pixels - const auto &renderPixels = state.hasDuplicates ? state.uniquePixelsToRender : pixelsToRender; - - auto rasterizeResult = - detail::autograd::RasterizeGaussiansToPixelsSparse::apply(renderPixels, - state.perGaussian2dMean, - state.perGaussianConic, - state.perGaussianRenderQuantity, - state.perGaussianOpacity, - settings.imageWidth, - settings.imageHeight, - 0, - 0, - settings.tileSize, - state.tileOffsets, - state.tileGaussianIds, - state.activeTiles, - state.tilePixelMask, - state.tilePixelCumsum, - state.pixelMap, - false, - backgrounds, - masks); - auto renderedPixelsJData = rasterizeResult[0]; - auto renderedAlphasJData = rasterizeResult[1]; - - // Scatter unique results back to all original positions (including duplicates). - // index_select's autograd backward (scatter_add) naturally sums gradients from duplicates. - if (state.hasDuplicates) { - renderedPixelsJData = renderedPixelsJData.index_select(0, state.inverseIndices); - renderedAlphasJData = renderedAlphasJData.index_select(0, state.inverseIndices); - } - - return {pixelsToRender.jagged_like(renderedPixelsJData), - pixelsToRender.jagged_like(renderedAlphasJData)}; -} - -std::tuple -GaussianSplat3d::renderNumContributingGaussiansImpl( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs) { - FVDB_FUNC_RANGE(); - const ProjectedGaussianSplats &state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianRasterizeNumContributingGaussians( - state.perGaussian2dMean, - state.perGaussianConic, - state.perGaussianOpacity, - state.tileOffsets, - state.tileGaussianIds, - settings); - }); -} - -std::tuple -GaussianSplat3d::sparseRenderNumContributingGaussiansImpl( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs) { - FVDB_FUNC_RANGE(); - - const SparseProjectedGaussianSplats &state = - sparseProjectGaussiansForCameraImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - - const auto &renderPixels = state.hasDuplicates ? state.uniquePixelsToRender : pixelsToRender; - - auto result = FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianSparseRasterizeNumContributingGaussians< - DeviceTag>(state.perGaussian2dMean, - state.perGaussianConic, - state.perGaussianOpacity, - state.tileOffsets, - state.tileGaussianIds, - renderPixels, - state.activeTiles, - state.tilePixelMask, - state.tilePixelCumsum, - state.pixelMap, - settings); - }); - - if (state.hasDuplicates) { - auto &[jt0, jt1] = result; - return {pixelsToRender.jagged_like(jt0.jdata().index_select(0, state.inverseIndices)), - pixelsToRender.jagged_like(jt1.jdata().index_select(0, state.inverseIndices))}; - } - return result; -} - -std::tuple -GaussianSplat3d::renderContributingGaussianIdsImpl( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &maybeNumContributingGaussians) { - FVDB_FUNC_RANGE(); - const ProjectedGaussianSplats &state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - // TODO: Currently projection only performs spherical harmonics evaluation on input SH/features, - // whereas we'd really like rendering to be more generic to be able to supply some - // other render quantity directly, such as an integer ID as we need in this case. So, to - // just test the ID rendering we need, we'll pass the IDs we want to render here rather - // than have a 'render deep samples of any quantity or evaluated shading' function. - // - // It would be more reusable here to render any 'features' we want by being able to - // provide a tensor of any type/value without any shading evaluation (or maybe in the - // future support other shading models/quantities). This would add complexity to support - // more 'shading models' during projection, i.e. not evaluating SH and passing through the - // 'raw' features. Secondly, this would require carrying around a separate 'feature' - // tensor from the sh0/shN ones to support a different 'shading model' and there's - // currently quite a lot of logic that assumes the primacy of the sh0/shN tensors, so we - // leave this all to a further refactor and just render 'deep IDs' as a fixed function. - // Currently, this function uses the existing 'projectGuassiansImpl' which performs SH - // evaluation which creates some wasted computation because we don't use the SH values. - - return FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianRasterizeContributingGaussianIds( - state.perGaussian2dMean, - state.perGaussianConic, - state.perGaussianOpacity, - state.tileOffsets, - state.tileGaussianIds, - settings, - maybeNumContributingGaussians); - }); -} - -std::tuple -GaussianSplat3d::sparseRenderContributingGaussianIdsImpl( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &maybeNumContributingGaussians) { - FVDB_FUNC_RANGE(); - - const SparseProjectedGaussianSplats &state = - sparseProjectGaussiansForCameraImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - - const auto &renderPixels = state.hasDuplicates ? state.uniquePixelsToRender : pixelsToRender; - - // When duplicates were removed, numContributingGaussians is in original (duplicated) space - // but the kernel expects it in unique-pixel space. Pick one representative per group. - std::optional kernelNumContrib = maybeNumContributingGaussians; - if (state.hasDuplicates && maybeNumContributingGaussians.has_value()) { - const auto device = state.inverseIndices.device(); - const auto longOpt = torch::TensorOptions().device(device).dtype(torch::kLong); - auto repIdx = torch::empty({renderPixels.rsize(0)}, longOpt); - repIdx.scatter_(0, state.inverseIndices, torch::arange(pixelsToRender.rsize(0), longOpt)); - auto uniqueData = maybeNumContributingGaussians->jdata().index_select(0, repIdx); - kernelNumContrib = renderPixels.jagged_like(uniqueData); - } - - auto result = FVDB_DISPATCH_KERNEL_DEVICE(state.perGaussian2dMean.device(), [&]() { - return fvdb::detail::ops::dispatchGaussianSparseRasterizeContributingGaussianIds( - state.perGaussian2dMean, - state.perGaussianConic, - state.perGaussianOpacity, - state.tileOffsets, - state.tileGaussianIds, - renderPixels, - state.activeTiles, - state.tilePixelMask, - state.tilePixelCumsum, - state.pixelMap, - settings, - kernelNumContrib); - }); - - if (state.hasDuplicates) { - auto &[jt0, jt1] = result; - return {pixelsToRender.jagged_like(jt0.jdata().index_select(0, state.inverseIndices)), - pixelsToRender.jagged_like(jt1.jdata().index_select(0, state.inverseIndices))}; - } - return result; -} - -GaussianSplat3d::ProjectedGaussianSplats -GaussianSplat3d::projectGaussiansForImages(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const float minRadius2d, - const float eps2d, - const bool antialias) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - - settings.renderMode = RenderMode::RGB; - - return projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); -} - -GaussianSplat3d::ProjectedGaussianSplats -GaussianSplat3d::projectGaussiansForDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const float minRadius2d, - const float eps2d, - const bool antialias) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = -1; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderMode::DEPTH; - - return projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); -} - -GaussianSplat3d::ProjectedGaussianSplats -GaussianSplat3d::projectGaussiansForImagesAndDepths( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const float minRadius2d, - const float eps2d, - const bool antialias) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - - settings.renderMode = RenderMode::RGBD; - - return projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); -} - -namespace { - -/// @brief Get a uint8_t pointer to the data of a tensor -/// @param tensor The tensor to get the pointer to -/// @return A uint8_t pointer to the data of the tensor -inline uint8_t * -tensorBytePointer(const torch::Tensor &tensor) { - return static_cast(tensor.data_ptr()); -} - -} // namespace - -void -GaussianSplat3d::savePly( - const std::string &filename, - std::optional> metadata) const { - detail::io::saveGaussianPly(filename, *this, metadata); -} - -std::tuple> -GaussianSplat3d::fromPly(const std::string &filename, torch::Device device) { - return detail::io::loadGaussianPly(filename, device); -} - -std::tuple -GaussianSplat3d::renderFromProjectedGaussians( - const GaussianSplat3d::ProjectedGaussianSplats &projectedGaussians, - const ssize_t cropWidth, - const ssize_t cropHeight, - const ssize_t cropOriginW, - const ssize_t cropOriginH, - const size_t tileSize, - const std::optional &backgrounds, - const std::optional &masks) { - return renderCropFromProjectedGaussiansImpl(projectedGaussians, - tileSize, - cropWidth, - cropHeight, - cropOriginW, - cropOriginH, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::renderImages(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::RGB; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return renderCropFromProjectedGaussiansImpl(state, - settings.tileSize, - settings.imageWidth, - settings.imageHeight, - 0, - 0, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::renderImagesFromWorld(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - FVDB_FUNC_RANGE(); - const int C = worldToCameraMatrices.size(0); // number of cameras - TORCH_CHECK(C > 0, "At least one camera must be provided (got 0)"); - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::RGB; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - - const torch::Tensor distortionCoeffsForRaster = distortionCoeffs.has_value() - ? distortionCoeffs.value() - : torch::empty({C, 0}, mMeans.options()); - - auto outputs = detail::autograd::RasterizeGaussiansToPixelsFromWorld3DGS::apply( - mMeans, - mQuats, - mLogScales, - state.perGaussianRenderQuantity, - state.perGaussianOpacity, - worldToCameraMatrices, - worldToCameraMatrices, - projectionMatrices, - distortionCoeffsForRaster, - fvdb::detail::ops::RollingShutterType::NONE, - cameraModel, - static_cast(imageWidth), - static_cast(imageHeight), - 0, - 0, - static_cast(tileSize), - state.tileOffsets, - state.tileGaussianIds, - backgrounds, - masks); - - return {outputs[0], outputs[1]}; -} - -std::tuple -GaussianSplat3d::renderDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = -1; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return renderCropFromProjectedGaussiansImpl(state, - settings.tileSize, - settings.imageWidth, - settings.imageHeight, - 0, - 0, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::renderDepthsFromWorld(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - FVDB_FUNC_RANGE(); - const int C = worldToCameraMatrices.size(0); - TORCH_CHECK(C > 0, "At least one camera must be provided (got 0)"); - - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = -1; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - const torch::Tensor distortionCoeffsForRaster = distortionCoeffs.has_value() - ? distortionCoeffs.value() - : torch::empty({C, 0}, mMeans.options()); - - auto outputs = detail::autograd::RasterizeGaussiansToPixelsFromWorld3DGS::apply( - mMeans, - mQuats, - mLogScales, - state.perGaussianRenderQuantity, - state.perGaussianOpacity, - worldToCameraMatrices, - worldToCameraMatrices, - projectionMatrices, - distortionCoeffsForRaster, - fvdb::detail::ops::RollingShutterType::NONE, - cameraModel, - static_cast(imageWidth), - static_cast(imageHeight), - 0, - 0, - static_cast(tileSize), - state.tileOffsets, - state.tileGaussianIds, - backgrounds, - masks); - return {outputs[0], outputs[1]}; -} - -std::tuple -GaussianSplat3d::renderNumContributingGaussians( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = 0; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - - return renderNumContributingGaussiansImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); -} - -std::tuple -GaussianSplat3d::sparseRenderDepths(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = 0; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - - return sparseRenderImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::sparseRenderImages(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderSettings::RenderMode::RGB; - - return sparseRenderImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::sparseRenderImagesAndDepths(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderSettings::RenderMode::RGBD; - - return sparseRenderImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::sparseRenderNumContributingGaussians( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = 0; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - - return sparseRenderNumContributingGaussiansImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); -} - -std::tuple -GaussianSplat3d::renderContributingGaussianIds(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const int topKContributors) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = 0; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - settings.numDepthSamples = topKContributors; - - if (topKContributors > 0) { - return renderContributingGaussianIdsImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - } else { - // Use the standard path - compute actual number of contributing gaussians - torch::Tensor numContributingGaussians, weights; - std::tie(numContributingGaussians, weights) = - renderNumContributingGaussiansImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return renderContributingGaussianIdsImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs, - numContributingGaussians); - } -} - -std::tuple -GaussianSplat3d::sparseRenderContributingGaussianIds( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const int topKContributors) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = 0; - settings.tileSize = tileSize; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.renderMode = RenderSettings::RenderMode::DEPTH; - settings.numDepthSamples = topKContributors; - - if (topKContributors > 0) { - return sparseRenderContributingGaussianIdsImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - } else { - fvdb::JaggedTensor numContributingGaussians, weights; - std::tie(numContributingGaussians, weights) = - sparseRenderNumContributingGaussiansImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return sparseRenderContributingGaussianIdsImpl(pixelsToRender, - worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs, - numContributingGaussians); - } -} - -std::tuple -GaussianSplat3d::renderImagesAndDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::RGBD; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - return renderCropFromProjectedGaussiansImpl(state, - settings.tileSize, - settings.imageWidth, - settings.imageHeight, - 0, - 0, - backgrounds, - masks); -} - -std::tuple -GaussianSplat3d::renderImagesAndDepthsFromWorld( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const int64_t shDegreeToUse, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias, - const std::optional &backgrounds, - const std::optional &masks) { - FVDB_FUNC_RANGE(); - const int C = worldToCameraMatrices.size(0); - TORCH_CHECK(C > 0, "At least one camera must be provided (got 0)"); - - RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.nearPlane = near; - settings.farPlane = far; - settings.shDegreeToUse = shDegreeToUse; - settings.radiusClip = minRadius2d; - settings.eps2d = eps2d; - settings.antialias = antialias; - settings.tileSize = tileSize; - settings.renderMode = RenderSettings::RenderMode::RGBD; - - const ProjectedGaussianSplats state = projectGaussiansForCameraImpl(worldToCameraMatrices, - projectionMatrices, - settings, - cameraModel, - projectionMethod, - distortionCoeffs); - const torch::Tensor distortionCoeffsForRaster = distortionCoeffs.has_value() - ? distortionCoeffs.value() - : torch::empty({C, 0}, mMeans.options()); - - auto outputs = detail::autograd::RasterizeGaussiansToPixelsFromWorld3DGS::apply( - mMeans, - mQuats, - mLogScales, - state.perGaussianRenderQuantity, - state.perGaussianOpacity, - worldToCameraMatrices, - worldToCameraMatrices, - projectionMatrices, - distortionCoeffsForRaster, - fvdb::detail::ops::RollingShutterType::NONE, - cameraModel, - static_cast(imageWidth), - static_cast(imageHeight), - 0, - 0, - static_cast(tileSize), - state.tileOffsets, - state.tileGaussianIds, - backgrounds, - masks); - return {outputs[0], outputs[1]}; -} - -std::tuple -GaussianSplat3d::relocateGaussians(const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &ratios, - const torch::Tensor &binomialCoeffs, - const int nMax, - const float minOpacity) { - return FVDB_DISPATCH_KERNEL(logScales.device(), [&]() { - return detail::ops::dispatchGaussianRelocation( - logScales, logitOpacities, ratios, binomialCoeffs, nMax, minOpacity); - }); -} - -void -GaussianSplat3d::addNoiseToMeans(const float noiseScale, const float t, const float k) { - FVDB_DISPATCH_KERNEL(mMeans.device(), [&]() { - return detail::ops::dispatchGaussianMCMCAddNoise( - mMeans, mLogScales, mLogitOpacities, mQuats, noiseScale, t, k); - }); -} - -GaussianSplat3d -GaussianSplat3d::tensorIndexGetImpl(const torch::Tensor &indices) const { - auto ret = GaussianSplat3d(mMeans.index({indices}), - mQuats.index({indices}), - mLogScales.index({indices}), - mLogitOpacities.index({indices}), - mSh0.index({indices}), - mShN.index({indices}), - mAccumulateMean2dGradients, - mAccumulateMax2dRadii, - false); - - if (mAccumulated2dRadiiForGrad.numel() > 0) { - ret.mAccumulated2dRadiiForGrad = mAccumulated2dRadiiForGrad.index({indices}); - } - - if (mGradientStepCountForGrad.numel() > 0) { - ret.mGradientStepCountForGrad = mGradientStepCountForGrad.index({indices}); - } - - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - ret.mAccumulatedNormalized2dMeansGradientNormsForGrad = - mAccumulatedNormalized2dMeansGradientNormsForGrad.index({indices}); - } - - return ret; -} -GaussianSplat3d -GaussianSplat3d::sliceSelect(const int64_t begin, const int64_t stop, const int64_t step) const { - auto slice = torch::indexing::Slice(begin, stop, step); - - auto ret = GaussianSplat3d(mMeans.index({slice}), - mQuats.index({slice}), - mLogScales.index({slice}), - mLogitOpacities.index({slice}), - mSh0.index({slice}), - mShN.index({slice}), - mAccumulateMean2dGradients, - mAccumulateMax2dRadii, - false); - - if (mAccumulated2dRadiiForGrad.numel() > 0) { - ret.mAccumulated2dRadiiForGrad = mAccumulated2dRadiiForGrad.index({slice}); - } - - if (mGradientStepCountForGrad.numel() > 0) { - ret.mGradientStepCountForGrad = mGradientStepCountForGrad.index({slice}); - } - - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - ret.mAccumulatedNormalized2dMeansGradientNormsForGrad = - mAccumulatedNormalized2dMeansGradientNormsForGrad.index({slice}); - } - - return ret; -} - -GaussianSplat3d -GaussianSplat3d::indexSelect(const torch::Tensor &indices) const { - TORCH_CHECK_VALUE(indices.dim() == 1, "indices must be a 1D tensor"); - TORCH_CHECK_VALUE(indices.dtype() == torch::kInt64 || indices.dtype() == torch::kInt32, - "indices must be of type int64 or int32"); - TORCH_CHECK_VALUE(indices.device() == indices.device(), - "indices must be on the same device as the GaussianSplat3d object"); - - return tensorIndexGetImpl(indices); -} - -GaussianSplat3d -GaussianSplat3d::maskSelect(const torch::Tensor &mask) const { - TORCH_CHECK_VALUE(mask.dim() == 1, "mask must be a 1D tensor"); - TORCH_CHECK_VALUE(mask.dtype() == torch::kBool, "mask must be of type bool"); - TORCH_CHECK_VALUE(mask.device() == mMeans.device(), - "mask must be on the same device as the GaussianSplat3d object"); - TORCH_CHECK_VALUE(mask.size(0) == mMeans.size(0), - "mask must have the same size as the number of gaussians"); - - return tensorIndexGetImpl(mask); -} - -void -GaussianSplat3d::tensorIndexSetImpl(const torch::Tensor &indices, const GaussianSplat3d &other) { - mMeans = mMeans.index_put({indices}, other.mMeans); - mQuats = mQuats.index_put({indices}, other.mQuats); - mLogScales = mLogScales.index_put({indices}, other.mLogScales); - mLogitOpacities = mLogitOpacities.index_put({indices}, other.mLogitOpacities); - mSh0 = mSh0.index_put({indices}, other.mSh0); - mShN = mShN.index_put({indices}, other.mShN); - - if (mAccumulated2dRadiiForGrad.numel() > 0) { - if (other.mAccumulated2dRadiiForGrad.numel() > 0) { - // If other is also tracking max 2d radii, make sure we copy them over - mAccumulated2dRadiiForGrad.index_put_({indices}, other.mAccumulated2dRadiiForGrad); - } else { - // If the other does not have accumulated radii, we set it to zero - mAccumulated2dRadiiForGrad.index_put_( - {indices}, - torch::zeros(other.numGaussians(), mAccumulated2dRadiiForGrad.options())); - } - } - - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - if (other.mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - // If other is also tracking accumulated normalized means gradient norms, - // make sure we copy them over - mAccumulatedNormalized2dMeansGradientNormsForGrad.index_put_( - {indices}, other.mAccumulatedNormalized2dMeansGradientNormsForGrad); - } else { - // If the other does not have accumulated normalized means gradient norms, we set it to - // zero - mAccumulatedNormalized2dMeansGradientNormsForGrad.index_put_( - {indices}, - torch::zeros(other.numGaussians(), - mAccumulatedNormalized2dMeansGradientNormsForGrad.options())); - } - } - - if (mGradientStepCountForGrad.numel() > 0) { - if (other.mGradientStepCountForGrad.numel() > 0) { - // If other is also tracking gradient step counts, make sure we copy them over - mGradientStepCountForGrad.index_put_({indices}, other.mGradientStepCountForGrad); - } else { - // If the other does not have gradient step counts, we set it to zero - mGradientStepCountForGrad.index_put_( - {indices}, torch::zeros(other.numGaussians(), mGradientStepCountForGrad.options())); - } - } -} - -void -GaussianSplat3d::sliceSet(const int64_t begin, - const int64_t end, - const int64_t step, - const GaussianSplat3d &other) { - const auto slice = torch::indexing::Slice(begin, end, step); - - mMeans.index({slice}) = other.mMeans; - mQuats.index({slice}) = other.mQuats; - mLogScales.index({slice}) = other.mLogScales; - mLogitOpacities.index({slice}) = other.mLogitOpacities; - mSh0.index({slice}) = other.mSh0; - mShN.index({slice}) = other.mShN; - - if (mAccumulated2dRadiiForGrad.numel() > 0) { - if (other.mAccumulated2dRadiiForGrad.numel() > 0) { - // If other is also tracking max 2d radii, make sure we copy them over - mAccumulated2dRadiiForGrad.index({slice}) = other.mAccumulated2dRadiiForGrad; - } else { - // If the other does not have accumulated radii, we set it to zero - mAccumulated2dRadiiForGrad.index({slice}) = - torch::zeros(other.numGaussians(), mAccumulated2dRadiiForGrad.options()); - } - } - - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - if (other.mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() > 0) { - // If other is also tracking accumulated normalized means gradient norms, - // make sure we copy them over - mAccumulatedNormalized2dMeansGradientNormsForGrad.index({slice}) = - other.mAccumulatedNormalized2dMeansGradientNormsForGrad; - } else { - // If the other does not have accumulated normalized means gradient norms, we set it to - // zero - mAccumulatedNormalized2dMeansGradientNormsForGrad.index({slice}) = torch::zeros( - other.numGaussians(), mAccumulatedNormalized2dMeansGradientNormsForGrad.options()); - } - } - - if (mGradientStepCountForGrad.numel() > 0) { - if (other.mGradientStepCountForGrad.numel() > 0) { - // If other is also tracking gradient step counts, make sure we copy them over - mGradientStepCountForGrad.index({slice}) = other.mGradientStepCountForGrad; - } else { - // If the other does not have gradient step counts, we set it to zero - mGradientStepCountForGrad.index({slice}) = - torch::zeros(other.numGaussians(), mGradientStepCountForGrad.options()); - } - } -} - -void -GaussianSplat3d::indexSet(const torch::Tensor &indices, const GaussianSplat3d &other) { - TORCH_CHECK_VALUE(indices.dim() == 1, "indices must be a 1D tensor"); - TORCH_CHECK_VALUE(indices.dtype() == torch::kInt64 || indices.dtype() == torch::kInt32, - "indices must be of type int64 or int32"); - TORCH_CHECK_VALUE(indices.device() == indices.device(), - "indices must be on the same device as the GaussianSplat3d object"); - - tensorIndexSetImpl(indices, other); -} - -void -GaussianSplat3d::maskSet(const torch::Tensor &mask, const GaussianSplat3d &other) { - TORCH_CHECK_VALUE(mask.dim() == 1, "mask must be a 1D tensor"); - TORCH_CHECK_VALUE(mask.dtype() == torch::kBool, "mask must be of type bool"); - TORCH_CHECK_VALUE(mask.device() == mMeans.device(), - "mask must be on the same device as the GaussianSplat3d object"); - TORCH_CHECK_VALUE(mask.size(0) == mMeans.size(0), - "mask must have the same size as the number of gaussians"); - - tensorIndexSetImpl(mask, other); -} - -// TODO: Make a batched class -std::tuple> -gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3] - const JaggedTensor &quats, // [N1 + N2 + ..., 4] - const JaggedTensor &scales, // [N1 + N2 + ..., 3] - const JaggedTensor &opacities, // [N1 + N2 + ...] - const JaggedTensor &sh_coeffs, // [N1 + N2 + ..., K, 3] - const JaggedTensor &viewmats, // [C1 + C2 + ..., 4, 4] - const JaggedTensor &Ks, // [C1 + C2 + ..., 3, 3] - const uint32_t image_width, - const uint32_t image_height, - const float near_plane, - const float far_plane, - const int sh_degree_to_use, - const int tile_size, - const float radius_clip, - const float eps2d, - const bool antialias, - const bool render_depth_channel, - const bool return_debug_info, - const bool render_depth_only, - const bool ortho, - const std::optional &backgrounds, - const std::optional &masks) { - const int ccz = viewmats.rsize(0); // number of cameras - const int ggz = means.rsize(0); // number of gaussians - const int D = render_depth_only ? 1 : sh_coeffs.rsize(-1); // Dimension of output - - using namespace torch::indexing; // For the Slice operation - - TORCH_CHECK(means.rsizes() == torch::IntArrayRef({ggz, 3}), "means must have shape (ggz, 3)"); - TORCH_CHECK(quats.rsizes() == torch::IntArrayRef({ggz, 4}), "quats must have shape (ggz, 4)"); - TORCH_CHECK(scales.rsizes() == torch::IntArrayRef({ggz, 3}), "scales must have shape (ggz, 3)"); - TORCH_CHECK(opacities.rsizes() == torch::IntArrayRef({ggz}), "opacities must have shape (ggz)"); - TORCH_CHECK(viewmats.rsizes() == torch::IntArrayRef({ccz, 4, 4}), - "viewmats must have shape (C, 4, 4)"); - TORCH_CHECK(Ks.rsizes() == torch::IntArrayRef({ccz, 3, 3}), "Ks must have shape (ccz, 3, 3)"); - - TORCH_CHECK(means.is_contiguous(), "means must be contiguous"); - TORCH_CHECK(quats.is_contiguous(), "quats must be contiguous"); - TORCH_CHECK(scales.is_contiguous(), "scales must be contiguous"); - TORCH_CHECK(opacities.is_contiguous(), "opacities must be contiguous"); - TORCH_CHECK(viewmats.is_contiguous(), "viewmats must be contiguous"); - TORCH_CHECK(Ks.is_contiguous(), "Ks must be contiguous"); - - // Check after we dispatch the unbatched version since the unbatched version accepts a - // [K, N, D] tensor for sh_coeffs while the batched version accepts a [ggz, K, D] tensor, - // which gets permuted later on. - const int K = render_depth_only ? 1 : sh_coeffs.rsize(-2); // number of SH bases - TORCH_CHECK(render_depth_only || sh_coeffs.rsizes() == torch::IntArrayRef({ggz, K, D}), - "sh_coeffs must have shape (ggz, K, D)"); - - // TODO: this part is very convoluted. But I don't have a better way of coding it without - // customized CUDA kernels. The idea is that given Gaussians with shape [\sum(N_i), ...] and - // cameras with shape [\sum(C_i), ...], we would calculate the intersection of each Gaussian - // with each camera, which result in a JaggedTensor with shape - // [\sum(C_i * N_i), ...]. And I need to keep track of the camera and Gaussian IDs (the index in - // the jagged tensor) for each intersection: - // - camera_ids: Shape of [\sum(C_i * N_i), ...], with each value \in [0, \sum(C_i)) - // - gaussian_ids: Shape of [\sum(C_i * N_i), ...], with each value \in [0, \sum(N_i)) - - // g_sizes is [N1, N2, ...] - torch::Tensor g_sizes = - means.joffsets().index({Slice(1, None)}) - means.joffsets().index({Slice(0, -1)}); - // c_sizes is [C1, C2, ...] - torch::Tensor c_sizes = - Ks.joffsets().index({Slice(1, None)}) - Ks.joffsets().index({Slice(0, -1)}); - // camera_ids is [0, 0, ..., 1, 1, ...] - torch::Tensor tt = g_sizes.repeat_interleave(c_sizes); - torch::Tensor camera_ids = - torch::arange(viewmats.rsize(0), means.options().dtype(torch::kInt32)) - .repeat_interleave(tt, 0); - // gaussian_ids is [0, 1, ..., 0, 1, ...] - torch::Tensor dd0 = means.joffsets().index({Slice(0, -1)}).repeat_interleave(c_sizes, 0); - torch::Tensor dd1 = means.joffsets().index({Slice(1, None)}).repeat_interleave(c_sizes, 0); - torch::Tensor shifts = dd0.index({Slice(1, None)}) - dd1.index({Slice(0, -1)}); - shifts = torch::cat({torch::tensor({0}, means.device()), shifts}); - torch::Tensor shifts_cumsum = shifts.cumsum(0); - torch::Tensor gaussian_ids = - torch::arange(camera_ids.size(0), means.options().dtype(torch::kInt32)); - gaussian_ids += shifts_cumsum.repeat_interleave(tt, 0); - - // Project to image plane [differentiable] - auto projection_results = detail::autograd::ProjectGaussiansJagged::apply(g_sizes, - means.jdata(), - quats.jdata(), - scales.jdata(), - c_sizes, - viewmats.jdata(), - Ks.jdata(), - image_width, - image_height, - eps2d, - near_plane, - far_plane, - radius_clip, - ortho); - torch::Tensor radii = projection_results[0]; - torch::Tensor means2d = projection_results[1]; - torch::Tensor depths = projection_results[2]; - torch::Tensor conics = projection_results[3]; - - // Turn [N1 + N2 + N3 + ..., ...] into [C1*N1 + C2*N2 + ..., ...] - torch::Tensor opacities_batched = opacities.jdata().index({gaussian_ids}); // [M] - if (antialias) { - opacities_batched *= projection_results[4]; - } - - std::unordered_map debug_info; - if (return_debug_info) { - debug_info["camera_ids"] = camera_ids; - debug_info["gaussian_ids"] = gaussian_ids; - debug_info["radii"] = radii; - debug_info["means2d"] = means2d; - debug_info["depths"] = depths; - debug_info["conics"] = conics; - debug_info["opacities"] = opacities_batched; - } - - torch::Tensor renderQuantities; - if (render_depth_only) { - renderQuantities = depths.index({gaussian_ids}).unsqueeze(-1); // [nnz, 1] - } else { - // Render quantities from SH coefficients [differentiable] - const torch::Tensor sh_coeffs_batched = sh_coeffs.jdata().permute({1, 0, 2}).index( - {Slice(), gaussian_ids, Slice()}); // [K, nnz, 3] - - const int K = sh_coeffs_batched.size(0); // number of SH bases - const int actualShDegree = sh_degree_to_use < 0 ? (std::sqrt(K) - 1) : sh_degree_to_use; - TORCH_CHECK(K >= (actualShDegree + 1) * (actualShDegree + 1), - "K must be at least (shDegreeToUse + 1)^2"); - - if (actualShDegree == 0) { - const auto sh0 = - sh_coeffs_batched.index({0, Slice(), Slice()}).unsqueeze(0); // [1, nnz, 3] - renderQuantities = - detail::autograd::EvaluateSphericalHarmonics::apply(actualShDegree, - 1, - torch::nullopt, - sh0.permute({1, 0, 2}), - torch::nullopt, - radii.unsqueeze(0))[0]; - } else { - const auto sh0 = - sh_coeffs_batched.index({0, Slice(), Slice()}).unsqueeze(0); // [1, nnz, 3] - const auto shN = - sh_coeffs_batched.index({Slice(1, None), Slice(), Slice()}); // [K-1, nnz, 3] - auto [camtoworlds, info] = torch::linalg_inv_ex(viewmats.jdata()); // [ccz, 4, 4] - const torch::Tensor dirs = means.jdata().index({gaussian_ids, Slice()}) - - camtoworlds.index({camera_ids, Slice(None, 3), 3}); - renderQuantities = - detail::autograd::EvaluateSphericalHarmonics::apply(actualShDegree, - 1, - dirs.unsqueeze(0), - sh0.permute({1, 0, 2}), - shN.permute({1, 0, 2}), - radii.unsqueeze(0))[0] - .squeeze(0); - } - - if (render_depth_channel) { - renderQuantities = - torch::cat({renderQuantities, depths.index({gaussian_ids}).unsqueeze(-1)}, -1); - } - } - - // Intersect projected Gaussians with image tiles [non-differentiable] - const int num_tiles_w = std::ceil(image_width / static_cast(tile_size)); - const int num_tiles_h = std::ceil(image_height / static_cast(tile_size)); - std::tuple tile_intersections = - FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { - return detail::ops::dispatchGaussianTileIntersection( - means2d, radii, depths, camera_ids, ccz, tile_size, num_tiles_h, num_tiles_w); - }); - torch::Tensor tile_offsets = std::get<0>(tile_intersections); - torch::Tensor tile_gaussian_ids = std::get<1>(tile_intersections); - if (return_debug_info) { - debug_info["tile_offsets"] = tile_offsets; - debug_info["tile_gaussian_ids"] = tile_gaussian_ids; - } - - // Rasterize projected Gaussians to pixels [differentiable] - auto outputs = - detail::autograd::RasterizeGaussiansToPixels::apply(means2d, - conics, - renderQuantities, - opacities_batched.contiguous(), - image_width, - image_height, - 0, - 0, - tile_size, - tile_offsets, - tile_gaussian_ids, - false, - backgrounds, - masks); - torch::Tensor renderedImages = outputs[0]; - torch::Tensor renderedAlphaImages = outputs[1]; - - return {renderedImages, renderedAlphaImages, debug_info}; -} - -} // namespace fvdb diff --git a/src/fvdb/GaussianSplat3d.h b/src/fvdb/GaussianSplat3d.h deleted file mode 100644 index 044fa3072..000000000 --- a/src/fvdb/GaussianSplat3d.h +++ /dev/null @@ -1,1678 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_GAUSSIANSPLAT3D_H -#define FVDB_GAUSSIANSPLAT3D_H - -#include -#include -#include - -#include -#include -#include - -#include - -namespace fvdb { - -/// @brief A class representing a Gaussian splat scene in 3D space. -/// This class is used to store the parameters of the Gaussians in the scene and provides -/// methods to project the Gaussians onto a 2D image plane, render images and depths, -/// and save the scene to a PLY file. -/// The Gaussians are represented by their means, quaternions (for rotation), log scales, -/// logit opacities, and SH coefficients. We use log_scales and logit_opacities since we can -/// optimize these quantities without clipping them to a specific range. -class GaussianSplat3d { - public: - /// Magic string prepended to additional metadata properties stored in PLY files - inline static const std::string PLY_MAGIC = "fvdb_ply_af_8198767135"; - - /// We won't allow keys in a PLY file longer than this many characters. - inline static const size_t MAX_PLY_KEY_LENGTH = 256; - - inline static const std::string PLY_VERSION_STRING = "fvdb_ply 1.0.0"; - - using PlyMetadataTypes = std::variant; - - GaussianSplat3d(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN, - const bool accumulateMean2dGradients, - const bool accumulateMax2dRadii, - const bool detach); - - /// @brief Create a GaussianSplat3d object from a state_dict (similar to Pytorch's nn.Module). - /// @param stateDict A dictionary containing the state of the GaussianSplat3d object. - /// @return A GaussianSplat3d object created from the state_dict. - GaussianSplat3d(const std::unordered_map &stateDict) { - loadStateDict(stateDict); - } - - using CameraModel = fvdb::detail::ops::DistortionModel; - using ProjectionMethod = fvdb::detail::ops::ProjectionMethod; - - /// @brief A set of projected Gaussians that can be used to render images. - struct ProjectedGaussianSplats { - torch::Tensor perGaussian2dMean; // [C, N, 2] - torch::Tensor perGaussianConic; // [C, N, 3] - torch::Tensor perGaussianRenderQuantity; // [C, N, 3] - torch::Tensor perGaussianDepth; // [C, N, 1] - torch::Tensor perGaussianOpacity; // [N] or [C, N] if antialias is true - torch::Tensor perGaussianRadius; // [C, N] - torch::Tensor tileOffsets; // [C, num_tiles_h, num_tiles_w, 2] - torch::Tensor tileGaussianIds; // [C, num_tiles_h, num_tiles_w, max_gaussians_per_tile] - - fvdb::detail::ops::RenderSettings mRenderSettings; - CameraModel mCameraModel = CameraModel::PINHOLE; - ProjectionMethod mProjectionMethod = ProjectionMethod::ANALYTIC; - - ssize_t - imageHeight() const { - return mRenderSettings.imageHeight; - } - - ssize_t - imageWidth() const { - return mRenderSettings.imageWidth; - } - - float - nearPlane() const { - return mRenderSettings.nearPlane; - } - - float - farPlane() const { - return mRenderSettings.farPlane; - } - - CameraModel - cameraModel() const { - return mCameraModel; - } - - ProjectionMethod - projectionMethod() const { - return mProjectionMethod; - } - - int64_t - shDegreeToUse() const { - return mRenderSettings.shDegreeToUse; - } - - float - minRadius2d() const { - return mRenderSettings.radiusClip; - } - - float - eps2d() const { - return mRenderSettings.eps2d; - } - - bool - antialias() const { - return mRenderSettings.antialias; - } - - torch::Tensor - means2d() const { - return perGaussian2dMean; - } - - torch::Tensor - conics() const { - return perGaussianConic; - } - - torch::Tensor - renderQuantities() const { - return perGaussianRenderQuantity; - } - - torch::Tensor - depths() const { - return perGaussianDepth; - } - - torch::Tensor - opacities() const { - if (perGaussianOpacity.dim() == 1) { - // perGaussianOpacity is [N]; expand to [C, N] view - const int64_t C = perGaussian2dMean.size(0); - return perGaussianOpacity.unsqueeze(0).expand({C, -1}); - } - // Already [C, N] (e.g. antialias case where compensation varies per camera) - return perGaussianOpacity; - } - - torch::Tensor - radii() const { - return perGaussianRadius; - } - - torch::Tensor - offsets() const { - return tileOffsets; - } - - torch::Tensor - gaussianIds() const { - return tileGaussianIds; - } - }; - - /// @brief A set of projected Gaussians with sparse tile intersection data for sparse rendering. - /// This struct extends ProjectedGaussianSplats with additional sparse-specific tensors. - struct SparseProjectedGaussianSplats : public ProjectedGaussianSplats { - torch::Tensor activeTiles; // [num_active_tiles] - tile IDs of active tiles - torch::Tensor activeTileMask; // [C, TH, TW] - boolean mask of active tiles - torch::Tensor tilePixelMask; // [num_active_tiles, words_per_tile] - bitmask of pixels - torch::Tensor tilePixelCumsum; // [num_active_tiles] - cumulative sum of active pixels - torch::Tensor pixelMap; // [num_active_pixels] - mapping for pixel write order - // Note: tileOffsets (inherited) is 1D [num_active_tiles + 1] in sparse mode - - // Duplicate pixel handling: when pixelsToRender contains duplicates, we deduplicate - // before rendering and scatter results back. inverseIndices maps each original pixel - // position to its corresponding unique pixel index. - torch::Tensor inverseIndices; // [total_pixels] maps original -> unique index - JaggedTensor uniquePixelsToRender; // deduplicated pixels passed to computeSparseInfo - bool hasDuplicates = false; - }; - - public: - /// @brief Concatenate a vector of GaussianSplat3d objects into a single GaussianSplat3d object. - /// @param splats A vector of GaussianSplat3d objects to concatenate. - /// @param accumulateMean2dGradients Whether to accumulate the mean 2D gradients for each - /// Gaussian. - /// For splats that do not have mean2d gradients, zeros will be copied to the means2d - /// gradient norm state in the output. - /// @param accumulateMax2dRadii Whether to accumulate the maximum 2D radii for each Gaussian. - /// For splats that do not have 2D radii, zeros will be copied to the radii state in the - /// output. - /// @return A new GaussianSplat3d object that is the concatenation of the input splats. - static GaussianSplat3d - cat(const std::vector &splats, - bool accumulateMean2dGradients, - bool accumulateMax2dRadii, - bool detach) { - TORCH_CHECK_VALUE(!splats.empty(), "Cannot concatenate an empty vector of splats"); - - std::vector meansVec, quatsVec, logScalesVec, logitOpacitiesVec, sh0Vec, - shNVec; - - std::vector accStepCountsVec, accMax2dRadiiVec, accNorm2dMeansGradientsVec; - - const auto device = splats[0].device(); - const auto dtype = splats[0].scalarType(); - - for (const auto &splat: splats) { - TORCH_CHECK_VALUE(splat.device() == device, "All splats must be on the same device"); - TORCH_CHECK_VALUE(splat.scalarType() == dtype, "All splats must be of the same type"); - - meansVec.push_back(splat.mMeans); - quatsVec.push_back(splat.mQuats); - logScalesVec.push_back(splat.mLogScales); - logitOpacitiesVec.push_back(splat.mLogitOpacities); - sh0Vec.push_back(splat.mSh0); - shNVec.push_back(splat.mShN); - - const auto N = splat.numGaussians(); - if (accumulateMean2dGradients) { - auto [accNorm2dMeansGradients, accGradientStepCounts] = [&]() { - if (splat.mAccumulatedNormalized2dMeansGradientNormsForGrad.defined()) { - TORCH_CHECK( - splat.mAccumulatedNormalized2dMeansGradientNormsForGrad.numel() == N, - "accumulated_mean_2d_gradient_norms_for_grad must have shape (N)"); - TORCH_CHECK( - splat.mAccumulatedNormalized2dMeansGradientNormsForGrad.device() == - splat.device(), - "accumulated_mean_2d_gradient_norms_for_grad must be on the same device as " - "means"); - TORCH_CHECK(splat.mGradientStepCountForGrad.defined(), - "gradient_step_counts_for_grad must be non-empty if " - "accumulated_mean_2d_gradient_norms_for_grad is non-empty"); - TORCH_CHECK( - splat.mGradientStepCountForGrad.numel() == N, - "accumulated_gradient_step_counts_for_grad must have shape (N)"); - return std::make_tuple( - splat.mAccumulatedNormalized2dMeansGradientNormsForGrad, - splat.mGradientStepCountForGrad); - } else { - return std::make_tuple(torch::zeros({N}, splat.mMeans.options()), - torch::zeros({N}, torch::kInt32).to(splat.device())); - } - }(); - accNorm2dMeansGradientsVec.push_back(accNorm2dMeansGradients); - accStepCountsVec.push_back(accGradientStepCounts); - } - if (accumulateMax2dRadii) { - if (splat.mAccumulated2dRadiiForGrad.defined()) { - TORCH_CHECK(splat.mAccumulated2dRadiiForGrad.numel() == N, - "accumulated_max_2d_radii_for_grad must have shape (N)"); - TORCH_CHECK( - splat.mAccumulated2dRadiiForGrad.device() == splat.device(), - "accumulated_max_2d_radii_for_grad must be on the same device as means"); - accMax2dRadiiVec.push_back(splat.mAccumulated2dRadiiForGrad); - } else { - accMax2dRadiiVec.push_back(torch::zeros({N}, torch::kInt32).to(splat.device())); - } - } - } - - torch::Tensor meansCat = torch::cat(meansVec, 0); - torch::Tensor quatsCat = torch::cat(quatsVec, 0); - torch::Tensor logScalesCat = torch::cat(logScalesVec, 0); - torch::Tensor logitOpacitiesCat = torch::cat(logitOpacitiesVec, 0); - torch::Tensor sh0Cat = torch::cat(sh0Vec, 0); - torch::Tensor shNCat = torch::cat(shNVec, 0); - - auto ret = GaussianSplat3d(meansCat, - quatsCat, - logScalesCat, - logitOpacitiesCat, - sh0Cat, - shNCat, - accumulateMean2dGradients, - accumulateMax2dRadii, - detach); - - if (accumulateMean2dGradients) { - auto catNorm2dGradMeans = torch::cat(accNorm2dMeansGradientsVec, 0); - if (detach) { - catNorm2dGradMeans = catNorm2dGradMeans.detach(); - } - auto catStepCounts = torch::cat(accStepCountsVec, 0); - if (detach) { - catStepCounts = catStepCounts.detach(); - } - ret.mAccumulatedNormalized2dMeansGradientNormsForGrad = catNorm2dGradMeans; - ret.mGradientStepCountForGrad = catStepCounts; - } - if (accumulateMax2dRadii) { - auto catMax2dRadii = torch::cat(accMax2dRadiiVec, 0); - if (detach) { - catMax2dRadii = catMax2dRadii.detach(); - } - ret.mAccumulated2dRadiiForGrad = catMax2dRadii; - } - - return ret; - } - - /// @brief Get the device this Gaussian splat is on. - /// @return The device of the means tensor. - torch::Device - device() const { - TORCH_CHECK(mMeans.device() == mQuats.device(), - "All tensors must be on the same device. Means and quats must match."); - TORCH_CHECK(mMeans.device() == mLogScales.device(), - "All tensors must be on the same device. Means and log scales must match."); - TORCH_CHECK( - mMeans.device() == mLogitOpacities.device(), - "All tensors must be on the same device. Means and logit opacities must match."); - TORCH_CHECK(mMeans.device() == mSh0.device(), - "All tensors must be on the same device. Means and SH0 must match."); - TORCH_CHECK(mMeans.device() == mShN.device(), - "All tensors must be on the same device. Means and SHN must match."); - return mMeans.device(); - } - - /// @brief Get the scalar type of the tensors in this Gaussian splat. - /// @return The scalar type of the means tensor. - /// All tensors are expected to have the same scalar type. - torch::ScalarType - scalarType() const { - TORCH_CHECK(mMeans.scalar_type() == mQuats.scalar_type(), - "All tensors must be of the same type. Means and quats must match."); - TORCH_CHECK(mMeans.scalar_type() == mLogScales.scalar_type(), - "All tensors must be of the same type. Means and log scales must match."); - TORCH_CHECK(mMeans.scalar_type() == mLogitOpacities.scalar_type(), - "All tensors must be of the same type. Means and logit opacities must match."); - TORCH_CHECK(mMeans.scalar_type() == mSh0.scalar_type(), - "All tensors must be of the same type. Means and SH0 must match."); - TORCH_CHECK(mMeans.scalar_type() == mShN.scalar_type(), - "All tensors must be of the same type. Means and SHN must match."); - return mMeans.scalar_type(); - } - - /// @brief Return the means of the Gaussians in this scene. - /// @return An [N, 3]-shaped tensor representing the means of the Gaussians in this scenes. - torch::Tensor - means() const { - return mMeans; - } - - /// @brief Return the quaternions of the Gaussians in this scene which define the rotation - /// component of the covariance of each Gaussian (in the form [x, y, z, w]). - /// @return An [N, 4]-shaped tensor representing the quaternions of the Gaussians in this scene. - torch::Tensor - quats() const { - return mQuats; - } - - /// @brief Return the log of the scales of the Gaussians in this scene. - /// @return An [N]-shaped tensor representing the log of the scales of the - /// Gaussians in this scene. - torch::Tensor - logScales() const { - return mLogScales; - } - - /// @brief Return the logit (inverse of Sigmoid) of the opacities of the Gaussians in this - /// scene. - /// @return An [N]-shaped tensor representing the logit of the opacities of the - /// Gaussians in this scene. - torch::Tensor - logitOpacities() const { - return mLogitOpacities; - } - - /// @brief Return the diffuse SH coefficients of the Gaussians in this scene - /// @return An [N, 1, D]-shaped tensor representing the diffuse SH coefficients of the - /// Gaussians in this scene. - torch::Tensor - sh0() const { - return mSh0; - } - - /// @brief Return the directionally-dependent SH coefficients of the Gaussians in this scene - /// @return A [N, K-1, D]-shaped tensor representing the directionally-dependent SH - /// coefficients of the Gaussians in this scene. - torch::Tensor - shN() const { - return mShN; - } - - /// @brief Return the scales of the Gaussians in this scene. - /// @return An [N, 3]-shaped tensor representing the scales of the Gaussians in this scene. - /// (i.e. exp(logScales)). - torch::Tensor - scales() const { - return torch::exp(mLogScales); - } - - /// @brief Return the opacities of the Gaussians in this scene. - /// @return An [N]-shaped tensor representing the opacities of the Gaussians in this scene. - /// (i.e. sigmoid(logitOpacities)). - torch::Tensor - opacities() const { - return torch::sigmoid(mLogitOpacities); - } - - int64_t - shDegree() const { - // The SH degree is determined by the number of SH coefficients in shN. - // If shN is empty, we return -1 to indicate that no SH coefficients are used. - const auto K = mShN.size(1) + 1; // number of SH bases - const auto shDegree = static_cast(std::sqrt(K) - 1); - return shDegree; - } - - /// @brief Return a copy of this GaussianSplat3d object with the same parameters, but detached - /// from the computation graph. - /// @return A new instance of GaussianSplat3d with the same parameters, but detached from the - /// computation graph. - GaussianSplat3d - detach() const { - return GaussianSplat3d(mMeans.detach(), - mQuats.detach(), - mLogScales.detach(), - mLogitOpacities.detach(), - mSh0.detach(), - mShN.detach(), - mAccumulateMean2dGradients, - mAccumulateMax2dRadii, - false); - } - - /// @brief Return a copy of this GaussianSplat3d object with the same parameters, but moved to - /// the specified device and dtype, or return *this if the device and dtype match this. - /// @param device The device to move the tensors to. - /// @param dtype The data type to convert the tensors to. - /// @return A new instance of GaussianSplat3d with the same parameters, but moved to the - /// specified device and dtype, or *this if the device and dtype match this. - GaussianSplat3d - to(torch::Device device, torch::ScalarType dtype) { - if (this->device() == device && this->scalarType() == dtype) { - return *this; // No need to copy if already on the right device and type - } else { - auto ret = GaussianSplat3d( - mMeans.to(device, dtype), - mQuats.to(device, dtype), - mLogScales.to(device, dtype), - mLogitOpacities.to(device, dtype), - mSh0.to(device, dtype), - mShN.to(device, dtype), - mAccumulateMean2dGradients, - mAccumulateMax2dRadii, - false // Detach is false since we are copying the data (not detaching it) - ); - if (mAccumulated2dRadiiForGrad.defined()) { - ret.mAccumulated2dRadiiForGrad = mAccumulated2dRadiiForGrad.to(device); - } - if (mAccumulatedNormalized2dMeansGradientNormsForGrad.defined()) { - ret.mAccumulatedNormalized2dMeansGradientNormsForGrad = - mAccumulatedNormalized2dMeansGradientNormsForGrad.to(device, dtype); - } - if (mGradientStepCountForGrad.defined()) { - ret.mGradientStepCountForGrad = mGradientStepCountForGrad.to(device); - } - return ret; - } - } - - /// @brief Detach the parameters of this GaussianSplat3d object in place. - /// This will detach the parameters from the computation graph, allowing them to be - /// modified without affecting the gradients of the original tensors. - void - detachInPlace() { - mMeans.detach_(); - mQuats.detach_(); - mLogScales.detach_(); - mLogitOpacities.detach_(); - mSh0.detach_(); - mShN.detach_(); - } - - /// @brief Set the log of the opacities of the Gaussians in this scene. - /// @param logitOpacities An [N]-shaped tensor representing the log of the opacities of the - /// Gaussians in this scene. - void - setLogitOpacities(const torch::Tensor &logitOpacities) { - TORCH_CHECK_VALUE(logitOpacities.sizes() == mLogitOpacities.sizes(), - "logit_opacities must have the same shape as the current opacities"); - TORCH_CHECK_VALUE( - logitOpacities.device() == mLogitOpacities.device(), - "logit_opacities must be on the same device as the current logit_opacities"); - mLogitOpacities = logitOpacities; - } - - /// @brief Set the log of the scales of the Gaussians in this scene. - /// @param logScales An [N, 3]-shaped tensor representing the log of the scales of the - void - setLogScales(const torch::Tensor &logScales) { - TORCH_CHECK_VALUE(logScales.sizes() == mLogScales.sizes(), - "log_scales must have the same shape as the current scales"); - TORCH_CHECK_VALUE(logScales.device() == mLogScales.device(), - "log_scales must be on the same device as the current log_scales"); - mLogScales = logScales; - } - - /// @brief Set the quaternions of the Gaussians in this scene which define the rotation - /// component of the covariance of each Gaussian (in the form [x, y, z, w]). - /// @param quats An [N, 4]-shaped tensor representing the quaternions of the Gaussians in this - /// scene. - void - setQuats(const torch::Tensor &quats) { - TORCH_CHECK_VALUE(quats.sizes() == mQuats.sizes(), - "quats must have the same shape as the current quats"); - TORCH_CHECK_VALUE(quats.device() == mQuats.device(), - "quats must be on the same device as the current quats"); - mQuats = quats; - } - - /// @brief Set the means of the Gaussians in this scene. - /// @param means An [N, 3]-shaped tensor representing the means of the Gaussians in this scene. - void - setMeans(const torch::Tensor &means) { - TORCH_CHECK_VALUE(means.sizes() == mMeans.sizes(), - "means must have the same shape as the current means"); - TORCH_CHECK_VALUE(means.device() == mMeans.device(), - "means must be on the same device as the current means"); - mMeans = means; - } - - /// @brief Set the diffuse SH coefficients of the Gaussians in this scene. - /// @param sh0 An [N, 1, D]-shaped tensor representing the diffuse SH coefficients of the - /// Gaussians in this scene. - void - setSh0(const torch::Tensor &sh0) { - TORCH_CHECK_VALUE(sh0.sizes() == mSh0.sizes(), - "sh0 must have the same shape as the current sh0"); - TORCH_CHECK_VALUE(sh0.device() == mSh0.device(), - "sh0 must be on the same device as the current sh0"); - mSh0 = sh0; - } - - /// @brief Set the directionally-dependent SH coefficients of the Gaussians in this scene. - /// @param shN A [N, K-1, D]-shaped tensor representing the directionally-dependent SH - /// coefficients of the Gaussians in this scene. - void - setShN(const torch::Tensor &shN) { - TORCH_CHECK_VALUE(shN.sizes() == mShN.sizes(), - "shN must have the same shape as the current shN"); - TORCH_CHECK_VALUE(shN.device() == mShN.device(), - "shN must be on the same device as the current shN"); - mShN = shN; - } - - /// @brief Return whether to track the maximum 2D radii of each Gaussian over backward passes - /// of projection. - /// @return True if the maximum 2D radii are tracked, false otherwise. - /// @note This is used by some optimizers to decide whether to split/delete/duplicate Gaussians. - /// If this is set to true, the maximum 2D radii will be accumulated during the - /// backward pass of projection. - bool - accumulateMax2dRadii() const { - return mAccumulateMax2dRadii; - } - - /// @brief Set whether to accumulate the maximum 2D radii of each Gaussian over backward passes - /// of projection. - /// @param accumulateMax2dRadii Whether to accumulate the maximum 2D radii of each Gaussian - /// over backward passes of projection. - /// @note This is used by some optimizers to decide whether to split/delete/duplicate Gaussians. - /// If this is set to true, the maximum 2D radii will be accumulated during the - /// backward pass of projection. - void - setAccumulateMax2dRadii(bool accumulateMax2dRadii) { - mAccumulateMax2dRadii = accumulateMax2dRadii; - } - - /// @brief Return whether to accumulate the means 2D gradients of each Gaussian over backward - /// passes of projection. - /// @return True if the means 2D gradients are accumulated, false otherwise. - /// @note This is used by some optimizers to decide whether to split/delete/duplicate Gaussians. - /// If this is set to true, the average norm of the gradient of projected means - /// for each Gaussian will be accumulated during the backward pass of projection. - /// This is useful for some optimization techniques. - bool - accumulateMean2dGradients() const { - return mAccumulateMean2dGradients; - } - - /// @brief Set whether to accumulate the means 2D gradients of each Gaussian over backward - /// passes of projection. - /// @param accumulateMean2dGradients Whether to accumulate the means 2D gradients - /// of each Gaussian over backward passes of projection. - /// @note This is used by some optimizers to decide whether to split/delete/duplicate Gaussians. - /// If this is set to true, the average norm of the gradient - /// of projected means for each Gaussian will be accumulated during the backward pass - /// of projection. This is useful for some optimization techniques. - void - setAccumulateMean2dGradients(bool accumulateMean2dGradients) { - mAccumulateMean2dGradients = accumulateMean2dGradients; - } - - /// @brief Return true if all tensors tracked by this object require gradients. - /// @return True if all tensors tracked by this object require gradients, false otherwise. - /// @note This function checks if all tensors are leaf tensors and have requires_grad set to - /// true. - /// If any of the tensors are non-leaf tensors, this function will return false. - /// If you want to check if the tensors require gradients individually, you can use - /// the `requires_grad()` method on each tensor directly. - /// @note If you want to ensure all tensors are leaf tensors, you can create a - /// new GaussianSplat3d object with the `detach` flag set to `True` when - /// creating the object. - bool - requiresGrad() const { - return mMeans.requires_grad() && mQuats.requires_grad() && mLogScales.requires_grad() && - mLogitOpacities.requires_grad() && mSh0.requires_grad() && mShN.requires_grad(); - } - - /// @brief Set requires_grad on all tensors managed by this object. - /// @param requiresGrad Whether the tensors should require gradients. - /// @note This function will throw an error if any of the tensors are non-leaf tensors. - /// If you want to set requires_grad on specific tensors, set them on the tensors directly - /// instead of using this function. - /// @note If you want to ensure all tensors are leaf tensors, you can call .detach() or create a - /// new GaussianSplat3d object with the `detach` flag set to `True` when - /// creating the object. - void - setRequiresGrad(bool requiresGrad) { - TORCH_CHECK_VALUE( - mMeans.is_leaf(), - "Cannot set requires_grad of means which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - TORCH_CHECK_VALUE( - mQuats.is_leaf(), - "Cannot set requires_grad of quats which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - TORCH_CHECK_VALUE( - mLogScales.is_leaf(), - "Cannot set requires_grad of log_scales which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - TORCH_CHECK_VALUE( - mLogitOpacities.is_leaf(), - "Cannot set requires_grad of logit_opacities which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - TORCH_CHECK_VALUE( - mSh0.is_leaf(), - "Cannot set requires_grad of sh0 which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - TORCH_CHECK_VALUE( - mShN.is_leaf(), - "Cannot set requires_grad of shN which is a non-leaf tensor. " - "Call .detach() on this object or create a new GaussianSplat3d object with leaf tensors."); - - mMeans.requires_grad_(requiresGrad); - mQuats.requires_grad_(requiresGrad); - mLogScales.requires_grad_(requiresGrad); - mLogitOpacities.requires_grad_(requiresGrad); - mSh0.requires_grad_(requiresGrad); - mShN.requires_grad_(requiresGrad); - } - - /// @brief Set the data of the GaussianSplat3d object from the given tensors. - /// @param means An [N, 3]-shaped tensor representing the means of the Gaussians in this scene. - /// @param quats An [N, 4]-shaped tensor representing the quaternions of the Gaussians in this - /// scene. - /// @param logScales An [N, 3]-shaped tensor representing the log of the scales of the - /// Gaussians in this scene. - /// @param logitOpacities An [N]-shaped tensor representing the logit of the opacities of the - /// Gaussians in this scene. - /// @param sh0 An [N, 1, D]-shaped tensor representing the diffuse SH coefficients of the - /// Gaussians in this scene. - /// @param shN A [N, K-1, D]-shaped tensor representing the directionally-dependent SH - /// coefficients of the Gaussians in this scene. - void setState(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN); - - /// @brief Return the number of Gaussians in the scene. - /// @return The number of Gaussians in the scene. - int64_t - numGaussians() const { - return mMeans.size(0); - } - - /// @brief Return the number of SH basis coeffients used in the scene. - /// @return The number of SH bases used in the scene. - int64_t - numShBases() const { - return mShN.size(1) + 1; - } - - /// @brief Return the number of channels used in the scene (e.g. 3 for RGB colors). - /// @return The number of channels used in the scene. - int64_t - numChannels() const { - return mShN.size(2); - } - - /// @brief Return the accumulated gradient norms of projected Gaussians in this - /// scene across backward passes. - /// This is used during optimization to decide whether to split/delete/duplicate - /// Gaussians. - /// @return An [N]-shaped tensor representing the accumulated gradient norms of projected - /// Gaussians in this scene across backward passes or an empty tensor if - /// accumulateMean2dGradients is false. - torch::Tensor - accumulated2dMeansGradientNormsForGrad() const { - return mAccumulatedNormalized2dMeansGradientNormsForGrad; - } - - /// @brief Return the accumulated maximum 2D radii of projected Gaussians in this - /// scene across backward passes. - /// This is used during optimization to decide whether to split/delete/duplicate - /// Gaussians. - /// @return An [N]-shaped tensor representing the accumulated maximum 2D radii of projected - /// Gaussians in this scene across backward passes or an empty tensor if - /// accumulateMax2dRadii is false. - torch::Tensor - accumulatedMax2dRadiiForGrad() const { - return mAccumulated2dRadiiForGrad; - } - - /// @brief Return the backward passes used to accumulate each Gaussian during optimization. - /// This is used during optimization to decide whether to split/delete/duplicate - /// Gaussians. - /// @return An [N]-shaped tensor representing the backward passes used to accumulate each - /// Gaussian during optimization or an empty tensor if accumulateMean2dGradients is - /// false. - torch::Tensor - gradientStepCountsForGrad() const { - return mGradientStepCountForGrad; - } - - /// @brief Reset the gradient statistics of the Gaussians in this scene. - /// See @ref accumulated2dMeansGradientNormsForGrad, @ref gradientStepCountsForGrad, - /// @ref accumulatedMax2dRadiiForGrad. - /// @note This function is only valid if requiresGrad is true. - void - resetAccumulatedGradientState() { - if (mAccumulateMean2dGradients) { - mAccumulatedNormalized2dMeansGradientNormsForGrad = torch::Tensor(); - mGradientStepCountForGrad = torch::Tensor(); - } - if (mAccumulateMax2dRadii) { - mAccumulated2dRadiiForGrad = torch::Tensor(); - } - } - - /// @brief Return the state of the GaussianSplat3d object as a dictionary (similar to Pytorch's - /// nn.Module). - /// @return A dictionary containing the state of the GaussianSplat3d object. - std::unordered_map stateDict() const; - - /// @brief Load the state of the GaussianSplat3d object from a state_dict (similar to Pytorch's - /// nn.Module). - /// @param stateDict A dictionary containing the state of the GaussianSplat3d object. - void loadStateDict(const std::unordered_map &stateDict); - - /// @brief Precompute the projected Gaussians to be re-used for rendering images (e.g. if you - /// want to render multiple images with the same camera settings or image patches). - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param shDegreeToUse Degree of SH to use for rendering (use -1 to use all SH bases) - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return ProjectedGaussianSplats object that can be used to render images with @ref - /// renderFromProjectedGaussians - ProjectedGaussianSplats - projectGaussiansForImages(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Precompute the projected Gaussians to be re-used for rendering depths (e.g. if - /// you want to render multiple depth maps with the same camera settings or image patches). - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return ProjectedGaussianSplats object that can be used to render depths with @ref - /// renderFromProjectedGaussians - ProjectedGaussianSplats - projectGaussiansForDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Precompute the projected Gaussians to be re-used for rendering images and depths - /// (e.g. if you want to render multiple images and depth maps with the same camera settings - /// or image patches). - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param shDegreeToUse Degree of SH to use for rendering (use -1 to use all SH bases) - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return ProjectedGaussianSplats object that can be used to render images and depths with - /// @ref renderFromProjectedGaussians - ProjectedGaussianSplats projectGaussiansForImagesAndDepths( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - size_t imageWidth, - size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Save this scene and optional training metadata to a PLY file with the given filename - /// @param filename The path to save the PLY file to - /// @param metadata An optional dictionary of training metadata to include in the PLY file. The - /// keys are strings and the values are either strings, int64s, doubles, or tensors - void savePly(const std::string &filename, - std::optional> metadata) const; - - /// @brief Load a PLY file's means, quats, scales, opacities, and SH coefficients as the state - /// of this GaussianSplat3d object - /// @param filename Filename of the PLY file - /// @param device Device to transfer the loaded tensors to - /// @return The loaded GaussianSplat3d class, and a dictionary of metadata (can be empty if no - // metadata was saved in the PLY file). The metadata keys are strings and the values are either - // strings, int64s, doubles, or tensors. - static std::tuple> - fromPly(const std::string &filename, torch::Device device = torch::kCPU); - - /// @brief Render using precomputed projected Gaussians (see - /// @ref projectGaussiansForImages, @ref projectGaussiansForDepths, - /// @ref projectGaussiansForImagesAndDepths). - /// Optionally lets you render a cropped image by specifying the crop width, height, and origin. - /// @param projectedGaussians ProjectedGaussianSplats object obtained from @ref - /// projectGaussiansForImages, @ref projectGaussiansForDepths, or @ref - /// projectGaussiansForImagesAndDepths - /// @param cropWidth Width of the cropped image (use -1 for no cropping) - /// @param cropHeight Height of the cropped image (use -1 for no cropping) - /// @param cropOriginW Origin of the cropped image in the width dimension (use -1 for no - /// cropping) - /// @param cropOriginH Origin of the cropped image in the height dimension (use -1 for no - /// cropping) - /// @param tileSize Size of the tiles used for rendering - /// @param backgrounds Optional [C, D] tensor of background colors for each camera - /// @return Tuple of two tensors: - /// images: A [C, H, W, D|1|D+1] tensor containing the the rendered image - /// (or depth or image and depth) for each camera - /// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered images - std::tuple - renderFromProjectedGaussians(const GaussianSplat3d::ProjectedGaussianSplats &projectedGaussians, - const ssize_t cropWidth = -1, - const ssize_t cropHeight = -1, - const ssize_t cropOriginW = -1, - const ssize_t cropOriginH = -1, - const size_t tileSize = 16, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Render images of this Gaussian splat scene from the given camera matrices and - /// projection matrices. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param shDegreeToUse Degree of SH to use for rendering (use -1 to use all SH bases) - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @param backgrounds Optional [C, D] tensor of background colors for each camera - /// @return Tuple of two tensors: - /// images: A [C, H, W, D] tensor containing the the rendered image for each camera - /// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered images - std::tuple - renderImages(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Render images by rasterizing directly from world-space 3D Gaussians. - /// - /// This is similar to @ref renderImages but performs rasterization directly from world-space - /// Gaussians (means/quats/log-scales) rather than from their 2D projections. This enables - /// geometry gradients through the rasterization step. - /// - /// Tile intersections are still computed using a (non-differentiable) projection step: - /// - For `cameraModel == CameraModel::PINHOLE` or `CameraModel::ORTHOGRAPHIC`, - /// `ProjectionMethod::AUTO` reuses the classic analytic projection path, while - /// callers may explicitly request `ProjectionMethod::UNSCENTED`. - /// - For OpenCV camera models, we use the Unscented Transform (UT) projection kernel to - /// compute per-Gaussian radii and depths for sorting / tiling, then rasterize with 3DGS. - std::tuple - renderImagesFromWorld(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Render depths of this Gaussian splat scene from the given camera matrices and - /// projection matrices. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @param backgrounds Optional [C, 1] tensor of background depths for each camera - /// @return Tuple of two tensors: - /// images: A [C, H, W, 1] tensor containing the the rendered depths for each camera - /// alphas: A [C, H, W, 1] tensor containing the alpha values of the rendered depths - std::tuple - renderDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple - renderDepthsFromWorld(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple - renderImagesAndDepths(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple renderImagesAndDepthsFromWorld( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple - sparseRenderImages(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple - sparseRenderDepths(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - std::tuple - sparseRenderImagesAndDepths(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const int64_t shDegreeToUse = -1, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Render the number of contributing Gaussians for each pixel in the image. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return Tuple of two tensors: - /// num_contributing_gaussians: A [C, H, W] tensor containing the number of contributing - /// Gaussians for each pixel for each camera - /// alphas: A [C, H, W] tensor containing the alpha values of the rendered images - std::tuple renderNumContributingGaussians( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Render the number of contributing Gaussians for each pixel in the image. - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @return Tuple of two tensors: - /// num_contributing_gaussians: A [P1 + P2 + ..., 1] jagged tensor containing the number of - /// contributing - /// Gaussians for each pixel for each camera - /// alphas: A [P1 + P2 + ..., 1] jagged tensor containing the composited alpha value of the - /// pixels - std::tuple - sparseRenderNumContributingGaussians(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const size_t tileSize, - const float minRadius2d, - const float eps2d, - const bool antialias); - - /// @brief Render the IDs of the gaussians that are the top K contributors to the rendered - /// pixels and the value of the weighted contribution to the rendered pixels. If the size of - /// `numSamples`(i.e. K) is greater than the number of contributing samples for a pixel, the - /// remaining samples' weights are filled with zeros and the IDs are filled with -1. - /// @param numSamples Requested number of top K contributing samples per pixel - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return Tuple of two tensors: - /// ids: A [C, H, W, K] tensor containing the the IDs of the top K contributors to the - /// rendered pixel for each camera - /// weights: A [C, H, W, K] tensor containing the weights of the top K contributors to the - /// rendered pixel for each camera. The weights are normalized to sum to 1 if the - /// list is exahustive of all contributing samples. - std::tuple renderTopContributingGaussianIds( - const int numSamples, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Render the IDs of the gaussians that are the top K contributors to the rendered - /// pixels and the value of the weighted contribution to the rendered pixels. If the size of - /// `numSamples`(i.e. K) is greater than the number of contributing samples for a pixel, the - /// remaining samples' weights are filled with zeros and the IDs are filled with -1. This - /// function will render only a sparse subset of the pixels in the overall image, as specified - /// by the `pixelsToRender` parameter. - /// @param numSamples Requested number of top K contributing samples per pixel - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @return Tuple of two tensors: - /// ids: A [P1 + P2 + ..., K] jagged tensor containing the the IDs of the top K contributors - /// to the - /// rendered pixel for each camera - /// weights: A [P1 + P2 + ..., K] jagged tensor containing the weights of the top K - /// contributors to the - /// rendered pixel for each camera. The weights are normalized to sum to 1 if the - /// list is exahustive of all contributing samples. - std::tuple sparseRenderTopContributingGaussianIds( - const int numSamples, - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false); - - /// @brief Render the IDs of the gaussians that are the contributors to the rendered images' - /// pixels and the value of their weighted contributions to the rendered pixels. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @param topKContributors If > 0, uses the efficient top-K kernel to return only the top K - /// contributors per pixel. If 0 (default), returns all contributing Gaussians. - /// @return Tuple of two JaggedTensors: - /// ids: A [[C1P1 + C1P2 + ... C1P(imageWidth * imageHeight), 1], ... [CNP1 + CNP2 + ... - /// CNP(imageWidth * imageHeight), 1]] jagged tensor containing the IDs of the - /// contributing Gaussians of each rendered pixel for each camera. The IDs are in - /// row-major order. - /// weights: A [[C1P1 + C1P2 + ... C1P(imageWidth * imageHeight), 1], ... [CNP1 + CNP2 + ... - /// CNP(imageWidth * imageHeight), 1]] jagged tensor containing the weights of the - /// contributing Gaussians of each rendered pixel for each camera. The weights are in - /// row-major order and sum to 1 for each pixel if that pixel is opaque (alpha=1). - std::tuple renderContributingGaussianIds( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const int topKContributors = 0); - - /// @brief Render the IDs of the gaussians that are the contributors to the rendered images' - /// pixels and the value of their weighted contributions to the rendered pixels. This - /// function will render only a sparse subset of the pixels in the overall image, as specified - /// by the `pixelsToRender` parameter. - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 4, 4] Projection matrices - /// @param imageWidth Width of the image - /// @param imageHeight Height of the image - /// @param near Near plane - /// @param far Far plane - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param tileSize Size of the tiles used for rendering - /// @param minRadius2d Minimum radius in pixels below which projected Gaussians are ignored - /// @param eps2d Blur factor for antialiasing (only used if antialias is true) - /// @param antialias Whether to antialias the image - /// @param topKContributors If > 0, uses the efficient top-K kernel to return only the top K - /// contributors per pixel. If 0 (default), returns all contributing Gaussians. - /// @return Tuple of two JaggedTensors: - /// ids: A [[C1P1 + C1P2 + ... C1PN1, 1], ... [CNP1 + CNP2 + ... CNPNN, 1]] jagged tensor - /// containing the IDs of the contributing Gaussians of each rendered pixel for each - /// camera. The IDs are in row-major order. - /// weights: A [[C1P1 + C1P2 + ... C1PN1, 1], ... [CNP1 + CNP2 + ... CNPNN, 1]] jagged - /// tensor containing the weights of the contributing Gaussians of each rendered - /// pixel for each camera. The weights are in row-major order and sum to 1 for each - /// pixel if that pixel is opaque (alpha=1). - std::tuple sparseRenderContributingGaussianIds( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const size_t imageWidth, - const size_t imageHeight, - const float near, - const float far, - const CameraModel cameraModel = CameraModel::PINHOLE, - const ProjectionMethod projectionMethod = ProjectionMethod::AUTO, - const std::optional &distortionCoeffs = std::nullopt, - const size_t tileSize = 16, - const float minRadius2d = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const int topKContributors = 0); - - /// @brief Relocate Gaussians by adjusting opacity and scale based on replication ratio. - /// @param logScales Log scales of the Gaussians to relocate [N, 3]. - /// @param logitOpacities Logit opacities of the Gaussians to relocate [N]. - /// @param ratios Replication ratios per Gaussian [N]. - /// @param binomialCoeffs Binomial coefficients table [nMax, nMax]. - /// @param nMax Maximum replication ratio (size of binomial table). - /// @param minOpacity Minimum opacity - /// @return Tuple of (logitOpacitiesNew [N], logScalesNew [N, 3]) - std::tuple - relocateGaussians(const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &ratios, // [N] - const torch::Tensor &binomialCoeffs, // [nMax, nMax] - const int nMax, - const float minOpacity); - - /// @brief Add noise to the Gaussian positions (means), scaled by noiseScale. - /// @param noiseScale Noise scale - /// @param t Cutoff for opacity scaling - /// @param k Exponent for opacity scaling - void addNoiseToMeans(const float noiseScale, const float t = 0.005, const float k = 100.0); - - /// @brief Select a subset of the Gaussians in this scene based on the given slice. - /// @param begin The start index of the slice (inclusive) - /// @param end The end index of the slice (exclusive) - /// @param step The step size of the slice - /// @return A new GaussianSplat3d object containing only the selected Gaussians. - GaussianSplat3d sliceSelect(const int64_t begin, const int64_t end, const int64_t step) const; - - /// @brief Select a subset of the Gaussians in this scene based on the given indices. - /// @param indices A 1D tensor of indices in the range [0, numGaussians-1] to select from the - // Gaussians in this scene. - /// @return A new GaussianSplat3d object containing only the selected Gaussians. - GaussianSplat3d indexSelect(const torch::Tensor &indices) const; - - /// @brief Select a subset of the Gaussians in this scene based on the given mask. - /// @param mask A 1D boolean tensor of shape [N] where N is the number of Gaussians in this - /// scene. The mask indicates which Gaussians to select. - /// @return A new GaussianSplat3d object containing only the selected Gaussians. - /// The mask must have the same length as the number of Gaussians in this scene. - GaussianSplat3d maskSelect(const torch::Tensor &mask) const; - - /// @brief Assign new Gaussians to a subset of the Gaussians in this scene based on the given - /// indices. - /// @param indices A 1D tensor of indices in the range [0, numGaussians-1] to assign new - /// Gaussians to. - /// @param other A GaussianSplat3d object containing the new Gaussians to assign. - void indexSet(const torch::Tensor &indices, const GaussianSplat3d &other); - - /// @brief Assign new Gaussians to a subset of the Gaussians in this scene based on the given - /// slice. - /// @param begin The start index of the slice (inclusive) - /// @param end The end index of the slice (exclusive) - /// @param step The step size of the slice - /// @param other A GaussianSplat3d object containing the new Gaussians to assign. - /// The mask must have the same length as the number of Gaussians in this scene. - void sliceSet(const int64_t begin, - const int64_t end, - const int64_t step, - const GaussianSplat3d &other); - - /// @brief Assign new Gaussians to a subset of the Gaussians in this scene based on the given - /// mask. - /// @param mask A 1D boolean tensor of shape [N] where N is the number of Gaussians in this - /// scene. The mask indicates which Gaussians to assign. - /// @param other A GaussianSplat3d object containing the new Gaussians to assign. - /// The mask must have the same length as the number of Gaussians in this scene. - void maskSet(const torch::Tensor &mask, const GaussianSplat3d &other); - - private: - torch::Tensor mMeans; // [N, 3] - torch::Tensor mQuats; // [N, 4] - torch::Tensor mLogScales; // [N, 3] - torch::Tensor mLogitOpacities; // [N] - torch::Tensor mSh0; // [N, 1, D] - torch::Tensor mShN; // [N, K-1, D] - - // Used for subdivision during optimization - torch::Tensor mAccumulatedNormalized2dMeansGradientNormsForGrad; // [N] - torch::Tensor mAccumulated2dRadiiForGrad; // [N] - torch::Tensor mGradientStepCountForGrad; // [N] - bool mAccumulateMean2dGradients = false; - bool mAccumulateMax2dRadii = false; - - static void checkState(const torch::Tensor &means, - const torch::Tensor &quats, - const torch::Tensor &logScales, - const torch::Tensor &logitOpacities, - const torch::Tensor &sh0, - const torch::Tensor &shN); - - ProjectedGaussianSplats projectGaussiansImpl(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel); - - ProjectedGaussianSplats - projectGaussiansForCameraImpl(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs); - - /// @brief Project Gaussians with sparse tile intersection for efficient sparse rendering. - /// @param pixelsToRender JaggedTensor of pixel coordinates to render [P1 + P2 + ..., 2] - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 3, 3] Projection matrices - /// @param settings Render settings - /// @return SparseProjectedGaussianSplats containing projected Gaussians and sparse tile data - SparseProjectedGaussianSplats - sparseProjectGaussiansImpl(const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel); - - /// @brief Sparse-project Gaussians with explicit camera-model configuration. - /// @param pixelsToRender JaggedTensor of pixel coordinates to render [P1 + P2 + ..., 2] - /// @param worldToCameraMatrices [C, 4, 4] Camera-to-world matrices - /// @param projectionMatrices [C, 3, 3] Projection matrices - /// @param settings Render settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @return SparseProjectedGaussianSplats containing projected Gaussians and sparse tile data - SparseProjectedGaussianSplats - sparseProjectGaussiansForCameraImpl(const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs); - - std::tuple renderCropFromProjectedGaussiansImpl( - const ProjectedGaussianSplats &state, - const size_t tileSize, - const ssize_t cropWidth, - const ssize_t cropHeight, - const ssize_t cropOriginW, - const ssize_t cropOriginH, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Implements index set with a tensor of booleans or integer indices - /// @param indexOrMask A 1D tensor of indices in the range [0, numGaussians-1] or a boolean mask - /// of shape [N] where N is the number of Gaussians in this scene. - /// @param other A GaussianSplat3d object containing the new Gaussians to assign. - /// The mask must have the same length as the number of Gaussians in this scene. - void tensorIndexSetImpl(const torch::Tensor &indexOrMask, const GaussianSplat3d &other); - - /// @brief Implements indexing with a tensor of booleans or integer indices - /// @param indexOrMask A 1D tensor of indices in the range [0, numGaussians-1] or a boolean mask - /// of shape [N] where N is the number of Gaussians in this scene. - /// @return A new GaussianSplat3d object containing only the selected Gaussians. - /// The mask must have the same length as the number of Gaussians in this scene. - GaussianSplat3d tensorIndexGetImpl(const torch::Tensor &indexOrMask) const; - - /// @brief Render the scene described by the Gaussian splats at the specified pixels in the - /// specified views. This function returns a single render quantity (RGB, depth, RGB+D) - /// and single alpha value per pixel. - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @return Tuple of (render quantity, alpha value) - std::tuple - sparseRenderImpl(const JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - - /// @brief Render the number of contributing Gaussians for each pixel in the image. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @return Tuple of two tensors: - /// num_contributing_gaussians: A [B, H, W] tensor containing the number of contributing - /// Gaussians for each pixel for each camera - std::tuple - renderNumContributingGaussiansImpl(const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs); - - /// @brief Render the number of contributing Gaussians for each pixel in the image. - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @return Tuple of two tensors: - /// num_contributing_gaussians: A [P1 + P2 + ..., 1] jagged tensor containing the number of - /// contributing - /// Gaussians for each pixel for each camera - /// alphas: A [P1 + P2 + ..., 1] jagged tensor containing the composited alpha value of the - /// pixels - std::tuple - sparseRenderNumContributingGaussiansImpl(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs); - - /// @brief Render the gaussian splatting scene - /// For every pixel being rendered, this function returns multiple samples in depth of - /// the gaussian IDs and multiple samples of the weighted alpha values. The samples are - /// ordered front to back in their depth ordering from camera. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param maybeNumContributingGaussians [C, H, W] tensor containing the number of contributing - /// Gaussians for each pixel for each camera. If not - /// provided, ``settings`` must have ``numDepthSamples`` - /// set to a value greater than 0. - /// @return Tuple of two JaggedTensors: - /// ids: A [[C1P1 + C1P2 + ... C1P(imageWidth * imageHeight), 1], ... [CNP1 + CNP2 + ... - /// CNP(imageWidth * imageHeight), 1]] jagged tensor containing the IDs of the - /// contributing Gaussians of each rendered pixel for each camera. The IDs are in - /// row-major order. - /// weights: A [[C1P1 + C1P2 + ... C1P(imageWidth * imageHeight), 1], ... [CNP1 + CNP2 + ... - /// CNP(imageWidth * imageHeight), 1]] jagged tensor containing the weights of the - /// contributing Gaussians of each rendered pixel for each camera. The weights are in - /// row-major order and sum to 1 for each pixel if that pixel is opaque (alpha=1). - std::tuple renderContributingGaussianIdsImpl( - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &maybeNumContributingGaussians = std::nullopt); - - /// @brief Sparse render the gaussian splatting scene - /// For every pixel being rendered, this function returns multiple samples in depth of - /// the gaussian IDs and multiple samples of the weighted alpha values. The number of - /// samples per pixel is determined by the sampling parameters in the settings. If - /// the size of the requested number of samples is greater than the number of - /// contributing samples for a pixel, the remaining samples' weights are filled with - /// zeros and the IDs are filled with -1. The samples are ordered front to back in - /// their depth ordering from camera. - /// @param pixelsToRender [P1 + P2 + ..., 2] JaggedTensor of pixels per camera to render. - /// @param worldToCameraMatrices [C, 4, 4] - /// @param projectionMatrices [C, 3, 3] - /// @param settings - /// @param cameraModel Semantic camera model for projection - /// @param projectionMethod Projection implementation selector - /// @param distortionCoeffs Optional OpenCV distortion coefficients for distorted cameras - /// @param maybeNumContributingGaussians [C, H, W] tensor containing the number of contributing - /// Gaussians for each pixel for each camera. If provided, - /// the kernel will use the top-k path and ignore this - /// tensor. If not provided, ``settings`` must have - /// ``numDepthSamples`` set to a value greater than 0. - /// @return Tuple of two tensors: - /// ids: A [P1 + P2 + ..., K] jagged tensor containing the the IDs of the top K contributors - /// to the - /// rendered pixel for each camera - /// weights: A [P1 + P2 + ..., K] jagged tensor containing the weights of the top K - /// contributors to the - /// rendered pixel for each camera. The weights are normalized to sum to the alpha - /// value of the final rendered pixel if the list is exahustive of all contributing - /// samples. - std::tuple sparseRenderContributingGaussianIdsImpl( - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &projectionMatrices, - const fvdb::detail::ops::RenderSettings &settings, - const CameraModel cameraModel, - const ProjectionMethod projectionMethod, - const std::optional &distortionCoeffs, - const std::optional &maybeNumContributingGaussians = std::nullopt); - - torch::Tensor evalSphericalHarmonicsImpl(const int64_t shDegreeToUse, - const torch::Tensor &worldToCameraMatrices, - const torch::Tensor &perGaussianProjectedRadii) const; -}; - -std::tuple> -gaussianRenderJagged(const JaggedTensor &means, // [N1 + N2 + ..., 3] - const JaggedTensor &quats, // [N1 + N2 + ..., 4] - const JaggedTensor &scales, // [N1 + N2 + ..., 3] - const JaggedTensor &opacities, // [N1 + N2 + ...] - const JaggedTensor &sh_coeffs, // [N1 + N2 + ..., K, 3] - const JaggedTensor &viewmats, // [C1 + C2 + ..., 4, 4] - const JaggedTensor &Ks, // [C1 + C2 + ..., 3, 3] - const uint32_t image_width, - const uint32_t image_height, - const float near_plane = 0.01, - const float far_plane = 1e10, - const int sh_degree_to_use = -1, - const int tile_size = 16, - const float radius_clip = 0.0, - const float eps2d = 0.3, - const bool antialias = false, - const bool render_depth_channel = false, - const bool return_debug_info = false, - const bool render_depth_only = false, - const bool ortho = false, - const std::optional &backgrounds = std::nullopt, - const std::optional &masks = std::nullopt); - -} // namespace fvdb - -#endif // FVDB_GAUSSIANSPLAT3D_H diff --git a/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp b/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp deleted file mode 100644 index ab19354d0..000000000 --- a/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include - -namespace fvdb { -namespace detail { -namespace autograd { - -EvaluateSphericalHarmonics::VariableList -EvaluateSphericalHarmonics::forward( - EvaluateSphericalHarmonics::AutogradContext *ctx, - const ssize_t shDegreeToUse, - const size_t numCameras, - const std::optional - viewDirections, // [C, N, 3] (optional) - const EvaluateSphericalHarmonics::Variable &sh0Coeffs, // [N, 1, D] - const std::optional &shNCoeffs, // [N, K-1, D] - const EvaluateSphericalHarmonics::Variable &radii // [C, N] -) { - FVDB_FUNC_RANGE_WITH_NAME("EvaluateSphericalHarmonics::forward"); - const Variable viewDirectionsValue = viewDirections.value_or(torch::Tensor()); - const Variable shNCoeffsValue = shNCoeffs.value_or(torch::Tensor()); - const Variable renderQuantities = FVDB_DISPATCH_KERNEL(sh0Coeffs.device(), [&]() { - return ops::dispatchSphericalHarmonicsForward( - shDegreeToUse, numCameras, viewDirectionsValue, sh0Coeffs, shNCoeffsValue, radii); - }); - ctx->save_for_backward({viewDirectionsValue, shNCoeffsValue, radii}); - ctx->saved_data["shDegreeToUse"] = static_cast(shDegreeToUse); - ctx->saved_data["numCameras"] = static_cast(numCameras); - ctx->saved_data["numGaussians"] = static_cast(sh0Coeffs.size(0)); - return {renderQuantities}; -} - -EvaluateSphericalHarmonics::VariableList -EvaluateSphericalHarmonics::backward(EvaluateSphericalHarmonics::AutogradContext *ctx, - EvaluateSphericalHarmonics::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("EvaluateSphericalHarmonics::backward"); - - // ensure the gradients are contiguous if they are not None - auto const dLossDColors = - gradOutput.at(0).defined() ? gradOutput.at(0).contiguous() : gradOutput.at(0); - - VariableList saved = ctx->get_saved_variables(); - Variable viewDirs = saved.at(0); - Variable shNCoeffs = saved.at(1); - Variable radii = saved.at(2); - - const int shDegreeToUse = static_cast(ctx->saved_data["shDegreeToUse"].toInt()); - const int numCameras = static_cast(ctx->saved_data["numCameras"].toInt()); - const int numGaussians = static_cast(ctx->saved_data["numGaussians"].toInt()); - // Only compute viewDirs gradients if viewDirs is defined and requires grad - const bool computeDLossDViewDirs = viewDirs.defined() && viewDirs.requires_grad(); - - auto variables = FVDB_DISPATCH_KERNEL(dLossDColors.device(), [&]() { - return ops::dispatchSphericalHarmonicsBackward(shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDColors, - radii, - computeDLossDViewDirs); - }); - Variable dLossDSh0Coeffs = std::get<0>(variables); - Variable dLossDShNCoeffs = std::get<1>(variables); - Variable dLossDViewDirs = std::get<2>(variables); - - return {Variable(), Variable(), dLossDViewDirs, dLossDSh0Coeffs, dLossDShNCoeffs, Variable()}; -} - -} // namespace autograd -} // namespace detail -} // namespace fvdb diff --git a/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.h b/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.h deleted file mode 100644 index 46d0ea43f..000000000 --- a/src/fvdb/detail/autograd/EvaluateSphericalHarmonics.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_AUTOGRAD_EVALUATESPHERICALHARMONICS_H -#define FVDB_DETAIL_AUTOGRAD_EVALUATESPHERICALHARMONICS_H - -#include - -namespace fvdb { -namespace detail { -namespace autograd { - -struct EvaluateSphericalHarmonics : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList - forward(AutogradContext *ctx, - const ssize_t shDegreeToUse, - const size_t numCameras, - const std::optional viewDirections, // [C, N, 3] or empty for deg 0 - const Variable &sh0Coeffs, // [N, 1, D] - const std::optional &shNCoeffs, // [N, K-1, D] - const Variable &radii // [C, N] - ); - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -} // namespace autograd -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_AUTOGRAD_EVALUATESPHERICALHARMONICS_H diff --git a/src/fvdb/detail/autograd/GaussianProjection.cpp b/src/fvdb/detail/autograd/GaussianProjection.cpp deleted file mode 100644 index d2e84551a..000000000 --- a/src/fvdb/detail/autograd/GaussianProjection.cpp +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace fvdb::detail::autograd { - -ProjectGaussians::VariableList -ProjectGaussians::forward(ProjectGaussians::AutogradContext *ctx, - const ProjectGaussians::Variable &means, - const ProjectGaussians::Variable &quats, - const ProjectGaussians::Variable &logScales, - const ProjectGaussians::Variable &worldToCamMatrices, - const ProjectGaussians::Variable &projectionMatrices, - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2D, - const bool calcCompensations, - const bool ortho, - std::optional outNormalizeddLossdMeans2dNormAccum, - std::optional outNormalizedMaxRadiiAccum, - std::optional outGradientStepCount) { - FVDB_FUNC_RANGE_WITH_NAME("ProjectGaussians::forward"); - TORCH_CHECK(means.dim() == 2, "means must have shape (N, 3)"); - TORCH_CHECK(worldToCamMatrices.dim() == 3, "worldToCamMatrices must have shape (C, 4, 4)"); - TORCH_CHECK(projectionMatrices.dim() == 3, "projectionMatrices must have shape (C, 3, 3)"); - - auto variables = FVDB_DISPATCH_KERNEL(means.device(), [&]() { - return ops::dispatchGaussianProjectionForward(means, - quats, - logScales, - worldToCamMatrices, - projectionMatrices, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2D, - calcCompensations, - ortho); - }); - Variable radii = std::get<0>(variables); - Variable means2d = std::get<1>(variables); - Variable depths = std::get<2>(variables); - Variable conics = std::get<3>(variables); - - ctx->saved_data["imageWidth"] = static_cast(imageWidth); - ctx->saved_data["imageHeight"] = static_cast(imageHeight); - ctx->saved_data["eps2d"] = static_cast(eps2d); - ctx->saved_data["calcCompensations"] = static_cast(calcCompensations); - ctx->saved_data["ortho"] = static_cast(ortho); - - const bool saveAccumState = outNormalizeddLossdMeans2dNormAccum.has_value(); - const bool trackMaxRadii = outNormalizedMaxRadiiAccum.has_value(); - ctx->saved_data["saveAccumState"] = saveAccumState; - ctx->saved_data["trackMaxRadii"] = trackMaxRadii; - if (saveAccumState) { - ctx->saved_data["outNormalizeddLossdMeans2dNormAccum"] = - outNormalizeddLossdMeans2dNormAccum.value(); - ctx->saved_data["outGradientStepCount"] = outGradientStepCount.value(); - } - if (trackMaxRadii) { - ctx->saved_data["outNormalizedMaxRadiiAccum"] = outNormalizedMaxRadiiAccum.value(); - } - - if (calcCompensations) { - Variable compensations = std::get<4>(variables); - ctx->save_for_backward({means, - quats, - logScales, - worldToCamMatrices, - projectionMatrices, - radii, - conics, - compensations}); - return {radii, means2d, depths, conics, compensations}; - } else { - ctx->save_for_backward( - {means, quats, logScales, worldToCamMatrices, projectionMatrices, radii, conics}); - return {radii, means2d, depths, conics}; - } -} - -ProjectGaussians::VariableList -ProjectGaussians::backward(ProjectGaussians::AutogradContext *ctx, - ProjectGaussians::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("ProjectGaussians::backward"); - Variable dLossDRadii = gradOutput.at(0); - Variable dLossDMeans2d = gradOutput.at(1); - Variable dLossDDepths = gradOutput.at(2); - Variable dLossDConics = gradOutput.at(3); - - // ensure the gradients are contiguous if they are not None - if (dLossDRadii.defined()) { - dLossDRadii = dLossDRadii.contiguous(); - } - if (dLossDMeans2d.defined()) { - dLossDMeans2d = dLossDMeans2d.contiguous(); - } - if (dLossDDepths.defined()) { - dLossDDepths = dLossDDepths.contiguous(); - } - if (dLossDConics.defined()) { - dLossDConics = dLossDConics.contiguous(); - } - - VariableList saved = ctx->get_saved_variables(); - Variable means = saved.at(0); - Variable quats = saved.at(1); - Variable logScales = saved.at(2); - Variable worldToCamMatrices = saved.at(3); - Variable projectionMatrices = saved.at(4); - Variable radii = saved.at(5); - Variable conics = saved.at(6); - - const bool calcCompensations = ctx->saved_data["calcCompensations"].toBool(); - - at::optional compensations, dLossDCompensations; - if (calcCompensations) { - Variable vcomp = gradOutput.at(4); - if (vcomp.defined()) { - vcomp = vcomp.contiguous(); - } - dLossDCompensations = vcomp; - compensations = saved.at(7); - } - - const int imageWidth = static_cast(ctx->saved_data["imageWidth"].toInt()); - const int imageHeight = static_cast(ctx->saved_data["imageHeight"].toInt()); - const float eps2d = static_cast(ctx->saved_data["eps2d"].toDouble()); - const bool ortho = ctx->saved_data["ortho"].toBool(); - const bool saveAccumState = ctx->saved_data["saveAccumState"].toBool(); - const bool trackMaxRadii = ctx->saved_data["trackMaxRadii"].toBool(); - - auto [normalizeddLossdMeans2dNormAccum, normalizedMaxRadiiAccum, gradientStepCount] = [&]() { - return std::make_tuple( - saveAccumState ? std::optional( - ctx->saved_data["outNormalizeddLossdMeans2dNormAccum"].toTensor()) - : std::nullopt, - trackMaxRadii ? std::optional( - ctx->saved_data["outNormalizedMaxRadiiAccum"].toTensor()) - : std::nullopt, - saveAccumState - ? std::optional(ctx->saved_data["outGradientStepCount"].toTensor()) - : std::nullopt); - }(); - auto variables = FVDB_DISPATCH_KERNEL(means.device(), [&]() { - return ops::dispatchGaussianProjectionBackward(means, - quats, - logScales, - worldToCamMatrices, - projectionMatrices, - compensations, - imageWidth, - imageHeight, - eps2d, - radii, - conics, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - dLossDCompensations, - ctx->needs_input_grad(4), - ortho, - normalizeddLossdMeans2dNormAccum, - normalizedMaxRadiiAccum, - gradientStepCount); - }); - - Variable dLossDMeans = std::get<0>(variables); - // Variable dLossDCovars = std::get<1>(variables); - Variable dLossDQuats = std::get<2>(variables); - Variable dLossDScales = std::get<3>(variables); - Variable dLossDWorldToCams = std::get<4>(variables); - - return {dLossDMeans, - dLossDQuats, - dLossDScales, - dLossDWorldToCams, - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable()}; -} - -ProjectGaussiansJagged::VariableList -ProjectGaussiansJagged::forward( - ProjectGaussiansJagged::AutogradContext *ctx, - const ProjectGaussiansJagged::Variable &gSizes, // [B] gaussian sizes - const ProjectGaussiansJagged::Variable &means, // [ggz, 3] - const ProjectGaussiansJagged::Variable &quats, // [ggz, 4] optional - const ProjectGaussiansJagged::Variable &scales, // [ggz, 3] optional - const ProjectGaussiansJagged::Variable &cSizes, // [B] camera sizes - const ProjectGaussiansJagged::Variable &worldToCamMatrices, // [ccz, 4, 4] - const ProjectGaussiansJagged::Variable &projectionMatrices, // [ccz, 3, 3] - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2D, - const bool ortho) { - FVDB_FUNC_RANGE_WITH_NAME("ProjectGaussiansJagged::forward"); - auto variables = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { - return ops::dispatchGaussianProjectionJaggedForward(gSizes, - means, - quats, - scales, - cSizes, - worldToCamMatrices, - projectionMatrices, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2D, - ortho); - }); - Variable radii = std::get<0>(variables); - Variable means2d = std::get<1>(variables); - Variable depths = std::get<2>(variables); - Variable conics = std::get<3>(variables); - - ctx->save_for_backward({gSizes, - means, - quats, - scales, - cSizes, - worldToCamMatrices, - projectionMatrices, - radii, - conics}); - ctx->saved_data["imageWidth"] = (int64_t)imageWidth; - ctx->saved_data["imageHeight"] = (int64_t)imageHeight; - ctx->saved_data["eps2d"] = (double)eps2d; - ctx->saved_data["ortho"] = (bool)ortho; - - return {radii, means2d, depths, conics}; -} - -ProjectGaussiansJagged::VariableList -ProjectGaussiansJagged::backward(ProjectGaussiansJagged::AutogradContext *ctx, - ProjectGaussiansJagged::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("ProjectGaussiansJagged::backward"); - Variable dLossDRadii = gradOutput.at(0); - Variable dLossDMeans2d = gradOutput.at(1); - Variable dLossDDepths = gradOutput.at(2); - Variable dLossDConics = gradOutput.at(3); - - // ensure the gradients are contiguous if they are not None - if (dLossDRadii.defined()) - dLossDRadii = dLossDRadii.contiguous(); - if (dLossDMeans2d.defined()) - dLossDMeans2d = dLossDMeans2d.contiguous(); - if (dLossDDepths.defined()) - dLossDDepths = dLossDDepths.contiguous(); - if (dLossDConics.defined()) - dLossDConics = dLossDConics.contiguous(); - - VariableList saved = ctx->get_saved_variables(); - Variable gSizes = saved.at(0); - Variable means = saved.at(1); - Variable quats = saved.at(2); - Variable scales = saved.at(3); - Variable cSizes = saved.at(4); - Variable worldToCamMatrices = saved.at(5); - Variable projectionMatrices = saved.at(6); - Variable radii = saved.at(7); - Variable conics = saved.at(8); - - const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); - const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); - const float eps2d = (float)ctx->saved_data["eps2d"].toDouble(); - const bool ortho = (bool)ctx->saved_data["ortho"].toBool(); - - auto variables = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { - return ops::dispatchGaussianProjectionJaggedBackward(gSizes, - means, - quats, - scales, - cSizes, - worldToCamMatrices, - projectionMatrices, - imageWidth, - imageHeight, - eps2d, - radii, - conics, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - ctx->needs_input_grad(6), - ortho); - }); - Variable dLossDMeans = std::get<0>(variables); - // Variable dLossDCovars = std::get<1>(variables); - Variable dLossDQuats = std::get<2>(variables); - Variable dLossDScales = std::get<3>(variables); - Variable dLossDWorldToCams = std::get<4>(variables); - - return {Variable(), - dLossDMeans, - dLossDQuats, - dLossDScales, - Variable(), - dLossDWorldToCams, - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable()}; -} - -} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianProjection.h b/src/fvdb/detail/autograd/GaussianProjection.h deleted file mode 100644 index d0ed5c33c..000000000 --- a/src/fvdb/detail/autograd/GaussianProjection.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANPROJECTION_H -#define FVDB_DETAIL_AUTOGRAD_GAUSSIANPROJECTION_H - -#include - -namespace fvdb::detail::autograd { - -struct ProjectGaussians : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList - forward(AutogradContext *ctx, - const Variable &means, // [N, 3] - const Variable &quats, // [N, 4] - const Variable &scales, // [N, 3] - const Variable &camToWorldMatrices, // [C, 4, 4] - const Variable &projectionMatrices, // [C, 3, 3] - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2D, - const bool calcCompensions, - const bool ortho, - std::optional outNormalizeddLossdMeans2dNormAccum = std::nullopt, - std::optional outNormalizedMaxRadiiAccum = std::nullopt, - std::optional outGradientStepCount = std::nullopt); - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -struct ProjectGaussiansJagged : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList forward(AutogradContext *ctx, - const Variable &gSizes, // [B] gaussian sizes - const Variable &means, // [ggz, 3] - const Variable &quats, // [ggz, 4] optional - const Variable &scales, // [ggz, 3] optional - const Variable &cSizes, // [B] camera sizes - const Variable &camToWorldMatrices, // [ccz, 4, 4] - const Variable &projectionMatrices, // [ccz, 3, 3] - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2D, - const bool ortho); - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -} // namespace fvdb::detail::autograd - -#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANPROJECTION_H diff --git a/src/fvdb/detail/autograd/GaussianRasterize.cpp b/src/fvdb/detail/autograd/GaussianRasterize.cpp deleted file mode 100644 index 72e808764..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterize.cpp +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include - -#include - -namespace fvdb::detail::autograd { - -RasterizeGaussiansToPixels::VariableList -RasterizeGaussiansToPixels::forward( - RasterizeGaussiansToPixels::AutogradContext *ctx, - const RasterizeGaussiansToPixels::Variable &means2d, // [C, N, 2] - const RasterizeGaussiansToPixels::Variable &conics, // [C, N, 3] - const RasterizeGaussiansToPixels::Variable &colors, // [C, N, 3] - const RasterizeGaussiansToPixels::Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const RasterizeGaussiansToPixels::Variable &tileOffsets, // [C, tile_height, tile_width] - const RasterizeGaussiansToPixels::Variable &tileGaussianIds, // [n_isects] - const bool absgrad, - std::optional backgrounds, - std::optional masks) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::forward"); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - const ops::RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; - return ops::dispatchGaussianRasterizeForward(means2d, - conics, - colors, - opacities, - renderWindow, - tileSize, - tileOffsets, - tileGaussianIds, - backgrounds, - masks); - }); - Variable renderedColors = std::get<0>(variables); - Variable renderedAlphas = std::get<1>(variables); - Variable lastIds = std::get<2>(variables); - - std::vector toSave = { - means2d, conics, colors, opacities, tileOffsets, tileGaussianIds, renderedAlphas, lastIds}; - if (backgrounds.has_value()) { - toSave.push_back(backgrounds.value()); - ctx->saved_data["has_backgrounds"] = true; - } else { - ctx->saved_data["has_backgrounds"] = false; - } - if (masks.has_value()) { - toSave.push_back(masks.value()); - ctx->saved_data["has_masks"] = true; - } else { - ctx->saved_data["has_masks"] = false; - } - ctx->save_for_backward(toSave); - - ctx->saved_data["imageWidth"] = (int64_t)imageWidth; - ctx->saved_data["imageHeight"] = (int64_t)imageHeight; - ctx->saved_data["tileSize"] = (int64_t)tileSize; - ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; - ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; - ctx->saved_data["absgrad"] = absgrad; - - return {renderedColors, renderedAlphas}; -} - -RasterizeGaussiansToPixels::VariableList -RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext *ctx, - RasterizeGaussiansToPixels::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::backward"); - Variable dLossDRenderedColors = gradOutput.at(0); - Variable dLossDRenderedAlphas = gradOutput.at(1); - - // ensure the gradients are contiguous if they are not None - if (dLossDRenderedColors.defined()) { - dLossDRenderedColors = dLossDRenderedColors.contiguous(); - } - if (dLossDRenderedAlphas.defined()) { - dLossDRenderedAlphas = dLossDRenderedAlphas.contiguous(); - } - - VariableList saved = ctx->get_saved_variables(); - Variable means2d = saved.at(0); - Variable conics = saved.at(1); - Variable colors = saved.at(2); - Variable opacities = saved.at(3); - Variable tileOffsets = saved.at(4); - Variable tileGaussianIds = saved.at(5); - Variable renderedAlphas = saved.at(6); - Variable lastIds = saved.at(7); - - const bool hasBackgrounds = ctx->saved_data["has_backgrounds"].toBool(); - const bool hasMasks = ctx->saved_data["has_masks"].toBool(); - std::optional backgrounds = std::nullopt; - std::optional masks = std::nullopt; - int64_t optIdx = 8; - if (hasBackgrounds) { - backgrounds = saved.at(optIdx++); - } - if (hasMasks) { - masks = saved.at(optIdx++); - } - - const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); - const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); - const int tileSize = (int)ctx->saved_data["tileSize"].toInt(); - const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt(); - const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt(); - const bool absgrad = ctx->saved_data["absgrad"].toBool(); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - const ops::RenderWindow2D renderWindow{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}; - return ops::dispatchGaussianRasterizeBackward(means2d, - conics, - colors, - opacities, - renderWindow, - tileSize, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds, - dLossDRenderedColors, - dLossDRenderedAlphas, - absgrad, - -1, - backgrounds, - masks); - }); - Variable dLossDMean2dAbs; - if (absgrad) { - dLossDMean2dAbs = std::get<0>(variables); - } else { - dLossDMean2dAbs = Variable(); - } - Variable dLossDMeans2d = std::get<1>(variables); - Variable dLossDConics = std::get<2>(variables); - Variable dLossDColors = std::get<3>(variables); - Variable dLossDOpacities = std::get<4>(variables); - - return { - dLossDMeans2d, - dLossDConics, - dLossDColors, - dLossDOpacities, - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), - Variable(), // backgrounds - Variable(), // masks - }; -} - -} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRasterize.h b/src/fvdb/detail/autograd/GaussianRasterize.h deleted file mode 100644 index 85d0d1a84..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterize.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H -#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H - -#include - -namespace fvdb::detail::autograd { - -struct RasterizeGaussiansToPixels : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList forward(AutogradContext *ctx, - const Variable &means2d, // [C, N, 2] - const Variable &conics, // [C, N, 3] - const Variable &colors, // [C, N, 3] - const Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const Variable &tileOffsets, // [C, tile_height, tile_width] - const Variable &tileGaussianIds, // [n_isects] - const bool absgrad, - std::optional backgrounds = std::nullopt, // [C, D] - std::optional masks = std::nullopt); // [C, tileH, tileW] - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -} // namespace fvdb::detail::autograd - -#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H diff --git a/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.cpp b/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.cpp deleted file mode 100644 index 927da61ec..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.cpp +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include -#include - -#include - -namespace fvdb::detail::autograd { - -RasterizeGaussiansToPixelsFromWorld3DGS::VariableList -RasterizeGaussiansToPixelsFromWorld3DGS::forward( - RasterizeGaussiansToPixelsFromWorld3DGS::AutogradContext *ctx, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &means, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &quats, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &logScales, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &features, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &opacities, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &worldToCamMatricesStart, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &worldToCamMatricesEnd, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &projectionMatrices, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &distortionCoeffs, - const fvdb::detail::ops::RollingShutterType rollingShutterType, - const fvdb::detail::ops::DistortionModel distortionModel, - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &tileOffsets, - const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &tileGaussianIds, - std::optional backgrounds, - std::optional masks) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsFromWorld3DGS::forward"); - - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.imageOriginW = imageOriginW; - settings.imageOriginH = imageOriginH; - settings.tileSize = tileSize; - - auto outputs = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { - return ops::dispatchGaussianRasterizeFromWorld3DGSForward( - means, - quats, - logScales, - features, - opacities, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - distortionCoeffs, - rollingShutterType, - distortionModel, - settings, - tileOffsets, - tileGaussianIds, - backgrounds, - masks); - }); - - Variable renderedFeatures = std::get<0>(outputs); - Variable renderedAlphas = std::get<1>(outputs); - Variable lastIds = std::get<2>(outputs); - - std::vector toSave = {means, - quats, - logScales, - features, - opacities, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - distortionCoeffs, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds}; - if (backgrounds.has_value()) { - toSave.push_back(backgrounds.value()); - ctx->saved_data["has_backgrounds"] = true; - } else { - ctx->saved_data["has_backgrounds"] = false; - } - if (masks.has_value()) { - toSave.push_back(masks.value()); - ctx->saved_data["has_masks"] = true; - } else { - ctx->saved_data["has_masks"] = false; - } - ctx->save_for_backward(toSave); - - ctx->saved_data["imageWidth"] = (int64_t)imageWidth; - ctx->saved_data["imageHeight"] = (int64_t)imageHeight; - ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; - ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; - ctx->saved_data["tileSize"] = (int64_t)tileSize; - ctx->saved_data["distortionModel"] = (int64_t)distortionModel; - ctx->saved_data["rollingShutterType"] = (int64_t)rollingShutterType; - - return {renderedFeatures, renderedAlphas}; -} - -RasterizeGaussiansToPixelsFromWorld3DGS::VariableList -RasterizeGaussiansToPixelsFromWorld3DGS::backward( - RasterizeGaussiansToPixelsFromWorld3DGS::AutogradContext *ctx, - RasterizeGaussiansToPixelsFromWorld3DGS::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsFromWorld3DGS::backward"); - - Variable dLossDRenderedFeatures = gradOutput.at(0); - Variable dLossDRenderedAlphas = gradOutput.at(1); - if (dLossDRenderedFeatures.defined()) { - dLossDRenderedFeatures = dLossDRenderedFeatures.contiguous(); - } - if (dLossDRenderedAlphas.defined()) { - dLossDRenderedAlphas = dLossDRenderedAlphas.contiguous(); - } - - VariableList saved = ctx->get_saved_variables(); - Variable means = saved.at(0); - Variable quats = saved.at(1); - Variable logScales = saved.at(2); - Variable features = saved.at(3); - Variable opacities = saved.at(4); - Variable worldToCamMatricesStart = saved.at(5); - Variable worldToCamMatricesEnd = saved.at(6); - Variable projectionMatrices = saved.at(7); - Variable distortionCoeffs = saved.at(8); - Variable tileOffsets = saved.at(9); - Variable tileGaussianIds = saved.at(10); - Variable renderedAlphas = saved.at(11); - Variable lastIds = saved.at(12); - - const bool hasBackgrounds = ctx->saved_data["has_backgrounds"].toBool(); - const bool hasMasks = ctx->saved_data["has_masks"].toBool(); - std::optional backgrounds = std::nullopt; - std::optional masks = std::nullopt; - int64_t optIdx = 13; - if (hasBackgrounds) { - backgrounds = saved.at(optIdx++); - } - if (hasMasks) { - masks = saved.at(optIdx++); - } - - const uint32_t imageWidth = (uint32_t)ctx->saved_data["imageWidth"].toInt(); - const uint32_t imageHeight = (uint32_t)ctx->saved_data["imageHeight"].toInt(); - const uint32_t imageOriginW = (uint32_t)ctx->saved_data["imageOriginW"].toInt(); - const uint32_t imageOriginH = (uint32_t)ctx->saved_data["imageOriginH"].toInt(); - const uint32_t tileSize = (uint32_t)ctx->saved_data["tileSize"].toInt(); - const auto distortionModel = - static_cast(ctx->saved_data["distortionModel"].toInt()); - const auto rollingShutterType = static_cast( - ctx->saved_data["rollingShutterType"].toInt()); - - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.imageOriginW = imageOriginW; - settings.imageOriginH = imageOriginH; - settings.tileSize = tileSize; - - auto grads = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { - return ops::dispatchGaussianRasterizeFromWorld3DGSBackward( - means, - quats, - logScales, - features, - opacities, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - distortionCoeffs, - rollingShutterType, - distortionModel, - settings, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds, - dLossDRenderedFeatures, - dLossDRenderedAlphas, - backgrounds, - masks); - }); - - Variable dMeans = std::get<0>(grads); - Variable dQuats = std::get<1>(grads); - Variable dLogScales = std::get<2>(grads); - Variable dFeatures = std::get<3>(grads); - Variable dOpacities = std::get<4>(grads); - - // Return gradients in the same order as forward inputs. - return {dMeans, dQuats, dLogScales, dFeatures, dOpacities, Variable(), Variable(), - Variable(), Variable(), Variable(), Variable(), Variable(), Variable(), Variable(), - Variable(), Variable(), Variable(), Variable(), Variable(), Variable()}; -} - -} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.h b/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.h deleted file mode 100644 index da506a795..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterizeFromWorld.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZEFROMWORLD_H -#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZEFROMWORLD_H - -#include - -#include - -#include - -namespace fvdb::detail::autograd { - -/// @brief Autograd wrapper for dense rasterization from world-space 3D Gaussians. -struct RasterizeGaussiansToPixelsFromWorld3DGS - : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList - forward(AutogradContext *ctx, - const Variable &means, // [N,3] - const Variable &quats, // [N,4] - const Variable &logScales, // [N,3] - const Variable &features, // [C,N,D] - const Variable &opacities, // [C,N] - const Variable &worldToCamMatricesStart, // [C,4,4] - const Variable &worldToCamMatricesEnd, // [C,4,4] - const Variable &projectionMatrices, // [C,3,3] - const Variable &distortionCoeffs, // [C,K] - const fvdb::detail::ops::RollingShutterType rollingShutterType, - const fvdb::detail::ops::DistortionModel distortionModel, - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const Variable &tileOffsets, // [C, tileH, tileW] - const Variable &tileGaussianIds, // [n_isects] - std::optional backgrounds = std::nullopt, // [C,D] - std::optional masks = std::nullopt); // [C,tileH,tileW] bool - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -} // namespace fvdb::detail::autograd - -#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZEFROMWORLD_H diff --git a/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp b/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp deleted file mode 100644 index da076269c..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include -#include -#include -#include -#include - -#include - -namespace fvdb::detail::autograd { - -RasterizeGaussiansToPixelsSparse::VariableList -RasterizeGaussiansToPixelsSparse::forward( - RasterizeGaussiansToPixelsSparse::AutogradContext *ctx, - const JaggedTensor &pixelsToRender, // [C, num_pixels, 2] - const RasterizeGaussiansToPixelsSparse::Variable &means2d, // [C, N, 2] - const RasterizeGaussiansToPixelsSparse::Variable &conics, // [C, N, 3] - const RasterizeGaussiansToPixelsSparse::Variable &colors, // [C, N, 3] - const RasterizeGaussiansToPixelsSparse::Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const RasterizeGaussiansToPixelsSparse::Variable - &tileOffsets, // [C, tile_height, tile_width] (dense) or [num_active_tiles + 1] (sparse) - const RasterizeGaussiansToPixelsSparse::Variable &tileGaussianIds, // [n_isects] - const RasterizeGaussiansToPixelsSparse::Variable &activeTiles, // [num_active_tiles] - const RasterizeGaussiansToPixelsSparse::Variable - &tilePixelMask, // [num_active_tiles, tileSize, tileSize] - const RasterizeGaussiansToPixelsSparse::Variable &tilePixelCumsum, // [num_active_tiles + 1] - const RasterizeGaussiansToPixelsSparse::Variable &pixelMap, // [num_pixels] - const bool absgrad, - std::optional backgrounds, - std::optional masks) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsSparse::forward"); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - const ops::RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; - return ops::dispatchGaussianSparseRasterizeForward(pixelsToRender, - means2d, - conics, - colors, - opacities, - renderWindow, - tileSize, - tileOffsets, - tileGaussianIds, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap, - backgrounds, - masks); - }); - JaggedTensor renderedColors = std::get<0>(variables); - JaggedTensor renderedAlphas = std::get<1>(variables); - JaggedTensor lastIds = std::get<2>(variables); - - const auto joffsets = pixelsToRender.joffsets(); - const auto jidx = pixelsToRender.jidx(); - const auto jlidx = pixelsToRender.jlidx(); - const auto numOuterLists = pixelsToRender.num_outer_lists(); - - std::vector toSave = {means2d, - conics, - colors, - opacities, - tileOffsets, - tileGaussianIds, - pixelsToRender.jdata(), - renderedColors.jdata(), - renderedAlphas.jdata(), - lastIds.jdata(), - joffsets, - jidx, - jlidx, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap}; - if (backgrounds.has_value()) { - toSave.push_back(backgrounds.value()); - ctx->saved_data["has_backgrounds"] = true; - } else { - ctx->saved_data["has_backgrounds"] = false; - } - if (masks.has_value()) { - toSave.push_back(masks.value()); - ctx->saved_data["has_masks"] = true; - } else { - ctx->saved_data["has_masks"] = false; - } - ctx->save_for_backward(toSave); - - ctx->saved_data["imageWidth"] = (int64_t)imageWidth; - ctx->saved_data["imageHeight"] = (int64_t)imageHeight; - ctx->saved_data["tileSize"] = (int64_t)tileSize; - ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW; - ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH; - ctx->saved_data["numOuterLists"] = (int64_t)numOuterLists; - ctx->saved_data["absgrad"] = absgrad; - - return {renderedColors.jdata(), renderedAlphas.jdata()}; -} - -RasterizeGaussiansToPixelsSparse::VariableList -RasterizeGaussiansToPixelsSparse::backward( - RasterizeGaussiansToPixelsSparse::AutogradContext *ctx, - RasterizeGaussiansToPixelsSparse::VariableList gradOutput) { - FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsSparse::backward"); - Variable dLossDRenderedFeaturesJData = gradOutput.at(0); - Variable dLossDRenderedAlphasJData = gradOutput.at(1); - - // ensure the gradients are contiguous if they are not None - if (dLossDRenderedFeaturesJData.defined()) { - dLossDRenderedFeaturesJData = dLossDRenderedFeaturesJData.contiguous(); - } - if (dLossDRenderedAlphasJData.defined()) { - dLossDRenderedAlphasJData = dLossDRenderedAlphasJData.contiguous(); - } - - VariableList saved = ctx->get_saved_variables(); - Variable means2d = saved.at(0); - Variable conics = saved.at(1); - Variable features = saved.at(2); - Variable opacities = saved.at(3); - Variable tileOffsets = saved.at(4); - Variable tileGaussianIds = saved.at(5); - Variable pixelsToRenderJData = saved.at(6); - Variable renderedColorsJData = saved.at(7); - Variable renderedAlphasJData = saved.at(8); - Variable lastIdsJData = saved.at(9); - Variable joffsets = saved.at(10); - Variable jidx = saved.at(11); - Variable jlidx = saved.at(12); - Variable activeTiles = saved.at(13); - Variable tilePixelMask = saved.at(14); - Variable tilePixelCumsum = saved.at(15); - Variable pixelMap = saved.at(16); - - const bool hasBackgrounds = ctx->saved_data["has_backgrounds"].toBool(); - const bool hasMasks = ctx->saved_data["has_masks"].toBool(); - std::optional backgrounds = std::nullopt; - std::optional masks = std::nullopt; - int64_t optIdx = 17; - if (hasBackgrounds) { - backgrounds = saved.at(optIdx++); - } - if (hasMasks) { - masks = saved.at(optIdx++); - } - - const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt(); - const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt(); - const int tileSize = (int)ctx->saved_data["tileSize"].toInt(); - const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt(); - const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt(); - const int64_t numOuterLists = ctx->saved_data["numOuterLists"].toInt(); - const bool absgrad = ctx->saved_data["absgrad"].toBool(); - - auto pixelsToRender = JaggedTensor::from_jdata_joffsets_jidx_and_lidx_unsafe( - pixelsToRenderJData, joffsets, jidx, jlidx, numOuterLists); - auto renderedAlphas = pixelsToRender.jagged_like(renderedAlphasJData); - auto lastIds = pixelsToRender.jagged_like(lastIdsJData); - - auto dLossDRenderedFeatures = pixelsToRender.jagged_like(dLossDRenderedFeaturesJData); - auto dLossDRenderedAlphas = pixelsToRender.jagged_like(dLossDRenderedAlphasJData); - - auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { - const ops::RenderWindow2D renderWindow{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}; - return ops::dispatchGaussianSparseRasterizeBackward(pixelsToRender, - means2d, - conics, - features, - opacities, - renderWindow, - tileSize, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastIds, - dLossDRenderedFeatures, - dLossDRenderedAlphas, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap, - absgrad, - -1, - backgrounds, - masks); - }); - Variable dLossDMean2dAbs; - if (absgrad) { - dLossDMean2dAbs = std::get<0>(variables); - } else { - dLossDMean2dAbs = Variable(); - } - Variable dLossDMeans2d = std::get<1>(variables); - Variable dLossDConics = std::get<2>(variables); - Variable dLossDColors = std::get<3>(variables); - Variable dLossDOpacities = std::get<4>(variables); - - return { - Variable(), // pixelsToRender - dLossDMeans2d, // means2d - dLossDConics, // conics - dLossDColors, // features - dLossDOpacities, // opacities - Variable(), // imageWidth - Variable(), // imageHeight - Variable(), // imageOriginW - Variable(), // imageOriginH - Variable(), // tileSize - Variable(), // tileOffsets - Variable(), // tileGaussianIds - Variable(), // activeTiles - Variable(), // tilePixelMask - Variable(), // tilePixelCumsum - Variable(), // pixelMap - Variable(), // absgrad - Variable(), // backgrounds - Variable(), // masks - }; -} - -} // namespace fvdb::detail::autograd diff --git a/src/fvdb/detail/autograd/GaussianRasterizeSparse.h b/src/fvdb/detail/autograd/GaussianRasterizeSparse.h deleted file mode 100644 index d68c45e37..000000000 --- a/src/fvdb/detail/autograd/GaussianRasterizeSparse.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H -#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H - -#include - -#include - -namespace fvdb::detail::autograd { - -struct RasterizeGaussiansToPixelsSparse - : public torch::autograd::Function { - using VariableList = torch::autograd::variable_list; - using AutogradContext = torch::autograd::AutogradContext; - using Variable = torch::autograd::Variable; - - static VariableList forward( - AutogradContext *ctx, - const JaggedTensor &pixelsToRender, // [C, num_pixels, 2] - const Variable &means2d, // [C, N, 2] - const Variable &conics, // [C, N, 3] - const Variable &features, // [C, N, D] - const Variable &opacities, // [N] - const uint32_t imageWidth, - const uint32_t imageHeight, - const uint32_t imageOriginW, - const uint32_t imageOriginH, - const uint32_t tileSize, - const Variable - &tileOffsets, // [C, tile_height, tile_width] (dense) or [num_active_tiles + 1] (sparse) - const Variable &tileGaussianIds, // [n_isects] - const Variable &activeTiles, // [num_active_tiles] - const Variable &tilePixelMask, // [num_active_tiles, tileSize, tileSize] - const Variable &tilePixelCumsum, // [num_active_tiles + 1] - const Variable &pixelMap, // [num_pixels] - const bool absgrad, - std::optional backgrounds = std::nullopt, // [C, D] - std::optional masks = std::nullopt); // [C, tileH, tileW] - - static VariableList backward(AutogradContext *ctx, VariableList gradOutput); -}; - -} // namespace fvdb::detail::autograd - -#endif // FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZESPARSE_H diff --git a/src/fvdb/detail/io/GaussianPlyIO.cpp b/src/fvdb/detail/io/GaussianPlyIO.cpp index 5d1ed9a6b..34487b3ab 100644 --- a/src/fvdb/detail/io/GaussianPlyIO.cpp +++ b/src/fvdb/detail/io/GaussianPlyIO.cpp @@ -4,10 +4,7 @@ #include // Ops headers -#include - -// Utils headers -#include +#include #include #include @@ -125,7 +122,7 @@ parsePlyMetadataComments(tinyply::PlyFile &plyf) { plyf.request_properties_from_element(key, {"value"}); TORCH_CHECK(tensorData != nullptr, "Failed to read tensor metadata '" + key + - "'. Make sure it was written with fvdb::GaussianSplat3d::savePly"); + "'. Make sure it was written with fvdb::detail::io::saveGaussianPly"); retTensorMetadata[key] = std::make_tuple(tensorData, tensorShape); } else { continue; // Not a metadata comment @@ -180,7 +177,13 @@ plyMetadataComment(const std::string &key, const PlyMetadataTypes &value) { } } -std::tuple> +std::tuple> loadGaussianPly(const std::string &filename, torch::Device device) { using namespace tinyply; @@ -301,32 +304,28 @@ loadGaussianPly(const std::string &filename, torch::Device device) { retMetadata[key] = tensor; } - return std::make_tuple(GaussianSplat3d(means.to(device), - quats.to(device), - logScales.to(device), - logitOpacities.to(device), - sh0Coeffs.to(device), - shNCoeffs.to(device), - false, - false, - false), + return std::make_tuple(means.to(device), + quats.to(device), + logScales.to(device), + logitOpacities.to(device), + sh0Coeffs.to(device), + shNCoeffs.to(device), retMetadata); } void saveGaussianPly(const std::string &filename, - const GaussianSplat3d &gaussians, + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &sh0, + const torch::Tensor &shN, std::optional> trainingMetadata) { using namespace tinyply; - const fvdb::JaggedTensor validMask = FVDB_DISPATCH_KERNEL(gaussians.means().device(), [&]() { - return detail::ops::dispatchGaussianNanInfMask(gaussians.means(), - gaussians.quats(), - gaussians.logScales(), - gaussians.logitOpacities(), - gaussians.sh0(), - gaussians.shN()); - }); + const fvdb::JaggedTensor validMask = detail::ops::compute_gaussian_nan_inf_mask( + means, quats, logScales, logitOpacities, sh0, shN); std::filebuf fb; fb.open(filename, std::ios::out | std::ios::binary); @@ -337,31 +336,27 @@ saveGaussianPly(const std::string &filename, PlyFile plyf; const torch::Tensor meansCPU = - gaussians.means().index({validMask.jdata(), torch::indexing::Ellipsis}).cpu().contiguous(); + means.index({validMask.jdata(), torch::indexing::Ellipsis}).cpu().contiguous(); const torch::Tensor quatsCPU = - gaussians.quats().index({validMask.jdata(), torch::indexing::Ellipsis}).cpu().contiguous(); - const torch::Tensor scalesCPU = gaussians.logScales() - .index({validMask.jdata(), torch::indexing::Ellipsis}) - .cpu() - .contiguous(); - const torch::Tensor opacitiesCPU = - gaussians.logitOpacities().index({validMask.jdata()}).cpu().contiguous(); + quats.index({validMask.jdata(), torch::indexing::Ellipsis}).cpu().contiguous(); + const torch::Tensor scalesCPU = + logScales.index({validMask.jdata(), torch::indexing::Ellipsis}).cpu().contiguous(); + const torch::Tensor opacitiesCPU = logitOpacities.index({validMask.jdata()}).cpu().contiguous(); // [N, D] const torch::Tensor shCoeffs0CPU = - gaussians.sh0().index({validMask.jdata(), 0, torch::indexing::Ellipsis}).cpu().contiguous(); + sh0.index({validMask.jdata(), 0, torch::indexing::Ellipsis}).cpu().contiguous(); // [N, K-1, D] const torch::Tensor shCoeffsNCPU = [&]() { - if (gaussians.shN().numel() <= 0) { - return torch::zeros({meansCPU.size(0), 0}, - gaussians.shN().options().device(torch::kCPU)); + if (shN.numel() <= 0) { + return torch::zeros({meansCPU.size(0), 0}, shN.options().device(torch::kCPU)); } else { // ShN has shape [N, K-1, D], meaning the spherical harmonic coefficients are ordered // by basis, then channel. i.e. RGBRGB... // Gaussian PLYs expect the coefficients to be ordered by channel, then basis. i.e. // RR...GG...BB... So we permute the axes to [N, D, K-1] and then reshape to [N, // D*(K-1)] - return gaussians.shN() + return shN .index({validMask.jdata(), torch::indexing::Slice(), torch::indexing::Ellipsis}) .cpu() .contiguous() diff --git a/src/fvdb/detail/io/GaussianPlyIO.h b/src/fvdb/detail/io/GaussianPlyIO.h index 63ba9f1cd..a60325d2f 100644 --- a/src/fvdb/detail/io/GaussianPlyIO.h +++ b/src/fvdb/detail/io/GaussianPlyIO.h @@ -4,11 +4,17 @@ #ifndef FVDB_DETAIL_IO_GAUSSIANPLYIO_H #define FVDB_DETAIL_IO_GAUSSIANPLYIO_H -#include +#include + +#include +#include +#include +#include +#include namespace fvdb::detail::io { -/// The types of valid metadata you can save in a PLY file alongside Gaussians +/// Type alias for PLY metadata values. using PlyMetadataTypes = std::variant; /// Magic string prepended to additional metadata properties stored in PLY files @@ -19,24 +25,36 @@ inline static const size_t MAX_PLY_KEY_LENGTH = 256; inline static const std::string PLY_VERSION_STRING = "fvdb_ply 1.0.0"; -/// @brief Load a PLY file's means, quats, scales, opacities, and SH coefficients as the state -/// of this GaussianSplat3d object +/// @brief Load a PLY file's means, quats, scales, opacities, and SH coefficients as raw tensors. /// @param filename Filename of the PLY file /// @param device Device to transfer the loaded tensors to -/// @return The loaded GaussianSplat3d class, and a dictionary of metadata (can be empty if no -// metadata was saved in the PLY file). The metadata keys are strings and the values are either -// strings, int64s, doubles, or tensors. -std::tuple> +/// @return A tuple of (means, quats, logScales, logitOpacities, sh0, shN, metadata) +std::tuple> loadGaussianPly(const std::string &filename, torch::Device device = torch::kCPU); -/// @brief Save this scene and optional training metadata to a PLY file with the given filename +/// @brief Save Gaussian splat data and optional training metadata to a PLY file. /// @param filename The path to save the PLY file to -/// @param gaussians The GaussianSplat3d object containing the Gaussians to saved. -/// @param metadata An optional dictionary of training metadata to include in the PLY file. The -/// keys are strings and the values are either strings, int64s, doubles, or tensors +/// @param means [N, 3] Gaussian means +/// @param quats [N, 4] Gaussian quaternions +/// @param logScales [N, 3] Log of Gaussian scales +/// @param logitOpacities [N] Logit of Gaussian opacities +/// @param sh0 [N, 1, D] Degree-0 SH coefficients +/// @param shN [N, K-1, D] Higher-degree SH coefficients +/// @param metadata An optional dictionary of training metadata to include in the PLY file void saveGaussianPly( const std::string &filename, - const GaussianSplat3d &gaussians, + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &sh0, + const torch::Tensor &shN, std::optional> metadata = std::nullopt); } // namespace fvdb::detail::io diff --git a/src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu b/src/fvdb/detail/ops/AddNoiseToGaussianMeans.cu similarity index 66% rename from src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu rename to src/fvdb/detail/ops/AddNoiseToGaussianMeans.cu index 41acd1491..2ad52931c 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.cu +++ b/src/fvdb/detail/ops/AddNoiseToGaussianMeans.cu @@ -2,12 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include #include #include +#include +#include #include #include +#include +#include +#include #include @@ -18,6 +22,16 @@ namespace fvdb::detail::ops { using fvdb::detail::deviceChunk; using fvdb::detail::mergeStreams; +// Internal dispatch template (specializations defined below). +template +void dispatch_add_noise_to_gaussian_means(torch::Tensor &means, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &quats, + const float noiseScale, + const float t, + const float k); + template inline __device__ ScalarType sigmoid(ScalarType x) { @@ -26,17 +40,16 @@ sigmoid(ScalarType x) { template __global__ void -gaussianMCMCAddNoiseKernel(int64_t localToGlobalOffset, - int64_t localSize, - fvdb::TorchRAcc64 outMeans, - fvdb::TorchRAcc64 logScales, - fvdb::TorchRAcc64 logitOpacities, - fvdb::TorchRAcc64 quats, - fvdb::TorchRAcc64 baseNoise, - const ScalarType noiseScale, - const ScalarType t, - const ScalarType k) { - const auto N = outMeans.size(0); +add_noise_to_gaussian_means_kernel(int64_t localToGlobalOffset, + int64_t localSize, + fvdb::TorchRAcc64 outMeans, + fvdb::TorchRAcc64 logScales, + fvdb::TorchRAcc64 logitOpacities, + fvdb::TorchRAcc64 quats, + fvdb::TorchRAcc64 baseNoise, + const ScalarType noiseScale, + const ScalarType t, + const ScalarType k) { for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x + localToGlobalOffset; idx < localSize + localToGlobalOffset; idx += blockDim.x * gridDim.x) { @@ -81,7 +94,7 @@ launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] const int blockDim = DEFAULT_BLOCK_DIM; const int gridDim = fvdb::GET_BLOCKS(size, blockDim); - gaussianMCMCAddNoiseKernel<<>>( + add_noise_to_gaussian_means_kernel<<>>( offset, size, means.packed_accessor64(), @@ -98,13 +111,13 @@ launchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] template <> void -dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] - const torch::Tensor &logScales, // [N] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &quats, // [N, 4] - const float noiseScale, - const float t, - const float k) { +dispatch_add_noise_to_gaussian_means(torch::Tensor &means, // [N, 3] + const torch::Tensor &logScales, // [N] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &quats, // [N, 4] + const float noiseScale, + const float t, + const float k) { FVDB_FUNC_RANGE(); const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); @@ -119,13 +132,14 @@ dispatchGaussianMCMCAddNoise(torch::Tensor &means, template <> void -dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] input/output - const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &quats, // [N, 4] - const float noiseScale, - const float t, - const float k) { +dispatch_add_noise_to_gaussian_means( + torch::Tensor &means, // [N, 3] input/output + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &quats, // [N, 4] + const float noiseScale, + const float t, + const float k) { FVDB_FUNC_RANGE(); const auto N = means.size(0); @@ -160,14 +174,28 @@ dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3 template <> void -dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] input/output - const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &quats, // [N, 4] - const float noiseScale, - const float t, - const float k) { +dispatch_add_noise_to_gaussian_means(torch::Tensor &means, // [N, 3] input/output + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &quats, // [N, 4] + const float noiseScale, + const float t, + const float k) { TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianMCMCAddNoise is not implemented for CPU"); } +void +add_noise_to_gaussian_means(torch::Tensor &means, // [N, 3] input/output + const torch::Tensor &logScales, // [N] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &quats, // [N, 4] + const float noiseScale, + const float t, + const float k) { + FVDB_DISPATCH_KERNEL(means.device(), [&]() { + return dispatch_add_noise_to_gaussian_means( + means, logScales, logitOpacities, quats, noiseScale, t, k); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/AddNoiseToGaussianMeans.h b/src/fvdb/detail/ops/AddNoiseToGaussianMeans.h new file mode 100644 index 000000000..47a3f285f --- /dev/null +++ b/src/fvdb/detail/ops/AddNoiseToGaussianMeans.h @@ -0,0 +1,38 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef FVDB_DETAIL_OPS_ADDNOISETOGAUSSIANMEANS_H +#define FVDB_DETAIL_OPS_ADDNOISETOGAUSSIANMEANS_H + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Add noise to Gaussian means, scaled by opacity and covariance. +/// +/// Dispatches to the appropriate device implementation (CPU, CUDA, or PrivateUse1) +/// based on the device of the input tensors. +/// +/// @param[in,out] means 3D positions of Gaussians [N, 3] (modified in place) +/// @param[in] logScales Log scale factors of Gaussians [N, 3] +/// @param[in] logitOpacities Logit opacity values of Gaussians [N] +/// @param[in] quats Quaternion rotations of Gaussians [N, 4] +/// @param[in] noiseScale Overall noise magnitude +/// @param[in] t Opacity threshold for noise scaling sigmoid +/// @param[in] k Sharpness of the noise scaling sigmoid +void add_noise_to_gaussian_means(torch::Tensor &means, // [N, 3] input/output + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &quats, // [N, 4] + const float noiseScale, + const float t, + const float k); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_ADDNOISETOGAUSSIANMEANS_H diff --git a/src/fvdb/detail/ops/BuildFineGridFromCoarse.cu b/src/fvdb/detail/ops/BuildFineGridFromCoarse.cu index 7d9827338..a1984b0ee 100644 --- a/src/fvdb/detail/ops/BuildFineGridFromCoarse.cu +++ b/src/fvdb/detail/ops/BuildFineGridFromCoarse.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -36,44 +37,7 @@ dispatchBuildFineGridFromCoarse(const GridBatchData &coarseBatchHdl, const nanovdb::Coord subdivisionFactor, const std::optional &subdivMask); -__device__ inline void -copyCoords(const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord &ijk0, - const nanovdb::CoordBBox &bbox, - TorchRAcc64 outIJK, - TorchRAcc64 outIJKBIdx) { - static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); - nanovdb::Coord ijk; - int32_t count = 0; - for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { - for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { - for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { - ijk = ijk0 + nanovdb::Coord(di, dj, dk); - outIJK[base + count][0] = ijk[0]; - outIJK[base + count][1] = ijk[1]; - outIJK[base + count][2] = ijk[2]; - outIJKBIdx[base + count] = bidx; - count += 1; - } - } - } -} - -__device__ inline void -copyCoords(const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord size, - const nanovdb::Coord &ijk0, - TorchRAcc64 outIJK, - TorchRAcc64 outIJKBIdx) { - return copyCoords(bidx, - base, - ijk0, - nanovdb::CoordBBox(nanovdb::Coord(0), size - nanovdb::Coord(1)), - outIJK, - outIJKBIdx); -} +using fvdb::detail::copyCoords; __device__ void fineIjkForCoarseGridVoxelCallback(int32_t bidx, diff --git a/src/fvdb/detail/ops/BuildPaddedGrid.cu b/src/fvdb/detail/ops/BuildPaddedGrid.cu index 8eeb0c76d..d2a474857 100644 --- a/src/fvdb/detail/ops/BuildPaddedGrid.cu +++ b/src/fvdb/detail/ops/BuildPaddedGrid.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -28,97 +29,9 @@ template nanovdb::GridHandle dispatchBuildPaddedGrid(const GridBatchData &baseBatchHdl, int bmin, int bmax, bool excludeBorder); -__device__ inline void -copyCoords(const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord &ijk0, - const nanovdb::CoordBBox &bbox, - TorchRAcc64 outIJK, - TorchRAcc64 outIJKBIdx) { - static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); - nanovdb::Coord ijk; - int32_t count = 0; - for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { - for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { - for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { - ijk = ijk0 + nanovdb::Coord(di, dj, dk); - outIJK[base + count][0] = ijk[0]; - outIJK[base + count][1] = ijk[1]; - outIJK[base + count][2] = ijk[2]; - outIJKBIdx[base + count] = bidx; - count += 1; - } - } - } -} - -__device__ inline void -copyCoords(const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord size, - const nanovdb::Coord &ijk0, - TorchRAcc64 outIJK, - TorchRAcc64 outIJKBIdx) { - return copyCoords(bidx, - base, - ijk0, - nanovdb::CoordBBox(nanovdb::Coord(0), size - nanovdb::Coord(1)), - outIJK, - outIJKBIdx); -} - -__device__ inline void -copyCoordsWithoutBorder( - const typename nanovdb::DefaultReadAccessor gridAccessor, - const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord &ijk0, - const nanovdb::CoordBBox &bbox, - const TorchRAcc64 packInfoBase, - TorchRAcc64 outIJK, - TorchRAcc64 outIJKBIdx) { - static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); - nanovdb::Coord ijk; - bool active = true; - for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { - for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { - for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { - ijk = ijk0 + nanovdb::Coord(di, dj, dk); - active = active && gridAccessor.isActive(ijk); - } - } - } - if (active) { - int64_t outBase = packInfoBase[base]; - outIJK[outBase][0] = ijk0[0]; - outIJK[outBase][1] = ijk0[1]; - outIJK[outBase][2] = ijk0[2]; - outIJKBIdx[outBase] = bidx; - } -} - -__device__ inline void -countCoordsWithoutBorder( - const typename nanovdb::DefaultReadAccessor gridAccessor, - const fvdb::JIdxType bidx, - const int64_t base, - const nanovdb::Coord &ijk0, - const nanovdb::CoordBBox &bbox, - TorchRAcc64 outCounter) { - static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); - nanovdb::Coord ijk; - bool active = true; - for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { - for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { - for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { - ijk = ijk0 + nanovdb::Coord(di, dj, dk); - active = active && gridAccessor.isActive(ijk); - } - } - } - - outCounter[base] = active ? 1 : 0; -} +using fvdb::detail::copyCoords; +using fvdb::detail::copyCoordsWithoutBorder; +using fvdb::detail::countCoordsWithoutBorder; __device__ void ijkForGridVoxelCallback(int32_t bidx, diff --git a/src/fvdb/detail/ops/gsplat/GaussianSplatSparse.cu b/src/fvdb/detail/ops/BuildSparseGaussianTileLayout.cu similarity index 93% rename from src/fvdb/detail/ops/gsplat/GaussianSplatSparse.cu rename to src/fvdb/detail/ops/BuildSparseGaussianTileLayout.cu index 46b5d4edd..982efa774 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianSplatSparse.cu +++ b/src/fvdb/detail/ops/BuildSparseGaussianTileLayout.cu @@ -1,11 +1,15 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include #include #include +#include +#include #include +#include +#include +#include #include #include @@ -16,15 +20,6 @@ namespace fvdb::detail::ops { -#define CUB_WRAPPER(func, ...) \ - do { \ - size_t temp_storage_bytes = 0; \ - func(nullptr, temp_storage_bytes, __VA_ARGS__); \ - auto &caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \ - auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \ - func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \ - } while (false) - // Definitions: // ----------------- // @@ -62,8 +57,8 @@ namespace fvdb::detail::ops { // where uv_n is a tensor of shape [P_n, 2] of pixel coordinates in the n^th image in the batch. // If a pixel (c, i, j) is in in pixels_to_render we call it *active*, otherwise it is *inactive*. // PRECONDITION: pixels_to_render must not contain duplicates. The caller -// (sparseProjectGaussiansImpl) deduplicates before calling computeSparseInfo and scatters results -// back afterward. +// (sparseProjectGaussiansImpl) deduplicates before calling build_sparse_gaussian_tile_layout and +// scatters results back afterward. // // Let: // AP denote the number of active pixels. @@ -212,13 +207,13 @@ zeros(at::IntArrayRef dims, c10::ScalarType dtype, torch::Device device) { } std::tuple -computeSparseInfo(const int32_t tileSideLength, - const int32_t numTilesW, - const int32_t numTilesH, - const fvdb::JaggedTensor &pixelsToRender) { +build_sparse_gaussian_tile_layout(const int32_t tileSideLength, + const int32_t numTilesW, + const int32_t numTilesH, + const fvdb::JaggedTensor &pixelsToRender) { FVDB_FUNC_RANGE(); TORCH_CHECK_NOT_IMPLEMENTED(pixelsToRender.device().is_cuda(), - "computeSparseInfo only implemented on the device"); + "build_sparse_gaussian_tile_layout only implemented on the device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(pixelsToRender.jdata())); at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(pixelsToRender.device().index()); @@ -254,7 +249,6 @@ computeSparseInfo(const int32_t tileSideLength, auto outMaskAccessor = fvdb::tensorAccessor(tileMask); auto outTileIdAccessor = fvdb::tensorAccessor(perPixelTileIds); - // TODO we do not output tileMask currently, but we should AT_DISPATCH_INDEX_TYPES(pixelsToRender.scalar_type(), "computeTileMask", [&]() { const int32_t NUM_BLOCKS = GET_BLOCKS(numPixels, DEFAULT_BLOCK_DIM); computeTileMask<<>>( diff --git a/src/fvdb/detail/ops/gsplat/GaussianSplatSparse.h b/src/fvdb/detail/ops/BuildSparseGaussianTileLayout.h similarity index 84% rename from src/fvdb/detail/ops/gsplat/GaussianSplatSparse.h rename to src/fvdb/detail/ops/BuildSparseGaussianTileLayout.h index 94ac44faa..75325e4c6 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianSplatSparse.h +++ b/src/fvdb/detail/ops/BuildSparseGaussianTileLayout.h @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPLATSPARSE_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPLATSPARSE_H +#ifndef FVDB_DETAIL_OPS_BUILDSPARSEGAUSSIANTILELAYOUT_H +#define FVDB_DETAIL_OPS_BUILDSPARSEGAUSSIANTILELAYOUT_H #include @@ -42,11 +42,11 @@ namespace fvdb::detail::ops { /// the same image. Callers must deduplicate before calling this function. /// @return Tuple of (active_tiles, active_tile_mask, tile_pixel_mask, tile_pixel_cumsum, pixel_map) std::tuple -computeSparseInfo(const int32_t tileSideLength, - const int32_t numTilesW, - const int32_t numTilesH, - const fvdb::JaggedTensor &pixelsToRender); +build_sparse_gaussian_tile_layout(const int32_t tileSideLength, + const int32_t numTilesW, + const int32_t numTilesH, + const fvdb::JaggedTensor &pixelsToRender); } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPLATSPARSE_H +#endif // FVDB_DETAIL_OPS_BUILDSPARSEGAUSSIANTILELAYOUT_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu b/src/fvdb/detail/ops/ComputeGaussianNanInfMask.cu similarity index 75% rename from src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu rename to src/fvdb/detail/ops/ComputeGaussianNanInfMask.cu index 85a870e95..a94b25e3b 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.cu +++ b/src/fvdb/detail/ops/ComputeGaussianNanInfMask.cu @@ -1,9 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include +#include #include #include @@ -13,6 +14,15 @@ namespace fvdb { namespace detail { namespace ops { +// Internal dispatch template (specializations defined below). +template +fvdb::JaggedTensor dispatch_compute_gaussian_nan_inf_mask(const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN); + template __global__ __launch_bounds__(DEFAULT_BLOCK_DIM) void computeNanInfMaskKernel(int64_t localToGlobalOffset, @@ -71,12 +81,12 @@ computeNanInfMaskKernel(int64_t localToGlobalOffset, template <> fvdb::JaggedTensor -dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, - const fvdb::JaggedTensor &quats, - const fvdb::JaggedTensor &logScales, - const fvdb::JaggedTensor &logitOpacities, - const fvdb::JaggedTensor &sh0, - const fvdb::JaggedTensor &shN) { +dispatch_compute_gaussian_nan_inf_mask(const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN) { FVDB_FUNC_RANGE(); TORCH_CHECK_VALUE(means.rsize(0) == quats.rsize(0), "All inputs must have the same number of gaussians"); @@ -134,12 +144,13 @@ dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, template <> fvdb::JaggedTensor -dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, - const fvdb::JaggedTensor &quats, - const fvdb::JaggedTensor &logScales, - const fvdb::JaggedTensor &logitOpacities, - const fvdb::JaggedTensor &sh0, - const fvdb::JaggedTensor &shN) { +dispatch_compute_gaussian_nan_inf_mask( + const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN) { FVDB_FUNC_RANGE(); TORCH_CHECK_VALUE(means.rsize(0) == quats.rsize(0), "All inputs must have the same number of gaussians"); @@ -206,13 +217,27 @@ dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, template <> fvdb::JaggedTensor -dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, - const fvdb::JaggedTensor &quats, - const fvdb::JaggedTensor &logScales, - const fvdb::JaggedTensor &logitOpacities, - const fvdb::JaggedTensor &sh0, - const fvdb::JaggedTensor &shN) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "dispatchGaussianNanInfMask not implemented on the CPU"); +dispatch_compute_gaussian_nan_inf_mask(const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "dispatch_compute_gaussian_nan_inf_mask not implemented on the CPU"); +} + +fvdb::JaggedTensor +compute_gaussian_nan_inf_mask(const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN) { + return FVDB_DISPATCH_KERNEL(means.device(), [&]() { + return dispatch_compute_gaussian_nan_inf_mask( + means, quats, logScales, logitOpacities, sh0, shN); + }); } } // namespace ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.h b/src/fvdb/detail/ops/ComputeGaussianNanInfMask.h similarity index 62% rename from src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.h rename to src/fvdb/detail/ops/ComputeGaussianNanInfMask.h index 6f3309466..ad3fbb3c0 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianComputeNanInfMask.h +++ b/src/fvdb/detail/ops/ComputeGaussianNanInfMask.h @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCOMPUTENANINFMASK_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCOMPUTENANINFMASK_H +#ifndef FVDB_DETAIL_OPS_COMPUTEGAUSSIANNANINFMASK_H +#define FVDB_DETAIL_OPS_COMPUTEGAUSSIANNANINFMASK_H #include @@ -19,7 +19,8 @@ namespace ops { /// numerical stability in Gaussian Splatting algorithms, allowing invalid Gaussians to be /// filtered out before rendering. /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) +/// Dispatches to the appropriate device implementation (CPU, CUDA, or PrivateUse1) +/// based on the device of the input tensors. /// /// @param[in] means 3D positions of Gaussians as a jagged tensor [C, N, 3] /// @param[in] quats Quaternion rotations of Gaussians as a jagged tensor [C, N, 4] @@ -30,16 +31,15 @@ namespace ops { /// /// @return A jagged tensor mask where True indicates valid values (no NaN/Inf) and False indicates /// invalid values -template -fvdb::JaggedTensor dispatchGaussianNanInfMask(const fvdb::JaggedTensor &means, - const fvdb::JaggedTensor &quats, - const fvdb::JaggedTensor &logScales, - const fvdb::JaggedTensor &logitOpacities, - const fvdb::JaggedTensor &sh0, - const fvdb::JaggedTensor &shN); +fvdb::JaggedTensor compute_gaussian_nan_inf_mask(const fvdb::JaggedTensor &means, + const fvdb::JaggedTensor &quats, + const fvdb::JaggedTensor &logScales, + const fvdb::JaggedTensor &logitOpacities, + const fvdb::JaggedTensor &sh0, + const fvdb::JaggedTensor &shN); } // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCOMPUTENANINFMASK_H +#endif // FVDB_DETAIL_OPS_COMPUTEGAUSSIANNANINFMASK_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.cu b/src/fvdb/detail/ops/CountContributingGaussians.cu similarity index 69% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.cu rename to src/fvdb/detail/ops/CountContributingGaussians.cu index d0bfd130f..5634e2335 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.cu +++ b/src/fvdb/detail/ops/CountContributingGaussians.cu @@ -2,16 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include +#include #include #include +#include +#include +#include +#include #include +#include #include namespace fvdb::detail::ops { @@ -149,7 +150,7 @@ template struct RasterizeNumContributingGa } const ScalarType nextTransmittance = accumTransmittance * (1.0f - alpha); - if (nextTransmittance <= 1e-4f) { // this pixel is done: exclusive + if (nextTransmittance <= kTransmittanceThreshold) { // this pixel is done done = true; break; } @@ -236,7 +237,11 @@ launchRasterizeNumContributingGaussiansForwardKernel( // intersections const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, const std::optional &pixelsToRender = std::nullopt, // [C, NumPixels, 2] const std::optional &activeTiles = std::nullopt, const std::optional &tilePixelMask = std::nullopt, @@ -247,19 +252,17 @@ launchRasterizeNumContributingGaussiansForwardKernel( // tileOffsets can be 3D (dense) or 1D (sparse) const bool tileOffsetsAreSparse = tileOffsets.dim() == 1; if (!tileOffsetsAreSparse) { - TORCH_CHECK_VALUE(tileOffsets.size(2) == - (settings.imageWidth + settings.tileSize - 1) / settings.tileSize, + TORCH_CHECK_VALUE(tileOffsets.size(2) == (imageWidth + tileSize - 1) / tileSize, "tileOffsets width must match the number of tiles in image size"); - TORCH_CHECK_VALUE(tileOffsets.size(1) == - (settings.imageHeight + settings.tileSize - 1) / settings.tileSize, + TORCH_CHECK_VALUE(tileOffsets.size(1) == (imageHeight + tileSize - 1) / tileSize, "tileOffsets height must match the number of tiles in image size"); } // Get C from tileOffsets for dense mode // For sparse mode, C is unused, only used for output sizing for dense mode const uint32_t C = tileOffsetsAreSparse ? 0 : tileOffsets.size(0); - const uint32_t tileExtentH = (settings.imageHeight + settings.tileSize - 1) / settings.tileSize; - const uint32_t tileExtentW = (settings.imageWidth + settings.tileSize - 1) / settings.tileSize; + const uint32_t tileExtentH = (imageHeight + tileSize - 1) / tileSize; + const uint32_t tileExtentW = (imageWidth + tileSize - 1) / tileSize; TORCH_CHECK_VALUE(pixelMap.has_value() == pixelsToRender.has_value(), "pixelMap and pixelsToRender must be provided together"); @@ -268,9 +271,8 @@ launchRasterizeNumContributingGaussiansForwardKernel( "pixelMap must have the same number of elements as pixelsToRender"); } - auto sizes = pixelsToRender.has_value() - ? pixelsToRender->lsizes1() - : std::vector{C * settings.imageHeight * settings.imageWidth}; + auto sizes = pixelsToRender.has_value() ? pixelsToRender->lsizes1() + : std::vector{C * imageHeight * imageWidth}; std::vector numContributingGaussiansToRenderVec; std::vector alphasToRenderVec; @@ -291,8 +293,7 @@ launchRasterizeNumContributingGaussiansForwardKernel( // - vec2t xy; -- 8 bytes for float32 // - scalar_t opacity; -- 4 bytes for float32 // - vec3t conic; -- 12 bytes for float32 - const uint32_t sharedMem = - settings.tileSize * settings.tileSize * sizeof(Gaussian2D); + const uint32_t sharedMem = tileSize * tileSize * sizeof(Gaussian2D); if (cudaFuncSetAttribute(rasterizeNumContributingGaussiansForward, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -302,7 +303,7 @@ launchRasterizeNumContributingGaussiansForwardKernel( " bytes), try lowering tile_size."); } - const dim3 blockDim = {settings.tileSize, settings.tileSize, 1}; + const dim3 blockDim = {tileSize, tileSize, 1}; const dim3 gridDim = activeTiles.has_value() // sparse mode ? dim3(activeTiles.value().size(0), 1, 1) : dim3(C * tileExtentH * tileExtentW, 1, 1); @@ -312,11 +313,11 @@ launchRasterizeNumContributingGaussiansForwardKernel( opacities, backgrounds, masks, - settings.imageWidth, - settings.imageHeight, - settings.imageOriginW, - settings.imageOriginH, - settings.tileSize, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, tileOffsets, tileGaussianIds, outNumContributingGaussians, @@ -335,16 +336,51 @@ launchRasterizeNumContributingGaussiansForwardKernel( } // namespace +template +std::tuple +dispatch_count_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize); + +template +std::tuple +dispatch_count_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize); + template <> std::tuple -dispatchGaussianRasterizeNumContributingGaussians( +dispatch_count_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings // render settings + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize ) { FVDB_FUNC_RANGE(); @@ -366,7 +402,11 @@ dispatchGaussianRasterizeNumContributingGaussians( masks, tileOffsets, tileGaussianIds, - settings) + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize) : launchRasterizeNumContributingGaussiansForwardKernel( means2d, conics, @@ -375,13 +415,16 @@ dispatchGaussianRasterizeNumContributingGaussians( masks, tileOffsets, tileGaussianIds, - settings); + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize); // Get C from tileOffsets for dense mode const auto C = tileOffsets.size(0); return std::make_tuple( - numContributingGaussians.jdata().reshape( - {C, settings.imageHeight, settings.imageWidth}), - alphas.jdata().reshape({C, settings.imageHeight, settings.imageWidth})); + numContributingGaussians.jdata().reshape({C, imageHeight, imageWidth}), + alphas.jdata().reshape({C, imageHeight, imageWidth})); }), AT_EXPAND(AT_FLOATING_TYPES), c10::kHalf); @@ -389,21 +432,24 @@ dispatchGaussianRasterizeNumContributingGaussians( template <> std::tuple -dispatchGaussianRasterizeNumContributingGaussians( +dispatch_count_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings // render settings -) { + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } template <> std::tuple -dispatchGaussianSparseRasterizeNumContributingGaussians( +dispatch_count_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -414,7 +460,11 @@ dispatchGaussianSparseRasterizeNumContributingGaussians( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings) { // render settings + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize) { FVDB_FUNC_RANGE(); const bool isPacked = means2d.dim() == 2; @@ -434,7 +484,11 @@ dispatchGaussianSparseRasterizeNumContributingGaussians( masks, tileOffsets, tileGaussianIds, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, pixelsToRender, activeTiles, tilePixelMask, @@ -449,7 +503,11 @@ dispatchGaussianSparseRasterizeNumContributingGaussians( masks, tileOffsets, tileGaussianIds, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, pixelsToRender, activeTiles, tilePixelMask, @@ -463,7 +521,7 @@ dispatchGaussianSparseRasterizeNumContributingGaussians( template <> std::tuple -dispatchGaussianSparseRasterizeNumContributingGaussians( +dispatch_count_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -474,8 +532,72 @@ dispatchGaussianSparseRasterizeNumContributingGaussians( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings) { // render settings + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +count_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_count_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize); + }); +} + +std::tuple +count_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_count_contributing_gaussians_sparse(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/CountContributingGaussians.h b/src/fvdb/detail/ops/CountContributingGaussians.h new file mode 100644 index 000000000..3b09528b3 --- /dev/null +++ b/src/fvdb/detail/ops/CountContributingGaussians.h @@ -0,0 +1,62 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_COUNTCONTRIBUTINGGAUSSIANS_H +#define FVDB_DETAIL_OPS_COUNTCONTRIBUTINGGAUSSIANS_H + +#include + +#include + +#include +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Count contributing Gaussians per pixel (dense). +/// +/// For each pixel, counts how many Gaussians contribute non-negligible opacity +/// using the same tile-based traversal as the rasterizer. +/// +/// @return (num_contributing [C, H, W], alpha [C, H, W, 1]) +std::tuple +count_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize); + +/// @brief Count contributing Gaussians at specified pixels (sparse). +/// +/// Sparse variant that counts only at the requested pixel locations. +/// +/// @return (num_contributing, alpha) as JaggedTensors +std::tuple +count_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_COUNTCONTRIBUTINGGAUSSIANS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.cu b/src/fvdb/detail/ops/EvalGaussianShBackward.cu similarity index 89% rename from src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.cu rename to src/fvdb/detail/ops/EvalGaussianShBackward.cu index 63a7fd8d1..a9b27b58c 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.cu +++ b/src/fvdb/detail/ops/EvalGaussianShBackward.cu @@ -1,15 +1,15 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include #include #include +#include #include +#include #include #include -#include #include namespace fvdb { @@ -22,9 +22,7 @@ namespace { // direction and write it out, so pull this into a function. template __device__ inline void -writeDLossDViewDir( - T x, T y, T z, T vX, T vY, T vZ, T inorm, typename Vec3Type::type *dLossDViewDir) { - using vec3t = typename Vec3Type::type; +writeDLossDViewDir(T x, T y, T z, T vX, T vY, T vZ, T inorm, float3 *dLossDViewDir) { const T dLossDViewDirDotViewDir = x * vX + y * vY + z * vZ; dLossDViewDir->x = (vX - dLossDViewDirDotViewDir * x) * inorm; @@ -34,16 +32,15 @@ writeDLossDViewDir( template inline __device__ void -evalShFunctionVJP(const int64_t degree, // degree of SH to be evaluated - const int64_t ci, // camera index - const int64_t gi, // gaussian index - const int64_t c, // render channel - const typename Vec3Type::type &dir, // [3] +evalShFunctionVJP(const int64_t degree, // degree of SH to be evaluated + const int64_t gi, // gaussian index + const int64_t c, // render channel + const float3 &dir, // [3] const torch::PackedTensorAccessor64 coeffsN, - const T *dLossDRenderQuantities, // [D] + const T *dLossDRenderQuantities, // [D] torch::PackedTensorAccessor64 dLossDSh0Coeffs, torch::PackedTensorAccessor64 dLossDShNCoeffs, - typename Vec3Type::type *dLossDViewDir // [3] optional + float3 *dLossDViewDir // [3] optional ) { T dLossDRenderQuantitiesLocal = dLossDRenderQuantities[c]; @@ -312,7 +309,9 @@ computeShBackward( return; } - using vec3t = typename Vec3Type::type; + static_assert(std::is_same::value, + "SH kernels assume float precision (float3 casts)"); + using vec3t = float3; const bool hasViewDirs = viewDirs.size(0) > 0; const vec3t viewDir = hasViewDirs ? *reinterpret_cast(viewDirs[cid][gid].data()) : vec3t{T(0), T(0), T(0)}; @@ -322,7 +321,6 @@ computeShBackward( vec3t *outDLossDViewDirPtr = outDLossDViewDirs == nullptr ? nullptr : &dLossDViewDir; evalShFunctionVJP(shDegreeToUse, - cid, gid, c, viewDir, @@ -371,9 +369,20 @@ computeShDiffuseOnlyBackward( } // namespace +template +std::tuple +dispatch_eval_gaussian_sh_bwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const int64_t numGaussians, + const torch::Tensor &viewDirs, // [C, N, 3] or empty + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &dLossDColors, + const torch::Tensor &radii, // [C, N] + const bool computeDLossDViewDirs); + template <> std::tuple -dispatchSphericalHarmonicsBackward( +dispatch_eval_gaussian_sh_bwd( const int64_t shDegreeToUse, const int64_t numCameras, const int64_t numGaussians, @@ -488,7 +497,7 @@ dispatchSphericalHarmonicsBackward( template <> std::tuple -dispatchSphericalHarmonicsBackward( +dispatch_eval_gaussian_sh_bwd( const int64_t shDegreeToUse, const int64_t numCameras, const int64_t numGaussians, @@ -587,7 +596,6 @@ dispatchSphericalHarmonicsBackward( std::vector prefetchSizes; const cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; std::vector prefetchLocations = {location}; - std::vector prefetchLocationIndices = {0}; for (int cameraIndex = 0; cameraIndex < C; ++cameraIndex) { prefetchPtrs.emplace_back(dLossDViewDirs.data_ptr() + @@ -597,6 +605,7 @@ dispatchSphericalHarmonicsBackward( sizeof(scalar_t)); } + std::vector prefetchLocationIndices(prefetchPtrs.size(), 0); C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPtrs.data(), prefetchSizes.data(), prefetchPtrs.size(), @@ -686,18 +695,38 @@ dispatchSphericalHarmonicsBackward( template <> std::tuple -dispatchSphericalHarmonicsBackward( - const int64_t shDegreeToUse, - const int64_t numCameras, - const int64_t numGaussians, - const torch::Tensor &viewDirs, // [C, N, 3] - const torch::Tensor &shNCoeffs, // [N, K-1, D] - const torch::Tensor &dLossDRenderQuantities, // [C, N, D] - const torch::Tensor &radii, // [C, N] - const bool computeDLossDViewDirs) { +dispatch_eval_gaussian_sh_bwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const int64_t numGaussians, + const torch::Tensor &viewDirs, // [C, N, 3] or empty + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &dLossDRenderQuantities, // [C, N, D] + const torch::Tensor &radii, // [C, N] + const bool computeDLossDViewDirs) { TORCH_CHECK(false, "CPU implementation not available"); } +std::tuple +eval_gaussian_sh_bwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const int64_t numGaussians, + const torch::Tensor &viewDirs, // [C, N, 3] or empty + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &dLossDColors, + const torch::Tensor &radii, // [C, N] + const bool computeDLossDViewDirs) { + return FVDB_DISPATCH_KERNEL(dLossDColors.device(), [&]() { + return dispatch_eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDColors, + radii, + computeDLossDViewDirs); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.h b/src/fvdb/detail/ops/EvalGaussianShBackward.h similarity index 66% rename from src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.h rename to src/fvdb/detail/ops/EvalGaussianShBackward.h index fe5b0077b..f758d7b17 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsBackward.h +++ b/src/fvdb/detail/ops/EvalGaussianShBackward.h @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSBACKWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSBACKWARD_H +#ifndef FVDB_DETAIL_OPS_EVALGAUSSIANSHBACKWARD_H +#define FVDB_DETAIL_OPS_EVALGAUSSIANSHBACKWARD_H #include @@ -35,19 +35,18 @@ namespace ops { /// - SH coefficients [N, K, 3] - ∂L/∂sh_coeffs /// - Direction vectors [N, 3] - ∂L/∂dirs (if compute_v_dirs is true, otherwise empty /// tensor) -template std::tuple -dispatchSphericalHarmonicsBackward(const int64_t shDegreeToUse, - const int64_t numCameras, - const int64_t numGaussians, - const torch::Tensor &viewDirs, // [N, 3] - const torch::Tensor &shNCoeffs, // [N, K-1, D] - const torch::Tensor &dLossDColors, - const torch::Tensor &radii, // [N] - const bool computeDLossDViewDirs); +eval_gaussian_sh_bwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const int64_t numGaussians, + const torch::Tensor &viewDirs, // [C, N, 3] or empty + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &dLossDColors, + const torch::Tensor &radii, // [C, N] + const bool computeDLossDViewDirs); } // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSBACKWARD_H +#endif // FVDB_DETAIL_OPS_EVALGAUSSIANSHBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.cu b/src/fvdb/detail/ops/EvalGaussianShForward.cu similarity index 76% rename from src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.cu rename to src/fvdb/detail/ops/EvalGaussianShForward.cu index 62dd504fb..e2f401f20 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.cu +++ b/src/fvdb/detail/ops/EvalGaussianShForward.cu @@ -1,13 +1,13 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include +#include #include +#include #include +#include #include -#include #include namespace fvdb { @@ -22,11 +22,10 @@ namespace { // implementation template inline __device__ T -evalShFunction(const int64_t degree, // degree of SH to be evaluated - const int64_t ci, // camera index - const int64_t gi, // gaussian index - const int64_t c, // render channel - const typename Vec3Type::type &viewDir, // [D] +evalShFunction(const int64_t degree, // degree of SH to be evaluated + const int64_t gi, // gaussian index + const int64_t c, // render channel + const float3 &viewDir, // [D] const torch::PackedTensorAccessor64 sh0Coeffs, const torch::PackedTensorAccessor64 shNCoeffs) { const T cSH0 = sh0Coeffs[gi][0][c]; @@ -160,11 +159,13 @@ computeSh( T result = T(0); if (!(radii != nullptr && radii[cid * N + gid] <= 0)) { - using vec3t = typename Vec3Type::type; + static_assert(std::is_same::value, + "SH kernels assume float precision (float3 casts)"); + using vec3t = float3; const bool hasViewDirs = viewDirs.size(0) > 0; const vec3t dir = hasViewDirs ? *reinterpret_cast(viewDirs[cid][gid].data()) : vec3t{0.f, 0.f, 0.f}; - result = evalShFunction(shDegreeToUse, cid, gid, c, dir, sh0Coeffs, shNCoeffs); + result = evalShFunction(shDegreeToUse, gid, c, dir, sh0Coeffs, shNCoeffs); } outRenderQuantities[(cid * N + gid) * D + c] = result; } @@ -199,14 +200,23 @@ computeShDiffuseOnly(const int64_t offset, } // namespace +template +torch::Tensor dispatch_eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, // [C, N, 3] + const torch::Tensor &sh0Coeffs, // [1, N, D] + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &radii // [C, N] +); + template <> torch::Tensor -dispatchSphericalHarmonicsForward(const int64_t shDegreeToUse, - const int64_t numCameras, - const torch::Tensor &viewDirs, // [C, N, 3] - const torch::Tensor &sh0Coeffs, // [N, 1, D] - const torch::Tensor &shNCoeffs, // [N, K-1, D] - const torch::Tensor &radii // [C, N] +dispatch_eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, // [C, N, 3] + const torch::Tensor &sh0Coeffs, // [N, 1, D] + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &radii // [C, N] ) { FVDB_FUNC_RANGE(); // Valid modes: @@ -302,13 +312,12 @@ dispatchSphericalHarmonicsForward(const int64_t shDegreeToUse, template <> torch::Tensor -dispatchSphericalHarmonicsForward( - const int64_t shDegreeToUse, - const int64_t numCameras, - const torch::Tensor &viewDirs, // [C, N, 3] - const torch::Tensor &sh0Coeffs, // [N, 1, D] - const torch::Tensor &shNCoeffs, // [N, K-1, D] - const torch::Tensor &radii // [C, N] +dispatch_eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, // [C, N, 3] + const torch::Tensor &sh0Coeffs, // [N, 1, D] + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &radii // [C, N] ) { FVDB_FUNC_RANGE(); // Valid modes: @@ -412,16 +421,66 @@ dispatchSphericalHarmonicsForward( template <> torch::Tensor -dispatchSphericalHarmonicsForward(const int64_t shDegreeToUse, - const int64_t numCameras, - const torch::Tensor &dirs, // [N, 3] - const torch::Tensor &sh0Coeffs, // [1, N, D] - const torch::Tensor &shNCoeffs, // [K-1, N, D] - const torch::Tensor &radii // [N] +dispatch_eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &dirs, // [N, 3] + const torch::Tensor &sh0Coeffs, // [1, N, D] + const torch::Tensor &shNCoeffs, // [K-1, N, D] + const torch::Tensor &radii // [N] ) { TORCH_CHECK(false, "CPU implementation not available"); } +torch::Tensor +eval_spherical_harmonics(const torch::Tensor &means, + const torch::Tensor &sh0, + const torch::Tensor &shN, + const int64_t shDegreeToUse, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &perGaussianProjectedRadii) { + FVDB_FUNC_RANGE(); + const auto K = shN.size(1) + 1; // number of SH bases + const auto C = worldToCameraMatrices.size(0); // number of cameras + const auto actualShDegree = shDegreeToUse < 0 ? (std::sqrt(K) - 1) : shDegreeToUse; + if (actualShDegree == 0) { + return FVDB_DISPATCH_KERNEL(sh0.device(), [&]() { + return dispatch_eval_gaussian_sh_fwd(actualShDegree, + C, + torch::Tensor(), + sh0, + torch::Tensor(), + perGaussianProjectedRadii); + }); + } else { + auto [camToWorldMatrices, info] = torch::linalg_inv_ex(worldToCameraMatrices); + const torch::Tensor viewDirs = + means.index( + {torch::indexing::None, torch::indexing::Slice(), torch::indexing::Slice()}) - + camToWorldMatrices.index({torch::indexing::Slice(), + torch::indexing::None, + torch::indexing::Slice(0, 3), + 3}); // [1, N, 3] - [C, 1, 3] + return FVDB_DISPATCH_KERNEL(sh0.device(), [&]() { + return dispatch_eval_gaussian_sh_fwd( + actualShDegree, C, viewDirs, sh0, shN, perGaussianProjectedRadii); + }); + } +} + +torch::Tensor +eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, // [C, N, 3] + const torch::Tensor &sh0Coeffs, // [1, N, D] + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &radii // [C, N] +) { + return FVDB_DISPATCH_KERNEL(sh0Coeffs.device(), [&]() { + return dispatch_eval_gaussian_sh_fwd( + shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/EvalGaussianShForward.h b/src/fvdb/detail/ops/EvalGaussianShForward.h new file mode 100644 index 000000000..c4d818136 --- /dev/null +++ b/src/fvdb/detail/ops/EvalGaussianShForward.h @@ -0,0 +1,67 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_EVALGAUSSIANSHFORWARD_H +#define FVDB_DETAIL_OPS_EVALGAUSSIANSHFORWARD_H + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Evaluate spherical harmonics functions to compute features/colors. +/// +/// This function computes the features/colors for points in 3D space using spherical harmonics +/// (SH) representation. Spherical harmonics provide an efficient way to represent view-dependent +/// appearance for Gaussian Splatting and other rendering techniques. The output features are not +/// limited to RGB colors; they can have any number of channels. +/// +/// @param[in] shDegreeToUse Degree of spherical harmonics to use (0-3 typically, higher degrees +/// provide more detail) +/// @param[in] numCameras Number of cameras used for rendering +/// @param[in] viewDirs Direction vectors [N, 3] (packed) or [C, N, 3] (unpacked) representing +/// view directions. Need not be unit-length; the kernel normalizes internally. +/// @param[in] sh0Coeffs Spherical harmonic coefficients [N, 1, D] (packed) or +/// [1, N, D] (unpacked), where D is the number of feature channels +/// @param[in] shNCoeffs Higher order spherical harmonic coefficients [N, K-1, D] (packed) or +/// [K-1, N, D] (unpacked), where K depends on sh_degree_to_use (K=(sh_degree_to_use+1)²) +/// @param[in] radii radii [N] (packed) or [C, N] (unpacked) for view-dependent level-of-detail +/// control +/// +/// @return Features/colors [N, D] computed from the spherical harmonics evaluation +torch::Tensor eval_gaussian_sh_fwd(const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, // [C, N, 3] + const torch::Tensor &sh0Coeffs, // [1, N, D] + const torch::Tensor &shNCoeffs, // [N, K-1, D] + const torch::Tensor &radii // [C, N] +); + +/// @brief Evaluate spherical harmonics to compute view-dependent features/colors. +/// +/// Computes per-camera, per-Gaussian features using spherical harmonics (SH) representation. +/// Internally derives view directions from the camera matrices and Gaussian means, then dispatches +/// to the SH forward kernel. When @p shDegreeToUse is 0, view directions are not needed. +/// The output features are not limited to RGB colors; they can have any number of channels. +/// +/// @param[in] means Gaussian mean positions [N, 3] +/// @param[in] sh0 Degree-0 SH coefficients [N, 1, D] where D is number of channels +/// @param[in] shN Higher-degree SH coefficients [N, K-1, D] where +/// K = (shDegreeToUse+1)² +/// @param[in] shDegreeToUse SH degree to use (0-3 typically, -1 to use all available degrees) +/// @param[in] worldToCameraMatrices Camera extrinsics [C, 4, 4] +/// @param[in] perGaussianProjectedRadii Projected radii [C, N] for view-dependent level-of-detail +/// @return Evaluated SH features [C, N, D] +torch::Tensor eval_spherical_harmonics(const torch::Tensor &means, + const torch::Tensor &sh0, + const torch::Tensor &shN, + int64_t shDegreeToUse, + const torch::Tensor &worldToCameraMatrices, + const torch::Tensor &perGaussianProjectedRadii); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_EVALGAUSSIANSHFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/FusedSSIM.cu b/src/fvdb/detail/ops/FusedSSIM.cu similarity index 93% rename from src/fvdb/detail/ops/gsplat/FusedSSIM.cu rename to src/fvdb/detail/ops/FusedSSIM.cu index 0124dd93c..183058457 100644 --- a/src/fvdb/detail/ops/gsplat/FusedSSIM.cu +++ b/src/fvdb/detail/ops/FusedSSIM.cu @@ -1,3 +1,6 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 + // This file contains source code from the fused-ssim library obtained from // https://github.com/rahul-goel/fused-ssim. The fused-ssim library is licensed under the MIT // License. Refer to ORSB 5512107 for more. Original license text follows. @@ -25,11 +28,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include +#include +#include #include -#include - #include #include @@ -43,6 +45,8 @@ namespace detail { namespace ops { +using fvdb::detail::imagePrefetchBatchAsync; + namespace { namespace cg = cooperative_groups; @@ -332,8 +336,6 @@ fusedSSIMBackwardKernel(int localToGlobalOffset, int H, int W, int CH, - float C1, - float C2, const float *__restrict__ img1, const float *__restrict__ img2, const float *__restrict__ dL_dmap, @@ -566,8 +568,6 @@ fusedSSIMBackwardCUDA(double C1, H, W, CH, - static_cast(C1), - static_cast(C2), img1.contiguous().const_data_ptr(), img2.contiguous().const_data_ptr(), dL_dmap.contiguous().const_data_ptr(), @@ -580,50 +580,7 @@ fusedSSIMBackwardCUDA(double C1, return dL_dimg1; } -namespace { - -void -imagePrefetchBatchAsync(const torch::TensorList &tensors, - int localElementOffset, - int localElementCount, - int deviceId, - cudaStream_t stream) { - TORCH_CHECK(stream, "cudaMemPrefetchBatchAsync does not support the default stream"); -#if (CUDART_VERSION < 13000) - for (size_t i = 0; i < tensors.size(); ++i) { - const auto &tensor = tensors[i]; - TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); - C10_CUDA_CHECK( - nanovdb::util::cuda::memPrefetchAsync(tensor.data_ptr() + localElementOffset, - localElementCount * sizeof(float), - deviceId, - stream)); - } -#else - std::vector prefetchPointers; - std::vector prefetchSizes; - cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; - std::vector prefetchLocations = {location}; - std::vector prefetchLocationIndices = {0}; - - for (size_t i = 0; i < tensors.size(); ++i) { - const auto &tensor = tensors[i]; - TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); - prefetchPointers.emplace_back(tensor.data_ptr() + localElementOffset); - prefetchSizes.emplace_back(localElementCount * sizeof(float)); - } - C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPointers.data(), - prefetchSizes.data(), - prefetchPointers.size(), - prefetchLocations.data(), - prefetchLocationIndices.data(), - prefetchLocations.size(), - 0, - stream)); -#endif -} - -} // namespace +namespace {} // namespace // ------------------------------------------ // PyTorch Interface (Forward) @@ -841,8 +798,6 @@ fusedSSIMBackwardPrivateUse1(double C1, H, W, CH, - static_cast(C1), - static_cast(C2), img1_.const_data_ptr(), img2_.const_data_ptr(), dL_dmap_.const_data_ptr(), diff --git a/src/fvdb/detail/ops/gsplat/FusedSSIM.h b/src/fvdb/detail/ops/FusedSSIM.h similarity index 94% rename from src/fvdb/detail/ops/gsplat/FusedSSIM.h rename to src/fvdb/detail/ops/FusedSSIM.h index 27f55101d..767c5afdc 100644 --- a/src/fvdb/detail/ops/gsplat/FusedSSIM.h +++ b/src/fvdb/detail/ops/FusedSSIM.h @@ -1,3 +1,6 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 + // This file contains source code from the fused-ssim library obtained from // https://github.com/rahul-goel/fused-ssim. The fused-ssim library is licensed under the MIT // License. Refer to ORSB 5512107 for more. Original license text follows. @@ -25,8 +28,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_FUSEDSSIM_H -#define FVDB_DETAIL_OPS_GSPLAT_FUSEDSSIM_H +#ifndef FVDB_DETAIL_OPS_FUSEDSSIM_H +#define FVDB_DETAIL_OPS_FUSEDSSIM_H #include diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.cu b/src/fvdb/detail/ops/IdentifyContributingGaussians.cu similarity index 81% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.cu rename to src/fvdb/detail/ops/IdentifyContributingGaussians.cu index dadf4ec8f..0081eceb9 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.cu +++ b/src/fvdb/detail/ops/IdentifyContributingGaussians.cu @@ -2,18 +2,19 @@ // SPDX-License-Identifier: Apache-2.0 // #include -#include -#include -#include -#include -#include -#include +#include +#include #include #include +#include +#include +#include +#include #include #include +#include #include namespace fvdb::detail::ops { @@ -404,7 +405,7 @@ template struct RasterizeContributingGauss } const ScalarType nextTransmittance = accumTransmittance * (1.0f - alpha); - if (nextTransmittance <= 1e-4f) { // this pixel is done: exclusive + if (nextTransmittance <= kTransmittanceThreshold) { // this pixel is done done = true; break; } @@ -516,9 +517,14 @@ launchRasterizeContributingGaussianIdsForwardKernel( // intersections const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - // render settings + // render parameters const std::optional &maybeNumContributingGaussians, // [C, NumPixels, 1] - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, // sparse rendering parameters const std::optional &pixelsToRender = std::nullopt, // [C, NumPixels, 2] const std::optional &activeTiles = std::nullopt, @@ -533,31 +539,43 @@ launchRasterizeContributingGaussianIdsForwardKernel( // Check if we're in top-k mode (numDepthSamples > 0 indicates top-k) // If so, call the top-k kernel and reformat the results - if (settings.numDepthSamples > 0) { + if (numDepthSamples > 0) { // Call the top-k dispatch function fvdb::JaggedTensor outIds, outWeights; if (pixelsToRender.has_value()) { // Sparse mode: call sparse top-k dispatch std::tie(outIds, outWeights) = - dispatchGaussianSparseRasterizeTopContributingGaussianIds( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - pixelsToRender.value(), - activeTiles.value(), - tilePixelMask.value(), - tilePixelCumsum.value(), - pixelMap.value(), - settings); + identify_top_contributing_gaussians_sparse(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender.value(), + activeTiles.value(), + tilePixelMask.value(), + tilePixelCumsum.value(), + pixelMap.value(), + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); } else { // Dense mode: call dense top-k dispatch torch::Tensor denseIds, denseWeights; - std::tie(denseIds, denseWeights) = - dispatchGaussianRasterizeTopContributingGaussianIds( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings); + std::tie(denseIds, denseWeights) = identify_top_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); // Convert dense output [C, H, W, K] to JaggedTensor format const auto C = denseIds.size(0); @@ -595,7 +613,7 @@ launchRasterizeContributingGaussianIdsForwardKernel( numValidSamplesData, totalCount, C, - settings.numDepthSamples, + numDepthSamples, outIds.joffsets(), pixelsToRender.has_value()); } @@ -632,12 +650,12 @@ launchRasterizeContributingGaussianIdsForwardKernel( TORCH_CHECK_VALUE( torch::equal(numContributingGaussians.joffsets(), torch::arange(0, C + 1, 1, numContributingGaussians.options()) * - settings.imageHeight * settings.imageWidth), + imageHeight * imageWidth), "numContributingGaussians must have the same number of elements as the number of pixels in the images"); } - const auto tileExtentH = (settings.imageHeight + settings.tileSize - 1) / settings.tileSize; - const auto tileExtentW = (settings.imageWidth + settings.tileSize - 1) / settings.tileSize; + const auto tileExtentH = (imageHeight + tileSize - 1) / tileSize; + const auto tileExtentW = (imageWidth + tileSize - 1) / tileSize; if (tileOffsets.dim() == 3) { TORCH_CHECK_VALUE(tileOffsets.size(2) == tileExtentW, @@ -655,7 +673,7 @@ launchRasterizeContributingGaussianIdsForwardKernel( const auto &sizes = pixelsToRender.has_value() ? pixelsToRender->lsizes1() - : std::vector{C * settings.imageHeight * settings.imageWidth}; + : std::vector{C * imageHeight * imageWidth}; // maximum possible number of depth samples per pixel const auto maxDepthSamplesPerPixel = numContributingGaussians.jdata().max().item(); @@ -698,8 +716,7 @@ launchRasterizeContributingGaussianIdsForwardKernel( // - vec3t conic; -- 12 bytes for float32 // Note: We use thread-local storage for buffering writes, so only need shared memory for // Gaussians. We add 32 bytes for alignment padding. - const uint32_t sharedMem = - settings.tileSize * settings.tileSize * sizeof(Gaussian2D); + const uint32_t sharedMem = tileSize * tileSize * sizeof(Gaussian2D); if (cudaFuncSetAttribute(rasterizeContributingGaussianIdsForward, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -709,7 +726,7 @@ launchRasterizeContributingGaussianIdsForwardKernel( " bytes), try lowering tile_size."); } - const dim3 blockDim = {settings.tileSize, settings.tileSize, 1}; + const dim3 blockDim = {tileSize, tileSize, 1}; const dim3 gridDim = activeTiles.has_value() // sparse mode ? dim3(activeTiles.value().size(0), 1, 1) : dim3(C * tileExtentH * tileExtentW, 1, 1); @@ -718,11 +735,11 @@ launchRasterizeContributingGaussianIdsForwardKernel( opacities, backgrounds, masks, - settings.imageWidth, - settings.imageHeight, - settings.imageOriginW, - settings.imageOriginH, - settings.tileSize, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, maxDepthSamplesPerPixel, tileOffsets, tileGaussianIds, @@ -768,16 +785,56 @@ launchRasterizeContributingGaussianIdsForwardKernel( } // namespace +template +std::tuple dispatch_identify_contributing_gaussians( + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples, + const std::optional &maybeNumContributingGaussians); + +template +std::tuple dispatch_identify_contributing_gaussians_sparse( + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples, + const std::optional &maybeNumContributingGaussians); + template <> __host__ std::tuple -dispatchGaussianRasterizeContributingGaussianIds( +dispatch_identify_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, const std::optional &maybeNumContributingGaussians // [C, H, W] ) { FVDB_FUNC_RANGE(); @@ -826,7 +883,12 @@ dispatchGaussianRasterizeContributingGaussianIds( tileOffsets, tileGaussianIds, numContributingGaussiansJagged, - settings) + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples) : launchRasterizeContributingGaussianIdsForwardKernel( means2d, conics, @@ -836,7 +898,12 @@ dispatchGaussianRasterizeContributingGaussianIds( tileOffsets, tileGaussianIds, numContributingGaussiansJagged, - settings); + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); return std::make_tuple(ids, weights); }), AT_EXPAND(AT_FLOATING_TYPES), @@ -845,21 +912,26 @@ dispatchGaussianRasterizeContributingGaussianIds( template <> std::tuple -dispatchGaussianRasterizeContributingGaussianIds( +dispatch_identify_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, const std::optional &maybeNumContributingGaussians) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } template <> __host__ std::tuple -dispatchGaussianSparseRasterizeContributingGaussianIds( +dispatch_identify_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -870,7 +942,12 @@ dispatchGaussianSparseRasterizeContributingGaussianIds( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, const std::optional &maybeNumContributingGaussians) { FVDB_FUNC_RANGE(); const bool isPacked = means2d.dim() == 2; @@ -892,7 +969,12 @@ dispatchGaussianSparseRasterizeContributingGaussianIds( tileOffsets, tileGaussianIds, maybeNumContributingGaussians, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, pixelsToRender, activeTiles, tilePixelMask, @@ -908,7 +990,12 @@ dispatchGaussianSparseRasterizeContributingGaussianIds( tileOffsets, tileGaussianIds, maybeNumContributingGaussians, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, pixelsToRender, activeTiles, tilePixelMask, @@ -922,7 +1009,7 @@ dispatchGaussianSparseRasterizeContributingGaussianIds( template <> std::tuple -dispatchGaussianSparseRasterizeContributingGaussianIds( +dispatch_identify_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -933,9 +1020,84 @@ dispatchGaussianSparseRasterizeContributingGaussianIds( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, const std::optional &maybeNumContributingGaussians) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +identify_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, + const std::optional &maybeNumContributingGaussians) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_identify_contributing_gaussians(means2d, + conics, + opacities, + tile_offsets, + tile_gaussian_ids, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, + maybeNumContributingGaussians); + }); +} + +std::tuple +identify_contributing_gaussians_sparse( + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, + const std::optional &maybeNumContributingGaussians) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_identify_contributing_gaussians_sparse( + means2d, + conics, + opacities, + tile_offsets, + tile_gaussian_ids, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, + maybeNumContributingGaussians); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/IdentifyContributingGaussians.h b/src/fvdb/detail/ops/IdentifyContributingGaussians.h new file mode 100644 index 000000000..0945ccb21 --- /dev/null +++ b/src/fvdb/detail/ops/IdentifyContributingGaussians.h @@ -0,0 +1,66 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_IDENTIFYCONTRIBUTINGGAUSSIANS_H +#define FVDB_DETAIL_OPS_IDENTIFYCONTRIBUTINGGAUSSIANS_H + +#include + +#include + +#include +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Identify contributing Gaussians per pixel (dense). +/// +/// Deep image rasterization: for each pixel, returns the IDs of all Gaussians +/// that contributed non-negligible opacity, along with per-Gaussian alpha weights. +/// +/// @return (gaussian_ids, weights) as JaggedTensors +std::tuple identify_contributing_gaussians( + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples, + const std::optional &maybeNumContributingGaussians = std::nullopt); + +/// @brief Identify contributing Gaussians at specified pixels (sparse). +/// +/// Sparse variant that identifies only at the requested pixel locations. +/// +/// @return (gaussian_ids, weights) as JaggedTensors +std::tuple identify_contributing_gaussians_sparse( + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples, + const std::optional &maybeNumContributingGaussians = std::nullopt); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_IDENTIFYCONTRIBUTINGGAUSSIANS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.cu b/src/fvdb/detail/ops/IdentifyTopContributingGaussians.cu similarity index 71% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.cu rename to src/fvdb/detail/ops/IdentifyTopContributingGaussians.cu index 16bbedfb2..6723d3242 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.cu +++ b/src/fvdb/detail/ops/IdentifyTopContributingGaussians.cu @@ -2,16 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 // #include -#include -#include -#include -#include -#include +#include #include #include +#include +#include +#include +#include #include +#include #include namespace fvdb::detail::ops { @@ -238,7 +239,7 @@ template struct RasterizeTopContributingGa } const ScalarType nextTransmittance = accumTransmittance * (1.0f - alpha); - if (nextTransmittance <= 1e-4f) { // this pixel is done: exclusive + if (nextTransmittance <= kTransmittanceThreshold) { // this pixel is done done = true; break; } @@ -358,7 +359,12 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( // intersections const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings, // render settings + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples, const std::optional &pixelsToRender = std::nullopt, // [C, NumPixels, 2] const std::optional &activeTiles = std::nullopt, const std::optional &tilePixelMask = std::nullopt, @@ -366,7 +372,7 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( const std::optional &pixelMap = std::nullopt) { const at::cuda::OptionalCUDAGuard device_guard(device_of(means2d)); - TORCH_CHECK_VALUE(settings.numDepthSamples > 0, "numDepthSamples must be greater than 0"); + TORCH_CHECK_VALUE(numDepthSamples > 0, "numDepthSamples must be greater than 0"); if (IS_PACKED) { TORCH_CHECK_VALUE(means2d.size(1) > 0, "means2d cannot be empty"); } @@ -374,19 +380,17 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( // tileOffsets can be 3D (dense) or 1D (sparse) const bool tileOffsetsAreSparse = tileOffsets.dim() == 1; if (!tileOffsetsAreSparse) { - TORCH_CHECK_VALUE(tileOffsets.size(2) == - (settings.imageWidth + settings.tileSize - 1) / settings.tileSize, + TORCH_CHECK_VALUE(tileOffsets.size(2) == (imageWidth + tileSize - 1) / tileSize, "tileOffsets width must match the number of tiles in image size"); - TORCH_CHECK_VALUE(tileOffsets.size(1) == - (settings.imageHeight + settings.tileSize - 1) / settings.tileSize, + TORCH_CHECK_VALUE(tileOffsets.size(1) == (imageHeight + tileSize - 1) / tileSize, "tileOffsets height must match the number of tiles in image size"); } // Get C from tileOffsets for dense mode // For sparse mode, C is unused, only used for output sizing for dense mode const uint32_t C = tileOffsetsAreSparse ? 0 : tileOffsets.size(0); - const uint32_t tileExtentH = (settings.imageHeight + settings.tileSize - 1) / settings.tileSize; - const uint32_t tileExtentW = (settings.imageWidth + settings.tileSize - 1) / settings.tileSize; + const uint32_t tileExtentH = (imageHeight + tileSize - 1) / tileSize; + const uint32_t tileExtentW = (imageWidth + tileSize - 1) / tileSize; TORCH_CHECK_VALUE(pixelMap.has_value() == pixelsToRender.has_value(), "pixelMap and pixelsToRender must be provided together"); @@ -401,7 +405,7 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( // Calculate total size and build offsets const auto &sizes = pixelsToRender.has_value() ? pixelsToRender->lsizes1() - : std::vector{C * settings.imageHeight * settings.imageWidth}; + : std::vector{C * imageHeight * imageWidth}; int64_t totalSize = 0; std::vector offsetsVec; @@ -413,10 +417,10 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( } auto outIdsData = - torch::empty({totalSize, settings.numDepthSamples}, means2d.options().dtype(torch::kInt32)); + torch::empty({totalSize, numDepthSamples}, means2d.options().dtype(torch::kInt32)); auto outWeightsData = - torch::empty({totalSize, settings.numDepthSamples}, + torch::empty({totalSize, numDepthSamples}, means2d.options().dtype(c10::CppTypeToScalarType::value)); // Initialize IDs to -1 (sentinel for "no valid ID") and weights to 0 @@ -448,9 +452,9 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( // - scalar_t opacity; -- 4 bytes for float32 // - vec3t conic; -- 12 bytes for float32 const uint32_t sharedMem = - settings.tileSize * settings.tileSize * + tileSize * tileSize * (sizeof(Gaussian2D) + - (sizeof(int32_t) + sizeof(ScalarType) + sizeof(uint32_t)) * settings.numDepthSamples); + (sizeof(int32_t) + sizeof(ScalarType) + sizeof(uint32_t)) * numDepthSamples); if (cudaFuncSetAttribute(rasterizeTopContributingGaussianIdsForward, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -460,30 +464,29 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( " bytes), try lowering tile_size."); } - const dim3 blockDim = {settings.tileSize, settings.tileSize, 1}; + const dim3 blockDim = {tileSize, tileSize, 1}; const dim3 gridDim = activeTiles.has_value() // sparse mode ? dim3(activeTiles.value().size(0), 1, 1) : dim3(C * tileExtentH * tileExtentW, 1, 1); - auto args = - RasterizeTopContributingGaussianIdsArgs(means2d, - conics, - opacities, - backgrounds, - masks, - settings.imageWidth, - settings.imageHeight, - settings.imageOriginW, - settings.imageOriginH, - settings.tileSize, - settings.numDepthSamples, - tileOffsets, - tileGaussianIds, - outIds, - outWeights, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap); + auto args = RasterizeTopContributingGaussianIdsArgs(means2d, + conics, + opacities, + backgrounds, + masks, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, + tileOffsets, + tileGaussianIds, + outIds, + outWeights, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap); rasterizeTopContributingGaussianIdsForward<<>>(args); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -493,18 +496,54 @@ launchRasterizeTopContributingGaussianIdsForwardKernel( } // namespace +template +std::tuple +dispatch_identify_top_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples); + +template +std::tuple +dispatch_identify_top_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples); + template <> std::tuple -dispatchGaussianRasterizeTopContributingGaussianIds( +dispatch_identify_top_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings // render settings - -) { + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples) { FVDB_FUNC_RANGE(); const bool isPacked = means2d.dim() == 2; @@ -524,7 +563,12 @@ dispatchGaussianRasterizeTopContributingGaussianIds( masks, tileOffsets, tileGaussianIds, - settings) + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples) : launchRasterizeTopContributingGaussianIdsForwardKernel( means2d, conics, @@ -533,14 +577,17 @@ dispatchGaussianRasterizeTopContributingGaussianIds( masks, tileOffsets, tileGaussianIds, - settings); + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); // Get C from tileOffsets for dense mode const auto C = tileOffsets.size(0); return std::make_tuple( - ids.jdata().reshape( - {C, settings.imageHeight, settings.imageWidth, settings.numDepthSamples}), - weights.jdata().reshape( - {C, settings.imageHeight, settings.imageWidth, settings.numDepthSamples})); + ids.jdata().reshape({C, imageHeight, imageWidth, numDepthSamples}), + weights.jdata().reshape({C, imageHeight, imageWidth, numDepthSamples})); }), AT_EXPAND(AT_FLOATING_TYPES), c10::kHalf); @@ -548,21 +595,25 @@ dispatchGaussianRasterizeTopContributingGaussianIds( template <> std::tuple -dispatchGaussianRasterizeTopContributingGaussianIds( +dispatch_identify_top_contributing_gaussians( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings // render settings -) { + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } template <> std::tuple -dispatchGaussianSparseRasterizeTopContributingGaussianIds( +dispatch_identify_top_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -573,8 +624,12 @@ dispatchGaussianSparseRasterizeTopContributingGaussianIds( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings // render settings -) { + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples) { FVDB_FUNC_RANGE(); const bool isPacked = means2d.dim() == 2; @@ -594,7 +649,12 @@ dispatchGaussianSparseRasterizeTopContributingGaussianIds( masks, tileOffsets, tileGaussianIds, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, pixelsToRender, activeTiles, tilePixelMask, @@ -609,7 +669,12 @@ dispatchGaussianSparseRasterizeTopContributingGaussianIds( masks, tileOffsets, tileGaussianIds, - settings, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples, pixelsToRender, activeTiles, tilePixelMask, @@ -623,7 +688,7 @@ dispatchGaussianSparseRasterizeTopContributingGaussianIds( template <> std::tuple -dispatchGaussianSparseRasterizeTopContributingGaussianIds( +dispatch_identify_top_contributing_gaussians_sparse( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] @@ -634,9 +699,77 @@ dispatchGaussianSparseRasterizeTopContributingGaussianIds( const torch::Tensor &tilePixelMask, const torch::Tensor &tilePixelCumsum, const torch::Tensor &pixelMap, - const RenderSettings &settings // render settings -) { + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const int numDepthSamples) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +identify_top_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_identify_top_contributing_gaussians(means2d, + conics, + opacities, + tile_offsets, + tile_gaussian_ids, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); + }); +} + +std::tuple +identify_top_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples) { + return FVDB_DISPATCH_KERNEL_DEVICE(means2d.device(), [&]() { + return dispatch_identify_top_contributing_gaussians_sparse(means2d, + conics, + opacities, + tile_offsets, + tile_gaussian_ids, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + numDepthSamples); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/IdentifyTopContributingGaussians.h b/src/fvdb/detail/ops/IdentifyTopContributingGaussians.h new file mode 100644 index 000000000..0164c478f --- /dev/null +++ b/src/fvdb/detail/ops/IdentifyTopContributingGaussians.h @@ -0,0 +1,58 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_IDENTIFYTOPCONTRIBUTINGGAUSSIANS_H +#define FVDB_DETAIL_OPS_IDENTIFYTOPCONTRIBUTINGGAUSSIANS_H + +#include + +#include + +#include +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Performs deep image rasterization to render the IDs and weighted alpha values of the +/// top-K most visible Gaussians for each pixel (dispatch wrapper). +std::tuple +identify_top_contributing_gaussians(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples); + +/// @brief Performs sparse deep image rasterization to render the IDs and weighted alpha values of +/// the top-K most visible Gaussians for each pixel. Renders only specified pixels (dispatch +/// wrapper). +std::tuple +identify_top_contributing_gaussians_sparse(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &opacities, + const torch::Tensor &tile_offsets, + const torch::Tensor &tile_gaussian_ids, + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + int numDepthSamples); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_IDENTIFYTOPCONTRIBUTINGGAUSSIANS_H diff --git a/src/fvdb/detail/ops/IntegrateTSDF.cu b/src/fvdb/detail/ops/IntegrateTSDF.cu index a99b9fe90..d7f2d7c2e 100644 --- a/src/fvdb/detail/ops/IntegrateTSDF.cu +++ b/src/fvdb/detail/ops/IntegrateTSDF.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -92,17 +93,7 @@ unprojectDepthmapKernel(int64_t imageWidth, } } -template struct OpType { - using type = T; -}; - -template <> struct OpType { - using type = float; -}; - -template <> struct OpType { - using type = float; -}; +using fvdb::detail::OpType; template __global__ __launch_bounds__(DEFAULT_BLOCK_DIM) void diff --git a/src/fvdb/detail/ops/gsplat/GaussianTileIntersection.cu b/src/fvdb/detail/ops/IntersectGaussianTiles.cu similarity index 72% rename from src/fvdb/detail/ops/gsplat/GaussianTileIntersection.cu rename to src/fvdb/detail/ops/IntersectGaussianTiles.cu index fc622668d..adf4979f1 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianTileIntersection.cu +++ b/src/fvdb/detail/ops/IntersectGaussianTiles.cu @@ -1,10 +1,12 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include +#include +#include #include +#include +#include +#include #include #include @@ -22,15 +24,6 @@ namespace { #define NUM_THREADS 256 -#define CUB_WRAPPER(func, ...) \ - do { \ - size_t tempStorageBytes = 0; \ - func(nullptr, tempStorageBytes, __VA_ARGS__); \ - auto &cachingAllocator = *::c10::cuda::CUDACachingAllocator::get(); \ - auto tempStorage = cachingAllocator.allocate(tempStorageBytes); \ - func(tempStorage.get(), tempStorageBytes, __VA_ARGS__); \ - } while (false) - // Compute the number of 2d image tiles intersected by a set of 2D projected Gaussians. // // The input is a set of 2D circles with depths approximating the projection of 3D gaussians onto @@ -65,7 +58,7 @@ countTilesPerGaussian(const uint32_t gaussianOffset, if (radius <= 0) { outNumTilesPerGaussian[gidx] = static_cast(0); } else { - using vec2f = typename Vec2Type::type; + using vec2f = float2; const vec2f mean2d = *reinterpret_cast(means2d + gidx * 2); const OpT tileRadius = radius / static_cast(tileSize); @@ -188,7 +181,7 @@ computeGaussianTileIntersections( const OpT radius = radii[gidx]; if (radius > 0) { - using vec2f = typename Vec2Type::type; + using vec2f = float2; const vec2f mean2d = *reinterpret_cast(means2d + 2 * gidx); const OpT tileRadius = radius / static_cast(tileSize); @@ -348,7 +341,7 @@ computeTileOffsets(const uint32_t offset, } std::tuple -gaussianTileIntersectionCUDAImpl( +intersect_gaussian_tiles_cuda_impl( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -570,311 +563,6 @@ gaussianTileIntersectionCUDAImpl( } } -/// @brief Implements the merge path binary search algorithm in order to find the median across two -/// sorted input key arrays -template -__device__ void -mergePath(KeyIteratorIn keys1, - size_t keys1Count, - KeyIteratorIn keys2, - size_t keys2Count, - ptrdiff_t *key1Intervals, - ptrdiff_t *key2Intervals, - int intervalIndex) { - using KeyType = typename ::cuda::std::iterator_traits::value_type; - - const size_t combinedIndex = intervalIndex * (keys1Count + keys2Count) / 2; - size_t leftTop = combinedIndex > keys1Count ? keys1Count : combinedIndex; - size_t rightTop = combinedIndex > keys1Count ? combinedIndex - keys1Count : 0; - size_t leftBottom = rightTop; - - KeyType leftKey; - KeyType rightKey; - while (true) { - ptrdiff_t offset = (leftTop - leftBottom) / 2; - ptrdiff_t leftMid = leftTop - offset; - ptrdiff_t rightMid = rightTop + offset; - - if (leftMid > keys1Count - 1 || rightMid < 1) { - leftKey = 1; - rightKey = 0; - } else { - leftKey = *(keys1 + leftMid); - rightKey = *(keys2 + rightMid - 1); - } - - if (leftKey > rightKey) { - if (rightMid > keys2Count - 1 || leftMid < 1) { - leftKey = 0; - rightKey = 1; - } else { - leftKey = *(keys1 + leftMid - 1); - rightKey = *(keys2 + rightMid); - } - - if (leftKey <= rightKey) { - *key1Intervals = leftMid; - *key2Intervals = rightMid; - break; - } else { - leftTop = leftMid - 1; - rightTop = rightMid + 1; - } - } else { - leftBottom = leftMid + 1; - } - } -} - -/// @brief Kernel wrapper for the merge path algorithm -template -__global__ void -mergePathKernel(KeyIteratorIn keys1, - size_t keys1Count, - KeyIteratorIn keys2, - size_t keys2Count, - ptrdiff_t *key1Intervals, - ptrdiff_t *key2Intervals, - size_t intervalOffset) { - const unsigned int intervalIndex = threadIdx.x + blockIdx.x * blockDim.x + intervalOffset; - mergePath(keys1, keys1Count, keys2, keys2Count, key1Intervals, key2Intervals, intervalIndex); -} - -template -void -radixSortAsync(KeyT *keysIn, - KeyT *keysOut, - ValueT *valuesIn, - ValueT *valuesOut, - NumItemsT numItems, - int beginBit, - int endBit, - cudaEvent_t *events) { - using OffsetT = int64_t; - using CountT = int64_t; - - auto hostOptions = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCPU); - auto itemOffsets = torch::empty({c10::cuda::device_count()}, hostOptions); - auto itemCounts = torch::empty({c10::cuda::device_count()}, hostOptions); - for (const auto deviceId: c10::irange(c10::cuda::device_count())) { - std::tie(itemOffsets.data_ptr()[deviceId], - itemCounts.data_ptr()[deviceId]) = deviceChunk(numItems, deviceId); - } - const auto *offsets = itemOffsets.const_data_ptr(); - const auto *counts = itemCounts.const_data_ptr(); - - torch::Tensor deviceMergeIntervals = - torch::empty({2 * c10::cuda::device_count()}, - torch::TensorOptions().dtype(torch::kInt64).device(torch::kPrivateUse1)); - auto mergeIntervals = deviceMergeIntervals.data_ptr(); - - // Radix sort the subset of keys assigned to each device in parallel - for (const auto deviceId: c10::irange(c10::cuda::device_count())) { - C10_CUDA_CHECK(cudaSetDevice(deviceId)); - auto stream = c10::cuda::getCurrentCUDAStream(deviceId); - // C10_CUDA_CHECK(cudaEventSynchronize(events[deviceId])); - - const KeyT *deviceKeysIn = keysIn + offsets[deviceId]; - const ValueT *deviceValuesIn = valuesIn + offsets[deviceId]; - KeyT *deviceKeysOut = keysOut + offsets[deviceId]; - ValueT *deviceValuesOut = valuesOut + offsets[deviceId]; - - C10_CUDA_CHECK(nanovdb::util::cuda::memPrefetchAsync( - deviceKeysIn, counts[deviceId] * sizeof(KeyT), deviceId, stream)); - C10_CUDA_CHECK(nanovdb::util::cuda::memPrefetchAsync( - deviceValuesIn, counts[deviceId] * sizeof(ValueT), deviceId, stream)); - C10_CUDA_CHECK(nanovdb::util::cuda::memPrefetchAsync( - deviceKeysOut, counts[deviceId] * sizeof(KeyT), deviceId, stream)); - C10_CUDA_CHECK(nanovdb::util::cuda::memPrefetchAsync( - deviceValuesOut, counts[deviceId] * sizeof(ValueT), deviceId, stream)); - - CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, - deviceKeysIn, - deviceKeysOut, - deviceValuesIn, - deviceValuesOut, - counts[deviceId], - beginBit, - endBit, - stream); - C10_CUDA_CHECK(cudaEventRecord(events[deviceId], stream)); - } - - // TODO: Generalize to numbers of GPUs that aren't powers of two - // For each pair of devices, merge the local sorts by first computing the median across the two - // devices followed by merging the elements less than and greater than/equal to the median onto - // the first and second device of the pair respectively. This avoids the allocating memory for - // and gathering the values from both devices onto a single device. - const int log2DeviceCount = log2(c10::cuda::device_count()); - OffsetT *leftIntervals = mergeIntervals; - OffsetT *rightIntervals = mergeIntervals + c10::cuda::device_count(); - for (int deviceExponent = 0; deviceExponent < log2DeviceCount; ++deviceExponent) { - std::swap(keysIn, keysOut); - std::swap(valuesIn, valuesOut); - const int deviceInc = 1 << deviceExponent; - const int deviceCount = static_cast(c10::cuda::device_count()); - - for (int leftDeviceId = 0; leftDeviceId < deviceCount; leftDeviceId += 2 * deviceInc) { - const int rightDeviceId = leftDeviceId + deviceInc; - - CountT leftDeviceItemCount = 0; - for (int deviceId = leftDeviceId; deviceId < rightDeviceId; ++deviceId) - leftDeviceItemCount += counts[deviceId]; - - CountT rightDeviceItemCount = 0; - for (int deviceId = rightDeviceId; deviceId < rightDeviceId + deviceInc; ++deviceId) - rightDeviceItemCount += counts[deviceId]; - - const KeyT *leftDeviceKeysIn = keysIn + offsets[leftDeviceId]; - const ValueT *leftDeviceValuesIn = valuesIn + offsets[leftDeviceId]; - const KeyT *rightDeviceKeysIn = keysIn + offsets[leftDeviceId] + leftDeviceItemCount; - const ValueT *rightDeviceValuesIn = - valuesIn + offsets[leftDeviceId] + leftDeviceItemCount; - - // Wait on the prior sort to finish on both devices before computing the median across - // both devices - auto mergePathSubfunc = [&](int deviceId, int otherDeviceId, int intervalIndex) { - C10_CUDA_CHECK(cudaSetDevice(deviceId)); - - C10_CUDA_CHECK(cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(deviceId), - events[otherDeviceId])); - mergePathKernel<<<1, 1, 0, c10::cuda::getCurrentCUDAStream(deviceId)>>>( - leftDeviceKeysIn, - leftDeviceItemCount, - rightDeviceKeysIn, - rightDeviceItemCount, - leftIntervals + deviceId, - rightIntervals + deviceId, - intervalIndex); - C10_CUDA_CHECK( - cudaEventRecord(events[deviceId], c10::cuda::getCurrentCUDAStream(deviceId))); - }; - mergePathSubfunc(leftDeviceId, rightDeviceId, 0); - mergePathSubfunc(rightDeviceId, leftDeviceId, 1); - } - - for (int leftDeviceId = 0; leftDeviceId < deviceCount; leftDeviceId += 2 * deviceInc) { - const int rightDeviceId = leftDeviceId + deviceInc; - - CountT leftDeviceItemCount = 0; - for (int deviceId = leftDeviceId; deviceId < rightDeviceId; ++deviceId) - leftDeviceItemCount += counts[deviceId]; - - CountT rightDeviceItemCount = 0; - for (int deviceId = rightDeviceId; deviceId < rightDeviceId + deviceInc; ++deviceId) - rightDeviceItemCount += counts[deviceId]; - - const KeyT *leftDeviceKeysIn = keysIn + offsets[leftDeviceId]; - const ValueT *leftDeviceValuesIn = valuesIn + offsets[leftDeviceId]; - const KeyT *rightDeviceKeysIn = keysIn + offsets[leftDeviceId] + leftDeviceItemCount; - const ValueT *rightDeviceValuesIn = - valuesIn + offsets[leftDeviceId] + leftDeviceItemCount; - - // Synchronize to read back the results of the merge path kernel - C10_CUDA_CHECK(cudaEventSynchronize(events[leftDeviceId])); - C10_CUDA_CHECK(cudaEventSynchronize(events[rightDeviceId])); - - // Merge the pairs less than the median to the left device - { - C10_CUDA_CHECK(cudaSetDevice(leftDeviceId)); - auto leftStream = c10::cuda::getCurrentCUDAStream(leftDeviceId); - - const KeyT *leftKeysIn = leftDeviceKeysIn + leftIntervals[leftDeviceId]; - const ValueT *leftValuesIn = leftDeviceValuesIn + leftIntervals[leftDeviceId]; - CountT leftCount = leftIntervals[rightDeviceId] - leftIntervals[leftDeviceId]; - - const KeyT *rightKeysIn = rightDeviceKeysIn + rightIntervals[leftDeviceId]; - const ValueT *rightValuesIn = rightDeviceValuesIn + rightIntervals[leftDeviceId]; - CountT rightCount = rightIntervals[rightDeviceId] - rightIntervals[leftDeviceId]; - - OffsetT outputOffset = offsets[leftDeviceId] + leftIntervals[leftDeviceId] + - rightIntervals[leftDeviceId]; - - CUB_WRAPPER(cub::DeviceMerge::MergePairs, - leftKeysIn, - leftValuesIn, - leftCount, - rightKeysIn, - rightValuesIn, - rightCount, - keysOut + outputOffset, - valuesOut + outputOffset, - {}, - leftStream); - C10_CUDA_CHECK(cudaEventRecord(events[leftDeviceId], leftStream)); - }; - - // Merge the pairs greater than/equal to the median to the right device - { - C10_CUDA_CHECK(cudaSetDevice(rightDeviceId)); - auto rightStream = c10::cuda::getCurrentCUDAStream(rightDeviceId); - - const KeyT *leftKeysIn = leftDeviceKeysIn + leftIntervals[rightDeviceId]; - const ValueT *leftValuesIn = leftDeviceValuesIn + leftIntervals[rightDeviceId]; - CountT leftCount = leftDeviceItemCount - leftIntervals[rightDeviceId]; - - const KeyT *rightKeysIn = rightDeviceKeysIn + rightIntervals[rightDeviceId]; - const ValueT *rightValuesIn = rightDeviceValuesIn + rightIntervals[rightDeviceId]; - CountT rightCount = rightDeviceItemCount - rightIntervals[rightDeviceId]; - - OffsetT outputOffset = offsets[leftDeviceId] + leftIntervals[rightDeviceId] + - rightIntervals[rightDeviceId]; - - CUB_WRAPPER(cub::DeviceMerge::MergePairs, - leftKeysIn, - leftValuesIn, - leftCount, - rightKeysIn, - rightValuesIn, - rightCount, - keysOut + outputOffset, - valuesOut + outputOffset, - {}, - rightStream); - C10_CUDA_CHECK(cudaEventRecord(events[rightDeviceId], rightStream)); - }; - } - - for (int leftDeviceId = 0; leftDeviceId < deviceCount; leftDeviceId += 2 * deviceInc) { - const int rightDeviceId = leftDeviceId + deviceInc; - - C10_CUDA_CHECK(cudaSetDevice(leftDeviceId)); - auto leftStream = c10::cuda::getCurrentCUDAStream(leftDeviceId); - C10_CUDA_CHECK(cudaStreamWaitEvent(leftStream, events[rightDeviceId])); - - C10_CUDA_CHECK(cudaSetDevice(rightDeviceId)); - auto rightStream = c10::cuda::getCurrentCUDAStream(rightDeviceId); - C10_CUDA_CHECK(cudaStreamWaitEvent(rightStream, events[leftDeviceId])); - } - } - - // There is no merging required for a single device so we simply copy the sorted result to the - // destination array (where the sort would have been merged to). - if (log2DeviceCount % 2) { - mergeStreams(); - - std::swap(keysIn, keysOut); - std::swap(valuesIn, valuesOut); - for (const auto deviceId: c10::irange(c10::cuda::device_count())) { - C10_CUDA_CHECK(cudaSetDevice(deviceId)); - auto stream = c10::cuda::getCurrentCUDAStream(deviceId); - - C10_CUDA_CHECK(cudaMemcpyAsync(keysOut + offsets[deviceId], - keysIn + offsets[deviceId], - counts[deviceId] * sizeof(KeyT), - cudaMemcpyDefault, - stream)); - C10_CUDA_CHECK(cudaMemcpyAsync(valuesOut + offsets[deviceId], - valuesIn + offsets[deviceId], - counts[deviceId] * sizeof(ValueT), - cudaMemcpyDefault, - stream)); - - C10_CUDA_CHECK(cudaEventRecord(events[deviceId], stream)); - } - } -} - namespace { [[maybe_unused]] __global__ void @@ -885,7 +573,7 @@ sleepKernel() { } // namespace std::tuple -gaussianTileIntersectionPrivateUse1Impl( +intersect_gaussian_tiles_private_use1_impl( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1075,7 +763,6 @@ gaussianTileIntersectionPrivateUse1Impl( std::vector prefetchSizes; const cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; std::vector prefetchLocations = {location}; - std::vector prefetchLocationIndices = {0}; prefetchPointers.emplace_back(intersectionKeys.data_ptr() + intersectionsOffset); @@ -1084,6 +771,7 @@ gaussianTileIntersectionPrivateUse1Impl( intersectionsOffset); prefetchSizes.emplace_back(intersectionsCount * sizeof(int32_t)); + std::vector prefetchLocationIndices(prefetchPointers.size(), 0); sleepKernel<<<1, 1, 0, stream>>>(); C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPointers.data(), prefetchSizes.data(), @@ -1205,9 +893,33 @@ gaussianTileIntersectionPrivateUse1Impl( } // namespace +template +std::tuple +dispatch_intersect_gaussian_tiles(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW); + +template +std::tuple dispatch_intersect_gaussian_tiles_sparse( + const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const torch::Tensor &tileMask, // [C, H, W] + const torch::Tensor &activeTiles, // [num_active_tiles] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW); + template <> std::tuple -dispatchGaussianTileIntersection( +dispatch_intersect_gaussian_tiles( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1217,21 +929,21 @@ dispatchGaussianTileIntersection( const uint32_t numTilesH, const uint32_t numTilesW) { FVDB_FUNC_RANGE(); - return gaussianTileIntersectionCUDAImpl(means2d, - radii, - depths, - cameraJIdx, - at::nullopt, - at::nullopt, - numCameras, - tileSize, - numTilesH, - numTilesW); + return intersect_gaussian_tiles_cuda_impl(means2d, + radii, + depths, + cameraJIdx, + at::nullopt, + at::nullopt, + numCameras, + tileSize, + numTilesH, + numTilesW); } template <> std::tuple -dispatchGaussianTileIntersection( +dispatch_intersect_gaussian_tiles( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1241,21 +953,21 @@ dispatchGaussianTileIntersection( const uint32_t numTilesH, const uint32_t numTilesW) { FVDB_FUNC_RANGE(); - return gaussianTileIntersectionPrivateUse1Impl(means2d, - radii, - depths, - cameraJIdx, - at::nullopt, - at::nullopt, - numCameras, - tileSize, - numTilesH, - numTilesW); + return intersect_gaussian_tiles_private_use1_impl(means2d, + radii, + depths, + cameraJIdx, + at::nullopt, + at::nullopt, + numCameras, + tileSize, + numTilesH, + numTilesW); } template <> std::tuple -dispatchGaussianTileIntersection( +dispatch_intersect_gaussian_tiles( const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &radii, // [C, N] or [nnz] const torch::Tensor &depths, // [C, N] or [nnz] @@ -1269,7 +981,7 @@ dispatchGaussianTileIntersection( template <> std::tuple -dispatchGaussianSparseTileIntersection( +dispatch_intersect_gaussian_tiles_sparse( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1281,21 +993,21 @@ dispatchGaussianSparseTileIntersection( const uint32_t numTilesH, const uint32_t numTilesW) { FVDB_FUNC_RANGE(); - return gaussianTileIntersectionCUDAImpl(means2d, - radii, - depths, - cameraJIdx, - tileMask, - activeTiles, - numCameras, - tileSize, - numTilesH, - numTilesW); + return intersect_gaussian_tiles_cuda_impl(means2d, + radii, + depths, + cameraJIdx, + tileMask, + activeTiles, + numCameras, + tileSize, + numTilesH, + numTilesW); } template <> std::tuple -dispatchGaussianSparseTileIntersection( +dispatch_intersect_gaussian_tiles_sparse( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1309,21 +1021,21 @@ dispatchGaussianSparseTileIntersection( FVDB_FUNC_RANGE(); // Sparse tile intersection is not implemented for multi-GPU (PrivateUse1) // The PrivateUse1 impl already checks for this and throws an appropriate error - return gaussianTileIntersectionPrivateUse1Impl(means2d, - radii, - depths, - cameraJIdx, - tileMask, - activeTiles, - numCameras, - tileSize, - numTilesH, - numTilesW); + return intersect_gaussian_tiles_private_use1_impl(means2d, + radii, + depths, + cameraJIdx, + tileMask, + activeTiles, + numCameras, + tileSize, + numTilesH, + numTilesW); } template <> std::tuple -dispatchGaussianSparseTileIntersection( +dispatch_intersect_gaussian_tiles_sparse( const torch::Tensor &means2d, // [C, N, 2] or [M, 2] const torch::Tensor &radii, // [C, N] or [M] const torch::Tensor &depths, // [C, N] or [M] @@ -1338,6 +1050,46 @@ dispatchGaussianSparseTileIntersection( TORCH_CHECK(false, "CPU implementation not available for sparse tile intersection"); } +std::tuple +intersect_gaussian_tiles(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW) { + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_intersect_gaussian_tiles( + means2d, radii, depths, cameraIds, numCameras, tileSize, numTilesH, numTilesW); + }); +} + +std::tuple +intersect_gaussian_tiles_sparse(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const torch::Tensor &tileMask, // [C, H, W] + const torch::Tensor &activeTiles, // [num_active_tiles] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW) { + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_intersect_gaussian_tiles_sparse(means2d, + radii, + depths, + tileMask, + activeTiles, + cameraIds, + numCameras, + tileSize, + numTilesH, + numTilesW); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/IntersectGaussianTiles.h b/src/fvdb/detail/ops/IntersectGaussianTiles.h new file mode 100644 index 000000000..81ed9d182 --- /dev/null +++ b/src/fvdb/detail/ops/IntersectGaussianTiles.h @@ -0,0 +1,43 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_INTERSECTGAUSSIANTILES_H +#define FVDB_DETAIL_OPS_INTERSECTGAUSSIANTILES_H + +#include + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Compute the intersection of 2D Gaussians with image tiles for efficient rasterization +std::tuple +intersect_gaussian_tiles(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW); + +/// @brief Compute the intersection of 2D Gaussians with image tiles for sparse rendering +std::tuple +intersect_gaussian_tiles_sparse(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] + const torch::Tensor &radii, // [C, N] or [M] + const torch::Tensor &depths, // [C, N] or [M] + const torch::Tensor &tileMask, // [C, H, W] + const torch::Tensor &activeTiles, // [num_active_tiles] + const at::optional &cameraIds, // NULL or [M] + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_INTERSECTGAUSSIANTILES_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu b/src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.cu similarity index 82% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.cu index 104b8f58b..50bf16922 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.cu +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.cu @@ -1,14 +1,17 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include +#include #include +#include +#include #include #include +#include +#include +#include +#include +#include #include @@ -129,7 +132,7 @@ projectionBackwardKernel(const int32_t offset, warpSum(dLossDPoint, warp_group_g); if (warp_group_g.thread_rank() == 0) { outDLossDMeans += gId * 3; - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t i = 0; i < 3; i++) { gpuAtomicAdd(outDLossDMeans + i, dLossDPoint[i]); } @@ -177,9 +180,9 @@ projectionBackwardKernel(const int32_t offset, warpSum(dLossDTranslation, warp_group_c); if (warp_group_c.thread_rank() == 0) { outDLossDWorldToCamMatrices += cId * 16; - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t i = 0; i < 3; i++) { // rows - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t j = 0; j < 3; j++) { // cols atomicAdd_system(outDLossDWorldToCamMatrices + i * 4 + j, dLossDRotation[i][j]); } @@ -202,9 +205,33 @@ projectionBackwardKernel(const int32_t offset, } // namespace +template +std::tuple +dispatch_project_gaussians_analytic_bwd( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const at::optional &compensations, // [N, 6] optional + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, // [C, N] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &dLossDMeans2d, // [C, N, 2] + const torch::Tensor &dLossDDepths, // [C, N] + const torch::Tensor &dLossDConics, // [C, N, 3] + const at::optional &dLossDCompensations, // [C, N] optional + const bool worldToCamMatricesRequiresGrad, + const bool ortho, + at::optional outNormalizeddLossdMeans2dNormAccum, + at::optional outNormalizedMaxRadiiAccum, + at::optional outGradientStepCounts); + template <> std::tuple -dispatchGaussianProjectionBackward( +dispatch_project_gaussians_analytic_bwd( // fwd inputs const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] @@ -235,26 +262,38 @@ dispatchGaussianProjectionBackward( // const at::optional &compensations = std::nullopt; // const at::optional &dLossDCompensations = std::nullopt; - GSPLAT_DEVICE_GUARD(means); - GSPLAT_CHECK_INPUT(means); + const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); + TORCH_CHECK(means.is_cuda() && means.is_contiguous(), "means must be a contiguous CUDA tensor"); if (covars.has_value()) { - GSPLAT_CHECK_INPUT(covars.value()); + TORCH_CHECK(covars.value().is_cuda() && covars.value().is_contiguous(), + "covars must be a contiguous CUDA tensor"); } else { - GSPLAT_CHECK_INPUT(quats); - GSPLAT_CHECK_INPUT(logScales); + TORCH_CHECK(quats.is_cuda() && quats.is_contiguous(), + "quats must be a contiguous CUDA tensor"); + TORCH_CHECK(logScales.is_cuda() && logScales.is_contiguous(), + "logScales must be a contiguous CUDA tensor"); } - GSPLAT_CHECK_INPUT(worldToCamMatrices); - GSPLAT_CHECK_INPUT(projectionMatrices); - GSPLAT_CHECK_INPUT(radii); - GSPLAT_CHECK_INPUT(conics); - GSPLAT_CHECK_INPUT(dLossDMeans2d); - GSPLAT_CHECK_INPUT(dLossDDepths); - GSPLAT_CHECK_INPUT(dLossDConics); + TORCH_CHECK(worldToCamMatrices.is_cuda() && worldToCamMatrices.is_contiguous(), + "worldToCamMatrices must be a contiguous CUDA tensor"); + TORCH_CHECK(projectionMatrices.is_cuda() && projectionMatrices.is_contiguous(), + "projectionMatrices must be a contiguous CUDA tensor"); + TORCH_CHECK(radii.is_cuda() && radii.is_contiguous(), "radii must be a contiguous CUDA tensor"); + TORCH_CHECK(conics.is_cuda() && conics.is_contiguous(), + "conics must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDMeans2d.is_cuda() && dLossDMeans2d.is_contiguous(), + "dLossDMeans2d must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDDepths.is_cuda() && dLossDDepths.is_contiguous(), + "dLossDDepths must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDConics.is_cuda() && dLossDConics.is_contiguous(), + "dLossDConics must be a contiguous CUDA tensor"); if (compensations.has_value()) { - GSPLAT_CHECK_INPUT(compensations.value()); + TORCH_CHECK(compensations.value().is_cuda() && compensations.value().is_contiguous(), + "compensations must be a contiguous CUDA tensor"); } if (dLossDCompensations.has_value()) { - GSPLAT_CHECK_INPUT(dLossDCompensations.value()); + TORCH_CHECK(dLossDCompensations.value().is_cuda() && + dLossDCompensations.value().is_contiguous(), + "dLossDCompensations must be a contiguous CUDA tensor"); assert(compensations.has_value()); } @@ -279,7 +318,6 @@ dispatchGaussianProjectionBackward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -324,7 +362,6 @@ dispatchGaussianProjectionBackward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -375,7 +412,7 @@ dispatchGaussianProjectionBackward( template <> std::tuple -dispatchGaussianProjectionBackward( +dispatch_project_gaussians_analytic_bwd( // fwd inputs const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] @@ -432,7 +469,6 @@ dispatchGaussianProjectionBackward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -479,7 +515,6 @@ dispatchGaussianProjectionBackward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -535,7 +570,7 @@ dispatchGaussianProjectionBackward( template <> std::tuple -dispatchGaussianProjectionBackward( +dispatch_project_gaussians_analytic_bwd( // fwd inputs const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] @@ -562,6 +597,53 @@ dispatchGaussianProjectionBackward( TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +project_gaussians_analytic_bwd( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const at::optional &compensations, // [N, 6] optional + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, // [C, N] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &dLossDMeans2d, // [C, N, 2] + const torch::Tensor &dLossDDepths, // [C, N] + const torch::Tensor &dLossDConics, // [C, N, 3] + const at::optional &dLossDCompensations, // [C, N] optional + const bool worldToCamMatricesRequiresGrad, + const bool ortho, + at::optional outNormalizeddLossdMeans2dNormAccum, + at::optional outNormalizedMaxRadiiAccum, + at::optional outGradientStepCounts) { + return FVDB_DISPATCH_KERNEL(means.device(), [&]() { + return dispatch_project_gaussians_analytic_bwd( + means, + quats, + scales, + worldToCamMatrices, + projectionMatrices, + compensations, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + worldToCamMatricesRequiresGrad, + ortho, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.h b/src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.h similarity index 85% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.h rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.h index 9f08ec787..e7ae30b11 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionBackward.h +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticBackward.h @@ -1,10 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONBACKWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONBACKWARD_H +#ifndef FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICBACKWARD_H +#define FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICBACKWARD_H -#include +#include #include @@ -21,8 +21,6 @@ namespace ops { /// intrinsics. It enables backpropagation through the projection step in the Gaussian Splatting /// pipeline. /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// /// @param[in] means 3D positions of Gaussians [N, 3] /// @param[in] quats Quaternion rotations of Gaussians [N, 4] in format (x, y, z, w) /// @param[in] scales Scale factors of Gaussians [N, 3] representing extent in each dimension @@ -49,14 +47,13 @@ namespace ops { /// /// @return std::tuple containing gradients of the loss function with respect to the input /// parameters: -/// - 3D means [N, 3] - ∂L/∂means -/// - Quaternions [N, 4] - ∂L/∂quats -/// - Scales [N, 3] - ∂L/∂scales -/// - View matrices [C, 4, 4] - ∂L/∂viewmats -/// - Camera intrinsics [C, 3, 3] - ∂L/∂Ks -template +/// - 3D means [N, 3] - dL/dmeans +/// - Quaternions [N, 4] - dL/dquats +/// - Scales [N, 3] - dL/dscales +/// - View matrices [C, 4, 4] - dL/dviewmats +/// - Camera intrinsics [C, 3, 3] - dL/dKs std::tuple -dispatchGaussianProjectionBackward( +project_gaussians_analytic_bwd( const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] const torch::Tensor &scales, // [N, 3] @@ -82,4 +79,4 @@ dispatchGaussianProjectionBackward( } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONBACKWARD_H +#endif // FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu b/src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.cu similarity index 79% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.cu index 702df13a7..3376684b6 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.cu +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.cu @@ -1,13 +1,17 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include +#include #include #include +#include +#include #include #include +#include +#include +#include +#include #include @@ -172,9 +176,25 @@ projectionForwardKernel(int64_t offset, } // namespace +template +std::tuple +dispatch_project_gaussians_analytic_fwd(const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations, + const bool ortho); + template <> std::tuple -dispatchGaussianProjectionForward( +dispatch_project_gaussians_analytic_fwd( const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] const torch::Tensor &logScales, // [N, 3] @@ -224,7 +244,6 @@ dispatchGaussianProjectionForward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -249,7 +268,6 @@ dispatchGaussianProjectionForward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -277,7 +295,7 @@ dispatchGaussianProjectionForward( template <> std::tuple -dispatchGaussianProjectionForward( +dispatch_project_gaussians_analytic_fwd( const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] const torch::Tensor &logScales, // [N, 3] @@ -331,7 +349,6 @@ dispatchGaussianProjectionForward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -356,7 +373,6 @@ dispatchGaussianProjectionForward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -387,22 +403,54 @@ dispatchGaussianProjectionForward( template <> std::tuple -dispatchGaussianProjectionForward(const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] - const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &worldToCamMatrices, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const int64_t imageWidth, - const int64_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float radiusClip, - const bool calcCompensations, - const bool ortho) { +dispatch_project_gaussians_analytic_fwd( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float radiusClip, + const bool calcCompensations, + const bool ortho) { TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +project_gaussians_analytic_fwd(const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations, + const bool ortho) { + return FVDB_DISPATCH_KERNEL(means.device(), [&]() { + return dispatch_project_gaussians_analytic_fwd(means, + quats, + scales, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + calcCompensations, + ortho); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.h b/src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.h similarity index 66% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionForward.h rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.h index 5b76bb6d7..b5143ffb1 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionForward.h +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticForward.h @@ -1,10 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONFORWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONFORWARD_H +#ifndef FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICFORWARD_H +#define FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICFORWARD_H -#include +#include #include @@ -27,8 +27,6 @@ namespace ops { /// too small) are set to zero, but the other output values of discarded Gaussians are uninitialized /// (undefined). /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// /// @param[in] means 3D positions of Gaussians [N, 3] where N is number of Gaussians /// @param[in] quats Quaternion rotations of Gaussians [N, 4] in format (x, y, z, w) /// @param[in] scales Scale factors of Gaussians [N, 3] representing extent in each dimension @@ -49,24 +47,23 @@ namespace ops { /// - Covariance matrices in conic form [C, N, 3] representing (a, b, c) in ax² + 2bxy + cy² /// - Radii of 2D Gaussians [C, N] /// - Compensation factors [C, N] (if calc_compensations is true, otherwise empty tensor) -template std::tuple -dispatchGaussianProjectionForward(const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] - const torch::Tensor &scales, // [N, 3] - const torch::Tensor &worldToCamMatrices, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const int64_t imageWidth, - const int64_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2d, - const bool calcCompensations, - const bool ortho); +project_gaussians_analytic_fwd(const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations, + const bool ortho); } // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONFORWARD_H +#endif // FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.cu similarity index 71% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.cu index 4026f130a..8b142d6f8 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.cu +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.cu @@ -1,13 +1,16 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include +#include #include +#include +#include #include +#include +#include +#include +#include +#include #include #include @@ -119,7 +122,7 @@ jaggedProjectionBackwardKernel( warpSum(dLossDPoint, warp_group_g); if (warp_group_g.thread_rank() == 0) { outDLossDMeans += gId * 3; - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t i = 0; i < 3; i++) { gpuAtomicAdd(outDLossDMeans + i, dLossDPoint[i]); } @@ -167,9 +170,9 @@ jaggedProjectionBackwardKernel( warpSum(dLossDTranslation, warp_group_c); if (warp_group_c.thread_rank() == 0) { outDLossDWorldToCamMatrices += cId * 16; - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t i = 0; i < 3; i++) { // rows - GSPLAT_PRAGMA_UNROLL +#pragma unroll for (uint32_t j = 0; j < 3; j++) { // cols gpuAtomicAdd(outDLossDWorldToCamMatrices + i * 4 + j, dLossDRotation[i][j]); } @@ -179,9 +182,29 @@ jaggedProjectionBackwardKernel( } } +template +std::tuple +dispatch_project_gaussians_analytic_jagged_bwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, // [N] + const torch::Tensor &conics, // [N, 3] + const torch::Tensor &dLossDMeans2d, // [N, 2] + const torch::Tensor &dLossDDepths, // [N] + const torch::Tensor &dLossDConics, // [N, 3] + const bool worldToCamMatricesRequiresGrad, + const bool ortho); + template <> std::tuple -dispatchGaussianProjectionJaggedBackward( +dispatch_project_gaussians_analytic_jagged_bwd( const torch::Tensor &gSizes, // [B] gaussian sizes const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] optional @@ -207,28 +230,42 @@ dispatchGaussianProjectionJaggedBackward( const at::optional &compensations = std::nullopt; const at::optional &dLossDCompensations = std::nullopt; - GSPLAT_DEVICE_GUARD(means); - GSPLAT_CHECK_INPUT(gSizes); - GSPLAT_CHECK_INPUT(means); + const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); + TORCH_CHECK(gSizes.is_cuda() && gSizes.is_contiguous(), + "gSizes must be a contiguous CUDA tensor"); + TORCH_CHECK(means.is_cuda() && means.is_contiguous(), "means must be a contiguous CUDA tensor"); if (covars.has_value()) { - GSPLAT_CHECK_INPUT(covars.value()); + TORCH_CHECK(covars.value().is_cuda() && covars.value().is_contiguous(), + "covars must be a contiguous CUDA tensor"); } else { - GSPLAT_CHECK_INPUT(quats); - GSPLAT_CHECK_INPUT(scales); + TORCH_CHECK(quats.is_cuda() && quats.is_contiguous(), + "quats must be a contiguous CUDA tensor"); + TORCH_CHECK(scales.is_cuda() && scales.is_contiguous(), + "scales must be a contiguous CUDA tensor"); } - GSPLAT_CHECK_INPUT(cSizes); - GSPLAT_CHECK_INPUT(worldToCamMatrices); - GSPLAT_CHECK_INPUT(projectionMatrices); - GSPLAT_CHECK_INPUT(radii); - GSPLAT_CHECK_INPUT(conics); - GSPLAT_CHECK_INPUT(dLossDMeans2d); - GSPLAT_CHECK_INPUT(dLossDDepths); - GSPLAT_CHECK_INPUT(dLossDConics); + TORCH_CHECK(cSizes.is_cuda() && cSizes.is_contiguous(), + "cSizes must be a contiguous CUDA tensor"); + TORCH_CHECK(worldToCamMatrices.is_cuda() && worldToCamMatrices.is_contiguous(), + "worldToCamMatrices must be a contiguous CUDA tensor"); + TORCH_CHECK(projectionMatrices.is_cuda() && projectionMatrices.is_contiguous(), + "projectionMatrices must be a contiguous CUDA tensor"); + TORCH_CHECK(radii.is_cuda() && radii.is_contiguous(), "radii must be a contiguous CUDA tensor"); + TORCH_CHECK(conics.is_cuda() && conics.is_contiguous(), + "conics must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDMeans2d.is_cuda() && dLossDMeans2d.is_contiguous(), + "dLossDMeans2d must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDDepths.is_cuda() && dLossDDepths.is_contiguous(), + "dLossDDepths must be a contiguous CUDA tensor"); + TORCH_CHECK(dLossDConics.is_cuda() && dLossDConics.is_contiguous(), + "dLossDConics must be a contiguous CUDA tensor"); if (compensations.has_value()) { - GSPLAT_CHECK_INPUT(compensations.value()); + TORCH_CHECK(compensations.value().is_cuda() && compensations.value().is_contiguous(), + "compensations must be a contiguous CUDA tensor"); } if (dLossDCompensations.has_value()) { - GSPLAT_CHECK_INPUT(dLossDCompensations.value()); + TORCH_CHECK(dLossDCompensations.value().is_cuda() && + dLossDCompensations.value().is_contiguous(), + "dLossDCompensations must be a contiguous CUDA tensor"); assert(compensations.has_value()); } @@ -259,7 +296,6 @@ dispatchGaussianProjectionJaggedBackward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -296,7 +332,6 @@ dispatchGaussianProjectionJaggedBackward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), kBackwardProjectionNearPlane, @@ -343,7 +378,7 @@ dispatchGaussianProjectionJaggedBackward( template <> std::tuple -dispatchGaussianProjectionJaggedBackward( +dispatch_project_gaussians_analytic_jagged_bwd( const torch::Tensor &gSizes, // [B] gaussian sizes const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] optional @@ -364,6 +399,46 @@ dispatchGaussianProjectionJaggedBackward( TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +project_gaussians_analytic_jagged_bwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, // [N] + const torch::Tensor &conics, // [N, 3] + const torch::Tensor &dLossDMeans2d, // [N, 2] + const torch::Tensor &dLossDDepths, // [N] + const torch::Tensor &dLossDConics, // [N, 3] + const bool worldToCamMatricesRequiresGrad, + const bool ortho) { + return FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { + return dispatch_project_gaussians_analytic_jagged_bwd( + gSizes, + means, + quats, + scales, + cSizes, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + worldToCamMatricesRequiresGrad, + ortho); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.h b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.h new file mode 100644 index 000000000..80cf57684 --- /dev/null +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedBackward.h @@ -0,0 +1,73 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDBACKWARD_H +#define FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDBACKWARD_H + +#include + +#include + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Calculate gradients for the jagged 3D to 2D Gaussian projection (backward pass) +/// +/// This function computes the gradients of the 3D to 2D Gaussian projection with respect to +/// the input parameters when using jagged tensors for batch processing. It enables backpropagation +/// through the projection step in the Gaussian Splatting pipeline for scenes with variable +/// numbers of objects and cameras per batch. +/// +/// @param[in] gSizes Batch sizes for Gaussians [B] +/// @param[in] means 3D positions of Gaussians [M, 3] +/// @param[in] quats Quaternion rotations of Gaussians [M, 4] in format (x, y, z, w) +/// @param[in] scales Scale factors of Gaussians [M, 3] representing extent in each dimension +/// @param[in] cSizes Batch sizes for cameras [B] +/// @param[in] worldToCamMatrices Camera view matrices [BC, 4, 4] +/// @param[in] projectionMatrices Camera intrinsic matrices [BC, 3, 3] +/// @param[in] imageWidth Width of the output image in pixels +/// @param[in] imageHeight Height of the output image in pixels +/// @param[in] eps2d 2D projection epsilon for numerical stability +/// @param[in] radii Output radii from forward pass [M] +/// @param[in] conics Output conics from forward pass [M, 3] +/// @param[out] dLossDMeans2d Gradients with respect to projected 2D means [M, 2] +/// @param[out] dLossDDepths Gradients with respect to depths [M] +/// @param[out] dLossDConics Gradients with respect to conics [M, 3] +/// @param[in] worldToCamMatricesRequiresGrad Whether viewmats requires gradient +/// @param[in] ortho Whether orthographic projection was used in forward pass +/// +/// @return std::tuple containing gradients of the loss function with respect to the input +/// parameters: +/// - 3D means [M, 3] - dL/dmeans +/// - Quaternions [M, 4] - dL/dquats +/// - Scales [M, 3] - dL/dscales +/// - View matrices [BC, 4, 4] - dL/dviewmats (if viewmats_requires_grad is true, otherwise +/// empty tensor) +/// - Camera intrinsics [BC, 3, 3] - dL/dKs +std::tuple +project_gaussians_analytic_jagged_bwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, // [N] + const torch::Tensor &conics, // [N, 3] + const torch::Tensor &dLossDMeans2d, // [N, 2] + const torch::Tensor &dLossDDepths, // [N] + const torch::Tensor &dLossDConics, // [N, 3] + const bool worldToCamMatricesRequiresGrad, + const bool ortho); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.cu b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.cu similarity index 66% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.cu rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.cu index 3026b3b11..a48351d5c 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.cu +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.cu @@ -1,13 +1,16 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include +#include #include +#include +#include #include +#include +#include +#include +#include +#include #include #include @@ -115,9 +118,26 @@ jaggedProjectionForwardKernel(const uint32_t B, } } +template +std::tuple +dispatch_project_gaussians_analytic_jagged_fwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool ortho); + template <> std::tuple -dispatchGaussianProjectionJaggedForward( +dispatch_project_gaussians_analytic_jagged_fwd( const torch::Tensor &gSizes, // [B] gaussian sizes const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] optional @@ -137,18 +157,25 @@ dispatchGaussianProjectionJaggedForward( const at::optional &covars = std::nullopt; constexpr bool calc_compensations = false; - GSPLAT_DEVICE_GUARD(means); - GSPLAT_CHECK_INPUT(gSizes); - GSPLAT_CHECK_INPUT(means); + const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); + TORCH_CHECK(gSizes.is_cuda() && gSizes.is_contiguous(), + "gSizes must be a contiguous CUDA tensor"); + TORCH_CHECK(means.is_cuda() && means.is_contiguous(), "means must be a contiguous CUDA tensor"); if (covars.has_value()) { - GSPLAT_CHECK_INPUT(covars.value()); + TORCH_CHECK(covars.value().is_cuda() && covars.value().is_contiguous(), + "covars must be a contiguous CUDA tensor"); } else { - GSPLAT_CHECK_INPUT(quats); - GSPLAT_CHECK_INPUT(scales); + TORCH_CHECK(quats.is_cuda() && quats.is_contiguous(), + "quats must be a contiguous CUDA tensor"); + TORCH_CHECK(scales.is_cuda() && scales.is_contiguous(), + "scales must be a contiguous CUDA tensor"); } - GSPLAT_CHECK_INPUT(cSizes); - GSPLAT_CHECK_INPUT(worldToCamMatrices); - GSPLAT_CHECK_INPUT(projectionMatrices); + TORCH_CHECK(cSizes.is_cuda() && cSizes.is_contiguous(), + "cSizes must be a contiguous CUDA tensor"); + TORCH_CHECK(worldToCamMatrices.is_cuda() && worldToCamMatrices.is_contiguous(), + "worldToCamMatrices must be a contiguous CUDA tensor"); + TORCH_CHECK(projectionMatrices.is_cuda() && projectionMatrices.is_contiguous(), + "projectionMatrices must be a contiguous CUDA tensor"); // TODO: use inclusive sum const uint32_t B = gSizes.size(0); @@ -175,7 +202,6 @@ dispatchGaussianProjectionJaggedForward( if (ortho) { const auto camera = OrthographicCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -204,7 +230,6 @@ dispatchGaussianProjectionJaggedForward( } else { const auto camera = PerspectiveCamera{projectionMatrices, worldToCamMatrices, - static_cast(C), static_cast(imageWidth), static_cast(imageHeight), nearPlane, @@ -239,7 +264,7 @@ dispatchGaussianProjectionJaggedForward( template <> std::tuple -dispatchGaussianProjectionJaggedForward( +dispatch_project_gaussians_analytic_jagged_fwd( const torch::Tensor &gSizes, // [B] gaussian sizes const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] optional @@ -257,6 +282,39 @@ dispatchGaussianProjectionJaggedForward( TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +project_gaussians_analytic_jagged_fwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool ortho) { + return FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { + return dispatch_project_gaussians_analytic_jagged_fwd(gSizes, + means, + quats, + scales, + cSizes, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + ortho); + }); +} + } // namespace ops } // namespace detail } // namespace fvdb diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.h b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.h similarity index 60% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.h rename to src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.h index 33ed9ecdb..a42f79be9 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedForward.h +++ b/src/fvdb/detail/ops/ProjectGaussiansAnalyticJaggedForward.h @@ -1,10 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDFORWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDFORWARD_H +#ifndef FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDFORWARD_H +#define FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDFORWARD_H -#include +#include #include @@ -24,8 +24,6 @@ namespace ops { /// too small) are set to zero, but the other output values of discarded Gaussians are uninitialized /// (undefined). /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// /// @param[in] gSizes Batch sizes for Gaussians [B] /// @param[in] means 3D positions of Gaussians [M, 3] /// @param[in] quats Quaternion rotations of Gaussians [M, 4] in format (x, y, z, w) @@ -44,28 +42,27 @@ namespace ops { /// @return std::tuple containing: /// - 2D projected Gaussian centers [M, 2] /// - Depths of Gaussians [M] -/// - Covariance matrices in conic form [M, 3] representing (a, b, c) in ax² + 2bxy + cy² +/// - Covariance matrices in conic form [M, 3] representing (a, b, c) in ax^2 + 2bxy + cy^2 /// - Radii of 2D Gaussians [M] /// - Flattened camera indices [M] indicating which camera each projection corresponds to -template std::tuple -dispatchGaussianProjectionJaggedForward(const torch::Tensor &gSizes, // [B] gaussian sizes - const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] optional - const torch::Tensor &scales, // [N, 3] optional - const torch::Tensor &cSizes, // [B] camera sizes - const torch::Tensor &worldToCamMatrices, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const float nearPlane, - const float farPlane, - const float minRadius2d, - const bool ortho); +project_gaussians_analytic_jagged_fwd(const torch::Tensor &gSizes, // [B] gaussian sizes + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] optional + const torch::Tensor &scales, // [N, 3] optional + const torch::Tensor &cSizes, // [B] camera sizes + const torch::Tensor &worldToCamMatrices, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool ortho); } // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDFORWARD_H +#endif // FVDB_DETAIL_OPS_PROJECTGAUSSIANSANALYTICJAGGEDFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu b/src/fvdb/detail/ops/ProjectGaussiansUtForward.cu similarity index 89% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu rename to src/fvdb/detail/ops/ProjectGaussiansUtForward.cu index c48c42563..606d67487 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.cu +++ b/src/fvdb/detail/ops/ProjectGaussiansUtForward.cu @@ -2,14 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include +#include #include #include +#include +#include #include +#include +#include +#include +#include +#include #include #include @@ -450,6 +453,27 @@ projectionForwardUTKernel(int64_t offset, } } +template +std::tuple +dispatchGaussianProjectionForwardUT( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const DistortionModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C, 12] for OPENCV_*, or [C, 0] for PINHOLE/ORTHO + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations); + /// @brief CUDA specialization for UT forward projection dispatch. /// /// Performs host-side validation and launches `projectionForwardUTKernel`. @@ -672,4 +696,44 @@ dispatchGaussianProjectionForwardUT( "GaussianProjectionForwardUT not implemented for this device type"); } +std::tuple +project_gaussians_ut_fwd( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] + const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] + const torch::Tensor &projectionMatrices, // [C, 3, 3] + const RollingShutterType rollingShutterType, + const UTParams &utParams, + const DistortionModel cameraModel, + const torch::Tensor &distortionCoeffs, // [C, 12] for OPENCV_*, or [C, 0] for PINHOLE/ORTHO + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations) { + return FVDB_DISPATCH_KERNEL(means.device(), [&]() { + return dispatchGaussianProjectionForwardUT(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + rollingShutterType, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + calcCompensations); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h b/src/fvdb/detail/ops/ProjectGaussiansUtForward.h similarity index 91% rename from src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h rename to src/fvdb/detail/ops/ProjectGaussiansUtForward.h index d5d114d0f..f69467bdb 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionUT.h +++ b/src/fvdb/detail/ops/ProjectGaussiansUtForward.h @@ -1,12 +1,11 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H +#ifndef FVDB_DETAIL_OPS_PROJECTGAUSSIANSUTFORWARD_H +#define FVDB_DETAIL_OPS_PROJECTGAUSSIANSUTFORWARD_H -#include +#include -#include #include #include @@ -43,8 +42,6 @@ namespace ops { /// 5. **Cull** gaussians that are out-of-range (near/far) or too small (min radius), and write /// outputs for the survivors. /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// /// @param[in] means 3D positions of Gaussians [N, 3] where N is number of Gaussians /// @param[in] quats Quaternion rotations of Gaussians [N, 4] in format (w, x, y, z) /// @param[in] logScales Log-scale factors of Gaussians [N, 3] (natural log), representing extent in @@ -76,11 +73,11 @@ namespace ops { /// - Radii of 2D Gaussians [C, N] /// - 2D projected Gaussian centers [C, N, 2] /// - Depths of Gaussians [C, N] -/// - Covariance matrices in conic form [C, N, 3] representing (a, b, c) in ax² + 2bxy + cy² +/// - Covariance matrices in conic form [C, N, 3] representing (a, b, c) in ax^2 + 2bxy + +/// cy^2 /// - Compensation factors [C, N] (if calc_compensations is true, otherwise empty tensor) -template std::tuple -dispatchGaussianProjectionForwardUT( +project_gaussians_ut_fwd( const torch::Tensor &means, // [N, 3] const torch::Tensor &quats, // [N, 4] const torch::Tensor &logScales, // [N, 3] @@ -103,4 +100,4 @@ dispatchGaussianProjectionForwardUT( } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONUT_H +#endif // FVDB_DETAIL_OPS_PROJECTGAUSSIANSUTFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.cu similarity index 90% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu rename to src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.cu index 3ea5ad9c5..fd86ecd5d 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.cu +++ b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.cu @@ -1,15 +1,16 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include -#include +#include #include #include +#include +#include +#include #include +#include +#include +#include #include #include @@ -1324,9 +1325,53 @@ callRasterizeBackwardPrivateUse1( } // namespace +template +std::tuple +dispatch_rasterize_screen_space_gaussians_bwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const RenderWindow2D &renderWindow, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks); + +template +std::tuple +dispatch_rasterize_screen_space_gaussians_sparse_bwd( + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const RenderWindow2D &renderWindow, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &renderedAlphas, + const fvdb::JaggedTensor &lastIds, + const fvdb::JaggedTensor &dLossDRenderedFeatures, + const fvdb::JaggedTensor &dLossDRenderedAlphas, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks); + template <> std::tuple -dispatchGaussianRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_bwd( const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] const torch::Tensor &features, // [C, N, 3] @@ -1417,7 +1462,7 @@ dispatchGaussianRasterizeBackward( template <> std::tuple -dispatchGaussianRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_bwd( const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] const torch::Tensor &features, // [C, N, 3] @@ -1505,7 +1550,7 @@ dispatchGaussianRasterizeBackward( template <> std::tuple -dispatchGaussianRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_bwd( const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] const torch::Tensor &features, // [C, N, 3] @@ -1527,7 +1572,7 @@ dispatchGaussianRasterizeBackward( template <> std::tuple -dispatchGaussianSparseRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_sparse_bwd( const fvdb::JaggedTensor &pixelsToRender, // [C, NumPixels, 2] const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -1630,7 +1675,7 @@ dispatchGaussianSparseRasterizeBackward( template <> std::tuple -dispatchGaussianSparseRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_sparse_bwd( const fvdb::JaggedTensor &pixelsToRender, // [C, NumPixels, 2] const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -1657,7 +1702,7 @@ dispatchGaussianSparseRasterizeBackward( template <> std::tuple -dispatchGaussianSparseRasterizeBackward( +dispatch_rasterize_screen_space_gaussians_sparse_bwd( const fvdb::JaggedTensor &pixelsToRender, // [C, NumPixels, 2] const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -1682,4 +1727,97 @@ dispatchGaussianSparseRasterizeBackward( TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +rasterize_screen_space_gaussians_bwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks) { + const RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_rasterize_screen_space_gaussians_bwd(means2d, + conics, + features, + opacities, + renderWindow, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + absGrad, + numSharedChannelsOverride, + backgrounds, + masks); + }); +} + +std::tuple +rasterize_screen_space_gaussians_sparse_bwd(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &renderedAlphas, + const fvdb::JaggedTensor &lastIds, + const fvdb::JaggedTensor &dLossDRenderedFeatures, + const fvdb::JaggedTensor &dLossDRenderedAlphas, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks) { + const RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_rasterize_screen_space_gaussians_sparse_bwd( + pixelsToRender, + means2d, + conics, + features, + opacities, + renderWindow, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + absGrad, + numSharedChannelsOverride, + backgrounds, + masks); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.h b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.h new file mode 100644 index 000000000..20dd59553 --- /dev/null +++ b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansBackward.h @@ -0,0 +1,153 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSBACKWARD_H +#define FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSBACKWARD_H + +#include + +#include + +#include + +#include + +namespace fvdb { +namespace detail { +namespace ops { + +/// @brief Calculate gradients for the Gaussian rasterization process (backward pass) +/// +/// This function computes the gradients of the Gaussian splatting rendering with respect to +/// its input parameters: 2D projected Gaussian means, conics, features/colors, and opacities. +/// It is used during backpropagation to update the Gaussian parameters during training. +/// +/// @param[in] means2d 2D projected Gaussian centers [C, N, 2] +/// @param[in] conics Gaussian covariance matrices in conic form [C, N, 3] representing (a, b, c) in +/// ax² + 2bxy + cy² +/// @param[in] features Feature / color values of Gaussians [C, N, D] +/// @param[in] opacities Opacity values for each Gaussian [N] +/// @param[in] imageWidth Width of the render window in pixels +/// @param[in] imageHeight Height of the render window in pixels +/// @param[in] imageOriginW Horizontal origin of the render window +/// @param[in] imageOriginH Vertical origin of the render window +/// @param[in] tileSize Size of tiles used for rasterization optimization +/// @param[in] tileOffsets Offsets for tiles [C, tile_height, tile_width] +/// @param[in] tileGaussianIds Flattened Gaussian IDs for tile intersection [n_isects] +/// @param[in] renderedAlphas Alpha values from forward pass [C, render_height, render_width, 1] +/// @param[in] lastIds Last Gaussian IDs per pixel from forward pass [C, render_height, +/// render_width] +/// @param[out] dLossDRenderedFeatures Gradients of loss with respect to rendered features [C, +/// render_height, render_width, D] +/// @param[out] dLossDRenderedAlphas Gradients of loss with respect to rendered alphas [C, +/// render_height, render_width, 1] +/// @param[in] absGrad Whether to use absolute gradients +/// @param[in] numSharedChannelsOverride Override for number of shared memory channels (-1 means +/// auto-select) +/// @param[in] backgrounds Optional background color per camera [C, D]. If provided, background +/// colors affect gradient computation for transparent pixels. If not provided, background is +/// assumed to be black. +/// @param[in] masks Optional per-tile boolean mask [C, tile_height, tile_width] +/// +/// @return std::tuple containing gradients of the loss function with respect to the input +/// parameters: +/// - Absolute value of 2D means [C, N, 2] - gradients dL/d|means2d| (optional: if +/// absGrad is true, this tensor is returned, otherwise it is an empty tensor) +/// - 2D means [C, N, 2] - gradients dL/dmeans2d +/// - conics [C, N, 3] - gradients dL/dconics +/// - features [C, N, D] - gradients dL/dfeatures +/// - opacities [N] - gradients dL/dopacities +std::tuple +rasterize_screen_space_gaussians_bwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + bool absGrad, + int64_t numSharedChannelsOverride = -1, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); + +/// @brief Calculate gradients for the sparse Gaussian rasterization process (backward pass) +/// +/// This function computes the gradients of the sparse Gaussian splatting rendering with respect to +/// its input parameters for only the specified pixels. It combines the efficiency of sparse +/// rasterization with gradient computation, processing only the pixels specified in pixelsToRender. +/// +/// @param[in] pixelsToRender JaggedTensor containing pixel coordinates to render [C, NumPixels, 2] +/// @param[in] means2d 2D projected Gaussian centers [C, N, 2] +/// @param[in] conics Gaussian covariance matrices in conic form [C, N, 3] representing (a, b, c) in +/// ax² + 2bxy + cy² +/// @param[in] features Feature / color values of Gaussians [C, N, D] +/// @param[in] opacities Opacity values for each Gaussian [N] +/// @param[in] imageWidth Width of the render window in pixels +/// @param[in] imageHeight Height of the render window in pixels +/// @param[in] imageOriginW Horizontal origin of the render window +/// @param[in] imageOriginH Vertical origin of the render window +/// @param[in] tileSize Size of tiles used for rasterization optimization +/// @param[in] tileOffsets Offsets for tiles [C, tile_height, tile_width] +/// @param[in] tileGaussianIds Flattened Gaussian IDs for tile intersection [n_isects] +/// @param[in] renderedAlphas Alpha values from sparse forward pass [JaggedTensor] +/// @param[in] lastIds Last Gaussian IDs per pixel from sparse forward pass [JaggedTensor] +/// @param[in] dLossDRenderedFeatures Gradients of loss w.r.t sparse rendered features +/// [JaggedTensor] +/// @param[in] dLossDRenderedAlphas Gradients of loss w.r.t sparse rendered alphas [JaggedTensor] +/// @param[in] activeTiles Tensor containing indices of active tiles +/// @param[in] tilePixelMask Tensor containing the mask for each tile pixel +/// @param[in] tilePixelCumsum Tensor containing cumulative sum of tile pixels +/// @param[in] pixelMap Tensor containing mapping of pixels to output indices +/// @param[in] absGrad Whether to use absolute gradients +/// @param[in] numSharedChannelsOverride Override for number of shared memory channels (-1 means +/// auto-select) +/// @param[in] backgrounds Optional background color per camera [C, D] +/// @param[in] masks Optional per-tile boolean mask [C, tile_height, tile_width] +/// +/// @return std::tuple containing gradients of the loss function with respect to the input +/// parameters: +/// - Absolute value of 2D means [C, N, 2] - gradients dL/d|means2d| +/// - 2D means [C, N, 2] - gradients dL/dmeans2d +/// - conics [C, N, 3] - gradients dL/dconics +/// - features [C, N, D] - gradients dL/dfeatures +/// - opacities [N] - gradients dL/dopacities +std::tuple +rasterize_screen_space_gaussians_sparse_bwd( + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &renderedAlphas, + const fvdb::JaggedTensor &lastIds, + const fvdb::JaggedTensor &dLossDRenderedFeatures, + const fvdb::JaggedTensor &dLossDRenderedAlphas, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + bool absGrad, + int64_t numSharedChannelsOverride = -1, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); + +} // namespace ops +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.cu similarity index 84% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu rename to src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.cu index 75dbde1ab..e74914d47 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.cu +++ b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.cu @@ -2,15 +2,19 @@ // SPDX-License-Identifier: Apache-2.0 // #include -#include -#include -#include -#include -#include -#include +#include #include #include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include #include @@ -99,7 +103,7 @@ template struct Ras template __device__ void writeFeatures(uint64_t pixelIndex, F &&f) { - PRAGMA_UNROLL +#pragma unroll for (uint32_t k = 0; k < NUM_CHANNELS; ++k) { mOutFeatures.data()[pixelIndex][k] = f(k); } @@ -189,7 +193,7 @@ template struct Ras } const ScalarType nextTransmittance = accumTransmittance * (1.0f - alpha); - if (nextTransmittance <= 1e-4f) { // this pixel is done: exclusive + if (nextTransmittance <= kTransmittanceThreshold) { // this pixel is done done = true; break; } @@ -204,7 +208,7 @@ template struct Ras return commonArgs.mFeatures[cid][gid]; } }(); - PRAGMA_UNROLL +#pragma unroll for (uint32_t k = 0; k < NUM_CHANNELS; ++k) { pixOut[k] += featureAccessor[k] * vis; } @@ -589,9 +593,40 @@ launchRasterizeForwardKernels( } // namespace +template +std::tuple +dispatch_rasterize_screen_space_gaussians_fwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const RenderWindow2D &renderWindow, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks); + +template +std::tuple +dispatch_rasterize_screen_space_gaussians_sparse_fwd(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const RenderWindow2D &renderWindow, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const at::optional &backgrounds, + const at::optional &masks); + template <> std::tuple -dispatchGaussianRasterizeForward( +dispatch_rasterize_screen_space_gaussians_fwd( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -671,7 +706,7 @@ dispatchGaussianRasterizeForward( template <> std::tuple -dispatchGaussianRasterizeForward( +dispatch_rasterize_screen_space_gaussians_fwd( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -753,7 +788,7 @@ dispatchGaussianRasterizeForward( template <> std::tuple -dispatchGaussianRasterizeForward( +dispatch_rasterize_screen_space_gaussians_fwd( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -773,7 +808,7 @@ dispatchGaussianRasterizeForward( template <> std::tuple -dispatchGaussianSparseRasterizeForward( +dispatch_rasterize_screen_space_gaussians_sparse_fwd( // sparse pixel coordinates const fvdb::JaggedTensor &pixelsToRender, // [C, maxPixelsPerCamera, 2] // Gaussian parameters @@ -864,7 +899,7 @@ dispatchGaussianSparseRasterizeForward( template <> std::tuple -dispatchGaussianSparseRasterizeForward( +dispatch_rasterize_screen_space_gaussians_sparse_fwd( // sparse pixel coordinates const fvdb::JaggedTensor &pixelsToRender, // [C, maxPixelsPerCamera, 2] // Gaussian parameters @@ -887,7 +922,7 @@ dispatchGaussianSparseRasterizeForward( template <> std::tuple -dispatchGaussianSparseRasterizeForward( +dispatch_rasterize_screen_space_gaussians_sparse_fwd( // sparse pixel coordinates const fvdb::JaggedTensor &pixelsToRender, // [C, maxPixelsPerCamera, 2] // Gaussian parameters @@ -908,4 +943,72 @@ dispatchGaussianSparseRasterizeForward( TORCH_CHECK_NOT_IMPLEMENTED(false, "CPU implementation not available"); } +std::tuple +rasterize_screen_space_gaussians_fwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks) { + const RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_rasterize_screen_space_gaussians_fwd(means2d, + conics, + features, + opacities, + renderWindow, + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds, + masks); + }); +} + +std::tuple +rasterize_screen_space_gaussians_sparse_fwd(const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const at::optional &backgrounds, + const at::optional &masks) { + const RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH}; + return FVDB_DISPATCH_KERNEL(means2d.device(), [&]() { + return dispatch_rasterize_screen_space_gaussians_sparse_fwd(pixelsToRender, + means2d, + conics, + features, + opacities, + renderWindow, + tileSize, + tileOffsets, + tileGaussianIds, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + backgrounds, + masks); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.h b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.h similarity index 55% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.h rename to src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.h index 65ce897ff..706c4fbeb 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeForward.h +++ b/src/fvdb/detail/ops/RasterizeScreenSpaceGaussiansForward.h @@ -1,16 +1,16 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFORWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFORWARD_H +#ifndef FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSFORWARD_H +#define FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSFORWARD_H #include -#include #include #include +#include #include namespace fvdb { @@ -24,14 +24,15 @@ namespace ops { /// feature/color, and opacity. The function performs alpha-blending of the Gaussians to generate /// the final rendered image. /// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// /// @param[in] means2d 2D projected Gaussian centers [C, N, 2] /// @param[in] conics Gaussian covariance matrices in conic form [C, N, 3] representing (a, b, c) in /// ax² + 2bxy + cy² /// @param[in] features Feature / color values of Gaussians [C, N, D] /// @param[in] opacities Opacity values for each Gaussian [N] -/// @param[in] renderWindow Render window dimensions and origin. +/// @param[in] imageWidth Width of the render window in pixels +/// @param[in] imageHeight Height of the render window in pixels +/// @param[in] imageOriginW Horizontal origin of the render window +/// @param[in] imageOriginH Vertical origin of the render window /// @param[in] tileSize Size of tiles used for rasterization optimization /// @param[in] tileOffsets Offsets for tiles [C, tile_height, tile_width] indicating for each tile /// where its Gaussians start @@ -40,34 +41,38 @@ namespace ops { /// @param[in] backgrounds Optional background color per camera [C, D]. If provided, background /// colors will be blended with transparent pixels. If not provided, background is assumed to be /// black. +/// @param[in] masks Optional per-tile boolean mask [C, tile_height, tile_width] /// /// @return std::tuple containing: /// - Rendered image features/colors [C, render_height, render_width, D] /// - Alpha values [C, render_height, render_width, 1] /// - Last Gaussian ID rendered at each pixel [C, render_height, render_width] -template -std::tuple dispatchGaussianRasterizeForward( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &features, // [C, N, D] - const torch::Tensor &opacities, // [N] - const RenderWindow2D &renderWindow, - const uint32_t tileSize, - const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] - const torch::Tensor &tileGaussianIds, // [n_isects] - const at::optional &backgrounds = at::nullopt, // [C, D] - const at::optional &masks = at::nullopt // [C, tile_height, tile_width] bool -); +std::tuple +rasterize_screen_space_gaussians_fwd(const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); -/// @brief Dispatches the sparse Gaussian rasterization forward pass to the specified device. -/// Renders only specified pixels. -/// @tparam Device The device type (e.g., torch::kCPU or torch::kCUDA). +/// @brief Sparse Gaussian rasterization forward pass. Renders only specified pixels. +/// /// @param pixelsToRender Tensor containing the indices of pixels to render [C, NumPixels, 2]. /// @param means2d Tensor of 2D means. /// @param conics Tensor of conic parameters. /// @param features Tensor of features (colors, etc). /// @param opacities Tensor of opacities. -/// @param renderWindow Render window dimensions and origin. +/// @param imageWidth Width of the render window in pixels +/// @param imageHeight Height of the render window in pixels +/// @param imageOriginW Horizontal origin of the render window +/// @param imageOriginH Vertical origin of the render window /// @param tileSize Size of the tiles used for processing. /// @param tileOffsets Tensor containing offsets for each tile. /// @param tileGaussianIds Tensor mapping tiles to Gaussian IDs. @@ -78,30 +83,34 @@ std::tuple dispatchGaussianRasteriz /// @param backgrounds Optional background color per camera [C, D]. If provided, background /// colors will be blended with transparent pixels. If not provided, background is assumed to be /// black. +/// @param masks Optional per-tile boolean mask [C, tile_height, tile_width] /// @return A tuple containing: /// - Output colors JaggedTensor for the specified pixels. /// - Output alphas JaggedTensor for the specified pixels. /// - Output last Gaussian IDs JaggedTensor for the specified pixels. -template std::tuple -dispatchGaussianSparseRasterizeForward(const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &means2d, - const torch::Tensor &conics, - const torch::Tensor &features, - const torch::Tensor &opacities, - const RenderWindow2D &renderWindow, - const uint32_t tileSize, - const torch::Tensor &tileOffsets, - const torch::Tensor &tileGaussianIds, - const torch::Tensor &activeTiles, - const torch::Tensor &tilePixelMask, - const torch::Tensor &tilePixelCumsum, - const torch::Tensor &pixelMap, - const at::optional &backgrounds = at::nullopt, - const at::optional &masks = at::nullopt); +rasterize_screen_space_gaussians_sparse_fwd( + const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); } // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFORWARD_H +#endif // FVDB_DETAIL_OPS_RASTERIZESCREENSPACEGAUSSIANSFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.cu b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.cu similarity index 82% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.cu rename to src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.cu index 2294a6d63..59f3f383f 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.cu +++ b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.cu @@ -1,11 +1,12 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include +#include #include +#include +#include +#include +#include #include #include @@ -14,6 +15,8 @@ #include +#include + namespace fvdb::detail::ops { namespace cg = cooperative_groups; @@ -425,7 +428,7 @@ launchBackward(const torch::Tensor &means, features.packed_accessor64(), opacities.packed_accessor64()}; - const PreparedRasterOptionalInputs opt = prepareRasterOptionalInputs( + const PreparedRasterOptionalInputs opt = prepare_raster_optional_inputs( features, C, tileExtentH, tileExtentW, (int64_t)NUM_CHANNELS, backgrounds, masks); args.backgrounds = opt.backgrounds; args.masks = opt.masks; @@ -458,6 +461,33 @@ launchBackward(const torch::Tensor &means, } // namespace +template +std::tuple +dispatchGaussianRasterizeFromWorld3DGSBackward(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + RollingShutterType rollingShutterType, + DistortionModel cameraModel, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const at::optional &backgrounds, + const at::optional &masks); + template <> std::tuple dispatchGaussianRasterizeFromWorld3DGSBackward( @@ -472,7 +502,11 @@ dispatchGaussianRasterizeFromWorld3DGSBackward( const torch::Tensor &distortionCoeffs, const RollingShutterType rollingShutterType, const DistortionModel cameraModel, - const RenderSettings &settings, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, const torch::Tensor &tileOffsets, const torch::Tensor &tileGaussianIds, const torch::Tensor &renderedAlphas, @@ -485,12 +519,6 @@ dispatchGaussianRasterizeFromWorld3DGSBackward( const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); - const uint32_t imageWidth = settings.imageWidth; - const uint32_t imageHeight = settings.imageHeight; - const uint32_t imageOriginW = settings.imageOriginW; - const uint32_t imageOriginH = settings.imageOriginH; - const uint32_t tileSize = settings.tileSize; - TORCH_CHECK_VALUE(means.is_cuda(), "means must be CUDA"); TORCH_CHECK_VALUE(features.is_cuda(), "features must be CUDA"); TORCH_CHECK_VALUE(opacities.is_cuda(), "opacities must be CUDA"); @@ -628,7 +656,11 @@ dispatchGaussianRasterizeFromWorld3DGSBackward(const torch::Tensor const torch::Tensor &, const RollingShutterType, const DistortionModel, - const RenderSettings &, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, const torch::Tensor &, const torch::Tensor &, const torch::Tensor &, @@ -640,4 +672,57 @@ dispatchGaussianRasterizeFromWorld3DGSBackward(const torch::Tensor TORCH_CHECK_VALUE(false, "dispatchGaussianRasterizeFromWorld3DGSBackward is CUDA-only"); } +std::tuple +rasterize_world_space_gaussians_bwd(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + const RollingShutterType rollingShutterType, + const DistortionModel cameraModel, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const at::optional &backgrounds, + const at::optional &masks) { + return FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { + return dispatchGaussianRasterizeFromWorld3DGSBackward(means, + quats, + logScales, + features, + opacities, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + distortionCoeffs, + rollingShutterType, + cameraModel, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + backgrounds, + masks); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.h b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.h new file mode 100644 index 000000000..ca8e1aba8 --- /dev/null +++ b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansBackward.h @@ -0,0 +1,81 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSBACKWARD_H +#define FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSBACKWARD_H + +#include + +#include + +#include +#include +#include + +namespace fvdb::detail::ops { + +/// @brief Backward pass for dense rasterization from 3D Gaussians using per-pixel rays. +/// +/// Gradients are produced for: +/// - means: [N, 3] +/// - quats: [N, 4] +/// - logScales: [N, 3] +/// - features: [C, N, D] +/// - opacities: [C, N] +/// +/// @param[in] means Gaussian mean positions [N, 3] +/// @param[in] quats Gaussian quaternion rotations [N, 4] +/// @param[in] logScales Gaussian log-scale factors [N, 3] +/// @param[in] features Feature/color values [C, N, D] +/// @param[in] opacities Opacity values [C, N] +/// @param[in] worldToCamMatricesStart World-to-camera matrices (start) [C, 4, 4] +/// @param[in] worldToCamMatricesEnd World-to-camera matrices (end) [C, 4, 4] +/// @param[in] projectionMatrices Camera intrinsics [C, 3, 3] +/// @param[in] distortionCoeffs Distortion coefficients [C, K] +/// @param[in] rollingShutterType Rolling shutter policy +/// @param[in] cameraModel Camera/distortion model +/// @param[in] settings Render settings (image dimensions, tile size, etc.) +/// @param[in] tileOffsets Tile offsets [C, tileH, tileW] +/// @param[in] tileGaussianIds Tile Gaussian IDs [n_isects] +/// @param[in] renderedAlphas Alpha values from forward pass [C, H, W, 1] +/// @param[in] lastIds Last Gaussian ID per pixel [C, H, W] +/// @param[in] dLossDRenderedFeatures Gradients w.r.t. rendered features [C, H, W, D] +/// @param[in] dLossDRenderedAlphas Gradients w.r.t. rendered alphas [C, H, W, 1] +/// @param[in] backgrounds Optional per-camera background [C, D] +/// @param[in] masks Optional per-tile boolean mask [C, tileH, tileW] +/// +/// @return std::tuple containing gradients: +/// - dL/dmeans [N, 3] +/// - dL/dquats [N, 4] +/// - dL/dlogScales [N, 3] +/// - dL/dfeatures [C, N, D] +/// - dL/dopacities [C, N] +std::tuple +rasterize_world_space_gaussians_bwd(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + RollingShutterType rollingShutterType, + DistortionModel cameraModel, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); + +} // namespace fvdb::detail::ops + +#endif // FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.cu b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.cu similarity index 81% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.cu rename to src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.cu index 556fbb420..2ddc1a2aa 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.cu +++ b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.cu @@ -1,10 +1,11 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include +#include #include +#include +#include +#include #include #include @@ -12,6 +13,8 @@ #include +#include + namespace fvdb::detail::ops { namespace cg = cooperative_groups; @@ -170,7 +173,7 @@ template struct RasterizeFromWorldForwa continue; } const float nextTransmittance = transmittance * (1.0f - alpha); - if (nextTransmittance <= 1e-4f) { + if (nextTransmittance <= kTransmittanceThreshold) { done = true; break; } @@ -260,7 +263,7 @@ launchForward(const torch::Tensor &means, features.packed_accessor64(), opacities.packed_accessor64()}; - const PreparedRasterOptionalInputs opt = prepareRasterOptionalInputs( + const PreparedRasterOptionalInputs opt = prepare_raster_optional_inputs( features, C, tileExtentH, tileExtentW, (int64_t)NUM_CHANNELS, backgrounds, masks); commonArgs.backgrounds = opt.backgrounds; commonArgs.masks = opt.masks; @@ -288,6 +291,29 @@ launchForward(const torch::Tensor &means, } // namespace +template +std::tuple +dispatchGaussianRasterizeFromWorld3DGSForward(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + RollingShutterType rollingShutterType, + DistortionModel cameraModel, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks); + template <> std::tuple dispatchGaussianRasterizeFromWorld3DGSForward( @@ -302,7 +328,11 @@ dispatchGaussianRasterizeFromWorld3DGSForward( const torch::Tensor &distortionCoeffs, const RollingShutterType rollingShutterType, const DistortionModel cameraModel, - const RenderSettings &settings, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, const torch::Tensor &tileOffsets, const torch::Tensor &tileGaussianIds, const at::optional &backgrounds, @@ -311,12 +341,6 @@ dispatchGaussianRasterizeFromWorld3DGSForward( const at::cuda::OptionalCUDAGuard device_guard(device_of(means)); - const uint32_t imageWidth = settings.imageWidth; - const uint32_t imageHeight = settings.imageHeight; - const uint32_t imageOriginW = settings.imageOriginW; - const uint32_t imageOriginH = settings.imageOriginH; - const uint32_t tileSize = settings.tileSize; - TORCH_CHECK_VALUE(means.is_cuda(), "means must be CUDA"); TORCH_CHECK_VALUE(features.is_cuda(), "features must be CUDA"); TORCH_CHECK_VALUE(tileOffsets.is_cuda(), "tileOffsets must be CUDA"); @@ -479,7 +503,11 @@ dispatchGaussianRasterizeFromWorld3DGSForward(const torch::Tensor & const torch::Tensor &, const RollingShutterType, const DistortionModel, - const RenderSettings &, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, + const uint32_t, const torch::Tensor &, const torch::Tensor &, const at::optional &, @@ -487,4 +515,49 @@ dispatchGaussianRasterizeFromWorld3DGSForward(const torch::Tensor & TORCH_CHECK_VALUE(false, "dispatchGaussianRasterizeFromWorld3DGSForward is CUDA-only"); } +std::tuple +rasterize_world_space_gaussians_fwd(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + const RollingShutterType rollingShutterType, + const DistortionModel cameraModel, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks) { + return FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() { + return dispatchGaussianRasterizeFromWorld3DGSForward(means, + quats, + logScales, + features, + opacities, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + distortionCoeffs, + rollingShutterType, + cameraModel, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds, + masks); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.h b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.h new file mode 100644 index 000000000..bdf6fb689 --- /dev/null +++ b/src/fvdb/detail/ops/RasterizeWorldSpaceGaussiansForward.h @@ -0,0 +1,76 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSFORWARD_H +#define FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSFORWARD_H + +#include + +#include + +#include +#include +#include + +namespace fvdb::detail::ops { + +/// @brief Rasterize images directly from 3D Gaussians using per-pixel rays. +/// +/// This kernel follows the gsplat "RasterizeToPixelsFromWorld3DGS" algorithm, but is wired to +/// FVDB's existing tile intersection representation (`tileOffsets`, `tileGaussianIds`). +/// +/// Inputs are world-space Gaussians (means/quats/logScales) and per-camera per-gaussian features +/// and opacities. The camera is defined via world->camera matrices (start/end), intrinsics, +/// `DistortionModel`, rolling shutter policy, and packed OpenCV distortion coefficients. +/// +/// This is a dense-only rasterizer: outputs are dense tensors of shape +/// - renderedFeatures: [C, H, W, D] +/// - renderedAlphas: [C, H, W, 1] +/// - lastIds: [C, H, W] +/// +/// @param[in] means Gaussian mean positions [N, 3] +/// @param[in] quats Gaussian quaternion rotations [N, 4] (w,x,y,z) +/// @param[in] logScales Gaussian log-scale factors [N, 3] +/// @param[in] features Feature/color values [C, N, D] +/// @param[in] opacities Opacity values [C, N] +/// @param[in] worldToCamMatricesStart World-to-camera matrices (start) [C, 4, 4] +/// @param[in] worldToCamMatricesEnd World-to-camera matrices (end) [C, 4, 4] +/// @param[in] projectionMatrices Camera intrinsics [C, 3, 3] +/// @param[in] distortionCoeffs Distortion coefficients [C, K] (K=0 or 12) +/// @param[in] rollingShutterType Rolling shutter policy +/// @param[in] cameraModel Camera/distortion model +/// @param[in] settings Render settings (image dimensions, tile size, etc.) +/// @param[in] tileOffsets Tile offsets [C, tileH, tileW] +/// @param[in] tileGaussianIds Tile Gaussian IDs [n_isects] +/// @param[in] backgrounds Optional per-camera background [C, D] +/// @param[in] masks Optional per-tile boolean mask [C, tileH, tileW] +/// +/// @return std::tuple containing: +/// - Rendered features [C, H, W, D] +/// - Alpha values [C, H, W, 1] +/// - Last Gaussian ID per pixel [C, H, W] +std::tuple +rasterize_world_space_gaussians_fwd(const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + RollingShutterType rollingShutterType, + DistortionModel cameraModel, + uint32_t imageWidth, + uint32_t imageHeight, + uint32_t imageOriginW, + uint32_t imageOriginH, + uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds = at::nullopt, + const at::optional &masks = at::nullopt); + +} // namespace fvdb::detail::ops + +#endif // FVDB_DETAIL_OPS_RASTERIZEWORLDSPACEGAUSSIANSFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.cu b/src/fvdb/detail/ops/RelocateGaussians.cu similarity index 77% rename from src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.cu rename to src/fvdb/detail/ops/RelocateGaussians.cu index 0a186573f..e78383aac 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.cu +++ b/src/fvdb/detail/ops/RelocateGaussians.cu @@ -2,17 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include +#include #include #include #include #include -#include - #include namespace fvdb::detail::ops { @@ -20,21 +19,30 @@ namespace fvdb::detail::ops { using fvdb::detail::deviceChunk; using fvdb::detail::mergeStreams; +// Internal dispatch template (specializations defined below). +template +std::tuple +dispatch_relocate_gaussians(const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &ratios, + const torch::Tensor &binomialCoeffs, + const int nMax, + float minOpacity); + namespace { template __global__ void -gaussianRelocationKernel(int64_t localToGlobalOffset, - int64_t localSize, - fvdb::TorchRAcc64 logScales, - fvdb::TorchRAcc64 logitOpacities, - fvdb::TorchRAcc64 ratios, - fvdb::TorchRAcc64 binomialCoeffs, - fvdb::TorchRAcc64 logScalesNew, - fvdb::TorchRAcc64 logitOpacitiesNew, - std::size_t nMax, - float minOpacity) { - const auto N = logScales.size(0); +relocate_gaussians_kernel(int64_t localToGlobalOffset, + int64_t localSize, + fvdb::TorchRAcc64 logScales, + fvdb::TorchRAcc64 logitOpacities, + fvdb::TorchRAcc64 ratios, + fvdb::TorchRAcc64 binomialCoeffs, + fvdb::TorchRAcc64 logScalesNew, + fvdb::TorchRAcc64 logitOpacitiesNew, + std::size_t nMax, + float minOpacity) { for (uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x + localToGlobalOffset; idx < localSize + localToGlobalOffset; idx += blockDim.x * gridDim.x) { @@ -83,7 +91,7 @@ launchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] const int blockDim = DEFAULT_BLOCK_DIM; const int gridDim = fvdb::GET_BLOCKS(size, blockDim); - gaussianRelocationKernel<<>>( + relocate_gaussians_kernel<<>>( offset, size, logScales.packed_accessor64(), @@ -102,12 +110,12 @@ launchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] template <> std::tuple -dispatchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &ratios, // [N] - const torch::Tensor &binomialCoeffs, // [nMax, nMax] - const int nMax, - float minOpacity) { +dispatch_relocate_gaussians(const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &ratios, // [N] + const torch::Tensor &binomialCoeffs, // [nMax, nMax] + const int nMax, + float minOpacity) { FVDB_FUNC_RANGE(); const at::cuda::OptionalCUDAGuard device_guard(device_of(logScales)); @@ -155,12 +163,13 @@ dispatchGaussianRelocation(const torch::Tensor &logScales, // template <> std::tuple -dispatchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &ratios, // [N] - const torch::Tensor &binomialCoeffs, // [nMax, nMax] - const int nMax, - float minOpacity) { +dispatch_relocate_gaussians( + const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &ratios, // [N] + const torch::Tensor &binomialCoeffs, // [nMax, nMax] + const int nMax, + float minOpacity) { FVDB_FUNC_RANGE(); const auto N = logScales.size(0); @@ -219,14 +228,27 @@ dispatchGaussianRelocation(const torch::Tensor &logScales, template <> std::tuple -dispatchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &ratios, // [N] - const torch::Tensor &binomialCoeffs, // [nMax, nMax] - const int nMax, - float minOpacity) { +dispatch_relocate_gaussians(const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &ratios, // [N] + const torch::Tensor &binomialCoeffs, // [nMax, nMax] + const int nMax, + float minOpacity) { // CPU path intentionally unsupported; keep signature for clearer error messaging in tests. TORCH_CHECK_NOT_IMPLEMENTED(false, "GaussianRelocation is not implemented for CPU"); } +std::tuple +relocate_gaussians(const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &ratios, // [N] + const torch::Tensor &binomialCoeffs, // [nMax, nMax] + const int nMax, + float minOpacity) { + return FVDB_DISPATCH_KERNEL(logScales.device(), [&]() { + return dispatch_relocate_gaussians( + logScales, logitOpacities, ratios, binomialCoeffs, nMax, minOpacity); + }); +} + } // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.h b/src/fvdb/detail/ops/RelocateGaussians.h similarity index 59% rename from src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.h rename to src/fvdb/detail/ops/RelocateGaussians.h index 3744df162..952edda59 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianMCMCRelocation.h +++ b/src/fvdb/detail/ops/RelocateGaussians.h @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCRELOCATION_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCRELOCATION_H +#ifndef FVDB_DETAIL_OPS_RELOCATEGAUSSIANS_H +#define FVDB_DETAIL_OPS_RELOCATEGAUSSIANS_H #include @@ -15,6 +15,9 @@ namespace ops { /// @brief Relocate Gaussians by adjusting opacity and scale based on replication ratio. /// +/// Dispatches to the appropriate device implementation (CPU, CUDA, or PrivateUse1) +/// based on the device of the input tensors. +/// /// @param logScales Input log scales [N, 3] /// @param logitOpacities Input logit opacities [N] /// @param ratios Replication ratios per Gaussian [N] (int32) @@ -22,14 +25,13 @@ namespace ops { /// @param nMax Maximum replication ratio (size of binomial table) /// @param minOpacity Minimum opacity /// @return tuple of (opacitiesNew [N], scalesNew [N, 3]) -template std::tuple -dispatchGaussianRelocation(const torch::Tensor &logScales, // [N, 3] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &ratios, // [N] - const torch::Tensor &binomialCoeffs, // [nMax, nMax] - const int nMax, - float minOpacity); +relocate_gaussians(const torch::Tensor &logScales, // [N, 3] + const torch::Tensor &logitOpacities, // [N] + const torch::Tensor &ratios, // [N] + const torch::Tensor &binomialCoeffs, // [nMax, nMax] + const int nMax, + float minOpacity); } // namespace ops } // namespace detail diff --git a/src/fvdb/detail/ops/gsplat/GaussianCameraMatrixUtils.cuh b/src/fvdb/detail/ops/gsplat/GaussianCameraMatrixUtils.cuh deleted file mode 100644 index 2343b3ab1..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianCameraMatrixUtils.cuh +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAMATRIXUTILS_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAMATRIXUTILS_CUH - -#include - -#include - -namespace fvdb::detail::ops { - -/// @brief Load world-to-camera rotation+translation from a row-major 4x4 matrix pointer. -/// -/// Input layout is row-major [4,4], where the translation is in the last column. -/// Returns a tuple (R, t) with R as Mat3 and t as Vec3. -template -inline __host__ __device__ cuda::std::tuple, nanovdb::math::Vec3> -loadWorldToCamRtRowMajor4x4(const T *m44) { - // Row-major 4x4 with last column = translation. - return {nanovdb::math::Mat3(m44[0], - m44[1], - m44[2], // 1st row - m44[4], - m44[5], - m44[6], // 2nd row - m44[8], - m44[9], - m44[10]), // 3rd row - nanovdb::math::Vec3(m44[3], m44[7], m44[11])}; -} - -/// @brief Load world-to-camera rotation+translation from a [C,4,4] accessor for camera `camId`. -/// -/// Returns a tuple (R, t) with R as Mat3 and t as Vec3, matching -/// `loadWorldToCamRtRowMajor4x4`. -template -inline __device__ cuda::std::tuple, nanovdb::math::Vec3> -loadWorldToCamRtFromAccessor44(const Acc44 &m44 /* [C,4,4] */, const int64_t camId) { - return {nanovdb::math::Mat3(m44[camId][0][0], - m44[camId][0][1], - m44[camId][0][2], // 1st row - m44[camId][1][0], - m44[camId][1][1], - m44[camId][1][2], // 2nd row - m44[camId][2][0], - m44[camId][2][1], - m44[camId][2][2]), // 3rd row - nanovdb::math::Vec3(m44[camId][0][3], m44[camId][1][3], m44[camId][2][3])}; -} - -/// @brief Load a 3x3 matrix from a [C,3,3] accessor for camera `camId`. -template -inline __device__ nanovdb::math::Mat3 -loadMat3FromAccessor33(const Acc33 &m33 /* [C,3,3] */, const int64_t camId) { - return nanovdb::math::Mat3(m33[camId][0][0], - m33[camId][0][1], - m33[camId][0][2], - m33[camId][1][0], - m33[camId][1][1], - m33[camId][1][2], - m33[camId][2][0], - m33[camId][2][1], - m33[camId][2][2]); -} - -} // namespace fvdb::detail::ops - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAMATRIXUTILS_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h b/src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h deleted file mode 100644 index b5134d126..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianMCMCAddNoise.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// - -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -template -void dispatchGaussianMCMCAddNoise(torch::Tensor &means, // [N, 3] input/output - const torch::Tensor &logScales, // [N] - const torch::Tensor &logitOpacities, // [N] - const torch::Tensor &quats, // [N, 4] - const float noiseScale, - const float t, - const float k); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMCMCADDNOISE_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianMacros.cuh b/src/fvdb/detail/ops/gsplat/GaussianMacros.cuh deleted file mode 100644 index 854a68ad9..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianMacros.cuh +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMACROS_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMACROS_CUH - -#define GSPLAT_CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define GSPLAT_CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define GSPLAT_CHECK_INPUT(x) \ - GSPLAT_CHECK_CUDA(x); \ - GSPLAT_CHECK_CONTIGUOUS(x) -#define GSPLAT_DEVICE_GUARD(_ten) const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); - -#define GSPLAT_PRAGMA_UNROLL _Pragma("unroll") - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANMACROS_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.h b/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.h deleted file mode 100644 index a39c6ca32..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianProjectionJaggedBackward.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDBACKWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDBACKWARD_H - -#include - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Calculate gradients for the jagged 3D to 2D Gaussian projection (backward pass) -/// -/// This function computes the gradients of the 3D to 2D Gaussian projection with respect to -/// the input parameters when using jagged tensors for batch processing. It enables backpropagation -/// through the projection step in the Gaussian Splatting pipeline for scenes with variable -/// numbers of objects and cameras per batch. -/// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// -/// @param[in] gSizes Batch sizes for Gaussians [B] -/// @param[in] means 3D positions of Gaussians [M, 3] -/// @param[in] quats Quaternion rotations of Gaussians [M, 4] in format (x, y, z, w) -/// @param[in] scales Scale factors of Gaussians [M, 3] representing extent in each dimension -/// @param[in] cSizes Batch sizes for cameras [B] -/// @param[in] worldToCamMatrices Camera view matrices [BC, 4, 4] -/// @param[in] projectionMatrices Camera intrinsic matrices [BC, 3, 3] -/// @param[in] imageWidth Width of the output image in pixels -/// @param[in] imageHeight Height of the output image in pixels -/// @param[in] eps2d 2D projection epsilon for numerical stability -/// @param[in] radii Output radii from forward pass [M] -/// @param[in] conics Output conics from forward pass [M, 3] -/// @param[out] dLossDMeans2d Gradients with respect to projected 2D means [M, 2] -/// @param[out] dLossDDepths Gradients with respect to depths [M] -/// @param[out] dLossDConics Gradients with respect to conics [M, 3] -/// @param[in] worldToCamMatricesRequiresGrad Whether viewmats requires gradient -/// @param[in] ortho Whether orthographic projection was used in forward pass -/// -/// @return std::tuple containing gradients of the loss function with respect to the input -/// parameters: -/// - 3D means [M, 3] - ∂L/∂means -/// - Quaternions [M, 4] - ∂L/∂quats -/// - Scales [M, 3] - ∂L/∂scales -/// - View matrices [BC, 4, 4] - ∂L/∂viewmats (if viewmats_requires_grad is true, otherwise -/// empty tensor) -/// - Camera intrinsics [BC, 3, 3] - ∂L/∂Ks -template -std::tuple -dispatchGaussianProjectionJaggedBackward(const torch::Tensor &gSizes, // [B] gaussian sizes - const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] optional - const torch::Tensor &scales, // [N, 3] optional - const torch::Tensor &cSizes, // [B] camera sizes - const torch::Tensor &worldToCamMatrices, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const uint32_t imageWidth, - const uint32_t imageHeight, - const float eps2d, - const torch::Tensor &radii, // [N] - const torch::Tensor &conics, // [N, 3] - const torch::Tensor &dLossDMeans2d, // [N, 2] - const torch::Tensor &dLossDDepths, // [N] - const torch::Tensor &dLossDConics, // [N, 3] - const bool worldToCamMatricesRequiresGrad, - const bool ortho); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANPROJECTIONJAGGEDBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.h deleted file mode 100644 index ef545415b..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeBackward.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEBACKWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEBACKWARD_H - -#include -#include - -#include - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Calculate gradients for the Gaussian rasterization process (backward pass) -/// -/// This function computes the gradients of the Gaussian splatting rendering with respect to -/// its input parameters: 2D projected Gaussian means, conics, features/colors, and opacities. -/// It is used during backpropagation to update the Gaussian parameters during training. -/// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// -/// @param[in] means2d 2D projected Gaussian centers [C, N, 2] -/// @param[in] conics Gaussian covariance matrices in conic form [C, N, 3] representing (a, b, c) in -/// ax² + 2bxy + cy² -/// @param[in] features Feature / color values of Gaussians [C, N, D] -/// @param[in] opacities Opacity values for each Gaussian [N] -/// @param[in] renderWindow Render window dimensions and origin. -/// @param[in] tileSize Size of tiles used for rasterization optimization -/// @param[in] tileOffsets Offsets for tiles [C, tile_height, tile_width] -/// @param[in] tileGaussianIds Flattened Gaussian IDs for tile intersection [n_isects] -/// @param[in] renderedAlphas Alpha values from forward pass [C, render_height, render_width, 1] -/// @param[in] lastIds Last Gaussian IDs per pixel from forward pass [C, render_height, -/// render_width] -/// @param[out] dLossDRenderedFeatures Gradients of loss with respect to rendered features [C, -/// render_height, render_width, D] -/// @param[out] dLossDRenderedAlphas Gradients of loss with respect to rendered alphas [C, -/// render_height, render_width, 1] -/// @param[in] absGrad Whether to use absolute gradients -/// @param[in] numSharedChannelsOverride Override for number of shared memory channels (-1 means -/// auto-select) -/// @param[in] backgrounds Optional background color per camera [C, D]. If provided, background -/// colors affect gradient computation for transparent pixels. If not provided, background is -/// assumed to be black. -/// -/// @return std::tuple containing gradients of the loss function with respect to the input -/// parameters: -/// - Absolute value of 2D means [C, N, 2] - gradients ∂L/∂|means2d| (optional: if -/// absGrad is true, this tensor is returned, otherwise it is an empty tensor) -/// - 2D means [C, N, 2] - gradients ∂L/∂means2d -/// - conics [C, N, 3] - gradients ∂L/∂conics -/// - features [C, N, D] - gradients ∂L/∂features -/// - opacities [N] - gradients ∂L/∂opacities -template -std::tuple -dispatchGaussianRasterizeBackward( - // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &features, // [C, N, D] - const torch::Tensor &opacities, // [N] - const RenderWindow2D &renderWindow, - const uint32_t tileSize, - const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] - const torch::Tensor &tileGaussianIds, // [n_isects] - const torch::Tensor &renderedAlphas, // [C, imageHeight, imageWidth, 1] - const torch::Tensor &lastIds, // [C, imageHeight, imageWidth] - const torch::Tensor &dLossDRenderedFeatures, // [C, imageHeight, imageWidth, D] - const torch::Tensor &dLossDRenderedAlphas, // [C, imageHeight, imageWidth, 1] - const bool absGrad, - const int64_t numSharedChannelsOverride = -1, - const at::optional &backgrounds = at::nullopt, // [C, D] - const at::optional &masks = at::nullopt // [C, tile_height, tile_width] bool -); - -/// @brief Calculate gradients for the sparse Gaussian rasterization process (backward pass) -/// -/// This function computes the gradients of the sparse Gaussian splatting rendering with respect to -/// its input parameters for only the specified pixels. It combines the efficiency of sparse -/// rasterization with gradient computation, processing only the pixels specified in pixelsToRender. -/// -/// @tparam DeviceType Device type template parameter (torch::kCUDA or torch::kCPU) -/// -/// @param[in] pixelsToRender JaggedTensor containing pixel coordinates to render [C, NumPixels, 2] -/// @param[in] means2d 2D projected Gaussian centers [C, N, 2] -/// @param[in] conics Gaussian covariance matrices in conic form [C, N, 3] representing (a, b, c) in -/// ax² + 2bxy + cy² -/// @param[in] features Feature / color values of Gaussians [C, N, D] -/// @param[in] opacities Opacity values for each Gaussian [N] -/// @param[in] renderWindow Render window dimensions and origin. -/// @param[in] tileSize Size of tiles used for rasterization optimization -/// @param[in] tileOffsets Offsets for tiles [C, tile_height, tile_width] -/// @param[in] tileGaussianIds Flattened Gaussian IDs for tile intersection [n_isects] -/// @param[in] activeTiles Tensor containing indices of active tiles -/// @param[in] tilePixelMask Tensor containing the mask for each tile pixel -/// @param[in] tilePixelCumsum Tensor containing cumulative sum of tile pixels -/// @param[in] pixelMap Tensor containing mapping of pixels to output indices -/// @param[in] renderedAlphas Alpha values from sparse forward pass [JaggedTensor: C lists of -/// varying sizes, each element [1]] -/// @param[in] lastIds Last Gaussian IDs per pixel from sparse forward pass [JaggedTensor: C lists -/// of varying sizes] -/// @param[in] dLossDRenderedFeatures Gradients of loss w.r.t sparse rendered features -/// [JaggedTensor: C lists of varying sizes, each element [D]] -/// @param[in] dLossDRenderedAlphas Gradients of loss w.r.t sparse rendered alphas [JaggedTensor: C -/// lists of varying sizes, each element [1]] -/// @param[in] absGrad Whether to use absolute gradients -/// @param[in] numSharedChannelsOverride Override for number of shared memory channels (-1 means -/// auto-select) -/// @param[in] backgrounds Optional background color per camera [C, D]. If provided, background -/// colors affect gradient computation for transparent pixels. If not provided, background is -/// assumed to be black. -/// -/// @return std::tuple containing gradients of the loss function with respect to the input -/// parameters: -/// - Absolute value of 2D means [C, N, 2] - gradients ∂L/∂|means2d| (optional: if -/// absGrad is true, this tensor is returned, otherwise it is an empty tensor) -/// - 2D means [C, N, 2] - gradients ∂L/∂means2d -/// - conics [C, N, 3] - gradients ∂L/∂conics -/// - features [C, N, D] - gradients ∂L/∂features -/// - opacities [N] - gradients ∂L/∂opacities -template -std::tuple -dispatchGaussianSparseRasterizeBackward( - // Sparse pixel coordinates and setup - const fvdb::JaggedTensor &pixelsToRender, // [C, NumPixels, 2] - // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &features, // [C, N, D] - const torch::Tensor &opacities, // [N] - // Image and tile setup - const RenderWindow2D &renderWindow, - const uint32_t tileSize, - const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] (dense) or [AT + 1] (sparse) - const torch::Tensor &tileGaussianIds, // [n_isects] - // Forward pass outputs (sparse) - const fvdb::JaggedTensor &renderedAlphas, // [C lists: varying sizes, each element [1]] - const fvdb::JaggedTensor &lastIds, // [C lists: varying sizes] - // Gradients (sparse) - const fvdb::JaggedTensor &dLossDRenderedFeatures, // [C lists: varying sizes, each element [D]] - const fvdb::JaggedTensor &dLossDRenderedAlphas, // [C lists: varying sizes, each element [1]] - // Sparse processing setup - const torch::Tensor &activeTiles, // [AT] - const torch::Tensor &tilePixelMask, // [AT, wordsPerTile] - const torch::Tensor &tilePixelCumsum, // [AT] - const torch::Tensor &pixelMap, // [AP] - // Options - const bool absGrad, - const int64_t numSharedChannelsOverride = -1, - const at::optional &backgrounds = at::nullopt, // [C, D] - const at::optional &masks = at::nullopt // [C, tile_height, tile_width] bool -); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.h deleted file mode 100644 index 0c586d85b..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeContributingGaussianIds.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZECONTRIBUTINGGAUSSIANIDS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZECONTRIBUTINGGAUSSIANIDS_H - -#include -#include - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Performs deep image rasterization to render the IDs and weighted alpha values of the -/// contributing Gaussians for each pixel -template -std::tuple dispatchGaussianRasterizeContributingGaussianIds( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &tile_gaussian_ids, // [n_isects] - const RenderSettings &settings, - const std::optional &maybeNumContributingGaussians = std::nullopt); - -/// @brief Performs sparse deep image rasterization to render the IDs and weighted alpha values of -/// the top-K most visible Gaussians for each pixel. Renders only specified pixels. -template -std::tuple -dispatchGaussianSparseRasterizeContributingGaussianIds( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] (dense) or [AT + 1] (sparse) - const torch::Tensor &tile_gaussian_ids, // [n_isects] - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &activeTiles, - const torch::Tensor &tilePixelMask, - const torch::Tensor &tilePixelCumsum, - const torch::Tensor &pixelMap, - const RenderSettings &settings, - const std::optional &maybeNumContributingGaussians = std::nullopt); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZECONTRIBUTINGGAUSSIANIDS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.h deleted file mode 100644 index 09856f678..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldBackward.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDBACKWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDBACKWARD_H - -#include -#include - -#include - -#include -#include - -namespace fvdb::detail::ops { - -/// @brief Backward pass for dense rasterization from 3D Gaussians using per-pixel rays. -/// -/// Gradients are produced for: -/// - means: [N, 3] -/// - quats: [N, 4] -/// - logScales: [N, 3] -/// - features: [C, N, D] -/// - opacities: [C, N] -/// -/// @tparam DeviceType torch::kCUDA (CPU not implemented). -template -std::tuple -dispatchGaussianRasterizeFromWorld3DGSBackward( - // Gaussian parameters (world space) - const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] - const torch::Tensor &logScales, // [N, 3] - // Per-camera quantities - const torch::Tensor &features, // [C, N, D] - const torch::Tensor &opacities, // [C, N] - const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] - const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const torch::Tensor &distortionCoeffs, // [C, K] (K=0 or 12) - const RollingShutterType rollingShutterType, - const DistortionModel cameraModel, - // Render settings - const RenderSettings &settings, - // Intersections - const torch::Tensor &tileOffsets, // [C, tileH, tileW] - const torch::Tensor &tileGaussianIds, // [n_isects] values in [0, C*N) - // Forward outputs needed for backward - const torch::Tensor &renderedAlphas, // [C, H, W, 1] - const torch::Tensor &lastIds, // [C, H, W] - // Gradients of outputs - const torch::Tensor &dLossDRenderedFeatures, // [C, H, W, D] - const torch::Tensor &dLossDRenderedAlphas, // [C, H, W, 1] - // Optional background (only affects alpha gradient term) - const at::optional &backgrounds = at::nullopt, // [C, D] - // Optional tile masks (parity with classic rasterizer) - const at::optional &masks = at::nullopt // [C, tileH, tileW] bool -); - -} // namespace fvdb::detail::ops - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDBACKWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.h deleted file mode 100644 index c9f76a1f5..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorldForward.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDFORWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDFORWARD_H - -#include -#include - -#include - -#include -#include - -namespace fvdb::detail::ops { - -/// @brief Rasterize images directly from 3D Gaussians using per-pixel rays. -/// -/// This kernel follows the gsplat "RasterizeToPixelsFromWorld3DGS" algorithm, but is wired to -/// FVDB's existing tile intersection representation (`tileOffsets`, `tileGaussianIds`). -/// -/// Inputs are world-space Gaussians (means/quats/logScales) and per-camera per-gaussian features -/// and opacities. The camera is defined via world->camera matrices (start/end), intrinsics, -/// `DistortionModel`, rolling shutter policy, and packed OpenCV distortion coefficients. -/// -/// This is a dense-only rasterizer: outputs are dense tensors of shape -/// - renderedFeatures: [C, H, W, D] -/// - renderedAlphas: [C, H, W, 1] -/// - lastIds: [C, H, W] -/// -/// @tparam DeviceType torch::kCUDA (CPU not implemented). -template -std::tuple -dispatchGaussianRasterizeFromWorld3DGSForward( - // Gaussian parameters (world space) - const torch::Tensor &means, // [N, 3] - const torch::Tensor &quats, // [N, 4] (w,x,y,z) - const torch::Tensor &logScales, // [N, 3] - // Per-camera quantities - const torch::Tensor &features, // [C, N, D] - const torch::Tensor &opacities, // [C, N] - const torch::Tensor &worldToCamMatricesStart, // [C, 4, 4] - const torch::Tensor &worldToCamMatricesEnd, // [C, 4, 4] - const torch::Tensor &projectionMatrices, // [C, 3, 3] - const torch::Tensor &distortionCoeffs, // [C, K] (K=0 or 12) - const RollingShutterType rollingShutterType, - const DistortionModel cameraModel, - // Render settings - const RenderSettings &settings, - // Intersections - const torch::Tensor &tileOffsets, // [C, tileH, tileW] - const torch::Tensor &tileGaussianIds, // [n_isects] values in [0, C*N) - // Optional background - const at::optional &backgrounds = at::nullopt, // [C, D] - // Optional tile masks (parity with classic rasterizer) - const at::optional &masks = at::nullopt // [C, tileH, tileW] bool -); - -} // namespace fvdb::detail::ops - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLDFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.h deleted file mode 100644 index 2eb672639..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeNumContributingGaussians.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZENUMCONTRIBUTINGGAUSSIANS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZENUMCONTRIBUTINGGAUSSIANS_H - -#include -#include - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -template -std::tuple dispatchGaussianRasterizeNumContributingGaussians( - // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] - const torch::Tensor &tileGaussianIds, // [n_isects] - const RenderSettings &settings // render settings -); - -template -std::tuple -dispatchGaussianSparseRasterizeNumContributingGaussians( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tileOffsets, // [C, tile_height, tile_width] (dense) or [AT + 1] (sparse) - const torch::Tensor &tileGaussianIds, // [n_isects] - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &activeTiles, - const torch::Tensor &tilePixelMask, - const torch::Tensor &tilePixelCumsum, - const torch::Tensor &pixelMap, - const RenderSettings &settings); // render settings - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZENUMCONTRIBUTINGGAUSSIANS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.h b/src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.h deleted file mode 100644 index 740656672..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeTopContributingGaussianIds.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZETOPCONTRIBUTINGGAUSSIANIDS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZETOPCONTRIBUTINGGAUSSIANIDS_H - -#include -#include - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Performs deep image rasterization to render the IDs and weighted alpha values of the -/// top-K most visible Gaussians for each pixel -template -std::tuple dispatchGaussianRasterizeTopContributingGaussianIds( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &tile_gaussian_ids, // [n_isects] - const RenderSettings &settings); - -/// @brief Performs sparse deep image rasterization to render the IDs and weighted alpha values of -/// the top-K most visible Gaussians for each pixel. Renders only specified pixels. -template -std::tuple -dispatchGaussianSparseRasterizeTopContributingGaussianIds( - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &opacities, // [N] - const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] (dense) or [AT + 1] (sparse) - const torch::Tensor &tile_gaussian_ids, // [n_isects] - const fvdb::JaggedTensor &pixelsToRender, - const torch::Tensor &activeTiles, - const torch::Tensor &tilePixelMask, - const torch::Tensor &tilePixelCumsum, - const torch::Tensor &pixelMap, - const RenderSettings &settings); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZETOPCONTRIBUTINGGAUSSIANIDS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRenderSettings.h b/src/fvdb/detail/ops/gsplat/GaussianRenderSettings.h deleted file mode 100644 index 34b47973b..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRenderSettings.h +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRENDERSETTINGS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRENDERSETTINGS_H - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -struct RenderWindow2D { - std::uint32_t width = 0; - std::uint32_t height = 0; - std::uint32_t originW = 0; - std::uint32_t originH = 0; - - inline constexpr std::uint32_t - pixelCountPerCamera() const { - return width * height; - } - - inline constexpr std::uint32_t - tileExtentW(const std::uint32_t tileSize) const { - return (width + tileSize - 1) / tileSize; - } - - inline constexpr std::uint32_t - tileExtentH(const std::uint32_t tileSize) const { - return (height + tileSize - 1) / tileSize; - } -}; - -struct RenderSettings { - enum class RenderMode { - RGB = 0, - DEPTH = 1, - RGBD = 2, - }; - - std::uint32_t imageWidth; - std::uint32_t imageHeight; - std::uint32_t imageOriginW = 0; - std::uint32_t imageOriginH = 0; - RenderMode renderMode = RenderMode::RGB; - float nearPlane = 0.01; - float farPlane = 1e10; - std::uint32_t tileSize = 16; - float radiusClip = 0.0; - float eps2d = 0.3; - bool antialias = false; - int shDegreeToUse = -1; - int numDepthSamples = -1; -}; -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRENDERSETTINGS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianRigidTransform.cuh b/src/fvdb/detail/ops/gsplat/GaussianRigidTransform.cuh deleted file mode 100644 index ff8d4a480..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianRigidTransform.cuh +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRIGIDTRANSFORM_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRIGIDTRANSFORM_CUH - -#include - -#include - -namespace fvdb::detail::ops { - -/// @brief Rigid transform (cached rotation + translation). -/// -/// Quaternion is stored as \([w,x,y,z]\) and is assumed to represent a rotation. -/// The corresponding rotation matrix \(R(q)\) is cached to avoid recomputing it for every point -/// transform (UT sigma points, rolling-shutter iterations, ray generation, etc.). -template struct RigidTransform { - nanovdb::math::Mat3 R; - nanovdb::math::Vec4 q; - nanovdb::math::Vec3 t; - - /// @brief Default constructor (identity transform). - /// - /// Initializes to unit quaternion \([1,0,0,0]\) and zero translation. - inline __host__ __device__ - RigidTransform() - : R(nanovdb::math::Mat3(nanovdb::math::Vec3(T(1), T(0), T(0)), - nanovdb::math::Vec3(T(0), T(1), T(0)), - nanovdb::math::Vec3(T(0), T(0), T(1)))), - q(T(1), T(0), T(0), T(0)), t(T(0), T(0), T(0)) {} - - /// @brief Construct from quaternion and translation. - /// @param[in] q_in Rotation quaternion \([w,x,y,z]\). - /// @param[in] t_in Translation vector. - inline __host__ __device__ - RigidTransform(const nanovdb::math::Vec4 &q_in, const nanovdb::math::Vec3 &t_in) - : R(quaternionToRotationMatrix(q_in)), q(q_in), t(t_in) {} - - /// @brief Construct from rotation matrix and translation. - /// @param[in] R_in Rotation matrix. - /// @param[in] t_in Translation vector. - inline __host__ __device__ - RigidTransform(const nanovdb::math::Mat3 &R_in, const nanovdb::math::Vec3 &t_in) - : R(R_in), q(rotationMatrixToQuaternion(R_in)), t(t_in) {} - - /// @brief Apply the transform to a 3D point: \(R(q)\,p + t\). - inline __host__ __device__ nanovdb::math::Vec3 - apply(const nanovdb::math::Vec3 &p_world) const { - // p_cam = R * p_world + t - return R * p_world + t; - } - - /// @brief Interpolate between two rigid transforms. - /// - /// Translation is linearly interpolated; rotation uses NLERP along the shortest arc. - inline static __host__ __device__ RigidTransform - interpolate(const T u, const RigidTransform &start, const RigidTransform &end) { - const nanovdb::math::Vec3 t_interp = start.t + u * (end.t - start.t); - const nanovdb::math::Vec4 q_interp = nlerpQuaternionShortestPath(start.q, end.q, u); - return RigidTransform(q_interp, t_interp); - } -}; - -} // namespace fvdb::detail::ops - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRIGIDTRANSFORM_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.h b/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.h deleted file mode 100644 index ac6b47565..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianSphericalHarmonicsForward.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSFORWARD_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSFORWARD_H - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Evaluate spherical harmonics functions to compute features/colors. -/// -/// This function computes the features/colors for points in 3D space using spherical harmonics -/// (SH) representation. Spherical harmonics provide an efficient way to represent view-dependent -/// appearance for Gaussian Splatting and other rendering techniques. The output features are not -/// limited to RGB colors; they can have any number of channels. -/// -/// @param[in] shDegreeToUse Degree of spherical harmonics to use (0-3 typically, higher degrees -/// provide more detail) -/// @param[in] numCameras Number of cameras used for rendering -/// @param[in] viewDirs Direction vectors [N, 3] (packed) or [C, N, 3] (unpacked) normalized to unit -/// length, representing view directions -/// @param[in] sh0Coeffs Spherical harmonic coefficients [N, 1, D] (packed) or -/// [1, N, D] (unpacked), where D is the number of feature channels -/// @param[in] shNCoeffs Higher order spherical harmonic coefficients [N, K-1, D] (packed) or -/// [K-1, N, D] (unpacked), where K depends on sh_degree_to_use (K=(sh_degree_to_use+1)²) -/// @param[in] radii radii [N] (packed) or [C, N] (unpacked) for view-dependent level-of-detail -/// control -/// -/// @return Features/colors [N, D] computed from the spherical harmonics evaluation -template -torch::Tensor dispatchSphericalHarmonicsForward(const int64_t shDegreeToUse, - const int64_t numCameras, - const torch::Tensor &viewDirs, // [C, N, 3] - const torch::Tensor &sh0Coeffs, // [1, N, D] - const torch::Tensor &shNCoeffs, // [N, K-1, D] - const torch::Tensor &radii // [C, N] -); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANSPHERICALHARMONICSFORWARD_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianTileIntersection.h b/src/fvdb/detail/ops/gsplat/GaussianTileIntersection.h deleted file mode 100644 index bcd04cf33..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianTileIntersection.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANTILEINTERSECTION_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANTILEINTERSECTION_H - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Compute the intersection of 2D Gaussians with image tiles for efficient rasterization -template -std::tuple -dispatchGaussianTileIntersection(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] - const torch::Tensor &radii, // [C, N] or [M] - const torch::Tensor &depths, // [C, N] or [M] - const at::optional &cameraIds, // NULL or [M] - const uint32_t numCameras, - const uint32_t tileSize, - const uint32_t numTilesH, - const uint32_t numTilesW); - -/// @brief Compute the intersection of 2D Gaussians with image tiles for sparse rendering -template -std::tuple -dispatchGaussianSparseTileIntersection(const torch::Tensor &means2d, // [C, N, 2] or [M, 2] - const torch::Tensor &radii, // [C, N] or [M] - const torch::Tensor &depths, // [C, N] or [M] - const torch::Tensor &tileMask, // [C, H, W] - const torch::Tensor &activeTiles, // [num_active_tiles] - const at::optional &cameraIds, // NULL or [M] - const uint32_t numCameras, - const uint32_t tileSize, - const uint32_t numTilesH, - const uint32_t numTilesW); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANTILEINTERSECTION_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianUtils.cpp b/src/fvdb/detail/ops/gsplat/GaussianUtils.cpp deleted file mode 100644 index 71e877036..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianUtils.cpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include - -#include -#include - -namespace fvdb::detail::ops { - -void -perCameraPrefetchAsync(const torch::Tensor &tensor, - uint32_t cameraOffset, - uint32_t cameraCount, - int deviceId, - cudaStream_t stream) { - TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); - TORCH_CHECK(cameraOffset + cameraCount <= tensor.size(0), - "Tensor does not have a batched first dimension"); - size_t scalarSize = c10::elementSize(tensor.scalar_type()); - nanovdb::util::cuda::memPrefetchAsync(static_cast(tensor.const_data_ptr()) + - cameraOffset * tensor.stride(0) * scalarSize, - cameraCount * tensor.stride(0) * scalarSize, - deviceId, - stream); -} - -void -perCameraPrefetchBatchAsync(const torch::TensorList &tensors, - uint32_t cameraOffset, - uint32_t cameraCount, - int deviceId, - cudaStream_t stream) { - TORCH_CHECK(stream, "cudaMemPrefetchBatchAsync does not support the default stream"); -#if (CUDART_VERSION < 13000) - for (size_t i = 0; i < tensors.size(); ++i) { - const auto &tensor = tensors[i]; - TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); - TORCH_CHECK(cameraOffset + cameraCount <= tensor.size(0), - "Tensor does not have a batched first dimension"); - size_t scalarSize = c10::elementSize(tensor.scalar_type()); - C10_CUDA_CHECK( - nanovdb::util::cuda::memPrefetchAsync(static_cast(tensor.data_ptr()) + - cameraOffset * tensor.stride(0) * scalarSize, - cameraCount * tensor.stride(0) * scalarSize, - deviceId, - stream)); - } -#else - std::vector prefetchPointers; - std::vector prefetchSizes; - const cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; - std::vector prefetchLocations = {location}; - std::vector prefetchLocationIndices = {0}; - - for (size_t i = 0; i < tensors.size(); ++i) { - const auto &tensor = tensors[i]; - TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); - TORCH_CHECK(cameraOffset + cameraCount <= tensor.size(0), - "Tensor does not have a batched first dimension"); - size_t scalarSize = c10::elementSize(tensor.scalar_type()); - prefetchPointers.emplace_back(static_cast(tensor.data_ptr()) + - cameraOffset * tensor.stride(0) * scalarSize); - prefetchSizes.emplace_back(cameraCount * tensor.stride(0) * scalarSize); - } - C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPointers.data(), - prefetchSizes.data(), - prefetchPointers.size(), - prefetchLocations.data(), - prefetchLocationIndices.data(), - prefetchLocations.size(), - 0, - stream)); -#endif -} - -} // namespace fvdb::detail::ops diff --git a/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh b/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh deleted file mode 100644 index 1656a6e19..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianUtils.cuh +++ /dev/null @@ -1,644 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_CUH - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -/// @brief Performs a binary search on a sorted array to find the insertion point for a value -/// -/// This function implements a standard binary search algorithm to find the last index -/// in the array where the element is less than or equal to the specified value. -/// -/// The function assumes that the array is sorted in non-decreasing order: -/// arr[0] <= arr[1] <= ... <= arr[len-1] -/// -/// Time complexity: O(log n) where n is the length of the array -/// -/// @tparam T Type of elements in the array (must support comparison operators) -/// @param arr Pointer to the sorted array -/// @param len Length of the array -/// @param val Value to search for -/// @return Index of the last element that is less than or equal to val, or -1 if no such element -/// exists -template -inline __device__ uint32_t -binSearch(const T *arr, const uint32_t len, const T val) { - uint32_t low = 0, high = len - 1; - while (low <= high) { - const uint32_t mid = (low + high) / 2; - if (arr[mid] <= val) { - low = mid + 1; - } else { - high = mid - 1; - } - } - return low - 1; -} - -/// @brief Converts a 3x3 rotation matrix to a quaternion. -/// -/// This function converts a 3x3 rotation matrix to the equivalent quaternion \([w,x,y,z]\) and -/// normalizes the result. -/// -/// The implementation uses the standard **branch-based** algorithm for numerical robustness: -/// - If \(\mathrm{trace}(R) > 0\), it uses the closed-form trace formula -/// \(w = \tfrac{1}{2}\sqrt{1 + \mathrm{trace}(R)}\) and derives \((x,y,z)\) from the -/// off-diagonals. -/// - Otherwise it selects the largest diagonal element and computes the quaternion from that -/// branch (x-dominant / y-dominant / z-dominant cases). -/// -/// Degenerate inputs (e.g. non-rotation matrices, NaNs) are guarded against to avoid division by -/// near-zero intermediates; in such cases the function falls back to the identity quaternion. -/// -/// @param R Input 3x3 rotation matrix. -/// @return nanovdb::math::Vec4 Quaternion equivalent to the rotation matrix. -template -__host__ __device__ nanovdb::math::Vec4 -rotationMatrixToQuaternion(const nanovdb::math::Mat3 &R) { - T trace = R[0][0] + R[1][1] + R[2][2]; - T x, y, z, w; - - // Guard against division by ~0 in the branch formulas below. - // This can happen for degenerate / NaN inputs where `t` underflows to 0 or is clamped to 0, - // causing `s = 2*sqrt(t)` to be 0, while the numerators remain finite -> inf/NaN. - const T s_min = (sizeof(T) == sizeof(float)) ? T(1e-8) : T(1e-12); - - if (trace > 0) { - T t = trace + T(1); - t = (t > T(0)) ? t : T(0); - T s = sqrt(t) * T(2); // S=4*qw - if (!(s > s_min)) { - // Degenerate input; fall back to identity. - w = T(1); - x = y = z = T(0); - } else { - w = T(0.25) * s; - x = (R[2][1] - R[1][2]) / s; - y = (R[0][2] - R[2][0]) / s; - z = (R[1][0] - R[0][1]) / s; - } - } else if ((R[0][0] > R[1][1]) && (R[0][0] > R[2][2])) { - T t = T(1) + R[0][0] - R[1][1] - R[2][2]; - t = (t > T(0)) ? t : T(0); - T s = sqrt(t) * T(2); // S=4*qx - if (!(s > s_min)) { - w = T(1); - x = y = z = T(0); - } else { - w = (R[2][1] - R[1][2]) / s; - x = T(0.25) * s; - y = (R[0][1] + R[1][0]) / s; - z = (R[0][2] + R[2][0]) / s; - } - } else if (R[1][1] > R[2][2]) { - T t = T(1) + R[1][1] - R[0][0] - R[2][2]; - t = (t > T(0)) ? t : T(0); - T s = sqrt(t) * T(2); // S=4*qy - if (!(s > s_min)) { - w = T(1); - x = y = z = T(0); - } else { - w = (R[0][2] - R[2][0]) / s; - x = (R[0][1] + R[1][0]) / s; - y = T(0.25) * s; - z = (R[1][2] + R[2][1]) / s; - } - } else { - T t = T(1) + R[2][2] - R[0][0] - R[1][1]; - t = (t > T(0)) ? t : T(0); - T s = sqrt(t) * T(2); // S=4*qz - if (!(s > s_min)) { - w = T(1); - x = y = z = T(0); - } else { - w = (R[1][0] - R[0][1]) / s; - x = (R[0][2] + R[2][0]) / s; - y = (R[1][2] + R[2][1]) / s; - z = T(0.25) * s; - } - } - - // Normalize to guard against accumulated FP error / slightly non-orthonormal inputs. - const T norm2 = (w * w + x * x + y * y + z * z); - if (norm2 > T(0)) { - const T invNorm = T(1) / sqrt(norm2); - w *= invNorm; - x *= invNorm; - y *= invNorm; - z *= invNorm; - } else { - // Degenerate input; fall back to identity. - w = T(1); - x = y = z = T(0); - } - - // Optional convention: keep a consistent sign (q and -q represent the same rotation). - if (w < T(0)) { - w = -w; - x = -x; - y = -y; - z = -z; - } - return nanovdb::math::Vec4(w, x, y, z); -} - -/// @brief Converts a quaternion to a 3x3 rotation matrix -/// -/// This function takes a quaternion [w,x,y,z] and converts it to the equivalent -/// 3x3 rotation matrix representation. The quaternion is first normalized to ensure -/// it has unit length, which is required for a proper rotation. -/// -/// The conversion uses the standard formula: -/// R = [ -/// 1-2(y²+z²) 2(xy-wz) 2(xz+wy) -/// 2(xy+wz) 1-2(x²+z²) 2(yz-wx) -/// 2(xz-wy) 2(yz+wx) 1-2(x²+y²) -/// ] -/// -/// Where w,x,y,z are the components of the normalized quaternion. -/// -/// @param quat Input quaternion in [w,x,y,z] format -/// @return nanovdb::math::Mat3 3x3 rotation matrix equivalent to the quaternion -template -inline __device__ nanovdb::math::Mat3 -quaternionToRotationMatrix(nanovdb::math::Vec4 const &quat) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - // normalize - T inverseNormalization = rsqrt(x * x + y * y + z * z + w * w); - x *= inverseNormalization; - y *= inverseNormalization; - z *= inverseNormalization; - w *= inverseNormalization; - T x2 = x * x, y2 = y * y, z2 = z * z; - T xy = x * y, xz = x * z, yz = y * z; - T wx = w * x, wy = w * y, wz = w * z; - return nanovdb::math::Mat3((1.f - 2.f * (y2 + z2)), - (2.f * (xy - wz)), - (2.f * (xz + wy)), // 1st row - (2.f * (xy + wz)), - (1.f - 2.f * (x2 + z2)), - (2.f * (yz - wx)), // 2nd row - (2.f * (xz - wy)), - (2.f * (yz + wx)), - (1.f - 2.f * (x2 + y2)) // 3rd row - ); -} - -/// @brief Normalizes a quaternion to unit length -/// -/// This function normalizes a quaternion to unit length. If the quaternion is zero, it is set to -/// the identity quaternion. -/// -/// @param q Input quaternion [w,x,y,z] -/// @return nanovdb::math::Vec4 Normalized quaternion -template -inline __host__ __device__ nanovdb::math::Vec4 -normalizeQuaternionSafe(nanovdb::math::Vec4 q) { - const T n2 = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]; - if (n2 > T(0)) { - const T invN = T(1) / sqrt(n2); - q[0] *= invN; - q[1] *= invN; - q[2] *= invN; - q[3] *= invN; - } else { - q[0] = T(1); - q[1] = q[2] = q[3] = T(0); - } - return q; -} - -/// @brief Interpolates between two quaternions using normalized linear interpolation along the -/// shortest path -/// -/// This function interpolates between two quaternions using normalized linear interpolation along -/// the shortest path. -/// -/// @param q0 First quaternion [w,x,y,z] -/// @param q1 Second quaternion [w,x,y,z] -/// @param u Interpolation factor in [0,1] -/// @return nanovdb::math::Vec4 Interpolated quaternion -template -inline __host__ __device__ nanovdb::math::Vec4 -nlerpQuaternionShortestPath(const nanovdb::math::Vec4 &q0, - nanovdb::math::Vec4 q1, - const T u) { - // Ensure shortest arc (q and -q represent the same rotation). - T dot = q0[0] * q1[0] + q0[1] * q1[1] + q0[2] * q1[2] + q0[3] * q1[3]; - if (dot < T(0)) { - q1[0] = -q1[0]; - q1[1] = -q1[1]; - q1[2] = -q1[2]; - q1[3] = -q1[3]; - } - - const T s = T(1) - u; - return normalizeQuaternionSafe(nanovdb::math::Vec4(s * q0[0] + u * q1[0], - s * q0[1] + u * q1[1], - s * q0[2] + u * q1[2], - s * q0[3] + u * q1[3])); -} - -/// @brief Computes the vector-Jacobian product for quaternion to rotation matrix transformation -/// -/// This function computes the gradient of the loss with respect to a quaternion (dL/dq) -/// given the gradient of the loss with respect to a rotation matrix (dL/dR) that was -/// derived from the quaternion. This is essentially a backwards pass through the quaternion -/// to rotation matrix transformation. -/// -/// The function first normalizes the quaternion, computes the vector-Jacobian product -/// for the normalized quaternion, and then applies the chain rule to get the gradient -/// with respect to the original quaternion. -/// -/// Mathematical details: -/// 1. Normalize quaternion to unit length -/// 2. Compute vector-Jacobian product for rotation matrix derivatives -/// 3. Project gradient to ensure it's orthogonal to the quaternion (preserving unit length) -/// 4. Apply chain rule for normalization -/// -/// @param quat Input quaternion [w,x,y,z] -/// @param dLossDRotation Gradient of loss with respect to rotation matrix (dL/dR) -/// @return nanovdb::math::Vec4 Gradient of loss with respect to quaternion (dL/dq) -template -inline __device__ nanovdb::math::Vec4 -quaternionToRotationMatrixVectorJacobianProduct(const nanovdb::math::Vec4 &quat, - const nanovdb::math::Mat3 &dLossDRotation) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - // normalize - const T inverseNormalization = rsqrt(x * x + y * y + z * z + w * w); - x *= inverseNormalization; - y *= inverseNormalization; - z *= inverseNormalization; - w *= inverseNormalization; - const nanovdb::math::Vec4 dLossDQuatNormalized( - 2.f * (x * (dLossDRotation[2][1] - dLossDRotation[1][2]) + - y * (dLossDRotation[0][2] - dLossDRotation[2][0]) + - z * (dLossDRotation[1][0] - dLossDRotation[0][1])), - 2.f * (-2.f * x * (dLossDRotation[1][1] + dLossDRotation[2][2]) + - y * (dLossDRotation[1][0] + dLossDRotation[0][1]) + - z * (dLossDRotation[2][0] + dLossDRotation[0][2]) + - w * (dLossDRotation[2][1] - dLossDRotation[1][2])), - 2.f * (x * (dLossDRotation[1][0] + dLossDRotation[0][1]) - - 2.f * y * (dLossDRotation[0][0] + dLossDRotation[2][2]) + - z * (dLossDRotation[2][1] + dLossDRotation[1][2]) + - w * (dLossDRotation[0][2] - dLossDRotation[2][0])), - 2.f * (x * (dLossDRotation[2][0] + dLossDRotation[0][2]) + - y * (dLossDRotation[2][1] + dLossDRotation[1][2]) - - 2.f * z * (dLossDRotation[0][0] + dLossDRotation[1][1]) + - w * (dLossDRotation[1][0] - dLossDRotation[0][1]))); - - const nanovdb::math::Vec4 quatNormalized(w, x, y, z); - return (dLossDQuatNormalized - dLossDQuatNormalized.dot(quatNormalized) * quatNormalized) * - inverseNormalization; -} - -/// @brief Computes gradients of loss with respect to quaternion and scale (or log_scale) parameters -/// -/// This function calculates the vector-Jacobian product for quaternion and scale parameters -/// that were used to generate a covariance matrix. It's used in the backward pass of -/// automatic differentiation when computing gradients through the covariance matrix computation. -/// -/// The covariance matrix is computed as C = M * M^T where M = R * S, with: -/// - R being the rotation matrix derived from the quaternion -/// - S being the diagonal scale matrix -/// -/// The function implements the chain rule to propagate gradients from the covariance matrix -/// back to the quaternion and scale parameters. -/// -/// Mathematical details: -/// 1. For matrix operations D = M * M^T, the gradient follows: -/// dL/dM = (dL/dD + (dL/dD)^T) * M -/// 2. For D = R * S, the gradient follows: -/// dL/dR = (dL/dD) * S^T and dL/dS = R^T * (dL/dD) -/// 3. When ApplyLogScaleChainRule is true and scale = exp(log_scale), chain rule gives: -/// dL/d(log_scale) = dL/d(scale) * scale -/// -/// @tparam T Scalar type (float or double) -/// @tparam ApplyLogScaleChainRule If true, returns dL/d(log_scale) by multiplying dL/d(scale) -/// by scale. If false, returns raw dL/d(scale). Default is true for backward -/// compatibility with callers that pass log_scales. -/// @param quat Input quaternion [w,x,y,z] -/// @param scale Scale parameters [sx,sy,sz]. When ApplyLogScaleChainRule=true, these should -/// be exp(log_scale). -/// @param R Precomputed rotation matrix from the quaternion -/// @param dLossDCovar Gradient of loss with respect to the covariance matrix -/// @return Tuple containing gradients for quaternion and scale (or log_scale) parameters -template -inline __device__ std::tuple, nanovdb::math::Vec3> -quaternionAndScaleToCovarianceVectorJacobianProduct(const nanovdb::math::Vec4 &quat, - const nanovdb::math::Vec3 &scale, - // precompute - const nanovdb::math::Mat3 &R, - // grad outputs - const nanovdb::math::Mat3 &dLossDCovar) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - T sx = scale[0], sy = scale[1], sz = scale[2]; - - // M = R * S - const nanovdb::math::Mat3 S(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); - const nanovdb::math::Mat3 M = R * S; - - // https://math.stackexchange.com/a/3850121 - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - // so - // for D = M * Mt, - // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M - const nanovdb::math::Mat3 dLossDM = (dLossDCovar + dLossDCovar.transpose()) * M; - const nanovdb::math::Mat3 dLossDR = dLossDM * S.transpose(); - - // grad for (quat, scale) from covar - const nanovdb::math::Vec4 &dLossDQuat = - quaternionToRotationMatrixVectorJacobianProduct(quat, dLossDR); - - // Row-major dot products for gradients w.r.t. scale - const nanovdb::math::Vec3 dLossDScale( - R[0][0] * dLossDM[0][0] + R[1][0] * dLossDM[1][0] + R[2][0] * dLossDM[2][0], - R[0][1] * dLossDM[0][1] + R[1][1] * dLossDM[1][1] + R[2][1] * dLossDM[2][1], - R[0][2] * dLossDM[0][2] + R[1][2] * dLossDM[1][2] + R[2][2] * dLossDM[2][2]); - - if constexpr (ApplyLogScaleChainRule) { - // Apply chain rule for log_scale: dL/d(log_scale) = dL/d(scale) * scale - // since scale = exp(log_scale), d(scale)/d(log_scale) = scale - return { - dLossDQuat, - nanovdb::math::Vec3(sx * dLossDScale[0], sy * dLossDScale[1], sz * dLossDScale[2])}; - } else { - // Return raw dL/d(scale) without chain rule - return {dLossDQuat, dLossDScale}; - } -} - -/// @brief Convert a quaternion and scale to a covariance matrix -/// -/// This function computes a 3x3 covariance matrix from a quaternion and scale parameters. -/// The covariance matrix represents the shape and orientation of a 3D Gaussian distribution. -/// -/// The computation follows the formula C = M * M^T where M = R * S, with: -/// - R being the rotation matrix derived from the quaternion -/// - S being the diagonal scale matrix -/// -/// This representation allows for efficient transformation of Gaussian distributions -/// in 3D space, where the quaternion controls the orientation and the scale parameters -/// control the extent along each principal axis. -/// -/// @param quat Input quaternion [w,x,y,z] representing rotation -/// @param scale Scale parameters [sx,sy,sz] representing extent along principal axes -/// @return 3x3 covariance matrix representing the Gaussian's shape and orientation -template -inline __device__ nanovdb::math::Mat3 -quaternionAndScaleToCovariance(const nanovdb::math::Vec4 &quat, - const nanovdb::math::Vec3 &scale) { - const nanovdb::math::Mat3 &R = quaternionToRotationMatrix(quat); - // C = R * S * S * Rt - const nanovdb::math::Mat3 S(scale[0], 0.f, 0.f, 0.f, scale[1], 0.f, 0.f, 0.f, scale[2]); - const nanovdb::math::Mat3 M = R * S; - return M * M.transpose(); -} - -/// @brief Adds blur to a 2D covariance matrix and computes compensation factor -/// -/// This function adds a small epsilon value to the diagonal elements of a 2D covariance matrix -/// to ensure numerical stability and prevent degenerate cases. It also computes a compensation -/// factor that can be used to adjust other calculations to account for this added blur. -/// -/// The blur is added by increasing the diagonal elements of the covariance matrix by eps2d. -/// The compensation factor is calculated as the square root of the ratio between the original -/// determinant and the determinant after adding blur, which helps maintain proper normalization -/// of the Gaussian when rendered. -/// -/// @param eps2d Epsilon value to add to diagonal elements of the covariance matrix -/// @param outCovar Input/output 2D covariance matrix that will be modified with added blur -/// @param outCompensation Output compensation factor to adjust for the added blur -/// @return Determinant of the covariance matrix after adding blur -template -inline __device__ T -addBlur(const T eps2d, nanovdb::math::Mat2 &outCovar, T &outCompensation) { - const T det_orig = outCovar[0][0] * outCovar[1][1] - outCovar[0][1] * outCovar[1][0]; - outCovar[0][0] += eps2d; - outCovar[1][1] += eps2d; - const T det_blur = outCovar[0][0] * outCovar[1][1] - outCovar[0][1] * outCovar[1][0]; - outCompensation = sqrt(max(0.f, det_orig / det_blur)); - return det_blur; -} - -/// @brief Computes the gradient of loss with respect to a covariance matrix from blur operations -/// -/// This function implements the vector-Jacobian product calculation for the backward pass -/// of the addBlur operation. It propagates gradients from the compensation factor back -/// to the original covariance matrix before blur was applied. -/// -/// During the forward pass, a blur is added to the covariance matrix by adding an epsilon -/// to its diagonal elements, and a compensation factor is computed to normalize the Gaussian. -/// This function computes how changes in that compensation factor affect the original matrix. -/// -/// The calculations account for: -/// 1. The change in determinant from adding blur -/// 2. The relationship between the compensation factor and the covariance determinant -/// 3. The effect of the epsilon value on the gradient -/// -/// @param eps2d Epsilon value that was added to the diagonal elements -/// @param conic_blur The 2x2 covariance matrix after blur was applied -/// @param compensation The compensation factor computed during forward pass -/// @param dLossDCompensation Gradient of loss with respect to the compensation factor -/// @return 2x2 matrix representing the gradient of loss with respect to the original covariance -template -inline __device__ nanovdb::math::Mat2 -generateBlurVectorJacobianProduct(const T eps2d, - const nanovdb::math::Mat2 &conic_blur, - const T compensation, - const T dLossDCompensation) { - const T det_conic_blur = - conic_blur[0][0] * conic_blur[1][1] - conic_blur[0][1] * conic_blur[1][0]; - const T v_sqr_comp = dLossDCompensation * 0.5 / (compensation + 1e-6); - const T one_minus_sqr_comp = 1 - compensation * compensation; - return v_sqr_comp * - nanovdb::math::Mat2(one_minus_sqr_comp * conic_blur[0][0] - eps2d * det_conic_blur, - one_minus_sqr_comp * conic_blur[0][1], - one_minus_sqr_comp * conic_blur[1][0], - one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); -} - -/// @brief Transform a point from world to camera coordinates -/// -/// This function applies the world-to-camera transformation to convert a point -/// from world coordinates to camera coordinates. The transformation consists of: -/// - A rotation matrix that defines the camera's orientation in world space -/// - A translation vector that defines the camera's position in world space -/// -/// @param camToWorldRotation Rotation matrix from camera to world -/// @param camToWorldTranslation Translation vector from camera to world -/// @param worldSpacePoint The point in world coordinates -/// @return The transformed point in camera coordinates -template -inline __device__ nanovdb::math::Vec3 -transformPointWorldToCam(nanovdb::math::Mat3 const &camToWorldRotation, - nanovdb::math::Vec3 const &camToWorldTranslation, - nanovdb::math::Vec3 const &worldSpacePoint) { - return camToWorldRotation * worldSpacePoint + camToWorldTranslation; -} - -/// @brief Computes gradients for the world-to-camera point transformation -/// -/// This function calculates the vector-Jacobian product (VJP) for the backward pass -/// of the point transformation from world to camera space. It propagates gradients -/// through the transformation p_camera = R * p_world + t. -/// -/// Given upstream gradients with respect to the camera-space point (dL/dp_camera), -/// this function computes the gradients with respect to: -/// 1. The rotation matrix (dL/dR) -/// 2. The translation vector (dL/dt) -/// 3. The original world-space point (dL/dp_world) -/// -/// The implementation follows the chain rule for matrix-vector operations: -/// - dL/dR = (dL/dp_camera) * (p_world)^T -/// - dL/dt = dL/dp_camera -/// - dL/dp_world = R^T * (dL/dp_camera) -/// -/// @param camToWorldRotation Rotation matrix from camera to world -/// @param camToWorldTranslation Translation vector from camera to world -/// @param worldSpacePoint The original point in world coordinates -/// @param dLossDPointCamera Upstream gradient with respect to camera-space point (dL/dp_camera) -/// @return Tuple of gradients (dL/dR, dL/dt, dL/dp_world) -template -inline __device__ std::tuple, nanovdb::math::Vec3, nanovdb::math::Vec3> -transformPointWorldToCamVectorJacobianProduct(const nanovdb::math::Mat3 &camToWorldRotation, - const nanovdb::math::Vec3 &camToWorldTranslation, - const nanovdb::math::Vec3 &worldSpacePoint, - // grad - const nanovdb::math::Vec3 &dLossDPointCamera) { - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - return {dLossDPointCamera.outer(worldSpacePoint), - dLossDPointCamera, - camToWorldRotation.transpose() * dLossDPointCamera}; -} - -/// @brief Transform a covariance matrix from world to camera coordinates -/// -/// This function transforms a 3x3 covariance matrix from world coordinate space to -/// camera coordinate space using a rotation matrix. The transformation follows -/// the sandwich rule for covariance matrices: -/// -/// covar_camera = R * covar_world * R^T -/// -/// where R is the rotation matrix from world to camera space and R^T is its transpose. -/// -/// This transformation preserves the properties of the covariance matrix while -/// correctly reorienting it according to the camera's viewpoint. It's commonly -/// used when projecting 3D Gaussian distributions to camera space for rendering -/// or further processing. -/// -/// @param R Rotation matrix from world to camera space -/// @param covar Covariance matrix in world coordinates -/// @return Transformed covariance matrix in camera coordinates -template -inline __device__ nanovdb::math::Mat3 -transformCovarianceWorldToCam(nanovdb::math::Mat3 const &R, - nanovdb::math::Mat3 const &covar) { - return R * covar * R.transpose(); -} - -/// @brief Computes gradients for the world-to-camera covariance transformation -/// -/// This function calculates the vector-Jacobian product (VJP) for the backward pass -/// of the covariance matrix transformation from world to camera space. It propagates -/// gradients through the transformation covar_camera = R * covar_world * R^T. -/// -/// Given upstream gradients with respect to the camera-space covariance (dL/dcovar_camera), -/// this function computes the gradients with respect to: -/// 1. The rotation matrix (dL/dR) -/// 2. The original world-space covariance matrix (dL/dcovar_world) -/// -/// The implementation applies the chain rule for matrix operations: -/// - dL/dR = dL/dcovar_camera * R * covar_world^T + (dL/dcovar_camera)^T * R * covar_world -/// - dL/dcovar_world = R^T * dL/dcovar_camera * R -/// -/// @param R Rotation matrix from world to camera -/// @param covar World-space covariance matrix -/// @param dLossDCovarCamera Upstream gradient with respect to camera-space covariance -/// @return Tuple of gradients (dL/dR, dL/dcovar_world) -template -inline __device__ std::tuple, nanovdb::math::Mat3> -transformCovarianceWorldToCamVectorJacobianProduct( - // fwd inputs - const nanovdb::math::Mat3 &R, - const nanovdb::math::Mat3 &covar, - // grad outputs - const nanovdb::math::Mat3 &dLossDCovarCamera) { - // for D = W * X * WT, G = df/dD - // df/dX = WT * G * W - // df/dW - // = G * (X * WT)T + ((W * X)T * G)T - // = G * W * XT + (XT * WT * G)T - // = G * W * XT + GT * W * X - return {dLossDCovarCamera * R * covar.transpose() + dLossDCovarCamera.transpose() * R * covar, - R.transpose() * dLossDCovarCamera * R}; -} - -/// @brief Computes gradient for matrix inverse operation -/// -/// This function calculates the vector-Jacobian product (VJP) for the backward pass -/// of a matrix inverse operation. Given the inverse of a matrix (P = M^-1) and -/// the upstream gradient with respect to that inverse (dL/dP), it computes the -/// gradient with respect to the original matrix (dL/dM). -/// -/// The mathematical formula used is: -/// dL/dM = -P * dL/dP * P = -M^-1 * dL/dP * M^-1 -/// -/// This calculation is derived from the differential of matrix inverse: -/// d(M^-1) = -M^-1 * dM * M^-1 -/// -/// @tparam T Matrix type that supports multiplication operations -/// @param MInv The inverse matrix (M^-1) -/// @param dLossDMInv Upstream gradient with respect to the inverse matrix (dL/dP) -/// @return Gradient with respect to the original matrix (dL/dM) -template -inline __device__ T -inverseVectorJacobianProduct(const T &MInv, const T &dLossDMInv) { - // P = M^-1 - // df/dM = -P * df/dP * P - return -MInv * dLossDMInv * MInv; -} - -using tilePixelMaskAccessor = fvdb::TorchRAcc64; -static constexpr uint32_t sTileBitmaskBitsPerWord = 64; - -inline uint32_t -numWordsPerTileBitmask(const uint32_t tileSideLength) { - return (tileSideLength * tileSideLength + sTileBitmaskBitsPerWord - 1) / - sTileBitmaskBitsPerWord; -} - -inline __device__ uint32_t -bitmaskWordIndex(const uint32_t bitIndex) { - return bitIndex / sTileBitmaskBitsPerWord; -} -inline __device__ uint32_t -bitmaskBitIndex(const uint32_t bitIndex) { - return bitIndex % sTileBitmaskBitsPerWord; -} - -inline __device__ bool -tilePixelActive(tilePixelMaskAccessor const &tilePixelMask, - const uint32_t tileSideLength, - const uint32_t tileId, - const uint32_t iInTile, - const uint32_t jInTile) { - const uint32_t bitIndex = iInTile * tileSideLength + jInTile; - return tilePixelMask[tileId][bitmaskWordIndex(bitIndex)] & (1ull << bitmaskBitIndex(bitIndex)); -} - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianUtils.h b/src/fvdb/detail/ops/gsplat/GaussianUtils.h deleted file mode 100644 index b11ba9718..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianUtils.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_H - -#include - -#include - -namespace fvdb { -namespace detail { -namespace ops { - -// Given a contiguous tensor with dimensions [C, ...] where C is the number of cameras, we prefetch -// the slices [cameraOffset : cameraCount, ...] to the specified device ordered on the input stream. -void perCameraPrefetchAsync(const torch::Tensor &tensor, - uint32_t cameraOffset, - uint32_t cameraCount, - int deviceId, - cudaStream_t stream); - -// Given a list of contiguous tensors each with dimensions [C, ...] where C is the number of -// cameras, we prefetch the slices [cameraOffset : cameraCount, ...] to the specified device ordered -// on the input stream in a single asynchronous batched prefetch call. -void perCameraPrefetchBatchAsync(const torch::TensorList &tensors, - uint32_t cameraOffset, - uint32_t cameraCount, - int deviceId, - cudaStream_t stream); - -} // namespace ops -} // namespace detail -} // namespace fvdb - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANUTILS_H diff --git a/src/fvdb/detail/ops/gsplat/GaussianVectorTypes.cuh b/src/fvdb/detail/ops/gsplat/GaussianVectorTypes.cuh deleted file mode 100644 index ada481101..000000000 --- a/src/fvdb/detail/ops/gsplat/GaussianVectorTypes.cuh +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANVECTORTYPES_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANVECTORTYPES_CUH - -#include - -#include -#include -#include - -#include - -/* -Wrap 2D vector types for different scalar types -*/ -template struct Vec2Type {}; - -template <> struct Vec2Type { - using type = char2; -}; - -template <> struct Vec2Type { - using type = uchar2; -}; - -template <> struct Vec2Type { - using type = short2; -}; - -template <> struct Vec2Type { - using type = ushort2; -}; - -template <> struct Vec2Type { - using type = int2; -}; - -template <> struct Vec2Type { - using type = uint2; -}; - -template <> struct Vec2Type { - using type = long2; -}; - -template <> struct Vec2Type { - using type = ulong2; -}; - -template <> struct Vec2Type { - using type = float2; -}; - -template <> struct Vec2Type { - using type = double2; -}; - -/* -Wrap 3D vector types for different scalar types -*/ -template struct Vec3Type {}; - -template <> struct Vec3Type { - using type = char3; -}; - -template <> struct Vec3Type { - using type = uchar3; -}; - -template <> struct Vec3Type { - using type = short3; -}; - -template <> struct Vec3Type { - using type = ushort3; -}; - -template <> struct Vec3Type { - using type = int3; -}; - -template <> struct Vec3Type { - using type = uint3; -}; - -template <> struct Vec3Type { - using type = long3; -}; - -template <> struct Vec3Type { - using type = ulong3; -}; - -template <> struct Vec3Type { - using type = float3; -}; - -template <> struct Vec3Type { - using type = double3; -}; - -/* -Wrap scalar types, and upcast half precision to float32 -*/ -template struct OpType { - typedef T type; -}; - -template <> struct OpType<__nv_bfloat16> { - typedef float type; -}; - -template <> struct OpType<__half> { - typedef float type; -}; - -template <> struct OpType { - typedef float type; -}; - -template <> struct OpType { - typedef float type; -}; - -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANVECTORTYPES_CUH diff --git a/src/fvdb/detail/utils/cuda/Alignment.cuh b/src/fvdb/detail/utils/cuda/Alignment.cuh new file mode 100644 index 000000000..3e17f72e0 --- /dev/null +++ b/src/fvdb/detail/utils/cuda/Alignment.cuh @@ -0,0 +1,38 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_ALIGNMENT_CUH +#define FVDB_DETAIL_UTILS_CUDA_ALIGNMENT_CUH + +#include +#include + +namespace fvdb { +namespace detail { + +/// Round up a byte count to the next multiple of a power-of-two alignment. +#ifdef __CUDACC__ +inline __host__ __device__ +#else +inline +#endif +constexpr size_t +alignUpBytes(const size_t value, const size_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +/// Advance an address (as uintptr_t) to the next multiple of a power-of-two alignment. +#ifdef __CUDACC__ +inline __host__ __device__ +#else +inline +#endif +constexpr uintptr_t +alignUpAddress(const uintptr_t value, const size_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_ALIGNMENT_CUH diff --git a/src/fvdb/detail/utils/cuda/BinSearch.cuh b/src/fvdb/detail/utils/cuda/BinSearch.cuh new file mode 100644 index 000000000..671e6470f --- /dev/null +++ b/src/fvdb/detail/utils/cuda/BinSearch.cuh @@ -0,0 +1,52 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_BINSEARCH_CUH +#define FVDB_DETAIL_UTILS_CUDA_BINSEARCH_CUH + +#include + +namespace fvdb { +namespace detail { + +/// @brief Binary search on a sorted array for the insertion point of a value. +/// +/// Finds the last index where `arr[i] <= val` in a non-decreasing array. +/// +/// @pre `len > 0` and `val >= arr[0]`. All current call sites satisfy both +/// preconditions (tile-offset lookups are always within bounds). +/// Passing `len == 0` or `val < arr[0]` returns 0 as a safe fallback, +/// but callers should not rely on this. +/// +/// Time complexity: O(log n). +/// +/// @tparam T Element type (must support comparison operators). +/// @param arr Pointer to the sorted array. +/// @param len Length of the array. +/// @param val Value to search for. +/// @return Index of the last element <= val, or 0 if preconditions are violated. +template +inline __device__ uint32_t +binSearch(const T *arr, const uint32_t len, const T val) { + if (len == 0) { + return 0; + } + uint32_t low = 0, high = len - 1; + while (low <= high) { + const uint32_t mid = (low + high) / 2; + if (arr[mid] <= val) { + low = mid + 1; + } else { + if (mid == 0) { + return 0; + } + high = mid - 1; + } + } + return low - 1; +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_BINSEARCH_CUH diff --git a/src/fvdb/detail/utils/cuda/CopyCoords.cuh b/src/fvdb/detail/utils/cuda/CopyCoords.cuh new file mode 100644 index 000000000..6f80f6abd --- /dev/null +++ b/src/fvdb/detail/utils/cuda/CopyCoords.cuh @@ -0,0 +1,113 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_COPYCOORDS_CUH +#define FVDB_DETAIL_UTILS_CUDA_COPYCOORDS_CUH + +#include + +#include + +namespace fvdb { +namespace detail { + +/// Write all ijk coordinates within a bounding box (offset by ijk0) into output tensors. +__device__ inline void +copyCoords(const fvdb::JIdxType bidx, + const int64_t base, + const nanovdb::Coord &ijk0, + const nanovdb::CoordBBox &bbox, + TorchRAcc64 outIJK, + TorchRAcc64 outIJKBIdx) { + static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); + nanovdb::Coord ijk; + int32_t count = 0; + for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { + for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { + for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { + ijk = ijk0 + nanovdb::Coord(di, dj, dk); + outIJK[base + count][0] = ijk[0]; + outIJK[base + count][1] = ijk[1]; + outIJK[base + count][2] = ijk[2]; + outIJKBIdx[base + count] = bidx; + count += 1; + } + } + } +} + +/// Overload taking a size Coord instead of a CoordBBox (builds bbox as [0, size-1]). +__device__ inline void +copyCoords(const fvdb::JIdxType bidx, + const int64_t base, + const nanovdb::Coord size, + const nanovdb::Coord &ijk0, + TorchRAcc64 outIJK, + TorchRAcc64 outIJKBIdx) { + return copyCoords(bidx, + base, + ijk0, + nanovdb::CoordBBox(nanovdb::Coord(0), size - nanovdb::Coord(1)), + outIJK, + outIJKBIdx); +} + +/// Write a single coordinate if all voxels in the bbox (offset by ijk0) are active. +__device__ inline void +copyCoordsWithoutBorder( + const typename nanovdb::DefaultReadAccessor gridAccessor, + const fvdb::JIdxType bidx, + const int64_t base, + const nanovdb::Coord &ijk0, + const nanovdb::CoordBBox &bbox, + const TorchRAcc64 packInfoBase, + TorchRAcc64 outIJK, + TorchRAcc64 outIJKBIdx) { + static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); + nanovdb::Coord ijk; + bool active = true; + for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { + for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { + for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { + ijk = ijk0 + nanovdb::Coord(di, dj, dk); + active = active && gridAccessor.isActive(ijk); + } + } + } + if (active) { + int64_t outBase = packInfoBase[base]; + outIJK[outBase][0] = ijk0[0]; + outIJK[outBase][1] = ijk0[1]; + outIJK[outBase][2] = ijk0[2]; + outIJKBIdx[outBase] = bidx; + } +} + +/// Count 1 if all voxels in the bbox (offset by ijk0) are active, 0 otherwise. +__device__ inline void +countCoordsWithoutBorder( + const typename nanovdb::DefaultReadAccessor gridAccessor, + const fvdb::JIdxType bidx, + const int64_t base, + const nanovdb::Coord &ijk0, + const nanovdb::CoordBBox &bbox, + TorchRAcc64 outCounter) { + static_assert(sizeof(nanovdb::Coord) == 3 * sizeof(int32_t)); + nanovdb::Coord ijk; + bool active = true; + for (int di = bbox.min()[0]; di <= bbox.max()[0]; di += 1) { + for (int dj = bbox.min()[1]; dj <= bbox.max()[1]; dj += 1) { + for (int dk = bbox.min()[2]; dk <= bbox.max()[2]; dk += 1) { + ijk = ijk0 + nanovdb::Coord(di, dj, dk); + active = active && gridAccessor.isActive(ijk); + } + } + } + + outCounter[base] = active ? 1 : 0; +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_COPYCOORDS_CUH diff --git a/src/fvdb/detail/utils/cuda/CubWrapper.cuh b/src/fvdb/detail/utils/cuda/CubWrapper.cuh new file mode 100644 index 000000000..522a93ef1 --- /dev/null +++ b/src/fvdb/detail/utils/cuda/CubWrapper.cuh @@ -0,0 +1,23 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_CUBWRAPPER_CUH +#define FVDB_DETAIL_UTILS_CUDA_CUBWRAPPER_CUH + +#include +#include + +/// Convenience wrapper for CUB API calls that require a temp-storage probe + allocation +/// pattern. Usage: +/// CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, keys, values, N, 0, sizeof(int64_t) * 8, stream); +#define CUB_WRAPPER(func, ...) \ + do { \ + size_t tempStorageBytes = 0; \ + C10_CUDA_CHECK(func(nullptr, tempStorageBytes, __VA_ARGS__)); \ + auto &cachingAllocator = *::c10::cuda::CUDACachingAllocator::get(); \ + auto tempStorage = \ + tempStorageBytes > 0 ? cachingAllocator.allocate(tempStorageBytes) : ::c10::DataPtr(); \ + C10_CUDA_CHECK(func(tempStorage.get(), tempStorageBytes, __VA_ARGS__)); \ + } while (false) + +#endif // FVDB_DETAIL_UTILS_CUDA_CUBWRAPPER_CUH diff --git a/src/fvdb/detail/utils/cuda/OpType.cuh b/src/fvdb/detail/utils/cuda/OpType.cuh new file mode 100644 index 000000000..6590f2c4a --- /dev/null +++ b/src/fvdb/detail/utils/cuda/OpType.cuh @@ -0,0 +1,41 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_OPTYPE_CUH +#define FVDB_DETAIL_UTILS_CUDA_OPTYPE_CUH + +#include +#include + +#include +#include + +namespace fvdb { +namespace detail { + +/// Scalar type promotion trait: upcasts half-precision types to float for +/// numerically stable computation while leaving other types unchanged. +template struct OpType { + using type = T; +}; + +template <> struct OpType<__nv_bfloat16> { + using type = float; +}; + +template <> struct OpType<__half> { + using type = float; +}; + +template <> struct OpType { + using type = float; +}; + +template <> struct OpType { + using type = float; +}; + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_OPTYPE_CUH diff --git a/src/fvdb/detail/utils/cuda/Prefetch.cuh b/src/fvdb/detail/utils/cuda/Prefetch.cuh new file mode 100644 index 000000000..300ee89f2 --- /dev/null +++ b/src/fvdb/detail/utils/cuda/Prefetch.cuh @@ -0,0 +1,130 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_PREFETCH_CUH +#define FVDB_DETAIL_UTILS_CUDA_PREFETCH_CUH + +#include + +#include +#include +#include + +#include + +#include +#include +#include + +namespace fvdb { +namespace detail { + +/// Prefetch a contiguous batch of slices from tensor dimension 0 to a GPU. +/// +/// Given a contiguous tensor with dimensions [C, ...] where C is the number of +/// cameras/batches, prefetch slices [offset, offset+count) to the specified device. +inline void +perCameraPrefetchBatchAsync(const torch::TensorList &tensors, + uint32_t cameraOffset, + uint32_t cameraCount, + int deviceId, + cudaStream_t stream) { + TORCH_CHECK(stream, "cudaMemPrefetchBatchAsync does not support the default stream"); +#if (CUDART_VERSION < 13000) + for (size_t i = 0; i < tensors.size(); ++i) { + const auto &tensor = tensors[i]; + TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); + TORCH_CHECK(cameraOffset + cameraCount <= tensor.size(0), + "Tensor does not have a batched first dimension"); + size_t scalarSize = c10::elementSize(tensor.scalar_type()); + C10_CUDA_CHECK( + nanovdb::util::cuda::memPrefetchAsync(static_cast(tensor.data_ptr()) + + cameraOffset * tensor.stride(0) * scalarSize, + cameraCount * tensor.stride(0) * scalarSize, + deviceId, + stream)); + } +#else + std::vector prefetchPointers; + std::vector prefetchSizes; + const cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; + std::vector prefetchLocations = {location}; + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto &tensor = tensors[i]; + TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); + TORCH_CHECK(cameraOffset + cameraCount <= tensor.size(0), + "Tensor does not have a batched first dimension"); + size_t scalarSize = c10::elementSize(tensor.scalar_type()); + prefetchPointers.emplace_back(static_cast(tensor.data_ptr()) + + cameraOffset * tensor.stride(0) * scalarSize); + prefetchSizes.emplace_back(cameraCount * tensor.stride(0) * scalarSize); + } + std::vector prefetchLocationIndices(prefetchPointers.size(), 0); + C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPointers.data(), + prefetchSizes.data(), + prefetchPointers.size(), + prefetchLocations.data(), + prefetchLocationIndices.data(), + prefetchLocations.size(), + 0, + stream)); +#endif +} + +/// Prefetch float tensors by raw element offset and count. +/// +/// Used by FusedSSIM where the tensors are flat float buffers rather than +/// camera-batched tensors. +inline void +imagePrefetchBatchAsync(const torch::TensorList &tensors, + int localElementOffset, + int localElementCount, + int deviceId, + cudaStream_t stream) { + TORCH_CHECK(stream, "cudaMemPrefetchBatchAsync does not support the default stream"); + for (size_t i = 0; i < tensors.size(); ++i) { + TORCH_CHECK(tensors[i].scalar_type() == torch::kFloat32, + "imagePrefetchBatchAsync expects float32 tensors, got ", + tensors[i].scalar_type(), + " for tensor ", + i); + } +#if (CUDART_VERSION < 13000) + for (size_t i = 0; i < tensors.size(); ++i) { + const auto &tensor = tensors[i]; + TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); + C10_CUDA_CHECK( + nanovdb::util::cuda::memPrefetchAsync(tensor.data_ptr() + localElementOffset, + localElementCount * sizeof(float), + deviceId, + stream)); + } +#else + std::vector prefetchPointers; + std::vector prefetchSizes; + cudaMemLocation location = {cudaMemLocationTypeDevice, deviceId}; + std::vector prefetchLocations = {location}; + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto &tensor = tensors[i]; + TORCH_CHECK(tensor.is_contiguous(), "Tensor to prefetch is not contiguous"); + prefetchPointers.emplace_back(tensor.data_ptr() + localElementOffset); + prefetchSizes.emplace_back(localElementCount * sizeof(float)); + } + std::vector prefetchLocationIndices(prefetchPointers.size(), 0); + C10_CUDA_CHECK(cudaMemPrefetchBatchAsync(prefetchPointers.data(), + prefetchSizes.data(), + prefetchPointers.size(), + prefetchLocations.data(), + prefetchLocationIndices.data(), + prefetchLocations.size(), + 0, + stream)); +#endif +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_PREFETCH_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianWarpUtils.cuh b/src/fvdb/detail/utils/cuda/WarpReduce.cuh similarity index 73% rename from src/fvdb/detail/ops/gsplat/GaussianWarpUtils.cuh rename to src/fvdb/detail/utils/cuda/WarpReduce.cuh index 10e24b294..d958ecb9e 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianWarpUtils.cuh +++ b/src/fvdb/detail/utils/cuda/WarpReduce.cuh @@ -1,10 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANWARPUTILS_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANWARPUTILS_CUH - -#include +#ifndef FVDB_DETAIL_UTILS_CUDA_WARPREDUCE_CUH +#define FVDB_DETAIL_UTILS_CUDA_WARPREDUCE_CUH #include @@ -13,7 +11,6 @@ namespace fvdb { namespace detail { -namespace ops { template inline __device__ void @@ -58,13 +55,6 @@ warpSumMut(ScalarT &val, WarpT &warp) { val = cooperative_groups::reduce(warp, val, cooperative_groups::plus()); } -template -inline __device__ void -warpSumMut(typename Vec2Type::type &val, WarpT &warp) { - val.x = cooperative_groups::reduce(warp, val.x, cooperative_groups::plus()); - val.y = cooperative_groups::reduce(warp, val.y, cooperative_groups::plus()); -} - template inline __device__ void warpSumMut(nanovdb::math::Vec2 &val, WarpT &warp) { @@ -72,14 +62,6 @@ warpSumMut(nanovdb::math::Vec2 &val, WarpT &warp) { val[1] = cooperative_groups::reduce(warp, val[1], cooperative_groups::plus()); } -template -inline __device__ void -warpSumMut(typename Vec3Type::type &val, WarpT &warp) { - val.x = cooperative_groups::reduce(warp, val.x, cooperative_groups::plus()); - val.y = cooperative_groups::reduce(warp, val.y, cooperative_groups::plus()); - val.z = cooperative_groups::reduce(warp, val.z, cooperative_groups::plus()); -} - template inline __device__ void warpSumMut(nanovdb::math::Vec3 &val, WarpT &warp) { @@ -105,8 +87,7 @@ warpSumMut(ScalarT *val, size_t nDims, WarpT &warp) { } } -} // namespace ops } // namespace detail } // namespace fvdb -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANWARPUTILS_CUH +#endif // FVDB_DETAIL_UTILS_CUDA_WARPREDUCE_CUH diff --git a/src/fvdb/detail/utils/cuda/math/AffineTransform.cuh b/src/fvdb/detail/utils/cuda/math/AffineTransform.cuh new file mode 100644 index 000000000..9fefbc8a8 --- /dev/null +++ b/src/fvdb/detail/utils/cuda/math/AffineTransform.cuh @@ -0,0 +1,64 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_MATH_AFFINETRANSFORM_CUH +#define FVDB_DETAIL_UTILS_CUDA_MATH_AFFINETRANSFORM_CUH + +#include + +#include + +namespace fvdb { +namespace detail { + +/// Transform a point from world to camera coordinates: p_cam = R * p_world + t. +template +inline __device__ nanovdb::math::Vec3 +transformPointWorldToCam(nanovdb::math::Mat3 const &worldToCamRotation, + nanovdb::math::Vec3 const &worldToCamTranslation, + nanovdb::math::Vec3 const &worldSpacePoint) { + return worldToCamRotation * worldSpacePoint + worldToCamTranslation; +} + +/// VJP for transformPointWorldToCam. Returns (dL/dR, dL/dt, dL/dp_world). +template +inline __device__ std::tuple, nanovdb::math::Vec3, nanovdb::math::Vec3> +transformPointWorldToCamVectorJacobianProduct(const nanovdb::math::Mat3 &worldToCamRotation, + const nanovdb::math::Vec3 &worldToCamTranslation, + const nanovdb::math::Vec3 &worldSpacePoint, + const nanovdb::math::Vec3 &dLossDPointCamera) { + return {dLossDPointCamera.outer(worldSpacePoint), + dLossDPointCamera, + worldToCamRotation.transpose() * dLossDPointCamera}; +} + +/// Transform a covariance matrix from world to camera: covar_cam = R * covar_world * R^T. +template +inline __device__ nanovdb::math::Mat3 +transformCovarianceWorldToCam(nanovdb::math::Mat3 const &R, + nanovdb::math::Mat3 const &covar) { + return R * covar * R.transpose(); +} + +/// VJP for transformCovarianceWorldToCam. Returns (dL/dR, dL/dcovar_world). +template +inline __device__ std::tuple, nanovdb::math::Mat3> +transformCovarianceWorldToCamVectorJacobianProduct( + const nanovdb::math::Mat3 &R, + const nanovdb::math::Mat3 &covar, + const nanovdb::math::Mat3 &dLossDCovarCamera) { + return {dLossDCovarCamera * R * covar.transpose() + dLossDCovarCamera.transpose() * R * covar, + R.transpose() * dLossDCovarCamera * R}; +} + +/// VJP for matrix inverse: dL/dM = -M^{-1} * dL/dM^{-1} * M^{-1}. +template +inline __device__ T +inverseVectorJacobianProduct(const T &MInv, const T &dLossDMInv) { + return -MInv * dLossDMInv * MInv; +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_MATH_AFFINETRANSFORM_CUH diff --git a/src/fvdb/detail/utils/cuda/math/Rotation.cuh b/src/fvdb/detail/utils/cuda/math/Rotation.cuh new file mode 100644 index 000000000..c78995d64 --- /dev/null +++ b/src/fvdb/detail/utils/cuda/math/Rotation.cuh @@ -0,0 +1,272 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_CUDA_MATH_ROTATION_CUH +#define FVDB_DETAIL_UTILS_CUDA_MATH_ROTATION_CUH + +#include + +namespace fvdb { +namespace detail { + +/// Safely normalize a 3D vector. +/// +/// Returns `v / ||v||` when `||v|| > 0`, otherwise returns zero. +template +inline __device__ nanovdb::math::Vec3 +normalizeSafe(const nanovdb::math::Vec3 &v) { + const T n2 = v.dot(v); + if (n2 > T(0)) { + return v * (T(1) / sqrt(n2)); + } + return nanovdb::math::Vec3(T(0), T(0), T(0)); +} + +/// Vector-Jacobian product for `y = normalizeSafe(x)`. +/// +/// Given upstream gradient `v_y = dL/dy`, returns `dL/dx`. +template +inline __device__ nanovdb::math::Vec3 +normalizeSafeVJP(const nanovdb::math::Vec3 &x, const nanovdb::math::Vec3 &v_y) { + const T n2 = x.dot(x); + if (!(n2 > T(0))) { + return nanovdb::math::Vec3(T(0), T(0), T(0)); + } + const T n = sqrt(n2); + const T invn = T(1) / n; + const T invn3 = invn * invn * invn; + const T xdotv = x.dot(v_y); + return v_y * invn - x * (xdotv * invn3); +} + +/// Clamp a scalar to [0, 1]. +template +inline __device__ T +clamp01(const T x) { + return x < T(0) ? T(0) : (x > T(1) ? T(1) : x); +} + +/// @brief Converts a 3x3 rotation matrix to a quaternion [w,x,y,z]. +/// +/// Uses a branch-based algorithm for numerical robustness. Degenerate inputs +/// fall back to the identity quaternion. +template +__host__ __device__ nanovdb::math::Vec4 +rotationMatrixToQuaternion(const nanovdb::math::Mat3 &R) { + T trace = R[0][0] + R[1][1] + R[2][2]; + T x, y, z, w; + + const T s_min = (sizeof(T) == sizeof(float)) ? T(1e-8) : T(1e-12); + + if (trace > 0) { + T t = trace + T(1); + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = T(0.25) * s; + x = (R[2][1] - R[1][2]) / s; + y = (R[0][2] - R[2][0]) / s; + z = (R[1][0] - R[0][1]) / s; + } + } else if ((R[0][0] > R[1][1]) && (R[0][0] > R[2][2])) { + T t = T(1) + R[0][0] - R[1][1] - R[2][2]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[2][1] - R[1][2]) / s; + x = T(0.25) * s; + y = (R[0][1] + R[1][0]) / s; + z = (R[0][2] + R[2][0]) / s; + } + } else if (R[1][1] > R[2][2]) { + T t = T(1) + R[1][1] - R[0][0] - R[2][2]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[0][2] - R[2][0]) / s; + x = (R[0][1] + R[1][0]) / s; + y = T(0.25) * s; + z = (R[1][2] + R[2][1]) / s; + } + } else { + T t = T(1) + R[2][2] - R[0][0] - R[1][1]; + t = (t > T(0)) ? t : T(0); + T s = sqrt(t) * T(2); + if (!(s > s_min)) { + w = T(1); + x = y = z = T(0); + } else { + w = (R[1][0] - R[0][1]) / s; + x = (R[0][2] + R[2][0]) / s; + y = (R[1][2] + R[2][1]) / s; + z = T(0.25) * s; + } + } + + const T norm2 = (w * w + x * x + y * y + z * z); + if (norm2 > T(0)) { + const T invNorm = T(1) / sqrt(norm2); + w *= invNorm; + x *= invNorm; + y *= invNorm; + z *= invNorm; + } else { + w = T(1); + x = y = z = T(0); + } + + if (w < T(0)) { + w = -w; + x = -x; + y = -y; + z = -z; + } + return nanovdb::math::Vec4(w, x, y, z); +} + +/// @brief Converts a quaternion [w,x,y,z] to a 3x3 rotation matrix. +template +inline __device__ nanovdb::math::Mat3 +quaternionToRotationMatrix(nanovdb::math::Vec4 const &quat) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T inverseNormalization = rsqrt(x * x + y * y + z * z + w * w); + x *= inverseNormalization; + y *= inverseNormalization; + z *= inverseNormalization; + w *= inverseNormalization; + T x2 = x * x, y2 = y * y, z2 = z * z; + T xy = x * y, xz = x * z, yz = y * z; + T wx = w * x, wy = w * y, wz = w * z; + return nanovdb::math::Mat3((1.f - 2.f * (y2 + z2)), + (2.f * (xy - wz)), + (2.f * (xz + wy)), + (2.f * (xy + wz)), + (1.f - 2.f * (x2 + z2)), + (2.f * (yz - wx)), + (2.f * (xz - wy)), + (2.f * (yz + wx)), + (1.f - 2.f * (x2 + y2))); +} + +/// @brief Normalizes a quaternion to unit length. +/// +/// If the quaternion is zero, returns the identity quaternion. +template +inline __host__ __device__ nanovdb::math::Vec4 +normalizeQuaternionSafe(nanovdb::math::Vec4 q) { + const T n2 = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]; + if (n2 > T(0)) { + const T invN = T(1) / sqrt(n2); + q[0] *= invN; + q[1] *= invN; + q[2] *= invN; + q[3] *= invN; + } else { + q[0] = T(1); + q[1] = q[2] = q[3] = T(0); + } + return q; +} + +/// @brief NLERP between two quaternions along the shortest arc. +template +inline __host__ __device__ nanovdb::math::Vec4 +nlerpQuaternionShortestPath(const nanovdb::math::Vec4 &q0, + nanovdb::math::Vec4 q1, + const T u) { + T dot = q0[0] * q1[0] + q0[1] * q1[1] + q0[2] * q1[2] + q0[3] * q1[3]; + if (dot < T(0)) { + q1[0] = -q1[0]; + q1[1] = -q1[1]; + q1[2] = -q1[2]; + q1[3] = -q1[3]; + } + + const T s = T(1) - u; + return normalizeQuaternionSafe(nanovdb::math::Vec4(s * q0[0] + u * q1[0], + s * q0[1] + u * q1[1], + s * q0[2] + u * q1[2], + s * q0[3] + u * q1[3])); +} + +/// @brief VJP for quaternion-to-rotation-matrix conversion. +/// +/// Given dL/dR, computes dL/dq. +template +inline __device__ nanovdb::math::Vec4 +quaternionToRotationMatrixVectorJacobianProduct(const nanovdb::math::Vec4 &quat, + const nanovdb::math::Mat3 &dLossDRotation) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + const T inverseNormalization = rsqrt(x * x + y * y + z * z + w * w); + x *= inverseNormalization; + y *= inverseNormalization; + z *= inverseNormalization; + w *= inverseNormalization; + const nanovdb::math::Vec4 dLossDQuatNormalized( + 2.f * (x * (dLossDRotation[2][1] - dLossDRotation[1][2]) + + y * (dLossDRotation[0][2] - dLossDRotation[2][0]) + + z * (dLossDRotation[1][0] - dLossDRotation[0][1])), + 2.f * (-2.f * x * (dLossDRotation[1][1] + dLossDRotation[2][2]) + + y * (dLossDRotation[1][0] + dLossDRotation[0][1]) + + z * (dLossDRotation[2][0] + dLossDRotation[0][2]) + + w * (dLossDRotation[2][1] - dLossDRotation[1][2])), + 2.f * (x * (dLossDRotation[1][0] + dLossDRotation[0][1]) - + 2.f * y * (dLossDRotation[0][0] + dLossDRotation[2][2]) + + z * (dLossDRotation[2][1] + dLossDRotation[1][2]) + + w * (dLossDRotation[0][2] - dLossDRotation[2][0])), + 2.f * (x * (dLossDRotation[2][0] + dLossDRotation[0][2]) + + y * (dLossDRotation[2][1] + dLossDRotation[1][2]) - + 2.f * z * (dLossDRotation[0][0] + dLossDRotation[1][1]) + + w * (dLossDRotation[1][0] - dLossDRotation[0][1]))); + + const nanovdb::math::Vec4 quatNormalized(w, x, y, z); + return (dLossDQuatNormalized - dLossDQuatNormalized.dot(quatNormalized) * quatNormalized) * + inverseNormalization; +} + +/// @brief Rigid transform (cached rotation + translation). +/// +/// Quaternion is stored as [w,x,y,z]. The rotation matrix is cached. +template struct RigidTransform { + nanovdb::math::Mat3 R; + nanovdb::math::Vec4 q; + nanovdb::math::Vec3 t; + + /// Construct from quaternion and translation. + inline __host__ __device__ + RigidTransform(const nanovdb::math::Vec4 &q_in, const nanovdb::math::Vec3 &t_in) + : R(quaternionToRotationMatrix(q_in)), q(q_in), t(t_in) {} + + /// Construct from rotation matrix and translation. + inline __host__ __device__ + RigidTransform(const nanovdb::math::Mat3 &R_in, const nanovdb::math::Vec3 &t_in) + : R(R_in), q(rotationMatrixToQuaternion(R_in)), t(t_in) {} + + /// Apply the transform: R * p + t. + inline __host__ __device__ nanovdb::math::Vec3 + apply(const nanovdb::math::Vec3 &p_world) const { + return R * p_world + t; + } + + /// Interpolate between two rigid transforms (linear t, NLERP q). + inline static __host__ __device__ RigidTransform + interpolate(const T u, const RigidTransform &start, const RigidTransform &end) { + const nanovdb::math::Vec3 t_interp = start.t + u * (end.t - start.t); + const nanovdb::math::Vec4 q_interp = nlerpQuaternionShortestPath(start.q, end.q, u); + return RigidTransform(q_interp, t_interp); + } +}; + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_CUDA_MATH_ROTATION_CUH diff --git a/src/fvdb/detail/ops/gsplat/Gaussian2D.cuh b/src/fvdb/detail/utils/gaussian/Gaussian2D.cuh similarity index 83% rename from src/fvdb/detail/ops/gsplat/Gaussian2D.cuh rename to src/fvdb/detail/utils/gaussian/Gaussian2D.cuh index 6c118e89b..0f5a9bbd1 100644 --- a/src/fvdb/detail/ops/gsplat/Gaussian2D.cuh +++ b/src/fvdb/detail/utils/gaussian/Gaussian2D.cuh @@ -1,10 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIAN2D_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIAN2D_CUH +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIAN2D_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIAN2D_CUH -#include +#include #include @@ -38,4 +38,4 @@ template struct alignas(32) Gaussian2D { // 28 bytes } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIAN2D_CUH +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIAN2D_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianCameraAccessorCopy.cuh b/src/fvdb/detail/utils/gaussian/GaussianCameraAccessorCopy.cuh similarity index 92% rename from src/fvdb/detail/ops/gsplat/GaussianCameraAccessorCopy.cuh rename to src/fvdb/detail/utils/gaussian/GaussianCameraAccessorCopy.cuh index 060baa7dc..c6e048b81 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianCameraAccessorCopy.cuh +++ b/src/fvdb/detail/utils/gaussian/GaussianCameraAccessorCopy.cuh @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAACCESSORCOPY_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAACCESSORCOPY_CUH +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAACCESSORCOPY_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAACCESSORCOPY_CUH #include @@ -62,4 +62,4 @@ copyDistortionCoeffs(const int64_t C, } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAACCESSORCOPY_CUH +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAACCESSORCOPY_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianCameras.cuh b/src/fvdb/detail/utils/gaussian/GaussianCameras.cuh similarity index 96% rename from src/fvdb/detail/ops/gsplat/GaussianCameras.cuh rename to src/fvdb/detail/utils/gaussian/GaussianCameras.cuh index c811d7014..52424413a 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianCameras.cuh +++ b/src/fvdb/detail/utils/gaussian/GaussianCameras.cuh @@ -1,25 +1,21 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAS_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAS_CUH +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAS_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAS_CUH + +#include +#include #include #include namespace fvdb::detail::ops { -/// @brief Align a byte count up to the next multiple of `alignment`. -inline constexpr size_t -alignUpBytes(const size_t value, const size_t alignment) { - return (value + alignment - 1) & ~(alignment - 1); -} - -/// @brief Align an address up to the next multiple of `alignment`. -inline constexpr uintptr_t -alignUpAddress(const uintptr_t value, const size_t alignment) { - return (value + alignment - 1) & ~(alignment - 1); -} +using fvdb::detail::alignUpAddress; +using fvdb::detail::alignUpBytes; +using fvdb::detail::clamp01; +using fvdb::detail::normalizeSafe; /// @brief Rolling shutter policy for camera projection / ray generation. enum class RollingShutterType : int32_t { NONE = 0, VERTICAL = 1, HORIZONTAL = 2 }; @@ -64,11 +60,11 @@ struct UTParams { #if defined(__CUDACC__) -#include -#include -#include -#include #include +#include +#include +#include +#include #include #include @@ -180,7 +176,6 @@ template struct PerspectiveCamera { PerspectiveCamera(const torch::Tensor &projectionMatrices, const torch::Tensor &worldToCamMatrices, - int32_t numCameras, int32_t imageWidth, int32_t imageHeight, T nearPlane, @@ -189,13 +184,12 @@ template struct PerspectiveCamera { projectionMatrices.template packed_accessor64()), worldToCamMatricesAcc( worldToCamMatrices.template packed_accessor64()), - numCameras(numCameras), imageWidth(imageWidth), imageHeight(imageHeight), - nearPlane(nearPlane), farPlane(farPlane) {} + imageWidth(imageWidth), imageHeight(imageHeight), nearPlane(nearPlane), + farPlane(farPlane) {} private: fvdb::TorchRAcc64 projectionMatricesAcc; // [C,3,3] fvdb::TorchRAcc64 worldToCamMatricesAcc; // [C,4,4] - int32_t numCameras = 0; int32_t imageWidth = 0; int32_t imageHeight = 0; T nearPlane = T(kBackwardProjectionNearPlane); @@ -462,7 +456,6 @@ template struct OrthographicCamera { OrthographicCamera(const torch::Tensor &projectionMatrices, const torch::Tensor &worldToCamMatrices, - int32_t numCameras, int32_t imageWidth, int32_t imageHeight, T nearPlane, @@ -471,13 +464,12 @@ template struct OrthographicCamera { projectionMatrices.template packed_accessor64()), worldToCamMatricesAcc( worldToCamMatrices.template packed_accessor64()), - numCameras(numCameras), imageWidth(imageWidth), imageHeight(imageHeight), - nearPlane(nearPlane), farPlane(farPlane) {} + imageWidth(imageWidth), imageHeight(imageHeight), nearPlane(nearPlane), + farPlane(farPlane) {} private: fvdb::TorchRAcc64 projectionMatricesAcc; // [C,3,3] fvdb::TorchRAcc64 worldToCamMatricesAcc; // [C,4,4] - int32_t numCameras = 0; int32_t imageWidth = 0; int32_t imageHeight = 0; T nearPlane = T(kBackwardProjectionNearPlane); @@ -950,22 +942,6 @@ template struct PerspectiveWithDistortionCamera { return &distortionCoeffsAcc[cid][0]; } - /// @brief Returns normalized vector, or zero when input norm is zero. - inline __device__ nanovdb::math::Vec3 - normalizeSafe(const nanovdb::math::Vec3 &v) const { - const T n2 = v.dot(v); - if (n2 > T(0)) { - return v * (T(1) / sqrt(n2)); - } - return nanovdb::math::Vec3(T(0), T(0), T(0)); - } - - /// @brief Returns scalar clamped to [0, 1]. - inline static __device__ T - clamp01(const T x) { - return (x < T(0)) ? T(0) : ((x > T(1)) ? T(1) : x); - } - /// @brief Converts pixel position to normalized rolling-shutter time in [0, 1]. inline static __device__ T rollingShutterTimeFromPixel(const RollingShutterType rollingShutterType, @@ -1335,22 +1311,6 @@ template struct OrthographicWithDistortionCamera { K[0][0], K[0][1], K[0][2], K[1][0], K[1][1], K[1][2], K[2][0], K[2][1], K[2][2]); } - /// @brief Returns normalized vector, or zero when input norm is zero. - inline __device__ nanovdb::math::Vec3 - normalizeSafe(const nanovdb::math::Vec3 &v) const { - const T n2 = v.dot(v); - if (n2 > T(0)) { - return v * (T(1) / sqrt(n2)); - } - return nanovdb::math::Vec3(T(0), T(0), T(0)); - } - - /// @brief Returns scalar clamped to [0, 1]. - inline static __device__ T - clamp01(const T x) { - return (x < T(0)) ? T(0) : ((x > T(1)) ? T(1) : x); - } - /// @brief Converts pixel position to normalized rolling-shutter time in [0, 1]. inline static __device__ T rollingShutterTimeFromPixel(const RollingShutterType rollingShutterType, @@ -1417,4 +1377,4 @@ template struct OrthographicWithDistortionCamera { #endif // defined(__CUDACC__) -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANCAMERAS_CUH +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANCAMERAS_CUH diff --git a/src/fvdb/detail/utils/gaussian/GaussianMath.cuh b/src/fvdb/detail/utils/gaussian/GaussianMath.cuh new file mode 100644 index 000000000..6ca4a7047 --- /dev/null +++ b/src/fvdb/detail/utils/gaussian/GaussianMath.cuh @@ -0,0 +1,131 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANMATH_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANMATH_CUH + +#include +#include + +#include + +#include +#include + +namespace fvdb { +namespace detail { + +/// @brief VJP for quaternion-and-scale to covariance matrix transformation. +/// +/// The covariance is C = M * M^T where M = R * S, with R from the quaternion and +/// S = diag(scale). When ApplyLogScaleChainRule is true (default), returns +/// dL/d(log_scale) by multiplying by scale. +template +inline __device__ std::tuple, nanovdb::math::Vec3> +quaternionAndScaleToCovarianceVectorJacobianProduct(const nanovdb::math::Vec4 &quat, + const nanovdb::math::Vec3 &scale, + const nanovdb::math::Mat3 &R, + const nanovdb::math::Mat3 &dLossDCovar) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T sx = scale[0], sy = scale[1], sz = scale[2]; + + const nanovdb::math::Mat3 S(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); + const nanovdb::math::Mat3 M = R * S; + + const nanovdb::math::Mat3 dLossDM = (dLossDCovar + dLossDCovar.transpose()) * M; + const nanovdb::math::Mat3 dLossDR = dLossDM * S.transpose(); + + const nanovdb::math::Vec4 &dLossDQuat = + quaternionToRotationMatrixVectorJacobianProduct(quat, dLossDR); + + const nanovdb::math::Vec3 dLossDScale( + R[0][0] * dLossDM[0][0] + R[1][0] * dLossDM[1][0] + R[2][0] * dLossDM[2][0], + R[0][1] * dLossDM[0][1] + R[1][1] * dLossDM[1][1] + R[2][1] * dLossDM[2][1], + R[0][2] * dLossDM[0][2] + R[1][2] * dLossDM[1][2] + R[2][2] * dLossDM[2][2]); + + if constexpr (ApplyLogScaleChainRule) { + return { + dLossDQuat, + nanovdb::math::Vec3(sx * dLossDScale[0], sy * dLossDScale[1], sz * dLossDScale[2])}; + } else { + return {dLossDQuat, dLossDScale}; + } +} + +/// @brief Compute covariance C = R * S * S^T * R^T from quaternion and scale. +template +inline __device__ nanovdb::math::Mat3 +quaternionAndScaleToCovariance(const nanovdb::math::Vec4 &quat, + const nanovdb::math::Vec3 &scale) { + const nanovdb::math::Mat3 &R = quaternionToRotationMatrix(quat); + const nanovdb::math::Mat3 S(scale[0], 0.f, 0.f, 0.f, scale[1], 0.f, 0.f, 0.f, scale[2]); + const nanovdb::math::Mat3 M = R * S; + return M * M.transpose(); +} + +/// @brief Add blur to a 2D covariance and compute a compensation factor. +template +inline __device__ T +addBlur(const T eps2d, nanovdb::math::Mat2 &outCovar, T &outCompensation) { + const T det_orig = outCovar[0][0] * outCovar[1][1] - outCovar[0][1] * outCovar[1][0]; + outCovar[0][0] += eps2d; + outCovar[1][1] += eps2d; + const T det_blur = outCovar[0][0] * outCovar[1][1] - outCovar[0][1] * outCovar[1][0]; + outCompensation = sqrt(max(0.f, det_orig / det_blur)); + return det_blur; +} + +/// @brief VJP for the addBlur operation. +template +inline __device__ nanovdb::math::Mat2 +generateBlurVectorJacobianProduct(const T eps2d, + const nanovdb::math::Mat2 &conic_blur, + const T compensation, + const T dLossDCompensation) { + const T det_conic_blur = + conic_blur[0][0] * conic_blur[1][1] - conic_blur[0][1] * conic_blur[1][0]; + const T v_sqr_comp = dLossDCompensation * 0.5 / (compensation + 1e-6); + const T one_minus_sqr_comp = 1 - compensation * compensation; + return v_sqr_comp * + nanovdb::math::Mat2(one_minus_sqr_comp * conic_blur[0][0] - eps2d * det_conic_blur, + one_minus_sqr_comp * conic_blur[0][1], + one_minus_sqr_comp * conic_blur[1][0], + one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); +} + +/// Transmittance threshold below which a pixel is considered fully opaque. +/// Matches the 3DGS convention (exclusive comparison). +constexpr float kTransmittanceThreshold = 1e-4f; + +using tilePixelMaskAccessor = fvdb::TorchRAcc64; +static constexpr uint32_t sTileBitmaskBitsPerWord = 64; + +inline uint32_t +numWordsPerTileBitmask(const uint32_t tileSideLength) { + return (tileSideLength * tileSideLength + sTileBitmaskBitsPerWord - 1) / + sTileBitmaskBitsPerWord; +} + +inline __device__ uint32_t +bitmaskWordIndex(const uint32_t bitIndex) { + return bitIndex / sTileBitmaskBitsPerWord; +} +inline __device__ uint32_t +bitmaskBitIndex(const uint32_t bitIndex) { + return bitIndex % sTileBitmaskBitsPerWord; +} + +inline __device__ bool +tilePixelActive(tilePixelMaskAccessor const &tilePixelMask, + const uint32_t tileSideLength, + const uint32_t tileId, + const uint32_t iInTile, + const uint32_t jInTile) { + const uint32_t bitIndex = iInTile * tileSideLength + jInTile; + return tilePixelMask[tileId][bitmaskWordIndex(bitIndex)] & (1ull << bitmaskBitIndex(bitIndex)); +} + +} // namespace detail +} // namespace fvdb + +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANMATH_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterize.cuh b/src/fvdb/detail/utils/gaussian/GaussianRasterize.cuh similarity index 95% rename from src/fvdb/detail/ops/gsplat/GaussianRasterize.cuh rename to src/fvdb/detail/utils/gaussian/GaussianRasterize.cuh index 238fcd013..2bb2c7269 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterize.cuh +++ b/src/fvdb/detail/utils/gaussian/GaussianRasterize.cuh @@ -1,23 +1,48 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZE_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZE_CUH +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZE_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZE_CUH -#include -#include -#include #include +#include +#include +#include +#include +#include #include #include #include -#define PRAGMA_UNROLL _Pragma("unroll") - namespace fvdb::detail::ops { +using fvdb::detail::numWordsPerTileBitmask; +using fvdb::detail::tilePixelActive; + +struct RenderWindow2D { + std::uint32_t width = 0; + std::uint32_t height = 0; + std::uint32_t originW = 0; + std::uint32_t originH = 0; + + inline constexpr std::uint32_t + pixelCountPerCamera() const { + return width * height; + } + + inline constexpr std::uint32_t + tileExtentW(const std::uint32_t tileSize) const { + return (width + tileSize - 1) / tileSize; + } + + inline constexpr std::uint32_t + tileExtentH(const std::uint32_t tileSize) const { + return (height + tileSize - 1) / tileSize; + } +}; + // Initialize an accessor for a tensor. The tensor must be a CUDA tensor. template inline auto @@ -441,4 +466,4 @@ template struct Raste }; } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZE_CUH +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZE_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorld.cuh b/src/fvdb/detail/utils/gaussian/GaussianRasterizeFromWorld.cuh similarity index 75% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorld.cuh rename to src/fvdb/detail/utils/gaussian/GaussianRasterizeFromWorld.cuh index b169d2df7..2717309f0 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeFromWorld.cuh +++ b/src/fvdb/detail/utils/gaussian/GaussianRasterizeFromWorld.cuh @@ -1,13 +1,15 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLD_CUH -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLD_CUH +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEFROMWORLD_CUH +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEFROMWORLD_CUH -#include -#include -#include #include +#include +#include +#include +#include +#include #include #include @@ -94,65 +96,12 @@ struct RasterizeFromWorldCommonArgs { return {firstGaussianIdInBlock, lastGaussianIdInBlock}; } - inline __device__ uint32_t - pixelId(const uint32_t row, const uint32_t col) const { - return row * imageWidth + col; - } - - inline __device__ uint32_t - outputPixelBase(const uint32_t cameraId, const uint32_t pixId) const { - return cameraId * imageHeight * imageWidth + pixId; - } - - inline __device__ uint32_t - outputFeatureBase(const uint32_t cameraId, const uint32_t pixId) const { - return outputPixelBase(cameraId, pixId) * numChannels; - } - inline __device__ float backgroundValue(const uint32_t cameraId, const uint32_t channelId) const { return (backgrounds != nullptr) ? backgrounds[cameraId * numChannels + channelId] : 0.0f; } }; -/// Safely normalize a 3D vector. -/// -/// Returns `v / ||v||` when `||v|| > 0`, otherwise returns zero. -template -inline __device__ nanovdb::math::Vec3 -normalizeSafe(const nanovdb::math::Vec3 &v) { - const T n2 = v.dot(v); - if (n2 > T(0)) { - return v * (T(1) / sqrt(n2)); - } - return nanovdb::math::Vec3(T(0), T(0), T(0)); -} - -/// Vector-Jacobian product for `y = normalizeSafe(x)`. -/// -/// Given upstream gradient `v_y = dL/dy`, returns `dL/dx`. -template -inline __device__ nanovdb::math::Vec3 -normalizeSafeVJP(const nanovdb::math::Vec3 &x, const nanovdb::math::Vec3 &v_y) { - const T n2 = x.dot(x); - if (!(n2 > T(0))) { - return nanovdb::math::Vec3(T(0), T(0), T(0)); - } - const T n = sqrt(n2); - const T invn = T(1) / n; - // v_x = (I/n - x x^T / n^3) v_y - const T invn3 = invn * invn * invn; - const T xdotv = x.dot(v_y); - return v_y * invn - x * (xdotv * invn3); -} - -/// Load quaternion in [w,x,y,z] order. -template -inline __device__ nanovdb::math::Vec4 -quatLoadWxyz(const T *q) { - return nanovdb::math::Vec4(q[0], q[1], q[2], q[3]); -} - /// Build S^{-1} R^T from quaternion + scale. template inline __device__ nanovdb::math::Mat3 @@ -197,4 +146,4 @@ isclRotVectorJacobianProduct(const nanovdb::math::Vec4 &quat_wxyz, } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEFROMWORLD_CUH +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEFROMWORLD_CUH diff --git a/src/fvdb/detail/ops/gsplat/GaussianRasterizeOptionalInputs.h b/src/fvdb/detail/utils/gaussian/GaussianRasterizeOptionalInputs.h similarity index 73% rename from src/fvdb/detail/ops/gsplat/GaussianRasterizeOptionalInputs.h rename to src/fvdb/detail/utils/gaussian/GaussianRasterizeOptionalInputs.h index 79239f53d..4d38c32e0 100644 --- a/src/fvdb/detail/ops/gsplat/GaussianRasterizeOptionalInputs.h +++ b/src/fvdb/detail/utils/gaussian/GaussianRasterizeOptionalInputs.h @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 // -#ifndef FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEOPTIONALINPUTS_H -#define FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEOPTIONALINPUTS_H +#ifndef FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEOPTIONALINPUTS_H +#define FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEOPTIONALINPUTS_H #include #include @@ -19,13 +19,13 @@ struct PreparedRasterOptionalInputs { }; inline PreparedRasterOptionalInputs -prepareRasterOptionalInputs(const torch::Tensor &features, - const int64_t C, - const int64_t tileExtentH, - const int64_t tileExtentW, - const int64_t numChannels, - const at::optional &backgrounds, - const at::optional &masks) { +prepare_raster_optional_inputs(const torch::Tensor &features, + const int64_t C, + const int64_t tileExtentH, + const int64_t tileExtentW, + const int64_t numChannels, + const at::optional &backgrounds, + const at::optional &masks) { PreparedRasterOptionalInputs out; if (backgrounds.has_value()) { @@ -56,4 +56,4 @@ prepareRasterOptionalInputs(const torch::Tensor &features, } // namespace fvdb::detail::ops -#endif // FVDB_DETAIL_OPS_GSPLAT_GAUSSIANRASTERIZEOPTIONALINPUTS_H +#endif // FVDB_DETAIL_UTILS_GAUSSIAN_GAUSSIANRASTERIZEOPTIONALINPUTS_H diff --git a/src/fvdb/detail/viewer/GaussianSplat3dView.h b/src/fvdb/detail/viewer/GaussianSplat3dView.h index 40784d745..64ba462a7 100644 --- a/src/fvdb/detail/viewer/GaussianSplat3dView.h +++ b/src/fvdb/detail/viewer/GaussianSplat3dView.h @@ -4,10 +4,9 @@ #ifndef FVDB_DETAIL_VIEWER_GAUSSIANSPLAT3DVIEW_H #define FVDB_DETAIL_VIEWER_GAUSSIANSPLAT3DVIEW_H -#include - #include +#include #include namespace fvdb::detail::viewer { diff --git a/src/fvdb/detail/viewer/Viewer.cpp b/src/fvdb/detail/viewer/Viewer.cpp index 2737a0dcb..1af8c8d68 100644 --- a/src/fvdb/detail/viewer/Viewer.cpp +++ b/src/fvdb/detail/viewer/Viewer.cpp @@ -2,9 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "fvdb/detail/viewer/GaussianSplat3dView.h" - #include +#include #include #include @@ -181,7 +180,12 @@ Viewer::removeView(const std::string &scene_name, const std::string &name) { fvdb::detail::viewer::GaussianSplat3dView & Viewer::addGaussianSplat3dView(const std::string &scene_name, const std::string &name, - const GaussianSplat3d &splats) { + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &sh0, + const torch::Tensor &shN) { std::shared_ptr oldData; auto itPrev = mSplat3dViews.find(name); if (itPrev != mSplat3dViews.end()) { @@ -191,14 +195,6 @@ Viewer::addGaussianSplat3dView(const std::string &scene_name, auto [it, inserted] = mSplat3dViews.emplace( std::piecewise_construct, std::forward_as_tuple(name), std::forward_as_tuple(name, *this)); - // Get the various tensors to pass to the viewer - torch::Tensor means = splats.means(); - torch::Tensor quats = splats.quats(); - torch::Tensor logScales = splats.logScales(); - torch::Tensor logitOpacities = splats.logitOpacities(); - torch::Tensor sh0 = splats.sh0(); - torch::Tensor shN = splats.shN(); - auto makeComputeArray = [this](const torch::Tensor &tensor) -> pnanovdb_compute_array_t * { torch::Tensor contig = tensor.cpu().contiguous(); size_t total_size = 1; @@ -376,19 +372,19 @@ Viewer::setCameraFar(const std::string &scene_name, float far) { updateCamera(scene_name); } -GaussianSplat3d::CameraModel +fvdb::detail::ops::DistortionModel Viewer::cameraModel(const std::string &scene_name) { getCamera(scene_name); - return mEditor.camera.config.is_orthographic ? GaussianSplat3d::CameraModel::ORTHOGRAPHIC - : GaussianSplat3d::CameraModel::PINHOLE; + return mEditor.camera.config.is_orthographic ? fvdb::detail::ops::DistortionModel::ORTHOGRAPHIC + : fvdb::detail::ops::DistortionModel::PINHOLE; } void -Viewer::setCameraModel(const std::string &scene_name, GaussianSplat3d::CameraModel model) { +Viewer::setCameraModel(const std::string &scene_name, fvdb::detail::ops::DistortionModel model) { getCamera(scene_name); - if (model == GaussianSplat3d::CameraModel::PINHOLE) { + if (model == fvdb::detail::ops::DistortionModel::PINHOLE) { mEditor.camera.config.is_orthographic = PNANOVDB_FALSE; - } else if (model == GaussianSplat3d::CameraModel::ORTHOGRAPHIC) { + } else if (model == fvdb::detail::ops::DistortionModel::ORTHOGRAPHIC) { mEditor.camera.config.is_orthographic = PNANOVDB_TRUE; } else { throw std::invalid_argument( diff --git a/src/fvdb/detail/viewer/Viewer.h b/src/fvdb/detail/viewer/Viewer.h index d472ce273..1b2bf74bc 100644 --- a/src/fvdb/detail/viewer/Viewer.h +++ b/src/fvdb/detail/viewer/Viewer.h @@ -4,7 +4,7 @@ #ifndef FVDB_DETAIL_VIEWER_VIEWER_H #define FVDB_DETAIL_VIEWER_VIEWER_H -#include +#include #include #include @@ -69,7 +69,12 @@ class Viewer { GaussianSplat3dView &addGaussianSplat3dView(const std::string &scene_name, const std::string &name, - const GaussianSplat3d &splats); + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &sh0, + const torch::Tensor &shN); CameraView &addCameraView(const std::string &scene_name, const std::string &name, const torch::Tensor &cameraToWorldMatrices, @@ -137,8 +142,8 @@ class Viewer { float cameraFar(const std::string &scene_name); void setCameraFar(const std::string &scene_name, float far); - void setCameraModel(const std::string &scene_name, GaussianSplat3d::CameraModel model); - GaussianSplat3d::CameraModel cameraModel(const std::string &scene_name); + void setCameraModel(const std::string &scene_name, fvdb::detail::ops::DistortionModel model); + fvdb::detail::ops::DistortionModel cameraModel(const std::string &scene_name); std::string ipAddress() const { diff --git a/src/python/Bindings.cpp b/src/python/Bindings.cpp index 438a43be7..99b1dd057 100644 --- a/src/python/Bindings.cpp +++ b/src/python/Bindings.cpp @@ -23,7 +23,7 @@ void bind_grid_batch_data(py::module &m); void bind_grid_batch_ops(py::module &m); void bind_jagged_tensor(py::module &m); -void bind_gaussian_splat3d(py::module &m); +void bind_gaussian_splat_ops(py::module &m); void bind_viewer(py::module &m); #define __FVDB__BUILDER_INNER(FUNC_NAME, FUNC_STR, LSHAPE_TYPE) \ @@ -135,7 +135,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { bind_grid_batch_data(m); bind_grid_batch_ops(m); bind_jagged_tensor(m); - bind_gaussian_splat3d(m); + bind_gaussian_splat_ops(m); bind_viewer(m); // diff --git a/src/python/FusedSSIMBinding.cpp b/src/python/FusedSSIMBinding.cpp index a4017a40a..4160d827d 100644 --- a/src/python/FusedSSIMBinding.cpp +++ b/src/python/FusedSSIMBinding.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include diff --git a/src/python/GaussianSplatBinding.cpp b/src/python/GaussianSplatBinding.cpp deleted file mode 100644 index 6b24bcb80..000000000 --- a/src/python/GaussianSplatBinding.cpp +++ /dev/null @@ -1,552 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 -// -#include - -#include "TypeCasters.h" - -#include -#include -#include -#include - -#include - -void -bind_gaussian_splat3d(py::module &m) { - py::enum_(m, "RollingShutterType") - .value("NONE", fvdb::detail::ops::RollingShutterType::NONE) - .value("VERTICAL", fvdb::detail::ops::RollingShutterType::VERTICAL) - .value("HORIZONTAL", fvdb::detail::ops::RollingShutterType::HORIZONTAL) - .export_values(); - - py::enum_(m, "CameraModel") - .value("PINHOLE", fvdb::detail::ops::DistortionModel::PINHOLE) - .value("OPENCV_RADTAN_5", fvdb::detail::ops::DistortionModel::OPENCV_RADTAN_5) - .value("OPENCV_RATIONAL_8", fvdb::detail::ops::DistortionModel::OPENCV_RATIONAL_8) - .value("OPENCV_RADTAN_THIN_PRISM_9", - fvdb::detail::ops::DistortionModel::OPENCV_RADTAN_THIN_PRISM_9) - .value("OPENCV_THIN_PRISM_12", fvdb::detail::ops::DistortionModel::OPENCV_THIN_PRISM_12) - .value("ORTHOGRAPHIC", fvdb::detail::ops::DistortionModel::ORTHOGRAPHIC) - .export_values(); - - py::enum_(m, "ProjectionMethod") - .value("AUTO", fvdb::detail::ops::ProjectionMethod::AUTO) - .value("ANALYTIC", fvdb::detail::ops::ProjectionMethod::ANALYTIC) - .value("UNSCENTED", fvdb::detail::ops::ProjectionMethod::UNSCENTED) - .export_values(); - - py::class_(m, "ProjectedGaussianSplats") - .def_property_readonly("means2d", &fvdb::GaussianSplat3d::ProjectedGaussianSplats::means2d) - .def_property_readonly("conics", &fvdb::GaussianSplat3d::ProjectedGaussianSplats::conics) - .def_property_readonly("render_quantities", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::renderQuantities) - .def_property_readonly("depths", &fvdb::GaussianSplat3d::ProjectedGaussianSplats::depths) - .def_property_readonly("opacities", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::opacities) - .def_property_readonly("radii", &fvdb::GaussianSplat3d::ProjectedGaussianSplats::radii) - .def_property_readonly("tile_offsets", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::offsets) - .def_property_readonly("tile_gaussian_ids", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::gaussianIds) - .def_property_readonly("image_width", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::imageWidth) - .def_property_readonly("image_height", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::imageHeight) - .def_property_readonly("near_plane", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::nearPlane) - .def_property_readonly("far_plane", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::farPlane) - .def_property_readonly("camera_model", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::cameraModel) - .def_property_readonly("projection_method", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::projectionMethod) - .def_property_readonly("sh_degree_to_use", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::shDegreeToUse) - .def_property_readonly("min_radius_2d", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::minRadius2d) - .def_property_readonly("eps_2d", &fvdb::GaussianSplat3d::ProjectedGaussianSplats::eps2d) - .def_property_readonly("antialias", - &fvdb::GaussianSplat3d::ProjectedGaussianSplats::antialias); - - py::class_ gs3d(m, "GaussianSplat3d", "A gaussian splat scene"); - - gs3d.def(py::init(), - py::arg("means"), - py::arg("quats"), - py::arg("log_scales"), - py::arg("logit_opacities"), - py::arg("sh0"), - py::arg("shN"), - py::arg("accumulate_mean_2d_gradients"), - py::arg("accumulate_max_2d_radii"), - py::arg("detach")) - .def_property_readonly("device", &fvdb::GaussianSplat3d::device) - .def_property_readonly("dtype", &fvdb::GaussianSplat3d::scalarType) - .def_property_readonly("sh_degree", &fvdb::GaussianSplat3d::shDegree) - .def_property("means", &fvdb::GaussianSplat3d::means, &fvdb::GaussianSplat3d::setMeans) - .def_property("quats", &fvdb::GaussianSplat3d::quats, &fvdb::GaussianSplat3d::setQuats) - .def_property_readonly("scales", &fvdb::GaussianSplat3d::scales) - .def_property( - "log_scales", &fvdb::GaussianSplat3d::logScales, &fvdb::GaussianSplat3d::setLogScales) - .def_property_readonly("opacities", &fvdb::GaussianSplat3d::opacities) - .def_property("logit_opacities", - &fvdb::GaussianSplat3d::logitOpacities, - &fvdb::GaussianSplat3d::setLogitOpacities) - .def_property("sh0", &fvdb::GaussianSplat3d::sh0, &fvdb::GaussianSplat3d::setSh0) - .def_property("shN", &fvdb::GaussianSplat3d::shN, &fvdb::GaussianSplat3d::setShN) - .def_property_readonly("num_gaussians", &fvdb::GaussianSplat3d::numGaussians) - .def_property_readonly("num_sh_bases", &fvdb::GaussianSplat3d::numShBases) - .def_property_readonly("num_channels", &fvdb::GaussianSplat3d::numChannels) - .def_property_readonly("requires_grad", &fvdb::GaussianSplat3d::requiresGrad) - .def_property("accumulate_max_2d_radii", - &fvdb::GaussianSplat3d::accumulateMax2dRadii, - &fvdb::GaussianSplat3d::setAccumulateMax2dRadii) - .def_property("accumulate_mean_2d_gradients", - &fvdb::GaussianSplat3d::accumulateMean2dGradients, - &fvdb::GaussianSplat3d::setAccumulateMean2dGradients) - .def_property_readonly("accumulated_mean_2d_gradient_norms", - &fvdb::GaussianSplat3d::accumulated2dMeansGradientNormsForGrad) - .def_property_readonly("accumulated_max_2d_radii", - &fvdb::GaussianSplat3d::accumulatedMax2dRadiiForGrad) - .def_property_readonly("accumulated_gradient_step_counts", - &fvdb::GaussianSplat3d::gradientStepCountsForGrad) - .def_static( - "from_state_dict", - [](const std::unordered_map &stateDict) { - return fvdb::GaussianSplat3d(stateDict); - }, - py::arg("state_dict")) - .def("to", &fvdb::GaussianSplat3d::to, py::arg("device"), py::arg("dtype")) - .def("detach", &fvdb::GaussianSplat3d::detach) - .def("detach_in_place", &fvdb::GaussianSplat3d::detachInPlace) - .def("state_dict", &fvdb::GaussianSplat3d::stateDict) - .def("load_state_dict", &fvdb::GaussianSplat3d::loadStateDict, py::arg("state_dict")) - .def_property("requires_grad", - &fvdb::GaussianSplat3d::requiresGrad, - &fvdb::GaussianSplat3d::setRequiresGrad) - .def("set_state", - &fvdb::GaussianSplat3d::setState, - py::arg("means"), - py::arg("quats"), - py::arg("log_scales"), - py::arg("logit_opacities"), - py::arg("sh0"), - py::arg("shN")) - .def("save_ply", &fvdb::GaussianSplat3d::savePly, py::arg("filename"), py::arg("metadata")) - .def_static("from_ply", - &fvdb::GaussianSplat3d::fromPly, - py::arg("filename"), - py::arg("device") = torch::kCPU) - .def_static("cat", - &fvdb::GaussianSplat3d::cat, - py::arg("splats_to_cat"), - py::arg("accumulate_mean_2d_gradients") = false, - py::arg("accumulate_max_2d_radii") = false, - py::arg("detach") = false) - .def("reset_accumulated_gradient_state", - &fvdb::GaussianSplat3d::resetAccumulatedGradientState) - .def("project_gaussians_for_images", - &fvdb::GaussianSplat3d::projectGaussiansForImages, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false) - - .def("project_gaussians_for_depths", - &fvdb::GaussianSplat3d::projectGaussiansForDepths, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false) - - .def("project_gaussians_for_images_and_depths", - &fvdb::GaussianSplat3d::projectGaussiansForImagesAndDepths, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false) - - .def("render_from_projected_gaussians", - &fvdb::GaussianSplat3d::renderFromProjectedGaussians, - py::arg("projected_gaussians"), - py::arg("crop_width") = -1, - py::arg("crop_height") = -1, - py::arg("crop_origin_w") = -1, - py::arg("crop_origin_h") = -1, - py::arg("tile_size") = 16, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_images", - &fvdb::GaussianSplat3d::renderImages, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_images_from_world", - &fvdb::GaussianSplat3d::renderImagesFromWorld, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_depths_from_world", - &fvdb::GaussianSplat3d::renderDepthsFromWorld, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_depths", - &fvdb::GaussianSplat3d::renderDepths, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_images_and_depths", - &fvdb::GaussianSplat3d::renderImagesAndDepths, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_images_and_depths_from_world", - &fvdb::GaussianSplat3d::renderImagesAndDepthsFromWorld, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("sparse_render_images", - &fvdb::GaussianSplat3d::sparseRenderImages, - py::arg("pixels_to_render"), - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("sparse_render_depths", - &fvdb::GaussianSplat3d::sparseRenderDepths, - py::arg("pixels_to_render"), - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("sparse_render_images_and_depths", - &fvdb::GaussianSplat3d::sparseRenderImagesAndDepths, - py::arg("pixels_to_render"), - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("sh_degree_to_use") = -1, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt) - - .def("render_num_contributing_gaussians", - &fvdb::GaussianSplat3d::renderNumContributingGaussians, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false) - - .def("sparse_render_num_contributing_gaussians", - &fvdb::GaussianSplat3d::sparseRenderNumContributingGaussians, - py::arg("pixels_to_render"), - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false) - - .def("render_contributing_gaussian_ids", - &fvdb::GaussianSplat3d::renderContributingGaussianIds, - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("top_k_contributors") = 0) - - .def("sparse_render_contributing_gaussian_ids", - &fvdb::GaussianSplat3d::sparseRenderContributingGaussianIds, - py::arg("pixels_to_render"), - py::arg("world_to_camera_matrices"), - py::arg("projection_matrices"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near"), - py::arg("far"), - py::arg("camera_model") = fvdb::GaussianSplat3d::CameraModel::PINHOLE, - py::arg("projection_method") = fvdb::GaussianSplat3d::ProjectionMethod::AUTO, - py::arg("distortion_coeffs") = std::nullopt, - py::arg("tile_size") = 16, - py::arg("min_radius_2d") = 0.0, - py::arg("eps_2d") = 0.3, - py::arg("antialias") = false, - py::arg("top_k_contributors") = 0) - - .def("relocate_gaussians", - &fvdb::GaussianSplat3d::relocateGaussians, - py::arg("log_scales"), - py::arg("logit_opacities"), - py::arg("ratios"), - py::arg("binomial_coeffs"), - py::arg("n_max"), - py::arg("min_opacity")) - - .def("add_noise_to_means", - &fvdb::GaussianSplat3d::addNoiseToMeans, - py::arg("noise_scale"), - py::arg("t") = 0.005, - py::arg("k") = 100.0) - - .def("index_select", &fvdb::GaussianSplat3d::indexSelect, py::arg("indices")) - .def("mask_select", &fvdb::GaussianSplat3d::maskSelect, py::arg("mask")) - .def("slice_select", - &fvdb::GaussianSplat3d::sliceSelect, - py::arg("begin"), - py::arg("end"), - py::arg("step")) - .def("index_set", &fvdb::GaussianSplat3d::indexSet, py::arg("indices"), py::arg("value")) - .def("mask_set", &fvdb::GaussianSplat3d::maskSet, py::arg("mask"), py::arg("value")) - .def("slice_set", - &fvdb::GaussianSplat3d::sliceSet, - py::arg("begin"), - py::arg("end"), - py::arg("step"), - py::arg("value")); - - m.def("gaussian_render_jagged", - &fvdb::gaussianRenderJagged, - py::arg("means"), - py::arg("quats"), - py::arg("scales"), - py::arg("opacities"), - py::arg("sh_coeffs"), - py::arg("viewmats"), - py::arg("Ks"), - py::arg("image_width"), - py::arg("image_height"), - py::arg("near_plane") = 0.01, - py::arg("far_plane") = 1e10, - py::arg("sh_degree_to_use") = 3, - py::arg("tile_size") = 16, - py::arg("radius_clip") = 0.0, - py::arg("eps2d") = 0.3, - py::arg("antialias") = false, - py::arg("render_depth_channel") = false, - py::arg("return_debug_info") = false, - py::arg("render_depth_only") = false, - py::arg("ortho") = false, - py::arg("backgrounds") = std::nullopt, - py::arg("masks") = std::nullopt); - - m.def( - "evaluate_spherical_harmonics", - [](int64_t shDegree, - int64_t numCameras, - const torch::Tensor &sh0, - const torch::Tensor &radii, - const std::optional &shN, - const std::optional &viewDirections) { - return fvdb::detail::autograd::EvaluateSphericalHarmonics::apply( - shDegree, numCameras, viewDirections, sh0, shN, radii)[0]; - }, - R"doc( -Evaluate spherical harmonics to compute view-dependent features/colors. - -This function evaluates spherical harmonics (SH) coefficients to compute -features (typically RGB colors) for a set of points, optionally considering -view directions for view-dependent appearance. - -Args: - sh_degree: Degree of spherical harmonics to use (0-3 typically). - Degree 0 uses only sh0 (view-independent). - Higher degrees require view_directions and shN. - num_cameras: Number of camera views (C). The output will have shape [C, N, D]. - sh0: DC term coefficients with shape [N, 1, D] where N is the number of - points and D is the number of feature channels. - radii: Projected radii with shape [C, N] (int32). Points with radii <= 0 - will output zeros (used to skip invisible gaussians). Pass a tensor - of ones to evaluate all points. - shN: Higher-order SH coefficients with shape [N, K-1, D] where - K = (sh_degree+1)^2. Required when sh_degree > 0. Pass None for degree 0. - view_directions: Unnormalized view directions with shape [C, N, 3]. - Required when sh_degree > 0. Pass None for degree 0. - -Returns: - Tensor of shape [C, N, D] containing the evaluated features/colors. -)doc", - py::arg("sh_degree"), - py::arg("num_cameras"), - py::arg("sh0"), - py::arg("radii"), - py::arg("shN") = std::nullopt, - py::arg("view_directions") = std::nullopt); -} diff --git a/src/python/GaussianSplatOps.cpp b/src/python/GaussianSplatOps.cpp new file mode 100644 index 000000000..d64050ac3 --- /dev/null +++ b/src/python/GaussianSplatOps.cpp @@ -0,0 +1,1006 @@ +// Copyright Contributors to the OpenVDB Project +// SPDX-License-Identifier: Apache-2.0 +// +// Pybind11 bindings for Gaussian splat free-function ops. +// These expose the fvdb::detail::ops functions as module-level functions on +// _fvdb_cpp, enabling the Python functional layer. +// +// Design note on accumulator mutability: +// The C++ projection backward kernel mutates three accumulator tensors in-place +// (gradient norms, max 2D radii, step counts) via atomicAdd. These support +// Gaussian densification (split/clone/prune decisions during training). +// The backward binding (project_gaussians_analytic_bwd) accepts these as +// optional tensors. The Python GaussianSplat3d class owns the accumulators +// and passes them through to the C++ backward dispatch. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +void +bind_gaussian_splat_ops(py::module &m) { + namespace ops = fvdb::detail::ops; + using DistortionModel = fvdb::detail::ops::DistortionModel; + using RollingShutterType = fvdb::detail::ops::RollingShutterType; + + // ----------------------------------------------------------------------- + // Enum types (moved from GaussianSplatBinding.cpp) + // ----------------------------------------------------------------------- + + py::enum_(m, "RollingShutterType") + .value("NONE", fvdb::detail::ops::RollingShutterType::NONE) + .value("VERTICAL", fvdb::detail::ops::RollingShutterType::VERTICAL) + .value("HORIZONTAL", fvdb::detail::ops::RollingShutterType::HORIZONTAL) + .export_values(); + + py::enum_(m, "CameraModel") + .value("PINHOLE", fvdb::detail::ops::DistortionModel::PINHOLE) + .value("OPENCV_RADTAN_5", fvdb::detail::ops::DistortionModel::OPENCV_RADTAN_5) + .value("OPENCV_RATIONAL_8", fvdb::detail::ops::DistortionModel::OPENCV_RATIONAL_8) + .value("OPENCV_RADTAN_THIN_PRISM_9", + fvdb::detail::ops::DistortionModel::OPENCV_RADTAN_THIN_PRISM_9) + .value("OPENCV_THIN_PRISM_12", fvdb::detail::ops::DistortionModel::OPENCV_THIN_PRISM_12) + .value("ORTHOGRAPHIC", fvdb::detail::ops::DistortionModel::ORTHOGRAPHIC) + .export_values(); + + py::enum_(m, "ProjectionMethod") + .value("AUTO", fvdb::detail::ops::ProjectionMethod::AUTO) + .value("ANALYTIC", fvdb::detail::ops::ProjectionMethod::ANALYTIC) + .value("UNSCENTED", fvdb::detail::ops::ProjectionMethod::UNSCENTED) + .export_values(); + + // ----------------------------------------------------------------------- + // Data types needed by the functional ops + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + // Kernel-level bindings (for Python autograd and composition) + // ----------------------------------------------------------------------- + + // ----------------------------------------------------------------------- + // Analysis operations (call raw tensor dispatch functions directly) + // ----------------------------------------------------------------------- + + m.def("count_contributing_gaussians", + &ops::count_contributing_gaussians, + py::arg("means2d"), + py::arg("conics"), + py::arg("opacities"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size")); + + m.def("count_contributing_gaussians_sparse", + &ops::count_contributing_gaussians_sparse, + py::arg("means2d"), + py::arg("conics"), + py::arg("opacities"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("pixels_to_render"), + py::arg("active_tiles"), + py::arg("tile_pixel_mask"), + py::arg("tile_pixel_cumsum"), + py::arg("pixel_map"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size")); + + m.def("identify_contributing_gaussians", + &ops::identify_contributing_gaussians, + py::arg("means2d"), + py::arg("conics"), + py::arg("opacities"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("num_depth_samples"), + py::arg("num_contributing_gaussians") = py::none()); + + m.def("identify_contributing_gaussians_sparse", + &ops::identify_contributing_gaussians_sparse, + py::arg("means2d"), + py::arg("conics"), + py::arg("opacities"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("pixels_to_render"), + py::arg("active_tiles"), + py::arg("tile_pixel_mask"), + py::arg("tile_pixel_cumsum"), + py::arg("pixel_map"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("num_depth_samples"), + py::arg("num_contributing_gaussians") = py::none()); + + // ----------------------------------------------------------------------- + // MCMC operations (thin dispatch wrappers) + // ----------------------------------------------------------------------- + + m.def( + "relocate_gaussians", + [](const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &ratios, + const torch::Tensor &binomialCoeffs, + const int nMax, + const float minOpacity) { + return ops::relocate_gaussians( + logScales, logitOpacities, ratios, binomialCoeffs, nMax, minOpacity); + }, + py::arg("log_scales"), + py::arg("logit_opacities"), + py::arg("ratios"), + py::arg("binomial_coeffs"), + py::arg("n_max"), + py::arg("min_opacity")); + + m.def( + "add_noise_to_gaussian_means", + [](torch::Tensor &means, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &quats, + const float noiseScale, + const float t, + const float k) { + ops::add_noise_to_gaussian_means( + means, logScales, logitOpacities, quats, noiseScale, t, k); + }, + py::arg("means"), + py::arg("log_scales"), + py::arg("logit_opacities"), + py::arg("quats"), + py::arg("noise_scale"), + py::arg("t"), + py::arg("k")); + + // ----------------------------------------------------------------------- + // PLY I/O (wraps C++ PLY functions directly with raw tensors) + // ----------------------------------------------------------------------- + + m.def( + "save_gaussians_ply", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &logitOpacities, + const torch::Tensor &sh0, + const torch::Tensor &shN, + const std::string &filename, + std::optional> + metadata) { + fvdb::detail::io::saveGaussianPly( + filename, means, quats, logScales, logitOpacities, sh0, shN, metadata); + }, + py::arg("means"), + py::arg("quats"), + py::arg("log_scales"), + py::arg("logit_opacities"), + py::arg("sh0"), + py::arg("shN"), + py::arg("filename"), + py::arg("metadata")); + + m.def( + "load_gaussians_ply", + [](const std::string &filename, torch::Device device) + -> std::tuple> { + return fvdb::detail::io::loadGaussianPly(filename, device); + }, + py::arg("filename"), + py::arg("device") = torch::kCPU); + + // ------- Raw forward/backward dispatch (for Python autograd) ------- + + // 1. project_gaussians_analytic_fwd + m.def( + "project_gaussians_analytic_fwd", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &scales, + const torch::Tensor &worldToCamMatrices, + const torch::Tensor &projectionMatrices, + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations, + const bool ortho) { + return ops::project_gaussians_analytic_fwd(means, + quats, + scales, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + calcCompensations, + ortho); + }, + py::arg("means"), + py::arg("quats"), + py::arg("scales"), + py::arg("world_to_cam_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("eps2d"), + py::arg("near"), + py::arg("far"), + py::arg("min_radius_2d"), + py::arg("calc_compensations"), + py::arg("ortho")); + + // 2. project_gaussians_analytic_bwd + m.def( + "project_gaussians_analytic_bwd", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &scales, + const torch::Tensor &worldToCamMatrices, + const torch::Tensor &projectionMatrices, + const at::optional &compensations, + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, + const torch::Tensor &conics, + const torch::Tensor &dLossDMeans2d, + const torch::Tensor &dLossDDepths, + const torch::Tensor &dLossDConics, + const at::optional &dLossDCompensations, + const bool worldToCamMatricesRequiresGrad, + const bool ortho, + at::optional outNormalizeddLossdMeans2dNormAccum, + at::optional outNormalizedMaxRadiiAccum, + at::optional outGradientStepCounts) { + return ops::project_gaussians_analytic_bwd(means, + quats, + scales, + worldToCamMatrices, + projectionMatrices, + compensations, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + worldToCamMatricesRequiresGrad, + ortho, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); + }, + py::arg("means"), + py::arg("quats"), + py::arg("scales"), + py::arg("world_to_cam_matrices"), + py::arg("projection_matrices"), + py::arg("compensations"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("eps2d"), + py::arg("radii"), + py::arg("conics"), + py::arg("d_loss_d_means2d"), + py::arg("d_loss_d_depths"), + py::arg("d_loss_d_conics"), + py::arg("d_loss_d_compensations"), + py::arg("world_to_cam_matrices_requires_grad"), + py::arg("ortho"), + py::arg("out_normalized_d_loss_d_means2d_norm_accum") = py::none(), + py::arg("out_normalized_max_radii_accum") = py::none(), + py::arg("out_gradient_step_counts") = py::none()); + + // 3. eval_gaussian_sh_fwd + m.def( + "eval_gaussian_sh_fwd", + [](const int64_t shDegreeToUse, + const int64_t numCameras, + const torch::Tensor &viewDirs, + const torch::Tensor &sh0Coeffs, + const torch::Tensor &shNCoeffs, + const torch::Tensor &radii) { + return ops::eval_gaussian_sh_fwd( + shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); + }, + py::arg("sh_degree_to_use"), + py::arg("num_cameras"), + py::arg("view_dirs"), + py::arg("sh0_coeffs"), + py::arg("sh_n_coeffs"), + py::arg("radii")); + + // 4. eval_gaussian_sh_bwd + m.def( + "eval_gaussian_sh_bwd", + [](const int64_t shDegreeToUse, + const int64_t numCameras, + const int64_t numGaussians, + const torch::Tensor &viewDirs, + const torch::Tensor &shNCoeffs, + const torch::Tensor &dLossDColors, + const torch::Tensor &radii, + const bool computeDLossDViewDirs) { + return ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDColors, + radii, + computeDLossDViewDirs); + }, + py::arg("sh_degree_to_use"), + py::arg("num_cameras"), + py::arg("num_gaussians"), + py::arg("view_dirs"), + py::arg("sh_n_coeffs"), + py::arg("d_loss_d_colors"), + py::arg("radii"), + py::arg("compute_d_loss_d_view_dirs")); + + // 5. rasterize_screen_space_gaussians_fwd + m.def( + "rasterize_screen_space_gaussians_fwd", + [](const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + features, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds, + masks); + }, + py::arg("means2d"), + py::arg("conics"), + py::arg("features"), + py::arg("opacities"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("backgrounds"), + py::arg("masks")); + + // 6. rasterize_screen_space_gaussians_bwd + m.def( + "rasterize_screen_space_gaussians_bwd", + [](const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_screen_space_gaussians_bwd(means2d, + conics, + features, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + absGrad, + numSharedChannelsOverride, + backgrounds, + masks); + }, + py::arg("means2d"), + py::arg("conics"), + py::arg("features"), + py::arg("opacities"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("rendered_alphas"), + py::arg("last_ids"), + py::arg("d_loss_d_rendered_features"), + py::arg("d_loss_d_rendered_alphas"), + py::arg("abs_grad"), + py::arg("num_shared_channels_override") = -1, + py::arg("backgrounds") = py::none(), + py::arg("masks") = py::none()); + + // 7. rasterize_screen_space_gaussians_sparse_fwd + m.def( + "rasterize_screen_space_gaussians_sparse_fwd", + [](const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_screen_space_gaussians_sparse_fwd(pixelsToRender, + means2d, + conics, + features, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + backgrounds, + masks); + }, + py::arg("pixels_to_render"), + py::arg("means2d"), + py::arg("conics"), + py::arg("features"), + py::arg("opacities"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("active_tiles"), + py::arg("tile_pixel_mask"), + py::arg("tile_pixel_cumsum"), + py::arg("pixel_map"), + py::arg("backgrounds"), + py::arg("masks")); + + // 8. rasterize_screen_space_gaussians_sparse_bwd + m.def( + "rasterize_screen_space_gaussians_sparse_bwd", + [](const fvdb::JaggedTensor &pixelsToRender, + const torch::Tensor &means2d, + const torch::Tensor &conics, + const torch::Tensor &features, + const torch::Tensor &opacities, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const fvdb::JaggedTensor &renderedAlphas, + const fvdb::JaggedTensor &lastIds, + const fvdb::JaggedTensor &dLossDRenderedFeatures, + const fvdb::JaggedTensor &dLossDRenderedAlphas, + const torch::Tensor &activeTiles, + const torch::Tensor &tilePixelMask, + const torch::Tensor &tilePixelCumsum, + const torch::Tensor &pixelMap, + const bool absGrad, + const int64_t numSharedChannelsOverride, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_screen_space_gaussians_sparse_bwd(pixelsToRender, + means2d, + conics, + features, + opacities, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + absGrad, + numSharedChannelsOverride, + backgrounds, + masks); + }, + py::arg("pixels_to_render"), + py::arg("means2d"), + py::arg("conics"), + py::arg("features"), + py::arg("opacities"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("rendered_alphas"), + py::arg("last_ids"), + py::arg("d_loss_d_rendered_features"), + py::arg("d_loss_d_rendered_alphas"), + py::arg("active_tiles"), + py::arg("tile_pixel_mask"), + py::arg("tile_pixel_cumsum"), + py::arg("pixel_map"), + py::arg("abs_grad"), + py::arg("num_shared_channels_override") = -1, + py::arg("backgrounds") = py::none(), + py::arg("masks") = py::none()); + + // 9. rasterize_world_space_gaussians_fwd + m.def( + "rasterize_world_space_gaussians_fwd", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + const RollingShutterType rollingShutterType, + const DistortionModel cameraModel, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_world_space_gaussians_fwd(means, + quats, + logScales, + features, + opacities, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + distortionCoeffs, + rollingShutterType, + cameraModel, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds, + masks); + }, + py::arg("means"), + py::arg("quats"), + py::arg("log_scales"), + py::arg("features"), + py::arg("opacities"), + py::arg("world_to_cam_matrices_start"), + py::arg("world_to_cam_matrices_end"), + py::arg("projection_matrices"), + py::arg("distortion_coeffs"), + py::arg("rolling_shutter_type"), + py::arg("camera_model"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("backgrounds"), + py::arg("masks")); + + // 10. rasterize_world_space_gaussians_bwd + m.def( + "rasterize_world_space_gaussians_bwd", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &features, + const torch::Tensor &opacities, + const torch::Tensor &worldToCamMatricesStart, + const torch::Tensor &worldToCamMatricesEnd, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + const RollingShutterType rollingShutterType, + const DistortionModel cameraModel, + const uint32_t imageWidth, + const uint32_t imageHeight, + const uint32_t imageOriginW, + const uint32_t imageOriginH, + const uint32_t tileSize, + const torch::Tensor &tileOffsets, + const torch::Tensor &tileGaussianIds, + const torch::Tensor &renderedAlphas, + const torch::Tensor &lastIds, + const torch::Tensor &dLossDRenderedFeatures, + const torch::Tensor &dLossDRenderedAlphas, + const at::optional &backgrounds, + const at::optional &masks) { + return ops::rasterize_world_space_gaussians_bwd(means, + quats, + logScales, + features, + opacities, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + distortionCoeffs, + rollingShutterType, + cameraModel, + imageWidth, + imageHeight, + imageOriginW, + imageOriginH, + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastIds, + dLossDRenderedFeatures, + dLossDRenderedAlphas, + backgrounds, + masks); + }, + py::arg("means"), + py::arg("quats"), + py::arg("log_scales"), + py::arg("features"), + py::arg("opacities"), + py::arg("world_to_cam_matrices_start"), + py::arg("world_to_cam_matrices_end"), + py::arg("projection_matrices"), + py::arg("distortion_coeffs"), + py::arg("rolling_shutter_type"), + py::arg("camera_model"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("image_origin_w"), + py::arg("image_origin_h"), + py::arg("tile_size"), + py::arg("tile_offsets"), + py::arg("tile_gaussian_ids"), + py::arg("rendered_alphas"), + py::arg("last_ids"), + py::arg("d_loss_d_rendered_features"), + py::arg("d_loss_d_rendered_alphas"), + py::arg("backgrounds"), + py::arg("masks")); + + // 11. project_gaussians_analytic_jagged_fwd + m.def( + "project_gaussians_analytic_jagged_fwd", + [](const torch::Tensor &gSizes, + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &scales, + const torch::Tensor &cSizes, + const torch::Tensor &worldToCamMatrices, + const torch::Tensor &projectionMatrices, + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool ortho) { + return ops::project_gaussians_analytic_jagged_fwd(gSizes, + means, + quats, + scales, + cSizes, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + ortho); + }, + py::arg("g_sizes"), + py::arg("means"), + py::arg("quats"), + py::arg("scales"), + py::arg("c_sizes"), + py::arg("world_to_cam_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("eps2d"), + py::arg("near"), + py::arg("far"), + py::arg("min_radius_2d"), + py::arg("ortho")); + + // 12. project_gaussians_analytic_jagged_bwd + m.def( + "project_gaussians_analytic_jagged_bwd", + [](const torch::Tensor &gSizes, + const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &scales, + const torch::Tensor &cSizes, + const torch::Tensor &worldToCamMatrices, + const torch::Tensor &projectionMatrices, + const uint32_t imageWidth, + const uint32_t imageHeight, + const float eps2d, + const torch::Tensor &radii, + const torch::Tensor &conics, + const torch::Tensor &dLossDMeans2d, + const torch::Tensor &dLossDDepths, + const torch::Tensor &dLossDConics, + const bool worldToCamMatricesRequiresGrad, + const bool ortho) { + return ops::project_gaussians_analytic_jagged_bwd(gSizes, + means, + quats, + scales, + cSizes, + worldToCamMatrices, + projectionMatrices, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + worldToCamMatricesRequiresGrad, + ortho); + }, + py::arg("g_sizes"), + py::arg("means"), + py::arg("quats"), + py::arg("scales"), + py::arg("c_sizes"), + py::arg("world_to_cam_matrices"), + py::arg("projection_matrices"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("eps2d"), + py::arg("radii"), + py::arg("conics"), + py::arg("d_loss_d_means2d"), + py::arg("d_loss_d_depths"), + py::arg("d_loss_d_conics"), + py::arg("world_to_cam_matrices_requires_grad"), + py::arg("ortho")); + + // ------- Tile intersection (non-differentiable) ------- + + m.def( + "intersect_gaussian_tiles", + [](const torch::Tensor &means2d, + const torch::Tensor &radii, + const torch::Tensor &depths, + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW, + const at::optional &cameraIds) { + return ops::intersect_gaussian_tiles( + means2d, radii, depths, cameraIds, numCameras, tileSize, numTilesH, numTilesW); + }, + py::arg("means2d"), + py::arg("radii"), + py::arg("depths"), + py::arg("num_cameras"), + py::arg("tile_size"), + py::arg("num_tiles_h"), + py::arg("num_tiles_w"), + py::arg("camera_ids") = py::none()); + + // ------- Sparse tile intersection (non-differentiable) ------- + + m.def( + "intersect_gaussian_tiles_sparse", + [](const torch::Tensor &means2d, + const torch::Tensor &radii, + const torch::Tensor &depths, + const torch::Tensor &tileMask, + const torch::Tensor &activeTiles, + const uint32_t numCameras, + const uint32_t tileSize, + const uint32_t numTilesH, + const uint32_t numTilesW, + const at::optional &cameraIds) { + return ops::intersect_gaussian_tiles_sparse(means2d, + radii, + depths, + tileMask, + activeTiles, + cameraIds, + numCameras, + tileSize, + numTilesH, + numTilesW); + }, + py::arg("means2d"), + py::arg("radii"), + py::arg("depths"), + py::arg("tile_mask"), + py::arg("active_tiles"), + py::arg("num_cameras"), + py::arg("tile_size"), + py::arg("num_tiles_h"), + py::arg("num_tiles_w"), + py::arg("camera_ids") = py::none()); + + // ------- Sparse tile layout (non-differentiable) ------- + + m.def( + "build_sparse_gaussian_tile_layout", + [](const int32_t tileSideLength, + const int32_t numTilesW, + const int32_t numTilesH, + const fvdb::JaggedTensor &pixelsToRender) { + return ops::build_sparse_gaussian_tile_layout( + tileSideLength, numTilesW, numTilesH, pixelsToRender); + }, + py::arg("tile_side_length"), + py::arg("num_tiles_w"), + py::arg("num_tiles_h"), + py::arg("pixels_to_render")); + + // ------- UT projection forward (non-differentiable) ------- + + m.def( + "project_gaussians_ut_fwd", + [](const torch::Tensor &means, + const torch::Tensor &quats, + const torch::Tensor &logScales, + const torch::Tensor &worldToCamMatrices, + const torch::Tensor &projectionMatrices, + const torch::Tensor &distortionCoeffs, + const DistortionModel cameraModel, + const int64_t imageWidth, + const int64_t imageHeight, + const float eps2d, + const float nearPlane, + const float farPlane, + const float minRadius2d, + const bool calcCompensations) { + ops::UTParams utParams{}; + return ops::project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatrices, + worldToCamMatrices, + projectionMatrices, + ops::RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + calcCompensations); + }, + py::arg("means"), + py::arg("quats"), + py::arg("log_scales"), + py::arg("world_to_cam_matrices"), + py::arg("projection_matrices"), + py::arg("distortion_coeffs"), + py::arg("camera_model"), + py::arg("image_width"), + py::arg("image_height"), + py::arg("eps2d"), + py::arg("near"), + py::arg("far"), + py::arg("min_radius_2d"), + py::arg("calc_compensations")); +} diff --git a/src/python/ViewerBinding.cpp b/src/python/ViewerBinding.cpp index 504358fb7..18393aba1 100644 --- a/src/python/ViewerBinding.cpp +++ b/src/python/ViewerBinding.cpp @@ -7,7 +7,7 @@ #include "TypeCasters.h" -#include +#include #include #include #include @@ -83,15 +83,19 @@ bind_viewer(py::module &m) { py::arg("device_id"), py::arg("verbose"), "Create a new Viewer instance") - .def( - "add_gaussian_splat_3d_view", - &fvdb::detail::viewer::Viewer::addGaussianSplat3dView, - py::arg("scene_name"), - py::arg("name"), - py::arg("gaussian_splat_3d"), - py::return_value_policy::reference_internal, // preserve reference; tie lifetime to - // parent - "Register a Gaussian splat 3D view with the viewer (accepts Python or C++ GaussianSplat3d)") + .def("add_gaussian_splat_3d_view", + &fvdb::detail::viewer::Viewer::addGaussianSplat3dView, + py::arg("scene_name"), + py::arg("name"), + py::arg("means"), + py::arg("quats"), + py::arg("log_scales"), + py::arg("logit_opacities"), + py::arg("sh0"), + py::arg("shN"), + py::return_value_policy::reference_internal, // preserve reference; tie lifetime to + // parent + "Register a Gaussian splat 3D view with the viewer (accepts raw tensors)") .def("has_gaussian_splat_3d_view", &fvdb::detail::viewer::Viewer::hasGaussianSplat3dView, py::arg("name"), @@ -212,12 +216,12 @@ bind_viewer(py::module &m) { "set_camera_model", [](fvdb::detail::viewer::Viewer &viewer, const std::string &sceneName, - fvdb::GaussianSplat3d::CameraModel model) { - if (model != fvdb::GaussianSplat3d::CameraModel::PINHOLE && - model != fvdb::GaussianSplat3d::CameraModel::ORTHOGRAPHIC) { + fvdb::detail::ops::DistortionModel model) { + if (model != fvdb::detail::ops::DistortionModel::PINHOLE && + model != fvdb::detail::ops::DistortionModel::ORTHOGRAPHIC) { PyErr_SetString(PyExc_NotImplementedError, - "Viewer currently only supports CameraModel.PINHOLE and " - "CameraModel.ORTHOGRAPHIC"); + "Viewer currently only supports DistortionModel.PINHOLE and " + "DistortionModel.ORTHOGRAPHIC"); throw py::error_already_set(); } viewer.setCameraModel(sceneName, model); diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 545d45f56..0819bb739 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -142,7 +142,6 @@ ConfigureTest(ExampleTest "ExampleTest.cpp") ConfigureTest(JaggedTensorTest "JaggedTensorTest.cpp") ConfigureTest(PackedJaggedAccessorTest "PackedJaggedAccessorTest.cu") ConfigureTest(GaussianComputeSparseInfoTest "GaussianComputeSparseInfoTest.cpp") -ConfigureTest(DeduplicatePixelsTest "DeduplicatePixelsTest.cpp") ConfigureTest(GaussianTileIntersectionTest "GaussianTileIntersectionTest.cpp") ConfigureTest(GaussianComputeNanInfMaskTest "GaussianComputeNanInfMaskTest.cpp") ConfigureTest(GaussianRasterizeBackwardTest "GaussianRasterizeBackwardTest.cpp") @@ -152,7 +151,7 @@ ConfigureTest(GaussianSphericalHarmonicsBackwardTest "GaussianSphericalHarmonics ConfigureTest(GaussianProjectionForwardTest "GaussianProjectionForwardTest.cpp") ConfigureTest(GaussianProjectionBackwardTest "GaussianProjectionBackwardTest.cpp") ConfigureTest(GaussianProjectionUTTest "GaussianProjectionUTTest.cpp") -ConfigureTest(GaussianSplat3dCameraApiTest "GaussianSplat3dCameraApiTest.cpp") + ConfigureTest(GaussianCamerasTest "GaussianCamerasTest.cu") ConfigureTest(GaussianUtilsTest "GaussianUtilsTest.cu") ConfigureTest(GaussianRasterizeTopContributorsTest "GaussianRasterizeTopContributorsTest.cpp") diff --git a/src/tests/DeduplicatePixelsTest.cpp b/src/tests/DeduplicatePixelsTest.cpp deleted file mode 100644 index 7c36cdf24..000000000 --- a/src/tests/DeduplicatePixelsTest.cpp +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 - -#include "utils/Tensor.h" - -#include - -#include - -#include -#include -#include - -// Forward-declare the function under test (defined in GaussianSplat3d.cpp with external linkage). -namespace fvdb { -std::tuple -deduplicatePixels(const JaggedTensor &pixelsToRender, int64_t imageWidth, int64_t imageHeight); -} // namespace fvdb - -using fvdb::test::tensorOpts; - -static constexpr int64_t kImageWidth = 64; -static constexpr int64_t kImageHeight = 64; - -template struct DeduplicatePixelsTest : public ::testing::Test {}; - -using CoordTypes = ::testing::Types; -TYPED_TEST_SUITE(DeduplicatePixelsTest, CoordTypes); - -TYPED_TEST(DeduplicatePixelsTest, Empty) { - auto opts = tensorOpts(torch::kCUDA); - auto pixels = fvdb::JaggedTensor(torch::empty({0, 2}, opts)); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_FALSE(hasDuplicates); - EXPECT_EQ(inverseIndices.size(0), 0); - EXPECT_EQ(uniquePixels.rsize(0), 0); -} - -TYPED_TEST(DeduplicatePixelsTest, SinglePixel) { - auto opts = tensorOpts(torch::kCPU); - auto coords = torch::tensor({{5, 10}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{coords}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_FALSE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 1); -} - -TYPED_TEST(DeduplicatePixelsTest, AllUnique) { - auto opts = tensorOpts(torch::kCPU); - auto coords = torch::tensor({{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 3}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{coords}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_FALSE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 5); - // inverseIndices should be a permutation of [0..4] (identity if no dups) - EXPECT_EQ(inverseIndices.size(0), 5); -} - -TYPED_TEST(DeduplicatePixelsTest, SomeDuplicates) { - auto opts = tensorOpts(torch::kCPU); - // (0,0) appears at index 0 and 2 - auto coords = torch::tensor({{0, 0}, {1, 1}, {0, 0}, {2, 2}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{coords}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 3); - EXPECT_EQ(inverseIndices.size(0), 4); - - // Indices 0 and 2 in the original both map to the same unique index - auto inv = inverseIndices.cpu(); - EXPECT_EQ(inv[0].template item(), inv[2].template item()); - // Indices 1 and 3 map to different unique indices - EXPECT_NE(inv[1].template item(), inv[3].template item()); -} - -TYPED_TEST(DeduplicatePixelsTest, AllSamePixel) { - auto opts = tensorOpts(torch::kCPU); - auto coords = torch::tensor({{5, 5}, {5, 5}, {5, 5}, {5, 5}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{coords}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 1); - EXPECT_EQ(inverseIndices.size(0), 4); - - // All inverse indices should be 0 - auto inv = inverseIndices.cpu(); - for (int i = 0; i < 4; i++) { - EXPECT_EQ(inv[i].template item(), 0); - } -} - -TYPED_TEST(DeduplicatePixelsTest, MultiBatchNoDuplicates) { - auto opts = tensorOpts(torch::kCPU); - // Same (0,0) in different batches should NOT be considered duplicates - auto batch0 = torch::tensor({{0, 0}, {1, 1}}, opts); - auto batch1 = torch::tensor({{0, 0}, {2, 2}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{batch0, batch1}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_FALSE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 4); - EXPECT_EQ(uniquePixels.num_outer_lists(), 2); -} - -TYPED_TEST(DeduplicatePixelsTest, MultiBatchWithDuplicates) { - auto opts = tensorOpts(torch::kCPU); - // Batch 0: (0,0) duplicated; Batch 1: (0,0) alone (not a dup of batch 0's) - auto batch0 = torch::tensor({{0, 0}, {1, 1}, {0, 0}}, opts); - auto batch1 = torch::tensor({{0, 0}, {3, 3}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{batch0, batch1}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.num_outer_lists(), 2); - // Batch 0: 2 unique pixels (0,0) and (1,1); Batch 1: 2 unique pixels (0,0) and (3,3) - EXPECT_EQ(uniquePixels.rsize(0), 4); - EXPECT_EQ(inverseIndices.size(0), 5); - - // Original indices 0 and 2 (both batch 0, pixel (0,0)) should map to same unique - auto inv = inverseIndices.cpu(); - EXPECT_EQ(inv[0].template item(), inv[2].template item()); -} - -TYPED_TEST(DeduplicatePixelsTest, MultiBatchAllSamePixel) { - auto opts = tensorOpts(torch::kCPU); - auto batch0 = torch::tensor({{1, 1}, {1, 1}, {1, 1}}, opts); - auto batch1 = torch::tensor({{2, 2}, {2, 2}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{batch0, batch1}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.num_outer_lists(), 2); - EXPECT_EQ(uniquePixels.rsize(0), 2); - - auto offsets = uniquePixels.joffsets().cpu(); - EXPECT_EQ(offsets[0].template item(), 0); - EXPECT_EQ(offsets[1].template item(), 1); - EXPECT_EQ(offsets[2].template item(), 2); - - auto inv = inverseIndices.cpu(); - EXPECT_EQ(inv[0].template item(), inv[1].template item()); - EXPECT_EQ(inv[0].template item(), inv[2].template item()); - EXPECT_EQ(inv[3].template item(), inv[4].template item()); - EXPECT_NE(inv[0].template item(), inv[3].template item()); -} - -TYPED_TEST(DeduplicatePixelsTest, RoundTripSomeDuplicates) { - auto opts = tensorOpts(torch::kCPU); - auto coords = torch::tensor({{3, 7}, {1, 2}, {3, 7}, {5, 5}, {1, 2}, {9, 0}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{coords}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.rsize(0), 4); - - // Round-trip: indexing unique pixels by inverseIndices should reconstruct the original - auto reconstructed = uniquePixels.jdata().index_select(0, inverseIndices); - EXPECT_TRUE(torch::equal(reconstructed.cpu(), coords.to(reconstructed.dtype()))); -} - -TYPED_TEST(DeduplicatePixelsTest, RoundTripMultiBatch) { - auto opts = tensorOpts(torch::kCPU); - auto batch0 = torch::tensor({{2, 3}, {4, 5}, {2, 3}}, opts); - auto batch1 = torch::tensor({{6, 7}, {6, 7}, {8, 9}}, opts); - auto pixels = fvdb::JaggedTensor(std::vector{batch0, batch1}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - - auto originalJdata = pixels.jdata(); - auto reconstructed = uniquePixels.jdata().index_select(0, inverseIndices); - EXPECT_TRUE(torch::equal(reconstructed.cpu(), originalJdata.cpu().to(reconstructed.dtype()))); -} - -TYPED_TEST(DeduplicatePixelsTest, JaggedTensorOffsets) { - auto opts = tensorOpts(torch::kCPU); - auto batch0 = torch::tensor({{0, 0}, {0, 0}, {1, 1}}, opts); // 3 pixels, 2 unique - auto batch1 = torch::tensor({{2, 2}}, opts); // 1 pixel, 1 unique - auto batch2 = torch::tensor({{3, 3}, {4, 4}, {3, 3}, {4, 4}}, opts); // 4 pixels, 2 unique - auto pixels = - fvdb::JaggedTensor(std::vector{batch0, batch1, batch2}).to(torch::kCUDA); - - auto [uniquePixels, inverseIndices, hasDuplicates] = - fvdb::deduplicatePixels(pixels, kImageWidth, kImageHeight); - - EXPECT_TRUE(hasDuplicates); - EXPECT_EQ(uniquePixels.num_outer_lists(), 3); - EXPECT_EQ(uniquePixels.rsize(0), 5); // 2 + 1 + 2 - - // Verify per-batch counts via offsets - auto offsets = uniquePixels.joffsets().cpu(); - EXPECT_EQ(offsets[0].template item(), 0); - EXPECT_EQ(offsets[1].template item(), 2); // batch 0: 2 unique - EXPECT_EQ(offsets[2].template item(), 3); // batch 1: 1 unique - EXPECT_EQ(offsets[3].template item(), 5); // batch 2: 2 unique -} diff --git a/src/tests/GaussianCamerasTest.cu b/src/tests/GaussianCamerasTest.cu index dffd3efd6..9fce0c1b0 100644 --- a/src/tests/GaussianCamerasTest.cu +++ b/src/tests/GaussianCamerasTest.cu @@ -1,7 +1,7 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 -#include +#include #include #include @@ -327,8 +327,7 @@ TEST(GaussianCamerasTest, PerspectiveEncapsulatedProjectionAndVJPMatchReferenceP auto meansWorld = makeMeansWorld(N); auto covarsWorld6 = makeCovarsWorld6(N); - auto camera = PerspectiveCamera( - projection, worldToCam, static_cast(C), 640, 480, 0.01f, 1.0e8f); + auto camera = PerspectiveCamera(projection, worldToCam, 640, 480, 0.01f, 1.0e8f); runForwardAndVjpParityChecks(camera, C, N, meansWorld, covarsWorld6); } @@ -340,8 +339,7 @@ TEST(GaussianCamerasTest, OrthographicEncapsulatedProjectionAndVJPMatchReference auto meansWorld = makeMeansWorld(N); auto covarsWorld6 = makeCovarsWorld6(N); - auto camera = OrthographicCamera( - projection, worldToCam, static_cast(C), 640, 480, -1.0e8f, 1.0e8f); + auto camera = OrthographicCamera(projection, worldToCam, 640, 480, -1.0e8f, 1.0e8f); runForwardAndVjpParityChecks(camera, C, N, meansWorld, covarsWorld6); } diff --git a/src/tests/GaussianComputeNanInfMaskTest.cpp b/src/tests/GaussianComputeNanInfMaskTest.cpp index ad3ab8229..cbfd578f6 100644 --- a/src/tests/GaussianComputeNanInfMaskTest.cpp +++ b/src/tests/GaussianComputeNanInfMaskTest.cpp @@ -3,7 +3,7 @@ #include "utils/Tensor.h" -#include +#include #include @@ -27,8 +27,8 @@ TEST(NanInfMaskTests, TestEmptyGaussians) { auto const sh0 = torch::rand({numGaussians, 1, 3}, floatOptsCUDA); auto const shN = torch::rand({numGaussians, 26, 3}, floatOptsCUDA); - auto mask = fvdb::detail::ops::dispatchGaussianNanInfMask( - means, quats, scales, opacities, sh0, shN); + auto mask = + fvdb::detail::ops::compute_gaussian_nan_inf_mask(means, quats, scales, opacities, sh0, shN); EXPECT_TRUE(mask.jdata().numel() == 0); EXPECT_TRUE(mask.jdata().is_cuda()); @@ -54,7 +54,7 @@ TEST(NanInfMaskTests, TestExceptionForInconsistentGaussians) { auto const sh0 = torch::rand({config[4], 1, 3}, floatOptsCUDA); auto const shN = torch::rand({config[5], 26, 3}, floatOptsCUDA); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianNanInfMask( + EXPECT_THROW(fvdb::detail::ops::compute_gaussian_nan_inf_mask( means, quats, scales, opacities, sh0, shN), c10::ValueError); } @@ -136,7 +136,7 @@ TEST_P(NanInfMaskTestFixture, TestNanInfMaskMeansNan) { auto const sh0JTData = sh0JT.jdata(); // [N, 1, 3] auto const shNJTData = shNJT.jdata(); // [N, 26, 3] - auto const mask = fvdb::detail::ops::dispatchGaussianNanInfMask( + auto const mask = fvdb::detail::ops::compute_gaussian_nan_inf_mask( meansJT, quatsJT, scalesJT, opacitiesJT, sh0JTData, shNJTData); EXPECT_TRUE(torch::equal(expectedMask.to(torch::kCUDA), mask.jdata())); diff --git a/src/tests/GaussianComputeSparseInfoTest.cpp b/src/tests/GaussianComputeSparseInfoTest.cpp index 38d4ecc7c..d2a149736 100644 --- a/src/tests/GaussianComputeSparseInfoTest.cpp +++ b/src/tests/GaussianComputeSparseInfoTest.cpp @@ -5,7 +5,7 @@ #include "utils/TestUtilities.h" #include "utils/TileBitMask.h" -#include +#include #include @@ -16,7 +16,7 @@ using fvdb::test::tensorOpts; using fvdb::test::TileBitMask; -// Helper function to calculate the expected tensors for computeSparseInfo: +// Helper function to calculate the expected tensors for build_sparse_gaussian_tile_layout: // 1. activeTiles: A 1D tensor of tile ids that have at least one active pixel // 2. tileBitMasks: A 2D tensor of tile bitmasks of shape {numActiveTiles, numWordsPerTile} // 3. pixelsPerTile: A 1D tensor of the inclusive cumulative sum of the number of active pixels in @@ -196,7 +196,7 @@ template struct ComputeSparseInfo : public ::testing::Test auto uvs = uvsCPU.to(torch::kCUDA); auto [activeTiles, activeTileMask, tileBitMasks, tilePixelOffsets, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( this->mTileSize, this->mNumTilesPerAxis, this->mNumTilesPerAxis, uvs); auto [expectedActiveTiles, @@ -235,13 +235,14 @@ TYPED_TEST_SUITE(BadTypeTest, BadCoordTypes); TYPED_TEST(BadTypeTest, GPUThrows) { auto const emptyPixels = fvdb::JaggedTensor{torch::empty({0, 0}, tensorOpts())}; - EXPECT_THROW(fvdb::detail::ops::computeSparseInfo(16, 4, 4, emptyPixels), c10::TypeError); + EXPECT_THROW(fvdb::detail::ops::build_sparse_gaussian_tile_layout(16, 4, 4, emptyPixels), + c10::TypeError); } TEST(BadTypeTest, CPUThrows) { auto const emptyPixels = fvdb::JaggedTensor{torch::empty({0, 0}, tensorOpts(torch::kCPU))}; - EXPECT_THROW(fvdb::detail::ops::computeSparseInfo(16, 4, 4, emptyPixels), + EXPECT_THROW(fvdb::detail::ops::build_sparse_gaussian_tile_layout(16, 4, 4, emptyPixels), c10::NotImplementedError); } @@ -251,7 +252,7 @@ TYPED_TEST(ComputeSparseInfo, Empty) { auto const emptyPixels = fvdb::JaggedTensor(torch::empty({0, 0}, opts)); auto [activeTiles, activeTileMask, tileBitMask, tilePixelOffsets, pixelMap] = - fvdb::detail::ops::computeSparseInfo(this->mTileSize, 4, 4, emptyPixels); + fvdb::detail::ops::build_sparse_gaussian_tile_layout(this->mTileSize, 4, 4, emptyPixels); EXPECT_TRUE( torch::equal(activeTiles, torch::empty({0}, tensorOpts(torch::kCUDA)))); @@ -353,7 +354,7 @@ TYPED_TEST(ComputeSparseInfo, SinglePixel) { auto uniqueUVs = fvdb::JaggedTensor(std::vector{uvsCPU}).to(torch::kCUDA); auto [activeTiles, activeTileMask, tileBitMasks, tilePixelOffsets, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( this->mTileSize, this->mNumTilesPerAxis, this->mNumTilesPerAxis, uniqueUVs); EXPECT_EQ(activeTiles.size(0), 1); @@ -370,7 +371,7 @@ TYPED_TEST(ComputeSparseInfo, TwoPixelsSameTile) { auto uvs = fvdb::JaggedTensor(std::vector{uvsCPU}).to(torch::kCUDA); auto [activeTiles, activeTileMask, tileBitMasks, tilePixelOffsets, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( this->mTileSize, this->mNumTilesPerAxis, this->mNumTilesPerAxis, uvs); EXPECT_EQ(activeTiles.size(0), 1); @@ -384,7 +385,7 @@ TYPED_TEST(ComputeSparseInfo, MultiImageMultiTile) { auto uniqueUVs = uniqueUVsCPU.to(torch::kCUDA); auto [activeTiles, activeTileMask, tileBitMasks, tilePixelOffsets, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( this->mTileSize, this->mNumTilesPerAxis, this->mNumTilesPerAxis, uniqueUVs); auto [expectedActiveTiles, diff --git a/src/tests/GaussianMCMCAddNoiseTest.cpp b/src/tests/GaussianMCMCAddNoiseTest.cpp index ba556a5db..ac830b31b 100644 --- a/src/tests/GaussianMCMCAddNoiseTest.cpp +++ b/src/tests/GaussianMCMCAddNoiseTest.cpp @@ -3,7 +3,7 @@ #include "utils/Tensor.h" -#include +#include #include #include @@ -36,7 +36,7 @@ class GaussianMCMCAddNoiseTest : public ::testing::Test { } // Save the current CUDA RNG state so we can reproduce the baseNoise that - // dispatchGaussianMCMCAddNoise draws internally. + // add_noise_to_gaussian_means draws internally. torch::Tensor saveCudaGeneratorState() { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); @@ -63,7 +63,7 @@ TEST_F(GaussianMCMCAddNoiseTest, AppliesNoiseWithDeterministicBaseNoise) { const auto rngState = saveCudaGeneratorState(); auto meansBaseline = means.clone(); - fvdb::detail::ops::dispatchGaussianMCMCAddNoise( + fvdb::detail::ops::add_noise_to_gaussian_means( means, logScales, logitOpacities, quats, noiseScale, 0.005, 100.0); restoreCudaGeneratorState(rngState); @@ -90,7 +90,7 @@ TEST_F(GaussianMCMCAddNoiseTest, RespectsAnisotropicScales) { const auto rngState = saveCudaGeneratorState(); - fvdb::detail::ops::dispatchGaussianMCMCAddNoise( + fvdb::detail::ops::add_noise_to_gaussian_means( means, logScales, logitOpacities, quats, noiseScale, 0.005, 100); restoreCudaGeneratorState(rngState); @@ -114,7 +114,7 @@ TEST_F(GaussianMCMCAddNoiseTest, HighOpacitySuppressesNoise) { .contiguous(); constexpr float noiseScale = 1.0f; - fvdb::detail::ops::dispatchGaussianMCMCAddNoise( + fvdb::detail::ops::add_noise_to_gaussian_means( means, logScales, logitOpacities, quats, noiseScale, 0.005, 100.0); // Gate approaches zero when opacity ~1; expect negligible movement. @@ -132,7 +132,7 @@ TEST_F(GaussianMCMCAddNoiseTest, ZeroNoiseScaleNoOp) { {{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, floatOpts()); - fvdb::detail::ops::dispatchGaussianMCMCAddNoise( + fvdb::detail::ops::add_noise_to_gaussian_means( means, logScales, logitOpacities, quats, /*noiseScale=*/0.0f, 0.005, 100.0); EXPECT_TRUE(torch::allclose(means, origMeans)); @@ -145,7 +145,7 @@ TEST_F(GaussianMCMCAddNoiseTest, CpuNotImplemented) { const auto quats = torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}}, fvdb::test::tensorOpts(torch::kCPU)); - EXPECT_THROW((fvdb::detail::ops::dispatchGaussianMCMCAddNoise( + EXPECT_THROW((fvdb::detail::ops::add_noise_to_gaussian_means( means, logScales, logitOpacities, quats, 1.0f, 0.005, 100)), c10::Error); } diff --git a/src/tests/GaussianMCMCRelocationTest.cpp b/src/tests/GaussianMCMCRelocationTest.cpp index bec4ce731..1f0dc4647 100644 --- a/src/tests/GaussianMCMCRelocationTest.cpp +++ b/src/tests/GaussianMCMCRelocationTest.cpp @@ -3,7 +3,7 @@ #include "utils/Tensor.h" -#include +#include #include @@ -100,9 +100,8 @@ class GaussianRelocationTest : public ::testing::Test { auto const binomialCoeffsCPU = buildBinomialCoeffsCPU(nMax); auto const binomialCoeffs = binomialCoeffsCPU.to(logScales.device()); - const auto [gpuLogitOpacitiesNew, gpuLogScalesNew] = - fvdb::detail::ops::dispatchGaussianRelocation( - logScales, logitOpacities, ratios, binomialCoeffs, nMax, mMinOpacity); + const auto [gpuLogitOpacitiesNew, gpuLogScalesNew] = fvdb::detail::ops::relocate_gaussians( + logScales, logitOpacities, ratios, binomialCoeffs, nMax, mMinOpacity); const auto [refLogitNew, refLogScalesNew] = referenceRelocation( logScales.cpu(), logitOpacities.cpu(), ratios.cpu(), binomialCoeffsCPU, mMinOpacity); @@ -160,31 +159,31 @@ TEST_F(GaussianRelocationTest, ValidatesInputs) { auto binomialCoeffs = binomialCoeffsCPU.to(torch::kCUDA); // binomialCoeffs on CPU - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRelocation( + EXPECT_THROW(fvdb::detail::ops::relocate_gaussians( logScales, logitOpacities, ratios, binomialCoeffsCPU, nMax, mMinOpacity), c10::Error); // binomialCoeffs wrong shape auto badBinomShape = binomialCoeffs.slice(/*dim=*/0, 0, nMax - 1); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRelocation( + EXPECT_THROW(fvdb::detail::ops::relocate_gaussians( logScales, logitOpacities, ratios, badBinomShape, nMax, mMinOpacity), c10::Error); // ratios wrong dtype auto ratiosLong = ratios.to(torch::kInt64); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRelocation( + EXPECT_THROW(fvdb::detail::ops::relocate_gaussians( logScales, logitOpacities, ratiosLong, binomialCoeffs, nMax, mMinOpacity), c10::Error); // opacities on CPU EXPECT_THROW( - fvdb::detail::ops::dispatchGaussianRelocation( + fvdb::detail::ops::relocate_gaussians( logScales.cpu(), logitOpacities.cpu(), ratios, binomialCoeffs, nMax, mMinOpacity), c10::Error); // scales wrong shape auto logScalesBad = logScales.view({2, 3, 1}); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRelocation( + EXPECT_THROW(fvdb::detail::ops::relocate_gaussians( logScalesBad, logitOpacities, ratios, binomialCoeffs, nMax, mMinOpacity), c10::Error); } @@ -201,7 +200,7 @@ TEST_F(GaussianRelocationTest, CpuNotImplemented) { auto ratios = torch::tensor({1, 2}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32)); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRelocation( + EXPECT_THROW(fvdb::detail::ops::relocate_gaussians( logScales, logitOpacities, ratios, binomialCoeffsCPU, nMax, mMinOpacity), c10::Error); } diff --git a/src/tests/GaussianProjectionBackwardTest.cpp b/src/tests/GaussianProjectionBackwardTest.cpp index a0dfb8558..7a49ccbed 100644 --- a/src/tests/GaussianProjectionBackwardTest.cpp +++ b/src/tests/GaussianProjectionBackwardTest.cpp @@ -3,9 +3,9 @@ #include "utils/Tensor.h" -#include -#include -#include +#include +#include +#include #include #include @@ -173,19 +173,19 @@ TEST_F(GaussianProjectionBackwardTestFixture, DISABLED_GenerateOutputData) { { // Perspective projection const auto [radii_proj, means2d_proj, depths_proj, conics_proj, compensations_proj] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - scales, - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - false); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + scales, + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + false); const auto C = radii_proj.size(0); const auto N = radii_proj.size(1); @@ -209,32 +209,31 @@ TEST_F(GaussianProjectionBackwardTestFixture, DISABLED_GenerateOutputData) { auto outGradientStepCounts = torch::zeros({N}, options.dtype(torch::kInt32)); const auto [dLossDMeans, dLossDCovars, dLossDQuats, dLossDScales, dLossDCamToWorlds] = - fvdb::detail::ops::dispatchGaussianProjectionBackward( - means, - quats, - torch::log(scales), - viewmats, - Ks, - compensations_proj, - imageWidth, - imageHeight, - eps2d, - radii_proj, - conics_proj, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - dLossDCompensations, - true, - false, - outNormalizeddLossdMeans2dNormAccum, - outNormalizedMaxRadiiAccum, - outGradientStepCounts); + fvdb::detail::ops::project_gaussians_analytic_bwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + compensations_proj, + imageWidth, + imageHeight, + eps2d, + radii_proj, + conics_proj, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + true, + false, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); std::vector outputData = { dLossDMeans, // dLossDCovars, Currently dLossDCovars is not output, not exposed, see - // dispatchGaussianProjectionBackward + // project_gaussians_analytic_bwd dLossDQuats, dLossDScales, dLossDCamToWorlds, @@ -249,19 +248,19 @@ TEST_F(GaussianProjectionBackwardTestFixture, DISABLED_GenerateOutputData) { { // Orthographic projection const auto [radii_proj, means2d_proj, depths_proj, conics_proj, compensations_proj] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - scales, - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - true); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + scales, + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + true); const auto C = radii_proj.size(0); const auto N = radii_proj.size(1); @@ -285,32 +284,31 @@ TEST_F(GaussianProjectionBackwardTestFixture, DISABLED_GenerateOutputData) { auto outGradientStepCounts = torch::zeros({N}, options.dtype(torch::kInt32)); const auto [dLossDMeans, dLossDCovars, dLossDQuats, dLossDScales, dLossDCamToWorlds] = - fvdb::detail::ops::dispatchGaussianProjectionBackward( - means, - quats, - torch::log(scales), - viewmats, - Ks, - compensations_proj, - imageWidth, - imageHeight, - eps2d, - radii_proj, - conics_proj, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - dLossDCompensations, - true, - true, - outNormalizeddLossdMeans2dNormAccum, - outNormalizedMaxRadiiAccum, - outGradientStepCounts); + fvdb::detail::ops::project_gaussians_analytic_bwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + compensations_proj, + imageWidth, + imageHeight, + eps2d, + radii_proj, + conics_proj, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + true, + true, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); std::vector outputData = { dLossDMeans, // dLossDCovars, Currently dLossDCovars is not output, not exposed, see - // dispatchGaussianProjectionBackward + // project_gaussians_analytic_bwd dLossDQuats, dLossDScales, dLossDCamToWorlds, @@ -342,27 +340,26 @@ TEST_F(GaussianProjectionBackwardTestFixture, TestPerspectiveProjection) { auto outGradientStepCounts = torch::zeros({N}, options.dtype(torch::kInt32)); const auto [dLossDMeans, dLossDCovars, dLossDQuats, dLossDScales, dLossDCamToWorlds] = - fvdb::detail::ops::dispatchGaussianProjectionBackward( - means, - quats, - torch::log(scales), - viewmats, - Ks, - compensations, - imageWidth, - imageHeight, - eps2d, - radii, - conics, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - dLossDCompensations, - true, - false, - outNormalizeddLossdMeans2dNormAccum, - outNormalizedMaxRadiiAccum, - outGradientStepCounts); + fvdb::detail::ops::project_gaussians_analytic_bwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + compensations, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + true, + false, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); auto [rtol, atol] = tolerances(); #if 0 @@ -420,27 +417,26 @@ TEST_F(GaussianProjectionBackwardTestFixture, TestOrthographicProjection) { auto outGradientStepCounts = torch::zeros({N}, options.dtype(torch::kInt32)); const auto [dLossDMeans, dLossDCovars, dLossDQuats, dLossDScales, dLossDCamToWorlds] = - fvdb::detail::ops::dispatchGaussianProjectionBackward( - means, - quats, - torch::log(scales), - viewmats, - Ks, - compensations, - imageWidth, - imageHeight, - eps2d, - radii, - conics, - dLossDMeans2d, - dLossDDepths, - dLossDConics, - dLossDCompensations, - true, - true, - outNormalizeddLossdMeans2dNormAccum, - outNormalizedMaxRadiiAccum, - outGradientStepCounts); + fvdb::detail::ops::project_gaussians_analytic_bwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + compensations, + imageWidth, + imageHeight, + eps2d, + radii, + conics, + dLossDMeans2d, + dLossDDepths, + dLossDConics, + dLossDCompensations, + true, + true, + outNormalizeddLossdMeans2dNormAccum, + outNormalizedMaxRadiiAccum, + outGradientStepCounts); auto tol2 = tolerances(); #if 0 diff --git a/src/tests/GaussianProjectionForwardTest.cpp b/src/tests/GaussianProjectionForwardTest.cpp index cb5d41416..a2bc1a6fa 100644 --- a/src/tests/GaussianProjectionForwardTest.cpp +++ b/src/tests/GaussianProjectionForwardTest.cpp @@ -3,8 +3,8 @@ #include "utils/Tensor.h" -#include -#include +#include +#include #include #include @@ -118,19 +118,19 @@ TEST_F(GaussianProjectionForwardTestFixture, DISABLED_GenerateOutputData) { { // Perspective projection const auto [radii, means2d, depths, conics, compensations] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - torch::log(scales), - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - false); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + false); std::vector outputData = {radii, means2d, depths, conics}; @@ -142,19 +142,19 @@ TEST_F(GaussianProjectionForwardTestFixture, DISABLED_GenerateOutputData) { { // Orthographic projection const auto [radii, means2d, depths, conics, compensations] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - torch::log(scales), - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - true); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + true); std::vector outputData = {radii, means2d, depths, conics}; @@ -168,19 +168,19 @@ TEST_F(GaussianProjectionForwardTestFixture, TestPerspectiveProjection) { loadTestData("projection_forward_inputs.pt", "projection_persp_forward_outputs.pt"); const auto [radii, means2d, depths, conics, compensations] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - torch::log(scales), - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - false); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + false); // Use relaxed tolerances to account for minor numerical differences between debug and release // builds. The default rtol=1e-5, atol=1e-8 are too strict for operations involving exp, sqrt, @@ -205,19 +205,19 @@ TEST_F(GaussianProjectionForwardTestFixture, TestOrthographicProjection) { loadTestData("projection_forward_inputs.pt", "projection_ortho_forward_outputs.pt"); const auto [radii, means2d, depths, conics, compensations] = - fvdb::detail::ops::dispatchGaussianProjectionForward(means, - quats, - torch::log(scales), - viewmats, - Ks, - imageWidth, - imageHeight, - 0.3, - 1e-2, - 1e10, - 0, - false, - true); + fvdb::detail::ops::project_gaussians_analytic_fwd(means, + quats, + torch::log(scales), + viewmats, + Ks, + imageWidth, + imageHeight, + 0.3, + 1e-2, + 1e10, + 0, + false, + true); // other outputs are undefined where radii is zero auto radiiNonZeroMask = radii > 0; // [C, N] diff --git a/src/tests/GaussianProjectionUTTest.cpp b/src/tests/GaussianProjectionUTTest.cpp index 02b0cc214..38226f876 100644 --- a/src/tests/GaussianProjectionUTTest.cpp +++ b/src/tests/GaussianProjectionUTTest.cpp @@ -1,7 +1,7 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 -#include +#include #include #include @@ -173,23 +173,23 @@ TEST_F(GaussianProjectionUTTestFixture, CenteredGaussian_NoDistortion_AnalyticMe distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto means2d_cpu = means2d.cpu(); auto depths_cpu = depths.cpu(); @@ -277,23 +277,23 @@ TEST_F(GaussianProjectionUTTestFixture, NonlinearUTCovariance_ProducesFinitePosi distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto conics_cpu = conics.cpu(); @@ -356,23 +356,23 @@ TEST_F(GaussianProjectionUTTestFixture, UTParams_InvalidAlpha_ThrowsOnHost) { projectionMatrices = projectionMatrices.cuda(); distortionCoeffs = distortionCoeffs.cuda(); - EXPECT_THROW((dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false)), + EXPECT_THROW((project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false)), c10::Error); } @@ -419,23 +419,23 @@ TEST_F(GaussianProjectionUTTestFixture, UTParams_InvalidKappa_ThrowsOnHost) { projectionMatrices = projectionMatrices.cuda(); distortionCoeffs = distortionCoeffs.cuda(); - EXPECT_THROW((dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false)), + EXPECT_THROW((project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false)), c10::Error); } @@ -484,23 +484,23 @@ TEST_F(GaussianProjectionUTTestFixture, DepthNearCameraPlane_BelowZEps_HardRejec distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); EXPECT_EQ(radii_cpu[0][0].item(), 0); @@ -550,23 +550,23 @@ TEST_F(GaussianProjectionUTTestFixture, DepthNearCameraPlane_AboveZEps_Projects) distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -619,23 +619,23 @@ TEST_F(GaussianProjectionUTTestFixture, Orthographic_NoDistortion_AnalyticMeanAn distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto means2d_cpu = means2d.cpu(); auto depths_cpu = depths.cpu(); @@ -695,23 +695,23 @@ TEST_F(GaussianProjectionUTTestFixture, OffAxisTinyGaussian_NoDistortion_MeanMat distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto means2d_cpu = means2d.cpu(); const float expected_u = fx * (x / z) + cx; @@ -790,23 +790,23 @@ TEST_F(GaussianProjectionUTTestFixture, MultiCamera_RadTanDistortion_PerCameraPa distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -879,23 +879,23 @@ TEST_F(GaussianProjectionUTTestFixture, MultiCamera_Pinhole_ZeroCoeffTensor_PerC distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -971,23 +971,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -1064,23 +1064,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -1165,23 +1165,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -1261,23 +1261,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -1329,23 +1329,23 @@ TEST_F(GaussianProjectionUTTestFixture, RadTanThinPrism_IgnoresK456EvenIfNonZero utParams = UTParams{}; const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); auto means2d_cpu = means2d.cpu(); @@ -1413,23 +1413,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); // When the UT kernel discards a Gaussian, only radii are defined to be 0; other outputs are // undefined (may contain garbage). Only assert radii here. @@ -1487,23 +1487,23 @@ TEST_F(GaussianProjectionUTTestFixture, distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto radii_cpu = radii.cpu(); EXPECT_GT(radii_cpu[0][0].item(), 0); @@ -1558,23 +1558,23 @@ TEST_F(GaussianProjectionUTTestFixture, RollingShutterNone_DepthUsesStartPoseNot distortionCoeffs = distortionCoeffs.cuda(); const auto [radii, means2d, depths, conics, compensations] = - dispatchGaussianProjectionForwardUT(means, - quats, - logScales, - worldToCamMatricesStart, - worldToCamMatricesEnd, - projectionMatrices, - RollingShutterType::NONE, - utParams, - cameraModel, - distortionCoeffs, - imageWidth, - imageHeight, - eps2d, - nearPlane, - farPlane, - minRadius2d, - false); + project_gaussians_ut_fwd(means, + quats, + logScales, + worldToCamMatricesStart, + worldToCamMatricesEnd, + projectionMatrices, + RollingShutterType::NONE, + utParams, + cameraModel, + distortionCoeffs, + imageWidth, + imageHeight, + eps2d, + nearPlane, + farPlane, + minRadius2d, + false); auto depths_cpu = depths.cpu(); // Start pose is identity, so depth should be exactly z (not z + 0.5). diff --git a/src/tests/GaussianRasterizeBackwardTest.cpp b/src/tests/GaussianRasterizeBackwardTest.cpp index b8472445c..e345d8f0e 100644 --- a/src/tests/GaussianRasterizeBackwardTest.cpp +++ b/src/tests/GaussianRasterizeBackwardTest.cpp @@ -3,10 +3,10 @@ #include "utils/Tensor.h" -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -167,7 +167,8 @@ class GaussianTestHelper { auto pixelsToRender = fvdb::JaggedTensor(sparsePixelCoords); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo(tileSize, numTilesW, numTilesH, pixelsToRender); + fvdb::detail::ops::build_sparse_gaussian_tile_layout( + tileSize, numTilesW, numTilesH, pixelsToRender); return {pixelsToRender, activeTiles, tilePixelMask, tilePixelCumsum, pixelMap}; } @@ -178,19 +179,18 @@ class GaussianTestHelper { int imageWidth, int imageHeight, int tileSize) { - auto [colors, alphas, lastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - gaussians.means2d, - gaussians.conics, - gaussians.colors, - gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tiles.tileOffsets, - tiles.tileGaussianIds); + auto [colors, alphas, lastIds] = fvdb::detail::ops::rasterize_screen_space_gaussians_fwd( + gaussians.means2d, + gaussians.conics, + gaussians.colors, + gaussians.opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tiles.tileOffsets, + tiles.tileGaussianIds); return {colors, alphas, lastIds}; } @@ -202,16 +202,16 @@ class GaussianTestHelper { int imageHeight, int tileSize) { auto [colors, alphas, lastIds] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -301,15 +301,15 @@ class GaussianTestHelper { int tileSize, int64_t numSharedChannelsOverride = -1) { auto [dLossDMeans2dAbs, dLossDMeans2d, dLossDConics, dLossDColors, dLossDOpacities] = - fvdb::detail::ops::dispatchGaussianRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd( gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -334,16 +334,16 @@ class GaussianTestHelper { int tileSize, int64_t numSharedChannelsOverride = -1) { auto [dLossDMeans2dAbs, dLossDMeans2d, dLossDConics, dLossDColors, dLossDOpacities] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_bwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -633,24 +633,24 @@ TEST_F(GaussianRasterizeTestFixture, TestChunkedChannels) { TEST_F(GaussianRasterizeTestFixture, CPUThrows) { loadTestData("rasterize_backward_inputs.pt", "rasterize_backward_outputs.pt"); moveToDevice(torch::kCPU); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRasterizeBackward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds, - renderedAlphas, - lastGaussianIdsPerPixel, - dLossDRenderedColors, - dLossDRenderedAlphas, - false), - c10::NotImplementedError); + EXPECT_THROW( + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds, + renderedAlphas, + lastGaussianIdsPerPixel, + dLossDRenderedColors, + dLossDRenderedAlphas, + false), + c10::NotImplementedError); } TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardRasterization) { @@ -663,7 +663,7 @@ TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardRasterization) { // Compute sparse info for the pixels to render auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); // Step 1: Run forward dense on the same scene to get dense rendered output @@ -1097,34 +1097,32 @@ TEST_F(GaussianRasterizeTestFixture, TestDenseBackwardWithBackgrounds) { // Run forward pass WITH backgrounds auto [colorsWithBg, alphasWithBg, lastIdsWithBg] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - gaussians.means2d, - gaussians.conics, - gaussians.colors, - gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tiles.tileOffsets, - tiles.tileGaussianIds, - backgrounds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(gaussians.means2d, + gaussians.conics, + gaussians.colors, + gaussians.opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tiles.tileOffsets, + tiles.tileGaussianIds, + backgrounds); // Run forward pass WITHOUT backgrounds auto [colorsNoBg, alphasNoBg, lastIdsNoBg] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - gaussians.means2d, - gaussians.conics, - gaussians.colors, - gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tiles.tileOffsets, - tiles.tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(gaussians.means2d, + gaussians.conics, + gaussians.colors, + gaussians.opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tiles.tileOffsets, + tiles.tileGaussianIds); // Alphas and last IDs should be identical EXPECT_TRUE(torch::allclose(alphasWithBg, alphasNoBg)); @@ -1140,25 +1138,24 @@ TEST_F(GaussianRasterizeTestFixture, TestDenseBackwardWithBackgrounds) { dLossDConicsWithBg, dLossDColorsWithBg, dLossDOpacitiesWithBg] = - fvdb::detail::ops::dispatchGaussianRasterizeBackward( - gaussians.means2d, - gaussians.conics, - gaussians.colors, - gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tiles.tileOffsets, - tiles.tileGaussianIds, - alphasWithBg, - lastIdsWithBg, - gradColors, - gradAlphas, - false, - -1, - backgrounds); + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd(gaussians.means2d, + gaussians.conics, + gaussians.colors, + gaussians.opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tiles.tileOffsets, + tiles.tileGaussianIds, + alphasWithBg, + lastIdsWithBg, + gradColors, + gradAlphas, + false, + -1, + backgrounds); // Run backward pass WITHOUT backgrounds auto [dLossDMeans2dAbsNoBg, @@ -1166,23 +1163,22 @@ TEST_F(GaussianRasterizeTestFixture, TestDenseBackwardWithBackgrounds) { dLossDConicsNoBg, dLossDColorsNoBg, dLossDOpacitiesNoBg] = - fvdb::detail::ops::dispatchGaussianRasterizeBackward( - gaussians.means2d, - gaussians.conics, - gaussians.colors, - gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tiles.tileOffsets, - tiles.tileGaussianIds, - alphasNoBg, - lastIdsNoBg, - gradColors, - gradAlphas, - false); + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd(gaussians.means2d, + gaussians.conics, + gaussians.colors, + gaussians.opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tiles.tileOffsets, + tiles.tileGaussianIds, + alphasNoBg, + lastIdsNoBg, + gradColors, + gradAlphas, + false); // Gradients should be DIFFERENT when backgrounds are used // (because transparent pixels now have background contribution) @@ -1231,16 +1227,16 @@ TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardWithBackgrounds) { // Run sparse forward pass WITH backgrounds auto [colorsWithBg, alphasWithBg, lastIdsWithBg] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -1252,16 +1248,16 @@ TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardWithBackgrounds) { // Run sparse forward pass WITHOUT backgrounds auto [colorsNoBg, alphasNoBg, lastIdsNoBg] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -1288,16 +1284,16 @@ TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardWithBackgrounds) { dLossDConicsWithBg, dLossDColorsWithBg, dLossDOpacitiesWithBg] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_bwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -1319,16 +1315,16 @@ TEST_F(GaussianRasterizeTestFixture, TestSparseBackwardWithBackgrounds) { dLossDConicsNoBg, dLossDColorsNoBg, dLossDOpacitiesNoBg] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_bwd( sparse.pixelsToRender, gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), tileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -1400,15 +1396,15 @@ TEST_F(GaussianRasterizeTestFixture, TestPackedModeBackwardMultipleCameras) { expectedDConics, expectedDColors, expectedDOpacities] = - fvdb::detail::ops::dispatchGaussianRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd( gaussians.means2d, gaussians.conics, gaussians.colors, gaussians.opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(testImageWidth), - static_cast(testImageHeight), - static_cast(0), - static_cast(0)}, + static_cast(testImageWidth), + static_cast(testImageHeight), + static_cast(0), + static_cast(0), testTileSize, tiles.tileOffsets, tiles.tileGaussianIds, @@ -1434,15 +1430,15 @@ TEST_F(GaussianRasterizeTestFixture, TestPackedModeBackwardMultipleCameras) { outDConicsPacked, outDColorsPacked, outDOpacitiesPacked] = - fvdb::detail::ops::dispatchGaussianRasterizeBackward( + fvdb::detail::ops::rasterize_screen_space_gaussians_bwd( means2dPacked, conicsPacked, colorsPacked, opacitiesPacked, - fvdb::detail::ops::RenderWindow2D{static_cast(testImageWidth), - static_cast(testImageHeight), - static_cast(0), - static_cast(0)}, + static_cast(testImageWidth), + static_cast(testImageHeight), + static_cast(0), + static_cast(0), testTileSize, tiles.tileOffsets, // Still use [C, H, W] tile offsets tiles.tileGaussianIds, diff --git a/src/tests/GaussianRasterizeContributingGaussianIdsTest.cpp b/src/tests/GaussianRasterizeContributingGaussianIdsTest.cpp index f59fd178d..9f4799dad 100644 --- a/src/tests/GaussianRasterizeContributingGaussianIdsTest.cpp +++ b/src/tests/GaussianRasterizeContributingGaussianIdsTest.cpp @@ -3,9 +3,9 @@ #include "utils/Tensor.h" -#include -#include -#include +#include +#include +#include #include #include @@ -142,26 +142,33 @@ struct GaussianRasterizeContributingGaussianIdsTestFixture : public ::testing::T TEST_F(GaussianRasterizeContributingGaussianIdsTestFixture, TestBasicInputsAndOutputs) { loadTestData("gaussian_top_contributors_1point_input.pt"); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - // First compute the number of contributing Gaussians const auto [numContributingGaussians, alphas] = - fvdb::detail::ops::dispatchGaussianRasterizeNumContributingGaussians( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings); + fvdb::detail::ops::count_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize); // Then compute the IDs and weights const auto [outIds, outWeights] = - fvdb::detail::ops::dispatchGaussianRasterizeContributingGaussianIds( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - settings, - numContributingGaussians); + fvdb::detail::ops::identify_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + -1, + numContributingGaussians); const int h = imageHeight; const int w = imageWidth; @@ -218,26 +225,33 @@ TEST_F(GaussianRasterizeContributingGaussianIdsTestFixture, TestBasicInputsAndOu TEST_F(GaussianRasterizeContributingGaussianIdsTestFixture, TestBasicInputsAndOutputsSparse) { loadTestData("gaussian_top_contributors_1point_input.pt"); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - // First compute the number of contributing Gaussians for dense rendering const auto [numContributingGaussians, alphas] = - fvdb::detail::ops::dispatchGaussianRasterizeNumContributingGaussians( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings); + fvdb::detail::ops::count_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize); // Then compute the IDs and weights for dense rendering const auto [outIds, outWeights] = - fvdb::detail::ops::dispatchGaussianRasterizeContributingGaussianIds( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - settings, - numContributingGaussians); + fvdb::detail::ops::identify_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + -1, + numContributingGaussians); const int h = imageHeight; const int w = imageWidth; @@ -247,39 +261,46 @@ TEST_F(GaussianRasterizeContributingGaussianIdsTestFixture, TestBasicInputsAndOu fvdb::JaggedTensor pixelsToRender({pixelsToRenderTensor.unsqueeze(0)}); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRenderTensor); // Compute num contributing gaussians for sparse rendering const auto [numContributingGaussiansSparse, alphasSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeNumContributingGaussians( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - pixelsToRender, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap, - settings); + fvdb::detail::ops::count_contributing_gaussians_sparse(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + 0, + 0, + tileSize); // Run the same scene with sparse sampling of only the center pixel const auto [outIdsSparse, outWeightsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeContributingGaussianIds( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - pixelsToRender, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap, - settings, - numContributingGaussiansSparse); + fvdb::detail::ops::identify_contributing_gaussians_sparse(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + -1, + numContributingGaussiansSparse); const int numGaussianLayers = 5; @@ -344,12 +365,15 @@ TEST_F(GaussianRasterizeContributingGaussianIdsTestFixture, CPUThrows) { loadTestData("gaussian_top_contributors_1point_input.pt"); moveToDevice(torch::kCPU); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRasterizeNumContributingGaussians( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings), + EXPECT_THROW(fvdb::detail::ops::count_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize), c10::NotImplementedError); } diff --git a/src/tests/GaussianRasterizeForwardTest.cpp b/src/tests/GaussianRasterizeForwardTest.cpp index 99e1099da..64cc86df5 100644 --- a/src/tests/GaussianRasterizeForwardTest.cpp +++ b/src/tests/GaussianRasterizeForwardTest.cpp @@ -4,9 +4,9 @@ #include "utils/ImageUtils.h" #include "utils/Tensor.h" -#include -#include -#include +#include +#include +#include #include @@ -343,25 +343,23 @@ TEST(GaussianRasterizeForwardMaskedEdgeTile, Child) { torch::TensorOptions().device(torch::kCUDA).dtype(torch::kBool)); masks[0][1][1] = false; // mask out bottom-right edge tile - auto [tileOffsets, tileGaussianIds] = - fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d, radii, depths, at::nullopt, (uint32_t)C, tileSize, tileExtentH, tileExtentW); + auto [tileOffsets, tileGaussianIds] = fvdb::detail::ops::intersect_gaussian_tiles( + means2d, radii, depths, at::nullopt, (uint32_t)C, tileSize, tileExtentH, tileExtentW); auto [outFeatures, outAlphas, outLastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - features, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(0), - static_cast(0)}, - tileSize, - tileOffsets, - tileGaussianIds, - backgrounds, - masks); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + features, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(0), + static_cast(0), + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds, + masks); (void)outLastIds; @@ -435,15 +433,15 @@ TEST_F(GaussianRasterizeForwardTestFixture, DISABLED_GenerateOutputData) { // Test with 3 channels { const auto [renderedColors, renderedAlphas, lastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd( means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds); @@ -460,15 +458,15 @@ TEST_F(GaussianRasterizeForwardTestFixture, DISABLED_GenerateOutputData) { auto colors_64 = catChannelsToDim(colors, 64); const auto [renderedColors, renderedAlphas, lastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd( means2d, conics, colors_64, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth / 2), - static_cast(imageHeight / 2), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth / 2), + static_cast(imageHeight / 2), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds); @@ -484,18 +482,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestBasicInputsAndOutputs) { loadTestData("rasterize_forward_inputs.pt", "rasterize_forward_outputs.pt"); const auto [outColors, outAlphas, outLastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); EXPECT_TRUE(torch::allclose(outColors, expectedRenderedColors)); EXPECT_TRUE(torch::allclose(outAlphas, expectedRenderedAlphas)); @@ -509,18 +506,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestConcatenatedChannels) { expectedRenderedColors = catChannelsToDim(expectedRenderedColors, 64); const auto [outColors, outAlphas, outLastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); EXPECT_TRUE(torch::allclose(outColors, expectedRenderedColors)); EXPECT_TRUE(torch::allclose(outAlphas, expectedRenderedAlphas)); @@ -535,18 +531,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestMultipleCameras) { // run all 3 cameras at once const auto [outColorsAll, outAlphasAll, outLastIdsAll] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); // rasterize each camera individually std::vector outColorsList; @@ -576,15 +571,15 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestMultipleCameras) { // Kernel receives adjusted offsets and 0-based IDs for this camera auto [outColors, outAlphas, outLastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd( means2d_1cam, conics_1cam, colors_1cam, opacities_1cam, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets_1cam, tileGaussianIds_1cam); @@ -669,34 +664,32 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestMultipleCamerasWithBackgrounds) // Render without background const auto [outColorsNoBackground, outAlphasNoBackground, outLastIdsNoBackground] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); // Render with different background per camera const auto [outColorsWithBackground, outAlphasWithBackground, outLastIdsWithBackground] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds, - backgrounds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds, + backgrounds); // Alphas and last IDs should be identical regardless of background EXPECT_TRUE(torch::allclose(outAlphasNoBackground, outAlphasWithBackground)); @@ -729,20 +722,20 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestSparseRasterization) { auto const pixelsToRender = generateSparsePixelCoords(numCameras, 100).cuda(); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); const auto [outColorsSparse, outAlphasSparse, outLastIdsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -771,20 +764,20 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestSparseRasterizationConcatenatedC auto const pixelsToRender = generateSparsePixelCoords(numCameras, 100).cuda(); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); const auto [outColorsSparse, outAlphasSparse, outLastIdsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -811,35 +804,34 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestSparseRasterizationMultipleCamer auto const pixelsToRender = generateSparsePixelCoords(numCameras, 100).cuda(); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); // run all 3 cameras at once const auto [outColorsAll, outAlphasAll, outLastIdsAll] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); const auto [outColorsSparse, outAlphasSparse, outLastIdsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -877,23 +869,23 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestSparseRasterizationMultipleCamer auto const pixelsToRender = generateSparsePixelCoords(numCameras, 100).cuda(); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); // Render sparse without background const auto [outColorsSparseNoBackground, outAlphasSparseNoBackground, outLastIdsSparseNoBackground] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -906,16 +898,16 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestSparseRasterizationMultipleCamer const auto [outColorsSparseWithBackground, outAlphasSparseWithBackground, outLastIdsSparseWithBackground] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -976,18 +968,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestPackedModeMultipleCameras) { // Step 1: Run non-packed rasterization to get expected results const auto [expectedColors, expectedAlphas, expectedLastIds] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); // Step 2: Reshape tensors to packed format [nnz, D] // The test data's tileGaussianIds already contains global indices (0 to C*N-1). @@ -1000,18 +991,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestPackedModeMultipleCameras) { // Step 3: Run packed rasterization with same tileOffsets and tileGaussianIds const auto [outColorsPacked, outAlphasPacked, outLastIdsPacked] = - fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2dPacked, - conicsPacked, - colorsPacked, - opacitiesPacked, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds); + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2dPacked, + conicsPacked, + colorsPacked, + opacitiesPacked, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds); // Step 4: Compare results // The output shapes should match: [C, H, W, D] for colors, [C, H, W, 1] for alphas @@ -1048,21 +1038,21 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestPackedModeSparseMultipleCameras) // Compute sparse info from pixels auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); // Step 1: Run non-packed sparse rasterization to get expected results const auto [expectedColorsSparse, expectedAlphasSparse, expectedLastIdsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2d, conics, colors, opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -1082,16 +1072,16 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestPackedModeSparseMultipleCameras) // Step 3: Run packed sparse rasterization with same sparse info and same gaussian IDs const auto [outColorsPacked, outAlphasPacked, outLastIdsPacked] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeForward( + fvdb::detail::ops::rasterize_screen_space_gaussians_sparse_fwd( pixelsToRender, means2dPacked, conicsPacked, colorsPacked, opacitiesPacked, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), tileSize, tileOffsets, tileGaussianIds, @@ -1125,17 +1115,17 @@ TEST_F(GaussianRasterizeForwardTestFixture, TestPackedModeSparseMultipleCameras) TEST_F(GaussianRasterizeForwardTestFixture, CPUThrows) { loadTestData("rasterize_forward_inputs.pt", "rasterize_forward_outputs.pt"); moveToDevice(torch::kCPU); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianRasterizeForward( - means2d, - conics, - colors, - opacities, - fvdb::detail::ops::RenderWindow2D{static_cast(imageWidth), - static_cast(imageHeight), - static_cast(imageOriginW), - static_cast(imageOriginH)}, - tileSize, - tileOffsets, - tileGaussianIds), - c10::NotImplementedError); + EXPECT_THROW( + fvdb::detail::ops::rasterize_screen_space_gaussians_fwd(means2d, + conics, + colors, + opacities, + static_cast(imageWidth), + static_cast(imageHeight), + static_cast(imageOriginW), + static_cast(imageOriginH), + tileSize, + tileOffsets, + tileGaussianIds), + c10::NotImplementedError); } diff --git a/src/tests/GaussianRasterizeTopContributorsTest.cpp b/src/tests/GaussianRasterizeTopContributorsTest.cpp index f80a1da1c..f9c7dfa82 100644 --- a/src/tests/GaussianRasterizeTopContributorsTest.cpp +++ b/src/tests/GaussianRasterizeTopContributorsTest.cpp @@ -3,8 +3,8 @@ #include "utils/Tensor.h" -#include -#include +#include +#include #include #include @@ -143,15 +143,18 @@ struct GaussianRasterizeTopContributorsTestFixture : public ::testing::Test { TEST_F(GaussianRasterizeTopContributorsTestFixture, TestBasicInputsAndOutputs) { loadTestData("gaussian_top_contributors_1point_input.pt"); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - settings.numDepthSamples = numDepthSamples; - const auto [outIds, outWeights] = - fvdb::detail::ops::dispatchGaussianRasterizeTopContributingGaussianIds( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings); + fvdb::detail::ops::identify_top_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + numDepthSamples); const int h = imageHeight; const int w = imageWidth; @@ -205,15 +208,18 @@ TEST_F(GaussianRasterizeTopContributorsTestFixture, TestBasicInputsAndOutputs) { TEST_F(GaussianRasterizeTopContributorsTestFixture, TestBasicInputsAndOutputsSparse) { loadTestData("gaussian_top_contributors_1point_input.pt"); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - settings.numDepthSamples = numDepthSamples; - const auto [outIds, outWeights] = - fvdb::detail::ops::dispatchGaussianRasterizeTopContributingGaussianIds( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings); + fvdb::detail::ops::identify_top_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + numDepthSamples); const int h = imageHeight; const int w = imageWidth; @@ -221,23 +227,27 @@ TEST_F(GaussianRasterizeTopContributorsTestFixture, TestBasicInputsAndOutputsSpa const auto pixelsToRender = torch::tensor({{h / 2 - 1, w / 2 - 1}}).cuda(); auto [activeTiles, activeTileMask, tilePixelMask, tilePixelCumsum, pixelMap] = - fvdb::detail::ops::computeSparseInfo( + fvdb::detail::ops::build_sparse_gaussian_tile_layout( tileSize, tileOffsets.size(2), tileOffsets.size(1), pixelsToRender); // Run the same scene with sparse sampling of only the center pixel const auto [outIdsSparse, outWeightsSparse] = - fvdb::detail::ops::dispatchGaussianSparseRasterizeTopContributingGaussianIds( - means2d, - conics, - opacities, - tileOffsets, - tileGaussianIds, - pixelsToRender, - activeTiles, - tilePixelMask, - tilePixelCumsum, - pixelMap, - settings); + fvdb::detail::ops::identify_top_contributing_gaussians_sparse(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + pixelsToRender, + activeTiles, + tilePixelMask, + tilePixelCumsum, + pixelMap, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + numDepthSamples); const int numGaussianLayers = 5; @@ -288,14 +298,16 @@ TEST_F(GaussianRasterizeTopContributorsTestFixture, CPUThrows) { loadTestData("gaussian_top_contributors_1point_input.pt"); moveToDevice(torch::kCPU); - fvdb::detail::ops::RenderSettings settings; - settings.imageWidth = imageWidth; - settings.imageHeight = imageHeight; - settings.tileSize = tileSize; - settings.numDepthSamples = numDepthSamples; - - EXPECT_THROW( - fvdb::detail::ops::dispatchGaussianRasterizeTopContributingGaussianIds( - means2d, conics, opacities, tileOffsets, tileGaussianIds, settings), - c10::NotImplementedError); + EXPECT_THROW(fvdb::detail::ops::identify_top_contributing_gaussians(means2d, + conics, + opacities, + tileOffsets, + tileGaussianIds, + imageWidth, + imageHeight, + 0, + 0, + tileSize, + numDepthSamples), + c10::NotImplementedError); } diff --git a/src/tests/GaussianSphericalHarmonicsBackwardTest.cpp b/src/tests/GaussianSphericalHarmonicsBackwardTest.cpp index 0ebd57fb6..ad9d7fda9 100644 --- a/src/tests/GaussianSphericalHarmonicsBackwardTest.cpp +++ b/src/tests/GaussianSphericalHarmonicsBackwardTest.cpp @@ -3,7 +3,7 @@ #include "utils/Tensor.h" -#include +#include #include @@ -90,15 +90,14 @@ struct SphericalHarmonincsBackwardTestFixture : public ::testing::TestWithParam< const bool setZeroRadii = false) { { auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); if (setZeroRadii) { const auto dLdSh0Slice = dLossDSh0Coeffs.index({torch::indexing::Slice(0, -1, 2), @@ -124,15 +123,14 @@ struct SphericalHarmonincsBackwardTestFixture : public ::testing::TestWithParam< // We don't return view direction gradients if you don't ask for them { auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); if (setZeroRadii) { const auto dLdSh0Slice = dLossDSh0Coeffs.index({torch::indexing::Slice(0, -1, 2), torch::indexing::Slice(), @@ -174,15 +172,14 @@ struct SphericalHarmonincsBackwardTestFixture : public ::testing::TestWithParam< { auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); EXPECT_TRUE(dLossDSh0Coeffs.sizes() == expectedSh0Sizes); EXPECT_FALSE(dLossDShNCoeffs.defined()); EXPECT_FALSE(dLossDViewDirs.defined()); @@ -194,15 +191,14 @@ struct SphericalHarmonincsBackwardTestFixture : public ::testing::TestWithParam< { shNCoeffs = torch::Tensor(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); EXPECT_TRUE(dLossDSh0Coeffs.sizes() == expectedSh0Sizes); EXPECT_FALSE(dLossDShNCoeffs.defined()); EXPECT_FALSE(dLossDViewDirs.defined()); @@ -215,15 +211,14 @@ struct SphericalHarmonincsBackwardTestFixture : public ::testing::TestWithParam< shNCoeffs = torch::Tensor(); viewDirs = torch::Tensor(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); EXPECT_TRUE(dLossDSh0Coeffs.sizes() == expectedSh0Sizes); EXPECT_FALSE(dLossDShNCoeffs.defined()); EXPECT_FALSE(dLossDViewDirs.defined()); @@ -271,15 +266,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkSh0) { for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); } @@ -289,15 +283,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkSh0) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -324,15 +317,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkSh0WithViewDirGrad) { for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); } @@ -342,15 +334,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkSh0WithViewDirGrad) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -376,15 +367,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkShNWithViewDirGrad) { for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); } @@ -394,15 +384,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkShNWithViewDirGrad) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - true); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + true); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -428,15 +417,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkShNWithoutViewDirGrad) { for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); } @@ -446,15 +434,14 @@ TEST_F(SphericalHarmonincsBackwardTestFixture, BenchmarkShNWithoutViewDirGrad) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); auto [dLossDSh0Coeffs, dLossDShNCoeffs, dLossDViewDirs] = - fvdb::detail::ops::dispatchSphericalHarmonicsBackward( - shDegreeToUse, - numCameras, - numGaussians, - viewDirs, - shNCoeffs, - dLossDRenderQuantities, - radii, - false); + fvdb::detail::ops::eval_gaussian_sh_bwd(shDegreeToUse, + numCameras, + numGaussians, + viewDirs, + shNCoeffs, + dLossDRenderQuantities, + radii, + false); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start); diff --git a/src/tests/GaussianSphericalHarmonicsForwardTest.cpp b/src/tests/GaussianSphericalHarmonicsForwardTest.cpp index 816dcb5bf..d2a588501 100644 --- a/src/tests/GaussianSphericalHarmonicsForwardTest.cpp +++ b/src/tests/GaussianSphericalHarmonicsForwardTest.cpp @@ -3,7 +3,7 @@ #include "utils/Tensor.h" -#include +#include #include @@ -105,7 +105,7 @@ struct SphericalHarmonicsForwardTestFixture : public ::testing::TestWithParam( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); EXPECT_TRUE(result.sizes() == torch::IntArrayRef({numCameras, numGaussians, numChannels})); @@ -114,7 +114,7 @@ TEST_P(SphericalHarmonicsForwardTestFixture, TestShForward) { { shNCoeffs = torch::Tensor(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); EXPECT_TRUE(result.sizes() == torch::IntArrayRef({numCameras, numGaussians, numChannels})); @@ -123,14 +123,14 @@ TEST_P(SphericalHarmonicsForwardTestFixture, TestShForward) { { viewDirs = torch::Tensor(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); EXPECT_TRUE(result.sizes() == torch::IntArrayRef({numCameras, numGaussians, numChannels})); EXPECT_TRUE(torch::allclose(result, expectedResult)); } } else { - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); EXPECT_TRUE(result.sizes() == torch::IntArrayRef({numCameras, numGaussians, numChannels})); EXPECT_TRUE(torch::allclose(result, expectedResult)); @@ -155,7 +155,7 @@ TEST_F(SphericalHarmonicsTestFixture, TestSh0Benchmark) { // Warm up for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); torch::cuda::synchronize(); } @@ -165,7 +165,7 @@ TEST_F(SphericalHarmonicsTestFixture, TestSh0Benchmark) { for (int i = 0; i < totalIters; i += 1) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); @@ -192,7 +192,7 @@ TEST_F(SphericalHarmonicsTestFixture, TestShNNBenchmark) { // Warm up for (int i = 0; i < 10; i += 1) { torch::cuda::synchronize(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); torch::cuda::synchronize(); } @@ -202,7 +202,7 @@ TEST_F(SphericalHarmonicsTestFixture, TestShNNBenchmark) { for (int i = 0; i < totalIters; i += 1) { torch::cuda::synchronize(); auto start = std::chrono::high_resolution_clock::now(); - auto result = fvdb::detail::ops::dispatchSphericalHarmonicsForward( + auto result = fvdb::detail::ops::eval_gaussian_sh_fwd( shDegreeToUse, numCameras, viewDirs, sh0Coeffs, shNCoeffs, radii); torch::cuda::synchronize(); auto end = std::chrono::high_resolution_clock::now(); diff --git a/src/tests/GaussianSplat3dCameraApiTest.cpp b/src/tests/GaussianSplat3dCameraApiTest.cpp deleted file mode 100644 index b8399df2d..000000000 --- a/src/tests/GaussianSplat3dCameraApiTest.cpp +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright Contributors to the OpenVDB Project -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include -#include -#include - -#include - -#include -#include - -namespace { - -using CameraModel = fvdb::GaussianSplat3d::CameraModel; -using ProjectionMethod = fvdb::GaussianSplat3d::ProjectionMethod; - -template -void -expectTorchErrorContains(Fn &&fn, const std::string &messageSubstring) { - try { - fn(); - FAIL() << "Expected c10::Error containing: " << messageSubstring; - } catch (const c10::Error &e) { - EXPECT_NE(std::string(e.what()).find(messageSubstring), std::string::npos) - << "Actual error was: " << e.what(); - } -} - -struct GaussianSplat3dCameraApiTest : public ::testing::Test { - void - SetUp() override { - torch::manual_seed(0); - if (!torch::cuda::is_available()) { - GTEST_SKIP() << "CUDA is not available; skipping GaussianSplat3d camera API tests."; - } - } - - static constexpr int64_t kImageWidth = 32; - static constexpr int64_t kImageHeight = 24; - static constexpr float kNearPlane = 0.05f; - static constexpr float kFarPlane = 20.0f; - - static fvdb::GaussianSplat3d - makeSimpleGaussianSplat() { - auto opts = torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32); - - const torch::Tensor means = - torch::tensor({{0.18f, -0.12f, 2.8f}, {-0.08f, 0.10f, 3.4f}}, opts); - const torch::Tensor quats = - torch::tensor({{1.0f, 0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 0.0f, 0.0f}}, opts); - const torch::Tensor logScales = - torch::log(torch::tensor({{0.06f, 0.05f, 0.04f}, {0.05f, 0.07f, 0.06f}}, opts)); - const torch::Tensor logitOpacities = torch::tensor({2.2f, 1.8f}, opts); - const torch::Tensor sh0 = - torch::tensor({{{0.7f, 0.1f, -0.2f}}, {{-0.3f, 0.5f, 0.4f}}}, opts); - const torch::Tensor shN = torch::empty({2, 0, 3}, opts); - - return fvdb::GaussianSplat3d( - means, quats, logScales, logitOpacities, sh0, shN, false, false, false); - } - - static torch::Tensor - makeWorldToCameraMatrices(const int64_t C) { - auto worldToCamera = - torch::eye(4, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32)) - .unsqueeze(0) - .repeat({C, 1, 1}); - auto acc = worldToCamera.accessor(); - for (int64_t c = 0; c < C; ++c) { - acc[c][0][3] = 0.03f * static_cast(c); - acc[c][1][3] = -0.02f * static_cast(c); - } - return worldToCamera.cuda().contiguous(); - } - - static torch::Tensor - makeProjectionMatrices(const int64_t C, const CameraModel cameraModel) { - auto projectionMatrices = torch::zeros( - {C, 3, 3}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32)); - auto acc = projectionMatrices.accessor(); - for (int64_t c = 0; c < C; ++c) { - const float fx = - cameraModel == CameraModel::ORTHOGRAPHIC ? 9.0f + 0.5f * c : 18.0f + 1.5f * c; - const float fy = - cameraModel == CameraModel::ORTHOGRAPHIC ? 8.5f + 0.5f * c : 17.0f + 1.25f * c; - acc[c][0][0] = fx; - acc[c][1][1] = fy; - acc[c][0][2] = (static_cast(kImageWidth) - 1.0f) / 2.0f + 0.3f * c; - acc[c][1][2] = (static_cast(kImageHeight) - 1.0f) / 2.0f - 0.2f * c; - acc[c][2][2] = 1.0f; - } - return projectionMatrices.cuda().contiguous(); - } - - static torch::Tensor - makeDistortionCoeffs(const int64_t C) { - auto distortionCoeffs = torch::zeros( - {C, 12}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32)); - auto acc = distortionCoeffs.accessor(); - for (int64_t c = 0; c < C; ++c) { - const float s = static_cast(c + 1); - acc[c][0] = 0.02f * s; - acc[c][1] = -0.004f * s; - acc[c][2] = 0.001f * s; - acc[c][6] = 0.0015f * s; - acc[c][7] = -0.0012f * s; - } - return distortionCoeffs.cuda().contiguous(); - } -}; - -TEST_F(GaussianSplat3dCameraApiTest, ProjectionMethodAutoResolvesByCameraModel) { - auto gs = makeSimpleGaussianSplat(); - auto worldToCam = makeWorldToCameraMatrices(1); - auto pinholeProj = makeProjectionMatrices(1, CameraModel::PINHOLE); - auto orthoProj = makeProjectionMatrices(1, CameraModel::ORTHOGRAPHIC); - auto distortion = makeDistortionCoeffs(1); - auto pinholeState = gs.projectGaussiansForImages(worldToCam, - pinholeProj, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::PINHOLE, - ProjectionMethod::AUTO, - std::nullopt, - 0); - auto orthoState = gs.projectGaussiansForImages(worldToCam, - orthoProj, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::ORTHOGRAPHIC, - ProjectionMethod::AUTO, - std::nullopt, - 0); - auto utState = gs.projectGaussiansForImages(worldToCam, - pinholeProj, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::PINHOLE, - ProjectionMethod::UNSCENTED, - std::nullopt, - 0); - auto opencvState = gs.projectGaussiansForImages(worldToCam, - pinholeProj, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::OPENCV_RADTAN_5, - ProjectionMethod::AUTO, - distortion, - 0); - - EXPECT_EQ(pinholeState.cameraModel(), CameraModel::PINHOLE); - EXPECT_EQ(pinholeState.projectionMethod(), ProjectionMethod::ANALYTIC); - EXPECT_EQ(orthoState.cameraModel(), CameraModel::ORTHOGRAPHIC); - EXPECT_EQ(orthoState.projectionMethod(), ProjectionMethod::ANALYTIC); - EXPECT_EQ(utState.cameraModel(), CameraModel::PINHOLE); - EXPECT_EQ(utState.projectionMethod(), ProjectionMethod::UNSCENTED); - EXPECT_EQ(opencvState.cameraModel(), CameraModel::OPENCV_RADTAN_5); - EXPECT_EQ(opencvState.projectionMethod(), ProjectionMethod::UNSCENTED); -} - -TEST_F(GaussianSplat3dCameraApiTest, CameraApiValidationRejectsInvalidArguments) { - auto gs = makeSimpleGaussianSplat(); - auto worldToCam = makeWorldToCameraMatrices(1); - auto projection = makeProjectionMatrices(1, CameraModel::PINHOLE); - auto distortion = makeDistortionCoeffs(1); - - expectTorchErrorContains( - [&]() { - gs.projectGaussiansForImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::OPENCV_RADTAN_5, - ProjectionMethod::AUTO, - std::nullopt, - 0); - }, - "distortionCoeffs must be provided"); - - expectTorchErrorContains( - [&]() { - gs.projectGaussiansForImages( - worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::OPENCV_RADTAN_5, - ProjectionMethod::AUTO, - distortion.index({torch::indexing::Slice(), torch::indexing::Slice(0, 5)}), - 0); - }, - "distortionCoeffs must have shape"); - - expectTorchErrorContains( - [&]() { - gs.renderImagesFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::OPENCV_RADTAN_5, - ProjectionMethod::ANALYTIC, - distortion, - 0); - }, - "ProjectionMethod::UNSCENTED or AUTO"); - - expectTorchErrorContains( - [&]() { - gs.projectGaussiansForImages(worldToCam, - projection.transpose(1, 2), - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::PINHOLE, - ProjectionMethod::AUTO, - std::nullopt, - 0); - }, - "projectionMatrices must be contiguous"); - - expectTorchErrorContains( - [&]() { - gs.projectGaussiansForImages(worldToCam.transpose(1, 2), - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::PINHOLE, - ProjectionMethod::AUTO, - std::nullopt, - 0); - }, - "worldToCameraMatrices must be contiguous"); -} - -TEST_F(GaussianSplat3dCameraApiTest, CameraApiValidationRejectsEmptyCameraBatches) { - auto gs = makeSimpleGaussianSplat(); - - auto emptyWorldToCam = - torch::empty({0, 4, 4}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - auto emptyProjection = - torch::empty({0, 3, 3}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); - - expectTorchErrorContains( - [&]() { - gs.projectGaussiansForImages(emptyWorldToCam, - emptyProjection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - CameraModel::PINHOLE, - ProjectionMethod::UNSCENTED, - std::nullopt, - 0); - }, - "At least one camera must be provided (got 0)"); -} - -TEST_F(GaussianSplat3dCameraApiTest, PinholeAndOrthographicIgnoreDistortionTensor) { - auto gs = makeSimpleGaussianSplat(); - auto distortion = makeDistortionCoeffs(1); - - for (const CameraModel cameraModel: {CameraModel::PINHOLE, CameraModel::ORTHOGRAPHIC}) { - SCOPED_TRACE(static_cast(cameraModel)); - const auto worldToCam = makeWorldToCameraMatrices(1); - const auto projection = makeProjectionMatrices(1, cameraModel); - const auto noDistortion = std::optional{}; - - const auto projectedDefault = gs.projectGaussiansForImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - noDistortion, - 0); - const auto projectedIgnored = gs.projectGaussiansForImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - EXPECT_TRUE( - torch::allclose(projectedDefault.means2d(), projectedIgnored.means2d(), 1e-6, 1e-6)); - - const auto [imagesDefault, imageAlphasDefault] = gs.renderImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - noDistortion, - 0); - const auto [imagesIgnored, imageAlphasIgnored] = gs.renderImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - EXPECT_TRUE(torch::allclose(imagesDefault, imagesIgnored, 1e-6, 1e-6)); - EXPECT_TRUE(torch::allclose(imageAlphasDefault, imageAlphasIgnored, 1e-6, 1e-6)); - - const auto [worldImagesDefault, worldImageAlphasDefault] = - gs.renderImagesFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - noDistortion, - 0); - const auto [worldImagesIgnored, worldImageAlphasIgnored] = - gs.renderImagesFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - EXPECT_TRUE(torch::allclose(worldImagesDefault, worldImagesIgnored, 1e-6, 1e-6)); - EXPECT_TRUE(torch::allclose(worldImageAlphasDefault, worldImageAlphasIgnored, 1e-6, 1e-6)); - } -} - -TEST_F(GaussianSplat3dCameraApiTest, ProjectedRenderMatchesDenseProjectedApis) { - auto gs = makeSimpleGaussianSplat(); - - for (const CameraModel cameraModel: - {CameraModel::PINHOLE, CameraModel::ORTHOGRAPHIC, CameraModel::OPENCV_RADTAN_5}) { - SCOPED_TRACE(static_cast(cameraModel)); - const auto worldToCam = makeWorldToCameraMatrices(1); - const auto projection = makeProjectionMatrices(1, cameraModel); - const auto distortion = cameraModel == CameraModel::OPENCV_RADTAN_5 - ? std::optional(makeDistortionCoeffs(1)) - : std::nullopt; - - const auto projectedImages = gs.projectGaussiansForImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - const auto [imagesFromProjected, imageAlphasFromProjected] = - gs.renderFromProjectedGaussians(projectedImages); - const auto [imagesFromDense, imageAlphasFromDense] = gs.renderImages(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - - EXPECT_TRUE(torch::allclose(imagesFromProjected, imagesFromDense, 1e-6, 1e-6)); - EXPECT_TRUE(torch::allclose(imageAlphasFromProjected, imageAlphasFromDense, 1e-6, 1e-6)); - - const auto projectedDepths = gs.projectGaussiansForDepths(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion); - const auto [depthsFromProjected, depthAlphasFromProjected] = - gs.renderFromProjectedGaussians(projectedDepths); - const auto [depthsFromDense, depthAlphasFromDense] = gs.renderDepths(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion); - - EXPECT_TRUE(torch::allclose(depthsFromProjected, depthsFromDense, 1e-6, 1e-6)); - EXPECT_TRUE(torch::allclose(depthAlphasFromProjected, depthAlphasFromDense, 1e-6, 1e-6)); - - const auto projectedRgbd = gs.projectGaussiansForImagesAndDepths(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - const auto [rgbdFromProjected, rgbdAlphasFromProjected] = - gs.renderFromProjectedGaussians(projectedRgbd); - const auto [rgbdFromDense, rgbdAlphasFromDense] = - gs.renderImagesAndDepths(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - - EXPECT_TRUE(torch::allclose(rgbdFromProjected, rgbdFromDense, 1e-6, 1e-6)); - EXPECT_TRUE(torch::allclose(rgbdAlphasFromProjected, rgbdAlphasFromDense, 1e-6, 1e-6)); - } -} - -TEST_F(GaussianSplat3dCameraApiTest, WorldSpaceRenderVariantsShareAlphaAndPacking) { - auto gs = makeSimpleGaussianSplat(); - - for (const CameraModel cameraModel: - {CameraModel::PINHOLE, CameraModel::ORTHOGRAPHIC, CameraModel::OPENCV_RADTAN_5}) { - SCOPED_TRACE(static_cast(cameraModel)); - const auto worldToCam = makeWorldToCameraMatrices(1); - const auto projection = makeProjectionMatrices(1, cameraModel); - const auto distortion = cameraModel == CameraModel::OPENCV_RADTAN_5 - ? std::optional(makeDistortionCoeffs(1)) - : std::nullopt; - - const auto [images, imageAlphas] = gs.renderImagesFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - const auto [depths, depthAlphas] = gs.renderDepthsFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion); - const auto [rgbd, rgbdAlphas] = gs.renderImagesAndDepthsFromWorld(worldToCam, - projection, - kImageWidth, - kImageHeight, - kNearPlane, - kFarPlane, - cameraModel, - ProjectionMethod::AUTO, - distortion, - 0); - - EXPECT_TRUE(torch::allclose(imageAlphas, depthAlphas, 1e-5, 1e-5)); - EXPECT_TRUE(torch::allclose(imageAlphas, rgbdAlphas, 1e-5, 1e-5)); - EXPECT_TRUE(torch::allclose(rgbd.index({torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(0, 3)}), - images, - 1e-5, - 1e-5)); - EXPECT_TRUE(torch::allclose(rgbd.index({torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(3, 4)}), - depths, - 1e-5, - 1e-5)); - } -} - -} // namespace diff --git a/src/tests/GaussianTileIntersectionTest.cpp b/src/tests/GaussianTileIntersectionTest.cpp index 3cfc5a5d4..d71586319 100644 --- a/src/tests/GaussianTileIntersectionTest.cpp +++ b/src/tests/GaussianTileIntersectionTest.cpp @@ -1,8 +1,8 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 -#include -#include +#include +#include #include @@ -216,15 +216,14 @@ class GaussianTileIntersectionTest : public ::testing::Test { TEST_F(GaussianTileIntersectionTest, CPUNotImplementedTest) { auto const [means2d, radii, depths] = createTestData(); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d, - radii, - depths, - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w), + EXPECT_THROW(fvdb::detail::ops::intersect_gaussian_tiles(means2d.cpu(), + radii.cpu(), + depths.cpu(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w), c10::Error); } @@ -237,15 +236,14 @@ TEST_F(GaussianTileIntersectionTest, ZeroLengthGaussianTest) { auto depths = torch::empty({numCameras, 0}, torch::kFloat32); // [C, N=0] auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - /*camera_jidx=*/at::nullopt, - numCameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles(means2d.cuda(), + radii.cuda(), + depths.cuda(), + /*camera_jidx=*/at::nullopt, + numCameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -304,27 +302,25 @@ TEST_F(GaussianTileIntersectionTest, BadInputsFailTest) { if (isPacked) { const auto cameraJidx = torch::zeros(config[3], torch::kInt32); - EXPECT_THROW( - fvdb::detail::ops::dispatchGaussianTileIntersection(means2d.cuda(), - radii.cuda(), - depths.cuda(), - cameraJidx.cuda(), - nc, - tile_size, - num_tiles_h, - num_tiles_w), - c10::ValueError); + EXPECT_THROW(fvdb::detail::ops::intersect_gaussian_tiles(means2d.cuda(), + radii.cuda(), + depths.cuda(), + cameraJidx.cuda(), + nc, + tile_size, + num_tiles_h, + num_tiles_w), + c10::ValueError); } else { - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - /*camera_jidx=*/at::nullopt, - nc, - tile_size, - num_tiles_h, - num_tiles_w), + EXPECT_THROW(fvdb::detail::ops::intersect_gaussian_tiles(means2d.cuda(), + radii.cuda(), + depths.cuda(), + /*camera_jidx=*/at::nullopt, + nc, + tile_size, + num_tiles_h, + num_tiles_w), c10::ValueError); } } @@ -336,15 +332,14 @@ TEST_F(GaussianTileIntersectionTest, ZeroRadiusGaussianTest) { radii.zero_(); // Set all radii to 0 auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles(means2d.cuda(), + radii.cuda(), + depths.cuda(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Verify that there are no intersections EXPECT_EQ(tile_offsets.sum().item(), 0); @@ -354,15 +349,14 @@ TEST_F(GaussianTileIntersectionTest, BasicIntersectionTest) { auto [means2d, radii, depths] = createTestData(); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles(means2d.cuda(), + radii.cuda(), + depths.cuda(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -388,14 +382,14 @@ TEST_F(GaussianTileIntersectionTest, PackedFormatTest) { torch::arange(num_cameras, torch::kInt32).repeat_interleave(num_gaussians); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianTileIntersection(means2d_packed.cuda(), - radii_packed.cuda(), - depths_packed.cuda(), - camera_indices.cuda(), - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles(means2d_packed.cuda(), + radii_packed.cuda(), + depths_packed.cuda(), + camera_indices.cuda(), + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -425,17 +419,16 @@ TEST_F(GaussianTileIntersectionTest, DenseViaSparseTest) { auto active_tiles = tile_mask_to_active_tiles(tile_mask); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - tile_mask.cuda(), - active_tiles.cuda(), - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d.cuda(), + radii.cuda(), + depths.cuda(), + tile_mask.cuda(), + active_tiles.cuda(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -470,17 +463,16 @@ TEST_F(GaussianTileIntersectionTest, SparseIntersectionTest) { int32_t num_active_tiles = active_tiles.size(0); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - tile_mask.cuda(), - active_tiles.cuda(), - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d.cuda(), + radii.cuda(), + depths.cuda(), + tile_mask.cuda(), + active_tiles.cuda(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -508,17 +500,16 @@ TEST_F(GaussianTileIntersectionTest, SparseCPUNotImplementedTest) { auto tile_mask = torch::ones({1, num_tiles_h, num_tiles_w}, torch::kBool); auto active_tiles = tile_mask_to_active_tiles(tile_mask); - EXPECT_THROW(fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d, - radii, - depths, - tile_mask, - active_tiles, - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w), + EXPECT_THROW(fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d.cpu(), + radii.cpu(), + depths.cpu(), + tile_mask.cpu(), + active_tiles.cpu(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w), c10::Error); } @@ -535,17 +526,16 @@ TEST_F(GaussianTileIntersectionTest, SparseZeroLengthGaussianTest) { auto active_tiles = tile_mask_to_active_tiles(tile_mask); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - tile_mask.cuda(), - active_tiles.cuda(), - /*camera_jidx=*/at::nullopt, - numCameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d.cuda(), + radii.cuda(), + depths.cuda(), + tile_mask.cuda(), + active_tiles.cuda(), + /*camera_jidx=*/at::nullopt, + numCameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -574,17 +564,16 @@ TEST_F(GaussianTileIntersectionTest, SparsePackedFormatTest) { auto active_tiles = tile_mask_to_active_tiles(tile_mask); auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d_packed.cuda(), - radii_packed.cuda(), - depths_packed.cuda(), - tile_mask.cuda(), - active_tiles.cuda(), - camera_indices.cuda(), - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d_packed.cuda(), + radii_packed.cuda(), + depths_packed.cuda(), + tile_mask.cuda(), + active_tiles.cuda(), + camera_indices.cuda(), + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); @@ -621,17 +610,16 @@ TEST_F(GaussianTileIntersectionTest, RandomSparsePatternTest) { } #endif auto [tile_offsets, intersection_values] = - fvdb::detail::ops::dispatchGaussianSparseTileIntersection( - means2d.cuda(), - radii.cuda(), - depths.cuda(), - tile_mask.cuda(), - active_tiles.cuda(), - /*camera_jidx=*/at::nullopt, - num_cameras, - tile_size, - num_tiles_h, - num_tiles_w); + fvdb::detail::ops::intersect_gaussian_tiles_sparse(means2d.cuda(), + radii.cuda(), + depths.cuda(), + tile_mask.cuda(), + active_tiles.cuda(), + /*camera_jidx=*/at::nullopt, + num_cameras, + tile_size, + num_tiles_h, + num_tiles_w); // Move results back to CPU for verification tile_offsets = tile_offsets.cpu(); diff --git a/src/tests/GaussianUtilsTest.cu b/src/tests/GaussianUtilsTest.cu index cd847dcd2..a1bb8d8ff 100644 --- a/src/tests/GaussianUtilsTest.cu +++ b/src/tests/GaussianUtilsTest.cu @@ -1,7 +1,10 @@ // Copyright Contributors to the OpenVDB Project // SPDX-License-Identifier: Apache-2.0 -#include +#include +#include +#include +#include #include diff --git a/src/tests/ViewerTest.cpp b/src/tests/ViewerTest.cpp index 6b2c8bd76..ef7936089 100644 --- a/src/tests/ViewerTest.cpp +++ b/src/tests/ViewerTest.cpp @@ -24,10 +24,16 @@ TEST(Viewer, ViewerTest) { torch::Device device(torch::kCUDA); - std::vector< - std::tuple>> - loadedData; + // Each loaded entry is a tuple of (means, quats, logScales, logitOpacities, sh0, shN, metadata) + using LoadResult = + std::tuple>; + std::vector loadedData; for (size_t i = 0; i < std::size(ply_paths) && i < std::size(view_names); ++i) { const std::string &ply_path = ply_paths[i]; @@ -45,11 +51,11 @@ TEST(Viewer, ViewerTest) { const std::string &ply_path = ply_paths[i]; const std::string &view_name = view_names[i]; - auto [splats, metadata] = loadedData[i]; + auto &[means, quats, logScales, logitOpacities, sh0, shN, metadata] = loadedData[i]; printf("Adding splats from %s\n", ply_path.c_str()); - fvdb::detail::viewer::GaussianSplat3dView &view = - viewer.addGaussianSplat3d(view_name, splats); + fvdb::detail::viewer::GaussianSplat3dView &view = viewer.addGaussianSplat3dView( + view_name, view_name, means, quats, logScales, logitOpacities, sh0, shN); view.setShDegreeToUse(3); @@ -69,8 +75,19 @@ TEST(Viewer, ViewerTest) { imageSizes = torch::empty({0}, device); } - fvdb::detail::viewer::CameraView &cameraView = viewer.addCameraView( - view_name, cameraToWorld, projectionMat, imageSizes, 0.f, 0.5f); + fvdb::detail::viewer::CameraView &cameraView = viewer.addCameraView(view_name, + view_name, + cameraToWorld, + projectionMat, + imageSizes, + 0.f, + 0.5f, + 0.5f, + 0.0125f, + 2.0f, + 1.0f, + {1.0f, 1.0f, 1.0f}, + true); std::this_thread::sleep_for(std::chrono::seconds(5)); @@ -101,11 +118,8 @@ TEST(Viewer, ViewerTest) { torch::Tensor sh0 = torch::rand({N, 1, 3}, device); torch::Tensor shN = torch::rand({N, 15, 3}, device); - fvdb::GaussianSplat3d splats( - means, quats, logScales, logitOpacities, sh0, shN, false, false, false); - - fvdb::detail::viewer::GaussianSplat3dView &view = - viewer.addGaussianSplat3d("test_view", splats); + fvdb::detail::viewer::GaussianSplat3dView &view = viewer.addGaussianSplat3dView( + "test_scene", "test_view", means, quats, logScales, logitOpacities, sh0, shN); const float testEps2d = 0.5f; view.setEps2d(testEps2d); diff --git a/tests/unit/test_binding_renames.py b/tests/unit/test_binding_renames.py new file mode 100644 index 000000000..98d48c87a --- /dev/null +++ b/tests/unit/test_binding_renames.py @@ -0,0 +1,109 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Smoke tests that every renamed/new C++ binding is accessible and every +deleted binding is absent.""" + +import unittest + +from fvdb import _fvdb_cpp as _C + + +class TestBindingRenames(unittest.TestCase): + """Verify the Gaussian splat C++ bindings match the refactored names.""" + + EXPECTED_BINDINGS = [ + # Utility + "relocate_gaussians", + "add_noise_to_gaussian_means", + "save_gaussians_ply", + "load_gaussians_ply", + # Analysis + "count_contributing_gaussians", + "count_contributing_gaussians_sparse", + "identify_contributing_gaussians", + "identify_contributing_gaussians_sparse", + # Analytic projection fwd/bwd + "project_gaussians_analytic_fwd", + "project_gaussians_analytic_bwd", + "project_gaussians_analytic_jagged_fwd", + "project_gaussians_analytic_jagged_bwd", + # UT projection fwd + "project_gaussians_ut_fwd", + # SH evaluation fwd/bwd + "eval_gaussian_sh_fwd", + "eval_gaussian_sh_bwd", + # Dense rasterization fwd/bwd + "rasterize_screen_space_gaussians_fwd", + "rasterize_screen_space_gaussians_bwd", + # Sparse rasterization fwd/bwd + "rasterize_screen_space_gaussians_sparse_fwd", + "rasterize_screen_space_gaussians_sparse_bwd", + # World-space rasterization fwd/bwd + "rasterize_world_space_gaussians_fwd", + "rasterize_world_space_gaussians_bwd", + # Tile intersection + "intersect_gaussian_tiles", + "intersect_gaussian_tiles_sparse", + "build_sparse_gaussian_tile_layout", + ] + + DELETED_BINDINGS = [ + # Old gsplat_ prefixed names + "check_gaussian_state", + "gsplat_check_state", + "gsplat_projection_fwd", + "gsplat_projection_bwd", + "gsplat_projection_jagged_fwd", + "gsplat_projection_jagged_bwd", + "gsplat_rasterize_fwd", + "gsplat_rasterize_bwd", + "gsplat_rasterize_sparse_fwd", + "gsplat_rasterize_sparse_bwd", + "gsplat_rasterize_from_world_fwd", + "gsplat_rasterize_from_world_bwd", + "gsplat_tile_intersection", + "gsplat_sh_eval_fwd", + "gsplat_sh_eval_bwd", + "gsplat_render_crop_from_projected", + "render_crop_from_projected_gaussians", + "gsplat_render_num_contributing", + "gsplat_render_contributing_ids", + "gsplat_sparse_render_num_contributing", + "gsplat_sparse_render_contributing_ids", + "gsplat_load_ply", + "gsplat_save_ply", + "gsplat_relocate_gaussians", + "gsplat_add_noise_to_means", + # Deleted pipeline / utility bindings + "gsplat_eval_sh", + "gsplat_project_gaussians_analytic", + "gsplat_project_gaussians_ut", + "gsplat_project_gaussians_for_camera_with_accum", + "gsplat_sparse_project_gaussians_for_camera", + "gsplat_sparse_project_gaussians_ut", + "gsplat_sparse_render", + "gsplat_rasterize_from_world", + "gsplat_render_depth_from_world", + "evaluate_spherical_harmonics", + # Old query_ prefixed names + "query_num_contributing_gaussians", + "query_num_contributing_gaussians_sparse", + "query_contributing_gaussian_ids", + "query_contributing_gaussian_ids_sparse", + ] + + def test_expected_bindings_exist(self): + for name in self.EXPECTED_BINDINGS: + with self.subTest(name=name): + self.assertTrue(hasattr(_C, name), f"Binding '{name}' should exist on _fvdb_cpp but is missing") + self.assertTrue(callable(getattr(_C, name)), f"Binding '{name}' should be callable") + + def test_deleted_bindings_absent(self): + for name in self.DELETED_BINDINGS: + with self.subTest(name=name): + self.assertFalse(hasattr(_C, name), f"Binding '{name}' should have been deleted but still exists") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_decomposed_sparse.py b/tests/unit/test_decomposed_sparse.py new file mode 100644 index 000000000..0215370c8 --- /dev/null +++ b/tests/unit/test_decomposed_sparse.py @@ -0,0 +1,229 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Tests for the decomposed sparse rendering pipeline: intersect_gaussian_tiles_sparse +and rasterize_screen_space_gaussians_sparse through the functional API. +""" + +import math +import unittest + +import numpy as np +import torch + +from fvdb import GaussianSplat3d, JaggedTensor, _fvdb_cpp as _C +from fvdb.enums import CameraModel, GaussianRenderMode, ProjectionMethod +from fvdb.utils.tests import get_fvdb_test_data_path + +import fvdb.functional as F + + +def rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor: + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +class TestDecomposedSparse(unittest.TestCase): + """Validate the decomposed sparse rendering bindings and pipeline.""" + + def setUp(self): + torch.random.manual_seed(42) + np.random.seed(42) + self.device = "cuda:0" + + data_path = get_fvdb_test_data_path() / "gsplat" / "test_garden_cropped.npz" + data = np.load(data_path) + + means = torch.from_numpy(data["means3d"]).float().to(self.device) + quats = torch.from_numpy(data["quats"]).float().to(self.device) + scales = torch.from_numpy(data["scales"]).float().to(self.device) + opacities = torch.from_numpy(data["opacities"]).float().to(self.device) + colors = torch.from_numpy(data["colors"]).float().to(self.device) + + all_w2c = torch.from_numpy(data["viewmats"]).float().to(self.device) + all_proj = torch.from_numpy(data["Ks"]).float().to(self.device) + self.W = data["width"].item() + self.H = data["height"].item() + self.tile_size = 16 + + self.world_to_cam = all_w2c[0:1].contiguous() + self.projection_matrices = all_proj[0:1].contiguous() + + self.means = means + self.quats = quats + self.log_scales = torch.log(scales) + self.logit_opacities = torch.logit(opacities) + + N = means.shape[0] + sh_degree = 3 + sh_coeffs = torch.zeros((N, (sh_degree + 1) ** 2, 3), device=self.device) + sh_coeffs[:, 0, :] = rgb_to_sh(colors) + self.sh0 = sh_coeffs[:, 0, :].unsqueeze(1).clone() + self.shN = sh_coeffs[:, 1:, :].clone() + self.sh_degree = sh_degree + + def _project(self): + return F.project_gaussians( + means=self.means, + quats=self.quats, + log_scales=self.log_scales, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + ) + + def _make_pixel_grid(self, step=8): + """Create a grid of pixel coordinates covering the image.""" + ys = torch.arange(0, self.H, step, device=self.device) + xs = torch.arange(0, self.W, step, device=self.device) + grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1).reshape(-1, 2) + return grid.int() + + def test_build_sparse_tile_layout_shapes(self): + """build_sparse_gaussian_tile_layout produces valid tile metadata tensors.""" + pixels = self._make_pixel_grid() + pixels_jt = JaggedTensor([pixels]) + + num_tiles_w = math.ceil(self.W / self.tile_size) + num_tiles_h = math.ceil(self.H / self.tile_size) + + active_tiles, active_tile_mask, tile_pixel_mask, tile_pixel_cumsum, pixel_map = ( + _C.build_sparse_gaussian_tile_layout(self.tile_size, num_tiles_w, num_tiles_h, pixels_jt._impl) + ) + + self.assertGreater(active_tiles.numel(), 0) + self.assertEqual(active_tile_mask.shape[0], 1) # C=1 + + def test_intersect_gaussian_tiles_sparse_shapes(self): + """intersect_gaussian_tiles_sparse produces valid tile offsets and IDs.""" + projected = self._project() + pixels = self._make_pixel_grid() + pixels_jt = JaggedTensor([pixels]) + + num_tiles_w = math.ceil(self.W / self.tile_size) + num_tiles_h = math.ceil(self.H / self.tile_size) + + active_tiles, active_tile_mask, _, _, _ = _C.build_sparse_gaussian_tile_layout( + self.tile_size, num_tiles_w, num_tiles_h, pixels_jt._impl + ) + + C = self.world_to_cam.size(0) + tile_offsets, tile_gaussian_ids = _C.intersect_gaussian_tiles_sparse( + projected.means2d, + projected.radii, + projected.depths, + active_tile_mask, + active_tiles, + C, + self.tile_size, + num_tiles_h, + num_tiles_w, + ) + + self.assertEqual(tile_offsets.dim(), 1) + self.assertEqual(tile_gaussian_ids.dim(), 1) + + def test_4stage_sparse_pipeline_produces_output(self): + """The 4-stage sparse pipeline produces non-trivial, finite rendered features.""" + pixels = self._make_pixel_grid(step=4) + pixels_jt = JaggedTensor([pixels]) + + projected = self._project() + features = F.evaluate_gaussian_sh( + self.means, + self.sh0, + self.shN, + self.world_to_cam, + projected, + sh_degree_to_use=self.sh_degree, + render_mode=GaussianRenderMode.FEATURES, + ) + sparse_tiles = F.intersect_gaussian_tiles_sparse( + pixels_jt, + projected, + tile_size=self.tile_size, + ) + features_jt, alphas_jt = F.rasterize_screen_space_gaussians_sparse( + projected, + features, + self.logit_opacities, + sparse_tiles, + ) + + self.assertTrue(torch.isfinite(features_jt.jdata).all()) + self.assertTrue(torch.isfinite(alphas_jt.jdata).all()) + self.assertGreater(features_jt.jdata.abs().sum().item(), 0) + + def test_4stage_sparse_matches_oo_sparse(self): + """4-stage sparse pipeline matches GaussianSplat3d.sparse_render_images.""" + pixels = self._make_pixel_grid(step=8) + pixels_jt = JaggedTensor([pixels]) + + gs3d = GaussianSplat3d.from_tensors( + means=self.means, + quats=self.quats, + log_scales=self.log_scales, + logit_opacities=self.logit_opacities, + sh0=self.sh0, + shN=self.shN, + ) + + oo_features, oo_alphas = gs3d.sparse_render_images( + pixels_to_render=pixels_jt, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + near=0.01, + far=1e10, + sh_degree_to_use=self.sh_degree, + ) + + projected = self._project() + features = F.evaluate_gaussian_sh( + self.means, + self.sh0, + self.shN, + self.world_to_cam, + projected, + sh_degree_to_use=self.sh_degree, + render_mode=GaussianRenderMode.FEATURES, + ) + sparse_tiles = F.intersect_gaussian_tiles_sparse( + pixels_jt, + projected, + tile_size=self.tile_size, + ) + fn_features_jt, fn_alphas_jt = F.rasterize_screen_space_gaussians_sparse( + projected, + features, + self.logit_opacities, + sparse_tiles, + ) + + if sparse_tiles.has_duplicates: + from fvdb import JaggedTensor as JT + + fn_features_jt = JT(impl=fn_features_jt._impl[sparse_tiles.inverse_indices]) + fn_alphas_jt = JT(impl=fn_alphas_jt._impl[sparse_tiles.inverse_indices]) + + torch.testing.assert_close( + fn_features_jt.jdata, + oo_features.jdata, + atol=1e-4, + rtol=1e-4, + msg="Functional sparse features don't match OO sparse features", + ) + torch.testing.assert_close( + fn_alphas_jt.jdata, + oo_alphas.jdata, + atol=1e-4, + rtol=1e-4, + msg="Functional sparse alphas don't match OO sparse alphas", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_find_unique_pixels.py b/tests/unit/test_find_unique_pixels.py new file mode 100644 index 000000000..a563872be --- /dev/null +++ b/tests/unit/test_find_unique_pixels.py @@ -0,0 +1,140 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +"""Tests for find_unique_pixels, ported from DeduplicatePixelsTest.cpp.""" + +import unittest + +import torch + +from fvdb import JaggedTensor +from fvdb.functional._gaussian_tile_intersection import _find_unique_pixels as find_unique_pixels + +IMAGE_WIDTH = 64 +IMAGE_HEIGHT = 64 + + +class TestFindUniquePixels(unittest.TestCase): + + def _run(self, pixels_jt, w=IMAGE_WIDTH, h=IMAGE_HEIGHT): + return find_unique_pixels(pixels_jt, image_width=w, image_height=h) + + def test_empty(self): + pixels = JaggedTensor(torch.empty(0, 2, dtype=torch.int32, device="cuda")) + unique, inv, has_dups = self._run(pixels) + self.assertFalse(has_dups) + self.assertEqual(inv.size(0), 0) + + def test_single_pixel(self): + coords = torch.tensor([[5, 10]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([coords]) + unique, inv, has_dups = self._run(pixels) + self.assertFalse(has_dups) + self.assertEqual(len(unique.jdata), 1) + + def test_all_unique(self): + coords = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1], [2, 3]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([coords]) + unique, inv, has_dups = self._run(pixels) + self.assertFalse(has_dups) + self.assertEqual(len(unique.jdata), 5) + self.assertEqual(inv.size(0), 5) + + def test_some_duplicates(self): + coords = torch.tensor([[0, 0], [1, 1], [0, 0], [2, 2]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([coords]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique.jdata), 3) + self.assertEqual(inv.size(0), 4) + inv_cpu = inv.cpu() + self.assertEqual(inv_cpu[0].item(), inv_cpu[2].item()) + self.assertNotEqual(inv_cpu[1].item(), inv_cpu[3].item()) + + def test_all_same_pixel(self): + coords = torch.tensor([[5, 5], [5, 5], [5, 5], [5, 5]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([coords]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique.jdata), 1) + self.assertEqual(inv.size(0), 4) + inv_cpu = inv.cpu() + for i in range(4): + self.assertEqual(inv_cpu[i].item(), 0) + + def test_multi_batch_no_duplicates(self): + batch0 = torch.tensor([[0, 0], [1, 1]], dtype=torch.int32, device="cuda") + batch1 = torch.tensor([[0, 0], [2, 2]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([batch0, batch1]) + unique, inv, has_dups = self._run(pixels) + self.assertFalse(has_dups) + self.assertEqual(len(unique.jdata), 4) + self.assertEqual(len(unique), 2) + + def test_multi_batch_with_duplicates(self): + batch0 = torch.tensor([[0, 0], [1, 1], [0, 0]], dtype=torch.int32, device="cuda") + batch1 = torch.tensor([[0, 0], [3, 3]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([batch0, batch1]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique), 2) + self.assertEqual(len(unique.jdata), 4) + self.assertEqual(inv.size(0), 5) + inv_cpu = inv.cpu() + self.assertEqual(inv_cpu[0].item(), inv_cpu[2].item()) + + def test_multi_batch_all_same_pixel(self): + batch0 = torch.tensor([[1, 1], [1, 1], [1, 1]], dtype=torch.int32, device="cuda") + batch1 = torch.tensor([[2, 2], [2, 2]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([batch0, batch1]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique), 2) + self.assertEqual(len(unique.jdata), 2) + offsets = unique.joffsets.cpu() + self.assertEqual(offsets[0].item(), 0) + self.assertEqual(offsets[1].item(), 1) + self.assertEqual(offsets[2].item(), 2) + inv_cpu = inv.cpu() + self.assertEqual(inv_cpu[0].item(), inv_cpu[1].item()) + self.assertEqual(inv_cpu[0].item(), inv_cpu[2].item()) + self.assertEqual(inv_cpu[3].item(), inv_cpu[4].item()) + self.assertNotEqual(inv_cpu[0].item(), inv_cpu[3].item()) + + def test_round_trip_some_duplicates(self): + coords = torch.tensor([[3, 7], [1, 2], [3, 7], [5, 5], [1, 2], [9, 0]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([coords]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique.jdata), 4) + reconstructed = unique.jdata.index_select(0, inv) + self.assertTrue(torch.equal(reconstructed.cpu(), coords.cpu().to(reconstructed.dtype))) + + def test_round_trip_multi_batch(self): + batch0 = torch.tensor([[2, 3], [4, 5], [2, 3]], dtype=torch.int32, device="cuda") + batch1 = torch.tensor([[6, 7], [6, 7], [8, 9]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([batch0, batch1]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + original_jdata = pixels.jdata + reconstructed = unique.jdata.index_select(0, inv) + self.assertTrue(torch.equal(reconstructed.cpu(), original_jdata.cpu().to(reconstructed.dtype))) + + def test_jagged_tensor_offsets(self): + batch0 = torch.tensor([[0, 0], [0, 0], [1, 1]], dtype=torch.int32, device="cuda") + batch1 = torch.tensor([[2, 2]], dtype=torch.int32, device="cuda") + batch2 = torch.tensor([[3, 3], [4, 4], [3, 3], [4, 4]], dtype=torch.int32, device="cuda") + pixels = JaggedTensor([batch0, batch1, batch2]) + unique, inv, has_dups = self._run(pixels) + self.assertTrue(has_dups) + self.assertEqual(len(unique), 3) + self.assertEqual(len(unique.jdata), 5) + offsets = unique.joffsets.cpu() + self.assertEqual(offsets[0].item(), 0) + self.assertEqual(offsets[1].item(), 2) + self.assertEqual(offsets[2].item(), 3) + self.assertEqual(offsets[3].item(), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_functional_splat_training.py b/tests/unit/test_functional_splat_training.py new file mode 100644 index 000000000..66b948a1a --- /dev/null +++ b/tests/unit/test_functional_splat_training.py @@ -0,0 +1,660 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Tests validating that the decomposed functional Gaussian splatting API can be +used to build a complete training loop WITHOUT GaussianSplat3d, and that +results match the OO API numerically. +""" + +import unittest + +import numpy as np +import torch +from fvdb.utils.tests import get_fvdb_test_data_path + +import fvdb.functional as F +from fvdb import GaussianSplat3d +from fvdb.enums import CameraModel, GaussianRenderMode + + +def rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor: + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +def _functional_render_4stage( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + world_to_cam, + projection_matrices, + image_width, + image_height, + near=0.01, + far=1e10, + sh_degree_to_use=3, + tile_size=16, +): + """Full functional forward pass through the 4-stage pipeline.""" + projected = F.project_gaussians( + means, + quats, + log_scales, + world_to_cam, + projection_matrices, + image_width, + image_height, + eps_2d=0.3, + near=near, + far=far, + radius_clip=0.0, + antialias=False, + camera_model=CameraModel.PINHOLE, + ) + features = F.evaluate_gaussian_sh( + means, + sh0, + shN, + world_to_cam, + projected, + sh_degree_to_use=sh_degree_to_use, + render_mode=GaussianRenderMode.FEATURES, + ) + tiles = F.intersect_gaussian_tiles(projected, tile_size=tile_size) + images, alphas = F.rasterize_screen_space_gaussians( + projected, + features, + logit_opacities, + tiles, + ) + return images, alphas + + +class TestFunctionalSplatTraining(unittest.TestCase): + """Validate functional API forward, backward, and training loop.""" + + def setUp(self): + torch.random.manual_seed(0) + np.random.seed(0) + self.device = "cuda:0" + + data_path = get_fvdb_test_data_path() / "gsplat" / "test_garden_cropped.npz" + data = np.load(data_path) + + means = torch.from_numpy(data["means3d"]).float().to(self.device) + quats = torch.from_numpy(data["quats"]).float().to(self.device) + scales = torch.from_numpy(data["scales"]).float().to(self.device) + opacities = torch.from_numpy(data["opacities"]).float().to(self.device) + colors = torch.from_numpy(data["colors"]).float().to(self.device) + + all_w2c = torch.from_numpy(data["viewmats"]).float().to(self.device) + all_proj = torch.from_numpy(data["Ks"]).float().to(self.device) + self.W = data["width"].item() + self.H = data["height"].item() + + self.world_to_cam = all_w2c[0:1].contiguous() + self.projection_matrices = all_proj[0:1].contiguous() + + self.log_scales_data = torch.log(scales) + self.logit_opacities_data = torch.logit(opacities) + self.means_data = means + + N = means.shape[0] + sh_degree = 3 + sh_coeffs = torch.zeros((N, (sh_degree + 1) ** 2, 3), device=self.device) + sh_coeffs[:, 0, :] = rgb_to_sh(colors) + self.sh0_data = sh_coeffs[:, 0, :].unsqueeze(1).clone() + self.shN_data = sh_coeffs[:, 1:, :].clone() + + self.quats_data = quats + self.sh_degree_to_use = sh_degree + + # ------------------------------------------------------------------ + # Test 1: Forward pass numerical equivalence (4-stage vs OO) + # ------------------------------------------------------------------ + def test_functional_forward_matches_oo(self): + """Render one frame through both paths, assert images match.""" + means = self.means_data.detach().requires_grad_(True) + quats = self.quats_data.detach().requires_grad_(True) + log_scales = self.log_scales_data.detach().requires_grad_(True) + logit_opacities = self.logit_opacities_data.detach().requires_grad_(True) + sh0 = self.sh0_data.detach().requires_grad_(True) + shN = self.shN_data.detach().requires_grad_(True) + + images_fn, alphas_fn = _functional_render_4stage( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + sh_degree_to_use=self.sh_degree_to_use, + ) + + gs3d = GaussianSplat3d.from_tensors( + means=means, + quats=quats, + log_scales=log_scales, + logit_opacities=logit_opacities, + sh0=sh0, + shN=shN, + ) + images_oo, alphas_oo = gs3d.render_images( + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + near=0.01, + far=1e10, + sh_degree_to_use=self.sh_degree_to_use, + ) + + torch.testing.assert_close(images_fn, images_oo, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(alphas_fn, alphas_oo, atol=1e-5, rtol=1e-5) + + # ------------------------------------------------------------------ + # Test 2: Backward pass gradient equivalence (4-stage vs OO) + # ------------------------------------------------------------------ + def test_functional_backward_matches_oo(self): + """Forward + backward through both paths, compare gradients.""" + means_fn = self.means_data.detach().clone().requires_grad_(True) + means_oo = self.means_data.detach().clone().requires_grad_(True) + quats_fn = self.quats_data.detach().clone().requires_grad_(True) + quats_oo = self.quats_data.detach().clone().requires_grad_(True) + log_scales_fn = self.log_scales_data.detach().clone().requires_grad_(True) + log_scales_oo = self.log_scales_data.detach().clone().requires_grad_(True) + logit_opacities_fn = self.logit_opacities_data.detach().clone().requires_grad_(True) + logit_opacities_oo = self.logit_opacities_data.detach().clone().requires_grad_(True) + sh0_fn = self.sh0_data.detach().clone().requires_grad_(True) + sh0_oo = self.sh0_data.detach().clone().requires_grad_(True) + shN_fn = self.shN_data.detach().clone().requires_grad_(True) + shN_oo = self.shN_data.detach().clone().requires_grad_(True) + + images_fn, _ = _functional_render_4stage( + means_fn, + quats_fn, + log_scales_fn, + logit_opacities_fn, + sh0_fn, + shN_fn, + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + sh_degree_to_use=self.sh_degree_to_use, + ) + images_fn.sum().backward() + + gs3d = GaussianSplat3d.from_tensors( + means=means_oo, + quats=quats_oo, + log_scales=log_scales_oo, + logit_opacities=logit_opacities_oo, + sh0=sh0_oo, + shN=shN_oo, + ) + gs3d.requires_grad = True + images_oo, _ = gs3d.render_images( + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + near=0.01, + far=1e10, + sh_degree_to_use=self.sh_degree_to_use, + ) + images_oo.sum().backward() + + fn_grads = { + "means": means_fn.grad, + "quats": quats_fn.grad, + "log_scales": log_scales_fn.grad, + "logit_opacities": logit_opacities_fn.grad, + "sh0": sh0_fn.grad, + "shN": shN_fn.grad, + } + for name, grad_fn in fn_grads.items(): + grad_oo = getattr(gs3d, name).grad + self.assertIsNotNone(grad_fn, f"Functional gradient for {name} is None") + self.assertIsNotNone(grad_oo, f"OO gradient for {name} is None") + torch.testing.assert_close( + grad_fn, + grad_oo, + atol=5e-3, + rtol=1e-4, + msg=f"Gradient mismatch for {name}", + ) + + # ------------------------------------------------------------------ + # Test 3: Training loop with Adam optimizer + # ------------------------------------------------------------------ + def test_functional_training_loop(self): + """5 steps of Adam on perturbed params, verify loss decreases.""" + means = self.means_data.detach().clone().requires_grad_(True) + quats = self.quats_data.detach().clone().requires_grad_(True) + log_scales = self.log_scales_data.detach().clone().requires_grad_(True) + logit_opacities = self.logit_opacities_data.detach().clone().requires_grad_(True) + sh0 = self.sh0_data.detach().clone().requires_grad_(True) + shN = self.shN_data.detach().clone().requires_grad_(True) + + params = [means, quats, log_scales, logit_opacities, sh0, shN] + param_names = ["means", "quats", "log_scales", "logit_opacities", "sh0", "shN"] + optimizer = torch.optim.Adam(params, lr=0.01) + + with torch.no_grad(): + target_images, _ = _functional_render_4stage( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + sh_degree_to_use=self.sh_degree_to_use, + ) + + with torch.no_grad(): + means.add_(torch.randn_like(means) * 0.01) + + losses = [] + num_steps = 5 + for step in range(num_steps): + optimizer.zero_grad() + images, alphas = _functional_render_4stage( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + sh_degree_to_use=self.sh_degree_to_use, + ) + loss = torch.nn.functional.l1_loss(images, target_images) + loss.backward() + + for param, name in zip(params, param_names): + grad = param.grad + assert grad is not None, f"Gradient for {name} is None at step {step}" + assert torch.isfinite(grad).all(), f"Non-finite gradient for {name} at step {step}" + assert grad.abs().sum() > 0, f"Zero gradient for {name} at step {step}" + + optimizer.step() + losses.append(loss.item()) + + self.assertLess( + losses[-1], + losses[0], + f"Loss did not decrease: {losses}", + ) + + +class TestFunctionalCrop(unittest.TestCase): + """Validate that the crop parameter produces correct sub-regions and raises on invalid inputs.""" + + def setUp(self): + torch.random.manual_seed(0) + np.random.seed(0) + self.device = "cuda:0" + + data_path = get_fvdb_test_data_path() / "gsplat" / "test_garden_cropped.npz" + data = np.load(data_path) + + self.means = torch.from_numpy(data["means3d"]).float().to(self.device) + self.quats = torch.from_numpy(data["quats"]).float().to(self.device) + self.log_scales = torch.log(torch.from_numpy(data["scales"]).float().to(self.device)) + self.logit_opacities = torch.logit(torch.from_numpy(data["opacities"]).float().to(self.device)) + colors = torch.from_numpy(data["colors"]).float().to(self.device) + + sh_coeffs = torch.zeros((self.means.shape[0], 16, 3), device=self.device) + sh_coeffs[:, 0, :] = rgb_to_sh(colors) + self.sh0 = sh_coeffs[:, 0, :].unsqueeze(1).clone() + self.shN = sh_coeffs[:, 1:, :].clone() + + self.world_to_cam = torch.from_numpy(data["viewmats"]).float().to(self.device)[0:1].contiguous() + self.projection_matrices = torch.from_numpy(data["Ks"]).float().to(self.device)[0:1].contiguous() + self.W = data["width"].item() + self.H = data["height"].item() + + self.projected = F.project_gaussians( + self.means, + self.quats, + self.log_scales, + self.world_to_cam, + self.projection_matrices, + self.W, + self.H, + eps_2d=0.3, + near=0.01, + far=1e10, + radius_clip=0.0, + antialias=False, + camera_model=CameraModel.PINHOLE, + ) + self.features = F.evaluate_gaussian_sh( + self.means, + self.sh0, + self.shN, + self.world_to_cam, + self.projected, + sh_degree_to_use=3, + render_mode=GaussianRenderMode.FEATURES, + ) + self.tiles = F.intersect_gaussian_tiles(self.projected, tile_size=16) + + def test_crop_none_matches_full_image(self): + """crop=None should be identical to not passing crop.""" + full, full_a = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + ) + crop_none, crop_none_a = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=None, + ) + torch.testing.assert_close(full, crop_none) + torch.testing.assert_close(full_a, crop_none_a) + + def test_crop_full_image_matches_no_crop(self): + """crop=(0, 0, W, H) should match no-crop rendering.""" + full, full_a = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + ) + cropped, cropped_a = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(0, 0, self.W, self.H), + ) + self.assertEqual(cropped.shape, full.shape) + torch.testing.assert_close(full, cropped) + torch.testing.assert_close(full_a, cropped_a) + + def test_crop_sub_region_shape_and_content(self): + """A sub-region crop should have the right shape and match the corresponding slice of the full render.""" + full, _ = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + ) + ox, oy, cw, ch = 16, 16, 64, 48 + cropped, _ = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(ox, oy, cw, ch), + ) + self.assertEqual(cropped.shape[1], ch) + self.assertEqual(cropped.shape[2], cw) + expected = full[:, oy : oy + ch, ox : ox + cw, :] + torch.testing.assert_close(cropped, expected, atol=1e-5, rtol=1e-5) + + def test_crop_is_differentiable(self): + """Gradients should flow through the crop path.""" + logit_ops = self.logit_opacities.detach().clone().requires_grad_(True) + cropped, _ = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + logit_ops, + self.tiles, + crop=(0, 0, 64, 48), + ) + cropped.sum().backward() + self.assertIsNotNone(logit_ops.grad) + self.assertTrue(logit_ops.grad.abs().sum() > 0) + + def test_crop_clamps_to_image_bounds(self): + """A crop extending beyond the image should be clamped, not error.""" + cropped, _ = F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(self.W - 32, self.H - 24, 128, 128), + ) + self.assertEqual(cropped.shape[2], 32) + self.assertEqual(cropped.shape[1], 24) + + def test_crop_rejects_negative_origin(self): + with self.assertRaises(ValueError, msg="negative origin_x"): + F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(-1, 0, 64, 48), + ) + + def test_crop_rejects_zero_size(self): + with self.assertRaises(ValueError, msg="zero width"): + F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(0, 0, 0, 48), + ) + + def test_crop_rejects_no_overlap(self): + with self.assertRaises(ValueError, msg="origin beyond image"): + F.rasterize_screen_space_gaussians( + self.projected, + self.features, + self.logit_opacities, + self.tiles, + crop=(self.W + 10, 0, 64, 48), + ) + + def test_count_contributing_crop_matches_full(self): + """count_contributing_gaussians with crop=(0,0,W,H) matches no-crop.""" + full_num, full_w = F.count_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + ) + cropped_num, cropped_w = F.count_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + crop=(0, 0, self.W, self.H), + ) + torch.testing.assert_close(full_num, cropped_num) + torch.testing.assert_close(full_w, cropped_w) + + def test_count_contributing_crop_sub_region(self): + """count_contributing_gaussians with a sub-region crop matches the slice of the full result.""" + full_num, _ = F.count_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + ) + ox, oy, cw, ch = 16, 16, 64, 48 + cropped_num, _ = F.count_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + crop=(ox, oy, cw, ch), + ) + self.assertEqual(cropped_num.shape[1], ch) + self.assertEqual(cropped_num.shape[2], cw) + expected = full_num[:, oy : oy + ch, ox : ox + cw] + torch.testing.assert_close(cropped_num, expected) + + def test_world_space_crop_full_matches_no_crop(self): + """rasterize_world_space_gaussians with crop=(0,0,W,H) matches no-crop.""" + distortion_coeffs = torch.empty( + self.world_to_cam.shape[0], + 0, + device=self.device, + dtype=torch.float32, + ) + full, full_a = F.rasterize_world_space_gaussians( + self.means, + self.quats, + self.log_scales, + self.projected, + self.features, + self.logit_opacities, + self.world_to_cam, + self.projection_matrices, + distortion_coeffs, + CameraModel.PINHOLE, + self.tiles, + ) + cropped, cropped_a = F.rasterize_world_space_gaussians( + self.means, + self.quats, + self.log_scales, + self.projected, + self.features, + self.logit_opacities, + self.world_to_cam, + self.projection_matrices, + distortion_coeffs, + CameraModel.PINHOLE, + self.tiles, + crop=(0, 0, self.W, self.H), + ) + torch.testing.assert_close(full, cropped) + torch.testing.assert_close(full_a, cropped_a) + + def test_world_space_crop_sub_region(self): + """rasterize_world_space_gaussians with a sub-region crop matches the slice of the full result.""" + distortion_coeffs = torch.empty( + self.world_to_cam.shape[0], + 0, + device=self.device, + dtype=torch.float32, + ) + full, _ = F.rasterize_world_space_gaussians( + self.means, + self.quats, + self.log_scales, + self.projected, + self.features, + self.logit_opacities, + self.world_to_cam, + self.projection_matrices, + distortion_coeffs, + CameraModel.PINHOLE, + self.tiles, + ) + ox, oy, cw, ch = 16, 16, 64, 48 + cropped, _ = F.rasterize_world_space_gaussians( + self.means, + self.quats, + self.log_scales, + self.projected, + self.features, + self.logit_opacities, + self.world_to_cam, + self.projection_matrices, + distortion_coeffs, + CameraModel.PINHOLE, + self.tiles, + crop=(ox, oy, cw, ch), + ) + self.assertEqual(cropped.shape[1], ch) + self.assertEqual(cropped.shape[2], cw) + expected = full[:, oy : oy + ch, ox : ox + cw, :] + torch.testing.assert_close(cropped, expected, atol=1e-5, rtol=1e-5) + + def test_identify_crop_full_matches_no_crop(self): + """identify_contributing_gaussians with crop=(0,0,W,H) matches no-crop.""" + full_ids, full_w = F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + ) + cropped_ids, cropped_w = F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + crop=(0, 0, self.W, self.H), + ) + torch.testing.assert_close(full_ids.jdata, cropped_ids.jdata) + torch.testing.assert_close(full_w.jdata, cropped_w.jdata) + torch.testing.assert_close(full_ids.joffsets, cropped_ids.joffsets) + + def test_identify_crop_sub_region(self): + """identify_contributing_gaussians with a sub-region crop selects the correct pixels.""" + full_ids, full_w = F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + ) + ox, oy, cw, ch = 16, 16, 64, 48 + cropped_ids, cropped_w = F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + crop=(ox, oy, cw, ch), + ) + + C = len(full_ids) + self.assertEqual(len(cropped_ids), C) + + for c in range(C): + full_cam = full_ids[c] + crop_cam = cropped_ids[c] + self.assertEqual(len(crop_cam), cw * ch) + + for dy in range(ch): + for dx in range(cw): + crop_pixel_idx = dy * cw + dx + full_pixel_idx = (oy + dy) * self.W + (ox + dx) + crop_pixel_data = crop_cam[crop_pixel_idx].jdata + full_pixel_data = full_cam[full_pixel_idx].jdata + torch.testing.assert_close(crop_pixel_data, full_pixel_data) + + def test_identify_crop_rejects_invalid(self): + """identify_contributing_gaussians raises on invalid crop inputs.""" + with self.assertRaises(ValueError): + F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + crop=(-1, 0, 64, 48), + ) + with self.assertRaises(ValueError): + F.identify_contributing_gaussians( + self.projected, + self.logit_opacities, + self.tiles, + top_k_contributors=5, + crop=(0, 0, 0, 48), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_functional_splat_ut.py b/tests/unit/test_functional_splat_ut.py new file mode 100644 index 000000000..77cc0fac6 --- /dev/null +++ b/tests/unit/test_functional_splat_ut.py @@ -0,0 +1,194 @@ +# Copyright Contributors to the OpenVDB Project +# SPDX-License-Identifier: Apache-2.0 +# +""" +Tests for the Unscented Transform (UT) projection path through the +decomposed functional API. +""" + +import unittest + +import numpy as np +import torch +from fvdb.utils.tests import get_fvdb_test_data_path + +import fvdb.functional as F +from fvdb import GaussianSplat3d +from fvdb.enums import CameraModel, GaussianRenderMode, ProjectionMethod +from fvdb import _fvdb_cpp as _C + + +def rgb_to_sh(rgb: torch.Tensor) -> torch.Tensor: + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +class TestUTProjection(unittest.TestCase): + """Validate the UT projection path through the functional API.""" + + def setUp(self): + torch.random.manual_seed(42) + np.random.seed(42) + self.device = "cuda:0" + + data_path = get_fvdb_test_data_path() / "gsplat" / "test_garden_cropped.npz" + data = np.load(data_path) + + means = torch.from_numpy(data["means3d"]).float().to(self.device) + quats = torch.from_numpy(data["quats"]).float().to(self.device) + scales = torch.from_numpy(data["scales"]).float().to(self.device) + opacities = torch.from_numpy(data["opacities"]).float().to(self.device) + colors = torch.from_numpy(data["colors"]).float().to(self.device) + + all_w2c = torch.from_numpy(data["viewmats"]).float().to(self.device) + all_proj = torch.from_numpy(data["Ks"]).float().to(self.device) + self.W = data["width"].item() + self.H = data["height"].item() + + self.world_to_cam = all_w2c[0:1].contiguous() + self.projection_matrices = all_proj[0:1].contiguous() + + self.log_scales = torch.log(scales) + self.logit_opacities = torch.logit(opacities) + self.means = means + self.quats = quats + + N = means.shape[0] + sh_degree = 3 + sh_coeffs = torch.zeros((N, (sh_degree + 1) ** 2, 3), device=self.device) + sh_coeffs[:, 0, :] = rgb_to_sh(colors) + self.sh0 = sh_coeffs[:, 0, :].unsqueeze(1).clone() + self.shN = sh_coeffs[:, 1:, :].clone() + self.sh_degree = sh_degree + + def test_ut_fwd_binding_produces_valid_shapes(self): + """project_gaussians_ut_fwd produces correctly shaped outputs.""" + C = self.world_to_cam.size(0) + N = self.means.size(0) + dc = torch.empty(C, 0, device=self.device, dtype=self.means.dtype) + + radii, means2d, depths, conics, compensations = _C.project_gaussians_ut_fwd( + self.means, + self.quats, + self.log_scales, + self.world_to_cam, + self.projection_matrices, + dc, + _C.CameraModel.PINHOLE, + self.W, + self.H, + 0.3, + 0.01, + 1e10, + 0.0, + False, + ) + + self.assertEqual(radii.shape, (C, N)) + self.assertEqual(means2d.shape, (C, N, 2)) + self.assertEqual(depths.shape, (C, N)) + self.assertEqual(conics.shape, (C, N, 3)) + + def test_project_gaussians_ut_returns_valid_projected(self): + """project_gaussians with UNSCENTED produces a valid ProjectedGaussians.""" + projected = F.project_gaussians( + means=self.means, + quats=self.quats, + log_scales=self.log_scales, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + projection_method=ProjectionMethod.UNSCENTED, + ) + + self.assertIsNotNone(projected.means2d) + self.assertIsNotNone(projected.conics) + self.assertIsNotNone(projected.radii) + self.assertIsNotNone(projected.depths) + + def test_ut_is_not_differentiable_through_projection(self): + """UT projection does not produce gradients on means/quats/scales.""" + means = self.means.detach().requires_grad_(True) + quats = self.quats.detach().requires_grad_(True) + log_scales = self.log_scales.detach().requires_grad_(True) + + projected = F.project_gaussians( + means=means, + quats=quats, + log_scales=log_scales, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + projection_method=ProjectionMethod.UNSCENTED, + ) + + self.assertIsNone(projected.means2d.grad_fn) + + def test_ut_and_analytic_produce_comparable_images(self): + """Both projection methods produce non-trivial, finite rendered images.""" + projected_analytic = F.project_gaussians( + means=self.means, + quats=self.quats, + log_scales=self.log_scales, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + projection_method=ProjectionMethod.ANALYTIC, + ) + projected_ut = F.project_gaussians( + means=self.means, + quats=self.quats, + log_scales=self.log_scales, + world_to_camera_matrices=self.world_to_cam, + projection_matrices=self.projection_matrices, + image_width=self.W, + image_height=self.H, + projection_method=ProjectionMethod.UNSCENTED, + ) + + features_a = F.evaluate_gaussian_sh( + self.means, + self.sh0, + self.shN, + self.world_to_cam, + projected_analytic, + sh_degree_to_use=self.sh_degree, + render_mode=GaussianRenderMode.FEATURES, + ) + features_u = F.evaluate_gaussian_sh( + self.means, + self.sh0, + self.shN, + self.world_to_cam, + projected_ut, + sh_degree_to_use=self.sh_degree, + render_mode=GaussianRenderMode.FEATURES, + ) + + tiles_a = F.intersect_gaussian_tiles(projected_analytic) + tiles_u = F.intersect_gaussian_tiles(projected_ut) + + images_a, _ = F.rasterize_screen_space_gaussians( + projected_analytic, + features_a, + self.logit_opacities, + tiles_a, + ) + images_u, _ = F.rasterize_screen_space_gaussians( + projected_ut, + features_u, + self.logit_opacities, + tiles_u, + ) + + self.assertTrue(torch.isfinite(images_a).all(), "Analytic image has non-finite values") + self.assertTrue(torch.isfinite(images_u).all(), "UT image has non-finite values") + self.assertGreater(images_a.abs().sum().item(), 0, "Analytic image is all zeros") + self.assertGreater(images_u.abs().sum().item(), 0, "UT image is all zeros") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_gaussian_splat_3d.py b/tests/unit/test_gaussian_splat_3d.py index 89948cb93..2010bad36 100644 --- a/tests/unit/test_gaussian_splat_3d.py +++ b/tests/unit/test_gaussian_splat_3d.py @@ -27,6 +27,7 @@ evaluate_spherical_harmonics, gaussian_render_jagged, ) +from fvdb.enums import GaussianRenderMode def compare_images(pixels_or_path_a, pixels_or_path_b): @@ -180,6 +181,8 @@ def check_grad(self): else: self.assertEqual(self.gs3d1.accumulated_max_2d_radii, None) + assert self.gs3d1.accumulated_gradient_step_counts is not None + assert self.gs3d1.accumulated_mean_2d_gradient_norms is not None self.assertTrue(self.gs3d1.accumulated_gradient_step_counts.shape == (self.gs3d1.num_gaussians,)) self.assertTrue(self.gs3d1.accumulated_mean_2d_gradient_norms.shape == (self.gs3d1.num_gaussians,)) @@ -188,9 +191,17 @@ def check_basic(self, gs3d_cat, gs3d_list, acc2d_rad, acc_m2dgrad): self.assertTrue(gs3d_cat.num_gaussians == len(gs3d_list) * self.gs3d.num_gaussians) self.assertTrue(torch.equal(gs3d_cat.means, torch.cat([gs.means for gs in gs3d_list], dim=0))) self.assertTrue(torch.equal(gs3d_cat.quats, torch.cat([gs.quats for gs in gs3d_list], dim=0))) - self.assertTrue(torch.equal(gs3d_cat.log_scales, torch.cat([gs.log_scales for gs in gs3d_list], dim=0))) self.assertTrue( - torch.equal(gs3d_cat.logit_opacities, torch.cat([gs.logit_opacities for gs in gs3d_list], dim=0)) + torch.equal( + gs3d_cat.log_scales, + torch.cat([gs.log_scales for gs in gs3d_list], dim=0), + ) + ) + self.assertTrue( + torch.equal( + gs3d_cat.logit_opacities, + torch.cat([gs.logit_opacities for gs in gs3d_list], dim=0), + ) ) self.assertTrue(torch.equal(gs3d_cat.sh0, torch.cat([gs.sh0 for gs in gs3d_list], dim=0))) self.assertTrue(torch.equal(gs3d_cat.shN, torch.cat([gs.shN for gs in gs3d_list], dim=0))) @@ -198,6 +209,7 @@ def check_basic(self, gs3d_cat, gs3d_list, acc2d_rad, acc_m2dgrad): self.assertEqual(gs3d_cat.accumulate_mean_2d_gradients, acc_m2dgrad) if gs3d_cat.accumulate_max_2d_radii: + assert gs3d_cat.accumulated_max_2d_radii is not None self.assertTrue(gs3d_cat.accumulated_max_2d_radii.shape == (gs3d_cat.num_gaussians,)) self.assertTrue(gs3d_cat.accumulated_max_2d_radii.dtype == torch.int32) self.assertTrue(gs3d_cat.accumulated_max_2d_radii.device == self.device) @@ -205,6 +217,8 @@ def check_basic(self, gs3d_cat, gs3d_list, acc2d_rad, acc_m2dgrad): self.assertEqual(gs3d_cat.accumulated_max_2d_radii, None) if gs3d_cat.accumulate_max_2d_radii: + assert gs3d_cat.accumulated_gradient_step_counts is not None + assert gs3d_cat.accumulated_mean_2d_gradient_norms is not None self.assertTrue(gs3d_cat.accumulated_gradient_step_counts.shape == (gs3d_cat.num_gaussians,)) self.assertTrue(gs3d_cat.accumulated_gradient_step_counts.dtype == torch.int32) self.assertTrue(gs3d_cat.accumulated_gradient_step_counts.device == self.device) @@ -218,22 +232,32 @@ def check_basic(self, gs3d_cat, gs3d_list, acc2d_rad, acc_m2dgrad): def test_cat_basic(self): gs3d_cat = GaussianSplat3d.cat( - [self.gs3d1, self.gs3d2, self.gs3d3], accumulate_max_2d_radii=False, accumulate_mean_2d_gradients=False + [self.gs3d1, self.gs3d2, self.gs3d3], + accumulate_max_2d_radii=False, + accumulate_mean_2d_gradients=False, ) self.check_grad() self.check_basic(gs3d_cat, [self.gs3d1, self.gs3d2, self.gs3d3], False, False) def test_cat_track_state_no_backward_on_two_and_three(self): gs3d_cat = GaussianSplat3d.cat( - [self.gs3d1, self.gs3d2, self.gs3d3], accumulate_max_2d_radii=True, accumulate_mean_2d_gradients=True + [self.gs3d1, self.gs3d2, self.gs3d3], + accumulate_max_2d_radii=True, + accumulate_mean_2d_gradients=True, ) self.check_grad() self.check_basic(gs3d_cat, [self.gs3d1, self.gs3d2, self.gs3d3], True, True) + assert gs3d_cat.accumulated_gradient_step_counts is not None + assert gs3d_cat.accumulated_mean_2d_gradient_norms is not None + assert gs3d_cat.accumulated_max_2d_radii is not None self.assertTrue(gs3d_cat.accumulated_gradient_step_counts.shape, (gs3d_cat.num_gaussians,)) self.assertTrue(gs3d_cat.accumulated_mean_2d_gradient_norms.shape, (gs3d_cat.num_gaussians,)) if self.run_backward: + assert self.gs3d1.accumulated_gradient_step_counts is not None + assert self.gs3d1.accumulated_mean_2d_gradient_norms is not None + assert self.gs3d1.accumulated_max_2d_radii is not None step_counts = torch.cat( [ self.gs3d1.accumulated_gradient_step_counts, @@ -267,14 +291,20 @@ def test_cat_track_state_no_backward_on_two_and_three(self): self.assertTrue(torch.equal(gs3d_cat.accumulated_max_2d_radii, max_radii)) def test_cat_track_state_all_backward(self): - gs3d1_d, gs3d2_d, gs3d3_d = self.gs3d1.detach(), self.gs3d2.detach(), self.gs3d3.detach() + gs3d1_d, gs3d2_d, gs3d3_d = ( + self.gs3d1.detach(), + self.gs3d2.detach(), + self.gs3d3.detach(), + ) if self.run_backward: self.run_backward_on_gs3d(gs3d1_d) self.run_backward_on_gs3d(gs3d2_d) self.run_backward_on_gs3d(gs3d3_d) gs3d_cat = GaussianSplat3d.cat( - [gs3d1_d, gs3d2_d, gs3d3_d], accumulate_max_2d_radii=True, accumulate_mean_2d_gradients=True + [gs3d1_d, gs3d2_d, gs3d3_d], + accumulate_max_2d_radii=True, + accumulate_mean_2d_gradients=True, ) self.check_grad() @@ -283,10 +313,22 @@ def test_cat_track_state_all_backward(self): self.assertTrue(gs3d_cat.accumulate_max_2d_radii) self.assertTrue(gs3d_cat.accumulate_mean_2d_gradients) + assert gs3d_cat.accumulated_gradient_step_counts is not None + assert gs3d_cat.accumulated_mean_2d_gradient_norms is not None + assert gs3d_cat.accumulated_max_2d_radii is not None self.assertTrue(gs3d_cat.accumulated_gradient_step_counts.shape, (gs3d_cat.num_gaussians,)) self.assertTrue(gs3d_cat.accumulated_mean_2d_gradient_norms.shape, (gs3d_cat.num_gaussians,)) if self.run_backward: + assert gs3d1_d.accumulated_gradient_step_counts is not None + assert gs3d2_d.accumulated_gradient_step_counts is not None + assert gs3d3_d.accumulated_gradient_step_counts is not None + assert gs3d1_d.accumulated_mean_2d_gradient_norms is not None + assert gs3d2_d.accumulated_mean_2d_gradient_norms is not None + assert gs3d3_d.accumulated_mean_2d_gradient_norms is not None + assert gs3d1_d.accumulated_max_2d_radii is not None + assert gs3d2_d.accumulated_max_2d_radii is not None + assert gs3d3_d.accumulated_max_2d_radii is not None step_counts = torch.cat( [ gs3d1_d.accumulated_gradient_step_counts, @@ -364,6 +406,7 @@ def check_device_and_dtype(self, gs3d, device, dtype): self.assertTrue(gs3d.dtype == dtype) if gs3d.accumulated_gradient_step_counts is not None: assert self.run_backward, "accumulated_gradient_step_counts should only be set when run_backward is True" + assert self.gs3d.accumulated_gradient_step_counts is not None self.assertTrue( gs3d.accumulated_gradient_step_counts.shape == self.gs3d.accumulated_gradient_step_counts.shape ) @@ -374,6 +417,7 @@ def check_device_and_dtype(self, gs3d, device, dtype): self.assertEqual(self.gs3d.accumulated_gradient_step_counts, None) if gs3d.accumulated_mean_2d_gradient_norms is not None: assert self.run_backward, "accumulated_mean_2d_gradient_norms should only be set when run_backward is True" + assert self.gs3d.accumulated_mean_2d_gradient_norms is not None self.assertTrue( gs3d.accumulated_mean_2d_gradient_norms.shape == self.gs3d.accumulated_mean_2d_gradient_norms.shape ) @@ -385,6 +429,7 @@ def check_device_and_dtype(self, gs3d, device, dtype): if gs3d.accumulated_max_2d_radii is not None: assert self.run_backward, "accumulated_max_2d_radii should only be set when run_backward is True" + assert self.gs3d.accumulated_max_2d_radii is not None self.assertTrue(gs3d.accumulated_max_2d_radii.shape == self.gs3d.accumulated_max_2d_radii.shape) self.assertTrue(gs3d.accumulated_max_2d_radii.device == device) self.assertTrue(gs3d.accumulated_max_2d_radii.dtype == torch.int32) @@ -571,6 +616,10 @@ def compare_src_and_dst( # Check that both the source and destination Gaussian Splat get their accumulate # gradient state correctly set if src_acc_m2d_grads and dst_track_m2d_grads: + assert src.accumulated_gradient_step_counts is not None + assert dst.accumulated_gradient_step_counts is not None + assert src.accumulated_mean_2d_gradient_norms is not None + assert dst.accumulated_mean_2d_gradient_norms is not None assertfun( torch.equal( src.accumulated_gradient_step_counts, @@ -584,6 +633,8 @@ def compare_src_and_dst( ) ) if track_max_2d_radii: + assert src.accumulated_max_2d_radii is not None + assert dst.accumulated_max_2d_radii is not None assertfun( torch.equal( src.accumulated_max_2d_radii, @@ -591,6 +642,8 @@ def compare_src_and_dst( ) ) elif dst_track_m2d_grads and not src_acc_m2d_grads: + assert dst.accumulated_gradient_step_counts is not None + assert dst.accumulated_mean_2d_gradient_norms is not None assertfun( torch.equal( torch.zeros(src.num_gaussians).to(dst.accumulated_gradient_step_counts), @@ -604,6 +657,7 @@ def compare_src_and_dst( ) ) if track_max_2d_radii: + assert dst.accumulated_max_2d_radii is not None assertfun( torch.equal( torch.zeros(src.num_gaussians).to(dst.accumulated_max_2d_radii), @@ -615,12 +669,22 @@ def compare_src_and_dst( self.assertEqual(dst.accumulated_mean_2d_gradient_norms, None) self.assertEqual(dst.accumulated_gradient_step_counts, None) # Check that the destination Gaussian Splat has the same gradient shapes as before + assert src.accumulated_gradient_step_counts is not None + assert src.accumulated_mean_2d_gradient_norms is not None self.assertTrue(src.accumulated_gradient_step_counts.shape == (src.num_gaussians,)) self.assertTrue(src.accumulated_mean_2d_gradient_norms.shape == (src.num_gaussians,)) if track_max_2d_radii: + assert src.accumulated_max_2d_radii is not None self.assertTrue(src.accumulated_max_2d_radii.shape == (src.num_gaussians,)) - def _run_test(self, indices, src_requires_grad, dst_requires_grad, track_max_2d_radii, slicefun=None): + def _run_test( + self, + indices, + src_requires_grad, + dst_requires_grad, + track_max_2d_radii, + slicefun=None, + ): # Create the source and destination Gaussian Splats src, dst = self.make_src_and_dst( indices, @@ -828,6 +892,10 @@ def _check( if accumulate_mean_2d_gradients: # Ensure the gradients and accumulated gradient state match at every other Gaussian + assert selected.accumulated_gradient_step_counts is not None + assert dst.accumulated_gradient_step_counts is not None + assert selected.accumulated_mean_2d_gradient_norms is not None + assert dst.accumulated_mean_2d_gradient_norms is not None self.assertTrue( torch.equal( selected.accumulated_gradient_step_counts, @@ -841,6 +909,8 @@ def _check( ) ) if accumulate_max_2d_radii: + assert selected.accumulated_max_2d_radii is not None + assert dst.accumulated_max_2d_radii is not None self.assertTrue( torch.equal( selected.accumulated_max_2d_radii, @@ -849,7 +919,10 @@ def _check( ) def _make_gs3d( - self, accumulate_mean_2d_gradients: bool, accumulate_max_2d_radii: bool, empty_shN: bool + self, + accumulate_mean_2d_gradients: bool, + accumulate_max_2d_radii: bool, + empty_shN: bool, ) -> GaussianSplat3d: # Create a GaussianSplat3d instance with gradients that matches self.gs3d shN = torch.empty((self.gs3d.num_gaussians, 0, 3), device=self.device) if empty_shN else self.gs3d.shN @@ -880,9 +953,12 @@ def _make_gs3d( # Check that we tracked accumulated gradient state properly if accumulate_mean_2d_gradients: + assert gs3d.accumulated_gradient_step_counts is not None + assert gs3d.accumulated_mean_2d_gradient_norms is not None self.assertTrue(gs3d.accumulated_gradient_step_counts.shape == (gs3d.num_gaussians,)) self.assertTrue(gs3d.accumulated_mean_2d_gradient_norms.shape == (gs3d.num_gaussians,)) if accumulate_max_2d_radii: + assert gs3d.accumulated_max_2d_radii is not None self.assertTrue(gs3d.accumulated_max_2d_radii.shape == (gs3d.num_gaussians,)) return gs3d @@ -1292,13 +1368,21 @@ def test_save_and_load_ply_with_training_info(self): assert isinstance(training_info["camera_to_world_matrices"], torch.Tensor) assert isinstance(training_info["projection_types"], torch.Tensor) assert isinstance(training_info["projection_parameters"], torch.Tensor) - self.assertTrue(torch.allclose(training_info["normalization_transform"], normalization_tx.to(self.device))) + self.assertTrue( + torch.allclose( + training_info["normalization_transform"], + normalization_tx.to(self.device), + ) + ) self.assertTrue(torch.allclose(training_info["camera_to_world_matrices"], cam_to_worlds.to(self.device))) self.assertTrue(torch.equal(training_info["projection_types"], cam_types.to(self.device))) self.assertTrue(torch.allclose(training_info["projection_parameters"], proj_params.to(self.device))) self.assertEqual(training_info["float_param"], 0.121243243523524650345740953) self.assertEqual(training_info["int_param"], 8198767135) - self.assertEqual(training_info["string_parameter"], "The Quick brown fox jumps over the lazy dog") + self.assertEqual( + training_info["string_parameter"], + "The Quick brown fox jumps over the lazy dog", + ) def test_save_ply_only_string_keys(self): tf = tempfile.NamedTemporaryFile(delete=True, suffix=".ply") @@ -1368,7 +1452,12 @@ def test_save_and_load_ply_with_training_info_non_contiguous(self): assert isinstance(training_info["projection_types"], torch.Tensor) assert isinstance(training_info["projection_parameters"], torch.Tensor) self.assertTrue(torch.allclose(training_info["normalization_tx"], normalization_tx.to(self.device))) - self.assertTrue(torch.allclose(training_info["camera_to_world_matrices123"], cam_to_worlds.to(self.device))) + self.assertTrue( + torch.allclose( + training_info["camera_to_world_matrices123"], + cam_to_worlds.to(self.device), + ) + ) self.assertTrue(torch.equal(training_info["projection_types"], cam_types.to(self.device))) self.assertTrue(torch.allclose(training_info["projection_parameters"], proj_params.to(self.device))) self.assertEqual(training_info["version_string"], "my version") @@ -1381,7 +1470,7 @@ def setUp(self): super().setUp() def test_gaussian_projection(self): - proj_res = self.gs3d.project_gaussians_for_images_and_depths( + proj_res = self.gs3d.project_gaussians( self.cam_to_world_mats, self.projection_mats, self.width, @@ -1391,8 +1480,8 @@ def test_gaussian_projection(self): ) radii = proj_res.radii means2d = proj_res.means2d - depths = proj_res.render_quantities[..., -1] - conics = proj_res.inv_covar_2d + depths = proj_res.depths + conics = proj_res.conics if self.save_regression_data: torch.save(radii, "regression_radii.pt") @@ -1411,8 +1500,8 @@ def test_gaussian_projection(self): torch.testing.assert_close(depths[radii > 0], test_depths[radii > 0]) torch.testing.assert_close(conics[radii > 0], test_conics[radii > 0], atol=1e-5, rtol=1e-4) - def test_projection_camera_metadata(self): - projected = self.gs3d.project_gaussians_for_images( + def test_projection_returns_valid_projected_gaussians(self): + projected = self.gs3d.project_gaussians( self.cam_to_world_mats[:1], self.projection_mats[:1], self.width, @@ -1423,8 +1512,12 @@ def test_projection_camera_metadata(self): projection_method=ProjectionMethod.AUTO, ) - self.assertEqual(projected.camera_model, CameraModel.PINHOLE) - self.assertEqual(projected.projection_method, ProjectionMethod.ANALYTIC) + self.assertEqual(projected.image_width, self.width) + self.assertEqual(projected.image_height, self.height) + self.assertEqual(projected.means2d.shape[0], 1) + self.assertIsNotNone(projected.conics) + self.assertIsNotNone(projected.depths) + self.assertIsNotNone(projected.radii) def test_from_world_depth_and_rgbd_render(self): cam_mats = self.cam_to_world_mats[:1] @@ -1610,7 +1703,12 @@ def test_gaussians_center_render(self): # For image size 1024x512, principal point should be around (512, 256) focal_length = 18.0 # Reasonable focal length for this image size intrinsics = torch.tensor( - [[focal_length, 0.0, w / 2.0], [0.0, focal_length, h / 2.0], [0.0, 0.0, 1.0]], device=self.device + [ + [focal_length, 0.0, w / 2.0], + [0.0, focal_length, h / 2.0], + [0.0, 0.0, 1.0], + ], + device=self.device, ) means3d = torch.cat( @@ -1623,7 +1721,11 @@ def test_gaussians_center_render(self): opacities = torch.cat( [ - torch.full((means3d.shape[0] // num_gaussian_layers,), 0.4, device=means3d.device) + torch.full( + (means3d.shape[0] // num_gaussian_layers,), + 0.4, + device=means3d.device, + ) for _ in range(num_gaussian_layers) ], dim=0, @@ -1727,7 +1829,8 @@ def test_gaussians_center_render(self): # test the center pixel should have the correct number of contributing gaussians self.assertTrue( torch.equal( - sparse_num_contributing_gaussians.unbind()[0][0], num_contributing_gaussians[0][h // 2 - 1][w // 2 - 1] + sparse_num_contributing_gaussians.unbind()[0][0], + num_contributing_gaussians[0][h // 2 - 1][w // 2 - 1], ) ) @@ -1782,7 +1885,12 @@ def test_gaussians_grid_render(self): # For image size 1024x512, principal point should be around (512, 256) focal_length = 18.0 # Reasonable focal length for this image size intrinsics = torch.tensor( - [[focal_length, 0.0, w / 2.0], [0.0, focal_length, h / 2.0], [0.0, 0.0, 1.0]], device=self.device + [ + [focal_length, 0.0, w / 2.0], + [0.0, focal_length, h / 2.0], + [0.0, 0.0, 1.0], + ], + device=self.device, ) means3d = torch.cat( @@ -1797,7 +1905,11 @@ def test_gaussians_grid_render(self): opacities = torch.cat( [ - torch.full((means3d.shape[0] // num_gaussian_layers,), 0.4, device=means3d.device) + torch.full( + (means3d.shape[0] // num_gaussian_layers,), + 0.4, + device=means3d.device, + ) for _ in range(num_gaussian_layers) ], dim=0, @@ -1832,14 +1944,18 @@ def test_gaussians_grid_render(self): alphas_centers = alphas[0][::2, ::2] expected_num_gaussians_centers = torch.full( - (num_gaussians_centers.shape[0], num_gaussians_centers.shape[1]), num_gaussian_layers, device=self.device + (num_gaussians_centers.shape[0], num_gaussians_centers.shape[1]), + num_gaussian_layers, + device=self.device, ) self.assertTrue(torch.equal(num_gaussians_centers, expected_num_gaussians_centers)) # pixels directly under the centers of the gaussians should have the correct alpha expected_alpha = self.calculate_expected_alpha(opacities[0].item(), num_gaussian_layers) expected_alphas_centers = torch.full( - (alphas_centers.shape[0], alphas_centers.shape[1]), expected_alpha, device=self.device + (alphas_centers.shape[0], alphas_centers.shape[1]), + expected_alpha, + device=self.device, ) self.assertTrue(torch.allclose(alphas_centers, expected_alphas_centers, atol=1e-5, rtol=1e-8)) @@ -1954,8 +2070,14 @@ def test_gaussians_grid_render(self): 10000.0, ) - for pixels, test_num_contributing_gaussians, reference_num_contributing_gaussians in zip( - pixels_to_render.unbind(), sparse_num_contributing_gaussians.unbind(), num_contributing_gaussians + for ( + pixels, + test_num_contributing_gaussians, + reference_num_contributing_gaussians, + ) in zip( + pixels_to_render.unbind(), + sparse_num_contributing_gaussians.unbind(), + num_contributing_gaussians, ): assert isinstance(pixels, torch.Tensor) assert isinstance(test_num_contributing_gaussians, torch.Tensor) @@ -1963,7 +2085,12 @@ def test_gaussians_grid_render(self): x_coords = pixels[:, 1] # [num_pixels_to_render] # Index reference_num_contributing_gaussians using the coordinates selected_reference_num_contributing_gaussians = reference_num_contributing_gaussians[y_coords, x_coords] - self.assertTrue(torch.equal(test_num_contributing_gaussians, selected_reference_num_contributing_gaussians)) + self.assertTrue( + torch.equal( + test_num_contributing_gaussians, + selected_reference_num_contributing_gaussians, + ) + ) for pixels, sparse_alphas, reference_alphas in zip(pixels_to_render.unbind(), sparse_alphas.unbind(), alphas): assert isinstance(pixels, torch.Tensor) @@ -1986,8 +2113,18 @@ def test_gaussians_grid_render(self): ) # Compare sparse results with dense results - for pixels, sparse_camera_ids, sparse_camera_weights, reference_camera_ids, reference_camera_weights in zip( - pixels_to_render.unbind(), sparse_ids.unbind(), sparse_weights.unbind(), ids.unbind(), weights.unbind() + for ( + pixels, + sparse_camera_ids, + sparse_camera_weights, + reference_camera_ids, + reference_camera_weights, + ) in zip( + pixels_to_render.unbind(), + sparse_ids.unbind(), + sparse_weights.unbind(), + ids.unbind(), + weights.unbind(), ): assert isinstance(pixels, torch.Tensor) @@ -2051,15 +2188,23 @@ def test_gaussian_contributors_scene_render(self): prev_num_contributing_gaussians = num_contributing_gaussians if self.save_regression_data: - torch.save(num_contributing_gaussians, self.data_path / "regression_num_contributing_gaussians.pt") - torch.save(alphas, self.data_path / "regression_num_contributing_gaussians_alphas.pt") + torch.save( + num_contributing_gaussians, + self.data_path / "regression_num_contributing_gaussians.pt", + ) + torch.save( + alphas, + self.data_path / "regression_num_contributing_gaussians_alphas.pt", + ) # load the regression data num_contributing_gaussians_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians.pt", + weights_only=True, ) alphas_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians_alphas.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians_alphas.pt", + weights_only=True, ) self.assertTrue(torch.equal(num_contributing_gaussians, num_contributing_gaussians_regression)) @@ -2080,9 +2225,13 @@ def test_gaussian_contributors_scene_render(self): torch.save(weights, self.data_path / "regression_contributing_gaussian_weights.pt") # load the regression data - ids_regression = torch.load(self.data_path / "regression_contributing_gaussian_ids.pt", weights_only=False) + ids_regression = torch.load( + self.data_path / "regression_contributing_gaussian_ids.pt", + weights_only=False, + ) weights_regression = torch.load( - self.data_path / "regression_contributing_gaussian_weights.pt", weights_only=False + self.data_path / "regression_contributing_gaussian_weights.pt", + weights_only=False, ) self.assertTrue(ids == ids_regression) @@ -2129,19 +2278,32 @@ def test_gaussian_contributors_scene_sparse_render(self): 10000.0, ) ) - self.assertTrue(torch.equal(num_contributing_gaussians.jdata, prev_num_contributing_gaussians.jdata)) + self.assertTrue( + torch.equal( + num_contributing_gaussians.jdata, + prev_num_contributing_gaussians.jdata, + ) + ) prev_num_contributing_gaussians = num_contributing_gaussians # load the regression data num_contributing_gaussians_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians.pt", + weights_only=True, ) alphas_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians_alphas.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians_alphas.pt", + weights_only=True, ) - for pixels, sparse_num_contributing_gaussians, reference_num_contributing_gaussians in zip( - pixels_to_render.unbind(), num_contributing_gaussians.unbind(), num_contributing_gaussians_regression + for ( + pixels, + sparse_num_contributing_gaussians, + reference_num_contributing_gaussians, + ) in zip( + pixels_to_render.unbind(), + num_contributing_gaussians.unbind(), + num_contributing_gaussians_regression, ): assert isinstance(pixels, torch.Tensor) assert isinstance(sparse_num_contributing_gaussians, torch.Tensor) @@ -2150,11 +2312,16 @@ def test_gaussian_contributors_scene_sparse_render(self): # Index reference_num_contributing_gaussians using the coordinates selected_reference_num_contributing_gaussians = reference_num_contributing_gaussians[y_coords, x_coords] self.assertTrue( - torch.equal(sparse_num_contributing_gaussians, selected_reference_num_contributing_gaussians) + torch.equal( + sparse_num_contributing_gaussians, + selected_reference_num_contributing_gaussians, + ) ) for pixels, sparse_alphas, reference_alphas in zip( - pixels_to_render.unbind(), num_contributing_gaussians_alphas.unbind(), alphas_regression + pixels_to_render.unbind(), + num_contributing_gaussians_alphas.unbind(), + alphas_regression, ): assert isinstance(pixels, torch.Tensor) assert isinstance(sparse_alphas, torch.Tensor) @@ -2176,9 +2343,13 @@ def test_gaussian_contributors_scene_sparse_render(self): ) # load the regression data - ids_regression = torch.load(self.data_path / "regression_contributing_gaussian_ids.pt", weights_only=False) + ids_regression = torch.load( + self.data_path / "regression_contributing_gaussian_ids.pt", + weights_only=False, + ) weights_regression = torch.load( - self.data_path / "regression_contributing_gaussian_weights.pt", weights_only=False + self.data_path / "regression_contributing_gaussian_weights.pt", + weights_only=False, ) for pixels, image_sparse_ids, image_reference_ids in zip(pixels_to_render.unbind(), sparse_ids, ids_regression): @@ -2243,14 +2414,22 @@ def test_gaussian_contributors_scene_dense_pixels_sparse_render(self): # load the regression data num_contributing_gaussians_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians.pt", + weights_only=True, ) alphas_regression = torch.load( - self.data_path / "regression_num_contributing_gaussians_alphas.pt", weights_only=True + self.data_path / "regression_num_contributing_gaussians_alphas.pt", + weights_only=True, ) - for pixels, sparse_num_contributing_gaussians, reference_num_contributing_gaussians in zip( - pixels_to_render.unbind(), sparse_num_contributing_gaussians.unbind(), num_contributing_gaussians_regression + for ( + pixels, + sparse_num_contributing_gaussians, + reference_num_contributing_gaussians, + ) in zip( + pixels_to_render.unbind(), + sparse_num_contributing_gaussians.unbind(), + num_contributing_gaussians_regression, ): assert isinstance(pixels, torch.Tensor) assert isinstance(sparse_num_contributing_gaussians, torch.Tensor) @@ -2259,7 +2438,10 @@ def test_gaussian_contributors_scene_dense_pixels_sparse_render(self): # Index reference_num_contributing_gaussians using the coordinates selected_reference_num_contributing_gaussians = reference_num_contributing_gaussians[y_coords, x_coords] self.assertTrue( - torch.equal(sparse_num_contributing_gaussians, selected_reference_num_contributing_gaussians) + torch.equal( + sparse_num_contributing_gaussians, + selected_reference_num_contributing_gaussians, + ) ) for pixels, sparse_alphas, reference_alphas in zip( @@ -2285,9 +2467,13 @@ def test_gaussian_contributors_scene_dense_pixels_sparse_render(self): ) # load the regression data - ids_regression = torch.load(self.data_path / "regression_contributing_gaussian_ids.pt", weights_only=False) + ids_regression = torch.load( + self.data_path / "regression_contributing_gaussian_ids.pt", + weights_only=False, + ) weights_regression = torch.load( - self.data_path / "regression_contributing_gaussian_weights.pt", weights_only=False + self.data_path / "regression_contributing_gaussian_weights.pt", + weights_only=False, ) for pixels, image_sparse_ids, image_reference_ids in zip(pixels_to_render, sparse_ids, ids_regression): @@ -2436,7 +2622,12 @@ def test_gaussian_render_sparse_depth_backward(self): "Sparse log scales grad does not match dense log scales grad at specified pixels", ) self.assertTrue( - torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + torch.allclose( + sparse_logit_opacities_grad, + dense_logit_opacities_grad, + atol=1e-4, + rtol=1e-8, + ), "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", ) @@ -2556,7 +2747,12 @@ def test_gaussian_render_sparse_features_backward(self): "Sparse log scales grad does not match dense log scales grad at specified pixels", ) self.assertTrue( - torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + torch.allclose( + sparse_logit_opacities_grad, + dense_logit_opacities_grad, + atol=1e-4, + rtol=1e-8, + ), "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", ) self.assertTrue( @@ -2684,7 +2880,12 @@ def test_gaussian_render_sparse_features_and_depths_backward(self): "Sparse log scales grad does not match dense log scales grad at specified pixels", ) self.assertTrue( - torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + torch.allclose( + sparse_logit_opacities_grad, + dense_logit_opacities_grad, + atol=1e-4, + rtol=1e-8, + ), "Sparse logit opacities grad does not match dense logit opacities grad at specified pixels", ) self.assertTrue( @@ -2764,12 +2965,23 @@ def test_render_with_different_backgrounds_per_camera(self): # Render without background colors_no_bg, alphas_no_bg = self.gs3d.render_images( - cam_mats, proj_mats, self.width, self.height, self.near_plane, self.far_plane + cam_mats, + proj_mats, + self.width, + self.height, + self.near_plane, + self.far_plane, ) # Render with different backgrounds colors_with_bg, alphas_with_bg = self.gs3d.render_images( - cam_mats, proj_mats, self.width, self.height, self.near_plane, self.far_plane, backgrounds=backgrounds + cam_mats, + proj_mats, + self.width, + self.height, + self.near_plane, + self.far_plane, + backgrounds=backgrounds, ) # Alphas should be identical @@ -2869,7 +3081,13 @@ def test_gradients_flow_with_backgrounds(self): # Render with background and compute loss colors, alphas = self.gs3d.render_images( - cam_mats, proj_mats, self.width, self.height, self.near_plane, self.far_plane, backgrounds=backgrounds + cam_mats, + proj_mats, + self.width, + self.height, + self.near_plane, + self.far_plane, + backgrounds=backgrounds, ) loss = colors.sum() @@ -2923,18 +3141,27 @@ def test_render_from_projected_gaussians_with_backgrounds(self): num_channels = 3 # Project gaussians - projected = self.gs3d.project_gaussians_for_images( - cam_mats, proj_mats, self.width, self.height, self.near_plane, self.far_plane + projected = self.gs3d.project_gaussians( + cam_mats, + proj_mats, + self.width, + self.height, + self.near_plane, + self.far_plane, ) # Create backgrounds backgrounds = torch.full((num_cameras, num_channels), 0.7, device=self.device, dtype=torch.float32) # Render without background - colors_no_bg, alphas_no_bg = self.gs3d.render_from_projected_gaussians(projected) + colors_no_bg, alphas_no_bg = self.gs3d.render_from_projected_gaussians( + projected, cam_mats, render_mode=GaussianRenderMode.FEATURES + ) # Render with background - colors_with_bg, alphas_with_bg = self.gs3d.render_from_projected_gaussians(projected, backgrounds=backgrounds) + colors_with_bg, alphas_with_bg = self.gs3d.render_from_projected_gaussians( + projected, cam_mats, render_mode=GaussianRenderMode.FEATURES, backgrounds=backgrounds + ) # Alphas should be identical self.assertTrue(torch.allclose(alphas_no_bg, alphas_with_bg)) @@ -3483,6 +3710,124 @@ def test_view_directions_not_prenormalized(self): self.assertFalse(torch.isnan(result).any()) self.assertFalse(torch.isinf(result).any()) + def test_unnormalized_matches_normalized(self): + """Regression: kernel normalizes internally, so scaled view_dirs must give identical results.""" + N = 50 + D = 3 + C = 2 + + torch.manual_seed(123) + sh0 = torch.randn(N, 1, D, device=self.device) + shN = torch.randn(N, 15, D, device=self.device) + radii = torch.ones(C, N, dtype=torch.int32, device=self.device) + + view_dirs = torch.randn(C, N, 3, device=self.device) + view_dirs_normalized = torch.nn.functional.normalize(view_dirs, dim=-1) + view_dirs_scaled = view_dirs * 100.0 + + result_norm = evaluate_spherical_harmonics( + sh_degree=3, num_cameras=C, sh0=sh0, radii=radii, shN=shN, view_directions=view_dirs_normalized + ) + result_unnorm = evaluate_spherical_harmonics( + sh_degree=3, num_cameras=C, sh0=sh0, radii=radii, shN=shN, view_directions=view_dirs + ) + result_scaled = evaluate_spherical_harmonics( + sh_degree=3, num_cameras=C, sh0=sh0, radii=radii, shN=shN, view_directions=view_dirs_scaled + ) + + self.assertTrue(torch.allclose(result_norm, result_unnorm, atol=1e-5)) + self.assertTrue(torch.allclose(result_norm, result_scaled, atol=1e-5)) + + def test_functional_evaluate_gaussian_sh_matches_standalone(self): + """Regression: evaluate_gaussian_sh (functional API) produces correct view-dependent SH.""" + from fvdb.functional import evaluate_gaussian_sh + from fvdb.functional._gaussian_projection import ProjectedGaussians + from fvdb.enums import CameraModel, ProjectionMethod + + N = 30 + D = 3 + C = 2 + + torch.manual_seed(42) + means = torch.randn(N, 3, device=self.device) + sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True) + shN = torch.randn(N, 8, D, device=self.device, requires_grad=True) + + w2c = torch.eye(4, device=self.device).unsqueeze(0).expand(C, -1, -1).contiguous() + w2c = w2c.clone() + w2c[0, 0, 3] = 1.0 + w2c[1, 1, 3] = -1.0 + + radii = torch.ones(C, N, dtype=torch.int32, device=self.device) + depths = torch.ones(C, N, device=self.device) + + projected = ProjectedGaussians( + means2d=torch.zeros(C, N, 2, device=self.device), + conics=torch.zeros(C, N, 3, device=self.device), + compensations=torch.zeros(C, N, device=self.device), + radii=radii, + depths=depths, + image_width=64, + image_height=64, + camera_model=CameraModel.PINHOLE, + projection_method=ProjectionMethod.ANALYTIC, + ) + + result = evaluate_gaussian_sh(means, sh0, shN, w2c, projected, sh_degree_to_use=2) + + self.assertEqual(result.shape, (C, N, D)) + self.assertFalse(torch.isnan(result).any()) + self.assertFalse(torch.allclose(result[0], result[1], atol=1e-3)) + + loss = result.sum() + loss.backward() + self.assertIsNotNone(sh0.grad) + self.assertIsNotNone(shN.grad) + self.assertTrue(torch.any(sh0.grad != 0)) + self.assertTrue(torch.any(shN.grad != 0)) + + def test_functional_evaluate_gaussian_sh_degree0_empty_view_dirs(self): + """Regression: degree-0 path passes an empty tensor for view_dirs (no allocation).""" + from fvdb.functional import evaluate_gaussian_sh + from fvdb.functional._gaussian_projection import ProjectedGaussians + from fvdb.enums import CameraModel, ProjectionMethod + + N = 20 + D = 3 + C = 2 + + torch.manual_seed(99) + means = torch.randn(N, 3, device=self.device) + sh0 = torch.randn(N, 1, D, device=self.device, requires_grad=True) + shN = torch.randn(N, 8, D, device=self.device) + + w2c = torch.eye(4, device=self.device).unsqueeze(0).expand(C, -1, -1).contiguous() + radii = torch.ones(C, N, dtype=torch.int32, device=self.device) + depths = torch.ones(C, N, device=self.device) + + projected = ProjectedGaussians( + means2d=torch.zeros(C, N, 2, device=self.device), + conics=torch.zeros(C, N, 3, device=self.device), + compensations=torch.zeros(C, N, device=self.device), + radii=radii, + depths=depths, + image_width=64, + image_height=64, + camera_model=CameraModel.PINHOLE, + projection_method=ProjectionMethod.ANALYTIC, + ) + + result = evaluate_gaussian_sh(means, sh0, shN, w2c, projected, sh_degree_to_use=0) + + self.assertEqual(result.shape, (C, N, D)) + self.assertFalse(torch.isnan(result).any()) + self.assertTrue(torch.allclose(result[0], result[1], atol=1e-6)) + + loss = result.sum() + loss.backward() + self.assertIsNotNone(sh0.grad) + self.assertTrue(torch.any(sh0.grad != 0)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") class TestGaussianRenderMasks(BaseGaussianTestCase): @@ -3504,10 +3849,18 @@ def _all_zeros_pixel_mask(self, C): return torch.zeros((C, self.height, self.width), device=self.device, dtype=torch.bool) def _all_ones_tile_mask(self, C): - return torch.ones((C, self.num_tiles_h, self.num_tiles_w), device=self.device, dtype=torch.bool) + return torch.ones( + (C, self.num_tiles_h, self.num_tiles_w), + device=self.device, + dtype=torch.bool, + ) def _all_zeros_tile_mask(self, C): - return torch.zeros((C, self.num_tiles_h, self.num_tiles_w), device=self.device, dtype=torch.bool) + return torch.zeros( + (C, self.num_tiles_h, self.num_tiles_w), + device=self.device, + dtype=torch.bool, + ) # -- Dense: render_images ------------------------------------------------ @@ -3517,7 +3870,13 @@ def test_render_images_all_ones_mask_matches_no_mask(self): proj = self.projection_mats[:C] ref, ref_a = self.gs3d.render_images( - cam, proj, self.width, self.height, self.near_plane, self.far_plane, tile_size=self.tile_size + cam, + proj, + self.width, + self.height, + self.near_plane, + self.far_plane, + tile_size=self.tile_size, ) out, out_a = self.gs3d.render_images( cam, @@ -3574,6 +3933,7 @@ def test_render_images_backward_with_masks(self): loss = out.sum() + out_a.sum() loss.backward() self.assertIsNotNone(self.gs3d.means.grad) + assert self.gs3d.means.grad is not None self.assertGreater(torch.abs(self.gs3d.means.grad).sum().item(), 0) def test_render_images_all_zeros_mask_zero_grads(self): @@ -3596,6 +3956,7 @@ def test_render_images_all_zeros_mask_zero_grads(self): loss = out.sum() + out_a.sum() loss.backward() self.assertIsNotNone(self.gs3d.means.grad) + assert self.gs3d.means.grad is not None self.assertTrue(torch.equal(self.gs3d.means.grad, torch.zeros_like(self.gs3d.means.grad))) # -- Dense: render_depths ------------------------------------------------ @@ -3653,18 +4014,22 @@ def test_render_from_projected_gaussians_with_masks(self): D = 3 bg = torch.tensor([[0.7, 0.7, 0.7]], device=self.device, dtype=torch.float32) - projected = self.gs3d.project_gaussians_for_images( - cam, proj, self.width, self.height, self.near_plane, self.far_plane + projected = self.gs3d.project_gaussians(cam, proj, self.width, self.height, self.near_plane, self.far_plane) + ref, ref_a = self.gs3d.render_from_projected_gaussians( + projected, cam, render_mode=GaussianRenderMode.FEATURES, backgrounds=bg ) - ref, ref_a = self.gs3d.render_from_projected_gaussians(projected, backgrounds=bg) out, out_a = self.gs3d.render_from_projected_gaussians( - projected, backgrounds=bg, masks=self._all_ones_pixel_mask(C) + projected, cam, render_mode=GaussianRenderMode.FEATURES, backgrounds=bg, masks=self._all_ones_pixel_mask(C) ) self.assertTrue(torch.allclose(ref, out, atol=1e-5)) self.assertTrue(torch.allclose(ref_a, out_a, atol=1e-5)) out_z, out_z_a = self.gs3d.render_from_projected_gaussians( - projected, backgrounds=bg, masks=self._all_zeros_pixel_mask(C) + projected, + cam, + render_mode=GaussianRenderMode.FEATURES, + backgrounds=bg, + masks=self._all_zeros_pixel_mask(C), ) expected = bg.view(C, 1, 1, D).expand(C, self.height, self.width, D) self.assertTrue(torch.equal(out_z_a, torch.zeros_like(out_z_a))) @@ -3733,6 +4098,7 @@ def test_sparse_render_images_backward_with_backgrounds_and_masks(self): loss = out.jdata.sum() + out_a.jdata.sum() loss.backward() self.assertIsNotNone(self.gs3d.means.grad) + assert self.gs3d.means.grad is not None self.assertGreater(torch.abs(self.gs3d.means.grad).sum().item(), 0) # -- Sparse: sparse_render_depths ---------------------------------------- @@ -4000,6 +4366,9 @@ def test_sparse_render_depth_backward_with_duplicates(self): l1.backward() assert self.gs3d.means.grad is not None + assert self.gs3d.quats.grad is not None + assert self.gs3d.log_scales.grad is not None + assert self.gs3d.logit_opacities.grad is not None sparse_means_grad = self.gs3d.means.grad.clone() sparse_quats_grad = self.gs3d.quats.grad.clone() sparse_log_scales_grad = self.gs3d.log_scales.grad.clone() @@ -4026,6 +4395,10 @@ def test_sparse_render_depth_backward_with_duplicates(self): l2 = torch.mean(dense_depth_pixels) + dense_alphas_pixels.sum() l2.backward() + assert self.gs3d.means.grad is not None + assert self.gs3d.quats.grad is not None + assert self.gs3d.log_scales.grad is not None + assert self.gs3d.logit_opacities.grad is not None dense_means_grad = self.gs3d.means.grad.clone() dense_quats_grad = self.gs3d.quats.grad.clone() dense_log_scales_grad = self.gs3d.log_scales.grad.clone() @@ -4044,7 +4417,12 @@ def test_sparse_render_depth_backward_with_duplicates(self): "Sparse log_scales grad with duplicates does not match dense", ) self.assertTrue( - torch.allclose(sparse_logit_opacities_grad, dense_logit_opacities_grad, atol=1e-4, rtol=1e-8), + torch.allclose( + sparse_logit_opacities_grad, + dense_logit_opacities_grad, + atol=1e-4, + rtol=1e-8, + ), "Sparse logit_opacities grad with duplicates does not match dense", ) @@ -4099,7 +4477,10 @@ def test_sparse_render_contributing_ids_with_duplicates(self): mask = inverse == i id_vals = ids.jdata[mask] weight_vals = weights.jdata[mask] - self.assertTrue(torch.all(id_vals == id_vals[0:1]), "Duplicate pixels have different contributing IDs") + self.assertTrue( + torch.all(id_vals == id_vals[0:1]), + "Duplicate pixels have different contributing IDs", + ) self.assertTrue( torch.allclose(weight_vals, weight_vals[0:1], atol=1e-6, rtol=1e-8), "Duplicate pixels have different contributing weights", @@ -4169,13 +4550,29 @@ def setUp(self): self.device = "cuda:0" self.dtype = torch.float32 - means = torch.tensor([[0.18, -0.12, 2.8], [-0.08, 0.10, 3.4]], device=self.device, dtype=self.dtype) - quats = torch.tensor([[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], device=self.device, dtype=self.dtype) + means = torch.tensor( + [[0.18, -0.12, 2.8], [-0.08, 0.10, 3.4]], + device=self.device, + dtype=self.dtype, + ) + quats = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], + device=self.device, + dtype=self.dtype, + ) log_scales = torch.log( - torch.tensor([[0.06, 0.05, 0.04], [0.05, 0.07, 0.06]], device=self.device, dtype=self.dtype) + torch.tensor( + [[0.06, 0.05, 0.04], [0.05, 0.07, 0.06]], + device=self.device, + dtype=self.dtype, + ) ) logit_opacities = torch.tensor([2.2, 1.8], device=self.device, dtype=self.dtype) - sh0 = torch.tensor([[[0.7, 0.1, -0.2]], [[-0.3, 0.5, 0.4]]], device=self.device, dtype=self.dtype) + sh0 = torch.tensor( + [[[0.7, 0.1, -0.2]], [[-0.3, 0.5, 0.4]]], + device=self.device, + dtype=self.dtype, + ) shN = torch.empty((2, 0, 3), device=self.device, dtype=self.dtype) self.gs3d = GaussianSplat3d.from_tensors( @@ -4310,7 +4707,10 @@ def _make_structural_comparison_splat(self) -> GaussianSplat3d: dtype=self.dtype, ) ) - logit_opacities = torch.tensor([2.1, 1.8, 1.6, 2.0, 1.7, 1.9, 1.5, 2.2, 1.8, 1.7, 1.6, 1.9], device=self.device) + logit_opacities = torch.tensor( + [2.1, 1.8, 1.6, 2.0, 1.7, 1.9, 1.5, 2.2, 1.8, 1.7, 1.6, 1.9], + device=self.device, + ) sh0 = torch.tensor( [ [[0.70, -0.05, -0.20]], @@ -4402,7 +4802,10 @@ def _blurred_rgb_rmse(self, rgb_a: torch.Tensor, rgb_b: torch.Tensor, union_mask pooled_a = nnf.avg_pool2d(rgb_a.permute(2, 0, 1).unsqueeze(0), kernel_size=5, stride=1, padding=2) pooled_b = nnf.avg_pool2d(rgb_b.permute(2, 0, 1).unsqueeze(0), kernel_size=5, stride=1, padding=2) pooled_mask = nnf.avg_pool2d( - union_mask.to(dtype=rgb_a.dtype).unsqueeze(0).unsqueeze(0), kernel_size=5, stride=1, padding=2 + union_mask.to(dtype=rgb_a.dtype).unsqueeze(0).unsqueeze(0), + kernel_size=5, + stride=1, + padding=2, )[0, 0] if not bool((pooled_mask > 1.0e-3).any()): return 0.0 @@ -4442,7 +4845,11 @@ def _structural_comparison_metrics( world_depth_mass = world_rgbd[0, ..., -1].sum() projected_depth_mean = projected_depth_mass / torch.clamp(projected_alpha_2d.sum(), min=1.0e-8) world_depth_mean = world_depth_mass / torch.clamp(world_alpha_2d.sum(), min=1.0e-8) - depth_scale = max(abs(float(projected_depth_mean.item())), abs(float(world_depth_mean.item())), 1.0e-8) + depth_scale = max( + abs(float(projected_depth_mean.item())), + abs(float(world_depth_mean.item())), + 1.0e-8, + ) return { "support_iou": float( @@ -4464,22 +4871,49 @@ def _structural_comparison_metrics( def test_projection_method_resolution_and_metadata(self): cases = [ (CameraModel.PINHOLE, ProjectionMethod.AUTO, ProjectionMethod.ANALYTIC), - (CameraModel.ORTHOGRAPHIC, ProjectionMethod.AUTO, ProjectionMethod.ANALYTIC), - (CameraModel.PINHOLE, ProjectionMethod.UNSCENTED, ProjectionMethod.UNSCENTED), - (CameraModel.ORTHOGRAPHIC, ProjectionMethod.UNSCENTED, ProjectionMethod.UNSCENTED), - (CameraModel.OPENCV_RADTAN_5, ProjectionMethod.AUTO, ProjectionMethod.UNSCENTED), - (CameraModel.OPENCV_RATIONAL_8, ProjectionMethod.AUTO, ProjectionMethod.UNSCENTED), - (CameraModel.OPENCV_RADTAN_THIN_PRISM_9, ProjectionMethod.AUTO, ProjectionMethod.UNSCENTED), - (CameraModel.OPENCV_THIN_PRISM_12, ProjectionMethod.AUTO, ProjectionMethod.UNSCENTED), + ( + CameraModel.ORTHOGRAPHIC, + ProjectionMethod.AUTO, + ProjectionMethod.ANALYTIC, + ), + ( + CameraModel.PINHOLE, + ProjectionMethod.UNSCENTED, + ProjectionMethod.UNSCENTED, + ), + ( + CameraModel.ORTHOGRAPHIC, + ProjectionMethod.UNSCENTED, + ProjectionMethod.UNSCENTED, + ), + ( + CameraModel.OPENCV_RADTAN_5, + ProjectionMethod.AUTO, + ProjectionMethod.UNSCENTED, + ), + ( + CameraModel.OPENCV_RATIONAL_8, + ProjectionMethod.AUTO, + ProjectionMethod.UNSCENTED, + ), + ( + CameraModel.OPENCV_RADTAN_THIN_PRISM_9, + ProjectionMethod.AUTO, + ProjectionMethod.UNSCENTED, + ), + ( + CameraModel.OPENCV_THIN_PRISM_12, + ProjectionMethod.AUTO, + ProjectionMethod.UNSCENTED, + ), ] for camera_model, requested_method, expected_method in cases: with self.subTest(camera_model=camera_model, requested_method=requested_method): render_args = self._render_args(camera_model) project_args = self._with_overrides(render_args, projection_method=requested_method) - projected = self.gs3d.project_gaussians_for_images( + projected = self.gs3d.project_gaussians( **project_args, - sh_degree_to_use=0, ) self.assertEqual(projected.camera_model, camera_model) self.assertEqual(projected.projection_method, expected_method) @@ -4489,17 +4923,16 @@ def test_camera_api_validation_errors(self): opencv_args = self._render_args(CameraModel.OPENCV_RADTAN_5) with self.assertRaisesRegex(RuntimeError, "distortionCoeffs must be provided"): - self.gs3d.project_gaussians_for_images( + self.gs3d.project_gaussians( **self._with_overrides(opencv_args, distortion_coeffs=None), - sh_degree_to_use=0, ) with self.assertRaisesRegex(RuntimeError, "distortionCoeffs must have shape"): - self.gs3d.project_gaussians_for_images( + self.gs3d.project_gaussians( **self._with_overrides( - opencv_args, distortion_coeffs=opencv_args["distortion_coeffs"][:, :5].contiguous() + opencv_args, + distortion_coeffs=opencv_args["distortion_coeffs"][:, :5].contiguous(), ), - sh_degree_to_use=0, ) with self.assertRaisesRegex(RuntimeError, "ProjectionMethod::UNSCENTED or AUTO"): @@ -4509,17 +4942,31 @@ def test_camera_api_validation_errors(self): ) with self.assertRaisesRegex(RuntimeError, "projectionMatrices must be contiguous"): - self.gs3d.project_gaussians_for_images( + self.gs3d.project_gaussians( **self._with_overrides( pinhole_args, projection_matrices=pinhole_args["projection_matrices"].transpose(1, 2), ), - sh_degree_to_use=0, ) def test_pinhole_and_orthographic_ignore_distortion_coeffs_tensor(self): ignored_distortion = torch.tensor( - [[0.12, -0.03, 0.01, 0.0, 0.0, 0.0, 0.02, -0.015, 0.004, -0.003, 0.002, -0.001]], + [ + [ + 0.12, + -0.03, + 0.01, + 0.0, + 0.0, + 0.0, + 0.02, + -0.015, + 0.004, + -0.003, + 0.002, + -0.001, + ] + ], device=self.device, dtype=self.dtype, ) @@ -4530,10 +4977,10 @@ def test_pinhole_and_orthographic_ignore_distortion_coeffs_tensor(self): render_args = self._render_args(camera_model) ignored_args = self._with_overrides(render_args, distortion_coeffs=ignored_distortion) - projected_default = parity_gs3d.project_gaussians_for_images(**render_args, sh_degree_to_use=0) - projected_ignored = parity_gs3d.project_gaussians_for_images(**ignored_args, sh_degree_to_use=0) + projected_default = parity_gs3d.project_gaussians(**render_args) + projected_ignored = parity_gs3d.project_gaussians(**ignored_args) torch.testing.assert_close(projected_default.means2d, projected_ignored.means2d) - torch.testing.assert_close(projected_default.inv_covar_2d, projected_ignored.inv_covar_2d) + torch.testing.assert_close(projected_default.conics, projected_ignored.conics) images_default, alpha_default = parity_gs3d.render_images(**render_args, sh_degree_to_use=0) images_ignored, alpha_ignored = parity_gs3d.render_images(**ignored_args, sh_degree_to_use=0) @@ -4584,33 +5031,42 @@ def test_pinhole_and_orthographic_ignore_distortion_coeffs_tensor(self): torch.testing.assert_close(world_rgbd_alpha_default, world_rgbd_alpha_ignored) def test_projected_render_matches_from_world_for_stable_scene(self): - for camera_model in (CameraModel.PINHOLE, CameraModel.ORTHOGRAPHIC, CameraModel.OPENCV_RADTAN_5): + for camera_model in ( + CameraModel.PINHOLE, + CameraModel.ORTHOGRAPHIC, + CameraModel.OPENCV_RADTAN_5, + ): with self.subTest(camera_model=camera_model): parity_gs3d = self._make_tiny_parity_splat() render_args = self._render_args(camera_model) - projected_images = parity_gs3d.project_gaussians_for_images(**render_args, sh_degree_to_use=0) + projected_images = parity_gs3d.project_gaussians(**render_args) images_from_projection, alpha_from_projection = parity_gs3d.render_from_projected_gaussians( - projected_images + projected_images, + render_args["world_to_camera_matrices"], + render_mode=GaussianRenderMode.FEATURES, + sh_degree_to_use=0, ) images_from_dense, alpha_from_dense = parity_gs3d.render_images(**render_args, sh_degree_to_use=0) images_from_world, alpha_from_world = parity_gs3d.render_images_from_world( **render_args, sh_degree_to_use=0 ) - projected_depths = parity_gs3d.project_gaussians_for_depths(**render_args) + projected_depths = parity_gs3d.project_gaussians(**render_args) depths_from_projection, depth_alpha_from_projection = parity_gs3d.render_from_projected_gaussians( - projected_depths + projected_depths, + render_args["world_to_camera_matrices"], + render_mode=GaussianRenderMode.DEPTH, ) depths_from_dense, depth_alpha_from_dense = parity_gs3d.render_depths(**render_args) depths_from_world, depth_alpha_from_world = parity_gs3d.render_depths_from_world(**render_args) - projected_rgbd = parity_gs3d.project_gaussians_for_images_and_depths( - **render_args, - sh_degree_to_use=0, - ) + projected_rgbd = parity_gs3d.project_gaussians(**render_args) rgbd_from_projection, rgbd_alpha_from_projection = parity_gs3d.render_from_projected_gaussians( - projected_rgbd + projected_rgbd, + render_args["world_to_camera_matrices"], + render_mode=GaussianRenderMode.FEATURES_AND_DEPTH, + sh_degree_to_use=0, ) rgbd_from_dense, rgbd_alpha_from_dense = parity_gs3d.render_images_and_depths( **render_args, @@ -4624,9 +5080,19 @@ def test_projected_render_matches_from_world_for_stable_scene(self): torch.testing.assert_close(images_from_projection, images_from_dense, atol=1e-6, rtol=1e-6) torch.testing.assert_close(alpha_from_projection, alpha_from_dense, atol=1e-6, rtol=1e-6) torch.testing.assert_close(depths_from_projection, depths_from_dense, atol=1e-6, rtol=1e-6) - torch.testing.assert_close(depth_alpha_from_projection, depth_alpha_from_dense, atol=1e-6, rtol=1e-6) + torch.testing.assert_close( + depth_alpha_from_projection, + depth_alpha_from_dense, + atol=1e-6, + rtol=1e-6, + ) torch.testing.assert_close(rgbd_from_projection, rgbd_from_dense, atol=1e-6, rtol=1e-6) - torch.testing.assert_close(rgbd_alpha_from_projection, rgbd_alpha_from_dense, atol=1e-6, rtol=1e-6) + torch.testing.assert_close( + rgbd_alpha_from_projection, + rgbd_alpha_from_dense, + atol=1e-6, + rtol=1e-6, + ) torch.testing.assert_close(alpha_from_world, depth_alpha_from_world, atol=1e-5, rtol=1e-5) torch.testing.assert_close(alpha_from_world, rgbd_alpha_from_world, atol=1e-5, rtol=1e-5) @@ -4639,17 +5105,21 @@ def test_projected_render_matches_from_world_for_stable_scene(self): def test_structural_projected_render_matches_from_world_for_medium_scene(self): # The two rasterization paths diverge too much for stable pixelwise parity on richer scenes, # so this test checks that they preserve the same overall support, location, depth, and appearance. - for camera_model in (CameraModel.PINHOLE, CameraModel.ORTHOGRAPHIC, CameraModel.OPENCV_RADTAN_5): + for camera_model in ( + CameraModel.PINHOLE, + CameraModel.ORTHOGRAPHIC, + CameraModel.OPENCV_RADTAN_5, + ): with self.subTest(camera_model=camera_model): parity_gs3d = self._make_structural_comparison_splat() render_args = self._render_args(camera_model) - projected_rgbd = parity_gs3d.project_gaussians_for_images_and_depths( - **render_args, - sh_degree_to_use=0, - ) + projected_rgbd = parity_gs3d.project_gaussians(**render_args) rgbd_from_projection, alpha_from_projection = parity_gs3d.render_from_projected_gaussians( - projected_rgbd + projected_rgbd, + render_args["world_to_camera_matrices"], + render_mode=GaussianRenderMode.FEATURES_AND_DEPTH, + sh_degree_to_use=0, ) rgbd_from_world, alpha_from_world = parity_gs3d.render_images_and_depths_from_world( **render_args, @@ -4674,7 +5144,12 @@ def test_sparse_render_camera_args_match_dense_render(self): pixels = self._all_pixels(C=2) sparse_cases = [ - ("images", self.gs3d.render_images, self.gs3d.sparse_render_images, {"sh_degree_to_use": 0}), + ( + "images", + self.gs3d.render_images, + self.gs3d.sparse_render_images, + {"sh_degree_to_use": 0}, + ), ("depths", self.gs3d.render_depths, self.gs3d.sparse_render_depths, {}), ( "rgbd", @@ -4684,7 +5159,11 @@ def test_sparse_render_camera_args_match_dense_render(self): ), ] - for camera_model in (CameraModel.PINHOLE, CameraModel.ORTHOGRAPHIC, CameraModel.OPENCV_RADTAN_5): + for camera_model in ( + CameraModel.PINHOLE, + CameraModel.ORTHOGRAPHIC, + CameraModel.OPENCV_RADTAN_5, + ): render_args = self._render_args(camera_model, C=2) for name, dense_fn, sparse_fn, extra_kwargs in sparse_cases: with self.subTest(camera_model=camera_model, render_mode=name): @@ -4697,7 +5176,9 @@ def test_sparse_render_camera_args_match_dense_render(self): self._assert_sparse_matches_dense(dense_values, sparse_values, pixels) self._assert_sparse_matches_dense(dense_alphas, sparse_alphas, pixels) - def test_batched_opencv_render_uses_per_camera_intrinsics_distortion_backgrounds_and_masks(self): + def test_batched_opencv_render_uses_per_camera_intrinsics_distortion_backgrounds_and_masks( + self, + ): C = 2 render_args = self._render_args(CameraModel.OPENCV_RADTAN_5, C=C) backgrounds = torch.tensor([[0.1, -0.2, 0.3], [-0.4, 0.2, 0.1]], device=self.device, dtype=self.dtype) @@ -4726,8 +5207,18 @@ def test_batched_opencv_render_uses_per_camera_intrinsics_distortion_backgrounds backgrounds=backgrounds[cam_idx : cam_idx + 1].contiguous(), masks=masks[cam_idx : cam_idx + 1].contiguous(), ) - torch.testing.assert_close(batched_features[cam_idx : cam_idx + 1], single_features, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(batched_alphas[cam_idx : cam_idx + 1], single_alphas, atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + batched_features[cam_idx : cam_idx + 1], + single_features, + atol=1e-5, + rtol=1e-5, + ) + torch.testing.assert_close( + batched_alphas[cam_idx : cam_idx + 1], + single_alphas, + atol=1e-5, + rtol=1e-5, + ) class TestProjectionGradsMultiCamera(unittest.TestCase): @@ -4802,7 +5293,17 @@ def _make_test_data(self): K = torch.tensor([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], device=device) Ks = K.unsqueeze(0).expand(self.C, -1, -1).contiguous() - return means, quats, log_scales, logit_opacities, sh0, shN, sh_coeffs, viewmats, Ks + return ( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + sh_coeffs, + viewmats, + Ks, + ) def _build_gs3d(self, means, quats, log_scales, logit_opacities, sh0, shN): gs3d = GaussianSplat3d.from_tensors( @@ -4818,7 +5319,17 @@ def _build_gs3d(self, means, quats, log_scales, logit_opacities, sh0, shN): def test_dense_projection_grads_multicamera(self): """Dense path: GaussianProjectionBackward.cu -- all parameter gradients.""" - means, quats, log_scales, logit_opacities, sh0, shN, _sh_coeffs, viewmats, Ks = self._make_test_data() + ( + means, + quats, + log_scales, + logit_opacities, + sh0, + shN, + _sh_coeffs, + viewmats, + Ks, + ) = self._make_test_data() gs3d = self._build_gs3d(means, quats, log_scales, logit_opacities, sh0, shN) images, _ = gs3d.render_images(viewmats, Ks, self.W, self.H, near=0.01, far=1e10) @@ -4849,7 +5360,17 @@ def test_dense_projection_grads_multicamera(self): def test_jagged_projection_grads_multicamera(self): """Jagged path: GaussianProjectionJaggedBackward.cu -- all parameter gradients.""" - means, quats, log_scales, logit_opacities, _sh0, _shN, sh_coeffs, viewmats, Ks = self._make_test_data() + ( + means, + quats, + log_scales, + logit_opacities, + _sh0, + _shN, + sh_coeffs, + viewmats, + Ks, + ) = self._make_test_data() scales = torch.exp(log_scales) opacities = torch.sigmoid(logit_opacities) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 97d78f345..f916d59f2 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -5,7 +5,7 @@ import pytest import torch -from fvdb.utils.metrics import psnr, ssim +from fvdb.functional import psnr, ssim @pytest.mark.parametrize("padding", ["same", "valid"]) # fused-ssim supports these paddings @@ -99,3 +99,22 @@ def test_psnr_input_validation(): c = torch.zeros((1, 8, 8)) with pytest.raises(ValueError): _ = psnr(c, c) + + +def test_ssim_padding_validation(): + """Invalid padding values must raise ValueError (not silently pass via assert).""" + img = torch.rand(1, 1, 32, 32, device="cuda", dtype=torch.float32) + with pytest.raises(ValueError, match="padding"): + ssim(img, img, padding="reflect") + with pytest.raises(ValueError, match="padding"): + ssim(img, img, padding="") + + +def test_metrics_import_from_functional(): + """Ensure metrics can be imported from fvdb.functional without circular import issues.""" + import importlib + + mod = importlib.import_module("fvdb.functional._metrics") + assert hasattr(mod, "ssim") + assert hasattr(mod, "psnr") + assert not hasattr(mod, "fvdb"), "Should not import top-level fvdb package directly"