Skip to content

Commit 9db1795

Browse files
committed
Add TPOT classifier wrapper, configuration, and tests
- Implement `TPOTClassifierWrapper` in `ml_grid/model_classes/TPOTClassifierWrapper.py` to wrap `TPOTClassifier` for scikit-learn compatibility, handling fitting, prediction, and probability estimation. - Create `TPOTClassifierClass` in `ml_grid/model_classes/tpot_classifier_class.py` to define hyperparameter spaces for grid search and Bayesian optimization (using `skopt`). - Add unit tests in `tests/test_tpot_classifier.py` to verify initialization, fitting, prediction, and configuration structure. - Handle optional TPOT import to avoid hard dependency failures.
1 parent 604fc80 commit 9db1795

3 files changed

Lines changed: 275 additions & 0 deletions

File tree

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""TPOT Classifier Wrapper.
2+
3+
This module provides a scikit-learn compatible wrapper for TPOTClassifier.
4+
"""
5+
6+
import logging
7+
from typing import Optional, Union
8+
9+
import numpy as np
10+
import pandas as pd
11+
from sklearn.base import BaseEstimator, ClassifierMixin
12+
from sklearn.utils.validation import check_is_fitted
13+
14+
# Attempt to import TPOT
15+
try:
16+
from tpot import TPOTClassifier
17+
except ImportError:
18+
TPOTClassifier = None
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class TPOTClassifierWrapper(BaseEstimator, ClassifierMixin):
24+
"""A scikit-learn compatible wrapper for TPOTClassifier."""
25+
26+
def __init__(
27+
self,
28+
generations: int = 5,
29+
population_size: int = 20,
30+
offspring_size: Optional[int] = None,
31+
mutation_rate: float = 0.9,
32+
crossover_rate: float = 0.1,
33+
scoring: str = "accuracy",
34+
cv: int = 5,
35+
subsample: float = 1.0,
36+
n_jobs: int = -1,
37+
max_time_mins: Optional[int] = None,
38+
max_eval_time_mins: float = 5,
39+
random_state: int = 42,
40+
verbosity: int = 2,
41+
early_stop: Optional[int] = None,
42+
):
43+
self.generations = generations
44+
self.population_size = population_size
45+
self.offspring_size = offspring_size
46+
self.mutation_rate = mutation_rate
47+
self.crossover_rate = crossover_rate
48+
self.scoring = scoring
49+
self.cv = cv
50+
self.subsample = subsample
51+
self.n_jobs = n_jobs
52+
self.max_time_mins = max_time_mins
53+
self.max_eval_time_mins = max_eval_time_mins
54+
self.random_state = random_state
55+
self.verbosity = verbosity
56+
self.early_stop = early_stop
57+
58+
self.model_ = None
59+
60+
def fit(
61+
self,
62+
X: Union[np.ndarray, pd.DataFrame],
63+
y: Union[np.ndarray, pd.Series],
64+
**kwargs,
65+
) -> "TPOTClassifierWrapper":
66+
if TPOTClassifier is None:
67+
raise ImportError(
68+
"TPOT is not installed. Please install it to use TPOTClassifierWrapper."
69+
)
70+
71+
self.model_ = TPOTClassifier(
72+
generations=self.generations,
73+
population_size=self.population_size,
74+
offspring_size=self.offspring_size,
75+
mutation_rate=self.mutation_rate,
76+
crossover_rate=self.crossover_rate,
77+
scoring=self.scoring,
78+
cv=self.cv,
79+
subsample=self.subsample,
80+
n_jobs=self.n_jobs,
81+
max_time_mins=self.max_time_mins,
82+
max_eval_time_mins=self.max_eval_time_mins,
83+
random_state=self.random_state,
84+
verbosity=self.verbosity,
85+
early_stop=self.early_stop,
86+
disable_update_check=True,
87+
)
88+
89+
# TPOT can be slow. For quick checks, it's useful to see it has started.
90+
logger.info(
91+
f"Starting TPOT fit with generations={self.generations}, population_size={self.population_size}..."
92+
)
93+
94+
self.model_.fit(X, y, **kwargs)
95+
96+
# After fitting, TPOT stores the best pipeline in the `fitted_pipeline_` attribute.
97+
# We must set `classes_` for scikit-learn compatibility (e.g., for check_is_fitted).
98+
# While TPOT exposes `self.model_.classes_`, inferring from `y` is a more robust
99+
# fallback, consistent with other wrappers in this project.
100+
if hasattr(self.model_, "classes_"):
101+
self.classes_ = self.model_.classes_
102+
else:
103+
self.classes_ = np.unique(y)
104+
105+
logger.info("TPOT fit completed.")
106+
return self
107+
108+
def predict(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
109+
check_is_fitted(self, ["model_", "classes_"])
110+
return self.model_.predict(X)
111+
112+
def predict_proba(self, X: Union[np.ndarray, pd.DataFrame]) -> np.ndarray:
113+
check_is_fitted(self, ["model_", "classes_"])
114+
return self.model_.predict_proba(X)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""TPOT Classifier Configuration.
2+
3+
This module contains the TPOTClassifierClass, which is a configuration
4+
class for the TPOTClassifierWrapper. It provides parameter spaces for
5+
grid search and Bayesian optimization, with a focus on providing a fast
6+
default for unit testing.
7+
"""
8+
9+
import logging
10+
from typing import Any, Dict, List, Optional, Union
11+
12+
import pandas as pd
13+
from skopt.space import Categorical, Integer
14+
15+
from ml_grid.model_classes.TPOTClassifierWrapper import TPOTClassifierWrapper
16+
from ml_grid.util.global_params import global_parameters
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class TPOTClassifierClass:
22+
"""Configuration class for TPOTClassifierWrapper."""
23+
24+
def __init__(
25+
self,
26+
X: Optional[pd.DataFrame] = None,
27+
y: Optional[pd.Series] = None,
28+
parameter_space_size: Optional[str] = None,
29+
):
30+
self.X = X
31+
self.y = y
32+
self.algorithm_implementation = TPOTClassifierWrapper()
33+
self.method_name = "TPOTClassifier"
34+
35+
self.parameter_space: Union[List[Dict[str, Any]], Dict[str, Any]]
36+
37+
if getattr(global_parameters, "test_mode", False):
38+
self.parameter_space = [
39+
{"generations": [2], "population_size": [5], "max_time_mins": [1]}
40+
]
41+
elif global_parameters.bayessearch:
42+
# A slightly larger space for Bayesian search, but still constrained
43+
self.parameter_space = {
44+
"generations": Integer(5, 100),
45+
"population_size": Integer(20, 100),
46+
"scoring": Categorical(
47+
["accuracy", "f1", "roc_auc", "precision", "recall"]
48+
),
49+
"max_time_mins": Integer(10, 120), # Time limit is crucial
50+
}
51+
else:
52+
# Expanded parameter space for grid search
53+
self.parameter_space = [
54+
{
55+
"generations": [5, 10, 20],
56+
"population_size": [20, 50, 100],
57+
"max_time_mins": [10, 30, 60],
58+
"scoring": ["accuracy", "f1", "roc_auc"],
59+
}
60+
]

