55from concurrent .futures import ThreadPoolExecutor , CancelledError
66import traceback
77from fnmatch import fnmatch
8- from collections import namedtuple , Counter
8+ from collections import namedtuple
99import math
1010from contextlib import contextmanager
11+ import os
1112
1213
1314def _sleep (seconds ):
@@ -23,10 +24,10 @@ def _sleep(seconds):
2324
2425
2526FlywheelAnalysisOperation = namedtuple ('FlywheelAnalysisOperation' , [
26- 'gear_name' , 'create_inputs' , 'label' ])
27+ 'gear_name' , 'create_inputs' , 'label' , 'label_matcher' ])
2728
2829
29- def define_analysis (gear_name , create_inputs , label = None ):
30+ def define_analysis (gear_name , create_inputs , label = None , label_matcher = None ):
3031 '''Defines an analysis operation that can be passed to run(...).
3132
3233 An analysis operation has a gear name, label (which defaults to
@@ -38,7 +39,10 @@ def define_analysis(gear_name, create_inputs, label=None):
3839 inputs (to override the default config).
3940 '''
4041 label = label or gear_name
41- return FlywheelAnalysisOperation (gear_name , create_inputs , label )
42+ label_matcher = label_matcher or label
43+ assert find ([dict (label = label )], label = label_matcher ),\
44+ 'Label matcher for operation {} does not detect this operation.' .format (label )
45+ return FlywheelAnalysisOperation (gear_name , create_inputs , label , label_matcher )
4246
4347
4448class FlywheelFileContainer (dict ):
@@ -105,11 +109,24 @@ def find(items, _constructor_=FlywheelFileContainer, **kwargs):
105109 # TODO make this have better errors messages for missing files
106110 result = next ((
107111 item for item in items
108- if all (item [k ] == v for k , v in kwargs .iteritems ())
112+ if all (
113+ v (item [k ]) if callable (v ) else item [k ] == v
114+ for k , v in kwargs .iteritems ()
115+ )
109116 ), None )
110117 return result and _constructor_ (result )
111118
112119
120+ def find_required_input_source (items , ** kwargs ):
121+ '''Finds a match to `kwargs` in `items` by using `find()`. If this match is not
122+ found, the current operation will be skipped.
123+ '''
124+ result = find (items , ** kwargs )
125+ if not result :
126+ raise SkipOperation ('could not find match to {}' .format (kwargs ))
127+ return result
128+
129+
113130def find_project (** kwargs ):
114131 '''Finds a project that matches the key, value pairs in `kwargs`.
115132
@@ -123,6 +140,19 @@ class ShuttingDownException(Exception):
123140 shutting_down = False
124141
125142
143+ class SkipOperation (Exception ):
144+ '''
145+ SkipOperation can be thrown from a `create_inputs` function to skip the execution of that
146+ operation. This is a way to more dynamically create operation graphs by discarding nodes
147+ at runtime.
148+
149+ For example, if every session has a variable number of functional acquisitions that need to be
150+ processed, you can define operations for the max number of per-session functional acquisitions,
151+ and throw SkipOperation for all operations corresponding to acquisitions missing for a session.
152+ '''
153+ pass
154+
155+
126156def request (* args , ** kwargs ):
127157 # HACK client is a module variable for now. In the future, we should pass client around.
128158 assert 'client' in state , 'client must be installed in state before using request. See `installed_client`.'
@@ -187,8 +217,8 @@ def _analyze_session(operations, gears_by_name, session):
187217 acquisitions = None
188218 session_id = session ['_id' ]
189219 analyses = _get_analyses (session_id )
190- for gear_name , create_inputs , label in operations :
191- analysis = find (analyses , label = label )
220+ for gear_name , create_inputs , label , label_matcher in operations :
221+ analysis = find (analyses , label = label_matcher )
192222
193223 # skip this analysis if we've already done it
194224 if analysis and analysis ['job' ]['state' ] == 'complete' :
@@ -201,7 +231,12 @@ def _analyze_session(operations, gears_by_name, session):
201231 # have completed analysis
202232 if not acquisitions :
203233 acquisitions = request ('sessions/{}/acquisitions' .format (session_id ))
204- job_inputs = create_inputs (analyses = analyses , acquisitions = acquisitions )
234+ try :
235+ job_inputs = create_inputs (analyses = analyses , acquisitions = acquisitions , session = session )
236+ except SkipOperation :
237+ # we skip to the next operation
238+ continue
239+
205240 job_config = _defaults_for_gear (gears_by_name [gear_name ])
206241
207242 # When create_inputs returns a tuple, we unpack it into job_inputs and job_config.
@@ -211,7 +246,7 @@ def _analyze_session(operations, gears_by_name, session):
211246 job_inputs , job_config = job_inputs [0 ], dict (job_config , ** job_inputs [1 ])
212247 _submit_analysis (session_id , gear_name , job_inputs , job_config , label )
213248
214- analyses = _wait_for_analysis (session_id , label )
249+ analyses = _wait_for_analysis (session_id , label_matcher )
215250 print (session_id , 'all analysis complete' )
216251
217252
@@ -225,6 +260,7 @@ def done(f):
225260 except (ShuttingDownException , CancelledError ):
226261 pass
227262 except Exception :
263+ print ('error with {}' .format (f .name ))
228264 traceback .print_exc ()
229265
230266 for future in futures :
@@ -270,13 +306,23 @@ def run(operations, project=None, max_workers=10, session_limit=None):
270306 will use and how many CPUs you can use from your Flywheel Engine instance.
271307 session_limit - Used to test pipelines out by limiting the number of sessions
272308 the pipeline code will run on.
309+
310+ Enabling status mode - By setting the environment variable
311+ FLYWHEEL_ANALYZER_STATUS to `true`, this method will only print the status
312+ of this pipeline. It will not run anything.
273313 """
274314 gears = [g ['gear' ] for g in request ('gears' , params = dict (fields = 'all' ))]
275315 gears_by_name = {
276316 gear ['name' ]: gear
277317 for gear in gears
278318 }
279319
320+ # HACK this is seriously a total hack, but is a nice way to see the status
321+ # of a pipeline without editing code.
322+ if os .environ .get ('FLYWHEEL_ANALYZER_STATUS' , '' ).lower () == 'true' :
323+ status (operations , project )
324+ return
325+
280326 for operation in operations :
281327 assert operation .gear_name in gears_by_name ,\
282328 'operation(name={}, label={}) has an invalid name.' .format (
@@ -295,28 +341,37 @@ def run(operations, project=None, max_workers=10, session_limit=None):
295341 sessions = sessions [:session_limit ]
296342
297343 with ThreadPoolExecutor (max_workers = max_workers ) as executor :
298- futures = [
299- executor .submit (_analyze_session , operations , gears_by_name , session )
300- for session in sessions
301- ]
344+ futures = []
345+ for session in sessions :
346+ f = executor .submit (_analyze_session , operations , gears_by_name , session )
347+ f .name = 'session {}' .format (session ['_id' ])
348+ futures .append (f )
302349 _wait_for_futures (futures )
303350
304351
305- def _session_status (expected_ops , session ):
352+ def _session_status (operations , session ):
306353 analyses = _get_analyses (session ['_id' ])
307354
308- started_ops = {
309- a ['label' ] for a in analyses }
355+ started_ops = set ()
356+ completed_ops = set ()
357+ expected_ops = set ()
358+
359+ for op in operations :
360+ a = find (analyses , label = op .label_matcher )
361+ if a :
362+ started_ops .add (op .label )
363+ if a ['job' ]['state' ] == 'complete' :
364+ completed_ops .add (op .label )
365+ expected_ops .add (op .label )
366+
310367 if not started_ops :
311368 return 'not started'
312369
313- completed_ops = {
314- a ['label' ] for a in analyses
315- if a ['job' ]['state' ] == 'complete' }
316370 if completed_ops == expected_ops :
317371 return 'complete'
318372 else :
319- return 'in progress'
373+ return 'in progress ({} of {} done)' .format (
374+ len (completed_ops ), len (expected_ops ))
320375
321376
322377def status (operations , project = None , detail = False ):
@@ -325,13 +380,12 @@ def status(operations, project=None, detail=False):
325380 detail - When true, some session IDs for each status are logged.
326381 '''
327382 sessions = request ('projects/{}/sessions' .format (project ['_id' ]))
328- expected_ops = {op .label for op in operations }
329- statuses = [(s , _session_status (expected_ops , s )) for s in sessions ]
330- if detail :
331- result = {}
332- for sess , stat in statuses :
333- result .setdefault (stat , []).append (sess ['_id' ])
334- for stat , session_ids in result .iteritems ():
335- print (stat , len (session_ids ), 'some IDs:' , session_ids [:4 ])
336- else :
337- print (Counter (stat for _ , stat in statuses ))
383+ statuses = [(s , _session_status (operations , s )) for s in sessions ]
384+ result = {}
385+ for sess , stat in statuses :
386+ result .setdefault (stat , []).append (sess ['_id' ])
387+ for stat , session_ids in sorted (result .iteritems ()):
388+ msg = []
389+ if detail :
390+ msg = ['some IDs:' , session_ids [:4 ]]
391+ print (len (session_ids ), stat , * msg )
0 commit comments