Skip to content

Commit 5003135

Browse files
mvdocclaude
andcommitted
Switch participant dataset plot to compact dot matrix style
Replace filled square heatmap with colored dots for each dataset. Add legend at bottom with one column, left-aligned with y-axis labels. Use distinct colors per dataset (blue, orange, green) for clarity. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fcaa7f2 commit 5003135

1 file changed

Lines changed: 27 additions & 23 deletions

File tree

scripts/qa/qa-plot-participant-datasets.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,34 +39,38 @@ def main():
3939
matrix[:, 1] = (df["budapest"] == "Yes").astype(int)
4040
matrix[:, 2] = (df["identity_decoding"] == "Yes").astype(int)
4141

42-
# Create figure with square cells
43-
fig, ax = plt.subplots(figsize=(2, n_participants * 0.25))
44-
45-
# Plot heatmap using pcolormesh for precise cell edges
46-
cmap = plt.cm.colors.ListedColormap(["white", "tab:blue"])
47-
ax.pcolormesh(
48-
matrix,
49-
cmap=cmap,
50-
vmin=0,
51-
vmax=1,
52-
edgecolors="lightgray",
53-
linewidth=0.3,
54-
)
55-
ax.set_aspect("equal")
56-
ax.invert_yaxis()
42+
# Create figure
43+
fig, ax = plt.subplots(figsize=(0.8, n_participants * 0.2))
44+
45+
# Colors for each dataset
46+
colors = ["tab:blue", "tab:orange", "tab:green"]
47+
labels = ["hyperface", "budapest", "identity decoding"]
48+
49+
# Plot dots for each dataset
50+
for j in range(3):
51+
for i in range(n_participants):
52+
if matrix[i, j] == 1:
53+
ax.scatter(j, i, color=colors[j], s=20, marker="o")
54+
else:
55+
ax.scatter(j, i, color="lightgray", s=20, marker="o", facecolors="none")
56+
57+
# Add legend handles
58+
for j in range(3):
59+
ax.scatter([], [], color=colors[j], s=20, marker="o", label=labels[j])
60+
61+
ax.set_xlim(-0.5, 2.5)
62+
ax.set_ylim(n_participants - 0.5, -0.5)
5763

5864
# Labels
59-
ax.set_xticks(np.arange(3) + 0.5)
60-
ax.set_xticklabels(
61-
["hyperface", "budapest", "identity\ndecoding"], rotation=45, ha="left"
62-
)
63-
ax.xaxis.tick_top()
64-
ax.set_yticks(np.arange(n_participants) + 0.5)
65+
ax.set_xticks([])
66+
ax.set_yticks(range(n_participants))
6567
ax.set_yticklabels(df["participant_id"].str.replace("sub-", ""))
6668

67-
ax.set_ylabel("Participant")
69+
ax.tick_params(axis="both", length=0)
70+
ax.spines[["top", "right", "bottom", "left"]].set_visible(False)
6871

69-
plt.tight_layout()
72+
# Legend at bottom, left-aligned with y-axis labels
73+
ax.legend(loc="upper left", bbox_to_anchor=(-0.6, -0.02), ncol=1, frameon=False)
7074

7175
# Save figure
7276
output_dir = config.paths.qa_base_dir / "figures"

0 commit comments

Comments
 (0)