Skip to content

Commit dc7dd0b

Browse files
authored
Merge pull request #9 from ezmsg-org/array_api
Extend Array API compliance
2 parents f7224c4 + 9c12e32 commit dc7dd0b

10 files changed

Lines changed: 684 additions & 161 deletions

File tree

docs/source/guides/array_api.rst

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
Array API Compatibility
2+
=======================
3+
4+
ezmsg-learn uses the `Array API standard <https://data-apis.org/array-api/latest/>`_
5+
to allow processors to operate on arrays from different backends — NumPy, CuPy,
6+
PyTorch, and others — without code changes.
7+
8+
.. contents:: On this page
9+
:local:
10+
:depth: 2
11+
12+
13+
How It Works
14+
------------
15+
16+
Modules that support the Array API derive the array namespace from their input
17+
data using ``array_api_compat.get_namespace()``:
18+
19+
.. code-block:: python
20+
21+
from array_api_compat import get_namespace
22+
23+
def process(self, data):
24+
xp = get_namespace(data) # numpy, cupy, torch, etc.
25+
result = xp.linalg.inv(data) # dispatches to the right backend
26+
return result
27+
28+
This means that if you pass a CuPy array, all computation stays on the GPU.
29+
If you pass a NumPy array, it behaves exactly as before.
30+
31+
Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
32+
and creation functions portably:
33+
34+
- ``array_device(x)`` — returns the device of an array, or ``None``
35+
- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
36+
functions (``zeros``, ``eye``) with optional device
37+
- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``
38+
39+
40+
Module Compatibility
41+
--------------------
42+
43+
The table below summarises the Array API status of each module.
44+
45+
Fully compatible
46+
^^^^^^^^^^^^^^^^
47+
48+
These modules perform all computation in the source array namespace.
49+
50+
.. list-table::
51+
:header-rows: 1
52+
:widths: 35 65
53+
54+
* - Module
55+
- Notes
56+
* - ``process.ssr``
57+
- LRR / self-supervised regression. Full Array API.
58+
* - ``model.cca``
59+
- Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
60+
eigendecomposition-based inverse square root using only Array API ops.
61+
* - ``process.rnn``
62+
- PyTorch-native; operates on ``torch.Tensor`` throughout.
63+
64+
Mostly compatible (with NumPy boundaries)
65+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66+
67+
These modules use the Array API for data manipulation but fall back to NumPy
68+
at specific points where a dependency requires it.
69+
70+
.. list-table::
71+
:header-rows: 1
72+
:widths: 25 35 40
73+
74+
* - Module
75+
- NumPy boundary
76+
- Reason
77+
* - ``model.refit_kalman``
78+
- ``_compute_gain()``
79+
- ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
80+
Matrices are converted to NumPy for the DARE solver, then converted back.
81+
* - ``model.refit_kalman``
82+
- ``refit()`` mutation loop
83+
- Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
84+
and scalar element assignment.
85+
* - ``process.refit_kalman``
86+
- Inherits boundaries from model
87+
- State init and output arrays use the source namespace.
88+
* - ``process.slda``
89+
- ``predict_proba``
90+
- sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
91+
* - ``process.adaptive_linear_regressor``
92+
- ``partial_fit`` / ``predict``
93+
- sklearn and river models require NumPy / pandas input.
94+
* - ``dim_reduce.adaptive_decomp``
95+
- ``partial_fit`` / ``transform``
96+
- sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.
97+
98+
Not converted
99+
^^^^^^^^^^^^^
100+
101+
These modules use NumPy directly. Conversion would provide little benefit
102+
because the underlying estimator is the bottleneck.
103+
104+
.. list-table::
105+
:header-rows: 1
106+
:widths: 25 75
107+
108+
* - Module
109+
- Reason
110+
* - ``process.linear_regressor``
111+
- Thin wrapper around sklearn ``LinearModel.predict``.
112+
Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
113+
(see below).
114+
* - ``process.sgd``
115+
- sklearn ``SGDClassifier`` has no Array API support.
116+
* - ``process.sklearn``
117+
- Generic wrapper for arbitrary models; cannot assume Array API support.
118+
* - ``dim_reduce.incremental_decomp``
119+
- Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
120+
Python tuples).
121+
122+
123+
sklearn Array API Dispatch
124+
--------------------------
125+
126+
scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
127+
of estimators. Two estimators used in ezmsg-learn are on the supported list:
128+
129+
.. list-table::
130+
:header-rows: 1
131+
:widths: 30 30 40
132+
133+
* - Estimator
134+
- Used in
135+
- Constraint
136+
* - ``LinearDiscriminantAnalysis``
137+
- ``process.slda``
138+
- Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
139+
is not supported)
140+
* - ``Ridge``
141+
- ``process.linear_regressor``
142+
- Requires ``solver="svd"``
143+
144+
To use dispatch, enable it before creating the estimator:
145+
146+
.. code-block:: python
147+
148+
from sklearn import set_config
149+
set_config(array_api_dispatch=True)
150+
151+
.. warning::
152+
153+
- ``array_api_dispatch`` is marked **experimental** in sklearn.
154+
- Solver constraints (``solver="svd"``) may produce slightly different
155+
numerical results compared to other solvers.
156+
- Enabling dispatch globally may affect other sklearn estimators in the
157+
same process.
158+
- ezmsg-learn does **not** enable dispatch by default.
159+
160+
Estimators that do **not** support Array API dispatch:
161+
162+
- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
163+
- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
164+
- All river models
165+
166+
167+
Writing Array API Compatible Code
168+
----------------------------------
169+
170+
When adding or modifying processors in ezmsg-learn, follow these patterns.
171+
172+
Deriving the namespace
173+
^^^^^^^^^^^^^^^^^^^^^^
174+
175+
Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:
176+
177+
.. code-block:: python
178+
179+
from array_api_compat import get_namespace
180+
from ezmsg.sigproc.util.array import array_device, xp_create
181+
182+
def _process(self, message):
183+
xp = get_namespace(message.data)
184+
dev = array_device(message.data)
185+
186+
Transposing matrices
187+
^^^^^^^^^^^^^^^^^^^^
188+
189+
The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:
190+
191+
.. code-block:: python
192+
193+
# Before (numpy-only)
194+
result = A.T @ B
195+
196+
# After (Array API)
197+
_mT = xp.linalg.matrix_transpose
198+
result = _mT(A) @ B
199+
200+
Creating arrays
201+
^^^^^^^^^^^^^^^
202+
203+
Use ``xp_create`` to handle device placement portably:
204+
205+
.. code-block:: python
206+
207+
# Before
208+
I = np.eye(n)
209+
z = np.zeros((m, n), dtype=np.float64)
210+
211+
# After
212+
I = xp_create(xp.eye, n, device=dev)
213+
z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
214+
215+
Handling sklearn boundaries
216+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
217+
218+
When calling into sklearn (or other NumPy-only libraries), convert at the
219+
boundary and convert back:
220+
221+
.. code-block:: python
222+
223+
from array_api_compat import is_numpy_array
224+
225+
# Convert to numpy for sklearn
226+
X_np = np.asarray(X) if not is_numpy_array(X) else X
227+
result_np = estimator.predict(X_np)
228+
229+
# Convert back to source namespace
230+
result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
231+
232+
Checking for NaN
233+
^^^^^^^^^^^^^^^^
234+
235+
Use ``xp.isnan`` instead of ``np.isnan``:
236+
237+
.. code-block:: python
238+
239+
if xp.any(xp.isnan(message.data)):
240+
return
241+
242+
Norms
243+
^^^^^
244+
245+
Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
246+
``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org <https://www.ezmsg.org>
5454
:caption: Contents:
5555

