Skip to content

Commit ad91358

Browse files
committed
add column names
1 parent 375e994 commit ad91358

3 files changed

Lines changed: 138 additions & 42 deletions

File tree

modelarrayio/cifti.py

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import os
3-
from collections import defaultdict
3+
from collections import defaultdict, OrderedDict
44
import os.path as op
55
from concurrent.futures import ThreadPoolExecutor, as_completed
66
import numpy as np
@@ -13,8 +13,51 @@
1313
from .tiledb_storage import (
1414
create_empty_scalar_matrix_array as tdb_create_empty,
1515
write_rows_in_column_stripes as tdb_write_stripes,
16+
write_column_names as tdb_write_column_names,
1617
)
17-
from .parser import add_relative_root_arg, add_output_hdf5_arg, add_cohort_arg, add_storage_args, add_backend_arg, add_output_tiledb_arg, add_tiledb_storage_args
18+
from .parser import add_relative_root_arg, add_output_hdf5_arg, add_cohort_arg, add_storage_args, add_backend_arg, add_output_tiledb_arg, add_tiledb_storage_args, add_scalar_columns_arg
19+
20+
21+
def _cohort_to_long_dataframe(cohort_df, scalar_columns=None):
22+
scalar_columns = [col for col in (scalar_columns or []) if col]
23+
if scalar_columns:
24+
missing = [col for col in scalar_columns if col not in cohort_df.columns]
25+
if missing:
26+
raise ValueError(f"Wide-format cohort is missing scalar columns: {missing}")
27+
records = []
28+
for _, row in cohort_df.iterrows():
29+
for scalar_col in scalar_columns:
30+
source_val = row[scalar_col]
31+
if pd.isna(source_val) or source_val is None:
32+
continue
33+
source_str = str(source_val).strip()
34+
if not source_str:
35+
continue
36+
records.append({"scalar_name": scalar_col, "source_file": source_str})
37+
return pd.DataFrame.from_records(records, columns=["scalar_name", "source_file"])
38+
39+
required = {"scalar_name", "source_file"}
40+
missing = required - set(cohort_df.columns)
41+
if missing:
42+
raise ValueError(f"Cohort file must contain columns {sorted(required)} when --scalar-columns is not used.")
43+
44+
long_df = cohort_df[list(required)].copy()
45+
long_df = long_df.dropna(subset=["scalar_name", "source_file"])
46+
long_df["scalar_name"] = long_df["scalar_name"].astype(str).str.strip()
47+
long_df["source_file"] = long_df["source_file"].astype(str).str.strip()
48+
long_df = long_df[(long_df["scalar_name"] != "") & (long_df["source_file"] != "")]
49+
return long_df.reset_index(drop=True)
50+
51+
52+
def _build_scalar_sources(long_df):
53+
scalar_sources = OrderedDict()
54+
for row in long_df.itertuples(index=False):
55+
scalar = str(row.scalar_name)
56+
source = str(row.source_file)
57+
if not scalar or not source:
58+
continue
59+
scalar_sources.setdefault(scalar, []).append(source)
60+
return scalar_sources
1861

1962

2063
def extract_cifti_scalar_data(cifti_file, reference_brain_names=None):
@@ -103,7 +146,8 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
103146
tdb_shuffle=True,
104147
tdb_tile_voxels=0,
105148
tdb_target_tile_mb=2.0,
106-
tdb_workers=None):
149+
tdb_workers=None,
150+
scalar_columns=None):
107151
"""
108152
Load all fixeldb data.
109153
Parameters
@@ -120,23 +164,30 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
120164
path to which index_file, directions_file and cohort_file (and its contents) are relative
121165
"""
122166

