Skip to content

Commit d80b3b9

Browse files
committed
Added mean axis for dirac on hypersphere
1 parent 175d2c1 commit d80b3b9

3 files changed

Lines changed: 59 additions & 3 deletions

File tree

pyrecest/distributions/hypersphere_subset/abstract_hypersphere_subset_dirac_distribution.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=redefined-builtin,no-name-in-module,no-member
22
# pylint: disable=no-name-in-module,no-member
3-
from pyrecest.backend import log, outer, sum, zeros
3+
from pyrecest.backend import log, outer, sum, zeros, reshape, linalg
44

55
from ..abstract_dirac_distribution import AbstractDiracDistribution
66
from .abstract_hypersphere_subset_distribution import (
@@ -31,3 +31,36 @@ def entropy(self):
3131

3232
def integrate(self, integration_boundaries=None):
3333
raise NotImplementedError()
34+
35+
def mean_axis(self):
36+
"""
37+
Returns the principal axis of the Dirac mixture on the hypersphere.
38+
Because ±v represent the same axis, the sign of the returned vector
39+
is arbitrary.
40+
"""
41+
# Column vector of weights, shape (n, 1)
42+
w_col = reshape(self.w, (-1, 1)) # or self.w[:, None]
43+
44+
# Weighted second-moment matrix: S = Σ w_i d_i d_i^T
45+
# d has shape (n, D), so (d * w_col) is (n, D), then transpose @ d -> (D, D)
46+
S = (self.d * w_col).T @ self.d
47+
48+
# Normalize in case weights don't sum to 1
49+
S = S / sum(self.w)
50+
51+
# Enforce symmetry (numerical safety)
52+
S = 0.5 * (S + S.T)
53+
54+
# Eigen-decomposition of symmetric S
55+
D, V = linalg.eig(S)
56+
57+
# Index of largest eigenvalue
58+
# If you don't have argmax in the backend, use argsort instead
59+
idx = D.argmax() # or idx = argsort(D)[-1]
60+
61+
axis = V[:, idx]
62+
63+
# Normalize to unit length (should already be, but just in case)
64+
axis = axis / linalg.norm(axis)
65+
66+
return axis

pyrecest/distributions/hypersphere_subset/hyperhemispherical_dirac_distribution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@
99
class HyperhemisphericalDiracDistribution(
1010
AbstractHypersphereSubsetDiracDistribution, AbstractHyperhemisphericalDistribution
1111
):
12-
pass
12+
def mean_axis(self):
13+
axis = super().mean_axis()
14+
if axis[-1] < 0:
15+
axis = -axis
16+
return axis

pyrecest/tests/distributions/test_hyperspherical_dirac_distribution.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# pylint: disable=redefined-builtin,no-name-in-module,no-member
1010
# pylint: disable=no-name-in-module,no-member
11-
from pyrecest.backend import array, linalg, mod, ones, random, sqrt, sum
11+
from pyrecest.backend import array, linalg, mod, ones, random, sqrt, sum, allclose
1212
from pyrecest.distributions import VonMisesFisherDistribution
1313
from pyrecest.distributions.hypersphere_subset.hyperspherical_dirac_distribution import (
1414
HypersphericalDiracDistribution,
@@ -79,6 +79,25 @@ def test_from_distribution(self):
7979
dirac_dist.mean_direction(), vmf.mean_direction(), decimal=2
8080
)
8181

82+
def test_mean_axis_symmetric_two_point_distribution(self):
83+
# Two antipodal points on S^2: ±e_x
84+
d = array([
85+
[1.0, 0.0, 0.0],
86+
[-1.0, 0.0, 0.0],
87+
])
88+
w = array([0.5, 0.5])
89+
90+
dist = HypersphericalDiracDistribution(d, w)
91+
92+
axis = dist.mean_axis()
93+
94+
# 1) axis should be unit length
95+
assert allclose(linalg.norm(axis), 1.0, atol=1e-7)
96+
97+
# 2) axis should be parallel to (1, 0, 0), i.e. |dot(axis, e_x)| ≈ 1
98+
v = array([1.0, 0.0, 0.0])
99+
dot = float(axis @ v)
100+
assert abs(dot) > 1.0 - 1e-6
82101

83102
if __name__ == "__main__":
84103
unittest.main()

0 commit comments

Comments
 (0)