Skip to content

Commit 2aec2e6

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Add SoftKNNClassifierModel (meta-pytorch#3243)
Summary: X-link: facebook/Ax#5072 Add a differentiable Soft K-Nearest Neighbors classifier model for failure-aware Bayesian optimization. Unlike tree-based classifiers (RF, XGBoost), SoftKNN is fully differentiable, enabling gradient-based acquisition function optimization. The model uses Gaussian kernel weights: P(y=1|x) = sum(w_i * y_i) / sum(w_i) where w_i = exp(-||x - x_i||^2 / (2 * sigma^2)) Implements construct_inputs classmethod for seamless Ax integration. Differential Revision: D90894389
1 parent 9a296b6 commit 2aec2e6

4 files changed

Lines changed: 383 additions & 0 deletions

File tree

botorch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ApproximateGPyTorchModel,
99
SingleTaskVariationalGP,
1010
)
11+
from botorch.models.classifier import SoftKNNClassifierModel
1112
from botorch.models.cost import AffineFidelityCostModel
1213
from botorch.models.deterministic import (
1314
AffineDeterministicModel,
@@ -52,4 +53,5 @@
5253
"SingleTaskGP",
5354
"SingleTaskMultiFidelityGP",
5455
"SingleTaskVariationalGP",
56+
"SoftKNNClassifierModel",
5557
]

