Skip to content

Commit a28d832

Browse files
committed
Bring back getJtJdiag for mag to deal with spherical
1 parent dd10bdf commit a28d832

1 file changed

Lines changed: 36 additions & 2 deletions

File tree

simpeg/dask/potential_fields/magnetics/simulation.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,40 @@
1+
import numpy as np
12
from ....potential_fields.magnetics import Simulation3DIntegral as Sim
2-
from ...simulation import getJtJdiag
3+
from ....utils import sdiag, mkvc
4+
5+
6+
def dask_getJtJdiag(self, m, W=None, f=None):
7+
"""
8+
Return the diagonal of JtJ
9+
"""
10+
11+
self.model = m
12+
13+
self.model = m
14+
if W is None:
15+
W = np.ones(self.Jmatrix.shape[0])
16+
else:
17+
W = W.diagonal()
18+
19+
if getattr(self, "_gtg_diagonal", None) is None:
20+
if not self.is_amplitude_data:
21+
diag = np.asarray(np.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix))
22+
else:
23+
ampDeriv = self.ampDeriv
24+
J = (
25+
ampDeriv[0, :, None] * self.Jmatrix[::3]
26+
+ ampDeriv[1, :, None] * self.Jmatrix[1::3]
27+
+ ampDeriv[2, :, None] * self.Jmatrix[2::3]
28+
)
29+
diag = ((W[:, None] * J) ** 2).sum(axis=0).compute()
30+
self._gtg_diagonal = diag
31+
else:
32+
diag = self._gtg_diagonal
33+
34+
return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0))
35+
36+
37+
Sim.getJtJdiag = dask_getJtJdiag
338

439

540
@property
@@ -14,5 +49,4 @@ def G(self):
1449

1550

1651
Sim._delete_on_model_update = []
17-
Sim.getJtJdiag = getJtJdiag
1852
Sim.G = G

0 commit comments

Comments
 (0)