Skip to content

Commit 885cd5b

Browse files
committed
Update requirements.txt and format code
1 parent 94ddc23 commit 885cd5b

2 files changed

Lines changed: 43 additions & 30 deletions

File tree

mimic/model_infer/infer_VAR_bayes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def _run_inference_large(self, **kwargs) -> None:
279279
c2 = pm.InverseGamma("c2", 2, 8)
280280
tau = pm.HalfCauchy("tau", beta=tau0)
281281
lam = pm.HalfCauchy("lam", beta=1, shape=(ndim, ndim))
282-
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam * \
282+
A = pm.Normal('A', mu=A_prior_mu, sigma=tau * lam *
283283
at.sqrt(c2 / (c2 + tau**2 * lam**2)), shape=(ndim, ndim))
284284

285285
# If noise covariance is provided, use it as a prior
@@ -438,14 +438,14 @@ def _run_inference_large_xs(self, **kwargs) -> None:
438438
c2_A = pm.InverseGamma("c2_A", 2, 1)
439439
tau_A = pm.HalfCauchy("tau_A", beta=tau0_A)
440440
lam_A = pm.HalfCauchy("lam_A", beta=1, shape=(nX, nX))
441-
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A * \
441+
Ah = pm.Normal('Ah', mu=A_prior_mu, sigma=tau_A * lam_A *
442442
at.sqrt(c2_A / (c2_A + tau_A**2 * lam_A**2)), shape=(nX, nX))
443443

444444
tau0_B = (DB0 / (DB - DB0)) * 0.1 / np.sqrt(N)
445445
c2_B = pm.InverseGamma("c2_B", 2, 1)
446446
tau_B = pm.HalfCauchy("tau_B", beta=tau0_B)
447447
lam_B = pm.HalfCauchy("lam_B", beta=1, shape=(nS, nX))
448-
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B * \
448+
Bh = pm.Normal('Bh', mu=0, sigma=tau_B * lam_B *
449449
at.sqrt(c2_B / (c2_B + tau_B**2 * lam_B**2)), shape=(nS, nX))
450450

451451
if noise_cov_prior is not None:

mimic/model_infer/infer_gLV_bayes.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ def plot_params(mu_h, M_h, e_h, nsp):
4444
plt.stem(np.arange(0, nsp), np.array(e_h), markerfmt="D")
4545

4646

47-
48-
49-
50-
5147
class infergLVbayes(BaseInfer):
5248
"""
5349
bayes_gLV class for Bayesian inference of gLV models without shrinkage priors
@@ -67,27 +63,34 @@ class infergLVbayes(BaseInfer):
6763
"""
6864

6965
def __init__(self,
70-
X = None,
71-
F = None,
72-
prior_mu_mean = None,
73-
prior_mu_sigma = None,
74-
prior_Mii_mean = None,
75-
prior_Mii_sigma = None,
76-
prior_Mij_sigma = None
66+
X=None,
67+
F=None,
68+
prior_mu_mean=None,
69+
prior_mu_sigma=None,
70+
prior_Mii_mean=None,
71+
prior_Mii_sigma=None,
72+
prior_Mij_sigma=None
7773
):
7874

7975
# self.data = data # data to do inference on
8076
self.X: Optional[np.ndarray] = X
8177
self.F: Optional[np.ndarray] = F
8278
self.mu: Optional[Union[int, float]] = None
8379
self.M: Optional[Union[int, float]] = None
84-
self.prior_mu_mean: Optional[Union[int, float, List[Union[int, float]]]] = prior_mu_mean
85-
self.prior_mu_sigma: Optional[Union[int, float, List[Union[int, float]]]] = prior_mu_sigma
86-
self.prior_Mii_mean: Optional[Union[int, float, List[Union[int, float]]]] = prior_Mii_mean
87-
self.prior_Mii_sigma: Optional[Union[int, float, List[Union[int, float]]]] = prior_Mii_sigma
88-
self.prior_Mij_sigma: Optional[Union[int, float, List[Union[int, float]]]] = prior_Mij_sigma
89-
self.prior_eps_mean: Optional[Union[int, float, List[Union[int, float]]]] = None
90-
self.prior_eps_sigma: Optional[Union[int, float, List[Union[int, float]]]] = None
80+
self.prior_mu_mean: Optional[Union[int, float,
81+
List[Union[int, float]]]] = prior_mu_mean
82+
self.prior_mu_sigma: Optional[Union[int, float,
83+
List[Union[int, float]]]] = prior_mu_sigma
84+
self.prior_Mii_mean: Optional[Union[int, float,
85+
List[Union[int, float]]]] = prior_Mii_mean
86+
self.prior_Mii_sigma: Optional[Union[int, float,
87+
List[Union[int, float]]]] = prior_Mii_sigma
88+
self.prior_Mij_sigma: Optional[Union[int, float,
89+
List[Union[int, float]]]] = prior_Mij_sigma
90+
self.prior_eps_mean: Optional[Union[int,
91+
float, List[Union[int, float]]]] = None
92+
self.prior_eps_sigma: Optional[Union[int,
93+
float, List[Union[int, float]]]] = None
9194
self.draws: Optional[int] = None
9295
self.tune: Optional[int] = None
9396
self.chains: Optional[int] = None
@@ -315,8 +318,12 @@ def run_inference(self) -> None:
315318
# print(f"Initial parameter values: {initial_values}")
316319

317320
# Posterior distribution
318-
idata = pm.sample(draws=draws, tune=tune, chains=chains, cores=cores, progressbar=True)
319-
321+
idata = pm.sample(
322+
draws=draws,
323+
tune=tune,
324+
chains=chains,
325+
cores=cores,
326+
progressbar=True)
320327

321328
return idata
322329

@@ -386,11 +393,19 @@ def run_inference_shrinkage(self) -> None:
386393
tau0 = (DA0 / (DA - DA0)) * noise_stddev / np.sqrt(N)
387394
c2 = pm.InverseGamma("c2", 2, 1)
388395
tau = pm.HalfCauchy("tau", beta=tau0)
389-
lam = pm.HalfCauchy("lam", beta=1, shape=(num_species, num_species - 1))
390-
M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=tau * lam *
391-
at.sqrt(c2 / (c2 + tau ** 2 * lam ** 2)), shape=(num_species,
392-
num_species-1))
393-
#M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma, shape=(num_species, num_species - 1)) # different shape for off-diagonal
396+
lam = pm.HalfCauchy(
397+
"lam", beta=1, shape=(
398+
num_species, num_species - 1))
399+
M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=tau *
400+
lam *
401+
at.sqrt(c2 /
402+
(c2 +
403+
tau ** 2 *
404+
lam ** 2)), shape=(num_species, num_species -
405+
1))
406+
# M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma,
407+
# shape=(num_species, num_species - 1)) # different shape for
408+
# off-diagonal
394409

395410
# Combine values
396411
# start with an all-zero matrix of the correct shape
@@ -730,8 +745,6 @@ def curve_compare(idata, F, times, yobs, init_species_start, sim_gLV_class):
730745
plot_fit_gLV(yobs, yobs_h, times)
731746

732747

733-
734-
735748
def param_data_compare_pert(
736749
idata,
737750
F,

0 commit comments

Comments
 (0)