Skip to content

Commit 94a7e8d

Browse files
committed
Added (Hyperhemi)sphericalGridDistribution
1 parent 76f1d88 commit 94a7e8d

12 files changed

Lines changed: 1663 additions & 26 deletions

pyrecest/_backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def get_backend_name():
198198
"vmap",
199199
"gammaln",
200200
"round",
201+
"array_equal",
201202
# For Riemannian score-based SDE
202203
"log1p"
203204
],

pyrecest/_backend/jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
vstack,
7474
where,
7575
zeros_like,
76+
# For pyrecest
7677
diag,
7778
diff,
7879
apply_along_axis,
@@ -139,6 +140,7 @@
139140
linspace,
140141
ones,
141142
round,
143+
array_equal,
142144
# For Riemannian score-based SDE
143145
log1p,
144146
)

pyrecest/_backend/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
roll,
9999
dstack,
100100
round,
101+
array_equal,
101102
# For Riemannian score-based SDE
102103
log1p,
103104
)

pyrecest/_backend/pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
# For Riemannian score-based SDE
7373
log1p,
7474
)
75+
from torch import equal as array_equal # For PyRecEst
76+
7577
from torch import broadcast_tensors as broadcast_arrays
7678
from torch import repeat_interleave as repeat
7779
from torch.special import gammaln as _gammaln

pyrecest/distributions/abstract_grid_distribution.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import warnings
33
from abc import abstractmethod
44

5-
import numpy as np
65
from beartype import beartype
76

87
from .abstract_distribution_type import AbstractDistributionType
98

9+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
10+
from pyrecest.backend import mean, abs, any
1011

