Skip to content

Commit 7009326

Browse files
committed
updated weights download for contrast phase algorithm
1 parent 7de895f commit 7009326

3 files changed

Lines changed: 45 additions & 5 deletions

File tree

comp2comp/contrast_phase/contrast_inf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,9 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False):
454454
text_file.write('{},{:.3f}\n'.format(phase_dict[i], y_pred_proba[i]))
455455

456456
print('Predicted phase: ' + pred_phase)
457+
print('\nProbabilities:')
457458
for i in range(len(y_pred_proba)):
458-
print('{},{:.3f}'.format(phase_dict[i], y_pred_proba[i]))
459+
print('{:<20}{:.3f}'.format(phase_dict[i], y_pred_proba[i]))
459460

460461
output_path_images = os.path.join(outputPath, "images")
461462
if not os.path.exists(output_path_images):

comp2comp/contrast_phase/contrast_phase.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from pathlib import Path
33
from time import time
44
from typing import Union
5+
import subprocess
6+
import zipfile
57

68
from totalsegmentator.libs import (
7-
download_pretrained_weights,
9+
# download_pretrained_weights,
810
nostdout,
911
setup_nnunet,
1012
)
@@ -67,9 +69,12 @@ def run_segmentation(
6769
task_id = [251]
6870

6971
setup_nnunet()
70-
for task_id in [251]:
71-
download_pretrained_weights(task_id)
72+
# for task_id in [251]:
73+
# download_pretrained_weights(task_id)
7274

75+
# download with weight for id 251
76+
self.download_pretrained_weights_updated(task_id[0])
77+
7378
from totalsegmentator.nnunet import nnUNet_predict_image
7479

7580
with nostdout():
@@ -130,7 +135,40 @@ def run_segmentation(
130135

131136
# return seg, img
132137
return seg, img
133-
138+
139+
def download_pretrained_weights_updated(self, task_id):
140+
'''
141+
Download the weights with curl to resolve problems
142+
with downloading from Zenodo
143+
'''
144+
home_path = Path(os.environ["SCRATCH"])
145+
config_dir = home_path / ".totalsegmentator/nnunet/results/nnUNet"
146+
(config_dir / "3d_fullres").mkdir(exist_ok=True, parents=True)
147+
(config_dir / "2d").mkdir(exist_ok=True, parents=True)
148+
149+
url = "https://zenodo.org/records/6802342/files/Task251_TotalSegmentator_part1_organs_1139subj.zip?download=1"
150+
config_dir = config_dir / "3d_fullres"
151+
weights_path = config_dir / "Task251_TotalSegmentator_part1_organs_1139subj"
152+
tempfile = config_dir / "tmp_download_file.zip"
153+
154+
if not weights_path.exists():
155+
print('Downloading weights..')
156+
subprocess.run(
157+
["curl", "-L", url, "-o", tempfile],
158+
check=True
159+
)
160+
161+
print('Unzipping..')
162+
with zipfile.ZipFile(config_dir / "tmp_download_file.zip", 'r') as zip_f:
163+
zip_f.extractall(config_dir)
164+
# print(f" downloaded in {time.time()-st:.2f}s")
165+
if tempfile.exists():
166+
os.remove(tempfile)
167+
print('Done.')
168+
else:
169+
print('Weights are already downloaded')
170+
171+
134172
def convertNibToNumpy(self, TSNib, ImageNib):
135173
"""Convert nifti to numpy array.
136174

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def get_version():
6565
"dicom2nifti<2.6",
6666
"torch==2.5.1",
6767
"torchvision==0.20.1",
68+
"xgboost",
6869
],
6970
extras_require={
7071
"all": ["shapely", "psutil"],

0 commit comments

Comments
 (0)