Skip to content

Support custom BaseSearchCV subclasses by accepting param_grid or param_distributions#8

Open
axif0 wants to merge 1 commit into
openml:mainfrom
axif0:warning
Open

Support custom BaseSearchCV subclasses by accepting param_grid or param_distributions#8
axif0 wants to merge 1 commit into
openml:mainfrom
axif0:warning

Conversation

@axif0
Copy link
Copy Markdown

@axif0 axif0 commented Nov 20, 2025

Removes the hard-coded check for only GridSearchCV and RandomizedSearchCV.
Now, any BaseSearchCV subclass is supported as long as it provides either param_distributions or param_grid.
The n_jobs safety check remains unchanged.

I tried this code, to reproduce the warning -

class CustomSuccessiveHalving(sklearn.model_selection._search.BaseSearchCV):
    def __init__(self, estimator, param_distributions):
        self.param_distributions = param_distributions
        self.estimator = estimator
        super().__init__(
            estimator=estimator,
            scoring=None,
            n_jobs=None,
            refit=True,
            cv=3,
            verbose=0,
            pre_dispatch='2*n_jobs',
            error_score='raise',
            return_train_score=False,
        )
    
    def _run_search(self, evaluate_candidates):
        from sklearn.model_selection import ParameterGrid
        param_grid = ParameterGrid(self.param_distributions)
        evaluate_candidates(param_grid)

if __name__ == "__main__":
    extension = SklearnExtension()
    
    base_estimator = sklearn.ensemble.RandomForestClassifier()
    param_distributions = {
        "max_depth": [3, 5],
        "n_estimators": [10, 20]
    }
    
    custom_hpo = CustomSuccessiveHalving(
        estimator=base_estimator,
        param_distributions=param_distributions
    )
    
    extension._prevent_optimize_n_jobs(custom_hpo)

Closes: #6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Run BaseSearchCV instances which are not scikit-learn builtins

1 participant