Skip to content

Commit c555a2d

Browse files
committed
Next
1 parent 88b6c5b commit c555a2d

1 file changed

Lines changed: 10 additions & 11 deletions

File tree

pyrecest/distributions/hypersphere_subset/hyperhemispherical_grid_distribution.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from matplotlib.pyplot import grid
21
from .abstract_hypersphere_subset_grid_distribution import (
32
AbstractHypersphereSubsetGridDistribution,
43
)
@@ -20,16 +19,16 @@ class HyperhemisphericalGridDistribution(
2019
AbstractHypersphereSubsetGridDistribution, AbstractHyperhemisphericalDistribution
2120
):
2221

23-
def __init__(self, grid_, grid_values, enforce_pdf_nonnegative=True):
22+
def __init__(self, grid, grid_values, enforce_pdf_nonnegative=True):
2423
# Do not test norm precisely, only quick test if any coordinate exceeds 1.
25-
assert all(abs(grid_) <= 1 + 1e-12), (
24+
assert all(abs(grid) <= 1 + 1e-12), (
2625
"Grid points must not lie outside the unit ball."
2726
)
2827
assert all(
29-
grid_[:, -1] >= -1e-12
28+
grid[:, -1] >= -1e-12
3029
), "Always using upper hemisphere (along last dimension)."
3130

32-
super().__init__(grid_, grid_values, enforce_pdf_nonnegative)
31+
super().__init__(grid, grid_values, enforce_pdf_nonnegative)
3332

3433
# ------------------------------------------------------------------
3534
# Basic functionality
@@ -59,9 +58,9 @@ def to_full_sphere(self):
5958
hyperspherical distribution normalized.
6059
"""
6160
from .hyperspherical_grid_distribution import HypersphericalGridDistribution
62-
grid_ = vstack((self.grid, -self.grid))
61+
grid_full = vstack((self.grid, -self.grid))
6362
grid_values_ = 0.5 * concatenate((self.grid_values, self.grid_values))
64-
hgd = HypersphericalGridDistribution(grid_, grid_values_)
63+
hgd = HypersphericalGridDistribution(grid_full, grid_values_)
6564
return hgd
6665

6766
def plot(self):
@@ -119,8 +118,8 @@ def get_closest_point(self, xs):
119118
raise ValueError("xs must be a 1D or 2D array.")
120119

121120
# Distances to each grid point and its antipode.
122-
diff1 = xs[:, None, :] - grid[None, :, :] # (batch, n_grid, dim)
123-
diff2 = xs[:, None, :] + grid[None, :, :] # (batch, n_grid, dim)
121+
diff1 = xs[:, None, :] - self.grid[None, :, :] # (batch, n_grid, dim)
122+
diff2 = xs[:, None, :] + self.grid[None, :, :] # (batch, n_grid, dim)
124123

125124
dists1 = linalg.norm(diff1, axis=2) # (batch, n_grid)
126125
dists2 = linalg.norm(diff2, axis=2) # (batch, n_grid)
@@ -189,8 +188,8 @@ def multiply(self: "HyperhemisphericalGridDistribution", other: "Hyperhemispheri
189188
@staticmethod
190189
def eq_point_set_upper_half(dim, n_points):
191190
ls = LeopardiSampler()
192-
grid_, _ = ls.get_grid(n_points * 2, dim)
193-
grid_upper_half = grid_[:grid_.shape[0]//2]
191+
grid, _ = ls.get_grid(n_points * 2, dim)
192+
grid_upper_half = grid[:grid.shape[0]//2]
194193
# To have upper half along last dim instead of first
195194
grid_upper_half[:, [0, -1]] = grid_upper_half[:, [-1, 0]]
196195
return grid_upper_half

0 commit comments

Comments
 (0)