-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplot_hexbin.py
More file actions
249 lines (202 loc) · 9.84 KB
/
plot_hexbin.py
File metadata and controls
249 lines (202 loc) · 9.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Hexbin plot of IQR variability vs mean (or median) correlation strength, per edge.
Examples:
# From raw connectivity data (small datasets)
python plot_hexbin.py data/ds_fm25_2.nc --all-pipelines --all-cells
# From pre-computed stats NetCDF (large datasets like OASIS)
python plot_hexbin.py --stats-nc results/ds_oasis_stats.nc --all-pipelines --all-cells
"""
import argparse
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from utils import slugify
DEFAULT_HEXBIN_MINCNT = 1
def compute_edge_stats(ds, pipeline_name, all_cells=False):
"""Compute mean r and IQR per edge.
all_cells=False (default): averages across subjects first; returns (n_cells,).
all_cells=True: returns one point per (subject, cell) pair; returns (n_subjects * n_cells,).
NaN entries are preserved and filtered at plot time.
"""
fc = ds.functional_connectivity.sel(pipeline=pipeline_name)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="All-NaN slice", category=RuntimeWarning)
warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning)
mean_r_by_subject = fc.mean(dim="iteration") # (subject, cell)
iqr_by_subject = fc.quantile(0.75, dim="iteration") - fc.quantile(0.25, dim="iteration")
if all_cells:
return mean_r_by_subject.values.ravel(), iqr_by_subject.values.ravel()
return mean_r_by_subject.mean(dim="subject").values, iqr_by_subject.mean(dim="subject").values
def retrieve_stats(stats_ds, pipeline_name, x_var="mean_r", all_cells=False):
"""Retrieve a pair of variables from a pre-computed stats NetCDF.
Returns (x, iqr) where x is selected by ``x_var`` (e.g. "mean_r" or "median_r").
stats_ds must have dims ``(subject, pipeline, cell)`` as produced by
compute_stats_xarray.py.
"""
x = stats_ds[x_var].sel(pipeline=pipeline_name)
iqr = stats_ds["iqr"].sel(pipeline=pipeline_name)
if all_cells:
return x.values.ravel(), iqr.values.ravel()
return x.mean(dim="subject").values, iqr.mean(dim="subject").values
def _draw_hexbin(ax, mean_r, iqr, pipeline, mincnt=DEFAULT_HEXBIN_MINCNT):
"""Draw a single hexbin panel on the given axes. Returns hexbin artist."""
mask = np.isfinite(mean_r) & np.isfinite(iqr)
x, y = mean_r[mask], iqr[mask]
if x.size == 0:
ax.set_visible(False)
return None
print(f" Plotting {x.size:,} valid points for {pipeline} (hexbin mincnt={mincnt})")
hb = ax.hexbin(x, y, gridsize=80, cmap="viridis", mincnt=mincnt)
ax.set_xlim(-1, 1)
ax.set_ylim(0, 2) #IQR
ax.set_title(pipeline, fontsize=9)
ax.spines[["top", "right"]].set_visible(False)
return hb
def plot_mean_vs_iqr(mean_r, iqr, pipeline, n_subjects, output_dir, suffix="", mincnt=DEFAULT_HEXBIN_MINCNT):
"""Create a single hexbin density plot and save to disk."""
fig, ax = plt.subplots(figsize=(7, 5))
hb = _draw_hexbin(ax, mean_r, iqr, pipeline, mincnt=mincnt)
if hb is None:
print(f" No valid edges for {pipeline}, skipping.")
plt.close(fig)
return None
cb = fig.colorbar(hb, ax=ax, shrink=0.8, label="Count")
cb.set_ticks([t for t in sorted({mincnt, *cb.get_ticks()}) if cb.vmin <= t <= cb.vmax])
ax.set_xlabel("Mean correlation (r)")
ax.set_ylabel("IQR")
ax.set_title(f"{pipeline} ({n_subjects} subjects)", fontsize=12)
output_dir.mkdir(parents=True, exist_ok=True)
filename = f"mean_vs_iqr_{slugify(pipeline)}{suffix}.png"
output_path = output_dir / filename
fig.savefig(output_path, dpi=200, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" Saved: {output_path}")
return output_path
def plot_median_vs_iqr(median_r, iqr, pipeline, n_subjects, output_dir, suffix="", mincnt=DEFAULT_HEXBIN_MINCNT):
"""Create a single hexbin density plot of median r vs IQR and save to disk."""
fig, ax = plt.subplots(figsize=(7, 5))
hb = _draw_hexbin(ax, median_r, iqr, pipeline, mincnt=mincnt)
if hb is None:
print(f" No valid edges for {pipeline}, skipping.")
plt.close(fig)
return None
cb = fig.colorbar(hb, ax=ax, shrink=0.8, label="Count")
cb.set_ticks([t for t in sorted({mincnt, *cb.get_ticks()}) if cb.vmin <= t <= cb.vmax])
ax.set_xlabel("Median correlation (r)")
ax.set_ylabel("IQR")
ax.set_title(f"{pipeline} ({n_subjects} subjects)", fontsize=12)
output_dir.mkdir(parents=True, exist_ok=True)
filename = f"median_vs_iqr_{slugify(pipeline)}{suffix}.png"
output_path = output_dir / filename
fig.savefig(output_path, dpi=200, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" Saved: {output_path}")
return output_path
def plot_combined(results, n_subjects, output_dir, suffix="", x_label="Mean r", mincnt=DEFAULT_HEXBIN_MINCNT):
"""Plot all pipelines side by side in a grid."""
n = len(results)
ncols = 5
nrows = int(np.ceil(n / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3.5 * nrows), squeeze=False)
last_hb = None
for idx, (pipeline, x, iqr) in enumerate(results):
ax = axes[idx // ncols, idx % ncols]
hb = _draw_hexbin(ax, x, iqr, pipeline, mincnt=mincnt)
if hb is not None:
last_hb = hb
if idx % ncols == 0:
ax.set_ylabel("IQR")
if idx // ncols == nrows - 1:
ax.set_xlabel(x_label)
# Hide unused panels
for idx in range(n, nrows * ncols):
axes[idx // ncols, idx % ncols].set_visible(False)
x_short = x_label.split()[0] # "Mean" or "Median"
fig.suptitle(f"IQR vs {x_short} r ({n_subjects} subjects)",
fontsize=14, fontweight="bold")
fig.tight_layout(rect=[0, 0, 0.92, 0.95])
if last_hb is not None:
cbar_ax = fig.add_axes([0.93, 0.08, 0.015, 0.82])
cb = fig.colorbar(last_hb, cax=cbar_ax, label="Count")
cb.set_ticks([t for t in sorted({mincnt, *cb.get_ticks()}) if cb.vmin <= t <= cb.vmax])
output_dir.mkdir(parents=True, exist_ok=True)
tag = slugify(x_short.lower())
output_path = output_dir / f"ALL_{tag}_vs_iqr{suffix}.png"
fig.savefig(output_path, dpi=200, bbox_inches="tight", facecolor="white")
plt.close(fig)
print(f" Saved combined: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(
description="Hexbin scatterplot: IQR variability vs mean correlation per edge.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
add = parser.add_argument
add("input", type=Path, nargs="?", default=None,
help="Input raw connectivity NetCDF (not needed with --stats-nc)")
add("--stats-nc", type=Path, default=None,
help="Pre-computed stats NetCDF from compute_stats_xarray.py")
add("--pipeline", type=str, default=None, help="Pipeline name (default: first)")
add("--all-pipelines", action="store_true", help="Iterate over all pipelines")
add("--all-cells", action="store_true",
help="Plot one point per (subject, cell) instead of averaging across subjects")
add("--median", action="store_true",
help="Plot median r vs IQR instead of mean r vs IQR")
add("--save-all-plots", action="store_true",
help="Also save individual per-pipeline plots; otherwise only save the combined plot")
add("--hexbin-mincnt", type=int, default=DEFAULT_HEXBIN_MINCNT,
help="Minimum count per occupied hexbin to display (default: 1)")
add("--output-dir", type=Path, default=Path("plots_exploratory"))
args = parser.parse_args()
if args.stats_nc is None and args.input is None:
parser.error("Provide either a raw NetCDF input file or --stats-nc.")
suffix = "_all_cells" if args.all_cells else ""
if args.stats_nc is not None:
print(f"Loading pre-computed stats from {args.stats_nc}")
stats_ds = xr.open_dataset(args.stats_nc)
pipelines = [str(p) for p in stats_ds.pipeline.values]
n_subjects = stats_ds.sizes["subject"]
pipeline_names = pipelines if args.all_pipelines else [args.pipeline or pipelines[0]]
x_var = "median_r" if args.median else "mean_r"
plot_fn = plot_median_vs_iqr if args.median else plot_mean_vs_iqr
results = []
for pipeline_name in pipeline_names:
print(f"Processing: {pipeline_name} ({n_subjects} subjects, from stats NC)")
x, iqr = retrieve_stats(stats_ds, pipeline_name, x_var=x_var, all_cells=args.all_cells)
if args.save_all_plots or len(pipeline_names) == 1:
plot_fn(
x, iqr, pipeline_name, n_subjects, args.output_dir, suffix,
mincnt=args.hexbin_mincnt,
)
results.append((pipeline_name, x, iqr))
else:
ds = xr.open_dataset(args.input)
pipelines = ds.pipeline.values
n_subjects = ds.sizes["subject"]
pipeline_names = (
[str(p) for p in pipelines] if args.all_pipelines
else [args.pipeline or str(pipelines[0])]
)
results = []
for pipeline_name in pipeline_names:
print(f"Processing: {pipeline_name} ({n_subjects} subjects)")
mean_r, iqr = compute_edge_stats(ds, pipeline_name, args.all_cells)
if args.save_all_plots or len(pipeline_names) == 1:
plot_mean_vs_iqr(
mean_r, iqr, pipeline_name, n_subjects, args.output_dir, suffix,
mincnt=args.hexbin_mincnt,
)
results.append((pipeline_name, mean_r, iqr))
if len(results) > 1:
x_label = "Median r" if args.median else "Mean r"
plot_combined(
results, n_subjects, args.output_dir, suffix,
x_label=x_label, mincnt=args.hexbin_mincnt,
)
print("Done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())