1- """Unit tests for CIFTI validation helpers."""
1+ """Unit tests for CIFTI utility helpers."""
22
33from __future__ import annotations
44
5- import nibabel as nb
65import numpy as np
76import pytest
8- from nibabel .cifti2 .cifti2_axes import BrainModelAxis , ParcelsAxis , ScalarAxis
7+ from nibabel .cifti2 .cifti2_axes import ScalarAxis
8+ from utils import make_dscalar , make_parcels_axis , make_pconn , make_pscalar
99
10- from modelarrayio .utils .cifti import extract_cifti_scalar_data
11-
12-
13- def _make_scalar_cifti (mask_bool : np .ndarray , values : np .ndarray ) -> nb .Cifti2Image :
14- scalar_axis = ScalarAxis (['synthetic' ])
15- brain_axis = BrainModelAxis .from_mask (mask_bool )
16- header = nb .cifti2 .Cifti2Header .from_axes ((scalar_axis , brain_axis ))
17- return nb .Cifti2Image (values .reshape (1 , - 1 ).astype (np .float32 ), header = header )
18-
19-
20- def _make_parcels_axis (parcel_names : list [str ]) -> ParcelsAxis :
21- """Create a minimal surface-only ParcelsAxis for testing."""
22- # One vertex per parcel on the left cortex
23- n = len (parcel_names )
24- nvertices = {'CIFTI_STRUCTURE_CORTEX_LEFT' : n }
25- vox_dtype = np .dtype ([('ijk' , '<i4' , (3 ,))])
26- voxels = [np .array ([], dtype = vox_dtype ) for _ in range (n )]
27- vertices = [{'CIFTI_STRUCTURE_CORTEX_LEFT' : np .array ([i ], dtype = np .int32 )} for i in range (n )]
28- affine = np .eye (4 )
29- volume_shape = (10 , 10 , 10 )
30- return ParcelsAxis (parcel_names , voxels , vertices , affine , volume_shape , nvertices )
31-
32-
33- def _make_pscalar_cifti (parcel_names : list [str ], values : np .ndarray ) -> nb .Cifti2Image :
34- scalar_axis = ScalarAxis (['synthetic' ])
35- parcels_axis = _make_parcels_axis (parcel_names )
36- header = nb .cifti2 .Cifti2Header .from_axes ((scalar_axis , parcels_axis ))
37- return nb .Cifti2Image (values .reshape (1 , - 1 ).astype (np .float32 ), header = header )
38-
39-
40- def _make_pconn_cifti (parcel_names : list [str ], values : np .ndarray ) -> nb .Cifti2Image :
41- parcels_axis = _make_parcels_axis (parcel_names )
42- header = nb .cifti2 .Cifti2Header .from_axes ((parcels_axis , parcels_axis ))
43- n = len (parcel_names )
44- return nb .Cifti2Image (values .reshape (n , n ).astype (np .float32 ), header = header )
10+ from modelarrayio .utils .cifti import brain_names_to_dataframe , extract_cifti_scalar_data
4511
4612
4713class _FakeHeader :
@@ -66,7 +32,7 @@ def test_extract_cifti_scalar_data_returns_data_and_names() -> None:
6632 mask = np .zeros ((2 , 2 , 2 ), dtype = bool )
6733 mask [0 , 0 , 0 ] = True
6834 mask [1 , 1 , 1 ] = True
69- image = _make_scalar_cifti (mask , np .array ([1.0 , 2.0 ], dtype = np .float32 ))
35+ image = make_dscalar (mask , np .array ([1.0 , 2.0 ], dtype = np .float32 ))
7036
7137 data , names = extract_cifti_scalar_data (image )
7238
@@ -95,7 +61,7 @@ def test_extract_cifti_scalar_data_rejects_inconsistent_reference_names() -> Non
9561 mask = np .zeros ((2 , 2 , 2 ), dtype = bool )
9662 mask [0 , 0 , 0 ] = True
9763 mask [1 , 1 , 1 ] = True
98- image = _make_scalar_cifti (mask , np .array ([1.0 , 2.0 ], dtype = np .float32 ))
64+ image = make_dscalar (mask , np .array ([1.0 , 2.0 ], dtype = np .float32 ))
9965
10066 with pytest .raises (ValueError , match = 'Inconsistent greyordinate names' ):
10167 extract_cifti_scalar_data (image , reference_brain_names = np .array (['wrong' , 'names' ]))
@@ -104,7 +70,7 @@ def test_extract_cifti_scalar_data_rejects_inconsistent_reference_names() -> Non
10470def test_extract_cifti_scalar_data_pscalar_returns_data_and_names () -> None :
10571 parcel_names = ['parcel_A' , 'parcel_B' , 'parcel_C' ]
10672 values = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
107- image = _make_pscalar_cifti (parcel_names , values )
73+ image = make_pscalar (parcel_names , values )
10874
10975 data , names = extract_cifti_scalar_data (image )
11076
@@ -115,7 +81,7 @@ def test_extract_cifti_scalar_data_pscalar_returns_data_and_names() -> None:
11581def test_extract_cifti_scalar_data_pscalar_validates_reference_names () -> None :
11682 parcel_names = ['parcel_A' , 'parcel_B' , 'parcel_C' ]
11783 values = np .array ([1.0 , 2.0 , 3.0 ], dtype = np .float32 )
118- image = _make_pscalar_cifti (parcel_names , values )
84+ image = make_pscalar (parcel_names , values )
11985
12086 with pytest .raises (ValueError , match = 'Inconsistent parcel names' ):
12187 extract_cifti_scalar_data (image , reference_brain_names = np .array (['X' , 'Y' , 'Z' ]))
@@ -125,7 +91,7 @@ def test_extract_cifti_scalar_data_pconn_flattens_matrix() -> None:
12591 parcel_names = ['parcel_A' , 'parcel_B' ]
12692 n = len (parcel_names )
12793 matrix = np .arange (n * n , dtype = np .float32 ).reshape (n , n )
128- image = _make_pconn_cifti (parcel_names , matrix )
94+ image = make_pconn (parcel_names , matrix )
12995
13096 data , names = extract_cifti_scalar_data (image )
13197
@@ -140,7 +106,7 @@ def test_extract_cifti_scalar_data_pconn_validates_reference_names() -> None:
140106 parcel_names = ['parcel_A' , 'parcel_B' ]
141107 n = len (parcel_names )
142108 matrix = np .zeros ((n , n ), dtype = np .float32 )
143- image = _make_pconn_cifti (parcel_names , matrix )
109+ image = make_pconn (parcel_names , matrix )
144110
145111 # Get the correct element_names first
146112 _ , element_names = extract_cifti_scalar_data (image )
@@ -150,3 +116,23 @@ def test_extract_cifti_scalar_data_pconn_validates_reference_names() -> None:
150116 bad_names [0 ] = 'wrong'
151117 with pytest .raises (ValueError , match = 'Inconsistent parcel names' ):
152118 extract_cifti_scalar_data (image , reference_brain_names = bad_names )
119+
120+
121+ def test_make_parcels_axis_produces_valid_axis () -> None :
122+ """Smoke test: make_parcels_axis should return a ParcelsAxis with correct length."""
123+ from nibabel .cifti2 .cifti2_axes import ParcelsAxis
124+
125+ names = ['A' , 'B' , 'C' ]
126+ axis = make_parcels_axis (names )
127+ assert isinstance (axis , ParcelsAxis )
128+ assert len (axis ) == len (names )
129+
130+
131+ def test_brain_names_to_dataframe () -> None :
132+ names = np .array (['CORTEX_LEFT' , 'CORTEX_LEFT' , 'CORTEX_RIGHT' ])
133+ gdf , struct_strings = brain_names_to_dataframe (names )
134+ assert len (gdf ) == 3
135+ assert 'vertex_id' in gdf .columns
136+ assert 'structure_id' in gdf .columns
137+ assert gdf ['vertex_id' ].tolist () == [0 , 1 , 2 ]
138+ assert len (struct_strings ) == 2 # factorize unique structures
0 commit comments