@@ -1102,79 +1102,74 @@ def test_immutable_search_space_and_opt_config(self) -> None:
11021102 )
11031103 self .assertTrue (immutable_exp_2 .immutable_search_space_and_opt_config )
11041104
1105- def test_attach_batch_trial_no_arm_names (self ) -> None :
1106- num_trials = len (self .experiment .trials )
1107-
1108- _ , trial_index = self .experiment .attach_trial (
1109- parameterizations = [
1110- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1111- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1112- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1113- ],
1114- ttl_seconds = 3600 ,
1115- run_metadata = {"test_metadata_field" : 1 },
1116- )
1117-
1118- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1119- self .assertEqual (
1120- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1121- 3 ,
1122- )
1123- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1124-
1125- def test_attach_batch_trial_with_arm_names (self ) -> None :
1126- num_trials = len (self .experiment .trials )
1127-
1128- _ , trial_index = self .experiment .attach_trial (
1129- parameterizations = [
1130- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1131- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1132- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1133- ],
1134- arm_names = ["arm1" , "arm2" , "arm3" ],
1135- ttl_seconds = 3600 ,
1136- run_metadata = {"test_metadata_field" : 1 },
1137- )
1138-
1139- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1140- self .assertEqual (
1141- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1142- 3 ,
1143- )
1144- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1145- self .assertEqual (
1146- {"arm1" , "arm2" , "arm3" },
1147- set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" },
1148- )
1149-
1150- def test_attach_single_arm_trial_no_arm_name (self ) -> None :
1105+ def _test_attach_trial (
1106+ self ,
1107+ parameterizations : list [dict [str , object ]],
1108+ arm_names : list [str ] | None ,
1109+ expected_type : type [BaseTrial ],
1110+ expected_names : set [str ] | str | None ,
1111+ ) -> None :
1112+ self .setUp ()
11511113 num_trials = len (self .experiment .trials )
1152-
1153- _ , trial_index = self .experiment .attach_trial (
1154- parameterizations = [{"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1155- ttl_seconds = 3600 ,
1156- run_metadata = {"test_metadata_field" : 1 },
1157- )
1114+ if arm_names is not None :
1115+ _ , trial_index = self .experiment .attach_trial (
1116+ parameterizations = parameterizations ,
1117+ arm_names = arm_names ,
1118+ ttl_seconds = 3600 ,
1119+ run_metadata = {"test_metadata_field" : 1 },
1120+ )
1121+ else :
1122+ _ , trial_index = self .experiment .attach_trial (
1123+ parameterizations = parameterizations ,
1124+ ttl_seconds = 3600 ,
1125+ run_metadata = {"test_metadata_field" : 1 },
1126+ )
11581127
11591128 self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1160- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1161-
1162- def test_attach_single_arm_trial_with_arm_name ( self ) -> None :
1163- num_trials = len ( self . experiment . trials )
1164-
1165- _ , trial_index = self . experiment . attach_trial (
1166- parameterizations = [{ "w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1167- arm_names = [ "arm1" ],
1168- ttl_seconds = 3600 ,
1169- run_metadata = { "test_metadata_field" : 1 } ,
1170- )
1129+ self .assertEqual (type (self .experiment .trials [trial_index ]), expected_type )
1130+ if isinstance ( expected_names , set ):
1131+ self . assertEqual (
1132+ expected_names ,
1133+ set ( self . experiment . trials [ trial_index ]. arms_by_name ) - { "status_quo" },
1134+ )
1135+ elif isinstance ( expected_names , str ):
1136+ self . assertEqual (
1137+ expected_names ,
1138+ self . experiment . trials [ trial_index ]. arm . name ,
1139+ )
11711140
1172- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1173- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1174- self .assertEqual (
1175- "arm1" ,
1176- self .experiment .trials [trial_index ].arm .name ,
1177- )
1141+ def test_attach_trial (self ) -> None :
1142+ batch_params : list [dict [str , object ]] = [
1143+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1144+ {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1145+ {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1146+ ]
1147+ single_params : list [dict [str , object ]] = [
1148+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1149+ ]
1150+ for label , params , arm_names , expected_type , expected_names in [
1151+ # Batch trial without explicit arm names -> auto-generated names
1152+ ("batch_no_arm_names" , batch_params , None , BatchTrial , None ),
1153+ # Batch trial with explicit arm names -> names preserved
1154+ (
1155+ "batch_with_arm_names" ,
1156+ batch_params ,
1157+ ["arm1" , "arm2" , "arm3" ],
1158+ BatchTrial ,
1159+ {"arm1" , "arm2" , "arm3" },
1160+ ),
1161+ # Single-arm trial without arm name -> creates Trial (not BatchTrial)
1162+ ("single_no_arm_name" , single_params , None , Trial , None ),
1163+ # Single-arm trial with explicit arm name -> name preserved
1164+ ("single_with_arm_name" , single_params , ["arm1" ], Trial , "arm1" ),
1165+ ]:
1166+ with self .subTest (label = label ):
1167+ self ._test_attach_trial (
1168+ parameterizations = params ,
1169+ arm_names = arm_names ,
1170+ expected_type = expected_type ,
1171+ expected_names = expected_names ,
1172+ )
11781173
11791174 def test_fetch_as_class (self ) -> None :
11801175 class MyMetric (Metric ):
@@ -2530,45 +2525,38 @@ def test_is_bope_problem(self) -> None:
25302525 )
25312526 self .assertFalse (experiment .is_bope_problem )
25322527
2533- def test_name_and_store_arm_if_not_exists_same_name_different_signature (
2534- self ,
2535- ) -> None :
2536- experiment = self .experiment
2537- shared_name = "shared_name"
2538-
2539- arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2540- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 })
2541- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2542-
2543- experiment ._register_arm (arm = arm_1 )
2544- with self .assertRaisesRegex (
2545- AxError ,
2546- f"Arm with name { shared_name } already exists on experiment "
2547- f"with different signature." ,
2548- ):
2549- experiment ._name_and_store_arm_if_not_exists (
2550- arm = arm_2 , proposed_name = shared_name
2551- )
2552-
2553- def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature (
2554- self ,
2555- ) -> None :
2556- experiment = self .experiment
2528+ def test_name_and_store_arm_if_not_exists_different_signature (self ) -> None :
25572529 shared_name = "shared_name"
2558-
25592530 arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2560- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name )
2561- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
25622531
2563- experiment ._register_arm (arm = arm_1 )
2564- with self .assertRaisesRegex (
2565- AxError ,
2566- f"Arm with name { shared_name } already exists on experiment "
2567- f"with different signature." ,
2568- ):
2569- experiment ._name_and_store_arm_if_not_exists (
2570- arm = arm_2 , proposed_name = "different proposed name"
2571- )
2532+ cases = [
2533+ # Unnamed arm with proposed_name matching existing arm's name
2534+ (
2535+ "arm_without_name" ,
2536+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }),
2537+ shared_name ,
2538+ ),
2539+ # Named arm matching existing arm but with a different proposed_name
2540+ (
2541+ "arm_with_shared_name" ,
2542+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name ),
2543+ "different proposed name" ,
2544+ ),
2545+ ]
2546+ for label , arm_2 , proposed_name in cases :
2547+ with self .subTest (label = label ):
2548+ self .setUp ()
2549+ experiment = self .experiment
2550+ experiment ._register_arm (arm = arm_1 )
2551+ self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2552+ with self .assertRaisesRegex (
2553+ AxError ,
2554+ f"Arm with name { shared_name } already exists on experiment "
2555+ f"with different signature." ,
2556+ ):
2557+ experiment ._name_and_store_arm_if_not_exists (
2558+ arm = arm_2 , proposed_name = proposed_name
2559+ )
25722560
25732561 def test_sorting_data_by_trial_index_and_arm_name (self ) -> None :
25742562 # test sorting data
0 commit comments