Skip to content

Commit a3996d2

Browse files
authored
Merge pull request #213 from MiraGeoscience/GEOPY-2193
GEOPY-2193: DC 2D inversion does not honour cooling schedule
2 parents f910d89 + c151e56 commit a3996d2

4 files changed

Lines changed: 52 additions & 8 deletions

File tree

simpeg_drivers/components/factories/directives_factory.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from abc import ABC
18+
from logging import getLogger
1819
from typing import TYPE_CHECKING
1920

2021
import numpy as np
@@ -29,6 +30,8 @@
2930
if TYPE_CHECKING:
3031
from simpeg_drivers.driver import InversionDriver
3132

33+
logger = getLogger(__name__)
34+
3235

3336
class DirectivesFactory:
3437
def __init__(self, driver: InversionDriver):
@@ -263,7 +266,15 @@ def scale_misfits(self):
263266
def update_irls_directive(self):
264267
"""Directive to update IRLS."""
265268
if self._update_irls_directive is None:
266-
has_chi_start = self.params.starting_chi_factor is not None
269+
start_chi_fact = self.params.starting_chi_factor
270+
271+
if start_chi_fact is not None and self.params.chi_factor > start_chi_fact:
272+
logger.warning(
273+
"Starting chi factor is greater than target chi factor.\n"
274+
"Setting the target chi factor to the starting chi factor."
275+
)
276+
start_chi_fact = self.params.chi_factor
277+
267278
self._update_irls_directive = directives.UpdateIRLS(
268279
f_min_change=self.params.f_min_change,
269280
max_irls_iterations=self.params.max_irls_iterations,
@@ -272,11 +283,7 @@ def update_irls_directive(self):
272283
cooling_rate=self.params.cooling_rate,
273284
cooling_factor=self.params.cooling_factor,
274285
irls_cooling_factor=self.params.epsilon_cooling_factor,
275-
chifact_start=(
276-
self.params.starting_chi_factor
277-
if has_chi_start
278-
else self.params.chi_factor
279-
),
286+
chifact_start=start_chi_fact or self.params.chi_factor,
280287
chifact_target=self.params.chi_factor,
281288
)
282289
return self._update_irls_directive

simpeg_drivers/joint/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class BaseJointOptions(BaseData):
8484
initial_beta: float | None = None
8585
cooling_factor: float = 2.0
8686

87-
cooling_rate: float = 1.0
87+
cooling_rate: int = 1
8888
max_global_iterations: int = 50
8989
max_line_search_iterations: int = 20
9090
max_cg_iterations: int = 30

simpeg_drivers/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ class BaseInversionOptions(CoreOptions):
362362
initial_beta: float | None = None
363363
cooling_factor: float = 2.0
364364

365-
cooling_rate: float = 1.0
365+
cooling_rate: int = 1
366366
max_global_iterations: int = 50
367367
max_line_search_iterations: int = 20
368368
max_cg_iterations: int = 30

tests/driver_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,40 @@ def test_smallness_terms(tmp_path: Path):
5757
params.alpha_s = None
5858
driver = GravityInversionDriver(params)
5959
assert driver.regularization.objfcts[0].alpha_s == 0.0
60+
61+
62+
def test_target_chi(tmp_path: Path, caplog):
63+
n_grid_points = 2
64+
refinement = (2,)
65+
66+
geoh5, _, model, survey, topography = setup_inversion_workspace(
67+
tmp_path,
68+
background=0.0,
69+
anomaly=0.75,
70+
n_electrodes=n_grid_points,
71+
n_lines=n_grid_points,
72+
refinement=refinement,
73+
flatten=False,
74+
)
75+
76+
with geoh5.open():
77+
gz = survey.add_data({"gz": {"values": np.ones(survey.n_vertices)}})
78+
mesh = model.parent
79+
active_cells = ActiveCellsOptions(topography_object=topography)
80+
params = GravityInversionOptions(
81+
geoh5=geoh5,
82+
mesh=mesh,
83+
active_cells=active_cells,
84+
data_object=gz.parent,
85+
gz_channel=gz,
86+
gz_uncertainty=2e-3,
87+
starting_model=1e-4,
88+
starting_chi_factor=1.0,
89+
chi_factor=2.0,
90+
)
91+
driver = GravityInversionDriver(params)
92+
93+
with caplog.at_level("WARNING"):
94+
assert driver.directives.update_irls_directive.chifact_start == 2.0
95+
96+
assert "Starting chi factor is greater" in caplog.text

0 commit comments

Comments
 (0)