Skip to content

Commit 65ab292

Browse files
committed
Added HypersphericalParticleFilter
1 parent d277a66 commit 65ab292

3 files changed

Lines changed: 99 additions & 0 deletions

File tree

pyrecest/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@
147147
from pyrecest.distributions.hypersphere_subset.custom_hyperhemispherical_distribution import (
148148
CustomHyperhemisphericalDistribution,
149149
)
150+
from pyrecest.distributions.hypersphere_subset.hyperspherical_dirac_distribution import (
151+
HypersphericalDiracDistribution,
152+
)
150153
from pyrecest.distributions.hypersphere_subset.hyperhemispherical_dirac_distribution import (
151154
HyperhemisphericalDiracDistribution,
152155
)
@@ -254,6 +257,7 @@
254257
]
255258

256259
__all__ = aliases + [
260+
"HypersphericalDiracDistribution",
257261
"GeneralizedKSineSkewedVonMisesDistribution",
258262
"AbstractBoundedDomainDistribution",
259263
"AbstractBoundedNonPeriodicDistribution",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from .abstract_hyperspherical_filter import AbstractHypersphericalFilter
2+
from pyrecest.distributions import VonMisesDistribution, AbstractHypersphericalDistribution
3+
from pyrecest.distributions.hypersphere_subset.hyperspherical_dirac_distribution import HypersphericalDiracDistribution
4+
import warnings
5+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
6+
from pyrecest.backend import tile, eye, sum, linalg
7+
8+
class HypersphericalParticleFilter(AbstractHypersphericalFilter):
9+
def __init__(self, n_particles, dim):
10+
self.wd = HypersphericalDiracDistribution(tile(eye(dim, 1), (1, n_particles)))
11+
12+
@property
13+
def filter_state(self):
14+
return self._filter_state
15+
16+
@filter_state.setter
17+
def filter_state(self, new_state):
18+
"""Sets the filter state to new_state if it is a type of AbstractHypersphericalDistribution."""
19+
if not isinstance(new_state, AbstractHypersphericalDistribution):
20+
raise TypeError("new_state must be an instance of AbstractHypersphericalDistribution")
21+
if not isinstance(new_state, HypersphericalDiracDistribution):
22+
new_state = HypersphericalDiracDistribution(new_state.sample(self._filter_state.d.shape[0]))
23+
self._filter_state = new_state
24+
25+
def predict_identity(self, noise_distribution):
26+
self.predict_nonlinear(lambda x: x, noise_distribution)
27+
28+
def update_identity(self, noise_distribution, z):
29+
assert isinstance(noise_distribution, VonMisesDistribution), "Currently, only VMF distributed noise terms are supported."
30+
if z is not None:
31+
noise_distribution.mu = z
32+
warnings.warn("Warning: update_identity: mu of noise_distribution is replaced by measurement...")
33+
34+
def update_nonlinear(self, likelihood, z=None):
35+
if z is None:
36+
self.wd = self.wd.reweigh(likelihood)
37+
else:
38+
self.wd = self.wd.reweigh(lambda x: likelihood(z, x))
39+
40+
def get_estimate_mean(self):
41+
vec_sum = sum(self.wd.d * tile(self.wd.w, (self.dim, 1)), axis=1)
42+
mean = vec_sum / linalg.norm(vec_sum)
43+
return mean
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
import numpy as np
3+
import warnings
4+
from pyrecest.distributions import VonMisesFisherDistribution, HypersphericalDiracDistribution
5+
# from pyrecest.distributions
6+
from pyrecest.filters.hyperspherical_particle_filter import HypersphericalParticleFilter
7+
from pyrecest.backend import array, allclose, ones, linalg
8+
9+
class HypersphericalParticleFilterTest(unittest.TestCase):
10+
def test3D(self):
11+
np.random.seed(1)
12+
n_samples = 20000
13+
vmf_init = VonMisesFisherDistribution(array([1.0, 0.0, 0.0]), 10)
14+
vmf_sys = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 10)
15+
vmf_meas = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 3)
16+
17+
hpf = HypersphericalParticleFilter(n_samples, 3)
18+
hpf.filter_state = HypersphericalDiracDistribution(vmf_init.sample(n_samples))
19+
20+
est = hpf.get_estimate_mean()
21+
self.assertIsInstance(est, np.ndarray)
22+
vmf_init_mean = vmf_init.moment()
23+
hpf_mean = est
24+
self.assertTrue(np.allclose(vmf_init_mean, hpf_mean, atol=1e-2))
25+
26+
# Prediction step
27+
hpf.predict_identity(vmf_sys)
28+
29+
# Update steps
30+
with warnings.catch_warnings():
31+
warnings.simplefilter("ignore")
32+
hpf.update_identity(vmf_meas, array([0.0, 0.0, 1.0]))
33+
hpf.update_identity(vmf_meas, array([0.0, 0.0, 1.0]))
34+
hpf.update_identity(vmf_meas, array([0.0, 0.0, 1.0]))
35+
36+
hpf_est_mean = hpf.get_estimate_mean()
37+
self.assertEqual(hpf_est_mean.shape, (3,))
38+
expected_mean = array([0.0, 0.0, 1.0])
39+
self.assertTrue(allclose(hpf_est_mean, expected_mean, atol=0.05))
40+
41+
# Reset state
42+
hpf.filter_state = HypersphericalDiracDistribution(vmf_init.sample(n_samples))
43+
44+
# Nonlinear update with a function that returns ones
45+
def f(z, x):
46+
return ones(x.shape[1])
47+
48+
z = 3
49+
est = hpf.get_estimate_mean()
50+
self.assertAlmostEqual(linalg.norm(est), 1, delta=1e-10)
51+
hpf.update_nonlinear(f, z)
52+
self.assertTrue(allclose(est, hpf.get_estimate_mean(), atol=1e-2))

0 commit comments

Comments
 (0)