Skip to content

Commit d58c919

Browse files
committed
Added HyperhemisphericalGridFilter
1 parent a15ecdc commit d58c919

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import numpy as np
2+
import warnings
3+
from .abstract_grid_filter import AbstractGridFilter
4+
from .abstract_hyperhemispherical_filter import AbstractHyperhemisphericalFilter
5+
from pyrecest.distributions.hypersphere_subset.hyperhemispherical_grid_distribution import HyperhemisphericalGridDistribution
6+
from pyrecest.distributions.conditional.sd_half_cond_sd_half_grid_distribution import SdHalfCondSdHalfGridDistribution
7+
from pyrecest.distributions import BinghamDistribution
8+
from pyrecest.distributions import WatsonDistribution
9+
from pyrecest.distributions import VonMisesFisherDistribution
10+
from pyrecest.distributions import HypersphericalMixture
11+
from pyrecest.distributions import HyperhemisphericalUniformDistribution
12+
from pyrecest.distributions import AbstractHyperhemisphericalDistribution
13+
from pyrecest.distributions import HyperhemisphericalWatsonDistribution
14+
15+
class HyperhemisphericalGridFilter(AbstractGridFilter, AbstractHyperhemisphericalFilter):
16+
def __init__(self, no_of_coefficients, dim, grid_type='eq_point_set_symm'):
17+
self.gd = HyperhemisphericalGridDistribution.from_distribution(
18+
HyperhemisphericalUniformDistribution(dim), no_of_coefficients, grid_type)
19+
20+
def set_state(self, new_state):
21+
assert self.dim == new_state.dim
22+
assert isinstance(new_state, AbstractHyperhemisphericalDistribution)
23+
self.gd = new_state
24+
25+
def predict_identity(self, d_sys):
26+
assert isinstance(d_sys, AbstractHyperhemisphericalDistribution)
27+
sd_half_cond_sd_half = HyperhemisphericalGridFilter.sys_noise_to_transition_density(
28+
d_sys, self.gd.grid_values.shape[0])
29+
self.predict_nonlinear_via_transition_density(sd_half_cond_sd_half)
30+
31+
def update_identity(self, meas_noise, z=None):
32+
assert isinstance(meas_noise, AbstractHyperhemisphericalDistribution)
33+
if not z==None:
34+
measNoise = measNoise.setMode(z)
35+
curr_grid = self.gd.get_grid()
36+
self.gd = self.gd.multiply(HyperhemisphericalGridDistribution(curr_grid, 2 * meas_noise.pdf(curr_grid).T))
37+
38+
def update_nonlinear(self, likelihood, z):
39+
self.gd.grid_values = self.gd.grid_values * likelihood(z, self.gd.get_grid()).T
40+
with warnings.catch_warnings():
41+
warnings.simplefilter("ignore", category=RuntimeWarning)
42+
self.gd = self.gd.normalize()
43+
44+
def predict_nonlinear_via_transition_density(self, f_trans):
45+
assert np.array_equal(self.gd.get_grid(), f_trans.get_grid()), \
46+
"fTrans is using an incompatible grid."
47+
self.gd = self.gd.normalize()
48+
grid_values_new = self.gd.get_manifold_size() / self.gd.grid_values.shape[0] * f_trans.grid_values.dot(
49+
self.gd.grid_values)
50+
self.gd = HyperhemisphericalGridDistribution(self.gd.get_grid(), grid_values_new)
51+
52+
def get_point_estimate(self):
53+
gd_full_sphere = self.gd.to_full_sphere()
54+
p = BinghamDistribution.fit(gd_full_sphere.get_grid(), gd_full_sphere.grid_values.T / np.sum(
55+
gd_full_sphere.grid_values)).mode()
56+
if p[-1] < 0:
57+
p = -p
58+
return p
59+
60+
@staticmethod
61+
def sys_noise_to_transition_density(d_sys, no_grid_points):
62+
if isinstance(d_sys, AbstractDistribution):
63+
if isinstance(d_sys, (HyperhemisphericalWatsonDistribution, WatsonDistribution)):
64+
def trans(xkk, xk):
65+
return np.array([2 * WatsonDistribution(xk[:, i], d_sys.kappa).pdf(xkk) for i in range(xk.shape[1])]).T
66+
67+
elif (isinstance(d_sys, HypersphericalMixture) and len(d_sys.dists) == 2 and
68+
np.all(d_sys.w == 0.5) and np.array_equal(d_sys.dists[0].mu, -d_sys.dists[1].mu) and
69+
d_sys.dists[0].kappa == d_sys.dists[1].kappa):
70+
def trans(xkk, xk):
71+
return np.array([(VonMisesFisherDistribution(xk[:, i], d_sys.dists[0].kappa).pdf(xkk) +
72+
VonMisesFisherDistribution(xk[:, i], d_sys.dists[0].kappa).pdf(-xkk))
73+
for i in range(xk.shape[1])]).T
74+
else:
75+
raise ValueError("Distribution not supported for predict identity. Must be zonal (rotationally symmetric around last dimension)")
76+
77+
print("PredictIdentity:Inefficient - Using inefficient prediction. Consider precalculating the SdHalfCondSdHalfGridDistribution and using predictNonlinearViaTransitionDensity.")
78+
sd_half_cond_sd_half = SdHalfCondSdHalfGridDistribution.from_function(trans, no_grid_points, True, 'eq_point_set_symm', 2 * d_sys.dim)
79+
return sd_half_cond_sd_half
80+
81+
else:
82+
raise TypeError("d_sys must be an instance of AbstractDistribution")
83+

0 commit comments

Comments
 (0)