@@ -1102,79 +1102,70 @@ def test_immutable_search_space_and_opt_config(self) -> None:
11021102 )
11031103 self .assertTrue (immutable_exp_2 .immutable_search_space_and_opt_config )
11041104
1105- def test_attach_batch_trial_no_arm_names (self ) -> None :
1106- num_trials = len (self .experiment .trials )
1107-
1108- _ , trial_index = self .experiment .attach_trial (
1109- parameterizations = [
1110- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1111- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1112- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1113- ],
1114- ttl_seconds = 3600 ,
1115- run_metadata = {"test_metadata_field" : 1 },
1116- )
1117-
1118- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1119- self .assertEqual (
1120- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1121- 3 ,
1122- )
1123- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1124-
1125- def test_attach_batch_trial_with_arm_names (self ) -> None :
1126- num_trials = len (self .experiment .trials )
1127-
1128- _ , trial_index = self .experiment .attach_trial (
1129- parameterizations = [
1130- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1131- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1132- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1133- ],
1134- arm_names = ["arm1" , "arm2" , "arm3" ],
1135- ttl_seconds = 3600 ,
1136- run_metadata = {"test_metadata_field" : 1 },
1137- )
1138-
1139- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1140- self .assertEqual (
1141- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1142- 3 ,
1143- )
1144- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1145- self .assertEqual (
1146- {"arm1" , "arm2" , "arm3" },
1147- set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" },
1148- )
1149-
1150- def test_attach_single_arm_trial_no_arm_name (self ) -> None :
1105+ def _test_attach_trial (
1106+ self ,
1107+ parameterizations : list [dict [str , object ]],
1108+ arm_names : list [str ] | None ,
1109+ expected_type : type [BaseTrial ],
1110+ expected_names : set [str ] | str | None ,
1111+ ) -> None :
1112+ self .setUp ()
11511113 num_trials = len (self .experiment .trials )
1152-
1153- _ , trial_index = self .experiment .attach_trial (
1154- parameterizations = [{"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1155- ttl_seconds = 3600 ,
1156- run_metadata = {"test_metadata_field" : 1 },
1157- )
1114+ if arm_names is not None :
1115+ _ , trial_index = self .experiment .attach_trial (
1116+ parameterizations = parameterizations ,
1117+ arm_names = arm_names ,
1118+ ttl_seconds = 3600 ,
1119+ run_metadata = {"test_metadata_field" : 1 },
1120+ )
1121+ else :
1122+ _ , trial_index = self .experiment .attach_trial (
1123+ parameterizations = parameterizations ,
1124+ ttl_seconds = 3600 ,
1125+ run_metadata = {"test_metadata_field" : 1 },
1126+ )
11581127
11591128 self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1160- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1161-
1162- def test_attach_single_arm_trial_with_arm_name ( self ) -> None :
1163- num_trials = len ( self . experiment . trials )
1164-
1165- _ , trial_index = self . experiment . attach_trial (
1166- parameterizations = [{ "w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1167- arm_names = [ "arm1" ],
1168- ttl_seconds = 3600 ,
1169- run_metadata = { "test_metadata_field" : 1 } ,
1170- )
1129+ self .assertEqual (type (self .experiment .trials [trial_index ]), expected_type )
1130+ if isinstance ( expected_names , set ):
1131+ self . assertEqual (
1132+ expected_names ,
1133+ set ( self . experiment . trials [ trial_index ]. arms_by_name ) - { "status_quo" },
1134+ )
1135+ elif isinstance ( expected_names , str ):
1136+ self . assertEqual (
1137+ expected_names ,
1138+ self . experiment . trials [ trial_index ]. arm . name ,
1139+ )
11711140
1172- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1173- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1174- self .assertEqual (
1175- "arm1" ,
1176- self .experiment .trials [trial_index ].arm .name ,
1177- )
1141+ def test_attach_trial (self ) -> None :
1142+ batch_params : list [dict [str , object ]] = [
1143+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1144+ {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1145+ {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1146+ ]
1147+ single_params : list [dict [str , object ]] = [
1148+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1149+ ]
1150+ for label , params , arm_names , expected_type , expected_names in [
1151+ ("batch_no_arm_names" , batch_params , None , BatchTrial , None ),
1152+ (
1153+ "batch_with_arm_names" ,
1154+ batch_params ,
1155+ ["arm1" , "arm2" , "arm3" ],
1156+ BatchTrial ,
1157+ {"arm1" , "arm2" , "arm3" },
1158+ ),
1159+ ("single_no_arm_name" , single_params , None , Trial , None ),
1160+ ("single_with_arm_name" , single_params , ["arm1" ], Trial , "arm1" ),
1161+ ]:
1162+ with self .subTest (label = label ):
1163+ self ._test_attach_trial (
1164+ parameterizations = params ,
1165+ arm_names = arm_names ,
1166+ expected_type = expected_type ,
1167+ expected_names = expected_names ,
1168+ )
11781169
11791170 def test_fetch_as_class (self ) -> None :
11801171 class MyMetric (Metric ):
@@ -2530,45 +2521,36 @@ def test_is_bope_problem(self) -> None:
25302521 )
25312522 self .assertFalse (experiment .is_bope_problem )
25322523
2533- def test_name_and_store_arm_if_not_exists_same_name_different_signature (
2534- self ,
2535- ) -> None :
2536- experiment = self .experiment
2537- shared_name = "shared_name"
2538-
2539- arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2540- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 })
2541- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2542-
2543- experiment ._register_arm (arm = arm_1 )
2544- with self .assertRaisesRegex (
2545- AxError ,
2546- f"Arm with name { shared_name } already exists on experiment "
2547- f"with different signature." ,
2548- ):
2549- experiment ._name_and_store_arm_if_not_exists (
2550- arm = arm_2 , proposed_name = shared_name
2551- )
2552-
2553- def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature (
2554- self ,
2555- ) -> None :
2556- experiment = self .experiment
2524+ def test_name_and_store_arm_if_not_exists_different_signature (self ) -> None :
25572525 shared_name = "shared_name"
2558-
25592526 arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2560- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name )
2561- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
25622527
2563- experiment ._register_arm (arm = arm_1 )
2564- with self .assertRaisesRegex (
2565- AxError ,
2566- f"Arm with name { shared_name } already exists on experiment "
2567- f"with different signature." ,
2568- ):
2569- experiment ._name_and_store_arm_if_not_exists (
2570- arm = arm_2 , proposed_name = "different proposed name"
2571- )
2528+ cases = [
2529+ (
2530+ "arm_without_name" ,
2531+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }),
2532+ shared_name ,
2533+ ),
2534+ (
2535+ "arm_with_shared_name" ,
2536+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name ),
2537+ "different proposed name" ,
2538+ ),
2539+ ]
2540+ for label , arm_2 , proposed_name in cases :
2541+ with self .subTest (label = label ):
2542+ self .setUp ()
2543+ experiment = self .experiment
2544+ experiment ._register_arm (arm = arm_1 )
2545+ self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2546+ with self .assertRaisesRegex (
2547+ AxError ,
2548+ f"Arm with name { shared_name } already exists on experiment "
2549+ f"with different signature." ,
2550+ ):
2551+ experiment ._name_and_store_arm_if_not_exists (
2552+ arm = arm_2 , proposed_name = proposed_name
2553+ )
25722554
25732555 def test_sorting_data_by_trial_index_and_arm_name (self ) -> None :
25742556 # test sorting data
0 commit comments