|
| 1 | +import numpy as np |
| 2 | +import warnings |
| 3 | +from .abstract_grid_filter import AbstractGridFilter |
| 4 | +from .abstract_hyperhemispherical_filter import AbstractHyperhemisphericalFilter |
| 5 | +from pyrecest.distributions.hypersphere_subset.hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution |
| 6 | +from pyrecest.distributions.conditional.sd_half_cond_sd_half_grid_distribution import SdHalfCondSdHalfGridDistribution |
| 7 | +from pyrecest.distributions import BinghamDistribution |
| 8 | +from pyrecest.distributions import WatsonDistribution |
| 9 | +from pyrecest.distributions import VonMisesFisherDistribution |
| 10 | +from pyrecest.distributions import HypersphericalMixture |
| 11 | +from pyrecest.distributions import HyperhemisphericalUniformDistribution |
| 12 | +from pyrecest.distributions import AbstractHyperhemisphericalDistribution |
| 13 | +from pyrecest.distributions import HyperhemisphericalWatsonDistribution |
| 14 | + |
| 15 | +class HyperhemisphericalGridFilter(AbstractGridFilter, AbstractHyperhemisphericalFilter): |
| 16 | + def __init__(self, no_of_coefficients, dim, grid_type='eq_point_set_symm'): |
| 17 | + self.gd = HyperhemisphericalGridDistribution.from_distribution( |
| 18 | + HyperhemisphericalUniformDistribution(dim), no_of_coefficients, grid_type) |
| 19 | + |
| 20 | + def set_state(self, new_state): |
| 21 | + assert self.dim == new_state.dim |
| 22 | + assert isinstance(new_state, AbstractHyperhemisphericalDistribution) |
| 23 | + self.gd = new_state |
| 24 | + |
| 25 | + def predict_identity(self, d_sys): |
| 26 | + assert isinstance(d_sys, AbstractHyperhemisphericalDistribution) |
| 27 | + sd_half_cond_sd_half = HyperhemisphericalGridFilter.sys_noise_to_transition_density( |
| 28 | + d_sys, self.gd.grid_values.shape[0]) |
| 29 | + self.predict_nonlinear_via_transition_density(sd_half_cond_sd_half) |
| 30 | + |
| 31 | + def update_identity(self, meas_noise, z=None): |
| 32 | + assert isinstance(meas_noise, AbstractHyperhemisphericalDistribution) |
| 33 | + if not z==None: |
| 34 | + measNoise = measNoise.setMode(z) |
| 35 | + curr_grid = self.gd.get_grid() |
| 36 | + self.gd = self.gd.multiply(HyperhemisphericalGridDistribution(curr_grid, 2 * meas_noise.pdf(curr_grid).T)) |
| 37 | + |
| 38 | + def update_nonlinear(self, likelihood, z): |
| 39 | + self.gd.grid_values = self.gd.grid_values * likelihood(z, self.gd.get_grid()).T |
| 40 | + with warnings.catch_warnings(): |
| 41 | + warnings.simplefilter("ignore", category=RuntimeWarning) |
| 42 | + self.gd = self.gd.normalize() |
| 43 | + |
| 44 | + def predict_nonlinear_via_transition_density(self, f_trans): |
| 45 | + assert np.array_equal(self.gd.get_grid(), f_trans.get_grid()), \ |
| 46 | + "fTrans is using an incompatible grid." |
| 47 | + self.gd = self.gd.normalize() |
| 48 | + grid_values_new = self.gd.get_manifold_size() / self.gd.grid_values.shape[0] * f_trans.grid_values.dot( |
| 49 | + self.gd.grid_values) |
| 50 | + self.gd = HyperhemisphericalGridDistribution(self.gd.get_grid(), grid_values_new) |
| 51 | + |
| 52 | + def get_point_estimate(self): |
| 53 | + gd_full_sphere = self.gd.to_full_sphere() |
| 54 | + p = BinghamDistribution.fit(gd_full_sphere.get_grid(), gd_full_sphere.grid_values.T / np.sum( |
| 55 | + gd_full_sphere.grid_values)).mode() |
| 56 | + if p[-1] < 0: |
| 57 | + p = -p |
| 58 | + return p |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def sys_noise_to_transition_density(d_sys, no_grid_points): |
| 62 | + if isinstance(d_sys, AbstractDistribution): |
| 63 | + if isinstance(d_sys, (HyperhemisphericalWatsonDistribution, WatsonDistribution)): |
| 64 | + def trans(xkk, xk): |
| 65 | + return np.array([2 * WatsonDistribution(xk[:, i], d_sys.kappa).pdf(xkk) for i in range(xk.shape[1])]).T |
| 66 | + |
| 67 | + elif (isinstance(d_sys, HypersphericalMixture) and len(d_sys.dists) == 2 and |
| 68 | + np.all(d_sys.w == 0.5) and np.array_equal(d_sys.dists[0].mu, -d_sys.dists[1].mu) and |
| 69 | + d_sys.dists[0].kappa == d_sys.dists[1].kappa): |
| 70 | + def trans(xkk, xk): |
| 71 | + return np.array([(VonMisesFisherDistribution(xk[:, i], d_sys.dists[0].kappa).pdf(xkk) + |
| 72 | + VonMisesFisherDistribution(xk[:, i], d_sys.dists[0].kappa).pdf(-xkk)) |
| 73 | + for i in range(xk.shape[1])]).T |
| 74 | + else: |
| 75 | + raise ValueError("Distribution not supported for predict identity. Must be zonal (rotationally symmetric around last dimension)") |
| 76 | + |
| 77 | + print("PredictIdentity:Inefficient - Using inefficient prediction. Consider precalculating the SdHalfCondSdHalfGridDistribution and using predictNonlinearViaTransitionDensity.") |
| 78 | + sd_half_cond_sd_half = SdHalfCondSdHalfGridDistribution.from_function(trans, no_grid_points, True, 'eq_point_set_symm', 2 * d_sys.dim) |
| 79 | + return sd_half_cond_sd_half |
| 80 | + |
| 81 | + else: |
| 82 | + raise TypeError("d_sys must be an instance of AbstractDistribution") |
| 83 | + |
0 commit comments