Skip to content

Commit 5476483

Browse files
committed
Update requirements.txt and format code
1 parent 63d23e3 commit 5476483

3 files changed

Lines changed: 59 additions & 28 deletions

File tree

mimic/model_infer/infer_CRM_bayes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def run_inference(self) -> None:
441441
# Priors for unknown model parameters
442442

443443
sigma = pm.HalfNormal(
444-
'sigma', sigma = 0.5, shape=(
444+
'sigma', sigma=0.5, shape=(
445445
1,)) # Same sigma for all responses
446446

447447
# Conditionally define parameters based on whether priors are

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/utilities/utilities.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,49 +102,80 @@ def plot_CRM(observed_species, observed_resources, timepoints, csv_file=None):
102102

103103
return fig, ax
104104

105-
def plot_CRM_with_intervals(observed_species, observed_resources, species_lower, species_upper,
106-
resource_lower, resource_upper, times, filename=None):
105+
106+
def plot_CRM_with_intervals(
107+
observed_species,
108+
observed_resources,
109+
species_lower,
110+
species_upper,
111+
resource_lower,
112+
resource_upper,
113+
times,
114+
filename=None):
107115
fig, ax = plt.subplots(figsize=(12, 8))
108-
116+
109117
# Plot median trajectories
110118
for i in range(observed_species.shape[1]):
111-
ax.plot(times, observed_species[:, i], label=f'Species {i+1}', linewidth=2)
112-
119+
ax.plot(times, observed_species[:, i],
120+
label=f'Species {i+1}', linewidth=2)
121+
113122
for i in range(observed_resources.shape[1]):
114-
ax.plot(times, observed_resources[:, i], label=f'Resource {i+1}', linewidth=2, linestyle='--')
115-
116-
# Add confidence ribbons
123+
ax.plot(times,
124+
observed_resources[:,
125+
i],
126+
label=f'Resource {i+1}',
127+
linewidth=2,
128+
linestyle='--')
129+
130+
# Add confidence ribbons
117131
for i in range(observed_species.shape[1]):
118-
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
119-
alpha=0.2, color=plt.cm.tab10(i))
120-
132+
ax.fill_between(times, species_lower[:, i], species_upper[:, i],
133+
alpha=0.2, color=plt.cm.tab10(i))
134+
121135
for i in range(observed_resources.shape[1]):
122-
ax.fill_between(times, resource_lower[:, i], resource_upper[:, i],
123-
alpha=0.2, color=plt.cm.tab10(i + observed_species.shape[1]))
124-
136+
ax.fill_between(times,
137+
resource_lower[:,
138+
i],
139+
resource_upper[:,
140+
i],
141+
alpha=0.2,
142+
color=plt.cm.tab10(i + observed_species.shape[1]))
143+
125144
if filename:
126145
true_data = pd.read_csv(filename)
127146
true_times = true_data['time'].values
128-
147+
129148
for i in range(observed_species.shape[1]):
130149
col_name = f'species_{i+1}'
131150
if col_name in true_data.columns:
132-
ax.scatter(true_times, true_data[col_name],
133-
marker='o', s=30, color=plt.cm.tab10(i), label=f'True {col_name}')
134-
151+
ax.scatter(
152+
true_times,
153+
true_data[col_name],
154+
marker='o',
155+
s=30,
156+
color=plt.cm.tab10(i),
157+
label=f'True {col_name}')
158+
135159
for i in range(observed_resources.shape[1]):
136160
col_name = f'resource_{i+1}'
137161
if col_name in true_data.columns:
138-
ax.scatter(true_times, true_data[col_name],
139-
marker='s', s=30, color=plt.cm.tab10(i + observed_species.shape[1]),
140-
label=f'True {col_name}')
141-
162+
ax.scatter(
163+
true_times,
164+
true_data[col_name],
165+
marker='s',
166+
s=30,
167+
color=plt.cm.tab10(
168+
i + observed_species.shape[1]),
169+
label=f'True {col_name}')
170+
142171
ax.set_xlabel('Time', fontsize=14)
143172
ax.set_ylabel('Concentration', fontsize=14)
144-
ax.set_title('Consumer-Resource Model Dynamics with 95% Credible Intervals', fontsize=16)
173+
ax.set_title(
174+
'Consumer-Resource Model Dynamics with 95% Credible Intervals',
175+
fontsize=16)
145176
ax.legend(loc='best', fontsize=12)
146177
ax.grid(True, alpha=0.3)
147-
178+
148179
plt.tight_layout()
149180
if filename:
150181
plt.savefig(f"{filename.split('.')[0]}_with_intervals.png", dpi=300)

0 commit comments

Comments
 (0)