Skip to content

Commit 10ae3d3

Browse files
committed
Make as much processing as possible use Array API
1 parent f7224c4 commit 10ae3d3

7 files changed

Lines changed: 433 additions & 157 deletions

File tree

src/ezmsg/learn/dim_reduce/adaptive_decomp.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
"""Adaptive decomposition transformers (PCA, NMF).
2+
3+
.. note::
4+
This module supports the Array API standard via
5+
``array_api_compat.get_namespace()``. Reshaping and output allocation
6+
use Array API operations; a NumPy boundary is applied before sklearn
7+
``partial_fit``/``transform`` calls.
8+
"""
9+
10+
import math
111
import typing
212

313
import ezmsg.core as ez
414
import numpy as np
15+
from array_api_compat import get_namespace, is_numpy_array
516
from ezmsg.baseproc import (
617
BaseAdaptiveTransformer,
718
BaseAdaptiveTransformerUnit,
@@ -128,6 +139,8 @@ def _process(self, message: AxisArray) -> AxisArray:
128139
if in_dat.shape[ax_idx] == 0:
129140
return self._state.template
130141

142+
xp = get_namespace(in_dat)
143+
131144
# Re-order axes
132145
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
133146
if message.dims != sorted_dims_exp:
@@ -137,16 +150,20 @@ def _process(self, message: AxisArray) -> AxisArray:
137150
pass
138151

139152
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
140-
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
141-
in_dat = in_dat.reshape((-1, d2))
153+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
154+
in_dat = xp.reshape(in_dat, (-1, d2))
142155

143156
replace_kwargs = {
144157
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
145158
}
146159

147-
# Transform data
160+
# Transform data — sklearn needs numpy
148161
if hasattr(self._state.estimator, "components_"):
149-
decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
162+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
163+
decomp_dat = self._state.estimator.transform(in_np)
164+
# Convert back to source namespace
165+
decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
166+
decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
150167
replace_kwargs["data"] = decomp_dat
151168

152169
return replace(self._state.template, **replace_kwargs)
@@ -165,18 +182,21 @@ def partial_fit(self, message: AxisArray) -> None:
165182
if in_dat.shape[ax_idx] == 0:
166183
return
167184

185+
xp = get_namespace(in_dat)
186+
168187
# Re-order axes if needed
169188
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
170189
if message.dims != sorted_dims_exp:
171190
# TODO: Implement axes transposition if needed
172191
pass
173192

174193
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
175-
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
176-
in_dat = in_dat.reshape((-1, d2))
194+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
195+
in_dat = xp.reshape(in_dat, (-1, d2))
177196

178-
# Fit the estimator
179-
self._state.estimator.partial_fit(in_dat)
197+
# Fit the estimator — sklearn needs numpy
198+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
199+
self._state.estimator.partial_fit(in_np)
180200

181201

182202
class IncrementalPCASettings(AdaptiveDecompSettings):

src/ezmsg/learn/model/cca.py

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
1+
"""Incremental Canonical Correlation Analysis (CCA).
2+
3+
.. note::
4+
This module supports the Array API standard via
5+
``array_api_compat.get_namespace()``. All linear algebra uses Array API
6+
operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition-
7+
based inverse square root (:func:`_inv_sqrtm_spd`).
8+
"""
9+
110
import numpy as np
2-
from scipy import linalg
11+
from array_api_compat import get_namespace
12+
from ezmsg.sigproc.util.array import array_device, xp_create
13+
14+
15+
def _inv_sqrtm_spd(xp, A):
16+
"""Inverse matrix square root for symmetric positive-definite matrices.
17+
18+
Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the
19+
eigendecomposition. This is more numerically stable than computing
20+
``inv(sqrtm(...))`` separately and uses only Array API operations.
21+
"""
22+
eigenvalues, eigenvectors = xp.linalg.eigh(A)
23+
eigenvalues = xp.clip(eigenvalues, 1e-12, None) # avoid div-by-zero
24+
inv_sqrt_eig = 1.0 / xp.sqrt(eigenvalues)
25+
# Q @ diag(v) == Q * v (broadcasting), then @ Q^T
26+
return (eigenvectors * inv_sqrt_eig) @ xp.linalg.matrix_transpose(eigenvectors)
327

428

529
class IncrementalCCA:
@@ -33,58 +57,74 @@ def __init__(
3357
self.adaptation_rate = adaptation_rate
3458
self.initialized = False
3559

36-
def initialize(self, d1, d2):
37-
"""Initialize the necessary matrices"""
60+
def initialize(self, d1, d2, *, ref_array=None):
61+
"""Initialize the necessary matrices.
62+
63+
Args:
64+
d1: Dimensionality of the first dataset.
65+
d2: Dimensionality of the second dataset.
66+
ref_array: Optional reference array to derive array namespace
67+
and device from. If ``None``, defaults to NumPy.
68+
"""
3869
self.d1 = d1
3970
self.d2 = d2
4071

72+
if ref_array is not None:
73+
xp = get_namespace(ref_array)
74+
dev = array_device(ref_array)
75+
else:
76+
xp, dev = np, None
77+
4178
# Initialize correlation matrices
42-
self.C11 = np.zeros((d1, d1))
43-
self.C22 = np.zeros((d2, d2))
44-
self.C12 = np.zeros((d1, d2))
79+
self.C11 = xp_create(xp.zeros, (d1, d1), dtype=xp.float64, device=dev)
80+
self.C22 = xp_create(xp.zeros, (d2, d2), dtype=xp.float64, device=dev)
81+
self.C12 = xp_create(xp.zeros, (d1, d2), dtype=xp.float64, device=dev)
4582

4683
self.initialized = True
4784

4885
def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
49-
"""Compute magnitude of change in correlation structure"""
86+
"""Compute magnitude of change in correlation structure."""
87+
xp = get_namespace(self.C11)
88+
5089
# Frobenius norm of differences
51-
diff11 = np.linalg.norm(C11_new - self.C11)
52-
diff22 = np.linalg.norm(C22_new - self.C22)
53-
diff12 = np.linalg.norm(C12_new - self.C12)
90+
diff11 = xp.linalg.matrix_norm(C11_new - self.C11)
91+
diff22 = xp.linalg.matrix_norm(C22_new - self.C22)
92+
diff12 = xp.linalg.matrix_norm(C12_new - self.C12)
5493

5594
# Normalize by matrix sizes
56-
diff11 /= self.d1 * self.d1
57-
diff22 /= self.d2 * self.d2
58-
diff12 /= self.d1 * self.d2
95+
diff11 = diff11 / (self.d1 * self.d1)
96+
diff22 = diff22 / (self.d2 * self.d2)
97+
diff12 = diff12 / (self.d1 * self.d2)
5998

60-
return (diff11 + diff22 + diff12) / 3
99+
return float((diff11 + diff22 + diff12) / 3)
61100

62101
def _adapt_smoothing(self, change_magnitude):
63-
"""Adapt smoothing factor based on detected changes"""
102+
"""Adapt smoothing factor based on detected changes."""
64103
# If change is large, decrease smoothing factor
65104
target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
66-
target_smoothing = np.clip(
67-
target_smoothing, self.min_smoothing, self.max_smoothing
68-
)
105+
target_smoothing = max(self.min_smoothing, min(target_smoothing, self.max_smoothing))
69106

70107
# Smooth the adaptation itself
71108
self.current_smoothing = (
72109
1 - self.adaptation_rate
73110
) * self.current_smoothing + self.adaptation_rate * target_smoothing
74111

75112
def partial_fit(self, X1, X2, update_projections=True):
76-
"""Update the model with new samples using adaptive smoothing
77-
Assumes X1 and X2 are already centered and scaled"""
113+
"""Update the model with new samples using adaptive smoothing.
114+
Assumes X1 and X2 are already centered and scaled."""
115+
xp = get_namespace(X1, X2)
116+
_mT = xp.linalg.matrix_transpose
117+
78118
if not self.initialized:
79-
self.initialize(X1.shape[1], X2.shape[1])
119+
self.initialize(X1.shape[1], X2.shape[1], ref_array=X1)
80120

81121
# Compute new correlation matrices from current batch
82-
C11_new = X1.T @ X1 / X1.shape[0]
83-
C22_new = X2.T @ X2 / X2.shape[0]
84-
C12_new = X1.T @ X2 / X1.shape[0]
122+
C11_new = _mT(X1) @ X1 / X1.shape[0]
123+
C22_new = _mT(X2) @ X2 / X2.shape[0]
124+
C12_new = _mT(X1) @ X2 / X1.shape[0]
85125

86126
# Detect changes and adapt smoothing factor
87-
if self.C11.any(): # Skip first update
127+
if bool(xp.any(self.C11 != 0)): # Skip first update
88128
change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
89129
self._adapt_smoothing(change_magnitude)
90130

@@ -98,25 +138,26 @@ def partial_fit(self, X1, X2, update_projections=True):
98138
self._update_projections()
99139

100140
def _update_projections(self):
101-
"""Update canonical vectors and correlations"""
141+
"""Update canonical vectors and correlations."""
142+
xp = get_namespace(self.C11)
143+
dev = array_device(self.C11)
144+
_mT = xp.linalg.matrix_transpose
145+
102146
eps = 1e-8
103-
C11_reg = self.C11 + eps * np.eye(self.d1)
104-
C22_reg = self.C22 + eps * np.eye(self.d2)
147+
C11_reg = self.C11 + eps * xp_create(xp.eye, self.d1, dtype=self.C11.dtype, device=dev)
148+
C22_reg = self.C22 + eps * xp_create(xp.eye, self.d2, dtype=self.C22.dtype, device=dev)
149+
150+
inv_sqrt_C11 = _inv_sqrtm_spd(xp, C11_reg)
151+
inv_sqrt_C22 = _inv_sqrtm_spd(xp, C22_reg)
105152

106-
K = (
107-
linalg.inv(linalg.sqrtm(C11_reg))
108-
@ self.C12
109-
@ linalg.inv(linalg.sqrtm(C22_reg))
110-
)
111-
U, self.correlations_, V = linalg.svd(K)
153+
K = inv_sqrt_C11 @ self.C12 @ inv_sqrt_C22
154+
U, self.correlations_, Vh = xp.linalg.svd(K, full_matrices=False)
112155

113-
self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components]
114-
self.y_weights_ = (
115-
linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components]
116-
)
156+
self.x_weights_ = inv_sqrt_C11 @ U[:, : self.n_components]
157+
self.y_weights_ = inv_sqrt_C22 @ _mT(Vh)[:, : self.n_components]
117158

118159
def transform(self, X1, X2):
119-
"""Project data onto canonical components"""
160+
"""Project data onto canonical components."""
120161
X1_proj = X1 @ self.x_weights_
121162
X2_proj = X2 @ self.y_weights_
122163
return X1_proj, X2_proj

0 commit comments

Comments
 (0)