Skip to content

Commit 2b1ffb6

Browse files
committed
feat: Integrate AutoML libraries and manage availability
This commit introduces support for several popular AutoML libraries into the ml_grid framework, including AutoGluon, TPOT, FLAML, and AutoKeras. Key changes: - The model class list now dynamically checks for the availability of these libraries and only enables the corresponding classifiers if the libraries are installed. - Added the new classifier classes to the central `MODEL_CLASS_MAP` for robust lookup. - Updated the default model dictionary to include the new AutoML models, with their inclusion controlled by their availability. - The CI configuration has been updated to disable these resource-intensive models during automated test runs. Additionally, the deprecated `ShapeDTW` classifier has been removed from the time-series model list to align with recent changes in the 'aeon' library.
1 parent 9fe5b4b commit 2b1ffb6

2 files changed

Lines changed: 51 additions & 3 deletions

File tree

ml_grid/pipeline/model_class_list.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
# Import all model classes to make them available for eval()
1313
from ml_grid.model_classes.adaboost_classifier_class import AdaBoostClassifierClass
14+
from ml_grid.model_classes.auto_gluon_classifier_class import AutoGluonClassifierClass
1415
from ml_grid.model_classes.catboost_classifier_class import CatBoostClassifierClass
16+
from ml_grid.model_classes.tpot_classifier_class import TPOTClassifierClass
17+
from ml_grid.model_classes.flaml_classifier_class import FLAMLClassifierClass
18+
from ml_grid.model_classes.auto_keras_classifier_class import AutoKerasClassifierClass
1519
from ml_grid.model_classes.gaussiannb_class import (
1620
GaussianNBClassifierClass,
1721
)
@@ -57,7 +61,6 @@
5761
from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass
5862
from ml_grid.model_classes.tabpfn_classifier_class import TabPFNClassifierClass
5963

60-
6164
# --- ROBUST MAPPING of config names to class objects ---
6265
# This dictionary provides a direct, secure, and explicit mapping from the
6366
# string names used in the YAML config files to the actual imported Python classes.
@@ -81,6 +84,10 @@
8184
"SVCClass": SVCClass,
8285
"NeuralNetworkClassifier_class": NeuralNetworkClassifier_class, # Corrected mapping
8386
"TabPFNClassifierClass": TabPFNClassifierClass,
87+
"AutoGluonClassifierClass": AutoGluonClassifierClass,
88+
"TPOTClassifierClass": TPOTClassifierClass,
89+
"FLAMLClassifierClass": FLAMLClassifierClass,
90+
"AutoKerasClassifierClass": AutoKerasClassifierClass,
8491
# GPU specific
8592
"KerasClassifierClass": KerasClassifierClass,
8693
# "KNNGpuWrapperClass": KNNGpuWrapperClass, #deprecated by python 3.12 and simsig dependency
@@ -132,6 +139,38 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
132139
if parameter_space_size is None:
133140
parameter_space_size = "small"
134141

142+
# Check for AutoGluon availability
143+
try:
144+
import autogluon.tabular # noqa: F401
145+
146+
autogluon_available = True
147+
except ImportError:
148+
autogluon_available = False
149+
150+
# Check for TPOT availability
151+
try:
152+
import tpot # noqa: F401
153+
154+
tpot_available = True
155+
except ImportError:
156+
tpot_available = False
157+
158+
# Check for FLAML availability
159+
try:
160+
import flaml # noqa: F401
161+
162+
flaml_available = True
163+
except ImportError:
164+
flaml_available = False
165+
166+
# Check for AutoKeras availability
167+
try:
168+
import autokeras # noqa: F401
169+
170+
autokeras_available = True
171+
except ImportError:
172+
autokeras_available = False
173+
135174
model_class_dict: Optional[Dict[str, bool]] = ml_grid_object.model_class_dict
136175

137176
if model_class_dict is None:
@@ -166,6 +205,10 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
166205
"H2O_StackedEnsemble_class": True, # H2O Stacked Ensemble
167206
"H2O_GAM_class": True, # H2O Generalized Additive Models
168207
"TabPFNClassifierClass": False, # requires hf token and agreement
208+
"AutoGluonClassifierClass": autogluon_available,
209+
"TPOTClassifierClass": tpot_available,
210+
"FLAMLClassifierClass": flaml_available,
211+
"AutoKerasClassifierClass": autokeras_available,
169212
}
170213

171214
# If running in a CI environment, explicitly disable resource-intensive models
@@ -188,6 +231,10 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
188231
"H2O_GAM_class",
189232
"TabTransformerClass",
190233
"TabPFNClassifierClass",
234+
"AutoGluonClassifierClass",
235+
"TPOTClassifierClass",
236+
"FLAMLClassifierClass",
237+
"AutoKerasClassifierClass",
191238
]
192239
for model_name in models_to_disable:
193240
if model_name in model_class_dict:

ml_grid/pipeline/model_class_list_ts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from ml_grid.model_classes_time_series.rocketClassifier_module import (
4343
RocketClassifier_class,
4444
)
45-
from ml_grid.model_classes_time_series.shapeDTWClassifier_module import ShapeDTW_class
45+
46+
# from ml_grid.model_classes_time_series.shapeDTWClassifier_module import ShapeDTW_class # deprecated
4647
from ml_grid.model_classes_time_series.SignatureClassifier_module import (
4748
SignatureClassifier_class,
4849
)
@@ -92,7 +93,7 @@ def get_model_class_list_ts(ml_grid_object: pipe) -> List[Any]:
9293
OrdinalTDE_class(ml_grid_object),
9394
ResNetClassifier_class(ml_grid_object),
9495
RocketClassifier_class(ml_grid_object),
95-
ShapeDTW_class(ml_grid_object),
96+
# ShapeDTW_class(ml_grid_object), # deprecated in newer aeon versions
9697
SignatureClassifier_class(ml_grid_object),
9798
SummaryClassifier_class(ml_grid_object),
9899
TapNetClassifier_class(ml_grid_object),

0 commit comments

Comments
 (0)