Skip to content

Commit f01f0f9

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Add SoftKNNClassifierModel (facebook#5072)
Summary: X-link: meta-pytorch/botorch#3243 Add a differentiable Soft K-Nearest Neighbors classifier model for failure-aware Bayesian optimization. Unlike tree-based classifiers (RF, XGBoost), SoftKNN is fully differentiable, enabling gradient-based acquisition function optimization. The model uses Gaussian kernel weights: P(y=1|x) = sum(w_i * y_i) / sum(w_i) where w_i = exp(-||x - x_i||^2 / (2 * sigma^2)) Implements construct_inputs classmethod for seamless Ax integration. Differential Revision: D90894389
1 parent 358c4c6 commit f01f0f9

2 files changed

Lines changed: 13 additions & 0 deletions

File tree

ax/generators/torch/botorch_modular/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from botorch.exceptions.errors import BotorchError, CandidateGenerationError
3838
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
3939
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
40+
from botorch.models.classifier import SoftKNNClassifierModel
4041
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
4142
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
4243
from botorch.models.gp_regression import SingleTaskGP
@@ -808,6 +809,15 @@ def _fit_botorch_model_fully_bayesian_nuts(
808809
fit_fully_bayesian_model_nuts(model, **mll_options)
809810

810811

812+
@fit_botorch_model.register(SoftKNNClassifierModel)
813+
def _fit_botorch_model_classifier(
814+
model: SoftKNNClassifierModel,
815+
mll_class: type[MarginalLogLikelihood],
816+
mll_options: dict[str, Any] | None = None,
817+
) -> None:
818+
"""Classifier models fit themselves in __init__(), so no-op here."""
819+
820+
811821
@fit_botorch_model.register(object)
812822
def _fit_botorch_model_not_implemented(
813823
model: Model,

ax/storage/botorch_modular_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
AnalyticExpectedUtilityOfBestOption,
6161
qExpectedUtilityOfBestOption,
6262
)
63+
from botorch.models.classifier import SoftKNNClassifierModel
6364
from botorch.models.contextual import LCEAGP
6465
from botorch.models.fully_bayesian import (
6566
FullyBayesianLinearSingleTaskGP,
@@ -145,6 +146,8 @@
145146
AdditiveMapSaasSingleTaskGP: "AdditiveMapSaasSingleTaskGP",
146147
EnsembleMapSaasSingleTaskGP: "EnsembleMapSaasSingleTaskGP",
147148
HeterogeneousMTGP: "HeterogeneousMTGP",
149+
# Classifier models
150+
SoftKNNClassifierModel: "SoftKNNClassifierModel",
148151
}
149152

150153

0 commit comments

Comments
 (0)