Skip to content

Commit ab5e269

Browse files
committed
feat: dynamic factor models and new examples
1 parent 478033f commit ab5e269

19 files changed

Lines changed: 1862 additions & 1 deletion

docs/api/factor.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
Dynamic Factor Models
2+
=====================
3+
4+
Dynamic Factor Models (DFMs) reduce high-dimensional multivariate time
5+
series to a small number of latent factors. The observation matrix
6+
(loading matrix) maps latent factors to observed variables.
7+
8+
High-Level API
9+
--------------
10+
11+
.. autoclass:: dynaris.models.dfm_api.DFMModel
12+
:members:
13+
:show-inheritance:
14+
15+
Model Factory
16+
-------------
17+
18+
.. autofunction:: dynaris.models.factor.DynamicFactorModel
19+
20+
Estimation
21+
----------
22+
23+
.. autofunction:: dynaris.estimation.dfm.fit_dfm_em
24+
25+
.. autoclass:: dynaris.estimation.dfm.DFMResult
26+
:members:
27+
28+
Utilities
29+
---------
30+
31+
.. autofunction:: dynaris.models.factor.initialize_loadings_pca
32+
33+
.. autofunction:: dynaris.models.factor.rotate_loadings
34+
35+
.. autofunction:: dynaris.models.factor.apply_identification_constraints

docs/api/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Complete reference for all public classes and functions in dynaris.
1414
+------------------+-------------------------------------------------------------+
1515
| :doc:`models` | Built-in nonlinear models (stochastic vol, tracking, etc.) |
1616
+------------------+-------------------------------------------------------------+
17+
| :doc:`factor` | Dynamic Factor Models (DFMModel, loadings, rotation) |
18+
+------------------+-------------------------------------------------------------+
1719
| :doc:`core` | ``StateSpaceModel``, ``GaussianState``, result containers |
1820
+------------------+-------------------------------------------------------------+
1921
| :doc:`filters` | Kalman, EKF, UKF, and Particle filters |
@@ -39,6 +41,7 @@ Complete reference for all public classes and functions in dynaris.
3941
dlm
4042
components
4143
models
44+
factor
4245
core
4346
filters
4447
switching

