Skip to content

Commit be8caf3

Browse files
committed
inference batch images
1 parent ee6428b commit be8caf3

3 files changed

Lines changed: 28 additions & 2 deletions

File tree

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ pf_model = PerspectiveFields(version).eval().cuda()
106106
img_bgr = cv2.imread('assets/imgs/cityscape.jpg')
107107
# inference
108108
predictions = pf_model.inference(img_bgr=img_bgr)
109+
110+
# alternatively, inference a batch of images
111+
predictions = pf_model.inference_batch(img_bgr_list=[img_bgr_0, img_bgr_1, img_bgr_2])
109112
```
110113
- Or checkout [Live Demo 🤗](https://huggingface.co/spaces/jinlinyi/PerspectiveFields).
111114
- Notebook to [Predict Perspective Fields](./notebooks/predict_perspective_fields.ipynb).

demo/demo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,8 @@ def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):
158158
pitch: 48.88
159159
vfov: 52.82
160160
cx: 0.00
161-
cy: 0.00""")
161+
cy: 0.00""")
162+
163+
print("Alternatively, inference a batch of images")
164+
predictions = pf_model.inference_batch(img_bgr_list=[img_bgr, img_bgr, img_bgr])
165+
breakpoint()

perspective2d/perspectivefields.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,22 @@ def inference(self, img_bgr):
204204
predictions = self.forward([inputs])[0]
205205
return predictions
206206

207+
@torch.no_grad()
208+
def inference_batch(self, img_bgr_list):
209+
input_list = []
210+
for img_bgr in img_bgr_list:
211+
original_image = img_bgr.copy()
212+
if self.input_format == "RGB":
213+
# whether the model expects BGR inputs or RGB
214+
original_image = original_image[:, :, ::-1]
215+
height, width = original_image.shape[:2]
216+
image = self.aug.apply_image(original_image)
217+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
218+
inputs = {"image": image, "height": height, "width": width}
219+
input_list.append(inputs)
220+
predictions = self.forward(input_list)
221+
return predictions
222+
207223
def forward(self, batched_inputs) -> dict:
208224
"""
209225
Forward pass of the PerspectiveFields model.
@@ -249,5 +265,8 @@ def forward(self, batched_inputs) -> dict:
249265
param["pred_rel_cx"] = torch.zeros_like(param["pred_vfov"])
250266
if "pred_rel_cy" not in param.keys():
251267
param["pred_rel_cy"] = torch.zeros_like(param["pred_vfov"])
252-
processed_results[0].update(param)
268+
assert len(processed_results) == len(param["pred_vfov"])
269+
for i in range(len(processed_results)):
270+
param_tmp = {k: v[i] for k, v in param.items()}
271+
processed_results[i].update(param_tmp)
253272
return processed_results

0 commit comments

Comments
 (0)