Skip to content

Commit 3c040e8

Browse files
authored
Merge pull request #4 from quant-sci/implement-ukf
feat: add Unscented Kalman Filter (UKF)
2 parents a1ff847 + 38d29bc commit 3c040e8

4 files changed

Lines changed: 776 additions & 1 deletion

File tree

src/dynaris/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
Regression,
1919
Seasonal,
2020
)
21-
from dynaris.filters import ExtendedKalmanFilter, KalmanFilter, ekf_filter, kalman_filter
21+
from dynaris.filters import (
22+
ExtendedKalmanFilter,
23+
KalmanFilter,
24+
UnscentedKalmanFilter,
25+
ekf_filter,
26+
kalman_filter,
27+
ukf_filter,
28+
)
2229
from dynaris.smoothers import RTSSmoother, rts_smooth
2330

2431
__version__ = "0.1.0"
@@ -41,8 +48,10 @@
4148
"SmootherProtocol",
4249
"SmootherResult",
4350
"StateSpaceModel",
51+
"UnscentedKalmanFilter",
4452
"__version__",
4553
"ekf_filter",
4654
"kalman_filter",
4755
"rts_smooth",
56+
"ukf_filter",
4857
]

src/dynaris/filters/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from dynaris.filters.ekf import ExtendedKalmanFilter, ekf_filter
44
from dynaris.filters.kalman import KalmanFilter, kalman_filter
5+
from dynaris.filters.ukf import UnscentedKalmanFilter, ukf_filter
56

67
__all__ = [
78
"ExtendedKalmanFilter",
89
"KalmanFilter",
10+
"UnscentedKalmanFilter",
911
"ekf_filter",
1012
"kalman_filter",
13+
"ukf_filter",
1114
]

