Skip to content

Commit b85287b

Browse files
authored
Merge branch 'develop' into GEOPY-789
2 parents 2dbf61c + 1fff506 commit b85287b

4 files changed

Lines changed: 25 additions & 27 deletions

File tree

recipe.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
context:
44
name: "mira-simpeg"
5-
version: "0.23.0.1a4"
5+
version: "0.23.0.2a1"
66
python_min: "3.10"
77

88
package:
@@ -27,7 +27,7 @@ requirements:
2727
run:
2828
- python >=${{ python_min }}
2929
# Mira packages
30-
- geoh5py >=0.11.0a, <0.12.dev
30+
- geoh5py >=0.12.0a, <0.13.dev
3131
# direct dependencies
3232
- discretize >=0.11
3333
- geoana >=0.7.0

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dask import array, delayed
1212
from dask.distributed import get_client
1313

14+
from time import time
1415
from simpeg.dask.utils import get_parallel_blocks
1516
from simpeg.utils import mkvc
1617

@@ -405,10 +406,6 @@ def get_field_deriv_block(
405406
# Cut out early data
406407
time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind]
407408
local_ind = np.arange(rx_ind.shape[0])[time_check]
408-
409-
if len(local_ind) < 1:
410-
continue
411-
412409
indices.append(
413410
(np.arange(count, count + len(local_ind)), local_ind),
414411
)
@@ -425,14 +422,14 @@ def get_field_deriv_block(
425422

426423
stacked_blocks.append(stacked_block)
427424

428-
if len(stacked_blocks) > 0:
429-
blocks = np.hstack(stacked_blocks)
430-
425+
blocks = np.hstack(stacked_blocks)
426+
if blocks.ndim == 2 and blocks.shape[1] > 0:
431427
solve = (AdiagTinv * blocks).reshape(blocks.shape)
432428
else:
433429
solve = None
434430

435431
updated_ATinv_df_duT_v = []
432+
436433
for (_, arrays), field_deriv, ATinv_chunk, (columns, local_ind) in zip(
437434
block, field_derivs, ATinv_df_duT_v, indices, strict=True
438435
):
@@ -444,10 +441,9 @@ def get_field_deriv_block(
444441
)
445442
ATinv_chunk = np.zeros(shape, dtype=np.float32)
446443

447-
if solve is None:
448-
continue
444+
if solve is not None:
445+
ATinv_chunk[:, local_ind] = solve[:, columns]
449446

450-
ATinv_chunk[:, local_ind] = solve[:, columns]
451447
updated_ATinv_df_duT_v.append(ATinv_chunk)
452448

453449
return updated_ATinv_df_duT_v

simpeg/dask/objective_function.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,15 @@ def _validate_type_or_future_of_type(
7272
property_name, objects, obj_type, ensure_unique=True
7373
)
7474
workload = [[]]
75+
lookup = {}
7576
count = 0
7677
for obj in objects:
7778
if count == len(workers):
7879
count = 0
7980
workload.append([])
8081
obj.simulation.simulations[0].worker = workers[count]
8182
future = client.scatter([obj], workers=workers[count])[0]
82-
83+
lookup[obj] = (future, workers[count])
8384
if hasattr(obj, "name"):
8485
future.name = obj.name
8586

@@ -100,7 +101,7 @@ def _validate_type_or_future_of_type(
100101
raise TypeError(f"{property_name} futures must be an instance of {obj_type}")
101102

102103
if return_workers:
103-
return workload, workers
104+
return workload, workers, lookup
104105
else:
105106
return workload
106107

@@ -390,7 +391,7 @@ def objfcts(self):
390391
def objfcts(self, objfcts):
391392
client = self.client
392393

393-
futures, workers = _validate_type_or_future_of_type(
394+
futures, workers, lookup = _validate_type_or_future_of_type(
394395
"objfcts",
395396
objfcts,
396397
L2DataMisfit,
@@ -404,8 +405,8 @@ def objfcts(self, objfcts):
404405
self._workers = workers
405406

406407
self._lookup = {
407-
obj.simulation: (future, worker)
408-
for future, worker, obj in zip(futures[0], workers, objfcts)
408+
misfit.simulation: (future, worker)
409+
for misfit, (future, worker) in lookup.items()
409410
}
410411

411412
def residuals(self, m, f=None):
@@ -443,12 +444,12 @@ def broadcast_updates(self, updates: dict):
443444
if fun not in self._lookup:
444445
continue
445446

446-
objfct, worker = self._lookup[fun]
447+
future, worker = self._lookup[fun]
447448

448449
stores.append(
449450
client.submit(
450451
_setter_broadcast,
451-
objfct,
452+
future,
452453
key,
453454
value,
454455
workers=worker,

simpeg/directives/_regularization.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,20 @@ def adjust_cooling_schedule(self):
229229
"""
230230
Adjust the cooling schedule based on the misfit.
231231
"""
232-
ratio = self.invProb.phi_d / self.misfit_from_chi_factor(self.chifact_target)
232+
if self.metrics.start_irls_iter is None:
233+
return
233234

234-
if (
235-
np.abs(1.0 - ratio) > self.misfit_tolerance
236-
and self.metrics.start_irls_iter is not None
237-
):
235+
ratio = self.invProb.phi_d / self.misfit_from_chi_factor(self.chifact_target)
236+
if np.abs(1.0 - ratio) > self.misfit_tolerance:
238237

239238
if ratio > 1:
240-
ratio = np.mean([2.0, ratio])
239+
update_ratio = 1 / np.mean([0.75, 1 / ratio])
241240
else:
242-
ratio = np.mean([0.75, ratio])
241+
update_ratio = 1 / np.mean([2.0, 1 / ratio])
243242

244-
self.cooling_factor = ratio
243+
self.cooling_factor = update_ratio
244+
else:
245+
self.cooling_factor = 1.0
245246

246247
def initialize(self):
247248
"""

0 commit comments

Comments
 (0)