Skip to content

Commit 4560400

Browse files
committed
Added UnscentedKalmanFilter
1 parent a7d8a9c commit 4560400

2 files changed

Lines changed: 163 additions & 0 deletions

File tree

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from bayesian_filters.kalman import UnscentedKalmanFilter as BayesianFiltersUKF
2+
from bayesian_filters.kalman import MerweScaledSigmaPoints
3+
4+
from .manifold_mixins import EuclideanFilterMixin
5+
from pyrecest.distributions import GaussianDistribution
6+
7+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
8+
from pyrecest.backend import atleast_1d
9+
import numbers
10+
from typing import Callable
11+
12+
class UnscentedKalmanFilter(EuclideanFilterMixin):
13+
def __init__(self, initial_state: GaussianDistribution | tuple,
14+
dt: numbers.Real = 1,
15+
fx: Callable = lambda x, dt: x,
16+
hx: Callable = lambda x: x,
17+
points=None,
18+
):
19+
if isinstance(initial_state, GaussianDistribution):
20+
dim_x = initial_state.dim
21+
elif isinstance(initial_state, tuple) and len(initial_state) == 2:
22+
dim_x = len(initial_state[0])
23+
else:
24+
raise ValueError(
25+
"initial_state must be a GaussianDistribution or a tuple of (mean, covariance)"
26+
)
27+
28+
if points is None:
29+
# Standard settings for Gaussian approximations
30+
points = MerweScaledSigmaPoints(dim_x, alpha=0.001, beta=2, kappa=0)
31+
32+
# Initialize filterpy UKF
33+
# Note: We initialize dim_z to dim_x as a default, but this can be
34+
# overridden dynamically in update() by providing R and hx
35+
self._filter_state = BayesianFiltersUKF(dim_x=dim_x, dim_z=dim_x, dt=dt, hx=hx, fx=fx, points=points)
36+
37+
# Set initial state
38+
if isinstance(initial_state, GaussianDistribution):
39+
self._filter_state.x = initial_state.mu
40+
self._filter_state.P = initial_state.C
41+
else:
42+
self._filter_state.x = initial_state[0]
43+
self._filter_state.P = initial_state[1]
44+
45+
self._filter_state.x_prior = self._filter_state.x.copy()
46+
self._filter_state.P_prior = self._filter_state.P.copy()
47+
48+
@property
49+
def filter_state(
50+
self,
51+
) -> (
52+
GaussianDistribution | tuple
53+
): # It can only return GaussianDistribution, this just serves to prevent mypy linter warnings
54+
return GaussianDistribution(self._filter_state.x, self._filter_state.P)
55+
56+
@filter_state.setter
57+
def filter_state(
58+
self, new_state: GaussianDistribution | tuple
59+
):
60+
"""
61+
Set the filter state.
62+
63+
:param new_state: Provide GaussianDistribution or mean and covariance as state.
64+
"""
65+
if isinstance(new_state, GaussianDistribution):
66+
self._filter_state.x = new_state.mu
67+
self._filter_state.P = new_state.C
68+
elif isinstance(new_state, tuple) and len(new_state) == 2:
69+
self._filter_state.x = new_state[0]
70+
self._filter_state.P = new_state[1]
71+
else:
72+
raise ValueError(
73+
"new_state must be a GaussianDistribution or a tuple of (mean, covariance)"
74+
)
75+
76+
def predict_nonlinear(self, fx, sys_noise_cov, dt=None, **fx_args):
77+
"""
78+
:param fx: Function with signature fx(x, dt, **fx_args)
79+
:param sys_noise_cov: Process noise matrix Q
80+
"""
81+
# FIX: FilterPy UKF uses member variable Q, not an argument to predict()
82+
self._filter_state.Q = sys_noise_cov
83+
self._filter_state.predict(dt=dt, fx=fx, **fx_args)
84+
85+
def update_nonlinear(self, measurement, hx, cov_mat_meas, **hx_args):
86+
# Update allows R to be passed as argument
87+
self._filter_state.update(z=measurement, hx=hx, R=cov_mat_meas, **hx_args)
88+
89+
def predict_identity(self, sys_noise_cov, dt=None):
90+
self._filter_state.Q = sys_noise_cov
91+
def fx(x, dt):
92+
return x
93+
94+
self._filter_state.predict(dt=dt, fx=fx)
95+
96+
def predict_linear(self, system_matrix, sys_noise_cov, sys_input=None, dt=None):
97+
self._filter_state.Q = sys_noise_cov
98+
99+
if sys_input is None:
100+
# F * x
101+
fx = lambda x, _dt: system_matrix @ x
102+
else:
103+
# F * x + B * u (Assuming B is Identity based on your previous code)
104+
# ideally sys_input should already be B*u or you should pass B explicitly
105+
fx = lambda x, _dt: system_matrix @ x + sys_input
106+
107+
self._filter_state.predict(dt=dt, fx=fx)
108+
109+
def update_identity(self, meas, meas_cov):
110+
# h(x) = x
111+
hx = lambda x: x
112+
self._filter_state.update(z=atleast_1d(meas), R=meas_cov, hx=hx)
113+
114+
def update_linear(self, measurement, measurement_matrix, cov_mat_meas):
115+
# h(x) = H * x
116+
# Note: hx must return a 1D array for filterpy
117+
hx = lambda x: measurement_matrix @ x
118+
119+
self._filter_state.update(z=measurement, R=cov_mat_meas, hx=hx)
120+
121+
def get_estimate(self):
122+
return GaussianDistribution(self._filter_state.x, self._filter_state.P)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
from pyrecest.filters.unscented_kalman_filter import UnscentedKalmanFilter
3+
from pyrecest.distributions import GaussianDistribution
4+
import copy
5+
from pyrecest.backend import array, diag, eye, allclose
6+
import numpy.testing as npt
7+
8+
class UnscentedKalmanFilterTest(unittest.TestCase):
9+
10+
def test_initialization(self):
11+
filter_custom = UnscentedKalmanFilter(GaussianDistribution(array([1.0]), array([[10000.0]])))
12+
npt.assert_allclose(filter_custom.get_point_estimate(), 1.0)
13+
14+
def test_initialization_gauss(self):
15+
filter_custom = UnscentedKalmanFilter(GaussianDistribution(array([4.0]), array([[10000.0]])))
16+
npt.assert_allclose(filter_custom.get_point_estimate(), 4)
17+
18+
def test_update_linear_1d(self):
19+
kf = UnscentedKalmanFilter(GaussianDistribution(array([0.0]), array([[1.0]])))
20+
kf.update_identity(array([3.0]), array([[1.0]]))
21+
npt.assert_allclose(kf.get_point_estimate(), 1.5)
22+
23+
def test_update_linear_2d(self):
24+
filter_add = UnscentedKalmanFilter(GaussianDistribution(array([0.0, 1.0]), diag(array([1.0, 2.0]))))
25+
filter_id = copy.deepcopy(filter_add)
26+
gauss = GaussianDistribution(array([1.0, 0.0]), diag(array([2.0, 1.0])))
27+
filter_add.update_linear(gauss.mu, eye(2), gauss.C)
28+
filter_id.update_identity(gauss.mu, gauss.C)
29+
self.assertTrue(allclose(filter_add.get_point_estimate(), filter_id.get_point_estimate()))
30+
self.assertTrue(allclose(filter_add.filter_state.covariance(), filter_id.filter_state.covariance()))
31+
32+
def test_predict_linear_2d(self):
33+
kf = UnscentedKalmanFilter(GaussianDistribution(array([0.0, 1.0]), diag(array(array([1.0, 2.0])))))
34+
kf.predict_linear(diag(array([1.0, 2.0])), diag(array([2.0, 1.0])))
35+
self.assertTrue(allclose(kf.get_point_estimate(), array([0.0, 2.0])))
36+
self.assertTrue(allclose(kf.filter_state.covariance(), diag(array([3.0, 9.0]))))
37+
kf.predict_linear(diag(array([1.0, 2.0])), diag(array([2.0, 1.0])), array([2.0, -2.0]))
38+
self.assertTrue(allclose(kf.get_point_estimate(), array([2.0, 2.0])))
39+
40+
if __name__ == "__main__":
41+
unittest.main()

0 commit comments

Comments
 (0)