Skip to content

Commit 884b59b

Browse files
committed
fix types
1 parent c504f48 commit 884b59b

11 files changed

Lines changed: 25 additions & 18 deletions

File tree

src/dynaris/core/ssm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,15 @@ def fit(
193193
elif self._filter_name == "ekf":
194194
self._filter_result = ekf_filter(self._model, obs, initial_state=initial_state)
195195
elif self._filter_name == "ukf":
196+
assert isinstance(self._model, NonlinearSSM)
196197
self._filter_result = ukf_filter(
197198
self._model,
198199
obs,
199200
initial_state=initial_state,
200201
**self._filter_kwargs,
201202
)
202203
elif self._filter_name == "particle":
204+
assert isinstance(self._model, NonlinearSSM)
203205
key = self._key if self._key is not None else jax.random.PRNGKey(0)
204206
self._filter_result = particle_filter(
205207
self._model,
@@ -209,6 +211,7 @@ def fit(
209211
**self._filter_kwargs,
210212
)
211213
elif self._filter_name == "hamilton":
214+
assert isinstance(self._model, MarkovSwitchingSSM)
212215
self._filter_result = hamilton_filter(self._model, obs, initial_state=initial_state)
213216

214217
self._is_fitted = True
@@ -229,9 +232,11 @@ def residuals(self) -> Array:
229232
if isinstance(self._model, StateSpaceModel):
230233
from dynaris.estimation.diagnostics import standardized_residuals
231234

235+
assert isinstance(fr, FilterResult)
232236
return standardized_residuals(fr, self._model)
233237

234238
# Nonlinear: compute y - h(predicted_state)
239+
assert isinstance(self._model, NonlinearSSM)
235240
predicted_obs = jax.vmap(self._model.h)(fr.predicted_states)
236241
return fr.observations - predicted_obs
237242

@@ -289,11 +294,13 @@ def _plot_filtered(self, **kwargs: Any) -> Any:
289294
if isinstance(self._model, StateSpaceModel):
290295
from dynaris.plotting.plots import plot_filtered
291296

297+
assert isinstance(fr, FilterResult)
292298
return plot_filtered(fr, self._model, **kwargs)
293299

294300
# Nonlinear: compute observation-space predictions
295301
import matplotlib.pyplot as plt
296302

303+
assert isinstance(self._model, NonlinearSSM)
297304
filtered_obs = jax.vmap(self._model.h)(fr.filtered_states)
298305
obs = np.asarray(fr.observations)
299306
filt = np.asarray(filtered_obs)

src/dynaris/core/switching.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,22 @@ def obs_dim(self) -> int:
125125
@property
126126
def G_stack(self) -> Array: # noqa: N802
127127
"""Stacked system matrices, shape (K, n, n)."""
128-
return self._G_stack # type: ignore[attr-defined]
128+
return self._G_stack # type: ignore[attr-defined, no-any-return]
129129

130130
@property
131131
def F_stack(self) -> Array: # noqa: N802
132132
"""Stacked observation matrices, shape (K, m, n)."""
133-
return self._F_stack # type: ignore[attr-defined]
133+
return self._F_stack # type: ignore[attr-defined, no-any-return]
134134

135135
@property
136136
def W_stack(self) -> Array: # noqa: N802
137137
"""Stacked evolution covariances, shape (K, n, n)."""
138-
return self._W_stack # type: ignore[attr-defined]
138+
return self._W_stack # type: ignore[attr-defined, no-any-return]
139139

140140
@property
141141
def V_stack(self) -> Array: # noqa: N802
142142
"""Stacked observation covariances, shape (K, m, m)."""
143-
return self._V_stack # type: ignore[attr-defined]
143+
return self._V_stack # type: ignore[attr-defined, no-any-return]
144144

145145
# --- Factory methods ---
146146

@@ -185,7 +185,7 @@ def tree_flatten(self) -> tuple[list[Array], dict[str, object]]:
185185
self.transition_matrix,
186186
self.initial_probs,
187187
]
188-
aux = {
188+
aux: dict[str, object] = {
189189
"n_regimes": self.n_regimes,
190190
"state_dim": self.state_dim,
191191
"obs_dim": self.obs_dim,

src/dynaris/estimation/bayesian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def model_fn(params):
125125
def _log_density(params: Array) -> Array:
126126
model = model_fn(params)
127127
fr = kalman_filter(model, observations)
128-
return fr.log_likelihood + log_prior_fn(params)
128+
return fr.log_likelihood + log_prior_fn(params) # type: ignore[no-any-return]
129129

130130
# NumPyro model: sample unconstrained params, factor by log-density
131131
def _numpyro_model() -> None:
@@ -149,7 +149,7 @@ def _numpyro_model() -> None:
149149
@jax.jit
150150
def _compute_ll(params: Array) -> Array:
151151
model = model_fn(params)
152-
return kalman_filter(model, observations).log_likelihood
152+
return kalman_filter(model, observations).log_likelihood # type: ignore[no-any-return]
153153

154154
log_lls = jax.vmap(_compute_ll)(samples)
155155

src/dynaris/estimation/comparison.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def compute_waic(
7171

7272
def _pw_ll(params: Array) -> Array:
7373
model = model_fn(params)
74-
return _pointwise_log_likelihood(model, observations)
74+
return _pointwise_log_likelihood(model, observations) # type: ignore[no-any-return]
7575

7676
pw_lls = jax.vmap(_pw_ll)(result.samples) # (n_samples, T)
7777

@@ -109,7 +109,7 @@ def compute_loo(
109109

110110
def _pw_ll(params: Array) -> Array:
111111
model = model_fn(params)
112-
return _pointwise_log_likelihood(model, observations)
112+
return _pointwise_log_likelihood(model, observations) # type: ignore[no-any-return]
113113

114114
pw_lls = jax.vmap(_pw_ll)(result.samples) # (n_samples, T)
115115
pw_lls_np = np.asarray(pw_lls)[np.newaxis, :, :] # (1, n_samples, T)
@@ -158,7 +158,7 @@ def to_arviz(
158158

159159
def _pw_ll(params: Array) -> Array:
160160
model = model_fn(params)
161-
return _pointwise_log_likelihood(model, observations)
161+
return _pointwise_log_likelihood(model, observations) # type: ignore[no-any-return]
162162

163163
pw_lls = jax.vmap(_pw_ll)(result.samples)
164164

src/dynaris/estimation/parallel_mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def fit_bayesian_parallel(
8686
def _log_density(params: Array) -> Array:
8787
model = model_fn(params)
8888
fr = kalman_filter(model, observations)
89-
return fr.log_likelihood + log_prior_fn(params)
89+
return fr.log_likelihood + log_prior_fn(params) # type: ignore[no-any-return]
9090

9191
def _numpyro_model() -> None:
9292
params = numpyro.sample(
@@ -119,7 +119,7 @@ def _numpyro_model() -> None:
119119
@jax.jit
120120
def _compute_ll(params: Array) -> Array:
121121
model = model_fn(params)
122-
return kalman_filter(model, observations).log_likelihood
122+
return kalman_filter(model, observations).log_likelihood # type: ignore[no-any-return]
123123

124124
log_lls = jax.vmap(_compute_ll)(all_samples)
125125

src/dynaris/estimation/predictive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _forecast_one(params: Array) -> Array:
102102
cov=fr.filtered_covariances[-1],
103103
)
104104
fc = forecast(model, last_state, steps)
105-
return fc.mean # (steps, obs_dim)
105+
return fc.mean # type: ignore[no-any-return] # (steps, obs_dim)
106106

107107
all_forecasts = jax.vmap(_forecast_one)(samples) # (n, steps, obs_dim)
108108

src/dynaris/filters/ekf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def scan(
131131
initial_state: GaussianState | None = None,
132132
) -> FilterResult:
133133
"""Run full forward EKF via jax.lax.scan."""
134-
return ekf_filter(model, observations, initial_state)
134+
return ekf_filter(model, observations, initial_state) # type: ignore[no-any-return]
135135

136136

137137
@jax.jit

src/dynaris/filters/hamilton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def hamilton_filter(
230230
if initial_state is None:
231231
initial_state = model.initial_state()
232232

233-
return _hamilton_scan(model, observations, initial_state)
233+
return _hamilton_scan(model, observations, initial_state) # type: ignore[no-any-return]
234234

235235

236236
@jax.jit

src/dynaris/filters/particle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def compute_log_weights(particles: Array, observation: Array, model: NonlinearSS
181181
log_det = jnp.linalg.slogdet(model.R)[1]
182182
r_inv_diff = jnp.linalg.solve(model.R, diff.T).T # (N, m)
183183
mahal = jnp.sum(diff * r_inv_diff, axis=-1) # (N,)
184-
return -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal)
184+
return -0.5 * (m * jnp.log(2.0 * jnp.pi) + log_det + mahal) # type: ignore[no-any-return]
185185

186186

187187
def _normalize_log_weights(log_weights: Array) -> tuple[Array, Array]:

src/dynaris/filters/ukf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _ukf_filter_impl(
317317
initial_state = model.initial_state()
318318

319319
weights = compute_weights(model.state_dim, alpha, beta, kappa)
320-
return _ukf_scan(model, observations, initial_state, weights)
320+
return _ukf_scan(model, observations, initial_state, weights) # type: ignore[no-any-return]
321321

322322

323323
@jax.jit

0 commit comments

Comments
 (0)