Skip to content

Commit d854bc5

Browse files
committed
Working TRT wrappers for encoder and class head
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
1 parent a1da5e2 commit d854bc5

7 files changed

Lines changed: 1017 additions & 47 deletions

File tree

scripts/export.bash

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
1+
python3 -m scripts.export --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
22

3-
python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true
3+
# python3 -m scripts.export --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true

scripts/export.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from .sliding_window import point_based_window_inferer, sliding_window_inference
3232
from .train import CONFIG
3333
from .utils.trans_utils import VistaPostTransform
34+
from .utils.trt_utils import ExportWrapper, TRTWrapper
35+
import time
3436

3537
rearrange, _ = optional_import("einops", name="rearrange")
3638
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
@@ -60,7 +62,6 @@ def infer_wrapper(inputs, model, **kwargs):
6062
outputs = model(input_images=inputs, **kwargs)
6163
return outputs.transpose(1, 0)
6264

63-
6465
class InferClass:
6566
def __init__(self, config_file="./configs/infer.yaml", **override):
6667
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -73,7 +74,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
7374
parser.update(pairs=_args)
7475

7576
# We do not use AMP for export
76-
self.amp = False # parser.get_parsed_content("amp")
77+
self.amp = parser.get_parsed_content("amp")
7778
input_channels = parser.get_parsed_content("input_channels")
7879
patch_size = parser.get_parsed_content("patch_size")
7980
self.patch_size = patch_size
@@ -129,6 +130,17 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
129130
self.save_transforms = transforms.Compose(save_transforms)
130131
self.prev_mask = None
131132
self.batch_data = None
133+
134+
en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder,
135+
input_names = ['x'], output_names = ['x_out'])
136+
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False)
137+
# self.model.image_encoder.encoder.load_engine()
138+
139+
cls_wrapper = ExportWrapper.wrap(self.model.class_head,
140+
input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding'])
141+
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False)
142+
# self.model.class_head.load_engine()
143+
132144
return
133145

134146
def clear_cache(self):
@@ -162,6 +174,7 @@ def infer(
162174
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
163175
time and avoid repeated inference. This is by default disabled.
164176
"""
177+
time00=time.time()
165178
self.model.eval()
166179
if not isinstance(image_file, dict):
167180
image_file = {"image": image_file}
@@ -248,12 +261,15 @@ def infer(
248261
finished = False
249262
if finished:
250263
break
264+
print(f"Infer Time: {time.time() - time00}")
265+
251266
if not finished:
252267
raise RuntimeError("Infer not finished due to OOM.")
253268
return batch_data[0]["pred"]
254269

255270
@torch.no_grad()
256271
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
272+
time00=time.time()
257273
self.model.eval()
258274
device = f"cuda:{rank}"
259275
if not isinstance(image_file, dict):
@@ -295,6 +311,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
295311
finished = False
296312
if finished:
297313
break
314+
print(f"InferEverything Time: {time.time() - time00}")
315+
298316
if not finished:
299317
raise RuntimeError("Infer not finished due to OOM.")
300318

@@ -317,5 +335,11 @@ def batch_infer_everything(self, datalist=str, basedir=str):
317335

318336

319337
if __name__ == "__main__":
338+
try:
339+
#import torch_onnx
340+
#torch_onnx.patch_torch(error_report=True)
341+
print("patch succeeded")
342+
except Exception:
343+
pass
320344
fire, _ = optional_import("fire")
321345
fire.Fire(InferClass)

scripts/utils/cast_utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3+
#
4+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5+
# property and proprietary rights in and to this material, related
6+
# documentation and any modifications thereto. Any use, reproduction,
7+
# disclosure or distribution of this material and related documentation
8+
# without an express license agreement from NVIDIA CORPORATION or
9+
# its affiliates is strictly prohibited.
10+
11+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
12+
#
13+
# Licensed under the Apache License, Version 2.0 (the "License");
14+
# you may not use this file except in compliance with the License.
15+
# You may obtain a copy of the License at
16+
#
17+
# http://www.apache.org/licenses/LICENSE-2.0
18+
#
19+
# Unless required by applicable law or agreed to in writing, software
20+
# distributed under the License is distributed on an "AS IS" BASIS,
21+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22+
# See the License for the specific language governing permissions and
23+
# limitations under the License.
24+
25+
from contextlib import nullcontext
26+
27+
import torch
28+
29+
def avoid_bfloat16_autocast_context():
30+
"""
31+
If the current autocast context is bfloat16,
32+
cast it to float32
33+
"""
34+
35+
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
36+
return torch.cuda.amp.autocast(dtype=torch.float32)
37+
else:
38+
return nullcontext()
39+
40+
41+
def avoid_float16_autocast_context():
42+
"""
43+
If the current autocast context is float16, cast it to bfloat16
44+
if available (unless we're in jit) or float32
45+
"""
46+
47+
if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
48+
if torch.jit.is_scripting() or torch.jit.is_tracing():
49+
return torch.cuda.amp.autocast(dtype=torch.float32)
50+
51+
if torch.cuda.is_bf16_supported():
52+
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
53+
else:
54+
return torch.cuda.amp.autocast(dtype=torch.float32)
55+
else:
56+
return nullcontext()
57+
58+
59+
def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
60+
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
61+
62+
63+
def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
64+
if isinstance(x, torch.Tensor):
65+
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
66+
else:
67+
if isinstance(x, dict):
68+
new_dict = {}
69+
for k in x.keys():
70+
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
71+
return new_dict
72+
elif isinstance(x, tuple):
73+
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
74+
75+
76+
class CastToFloat(torch.nn.Module):
77+
def __init__(self, mod):
78+
super(CastToFloat, self).__init__()
79+
self.mod = mod
80+
81+
def forward(self, x):
82+
with torch.cuda.amp.autocast(enabled=False):
83+
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
84+
return ret
85+
86+
87+
class CastToFloatAll(torch.nn.Module):
88+
def __init__(self, mod):
89+
super(CastToFloatAll, self).__init__()
90+
self.mod = mod
91+
92+
def forward(self, *args):
93+
from_dtype = args[0].dtype
94+
with torch.cuda.amp.autocast(enabled=False):
95+
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
96+
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)

0 commit comments

Comments
 (0)