Skip to content

Commit 825b0b7

Browse files
authored
Merge pull request #119 from MiraGeoscience/GEOPY-2466
GEOPY-2466: Parallelize the Sweeps
2 parents 6e8d69f + 3f33a28 commit 825b0b7

6 files changed

Lines changed: 112 additions & 122 deletions

File tree

simpeg/dask/electromagnetics/frequency_domain/simulation.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import scipy.sparse as sp
88

99
from dask import array, compute, delayed
10-
from dask.distributed import get_client
1110
from simpeg.dask.utils import get_parallel_blocks
1211
from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary
1312
import zarr
@@ -104,20 +103,17 @@ def getSourceTerm(self, freq, source=None):
104103

105104
if source is None:
106105

107-
try:
108-
client = get_client()
109-
sim = client.scatter(self, workers=self.worker)
110-
except ValueError:
111-
client = None
112-
sim = self
113-
106+
client, worker = self._get_client_worker()
114107
source_list = self.survey.get_sources_by_frequency(freq)
115108
source_blocks = np.array_split(
116-
np.arange(len(source_list)), self.n_threads(client=client)
109+
np.arange(len(source_list)), self.n_threads(client=client, worker=worker)
117110
)
118111

119112
if client:
120-
source_list = client.scatter(source_list, workers=self.worker)
113+
sim = client.scatter(self, workers=self.worker)
114+
source_list = client.scatter(source_list, workers=worker)
115+
else:
116+
sim = self
121117

122118
block_compute = []
123119

@@ -127,9 +123,7 @@ def getSourceTerm(self, freq, source=None):
127123

128124
if client:
129125
block_compute.append(
130-
client.submit(
131-
source_eval, sim, source_list, block, workers=self.worker
132-
)
126+
client.submit(source_eval, sim, source_list, block, workers=worker)
133127
)
134128
else:
135129
block_compute.append(source_eval(sim, source_list, block))
@@ -221,12 +215,7 @@ def compute_J(self, m, f=None):
221215
fields_array = f[:, self._solutionType]
222216
blocks_receiver_derivs = []
223217

224-
try:
225-
client = get_client()
226-
worker = self.worker
227-
except ValueError:
228-
client = None
229-
worker = None
218+
client, worker = self._get_client_worker()
230219

231220
if client:
232221
fields_array = client.scatter(f[:, self._solutionType], workers=worker)

simpeg/dask/electromagnetics/static/resistivity/simulation.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix
44

55
from .....utils import Zero
6-
from dask.distributed import get_client
6+
77
import dask.array as da
88
import numpy as np
99
from scipy import sparse as sp
@@ -163,22 +163,23 @@ def getSourceTerm(self):
163163
source_list = self.survey.source_list
164164

165165
indices = np.arange(len(source_list))
166-
try:
167166

168-
client = get_client()
169-
sim = client.scatter(self, workers=self.worker)
170-
future_list = client.scatter(source_list, workers=self.worker)
171-
indices = np.array_split(indices, self.n_threads(client=client))
167+
client, worker = self._get_client_worker()
168+
169+
if client:
170+
sim = client.scatter(self, workers=worker)
171+
future_list = client.scatter(source_list, workers=worker)
172+
indices = np.array_split(
173+
indices, self.n_threads(client=client, worker=worker)
174+
)
172175
blocks = []
173176
for ind in indices:
174177
blocks.append(
175-
client.submit(
176-
source_eval, sim, future_list, ind, workers=self.worker
177-
)
178+
client.submit(source_eval, sim, future_list, ind, workers=worker)
178179
)
179180

180181
blocks = sp.hstack(client.gather(blocks))
181-
except ValueError:
182+
else:
182183
blocks = source_eval(self, source_list, indices)
183184

184185
self._q = blocks

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import scipy.sparse as sp
1111
from dask import array, delayed
12-
from dask.distributed import get_client
12+
1313

