-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_cifti_utils.py
More file actions
138 lines (96 loc) · 4.69 KB
/
test_cifti_utils.py
File metadata and controls
138 lines (96 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Unit tests for CIFTI utility helpers."""
from __future__ import annotations
import numpy as np
import pytest
from nibabel.cifti2.cifti2_axes import ScalarAxis
from utils import make_dscalar, make_parcels_axis, make_pconn, make_pscalar
from modelarrayio.utils.cifti import brain_names_to_dataframe, extract_cifti_scalar_data
class _FakeHeader:
def __init__(self, axes):
self._axes = axes
def get_axis(self, index: int):
return self._axes[index]
class _FakeCiftiImage:
def __init__(self, axes, data: np.ndarray):
self.header = _FakeHeader(axes)
self.ndim = len(axes)
self._data = data
def get_fdata(self) -> np.ndarray:
return self._data
def test_extract_cifti_scalar_data_returns_data_and_names() -> None:
mask = np.zeros((2, 2, 2), dtype=bool)
mask[0, 0, 0] = True
mask[1, 1, 1] = True
image = make_dscalar(mask, np.array([1.0, 2.0], dtype=np.float32))
data, names = extract_cifti_scalar_data(image)
np.testing.assert_array_equal(data, np.array([1.0, 2.0], dtype=np.float32))
assert names.shape == (2,)
def test_extract_cifti_scalar_data_rejects_wrong_axis_count() -> None:
image = _FakeCiftiImage([ScalarAxis(['a'])], np.array([1.0], dtype=np.float32))
with pytest.raises(ValueError, match='exactly 2 axes'):
extract_cifti_scalar_data(image)
def test_extract_cifti_scalar_data_rejects_unsupported_axes() -> None:
image = _FakeCiftiImage(
[ScalarAxis(['a']), ScalarAxis(['b'])],
np.array([[1.0]], dtype=np.float32),
)
with pytest.raises(ValueError, match='Unsupported CIFTI axis combination'):
extract_cifti_scalar_data(image)
def test_extract_cifti_scalar_data_rejects_inconsistent_reference_names() -> None:
mask = np.zeros((2, 2, 2), dtype=bool)
mask[0, 0, 0] = True
mask[1, 1, 1] = True
image = make_dscalar(mask, np.array([1.0, 2.0], dtype=np.float32))
with pytest.raises(ValueError, match='Inconsistent greyordinate names'):
extract_cifti_scalar_data(image, reference_brain_names=np.array(['wrong', 'names']))
def test_extract_cifti_scalar_data_pscalar_returns_data_and_names() -> None:
parcel_names = ['parcel_A', 'parcel_B', 'parcel_C']
values = np.array([1.0, 2.0, 3.0], dtype=np.float32)
image = make_pscalar(parcel_names, values)
data, names = extract_cifti_scalar_data(image)
np.testing.assert_array_equal(data, values)
assert list(names) == parcel_names
def test_extract_cifti_scalar_data_pscalar_validates_reference_names() -> None:
parcel_names = ['parcel_A', 'parcel_B', 'parcel_C']
values = np.array([1.0, 2.0, 3.0], dtype=np.float32)
image = make_pscalar(parcel_names, values)
with pytest.raises(ValueError, match='Inconsistent parcel names'):
extract_cifti_scalar_data(image, reference_brain_names=np.array(['X', 'Y', 'Z']))
def test_extract_cifti_scalar_data_pconn_flattens_matrix() -> None:
parcel_names = ['parcel_A', 'parcel_B']
n = len(parcel_names)
matrix = np.arange(n * n, dtype=np.float32).reshape(n, n)
image = make_pconn(parcel_names, matrix)
data, names = extract_cifti_scalar_data(image)
np.testing.assert_array_equal(data, matrix.flatten())
assert len(names) == n * n
# Each row parcel name should appear n_col times consecutively
assert list(names[:n]) == [parcel_names[0]] * n
assert list(names[n:]) == [parcel_names[1]] * n
def test_extract_cifti_scalar_data_pconn_validates_reference_names() -> None:
parcel_names = ['parcel_A', 'parcel_B']
n = len(parcel_names)
matrix = np.zeros((n, n), dtype=np.float32)
image = make_pconn(parcel_names, matrix)
# Get the correct element_names first
_, element_names = extract_cifti_scalar_data(image)
# Modify one name to trigger validation failure
bad_names = element_names.copy()
bad_names[0] = 'wrong'
with pytest.raises(ValueError, match='Inconsistent parcel names'):
extract_cifti_scalar_data(image, reference_brain_names=bad_names)
def test_make_parcels_axis_produces_valid_axis() -> None:
"""Smoke test: make_parcels_axis should return a ParcelsAxis with correct length."""
from nibabel.cifti2.cifti2_axes import ParcelsAxis
names = ['A', 'B', 'C']
axis = make_parcels_axis(names)
assert isinstance(axis, ParcelsAxis)
assert len(axis) == len(names)
def test_brain_names_to_dataframe() -> None:
names = np.array(['CORTEX_LEFT', 'CORTEX_LEFT', 'CORTEX_RIGHT'])
gdf, struct_strings = brain_names_to_dataframe(names)
assert len(gdf) == 3
assert 'vertex_id' in gdf.columns
assert 'structure_id' in gdf.columns
assert gdf['vertex_id'].tolist() == [0, 1, 2]
assert len(struct_strings) == 2 # factorize unique structures