Skip to content

Commit 44a460d

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Consolidate ax/core tests (facebook#5009)
Summary: Pull Request resolved: facebook#5009 Part of a 19-diff stack to consolidate repetitive tests across Ax using `subTest`. Consolidate 12 test files in ax/core/ — adds subTest to experiment, metric, parameter constraint, runner, trial, and types tests. Differential Revision: D95604288
1 parent f588404 commit 44a460d

12 files changed

Lines changed: 352 additions & 380 deletions

ax/core/tests/test_batch_trial.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -530,34 +530,35 @@ def test_Repr(self) -> None:
530530
"BatchTrial(experiment_name='test', index=0, status=TrialStatus.CANDIDATE)",
531531
)
532532

533-
def test_TTL_status_property_check(self) -> None:
534-
"""Verify that TTL is checked on execution of the `status` property."""
535-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
536-
self.assertTrue(candidate_trial.status.is_candidate)
537-
sleep(1) # Wait 1 second for trial TTL to elapse.
538-
# Candidate should become stale on expiration of TTL.
539-
self.assertTrue(candidate_trial.status.is_stale)
540-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
541-
542-
def test_TTL_trial_indices_by_status_check(self) -> None:
543-
"""Verify that TTL is checked on `experiment.trial_indices_by_status`."""
544-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
545-
self.assertTrue(candidate_trial.status.is_candidate)
546-
sleep(1) # Wait 1 second for trial TTL to elapse.
547-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
548-
self.assertTrue(candidate_trial.status.is_stale)
549-
550-
def test_TTL_experiment_trials_check(self) -> None:
551-
"""Verify that TTL is checked on `experiment.trials`."""
552-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
553-
self.assertTrue(candidate_trial.status.is_candidate)
554-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.CANDIDATE])
555-
sleep(1) # Wait 1 second for trial TTL to elapse.
556-
self.experiment.trials
557-
# Check `_status`, not `status`, to ensure it's within `trials` that the status
558-
# was actually changed, not in `status`.
559-
self.assertEqual(candidate_trial._status, TrialStatus.STALE)
560-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
533+
def test_TTL_expiry_across_access_paths(self) -> None:
534+
"""Verify that TTL is checked via different access paths."""
535+
536+
def check_via_experiment_trials(trial: BatchTrial) -> bool:
537+
# Access experiment.trials first (this is the side-effect that
538+
# triggers the TTL status update), then separately verify
539+
# that the private _status field was updated.
540+
_ = self.experiment.trials
541+
return trial._status == TrialStatus.STALE
542+
543+
ttl_access_paths = {
544+
"status_property": lambda trial: trial.status.is_stale,
545+
"trial_indices_by_status": lambda trial: (
546+
1 in self.experiment.trial_indices_by_status[TrialStatus.STALE]
547+
),
548+
"experiment_trials": check_via_experiment_trials,
549+
}
550+
for access_path, check_stale in ttl_access_paths.items():
551+
with self.subTest(access_path=access_path):
552+
# Reset by creating a fresh experiment for each sub-test
553+
self.setUp()
554+
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
555+
self.assertTrue(candidate_trial.status.is_candidate)
556+
sleep(1) # Wait 1 second for trial TTL to elapse.
557+
self.assertTrue(check_stale(candidate_trial))
558+
self.assertIn(
559+
1,
560+
self.experiment.trial_indices_by_status[TrialStatus.STALE],
561+
)
561562

562563
def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
563564
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)

ax/core/tests/test_data.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -793,47 +793,43 @@ def setUp(self) -> None:
793793
)
794794

795795
def test_relativize_data(self) -> None:
796-
data = Data(df=self.df)
797-
expected_relativized_data = Data(df=self.expected_relativized_df)
796+
for sem_label, modify_sem in [
797+
("with_sem", False),
798+
("no_sem", True),
799+
]:
800+
with self.subTest(sem=sem_label):
801+
df = self.df.copy()
802+
expected_relativized_df = self.expected_relativized_df.copy()
803+
expected_relativized_df_with_sq = (
804+
self.expected_relativized_df_with_sq.copy()
805+
)
798806

799-
expected_relativized_data_with_sq = Data(
800-
df=self.expected_relativized_df_with_sq
801-
)
807+
if modify_sem:
808+
df["sem"] = np.nan
809+
expected_relativized_df["sem"] = np.nan
810+
expected_relativized_df_with_sq.loc[
811+
expected_relativized_df_with_sq["arm_name"] != "status_quo",
812+
"sem",
813+
] = np.nan
802814

803-
actual_relativized_data = data.relativize()
804-
self.assertEqual(expected_relativized_data, actual_relativized_data)
815+
data = Data(df=df)
816+
expected_relativized_data = Data(df=expected_relativized_df)
817+
expected_relativized_data_with_sq = Data(
818+
df=expected_relativized_df_with_sq
819+
)
805820