1112
class AbstractGridDistribution(AbstractDistributionType):
1213
# pylint: disable=too-many-positional-arguments
@@ -23,10 +24,10 @@ def __init__(
2324
not grid_type == "custom" or grid is not None
2425
) # if grid_type is custom, grid needs to be given
2526
assert (
26-
grid is None or np.size(grid) == 0 or grid.shape[0] == grid_values.shape[0]
27+
grid is None or grid.shape == () or grid.shape[0] == grid_values.shape[0]
2728
)
2829
assert (
29-
grid is None or np.size(grid) == 0 or grid.ndim == 1 or grid.shape[1] == dim
30+
grid is None or grid.shape == () or grid.ndim == 1 or grid.shape[1] == dim
3031
)
3132
if grid is None or grid.ndim > 1 and grid.shape[0] < grid.shape[1]:
3233
warnings.warn(
@@ -37,7 +38,7 @@ def __init__(
3738
self.grid = grid
3839
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
3940
# Overwrite with more descriptive parameterization
40-
self.grid_density_description = {"n_grid_values": np.size(grid_values)}
41+
self.grid_density_description = {"n_grid_values": grid_values.shape[0], "grid_type": grid_type}
4142

4243
def pdf(self, xs):
4344
# Use nearest neighbor interpolation by default
@@ -47,7 +48,7 @@ def pdf(self, xs):
4748
@property
4849
def n_grid_points(self):
4950
# Overwrite if grid_values contains values that are not used as grid values
50-
return np.size(self.grid_values)
51+
return self.grid_values.shape[0]
5152

5253
@abstractmethod
5354
def get_closest_point(self, xs):
@@ -61,19 +62,19 @@ def integrate(self, integration_boundaries=None):
6162
assert (
6263
integration_boundaries is None
6364
), "Custom integration boundaries are currently not supported"
64-
return self.get_manifold_size() * np.mean(self.grid_values)
65+
return self.get_manifold_size() * mean(self.grid_values)
6566

6667
def normalize_in_place(self, tol=1e-4, warn_unnorm=True):
6768
int_val = self.integrate()
68-
if np.any(self.grid_values < 0):
69+
if any(self.grid_values < 0):
6970
warnings.warn(
7071
"Warning: There are negative values. This usually points to a user error."
7172
)
72-
elif np.abs(int_val) < 1e-200:
73+
elif abs(int_val) < 1e-200:
7374
raise ValueError(
7475
"Sum of grid values is too close to zero, this usually points to a user error."
7576
)
76-
elif np.abs(int_val - 1) > tol:
77+
elif abs(int_val - 1) > tol:
7778
if warn_unnorm:
7879
warnings.warn(
7980
"Warning: Grid values apparently do not belong to a normalized density. Normalizing..."
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import warnings
2+
3+
from ..abstract_grid_distribution import AbstractGridDistribution
4+
from .abstract_hypersphere_subset_distribution import AbstractHypersphereSubsetDistribution
5+
from .abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution
6+
from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution
7+
from .von_mises_fisher_distribution import VonMisesFisherDistribution
8+
from .bingham_distribution import BinghamDistribution
9+
from .hyperspherical_mixture import HypersphericalMixture
10+
from .watson_distribution import WatsonDistribution
11+
from beartype import beartype
12+
13+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
14+
from pyrecest.backend import array_equal, argmax, sum
15+
16+
class AbstractHypersphereSubsetGridDistribution(AbstractGridDistribution, AbstractHypersphereSubsetDistribution):
17+
18+
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
19+
# Check size consistency
20+
if grid.shape[0] != grid_values.shape[0]:
21+
raise ValueError("Grid size must match number of grid values.")
22+
23+
AbstractGridDistribution.__init__(self, grid_values, grid_type = "unknown", grid=grid, dim=grid.shape[1], enforce_pdf_nonnegative=enforce_pdf_nonnegative)
24+
AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1])
25+
self.normalize()
26+
27+
def mean_direction(self):
28+
warnings.warn("For hyperhemispheres, this function yields the mode and not the mean.", UserWarning)
29+
# If we took the mean, it would be biased toward [0;...;0;1]
30+
# because the lower half is considered inexistant.
31+
index_max = argmax(self.grid_values)
32+
mu = self.get_grid_point(index_max)
33+
return mu
34+
35+
def moment(self):
36+
weights = self.grid_values / sum(self.grid_values) # (N,)
37+
38+
weighted_grid = self.get_grid() * weights
39+
40+
C = weighted_grid * (self.get_grid().T @ self.get_grid())
41+
return C
42+
43+
@beartype
44+
def multiply(self: "AbstractHypersphereSubsetGridDistribution", other: "AbstractHypersphereSubsetGridDistribution") -> "AbstractHypersphereSubsetGridDistribution":
45+
# Check for grid compatibility
46+
if not array_equal(self.get_grid(), other.get_grid()):
47+
raise ValueError("Can only multiply for equal grids. Grids are incompatible.")
48+
49+
# Delegates multiplication logic to AbstractGridDistribution
50+
return super().multiply(other)
51+
52+
@staticmethod
53+
def from_distribution(distribution, no_of_grid_points, grid_type='healpix'):
54+
# Import here to avoid circular imports
55+
from .hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution
56+
from .hyperspherical_grid_distribution import HypersphericalGridDistribution
57+
# pylint: disable=too-many-boolean-expressions
58+
if isinstance(distribution, AbstractHyperhemisphericalDistribution):
59+
fun = distribution.pdf
60+
elif (isinstance(distribution, (WatsonDistribution, BinghamDistribution)) or
61+
(isinstance(distribution, VonMisesFisherDistribution) and distribution.mu[-1] == 0) or
62+
(isinstance(distribution, HypersphericalMixture) and
63+
len(distribution.dists) == 2 and all(w == 0.5 for w in distribution.w) and
64+
array_equal(distribution.dists[1].mu, -distribution.dists[0].mu))):
65+
def fun(x):
66+
return 2 * distribution.pdf(x)
67+
elif isinstance(distribution, HypersphericalGridDistribution):
68+
raise ValueError('Converting a HypersphericalGridDistribution to a HyperhemisphericalGridDistribution is not supported')
69+
elif isinstance(distribution, AbstractHypersphericalDistribution):
70+
warnings.warn('Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.',
71+
UserWarning)
72+
def fun(x):
73+
return 2 * distribution.pdf(x)
74+
else:
75+
raise ValueError('Distribution currently not supported.')
76+
77+
sgd = HyperhemisphericalGridDistribution.from_function(fun, no_of_grid_points, distribution.dim, grid_type)
78+
return sgd
79+

pyrecest/distributions/hypersphere_subset/bingham_distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pylint: disable=redefined-builtin,no-name-in-module,no-member
44
# pylint: disable=no-name-in-module,no-member
5-
from pyrecest.backend import abs, all, argsort, diag, exp, eye, linalg, max, sum, zeros
5+
from pyrecest.backend import abs, all, argsort, diag, exp, eye, linalg, max, sum, zeros, array
66
from scipy.integrate import quad
77
from scipy.special import iv
88

@@ -20,7 +20,7 @@ def __init__(self, Z, M):
2020
assert all(Z[:-1] <= Z[1:]), "Values in Z have to be ascending"
2121

2222
# Verify that M is orthogonal
23-
epsilon = 0.001
23+
epsilon = array(0.001)
2424
assert max(abs(M @ M.T - eye(self.dim + 1))) < epsilon, "M is not orthogonal"
2525

2626
self.Z = Z

0 commit comments

Comments
 (0)