Skip to content

Commit d7ba962

Browse files
committed
Idiomatic Jax for linear filter, though no speed difference. Added a note.
1 parent 7eb35a9 commit d7ba962

2 files changed

Lines changed: 63 additions & 17 deletions

File tree

docs/explanations/linear_predict.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ For each latent factor $i$:
5050
- **Constant factor**: row $i$ of $F$ is the unit vector $e_i$ (identity row) and
5151
$c_i = 0$, so the factor value is simply carried forward.
5252

53+
The implementation uses a stack-then-mask approach: all coefficient arrays are stacked
54+
into a single matrix (with zero-padded rows for constant factors), an identity matrix
55+
provides the constant-factor rows, and `jnp.where` selects between them using a boolean
56+
mask. This avoids per-element `.at[i].set()` calls and conditional branching, producing
57+
a cleaner trace for JAX's compiler.
58+
59+
Three construction strategies were benchmarked (loop with conditional `.at[i].set()`,
60+
stack-then-mask with `jnp.where`, and index-scatter with pre-separated sub-matrices).
61+
All three produced identical XLA graphs and showed no meaningful runtime difference
62+
(~6.3--6.7 ms per call on CPU, 4-factor model, 5000 observations), confirming that the
63+
construction is fully resolved at trace time. The stack-then-mask variant was kept for
64+
its cleaner, more idiomatic JAX style.
65+
5366
## Mean prediction
5467

5568
The mean prediction incorporates anchoring, which rescales factors to a common metric

src/skillmodels/kalman_filters.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -284,23 +284,9 @@ def linear_kalman_predict(
284284
"""
285285
n_latent = len(latent_factors)
286286

287-
# Build F (n_latent x n_all) and c (n_latent,) from trans_coeffs.
288-
# linear factor i: F[i] = trans_coeffs[factor_i][:-1], c[i] = last element
289-
# constant factor i: F[i] = e_i (unit vector), c[i] = 0
290-
f_rows = []
291-
c_vals = []
292-
for i, factor in enumerate(latent_factors):
293-
if i in constant_factor_indices:
294-
row = jnp.zeros(n_all_factors).at[i].set(1.0)
295-
f_rows.append(row)
296-
c_vals.append(0.0)
297-
else:
298-
coeffs = trans_coeffs[factor]
299-
f_rows.append(coeffs[:-1])
300-
c_vals.append(coeffs[-1])
301-
302-
f_mat = jnp.stack(f_rows) # (n_latent, n_all)
303-
c_vec = jnp.array(c_vals) # (n_latent,)
287+
f_mat, c_vec = _build_f_and_c(
288+
latent_factors, constant_factor_indices, n_all_factors, trans_coeffs
289+
)
304290

305291
s_in = anchoring_scaling_factors[0][:n_latent] # (n_latent,) for input period
306292
s_out = anchoring_scaling_factors[1][:n_latent] # (n_latent,) for output period
@@ -338,6 +324,53 @@ def linear_kalman_predict(
338324
return predicted_states, predicted_covs
339325

340326

327+
def _build_f_and_c(
328+
latent_factors: tuple[str, ...],
329+
constant_factor_indices: frozenset[int],
330+
n_all_factors: int,
331+
trans_coeffs: dict[str, Array],
332+
) -> tuple[Array, Array]:
333+
"""Build F matrix and c vector from transition coefficients.
334+
335+
Stack all coefficient arrays, build identity rows for constant factors,
336+
and select via a boolean mask.
337+
338+
Args:
339+
latent_factors: Tuple of latent factor names.
340+
constant_factor_indices: Indices of factors with `constant` transition.
341+
n_all_factors: Total number of factors (latent + observed).
342+
trans_coeffs: Dict mapping factor name to 1d coefficient array.
343+
344+
Returns:
345+
f_mat: Array of shape (n_latent, n_all_factors).
346+
c_vec: Array of shape (n_latent,).
347+
348+
"""
349+
n_latent = len(latent_factors)
350+
identity = jnp.eye(n_latent, n_all_factors)
351+
352+
# Will be of shape (n_latent, n_all+1)
353+
all_coeffs = jnp.stack(
354+
[
355+
trans_coeffs[f]
356+
if i not in constant_factor_indices
357+
else jnp.zeros(n_all_factors + 1)
358+
for i, f in enumerate(latent_factors)
359+
]
360+
)
361+
362+
f_from_coeffs = all_coeffs[:, :-1]
363+
c_from_coeffs = all_coeffs[:, -1]
364+
365+
is_constant = jnp.array([i in constant_factor_indices for i in range(n_latent)])
366+
mask = is_constant[:, None] # (n_latent, 1)
367+
368+
f_mat = jnp.where(mask, identity, f_from_coeffs)
369+
c_vec = jnp.where(is_constant, 0.0, c_from_coeffs)
370+
371+
return f_mat, c_vec
372+
373+
341374
def _calculate_sigma_points(
342375
states: Array,
343376
upper_chols: Array,

0 commit comments

Comments
 (0)