examples/bayesian_nile.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Bayesian estimation — airline passengers with trend + seasonality.
2+
3+
Demonstrates full Bayesian inference for a composed DLM (local linear
4+
trend + seasonal) using NumPyro's NUTS sampler. The airline passenger
5+
data is log-transformed to convert multiplicative seasonality to additive.
6+
7+
This example:
8+
1. Fits a trend + seasonal model to log(airline passengers) via MLE
9+
2. Fits the same model via Bayesian MCMC with a normal prior
10+
3. Compares posterior mean with MLE point estimates
11+
4. Generates posterior predictive forecasts with credible intervals
12+
5. Runs prior predictive checks
13+
"""
14+
15+
import jax
16+
import jax.numpy as jnp
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
20+
from dynaris import LocalLinearTrend, Seasonal
21+
from dynaris.core.state_space import StateSpaceModel
22+
from dynaris.datasets import load_airline
23+
from dynaris.estimation import fit_bayesian, fit_mle
24+
from dynaris.estimation.predictive import posterior_predictive_forecast, prior_predictive
25+
from dynaris.estimation.priors import normal_log_prior
26+
27+
# --- Data (log-transform for additive seasonality) ---
28+
y = load_airline()
29+
obs = jnp.log(jnp.array(y.values, dtype=jnp.float32)).reshape(-1, 1)
30+
print(f"Airline passengers: {len(obs)} monthly observations (log-transformed)")
31+
32+
33+
# --- Model factory: trend + seasonal (period=12) ---
34+
# params: [log_sigma_level, log_sigma_slope, log_sigma_seasonal, log_sigma_obs]
35+
def model_fn(params: jax.Array) -> StateSpaceModel:
36+
trend = LocalLinearTrend(
37+
sigma_level=jnp.exp(params[0]),
38+
sigma_slope=jnp.exp(params[1]),
39+
sigma_obs=0.0,
40+
)
41+
seasonal = Seasonal(
42+
period=12,
43+
sigma_seasonal=jnp.exp(params[2]),
44+
sigma_obs=jnp.exp(params[3]),
45+
)
46+
return trend + seasonal
47+
48+
49+
# --- MLE fit (baseline) ---
50+
print("\nFitting MLE...")
51+
init_params = jnp.array([0.0, -2.0, -2.0, -1.0])
52+
mle_result = fit_mle(model_fn, obs, init_params)
53+
print(f" MLE params: {np.round(np.asarray(mle_result.params), 3)}")
54+
print(f" MLE log-likelihood: {mle_result.log_likelihood:.2f}")
55+
56+
# --- Bayesian fit ---
57+
print("\nRunning Bayesian MCMC (NUTS)...")
58+
prior = normal_log_prior(loc=0.0, scale=5.0)
59+
bayes_result = fit_bayesian(
60+
model_fn,
61+
obs,
62+
init_params=mle_result.params,
63+
log_prior_fn=prior,
64+
n_warmup=500,
65+
n_samples=1000,
66+
key=jax.random.PRNGKey(42),
67+
param_names=(
68+
"log_sigma_level",
69+
"log_sigma_slope",
70+
"log_sigma_seasonal",
71+
"log_sigma_obs",
72+
),
73+
)
74+
75+
samples = bayes_result.samples
76+
names = bayes_result.param_names or ()
77+
print(" Posterior means:")
78+
for j, name in enumerate(names):
79+
post_mean = float(jnp.mean(samples[:, j]))
80+
mle_val = float(mle_result.params[j])
81+
print(f" {name}: {post_mean:.3f} (MLE: {mle_val:.3f})")
82+
mean_ll = float(jnp.mean(bayes_result.log_likelihood_samples))
83+
print(f" Posterior mean log-likelihood: {mean_ll:.2f}")
84+
if bayes_result.info:
85+
print(f" Divergences: {bayes_result.info.get('n_divergences', 'N/A')}")
86+
87+
# --- Posterior predictive forecast (24 months ahead) ---
88+
print("\nGenerating posterior predictive forecast (24 months)...")
89+
fc = posterior_predictive_forecast(
90+
bayes_result,
91+
model_fn,
92+
obs,
93+
steps=24,
94+
n_posterior_samples=200,
95+
)
96+
97+
# --- Prior predictive ---
98+
print("Running prior predictive check...")
99+
100+
101+
def prior_sample_fn(key: jax.Array) -> jax.Array:
102+
return jax.random.normal(key, (4,)) * 1.0 + mle_result.params
103+
104+
105+
prior_sims = prior_predictive(model_fn, prior_sample_fn, n_steps=144, n_samples=50)
106+
107+
# --- Plots ---
108+
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
109+
110+
# Posterior parameter distributions
111+
ax = axes[0, 0]
112+
for j, name in enumerate(names):
113+
ax.hist(np.asarray(samples[:, j]), bins=25, alpha=0.6, label=name)
114+
ax.axvline(float(mle_result.params[j]), color=f"C{j}", linestyle="--", linewidth=1)
115+
ax.set_xlabel("Parameter value (log scale)")
116+
ax.set_ylabel("Count")
117+
ax.set_title("Posterior Distributions (dashed = MLE)")
118+
ax.legend(frameon=False, fontsize=7)
119+
120+
# Posterior predictive forecast (back-transformed to passenger scale)
121+
ax = axes[0, 1]
122+
n_obs = len(obs)
123+
time_obs = np.arange(n_obs)
124+
time_fc = np.arange(n_obs, n_obs + 24)
125+
ax.plot(
126+
time_obs[-48:],
127+
np.exp(np.asarray(obs[-48:, 0])),
128+
"k.-",
129+
markersize=2,
130+
label="Observed",
131+
)
132+
ax.plot(
133+
time_fc,
134+
np.exp(np.asarray(fc["mean"][:, 0])),
135+
"C0-",
136+
linewidth=1.5,
137+
label="Forecast mean",
138+
)
139+
ax.fill_between(
140+
time_fc,
141+
np.exp(np.asarray(fc["lower"][:, 0])),
142+
np.exp(np.asarray(fc["upper"][:, 0])),
143+
alpha=0.3,
144+
color="C0",
145+
label="95% credible interval",
146+
)
147+
ax.set_xlabel("Month index")
148+
ax.set_ylabel("Passengers")
149+
ax.set_title("Posterior Predictive Forecast (24 months)")
150+
ax.legend(frameon=False, fontsize=8)
151+
152+
# Prior predictive (log scale — simulations start from zero)
153+
ax = axes[1, 0]
154+
for i in range(min(20, prior_sims.shape[0])):
155+
ax.plot(np.asarray(prior_sims[i, :, 0]), alpha=0.4, linewidth=0.5, color="gray")
156+
ax.axhline(float(jnp.mean(obs)), color="k", linestyle="--", linewidth=0.8, label="Data mean")
157+
ax.set_xlabel("Time step")
158+
ax.set_ylabel("log(passengers)")
159+
ax.set_title("Prior Predictive Samples (log scale)")
160+
ax.legend(frameon=False, fontsize=8)
161+
162+
# Trace plot
163+
ax = axes[1, 1]
164+
for j, name in enumerate(names):
165+
ax.plot(np.asarray(samples[:, j]), alpha=0.6, linewidth=0.4, label=name)
166+
ax.set_xlabel("Sample index")
167+
ax.set_ylabel("Value")
168+
ax.set_title("MCMC Traces")
169+
ax.legend(frameon=False, fontsize=7)
170+
171+
fig.suptitle("Bayesian Estimation — Airline Passengers (Trend + Seasonality)")
172+
fig.tight_layout()
173+
plt.show()

examples/bearings_tracking.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Bearings-only tracking — EKF on a nonlinear observation model.
2+
3+
A target moves with near-constant velocity in 2D, observed only via
4+
bearing (angle) from a fixed sensor. The observation function is
5+
nonlinear (atan2), making this a classic EKF/UKF benchmark.
6+
7+
This example:
8+
1. Simulates a target on a curved trajectory
9+
2. Generates noisy bearing measurements from a sensor at the origin
10+
3. Tracks the 4D state (x, vx, y, vy) using the EKF
11+
4. Plots the reconstructed trajectory vs ground truth
12+
"""
13+
14+
import jax
15+
import jax.numpy as jnp
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
19+
from dynaris.core.types import GaussianState
20+
from dynaris.filters.ekf import ekf_filter
21+
from dynaris.models import BearingsTracking
22+
23+
# --- Model ---
24+
model = BearingsTracking(sensor_pos=(0.0, 0.0), dt=1.0, sigma_accel=0.01, sigma_bearing=0.05)
25+
26+
# --- Simulate target trajectory (gentle curve) ---
27+
key = jax.random.PRNGKey(7)
28+
n_steps = 100
29+
30+
true_state = jnp.array([5.0, 0.3, 5.0, 0.2])
31+
true_states = []
32+
observations = []
33+
for _ in range(n_steps):
34+
key, k_obs = jax.random.split(key)
35+
true_state = model.f(true_state)
36+
bearing = model.h(true_state)
37+
noisy_bearing = bearing + jax.random.normal(k_obs, (1,)) * 0.05
38+
true_states.append(true_state)
39+
observations.append(noisy_bearing)
40+
true_states = jnp.stack(true_states)
41+
observations = jnp.stack(observations)
42+
43+
# --- Track with EKF ---
44+
init = GaussianState(
45+
mean=jnp.array([5.0, 0.3, 5.0, 0.2]),
46+
cov=jnp.diag(jnp.array([1.0, 0.5, 1.0, 0.5])),
47+
)
48+
result = ekf_filter(model, observations, initial_state=init)
49+
50+
# --- Plot trajectory ---
51+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
52+
53+
# 2D trajectory
54+
ax = axes[0]
55+
ax.plot(
56+
np.asarray(true_states[:, 0]),
57+
np.asarray(true_states[:, 2]),
58+
"k-",
59+
linewidth=1.5,
60+
label="True path",
61+
)
62+
ax.plot(
63+
np.asarray(result.filtered_states[:, 0]),
64+
np.asarray(result.filtered_states[:, 2]),
65+
"o-",
66+
markersize=2,
67+
linewidth=0.8,
68+
label="EKF estimate",
69+
)
70+
ax.plot(0, 0, "r^", markersize=10, label="Sensor")
71+
ax.set_xlabel("x")
72+
ax.set_ylabel("y")
73+
ax.set_title("2D Target Trajectory")
74+
ax.legend(frameon=False)
75+
ax.set_aspect("equal")
76+
77+
# Bearing observations
78+
ax = axes[1]
79+
ax.plot(np.asarray(observations[:, 0]), "k.", markersize=2, alpha=0.5, label="Observed bearings")
80+
true_bearings = np.asarray(jax.vmap(model.h)(true_states)[:, 0])
81+
ax.plot(true_bearings, "r-", linewidth=1, label="True bearing")
82+
ax.set_xlabel("Time step")
83+
ax.set_ylabel("Bearing (rad)")
84+
ax.set_title("Bearing Observations")
85+
ax.legend(frameon=False)
86+
87+
fig.suptitle("Bearings-Only Tracking with EKF")
88+
fig.tight_layout()
89+
plt.show()

