@@ -1121,79 +1121,74 @@ def test_immutable_search_space_and_opt_config(self) -> None:
11211121 )
11221122 self .assertTrue (immutable_exp_2 .immutable_search_space_and_opt_config )
11231123
1124- def test_attach_batch_trial_no_arm_names (self ) -> None :
1125- num_trials = len (self .experiment .trials )
1126-
1127- _ , trial_index = self .experiment .attach_trial (
1128- parameterizations = [
1129- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1130- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1131- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1132- ],
1133- ttl_seconds = 3600 ,
1134- run_metadata = {"test_metadata_field" : 1 },
1135- )
1136-
1137- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1138- self .assertEqual (
1139- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1140- 3 ,
1141- )
1142- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1143-
1144- def test_attach_batch_trial_with_arm_names (self ) -> None :
1145- num_trials = len (self .experiment .trials )
1146-
1147- _ , trial_index = self .experiment .attach_trial (
1148- parameterizations = [
1149- {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1150- {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1151- {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1152- ],
1153- arm_names = ["arm1" , "arm2" , "arm3" ],
1154- ttl_seconds = 3600 ,
1155- run_metadata = {"test_metadata_field" : 1 },
1156- )
1157-
1158- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1159- self .assertEqual (
1160- len (set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" }),
1161- 3 ,
1162- )
1163- self .assertEqual (type (self .experiment .trials [trial_index ]), BatchTrial )
1164- self .assertEqual (
1165- {"arm1" , "arm2" , "arm3" },
1166- set (self .experiment .trials [trial_index ].arms_by_name ) - {"status_quo" },
1167- )
1168-
1169- def test_attach_single_arm_trial_no_arm_name (self ) -> None :
1124+ def _test_attach_trial (
1125+ self ,
1126+ parameterizations : list [dict [str , object ]],
1127+ arm_names : list [str ] | None ,
1128+ expected_type : type [BaseTrial ],
1129+ expected_names : set [str ] | str | None ,
1130+ ) -> None :
1131+ self .setUp ()
11701132 num_trials = len (self .experiment .trials )
1171-
1172- _ , trial_index = self .experiment .attach_trial (
1173- parameterizations = [{"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1174- ttl_seconds = 3600 ,
1175- run_metadata = {"test_metadata_field" : 1 },
1176- )
1133+ if arm_names is not None :
1134+ _ , trial_index = self .experiment .attach_trial (
1135+ parameterizations = parameterizations ,
1136+ arm_names = arm_names ,
1137+ ttl_seconds = 3600 ,
1138+ run_metadata = {"test_metadata_field" : 1 },
1139+ )
1140+ else :
1141+ _ , trial_index = self .experiment .attach_trial (
1142+ parameterizations = parameterizations ,
1143+ ttl_seconds = 3600 ,
1144+ run_metadata = {"test_metadata_field" : 1 },
1145+ )
11771146
11781147 self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1179- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1180-
1181- def test_attach_single_arm_trial_with_arm_name ( self ) -> None :
1182- num_trials = len ( self . experiment . trials )
1183-
1184- _ , trial_index = self . experiment . attach_trial (
1185- parameterizations = [{ "w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 }],
1186- arm_names = [ "arm1" ],
1187- ttl_seconds = 3600 ,
1188- run_metadata = { "test_metadata_field" : 1 } ,
1189- )
1148+ self .assertEqual (type (self .experiment .trials [trial_index ]), expected_type )
1149+ if isinstance ( expected_names , set ):
1150+ self . assertEqual (
1151+ expected_names ,
1152+ set ( self . experiment . trials [ trial_index ]. arms_by_name ) - { "status_quo" },
1153+ )
1154+ elif isinstance ( expected_names , str ):
1155+ self . assertEqual (
1156+ expected_names ,
1157+ self . experiment . trials [ trial_index ]. arm . name ,
1158+ )
11901159
1191- self .assertEqual (len (self .experiment .trials ), num_trials + 1 )
1192- self .assertEqual (type (self .experiment .trials [trial_index ]), Trial )
1193- self .assertEqual (
1194- "arm1" ,
1195- self .experiment .trials [trial_index ].arm .name ,
1196- )
1160+ def test_attach_trial (self ) -> None :
1161+ batch_params : list [dict [str , object ]] = [
1162+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1163+ {"w" : 5.2 , "x" : 5 , "y" : "foo" , "z" : True , "d" : 11.4 },
1164+ {"w" : 5.1 , "x" : 5 , "y" : "bar" , "z" : True , "d" : 11.2 },
1165+ ]
1166+ single_params : list [dict [str , object ]] = [
1167+ {"w" : 5.3 , "x" : 5 , "y" : "baz" , "z" : True , "d" : 11.6 },
1168+ ]
1169+ for label , params , arm_names , expected_type , expected_names in [
1170+ # Batch trial without explicit arm names -> auto-generated names
1171+ ("batch_no_arm_names" , batch_params , None , BatchTrial , None ),
1172+ # Batch trial with explicit arm names -> names preserved
1173+ (
1174+ "batch_with_arm_names" ,
1175+ batch_params ,
1176+ ["arm1" , "arm2" , "arm3" ],
1177+ BatchTrial ,
1178+ {"arm1" , "arm2" , "arm3" },
1179+ ),
1180+ # Single-arm trial without arm name -> creates Trial (not BatchTrial)
1181+ ("single_no_arm_name" , single_params , None , Trial , None ),
1182+ # Single-arm trial with explicit arm name -> name preserved
1183+ ("single_with_arm_name" , single_params , ["arm1" ], Trial , "arm1" ),
1184+ ]:
1185+ with self .subTest (label = label ):
1186+ self ._test_attach_trial (
1187+ parameterizations = params ,
1188+ arm_names = arm_names ,
1189+ expected_type = expected_type ,
1190+ expected_names = expected_names ,
1191+ )
11971192
11981193 def test_fetch_as_class (self ) -> None :
11991194 class MyMetric (Metric ):
@@ -2593,45 +2588,38 @@ def test_is_bope_problem(self) -> None:
25932588 )
25942589 self .assertFalse (experiment .is_bope_problem )
25952590
2596- def test_name_and_store_arm_if_not_exists_same_name_different_signature (
2597- self ,
2598- ) -> None :
2599- experiment = self .experiment
2600- shared_name = "shared_name"
2601-
2602- arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2603- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 })
2604- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2605-
2606- experiment ._register_arm (arm = arm_1 )
2607- with self .assertRaisesRegex (
2608- AxError ,
2609- f"Arm with name { shared_name } already exists on experiment "
2610- f"with different signature." ,
2611- ):
2612- experiment ._name_and_store_arm_if_not_exists (
2613- arm = arm_2 , proposed_name = shared_name
2614- )
2615-
2616- def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature (
2617- self ,
2618- ) -> None :
2619- experiment = self .experiment
2591+ def test_name_and_store_arm_if_not_exists_different_signature (self ) -> None :
26202592 shared_name = "shared_name"
2621-
26222593 arm_1 = Arm ({"x1" : - 1.0 , "x2" : 1.0 }, name = shared_name )
2623- arm_2 = Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name )
2624- self .assertNotEqual (arm_1 .signature , arm_2 .signature )
26252594
2626- experiment ._register_arm (arm = arm_1 )
2627- with self .assertRaisesRegex (
2628- AxError ,
2629- f"Arm with name { shared_name } already exists on experiment "
2630- f"with different signature." ,
2631- ):
2632- experiment ._name_and_store_arm_if_not_exists (
2633- arm = arm_2 , proposed_name = "different proposed name"
2634- )
2595+ cases = [
2596+ # Unnamed arm with proposed_name matching existing arm's name
2597+ (
2598+ "arm_without_name" ,
2599+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }),
2600+ shared_name ,
2601+ ),
2602+ # Named arm matching existing arm but with a different proposed_name
2603+ (
2604+ "arm_with_shared_name" ,
2605+ Arm ({"x1" : - 1.7 , "x2" : 0.2 , "x3" : 1 }, name = shared_name ),
2606+ "different proposed name" ,
2607+ ),
2608+ ]
2609+ for label , arm_2 , proposed_name in cases :
2610+ with self .subTest (label = label ):
2611+ self .setUp ()
2612+ experiment = self .experiment
2613+ experiment ._register_arm (arm = arm_1 )
2614+ self .assertNotEqual (arm_1 .signature , arm_2 .signature )
2615+ with self .assertRaisesRegex (
2616+ AxError ,
2617+ f"Arm with name { shared_name } already exists on experiment "
2618+ f"with different signature." ,
2619+ ):
2620+ experiment ._name_and_store_arm_if_not_exists (
2621+ arm = arm_2 , proposed_name = proposed_name
2622+ )
26352623
26362624 def test_sorting_data_by_trial_index_and_arm_name (self ) -> None :
26372625 # test sorting data
0 commit comments