Skip to content

Commit 876a348

Browse files
committed
Added SdHalfCondSdHalfGridDistribution
1 parent faf8e3d commit 876a348

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
from .abstract_conditional_distribution import AbstractConditionalDistribution
3+
from ..abstract_grid_distribution import AbstractGridDistribution
4+
from ..hypersphere_subset.hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution
5+
6+
class SdHalfCondSdHalfGridDistribution(AbstractConditionalDistribution, AbstractGridDistribution):
7+
def __init__(self, grid_, grid_values_, enforce_pdf_nonnegative=True):
8+
assert np.all(grid_[-1] >= 0), "Always using upper hemisphere (along last dimension)."
9+
self.dim = 2 * grid_.shape[0]
10+
assert grid_values_.shape[0] == grid_values_.shape[1]
11+
assert grid_.shape[1] == grid_values_.shape[0]
12+
self.grid = grid_
13+
self.grid_values = grid_values_
14+
self.enforce_pdf_nonnegative = enforce_pdf_nonnegative
15+
self.normalize()
16+
17+
def normalize(self, tol=0.01):
18+
ints = np.mean(self.grid_values, axis=1) * 0.5 * self.compute_unit_sphere_surface(self.dim // 2)
19+
if any(np.abs(ints - 1) > tol):
20+
if all(np.abs(ints - 1) <= tol):
21+
raise ValueError("Not normalized but would be normalized if order of the spheres were swapped. Check input.")
22+
else:
23+
print("When conditioning values for first sphere on second, normalization is not ensured. One reason may be that you are approximating a density on the entire sphere that is not symmetrical. You can try to increase tolerance.")
24+
25+
def multiply(self, other):
26+
assert np.array_equal(self.grid, other.grid), "Can only multiply for equal grids."
27+
print("Multiplication does not yield normalized result.")
28+
self.grid_values = self.grid_values * other.grid_values
29+
30+
def marginalize_out(self, first_or_second):
31+
if first_or_second == 1:
32+
grid_values_sgd = np.sum(self.grid_values, axis=1).T
33+
elif first_or_second == 2:
34+
grid_values_sgd = np.sum(self.grid_values, axis=0)
35+
else:
36+
raise ValueError("Invalid value for first_or_second. Must be 1 or 2.")
37+
return HyperhemisphericalGridDistribution(self.grid, grid_values_sgd)
38+
39+
def fix_dim(self, first_or_second, point):
40+
assert point.shape[0] == self.dim // 2
41+
lia, locb = ismember(point.T, self.grid.T, "rows")
42+
if not lia:
43+
raise ValueError("Cannot fix value at this point because it is not on the grid")
44+
if first_or_second == 1:
45+
grid_values_slice = self.grid_values[locb, :].T
46+
elif first_or_second == 2:
47+
grid_values_slice = self.grid_values[:, locb]
48+
else:
49+
raise ValueError("Invalid value for first_or_second. Must be 1 or 2.")
50+
return HyperhemisphericalGridDistribution(self.grid, grid_values_slice)

0 commit comments

Comments
 (0)