@@ -51,8 +51,8 @@ def plot_growth_curves(data, ax=None):
5151
5252def CRM_inf_func (y , t , p ):
5353 # Unpack parameters from the vector p
54- nr = p [0 ].astype ("int32" ) # Number of resources
55- nsp = p [1 ].astype ("int32" ) # Number of species
54+ nsp = p [0 ].astype ("int32" ) # Number of resources
55+ nr = p [1 ].astype ("int32" ) # Number of species
5656 tau = p [2 :2 + nsp ] # Species time scales
5757 w = p [2 + nsp :2 + nsp + nr ] # Resource quality
5858 # Flattened resource preferences
@@ -69,18 +69,23 @@ def CRM_inf_func(y, t, p):
6969 N = y [:nsp ] # Species populations
7070 R = y [nsp :] # Resource availability
7171
72+
7273 # Species growth equation (dN)
7374 growth_term = at .dot (c , w * R ) # Matrix multiplication as tensor
7475 dN = (N / tau ) * (growth_term - m ) # Species growth equation
7576
7677 # Resource consumption equation (dR)
7778 consumption_term = at .dot (N , c ) # Matrix multiplication as tensor
78- dR = (1 / (r * K )) * (K - R ) * R - consumption_term * \
79- R # Resource consumption equation
79+ dR = (1 / (r * K )) * (K - R ) * R - consumption_term * R # Resource consumption equation
80+
81+
82+ # Flatten array to 1D for concatenation
83+ dN_flat = at .flatten (dN )
84+ dR_flat = at .flatten (dR )
8085
8186 # Combine dN and dR into a single 1D array
82- derivatives = [dN [0 ], dN [1 ], dR [0 ], dR [1 ]] # 1D array
83- # derivatives = np .concatenate([dN, dR ]) # Concatenate species and
87+ # derivatives = [dN[0], dN[1], dR[0], dR[1]] # 1D array
88+ derivatives = at .concatenate ([dN_flat , dR_flat ]) # Concatenate species and
8489 # resource derivatives
8590
8691 # Return the derivatives for both species and resources as a single array
@@ -426,6 +431,8 @@ def run_inference(self) -> None:
426431 n_states = nsp + nr
427432 n_theta = 2 + (2 * nsp ) + (3 * nr ) + (nsp * nr )
428433
434+ yobs_species_only = yobs [:, :nsp ]
435+
429436 # Define the DifferentialEquation model
430437 crm_model = DifferentialEquation (
431438 func = CRM_inf_func , # The ODE function
@@ -440,53 +447,30 @@ def run_inference(self) -> None:
440447 with bayes_model :
441448 # Priors for unknown model parameters
442449
443- sigma = pm .HalfNormal (
444- 'sigma' , sigma = 0.1 , shape = (
445- 1 ,)) # Same sigma for all responses
450+ sigma = pm .HalfNormal ('sigma' , sigma = 0.1 , shape = (1 ,)) # Same sigma for all responses
446451
447452 # Conditionally define parameters based on whether priors are
448453 # provided
449454
450455 # For tau parameter
451456 if prior_tau_mean is not None and prior_tau_sigma is not None :
452- tau_hat = pm .TruncatedNormal (
453- 'tau_hat' ,
454- mu = prior_tau_mean ,
455- sigma = prior_tau_sigma ,
456- lower = 0.1 ,
457- shape = (
458- nsp ,
459- ))
457+ tau_hat = pm .TruncatedNormal ('tau_hat' ,mu = prior_tau_mean ,sigma = prior_tau_sigma ,lower = 0 ,shape = (nsp ,))
460458 print ("tau_hat is inferred" )
461459 else :
462460 tau_hat = at .as_tensor_variable (tau )
463461 print ("tau_hat is fixed" )
464462
465463 # For w parameter
466464 if prior_w_mean is not None and prior_w_sigma is not None :
467- w_hat = pm .TruncatedNormal (
468- 'w_hat' ,
469- mu = prior_w_mean ,
470- sigma = prior_w_sigma ,
471- lower = 0.1 ,
472- shape = (
473- nr ,
474- ))
465+ w_hat = pm .TruncatedNormal ('w_hat' ,mu = prior_w_mean ,sigma = prior_w_sigma ,lower = 0 ,shape = (nr ,))
475466 print ("w_hat is inferred" )
476467 else :
477468 w_hat = at .as_tensor_variable (w )
478469 print ("w_hat is fixed" )
479470
480471 # For c parameter
481472 if prior_c_mean is not None and prior_c_sigma is not None :
482- c_hat_vals = pm .TruncatedNormal (
483- 'c_hat_vals' ,
484- mu = prior_c_mean ,
485- sigma = prior_c_sigma ,
486- lower = 0 ,
487- shape = (
488- nsp ,
489- nr ))
473+ c_hat_vals = pm .TruncatedNormal ('c_hat_vals' ,mu = prior_c_mean ,sigma = prior_c_sigma ,lower = 0 ,shape = (nsp ,nr ))
490474 c_hat = pm .Deterministic ('c_hat' , c_hat_vals )
491475 print ("c_hat is inferred" )
492476 else :
@@ -495,60 +479,64 @@ def run_inference(self) -> None:
495479
496480 # For m parameter
497481 if prior_m_mean is not None and prior_m_sigma is not None :
498- m_hat = pm .TruncatedNormal (
499- 'm_hat' ,
500- mu = prior_m_mean ,
501- sigma = prior_m_sigma ,
502- lower = 0.1 ,
503- shape = (
504- nsp ,
505- ))
482+ m_hat = pm .TruncatedNormal ('m_hat' ,mu = prior_m_mean ,sigma = prior_m_sigma ,lower = 0 ,shape = (nsp , ))
506483 print ("m_hat is inferred" )
507484 else :
508485 m_hat = at .as_tensor_variable (m )
509486 print ("m_hat is fixed" )
510487
511488 # For r parameter
512489 if prior_r_mean is not None and prior_r_sigma is not None :
513- r_hat = pm .TruncatedNormal (
514- 'r_hat' ,
515- mu = prior_r_mean ,
516- sigma = prior_r_sigma ,
517- lower = 0 ,
518- shape = (
519- nr ,
520- ))
490+ r_hat = pm .TruncatedNormal ('r_hat' ,mu = prior_r_mean ,sigma = prior_r_sigma ,lower = 0 ,shape = (nr ,))
521491 print ("r_hat is inferred" )
522492 else :
523493 r_hat = at .as_tensor_variable (r )
524494 print ("r_hat is fixed" )
525495
526496 # For K parameter
527497 if prior_K_mean is not None and prior_K_sigma is not None :
528- K_hat = pm .TruncatedNormal (
529- 'K_hat' ,
530- mu = prior_K_mean ,
531- sigma = prior_K_sigma ,
532- lower = 1.0 ,
533- shape = (
534- nr ,
535- ))
498+ K_hat = pm .TruncatedNormal ('K_hat' , mu = prior_K_mean , sigma = prior_K_sigma ,lower = 0 ,shape = (nr ,))
536499 print ("K_hat is inferred" )
537500 else :
538501 K_hat = at .as_tensor_variable (K )
539502 print ("K_hat is fixed" )
540503
541504 # Flatten to read into CRM_inf_func as a single vector
542- nr_tensor = at .as_tensor_variable ([nr ])
543505 nsp_tensor = at .as_tensor_variable ([nsp ])
544-
545- theta = at .concatenate (
546- [nr_tensor , nsp_tensor , tau_hat , w_hat , c_hat .flatten (), m_hat , r_hat , K_hat ])
506+ nr_tensor = at .as_tensor_variable ([nr ])
507+
508+
509+ theta = at .concatenate ([nsp_tensor , nr_tensor , tau_hat , w_hat , c_hat .flatten (), m_hat , r_hat , K_hat ])
510+
511+ print ("=== CRITICAL: TESTING IF MODEL STRUCTURE IS CORRECT ===" )
512+ try :
513+ y0 = np .full (n_states , 10.0 )
514+ test_curves = crm_model (y0 = y0 , theta = theta )
515+ test_pred = test_curves .eval ()
516+
517+ rmse = np .sqrt (np .mean ((test_pred - yobs )** 2 ))
518+ print (f"RMSE with near-true parameters: { rmse :.6f} " )
519+ print (f"Data scale: { np .mean (yobs ):.3f} " )
520+ print (f"Model scale: { np .mean (test_pred ):.3f} " )
521+ print (f"First few predictions: { test_pred [:3 ]} " )
522+ print (f"First few observations: { yobs [:3 ]} " )
523+
524+ except Exception as e :
525+ print (f"MODEL FAILED: { e } " )
526+
527+ print (f"nsp_tensor: { nsp_tensor .eval ()} , nr_tensor: { nr_tensor .eval ()} " )
528+ print (f"tau_hat: { tau_hat .eval ()} , w_hat: { w_hat .eval ()} " )
529+ print (f"c_hat: { c_hat .eval ()} , m_hat: { m_hat .eval ()} " )
530+ print (f"r_hat: { r_hat .eval ()} , K_hat: { K_hat .eval ()} " )
531+ print (f"theta: { theta .eval ()} " )
547532
548533 # Initial conditions for the ODE
549534 # initial_conditions = np.concatenate([(yobs[0,:nsp]), np.array([10.0, 10.0])])
550535 # Initial species and resource populations
551- y0 = np .concatenate ([np .ones (nsp ), np .ones (nr )])
536+ #y0 = np.concatenate([np.ones(nsp), np.ones(nr)])
537+ #y0 = yobs[0, :]
538+ y0 = np .full (nsp + nr , 10.0 )
539+ print (f"Initial conditions (y0): { y0 } " )
552540 # y0 = np.array([10.0, 10.0, 10.0, 10.0])
553541 # y0 = np.full(n_states, 10.0)
554542
@@ -557,11 +545,9 @@ def run_inference(self) -> None:
557545
558546 # Define the log-normal likelihood with log-transformed observed data
559547 # Y = pm.Lognormal("Y", mu=pm.math.log(crm_curves), sigma=sigma, observed=yobs)
560- Y = pm .Lognormal (
561- "Y" ,
562- mu = at .log (crm_curves ),
563- sigma = sigma ,
564- observed = yobs )
548+ #Y = pm.Lognormal( "Y",mu=at.log(crm_curves),sigma=sigma, observed=yobs)
549+ #Y = pm.Normal("Y", mu=crm_curves, sigma=sigma, observed=yobs)
550+ Y = pm .Normal ("Y" , mu = crm_curves [:, :nsp ], sigma = sigma , observed = yobs_species_only ) # species only
565551
566552 # For debugging:
567553 # print if `debug` is set to 'high' or 'low'
@@ -581,12 +567,7 @@ def run_inference(self) -> None:
581567 print ("Shape of crm_curves:" , crm_curves .shape .eval ())
582568
583569 # Sample the posterior
584- idata = pm .sample (
585- draws = draws ,
586- tune = tune ,
587- chains = chains ,
588- cores = cores ,
589- progressbar = True )
570+ idata = pm .sample (draws = draws ,tune = tune ,chains = chains ,cores = cores ,progressbar = True )
590571
591572 return idata
592573
0 commit comments