Skip to content

Commit bff3f88

Browse files
Use embree raytracing for virtual fitting (#367)
This greatly speeds up building the spherical interpolator for virtual fitting, but only for intel architecture.
1 parent 12b9454 commit bff3f88

3 files changed

Lines changed: 71 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"OpenEXR",
5050
"scikit-image",
5151
"trimesh",
52+
"embreex; platform_machine=='x86_64'",
5253
"onnxruntime==1.18.0",
5354
]
5455

src/openlifu/seg/skinseg.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import numpy as np
1919
import skimage.filters
2020
import skimage.measure
21+
import trimesh
2122
import vtk
2223
from packaging.version import parse
2324
from scipy.interpolate import LinearNDInterpolator
2425
from scipy.ndimage import distance_transform_edt
25-
from vtk.util.numpy_support import numpy_to_vtk
26+
from vtk.util.numpy_support import numpy_to_vtk, vtk_to_numpy
2627

27-
from openlifu.geo import cartesian_to_spherical
28+
from openlifu.geo import cartesian_to_spherical, spherical_to_cartesian_vectorized
2829

2930

3031
def apply_affine_to_polydata(affine:np.ndarray, polydata:vtk.vtkPolyData) -> vtk.vtkPolyData:
@@ -282,8 +283,9 @@ def spherical_interpolator_from_mesh(
282283
surface_mesh: vtk.vtkPolyData,
283284
origin: Tuple[float, float, float] = (0.,0.,0.),
284285
xyz_direction_columns: np.ndarray | None = None,
285-
dist_tolerance: float = 0.0001
286-
) -> Callable[[float, float], float]:
286+
use_embree: bool|None = None,
287+
dist_tolerance: float = 0.0001,
288+
) -> Callable:
287289
"""Create a spherical interpolator from a vtkPolyData.
288290
289291
Here a "spherical interpolator" is a function that maps angles from a spherical coordinate system
@@ -300,10 +302,12 @@ def spherical_interpolator_from_mesh(
300302
of how the spherical angles relate to the x, y, and z axes. If not provided, the xyz_direction_columns will
301303
be an identity matrix, which means that the coordinates in which surface_mesh is given will directly be
302304
interpreted as the x,y,z upon which a spherical coordinate system will be based.
305+
use_embree: Use an alternative algorithm that uses embree CPU raytracing. Defaults to True only if embree is available;
306+
it requires x86 architecture.
303307
dist_tolerance: A vertex of the surface_mesh will only be included if it is the furthest point from the origin
304308
that is on the mesh along the ray emanating from the origin and passing through the vertex. The
305309
dist_tolerance is the threshold for determining whether an intersection of the ray with the mesh
306-
counts as being a distinct further out point from the vertex.
310+
counts as being a distinct further out point from the vertex. This parameter only matters if use_embree is off.
307311
308312
Returns:
309313
A spherical interpolator, which is a callable that maps (theta,phi) pairs of spherical coordinates (phi being azimuthal)
@@ -329,6 +333,9 @@ def spherical_interpolator_from_mesh(
329333
if xyz_direction_columns is None:
330334
xyz_direction_columns = np.eye(3, dtype=float)
331335

336+
if use_embree is None:
337+
use_embree = trimesh.ray.has_embree
338+
332339
xyz_affine = np.eye(4)
333340
xyz_affine[:3,:3] = xyz_direction_columns
334341
xyz_affine[:3,3] = origin
@@ -342,20 +349,29 @@ def spherical_interpolator_from_mesh(
342349
transform_filter = vtk.vtkTransformPolyDataFilter()
343350
transform_filter.SetTransform(xyz_inverse_transform)
344351
transform_filter.SetInputData(surface_mesh)
345-
transform_filter.Update()
346-
surface_mesh_transformed = transform_filter.GetOutput()
352+
triangle_filter = vtk.vtkTriangleFilter()
353+
triangle_filter.SetInputConnection(transform_filter.GetOutputPort())
354+
triangle_filter.Update()
355+
surface_mesh_transformed = triangle_filter.GetOutput()
356+
357+
if use_embree:
358+
return _spherical_interpolator_from_mesh_embree(surface_mesh_transformed)
359+
else:
360+
return _spherical_interpolator_from_mesh_cell_locator(surface_mesh_transformed, dist_tolerance)
361+
362+
def _spherical_interpolator_from_mesh_cell_locator(surface_mesh : vtk.vtkPolyData, dist_tolerance:float) -> Callable:
347363

348364
spherical_coords_on_mesh : List[Tuple[float,float,float]] = []
349365

350-
points = surface_mesh_transformed.GetPoints()
366+
points = surface_mesh.GetPoints()
351367

352368
# The farthest point from the origin is this far out:
353369
r_max = np.max([np.sqrt(np.sum(np.array(points.GetPoint(i))**2)) for i in range(points.GetNumberOfPoints())])
354370

355371
sqdist_tolerance = dist_tolerance**2
356372

357373
locator = vtk.vtkCellLocator() # Tried vtkOBBTree and it seems vtkCellLocator is much faster for this application
358-
locator.SetDataSet(surface_mesh_transformed)
374+
locator.SetDataSet(surface_mesh)
359375
locator.BuildLocator()
360376

361377
for i in range(points.GetNumberOfPoints()):
@@ -422,3 +438,45 @@ def spherical_interpolator_from_mesh(
422438
)
423439

424440
return interpolator
441+
442+
def _spherical_interpolator_from_mesh_embree(surface_mesh : vtk.vtkPolyData) -> Callable:
443+
vtk_points = surface_mesh.GetPoints()
444+
points_np = vtk_to_numpy(vtk_points.GetData()).astype(np.float64) # (N,3)
445+
polys = surface_mesh.GetPolys()
446+
polys_np = vtk_to_numpy(polys.GetData()) # flat array [3,i0,i1,i2,3,i0,i1,i2,...]
447+
if polys_np.size == 0:
448+
raise RuntimeError("Input mesh has no polygons after transformation/triangulation.")
449+
polys_np = polys_np.reshape(-1, 4) # (M, 4)
450+
faces_np = polys_np[:, 1:4].astype(np.int64) # (M, 3)
451+
452+
r_squared = np.sum(points_np**2, axis=1)
453+
454+
# The farthest point from the origin is this far out:
455+
r_max = float(np.sqrt(r_squared.max()))
456+
457+
tm = trimesh.Trimesh(vertices=points_np, faces=faces_np, process=False)
458+
intersector = trimesh.ray.ray_pyembree.RayMeshIntersector(tm)
459+
460+
def interpolator(*args):
461+
if len(args)==2:
462+
arr = np.array(args)
463+
elif len(args)==1 and isinstance(args[0], np.ndarray):
464+
arr = args[0] # expected shape (...,2)
465+
if arr.shape[-1] != 2:
466+
msg = f"Interpolator expects array of shape (...,2). Got shape {arr.shape}"
467+
raise ValueError(msg)
468+
else:
469+
raise ValueError("Interpolator expects either two args (theta, phi) or a single numpy array arg shaped (...,2)")
470+
471+
origins = spherical_to_cartesian_vectorized(
472+
np.concatenate([np.full(arr.shape[:-1] + (1,), r_max+1),arr], axis=-1) # add r coordinate, giving shape (...,3)
473+
)
474+
475+
# intersects_id will expect shape (N,3), but we want to support (...,3), so reshape if needed:
476+
batch_shape = origins.shape[:-1]
477+
origins = origins.reshape((-1,3))
478+
479+
_, _, hit_locations = intersector.intersects_id(origins, -origins, multiple_hits=False, return_locations=True)
480+
return np.linalg.norm(hit_locations, axis=-1).reshape(batch_shape)
481+
482+
return interpolator

tests/test_skinseg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def test_create_closed_surface_from_labelmap():
111111
# isn't supposed to add a colormap or anything like that
112112
assert surface.GetPointData().GetScalars() is None
113113

114-
def test_spherical_interpolator_from_mesh():
114+
@pytest.mark.parametrize("use_embree", [True, False])
115+
def test_spherical_interpolator_from_mesh(use_embree):
115116
"""Check using a torus that the spherical interpolator behaves reasonably"""
116117
parametric_torus = vtk.vtkParametricTorus()
117118
parametric_torus.SetRingRadius(12.)
@@ -131,6 +132,7 @@ def test_spherical_interpolator_from_mesh():
131132
surface_mesh = torus_polydata,
132133
origin = origin,
133134
xyz_direction_columns = xyz_direction_columns,
135+
use_embree=use_embree
134136
)
135137

136138
sphere_source = vtk.vtkSphereSource()

0 commit comments

Comments
 (0)