Skip to content

Commit 6b8af8a

Browse files
authored
Merge branch 'develop' into GEOPY-2561
2 parents 8c375d9 + 7c38e72 commit 6b8af8a

6 files changed

Lines changed: 28 additions & 17 deletions

File tree

simpeg_drivers/components/factories/directives_factory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def save_directives(self):
144144
]:
145145
save_directive = getattr(self, directive)
146146
if save_directive is not None:
147-
directives_list.append(getattr(self, directive))
147+
directives_list.append(save_directive)
148148

149149
if (
150150
isinstance(save_directive, directives.SaveDataGeoH5)
@@ -200,8 +200,10 @@ def save_property_group(self):
200200
@property
201201
def save_sensitivities_directive(self):
202202
""""""
203-
if self._save_sensitivities_directive is None and isinstance(
204-
self.params, BaseInversionOptions
203+
if (
204+
self._save_sensitivities_directive is None
205+
and isinstance(self.params, BaseInversionOptions)
206+
and self.params.directives.save_sensitivities
205207
):
206208
self._save_sensitivities_directive = SaveSensitivitiesGeoh5Factory(
207209
self.params

simpeg_drivers/components/factories/simulation_factory.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,17 @@ def assemble_arguments(
121121
survey=None,
122122
mesh=None,
123123
models=None,
124+
**kwargs,
124125
):
125126
if "1d" in self.factory_type:
126127
return ()
127128

128129
return [mesh]
129130

130-
def assemble_keyword_arguments(
131-
self,
132-
survey=None,
133-
mesh=None,
134-
models=None,
135-
):
136-
kwargs = {}
131+
def assemble_keyword_arguments(self, survey=None, mesh=None, models=None, **kwargs):
132+
if not kwargs:
133+
kwargs = {}
134+
137135
kwargs["survey"] = survey
138136
kwargs["max_chunk_size"] = self.params.compute.max_chunk_size
139137
kwargs["store_sensitivities"] = (

simpeg_drivers/components/models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,10 +637,9 @@ def _get_value(self, model: float | NumericData) -> np.ndarray:
637637
if isinstance(model, NumericData):
638638
model = self.obj_2_mesh(model, self.driver.inversion_mesh.entity)
639639
model = (self.driver.inversion_mesh.permutation @ model).astype(model.dtype)
640-
else:
640+
elif isinstance(model, int | float):
641641
nc = self.driver.inversion_mesh.mesh.n_cells
642-
if isinstance(model, int | float):
643-
model *= np.ones(nc)
642+
model *= np.ones(nc)
644643

645644
return model
646645

simpeg_drivers/electromagnetics/base_1d_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def simulation(self):
9797
mesh=self.inversion_mesh.mesh,
9898
models=self.models,
9999
survey=self.inversion_data.survey,
100+
topo=[0, 0, -np.inf], # Bypass check for global simulation
100101
)
101102

102103
self._simulation.mesh = self.inversion_mesh.mesh

simpeg_drivers/options.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
from geoapps_utils.base import Options
20+
from geoapps_utils.utils.numerical import weighted_average
2021
from geoh5py.data import (
2122
BooleanData,
2223
DataAssociationEnum,
@@ -29,7 +30,6 @@
2930
from geoh5py.objects import DrapeModel, Grid2D, Octree, Points
3031
from geoh5py.objects.surveys.electromagnetics.base import BaseEMSurvey
3132
from geoh5py.ui_json import InputFile
32-
from geoh5py.ui_json.utils import fetch_active_workspace
3333
from pydantic import (
3434
AliasChoices,
3535
BaseModel,
@@ -285,6 +285,8 @@ class ModelOptions(BaseModel):
285285
y_norm: float | FloatData | None = 2.0
286286
z_norm: float | FloatData = 2.0
287287

288+
_gradient_orientations: np.ndarray | None = None
289+
288290
@property
289291
def gradient_direction(self) -> np.ndarray:
290292
if self.gradient_orientations is None:
@@ -306,12 +308,20 @@ def gradient_orientations(self) -> tuple(float, float):
306308
and clockwise from horizontal for dip.
307309
"""
308310

309-
if self.gradient_rotation is not None:
311+
if self._gradient_orientations is None and self.gradient_rotation is not None:
310312
orientations = direction_and_dip(self.gradient_rotation)
311313

312-
return np.deg2rad(orientations)
314+
angles = np.deg2rad(orientations)
315+
# Deal with aircells here
316+
orientations = weighted_average(
317+
self.gradient_rotation.parent.centroids,
318+
self.gradient_rotation.parent.centroids,
319+
[angles[:, 0], angles[:, 1]],
320+
)
321+
322+
self._gradient_orientations = np.vstack(orientations).T
313323

314-
return None
324+
return self._gradient_orientations
315325

316326

317327
class ConductivityModelOptions(ModelOptions):

tests/run_tests/driver_mt_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def test_magnetotellurics_run(tmp_path: Path, max_iterations=1, pytest=True):
178178
driver.params.geoh5.h5file, driver.params.out_group.uid
179179
)
180180
output["data"] = orig_zyy_real_1
181+
assert not run_ws.get_entity("Iteration_0_sensitivities")[0]
181182
if pytest:
182183
check_target(output, target_run, tolerance=0.2)
183184
nan_ind = np.isnan(run_ws.get_entity("Iteration_0_model")[0].values)

0 commit comments

Comments
 (0)