Skip to content

Commit cd7b5e9

Browse files
committed
minor changes
1 parent 5610828 commit cd7b5e9

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

ml_grid/pipeline/model_class_list.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
4949
Returns:
5050
List[Any]: A list of instantiated model class objects.
5151
"""
52+
# Get the parameter space size, defaulting to 'small' if not provided.
53+
# This prevents errors when the key is missing from the configuration.
5254
parameter_space_size = ml_grid_object.local_param_dict.get("param_space_size")
55+
if parameter_space_size is None:
56+
parameter_space_size = "small"
57+
5358
model_class_dict: Optional[Dict[str, bool]] = ml_grid_object.model_class_dict
5459

5560
if model_class_dict is None:
@@ -80,7 +85,15 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
8085

8186
for class_name, include in model_class_dict.items():
8287
if include:
83-
model_class = eval(class_name)
88+
# Try the exact name first, then try with '_class' appended for convenience
89+
try:
90+
model_class = eval(class_name)
91+
except NameError:
92+
class_name_with_suffix = f"{class_name}_class"
93+
try:
94+
model_class = eval(class_name_with_suffix)
95+
except NameError:
96+
raise NameError(f"Could not find model class '{class_name}' or '{class_name_with_suffix}'. Please check the name and ensure it's imported.")
8497
model_instance = model_class(
8598
X=ml_grid_object.X_train,
8699
y=ml_grid_object.y_train,

0 commit comments

Comments
 (0)