Skip to content

Commit 968da0c

Browse files
authored
Merge branch 'master' into CRM
2 parents cdba9e9 + 2d86c83 commit 968da0c

5 files changed

Lines changed: 76 additions & 95 deletions

File tree

examples/gLV/examples-Rutter-Dekker.ipynb

Lines changed: 6 additions & 36 deletions
Large diffs are not rendered by default.

examples/gLV/examples-bayes-gLV.ipynb

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
}
9191
],
9292
"source": [
93-
"# read in pickled simulated parameters, mu, M, epsilon\n",
93+
"# read in pickled simulated parameters, mu, M, epsilon, created in examples-sim-gLV.ipynb\n",
9494
"num_species = 5\n",
9595
"with open(\"params-s5.pkl\", \"rb\") as f:\n",
9696
" params = pickle.load(f)\n",
@@ -661,20 +661,6 @@
661661
"#inference.plot_posterior(idata)\n",
662662
"\n",
663663
"\n",
664-
"\n",
665-
"\n",
666-
"#nX = num_species\n",
667-
"#n_obs = times.shape[0] - 1\n",
668-
"#noise_stddev = 0.1\n",
669-
"\n",
670-
"# Params for shrinkage on M_ij (non diagonal elements)\n",
671-
"#DA = nX*nX - nX\n",
672-
"#DA0 = 3 # expected number of non zero entries in M_ij\n",
673-
"#N = n_obs - 2\n",
674-
"\n",
675-
"#inference = infergLVbayes(X, F, mu_prior, M_prior, DA=DA, DA0=DA0, N=N, noise_stddev=noise_stddev)\n",
676-
"#idata = inference.run_bayes_gLV_shrinkage()\n",
677-
"\n",
678664
"# print summary\n",
679665
"summary = az.summary(idata, var_names=[\"mu_hat\", \"M_ii_hat\", \"M_ij_hat\", \"M_hat\", \"sigma\"])\n",
680666
"print(summary[[\"mean\", \"sd\", \"r_hat\"]])\n",
@@ -705,7 +691,7 @@
705691
},
706692
{
707693
"cell_type": "code",
708-
"execution_count": 8,
694+
"execution_count": null,
709695
"id": "c6d6c2df",
710696
"metadata": {
711697
"ExecuteTime": {
@@ -951,20 +937,6 @@
951937
"#inference.plot_posterior_pert(idata)\n",
952938
"\n",
953939
"\n",
954-
"\n",
955-
"\n",
956-
"#nX = num_species\n",
957-
"#n_obs = times.shape[0] - 1\n",
958-
"#noise_stddev = 0.1\n",
959-
"\n",
960-
"# Params for shrinkage on M_ij (non diagonal elements)\n",
961-
"#DA = nX*nX - nX\n",
962-
"#DA0 = 3 # expected number of non zero entries in M_ij\n",
963-
"#N = n_obs - 2\n",
964-
"\n",
965-
"#inference = infergLVbayes(X, F, mu_prior, M_prior, DA=DA, DA0=DA0, N=N, noise_stddev=noise_stddev, epsilon=epsilon)\n",
966-
"#idata = inference.run_bayes_gLV_shrinkage_pert()\n",
967-
"\n",
968940
"# print summary\n",
969941
"summary = az.summary(idata, var_names=[\"mu_hat\", \"M_ii_hat\", \"M_ij_hat\", \"M_hat\", \"epsilon_hat\", \"sigma\"])\n",
970942
"print(summary[[\"mean\", \"sd\", \"r_hat\"]])\n",

mimic/model_infer/infer_CRM_bayes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,10 @@ def run_inference(self) -> None:
459459
with bayes_model:
460460
# Priors for unknown model parameters
461461

462+
462463
sigma = pm.HalfNormal('sigma', sigma = 0.1, shape=(1,)) # Same sigma for all responses
463464

465+
464466
# Conditionally define parameters based on whether priors are
465467
# provided
466468

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/utilities/utilities.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@
1010

1111

1212
def plot_gLV(yobs, timepoints):
13-
# fig, axs = plt.subplots(1, 2, layout='constrained')
13+
# fig, axs = plt.subplots(1, 2, layout='constrained') # Optional
14+
# alternative
1415
fig, axs = plt.subplots(1, 1)
1516
for species_idx in range(yobs.shape[1]):
16-
axs.plot(timepoints, yobs[:, species_idx], color=cols[species_idx])
17+
label = f'Species {species_idx + 1}' # Add a label for each species
18+
axs.plot(timepoints, yobs[:, species_idx],
19+
color=cols[species_idx], label=label)
20+
1721
axs.set_xlabel('time')
1822
axs.set_ylabel('[species]')
23+
axs.legend() # Ensure the legend is called on the correct axes
24+
plt.show()
1925

2026

2127
def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
@@ -102,49 +108,80 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
102108

103109
return fig, ax
104110

105-
def plot_CRM_with_intervals(observed_species, observed_resources, species_lower, species_upper,
106-
resource_lower, resource_upper, times, filename=None):
111+
112+
def plot_CRM_with_intervals(
113+
observed_species,
114+
observed_resources,
115+
species_lower,
116+
species_upper,
117+
resource_lower,
118+
resource_upper,
119+
times,
120+
filename=None):
107121
fig, ax = plt.subplots(figsize=(12, 8))
108-
122+
109123
# Plot median trajectories
110124
for i in range(observed_species.shape[1]):
111-
ax.plot(times, observed_species[:, i], label=f'Species {i+1}', linewidth=2)
112-
125+
ax.plot(times, observed_species[:, i],
126+
label=f'Species {i+1}', linewidth=2)
127+
113128
for i in range(observed_resources.shape[1]):
114-
ax.plot(times, observed_resources[:, i], label=f'Resource {i+1}', linewidth=2, linestyle='--')
115-
116-
# Add confidence ribbons
129+
ax.plot(times,
130+
observed_resources[:,
131+
i],
132+
label=f'Resource {i+1}',
133+
linewidth=2,
134+
linestyle='--')
135+
136+
# Add confidence ribbons
117137
for i in range(observed_species.shape[1]):
118-
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
119-
alpha=0.2, color=plt.cm.tab10(i))
120-
138+
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
139+
alpha=0.2, color=plt.cm.tab10(i))
140+
121141
for i in range(observed_resources.shape[1]):
122-
ax.fill_between(times, resource_lower[:, i], resource_upper[:, i],
123-
alpha=0.2, color=plt.cm.tab10(i + observed_species.shape[1]))
124-
142+
ax.fill_between(times,
143+
resource_lower[:,
144+
i],
145+
resource_upper[:,
146+
i],
147+
alpha=0.2,
148+
color=plt.cm.tab10(i + observed_species.shape[1]))
149+
125150
if filename:
126151
true_data = pd.read_csv(filename)
127152
true_times = true_data['time'].values
128-
153+
129154
for i in range(observed_species.shape[1]):
130155
col_name = f'species_{i+1}'
131156
if col_name in true_data.columns:
132-
ax.scatter(true_times, true_data[col_name],
133-
marker='o', s=30, color=plt.cm.tab10(i), label=f'True {col_name}')
134-
157+
ax.scatter(
158+
true_times,
159+
true_data[col_name],
160+
marker='o',
161+
s=30,
162+
color=plt.cm.tab10(i),
163+
label=f'True {col_name}')
164+
135165
for i in range(observed_resources.shape[1]):
136166
col_name = f'resource_{i+1}'
137167
if col_name in true_data.columns:
138-
ax.scatter(true_times, true_data[col_name],
139-
marker='s', s=30, color=plt.cm.tab10(i + observed_species.shape[1]),
140-
label=f'True {col_name}')
141-
168+
ax.scatter(
169+
true_times,
170+
true_data[col_name],
171+
marker='s',
172+
s=30,
173+
color=plt.cm.tab10(
174+
i + observed_species.shape[1]),
175+
label=f'True {col_name}')
176+
142177
ax.set_xlabel('Time', fontsize=14)
143178
ax.set_ylabel('Concentration', fontsize=14)
144-
ax.set_title('Consumer-Resource Model Dynamics with 95% Credible Intervals', fontsize=16)
179+
ax.set_title(
180+
'Consumer-Resource Model Dynamics with 95% Credible Intervals',
181+
fontsize=16)
145182
ax.legend(loc='best', fontsize=12)
146183
ax.grid(True, alpha=0.3)
147-
184+
148185
plt.tight_layout()
149186
if filename:
150187
plt.savefig(f"{filename.split('.')[0]}_with_intervals.png", dpi=300)

0 commit comments

Comments
 (0)