Skip to content

Commit 2f31e14

Browse files
committed
Clip out rows of zero gradients. Simplify get_regularization
1 parent ff56ac1 commit 2f31e14

2 files changed

Lines changed: 34 additions & 31 deletions

File tree

simpeg_drivers/driver.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -480,38 +480,31 @@ def get_regularization(self):
480480
weight = mapping * getattr(self.models, weight_name)
481481
norm = mapping * getattr(self.models, f"{comp}_norm")
482482

483-
if isinstance(fun, SparseSmoothness):
484-
if is_rotated:
485-
if forward_mesh is None:
486-
fun = set_rotated_operators(
487-
fun,
488-
neighbors,
489-
comp,
490-
self.models.gradient_dip,
491-
self.models.gradient_direction,
492-
)
493-
494-
else:
495-
weight = (
496-
getattr(
497-
reg_func.regularization_mesh,
498-
f"aveCC2F{fun.orientation}",
499-
)
500-
* weight
501-
)
502-
norm = (
503-
getattr(
504-
reg_func.regularization_mesh,
505-
f"aveCC2F{fun.orientation}",
506-
)
507-
* norm
483+
if not isinstance(fun, SparseSmoothness):
484+
fun.set_weights(**{comp: weight})
485+
fun.norm = norm
486+
functions.append(fun)
487+
continue
488+
489+
if is_rotated:
490+
if forward_mesh is None:
491+
fun = set_rotated_operators(
492+
fun,
493+
neighbors,
494+
comp,
495+
self.models.gradient_dip,
496+
self.models.gradient_direction,
508497
)
509498

510-
fun.set_weights(**{comp: weight})
511-
fun.norm = norm
499+
average_op = getattr(
500+
reg_func.regularization_mesh,
501+
f"aveCC2F{fun.orientation}",
502+
)
503+
fun.set_weights(**{comp: average_op @ weight})
504+
fun.norm = average_op @ norm
512505
functions.append(fun)
513506

514-
if isinstance(fun, SparseSmoothness) and is_rotated:
507+
if is_rotated:
515508
fun.gradient_type = "components"
516509
backward_fun = deepcopy(fun)
517510
setattr(backward_fun, "_regularization_mesh", backward_mesh)
@@ -526,6 +519,12 @@ def get_regularization(self):
526519
self.models.gradient_direction,
527520
forward=False,
528521
)
522+
average_op = getattr(
523+
backward_fun.regularization_mesh,
524+
f"aveCC2F{fun.orientation}",
525+
)
526+
backward_fun.set_weights(**{comp: average_op @ weight})
527+
backward_fun.norm = average_op @ norm
529528
functions.append(backward_fun)
530529

531530
# Will avoid recomputing operators if the regularization mesh is the same

simpeg_drivers/utils/regularization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,20 @@ def set_rotated_operators(
415415
grad_op = rotated_gradient(
416416
function.regularization_mesh.mesh, neighbors, axis, dip, direction, forward
417417
)
418+
grad_op_active = function.regularization_mesh.Pac.T @ (
419+
grad_op @ function.regularization_mesh.Pac
420+
)
421+
active_faces = grad_op_active.max(axis=1).toarray().ravel() > 0
422+
418423
setattr(
419424
function.regularization_mesh,
420425
f"_cell_gradient_{function.orientation}",
421-
function.regularization_mesh.Pac.T
422-
@ (grad_op @ function.regularization_mesh.Pac),
426+
grad_op_active[active_faces, :],
423427
)
424428
setattr(
425429
function.regularization_mesh,
426430
f"_aveCC2F{function.orientation}",
427-
sdiag(np.ones(function.regularization_mesh.n_cells)),
431+
sdiag(np.ones(function.regularization_mesh.n_cells))[active_faces, :],
428432
)
429433

430434
return function

0 commit comments

Comments
 (0)