Skip to content

Commit a817f4e

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Consolidate ax/core tests (#5009)
Summary: Pull Request resolved: #5009 Part of a 19-diff stack to consolidate repetitive tests across Ax using `subTest`. Consolidate 12 test files in ax/core/ — adds subTest to experiment, metric, parameter constraint, runner, trial, and types tests. Differential Revision: D95604288
1 parent 55d992b commit a817f4e

12 files changed

Lines changed: 366 additions & 380 deletions

ax/core/tests/test_batch_trial.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -530,34 +530,39 @@ def test_Repr(self) -> None:
530530
"BatchTrial(experiment_name='test', index=0, status=TrialStatus.CANDIDATE)",
531531
)
532532

533-
def test_TTL_status_property_check(self) -> None:
534-
"""Verify that TTL is checked on execution of the `status` property."""
535-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
536-
self.assertTrue(candidate_trial.status.is_candidate)
537-
sleep(1) # Wait 1 second for trial TTL to elapse.
538-
# Candidate should become stale on expiration of TTL.
539-
self.assertTrue(candidate_trial.status.is_stale)
540-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
541-
542-
def test_TTL_trial_indices_by_status_check(self) -> None:
543-
"""Verify that TTL is checked on `experiment.trial_indices_by_status`."""
544-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
545-
self.assertTrue(candidate_trial.status.is_candidate)
546-
sleep(1) # Wait 1 second for trial TTL to elapse.
547-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
548-
self.assertTrue(candidate_trial.status.is_stale)
549-
550-
def test_TTL_experiment_trials_check(self) -> None:
551-
"""Verify that TTL is checked on `experiment.trials`."""
552-
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
553-
self.assertTrue(candidate_trial.status.is_candidate)
554-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.CANDIDATE])
555-
sleep(1) # Wait 1 second for trial TTL to elapse.
556-
self.experiment.trials
557-
# Check `_status`, not `status`, to ensure it's within `trials` that the status
558-
# was actually changed, not in `status`.
559-
self.assertEqual(candidate_trial._status, TrialStatus.STALE)
560-
self.assertIn(1, self.experiment.trial_indices_by_status[TrialStatus.STALE])
533+
def test_TTL_expiry_across_access_paths(self) -> None:
534+
"""Verify that TTL is checked via different access paths."""
535+
536+
def check_via_experiment_trials(trial: BatchTrial) -> bool:
537+
# Access experiment.trials first (this is the side-effect that
538+
# triggers the TTL status update), then separately verify
539+
# that the private _status field was updated.
540+
_ = self.experiment.trials
541+
return trial._status == TrialStatus.STALE
542+
543+
# Three different code paths that should all trigger TTL expiry:
544+
# 1. Accessing trial.status directly
545+
# 2. Querying experiment.trial_indices_by_status
546+
# 3. Accessing experiment.trials (updates _status as side effect)
547+
ttl_access_paths = {
548+
"status_property": lambda trial: trial.status.is_stale,
549+
"trial_indices_by_status": lambda trial: (
550+
1 in self.experiment.trial_indices_by_status[TrialStatus.STALE]
551+
),
552+
"experiment_trials": check_via_experiment_trials,
553+
}
554+
for access_path, check_stale in ttl_access_paths.items():
555+
with self.subTest(access_path=access_path):
556+
# Reset by creating a fresh experiment for each sub-test
557+
self.setUp()
558+
candidate_trial = self.experiment.new_batch_trial(ttl_seconds=1)
559+
self.assertTrue(candidate_trial.status.is_candidate)
560+
sleep(1) # Wait 1 second for trial TTL to elapse.
561+
self.assertTrue(check_stale(candidate_trial))
562+
self.assertIn(
563+
1,
564+
self.experiment.trial_indices_by_status[TrialStatus.STALE],
565+
)
561566

562567
def test_get_candidate_metadata_from_all_generator_runs(self) -> None:
563568
self.assertEqual(self.batch.generation_method_str, MANUAL_GENERATION_METHOD_STR)

