@@ -35,9 +35,7 @@ def create_mock_trace(dim, chains=2, draws=500):
3535 )
3636
3737 posterior = xr .Dataset ({"x0" : x0 , "A" : A })
38-
39- trace = az .InferenceData (posterior = posterior )
40- return trace
38+ return az .InferenceData (posterior = posterior )
4139
4240
4341def create_mock_trace_xs (nX , nS , chains = 2 , draws = 500 ):
@@ -55,9 +53,7 @@ def create_mock_trace_xs(nX, nS, chains=2, draws=500):
5553 )
5654
5755 posterior = xr .Dataset ({"Ah" : Ah , "Bh" : Bh })
58-
59- trace = az .InferenceData (posterior = posterior )
60- return trace
56+ return az .InferenceData (posterior = posterior )
6157
6258
6359def test_initialization (example_data ):
@@ -76,11 +72,9 @@ def test_initialization(example_data):
7672def test_run_inference (example_data ):
7773 # Initialize the model with data
7874 model = infer_VAR (data = example_data )
79-
80- # Determine the dimension based on example_data
8175 dim = example_data .shape [1 ]
8276
83- # Create a mock trace object with the correct dimensions
77+ # Create a mock trace object
8478 mock_trace = create_mock_trace (dim = dim )
8579
8680 # Mock pymc's sample method to return the mock trace
@@ -98,7 +92,6 @@ def test_run_inference_large(example_data):
9892
9993 # Initialize the model with data
10094 model = infer_VAR (data = data )
101-
10295 dim = data .shape [1 ]
10396
10497 # Create a mock trace object
@@ -109,14 +102,14 @@ def test_run_inference_large(example_data):
109102 patch ('mimic.model_infer.infer_VAR_bayes.plt.savefig' ), \
110103 patch ('mimic.model_infer.infer_VAR_bayes.az.to_netcdf' ), \
111104 patch ('mimic.model_infer.infer_VAR_bayes.np.savez' ):
112- model .run_inference_large (samples = 1000 , tune = 500 , cores = 4 )
105+ # Call the new single method with method='large'
106+ model .run_inference (method = 'large' , samples = 1000 , tune = 500 , cores = 4 )
113107 mock_sample .assert_called_once_with (draws = 1000 , tune = 500 , cores = 4 )
114108
115109
116110def test_run_inference_xs (example_data , example_metabolite_data ):
117111 # Initialize the model with data and metabolite data
118112 model = infer_VAR (data = example_data , dataS = example_metabolite_data )
119-
120113 nX = example_data .shape [1 ]
121114 nS = example_metabolite_data .shape [1 ]
122115
@@ -128,13 +121,13 @@ def test_run_inference_xs(example_data, example_metabolite_data):
128121 patch ('mimic.model_infer.infer_VAR_bayes.plt.savefig' ), \
129122 patch ('mimic.model_infer.infer_VAR_bayes.az.to_netcdf' ), \
130123 patch ('mimic.model_infer.infer_VAR_bayes.np.savez' ):
131- model .run_inference_xs ( samples = 500 , tune = 200 , cores = 2 )
124+ model .run_inference ( method = 'xs' , samples = 500 , tune = 200 , cores = 2 )
132125 mock_sample .assert_called_once_with (draws = 500 , tune = 200 , cores = 2 )
133126
134127 # Test for missing metabolite data
135128 model .dataS = None
136129 with pytest .raises (ValueError ):
137- model .run_inference_xs ( )
130+ model .run_inference ( method = 'xs' )
138131
139132
140133def test_posterior_analysis (mocker , example_data , example_metabolite_data ):
@@ -156,7 +149,8 @@ def test_posterior_analysis(mocker, example_data, example_metabolite_data):
156149 # Mock np.load to return a context manager with example data
157150 mock_np_load = MagicMock ()
158151 mock_np_load .__enter__ .return_value = {
159- 'dataX' : example_data , 'dataS' : example_metabolite_data }
152+ 'dataX' : example_data , 'dataS' : example_metabolite_data
153+ }
160154 mocker .patch ('numpy.load' , return_value = mock_np_load )
161155
162156 # Ensure that the trace exists in model to avoid loading errors
0 commit comments