Skip to content

Commit bca57a4

Browse files
committed
Added (Hyper)(hemi)sphericalGridDistribution
1 parent ad6def2 commit bca57a4

10 files changed

Lines changed: 1569 additions & 0 deletions

pyrecest/_backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def get_backend_name():
198198
"vmap",
199199
"gammaln",
200200
"round",
201+
"array_equal",
201202
# For Riemannian score-based SDE
202203
"log1p"
203204
],

pyrecest/_backend/jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
vstack,
7474
where,
7575
zeros_like,
76+
# For pyrecest
7677
diag,
7778
diff,
7879
apply_along_axis,
@@ -139,6 +140,7 @@
139140
linspace,
140141
ones,
141142
round,
143+
array_equal,
142144
# For Riemannian score-based SDE
143145
log1p,
144146
)

pyrecest/_backend/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
roll,
9999
dstack,
100100
round,
101+
array_equal,
101102
# For Riemannian score-based SDE
102103
log1p,
103104
)

pyrecest/_backend/pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
# For Riemannian score-based SDE
7373
log1p,
7474
)
75+
from torch import equal as array_equal # For PyRecEst
76+
7577
from torch import broadcast_tensors as broadcast_arrays
7678
from torch import repeat_interleave as repeat
7779
from torch.special import gammaln as _gammaln
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import warnings
2+
3+
from ..abstract_grid_distribution import AbstractGridDistribution
4+
from .abstract_hypersphere_subset_distribution import AbstractHypersphereSubsetDistribution
5+
from .abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution
6+
from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution
7+
from .von_mises_fisher_distribution import VonMisesFisherDistribution
8+
from .bingham_distribution import BinghamDistribution
9+
from .hyperspherical_mixture import HypersphericalMixture
10+
from .watson_distribution import WatsonDistribution
11+
from beartype import beartype
12+
13+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
14+
from pyrecest.backend import array_equal, argmax, sum
15+
16+
class AbstractHypersphereSubsetGridDistribution(AbstractGridDistribution, AbstractHypersphereSubsetDistribution):
17+
18+
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
19+
# Check size consistency
20+
if grid.shape[0] != grid_values.shape[0]:
21+
raise ValueError("Grid size must match number of grid values.")
22+
23+
AbstractGridDistribution.__init__(self, grid_values, grid_type = "unknown", grid=grid, dim=grid.shape[1], enforce_pdf_nonnegative=enforce_pdf_nonnegative)
24+
AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1])
25+
self.normalize()
26+
27+
def mean_direction(self):
28+
warnings.warn("For hyperhemispheres, this function yields the mode and not the mean.", UserWarning)
29+
# If we took the mean, it would be biased toward [0;...;0;1]
30+
# because the lower half is considered inexistant.
31+
index_max = argmax(self.grid_values)
32+
mu = self.get_grid_point(index_max)
33+
return mu
34+
35+
def moment(self):
36+
weights = self.grid_values / sum(self.grid_values) # (N,)
37+
38+
weighted_grid = self.get_grid() * weights
39+
40+
C = weighted_grid * (self.get_grid().T @ self.get_grid())
41+
return C
42+
43+
@beartype
44+
def multiply(self: "AbstractHypersphereSubsetGridDistribution", other: "AbstractHypersphereSubsetGridDistribution") -> "AbstractHypersphereSubsetGridDistribution":
45+
# Check for grid compatibility
46+
if not array_equal(self.get_grid(), other.get_grid()):
47+
raise ValueError("Can only multiply for equal grids. Grids are incompatible.")
48+
49+
# Delegates multiplication logic to AbstractGridDistribution
50+
return super().multiply(other)
51+
52+
@staticmethod
53+
def from_distribution(distribution, no_of_grid_points, grid_type='healpix'):
54+
# Import here to avoid circular imports
55+
from .hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution
56+
from .hyperspherical_grid_distribution import HypersphericalGridDistribution
57+
# pylint: disable=too-many-boolean-expressions
58+
if isinstance(distribution, AbstractHyperhemisphericalDistribution):
59+
fun = distribution.pdf
60+
elif (isinstance(distribution, (WatsonDistribution, BinghamDistribution)) or
61+
(isinstance(distribution, VonMisesFisherDistribution) and distribution.mu[-1] == 0) or
62+
(isinstance(distribution, HypersphericalMixture) and
63+
len(distribution.dists) == 2 and all(w == 0.5 for w in distribution.w) and
64+
array_equal(distribution.dists[1].mu, -distribution.dists[0].mu))):
65+
def fun(x):
66+
return 2 * distribution.pdf(x)
67+
elif isinstance(distribution, HypersphericalGridDistribution):
68+
raise ValueError('Converting a HypersphericalGridDistribution to a HyperhemisphericalGridDistribution is not supported')
69+
elif isinstance(distribution, AbstractHypersphericalDistribution):
70+
warnings.warn('Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.',
71+
UserWarning)
72+
def fun(x):
73+
return 2 * distribution.pdf(x)
74+
else:
75+
raise ValueError('Distribution currently not supported.')
76+
77+
sgd = HyperhemisphericalGridDistribution.from_function(fun, no_of_grid_points, distribution.dim, grid_type)
78+
return sgd
79+

0 commit comments

Comments
 (0)