|
2 | 2 | from pathlib import Path |
3 | 3 | from time import time |
4 | 4 | from typing import Union |
| 5 | +import subprocess |
| 6 | +import zipfile |
5 | 7 |
|
6 | 8 | from totalsegmentator.libs import ( |
7 | | - download_pretrained_weights, |
| 9 | + # download_pretrained_weights, |
8 | 10 | nostdout, |
9 | 11 | setup_nnunet, |
10 | 12 | ) |
@@ -67,9 +69,12 @@ def run_segmentation( |
67 | 69 | task_id = [251] |
68 | 70 |
|
69 | 71 | setup_nnunet() |
70 | | - for task_id in [251]: |
71 | | - download_pretrained_weights(task_id) |
| 72 | + # for task_id in [251]: |
| 73 | + # download_pretrained_weights(task_id) |
72 | 74 |
|
| 75 | + # download with weight for id 251 |
| 76 | + self.download_pretrained_weights_updated(task_id[0]) |
| 77 | + |
73 | 78 | from totalsegmentator.nnunet import nnUNet_predict_image |
74 | 79 |
|
75 | 80 | with nostdout(): |
@@ -130,7 +135,40 @@ def run_segmentation( |
130 | 135 |
|
131 | 136 | # return seg, img |
132 | 137 | return seg, img |
133 | | - |
| 138 | + |
| 139 | + def download_pretrained_weights_updated(self, task_id): |
| 140 | + ''' |
| 141 | + Download the weights with curl to resolve problems |
| 142 | + with downloading from Zenodo |
| 143 | + ''' |
| 144 | + home_path = Path(os.environ["SCRATCH"]) |
| 145 | + config_dir = home_path / ".totalsegmentator/nnunet/results/nnUNet" |
| 146 | + (config_dir / "3d_fullres").mkdir(exist_ok=True, parents=True) |
| 147 | + (config_dir / "2d").mkdir(exist_ok=True, parents=True) |
| 148 | + |
| 149 | + url = "https://zenodo.org/records/6802342/files/Task251_TotalSegmentator_part1_organs_1139subj.zip?download=1" |
| 150 | + config_dir = config_dir / "3d_fullres" |
| 151 | + weights_path = config_dir / "Task251_TotalSegmentator_part1_organs_1139subj" |
| 152 | + tempfile = config_dir / "tmp_download_file.zip" |
| 153 | + |
| 154 | + if not weights_path.exists(): |
| 155 | + print('Downloading weights..') |
| 156 | + subprocess.run( |
| 157 | + ["curl", "-L", url, "-o", tempfile], |
| 158 | + check=True |
| 159 | + ) |
| 160 | + |
| 161 | + print('Unzipping..') |
| 162 | + with zipfile.ZipFile(config_dir / "tmp_download_file.zip", 'r') as zip_f: |
| 163 | + zip_f.extractall(config_dir) |
| 164 | + # print(f" downloaded in {time.time()-st:.2f}s") |
| 165 | + if tempfile.exists(): |
| 166 | + os.remove(tempfile) |
| 167 | + print('Done.') |
| 168 | + else: |
| 169 | + print('Weights are already downloaded') |
| 170 | + |
| 171 | + |
134 | 172 | def convertNibToNumpy(self, TSNib, ImageNib): |
135 | 173 | """Convert nifti to numpy array. |
136 | 174 |
|
|
0 commit comments