@@ -132,7 +132,9 @@ def __call__(self, m, f=None):
132132
133133 values = []
134134 count = 0
135- for futures in self ._futures :
135+ for ind , futures in enumerate (self ._futures ):
136+
137+ priority = len (self ._futures ) - ind # reverse order for priority
136138 for objfct , worker in zip (futures , self ._workers , strict = True ):
137139
138140 if self .multipliers [count ] == 0.0 :
@@ -145,6 +147,7 @@ def __call__(self, m, f=None):
145147 self .multipliers [count ],
146148 m_future ,
147149 workers = worker ,
150+ priority = priority ,
148151 )
149152 )
150153 count += 1
@@ -193,11 +196,12 @@ def deriv(self, m, f=None):
193196 client = self .client
194197 m_future = self ._m_as_future
195198
196- derivs = 0.0
197199 count = 0
200+ future_deriv = []
201+ for ind , futures in enumerate (self ._futures ):
202+
203+ priority = len (self ._futures ) - ind # reverse order for priority
198204
199- for futures in self ._futures :
200- future_deriv = []
201205 for objfct , worker in zip (futures , self ._workers ):
202206 if self .multipliers [count ] == 0.0 : # don't evaluate the fct
203207 continue
@@ -209,13 +213,14 @@ def deriv(self, m, f=None):
209213 self .multipliers [count ],
210214 m_future ,
211215 workers = worker ,
216+ priority = priority ,
212217 )
213218 )
214219
215220 count += 1
216- future_deriv = client .gather (future_deriv )
221+ future_deriv = client .gather (future_deriv )
217222
218- derivs + = np .sum (future_deriv , axis = 0 )
223+ derivs = np .sum (future_deriv , axis = 0 )
219224
220225 return derivs
221226
@@ -234,12 +239,11 @@ def deriv2(self, m, v=None, f=None):
234239 m_future = self ._m_as_future
235240 [v_future ] = client .scatter ([v ], broadcast = True )
236241
237- derivs = 0.0
238242 count = 0
243+ future_derivs = []
244+ for ind , futures in enumerate (self ._futures ):
239245
240- for futures in self ._futures :
241-
242- future_derivs = []
246+ priority = len (self ._futures ) - ind # reverse order for priority
243247 for objfct , worker in zip (futures , self ._workers ):
244248 if self .multipliers [count ] == 0.0 : # don't evaluate the fct
245249 continue
@@ -253,12 +257,13 @@ def deriv2(self, m, v=None, f=None):
253257 v_future ,
254258 # field,
255259 workers = worker ,
260+ priority = priority ,
256261 )
257262 )
258263 count += 1
259264
260- future_derivs = self .client .gather (future_derivs )
261- derivs + = np .sum (future_derivs , axis = 0 )
265+ future_derivs = self .client .gather (future_derivs )
266+ derivs = np .sum (future_derivs , axis = 0 )
262267
263268 return derivs
264269
@@ -270,20 +275,22 @@ def get_dpred(self, m, f=None):
270275
271276 client = self .client
272277 m_future = self ._m_as_future
273- dpred = []
278+ # dpred = []
279+ future_preds = []
280+ for ind , futures in enumerate (self ._futures ):
274281
275- for futures in self ._futures :
276- future_preds = []
282+ priority = len (self ._futures ) - ind # reverse order for priority
277283 for objfct , worker in zip (futures , self ._workers ):
278284 future_preds .append (
279285 client .submit (
280286 _calc_dpred ,
281287 objfct ,
282288 m_future ,
283289 workers = worker ,
290+ priority = priority ,
284291 )
285292 )
286- dpred + = client .gather (future_preds )
293+ dpred = client .gather (future_preds )
287294
288295 return dpred
289296
@@ -294,25 +301,24 @@ def getJtJdiag(self, m, f=None):
294301 self .model = m
295302 m_future = self ._m_as_future
296303 if getattr (self , "_jtjdiag" , None ) is None :
297-
298- jtj_diag = 0.0
299304 client = self .client
305+ work = []
306+ for ind , futures in enumerate (self ._futures ):
300307
301- for futures in self ._futures :
302- work = []
303-
308+ priority = len (self ._futures ) - ind # reverse order for priority
304309 for objfct , worker in zip (futures , self ._workers ):
305310 work .append (
306311 client .submit (
307312 _get_jtj_diag ,
308313 objfct ,
309314 m_future ,
310315 workers = worker ,
316+ priority = priority ,
311317 )
312318 )
313319
314- work = client .gather (work )
315- jtj_diag + = np .sum (work , axis = 0 )
320+ work = client .gather (work )
321+ jtj_diag = np .sum (work , axis = 0 )
316322
317323 self ._jtjdiag = jtj_diag
318324
@@ -332,15 +338,18 @@ def fields(self, m):
332338 # The above should pass the model to all the internal simulations.
333339 f = []
334340
335- for futures in self ._futures :
336- f .append ([])
341+ for ind , futures in enumerate (self ._futures ):
342+
343+ priority = len (self ._futures ) - ind # reverse order for priority
344+
337345 for objfct , worker in zip (futures , self ._workers ):
338- f [ - 1 ] .append (
346+ f .append (
339347 client .submit (
340348 _calc_fields ,
341349 objfct ,
342350 m_future ,
343351 workers = worker ,
352+ priority = priority ,
344353 )
345354 )
346355 self ._stashed_fields = f
@@ -367,14 +376,17 @@ def model(self, value):
367376 [self ._m_as_future ] = client .scatter ([value ], broadcast = True )
368377
369378 stores = []
370- for futures in self ._futures :
379+ for ind , futures in enumerate (self ._futures ):
380+
381+ priority = len (self ._futures ) - ind # reverse order for priority
371382 for objfct , worker in zip (futures , self ._workers ):
372383 stores .append (
373384 client .submit (
374385 _store_model ,
375386 objfct ,
376387 self ._m_as_future ,
377388 workers = worker ,
389+ priority = priority ,
378390 )
379391 )
380392 self .client .gather (stores ) # blocking call to ensure all models were stored
@@ -418,19 +430,22 @@ def residuals(self, m, f=None):
418430 client = self .client
419431 m_future = self ._m_as_future
420432 residuals = []
433+ future_residuals = []
434+ for ind , futures in enumerate (self ._futures ):
435+
436+ priority = len (self ._futures ) - ind # reverse order for priority
421437
422- for futures in self ._futures :
423- future_residuals = []
424438 for objfct , worker in zip (futures , self ._workers ):
425439 future_residuals .append (
426440 client .submit (
427441 _calc_residual ,
428442 objfct ,
429443 m_future ,
430444 workers = worker ,
445+ priority = priority ,
431446 )
432447 )
433- residuals + = client .gather (future_residuals )
448+ residuals = client .gather (future_residuals )
434449
435450 return residuals
436451
0 commit comments