@@ -480,38 +480,31 @@ def get_regularization(self):
480480 weight = mapping * getattr (self .models , weight_name )
481481 norm = mapping * getattr (self .models , f"{ comp } _norm" )
482482
483- if isinstance (fun , SparseSmoothness ):
484- if is_rotated :
485- if forward_mesh is None :
486- fun = set_rotated_operators (
487- fun ,
488- neighbors ,
489- comp ,
490- self .models .gradient_dip ,
491- self .models .gradient_direction ,
492- )
493-
494- else :
495- weight = (
496- getattr (
497- reg_func .regularization_mesh ,
498- f"aveCC2F{ fun .orientation } " ,
499- )
500- * weight
501- )
502- norm = (
503- getattr (
504- reg_func .regularization_mesh ,
505- f"aveCC2F{ fun .orientation } " ,
506- )
507- * norm
483+ if not isinstance (fun , SparseSmoothness ):
484+ fun .set_weights (** {comp : weight })
485+ fun .norm = norm
486+ functions .append (fun )
487+ continue
488+
489+ if is_rotated :
490+ if forward_mesh is None :
491+ fun = set_rotated_operators (
492+ fun ,
493+ neighbors ,
494+ comp ,
495+ self .models .gradient_dip ,
496+ self .models .gradient_direction ,
508497 )
509498
510- fun .set_weights (** {comp : weight })
511- fun .norm = norm
499+ average_op = getattr (
500+ reg_func .regularization_mesh ,
501+ f"aveCC2F{ fun .orientation } " ,
502+ )
503+ fun .set_weights (** {comp : average_op @ weight })
504+ fun .norm = average_op @ norm
512505 functions .append (fun )
513506
514- if isinstance ( fun , SparseSmoothness ) and is_rotated :
507+ if is_rotated :
515508 fun .gradient_type = "components"
516509 backward_fun = deepcopy (fun )
517510 setattr (backward_fun , "_regularization_mesh" , backward_mesh )
@@ -526,6 +519,12 @@ def get_regularization(self):
526519 self .models .gradient_direction ,
527520 forward = False ,
528521 )
522+ average_op = getattr (
523+ backward_fun .regularization_mesh ,
524+ f"aveCC2F{ fun .orientation } " ,
525+ )
526+ backward_fun .set_weights (** {comp : average_op @ weight })
527+ backward_fun .norm = average_op @ norm
529528 functions .append (backward_fun )
530529
531530 # Will avoid recomputing operators if the regularization mesh is the same
0 commit comments