5656
guides/classification
57+
guides/array_api
5758
api/index
5859

5960

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ license = "MIT"
99
requires-python = ">=3.10.15"
1010
dynamic = ["version"]
1111
dependencies = [
12-
"ezmsg-baseproc>=1.4.0",
13-
"ezmsg-sigproc>=2.15.0",
12+
"ezmsg>=3.7.3",
13+
"ezmsg-baseproc>=1.5.1",
14+
"ezmsg-sigproc>=2.17.0",
1415
"river>=0.22.0",
1516
"scikit-learn>=1.6.0",
1617
"torch>=2.6.0",
@@ -73,5 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"]
7374

7475
[tool.uv.sources]
7576
# Uncomment to use development version of ezmsg from git
76-
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
77-
#ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true }
77+
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }

src/ezmsg/learn/dim_reduce/adaptive_decomp.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
"""Adaptive decomposition transformers (PCA, NMF).
2+
3+
.. note::
4+
This module supports the Array API standard via
5+
``array_api_compat.get_namespace()``. Reshaping and output allocation
6+
use Array API operations; a NumPy boundary is applied before sklearn
7+
``partial_fit``/``transform`` calls.
8+
"""
9+
10+
import math
111
import typing
212

313
import ezmsg.core as ez
414
import numpy as np
15+
from array_api_compat import get_namespace, is_numpy_array
516
from ezmsg.baseproc import (
617
BaseAdaptiveTransformer,
718
BaseAdaptiveTransformerUnit,
@@ -128,6 +139,8 @@ def _process(self, message: AxisArray) -> AxisArray:
128139
if in_dat.shape[ax_idx] == 0:
129140
return self._state.template
130141

142+
xp = get_namespace(in_dat)
143+
131144
# Re-order axes
132145
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
133146
if message.dims != sorted_dims_exp:
@@ -137,16 +150,20 @@ def _process(self, message: AxisArray) -> AxisArray:
137150
pass
138151

139152
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
140-
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
141-
in_dat = in_dat.reshape((-1, d2))
153+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
154+
in_dat = xp.reshape(in_dat, (-1, d2))
142155

143156
replace_kwargs = {
144157
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
145158
}
146159

147-
# Transform data
160+
# Transform data — sklearn needs numpy
148161
if hasattr(self._state.estimator, "components_"):
149-
decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
162+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
163+
decomp_dat = self._state.estimator.transform(in_np)
164+
# Convert back to source namespace
165+
decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
166+
decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
150167
replace_kwargs["data"] = decomp_dat
151168

152169
return replace(self._state.template, **replace_kwargs)
@@ -165,18 +182,21 @@ def partial_fit(self, message: AxisArray) -> None:
165182
if in_dat.shape[ax_idx] == 0:
166183
return
167184

185+
xp = get_namespace(in_dat)
186+
168187
# Re-order axes if needed
169188
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
170189
if message.dims != sorted_dims_exp:
171190
# TODO: Implement axes transposition if needed
172191
pass
173192

174193
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
175-
d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
176-
in_dat = in_dat.reshape((-1, d2))
194+
d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
195+
in_dat = xp.reshape(in_dat, (-1, d2))
177196

178-
# Fit the estimator
179-
self._state.estimator.partial_fit(in_dat)
197+
# Fit the estimator — sklearn needs numpy
198+
in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
199+
self._state.estimator.partial_fit(in_np)
180200

181201

182202
class IncrementalPCASettings(AdaptiveDecompSettings):

0 commit comments

Comments
 (0)