|
| 1 | +import os |
| 2 | +import os.path as op |
| 3 | +import csv |
| 4 | +import subprocess |
| 5 | +import sys |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import nibabel as nb |
| 9 | +import h5py |
| 10 | + |
| 11 | + |
| 12 | +def _make_nifti(data, affine=None): |
| 13 | + if affine is None: |
| 14 | + affine = np.eye(4) |
| 15 | + return nb.Nifti1Image(data.astype(np.float32), affine) |
| 16 | + |
| 17 | + |
| 18 | +def _ijk_value(i, j, k): |
| 19 | + return i * 100.0 + j * 10.0 + k * 1.0 |
| 20 | + |
| 21 | + |
| 22 | +def test_convoxel_cli_creates_expected_hdf5(tmp_path): |
| 23 | + # Small synthetic volume |
| 24 | + shape = (5, 6, 7) |
| 25 | + group_mask = np.zeros(shape, dtype=bool) |
| 26 | + # Create a sparse pattern of true voxels |
| 27 | + true_coords = [(0, 1, 1), (1, 2, 3), (2, 4, 5), (3, 0, 0), (4, 5, 6), (1, 1, 4), (2, 2, 2)] |
| 28 | + for (i, j, k) in true_coords: |
| 29 | + group_mask[i, j, k] = True |
| 30 | + |
| 31 | + # Save group mask |
| 32 | + group_mask_img = _make_nifti(group_mask.astype(np.uint8)) |
| 33 | + group_mask_file = tmp_path / "group_mask.nii.gz" |
| 34 | + group_mask_img.to_filename(group_mask_file) |
| 35 | + |
| 36 | + # Create two subjects with individual masks (one drops a voxel) |
| 37 | + subjects = [] |
| 38 | + for sidx in range(2): |
| 39 | + # Scalar volume encodes f(i,j,k) |
| 40 | + scalar = np.zeros(shape, dtype=np.float32) |
| 41 | + for (i, j, k) in true_coords: |
| 42 | + scalar[i, j, k] = _ijk_value(i, j, k) + sidx # slight per-subject shift |
| 43 | + |
| 44 | + # Individual mask: subject 1 omits one voxel |
| 45 | + indiv_mask = group_mask.copy() |
| 46 | + if sidx == 1: |
| 47 | + omit = true_coords[1] |
| 48 | + indiv_mask[omit] = False |
| 49 | + |
| 50 | + scalar_img = _make_nifti(scalar) |
| 51 | + mask_img = _make_nifti(indiv_mask.astype(np.uint8)) |
| 52 | + |
| 53 | + scalar_path = tmp_path / f"sub-{sidx+1}_scalar.nii.gz" |
| 54 | + mask_path = tmp_path / f"sub-{sidx+1}_mask.nii.gz" |
| 55 | + scalar_img.to_filename(scalar_path) |
| 56 | + mask_img.to_filename(mask_path) |
| 57 | + subjects.append((str(scalar_path.name), str(mask_path.name))) |
| 58 | + |
| 59 | + # Build cohort CSV (relative paths) |
| 60 | + cohort_csv = tmp_path / "cohort.csv" |
| 61 | + with cohort_csv.open("w", newline="") as f: |
| 62 | + writer = csv.DictWriter(f, fieldnames=["scalar_name", "source_file", "source_mask_file"]) |
| 63 | + writer.writeheader() |
| 64 | + for sidx, (scalar_name, mask_name) in enumerate(subjects): |
| 65 | + writer.writerow({ |
| 66 | + "scalar_name": "FA", |
| 67 | + "source_file": scalar_name, |
| 68 | + "source_mask_file": mask_name, |
| 69 | + }) |
| 70 | + |
| 71 | + # Run CLI using module to avoid PATH issues |
| 72 | + out_h5 = tmp_path / "out.h5" |
| 73 | + cmd = [ |
| 74 | + sys.executable, |
| 75 | + "-m", |
| 76 | + "modelarrayio.voxels", |
| 77 | + "--group-mask-file", str(group_mask_file.name), |
| 78 | + "--cohort-file", str(cohort_csv.name), |
| 79 | + "--relative-root", str(tmp_path), |
| 80 | + "--output-hdf5", str(out_h5.name), |
| 81 | + "--backend", "hdf5", |
| 82 | + "--dtype", "float32", |
| 83 | + "--compression", "gzip", |
| 84 | + "--compression-level", "1", |
| 85 | + "--shuffle", "True", |
| 86 | + "--chunk-voxels", "0", |
| 87 | + "--target-chunk-mb", "1.0", |
| 88 | + ] |
| 89 | + env = os.environ.copy() |
| 90 | + proc = subprocess.run(cmd, cwd=str(tmp_path), env=env, capture_output=True, text=True) |
| 91 | + assert proc.returncode == 0, f"convoxel failed: {proc.stdout}\n{proc.stderr}" |
| 92 | + assert op.exists(out_h5) |
| 93 | + |
| 94 | + # Validate HDF5 contents |
| 95 | + with h5py.File(out_h5, "r") as h5: |
| 96 | + assert "voxels" in h5 |
| 97 | + vox = np.array(h5["voxels"]) # stored as transposed table (3, N) |
| 98 | + assert vox.shape[0] == 3 |
| 99 | + ijk = np.vstack(np.nonzero(group_mask)) # (3, N) ordered by i, then j, then k |
| 100 | + assert vox.shape[1] == ijk.shape[1] |
| 101 | + |
| 102 | + # Check ordering matches nonzero order (allow exact match) |
| 103 | + assert np.array_equal(vox, ijk) |
| 104 | + |
| 105 | + # Scalars dataset |
| 106 | + dset = h5["scalars/FA/values"] |
| 107 | + num_subjects, num_voxels = dset.shape |
| 108 | + assert num_subjects == 2 |
| 109 | + assert num_voxels == ijk.shape[1] |
| 110 | + |
| 111 | + # Column names exist and match subjects count |
| 112 | + grp = h5["scalars/FA"] |
| 113 | + assert "column_names" in grp |
| 114 | + colnames = list(map(lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x), grp["column_names"][...])) |
| 115 | + assert len(colnames) == 2 |
| 116 | + |
| 117 | + # Spot-check a voxel mapping (pick the third voxel) |
| 118 | + vidx = 2 |
| 119 | + i, j, k = int(ijk[0, vidx]), int(ijk[1, vidx]), int(ijk[2, vidx]) |
| 120 | + expected_s0 = _ijk_value(i, j, k) + 0 |
| 121 | + expected_s1 = _ijk_value(i, j, k) + 1 |
| 122 | + # If subject 1 omitted that voxel, it should be NaN (masked out becomes NaN on flatten) |
| 123 | + v0 = float(dset[0, vidx]) |
| 124 | + v1 = float(dset[1, vidx]) |
| 125 | + assert np.isclose(v0, expected_s0, equal_nan=True) |
| 126 | + # Determine whether subject 1 omitted this voxel |
| 127 | + omitted = False |
| 128 | + omit = true_coords[1] |
| 129 | + if (i, j, k) == omit: |
| 130 | + omitted = True |
| 131 | + if omitted: |
| 132 | + assert np.isnan(v1) |
| 133 | + else: |
| 134 | + assert np.isclose(v1, expected_s1, equal_nan=True) |
| 135 | + |
| 136 | + |
0 commit comments