src/dynaris/filters/ukf.py

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
"""Unscented Kalman Filter for nonlinear state-space models.
2+
3+
Propagates sigma points through nonlinear transition and observation functions
4+
to capture the posterior mean and covariance without linearization. Uses the
5+
scaled unscented transform with configurable alpha, beta, kappa parameters.
6+
7+
References:
8+
Julier, S.J. and Uhlmann, J.K. (2004). "Unscented Filtering and
9+
Nonlinear Estimation." Proceedings of the IEEE, 92(3), 401-422.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from typing import NamedTuple
15+
16+
import jax
17+
import jax.numpy as jnp
18+
from jax import Array
19+
20+
from dynaris.core.nonlinear import NonlinearSSM
21+
from dynaris.core.results import FilterResult
22+
from dynaris.core.types import GaussianState
23+
24+
# ---------------------------------------------------------------------------
25+
# Sigma-point weights
26+
# ---------------------------------------------------------------------------
27+
28+
29+
class SigmaWeights(NamedTuple):
30+
"""Weights for the unscented transform."""
31+
32+
wm: Array # mean weights, shape (2n+1,)
33+
wc: Array # covariance weights, shape (2n+1,)
34+
lam: Array # scaling parameter lambda (scalar array)
35+
36+
37+
def compute_weights(
38+
n: int,
39+
alpha: float = 1e-3,
40+
beta: float = 2.0,
41+
kappa: float = 0.0,
42+
) -> SigmaWeights:
43+
"""Compute sigma-point weights for the scaled unscented transform.
44+
45+
Args:
46+
n: State dimension.
47+
alpha: Spread of sigma points around the mean (typically 1e-4 to 1).
48+
beta: Prior knowledge of distribution (2.0 is optimal for Gaussian).
49+
kappa: Secondary scaling parameter (typically 0 or 3-n).
50+
51+
Returns:
52+
SigmaWeights with mean weights, covariance weights, and lambda.
53+
"""
54+
lam = alpha**2 * (n + kappa) - n
55+
56+
wm = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam)))
57+
wm = wm.at[0].set(lam / (n + lam))
58+
59+
wc = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam)))
60+
wc = wc.at[0].set(lam / (n + lam) + (1.0 - alpha**2 + beta))
61+
62+
return SigmaWeights(wm=wm, wc=wc, lam=jnp.array(lam))
63+
64+
65+
# ---------------------------------------------------------------------------
66+
# Sigma-point generation
67+
# ---------------------------------------------------------------------------
68+
69+
70+
def sigma_points(state: GaussianState, lam: Array) -> Array:
71+
"""Generate 2n+1 sigma points from a Gaussian state.
72+
73+
Args:
74+
state: Gaussian belief with mean (n,) and covariance (n, n).
75+
lam: Scaling parameter lambda (scalar).
76+
77+
Returns:
78+
Sigma points, shape (2n+1, n).
79+
"""
80+
n = state.mean.shape[0]
81+
scaled_cov = (n + lam) * state.cov
82+
L = jnp.linalg.cholesky(scaled_cov) # (n, n)
83+
84+
# Build sigma points: [mean, mean + L_i, mean - L_i]
85+
offsets = jnp.concatenate([
86+
jnp.zeros((1, n)),
87+
L, # rows of L as positive offsets
88+
-L, # rows of L as negative offsets
89+
], axis=0) # (2n+1, n)
90+
91+
return state.mean[None, :] + offsets
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# Internal scan carry
96+
# ---------------------------------------------------------------------------
97+
98+
99+
class _ScanCarry(NamedTuple):
100+
filtered: GaussianState
101+
log_likelihood: Array # scalar
102+
103+
104+
class _ScanOutput(NamedTuple):
105+
predicted_mean: Array
106+
predicted_cov: Array
107+
filtered_mean: Array
108+
filtered_cov: Array
109+
110+
111+
# ---------------------------------------------------------------------------
112+
# Pure-function predict and update steps
113+
# ---------------------------------------------------------------------------
114+
115+
116+
def predict(
117+
state: GaussianState,
118+
model: NonlinearSSM,
119+
weights: SigmaWeights,
120+
) -> GaussianState:
121+
"""UKF predict step (time update).
122+
123+
Generates sigma points, propagates them through the transition function,
124+
and recovers the predicted mean and covariance.
125+
"""
126+
pts = sigma_points(state, weights.lam) # (2n+1, n)
127+
128+
# Propagate sigma points through transition function
129+
pts_pred = jax.vmap(model.f)(pts) # (2n+1, n)
130+
131+
# Recover predicted mean
132+
mean = jnp.sum(weights.wm[:, None] * pts_pred, axis=0) # (n,)
133+
134+
# Recover predicted covariance
135+
diff = pts_pred - mean[None, :] # (2n+1, n)
136+
cov = jnp.sum(weights.wc[:, None, None] * (diff[:, :, None] * diff[:, None, :]), axis=0)
137+
cov = cov + model.Q
138+
139+
return GaussianState(mean=mean, cov=cov)
140+
141+
142+
def update(
143+
predicted: GaussianState,
144+
observation: Array,
145+
model: NonlinearSSM,
146+
weights: SigmaWeights,
147+
) -> tuple[GaussianState, Array]:
148+
"""UKF update step (measurement update).
149+
150+
Generates sigma points from the predicted state, propagates through the
151+
observation function, and computes the Kalman gain.
152+
153+
Returns the filtered state and the log-likelihood contribution.
154+
Handles missing observations (NaN) by skipping the update.
155+
"""
156+
y = observation
157+
pts = sigma_points(predicted, weights.lam) # (2n+1, n)
158+
159+
# Propagate through observation function
160+
pts_obs = jax.vmap(model.h)(pts) # (2n+1, m)
161+
162+
# Predicted observation mean
163+
y_pred = jnp.sum(weights.wm[:, None] * pts_obs, axis=0) # (m,)
164+
165+
# Innovation covariance S = sum wc * (y_diff)(y_diff)' + R
166+
y_diff = pts_obs - y_pred[None, :] # (2n+1, m)
167+
S = jnp.sum(weights.wc[:, None, None] * (y_diff[:, :, None] * y_diff[:, None, :]), axis=0)
168+
S = S + model.R # (m, m)
169+
170+
# Cross-covariance P_xy = sum wc * (x_diff)(y_diff)'
171+
x_diff = pts - predicted.mean[None, :] # (2n+1, n)
172+
P_xy = jnp.sum(weights.wc[:, None, None] * (x_diff[:, :, None] * y_diff[:, None, :]), axis=0)
173+
# (n, m)
174+
175+
# Kalman gain K = P_xy @ S^{-1}
176+
K = jnp.linalg.solve(S.T, P_xy.T).T # (n, m)
177+
178+
# Innovation
179+
e = y - y_pred # (m,)
180+
181+
filtered_mean = predicted.mean + K @ e
182+
filtered_cov = predicted.cov - K @ S @ K.T
183+
184+
# Log-likelihood: log N(e; 0, S)
185+
m = observation.shape[-1]
186+
log_det = jnp.linalg.slogdet(S)[1]
187+
mahal = e @ jnp.linalg.solve(S, e)
188+
ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal)
189+
190+
# Handle missing observations
191+
obs_valid = ~jnp.any(jnp.isnan(y))
192+
filtered_mean = jnp.where(obs_valid, filtered_mean, predicted.mean)
193+
filtered_cov = jnp.where(obs_valid, filtered_cov, predicted.cov)
194+
ll = jnp.where(obs_valid, ll, 0.0)
195+
196+
filtered = GaussianState(mean=filtered_mean, cov=filtered_cov)
197+
return filtered, ll
198+
199+
200+
# ---------------------------------------------------------------------------
201+
# Full forward pass via lax.scan
202+
# ---------------------------------------------------------------------------
203+
204+
205+
class UnscentedKalmanFilter:
206+
"""Unscented Kalman Filter for nonlinear state-space models.
207+
208+
Uses the scaled unscented transform to propagate sigma points through
209+
nonlinear functions, avoiding the need for Jacobian computation.
210+
211+
Args:
212+
alpha: Spread of sigma points (default 1e-3).
213+
beta: Prior distribution parameter (default 2.0, optimal for Gaussian).
214+
kappa: Secondary scaling parameter (default 0.0).
215+
"""
216+
217+
def __init__(
218+
self,
219+
alpha: float = 1e-3,
220+
beta: float = 2.0,
221+
kappa: float = 0.0,
222+
) -> None:
223+
self.alpha = alpha
224+
self.beta = beta
225+
self.kappa = kappa
226+
227+
def predict(self, state: GaussianState, model: NonlinearSSM) -> GaussianState:
228+
"""UKF predict step (time update)."""
229+
w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa)
230+
return predict(state, model, w)
231+
232+
def update(
233+
self,
234+
predicted: GaussianState,
235+
observation: Array,
236+
model: NonlinearSSM,
237+
) -> GaussianState:
238+
"""UKF update step (measurement update)."""
239+
w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa)
240+
filtered, _ll = update(predicted, observation, model, w)
241+
return filtered
242+
243+
def scan(
244+
self,
245+
model: NonlinearSSM,
246+
observations: Array,
247+
initial_state: GaussianState | None = None,
248+
) -> FilterResult:
249+
"""Run full forward UKF via jax.lax.scan."""
250+
return _ukf_filter_impl(
251+
model, observations, initial_state,
252+
self.alpha, self.beta, self.kappa,
253+
)
254+
255+
256+
def ukf_filter(
257+
model: NonlinearSSM,
258+
observations: Array,
259+
initial_state: GaussianState | None = None,
260+
*,
261+
alpha: float = 1e-3,
262+
beta: float = 2.0,
263+
kappa: float = 0.0,
264+
) -> FilterResult:
265+
"""Unscented Kalman Filter forward pass.
266+
267+
Uses the scaled unscented transform with configurable parameters to
268+
propagate sigma points through nonlinear transition and observation
269+
functions.
270+
271+
Args:
272+
model: Nonlinear state-space model with callable f and h.
273+
observations: Observation sequence, shape (T, obs_dim).
274+
initial_state: Initial state belief. Defaults to diffuse prior.
275+
alpha: Spread of sigma points around the mean (default 1e-3).
276+
beta: Prior distribution parameter (default 2.0, optimal for Gaussian).
277+
kappa: Secondary scaling parameter (default 0.0).
278+
279+
Returns:
280+
FilterResult with filtered/predicted states and log-likelihood.
281+
282+
Example::
283+
284+
import jax.numpy as jnp
285+
from dynaris.core.nonlinear import NonlinearSSM
286+
from dynaris.filters.ukf import ukf_filter
287+
288+
model = NonlinearSSM(
289+
transition_fn=lambda x: x,
290+
observation_fn=lambda x: x,
291+
transition_cov=jnp.eye(1),
292+
observation_cov=jnp.eye(1),
293+
state_dim=1, obs_dim=1,
294+
)
295+
result = ukf_filter(model, observations)
296+
"""
297+
return _ukf_filter_impl(model, observations, initial_state, alpha, beta, kappa)
298+
299+
300+
def _ukf_filter_impl(
301+
model: NonlinearSSM,
302+
observations: Array,
303+
initial_state: GaussianState | None,
304+
alpha: float,
305+
beta: float,
306+
kappa: float,
307+
) -> FilterResult:
308+
"""Internal implementation — weights computed before JIT boundary."""
309+
if initial_state is None:
310+
initial_state = model.initial_state()
311+
312+
weights = compute_weights(model.state_dim, alpha, beta, kappa)
313+
return _ukf_scan(model, observations, initial_state, weights)
314+
315+
316+
@jax.jit
317+
def _ukf_scan(
318+
model: NonlinearSSM,
319+
observations: Array,
320+
initial_state: GaussianState,
321+
weights: SigmaWeights,
322+
) -> FilterResult:
323+
"""JIT-compiled scan loop for UKF."""
324+
init_carry = _ScanCarry(
325+
filtered=initial_state,
326+
log_likelihood=jnp.array(0.0),
327+
)
328+
329+
def _scan_step(
330+
carry: _ScanCarry, obs: Array
331+
) -> tuple[_ScanCarry, _ScanOutput]:
332+
predicted = predict(carry.filtered, model, weights)
333+
filtered, ll = update(predicted, obs, model, weights)
334+
new_carry = _ScanCarry(
335+
filtered=filtered,
336+
log_likelihood=carry.log_likelihood + ll,
337+
)
338+
output = _ScanOutput(
339+
predicted_mean=predicted.mean,
340+
predicted_cov=predicted.cov,
341+
filtered_mean=filtered.mean,
342+
filtered_cov=filtered.cov,
343+
)
344+
return new_carry, output
345+
346+
final_carry, outputs = jax.lax.scan(_scan_step, init_carry, observations)
347+
348+
return FilterResult(
349+
filtered_states=outputs.filtered_mean,
350+
filtered_covariances=outputs.filtered_cov,
351+
predicted_states=outputs.predicted_mean,
352+
predicted_covariances=outputs.predicted_cov,
353+
log_likelihood=final_carry.log_likelihood,
354+
observations=observations,
355+
)

0 commit comments

Comments
 (0)