Skip to content

Commit 89c6ac1

Browse files
committed
Add dask method for getjtjdiag for mvi pde
1 parent 6eecc1c commit 89c6ac1

3 files changed

Lines changed: 77 additions & 2 deletions

File tree

simpeg/dask/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import simpeg.dask.potential_fields.base
1212
import simpeg.dask.potential_fields.gravity.simulation
1313
import simpeg.dask.potential_fields.magnetics.simulation
14+
import simpeg.dask.potential_fields.magnetics.simulation_pde
1415
import simpeg.dask.simulation
1516
import simpeg.dask.inverse_problem
1617
import simpeg.dask.objective_function
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import numpy as np
2+
from ....potential_fields.magnetics import Simulation3DDifferential as Sim
3+
from ....utils import sdiag, mkvc
4+
5+
6+
def distance_weights(locations, cell_centers, exponent=3, threshold=1e-2):
7+
distance_weights = np.zeros(len(cell_centers))
8+
for ind, loc in enumerate(locations):
9+
distance = np.linalg.norm(cell_centers - loc, axis=1)
10+
distance_weights += (distance + threshold) ** (-2 * exponent)
11+
12+
return distance_weights
13+
14+
def dask_getJtJdiag(self, m, W=None, f=None):
15+
"""
16+
Return the diagonal of JtJ
17+
"""
18+
19+
self.model = m
20+
21+
self.model = m
22+
if W is None:
23+
W = np.ones(self.Jmatrix.shape[0])
24+
else:
25+
W = W.diagonal()
26+
27+
client, worker = self._get_client_worker()
28+
29+
n_threads = self.n_threads(client=client, worker=worker)
30+
31+
chunks = np.array_split(self.survey.receiver_locations, n_threads)
32+
cell_centers = self.mesh.cell_centers.copy()
33+
34+
if client:
35+
cell_centers = client.scatter(cell_centers, workers=worker)
36+
else:
37+
delayed_distance_weights = delayed(distance_weights)
38+
39+
futures = []
40+
for block in chunks:
41+
if client:
42+
futures.append(
43+
client.submit(
44+
distance_weights,
45+
block,
46+
cell_centers,
47+
workers=worker,
48+
)
49+
)
50+
else:
51+
futures.append(
52+
array.from_delayed(
53+
delayed_compute_rows(
54+
block,
55+
cell_centers,
56+
),
57+
dtype=np.float32,
58+
shape=(
59+
len(block),
60+
self.active_cells.sum(),
61+
),
62+
)
63+
)
64+
65+
if client:
66+
diag = client.gather(futures)
67+
else:
68+
diag = compute(futures)
69+
70+
diag = np.tile(np.vstack(diag).sum(axis=0) * self.mesh.cell_volumes**2.,3)**0.5
71+
return mkvc((sdiag(np.sqrt(diag)) @ self.remDeriv).power(2).sum(axis=0))
72+
73+
74+
Sim.getJtJdiag = dask_getJtJdiag

simpeg/directives/_vector_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class VectorInversion(InversionDirective):
9191
chifact_target = 1.0
9292
reference_model = None
9393
mode = "cartesian"
94-
inversion_type = "mvis"
94+
inversion_type = "magnetic vector"
9595
norms = []
9696
alphas = []
9797
cartesian_model = None
@@ -162,7 +162,7 @@ def endIter(self):
162162

163163
if (
164164
self.invProb.phi_d < self.target
165-
) and self.mode == "cartesian": # and self.inversion_type == 'mvis':
165+
) and self.mode == "cartesian" and self.inversion_type == "magnetic vector":
166166
print("Switching MVI to spherical coordinates")
167167
self.mode = "spherical"
168168
self.cartesian_model = model

0 commit comments

Comments
 (0)