1+ from bayesian_filters .kalman import UnscentedKalmanFilter as BayesianFiltersUKF
2+ from bayesian_filters .kalman import MerweScaledSigmaPoints
3+
4+ from .manifold_mixins import EuclideanFilterMixin
5+ from pyrecest .distributions import GaussianDistribution
6+
7+ # pylint: disable=redefined-builtin,no-name-in-module,no-member
8+ from pyrecest .backend import atleast_1d
9+ import numbers
10+ from typing import Callable
11+
12+ class UnscentedKalmanFilter (EuclideanFilterMixin ):
13+ def __init__ (self , initial_state : GaussianDistribution | tuple ,
14+ dt : numbers .Real = 1 ,
15+ fx : Callable = lambda x , dt : x ,
16+ hx : Callable = lambda x : x ,
17+ points = None ,
18+ ):
19+ if isinstance (initial_state , GaussianDistribution ):
20+ dim_x = initial_state .dim
21+ elif isinstance (initial_state , tuple ) and len (initial_state ) == 2 :
22+ dim_x = len (initial_state [0 ])
23+ else :
24+ raise ValueError (
25+ "initial_state must be a GaussianDistribution or a tuple of (mean, covariance)"
26+ )
27+
28+ if points is None :
29+ # Standard settings for Gaussian approximations
30+ points = MerweScaledSigmaPoints (dim_x , alpha = 0.001 , beta = 2 , kappa = 0 )
31+
32+ # Initialize filterpy UKF
33+ # Note: We initialize dim_z to dim_x as a default, but this can be
34+ # overridden dynamically in update() by providing R and hx
35+ self ._filter_state = BayesianFiltersUKF (dim_x = dim_x , dim_z = dim_x , dt = dt , hx = hx , fx = fx , points = points )
36+
37+ # Set initial state
38+ if isinstance (initial_state , GaussianDistribution ):
39+ self ._filter_state .x = initial_state .mu
40+ self ._filter_state .P = initial_state .C
41+ else :
42+ self ._filter_state .x = initial_state [0 ]
43+ self ._filter_state .P = initial_state [1 ]
44+
45+ self ._filter_state .x_prior = self ._filter_state .x .copy ()
46+ self ._filter_state .P_prior = self ._filter_state .P .copy ()
47+
48+ @property
49+ def filter_state (
50+ self ,
51+ ) -> (
52+ GaussianDistribution | tuple
53+ ): # It can only return GaussianDistribution, this just serves to prevent mypy linter warnings
54+ return GaussianDistribution (self ._filter_state .x , self ._filter_state .P )
55+
56+ @filter_state .setter
57+ def filter_state (
58+ self , new_state : GaussianDistribution | tuple
59+ ):
60+ """
61+ Set the filter state.
62+
63+ :param new_state: Provide GaussianDistribution or mean and covariance as state.
64+ """
65+ if isinstance (new_state , GaussianDistribution ):
66+ self ._filter_state .x = new_state .mu
67+ self ._filter_state .P = new_state .C
68+ elif isinstance (new_state , tuple ) and len (new_state ) == 2 :
69+ self ._filter_state .x = new_state [0 ]
70+ self ._filter_state .P = new_state [1 ]
71+ else :
72+ raise ValueError (
73+ "new_state must be a GaussianDistribution or a tuple of (mean, covariance)"
74+ )
75+
76+ def predict_nonlinear (self , fx , sys_noise_cov , dt = None , ** fx_args ):
77+ """
78+ :param fx: Function with signature fx(x, dt, **fx_args)
79+ :param sys_noise_cov: Process noise matrix Q
80+ """
81+ # FIX: FilterPy UKF uses member variable Q, not an argument to predict()
82+ self ._filter_state .Q = sys_noise_cov
83+ self ._filter_state .predict (dt = dt , fx = fx , ** fx_args )
84+
85+ def update_nonlinear (self , measurement , hx , cov_mat_meas , ** hx_args ):
86+ # Update allows R to be passed as argument
87+ self ._filter_state .update (z = measurement , hx = hx , R = cov_mat_meas , ** hx_args )
88+
89+ def predict_identity (self , sys_noise_cov , dt = None ):
90+ self ._filter_state .Q = sys_noise_cov
91+ def fx (x , dt ):
92+ return x
93+
94+ self ._filter_state .predict (dt = dt , fx = fx )
95+
96+ def predict_linear (self , system_matrix , sys_noise_cov , sys_input = None , dt = None ):
97+ self ._filter_state .Q = sys_noise_cov
98+
99+ if sys_input is None :
100+ # F * x
101+ fx = lambda x , _dt : system_matrix @ x
102+ else :
103+ # F * x + B * u (Assuming B is Identity based on your previous code)
104+ # ideally sys_input should already be B*u or you should pass B explicitly
105+ fx = lambda x , _dt : system_matrix @ x + sys_input
106+
107+ self ._filter_state .predict (dt = dt , fx = fx )
108+
109+ def update_identity (self , meas , meas_cov ):
110+ # h(x) = x
111+ hx = lambda x : x
112+ self ._filter_state .update (z = atleast_1d (meas ), R = meas_cov , hx = hx )
113+
114+ def update_linear (self , measurement , measurement_matrix , cov_mat_meas ):
115+ # h(x) = H * x
116+ # Note: hx must return a 1D array for filterpy
117+ hx = lambda x : measurement_matrix @ x
118+
119+ self ._filter_state .update (z = measurement , R = cov_mat_meas , hx = hx )
120+
121+ def get_estimate (self ):
122+ return GaussianDistribution (self ._filter_state .x , self ._filter_state .P )
0 commit comments