22
33import numpy as np
44import torch
5+ from scipy .ndimage import shift
56from skimage .draw import disk
67from skimage .measure import label as connected_components
78
89import micro_sam .util as util
910
11+ try :
12+ from trackastra .model import Trackastra # noqa
13+ WITH_TRACKASTRA = True
14+ except ImportError :
15+ WITH_TRACKASTRA = False
16+
1017HAVE_CUDA = torch .cuda .is_available ()
1118
1219
@@ -43,9 +50,15 @@ def write_object(center, radius):
4350 def _get_3d_inputs (cls , shape ):
4451 mask , image = cls ._get_2d_inputs (shape [- 2 :])
4552
46- # Create volumes by stacking the input image and respective mask.
47- volume = np .stack ([image ] * shape [0 ])
48- labels = np .stack ([mask ] * shape [0 ])
53+ # Create volumes by stacking the input image and respective mask with small shifts.
54+ labels , volume = [mask ], [image ]
55+ for _ in range (shape [0 ] - 1 ):
56+ shift_vector = (np .random .randint (1 , 4 ), np .random .randint (1 , 4 ))
57+ labels .append (shift (mask , shift = shift_vector , order = 0 , mode = "constant" , cval = 0 ))
58+ volume .append (shift (image , shift = shift_vector , order = 0 , mode = "constant" , cval = 0 ))
59+
60+ volume = np .stack (volume )
61+ labels = np .stack (labels )
4962 return labels , volume
5063
5164 @classmethod
@@ -104,6 +117,7 @@ def test_instance_segmentation_with_decoder_2d(self):
104117 predictor = predictor , segmenter = segmenter , input_path = image , ndim = 2 ,
105118 )
106119 self .assertEqual (mask .shape , instances .shape )
120+ self .assertGreater (instances .max (), 0 )
107121
108122 def test_tiled_instance_segmentation_with_decoder_2d (self ):
109123 from micro_sam .automatic_segmentation import automatic_instance_segmentation , get_predictor_and_segmenter
@@ -118,6 +132,7 @@ def test_tiled_instance_segmentation_with_decoder_2d(self):
118132 batch_size = 2 ,
119133 )
120134 self .assertEqual (mask .shape , instances .shape )
135+ self .assertGreater (instances .max (), 0 )
121136
122137 @unittest .skipUnless (HAVE_CUDA , "Skipping long running tests unless we have a GPU." )
123138 def test_automatic_mask_generator_3d (self ):
@@ -131,6 +146,7 @@ def test_automatic_mask_generator_3d(self):
131146 predictor = predictor , segmenter = segmenter , input_path = volume , ndim = 3 ,
132147 )
133148 self .assertEqual (labels .shape , instances .shape )
149+ self .assertGreater (instances .max (), 0 )
134150
135151 @unittest .skipUnless (HAVE_CUDA , "Skipping long running tests unless we have a GPU." )
136152 def test_tiled_automatic_mask_generator_3d (self ):
@@ -145,6 +161,7 @@ def test_tiled_automatic_mask_generator_3d(self):
145161 ndim = 3 , tile_shape = self .tile_shape , halo = self .halo ,
146162 )
147163 self .assertEqual (labels .shape , instances .shape )
164+ self .assertGreater (instances .max (), 0 )
148165
149166 def test_instance_segmentation_with_decoder_3d (self ):
150167 from micro_sam .automatic_segmentation import automatic_instance_segmentation , get_predictor_and_segmenter
@@ -157,6 +174,19 @@ def test_instance_segmentation_with_decoder_3d(self):
157174 predictor = predictor , segmenter = segmenter , input_path = volume , ndim = 3 ,
158175 )
159176 self .assertEqual (labels .shape , instances .shape )
177+ self .assertGreater (instances .max (), 0 )
178+
179+ @unittest .skipUnless (WITH_TRACKASTRA , "Needs trackastra" )
180+ def test_automatic_tracking (self ):
181+ from micro_sam .automatic_segmentation import automatic_tracking , get_predictor_and_segmenter
182+
183+ labels , volume = self .labels , self .volume
184+ predictor , segmenter = get_predictor_and_segmenter (
185+ model_type = self .model_type_ais , segmentation_mode = "ais" , is_tiled = False
186+ )
187+ instances , _ = automatic_tracking (predictor = predictor , segmenter = segmenter , input_path = volume )
188+ self .assertEqual (labels .shape , instances .shape )
189+ self .assertGreater (instances .max (), 0 )
160190
161191 def test_tiled_instance_segmentation_with_decoder_3d (self ):
162192 from micro_sam .automatic_segmentation import automatic_instance_segmentation , get_predictor_and_segmenter
@@ -170,6 +200,7 @@ def test_tiled_instance_segmentation_with_decoder_3d(self):
170200 ndim = 3 , tile_shape = self .tile_shape , halo = self .halo ,
171201 )
172202 self .assertEqual (labels .shape , instances .shape )
203+ self .assertGreater (instances .max (), 0 )
173204
174205
175206if __name__ == "__main__" :
0 commit comments