@@ -75,28 +75,27 @@ def CRM_inf_func(y, t, p):
7575 N_safe = at .maximum (N , eps )
7676 R_safe = at .maximum (R , eps )
7777
78-
7978 # Species growth equation (dN)
8079 growth_term = at .dot (c , w * R_safe ) # Matrix multiplication as tensor
8180 dN = (N_safe / tau ) * (growth_term - m ) # Species growth equation
8281
8382 # Resource consumption equation (dR)
8483 consumption_term = at .dot (N_safe , c ) # Matrix multiplication as tensor
85- dR = (1 / (r * K )) * (K - R_safe ) * R_safe - consumption_term * R_safe # Resource consumption equation
86-
84+ dR = (1 / (r * K )) * (K - R_safe ) * R_safe - \
85+ consumption_term * R_safe # Resource consumption equation
8786
88- # If species population or resource concentration is smaller than eps *and* decreasing,
89- # then set rate of change to zero to prevent negative values in the next step
87+ # If species population or resource concentration is smaller than eps *and* decreasing,
88+ # then set rate of change to zero to prevent negative values in the next
89+ # step
9090 dN = at .where ((N < eps ) & (dN < 0 ), 0.0 , dN )
9191 dR = at .where ((R < eps ) & (dR < 0 ), 0.0 , dR )
9292
93-
9493 # Flatten array to 1D for concatenation
9594 dN_flat = at .flatten (dN )
9695 dR_flat = at .flatten (dR )
9796
9897 # Combine dN and dR into a single 1D array
99- #derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
98+ # derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
10099 derivatives = at .concatenate ([dN_flat , dR_flat ]) # Concatenate species and
101100 # resource derivatives
102101
@@ -443,7 +442,7 @@ def run_inference(self) -> None:
443442 n_states = nsp + nr
444443 n_theta = 2 + (2 * nsp ) + (3 * nr ) + (nsp * nr )
445444
446- yobs_species_only = yobs [:, :nsp ]
445+ yobs_species_only = yobs [:, :nsp ]
447446
448447 # Define the DifferentialEquation model
449448 crm_model = DifferentialEquation (
@@ -459,32 +458,53 @@ def run_inference(self) -> None:
459458 with bayes_model :
460459 # Priors for unknown model parameters
461460
462-
463- sigma = pm . HalfNormal ( 'sigma' , sigma = 0.1 , shape = (1 ,)) # Same sigma for all responses
464-
461+ sigma = pm . HalfNormal (
462+ 'sigma' , sigma = 0.1 , shape = (
463+ 1 ,)) # Same sigma for all responses
465464
466465 # Conditionally define parameters based on whether priors are
467466 # provided
468467
469468 # For tau parameter
470469 if prior_tau_mean is not None and prior_tau_sigma is not None :
471- tau_hat = pm .TruncatedNormal ('tau_hat' ,mu = prior_tau_mean ,sigma = prior_tau_sigma ,lower = 0 ,shape = (nsp ,))
470+ tau_hat = pm .TruncatedNormal (
471+ 'tau_hat' ,
472+ mu = prior_tau_mean ,
473+ sigma = prior_tau_sigma ,
474+ lower = 0 ,
475+ shape = (
476+ nsp ,
477+ ))
472478 print ("tau_hat is inferred" )
473479 else :
474480 tau_hat = at .as_tensor_variable (tau )
475481 print ("tau_hat is fixed" )
476482
477483 # For w parameter
478484 if prior_w_mean is not None and prior_w_sigma is not None :
479- w_hat = pm .TruncatedNormal ('w_hat' ,mu = prior_w_mean ,sigma = prior_w_sigma ,lower = 0 ,shape = (nr ,))
485+ w_hat = pm .TruncatedNormal (
486+ 'w_hat' ,
487+ mu = prior_w_mean ,
488+ sigma = prior_w_sigma ,
489+ lower = 0 ,
490+ shape = (
491+ nr ,
492+ ))
480493 print ("w_hat is inferred" )
481494 else :
482495 w_hat = at .as_tensor_variable (w )
483496 print ("w_hat is fixed" )
484497
485498 # For c parameter
486499 if prior_c_mean is not None and prior_c_sigma is not None :
487- c_hat_vals = pm .TruncatedNormal ('c_hat_vals' ,mu = prior_c_mean ,sigma = prior_c_sigma ,lower = 0 ,shape = (nsp ,nr ))
500+ c_hat_vals = pm .TruncatedNormal (
501+ 'c_hat_vals' ,
502+ mu = prior_c_mean ,
503+ sigma = prior_c_sigma ,
504+ lower = 0 ,
505+ shape = (
506+ nsp ,
507+ nr ))
488508 c_hat = pm .Deterministic ('c_hat' , c_hat_vals )
489509 print ("c_hat is inferred" )
490510 else :
@@ -493,23 +513,44 @@ def run_inference(self) -> None:
493513
494514 # For m parameter
495515 if prior_m_mean is not None and prior_m_sigma is not None :
496- m_hat = pm .TruncatedNormal ('m_hat' ,mu = prior_m_mean ,sigma = prior_m_sigma ,lower = 0 ,shape = (nsp , ))
516+ m_hat = pm .TruncatedNormal (
517+ 'm_hat' ,
518+ mu = prior_m_mean ,
519+ sigma = prior_m_sigma ,
520+ lower = 0 ,
521+ shape = (
522+ nsp ,
523+ ))
497524 print ("m_hat is inferred" )
498525 else :
499526 m_hat = at .as_tensor_variable (m )
500527 print ("m_hat is fixed" )
501528
502529 # For r parameter
503530 if prior_r_mean is not None and prior_r_sigma is not None :
504- r_hat = pm .TruncatedNormal ('r_hat' ,mu = prior_r_mean ,sigma = prior_r_sigma ,lower = 0 ,shape = (nr ,))
531+ r_hat = pm .TruncatedNormal (
532+ 'r_hat' ,
533+ mu = prior_r_mean ,
534+ sigma = prior_r_sigma ,
535+ lower = 0 ,
536+ shape = (
537+ nr ,
538+ ))
505539 print ("r_hat is inferred" )
506540 else :
507541 r_hat = at .as_tensor_variable (r )
508542 print ("r_hat is fixed" )
509543
510544 # For K parameter
511545 if prior_K_mean is not None and prior_K_sigma is not None :
512- K_hat = pm .TruncatedNormal ('K_hat' , mu = prior_K_mean , sigma = prior_K_sigma ,lower = 0 ,shape = (nr ,))
546+ K_hat = pm .TruncatedNormal (
547+ 'K_hat' ,
548+ mu = prior_K_mean ,
549+ sigma = prior_K_sigma ,
550+ lower = 0 ,
551+ shape = (
552+ nr ,
553+ ))
513554 print ("K_hat is inferred" )
514555 else :
515556 K_hat = at .as_tensor_variable (K )
@@ -518,23 +559,23 @@ def run_inference(self) -> None:
518559 # Flatten to read into CRM_inf_func as a single vector
519560 nsp_tensor = at .as_tensor_variable ([nsp ])
520561 nr_tensor = at .as_tensor_variable ([nr ])
521-
522562
523- theta = at .concatenate ([nsp_tensor , nr_tensor , tau_hat , w_hat , c_hat .flatten (), m_hat , r_hat , K_hat ])
563+ theta = at .concatenate (
564+ [nsp_tensor , nr_tensor , tau_hat , w_hat , c_hat .flatten (), m_hat , r_hat , K_hat ])
524565
525566 print ("=== RSME ===" )
526567 try :
527568 y0 = np .full (n_states , 10.0 )
528569 test_curves = crm_model (y0 = y0 , theta = theta )
529570 test_pred = test_curves .eval ()
530-
571+
531572 rmse = np .sqrt (np .mean ((test_pred - yobs )** 2 ))
532573 print (f"RMSE: { rmse :.6f} " )
533574 print (f"Data scale: { np .mean (yobs ):.3f} " )
534575 print (f"Model scale: { np .mean (test_pred ):.3f} " )
535576 print (f"First few predictions: { test_pred [:3 ]} " )
536577 print (f"First few observations: { yobs [:3 ]} " )
537-
578+
538579 except Exception as e :
539580 print (f"MODEL FAILED: { e } " )
540581
@@ -547,8 +588,8 @@ def run_inference(self) -> None:
547588 # Initial conditions for the ODE
548589 # initial_conditions = np.concatenate([(yobs[0,:nsp]), np.array([10.0, 10.0])])
549590 # Initial species and resource populations
550- #y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
551- #y0 = yobs[0, :]
591+ # y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
592+ # y0 = yobs[0, :]
552593 y0 = np .full (nsp + nr , 10.0 )
553594 print (f"Initial conditions (y0): { y0 } " )
554595 # y0 = np.array([10.0, 10.0, 10.0, 10.0])
@@ -558,9 +599,10 @@ def run_inference(self) -> None:
558599 crm_curves = crm_model (y0 = y0 , theta = theta )
559600
560601 # Define the log-normal likelihood with log-transformed observed data
561- #Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
602+ # Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
562603 Y = pm .Normal ("Y" , mu = crm_curves , sigma = sigma , observed = yobs )
563- #Y = pm.Normal("Y", mu=crm_curves[:, :nsp], sigma=sigma, observed=yobs_species_only) # species only
604+ # Y = pm.Normal("Y", mu=crm_curves[:, :nsp], sigma=sigma,
605+ # observed=yobs_species_only) # species only
564606
565607 # For debugging:
566608 # print if `debug` is set to 'high' or 'low'
@@ -580,7 +622,12 @@ def run_inference(self) -> None:
580622 print ("Shape of crm_curves:" , crm_curves .shape .eval ())
581623
582624 # Sample the posterior
583- idata = pm .sample (draws = draws ,tune = tune ,chains = chains ,cores = cores ,progressbar = True )
625+ idata = pm .sample (
626+ draws = draws ,
627+ tune = tune ,
628+ chains = chains ,
629+ cores = cores ,
630+ progressbar = True )
584631
585632 return idata
586633
@@ -602,7 +649,7 @@ def plot_posterior(self, idata, true_params=None):
602649 for i , param in enumerate (param_names ):
603650 if param in available_vars :
604651 print (f"Plotting posterior for { param } " )
605-
652+
606653 # Extract the posterior mean for the parameter (as before)
607654 if param == "c_hat" :
608655 # Special handling for c_hat due to its shape
@@ -613,35 +660,40 @@ def plot_posterior(self, idata, true_params=None):
613660 param_np = idata .posterior [param ].mean (
614661 dim = ('chain' , 'draw' )).values .flatten ()
615662 ref_val = param_np .tolist ()
616-
663+
617664 # Plot the posterior distribution (original behavior)
618665 az .plot_posterior (
619666 idata ,
620667 var_names = [param ],
621668 ref_val = ref_val
622669 )
623-
670+
624671 # Add true value as a vertical line if available
625672 true_param_name = true_param_names [i ]
626673 if true_params and true_param_name in true_params :
627674 true_val = true_params [true_param_name ]
628-
675+
629676 # Flatten the true values to match the subplot structure
630677 true_vals = true_val .flatten ()
631-
678+
632679 # Get current axes
633680 axes = plt .gcf ().get_axes ()
634681 for j , ax in enumerate (axes ):
635682 if j < len (true_vals ):
636- ax .axvline (true_vals [j ], color = 'red' , linestyle = '--' , linewidth = 2 ,
637- label = f'True value' )
683+ ax .axvline (
684+ true_vals [j ],
685+ color = 'red' ,
686+ linestyle = '--' ,
687+ linewidth = 2 ,
688+ label = f'True value' )
638689 ax .legend ()
639-
690+
640691 print (f"Added true value line for { param } : { true_val } " )
641-
692+
642693 # Save the plot
643694 plt .savefig (f"plot-posterior-{ param } .pdf" )
644695 plt .show ()
645696 plt .close ()
646697 else :
647- print (f"Parameter { param } not found in posterior samples, skipping plot." )
698+ print (
699+ f"Parameter { param } not found in posterior samples, skipping plot." )
0 commit comments