806-
actual_relativized_data_with_sq = data.relativize(include_sq=True)
807-
self.assertEqual(
808-
expected_relativized_data_with_sq, actual_relativized_data_with_sq
809-
)
821+
actual_relativized_data = data.relativize()
822+
self.assertEqual(expected_relativized_data, actual_relativized_data)
823+
824+
actual_relativized_data_with_sq = data.relativize(include_sq=True)
825+
self.assertEqual(
826+
expected_relativized_data_with_sq,
827+
actual_relativized_data_with_sq,
828+
)
810829

811830
with self.subTest("step column not supported"):
812831
data = Data(df=self.df.assign(step=0))
813832
with self.assertRaisesRegex(
814833
NotImplementedError, "Relativization is not supported"
815834
):
816835
data.relativize()
817-
818-
def test_relativize_data_no_sem(self) -> None:
819-
df = self.df.copy()
820-
df["sem"] = np.nan
821-
data = Data(df=df)
822-
823-
expected_relativized_df = self.expected_relativized_df.copy()
824-
expected_relativized_df["sem"] = np.nan
825-
expected_relativized_data = Data(df=expected_relativized_df)
826-
827-
expected_relativized_df_with_sq = self.expected_relativized_df_with_sq.copy()
828-
expected_relativized_df_with_sq.loc[
829-
expected_relativized_df_with_sq["arm_name"] != "status_quo", "sem"
830-
] = np.nan
831-
expected_relativized_data_with_sq = Data(df=expected_relativized_df_with_sq)
832-
833-
actual_relativized_data = data.relativize()
834-
self.assertEqual(expected_relativized_data, actual_relativized_data)
835-
836-
actual_relativized_data_with_sq = data.relativize(include_sq=True)
837-
self.assertEqual(
838-
expected_relativized_data_with_sq, actual_relativized_data_with_sq
839-
)

ax/core/tests/test_experiment.py

Lines changed: 87 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,79 +1102,70 @@ def test_immutable_search_space_and_opt_config(self) -> None:
11021102
)
11031103
self.assertTrue(immutable_exp_2.immutable_search_space_and_opt_config)
11041104

1105-
def test_attach_batch_trial_no_arm_names(self) -> None:
1106-
num_trials = len(self.experiment.trials)
1107-
1108-
_, trial_index = self.experiment.attach_trial(
1109-
parameterizations=[
1110-
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1111-
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1112-
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1113-
],
1114-
ttl_seconds=3600,
1115-
run_metadata={"test_metadata_field": 1},
1116-
)
1117-
1118-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1119-
self.assertEqual(
1120-
len(set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"}),
1121-
3,
1122-
)
1123-
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)
1124-
1125-
def test_attach_batch_trial_with_arm_names(self) -> None:
1126-
num_trials = len(self.experiment.trials)
1127-
1128-
_, trial_index = self.experiment.attach_trial(
1129-
parameterizations=[
1130-
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1131-
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1132-
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1133-
],
1134-
arm_names=["arm1", "arm2", "arm3"],
1135-
ttl_seconds=3600,
1136-
run_metadata={"test_metadata_field": 1},
1137-
)
1138-
1139-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1140-
self.assertEqual(
1141-
len(set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"}),
1142-
3,
1143-
)
1144-
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)
1145-
self.assertEqual(
1146-
{"arm1", "arm2", "arm3"},
1147-
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
1148-
)
1149-
1150-
def test_attach_single_arm_trial_no_arm_name(self) -> None:
1105+
def _test_attach_trial(
1106+
self,
1107+
parameterizations: list[dict[str, object]],
1108+
arm_names: list[str] | None,
1109+
expected_type: type[BaseTrial],
1110+
expected_names: set[str] | str | None,
1111+
) -> None:
1112+
self.setUp()
11511113
num_trials = len(self.experiment.trials)
1152-
1153-
_, trial_index = self.experiment.attach_trial(
1154-
parameterizations=[{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6}],
1155-
ttl_seconds=3600,
1156-
run_metadata={"test_metadata_field": 1},
1157-
)
1114+
if arm_names is not None:
1115+
_, trial_index = self.experiment.attach_trial(
1116+
parameterizations=parameterizations,
1117+
arm_names=arm_names,
1118+
ttl_seconds=3600,
1119+
run_metadata={"test_metadata_field": 1},
1120+
)
1121+
else:
1122+
_, trial_index = self.experiment.attach_trial(
1123+
parameterizations=parameterizations,
1124+
ttl_seconds=3600,
1125+
run_metadata={"test_metadata_field": 1},
1126+
)
11581127

