Skip to content

Commit 8c0a297

Browse files
Add support for training from in-memory data (#1135)
1 parent 3f1540c commit 8c0a297

3 files changed

Lines changed: 120 additions & 35 deletions

File tree

environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies:
2323
- pytorch >=2.5
2424
- segment-anything
2525
- torchvision
26-
- torch_em >=0.7.10
26+
- torch_em >=0.8
2727
- tqdm
2828
- timm
2929
- trackastra

micro_sam/training/training.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
99

1010
import imageio.v3 as imageio
11-
11+
import numpy as np
1212
import torch
1313
from torch.optim import Optimizer
1414
from torch.utils.data import random_split
@@ -538,10 +538,12 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
538538
path = raw_paths
539539
else:
540540
path = raw_paths[0]
541-
assert isinstance(path, (str, os.PathLike))
541+
assert isinstance(path, (str, os.PathLike, np.ndarray, torch.Tensor))
542542

543543
# Check the underlying data dimensionality.
544-
if raw_key is None: # If no key is given then we assume it's an image file.
544+
if raw_key is None and isinstance(path, (np.ndarray, torch.Tensor)):
545+
ndim = path.ndim
546+
elif raw_key is None: # If no key is given and this is not a tensor/array then we assume it's an image file.
545547
ndim = imageio.imread(path).ndim
546548
else: # Otherwise we try to open the file from key.
547549
try: # First try to open it with elf.
@@ -566,10 +568,53 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
566568
return patch_shape
567569

568570

571+
def _determine_dataset_and_channels(raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs):
572+
# By default, let the 'default_segmentation_dataset' heuristic decide for itself.
573+
is_seg_dataset = kwargs.pop("is_seg_dataset", None)
574+
575+
# Check if the input data is in numpy or torch. In this case determine we use
576+
# the image collection dataset heuristic (torch_em will figure it out),
577+
# and we determine the number of channels.
578+
if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)):
579+
is_seg_dataset = False
580+
if with_channels is None:
581+
with_channels = raw_paths[0].ndim == 3 and raw_paths.shape[-1] == 3
582+
return is_seg_dataset, with_channels
583+
584+
# Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
585+
# Get valid raw paths to make checks possible.
586+
if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image.
587+
rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
588+
else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
589+
rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
590+
591+
# Load one of the raw inputs to validate whether it is RGB or not.
592+
test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
593+
if test_raw_inputs.ndim == 3:
594+
if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last.
595+
is_seg_dataset = False # we use 'ImageCollectionDataset' in this case.
596+
# We need to provide a list of inputs to 'ImageCollectionDataset'.
597+
raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
598+
label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
599+
600+
# This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
601+
with_channels = False if with_channels is None else with_channels
602+
603+
elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first.
604+
# This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
605+
with_channels = True if with_channels is None else with_channels
606+
607+
# Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
608+
# Otherwise, let the user make the choice as priority, else set this to our suggested default.
609+
with_channels = False if with_channels is None else with_channels
610+
611+
return is_seg_dataset, with_channels
612+
613+
569614
def default_sam_dataset(
570-
raw_paths: Union[List[FilePath], FilePath],
615+
raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
571616
raw_key: Optional[str],
572-
label_paths: Union[List[FilePath], FilePath],
617+
label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
573618
label_key: Optional[str],
574619
patch_shape: Tuple[int],
575620
with_segmentation_decoder: bool,
@@ -590,12 +635,16 @@ def default_sam_dataset(
590635
Args:
591636
raw_paths: The path(s) to the image data used for training.
592637
Can either be multiple 2D images or volumetric data.
638+
The data can also be passed as a list of numpy arrays or torch tensors.
593639
raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
594640
or a glob pattern for selecting multiple files.
641+
Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
595642
label_paths: The path(s) to the label data used for training.
596643
Can either be multiple 2D images or volumetric data.
644+
The data can also be passed as a list of numpy arrays or torch tensors.
597645
label_key: The key for accessing the label data. Internal filepath for hdf5-like input
598646
or a glob pattern for selecting multiple files.
647+
Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
599648
patch_shape: The shape for training patches.
600649
with_segmentation_decoder: Whether to train with additional segmentation decoder.
601650
with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
@@ -634,35 +683,9 @@ def default_sam_dataset(
634683
if sampler is None and not train_instance_segmentation_only:
635684
sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
636685

637-
# By default, let the 'default_segmentation_dataset' heuristic decide for itself.
638-
is_seg_dataset = kwargs.pop("is_seg_dataset", None)
639-
640-
# Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
641-
# Get valid raw paths to make checks possible.
642-
if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image.
643-
rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
644-
else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
645-
rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
646-
647-
# Load one of the raw inputs to validate whether it is RGB or not.
648-
test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
649-
if test_raw_inputs.ndim == 3:
650-
if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last.
651-
is_seg_dataset = False # we use 'ImageCollectionDataset' in this case.
652-
# We need to provide a list of inputs to 'ImageCollectionDataset'.
653-
raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
654-
label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
655-
656-
# This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
657-
with_channels = False if with_channels is None else with_channels
658-
659-
elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first.
660-
# This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
661-
with_channels = True if with_channels is None else with_channels
662-
663-
# Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
664-
# Otherwise, let the user make the choice as priority, else set this to our suggested default.
665-
with_channels = False if with_channels is None else with_channels
686+
is_seg_dataset, with_channels = _determine_dataset_and_channels(
687+
raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs
688+
)
666689

667690
# Set the data transformations.
668691
if raw_transform is None:

test/test_training.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,68 @@
1010
from micro_sam.util import VIT_T_SUPPORT, get_sam_model, SamPredictor
1111

1212

13+
class TestDataset(unittest.TestCase):
14+
tmp_folder = "./tmp-dataset"
15+
16+
def setUp(self):
17+
self.image_dir = os.path.join(self.tmp_folder, "synthetic-data", "images")
18+
self.label_dir = os.path.join(self.tmp_folder, "synthetic-data", "labels")
19+
shape = (512, 512)
20+
21+
os.makedirs(self.image_dir, exist_ok=True)
22+
os.makedirs(self.label_dir, exist_ok=True)
23+
24+
n_images = 5
25+
for idx in range(n_images):
26+
image_path = os.path.join(self.image_dir, f"data-{idx}.tif")
27+
label_path = os.path.join(self.label_dir, f"data-{idx}.tif")
28+
29+
image, labels = synthetic_data(shape)
30+
imageio.imwrite(image_path, image)
31+
imageio.imwrite(label_path, labels)
32+
33+
def tearDown(self):
34+
try:
35+
rmtree(self.tmp_folder)
36+
except OSError:
37+
pass
38+
39+
def _check_dataset(self, ds, patch_shape, exp_type):
40+
self.assertIsInstance(ds, exp_type)
41+
self.assertEqual(ds._ndim, 2)
42+
43+
expected_im_shape = (1,) + patch_shape
44+
expected_label_shape = (4,) + patch_shape
45+
for i in range(5):
46+
x, y = ds[i]
47+
self.assertEqual(x.shape, expected_im_shape)
48+
self.assertEqual(y.shape, expected_label_shape)
49+
50+
def test_default_sam_dataset(self):
51+
from micro_sam.training.training import default_sam_dataset
52+
from torch_em.data import SegmentationDataset
53+
54+
patch_shape = (512, 512)
55+
ds = default_sam_dataset(
56+
self.image_dir, "*.tif", self.label_dir, "*.tif", patch_shape, with_segmentation_decoder=True
57+
)
58+
self._check_dataset(ds, patch_shape, SegmentationDataset)
59+
60+
def test_default_sam_dataset_with_numpy_data(self):
61+
from micro_sam.training.training import default_sam_dataset
62+
from torch_em.data import TensorDataset
63+
64+
patch_shape = (512, 512)
65+
images = sorted(glob(os.path.join(self.image_dir, "*.tif")))
66+
images = [imageio.imread(im) for im in images]
67+
labels = sorted(glob(os.path.join(self.label_dir, "*.tif")))
68+
labels = [imageio.imread(lab) for lab in labels]
69+
ds = default_sam_dataset(
70+
images, None, labels, None, patch_shape, with_segmentation_decoder=True
71+
)
72+
self._check_dataset(ds, patch_shape, TensorDataset)
73+
74+
1375
@unittest.skip("Not working in CI")
1476
@unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.")
1577
class TestTraining(unittest.TestCase):

0 commit comments

Comments
 (0)