Skip to content

Commit d54be73

Browse files
committed
fix key not found; fix general_vfov_to_focal when batched input
1 parent be8caf3 commit d54be73

4 files changed

Lines changed: 15 additions & 13 deletions

File tree

demo/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,4 @@ def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None):
162162

163163
print("Alternatively, inference a batch of images")
164164
predictions = pf_model.inference_batch(img_bgr_list=[img_bgr, img_bgr, img_bgr])
165-
breakpoint()
165+
breakpoint()

perspective2d/modeling/param_network/param_network.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,13 @@ def forward(self, predictions, batched_inputs=None):
210210
param["pred_general_vfov"] = param["pred_vfov"]
211211
if "pred_rel_focal" not in param:
212212
param["pred_rel_focal"] = torch.FloatTensor(
213-
[
214-
general_vfov_to_focal(
215-
to_numpy(param["pred_rel_cx"]),
216-
to_numpy(param["pred_rel_cy"]),
217-
1,
218-
to_numpy(param["pred_general_vfov"]),
219-
degree=True,
220-
)
221-
]
213+
general_vfov_to_focal(
214+
to_numpy(param["pred_rel_cx"]),
215+
to_numpy(param["pred_rel_cy"]),
216+
1,
217+
to_numpy(param["pred_general_vfov"]),
218+
degree=True,
219+
)
222220
)
223221
return param
224222

perspective2d/perspectivefields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def forward(self, batched_inputs) -> dict:
265265
param["pred_rel_cx"] = torch.zeros_like(param["pred_vfov"])
266266
if "pred_rel_cy" not in param.keys():
267267
param["pred_rel_cy"] = torch.zeros_like(param["pred_vfov"])
268-
assert len(processed_results) == len(param["pred_vfov"])
268+
assert len(processed_results) == len(param["pred_general_vfov"])
269269
for i in range(len(processed_results)):
270270
param_tmp = {k: v[i] for k, v in param.items()}
271271
processed_results[i].update(param_tmp)

perspective2d/utils/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,14 @@ def fun(focal, *args):
7979
q_sqr = (focal / h) ** 2 + d_cx**2 + (d_cy - 0.5) ** 2
8080
cos_FoV = (p_sqr + q_sqr - 1) / 2 / np.sqrt(p_sqr) / np.sqrt(q_sqr)
8181
return cos_FoV - target_cos_FoV
82-
8382
if degree:
8483
gvfov = np.radians(gvfov)
85-
focal = scipy.optimize.fsolve(fun, 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))[0]
84+
if type(rel_cx) != np.ndarray:
85+
# if input is float
86+
focal = scipy.optimize.fsolve(fun, 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))[0]
87+
else:
88+
# if input is numpy array
89+
focal = scipy.optimize.fsolve(fun, np.ones(len(rel_cx)) * 1.5, args=(h, rel_cx, rel_cy, np.cos(gvfov)))
8690
focal = np.abs(focal)
8791
return focal
8892

0 commit comments

Comments
 (0)