1414
from time import time
1515
from simpeg.dask.utils import get_parallel_blocks
@@ -89,10 +89,7 @@ def compute_J(self, m, f=None):
8989
if f is None:
9090
f, Ainv = self.fields(m=m, return_Ainv=True)
9191

92-
try:
93-
client = get_client()
94-
except ValueError:
95-
client = None
92+
client, worker = self._get_client_worker()
9693

9794
ftype = self._fieldType + "Solution"
9895
sens_name = self.sensitivity_path[:-5]
@@ -118,22 +115,22 @@ def compute_J(self, m, f=None):
118115
blocks = get_parallel_blocks(
119116
self.survey.source_list,
120117
compute_row_size,
121-
thread_count=self.n_threads(client=client),
118+
thread_count=self.n_threads(client=client, worker=worker),
122119
)
123120
fields_array = f[:, ftype, :]
124121

125122
if len(self.survey.source_list) == 1:
126123
fields_array = fields_array[:, np.newaxis, :]
127124

128125
times_field_derivs, Jmatrix = compute_field_derivs(
129-
self, f, blocks, Jmatrix, fields_array.shape, client
126+
self, f, blocks, Jmatrix, fields_array.shape
130127
)
131128

132129
ATinv_df_duT_v = [[] for _ in blocks]
133130

134131
if client:
135-
fields_array = client.scatter(fields_array, workers=self.worker)
136-
sim = client.scatter(self, workers=self.worker)
132+
fields_array = client.scatter(fields_array, workers=worker)
133+
sim = client.scatter(self, workers=worker)
137134
else:
138135
delayed_compute_rows = delayed(compute_rows)
139136
sim = self
@@ -161,7 +158,7 @@ def compute_J(self, m, f=None):
161158
)
162159

163160
if client:
164-
field_derivatives = client.scatter(ATinv_df_duT_v, workers=self.worker)
161+
field_derivatives = client.scatter(ATinv_df_duT_v, workers=worker)
165162
else:
166163
field_derivatives = ATinv_df_duT_v
167164

@@ -181,7 +178,7 @@ def compute_J(self, m, f=None):
181178
field_derivatives,
182179
fields_array,
183180
time_mask,
184-
workers=self.worker,
181+
workers=worker,
185182
)
186183
)
187184
else:
@@ -267,17 +264,19 @@ def evaluate_receivers(block, mesh, time_mesh, fields, fields_array):
267264
return np.hstack(data)
268265

269266

270-
def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client):
267+
def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape):
271268
"""
272269
Compute the derivative of the fields
273270
"""
274271
delayed_chunks = []
275272

273+
client, worker = self._get_client_worker()
274+
276275
if client:
277-
mesh = client.scatter(self.mesh, workers=self.worker)
278-
time_mesh = client.scatter(self.time_mesh, workers=self.worker)
279-
fields = client.scatter(fields, workers=self.worker)
280-
source_list = client.scatter(self.survey.source_list, workers=self.worker)
276+
mesh = client.scatter(self.mesh, workers=worker)
277+
time_mesh = client.scatter(self.time_mesh, workers=worker)
278+
fields = client.scatter(fields, workers=worker)
279+
source_list = client.scatter(self.survey.source_list, workers=worker)
281280
else:
282281
mesh = self.mesh
283282
time_mesh = self.time_mesh
@@ -300,7 +299,7 @@ def compute_field_derivs(self, fields, blocks, Jmatrix, fields_shape, client):
300299
time_mesh,
301300
fields,
302301
self.model.size,
303-
workers=self.worker,
302+
workers=worker,
304303
)
305304
)
306305
else:
@@ -537,24 +536,22 @@ def dpred(self, m=None, f=None):
537536
"simulation.survey = survey"
538537
)
539538

540-
try:
541-
client = get_client()
542-
except ValueError:
543-
client = None
539+
client, worker = self._get_client_worker()
544540

545541
if f is None:
546542
f = self.fields(m)
547543

548544
delayed_chunks = []
549545

