Skip to content

Commit 04096f0

Browse files
committed
Add Pytorch backend
1 parent d9910b1 commit 04096f0

2 files changed

Lines changed: 144 additions & 4 deletions

File tree

test/test_cav_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def test_all(self):
278278
cav_trainer = CavTrainer(tpcav_model, penalty="l2")
279279
cav_trainer.set_control(builder.control_concepts[0], num_samples=100)
280280

281+
cav_trainer.train_concepts(
282+
builder.concepts, 100, output_dir="data/cavs/", num_processes=1, backend='torch'
283+
)
281284
cav_trainer.train_concepts(
282285
builder.concepts, 100, output_dir="data/cavs/", num_processes=2
283286
)

tpcav/cavs.py

Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pandas as pd
1717
import seaborn as sns
1818
import torch
19+
from copy import deepcopy
1920
from scipy import stats
2021
from sklearn.linear_model import SGDClassifier
2122
from sklearn.metrics import precision_recall_fscore_support
@@ -35,6 +36,125 @@
3536
logger = logging.getLogger(__name__)
3637

3738

39+
class _TorchLinear(torch.nn.Module):
40+
"""Torch linear layer classifier"""
41+
42+
def __init__(self, input_dim, num_class=1, device='cuda:0'):
43+
super().__init__()
44+
self.linear = torch.nn.Linear(input_dim, num_class)
45+
self.device = device
46+
47+
def forward(self, avs):
48+
return self.linear(avs).squeeze(-1)
49+
50+
def shuffle_tensor(self, *tensors):
51+
p = torch.randperm(len(tensors[0]))
52+
53+
new_tensors = [t[p] for t in tensors]
54+
55+
return new_tensors
56+
57+
def fit(self, train_avs: np.ndarray, train_ls: np.ndarray,
58+
val_avs: np.ndarray, val_ls: np.ndarray,
59+
patience=10, lr=1e-2, weight_decay=1e-2, max_epochs=1000):
60+
train_avs = torch.from_numpy(train_avs)
61+
train_ls = torch.from_numpy(train_ls)
62+
val_avs = torch.from_numpy(val_avs)
63+
val_ls = torch.from_numpy(val_ls)
64+
65+
optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
66+
67+
best_loss = None
68+
best_state_dict = None
69+
epoch = 0
70+
t = 0
71+
while True:
72+
epoch += 1
73+
if epoch > max_epochs: break
74+
75+
#shuffle the tensors before every epoch
76+
train_avs, train_ls = self.shuffle_tensor(train_avs, train_ls)
77+
78+
self.train()
79+
for i in range(0, len(train_avs), 32):
80+
optimizer.zero_grad()
81+
82+
avs = train_avs[i: (i+32)]
83+
l = train_ls[i: (i+32)]
84+
85+
y_hat = self(avs.to(self.device))
86+
loss = torch.mean(torch.clamp(1 - l.to(self.device) * y_hat, min=0))
87+
88+
loss.backward()
89+
optimizer.step()
90+
91+
logger.debug(f"Training loss at epoch {epoch}: {loss}")
92+
93+
self.eval()
94+
val_loss_all = []
95+
for i in range(0, len(val_avs), 32):
96+
avs = val_avs[i: (i+32)]
97+
l = val_ls[i: (i+32)]
98+
99+
y_hat = self(avs.to(self.device))
100+
val_loss = torch.mean(torch.clamp(1 - l.to(self.device) * y_hat, min=0))
101+
102+
val_loss_all.append(val_loss.item())
103+
104+
val_loss_all = np.mean(val_loss_all)
105+
106+
#logger.debug(f"Validation loss at epoch {epoch}: {val_loss_all}, patience {t} out of {patience}")
107+
108+
if (best_loss is None) or (val_loss_all < best_loss):
109+
best_loss = val_loss_all
110+
best_state_dict = deepcopy(self.state_dict())
111+
else:
112+
t += 1
113+
if t >= patience: break
114+
115+
return best_state_dict, best_loss
116+
117+
class _TorchLinearWrapper:
118+
def __init__(self, input_dim, num_class=1, lr=1e-2, weight_decay_search = [1e-2, 1e-4, 1e-6], device="cuda:0"):
119+
super().__init__()
120+
self.input_dim = input_dim
121+
self.num_class = num_class
122+
self.lr = lr
123+
self.weight_decay_search = weight_decay_search
124+
self.device = device
125+
126+
def fit(self, train_val_avs: np.ndarray, train_val_ls: np.ndarray):
127+
128+
train_avs, val_avs, train_ls, val_ls = train_test_split(train_val_avs, train_val_ls, test_size=0.1)
129+
130+
best_state_dict = None; best_loss = None
131+
for w in self.weight_decay_search:
132+
model = _TorchLinear(self.input_dim, self.num_class).to(self.device)
133+
state_dict, loss = model.fit(train_avs, train_ls, val_avs, val_ls, lr=self.lr, weight_decay=w)
134+
if (best_loss is None) or (loss < best_loss):
135+
best_loss = loss
136+
best_state_dict = state_dict
137+
138+
self.best_model = _TorchLinear(self.input_dim, self.num_class)
139+
self.best_model.load_state_dict(best_state_dict)
140+
self.best_model.to(self.device)
141+
142+
def predict(self, avs: np.ndarray):
143+
y_hat = self.best_model(torch.from_numpy(avs).to(self.device))
144+
145+
y_hat[y_hat>=0] = 1
146+
y_hat[y_hat<0] = -1
147+
148+
return y_hat.detach().cpu().numpy()
149+
150+
@property
151+
def weights(self):
152+
linear_weight = self.best_model.linear.weight.detach().cpu()[0]
153+
return torch.stack([-1 * linear_weight, linear_weight])
154+
155+
@property
156+
def classes_(self):
157+
return self.num_class
38158

