Skip to content

Commit 5464bd7

Browse files
committed
ADD: add BraTS segmentation sample application
BraTS sample application addition using BraTS 2020 multi-volume data Signed-off-by: Cavan Riley <cavan-riley@uiowa.edu>
1 parent 663b5a2 commit 5464bd7

5 files changed

Lines changed: 735 additions & 2 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
from typing import Any, Dict, Optional, Union
15+
16+
import lib.infers
17+
import lib.trainers
18+
from monai.networks.nets import SegResNet
19+
from monai.utils import optional_import
20+
21+
from monailabel.interfaces.config import TaskConfig
22+
from monailabel.interfaces.tasks.infer_v2 import InferTask
23+
from monailabel.interfaces.tasks.train import TrainTask
24+
from monailabel.utils.others.generic import download_file, strtobool
25+
26+
_, has_cp = optional_import("cupy")
27+
_, has_cucim = optional_import("cucim")
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class Segmentation(TaskConfig):
33+
def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs):
34+
"""Initializes the SegmentationBrats task."""
35+
super().init(name, model_dir, conf, planner, **kwargs)
36+
37+
# BraTS labels: 3 multi-label channels produced by ConvertToMultiChannelBasedOnBratsClassesd
38+
# Channel 0: TC - Tumor Core (label 2 OR label 3)
39+
# Channel 1: WT - Whole Tumor (label 1 OR label 2 OR label 3)
40+
# Channel 2: ET - Enhancing Tumor (label 2)
41+
self.labels = {
42+
"tumor core": 1, # Tumor Core
43+
"whole tumor": 2, # Whole Tumor
44+
"enhancing tumor": 3, # Enhancing Tumor
45+
}
46+
47+
# Model Files
48+
self.path = [
49+
os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained
50+
os.path.join(self.model_dir, f"{name}.pt"), # published
51+
]
52+
53+
# Download PreTrained Model (optional)
54+
if strtobool(self.conf.get("use_pretrained_model", "true")):
55+
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
56+
url = f"{url}/radiology_segmentation_segresnet_brats.pt"
57+
download_file(url, self.path[0])
58+
59+
# Spacing and ROI for BraTS (isotropic 1mm, large crop matching tutorial)
60+
self.target_spacing = (1.0, 1.0, 1.0)
61+
self.roi_size = (224, 224, 144)
62+
63+
# Number of input channels: 4 MRI modalities (FLAIR, T1, T1Gd, T2)
64+
# when multi_file=True the LoadDirectoryImagesd loader stacks them;
65+
# when multi_file=False the image file must already be a 4-channel volume.
66+
try:
67+
input_channels = int(self.conf.get("input_channels", 4))
68+
except (ValueError, TypeError):
69+
logger.warning("Could not parse input_channels, defaulting to 4")
70+
input_channels = 4
71+
72+
# Network
73+
self.network = SegResNet(
74+
blocks_down=(1, 2, 2, 4),
75+
blocks_up=(1, 1, 1),
76+
init_filters=16,
77+
in_channels=input_channels,
78+
out_channels=len(self.labels), # TC, WT, ET — sigmoid multilabel, no background channel
79+
dropout_prob=0.2,
80+
)
81+
82+
def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
83+
"""Creates the SegmentationBrats InferTask task."""
84+
task: InferTask = lib.infers.SegmentationBrats(
85+
path=self.path,
86+
network=self.network,
87+
roi_size=self.roi_size,
88+
target_spacing=self.target_spacing,
89+
labels=self.labels,
90+
preload=strtobool(self.conf.get("preload", "false")),
91+
config={"largest_cc": True if has_cp and has_cucim else False},
92+
)
93+
return task
94+
95+
def trainer(self) -> Optional[TrainTask]:
96+
"""Creates the SegmentationBrats Trainer task."""
97+
output_dir = os.path.join(self.model_dir, self.name)
98+
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]
99+
100+
task: TrainTask = lib.trainers.SegmentationBrats(
101+
model_dir=output_dir,
102+
network=self.network,
103+
roi_size=self.roi_size,
104+
target_spacing=self.target_spacing,
105+
load_path=load_path,
106+
publish_path=self.path[1],
107+
description="Train BraTS Segmentation Model (TC/WT/ET multilabel)",
108+
labels=self.labels,
109+
)
110+
return task
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Callable, Sequence
13+
14+
from lib.transforms.transforms import ConvertFromMultiChannelBasedOnBratsClassesd, GetCentroidsd, LoadDirectoryImagesd
15+
from monai.inferers import Inferer, SlidingWindowInferer
16+
from monai.transforms import (
17+
Activationsd,
18+
AsDiscreted,
19+
EnsureChannelFirstd,
20+
EnsureTyped,
21+
KeepLargestConnectedComponentd,
22+
LoadImaged,
23+
NormalizeIntensityd,
24+
Orientationd,
25+
Spacingd,
26+
)
27+
28+
from monailabel.interfaces.tasks.infer_v2 import InferType
29+
from monailabel.tasks.infer.basic_infer import BasicInferTask
30+
from monailabel.transform.post import Restored
31+
32+
33+
class SegmentationBrats(BasicInferTask):
34+
"""
35+
Inference Engine for BraTS brain tumour segmentation using a SegResNet.
36+
37+
The model outputs 3 channels (TC, WT, ET) with sigmoid activations — it is
38+
a multilabel task, NOT a softmax classification. Each channel is thresholded
39+
independently at 0.5 to produce binary maps.
40+
41+
Two image loading modes are supported (set via ``data["multi_file"]``):
42+
- False (default): the input image is a single 4-channel NIfTI volume.
43+
- True: ``data["image"]`` is a directory containing 4 single-
44+
modality NIfTI files; LoadDirectoryImagesd stacks them.
45+
"""
46+
47+
def __init__(
48+
self,
49+
path,
50+
network=None,
51+
target_spacing=(1.0, 1.0, 1.0),
52+
type=InferType.SEGMENTATION,
53+
labels=None,
54+
dimension=3,
55+
description="Pre-trained BraTS SegResNet — TC/WT/ET multilabel segmentation",
56+
**kwargs,
57+
):
58+
"""
59+
Args:
60+
path: path(s) to the model checkpoint(s).
61+
network: optional pre-instantiated network; if None the checkpoint
62+
is loaded directly.
63+
target_spacing: voxel spacing to resample images to before inference.
64+
type: inference type tag (default SEGMENTATION).
65+
labels: label name → integer index mapping.
66+
dimension: spatial dimension of the model (3 for volumetric).
67+
description: human-readable description surfaced in the REST API.
68+
**kwargs: forwarded to ``BasicInferTask``.
69+
"""
70+
super().__init__(
71+
path=path,
72+
network=network,
73+
type=type,
74+
labels=labels,
75+
dimension=dimension,
76+
description=description,
77+
load_strict=False,
78+
**kwargs,
79+
)
80+
self.target_spacing = target_spacing
81+
82+
def pre_transforms(self, data=None) -> Sequence[Callable]:
83+
"""
84+
Pre-processing pipeline matching the official MONAI BraTS tutorial.
85+
86+
NOTE: ScaleIntensityRangePercentilesd and CenterSpatialCropd from the
87+
original file have been removed — they are not part of the BraTS pipeline
88+
and would distort MRI intensity normalisation. NormalizeIntensityd with
89+
nonzero=True, channel_wise=True is the correct approach for multi-modal MRI.
90+
"""
91+
data = data or {}
92+
channels = data.get("input_channels", 4)
93+
t = [
94+
(
95+
LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True)
96+
if data.get("multi_file", False) is False
97+
else LoadDirectoryImagesd(
98+
keys="image",
99+
target_spacing=self.target_spacing,
100+
channels=channels,
101+
)
102+
),
103+
EnsureTyped(keys="image", device=data.get("device") if data else None),
104+
# EnsureChannelFirstd is safe to keep as a guard; if the channel dim is
105+
# already present (ITKReader + ensure_channel_first) it is a no-op.
106+
EnsureChannelFirstd(keys="image", channel_dim=0),
107+
Orientationd(keys="image", axcodes="RAS"),
108+
Spacingd(
109+
keys="image",
110+
pixdim=self.target_spacing,
111+
allow_missing_keys=True,
112+
),
113+
# Channel-wise intensity normalisation on non-zero voxels only.
114+
# This matches both the tutorial and the training pipeline exactly.
115+
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
116+
]
117+
return t
118+
119+
def inferer(self, data=None) -> Inferer:
120+
"""Return a SlidingWindowInferer configured for BraTS volumetric inference."""
121+
return SlidingWindowInferer(
122+
roi_size=self.roi_size,
123+
sw_batch_size=2,
124+
overlap=0.4,
125+
padding_mode="replicate",
126+
mode="gaussian",
127+
)
128+
129+
def inverse_transforms(self, data=None):
130+
"""No inverse transforms needed; Restored handles spatial restoration directly."""
131+
return []
132+
133+
def post_transforms(self, data=None) -> Sequence[Callable]:
134+
"""
135+
Post-processing for multilabel sigmoid output.
136+
137+
IMPORTANT differences from a softmax segmentation:
138+
- Activationsd uses sigmoid=True (not softmax=True).
139+
- AsDiscreted thresholds each channel at 0.5 independently
140+
(not argmax, because channels are not mutually exclusive).
141+
- KeepLargestConnectedComponentd is applied per-channel if available.
142+
"""
143+
data = data or {}
144+
t = [
145+
EnsureTyped(keys="pred", device=data.get("device") if data else None),
146+
# Sigmoid: each of the 3 channels (TC, WT, ET) is activated independently.
147+
Activationsd(keys="pred", sigmoid=True),
148+
# Threshold each channel at 0.5 to produce binary masks.
149+
AsDiscreted(keys="pred", threshold=0.5),
150+
]
151+
152+
if data and data.get("largest_cc", False):
153+
# Apply per-channel so TC, WT and ET are each cleaned independently.
154+
t.append(
155+
KeepLargestConnectedComponentd(
156+
keys="pred",
157+
independent=True, # treat each channel separately
158+
)
159+
)
160+
161+
t.extend(
162+
[
163+
# Merge 3 binary channels → single-channel integer label map
164+
# Must happen before Restored so spatial metadata is applied
165+
# to the final (1, H, W, D) output, not the intermediate (3, H, W, D).
166+
ConvertFromMultiChannelBasedOnBratsClassesd(keys="pred"),
167+
Restored(
168+
keys="pred",
169+
ref_image="image",
170+
config_labels=self.labels if data.get("restore_label_idx", False) else None,
171+
),
172+
GetCentroidsd(keys="pred", centroids_key="centroids"),
173+
]
174+
)
175+
return t

0 commit comments

Comments
 (0)