Skip to content

Commit a652abe

Browse files
committed
improve html report
1 parent 2d43327 commit a652abe

3 files changed

Lines changed: 897 additions & 70 deletions

File tree

test/test_cav_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,33 @@ def test_run_tpcav_random_control(self):
118118
output_dir="data/test_run_tpcav_output/",
119119
)
120120

121+
random_regions_1 = helper.random_regions_dataframe(
122+
"data/hg38.analysisSet.fa.fai", 1024, 100, seed=1
123+
)
124+
random_regions_2 = helper.random_regions_dataframe(
125+
"data/hg38.analysisSet.fa.fai", 1024, 100, seed=2
126+
)
127+
128+
def pack_data_iters(df):
129+
seq_fasta_iter = helper.dataframe_to_fasta_iter(
130+
df, "data/hg38.analysisSet.fa", batch_size=8
131+
)
132+
seq_one_hot_iter = (
133+
helper.fasta_to_one_hot_sequences(seq_fasta)
134+
for seq_fasta in seq_fasta_iter
135+
)
136+
chrom_iter = helper.dataframe_to_chrom_tracks_iter(df, None, batch_size=8)
137+
return zip(
138+
seq_one_hot_iter,
139+
)
140+
141+
attributions = bed_cav_trainer.tpcav.layer_attributions(
142+
pack_data_iters(random_regions_1), pack_data_iters(random_regions_2)
143+
).cpu()
144+
121145
report.generate_tcav_html_report("data/test_html.html", motif_cav_trainers,
122146
extra_cav_trainers = {'repeats': bed_cav_trainer},
147+
attributions = [attributions, ] * 3,
123148
motif_file=motif_path, fscore_thresh=0.1)
124149

125150
def test_run_tpcav_consensus_random_control(self):

tpcav/cavs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,16 @@ def plot_cavs_similaritiy_heatmap(
433433

434434
heatmap_bbox = cm.ax_heatmap.get_position()
435435
ax_logs = []
436+
log_ratios_by_attr = None
436437
if attributions is not None:
437-
attributions = attributions if isinstance(attributions, List) else [attributions, ]
438+
log_ratios_by_attr = {}
439+
attributions = (
440+
attributions
441+
if isinstance(attributions, (list, tuple))
442+
else [
443+
attributions,
444+
]
445+
)
438446
for i, attrs in enumerate(attributions):
439447
offset = 1 + i*0.2
440448
## plot log ratio plot
@@ -445,6 +453,7 @@ def plot_cavs_similaritiy_heatmap(
445453
self.tpcav_score_binary_log_ratio(cname, attrs)
446454
for cname in cavs_names_sorted
447455
]
456+
log_ratios_by_attr[i] = log_ratios_reordered
448457
sns.barplot(y=cavs_names_sorted, x=log_ratios_reordered, orient="y", ax=ax_log)
449458
# set color of bar by value
450459
for idx in range(len((ax_log.containers[0]))):
@@ -482,6 +491,20 @@ def plot_cavs_similaritiy_heatmap(
482491
ax_logo.axis('off')
483492

484493
plt.savefig(output_path, dpi=300, bbox_inches="tight")
494+
row_reordered_ind = list(cm.dendrogram_row.reordered_ind)
495+
col_reordered_ind = list(cm.dendrogram_col.reordered_ind)
496+
matrix_similarity_sorted = matrix_similarity[np.ix_(row_reordered_ind, col_reordered_ind)]
497+
return {
498+
"concept_names": cavs_names_pass,
499+
"matrix_similarity": matrix_similarity,
500+
"row_reordered_ind": row_reordered_ind,
501+
"col_reordered_ind": col_reordered_ind,
502+
"concept_names_sorted_rows": cavs_names_sorted,
503+
"concept_names_sorted_cols": cavs_names_sorted,
504+
"matrix_similarity_sorted": matrix_similarity_sorted,
505+
"log_ratios_by_attr": log_ratios_by_attr,
506+
"output_path": str(output_path),
507+
}
485508

486509
def seq_logo(key, motif_meme_file, ax, max_len=20):
487510
"plot a pwm logo"

0 commit comments

Comments
 (0)