1010import numpy as np
1111from scipy import linalg
1212
13- from . import LSSM
1413from . import PrepModel
15- from .PSID import blkhankskip , projOrth , getHSize
16-
17-
18- def transposeIf (Y ):
19- """Transposes Y itself if Y is an array or each element of Y if it is a list/tuple of arrays.
20-
21- Args:
22- Y (np.array or list or tuple): input data or list of input data arrays.
23-
24- Returns:
25- np.array or list or tuple: transposed Y or list of transposed arrays.
26- """
27- if Y is None :
28- return None
29- elif isinstance (Y , (list , tuple )):
30- return [transposeIf (YThis ) for YThis in Y ]
31- else :
32- return Y .T
33-
34-
35- def catIf (Y , axis = None ):
36- """If Y is a list of arrays, will concatenate them otherwise returns Y
37-
38- Args:
39- Y (np.array or list or tuple): input data or list of input data arrays.
40-
41- Returns:
42- np.array or list or tuple: transposed Y or list of transposed arrays.
43- """
44- if Y is None :
45- return None
46- elif isinstance (Y , (list , tuple )):
47- return np .concatenate (Y , axis = axis )
48- else :
49- return Y
14+ from .LSSM import LSSM
15+ from .PSID import blkhankskip , getHSize , projOrth
16+ from .tools import transposeIf , catIf
5017
5118
5219def removeProjOrth (A , B ):
@@ -101,8 +68,12 @@ def recomputeObsAndStates(A, C, i, YHat, YHatMinus):
10168 2) Xk_Plus1: recomputed states at next time step
10269 """
10370 Oy , Oy_Minus = computeObsFromAC (A , C , i )
104- Xk = np .linalg .pinv (Oy ) @ YHat
105- Xk_Plus1 = np .linalg .pinv (Oy_Minus ) @ YHatMinus
71+ Xk = (
72+ np .linalg .pinv (Oy ) @ YHat
73+ ) # VODM Book p131, Step 5 => NOTE that when there is input, this is XUk (or Zk per VODM) rather than Xk
74+ Xk_Plus1 = (
75+ np .linalg .pinv (Oy_Minus ) @ YHatMinus
76+ ) # VODM Book p131, Step 5 => NOTE that when there is input, this is UXk_Plus1 rather than Xk_Plus1
10677 return Xk , Xk_Plus1
10778
10879
@@ -117,7 +88,7 @@ def computeBD(A, C, Yii, Xk_Plus1, Xk, i, nu, Uf):
11788 # Find B and D
11889 Oy , Oy_Minus = computeObsFromAC (A , C , i )
11990
120- # See ref. 40 pages 125-127
91+ # See VODM book pages 125-127
12192 PP = np .concatenate ((Xk_Plus1 - A @ Xk , Yii - C @ Xk ))
12293
12394 L1 = A @ np .linalg .pinv (Oy )
@@ -128,11 +99,11 @@ def computeBD(A, C, Yii, Xk_Plus1, Xk, i, nu, Uf):
12899
129100 ZM = np .concatenate ((np .zeros ((nx , ny )), np .linalg .pinv (Oy_Minus )), axis = 1 )
130101
131- # LHS * DB = PP
102+ # LHS * DB = PP # VODM (4.61)
132103 LHS = np .zeros ((PP .size , (nx + ny ) * nu ))
133104 RMul = linalg .block_diag (np .eye (ny ), Oy_Minus )
134105
135- NNAll = [] # ref. 40 (4.54), (4.57),.., (4.59)
106+ NNAll = [] # VODM (4.54) && VODM (4.57) .. (4.59)
136107 # Plug in the terms into NN
137108 for ii in range (i ):
138109 NN = np .zeros (((nx + ny ), i * ny ))
@@ -146,7 +117,9 @@ def computeBD(A, C, Yii, Xk_Plus1, Xk, i, nu, Uf):
146117 LHS = LHS + np .kron (Uf [(ii * nu ) : (ii * nu + nu ), :].T , NN @ RMul )
147118 NNAll .append (NN )
148119
149- DBVec = np .linalg .lstsq (LHS , PP .flatten (order = "F" ), rcond = None )[0 ]
120+ DBVec = np .linalg .lstsq (LHS , PP .flatten (order = "F" ), rcond = None )[
121+ 0
122+ ] # In MATLAB: LHS \ PP(:)
150123 DB = np .reshape (DBVec , [nx + ny , nu ], order = "F" )
151124 D = DB [:ny , :]
152125 B = DB [ny : (ny + nx ), :]
@@ -265,7 +238,7 @@ def combineIdSysWithEps(s, s3, missing_marker):
265238 "Syz" : s .Syz ,
266239 "Rz" : s .Rz ,
267240 }
268- newSys = LSSM . LSSM (params = new_params )
241+ newSys = LSSM (params = new_params )
269242
270243 return newSys
271244
@@ -291,6 +264,9 @@ def IPSID(
291264 remove_nonYrelated_fromX1 = False ,
292265 n_pre = np .inf ,
293266 n3 = 0 ,
267+ force_stable_if_not = True ,
268+ force_stable_stage1 = False ,
269+ force_stable_stage2 = False ,
294270) -> LSSM :
295271 """
296272 IPSID: Input Preferential Subspace Identification Algorithm
@@ -403,7 +379,13 @@ def IPSID(
403379 If n_pre=0, Additional steps 1 and 2 won't happen and x3 won't be learned
404380 (remove_nonYrelated_fromX1 will be set to False, n3 will be 0).
405381 - (20) n3: number of latent states x3(k) in the optional additional step 2.
406-
382+ - (21) force_stable_if_not (default: True): If True, will run a second
383+ pass to learn a model with stable A-KC if the original learned model
384+ has unstable A-KC.
385+ - (22/23) force_stable_stage1/force_stable_stage2 (default: False): These are
386+ internal and should never be set to True by the user. It may be
387+ automatically used in a second pass to enforce stability. If used when
388+ unnecessary will cause unnecessary error in the model.
407389 Outputs:
408390 - (1) idSys: an LSSM object with the system parameters for
409391 the identified system. Will have the following
@@ -429,6 +411,7 @@ def IPSID(
429411 a special case of IPSID. To do so, simply set Z=None and n1=0.
430412 (6) NDM (or SID, i.e., Standard Subspace Identification without input U, unsupervised by Z) can be performed as
431413 a special case of IPSID. To do so, simply set Z=None, U=None and n1=0.
414+ (7) VODM: The Van Overschee and De Moor subspace identification book, which is ref. 40 in (Vahidi, Sani, et al).
432415
433416 Usage example:
434417 idSys = IPSID(Y, Z, U, nx=nx, n1=n1, i=i); # With external input
@@ -466,11 +449,21 @@ def IPSID(
466449 U = UPrepModel .apply (U , time_first = time_first )
467450
468451 ny , ySamples , N , y1 , NTot = getHSize (Y , iMax , time_first = time_first )
469- if Z is not None :
452+ if Z is not None and (
453+ not isinstance (Z , (list , tuple ))
454+ and Z .size > 0
455+ or isinstance (Z , (list , tuple ))
456+ and len (Z ) > 0
457+ ):
470458 nz , zSamples , _ , z1 , NTot = getHSize (Z , iMax , time_first = time_first )
471459 else :
472460 nz , zSamples = 0 , 0
473- if U is not None :
461+ if U is not None and (
462+ not isinstance (U , (list , tuple ))
463+ and U .size > 0
464+ or isinstance (U , (list , tuple ))
465+ and len (U ) > 0
466+ ):
474467 nu , uSamples , _ , u1 , NTot = getHSize (U , iMax , time_first = time_first )
475468 else :
476469 nu = 0
@@ -547,6 +540,7 @@ def IPSID(
547540 ): # Due to provided settings, preprocessing step is disabled and X3 won't be learned.
548541 remove_nonYrelated_fromX1 , n_pre , n3 = False , 0 , 0
549542
543+ # Stage 1
550544 if n1 > 0 and nz > 0 :
551545 if n1 > iZ * nz :
552546 raise (
@@ -623,6 +617,16 @@ def IPSID(
623617 Oz = Uz @ Sz ** (1 / 2 )
624618 Oz_Minus = Oz [:- nz , :]
625619
620+ if force_stable_stage1 :
621+ # Modified analogously to the heuristic in VODM book page 129
622+ # Add zeros to Oz_Minus to get Oz_0
623+ Oz_0 = np .concatenate ((Oz_Minus , np .zeros ((nz , n1 ))), axis = 0 )
624+ Cz_temp = Oz [:nz , :]
625+ A1_temp = np .linalg .pinv (Oz ) @ Oz_0
626+ # # Recompute Oz and Oz_Minus with the new A1_temp, Cz_temp
627+ # Oz = np.concatenate([Cz_temp@A1_temp**p for p in range(iZ)], axis=0)
628+ # Oz_Minus = Oz[:-nz, :]
629+
626630 Xk = np .linalg .pinv (Oz ) @ WS ["ZHat" ]
627631 # Eq. (24)
628632 Xk_Plus1 = np .linalg .pinv (Oz_Minus ) @ WS ["ZHatMinus" ]
@@ -703,6 +707,16 @@ def IPSID(
703707 Oy = U2 @ S2 ** (1 / 2 )
704708 Oy_Minus = Oy [:- ny , :]
705709
710+ if force_stable_stage2 :
711+ # Modified analogously to the heuristic in VODM book page 129
712+ # Add zeros to Oy_Minus to get Oy_0
713+ Oy_0 = np .concatenate ((Oy_Minus , np .zeros ((ny , n2 ))), axis = 0 )
714+ Cy_temp = Oy [:ny , :]
715+ A2_temp = np .linalg .pinv (Oy ) @ Oy_0
716+ # # Recompute Oy and Oy_Minus with the new A2_temp, Cy_temp
717+ # Oy = np.concatenate([Cy_temp@A2_temp**p for p in range(iY)], axis=0)
718+ # Oy_Minus = Oy[:-ny, :]
719+
706720 Xk2 = np .linalg .pinv (Oy ) @ WS ["YHat" ] # Eq.(28)
707721 Xk2_Plus1 = np .linalg .pinv (Oy_Minus ) @ WS ["YHatMinus" ]
708722
@@ -717,6 +731,8 @@ def IPSID(
717731 Xk_Plus1 [:n1 , :], np .concatenate ((Xk [:n1 , :], WS ["Uf" ]))
718732 ) # Eq.(29)
719733 A = A1Tmp [:n1 , :n1 ]
734+ if force_stable_stage1 :
735+ A = A1_temp
720736 w = Xk_Plus1 [:n1 , :] - XkP1Hat [:n1 , :] # Eq.(33)
721737 else :
722738 A = np .empty ([0 , 0 ])
@@ -728,6 +744,8 @@ def IPSID(
728744 Xk_Plus1 [n1 :, :], np .concatenate ((Xk , WS ["Uf" ]))
729745 ) # Eq.(30)
730746 A23 = A23Tmp [:, :nx ]
747+ if force_stable_stage2 :
748+ A23 = np .concatenate ((A23Tmp [:, :n1 ], A2_temp ), axis = 1 )
731749 if n1 > 0 :
732750 A10 = np .concatenate ((A , np .zeros ([n1 , n2 ])), axis = 1 )
733751 A = np .concatenate ((A10 , A23 ))
@@ -738,12 +756,16 @@ def IPSID(
738756 if nz > 0 :
739757 ZiiHat , CzTmp = projOrth (WS ["Zii" ], np .concatenate ((Xk , WS ["Uf" ]))) # Eq.(32)
740758 Cz = CzTmp [:, :nx ]
759+ if force_stable_stage1 and n1 > 0 :
760+ Cz = Cz_temp
741761 e = WS ["Zii" ] - ZiiHat
742762 else :
743763 Cz = np .empty ([0 , nx ])
744764
745765 YiiHat , CyTmp = projOrth (WS ["Yii" ], np .concatenate ((Xk , WS ["Uf" ]))) # Eq.(31)
746766 Cy = CyTmp [:, :nx ]
767+ if force_stable_stage2 and n2 > 0 :
768+ Cy = Cy_temp
747769 v = WS ["Yii" ] - YiiHat # Eq.(35)
748770
749771 # Compute noise covariances
@@ -762,15 +784,85 @@ def IPSID(
762784 Rz = (e @ e .T ) / NA
763785 params ["Rz" ] = (Rz + Rz .T ) / 2 # Make precisely symmetric
764786
765- s = LSSM . LSSM (params = params )
766- if np .any (np .isnan (s .Pp )): # Riccati did not have a solution.
787+ s = LSSM (params = params , missing_marker = missing_marker )
788+ if np .any (np .isnan (s .Pp )): # Riccati did not have a solution. Enforce stability.
767789 warnings .warn (
768790 "The learned model did not have a solution for the Riccati equation."
769791 )
770-
792+ if force_stable_if_not and not (force_stable_stage1 or force_stable_stage2 ):
793+ eigVals1 , eigVecs1 = np .linalg .eig (s .A [:n1 , :n1 ])
794+ stage1_is_unstable = np .any (np .abs (eigVals1 )) > 1
795+ eigVals2 , eigVecs2 = np .linalg .eig (s .A [(n1 + 1 ) :, (n1 + 1 ) :])
796+ stage2_is_unstable = np .any (np .abs (eigVals2 )) > 1
797+ isOk = False
798+ cnt = 0
799+ newNx = nx
800+ newN1 = n1
801+ while not isOk and newNx >= 2 :
802+ cnt += 1
803+ newNx = int (newNx / 2 )
804+ print (
805+ "Attempt #{} to refit the model (with nx={}, n1={}) while enforcing stability." .format (
806+ cnt , newNx , newN1
807+ )
808+ )
809+ s_tmp = PSID (
810+ Y ,
811+ Z = Z ,
812+ U = U ,
813+ nx = newNx ,
814+ n1 = newN1 ,
815+ i = i ,
816+ fit_Cz_via_KF = False ,
817+ time_first = time_first , # Prep is already done
818+ remove_mean_Y = False ,
819+ remove_mean_Z = False ,
820+ remove_mean_U = False ,
821+ zscore_Y = False ,
822+ zscore_Z = False ,
823+ zscore_U = False ,
824+ force_stable_stage1 = stage1_is_unstable ,
825+ force_stable_stage2 = stage2_is_unstable ,
826+ missing_marker = missing_marker ,
827+ )
828+ isOk = np .all (~ np .isnan (s_tmp .Pp ))
829+ if isOk :
830+ # Attach additional zero states to the temp model
831+ s_tmp .changeParams (
832+ {
833+ "A" : np .block (
834+ [
835+ [s_tmp .A , np .zeros ((newNx , nx - newNx ))],
836+ [
837+ np .zeros ((newNx , nx - newNx )),
838+ np .zeros ((newNx , newNx )),
839+ ],
840+ ]
841+ ),
842+ "C" : np .concatenate (
843+ (s_tmp .C , np .zeros ((ny , nx - newNx ))), axis = 1
844+ ),
845+ "Q" : np .block (
846+ [
847+ [s_tmp .Q , np .zeros ((newNx , nx - newNx ))],
848+ [
849+ np .zeros ((newNx , nx - newNx )),
850+ np .zeros ((newNx , newNx )),
851+ ],
852+ ]
853+ ),
854+ "S" : np .concatenate (
855+ (s_tmp .S , np .zeros ((nx - newNx , ny ))), axis = 0
856+ ),
857+ "Sxz" : np .concatenate (
858+ (s_tmp .Sxz , np .zeros ((nx - newNx , nz ))), axis = 0
859+ ),
860+ }
861+ )
862+ s = s_tmp
771863 if (
772864 nu > 0
773- ): # Following a procedure similar to ref. 40 in (Vahidi, Sani, et al) , pages 125-127 to find the least squares solution for the model parameters B and Dy
865+ ): # Following a procedure similar to VODM , pages 125-127 to find the least squares solution for the model parameters B and Dy
774866 RR = np .triu (
775867 np .linalg .qr (
776868 np .concatenate ((WS ["Up" ], WS ["Uf" ], WS ["Yp" ], WS ["Yf" ])).T / np .sqrt (NA )
@@ -854,7 +946,7 @@ def IPSID(
854946 "Rz" : s3 .R ,
855947 }
856948
857- s3 = LSSM . LSSM (params = params3 )
949+ s3 = LSSM (params = params3 )
858950 s = combineIdSysWithEps (
859951 s , s3 , missing_marker
860952 ) # Combining model parametrs learned for [X1,X2] and [X3] in a single model
0 commit comments