ax/core/tests/test_data.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -793,47 +793,44 @@ def setUp(self) -> None:
793793
)
794794

795795
def test_relativize_data(self) -> None:
796-
data = Data(df=self.df)
797-
expected_relativized_data = Data(df=self.expected_relativized_df)
796+
# Test relativization both with valid SEM values and with NaN SEMs
797+
for sem_label, modify_sem in [
798+
("with_sem", False),
799+
("no_sem", True),
800+
]:
801+
with self.subTest(sem=sem_label):
802+
df = self.df.copy()
803+
expected_relativized_df = self.expected_relativized_df.copy()
804+
expected_relativized_df_with_sq = (
805+
self.expected_relativized_df_with_sq.copy()
806+
)
798807

799-
expected_relativized_data_with_sq = Data(
800-
df=self.expected_relativized_df_with_sq
801-
)
808+
if modify_sem:
809+
df["sem"] = np.nan
810+
expected_relativized_df["sem"] = np.nan
811+
expected_relativized_df_with_sq.loc[
812+
expected_relativized_df_with_sq["arm_name"] != "status_quo",
813+
"sem",
814+
] = np.nan
802815

803-
actual_relativized_data = data.relativize()
804-
self.assertEqual(expected_relativized_data, actual_relativized_data)
816+
data = Data(df=df)
817+
expected_relativized_data = Data(df=expected_relativized_df)
818+
expected_relativized_data_with_sq = Data(
819+
df=expected_relativized_df_with_sq
820+
)
805821

806-
actual_relativized_data_with_sq = data.relativize(include_sq=True)
807-
self.assertEqual(
808-
expected_relativized_data_with_sq, actual_relativized_data_with_sq
809-
)
822+
actual_relativized_data = data.relativize()
823+
self.assertEqual(expected_relativized_data, actual_relativized_data)
824+
825+
actual_relativized_data_with_sq = data.relativize(include_sq=True)
826+
self.assertEqual(
827+
expected_relativized_data_with_sq,
828+
actual_relativized_data_with_sq,
829+
)
810830

811831
with self.subTest("step column not supported"):
812832
data = Data(df=self.df.assign(step=0))
813833
with self.assertRaisesRegex(
814834
NotImplementedError, "Relativization is not supported"
815835
):
816836
data.relativize()
817-
818-
def test_relativize_data_no_sem(self) -> None:
819-
df = self.df.copy()
820-
df["sem"] = np.nan
821-
data = Data(df=df)
822-
823-
expected_relativized_df = self.expected_relativized_df.copy()
824-
expected_relativized_df["sem"] = np.nan
825-
expected_relativized_data = Data(df=expected_relativized_df)
826-
827-
expected_relativized_df_with_sq = self.expected_relativized_df_with_sq.copy()
828-
expected_relativized_df_with_sq.loc[
829-
expected_relativized_df_with_sq["arm_name"] != "status_quo", "sem"
830-
] = np.nan
831-
expected_relativized_data_with_sq = Data(df=expected_relativized_df_with_sq)
832-
833-
actual_relativized_data = data.relativize()
834-
self.assertEqual(expected_relativized_data, actual_relativized_data)
835-
836-
actual_relativized_data_with_sq = data.relativize(include_sq=True)
837-
self.assertEqual(
838-
expected_relativized_data_with_sq, actual_relativized_data_with_sq
839-
)

ax/core/tests/test_experiment.py

Lines changed: 93 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,79 +1102,74 @@ def test_immutable_search_space_and_opt_config(self) -> None:
11021102
)
11031103
self.assertTrue(immutable_exp_2.immutable_search_space_and_opt_config)
11041104

