Skip to content

Commit 1793b20

Browse files
committed
improve api
1 parent 51f2c6b commit 1793b20

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

tpcav/cavs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def tpcav_score_all_concepts_log_ratio(
384384

385385
def plot_cavs_similaritiy_heatmap(
386386
self,
387-
attributions: Optional[List[torch.Tensor]] = None,
387+
attributions: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
388388
concept_list: Optional[List[str]] = None,
389389
fscore_thresh=0.8,
390390
motif_meme_file: Optional[str] = None,
@@ -433,6 +433,7 @@ def plot_cavs_similaritiy_heatmap(
433433
heatmap_bbox = cm.ax_heatmap.get_position()
434434
ax_logs = []
435435
if attributions is not None:
436+
attributions = attributions if isinstance(attributions, List) else [attributions, ]
436437
for i, attrs in enumerate(attributions):
437438
offset = 1 + i*0.2
438439
## plot log ratio plot

0 commit comments

Comments
 (0)