Skip to content

Commit 5a2b2ea

Browse files
Fix R0801 duplicate-code: move shared logic to AbstractConditionalDistribution
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/c2dbd81e-252c-43e0-a0b9-ccda405fe1d6 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
1 parent 5d64c77 commit 5a2b2ea

3 files changed

Lines changed: 172 additions & 205 deletions

File tree

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,165 @@
1+
import copy
2+
import warnings
13
from abc import ABC
24

5+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
6+
from pyrecest.backend import (
7+
abs,
8+
any,
9+
arange,
10+
argmin,
11+
array_equal,
12+
linalg,
13+
meshgrid,
14+
)
15+
316

417
class AbstractConditionalDistribution(ABC):
5-
pass
18+
"""Abstract base class for conditional grid distributions on manifolds.
19+
20+
Subclasses represent distributions of the form f(a | b) where both a and b
21+
live on the same manifold. The joint state is stored as a square matrix
22+
``grid_values`` where ``grid_values[i, j] = f(grid[i] | grid[j])``.
23+
"""
24+
25+
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
26+
"""Common initialisation for conditional grid distributions.
27+
28+
Parameters
29+
----------
30+
grid : array of shape (n_points, d)
31+
Grid points on the individual manifold.
32+
grid_values : array of shape (n_points, n_points)
33+
Conditional pdf values; ``grid_values[i, j] = f(grid[i] | grid[j])``.
34+
enforce_pdf_nonnegative : bool
35+
Whether to require non-negative ``grid_values``.
36+
"""
37+
if grid.ndim != 2:
38+
raise ValueError("grid must be a 2D array of shape (n_points, d).")
39+
40+
n_points, d = grid.shape
41+
42+
if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points):
43+
raise ValueError(
44+
f"grid_values must be a square 2D array of shape ({n_points}, {n_points})."
45+
)
46+
47+
if enforce_pdf_nonnegative and any(grid_values < 0):
48+
raise ValueError("grid_values must be non-negative.")
49+
50+
self.grid = grid
51+
self.grid_values = grid_values
52+
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
53+
# Embedding dimension of the Cartesian product space (convention from
54+
# libDirectional: dim = 2 * dim_of_individual_manifold).
55+
self.dim = 2 * d
56+
57+
# ------------------------------------------------------------------
58+
# Normalization
59+
# ------------------------------------------------------------------
60+
61+
def normalize(self):
62+
"""No-op – returns ``self`` for compatibility."""
63+
return self
64+
65+
# ------------------------------------------------------------------
66+
# Arithmetic
67+
# ------------------------------------------------------------------
68+
69+
def multiply(self, other):
70+
"""Element-wise multiply two conditional grid distributions.
71+
72+
The resulting distribution is *not* normalized.
73+
74+
Parameters
75+
----------
76+
other : AbstractConditionalDistribution
77+
Must be defined on the same grid.
78+
79+
Returns
80+
-------
81+
AbstractConditionalDistribution
82+
Same concrete type as ``self``.
83+
"""
84+
if not array_equal(self.grid, other.grid):
85+
raise ValueError(
86+
"Multiply:IncompatibleGrid: Can only multiply distributions "
87+
"defined on identical grids."
88+
)
89+
warnings.warn(
90+
"Multiply:UnnormalizedResult: Multiplication does not yield a "
91+
"normalized result.",
92+
UserWarning,
93+
)
94+
result = copy.deepcopy(self)
95+
result.grid_values = result.grid_values * other.grid_values
96+
return result
97+
98+
# ------------------------------------------------------------------
99+
# Protected helpers
100+
# ------------------------------------------------------------------
101+
102+
def _get_grid_slice(self, first_or_second, point):
103+
"""Return the ``grid_values`` slice for a fixed grid point.
104+
105+
Parameters
106+
----------
107+
first_or_second : int (1 or 2)
108+
Which variable to fix.
109+
point : array of shape (d,)
110+
Must be an existing grid point.
111+
112+
Returns
113+
-------
114+
array of shape (n_points,)
115+
"""
116+
d = self.grid.shape[1]
117+
if point.shape[0] != d:
118+
raise ValueError(
119+
f"point must have length {d} (grid dimension)."
120+
)
121+
diffs = linalg.norm(self.grid - point[None, :], axis=1)
122+
locb = argmin(diffs)
123+
if diffs[locb] > 1e-10:
124+
raise ValueError(
125+
"Cannot fix value at this point because it is not on the grid."
126+
)
127+
if first_or_second == 1:
128+
return self.grid_values[locb, :]
129+
if first_or_second == 2:
130+
return self.grid_values[:, locb]
131+
raise ValueError("first_or_second must be 1 or 2.")
132+
133+
@staticmethod
134+
def _evaluate_on_grid(fun, grid, n, fun_does_cartesian_product):
135+
"""Evaluate ``fun`` on all grid point pairs and return an (n, n) array.
136+
137+
Parameters
138+
----------
139+
fun : callable
140+
``f(a, b)`` with the semantics described in ``from_function``.
141+
grid : array of shape (n, d)
142+
Grid points on the individual manifold.
143+
n : int
144+
Number of grid points (``grid.shape[0]``).
145+
fun_does_cartesian_product : bool
146+
Whether *fun* handles all grid combinations internally.
147+
148+
Returns
149+
-------
150+
array of shape (n, n)
151+
"""
152+
if fun_does_cartesian_product:
153+
fvals = fun(grid, grid)
154+
return fvals.reshape(n, n)
155+
idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij")
156+
grid_a = grid[idx_a.ravel()]
157+
grid_b = grid[idx_b.ravel()]
158+
fvals = fun(grid_a, grid_b)
159+
if fvals.shape == (n**2, n**2):
160+
raise ValueError(
161+
"Function apparently performs the Cartesian product itself. "
162+
"Set fun_does_cartesian_product=True."
163+
)
164+
return fvals.reshape(n, n)
165+

