Skip to content

Commit cd49b3a

Browse files
authored
Merge pull request #10 from scitran/cgc/label-matcher
Match analysis more flexibly.
2 parents 3b40324 + 3db7d7a commit cd49b3a

7 files changed

Lines changed: 208 additions & 50 deletions

File tree

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
test:
2-
# hack to avoid needing to install this package
2+
# hack to avoid needing to install this client as a package
33
# from http://stackoverflow.com/a/34140498
4-
python -m pytest tests
4+
python -m pytest tests -v
55

66
lint:
77
flake8 examples scitran_client

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ Lint your code with
5454
make lint
5555
```
5656

57+
Test your code with
58+
```bash
59+
make test
60+
```
61+
5762
Publish a new version of the docs with
5863
```bash
5964
make publish_docs

examples/flywheel_analyzer_afq.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import scitran_client.flywheel_analyzer as fa
2+
from scitran_client import ScitranClient
3+
4+
client = ScitranClient('https://flywheel-cni.scitran.stanford.edu')
5+
6+
7+
def prefix_matcher(prefix):
8+
# doing this funny prefix matching to catch both "afq" and "afq 2017-01-01..."
9+
return lambda val: val == prefix or val.startswith(prefix + ' ')
10+
11+
dtiinit_matcher = prefix_matcher('dtiinit')
12+
afq_matcher = prefix_matcher('afq')
213

314

415
def dtiinit_inputs(acquisitions, **kwargs):
5-
diffusion = fa.find(acquisitions, measurement='diffusion')
16+
diffusion = fa.find(acquisitions, label='DTI 2mm b1250 84dir(axial)')
617

