2323from sklearn .model_selection import GridSearchCV
2424from torch .utils .data import DataLoader , TensorDataset , random_split
2525from sklearn .linear_model import LinearRegression
26+ from sklearn .model_selection import train_test_split
2627import logomaker
28+ import multiprocessing as mp
2729
2830from . import helper , utils , report
2931from .concepts import ConceptBuilder
3335logger = logging .getLogger (__name__ )
3436
3537
36- def _load_all_tensors_to_numpy (dataloaders : Iterable [DataLoader ]):
37- if not isinstance (dataloaders , list ):
38- dataloaders = [dataloaders ]
39- avs , ls = [], []
40- for dataloader in dataloaders :
41- for av , l in dataloader :
42- avs .append (av .cpu ().numpy ())
43- ls .append (l .cpu ().numpy ())
44- return np .concatenate (avs ), np .concatenate (ls )
45-
4638
4739class _SGDWrapper :
4840 """Lightweight SGD concept classifier."""
@@ -66,8 +58,7 @@ def __init__(self, penalty: str = "l2", n_jobs: int = -1):
6658 raise ValueError (f"Unexpected penalty type { penalty } " )
6759 self .search = GridSearchCV (self .lm , params )
6860
69- def fit (self , train_dl : DataLoader , val_dl : DataLoader ):
70- train_avs , train_ls = _load_all_tensors_to_numpy ([train_dl , val_dl ])
61+ def fit (self , train_avs : np .ndarray , train_ls : np .ndarray ):
7162 self .search .fit (train_avs , train_ls )
7263 self .lm = self .search .best_estimator_
7364 logger .info (
@@ -89,6 +80,31 @@ def classes_(self):
8980 def predict (self , x : np .ndarray ) -> np .ndarray :
9081 return self .lm .predict (x )
9182
83+ def prepare_xy (concept_embeddings , control_embeddings , seed = 42 ):
84+ # move to CPU + numpy, just double confirm
85+ concept = concept_embeddings .detach ().cpu ().numpy ()
86+ control = control_embeddings .detach ().cpu ().numpy ()
87+
88+ # labels
89+ y_concept = np .zeros (len (concept ), dtype = np .int64 )
90+ y_control = np .ones (len (control ), dtype = np .int64 )
91+
92+ # combine
93+ X = np .concatenate ([concept , control ], axis = 0 )
94+ y = np .concatenate ([y_concept , y_control ], axis = 0 )
95+
96+ # flatten if needed (SGDClassifier expects 2D)
97+ X = X .reshape (X .shape [0 ], - 1 )
98+
99+ # split: train vs temp (val+test)
100+ X_train , X_test , y_train , y_test = train_test_split (
101+ X , y ,
102+ test_size = 0.1 ,
103+ random_state = seed ,
104+ stratify = y , # keeps class balance
105+ )
106+
107+ return X_train , y_train , X_test , y_test
92108
93109def _train (
94110 concept_embeddings : torch .Tensor ,
@@ -102,34 +118,19 @@ def _train(
102118 Requires set_control to have been called beforehand.
103119 """
104120 output_dir = Path (output_dir )
121+
122+ train_avs , train_l , test_avs , test_l = prepare_xy (concept_embeddings , control_embeddings )
123+
124+ clf = _SGDWrapper (penalty = penalty )
125+ clf .fit (train_avs , train_l )
105126
106- avd = TensorDataset (
107- concept_embeddings , torch .full ((concept_embeddings .shape [0 ],), 0 )
108- )
109- cvd = TensorDataset (
110- control_embeddings , torch .full ((control_embeddings .shape [0 ],), 1 )
111- )
112- train_ds , val_ds , test_ds = random_split (avd , [0.8 , 0.1 , 0.1 ])
113- c_train , c_val , c_test = random_split (cvd , [0.8 , 0.1 , 0.1 ])
114-
115- train_dl = DataLoader (train_ds + c_train , batch_size = 32 , shuffle = True )
116- val_dl = DataLoader (val_ds + c_val , batch_size = 32 )
117- test_dl = DataLoader (test_ds + c_test , batch_size = 32 )
127+ def _eval (avs , l , name : str ):
128+
129+ y_preds = clf .predict (avs )
118130
119- clf = _SGDWrapper (penalty = penalty )
120- clf .fit (train_dl , val_dl )
121-
122- def _eval (split_dl : DataLoader , name : str ):
123- y_preds , y_trues = [], []
124- for x , y in split_dl :
125- y_pred = clf .predict (x .cpu ().numpy ())
126- y_preds .append (y_pred )
127- y_trues .append (y .cpu ().numpy ())
128- y_preds = np .concatenate (y_preds )
129- y_trues = np .concatenate (y_trues )
130- acc = (y_preds == y_trues ).sum () / len (y_trues )
131+ acc = (y_preds == l ).sum () / len (l )
131132 precision , recall , fscore , support = precision_recall_fscore_support (
132- y_trues , y_preds , average = "binary" , pos_label = 1
133+ l , y_preds , average = "binary" , pos_label = 1
133134 )
134135 logger .info ("[%s] Accuracy: %.4f" , name , acc )
135136 (output_dir / f"classifier_perform_on_{ name } .txt" ).write_text (
@@ -138,9 +139,8 @@ def _eval(split_dl: DataLoader, name: str):
138139 return fscore
139140
140141 output_dir .mkdir (parents = True , exist_ok = True )
141- _eval (train_dl , "train" )
142- _eval (val_dl , "val" )
143- test_fscore = _eval (test_dl , "test" )
142+ _eval (train_avs , train_l , "train" )
143+ test_fscore = _eval (test_avs , test_l , "test" )
144144
145145 weights = clf .weights
146146 assert len (weights .shape ) == 2 and weights .shape [0 ] == 2
@@ -230,7 +230,8 @@ def train_concepts(
230230 self .cavs_list .append (weight )
231231 else :
232232 futures = []
233- with ProcessPoolExecutor (max_workers = num_processes ) as executor :
233+ ctx = mp .get_context ("spawn" )
234+ with ProcessPoolExecutor (mp_context = ctx , max_workers = num_processes ) as executor :
234235 for c in concept_list :
235236 concept_embeddings = self .tpcav .concept_embeddings (
236237 c , num_samples = num_samples
0 commit comments