1105-
def test_attach_batch_trial_no_arm_names(self) -> None:
1106-
num_trials = len(self.experiment.trials)
1107-
1108-
_, trial_index = self.experiment.attach_trial(
1109-
parameterizations=[
1110-
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1111-
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1112-
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1113-
],
1114-
ttl_seconds=3600,
1115-
run_metadata={"test_metadata_field": 1},
1116-
)
1117-
1118-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1119-
self.assertEqual(
1120-
len(set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"}),
1121-
3,
1122-
)
1123-
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)
1124-
1125-
def test_attach_batch_trial_with_arm_names(self) -> None:
1126-
num_trials = len(self.experiment.trials)
1127-
1128-
_, trial_index = self.experiment.attach_trial(
1129-
parameterizations=[
1130-
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1131-
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1132-
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1133-
],
1134-
arm_names=["arm1", "arm2", "arm3"],
1135-
ttl_seconds=3600,
1136-
run_metadata={"test_metadata_field": 1},
1137-
)
1138-
1139-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1140-
self.assertEqual(
1141-
len(set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"}),
1142-
3,
1143-
)
1144-
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)
1145-
self.assertEqual(
1146-
{"arm1", "arm2", "arm3"},
1147-
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
1148-
)
1149-
1150-
def test_attach_single_arm_trial_no_arm_name(self) -> None:
1105+
def _test_attach_trial(
1106+
self,
1107+
parameterizations: list[dict[str, object]],
1108+
arm_names: list[str] | None,
1109+
expected_type: type[BaseTrial],
1110+
expected_names: set[str] | str | None,
1111+
) -> None:
1112+
self.setUp()
11511113
num_trials = len(self.experiment.trials)
1152-
1153-
_, trial_index = self.experiment.attach_trial(
1154-
parameterizations=[{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6}],
1155-
ttl_seconds=3600,
1156-
run_metadata={"test_metadata_field": 1},
1157-
)
1114+
if arm_names is not None:
1115+
_, trial_index = self.experiment.attach_trial(
1116+
parameterizations=parameterizations,
1117+
arm_names=arm_names,
1118+
ttl_seconds=3600,
1119+
run_metadata={"test_metadata_field": 1},
1120+
)
1121+
else:
1122+
_, trial_index = self.experiment.attach_trial(
1123+
parameterizations=parameterizations,
1124+
ttl_seconds=3600,
1125+
run_metadata={"test_metadata_field": 1},
1126+
)
11581127

11591128
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1160-
self.assertEqual(type(self.experiment.trials[trial_index]), Trial)
1161-
1162-
def test_attach_single_arm_trial_with_arm_name(self) -> None:
1163-
num_trials = len(self.experiment.trials)
1164-
1165-
_, trial_index = self.experiment.attach_trial(
1166-
parameterizations=[{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6}],
1167-
arm_names=["arm1"],
1168-
ttl_seconds=3600,
1169-
run_metadata={"test_metadata_field": 1},
1170-
)
1129+
self.assertEqual(type(self.experiment.trials[trial_index]), expected_type)
1130+
if isinstance(expected_names, set):
1131+
self.assertEqual(
1132+
expected_names,
1133+
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
1134+
)
1135+
elif isinstance(expected_names, str):
1136+
self.assertEqual(
1137+
expected_names,
1138+
self.experiment.trials[trial_index].arm.name,
1139+
)
11711140