11591128
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1160-
self.assertEqual(type(self.experiment.trials[trial_index]), Trial)
1161-
1162-
def test_attach_single_arm_trial_with_arm_name(self) -> None:
1163-
num_trials = len(self.experiment.trials)
1164-
1165-
_, trial_index = self.experiment.attach_trial(
1166-
parameterizations=[{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6}],
1167-
arm_names=["arm1"],
1168-
ttl_seconds=3600,
1169-
run_metadata={"test_metadata_field": 1},
1170-
)
1129+
self.assertEqual(type(self.experiment.trials[trial_index]), expected_type)
1130+
if isinstance(expected_names, set):
1131+
self.assertEqual(
1132+
expected_names,
1133+
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
1134+
)
1135+
elif isinstance(expected_names, str):
1136+
self.assertEqual(
1137+
expected_names,
1138+
self.experiment.trials[trial_index].arm.name,
1139+
)
11711140

1172-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1173-
self.assertEqual(type(self.experiment.trials[trial_index]), Trial)
1174-
self.assertEqual(
1175-
"arm1",
1176-
self.experiment.trials[trial_index].arm.name,
1177-
)
1141+
def test_attach_trial(self) -> None:
1142+
batch_params: list[dict[str, object]] = [
1143+
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1144+
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1145+
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1146+
]
1147+
single_params: list[dict[str, object]] = [
1148+
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1149+
]
1150+
for label, params, arm_names, expected_type, expected_names in [
1151+
("batch_no_arm_names", batch_params, None, BatchTrial, None),
1152+
(
1153+
"batch_with_arm_names",
1154+
batch_params,
1155+
["arm1", "arm2", "arm3"],
1156+
BatchTrial,
1157+
{"arm1", "arm2", "arm3"},
1158+
),
1159+
("single_no_arm_name", single_params, None, Trial, None),
1160+
("single_with_arm_name", single_params, ["arm1"], Trial, "arm1"),
1161+
]:
1162+
with self.subTest(label=label):
1163+
self._test_attach_trial(
1164+
parameterizations=params,
1165+
arm_names=arm_names,
1166+
expected_type=expected_type,
1167+
expected_names=expected_names,
1168+
)
11781169

11791170
def test_fetch_as_class(self) -> None:
11801171
class MyMetric(Metric):
@@ -2530,45 +2521,36 @@ def test_is_bope_problem(self) -> None:
25302521
)
25312522
self.assertFalse(experiment.is_bope_problem)
25322523

2533-
def test_name_and_store_arm_if_not_exists_same_name_different_signature(
2534-
self,
2535-
) -> None:
2536-
experiment = self.experiment
2537-
shared_name = "shared_name"
2538-
2539-
arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
2540-
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1})
2541-
self.assertNotEqual(arm_1.signature, arm_2.signature)
2542-
2543-
experiment._register_arm(arm=arm_1)
2544-
with self.assertRaisesRegex(
2545-
AxError,
2546-
f"Arm with name {shared_name} already exists on experiment "
2547-
f"with different signature.",
2548-
):
2549-
experiment._name_and_store_arm_if_not_exists(
2550-
arm=arm_2, proposed_name=shared_name
2551-
)
2552-
2553-
def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature(
2554-
self,
2555-
) -> None:
2556-
experiment = self.experiment
2524+
def test_name_and_store_arm_if_not_exists_different_signature(self) -> None:
25572525
shared_name = "shared_name"
2558-
25592526
arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
2560-
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name)
2561-
self.assertNotEqual(arm_1.signature, arm_2.signature)
25622527

2563-
experiment._register_arm(arm=arm_1)
2564-
with self.assertRaisesRegex(
2565-
AxError,
2566-
f"Arm with name {shared_name} already exists on experiment "
2567-
f"with different signature.",
2568-
):
2569-
experiment._name_and_store_arm_if_not_exists(
2570-
arm=arm_2, proposed_name="different proposed name"
2571-
)
2528+
cases = [
2529+
(
2530+
"arm_without_name",
2531+
Arm({"x1": -1.7, "x2": 0.2, "x3": 1}),
2532+
shared_name,
2533+
),
2534+
(
2535+
"arm_with_shared_name",
2536+
Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name),
2537+
"different proposed name",
2538+
),
2539+
]
2540+
for label, arm_2, proposed_name in cases:
2541+
with self.subTest(label=label):
2542+
self.setUp()
2543+
experiment = self.experiment
2544+
experiment._register_arm(arm=arm_1)
2545+
self.assertNotEqual(arm_1.signature, arm_2.signature)
2546+
with self.assertRaisesRegex(
2547+
AxError,
2548+
f"Arm with name {shared_name} already exists on experiment "
2549+
f"with different signature.",
2550+
):
2551+
experiment._name_and_store_arm_if_not_exists(
2552+
arm=arm_2, proposed_name=proposed_name
2553+
)
25722554

25732555
def test_sorting_data_by_trial_index_and_arm_name(self) -> None:
25742556
# test sorting data

0 commit comments

Comments
 (0)