3131from .sliding_window import point_based_window_inferer , sliding_window_inference
3232from .train import CONFIG
3333from .utils .trans_utils import VistaPostTransform
34+ from .utils .trt_utils import ExportWrapper , TRTWrapper
35+ import time
3436
3537rearrange , _ = optional_import ("einops" , name = "rearrange" )
3638sys .path .insert (0 , os .path .abspath (os .path .dirname (__file__ )))
@@ -60,7 +62,6 @@ def infer_wrapper(inputs, model, **kwargs):
6062 outputs = model (input_images = inputs , ** kwargs )
6163 return outputs .transpose (1 , 0 )
6264
63-
6465class InferClass :
6566 def __init__ (self , config_file = "./configs/infer.yaml" , ** override ):
6667 logging .basicConfig (stream = sys .stdout , level = logging .INFO )
@@ -73,7 +74,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
7374 parser .update (pairs = _args )
7475
7576 # We do not use AMP for export
76- self .amp = False # parser.get_parsed_content("amp")
77+ self .amp = parser .get_parsed_content ("amp" )
7778 input_channels = parser .get_parsed_content ("input_channels" )
7879 patch_size = parser .get_parsed_content ("patch_size" )
7980 self .patch_size = patch_size
@@ -129,6 +130,17 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
129130 self .save_transforms = transforms .Compose (save_transforms )
130131 self .prev_mask = None
131132 self .batch_data = None
133+
134+ en_wrapper = ExportWrapper .wrap (self .model .image_encoder .encoder ,
135+ input_names = ['x' ], output_names = ['x_out' ])
136+ self .model .image_encoder .encoder = TRTWrapper ("Encoder" , en_wrapper , use_cuda_graph = False )
137+ # self.model.image_encoder.encoder.load_engine()
138+
139+ cls_wrapper = ExportWrapper .wrap (self .model .class_head ,
140+ input_names = ['src' , 'class_vector' ], output_names = ['masks' , 'class_embedding' ])
141+ self .model .class_head = TRTWrapper ("ClassHead" , cls_wrapper , use_cuda_graph = False )
142+ # self.model.class_head.load_engine()
143+
132144 return
133145
134146 def clear_cache (self ):
@@ -162,6 +174,7 @@ def infer(
162174 used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
163175 time and avoid repeated inference. This is by default disabled.
164176 """
177+ time00 = time .time ()
165178 self .model .eval ()
166179 if not isinstance (image_file , dict ):
167180 image_file = {"image" : image_file }
@@ -248,12 +261,15 @@ def infer(
248261 finished = False
249262 if finished :
250263 break
264+ print (f"Infer Time: { time .time () - time00 } " )
265+
251266 if not finished :
252267 raise RuntimeError ("Infer not finished due to OOM." )
253268 return batch_data [0 ]["pred" ]
254269
255270 @torch .no_grad ()
256271 def infer_everything (self , image_file , label_prompt = EVERYTHING_PROMPT , rank = 0 ):
272+ time00 = time .time ()
257273 self .model .eval ()
258274 device = f"cuda:{ rank } "
259275 if not isinstance (image_file , dict ):
@@ -295,6 +311,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
295311 finished = False
296312 if finished :
297313 break
314+ print (f"InferEverything Time: { time .time () - time00 } " )
315+
298316 if not finished :
299317 raise RuntimeError ("Infer not finished due to OOM." )
300318
@@ -317,5 +335,11 @@ def batch_infer_everything(self, datalist=str, basedir=str):
317335
318336
319337if __name__ == "__main__" :
338+ try :
339+ #import torch_onnx
340+ #torch_onnx.patch_torch(error_report=True)
341+ print ("patch succeeded" )
342+ except Exception :
343+ pass
320344 fire , _ = optional_import ("fire" )
321345 fire .Fire (InferClass )
0 commit comments