Skip to content

Commit 9415434

Browse files
committed
add save_cav_trainer option
1 parent 98d8a72 commit 9415434

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

tpcav/cavs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def run_tpcav(
604604
num_pc: Union[str,int]='full',
605605
p=1,
606606
max_pending_jobs=4,
607+
save_cav_trainer=True
607608
):
608609
"""
609610
One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -696,6 +697,8 @@ def run_tpcav(
696697
num_samples_for_cav,
697698
output_dir=str(output_path / f"cavs_{nm}_motifs/"),
698699
num_processes=p, max_pending=max_pending_jobs)
700+
if save_cav_trainer:
701+
torch.save(cav_trainer, str(output_path / f"cavs_{nm}_motifs/cav_trainer.pt"))
699702
motif_cav_trainers[nm] = cav_trainer
700703
if bed_builder is not None:
701704
bed_cav_trainer = CavTrainer(tpcav_model, penalty="l2")
@@ -708,6 +711,8 @@ def run_tpcav(
708711
output_dir=str(output_path / f"cavs_bed_concepts/"),
709712
num_processes=p,
710713
)
714+
if save_cav_trainer:
715+
torch.save(bed_cav_trainer, str(output_path / f"cavs_bed_concepts/cav_trainer.pt"))
711716
else:
712717
bed_cav_trainer = None
713718

0 commit comments

Comments
 (0)