Skip to content

Commit df7c1bd

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Consolidate ax/early_stopping and ax/generation_strategy tests (#5010)
Summary: Pull Request resolved: #5010 Part of a 19-diff stack to consolidate repetitive tests across Ax using `subTest`. Consolidate 2 test files in ax/generation_strategy/ — adds subTest to best model selector criterion/aggregation tests and generation node input constructor tests. Differential Revision: D95604515
1 parent eb9a957 commit df7c1bd

2 files changed

Lines changed: 131 additions & 152 deletions

File tree

ax/generation_strategy/tests/test_best_model_selector.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,38 +48,37 @@ def test_user_input_error(self) -> None:
4848
criterion=ReductionCriterion.MEAN,
4949
)
5050

51-
def test_SingleDiagnosticBestModelSelector_min_mean(self) -> None:
52-
s = SingleDiagnosticBestModelSelector(
53-
diagnostic="Fisher exact test p",
54-
criterion=ReductionCriterion.MIN,
55-
metric_aggregation=ReductionCriterion.MEAN,
56-
)
57-
# Min/mean will pick index 1 since it has the lowest mean (0.1 vs 0.2 & 0.55).
58-
self.assertIs(
59-
s.best_model(generator_specs=self.generator_specs), self.generator_specs[1]
60-
)
61-
62-
def test_SingleDiagnosticBestModelSelector_min_min(self) -> None:
63-
s = SingleDiagnosticBestModelSelector(
64-
diagnostic="Fisher exact test p",
65-
criterion=ReductionCriterion.MIN,
66-
metric_aggregation=ReductionCriterion.MIN,
67-
)
68-
# Min/min will pick index 0 since it has the lowest min (0.0 vs 0.1 & 0.5).
69-
self.assertIs(
70-
s.best_model(generator_specs=self.generator_specs), self.generator_specs[0]
71-
)
72-
73-
def test_SingleDiagnosticBestModelSelector_max_mean(self) -> None:
74-
s = SingleDiagnosticBestModelSelector(
75-
diagnostic="Fisher exact test p",
76-
criterion=ReductionCriterion.MAX,
77-
metric_aggregation=ReductionCriterion.MEAN,
78-
)
79-
# Max/mean will pick index 2 since it has the largest mean (0.55 vs 0.1 & 0.2).
80-
self.assertIs(
81-
s.best_model(generator_specs=self.generator_specs), self.generator_specs[2]
82-
)
51+
def test_SingleDiagnosticBestModelSelector_criterion_aggregation(self) -> None:
52+
for label, criterion, aggregation, expected_index in [
53+
(
54+
"min_mean picks index 1 (lowest mean: 0.1)",
55+
ReductionCriterion.MIN,
56+
ReductionCriterion.MEAN,
57+
1,
58+
),
59+
(
60+
"min_min picks index 0 (lowest min: 0.0)",
61+
ReductionCriterion.MIN,
62+
ReductionCriterion.MIN,
63+
0,
64+
),
65+
(
66+
"max_mean picks index 2 (largest mean: 0.55)",
67+
ReductionCriterion.MAX,
68+
ReductionCriterion.MEAN,
69+
2,
70+
),
71+
]:
72+
with self.subTest(label):
73+
s = SingleDiagnosticBestModelSelector(
74+
diagnostic="Fisher exact test p",
75+
criterion=criterion,
76+
metric_aggregation=aggregation,
77+
)
78+
self.assertIs(
79+
s.best_model(generator_specs=self.generator_specs),
80+
self.generator_specs[expected_index],
81+
)
8382

