-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcompute_stats_xarray.py
More file actions
158 lines (123 loc) · 5.35 KB
/
compute_stats_xarray.py
File metadata and controls
158 lines (123 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Compute pipeline variability statistics (IQR, STD, mean r, median r) from
xarray connectivity data using vectorised operations. Works for both ds_fm and oasis datasets.
All four statistics are computed across the iteration dimension and saved as a
NetCDF file.
Examples:
python compute_stats_xarray.py data/ds_fm25.nc --output-dir results
python compute_stats_xarray.py ds_oasis.nc --output-dir results --use-dask --n-workers 8
"""
import argparse
import time
from pathlib import Path
import xarray as xr
def _stats_path(output_dir, dataset_name):
return Path(output_dir) / f"{dataset_name}_stats.nc"
def compute_stats(ds):
"""Compute variability IQR, STD, mean r, median r"""
fc = ds.functional_connectivity
print(f"Computing statistics over {dict(ds.sizes)} ...")
t0 = time.perf_counter()
quantiles = fc.quantile([0.25, 0.5, 0.75], dim="iteration")
iqr = quantiles.sel(quantile=0.75, drop=True) - quantiles.sel(quantile=0.25, drop=True)
median_r = quantiles.sel(quantile=0.5, drop=True)
stats = xr.Dataset({
"iqr": iqr,
"std": fc.std(dim="iteration", skipna=True),
"mean_r": fc.mean(dim="iteration", skipna=True),
"median_r": median_r,
})
# Force computation if backed by Dask
if hasattr(stats, "compute"):
stats = stats.compute()
elapsed = time.perf_counter() - t0
print(f"Computation complete in {elapsed:.1f}s.")
return stats
def save_stats(stats_ds, output_dir, dataset_name):
"""Save statistics Dataset to NetCDF."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
path = _stats_path(output_dir, dataset_name)
stats_ds.to_netcdf(path)
print(f"Saved statistics: {path} {dict(stats_ds.sizes)}")
return path
def load_stats(output_dir, dataset_name):
"""Load statistics Dataset from NetCDF."""
path = _stats_path(output_dir, dataset_name)
if not path.exists():
raise FileNotFoundError(f"Statistics file not found: {path}")
stats_ds = xr.open_dataset(path)
print(f"Loaded statistics: {path} {dict(stats_ds.sizes)}")
return stats_ds
def parse_args():
parser = argparse.ArgumentParser(
description="Compute pipeline variability statistics and save to NetCDF.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument("input", type=Path, help="Input NetCDF (.nc) file")
parser.add_argument("--output-dir", type=Path, default=Path("results"),
help="Directory for output NetCDF (default: results)")
parser.add_argument("--dataset-name", type=str, default=None,
help="Name prefix for output file (default: input stem)")
parser.add_argument("--use-dask", action="store_true",
help="Use Dask for out-of-core computation (for large datasets)")
parser.add_argument("--n-workers", type=int, default=4,
help="Number of Dask workers (default: 4, only with --use-dask)")
return parser.parse_args()
def main():
args = parse_args()
input_path = args.input
if not input_path.exists():
raise FileNotFoundError(f"Input file does not exist: {input_path}")
dataset_name = args.dataset_name or input_path.stem
print(f"Loading dataset from {input_path}")
if args.use_dask:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(
n_workers=args.n_workers,
threads_per_worker=4,
processes=True,
dashboard_address="localhost:0",
)
client = Client(cluster)
print(f"Dask cluster: {args.n_workers} workers, dashboard: {client.dashboard_link}")
ds = xr.open_dataset(
input_path,
chunks={"iteration": -1, "subject": "auto", "pipeline": "auto", "cell": "auto"},
)
else:
ds = xr.open_dataset(input_path)
subjects = ds.subject.values
pipelines = ds.pipeline.values
print(f"Found {len(subjects)} subjects and {len(pipelines)} pipelines")
stats_ds = compute_stats(ds)
output_path = save_stats(stats_ds, args.output_dir, dataset_name)
# --- Sanity checks ---
expected_vars = {"iqr", "std", "mean_r", "median_r"}
actual_vars = set(stats_ds.data_vars)
assert expected_vars == actual_vars, (
f"Expected variables {expected_vars}, got {actual_vars}"
)
assert "subject" in stats_ds.dims, "Missing 'subject' dimension"
assert "pipeline" in stats_ds.dims, "Missing 'pipeline' dimension"
assert "cell" in stats_ds.dims, "Missing 'cell' dimension"
assert "iteration" not in stats_ds.dims, (
"'iteration' dim should have been reduced"
)
assert stats_ds.sizes["subject"] == len(subjects), "Subject count mismatch"
assert stats_ds.sizes["pipeline"] == len(pipelines), "Pipeline count mismatch"
# Check no all-NaN variables
for var in expected_vars:
assert not stats_ds[var].isnull().all(), f"Variable '{var}' is entirely NaN"
# Round-trip: reload and verify shapes match
reloaded = xr.open_dataset(output_path)
assert dict(reloaded.sizes) == dict(stats_ds.sizes), (
f"Round-trip size mismatch: {dict(reloaded.sizes)} vs {dict(stats_ds.sizes)}"
)
reloaded.close()
print("All checks passed.")
print("Done.")
return 0
if __name__ == "__main__":
raise SystemExit(main())