11from abc import ABCMeta , abstractmethod
2- from typing import TYPE_CHECKING
2+ from typing import Iterable , TYPE_CHECKING
33
44from datetime import datetime
55import pathlib
@@ -2655,6 +2655,16 @@ def endIter(self):
26552655 self .opt .xc = m
26562656
26572657
2658+ def flatten (nested_iterable ):
2659+ for item in nested_iterable :
2660+ if isinstance (item , list ):
2661+ yield from flatten (item )
2662+ elif isinstance (item , np .ndarray ):
2663+ yield from item .tolist ()
2664+ else :
2665+ yield item
2666+
2667+
26582668class ScaleMisfitMultipliers (InversionDirective ):
26592669 """
26602670 Scale the misfits by the relative chi-factors of multiple misfit functions.
@@ -2670,9 +2680,19 @@ class ScaleMisfitMultipliers(InversionDirective):
26702680 Path to save the chi-factors log file.
26712681 """
26722682
2673- def __init__ (self , path : pathlib .Path | None = None , ** kwargs ):
2683+ def __init__ (
2684+ self ,
2685+ path : pathlib .Path | None = None ,
2686+ nesting : list [list ] | None = None ,
2687+ target_chi : float = 1.0 ,
2688+ headers : list [str ] | None = None ,
2689+ ** kwargs ,
2690+ ):
26742691 self .last_beta = None
26752692 self .chi_factors = None
2693+ self .target_chi = target_chi
2694+ self .nesting = nesting
2695+ self .headers = headers
26762696
26772697 if path is None :
26782698 path = pathlib .Path ("./" )
@@ -2681,45 +2701,112 @@ def __init__(self, path: pathlib.Path | None = None, **kwargs):
26812701
26822702 super ().__init__ (** kwargs )
26832703
2704+ self ._log_array : np .ndarray | None = None
2705+
2706+ @property
2707+ def log_array (self , headers : list [str ] | None = None ):
2708+ if self ._log_array is None :
2709+ if self .headers is None :
2710+
2711+ def append_sub_indices (elements , header ):
2712+ values = []
2713+ for ii , elem in enumerate (elements ):
2714+ heads = header + f"[{ ii } ]"
2715+ if isinstance (elem , Iterable ):
2716+ values += append_sub_indices (elem , heads )
2717+ else :
2718+ values += [heads ]
2719+ return values
2720+
2721+ headers = []
2722+ for ii , elem in enumerate (self .misfit_tree_indices ):
2723+ headers += append_sub_indices (elem , f"[{ ii } ]" )
2724+ self .headers = headers
2725+
2726+ dtype = np .dtype (
2727+ [("Iterations" , np .int32 )] + [(h , np .float32 ) for h in self .headers ]
2728+ )
2729+ self ._log_array = np .rec .fromrecords ((), dtype = dtype )
2730+
2731+ return self ._log_array
2732+
26842733 def initialize (self ):
26852734 self .last_beta = self .invProb .beta
26862735 self .multipliers = self .invProb .dmisfit .multipliers
2687- self .scalings = np .ones_like (self .multipliers )
2688- with open (self .filepath , "w" , encoding = "utf-8" ) as f :
2689- f .write ("Logging of [scaling * chi factor] per misfit function.\n \n " )
2690- f .write (
2691- "Iterations\t "
2692- + "\t " .join (
2693- f"[{ objfct .name } ]" for objfct in self .invProb .dmisfit .objfcts
2694- )
2695- )
2696- f .write ("\n " )
2736+ self .scalings = np .ones_like (self .multipliers ) # Everyone gets a fair chance
2737+ self .misfit_tree_indices = self .parse_by_nested_levels (self .nesting )
26972738
2698- def endIter (self ):
2699- ratio = self .invProb .beta / self .last_beta
2700- chi_factors = []
2701- for residual in self .invProb .residuals :
2702- phi_d = np .vdot (residual , residual )
2703- chi_factors .append (phi_d / len (residual ))
2739+ self .write_log ()
27042740
2705- self .chi_factors = np .asarray (chi_factors )
2741+ def scale_by_level (
2742+ self , nested_values , nested_indices , ratio , scaling_vector : np .ndarray | None
2743+ ):
2744+ """
2745+ Recursively compute scaling factors for each level of the nested misfit structure.
27062746
2707- if np .all (self .chi_factors < 1 ) or ratio >= 1 :
2708- self .last_beta = self .invProb .beta
2709- self .write_log ()
2710- return
2747+ The maximum chi-factor at each level is used to determine scaling factors
2748+ for the misfit functions at that level. The scaling factors are then propagated
2749+ down to the next level of the nested structure.
2750+
2751+ Parameters
2752+ ----------
2753+ nested_values : list
2754+ Nested list of misfit residuals.
2755+
2756+ nested_indices : list
2757+ Nested list of indices corresponding to the misfit residuals.
27112758
2712- # Normalize scaling between [ratio, 1]
2759+ ratio : float
2760+ Ratio of current beta to last beta.
2761+
2762+ scaling_vector : np.ndarray, optional
2763+ Current scaling vector to be updated.
2764+ """
2765+ if scaling_vector is None :
2766+ scaling_vector = np .ones (len (self .invProb .dmisfit .multipliers ))
2767+
2768+ chi_factors = []
2769+ flat_indices = []
2770+ for elem , indices in zip (nested_values , nested_indices ):
2771+ flat_indices .append (np .asarray (list (flatten (indices ))))
2772+ residuals = np .asarray (list (flatten (elem )))
2773+ phi_d = np .vdot (residuals , residuals )
2774+ chi_factors .append (phi_d / len (residuals ))
2775+
2776+ chi_factors = np .hstack (chi_factors )
27132777 scalings = (
2714- 1
2715- - (1 - ratio )
2716- * (self .chi_factors .max () - self .chi_factors )
2717- / self .chi_factors .max ()
2778+ 1 - (1 - ratio ) * (chi_factors .max () - chi_factors ) / chi_factors .max ()
27182779 )
27192780
27202781 # Force the ones that overshot target
2721- scalings [self .chi_factors < 1 ] = (
2722- ratio # * self.chi_factors[self.chi_factors < 1]
2782+ scalings [chi_factors < self .target_chi ] = ratio
2783+
2784+ for elem , indices , scale , group_ind in zip (
2785+ nested_values , nested_indices , scalings , flat_indices
2786+ ):
2787+ # Scale everything below same as super group
2788+ scaling_vector [group_ind ] = np .maximum (
2789+ ratio , scale * scaling_vector [group_ind ]
2790+ )
2791+
2792+ # Continue one level deeper if more nesting
2793+ if isinstance (indices , list ) and (
2794+ len (indices ) > 1 and isinstance (indices [0 ], list )
2795+ ):
2796+ scaling_vector = self .scale_by_level (
2797+ elem , indices , ratio , scaling_vector
2798+ )
2799+
2800+ return scaling_vector
2801+
2802+ def endIter (self ):
2803+ ratio = self .invProb .beta / self .last_beta
2804+ nested_residuals = self .parse_by_nested_levels (
2805+ self .nesting , self .invProb .residuals
2806+ )
2807+
2808+ scalings = self .scale_by_level (
2809+ nested_residuals , self .misfit_tree_indices , ratio , None
27232810 )
27242811
27252812 # Update the scaling
@@ -2728,22 +2815,76 @@ def endIter(self):
27282815 # Normalize total phi_d with scalings
27292816 self .invProb .dmisfit .multipliers = self .multipliers * self .scalings
27302817 self .last_beta = self .invProb .beta
2818+
2819+ # Log the scaling factors
27312820 self .write_log ()
27322821
2822+ def parse_by_nested_levels (
2823+ self , nesting : list [Iterable ], values : Iterable | None = None
2824+ ) -> Iterable :
2825+ """
2826+ Replace leaf elements of `nesting` with values from `values` (in order).
2827+ Assumes the number of leaf positions equals len(values).
2828+
2829+ Parameters:
2830+ - nesting: arbitrarily nested list structure; leaves are non-list values
2831+ - values: flat iterable whose values will fill the leaves in order
2832+
2833+ Returns:
2834+ - A new nested structure with leaves replaced by values from `values`.
2835+
2836+ Raises:
2837+ - ValueError if `values` has fewer or more elements than required by `nesting`.
2838+ """
2839+ indices = np .arange (len (self .invProb .dmisfit .objfcts ))
2840+ if nesting is None :
2841+ if values is not None :
2842+ return values
2843+ return indices .tolist ()
2844+
2845+ it = iter (indices )
2846+
2847+ def _fill (node : Iterable ) -> Iterable :
2848+ if isinstance (node , list ):
2849+ return [_fill (child ) for child in node ]
2850+ elif isinstance (node , dict ):
2851+ return [_fill (child ) for child in node .values ()]
2852+ # leaf: consume a value
2853+ try :
2854+ if values is not None :
2855+ return values [next (it )]
2856+ return next (it )
2857+ except StopIteration :
2858+ raise ValueError ("Not enough elements in `flat` to fill `nesting`." )
2859+
2860+ result = _fill (nesting )
2861+
2862+ # ensure no extra elements left
2863+ try :
2864+ next (it )
2865+ raise ValueError ("Too many elements in `flat` for the given `nesting`." )
2866+ except StopIteration :
2867+ pass
2868+
2869+ return result
2870+
27332871 def write_log (self ):
27342872 """
27352873 Write the scaling factors to the log file.
27362874 """
2737- with open (self .filepath , "a" , encoding = "utf-8" ) as f :
2738- f .write (
2739- f"{ self .opt .iter } \t "
2740- + "\t " .join (
2741- f"{ multi :.2e} * { chi :.2e} "
2742- for multi , chi in zip (
2743- self .invProb .dmisfit .multipliers , self .chi_factors
2744- )
2745- )
2746- + "\n "
2875+ self ._log_array = np .append (
2876+ self .log_array ,
2877+ np .rec .fromrecords (
2878+ tuple ([getattr (self .opt , "iter" , 0 )] + self .scalings .tolist ()),
2879+ dtype = self .log_array .dtype ,
2880+ ),
2881+ )
2882+ with open (self .filepath , "w" , encoding = "utf-8" ) as f :
2883+ np .savetxt (
2884+ f ,
2885+ self .log_array ,
2886+ header = "Iterations - Scaling per misfit" ,
2887+ fmt = ["%d" ] + ["%0.2e" ] * (len (self ._log_array .dtype ) - 1 ),
27472888 )
27482889
27492890
0 commit comments