1172-
self.assertEqual(len(self.experiment.trials), num_trials + 1)
1173-
self.assertEqual(type(self.experiment.trials[trial_index]), Trial)
1174-
self.assertEqual(
1175-
"arm1",
1176-
self.experiment.trials[trial_index].arm.name,
1177-
)
1141+
def test_attach_trial(self) -> None:
1142+
batch_params: list[dict[str, object]] = [
1143+
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1144+
{"w": 5.2, "x": 5, "y": "foo", "z": True, "d": 11.4},
1145+
{"w": 5.1, "x": 5, "y": "bar", "z": True, "d": 11.2},
1146+
]
1147+
single_params: list[dict[str, object]] = [
1148+
{"w": 5.3, "x": 5, "y": "baz", "z": True, "d": 11.6},
1149+
]
1150+
for label, params, arm_names, expected_type, expected_names in [
1151+
# Batch trial without explicit arm names -> auto-generated names
1152+
("batch_no_arm_names", batch_params, None, BatchTrial, None),
1153+
# Batch trial with explicit arm names -> names preserved
1154+
(
1155+
"batch_with_arm_names",
1156+
batch_params,
1157+
["arm1", "arm2", "arm3"],
1158+
BatchTrial,
1159+
{"arm1", "arm2", "arm3"},
1160+
),
1161+
# Single-arm trial without arm name -> creates Trial (not BatchTrial)
1162+
("single_no_arm_name", single_params, None, Trial, None),
1163+
# Single-arm trial with explicit arm name -> name preserved
1164+
("single_with_arm_name", single_params, ["arm1"], Trial, "arm1"),
1165+
]:
1166+
with self.subTest(label=label):
1167+
self._test_attach_trial(
1168+
parameterizations=params,
1169+
arm_names=arm_names,
1170+
expected_type=expected_type,
1171+
expected_names=expected_names,
1172+
)
11781173

11791174
def test_fetch_as_class(self) -> None:
11801175
class MyMetric(Metric):
@@ -2530,45 +2525,38 @@ def test_is_bope_problem(self) -> None:
25302525
)
25312526
self.assertFalse(experiment.is_bope_problem)
25322527

2533-
def test_name_and_store_arm_if_not_exists_same_name_different_signature(
2534-
self,
2535-
) -> None:
2536-
experiment = self.experiment
2537-
shared_name = "shared_name"
2538-
2539-
arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
2540-
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1})
2541-
self.assertNotEqual(arm_1.signature, arm_2.signature)
2542-
2543-
experiment._register_arm(arm=arm_1)
2544-
with self.assertRaisesRegex(
2545-
AxError,
2546-
f"Arm with name {shared_name} already exists on experiment "
2547-
f"with different signature.",
2548-
):
2549-
experiment._name_and_store_arm_if_not_exists(
2550-
arm=arm_2, proposed_name=shared_name
2551-
)
2552-
2553-
def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature(
2554-
self,
2555-
) -> None:
2556-
experiment = self.experiment
2528+
def test_name_and_store_arm_if_not_exists_different_signature(self) -> None:
25572529
shared_name = "shared_name"
2558-
25592530
arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name)
2560-
arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name)
2561-
self.assertNotEqual(arm_1.signature, arm_2.signature)
25622531

2563-
experiment._register_arm(arm=arm_1)
2564-
with self.assertRaisesRegex(
2565-
AxError,
2566-
f"Arm with name {shared_name} already exists on experiment "
2567-
f"with different signature.",
2568-
):
2569-
experiment._name_and_store_arm_if_not_exists(
2570-
arm=arm_2, proposed_name="different proposed name"
2571-
)
2532+
cases = [
2533+
# Unnamed arm with proposed_name matching existing arm's name
2534+
(
2535+
"arm_without_name",
2536+
Arm({"x1": -1.7, "x2": 0.2, "x3": 1}),
2537+
shared_name,
2538+
),
2539+
# Named arm matching existing arm but with a different proposed_name
2540+
(
2541+
"arm_with_shared_name",
2542+
Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name),
2543+
"different proposed name",
2544+
),
2545+
]
2546+
for label, arm_2, proposed_name in cases:
2547+
with self.subTest(label=label):
2548+
self.setUp()
2549+
experiment = self.experiment
2550+
experiment._register_arm(arm=arm_1)
2551+
self.assertNotEqual(arm_1.signature, arm_2.signature)
2552+
with self.assertRaisesRegex(
2553+
AxError,
2554+
f"Arm with name {shared_name} already exists on experiment "
2555+
f"with different signature.",
2556+
):
2557+
experiment._name_and_store_arm_if_not_exists(
2558+
arm=arm_2, proposed_name=proposed_name
2559+
)
25722560

25732561
def test_sorting_data_by_trial_index_and_arm_name(self) -> None:
25742562
# test sorting data

0 commit comments

Comments
 (0)