@@ -154,7 +154,7 @@ class CavTrainer:
154154 def __init__ (self , tpcav : TPCAV , penalty : str = "l2" ) -> None :
155155 self .tpcav = tpcav
156156 self .penalty = penalty
157- self .cavs_fscores = {}
157+ self .cav_fscores = {}
158158 self .cav_weights = {}
159159 self .control_embeddings : Optional [torch .Tensor ] = None
160160 self .cavs_list : List [torch .Tensor ] = []
@@ -165,7 +165,7 @@ def save_state(self, output_path: str = "cav_trainer_state.pt"):
165165 """
166166 state = {
167167 "penalty" : self .penalty ,
168- "cavs_fscores " : self .cavs_fscores ,
168+ "cav_fscores " : self .cav_fscores ,
169169 "cav_weights" : self .cav_weights ,
170170 "control_embeddings" : self .control_embeddings ,
171171 "cavs_list" : self .cavs_list ,
@@ -180,7 +180,7 @@ def load_state(tpcav_model: TPCAV, state_path: str = "cav_trainer_state.pt"):
180180 state = torch .load (state_path , map_location = "cpu" )
181181 cav_trainer = CavTrainer (tpcav_model , penalty = state ["penalty" ])
182182
183- cav_trainer .cavs_fscores = state ["cavs_fscores " ]
183+ cav_trainer .cav_fscores = state ["cav_fscores " ]
184184 cav_trainer .cav_weights = state ["cav_weights" ]
185185 cav_trainer .control_embeddings = state ["control_embeddings" ]
186186 cav_trainer .cavs_list = state ["cavs_list" ]
@@ -224,7 +224,7 @@ def train_concepts(
224224 Path (output_dir ) / c .name ,
225225 self .penalty ,
226226 )
227- self .cavs_fscores [c .name ] = fscore
227+ self .cav_fscores [c .name ] = fscore
228228 self .cav_weights [c .name ] = weight
229229 self .cavs_list .append (weight )
230230 else :
@@ -259,7 +259,7 @@ def train_concepts(
259259
260260 results = [(name , f .result ()) for name , f in futures ]
261261 for name , (fscore , weight ) in results :
262- self .cavs_fscores [name ] = fscore
262+ self .cav_fscores [name ] = fscore
263263 self .cav_weights [name ] = weight
264264 self .cavs_list .append (weight )
265265
@@ -288,7 +288,7 @@ def train_concepts_pairs(self,
288288 Path (output_dir ) / c_test .name ,
289289 self .penalty ,
290290 )
291- self .cavs_fscores [c_test .name ] = fscore
291+ self .cav_fscores [c_test .name ] = fscore
292292 self .cav_weights [c_test .name ] = weight
293293 self .cavs_list .append (weight )
294294 else :
@@ -326,7 +326,7 @@ def train_concepts_pairs(self,
326326
327327 results = [(name , f .result ()) for name , f in futures ]
328328 for name , (fscore , weight ) in results :
329- self .cavs_fscores [name ] = fscore
329+ self .cav_fscores [name ] = fscore
330330 self .cav_weights [name ] = weight
331331 self .cavs_list .append (weight )
332332
@@ -397,14 +397,14 @@ def plot_cavs_similaritiy_heatmap(
397397 cavs_pass = []
398398 cavs_names_pass = []
399399 for cname in cavs_names :
400- if self .cavs_fscores [cname ] >= fscore_thresh :
400+ if self .cav_fscores [cname ] >= fscore_thresh :
401401 cavs_pass .append (self .cav_weights [cname ].cpu ().numpy ())
402402 cavs_names_pass .append (cname )
403403 else :
404404 logger .info (
405405 "Skipping CAV %s with F-score %.3f below threshold %.3f" ,
406406 cname ,
407- self .cavs_fscores [cname ],
407+ self .cav_fscores [cname ],
408408 fscore_thresh ,
409409 )
410410 if len (cavs_pass ) == 0 :
@@ -533,8 +533,8 @@ def compute_motif_auc_fscore(num_motif_insertions: List[int], cav_trainers: List
533533
534534 assert motif_file_fmt in ['meme' , 'consensus' ]
535535
536- cavs_fscores_df = pd .DataFrame ({nm : cav_trainer .cavs_fscores for nm , cav_trainer in zip (num_motif_insertions , cav_trainers )})
537- cavs_fscores_df ['concept' ] = list (cav_trainers [0 ].cavs_fscores .keys ())
536+ cavs_fscores_df = pd .DataFrame ({nm : cav_trainer .cav_fscores for nm , cav_trainer in zip (num_motif_insertions , cav_trainers )})
537+ cavs_fscores_df ['concept' ] = list (cav_trainers [0 ].cav_fscores .keys ())
538538
539539 def compute_auc_fscore (row ):
540540 y = [row [nm ] for nm in num_motif_insertions ]
0 commit comments