123-
# gather cohort data
124-
cohort_df = pd.read_csv(op.join(relative_root, cohort_file))
125-
126-
# upload each cohort's data
127-
scalars = defaultdict(list)
128-
sources_lists = defaultdict(list)
129-
last_brain_names = None
130-
for ix, row in tqdm(cohort_df.iterrows(), total=cohort_df.shape[0]): # ix: index of row (start from 0); row: one row of data
131-
scalar_file = op.join(relative_root, row['source_file'])
132-
cifti_data, brain_names = extract_cifti_scalar_data(
133-
scalar_file, reference_brain_names=last_brain_names)
134-
last_brain_names = brain_names.copy()
135-
scalars[row['scalar_name']].append(cifti_data) # append to specific scalar_name
136-
sources_lists[row['scalar_name']].append(row['source_file']) # append source mif filename to specific scalar_name
137-
138-
# Write the output
167+
cohort_path = op.join(relative_root, cohort_file)
168+
cohort_df = pd.read_csv(cohort_path)
169+
cohort_long = _cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
170+
if cohort_long.empty:
171+
raise ValueError("Cohort file does not contain any scalar entries after normalization.")
172+
scalar_sources = _build_scalar_sources(cohort_long)
173+
if not scalar_sources:
174+
raise ValueError("Unable to derive scalar sources from cohort file.")
175+
139176
if backend == 'hdf5':
177+
scalars = defaultdict(list)
178+
last_brain_names = None
179+
for row in tqdm(
180+
cohort_long.itertuples(index=False),
181+
total=cohort_long.shape[0],
182+
desc="Loading CIFTI scalars",
183+
):
184+
scalar_file = op.join(relative_root, row.source_file)
185+
cifti_data, brain_names = extract_cifti_scalar_data(
186+
scalar_file, reference_brain_names=last_brain_names
187+
)
188+
last_brain_names = brain_names.copy()
189+
scalars[row.scalar_name].append(cifti_data)
190+
140191
output_file = op.join(relative_root, output_h5)
141192
f = h5py.File(output_file, "w")
142193

@@ -159,19 +210,36 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
159210
shuffle=shuffle,
160211
chunk_voxels=chunk_voxels,
161212
target_chunk_mb=target_chunk_mb,
162-
sources_list=sources_lists[scalar_name])
213+
sources_list=scalar_sources[scalar_name])
163214

164215
write_rows_in_column_stripes(dset, scalars[scalar_name])
165216
f.close()
166217
return int(not op.exists(output_file))
167218
else:
168219
base_uri = op.join(relative_root, output_tdb)
169220
os.makedirs(base_uri, exist_ok=True)
170-
scalar_names = list(scalars.keys())
171-
for scalar_name in scalar_names:
172-
num_subjects = len(scalars[scalar_name])
173-
num_items = scalars[scalar_name][0].shape[0] if num_subjects > 0 else 0
221+
if not scalar_sources:
222+
return 0
223+
224+
# Establish a reference brain axis once to ensure consistent ordering across workers.
225+
first_scalar, first_sources = next(iter(scalar_sources.items()))
226+
first_path = op.join(relative_root, first_sources[0])
227+
_, reference_brain_names = extract_cifti_scalar_data(first_path)
228+
229+
def _process_scalar_job(scalar_name, source_files):
174230
dataset_path = f'scalars/{scalar_name}/values'
231+
rows = []
232+
for source_file in source_files:
233+
scalar_file = op.join(relative_root, source_file)
234+
cifti_data, _ = extract_cifti_scalar_data(
235+
scalar_file, reference_brain_names=reference_brain_names
236+
)
237+
rows.append(cifti_data)
238+
239+
num_subjects = len(rows)
240+
if num_subjects == 0:
241+
return scalar_name
242+
num_items = rows[0].shape[0]
175243
tdb_create_empty(
176244
base_uri,
177245
dataset_path,
@@ -183,18 +251,15 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
183251
shuffle=tdb_shuffle,
184252
tile_voxels=tdb_tile_voxels,
185253
target_tile_mb=tdb_target_tile_mb,
186-
sources_list=sources_lists[scalar_name],
254+
sources_list=source_files,
187255
)
188-
189-
def _write_scalar_to_tdb(scalar_name):
190-
dataset_path = f'scalars/{scalar_name}/values'
256+
# write column names array for ModelArray compatibility
257+
tdb_write_column_names(base_uri, scalar_name, source_files)
191258
uri = op.join(base_uri, dataset_path)
192-
tdb_write_stripes(uri, scalars[scalar_name])
193-
194-
if not scalar_names:
195-
return 0
259+
tdb_write_stripes(uri, rows)
260+
return scalar_name
196261

197-
# Determine worker count: explicit value takes precedence; fallback to CPU count.
262+
scalar_names = list(scalar_sources.keys())
198263
worker_count = tdb_workers if isinstance(tdb_workers, int) and tdb_workers > 0 else None
199264
if worker_count is None:
200265
cpu_count = os.cpu_count() or 1
@@ -204,12 +269,12 @@ def _write_scalar_to_tdb(scalar_name):
204269

