|
| 1 | +from pyrecest.distributions.cart_prod.state_space_subdivision_distribution import StateSpaceSubdivisionDistribution |
| 2 | +from pyrecest.distributions.cart_prod.abstract_hypercylindrical_distribution import AbstractHypercylindricalDistribution |
| 3 | +from pyrecest.distributions.nonperiodic.custom_linear_distribution import CustomLinearDistribution |
| 4 | +from pyrecest.distributions.circle.circular_uniform_distribution import CircularUniformDistribution |
| 5 | +from scipy.integrate import quad |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +class HypercylindricalStateSpaceSubdivisionDistribution(StateSpaceSubdivisionDistribution, AbstractHypercylindricalDistribution): |
| 9 | + |
| 10 | + def __init__(self, gd_, lin_distributions): |
| 11 | + StateSpaceSubdivisionDistribution.__init__(self, gd_, lin_distributions) |
| 12 | + |
| 13 | + def plot(self, interpolate=False): |
| 14 | + if interpolate: |
| 15 | + return AbstractHypercylindricalDistribution.plot(self) |
| 16 | + else: |
| 17 | + return StateSpaceSubdivisionDistribution.plot(self) |
| 18 | + |
| 19 | + def plot_interpolated(self): |
| 20 | + return self.plot(interpolate=True) |
| 21 | + |
| 22 | + def mode(self): |
| 23 | + return StateSpaceSubdivisionDistribution.mode(self) |
| 24 | + |
| 25 | + @staticmethod |
| 26 | + def from_distribution(distribution, no_of_grid_points, grid_type='CartesianProd'): |
| 27 | + return HypercylindricalStateSpaceSubdivisionDistribution.from_function( |
| 28 | + distribution.pdf, no_of_grid_points, distribution.linD, distribution.boundD, grid_type) |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def from_function(fun, no_of_grid_points, dim_lin, dim_bound=1, grid_type='CartesianProd', int_range=(-np.inf , np.inf)): |
| 32 | + assert dim_lin == 1, 'Currently, bounded dimension must be 1.' |
| 33 | + |
| 34 | + gd = CircularGridDistribution.from_distribution(CircularUniformDistribution(), no_of_grid_points) |
| 35 | + grid = gd.get_grid() |
| 36 | + cds = [None] * no_of_grid_points |
| 37 | + |
| 38 | + for i in range(no_of_grid_points): |
| 39 | + fun_curr = lambda y: np.reshape(fun(np.vstack((grid[i] * np.ones_like(y), y))), np.shape(y)) |
| 40 | + |
| 41 | + # Obtain grid value via integral |
| 42 | + gd.grid_values[i], _ = quad(fun_curr, int_range[0], int_range[1]) |
| 43 | + |
| 44 | + # Original function divided by grid value is linear |
| 45 | + cds[i] = CustomLinearDistribution(lambda x: fun_curr(x) / gd.grid_values[i], 1) |
| 46 | + |
| 47 | + return HypercylindricalStateSpaceSubdivisionDistribution(gd, cds) |
0 commit comments