1+ import functools
2+
13import jax
24import jax .numpy as jnp
35
911# ======================================================================================
1012
1113
14+ @functools .partial (jax .checkpoint , prevent_cse = False )
1215def kalman_update (
1316 states ,
1417 upper_chols ,
@@ -152,12 +155,13 @@ def calculate_sigma_scaling_factor_and_weights(n_states, kappa=2):
152155 return scaling_factor , weights
153156
154157
158+ @functools .partial (jax .checkpoint , static_argnums = 0 , prevent_cse = False )
155159def kalman_predict (
160+ transition_func ,
156161 states ,
157162 upper_chols ,
158163 sigma_scaling_factor ,
159164 sigma_weights ,
160- transition_info ,
161165 trans_coeffs ,
162166 shock_sds ,
163167 anchoring_scaling_factors ,
@@ -167,6 +171,7 @@ def kalman_predict(
167171 """Make a unscented Kalman predict.
168172
169173 Args:
174+ transition_func (Callable): The transition function.
170175 states (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states) with
171176 pre-update states estimates.
172177 upper_chols (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states,
@@ -177,9 +182,6 @@ def kalman_predict(
177182 the sigma_point algorithm chosen.
178183 sigma_weights (jax.numpy.array): 1d array of length n_sigma with non-negative
179184 sigma weights.
180- transition_info (dict): Dict with the entries "func" (the actual transition
181- function) and "columns" (a dictionary mapping factors that are needed
182- as individual columns to positions in the factor array).
183185 trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
184186 anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_fac) with
185187 the scaling factors for anchoring. The first row corresponds to the input
@@ -203,7 +205,7 @@ def kalman_predict(
203205 )
204206 transformed = transform_sigma_points (
205207 sigma_points ,
206- transition_info ,
208+ transition_func ,
207209 trans_coeffs ,
208210 anchoring_scaling_factors ,
209211 anchoring_constants ,
@@ -225,6 +227,7 @@ def kalman_predict(
225227 return predicted_states , predicted_covs
226228
227229
230+ @functools .partial (jax .checkpoint , prevent_cse = False )
228231def _calculate_sigma_points (states , upper_chols , scaling_factor , observed_factors ):
229232 """Calculate the array of sigma_points for the unscented transform.
230233
@@ -272,7 +275,7 @@ def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factor
272275
273276def transform_sigma_points (
274277 sigma_points ,
275- transition_info ,
278+ transition_func ,
276279 trans_coeffs ,
277280 anchoring_scaling_factors ,
278281 anchoring_constants ,
@@ -281,9 +284,7 @@ def transform_sigma_points(
281284
282285 Args:
283286 sigma_points (jax.numpy.array) of shape n_obs, n_mixtures, n_sigma, n_fac.
284- transition_info (dict): Dict with the entries "func" (the actual transition
285- function) and "columns" (a dictionary mapping factors that are needed
286- as individual columns to positions in the factor array).
287+ transition_func (Callable): The transition function.
287288 trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
288289 anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_states) with
289290 the scaling factors for anchoring. The first row corresponds to the input
@@ -303,9 +304,7 @@ def transform_sigma_points(
303304
304305 anchored = flat_sigma_points * anchoring_scaling_factors [0 ] + anchoring_constants [0 ]
305306
306- transition_function = transition_info ["func" ]
307-
308- transformed_anchored = transition_function (trans_coeffs , anchored )
307+ transformed_anchored = transition_func (trans_coeffs , anchored )
309308
310309 n_observed = transformed_anchored .shape [- 1 ]
311310
0 commit comments