@@ -442,74 +442,91 @@ def get_regularization(self):
442442 return BaseRegularization (mesh = self .inversion_mesh .mesh )
443443
444444 reg_funcs = []
445+ total_reg_funcs = []
445446 is_rotated = self .params .gradient_rotation is not None
446447 for mapping in self .mapping :
447- reg = Sparse (
448- self .inversion_mesh .mesh ,
449- active_cells = self .models .active_cells ,
450- mapping = mapping ,
451- reference_model = self .models .reference ,
448+ reg_funcs .append (
449+ Sparse (
450+ self .inversion_mesh .mesh ,
451+ active_cells = self .models .active_cells ,
452+ mapping = mapping ,
453+ reference_model = self .models .reference ,
454+ )
452455 )
453456
454457 if is_rotated :
455- neighbors = cell_neighbors (reg .regularization_mesh .mesh )
458+ reg_funcs .append (
459+ Sparse (
460+ self .inversion_mesh .mesh ,
461+ active_cells = self .models .active_cells ,
462+ mapping = mapping ,
463+ reference_model = self .models .reference ,
464+ )
465+ )
466+ neighbors = cell_neighbors (reg_funcs [0 ].regularization_mesh .mesh )
456467
457468 # Adjustment for 2D versus 3D problems
458469 comps = "sxz" if "2d" in self .params .inversion_type else "sxyz"
459470 avg_comps = "sxy" if "2d" in self .params .inversion_type else "sxyz"
460471 weights = ["alpha_s" ] + [f"length_scale_{ k } " for k in comps [1 :]]
461472
462473 for comp , avg_comp , objfct , weight in zip (
463- comps , avg_comps , reg .objfcts , weights
474+ comps , avg_comps , reg_funcs [ 0 ] .objfcts , weights
464475 ):
465476 if getattr (self .models , weight ) is None :
466- setattr (reg , weight , 0.0 )
477+ setattr (reg_funcs [ 0 ] , weight , 0.0 )
467478 continue
468479
469480 weight = mapping * getattr (self .models , weight )
470481 norm = mapping * getattr (self .models , f"{ comp } _norm" )
471482 if comp in "xyz" :
472483 if is_rotated :
473- grad_op = rotated_gradient (
474- mesh = reg .regularization_mesh .mesh ,
475- neighbors = neighbors ,
476- axis = comp ,
477- dip = self .models .gradient_dip ,
478- direction = self .models .gradient_direction ,
479- )
480- setattr (
481- reg .regularization_mesh ,
482- f"_cell_gradient_{ comp } " ,
483- reg .regularization_mesh .Pac .T
484- @ (grad_op @ reg .regularization_mesh .Pac ),
485- )
486- setattr (
487- reg .regularization_mesh ,
488- f"_aveCC2F{ avg_comp } " ,
489- sdiag (np .ones (reg .regularization_mesh .n_cells )),
490- )
491-
484+ for reg , forward in zip (reg_funcs , [True , False ]):
485+ grad_op = rotated_gradient (
486+ mesh = reg .regularization_mesh .mesh ,
487+ neighbors = neighbors ,
488+ axis = comp ,
489+ dip = self .models .gradient_dip ,
490+ direction = self .models .gradient_direction ,
491+ forward = forward ,
492+ )
493+ setattr (
494+ reg .regularization_mesh ,
495+ f"_cell_gradient_{ comp } " ,
496+ reg .regularization_mesh .Pac .T
497+ @ (grad_op @ reg .regularization_mesh .Pac ),
498+ )
499+ setattr (
500+ reg .regularization_mesh ,
501+ f"_aveCC2F{ avg_comp } " ,
502+ sdiag (np .ones (reg .regularization_mesh .n_cells )),
503+ )
492504 else :
493505 weight = (
494- getattr (reg .regularization_mesh , f"aveCC2F{ avg_comp } " )
506+ getattr (
507+ reg_funcs [0 ].regularization_mesh , f"aveCC2F{ avg_comp } "
508+ )
495509 * weight
496510 )
497511 norm = (
498- getattr (reg .regularization_mesh , f"aveCC2F{ avg_comp } " )
512+ getattr (
513+ reg_funcs [0 ].regularization_mesh , f"aveCC2F{ avg_comp } "
514+ )
499515 * norm
500516 )
501517
502518 objfct .set_weights (** {comp : weight })
503519 objfct .norm = norm
504520
505521 if getattr (self .params , "gradient_type" ) is not None :
506- setattr (
507- reg ,
508- "gradient_type" ,
509- getattr (self .params , "gradient_type" ),
510- )
511-
512- reg_funcs .append (reg )
522+ for reg in reg_funcs :
523+ setattr (
524+ reg ,
525+ "gradient_type" ,
526+ getattr (self .params , "gradient_type" ),
527+ )
528+
529+ total_reg_funcs += reg_funcs
513530
514531 return objective_function .ComboObjectiveFunction (objfcts = reg_funcs )
515532
0 commit comments