1+ from dask import array , compute , delayed
12import numpy as np
23from ....potential_fields .magnetics import Simulation3DDifferential as Sim
34from ....utils import sdiag , mkvc
45
56
6- def distance_weights (locations , cell_centers , exponent = 3 , threshold = 1e-2 ):
7+ def distance_weights (locations , cell_centers , cell_volumes , exponent = 3 , threshold = 1e-2 ):
78 distance_weights = np .zeros (len (cell_centers ))
8- for ind , loc in enumerate ( locations ) :
9+ for loc in locations :
910 distance = np .linalg .norm (cell_centers - loc , axis = 1 )
10- distance_weights += (distance + threshold ) ** (- 2 * exponent )
11+ distance_weights += cell_volumes ** 2.0 * (distance + threshold ) ** (
12+ - 2 * exponent
13+ )
1114
1215 return distance_weights
1316
17+
1418def dask_getJtJdiag (self , m , W = None , f = None ):
1519 """
1620 Return the diagonal of JtJ
@@ -30,9 +34,11 @@ def dask_getJtJdiag(self, m, W=None, f=None):
3034
3135 chunks = np .array_split (self .survey .receiver_locations , n_threads )
3236 cell_centers = self .mesh .cell_centers .copy ()
37+ cell_volumes = self .mesh .cell_volumes .copy ()
3338
3439 if client :
3540 cell_centers = client .scatter (cell_centers , workers = worker )
41+ cell_volumes = client .scatter (cell_volumes , workers = worker )
3642 else :
3743 delayed_distance_weights = delayed (distance_weights )
3844
@@ -44,15 +50,17 @@ def dask_getJtJdiag(self, m, W=None, f=None):
4450 distance_weights ,
4551 block ,
4652 cell_centers ,
53+ cell_volumes ,
4754 workers = worker ,
4855 )
4956 )
5057 else :
5158 futures .append (
5259 array .from_delayed (
53- delayed_compute_rows (
60+ delayed_distance_weights (
5461 block ,
5562 cell_centers ,
63+ cell_volumes ,
5664 ),
5765 dtype = np .float32 ,
5866 shape = (
@@ -67,8 +75,8 @@ def dask_getJtJdiag(self, m, W=None, f=None):
6775 else :
6876 diag = compute (futures )
6977
70- diag = np .tile (np .vstack (diag ).sum (axis = 0 ) * self . mesh . cell_volumes ** 2. , 3 ) ** 0.5
78+ diag = np .tile (np .vstack (diag ).sum (axis = 0 ), 3 )
7179 return mkvc ((sdiag (np .sqrt (diag )) @ self .remDeriv ).power (2 ).sum (axis = 0 ))
7280
7381
74- Sim .getJtJdiag = dask_getJtJdiag
82+ Sim .getJtJdiag = dask_getJtJdiag
0 commit comments