11"""Convert CIFTI2 dscalar data to an HDF5 file."""
22
3+ from __future__ import annotations
4+
35import argparse
46import logging
57import os
68from concurrent .futures import ThreadPoolExecutor , as_completed
9+ from pathlib import Path
710
811import h5py
912import pandas as pd
1013from tqdm import tqdm
1114
15+ from modelarrayio .cli import utils as cli_utils
1216from modelarrayio .cli .parser_utils import (
1317 add_backend_arg ,
1418 add_cohort_arg ,
1721 add_scalar_columns_arg ,
1822 add_storage_args ,
1923)
20- from modelarrayio .storage import h5_storage , tiledb_storage
2124from modelarrayio .utils .cifti import (
2225 _build_scalar_sources ,
2326 _cohort_to_long_dataframe ,
@@ -82,6 +85,7 @@ def cifti_to_h5(
8285 """
8386 cohort_df = pd .read_csv (cohort_file )
8487 cohort_long = _cohort_to_long_dataframe (cohort_df , scalar_columns = scalar_columns )
88+ output_path = Path (output )
8589 if cohort_long .empty :
8690 raise ValueError ('Cohort file does not contain any scalar entries after normalization.' )
8791 scalar_sources = _build_scalar_sources (cohort_long )
@@ -90,101 +94,82 @@ def cifti_to_h5(
9094
9195 if backend == 'hdf5' :
9296 scalars , last_brain_names = _load_cohort_cifti (cohort_long , s3_workers )
93-
94- f = h5py .File (output , 'w' )
95-
9697 greyordinate_table , structure_names = brain_names_to_dataframe (last_brain_names )
97- greyordinatesh5 = f .create_dataset (
98- name = 'greyordinates' , data = greyordinate_table .to_numpy ().T
99- )
100- greyordinatesh5 .attrs ['column_names' ] = list (greyordinate_table .columns )
101- greyordinatesh5 .attrs ['structure_names' ] = structure_names
102-
103- for scalar_name in scalars .keys ():
104- num_subjects = len (scalars [scalar_name ])
105- num_items = scalars [scalar_name ][0 ].shape [0 ] if num_subjects > 0 else 0
106- dset = h5_storage .create_empty_scalar_matrix_dataset (
107- f ,
108- f'scalars/{ scalar_name } /values' ,
109- num_subjects ,
110- num_items ,
98+ output_path = cli_utils .prepare_output_parent (output_path )
99+ with h5py .File (output_path , 'w' ) as h5_file :
100+ cli_utils .write_table_dataset (
101+ h5_file ,
102+ 'greyordinates' ,
103+ greyordinate_table ,
104+ extra_attrs = {'structure_names' : structure_names },
105+ )
106+ cli_utils .write_hdf5_scalar_matrices (
107+ h5_file ,
108+ scalars ,
109+ scalar_sources ,
111110 storage_dtype = storage_dtype ,
112111 compression = compression ,
113112 compression_level = compression_level ,
114113 shuffle = shuffle ,
115114 chunk_voxels = chunk_voxels ,
116115 target_chunk_mb = target_chunk_mb ,
117- sources_list = scalar_sources [scalar_name ],
118116 )
117+ return int (not output_path .exists ())
119118
120- h5_storage .write_rows_in_column_stripes (dset , scalars [scalar_name ])
121- f .close ()
122- return int (not os .path .exists (output ))
123- else :
124- os .makedirs (output , exist_ok = True )
125- if not scalar_sources :
126- return 0
127-
128- # Establish a reference brain axis once to ensure consistent ordering across workers.
129- _first_scalar , first_sources = next (iter (scalar_sources .items ()))
130- first_path = first_sources [0 ]
131- _ , reference_brain_names = extract_cifti_scalar_data (first_path )
132-
133- def _process_scalar_job (scalar_name , source_files ):
134- dataset_path = f'scalars/{ scalar_name } /values'
135- rows = []
136- for source_file in source_files :
137- cifti_data , _ = extract_cifti_scalar_data (
138- source_file , reference_brain_names = reference_brain_names
139- )
140- rows .append (cifti_data )
141-
142- num_subjects = len (rows )
143- if num_subjects == 0 :
144- return scalar_name
145- num_items = rows [0 ].shape [0 ]
146- tiledb_storage .create_empty_scalar_matrix_array (
147- output ,
148- dataset_path ,
149- num_subjects ,
150- num_items ,
119+ output_path .mkdir (parents = True , exist_ok = True )
120+ if not scalar_sources :
121+ return 0
122+
123+ _first_scalar , first_sources = next (iter (scalar_sources .items ()))
124+ first_path = first_sources [0 ]
125+ _ , reference_brain_names = extract_cifti_scalar_data (first_path )
126+
127+ def _process_scalar_job (scalar_name , source_files ):
128+ rows = []
129+ for source_file in source_files :
130+ cifti_data , _ = extract_cifti_scalar_data (
131+ source_file , reference_brain_names = reference_brain_names
132+ )
133+ rows .append (cifti_data )
134+
135+ if rows :
136+ cli_utils .write_tiledb_scalar_matrices (
137+ output_path ,
138+ {scalar_name : rows },
139+ {scalar_name : source_files },
151140 storage_dtype = storage_dtype ,
152141 compression = compression ,
153142 compression_level = compression_level ,
154143 shuffle = shuffle ,
155- tile_voxels = chunk_voxels ,
156- target_tile_mb = target_chunk_mb ,
157- sources_list = source_files ,
144+ chunk_voxels = chunk_voxels ,
145+ target_chunk_mb = target_chunk_mb ,
146+ write_column_name_arrays = True ,
158147 )
159- # write column names array for ModelArray compatibility
160- tiledb_storage .write_column_names (output , scalar_name , source_files )
161- uri = os .path .join (output , dataset_path )
162- tiledb_storage .write_rows_in_column_stripes (uri , rows )
163148 return scalar_name
164149
165- scalar_names = list (scalar_sources .keys ())
166- worker_count = workers if isinstance (workers , int ) and workers > 0 else None
167- if worker_count is None :
168- cpu_count = os .cpu_count () or 1
169- worker_count = min (len (scalar_names ), max (1 , cpu_count ))
170- else :
171- worker_count = min (len (scalar_names ), worker_count )
172-
173- if worker_count <= 1 :
174- for scalar_name in scalar_names :
175- _process_scalar_job (scalar_name , scalar_sources [scalar_name ])
176- else :
177- desc = 'TileDB scalars'
178- with ThreadPoolExecutor (max_workers = worker_count ) as executor :
179- futures = {
180- executor .submit (
181- _process_scalar_job , scalar_name , scalar_sources [ scalar_name ]
182- ): scalar_name
183- for scalar_name in scalar_names
184- }
185- for future in tqdm (as_completed (futures ), total = len (futures ), desc = desc ):
186- future .result ()
187- return 0
150+ scalar_names = list (scalar_sources .keys ())
151+ worker_count = workers if isinstance (workers , int ) and workers > 0 else None
152+ if worker_count is None :
153+ cpu_count = os .cpu_count () or 1
154+ worker_count = min (len (scalar_names ), max (1 , cpu_count ))
155+ else :
156+ worker_count = min (len (scalar_names ), worker_count )
157+
158+ if worker_count <= 1 :
159+ for scalar_name in scalar_names :
160+ _process_scalar_job (scalar_name , scalar_sources [scalar_name ])
161+ else :
162+ desc = 'TileDB scalars'
163+ with ThreadPoolExecutor (max_workers = worker_count ) as executor :
164+ futures = {
165+ executor .submit ( _process_scalar_job , scalar_name , scalar_sources [ scalar_name ]): (
166+ scalar_name
167+ )
168+ for scalar_name in scalar_names
169+ }
170+ for future in tqdm (as_completed (futures ), total = len (futures ), desc = desc ):
171+ future .result ()
172+ return 0
188173
189174
190175def cifti_to_h5_main (
@@ -203,10 +188,7 @@ def cifti_to_h5_main(
203188 log_level = 'INFO' ,
204189):
205190 """Entry point for the ``modelarrayio cifti-to-h5`` command."""
206- logging .basicConfig (
207- level = getattr (logging , str (log_level ).upper (), logging .INFO ),
208- format = '[%(levelname)s] %(name)s: %(message)s' ,
209- )
191+ cli_utils .configure_logging (log_level )
210192 return cifti_to_h5 (
211193 cohort_file = cohort_file ,
212194 backend = backend ,
0 commit comments