1818import imageio .v3 as imageio
1919
2020from torch_em .util .debug import check_loader
21- from torch_em .util .util import get_random_colors
2221
2322import micro_sam .training as sam_training
2423from micro_sam .training .util import normalize_to_8bit
@@ -129,8 +128,8 @@ def run_finetuning(model_name, train_loader, val_loader, model_type, overwrite,
129128 best_checkpoint = os .path .join (save_root , "checkpoints" , model_name , "best.pt" )
130129 if os .path .exists (best_checkpoint ) and not overwrite :
131130 print (
132- "It looks like the training has completed. You must pass the argument '--overwrite' to overwrite "
133- "the already finetuned model (or provide a new filepath at '--save_root' for training new models )."
131+ "It looks like the training has completed. Pass the argument '--overwrite' to overwrite "
132+ "the already finetuned model (or train a model with a new name using the argument '--model_name' )."
134133 )
135134 return best_checkpoint
136135
@@ -148,45 +147,41 @@ def run_finetuning(model_name, train_loader, val_loader, model_type, overwrite,
148147 return best_checkpoint
149148
150149
151- # TODO
152150def run_instance_segmentation_with_decoder (input_root , model_type , checkpoint ):
153- """Run automatic instance segmentation (AIS) .
151+ """Run automatic instance segmentation with the fine-tuned model .
154152
155153 Args:
156154 test_image_paths: List of filepaths for the test image data.
157155 model_type: The choice of Segment Anything model (connotated by the size of image encoder).
158156 checkpoint: Filepath to the finetuned model checkpoints.
159- device: The torch device used for inference.
160157 """
161158 assert os .path .exists (checkpoint ), "Please train the model first to run inference on the finetuned model."
162159
160+ # CHANGE: For your own data you have to adapt these folder names.
161+ folder = os .path .join (input_root , "hpa/test/images" )
162+ image_paths = glob (os .path .join (folder , "*.tif" ))
163+ print (len (image_paths ))
164+
163165 # Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
164- predictor , segmenter = get_predictor_and_segmenter (
165- model_type = model_type , checkpoint = checkpoint , device = device , is_tiled = True
166- )
166+ predictor , segmenter = get_predictor_and_segmenter (model_type = model_type , checkpoint = checkpoint , is_tiled = True )
167167
168- for image_path in test_image_paths :
168+ # Iterate over the training images to run the segmentation and visualize the result in napari.
169+ for image_path in image_paths :
169170 image = imageio .imread (image_path )
170171 image = normalize_to_8bit (image )
171172
172- # Predicting the instances.
173+ # CHANGE: Update the values for 'tile_shape' and 'halo' so that they match the patch_shape used in training,
174+ # according to: patch_shape = tile_shape + 2 * halo.
175+ # For example, if you used patch_shape = (512, 512), you can use tile_shape = (384, 384) and halo = (64, 64)
173176 prediction = automatic_instance_segmentation (
174177 predictor = predictor , segmenter = segmenter , input_path = image , ndim = 2 , tile_shape = (768 , 768 ), halo = (128 , 128 )
175178 )
176179
177- # Visualize the predictions
178- fig , ax = plt .subplots (1 , 2 , figsize = (10 , 10 ))
179-
180- ax [0 ].imshow (image )
181- ax [0 ].axis ("off" )
182- ax [0 ].set_title ("Input Image" )
183-
184- ax [1 ].imshow (prediction , cmap = get_random_colors (prediction ), interpolation = "nearest" )
185- ax [1 ].axis ("off" )
186- ax [1 ].set_title ("Predictions (AIS)" )
187-
188- plt .show ()
189- plt .close ()
180+ import napari
181+ v = napari .Viewer ()
182+ v .add_image (image )
183+ v .add_labels (prediction )
184+ napari .run ()
190185
191186 break # comment this out in case you want to run inference for all images.
192187
@@ -224,6 +219,7 @@ def main():
224219 # - vit_b_histopathology: For nucleus segmentation in histopathology.
225220 model_type = "vit_b_lm"
226221
222+ # Get the data loaders and run the training.
227223 train_loader , val_loader = get_dataloaders (args .input_path , view = args .view )
228224 checkpoint_path = run_finetuning (
229225 model_name = args .model_name ,
@@ -233,11 +229,12 @@ def main():
233229 overwrite = args .overwrite ,
234230 n_epochs = args .n_epochs ,
235231 )
232+ assert os .path .exists (checkpoint_path ), checkpoint_path
236233
237- # TODO
238- # run_instance_segmentation_with_decoder(
239- # test_image_paths=test_image_paths, model_type=model_type, checkpoint=checkpoint_path, device=device,
240- # )
234+ # Use the fine-tuned model for instance segmentation on test data to verify that it worked.
235+ # Note: You can also use the fine-tuned model within the micro_sam napari plugin
236+ # or within other python functions from the micro_sam library.
237+ run_instance_segmentation_with_decoder ( args . input_path , model_type = model_type , checkpoint = checkpoint_path )
241238
242239
243240if __name__ == "__main__" :
0 commit comments