@@ -284,23 +284,9 @@ def linear_kalman_predict(
284284 """
285285 n_latent = len (latent_factors )
286286
287- # Build F (n_latent x n_all) and c (n_latent,) from trans_coeffs.
288- # linear factor i: F[i] = trans_coeffs[factor_i][:-1], c[i] = last element
289- # constant factor i: F[i] = e_i (unit vector), c[i] = 0
290- f_rows = []
291- c_vals = []
292- for i , factor in enumerate (latent_factors ):
293- if i in constant_factor_indices :
294- row = jnp .zeros (n_all_factors ).at [i ].set (1.0 )
295- f_rows .append (row )
296- c_vals .append (0.0 )
297- else :
298- coeffs = trans_coeffs [factor ]
299- f_rows .append (coeffs [:- 1 ])
300- c_vals .append (coeffs [- 1 ])
301-
302- f_mat = jnp .stack (f_rows ) # (n_latent, n_all)
303- c_vec = jnp .array (c_vals ) # (n_latent,)
287+ f_mat , c_vec = _build_f_and_c (
288+ latent_factors , constant_factor_indices , n_all_factors , trans_coeffs
289+ )
304290
305291 s_in = anchoring_scaling_factors [0 ][:n_latent ] # (n_latent,) for input period
306292 s_out = anchoring_scaling_factors [1 ][:n_latent ] # (n_latent,) for output period
@@ -338,6 +324,53 @@ def linear_kalman_predict(
338324 return predicted_states , predicted_covs
339325
340326
327+ def _build_f_and_c (
328+ latent_factors : tuple [str , ...],
329+ constant_factor_indices : frozenset [int ],
330+ n_all_factors : int ,
331+ trans_coeffs : dict [str , Array ],
332+ ) -> tuple [Array , Array ]:
333+ """Build F matrix and c vector from transition coefficients.
334+
335+ Stack all coefficient arrays, build identity rows for constant factors,
336+ and select via a boolean mask.
337+
338+ Args:
339+ latent_factors: Tuple of latent factor names.
340+ constant_factor_indices: Indices of factors with `constant` transition.
341+ n_all_factors: Total number of factors (latent + observed).
342+ trans_coeffs: Dict mapping factor name to 1d coefficient array.
343+
344+ Returns:
345+ f_mat: Array of shape (n_latent, n_all_factors).
346+ c_vec: Array of shape (n_latent,).
347+
348+ """
349+ n_latent = len (latent_factors )
350+ identity = jnp .eye (n_latent , n_all_factors )
351+
352+ # Will be of shape (n_latent, n_all+1)
353+ all_coeffs = jnp .stack (
354+ [
355+ trans_coeffs [f ]
356+ if i not in constant_factor_indices
357+ else jnp .zeros (n_all_factors + 1 )
358+ for i , f in enumerate (latent_factors )
359+ ]
360+ )
361+
362+ f_from_coeffs = all_coeffs [:, :- 1 ]
363+ c_from_coeffs = all_coeffs [:, - 1 ]
364+
365+ is_constant = jnp .array ([i in constant_factor_indices for i in range (n_latent )])
366+ mask = is_constant [:, None ] # (n_latent, 1)
367+
368+ f_mat = jnp .where (mask , identity , f_from_coeffs )
369+ c_vec = jnp .where (is_constant , 0.0 , c_from_coeffs )
370+
371+ return f_mat , c_vec
372+
373+
341374def _calculate_sigma_points (
342375 states : Array ,
343376 upper_chols : Array ,
0 commit comments