Skip to content

Commit 6e3201f

Browse files
committed
Initial profiling/export
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
1 parent 6448762 commit 6e3201f

5 files changed

Lines changed: 379 additions & 6 deletions

File tree

scripts/export.bash

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
2+
3+
python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true

scripts/export.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import os
14+
import sys
15+
from functools import partial
16+
17+
import monai
18+
import numpy as np
19+
import torch
20+
import torch.distributed as dist
21+
from monai import transforms
22+
from monai.apps.auto3dseg.auto_runner import logger
23+
from monai.auto3dseg.utils import datafold_read
24+
from monai.bundle import ConfigParser
25+
from monai.bundle.scripts import _pop_args, _update_args
26+
from monai.data import decollate_batch, list_data_collate, partition_dataset
27+
from monai.utils import optional_import
28+
29+
from vista3d import vista_model_registry
30+
31+
from .sliding_window import point_based_window_inferer, sliding_window_inference
32+
from .train import CONFIG
33+
from .utils.trans_utils import VistaPostTransform
34+
35+
rearrange, _ = optional_import("einops", name="rearrange")
36+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
37+
IGNORE_PROMPT = set(
38+
[
39+
2, # kidney
40+
16, # prostate or uterus
41+
18, # rectum
42+
20, # lung
43+
21, # bone
44+
23, # lung tumor
45+
24, # pancreatic tumor
46+
25, # hepatic vessel
47+
26, # hepatic tumor
48+
27, # colon cancer primaries
49+
128, # bone lesion
50+
129, # kidney mass
51+
130, # liver tumor
52+
131, # vertebrae L6
53+
132,
54+
]
55+
) # airway
56+
EVERYTHING_PROMPT = list(set([i + 1 for i in range(133)]) - IGNORE_PROMPT)
57+
58+
59+
def infer_wrapper(inputs, model, **kwargs):
60+
outputs = model(input_images=inputs, **kwargs)
61+
return outputs.transpose(1, 0)
62+
63+
64+
class InferClass:
65+
def __init__(self, config_file="./configs/infer.yaml", **override):
66+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
67+
68+
_args = _update_args(config_file=config_file, **override)
69+
config_file_ = _pop_args(_args, "config_file")[0]
70+
71+
parser = ConfigParser()
72+
parser.read_config(config_file_)
73+
parser.update(pairs=_args)
74+
75+
# We do not use AMP for export
76+
self.amp = False # parser.get_parsed_content("amp")
77+
input_channels = parser.get_parsed_content("input_channels")
78+
patch_size = parser.get_parsed_content("patch_size")
79+
self.patch_size = patch_size
80+
81+
ckpt_name = parser.get_parsed_content("infer")["ckpt_name"]
82+
output_path = parser.get_parsed_content("infer")["output_path"]
83+
if not os.path.exists(output_path):
84+
os.makedirs(output_path, exist_ok=True)
85+
86+
CONFIG["handlers"]["file"]["filename"] = parser.get_parsed_content("infer")[
87+
"log_output_file"
88+
]
89+
logging.config.dictConfig(CONFIG)
90+
self.infer_transforms = parser.get_parsed_content("transforms_infer")
91+
92+
self.device = torch.device("cuda:0")
93+
model_registry = parser.get_parsed_content("model")
94+
model = vista_model_registry[model_registry](
95+
in_channels=input_channels, image_size=patch_size
96+
)
97+
self.model = model.to(self.device)
98+
99+
pretrained_ckpt = torch.load(ckpt_name, map_location=self.device)
100+
self.model.load_state_dict(pretrained_ckpt, strict=False)
101+
logger.debug(f"[debug] checkpoint {ckpt_name:s} loaded")
102+
post_transforms = [
103+
VistaPostTransform(keys="pred"),
104+
transforms.Invertd(
105+
keys="pred",
106+
transform=self.infer_transforms,
107+
orig_keys="image",
108+
meta_keys="pred_meta_dict",
109+
orig_meta_keys="image_meta_dict",
110+
meta_key_postfix="meta_dict",
111+
nearest_interp=True,
112+
to_tensor=True,
113+
),
114+
]
115+
116+
# For Vista3d, sigmoid is always used, but for visualization, argmax is needed
117+
save_transforms = [
118+
transforms.SaveImaged(
119+
keys="pred",
120+
meta_keys="pred_meta_dict",
121+
output_dir=output_path,
122+
output_postfix="seg",
123+
resample=False,
124+
data_root_dir=None,
125+
print_log=False,
126+
)
127+
]
128+
self.post_transforms = transforms.Compose(post_transforms)
129+
self.save_transforms = transforms.Compose(save_transforms)
130+
self.prev_mask = None
131+
self.batch_data = None
132+
return
133+
134+
def clear_cache(self):
135+
self.prev_mask = None
136+
self.batch_data = None
137+
138+
def transform_points(self, point, affine):
139+
"""transform point to the coordinates of the transformed image
140+
point: numpy array [bs, N, 3]
141+
"""
142+
bs, N = point.shape[:2]
143+
point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1)
144+
point = rearrange(point, "b n d -> d (b n)")
145+
point = affine @ point
146+
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
147+
return point
148+
149+
@torch.no_grad()
150+
def infer(
151+
self,
152+
image_file,
153+
point=None,
154+
point_label=None,
155+
label_prompt=None,
156+
prompt_class=None,
157+
save_mask=False,
158+
point_start=0,
159+
):
160+
"""Infer a single image_file. If save_mask is true, save the argmax prediction to disk. If false,
161+
do not save and return the probability maps (usually used by autorunner emsembler). point_start is
162+
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
163+
time and avoid repeated inference. This is by default disabled.
164+
"""
165+
self.model.eval()
166+
if not isinstance(image_file, dict):
167+
image_file = {"image": image_file}
168+
if self.batch_data is not None:
169+
batch_data = self.batch_data
170+
else:
171+
batch_data = self.infer_transforms(image_file)
172+
batch_data["label_prompt"] = label_prompt
173+
batch_data = list_data_collate([batch_data])
174+
self.batch_data = batch_data
175+
if point is not None:
176+
point = self.transform_points(
177+
point,
178+
np.linalg.inv(batch_data["image"].affine[0])
179+
@ batch_data["image"].meta["original_affine"][0].numpy(),
180+
)
181+
self.sliding_window_inferer = partial(
182+
point_based_window_inferer, point_start=point_start
183+
)
184+
else:
185+
self.sliding_window_inferer = sliding_window_inference
186+
device_list_input = [self.device, self.device, "cpu"]
187+
device_list_output = [self.device, "cpu", "cpu"]
188+
for _device_in, _device_out in zip(device_list_input, device_list_output):
189+
try:
190+
with torch.cuda.amp.autocast(enabled=self.amp):
191+
batch_data["pred"] = self.sliding_window_inferer(
192+
inputs=batch_data["image"].to(_device_in),
193+
roi_size=self.patch_size,
194+
sw_batch_size=1,
195+
predictor=partial(infer_wrapper, model=self.model),
196+
mode="gaussian",
197+
overlap=0.625,
198+
progress=True,
199+
sw_device=self.device,
200+
device=_device_out,
201+
point_coords=(
202+
torch.tensor(point).to(_device_in)
203+
if point is not None
204+
else None
205+
),
206+
point_labels=(
207+
torch.tensor(point_label).to(_device_in)
208+
if point_label is not None
209+
else None
210+
),
211+
class_vector=(
212+
torch.tensor(label_prompt).to(_device_in)
213+
if label_prompt is not None
214+
else None
215+
),
216+
prompt_class=(
217+
torch.tensor(prompt_class).to(_device_in)
218+
if prompt_class is not None
219+
else None
220+
),
221+
prev_mask=(
222+
torch.tensor(self.prev_mask).to(_device_in)
223+
if self.prev_mask is not None
224+
else None
225+
),
226+
)
227+
228+
if not hasattr(batch_data["pred"], "meta"):
229+
batch_data["pred"] = monai.data.MetaTensor(
230+
batch_data["pred"],
231+
affine=batch_data["image"].meta["affine"],
232+
meta=batch_data["image"].meta,
233+
)
234+
self.prev_mask = batch_data["pred"]
235+
batch_data["image"] = batch_data["image"].to("cpu")
236+
batch_data["pred"] = batch_data["pred"].to("cpu")
237+
torch.cuda.empty_cache()
238+
batch_data = [
239+
self.post_transforms(i) for i in decollate_batch(batch_data)
240+
]
241+
if save_mask:
242+
batch_data = [self.save_transforms(i) for i in batch_data]
243+
244+
finished = True
245+
except RuntimeError as e:
246+
if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")):
247+
raise e
248+
finished = False
249+
if finished:
250+
break
251+
if not finished:
252+
raise RuntimeError("Infer not finished due to OOM.")
253+
return batch_data[0]["pred"]
254+
255+
@torch.no_grad()
256+
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
257+
self.model.eval()
258+
device = f"cuda:{rank}"
259+
if not isinstance(image_file, dict):
260+
image_file = {"image": image_file}
261+
batch_data = self.infer_transforms(image_file)
262+
batch_data["label_prompt"] = label_prompt
263+
batch_data = list_data_collate([batch_data])
264+
device_list_input = [device, device, "cpu"]
265+
device_list_output = [device, "cpu", "cpu"]
266+
for _device_in, _device_out in zip(device_list_input, device_list_output):
267+
try:
268+
with torch.cuda.amp.autocast(enabled=self.amp):
269+
batch_data["pred"] = sliding_window_inference(
270+
inputs=batch_data["image"].to(_device_in),
271+
roi_size=self.patch_size,
272+
sw_batch_size=1,
273+
predictor=partial(infer_wrapper, model=self.model),
274+
mode="gaussian",
275+
overlap=0.625,
276+
sw_device=device,
277+
device=_device_out,
278+
class_vector=torch.tensor(label_prompt).to(_device_in),
279+
)
280+
if not hasattr(batch_data["pred"], "meta"):
281+
batch_data["pred"] = monai.data.MetaTensor(
282+
batch_data["pred"],
283+
affine=batch_data["image"].meta["affine"],
284+
meta=batch_data["image"].meta,
285+
)
286+
torch.cuda.empty_cache()
287+
batch_data = [
288+
self.post_transforms(i) for i in decollate_batch(batch_data)
289+
]
290+
batch_data = [self.save_transforms(i) for i in batch_data]
291+
finished = True
292+
except RuntimeError as e:
293+
if not any(x in str(e).lower() for x in ("memory", "cuda", "cudnn")):
294+
raise e
295+
finished = False
296+
if finished:
297+
break
298+
if not finished:
299+
raise RuntimeError("Infer not finished due to OOM.")
300+
301+
@torch.no_grad()
302+
def batch_infer_everything(self, datalist=str, basedir=str):
303+
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=0)
304+
train_files = [_["image"] for _ in train_files]
305+
dist.init_process_group(backend="nccl", init_method="env://")
306+
world_size = dist.get_world_size()
307+
rank = dist.get_rank()
308+
# no need to wrap model with DistributedDataParallel
309+
self.model = self.model.to(f"cuda:{rank}")
310+
infer_files = partition_dataset(
311+
data=train_files,
312+
shuffle=False,
313+
num_partitions=world_size,
314+
even_divisible=False,
315+
)[rank]
316+
self.infer(infer_files, label_prompt=EVERYTHING_PROMPT, rank=rank)
317+
318+
319+
if __name__ == "__main__":
320+
fire, _ = optional_import("fire")
321+
fire.Fire(InferClass)

