Skip to content

Commit 7fab50f

Browse files
committed
minor fix
1 parent f9ecb6d commit 7fab50f

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

tpcav/cavs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _train(
238238
239239
Requires set_control to have been called beforehand.
240240
"""
241-
assert backend in ["sklearn", "torch"]
241+
assert backend in ["sklearn", "torch"], "Backend has to be either sklearn or torch!"
242242

243243
output_dir = Path(output_dir)
244244

@@ -790,6 +790,7 @@ def run_tpcav(
790790
generate_html_report=True,
791791
html_report_fscore_thresh=0.9,
792792
seed=1001,
793+
backend='sklearn',
793794
):
794795
"""
795796
One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -877,13 +878,13 @@ def run_tpcav(
877878
cav_trainer.train_concepts_pairs(motif_concepts_pairs[nm],
878879
num_samples_for_cav,
879880
output_dir=str(output_path / f"cavs_{nm}_motifs/"),
880-
num_processes=p, max_pending=max_pending_jobs)
881+
num_processes=p, max_pending=max_pending_jobs, backend=backend)
881882
else:
882883
cav_trainer.set_control(motif_concept_builders[nm].control_concepts[0], num_samples=num_samples_for_cav)
883884
cav_trainer.train_concepts([c for c, _ in motif_concepts_pairs[nm]],
884885
num_samples_for_cav,
885886
output_dir=str(output_path / f"cavs_{nm}_motifs/"),
886-
num_processes=p, max_pending=max_pending_jobs)
887+
num_processes=p, max_pending=max_pending_jobs, backend=backend)
887888
if save_cav_trainer:
888889
torch.save(cav_trainer, str(output_path / f"cavs_{nm}_motifs/cav_trainer.pt"))
889890
motif_cav_trainers[nm] = cav_trainer
@@ -897,6 +898,7 @@ def run_tpcav(
897898
num_samples_for_cav,
898899
output_dir=str(output_path / f"cavs_bed_concepts/"),
899900
num_processes=p,
901+
backend=backend
900902
)
901903
if save_cav_trainer:
902904
torch.save(bed_cav_trainer, str(output_path / f"cavs_bed_concepts/cav_trainer.pt"))

0 commit comments

Comments
 (0)