Skip to content

Commit 7e7143e

Browse files
CopilotFlorianPfaff
andcommitted
Add ToroidalVMMatrixDistribution (bivariate von Mises, matrix version)
Port MATLAB ToroidalVMMatrixDistribution to Python. Implements: - PDF with kappa concentration and A correlation matrix - Series approximation normalization for low concentrations (n=7 terms) - Numerical normalization (dblquad) for high concentrations - multiply(): exact product of two distributions - marginalize_to_1d(): Bessel-function-based analytic marginal - shift(): shift mu parameters Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
1 parent c3471d3 commit 7e7143e

2 files changed

Lines changed: 414 additions & 0 deletions

File tree

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
2+
import copy
3+
from math import factorial
4+
5+
import numpy as np
6+
from pyrecest.backend import array, cos, exp, mod, pi, sin
7+
from scipy.integrate import dblquad
8+
from scipy.special import iv
9+
10+
from ..circle.custom_circular_distribution import CustomCircularDistribution
11+
from .abstract_toroidal_distribution import AbstractToroidalDistribution
12+
13+
14+
class ToroidalVMMatrixDistribution(AbstractToroidalDistribution):
15+
"""Bivariate von Mises distribution, matrix version.
16+
17+
See:
18+
- Mardia, K. V. Statistics of Directional Data. JRSS-B, 1975.
19+
- Mardia, K. V. & Jupp, P. E. Directional Statistics. Wiley, 1999.
20+
- Kurz, Hanebeck. Toroidal Information Fusion Based on the Bivariate
21+
von Mises Distribution. MFI 2015.
22+
"""
23+
24+
def __init__(self, mu, kappa, A):
25+
AbstractToroidalDistribution.__init__(self)
26+
assert mu.shape == (2,)
27+
assert kappa.shape == (2,)
28+
assert A.shape == (2, 2)
29+
assert kappa[0] > 0
30+
assert kappa[1] > 0
31+
32+
self.mu = mod(mu, 2.0 * pi)
33+
self.kappa = kappa
34+
self.A = A
35+
36+
A_np = np.array(A, dtype=float) if not isinstance(A, np.ndarray) else A.astype(float)
37+
use_numerical = (
38+
float(kappa[0]) > 1.5
39+
or float(kappa[1]) > 1.5
40+
or np.max(np.abs(A_np)) > 1.0
41+
)
42+
43+
if use_numerical:
44+
self.C = 1.0
45+
Cinv, _ = dblquad(
46+
lambda y, x: float(self.pdf(array([x, y]))),
47+
0.0,
48+
float(2 * pi),
49+
0.0,
50+
float(2 * pi),
51+
)
52+
self.C = 1.0 / Cinv
53+
else:
54+
self.C = self._norm_const_approx()
55+
56+
def pdf(self, xs):
57+
assert xs.shape[-1] == 2
58+
x1_mm = xs[..., 0] - self.mu[0]
59+
x2_mm = xs[..., 1] - self.mu[1]
60+
exponent = (
61+
self.kappa[0] * cos(x1_mm)
62+
+ self.kappa[1] * cos(x2_mm)
63+
+ cos(x1_mm) * self.A[0, 0] * cos(x2_mm)
64+
+ cos(x1_mm) * self.A[0, 1] * sin(x2_mm)
65+
+ sin(x1_mm) * self.A[1, 0] * cos(x2_mm)
66+
+ sin(x1_mm) * self.A[1, 1] * sin(x2_mm)
67+
)
68+
return self.C * exp(exponent)
69+
70+
def _norm_const_approx(self, n=8):
71+
"""Approximate normalization constant using Taylor series (up to n=8 summands)."""
72+
a11 = float(self.A[0, 0])
73+
a12 = float(self.A[0, 1])
74+
a21 = float(self.A[1, 0])
75+
a22 = float(self.A[1, 1])
76+
k1 = float(self.kappa[0])
77+
k2 = float(self.kappa[1])
78+
79+
total = 4 * np.pi**2 # n=0 term
80+
# n=1 term is zero
81+
if n >= 2:
82+
total += (
83+
(a11**2 + a12**2 + a21**2 + a22**2 + 2 * k1**2 + 2 * k2**2)
84+
* np.pi**2
85+
/ factorial(2)
86+
)
87+
if n >= 3:
88+
total += 6 * a11 * k1 * k2 * np.pi**2 / factorial(3)
89+
if n >= 4:
90+
total += (
91+
3
92+
/ 16
93+
* (
94+
3 * a11**4
95+
+ 3 * a12**4
96+
+ 3 * a21**4
97+
+ 8 * a11 * a12 * a21 * a22
98+
+ 6 * a21**2 * a22**2
99+
+ 3 * a22**4
100+
+ 8 * a21**2 * k1**2
101+
+ 8 * a22**2 * k1**2
102+
+ 8 * k1**4
103+
+ 8 * (3 * a21**2 + a22**2 + 4 * k1**2) * k2**2
104+
+ 8 * k2**4
105+
+ 2 * a11**2 * (3 * a12**2 + 3 * a21**2 + a22**2 + 12 * (k1**2 + k2**2))
106+
+ 2 * a12**2 * (a21**2 + 3 * a22**2 + 4 * (3 * k1**2 + k2**2))
107+
)
108+
* np.pi**2
109+
/ factorial(4)
110+
)
111+
if n >= 5:
112+
total += (
113+
15
114+
/ 4
115+
* np.pi**2
116+
* k1
117+
* k2
118+
* (
119+
3 * a11**3
120+
+ 3 * a11 * a12**2
121+
+ 3 * a11 * a21**2
122+
+ a11 * a22**2
123+
+ 4 * a11 * k1**2
124+
+ 4 * a11 * k2**2
125+
+ 2 * a12 * a21 * a22
126+
)
127+
/ factorial(5)
128+
)
129+
if n >= 6:
130+
total += (
131+
5
132+
/ 64
133+
* np.pi**2
134+
* (
135+
5 * a11**6
136+
+ 15 * a11**4 * a12**2
137+
+ 15 * a11**4 * a21**2
138+
+ 3 * a11**4 * a22**2
139+
+ 90 * a11**4 * k1**2
140+
+ 90 * a11**4 * k2**2
141+
+ 24 * a11**3 * a12 * a21 * a22
142+
+ 15 * a11**2 * a12**4
143+
+ 18 * a11**2 * a12**2 * a21**2
144+
+ 18 * a11**2 * a12**2 * a22**2
145+
+ 180 * a11**2 * a12**2 * k1**2
146+
+ 108 * a11**2 * a12**2 * k2**2
147+
+ 15 * a11**2 * a21**4
148+
+ 18 * a11**2 * a21**2 * a22**2
149+
+ 108 * a11**2 * a21**2 * k1**2
150+
+ 180 * a11**2 * a21**2 * k2**2
151+
+ 3 * a11**2 * a22**4
152+
+ 36 * a11**2 * a22**2 * k1**2
153+
+ 36 * a11**2 * a22**2 * k2**2
154+
+ 120 * a11**2 * k1**4
155+
+ 648 * a11**2 * k1**2 * k2**2
156+
+ 120 * a11**2 * k2**4
157+
+ 24 * a11 * a12**3 * a21 * a22
158+
+ 24 * a11 * a12 * a21**3 * a22
159+
+ 24 * a11 * a12 * a21 * a22**3
160+
+ 144 * a11 * a12 * a21 * a22 * k1**2
161+
+ 144 * a11 * a12 * a21 * a22 * k2**2
162+
+ 5 * a12**6
163+
+ 3 * a12**4 * a21**2
164+
+ 15 * a12**4 * a22**2
165+
+ 90 * a12**4 * k1**2
166+
+ 18 * a12**4 * k2**2
167+
+ 3 * a12**2 * a21**4
168+
+ 18 * a12**2 * a21**2 * a22**2
169+
+ 36 * a12**2 * a21**2 * k1**2
170+
+ 36 * a12**2 * a21**2 * k2**2
171+
+ 15 * a12**2 * a22**4
172+
+ 108 * a12**2 * a22**2 * k1**2
173+
+ 36 * a12**2 * a22**2 * k2**2
174+
+ 120 * a12**2 * k1**4
175+
+ 216 * a12**2 * k1**2 * k2**2
176+
+ 24 * a12**2 * k2**4
177+
+ 5 * a21**6
178+
+ 15 * a21**4 * a22**2
179+
+ 18 * a21**4 * k1**2
180+
+ 90 * a21**4 * k2**2
181+
+ 15 * a21**2 * a22**4
182+
+ 36 * a21**2 * a22**2 * k1**2
183+
+ 108 * a21**2 * a22**2 * k2**2
184+
+ 24 * a21**2 * k1**4
185+
+ 216 * a21**2 * k1**2 * k2**2
186+
+ 120 * a21**2 * k2**4
187+
+ 5 * a22**6
188+
+ 18 * a22**4 * k1**2
189+
+ 18 * a22**4 * k2**2
190+
+ 24 * a22**2 * k1**4
191+
+ 72 * a22**2 * k1**2 * k2**2
192+
+ 24 * a22**2 * k2**4
193+
+ 16 * k1**6
194+
+ 144 * k1**4 * k2**2
195+
+ 144 * k1**2 * k2**4
196+
+ 16 * k2**6
197+
)
198+
/ factorial(6)
199+
)
200+
if n >= 7:
201+
total += (
202+
105
203+
/ 32
204+
* k1
205+
* k2
206+
* np.pi**2
207+
* (
208+
5 * a11**5
209+
+ 10 * a11**3 * a12**2
210+
+ 10 * a11**3 * a21**2
211+
+ 2 * a11**3 * a22**2
212+
+ 20 * a11**3 * k1**2
213+
+ 20 * a11**3 * k2**2
214+
+ 12 * a11**2 * a12 * a21 * a22
215+
+ 5 * a11 * a12**4
216+
+ 6 * a11 * a12**2 * a21**2
217+
+ 6 * a11 * a12**2 * a22**2
218+
+ 20 * a11 * a12**2 * k1**2
219+
+ 12 * a11 * a12**2 * k2**2
220+
+ 5 * a11 * a21**4
221+
+ 6 * a11 * a21**2 * a22**2
222+
+ 12 * a11 * a21**2 * k1**2
223+
+ 20 * a11 * a21**2 * k2**2
224+
+ a11 * a22**4
225+
+ 4 * a11 * a22**2 * k1**2
226+
+ 4 * a11 * a22**2 * k2**2
227+
+ 8 * a11 * k1**4
228+
+ 24 * a11 * k1**2 * k2**2
229+
+ 8 * a11 * k2**4
230+
+ 4 * a12**3 * a21 * a22
231+
+ 4 * a12 * a21**3 * a22
232+
+ 4 * a12 * a21 * a22**3
233+
+ 8 * a12 * a21 * a22 * k1**2
234+
+ 8 * a12 * a21 * a22 * k2**2
235+
)
236+
/ factorial(7)
237+
)
238+
return 1.0 / total
239+
240+
def multiply(self, other):
241+
"""Multiply two ToroidalVMMatrixDistributions (exact product)."""
242+
assert isinstance(other, ToroidalVMMatrixDistribution)
243+
244+
C1 = float(self.kappa[0]) * np.cos(float(self.mu[0])) + float(other.kappa[0]) * np.cos(float(other.mu[0]))
245+
S1 = float(self.kappa[0]) * np.sin(float(self.mu[0])) + float(other.kappa[0]) * np.sin(float(other.mu[0]))
246+
C2 = float(self.kappa[1]) * np.cos(float(self.mu[1])) + float(other.kappa[1]) * np.cos(float(other.mu[1]))
247+
S2 = float(self.kappa[1]) * np.sin(float(self.mu[1])) + float(other.kappa[1]) * np.sin(float(other.mu[1]))
248+
249+
mu_new = array([np.arctan2(S1, C1) % (2 * np.pi), np.arctan2(S2, C2) % (2 * np.pi)])
250+
kappa_new = array([np.sqrt(C1**2 + S1**2), np.sqrt(C2**2 + S2**2)])
251+
252+
def _M(mu):
253+
c1, s1 = np.cos(float(mu[0])), np.sin(float(mu[0]))
254+
c2, s2 = np.cos(float(mu[1])), np.sin(float(mu[1]))
255+
return np.array([
256+
[ c1 * c2, -s1 * c2, -c1 * s2, s1 * s2],
257+
[ s1 * c2, c1 * c2, -s1 * s2, -c1 * s2],
258+
[ c1 * s2, -s1 * s2, c1 * c2, -s1 * c2],
259+
[ s1 * s2, c1 * s2, s1 * c2, c1 * c2],
260+
])
261+
262+
# Pack A columns as [A11; A21; A12; A22]
263+
A1 = np.array([[float(self.A[0, 0])], [float(self.A[1, 0])], [float(self.A[0, 1])], [float(self.A[1, 1])]])
264+
A2 = np.array([[float(other.A[0, 0])], [float(other.A[1, 0])], [float(other.A[0, 1])], [float(other.A[1, 1])]])
265+
b = _M(self.mu) @ A1 + _M(other.mu) @ A2
266+
a = np.linalg.solve(_M(mu_new), b).ravel()
267+
A_new = array([[a[0], a[2]], [a[1], a[3]]])
268+
269+
return ToroidalVMMatrixDistribution(mu_new, kappa_new, A_new)
270+
271+
def marginalize_to_1d(self, dimension):
272+
"""Get marginal distribution in the given dimension (0 or 1, 0-indexed).
273+
274+
Integrates out the *other* dimension analytically using the Bessel
275+
function identity for the von-Mises-type integral.
276+
"""
277+
assert dimension in (0, 1)
278+
other = 1 - dimension
279+
280+
mu_d = float(self.mu[dimension])
281+
mu_o = float(self.mu[other]) # noqa: F841 – retained for clarity
282+
k_d = float(self.kappa[dimension])
283+
k_o = float(self.kappa[other])
284+
a11 = float(self.A[0, 0])
285+
a12 = float(self.A[0, 1])
286+
a21 = float(self.A[1, 0])
287+
a22 = float(self.A[1, 1])
288+
C = float(self.C)
289+
290+
if dimension == 0:
291+
# Integrate over x2; x = x1
292+
def f(x):
293+
dx = x - mu_d
294+
alpha = k_o + np.cos(dx) * a11 + np.sin(dx) * a21
295+
beta = np.cos(dx) * a12 + np.sin(dx) * a22
296+
return 2 * np.pi * C * iv(0, np.sqrt(alpha**2 + beta**2)) * np.exp(k_d * np.cos(dx))
297+
else:
298+
# Integrate over x1; x = x2
299+
def f(x):
300+
dx = x - mu_d
301+
alpha = k_o + np.cos(dx) * a11 + np.sin(dx) * a12
302+
beta = np.cos(dx) * a21 + np.sin(dx) * a22
303+
return 2 * np.pi * C * iv(0, np.sqrt(alpha**2 + beta**2)) * np.exp(k_d * np.cos(dx))
304+
305+
return CustomCircularDistribution(f)
306+
307+
def shift(self, shift_angles):
308+
"""Return a copy of this distribution shifted by shift_angles."""
309+
assert shift_angles.shape == (2,)
310+
result = copy.copy(self)
311+
result.mu = mod(self.mu + shift_angles, 2.0 * pi)
312+
return result

0 commit comments

Comments
 (0)