Skip to content

Commit ca2e965

Browse files
Update finetuning example for workshop
1 parent 219852b commit ca2e965

1 file changed

Lines changed: 25 additions & 28 deletions

File tree

workshops/epfl_2026/finetune_sam.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import imageio.v3 as imageio
1919

2020
from torch_em.util.debug import check_loader
21-
from torch_em.util.util import get_random_colors
2221

2322
import micro_sam.training as sam_training
2423
from micro_sam.training.util import normalize_to_8bit
@@ -129,8 +128,8 @@ def run_finetuning(model_name, train_loader, val_loader, model_type, overwrite,
129128
best_checkpoint = os.path.join(save_root, "checkpoints", model_name, "best.pt")
130129
if os.path.exists(best_checkpoint) and not overwrite:
131130
print(
132-
"It looks like the training has completed. You must pass the argument '--overwrite' to overwrite "
133-
"the already finetuned model (or provide a new filepath at '--save_root' for training new models)."
131+
"It looks like the training has completed. Pass the argument '--overwrite' to overwrite "
132+
"the already finetuned model (or train a model with a new name using the argument '--model_name')."
134133
)
135134
return best_checkpoint
136135

@@ -148,45 +147,41 @@ def run_finetuning(model_name, train_loader, val_loader, model_type, overwrite,
148147
return best_checkpoint
149148

150149

151-
# TODO
152150
def run_instance_segmentation_with_decoder(input_root, model_type, checkpoint):
153-
"""Run automatic instance segmentation (AIS).
151+
"""Run automatic instance segmentation with the fine-tuned model.
154152
155153
Args:
156154
test_image_paths: List of filepaths for the test image data.
157155
model_type: The choice of Segment Anything model (connotated by the size of image encoder).
158156
checkpoint: Filepath to the finetuned model checkpoints.
159-
device: The torch device used for inference.
160157
"""
161158
assert os.path.exists(checkpoint), "Please train the model first to run inference on the finetuned model."
162159

160+
# CHANGE: For your own data you have to adapt these folder names.
161+
folder = os.path.join(input_root, "hpa/test/images")
162+
image_paths = glob(os.path.join(folder, "*.tif"))
163+
print(len(image_paths))
164+
163165
# Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.
164-
predictor, segmenter = get_predictor_and_segmenter(
165-
model_type=model_type, checkpoint=checkpoint, device=device, is_tiled=True
166-
)
166+
predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, is_tiled=True)
167167

168-
for image_path in test_image_paths:
168+
# Iterate over the training images to run the segmentation and visualize the result in napari.
169+
for image_path in image_paths:
169170
image = imageio.imread(image_path)
170171
image = normalize_to_8bit(image)
171172

172-
# Predicting the instances.
173+
# CHANGE: Update the values for 'tile_shape' and 'halo' so that they match the patch_shape used in training,
174+
# according to: patch_shape = tile_shape + 2 * halo.
175+
# For example, if you used patch_shape = (512, 512), you can use tile_shape = (384, 384) and halo = (64, 64)
173176
prediction = automatic_instance_segmentation(
174177
predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, tile_shape=(768, 768), halo=(128, 128)
175178
)
176179

177-
# Visualize the predictions
178-
fig, ax = plt.subplots(1, 2, figsize=(10, 10))
179-
180-
ax[0].imshow(image)
181-
ax[0].axis("off")
182-
ax[0].set_title("Input Image")
183-
184-
ax[1].imshow(prediction, cmap=get_random_colors(prediction), interpolation="nearest")
185-
ax[1].axis("off")
186-
ax[1].set_title("Predictions (AIS)")
187-
188-
plt.show()
189-
plt.close()
180+
import napari
181+
v = napari.Viewer()
182+
v.add_image(image)
183+
v.add_labels(prediction)
184+
napari.run()
190185

191186
break # comment this out in case you want to run inference for all images.
192187

@@ -224,6 +219,7 @@ def main():
224219
# - vit_b_histopathology: For nucleus segmentation in histopathology.
225220
model_type = "vit_b_lm"
226221

222+
# Get the data loaders and run the training.
227223
train_loader, val_loader = get_dataloaders(args.input_path, view=args.view)
228224
checkpoint_path = run_finetuning(
229225
model_name=args.model_name,
@@ -233,11 +229,12 @@ def main():
233229
overwrite=args.overwrite,
234230
n_epochs=args.n_epochs,
235231
)
232+
assert os.path.exists(checkpoint_path), checkpoint_path
236233

237-
# TODO
238-
# run_instance_segmentation_with_decoder(
239-
# test_image_paths=test_image_paths, model_type=model_type, checkpoint=checkpoint_path, device=device,
240-
# )
234+
# Use the fine-tuned model for instance segmentation on test data to verify that it worked.
235+
# Note: You can also use the fine-tuned model within the micro_sam napari plugin
236+
# or within other python functions from the micro_sam library.
237+
run_instance_segmentation_with_decoder(args.input_path, model_type=model_type, checkpoint=checkpoint_path)
241238

242239

243240
if __name__ == "__main__":

0 commit comments

Comments
 (0)