Skip to content

Commit 1076906

Browse files
authored
Make code more Pythonic (#30)
1 parent 862a2d1 commit 1076906

24 files changed

Lines changed: 863 additions & 619 deletions

.circleci/config.yml

Lines changed: 0 additions & 108 deletions
This file was deleted.

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ jobs:
2424
runs-on: ubuntu-latest
2525
steps:
2626
- uses: actions/checkout@v6
27-
- run: pipx run ruff check .
28-
- run: pipx run ruff format --diff .
27+
- run: pipx run ruff check src test
28+
- run: pipx run ruff format --diff src test
2929

3030
codespell:
3131
name: Check for spelling errors

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ exclude = ".*"
100100
[tool.ruff]
101101
line-length = 99
102102
src = ["src"]
103+
extend-exclude = ["src/modelarrayio/__about__.py"]
103104

104105
[tool.ruff.lint]
105106
extend-select = [
@@ -155,6 +156,7 @@ inline-quotes = "single"
155156
[tool.ruff.lint.extend-per-file-ignores]
156157
"*/test_*.py" = ["S101"]
157158
"docs/conf.py" = ["A001"]
159+
"src/modelarrayio/__about__.py" = ["Q000", "RUF022"]
158160

159161
[tool.ruff.format]
160162
quote-style = "single"

src/modelarrayio/cli/cifti_to_h5.py

Lines changed: 66 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
"""Convert CIFTI2 dscalar data to an HDF5 file."""
22

3+
from __future__ import annotations
4+
35
import argparse
46
import logging
57
import os
68
from concurrent.futures import ThreadPoolExecutor, as_completed
9+
from pathlib import Path
710

811
import h5py
912
import pandas as pd
1013
from tqdm import tqdm
1114

15+
from modelarrayio.cli import utils as cli_utils
1216
from modelarrayio.cli.parser_utils import (
1317
add_backend_arg,
1418
add_cohort_arg,
@@ -17,7 +21,6 @@
1721
add_scalar_columns_arg,
1822
add_storage_args,
1923
)
20-
from modelarrayio.storage import h5_storage, tiledb_storage
2124
from modelarrayio.utils.cifti import (
2225
_build_scalar_sources,
2326
_cohort_to_long_dataframe,
@@ -82,6 +85,7 @@ def cifti_to_h5(
8285
"""
8386
cohort_df = pd.read_csv(cohort_file)
8487
cohort_long = _cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns)
88+
output_path = Path(output)
8589
if cohort_long.empty:
8690
raise ValueError('Cohort file does not contain any scalar entries after normalization.')
8791
scalar_sources = _build_scalar_sources(cohort_long)
@@ -90,101 +94,82 @@ def cifti_to_h5(
9094

9195
if backend == 'hdf5':
9296
scalars, last_brain_names = _load_cohort_cifti(cohort_long, s3_workers)
93-
94-
f = h5py.File(output, 'w')
95-
9697
greyordinate_table, structure_names = brain_names_to_dataframe(last_brain_names)
97-
greyordinatesh5 = f.create_dataset(
98-
name='greyordinates', data=greyordinate_table.to_numpy().T
99-
)
100-
greyordinatesh5.attrs['column_names'] = list(greyordinate_table.columns)
101-
greyordinatesh5.attrs['structure_names'] = structure_names
102-
103-
for scalar_name in scalars.keys():
104-
num_subjects = len(scalars[scalar_name])
105-
num_items = scalars[scalar_name][0].shape[0] if num_subjects > 0 else 0
106-
dset = h5_storage.create_empty_scalar_matrix_dataset(
107-
f,
108-
f'scalars/{scalar_name}/values',
109-
num_subjects,
110-
num_items,
98+
output_path = cli_utils.prepare_output_parent(output_path)
99+
with h5py.File(output_path, 'w') as h5_file:
100+
cli_utils.write_table_dataset(
101+
h5_file,
102+
'greyordinates',
103+
greyordinate_table,
104+
extra_attrs={'structure_names': structure_names},
105+
)
106+
cli_utils.write_hdf5_scalar_matrices(
107+
h5_file,
108+
scalars,
109+
scalar_sources,
111110
storage_dtype=storage_dtype,
112111
compression=compression,
113112
compression_level=compression_level,
114113
shuffle=shuffle,
115114
chunk_voxels=chunk_voxels,
116115
target_chunk_mb=target_chunk_mb,
117-
sources_list=scalar_sources[scalar_name],
118116
)
117+
return int(not output_path.exists())
119118

120-
h5_storage.write_rows_in_column_stripes(dset, scalars[scalar_name])
121-
f.close()
122-
return int(not os.path.exists(output))
123-
else:
124-
os.makedirs(output, exist_ok=True)
125-
if not scalar_sources:
126-
return 0
127-
128-
# Establish a reference brain axis once to ensure consistent ordering across workers.
129-
_first_scalar, first_sources = next(iter(scalar_sources.items()))
130-
first_path = first_sources[0]
131-
_, reference_brain_names = extract_cifti_scalar_data(first_path)
132-
133-
def _process_scalar_job(scalar_name, source_files):
134-
dataset_path = f'scalars/{scalar_name}/values'
135-
rows = []
136-
for source_file in source_files:
137-
cifti_data, _ = extract_cifti_scalar_data(
138-
source_file, reference_brain_names=reference_brain_names
139-
)
140-
rows.append(cifti_data)
141-
142-
num_subjects = len(rows)
143-
if num_subjects == 0:
144-
return scalar_name
145-
num_items = rows[0].shape[0]
146-
tiledb_storage.create_empty_scalar_matrix_array(
147-
output,
148-
dataset_path,
149-
num_subjects,
150-
num_items,
119+
output_path.mkdir(parents=True, exist_ok=True)
120+
if not scalar_sources:
121+
return 0
122+
123+
_first_scalar, first_sources = next(iter(scalar_sources.items()))
124+
first_path = first_sources[0]
125+
_, reference_brain_names = extract_cifti_scalar_data(first_path)
126+
127+
def _process_scalar_job(scalar_name, source_files):
128+
rows = []
129+
for source_file in source_files:
130+
cifti_data, _ = extract_cifti_scalar_data(
131+
source_file, reference_brain_names=reference_brain_names
132+
)
133+
rows.append(cifti_data)
134+
135+
if rows:
136+
cli_utils.write_tiledb_scalar_matrices(
137+
output_path,
138+
{scalar_name: rows},
139+
{scalar_name: source_files},
151140
storage_dtype=storage_dtype,
152141
compression=compression,
153142
compression_level=compression_level,
154143
shuffle=shuffle,
155-
tile_voxels=chunk_voxels,
156-
target_tile_mb=target_chunk_mb,
157-
sources_list=source_files,
144+
chunk_voxels=chunk_voxels,
145+
target_chunk_mb=target_chunk_mb,
146+
write_column_name_arrays=True,
158147
)
159-
# write column names array for ModelArray compatibility
160-
tiledb_storage.write_column_names(output, scalar_name, source_files)
161-
uri = os.path.join(output, dataset_path)
162-
tiledb_storage.write_rows_in_column_stripes(uri, rows)
163148
return scalar_name
164149

165-
scalar_names = list(scalar_sources.keys())
166-
worker_count = workers if isinstance(workers, int) and workers > 0 else None
167-
if worker_count is None:
168-
cpu_count = os.cpu_count() or 1
169-
worker_count = min(len(scalar_names), max(1, cpu_count))
170-
else:
171-
worker_count = min(len(scalar_names), worker_count)
172-
173-
if worker_count <= 1:
174-
for scalar_name in scalar_names:
175-
_process_scalar_job(scalar_name, scalar_sources[scalar_name])
176-
else:
177-
desc = 'TileDB scalars'
178-
with ThreadPoolExecutor(max_workers=worker_count) as executor:
179-
futures = {
180-
executor.submit(
181-
_process_scalar_job, scalar_name, scalar_sources[scalar_name]
182-
): scalar_name
183-
for scalar_name in scalar_names
184-
}
185-
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
186-
future.result()
187-
return 0
150+
scalar_names = list(scalar_sources.keys())
151+
worker_count = workers if isinstance(workers, int) and workers > 0 else None
152+
if worker_count is None:
153+
cpu_count = os.cpu_count() or 1
154+
worker_count = min(len(scalar_names), max(1, cpu_count))
155+
else:
156+
worker_count = min(len(scalar_names), worker_count)
157+
158+
if worker_count <= 1:
159+
for scalar_name in scalar_names:
160+
_process_scalar_job(scalar_name, scalar_sources[scalar_name])
161+
else:
162+
desc = 'TileDB scalars'
163+
with ThreadPoolExecutor(max_workers=worker_count) as executor:
164+
futures = {
165+
executor.submit(_process_scalar_job, scalar_name, scalar_sources[scalar_name]): (
166+
scalar_name
167+
)
168+
for scalar_name in scalar_names
169+
}
170+
for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
171+
future.result()
172+
return 0
188173

189174

190175
def cifti_to_h5_main(
@@ -203,10 +188,7 @@ def cifti_to_h5_main(
203188
log_level='INFO',
204189
):
205190
"""Entry point for the ``modelarrayio cifti-to-h5`` command."""
206-
logging.basicConfig(
207-
level=getattr(logging, str(log_level).upper(), logging.INFO),
208-
format='[%(levelname)s] %(name)s: %(message)s',
209-
)
191+
cli_utils.configure_logging(log_level)
210192
return cifti_to_h5(
211193
cohort_file=cohort_file,
212194
backend=backend,

0 commit comments

Comments
 (0)