diff --git a/panoptic_segmentation/mask_pls/data/collation.py b/panoptic_segmentation/mask_pls/data/collation.py index 739ac13..346df12 100644 --- a/panoptic_segmentation/mask_pls/data/collation.py +++ b/panoptic_segmentation/mask_pls/data/collation.py @@ -48,7 +48,7 @@ def __call__(self, data) -> Dict: class fVDBSemanticSegmentationDatasetCollation: """ A data collation class for semantic segmentation datasets using fVDB. - This class handles the conversion of point cloud data into VDBTensors. + This class handles the conversion of point cloud data into sparse grid features. Parameters ---------- device : torch.device, optional @@ -56,7 +56,7 @@ class fVDBSemanticSegmentationDatasetCollation: Methods ------- __call__(data: dict) -> dict: - Transforms input point cloud data into VDBTensors. + Transforms input point cloud data into fVDB grid and feature representations. Parameters: data (dict): Dictionary containing: - xyz: List of point coordinates in world space @@ -65,7 +65,8 @@ class fVDBSemanticSegmentationDatasetCollation: Returns: dict: Original dictionary updated with: - xyz: JaggedTensor of point coordinates - - vdbtensor: VDBTensor containing the structured volumetric data + - grid: GridBatch containing the sparse grid topology + - features: JaggedTensor containing the per-voxel features """ def __init__(self, device=torch.device("cuda:0")): @@ -76,21 +77,18 @@ def __call__(self, data): # xyz world space point positions data["xyz"] = fvdb.JaggedTensor([torch.tensor(c, device=self.device) for c in data["xyz"]]) - grid = fvdb.gridbatch_from_points(data["xyz"], voxel_sizes=[n.tolist() for n in data["voxel_size"]]) + grid = fvdb.GridBatch.from_points(data["xyz"], voxel_sizes=[n.tolist() for n in data["voxel_size"]]) # get mapping of the coordinates to the grid for feature mapping - coord_ijks = grid.world_to_grid(data["xyz"]).round().int() + coord_ijks = grid.world_to_voxel(data["xyz"]).round().int() inv_idx = grid.ijk_to_inv_index(coord_ijks, cumulative=True) - # assert(torch.all(grid.ijk.jdata == coord_ijks.jdata[inv_idx.jdata])) - jfeats = torch.cat([torch.tensor(f, device=self.device).unsqueeze(-1) for f in data["intensity"]]) jfeats = grid.jagged_like(jfeats[inv_idx.jdata]) jfeats = fvdb.jcat([grid.ijk.float(), jfeats], dim=1) - vdbtensor = fvdb.nn.VDBTensor(grid, jfeats) - - data["vdbtensor"] = vdbtensor + data["grid"] = grid + data["features"] = jfeats return data diff --git a/panoptic_segmentation/mask_pls/maskpls_environment.yml b/panoptic_segmentation/mask_pls/maskpls_environment.yml index f6204f5..934da88 100644 --- a/panoptic_segmentation/mask_pls/maskpls_environment.yml +++ b/panoptic_segmentation/mask_pls/maskpls_environment.yml @@ -3,14 +3,13 @@ channels: - conda-forge - nodefaults dependencies: - - python=3.11 - - pytorch-gpu=2.4.1[build=cuda120*] + - python=3.12 + - pytorch-gpu=2.8.0[build=cuda129*] - pip - git - gitpython - ipython - tqdm - - numpy<2 - tyro - scikit-learn - py-opencv diff --git a/panoptic_segmentation/mask_pls/models/mask_pls/backbone.py b/panoptic_segmentation/mask_pls/models/mask_pls/backbone.py index 35dc364..49a2937 100644 --- a/panoptic_segmentation/mask_pls/models/mask_pls/backbone.py +++ b/panoptic_segmentation/mask_pls/models/mask_pls/backbone.py @@ -1,12 +1,13 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 # -from typing import List, Type +from typing import List, Tuple import torch import fvdb import fvdb.nn +from fvdb import ConvolutionPlan, GridBatch, JaggedTensor from .blocks import BasicConvolutionBlock, BasicDeconvolutionBlock, ResidualBlock @@ -20,151 +21,79 @@ def __init__( input_dim: int = 4, stem_blocks: int = 1, output_feature_levels: List[int] = [3], - conv_deconv_non_lin: Type = fvdb.nn.ReLU, bn_momentum: float = 0.02, ): super().__init__() self.output_feature_levels = output_feature_levels down_res_blocks = [2, 3, 4, 6] - self.stem = [ - fvdb.nn.SparseConv3d(input_dim, self.channels[0], kernel_size=3), - fvdb.nn.BatchNorm(self.channels[0], momentum=bn_momentum), - fvdb.nn.ReLU(inplace=True), - ] + # Stem: stride=1, ks=3 convolutions + self.stem_convs = torch.nn.ModuleList() + self.stem_bns = torch.nn.ModuleList() + self.stem_convs.append(fvdb.nn.SparseConv3d(input_dim, self.channels[0], kernel_size=3)) + self.stem_bns.append(fvdb.nn.BatchNorm(self.channels[0], momentum=bn_momentum)) for _ in range(1, stem_blocks): - self.stem.extend( - [ - fvdb.nn.SparseConv3d(self.channels[0], self.channels[0], kernel_size=3), - fvdb.nn.BatchNorm(self.channels[0], momentum=bn_momentum), - fvdb.nn.ReLU(inplace=True), - ] - ) - self.stem = torch.nn.Sequential(*self.stem) - - self.stage1 = [ - BasicConvolutionBlock( - self.channels[0], self.channels[0], ks=2, stride=2, bn_mom=bn_momentum, non_lin=conv_deconv_non_lin - ), + self.stem_convs.append(fvdb.nn.SparseConv3d(self.channels[0], self.channels[0], kernel_size=3)) + self.stem_bns.append(fvdb.nn.BatchNorm(self.channels[0], momentum=bn_momentum)) + + # Encoder stages: each starts with a stride-2 downsample then residual blocks + self.stage1 = torch.nn.ModuleList([ + BasicConvolutionBlock(self.channels[0], self.channels[0], ks=2, stride=2, bn_mom=bn_momentum), ResidualBlock(self.channels[0], self.channels[1], ks=3, bn_mom=bn_momentum), - ] - self.stage1.extend( - [ - ResidualBlock(self.channels[1], self.channels[1], ks=3, bn_mom=bn_momentum) - for _ in range(1, down_res_blocks[0]) - ] - ) - self.stage1 = torch.nn.Sequential(*self.stage1) + ] + [ + ResidualBlock(self.channels[1], self.channels[1], ks=3, bn_mom=bn_momentum) + for _ in range(1, down_res_blocks[0]) + ]) - self.stage2 = [ - BasicConvolutionBlock( - self.channels[1], self.channels[1], ks=2, stride=2, bn_mom=bn_momentum, non_lin=conv_deconv_non_lin - ), + self.stage2 = torch.nn.ModuleList([ + BasicConvolutionBlock(self.channels[1], self.channels[1], ks=2, stride=2, bn_mom=bn_momentum), ResidualBlock(self.channels[1], self.channels[2], ks=3, bn_mom=bn_momentum), - ] - self.stage2.extend( - [ - ResidualBlock(self.channels[2], self.channels[2], ks=3, bn_mom=bn_momentum) - for _ in range(1, down_res_blocks[1]) - ] - ) - self.stage2 = torch.nn.Sequential(*self.stage2) + ] + [ + ResidualBlock(self.channels[2], self.channels[2], ks=3, bn_mom=bn_momentum) + for _ in range(1, down_res_blocks[1]) + ]) - self.stage3 = [ - BasicConvolutionBlock( - self.channels[2], self.channels[2], ks=2, stride=2, bn_mom=bn_momentum, non_lin=conv_deconv_non_lin - ), + self.stage3 = torch.nn.ModuleList([ + BasicConvolutionBlock(self.channels[2], self.channels[2], ks=2, stride=2, bn_mom=bn_momentum), ResidualBlock(self.channels[2], self.channels[3], ks=3, bn_mom=bn_momentum), - ] - self.stage3.extend( - [ - ResidualBlock(self.channels[3], self.channels[3], ks=3, bn_mom=bn_momentum) - for _ in range(1, down_res_blocks[2]) - ] - ) - self.stage3 = torch.nn.Sequential(*self.stage3) + ] + [ + ResidualBlock(self.channels[3], self.channels[3], ks=3, bn_mom=bn_momentum) + for _ in range(1, down_res_blocks[2]) + ]) - self.stage4 = [ - BasicConvolutionBlock( - self.channels[3], self.channels[3], ks=2, stride=2, bn_mom=bn_momentum, non_lin=conv_deconv_non_lin - ), + self.stage4 = torch.nn.ModuleList([ + BasicConvolutionBlock(self.channels[3], self.channels[3], ks=2, stride=2, bn_mom=bn_momentum), ResidualBlock(self.channels[3], self.channels[4], ks=3, bn_mom=bn_momentum), - ] - self.stage4.extend( - [ - ResidualBlock(self.channels[4], self.channels[4], ks=3, bn_mom=bn_momentum) - for _ in range(1, down_res_blocks[3]) - ] - ) - self.stage4 = torch.nn.Sequential(*self.stage4) - - self.up1 = torch.nn.ModuleList( - [ - BasicDeconvolutionBlock( - self.channels[4], - self.channels[5], - ks=2, - stride=2, - bn_mom=bn_momentum, - ), - torch.nn.Sequential( - ResidualBlock(self.channels[5] + self.channels[3], self.channels[5], ks=3, bn_mom=bn_momentum), - ResidualBlock(self.channels[5], self.channels[5], ks=3, bn_mom=bn_momentum), - ), - ] - ) - - self.up2 = torch.nn.ModuleList( - [ - BasicDeconvolutionBlock( - self.channels[5], - self.channels[6], - ks=2, - stride=2, - bn_mom=bn_momentum, - ), - torch.nn.Sequential( - ResidualBlock(self.channels[6] + self.channels[2], self.channels[6], ks=3, bn_mom=bn_momentum), - ResidualBlock(self.channels[6], self.channels[6], ks=3, bn_mom=bn_momentum), - ), - ] - ) - - self.up3 = torch.nn.ModuleList( - [ - BasicDeconvolutionBlock( - self.channels[6], - self.channels[7], - ks=2, - stride=2, - bn_mom=bn_momentum, - ), - torch.nn.Sequential( - ResidualBlock(self.channels[7] + self.channels[1], self.channels[7], ks=3, bn_mom=bn_momentum), - ResidualBlock(self.channels[7], self.channels[7], ks=3, bn_mom=bn_momentum), - ), - ] - ) - - self.up4 = torch.nn.ModuleList( - [ - BasicDeconvolutionBlock( - self.channels[7], - self.channels[8], - ks=2, - stride=2, - bn_mom=bn_momentum, - ), - torch.nn.Sequential( - ResidualBlock(self.channels[8] + self.channels[0], self.channels[8], ks=3, bn_mom=bn_momentum), - ResidualBlock(self.channels[8], self.channels[8], ks=3, bn_mom=bn_momentum), - ), - ] - ) + ] + [ + ResidualBlock(self.channels[4], self.channels[4], ks=3, bn_mom=bn_momentum) + for _ in range(1, down_res_blocks[3]) + ]) + + # Decoder: each level has a deconv block + residual blocks after skip concatenation + self.up1_deconv = BasicDeconvolutionBlock(self.channels[4], self.channels[5], ks=2, stride=2, bn_mom=bn_momentum) + self.up1_res = torch.nn.ModuleList([ + ResidualBlock(self.channels[5] + self.channels[3], self.channels[5], ks=3, bn_mom=bn_momentum), + ResidualBlock(self.channels[5], self.channels[5], ks=3, bn_mom=bn_momentum), + ]) + + self.up2_deconv = BasicDeconvolutionBlock(self.channels[5], self.channels[6], ks=2, stride=2, bn_mom=bn_momentum) + self.up2_res = torch.nn.ModuleList([ + ResidualBlock(self.channels[6] + self.channels[2], self.channels[6], ks=3, bn_mom=bn_momentum), + ResidualBlock(self.channels[6], self.channels[6], ks=3, bn_mom=bn_momentum), + ]) + + self.up3_deconv = BasicDeconvolutionBlock(self.channels[6], self.channels[7], ks=2, stride=2, bn_mom=bn_momentum) + self.up3_res = torch.nn.ModuleList([ + ResidualBlock(self.channels[7] + self.channels[1], self.channels[7], ks=3, bn_mom=bn_momentum), + ResidualBlock(self.channels[7], self.channels[7], ks=3, bn_mom=bn_momentum), + ]) + + self.up4_deconv = BasicDeconvolutionBlock(self.channels[7], self.channels[8], ks=2, stride=2, bn_mom=bn_momentum) + self.up4_res = torch.nn.ModuleList([ + ResidualBlock(self.channels[8] + self.channels[0], self.channels[8], ks=3, bn_mom=bn_momentum), + ResidualBlock(self.channels[8], self.channels[8], ks=3, bn_mom=bn_momentum), + ]) - levels = [self.channels[-i] for i in range(4, 0, -1)] - - # conv mask projection self.mask_feat = fvdb.nn.SparseConv3d( self.channels[-1], self.channels[-1], @@ -172,43 +101,76 @@ def __init__( stride=1, ) - self.out_bnorm = torch.nn.ModuleList([torch.nn.Sequential() for _ in levels]) - - def forward(self, x) -> List[fvdb.nn.VDBTensor]: - - sparse_input = x["vdbtensor"] - - x0 = self.stem(sparse_input) # type: ignore - x1 = self.stage1(x0) # type: ignore - x2 = self.stage2(x1) # type: ignore - x3 = self.stage3(x2) # type: ignore - x4 = self.stage4(x3) # type: ignore + def _run_stage( + self, stage: torch.nn.ModuleList, data: JaggedTensor, grid: GridBatch + ) -> Tuple[JaggedTensor, GridBatch]: + for block in stage: + data, grid = block(data, grid) + return data, grid - y1 = self.up1[0](x4, out_grid=x3.grid) - y1 = fvdb.jcat([y1, x3], dim=1) - y1 = self.up1[1](y1) - - y2 = self.up2[0](y1, out_grid=x2.grid) - y2 = fvdb.jcat([y2, x2], dim=1) - y2 = self.up2[1](y2) - - y3 = self.up3[0](y2, out_grid=x1.grid) - y3 = fvdb.jcat([y3, x1], dim=1) - y3 = self.up3[1](y3) - - y4 = self.up4[0](y3, out_grid=x0.grid) - y4 = fvdb.jcat([y4, x0], dim=1) - y4 = self.up4[1](y4) + def _run_decoder_level( + self, + deconv: BasicDeconvolutionBlock, + res_blocks: torch.nn.ModuleList, + data: JaggedTensor, + source_grid: GridBatch, + skip_data: JaggedTensor, + skip_grid: GridBatch, + ) -> Tuple[JaggedTensor, GridBatch]: + data = deconv(data, source_grid, skip_grid) + data = fvdb.jcat([data, skip_data], dim=1) + grid = skip_grid + for block in res_blocks: + data, grid = block(data, grid) + return data, grid + + def forward(self, x) -> List[Tuple[JaggedTensor, GridBatch]]: + data: JaggedTensor = x["features"] + grid: GridBatch = x["grid"] + + # Stem: stride=1, ks=3 convolutions (grid topology unchanged) + stem_plan = ConvolutionPlan.from_grid_batch(kernel_size=3, stride=1, source_grid=grid, target_grid=grid) + for conv, bn in zip(self.stem_convs, self.stem_bns): + data = conv(data, stem_plan) + data = bn(data, grid) + data = fvdb.relu(data) + x0_data, x0_grid = data, grid + + # Encoder + x1_data, x1_grid = self._run_stage(self.stage1, x0_data, x0_grid) + x2_data, x2_grid = self._run_stage(self.stage2, x1_data, x1_grid) + x3_data, x3_grid = self._run_stage(self.stage3, x2_data, x2_grid) + x4_data, x4_grid = self._run_stage(self.stage4, x3_data, x3_grid) + + # Decoder + y1_data, y1_grid = self._run_decoder_level( + self.up1_deconv, self.up1_res, x4_data, x4_grid, x3_data, x3_grid + ) + y2_data, y2_grid = self._run_decoder_level( + self.up2_deconv, self.up2_res, y1_data, y1_grid, x2_data, x2_grid + ) + y3_data, y3_grid = self._run_decoder_level( + self.up3_deconv, self.up3_res, y2_data, y2_grid, x1_data, x1_grid + ) + y4_data, y4_grid = self._run_decoder_level( + self.up4_deconv, self.up4_res, y3_data, y3_grid, x0_data, x0_grid + ) - out_feats = [y1, y2, y3, y4] + out_feats = [ + (y1_data, y1_grid), + (y2_data, y2_grid), + (y3_data, y3_grid), + (y4_data, y4_grid), + ] feat_levels = self.output_feature_levels + [3] + out_feats = [out_feats[i] for i in feat_levels] - out_feats = [out_feats[feats] for feats in feat_levels] - - out_feats[-1] = self.mask_feat(out_feats[-1]) - - # batch norm - out_feats = [bn(feat) for feat, bn in zip(out_feats, self.out_bnorm)] + # Apply mask projection conv to the last feature level + last_data, last_grid = out_feats[-1] + mask_plan = ConvolutionPlan.from_grid_batch( + kernel_size=3, stride=1, source_grid=last_grid, target_grid=last_grid + ) + out_feats[-1] = (self.mask_feat(last_data, mask_plan), last_grid) return out_feats diff --git a/panoptic_segmentation/mask_pls/models/mask_pls/blocks.py b/panoptic_segmentation/mask_pls/models/mask_pls/blocks.py index 6c4f4cb..23348bb 100644 --- a/panoptic_segmentation/mask_pls/models/mask_pls/blocks.py +++ b/panoptic_segmentation/mask_pls/models/mask_pls/blocks.py @@ -1,13 +1,14 @@ # Copyright Contributors to the OpenVDB Project # SPDX-License-Identifier: Apache-2.0 # -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn.functional as F import fvdb import fvdb.nn +from fvdb import ConvolutionPlan, GridBatch, JaggedTensor class SelfAttentionLayer(torch.nn.Module): @@ -120,19 +121,14 @@ def forward(self, tgt): class MLP(torch.nn.Module): - def __init__(self, input_dim, hidden_dim_list, output_dim, use_fvdb: bool = False): + def __init__(self, input_dim, hidden_dim_list, output_dim): super().__init__() - if use_fvdb: - linear_cls = fvdb.nn.Linear - relu_cls = fvdb.nn.ReLU - else: - linear_cls = torch.nn.Linear - relu_cls = torch.nn.ReLU - self.num_layers = len(hidden_dim_list) + 1 h = hidden_dim_list - self.layers = torch.nn.ModuleList(linear_cls(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - self.relu = relu_cls() + self.layers = torch.nn.ModuleList( + torch.nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.relu = torch.nn.ReLU() def forward(self, x): for i, layer in enumerate(self.layers): @@ -143,46 +139,42 @@ def forward(self, x): class BasicConvolutionBlock(torch.nn.Module): - def __init__( - self, - inc, - outc, - ks=3, - stride=1, - dilation=1, - bn_mom=0.1, - non_lin=fvdb.nn.ReLU, - ): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1, bn_mom=0.1): super().__init__() if dilation != 1: raise NotImplementedError("Dilation not implemented for fVDB SparseConv3d") - self.net = torch.nn.Sequential( - fvdb.nn.SparseConv3d(inc, outc, kernel_size=ks, stride=stride), - fvdb.nn.BatchNorm(outc, momentum=bn_mom), - non_lin(inplace=True), + self.ks = ks + self.stride = stride + self.conv = fvdb.nn.SparseConv3d(inc, outc, kernel_size=ks, stride=stride) + self.bn = fvdb.nn.BatchNorm(outc, momentum=bn_mom) + + def forward(self, data: JaggedTensor, grid: GridBatch) -> Tuple[JaggedTensor, GridBatch]: + target_grid = grid if self.stride == 1 else grid.conv_grid(kernel_size=self.ks, stride=self.stride) + plan = ConvolutionPlan.from_grid_batch( + kernel_size=self.ks, stride=self.stride, source_grid=grid, target_grid=target_grid ) - - def forward(self, x): - out = self.net(x) - return out + data = self.conv(data, plan) + data = self.bn(data, target_grid) + data = fvdb.relu(data) + return data, target_grid class BasicDeconvolutionBlock(torch.nn.Module): - def __init__(self, inc, outc, ks=3, stride=1, bn_mom=0.1, non_lin=fvdb.nn.LeakyReLU): + def __init__(self, inc, outc, ks=3, stride=1, bn_mom=0.1): super().__init__() - self.net = torch.nn.Sequential( - fvdb.nn.SparseConv3d(inc, outc, kernel_size=ks, stride=stride, transposed=True), - fvdb.nn.BatchNorm(outc, momentum=bn_mom), - non_lin(inplace=True), + self.ks = ks + self.stride = stride + self.deconv = fvdb.nn.SparseConvTranspose3d(inc, outc, kernel_size=ks, stride=stride) + self.bn = fvdb.nn.BatchNorm(outc, momentum=bn_mom) + + def forward(self, data: JaggedTensor, source_grid: GridBatch, target_grid: GridBatch) -> JaggedTensor: + plan = ConvolutionPlan.from_grid_batch_transposed( + kernel_size=self.ks, stride=self.stride, source_grid=source_grid, target_grid=target_grid ) - - def forward(self, x, out_grid=None): - for module in self.net: - if isinstance(module, fvdb.nn.SparseConv3d): - x = module(x, out_grid=out_grid) - else: - x = module(x) - return x + data = self.deconv(data, plan) + data = self.bn(data, target_grid) + data = target_grid.jagged_like(F.leaky_relu(data.jdata)) + return data class ResidualBlock(torch.nn.Module): @@ -190,25 +182,45 @@ def __init__(self, inc, outc, ks=3, stride=1, dilation=1, bn_mom=0.1): super().__init__() if dilation != 1: raise NotImplementedError("Dilation not implemented for fVDB SparseConv3d") - self.net = torch.nn.Sequential( - fvdb.nn.SparseConv3d(inc, outc, kernel_size=ks, stride=stride), - fvdb.nn.BatchNorm(outc, momentum=bn_mom), - fvdb.nn.ReLU(inplace=True), - fvdb.nn.SparseConv3d(outc, outc, kernel_size=ks, stride=1), - fvdb.nn.BatchNorm(outc, momentum=bn_mom), + self.ks = ks + self.stride = stride + + self.conv1 = fvdb.nn.SparseConv3d(inc, outc, kernel_size=ks, stride=stride) + self.bn1 = fvdb.nn.BatchNorm(outc, momentum=bn_mom) + self.conv2 = fvdb.nn.SparseConv3d(outc, outc, kernel_size=ks, stride=1) + self.bn2 = fvdb.nn.BatchNorm(outc, momentum=bn_mom) + + if inc == outc and stride == 1: + self.downsample_conv = None + self.downsample_bn = None + else: + self.downsample_conv = fvdb.nn.SparseConv3d(inc, outc, kernel_size=1, stride=stride) + self.downsample_bn = fvdb.nn.BatchNorm(outc, momentum=bn_mom) + + def forward(self, data: JaggedTensor, grid: GridBatch) -> Tuple[JaggedTensor, GridBatch]: + target_grid = grid if self.stride == 1 else grid.conv_grid(kernel_size=self.ks, stride=self.stride) + + plan1 = ConvolutionPlan.from_grid_batch( + kernel_size=self.ks, stride=self.stride, source_grid=grid, target_grid=target_grid ) + out = self.conv1(data, plan1) + out = self.bn1(out, target_grid) + out = fvdb.relu(out) - self.downsample = ( - torch.nn.Sequential() - if (inc == outc and stride == 1) - else torch.nn.Sequential( - fvdb.nn.SparseConv3d(inc, outc, kernel_size=1, stride=stride), - fvdb.nn.BatchNorm(outc, momentum=bn_mom), - ) + plan2 = ConvolutionPlan.from_grid_batch( + kernel_size=self.ks, stride=1, source_grid=target_grid, target_grid=target_grid ) + out = self.conv2(out, plan2) + out = self.bn2(out, target_grid) - self.relu = fvdb.nn.ReLU(inplace=True) + if self.downsample_conv is not None: + ds_plan = ConvolutionPlan.from_grid_batch( + kernel_size=1, stride=self.stride, source_grid=grid, target_grid=target_grid + ) + residual = self.downsample_conv(data, ds_plan) + residual = self.downsample_bn(residual, target_grid) + else: + residual = data - def forward(self, x): - out = self.relu(self.net(x) + self.downsample(x)) - return out + out = fvdb.relu(out + residual) + return out, target_grid diff --git a/panoptic_segmentation/mask_pls/models/mask_pls/loss.py b/panoptic_segmentation/mask_pls/models/mask_pls/loss.py index 2cab16a..0136cb6 100644 --- a/panoptic_segmentation/mask_pls/models/mask_pls/loss.py +++ b/panoptic_segmentation/mask_pls/models/mask_pls/loss.py @@ -38,14 +38,14 @@ def forward(self, outputs, targets) -> dict: if self.input_mode == MaskPLS.DecoderInputMode.GRID: # If the input to the loss function (which is the same as the input/output from the decoder) is the grid centers, # (i.e. not the original xyz coordinates), we need to convert the targets to the grid centers as well. - input_vdbtensor = targets["vdbtensor"] + input_grid = targets["grid"] # map target semantic labels to the grid points = targets["xyz"] # get mapping of the coordinates to the grid for feature mapping - coord_ijks = input_vdbtensor.grid.world_to_grid(points).round().int() - inv_idx = input_vdbtensor.grid.ijk_to_inv_index(coord_ijks, cumulative=True) + coord_ijks = input_grid.world_to_voxel(points).round().int() + inv_idx = input_grid.ijk_to_inv_index(coord_ijks, cumulative=True) sem_labels = sem_labels[inv_idx.jdata] sem_targets = sem_labels diff --git a/panoptic_segmentation/mask_pls/models/mask_pls/mask_model.py b/panoptic_segmentation/mask_pls/models/mask_pls/mask_model.py index 28986ed..19fd0f0 100644 --- a/panoptic_segmentation/mask_pls/models/mask_pls/mask_model.py +++ b/panoptic_segmentation/mask_pls/models/mask_pls/mask_model.py @@ -7,9 +7,6 @@ import torch import torch.nn -import fvdb -import fvdb.nn - from .backbone import MaskPLSEncoderDecoder from .blocks import MLP from .decoder import MaskedTransformerDecoder @@ -47,11 +44,7 @@ def __init__( self.backbone = MaskPLSEncoderDecoder(output_feature_levels=[3]) - self.sem_head = ( - fvdb.nn.Linear(self.backbone.channels[-1], num_classes) - if self.decoder_input_mode == MaskPLS.DecoderInputMode.GRID - else torch.nn.Linear(self.backbone.channels[-1], num_classes) - ) + self.sem_head = torch.nn.Linear(self.backbone.channels[-1], num_classes) self.semantic_embedding_distil = False if self.semantic_embedding_distil: @@ -60,7 +53,6 @@ def __init__( self.backbone.channels[-1], semantic_embedding_hidden_dims[:-1], semantic_embedding_hidden_dims[-1], - use_fvdb=(self.decoder_input_mode == MaskPLS.DecoderInputMode.GRID), ) if not self.segmentation_only: @@ -71,36 +63,36 @@ def __init__( def forward(self, x: Dict): outputs = {} + logits_sem_embed_grid = None + ###### Backbone ###### out_feats_grids = self.backbone(x) - # out_feats_grids is a List[fvdb.nn.VDBTensor] - # where each VDBTensor corresponds to the `ouput_feature_levels` + # out_feats_grids is a List[Tuple[JaggedTensor, GridBatch]] + # where each tuple corresponds to the `output_feature_levels` # plus 1 additional entry which is the last/full-resolution feature level run through the conv mask projection ###### v2p ###### # NOTE: Matching MaskPLS paper which performs v2p before sem_head # In SAL, features are at voxel centers throughout, so we provide an option to try either if self.decoder_input_mode == MaskPLS.DecoderInputMode.XYZ: - # If decoder inputs are the original points, we need to sample the features in the grid and pad them for form - # a minibatch for the semantic head and decoder xyz = x["xyz"] - feats = [feats_grid.sample_trilinear(xyz).unbind() for feats_grid in out_feats_grids] + feats = [grid.sample_trilinear(xyz, data).unbind() for data, grid in out_feats_grids] # pad batch feats, coords, pad_masks = pad_batch(feats, [xyz.unbind() for _ in feats]) # type: ignore - else: - feats = out_feats_grids - logits = [self.sem_head(feats[-1])] + logits = [self.sem_head(feats[-1])] + else: + # GRID mode: apply sem_head to the raw JaggedTensor features, then unpack into padded batches + last_data, last_grid = out_feats_grids[-1] + logits_jt = last_grid.jagged_like(self.sem_head(last_data.jdata)) - if self.semantic_embedding_distil: - logits_sem_embed_grid = self.sem_embed(feats[-1]) + if self.semantic_embedding_distil: + logits_sem_embed_grid = last_grid.jagged_like(self.sem_embed(last_data.jdata)) - if self.decoder_input_mode == MaskPLS.DecoderInputMode.GRID: - # produce a padded batch for the decoder and loss - coords = [feat.grid.grid_to_world(feat.ijk.float()).unbind() for feat in out_feats_grids] - feats = [feat.data.unbind() for feat in out_feats_grids] - logits = [ls.data.unbind() for ls in logits] + coords = [grid.voxel_to_world(grid.ijk.float()).unbind() for data, grid in out_feats_grids] + feats = [data.unbind() for data, grid in out_feats_grids] + logits = [logits_jt.unbind()] feats, coords, pad_masks, logits = pad_batch(feats, coords, additional_feats=logits) # type: ignore ###### Decoder ######