88atol = _common .atol
99
1010
11+ def _transpose (array ):
12+ axes = list (range (0 , array .ndim ))
13+ axes [- 2 ], axes [- 1 ] = axes [- 1 ], axes [- 2 ]
14+ return _np .transpose (array , axes = axes )
15+
16+
1117def _is_symmetric (x , tol = atol ):
12- new_x = _to_ndarray (x , to_ndim = 3 )
13- return (_np .abs (new_x - _np .transpose (new_x , axes = (0 , 2 , 1 ))) < tol ).all ()
18+ return (_np .abs (x - _transpose (x )) < tol ).all ()
1419
1520
1621def _is_hermitian (x , tol = atol ):
17- new_x = _to_ndarray (x , to_ndim = 3 )
18- return (_np .abs (new_x - _np .conj (_np .transpose (new_x , axes = (0 , 2 , 1 )))) < tol ).all ()
22+ return (_np .abs (x - _np .conj (_transpose (x ))) < tol ).all ()
1923
2024
2125_diag_vec = _np .vectorize (_np .diag , signature = "(n)->(n,n)" )
@@ -26,38 +30,30 @@ def _is_hermitian(x, tol=atol):
2630
2731
2832def logm (x ):
29- ndim = x .ndim
30- new_x = _to_ndarray (x , to_ndim = 3 )
31-
32- if _is_symmetric (new_x ) and new_x .dtype not in [_np .complex64 , _np .complex128 ]:
33- eigvals , eigvecs = _np .linalg .eigh (new_x )
33+ if _is_symmetric (x ) and x .dtype not in [_np .complex64 , _np .complex128 ]:
34+ eigvals , eigvecs = _np .linalg .eigh (x )
3435 if (eigvals > 0 ).all ():
3536 eigvals = _np .log (eigvals )
3637 eigvals = _diag_vec (eigvals )
37- transp_eigvecs = _np . transpose (eigvecs , axes = ( 0 , 2 , 1 ) )
38+ transp_eigvecs = _transpose (eigvecs )
3839 result = _np .matmul (eigvecs , eigvals )
3940 result = _np .matmul (result , transp_eigvecs )
4041 else :
41- result = _logm_vec (new_x )
42+ result = _logm_vec (x )
4243 else :
43- result = _logm_vec (new_x )
44+ result = _logm_vec (x )
4445
45- if ndim == 2 :
46- return result [0 ]
4746 return result
4847
4948
5049def solve_sylvester (a , b , q , tol = atol ):
5150 if a .shape == b .shape :
52- axes = (0 , 2 , 1 ) if a .ndim == 3 else (1 , 0 )
53- if _np .all (_np .isclose (a , b )) and _np .all (
54- _np .abs (a - _np .transpose (a , axes )) < tol
55- ):
51+ if _np .all (_np .isclose (a , b )) and _np .all (_np .abs (a - _transpose (a )) < tol ):
5652 eigvals , eigvecs = _np .linalg .eigh (a )
5753 if _np .all (eigvals >= tol ):
58- tilde_q = _np . transpose (eigvecs , axes ) @ q @ eigvecs
54+ tilde_q = _transpose (eigvecs ) @ q @ eigvecs
5955 tilde_x = tilde_q / (eigvals [..., :, None ] + eigvals [..., None , :])
60- return eigvecs @ tilde_x @ _np . transpose (eigvecs , axes )
56+ return eigvecs @ tilde_x @ _transpose (eigvecs )
6157
6258 return _np .vectorize (
6359 _scipy .linalg .solve_sylvester , signature = "(m,m),(n,n),(m,n)->(m,n)"
@@ -102,4 +98,41 @@ def fractional_matrix_power(A, t):
10298 if A .ndim == 2 :
10399 return _scipy .linalg .fractional_matrix_power (A , t )
104100
105- return _np .stack ([_scipy .linalg .fractional_matrix_power (A_ , t ) for A_ in A ])
101+ return _np .stack ([_scipy .linalg .fractional_matrix_power (A_ , t ) for A_ in A ])
102+
103+
104+ def polar (* args , ** kwargs ):
105+ """Polar decomposition of a matrix."""
106+ return _np .vectorize (
107+ _scipy .linalg .polar , signature = "(n,n)->(n,n),(n,n)" , excluded = ["side" ]
108+ )(* args , ** kwargs )
109+
110+
111+ def solve (a , b ):
112+ """
113+ Solve a linear matrix equation, or system of linear scalar equations.
114+
115+ Computes the "exact" solution, `x`, of the well-determined, i.e., full
116+ rank, linear matrix equation `ax = b`.
117+
118+ Parameters
119+ ----------
120+ a : array-like, shape=[..., M, M]
121+ Coefficient matrix.
122+ b : array-like, shape=[..., M]
123+ Ordinate or "dependent variable" values".
124+
125+ Returns
126+ -------
127+ x : array-like, shape=[..., M]
128+ Solution to the system a x = b.
129+ """
130+ batch_shape = a .shape [:- 2 ]
131+ if batch_shape :
132+ b = _np .expand_dims (b , axis = - 1 )
133+
134+ res = _np .linalg .solve (a , b )
135+ if batch_shape :
136+ return res [..., 0 ]
137+
138+ return res
0 commit comments