Skip to content

Commit 2cd2a62

Browse files
committed
optimize data cleaning for training cavs
1 parent ed3fde0 commit 2cd2a62

2 files changed

Lines changed: 43 additions & 41 deletions

File tree

tpcav/cavs.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from sklearn.model_selection import GridSearchCV
2424
from torch.utils.data import DataLoader, TensorDataset, random_split
2525
from sklearn.linear_model import LinearRegression
26+
from sklearn.model_selection import train_test_split
2627
import logomaker
28+
import multiprocessing as mp
2729

2830
from . import helper, utils, report
2931
from .concepts import ConceptBuilder
@@ -33,16 +35,6 @@
3335
logger = logging.getLogger(__name__)
3436

3537

36-
def _load_all_tensors_to_numpy(dataloaders: Iterable[DataLoader]):
37-
if not isinstance(dataloaders, list):
38-
dataloaders = [dataloaders]
39-
avs, ls = [], []
40-
for dataloader in dataloaders:
41-
for av, l in dataloader:
42-
avs.append(av.cpu().numpy())
43-
ls.append(l.cpu().numpy())
44-
return np.concatenate(avs), np.concatenate(ls)
45-
4638

4739
class _SGDWrapper:
4840
"""Lightweight SGD concept classifier."""
@@ -66,8 +58,7 @@ def __init__(self, penalty: str = "l2", n_jobs: int = -1):
6658
raise ValueError(f"Unexpected penalty type {penalty}")
6759
self.search = GridSearchCV(self.lm, params)
6860

69-
def fit(self, train_dl: DataLoader, val_dl: DataLoader):
70-
train_avs, train_ls = _load_all_tensors_to_numpy([train_dl, val_dl])
61+
def fit(self, train_avs: np.ndarray, train_ls: np.ndarray):
7162
self.search.fit(train_avs, train_ls)
7263
self.lm = self.search.best_estimator_
7364
logger.info(
@@ -89,6 +80,31 @@ def classes_(self):
8980
def predict(self, x: np.ndarray) -> np.ndarray:
9081
return self.lm.predict(x)
9182

83+
def prepare_xy(concept_embeddings, control_embeddings, seed=42):
84+
# move to CPU + numpy, just double confirm
85+
concept = concept_embeddings.detach().cpu().numpy()
86+
control = control_embeddings.detach().cpu().numpy()
87+
88+
# labels
89+
y_concept = np.zeros(len(concept), dtype=np.int64)
90+
y_control = np.ones(len(control), dtype=np.int64)
91+
92+
# combine
93+
X = np.concatenate([concept, control], axis=0)
94+
y = np.concatenate([y_concept, y_control], axis=0)
95+
96+
# flatten if needed (SGDClassifier expects 2D)
97+
X = X.reshape(X.shape[0], -1)
98+
99+
# split: train vs temp (val+test)
100+
X_train, X_test, y_train, y_test = train_test_split(
101+
X, y,
102+
test_size=0.1,
103+
random_state=seed,
104+
stratify=y, # keeps class balance
105+
)
106+
107+
return X_train, y_train, X_test, y_test
92108

93109
def _train(
94110
concept_embeddings: torch.Tensor,
@@ -102,34 +118,19 @@ def _train(
102118
Requires set_control to have been called beforehand.
103119
"""
104120
output_dir = Path(output_dir)
121+
122+
train_avs, train_l, test_avs, test_l = prepare_xy(concept_embeddings, control_embeddings)
123+
124+
clf = _SGDWrapper(penalty=penalty)
125+
clf.fit(train_avs, train_l)
105126

106-
avd = TensorDataset(
107-
concept_embeddings, torch.full((concept_embeddings.shape[0],), 0)
108-
)
109-
cvd = TensorDataset(
110-
control_embeddings, torch.full((control_embeddings.shape[0],), 1)
111-
)
112-
train_ds, val_ds, test_ds = random_split(avd, [0.8, 0.1, 0.1])
113-
c_train, c_val, c_test = random_split(cvd, [0.8, 0.1, 0.1])
114-
115-
train_dl = DataLoader(train_ds + c_train, batch_size=32, shuffle=True)
116-
val_dl = DataLoader(val_ds + c_val, batch_size=32)
117-
test_dl = DataLoader(test_ds + c_test, batch_size=32)
127+
def _eval(avs, l, name: str):
128+
129+
y_preds = clf.predict(avs)
118130

119-
clf = _SGDWrapper(penalty=penalty)
120-
clf.fit(train_dl, val_dl)
121-
122-
def _eval(split_dl: DataLoader, name: str):
123-
y_preds, y_trues = [], []
124-
for x, y in split_dl:
125-
y_pred = clf.predict(x.cpu().numpy())
126-
y_preds.append(y_pred)
127-
y_trues.append(y.cpu().numpy())
128-
y_preds = np.concatenate(y_preds)
129-
y_trues = np.concatenate(y_trues)
130-
acc = (y_preds == y_trues).sum() / len(y_trues)
131+
acc = (y_preds == l).sum() / len(l)
131132
precision, recall, fscore, support = precision_recall_fscore_support(
132-
y_trues, y_preds, average="binary", pos_label=1
133+
l, y_preds, average="binary", pos_label=1
133134
)
134135
logger.info("[%s] Accuracy: %.4f", name, acc)
135136
(output_dir / f"classifier_perform_on_{name}.txt").write_text(
@@ -138,9 +139,8 @@ def _eval(split_dl: DataLoader, name: str):
138139
return fscore
139140

140141
output_dir.mkdir(parents=True, exist_ok=True)
141-
_eval(train_dl, "train")
142-
_eval(val_dl, "val")
143-
test_fscore = _eval(test_dl, "test")
142+
_eval(train_avs, train_l, "train")
143+
test_fscore = _eval(test_avs, test_l, "test")
144144

145145
weights = clf.weights
146146
assert len(weights.shape) == 2 and weights.shape[0] == 2
@@ -230,7 +230,8 @@ def train_concepts(
230230
self.cavs_list.append(weight)
231231
else:
232232
futures = []
233-
with ProcessPoolExecutor(max_workers=num_processes) as executor:
233+
ctx = mp.get_context("spawn")
234+
with ProcessPoolExecutor(mp_context=ctx, max_workers=num_processes) as executor:
234235
for c in concept_list:
235236
concept_embeddings = self.tpcav.concept_embeddings(
236237
c, num_samples=num_samples

tpcav/tpcav_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def fit_pca(
9292
num_pc: Optional[Union[int, str]] = None,
9393
) -> Dict[str, torch.Tensor]:
9494
"""Sample activations, compute PCA, and attach buffers to the model."""
95+
9596
logger.info("Start building PCA transformation.")
9697

9798
sampled_avs = []

0 commit comments

Comments
 (0)