|
| 1 | +import warnings |
| 2 | + |
| 3 | +from ..abstract_grid_distribution import AbstractGridDistribution |
| 4 | +from .abstract_hypersphere_subset_distribution import AbstractHypersphereSubsetDistribution |
| 5 | +from .abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution |
| 6 | +from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution |
| 7 | +from .von_mises_fisher_distribution import VonMisesFisherDistribution |
| 8 | +from .bingham_distribution import BinghamDistribution |
| 9 | +from .hyperspherical_mixture import HypersphericalMixture |
| 10 | +from .watson_distribution import WatsonDistribution |
| 11 | +from beartype import beartype |
| 12 | + |
| 13 | +# pylint: disable=redefined-builtin,no-name-in-module,no-member |
| 14 | +from pyrecest.backend import array_equal, argmax, sum |
| 15 | + |
| 16 | +class AbstractHypersphereSubsetGridDistribution(AbstractGridDistribution, AbstractHypersphereSubsetDistribution): |
| 17 | + |
| 18 | + def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): |
| 19 | + # Check size consistency |
| 20 | + if grid.shape[0] != grid_values.shape[0]: |
| 21 | + raise ValueError("Grid size must match number of grid values.") |
| 22 | + |
| 23 | + AbstractGridDistribution.__init__(self, grid_values, grid_type = "unknown", grid=grid, dim=grid.shape[1], enforce_pdf_nonnegative=enforce_pdf_nonnegative) |
| 24 | + AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1]) |
| 25 | + self.normalize() |
| 26 | + |
| 27 | + def mean_direction(self): |
| 28 | + warnings.warn("For hyperhemispheres, this function yields the mode and not the mean.", UserWarning) |
| 29 | + # If we took the mean, it would be biased toward [0;...;0;1] |
| 30 | + # because the lower half is considered inexistant. |
| 31 | + index_max = argmax(self.grid_values) |
| 32 | + mu = self.get_grid_point(index_max) |
| 33 | + return mu |
| 34 | + |
| 35 | + def moment(self): |
| 36 | + weights = self.grid_values / sum(self.grid_values) # (N,) |
| 37 | + |
| 38 | + weighted_grid = self.get_grid() * weights |
| 39 | + |
| 40 | + C = weighted_grid * (self.get_grid().T @ self.get_grid()) |
| 41 | + return C |
| 42 | + |
| 43 | + @beartype |
| 44 | + def multiply(self: "AbstractHypersphereSubsetGridDistribution", other: "AbstractHypersphereSubsetGridDistribution") -> "AbstractHypersphereSubsetGridDistribution": |
| 45 | + # Check for grid compatibility |
| 46 | + if not array_equal(self.get_grid(), other.get_grid()): |
| 47 | + raise ValueError("Can only multiply for equal grids. Grids are incompatible.") |
| 48 | + |
| 49 | + # Delegates multiplication logic to AbstractGridDistribution |
| 50 | + return super().multiply(other) |
| 51 | + |
| 52 | + @staticmethod |
| 53 | + def from_distribution(distribution, no_of_grid_points, grid_type='healpix'): |
| 54 | + # Import here to avoid circular imports |
| 55 | + from .hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution |
| 56 | + from .hyperspherical_grid_distribution import HypersphericalGridDistribution |
| 57 | + # pylint: disable=too-many-boolean-expressions |
| 58 | + if isinstance(distribution, AbstractHyperhemisphericalDistribution): |
| 59 | + fun = distribution.pdf |
| 60 | + elif (isinstance(distribution, (WatsonDistribution, BinghamDistribution)) or |
| 61 | + (isinstance(distribution, VonMisesFisherDistribution) and distribution.mu[-1] == 0) or |
| 62 | + (isinstance(distribution, HypersphericalMixture) and |
| 63 | + len(distribution.dists) == 2 and all(w == 0.5 for w in distribution.w) and |
| 64 | + array_equal(distribution.dists[1].mu, -distribution.dists[0].mu))): |
| 65 | + def fun(x): |
| 66 | + return 2 * distribution.pdf(x) |
| 67 | + elif isinstance(distribution, HypersphericalGridDistribution): |
| 68 | + raise ValueError('Converting a HypersphericalGridDistribution to a HyperhemisphericalGridDistribution is not supported') |
| 69 | + elif isinstance(distribution, AbstractHypersphericalDistribution): |
| 70 | + warnings.warn('Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.', |
| 71 | + UserWarning) |
| 72 | + def fun(x): |
| 73 | + return 2 * distribution.pdf(x) |
| 74 | + else: |
| 75 | + raise ValueError('Distribution currently not supported.') |
| 76 | + |
| 77 | + sgd = HyperhemisphericalGridDistribution.from_function(fun, no_of_grid_points, distribution.dim, grid_type) |
| 78 | + return sgd |
| 79 | + |
0 commit comments