Skip to content

Commit 057130e

Browse files
committed
ADD: add multifile and multichannel flag to local datastore
Adds flags to the local datastore init signaling the data is either multichannel or multi-file (each directory has a sample with multiple volumes Signed-off-by: Cavan Riley <cavan-riley@uiowa.edu>
1 parent 068fc81 commit 057130e

7 files changed

Lines changed: 134 additions & 21 deletions

File tree

monailabel/datastore/local.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@ def __init__(
102102
images_dir: str = ".",
103103
labels_dir: str = "labels",
104104
datastore_config: str = "datastore_v2.json",
105-
extensions=("*.nii.gz", "*.nii"),
105+
extensions=("*.nii.gz", "*.nii", "*.nrrd"),
106106
auto_reload=False,
107107
read_only=False,
108+
multichannel=False,
109+
multi_file=False,
108110
):
109111
"""
110112
Creates a `LocalDataset` object
@@ -124,6 +126,8 @@ def __init__(
124126
self._ignore_event_config = False
125127
self._config_ts = 0
126128
self._auto_reload = auto_reload
129+
self._multichannel = multichannel
130+
self._multi_file = multi_file
127131

128132
logging.getLogger("filelock").setLevel(logging.ERROR)
129133

@@ -256,6 +260,12 @@ def datalist(self, full_path=True) -> List[Dict[str, Any]]:
256260
ds = json.loads(json.dumps(ds).replace(f"{self._datastore_path.rstrip(os.pathsep)}{os.pathsep}", ""))
257261
return ds
258262

263+
def get_is_multichannel(self) -> bool:
264+
return self._multichannel
265+
266+
def get_is_multi_file(self) -> bool:
267+
return self._multi_file
268+
259269
def get_image(self, image_id: str, params=None) -> Any:
260270
"""
261271
Retrieve image object based on image id
@@ -431,6 +441,29 @@ def refresh(self):
431441
"""
432442
self._reconcile_datastore()
433443

444+
def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str:
445+
id = os.path.basename(filename)
446+
if not directory_id:
447+
directory_id = id
448+
449+
logger.info(f"Adding Image: {directory_id} => {filename}")
450+
name = directory_id
451+
dest = os.path.realpath(os.path.join(self._datastore.image_path(), name))
452+
453+
with FileLock(self._lock_file):
454+
logger.debug("Acquired the lock!")
455+
shutil.copy(filename, dest)
456+
457+
info = info if info else {}
458+
info["ts"] = int(time.time())
459+
info["name"] = name
460+
461+
# images = get_directory_contents(filename)
462+
self._datastore.objects[directory_id] = ImageLabelModel(image=DataModel(info=info, ext=""))
463+
self._update_datastore_file(lock=False)
464+
logger.debug("Released the lock!")
465+
return directory_id
466+
434467
def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str:
435468
id, image_ext = self._to_id(os.path.basename(image_filename))
436469
if not image_id:
@@ -552,10 +585,15 @@ def _list_files(self, path, patterns):
552585
files = os.listdir(path)
553586

554587
filtered = dict()
555-
for pattern in patterns:
556-
matching = fnmatch.filter(files, pattern)
557-
for file in matching:
558-
filtered[os.path.basename(file)] = file
588+
if not self._multi_file:
589+
for pattern in patterns:
590+
matching = fnmatch.filter(files, pattern)
591+
for file in matching:
592+
filtered[os.path.basename(file)] = file
593+
else:
594+
for file in files:
595+
if file.lower() not in ["labels", ".lock", "datastore_v2.json"]:
596+
filtered[os.path.basename(file)] = file
559597
return filtered
560598

561599
def _reconcile_datastore(self):
@@ -585,23 +623,26 @@ def _add_non_existing_images(self) -> int:
585623
invalidate = 0
586624
self._init_from_datastore_file()
587625

588-
local_images = self._list_files(self._datastore.image_path(), self._extensions)
626+
local_files = self._list_files(self._datastore.image_path(), self._extensions)
589627

590-
image_ids = list(self._datastore.objects.keys())
591-
for image_file in local_images:
592-
image_id, image_ext = self._to_id(image_file)
593-
if image_id not in image_ids:
594-
logger.info(f"Adding New Image: {image_id} => {image_file}")
628+
ids = list(self._datastore.objects.keys())
629+
for file in local_files:
630+
if self._multi_file:
631+
# Directories have no extension — use the name as-is
632+
file_id = file
633+
file_ext_str = ""
634+
else:
635+
file_id, file_ext_str = self._to_id(file)
595636

596-
name = self._filename(image_id, image_ext)
597-
image_info = {
637+
if file_id not in ids:
638+
logger.info(f"Adding New Image: {file_id} => {file}")
639+
name = self._filename(file_id, file_ext_str)
640+
file_info = {
598641
"ts": int(time.time()),
599-
# "checksum": file_checksum(os.path.join(self._datastore.image_path(), name)),
600642
"name": name,
601643
}
602-
603644
invalidate += 1
604-
self._datastore.objects[image_id] = ImageLabelModel(image=DataModel(info=image_info, ext=image_ext))
645+
self._datastore.objects[file_id] = ImageLabelModel(name=DataModel(info=file_info, ext=file_ext_str))
605646

606647
return invalidate
607648

monailabel/endpoints/datastore.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def add_image(
6868
logger.info(f"Image: {image}; File: {file}; params: {params}")
6969
file_ext = "".join(pathlib.Path(file.filename).suffixes) if file.filename else ".nii.gz"
7070

71-
image_id = image if image else os.path.basename(file.filename).replace(file_ext, "")
71+
id = image if image else os.path.basename(file.filename).replace(file_ext, "")
7272
image_file = tempfile.NamedTemporaryFile(suffix=file_ext).name
7373

7474
with open(image_file, "wb") as buffer:
@@ -79,8 +79,12 @@ def add_image(
7979
save_params: Dict[str, Any] = json.loads(params) if params else {}
8080
if user:
8181
save_params["user"] = user
82-
image_id = instance.datastore().add_image(image_id, image_file, save_params)
83-
return {"image": image_id}
82+
if not instance.datastore().get_is_multi_file():
83+
image_id = instance.datastore().add_image(id, image_file, save_params)
84+
return {"image": image_id}
85+
else:
86+
directory_id = instance.datastore().add_directory(id, image_file, save_params)
87+
return {"image": directory_id}
8488

8589

8690
def remove_image(id: str, user: Optional[str] = None):

monailabel/interfaces/app.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def __init__(
9090
self.app_dir = app_dir
9191
self.studies = studies
9292
self.conf = conf if conf else {}
93-
93+
self.multichannel = conf.get("multichannel", False)
94+
self.multi_file = conf.get("multi_file", False)
95+
self.input_channels = conf.get("input_channels", False)
9496
self.name = name
9597
self.description = description
9698
self.version = version
@@ -142,6 +144,8 @@ def init_datastore(self) -> Datastore:
142144
extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT,
143145
auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD,
144146
read_only=settings.MONAI_LABEL_DATASTORE_READ_ONLY,
147+
multichannel=self.multichannel,
148+
multi_file=self.multi_file,
145149
)
146150

147151
def init_remote_datastore(self) -> Datastore:
@@ -277,6 +281,10 @@ def infer(self, request, datastore=None):
277281
f"Inference Task is not Initialized. There is no model '{model}' available",
278282
)
279283

284+
request["multi_file"] = self.multi_file
285+
request["multichannel"] = self.multichannel
286+
request["input_channels"] = self.input_channels
287+
280288
request = copy.deepcopy(request)
281289
request["description"] = task.description
282290

@@ -288,7 +296,7 @@ def infer(self, request, datastore=None):
288296
else:
289297
request["image"] = datastore.get_image_uri(request["image"])
290298

291-
if os.path.isdir(request["image"]):
299+
if os.path.isdir(request["image"]) and not self.multi_file:
292300
logger.info("Input is a Directory; Consider it as DICOM")
293301

294302
logger.debug(f"Image => {request['image']}")
@@ -419,6 +427,11 @@ def train(self, request):
419427
f"Train Task is not Initialized. There is no model '{model}' available; {request}",
420428
)
421429

430+
# 4D image support, send train task information regarding data
431+
request["multi_file"] = self.multi_file
432+
request["multichannel"] = self.multichannel
433+
request["input_channels"] = self.input_channels
434+
422435
request = copy.deepcopy(request)
423436
result = task(request, self.datastore())
424437

monailabel/interfaces/datastore.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,19 @@ def refresh(self) -> None:
201201
"""
202202
pass
203203

204+
# TODO: look into implementing for all datastore children and make abstract
205+
# @abstractmethod
206+
def add_directory(self, id: str, filename: str, info: Dict[str, Any]) -> str:
207+
"""
208+
Save a directory for the given directory id and return the newly saved directory's id
209+
210+
:param id: the directory id for the image; If None then base filename will be used
211+
:param filename: the path to the directory
212+
:param info: additional info for the directory
213+
:return: the directory id for the saved image filename
214+
"""
215+
pass
216+
204217
@abstractmethod
205218
def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str:
206219
"""
@@ -279,3 +292,19 @@ def json(self):
279292
Return json representation of datastore
280293
"""
281294
pass
295+
296+
# TODO: look into implementing for all datastore children and make abstract
297+
# @abstractmethod
298+
def get_is_multichannel(self) -> bool:
299+
"""
300+
Returns whether the application's studies is directed at multichannel (4D) data
301+
"""
302+
pass
303+
304+
# TODO: look into implementing for all datastore children and make abstract
305+
# @abstractmethod
306+
def get_is_multi_file(self) -> bool:
307+
"""
308+
Returns whether the application's studies is directed at directories containing multiple images per sample
309+
"""
310+
pass

monailabel/tasks/activelearning/first.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,13 @@ def __call__(self, request, datastore: Datastore):
3535
images.sort()
3636
image = images[0]
3737

38+
# If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences
39+
if datastore.get_is_multichannel():
40+
return {"id": image, "multichannel": True}
41+
42+
# If the datastore is multi_file, each sample has a directory with multiple images
43+
if datastore.get_is_multi_file():
44+
return {"id": image, "multi_file": True}
45+
3846
logger.info(f"First: Selected Image: {image}")
3947
return {"id": image}

monailabel/tasks/activelearning/random.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,17 @@ def __call__(self, request, datastore: Datastore):
4545
image = random.choices(images, weights=weights)[0]
4646
logger.debug(f"Random: Images: {images}; Weight: {weights}")
4747
logger.info(f"Random: Selected Image: {image}; Weight: {weights[0]}")
48+
49+
# If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences
50+
if datastore.get_is_multichannel():
51+
return {"id": image, "weight": weights[0], "multichannel": True}
52+
53+
# If the datastore is multi_file, each sample has a directory with multiple images
54+
if datastore.get_is_multi_file():
55+
return {
56+
"id": image,
57+
"weight": weights[0],
58+
"multi_file": True,
59+
} # this will send the directory and we will walk it later on
60+
4861
return {"id": image, "weight": weights[0]}

monailabel/tasks/train/basic_train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self):
8383
self.multi_gpu = False # multi gpu enabled
8484
self.local_rank = 0 # local rank in case of multi gpu
8585
self.world_size = 0 # world size in case of multi gpu
86+
self.input_channels = 1
87+
self.multi_file = False
8688

8789
self.request = None
8890
self.trainer = None
@@ -490,6 +492,9 @@ def train(self, rank, world_size, request, datalist):
490492

491493
context.run_id = request["run_id"]
492494
context.multi_gpu = request["multi_gpu"]
495+
context.multi_file = request.get("multi_file", False)
496+
context.input_channels = request.get("input_channels", 1)
497+
493498
if context.multi_gpu:
494499
os.environ["LOCAL_RANK"] = str(context.local_rank)
495500

0 commit comments

Comments
 (0)