@@ -238,6 +238,7 @@ def _train(
238238 penalty : str = "l2" ,
239239 backend : str = "sklearn" ,
240240 device = None ,
241+ name = None ,
241242) -> Tuple [float , torch .Tensor ]:
242243 """
243244 Train a binary CAV classifier for a concept vs cached control embeddings.
@@ -270,16 +271,18 @@ def _eval(avs, l, name: str):
270271 precision , recall , fscore , support = precision_recall_fscore_support (
271272 l , y_preds , average = "binary" , pos_label = 1
272273 )
273- logger .info ("[%s] Accuracy: %.4f" , name , acc )
274+ # logger.info("[%s] Accuracy: %.4f", name, acc)
274275 (output_dir / f"classifier_perform_on_{ name } .txt" ).write_text (
275276 f"Accuracy: { acc } \n "
276277 )
277278 return fscore
278279
279280 output_dir .mkdir (parents = True , exist_ok = True )
280- _eval (train_avs , train_l , "train" )
281+ train_fscore = _eval (train_avs , train_l , "train" )
281282 test_fscore = _eval (test_avs , test_l , "test" )
282283
284+ logger .info ("Concept %s: [train] F-score: %.4f, [test] F-score: %.4f" , name , train_fscore , test_fscore )
285+
283286 weights = clf .weights
284287 assert len (weights .shape ) == 2 and weights .shape [0 ] == 2
285288 torch .save (weights , output_dir / "classifier_weights.pt" )
@@ -357,11 +360,11 @@ def _cleanup_paths(paths: list[str]) -> None:
357360 pass
358361
359362 @classmethod
360- def _reap_done_futures (cls , futures : list ):
363+ def _reap_done_futures (cls , futures : list , results : list ):
361364 pending = []
362365 for name , fut , paths in futures :
363366 if fut .done ():
364- fut .result () # raises if worker failed
367+ results . append (( name , fut .result ()) ) # raises if worker failed
365368 cls ._cleanup_paths (paths )
366369 else :
367370 pending .append ((name , fut , paths ))
@@ -371,11 +374,12 @@ def _reap_done_futures(cls, futures: list):
371374 def _wait_for_capacity (
372375 cls ,
373376 futures : list ,
377+ results : list ,
374378 capacity : int ,
375379 sleep_s : int = 5 ,
376380 ):
377381 while True :
378- futures = cls ._reap_done_futures (futures )
382+ futures = cls ._reap_done_futures (futures , results )
379383 if len (futures ) < capacity :
380384 return futures
381385 time .sleep (sleep_s )
@@ -418,15 +422,16 @@ def train_concepts(
418422 concept_dir ,
419423 self .penalty ,
420424 backend = backend ,
421- device = device
425+ device = device ,
426+ name = c .name ,
422427 )
423428 self .cav_fscores [c .name ] = fscore
424429 self .cav_weights [c .name ] = weight
425430 self .cavs_list .append (weight )
426431
427432 self ._cleanup_paths ([str (concept_memmap_path )])
428433 else :
429- futures = []
434+ futures = []; results = []
430435 ctx = mp .get_context ("spawn" )
431436 with ProcessPoolExecutor (mp_context = ctx , max_workers = num_processes ) as executor :
432437 for c in concept_list :
@@ -441,7 +446,7 @@ def train_concepts(
441446
442447 # block the process to avoid too long queue
443448 futures = self ._wait_for_capacity (
444- futures , capacity = (max_pending + num_processes ), sleep_s = 5
449+ futures , results , capacity = (max_pending + num_processes ), sleep_s = 5
445450 )
446451
447452 future = executor .submit (
@@ -451,12 +456,12 @@ def train_concepts(
451456 concept_dir ,
452457 self .penalty ,
453458 backend = backend ,
454- device = device
459+ device = device ,
460+ name = c .name ,
455461 )
456462 logger .info ("Submitted CAV training for concept %s" , c .name )
457463 futures .append ((c .name , future , [str (concept_memmap_path )]))
458464
459- results = []
460465 for name , fut , paths in futures :
461466 results .append ((name , fut .result ()))
462467 self ._cleanup_paths (paths )
@@ -504,15 +509,16 @@ def train_concepts_pairs(self,
504509 concept_dir ,
505510 self .penalty ,
506511 backend = backend ,
507- device = device
512+ device = device ,
513+ name = c_test .name ,
508514 )
509515 self .cav_fscores [c_test .name ] = fscore
510516 self .cav_weights [c_test .name ] = weight
511517 self .cavs_list .append (weight )
512518
513519 self ._cleanup_paths ([str (concept_memmap_path ), str (control_memmap_path )])
514520 else :
515- futures = []
521+ futures = []; results = []
516522 with ProcessPoolExecutor (max_workers = num_processes ) as executor :
517523 for c_test , c_control in concept_pair_list :
518524 concept_embeddings = self .tpcav .concept_embeddings (
@@ -531,7 +537,7 @@ def train_concepts_pairs(self,
531537
532538 # block the process to avoid too long queue
533539 futures = self ._wait_for_capacity (
534- futures , capacity = (max_pending + num_processes ), sleep_s = 5
540+ futures , results , capacity = (max_pending + num_processes ), sleep_s = 5
535541 )
536542
537543 future = executor .submit (
@@ -541,7 +547,8 @@ def train_concepts_pairs(self,
541547 concept_dir ,
542548 self .penalty ,
543549 backend = backend ,
544- device = device
550+ device = device ,
551+ name = c_test .name ,
545552 )
546553 logger .info ("Submitted CAV training for concept %s" , c_test .name )
547554 futures .append (
@@ -552,7 +559,6 @@ def train_concepts_pairs(self,
552559 )
553560 )
554561
555- results = []
556562 for name , fut , paths in futures :
557563 results .append ((name , fut .result ()))
558564 self ._cleanup_paths (paths )
0 commit comments