@@ -130,7 +130,7 @@ def fit(self, train_val_avs: np.ndarray, train_val_ls: np.ndarray):
130130
131131 best_state_dict = None ; best_loss = None
132132 for w in self .weight_decay_search :
133- model = _TorchLinear (self .input_dim , self .num_class ).to (self .device )
133+ model = _TorchLinear (self .input_dim , self .num_class , device = self . device ).to (self .device )
134134 state_dict , loss = model .fit (train_avs , train_ls , val_avs , val_ls , lr = self .lr , weight_decay = w )
135135 if (best_loss is None ) or (loss < best_loss ):
136136 best_loss = loss
@@ -237,6 +237,7 @@ def _train(
237237 output_dir : str ,
238238 penalty : str = "l2" ,
239239 backend : str = "sklearn" ,
240+ device = None ,
240241) -> Tuple [float , torch .Tensor ]:
241242 """
242243 Train a binary CAV classifier for a concept vs cached control embeddings.
@@ -255,8 +256,9 @@ def _train(
255256 # replace label 0 as -1 to accomodate hinge loss
256257 train_l [train_l == 0 ] = - 1
257258 test_l [test_l == 0 ] = - 1
258-
259- clf = _TorchLinearWrapper (input_dim = train_avs .shape [1 ])
259+
260+ device = device or ('cuda:0' if torch .cuda .is_available else 'cpu' )
261+ clf = _TorchLinearWrapper (input_dim = train_avs .shape [1 ], device = device )
260262 clf .fit (train_avs , train_l )
261263
262264 #breakpoint()
@@ -348,6 +350,7 @@ def train_concepts(
348350 num_processes : int = 1 ,
349351 max_pending : int = 8 ,
350352 backend = 'sklearn' ,
353+ device = None
351354 ):
352355 "Train concepts with a fixed control set by self.set_control()"
353356 if self .control_embeddings is None :
@@ -367,7 +370,8 @@ def train_concepts(
367370 self .control_embeddings .cpu (),
368371 Path (output_dir ) / c .name ,
369372 self .penalty ,
370- backend = backend
373+ backend = backend ,
374+ device = device
371375 )
372376 self .cav_fscores [c .name ] = fscore
373377 self .cav_weights [c .name ] = weight
@@ -399,7 +403,8 @@ def train_concepts(
399403 self .control_embeddings ,
400404 Path (output_dir ) / c .name ,
401405 self .penalty ,
402- backend = backend
406+ backend = backend ,
407+ device = device
403408 )
404409 logger .info ("Submitted CAV training for concept %s" , c .name )
405410 futures .append ((c .name , future ))
@@ -416,7 +421,8 @@ def train_concepts_pairs(self,
416421 output_dir : str ,
417422 num_processes : int = 1 ,
418423 max_pending : int = 8 ,
419- backend = 'sklearn' ):
424+ backend = 'sklearn' ,
425+ device = None ):
420426 """Train concept pairs (test concept, control concept)
421427
422428 Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
@@ -435,7 +441,8 @@ def train_concepts_pairs(self,
435441 control_embeddings .cpu (),
436442 Path (output_dir ) / c_test .name ,
437443 self .penalty ,
438- backend = backend
444+ backend = backend ,
445+ device = device
439446 )
440447 self .cav_fscores [c_test .name ] = fscore
441448 self .cav_weights [c_test .name ] = weight
@@ -469,7 +476,8 @@ def train_concepts_pairs(self,
469476 control_embeddings .cpu (),
470477 Path (output_dir ) / c_test .name ,
471478 self .penalty ,
472- backend = backend
479+ backend = backend ,
480+ device = device
473481 )
474482 logger .info ("Submitted CAV training for concept %s" , c_test .name )
475483 futures .append ((c_test .name , future ))
@@ -802,6 +810,7 @@ def run_tpcav(
802810 html_report_fscore_thresh = 0.9 ,
803811 seed = 1001 ,
804812 backend = 'sklearn' ,
813+ device = None ,
805814):
806815 """
807816 One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -889,13 +898,13 @@ def run_tpcav(
889898 cav_trainer .train_concepts_pairs (motif_concepts_pairs [nm ],
890899 num_samples_for_cav ,
891900 output_dir = str (output_path / f"cavs_{ nm } _motifs/" ),
892- num_processes = p , max_pending = max_pending_jobs , backend = backend )
901+ num_processes = p , max_pending = max_pending_jobs , backend = backend , device = device )
893902 else :
894903 cav_trainer .set_control (motif_concept_builders [nm ].control_concepts [0 ], num_samples = num_samples_for_cav )
895904 cav_trainer .train_concepts ([c for c , _ in motif_concepts_pairs [nm ]],
896905 num_samples_for_cav ,
897906 output_dir = str (output_path / f"cavs_{ nm } _motifs/" ),
898- num_processes = p , max_pending = max_pending_jobs , backend = backend )
907+ num_processes = p , max_pending = max_pending_jobs , backend = backend , device = device )
899908 if save_cav_trainer :
900909 torch .save (cav_trainer , str (output_path / f"cavs_{ nm } _motifs/cav_trainer.pt" ))
901910 motif_cav_trainers [nm ] = cav_trainer
@@ -909,7 +918,8 @@ def run_tpcav(
909918 num_samples_for_cav ,
910919 output_dir = str (output_path / f"cavs_bed_concepts/" ),
911920 num_processes = p ,
912- backend = backend
921+ backend = backend ,
922+ device = device
913923 )
914924 if save_cav_trainer :
915925 torch .save (bed_cav_trainer , str (output_path / f"cavs_bed_concepts/cav_trainer.pt" ))
0 commit comments