1+ """Incremental Canonical Correlation Analysis (CCA).
2+
3+ .. note::
4+ This module supports the Array API standard via
5+ ``array_api_compat.get_namespace()``. All linear algebra uses Array API
6+ operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition-
7+ based inverse square root (:func:`_inv_sqrtm_spd`).
8+ """
9+
110import numpy as np
2- from scipy import linalg
11+ from array_api_compat import get_namespace
12+ from ezmsg .sigproc .util .array import array_device , xp_create
13+
14+
15+ def _inv_sqrtm_spd (xp , A ):
16+ """Inverse matrix square root for symmetric positive-definite matrices.
17+
18+ Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the
19+ eigendecomposition. This is more numerically stable than computing
20+ ``inv(sqrtm(...))`` separately and uses only Array API operations.
21+ """
22+ eigenvalues , eigenvectors = xp .linalg .eigh (A )
23+ eigenvalues = xp .clip (eigenvalues , 1e-12 , None ) # avoid div-by-zero
24+ inv_sqrt_eig = 1.0 / xp .sqrt (eigenvalues )
25+ # Q @ diag(v) == Q * v (broadcasting), then @ Q^T
26+ return (eigenvectors * inv_sqrt_eig ) @ xp .linalg .matrix_transpose (eigenvectors )
327
428
529class IncrementalCCA :
@@ -33,58 +57,74 @@ def __init__(
3357 self .adaptation_rate = adaptation_rate
3458 self .initialized = False
3559
36- def initialize (self , d1 , d2 ):
37- """Initialize the necessary matrices"""
60+ def initialize (self , d1 , d2 , * , ref_array = None ):
61+ """Initialize the necessary matrices.
62+
63+ Args:
64+ d1: Dimensionality of the first dataset.
65+ d2: Dimensionality of the second dataset.
66+ ref_array: Optional reference array to derive array namespace
67+ and device from. If ``None``, defaults to NumPy.
68+ """
3869 self .d1 = d1
3970 self .d2 = d2
4071
72+ if ref_array is not None :
73+ xp = get_namespace (ref_array )
74+ dev = array_device (ref_array )
75+ else :
76+ xp , dev = np , None
77+
4178 # Initialize correlation matrices
42- self .C11 = np .zeros (( d1 , d1 ))
43- self .C22 = np .zeros (( d2 , d2 ))
44- self .C12 = np .zeros (( d1 , d2 ))
79+ self .C11 = xp_create ( xp .zeros , ( d1 , d1 ), dtype = xp . float64 , device = dev )
80+ self .C22 = xp_create ( xp .zeros , ( d2 , d2 ), dtype = xp . float64 , device = dev )
81+ self .C12 = xp_create ( xp .zeros , ( d1 , d2 ), dtype = xp . float64 , device = dev )
4582
4683 self .initialized = True
4784
4885 def _compute_change_magnitude (self , C11_new , C22_new , C12_new ):
49- """Compute magnitude of change in correlation structure"""
86+ """Compute magnitude of change in correlation structure."""
87+ xp = get_namespace (self .C11 )
88+
5089 # Frobenius norm of differences
51- diff11 = np .linalg .norm (C11_new - self .C11 )
52- diff22 = np .linalg .norm (C22_new - self .C22 )
53- diff12 = np .linalg .norm (C12_new - self .C12 )
90+ diff11 = xp .linalg .matrix_norm (C11_new - self .C11 )
91+ diff22 = xp .linalg .matrix_norm (C22_new - self .C22 )
92+ diff12 = xp .linalg .matrix_norm (C12_new - self .C12 )
5493
5594 # Normalize by matrix sizes
56- diff11 /= self .d1 * self .d1
57- diff22 /= self .d2 * self .d2
58- diff12 /= self .d1 * self .d2
95+ diff11 = diff11 / ( self .d1 * self .d1 )
96+ diff22 = diff22 / ( self .d2 * self .d2 )
97+ diff12 = diff12 / ( self .d1 * self .d2 )
5998
60- return ( diff11 + diff22 + diff12 ) / 3
99+ return float (( diff11 + diff22 + diff12 ) / 3 )
61100
62101 def _adapt_smoothing (self , change_magnitude ):
63- """Adapt smoothing factor based on detected changes"""
102+ """Adapt smoothing factor based on detected changes. """
64103 # If change is large, decrease smoothing factor
65104 target_smoothing = self .base_smoothing * (1.0 - change_magnitude )
66- target_smoothing = np .clip (
67- target_smoothing , self .min_smoothing , self .max_smoothing
68- )
105+ target_smoothing = max (self .min_smoothing , min (target_smoothing , self .max_smoothing ))
69106
70107 # Smooth the adaptation itself
71108 self .current_smoothing = (
72109 1 - self .adaptation_rate
73110 ) * self .current_smoothing + self .adaptation_rate * target_smoothing
74111
75112 def partial_fit (self , X1 , X2 , update_projections = True ):
76- """Update the model with new samples using adaptive smoothing
77- Assumes X1 and X2 are already centered and scaled"""
113+ """Update the model with new samples using adaptive smoothing.
114+ Assumes X1 and X2 are already centered and scaled."""
115+ xp = get_namespace (X1 , X2 )
116+ _mT = xp .linalg .matrix_transpose
117+
78118 if not self .initialized :
79- self .initialize (X1 .shape [1 ], X2 .shape [1 ])
119+ self .initialize (X1 .shape [1 ], X2 .shape [1 ], ref_array = X1 )
80120
81121 # Compute new correlation matrices from current batch
82- C11_new = X1 . T @ X1 / X1 .shape [0 ]
83- C22_new = X2 . T @ X2 / X2 .shape [0 ]
84- C12_new = X1 . T @ X2 / X1 .shape [0 ]
122+ C11_new = _mT ( X1 ) @ X1 / X1 .shape [0 ]
123+ C22_new = _mT ( X2 ) @ X2 / X2 .shape [0 ]
124+ C12_new = _mT ( X1 ) @ X2 / X1 .shape [0 ]
85125
86126 # Detect changes and adapt smoothing factor
87- if self . C11 . any (): # Skip first update
127+ if bool ( xp . any (self . C11 != 0 ) ): # Skip first update
88128 change_magnitude = self ._compute_change_magnitude (C11_new , C22_new , C12_new )
89129 self ._adapt_smoothing (change_magnitude )
90130
@@ -98,25 +138,26 @@ def partial_fit(self, X1, X2, update_projections=True):
98138 self ._update_projections ()
99139
100140 def _update_projections (self ):
101- """Update canonical vectors and correlations"""
141+ """Update canonical vectors and correlations."""
142+ xp = get_namespace (self .C11 )
143+ dev = array_device (self .C11 )
144+ _mT = xp .linalg .matrix_transpose
145+
102146 eps = 1e-8
103- C11_reg = self .C11 + eps * np .eye (self .d1 )
104- C22_reg = self .C22 + eps * np .eye (self .d2 )
147+ C11_reg = self .C11 + eps * xp_create (xp .eye , self .d1 , dtype = self .C11 .dtype , device = dev )
148+ C22_reg = self .C22 + eps * xp_create (xp .eye , self .d2 , dtype = self .C22 .dtype , device = dev )
149+
150+ inv_sqrt_C11 = _inv_sqrtm_spd (xp , C11_reg )
151+ inv_sqrt_C22 = _inv_sqrtm_spd (xp , C22_reg )
105152
106- K = (
107- linalg .inv (linalg .sqrtm (C11_reg ))
108- @ self .C12
109- @ linalg .inv (linalg .sqrtm (C22_reg ))
110- )
111- U , self .correlations_ , V = linalg .svd (K )
153+ K = inv_sqrt_C11 @ self .C12 @ inv_sqrt_C22
154+ U , self .correlations_ , Vh = xp .linalg .svd (K , full_matrices = False )
112155
113- self .x_weights_ = linalg .inv (linalg .sqrtm (C11_reg )) @ U [:, : self .n_components ]
114- self .y_weights_ = (
115- linalg .inv (linalg .sqrtm (C22_reg )) @ V .T [:, : self .n_components ]
116- )
156+ self .x_weights_ = inv_sqrt_C11 @ U [:, : self .n_components ]
157+ self .y_weights_ = inv_sqrt_C22 @ _mT (Vh )[:, : self .n_components ]
117158
118159 def transform (self , X1 , X2 ):
119- """Project data onto canonical components"""
160+ """Project data onto canonical components. """
120161 X1_proj = X1 @ self .x_weights_
121162 X2_proj = X2 @ self .y_weights_
122163 return X1_proj , X2_proj
0 commit comments