From a4023fc6f6760efb3402832d6706c0acee88d20e Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Mon, 15 Dec 2025 11:53:20 -0800 Subject: [PATCH 1/4] This commit fixes some mismatch with official PTV3. I currently believe the implementation matches with official ptv3 very well. Signed-off-by: Hexu Zhao --- .../models/point_transformer_v3m1_fvdb.py | 5 +--- .../fvdb_extensions/models/ptv3_fvdb.py | 28 ++++++++++--------- point_transformer_v3/requirements.txt | 9 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py b/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py index e8f3886..652deda 100644 --- a/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py +++ b/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py @@ -63,9 +63,7 @@ def create_grid_from_points( coords_jagged = fvdb.JaggedTensor(coords_list) grid = fvdb.GridBatch.from_ijk( - coords_jagged, - voxel_sizes=[[voxel_size, voxel_size, voxel_size]] * len(coords_list), - origins=[0.0] * 3, + coords_jagged ) feats_jagged = fvdb.JaggedTensor(feats_list) @@ -195,7 +193,6 @@ def forward(self, data_dict: dict) -> torch.Tensor: grid, jfeats, original_coord_to_voxel_idx = create_grid_from_points( grid_coord, feat, offset, voxel_size=0.02 ) - # import pdb; pdb.set_trace() # TODO: check the downsampling behavior is the same or not? assert ( grid_coord.shape == grid.ijk.jdata.shape diff --git a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py index 334c2ba..fc985ca 100644 --- a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py +++ b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py @@ -161,10 +161,10 @@ def __init__( self.out_channels = out_channels self.proj = FJTM(torch.nn.Linear(in_channels, out_channels)) - self.norm = FJTM(norm_layer_module(out_channels)) - self.act_layer = FJTM(torch.nn.GELU()) self.proj_skip = FJTM(torch.nn.Linear(skip_channels, out_channels)) + self.norm = FJTM(norm_layer_module(out_channels)) self.norm_skip = FJTM(norm_layer_module(out_channels)) + self.act_layer = FJTM(torch.nn.GELU()) self.act_layer_skip = FJTM(torch.nn.GELU()) def __call__( @@ -180,7 +180,7 @@ def forward( # The conversion is to avoid the bug when enabled AMP, # despite both feats.jdata and linear.weights are float32, # the output becomes float16 which causes the subsequent convolution operation to fail. - feats = self.proj(feats).to(torch.float32) + feats = self.proj(feats) #.to(torch.float32) feats = self.norm(feats) feats = self.act_layer(feats) @@ -387,7 +387,7 @@ def __init__( sliding_window_attention, order_index, order_types, - ) + ) # temporary disable attention self.norm2 = FJTM(torch.nn.LayerNorm(hidden_size)) self.order_index = order_index self.mlp = PTV3_MLP(hidden_size, proj_drop) @@ -397,16 +397,18 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged assert isinstance(feats, fvdb.JaggedTensor), "Input feats must be a JaggedTensor" assert isinstance(grid, fvdb.GridBatch), "Input grid must be a GridBatch" with NVTXRange("PTV3_Block"): + short_cut = feats feats = self.cpe(feats, grid) + feats = fvdb.add(short_cut, feats) short_cut = feats feats = self.norm1(feats) - feats = self.attn(feats, grid) + feats = self.attn(feats, grid) # temporary disable attention + feats = self.drop_path(feats) # temporary disable attention # The drop_path is applied to each point independently. - feats = self.drop_path(feats) feats = fvdb.add(short_cut, feats) - short_cut = feats + short_cut = feats feats = self.norm2(feats) feats = self.mlp(feats) feats = self.drop_path(feats) @@ -634,23 +636,23 @@ def __init__( ) ) - def _shuffle_order(self): + def _shuffle_order(self, shuffled_order): """ Randomly shuffle the order tuple to create variation across forward passes. Returns a new shuffled tuple of order types. """ if self.shuffle_orders: - indices = torch.randperm(len(self.order_type)) - return tuple(self.order_type[i] for i in indices) + indices = torch.randperm(len(shuffled_order)) + return tuple(shuffled_order[i] for i in indices) else: - return self.order_type + return shuffled_order def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.JaggedTensor: original_grid = grid with NVTXRange("PTV3_Forward"): # Shuffle order at the beginning of forward pass (matching reference implementation) - shuffled_order = self._shuffle_order() + shuffled_order = self._shuffle_order(self.order_type) # Store shuffled order in grid metadata so all blocks can access it grid._shuffled_order = shuffled_order # type: ignore @@ -670,7 +672,7 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged feats, grid = pooler(feats, grid) # Shuffle order after pooling for the next (downsampled) stage - shuffled_order = self._shuffle_order() + shuffled_order = self._shuffle_order(shuffled_order) grid._shuffled_order = shuffled_order # type: ignore layer_id += 1 with NVTXRange(f"PTV3_Encoder_{layer_id}"): diff --git a/point_transformer_v3/requirements.txt b/point_transformer_v3/requirements.txt index e54dda0..215ab79 100644 --- a/point_transformer_v3/requirements.txt +++ b/point_transformer_v3/requirements.txt @@ -1,7 +1,7 @@ -# fvdb requirements ---extra-index-url https://download.pytorch.org/whl/cu129 ---extra-index-url https://d36m13axqqhiit.cloudfront.net/simple -fvdb-core==0.3.0+pt28.cu129 +# # fvdb requirements +# --extra-index-url https://download.pytorch.org/whl/cu129 +# --extra-index-url https://d36m13axqqhiit.cloudfront.net/simple +# fvdb-core==0.3.0+pt28.cu129 # Core dependencies for PT-v3 FVDB implementation timm @@ -11,6 +11,7 @@ peft wandb tensorboard tensorboardx +yapf # flash-attn is only needed when patch_size > 0 (default config uses patch_size=1024) # While PyTorch 2.8+ has built-in flash attention, flash-attn provides optimized varlen functions From d0d30531e5f214336c2d607985cf679f5200d844 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Mon, 15 Dec 2025 11:59:48 -0800 Subject: [PATCH 2/4] fix requirements Signed-off-by: Hexu Zhao --- point_transformer_v3/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/point_transformer_v3/requirements.txt b/point_transformer_v3/requirements.txt index 215ab79..4afc9d6 100644 --- a/point_transformer_v3/requirements.txt +++ b/point_transformer_v3/requirements.txt @@ -1,7 +1,7 @@ # # fvdb requirements -# --extra-index-url https://download.pytorch.org/whl/cu129 -# --extra-index-url https://d36m13axqqhiit.cloudfront.net/simple -# fvdb-core==0.3.0+pt28.cu129 +--extra-index-url https://download.pytorch.org/whl/cu129 +--extra-index-url https://d36m13axqqhiit.cloudfront.net/simple +fvdb-core==0.3.0+pt28.cu129 # Core dependencies for PT-v3 FVDB implementation timm From 17763e0fdf906223eb5dc0451ab2f8dbb4c5d97c Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Mon, 15 Dec 2025 12:02:12 -0800 Subject: [PATCH 3/4] format Signed-off-by: Hexu Zhao --- .../fvdb_extensions/models/point_transformer_v3m1_fvdb.py | 4 +--- point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py b/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py index 652deda..2f67100 100644 --- a/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py +++ b/point_transformer_v3/fvdb_extensions/models/point_transformer_v3m1_fvdb.py @@ -62,9 +62,7 @@ def create_grid_from_points( coords_jagged = fvdb.JaggedTensor(coords_list) - grid = fvdb.GridBatch.from_ijk( - coords_jagged - ) + grid = fvdb.GridBatch.from_ijk(coords_jagged) feats_jagged = fvdb.JaggedTensor(feats_list) feats_vdb_order = grid.inject_from_ijk(coords_jagged, feats_jagged) # diff --git a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py index fc985ca..5f2bb20 100644 --- a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py +++ b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py @@ -180,7 +180,7 @@ def forward( # The conversion is to avoid the bug when enabled AMP, # despite both feats.jdata and linear.weights are float32, # the output becomes float16 which causes the subsequent convolution operation to fail. - feats = self.proj(feats) #.to(torch.float32) + feats = self.proj(feats) # .to(torch.float32) feats = self.norm(feats) feats = self.act_layer(feats) @@ -387,7 +387,7 @@ def __init__( sliding_window_attention, order_index, order_types, - ) # temporary disable attention + ) # temporary disable attention self.norm2 = FJTM(torch.nn.LayerNorm(hidden_size)) self.order_index = order_index self.mlp = PTV3_MLP(hidden_size, proj_drop) @@ -403,8 +403,8 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged short_cut = feats feats = self.norm1(feats) - feats = self.attn(feats, grid) # temporary disable attention - feats = self.drop_path(feats) # temporary disable attention + feats = self.attn(feats, grid) # temporary disable attention + feats = self.drop_path(feats) # temporary disable attention # The drop_path is applied to each point independently. feats = fvdb.add(short_cut, feats) From f59d47ae249aa065f1efbd39015cbf31e75cc1f3 Mon Sep 17 00:00:00 2001 From: Hexu Zhao Date: Tue, 16 Dec 2025 03:50:10 +0000 Subject: [PATCH 4/4] Update fvdb and ptv3. Signed-off-by: Hexu Zhao --- .../fvdb_extensions/models/fvdb_utils.py | 27 +++++++++++-------- .../fvdb_extensions/models/ptv3_fvdb.py | 19 ++++++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/point_transformer_v3/fvdb_extensions/models/fvdb_utils.py b/point_transformer_v3/fvdb_extensions/models/fvdb_utils.py index 7e9a9d6..551b6a7 100644 --- a/point_transformer_v3/fvdb_extensions/models/fvdb_utils.py +++ b/point_transformer_v3/fvdb_extensions/models/fvdb_utils.py @@ -62,28 +62,30 @@ def jagged_cumulative_argsort(unsorted_jt: fvdb.JaggedTensor) -> fvdb.JaggedTens def morton_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor: - ijk_j = jagged_ijk.jdata - morton_j = fvdb.morton(ijk_j) + ijk_j: torch.Tensor = jagged_ijk.jdata + kji_j = ijk_j[:, [2, 1, 0]].contiguous() + morton_j = fvdb.morton(kji_j) return jagged_ijk.jagged_like(morton_j) def morton_flipped_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor: ijk_j: torch.Tensor = jagged_ijk.jdata - kji_j = ijk_j.flip(dims=[-1]) - morton_j = fvdb.morton(kji_j) + kij_j = ijk_j[:, [2, 0, 1]].contiguous() + morton_j = fvdb.morton(kij_j) return jagged_ijk.jagged_like(morton_j) def hilbert_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor: - ijk_j = jagged_ijk.jdata - hilbert_j = fvdb.hilbert(ijk_j) + ijk_j: torch.Tensor = jagged_ijk.jdata + jki_j = ijk_j[:, [1, 2, 0]].contiguous() + hilbert_j = fvdb.hilbert(jki_j) return jagged_ijk.jagged_like(hilbert_j) def hilbert_flipped_from_jagged_ijk(jagged_ijk: fvdb.JaggedTensor) -> fvdb.JaggedTensor: ijk_j: torch.Tensor = jagged_ijk.jdata - kji_j = ijk_j.flip(dims=[-1]) - hilbert_j = fvdb.hilbert(kji_j) + ikj_j = ijk_j[:, [0, 2, 1]].contiguous() + hilbert_j = fvdb.hilbert(ikj_j) return jagged_ijk.jagged_like(hilbert_j) @@ -259,7 +261,10 @@ def jagged_attention( out_b = cast( Any, flash_attn.flash_attn_qkvpacked_func( - qkv_b.half(), dropout_p=0.0, softmax_scale=scale, window_size=window_size + qkv_b.half(), + dropout_p=0.0, + softmax_scale=scale, + window_size=window_size, ), ).reshape( Li, hidden_size @@ -303,7 +308,7 @@ def jagged_attention( feats_out_j = cast( Any, flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_j.half(), + qkv_j.to(dtype=torch.bfloat16), cu_seqlens, max_seqlen=patch_size, dropout_p=0.0, # TODO: implement attention dropout in the future. By default, it is 0. @@ -311,7 +316,7 @@ def jagged_attention( ), ).reshape( num_voxels, hidden_size - ) # dtype: float16 + ) # dtype: bfloat16 feats_out_j = feats_out_j.to(feats_j.dtype) else: diff --git a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py index 5f2bb20..a06916f 100644 --- a/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py +++ b/point_transformer_v3/fvdb_extensions/models/ptv3_fvdb.py @@ -168,13 +168,21 @@ def __init__( self.act_layer_skip = FJTM(torch.nn.GELU()) def __call__( - self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch, last_feats: fvdb.JaggedTensor, last_grid: fvdb.GridBatch + self, + feats: fvdb.JaggedTensor, + grid: fvdb.GridBatch, + last_feats: fvdb.JaggedTensor, + last_grid: fvdb.GridBatch, ) -> tuple[fvdb.JaggedTensor, fvdb.GridBatch]: """Override __call__ to preserve type hints from forward.""" return super().__call__(feats, grid, last_feats, last_grid) def forward( - self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch, last_feats: fvdb.JaggedTensor, last_grid: fvdb.GridBatch + self, + feats: fvdb.JaggedTensor, + grid: fvdb.GridBatch, + last_feats: fvdb.JaggedTensor, + last_grid: fvdb.GridBatch, ) -> tuple[fvdb.JaggedTensor, fvdb.GridBatch]: with NVTXRange("PTV3_Unpooling"): # The conversion is to avoid the bug when enabled AMP, @@ -296,7 +304,12 @@ def forward(self, feats: fvdb.JaggedTensor, grid: fvdb.GridBatch) -> fvdb.Jagged class PTV3_CPE(FVDBGridModule): - def __init__(self, hidden_size: int, no_conv_in_cpe: bool = False, shared_plan_cache: dict | None = None): + def __init__( + self, + hidden_size: int, + no_conv_in_cpe: bool = False, + shared_plan_cache: dict | None = None, + ): """ Args: hidden_size (int): Number of channels in the input features.