Skip to content

Commit 9aeb549

Browse files
authored
Merge pull request #106 from ucl-cssb/CRM
CRM_fix
2 parents 2d86c83 + 806c3bf commit 9aeb549

6 files changed

Lines changed: 2842 additions & 94 deletions

File tree

examples/CRM/examples-bayes-CRM.ipynb

Lines changed: 2372 additions & 46 deletions
Large diffs are not rendered by default.

examples/CRM/examples-sim-CRM.ipynb

Lines changed: 29 additions & 15 deletions
Large diffs are not rendered by default.

examples/CRM/run-bayes-CRM.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
from mimic.utilities import *
2+
from mimic.utilities.utilities import plot_CRM, plot_CRM_with_intervals
3+
4+
from mimic.model_infer.infer_CRM_bayes import *
5+
from mimic.model_infer import *
6+
from mimic.model_simulate import *
7+
from mimic.model_simulate.sim_CRM import *
8+
9+
import numpy as np
10+
import pandas as pd
11+
import seaborn as sns
12+
import matplotlib
13+
matplotlib.use('Agg') # Non-interactive backend for background running
14+
import matplotlib.pyplot as plt
15+
16+
import arviz as az
17+
import pymc as pm
18+
import pytensor.tensor as at
19+
import pickle
20+
import cloudpickle
21+
import os
22+
23+
from scipy import stats
24+
from scipy.integrate import odeint
25+
26+
import glob
27+
import shutil
28+
29+
# Set the working directory to the script's directory
30+
script_dir = os.path.dirname(os.path.abspath(__file__))
31+
os.chdir(script_dir)
32+
33+
34+
## Load the parameters
35+
with open("params-s2-r2.pkl", "rb") as f:
36+
params = pickle.load(f)
37+
tau = params["tau"]
38+
m = params["m"]
39+
r = params["r"]
40+
w = params["w"]
41+
K = params["K"]
42+
c = params["c"]
43+
44+
## read in the data
45+
# When given both species and resource data
46+
47+
data = pd.read_csv("data-s2-r2.csv")
48+
49+
times = data.iloc[:, 0].values
50+
yobs = data.iloc[:, 1:6].values
51+
52+
# When given only species data for the same system as above
53+
54+
# data = pd.read_csv("data-s2-infer-r2.csv")
55+
56+
# times = data.iloc[:, 0].values
57+
# yobs = data.iloc[:, 1:3].values
58+
59+
60+
# Output folder specification
61+
output_folder = "s2_r2_inferC_1prior2" # Change this for different runs
62+
# Create output directory
63+
os.makedirs(output_folder, exist_ok=True)
64+
65+
66+
## Define the number of species and resources, and fixed parameters if necessary
67+
68+
num_species = 2
69+
num_resources = 2
70+
71+
# fixed parameters
72+
tau = params["tau"]
73+
# c = params["c"]
74+
m = params["m"]
75+
r = params["r"]
76+
w = params["w"]
77+
K = params["K"]
78+
79+
# Define priors as necessary
80+
81+
# prior_tau_mean = 0.7
82+
# prior_tau_sigma = 0.2
83+
84+
# prior_w_mean = 0.55
85+
# prior_w_sigma = 0.2
86+
87+
prior_c_mean = [[0.2, 0.1], [0.1, 0.2]]
88+
prior_c_sigma = [[0.1, 0.1], [0.1, 0.1]]
89+
90+
# prior_m_mean = 0.25
91+
# prior_m_sigma = 0.1
92+
93+
# prior_r_mean = 0.4
94+
# prior_r_sigma = 0.1
95+
96+
# prior_K_mean = 5.5
97+
# prior_K_sigma = 0.5
98+
99+
100+
# Sampling conditions
101+
draws = 50
102+
tune = 50
103+
chains = 4
104+
cores = 4
105+
106+
107+
# Save model conditions to file
108+
conditions_text = f"""Model Conditions and Priors
109+
============================
110+
111+
Sampling Conditions:
112+
- draws: {draws}
113+
- tune: {tune}
114+
- chains: {chains}
115+
- cores: {cores}
116+
117+
Number of species: {num_species}
118+
Number of resources: {num_resources}
119+
120+
Prior Parameters:
121+
- tau: mean = {globals().get('prior_tau_mean', 'na')}, sigma = {globals().get('prior_tau_sigma', 'na')}
122+
- w: mean = {globals().get('prior_w_mean', 'na')}, sigma = {globals().get('prior_w_sigma', 'na')}
123+
- c: mean = {globals().get('prior_c_mean', 'na')}, sigma = {globals().get('prior_c_sigma', 'na')}
124+
- m: mean = {globals().get('prior_m_mean', 'na')}, sigma = {globals().get('prior_m_sigma', 'na')}
125+
- r: mean = {globals().get('prior_r_mean', 'na')}, sigma = {globals().get('prior_r_sigma', 'na')}
126+
- K: mean = {globals().get('prior_K_mean', 'na')}, sigma = {globals().get('prior_K_sigma', 'na')}
127+
"""
128+
129+
with open(os.path.join(output_folder, 'model_conditions.txt'), 'w') as f:
130+
f.write(conditions_text)
131+
print(f"Saved model conditions to {output_folder}/model_conditions.txt")
132+
133+
# Run inference
134+
135+
inference = inferCRMbayes()
136+
137+
# adjust set_parameters to include either fixed parameters, as in tau=tau, or priors
138+
# to infer them, as in prior_tau_mean=prior_tau_mean, prior_tau_sigma=prior_tau_sigma
139+
140+
inference.set_parameters(times=times, yobs=yobs, num_species=num_species, num_resources=num_resources,
141+
tau=tau, w=w, m=m, r=r, K=K,
142+
prior_c_mean=prior_c_mean, prior_c_sigma=prior_c_sigma,
143+
draws=draws, tune=tune, chains=chains, cores=cores)
144+
145+
idata = inference.run_inference()
146+
147+
# To plot posterior distributions
148+
inference.plot_posterior(idata, true_params=params) # saves to wd as default
149+
150+
# Move all the generated posterior plot files to output folder
151+
posterior_files = glob.glob("plot-posterior-*.pdf")
152+
for file in posterior_files:
153+
shutil.move(file, os.path.join(output_folder, file))
154+
155+
print(f"Moved {len(posterior_files)} posterior plots to output folder")
156+
157+
158+
# To plot summary statistics of the posterior distributions, delete as appropriate
159+
#summary = az.summary(idata, var_names=["tau_hat", "w_hat","c_hat", "m_hat", "r_hat", "K_hat", "sigma"])
160+
summary = az.summary(idata, var_names=["c_hat", "sigma"])
161+
print("Summary Statistics:")
162+
print(summary[["mean", "sd", "r_hat"]])
163+
164+
# Also save to text file
165+
summary[["mean", "sd", "r_hat"]].to_csv(os.path.join(output_folder, 'summary_statistics.txt'), sep='\t')
166+
print("Saved summary statistics to summary_statistics.txt")
167+
168+
# Save posterior samples to file
169+
az.to_netcdf(idata, os.path.join(output_folder, 'model_posterior.nc'))
170+
171+
172+
#az.plot_trace(idata, var_names=["tau_hat", "w_hat","c_hat", "m_hat", "r_hat", "K_hat", "sigma"])
173+
az.plot_trace(idata, var_names=["c_hat", "sigma"])
174+
plt.savefig(os.path.join(output_folder, 'posterior-trace.jpg'), dpi=300, bbox_inches='tight')
175+
plt.close()
176+
print("Saved posterior trace plot")
177+
178+
179+
## Plot the CRM using median values from the posterior samples
180+
181+
init_species = 10 * np.ones(num_species+num_resources)
182+
183+
# inferred parameters - adjust as appropriate
184+
# tau_h = np.median(idata.posterior["tau_hat"].values, axis=(0,1))
185+
# w_h = np.median(idata.posterior["w_hat"].values, axis=(0,1))
186+
c_h = np.median(idata.posterior["c_hat"].values, axis=(0,1))
187+
# m_h = np.median(idata.posterior["m_hat"].values, axis=(0,1))
188+
# r_h = np.median(idata.posterior["r_hat"].values, axis=(0,1))
189+
# K_h = np.median(idata.posterior["K_hat"].values, axis=(0,1))
190+
191+
192+
# Individual parameter comparisons with specific filenames
193+
#params_to_compare = [('tau', tau, tau_h), ('w', w, w_h), ('c', c, c_h), ('m', m, m_h), ('r', r, r_h), ('K', K, K_h)]
194+
params_to_compare = [('c', c, c_h)]
195+
196+
197+
198+
for param_name, true_val, pred_val in params_to_compare:
199+
compare_params(**{param_name: (true_val, pred_val)})
200+
plt.savefig(os.path.join(output_folder, f'parameter_comparison_{param_name}.jpg'), dpi=300, bbox_inches='tight')
201+
plt.close()
202+
print(f"Saved {param_name} parameter comparison plot")
203+
204+
predictor = sim_CRM()
205+
206+
predictor.set_parameters(num_species = num_species,
207+
num_resources = num_resources,
208+
tau = tau,
209+
w = w,
210+
c = c_h,
211+
m = m,
212+
r = r,
213+
K = K)
214+
215+
#predictor.print_parameters()
216+
217+
observed_species, observed_resources = predictor.simulate(times, init_species)
218+
observed_data = np.hstack((observed_species, observed_resources))
219+
220+
# plot predicted species and resouce dynamics against observed data
221+
222+
plot_CRM(observed_species, observed_resources, times, 'data-s2-r2.csv')
223+
plt.savefig(os.path.join(output_folder, 'CRM_prediction.jpg'), dpi=300, bbox_inches='tight')
224+
plt.close()
225+
print("Saved CRM prediction plot")
226+
227+
228+
## Plot CRM with confidence intervals
229+
230+
# Get posterior samples for c_hat
231+
# tau_posterior_samples = idata.posterior["tau_hat"].values
232+
# w_posterior_samples = idata.posterior["w_hat"].values
233+
c_posterior_samples = idata.posterior["c_hat"].values
234+
# m_posterior_samples = idata.posterior["m_hat"].values
235+
# r_posterior_samples = idata.posterior["r_hat"].values
236+
# K_posterior_samples = idata.posterior["K_hat"].values
237+
238+
239+
# Create quantiles table for all parameters
240+
quantiles_dict = {}
241+
242+
# List of all possible parameters to check
243+
# param_names = ["tau_hat", "w_hat", "c_hat", "m_hat", "r_hat", "K_hat", "sigma"]
244+
param_names = ["c_hat", "sigma"]
245+
246+
for param_name in param_names:
247+
if param_name in idata.posterior.data_vars:
248+
param_samples = idata.posterior[param_name].values
249+
param_shape = param_samples.shape[2:] # Remove chain and draw dimensions
250+
251+
if len(param_shape) == 0: # Scalar parameter
252+
samples = param_samples.flatten()
253+
quantiles_dict[param_name] = {
254+
'0.25': np.percentile(samples, 25),
255+
'0.5': np.percentile(samples, 50),
256+
'0.75': np.percentile(samples, 75)
257+
}
258+
elif len(param_shape) == 1: # 1D array
259+
for i in range(param_shape[0]):
260+
param_key = f'{param_name}[{i}]'
261+
samples = param_samples[:, :, i].flatten()
262+
quantiles_dict[param_key] = {
263+
'0.25': np.percentile(samples, 25),
264+
'0.5': np.percentile(samples, 50),
265+
'0.75': np.percentile(samples, 75)
266+
}
267+
elif len(param_shape) == 2: # 2D matrix
268+
for i in range(param_shape[0]):
269+
for j in range(param_shape[1]):
270+
param_key = f'{param_name}[{i}, {j}]'
271+
samples = param_samples[:, :, i, j].flatten()
272+
quantiles_dict[param_key] = {
273+
'0.25': np.percentile(samples, 25),
274+
'0.5': np.percentile(samples, 50),
275+
'0.75': np.percentile(samples, 75)
276+
}
277+
278+
# Create DataFrame and save to text file
279+
quantiles_df = pd.DataFrame(quantiles_dict).T
280+
quantiles_df.to_csv(os.path.join(output_folder, 'parameter_quantiles.txt'), sep='\t', float_format='%.4f')
281+
print("Saved parameter quantiles to parameter_quantiles.txt")
282+
print(quantiles_df)
283+
284+
285+
286+
lower_percentile = 2.5
287+
upper_percentile = 97.5
288+
289+
n_samples = 50 # adjust as necessary
290+
random_indices = np.random.choice(c_posterior_samples.shape[1], size=n_samples, replace=False)
291+
292+
# Store simulation results
293+
all_species_trajectories = []
294+
all_resource_trajectories = []
295+
296+
# Run simulations with different posterior samples
297+
for i in range(n_samples):
298+
chain_idx = np.random.randint(0, c_posterior_samples.shape[0])
299+
draw_idx = np.random.randint(0, c_posterior_samples.shape[1])
300+
301+
# tau_sample = tau_posterior_samples[chain_idx, draw_idx]
302+
# w_sample = w_posterior_samples[chain_idx, draw_idx]
303+
c_sample = c_posterior_samples[chain_idx, draw_idx]
304+
# m_sample = m_posterior_samples[chain_idx, draw_idx]
305+
# r_sample = r_posterior_samples[chain_idx, draw_idx]
306+
# K_sample = K_posterior_samples[chain_idx, draw_idx]
307+
308+
sample_predictor = sim_CRM()
309+
sample_predictor.set_parameters(num_species=num_species,
310+
num_resources=num_resources,
311+
tau=tau,
312+
w=w,
313+
c=c_sample,
314+
m=m,
315+
r=r,
316+
K=K)
317+
318+
sample_species, sample_resources = sample_predictor.simulate(times, init_species)
319+
320+
# Store results
321+
all_species_trajectories.append(sample_species)
322+
all_resource_trajectories.append(sample_resources)
323+
324+
325+
# Convert to numpy arrays
326+
all_species_trajectories = np.array(all_species_trajectories)
327+
all_resource_trajectories = np.array(all_resource_trajectories)
328+
329+
# Calculate percentiles across samples for each time point and species/resource
330+
species_lower = np.percentile(all_species_trajectories, lower_percentile, axis=0)
331+
species_median = np.median(all_species_trajectories, axis=0)
332+
species_upper = np.percentile(all_species_trajectories, upper_percentile, axis=0)
333+
resource_lower = np.percentile(all_resource_trajectories, lower_percentile, axis=0)
334+
resource_median = np.median(all_resource_trajectories, axis=0)
335+
resource_upper = np.percentile(all_resource_trajectories, upper_percentile, axis=0)
336+
337+
# plot the CRM with confidence intervals
338+
plot_CRM_with_intervals(species_median, resource_median,
339+
species_lower, species_upper,
340+
resource_lower, resource_upper,
341+
times, 'data-s2-r2.csv')
342+
plt.savefig(os.path.join(output_folder, 'CRM_with_confidence_intervals.jpg'), dpi=300, bbox_inches='tight')
343+
plt.close()
344+
print("Saved CRM confidence intervals plot")

0 commit comments

Comments
 (0)