File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -49,7 +49,12 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
4949 Returns:
5050 List[Any]: A list of instantiated model class objects.
5151 """
52+ # Get the parameter space size, defaulting to 'small' if not provided.
53+ # This prevents errors when the key is missing from the configuration.
5254 parameter_space_size = ml_grid_object .local_param_dict .get ("param_space_size" )
55+ if parameter_space_size is None :
56+ parameter_space_size = "small"
57+
5358 model_class_dict : Optional [Dict [str , bool ]] = ml_grid_object .model_class_dict
5459
5560 if model_class_dict is None :
@@ -80,7 +85,15 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
8085
8186 for class_name , include in model_class_dict .items ():
8287 if include :
83- model_class = eval (class_name )
88+ # Try the exact name first, then try with '_class' appended for convenience
89+ try :
90+ model_class = eval (class_name )
91+ except NameError :
92+ class_name_with_suffix = f"{ class_name } _class"
93+ try :
94+ model_class = eval (class_name_with_suffix )
95+ except NameError :
96+ raise NameError (f"Could not find model class '{ class_name } ' or '{ class_name_with_suffix } '. Please check the name and ensure it's imported." )
8497 model_instance = model_class (
8598 X = ml_grid_object .X_train ,
8699 y = ml_grid_object .y_train ,
You can’t perform that action at this time.
0 commit comments