@@ -220,6 +220,7 @@ def train_sam(
220220 verify_n_labels_in_loader : Optional [int ] = 50 ,
221221 box_distortion_factor : Optional [float ] = 0.025 ,
222222 overwrite_training : bool = True ,
223+ strict_decoder_loading : bool = True ,
223224 ** model_kwargs ,
224225) -> None :
225226 """Run training for a SAM model.
@@ -265,6 +266,10 @@ def train_sam(
265266 overwrite_training: Whether to overwrite the trained model stored at the same location.
266267 By default, overwrites the trained model at each run.
267268 If set to 'False', it will avoid retraining the model if the previous run was completed.
269+ strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
270+ exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
271+ channels if they were pre-trained for a different task. If set to False, decoders with a different
272+ output dimension can be loaded; the output channels will be re-initialized.
268273 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
269274 """
270275 with _filter_warnings (ignore_warnings ):
@@ -294,7 +299,8 @@ def train_sam(
294299
295300 # Get the UNETR.
296301 unetr = get_unetr (
297- image_encoder = model .sam .image_encoder , decoder_state = state .get ("decoder_state" , None ), device = device ,
302+ image_encoder = model .sam .image_encoder , decoder_state = state .get ("decoder_state" , None ),
303+ device = device , flexible_load_checkpoint = not strict_decoder_loading ,
298304 )
299305
300306 # Get the parameters for SAM and the decoder from UNETR.
@@ -435,6 +441,7 @@ def train_instance_segmentation(
435441 peft_kwargs : Optional [Dict ] = None ,
436442 ignore_warnings : bool = True ,
437443 overwrite_training : bool = True ,
444+ strict_decoder_loading : bool = True ,
438445 ** model_kwargs ,
439446) -> None :
440447 """Train a UNETR for instance segmentation using the SAM encoder as backbone.
@@ -481,6 +488,10 @@ def train_instance_segmentation(
481488 overwrite_training: Whether to overwrite the trained model stored at the same location.
482489 By default, overwrites the trained model at each run.
483490 If set to 'False', it will avoid retraining the model if the previous run was completed.
491+ strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
492+ exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
493+ channels if they were pre-trained for a different task. If set to False, decoders with a different
494+ output dimension can be loaded; the output channels will be re-initialized.
484495 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
485496 """
486497
@@ -498,7 +509,8 @@ def train_instance_segmentation(
498509 )
499510 device = get_device (device )
500511 model = get_unetr (
501- image_encoder = sam_model .sam .image_encoder , decoder_state = state .get ("decoder_state" , None ), device = device ,
512+ image_encoder = sam_model .sam .image_encoder , decoder_state = state .get ("decoder_state" , None ),
513+ device = device , flexible_load_checkpoint = not strict_decoder_loading ,
502514 )
503515
504516 optimizer , scheduler = _get_optimizer_and_scheduler (
0 commit comments