1616import pandas as pd
1717import seaborn as sns
1818import torch
19+ from scipy import stats
1920from sklearn .linear_model import SGDClassifier
2021from sklearn .metrics import precision_recall_fscore_support
2122from sklearn .metrics .pairwise import cosine_similarity
@@ -527,9 +528,14 @@ def load_motifs_from_custom_motif(motif_file):
527528
528529 return motifs_dict
529530
531+ def plot_reg (data , x , y , ax = None ):
532+ ax = sns .regplot (data = data , x = x , y = y , ax = ax )
533+ res = stats .linregress (data [x ], data [y ])
534+ ax .text (0.05 , 0.9 , f"R^2: { res .rvalue ** 2 :.4f} \n P value: { res .pvalue } " , transform = ax .transAxes )
535+ return res
530536
531537def compute_motif_auc_fscore (num_motif_insertions : List [int ], cav_trainers : List [CavTrainer ], motif_file : Optional [str ] = None ,
532- motif_file_fmt : str = 'meme' ):
538+ motif_file_fmt : str = 'meme' , figure_path : Optional [ str ] = None ):
533539
534540 assert motif_file_fmt in ['meme' , 'consensus' ]
535541
@@ -544,7 +550,8 @@ def compute_auc_fscore(row):
544550
545551 cavs_fscores_df ["AUC_fscores" ] = cavs_fscores_df .apply (compute_auc_fscore , axis = 1 )
546552
547- # if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on information content and motif lengthj
553+ # if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on either information_content_GC or motif length and motif gc
554+ fig , axes = plt .subplots (nrows = 1 , ncols = 3 , sharey = True , figsize = (15 , 4 ))
548555 if motif_file is not None :
549556 if motif_file_fmt == 'meme' :
550557 motifs_dict = load_motifs_from_meme (motif_file )
@@ -559,6 +566,9 @@ def load_meme_motif_info(key):
559566 y_pred = model .predict (cavs_fscores_df [['information_content_GC' ,]].to_numpy ())
560567 residuals = cavs_fscores_df ['AUC_fscores' ].to_numpy () - y_pred .flatten ()
561568 cavs_fscores_df ['AUC_fscores_residual' ] = residuals
569+ plot_reg (data = cavs_fscores_df , x = 'information_content_GC' , y = 'AUC_fscores' , ax = axes [0 ])
570+ plot_reg (data = cavs_fscores_df , x = 'information_content' , y = 'AUC_fscores' , ax = axes [1 ])
571+ plot_reg (data = cavs_fscores_df , x = 'motif_len' , y = 'AUC_fscores' , ax = axes [2 ])
562572 else :
563573 motifs_dict = load_motifs_from_custom_motif (motif_file )
564574 def load_custom_motif_info (key ):
@@ -570,15 +580,21 @@ def load_custom_motif_info(key):
570580 cavs_fscores_df [['avg_len' , 'avg_gc' ]] = cavs_fscores_df .apply (lambda x : load_custom_motif_info (x ['concept' ]), axis = 1 , result_type = 'expand' )
571581
572582 model = LinearRegression ()
573- model .fit (cavs_fscores_df [['avg_len' , ' avg_gc' ]].to_numpy (), cavs_fscores_df ['AUC_fscores' ].to_numpy ()[:, np .newaxis ])
583+ model .fit (cavs_fscores_df [['avg_gc' , ]].to_numpy (), cavs_fscores_df ['AUC_fscores' ].to_numpy ()[:, np .newaxis ])
574584
575- y_pred = model .predict (cavs_fscores_df [['avg_len' , ' avg_gc' ]].to_numpy ())
585+ y_pred = model .predict (cavs_fscores_df [['avg_gc' , ]].to_numpy ())
576586 residuals = cavs_fscores_df ['AUC_fscores' ].to_numpy () - y_pred .flatten ()
577587 cavs_fscores_df ['AUC_fscores_residual' ] = residuals
588+ plot_reg (data = cavs_fscores_df , x = 'avg_len' , y = 'AUC_fscores' , ax = axes [0 ])
589+ plot_reg (data = cavs_fscores_df , x = 'avg_gc' , y = 'AUC_fscores' , ax = axes [1 ])
590+ axes [2 ].set_axis ('off' )
578591
579592 cavs_fscores_df .sort_values ("AUC_fscores_residual" , ascending = False , inplace = True )
580593 else :
581594 cavs_fscores_df .sort_values ("AUC_fscores" , ascending = False , inplace = True )
595+
596+ if figure_path is not None :
597+ plt .savefig (figure_path , dpi = 300 , bbox_inches = 'tight' )
582598
583599 return cavs_fscores_df
584600
0 commit comments