@@ -302,15 +302,16 @@ def forward(
302302 ):
303303 out , out_auto = self .image_embeddings , None
304304 else :
305- # print(input_images.dtype)
306- self .image_encoder .encoder .build_and_save (
307- (input_images ,),
308- dynamo = False ,
309- verbose = False ,
310- fp16 = True , tf32 = True ,
311- builder_optimization_level = 5 ,
312- enable_all_tactics = True
313- )
305+ # Support for TRT wrappping
306+ if hasattr (self .image_encoder .encoder , "build_and_save" ):
307+ self .image_encoder .encoder .build_and_save (
308+ (input_images ,),
309+ dynamo = False ,
310+ verbose = False ,
311+ fp16 = True , tf32 = True ,
312+ builder_optimization_level = 5 ,
313+ enable_all_tactics = True
314+ )
314315
315316 time0 = time .time ()
316317 out , out_auto = self .image_encoder (
@@ -325,19 +326,20 @@ def forward(
325326 # force releasing memories that set to None
326327 torch .cuda .empty_cache ()
327328 if class_vector is not None :
328- self .class_head .build_and_save (
329- (out_auto , class_vector ,),
330- fp16 = True , tf32 = True ,
331- dynamo = False ,
332- verbose = False ,
333- )
334- time2 = time .time ()
329+ if hasattr (self .class_head , "build_and_save" ):
330+ self .class_head .build_and_save (
331+ (out_auto , class_vector ,),
332+ fp16 = True , tf32 = True ,
333+ dynamo = False ,
334+ verbose = False ,
335+ )
336+ # time2 = time.time()
335337 logits , _ = self .class_head (src = out_auto , class_vector = class_vector )
336338 # torch.cuda.synchronize()
337339 # print(f"Class Head Time: {time.time() - time2}")
338340
339341 if point_coords is not None :
340- time3 = time .time ()
342+ # time3 = time.time()
341343 point_logits = self .point_head (
342344 out , point_coords , point_labels , class_vector = prompt_class
343345 )
@@ -376,8 +378,8 @@ def forward(
376378 mapping_index ,
377379 )
378380
379- torch .cuda .synchronize ()
380- # print(f"Head time: {time.time() - time1}, total time : {time.time() - time00} shape : {logits.shape}")
381+ # torch.cuda.synchronize()
382+ # print(f"Total time : {time.time() - time00} shape : {logits.shape}")
381383
382384 if kwargs .get ("keep_cache" , False ) and class_vector is None :
383385 self .image_embeddings = out .detach ()
0 commit comments