Skip to content

Commit 2d86c83

Browse files
authored
Merge pull request #105 from ucl-cssb/CRM
CRM update
2 parents c069483 + 5476483 commit 2d86c83

4 files changed

Lines changed: 242 additions & 1081 deletions

File tree

examples/CRM/examples-bayes-CRM.ipynb

Lines changed: 150 additions & 1068 deletions
Large diffs are not rendered by default.

examples/CRM/examples-sim-CRM.ipynb

Lines changed: 8 additions & 8 deletions
Large diffs are not rendered by default.

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/utilities/utilities.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import random
55
import numpy as np
6+
import pandas as pd
67
import matplotlib.pyplot as plt
78

89
cols = ["red", "green", "blue", "royalblue", "orange", "black"]
@@ -45,8 +46,7 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
4546
# Use a different color index for resources (continuing from where
4647
# species left off)
4748
color_idx = observed_species.shape[1] + resource_idx
48-
# Use modulo to cycle through colors if we have more entities than
49-
# colors
49+
5050
color_idx = color_idx % len(cols)
5151

5252
label = f'Resource {resource_idx + 1}'
@@ -109,6 +109,85 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
109109
return fig, ax
110110

111111

112+
def plot_CRM_with_intervals(
113+
observed_species,
114+
observed_resources,
115+
species_lower,
116+
species_upper,
117+
resource_lower,
118+
resource_upper,
119+
times,
120+
filename=None):
121+
fig, ax = plt.subplots(figsize=(12, 8))
122+
123+
# Plot median trajectories
124+
for i in range(observed_species.shape[1]):
125+
ax.plot(times, observed_species[:, i],
126+
label=f'Species {i+1}', linewidth=2)
127+
128+
for i in range(observed_resources.shape[1]):
129+
ax.plot(times,
130+
observed_resources[:,
131+
i],
132+
label=f'Resource {i+1}',
133+
linewidth=2,
134+
linestyle='--')
135+
136+
# Add confidence ribbons
137+
for i in range(observed_species.shape[1]):
138+
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
139+
alpha=0.2, color=plt.cm.tab10(i))
140+
141+
for i in range(observed_resources.shape[1]):
142+
ax.fill_between(times,
143+
resource_lower[:,
144+
i],
145+
resource_upper[:,
146+
i],
147+
alpha=0.2,
148+
color=plt.cm.tab10(i + observed_species.shape[1]))
149+
150+
if filename:
151+
true_data = pd.read_csv(filename)
152+
true_times = true_data['time'].values
153+
154+
for i in range(observed_species.shape[1]):
155+
col_name = f'species_{i+1}'
156+
if col_name in true_data.columns:
157+
ax.scatter(
158+
true_times,
159+
true_data[col_name],
160+
marker='o',
161+
s=30,
162+
color=plt.cm.tab10(i),
163+
label=f'True {col_name}')
164+
165+
for i in range(observed_resources.shape[1]):
166+
col_name = f'resource_{i+1}'
167+
if col_name in true_data.columns:
168+
ax.scatter(
169+
true_times,
170+
true_data[col_name],
171+
marker='s',
172+
s=30,
173+
color=plt.cm.tab10(
174+
i + observed_species.shape[1]),
175+
label=f'True {col_name}')
176+
177+
ax.set_xlabel('Time', fontsize=14)
178+
ax.set_ylabel('Concentration', fontsize=14)
179+
ax.set_title(
180+
'Consumer-Resource Model Dynamics with 95% Credible Intervals',
181+
fontsize=16)
182+
ax.legend(loc='best', fontsize=12)
183+
ax.grid(True, alpha=0.3)
184+
185+
plt.tight_layout()
186+
if filename:
187+
plt.savefig(f"{filename.split('.')[0]}_with_intervals.png", dpi=300)
188+
plt.show()
189+
190+
112191
def plot_gMLV(yobs, sobs, timepoints):
113192
# fig, axs = plt.subplots(1, 2, layout='constrained')
114193
fig, axs = plt.subplots(1, 2)

0 commit comments

Comments
 (0)