Skip to content

Commit 8ae8ccc

Browse files
committed
add figure output to auc f-score compute function
1 parent ae39e7a commit 8ae8ccc

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

tpcav/cavs.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pandas as pd
1717
import seaborn as sns
1818
import torch
19+
from scipy import stats
1920
from sklearn.linear_model import SGDClassifier
2021
from sklearn.metrics import precision_recall_fscore_support
2122
from sklearn.metrics.pairwise import cosine_similarity
@@ -527,9 +528,14 @@ def load_motifs_from_custom_motif(motif_file):
527528

528529
return motifs_dict
529530

531+
def plot_reg(data, x, y, ax=None):
532+
ax = sns.regplot(data=data, x=x, y=y, ax=ax)
533+
res = stats.linregress(data[x], data[y])
534+
ax.text(0.05, 0.9, f"R^2: {res.rvalue**2:.4f}\nP value: {res.pvalue}", transform=ax.transAxes)
535+
return res
530536

531537
def compute_motif_auc_fscore(num_motif_insertions: List[int], cav_trainers: List[CavTrainer], motif_file: Optional[str] = None,
532-
motif_file_fmt: str = 'meme'):
538+
motif_file_fmt: str = 'meme', figure_path: Optional[str]=None):
533539

534540
assert motif_file_fmt in ['meme', 'consensus']
535541

@@ -544,7 +550,8 @@ def compute_auc_fscore(row):
544550

545551
cavs_fscores_df["AUC_fscores"] = cavs_fscores_df.apply(compute_auc_fscore, axis=1)
546552

547-
# if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on information content and motif lengthj
553+
# if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on either information_content_GC or motif length and motif gc
554+
fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(15, 4))
548555
if motif_file is not None:
549556
if motif_file_fmt == 'meme':
550557
motifs_dict = load_motifs_from_meme(motif_file)
@@ -559,6 +566,9 @@ def load_meme_motif_info(key):
559566
y_pred = model.predict(cavs_fscores_df[['information_content_GC',]].to_numpy())
560567
residuals = cavs_fscores_df['AUC_fscores'].to_numpy() - y_pred.flatten()
561568
cavs_fscores_df['AUC_fscores_residual'] = residuals
569+
plot_reg(data=cavs_fscores_df, x='information_content_GC', y='AUC_fscores', ax=axes[0])
570+
plot_reg(data=cavs_fscores_df, x='information_content', y='AUC_fscores', ax=axes[1])
571+
plot_reg(data=cavs_fscores_df, x='motif_len', y='AUC_fscores', ax=axes[2])
562572
else:
563573
motifs_dict = load_motifs_from_custom_motif(motif_file)
564574
def load_custom_motif_info(key):
@@ -570,15 +580,21 @@ def load_custom_motif_info(key):
570580
cavs_fscores_df[['avg_len', 'avg_gc']] = cavs_fscores_df.apply(lambda x: load_custom_motif_info(x['concept']), axis=1, result_type='expand')
571581

572582
model = LinearRegression()
573-
model.fit(cavs_fscores_df[['avg_len', 'avg_gc']].to_numpy(), cavs_fscores_df['AUC_fscores'].to_numpy()[:, np.newaxis])
583+
model.fit(cavs_fscores_df[['avg_gc',]].to_numpy(), cavs_fscores_df['AUC_fscores'].to_numpy()[:, np.newaxis])
574584

575-
y_pred = model.predict(cavs_fscores_df[['avg_len', 'avg_gc']].to_numpy())
585+
y_pred = model.predict(cavs_fscores_df[['avg_gc',]].to_numpy())
576586
residuals = cavs_fscores_df['AUC_fscores'].to_numpy() - y_pred.flatten()
577587
cavs_fscores_df['AUC_fscores_residual'] = residuals
588+
plot_reg(data=cavs_fscores_df, x='avg_len', y='AUC_fscores', ax=axes[0])
589+
plot_reg(data=cavs_fscores_df, x='avg_gc', y='AUC_fscores', ax=axes[1])
590+
axes[2].set_axis('off')
578591

579592
cavs_fscores_df.sort_values("AUC_fscores_residual", ascending=False, inplace=True)
580593
else:
581594
cavs_fscores_df.sort_values("AUC_fscores", ascending=False, inplace=True)
595+
596+
if figure_path is not None:
597+
plt.savefig(figure_path, dpi=300, bbox_inches='tight')
582598

583599
return cavs_fscores_df
584600

0 commit comments

Comments
 (0)