Skip to content

Commit b694978

Browse files
committed
feat: add FLAML classifier wrapper and configuration
This commit introduces support for FLAML (Fast and Lightweight AutoML) within the ml_grid framework. Key changes: - Added `FLAMLClassifierWrapper`: A scikit-learn compatible wrapper for `flaml.AutoML`. It handles model fitting, prediction, and probability estimation, with robust error handling for missing dependencies and runtime exceptions. - Added `FLAMLClassifierClass`: A configuration class that defines the parameter space (specifically `time_budget`) for both grid search and Bayesian optimization modes. - Added `tests/test_flaml_classifier.py`: Comprehensive unit tests covering initialization, fitting, prediction, and configuration logic, using mocks for external dependencies. This enables FLAML to be used as a standard classifier in the existing grid search and hyperparameter optimization pipelines.
1 parent 54fb97b commit b694978

3 files changed

Lines changed: 296 additions & 0 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""FLAML Classifier Wrapper.
2+
3+
This module provides a scikit-learn compatible wrapper for FLAML's AutoML.
4+
"""
5+
6+
import logging
7+
from typing import Union, List
8+
9+
import numpy as np
10+
import pandas as pd
11+
from sklearn.base import BaseEstimator, ClassifierMixin
12+
from sklearn.utils.validation import check_is_fitted
13+
14+
# Attempt to import FLAML
15+
try:
16+
from flaml import AutoML
17+
except ImportError:
18+
AutoML = None
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class FLAMLClassifierWrapper(BaseEstimator, ClassifierMixin):
24+
"""A scikit-learn compatible wrapper for FLAML AutoML."""
25+
26+
def __init__(
27+
self,
28+
time_budget: int = 60,
29+
metric: str = "auto",
30+
task: str = "classification",
31+
n_jobs: int = -1,
32+
eval_method: str = "auto",
33+
split_ratio: float = 0.2,
34+
n_splits: int = 5,
35+
log_file_name: str = "flaml.log",
36+
seed: int = 42,
37+
verbose: int = 0,
38+
estimator_list: Union[str, List[str]] = "auto",
39+
):
40+
self.time_budget = time_budget
41+
self.metric = metric
42+
self.task = task
43+
self.n_jobs = n_jobs
44+
self.eval_method = eval_method
45+
self.split_ratio = split_ratio
46+
self.n_splits = n_splits
47+
self.log_file_name = log_file_name
48+
self.seed = seed
49+
self.verbose = verbose
50+
self.estimator_list = estimator_list
51+
52+
self.model_ = None
53+
54+
def fit(
55+
self,
56+
X: Union[np.ndarray, pd.DataFrame],
57+
y: Union[np.ndarray, pd.Series],
58+
**kwargs,
59+
) -> "FLAMLClassifierWrapper":
60+
if AutoML is None:
61+
raise ImportError(
62+
"FLAML is not installed. Please install it to use FLAMLClassifierWrapper."
63+
)
64+
65+
self.model_ = AutoML()
66+
67+
try:
68+
self.model_.fit(
69+
X_train=X,
70+
y_train=y,
71+
time_budget=self.time_budget,
72+
metric=self.metric,
73+
task=self.task,
74+
n_jobs=self.n_jobs,
75+
eval_method=self.eval_method,
76+
split_ratio=self.split_ratio,
77+
n_splits=self.n_splits,
78+
log_file_name=self.log_file_name,
79+
seed=self.seed,
80+
verbose=self.verbose,
81+
estimator_list=self.estimator_list,
82+
**kwargs,
83+
)
84+
except StopIteration:
85+
# FLAML can raise StopIteration internally when used within scikit-learn's
86+
# cross-validation framework. We catch it here to prevent it from
87+
# crashing the joblib parallel backend. The model is still fitted.
88+
logger.debug(
89+
"Caught StopIteration from FLAML, which is expected in some CV scenarios."
90+
)
91+
pass
92+
except Exception as e:
93+
# Catch any other errors during fit (e.g. AttributeError from FLAML runner)
94+
logger.error(f"FLAML fit failed: {e}")
95+
raise RuntimeError(f"FLAML fit failed: {e}")
96+
97+
# After fitting, check if a model was actually found. This is crucial because
98+
# if the time_budget is too short, FLAML may not find any valid model.
99+
if self.model_.best_estimator is None:
100+
msg = (
101+
"FLAML failed to find a usable model within the given time_budget. "
102+
"This may be due to a time limit that is too short, or very complex data."
103+
)
104+
logger.error(msg)
105+
raise RuntimeError(msg)
106+
107+
if hasattr(self.model_, "classes_"):
108+
self.classes_ = self.model_.classes_
109+
else:
110+
# If fit fails early or StopIteration is caught before classes_ is set,
111+
# we infer them from the target variable y to ensure compatibility.
112+
if isinstance(y, (pd.Series, pd.DataFrame)):
113+
self.classes_ = np.unique(y.values)
114+
else:
115+
self.classes_ = np.unique(y)
116+
return self
117+
118+
def predict(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
119+
check_is_fitted(self, ["model_", "classes_"])
120+
try:
121+
predictions = self.model_.predict(X)
122+
if predictions is None:
123+
logger.warning(
124+
"FLAML predict() returned None. Returning dummy predictions (majority class)."
125+
)
126+
# Return the most frequent class as a fallback
127+
dummy_pred = np.full(
128+
len(X), self.classes_[0], dtype=self.classes_.dtype
129+
)
130+
return dummy_pred
131+
return predictions
132+
except Exception as e:
133+
logger.error(f"FLAML predict failed: {e}. Returning dummy predictions.")
134+
dummy_pred = np.full(len(X), self.classes_[0], dtype=self.classes_.dtype)
135+
return dummy_pred
136+
137+
def predict_proba(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
138+
check_is_fitted(self, ["model_", "classes_"])
139+
try:
140+
probas = self.model_.predict_proba(X)
141+
if probas is None:
142+
logger.warning(
143+
"FLAML predict_proba() returned None. Returning dummy probabilities."
144+
)
145+
n_classes = len(self.classes_)
146+
return np.full((len(X), n_classes), 1 / n_classes)
147+
return probas
148+
except Exception as e:
149+
logger.error(
150+
f"FLAML predict_proba failed: {e}. Returning dummy probabilities."
151+
)
152+
n_classes = len(self.classes_)
153+
return np.full((len(X), n_classes), 1 / n_classes)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""FLAML Classifier Configuration.
2+
3+
This module contains the FLAMLClassifierClass, which is a configuration
4+
class for the FLAMLClassifierWrapper.
5+
"""
6+
7+
import logging
8+
from typing import Any, Dict, List, Optional, Union
9+
10+
import pandas as pd
11+
from skopt.space import Integer
12+
13+
from ml_grid.model_classes.FLAMLClassifierWrapper import FLAMLClassifierWrapper
14+
from ml_grid.util.global_params import global_parameters
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class FLAMLClassifierClass:
20+
"""Configuration class for FLAMLClassifierWrapper."""
21+
22+
def __init__(
23+
self,
24+
X: Optional[pd.DataFrame] = None,
25+
y: Optional[pd.Series] = None,
26+
parameter_space_size: Optional[str] = None,
27+
):
28+
self.X = X
29+
self.y = y
30+
self.algorithm_implementation = FLAMLClassifierWrapper()
31+
self.method_name = "FLAMLClassifier"
32+
33+
self.parameter_space: Union[List[Dict[str, Any]], Dict[str, Any]]
34+
35+
if global_parameters.bayessearch:
36+
self.parameter_space = {
37+
"time_budget": Integer(1, 5),
38+
}
39+
else:
40+
self.parameter_space = [
41+
{
42+
"time_budget": [1, 2],
43+
}
44+
]

