@@ -2987,23 +2987,29 @@ class SaveIterationsGeoH5(InversionDirective):
29872987 Saves inversion results to a geoh5 file
29882988 """
29892989
2990- def __init__ (self , h5_object , ** kwargs ):
2990+ def __init__ (
2991+ self , h5_object , dmisfit = None , attribute_type : str = "model" , ** kwargs
2992+ ):
29912993 self .data_type = {}
29922994 self ._association = None
2993- self .attribute_type = "model"
2995+ self .attribute_type = attribute_type
29942996 self ._label = None
29952997 self .channels = ["" ]
29962998 self .components = ["" ]
2997- self ._h5_object = None
2998- self ._workspace = None
29992999 self ._transforms : list = []
30003000 self .save_objective_function = False
30013001 self .sorting = None
30023002 self ._reshape = None
30033003 self .h5_object = h5_object
30043004 self ._joint_index = None
3005+
3006+ if attribute_type == "sensitivities" and dmisfit is None :
3007+ raise ValueError (
3008+ "To save sensitivities, the data misfit object must be provided."
3009+ )
3010+
30053011 super ().__init__ (
3006- inversion = None , dmisfit = None , reg = None , verbose = False , ** kwargs
3012+ inversion = None , dmisfit = dmisfit , reg = None , verbose = False , ** kwargs
30073013 )
30083014
30093015 def initialize (self ):
@@ -3085,9 +3091,10 @@ def get_values(self, values: list[np.ndarray] | None):
30853091
30863092 prop = self .stack_channels (dpred )
30873093 elif self .attribute_type == "sensitivities" :
3088- for directive in self .inversion .directiveList .dList :
3089- if isinstance (directive , directives .UpdateSensitivityWeights ):
3090- prop = self .reshape (np .sum (directive .JtJdiag , axis = 0 ) ** 0.5 )
3094+
3095+ prop = np .zeros_like (self .invProb .model )
3096+ for fun in self .dmisfit .objfcts :
3097+ prop += fun .getJtJdiag (self .invProb .model )
30913098
30923099 return prop
30933100
0 commit comments