|
| 1 | +import copy |
| 2 | +import warnings |
| 3 | + |
| 4 | +# pylint: disable=no-name-in-module,no-member |
| 5 | +from pyrecest.backend import ( |
| 6 | + allclose, |
| 7 | + any as backend_any, |
| 8 | + argmax, |
| 9 | + array, |
| 10 | + asarray, |
| 11 | + concatenate, |
| 12 | + stack, |
| 13 | + sum as backend_sum, |
| 14 | + zeros, |
| 15 | +) |
| 16 | + |
| 17 | +from ..nonperiodic.gaussian_distribution import GaussianDistribution |
| 18 | +from ..nonperiodic.gaussian_mixture import GaussianMixture |
| 19 | +from .state_space_subdivision_distribution import StateSpaceSubdivisionDistribution |
| 20 | + |
| 21 | + |
| 22 | +class StateSpaceSubdivisionGaussianDistribution(StateSpaceSubdivisionDistribution): |
| 23 | + """ |
| 24 | + Joint distribution over a Cartesian product of a grid-based |
| 25 | + (periodic/bounded) space and a linear space where every conditional |
| 26 | + linear distribution is a Gaussian. |
| 27 | +
|
| 28 | + The periodic part is a grid distribution (e.g. HypertoroidalGridDistribution |
| 29 | + or HyperhemisphericalGridDistribution). The linear part is a list of |
| 30 | + GaussianDistribution objects, one per grid point. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, gd, gaussians): |
| 34 | + """ |
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + gd : AbstractGridDistribution |
| 38 | + Grid-based distribution for the periodic/bounded part. |
| 39 | + gaussians : list of GaussianDistribution |
| 40 | + One Gaussian per grid point of *gd*. |
| 41 | + """ |
| 42 | + assert all(isinstance(g, GaussianDistribution) for g in gaussians), ( |
| 43 | + "All elements of gaussians must be GaussianDistribution instances." |
| 44 | + ) |
| 45 | + super().__init__(gd, gaussians) |
| 46 | + |
| 47 | + # ------------------------------------------------------------------ |
| 48 | + # Marginalisation |
| 49 | + # ------------------------------------------------------------------ |
| 50 | + |
| 51 | + def marginalize_linear(self): |
| 52 | + """Return the grid distribution (marginalised over the linear part).""" |
| 53 | + return copy.deepcopy(self.gd) |
| 54 | + |
| 55 | + def marginalize_periodic(self): |
| 56 | + """ |
| 57 | + Marginalise over the periodic/bounded dimensions. |
| 58 | +
|
| 59 | + Returns a GaussianMixture whose components are the conditional |
| 60 | + Gaussians and whose weights are the (normalised) grid values. |
| 61 | + """ |
| 62 | + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) |
| 63 | + return GaussianMixture(list(self.linear_distributions), weights) |
| 64 | + |
| 65 | + # ------------------------------------------------------------------ |
| 66 | + # Linear moments |
| 67 | + # ------------------------------------------------------------------ |
| 68 | + |
| 69 | + def linear_mean(self): |
| 70 | + """ |
| 71 | + Compute the mean of the marginal linear distribution by treating |
| 72 | + the state as a Gaussian mixture. |
| 73 | +
|
| 74 | + Returns |
| 75 | + ------- |
| 76 | + mu : array, shape (lin_dim,) |
| 77 | + """ |
| 78 | + means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim) |
| 79 | + covs = stack( |
| 80 | + [ld.C for ld in self.linear_distributions], axis=2 |
| 81 | + ) # (lin_dim, lin_dim, n) |
| 82 | + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) |
| 83 | + mu, _ = GaussianMixture.mixture_parameters_to_gaussian_parameters( |
| 84 | + means, covs, weights |
| 85 | + ) |
| 86 | + return mu |
| 87 | + |
| 88 | + def linear_covariance(self): |
| 89 | + """ |
| 90 | + Compute the covariance of the marginal linear distribution by treating |
| 91 | + the state as a Gaussian mixture. |
| 92 | +
|
| 93 | + Returns |
| 94 | + ------- |
| 95 | + C : array, shape (lin_dim, lin_dim) |
| 96 | + """ |
| 97 | + means = array([ld.mu for ld in self.linear_distributions]) # (n, lin_dim) |
| 98 | + covs = stack( |
| 99 | + [ld.C for ld in self.linear_distributions], axis=2 |
| 100 | + ) # (lin_dim, lin_dim, n) |
| 101 | + weights = self.gd.grid_values / backend_sum(self.gd.grid_values) |
| 102 | + _, C = GaussianMixture.mixture_parameters_to_gaussian_parameters( |
| 103 | + means, covs, weights |
| 104 | + ) |
| 105 | + return C |
| 106 | + |
| 107 | + # ------------------------------------------------------------------ |
| 108 | + # Multiplication |
| 109 | + # ------------------------------------------------------------------ |
| 110 | + |
| 111 | + def multiply(self, other): |
| 112 | + """ |
| 113 | + Multiply two StateSpaceSubdivisionGaussianDistributions. |
| 114 | +
|
| 115 | + Both operands must be defined on the same grid. For each grid point |
| 116 | + the conditional Gaussians are multiplied (Bayesian update). The grid |
| 117 | + weights are updated by the likelihood factors that arise from the |
| 118 | + overlap of the two conditional Gaussians. |
| 119 | +
|
| 120 | + Parameters |
| 121 | + ---------- |
| 122 | + other : StateSpaceSubdivisionGaussianDistribution |
| 123 | +
|
| 124 | + Returns |
| 125 | + ------- |
| 126 | + StateSpaceSubdivisionGaussianDistribution |
| 127 | + """ |
| 128 | + assert isinstance(other, StateSpaceSubdivisionGaussianDistribution) |
| 129 | + assert self.gd.n_grid_points == other.gd.n_grid_points, ( |
| 130 | + "Can only multiply distributions defined on grids with the same " |
| 131 | + "number of grid points." |
| 132 | + ) |
| 133 | + self_grid = asarray(self.gd.get_grid()) |
| 134 | + other_grid = asarray(other.gd.get_grid()) |
| 135 | + assert allclose(self_grid, other_grid), ( |
| 136 | + "Can only multiply for equal grids." |
| 137 | + ) |
| 138 | + |
| 139 | + n = len(self.linear_distributions) |
| 140 | + new_linear_distributions = [] |
| 141 | + pdf_values = [] |
| 142 | + |
| 143 | + for i in range(n): |
| 144 | + ld_self = self.linear_distributions[i] |
| 145 | + ld_other = other.linear_distributions[i] |
| 146 | + |
| 147 | + # The likelihood factor for grid point i is the pdf of |
| 148 | + # N(mu_self_i, C_self_i + C_other_i) evaluated at mu_other_i. |
| 149 | + # This is equivalent to N(0, C_self_i + C_other_i) at 0. |
| 150 | + combined_cov = ld_self.C + ld_other.C |
| 151 | + temp_g = GaussianDistribution(ld_other.mu, combined_cov, check_validity=False) |
| 152 | + pdf_values.append(temp_g.pdf(ld_self.mu)) |
| 153 | + |
| 154 | + new_linear_distributions.append(ld_self.multiply(ld_other)) |
| 155 | + |
| 156 | + # Build a 1-D factors array. pdf() may return shape () or (1,) depending |
| 157 | + # on backend and Gaussian dimension; reshape each value to (1,) before |
| 158 | + # concatenating so the result is always shape (n,). |
| 159 | + factors_linear = concatenate([asarray(v).reshape((1,)) for v in pdf_values]) |
| 160 | + |
| 161 | + # Build result |
| 162 | + result = copy.deepcopy(self) |
| 163 | + result.linear_distributions = new_linear_distributions |
| 164 | + result.gd = copy.deepcopy(self.gd) |
| 165 | + result.gd.grid_values = ( |
| 166 | + self.gd.grid_values * other.gd.grid_values * array(factors_linear) |
| 167 | + ) |
| 168 | + result.gd.normalize_in_place(warn_unnorm=False) |
| 169 | + return result |
| 170 | + |
| 171 | + # ------------------------------------------------------------------ |
| 172 | + # Mode |
| 173 | + # ------------------------------------------------------------------ |
| 174 | + |
| 175 | + def mode(self): |
| 176 | + """ |
| 177 | + Compute the (approximate) joint mode. |
| 178 | +
|
| 179 | + The mode is found by maximising the product of the conditional |
| 180 | + Gaussian peak value and the grid weight at each grid point. Only |
| 181 | + the discrete grid is searched (no interpolation). |
| 182 | +
|
| 183 | + Returns |
| 184 | + ------- |
| 185 | + m : array, shape (bound_dim + lin_dim,) |
| 186 | + Concatenation of the periodic mode (grid point) and the linear |
| 187 | + mode (mean of the conditional Gaussian at that grid point). |
| 188 | +
|
| 189 | + Warns |
| 190 | + ----- |
| 191 | + UserWarning |
| 192 | + If the density appears multimodal (i.e. another grid point has a |
| 193 | + joint value within a factor of 1.001 of the maximum). |
| 194 | + """ |
| 195 | + lin_dim = self.linear_distributions[0].dim |
| 196 | + zeros_d = zeros(lin_dim) |
| 197 | + |
| 198 | + # Peak value of N(mu_i, C_i) depends only on C_i; it equals |
| 199 | + # N(0 | 0, C_i). We evaluate each conditional Gaussian at its own |
| 200 | + # mean to obtain the maximum pdf value. |
| 201 | + peak_vals = array( |
| 202 | + [ |
| 203 | + float( |
| 204 | + GaussianDistribution(zeros_d, ld.C, check_validity=False).pdf( |
| 205 | + zeros_d |
| 206 | + ) |
| 207 | + ) |
| 208 | + for ld in self.linear_distributions |
| 209 | + ] |
| 210 | + ) |
| 211 | + |
| 212 | + fun_vals_joint = peak_vals * asarray(self.gd.grid_values) |
| 213 | + index = int(argmax(fun_vals_joint)) |
| 214 | + max_val = float(fun_vals_joint[index]) |
| 215 | + |
| 216 | + # Remove the maximum entry to check for multimodality |
| 217 | + remaining = concatenate( |
| 218 | + [fun_vals_joint[:index], fun_vals_joint[index + 1:]] # noqa: E203 |
| 219 | + ) |
| 220 | + if len(remaining) > 0 and ( |
| 221 | + backend_any((max_val - remaining) < 1e-15) |
| 222 | + or backend_any((max_val / remaining) < 1.001) |
| 223 | + ): |
| 224 | + warnings.warn( |
| 225 | + "Density may not be unimodal. However, this can also be caused " |
| 226 | + "by a high grid resolution and thus very similar function values " |
| 227 | + "at the grid points.", |
| 228 | + UserWarning, |
| 229 | + stacklevel=2, |
| 230 | + ) |
| 231 | + |
| 232 | + periodic_mode = self.gd.get_grid_point(index) # shape (bound_dim,) |
| 233 | + linear_mode = self.linear_distributions[index].mu # shape (lin_dim,) |
| 234 | + return concatenate([periodic_mode.reshape(-1), linear_mode.reshape(-1)]) |
| 235 | + |
| 236 | + # ------------------------------------------------------------------ |
| 237 | + # Unsupported operations |
| 238 | + # ------------------------------------------------------------------ |
| 239 | + |
| 240 | + def convolve(self, _other): |
| 241 | + raise NotImplementedError( |
| 242 | + "convolve is not supported for " |
| 243 | + "StateSpaceSubdivisionGaussianDistribution." |
| 244 | + ) |
0 commit comments