Skip to content

Commit 3430f93

Browse files
authored
Merge pull request #137 from MiraGeoscience/GEOPY-2799
GEOPY-2799
2 parents f7dd0d3 + 465bc8d commit 3430f93

6 files changed

Lines changed: 106 additions & 7 deletions

File tree

simpeg/dask/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import simpeg.dask.potential_fields.base
1212
import simpeg.dask.potential_fields.gravity.simulation
1313
import simpeg.dask.potential_fields.magnetics.simulation
14+
import simpeg.dask.potential_fields.magnetics.simulation_pde
1415
import simpeg.dask.simulation
1516
import simpeg.dask.inverse_problem
1617
import simpeg.dask.objective_function
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from dask import array, compute, delayed
2+
import numpy as np
3+
from ....potential_fields.magnetics import Simulation3DDifferential as Sim
4+
from ....utils import sdiag, mkvc
5+
6+
7+
def distance_weights(locations, cell_centers, cell_volumes, exponent=3, threshold=1e-2):
8+
weights = np.zeros(len(cell_centers))
9+
for loc in locations:
10+
distance = np.linalg.norm(cell_centers - loc, axis=1)
11+
weights += cell_volumes**2.0 * (distance + threshold) ** (
12+
-2 * exponent
13+
)
14+
15+
return weights
16+
17+
18+
def dask_getJtJdiag(self, m, W=None, f=None):
19+
"""
20+
Return the diagonal of JtJ
21+
"""
22+
23+
self.model = m
24+
25+
self.model = m
26+
if W is None:
27+
W = np.ones(self.Jmatrix.shape[0])
28+
else:
29+
W = W.diagonal()
30+
31+
client, worker = self._get_client_worker()
32+
33+
n_threads = self.n_threads(client=client, worker=worker)
34+
35+
chunks = np.array_split(self.survey.receiver_locations, n_threads)
36+
cell_centers = self.mesh.cell_centers.copy()
37+
cell_volumes = self.mesh.cell_volumes.copy()
38+
39+
if client:
40+
cell_centers = client.scatter(cell_centers, workers=worker)
41+
cell_volumes = client.scatter(cell_volumes, workers=worker)
42+
else:
43+
delayed_distance_weights = delayed(distance_weights)
44+
45+
futures = []
46+
for block in chunks:
47+
if client:
48+
futures.append(
49+
client.submit(
50+
distance_weights,
51+
block,
52+
cell_centers,
53+
cell_volumes,
54+
workers=worker,
55+
)
56+
)
57+
else:
58+
futures.append(
59+
array.from_delayed(
60+
delayed_distance_weights(
61+
block,
62+
cell_centers,
63+
cell_volumes,
64+
),
65+
dtype=np.float32,
66+
shape=(
67+
len(block),
68+
len(cell_centers),
69+
),
70+
)
71+
)
72+
73+
if client:
74+
diag = client.gather(futures)
75+
else:
76+
diag = compute(futures)[0]
77+
78+
diag = np.tile(np.vstack(diag).sum(axis=0), 3)
79+
return mkvc((sdiag(np.sqrt(diag)) @ self.remDeriv).power(2).sum(axis=0))
80+
81+
82+
Sim.getJtJdiag = dask_getJtJdiag

simpeg/directives/_save_geoh5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def write(self, iteration: int, **_):
486486
if (channel_name in child.name and isinstance(child, FloatData))
487487
]
488488

489-
if children[0] is not None:
489+
if children:
490490
properties += children
491491

492492
if len(properties) == 0:

simpeg/directives/_vector_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class VectorInversion(InversionDirective):
9191
chifact_target = 1.0
9292
reference_model = None
9393
mode = "cartesian"
94-
inversion_type = "mvis"
94+
inversion_type = "magnetic vector"
9595
norms = []
9696
alphas = []
9797
cartesian_model = None
@@ -162,7 +162,7 @@ def endIter(self):
162162

163163
if (
164164
self.invProb.phi_d < self.target
165-
) and self.mode == "cartesian": # and self.inversion_type == 'mvis':
165+
) and self.mode == "cartesian" and self.inversion_type == "magnetic vector":
166166
print("Switching MVI to spherical coordinates")
167167
self.mode = "spherical"
168168
self.cartesian_model = model

simpeg/maps/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,10 @@ def __init__(self, *args):
11501150
for arg in args:
11511151

11521152
if isinstance(arg[1], (int, np.integer)):
1153+
1154+
if not getattr(self, "_nP", None):
1155+
self._nP = int(np.sum([w[1] for w in args]))
1156+
11531157
wire = Projection(self.nP, slice(start, start + arg[1]))
11541158
start += arg[1]
11551159
else:

simpeg/potential_fields/magnetics/simulation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,7 +1862,7 @@ def _getRHS(self, m):
18621862
).diagonal()
18631863
)
18641864

1865-
return rhs
1865+
return rhs.astype(self.solver_dtype)
18661866

18671867
def _getA(self):
18681868
A = self._Div * self.MfMuiI * self._DivT
@@ -1991,13 +1991,23 @@ def _Jtvec(self, m, v, f):
19911991
if v is None:
19921992
v = np.eye(Q.shape[0])
19931993
divt_solve_q = (
1994-
self._DivT * (self._Ainv * ((Q * self.MfMuiI * -self._DivT).T * v))
1994+
self._DivT
1995+
* (
1996+
self._Ainv
1997+
* ((Q * self.MfMuiI * -self._DivT).T * v).astype(self.solver_dtype)
1998+
)
19951999
+ Q.T * v
19962000
)
19972001
del v
19982002
else:
19992003
divt_solve_q = (
2000-
self._DivT * (self._Ainv * ((-self._Div * (self.MfMuiI.T * (Q.T * v)))))
2004+
self._DivT
2005+
* (
2006+
self._Ainv
2007+
* ((-self._Div * (self.MfMuiI.T * (Q.T * v)))).astype(
2008+
self.solver_dtype
2009+
)
2010+
)
20012011
+ Q.T * v
20022012
)
20032013

@@ -2071,7 +2081,9 @@ def _Jvec(self, m, v, f):
20712081
self.MfMuiI * Mf_r_mui_deriv * v
20722082
)
20732083

2074-
Ainv_Ddm = self._Ainv * (self._Div * (-dCmu_dm + db_dm))
2084+
Ainv_Ddm = self._Ainv * (self._Div * (-dCmu_dm + db_dm)).astype(
2085+
self.solver_dtype
2086+
)
20752087

20762088
Jv = Q * (C * Ainv_Ddm + (-dCmu_dm + db_dm))
20772089

0 commit comments

Comments
 (0)