Skip to content

Commit 51f05b1

Browse files
committed
CHange implementation of rescaling
1 parent 1cf096d commit 51f05b1

1 file changed

Lines changed: 132 additions & 30 deletions

File tree

simpeg/directives/_directives.py

Lines changed: 132 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta, abstractmethod
2-
from typing import TYPE_CHECKING
2+
from typing import Iterable, TYPE_CHECKING
33

44
from datetime import datetime
55
import pathlib
@@ -2655,6 +2655,16 @@ def endIter(self):
26552655
self.opt.xc = m
26562656

26572657

2658+
def flatten(nested_iterable):
2659+
for item in nested_iterable:
2660+
if isinstance(item, list):
2661+
yield from flatten(item)
2662+
elif isinstance(item, np.ndarray):
2663+
yield from item.tolist()
2664+
else:
2665+
yield item
2666+
2667+
26582668
class ScaleMisfitMultipliers(InversionDirective):
26592669
"""
26602670
Scale the misfits by the relative chi-factors of multiple misfit functions.
@@ -2670,9 +2680,17 @@ class ScaleMisfitMultipliers(InversionDirective):
26702680
Path to save the chi-factors log file.
26712681
"""
26722682

2673-
def __init__(self, path: pathlib.Path | None = None, **kwargs):
2683+
def __init__(
2684+
self,
2685+
path: pathlib.Path | None = None,
2686+
nesting: list[list] | None = None,
2687+
target_chi: float = 1.0,
2688+
**kwargs,
2689+
):
26742690
self.last_beta = None
26752691
self.chi_factors = None
2692+
self.target_chi = target_chi
2693+
self.nesting = nesting
26762694

26772695
if path is None:
26782696
path = pathlib.Path("./")
@@ -2684,42 +2702,77 @@ def __init__(self, path: pathlib.Path | None = None, **kwargs):
26842702
def initialize(self):
26852703
self.last_beta = self.invProb.beta
26862704
self.multipliers = self.invProb.dmisfit.multipliers
2687-
self.scalings = np.ones_like(self.multipliers)
2688-
with open(self.filepath, "w", encoding="utf-8") as f:
2689-
f.write("Logging of [scaling * chi factor] per misfit function.\n\n")
2690-
f.write(
2691-
"Iterations\t"
2692-
+ "\t".join(
2693-
f"[{objfct.name}]" for objfct in self.invProb.dmisfit.objfcts
2694-
)
2695-
)
2696-
f.write("\n")
2705+
self.scalings = np.ones_like(self.multipliers) # Everyone gets a fair chance
2706+
self.misfit_tree_indices = self.parse_by_nested_levels(self.nesting)
2707+
# def append_labels(_, depth):
2708+
# return f"[{depth}]"
2709+
#
2710+
# with open(self.filepath, "w", encoding="utf-8") as f:
2711+
# f.write("Logging of [scaling * chi factor] per misfit function.\n\n")
2712+
#
2713+
# header = "Iterations\t"
2714+
# for elem in recursions(self.nested_misfits, append_labels):
2715+
# header += "\t".join(f"Misfit [{elem}]")
2716+
#
2717+
# f.write("\n")
2718+
2719+
def scalings_by_level(
2720+
self, nested_values, nested_indices, ratio, scaling_vector: np.ndarray | None
2721+
):
2722+
"""
2723+
Recursively compute scaling factors for each level of the nested misfit structure.
2724+
2725+
The maximum chi-factor at each level is used to determine scaling factors
2726+
for the misfit functions at that level. The scaling factors are then propagated
2727+
down to the next level of the nested structure.
2728+
"""
2729+
if scaling_vector is None:
2730+
scaling_vector = np.ones(len(self.invProb.dmisfit.multipliers))
26972731

2698-
def endIter(self):
2699-
ratio = self.invProb.beta / self.last_beta
27002732
chi_factors = []
2701-
for residual in self.invProb.residuals:
2702-
phi_d = np.vdot(residual, residual)
2703-
chi_factors.append(phi_d / len(residual))
2733+
flat_indices = []
2734+
for elem, indices in zip(nested_values, nested_indices):
27042735

2705-
self.chi_factors = np.asarray(chi_factors)
2736+
# Reach the outer most level
2737+
if not isinstance(indices, list) or (
2738+
len(indices) == 1 and not isinstance(indices[0], list)
2739+
):
2740+
return scaling_vector
27062741