examples/lorenz_tracking.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Lorenz attractor tracking — nonlinear filter comparison.
2+
3+
Demonstrates EKF, UKF, and Particle Filter on the chaotic Lorenz system.
4+
The Lorenz attractor is a classic benchmark for nonlinear filtering,
5+
with three coupled differential equations producing deterministic chaos.
6+
7+
This example:
8+
1. Simulates a trajectory from the Lorenz system
9+
2. Generates noisy partial observations (x and y only)
10+
3. Tracks the full 3D state using EKF, UKF, and Particle Filter
11+
4. Compares filter accuracy via correlation with the true state
12+
"""
13+
14+
import jax
15+
import jax.numpy as jnp
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
19+
from dynaris.core.types import GaussianState
20+
from dynaris.filters.ekf import ekf_filter
21+
from dynaris.filters.particle import particle_filter
22+
from dynaris.filters.ukf import ukf_filter
23+
from dynaris.models import LorenzAttractor
24+
25+
# --- Model ---
26+
model = LorenzAttractor(dt=0.01, process_noise=0.5, obs_noise=2.0, obs_dims=2)
27+
28+
# --- Simulate true trajectory ---
29+
key = jax.random.PRNGKey(42)
30+
k1, k2, k_pf = jax.random.split(key, 3)
31+
n_steps = 500
32+
33+
state = jnp.array([1.0, 1.0, 1.0])
34+
true_states = []
35+
observations = []
36+
for _ in range(n_steps):
37+
k1, k_state, k_obs = jax.random.split(k1, 3)
38+
state = model.f(state) + jax.random.normal(k_state, (3,)) * 0.5
39+
obs = model.h(state) + jax.random.normal(k_obs, (2,)) * 2.0
40+
true_states.append(state)
41+
observations.append(obs)
42+
true_states = jnp.stack(true_states)
43+
observations = jnp.stack(observations)
44+
45+
# --- Initial state ---
46+
init = GaussianState(mean=jnp.array([1.0, 1.0, 1.0]), cov=jnp.eye(3) * 10.0)
47+
48+
# --- Run filters ---
49+
print("Running EKF...")
50+
ekf_result = ekf_filter(model, observations, initial_state=init)
51+
52+
print("Running UKF...")
53+
ukf_result = ukf_filter(model, observations, initial_state=init, alpha=1.0)
54+
55+
print("Running Particle Filter (1000 particles)...")
56+
pf_result = particle_filter(model, observations, n_particles=1000, key=k_pf, initial_state=init)
57+
58+
# --- Compare accuracy ---
59+
print("\nFilter accuracy (correlation with true x-component):")
60+
for name, result in [("EKF", ekf_result), ("UKF", ukf_result), ("PF", pf_result)]:
61+
corr = float(jnp.corrcoef(jnp.stack([result.filtered_states[:, 0], true_states[:, 0]]))[0, 1])
62+
print(f" {name}: r = {corr:.4f}")
63+
64+
# --- Plot ---
65+
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
66+
labels = ["x", "y", "z"]
67+
for i, ax in enumerate(axes):
68+
ax.plot(np.asarray(true_states[:, i]), "k-", alpha=0.4, linewidth=0.8, label="True")
69+
ax.plot(np.asarray(ekf_result.filtered_states[:, i]), label="EKF", linewidth=1.0)
70+
ax.plot(np.asarray(ukf_result.filtered_states[:, i]), label="UKF", linewidth=1.0)
71+
ax.plot(np.asarray(pf_result.filtered_states[:, i]), label="PF", linewidth=1.0)
72+
ax.set_ylabel(labels[i])
73+
if i == 0:
74+
ax.legend(frameon=False, ncol=4)
75+
axes[-1].set_xlabel("Time step")
76+
fig.suptitle("Lorenz Attractor Tracking — EKF vs UKF vs Particle Filter")
77+
fig.tight_layout()
78+
plt.show()

0 commit comments

Comments
 (0)