|
| 1 | +"""Unscented Kalman Filter for nonlinear state-space models. |
| 2 | +
|
| 3 | +Propagates sigma points through nonlinear transition and observation functions |
| 4 | +to capture the posterior mean and covariance without linearization. Uses the |
| 5 | +scaled unscented transform with configurable alpha, beta, kappa parameters. |
| 6 | +
|
| 7 | +References: |
| 8 | + Julier, S.J. and Uhlmann, J.K. (2004). "Unscented Filtering and |
| 9 | + Nonlinear Estimation." Proceedings of the IEEE, 92(3), 401-422. |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +from typing import NamedTuple |
| 15 | + |
| 16 | +import jax |
| 17 | +import jax.numpy as jnp |
| 18 | +from jax import Array |
| 19 | + |
| 20 | +from dynaris.core.nonlinear import NonlinearSSM |
| 21 | +from dynaris.core.results import FilterResult |
| 22 | +from dynaris.core.types import GaussianState |
| 23 | + |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | +# Sigma-point weights |
| 26 | +# --------------------------------------------------------------------------- |
| 27 | + |
| 28 | + |
| 29 | +class SigmaWeights(NamedTuple): |
| 30 | + """Weights for the unscented transform.""" |
| 31 | + |
| 32 | + wm: Array # mean weights, shape (2n+1,) |
| 33 | + wc: Array # covariance weights, shape (2n+1,) |
| 34 | + lam: Array # scaling parameter lambda (scalar array) |
| 35 | + |
| 36 | + |
| 37 | +def compute_weights( |
| 38 | + n: int, |
| 39 | + alpha: float = 1e-3, |
| 40 | + beta: float = 2.0, |
| 41 | + kappa: float = 0.0, |
| 42 | +) -> SigmaWeights: |
| 43 | + """Compute sigma-point weights for the scaled unscented transform. |
| 44 | +
|
| 45 | + Args: |
| 46 | + n: State dimension. |
| 47 | + alpha: Spread of sigma points around the mean (typically 1e-4 to 1). |
| 48 | + beta: Prior knowledge of distribution (2.0 is optimal for Gaussian). |
| 49 | + kappa: Secondary scaling parameter (typically 0 or 3-n). |
| 50 | +
|
| 51 | + Returns: |
| 52 | + SigmaWeights with mean weights, covariance weights, and lambda. |
| 53 | + """ |
| 54 | + lam = alpha**2 * (n + kappa) - n |
| 55 | + |
| 56 | + wm = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam))) |
| 57 | + wm = wm.at[0].set(lam / (n + lam)) |
| 58 | + |
| 59 | + wc = jnp.full(2 * n + 1, 1.0 / (2.0 * (n + lam))) |
| 60 | + wc = wc.at[0].set(lam / (n + lam) + (1.0 - alpha**2 + beta)) |
| 61 | + |
| 62 | + return SigmaWeights(wm=wm, wc=wc, lam=jnp.array(lam)) |
| 63 | + |
| 64 | + |
| 65 | +# --------------------------------------------------------------------------- |
| 66 | +# Sigma-point generation |
| 67 | +# --------------------------------------------------------------------------- |
| 68 | + |
| 69 | + |
| 70 | +def sigma_points(state: GaussianState, lam: Array) -> Array: |
| 71 | + """Generate 2n+1 sigma points from a Gaussian state. |
| 72 | +
|
| 73 | + Args: |
| 74 | + state: Gaussian belief with mean (n,) and covariance (n, n). |
| 75 | + lam: Scaling parameter lambda (scalar). |
| 76 | +
|
| 77 | + Returns: |
| 78 | + Sigma points, shape (2n+1, n). |
| 79 | + """ |
| 80 | + n = state.mean.shape[0] |
| 81 | + scaled_cov = (n + lam) * state.cov |
| 82 | + L = jnp.linalg.cholesky(scaled_cov) # (n, n) |
| 83 | + |
| 84 | + # Build sigma points: [mean, mean + L_i, mean - L_i] |
| 85 | + offsets = jnp.concatenate([ |
| 86 | + jnp.zeros((1, n)), |
| 87 | + L, # rows of L as positive offsets |
| 88 | + -L, # rows of L as negative offsets |
| 89 | + ], axis=0) # (2n+1, n) |
| 90 | + |
| 91 | + return state.mean[None, :] + offsets |
| 92 | + |
| 93 | + |
| 94 | +# --------------------------------------------------------------------------- |
| 95 | +# Internal scan carry |
| 96 | +# --------------------------------------------------------------------------- |
| 97 | + |
| 98 | + |
| 99 | +class _ScanCarry(NamedTuple): |
| 100 | + filtered: GaussianState |
| 101 | + log_likelihood: Array # scalar |
| 102 | + |
| 103 | + |
| 104 | +class _ScanOutput(NamedTuple): |
| 105 | + predicted_mean: Array |
| 106 | + predicted_cov: Array |
| 107 | + filtered_mean: Array |
| 108 | + filtered_cov: Array |
| 109 | + |
| 110 | + |
| 111 | +# --------------------------------------------------------------------------- |
| 112 | +# Pure-function predict and update steps |
| 113 | +# --------------------------------------------------------------------------- |
| 114 | + |
| 115 | + |
| 116 | +def predict( |
| 117 | + state: GaussianState, |
| 118 | + model: NonlinearSSM, |
| 119 | + weights: SigmaWeights, |
| 120 | +) -> GaussianState: |
| 121 | + """UKF predict step (time update). |
| 122 | +
|
| 123 | + Generates sigma points, propagates them through the transition function, |
| 124 | + and recovers the predicted mean and covariance. |
| 125 | + """ |
| 126 | + pts = sigma_points(state, weights.lam) # (2n+1, n) |
| 127 | + |
| 128 | + # Propagate sigma points through transition function |
| 129 | + pts_pred = jax.vmap(model.f)(pts) # (2n+1, n) |
| 130 | + |
| 131 | + # Recover predicted mean |
| 132 | + mean = jnp.sum(weights.wm[:, None] * pts_pred, axis=0) # (n,) |
| 133 | + |
| 134 | + # Recover predicted covariance |
| 135 | + diff = pts_pred - mean[None, :] # (2n+1, n) |
| 136 | + cov = jnp.sum(weights.wc[:, None, None] * (diff[:, :, None] * diff[:, None, :]), axis=0) |
| 137 | + cov = cov + model.Q |
| 138 | + |
| 139 | + return GaussianState(mean=mean, cov=cov) |
| 140 | + |
| 141 | + |
| 142 | +def update( |
| 143 | + predicted: GaussianState, |
| 144 | + observation: Array, |
| 145 | + model: NonlinearSSM, |
| 146 | + weights: SigmaWeights, |
| 147 | +) -> tuple[GaussianState, Array]: |
| 148 | + """UKF update step (measurement update). |
| 149 | +
|
| 150 | + Generates sigma points from the predicted state, propagates through the |
| 151 | + observation function, and computes the Kalman gain. |
| 152 | +
|
| 153 | + Returns the filtered state and the log-likelihood contribution. |
| 154 | + Handles missing observations (NaN) by skipping the update. |
| 155 | + """ |
| 156 | + y = observation |
| 157 | + pts = sigma_points(predicted, weights.lam) # (2n+1, n) |
| 158 | + |
| 159 | + # Propagate through observation function |
| 160 | + pts_obs = jax.vmap(model.h)(pts) # (2n+1, m) |
| 161 | + |
| 162 | + # Predicted observation mean |
| 163 | + y_pred = jnp.sum(weights.wm[:, None] * pts_obs, axis=0) # (m,) |
| 164 | + |
| 165 | + # Innovation covariance S = sum wc * (y_diff)(y_diff)' + R |
| 166 | + y_diff = pts_obs - y_pred[None, :] # (2n+1, m) |
| 167 | + S = jnp.sum(weights.wc[:, None, None] * (y_diff[:, :, None] * y_diff[:, None, :]), axis=0) |
| 168 | + S = S + model.R # (m, m) |
| 169 | + |
| 170 | + # Cross-covariance P_xy = sum wc * (x_diff)(y_diff)' |
| 171 | + x_diff = pts - predicted.mean[None, :] # (2n+1, n) |
| 172 | + P_xy = jnp.sum(weights.wc[:, None, None] * (x_diff[:, :, None] * y_diff[:, None, :]), axis=0) |
| 173 | + # (n, m) |
| 174 | + |
| 175 | + # Kalman gain K = P_xy @ S^{-1} |
| 176 | + K = jnp.linalg.solve(S.T, P_xy.T).T # (n, m) |
| 177 | + |
| 178 | + # Innovation |
| 179 | + e = y - y_pred # (m,) |
| 180 | + |
| 181 | + filtered_mean = predicted.mean + K @ e |
| 182 | + filtered_cov = predicted.cov - K @ S @ K.T |
| 183 | + |
| 184 | + # Log-likelihood: log N(e; 0, S) |
| 185 | + m = observation.shape[-1] |
| 186 | + log_det = jnp.linalg.slogdet(S)[1] |
| 187 | + mahal = e @ jnp.linalg.solve(S, e) |
| 188 | + ll = -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal) |
| 189 | + |
| 190 | + # Handle missing observations |
| 191 | + obs_valid = ~jnp.any(jnp.isnan(y)) |
| 192 | + filtered_mean = jnp.where(obs_valid, filtered_mean, predicted.mean) |
| 193 | + filtered_cov = jnp.where(obs_valid, filtered_cov, predicted.cov) |
| 194 | + ll = jnp.where(obs_valid, ll, 0.0) |
| 195 | + |
| 196 | + filtered = GaussianState(mean=filtered_mean, cov=filtered_cov) |
| 197 | + return filtered, ll |
| 198 | + |
| 199 | + |
| 200 | +# --------------------------------------------------------------------------- |
| 201 | +# Full forward pass via lax.scan |
| 202 | +# --------------------------------------------------------------------------- |
| 203 | + |
| 204 | + |
| 205 | +class UnscentedKalmanFilter: |
| 206 | + """Unscented Kalman Filter for nonlinear state-space models. |
| 207 | +
|
| 208 | + Uses the scaled unscented transform to propagate sigma points through |
| 209 | + nonlinear functions, avoiding the need for Jacobian computation. |
| 210 | +
|
| 211 | + Args: |
| 212 | + alpha: Spread of sigma points (default 1e-3). |
| 213 | + beta: Prior distribution parameter (default 2.0, optimal for Gaussian). |
| 214 | + kappa: Secondary scaling parameter (default 0.0). |
| 215 | + """ |
| 216 | + |
| 217 | + def __init__( |
| 218 | + self, |
| 219 | + alpha: float = 1e-3, |
| 220 | + beta: float = 2.0, |
| 221 | + kappa: float = 0.0, |
| 222 | + ) -> None: |
| 223 | + self.alpha = alpha |
| 224 | + self.beta = beta |
| 225 | + self.kappa = kappa |
| 226 | + |
| 227 | + def predict(self, state: GaussianState, model: NonlinearSSM) -> GaussianState: |
| 228 | + """UKF predict step (time update).""" |
| 229 | + w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa) |
| 230 | + return predict(state, model, w) |
| 231 | + |
| 232 | + def update( |
| 233 | + self, |
| 234 | + predicted: GaussianState, |
| 235 | + observation: Array, |
| 236 | + model: NonlinearSSM, |
| 237 | + ) -> GaussianState: |
| 238 | + """UKF update step (measurement update).""" |
| 239 | + w = compute_weights(model.state_dim, self.alpha, self.beta, self.kappa) |
| 240 | + filtered, _ll = update(predicted, observation, model, w) |
| 241 | + return filtered |
| 242 | + |
| 243 | + def scan( |
| 244 | + self, |
| 245 | + model: NonlinearSSM, |
| 246 | + observations: Array, |
| 247 | + initial_state: GaussianState | None = None, |
| 248 | + ) -> FilterResult: |
| 249 | + """Run full forward UKF via jax.lax.scan.""" |
| 250 | + return _ukf_filter_impl( |
| 251 | + model, observations, initial_state, |
| 252 | + self.alpha, self.beta, self.kappa, |
| 253 | + ) |
| 254 | + |
| 255 | + |
| 256 | +def ukf_filter( |
| 257 | + model: NonlinearSSM, |
| 258 | + observations: Array, |
| 259 | + initial_state: GaussianState | None = None, |
| 260 | + *, |
| 261 | + alpha: float = 1e-3, |
| 262 | + beta: float = 2.0, |
| 263 | + kappa: float = 0.0, |
| 264 | +) -> FilterResult: |
| 265 | + """Unscented Kalman Filter forward pass. |
| 266 | +
|
| 267 | + Uses the scaled unscented transform with configurable parameters to |
| 268 | + propagate sigma points through nonlinear transition and observation |
| 269 | + functions. |
| 270 | +
|
| 271 | + Args: |
| 272 | + model: Nonlinear state-space model with callable f and h. |
| 273 | + observations: Observation sequence, shape (T, obs_dim). |
| 274 | + initial_state: Initial state belief. Defaults to diffuse prior. |
| 275 | + alpha: Spread of sigma points around the mean (default 1e-3). |
| 276 | + beta: Prior distribution parameter (default 2.0, optimal for Gaussian). |
| 277 | + kappa: Secondary scaling parameter (default 0.0). |
| 278 | +
|
| 279 | + Returns: |
| 280 | + FilterResult with filtered/predicted states and log-likelihood. |
| 281 | +
|
| 282 | + Example:: |
| 283 | +
|
| 284 | + import jax.numpy as jnp |
| 285 | + from dynaris.core.nonlinear import NonlinearSSM |
| 286 | + from dynaris.filters.ukf import ukf_filter |
| 287 | +
|
| 288 | + model = NonlinearSSM( |
| 289 | + transition_fn=lambda x: x, |
| 290 | + observation_fn=lambda x: x, |
| 291 | + transition_cov=jnp.eye(1), |
| 292 | + observation_cov=jnp.eye(1), |
| 293 | + state_dim=1, obs_dim=1, |
| 294 | + ) |
| 295 | + result = ukf_filter(model, observations) |
| 296 | + """ |
| 297 | + return _ukf_filter_impl(model, observations, initial_state, alpha, beta, kappa) |
| 298 | + |
| 299 | + |
| 300 | +def _ukf_filter_impl( |
| 301 | + model: NonlinearSSM, |
| 302 | + observations: Array, |
| 303 | + initial_state: GaussianState | None, |
| 304 | + alpha: float, |
| 305 | + beta: float, |
| 306 | + kappa: float, |
| 307 | +) -> FilterResult: |
| 308 | + """Internal implementation — weights computed before JIT boundary.""" |
| 309 | + if initial_state is None: |
| 310 | + initial_state = model.initial_state() |
| 311 | + |
| 312 | + weights = compute_weights(model.state_dim, alpha, beta, kappa) |
| 313 | + return _ukf_scan(model, observations, initial_state, weights) |
| 314 | + |
| 315 | + |
| 316 | +@jax.jit |
| 317 | +def _ukf_scan( |
| 318 | + model: NonlinearSSM, |
| 319 | + observations: Array, |
| 320 | + initial_state: GaussianState, |
| 321 | + weights: SigmaWeights, |
| 322 | +) -> FilterResult: |
| 323 | + """JIT-compiled scan loop for UKF.""" |
| 324 | + init_carry = _ScanCarry( |
| 325 | + filtered=initial_state, |
| 326 | + log_likelihood=jnp.array(0.0), |
| 327 | + ) |
| 328 | + |
| 329 | + def _scan_step( |
| 330 | + carry: _ScanCarry, obs: Array |
| 331 | + ) -> tuple[_ScanCarry, _ScanOutput]: |
| 332 | + predicted = predict(carry.filtered, model, weights) |
| 333 | + filtered, ll = update(predicted, obs, model, weights) |
| 334 | + new_carry = _ScanCarry( |
| 335 | + filtered=filtered, |
| 336 | + log_likelihood=carry.log_likelihood + ll, |
| 337 | + ) |
| 338 | + output = _ScanOutput( |
| 339 | + predicted_mean=predicted.mean, |
| 340 | + predicted_cov=predicted.cov, |
| 341 | + filtered_mean=filtered.mean, |
| 342 | + filtered_cov=filtered.cov, |
| 343 | + ) |
| 344 | + return new_carry, output |
| 345 | + |
| 346 | + final_carry, outputs = jax.lax.scan(_scan_step, init_carry, observations) |
| 347 | + |
| 348 | + return FilterResult( |
| 349 | + filtered_states=outputs.filtered_mean, |
| 350 | + filtered_covariances=outputs.filtered_cov, |
| 351 | + predicted_states=outputs.predicted_mean, |
| 352 | + predicted_covariances=outputs.predicted_cov, |
| 353 | + log_likelihood=final_carry.log_likelihood, |
| 354 | + observations=observations, |
| 355 | + ) |
0 commit comments