Skip to content

Commit 4424230

Browse files
committed
Cleanup and using backend
1 parent e5e0b20 commit 4424230

1 file changed

Lines changed: 10 additions & 27 deletions

File tree

pyrecest/tests/distributions/test_s2_cond_s2_grid_distribution.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
"""
2-
Tests for S2CondS2GridDistribution.
3-
4-
These tests mirror the MATLAB test class S2CondS2GridDistributionTest.
5-
"""
61
import unittest
72
import warnings
83

4+
from matplotlib.pylab import column_stack
95
import numpy.testing as npt
10-
import pyrecest
6+
import pyrecest.backend import array, zeros, pi, ones, sum
117

12-
from pyrecest.backend import array, ones
138
from pyrecest.distributions.conditional.s2_cond_s2_grid_distribution import (
149
S2CondS2GridDistribution,
1510
)
@@ -62,14 +57,12 @@ def test_wrong_grid_dim_raises(self):
6257
# Build a 2-sphere grid and misshape it to 4D
6358
grid, _ = get_grid_hypersphere("leopardi", 10, 2)
6459
n = grid.shape[0]
65-
import numpy as np
6660

67-
surface = 4 * np.pi
61+
surface = 4 * pi
6862
gv = ones((n, n)) / surface
6963
# Simulate a non-S2 grid (embed in 4D instead of 3D) - should raise
70-
import numpy as np
7164

72-
grid_4d = np.column_stack([grid, np.zeros(n)])
65+
grid_4d = column_stack([grid, zeros(n)])
7366
with self.assertRaises(ValueError):
7467
S2CondS2GridDistribution(grid_4d, gv)
7568

@@ -84,9 +77,7 @@ def test_warning_free_normalized_vmf(self):
8477

8578
def trans(xkk, xk):
8679
# xkk: (n1, 3), xk: (n2, 3) -> (n1, n2)
87-
import numpy as np
88-
89-
result = np.zeros((xkk.shape[0], xk.shape[0]))
80+
result = zeros((xkk.shape[0], xk.shape[0]))
9081
for i in range(xk.shape[0]):
9182
vmf = VonMisesFisherDistribution(xk[i], 1.0)
9283
result[:, i] = vmf.pdf(xkk)
@@ -109,7 +100,7 @@ def trans(xkk, xk):
109100
# xkk, xk both (n_pairs, 3) when fun_does_cartesian_product=False
110101
D = array([0.1, 0.15, 1.0])
111102
diff = (xkk - xk) * D[None, :]
112-
return 1.0 / (np.sum(diff**2, axis=1) + 0.01)
103+
return 1.0 / (sum(diff**2, axis=1) + 0.01)
113104

114105
with self.assertWarns(UserWarning):
115106
S2CondS2GridDistribution.from_function(
@@ -134,7 +125,7 @@ def trans_unnorm(pts, fixed):
134125
diff = (pts - fixed[None, :]) * D[None, :]
135126
return 1.0 / (np.sum(diff**2, axis=1) + 0.01)
136127

137-
p = np.zeros((xkk.shape[0], xk.shape[0]))
128+
p = zeros((xkk.shape[0], xk.shape[0]))
138129
for i in range(xk.shape[0]):
139130
chd = CustomHypersphericalDistribution(
140131
lambda pts, fi=xk[i]: trans_unnorm(pts, fi), 2
@@ -156,8 +147,6 @@ def test_equal_with_and_without_cart(self):
156147
dist = VonMisesFisherDistribution(array([0.0, -1.0, 0.0]), 100.0)
157148

158149
def f_trans1(xkk, xk):
159-
import numpy as np
160-
161150
vals = dist.pdf(xkk) # (n1,)
162151
return np.tile(vals[:, None], (1, xk.shape[0])) # (n1, n2)
163152

@@ -178,9 +167,7 @@ def test_fix_dim_returns_spherical_grid_distribution(self):
178167
no_grid_points = 50
179168

180169
def trans(xkk, xk):
181-
import numpy as np
182-
183-
result = np.zeros((xkk.shape[0], xk.shape[0]))
170+
result = zeros((xkk.shape[0], xk.shape[0]))
184171
for i in range(xk.shape[0]):
185172
vmf = VonMisesFisherDistribution(xk[i], 1.0)
186173
result[:, i] = vmf.pdf(xkk)
@@ -205,9 +192,7 @@ def test_fix_dim_mean_direction(self):
205192
no_grid_points = 112
206193

207194
def trans(xkk, xk):
208-
import numpy as np
209-
210-
result = np.zeros((xkk.shape[0], xk.shape[0]))
195+
result = zeros((xkk.shape[0], xk.shape[0]))
211196
for i in range(xk.shape[0]):
212197
vmf = VonMisesFisherDistribution(xk[i], 1.0)
213198
result[:, i] = vmf.pdf(xkk)
@@ -227,9 +212,7 @@ def test_marginalize_out_returns_spherical_grid_distribution(self):
227212
no_grid_points = 50
228213

229214
def trans(xkk, xk):
230-
import numpy as np
231-
232-
result = np.zeros((xkk.shape[0], xk.shape[0]))
215+
result = zeros((xkk.shape[0], xk.shape[0]))
233216
for i in range(xk.shape[0]):
234217
vmf = VonMisesFisherDistribution(xk[i], 1.0)
235218
result[:, i] = vmf.pdf(xkk)

0 commit comments

Comments
 (0)