|
1 | 1 | import warnings |
2 | 2 |
|
3 | | -from ..abstract_grid_distribution import AbstractGridDistribution |
4 | | -from .abstract_hypersphere_subset_distribution import AbstractHypersphereSubsetDistribution |
5 | | -from .abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution |
| 3 | +from beartype import beartype |
| 4 | + |
| 5 | +# pylint: disable=redefined-builtin,no-name-in-module,no-member |
| 6 | +from pyrecest.backend import argmax, array_equal, sum |
| 7 | + |
| 8 | +from ..abstract_grid_distribution import AbstractGridDistribution |
| 9 | +from .abstract_hyperhemispherical_distribution import ( |
| 10 | + AbstractHyperhemisphericalDistribution, |
| 11 | +) |
| 12 | +from .abstract_hypersphere_subset_distribution import ( |
| 13 | + AbstractHypersphereSubsetDistribution, |
| 14 | +) |
6 | 15 | from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution |
7 | | -from .von_mises_fisher_distribution import VonMisesFisherDistribution |
8 | 16 | from .bingham_distribution import BinghamDistribution |
9 | 17 | from .hyperspherical_mixture import HypersphericalMixture |
| 18 | +from .von_mises_fisher_distribution import VonMisesFisherDistribution |
10 | 19 | from .watson_distribution import WatsonDistribution |
11 | | -from beartype import beartype |
12 | 20 |
|
13 | | -# pylint: disable=redefined-builtin,no-name-in-module,no-member |
14 | | -from pyrecest.backend import array_equal, argmax, sum |
15 | 21 |
|
16 | | -class AbstractHypersphereSubsetGridDistribution(AbstractGridDistribution, AbstractHypersphereSubsetDistribution): |
17 | | - |
| 22 | +class AbstractHypersphereSubsetGridDistribution( |
| 23 | + AbstractGridDistribution, AbstractHypersphereSubsetDistribution |
| 24 | +): |
| 25 | + |
18 | 26 | def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True): |
19 | 27 | # Check size consistency |
20 | 28 | if grid.shape[0] != grid_values.shape[0]: |
21 | 29 | raise ValueError("Grid size must match number of grid values.") |
22 | | - |
23 | | - AbstractGridDistribution.__init__(self, grid_values, grid_type = "unknown", grid=grid, dim=grid.shape[1], enforce_pdf_nonnegative=enforce_pdf_nonnegative) |
| 30 | + |
| 31 | + AbstractGridDistribution.__init__( |
| 32 | + self, |
| 33 | + grid_values, |
| 34 | + grid_type="unknown", |
| 35 | + grid=grid, |
| 36 | + dim=grid.shape[1], |
| 37 | + enforce_pdf_nonnegative=enforce_pdf_nonnegative, |
| 38 | + ) |
24 | 39 | AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1]) |
25 | | - self.normalize() |
| 40 | + self.normalize() |
26 | 41 |
|
27 | 42 | def mean_direction(self): |
28 | | - warnings.warn("For hyperhemispheres, this function yields the mode and not the mean.", UserWarning) |
| 43 | + warnings.warn( |
| 44 | + "For hyperhemispheres, this function yields the mode and not the mean.", |
| 45 | + UserWarning, |
| 46 | + ) |
29 | 47 | # If we took the mean, it would be biased toward [0;...;0;1] |
30 | 48 | # because the lower half is considered inexistant. |
31 | 49 | index_max = argmax(self.grid_values) |
32 | 50 | mu = self.get_grid_point(index_max) |
33 | 51 | return mu |
34 | 52 |
|
35 | 53 | def moment(self): |
36 | | - weights = self.grid_values / sum(self.grid_values) # (N,) |
37 | | - |
38 | | - weighted_grid = self.get_grid() * weights |
| 54 | + weights = self.grid_values / sum(self.grid_values) # (N,) |
| 55 | + |
| 56 | + weighted_grid = self.get_grid() * weights |
39 | 57 |
|
40 | 58 | C = weighted_grid * (self.get_grid().T @ self.get_grid()) |
41 | 59 | return C |
42 | 60 |
|
43 | 61 | @beartype |
44 | | - def multiply(self: "AbstractHypersphereSubsetGridDistribution", other: "AbstractHypersphereSubsetGridDistribution") -> "AbstractHypersphereSubsetGridDistribution": |
| 62 | + def multiply( |
| 63 | + self: "AbstractHypersphereSubsetGridDistribution", |
| 64 | + other: "AbstractHypersphereSubsetGridDistribution", |
| 65 | + ) -> "AbstractHypersphereSubsetGridDistribution": |
45 | 66 | # Check for grid compatibility |
46 | 67 | if not array_equal(self.get_grid(), other.get_grid()): |
47 | | - raise ValueError("Can only multiply for equal grids. Grids are incompatible.") |
48 | | - |
| 68 | + raise ValueError( |
| 69 | + "Can only multiply for equal grids. Grids are incompatible." |
| 70 | + ) |
| 71 | + |
49 | 72 | # Delegates multiplication logic to AbstractGridDistribution |
50 | 73 | return super().multiply(other) |
51 | 74 |
|
52 | 75 | @classmethod |
53 | | - def from_distribution(cls, distribution, no_of_grid_points, grid_type, enforce_pdf_nonnegative=True): |
| 76 | + def from_distribution( |
| 77 | + cls, distribution, no_of_grid_points, grid_type, enforce_pdf_nonnegative=True |
| 78 | + ): |
| 79 | + from .hyperhemispherical_grid_distribution import ( |
| 80 | + HyperhemisphericalGridDistribution, |
| 81 | + ) |
54 | 82 | from .hyperspherical_grid_distribution import HypersphericalGridDistribution |
55 | | - from .hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution |
| 83 | + |
56 | 84 | # pylint: disable=too-many-boolean-expressions |
57 | 85 | if isinstance(distribution, AbstractHypersphereSubsetGridDistribution): |
58 | | - raise ValueError('Already a grid distribution. Use directly instead of converting.') |
59 | | - |
60 | | - if isinstance(distribution, AbstractHypersphericalDistribution) and issubclass(cls, HypersphericalGridDistribution)\ |
61 | | - or isinstance(distribution, AbstractHyperhemisphericalDistribution) and issubclass(cls, HyperhemisphericalGridDistribution): |
| 86 | + raise ValueError( |
| 87 | + "Already a grid distribution. Use directly instead of converting." |
| 88 | + ) |
| 89 | + |
| 90 | + if ( |
| 91 | + isinstance(distribution, AbstractHypersphericalDistribution) |
| 92 | + and issubclass(cls, HypersphericalGridDistribution) |
| 93 | + or isinstance(distribution, AbstractHyperhemisphericalDistribution) |
| 94 | + and issubclass(cls, HyperhemisphericalGridDistribution) |
| 95 | + ): |
62 | 96 | # sphere -> sphere or hemisphere -> hemisphere |
63 | 97 | fun = distribution.pdf |
64 | | - elif issubclass(cls, HyperhemisphericalGridDistribution) and (isinstance(distribution, (WatsonDistribution, BinghamDistribution)) or |
65 | | - (isinstance(distribution, VonMisesFisherDistribution) and distribution.mu[-1] == 0) or |
66 | | - (isinstance(distribution, HypersphericalMixture) and |
67 | | - len(distribution.dists) == 2 and all(w == 0.5 for w in distribution.w) and |
68 | | - array_equal(distribution.dists[1].mu, -distribution.dists[0].mu))): |
| 98 | + elif issubclass(cls, HyperhemisphericalGridDistribution) and ( |
| 99 | + isinstance(distribution, (WatsonDistribution, BinghamDistribution)) |
| 100 | + or ( |
| 101 | + isinstance(distribution, VonMisesFisherDistribution) |
| 102 | + and distribution.mu[-1] == 0 |
| 103 | + ) |
| 104 | + or ( |
| 105 | + isinstance(distribution, HypersphericalMixture) |
| 106 | + and len(distribution.dists) == 2 |
| 107 | + and all(w == 0.5 for w in distribution.w) |
| 108 | + and array_equal(distribution.dists[1].mu, -distribution.dists[0].mu) |
| 109 | + ) |
| 110 | + ): |
69 | 111 | # sphere -> hemisphere for symmetric distributions |
70 | 112 | def fun(x): |
71 | 113 | return 2 * distribution.pdf(x) |
72 | | - elif isinstance(distribution, AbstractHypersphericalDistribution) and issubclass(cls, HyperhemisphericalGridDistribution): |
| 114 | + |
| 115 | + elif isinstance( |
| 116 | + distribution, AbstractHypersphericalDistribution |
| 117 | + ) and issubclass(cls, HyperhemisphericalGridDistribution): |
73 | 118 | # sphere -> hemisphere for general distributions, which we do not know to be symmetric |
74 | | - warnings.warn('Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.', |
75 | | - UserWarning) |
| 119 | + warnings.warn( |
| 120 | + "Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.", |
| 121 | + UserWarning, |
| 122 | + ) |
| 123 | + |
76 | 124 | def fun(x): |
77 | 125 | return 2 * distribution.pdf(x) |
| 126 | + |
78 | 127 | else: |
79 | | - raise ValueError('Distribution currently not supported.') |
80 | | - |
81 | | - sgd = cls.from_function(fun, no_of_grid_points, distribution.dim, grid_type, enforce_pdf_nonnegative=enforce_pdf_nonnegative) |
82 | | - return sgd |
| 128 | + raise ValueError("Distribution currently not supported.") |
83 | 129 |
|
| 130 | + sgd = cls.from_function( |
| 131 | + fun, |
| 132 | + no_of_grid_points, |
| 133 | + distribution.dim, |
| 134 | + grid_type, |
| 135 | + enforce_pdf_nonnegative=enforce_pdf_nonnegative, |
| 136 | + ) |
| 137 | + return sgd |
0 commit comments