1- from matplotlib .pyplot import grid
21from .abstract_hypersphere_subset_grid_distribution import (
32 AbstractHypersphereSubsetGridDistribution ,
43)
@@ -20,16 +19,16 @@ class HyperhemisphericalGridDistribution(
2019 AbstractHypersphereSubsetGridDistribution , AbstractHyperhemisphericalDistribution
2120):
2221
23- def __init__ (self , grid_ , grid_values , enforce_pdf_nonnegative = True ):
22+ def __init__ (self , grid , grid_values , enforce_pdf_nonnegative = True ):
2423 # Do not test norm precisely, only quick test if any coordinate exceeds 1.
25- assert all (abs (grid_ ) <= 1 + 1e-12 ), (
24+ assert all (abs (grid ) <= 1 + 1e-12 ), (
2625 "Grid points must not lie outside the unit ball."
2726 )
2827 assert all (
29- grid_ [:, - 1 ] >= - 1e-12
28+ grid [:, - 1 ] >= - 1e-12
3029 ), "Always using upper hemisphere (along last dimension)."
3130
32- super ().__init__ (grid_ , grid_values , enforce_pdf_nonnegative )
31+ super ().__init__ (grid , grid_values , enforce_pdf_nonnegative )
3332
3433 # ------------------------------------------------------------------
3534 # Basic functionality
@@ -59,9 +58,9 @@ def to_full_sphere(self):
5958 hyperspherical distribution normalized.
6059 """
6160 from .hyperspherical_grid_distribution import HypersphericalGridDistribution
62- grid_ = vstack ((self .grid , - self .grid ))
61+ grid_full = vstack ((self .grid , - self .grid ))
6362 grid_values_ = 0.5 * concatenate ((self .grid_values , self .grid_values ))
64- hgd = HypersphericalGridDistribution (grid_ , grid_values_ )
63+ hgd = HypersphericalGridDistribution (grid_full , grid_values_ )
6564 return hgd
6665
6766 def plot (self ):
@@ -119,8 +118,8 @@ def get_closest_point(self, xs):
119118 raise ValueError ("xs must be a 1D or 2D array." )
120119
121120 # Distances to each grid point and its antipode.
122- diff1 = xs [:, None , :] - grid [None , :, :] # (batch, n_grid, dim)
123- diff2 = xs [:, None , :] + grid [None , :, :] # (batch, n_grid, dim)
121+ diff1 = xs [:, None , :] - self . grid [None , :, :] # (batch, n_grid, dim)
122+ diff2 = xs [:, None , :] + self . grid [None , :, :] # (batch, n_grid, dim)
124123
125124 dists1 = linalg .norm (diff1 , axis = 2 ) # (batch, n_grid)
126125 dists2 = linalg .norm (diff2 , axis = 2 ) # (batch, n_grid)
@@ -189,8 +188,8 @@ def multiply(self: "HyperhemisphericalGridDistribution", other: "Hyperhemispheri
189188 @staticmethod
190189 def eq_point_set_upper_half (dim , n_points ):
191190 ls = LeopardiSampler ()
192- grid_ , _ = ls .get_grid (n_points * 2 , dim )
193- grid_upper_half = grid_ [: grid_ .shape [0 ]// 2 ]
191+ grid , _ = ls .get_grid (n_points * 2 , dim )
192+ grid_upper_half = grid [: grid .shape [0 ]// 2 ]
194193 # To have upper half along last dim instead of first
195194 grid_upper_half [:, [0 , - 1 ]] = grid_upper_half [:, [- 1 , 0 ]]
196195 return grid_upper_half
0 commit comments