|
| 1 | +"""Bayesian estimation — airline passengers with trend + seasonality. |
| 2 | +
|
| 3 | +Demonstrates full Bayesian inference for a composed DLM (local linear |
| 4 | +trend + seasonal) using NumPyro's NUTS sampler. The airline passenger |
| 5 | +data is log-transformed to convert multiplicative seasonality to additive. |
| 6 | +
|
| 7 | +This example: |
| 8 | +1. Fits a trend + seasonal model to log(airline passengers) via MLE |
| 9 | +2. Fits the same model via Bayesian MCMC with a normal prior |
| 10 | +3. Compares posterior mean with MLE point estimates |
| 11 | +4. Generates posterior predictive forecasts with credible intervals |
| 12 | +5. Runs prior predictive checks |
| 13 | +""" |
| 14 | + |
| 15 | +import jax |
| 16 | +import jax.numpy as jnp |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from dynaris import LocalLinearTrend, Seasonal |
| 21 | +from dynaris.core.state_space import StateSpaceModel |
| 22 | +from dynaris.datasets import load_airline |
| 23 | +from dynaris.estimation import fit_bayesian, fit_mle |
| 24 | +from dynaris.estimation.predictive import posterior_predictive_forecast, prior_predictive |
| 25 | +from dynaris.estimation.priors import normal_log_prior |
| 26 | + |
| 27 | +# --- Data (log-transform for additive seasonality) --- |
| 28 | +y = load_airline() |
| 29 | +obs = jnp.log(jnp.array(y.values, dtype=jnp.float32)).reshape(-1, 1) |
| 30 | +print(f"Airline passengers: {len(obs)} monthly observations (log-transformed)") |
| 31 | + |
| 32 | + |
| 33 | +# --- Model factory: trend + seasonal (period=12) --- |
| 34 | +# params: [log_sigma_level, log_sigma_slope, log_sigma_seasonal, log_sigma_obs] |
| 35 | +def model_fn(params: jax.Array) -> StateSpaceModel: |
| 36 | + trend = LocalLinearTrend( |
| 37 | + sigma_level=jnp.exp(params[0]), |
| 38 | + sigma_slope=jnp.exp(params[1]), |
| 39 | + sigma_obs=0.0, |
| 40 | + ) |
| 41 | + seasonal = Seasonal( |
| 42 | + period=12, |
| 43 | + sigma_seasonal=jnp.exp(params[2]), |
| 44 | + sigma_obs=jnp.exp(params[3]), |
| 45 | + ) |
| 46 | + return trend + seasonal |
| 47 | + |
| 48 | + |
| 49 | +# --- MLE fit (baseline) --- |
| 50 | +print("\nFitting MLE...") |
| 51 | +init_params = jnp.array([0.0, -2.0, -2.0, -1.0]) |
| 52 | +mle_result = fit_mle(model_fn, obs, init_params) |
| 53 | +print(f" MLE params: {np.round(np.asarray(mle_result.params), 3)}") |
| 54 | +print(f" MLE log-likelihood: {mle_result.log_likelihood:.2f}") |
| 55 | + |
| 56 | +# --- Bayesian fit --- |
| 57 | +print("\nRunning Bayesian MCMC (NUTS)...") |
| 58 | +prior = normal_log_prior(loc=0.0, scale=5.0) |
| 59 | +bayes_result = fit_bayesian( |
| 60 | + model_fn, |
| 61 | + obs, |
| 62 | + init_params=mle_result.params, |
| 63 | + log_prior_fn=prior, |
| 64 | + n_warmup=500, |
| 65 | + n_samples=1000, |
| 66 | + key=jax.random.PRNGKey(42), |
| 67 | + param_names=( |
| 68 | + "log_sigma_level", |
| 69 | + "log_sigma_slope", |
| 70 | + "log_sigma_seasonal", |
| 71 | + "log_sigma_obs", |
| 72 | + ), |
| 73 | +) |
| 74 | + |
| 75 | +samples = bayes_result.samples |
| 76 | +names = bayes_result.param_names or () |
| 77 | +print(" Posterior means:") |
| 78 | +for j, name in enumerate(names): |
| 79 | + post_mean = float(jnp.mean(samples[:, j])) |
| 80 | + mle_val = float(mle_result.params[j]) |
| 81 | + print(f" {name}: {post_mean:.3f} (MLE: {mle_val:.3f})") |
| 82 | +mean_ll = float(jnp.mean(bayes_result.log_likelihood_samples)) |
| 83 | +print(f" Posterior mean log-likelihood: {mean_ll:.2f}") |
| 84 | +if bayes_result.info: |
| 85 | + print(f" Divergences: {bayes_result.info.get('n_divergences', 'N/A')}") |
| 86 | + |
| 87 | +# --- Posterior predictive forecast (24 months ahead) --- |
| 88 | +print("\nGenerating posterior predictive forecast (24 months)...") |
| 89 | +fc = posterior_predictive_forecast( |
| 90 | + bayes_result, |
| 91 | + model_fn, |
| 92 | + obs, |
| 93 | + steps=24, |
| 94 | + n_posterior_samples=200, |
| 95 | +) |
| 96 | + |
| 97 | +# --- Prior predictive --- |
| 98 | +print("Running prior predictive check...") |
| 99 | + |
| 100 | + |
| 101 | +def prior_sample_fn(key: jax.Array) -> jax.Array: |
| 102 | + return jax.random.normal(key, (4,)) * 1.0 + mle_result.params |
| 103 | + |
| 104 | + |
| 105 | +prior_sims = prior_predictive(model_fn, prior_sample_fn, n_steps=144, n_samples=50) |
| 106 | + |
| 107 | +# --- Plots --- |
| 108 | +fig, axes = plt.subplots(2, 2, figsize=(14, 9)) |
| 109 | + |
| 110 | +# Posterior parameter distributions |
| 111 | +ax = axes[0, 0] |
| 112 | +for j, name in enumerate(names): |
| 113 | + ax.hist(np.asarray(samples[:, j]), bins=25, alpha=0.6, label=name) |
| 114 | + ax.axvline(float(mle_result.params[j]), color=f"C{j}", linestyle="--", linewidth=1) |
| 115 | +ax.set_xlabel("Parameter value (log scale)") |
| 116 | +ax.set_ylabel("Count") |
| 117 | +ax.set_title("Posterior Distributions (dashed = MLE)") |
| 118 | +ax.legend(frameon=False, fontsize=7) |
| 119 | + |
| 120 | +# Posterior predictive forecast (back-transformed to passenger scale) |
| 121 | +ax = axes[0, 1] |
| 122 | +n_obs = len(obs) |
| 123 | +time_obs = np.arange(n_obs) |
| 124 | +time_fc = np.arange(n_obs, n_obs + 24) |
| 125 | +ax.plot( |
| 126 | + time_obs[-48:], |
| 127 | + np.exp(np.asarray(obs[-48:, 0])), |
| 128 | + "k.-", |
| 129 | + markersize=2, |
| 130 | + label="Observed", |
| 131 | +) |
| 132 | +ax.plot( |
| 133 | + time_fc, |
| 134 | + np.exp(np.asarray(fc["mean"][:, 0])), |
| 135 | + "C0-", |
| 136 | + linewidth=1.5, |
| 137 | + label="Forecast mean", |
| 138 | +) |
| 139 | +ax.fill_between( |
| 140 | + time_fc, |
| 141 | + np.exp(np.asarray(fc["lower"][:, 0])), |
| 142 | + np.exp(np.asarray(fc["upper"][:, 0])), |
| 143 | + alpha=0.3, |
| 144 | + color="C0", |
| 145 | + label="95% credible interval", |
| 146 | +) |
| 147 | +ax.set_xlabel("Month index") |
| 148 | +ax.set_ylabel("Passengers") |
| 149 | +ax.set_title("Posterior Predictive Forecast (24 months)") |
| 150 | +ax.legend(frameon=False, fontsize=8) |
| 151 | + |
| 152 | +# Prior predictive (log scale — simulations start from zero) |
| 153 | +ax = axes[1, 0] |
| 154 | +for i in range(min(20, prior_sims.shape[0])): |
| 155 | + ax.plot(np.asarray(prior_sims[i, :, 0]), alpha=0.4, linewidth=0.5, color="gray") |
| 156 | +ax.axhline(float(jnp.mean(obs)), color="k", linestyle="--", linewidth=0.8, label="Data mean") |
| 157 | +ax.set_xlabel("Time step") |
| 158 | +ax.set_ylabel("log(passengers)") |
| 159 | +ax.set_title("Prior Predictive Samples (log scale)") |
| 160 | +ax.legend(frameon=False, fontsize=8) |
| 161 | + |
| 162 | +# Trace plot |
| 163 | +ax = axes[1, 1] |
| 164 | +for j, name in enumerate(names): |
| 165 | + ax.plot(np.asarray(samples[:, j]), alpha=0.6, linewidth=0.4, label=name) |
| 166 | +ax.set_xlabel("Sample index") |
| 167 | +ax.set_ylabel("Value") |
| 168 | +ax.set_title("MCMC Traces") |
| 169 | +ax.legend(frameon=False, fontsize=7) |
| 170 | + |
| 171 | +fig.suptitle("Bayesian Estimation — Airline Passengers (Trend + Seasonality)") |
| 172 | +fig.tight_layout() |
| 173 | +plt.show() |
0 commit comments