|
| 1 | +# pylint: disable=redefined-builtin,no-name-in-module,no-member |
| 2 | +import copy |
| 3 | +from math import factorial |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from pyrecest.backend import array, cos, exp, mod, pi, sin |
| 7 | +from scipy.integrate import dblquad |
| 8 | +from scipy.special import iv |
| 9 | + |
| 10 | +from ..circle.custom_circular_distribution import CustomCircularDistribution |
| 11 | +from .abstract_toroidal_distribution import AbstractToroidalDistribution |
| 12 | + |
| 13 | + |
| 14 | +class ToroidalVMMatrixDistribution(AbstractToroidalDistribution): |
| 15 | + """Bivariate von Mises distribution, matrix version. |
| 16 | +
|
| 17 | + See: |
| 18 | + - Mardia, K. V. Statistics of Directional Data. JRSS-B, 1975. |
| 19 | + - Mardia, K. V. & Jupp, P. E. Directional Statistics. Wiley, 1999. |
| 20 | + - Kurz, Hanebeck. Toroidal Information Fusion Based on the Bivariate |
| 21 | + von Mises Distribution. MFI 2015. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__(self, mu, kappa, A): |
| 25 | + AbstractToroidalDistribution.__init__(self) |
| 26 | + assert mu.shape == (2,) |
| 27 | + assert kappa.shape == (2,) |
| 28 | + assert A.shape == (2, 2) |
| 29 | + assert kappa[0] > 0 |
| 30 | + assert kappa[1] > 0 |
| 31 | + |
| 32 | + self.mu = mod(mu, 2.0 * pi) |
| 33 | + self.kappa = kappa |
| 34 | + self.A = A |
| 35 | + |
| 36 | + A_np = np.array(A, dtype=float) if not isinstance(A, np.ndarray) else A.astype(float) |
| 37 | + use_numerical = ( |
| 38 | + float(kappa[0]) > 1.5 |
| 39 | + or float(kappa[1]) > 1.5 |
| 40 | + or np.max(np.abs(A_np)) > 1.0 |
| 41 | + ) |
| 42 | + |
| 43 | + if use_numerical: |
| 44 | + self.C = 1.0 |
| 45 | + Cinv, _ = dblquad( |
| 46 | + lambda y, x: float(self.pdf(array([x, y]))), |
| 47 | + 0.0, |
| 48 | + float(2 * pi), |
| 49 | + 0.0, |
| 50 | + float(2 * pi), |
| 51 | + ) |
| 52 | + self.C = 1.0 / Cinv |
| 53 | + else: |
| 54 | + self.C = self._norm_const_approx() |
| 55 | + |
| 56 | + def pdf(self, xs): |
| 57 | + assert xs.shape[-1] == 2 |
| 58 | + x1_mm = xs[..., 0] - self.mu[0] |
| 59 | + x2_mm = xs[..., 1] - self.mu[1] |
| 60 | + exponent = ( |
| 61 | + self.kappa[0] * cos(x1_mm) |
| 62 | + + self.kappa[1] * cos(x2_mm) |
| 63 | + + cos(x1_mm) * self.A[0, 0] * cos(x2_mm) |
| 64 | + + cos(x1_mm) * self.A[0, 1] * sin(x2_mm) |
| 65 | + + sin(x1_mm) * self.A[1, 0] * cos(x2_mm) |
| 66 | + + sin(x1_mm) * self.A[1, 1] * sin(x2_mm) |
| 67 | + ) |
| 68 | + return self.C * exp(exponent) |
| 69 | + |
| 70 | + def _norm_const_approx(self, n=8): |
| 71 | + """Approximate normalization constant using Taylor series (up to n=8 summands).""" |
| 72 | + a11 = float(self.A[0, 0]) |
| 73 | + a12 = float(self.A[0, 1]) |
| 74 | + a21 = float(self.A[1, 0]) |
| 75 | + a22 = float(self.A[1, 1]) |
| 76 | + k1 = float(self.kappa[0]) |
| 77 | + k2 = float(self.kappa[1]) |
| 78 | + |
| 79 | + total = 4 * np.pi**2 # n=0 term |
| 80 | + # n=1 term is zero |
| 81 | + if n >= 2: |
| 82 | + total += ( |
| 83 | + (a11**2 + a12**2 + a21**2 + a22**2 + 2 * k1**2 + 2 * k2**2) |
| 84 | + * np.pi**2 |
| 85 | + / factorial(2) |
| 86 | + ) |
| 87 | + if n >= 3: |
| 88 | + total += 6 * a11 * k1 * k2 * np.pi**2 / factorial(3) |
| 89 | + if n >= 4: |
| 90 | + total += ( |
| 91 | + 3 |
| 92 | + / 16 |
| 93 | + * ( |
| 94 | + 3 * a11**4 |
| 95 | + + 3 * a12**4 |
| 96 | + + 3 * a21**4 |
| 97 | + + 8 * a11 * a12 * a21 * a22 |
| 98 | + + 6 * a21**2 * a22**2 |
| 99 | + + 3 * a22**4 |
| 100 | + + 8 * a21**2 * k1**2 |
| 101 | + + 8 * a22**2 * k1**2 |
| 102 | + + 8 * k1**4 |
| 103 | + + 8 * (3 * a21**2 + a22**2 + 4 * k1**2) * k2**2 |
| 104 | + + 8 * k2**4 |
| 105 | + + 2 * a11**2 * (3 * a12**2 + 3 * a21**2 + a22**2 + 12 * (k1**2 + k2**2)) |
| 106 | + + 2 * a12**2 * (a21**2 + 3 * a22**2 + 4 * (3 * k1**2 + k2**2)) |
| 107 | + ) |
| 108 | + * np.pi**2 |
| 109 | + / factorial(4) |
| 110 | + ) |
| 111 | + if n >= 5: |
| 112 | + total += ( |
| 113 | + 15 |
| 114 | + / 4 |
| 115 | + * np.pi**2 |
| 116 | + * k1 |
| 117 | + * k2 |
| 118 | + * ( |
| 119 | + 3 * a11**3 |
| 120 | + + 3 * a11 * a12**2 |
| 121 | + + 3 * a11 * a21**2 |
| 122 | + + a11 * a22**2 |
| 123 | + + 4 * a11 * k1**2 |
| 124 | + + 4 * a11 * k2**2 |
| 125 | + + 2 * a12 * a21 * a22 |
| 126 | + ) |
| 127 | + / factorial(5) |
| 128 | + ) |
| 129 | + if n >= 6: |
| 130 | + total += ( |
| 131 | + 5 |
| 132 | + / 64 |
| 133 | + * np.pi**2 |
| 134 | + * ( |
| 135 | + 5 * a11**6 |
| 136 | + + 15 * a11**4 * a12**2 |
| 137 | + + 15 * a11**4 * a21**2 |
| 138 | + + 3 * a11**4 * a22**2 |
| 139 | + + 90 * a11**4 * k1**2 |
| 140 | + + 90 * a11**4 * k2**2 |
| 141 | + + 24 * a11**3 * a12 * a21 * a22 |
| 142 | + + 15 * a11**2 * a12**4 |
| 143 | + + 18 * a11**2 * a12**2 * a21**2 |
| 144 | + + 18 * a11**2 * a12**2 * a22**2 |
| 145 | + + 180 * a11**2 * a12**2 * k1**2 |
| 146 | + + 108 * a11**2 * a12**2 * k2**2 |
| 147 | + + 15 * a11**2 * a21**4 |
| 148 | + + 18 * a11**2 * a21**2 * a22**2 |
| 149 | + + 108 * a11**2 * a21**2 * k1**2 |
| 150 | + + 180 * a11**2 * a21**2 * k2**2 |
| 151 | + + 3 * a11**2 * a22**4 |
| 152 | + + 36 * a11**2 * a22**2 * k1**2 |
| 153 | + + 36 * a11**2 * a22**2 * k2**2 |
| 154 | + + 120 * a11**2 * k1**4 |
| 155 | + + 648 * a11**2 * k1**2 * k2**2 |
| 156 | + + 120 * a11**2 * k2**4 |
| 157 | + + 24 * a11 * a12**3 * a21 * a22 |
| 158 | + + 24 * a11 * a12 * a21**3 * a22 |
| 159 | + + 24 * a11 * a12 * a21 * a22**3 |
| 160 | + + 144 * a11 * a12 * a21 * a22 * k1**2 |
| 161 | + + 144 * a11 * a12 * a21 * a22 * k2**2 |
| 162 | + + 5 * a12**6 |
| 163 | + + 3 * a12**4 * a21**2 |
| 164 | + + 15 * a12**4 * a22**2 |
| 165 | + + 90 * a12**4 * k1**2 |
| 166 | + + 18 * a12**4 * k2**2 |
| 167 | + + 3 * a12**2 * a21**4 |
| 168 | + + 18 * a12**2 * a21**2 * a22**2 |
| 169 | + + 36 * a12**2 * a21**2 * k1**2 |
| 170 | + + 36 * a12**2 * a21**2 * k2**2 |
| 171 | + + 15 * a12**2 * a22**4 |
| 172 | + + 108 * a12**2 * a22**2 * k1**2 |
| 173 | + + 36 * a12**2 * a22**2 * k2**2 |
| 174 | + + 120 * a12**2 * k1**4 |
| 175 | + + 216 * a12**2 * k1**2 * k2**2 |
| 176 | + + 24 * a12**2 * k2**4 |
| 177 | + + 5 * a21**6 |
| 178 | + + 15 * a21**4 * a22**2 |
| 179 | + + 18 * a21**4 * k1**2 |
| 180 | + + 90 * a21**4 * k2**2 |
| 181 | + + 15 * a21**2 * a22**4 |
| 182 | + + 36 * a21**2 * a22**2 * k1**2 |
| 183 | + + 108 * a21**2 * a22**2 * k2**2 |
| 184 | + + 24 * a21**2 * k1**4 |
| 185 | + + 216 * a21**2 * k1**2 * k2**2 |
| 186 | + + 120 * a21**2 * k2**4 |
| 187 | + + 5 * a22**6 |
| 188 | + + 18 * a22**4 * k1**2 |
| 189 | + + 18 * a22**4 * k2**2 |
| 190 | + + 24 * a22**2 * k1**4 |
| 191 | + + 72 * a22**2 * k1**2 * k2**2 |
| 192 | + + 24 * a22**2 * k2**4 |
| 193 | + + 16 * k1**6 |
| 194 | + + 144 * k1**4 * k2**2 |
| 195 | + + 144 * k1**2 * k2**4 |
| 196 | + + 16 * k2**6 |
| 197 | + ) |
| 198 | + / factorial(6) |
| 199 | + ) |
| 200 | + if n >= 7: |
| 201 | + total += ( |
| 202 | + 105 |
| 203 | + / 32 |
| 204 | + * k1 |
| 205 | + * k2 |
| 206 | + * np.pi**2 |
| 207 | + * ( |
| 208 | + 5 * a11**5 |
| 209 | + + 10 * a11**3 * a12**2 |
| 210 | + + 10 * a11**3 * a21**2 |
| 211 | + + 2 * a11**3 * a22**2 |
| 212 | + + 20 * a11**3 * k1**2 |
| 213 | + + 20 * a11**3 * k2**2 |
| 214 | + + 12 * a11**2 * a12 * a21 * a22 |
| 215 | + + 5 * a11 * a12**4 |
| 216 | + + 6 * a11 * a12**2 * a21**2 |
| 217 | + + 6 * a11 * a12**2 * a22**2 |
| 218 | + + 20 * a11 * a12**2 * k1**2 |
| 219 | + + 12 * a11 * a12**2 * k2**2 |
| 220 | + + 5 * a11 * a21**4 |
| 221 | + + 6 * a11 * a21**2 * a22**2 |
| 222 | + + 12 * a11 * a21**2 * k1**2 |
| 223 | + + 20 * a11 * a21**2 * k2**2 |
| 224 | + + a11 * a22**4 |
| 225 | + + 4 * a11 * a22**2 * k1**2 |
| 226 | + + 4 * a11 * a22**2 * k2**2 |
| 227 | + + 8 * a11 * k1**4 |
| 228 | + + 24 * a11 * k1**2 * k2**2 |
| 229 | + + 8 * a11 * k2**4 |
| 230 | + + 4 * a12**3 * a21 * a22 |
| 231 | + + 4 * a12 * a21**3 * a22 |
| 232 | + + 4 * a12 * a21 * a22**3 |
| 233 | + + 8 * a12 * a21 * a22 * k1**2 |
| 234 | + + 8 * a12 * a21 * a22 * k2**2 |
| 235 | + ) |
| 236 | + / factorial(7) |
| 237 | + ) |
| 238 | + return 1.0 / total |
| 239 | + |
| 240 | + def multiply(self, other): |
| 241 | + """Multiply two ToroidalVMMatrixDistributions (exact product).""" |
| 242 | + assert isinstance(other, ToroidalVMMatrixDistribution) |
| 243 | + |
| 244 | + C1 = float(self.kappa[0]) * np.cos(float(self.mu[0])) + float(other.kappa[0]) * np.cos(float(other.mu[0])) |
| 245 | + S1 = float(self.kappa[0]) * np.sin(float(self.mu[0])) + float(other.kappa[0]) * np.sin(float(other.mu[0])) |
| 246 | + C2 = float(self.kappa[1]) * np.cos(float(self.mu[1])) + float(other.kappa[1]) * np.cos(float(other.mu[1])) |
| 247 | + S2 = float(self.kappa[1]) * np.sin(float(self.mu[1])) + float(other.kappa[1]) * np.sin(float(other.mu[1])) |
| 248 | + |
| 249 | + mu_new = array([np.arctan2(S1, C1) % (2 * np.pi), np.arctan2(S2, C2) % (2 * np.pi)]) |
| 250 | + kappa_new = array([np.sqrt(C1**2 + S1**2), np.sqrt(C2**2 + S2**2)]) |
| 251 | + |
| 252 | + def _M(mu): |
| 253 | + c1, s1 = np.cos(float(mu[0])), np.sin(float(mu[0])) |
| 254 | + c2, s2 = np.cos(float(mu[1])), np.sin(float(mu[1])) |
| 255 | + return np.array([ |
| 256 | + [ c1 * c2, -s1 * c2, -c1 * s2, s1 * s2], |
| 257 | + [ s1 * c2, c1 * c2, -s1 * s2, -c1 * s2], |
| 258 | + [ c1 * s2, -s1 * s2, c1 * c2, -s1 * c2], |
| 259 | + [ s1 * s2, c1 * s2, s1 * c2, c1 * c2], |
| 260 | + ]) |
| 261 | + |
| 262 | + # Pack A columns as [A11; A21; A12; A22] |
| 263 | + A1 = np.array([[float(self.A[0, 0])], [float(self.A[1, 0])], [float(self.A[0, 1])], [float(self.A[1, 1])]]) |
| 264 | + A2 = np.array([[float(other.A[0, 0])], [float(other.A[1, 0])], [float(other.A[0, 1])], [float(other.A[1, 1])]]) |
| 265 | + b = _M(self.mu) @ A1 + _M(other.mu) @ A2 |
| 266 | + a = np.linalg.solve(_M(mu_new), b).ravel() |
| 267 | + A_new = array([[a[0], a[2]], [a[1], a[3]]]) |
| 268 | + |
| 269 | + return ToroidalVMMatrixDistribution(mu_new, kappa_new, A_new) |
| 270 | + |
| 271 | + def marginalize_to_1d(self, dimension): |
| 272 | + """Get marginal distribution in the given dimension (0 or 1, 0-indexed). |
| 273 | +
|
| 274 | + Integrates out the *other* dimension analytically using the Bessel |
| 275 | + function identity for the von-Mises-type integral. |
| 276 | + """ |
| 277 | + assert dimension in (0, 1) |
| 278 | + other = 1 - dimension |
| 279 | + |
| 280 | + mu_d = float(self.mu[dimension]) |
| 281 | + mu_o = float(self.mu[other]) # noqa: F841 – retained for clarity |
| 282 | + k_d = float(self.kappa[dimension]) |
| 283 | + k_o = float(self.kappa[other]) |
| 284 | + a11 = float(self.A[0, 0]) |
| 285 | + a12 = float(self.A[0, 1]) |
| 286 | + a21 = float(self.A[1, 0]) |
| 287 | + a22 = float(self.A[1, 1]) |
| 288 | + C = float(self.C) |
| 289 | + |
| 290 | + if dimension == 0: |
| 291 | + # Integrate over x2; x = x1 |
| 292 | + def f(x): |
| 293 | + dx = x - mu_d |
| 294 | + alpha = k_o + np.cos(dx) * a11 + np.sin(dx) * a21 |
| 295 | + beta = np.cos(dx) * a12 + np.sin(dx) * a22 |
| 296 | + return 2 * np.pi * C * iv(0, np.sqrt(alpha**2 + beta**2)) * np.exp(k_d * np.cos(dx)) |
| 297 | + else: |
| 298 | + # Integrate over x1; x = x2 |
| 299 | + def f(x): |
| 300 | + dx = x - mu_d |
| 301 | + alpha = k_o + np.cos(dx) * a11 + np.sin(dx) * a12 |
| 302 | + beta = np.cos(dx) * a21 + np.sin(dx) * a22 |
| 303 | + return 2 * np.pi * C * iv(0, np.sqrt(alpha**2 + beta**2)) * np.exp(k_d * np.cos(dx)) |
| 304 | + |
| 305 | + return CustomCircularDistribution(f) |
| 306 | + |
| 307 | + def shift(self, shift_angles): |
| 308 | + """Return a copy of this distribution shifted by shift_angles.""" |
| 309 | + assert shift_angles.shape == (2,) |
| 310 | + result = copy.copy(self) |
| 311 | + result.mu = mod(self.mu + shift_angles, 2.0 * pi) |
| 312 | + return result |
0 commit comments