Skip to content

Commit 77dea94

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/master'
* refs/remotes/origin/master: Fix accuracy statistics to report per-subject averages Add QA script for task accuracy visualization
2 parents 357435e + 4422073 commit 77dea94

3 files changed

Lines changed: 312 additions & 0 deletions

File tree

scripts/qa/qa-plot-accuracy.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
#!/usr/bin/env python
2+
"""Plot task accuracy per run for each participant.
3+
4+
Extracts accuracy from events.tsv files (visualmemory task only) and generates:
5+
- A bar chart showing mean accuracy per subject with individual run values as scatter
6+
- A text summary file with accuracy statistics
7+
8+
Outputs:
9+
- desc-accuracy_barplot.png: Bar chart of accuracy per subject
10+
- accuracy_summary.txt: Text summary of accuracy statistics
11+
12+
Usage:
13+
python scripts/qa/qa-plot-accuracy.py
14+
python scripts/qa/qa-plot-accuracy.py --subjects sub-sid000005 sub-sid000009
15+
"""
16+
17+
import re
18+
from pathlib import Path
19+
20+
import matplotlib.pyplot as plt
21+
import numpy as np
22+
import pandas as pd
23+
import seaborn as sns
24+
25+
from hyperface.qa import create_qa_argument_parser, discover_subjects, get_config
26+
27+
# Plot settings
28+
DPI = 300
29+
PRIMARY_COLOR = "steelblue"
30+
EDGE_COLOR = "darkslategray"
31+
SCATTER_COLOR = "darkred"
32+
33+
34+
def extract_accuracy_from_events(events_file: Path) -> int | None:
35+
"""Extract accuracy percentage from an events.tsv file.
36+
37+
Parameters
38+
----------
39+
events_file : Path
40+
Path to the events.tsv file.
41+
42+
Returns
43+
-------
44+
int or None
45+
Accuracy percentage (0-100), or None if not found.
46+
"""
47+
df = pd.read_csv(events_file, sep="\t")
48+
49+
# Look for accuracy_XX pattern in trial_type column
50+
for trial_type in df["trial_type"].values:
51+
if isinstance(trial_type, str) and trial_type.startswith("accuracy_"):
52+
match = re.match(r"accuracy_(\d+)", trial_type)
53+
if match:
54+
return int(match.group(1))
55+
56+
return None
57+
58+
59+
def collect_accuracy_data(
60+
data_dir: Path, subjects: list[str]
61+
) -> dict[str, dict[str, int]]:
62+
"""Collect accuracy data from events.tsv files for all subjects.
63+
64+
Parameters
65+
----------
66+
data_dir : Path
67+
Root BIDS data directory.
68+
subjects : list[str]
69+
List of subject IDs to process.
70+
71+
Returns
72+
-------
73+
dict
74+
Dictionary mapping subject IDs to dict of run -> accuracy.
75+
Example: {"sub-001": {"run-01": 100, "run-02": 75}}
76+
"""
77+
accuracy_data = {}
78+
79+
for subject in sorted(subjects):
80+
subject_dir = data_dir / subject
81+
if not subject_dir.exists():
82+
continue
83+
84+
# Find all visualmemory events files
85+
events_files = list(
86+
subject_dir.glob("ses-*/func/*_task-visualmemory_run-*_events.tsv")
87+
)
88+
89+
if not events_files:
90+
continue
91+
92+
subject_accuracy = {}
93+
for events_file in sorted(events_files):
94+
# Extract session and run number from filename
95+
ses_match = re.search(r"ses-(\d+)", events_file.name)
96+
run_match = re.search(r"run-(\d+)", events_file.name)
97+
if ses_match and run_match:
98+
run_id = f"ses-{ses_match.group(1)}_run-{run_match.group(1)}"
99+
accuracy = extract_accuracy_from_events(events_file)
100+
if accuracy is not None:
101+
subject_accuracy[run_id] = accuracy
102+
103+
if subject_accuracy:
104+
accuracy_data[subject] = subject_accuracy
105+
106+
return accuracy_data
107+
108+
109+
def plot_accuracy_figure(
110+
accuracy_data: dict[str, dict[str, int]], output_path: Path
111+
) -> None:
112+
"""Create a bar chart with scatter overlay showing accuracy per subject.
113+
114+
Parameters
115+
----------
116+
accuracy_data : dict
117+
Dictionary mapping subject IDs to dict of run -> accuracy.
118+
output_path : Path
119+
Path to save the figure.
120+
"""
121+
# Prepare data for plotting
122+
subjects = sorted(accuracy_data.keys())
123+
mean_accuracies = []
124+
all_run_values = []
125+
126+
for subject in subjects:
127+
runs = accuracy_data[subject]
128+
values = list(runs.values())
129+
mean_accuracies.append(np.mean(values))
130+
all_run_values.append(values)
131+
132+
# Create figure
133+
fig, ax = plt.subplots(figsize=(max(12, len(subjects) * 0.5), 6))
134+
135+
# Bar plot for mean accuracy
136+
x_positions = np.arange(len(subjects))
137+
bars = ax.bar(
138+
x_positions,
139+
mean_accuracies,
140+
color=PRIMARY_COLOR,
141+
edgecolor=EDGE_COLOR,
142+
linewidth=1,
143+
alpha=0.7,
144+
label="Mean accuracy",
145+
)
146+
147+
# Scatter plot for individual run values
148+
for i, (x_pos, values) in enumerate(zip(x_positions, all_run_values)):
149+
jitter = np.random.uniform(-0.15, 0.15, len(values))
150+
ax.scatter(
151+
[x_pos + j for j in jitter],
152+
values,
153+
color=SCATTER_COLOR,
154+
s=50,
155+
zorder=5,
156+
alpha=0.8,
157+
edgecolors="white",
158+
linewidths=0.5,
159+
)
160+
161+
# Add a single scatter point to legend
162+
ax.scatter([], [], color=SCATTER_COLOR, s=50, label="Individual runs")
163+
164+
# Styling
165+
ax.set_xticks(x_positions)
166+
# Shorten subject labels for readability
167+
short_labels = [s.replace("sub-sid", "s") for s in subjects]
168+
ax.set_xticklabels(short_labels, rotation=45, ha="right", fontsize=9)
169+
ax.set_ylabel("Accuracy (%)", fontsize=13)
170+
ax.set_xlabel("Subject", fontsize=13)
171+
ax.set_title(
172+
"Task Accuracy - Visual Memory", fontsize=15, fontweight="bold", pad=12
173+
)
174+
ax.set_ylim(0, 105)
175+
ax.axhline(y=100, color="gray", linestyle="--", alpha=0.5, linewidth=1)
176+
ax.legend(loc="lower right")
177+
ax.grid(True, axis="y", alpha=0.3)
178+
ax.set_axisbelow(True)
179+
sns.despine(ax=ax)
180+
181+
plt.tight_layout()
182+
fig.savefig(output_path, dpi=DPI, bbox_inches="tight", facecolor="white")
183+
plt.close(fig)
184+
print(f"Saved: {output_path}")
185+
186+
187+
def format_accuracy_summary(accuracy_data: dict[str, dict[str, int]]) -> str:
188+
"""Format accuracy summary text.
189+
190+
Parameters
191+
----------
192+
accuracy_data : dict
193+
Dictionary mapping subject IDs to dict of run -> accuracy.
194+
195+
Returns
196+
-------
197+
str
198+
Formatted summary string.
199+
"""
200+
if not accuracy_data:
201+
return "No accuracy data found."
202+
203+
# Compute statistics
204+
all_values = []
205+
subject_means = []
206+
perfect_subjects = 0
207+
208+
for subject, runs in accuracy_data.items():
209+
values = list(runs.values())
210+
all_values.extend(values)
211+
subject_means.append(np.mean(values))
212+
if all(v == 100 for v in values):
213+
perfect_subjects += 1
214+
215+
n_subjects = len(accuracy_data)
216+
n_runs_per_subject = len(next(iter(accuracy_data.values())))
217+
218+
lines = [
219+
"=" * 60,
220+
"ACCURACY SUMMARY - Visual Memory Task",
221+
"=" * 60,
222+
"",
223+
f"Number of subjects: {n_subjects}",
224+
f"Number of runs per subject: {n_runs_per_subject}",
225+
"",
226+
"Accuracy Statistics (per-subject averages):",
227+
f" Mean: {np.mean(subject_means):.1f}%",
228+
f" Median: {np.median(subject_means):.1f}%",
229+
f" Min: {np.min(subject_means):.1f}%",
230+
f" Max: {np.max(subject_means):.1f}%",
231+
f" Subjects with 100% accuracy (all runs): {perfect_subjects}/{n_subjects}",
232+
"",
233+
"Per-subject breakdown:",
234+
]
235+
236+
for subject in sorted(accuracy_data.keys()):
237+
runs = accuracy_data[subject]
238+
run_str = ", ".join(
239+
[f"{run}: {acc}%" for run, acc in sorted(runs.items())]
240+
)
241+
lines.append(f" {subject}: {run_str}")
242+
243+
lines.extend(
244+
[
245+
"",
246+
"-" * 60,
247+
"",
248+
"Paper-ready text:",
249+
f" Participants achieved a mean accuracy of {np.mean(subject_means):.1f}% "
250+
f"(median {np.median(subject_means):.1f}%, min {np.min(subject_means):.1f}%, "
251+
f"max {np.max(subject_means):.1f}%) on the visual memory task. "
252+
f"{perfect_subjects} out of {n_subjects} participants achieved "
253+
f"100% accuracy across all runs.",
254+
"",
255+
]
256+
)
257+
258+
return "\n".join(lines)
259+
260+
261+
def main():
262+
parser = create_qa_argument_parser(
263+
description="Plot task accuracy per run for each participant.",
264+
include_subjects=True,
265+
)
266+
args = parser.parse_args()
267+
268+
# Load configuration
269+
config = get_config(config_path=args.config, data_dir=args.data_dir)
270+
data_dir = config.paths.data_dir
271+
accuracy_dir = config.paths.accuracy_dir
272+
273+
# Discover subjects from raw data directory
274+
subjects = discover_subjects(data_dir, args.subjects)
275+
print(f"Processing {len(subjects)} subjects...")
276+
277+
# Collect accuracy data
278+
accuracy_data = collect_accuracy_data(data_dir, subjects)
279+
280+
if not accuracy_data:
281+
print("No accuracy data found in events.tsv files.")
282+
return 1
283+
284+
print(f"Found accuracy data for {len(accuracy_data)} subjects")
285+
286+
# Create output directories
287+
figures_dir = accuracy_dir / "figures"
288+
figures_dir.mkdir(parents=True, exist_ok=True)
289+
290+
# Generate figure
291+
plot_accuracy_figure(accuracy_data, figures_dir / "desc-accuracy_barplot.png")
292+
293+
# Generate and save text summary
294+
summary_text = format_accuracy_summary(accuracy_data)
295+
print(summary_text)
296+
297+
summary_path = accuracy_dir / "accuracy_summary.txt"
298+
summary_path.write_text(summary_text)
299+
print(f"Saved: {summary_path}")
300+
301+
return 0
302+
303+
304+
if __name__ == "__main__":
305+
raise SystemExit(main())