718
return dict(
819
bvec=diffusion.find_file('*.bvec'),
@@ -12,15 +23,19 @@ def dtiinit_inputs(acquisitions, **kwargs):
1223

1324

1425
def afq_inputs(analyses, **kwargs):
15-
dtiinit = fa.find(analyses, label='dtiinit')
26+
dtiinit = fa.find(analyses, label=dtiinit_matcher)
1627

1728
return dict(
1829
dtiInit_Archive=dtiinit.find_file('dtiInit_*.zip'),
1930
)
2031

2132
if __name__ == '__main__':
22-
with fa.installed_client():
33+
with fa.installed_client(client):
2334
fa.run([
24-
fa.define_analysis('dtiinit', dtiinit_inputs),
25-
fa.define_analysis('afq', afq_inputs),
35+
fa.define_analysis(
36+
'dtiinit', dtiinit_inputs,
37+
label_matcher=dtiinit_matcher),
38+
fa.define_analysis(
39+
'afq', afq_inputs,
40+
label_matcher=afq_matcher),
2641
], project=fa.find_project(label='ENGAGE'))

examples/flywheel_analyzer_engage.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
11
import scitran_client.flywheel_analyzer as fa
22
from scitran_client import ScitranClient
33

4+
client = ScitranClient('https://flywheel-cni.scitran.stanford.edu')
5+
with fa.installed_client(client):
6+
project = fa.find_project(label='ENGAGE')
7+
sessions = client.request('projects/{}/sessions'.format(project['_id'])).json()
8+
session_by_subject = {}
9+
second_to_first_visit_id = {}
10+
for s in sessions:
11+
subject = s['subject']['code'][:7].upper()
12+
session_by_subject.setdefault(subject, []).append(s)
13+
for subject, subject_sessions in session_by_subject.iteritems():
14+
# we need at least two sessions
15+
if len(subject_sessions) < 2:
16+
continue
17+
subject_sessions.sort(key=lambda s: s['timestamp'])
18+
# HACK this is a bit of a heuristic. this will certainly fail
19+
# for some folks that skipped a BV, or folks that missed the 2mo
20+
# but hopefully, those folks will otherwise have data that is just
21+
# fine.
22+
second_to_first_visit_id[subject_sessions[1]['_id']] = subject_sessions[0]['_id']
23+
424

525
# XXX at least make this be just the first thing without ' 2'?
626
label_to_task_type = {
@@ -39,18 +59,28 @@ def define_analysis(gear_name, acquisition_label, create_inputs):
3959
label=analysis_label(gear_name, acquisition_label))
4060

4161

42-
def reactivity_inputs(acquisition_label, acquisitions, **kwargs):
43-
functional = fa.find(acquisitions, label=acquisition_label)
62+
def reactivity_inputs(acquisition_label, acquisitions, session, **kwargs):
63+
functional = fa.find_required_input_source(acquisitions, label=acquisition_label)
64+
# using plain find() here b/c this T1w might be missing
4465
structural = fa.find(acquisitions, label='T1w 1mm')
66+
if not structural:
67+
assert session['_id'] in second_to_first_visit_id,\
68+
'the only sessions that should be missing T1w are second visits. {} was missing a T1w'\
69+
.format(session['_id'])
70+
first_visit_session_id = second_to_first_visit_id[session['_id']]
71+
first_visit_acquisitions = client.request(
72+
'sessions/{}/acquisitions'.format(first_visit_session_id)).json()
73+
structural = fa.find(first_visit_acquisitions, label='T1w 1mm')
74+
assert structural, 'Session {} is missing a structural.'.format(session['_id'])
4575

4676
return dict(
4777
functional=functional.find_file('*.nii.gz'),
4878
structural=structural.find_file('*.nii.gz'),
4979
)
5080

5181

52-
def connectivity_inputs(acquisition_label, analyses, acquisitions):
53-
reactivity = fa.find(
82+
def connectivity_inputs(acquisition_label, analyses, **kwargs):
83+
reactivity = fa.find_required_input_source(
5484
analyses, label=analysis_label('reactivity-preprocessing', acquisition_label))
5585

5686
return dict(
@@ -60,12 +90,12 @@ def connectivity_inputs(acquisition_label, analyses, acquisitions):
6090
)
6191

6292

63-
def first_level_model_inputs(acquisition_label, analyses, acquisitions):
64-
reactivity = fa.find(
93+
def first_level_model_inputs(acquisition_label, analyses, acquisitions, **kwargs):
94+
reactivity = fa.find_required_input_source(
6595
analyses, label=analysis_label('reactivity-preprocessing', acquisition_label))
66-
connectivity = fa.find(
96+
connectivity = fa.find_required_input_source(
6797
analyses, label=analysis_label('connectivity-preprocessing', acquisition_label))
68-
behavioral = fa.find(
98+
behavioral = fa.find_required_input_source(
6999
acquisitions, label='Behavioral and Physiological')
70100

71101
return dict(
@@ -81,7 +111,7 @@ def first_level_model_inputs(acquisition_label, analyses, acquisitions):
81111
), dict(task_type=label_to_task_type[acquisition_label])
82112

83113
if __name__ == '__main__':
84-
with fa.installed_client(ScitranClient('https://flywheel-cni.scitran.stanford.edu')):
114+
with fa.installed_client(client):
85115
fa.run([
86116
define_analysis('reactivity-preprocessing', 'go-no-go 2', reactivity_inputs),
87117
define_analysis('connectivity-preprocessing', 'go-no-go 2', connectivity_inputs),
@@ -94,4 +124,8 @@ def first_level_model_inputs(acquisition_label, analyses, acquisitions):
94124
define_analysis('reactivity-preprocessing', 'nonconscious 2', reactivity_inputs),
95125
define_analysis('connectivity-preprocessing', 'nonconscious 2', connectivity_inputs),
96126
define_analysis('first-level-models', 'nonconscious 2', first_level_model_inputs),
97-
], project=fa.find_project(label='ENGAGE'), session_limit=1)
127+
128+
define_analysis('reactivity-preprocessing', 'EmoReg', reactivity_inputs),
129+
define_analysis('connectivity-preprocessing', 'EmoReg', connectivity_inputs),
130+
# define_analysis('first-level-models', 'EmoReg', first_level_model_inputs),
131+
], project=project)

scitran_client/flywheel_analyzer.py

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from concurrent.futures import ThreadPoolExecutor, CancelledError
66
import traceback
77
from fnmatch import fnmatch
8-
from collections import namedtuple, Counter
8+
from collections import namedtuple
99
import math
1010
from contextlib import contextmanager
11+
import os
1112

1213

1314
def _sleep(seconds):
@@ -23,10 +24,10 @@ def _sleep(seconds):
2324

2425

2526
FlywheelAnalysisOperation = namedtuple('FlywheelAnalysisOperation', [
26-
'gear_name', 'create_inputs', 'label'])
27+
'gear_name', 'create_inputs', 'label', 'label_matcher'])
2728

2829

29-
def define_analysis(gear_name, create_inputs, label=None):
30+
def define_analysis(gear_name, create_inputs, label=None, label_matcher=None):
3031
'''Defines an analysis operation that can be passed to run(...).
3132
3233
An analysis operation has a gear name, label (which defaults to
@@ -38,7 +39,10 @@ def define_analysis(gear_name, create_inputs, label=None):
3839
inputs (to override the default config).
3940
'''
4041
label = label or gear_name
41-
return FlywheelAnalysisOperation(gear_name, create_inputs, label)
42+
label_matcher = label_matcher or label
43+
assert find([dict(label=label)], label=label_matcher),\
44+
'Label matcher for operation {} does not detect this operation.'.format(label)
45+
return FlywheelAnalysisOperation(gear_name, create_inputs, label, label_matcher)
4246

4347

4448
class FlywheelFileContainer(dict):
@@ -105,11 +109,24 @@ def find(items, _constructor_=FlywheelFileContainer, **kwargs):
105109
# TODO make this have better errors messages for missing files
106110
result = next((
107111
item for item in items
108-
if all(item[k] == v for k, v in kwargs.iteritems())
112+
if all(
113+
v(item[k]) if callable(v) else item[k] == v
114+
for k, v in kwargs.iteritems()
115+
)
109116
), None)
110117
return result and _constructor_(result)
111118

112119

120+
def find_required_input_source(items, **kwargs):
121+
'''Finds a match to `kwargs` in `items` by using `find()`. If this match is not
122+
found, the current operation will be skipped.
123+
'''
124+
result = find(items, **kwargs)
125+
if not result:
126+
raise SkipOperation('could not find match to {}'.format(kwargs))
127+
return result
128+
129+
113130
def find_project(**kwargs):
114131
'''Finds a project that matches the key, value pairs in `kwargs`.
115132
@@ -123,6 +140,19 @@ class ShuttingDownException(Exception):
123140
shutting_down = False
124141

125142

143+
class SkipOperation(Exception):
144+
'''
145+
SkipOperation can be thrown from a `create_inputs` function to skip the execution of that
146+
operation. This is a way to more dynamically create operation graphs by discarding nodes
147+
at runtime.
148+
149+
For example, if every session has a variable number of functional acquisitions that need to be
150+
processed, you can define operations for the max number of per-session functional acquisitions,
151+
and throw SkipOperation for all operations corresponding to acquisitions missing for a session.
152+
'''
153+
pass
154+
155+
126156
def request(*args, **kwargs):
127157
# HACK client is a module variable for now. In the future, we should pass client around.
128158
assert 'client' in state, 'client must be installed in state before using request. See `installed_client`.'
@@ -187,8 +217,8 @@ def _analyze_session(operations, gears_by_name, session):
187217
acquisitions = None
188218
session_id = session['_id']
189219
analyses = _get_analyses(session_id)
190-
for gear_name, create_inputs, label in operations:
191-
analysis = find(analyses, label=label)
220+
for gear_name, create_inputs, label, label_matcher in operations:
221+
analysis = find(analyses, label=label_matcher)
192222

193223
# skip this analysis if we've already done it
194224
if analysis and analysis['job']['state'] == 'complete':
@@ -201,7 +231,12 @@ def _analyze_session(operations, gears_by_name, session):
201231
# have completed analysis
202232
if not acquisitions:
203233
acquisitions = request('sessions/{}/acquisitions'.format(session_id))
204-
job_inputs = create_inputs(analyses=analyses, acquisitions=acquisitions)
234+
try:
235+
job_inputs = create_inputs(analyses=analyses, acquisitions=acquisitions, session=session)
236+
except SkipOperation:
237+
# we skip to the next operation
238+
continue
239+
205240
job_config = _defaults_for_gear(gears_by_name[gear_name])
206241

207242
# When create_inputs returns a tuple, we unpack it into job_inputs and job_config.
@@ -211,7 +246,7 @@ def _analyze_session(operations, gears_by_name, session):
211246
job_inputs, job_config = job_inputs[0], dict(job_config, **job_inputs[1])
212247
_submit_analysis(session_id, gear_name, job_inputs, job_config, label)
213248

214-
analyses = _wait_for_analysis(session_id, label)
249+
analyses = _wait_for_analysis(session_id, label_matcher)
215250
print(session_id, 'all analysis complete')
216251

217252

@@ -225,6 +260,7 @@ def done(f):
225260
except (ShuttingDownException, CancelledError):
226261
pass
227262
except Exception:
263+
print('error with {}'.format(f.name))
228264
traceback.print_exc()
229265

230266
for future in futures:
@@ -270,13 +306,23 @@ def run(operations, project=None, max_workers=10, session_limit=None):
270306
will use and how many CPUs you can use from your Flywheel Engine instance.
271307
session_limit - Used to test pipelines out by limiting the number of sessions
272308
the pipeline code will run on.
309+
310+
Enabling status mode - By setting the environment variable
311+
FLYWHEEL_ANALYZER_STATUS to `true`, this method will only print the status
312+
of this pipeline. It will not run anything.
273313
"""
274314
gears = [g['gear'] for g in request('gears', params=dict(fields='all'))]
275315
gears_by_name = {
276316
gear['name']: gear
277317
for gear in gears
278318
}
279319

320+
# HACK this is seriously a total hack, but is a nice way to see the status
321+
# of a pipeline without editing code.
322+
if os.environ.get('FLYWHEEL_ANALYZER_STATUS', '').lower() == 'true':
323+
status(operations, project)
324+
return
325+
280326
for operation in operations:
281327
assert operation.gear_name in gears_by_name,\
282328
'operation(name={}, label={}) has an invalid name.'.format(
@@ -295,28 +341,37 @@ def run(operations, project=None, max_workers=10, session_limit=None):
295341
sessions = sessions[:session_limit]
296342

297343
with ThreadPoolExecutor(max_workers=max_workers) as executor:
298-
futures = [
299-
executor.submit(_analyze_session, operations, gears_by_name, session)
300-
for session in sessions
301-
]
344+
futures = []
345+
for session in sessions:
346+
f = executor.submit(_analyze_session, operations, gears_by_name, session)
347+
f.name = 'session {}'.format(session['_id'])
348+
futures.append(f)
302349
_wait_for_futures(futures)
303350

304351

305-
def _session_status(expected_ops, session):
352+
def _session_status(operations, session):
306353
analyses = _get_analyses(session['_id'])
307354

308-
started_ops = {
309-
a['label'] for a in analyses}
355+
started_ops = set()
356+
completed_ops = set()
357+
expected_ops = set()
358+
359+
for op in operations:
360+
a = find(analyses, label=op.label_matcher)
361+
if a:
362+
started_ops.add(op.label)
363+
if a['job']['state'] == 'complete':
364+
completed_ops.add(op.label)
365+
expected_ops.add(op.label)
366+
310367
if not started_ops:
311368
return 'not started'
312369

313-
completed_ops = {
314-
a['label'] for a in analyses
315-
if a['job']['state'] == 'complete'}
316370
if completed_ops == expected_ops:
317371
return 'complete'
318372
else:
319-
return 'in progress'
373+
return 'in progress ({} of {} done)'.format(
374+
len(completed_ops), len(expected_ops))
320375

321376

322377
def status(operations, project=None, detail=False):
@@ -325,13 +380,12 @@ def status(operations, project=None, detail=False):
325380
detail - When true, some session IDs for each status are logged.
326381
'''
327382
sessions = request('projects/{}/sessions'.format(project['_id']))
328-
expected_ops = {op.label for op in operations}
329-
statuses = [(s, _session_status(expected_ops, s)) for s in sessions]
330-
if detail:
331-
result = {}
332-
for sess, stat in statuses:
333-
result.setdefault(stat, []).append(sess['_id'])
334-
for stat, session_ids in result.iteritems():
335-
print(stat, len(session_ids), 'some IDs:', session_ids[:4])
336-
else:
337-
print(Counter(stat for _, stat in statuses))
383+
statuses = [(s, _session_status(operations, s)) for s in sessions]
384+
result = {}
385+
for sess, stat in statuses:
386+
result.setdefault(stat, []).append(sess['_id'])
387+
for stat, session_ids in sorted(result.iteritems()):
388+
msg = []
389+
if detail:
390+
msg = ['some IDs:', session_ids[:4]]
391+
print(len(session_ids), stat, *msg)

scitran_client/st_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,12 @@ def download_file(
278278
desc = tqdm_kwargs.pop('desc', file_name)
279279
leave = tqdm_kwargs.pop('leave', False)
280280
with open(abs_file_path, 'wb') as fd:
281-
content = response.iter_content()
281+
content = response.iter_content(4096)
282282
if not tqdm_disable:
283283
content = tqdm(
284-
response.iter_content(),
284+
content,
285285
desc=desc, leave=leave,
286-
unit_scale=True, unit='B',
286+
unit=' 4KB',
287287
**tqdm_kwargs
288288
)
289289
for chunk in content:

0 commit comments

Comments
 (0)