Skip to content

Commit ba1d5be

Browse files
committed
Add VF option to specify number of transforms returned (#416)
1 parent 1c2fbfe commit ba1d5be

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/openlifu/virtual_fit.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class VirtualFitOptions(DictMixin):
9090
planefit_dpitch_step: Annotated[float, OpenLIFUFieldData("Plane fit pitch step", "Local pitch axis step size to use when constructing plane fitting grids. In spatial units of `units`")] = 3
9191
"""Local pitch axis step size to use when constructing plane fitting grids. In spatial units of `units`."""
9292

93+
top_n_candidates: Annotated[int, OpenLIFUFieldData("No. of candidates returned", "Sets the limit for the number of transducer transform candidates returned by the algorithm.")] = 4
94+
"""Sets the limit for the number of transducer transform candidates returned by the algorithm."""
95+
9396
def __post_init__(self):
9497
if not isinstance(self.units, str):
9598
raise TypeError("Units must be a string")
@@ -135,6 +138,10 @@ def __post_init__(self):
135138
raise ValueError("Plane fit pitch extent must be greater than 0")
136139
if not isinstance(self.planefit_dpitch_step, int | float):
137140
raise TypeError("Plane fit pitch step must be a number")
141+
if not isinstance(self.top_n_candidates, int ):
142+
raise TypeError("Number of transducer transform candidates returned must be an integer")
143+
if self.top_n_candidates <= 0:
144+
raise TypeError("Number of transducer transform candidates returned must be greater than 0")
138145

139146
def to_units(self, target_units: str) -> VirtualFitOptions:
140147
"""Do unit conversion and return a version of this VirtualFitOptions that uses
@@ -152,6 +159,7 @@ def to_units(self, target_units: str) -> VirtualFitOptions:
152159
planefit_dyaw_step = conversion_factor * self.planefit_dyaw_step,
153160
planefit_dpitch_extent = conversion_factor * self.planefit_dpitch_extent,
154161
planefit_dpitch_step = conversion_factor * self.planefit_dpitch_step,
162+
top_n_candidates = self.top_n_candidates,
155163
)
156164

157165
@staticmethod
@@ -275,6 +283,7 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
275283
planefit_dyaw_step = options.planefit_dyaw_step
276284
planefit_dpitch_extent = options.planefit_dpitch_extent
277285
planefit_dpitch_step = options.planefit_dpitch_step
286+
top_n_candidates = options.top_n_candidates
278287

279288

280289
if skin_mesh is None:
@@ -443,7 +452,7 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
443452

444453
progress_callback(100, "Complete")
445454
return (
446-
sorted_transforms,
455+
sorted_transforms[:top_n_candidates],
447456
VirtualFitDebugInfo(
448457
skin_mesh = skin_mesh,
449458
spherically_interpolated_mesh = interpolator_mesh,
@@ -455,4 +464,4 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
455464
)
456465

457466
progress_callback(100, "Complete")
458-
return sorted_transforms
467+
return sorted_transforms[:top_n_candidates]

0 commit comments

Comments
 (0)