src/hyperface/assets/qa_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ directories:
2121
motion: "motion"
2222
isc: "isc"
2323
stimuli: "stimuli"
24+
accuracy: "accuracy"
2425

2526
# Stimuli labels directory (under derivatives_dir)
2627
stimuli_labels: "stimuli/labels"

src/hyperface/qa/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class QAPaths:
3333
Inter-subject correlation output directory.
3434
stimuli_dir : Path
3535
Stimuli QA output directory.
36+
accuracy_dir : Path
37+
Task accuracy QA output directory.
3638
stimuli_labels_dir : Path
3739
Stimuli labels input directory.
3840
"""
@@ -46,6 +48,7 @@ class QAPaths:
4648
motion_dir: Path
4749
isc_dir: Path
4850
stimuli_dir: Path
51+
accuracy_dir: Path
4952
stimuli_labels_dir: Path
5053

5154
@classmethod
@@ -121,11 +124,13 @@ def from_config(cls, config: dict, base_dir: Path | None = None) -> "QAPaths":
121124
motion_subdir = qa_config.get("motion", "motion")
122125
isc_subdir = qa_config.get("isc", "isc")
123126
stimuli_subdir = qa_config.get("stimuli", "stimuli")
127+
accuracy_subdir = qa_config.get("accuracy", "accuracy")
124128
else:
125129
tsnr_subdir = "tsnr"
126130
motion_subdir = "motion"
127131
isc_subdir = "isc"
128132
stimuli_subdir = "stimuli"
133+
accuracy_subdir = "accuracy"
129134

130135
# Stimuli labels directory (input)
131136
stimuli_labels_subdir = dirs.get("stimuli_labels", "stimuli/labels")
@@ -140,6 +145,7 @@ def from_config(cls, config: dict, base_dir: Path | None = None) -> "QAPaths":
140145
motion_dir=(qa_base_dir / motion_subdir).resolve(),
141146
isc_dir=(qa_base_dir / isc_subdir).resolve(),
142147
stimuli_dir=(qa_base_dir / stimuli_subdir).resolve(),
148+
accuracy_dir=(qa_base_dir / accuracy_subdir).resolve(),
143149
stimuli_labels_dir=(derivatives_dir / stimuli_labels_subdir).resolve(),
144150
)
145151

0 commit comments

Comments
 (0)