Skip to content

Commit 47dbe3b

Browse files
authored
Merge pull request #1583 from FlorianPfaff/copilot/add-state-space-subdivision-gaussian
Add StateSpaceSubdivisionGaussianDistribution for grid-based hybrid state spaces
2 parents 0c6e25c + e5d6792 commit 47dbe3b

4 files changed

Lines changed: 476 additions & 2 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import copy
2+
from abc import abstractmethod
3+
4+
5+
class StateSpaceSubdivisionDistribution:
6+
"""
7+
Represents a joint distribution over a Cartesian product of a grid-based
8+
(periodic/bounded) space and a linear space, where the linear part is
9+
represented as a collection of distributions conditioned on each grid point
10+
of the periodic/bounded part.
11+
12+
The periodic part is stored as an AbstractGridDistribution, which holds
13+
grid_values (unnormalized weights) at each grid point. The linear part is
14+
stored as a list of distributions, one per grid point.
15+
"""
16+
17+
def __init__(self, gd, linear_distributions):
18+
"""
19+
Parameters
20+
----------
21+
gd : AbstractGridDistribution
22+
Grid-based distribution for the periodic/bounded part. Its
23+
grid_values represent (unnormalized) marginal weights over the
24+
grid points.
25+
linear_distributions : list
26+
One distribution per grid point representing the conditional
27+
distribution of the linear state given that grid point.
28+
"""
29+
assert gd.n_grid_points == len(linear_distributions), (
30+
"Number of grid points in gd must match length of linear_distributions."
31+
)
32+
self.gd = copy.deepcopy(gd)
33+
self.linear_distributions = list(copy.deepcopy(linear_distributions))
34+
35+
@property
36+
def bound_dim(self):
37+
"""Dimension of the periodic/bounded space (ambient dimension of grid points)."""
38+
return self.gd.dim
39+
40+
@property
41+
def lin_dim(self):
42+
"""Dimension of the linear space."""
43+
return self.linear_distributions[0].dim
44+
45+
def hybrid_mean(self):
46+
"""
47+
Returns the hybrid mean, i.e. the concatenation of the mean direction
48+
of the periodic part and the mean of the linear marginal.
49+
"""
50+
# pylint: disable=no-name-in-module,no-member
51+
from pyrecest.backend import concatenate
52+
53+
periodic_mean = self.gd.mean_direction()
54+
linear_mean_val = self.marginalize_periodic().mean()
55+
return concatenate([periodic_mean.reshape(-1), linear_mean_val.reshape(-1)])
56+
57+
@abstractmethod
58+
def marginalize_linear(self):
59+
"""Marginalise out the linear dimensions, returning a distribution over
60+
the periodic/bounded part only."""
61+
62+
@abstractmethod
63+
def marginalize_periodic(self):
64+
"""Marginalise out the periodic/bounded dimensions, returning a
65+
distribution over the linear part only."""
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import copy
2+
import warnings
3+
4+
# pylint: disable=no-name-in-module,no-member
5+
from pyrecest.backend import (
6+
allclose,
7+
any as backend_any,
8+
argmax,
9+
array,
10+
asarray,
11+
concatenate,
12+
stack,
13+
sum as backend_sum,
14+
zeros,
15+
)
16+
17+
from ..nonperiodic.gaussian_distribution import GaussianDistribution
18+
from ..nonperiodic.gaussian_mixture import GaussianMixture
19+
from .state_space_subdivision_distribution import StateSpaceSubdivisionDistribution
20+
21+
22+
class StateSpaceSubdivisionGaussianDistribution(StateSpaceSubdivisionDistribution):
23+
"""
24+
Joint distribution over a Cartesian product of a grid-based
25+
(periodic/bounded) space and a linear space where every conditional
26+
linear distribution is a Gaussian.
27+
28+
The periodic part is a grid distribution (e.g. HypertoroidalGridDistribution
29+
or HyperhemisphericalGridDistribution). The linear part is a list of
30+
GaussianDistribution objects, one per grid point.
31+
"""
32+
33+
def __init__(self, gd, gaussians):
34+
"""
35+
Parameters
36+
----------
37+
gd : AbstractGridDistribution
38+
Grid-based distribution for the periodic/bounded part.
39+
gaussians : list of GaussianDistribution
40+
One Gaussian per grid point of *gd*.
41+
"""
42+
assert all(isinstance(g, GaussianDistribution) for g in gaussians), (
43+
"All elements of gaussians must be GaussianDistribution instances."
44+
)
45+
super().__init__(gd, gaussians)
46+
47+
# ------------------------------------------------------------------
48+
# Marginalisation
49+
# ------------------------------------------------------------------
50+
51+
def marginalize_linear(self):
52+
"""Return the grid distribution (marginalised over the linear part)."""
53+
return copy.deepcopy(self.gd)
54+
55+
def marginalize_periodic(self):
56+
"""
57+
Marginalise over the periodic/bounded dimensions.
58+
59+
Returns a GaussianMixture whose components are the conditional
60+
Gaussians and whose weights are the (normalised) grid values.
61+
"""
62+
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
63+
return GaussianMixture(list(self.linear_distributions), weights)
64+
65+
# ------------------------------------------------------------------
66+
# Linear moments
67+
# ------------------------------------------------------------------
68+
69+
def linear_mean(self):
70+
"""
71+
Compute the mean of the marginal linear distribution by treating
72+
the state as a Gaussian mixture.
73+
74+
Returns
75+
-------
76+
mu : array, shape (lin_dim,)
77+
"""
78+
means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim)
79+
covs = stack(
80+
[ld.C for ld in self.linear_distributions], axis=2
81+
) # (lin_dim, lin_dim, n)
82+
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
83+
mu, _ = GaussianMixture.mixture_parameters_to_gaussian_parameters(
84+
means, covs, weights
85+
)
86+
return mu
87+
88+
def linear_covariance(self):
89+
"""
90+
Compute the covariance of the marginal linear distribution by treating
91+
the state as a Gaussian mixture.
92+
93+
Returns
94+
-------
95+
C : array, shape (lin_dim, lin_dim)
96+
"""
97+
means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim)
98+
covs = stack(
99+
[ld.C for ld in self.linear_distributions], axis=2
100+
) # (lin_dim, lin_dim, n)
101+
weights = self.gd.grid_values / backend_sum(self.gd.grid_values)
102+
_, C = GaussianMixture.mixture_parameters_to_gaussian_parameters(
103+
means, covs, weights
104+
)
105+
return C
106+
107+
# ------------------------------------------------------------------
108+
# Multiplication
109+
# ------------------------------------------------------------------
110+
111+
def multiply(self, other):
112+
"""
113+
Multiply two StateSpaceSubdivisionGaussianDistributions.
114+
115+
Both operands must be defined on the same grid. For each grid point
116+
the conditional Gaussians are multiplied (Bayesian update). The grid
117+
weights are updated by the likelihood factors that arise from the
118+
overlap of the two conditional Gaussians.
119+
120+
Parameters
121+
----------
122+
other : StateSpaceSubdivisionGaussianDistribution
123+
124+
Returns
125+
-------
126+
StateSpaceSubdivisionGaussianDistribution
127+
"""
128+
assert isinstance(other, StateSpaceSubdivisionGaussianDistribution)
129+
assert self.gd.n_grid_points == other.gd.n_grid_points, (
130+
"Can only multiply distributions defined on grids with the same "
131+
"number of grid points."
132+
)
133+
self_grid = asarray(self.gd.get_grid())
134+
other_grid = asarray(other.gd.get_grid())
135+
assert allclose(self_grid, other_grid), (
136+
"Can only multiply for equal grids."
137+
)
138+
139+
n = len(self.linear_distributions)
140+
new_linear_distributions = []
141+
pdf_values = []
142+
143+
for i in range(n):
144+
ld_self = self.linear_distributions[i]
145+
ld_other = other.linear_distributions[i]
146+
147+
# The likelihood factor for grid point i is the pdf of
148+
# N(mu_self_i, C_self_i + C_other_i) evaluated at mu_other_i.
149+
# This is equivalent to N(0, C_self_i + C_other_i) at 0.
150+
combined_cov = ld_self.C + ld_other.C
151+
temp_g = GaussianDistribution(ld_other.mu, combined_cov, check_validity=False)
152+
pdf_values.append(temp_g.pdf(ld_self.mu))
153+
154+
new_linear_distributions.append(ld_self.multiply(ld_other))
155+
156+
# Build a 1-D factors array. pdf() may return shape () or (1,) depending
157+
# on backend and Gaussian dimension; reshape each value to (1,) before
158+
# concatenating so the result is always shape (n,).
159+
factors_linear = concatenate([asarray(v).reshape((1,)) for v in pdf_values])
160+
161+
# Build result
162+
result = copy.deepcopy(self)
163+
result.linear_distributions = new_linear_distributions
164+
result.gd = copy.deepcopy(self.gd)
165+
result.gd.grid_values = (
166+
self.gd.grid_values * other.gd.grid_values * array(factors_linear)
167+
)
168+
result.gd.normalize_in_place(warn_unnorm=False)
169+
return result
170+
171+
# ------------------------------------------------------------------
172+
# Mode
173+
# ------------------------------------------------------------------
174+
175+
def mode(self):
176+
"""
177+
Compute the (approximate) joint mode.
178+
179+
The mode is found by maximising the product of the conditional
180+
Gaussian peak value and the grid weight at each grid point. Only
181+
the discrete grid is searched (no interpolation).
182+
183+
Returns
184+
-------
185+
m : array, shape (bound_dim + lin_dim,)
186+
Concatenation of the periodic mode (grid point) and the linear
187+
mode (mean of the conditional Gaussian at that grid point).
188+
189+
Warns
190+
-----
191+
UserWarning
192+
If the density appears multimodal (i.e. another grid point has a
193+
joint value within a factor of 1.001 of the maximum).
194+
"""
195+
lin_dim = self.linear_distributions[0].dim
196+
zeros_d = zeros(lin_dim)
197+
198+
# Peak value of N(mu_i, C_i) depends only on C_i; it equals
199+
# N(0 | 0, C_i). We evaluate each conditional Gaussian at its own
200+
# mean to obtain the maximum pdf value.
201+
peak_vals = array(
202+
[
203+
float(
204+
GaussianDistribution(zeros_d, ld.C, check_validity=False).pdf(
205+
zeros_d
206+
)
207+
)
208+
for ld in self.linear_distributions
209+
]
210+
)
211+
212+
fun_vals_joint = peak_vals * asarray(self.gd.grid_values)
213+
index = int(argmax(fun_vals_joint))
214+
max_val = float(fun_vals_joint[index])
215+
216+
# Remove the maximum entry to check for multimodality
217+
remaining = concatenate(
218+
[fun_vals_joint[:index], fun_vals_joint[index + 1:]] # noqa: E203
219+
)
220+
if len(remaining) > 0 and (
221+
backend_any((max_val - remaining) < 1e-15)
222+
or backend_any((max_val / remaining) < 1.001)
223+
):
224+
warnings.warn(
225+
"Density may not be unimodal. However, this can also be caused "
226+
"by a high grid resolution and thus very similar function values "
227+
"at the grid points.",
228+
UserWarning,
229+
stacklevel=2,
230+
)
231+
232+
periodic_mode = self.gd.get_grid_point(index) # shape (bound_dim,)
233+
linear_mode = self.linear_distributions[index].mu # shape (lin_dim,)
234+
return concatenate([periodic_mode.reshape(-1), linear_mode.reshape(-1)])
235+
236+
# ------------------------------------------------------------------
237+
# Unsupported operations
238+
# ------------------------------------------------------------------
239+
240+
def convolve(self, _other):
241+
raise NotImplementedError(
242+
"convolve is not supported for "
243+
"StateSpaceSubdivisionGaussianDistribution."
244+
)

pyrecest/distributions/nonperiodic/gaussian_mixture.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=redefined-builtin,no-name-in-module,no-member
22
# pylint: disable=no-name-in-module,no-member
3-
from pyrecest.backend import array, dot, ones, stack, sum
3+
from pyrecest.backend import array, ones, reshape, stack, sum
44

55
from .abstract_linear_distribution import AbstractLinearDistribution
66
from .gaussian_distribution import GaussianDistribution
@@ -15,7 +15,8 @@ def __init__(self, dists: list[GaussianDistribution], w):
1515

1616
def mean(self):
1717
gauss_array = self.dists
18-
return dot(array([g.mu for g in gauss_array]), self.w)
18+
means = array([g.mu for g in gauss_array]) # shape (n, dim)
19+
return sum(means * reshape(self.w, (-1, 1)), axis=0)
1920

2021
def set_mean(self, new_mean):
2122
mean_offset = new_mean - self.mean()

0 commit comments

Comments
 (0)