Skip to content

Commit 8da2d66

Browse files
authored
All tests passed!
Improve bayes g lv examples
2 parents 6080583 + 35c5af0 commit 8da2d66

5 files changed

Lines changed: 482 additions & 410 deletions

File tree

examples/gLV/examples-Rutter-Dekker.ipynb

Lines changed: 145 additions & 47 deletions
Large diffs are not rendered by default.

examples/gLV/examples-bayes-gLV.ipynb

Lines changed: 281 additions & 298 deletions
Large diffs are not rendered by default.

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/model_infer/infer_gLV_bayes.py

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,19 @@
22
import matplotlib.pyplot as plt
33
import numpy as np
44
import pymc as pm
5+
import pandas as pd
6+
import seaborn as sns
57
import pytensor.tensor as at
68
import pickle
79
import cloudpickle
10+
import os
11+
from typing import Optional, Union, List, Dict, Any
812

913
from mimic.utilities import *
1014
from mimic.model_simulate.sim_gLV import *
1115
from mimic.model_infer.base_infer import BaseInfer
1216

13-
from mimic.model_infer.base_infer import BaseInfer
14-
15-
import os
16-
from typing import Optional, Union, List, Dict, Any
17-
18-
19-
import pandas as pd
20-
import numpy as np
21-
import seaborn as sns
22-
import matplotlib.pyplot as plt
23-
2417

25-
# Used in examples-Stein.ipynb
2618
def plot_params(mu_h, M_h, e_h, nsp):
2719
print("\ninferred params:")
2820
print("mu_hat/mu:")
@@ -69,8 +61,9 @@ def __init__(self,
6961
prior_mu_sigma=None,
7062
prior_Mii_mean=None,
7163
prior_Mii_sigma=None,
72-
prior_Mij_sigma=None
73-
):
64+
prior_Mij_sigma=None):
65+
66+
super().__init__() # Call base class constructor
7467

7568
# self.data = data # data to do inference on
7669
self.X: Optional[np.ndarray] = X
@@ -217,7 +210,7 @@ def calculate_DA0(self, num_species, proportion=0.15):
217210
DA0 = int(round(expected_non_zero_elements))
218211
return max(DA0, 1)
219212

220-
def run_inference(self) -> None:
213+
def run_inference(self, **kwargs) -> None:
221214
"""
222215
This function infers the parameters for the Bayesian gLV model
223216
@@ -314,9 +307,6 @@ def run_inference(self) -> None:
314307
# eg
315308
# print(f"mu_hat: {mu_hat.eval()}")
316309

317-
# initial_values = bayes_model.initial_point()
318-
# print(f"Initial parameter values: {initial_values}")
319-
320310
# Posterior distribution
321311
idata = pm.sample(
322312
draws=draws,
@@ -327,7 +317,7 @@ def run_inference(self) -> None:
327317

328318
return idata
329319

330-
def run_inference_shrinkage(self) -> None:
320+
def run_inference_shrinkage(self, **kwargs) -> None:
331321
"""
332322
This function infers the parameters for the Bayesian gLV model with Horseshoe prior for shrinkage
333323
@@ -426,14 +416,15 @@ def run_inference_shrinkage(self) -> None:
426416
Y_obs = pm.Normal('Y_obs', mu=model_mean, sigma=sigma, observed=F)
427417

428418
# For debugging:
419+
# print if `debug` is set to 'high' or 'low'
420+
if self.debug in ["high", "low"]:
421+
initial_values = bayes_model.initial_point()
422+
print(f"Initial parameter values: {initial_values}")
429423

430424
# As tensor objects are symbolic, if needed print using .eval()
431425
# eg
432426
# print(f"mu_hat: {mu_hat.eval()}")
433427

