Skip to content

Commit 818a548

Browse files
committed
Cleaned up, working TRT wrapping
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
1 parent a82ce56 commit 818a548

2 files changed

Lines changed: 25 additions & 23 deletions

File tree

scripts/export.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,13 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
133133

134134
en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder,
135135
input_names = ['x'], output_names = ['x_out'])
136-
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper, use_cuda_graph=False)
137-
# self.model.image_encoder.encoder.load_engine()
136+
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper)
137+
self.model.image_encoder.encoder.load_engine()
138138

139139
cls_wrapper = ExportWrapper.wrap(self.model.class_head,
140140
input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding'])
141-
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper, use_cuda_graph=False)
142-
# self.model.class_head.load_engine()
141+
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper)
142+
self.model.class_head.load_engine()
143143

144144
return
145145

vista3d/modeling/vista3d.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,16 @@ def forward(
302302
):
303303
out, out_auto = self.image_embeddings, None
304304
else:
305-
# print(input_images.dtype)
306-
self.image_encoder.encoder.build_and_save(
307-
(input_images,),
308-
dynamo=False,
309-
verbose=False,
310-
fp16=True, tf32=True,
311-
builder_optimization_level=5,
312-
enable_all_tactics=True
313-
)
305+
# Support for TRT wrappping
306+
if hasattr(self.image_encoder.encoder, "build_and_save"):
307+
self.image_encoder.encoder.build_and_save(
308+
(input_images,),
309+
dynamo=False,
310+
verbose=False,
311+
fp16=True, tf32=True,
312+
builder_optimization_level=5,
313+
enable_all_tactics=True
314+
)
314315

315316
time0 = time.time()
316317
out, out_auto = self.image_encoder(
@@ -325,19 +326,20 @@ def forward(
325326
# force releasing memories that set to None
326327
torch.cuda.empty_cache()
327328
if class_vector is not None:
328-
self.class_head.build_and_save(
329-
(out_auto, class_vector,),
330-
fp16=True, tf32=True,
331-
dynamo=False,
332-
verbose=False,
333-
)
334-
time2 = time.time()
329+
if hasattr(self.class_head, "build_and_save"):
330+
self.class_head.build_and_save(
331+
(out_auto, class_vector,),
332+
fp16=True, tf32=True,
333+
dynamo=False,
334+
verbose=False,
335+
)
336+
# time2 = time.time()
335337
logits, _ = self.class_head(src=out_auto, class_vector=class_vector)
336338
# torch.cuda.synchronize()
337339
# print(f"Class Head Time: {time.time() - time2}")
338340

339341
if point_coords is not None:
340-
time3 = time.time()
342+
# time3 = time.time()
341343
point_logits = self.point_head(
342344
out, point_coords, point_labels, class_vector=prompt_class
343345
)
@@ -376,8 +378,8 @@ def forward(
376378
mapping_index,
377379
)
378380

379-
torch.cuda.synchronize()
380-
# print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}")
381+
# torch.cuda.synchronize()
382+
# print(f"Total time : {time.time() - time00} shape : {logits.shape}")
381383

382384
if kwargs.get("keep_cache", False) and class_vector is None:
383385
self.image_embeddings = out.detach()

0 commit comments

Comments
 (0)