Skip to content

Commit 220a9b0

Browse files
committed
Update docs based on benchmark results.
1 parent 6388c0f commit 220a9b0

1 file changed

Lines changed: 23 additions & 1 deletion

File tree

docs/explanations/linear_predict.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ via `functools.partial`. When the linear path is chosen, extra keyword arguments
1212
(`latent_factors`, `constant_factor_indices`, `n_all_factors`) are bound at setup time so
1313
the predict function has the same call signature as the unscented variant.
1414

15-
## Why it is faster
15+
## Why it is faster and uses less memory
1616

1717
The unscented predict generates $2n + 1$ sigma points (where $n$ is the number of latent
1818
factors), transforms each one through the transition function, then recovers predicted
@@ -27,6 +27,14 @@ $(2n) \times n$ matrix: $n$ rows from the propagated Cholesky factor and $n$ row
2727
shocks. The reduction from $3n + 1$ to $2n$ rows speeds up the QR step and removes all
2828
sigma-point overhead.
2929

30+
The memory savings can be more important than the speed gains. The unscented path
31+
materialises $2n + 1$ sigma points for every observation and mixture component, and
32+
JAX's automatic differentiation retains intermediate buffers for the backward pass. The
33+
linear path replaces all of this with a single matrix multiply whose memory footprint
34+
scales with $n^2$ rather than with the number of sigma points times the number of
35+
observations. On memory-constrained GPUs this can be the difference between fitting the
36+
model and running out of memory.
37+
3038
## Building F and c
3139

3240
The linear predict assembles a transition matrix $F$ of shape
@@ -98,3 +106,17 @@ columns in $F$ and therefore influence the predicted mean through the matrix--ve
98106
product. However, they carry no uncertainty: their columns are excluded from the
99107
covariance propagation. This is why $G$ uses only the first $n_\text{latent}$ columns of
100108
$F$ rather than the full matrix.
109+
110+
## Practical impact
111+
112+
Benchmarks on a 4-factor linear model (`health-cognition`,
113+
`no_feedback_to_investments_linear`, 8 GiB GPU) show a modest ~6 % speed-up on GPU
114+
(8.4 vs 8.9 s per optimizer iteration) and negligible difference on CPU. The speed gain
115+
is small because with only 4 latent factors the unscented transform generates just 9
116+
sigma points — a trivially cheap operation on modern hardware.
117+
118+
The memory reduction is the more significant benefit. Under the same conditions the
119+
unscented path ran out of GPU memory when only ~5 GiB was free, while the linear path
120+
ran without issues. For models with more latent factors both advantages grow: the
121+
sigma-point count scales as $2n + 1$ and the QR matrix shrinks from $(3n + 1) \times n$
122+
to $(2n) \times n$.

0 commit comments

Comments
 (0)