Skip to content

Commit ed4bad0

Browse files
pr/new enrml esmda commits (#143)
* Update adjoint handling * Fix logic for logging in LM-EnRML * Include ensemble_mistfit for LM-EnRML * Include enAdj for ESMDA
1 parent 6072a12 commit ed4bad0

3 files changed

Lines changed: 69 additions & 93 deletions

File tree

pipt/misc_tools/data_tools.py

Lines changed: 31 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@
66
__all__ = [
77
'combine_ensemble_predictions',
88
'en_pred_to_pred_data',
9-
'melt_adjoint_to_sensitivity',
10-
'combine_ensemble_dataframes',
11-
'combine_adjoint_ensemble',
9+
'merge_dataframes',
10+
'multilevel_to_singlelevel_columns',
1211
'dataframe_to_series',
1312
'series_to_dataframe',
1413
'series_to_matrix',
15-
'dataframe_to_matrix',
16-
'multilevel_to_singlelevel_columns'
14+
'dataframe_to_matrix'
1715
]
1816

1917

@@ -140,42 +138,7 @@ def en_pred_to_pred_data(en_pred):
140138
return pred_data
141139

142140

143-
def melt_adjoint_to_sensitivity(adjoint: pd.DataFrame, datatype: list, idX: dict):
144-
145-
adj_datatype = adjoint.columns.levels[0]
146-
adj_params = adjoint.columns.levels[1]
147-
148-
adj_datatype = sorted(adj_datatype, key=lambda x: datatype.index(x))
149-
adj_params = sorted(adj_params, key=lambda x: list(idX.keys()).index(x))
150-
151-
sens = pd.DataFrame(columns=adj_datatype, index=adjoint.index)
152-
for idx in sens.index:
153-
for dkey in adj_datatype:
154-
arr = np.array([])
155-
for param in adj_params:
156-
157-
if not isinstance(adjoint.at[idx, (dkey, param)], np.ndarray):
158-
if np.isnan(adjoint.at[idx, (dkey, param)]):
159-
dim = idX[param]
160-
dim = dim[1] - dim[0]
161-
arr = np.append(arr, np.zeros(dim))
162-
else:
163-
arr = np.append(arr, np.array([adjoint.at[idx, (dkey, param)]]))
164-
165-
else:
166-
a = adjoint.at[idx, (dkey, param)]
167-
a = np.where(np.isnan(a), 0, a)
168-
arr = np.append(arr, a)
169-
170-
sens.at[idx, dkey] = arr
171-
172-
# Melt
173-
sens = sens.melt(ignore_index=False)
174-
sens.rename(columns={'variable': 'datatype', 'value': 'adjoint'}, inplace=True)
175-
return sens
176-
177-
178-
def combine_ensemble_dataframes(en_dfs: list):
141+
def merge_dataframes(en_dfs: list[pd.DataFrame]) -> pd.DataFrame:
179142
'''
180143
Combine a list of DataFrames (one per ensemble member) into a single DataFrame
181144
where each cell contains an array of ensemble values.
@@ -193,32 +156,36 @@ def combine_ensemble_dataframes(en_dfs: list):
193156
values = []
194157
for dfn in en_dfs:
195158
values.append(dfn.at[idx, col])
196-
df.at[idx, col] = np.array(values).squeeze()
197-
159+
df.at[idx, col] = np.array(values).squeeze().T
198160
return df
199161

200-
def combine_adjoint_ensemble(en_adj, datatype: list, idX: dict):
201-
202-
adjoints = [melt_adjoint_to_sensitivity(adj, datatype, idX) for adj in en_adj]
162+
def multilevel_to_singlelevel_columns(df: pd.DataFrame) -> pd.DataFrame:
163+
"""
164+
Convert a MultiIndex-column DataFrame with structure (key, param)
165+
into a DataFrame with one column per key, where the value is
166+
the concatenation of all param-arrays for that key.
167+
"""
168+
result = {}
203169

204-
index = adjoints[0].index
205-
index_name = adjoints[0].index.name
206-
keys = adjoints[0]['datatype'].values
207-
keys = sorted(keys, key=lambda x: datatype.index(x))
170+
# Top-level keys (level 0 of MultiIndex), preserving first appearance order
171+
keys = pd.Index(df.columns.get_level_values(0)).unique()
208172

209-
#df = pd.DataFrame(columns=['datatype', 'adjoint'], index=index, dtype=object)
173+
for key in keys:
174+
# Extract all columns for this key → list of arrays per row
175+
param_arrays = df[key] # this is a sub-dataframe for this key
210176

211-
data = {'datatype': [], 'adjoint': []}
212-
for i, idx in enumerate(index):
213-
data['datatype'].append(keys[i])
214-
matrix = []
215-
for adj in adjoints:
216-
matrix.append(adj.iloc[i]['adjoint'])
217-
data['adjoint'].append(np.array(matrix).T) # Transpose to get correct shape (n_param, n_ensembles)
177+
# For each row, concatenate arrays from all params
178+
concatenated = [
179+
np.concatenate(param_arrays.iloc[i].values)
180+
for i in range(len(df))
181+
]
218182

219-
df = pd.DataFrame(data, index=index)
220-
df.index.name = index_name
221-
return df
183+
result[key] = concatenated
184+
185+
df_new = pd.DataFrame(result, index=df.index)
186+
df_new.index.name = df.index.name
187+
return df_new
188+
222189

223190
def dataframe_to_series(df):
224191
mult_index = []
@@ -250,16 +217,10 @@ def dataframe_to_matrix(df):
250217
series = dataframe_to_series(df)
251218
return series_to_matrix(series)
252219

253-
def multilevel_to_singlelevel_columns(df):
254-
cols = df.columns.get_level_values(0).unique()
255-
parms = df.columns.get_level_values(1).unique()
256-
257-
df_new = pd.DataFrame(index=df.index)
258-
for col in cols:
259-
df_new[col] = np.concatenate([df[(col, param)].values for param in parms])
260-
261-
return df_new
262220

221+
222+
223+
263224

264225

265226

pipt/update_schemes/enrml.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def calc_analysis(self):
133133
)
134134

135135
# Store the (mean) data misfit (also for conv. check)
136+
self.ensemble_misfit = data_misfit
136137
self.data_misfit = np.mean(data_misfit)
137138
self.prior_data_misfit = np.mean(data_misfit)
138139
self.data_misfit_std = np.std(data_misfit)
@@ -149,8 +150,8 @@ def calc_analysis(self):
149150

150151
# Check for adjoint
151152
if hasattr(self, 'adjoints'):
152-
enAdj = dtools.combine_ensemble_dataframes(self.adjoints)
153-
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, ne, nx)
153+
enAdj = dtools.merge_dataframes(self.adjoints)
154+
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, nx, ne)
154155
else:
155156
enAdj = None
156157

@@ -209,7 +210,7 @@ def check_convergence(self):
209210
# data instead.
210211

211212
data_misfit = at.calc_objectivefun(self.enObs, enPred, self.cov_data)
212-
213+
self.ensemble_misfit = data_misfit
213214
self.data_misfit = np.mean(data_misfit)
214215
self.data_misfit_std = np.std(data_misfit)
215216

@@ -229,11 +230,13 @@ def check_convergence(self):
229230

230231
if self.data_misfit >= self.prev_data_misfit:
231232
success = False
233+
self.log_update(success=success)
232234
self.logger(
233235
f'Iterations have converged after {self.iteration} iterations. Objective function reduced '
234236
f'from {self.prior_data_misfit:0.1f} to {self.prev_data_misfit:0.1f}'
235237
)
236238
else:
239+
self.log_update(success=True)
237240
self.logger.info(
238241
f'Iterations have converged after {self.iteration} iterations. Objective function reduced '
239242
f'from {self.prior_data_misfit:0.1f} to {self.data_misfit:0.1f}'
@@ -249,20 +252,21 @@ def check_convergence(self):
249252
'prev_data_misfit': self.prev_data_misfit,
250253
'lambda': self.lam,
251254
'lambda_stop': self.lam >= self.lam_max}
252-
253-
# Log step
254-
self.log_update(success=success)
255+
255256

256257
###############################################
257258
##### update Lambda step-size values ##########
258259
###############################################
259260
# If reduction in mean data misfit, reduce damping param
260261
if self.data_misfit < self.prev_data_misfit and self.data_misfit_std < self.prev_data_misfit_std:
261-
# Reduce damping parameter (divide calculations for ANALYSISDEBUG purpose)
262+
263+
success = True
264+
self.log_update(success=success)
265+
266+
# Reduce damping parameter
262267
if self.lam > self.lam_min:
263268
self.lam = self.lam / self.gamma
264269
self.logger(f'λ reduced: {self.lam * self.gamma} ──> {self.lam}')
265-
success = True
266270

267271
# Update state ensemble
268272
self.enX = cp.deepcopy(self.enX_temp)
@@ -274,8 +278,10 @@ def check_convergence(self):
274278

275279

276280
elif self.data_misfit < self.prev_data_misfit and self.data_misfit_std >= self.prev_data_misfit_std:
281+
277282
# accept itaration, but keep lam the same
278283
success = True
284+
self.log_update(success=success)
279285

280286
# Update state ensemble
281287
self.enX = cp.deepcopy(self.enX_temp)
@@ -286,10 +292,11 @@ def check_convergence(self):
286292
self.current_W = cp.deepcopy(self.W)
287293

288294
else: # Reject iteration, and increase lam
289-
# Increase damping parameter (divide calculations for ANALYSISDEBUG purpose)
295+
success = False
296+
self.log_update(success=success)
290297
self.lam = self.lam * self.gamma
298+
# Increase damping parameter (divide calculations for ANALYSISDEBUG purpose)
291299
self.logger(f'Data misfit increased! λ increased: {self.lam / self.gamma} ──> {self.lam}')
292-
success = False
293300

294301
if not success:
295302
# Reset the objective function after report
@@ -303,21 +310,18 @@ def log_update(self, success, prior_run=False):
303310
'''
304311
Log the update results in a formatted table.
305312
'''
306-
log_data = {
307-
"Iteration": f'{0 if prior_run else self.iteration}',
308-
"Status": "Success" if (prior_run or success) else "Failed",
309-
"Data Misfit": self.data_misfit,
310-
"λ": self.lam
313+
info = {
314+
"Iteration" : f'{0 if prior_run else self.iteration}',
315+
"Status" : "Success" if (prior_run or success) else "Failed",
316+
"Data Misfit" : self.data_misfit,
317+
"Change (%)" : '',
318+
"λ" : self.lam
311319
}
312320
if not prior_run:
313-
if success:
314-
log_data["Reduction (%)"] = 100 * (1 - self.data_misfit / self.prev_data_misfit)
315-
else:
316-
log_data["Increase (%)"] = 100 * (self.data_misfit / self.prev_data_misfit - 1)
317-
else:
318-
log_data["Reduction (%)"] = 'N/A'
321+
delta = 100*(self.data_misfit / self.prev_data_misfit - 1)
322+
info["Change (%)"] = delta
319323

320-
self.logger(**log_data)
324+
self.logger(**info)
321325

322326

323327

pipt/update_schemes/esmda.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pipt.loop.ensemble import Ensemble
1313
import pipt.misc_tools.analysis_tools as at
1414
import pipt.misc_tools.ensemble_tools as entools
15+
import pipt.misc_tools.data_tools as dtools
1516

1617
# import update schemes
1718
from pipt.update_schemes.update_methods_ns.approx_update import approx_update
@@ -158,12 +159,22 @@ def calc_analysis(self):
158159
if 'localanalysis' in self.keys_da:
159160
self.local_analysis_update()
160161
else:
162+
163+
# Check for adjoint
164+
if hasattr(self, 'adjoints'):
165+
enAdj = dtools.merge_dataframes(self.adjoints)
166+
enAdj = dtools.dataframe_to_matrix(enAdj) # Shape (nd, nx, ne)
167+
else:
168+
enAdj = None
169+
161170
# Perform the update
162171
self.update(
163172
enX = self.enX,
164173
enY = self.enPred,
165174
enE = self.enObs,
166-
prior = self.prior_enX
175+
# kwargs
176+
prior = self.prior_enX,
177+
enAdj = enAdj
167178
)
168179

169180
# Update the state ensemble and weights

0 commit comments

Comments
 (0)