Skip to content

Commit 5706297

Browse files
FlorianPfaffgithub-actions[bot]
authored andcommitted
[MegaLinter] Apply linters automatic fixes
1 parent 47dbe3b commit 5706297

3 files changed

Lines changed: 25 additions & 26 deletions

File tree

pyrecest/distributions/cart_prod/state_space_subdivision_distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def __init__(self, gd, linear_distributions):
2626
One distribution per grid point representing the conditional
2727
distribution of the linear state given that grid point.
2828
"""
29-
assert gd.n_grid_points == len(linear_distributions), (
30-
"Number of grid points in gd must match length of linear_distributions."
31-
)
29+
assert gd.n_grid_points == len(
30+
linear_distributions
31+
), "Number of grid points in gd must match length of linear_distributions."
3232
self.gd = copy.deepcopy(gd)
3333
self.linear_distributions = list(copy.deepcopy(linear_distributions))
3434

pyrecest/distributions/cart_prod/state_space_subdivision_gaussian_distribution.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
# pylint: disable=no-name-in-module,no-member
55
from pyrecest.backend import (
66
allclose,
7-
any as backend_any,
7+
)
8+
from pyrecest.backend import any as backend_any
9+
from pyrecest.backend import (
810
argmax,
911
array,
1012
asarray,
1113
concatenate,
1214
stack,
13-
sum as backend_sum,
15+
)
16+
from pyrecest.backend import sum as backend_sum
17+
from pyrecest.backend import (
1418
zeros,
1519
)
1620

@@ -39,9 +43,9 @@ def __init__(self, gd, gaussians):
3943
gaussians : list of GaussianDistribution
4044
One Gaussian per grid point of *gd*.
4145
"""
42-
assert all(isinstance(g, GaussianDistribution) for g in gaussians), (
43-
"All elements of gaussians must be GaussianDistribution instances."
44-
)
46+
assert all(
47+
isinstance(g, GaussianDistribution) for g in gaussians
48+
), "All elements of gaussians must be GaussianDistribution instances."
4549
super().__init__(gd, gaussians)
4650

4751
# ------------------------------------------------------------------
@@ -132,9 +136,7 @@ def multiply(self, other):
132136
)
133137
self_grid = asarray(self.gd.get_grid())
134138
other_grid = asarray(other.gd.get_grid())
135-
assert allclose(self_grid, other_grid), (
136-
"Can only multiply for equal grids."
137-
)
139+
assert allclose(self_grid, other_grid), "Can only multiply for equal grids."
138140

139141
n = len(self.linear_distributions)
140142
new_linear_distributions = []
@@ -148,7 +150,9 @@ def multiply(self, other):
148150
# N(mu_self_i, C_self_i + C_other_i) evaluated at mu_other_i.
149151
# This is equivalent to N(0, C_self_i + C_other_i) at 0.
150152
combined_cov = ld_self.C + ld_other.C
151-
temp_g = GaussianDistribution(ld_other.mu, combined_cov, check_validity=False)
153+
temp_g = GaussianDistribution(
154+
ld_other.mu, combined_cov, check_validity=False
155+
)
152156
pdf_values.append(temp_g.pdf(ld_self.mu))
153157

154158
new_linear_distributions.append(ld_self.multiply(ld_other))
@@ -215,7 +219,7 @@ def mode(self):
215219

216220
# Remove the maximum entry to check for multimodality
217221
remaining = concatenate(
218-
[fun_vals_joint[:index], fun_vals_joint[index + 1:]] # noqa: E203
222+
[fun_vals_joint[:index], fun_vals_joint[index + 1 :]] # noqa: E203
219223
)
220224
if len(remaining) > 0 and (
221225
backend_any((max_val - remaining) < 1e-15)

pyrecest/tests/distributions/test_state_space_subdivision_gaussian_distribution.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def test_multiply_s1_x_r1_identical_precise(self):
3333
gd = HypertoroidalGridDistribution.from_distribution(
3434
CircularUniformDistribution(), (n,)
3535
)
36-
gaussians = [GaussianDistribution(array([0.0]), array([[1.0]])) for _ in range(n)]
36+
gaussians = [
37+
GaussianDistribution(array([0.0]), array([[1.0]])) for _ in range(n)
38+
]
3739
rbd1 = StateSpaceSubdivisionGaussianDistribution(gd, gaussians)
3840

3941
gaussians2 = [
@@ -95,9 +97,7 @@ def test_hybrid_mean(self):
9597
gd = HypertoroidalGridDistribution.from_distribution(
9698
VonMisesDistribution(mu_periodic, 1.0), (n,)
9799
)
98-
gaussians = [
99-
GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)
100-
]
100+
gaussians = [GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)]
101101
rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians)
102102
npt.assert_allclose(
103103
rbd.hybrid_mean(),
@@ -112,9 +112,7 @@ def test_linear_mean(self):
112112
gd = HypertoroidalGridDistribution.from_distribution(
113113
VonMisesDistribution(4.0, 1.0), (n,)
114114
)
115-
gaussians = [
116-
GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)
117-
]
115+
gaussians = [GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)]
118116
rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians)
119117
npt.assert_allclose(rbd.linear_mean(), mu_linear, rtol=5e-7)
120118

@@ -125,9 +123,7 @@ def test_mode_warning_uniform(self):
125123
gd = HypertoroidalGridDistribution.from_distribution(
126124
CircularUniformDistribution(), (n,)
127125
)
128-
gaussians = [
129-
GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)
130-
]
126+
gaussians = [GaussianDistribution(mu_linear, 1000.0 * eye(3)) for _ in range(n)]
131127
rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians)
132128
with self.assertWarns(UserWarning):
133129
rbd.mode()
@@ -140,14 +136,13 @@ def test_mode(self):
140136
gd = HypertoroidalGridDistribution.from_distribution(
141137
VonMisesDistribution(mu_periodic, 10.0), (n,)
142138
)
143-
gaussians = [
144-
GaussianDistribution(mu_linear, eye(3)) for _ in range(n)
145-
]
139+
gaussians = [GaussianDistribution(mu_linear, eye(3)) for _ in range(n)]
146140
rbd = StateSpaceSubdivisionGaussianDistribution(gd, gaussians)
147141

148142
# Should not warn
149143
with self.assertNoLogs(level="WARNING"):
150144
import warnings as _warnings
145+
151146
with _warnings.catch_warnings():
152147
_warnings.simplefilter("error", UserWarning)
153148
m = rbd.mode()

0 commit comments

Comments
 (0)