botorch/models/classifier.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Classifier-based models for constraint boundaries and deterministic feasibility.
9+
10+
These models wrap classifiers as BoTorch deterministic models,
11+
enabling them to be used for modeling binary constraints, feasibility, and other
12+
discontinuous outputs where traditional GP models fail due to smoothness assumptions.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
19+
import torch
20+
from botorch.models.deterministic import GenericDeterministicModel
21+
from botorch.models.transforms.input import InputTransform
22+
from botorch.utils.datasets import SupervisedDataset
23+
from torch import Tensor
24+
25+
26+
class SoftKNNClassifierModel(GenericDeterministicModel):
27+
"""
28+
Soft K-Nearest Neighbors classifier wrapped as a BoTorch deterministic model.
29+
30+
This model uses Gaussian kernel weighting to compute soft class probabilities.
31+
Supports both fixed scalar sigma and learnable per-dimension sigma trained via
32+
leave-one-out (LOO) cross-validation.
33+
34+
Example:
35+
>>> from botorch.models.classifier import SoftKNNClassifierModel
36+
>>> from botorch.utils.datasets import SupervisedDataset
37+
>>> import torch
38+
>>>
39+
>>> X = torch.randn(100, 5)
40+
>>> y = torch.randint(0, 2, (100, 1), dtype=torch.float64)
41+
>>> dataset = SupervisedDataset(X=X, Y=y)
42+
>>>
43+
>>> # Fixed sigma
44+
>>> model_inputs = SoftKNNClassifierModel.construct_inputs(
45+
... training_data=dataset,
46+
... sigma=0.3
47+
... )
48+
>>> model = SoftKNNClassifierModel(**model_inputs)
49+
>>>
50+
>>> # Learnable per-dimension sigma
51+
>>> model_inputs = SoftKNNClassifierModel.construct_inputs(
52+
... training_data=dataset,
53+
... learnable_sigma=True,
54+
... sigma_epochs=100
55+
... )
56+
>>> model = SoftKNNClassifierModel(**model_inputs)
57+
"""
58+
59+
def __init__(
60+
self,
61+
train_X: Tensor,
62+
train_Y: Tensor,
63+
sigma: float = 0.1,
64+
learnable_sigma: bool = False,
65+
sigma_lr: float = 0.1,
66+
sigma_epochs: int = 100,
67+
input_transform: InputTransform | None = None,
68+
**kwargs: Any,
69+
) -> None:
70+
"""Initialize SoftKNNClassifierModel.
71+
72+
Args:
73+
train_X: Training features tensor of shape (n, d).
74+
train_Y: Training labels tensor of shape (n,) or (n, 1), binary (0 or 1).
75+
sigma: Initial Gaussian kernel bandwidth (default: 0.1).
76+
learnable_sigma: If True, learn per-dimension sigma via LOO
77+
cross-validation (default: False).
78+
sigma_lr: Learning rate for sigma optimization (default: 0.1).
79+
sigma_epochs: Training epochs for sigma (default: 100).
80+
input_transform: Optional InputTransform applied to both training
81+
and test inputs before distance computation.
82+
**kwargs: Additional arguments (ignored).
83+
"""
84+
# Ensure train_Y is 1D
85+
train_Y = train_Y.view(-1)
86+
87+
# Apply input transform to training data if provided
88+
# This ensures train_X_t is in the same space as test inputs
89+
# (which are transformed via Model.transform_inputs in posterior())
90+
if input_transform is not None:
91+
train_X_t = input_transform(train_X)
92+
else:
93+
train_X_t = train_X
94+
95+
# Learn or use fixed sigma
96+
learned_sigma_tensor: Tensor | None = None
97+
if learnable_sigma:
98+
# Learn per-dimension sigma via LOO cross-validation
99+
d = train_X_t.shape[-1]
100+
log_sigma = torch.nn.Parameter(
101+
torch.full(
102+
(d,),
103+
torch.log(torch.tensor(sigma, dtype=train_X_t.dtype)),
104+
device=train_X_t.device,
105+
dtype=train_X_t.dtype,
106+
)
107+
)
108+
109+
optimizer = torch.optim.Adam([log_sigma], lr=sigma_lr, foreach=True)
110+
N = train_X_t.shape[0]
111+
train_Y_float = train_Y.to(dtype=train_X_t.dtype)
112+
113+
for _ in range(sigma_epochs):
114+
optimizer.zero_grad()
115+
sigma_vec = log_sigma.exp() # [d]
116+
117+
# Pairwise distances with per-dim sigma: sum((x_i - x_j)^2 / sigma_j^2)
118+
diffs = train_X_t.unsqueeze(1) - train_X_t.unsqueeze(0) # [N, N, d]
119+
dists = torch.sum((diffs**2) / (sigma_vec**2), dim=2) # [N, N]
120+
121+
# LOO: exclude self (diagonal)
122+
mask = ~torch.eye(N, dtype=torch.bool, device=train_X_t.device)
123+
weights = torch.exp(-dists / 2) * mask
124+
125+
weighted_class1 = torch.sum(
126+
weights * (train_Y_float == 1.0).to(dtype=train_X_t.dtype), dim=1
127+
)
128+
total_weights = torch.sum(weights, dim=1)
129+
prob_class1 = weighted_class1 / (total_weights + 1e-12)
130+
131+
# Binary cross-entropy loss
132+
eps = 1e-7
133+
prob_class1_clamped = prob_class1.clamp(eps, 1 - eps)
134+
loss = -torch.mean(
135+
train_Y_float * torch.log(prob_class1_clamped)
136+
+ (1 - train_Y_float) * torch.log(1 - prob_class1_clamped)
137+
)
138+
loss.backward()
139+
optimizer.step()
140+
141+
# Detach learned sigma for inference
142+
sigma_final: Tensor | float = log_sigma.exp().detach() # [d]
143+
learned_sigma_tensor = sigma_final
144+
else:
145+
sigma_final = sigma # scalar
146+
147+
# Create prediction closure with transformed training data
148+
def predict_proba_fn(X: Tensor) -> Tensor:
149+
original_shape = X.shape[:-1]
150+
# Already transformed via Model.transform_inputs if set
151+
X_flat = X.reshape(-1, X.shape[-1])
152+
153+
diffs = X_flat.unsqueeze(1) - train_X_t.to(X_flat).unsqueeze(0)
154+
155+
if isinstance(sigma_final, Tensor):
156+
# Per-dimension sigma
157+
dists = torch.sum((diffs**2) / (sigma_final.to(X_flat) ** 2), dim=2)
158+
weights = torch.exp(-dists / 2)
159+
else:
160+
# Scalar sigma
161+
dists = torch.sum(diffs**2, dim=2)
162+
weights = torch.exp(-dists / (2 * sigma_final**2))
163+
164+
mask_class1 = train_Y.to(X_flat) == 1.0
165+
mask_class1 = mask_class1.to(dtype=X_flat.dtype)
166+
167+
weighted_class1 = torch.matmul(weights, mask_class1)
168+
total_weights = torch.sum(weights, dim=1)
169+
probs_flat = weighted_class1 / (total_weights + 1e-12)
170+
171+
return probs_flat.reshape(*original_shape, 1)
172+
173+
# Initialize parent with the prediction function
174+
super().__init__(f=predict_proba_fn, num_outputs=1)
175+
176+
# Register input_transform as a submodule so posterior() applies it
177+
if input_transform is not None:
178+
self.input_transform = input_transform
179+
180+
# Expose learned sigma (if any) for inspection
181+
self.learned_sigma = learned_sigma_tensor
182+
183+
@classmethod
184+
def construct_inputs(
185+
cls,
186+
training_data: SupervisedDataset,
187+
**kwargs: Any,
188+
) -> dict[str, Any]:
189+
"""
190+
Construct inputs for SoftKNNClassifierModel from training data.
191+
192+
This method extracts training data and parameters that will be passed
193+
to __init__, where the input_transform is applied and the prediction
194+
closure is created. This ensures compatibility with Ax's model bridge,
195+
which adds input_transform after calling construct_inputs.
196+
197+
Args:
198+
training_data: SupervisedDataset with X (features) and Y (labels).
199+
sigma: Initial Gaussian kernel bandwidth (default: 0.1).
200+
learnable_sigma: If True, learn per-dimension sigma via LOO
201+
cross-validation (default: False).
202+
sigma_lr: Learning rate for sigma optimization (default: 0.1).
203+
sigma_epochs: Training epochs for sigma (default: 100).
204+
input_transform: Optional InputTransform applied to both training
205+
and test inputs before distance computation.
206+
207+
Returns:
208+
Dictionary with training data and model parameters.
209+
"""
210+
return {
211+
"train_X": training_data.X.detach().clone(),
212+
"train_Y": training_data.Y.detach().clone(),
213+
"sigma": kwargs.get("sigma", 0.1),
214+
"learnable_sigma": kwargs.get("learnable_sigma", False),
215+
"sigma_lr": kwargs.get("sigma_lr", 0.1),
216+
"sigma_epochs": kwargs.get("sigma_epochs", 100),
217+
"input_transform": kwargs.get("input_transform", None),
218+
}

sphinx/source/models.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ Additive GP Models
3939
.. automodule:: botorch.models.additive_gp
4040
:members:
4141

42+
Classifier Models
43+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44+
.. automodule:: botorch.models.classifier
45+
:members:
46+
4247
Cost Models (for cost-aware optimization)
4348
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4449
.. automodule:: botorch.models.cost

0 commit comments

Comments
 (0)