@@ -419,7 +419,7 @@ def plot_cavs_similaritiy_heatmap(
419419 cm = sns .clustermap (
420420 matrix_similarity ,
421421 xticklabels = False ,
422- yticklabels = False ,
422+ yticklabels = False if attributions is not None else cavs_names_pass ,
423423 cmap = "bwr" ,
424424 vmin = - 1 ,
425425 vmax = 1 ,
@@ -462,15 +462,20 @@ def plot_cavs_similaritiy_heatmap(
462462
463463 # plot motif logo if provided meme file, try to look for pwm for every concept in the file
464464 if motif_meme_file is not None :
465- ax_logs [- 1 ].tick_params (
466- axis = "y" , which = "major" , pad = cm .figure .get_size_inches ()[0 ] * 0.2 * 72 # leave space for motif logos
467- )
465+ if attributions is not None :
466+ ax_logs [- 1 ].tick_params (
467+ axis = "y" , which = "major" , pad = cm .figure .get_size_inches ()[0 ] * 0.3 * 72 # leave space for motif logos
468+ )
469+ else :
470+ cm .ax_heatmap .tick_params (
471+ axis = "y" , which = "major" , pad = cm .figure .get_size_inches ()[0 ] * 0.3 * 72 # leave space for motif logos
472+ )
468473 gs_logo = gridspec .GridSpec (len (cavs_names_pass ), 1 )
469474
470475 logo_height = heatmap_bbox .height / len (cavs_names_pass )
471476 for i , (cav_key , g ) in enumerate (zip (cavs_names_sorted [::- 1 ], gs_logo )):
472477 ax_logo = plt .subplot (g )
473- ax_logo .set_position ([1 + len (ax_logs )* 0.2 + 0.01 , heatmap_bbox .y0 + i * logo_height , 0.2 + 0.01 , logo_height ])
478+ ax_logo .set_position ([1 + len (ax_logs )* 0.2 + 0.01 , heatmap_bbox .y0 + i * logo_height , 0.3 + 0.01 , logo_height ])
474479 if cav_key is not None :
475480 seq_logo (cav_key , motif_meme_file = motif_meme_file , ax = ax_logo )
476481 else :
0 commit comments