Skip to content

Commit c1d8fa5

Browse files
Some updates to tracking with trackastra and add test for automatic tracking (#1185)
* Some updates to tracking with trackastra and add test for automatic tracking * Fix failing tests * Add trackastra version constraint and pin version
1 parent 548139e commit c1d8fa5

4 files changed

Lines changed: 50 additions & 7 deletions

File tree

environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies:
2626
- torch_em >=0.8
2727
- tqdm
2828
- timm
29-
- trackastra
29+
- trackastra >=0.5.3
3030
- xarray
3131
- zarr
3232
- pip:

micro_sam/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.7.5"
1+
__version__ = "1.7.6"

micro_sam/multi_dimensional_segmentation.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import os
55
import multiprocessing as mp
6+
import warnings
67
from concurrent import futures
78
from typing import Dict, List, Optional, Union, Tuple
89

@@ -569,8 +570,19 @@ def _filter_lineages(lineages, tracking_result):
569570
def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None):
570571
device = "cuda" if torch.cuda.is_available() else "cpu"
571572
model = Trackastra.from_pretrained("general_2d", device=device)
572-
lineage_graph, _ = model.track(timeseries, segmentation, mode=mode)
573+
result = model.track(timeseries, segmentation, mode=mode)
574+
try:
575+
lineage_graph, _ = result
576+
except ValueError:
577+
lineage_graph = result
578+
573579
track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph)
580+
if track_data.size == 0:
581+
warnings.warn("Tracking result is empty.")
582+
tracking_result = np.zeros_like(segmentation)
583+
lineages = []
584+
return tracking_result, lineages
585+
574586
node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph)
575587
tracking_result = recolor_segmentation(segmentation, node_to_track)
576588

@@ -625,7 +637,7 @@ def track_across_frames(
625637
"""
626638
if Trackastra is None:
627639
raise RuntimeError(
628-
"The automatic tracking functionality requires trackastra. You can install it via 'pip install trackastra'."
640+
"Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
629641
)
630642

631643
_, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)

test/test_automatic_segmentation.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@
22

33
import numpy as np
44
import torch
5+
from scipy.ndimage import shift
56
from skimage.draw import disk
67
from skimage.measure import label as connected_components
78

89
import micro_sam.util as util
910

11+
try:
12+
from trackastra.model import Trackastra # noqa
13+
WITH_TRACKASTRA = True
14+
except ImportError:
15+
WITH_TRACKASTRA = False
16+
1017
HAVE_CUDA = torch.cuda.is_available()
1118

1219

@@ -43,9 +50,15 @@ def write_object(center, radius):
4350
def _get_3d_inputs(cls, shape):
4451
mask, image = cls._get_2d_inputs(shape[-2:])
4552

46-
# Create volumes by stacking the input image and respective mask.
47-
volume = np.stack([image] * shape[0])
48-
labels = np.stack([mask] * shape[0])
53+
# Create volumes by stacking the input image and respective mask with small shifts.
54+
labels, volume = [mask], [image]
55+
for _ in range(shape[0] - 1):
56+
shift_vector = (np.random.randint(1, 4), np.random.randint(1, 4))
57+
labels.append(shift(mask, shift=shift_vector, order=0, mode="constant", cval=0))
58+
volume.append(shift(image, shift=shift_vector, order=0, mode="constant", cval=0))
59+
60+
volume = np.stack(volume)
61+
labels = np.stack(labels)
4962
return labels, volume
5063

5164
@classmethod
@@ -104,6 +117,7 @@ def test_instance_segmentation_with_decoder_2d(self):
104117
predictor=predictor, segmenter=segmenter, input_path=image, ndim=2,
105118
)
106119
self.assertEqual(mask.shape, instances.shape)
120+
self.assertGreater(instances.max(), 0)
107121

108122
def test_tiled_instance_segmentation_with_decoder_2d(self):
109123
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
@@ -118,6 +132,7 @@ def test_tiled_instance_segmentation_with_decoder_2d(self):
118132
batch_size=2,
119133
)
120134
self.assertEqual(mask.shape, instances.shape)
135+
self.assertGreater(instances.max(), 0)
121136

122137
@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
123138
def test_automatic_mask_generator_3d(self):
@@ -131,6 +146,7 @@ def test_automatic_mask_generator_3d(self):
131146
predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3,
132147
)
133148
self.assertEqual(labels.shape, instances.shape)
149+
self.assertGreater(instances.max(), 0)
134150

135151
@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
136152
def test_tiled_automatic_mask_generator_3d(self):
@@ -145,6 +161,7 @@ def test_tiled_automatic_mask_generator_3d(self):
145161
ndim=3, tile_shape=self.tile_shape, halo=self.halo,
146162
)
147163
self.assertEqual(labels.shape, instances.shape)
164+
self.assertGreater(instances.max(), 0)
148165

149166
def test_instance_segmentation_with_decoder_3d(self):
150167
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
@@ -157,6 +174,19 @@ def test_instance_segmentation_with_decoder_3d(self):
157174
predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3,
158175
)
159176
self.assertEqual(labels.shape, instances.shape)
177+
self.assertGreater(instances.max(), 0)
178+
179+
@unittest.skipUnless(WITH_TRACKASTRA, "Needs trackastra")
180+
def test_automatic_tracking(self):
181+
from micro_sam.automatic_segmentation import automatic_tracking, get_predictor_and_segmenter
182+
183+
labels, volume = self.labels, self.volume
184+
predictor, segmenter = get_predictor_and_segmenter(
185+
model_type=self.model_type_ais, segmentation_mode="ais", is_tiled=False
186+
)
187+
instances, _ = automatic_tracking(predictor=predictor, segmenter=segmenter, input_path=volume)
188+
self.assertEqual(labels.shape, instances.shape)
189+
self.assertGreater(instances.max(), 0)
160190

161191
def test_tiled_instance_segmentation_with_decoder_3d(self):
162192
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
@@ -170,6 +200,7 @@ def test_tiled_instance_segmentation_with_decoder_3d(self):
170200
ndim=3, tile_shape=self.tile_shape, halo=self.halo,
171201
)
172202
self.assertEqual(labels.shape, instances.shape)
203+
self.assertGreater(instances.max(), 0)
173204

174205

175206
if __name__ == "__main__":

0 commit comments

Comments
 (0)