@@ -193,13 +193,15 @@ def fit(
193193 elif self ._filter_name == "ekf" :
194194 self ._filter_result = ekf_filter (self ._model , obs , initial_state = initial_state )
195195 elif self ._filter_name == "ukf" :
196+ assert isinstance (self ._model , NonlinearSSM )
196197 self ._filter_result = ukf_filter (
197198 self ._model ,
198199 obs ,
199200 initial_state = initial_state ,
200201 ** self ._filter_kwargs ,
201202 )
202203 elif self ._filter_name == "particle" :
204+ assert isinstance (self ._model , NonlinearSSM )
203205 key = self ._key if self ._key is not None else jax .random .PRNGKey (0 )
204206 self ._filter_result = particle_filter (
205207 self ._model ,
@@ -209,6 +211,7 @@ def fit(
209211 ** self ._filter_kwargs ,
210212 )
211213 elif self ._filter_name == "hamilton" :
214+ assert isinstance (self ._model , MarkovSwitchingSSM )
212215 self ._filter_result = hamilton_filter (self ._model , obs , initial_state = initial_state )
213216
214217 self ._is_fitted = True
@@ -229,9 +232,11 @@ def residuals(self) -> Array:
229232 if isinstance (self ._model , StateSpaceModel ):
230233 from dynaris .estimation .diagnostics import standardized_residuals
231234
235+ assert isinstance (fr , FilterResult )
232236 return standardized_residuals (fr , self ._model )
233237
234238 # Nonlinear: compute y - h(predicted_state)
239+ assert isinstance (self ._model , NonlinearSSM )
235240 predicted_obs = jax .vmap (self ._model .h )(fr .predicted_states )
236241 return fr .observations - predicted_obs
237242
@@ -289,11 +294,13 @@ def _plot_filtered(self, **kwargs: Any) -> Any:
289294 if isinstance (self ._model , StateSpaceModel ):
290295 from dynaris .plotting .plots import plot_filtered
291296
297+ assert isinstance (fr , FilterResult )
292298 return plot_filtered (fr , self ._model , ** kwargs )
293299
294300 # Nonlinear: compute observation-space predictions
295301 import matplotlib .pyplot as plt
296302
303+ assert isinstance (self ._model , NonlinearSSM )
297304 filtered_obs = jax .vmap (self ._model .h )(fr .filtered_states )
298305 obs = np .asarray (fr .observations )
299306 filt = np .asarray (filtered_obs )
0 commit comments