Skip to content

Commit 092f610

Browse files
committed
Review distance weights comps
1 parent 89c6ac1 commit 092f610

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

simpeg/dask/potential_fields/magnetics/simulation_pde.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from dask import array, compute, delayed
12
import numpy as np
23
from ....potential_fields.magnetics import Simulation3DDifferential as Sim
34
from ....utils import sdiag, mkvc
45

56

6-
def distance_weights(locations, cell_centers, exponent=3, threshold=1e-2):
7+
def distance_weights(locations, cell_centers, cell_volumes, exponent=3, threshold=1e-2):
78
distance_weights = np.zeros(len(cell_centers))
8-
for ind, loc in enumerate(locations):
9+
for loc in locations:
910
distance = np.linalg.norm(cell_centers - loc, axis=1)
10-
distance_weights += (distance + threshold) ** (-2 * exponent)
11+
distance_weights += cell_volumes**2.0 * (distance + threshold) ** (
12+
-2 * exponent
13+
)
1114

1215
return distance_weights
1316

17+
1418
def dask_getJtJdiag(self, m, W=None, f=None):
1519
"""
1620
Return the diagonal of JtJ
@@ -30,9 +34,11 @@ def dask_getJtJdiag(self, m, W=None, f=None):
3034

3135
chunks = np.array_split(self.survey.receiver_locations, n_threads)
3236
cell_centers = self.mesh.cell_centers.copy()
37+
cell_volumes = self.mesh.cell_volumes.copy()
3338

3439
if client:
3540
cell_centers = client.scatter(cell_centers, workers=worker)
41+
cell_volumes = client.scatter(cell_volumes, workers=worker)
3642
else:
3743
delayed_distance_weights = delayed(distance_weights)
3844

@@ -44,15 +50,17 @@ def dask_getJtJdiag(self, m, W=None, f=None):
4450
distance_weights,
4551
block,
4652
cell_centers,
53+
cell_volumes,
4754
workers=worker,
4855
)
4956
)
5057
else:
5158
futures.append(
5259
array.from_delayed(
53-
delayed_compute_rows(
60+
delayed_distance_weights(
5461
block,
5562
cell_centers,
63+
cell_volumes,
5664
),
5765
dtype=np.float32,
5866
shape=(
@@ -67,8 +75,8 @@ def dask_getJtJdiag(self, m, W=None, f=None):
6775
else:
6876
diag = compute(futures)
6977

70-
diag = np.tile(np.vstack(diag).sum(axis=0) * self.mesh.cell_volumes**2.,3)**0.5
78+
diag = np.tile(np.vstack(diag).sum(axis=0), 3)
7179
return mkvc((sdiag(np.sqrt(diag)) @ self.remDeriv).power(2).sum(axis=0))
7280

7381

74-
Sim.getJtJdiag = dask_getJtJdiag
82+
Sim.getJtJdiag = dask_getJtJdiag

0 commit comments

Comments
 (0)