22import matplotlib .pyplot as plt
33import numpy as np
44import pymc as pm
5+ import pandas as pd
6+ import seaborn as sns
57import pytensor .tensor as at
68import pickle
79import cloudpickle
10+ import os
11+ from typing import Optional , Union , List , Dict , Any
812
913from mimic .utilities import *
1014from mimic .model_simulate .sim_gLV import *
1115from mimic .model_infer .base_infer import BaseInfer
1216
13- from mimic .model_infer .base_infer import BaseInfer
14-
15- import os
16- from typing import Optional , Union , List , Dict , Any
17-
18-
19- import pandas as pd
20- import numpy as np
21- import seaborn as sns
22- import matplotlib .pyplot as plt
23-
2417
25- # Used in examples-Stein.ipynb
2618def plot_params (mu_h , M_h , e_h , nsp ):
2719 print ("\n inferred params:" )
2820 print ("mu_hat/mu:" )
@@ -69,8 +61,9 @@ def __init__(self,
6961 prior_mu_sigma = None ,
7062 prior_Mii_mean = None ,
7163 prior_Mii_sigma = None ,
72- prior_Mij_sigma = None
73- ):
64+ prior_Mij_sigma = None ):
65+
66+ super ().__init__ () # Call base class constructor
7467
7568 # self.data = data # data to do inference on
7669 self .X : Optional [np .ndarray ] = X
@@ -217,7 +210,7 @@ def calculate_DA0(self, num_species, proportion=0.15):
217210 DA0 = int (round (expected_non_zero_elements ))
218211 return max (DA0 , 1 )
219212
220- def run_inference (self ) -> None :
213+ def run_inference (self , ** kwargs ) -> None :
221214 """
222215 This function infers the parameters for the Bayesian gLV model
223216
@@ -314,9 +307,6 @@ def run_inference(self) -> None:
314307 # eg
315308 # print(f"mu_hat: {mu_hat.eval()}")
316309
317- # initial_values = bayes_model.initial_point()
318- # print(f"Initial parameter values: {initial_values}")
319-
320310 # Posterior distribution
321311 idata = pm .sample (
322312 draws = draws ,
@@ -327,7 +317,7 @@ def run_inference(self) -> None:
327317
328318 return idata
329319
330- def run_inference_shrinkage (self ) -> None :
320+ def run_inference_shrinkage (self , ** kwargs ) -> None :
331321 """
332322 This function infers the parameters for the Bayesian gLV model with Horseshoe prior for shrinkage
333323
@@ -426,14 +416,15 @@ def run_inference_shrinkage(self) -> None:
426416 Y_obs = pm .Normal ('Y_obs' , mu = model_mean , sigma = sigma , observed = F )
427417
428418 # For debugging:
419+ # print if `debug` is set to 'high' or 'low'
420+ if self .debug in ["high" , "low" ]:
421+ initial_values = bayes_model .initial_point ()
422+ print (f"Initial parameter values: { initial_values } " )
429423
430424 # As tensor objects are symbolic, if needed print using .eval()
431425 # eg
432426 # print(f"mu_hat: {mu_hat.eval()}")
433427
434- # initial_values = bayes_model.initial_point()
435- # print(f"Initial parameter values: {initial_values}")
436-
437428 # Posterior distribution
438429 idata = pm .sample (
439430 draws = draws ,
@@ -443,7 +434,7 @@ def run_inference_shrinkage(self) -> None:
443434
444435 return idata
445436
446- def run_inference_shrinkage_pert (self ) -> None :
437+ def run_inference_shrinkage_pert (self , ** kwargs ) -> None :
447438 """
448439 This function infers the parameters for the Bayesian gLV model with Horseshoe prior for shrinkage
449440
@@ -525,10 +516,8 @@ def run_inference_shrinkage_pert(self) -> None:
525516 lam = pm .HalfCauchy (
526517 "lam" , beta = 1 , shape = (
527518 num_species , num_species - 1 ))
528- M_ij_hat = pm .Normal ('M_ij_hat' , mu = prior_Mij_sigma , sigma = tau * lam *
529- at .sqrt (c2 / (c2 + tau ** 2 * lam ** 2 )),
530- shape = (num_species ,
531- num_species - 1 ))
519+ M_ij_hat = pm .Normal ('M_ij_hat' , mu = prior_Mij_sigma , sigma = tau * lam * at .sqrt (
520+ c2 / (c2 + tau ** 2 * lam ** 2 )), shape = (num_species , num_species - 1 ))
532521 # M_ij_hat = pm.Normal('M_ij_hat', mu=0, sigma=prior_Mij_sigma,
533522 # shape=(num_species, num_species - 1)) # different shape for
534523 # off-diagonal
@@ -553,14 +542,15 @@ def run_inference_shrinkage_pert(self) -> None:
553542 Y_obs = pm .Normal ('Y_obs' , mu = model_mean , sigma = sigma , observed = F )
554543
555544 # For debugging:
545+ # print if `debug` is set to 'high' or 'low'
546+ if self .debug in ["high" , "low" ]:
547+ initial_values = bayes_model .initial_point ()
548+ print (f"Initial parameter values: { initial_values } " )
556549
557550 # As tensor objects are symbolic, if needed print using .eval()
558551 # eg
559552 # print(f"mu_hat: {mu_hat.eval()}")
560553
561- # initial_values = bayes_model.initial_point()
562- # print(f"Initial parameter values: {initial_values}")
563-
564554 # Posterior distribution
565555 idata = pm .sample (
566556 draws = draws ,
@@ -587,17 +577,15 @@ def plot_posterior(self, idata):
587577 az .plot_posterior (
588578 idata ,
589579 var_names = ["mu_hat" ],
590- ref_val = mu_hat_np .tolist ()
591- )
580+ ref_val = mu_hat_np .tolist ())
592581 plt .savefig ("plot-posterior-mu.pdf" )
593582 plt .show ()
594583 plt .close ()
595584
596585 az .plot_posterior (
597586 idata ,
598587 var_names = ["M_ii_hat" ],
599- ref_val = np .diag (M_hat_np ).tolist ()
600- )
588+ ref_val = np .diag (M_hat_np ).tolist ())
601589 plt .savefig ("plot-posterior-Mii.pdf" )
602590 plt .show ()
603591 plt .close ()
@@ -607,8 +595,7 @@ def plot_posterior(self, idata):
607595 az .plot_posterior (
608596 idata ,
609597 var_names = ["M_ij_hat" ],
610- ref_val = M_ij .flatten ().tolist ()
611- )
598+ ref_val = M_ij .flatten ().tolist ())
612599 plt .savefig ("plot-posterior-Mij.pdf" )
613600 plt .show ()
614601 plt .close ()
@@ -676,6 +663,11 @@ def plot_interaction_matrix(self, M, M_h):
676663 color = 'white' )
677664
678665
666+ ######
667+ ######
668+ ######
669+
670+
679671def param_data_compare (
680672 idata ,
681673 F ,
@@ -735,26 +727,24 @@ def curve_compare(idata, F, times, yobs, init_species_start, sim_gLV_class):
735727 # mu_h = idata.posterior['mu_hat'].mean(dim=('chain', 'draw')).values.flatten()
736728 # M_h= idata.posterior['M_hat'].mean(dim=('chain', 'draw')).values
737729
738- predictor = sim_gLV (num_species = num_species ,
739- M = M_h .T ,
740- mu = mu_h
741- )
730+ predictor = sim_gLV (num_species = num_species , M = M_h .T , mu = mu_h )
742731 yobs_h , _ , _ , _ , _ = predictor .simulate (
743732 times = times , init_species = init_species )
744733
745734 plot_fit_gLV (yobs , yobs_h , times )
746735
747736
748737def param_data_compare_pert (
749- idata ,
750- F ,
751- mu ,
752- M ,
753- epsilon ,
754- num_perturbations ,
755- times ,
756- yobs ,
757- init_species_start ,
738+ idata ,
739+ F ,
740+ mu ,
741+ M ,
742+ u ,
743+ epsilon ,
744+ num_perturbations ,
745+ times ,
746+ yobs ,
747+ init_species_start ,
758748 sim_gLV_class ):
759749 # az.to_netcdf(idata, 'model_posterior.nc')
760750 # Compare model parameters to the data
@@ -791,9 +781,9 @@ def param_data_compare_pert(
791781 epsilon = epsilon )
792782
793783 yobs , init_species , mu , M , _ = simulator .simulate (
794- times = times , init_species = init_species , u = pert_fn )
784+ times = times , init_species = init_species , u = u )
795785 yobs_h , _ , _ , _ , _ = predictor .simulate (
796- times = times , init_species = init_species , u = pert_fn )
786+ times = times , init_species = init_species , u = u )
797787
798788 plot_fit_gLV (yobs , yobs_h , times )
799789
0 commit comments