@@ -792,7 +792,20 @@ def main(par: dict) -> pd.DataFrame:
792792 """
793793
794794 # Load data
795- print ("\n [1/5] Loading data..." )
795+ pathway_files = {}
796+ geneset_mapping = {
797+ 'geneset_hallmark_2020' : 'hallmark_2020' ,
798+ 'geneset_kegg_2021' : 'kegg_2021' ,
799+ 'geneset_reactome_2022' : 'reactome_2022' ,
800+ 'geneset_go_bp_2023' : 'go_bp_2023' ,
801+ 'geneset_bioplanet_2019' : 'bioplanet_2019' ,
802+ 'geneset_wikipathways_2019' : 'wikipathways_2019' ,
803+ }
804+
805+ for arg_name , geneset_name in geneset_mapping .items ():
806+ pathway_files [geneset_name ] = par [arg_name ]
807+
808+ par ['pathway_files' ] = pathway_files
796809 evaluation_data = ad .read_h5ad (par ['evaluation_data' ], backed = 'r' )
797810 all_genes = set (evaluation_data .var_names .tolist ())
798811 prediction = read_prediction (par )
@@ -846,21 +859,25 @@ def main(par: dict) -> pd.DataFrame:
846859 all_results .append (result_dict )
847860
848861
849- final_dict = {}
862+ detailed_dict = {}
850863 for result in all_results :
851864 geneset_name = result ['geneset_name' ]
852- final_dict [f'{ geneset_name } _gs_precision' ] = result ['precision' ]
853- final_dict [f'{ geneset_name } _gs_recall' ] = result ['recall' ]
854- final_dict [f'{ geneset_name } _gs_f1' ] = result ['f1' ]
855- final_dict [f'{ geneset_name } _gs_n_active' ] = result ['n_active_pathways' ]
865+ detailed_dict [f'{ geneset_name } _gs_precision' ] = result ['precision' ]
866+ detailed_dict [f'{ geneset_name } _gs_recall' ] = result ['recall' ]
867+ detailed_dict [f'{ geneset_name } _gs_f1' ] = result ['f1' ]
868+ detailed_dict [f'{ geneset_name } _gs_n_active' ] = result ['n_active_pathways' ]
856869
857870 # Calculate mean across all gene sets
871+ short_dict = {}
858872 if all_results :
859- final_dict ['gs_precision' ] = np .mean ([r ['precision' ] for r in all_results ])
860- final_dict ['gs_recall' ] = np .mean ([r ['recall' ] for r in all_results ])
861- final_dict ['gs_f1' ] = np .mean ([r ['f1' ] for r in all_results ])
862- final_dict ['gs_n_active' ] = np .mean ([r ['n_active_pathways' ] for r in all_results ])
863-
873+ short_dict ['gs_precision' ] = np .mean ([r ['precision' ] for r in all_results ])
874+ short_dict ['gs_recall' ] = np .mean ([r ['recall' ] for r in all_results ])
875+ short_dict ['gs_f1' ] = np .mean ([r ['f1' ] for r in all_results ])
876+ short_dict ['gs_n_active' ] = np .mean ([r ['n_active_pathways' ] for r in all_results ])
877+ if par ['output_detailed_metrics' ]:
878+ final_dict = {** short_dict , ** detailed_dict }
879+ else :
880+ final_dict = short_dict
864881 summary_df = pd .DataFrame ([final_dict ])
865882 print (summary_df )
866883 return summary_df
0 commit comments