1+ # Copyright (c) MONAI Consortium
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ # http://www.apache.org/licenses/LICENSE-2.0
6+ # Unless required by applicable law or agreed to in writing, software
7+ # distributed under the License is distributed on an "AS IS" BASIS,
8+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+ # See the License for the specific language governing permissions and
10+ # limitations under the License.
11+
112# Copyright (c) Meta Platforms, Inc. and affiliates.
213# All rights reserved.
314
718from functools import partial
819from typing import Any , Dict , List , Tuple
920
21+ import monai
1022import torch
23+ from segment_anything .modeling import TwoWayTransformer
24+ from segment_anything .modeling .mask_decoder import MaskDecoder
1125from torch import nn
1226from torch .nn import functional as F
1327
14- from .segment_anything .modeling import TwoWayTransformer
15- from .segment_anything .modeling .image_encoder import ImageEncoderViT
16- from .segment_anything .modeling .mask_decoder import MaskDecoder
17- from .segment_anything .modeling .prompt_encoder import PromptEncoder
28+ from .vista_2pt5_image_encoder import VistaImageEncoderViT
29+ from .vista_2pt5_prompt_encoder import VistaPromptEncoder
1830
1931
20- class Samm2pt5D (nn .Module ):
32+ class Vista2pt5D (nn .Module ):
2133 mask_threshold : float = 0.5
2234 image_format : str = "RGB"
2335
2436 def __init__ (
2537 self ,
26- image_encoder : ImageEncoderViT ,
27- prompt_encoder : PromptEncoder ,
38+ image_encoder : VistaImageEncoderViT ,
39+ prompt_encoder : VistaPromptEncoder ,
2840 mask_decoder : MaskDecoder ,
2941 pixel_mean : List [float ] = [123.675 , 116.28 , 103.53 ],
3042 pixel_std : List [float ] = [58.395 , 57.12 , 57.375 ],
@@ -67,7 +79,6 @@ def get_mask_prediction(
6779 for image_record , curr_embedding in zip (batched_input , image_embeddings ):
6880 if "point_coords" in image_record :
6981 points = (image_record ["point_coords" ], image_record ["point_labels" ])
70- # raise NotImplementedError
7182 else :
7283 points = None
7384 sparse_embeddings , dense_embeddings = self .prompt_encoder (
@@ -86,7 +97,6 @@ def get_mask_prediction(
8697
8798 high_res_masks = self .postprocess_masks (
8899 low_res_masks ,
89- # input_size=image_record["image"].shape[-2:],
90100 original_size = image_record ["original_size" ],
91101 )
92102 masks = high_res_masks > self .mask_threshold
@@ -124,6 +134,8 @@ def forward(
124134 input frame of the model.
125135 'point_labels': (torch.Tensor) Batched labels for point prompts,
126136 with shape BxN.
137+ 'labels': (torch.Tensor) Batched labels for class-label prompt,
138+ with shape BxN.
127139 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
128140 Already transformed to the input frame of the model.
129141 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
@@ -151,7 +163,6 @@ def forward(
151163 for image_record , curr_embedding in zip (batched_input , image_embeddings ):
152164 if "point_coords" in image_record :
153165 points = (image_record ["point_coords" ], image_record ["point_labels" ])
154- # raise NotImplementedError
155166 else :
156167 points = None
157168 sparse_embeddings , dense_embeddings = self .prompt_encoder (
@@ -194,7 +205,6 @@ def forward(
194205 def postprocess_masks (
195206 self ,
196207 masks : torch .Tensor ,
197- # input_size: Tuple[int, ...],
198208 original_size : Tuple [int , ...],
199209 ) -> torch .Tensor :
200210 """
@@ -261,21 +271,23 @@ def preprocess(self, x: torch.Tensor, is_input=True) -> torch.Tensor:
261271 return x
262272
263273
264- def _build_sam2pt5d (
274+ def _build_vista2pt5d (
265275 encoder_in_chans ,
266276 encoder_embed_dim ,
267277 encoder_depth ,
268278 encoder_num_heads ,
269279 encoder_global_attn_indexes ,
270280 checkpoint = None ,
271281 image_size = 1024 ,
282+ clip_class_label_prompt = False ,
283+ patch_embed_3d = False ,
272284):
273285 prompt_embed_dim = 256
274- image_size = image_size # TODO: Shall we try to adapt model to 512x512 ?
286+ image_size = image_size
275287 vit_patch_size = 16
276288 image_embedding_size = image_size // vit_patch_size
277- sam = Samm2pt5D (
278- image_encoder = ImageEncoderViT (
289+ sam = Vista2pt5D (
290+ image_encoder = VistaImageEncoderViT (
279291 in_chans = encoder_in_chans ,
280292 depth = encoder_depth ,
281293 embed_dim = encoder_embed_dim ,
@@ -289,12 +301,14 @@ def _build_sam2pt5d(
289301 global_attn_indexes = encoder_global_attn_indexes ,
290302 window_size = 14 ,
291303 out_chans = prompt_embed_dim ,
304+ patch_embed_3d = patch_embed_3d ,
292305 ),
293- prompt_encoder = PromptEncoder (
306+ prompt_encoder = VistaPromptEncoder (
294307 embed_dim = prompt_embed_dim ,
295308 image_embedding_size = (image_embedding_size , image_embedding_size ),
296309 input_image_size = (image_size , image_size ),
297310 mask_in_chans = 16 ,
311+ clip_class_label_prompt = clip_class_label_prompt ,
298312 ),
299313 mask_decoder = MaskDecoder (
300314 num_multimask_outputs = 3 , # TODO: only predict one binary mask
@@ -315,22 +329,26 @@ def _build_sam2pt5d(
315329 if checkpoint is not None :
316330 with open (checkpoint , "rb" ) as f :
317331 state_dict = torch .load (f )
332+
318333 if image_size == 1024 :
319334 # we try to use all pretrained weights
320335 new_dict = state_dict
321- if encoder_in_chans != 3 :
322- new_dict .pop ("image_encoder.patch_embed.proj.weight" )
323336 else :
324337 new_dict = {}
325338 for k , v in state_dict .items ():
326- # skip weights in prompt_encoder and mask_decoder
327- if k .startswith ("prompt_encoder" ) or k .startswith ("mask_decoder" ):
328- continue
329339 # skip weights in position embedding and learned relative positional embeddings
330- elif "pos_embed" in k or "attn.rel_pos" in k :
340+ # due to the change of input size
341+ if ("pos_embed" in k and k .startswith ("image_encoder" )) or (
342+ "attn.rel_pos" in k and k .startswith ("image_encoder" )
343+ ):
331344 continue
332345 else :
333346 new_dict [k ] = v
347+
348+ if encoder_in_chans != 3 :
349+ new_dict .pop ("image_encoder.patch_embed.proj.weight" )
350+ new_dict .pop ("image_encoder.patch_embed.proj.bias" )
351+
334352 sam .load_state_dict (new_dict , strict = False )
335353 print (f"Load { len (new_dict )} keys from checkpoint { checkpoint } , current model has { len (sam .state_dict ())} keys" )
336354
@@ -355,106 +373,67 @@ def _build_sam2pt5d(
355373 f"{ sum (mask_decoder_params ) * 1.e-6 :.2f} M params in mask decoder."
356374 )
357375
358- # comment to unfreeze all encoder layers
359- # for name, param in sam.named_parameters():
360- # if name.startswith("image_encoder"):
361- # if image_size == 1024:
362- # if "pos_embed" in name or "patch_embed" in name or "blocks.0" in name:
363- # # we only retrain layers before blocks.1 in image_encoder
364- # continue
365- # # if "pos_embed" in name or "patch_embed" in name:
366- # # # we only retrain pos_embed and patch_embed
367- # # continue
368- # else:
369- # if "pos_embed" in name or "attn.rel_pos" in name or \
370- # "patch_embed" in name or "blocks.0" in name or "neck" in name:
371- # # we only train pos_embed, patch_embed, blocks.0, attn.rel_pos (due res change)
372- # # and neck (a few conv layers for outputs) in image_encoder
373- # continue
374- #
375- # # we freeze all other layers in image_encoder
376- # param.requires_grad = False
377-
378376 total_trainable_params = sum (p .numel () if p .requires_grad else 0 for p in sam .parameters ())
379377 print (f"{ sam .__class__ .__name__ } has { total_trainable_params * 1.e-6 :.2f} M trainable params." )
380378 return sam
381379
382380
383- def build_samm2pt5d_vit_h (checkpoint = None , image_size = 1024 , encoder_in_chans = 3 ):
384- return _build_sam2pt5d (
381+ def build_vista2pt5d_vit_h (
382+ checkpoint = None , image_size = 1024 , encoder_in_chans = 3 , clip_class_label_prompt = False , patch_embed_3d = False
383+ ):
384+ return _build_vista2pt5d (
385385 encoder_in_chans = encoder_in_chans ,
386386 encoder_embed_dim = 1280 ,
387387 encoder_depth = 32 ,
388388 encoder_num_heads = 16 ,
389389 encoder_global_attn_indexes = [7 , 15 , 23 , 31 ],
390390 checkpoint = checkpoint ,
391391 image_size = image_size ,
392+ clip_class_label_prompt = clip_class_label_prompt ,
393+ patch_embed_3d = patch_embed_3d ,
392394 )
393395
394396
395- def build_samm2pt5d_vit_l (checkpoint = None , image_size = 1024 , encoder_in_chans = 3 ):
396- return _build_sam2pt5d (
397+ def build_vista2pt5d_vit_l (
398+ checkpoint = None , image_size = 1024 , encoder_in_chans = 3 , clip_class_label_prompt = False , patch_embed_3d = False
399+ ):
400+ return _build_vista2pt5d (
397401 encoder_in_chans = encoder_in_chans ,
398402 encoder_embed_dim = 1024 ,
399403 encoder_depth = 24 ,
400404 encoder_num_heads = 16 ,
401405 encoder_global_attn_indexes = [5 , 11 , 17 , 23 ],
402406 checkpoint = checkpoint ,
403407 image_size = image_size ,
408+ clip_class_label_prompt = clip_class_label_prompt ,
409+ patch_embed_3d = patch_embed_3d ,
404410 )
405411
406412
407- def build_samm2pt5d_vit_b (checkpoint = None , image_size = 1024 , encoder_in_chans = 3 ):
408- return _build_sam2pt5d (
413+ def build_vista2pt5d_vit_b (
414+ checkpoint = None , image_size = 1024 , encoder_in_chans = 3 , clip_class_label_prompt = False , patch_embed_3d = False
415+ ):
416+ return _build_vista2pt5d (
409417 encoder_in_chans = encoder_in_chans ,
410418 encoder_embed_dim = 768 ,
411419 encoder_depth = 12 ,
412420 encoder_num_heads = 12 ,
413421 encoder_global_attn_indexes = [2 , 5 , 8 , 11 ],
414422 checkpoint = checkpoint ,
415423 image_size = image_size ,
424+ clip_class_label_prompt = clip_class_label_prompt ,
425+ patch_embed_3d = patch_embed_3d ,
416426 )
417427
418428
419429sam_model_registry = {
420- "default" : build_samm2pt5d_vit_h ,
421- "vit_h" : build_samm2pt5d_vit_h ,
422- "vit_l" : build_samm2pt5d_vit_l ,
423- "vit_b" : build_samm2pt5d_vit_b ,
430+ "default" : build_vista2pt5d_vit_h ,
431+ "vit_h" : build_vista2pt5d_vit_h ,
432+ "vit_l" : build_vista2pt5d_vit_l ,
433+ "vit_b" : build_vista2pt5d_vit_b ,
424434}
425435
436+
426437if __name__ == "__main__" :
427- model = build_samm2pt5d_vit_b ()
438+ model = build_vista2pt5d_vit_b ()
428439 model .cuda ()
429- #
430- # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (176, 345),
431- # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda(),
432- # "labels": torch.ones(3, 1).long().cuda()},
433- # {"image": torch.rand(3, 128, 365).cuda(), "original_size": (128, 365),
434- # "point_coords": torch.rand(1, 3, 2).cuda(), "point_labels": torch.ones(1, 3).cuda(),
435- # "labels": torch.ones(1, 1).long().cuda()}
436- # ]
437- # # dummy_input = [{"image": torch.rand(3, 176, 345).cuda(), "original_size": (256, 512),
438- # # "point_coords": torch.rand(3, 5, 2).cuda(), "point_labels": torch.ones(3, 5).cuda()}]
439- # outputs = model(dummy_input)
440-
441- # test if postprocessing can inverse preprocess
442- # path = "/home/pengfeig/Downloads/fffabebf-74fd3a1f-673b6b41-96ec0ac9-2ab69818.jpg"
443- # from PIL import Image
444- # import numpy as np
445- # import matplotlib.pyplot as plt
446- #
447- # image = np.array(Image.open(path)).transpose(2, 0, 1)[:, :365, :256].astype(np.float32)
448- # plt.imshow(image.transpose(1, 2, 0).astype(np.uint8))
449- # plt.show()
450- # dummy_tensor = torch.from_numpy(image).cuda()
451- # tmp = model.preprocess(dummy_tensor)
452- # plt.imshow(tmp.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
453- # plt.show()
454- # inverse_tensor = model.postprocess_masks(tmp.unsqueeze(0), (365, 256)).squeeze(0)
455- # print(torch.sum(torch.abs(inverse_tensor-dummy_tensor)))
456- # print("dummy_tensor", torch.min(dummy_tensor), torch.max(dummy_tensor))
457- # print("inverse_tensor", torch.min(inverse_tensor), torch.max(inverse_tensor))
458- # plt.imshow(inverse_tensor.cpu().numpy().transpose(1, 2, 0).astype(np.uint8))
459- # plt.show()
460- # print()
0 commit comments