scripts/utils/trans_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def __call__(
349349
pred += 0.5 # inplace mapping to avoid cloning pred
350350
for i in range(1, object_num + 1):
351351
frac = i + 0.5
352-
pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype)
352+
pred[pred == frac] = torch.tensor(data["label_prompt"][i - 1]).to(pred.dtype)
353+
# pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype)
353354
pred[pred == 0.5] = 0.0
354355
data[keys] = pred
355356
return data

vista3d/modeling/segresnetds.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def _forward(self, x: torch.Tensor) -> list[torch.Tensor]:
238238

239239
if self.head_module is not None:
240240
outputs = self.head_module(outputs)
241-
242241
return outputs
243242

244243
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
@@ -464,7 +463,7 @@ def is_valid_shape(self, x):
464463

465464
def _forward(
466465
self, x: torch.Tensor, with_point, with_label
467-
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
466+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
468467
if self.preprocess is not None:
469468
x = self.preprocess(x)
470469

@@ -521,8 +520,8 @@ def _forward(
521520
return outputs, outputs_auto
522521

523522
def forward(
524-
self, x: torch.Tensor, with_point=True, with_label=True, **kwargs
525-
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
523+
self, x: torch.Tensor, with_point=True, with_label=True, # **kwargs
524+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
526525
return self._forward(x, with_point, with_label)
527526

528527
def set_auto_grad(self, auto_freeze=False, point_freeze=False):

0 commit comments

Comments
 (0)