@@ -204,6 +204,22 @@ def inference(self, img_bgr):
204204 predictions = self .forward ([inputs ])[0 ]
205205 return predictions
206206
207+ @torch .no_grad ()
208+ def inference_batch (self , img_bgr_list ):
209+ input_list = []
210+ for img_bgr in img_bgr_list :
211+ original_image = img_bgr .copy ()
212+ if self .input_format == "RGB" :
213+ # whether the model expects BGR inputs or RGB
214+ original_image = original_image [:, :, ::- 1 ]
215+ height , width = original_image .shape [:2 ]
216+ image = self .aug .apply_image (original_image )
217+ image = torch .as_tensor (image .astype ("float32" ).transpose (2 , 0 , 1 ))
218+ inputs = {"image" : image , "height" : height , "width" : width }
219+ input_list .append (inputs )
220+ predictions = self .forward (input_list )
221+ return predictions
222+
207223 def forward (self , batched_inputs ) -> dict :
208224 """
209225 Forward pass of the PerspectiveFields model.
@@ -249,5 +265,8 @@ def forward(self, batched_inputs) -> dict:
249265 param ["pred_rel_cx" ] = torch .zeros_like (param ["pred_vfov" ])
250266 if "pred_rel_cy" not in param .keys ():
251267 param ["pred_rel_cy" ] = torch .zeros_like (param ["pred_vfov" ])
252- processed_results [0 ].update (param )
268+ assert len (processed_results ) == len (param ["pred_vfov" ])
269+ for i in range (len (processed_results )):
270+ param_tmp = {k : v [i ] for k , v in param .items ()}
271+ processed_results [i ].update (param_tmp )
253272 return processed_results
0 commit comments