@@ -44,10 +44,6 @@ def plot_params(mu_h, M_h, e_h, nsp):
4444 plt .stem (np .arange (0 , nsp ), np .array (e_h ), markerfmt = "D" )
4545
4646
47-
48-
49-
50-
5147class infergLVbayes (BaseInfer ):
5248 """
5349 bayes_gLV class for Bayesian inference of gLV models without shrinkage priors
@@ -67,27 +63,34 @@ class infergLVbayes(BaseInfer):
6763 """
6864
6965 def __init__ (self ,
70- X = None ,
71- F = None ,
72- prior_mu_mean = None ,
73- prior_mu_sigma = None ,
74- prior_Mii_mean = None ,
75- prior_Mii_sigma = None ,
76- prior_Mij_sigma = None
66+ X = None ,
67+ F = None ,
68+ prior_mu_mean = None ,
69+ prior_mu_sigma = None ,
70+ prior_Mii_mean = None ,
71+ prior_Mii_sigma = None ,
72+ prior_Mij_sigma = None
7773 ):
7874
7975 # self.data = data # data to do inference on
8076 self .X : Optional [np .ndarray ] = X
8177 self .F : Optional [np .ndarray ] = F
8278 self .mu : Optional [Union [int , float ]] = None
8379 self .M : Optional [Union [int , float ]] = None
84- self .prior_mu_mean : Optional [Union [int , float , List [Union [int , float ]]]] = prior_mu_mean
85- self .prior_mu_sigma : Optional [Union [int , float , List [Union [int , float ]]]] = prior_mu_sigma
86- self .prior_Mii_mean : Optional [Union [int , float , List [Union [int , float ]]]] = prior_Mii_mean
87- self .prior_Mii_sigma : Optional [Union [int , float , List [Union [int , float ]]]] = prior_Mii_sigma
88- self .prior_Mij_sigma : Optional [Union [int , float , List [Union [int , float ]]]] = prior_Mij_sigma
89- self .prior_eps_mean : Optional [Union [int , float , List [Union [int , float ]]]] = None
90- self .prior_eps_sigma : Optional [Union [int , float , List [Union [int , float ]]]] = None
80+ self .prior_mu_mean : Optional [Union [int , float ,
81+ List [Union [int , float ]]]] = prior_mu_mean
82+ self .prior_mu_sigma : Optional [Union [int , float ,
83+ List [Union [int , float ]]]] = prior_mu_sigma
84+ self .prior_Mii_mean : Optional [Union [int , float ,
85+ List [Union [int , float ]]]] = prior_Mii_mean
86+ self .prior_Mii_sigma : Optional [Union [int , float ,
87+ List [Union [int , float ]]]] = prior_Mii_sigma
88+ self .prior_Mij_sigma : Optional [Union [int , float ,
89+ List [Union [int , float ]]]] = prior_Mij_sigma
90+ self .prior_eps_mean : Optional [Union [int ,
91+ float , List [Union [int , float ]]]] = None
92+ self .prior_eps_sigma : Optional [Union [int ,
93+ float , List [Union [int , float ]]]] = None
9194 self .draws : Optional [int ] = None
9295 self .tune : Optional [int ] = None
9396 self .chains : Optional [int ] = None
@@ -315,8 +318,12 @@ def run_inference(self) -> None:
315318 # print(f"Initial parameter values: {initial_values}")
316319
317320 # Posterior distribution
318- idata = pm .sample (draws = draws , tune = tune , chains = chains , cores = cores , progressbar = True )
319-
321+ idata = pm .sample (
322+ draws = draws ,
323+ tune = tune ,
324+ chains = chains ,
325+ cores = cores ,
326+ progressbar = True )
320327
321328 return idata
322329
@@ -386,11 +393,19 @@ def run_inference_shrinkage(self) -> None:
386393 tau0 = (DA0 / (DA - DA0 )) * noise_stddev / np .sqrt (N )
387394 c2 = pm .InverseGamma ("c2" , 2 , 1 )
388395 tau = pm .HalfCauchy ("tau" , beta = tau0 )
389- lam = pm .HalfCauchy ("lam" , beta = 1 , shape = (num_species , num_species - 1 ))
390- M_ij_hat = pm .Normal ('M_ij_hat' , mu = 0 , sigma = tau * lam *
391- at .sqrt (c2 / (c2 + tau ** 2 * lam ** 2 )), shape = (num_species ,
392- num_species - 1 ))
393- #M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma, shape=(num_species, num_species - 1)) # different shape for off-diagonal
396+ lam = pm .HalfCauchy (
397+ "lam" , beta = 1 , shape = (
398+ num_species , num_species - 1 ))
399+ M_ij_hat = pm .Normal ('M_ij_hat' , mu = 0 , sigma = tau *
400+ lam *
401+ at .sqrt (c2 /
402+ (c2 +
403+ tau ** 2 *
404+ lam ** 2 )), shape = (num_species , num_species -
405+ 1 ))
406+ # M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma,
407+ # shape=(num_species, num_species - 1)) # different shape for
408+ # off-diagonal
394409
395410 # Combine values
396411 # start with an all-zero matrix of the correct shape
@@ -730,8 +745,6 @@ def curve_compare(idata, F, times, yobs, init_species_start, sim_gLV_class):
730745 plot_fit_gLV (yobs , yobs_h , times )
731746
732747
733-
734-
735748def param_data_compare_pert (
736749 idata ,
737750 F ,
0 commit comments