Skip to content

Commit bd49915

Browse files
committed
Replace Brandts sorting with LAPACK DTRSEN, stabilize Gram-Schmidt
- Replace 624-line Python Brandts sorting (_sort_real_schur.py) with a single LAPACK DTRSEN call in sorted_brandts_schur(). Same Bai-Demmel algorithm, compiled Fortran instead of Python loops. Fixes subspace angle failures for n >= 100. - Replace single-pass Modified Gram-Schmidt in _gram_schmidt_mod() with Householder QR (np.linalg.qr, LAPACK DGEQRF). Single-pass MGS lost orthogonality for ill-conditioned inputs (kappa up to 1e16). QR is backward stable regardless of conditioning. Same O(nm^2) complexity, faster in practice (compiled LAPACK vs Python loops). - Replace fragile constant-vector detection (element-wise constancy check with tight tolerance) with cosine similarity against sqrt(eta). The old check failed when Schur vectors were only approximately constant due to numerical noise from different eigenvalue orderings. - Recompute R in _do_schur() after re-orthogonalization to maintain the Schur relation in the new basis. - Update tests accordingly.
1 parent 773c463 commit bd49915

4 files changed

Lines changed: 104 additions & 717 deletions

File tree

pygpcca/_gpcca.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
9292
r"""
9393
:math:`\eta`-orthonormalize Schur vectors.
9494
95-
This uses a modified, numerically stable version of Gram-Schmidt
96-
Orthonormalization.
95+
This uses Householder QR decomposition (LAPACK DGEQRF via
96+
:func:`numpy.linalg.qr`) for backward-stable orthonormalization.
9797
9898
Parameters
9999
----------
@@ -111,18 +111,18 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
111111
# Keep copy of the original (Schur) vectors for later sanity check.
112112
Xc = np.copy(X)
113113

114-
# Initialize matrices.
115114
n, m = X.shape
116-
Q = np.zeros((n, m))
117-
R = np.zeros((m, m))
118115

119-
# Search for the constant (Schur) vector, if explicitly present.
120-
max_i = 0
121-
for i in range(m):
122-
vsum = np.sum(X[:, i])
123-
dummy = np.ones(X[:, i].shape) * (vsum / n)
124-
if np.allclose(X[:, i], dummy, rtol=1e-6, atol=1e-5):
125-
max_i = i # TODO: check, if more than one vec fulfills this
116+
# Find the column most aligned with sqrt(eta) (i.e., the stationary vector)
117+
# by cosine similarity. This is more robust than checking element-wise
118+
# constancy, which can fail when Schur vectors are only approximately
119+
# constant due to numerical noise.
120+
sqrt_eta = np.sqrt(eta)
121+
sqrt_eta_normed = sqrt_eta / np.linalg.norm(sqrt_eta)
122+
col_norms = np.linalg.norm(X, axis=0)
123+
col_norms = np.where(col_norms > 0, col_norms, 1.0)
124+
cosines = np.abs((X / col_norms).T @ sqrt_eta_normed)
125+
max_i = int(np.argmax(cosines))
126126

127127
# Shift non-constant first (Schur) vector to the right.
128128
X[:, max_i] = X[:, 0]
@@ -142,14 +142,11 @@ def _gram_schmidt_mod(X: ArrayLike, eta: ArrayLike) -> ArrayLike:
142142
f"Number of clusters: {m}."
143143
)
144144

145-
# eta-orthonormalization
146-
for j in range(m):
147-
v = X[:, j]
148-
for i in range(j):
149-
R[i, j] = np.dot(Q[:, i].conj(), v)
150-
v = v - np.dot(R[i, j], Q[:, i])
151-
R[j, j] = np.linalg.norm(v)
152-
Q[:, j] = np.true_divide(v, R[j, j])
145+
# Orthonormalize via Householder QR (backward stable, LAPACK DGEQRF).
146+
Q, _ = np.linalg.qr(X)
147+
# QR may flip the sign of columns; ensure column 0 aligns with sqrt(eta).
148+
if Q[:, 0] @ np.sqrt(eta) < 0:
149+
Q[:, 0] = -Q[:, 0]
153150

154151
# Raise, if the subspace changed!
155152
dummy = subspace_angles(Q, Xc)
@@ -258,18 +255,21 @@ def _do_schur(
258255
if not np.allclose(Q.T.dot(Q * eta[:, None]), np.eye(Q.shape[1]), rtol=1e6 * EPS, atol=1e6 * EPS):
259256
logging.debug("The Schur vectors aren't D-orthogonal so they are D-orthogonalized.")
260257
Q = _gram_schmidt_mod(Q, eta)
258+
# Recompute R in the new orthonormal basis to maintain P_bar @ Q ≈ Q @ R.
259+
P_bar_dense = P_bar.toarray() if issparse(P_bar) else P_bar
260+
R = Q.T @ P_bar_dense @ Q
261261
# Transform the orthonormalized Schur vectors of P_bar back
262262
# to orthonormalized Schur vectors X of P.
263263
X = np.true_divide(Q, np.sqrt(eta)[:, None])
264264
else:
265-
# Search for the constant (Schur) vector, if explicitly present.
265+
# Find the column most aligned with sqrt(eta) (i.e., the stationary vector).
266266
n, m = Q.shape
267-
max_i = 0
268-
for i in range(m):
269-
vsum = np.sum(Q[:, i])
270-
dummy = np.ones(Q[:, i].shape) * (vsum / n)
271-
if np.allclose(Q[:, i], dummy, rtol=1e-6, atol=1e-5):
272-
max_i = i # TODO: check, if more than one vec fulfills this
267+
sqrt_eta = np.sqrt(eta)
268+
sqrt_eta_normed = sqrt_eta / np.linalg.norm(sqrt_eta)
269+
col_norms = np.linalg.norm(Q, axis=0)
270+
col_norms = np.where(col_norms > 0, col_norms, 1.0)
271+
cosines = np.abs((Q / col_norms).T @ sqrt_eta_normed)
272+
max_i = int(np.argmax(cosines))
273273

274274
# Shift non-constant first (Schur) vector to the right.
275275
Q[:, max_i] = Q[:, 0]

0 commit comments

Comments
 (0)