Skip to content

Commit 2a77ee4

Browse files
hmgaudeckerclaude
andauthored
Linear predict (#87)
* First shot at fixing #36. * Bug fix in kalman_filters.py — added [:n_latent] slice to s_in and c_in. Add tests. * Fix ty errors and add docs on linear predict. * Update docs based on benchmark results. * Idiomatic Jax for linear filter, though no speed difference. Added a note. --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 023775e commit 2a77ee4

8 files changed

Lines changed: 595 additions & 34 deletions

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Linear predict optimization
2+
3+
## When the linear predict is used
4+
5+
At model setup time, `is_all_linear` checks whether every latent factor's transition
6+
function name belongs to `{"linear", "constant"}`. This is an all-or-nothing decision:
7+
if even one factor uses a nonlinear transition (e.g. `translog`), the entire model falls
8+
back to the unscented predict.
9+
10+
The check happens in `get_maximization_inputs`, where the predict function is selected
11+
via `functools.partial`. When the linear path is chosen, extra keyword arguments
12+
(`latent_factors`, `constant_factor_indices`, `n_all_factors`) are bound at setup time so
13+
the predict function has the same call signature as the unscented variant.
14+
15+
## Why it is faster and uses less memory
16+
17+
The unscented predict generates $2n + 1$ sigma points (where $n$ is the number of latent
18+
factors), transforms each one through the transition function, then recovers predicted
19+
means and covariances from weighted statistics. Its QR decomposition operates on a matrix
20+
of shape $(3n + 1) \times n$: the $2n + 1$ weighted deviation rows plus $n$ rows for the
21+
shock standard deviations.
22+
23+
The linear predict skips sigma-point generation entirely. Because the transition is
24+
linear, the predicted mean is just a matrix--vector product, and the predicted covariance
25+
follows from the standard linear Gaussian formula. Its QR decomposition operates on a
26+
$(2n) \times n$ matrix: $n$ rows from the propagated Cholesky factor and $n$ rows for the
27+
shocks. The reduction from $3n + 1$ to $2n$ rows speeds up the QR step and removes all
28+
sigma-point overhead.
29+
30+
The memory savings can be more important than the speed gains. The unscented path
31+
materialises $2n + 1$ sigma points for every observation and mixture component, and
32+
JAX's automatic differentiation retains intermediate buffers for the backward pass. The
33+
linear path replaces all of this with a single matrix multiply whose memory footprint
34+
scales with $n^2$ rather than with the number of sigma points times the number of
35+
observations. On memory-constrained GPUs this can be the difference between fitting the
36+
model and running out of memory.
37+
38+
## Building F and c
39+
40+
The linear predict assembles a transition matrix $F$ of shape
41+
$(n_\text{latent}, n_\text{all})$ and a constant vector $c$ of length $n_\text{latent}$
42+
from the `trans_coeffs` dictionary. Here $n_\text{all}$ includes both latent and observed
43+
factors.
44+
45+
For each latent factor $i$:
46+
47+
- **Linear factor**: `trans_coeffs[factor]` is a 1-d array whose last element is the
48+
intercept and whose preceding elements are the coefficients on all factors (latent and
49+
observed). Row $i$ of $F$ is set to `coeffs[:-1]` and $c_i$ is set to `coeffs[-1]`.
50+
- **Constant factor**: row $i$ of $F$ is the unit vector $e_i$ (identity row) and
51+
$c_i = 0$, so the factor value is simply carried forward.
52+
53+
The implementation uses a stack-then-mask approach: all coefficient arrays are stacked
54+
into a single matrix (with zero-padded rows for constant factors), an identity matrix
55+
provides the constant-factor rows, and `jnp.where` selects between them using a boolean
56+
mask. This avoids per-element `.at[i].set()` calls and conditional branching, producing
57+
a cleaner trace for JAX's compiler.
58+
59+
Three construction strategies were benchmarked (loop with conditional `.at[i].set()`,
60+
stack-then-mask with `jnp.where`, and index-scatter with pre-separated sub-matrices).
61+
All three produced identical XLA graphs and showed no meaningful runtime difference
62+
(~6.3--6.7 ms per call on CPU, 4-factor model, 5000 observations), confirming that the
63+
construction is fully resolved at trace time. The stack-then-mask variant was kept for
64+
its cleaner, more idiomatic JAX style.
65+
66+
## Mean prediction
67+
68+
The mean prediction incorporates anchoring, which rescales factors to a common metric
69+
across periods. Let $s^{\text{in}}$ and $c^{\text{in}}$ be the input-period scaling
70+
factors and constants, and $s^{\text{out}}$ and $c^{\text{out}}$ the output-period
71+
counterparts. The steps are:
72+
73+
1. **Anchor** the input states: $x^a = x \odot s^{\text{in}} + c^{\text{in}}$.
74+
2. **Concatenate** observed factors to form the full state vector
75+
$\tilde{x} = [x^a, x^{\text{obs}}]$.
76+
3. **Apply the linear transition**: $y^a = \tilde{x}\, F^\top + c$.
77+
4. **Un-anchor** to get the predicted states:
78+
$\hat{x} = (y^a - c^{\text{out}}) \oslash s^{\text{out}}$.
79+
80+
## Covariance prediction (square-root form)
81+
82+
skillmodels maintains covariances in square-root (upper Cholesky) form throughout. Let
83+
$R$ denote the current upper Cholesky factor so that $P = R^\top R$. The linear predict
84+
propagates $R$ as follows.
85+
86+
Define the effective transition matrix
87+
88+
$$
89+
G = \operatorname{diag}(1 / s^{\text{out}})\; F_{\text{latent}}\;
90+
\operatorname{diag}(s^{\text{in}})
91+
$$
92+
93+
where $F_{\text{latent}}$ is the first $n_\text{latent}$ columns of $F$ (the columns
94+
corresponding to latent factors). $G$ folds the anchoring scales into the transition so
95+
that the covariance update works directly in the un-anchored (internal) scale.
96+
97+
The predicted covariance satisfies
98+
99+
$$
100+
\hat{P} = G\, P\, G^\top + Q
101+
$$
102+
103+
where $Q = \operatorname{diag}(\sigma / s^{\text{out}})^2$ and $\sigma$ is the vector of
104+
shock standard deviations. In square-root form, the upper Cholesky factor $\hat{R}$ of
105+
$\hat{P}$ is obtained via a single QR decomposition of the stacked matrix
106+
107+
$$
108+
S = \begin{bmatrix} R\, G^\top \\ \operatorname{diag}(\sigma / s^{\text{out}})
109+
\end{bmatrix}
110+
$$
111+
112+
which has shape $(2n) \times n$. The upper-triangular $R$-factor of $S$ (its first $n$
113+
rows) gives $\hat{R}$.
114+
115+
## Observed factors
116+
117+
Observed factors (e.g. investment measures whose values are known from data) appear as
118+
columns in $F$ and therefore influence the predicted mean through the matrix--vector
119+
product. However, they carry no uncertainty: their columns are excluded from the
120+
covariance propagation. This is why $G$ uses only the first $n_\text{latent}$ columns of
121+
$F$ rather than the full matrix.
122+
123+
## Practical impact
124+
125+
Benchmarks on a 4-factor linear model (`health-cognition`,
126+
`no_feedback_to_investments_linear`, 8 GiB GPU) show a modest ~6 % speed-up on GPU
127+
(8.4 vs 8.9 s per optimizer iteration) and negligible difference on CPU. The speed gain
128+
is small because with only 4 latent factors the unscented transform generates just 9
129+
sigma points — a trivially cheap operation on modern hardware.
130+
131+
The memory reduction is the more significant benefit. Under the same conditions the
132+
unscented path ran out of GPU memory when only ~5 GiB was free, while the linear path
133+
ran without issues. For models with more latent factors both advantages grow: the
134+
sigma-point count scales as $2n + 1$ and the QR matrix shrinks from $(3n + 1) \times n$
135+
to $(2n) \times n$.

docs/myst.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ project:
4040
children:
4141
- file: explanations/names_and_concepts.md
4242
- file: explanations/notes_on_factor_scales.md
43+
- file: explanations/linear_predict.md
4344
- title: Reference Guides
4445
children:
4546
- file: reference_guides/transition_functions.md

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/skillmodels/kalman_filters.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
"""Kalman filter operations for state estimation using the square-root form."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Mapping
44

55
import jax
66
import jax.numpy as jnp
77
from jax import Array
88

99
from skillmodels.qr import qr_gpu
1010

11+
LINEAR_FUNCTION_NAMES = frozenset({"linear", "constant"})
12+
13+
14+
def is_all_linear(function_names: Mapping[str, str]) -> bool:
15+
"""Return True if every factor uses a linear or constant transition function."""
16+
return all(name in LINEAR_FUNCTION_NAMES for name in function_names.values())
17+
18+
1119
array_qr_jax = (
1220
jax.vmap(jax.vmap(qr_gpu))
1321
if jax.default_backend() == "gpu"
@@ -228,6 +236,141 @@ def kalman_predict(
228236
return predicted_states, predicted_covs
229237

230238

239+
def linear_kalman_predict(
240+
transition_func: Callable | None, # noqa: ARG001
241+
states: Array,
242+
upper_chols: Array,
243+
sigma_scaling_factor: float, # noqa: ARG001
244+
sigma_weights: Array, # noqa: ARG001
245+
trans_coeffs: dict[str, Array],
246+
shock_sds: Array,
247+
anchoring_scaling_factors: Array,
248+
anchoring_constants: Array,
249+
observed_factors: Array,
250+
*,
251+
latent_factors: tuple[str, ...],
252+
constant_factor_indices: frozenset[int],
253+
n_all_factors: int,
254+
) -> tuple[Array, Array]:
255+
"""Make a linear Kalman predict (square-root form).
256+
257+
Much cheaper than the unscented predict because it avoids sigma point
258+
generation and transformation. Only valid when every factor uses a `linear`
259+
or `constant` transition function.
260+
261+
The positional parameters `transition_func`, `sigma_scaling_factor` and
262+
`sigma_weights` are accepted for signature compatibility with
263+
`kalman_predict` but are ignored.
264+
265+
Args:
266+
transition_func: Ignored (kept for signature compatibility).
267+
states: Array of shape (n_obs, n_mixtures, n_states).
268+
upper_chols: Array of shape (n_obs, n_mixtures, n_states, n_states).
269+
sigma_scaling_factor: Ignored.
270+
sigma_weights: Ignored.
271+
trans_coeffs: Dict mapping factor name to 1d coefficient array.
272+
shock_sds: 1d array of length n_states.
273+
anchoring_scaling_factors: Array of shape (2, n_states).
274+
anchoring_constants: Array of shape (2, n_states).
275+
observed_factors: Array of shape (n_obs, n_observed_factors).
276+
latent_factors: Tuple of latent factor names.
277+
constant_factor_indices: Indices of factors with `constant` transition.
278+
n_all_factors: Total number of factors (latent + observed).
279+
280+
Returns:
281+
Predicted states, same shape as states.
282+
Predicted upper_chols, same shape as upper_chols.
283+
284+
"""
285+
n_latent = len(latent_factors)
286+
287+
f_mat, c_vec = _build_f_and_c(
288+
latent_factors, constant_factor_indices, n_all_factors, trans_coeffs
289+
)
290+
291+
s_in = anchoring_scaling_factors[0][:n_latent] # (n_latent,) for input period
292+
s_out = anchoring_scaling_factors[1][:n_latent] # (n_latent,) for output period
293+
c_in = anchoring_constants[0][:n_latent] # (n_latent,)
294+
c_out = anchoring_constants[1][:n_latent] # (n_latent,)
295+
296+
# Mean prediction
297+
anchored_states = states * s_in + c_in # (n_obs, n_mix, n_latent)
298+
# Concatenate with observed factors to get full state vector
299+
n_obs, n_mix, _ = states.shape
300+
obs_expanded = jnp.broadcast_to(
301+
observed_factors[:, jnp.newaxis, :], (n_obs, n_mix, observed_factors.shape[1])
302+
)
303+
full_states = jnp.concatenate([anchored_states, obs_expanded], axis=-1)
304+
305+
predicted_anchored = full_states @ f_mat.T + c_vec # (n_obs, n_mix, n_latent)
306+
predicted_states = (predicted_anchored - c_out) / s_out
307+
308+
# Covariance prediction (square-root form)
309+
# G = diag(1/s_out) @ F_latent @ diag(s_in) where F_latent is the first
310+
# n_latent columns of F
311+
f_latent = f_mat[:, :n_latent] # (n_latent, n_latent)
312+
g_mat = (f_latent * s_in) / s_out[:, jnp.newaxis] # (n_latent, n_latent)
313+
314+
# Stack: [upper_chol @ G.T ; diag(shock_sds / s_out)]
315+
chol_g = upper_chols @ g_mat.T # (n_obs, n_mix, n_latent, n_latent)
316+
shock_diag = jnp.diag(shock_sds / s_out) # (n_latent, n_latent)
317+
318+
stack = jnp.concatenate(
319+
[chol_g, jnp.broadcast_to(shock_diag, chol_g.shape)], axis=-2
320+
) # (n_obs, n_mix, 2*n_latent, n_latent)
321+
322+
predicted_covs = array_qr_jax(stack)[1][:, :, :n_latent]
323+
324+
return predicted_states, predicted_covs
325+
326+
327+
def _build_f_and_c(
328+
latent_factors: tuple[str, ...],
329+
constant_factor_indices: frozenset[int],
330+
n_all_factors: int,
331+
trans_coeffs: dict[str, Array],
332+
) -> tuple[Array, Array]:
333+
"""Build F matrix and c vector from transition coefficients.
334+
335+
Stack all coefficient arrays, build identity rows for constant factors,
336+
and select via a boolean mask.
337+
338+
Args:
339+
latent_factors: Tuple of latent factor names.
340+
constant_factor_indices: Indices of factors with `constant` transition.
341+
n_all_factors: Total number of factors (latent + observed).
342+
trans_coeffs: Dict mapping factor name to 1d coefficient array.
343+
344+
Returns:
345+
f_mat: Array of shape (n_latent, n_all_factors).
346+
c_vec: Array of shape (n_latent,).
347+
348+
"""
349+
n_latent = len(latent_factors)
350+
identity = jnp.eye(n_latent, n_all_factors)
351+
352+
# Will be of shape (n_latent, n_all+1)
353+
all_coeffs = jnp.stack(
354+
[
355+
trans_coeffs[f]
356+
if i not in constant_factor_indices
357+
else jnp.zeros(n_all_factors + 1)
358+
for i, f in enumerate(latent_factors)
359+
]
360+
)
361+
362+
f_from_coeffs = all_coeffs[:, :-1]
363+
c_from_coeffs = all_coeffs[:, -1]
364+
365+
is_constant = jnp.array([i in constant_factor_indices for i in range(n_latent)])
366+
mask = is_constant[:, None] # (n_latent, 1)
367+
368+
f_mat = jnp.where(mask, identity, f_from_coeffs)
369+
c_vec = jnp.where(is_constant, 0.0, c_from_coeffs)
370+
371+
return f_mat, c_vec
372+
373+
231374
def _calculate_sigma_points(
232375
states: Array,
233376
upper_chols: Array,

0 commit comments

Comments
 (0)