|
1 | | -import copy |
2 | 1 | import warnings |
3 | 2 |
|
4 | 3 | # pylint: disable=redefined-builtin,no-name-in-module,no-member |
5 | 4 | from pyrecest.backend import ( |
6 | 5 | abs, |
7 | 6 | all, |
8 | 7 | any, |
9 | | - arange, |
10 | | - argmin, |
11 | | - array_equal, |
12 | | - linalg, |
13 | 8 | mean, |
14 | | - meshgrid, |
15 | 9 | sum, |
16 | 10 | ) |
17 | 11 | from pyrecest.distributions.hypersphere_subset.abstract_hypersphere_subset_distribution import ( |
@@ -50,31 +44,11 @@ def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): |
50 | 44 | enforce_pdf_nonnegative : bool |
51 | 45 | Whether non-negativity of ``grid_values`` is required. |
52 | 46 | """ |
53 | | - if grid.ndim != 2: |
54 | | - raise ValueError("grid must be a 2D array of shape (n_points, d).") |
55 | | - |
56 | | - n_points, d = grid.shape |
57 | | - |
58 | | - if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points): |
59 | | - raise ValueError( |
60 | | - f"grid_values must be a square 2D array of shape ({n_points}, {n_points})." |
61 | | - ) |
62 | | - |
63 | | - if any(abs(grid) > 1 + 1e-12): |
| 47 | + super().__init__(grid, grid_values, enforce_pdf_nonnegative) |
| 48 | + if any(abs(self.grid) > 1 + 1e-12): |
64 | 49 | raise ValueError( |
65 | 50 | "Grid points must have coordinates in [-1, 1] (unit sphere)." |
66 | 51 | ) |
67 | | - |
68 | | - if enforce_pdf_nonnegative and any(grid_values < 0): |
69 | | - raise ValueError("grid_values must be non-negative.") |
70 | | - |
71 | | - self.grid = grid |
72 | | - self.grid_values = grid_values |
73 | | - self.enforce_pdf_nonnegative = enforce_pdf_nonnegative |
74 | | - # Embedding dimension of the Cartesian product space (convention from |
75 | | - # libDirectional: dim = 2 * embedding_dim_of_individual_sphere). |
76 | | - self.dim = 2 * d |
77 | | - |
78 | 52 | self._check_normalization() |
79 | 53 |
|
80 | 54 | # ------------------------------------------------------------------ |
@@ -107,43 +81,6 @@ def _check_normalization(self, tol=0.01): |
107 | 81 | UserWarning, |
108 | 82 | ) |
109 | 83 |
|
110 | | - def normalize(self): |
111 | | - """No-op – returns ``self`` for compatibility.""" |
112 | | - return self |
113 | | - |
114 | | - # ------------------------------------------------------------------ |
115 | | - # Arithmetic |
116 | | - # ------------------------------------------------------------------ |
117 | | - |
118 | | - def multiply(self, other): |
119 | | - """ |
120 | | - Element-wise multiply two conditional grid distributions. |
121 | | -
|
122 | | - The resulting distribution is *not* normalized. |
123 | | -
|
124 | | - Parameters |
125 | | - ---------- |
126 | | - other : SdCondSdGridDistribution |
127 | | - Must be defined on the same grid. |
128 | | -
|
129 | | - Returns |
130 | | - ------- |
131 | | - SdCondSdGridDistribution |
132 | | - """ |
133 | | - if not array_equal(self.grid, other.grid): |
134 | | - raise ValueError( |
135 | | - "Multiply:IncompatibleGrid: Can only multiply distributions " |
136 | | - "defined on identical grids." |
137 | | - ) |
138 | | - warnings.warn( |
139 | | - "Multiply:UnnormalizedResult: Multiplication does not yield a " |
140 | | - "normalized result.", |
141 | | - UserWarning, |
142 | | - ) |
143 | | - result = copy.deepcopy(self) |
144 | | - result.grid_values = result.grid_values * other.grid_values |
145 | | - return result |
146 | | - |
147 | 84 | # ------------------------------------------------------------------ |
148 | 85 | # Marginalisation and conditioning |
149 | 86 | # ------------------------------------------------------------------ |
@@ -201,26 +138,7 @@ def fix_dim(self, first_or_second, point): |
201 | 138 | HypersphericalGridDistribution, |
202 | 139 | ) |
203 | 140 |
|
204 | | - d = self.grid.shape[1] |
205 | | - if point.shape[0] != d: |
206 | | - raise ValueError( |
207 | | - f"point must have length {d} (embedding dimension of the sphere)." |
208 | | - ) |
209 | | - |
210 | | - diffs = linalg.norm(self.grid - point[None, :], axis=1) |
211 | | - locb = argmin(diffs) |
212 | | - if diffs[locb] > 1e-10: |
213 | | - raise ValueError( |
214 | | - "Cannot fix value at this point because it is not on the grid." |
215 | | - ) |
216 | | - |
217 | | - if first_or_second == 1: |
218 | | - grid_values_slice = self.grid_values[locb, :] |
219 | | - elif first_or_second == 2: |
220 | | - grid_values_slice = self.grid_values[:, locb] |
221 | | - else: |
222 | | - raise ValueError("first_or_second must be 1 or 2.") |
223 | | - |
| 141 | + grid_values_slice = self._get_grid_slice(first_or_second, point) |
224 | 142 | return HypersphericalGridDistribution(self.grid, grid_values_slice) |
225 | 143 |
|
226 | 144 | # ------------------------------------------------------------------ |
@@ -276,24 +194,8 @@ def from_function( |
276 | 194 | # manifold dim: embedding_dim = dim // 2, manifold_dim = embedding_dim - 1. |
277 | 195 | manifold_dim = dim // 2 - 1 |
278 | 196 | grid, _ = get_grid_hypersphere(grid_type, n, manifold_dim) |
279 | | - # grid is (n, dim//2) |
280 | | - |
281 | | - if fun_does_cartesian_product: |
282 | | - fvals = fun(grid, grid) |
283 | | - grid_values = fvals.reshape(n, n) |
284 | | - else: |
285 | | - # Build index pairs: idx_a[i, j] = i, idx_b[i, j] = j |
286 | | - idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij") |
287 | | - grid_a = grid[idx_a.ravel()] # (n*n, d) |
288 | | - grid_b = grid[idx_b.ravel()] # (n*n, d) |
289 | | - fvals = fun(grid_a, grid_b) # (n*n,) |
290 | | - |
291 | | - if fvals.shape == (n**2, n**2): |
292 | | - raise ValueError( |
293 | | - "Function apparently performs the Cartesian product itself. " |
294 | | - "Set fun_does_cartesian_product=True." |
295 | | - ) |
296 | | - |
297 | | - grid_values = fvals.reshape(n, n) |
298 | 197 |
|
| 198 | + grid_values = SdCondSdGridDistribution._evaluate_on_grid( |
| 199 | + fun, grid, n, fun_does_cartesian_product |
| 200 | + ) |
299 | 201 | return SdCondSdGridDistribution(grid, grid_values) |
0 commit comments