Skip to content

Commit 65badb2

Browse files
Merge pull request #1171 from computational-cell-analytics/dev
Bump version
2 parents a4a4eaf + 71a4053 commit 65badb2

11 files changed

Lines changed: 128 additions & 43 deletions

File tree

doc/faq.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ You can then use those models with the custom checkpoint option, see answer 15 f
187187
### 18. I would like to evaluate the instance segmentation quantitatively. Can you suggest how to do that?
188188
`micro-sam` supports a `micro_sam.evaluate` CLI, which computes the mean segmentation accuracy (introduced in the Pascal VOC challenge) of the predicted instance segmentation with the corresponding ground-truth annotations. Please see our paper (`Methods` -> `Inference and Evaluation` for more details about it) and `$ micro_sam.evaluate -h` for more details about the evaluation CLI.
189189

190+
### 19. I get `RuntimeError: GET was unable to find an engine to execute this computation` on a V100 GPU (*"or any older GPU"*).
191+
This is a known issue for a combination of older generation GPUs (eg. V100s) and pytorch compiled with the latest CUDA Toolkit (eg. CUDA 12.9 and PyTorch 2.8 has been tested to throw this error on V100s).
192+
Here's what you can do to solve this issue:
193+
- Use a PyTorch/CUDA build that is known to work with V100, for example CUDA 12.1 or 11.8 with a compatible PyTorch version (please check your installed CUDA drivers).
194+
- Run on CPU (slower, but works).
190195

191196
## Fine-tuning questions
192197

micro_sam/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.1"
1+
__version__ = "1.7.2"