550546
source_block = np.array_split(
551-
np.arange(len(self.survey.source_list)), self.n_threads(client=client)
547+
np.arange(len(self.survey.source_list)),
548+
self.n_threads(client=client, worker=worker),
552549
)
553550
if client:
554-
mesh = client.scatter(self.mesh, workers=self.worker)
555-
time_mesh = client.scatter(self.time_mesh, workers=self.worker)
556-
fields = client.scatter(f, workers=self.worker)
557-
source_list = client.scatter(self.survey.source_list, workers=self.worker)
551+
mesh = client.scatter(self.mesh, workers=worker)
552+
time_mesh = client.scatter(self.time_mesh, workers=worker)
553+
fields = client.scatter(f, workers=worker)
554+
source_list = client.scatter(self.survey.source_list, workers=worker)
558555
else:
559556
mesh = self.mesh
560557
time_mesh = self.time_mesh
@@ -575,7 +572,7 @@ def dpred(self, m=None, f=None):
575572
mesh,
576573
time_mesh,
577574
fields,
578-
workers=self.worker,
575+
workers=worker,
579576
)
580577
)
581578
else:

simpeg/dask/objective_function.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Callable
88
import numpy as np
99

10-
from dask.distributed import Client, Future
10+
from dask.distributed import Client, Future, get_client
1111
from ..data_misfit import L2DataMisfit
1212

1313
from simpeg.utils import validate_list_of_types
@@ -145,9 +145,7 @@ def _validate_type_or_future_of_type(
145145
):
146146

147147
if workers is None:
148-
workers = [
149-
(worker.worker_address,) for worker in client.cluster.workers.values()
150-
]
148+
workers = [(worker,) for worker in client.nthreads()]
151149

152150
objects = validate_list_of_types(
153151
property_name, objects, obj_type, ensure_unique=True
@@ -262,6 +260,9 @@ def client(self):
262260

263261
@client.setter
264262
def client(self, client):
263+
if client is None:
264+
client = get_client()
265+
265266
if not isinstance(client, Client):
266267
raise TypeError("client must be a dask.distributed.Client")
267268

@@ -279,6 +280,23 @@ def workers(self, workers):
279280
if not isinstance(workers, list | type(None)):
280281
raise TypeError("workers must be a list of strings")
281282

283+
available_workers = [(worker,) for worker in self.client.nthreads()]
284+
285+
if workers is None:
286+
workers = available_workers
287+
288+
if not isinstance(workers, list) or not all(
289+
isinstance(w, tuple) for w in workers
290+
):
291+
raise TypeError("Workers must be a list of tuple[str].")
292+
293+
invalid_workers = [w for w in workers if w not in available_workers]
294+
if invalid_workers:
295+
raise ValueError(
296+
f"The following workers are not available: {invalid_workers}. "
297+
f"Available workers are: {available_workers}."
298+
)
299+
282300
self._workers = workers
283301

284302
def deriv(self, m, f=None):

simpeg/dask/potential_fields/base.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from ...potential_fields.base import BasePFSimulation as Sim
4-
from dask.distributed import get_client
4+
55
import os
66
from dask import delayed, array, config
77
from dask.diagnostics import ProgressBar
@@ -60,13 +60,10 @@ def linear_operator(self):
6060
)
6161
block_split = np.array_split(self.survey.receiver_locations, n_blocks)
6262

63-
try:
64-
client = get_client()
65-
except ValueError:
66-
client = None
63+
client, worker = self._get_client_worker()
6764

6865
if client:
69-
sim = client.scatter(self, workers=self.worker)
66+
sim = client.scatter(self, workers=worker)
7067
else:
7168
delayed_compute = delayed(block_compute)
7269

@@ -79,7 +76,7 @@ def linear_operator(self):
7976
sim,
8077
block,
8178
self.survey.components,
82-
workers=self.worker,
79+
workers=worker,
8380
)
8481
)
8582
else:

0 commit comments

Comments
 (0)