11from abc import ABCMeta , abstractmethod
2- from typing import TYPE_CHECKING
2+ from typing import Iterable , TYPE_CHECKING
33
44from datetime import datetime
55import pathlib
@@ -2655,6 +2655,16 @@ def endIter(self):
26552655 self .opt .xc = m
26562656
26572657
2658+ def flatten (nested_iterable ):
2659+ for item in nested_iterable :
2660+ if isinstance (item , list ):
2661+ yield from flatten (item )
2662+ elif isinstance (item , np .ndarray ):
2663+ yield from item .tolist ()
2664+ else :
2665+ yield item
2666+
2667+
26582668class ScaleMisfitMultipliers (InversionDirective ):
26592669 """
26602670 Scale the misfits by the relative chi-factors of multiple misfit functions.
@@ -2670,9 +2680,17 @@ class ScaleMisfitMultipliers(InversionDirective):
26702680 Path to save the chi-factors log file.
26712681 """
26722682
2673- def __init__ (self , path : pathlib .Path | None = None , ** kwargs ):
2683+ def __init__ (
2684+ self ,
2685+ path : pathlib .Path | None = None ,
2686+ nesting : list [list ] | None = None ,
2687+ target_chi : float = 1.0 ,
2688+ ** kwargs ,
2689+ ):
26742690 self .last_beta = None
26752691 self .chi_factors = None
2692+ self .target_chi = target_chi
2693+ self .nesting = nesting
26762694
26772695 if path is None :
26782696 path = pathlib .Path ("./" )
@@ -2684,42 +2702,77 @@ def __init__(self, path: pathlib.Path | None = None, **kwargs):
26842702 def initialize (self ):
26852703 self .last_beta = self .invProb .beta
26862704 self .multipliers = self .invProb .dmisfit .multipliers
2687- self .scalings = np .ones_like (self .multipliers )
2688- with open (self .filepath , "w" , encoding = "utf-8" ) as f :
2689- f .write ("Logging of [scaling * chi factor] per misfit function.\n \n " )
2690- f .write (
2691- "Iterations\t "
2692- + "\t " .join (
2693- f"[{ objfct .name } ]" for objfct in self .invProb .dmisfit .objfcts
2694- )
2695- )
2696- f .write ("\n " )
2705+ self .scalings = np .ones_like (self .multipliers ) # Everyone gets a fair chance
2706+ self .misfit_tree_indices = self .parse_by_nested_levels (self .nesting )
2707+ # def append_labels(_, depth):
2708+ # return f"[{depth}]"
2709+ #
2710+ # with open(self.filepath, "w", encoding="utf-8") as f:
2711+ # f.write("Logging of [scaling * chi factor] per misfit function.\n\n")
2712+ #
2713+ # header = "Iterations\t"
2714+ # for elem in recursions(self.nested_misfits, append_labels):
2715+ # header += "\t".join(f"Misfit [{elem}]")
2716+ #
2717+ # f.write("\n")
2718+
2719+ def scalings_by_level (
2720+ self , nested_values , nested_indices , ratio , scaling_vector : np .ndarray | None
2721+ ):
2722+ """
2723+ Recursively compute scaling factors for each level of the nested misfit structure.
2724+
2725+ The maximum chi-factor at each level is used to determine scaling factors
2726+ for the misfit functions at that level. The scaling factors are then propagated
2727+ down to the next level of the nested structure.
2728+ """
2729+ if scaling_vector is None :
2730+ scaling_vector = np .ones (len (self .invProb .dmisfit .multipliers ))
26972731
2698- def endIter (self ):
2699- ratio = self .invProb .beta / self .last_beta
27002732 chi_factors = []
2701- for residual in self .invProb .residuals :
2702- phi_d = np .vdot (residual , residual )
2703- chi_factors .append (phi_d / len (residual ))
2733+ flat_indices = []
2734+ for elem , indices in zip (nested_values , nested_indices ):
27042735
2705- self .chi_factors = np .asarray (chi_factors )
2736+ # Reach the outer most level
2737+ if not isinstance (indices , list ) or (
2738+ len (indices ) == 1 and not isinstance (indices [0 ], list )
2739+ ):
2740+ return scaling_vector
27062741
2707- if np . all ( self . chi_factors < 1 ) or ratio >= 1 :
2708- self . last_beta = self . invProb . beta
2709- self . write_log ( )
2710- return
2742+ flat_indices . append ( np . asarray ( list ( flatten ( indices ))))
2743+ residuals = np . asarray ( list ( flatten ( elem )))
2744+ phi_d = np . vdot ( residuals , residuals )
2745+ chi_factors . append ( phi_d / len ( residuals ))
27112746
2712- # Normalize scaling between [ratio, 1]
2747+ chi_factors = np . hstack ( chi_factors )
27132748 scalings = (
2714- 1
2715- - (1 - ratio )
2716- * (self .chi_factors .max () - self .chi_factors )
2717- / self .chi_factors .max ()
2749+ 1 - (1 - ratio ) * (chi_factors .max () - chi_factors ) / chi_factors .max ()
27182750 )
27192751
27202752 # Force the ones that overshot target
2721- scalings [self .chi_factors < 1 ] = (
2722- ratio # * self.chi_factors[self.chi_factors < 1]
2753+ scalings [chi_factors < self .target_chi ] = ratio
2754+
2755+ for elem , indices , scale , group_ind in zip (
2756+ nested_values , nested_indices , scalings , flat_indices
2757+ ):
2758+ # Scale everything below same as super group
2759+ scaling_vector [group_ind ] = np .maximum (
2760+ ratio , scale * scaling_vector [group_ind ]
2761+ )
2762+ scaling_vector = self .scalings_by_level (
2763+ elem , indices , ratio , scaling_vector
2764+ )
2765+
2766+ return scaling_vector
2767+
2768+ def endIter (self ):
2769+ ratio = self .invProb .beta / self .last_beta
2770+ nested_residuals = self .parse_by_nested_levels (
2771+ self .nesting , self .invProb .residuals
2772+ )
2773+
2774+ scalings = self .scalings_by_level (
2775+ nested_residuals , self .misfit_tree_indices , ratio , None
27232776 )
27242777
27252778 # Update the scaling
@@ -2728,7 +2781,56 @@ def endIter(self):
27282781 # Normalize total phi_d with scalings
27292782 self .invProb .dmisfit .multipliers = self .multipliers * self .scalings
27302783 self .last_beta = self .invProb .beta
2731- self .write_log ()
2784+ # self.write_log()
2785+
2786+ def parse_by_nested_levels (
2787+ self , nesting : list [Iterable ], values : Iterable | None = None
2788+ ) -> Iterable :
2789+ """
2790+ Replace leaf elements of `nesting` with values from `flat` (in order).
2791+ Assumes the number of leaf positions equals len(flat).
2792+
2793+ Parameters:
2794+ - nesting: arbitrarily nested list structure; leaves are non-list values
2795+ - flat: flat iterable whose values will fill the leaves in order
2796+
2797+ Returns:
2798+ - A new nested structure with leaves replaced by values from `flat`.
2799+
2800+ Raises:
2801+ - ValueError if `flat` has fewer or more elements than required by `nesting`.
2802+ """
2803+ indices = np .arange (len (self .invProb .dmisfit .objfcts ))
2804+ if nesting is None :
2805+ if values is not None :
2806+ return values
2807+ return indices .tolist ()
2808+
2809+ it = iter (indices )
2810+
2811+ def _fill (node : Iterable ) -> Iterable :
2812+ if isinstance (node , list ):
2813+ return [_fill (child ) for child in node ]
2814+ elif isinstance (node , dict ):
2815+ return [_fill (child ) for child in node .values ()]
2816+ # leaf: consume a value
2817+ try :
2818+ if values is not None :
2819+ return values [next (it )]
2820+ return next (it )
2821+ except StopIteration :
2822+ raise ValueError ("Not enough elements in `flat` to fill `nesting`." )
2823+
2824+ result = _fill (nesting )
2825+
2826+ # ensure no extra elements left
2827+ try :
2828+ next (it )
2829+ raise ValueError ("Too many elements in `flat` for the given `nesting`." )
2830+ except StopIteration :
2831+ pass
2832+
2833+ return result
27322834
27332835 def write_log (self ):
27342836 """
0 commit comments