@@ -238,7 +238,7 @@ def _train(
238238
239239 Requires set_control to have been called beforehand.
240240 """
241- assert backend in ["sklearn" , "torch" ]
241+ assert backend in ["sklearn" , "torch" ], "Backend has to be either sklearn or torch!"
242242
243243 output_dir = Path (output_dir )
244244
@@ -790,6 +790,7 @@ def run_tpcav(
790790 generate_html_report = True ,
791791 html_report_fscore_thresh = 0.9 ,
792792 seed = 1001 ,
793+ backend = 'sklearn' ,
793794):
794795 """
795796 One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -877,13 +878,13 @@ def run_tpcav(
877878 cav_trainer .train_concepts_pairs (motif_concepts_pairs [nm ],
878879 num_samples_for_cav ,
879880 output_dir = str (output_path / f"cavs_{ nm } _motifs/" ),
880- num_processes = p , max_pending = max_pending_jobs )
881+ num_processes = p , max_pending = max_pending_jobs , backend = backend )
881882 else :
882883 cav_trainer .set_control (motif_concept_builders [nm ].control_concepts [0 ], num_samples = num_samples_for_cav )
883884 cav_trainer .train_concepts ([c for c , _ in motif_concepts_pairs [nm ]],
884885 num_samples_for_cav ,
885886 output_dir = str (output_path / f"cavs_{ nm } _motifs/" ),
886- num_processes = p , max_pending = max_pending_jobs )
887+ num_processes = p , max_pending = max_pending_jobs , backend = backend )
887888 if save_cav_trainer :
888889 torch .save (cav_trainer , str (output_path / f"cavs_{ nm } _motifs/cav_trainer.pt" ))
889890 motif_cav_trainers [nm ] = cav_trainer
@@ -897,6 +898,7 @@ def run_tpcav(
897898 num_samples_for_cav ,
898899 output_dir = str (output_path / f"cavs_bed_concepts/" ),
899900 num_processes = p ,
901+ backend = backend
900902 )
901903 if save_cav_trainer :
902904 torch .save (bed_cav_trainer , str (output_path / f"cavs_bed_concepts/cav_trainer.pt" ))
0 commit comments