micro_sam/automatic_segmentation.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ def automatic_tracking(
127127
The lineages representing cell divisions, stored as a dictionary.
128128
"""
129129
# Load the input image file.
130-
if isinstance(input_path, np.ndarray):
131-
image_data = input_path
132-
else:
133-
image_data = util.load_image_data(input_path, key)
130+
# We assume that it has to be read from file if it is a str or pathlike.
131+
# Otherwise we assume it is a numpy array like object.
132+
image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
134133

135134
if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
136135
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
@@ -168,7 +167,9 @@ def automatic_instance_segmentation(
168167
input_path: Union[Union[os.PathLike, str], np.ndarray],
169168
output_path: Optional[Union[os.PathLike, str]] = None,
170169
embedding_path: Optional[Union[os.PathLike, str]] = None,
170+
mask_path: Optional[Union[Union[os.PathLike, str], np.ndarray]] = None,
171171
key: Optional[str] = None,
172+
mask_key: Optional[str] = None,
172173
ndim: Optional[int] = None,
173174
tile_shape: Optional[Tuple[int, int]] = None,
174175
halo: Optional[Tuple[int, int]] = None,
@@ -187,8 +188,10 @@ def automatic_instance_segmentation(
187188
or a container file (e.g. hdf5 or zarr).
188189
output_path: The output path where the instance segmentations will be saved.
189190
embedding_path: The path where the embeddings are cached already / will be saved.
191+
mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed.
190192
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
191193
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
194+
mask_key: The key to the (optional) foreground mask.
192195
ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
193196
If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
194197
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
@@ -212,11 +215,9 @@ def automatic_instance_segmentation(
212215
print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
213216
return
214217

215-
# Load the input image file.
216-
if isinstance(input_path, np.ndarray):
217-
image_data = input_path
218-
else:
219-
image_data = util.load_image_data(input_path, key)
218+
# We assume that it has to be read from file if it is a str or pathlike.
219+
# Otherwise we assume it is a numpy array like object.
220+
image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
220221

221222
ndim = image_data.ndim if ndim is None else ndim
222223

@@ -244,6 +245,11 @@ def automatic_instance_segmentation(
244245
generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
245246
initialize_kwargs["batch_size"] = batch_size
246247

248+
# Load the mask defining foreground if it was given.
249+
if mask_path is not None:
250+
mask = util.load_image_data(mask_path, mask_key) if isinstance(mask_path, (str, os.PathLike)) else mask_path
251+
initialize_kwargs["mask"] = mask
252+
247253
segmenter.initialize(**initialize_kwargs)
248254
instances = segmenter.generate(**generate_kwargs)
249255

micro_sam/instance_segmentation.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,25 @@
44
"""
55

66
import os
7+
import shutil
8+
import tempfile
79
import warnings
810
from abc import ABC
11+
from contextlib import contextmanager
912
from copy import deepcopy
1013
from collections import OrderedDict
1114
from typing import Any, Dict, Literal, List, Optional, Tuple, Union
1215

13-
import vigra
1416
import numpy as np
17+
import zarr
1518
from skimage.measure import regionprops
1619
from skimage.segmentation import find_boundaries
1720

21+
try:
22+
import fastfilters as filter_impl
23+
except ImportError:
24+
import vigra.filters as filter_impl
25+
1826
import torch
1927
from torchvision.ops.boxes import batched_nms, box_area
2028

@@ -23,6 +31,8 @@
2331

2432
import elf.parallel as parallel_impl
2533
from elf.parallel.filters import apply_filter
34+
from elf.wrapper.base import MultiTransformationWrapper
35+
from elf.wrapper.generic import ThresholdWrapper
2636

2737
from nifty.tools import blocking
2838

@@ -853,6 +863,23 @@ def get_predictor_and_decoder(
853863
return predictor, decoder
854864

855865

866+
@contextmanager
867+
def _array_or_zarr(shape, dtype, chunks, use_zarr=False):
868+
if not use_zarr:
869+
yield np.zeros(shape, dtype=dtype)
870+
return
871+
872+
tmpdir = tempfile.mkdtemp(prefix="tmp-zarr-")
873+
try:
874+
store_path = os.path.join(tmpdir, "tmp.zarr")
875+
root = zarr.open_group(store_path, mode="w")
876+
arr = root.create_dataset(name="data", shape=shape, dtype=dtype, chunks=chunks)
877+
yield arr
878+
879+
finally:
880+
shutil.rmtree(tmpdir, ignore_errors=True)
881+
882+
856883
def _watershed_from_center_and_boundary_distances_parallel(
857884
center_distances,
858885
boundary_distances,
@@ -866,6 +893,8 @@ def _watershed_from_center_and_boundary_distances_parallel(
866893
halo,
867894
n_threads,
868895
verbose=False,
896+
optimize_memory=False,
897+
segmentation=None,
869898
):
870899
center_distances = apply_filter(
871900
center_distances, "gaussianSmoothing", sigma=distance_smoothing,
@@ -876,30 +905,45 @@ def _watershed_from_center_and_boundary_distances_parallel(
876905
block_shape=tile_shape, n_threads=n_threads
877906
)
878907

879-
fg_mask = foreground_map > foreground_threshold
908+
fg_mask = ThresholdWrapper(foreground_map, foreground_threshold, operator=np.greater)
880909

881-
marker_map = np.logical_and(
882-
center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
910+
marker_map = MultiTransformationWrapper(
911+
np.logical_and,
912+
ThresholdWrapper(center_distances, center_distance_threshold, operator=np.less),
913+
ThresholdWrapper(boundary_distances, boundary_distance_threshold, operator=np.less),
883914
)
884-
marker_map[~fg_mask] = 0
915+
marker_map = MultiTransformationWrapper(np.logical_and, marker_map, fg_mask)
885916

886-
markers = np.zeros(marker_map.shape, dtype="uint64")
887-
markers = parallel_impl.label(
888-
marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
889-
)
917+
with _array_or_zarr(marker_map.shape, dtype="uint64", chunks=tile_shape, use_zarr=optimize_memory) as markers:
918+
markers = parallel_impl.label(
919+
marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
920+
)
890921

891-
seg = np.zeros_like(markers, dtype="uint64")
892-
seg = parallel_impl.seeded_watershed(
893-
boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
894-
halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
895-
)
922+
if segmentation is None:
923+
segmentation = np.zeros(markers.shape, dtype="uint64")
924+
segmentation = parallel_impl.seeded_watershed(
925+
boundary_distances, seeds=markers, out=segmentation, block_shape=tile_shape,
926+
halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
927+
)
896928

897-
out = np.zeros_like(seg, dtype="uint64")
898-
out = parallel_impl.size_filter(
899-
seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
900-
)
929+
if min_size > 0:
930+
segmentation = parallel_impl.size_filter(
931+
segmentation, out=segmentation, min_size=min_size,
932+
block_shape=tile_shape, n_threads=n_threads, verbose=verbose
933+
)
901934

902-
return out
935+
return segmentation
936+
937+
938+
def _apply_smoothing(foreground, foreground_smoothing, tile_shape, n_threads):
939+
if tile_shape is None:
940+
foreground = filter_impl.gaussianSmoothing(foreground, foreground_smoothing)
941+
else:
942+
foreground = apply_filter(
943+
foreground, "gaussianSmoothing", sigma=foreground_smoothing,
944+
block_shape=tile_shape, n_threads=n_threads
945+
)
946+
return foreground
903947

904948

905949
class InstanceSegmentationWithDecoder:
@@ -1042,6 +1086,8 @@ def generate(
10421086
tile_shape: Optional[Tuple[int, int]] = None,
10431087
halo: Optional[Tuple[int, int]] = None,
10441088
n_threads: Optional[int] = None,
1089+
optimize_memory: bool = False,
1090+
segmentation: Optional[np.ndarray] = None,
10451091
) -> Union[List[Dict[str, Any]], np.ndarray]:
10461092
"""Generate instance segmentation for the currently initialized image.
10471093
@@ -1067,6 +1113,8 @@ def generate(
10671113
If not given then post-processing will not be parallelized.
10681114
halo: Halo for parallel post-processing. See also `tile_shape`.
10691115
n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1116+
optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
1117+
segmentation: Optional pre-allocated segmentation.
10701118
10711119
Returns:
10721120
The segmentation masks.
@@ -1075,7 +1123,7 @@ def generate(
10751123
raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
10761124

10771125
if foreground_smoothing > 0:
1078-
foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1126+
foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads)
10791127
else:
10801128
foreground = self._foreground
10811129

@@ -1106,6 +1154,8 @@ def generate(
11061154
halo=halo,
11071155
n_threads=n_threads,
11081156
verbose=False,
1157+
optimize_memory=optimize_memory,
1158+
segmentation=segmentation,
11091159
)
11101160

11111161
if output_mode != "instance_segmentation":

micro_sam/sam_annotator/_annotator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None):
3434
widgets._validation_window_for_missing_layer("current_object")
3535
self._viewer.add_labels(data=dummy_data, name="current_object")
3636
if image_scale is not None:
37-
self.layers["current_objects"].scale = image_scale
37+
self._viewer.layers["current_object"].scale = image_scale
3838

3939
if "auto_segmentation" not in self._viewer.layers:
4040
if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button.
4141
widgets._validation_window_for_missing_layer("auto_segmentation")
4242
self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
4343
if image_scale is not None:
44-
self.layers["auto_segmentation"].scale = image_scale
44+
self._viewer.layers["auto_segmentation"].scale = image_scale
4545

4646
if "committed_objects" not in self._viewer.layers:
4747
if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button.
@@ -50,7 +50,7 @@ def _require_layers(self, layer_choices: Optional[List[str]] = None):
5050
# Randomize colors so it is easy to see when object committed.
5151
self._viewer.layers["committed_objects"].new_colormap()
5252
if image_scale is not None:
53-
self.layers["committed_objects"].scale = image_scale
53+
self._viewer.layers["committed_objects"].scale = image_scale
5454

5555
# Add the point layer for point prompts.
5656
self._point_labels = ["positive", "negative"]

micro_sam/sam_annotator/_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def initialize_predictor(
9494
predictor=None,
9595
decoder=None,
9696
checkpoint_path=None,
97+
decoder_path=None,
9798
tile_shape=None,
9899
halo=None,
99100
precompute_amg_state=False,
@@ -113,7 +114,7 @@ def progress_bar_factory(model_type):
113114

114115
self.predictor, state = util.get_sam_model(
115116
device=device, model_type=model_type,
116-
checkpoint_path=checkpoint_path, return_state=True,
117+
checkpoint_path=checkpoint_path, decoder_path=decoder_path, return_state=True,
117118
progress_bar_factory=None if use_cli else progress_bar_factory,
118119
)
119120
if prefer_decoder and "decoder_state" in state and model_type != "vit_b_medical_imaging":

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def annotator_2d(
4848
viewer: Optional["napari.viewer.Viewer"] = None,
4949
precompute_amg_state: bool = False,
5050
checkpoint_path: Optional[str] = None,
51+
decoder_path: Optional[str] = None,
5152
device: Optional[Union[str, torch.device]] = None,
5253
prefer_decoder: bool = True,
5354
) -> Optional["napari.viewer.Viewer"]:
@@ -73,6 +74,7 @@ def annotator_2d(
7374
This will take more time when precomputing embeddings, but will then make
7475
automatic mask generation much faster. By default, set to 'False'.
7576
checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
77+
decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
7678
device: The computational device to use for the SAM model.
7779
By default, automatically chooses the best available device.
7880
prefer_decoder: Whether to use decoder based instance segmentation if
@@ -89,7 +91,8 @@ def annotator_2d(
8991
state.initialize_predictor(
9092
image, model_type=model_type, save_path=embedding_path,
9193
halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
92-
ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
94+
ndim=2, checkpoint_path=checkpoint_path, decoder_path=decoder_path,
95+
device=device, prefer_decoder=prefer_decoder,
9396
skip_load=False, use_cli=True,
9497
)
9598

@@ -137,5 +140,5 @@ def main():
137140
segmentation_result=segmentation_result,
138141
model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
139142
precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint,
140-
device=args.device, prefer_decoder=args.prefer_decoder,
143+
decoder_path=args.decoder_path, device=args.device, prefer_decoder=args.prefer_decoder,
141144
)

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def annotator_3d(
5858
viewer: Optional["napari.viewer.Viewer"] = None,
5959
precompute_amg_state: bool = False,
6060
checkpoint_path: Optional[str] = None,
61+
decoder_path: Optional[str] = None,
6162
device: Optional[Union[str, torch.device]] = None,
6263
prefer_decoder: bool = True,
6364
) -> Optional["napari.viewer.Viewer"]:
@@ -83,6 +84,7 @@ def annotator_3d(
8384
This will take more time when precomputing embeddings, but will then make
8485
automatic mask generation much faster. By default, set to 'False'.
8586
checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
87+
decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
8688
device: The computational device to use for the SAM model.
8789
By default, automatically chooses the best available device.
8890
prefer_decoder: Whether to use decoder based instance segmentation if
@@ -99,7 +101,8 @@ def annotator_3d(
99101
state.initialize_predictor(
100102
image, model_type=model_type, save_path=embedding_path,
101103
halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
102-
checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
104+
checkpoint_path=checkpoint_path, decoder_path=decoder_path,
105+
device=device, prefer_decoder=prefer_decoder,
103106
use_cli=True,
104107
)
105108

@@ -148,4 +151,5 @@ def main():
148151
model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
149152
checkpoint_path=args.checkpoint, device=args.device,
150153
precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder,
154+
decoder_path=args.decoder_path,
151155
)

0 commit comments

Comments
 (0)