1616import pandas as pd
1717import seaborn as sns
1818import torch
19+ from copy import deepcopy
1920from scipy import stats
2021from sklearn .linear_model import SGDClassifier
2122from sklearn .metrics import precision_recall_fscore_support
3536logger = logging .getLogger (__name__ )
3637
3738
39+ class _TorchLinear (torch .nn .Module ):
40+ """Torch linear layer classifier"""
41+
42+ def __init__ (self , input_dim , num_class = 1 , device = 'cuda:0' ):
43+ super ().__init__ ()
44+ self .linear = torch .nn .Linear (input_dim , num_class )
45+ self .device = device
46+
47+ def forward (self , avs ):
48+ return self .linear (avs ).squeeze (- 1 )
49+
50+ def shuffle_tensor (self , * tensors ):
51+ p = torch .randperm (len (tensors [0 ]))
52+
53+ new_tensors = [t [p ] for t in tensors ]
54+
55+ return new_tensors
56+
57+ def fit (self , train_avs : np .ndarray , train_ls : np .ndarray ,
58+ val_avs : np .ndarray , val_ls : np .ndarray ,
59+ patience = 10 , lr = 1e-2 , weight_decay = 1e-2 , max_epochs = 1000 ):
60+ train_avs = torch .from_numpy (train_avs )
61+ train_ls = torch .from_numpy (train_ls )
62+ val_avs = torch .from_numpy (val_avs )
63+ val_ls = torch .from_numpy (val_ls )
64+
65+ optimizer = torch .optim .AdamW (self .parameters (), lr = lr , weight_decay = weight_decay )
66+
67+ best_loss = None
68+ best_state_dict = None
69+ epoch = 0
70+ t = 0
71+ while True :
72+ epoch += 1
73+ if epoch > max_epochs : break
74+
75+ #shuffle the tensors before every epoch
76+ train_avs , train_ls = self .shuffle_tensor (train_avs , train_ls )
77+
78+ self .train ()
79+ for i in range (0 , len (train_avs ), 32 ):
80+ optimizer .zero_grad ()
81+
82+ avs = train_avs [i : (i + 32 )]
83+ l = train_ls [i : (i + 32 )]
84+
85+ y_hat = self (avs .to (self .device ))
86+ loss = torch .mean (torch .clamp (1 - l .to (self .device ) * y_hat , min = 0 ))
87+
88+ loss .backward ()
89+ optimizer .step ()
90+
91+ logger .debug (f"Training loss at epoch { epoch } : { loss } " )
92+
93+ self .eval ()
94+ val_loss_all = []
95+ for i in range (0 , len (val_avs ), 32 ):
96+ avs = val_avs [i : (i + 32 )]
97+ l = val_ls [i : (i + 32 )]
98+
99+ y_hat = self (avs .to (self .device ))
100+ val_loss = torch .mean (torch .clamp (1 - l .to (self .device ) * y_hat , min = 0 ))
101+
102+ val_loss_all .append (val_loss .item ())
103+
104+ val_loss_all = np .mean (val_loss_all )
105+
106+ #logger.debug(f"Validation loss at epoch {epoch}: {val_loss_all}, patience {t} out of {patience}")
107+
108+ if (best_loss is None ) or (val_loss_all < best_loss ):
109+ best_loss = val_loss_all
110+ best_state_dict = deepcopy (self .state_dict ())
111+ else :
112+ t += 1
113+ if t >= patience : break
114+
115+ return best_state_dict , best_loss
116+
117+ class _TorchLinearWrapper :
118+ def __init__ (self , input_dim , num_class = 1 , lr = 1e-2 , weight_decay_search = [1e-2 , 1e-4 , 1e-6 ], device = "cuda:0" ):
119+ super ().__init__ ()
120+ self .input_dim = input_dim
121+ self .num_class = num_class
122+ self .lr = lr
123+ self .weight_decay_search = weight_decay_search
124+ self .device = device
125+
126+ def fit (self , train_val_avs : np .ndarray , train_val_ls : np .ndarray ):
127+
128+ train_avs , val_avs , train_ls , val_ls = train_test_split (train_val_avs , train_val_ls , test_size = 0.1 )
129+
130+ best_state_dict = None ; best_loss = None
131+ for w in self .weight_decay_search :
132+ model = _TorchLinear (self .input_dim , self .num_class ).to (self .device )
133+ state_dict , loss = model .fit (train_avs , train_ls , val_avs , val_ls , lr = self .lr , weight_decay = w )
134+ if (best_loss is None ) or (loss < best_loss ):
135+ best_loss = loss
136+ best_state_dict = state_dict
137+
138+ self .best_model = _TorchLinear (self .input_dim , self .num_class )
139+ self .best_model .load_state_dict (best_state_dict )
140+ self .best_model .to (self .device )
141+
142+ def predict (self , avs : np .ndarray ):
143+ y_hat = self .best_model (torch .from_numpy (avs ).to (self .device ))
144+
145+ y_hat [y_hat >= 0 ] = 1
146+ y_hat [y_hat < 0 ] = - 1
147+
148+ return y_hat .detach ().cpu ().numpy ()
149+
150+ @property
151+ def weights (self ):
152+ linear_weight = self .best_model .linear .weight .detach ().cpu ()[0 ]
153+ return torch .stack ([- 1 * linear_weight , linear_weight ])
154+
155+ @property
156+ def classes_ (self ):
157+ return self .num_class
38158
39159class _SGDWrapper :
40160 """Lightweight SGD concept classifier."""
@@ -111,19 +231,30 @@ def _train(
111231 control_embeddings : torch .Tensor ,
112232 output_dir : str ,
113233 penalty : str = "l2" ,
234+ backend : str = "sklearn" ,
114235) -> Tuple [float , torch .Tensor ]:
115236 """
116237 Train a binary CAV classifier for a concept vs cached control embeddings.
117238
118239 Requires set_control to have been called beforehand.
119240 """
241+ assert backend in ["sklearn" , "torch" ]
242+
120243 output_dir = Path (output_dir )
121244
122245 train_avs , train_l , test_avs , test_l = prepare_xy (concept_embeddings , control_embeddings )
123246
124- clf = _SGDWrapper (penalty = penalty )
125- clf .fit (train_avs , train_l )
247+ if backend == "sklearn" :
248+ clf = _SGDWrapper (penalty = penalty )
249+ else :
250+ # replace label 0 as -1 to accomodate hinge loss
251+ train_l [train_l == 0 ] = - 1
252+ test_l [test_l == 0 ] = - 1
126253
254+ clf = _TorchLinearWrapper (input_dim = train_avs .shape [1 ])
255+ clf .fit (train_avs , train_l )
256+
257+ #breakpoint()
127258 def _eval (avs , l , name : str ):
128259
129260 y_preds = clf .predict (avs )
@@ -204,7 +335,8 @@ def train_concepts(
204335 num_samples : int ,
205336 output_dir : str ,
206337 num_processes : int = 1 ,
207- max_pending : int = 8
338+ max_pending : int = 8 ,
339+ backend = 'sklearn' ,
208340 ):
209341 "Train concepts with a fixed control set by self.set_control()"
210342 if self .control_embeddings is None :
@@ -224,6 +356,7 @@ def train_concepts(
224356 self .control_embeddings .cpu (),
225357 Path (output_dir ) / c .name ,
226358 self .penalty ,
359+ backend = backend
227360 )
228361 self .cav_fscores [c .name ] = fscore
229362 self .cav_weights [c .name ] = weight
@@ -255,6 +388,7 @@ def train_concepts(
255388 self .control_embeddings ,
256389 Path (output_dir ) / c .name ,
257390 self .penalty ,
391+ backend = backend
258392 )
259393 logger .info ("Submitted CAV training for concept %s" , c .name )
260394 futures .append ((c .name , future ))
@@ -270,7 +404,8 @@ def train_concepts_pairs(self,
270404 num_samples : int ,
271405 output_dir : str ,
272406 num_processes : int = 1 ,
273- max_pending : int = 8 ):
407+ max_pending : int = 8 ,
408+ backend = 'sklearn' ):
274409 """Train concept pairs (test concept, control concept)
275410
276411 Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
@@ -289,6 +424,7 @@ def train_concepts_pairs(self,
289424 control_embeddings .cpu (),
290425 Path (output_dir ) / c_test .name ,
291426 self .penalty ,
427+ backend = backend
292428 )
293429 self .cav_fscores [c_test .name ] = fscore
294430 self .cav_weights [c_test .name ] = weight
@@ -322,6 +458,7 @@ def train_concepts_pairs(self,
322458 control_embeddings .cpu (),
323459 Path (output_dir ) / c_test .name ,
324460 self .penalty ,
461+ backend = backend
325462 )
326463 logger .info ("Submitted CAV training for concept %s" , c_test .name )
327464 futures .append ((c_test .name , future ))
0 commit comments