@@ -149,132 +149,112 @@ def test_repeat_arm_n_constructor_return_0(self) -> None:
149149 self .assertEqual (small_n , 0 )
150150 self .assertTrue (self .sobol_generation_node ._should_skip )
151151
152- def test_remaining_n_constructor_expect_1 (self ) -> None :
153- """Test that the remaining_n_constructor returns the remaining n."""
154- # should return 1 because 4 arms already exist and 5 are requested
155- expect_1 = NodeInputConstructors .REMAINING_N (
156- previous_node = None ,
157- next_node = self .sobol_generation_node ,
158- gs_gen_call_kwargs = {"n" : 5 , "grs_this_gen" : self .grs },
159- experiment = self .experiment ,
160- )
161- self .assertEqual (expect_1 , 1 )
162-
163- def test_remaining_n_constructor_expect_0 (self ) -> None :
164- # should return 0 because 4 arms already exist and 4 are requested
165- expect_0 = NodeInputConstructors .REMAINING_N (
166- previous_node = None ,
167- next_node = self .sobol_generation_node ,
168- gs_gen_call_kwargs = {"n" : 4 , "grs_this_gen" : self .grs },
169- experiment = self .experiment ,
170- )
171- self .assertEqual (expect_0 , 0 )
172-
173- def test_remaining_n_constructor_cap_at_zero (self ) -> None :
174- # should return 0 because 4 arms already exist and 3 are requested
175- # this is a bad state that should never be hit, but ensuring proper
176- # handling here feels like a valid edge case
177- expect_0 = NodeInputConstructors .REMAINING_N (
178- previous_node = None ,
179- next_node = self .sobol_generation_node ,
180- gs_gen_call_kwargs = {"n" : 3 , "grs_this_gen" : self .grs },
181- experiment = self .experiment ,
182- )
183- self .assertEqual (expect_0 , 0 )
152+ def test_remaining_n_constructor (self ) -> None :
153+ """Test that the remaining_n_constructor returns the correct remaining n."""
154+ for label , n , expected in [
155+ ("returns 1 when 4 arms exist and 5 requested" , 5 , 1 ),
156+ ("returns 0 when 4 arms exist and 4 requested" , 4 , 0 ),
157+ ("caps at 0 when 4 arms exist and 3 requested" , 3 , 0 ),
158+ ]:
159+ with self .subTest (label ):
160+ result = NodeInputConstructors .REMAINING_N (
161+ previous_node = None ,
162+ next_node = self .sobol_generation_node ,
163+ gs_gen_call_kwargs = {"n" : n , "grs_this_gen" : self .grs },
164+ experiment = self .experiment ,
165+ )
166+ self .assertEqual (result , expected )
184167
185168 def test_no_n_provided_all_n (self ) -> None :
186- num_to_gen = NodeInputConstructors .ALL_N (
187- previous_node = None ,
188- next_node = self .sobol_generation_node ,
189- gs_gen_call_kwargs = {},
190- experiment = self .experiment ,
191- )
192- self .assertEqual (num_to_gen , 10 )
193-
194- def test_no_n_provided_all_n_with_exp_prop (self ) -> None :
195- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 12
196- num_to_gen = NodeInputConstructors .ALL_N (
197- previous_node = None ,
198- next_node = self .sobol_generation_node ,
199- gs_gen_call_kwargs = {},
200- experiment = self .experiment ,
201- )
202- self .assertEqual (num_to_gen , 12 )
203-
204- def test_no_n_provided_all_n_with_exp_prop_long_run (self ) -> None :
205- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 13
206- self .sobol_generation_node ._trial_type = Keys .LONG_RUN
207- num_to_gen = NodeInputConstructors .ALL_N (
208- previous_node = None ,
209- next_node = self .sobol_generation_node ,
210- gs_gen_call_kwargs = {},
211- experiment = self .experiment ,
212- )
213- self .assertEqual (num_to_gen , 7 )
214-
215- def test_no_n_provided_all_n_with_exp_prop_short_run (self ) -> None :
216- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 13
217- self .sobol_generation_node ._trial_type = Keys .SHORT_RUN
218- num_to_gen = NodeInputConstructors .ALL_N (
219- previous_node = None ,
220- next_node = self .sobol_generation_node ,
221- gs_gen_call_kwargs = {},
222- experiment = self .experiment ,
223- )
224- self .assertEqual (num_to_gen , 6 )
169+ for label , exp_prop , trial_type , expected in [
170+ ("default returns 10" , None , None , 10 ),
171+ ("with exp prop 12" , 12 , None , 12 ),
172+ ("with exp prop 13 and long_run" , 13 , Keys .LONG_RUN , 7 ),
173+ ("with exp prop 13 and short_run" , 13 , Keys .SHORT_RUN , 6 ),
174+ ]:
175+ with self .subTest (label ):
176+ # Reset state for each subtest
177+ experiment = get_branin_experiment ()
178+ node = GenerationNode (
179+ name = "test" ,
180+ generator_specs = [self .sobol_generator_spec ],
181+ )
182+ if exp_prop is not None :
183+ experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = (
184+ exp_prop
185+ )
186+ if trial_type is not None :
187+ node ._trial_type = trial_type
188+ num_to_gen = NodeInputConstructors .ALL_N (
189+ previous_node = None ,
190+ next_node = node ,
191+ gs_gen_call_kwargs = {},
192+ experiment = experiment ,
193+ )
194+ self .assertEqual (num_to_gen , expected )
225195
226196 def test_no_n_provided_repeat_n (self ) -> None :
227- num_to_gen = NodeInputConstructors .REPEAT_N (
228- previous_node = None ,
229- next_node = self .sobol_generation_node ,
230- gs_gen_call_kwargs = {},
231- experiment = self .experiment ,
232- )
233- self .assertEqual (num_to_gen , 1 )
234-
235- def test_no_n_provided_repeat_n_with_exp_prop (self ) -> None :
236- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 18
237- num_to_gen = NodeInputConstructors .REPEAT_N (
238- previous_node = None ,
239- next_node = self .sobol_generation_node ,
240- gs_gen_call_kwargs = {},
241- experiment = self .experiment ,
242- )
243- self .assertEqual (num_to_gen , 2 )
244-
245- def test_no_n_provided_repeat_n_with_exp_prop_long_run (self ) -> None :
246- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 18
247- self .sobol_generation_node ._trial_type = Keys .SHORT_RUN
248- num_to_gen = NodeInputConstructors .REPEAT_N (
249- previous_node = None ,
250- next_node = self .sobol_generation_node ,
251- gs_gen_call_kwargs = {},
252- experiment = self .experiment ,
253- )
254- # expect 1 arm here because total concurrent arms is 18, and we have a trial
255- # type (short run), so we'll take the floor of 18/2 = 9 to be used in the
256- # logic for repeat arms which says if we have less than 10 requested arms we
257- # should get 1 repeat arm.
258- self .assertEqual (num_to_gen , 1 )
197+ for label , exp_prop , trial_type , expected in [
198+ ("default returns 1" , None , None , 1 ),
199+ ("with exp prop 18 returns 2" , 18 , None , 2 ),
200+ (
201+ "with exp prop 18 and short_run returns 1" ,
202+ 18 ,
203+ Keys .SHORT_RUN ,
204+ # expect 1 because total concurrent arms is 18, and we have a trial
205+ # type (short run), so we'll take the floor of 18/2 = 9 to be used
206+ # in the logic for repeat arms which says if we have less than 10
207+ # requested arms we should get 1 repeat arm.
208+ 1 ,
209+ ),
210+ ]:
211+ with self .subTest (label ):
212+ experiment = get_branin_experiment ()
213+ node = GenerationNode (
214+ name = "test" ,
215+ generator_specs = [self .sobol_generator_spec ],
216+ )
217+ if exp_prop is not None :
218+ experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = (
219+ exp_prop
220+ )
221+ if trial_type is not None :
222+ node ._trial_type = trial_type
223+ num_to_gen = NodeInputConstructors .REPEAT_N (
224+ previous_node = None ,
225+ next_node = node ,
226+ gs_gen_call_kwargs = {},
227+ experiment = experiment ,
228+ )
229+ self .assertEqual (num_to_gen , expected )
259230
260231 def test_no_n_provided_remaining_n (self ) -> None :
261- num_to_gen = NodeInputConstructors .REMAINING_N (
262- previous_node = None ,
263- next_node = self .sobol_generation_node ,
264- gs_gen_call_kwargs = {},
265- experiment = self .experiment ,
266- )
267- self .assertEqual (num_to_gen , 10 )
268-
269- def test_no_n_provided_remaining_n_with_exp_prop (self ) -> None :
270- self .experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = 8
271- num_to_gen = NodeInputConstructors .REMAINING_N (
272- previous_node = None ,
273- next_node = self .sobol_generation_node ,
274- gs_gen_call_kwargs = {"grs_this_gen" : self .grs },
275- experiment = self .experiment ,
276- )
277- self .assertEqual (num_to_gen , 4 )
232+ for label , exp_prop , gs_gen_call_kwargs , expected in [
233+ ("default returns 10" , None , {}, 10 ),
234+ (
235+ "with exp prop 8 and existing grs returns 4" ,
236+ 8 ,
237+ {"grs_this_gen" : self .grs },
238+ 4 ,
239+ ),
240+ ]:
241+ with self .subTest (label ):
242+ experiment = get_branin_experiment ()
243+ node = GenerationNode (
244+ name = "test" ,
245+ generator_specs = [self .sobol_generator_spec ],
246+ )
247+ if exp_prop is not None :
248+ experiment ._properties [Keys .EXPERIMENT_TOTAL_CONCURRENT_ARMS ] = (
249+ exp_prop
250+ )
251+ num_to_gen = NodeInputConstructors .REMAINING_N (
252+ previous_node = None ,
253+ next_node = node ,
254+ gs_gen_call_kwargs = gs_gen_call_kwargs ,
255+ experiment = experiment ,
256+ )
257+ self .assertEqual (num_to_gen , expected )
278258
279259 def test_set_target_trial_long_run_wins (self ) -> None :
280260 for num_arms , trial_type in zip ((1 , 3 ), (Keys .LONG_RUN , Keys .SHORT_RUN )):
0 commit comments