1818import numpy as np
1919import skimage .filters
2020import skimage .measure
21+ import trimesh
2122import vtk
2223from packaging .version import parse
2324from scipy .interpolate import LinearNDInterpolator
2425from scipy .ndimage import distance_transform_edt
25- from vtk .util .numpy_support import numpy_to_vtk
26+ from vtk .util .numpy_support import numpy_to_vtk , vtk_to_numpy
2627
27- from openlifu .geo import cartesian_to_spherical
28+ from openlifu .geo import cartesian_to_spherical , spherical_to_cartesian_vectorized
2829
2930
3031def apply_affine_to_polydata (affine :np .ndarray , polydata :vtk .vtkPolyData ) -> vtk .vtkPolyData :
@@ -282,8 +283,9 @@ def spherical_interpolator_from_mesh(
282283 surface_mesh : vtk .vtkPolyData ,
283284 origin : Tuple [float , float , float ] = (0. ,0. ,0. ),
284285 xyz_direction_columns : np .ndarray | None = None ,
285- dist_tolerance : float = 0.0001
286- ) -> Callable [[float , float ], float ]:
286+ use_embree : bool | None = None ,
287+ dist_tolerance : float = 0.0001 ,
288+ ) -> Callable :
287289 """Create a spherical interpolator from a vtkPolyData.
288290
289291 Here a "spherical interpolator" is a function that maps angles from a spherical coordinate system
@@ -300,10 +302,12 @@ def spherical_interpolator_from_mesh(
300302 of how the spherical angles relate to the x, y, and z axes. If not provided, the xyz_direction_columns will
301303 be an identity matrix, which means that the coordinates in which surface_mesh is given will directly be
302304 interpreted as the x,y,z upon which a spherical coordinate system will be based.
305+ use_embree: Use an alternative algorithm that uses embree CPU raytracing. Defaults to True only if embree is available;
306+ it requires x86 architecture.
303307 dist_tolerance: A vertex of the surface_mesh will only be included if it is the furthest point from the origin
304308 that is on the mesh along the ray emanating from the origin and passing through the vertex. The
305309 dist_tolerance is the threshold for determining whether an intersection of the ray with the mesh
306- counts as being a distinct further out point from the vertex.
310+ counts as being a distinct further out point from the vertex. This parameter only matters if use_embree is off.
307311
308312 Returns:
309313 A spherical interpolator, which is a callable that maps (theta,phi) pairs of spherical coordinates (phi being azimuthal)
@@ -329,6 +333,9 @@ def spherical_interpolator_from_mesh(
329333 if xyz_direction_columns is None :
330334 xyz_direction_columns = np .eye (3 , dtype = float )
331335
336+ if use_embree is None :
337+ use_embree = trimesh .ray .has_embree
338+
332339 xyz_affine = np .eye (4 )
333340 xyz_affine [:3 ,:3 ] = xyz_direction_columns
334341 xyz_affine [:3 ,3 ] = origin
@@ -342,20 +349,29 @@ def spherical_interpolator_from_mesh(
342349 transform_filter = vtk .vtkTransformPolyDataFilter ()
343350 transform_filter .SetTransform (xyz_inverse_transform )
344351 transform_filter .SetInputData (surface_mesh )
345- transform_filter .Update ()
346- surface_mesh_transformed = transform_filter .GetOutput ()
352+ triangle_filter = vtk .vtkTriangleFilter ()
353+ triangle_filter .SetInputConnection (transform_filter .GetOutputPort ())
354+ triangle_filter .Update ()
355+ surface_mesh_transformed = triangle_filter .GetOutput ()
356+
357+ if use_embree :
358+ return _spherical_interpolator_from_mesh_embree (surface_mesh_transformed )
359+ else :
360+ return _spherical_interpolator_from_mesh_cell_locator (surface_mesh_transformed , dist_tolerance )
361+
362+ def _spherical_interpolator_from_mesh_cell_locator (surface_mesh : vtk .vtkPolyData , dist_tolerance :float ) -> Callable :
347363
348364 spherical_coords_on_mesh : List [Tuple [float ,float ,float ]] = []
349365
350- points = surface_mesh_transformed .GetPoints ()
366+ points = surface_mesh .GetPoints ()
351367
352368 # The farthest point from the origin is this far out:
353369 r_max = np .max ([np .sqrt (np .sum (np .array (points .GetPoint (i ))** 2 )) for i in range (points .GetNumberOfPoints ())])
354370
355371 sqdist_tolerance = dist_tolerance ** 2
356372
357373 locator = vtk .vtkCellLocator () # Tried vtkOBBTree and it seems vtkCellLocator is much faster for this application
358- locator .SetDataSet (surface_mesh_transformed )
374+ locator .SetDataSet (surface_mesh )
359375 locator .BuildLocator ()
360376
361377 for i in range (points .GetNumberOfPoints ()):
@@ -422,3 +438,45 @@ def spherical_interpolator_from_mesh(
422438 )
423439
424440 return interpolator
441+
442+ def _spherical_interpolator_from_mesh_embree (surface_mesh : vtk .vtkPolyData ) -> Callable :
443+ vtk_points = surface_mesh .GetPoints ()
444+ points_np = vtk_to_numpy (vtk_points .GetData ()).astype (np .float64 ) # (N,3)
445+ polys = surface_mesh .GetPolys ()
446+ polys_np = vtk_to_numpy (polys .GetData ()) # flat array [3,i0,i1,i2,3,i0,i1,i2,...]
447+ if polys_np .size == 0 :
448+ raise RuntimeError ("Input mesh has no polygons after transformation/triangulation." )
449+ polys_np = polys_np .reshape (- 1 , 4 ) # (M, 4)
450+ faces_np = polys_np [:, 1 :4 ].astype (np .int64 ) # (M, 3)
451+
452+ r_squared = np .sum (points_np ** 2 , axis = 1 )
453+
454+ # The farthest point from the origin is this far out:
455+ r_max = float (np .sqrt (r_squared .max ()))
456+
457+ tm = trimesh .Trimesh (vertices = points_np , faces = faces_np , process = False )
458+ intersector = trimesh .ray .ray_pyembree .RayMeshIntersector (tm )
459+
460+ def interpolator (* args ):
461+ if len (args )== 2 :
462+ arr = np .array (args )
463+ elif len (args )== 1 and isinstance (args [0 ], np .ndarray ):
464+ arr = args [0 ] # expected shape (...,2)
465+ if arr .shape [- 1 ] != 2 :
466+ msg = f"Interpolator expects array of shape (...,2). Got shape { arr .shape } "
467+ raise ValueError (msg )
468+ else :
469+ raise ValueError ("Interpolator expects either two args (theta, phi) or a single numpy array arg shaped (...,2)" )
470+
471+ origins = spherical_to_cartesian_vectorized (
472+ np .concatenate ([np .full (arr .shape [:- 1 ] + (1 ,), r_max + 1 ),arr ], axis = - 1 ) # add r coordinate, giving shape (...,3)
473+ )
474+
475+ # intersects_id will expect shape (N,3), but we want to support (...,3), so reshape if needed:
476+ batch_shape = origins .shape [:- 1 ]
477+ origins = origins .reshape ((- 1 ,3 ))
478+
479+ _ , _ , hit_locations = intersector .intersects_id (origins , - origins , multiple_hits = False , return_locations = True )
480+ return np .linalg .norm (hit_locations , axis = - 1 ).reshape (batch_shape )
481+
482+ return interpolator
0 commit comments