Skip to content

Commit 60abdb7

Browse files
committed
Add Array API documentation
1 parent 10ae3d3 commit 60abdb7

2 files changed

Lines changed: 247 additions & 0 deletions

File tree

docs/source/guides/array_api.rst

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
Array API Compatibility
2+
=======================
3+
4+
ezmsg-learn uses the `Array API standard <https://data-apis.org/array-api/latest/>`_
5+
to allow processors to operate on arrays from different backends — NumPy, CuPy,
6+
PyTorch, and others — without code changes.
7+
8+
.. contents:: On this page
9+
:local:
10+
:depth: 2
11+
12+
13+
How It Works
14+
------------
15+
16+
Modules that support the Array API derive the array namespace from their input
17+
data using ``array_api_compat.get_namespace()``:
18+
19+
.. code-block:: python
20+
21+
from array_api_compat import get_namespace
22+
23+
def process(self, data):
24+
xp = get_namespace(data) # numpy, cupy, torch, etc.
25+
result = xp.linalg.inv(data) # dispatches to the right backend
26+
return result
27+
28+
This means that if you pass a CuPy array, all computation stays on the GPU.
29+
If you pass a NumPy array, it behaves exactly as before.
30+
31+
Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
32+
and creation functions portably:
33+
34+
- ``array_device(x)`` — returns the device of an array, or ``None``
35+
- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
36+
functions (``zeros``, ``eye``) with optional device
37+
- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``
38+
39+
40+
Module Compatibility
41+
--------------------
42+
43+
The table below summarises the Array API status of each module.
44+
45+
Fully compatible
46+
^^^^^^^^^^^^^^^^
47+
48+
These modules perform all computation in the source array namespace.
49+
50+
.. list-table::
51+
:header-rows: 1
52+
:widths: 35 65
53+
54+
* - Module
55+
- Notes
56+
* - ``process.ssr``
57+
- LRR / self-supervised regression. Full Array API.
58+
* - ``model.cca``
59+
- Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
60+
eigendecomposition-based inverse square root using only Array API ops.
61+
* - ``process.rnn``
62+
- PyTorch-native; operates on ``torch.Tensor`` throughout.
63+
64+
Mostly compatible (with NumPy boundaries)
65+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66+
67+
These modules use the Array API for data manipulation but fall back to NumPy
68+
at specific points where a dependency requires it.
69+
70+
.. list-table::
71+
:header-rows: 1
72+
:widths: 25 35 40
73+
74+
* - Module
75+
- NumPy boundary
76+
- Reason
77+
* - ``model.refit_kalman``
78+
- ``_compute_gain()``
79+
- ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
80+
Matrices are converted to NumPy for the DARE solver, then converted back.
81+
* - ``model.refit_kalman``
82+
- ``refit()`` mutation loop
83+
- Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
84+
and scalar element assignment.
85+
* - ``process.refit_kalman``
86+
- Inherits boundaries from model
87+
- State init and output arrays use the source namespace.
88+
* - ``process.slda``
89+
- ``predict_proba``
90+
- sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
91+
* - ``process.adaptive_linear_regressor``
92+
- ``partial_fit`` / ``predict``
93+
- sklearn and river models require NumPy / pandas input.
94+
* - ``dim_reduce.adaptive_decomp``
95+
- ``partial_fit`` / ``transform``
96+
- sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.
97+
98+
Not converted
99+
^^^^^^^^^^^^^
100+
101+
These modules use NumPy directly. Conversion would provide little benefit
102+
because the underlying estimator is the bottleneck.
103+
104+
.. list-table::
105+
:header-rows: 1
106+
:widths: 25 75
107+
108+
* - Module
109+
- Reason
110+
* - ``process.linear_regressor``
111+
- Thin wrapper around sklearn ``LinearModel.predict``.
112+
Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
113+
(see below).
114+
* - ``process.sgd``
115+
- sklearn ``SGDClassifier`` has no Array API support.
116+
* - ``process.sklearn``
117+
- Generic wrapper for arbitrary models; cannot assume Array API support.
118+
* - ``dim_reduce.incremental_decomp``
119+
- Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
120+
Python tuples).
121+
122+
123+
sklearn Array API Dispatch
124+
--------------------------
125+
126+
scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
127+
of estimators. Two estimators used in ezmsg-learn are on the supported list:
128+
129+
.. list-table::
130+
:header-rows: 1
131+
:widths: 30 30 40
132+
133+
* - Estimator
134+
- Used in
135+
- Constraint
136+
* - ``LinearDiscriminantAnalysis``
137+
- ``process.slda``
138+
- Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
139+
is not supported)
140+
* - ``Ridge``
141+
- ``process.linear_regressor``
142+
- Requires ``solver="svd"``
143+
144+
To use dispatch, enable it before creating the estimator:
145+
146+
.. code-block:: python
147+
148+
from sklearn import set_config
149+
set_config(array_api_dispatch=True)
150+
151+
.. warning::
152+
153+
- ``array_api_dispatch`` is marked **experimental** in sklearn.
154+
- Solver constraints (``solver="svd"``) may produce slightly different
155+
numerical results compared to other solvers.
156+
- Enabling dispatch globally may affect other sklearn estimators in the
157+
same process.
158+
- ezmsg-learn does **not** enable dispatch by default.
159+
160+
Estimators that do **not** support Array API dispatch:
161+
162+
- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
163+
- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
164+
- All river models
165+
166+
167+
Writing Array API Compatible Code
168+
----------------------------------
169+
170+
When adding or modifying processors in ezmsg-learn, follow these patterns.
171+
172+
Deriving the namespace
173+
^^^^^^^^^^^^^^^^^^^^^^
174+
175+
Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:
176+
177+
.. code-block:: python
178+
179+
from array_api_compat import get_namespace
180+
from ezmsg.sigproc.util.array import array_device, xp_create
181+
182+
def _process(self, message):
183+
xp = get_namespace(message.data)
184+
dev = array_device(message.data)
185+
186+
Transposing matrices
187+
^^^^^^^^^^^^^^^^^^^^
188+
189+
The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:
190+
191+
.. code-block:: python
192+
193+
# Before (numpy-only)
194+
result = A.T @ B
195+
196+
# After (Array API)
197+
_mT = xp.linalg.matrix_transpose
198+
result = _mT(A) @ B
199+
200+
Creating arrays
201+
^^^^^^^^^^^^^^^
202+
203+
Use ``xp_create`` to handle device placement portably:
204+
205+
.. code-block:: python
206+
207+
# Before
208+
I = np.eye(n)
209+
z = np.zeros((m, n), dtype=np.float64)
210+
211+
# After
212+
I = xp_create(xp.eye, n, device=dev)
213+
z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
214+
215+
Handling sklearn boundaries
216+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
217+
218+
When calling into sklearn (or other NumPy-only libraries), convert at the
219+
boundary and convert back:
220+
221+
.. code-block:: python
222+
223+
from array_api_compat import is_numpy_array
224+
225+
# Convert to numpy for sklearn
226+
X_np = np.asarray(X) if not is_numpy_array(X) else X
227+
result_np = estimator.predict(X_np)
228+
229+
# Convert back to source namespace
230+
result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
231+
232+
Checking for NaN
233+
^^^^^^^^^^^^^^^^
234+
235+
Use ``xp.isnan`` instead of ``np.isnan``:
236+
237+
.. code-block:: python
238+
239+
if xp.any(xp.isnan(message.data)):
240+
return
241+
242+
Norms
243+
^^^^^
244+
245+
Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
246+
``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org <https://www.ezmsg.org>
5454
:caption: Contents:
5555

5656
guides/classification
57+
guides/array_api
5758
api/index
5859

5960

0 commit comments

Comments
 (0)