[Paper] [Demo] [Checkpoint]
The VISTA3D is a foundation model trained systematically on 11,454 volumes encompassing 127 types of human anatomical structures and various lesions. It provides accurate out-of-the-box segmentation that matches state-of-the-art supervised models which are trained on each dataset. The model also achieves state-of-the-art zero-shot interactive segmentation in 3D, representing a promising step toward developing a versatile medical image foundation model.
For supported 127 classes, the model can perform highly accurate out-of-box segmentation. The fully automated process adopts a patch-based sliding-window inference and only requires a class prompt. Compared to supervised segmentation models trained on each dataset separately, VISTA3D showed comparable out-of-box performances and strong generalizability ('VISTA3D auto' in Table.1).
The interactive segmentation is based on user-provided clicks. Each click point will impact a local 3D patch. User can either effectively refine the automatic results with clicks ('VISTA3D auto+point' in Table.1) or simply provide a click without specifying the target class ('VISTA3D point' in Table.1) .
VISTA3D is built to produce visually plausible segmentations on previously unseen classes. This capability makes the model even more flexible and accelerates practical segmentation data curation processes.
VISTA3D checkpoint showed improvements when finetuning in few-shot settings. Once a few annotated examples are provided, user can start finetune with the VISTA3D checkpoint.
To perform inference locally with a debugger GUI, simply install
git clone https://github.com/Project-MONAI/VISTA.git;
cd ./VISTA/vista3d;
pip intall -r requirements.txt
Download the model checkpoint and save it at ./models/model.pt.
The NIM Demo (VISTA3D NVIDIA Inference Microservices) does not support medical data upload due to legal concerns. We provide scripts for inference locally. The automatic segmentation label definition can be found at label_dict.
- We provide the
infer.pyscript and its light-weight front-enddebugger.py. User can directly lauch a local interface for both automatic and interactive segmentation.
python -m scripts.debugger run
or directly call infer.py to generate automatic segmentation. To segment a liver (label_prompt=1 as defined in label_dict), run
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt "[1]" --save_mask true
To segment everything, run
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
The output path and other configs can be changed in the configs/infer.yaml
- The MONAI bundle wraps VISTA3D and provides a unified API for inference, and the NIM Demo deploys the bundle with an interactive front-end. Although NIM Demo cannot run locally, the bundle is available and can run locally. The running enviroment requires a monai docker. The MONAI bundle is more suitable for automatic segmentattion in batch.
docker pull projectmonai/monai:1.3.2
All dataset must contain a json data list file. We provide the json lists for all our training data in data/jsons. More details can be found here. For datasets used in VISTA3D training, we already included the json splits and registered their data specific label index to the global index as label_mapping and their data path coded in ./data/datasets.py. The supported global class index is defined in label_dict. To generate supervoxels, refer to the instruction.
VISTA3D has four stages training. The configurations represents the training procedure but may not fully reproduce the weights of VISTA3D since each stage has multiple rounds with slightly varying configuration changes.
export CUDA_VISIBLE_DEVICES=0; python -m scripts.train run --config_file "['configs/train/hyper_parameters_stage1.yaml']"
Execute multi-GPU model training (the codebase also supports multi-node training):
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.train run --config_file "['configs/train/hyper_parameters_stage1.yaml']"
We provide code for supported class fully automatic dice score evaluation (val_multigpu_point_patch), point click only (val_multigpu_point_patch), and auto + point (val_multigpu_autopoint_patch).
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_point_patch run --config_file "['configs/supported_eval/infer_patch_auto.yaml']" --dataset_name 'xxxx'
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_point_patch run --config_file "['configs/supported_eval/infer_patch_point.yaml']" --dataset_name 'xxxx'
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_autopoint_patch run --config_file "['configs/supported_eval/infer_patch_autopoint.yaml']" --dataset_name 'xxxx'
For zero-shot, we perform iterative point sampling. To create a new zero-shot evaluation dataset, user only need to change label_set in the json config to match the class indexes in the original groundtruth.
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_point_iterative run --config_file "['configs/zeroshot_eval/infer_iter_point_hcc.yaml']"
For finetuning, user need to change label_set and mapped_label_set in the json config, where label_set matches the index values in the groundtruth files. The mapped_label_set can be random selected but we recommend pick the most related global index defined in label_dict. User should modify the transforms, resolutions, patch sizes e.t.c regarding to their dataset for optimal finetuning performances, we recommend using configs generated by auto3dseg. The learning rate 5e-5 should be good enough for finetuning purposes.
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.train_finetune run --config_file "['configs/finetune/train_finetune_word.yaml']"
We provide scripts to run SAM2 evaluation. Modify SAM2 source code to support background remove: Add z_slice to sam2_video_predictor.py. Require SAM2 package installation
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
z_slice=None
):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
)
if z_slice is not None:
images = images[z_slice]
Run evaluation
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_sam2_point_iterative run --config_file "['configs/supported_eval/infer_sam2_point.yaml']" --saliency False --dataset_name 'Task06'
Join the conversation on Twitter @ProjectMONAI or join our Slack channel.
Ask and answer questions on MONAI VISTA's GitHub discussions tab.
The codebase is under Apache 2.0 Licence. The model weight is released under NVIDIA OneWay Noncommercial License.
@article{he2024vista3d,
title={VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography},
author={He, Yufan and Guo, Pengfei and Tang, Yucheng and Myronenko, Andriy and Nath, Vishwesh and Xu, Ziyue and Yang, Dong and Zhao, Can and Simon, Benjamin and Belue, Mason and others},
journal={arXiv preprint arXiv:2406.05285},
year={2024}
}






