Skip to content

Commit b94d873

Browse files
committed
Added HypercylindricalStateSpaceSubdivisionDistribution
1 parent faf8e3d commit b94d873

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from pyrecest.distributions.cart_prod.state_space_subdivision_distribution import StateSpaceSubdivisionDistribution
2+
from pyrecest.distributions.cart_prod.abstract_hypercylindrical_distribution import AbstractHypercylindricalDistribution
3+
from pyrecest.distributions.nonperiodic.custom_linear_distribution import CustomLinearDistribution
4+
from pyrecest.distributions.circle.circular_uniform_distribution import CircularUniformDistribution
5+
from scipy.integrate import quad
6+
import numpy as np
7+
8+
class HypercylindricalStateSpaceSubdivisionDistribution(StateSpaceSubdivisionDistribution, AbstractHypercylindricalDistribution):
9+
10+
def __init__(self, gd_, lin_distributions):
11+
StateSpaceSubdivisionDistribution.__init__(self, gd_, lin_distributions)
12+
13+
def plot(self, interpolate=False):
14+
if interpolate:
15+
return AbstractHypercylindricalDistribution.plot(self)
16+
else:
17+
return StateSpaceSubdivisionDistribution.plot(self)
18+
19+
def plot_interpolated(self):
20+
return self.plot(interpolate=True)
21+
22+
def mode(self):
23+
return StateSpaceSubdivisionDistribution.mode(self)
24+
25+
@staticmethod
26+
def from_distribution(distribution, no_of_grid_points, grid_type='CartesianProd'):
27+
return HypercylindricalStateSpaceSubdivisionDistribution.from_function(
28+
distribution.pdf, no_of_grid_points, distribution.linD, distribution.boundD, grid_type)
29+
30+
@staticmethod
31+
def from_function(fun, no_of_grid_points, dim_lin, dim_bound=1, grid_type='CartesianProd', int_range=(-np.inf , np.inf)):
32+
assert dim_lin == 1, 'Currently, bounded dimension must be 1.'
33+
34+
gd = CircularGridDistribution.from_distribution(CircularUniformDistribution(), no_of_grid_points)
35+
grid = gd.get_grid()
36+
cds = [None] * no_of_grid_points
37+
38+
for i in range(no_of_grid_points):
39+
fun_curr = lambda y: np.reshape(fun(np.vstack((grid[i] * np.ones_like(y), y))), np.shape(y))
40+
41+
# Obtain grid value via integral
42+
gd.grid_values[i], _ = quad(fun_curr, int_range[0], int_range[1])
43+
44+
# Original function divided by grid value is linear
45+
cds[i] = CustomLinearDistribution(lambda x: fun_curr(x) / gd.grid_values[i], 1)
46+
47+
return HypercylindricalStateSpaceSubdivisionDistribution(gd, cds)

0 commit comments

Comments
 (0)