Skip to content

Commit b47a375

Browse files
committed
Add backward diff gradient regularization if rotated
1 parent 0edf62a commit b47a375

1 file changed

Lines changed: 53 additions & 36 deletions

File tree

simpeg_drivers/driver.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -442,74 +442,91 @@ def get_regularization(self):
442442
return BaseRegularization(mesh=self.inversion_mesh.mesh)
443443

444444
reg_funcs = []
445+
total_reg_funcs = []
445446
is_rotated = self.params.gradient_rotation is not None
446447
for mapping in self.mapping:
447-
reg = Sparse(
448-
self.inversion_mesh.mesh,
449-
active_cells=self.models.active_cells,
450-
mapping=mapping,
451-
reference_model=self.models.reference,
448+
reg_funcs.append(
449+
Sparse(
450+
self.inversion_mesh.mesh,
451+
active_cells=self.models.active_cells,
452+
mapping=mapping,
453+
reference_model=self.models.reference,
454+
)
452455
)
453456

454457
if is_rotated:
455-
neighbors = cell_neighbors(reg.regularization_mesh.mesh)
458+
reg_funcs.append(
459+
Sparse(
460+
self.inversion_mesh.mesh,
461+
active_cells=self.models.active_cells,
462+
mapping=mapping,
463+
reference_model=self.models.reference,
464+
)
465+
)
466+
neighbors = cell_neighbors(reg_funcs[0].regularization_mesh.mesh)
456467

457468
# Adjustment for 2D versus 3D problems
458469
comps = "sxz" if "2d" in self.params.inversion_type else "sxyz"
459470
avg_comps = "sxy" if "2d" in self.params.inversion_type else "sxyz"
460471
weights = ["alpha_s"] + [f"length_scale_{k}" for k in comps[1:]]
461472

462473
for comp, avg_comp, objfct, weight in zip(
463-
comps, avg_comps, reg.objfcts, weights
474+
comps, avg_comps, reg_funcs[0].objfcts, weights
464475
):
465476
if getattr(self.models, weight) is None:
466-
setattr(reg, weight, 0.0)
477+
setattr(reg_funcs[0], weight, 0.0)
467478
continue
468479

469480
weight = mapping * getattr(self.models, weight)
470481
norm = mapping * getattr(self.models, f"{comp}_norm")
471482
if comp in "xyz":
472483
if is_rotated:
473-
grad_op = rotated_gradient(
474-
mesh=reg.regularization_mesh.mesh,
475-
neighbors=neighbors,
476-
axis=comp,
477-
dip=self.models.gradient_dip,
478-
direction=self.models.gradient_direction,
479-
)
480-
setattr(
481-
reg.regularization_mesh,
482-
f"_cell_gradient_{comp}",
483-
reg.regularization_mesh.Pac.T
484-
@ (grad_op @ reg.regularization_mesh.Pac),
485-
)
486-
setattr(
487-
reg.regularization_mesh,
488-
f"_aveCC2F{avg_comp}",
489-
sdiag(np.ones(reg.regularization_mesh.n_cells)),
490-
)
491-
484+
for reg, forward in zip(reg_funcs, [True, False]):
485+
grad_op = rotated_gradient(
486+
mesh=reg.regularization_mesh.mesh,
487+
neighbors=neighbors,
488+
axis=comp,
489+
dip=self.models.gradient_dip,
490+
direction=self.models.gradient_direction,
491+
forward=forward,
492+
)
493+
setattr(
494+
reg.regularization_mesh,
495+
f"_cell_gradient_{comp}",
496+
reg.regularization_mesh.Pac.T
497+
@ (grad_op @ reg.regularization_mesh.Pac),
498+
)
499+
setattr(
500+
reg.regularization_mesh,
501+
f"_aveCC2F{avg_comp}",
502+
sdiag(np.ones(reg.regularization_mesh.n_cells)),
503+
)
492504
else:
493505
weight = (
494-
getattr(reg.regularization_mesh, f"aveCC2F{avg_comp}")
506+
getattr(
507+
reg_funcs[0].regularization_mesh, f"aveCC2F{avg_comp}"
508+
)
495509
* weight
496510
)
497511
norm = (
498-
getattr(reg.regularization_mesh, f"aveCC2F{avg_comp}")
512+
getattr(
513+
reg_funcs[0].regularization_mesh, f"aveCC2F{avg_comp}"
514+
)
499515
* norm
500516
)
501517

502518
objfct.set_weights(**{comp: weight})
503519
objfct.norm = norm
504520

505521
if getattr(self.params, "gradient_type") is not None:
506-
setattr(
507-
reg,
508-
"gradient_type",
509-
getattr(self.params, "gradient_type"),
510-
)
511-
512-
reg_funcs.append(reg)
522+
for reg in reg_funcs:
523+
setattr(
524+
reg,
525+
"gradient_type",
526+
getattr(self.params, "gradient_type"),
527+
)
528+
529+
total_reg_funcs += reg_funcs
513530

514531
return objective_function.ComboObjectiveFunction(objfcts=reg_funcs)
515532

0 commit comments

Comments
 (0)