44
55import argparse
66import logging
7- import os
87from concurrent .futures import ThreadPoolExecutor , as_completed
98from pathlib import Path
109
1312from tqdm import tqdm
1413
1514from modelarrayio .cli import utils as cli_utils
16- from modelarrayio .cli .parser_utils import add_scalar_columns_arg , add_to_modelarray_args
15+ from modelarrayio .cli .parser_utils import add_to_modelarray_args
1716from modelarrayio .utils .cifti import (
18- _build_scalar_sources ,
19- _cohort_to_long_dataframe ,
20- _load_cohort_cifti ,
2117 brain_names_to_dataframe ,
2218 extract_cifti_scalar_data ,
19+ load_cohort_cifti ,
2320)
21+ from modelarrayio .utils .misc import build_scalar_sources , cohort_to_long_dataframe
2422
2523logger = logging .getLogger (__name__ )
2624
@@ -35,7 +33,7 @@ def cifti_to_h5(
3533 shuffle = True ,
3634 chunk_voxels = 0 ,
3735 target_chunk_mb = 2.0 ,
38- workers = None ,
36+ workers = 1 ,
3937 s3_workers = 1 ,
4038 scalar_columns = None ,
4139):
@@ -47,7 +45,7 @@ def cifti_to_h5(
4745 Path to a csv with demographic info and paths to data
4846 backend : :obj:`str`
4947 Backend to use for storage (``'hdf5'`` or ``'tiledb'``)
50- output : :obj:`str `
48+ output : :obj:`pathlib.Path `
5149 Output path. For the hdf5 backend, path to an .h5 file;
5250 for the tiledb backend, path to a .tdb directory.
5351 storage_dtype : :obj:`str`
@@ -64,7 +62,7 @@ def cifti_to_h5(
6462 target_chunk_mb : :obj:`float`
6563 Target chunk/tile size in MiB when auto-computing the spatial axis length
6664 workers : :obj:`int`
67- Maximum number of parallel TileDB write workers (``None`` = auto) .
65+ Maximum number of parallel TileDB write workers. Default 1 .
6866 Has no effect when ``backend='hdf5'``.
6967 s3_workers : :obj:`int`
7068 Number of workers for parallel S3 downloads
@@ -77,19 +75,49 @@ def cifti_to_h5(
7775 0 if successful, 1 if failed.
7876 """
7977 cohort_df = pd .read_csv (cohort_file )
80- cohort_long = _cohort_to_long_dataframe (cohort_df , scalar_columns = scalar_columns )
81- output_path = Path (output )
78+ cohort_long = cohort_to_long_dataframe (cohort_df , scalar_columns = scalar_columns )
8279 if cohort_long .empty :
8380 raise ValueError ('Cohort file does not contain any scalar entries after normalization.' )
84- scalar_sources = _build_scalar_sources (cohort_long )
81+ scalar_sources = build_scalar_sources (cohort_long )
8582 if not scalar_sources :
8683 raise ValueError ('Unable to derive scalar sources from cohort file.' )
84+ scalar_names = list (scalar_sources .keys ())
85+ split_scalar_outputs = bool (scalar_columns )
8786
8887 if backend == 'hdf5' :
89- scalars , last_brain_names = _load_cohort_cifti (cohort_long , s3_workers )
88+ if split_scalar_outputs :
89+ scalars , last_brain_names = load_cohort_cifti (cohort_long , s3_workers )
90+ greyordinate_table , structure_names = brain_names_to_dataframe (last_brain_names )
91+ outputs : list [Path ] = []
92+ for scalar_name in scalar_names :
93+ scalar_output = cli_utils .prepare_output_parent (
94+ cli_utils .prefixed_output_path (output , scalar_name )
95+ )
96+ with h5py .File (scalar_output , 'w' ) as h5_file :
97+ cli_utils .write_table_dataset (
98+ h5_file ,
99+ 'greyordinates' ,
100+ greyordinate_table ,
101+ extra_attrs = {'structure_names' : structure_names },
102+ )
103+ cli_utils .write_hdf5_scalar_matrices (
104+ h5_file ,
105+ {scalar_name : scalars [scalar_name ]},
106+ {scalar_name : scalar_sources [scalar_name ]},
107+ storage_dtype = storage_dtype ,
108+ compression = compression ,
109+ compression_level = compression_level ,
110+ shuffle = shuffle ,
111+ chunk_voxels = chunk_voxels ,
112+ target_chunk_mb = target_chunk_mb ,
113+ )
114+ outputs .append (scalar_output )
115+ return int (not all (path .exists () for path in outputs ))
116+
117+ scalars , last_brain_names = load_cohort_cifti (cohort_long , s3_workers )
90118 greyordinate_table , structure_names = brain_names_to_dataframe (last_brain_names )
91- output_path = cli_utils .prepare_output_parent (output_path )
92- with h5py .File (output_path , 'w' ) as h5_file :
119+ output = cli_utils .prepare_output_parent (output )
120+ with h5py .File (output , 'w' ) as h5_file :
93121 cli_utils .write_table_dataset (
94122 h5_file ,
95123 'greyordinates' ,
@@ -107,9 +135,8 @@ def cifti_to_h5(
107135 chunk_voxels = chunk_voxels ,
108136 target_chunk_mb = target_chunk_mb ,
109137 )
110- return int (not output_path .exists ())
138+ return int (not output .exists ())
111139
112- output_path .mkdir (parents = True , exist_ok = True )
113140 if not scalar_sources :
114141 return 0
115142
@@ -126,8 +153,13 @@ def _process_scalar_job(scalar_name, source_files):
126153 rows .append (cifti_data )
127154
128155 if rows :
156+ scalar_output = (
157+ cli_utils .prefixed_output_path (output , scalar_name )
158+ if split_scalar_outputs
159+ else output
160+ )
129161 cli_utils .write_tiledb_scalar_matrices (
130- output_path ,
162+ scalar_output ,
131163 {scalar_name : rows },
132164 {scalar_name : source_files },
133165 storage_dtype = storage_dtype ,
@@ -140,13 +172,7 @@ def _process_scalar_job(scalar_name, source_files):
140172 )
141173 return scalar_name
142174
143- scalar_names = list (scalar_sources .keys ())
144- worker_count = workers if isinstance (workers , int ) and workers > 0 else None
145- if worker_count is None :
146- cpu_count = os .cpu_count () or 1
147- worker_count = min (len (scalar_names ), max (1 , cpu_count ))
148- else :
149- worker_count = min (len (scalar_names ), worker_count )
175+ worker_count = min (len (scalar_names ), workers )
150176
151177 if worker_count <= 1 :
152178 for scalar_name in scalar_names :
@@ -178,5 +204,4 @@ def _parse_cifti_to_h5():
178204 formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
179205 )
180206 add_to_modelarray_args (parser , default_output = 'greyordinatearray.h5' )
181- add_scalar_columns_arg (parser )
182207 return parser
0 commit comments