pyrecest/distributions/conditional/sd_cond_sd_grid_distribution.py

Lines changed: 6 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
import copy
21
import warnings
32

43
# pylint: disable=redefined-builtin,no-name-in-module,no-member
54
from pyrecest.backend import (
65
abs,
76
all,
87
any,
9-
arange,
10-
argmin,
11-
array_equal,
12-
linalg,
138
mean,
14-
meshgrid,
159
sum,
1610
)
1711
from pyrecest.distributions.hypersphere_subset.abstract_hypersphere_subset_distribution import (
@@ -50,31 +44,11 @@ def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
5044
enforce_pdf_nonnegative : bool
5145
Whether non-negativity of ``grid_values`` is required.
5246
"""
53-
if grid.ndim != 2:
54-
raise ValueError("grid must be a 2D array of shape (n_points, d).")
55-
56-
n_points, d = grid.shape
57-
58-
if grid_values.ndim != 2 or grid_values.shape != (n_points, n_points):
59-
raise ValueError(
60-
f"grid_values must be a square 2D array of shape ({n_points}, {n_points})."
61-
)
62-
63-
if any(abs(grid) > 1 + 1e-12):
47+
super().__init__(grid, grid_values, enforce_pdf_nonnegative)
48+
if any(abs(self.grid) > 1 + 1e-12):
6449
raise ValueError(
6550
"Grid points must have coordinates in [-1, 1] (unit sphere)."
6651
)
67-
68-
if enforce_pdf_nonnegative and any(grid_values < 0):
69-
raise ValueError("grid_values must be non-negative.")
70-
71-
self.grid = grid
72-
self.grid_values = grid_values
73-
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
74-
# Embedding dimension of the Cartesian product space (convention from
75-
# libDirectional: dim = 2 * embedding_dim_of_individual_sphere).
76-
self.dim = 2 * d
77-
7852
self._check_normalization()
7953

8054
# ------------------------------------------------------------------
@@ -107,43 +81,6 @@ def _check_normalization(self, tol=0.01):
10781
UserWarning,
10882
)
10983

110-
def normalize(self):
111-
"""No-op – returns ``self`` for compatibility."""
112-
return self
113-
114-
# ------------------------------------------------------------------
115-
# Arithmetic
116-
# ------------------------------------------------------------------
117-
118-
def multiply(self, other):
119-
"""
120-
Element-wise multiply two conditional grid distributions.
121-
122-
The resulting distribution is *not* normalized.
123-
124-
Parameters
125-
----------
126-
other : SdCondSdGridDistribution
127-
Must be defined on the same grid.
128-
129-
Returns
130-
-------
131-
SdCondSdGridDistribution
132-
"""
133-
if not array_equal(self.grid, other.grid):
134-
raise ValueError(
135-
"Multiply:IncompatibleGrid: Can only multiply distributions "
136-
"defined on identical grids."
137-
)
138-
warnings.warn(
139-
"Multiply:UnnormalizedResult: Multiplication does not yield a "
140-
"normalized result.",
141-
UserWarning,
142-
)
143-
result = copy.deepcopy(self)
144-
result.grid_values = result.grid_values * other.grid_values
145-
return result
146-
14784
# ------------------------------------------------------------------
14885
# Marginalisation and conditioning
14986
# ------------------------------------------------------------------
@@ -201,26 +138,7 @@ def fix_dim(self, first_or_second, point):
201138
HypersphericalGridDistribution,
202139
)
203140

204-
d = self.grid.shape[1]
205-
if point.shape[0] != d:
206-
raise ValueError(
207-
f"point must have length {d} (embedding dimension of the sphere)."
208-
)
209-
210-
diffs = linalg.norm(self.grid - point[None, :], axis=1)
211-
locb = argmin(diffs)
212-
if diffs[locb] > 1e-10:
213-
raise ValueError(
214-
"Cannot fix value at this point because it is not on the grid."
215-
)
216-
217-
if first_or_second == 1:
218-
grid_values_slice = self.grid_values[locb, :]
219-
elif first_or_second == 2:
220-
grid_values_slice = self.grid_values[:, locb]
221-
else:
222-
raise ValueError("first_or_second must be 1 or 2.")
223-
141+
grid_values_slice = self._get_grid_slice(first_or_second, point)
224142
return HypersphericalGridDistribution(self.grid, grid_values_slice)
225143

226144
# ------------------------------------------------------------------
@@ -276,24 +194,8 @@ def from_function(
276194
# manifold dim: embedding_dim = dim // 2, manifold_dim = embedding_dim - 1.
277195
manifold_dim = dim // 2 - 1
278196
grid, _ = get_grid_hypersphere(grid_type, n, manifold_dim)
279-
# grid is (n, dim//2)
280-
281-
if fun_does_cartesian_product:
282-
fvals = fun(grid, grid)
283-
grid_values = fvals.reshape(n, n)
284-
else:
285-
# Build index pairs: idx_a[i, j] = i, idx_b[i, j] = j
286-
idx_a, idx_b = meshgrid(arange(n), arange(n), indexing="ij")
287-
grid_a = grid[idx_a.ravel()] # (n*n, d)
288-
grid_b = grid[idx_b.ravel()] # (n*n, d)
289-
fvals = fun(grid_a, grid_b) # (n*n,)
290-
291-
if fvals.shape == (n**2, n**2):
292-
raise ValueError(
293-
"Function apparently performs the Cartesian product itself. "
294-
"Set fun_does_cartesian_product=True."
295-
)
296-
297-
grid_values = fvals.reshape(n, n)
298197

198+
grid_values = SdCondSdGridDistribution._evaluate_on_grid(
199+
fun, grid, n, fun_does_cartesian_product
200+
)
299201
return SdCondSdGridDistribution(grid, grid_values)

0 commit comments

Comments
 (0)