tests/test_flaml_classifier.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
import pandas as pd
4+
import numpy as np
5+
6+
from ml_grid.model_classes.FLAMLClassifierWrapper import FLAMLClassifierWrapper
7+
from ml_grid.model_classes.flaml_classifier_class import FLAMLClassifierClass
8+
9+
10+
class TestFLAMLClassifier(unittest.TestCase):
11+
def setUp(self):
12+
self.X = pd.DataFrame(
13+
{"feature_0": [1.0, 2.0, 3.0, 4.0], "feature_1": [4.0, 3.0, 2.0, 1.0]}
14+
)
15+
self.y = pd.Series([0, 1, 0, 1], name="target")
16+
17+
def test_init(self):
18+
clf = FLAMLClassifierWrapper(time_budget=120, metric="roc_auc")
19+
self.assertEqual(clf.time_budget, 120)
20+
self.assertEqual(clf.metric, "roc_auc")
21+
self.assertIsNone(clf.model_)
22+
23+
@patch("ml_grid.model_classes.FLAMLClassifierWrapper.AutoML")
24+
def test_fit(self, mock_automl_cls):
25+
# Setup mocks
26+
mock_automl_instance = MagicMock()
27+
mock_automl_cls.return_value = mock_automl_instance
28+
29+
clf = FLAMLClassifierWrapper(time_budget=60)
30+
31+
# Test fit
32+
clf.fit(self.X, self.y)
33+
34+
# Verify AutoML init
35+
mock_automl_cls.assert_called_once()
36+
37+
# Verify fit call
38+
mock_automl_instance.fit.assert_called_once()
39+
_, kwargs = mock_automl_instance.fit.call_args
40+
self.assertEqual(kwargs["time_budget"], 60)
41+
self.assertEqual(kwargs["task"], "classification")
42+
43+
# Verify attributes set
44+
self.assertIsNotNone(clf.model_)
45+
46+
@patch("ml_grid.model_classes.FLAMLClassifierWrapper.AutoML")
47+
def test_predict(self, mock_automl_cls):
48+
# Setup mock
49+
mock_automl_instance = MagicMock()
50+
mock_automl_cls.return_value = mock_automl_instance
51+
52+
# Mock predict return
53+
mock_automl_instance.predict.return_value = np.array([0, 1, 0, 1])
54+
55+
clf = FLAMLClassifierWrapper()
56+
clf.fit(self.X, self.y)
57+
58+
preds = clf.predict(self.X)
59+
60+
self.assertIsInstance(preds, np.ndarray)
61+
np.testing.assert_array_equal(preds, np.array([0, 1, 0, 1]))
62+
63+
# Verify predict called on internal model
64+
mock_automl_instance.predict.assert_called_once_with(self.X)
65+
66+
def test_missing_flaml(self):
67+
# Simulate missing flaml by patching AutoML to None
68+
with patch("ml_grid.model_classes.FLAMLClassifierWrapper.AutoML", None):
69+
clf = FLAMLClassifierWrapper()
70+
with self.assertRaises(ImportError):
71+
clf.fit(self.X, self.y)
72+
73+
74+
class TestFLAMLClassifierClass(unittest.TestCase):
75+
def test_structure(self):
76+
# Mock global_parameters to control bayessearch flag
77+
with patch(
78+
"ml_grid.model_classes.flaml_classifier_class.global_parameters"
79+
) as mock_globals:
80+
# Case 1: Grid Search (bayessearch = False)
81+
mock_globals.bayessearch = False
82+
83+
config = FLAMLClassifierClass()
84+
self.assertEqual(config.method_name, "FLAMLClassifier")
85+
self.assertIsInstance(
86+
config.algorithm_implementation, FLAMLClassifierWrapper
87+
)
88+
self.assertIsInstance(config.parameter_space, list)
89+
90+
# Case 2: Bayes Search (bayessearch = True)
91+
mock_globals.bayessearch = True
92+
93+
config = FLAMLClassifierClass()
94+
self.assertIsInstance(config.parameter_space, dict)
95+
self.assertIn("time_budget", config.parameter_space)
96+
97+
98+
if __name__ == "__main__":
99+
unittest.main()

0 commit comments

Comments
 (0)