Skip to content

Commit 84ad3a3

Browse files
committed
add option for select gpu for classifier train
1 parent 014383b commit 84ad3a3

2 files changed

Lines changed: 22 additions & 12 deletions

File tree

test/test_cav_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def test_all(self):
279279
cav_trainer.set_control(builder.control_concepts[0], num_samples=100)
280280

281281
cav_trainer.train_concepts(
282-
builder.concepts, 100, output_dir="data/cavs/", num_processes=2, backend='torch'
282+
builder.concepts, 100, output_dir="data/cavs/", num_processes=2, backend='torch', device='cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0',
283283
)
284284
cav_trainer.train_concepts(
285285
builder.concepts, 100, output_dir="data/cavs/", num_processes=2

tpcav/cavs.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def fit(self, train_val_avs: np.ndarray, train_val_ls: np.ndarray):
130130

131131
best_state_dict = None; best_loss = None
132132
for w in self.weight_decay_search:
133-
model = _TorchLinear(self.input_dim, self.num_class).to(self.device)
133+
model = _TorchLinear(self.input_dim, self.num_class, device=self.device).to(self.device)
134134
state_dict, loss = model.fit(train_avs, train_ls, val_avs, val_ls, lr=self.lr, weight_decay=w)
135135
if (best_loss is None) or (loss < best_loss):
136136
best_loss = loss
@@ -237,6 +237,7 @@ def _train(
237237
output_dir: str,
238238
penalty: str = "l2",
239239
backend: str = "sklearn",
240+
device=None,
240241
) -> Tuple[float, torch.Tensor]:
241242
"""
242243
Train a binary CAV classifier for a concept vs cached control embeddings.
@@ -255,8 +256,9 @@ def _train(
255256
# replace label 0 as -1 to accomodate hinge loss
256257
train_l[train_l==0] = -1
257258
test_l[test_l==0] = -1
258-
259-
clf = _TorchLinearWrapper(input_dim= train_avs.shape[1])
259+
260+
device = device or ('cuda:0' if torch.cuda.is_available else 'cpu')
261+
clf = _TorchLinearWrapper(input_dim= train_avs.shape[1], device=device)
260262
clf.fit(train_avs, train_l)
261263

262264
#breakpoint()
@@ -348,6 +350,7 @@ def train_concepts(
348350
num_processes: int = 1,
349351
max_pending: int = 8,
350352
backend='sklearn',
353+
device=None
351354
):
352355
"Train concepts with a fixed control set by self.set_control()"
353356
if self.control_embeddings is None:
@@ -367,7 +370,8 @@ def train_concepts(
367370
self.control_embeddings.cpu(),
368371
Path(output_dir) / c.name,
369372
self.penalty,
370-
backend=backend
373+
backend=backend,
374+
device=device
371375
)
372376
self.cav_fscores[c.name] = fscore
373377
self.cav_weights[c.name] = weight
@@ -399,7 +403,8 @@ def train_concepts(
399403
self.control_embeddings,
400404
Path(output_dir) / c.name,
401405
self.penalty,
402-
backend=backend
406+
backend=backend,
407+
device=device
403408
)
404409
logger.info("Submitted CAV training for concept %s", c.name)
405410
futures.append((c.name, future))
@@ -416,7 +421,8 @@ def train_concepts_pairs(self,
416421
output_dir: str,
417422
num_processes: int = 1,
418423
max_pending: int = 8,
419-
backend='sklearn'):
424+
backend='sklearn',
425+
device=None):
420426
"""Train concept pairs (test concept, control concept)
421427
422428
Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
@@ -435,7 +441,8 @@ def train_concepts_pairs(self,
435441
control_embeddings.cpu(),
436442
Path(output_dir) / c_test.name,
437443
self.penalty,
438-
backend=backend
444+
backend=backend,
445+
device=device
439446
)
440447
self.cav_fscores[c_test.name] = fscore
441448
self.cav_weights[c_test.name] = weight
@@ -469,7 +476,8 @@ def train_concepts_pairs(self,
469476
control_embeddings.cpu(),
470477
Path(output_dir) / c_test.name,
471478
self.penalty,
472-
backend=backend
479+
backend=backend,
480+
device=device
473481
)
474482
logger.info("Submitted CAV training for concept %s", c_test.name)
475483
futures.append((c_test.name, future))
@@ -802,6 +810,7 @@ def run_tpcav(
802810
html_report_fscore_thresh=0.9,
803811
seed=1001,
804812
backend='sklearn',
813+
device=None,
805814
):
806815
"""
807816
One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -889,13 +898,13 @@ def run_tpcav(
889898
cav_trainer.train_concepts_pairs(motif_concepts_pairs[nm],
890899
num_samples_for_cav,
891900
output_dir=str(output_path / f"cavs_{nm}_motifs/"),
892-
num_processes=p, max_pending=max_pending_jobs, backend=backend)
901+
num_processes=p, max_pending=max_pending_jobs, backend=backend, device=device)
893902
else:
894903
cav_trainer.set_control(motif_concept_builders[nm].control_concepts[0], num_samples=num_samples_for_cav)
895904
cav_trainer.train_concepts([c for c, _ in motif_concepts_pairs[nm]],
896905
num_samples_for_cav,
897906
output_dir=str(output_path / f"cavs_{nm}_motifs/"),
898-
num_processes=p, max_pending=max_pending_jobs, backend=backend)
907+
num_processes=p, max_pending=max_pending_jobs, backend=backend, device=device)
899908
if save_cav_trainer:
900909
torch.save(cav_trainer, str(output_path / f"cavs_{nm}_motifs/cav_trainer.pt"))
901910
motif_cav_trainers[nm] = cav_trainer
@@ -909,7 +918,8 @@ def run_tpcav(
909918
num_samples_for_cav,
910919
output_dir=str(output_path / f"cavs_bed_concepts/"),
911920
num_processes=p,
912-
backend=backend
921+
backend=backend,
922+
device=device
913923
)
914924
if save_cav_trainer:
915925
torch.save(bed_cav_trainer, str(output_path / f"cavs_bed_concepts/cav_trainer.pt"))

0 commit comments

Comments
 (0)