39159
class _SGDWrapper:
40160
"""Lightweight SGD concept classifier."""
@@ -111,19 +231,30 @@ def _train(
111231
control_embeddings: torch.Tensor,
112232
output_dir: str,
113233
penalty: str = "l2",
234+
backend: str = "sklearn",
114235
) -> Tuple[float, torch.Tensor]:
115236
"""
116237
Train a binary CAV classifier for a concept vs cached control embeddings.
117238
118239
Requires set_control to have been called beforehand.
119240
"""
241+
assert backend in ["sklearn", "torch"]
242+
120243
output_dir = Path(output_dir)
121244

122245
train_avs, train_l, test_avs, test_l = prepare_xy(concept_embeddings, control_embeddings)
123246

124-
clf = _SGDWrapper(penalty=penalty)
125-
clf.fit(train_avs, train_l)
247+
if backend == "sklearn":
248+
clf = _SGDWrapper(penalty=penalty)
249+
else:
250+
# replace label 0 as -1 to accomodate hinge loss
251+
train_l[train_l==0] = -1
252+
test_l[test_l==0] = -1
126253

254+
clf = _TorchLinearWrapper(input_dim= train_avs.shape[1])
255+
clf.fit(train_avs, train_l)
256+
257+
#breakpoint()
127258
def _eval(avs, l, name: str):
128259

129260
y_preds = clf.predict(avs)
@@ -204,7 +335,8 @@ def train_concepts(
204335
num_samples: int,
205336
output_dir: str,
206337
num_processes: int = 1,
207-
max_pending: int = 8
338+
max_pending: int = 8,
339+
backend='sklearn',
208340
):
209341
"Train concepts with a fixed control set by self.set_control()"
210342
if self.control_embeddings is None:
@@ -224,6 +356,7 @@ def train_concepts(
224356
self.control_embeddings.cpu(),
225357
Path(output_dir) / c.name,
226358
self.penalty,
359+
backend=backend
227360
)
228361
self.cav_fscores[c.name] = fscore
229362
self.cav_weights[c.name] = weight
@@ -255,6 +388,7 @@ def train_concepts(
255388
self.control_embeddings,
256389
Path(output_dir) / c.name,
257390
self.penalty,
391+
backend=backend
258392
)
259393
logger.info("Submitted CAV training for concept %s", c.name)
260394
futures.append((c.name, future))
@@ -270,7 +404,8 @@ def train_concepts_pairs(self,
270404
num_samples: int,
271405
output_dir: str,
272406
num_processes: int = 1,
273-
max_pending: int = 8):
407+
max_pending: int = 8,
408+
backend='sklearn'):
274409
"""Train concept pairs (test concept, control concept)
275410
276411
Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
@@ -289,6 +424,7 @@ def train_concepts_pairs(self,
289424
control_embeddings.cpu(),
290425
Path(output_dir) / c_test.name,
291426
self.penalty,
427+
backend=backend
292428
)
293429
self.cav_fscores[c_test.name] = fscore
294430
self.cav_weights[c_test.name] = weight
@@ -322,6 +458,7 @@ def train_concepts_pairs(self,
322458
control_embeddings.cpu(),
323459
Path(output_dir) / c_test.name,
324460
self.penalty,
461+
backend=backend
325462
)
326463
logger.info("Submitted CAV training for concept %s", c_test.name)
327464
futures.append((c_test.name, future))

0 commit comments

Comments
 (0)