Skip to content

Commit 375e994

Browse files
committed
Update tiledb parallel
1 parent 544e782 commit 375e994

5 files changed

Lines changed: 313 additions & 3 deletions

File tree

.github/workflows/ci.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
test:
9+
runs-on: ubuntu-latest
10+
strategy:
11+
fail-fast: false
12+
matrix:
13+
python-version: ["3.11", "3.12"]
14+
steps:
15+
- name: Checkout
16+
uses: actions/checkout@v4
17+
18+
- name: Set up Python
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -e .
27+
pip install pytest
28+
29+
- name: Run tests
30+
run: pytest -q
31+
32+

modelarrayio/cifti.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from collections import defaultdict
44
import os.path as op
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
56
import numpy as np
67
import nibabel as nb
78
import pandas as pd
@@ -101,7 +102,8 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
101102
tdb_compression_level=5,
102103
tdb_shuffle=True,
103104
tdb_tile_voxels=0,
104-
tdb_target_tile_mb=2.0):
105+
tdb_target_tile_mb=2.0,
106+
tdb_workers=None):
105107
"""
106108
Load all fixeldb data.
107109
Parameters
@@ -165,7 +167,8 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
165167
else:
166168
base_uri = op.join(relative_root, output_tdb)
167169
os.makedirs(base_uri, exist_ok=True)
168-
for scalar_name in scalars.keys():
170+
scalar_names = list(scalars.keys())
171+
for scalar_name in scalar_names:
169172
num_subjects = len(scalars[scalar_name])
170173
num_items = scalars[scalar_name][0].shape[0] if num_subjects > 0 else 0
171174
dataset_path = f'scalars/{scalar_name}/values'
@@ -182,8 +185,35 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
182185
target_tile_mb=tdb_target_tile_mb,
183186
sources_list=sources_lists[scalar_name],
184187
)
188+
189+
def _write_scalar_to_tdb(scalar_name):
190+
dataset_path = f'scalars/{scalar_name}/values'
185191
uri = op.join(base_uri, dataset_path)
186192
tdb_write_stripes(uri, scalars[scalar_name])
193+
194+
if not scalar_names:
195+
return 0
196+
197+
# Determine worker count: explicit value takes precedence; fallback to CPU count.
198+
worker_count = tdb_workers if isinstance(tdb_workers, int) and tdb_workers > 0 else None
199+
if worker_count is None:
200+
cpu_count = os.cpu_count() or 1
201+
worker_count = min(len(scalar_names), max(1, cpu_count))
202+
else:
203+
worker_count = min(len(scalar_names), worker_count)
204+
205+
if worker_count <= 1:
206+
for scalar_name in scalar_names:
207+
_write_scalar_to_tdb(scalar_name)
208+
else:
209+
desc = "TileDB scalars"
210+
with ThreadPoolExecutor(max_workers=worker_count) as executor:
211+
futures = {
212+
executor.submit(_write_scalar_to_tdb, scalar_name): scalar_name
213+
for scalar_name in scalar_names
214+
}
215+
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
216+
future.result()
187217
return 0
188218

189219

@@ -221,7 +251,8 @@ def main():
221251
tdb_compression_level=args.tdb_compression_level,
222252
tdb_shuffle=args.tdb_shuffle,
223253
tdb_tile_voxels=args.tdb_tile_voxels,
224-
tdb_target_tile_mb=args.tdb_target_tile_mb)
254+
tdb_target_tile_mb=args.tdb_target_tile_mb,
255+
tdb_workers=args.tdb_workers)
225256
return status
226257

227258

modelarrayio/parser.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ def add_tiledb_storage_args(parser):
123123
type=float,
124124
help="Target tile size in MiB when auto-computing item tile length. Default 2.0",
125125
default=2.0)
126+
parser.add_argument(
127+
"--tdb-workers", "--tdb_workers",
128+
type=int,
129+
help=(
130+
"Maximum number of TileDB write workers. Default 0 (auto, uses CPU count). "
131+
"Set to 1 to disable parallel writes."
132+
),
133+
default=0)
126134
return parser
127135

128136

tests/test_cifti_cli.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import os.path as op
3+
import csv
4+
import subprocess
5+
import sys
6+
7+
import numpy as np
8+
import nibabel as nb
9+
from nibabel.cifti2.cifti2_axes import ScalarAxis, BrainModelAxis
10+
import h5py
11+
12+
13+
def _make_synthetic_cifti_dscalar(mask_bool: np.ndarray, values: np.ndarray) -> nb.Cifti2Image:
14+
# Build axes: single scalar and a brain model from a volumetric mask
15+
scalar_axis = ScalarAxis(["synthetic"]) # one scalar map
16+
brain_axis = BrainModelAxis.from_mask(mask_bool)
17+
header = nb.cifti2.Cifti2Header.from_axes((scalar_axis, brain_axis))
18+
# Data must be 2D: (nmaps, ngrayordinates)
19+
data_2d = values.reshape(1, -1).astype(np.float32)
20+
return nb.Cifti2Image(data_2d, header=header)
21+
22+
23+
def test_concifti_cli_creates_expected_hdf5(tmp_path):
24+
# Create a small volumetric mask for brain model axis
25+
vol_shape = (3, 3, 3)
26+
mask = np.zeros(vol_shape, dtype=bool)
27+
true_vox = [(0, 0, 0), (0, 1, 2), (1, 1, 1), (2, 2, 0), (2, 1, 2)]
28+
for ijk in true_vox:
29+
mask[ijk] = True
30+
n_grayordinates = int(mask.sum())
31+
32+
# Create two subjects with simple sequences
33+
subjects = []
34+
for sidx in range(2):
35+
vals = np.arange(n_grayordinates, dtype=np.float32) + sidx
36+
img = _make_synthetic_cifti_dscalar(mask, vals)
37+
path = tmp_path / f"sub-{sidx+1}.dscalar.nii"
38+
img.to_filename(path)
39+
subjects.append(str(path.name))
40+
41+
# Build cohort CSV
42+
cohort_csv = tmp_path / "cohort_cifti.csv"
43+
with cohort_csv.open("w", newline="") as f:
44+
writer = csv.DictWriter(f, fieldnames=["scalar_name", "source_file"])
45+
writer.writeheader()
46+
for sname in subjects:
47+
writer.writerow({
48+
"scalar_name": "THICK",
49+
"source_file": sname,
50+
})
51+
52+
out_h5 = tmp_path / "out_cifti.h5"
53+
cmd = [
54+
sys.executable,
55+
"-m",
56+
"modelarrayio.cifti",
57+
"--cohort-file", str(cohort_csv.name),
58+
"--relative-root", str(tmp_path),
59+
"--output-hdf5", str(out_h5.name),
60+
"--backend", "hdf5",
61+
"--dtype", "float32",
62+
"--compression", "gzip",
63+
"--compression-level", "1",
64+
"--shuffle", "True",
65+
"--chunk-voxels", "0",
66+
"--target-chunk-mb", "1.0",
67+
]
68+
env = os.environ.copy()
69+
proc = subprocess.run(cmd, cwd=str(tmp_path), env=env, capture_output=True, text=True)
70+
assert proc.returncode == 0, f"concifti failed: {proc.stdout}\n{proc.stderr}"
71+
assert op.exists(out_h5)
72+
73+
# Validate HDF5 contents
74+
with h5py.File(out_h5, "r") as h5:
75+
assert "greyordinates" in h5
76+
grey = np.array(h5["greyordinates"]) # stored as transposed table (2, N)
77+
assert grey.shape[0] == 2 # vertex_id, structure_id
78+
n = grey.shape[1]
79+
assert n == n_grayordinates
80+
81+
# structure_names present
82+
g = h5["greyordinates"]
83+
assert "structure_names" in g.attrs
84+
struct_names = g.attrs["structure_names"]
85+
assert len(struct_names) >= 1
86+
87+
# Scalars dataset
88+
dset = h5["scalars/THICK/values"]
89+
num_subjects, num_items = dset.shape
90+
assert num_subjects == 2
91+
assert num_items == n_grayordinates
92+
93+
# Column names exist and match subjects count
94+
grp = h5["scalars/THICK"]
95+
assert "column_names" in grp
96+
colnames = list(map(lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x), grp["column_names"][...]))
97+
assert len(colnames) == 2
98+
99+
# Spot-check a couple values
100+
assert np.isclose(float(dset[0, 0]), 0.0)
101+
assert np.isclose(float(dset[1, 0]), 1.0)
102+
103+

tests/test_voxels_cli.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
import os.path as op
3+
import csv
4+
import subprocess
5+
import sys
6+
7+
import numpy as np
8+
import nibabel as nb
9+
import h5py
10+
11+
12+
def _make_nifti(data, affine=None):
13+
if affine is None:
14+
affine = np.eye(4)
15+
return nb.Nifti1Image(data.astype(np.float32), affine)
16+
17+
18+
def _ijk_value(i, j, k):
19+
return i * 100.0 + j * 10.0 + k * 1.0
20+
21+
22+
def test_convoxel_cli_creates_expected_hdf5(tmp_path):
23+
# Small synthetic volume
24+
shape = (5, 6, 7)
25+
group_mask = np.zeros(shape, dtype=bool)
26+
# Create a sparse pattern of true voxels
27+
true_coords = [(0, 1, 1), (1, 2, 3), (2, 4, 5), (3, 0, 0), (4, 5, 6), (1, 1, 4), (2, 2, 2)]
28+
for (i, j, k) in true_coords:
29+
group_mask[i, j, k] = True
30+
31+
# Save group mask
32+
group_mask_img = _make_nifti(group_mask.astype(np.uint8))
33+
group_mask_file = tmp_path / "group_mask.nii.gz"
34+
group_mask_img.to_filename(group_mask_file)
35+
36+
# Create two subjects with individual masks (one drops a voxel)
37+
subjects = []
38+
for sidx in range(2):
39+
# Scalar volume encodes f(i,j,k)
40+
scalar = np.zeros(shape, dtype=np.float32)
41+
for (i, j, k) in true_coords:
42+
scalar[i, j, k] = _ijk_value(i, j, k) + sidx # slight per-subject shift
43+
44+
# Individual mask: subject 1 omits one voxel
45+
indiv_mask = group_mask.copy()
46+
if sidx == 1:
47+
omit = true_coords[1]
48+
indiv_mask[omit] = False
49+
50+
scalar_img = _make_nifti(scalar)
51+
mask_img = _make_nifti(indiv_mask.astype(np.uint8))
52+
53+
scalar_path = tmp_path / f"sub-{sidx+1}_scalar.nii.gz"
54+
mask_path = tmp_path / f"sub-{sidx+1}_mask.nii.gz"
55+
scalar_img.to_filename(scalar_path)
56+
mask_img.to_filename(mask_path)
57+
subjects.append((str(scalar_path.name), str(mask_path.name)))
58+
59+
# Build cohort CSV (relative paths)
60+
cohort_csv = tmp_path / "cohort.csv"
61+
with cohort_csv.open("w", newline="") as f:
62+
writer = csv.DictWriter(f, fieldnames=["scalar_name", "source_file", "source_mask_file"])
63+
writer.writeheader()
64+
for sidx, (scalar_name, mask_name) in enumerate(subjects):
65+
writer.writerow({
66+
"scalar_name": "FA",
67+
"source_file": scalar_name,
68+
"source_mask_file": mask_name,
69+
})
70+
71+
# Run CLI using module to avoid PATH issues
72+
out_h5 = tmp_path / "out.h5"
73+
cmd = [
74+
sys.executable,
75+
"-m",
76+
"modelarrayio.voxels",
77+
"--group-mask-file", str(group_mask_file.name),
78+
"--cohort-file", str(cohort_csv.name),
79+
"--relative-root", str(tmp_path),
80+
"--output-hdf5", str(out_h5.name),
81+
"--backend", "hdf5",
82+
"--dtype", "float32",
83+
"--compression", "gzip",
84+
"--compression-level", "1",
85+
"--shuffle", "True",
86+
"--chunk-voxels", "0",
87+
"--target-chunk-mb", "1.0",
88+
]
89+
env = os.environ.copy()
90+
proc = subprocess.run(cmd, cwd=str(tmp_path), env=env, capture_output=True, text=True)
91+
assert proc.returncode == 0, f"convoxel failed: {proc.stdout}\n{proc.stderr}"
92+
assert op.exists(out_h5)
93+
94+
# Validate HDF5 contents
95+
with h5py.File(out_h5, "r") as h5:
96+
assert "voxels" in h5
97+
vox = np.array(h5["voxels"]) # stored as transposed table (3, N)
98+
assert vox.shape[0] == 3
99+
ijk = np.vstack(np.nonzero(group_mask)) # (3, N) ordered by i, then j, then k
100+
assert vox.shape[1] == ijk.shape[1]
101+
102+
# Check ordering matches nonzero order (allow exact match)
103+
assert np.array_equal(vox, ijk)
104+
105+
# Scalars dataset
106+
dset = h5["scalars/FA/values"]
107+
num_subjects, num_voxels = dset.shape
108+
assert num_subjects == 2
109+
assert num_voxels == ijk.shape[1]
110+
111+
# Column names exist and match subjects count
112+
grp = h5["scalars/FA"]
113+
assert "column_names" in grp
114+
colnames = list(map(lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x), grp["column_names"][...]))
115+
assert len(colnames) == 2
116+
117+
# Spot-check a voxel mapping (pick the third voxel)
118+
vidx = 2
119+
i, j, k = int(ijk[0, vidx]), int(ijk[1, vidx]), int(ijk[2, vidx])
120+
expected_s0 = _ijk_value(i, j, k) + 0
121+
expected_s1 = _ijk_value(i, j, k) + 1
122+
# If subject 1 omitted that voxel, it should be NaN (masked out becomes NaN on flatten)
123+
v0 = float(dset[0, vidx])
124+
v1 = float(dset[1, vidx])
125+
assert np.isclose(v0, expected_s0, equal_nan=True)
126+
# Determine whether subject 1 omitted this voxel
127+
omitted = False
128+
omit = true_coords[1]
129+
if (i, j, k) == omit:
130+
omitted = True
131+
if omitted:
132+
assert np.isnan(v1)
133+
else:
134+
assert np.isclose(v1, expected_s1, equal_nan=True)
135+
136+

0 commit comments

Comments
 (0)