1+ import numpy as np
12from ....potential_fields .magnetics import Simulation3DIntegral as Sim
2- from ...simulation import getJtJdiag
3+ from ....utils import sdiag , mkvc
4+
5+
6+ def dask_getJtJdiag (self , m , W = None , f = None ):
7+ """
8+ Return the diagonal of JtJ
9+ """
10+
11+ self .model = m
12+
13+ self .model = m
14+ if W is None :
15+ W = np .ones (self .Jmatrix .shape [0 ])
16+ else :
17+ W = W .diagonal ()
18+
19+ if getattr (self , "_gtg_diagonal" , None ) is None :
20+ if not self .is_amplitude_data :
21+ diag = np .asarray (np .einsum ("i,ij,ij->j" , W ** 2 , self .Jmatrix , self .Jmatrix ))
22+ else :
23+ ampDeriv = self .ampDeriv
24+ J = (
25+ ampDeriv [0 , :, None ] * self .Jmatrix [::3 ]
26+ + ampDeriv [1 , :, None ] * self .Jmatrix [1 ::3 ]
27+ + ampDeriv [2 , :, None ] * self .Jmatrix [2 ::3 ]
28+ )
29+ diag = ((W [:, None ] * J ) ** 2 ).sum (axis = 0 ).compute ()
30+ self ._gtg_diagonal = diag
31+ else :
32+ diag = self ._gtg_diagonal
33+
34+ return mkvc ((sdiag (np .sqrt (diag )) @ self .chiDeriv ).power (2 ).sum (axis = 0 ))
35+
36+
37+ Sim .getJtJdiag = dask_getJtJdiag
338
439
540@property
@@ -14,5 +49,4 @@ def G(self):
1449
1550
1651Sim ._delete_on_model_update = []
17- Sim .getJtJdiag = getJtJdiag
1852Sim .G = G
0 commit comments