2707-
if np.all(self.chi_factors < 1) or ratio >= 1:
2708-
self.last_beta = self.invProb.beta
2709-
self.write_log()
2710-
return
2742+
flat_indices.append(np.asarray(list(flatten(indices))))
2743+
residuals = np.asarray(list(flatten(elem)))
2744+
phi_d = np.vdot(residuals, residuals)
2745+
chi_factors.append(phi_d / len(residuals))
27112746

2712-
# Normalize scaling between [ratio, 1]
2747+
chi_factors = np.hstack(chi_factors)
27132748
scalings = (
2714-
1
2715-
- (1 - ratio)
2716-
* (self.chi_factors.max() - self.chi_factors)
2717-
/ self.chi_factors.max()
2749+
1 - (1 - ratio) * (chi_factors.max() - chi_factors) / chi_factors.max()
27182750
)
27192751

27202752
# Force the ones that overshot target
2721-
scalings[self.chi_factors < 1] = (
2722-
ratio # * self.chi_factors[self.chi_factors < 1]
2753+
scalings[chi_factors < self.target_chi] = ratio
2754+
2755+
for elem, indices, scale, group_ind in zip(
2756+
nested_values, nested_indices, scalings, flat_indices
2757+
):
2758+
# Scale everything below same as super group
2759+
scaling_vector[group_ind] = np.maximum(
2760+
ratio, scale * scaling_vector[group_ind]
2761+
)
2762+
scaling_vector = self.scalings_by_level(
2763+
elem, indices, ratio, scaling_vector
2764+
)
2765+
2766+
return scaling_vector
2767+
2768+
def endIter(self):
2769+
ratio = self.invProb.beta / self.last_beta
2770+
nested_residuals = self.parse_by_nested_levels(
2771+
self.nesting, self.invProb.residuals
2772+
)
2773+
2774+
scalings = self.scalings_by_level(
2775+
nested_residuals, self.misfit_tree_indices, ratio, None
27232776
)
27242777

27252778
# Update the scaling
@@ -2728,7 +2781,56 @@ def endIter(self):
27282781
# Normalize total phi_d with scalings
27292782
self.invProb.dmisfit.multipliers = self.multipliers * self.scalings
27302783
self.last_beta = self.invProb.beta
2731-
self.write_log()
2784+
# self.write_log()
2785+
2786+
def parse_by_nested_levels(
2787+
self, nesting: list[Iterable], values: Iterable | None = None
2788+
) -> Iterable:
2789+
"""
2790+
Replace leaf elements of `nesting` with values from `flat` (in order).
2791+
Assumes the number of leaf positions equals len(flat).
2792+
2793+
Parameters:
2794+
- nesting: arbitrarily nested list structure; leaves are non-list values
2795+
- flat: flat iterable whose values will fill the leaves in order
2796+
2797+
Returns:
2798+
- A new nested structure with leaves replaced by values from `flat`.
2799+
2800+
Raises:
2801+
- ValueError if `flat` has fewer or more elements than required by `nesting`.
2802+
"""
2803+
indices = np.arange(len(self.invProb.dmisfit.objfcts))
2804+
if nesting is None:
2805+
if values is not None:
2806+
return values
2807+
return indices.tolist()
2808+
2809+
it = iter(indices)
2810+
2811+
def _fill(node: Iterable) -> Iterable:
2812+
if isinstance(node, list):
2813+
return [_fill(child) for child in node]
2814+
elif isinstance(node, dict):
2815+
return [_fill(child) for child in node.values()]
2816+
# leaf: consume a value
2817+
try:
2818+
if values is not None:
2819+
return values[next(it)]
2820+
return next(it)
2821+
except StopIteration:
2822+
raise ValueError("Not enough elements in `flat` to fill `nesting`.")
2823+
2824+
result = _fill(nesting)
2825+
2826+
# ensure no extra elements left
2827+
try:
2828+
next(it)
2829+
raise ValueError("Too many elements in `flat` for the given `nesting`.")
2830+
except StopIteration:
2831+
pass
2832+
2833+
return result
27322834

27332835
def write_log(self):
27342836
"""

0 commit comments

Comments
 (0)