Skip to content

Commit 13bf9ce

Browse files
committed
Use priority to send queue all jobs on workers
1 parent 74c3cb8 commit 13bf9ce

1 file changed

Lines changed: 45 additions & 30 deletions

File tree

simpeg/dask/objective_function.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def __call__(self, m, f=None):
132132

133133
values = []
134134
count = 0
135-
for futures in self._futures:
135+
for ind, futures in enumerate(self._futures):
136+
137+
priority = len(self._futures) - ind # reverse order for priority
136138
for objfct, worker in zip(futures, self._workers, strict=True):
137139

138140
if self.multipliers[count] == 0.0:
@@ -145,6 +147,7 @@ def __call__(self, m, f=None):
145147
self.multipliers[count],
146148
m_future,
147149
workers=worker,
150+
priority=priority,
148151
)
149152
)
150153
count += 1
@@ -193,11 +196,12 @@ def deriv(self, m, f=None):
193196
client = self.client
194197
m_future = self._m_as_future
195198

196-
derivs = 0.0
197199
count = 0
200+
future_deriv = []
201+
for ind, futures in enumerate(self._futures):
202+
203+
priority = len(self._futures) - ind # reverse order for priority
198204

199-
for futures in self._futures:
200-
future_deriv = []
201205
for objfct, worker in zip(futures, self._workers):
202206
if self.multipliers[count] == 0.0: # don't evaluate the fct
203207
continue
@@ -209,13 +213,14 @@ def deriv(self, m, f=None):
209213
self.multipliers[count],
210214
m_future,
211215
workers=worker,
216+
priority=priority,
212217
)
213218
)
214219

215220
count += 1
216-
future_deriv = client.gather(future_deriv)
221+
future_deriv = client.gather(future_deriv)
217222

218-
derivs += np.sum(future_deriv, axis=0)
223+
derivs = np.sum(future_deriv, axis=0)
219224

220225
return derivs
221226

@@ -234,12 +239,11 @@ def deriv2(self, m, v=None, f=None):
234239
m_future = self._m_as_future
235240
[v_future] = client.scatter([v], broadcast=True)
236241

237-
derivs = 0.0
238242
count = 0
243+
future_derivs = []
244+
for ind, futures in enumerate(self._futures):
239245

240-
for futures in self._futures:
241-
242-
future_derivs = []
246+
priority = len(self._futures) - ind # reverse order for priority
243247
for objfct, worker in zip(futures, self._workers):
244248
if self.multipliers[count] == 0.0: # don't evaluate the fct
245249
continue
@@ -253,12 +257,13 @@ def deriv2(self, m, v=None, f=None):
253257
v_future,
254258
# field,
255259
workers=worker,
260+
priority=priority,
256261
)
257262
)
258263
count += 1
259264

260-
future_derivs = self.client.gather(future_derivs)
261-
derivs += np.sum(future_derivs, axis=0)
265+
future_derivs = self.client.gather(future_derivs)
266+
derivs = np.sum(future_derivs, axis=0)
262267

263268
return derivs
264269

@@ -270,20 +275,22 @@ def get_dpred(self, m, f=None):
270275

271276
client = self.client
272277
m_future = self._m_as_future
273-
dpred = []
278+
# dpred = []
279+
future_preds = []
280+
for ind, futures in enumerate(self._futures):
274281

275-
for futures in self._futures:
276-
future_preds = []
282+
priority = len(self._futures) - ind # reverse order for priority
277283
for objfct, worker in zip(futures, self._workers):
278284
future_preds.append(
279285
client.submit(
280286
_calc_dpred,
281287
objfct,
282288
m_future,
283289
workers=worker,
290+
priority=priority,
284291
)
285292
)
286-
dpred += client.gather(future_preds)
293+
dpred = client.gather(future_preds)
287294

288295
return dpred
289296

@@ -294,25 +301,24 @@ def getJtJdiag(self, m, f=None):
294301
self.model = m
295302
m_future = self._m_as_future
296303
if getattr(self, "_jtjdiag", None) is None:
297-
298-
jtj_diag = 0.0
299304
client = self.client
305+
work = []
306+
for ind, futures in enumerate(self._futures):
300307

301-
for futures in self._futures:
302-
work = []
303-
308+
priority = len(self._futures) - ind # reverse order for priority
304309
for objfct, worker in zip(futures, self._workers):
305310
work.append(
306311
client.submit(
307312
_get_jtj_diag,
308313
objfct,
309314
m_future,
310315
workers=worker,
316+
priority=priority,
311317
)
312318
)
313319

314-
work = client.gather(work)
315-
jtj_diag += np.sum(work, axis=0)
320+
work = client.gather(work)
321+
jtj_diag = np.sum(work, axis=0)
316322

317323
self._jtjdiag = jtj_diag
318324

@@ -332,15 +338,18 @@ def fields(self, m):
332338
# The above should pass the model to all the internal simulations.
333339
f = []
334340

335-
for futures in self._futures:
336-
f.append([])
341+
for ind, futures in enumerate(self._futures):
342+
343+
priority = len(self._futures) - ind # reverse order for priority
344+
337345
for objfct, worker in zip(futures, self._workers):
338-
f[-1].append(
346+
f.append(
339347
client.submit(
340348
_calc_fields,
341349
objfct,
342350
m_future,
343351
workers=worker,
352+
priority=priority,
344353
)
345354
)
346355
self._stashed_fields = f
@@ -367,14 +376,17 @@ def model(self, value):
367376
[self._m_as_future] = client.scatter([value], broadcast=True)
368377

369378
stores = []
370-
for futures in self._futures:
379+
for ind, futures in enumerate(self._futures):
380+
381+
priority = len(self._futures) - ind # reverse order for priority
371382
for objfct, worker in zip(futures, self._workers):
372383
stores.append(
373384
client.submit(
374385
_store_model,
375386
objfct,
376387
self._m_as_future,
377388
workers=worker,
389+
priority=priority,
378390
)
379391
)
380392
self.client.gather(stores) # blocking call to ensure all models were stored
@@ -418,19 +430,22 @@ def residuals(self, m, f=None):
418430
client = self.client
419431
m_future = self._m_as_future
420432
residuals = []
433+
future_residuals = []
434+
for ind, futures in enumerate(self._futures):
435+
436+
priority = len(self._futures) - ind # reverse order for priority
421437

422-
for futures in self._futures:
423-
future_residuals = []
424438
for objfct, worker in zip(futures, self._workers):
425439
future_residuals.append(
426440
client.submit(
427441
_calc_residual,
428442
objfct,
429443
m_future,
430444
workers=worker,
445+
priority=priority,
431446
)
432447
)
433-
residuals += client.gather(future_residuals)
448+
residuals = client.gather(future_residuals)
434449

435450
return residuals
436451

0 commit comments

Comments
 (0)