205270
if worker_count <= 1:
206271
for scalar_name in scalar_names:
207-
_write_scalar_to_tdb(scalar_name)
272+
_process_scalar_job(scalar_name, scalar_sources[scalar_name])
208273
else:
209274
desc = "TileDB scalars"
210275
with ThreadPoolExecutor(max_workers=worker_count) as executor:
211276
futures = {
212-
executor.submit(_write_scalar_to_tdb, scalar_name): scalar_name
277+
executor.submit(_process_scalar_job, scalar_name, scalar_sources[scalar_name]): scalar_name
213278
for scalar_name in scalar_names
214279
}
215280
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
@@ -221,6 +286,7 @@ def get_parser():
221286
parser = argparse.ArgumentParser(
222287
description="Create a hdf5 file of CIDTI2 dscalar data")
223288
add_cohort_arg(parser)
289+
add_scalar_columns_arg(parser)
224290
add_relative_root_arg(parser)
225291
add_output_hdf5_arg(parser, default_name="fixelarray.h5")
226292
add_output_tiledb_arg(parser, default_name="arraydb.tdb")
@@ -252,7 +318,8 @@ def main():
252318
tdb_shuffle=args.tdb_shuffle,
253319
tdb_tile_voxels=args.tdb_tile_voxels,
254320
tdb_target_tile_mb=args.tdb_target_tile_mb,
255-
tdb_workers=args.tdb_workers)
321+
tdb_workers=args.tdb_workers,
322+
scalar_columns=args.scalar_columns)
256323
return status
257324

258325

modelarrayio/parser.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,16 @@ def add_tiledb_storage_args(parser):
134134
return parser
135135

136136

137+
def add_scalar_columns_arg(parser):
138+
parser.add_argument(
139+
"--scalar-columns", "--scalar_columns",
140+
nargs="+",
141+
help=(
142+
"Column names containing scalar file paths when the cohort table is in wide format. "
143+
"If omitted, the cohort file must include 'scalar_name' and 'source_file' columns."
144+
),
145+
)
146+
return parser
147+
148+
137149

modelarrayio/tiledb_storage.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,32 @@ def write_rows_in_column_stripes(uri: str, rows: Sequence[np.ndarray]):
222222

223223
def write_column_names(base_uri: str, scalar: str, sources: Sequence[str]):
224224
"""
225-
Store column names as metadata on the TileDB group for the given scalar.
226-
This mirrors HDF5's practice of storing names alongside the data.
225+
Store column names as a 1D dense TileDB array for the given scalar.
226+
This mirrors the HDF5 dataset approach and scales to large cohorts.
227227
"""
228+
sources = list(map(str, sources))
229+
uri = os.path.join(base_uri, "scalars", scalar, "column_names")
230+
_ensure_parent_group(uri)
231+
232+
n = len(sources)
233+
dim_idx = tiledb.Dim(name="idx", domain=(0, max(n - 1, 0)), tile=max(1, min(n, 1024)), dtype=np.int64)
234+
dom = tiledb.Domain(dim_idx)
235+
attr_values = tiledb.Attr(name="values", dtype=np.unicode_)
236+
schema = tiledb.ArraySchema(domain=dom, attrs=[attr_values], sparse=False)
237+
238+
if tiledb.object_type(uri):
239+
tiledb.remove(uri)
240+
tiledb.Array.create(uri, schema)
241+
242+
with tiledb.open(uri, "w") as A:
243+
A[:] = {"values": np.array(sources, dtype=object)}
244+
245+
# Also write metadata on the parent group for quick discovery (optional)
228246
group_uri = os.path.join(base_uri, "scalars", scalar)
229-
if not tiledb.object_type(group_uri):
230-
tiledb.group_create(group_uri)
231-
with tiledb.Group(group_uri, "w") as G:
247+
if tiledb.object_type(group_uri):
232248
try:
233-
G.meta["column_names"] = json.dumps(list(map(str, sources)))
249+
with tiledb.Group(group_uri, "w") as G:
250+
G.meta["column_names"] = json.dumps(sources)
234251
except Exception:
235252
logger.warning("Failed to write column_names metadata for group %s", group_uri)
236253

0 commit comments

Comments
 (0)