88from typing import Any , Callable , Dict , List , Optional , Tuple , Union
99
1010import imageio .v3 as imageio
11-
11+ import numpy as np
1212import torch
1313from torch .optim import Optimizer
1414from torch .utils .data import random_split
@@ -538,10 +538,12 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
538538 path = raw_paths
539539 else :
540540 path = raw_paths [0 ]
541- assert isinstance (path , (str , os .PathLike ))
541+ assert isinstance (path , (str , os .PathLike , np . ndarray , torch . Tensor ))
542542
543543 # Check the underlying data dimensionality.
544- if raw_key is None : # If no key is given then we assume it's an image file.
544+ if raw_key is None and isinstance (path , (np .ndarray , torch .Tensor )):
545+ ndim = path .ndim
546+ elif raw_key is None : # If no key is given and this is not a tensor/array then we assume it's an image file.
545547 ndim = imageio .imread (path ).ndim
546548 else : # Otherwise we try to open the file from key.
547549 try : # First try to open it with elf.
@@ -566,10 +568,53 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
566568 return patch_shape
567569
568570
571+ def _determine_dataset_and_channels (raw_paths , raw_key , label_paths , label_key , with_channels , ** kwargs ):
572+ # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
573+ is_seg_dataset = kwargs .pop ("is_seg_dataset" , None )
574+
575+ # Check if the input data is in numpy or torch. In this case determine we use
576+ # the image collection dataset heuristic (torch_em will figure it out),
577+ # and we determine the number of channels.
578+ if isinstance (raw_paths , list ) and isinstance (raw_paths [0 ], (np .ndarray , torch .Tensor )):
579+ is_seg_dataset = False
580+ if with_channels is None :
581+ with_channels = raw_paths [0 ].ndim == 3 and raw_paths .shape [- 1 ] == 3
582+ return is_seg_dataset , with_channels
583+
584+ # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
585+ # Get valid raw paths to make checks possible.
586+ if raw_key and "*" in raw_key : # Use the wildcard pattern to find the filepath to only one image.
587+ rpath = glob (os .path .join (raw_paths if isinstance (raw_paths , str ) else raw_paths [0 ], raw_key ))[0 ]
588+ else : # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
589+ rpath = raw_paths if isinstance (raw_paths , str ) else raw_paths [0 ]
590+
591+ # Load one of the raw inputs to validate whether it is RGB or not.
592+ test_raw_inputs = load_data (path = rpath , key = raw_key if raw_key and "*" not in raw_key else None )
593+ if test_raw_inputs .ndim == 3 :
594+ if test_raw_inputs .shape [- 1 ] == 3 : # i.e. if it is an RGB image and has channels last.
595+ is_seg_dataset = False # we use 'ImageCollectionDataset' in this case.
596+ # We need to provide a list of inputs to 'ImageCollectionDataset'.
597+ raw_paths = [raw_paths ] if isinstance (raw_paths , str ) else raw_paths
598+ label_paths = [label_paths ] if isinstance (label_paths , str ) else label_paths
599+
600+ # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
601+ with_channels = False if with_channels is None else with_channels
602+
603+ elif test_raw_inputs .shape [0 ] == 3 : # i.e. if it is a RGB image and has 3 channels first.
604+ # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
605+ with_channels = True if with_channels is None else with_channels
606+
607+ # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
608+ # Otherwise, let the user make the choice as priority, else set this to our suggested default.
609+ with_channels = False if with_channels is None else with_channels
610+
611+ return is_seg_dataset , with_channels
612+
613+
569614def default_sam_dataset (
570- raw_paths : Union [List [FilePath ], FilePath ],
615+ raw_paths : Union [List [Union [ np . ndarray , torch . Tensor ]], List [ FilePath ], FilePath ],
571616 raw_key : Optional [str ],
572- label_paths : Union [List [FilePath ], FilePath ],
617+ label_paths : Union [List [Union [ np . ndarray , torch . Tensor ]], List [ FilePath ], FilePath ],
573618 label_key : Optional [str ],
574619 patch_shape : Tuple [int ],
575620 with_segmentation_decoder : bool ,
@@ -590,12 +635,16 @@ def default_sam_dataset(
590635 Args:
591636 raw_paths: The path(s) to the image data used for training.
592637 Can either be multiple 2D images or volumetric data.
638+ The data can also be passed as a list of numpy arrays or torch tensors.
593639 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
594640 or a glob pattern for selecting multiple files.
641+ Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
595642 label_paths: The path(s) to the label data used for training.
596643 Can either be multiple 2D images or volumetric data.
644+ The data can also be passed as a list of numpy arrays or torch tensors.
597645 label_key: The key for accessing the label data. Internal filepath for hdf5-like input
598646 or a glob pattern for selecting multiple files.
647+ Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
599648 patch_shape: The shape for training patches.
600649 with_segmentation_decoder: Whether to train with additional segmentation decoder.
601650 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
@@ -634,35 +683,9 @@ def default_sam_dataset(
634683 if sampler is None and not train_instance_segmentation_only :
635684 sampler = torch_em .data .sampler .MinInstanceSampler (2 , min_size = min_size )
636685
637- # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
638- is_seg_dataset = kwargs .pop ("is_seg_dataset" , None )
639-
640- # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
641- # Get valid raw paths to make checks possible.
642- if raw_key and "*" in raw_key : # Use the wildcard pattern to find the filepath to only one image.
643- rpath = glob (os .path .join (raw_paths if isinstance (raw_paths , str ) else raw_paths [0 ], raw_key ))[0 ]
644- else : # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
645- rpath = raw_paths if isinstance (raw_paths , str ) else raw_paths [0 ]
646-
647- # Load one of the raw inputs to validate whether it is RGB or not.
648- test_raw_inputs = load_data (path = rpath , key = raw_key if raw_key and "*" not in raw_key else None )
649- if test_raw_inputs .ndim == 3 :
650- if test_raw_inputs .shape [- 1 ] == 3 : # i.e. if it is an RGB image and has channels last.
651- is_seg_dataset = False # we use 'ImageCollectionDataset' in this case.
652- # We need to provide a list of inputs to 'ImageCollectionDataset'.
653- raw_paths = [raw_paths ] if isinstance (raw_paths , str ) else raw_paths
654- label_paths = [label_paths ] if isinstance (label_paths , str ) else label_paths
655-
656- # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
657- with_channels = False if with_channels is None else with_channels
658-
659- elif test_raw_inputs .shape [0 ] == 3 : # i.e. if it is a RGB image and has 3 channels first.
660- # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
661- with_channels = True if with_channels is None else with_channels
662-
663- # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
664- # Otherwise, let the user make the choice as priority, else set this to our suggested default.
665- with_channels = False if with_channels is None else with_channels
686+ is_seg_dataset , with_channels = _determine_dataset_and_channels (
687+ raw_paths , raw_key , label_paths , label_key , with_channels , ** kwargs
688+ )
666689
667690 # Set the data transformations.
668691 if raw_transform is None :
0 commit comments