tests/test_tpot_classifier.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
import pandas as pd
4+
import numpy as np
5+
6+
from ml_grid.model_classes.TPOTClassifierWrapper import TPOTClassifierWrapper
7+
from ml_grid.model_classes.tpot_classifier_class import TPOTClassifierClass
8+
9+
10+
class TestTPOTClassifier(unittest.TestCase):
11+
def setUp(self):
12+
self.X = pd.DataFrame(
13+
{"feature_0": [1.0, 2.0, 3.0, 4.0], "feature_1": [4.0, 3.0, 2.0, 1.0]}
14+
)
15+
self.y = pd.Series([0, 1, 0, 1], name="target")
16+
17+
def test_init(self):
18+
clf = TPOTClassifierWrapper(generations=10, population_size=50)
19+
self.assertEqual(clf.generations, 10)
20+
self.assertEqual(clf.population_size, 50)
21+
self.assertIsNone(clf.model_)
22+
23+
@patch("ml_grid.model_classes.TPOTClassifierWrapper.TPOTClassifier")
24+
def test_fit(self, mock_tpot_cls):
25+
# Setup mocks
26+
mock_tpot_instance = MagicMock()
27+
mock_tpot_cls.return_value = mock_tpot_instance
28+
29+
clf = TPOTClassifierWrapper(generations=5, population_size=20)
30+
31+
# Test fit
32+
clf.fit(self.X, self.y)
33+
34+
# Verify TPOTClassifier init
35+
mock_tpot_cls.assert_called_once()
36+
_, kwargs = mock_tpot_cls.call_args
37+
self.assertEqual(kwargs["generations"], 5)
38+
self.assertEqual(kwargs["population_size"], 20)
39+
self.assertEqual(kwargs["disable_update_check"], True)
40+
41+
# Verify fit call
42+
mock_tpot_instance.fit.assert_called_once_with(self.X, self.y)
43+
44+
# Verify attributes set
45+
self.assertIsNotNone(clf.model_)
46+
47+
@patch("ml_grid.model_classes.TPOTClassifierWrapper.TPOTClassifier")
48+
def test_predict(self, mock_tpot_cls):
49+
# Setup mock
50+
mock_tpot_instance = MagicMock()
51+
mock_tpot_cls.return_value = mock_tpot_instance
52+
53+
# Mock predict return
54+
mock_tpot_instance.predict.return_value = np.array([0, 1, 0, 1])
55+
56+
clf = TPOTClassifierWrapper()
57+
clf.fit(self.X, self.y)
58+
59+
preds = clf.predict(self.X)
60+
61+
self.assertIsInstance(preds, np.ndarray)
62+
np.testing.assert_array_equal(preds, np.array([0, 1, 0, 1]))
63+
64+
# Verify predict called on internal model
65+
mock_tpot_instance.predict.assert_called_once_with(self.X)
66+
67+
def test_missing_tpot(self):
68+
# Simulate missing tpot by patching TPOTClassifier to None
69+
with patch("ml_grid.model_classes.TPOTClassifierWrapper.TPOTClassifier", None):
70+
clf = TPOTClassifierWrapper()
71+
with self.assertRaises(ImportError):
72+
clf.fit(self.X, self.y)
73+
74+
75+
class TestTPOTClassifierClass(unittest.TestCase):
76+
def test_structure(self):
77+
# Mock global_parameters to control bayessearch flag
78+
with patch(
79+
"ml_grid.model_classes.tpot_classifier_class.global_parameters"
80+
) as mock_globals:
81+
# Case 1: Grid Search (bayessearch = False)
82+
mock_globals.bayessearch = False
83+
84+
config = TPOTClassifierClass()
85+
self.assertEqual(config.method_name, "TPOTClassifier")
86+
self.assertIsInstance(
87+
config.algorithm_implementation, TPOTClassifierWrapper
88+
)
89+
self.assertIsInstance(config.parameter_space, list)
90+
91+
# Case 2: Bayes Search (bayessearch = True)
92+
mock_globals.bayessearch = True
93+
94+
config = TPOTClassifierClass()
95+
self.assertIsInstance(config.parameter_space, dict)
96+
self.assertIn("generations", config.parameter_space)
97+
self.assertIn("population_size", config.parameter_space)
98+
99+
100+
if __name__ == "__main__":
101+
unittest.main()

0 commit comments

Comments
 (0)