@@ -122,8 +122,7 @@ def test_find_breakpoints_with_noisy_data(self):
122122 3.69484446e-01 , 2.19359553e-01 , 1.19059738e-01 , 5.64115725e-02 ,
123123 2.30604686e-02 , 9.14406238e-03 , 4.24754874e-03 , 1.61814681e-03 ])
124124 spacecraft_potential , core_halo_breakpoint = find_breakpoints (
125- xs , avg_flux , 10 , 80 ,
126- 11 , 81 , config )
125+ xs , avg_flux , [10 , 10 , 10 ], [80 , 80 , 80 ], config )
127126 self .assertAlmostEqual (11.1 , spacecraft_potential , 1 )
128127 self .assertAlmostEqual (81.1 , core_halo_breakpoint , 1 )
129128
@@ -146,18 +145,18 @@ def test_find_breakpoints_with_synthetic_data(self):
146145 noise_floor = 1
147146 avg_flux += noise_floor
148147 spacecraft_potential , core_halo_breakpoint = find_breakpoints (
149- xs , avg_flux , 10 , 80 ,
150- 11 , 82 , config )
148+ xs , avg_flux , [ 10 , 10 , 10 ], [ 80 , 80 , 80 ] ,
149+ config )
151150 self .assertAlmostEqual (expected_potential , spacecraft_potential , 2 )
152151 self .assertAlmostEqual (expected_core_halo , core_halo_breakpoint , 0 )
153152
154153 def test_find_breakpoints_using_initial_guess (self ):
155154 config = build_swe_configuration ()
156155
157156 cases = [
158- (4 , 40 , 4 , 50 ),
159- (10 , 80 , 12 , 100 ),
160- (12 , 60 , 10 , 80 ),
157+ (4 , 40 , [ 4 , 4 , 4 ], [ 50 , 50 , 50 ] ),
158+ (10 , 80 , [ 12 , 12 , 12 ], [ 100 , 100 , 100 ] ),
159+ (12 , 60 , [ 10 , 10 , 10 ], [ 80 , 80 , 80 ] ),
161160 ]
162161 for case in cases :
163162 with self .subTest (case ):
@@ -172,7 +171,7 @@ def test_find_breakpoints_using_initial_guess(self):
172171 avg_flux += noise_floor
173172
174173 spacecraft_potential , core_halo_breakpoint = find_breakpoints (
175- xs , avg_flux , guess_potential , guess_halo , 10 , 80 , config )
174+ xs , avg_flux , guess_potential , guess_halo , config )
176175 self .assertAlmostEqual (expected_potential , spacecraft_potential , 2 )
177176 self .assertAlmostEqual (expected_core_halo , core_halo_breakpoint , 0 )
178177
@@ -199,9 +198,8 @@ def test_find_breakpoints_determines_b_deltas_correctly(self, mock_try_curve_fit
199198 avg_flux = np .exp (log_flux )
200199
201200 result = find_breakpoints (
202- xs , avg_flux , 10 , 80 , 15 ,
203- 90 , config )
204- mock_try_curve_fit_until_valid .assert_called_with (ANY , ANY , ANY , 15 , 90 , expected_b2_delta ,
201+ xs , avg_flux , [10 , 10 , 10 ], [80 , 80 , 80 ], config )
202+ mock_try_curve_fit_until_valid .assert_called_with (ANY , ANY , ANY , 10 , 80 , expected_b2_delta ,
205203 expected_b4_delta )
206204
207205 self .assertEqual (mock_try_curve_fit_until_valid .return_value , result )
@@ -233,8 +231,8 @@ def test_find_breakpoints_uses_config_for_slope_guesses(self, mock_curve_fit):
233231 avg_flux = np .exp (log_flux )
234232
235233 spacecraft_potential , core_halo_breakpoint = find_breakpoints (
236- xs , avg_flux , 10 , 80 ,
237- 11 , 81 , config )
234+ xs , avg_flux , [ 10 , 10 , 10 ], [ 80 , 80 , 80 ] ,
235+ config )
238236 expected_guesses = [ANY , b1 , 10 , b3 , 80 , b5 ]
239237 rounded_actuals = [round (x , 6 ) for x in mock_curve_fit .call_args .args [3 ]]
240238 self .assertEqual (expected_guesses , rounded_actuals )
@@ -264,8 +262,7 @@ def test_find_breakpoints_uses_config_for_slope_ratio(self, mock_curve_fit):
264262 avg_flux = np .exp (log_flux )
265263
266264 spacecraft_potential , core_halo_breakpoint = find_breakpoints (
267- xs , avg_flux , 10 , 80 ,
268- 11 , 81 , config )
265+ xs , avg_flux , [10 , 10 , 10 ], [80 , 80 , 80 ], config )
269266
270267 np .testing .assert_almost_equal (mock_curve_fit .call_args .args [1 ], xs [:data_length ])
271268 np .testing .assert_almost_equal (mock_curve_fit .call_args .args [2 ], log_flux [:data_length ])
0 commit comments