|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +from collections.abc import Sequence |
14 | 15 | from copy import deepcopy |
15 | 16 | from typing import TYPE_CHECKING |
16 | 17 |
|
|
43 | 44 | from simpeg_drivers.driver import InversionDriver |
44 | 45 |
|
45 | 46 |
|
| 47 | +def octree_extents(octree: Octree) -> np.ndarray: |
| 48 | + """ |
| 49 | + Get the true extents of an octree (min/max of the perimeter). |
| 50 | +
|
| 51 | + The octree.extents property returns min/max of the centroids |
| 52 | +
|
| 53 | + :param octree: Octree mesh object. |
| 54 | +
|
| 55 | + :returns: Array of [xmin, xmax, ymin, ymax]. |
| 56 | + """ |
| 57 | + |
| 58 | + def half_cell(axis): |
| 59 | + return ( |
| 60 | + getattr(octree, f"{axis}_cell_size") * octree.octree_cells["NCells"] |
| 61 | + ) / 2 |
| 62 | + |
| 63 | + xmin = (octree.centroids[:, 0] - half_cell("u")).min() |
| 64 | + xmax = (octree.centroids[:, 0] + half_cell("u")).max() |
| 65 | + ymin = (octree.centroids[:, 1] - half_cell("v")).min() |
| 66 | + ymax = (octree.centroids[:, 1] + half_cell("v")).max() |
| 67 | + |
| 68 | + return np.array([xmin, xmax, ymin, ymax]) |
| 69 | + |
| 70 | + |
| 71 | +def mask_vertices_and_cells( |
| 72 | + extent: Sequence, vertices: np.ndarray, cells: np.ndarray | None |
| 73 | +) -> tuple[np.ndarray, np.ndarray]: |
| 74 | + """ |
| 75 | + Mask vertices and remove cells whose vertices are all outside the extent. |
| 76 | +
|
| 77 | + :param extent: Array-like object of [xmin, xmax, ymin, ymax]. |
| 78 | + :param vertices: Array of shape (n_vertices, 3) containing the x, y, z coordinates. |
| 79 | + :param cells: Array of shape (n_cells, 3) containing the indices of the vertices |
| 80 | + that make up each cell. |
| 81 | + """ |
| 82 | + |
| 83 | + vertex_mask = ( |
| 84 | + (vertices[:, 0] >= extent[0]) |
| 85 | + & (vertices[:, 0] <= extent[1]) |
| 86 | + & (vertices[:, 1] >= extent[2]) |
| 87 | + & (vertices[:, 1] <= extent[3]) |
| 88 | + ) |
| 89 | + if cells is None: |
| 90 | + return vertices[vertex_mask], None |
| 91 | + |
| 92 | + cell_mask = np.any(vertex_mask[cells], axis=1) |
| 93 | + vertex_mask = np.zeros_like(vertex_mask, dtype=bool) |
| 94 | + vertex_mask[cells[cell_mask].flatten()] = True |
| 95 | + |
| 96 | + new_cells = cells.copy()[cell_mask] |
| 97 | + cell_map = np.arange(len(vertices))[vertex_mask] |
| 98 | + new_cells = np.searchsorted(cell_map, new_cells) |
| 99 | + |
| 100 | + return vertices[vertex_mask], new_cells |
| 101 | + |
| 102 | + |
46 | 103 | def calculate_2D_trend( |
47 | 104 | points: np.ndarray, values: np.ndarray, order: int = 0, method: str = "all" |
48 | 105 | ): |
@@ -495,7 +552,7 @@ def active_from_xyz( |
495 | 552 | raise ValueError("'grid_reference' must be one of 'center', 'top', or 'bottom'") |
496 | 553 |
|
497 | 554 | # Return the active cell array |
498 | | - return mask_under_horizon(locations, topo, triangulation=triangulation) |
| 555 | + return mask_under_horizon(locations, horizon=topo, triangulation=triangulation) |
499 | 556 |
|
500 | 557 |
|
501 | 558 | def truncate_locs_depths(locs: np.ndarray, depth_core: float) -> np.ndarray: |
|
0 commit comments