Skip to content

Commit ff82940

Browse files
Merge pull request #1175 from computational-cell-analytics/dev
Dev
2 parents 3741317 + b317328 commit ff82940

3 files changed

Lines changed: 30 additions & 7 deletions

File tree

micro_sam/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.4"
1+
__version__ = "1.7.5"

micro_sam/instance_segmentation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,16 +789,21 @@ def get_unetr(
789789
unetr_state_dict = unetr.state_dict()
790790
for k, v in unetr_state_dict.items():
791791
if not k.startswith("encoder"):
792-
if flexible_load_checkpoint: # Whether allow reinitalization of params, if not found.
792+
# Whether allow reinitalization of params, if not found or mismatched.
793+
if flexible_load_checkpoint:
793794
if k in decoder_state: # First check whether the key is available in the provided decoder state.
794-
unetr_state_dict[k] = decoder_state[k]
795+
if v.shape != decoder_state[k].shape: # Then check if the sizes mismatch.
796+
warnings.warn(f"Shape of '{k}' did not match. Hence, we reinitialize it.")
797+
unetr_state_dict[k] = v
798+
else:
799+
unetr_state_dict[k] = decoder_state[k]
795800
else: # Otherwise, allow it to initialize it.
796801
warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
797802
unetr_state_dict[k] = v
798803

799-
else: # Whether be strict on finding the parameter in the decoder state.
804+
else: # Be strict on finding the parameter in the decoder state.
800805
if k not in decoder_state:
801-
raise RuntimeError(f"The parameters for '{k}' could not be found.")
806+
raise RuntimeError(f"The parameters for '{k}' could not be found or has a size mismatch.")
802807
unetr_state_dict[k] = decoder_state[k]
803808

804809
unetr.load_state_dict(unetr_state_dict)
@@ -1143,6 +1148,12 @@ def generate(
11431148
else:
11441149
if halo is None:
11451150
raise ValueError("You must pass a value for halo if tile_shape is given.")
1151+
1152+
# Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling.
1153+
# This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed.
1154+
if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None:
1155+
tile_shape = segmentation.shards
1156+
11461157
segmentation = _watershed_from_center_and_boundary_distances_parallel(
11471158
center_distances=self._center_distances,
11481159
boundary_distances=self._boundary_distances,

micro_sam/training/training.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def train_sam(
220220
verify_n_labels_in_loader: Optional[int] = 50,
221221
box_distortion_factor: Optional[float] = 0.025,
222222
overwrite_training: bool = True,
223+
strict_decoder_loading: bool = True,
223224
**model_kwargs,
224225
) -> None:
225226
"""Run training for a SAM model.
@@ -265,6 +266,10 @@ def train_sam(
265266
overwrite_training: Whether to overwrite the trained model stored at the same location.
266267
By default, overwrites the trained model at each run.
267268
If set to 'False', it will avoid retraining the model if the previous run was completed.
269+
strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
270+
exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
271+
channels if they were pre-trained for a different task. If set to False, decoders with a different
272+
output dimension can be loaded; the output channels will be re-initialized.
268273
model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
269274
"""
270275
with _filter_warnings(ignore_warnings):
@@ -294,7 +299,8 @@ def train_sam(
294299

295300
# Get the UNETR.
296301
unetr = get_unetr(
297-
image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
302+
image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
303+
device=device, flexible_load_checkpoint=not strict_decoder_loading,
298304
)
299305

300306
# Get the parameters for SAM and the decoder from UNETR.
@@ -435,6 +441,7 @@ def train_instance_segmentation(
435441
peft_kwargs: Optional[Dict] = None,
436442
ignore_warnings: bool = True,
437443
overwrite_training: bool = True,
444+
strict_decoder_loading: bool = True,
438445
**model_kwargs,
439446
) -> None:
440447
"""Train a UNETR for instance segmentation using the SAM encoder as backbone.
@@ -481,6 +488,10 @@ def train_instance_segmentation(
481488
overwrite_training: Whether to overwrite the trained model stored at the same location.
482489
By default, overwrites the trained model at each run.
483490
If set to 'False', it will avoid retraining the model if the previous run was completed.
491+
strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
492+
exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
493+
channels if they were pre-trained for a different task. If set to False, decoders with a different
494+
output dimension can be loaded; the output channels will be re-initialized.
484495
model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
485496
"""
486497

@@ -498,7 +509,8 @@ def train_instance_segmentation(
498509
)
499510
device = get_device(device)
500511
model = get_unetr(
501-
image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
512+
image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
513+
device=device, flexible_load_checkpoint=not strict_decoder_loading,
502514
)
503515

504516
optimizer, scheduler = _get_optimizer_and_scheduler(

0 commit comments

Comments
 (0)