Skip to content

Commit 62354b0

Browse files
committed
keep dpred and residuals as list for joint problems
1 parent 8829a55 commit 62354b0

4 files changed

Lines changed: 29 additions & 20 deletions

File tree

simpeg/dask/inverse_problem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def get_dpred(self, m, f=None, return_residuals=False):
4646
)
4747

4848
if return_residuals:
49-
return np.hstack(results[0]), np.hstack(results[1])
49+
return results[0], results[1]
5050

51-
return np.hstack(results)
51+
return results
5252

5353

5454
BaseInvProblem.get_dpred = get_dpred
@@ -61,7 +61,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True):
6161
self.model = m
6262
self.dpred, self.residuals = self.get_dpred(m, return_residuals=True)
6363

64-
phi_d = np.vdot(self.residuals, self.residuals)
64+
phi_d = (np.hstack(self.residuals) ** 2.0).sum()
6565

6666
reg2Deriv = []
6767
if isinstance(self.reg, ComboObjectiveFunction):

simpeg/dask/objective_function.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,39 +149,49 @@ def _validate_type_or_future_of_type(
149149
objects = validate_list_of_types(
150150
property_name, objects, obj_type, ensure_unique=True
151151
)
152-
workload = [[]]
152+
workloads = {}
153+
for worker in workers:
154+
workloads[worker] = []
153155

154156
count = 0
155-
for obj in objects:
156-
if count == len(workers):
157-
count = 0
158-
workload.append([])
157+
for ii, obj in enumerate(objects):
158+
count = ii % len(workers)
159159

160160
if isinstance(obj, Future):
161161
future = obj
162+
count = workers.index(client.who_has(obj)[obj.key])
162163
else:
163164
future = client.scatter([obj], workers=workers[count])[0]
164165

165-
workload[-1].append(future)
166-
count += 1
166+
workloads[workers[count]].append(future)
167167

168168
futures = []
169169
assignments = []
170-
for work in workload:
171-
for obj, worker in zip(work, workers):
170+
for worker, work in workloads.items():
171+
for future in work:
172172
futures.append(
173173
client.submit(
174-
lambda v: not isinstance(v, obj_type), obj, workers=worker
174+
lambda v: not isinstance(v, obj_type), future, workers=worker
175175
)
176176
)
177-
assignments.append(client.submit(_set_worker, obj, worker, workers=worker))
177+
assignments.append(
178+
client.submit(_set_worker, future, worker, workers=worker)
179+
)
178180

179181
client.gather(assignments)
180182

181183
is_not_obj = np.array(client.gather(futures))
182184
if np.any(is_not_obj):
183185
raise TypeError(f"{property_name} futures must be an instance of {obj_type}")
184186

187+
# Re-distribute the workload to ensure all workers are equally loaded
188+
workload = []
189+
for work in workloads.values():
190+
for ii, future in enumerate(work):
191+
if len(workload) <= ii:
192+
workload.append([])
193+
workload[ii].append(future)
194+
185195
if return_workers:
186196
return workload, workers
187197
else:
@@ -382,9 +392,9 @@ def get_dpred(self, m, f=None, return_residuals=False):
382392
dpred += [result]
383393

384394
if return_residuals:
385-
return np.hstack(dpred), np.hstack(residuals)
395+
return dpred, residuals
386396

387-
return np.hstack(dpred)
397+
return dpred
388398

389399
def getJtJdiag(self, m, f=None):
390400
"""

simpeg/directives/_regularization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def misfit_from_chi_factor(self, chi_factor: float) -> float:
218218
chi_factor : float
219219
Chi factor to compute the target misfit from.
220220
"""
221-
return self.invProb.dpred.shape[0] * chi_factor
221+
return np.hstack(self.invProb.dpred).shape[0] * chi_factor
222222

223223
def adjust_cooling_schedule(self):
224224
"""

simpeg/directives/directives.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3058,10 +3058,9 @@ def initialize(self):
30583058
def endIter(self):
30593059
ratio = self.invProb.beta / self.last_beta
30603060
chi_factors = []
3061-
for objfct, pred in zip(self.invProb.dmisfit.objfcts, self.invProb.dpred):
3062-
residual = objfct.W * (objfct.data.dobs - pred)
3061+
for residual in self.invProb.residuals:
30633062
phi_d = np.vdot(residual, residual)
3064-
chi_factors.append(phi_d / objfct.nD)
3063+
chi_factors.append(phi_d / len(residual))
30653064

30663065
self.chi_factors = np.asarray(chi_factors)
30673066

0 commit comments

Comments
 (0)