@@ -12,7 +12,7 @@ via `functools.partial`. When the linear path is chosen, extra keyword arguments
1212(` latent_factors ` , ` constant_factor_indices ` , ` n_all_factors ` ) are bound at setup time so
1313the predict function has the same call signature as the unscented variant.
1414
15- ## Why it is faster
15+ ## Why it is faster and uses less memory
1616
1717The unscented predict generates $2n + 1$ sigma points (where $n$ is the number of latent
1818factors), transforms each one through the transition function, then recovers predicted
@@ -27,6 +27,14 @@ $(2n) \times n$ matrix: $n$ rows from the propagated Cholesky factor and $n$ row
2727shocks. The reduction from $3n + 1$ to $2n$ rows speeds up the QR step and removes all
2828sigma-point overhead.
2929
30+ The memory savings can be more important than the speed gains. The unscented path
31+ materialises $2n + 1$ sigma points for every observation and mixture component, and
32+ JAX's automatic differentiation retains intermediate buffers for the backward pass. The
33+ linear path replaces all of this with a single matrix multiply whose memory footprint
34+ scales with $n^2$ rather than with the number of sigma points times the number of
35+ observations. On memory-constrained GPUs this can be the difference between fitting the
36+ model and running out of memory.
37+
3038## Building F and c
3139
3240The linear predict assembles a transition matrix $F$ of shape
@@ -98,3 +106,17 @@ columns in $F$ and therefore influence the predicted mean through the matrix--ve
98106product. However, they carry no uncertainty: their columns are excluded from the
99107covariance propagation. This is why $G$ uses only the first $n_ \text{latent}$ columns of
100108$F$ rather than the full matrix.
109+
110+ ## Practical impact
111+
112+ Benchmarks on a 4-factor linear model (` health-cognition ` ,
113+ ` no_feedback_to_investments_linear ` , 8 GiB GPU) show a modest ~ 6 % speed-up on GPU
114+ (8.4 vs 8.9 s per optimizer iteration) and negligible difference on CPU. The speed gain
115+ is small because with only 4 latent factors the unscented transform generates just 9
116+ sigma points — a trivially cheap operation on modern hardware.
117+
118+ The memory reduction is the more significant benefit. Under the same conditions the
119+ unscented path ran out of GPU memory when only ~ 5 GiB was free, while the linear path
120+ ran without issues. For models with more latent factors both advantages grow: the
121+ sigma-point count scales as $2n + 1$ and the QR matrix shrinks from $(3n + 1) \times n$
122+ to $(2n) \times n$.
0 commit comments