Skip to content

Commit 82896b0

Browse files
authored
SMIT CT Lung GTV segmentation model (#108)
Submitting the SMIT CT Lung GTV segmentation model for mHub
1 parent ef742c0 commit 82896b0

24 files changed

Lines changed: 22224 additions & 0 deletions
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
general:
2+
data_base_dir: /app/data
3+
version: 1.0.0
4+
description: Default configuration for SMIT model (dicom to dicom)
5+
6+
execute:
7+
- DicomImporter
8+
- NiftiConverter
9+
- SMITRunner
10+
- DsegConverter
11+
- DataOrganizer
12+
13+
modules:
14+
DicomImporter:
15+
source_dir: input_data
16+
import_dir: sorted_data
17+
sort_data: true
18+
meta:
19+
mod: '%Modality'
20+
21+
NiftiConverter:
22+
engine: dcm2niix
23+
24+
DsegConverter:
25+
model_name: SMIT
26+
body_part_examined: CHEST
27+
source_segs: nifti:mod=seg
28+
skip_empty_slices: true
29+
30+
DataOrganizer:
31+
targets:
32+
- dicomseg:mod=seg-->[i:sid]/msk_smit_lung_gtv.seg.dcm
33+
34+
sample:
35+
input:
36+
dicom/: Folder with DICOM files of one or more CT scans.
37+
output:
38+
1.3.6.1.4.1.14519.5.2.1.7311.5101.160028252338004527274326500702/msk_smit_lung_gtv.seg.dcm: The DICOM SEG file with Lung GTV segmentation (arbitrary series ID foldername).
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
FROM mhubai/base:latest
2+
3+
# Update authors label
4+
LABEL authors="jiangj1@mskcc.org,aptea@mskcc.org,deasyj@mskcc.org,iyera@mskcc.org,locastre@mskcc.org"
5+
6+
SHELL ["/bin/bash", "-c"]
7+
8+
RUN apt update
9+
10+
ARG MHUB_MODELS_REPO
11+
ENV MODEL_NAME=msk_smit_lung_gtv
12+
RUN buildutils/import_mhub_model.sh msk_smit_lung_gtv ${MHUB_MODELS_REPO}
13+
14+
RUN source /app/models/msk_smit_lung_gtv/src/get_weights.sh
15+
16+
RUN uv venv --python-preference only-managed -p 3.9 .venv39
17+
RUN uv pip install -n -p .venv39 --extra-index-url https://download.pytorch.org/whl/cu116 torch==1.12.1+cu116
18+
RUN uv pip install -n -p .venv39 simpleitk==2.2.1 nibabel==4.0.2 monai==0.8.0 timm==0.6.11 ml-collections==0.1.1 einops==0.8.1 scikit-image==0.19.3 Cmake imagecodecs monai==0.8.0 torchaudio==0.12.1 pytorch-ignite==0.4.8
19+
RUN uv pip install -n -p .venv39 numpy==1.23.4
20+
21+
ENTRYPOINT ["mhub.run"]
22+
CMD ["--config", "/app/models/msk_smit_lung_gtv/config/default.yml"]

models/msk_smit_lung_gtv/meta.json

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
{
2+
"id": "",
3+
"name": "msk_smit_lung_gtv",
4+
"title": "Self-supervised 3D segmentation using self-distilled masked image transformer for Lung GTV Segmentation",
5+
"summary": {
6+
"description": "A Lung GTV segmentation model, fine-tuned from a foundation model pretrained with 10K CT scans",
7+
"inputs": [
8+
{
9+
"label": "Input Image",
10+
"description": "The CT scan of a patient.",
11+
"format": "NIFTI",
12+
"modality": "CT",
13+
"bodypartexamined": "Chest",
14+
"slicethickness": "5mm",
15+
"contrast": true,
16+
"non-contrast": true
17+
}
18+
],
19+
"outputs": [
20+
{
21+
"label": "Segmentation",
22+
"description": "Segmentation of the lung GTV for input CT images.",
23+
"type": "Segmentation",
24+
"classes": [
25+
"LUNG+NEOPLASM_MALIGNANT_PRIMARY"
26+
]
27+
}
28+
],
29+
"model": {
30+
"architecture": "Swin3D Transformer",
31+
"training": "supervised",
32+
"cmpapproach": "3D"
33+
},
34+
"data": {
35+
"training": {
36+
"vol_samples": 377
37+
},
38+
"evaluation": {
39+
"vol_samples": 139
40+
},
41+
"public": true,
42+
"external": false
43+
}
44+
},
45+
"details": {
46+
"name": "Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT)",
47+
"version": "1.0.0",
48+
"devteam": "",
49+
"authors": ["Jue Jiang, Harini Veeraraghavan"],
50+
"type": "it is a 3D Swin transformer based segmentation net, which was pretrained with 10K CT data and then finetuned for Lung GTV Segmentation",
51+
"date": {
52+
"code": "11.03.2025",
53+
"weights": "11.03.2025",
54+
"pub": "15.07.2024"
55+
},
56+
"cite": "Jiang, Jue, and Harini Veeraraghavan. Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation. Proceedings of machine learning research 250 (2024): 708.",
57+
"license": {
58+
"code": "GNU General Public License",
59+
"weights": "GNU General Public License"
60+
},
61+
"publications": [
62+
{
63+
"title": "Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation",
64+
"uri": "https://openreview.net/pdf?id=G9Te2IevNm"
65+
},
66+
{
67+
"title":"Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT)",
68+
"uri":"https://link.springer.com/chapter/10.1007/978-3-031-16440-8_53"
69+
}
70+
],
71+
"github": "https://github.com/The-Veeraraghavan-Lab/CTRobust_Transformers.git"
72+
},
73+
"info": {
74+
"use": {
75+
"title": "Intended use",
76+
"text": "This model is intended to be used on CT images (with or without contrast)",
77+
"references": [],
78+
"tables": []
79+
80+
},
81+
"evaluation": {
82+
"title": "Evaluation data",
83+
"text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.",
84+
"references": [],
85+
"tables": [],
86+
"limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices"
87+
},
88+
"training": {
89+
"title": "Training data",
90+
"text": "Training data was from 377 data in the TCIA NSCLC-Radiomics data, references: Aerts, H. J. W. L., Wee, L., Rios Velazquez, E., Leijenaar, R. T. H., Parmar, C., Grossmann, P., Carvalho, S., Bussink, J., Monshouwer, R., Haibe-Kains, B., Rietveld, D., Hoebers, F., Rietbergen, M. M., Leemans, C. R., Dekker, A., Quackenbush, J., Gillies, R. J., Lambin, P. (2014). Data From NSCLC-Radiomics (version 4) [Data set]. The Cancer Imaging Archive."
91+
92+
},
93+
"analyses": {
94+
"title": "Evaluation",
95+
"text": "Evaluation was determined with DICE score, See the paper (Methods, Section 4.2, section on Experiments and evaluation metrics, and Results 5.1, Table 2 for additional details.",
96+
"references": [
97+
{
98+
"label": "Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation",
99+
"uri": "https://proceedings.mlr.press/v250/jiang24b.html"
100+
}
101+
],
102+
"tables": [
103+
{
104+
"label": "Dice scores",
105+
"entries": {
106+
"From Scratch": "0.54 ±0.31",
107+
"This model": "0.69 ±0.18"
108+
}
109+
}
110+
111+
]
112+
},
113+
"limitations": {
114+
"title": "Limitations",
115+
"text": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices"
116+
}
117+
}
118+
}

models/msk_smit_lung_gtv/mhub.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[model.deployment]
2+
test = "https://zenodo.org/records/15270887/files/msk_smit_lung_gtv.test.zip"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## References
2+
[1] Jiang, Jue, and Harini Veeraraghavan. "Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation." In Medical Imaging with Deep Learning. 2024.
3+
4+
[2] Jiang, Jue, Neelam Tyagi, Kathryn Tringale, Christopher Crane, and Harini Veeraraghavan. "Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT)." In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 556-566. Cham: Springer Nature Switzerland, 2022.
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Callable, List, Sequence, Tuple, Union
13+
14+
import torch
15+
import torch.nn.functional as F
16+
17+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
18+
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option, optional_import
19+
20+
import time
21+
22+
tqdm, _ = optional_import("tqdm", name="tqdm")
23+
24+
__all__ = ["sliding_window_inference"]
25+
26+
27+
def sliding_window_inference(
28+
inputs: torch.Tensor,
29+
roi_size: Union[Sequence[int], int],
30+
sw_batch_size: int,
31+
predictor: Callable[..., torch.Tensor],
32+
overlap: float = 0.25,
33+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
34+
sigma_scale: Union[Sequence[float], float] = 0.125,
35+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
36+
cval: float = 0.0,
37+
sw_device: Union[torch.device, str, None] = None,
38+
device: Union[torch.device, str, None] = None,
39+
*args: Any,
40+
**kwargs: Any,
41+
) -> torch.Tensor:
42+
"""
43+
Sliding window inference on `inputs` with `predictor`.
44+
45+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
46+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
47+
48+
Args:
49+
inputs: input image to be processed (assuming NCHW[D])
50+
roi_size: the spatial window size for inferences.
51+
When its components have None or non-positives, the corresponding inputs dimension will be used.
52+
if the components of the `roi_size` are non-positive values, the transform will use the
53+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
54+
to `(32, 64)` if the second spatial dimension size of img is `64`.
55+
sw_batch_size: the batch size to run window slices.
56+
predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)`
57+
should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D];
58+
where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`.
59+
overlap: Amount of overlap between scans.
60+
mode: {``"constant"``, ``"gaussian"``}
61+
How to blend output of overlapping windows. Defaults to ``"constant"``.
62+
63+
- ``"constant``": gives equal weight to all predictions.
64+
- ``"gaussian``": gives less weight to predictions on edges of windows.
65+
66+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
67+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
68+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
69+
spatial dimensions.
70+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
71+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
72+
See also: https://pytorch.org/docs/stable/nn.functional.html#pad
73+
cval: fill value for 'constant' padding mode. Default: 0
74+
sw_device: device for the window data.
75+
By default the device (and accordingly the memory) of the `inputs` is used.
76+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
77+
device: device for the stitched output prediction.
78+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
79+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
80+
`inputs` and `roi_size`. Output is on the `device`.
81+
args: optional args to be passed to ``predictor``.
82+
kwargs: optional keyword args to be passed to ``predictor``.
83+
84+
Note:
85+
- input must be channel-first and have a batch dim, supports N-D sliding window.
86+
87+
"""
88+
num_spatial_dims = len(inputs.shape) - 2
89+
if overlap < 0 or overlap >= 1:
90+
raise AssertionError("overlap must be >= 0 and < 1.")
91+
92+
# determine image spatial size and batch size
93+
# Note: all input images must have the same image size and batch size
94+
image_size_ = list(inputs.shape[2:])
95+
batch_size = inputs.shape[0]
96+
97+
if device is None:
98+
device = inputs.device
99+
if sw_device is None:
100+
sw_device = inputs.device
101+
102+
roi_size = fall_back_tuple(roi_size, image_size_)
103+
# in case that image size is smaller than roi size
104+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
105+
pad_size = []
106+
for k in range(len(inputs.shape) - 1, 1, -1):
107+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
108+
half = diff // 2
109+
pad_size.extend([half, diff - half])
110+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
111+
112+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
113+
114+
# Store all slices in list
115+
slices = dense_patch_slices(image_size, roi_size, scan_interval)
116+
num_win = len(slices) # number of windows per image
117+
total_slices = num_win * batch_size # total number of windows
118+
119+
# Create window-level importance map
120+
importance_map = compute_importance_map(
121+
get_valid_patch_size(image_size, roi_size), mode=mode, sigma_scale=sigma_scale, device=device
122+
)
123+
importance_map=importance_map.cpu()
124+
# Perform predictions
125+
output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)
126+
_initialized = False
127+
for slice_g in range(0, total_slices, sw_batch_size):
128+
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
129+
unravel_slice = [
130+
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
131+
for idx in slice_range
132+
]
133+
window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
134+
seg_prob = predictor(window_data, *args, **kwargs).to(device) # batched patch segmentation
135+
136+
if not _initialized: # init. buffer at the first iteration
137+
output_classes = seg_prob.shape[1]
138+
output_shape = [batch_size, output_classes] + list(image_size)
139+
# allocate memory to store the full output and the count for overlapping parts
140+
#output_image = torch.zeros(output_shape, dtype=torch.float32, device=device)
141+
#count_map = torch.zeros(output_shape, dtype=torch.float32, device=device)
142+
143+
output_image = torch.zeros(output_shape, dtype=torch.float32, device='cpu')
144+
count_map = torch.zeros(output_shape, dtype=torch.float32, device='cpu')
145+
146+
_initialized = True
147+
148+
# store the result in the proper location of the full output. Apply weights from importance map.
149+
for idx, original_idx in zip(slice_range, unravel_slice):
150+
output_image[original_idx] += importance_map * seg_prob[idx - slice_g].cpu()
151+
count_map[original_idx] += importance_map
152+
153+
# account for any overlapping sections
154+
output_image = output_image / count_map
155+
156+
final_slicing: List[slice] = []
157+
for sp in range(num_spatial_dims):
158+
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
159+
final_slicing.insert(0, slice_dim)
160+
while len(final_slicing) < len(output_image.shape):
161+
final_slicing.insert(0, slice(None))
162+
return output_image[final_slicing]
163+
164+
165+
def _get_scan_interval(
166+
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
167+
) -> Tuple[int, ...]:
168+
"""
169+
Compute scan interval according to the image size, roi size and overlap.
170+
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
171+
use 1 instead to make sure sliding window works.
172+
173+
"""
174+
if len(image_size) != num_spatial_dims:
175+
raise ValueError("image coord different from spatial dims.")
176+
if len(roi_size) != num_spatial_dims:
177+
raise ValueError("roi coord different from spatial dims.")
178+
179+
scan_interval = []
180+
for i in range(num_spatial_dims):
181+
if roi_size[i] == image_size[i]:
182+
scan_interval.append(int(roi_size[i]))
183+
else:
184+
interval = int(roi_size[i] * (1 - overlap))
185+
scan_interval.append(interval if interval > 0 else 1)
186+
return tuple(scan_interval)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash
2+
3+
DEST="$1"
4+
5+
if [ -z "$DEST" ]
6+
then
7+
DEST=/app
8+
fi
9+
10+
MODEL_NAME=msk_smit_lung_gtv
11+
WEIGHTS_HASH=H4sIADC3/mcAAwXByRGAIAwAwL/FkJFDwW4wqDCAYQwPtHp3Y++NN4DKGVHsNARSBY7+OQJw9z0hUDG30lnyG42S2UZnSwrz6tTQJy1NXN/0A15deWNIAAAA
12+
WEIGHTS_URL=`base64 -d <<<${WEIGHTS_HASH} | gunzip`
13+
wget $WEIGHTS_URL -O weights.tar.gz
14+
tar xvf weights.tar.gz -C $DEST/models/${MODEL_NAME}/src && rm weights.tar.gz

0 commit comments

Comments
 (0)