We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 51f2c6b commit 1793b20Copy full SHA for 1793b20
1 file changed
tpcav/cavs.py
@@ -384,7 +384,7 @@ def tpcav_score_all_concepts_log_ratio(
384
385
def plot_cavs_similaritiy_heatmap(
386
self,
387
- attributions: Optional[List[torch.Tensor]] = None,
+ attributions: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
388
concept_list: Optional[List[str]] = None,
389
fscore_thresh=0.8,
390
motif_meme_file: Optional[str] = None,
@@ -433,6 +433,7 @@ def plot_cavs_similaritiy_heatmap(
433
heatmap_bbox = cm.ax_heatmap.get_position()
434
ax_logs = []
435
if attributions is not None:
436
+ attributions = attributions if isinstance(attributions, List) else [attributions, ]
437
for i, attrs in enumerate(attributions):
438
offset = 1 + i*0.2
439
## plot log ratio plot
0 commit comments