@@ -433,8 +433,16 @@ def plot_cavs_similaritiy_heatmap(
433433
434434 heatmap_bbox = cm .ax_heatmap .get_position ()
435435 ax_logs = []
436+ log_ratios_by_attr = None
436437 if attributions is not None :
437- attributions = attributions if isinstance (attributions , List ) else [attributions , ]
438+ log_ratios_by_attr = {}
439+ attributions = (
440+ attributions
441+ if isinstance (attributions , (list , tuple ))
442+ else [
443+ attributions ,
444+ ]
445+ )
438446 for i , attrs in enumerate (attributions ):
439447 offset = 1 + i * 0.2
440448 ## plot log ratio plot
@@ -445,6 +453,7 @@ def plot_cavs_similaritiy_heatmap(
445453 self .tpcav_score_binary_log_ratio (cname , attrs )
446454 for cname in cavs_names_sorted
447455 ]
456+ log_ratios_by_attr [i ] = log_ratios_reordered
448457 sns .barplot (y = cavs_names_sorted , x = log_ratios_reordered , orient = "y" , ax = ax_log )
449458 # set color of bar by value
450459 for idx in range (len ((ax_log .containers [0 ]))):
@@ -482,6 +491,20 @@ def plot_cavs_similaritiy_heatmap(
482491 ax_logo .axis ('off' )
483492
484493 plt .savefig (output_path , dpi = 300 , bbox_inches = "tight" )
494+ row_reordered_ind = list (cm .dendrogram_row .reordered_ind )
495+ col_reordered_ind = list (cm .dendrogram_col .reordered_ind )
496+ matrix_similarity_sorted = matrix_similarity [np .ix_ (row_reordered_ind , col_reordered_ind )]
497+ return {
498+ "concept_names" : cavs_names_pass ,
499+ "matrix_similarity" : matrix_similarity ,
500+ "row_reordered_ind" : row_reordered_ind ,
501+ "col_reordered_ind" : col_reordered_ind ,
502+ "concept_names_sorted_rows" : cavs_names_sorted ,
503+ "concept_names_sorted_cols" : cavs_names_sorted ,
504+ "matrix_similarity_sorted" : matrix_similarity_sorted ,
505+ "log_ratios_by_attr" : log_ratios_by_attr ,
506+ "output_path" : str (output_path ),
507+ }
485508
486509def seq_logo (key , motif_meme_file , ax , max_len = 20 ):
487510 "plot a pwm logo"
0 commit comments