Skip to content

Commit 65d3a17

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0a04161 commit 65d3a17

3 files changed

Lines changed: 26 additions & 15 deletions

File tree

vista3d/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ mv model-zoo/models/vista3d vista3dbundle & rm -rf model-zoo
9696
cd vista3dbundle
9797
mkdir models
9898
# minor model weights naming conversion due to monai version change
99-
wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt
99+
wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt
100100
```
101-
MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys.
101+
MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys.
102102
```python
103103
# Automatic Segment everything
104104
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz'}
@@ -108,7 +108,7 @@ python -m monai.bundle run --config_file configs/inference.json --input_dict "{'
108108
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','label_prompt':[3]}
109109
```
110110
```python
111-
# Interactive segmentation
111+
# Interactive segmentation
112112
# Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]]. Point labels can only be -1(ignore), 0(negative), 1(positive) and 2(negative for special overlaped class like tumor), 3(positive for special class). Only supporting 1 class per inference. The output 255 represents NaN value which means not processed region.
113113
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','points':[[128,128,16], [100,100,16]],'point_labels':[1, 0]}"
114114
```
@@ -158,7 +158,7 @@ python -m monai.bundle run --config_file="['configs/inference.json', 'configs/ba
158158
### 1.1 Overlapped classes and postprocessing with [ShapeKit](https://arxiv.org/pdf/2506.24003)
159159
VISTA3D is trained with binary segmentation, and may produce false positives due to weak false positive supervision. ShapeKit solves this problem with sophisticated postprocessing. ShapeKit requires segmentation mask for each class. VISTA3D by default performs argmax and collaps overlapping classes. Change the `monai.apps.vista3d.transforms.VistaPostTransformd` in `inference.json` to save each class segmentation as a separate channel. Then follow [ShapeKit](https://github.com/BodyMaps/ShapeKit) codebase for processing.
160160
```json
161-
{
161+
{
162162
"_target_": "Activationsd",
163163
"sigmoid": true,
164164
"keys": "pred"
@@ -180,7 +180,7 @@ To segment everything, run
180180
```bash
181181
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
182182
```
183-
To segment based on point clicks, provide `point` and `point_label`.
183+
To segment based on point clicks, provide `point` and `point_label`.
184184
```bash
185185
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --point "[[128,128,16],[100,100,6]]" --point_label "[1,0]" --save_mask true
186186
```

vista3d/cvpr_workshop/infer_cvpr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from train_cvpr import ROI_SIZE
1818

19+
1920
def convert_clicks(alldata):
2021
# indexes = list(alldata.keys())
2122
# data = [alldata[i] for i in indexes]

vista3d/cvpr_workshop/train_cvpr.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import matplotlib.pyplot as plt
2323

2424
NUM_PATCHES_PER_IMAGE = 2
25-
ROI_SIZE= [128, 128, 128]
25+
ROI_SIZE = [128, 128, 128]
26+
2627

2728
def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs):
2829
"""
@@ -109,7 +110,7 @@ def __getitem__(self, idx):
109110
keys=["image", "label"],
110111
label_key="label",
111112
num_classes=label.max() + 1,
112-
ratios=tuple(float(i > 0) for i in range(label.max()+1)),
113+
ratios=tuple(float(i > 0) for i in range(label.max() + 1)),
113114
num_samples=NUM_PATCHES_PER_IMAGE,
114115
),
115116
monai.transforms.RandScaleIntensityd(
@@ -137,17 +138,19 @@ def __getitem__(self, idx):
137138
mode=["constant", "constant"],
138139
keys=["image", "label"],
139140
spatial_size=ROI_SIZE,
140-
)
141+
),
141142
]
142143
)
143144
data = transforms(data)
144145
return data
145146

147+
146148
import re
147149

150+
148151
def get_latest_epoch(directory):
149152
# Pattern to match filenames like 'model_epoch<number>.pth'
150-
pattern = re.compile(r'model_epoch(\d+)\.pth')
153+
pattern = re.compile(r"model_epoch(\d+)\.pth")
151154
max_epoch = -1
152155

153156
for filename in os.listdir(directory):
@@ -159,6 +162,7 @@ def get_latest_epoch(directory):
159162

160163
return max_epoch if max_epoch != -1 else None
161164

165+
162166
# Training function
163167
def train():
164168
json_file = "allset.json" # Update with your JSON file
@@ -169,7 +173,6 @@ def train():
169173
start_epoch = get_latest_epoch(checkpoint_dir)
170174
start_checkpoint = "./CPRR25_vista3D_model_final_10percent_data.pth"
171175

172-
173176
os.makedirs(checkpoint_dir, exist_ok=True)
174177
dist.init_process_group(backend="nccl")
175178
world_size = int(os.environ["WORLD_SIZE"])
@@ -189,11 +192,12 @@ def train():
189192
model.load_state_dict(pretrained_ckpt, strict=True)
190193
else:
191194
print(f"Resuming from epoch {start_epoch}")
192-
pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
193-
model.load_state_dict(pretrained_ckpt['model'], strict=True)
195+
pretrained_ckpt = torch.load(
196+
os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")
197+
)
198+
model.load_state_dict(pretrained_ckpt["model"], strict=True)
194199
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
195200

196-
197201
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)
198202
lr_scheduler = monai.optimizers.WarmupCosineSchedule(
199203
optimizer=optimizer,
@@ -265,10 +269,16 @@ def train():
265269
if local_rank == 0:
266270
writer.add_scalar("loss", loss.item(), step)
267271
if local_rank == 0 and (epoch + 1) % save_interval == 0:
268-
checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch + 1}.pth")
272+
checkpoint_path = os.path.join(
273+
checkpoint_dir, f"model_epoch{epoch + 1}.pth"
274+
)
269275
if world_size > 1:
270276
torch.save(
271-
{"model": model.module.state_dict(), "epoch": epoch + 1, "step": step},
277+
{
278+
"model": model.module.state_dict(),
279+
"epoch": epoch + 1,
280+
"step": step,
281+
},
272282
checkpoint_path,
273283
)
274284
print(

0 commit comments

Comments
 (0)