434-
# initial_values = bayes_model.initial_point()
435-
# print(f"Initial parameter values: {initial_values}")
436-
437428
# Posterior distribution
438429
idata = pm.sample(
439430
draws=draws,
@@ -443,7 +434,7 @@ def run_inference_shrinkage(self) -> None:
443434

444435
return idata
445436

446-
def run_inference_shrinkage_pert(self) -> None:
437+
def run_inference_shrinkage_pert(self, **kwargs) -> None:
447438
"""
448439
This function infers the parameters for the Bayesian gLV model with Horseshoe prior for shrinkage
449440
@@ -525,10 +516,8 @@ def run_inference_shrinkage_pert(self) -> None:
525516
lam = pm.HalfCauchy(
526517
"lam", beta=1, shape=(
527518
num_species, num_species - 1))
528-
M_ij_hat = pm.Normal('M_ij_hat', mu=prior_Mij_sigma, sigma=tau * lam *
529-
at.sqrt(c2 / (c2 + tau ** 2 * lam ** 2)),
530-
shape=(num_species,
531-
num_species - 1))
519+
M_ij_hat = pm.Normal('M_ij_hat', mu=prior_Mij_sigma, sigma=tau * lam * at.sqrt(
520+
c2 / (c2 + tau ** 2 * lam ** 2)), shape=(num_species, num_species - 1))
532521
# M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma,
533522
# shape=(num_species, num_species - 1)) # different shape for
534523
# off-diagonal
@@ -553,14 +542,15 @@ def run_inference_shrinkage_pert(self) -> None:
553542
Y_obs = pm.Normal('Y_obs', mu=model_mean, sigma=sigma, observed=F)
554543

555544
# For debugging:
545+
# print if `debug` is set to 'high' or 'low'
546+
if self.debug in ["high", "low"]:
547+
initial_values = bayes_model.initial_point()
548+
print(f"Initial parameter values: {initial_values}")
556549

557550
# As tensor objects are symbolic, if needed print using .eval()
558551
# eg
559552
# print(f"mu_hat: {mu_hat.eval()}")
560553

561-
# initial_values = bayes_model.initial_point()
562-
# print(f"Initial parameter values: {initial_values}")
563-
564554
# Posterior distribution
565555
idata = pm.sample(
566556
draws=draws,
@@ -587,17 +577,15 @@ def plot_posterior(self, idata):
587577
az.plot_posterior(
588578
idata,
589579
var_names=["mu_hat"],
590-
ref_val=mu_hat_np.tolist()
591-
)
580+
ref_val=mu_hat_np.tolist())
592581
plt.savefig("plot-posterior-mu.pdf")
593582
plt.show()
594583
plt.close()
595584

596585
az.plot_posterior(
597586
idata,
598587
var_names=["M_ii_hat"],
599-
ref_val=np.diag(M_hat_np).tolist()
600-
)
588+
ref_val=np.diag(M_hat_np).tolist())
601589
plt.savefig("plot-posterior-Mii.pdf")
602590
plt.show()
603591
plt.close()
@@ -607,8 +595,7 @@ def plot_posterior(self, idata):
607595
az.plot_posterior(
608596
idata,
609597
var_names=["M_ij_hat"],
610-
ref_val=M_ij.flatten().tolist()
611-
)
598+
ref_val=M_ij.flatten().tolist())
612599
plt.savefig("plot-posterior-Mij.pdf")
613600
plt.show()
614601
plt.close()
@@ -676,6 +663,11 @@ def plot_interaction_matrix(self, M, M_h):
676663
color='white')
677664

678665

666+
######
667+
######
668+
######
669+
670+
679671
def param_data_compare(
680672
idata,
681673
F,
@@ -735,26 +727,24 @@ def curve_compare(idata, F, times, yobs, init_species_start, sim_gLV_class):
735727
# mu_h = idata.posterior['mu_hat'].mean(dim=('chain', 'draw')).values.flatten()
736728
# M_h= idata.posterior['M_hat'].mean(dim=('chain', 'draw')).values
737729

738-
predictor = sim_gLV(num_species=num_species,
739-
M=M_h.T,
740-
mu=mu_h
741-
)
730+
predictor = sim_gLV(num_species=num_species, M=M_h.T, mu=mu_h)
742731
yobs_h, _, _, _, _ = predictor.simulate(
743732
times=times, init_species=init_species)
744733

745734
plot_fit_gLV(yobs, yobs_h, times)
746735

747736

748737
def param_data_compare_pert(
749-
idata,
750-
F,
751-
mu,
752-
M,
753-
epsilon,
754-
num_perturbations,
755-
times,
756-
yobs,
757-
init_species_start,
738+
idata,
739+
F,
740+
mu,
741+
M,
742+
u,
743+
epsilon,
744+
num_perturbations,
745+
times,
746+
yobs,
747+
init_species_start,
758748
sim_gLV_class):
759749
# az.to_netcdf(idata, 'model_posterior.nc')
760750
# Compare model parameters to the data
@@ -791,9 +781,9 @@ def param_data_compare_pert(
791781
epsilon=epsilon)
792782

793783
yobs, init_species, mu, M, _ = simulator.simulate(
794-
times=times, init_species=init_species, u=pert_fn)
784+
times=times, init_species=init_species, u=u)
795785
yobs_h, _, _, _, _ = predictor.simulate(
796-
times=times, init_species=init_species, u=pert_fn)
786+
times=times, init_species=init_species, u=u)
797787

798788
plot_fit_gLV(yobs, yobs_h, times)
799789

tests/test_infer_gLV_bayes.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def setup_data(request):
2020
X = np.random.randn(100, num_species+1) # Random design matrix for 100 samples and 2 species
2121
F = np.random.randn(100, num_species) # Random data matrix for 100 samples and 2 species
2222

23-
# Add a condition for different shapes
24-
if 'test_run_inference' in request.node.nodeid or 'test_run_gLV_bayes_shrinkage' in request.node.nodeid:
25-
X = np.random.randn(100, num_species + 1) # Shape (100, num_species + 1)
26-
elif 'test_run_bayes_gLV_shrinkage_pert' in request.node.nodeid or 'test_plot_posterior_pert' in request.node.nodeid:
27-
X = np.random.randn(100, num_species + 2) # Shape (100, num_species + 2)
23+
# Add a condition for different shapes depending on number of parameters
24+
test_name = request.node.nodeid.split("::")[-1]
25+
if test_name in ["test_run_inference", "test_run_inference_shrinkage"]:
26+
X = np.random.randn(100, num_species + 1)
27+
elif test_name in ["test_run_inference_shrinkage_pert", "test_plot_posterior_pert"]:
28+
X = np.random.randn(100, num_species + 2)
2829

2930

3031
prior_mu_mean = 0
@@ -117,12 +118,12 @@ def test_run_inference(bayes_gLV_instance):
117118
assert len(idata.posterior["M_hat"]) > 0, "'M_hat' has no samples."
118119

119120

120-
def test_run_bayes_gLV_shrinkage(bayes_gLV_instance):
121+
def test_run_inference_shrinkage(bayes_gLV_instance):
121122
"""
122123
Test the `run_bayes_gLV_shrinkage` function to check if it returns the correct output.
123124
"""
124125
# Call the method to test
125-
idata = bayes_gLV_instance.run_bayes_gLV_shrinkage()
126+
idata = bayes_gLV_instance.run_inference_shrinkage()
126127

127128
# Check that the output is an ArviZ InferenceData object
128129
assert isinstance(idata, az.InferenceData), "The output is not an InferenceData object."
@@ -132,12 +133,12 @@ def test_run_bayes_gLV_shrinkage(bayes_gLV_instance):
132133
assert "M_hat" in idata.posterior, "'M_hat' is not in the posterior."
133134

134135

135-
def test_run_bayes_gLV_shrinkage_pert(bayes_gLV_instance):
136+
def test_run_inference_shrinkage_pert(bayes_gLV_instance):
136137
"""
137138
Test the `run_bayes_gLV_shrinkage_pert` function to check if it returns the correct output.
138139
"""
139140
# Call the method to test
140-
idata = bayes_gLV_instance.run_bayes_gLV_shrinkage_pert()
141+
idata = bayes_gLV_instance.run_inference_shrinkage_pert()
141142

142143
# Check that the output is an ArviZ InferenceData object
143144
assert isinstance(idata, az.InferenceData), "The output is not an InferenceData object."
@@ -153,7 +154,7 @@ def test_plot_posterior(bayes_gLV_instance):
153154
Test the `plot_posterior` function to ensure it runs without errors.
154155
"""
155156
# Generate the InferenceData from the method
156-
idata = bayes_gLV_instance.run_bayes_gLV_shrinkage()
157+
idata = bayes_gLV_instance.run_inference_shrinkage()
157158

158159
# Try to call the plot_posterior method
159160
try:
@@ -167,7 +168,7 @@ def test_plot_posterior_pert(bayes_gLV_instance):
167168
Test the `plot_posterior_pert` function to ensure it produces the correct plot.
168169
"""
169170
# First, generate the InferenceData from the method
170-
idata = bayes_gLV_instance.run_bayes_gLV_shrinkage_pert()
171+
idata = bayes_gLV_instance.run_inference_shrinkage_pert()
171172

172173
# Try to call the plot_posterior_pert method
173174
try:

0 commit comments

Comments
 (0)