Skip to content

Commit e08c0e6

Browse files
authored
Merge pull request #96 from ucl-cssb/Applying-fix-84-to-gLV
This PR looks great
2 parents e1b5745 + fd5623e commit e08c0e6

5 files changed

Lines changed: 352 additions & 276 deletions

File tree

examples/gLV/examples-Rutter-Dekker.ipynb

Lines changed: 177 additions & 145 deletions
Large diffs are not rendered by default.

mimic/model_infer/infer_gLV_bayes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from mimic.utilities import *
1010
from mimic.model_simulate.sim_gLV import *
11+
from mimic.model_infer.base_infer import BaseInfer
1112

1213

1314
import pandas as pd
@@ -83,7 +84,7 @@ def plot_growth_curves(data):
8384
plt.show()
8485

8586

86-
class infergLVbayes:
87+
class infergLVbayes(BaseInfer):
8788
"""
8889
bayes_gLV class for Bayesian inference of gLV models without shrinkage priors
8990
@@ -248,7 +249,7 @@ def calculate_DA0(self, num_species, proportion=0.15):
248249
DA0 = int(round(expected_non_zero_elements))
249250
return max(DA0, 1)
250251

251-
def run_bayes_gLV(self) -> None:
252+
def run_inference(self) -> None:
252253
"""
253254
This function infers the parameters for the Bayesian gLV model
254255

tests/test_base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_print_parameters(mock_model, capsys):
9999
"""Test printing of model parameters."""
100100
mock_model.model = "TestModel"
101101
mock_model.parameters = {"param1": 1, "param2": np.array([1.0, 2.0])}
102+
mock_model.debug = "low" # <--- make sure it actually prints
102103

103104
mock_model.print_parameters(precision=2)
104105

tests/test_infer_VAR_bayes.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def create_mock_trace(dim, chains=2, draws=500):
3535
)
3636

3737
posterior = xr.Dataset({"x0": x0, "A": A})
38-
39-
trace = az.InferenceData(posterior=posterior)
40-
return trace
38+
return az.InferenceData(posterior=posterior)
4139

4240

4341
def create_mock_trace_xs(nX, nS, chains=2, draws=500):
@@ -55,9 +53,7 @@ def create_mock_trace_xs(nX, nS, chains=2, draws=500):
5553
)
5654

5755
posterior = xr.Dataset({"Ah": Ah, "Bh": Bh})
58-
59-
trace = az.InferenceData(posterior=posterior)
60-
return trace
56+
return az.InferenceData(posterior=posterior)
6157

6258

6359
def test_initialization(example_data):
@@ -76,11 +72,9 @@ def test_initialization(example_data):
7672
def test_run_inference(example_data):
7773
# Initialize the model with data
7874
model = infer_VAR(data=example_data)
79-
80-
# Determine the dimension based on example_data
8175
dim = example_data.shape[1]
8276

83-
# Create a mock trace object with the correct dimensions
77+
# Create a mock trace object
8478
mock_trace = create_mock_trace(dim=dim)
8579

8680
# Mock pymc's sample method to return the mock trace
@@ -98,7 +92,6 @@ def test_run_inference_large(example_data):
9892

9993
# Initialize the model with data
10094
model = infer_VAR(data=data)
101-
10295
dim = data.shape[1]
10396

10497
# Create a mock trace object
@@ -109,14 +102,14 @@ def test_run_inference_large(example_data):
109102
patch('mimic.model_infer.infer_VAR_bayes.plt.savefig'), \
110103
patch('mimic.model_infer.infer_VAR_bayes.az.to_netcdf'), \
111104
patch('mimic.model_infer.infer_VAR_bayes.np.savez'):
112-
model.run_inference_large(samples=1000, tune=500, cores=4)
105+
# Call the new single method with method='large'
106+
model.run_inference(method='large', samples=1000, tune=500, cores=4)
113107
mock_sample.assert_called_once_with(draws=1000, tune=500, cores=4)
114108

115109

116110
def test_run_inference_xs(example_data, example_metabolite_data):
117111
# Initialize the model with data and metabolite data
118112
model = infer_VAR(data=example_data, dataS=example_metabolite_data)
119-
120113
nX = example_data.shape[1]
121114
nS = example_metabolite_data.shape[1]
122115

@@ -128,13 +121,13 @@ def test_run_inference_xs(example_data, example_metabolite_data):
128121
patch('mimic.model_infer.infer_VAR_bayes.plt.savefig'), \
129122
patch('mimic.model_infer.infer_VAR_bayes.az.to_netcdf'), \
130123
patch('mimic.model_infer.infer_VAR_bayes.np.savez'):
131-
model.run_inference_xs(samples=500, tune=200, cores=2)
124+
model.run_inference(method='xs', samples=500, tune=200, cores=2)
132125
mock_sample.assert_called_once_with(draws=500, tune=200, cores=2)
133126

134127
# Test for missing metabolite data
135128
model.dataS = None
136129
with pytest.raises(ValueError):
137-
model.run_inference_xs()
130+
model.run_inference(method='xs')
138131

139132

140133
def test_posterior_analysis(mocker, example_data, example_metabolite_data):
@@ -156,7 +149,8 @@ def test_posterior_analysis(mocker, example_data, example_metabolite_data):
156149
# Mock np.load to return a context manager with example data
157150
mock_np_load = MagicMock()
158151
mock_np_load.__enter__.return_value = {
159-
'dataX': example_data, 'dataS': example_metabolite_data}
152+
'dataX': example_data, 'dataS': example_metabolite_data
153+
}
160154
mocker.patch('numpy.load', return_value=mock_np_load)
161155

162156
# Ensure that the trace exists in model to avoid loading errors

0 commit comments

Comments
 (0)