|
1 | 1 | """Kalman filter operations for state estimation using the square-root form.""" |
2 | 2 |
|
3 | | -from collections.abc import Callable |
| 3 | +from collections.abc import Callable, Mapping |
4 | 4 |
|
5 | 5 | import jax |
6 | 6 | import jax.numpy as jnp |
7 | 7 | from jax import Array |
8 | 8 |
|
9 | 9 | from skillmodels.qr import qr_gpu |
10 | 10 |
|
| 11 | +LINEAR_FUNCTION_NAMES = frozenset({"linear", "constant"}) |
| 12 | + |
| 13 | + |
| 14 | +def is_all_linear(function_names: Mapping[str, str]) -> bool: |
| 15 | + """Return True if every factor uses a linear or constant transition function.""" |
| 16 | + return all(name in LINEAR_FUNCTION_NAMES for name in function_names.values()) |
| 17 | + |
| 18 | + |
11 | 19 | array_qr_jax = ( |
12 | 20 | jax.vmap(jax.vmap(qr_gpu)) |
13 | 21 | if jax.default_backend() == "gpu" |
@@ -228,6 +236,141 @@ def kalman_predict( |
228 | 236 | return predicted_states, predicted_covs |
229 | 237 |
|
230 | 238 |
|
| 239 | +def linear_kalman_predict( |
| 240 | + transition_func: Callable | None, # noqa: ARG001 |
| 241 | + states: Array, |
| 242 | + upper_chols: Array, |
| 243 | + sigma_scaling_factor: float, # noqa: ARG001 |
| 244 | + sigma_weights: Array, # noqa: ARG001 |
| 245 | + trans_coeffs: dict[str, Array], |
| 246 | + shock_sds: Array, |
| 247 | + anchoring_scaling_factors: Array, |
| 248 | + anchoring_constants: Array, |
| 249 | + observed_factors: Array, |
| 250 | + *, |
| 251 | + latent_factors: tuple[str, ...], |
| 252 | + constant_factor_indices: frozenset[int], |
| 253 | + n_all_factors: int, |
| 254 | +) -> tuple[Array, Array]: |
| 255 | + """Make a linear Kalman predict (square-root form). |
| 256 | +
|
| 257 | + Much cheaper than the unscented predict because it avoids sigma point |
| 258 | + generation and transformation. Only valid when every factor uses a `linear` |
| 259 | + or `constant` transition function. |
| 260 | +
|
| 261 | + The positional parameters `transition_func`, `sigma_scaling_factor` and |
| 262 | + `sigma_weights` are accepted for signature compatibility with |
| 263 | + `kalman_predict` but are ignored. |
| 264 | +
|
| 265 | + Args: |
| 266 | + transition_func: Ignored (kept for signature compatibility). |
| 267 | + states: Array of shape (n_obs, n_mixtures, n_states). |
| 268 | + upper_chols: Array of shape (n_obs, n_mixtures, n_states, n_states). |
| 269 | + sigma_scaling_factor: Ignored. |
| 270 | + sigma_weights: Ignored. |
| 271 | + trans_coeffs: Dict mapping factor name to 1d coefficient array. |
| 272 | + shock_sds: 1d array of length n_states. |
| 273 | + anchoring_scaling_factors: Array of shape (2, n_states). |
| 274 | + anchoring_constants: Array of shape (2, n_states). |
| 275 | + observed_factors: Array of shape (n_obs, n_observed_factors). |
| 276 | + latent_factors: Tuple of latent factor names. |
| 277 | + constant_factor_indices: Indices of factors with `constant` transition. |
| 278 | + n_all_factors: Total number of factors (latent + observed). |
| 279 | +
|
| 280 | + Returns: |
| 281 | + Predicted states, same shape as states. |
| 282 | + Predicted upper_chols, same shape as upper_chols. |
| 283 | +
|
| 284 | + """ |
| 285 | + n_latent = len(latent_factors) |
| 286 | + |
| 287 | + f_mat, c_vec = _build_f_and_c( |
| 288 | + latent_factors, constant_factor_indices, n_all_factors, trans_coeffs |
| 289 | + ) |
| 290 | + |
| 291 | + s_in = anchoring_scaling_factors[0][:n_latent] # (n_latent,) for input period |
| 292 | + s_out = anchoring_scaling_factors[1][:n_latent] # (n_latent,) for output period |
| 293 | + c_in = anchoring_constants[0][:n_latent] # (n_latent,) |
| 294 | + c_out = anchoring_constants[1][:n_latent] # (n_latent,) |
| 295 | + |
| 296 | + # Mean prediction |
| 297 | + anchored_states = states * s_in + c_in # (n_obs, n_mix, n_latent) |
| 298 | + # Concatenate with observed factors to get full state vector |
| 299 | + n_obs, n_mix, _ = states.shape |
| 300 | + obs_expanded = jnp.broadcast_to( |
| 301 | + observed_factors[:, jnp.newaxis, :], (n_obs, n_mix, observed_factors.shape[1]) |
| 302 | + ) |
| 303 | + full_states = jnp.concatenate([anchored_states, obs_expanded], axis=-1) |
| 304 | + |
| 305 | + predicted_anchored = full_states @ f_mat.T + c_vec # (n_obs, n_mix, n_latent) |
| 306 | + predicted_states = (predicted_anchored - c_out) / s_out |
| 307 | + |
| 308 | + # Covariance prediction (square-root form) |
| 309 | + # G = diag(1/s_out) @ F_latent @ diag(s_in) where F_latent is the first |
| 310 | + # n_latent columns of F |
| 311 | + f_latent = f_mat[:, :n_latent] # (n_latent, n_latent) |
| 312 | + g_mat = (f_latent * s_in) / s_out[:, jnp.newaxis] # (n_latent, n_latent) |
| 313 | + |
| 314 | + # Stack: [upper_chol @ G.T ; diag(shock_sds / s_out)] |
| 315 | + chol_g = upper_chols @ g_mat.T # (n_obs, n_mix, n_latent, n_latent) |
| 316 | + shock_diag = jnp.diag(shock_sds / s_out) # (n_latent, n_latent) |
| 317 | + |
| 318 | + stack = jnp.concatenate( |
| 319 | + [chol_g, jnp.broadcast_to(shock_diag, chol_g.shape)], axis=-2 |
| 320 | + ) # (n_obs, n_mix, 2*n_latent, n_latent) |
| 321 | + |
| 322 | + predicted_covs = array_qr_jax(stack)[1][:, :, :n_latent] |
| 323 | + |
| 324 | + return predicted_states, predicted_covs |
| 325 | + |
| 326 | + |
| 327 | +def _build_f_and_c( |
| 328 | + latent_factors: tuple[str, ...], |
| 329 | + constant_factor_indices: frozenset[int], |
| 330 | + n_all_factors: int, |
| 331 | + trans_coeffs: dict[str, Array], |
| 332 | +) -> tuple[Array, Array]: |
| 333 | + """Build F matrix and c vector from transition coefficients. |
| 334 | +
|
| 335 | + Stack all coefficient arrays, build identity rows for constant factors, |
| 336 | + and select via a boolean mask. |
| 337 | +
|
| 338 | + Args: |
| 339 | + latent_factors: Tuple of latent factor names. |
| 340 | + constant_factor_indices: Indices of factors with `constant` transition. |
| 341 | + n_all_factors: Total number of factors (latent + observed). |
| 342 | + trans_coeffs: Dict mapping factor name to 1d coefficient array. |
| 343 | +
|
| 344 | + Returns: |
| 345 | + f_mat: Array of shape (n_latent, n_all_factors). |
| 346 | + c_vec: Array of shape (n_latent,). |
| 347 | +
|
| 348 | + """ |
| 349 | + n_latent = len(latent_factors) |
| 350 | + identity = jnp.eye(n_latent, n_all_factors) |
| 351 | + |
| 352 | + # Will be of shape (n_latent, n_all+1) |
| 353 | + all_coeffs = jnp.stack( |
| 354 | + [ |
| 355 | + trans_coeffs[f] |
| 356 | + if i not in constant_factor_indices |
| 357 | + else jnp.zeros(n_all_factors + 1) |
| 358 | + for i, f in enumerate(latent_factors) |
| 359 | + ] |
| 360 | + ) |
| 361 | + |
| 362 | + f_from_coeffs = all_coeffs[:, :-1] |
| 363 | + c_from_coeffs = all_coeffs[:, -1] |
| 364 | + |
| 365 | + is_constant = jnp.array([i in constant_factor_indices for i in range(n_latent)]) |
| 366 | + mask = is_constant[:, None] # (n_latent, 1) |
| 367 | + |
| 368 | + f_mat = jnp.where(mask, identity, f_from_coeffs) |
| 369 | + c_vec = jnp.where(is_constant, 0.0, c_from_coeffs) |
| 370 | + |
| 371 | + return f_mat, c_vec |
| 372 | + |
| 373 | + |
231 | 374 | def _calculate_sigma_points( |
232 | 375 | states: Array, |
233 | 376 | upper_chols: Array, |
|
0 commit comments