Skip to content

Commit 1e7f3cc

Browse files
Replace 2pt5 model files on monailabel side (#12)
Replace the old 2pt5 model files with latest training models networks. model.py vista_image_encoder vista_prompt_encoder --------- Signed-off-by: tangy5 <yucheng.tang@vanderbilt.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9fc976b commit 1e7f3cc

6 files changed

Lines changed: 361 additions & 1052 deletions

File tree

monailabel/monaivista/lib/configs/vista_point_2pt5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import lib.infers
1717
import lib.trainers
18-
from lib.model.vista_point_2pt5.models_samm2pt5d import sam_model_registry
18+
from lib.model.vista_point_2pt5.model_2pt5 import sam_model_registry
1919
from monailabel.interfaces.config import TaskConfig
2020
from monailabel.interfaces.tasks.infer_v2 import InferTask
2121
from monailabel.interfaces.tasks.train import TrainTask
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

monailabel/monaivista/lib/model/vista_point_2pt5/models_samm2pt5d.py renamed to monailabel/monaivista/lib/model/vista_point_2pt5/model_2pt5.py

Lines changed: 64 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
112
# Copyright (c) Meta Platforms, Inc. and affiliates.
213
# All rights reserved.
314

@@ -7,24 +18,25 @@
718
from functools import partial
819
from typing import Any, Dict, List, Tuple
920

21+
import monai
1022
import torch
23+
from segment_anything.modeling import TwoWayTransformer
24+
from segment_anything.modeling.mask_decoder import MaskDecoder
1125
from torch import nn
1226
from torch.nn import functional as F
1327

14-
from .segment_anything.modeling import TwoWayTransformer
15-
from .segment_anything.modeling.image_encoder import ImageEncoderViT
16-
from .segment_anything.modeling.mask_decoder import MaskDecoder
17-
from .segment_anything.modeling.prompt_encoder import PromptEncoder
28+
from .vista_2pt5_image_encoder import VistaImageEncoderViT
29+
from .vista_2pt5_prompt_encoder import VistaPromptEncoder
1830

1931

20-
class Samm2pt5D(nn.Module):
32+
class Vista2pt5D(nn.Module):
2133
mask_threshold: float = 0.5
2234
image_format: str = "RGB"
2335

2436
def __init__(
2537
self,
26-
image_encoder: ImageEncoderViT,
27-
prompt_encoder: PromptEncoder,
38+
image_encoder: VistaImageEncoderViT,
39+
prompt_encoder: VistaPromptEncoder,
2840
mask_decoder: MaskDecoder,
2941
pixel_mean: List[float] = [123.675, 116.28, 103.53],
3042
pixel_std: List[float] = [58.395, 57.12, 57.375],
@@ -67,7 +79,6 @@ def get_mask_prediction(
6779
for image_record, curr_embedding in zip(batched_input, image_embeddings):
6880
if "point_coords" in image_record:
6981
points = (image_record["point_coords"], image_record["point_labels"])
70-
# raise NotImplementedError
7182
else:
7283
points = None
7384
sparse_embeddings, dense_embeddings = self.prompt_encoder(
@@ -86,7 +97,6 @@ def get_mask_prediction(
8697

8798
high_res_masks = self.postprocess_masks(
8899
low_res_masks,
89-
# input_size=image_record["image"].shape[-2:],
90100
original_size=image_record["original_size"],
91101
)
92102
masks = high_res_masks > self.mask_threshold
@@ -124,6 +134,8 @@ def forward(
124134
input frame of the model.
125135
'point_labels': (torch.Tensor) Batched labels for point prompts,
126136
with shape BxN.
137+
'labels': (torch.Tensor) Batched labels for class-label prompt,
138+
with shape BxN.
127139
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
128140
Already transformed to the input frame of the model.
129141
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
@@ -151,7 +163,6 @@ def forward(
151163
for image_record, curr_embedding in zip(batched_input, image_embeddings):
152164
if "point_coords" in image_record:
153165
points = (image_record["point_coords"], image_record["point_labels"])
154-
# raise NotImplementedError
155166
else:
156167
points = None
157168
sparse_embeddings, dense_embeddings = self.prompt_encoder(
@@ -194,7 +205,6 @@ def forward(
194205
def postprocess_masks(
195206
self,
196207
masks: torch.Tensor,
197-
# input_size: Tuple[int, ...],
198208
original_size: Tuple[int, ...],
199209
) -> torch.Tensor:
200210
"""
@@ -261,21 +271,23 @@ def preprocess(self, x: torch.Tensor, is_input=True) -> torch.Tensor:
261271
return x
262272

263273

264-
def _build_sam2pt5d(
274+
def _build_vista2pt5d(
265275
encoder_in_chans,
266276
encoder_embed_dim,
267277
encoder_depth,
268278
encoder_num_heads,
269279
encoder_global_attn_indexes,
270280
checkpoint=None,
271281
image_size=1024,
282+
clip_class_label_prompt=False,
283+
patch_embed_3d=False,
272284
):
273285
prompt_embed_dim = 256
274-
image_size = image_size # TODO: Shall we try to adapt model to 512x512 ?
286+
image_size = image_size
275287
vit_patch_size = 16
276288
image_embedding_size = image_size // vit_patch_size
277-
sam = Samm2pt5D(
278-
image_encoder=ImageEncoderViT(
289+
sam = Vista2pt5D(
290+
image_encoder=VistaImageEncoderViT(
279291
in_chans=encoder_in_chans,
280292
depth=encoder_depth,
281293
embed_dim=encoder_embed_dim,
@@ -289,12 +301,14 @@ def _build_sam2pt5d(
289301
global_attn_indexes=encoder_global_attn_indexes,
290302
window_size=14,
291303
out_chans=prompt_embed_dim,
304+
patch_embed_3d=patch_embed_3d,
292305
),
293-
prompt_encoder=PromptEncoder(
306+
prompt_encoder=VistaPromptEncoder(
294307
embed_dim=prompt_embed_dim,
295308
image_embedding_size=(image_embedding_size, image_embedding_size),
296309
input_image_size=(image_size, image_size),
297310
mask_in_chans=16,
311+
clip_class_label_prompt=clip_class_label_prompt,
298312
),
299313
mask_decoder=MaskDecoder(
300314
num_multimask_outputs=3, # TODO: only predict one binary mask
@@ -315,22 +329,26 @@ def _build_sam2pt5d(
315329
if checkpoint is not None:
316330
with open(checkpoint, "rb") as f:
317331
state_dict = torch.load(f)
332+
318333
if image_size == 1024:
319334
# we try to use all pretrained weights
320335
new_dict = state_dict
321-
if encoder_in_chans != 3:
322-
new_dict.pop("image_encoder.patch_embed.proj.weight")
323336
else:
324337
new_dict = {}
325338
for k, v in state_dict.items():
326-
# skip weights in prompt_encoder and mask_decoder
327-
if k.startswith("prompt_encoder") or k.startswith("mask_decoder"):
328-
continue
329339
# skip weights in position embedding and learned relative positional embeddings
330-
elif "pos_embed" in k or "attn.rel_pos" in k:
340+
# due to the change of input size
341+
if ("pos_embed" in k and k.startswith("image_encoder")) or (
342+
"attn.rel_pos" in k and k.startswith("image_encoder")
343+
):
331344
continue
332345
else:
333346
new_dict[k] = v
347+
348+
if encoder_in_chans != 3:
349+
new_dict.pop("image_encoder.patch_embed.proj.weight")
350+
new_dict.pop("image_encoder.patch_embed.proj.bias")
351+
334352
sam.load_state_dict(new_dict, strict=False)
335353
print(f"Load {len(new_dict)} keys from checkpoint {checkpoint}, current model has {len(sam.state_dict())} keys")
336354

@@ -355,106 +373,67 @@ def _build_sam2pt5d(
355373
f"{sum(mask_decoder_params) * 1.e-6:.2f} M params in mask decoder."
356374
)
357375

358-
# comment to unfreeze all encoder layers
359-
# for name, param in sam.named_parameters():
360-
# if name.startswith("image_encoder"):
361-
# if image_size == 1024:
362-
# if "pos_embed" in name or "patch_embed" in name or "blocks.0" in name:
363-
# # we only retrain layers before blocks.1 in image_encoder
364-
# continue
365-
# # if "pos_embed" in name or "patch_embed" in name:
366-
# # # we only retrain pos_embed and patch_embed
367-
# # continue
368-
# else:
369-
# if "pos_embed" in name or "attn.rel_pos" in name or \
370-
# "patch_embed" in name or "blocks.0" in name or "neck" in name:
371-
# # we only train pos_embed, patch_embed, blocks.0, attn.rel_pos (due res change)
372-
# # and neck (a few conv layers for outputs) in image_encoder
373-
# continue
374-
#
375-
# # we freeze all other layers in image_encoder
376-
# param.requires_grad = False
377-
378376
total_trainable_params = sum(p.numel() if p.requires_grad else 0 for p in sam.parameters())
379377
print(f"{sam.__class__.__name__} has {total_trainable_params * 1.e-6:.2f} M trainable params.")
380378
return sam
381379

382380

383-
def build_samm2pt5d_vit_h(checkpoint=None, image_size=1024, encoder_in_chans=3):
384-
return _build_sam2pt5d(
381+
def build_vista2pt5d_vit_h(
382+
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
383+
):
384+
return _build_vista2pt5d(
385385
encoder_in_chans=encoder_in_chans,
386386
encoder_embed_dim=1280,
387387
encoder_depth=32,
388388
encoder_num_heads=16,
389389
encoder_global_attn_indexes=[7, 15, 23, 31],
390390
checkpoint=checkpoint,
391391
image_size=image_size,
392+
clip_class_label_prompt=clip_class_label_prompt,
393+
patch_embed_3d=patch_embed_3d,
392394
)
393395

394396

395-
def build_samm2pt5d_vit_l(checkpoint=None, image_size=1024, encoder_in_chans=3):
396-
return _build_sam2pt5d(
397+
def build_vista2pt5d_vit_l(
398+
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
399+
):
400+
return _build_vista2pt5d(
397401
encoder_in_chans=encoder_in_chans,
398402
encoder_embed_dim=1024,
399403
encoder_depth=24,
400404
encoder_num_heads=16,
401405
encoder_global_attn_indexes=[5, 11, 17, 23],
402406
checkpoint=checkpoint,
403407
image_size=image_size,
408+
clip_class_label_prompt=clip_class_label_prompt,
409+
patch_embed_3d=patch_embed_3d,
404410
)
405411

406412

407-
def build_samm2pt5d_vit_b(checkpoint=None, image_size=1024, encoder_in_chans=3):
408-
return _build_sam2pt5d(
413+
def build_vista2pt5d_vit_b(
414+
checkpoint=None, image_size=1024, encoder_in_chans=3, clip_class_label_prompt=False, patch_embed_3d=False
415+
):
416+
return _build_vista2pt5d(
409417
encoder_in_chans=encoder_in_chans,
410418
encoder_embed_dim=768,
411419
encoder_depth=12,
412420
encoder_num_heads=12,
413421
encoder_global_attn_indexes=[2, 5, 8, 11],
414422
checkpoint=checkpoint,
415423
image_size=image_size,
424+
clip_class_label_prompt=clip_class_label_prompt,
425+
patch_embed_3d=patch_embed_3d,
416426
)
417427

418428

419429
sam_model_registry = {
420-
"default": build_samm2pt5d_vit_h,
421-
"vit_h": build_samm2pt5d_vit_h,
422-
"vit_l": build_samm2pt5d_vit_l,
423-
"vit_b": build_samm2pt5d_vit_b,
430+
"default": build_vista2pt5d_vit_h,
431+
"vit_h": build_vista2pt5d_vit_h,
432+
"vit_l": build_vista2pt5d_vit_l,
433+
"vit_b": build_vista2pt5d_vit_b,
424434
}
425435

436+
426437
if __name__ == "__main__":
427-
model = build_samm2pt5d_vit_b()
438+
model = build_vista2pt5d_vit_b()
428439
model.cuda()
429-
#
430-
# dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (176, 345),
431-
# "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda(),
432-
# "labels": torch.ones(3, 1).long().cuda()},
433-
# {"image": torch.rand(3, 128, 365).cuda(), "original_size": (128, 365),
434-
# "point_coords": torch.rand(1, 3, 2).cuda(), "point_labels": torch.ones(1, 3).cuda(),
435-
# "labels": torch.ones(1, 1).long().cuda()}
436-
# ]
437-
# # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (256, 512),
438-
# # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda()}]
439-
# outputs = model(dummy_input)
440-
441-
# test if postprocessing can inverse preprocess
442-
# path = "/home/pengfeig/Downloads/fffabebf-74fd3a1f-673b6b41-96ec0ac9-2ab69818.jpg"
443-
# from PIL import Image
444-
# import numpy as np
445-
# import matplotlib.pyplot as plt
446-
#
447-
# image = np.array(Image.open(path)).transpose(2, 0, 1)[:, :365, :256].astype(np.float32)
448-
# plt.imshow(image.transpose(1, 2, 0).astype(np.uint8))
449-
# plt.show()
450-
# dummy_tensor = torch.from_numpy(image).cuda()
451-
# tmp = model.preprocess(dummy_tensor)
452-
# plt.imshow(tmp.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
453-
# plt.show()
454-
# inverse_tensor = model.postprocess_masks(tmp.unsqueeze(0), (365, 256)).squeeze(0)
455-
# print(torch.sum(torch.abs(inverse_tensor-dummy_tensor)))
456-
# print("dummy_tensor", torch.min(dummy_tensor), torch.max(dummy_tensor))
457-
# print("inverse_tensor", torch.min(inverse_tensor), torch.max(inverse_tensor))
458-
# plt.imshow(inverse_tensor.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
459-
# plt.show()
460-
# print()

0 commit comments

Comments
 (0)