Skip to content

Commit 54fb97b

Browse files
committed
Add AutoGluonClassifier wrapper, configuration, and tests
Implemented AutoGluonClassifier in ml_grid/model_classes/AutoGluonClassifier.py, providing a scikit-learn compatible wrapper for AutoGluon's TabularPredictor. Added AutoGluonClassifierClass in ml_grid/model_classes/auto_gluon_classifier_class.py to define parameter spaces for grid search and Bayesian optimization. Created unit tests in tests/test_auto_gluon_classifier.py to verify initialization, fitting, prediction, and configuration structure. The wrapper includes handling for temporary directories, time limits with safety buffers, and exclusion of specific models (e.g., NN_TORCH) for stability.
1 parent 9db1795 commit 54fb97b

3 files changed

Lines changed: 444 additions & 0 deletions

File tree

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"""AutoGluon Classifier Wrapper.
2+
3+
This module provides a scikit-learn compatible wrapper for AutoGluon's TabularPredictor.
4+
"""
5+
6+
import logging
7+
import os
8+
import shutil
9+
import tempfile
10+
import uuid
11+
from typing import Optional, List
12+
13+
import numpy as np
14+
import pandas as pd
15+
from sklearn.base import BaseEstimator, ClassifierMixin
16+
from sklearn.utils.validation import check_is_fitted
17+
18+
# Attempt to import AutoGluon
19+
try:
20+
from autogluon.tabular import TabularPredictor
21+
from autogluon.core.utils.exceptions import TimeLimitExceeded
22+
from ml_grid.util.global_params import global_parameters
23+
except ImportError:
24+
TabularPredictor = None
25+
TimeLimitExceeded = TimeoutError
26+
27+
# Mock object to avoid errors if autogluon is not installed
28+
class MockGlobalParams:
29+
pass
30+
31+
global_parameters = MockGlobalParams()
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
class AutoGluonClassifier(BaseEstimator, ClassifierMixin):
37+
"""A scikit-learn compatible wrapper for AutoGluon TabularPredictor."""
38+
39+
def __init__(
40+
self,
41+
time_limit: int = 120,
42+
presets: Optional[str] = None,
43+
eval_metric: str = "accuracy",
44+
problem_type: Optional[str] = None,
45+
seed: int = 42,
46+
verbosity: int = 2,
47+
path: Optional[str] = None,
48+
excluded_model_types: Optional[List[str]] = None,
49+
hyperparameters: Optional[dict] = None,
50+
):
51+
self.time_limit = time_limit
52+
self.presets = presets
53+
self.eval_metric = eval_metric
54+
self.problem_type = problem_type
55+
self.seed = seed
56+
self.verbosity = verbosity
57+
self.path = path
58+
self.excluded_model_types = excluded_model_types
59+
self.hyperparameters = hyperparameters
60+
61+
self.predictor_ = None
62+
self.classes_ = None
63+
self._temp_dir = None
64+
self.model_id = None # For compatibility with internal logging if needed
65+
self.timed_out_ = False
66+
67+
def fit(self, X: pd.DataFrame, y: pd.Series, **kwargs) -> "AutoGluonClassifier":
68+
if TabularPredictor is None:
69+
raise ImportError(
70+
"AutoGluon is not installed. Please install it to use AutoGluonClassifier."
71+
)
72+
73+
# Validate input X
74+
if not isinstance(X, pd.DataFrame):
75+
X = pd.DataFrame(X)
76+
X.columns = [f"feature_{i}" for i in range(X.shape[1])]
77+
78+
# Validate input y
79+
if not isinstance(y, pd.Series):
80+
y = pd.Series(y, name="target")
81+
82+
# Ensure y has a name
83+
if y.name is None:
84+
y.name = "target"
85+
86+
label_column = y.name
87+
88+
# Prepare training data
89+
train_data = X.copy()
90+
train_data[label_column] = y.values
91+
92+
effective_time_limit = self.time_limit
93+
94+
# Handle path
95+
if self.path is None:
96+
self._temp_dir = tempfile.mkdtemp(prefix="autogluon_")
97+
# AutoGluon warns if the directory exists. Since mkdtemp creates it,
98+
# we remove it so AutoGluon can recreate it without warning.
99+
shutil.rmtree(self._temp_dir)
100+
model_path = self._temp_dir
101+
else:
102+
model_path = self.path
103+
104+
# Check for FastAI and exclude if not installed to prevent ImportErrors
105+
excluded_models = (
106+
self.excluded_model_types if self.excluded_model_types is not None else []
107+
)
108+
try:
109+
import fastai # noqa: F401, E402
110+
except ImportError:
111+
if "FASTAI" not in excluded_models:
112+
excluded_models = list(excluded_models) + ["FASTAI"]
113+
114+
# Exclude NeuralNetTorch (NN_TORCH) by default for stability in unit tests, as it can be
115+
# resource-intensive and prone to filesystem errors with Ray's checkpointing.
116+
if "NN_TORCH" not in excluded_models:
117+
excluded_models.append("NN_TORCH")
118+
119+
# Initialize predictor
120+
self.predictor_ = TabularPredictor(
121+
label=label_column,
122+
problem_type=self.problem_type,
123+
eval_metric=self.eval_metric,
124+
path=model_path,
125+
verbosity=self.verbosity,
126+
)
127+
128+
# The seed for AutoGluon's HPO search should be passed in hyperparameter_tune_kwargs.
129+
# This ensures reproducibility of the internal model selection and tuning process.
130+
hyperparameter_tune_kwargs = {
131+
"searcher": "random", # Default searcher
132+
"scheduler": "local", # Default scheduler
133+
"searcher_options": {"seed": self.seed},
134+
}
135+
136+
# Apply a safety buffer to the time limit to ensure we return before any external timeout.
137+
# AutoGluon attempts to stop training by the limit, but saving/cleanup adds overhead.
138+
safe_time_limit = effective_time_limit
139+
if effective_time_limit and effective_time_limit > 20:
140+
# Reserve 10% for overhead, with a floor of 15s and a ceiling of 60s.
141+
buffer = min(60, max(15, int(effective_time_limit * 0.10)))
142+
safe_time_limit = max(effective_time_limit - buffer, 10)
143+
logger.info(
144+
f"Reduced AutoGluon time_limit from {effective_time_limit}s to {safe_time_limit}s to allow for overhead."
145+
)
146+
147+
# Set up arguments for AutoGluon's fit method
148+
fit_args = kwargs.copy()
149+
fit_args.update(
150+
{
151+
"time_limit": safe_time_limit,
152+
"hyperparameter_tune_kwargs": hyperparameter_tune_kwargs,
153+
"excluded_model_types": excluded_models,
154+
"dynamic_stacking": False,
155+
}
156+
)
157+
158+
# Prioritize hyperparameters, then presets. If neither, use a fast default for tests.
159+
if self.hyperparameters:
160+
fit_args["hyperparameters"] = self.hyperparameters
161+
elif self.presets:
162+
fit_args["presets"] = self.presets
163+
else:
164+
logger.info(
165+
"No presets or hyperparameters specified. Using fast default for unit testing: {'GBM': {}}"
166+
)
167+
fit_args["hyperparameters"] = {"GBM": {}}
168+
169+
# Log configuration to assist with debugging silent/long runs
170+
logger.info(f"Starting AutoGluon fit. Path: {model_path}")
171+
logger.info(
172+
f"Time limit: {safe_time_limit}s (Effective: {effective_time_limit}s)"
173+
)
174+
logger.info(f"Verbosity: {self.verbosity}")
175+
176+
if fit_args.get("presets"):
177+
logger.info(f"Presets: {fit_args['presets']}")
178+
179+
if fit_args.get("hyperparameters"):
180+
# Log keys only to avoid flooding logs if hyperparameters are large
181+
logger.info(
182+
f"Hyperparameters keys: {list(fit_args['hyperparameters'].keys()) if isinstance(fit_args['hyperparameters'], dict) else 'custom'}"
183+
)
184+
185+
# Mitigate nested parallelism when running inside a joblib worker.
186+
# If the JOBLIB_SPAWNED_PROCESS env var is present, we are in a worker.
187+
# Constraining num_cpus prevents resource over-subscription.
188+
if "JOBLIB_SPAWNED_PROCESS" in os.environ:
189+
logger.info(
190+
"Detected execution within a joblib worker. Constraining AutoGluon to use 1 CPU core."
191+
)
192+
if self.verbosity > 0:
193+
logger.warning(
194+
"Running inside joblib worker. AutoGluon output may be captured/suppressed by the parent process."
195+
)
196+
fit_args["num_cpus"] = 1
197+
198+
# Fit predictor
199+
try:
200+
self.predictor_.fit(train_data, **fit_args)
201+
except TimeLimitExceeded:
202+
self.timed_out_ = True
203+
logger.warning(
204+
"AutoGluon TimeLimitExceeded during fit. Checking if any models were trained..."
205+
)
206+
if self.predictor_.model_names():
207+
logger.info(
208+
"At least one model was trained. Continuing with partial fit."
209+
)
210+
else:
211+
raise
212+
except Exception as e:
213+
logger.error(f"AutoGluon fit failed with error: {e}")
214+
raise
215+
216+
# Check if any models were actually trained
217+
if not self.predictor_.model_names():
218+
msg = "AutoGluon failed to train any models."
219+
logger.error(msg)
220+
raise RuntimeError(msg)
221+
222+
self.classes_ = np.array(self.predictor_.class_labels)
223+
self.model_id = f"autogluon_{uuid.uuid4().hex}"
224+
225+
return self
226+
227+
def predict(self, X: pd.DataFrame) -> np.ndarray:
228+
check_is_fitted(self, "classes_")
229+
if not isinstance(X, pd.DataFrame):
230+
X = pd.DataFrame(X)
231+
X.columns = [f"feature_{i}" for i in range(X.shape[1])]
232+
233+
return self.predictor_.predict(X).values
234+
235+
def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
236+
check_is_fitted(self, "classes_")
237+
if not isinstance(X, pd.DataFrame):
238+
X = pd.DataFrame(X)
239+
X.columns = [f"feature_{i}" for i in range(X.shape[1])]
240+
241+
# AutoGluon returns a DataFrame with class labels as columns
242+
probas_df = self.predictor_.predict_proba(X)
243+
244+
# Ensure we return columns in the same order as self.classes_
245+
if self.classes_ is not None:
246+
return probas_df[self.classes_].values
247+
248+
return probas_df.values
249+
250+
def __del__(self):
251+
# Cleanup temporary directory
252+
if self._temp_dir and os.path.exists(self._temp_dir):
253+
try:
254+
shutil.rmtree(self._temp_dir)
255+
except Exception:
256+
pass
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""AutoGluon Classifier Configuration.
2+
3+
This module contains the AutoGluonClassifierClass, which is a configuration
4+
class for the AutoGluonClassifier. It provides parameter spaces for
5+
grid search and Bayesian optimization.
6+
"""
7+
8+
import logging
9+
from typing import Any, Dict, List, Optional, Union
10+
11+
import pandas as pd
12+
from skopt.space import Categorical, Integer
13+
14+
from ml_grid.model_classes.AutoGluonClassifier import AutoGluonClassifier
15+
from ml_grid.util.global_params import global_parameters
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class AutoGluonClassifierClass:
21+
"""Configuration class for AutoGluonClassifier."""
22+
23+
def __init__(
24+
self,
25+
X: Optional[pd.DataFrame] = None,
26+
y: Optional[pd.Series] = None,
27+
parameter_space_size: Optional[str] = None,
28+
):
29+
self.X = X
30+
self.y = y
31+
self.algorithm_implementation = AutoGluonClassifier()
32+
self.method_name = "AutoGluonClassifier"
33+
34+
self.parameter_space: Union[List[Dict[str, Any]], Dict[str, Any]]
35+
36+
if getattr(global_parameters, "test_mode", False):
37+
self.parameter_space = [
38+
{
39+
"time_limit": [5],
40+
"presets": ["medium_quality"],
41+
}
42+
]
43+
elif global_parameters.bayessearch:
44+
self.parameter_space = {
45+
"time_limit": Integer(120, 240),
46+
"presets": Categorical(["medium_quality"]),
47+
}
48+
else:
49+
self.parameter_space = [
50+
{
51+
"time_limit": [120, 180],
52+
"presets": ["medium_quality"],
53+
}
54+
]

0 commit comments

Comments
 (0)