|
1 | 1 | from ..inverse_problem import BaseInvProblem |
2 | 2 | import numpy as np |
3 | 3 |
|
4 | | -from .objective_function import DaskComboMisfits |
| 4 | +from .objective_function import DistributedComboMisfits |
5 | 5 | from scipy.sparse.linalg import LinearOperator |
6 | 6 | from ..regularization import WeightedLeastSquares, Sparse |
7 | 7 | from ..objective_function import ComboObjectiveFunction |
8 | 8 | from simpeg.utils import call_hooks |
9 | 9 | from simpeg.version import __version__ as simpeg_version |
10 | 10 |
|
11 | 11 |
|
12 | | -def get_dpred(self, m, f=None): |
| 12 | +def get_nested_predicted(objfcts, m, f=None, return_residuals=False): |
13 | 13 | dpreds = [] |
| 14 | + residuals = [] |
| 15 | + for objfct in objfcts: |
| 16 | + |
| 17 | + if isinstance(objfct, ComboObjectiveFunction): |
| 18 | + nesting = get_nested_predicted( |
| 19 | + objfct.objfcts, m, f=f, return_residuals=return_residuals |
| 20 | + ) |
14 | 21 |
|
15 | | - if isinstance(self.dmisfit, DaskComboMisfits): |
16 | | - return self.dmisfit.get_dpred(m, f=f) |
| 22 | + if return_residuals: |
| 23 | + dpreds += nesting[0] |
| 24 | + residuals += nesting[1] |
| 25 | + else: |
| 26 | + dpreds += nesting |
| 27 | + else: |
| 28 | + dpred = objfct.simulation.dpred(m, f=f) |
| 29 | + dpreds += [np.asarray(dpred)] |
17 | 30 |
|
18 | | - for objfct in self.dmisfit.objfcts: |
19 | | - dpred = objfct.simulation.dpred(m, f=f) |
20 | | - dpreds += [np.asarray(dpred)] |
| 31 | + if return_residuals: |
| 32 | + residual = objfct.W * (objfct.data.dobs - dpred) |
| 33 | + residuals += [np.asarray(residual)] |
21 | 34 |
|
| 35 | + if return_residuals: |
| 36 | + return dpreds, residuals |
22 | 37 | return dpreds |
23 | 38 |
|
24 | 39 |
|
| 40 | +def get_dpred(self, m, f=None, return_residuals=False): |
| 41 | + if isinstance(self.dmisfit, DistributedComboMisfits): |
| 42 | + results = self.dmisfit.get_dpred(m, f=f, return_residuals=return_residuals) |
| 43 | + else: |
| 44 | + results = get_nested_predicted( |
| 45 | + self.dmisfit.objfcts, m, f=f, return_residuals=return_residuals |
| 46 | + ) |
| 47 | + |
| 48 | + if return_residuals: |
| 49 | + return np.hstack(results[0]), np.hstack(results[1]) |
| 50 | + |
| 51 | + return np.hstack(results) |
| 52 | + |
| 53 | + |
25 | 54 | BaseInvProblem.get_dpred = get_dpred |
26 | 55 |
|
27 | 56 |
|
28 | 57 | def dask_evalFunction(self, m, return_g=True, return_H=True): |
29 | 58 | """evalFunction(m, return_g=True, return_H=True)""" |
30 | | - self.model = m |
31 | | - self.dpred = self.get_dpred(m) |
32 | | - residuals = [] |
33 | 59 |
|
34 | | - if isinstance(self.dmisfit, DaskComboMisfits): |
35 | | - residuals = self.dmisfit.residuals(m) |
36 | | - else: |
37 | | - for (_, objfct), pred in zip(self.dmisfit, self.dpred): |
38 | | - residuals.append(objfct.W * (objfct.data.dobs - pred)) |
| 60 | + if not np.allclose(self.model, m): |
| 61 | + self.model = m |
| 62 | + self.dpred, self.residuals = self.get_dpred(m, return_residuals=True) |
39 | 63 |
|
40 | | - phi_d = 0.0 |
41 | | - for residual in residuals: |
42 | | - phi_d += np.vdot(residual, residual) |
| 64 | + phi_d = np.vdot(self.residuals, self.residuals) |
43 | 65 |
|
44 | 66 | reg2Deriv = [] |
45 | 67 | if isinstance(self.reg, ComboObjectiveFunction): |
|
0 commit comments