1- """
2- Tests for S2CondS2GridDistribution.
3-
4- These tests mirror the MATLAB test class S2CondS2GridDistributionTest.
5- """
61import unittest
72import warnings
83
4+ from matplotlib .pylab import column_stack
95import numpy .testing as npt
10- import pyrecest
6+ import pyrecest . backend import array , zeros , pi , ones , sum
117
12- from pyrecest .backend import array , ones
138from pyrecest .distributions .conditional .s2_cond_s2_grid_distribution import (
149 S2CondS2GridDistribution ,
1510)
@@ -62,14 +57,12 @@ def test_wrong_grid_dim_raises(self):
6257 # Build a 2-sphere grid and misshape it to 4D
6358 grid , _ = get_grid_hypersphere ("leopardi" , 10 , 2 )
6459 n = grid .shape [0 ]
65- import numpy as np
6660
67- surface = 4 * np . pi
61+ surface = 4 * pi
6862 gv = ones ((n , n )) / surface
6963 # Simulate a non-S2 grid (embed in 4D instead of 3D) - should raise
70- import numpy as np
7164
72- grid_4d = np . column_stack ([grid , np . zeros (n )])
65+ grid_4d = column_stack ([grid , zeros (n )])
7366 with self .assertRaises (ValueError ):
7467 S2CondS2GridDistribution (grid_4d , gv )
7568
@@ -84,9 +77,7 @@ def test_warning_free_normalized_vmf(self):
8477
8578 def trans (xkk , xk ):
8679 # xkk: (n1, 3), xk: (n2, 3) -> (n1, n2)
87- import numpy as np
88-
89- result = np .zeros ((xkk .shape [0 ], xk .shape [0 ]))
80+ result = zeros ((xkk .shape [0 ], xk .shape [0 ]))
9081 for i in range (xk .shape [0 ]):
9182 vmf = VonMisesFisherDistribution (xk [i ], 1.0 )
9283 result [:, i ] = vmf .pdf (xkk )
@@ -109,7 +100,7 @@ def trans(xkk, xk):
109100 # xkk, xk both (n_pairs, 3) when fun_does_cartesian_product=False
110101 D = array ([0.1 , 0.15 , 1.0 ])
111102 diff = (xkk - xk ) * D [None , :]
112- return 1.0 / (np . sum (diff ** 2 , axis = 1 ) + 0.01 )
103+ return 1.0 / (sum (diff ** 2 , axis = 1 ) + 0.01 )
113104
114105 with self .assertWarns (UserWarning ):
115106 S2CondS2GridDistribution .from_function (
@@ -134,7 +125,7 @@ def trans_unnorm(pts, fixed):
134125 diff = (pts - fixed [None , :]) * D [None , :]
135126 return 1.0 / (np .sum (diff ** 2 , axis = 1 ) + 0.01 )
136127
137- p = np . zeros ((xkk .shape [0 ], xk .shape [0 ]))
128+ p = zeros ((xkk .shape [0 ], xk .shape [0 ]))
138129 for i in range (xk .shape [0 ]):
139130 chd = CustomHypersphericalDistribution (
140131 lambda pts , fi = xk [i ]: trans_unnorm (pts , fi ), 2
@@ -156,8 +147,6 @@ def test_equal_with_and_without_cart(self):
156147 dist = VonMisesFisherDistribution (array ([0.0 , - 1.0 , 0.0 ]), 100.0 )
157148
158149 def f_trans1 (xkk , xk ):
159- import numpy as np
160-
161150 vals = dist .pdf (xkk ) # (n1,)
162151 return np .tile (vals [:, None ], (1 , xk .shape [0 ])) # (n1, n2)
163152
@@ -178,9 +167,7 @@ def test_fix_dim_returns_spherical_grid_distribution(self):
178167 no_grid_points = 50
179168
180169 def trans (xkk , xk ):
181- import numpy as np
182-
183- result = np .zeros ((xkk .shape [0 ], xk .shape [0 ]))
170+ result = zeros ((xkk .shape [0 ], xk .shape [0 ]))
184171 for i in range (xk .shape [0 ]):
185172 vmf = VonMisesFisherDistribution (xk [i ], 1.0 )
186173 result [:, i ] = vmf .pdf (xkk )
@@ -205,9 +192,7 @@ def test_fix_dim_mean_direction(self):
205192 no_grid_points = 112
206193
207194 def trans (xkk , xk ):
208- import numpy as np
209-
210- result = np .zeros ((xkk .shape [0 ], xk .shape [0 ]))
195+ result = zeros ((xkk .shape [0 ], xk .shape [0 ]))
211196 for i in range (xk .shape [0 ]):
212197 vmf = VonMisesFisherDistribution (xk [i ], 1.0 )
213198 result [:, i ] = vmf .pdf (xkk )
@@ -227,9 +212,7 @@ def test_marginalize_out_returns_spherical_grid_distribution(self):
227212 no_grid_points = 50
228213
229214 def trans (xkk , xk ):
230- import numpy as np
231-
232- result = np .zeros ((xkk .shape [0 ], xk .shape [0 ]))
215+ result = zeros ((xkk .shape [0 ], xk .shape [0 ]))
233216 for i in range (xk .shape [0 ]):
234217 vmf = VonMisesFisherDistribution (xk [i ], 1.0 )
235218 result [:, i ] = vmf .pdf (xkk )
0 commit comments