Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions panoptic_segmentation/mask_pls/data/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def __call__(self, data) -> Dict:
class fVDBSemanticSegmentationDatasetCollation:
"""
A data collation class for semantic segmentation datasets using fVDB.
This class handles the conversion of point cloud data into VDBTensors.
This class handles the conversion of point cloud data into sparse grid features.
Parameters
----------
device : torch.device, optional
The device on which to perform computations (default is "cuda:0")
Methods
-------
__call__(data: dict) -> dict:
Transforms input point cloud data into VDBTensors.
Transforms input point cloud data into fVDB grid and feature representations.
Parameters:
data (dict): Dictionary containing:
- xyz: List of point coordinates in world space
Expand All @@ -65,7 +65,8 @@ class fVDBSemanticSegmentationDatasetCollation:
Returns:
dict: Original dictionary updated with:
- xyz: JaggedTensor of point coordinates
- vdbtensor: VDBTensor containing the structured volumetric data
- grid: GridBatch containing the sparse grid topology
- features: JaggedTensor containing the per-voxel features
"""

def __init__(self, device=torch.device("cuda:0")):
Expand All @@ -76,21 +77,18 @@ def __call__(self, data):
# xyz world space point positions
data["xyz"] = fvdb.JaggedTensor([torch.tensor(c, device=self.device) for c in data["xyz"]])

grid = fvdb.gridbatch_from_points(data["xyz"], voxel_sizes=[n.tolist() for n in data["voxel_size"]])
grid = fvdb.GridBatch.from_points(data["xyz"], voxel_sizes=[n.tolist() for n in data["voxel_size"]])

# get mapping of the coordinates to the grid for feature mapping
coord_ijks = grid.world_to_grid(data["xyz"]).round().int()
coord_ijks = grid.world_to_voxel(data["xyz"]).round().int()
inv_idx = grid.ijk_to_inv_index(coord_ijks, cumulative=True)

# assert(torch.all(grid.ijk.jdata == coord_ijks.jdata[inv_idx.jdata]))

jfeats = torch.cat([torch.tensor(f, device=self.device).unsqueeze(-1) for f in data["intensity"]])
jfeats = grid.jagged_like(jfeats[inv_idx.jdata])

jfeats = fvdb.jcat([grid.ijk.float(), jfeats], dim=1)

vdbtensor = fvdb.nn.VDBTensor(grid, jfeats)

data["vdbtensor"] = vdbtensor
data["grid"] = grid
data["features"] = jfeats

return data
5 changes: 2 additions & 3 deletions panoptic_segmentation/mask_pls/maskpls_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ channels:
- conda-forge
- nodefaults
dependencies:
- python=3.11
- pytorch-gpu=2.4.1[build=cuda120*]
- python=3.12
- pytorch-gpu=2.8.0[build=cuda129*]
- pip
- git
- gitpython
- ipython
- tqdm
- numpy<2
- tyro
- scikit-learn
- py-opencv
Expand Down
Loading