Skip to content

Commit 588b299

Browse files
committed
Update readme
Signed-off-by: heyufan1995 <heyufan1995@gmail.com>
1 parent 7cfd7d9 commit 588b299

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

vista3d/cvpr_workshop/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License.
1515
This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It
1616
is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI.
1717

18-
1918
It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing)
2019

2120
# Setup
@@ -38,4 +37,3 @@ docker save -o vista3d.tar.gz vista3d:latest
3837
```
3938

4039

41-

vista3d/cvpr_workshop/train_cvpr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,15 @@ def __getitem__(self, idx):
104104
return data
105105
# Training function
106106
def train():
107+
json_file = "subset.json" # Update with your JSON file
107108
json_file = "subset.json" # Update with your JSON file
108109
epoch_number = 100
109110
start_epoch = 0
110111
lr = 2e-5
111112
checkpoint_dir = "checkpoints"
112113
start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth'
114+
start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth'
115+
113116
os.makedirs(checkpoint_dir, exist_ok=True)
114117
dist.init_process_group(backend="nccl")
115118
world_size = int(os.environ["WORLD_SIZE"])
@@ -122,6 +125,8 @@ def train():
122125
model = vista3d132(in_channels=1).to(device)
123126
pretrained_ckpt = torch.load(start_checkpoint, map_location=device)
124127
# pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
128+
pretrained_ckpt = torch.load(start_checkpoint, map_location=device)
129+
# pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
125130
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
126131
model.load_state_dict(pretrained_ckpt['model'], strict=True)
127132
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)

0 commit comments

Comments
 (0)