8483
def test_SingleDiagnosticBestModelSelector_cv_kwargs(self) -> None:
8584
s = SingleDiagnosticBestModelSelector(

ax/generation_strategy/tests/test_generation_node_input_constructors.py

Lines changed: 100 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -149,132 +149,112 @@ def test_repeat_arm_n_constructor_return_0(self) -> None:
149149
self.assertEqual(small_n, 0)
150150
self.assertTrue(self.sobol_generation_node._should_skip)
151151

152-
def test_remaining_n_constructor_expect_1(self) -> None:
153-
"""Test that the remaining_n_constructor returns the remaining n."""
154-
# should return 1 because 4 arms already exist and 5 are requested
155-
expect_1 = NodeInputConstructors.REMAINING_N(
156-
previous_node=None,
157-
next_node=self.sobol_generation_node,
158-
gs_gen_call_kwargs={"n": 5, "grs_this_gen": self.grs},
159-
experiment=self.experiment,
160-
)
161-
self.assertEqual(expect_1, 1)
162-
163-
def test_remaining_n_constructor_expect_0(self) -> None:
164-
# should return 0 because 4 arms already exist and 4 are requested
165-
expect_0 = NodeInputConstructors.REMAINING_N(
166-
previous_node=None,
167-
next_node=self.sobol_generation_node,
168-
gs_gen_call_kwargs={"n": 4, "grs_this_gen": self.grs},
169-
experiment=self.experiment,
170-
)
171-
self.assertEqual(expect_0, 0)
172-
173-
def test_remaining_n_constructor_cap_at_zero(self) -> None:
174-
# should return 0 because 4 arms already exist and 3 are requested
175-
# this is a bad state that should never be hit, but ensuring proper
176-
# handling here feels like a valid edge case
177-
expect_0 = NodeInputConstructors.REMAINING_N(
178-
previous_node=None,
179-
next_node=self.sobol_generation_node,
180-
gs_gen_call_kwargs={"n": 3, "grs_this_gen": self.grs},
181-
experiment=self.experiment,
182-
)
183-
self.assertEqual(expect_0, 0)
152+
def test_remaining_n_constructor(self) -> None:
153+
"""Test that the remaining_n_constructor returns the correct remaining n."""
154+
for label, n, expected in [
155+
("returns 1 when 4 arms exist and 5 requested", 5, 1),
156+
("returns 0 when 4 arms exist and 4 requested", 4, 0),
157+
("caps at 0 when 4 arms exist and 3 requested", 3, 0),
158+
]:
159+
with self.subTest(label):
160+
result = NodeInputConstructors.REMAINING_N(
161+
previous_node=None,
162+
next_node=self.sobol_generation_node,
163+
gs_gen_call_kwargs={"n": n, "grs_this_gen": self.grs},
164+
experiment=self.experiment,
165+
)
166+
self.assertEqual(result, expected)
184167

185168
def test_no_n_provided_all_n(self) -> None:
186-
num_to_gen = NodeInputConstructors.ALL_N(
187-
previous_node=None,
188-
next_node=self.sobol_generation_node,
189-
gs_gen_call_kwargs={},
190-
experiment=self.experiment,
191-
)
192-
self.assertEqual(num_to_gen, 10)
193-
194-
def test_no_n_provided_all_n_with_exp_prop(self) -> None:
195-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 12
196-
num_to_gen = NodeInputConstructors.ALL_N(
197-
previous_node=None,
198-
next_node=self.sobol_generation_node,
199-
gs_gen_call_kwargs={},
200-
experiment=self.experiment,
201-
)
202-
self.assertEqual(num_to_gen, 12)
203-
204-
def test_no_n_provided_all_n_with_exp_prop_long_run(self) -> None:
205-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13
206-
self.sobol_generation_node._trial_type = Keys.LONG_RUN
207-
num_to_gen = NodeInputConstructors.ALL_N(
208-
previous_node=None,
209-
next_node=self.sobol_generation_node,
210-
gs_gen_call_kwargs={},
211-
experiment=self.experiment,
212-
)
213-
self.assertEqual(num_to_gen, 7)
214-
215-
def test_no_n_provided_all_n_with_exp_prop_short_run(self) -> None:
216-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 13
217-
self.sobol_generation_node._trial_type = Keys.SHORT_RUN
218-
num_to_gen = NodeInputConstructors.ALL_N(
219-
previous_node=None,
220-
next_node=self.sobol_generation_node,
221-
gs_gen_call_kwargs={},
222-
experiment=self.experiment,
223-
)
224-
self.assertEqual(num_to_gen, 6)
169+
for label, exp_prop, trial_type, expected in [
170+
("default returns 10", None, None, 10),
171+
("with exp prop 12", 12, None, 12),
172+
("with exp prop 13 and long_run", 13, Keys.LONG_RUN, 7),
173+
("with exp prop 13 and short_run", 13, Keys.SHORT_RUN, 6),
174+
]:
175+
with self.subTest(label):
176+
# Reset state for each subtest
177+
experiment = get_branin_experiment()
178+
node = GenerationNode(
179+
name="test",
180+
generator_specs=[self.sobol_generator_spec],
181+
)
182+
if exp_prop is not None:
183+
experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = (
184+
exp_prop
185+
)
186+
if trial_type is not None:
187+
node._trial_type = trial_type
188+
num_to_gen = NodeInputConstructors.ALL_N(
189+
previous_node=None,
190+
next_node=node,
191+
gs_gen_call_kwargs={},
192+
experiment=experiment,
193+
)
194+
self.assertEqual(num_to_gen, expected)
225195

226196
def test_no_n_provided_repeat_n(self) -> None:
227-
num_to_gen = NodeInputConstructors.REPEAT_N(
228-
previous_node=None,
229-
next_node=self.sobol_generation_node,
230-
gs_gen_call_kwargs={},
231-
experiment=self.experiment,
232-
)
233-
self.assertEqual(num_to_gen, 1)
234-
235-
def test_no_n_provided_repeat_n_with_exp_prop(self) -> None:
236-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18
237-
num_to_gen = NodeInputConstructors.REPEAT_N(
238-
previous_node=None,
239-
next_node=self.sobol_generation_node,
240-
gs_gen_call_kwargs={},
241-
experiment=self.experiment,
242-
)
243-
self.assertEqual(num_to_gen, 2)
244-
245-
def test_no_n_provided_repeat_n_with_exp_prop_long_run(self) -> None:
246-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 18
247-
self.sobol_generation_node._trial_type = Keys.SHORT_RUN
248-
num_to_gen = NodeInputConstructors.REPEAT_N(
249-
previous_node=None,
250-
next_node=self.sobol_generation_node,
251-
gs_gen_call_kwargs={},
252-
experiment=self.experiment,
253-
)
254-
# expect 1 arm here because total concurrent arms is 18, and we have a trial
255-
# type (short run), so we'll take the floor of 18/2 = 9 to be used in the
256-
# logic for repeat arms which says if we have less than 10 requested arms we
257-
# should get 1 repeat arm.
258-
self.assertEqual(num_to_gen, 1)
197+
for label, exp_prop, trial_type, expected in [
198+
("default returns 1", None, None, 1),
199+
("with exp prop 18 returns 2", 18, None, 2),
200+
(
201+
"with exp prop 18 and short_run returns 1",
202+
18,
203+
Keys.SHORT_RUN,
204+
# expect 1 because total concurrent arms is 18, and we have a trial
205+
# type (short run), so we'll take the floor of 18/2 = 9 to be used
206+
# in the logic for repeat arms which says if we have less than 10
207+
# requested arms we should get 1 repeat arm.
208+
1,
209+
),
210+
]:
211+
with self.subTest(label):
212+
experiment = get_branin_experiment()
213+
node = GenerationNode(
214+
name="test",
215+
generator_specs=[self.sobol_generator_spec],
216+
)
217+
if exp_prop is not None:
218+
experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = (
219+
exp_prop
220+
)
221+
if trial_type is not None:
222+
node._trial_type = trial_type
223+
num_to_gen = NodeInputConstructors.REPEAT_N(
224+
previous_node=None,
225+
next_node=node,
226+
gs_gen_call_kwargs={},
227+
experiment=experiment,
228+
)
229+
self.assertEqual(num_to_gen, expected)
259230

260231
def test_no_n_provided_remaining_n(self) -> None:
261-
num_to_gen = NodeInputConstructors.REMAINING_N(
262-
previous_node=None,
263-
next_node=self.sobol_generation_node,
264-
gs_gen_call_kwargs={},
265-
experiment=self.experiment,
266-
)
267-
self.assertEqual(num_to_gen, 10)
268-
269-
def test_no_n_provided_remaining_n_with_exp_prop(self) -> None:
270-
self.experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = 8
271-
num_to_gen = NodeInputConstructors.REMAINING_N(
272-
previous_node=None,
273-
next_node=self.sobol_generation_node,
274-
gs_gen_call_kwargs={"grs_this_gen": self.grs},
275-
experiment=self.experiment,
276-
)
277-
self.assertEqual(num_to_gen, 4)
232+
for label, exp_prop, gs_gen_call_kwargs, expected in [
233+
("default returns 10", None, {}, 10),
234+
(
235+
"with exp prop 8 and existing grs returns 4",
236+
8,
237+
{"grs_this_gen": self.grs},
238+
4,
239+
),
240+
]:
241+
with self.subTest(label):
242+
experiment = get_branin_experiment()
243+
node = GenerationNode(
244+
name="test",
245+
generator_specs=[self.sobol_generator_spec],
246+
)
247+
if exp_prop is not None:
248+
experiment._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS] = (
249+
exp_prop
250+
)
251+
num_to_gen = NodeInputConstructors.REMAINING_N(
252+
previous_node=None,
253+
next_node=node,
254+
gs_gen_call_kwargs=gs_gen_call_kwargs,
255+
experiment=experiment,
256+
)
257+
self.assertEqual(num_to_gen, expected)
278258

279259
def test_set_target_trial_long_run_wins(self) -> None:
280260
for num_arms, trial_type in zip((1, 3), (Keys.LONG_RUN, Keys.SHORT_RUN)):

0 commit comments

Comments
 (0)