11import argparse
22import os
3- from collections import defaultdict
3+ from collections import defaultdict , OrderedDict
44import os .path as op
55from concurrent .futures import ThreadPoolExecutor , as_completed
66import numpy as np
1313from .tiledb_storage import (
1414 create_empty_scalar_matrix_array as tdb_create_empty ,
1515 write_rows_in_column_stripes as tdb_write_stripes ,
16+ write_column_names as tdb_write_column_names ,
1617)
17- from .parser import add_relative_root_arg , add_output_hdf5_arg , add_cohort_arg , add_storage_args , add_backend_arg , add_output_tiledb_arg , add_tiledb_storage_args
18+ from .parser import add_relative_root_arg , add_output_hdf5_arg , add_cohort_arg , add_storage_args , add_backend_arg , add_output_tiledb_arg , add_tiledb_storage_args , add_scalar_columns_arg
19+
20+
21+ def _cohort_to_long_dataframe (cohort_df , scalar_columns = None ):
22+ scalar_columns = [col for col in (scalar_columns or []) if col ]
23+ if scalar_columns :
24+ missing = [col for col in scalar_columns if col not in cohort_df .columns ]
25+ if missing :
26+ raise ValueError (f"Wide-format cohort is missing scalar columns: { missing } " )
27+ records = []
28+ for _ , row in cohort_df .iterrows ():
29+ for scalar_col in scalar_columns :
30+ source_val = row [scalar_col ]
31+ if pd .isna (source_val ) or source_val is None :
32+ continue
33+ source_str = str (source_val ).strip ()
34+ if not source_str :
35+ continue
36+ records .append ({"scalar_name" : scalar_col , "source_file" : source_str })
37+ return pd .DataFrame .from_records (records , columns = ["scalar_name" , "source_file" ])
38+
39+ required = {"scalar_name" , "source_file" }
40+ missing = required - set (cohort_df .columns )
41+ if missing :
42+ raise ValueError (f"Cohort file must contain columns { sorted (required )} when --scalar-columns is not used." )
43+
44+ long_df = cohort_df [list (required )].copy ()
45+ long_df = long_df .dropna (subset = ["scalar_name" , "source_file" ])
46+ long_df ["scalar_name" ] = long_df ["scalar_name" ].astype (str ).str .strip ()
47+ long_df ["source_file" ] = long_df ["source_file" ].astype (str ).str .strip ()
48+ long_df = long_df [(long_df ["scalar_name" ] != "" ) & (long_df ["source_file" ] != "" )]
49+ return long_df .reset_index (drop = True )
50+
51+
52+ def _build_scalar_sources (long_df ):
53+ scalar_sources = OrderedDict ()
54+ for row in long_df .itertuples (index = False ):
55+ scalar = str (row .scalar_name )
56+ source = str (row .source_file )
57+ if not scalar or not source :
58+ continue
59+ scalar_sources .setdefault (scalar , []).append (source )
60+ return scalar_sources
1861
1962
2063def extract_cifti_scalar_data (cifti_file , reference_brain_names = None ):
@@ -103,7 +146,8 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
103146 tdb_shuffle = True ,
104147 tdb_tile_voxels = 0 ,
105148 tdb_target_tile_mb = 2.0 ,
106- tdb_workers = None ):
149+ tdb_workers = None ,
150+ scalar_columns = None ):
107151 """
108152 Load all fixeldb data.
109153 Parameters
@@ -120,23 +164,30 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
120164 path to which index_file, directions_file and cohort_file (and its contents) are relative
121165 """
122166
123- # gather cohort data
124- cohort_df = pd .read_csv (op .join (relative_root , cohort_file ))
125-
126- # upload each cohort's data
127- scalars = defaultdict (list )
128- sources_lists = defaultdict (list )
129- last_brain_names = None
130- for ix , row in tqdm (cohort_df .iterrows (), total = cohort_df .shape [0 ]): # ix: index of row (start from 0); row: one row of data
131- scalar_file = op .join (relative_root , row ['source_file' ])
132- cifti_data , brain_names = extract_cifti_scalar_data (
133- scalar_file , reference_brain_names = last_brain_names )
134- last_brain_names = brain_names .copy ()
135- scalars [row ['scalar_name' ]].append (cifti_data ) # append to specific scalar_name
136- sources_lists [row ['scalar_name' ]].append (row ['source_file' ]) # append source mif filename to specific scalar_name
137-
138- # Write the output
167+ cohort_path = op .join (relative_root , cohort_file )
168+ cohort_df = pd .read_csv (cohort_path )
169+ cohort_long = _cohort_to_long_dataframe (cohort_df , scalar_columns = scalar_columns )
170+ if cohort_long .empty :
171+ raise ValueError ("Cohort file does not contain any scalar entries after normalization." )
172+ scalar_sources = _build_scalar_sources (cohort_long )
173+ if not scalar_sources :
174+ raise ValueError ("Unable to derive scalar sources from cohort file." )
175+
139176 if backend == 'hdf5' :
177+ scalars = defaultdict (list )
178+ last_brain_names = None
179+ for row in tqdm (
180+ cohort_long .itertuples (index = False ),
181+ total = cohort_long .shape [0 ],
182+ desc = "Loading CIFTI scalars" ,
183+ ):
184+ scalar_file = op .join (relative_root , row .source_file )
185+ cifti_data , brain_names = extract_cifti_scalar_data (
186+ scalar_file , reference_brain_names = last_brain_names
187+ )
188+ last_brain_names = brain_names .copy ()
189+ scalars [row .scalar_name ].append (cifti_data )
190+
140191 output_file = op .join (relative_root , output_h5 )
141192 f = h5py .File (output_file , "w" )
142193
@@ -159,19 +210,36 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
159210 shuffle = shuffle ,
160211 chunk_voxels = chunk_voxels ,
161212 target_chunk_mb = target_chunk_mb ,
162- sources_list = sources_lists [scalar_name ])
213+ sources_list = scalar_sources [scalar_name ])
163214
164215 write_rows_in_column_stripes (dset , scalars [scalar_name ])
165216 f .close ()
166217 return int (not op .exists (output_file ))
167218 else :
168219 base_uri = op .join (relative_root , output_tdb )
169220 os .makedirs (base_uri , exist_ok = True )
170- scalar_names = list (scalars .keys ())
171- for scalar_name in scalar_names :
172- num_subjects = len (scalars [scalar_name ])
173- num_items = scalars [scalar_name ][0 ].shape [0 ] if num_subjects > 0 else 0
221+ if not scalar_sources :
222+ return 0
223+
224+ # Establish a reference brain axis once to ensure consistent ordering across workers.
225+ first_scalar , first_sources = next (iter (scalar_sources .items ()))
226+ first_path = op .join (relative_root , first_sources [0 ])
227+ _ , reference_brain_names = extract_cifti_scalar_data (first_path )
228+
229+ def _process_scalar_job (scalar_name , source_files ):
174230 dataset_path = f'scalars/{ scalar_name } /values'
231+ rows = []
232+ for source_file in source_files :
233+ scalar_file = op .join (relative_root , source_file )
234+ cifti_data , _ = extract_cifti_scalar_data (
235+ scalar_file , reference_brain_names = reference_brain_names
236+ )
237+ rows .append (cifti_data )
238+
239+ num_subjects = len (rows )
240+ if num_subjects == 0 :
241+ return scalar_name
242+ num_items = rows [0 ].shape [0 ]
175243 tdb_create_empty (
176244 base_uri ,
177245 dataset_path ,
@@ -183,18 +251,15 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
183251 shuffle = tdb_shuffle ,
184252 tile_voxels = tdb_tile_voxels ,
185253 target_tile_mb = tdb_target_tile_mb ,
186- sources_list = sources_lists [ scalar_name ] ,
254+ sources_list = source_files ,
187255 )
188-
189- def _write_scalar_to_tdb (scalar_name ):
190- dataset_path = f'scalars/{ scalar_name } /values'
256+ # write column names array for ModelArray compatibility
257+ tdb_write_column_names (base_uri , scalar_name , source_files )
191258 uri = op .join (base_uri , dataset_path )
192- tdb_write_stripes (uri , scalars [scalar_name ])
193-
194- if not scalar_names :
195- return 0
259+ tdb_write_stripes (uri , rows )
260+ return scalar_name
196261
197- # Determine worker count: explicit value takes precedence; fallback to CPU count.
262+ scalar_names = list ( scalar_sources . keys ())
198263 worker_count = tdb_workers if isinstance (tdb_workers , int ) and tdb_workers > 0 else None
199264 if worker_count is None :
200265 cpu_count = os .cpu_count () or 1
@@ -204,12 +269,12 @@ def _write_scalar_to_tdb(scalar_name):
204269
205270 if worker_count <= 1 :
206271 for scalar_name in scalar_names :
207- _write_scalar_to_tdb (scalar_name )
272+ _process_scalar_job (scalar_name , scalar_sources [ scalar_name ] )
208273 else :
209274 desc = "TileDB scalars"
210275 with ThreadPoolExecutor (max_workers = worker_count ) as executor :
211276 futures = {
212- executor .submit (_write_scalar_to_tdb , scalar_name ): scalar_name
277+ executor .submit (_process_scalar_job , scalar_name , scalar_sources [ scalar_name ] ): scalar_name
213278 for scalar_name in scalar_names
214279 }
215280 for future in tqdm (as_completed (futures ), total = len (futures ), desc = desc ):
@@ -221,6 +286,7 @@ def get_parser():
221286 parser = argparse .ArgumentParser (
222287 description = "Create a hdf5 file of CIDTI2 dscalar data" )
223288 add_cohort_arg (parser )
289+ add_scalar_columns_arg (parser )
224290 add_relative_root_arg (parser )
225291 add_output_hdf5_arg (parser , default_name = "fixelarray.h5" )
226292 add_output_tiledb_arg (parser , default_name = "arraydb.tdb" )
@@ -252,7 +318,8 @@ def main():
252318 tdb_shuffle = args .tdb_shuffle ,
253319 tdb_tile_voxels = args .tdb_tile_voxels ,
254320 tdb_target_tile_mb = args .tdb_target_tile_mb ,
255- tdb_workers = args .tdb_workers )
321+ tdb_workers = args .tdb_workers ,
322+ scalar_columns = args .scalar_columns )
256323 return status
257324
258325
0 commit comments