Skip to content

Commit ca706ef

Browse files
authored
Merge pull request #103 from ucl-cssb/improve-bayes-gLV-examples
Improve bayes g lv examples
2 parents 8da2d66 + 1eb50f7 commit ca706ef

4 files changed

Lines changed: 19 additions & 71 deletions

File tree

examples/gLV/examples-Rutter-Dekker.ipynb

Lines changed: 6 additions & 36 deletions
Large diffs are not rendered by default.

examples/gLV/examples-bayes-gLV.ipynb

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@
247247
}
248248
],
249249
"source": [
250-
"# read in pickled simulated parameters, mu, M, epsilon\n",
250+
"# read in pickled simulated parameters, mu, M, epsilon, created in examples-sim-gLV.ipynb\n",
251251
"num_species = 5\n",
252252
"with open(\"params-s5.pkl\", \"rb\") as f:\n",
253253
" params = pickle.load(f)\n",
@@ -818,20 +818,6 @@
818818
"#inference.plot_posterior(idata)\n",
819819
"\n",
820820
"\n",
821-
"\n",
822-
"\n",
823-
"#nX = num_species\n",
824-
"#n_obs = times.shape[0] - 1\n",
825-
"#noise_stddev = 0.1\n",
826-
"\n",
827-
"# Params for shrinkage on M_ij (non diagonal elements)\n",
828-
"#DA = nX*nX - nX\n",
829-
"#DA0 = 3 # expected number of non zero entries in M_ij\n",
830-
"#N = n_obs - 2\n",
831-
"\n",
832-
"#inference = infergLVbayes(X, F, mu_prior, M_prior, DA=DA, DA0=DA0, N=N, noise_stddev=noise_stddev)\n",
833-
"#idata = inference.run_bayes_gLV_shrinkage()\n",
834-
"\n",
835821
"# print summary\n",
836822
"summary = az.summary(idata, var_names=[\"mu_hat\", \"M_ii_hat\", \"M_ij_hat\", \"M_hat\", \"sigma\"])\n",
837823
"print(summary[[\"mean\", \"sd\", \"r_hat\"]])\n",
@@ -862,7 +848,7 @@
862848
},
863849
{
864850
"cell_type": "code",
865-
"execution_count": 8,
851+
"execution_count": null,
866852
"id": "c6d6c2df",
867853
"metadata": {
868854
"ExecuteTime": {
@@ -1108,20 +1094,6 @@
11081094
"#inference.plot_posterior_pert(idata)\n",
11091095
"\n",
11101096
"\n",
1111-
"\n",
1112-
"\n",
1113-
"#nX = num_species\n",
1114-
"#n_obs = times.shape[0] - 1\n",
1115-
"#noise_stddev = 0.1\n",
1116-
"\n",
1117-
"# Params for shrinkage on M_ij (non diagonal elements)\n",
1118-
"#DA = nX*nX - nX\n",
1119-
"#DA0 = 3 # expected number of non zero entries in M_ij\n",
1120-
"#N = n_obs - 2\n",
1121-
"\n",
1122-
"#inference = infergLVbayes(X, F, mu_prior, M_prior, DA=DA, DA0=DA0, N=N, noise_stddev=noise_stddev, epsilon=epsilon)\n",
1123-
"#idata = inference.run_bayes_gLV_shrinkage_pert()\n",
1124-
"\n",
11251097
"# print summary\n",
11261098
"summary = az.summary(idata, var_names=[\"mu_hat\", \"M_ii_hat\", \"M_ij_hat\", \"M_hat\", \"epsilon_hat\", \"sigma\"])\n",
11271099
"print(summary[[\"mean\", \"sd\", \"r_hat\"]])\n",

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/utilities/utilities.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99

1010

1111
def plot_gLV(yobs, timepoints):
12-
# fig, axs = plt.subplots(1, 2, layout='constrained')
12+
# fig, axs = plt.subplots(1, 2, layout='constrained') # Optional
13+
# alternative
1314
fig, axs = plt.subplots(1, 1)
1415
for species_idx in range(yobs.shape[1]):
15-
axs.plot(timepoints, yobs[:, species_idx], color=cols[species_idx])
16+
label = f'Species {species_idx + 1}' # Add a label for each species
17+
axs.plot(timepoints, yobs[:, species_idx],
18+
color=cols[species_idx], label=label)
19+
1620
axs.set_xlabel('time')
1721
axs.set_ylabel('[species]')
22+
axs.legend() # Ensure the legend is called on the correct axes
23+
plt.show()
1824

1925

2026
def plot_gMLV(yobs, sobs, timepoints):

0 commit comments

Comments
 (0)