Skip to content

Commit 29243e7

Browse files
authored
Merge pull request #1418 from FlorianPfaff/megalinter-fixes
[MegaLinter] Apply linters automatic fixes
2 parents 4a27afb + a41e80e commit 29243e7

24 files changed

Lines changed: 373 additions & 217 deletions

pyrecest/distributions/abstract_grid_distribution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import copy
22
import warnings
33
from abc import abstractmethod
4+
from math import prod
45

56
from beartype import beartype
67

78
# pylint: disable=redefined-builtin,no-name-in-module,no-member
89
from pyrecest.backend import abs, any, mean
9-
from math import prod
1010

1111
from .abstract_distribution_type import AbstractDistributionType
1212

@@ -27,7 +27,9 @@ def __init__(
2727
) # if grid_type is custom, grid needs to be given
2828
assert (
2929
# Use builtin prod because .shape is a tuple of ints
30-
grid is None or grid.shape == () or grid.shape[0] == prod(grid_values.shape)
30+
grid is None
31+
or grid.shape == ()
32+
or grid.shape[0] == prod(grid_values.shape)
3133
)
3234
assert (
3335
grid is None or grid.shape == () or grid.ndim == 1 or grid.shape[1] == dim
Lines changed: 91 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,137 @@
11
import warnings
22

3-
from ..abstract_grid_distribution import AbstractGridDistribution
4-
from .abstract_hypersphere_subset_distribution import AbstractHypersphereSubsetDistribution
5-
from .abstract_hyperhemispherical_distribution import AbstractHyperhemisphericalDistribution
3+
from beartype import beartype
4+
5+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
6+
from pyrecest.backend import argmax, array_equal, sum
7+
8+
from ..abstract_grid_distribution import AbstractGridDistribution
9+
from .abstract_hyperhemispherical_distribution import (
10+
AbstractHyperhemisphericalDistribution,
11+
)
12+
from .abstract_hypersphere_subset_distribution import (
13+
AbstractHypersphereSubsetDistribution,
14+
)
615
from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution
7-
from .von_mises_fisher_distribution import VonMisesFisherDistribution
816
from .bingham_distribution import BinghamDistribution
917
from .hyperspherical_mixture import HypersphericalMixture
18+
from .von_mises_fisher_distribution import VonMisesFisherDistribution
1019
from .watson_distribution import WatsonDistribution
11-
from beartype import beartype
1220

13-
# pylint: disable=redefined-builtin,no-name-in-module,no-member
14-
from pyrecest.backend import array_equal, argmax, sum
1521

16-
class AbstractHypersphereSubsetGridDistribution(AbstractGridDistribution, AbstractHypersphereSubsetDistribution):
17-
22+
class AbstractHypersphereSubsetGridDistribution(
23+
AbstractGridDistribution, AbstractHypersphereSubsetDistribution
24+
):
25+
1826
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
1927
# Check size consistency
2028
if grid.shape[0] != grid_values.shape[0]:
2129
raise ValueError("Grid size must match number of grid values.")
22-
23-
AbstractGridDistribution.__init__(self, grid_values, grid_type = "unknown", grid=grid, dim=grid.shape[1], enforce_pdf_nonnegative=enforce_pdf_nonnegative)
30+
31+
AbstractGridDistribution.__init__(
32+
self,
33+
grid_values,
34+
grid_type="unknown",
35+
grid=grid,
36+
dim=grid.shape[1],
37+
enforce_pdf_nonnegative=enforce_pdf_nonnegative,
38+
)
2439
AbstractHypersphereSubsetDistribution.__init__(self, dim=grid.shape[1])
25-
self.normalize()
40+
self.normalize()
2641

2742
def mean_direction(self):
28-
warnings.warn("For hyperhemispheres, this function yields the mode and not the mean.", UserWarning)
43+
warnings.warn(
44+
"For hyperhemispheres, this function yields the mode and not the mean.",
45+
UserWarning,
46+
)
2947
# If we took the mean, it would be biased toward [0;...;0;1]
3048
# because the lower half is considered inexistant.
3149
index_max = argmax(self.grid_values)
3250
mu = self.get_grid_point(index_max)
3351
return mu
3452

3553
def moment(self):
36-
weights = self.grid_values / sum(self.grid_values) # (N,)
37-
38-
weighted_grid = self.get_grid() * weights
54+
weights = self.grid_values / sum(self.grid_values) # (N,)
55+
56+
weighted_grid = self.get_grid() * weights
3957

4058
C = weighted_grid * (self.get_grid().T @ self.get_grid())
4159
return C
4260

4361
@beartype
44-
def multiply(self: "AbstractHypersphereSubsetGridDistribution", other: "AbstractHypersphereSubsetGridDistribution") -> "AbstractHypersphereSubsetGridDistribution":
62+
def multiply(
63+
self: "AbstractHypersphereSubsetGridDistribution",
64+
other: "AbstractHypersphereSubsetGridDistribution",
65+
) -> "AbstractHypersphereSubsetGridDistribution":
4566
# Check for grid compatibility
4667
if not array_equal(self.get_grid(), other.get_grid()):
47-
raise ValueError("Can only multiply for equal grids. Grids are incompatible.")
48-
68+
raise ValueError(
69+
"Can only multiply for equal grids. Grids are incompatible."
70+
)
71+
4972
# Delegates multiplication logic to AbstractGridDistribution
5073
return super().multiply(other)
5174

5275
@classmethod
53-
def from_distribution(cls, distribution, no_of_grid_points, grid_type, enforce_pdf_nonnegative=True):
76+
def from_distribution(
77+
cls, distribution, no_of_grid_points, grid_type, enforce_pdf_nonnegative=True
78+
):
79+
from .hyperhemispherical_grid_distribution import (
80+
HyperhemisphericalGridDistribution,
81+
)
5482
from .hyperspherical_grid_distribution import HypersphericalGridDistribution
55-
from .hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution
83+
5684
# pylint: disable=too-many-boolean-expressions
5785
if isinstance(distribution, AbstractHypersphereSubsetGridDistribution):
58-
raise ValueError('Already a grid distribution. Use directly instead of converting.')
59-
60-
if isinstance(distribution, AbstractHypersphericalDistribution) and issubclass(cls, HypersphericalGridDistribution)\
61-
or isinstance(distribution, AbstractHyperhemisphericalDistribution) and issubclass(cls, HyperhemisphericalGridDistribution):
86+
raise ValueError(
87+
"Already a grid distribution. Use directly instead of converting."
88+
)
89+
90+
if (
91+
isinstance(distribution, AbstractHypersphericalDistribution)
92+
and issubclass(cls, HypersphericalGridDistribution)
93+
or isinstance(distribution, AbstractHyperhemisphericalDistribution)
94+
and issubclass(cls, HyperhemisphericalGridDistribution)
95+
):
6296
# sphere -> sphere or hemisphere -> hemisphere
6397
fun = distribution.pdf
64-
elif issubclass(cls, HyperhemisphericalGridDistribution) and (isinstance(distribution, (WatsonDistribution, BinghamDistribution)) or
65-
(isinstance(distribution, VonMisesFisherDistribution) and distribution.mu[-1] == 0) or
66-
(isinstance(distribution, HypersphericalMixture) and
67-
len(distribution.dists) == 2 and all(w == 0.5 for w in distribution.w) and
68-
array_equal(distribution.dists[1].mu, -distribution.dists[0].mu))):
98+
elif issubclass(cls, HyperhemisphericalGridDistribution) and (
99+
isinstance(distribution, (WatsonDistribution, BinghamDistribution))
100+
or (
101+
isinstance(distribution, VonMisesFisherDistribution)
102+
and distribution.mu[-1] == 0
103+
)
104+
or (
105+
isinstance(distribution, HypersphericalMixture)
106+
and len(distribution.dists) == 2
107+
and all(w == 0.5 for w in distribution.w)
108+
and array_equal(distribution.dists[1].mu, -distribution.dists[0].mu)
109+
)
110+
):
69111
# sphere -> hemisphere for symmetric distributions
70112
def fun(x):
71113
return 2 * distribution.pdf(x)
72-
elif isinstance(distribution, AbstractHypersphericalDistribution) and issubclass(cls, HyperhemisphericalGridDistribution):
114+
115+
elif isinstance(
116+
distribution, AbstractHypersphericalDistribution
117+
) and issubclass(cls, HyperhemisphericalGridDistribution):
73118
# sphere -> hemisphere for general distributions, which we do not know to be symmetric
74-
warnings.warn('Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.',
75-
UserWarning)
119+
warnings.warn(
120+
"Approximating a hyperspherical distribution on a hemisphere. The density may not be symmetric. Double check if this is intentional.",
121+
UserWarning,
122+
)
123+
76124
def fun(x):
77125
return 2 * distribution.pdf(x)
126+
78127
else:
79-
raise ValueError('Distribution currently not supported.')
80-
81-
sgd = cls.from_function(fun, no_of_grid_points, distribution.dim, grid_type, enforce_pdf_nonnegative=enforce_pdf_nonnegative)
82-
return sgd
128+
raise ValueError("Distribution currently not supported.")
83129

130+
sgd = cls.from_function(
131+
fun,
132+
no_of_grid_points,
133+
distribution.dim,
134+
grid_type,
135+
enforce_pdf_nonnegative=enforce_pdf_nonnegative,
136+
)
137+
return sgd

pyrecest/distributions/hypersphere_subset/hyperhemispherical_grid_distribution.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,40 @@
1-
from .abstract_hypersphere_subset_grid_distribution import (
2-
AbstractHypersphereSubsetGridDistribution,
1+
import warnings
2+
3+
from beartype import beartype
4+
5+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
6+
from pyrecest.backend import (
7+
abs,
8+
all,
9+
allclose,
10+
argmax,
11+
argmin,
12+
concatenate,
13+
linalg,
14+
minimum,
15+
vstack,
316
)
17+
18+
from ...sampling.hyperspherical_sampler import get_grid_hyperhemisphere
419
from .abstract_hyperhemispherical_distribution import (
520
AbstractHyperhemisphericalDistribution,
621
)
22+
from .abstract_hypersphere_subset_grid_distribution import (
23+
AbstractHypersphereSubsetGridDistribution,
24+
)
725
from .custom_hyperhemispherical_distribution import CustomHyperhemisphericalDistribution
826
from .hyperspherical_dirac_distribution import HypersphericalDiracDistribution
9-
from ...sampling.hyperspherical_sampler import get_grid_hyperhemisphere
1027

11-
import warnings
12-
13-
# pylint: disable=redefined-builtin,no-name-in-module,no-member
14-
from pyrecest.backend import all, abs, argmax, concatenate, vstack, linalg, allclose, minimum, argmin
15-
from beartype import beartype
1628

1729
class HyperhemisphericalGridDistribution(
1830
AbstractHypersphereSubsetGridDistribution, AbstractHyperhemisphericalDistribution
1931
):
2032

2133
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
2234
# Do not test norm precisely, only quick test if any coordinate exceeds 1.
23-
assert all(abs(grid) <= 1 + 1e-12), (
24-
"Grid points must not lie outside the unit cube."
25-
)
35+
assert all(
36+
abs(grid) <= 1 + 1e-12
37+
), "Grid points must not lie outside the unit cube."
2638
assert all(
2739
grid[:, -1] >= -1e-12
2840
), "Always using upper hemisphere (along last dimension)."
@@ -62,6 +74,7 @@ def to_full_sphere(self, method="antipodal"):
6274
"full-sphere grid distributions."
6375
)
6476
from .hyperspherical_grid_distribution import HypersphericalGridDistribution
77+
6578
grid_full = vstack((self.grid, -self.grid))
6679
grid_values_ = 0.5 * concatenate((self.grid_values, self.grid_values))
6780
hgd = HypersphericalGridDistribution(grid_full, grid_values_)
@@ -74,13 +87,14 @@ def plot(self):
7487

7588
def plot_interpolated(self):
7689
hdgd = self.to_full_sphere()
90+
7791
def pdf_doubled(x):
7892
return 2 * hdgd.pdf(x)
79-
hhgd_interp = CustomHyperhemisphericalDistribution(
80-
pdf_doubled, 3
81-
)
93+
94+
hhgd_interp = CustomHyperhemisphericalDistribution(pdf_doubled, 3)
8295
h = hhgd_interp.plot()
8396
return h
97+
8498
# ------------------------------------------------------------------
8599
# Grid geometry utilities
86100
# ------------------------------------------------------------------
@@ -105,19 +119,15 @@ def get_closest_point(self, xs):
105119

106120
if xs.ndim == 1:
107121
if xs.shape[0] != self.dim:
108-
raise ValueError(
109-
f"xs must have length {self.dim}, got {xs.shape[0]}."
110-
)
122+
raise ValueError(f"xs must have length {self.dim}, got {xs.shape[0]}.")
111123
xs = xs[None, :] # (1, dim)
112124
elif xs.ndim == 2:
113125
if xs.shape[1] == self.dim:
114126
pass # already (batch, dim)
115127
elif xs.shape[0] == self.dim:
116128
xs = xs.T # (batch, dim)
117129
else:
118-
raise ValueError(
119-
f"xs must have shape (n, dim) with dim={self.dim}."
120-
)
130+
raise ValueError(f"xs must have shape (n, dim) with dim={self.dim}.")
121131
else:
122132
raise ValueError("xs must be a 1D or 2D array.")
123133

@@ -139,15 +149,18 @@ def get_closest_point(self, xs):
139149
indices = indices[0]
140150

141151
return points, indices
142-
152+
143153
def get_manifold_size(self):
144154
return AbstractHyperhemisphericalDistribution.get_manifold_size(self)
145155

146156
# ------------------------------------------------------------------
147157
# Multiplication on the hemisphere
148158
# ------------------------------------------------------------------
149159
@beartype
150-
def multiply(self: "HyperhemisphericalGridDistribution", other: "HyperhemisphericalGridDistribution") -> "HyperhemisphericalGridDistribution":
160+
def multiply(
161+
self: "HyperhemisphericalGridDistribution",
162+
other: "HyperhemisphericalGridDistribution",
163+
) -> "HyperhemisphericalGridDistribution":
151164
"""
152165
Multiply two hyperhemispherical grid distributions that share the same grid.
153166
@@ -194,7 +207,13 @@ def multiply(self: "HyperhemisphericalGridDistribution", other: "Hyperhemispheri
194207
# ------------------------------------------------------------------
195208
# pylint: disable=too-many-locals
196209
@staticmethod
197-
def from_function(fun, no_of_grid_points, dim=2, grid_type="leopardi_symm", enforce_pdf_nonnegative=True):
210+
def from_function(
211+
fun,
212+
no_of_grid_points,
213+
dim=2,
214+
grid_type="leopardi_symm",
215+
enforce_pdf_nonnegative=True,
216+
):
198217
"""
199218
Construct a hyperhemispherical grid distribution from a callable.
200219
@@ -216,13 +235,16 @@ def from_function(fun, no_of_grid_points, dim=2, grid_type="leopardi_symm", enfo
216235
depends on (dim, no_of_grid_points, grid_type). For 'healpix',
217236
only dim == 3 is supported and `healpy` must be installed.
218237
"""
219-
assert grid_type in ("leopardi_symm", "healpix"), (
220-
"For hyperhemispheres, use one of the symmetric grid types 'leopardi_symm' or 'healpix'."
221-
)
238+
assert grid_type in (
239+
"leopardi_symm",
240+
"healpix",
241+
), "For hyperhemispheres, use one of the symmetric grid types 'leopardi_symm' or 'healpix'."
222242
grid, _ = get_grid_hyperhemisphere(grid_type, no_of_grid_points, dim=dim)
223243

224244
grid_values = fun(grid)
225245

226-
sgd = HyperhemisphericalGridDistribution(grid, grid_values, enforce_pdf_nonnegative=enforce_pdf_nonnegative)
246+
sgd = HyperhemisphericalGridDistribution(
247+
grid, grid_values, enforce_pdf_nonnegative=enforce_pdf_nonnegative
248+
)
227249
sgd.grid_type = grid_type
228250
return sgd

0 commit comments

Comments
 (0)