1919import torch
2020from copy import deepcopy
2121from scipy import stats
22+ import uuid
2223from sklearn .linear_model import SGDClassifier
2324from sklearn .metrics import precision_recall_fscore_support
2425from sklearn .metrics .pairwise import cosine_similarity
@@ -202,9 +203,12 @@ def predict(self, x: np.ndarray) -> np.ndarray:
202203 return self .lm .predict (x )
203204
204205def prepare_xy (concept_embeddings , control_embeddings , seed = 42 ):
206+ def _to_numpy (emb ):
207+ return np .load (str (emb ), mmap_mode = "r" )
208+
205209 # move to CPU + numpy, just double confirm
206- concept = concept_embeddings . detach (). cpu (). numpy ( )
207- control = control_embeddings . detach (). cpu (). numpy ( )
210+ concept = _to_numpy ( concept_embeddings )
211+ control = _to_numpy ( control_embeddings )
208212
209213 # labels
210214 y_concept = np .zeros (len (concept ), dtype = np .int64 )
@@ -228,8 +232,8 @@ def prepare_xy(concept_embeddings, control_embeddings, seed=42):
228232 return X_train , y_train , X_test , y_test
229233
230234def _train (
231- concept_embeddings : torch . Tensor ,
232- control_embeddings : torch . Tensor ,
235+ concept_embeddings : str ,
236+ control_embeddings : str ,
233237 output_dir : str ,
234238 penalty : str = "l2" ,
235239 backend : str = "sklearn" ,
@@ -338,6 +342,44 @@ def set_control(self, control_concept, num_samples: int) -> torch.Tensor:
338342 )
339343 return self .control_embeddings
340344
345+ @staticmethod
346+ def _save_tensor_npy (path : Path , tensor : torch .Tensor ) -> str :
347+ path .parent .mkdir (parents = True , exist_ok = True )
348+ np .save (path , tensor .detach ().cpu ().numpy ())
349+ return str (path )
350+
351+ @staticmethod
352+ def _cleanup_paths (paths : list [str ]) -> None :
353+ for p in paths :
354+ try :
355+ Path (p ).unlink (missing_ok = True )
356+ except Exception :
357+ pass
358+
359+ @classmethod
360+ def _reap_done_futures (cls , futures : list ):
361+ pending = []
362+ for name , fut , paths in futures :
363+ if fut .done ():
364+ fut .result () # raises if worker failed
365+ cls ._cleanup_paths (paths )
366+ else :
367+ pending .append ((name , fut , paths ))
368+ return pending
369+
370+ @classmethod
371+ def _wait_for_capacity (
372+ cls ,
373+ futures : list ,
374+ capacity : int ,
375+ sleep_s : int = 5 ,
376+ ):
377+ while True :
378+ futures = cls ._reap_done_futures (futures )
379+ if len (futures ) < capacity :
380+ return futures
381+ time .sleep (sleep_s )
382+
341383 def train_concepts (
342384 self ,
343385 concept_list ,
@@ -356,22 +398,33 @@ def train_concepts(
356398 else :
357399 self .control_embeddings = self .control_embeddings .cpu ()
358400
401+ output_dir_path = Path (output_dir )
402+ output_dir_path .mkdir (parents = True , exist_ok = True )
403+ control_memmap_path = output_dir_path / f"_control_embeddings_{ uuid .uuid4 ().hex } .npy"
404+ self ._save_tensor_npy (control_memmap_path , self .control_embeddings )
405+
359406 if num_processes == 1 :
360407 for c in concept_list :
361408 concept_embeddings = self .tpcav .concept_embeddings (
362409 c , num_samples = num_samples
363410 )
411+ concept_dir = output_dir_path / c .name
412+ concept_dir .mkdir (parents = True , exist_ok = True )
413+ concept_memmap_path = concept_dir / "concept_embeddings.npy"
414+ self ._save_tensor_npy (concept_memmap_path , concept_embeddings )
364415 fscore , weight = _train (
365- concept_embeddings . cpu ( ),
366- self . control_embeddings . cpu ( ),
367- Path ( output_dir ) / c . name ,
416+ str ( concept_memmap_path ),
417+ str ( control_memmap_path ),
418+ concept_dir ,
368419 self .penalty ,
369420 backend = backend ,
370421 device = device
371422 )
372423 self .cav_fscores [c .name ] = fscore
373424 self .cav_weights [c .name ] = weight
374425 self .cavs_list .append (weight )
426+
427+ self ._cleanup_paths ([str (concept_memmap_path )])
375428 else :
376429 futures = []
377430 ctx = mp .get_context ("spawn" )
@@ -381,36 +434,39 @@ def train_concepts(
381434 c , num_samples = num_samples
382435 )
383436
384- # block the process to avoid too long queue
385- while True :
386- done = [f for (_ , f ) in futures if f .done ()]
387- for f in done :
388- f .result () # raises if worker failed
389-
390- pending = [f for (_ , f ) in futures if not f .done ()]
391- if len (pending ) < (max_pending + num_processes ):
392- break
437+ concept_dir = output_dir_path / c .name
438+ concept_dir .mkdir (parents = True , exist_ok = True )
439+ concept_memmap_path = concept_dir / "concept_embeddings.npy"
440+ self ._save_tensor_npy (concept_memmap_path , concept_embeddings )
393441
394- time .sleep (5 )
442+ # block the process to avoid too long queue
443+ futures = self ._wait_for_capacity (
444+ futures , capacity = (max_pending + num_processes ), sleep_s = 5
445+ )
395446
396447 future = executor .submit (
397448 _train ,
398- concept_embeddings . cpu ( ),
399- self . control_embeddings ,
400- Path ( output_dir ) / c . name ,
449+ str ( concept_memmap_path ),
450+ str ( control_memmap_path ) ,
451+ concept_dir ,
401452 self .penalty ,
402453 backend = backend ,
403454 device = device
404455 )
405456 logger .info ("Submitted CAV training for concept %s" , c .name )
406- futures .append ((c .name , future ))
457+ futures .append ((c .name , future , [ str ( concept_memmap_path )] ))
407458
408- results = [(name , f .result ()) for name , f in futures ]
459+ results = []
460+ for name , fut , paths in futures :
461+ results .append ((name , fut .result ()))
462+ self ._cleanup_paths (paths )
409463 for name , (fscore , weight ) in results :
410464 self .cav_fscores [name ] = fscore
411465 self .cav_weights [name ] = weight
412466 self .cavs_list .append (weight )
413467
468+ self ._cleanup_paths ([str (control_memmap_path )])
469+
414470 def train_concepts_pairs (self ,
415471 concept_pair_list ,
416472 num_samples : int ,
@@ -423,6 +479,9 @@ def train_concepts_pairs(self,
423479
424480 Note: It would compute embeddings on every control concept, use self.train_concepts if control concept is fixed
425481 """
482+ output_dir_path = Path (output_dir )
483+ output_dir_path .mkdir (parents = True , exist_ok = True )
484+
426485 if num_processes == 1 :
427486 for c_test , c_control in concept_pair_list :
428487 concept_embeddings = self .tpcav .concept_embeddings (
@@ -432,17 +491,26 @@ def train_concepts_pairs(self,
432491 c_control , num_samples = num_samples
433492 )
434493
494+ concept_dir = output_dir_path / c_test .name
495+ concept_dir .mkdir (parents = True , exist_ok = True )
496+ concept_memmap_path = concept_dir / "concept_embeddings.npy"
497+ control_memmap_path = concept_dir / "control_embeddings.npy"
498+ self ._save_tensor_npy (concept_memmap_path , concept_embeddings )
499+ self ._save_tensor_npy (control_memmap_path , control_embeddings )
500+
435501 fscore , weight = _train (
436- concept_embeddings . cpu ( ),
437- control_embeddings . cpu ( ),
438- Path ( output_dir ) / c_test . name ,
502+ str ( concept_memmap_path ),
503+ str ( control_memmap_path ),
504+ concept_dir ,
439505 self .penalty ,
440506 backend = backend ,
441507 device = device
442508 )
443509 self .cav_fscores [c_test .name ] = fscore
444510 self .cav_weights [c_test .name ] = weight
445511 self .cavs_list .append (weight )
512+
513+ self ._cleanup_paths ([str (concept_memmap_path ), str (control_memmap_path )])
446514 else :
447515 futures = []
448516 with ProcessPoolExecutor (max_workers = num_processes ) as executor :
@@ -454,31 +522,40 @@ def train_concepts_pairs(self,
454522 c_control , num_samples = num_samples
455523 )
456524
457- # block the process to avoid too long queue
458- while True :
459- done = [f for (_ , f ) in futures if f .done ()]
460- for f in done :
461- f .result () # raises if worker failed
462-
463- pending = [f for (_ , f ) in futures if not f .done ()]
464- if len (pending ) < (max_pending + num_processes ):
465- break
525+ concept_dir = output_dir_path / c_test .name
526+ concept_dir .mkdir (parents = True , exist_ok = True )
527+ concept_memmap_path = concept_dir / "concept_embeddings.npy"
528+ control_memmap_path = concept_dir / "control_embeddings.npy"
529+ self ._save_tensor_npy (concept_memmap_path , concept_embeddings )
530+ self ._save_tensor_npy (control_memmap_path , control_embeddings )
466531
467- time .sleep (5 )
532+ # block the process to avoid too long queue
533+ futures = self ._wait_for_capacity (
534+ futures , capacity = (max_pending + num_processes ), sleep_s = 5
535+ )
468536
469537 future = executor .submit (
470538 _train ,
471- concept_embeddings . cpu ( ),
472- control_embeddings . cpu ( ),
473- Path ( output_dir ) / c_test . name ,
539+ str ( concept_memmap_path ),
540+ str ( control_memmap_path ),
541+ concept_dir ,
474542 self .penalty ,
475543 backend = backend ,
476544 device = device
477545 )
478546 logger .info ("Submitted CAV training for concept %s" , c_test .name )
479- futures .append ((c_test .name , future ))
547+ futures .append (
548+ (
549+ c_test .name ,
550+ future ,
551+ [str (concept_memmap_path ), str (control_memmap_path )],
552+ )
553+ )
480554
481- results = [(name , f .result ()) for name , f in futures ]
555+ results = []
556+ for name , fut , paths in futures :
557+ results .append ((name , fut .result ()))
558+ self ._cleanup_paths (paths )
482559 for name , (fscore , weight ) in results :
483560 self .cav_fscores [name ] = fscore
484561 self .cav_weights [name ] = weight
0 commit comments