2222)
2323from pyrecest .distributions import (
2424 AbstractSphericalDistribution ,
25- HypersphericalUniformDistribution ,
2625 HyperhemisphericalUniformDistribution ,
26+ HypersphericalUniformDistribution ,
2727)
2828
2929from .abstract_sampler import AbstractSampler
@@ -53,12 +53,16 @@ def get_grid_hypersphere(method: str, grid_density_parameter: int, dim: int):
5353
5454 return samples , grid_specific_description
5555
56+
5657def get_grid_sphere (method : str , grid_density_parameter : int ):
5758 return get_grid_hypersphere (method , grid_density_parameter , dim = 2 )
5859
60+
5961def get_grid_hyperhemisphere (method : str , grid_density_parameter : int , dim : int ):
6062 if method == "leopardi" :
61- ls = SymmetricLeopardiSampler (original_code_column_order = True , delete_half = True , symmetry_type = "plane" )
63+ ls = SymmetricLeopardiSampler (
64+ original_code_column_order = True , delete_half = True , symmetry_type = "plane"
65+ )
6266 samples , _ = ls .get_grid (grid_density_parameter * 2 , dim )
6367 # To have upper half along last dim instead of first
6468 grid_specific_description = {
@@ -70,14 +74,16 @@ def get_grid_hyperhemisphere(method: str, grid_density_parameter: int, dim: int)
7074
7175 return samples , grid_specific_description
7276
77+
7378class AbstractHypersphericalUniformSampler (AbstractSampler ):
7479 def sample_stochastic (self , n_samples : int , dim : int ):
7580 return HypersphericalUniformDistribution (dim ).sample (n_samples )
7681
7782 @abstractmethod
7883 def get_grid (self , grid_density_parameter , dim : int ):
7984 raise NotImplementedError ()
80-
85+
86+
8187class AbstractHyperhemisphericalUniformSampler (AbstractSampler ):
8288 def sample_stochastic (self , n_samples : int , dim : int ):
8389 return HyperhemisphericalUniformDistribution (dim ).sample (n_samples )
@@ -155,7 +161,9 @@ def __init__(self, original_code_column_order=True):
155161
156162 def get_grid (self , grid_density_parameter , dim : int ):
157163 # Use [::-1] due to different convention
158- grid_eucl = get_partition_points_cartesian (dim , grid_density_parameter , delete_half = False , symmetry_type = "asymm" )
164+ grid_eucl = get_partition_points_cartesian (
165+ dim , grid_density_parameter , delete_half = False , symmetry_type = "asymm"
166+ )
159167
160168 if self .original_code_column_order :
161169 grid_eucl = flip (grid_eucl , axis = 1 )
@@ -166,18 +174,26 @@ def get_grid(self, grid_density_parameter, dim: int):
166174 "n_side" : grid_density_parameter ,
167175 }
168176 return grid_eucl , grid_specific_description
169-
177+
178+
170179class SymmetricLeopardiSampler (AbstractHypersphericalUniformSampler ):
171- def __init__ (self , original_code_column_order = True , delete_half = False , symmetry_type = 'plane' ):
180+ def __init__ (
181+ self , original_code_column_order = True , delete_half = False , symmetry_type = "plane"
182+ ):
172183 self .original_code_column_order = original_code_column_order
173184 self .delete_half = delete_half
174185 self .symmetry_type = symmetry_type
175186 assert backend .__backend_name__ != "jax" , "Backend unsupported"
176187
177188 def get_grid (self , grid_density_parameter , dim : int ):
178189 # Use [::-1] due to different convention
179- grid_eucl = get_partition_points_cartesian (dim , grid_density_parameter , delete_half = self .delete_half , symmetry_type = self .symmetry_type )
180-
190+ grid_eucl = get_partition_points_cartesian (
191+ dim ,
192+ grid_density_parameter ,
193+ delete_half = self .delete_half ,
194+ symmetry_type = self .symmetry_type ,
195+ )
196+
181197 if self .original_code_column_order :
182198 grid_eucl = flip (grid_eucl , axis = 1 )
183199 grid_eucl [:, [0 , 1 ]] = grid_eucl [:, [1 , 0 ]]
0 commit comments