Skip to content

Commit 2d43327

Browse files
committed
add html report
1 parent 8cfc153 commit 2d43327

3 files changed

Lines changed: 397 additions & 6 deletions

File tree

test/test_cav_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from Bio import motifs as Bio_motifs
77
from captum.attr import DeepLift
88

9-
from tpcav import helper, run_tpcav, utils
9+
from tpcav import helper, run_tpcav, utils, report
1010
from tpcav.cavs import CavTrainer
1111
from tpcav.concepts import ConceptBuilder
1212
from tpcav.tpcav_model import TPCAV, _abs_attribution_func
@@ -118,6 +118,10 @@ def test_run_tpcav_random_control(self):
118118
output_dir="data/test_run_tpcav_output/",
119119
)
120120

121+
report.generate_tcav_html_report("data/test_html.html", motif_cav_trainers,
122+
extra_cav_trainers = {'repeats': bed_cav_trainer},
123+
motif_file=motif_path, fscore_thresh=0.1)
124+
121125
def test_run_tpcav_consensus_random_control(self):
122126
motif_path = Path("data") / "custom_motifs_alan.tsv"
123127
genome_fasta = "data/hg38.analysisSet.fa"

tpcav/cavs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def plot_cavs_similaritiy_heatmap(
419419
cm = sns.clustermap(
420420
matrix_similarity,
421421
xticklabels=False,
422-
yticklabels=False,
422+
yticklabels=False if attributions is not None else cavs_names_pass,
423423
cmap="bwr",
424424
vmin=-1,
425425
vmax=1,
@@ -462,15 +462,20 @@ def plot_cavs_similaritiy_heatmap(
462462

463463
# plot motif logo if provided meme file, try to look for pwm for every concept in the file
464464
if motif_meme_file is not None:
465-
ax_logs[-1].tick_params(
466-
axis="y", which="major", pad=cm.figure.get_size_inches()[0] * 0.2 * 72 # leave space for motif logos
467-
)
465+
if attributions is not None:
466+
ax_logs[-1].tick_params(
467+
axis="y", which="major", pad=cm.figure.get_size_inches()[0] * 0.3 * 72 # leave space for motif logos
468+
)
469+
else:
470+
cm.ax_heatmap.tick_params(
471+
axis="y", which="major", pad=cm.figure.get_size_inches()[0] * 0.3 * 72 # leave space for motif logos
472+
)
468473
gs_logo = gridspec.GridSpec(len(cavs_names_pass), 1)
469474

470475
logo_height = heatmap_bbox.height/len(cavs_names_pass)
471476
for i, (cav_key, g) in enumerate(zip(cavs_names_sorted[::-1], gs_logo)):
472477
ax_logo = plt.subplot(g)
473-
ax_logo.set_position([1+len(ax_logs)*0.2+0.01, heatmap_bbox.y0+i*logo_height, 0.2+0.01, logo_height])
478+
ax_logo.set_position([1+len(ax_logs)*0.2+0.01, heatmap_bbox.y0+i*logo_height, 0.3+0.01, logo_height])
474479
if cav_key is not None:
475480
seq_logo(cav_key, motif_meme_file=motif_meme_file, ax=ax_logo)
476481
else:

0 commit comments

Comments
 (0)