Skip to content

Commit 6f313d7

Browse files
authored
new version of subspace methods (#154)
* new version of subspace methods * removed sklearn dependence
1 parent 409c9a6 commit 6f313d7

5 files changed

Lines changed: 228 additions & 55 deletions

File tree

src/pipt/update_schemes/enrml.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
import pipt.misc_tools.ensemble_tools as entools
88
import pipt.misc_tools.data_tools as dtools
99

10+
1011
from geostat.decomp import Cholesky
1112
from pipt.loop.ensemble import Ensemble
1213
from pipt.update_schemes.update_methods_ns.subspace_update import subspace_update
14+
from pipt.update_schemes.update_methods_ns.subspace2_update import subspace2_update
1315
from pipt.update_schemes.update_methods_ns.full_update import full_update
1416
from pipt.update_schemes.update_methods_ns.approx_update import approx_update
17+
from pipt.update_schemes.update_methods_ns.margIS_update import margIS_update
18+
1519
import sys
1620
import pkgutil
1721
import inspect
@@ -35,11 +39,11 @@
3539
# import standard libraries
3640

3741
# Check and import (if present) from other namespace packages
38-
if 'margIS_update' in [el[0] for el in tot_ns_pkg]: # only compare package name
39-
from pipt.update_schemes.update_methods_ns.margIS_update import margIS_update
40-
else:
41-
class margIS_update:
42-
pass
42+
#if 'margIS_update' in [el[0] for el in tot_ns_pkg]: # only compare package name
43+
# from pipt.update_schemes.update_methods_ns.margIS_update import margIS_update
44+
#else:
45+
# class margIS_update:
46+
# pass
4347

4448
# Internal imports
4549
from pipt.misc_tools.analysis_tools import aug_state
@@ -168,9 +172,26 @@ def calc_analysis(self):
168172
# Update the state ensemble and weights
169173
if hasattr(self, 'step'):
170174
self.enX_temp = self.enX + self.step
175+
# This is the vector update following e.g. Evensen et al 2019 update for subspace
171176
if hasattr(self, 'w_step'):
172177
self.W = self.current_W + self.w_step
173-
self.enX_temp = np.dot(self.prior_enX, (np.eye(self.ne) + self.W/np.sqrt(self.ne - 1)))
178+
self.enX_temp = np.dot(self.prior_enX, (np.eye(self.ne) + self.W / np.sqrt(self.ne - 1)))
179+
#This is the matrix update following e.g. Raanes et al 2019 update for subspace
180+
if hasattr(self, 'W_step'):
181+
self.W = self.current_W + self.W_step
182+
X_p = self.prior_enX @ self.proj * np.sqrt(self.ne - 1)
183+
self.enX_temp = np.mean(self.prior_enX, axis=1, keepdims=True) + np.dot(X_p, self.W)
184+
185+
if hasattr(self, 'sqrt_w_step'):
186+
self.w = self.current_w + self.sqrt_w_step
187+
Us, Ss, VsT = np.linalg.svd(self.S, full_matrices=False)
188+
eps = 1e-8 * Ss[0] # e.g., 1e-8 * largest
189+
s_inv = 1.0 / np.sqrt(np.maximum(Ss, eps))
190+
S_inv = np.diag(s_inv)
191+
self.W = Us @ S_inv @ Us.T
192+
X_p = self.prior_enX @ self.proj * np.sqrt(self.ne - 1)
193+
x = np.mean(self.prior_enX, axis=1) + X_p @ self.w
194+
self.enX_temp = np.repeat(x[:, None], self.ne, axis=1) + np.dot(X_p, self.W)
174195

175196

176197
# Ensure limits are respected
@@ -275,6 +296,8 @@ def check_convergence(self):
275296
# Update ensemble weights
276297
if hasattr(self, 'W'):
277298
self.current_W = cp.deepcopy(self.W)
299+
if hasattr(self, 'w'):
300+
self.current_w = cp.deepcopy(self.w)
278301

279302

280303
elif self.data_misfit < self.prev_data_misfit and self.data_misfit_std >= self.prev_data_misfit_std:
@@ -290,6 +313,8 @@ def check_convergence(self):
290313
# Update ensemble weights
291314
if hasattr(self, 'W'):
292315
self.current_W = cp.deepcopy(self.W)
316+
if hasattr(self, 'w'):
317+
self.current_w = cp.deepcopy(self.w)
293318

294319
else: # Reject iteration, and increase lam
295320
success = False
@@ -337,6 +362,12 @@ class lmenrml_full(lmenrmlMixIn, full_update):
337362
class lmenrml_subspace(lmenrmlMixIn, subspace_update):
338363
pass
339364

365+
class lmenrml_subspace2(lmenrmlMixIn, subspace2_update):
366+
pass
367+
368+
class lmenrml_margIS(lmenrmlMixIn, margIS_update):
369+
pass
370+
340371

341372
class gnenrmlMixIn(Ensemble):
342373
"""
@@ -662,6 +693,9 @@ class gnenrml_full(gnenrmlMixIn, full_update):
662693
class gnenrml_subspace(gnenrmlMixIn, subspace_update):
663694
pass
664695

696+
class gnenrml_subspace2(gnenrmlMixIn, subspace2_update):
697+
pass
698+
665699

666700
class gnenrml_margis(gnenrmlMixIn, margIS_update):
667701
'''

src/pipt/update_schemes/esmda.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pipt.update_schemes.update_methods_ns.approx_update import approx_update
1919
from pipt.update_schemes.update_methods_ns.full_update import full_update
2020
from pipt.update_schemes.update_methods_ns.subspace_update import subspace_update
21-
21+
from pipt.update_schemes.update_methods_ns.subspace2_update import subspace2_update
2222

2323
class esmdaMixIn(Ensemble):
2424
"""
@@ -180,9 +180,26 @@ def calc_analysis(self):
180180
# Update the state ensemble and weights
181181
if hasattr(self, 'step'):
182182
self.enX_temp = self.enX + self.step
183+
# This is the vector update following e.g. Evensen et al 2019 update for subspace
183184
if hasattr(self, 'w_step'):
184185
self.W = self.current_W + self.w_step
185-
self.enX_temp = np.dot(self.prior_enX, (np.eye(self.ne) + self.W/np.sqrt(self.ne - 1)))
186+
self.enX_temp = np.dot(self.prior_enX, (np.eye(self.ne) + self.W / np.sqrt(self.ne - 1)))
187+
# This is the matrix update following e.g. Raanes et al 2019 update for subspace
188+
if hasattr(self, 'W_step'):
189+
self.W = self.current_W + self.W_step
190+
X_p = self.prior_enX @ self.proj * np.sqrt(self.ne - 1)
191+
self.enX_temp = np.mean(self.prior_enX, axis=1, keepdims=True) + np.dot(X_p, self.W)
192+
193+
if hasattr(self, 'sqrt_w_step'):
194+
self.w = self.current_w + self.sqrt_w_step
195+
Us, Ss, VsT = np.linalg.svd(self.S, full_matrices=False)
196+
eps = 1e-8 * Ss[0] # e.g., 1e-8 * largest
197+
s_inv = 1.0 / np.sqrt(np.maximum(Ss, eps))
198+
S_inv = np.diag(s_inv)
199+
self.W = Us @ S_inv @ Us.T
200+
X_p = self.prior_enX @ self.proj * np.sqrt(self.ne - 1)
201+
x = np.mean(self.prior_enX, axis=1) + X_p @ self.w
202+
self.enX_temp = np.repeat(x[:, None], self.ne, axis=1) + np.dot(X_p, self.W)
186203

187204

188205
# Ensure limits are respected
@@ -367,6 +384,9 @@ class esmda_full(esmdaMixIn, full_update):
367384
class esmda_subspace(esmdaMixIn, subspace_update):
368385
pass
369386

387+
class esmda_subspace2(esmdaMixIn, subspace2_update):
388+
pass
389+
370390

371391
class esmda_geo(esmda_approx):
372392
"""

src/pipt/update_schemes/update_methods_ns/margIS_update.py

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,101 @@
1+
"""Stochastic iterative ensemble smoother (IES, i.e. EnRML) with *subspace* implementation."""
2+
13
import numpy as np
2-
from scipy.linalg import solve
3-
import copy as cp
4-
from pipt.misc_tools import analysis_tools as at
4+
from scipy.linalg import solve, lu_solve, lu_factor, cho_solve
5+
6+
import pipt.misc_tools.analysis_tools as at
57

6-
class margIS_update():
78

9+
class margIS_update():
810
"""
9-
Placeholder for private margIS method
11+
MargIES update from Stordal et.al.
12+
This is now implemented with perturbed observations, which means that we set a prior belief on the data uncertainty.
13+
Thus, the prior is an invers chi2 distriubtuinm and after scaling the mean varians is 1.
1014
"""
11-
def update(self):
12-
if self.iteration == 1: # method requires some initiallization
13-
self.aug_prior = cp.deepcopy(at.aug_state(self.prior_state, self.list_states))
14-
self.mean_prior = self.aug_prior.mean(axis=1)
15-
self.X = (self.aug_prior - np.dot(np.resize(self.mean_prior, (len(self.mean_prior), 1)),
16-
np.ones((1, self.ne))))
17-
self.W = np.eye(self.ne)
18-
self.current_w = np.zeros((self.ne,))
19-
self.E = np.dot(self.real_obs_data, self.proj)
20-
21-
M = len(self.real_obs_data)
22-
Ytmp = solve(self.W, self.proj)
23-
if len(self.scale_data.shape) == 1:
24-
Y = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1), np.ones((1, self.ne))) * \
25-
np.dot(self.aug_pred_data, Ytmp)
26-
else:
27-
Y = solve(self.scale_data, np.dot(self.aug_pred_data, Ytmp))
2815

29-
pred_data_mean = np.mean(self.aug_pred_data, 1)
30-
delta_d = (self.obs_data_vector - pred_data_mean)
16+
def update(self, enX, enY, enE, **kwargs):
17+
18+
if self.iteration == 1: # method requires some initiallization
19+
self.current_W = np.eye(self.ne)
20+
self.current_w = np.zeros(self.ne)
21+
self.D = self.scale(enE, self.scale_data)
22+
# Scale everything so that data uncertainty is I
23+
24+
sY = self.scale(enY, self.scale_data) #Scaling is same as with 'known' uncertainty, hence makes sense to set s = 1
25+
self.S = 0
26+
27+
deltaD = 0
28+
deltaD_sqrt = 0
29+
30+
Y = np.linalg.solve(self.current_W.T, sY.T).T
31+
Y = Y @ self.proj * np.sqrt(self.ne - 1)
32+
index = np.arange(0, 70, 70) # Has to be specified via data types (or select each data)...
33+
M = 1 #Numbers of data per type. Computed from index
34+
s = 1 #should be default option with possibility to change in setup
35+
nu = self.ne-1 #should be default option with possibility to change in setup
36+
for j in range(70):
37+
38+
delta = self.D[index,:]-sY[index,:]
39+
Chi = np.sum(delta * delta, axis = 0)
40+
Chi = np.mean(Chi)
41+
Ratio = (M + nu) / (Chi + nu*s*s)
42+
#Ratio = 1
43+
#Gradient
44+
deltaD = deltaD + (Y[index,:] * Ratio).T @ delta
45+
deltaD_sqrt = deltaD_sqrt + np.mean((Y[index, :] * Ratio).T @ delta ,axis=1)
46+
# Hessian
47+
self.S = self.S + (Y[index,:] * Ratio).T @ Y[index,:]
48+
index += 1
49+
50+
deltaM = (self.ne-1)*(np.eye(self.ne)-self.current_W)
51+
deltaM_sqrt = (self.ne-1)*self.current_w
52+
self.S = self.S + np.eye(self.ne) * (self.ne - 1)
53+
Delta = deltaM + deltaD
54+
Delta_sqrt = deltaM_sqrt + deltaD_sqrt
3155

32-
if len(self.cov_data.shape) == 1:
33-
S = np.dot(delta_d, (self.cov_data**(-1)) * delta_d)
34-
Ratio = M / S
35-
grad_lklhd = np.dot(Y.T * Ratio, (self.cov_data**(-1)) * delta_d)
36-
grad_prior = (self.ne - 1) * self.current_w
37-
self.C_w = (np.dot(Ratio * Y.T, np.dot(np.diag(self.cov_data ** (-1)), Y)) + (self.ne - 1) * np.eye(self.ne))
56+
57+
self.W_step = np.linalg.solve(self.S, Delta) / (1 + self.lam)
58+
# self.sqrt_w_step = np.linalg.solve(self.S, Delta_sqrt) / (1 + self.lam)
59+
60+
def scale(self, data, scaling):
61+
"""
62+
Scale the data perturbations by the data error standard deviation.
63+
64+
Args:
65+
data (np.ndarray): data perturbations
66+
scaling (np.ndarray): data error standard deviation
67+
68+
Returns:
69+
np.ndarray: scaled data perturbations
70+
"""
71+
72+
if len(scaling.shape) == 1:
73+
return (scaling ** (-1))[:, None] * data
3874
else:
39-
S = np.dot(delta_d, solve(self.cov_data, delta_d))
40-
Ratio = M / S
41-
grad_lklhd = np.dot(Y.T * Ratio, solve(self.cov_data, delta_d))
42-
grad_prior = (self.ne - 1) * self.current_w
43-
self.C_w = (np.dot(Ratio * Y.T, solve(self.cov_data, Y)) + (self.ne - 1) * np.eye(self.ne))
75+
return solve(scaling, data)
76+
77+
78+
79+
80+
81+
82+
83+
84+
85+
86+
87+
88+
89+
90+
91+
92+
93+
94+
95+
96+
97+
98+
99+
100+
44101

45-
self.sqrt_w_step = solve(self.C_w, grad_prior + grad_lklhd)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Stochastic iterative ensemble smoother (IES, i.e. EnRML) with *subspace* implementation."""
2+
3+
import numpy as np
4+
from scipy.linalg import solve, lu_solve, lu_factor
5+
import pipt.misc_tools.analysis_tools as at
6+
7+
8+
9+
10+
11+
class subspace2_update():
12+
"""
13+
Ensemble subspace update, as described in Raanes, P. N., Stordal, A. S., &
14+
Evensen, G. (2019). Revising the stochastic iterative ensemble smoother.
15+
Nonlinear Processes in Geophysics, 26(3), 325–338. https://doi.org/10.5194/npg-26-325-2019
16+
17+
"""
18+
19+
def update(self, enX, enY, enE, **kwargs):
20+
21+
if self.iteration == 1: # method requires some initiallization
22+
self.current_W = np.eye(self.ne)
23+
self.D = self.scale(enE, self.scale_data)
24+
# Scale everything so that data uncertainty is I
25+
sY = self.scale(enY, self.scale_data)
26+
Y = np.linalg.solve(self.current_W.T,sY.T).T #Raanes
27+
Y = np.dot(Y, self.proj) * np.sqrt(self.ne - 1) #Raanes
28+
29+
30+
#Gradients
31+
32+
deltaD = Y.T @ (self.D - sY)
33+
deltaM = (self.ne-1)*(np.eye(self.ne)-self.current_W)
34+
35+
#Hessian
36+
S = Y.T @ Y + np.eye(self.ne)*(self.ne-1)
37+
38+
self.W_step = np.linalg.solve(S , (deltaM + deltaD))/(1 + self.lam)
39+
40+
41+
def scale(self, data, scaling):
42+
"""
43+
Scale the data perturbations by the data error standard deviation.
44+
45+
Args:
46+
data (np.ndarray): data perturbations
47+
scaling (np.ndarray): data error standard deviation
48+
49+
Returns:
50+
np.ndarray: scaled data perturbations
51+
"""
52+
53+
if len(scaling.shape) == 1:
54+
return (scaling ** (-1))[:, None] * data
55+
else:
56+
return solve(scaling, data)

0 commit comments

Comments
 (0)