From d3ab7dde9131b3ce35c8f3130da6cfcdd9cc1200 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 20:06:38 -0300 Subject: [PATCH 1/3] Improve linalg guide and linop representations --- spacecore/linalg/_cg.py | 14 +- spacecore/linalg/_lsqr.py | 15 +- spacecore/linalg/_power.py | 15 +- spacecore/linalg/_utils.py | 30 ++ spacecore/linop/_base.py | 16 + spacecore/linop/_dense.py | 12 +- spacecore/linop/_sparse.py | 13 +- tests/linops/test_to_dense.py | 73 +++++ tutorials/8_Linalg_MatrixFree.ipynb | 473 ++++++++++++++++++++++++++++ 9 files changed, 656 insertions(+), 5 deletions(-) create mode 100644 tutorials/8_Linalg_MatrixFree.ipynb diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index 5e9569b..775b46c 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import is_converged, real_inner, require_linop, require_square -from ._utils import safe_inverse, should_check_iteration, threshold +from ._utils import result_repr, safe_inverse, should_check_iteration, threshold class CGResult(NamedTuple): @@ -16,6 +16,18 @@ class CGResult(NamedTuple): num_iters: Any residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "CGResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "x": self.x, + }, + ) + def cg( A: LinOp, diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index 95a3921..8494df0 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import is_converged, require_linop, safe_inverse, should_check_iteration -from ._utils import threshold +from ._utils import result_repr, threshold class LSQRResult(NamedTuple): @@ -17,6 +17,19 @@ class LSQRResult(NamedTuple): residual_norm: Any normal_residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "LSQRResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "normal_residual_norm": self.normal_residual_norm, + "x": self.x, + }, + ) + def lsqr( A: LinOp, diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 8a64c12..5358b85 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import default_initial_vector, is_converged, normalize, require_linop -from ._utils import require_square, should_check_iteration +from ._utils import require_square, result_repr, should_check_iteration class PowerIterationResult(NamedTuple): @@ -17,6 +17,19 @@ class PowerIterationResult(NamedTuple): num_iters: Any residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "PowerIterationResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "eigenvalue": self.eigenvalue, + "residual_norm": self.residual_norm, + "eigenvector": self.eigenvector, + }, + ) + def power_iteration( A: LinOp, diff --git a/spacecore/linalg/_utils.py b/spacecore/linalg/_utils.py index 773f0fd..d9af97e 100644 --- a/spacecore/linalg/_utils.py +++ b/spacecore/linalg/_utils.py @@ -82,3 +82,33 @@ def default_initial_vector(A: LinOp) -> Any: size = prod(A.domain.shape) flat = A.ops.ones((size,), dtype=A.dtype) / A.ops.sqrt(A.ops.asarray(float(size))) return A.domain.unflatten(flat) + + +def summarize_value(value: Any) -> str: + """Return a compact representation for arrays, scalars, and pytrees.""" + shape = getattr(value, "shape", None) + dtype = getattr(value, "dtype", None) + if shape is not None: + shape_text = tuple(shape) + if shape_text == (): + dtype_text = str(dtype) + if dtype_text in {"bool", "bool_", "torch.bool"}: + try: + return repr(bool(value)) + except Exception: + return repr(value) + try: + return f"{float(value):.6g}" + except Exception: + return repr(value) + dtype_text = "" if dtype is None else f", dtype={dtype}" + return f"" + if isinstance(value, tuple): + return "(" + ", ".join(summarize_value(part) for part in value) + ")" + return repr(value) + + +def result_repr(name: str, fields: dict[str, Any]) -> str: + """Return a compact result-object representation.""" + body = ", ".join(f"{key}={summarize_value(value)}" for key, value in fields.items()) + return f"{name}({body})" diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index f006e9a..af9c05f 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -43,6 +43,22 @@ def codomain(self) -> Codomain: """Codomain space of this linear operator.""" return self.cod + @property + def A(self) -> Any: + """ + Native numerical representation of this operator. + + Concrete subclasses may choose the representation that best matches + their storage model: for example, dense operators return a dense array + while sparse operators return their sparse matrix. Matrix-free or lazy + operators generally do not have such a representation and should leave + this property unimplemented. Use :meth:`to_dense` when a dense tensor + materialization is explicitly required. + """ + raise NotImplementedError( + f"{type(self).__name__} does not define a native numerical representation." + ) + @abstractmethod def apply(self, x: Any) -> Any: """ diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 9b36a09..87f38d5 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -40,7 +40,7 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = prod(self.cod.shape) self._dom_size = prod(self.dom.shape) self._matrix_shape = (self._cod_size, self._dom_size) @@ -56,6 +56,16 @@ def __init__(self, self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + @property + def A(self) -> DenseArray: + """ + Stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape`` + and is the same object supplied at construction. + """ + return self._A + def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 7a8162e..cda5ba6 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -37,7 +37,7 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == (prod(cod.shape), prod(dom.shape)) == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = expected[0] self._dom_size = expected[1] dtype = self.ops.get_dtype(self.A) @@ -52,6 +52,17 @@ def __init__(self, self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + @property + def A(self) -> SparseArray: + """ + Stored sparse matrix representation of this operator. + + The returned sparse matrix has shape + ``(prod(self.codomain.shape), prod(self.domain.shape))`` and is the + same object supplied at construction. + """ + return self._A + def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py index 5ca33f2..a60097d 100644 --- a/tests/linops/test_to_dense.py +++ b/tests/linops/test_to_dense.py @@ -1,6 +1,7 @@ import importlib import numpy as np +import pytest import scipy.sparse as sps @@ -29,6 +30,17 @@ def test_dense_linop_to_dense_returns_stored_matrix_and_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_dense_linop_A_returns_stored_dense_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.A is A + + def test_sparse_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -41,6 +53,17 @@ def test_sparse_linop_to_dense_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_sparse_linop_A_returns_stored_sparse_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sps.csr_matrix([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(A, dom, cod, ctx) + + assert op.A is A + + def test_identity_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -80,6 +103,56 @@ def test_matrix_free_linop_to_dense_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_matrix_free_linop_A_is_not_implemented_by_default(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + with pytest.raises(NotImplementedError, match="native numerical representation"): + _ = op.A + + +def test_custom_linop_can_define_A_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + class CustomLinOp(sc.LinOp): + @property + def A(self): + return {"backend": "custom", "data": dense} + + def apply(self, x): + return ctx.asarray(np.asarray(dense) @ np.asarray(x)) + + def rapply(self, y): + return ctx.asarray(np.asarray(dense).T @ np.asarray(y)) + + def tree_flatten(self): + return (), (self.domain, self.codomain, self.ctx) + + @classmethod + def tree_unflatten(cls, aux, children): + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + op = CustomLinOp(dom, cod, ctx) + + assert op.A["backend"] == "custom" + assert op.A["data"] is dense + + def test_sum_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() diff --git a/tutorials/8_Linalg_MatrixFree.ipynb b/tutorials/8_Linalg_MatrixFree.ipynb new file mode 100644 index 0000000..f195685 --- /dev/null +++ b/tutorials/8_Linalg_MatrixFree.ipynb @@ -0,0 +1,473 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# Matrix-free linear algebra\n", + "\n", + "This guide demonstrates SpaceCore's iterative linear algebra routines on `MatrixFreeLinOp` objects.\n", + "\n", + "A matrix-free operator is useful when the action of an operator is cheap, but building or storing the full matrix would be wasteful. In SpaceCore, the only requirements are:\n", + "\n", + "- a domain space,\n", + "- a codomain space,\n", + "- a forward action `apply(x)`,\n", + "- an adjoint action `rapply(y)`.\n", + "\n", + "The solvers below never need a dense matrix representation of the operator." + ] + }, + { + "cell_type": "code", + "id": "imports", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.583864Z", + "start_time": "2026-05-21T15:58:04.539898Z" + } + }, + "source": [ + "import numpy as np\n", + "\n", + "import spacecore as sc" + ], + "outputs": [], + "execution_count": 50 + }, + { + "cell_type": "markdown", + "id": "context", + "metadata": {}, + "source": [ + "## Backend context and vector space\n", + "\n", + "We use NumPy here to keep the notebook easy to run. The same operators can be converted to other supported backends when their callbacks are written using backend-compatible operations." + ] + }, + { + "cell_type": "code", + "id": "setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.612618Z", + "start_time": "2026-05-21T15:58:04.600433Z" + } + }, + "source": [ + "ctx = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False)\n", + "n = 50000\n", + "X = sc.VectorSpace((n,), ctx)\n", + "\n", + "grid = np.linspace(0.0, np.pi, n)" + ], + "outputs": [], + "execution_count": 51 + }, + { + "cell_type": "markdown", + "id": "spd-op", + "metadata": {}, + "source": [ + "## A square Hermitian positive-definite operator\n", + "\n", + "This operator acts like a positive diagonal matrix, but we do not build a matrix object. The callback stores only the diagonal coefficients and multiplies elementwise.\n", + "\n", + "Because the operator is real and self-adjoint, the forward and adjoint callbacks are the same function." + ] + }, + { + "cell_type": "code", + "id": "spd-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.696219Z", + "start_time": "2026-05-21T15:58:04.628355Z" + } + }, + "source": [ + "diag = ctx.asarray(np.concatenate(([0.25], np.linspace(1.0, 2.0, n - 2), [6.0])))\n", + "\n", + "def spd_apply(x):\n", + " return diag * x\n", + "\n", + "A = sc.MatrixFreeLinOp(spd_apply, spd_apply, X, X, ctx)\n", + "\n", + "A.domain.shape, A.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (50000,))" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 52 + }, + { + "cell_type": "markdown", + "id": "cg", + "metadata": {}, + "source": [ + "## Conjugate gradient: solve `A x = b`\n", + "\n", + "`cg` is for square Hermitian positive-definite systems. We create a known solution, apply the operator to get the right-hand side, and ask CG to recover the solution.\n", + "\n", + "The result object summarizes convergence metadata and avoids printing the full solution vector in its `repr`." + ] + }, + { + "cell_type": "code", + "id": "cg-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.725542Z", + "start_time": "2026-05-21T15:58:04.699252Z" + } + }, + "source": [ + "x_true = ctx.asarray(np.sin(grid) + 0.2 * np.cos(3.0 * grid))\n", + "b = A.apply(x_true)\n", + "\n", + "cg_result = sc.cg(A, b, tol=1e-8, maxiter=256)\n", + "cg_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "CGResult(converged=True, num_iters=64, residual_norm=1.18582e-08, x=)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 53 + }, + { + "cell_type": "code", + "id": "cg-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.738009Z", + "start_time": "2026-05-21T15:58:04.726043Z" + } + }, + "source": [ + "relative_error = ctx.ops.norm(cg_result.x - x_true) / ctx.ops.norm(x_true)\n", + "float(relative_error)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "4.5131997046538686e-11" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 54 + }, + { + "cell_type": "markdown", + "id": "lsqr-op", + "metadata": {}, + "source": [ + "## A rectangular matrix-free operator\n", + "\n", + "`lsqr` works for rectangular least-squares problems. Here `B : R^n -> R^{2n}` maps a vector into two measurement channels:\n", + "\n", + "$$\n", + "B x = \\begin{bmatrix}x \\\\ w \\odot x\\end{bmatrix}.\n", + "$$\n", + "\n", + "The adjoint combines the two channels:\n", + "\n", + "$$\n", + "B^* y = y_1 + w \\odot y_2.\n", + "$$\n", + "\n", + "Again, no matrix is built." + ] + }, + { + "cell_type": "code", + "id": "lsqr-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.750694Z", + "start_time": "2026-05-21T15:58:04.739637Z" + } + }, + "source": [ + "Y = sc.VectorSpace((2 * n,), ctx)\n", + "weights = ctx.asarray(np.linspace(0.5, 1.5, n))\n", + "\n", + "def rectangular_apply(x):\n", + " return ctx.ops.concatenate((x, weights * x), axis=0)\n", + "\n", + "def rectangular_rapply(y):\n", + " return y[:n] + weights * y[n:]\n", + "\n", + "B = sc.MatrixFreeLinOp(rectangular_apply, rectangular_rapply, X, Y, ctx)\n", + "\n", + "B.domain.shape, B.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (100000,))" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 55 + }, + { + "cell_type": "markdown", + "id": "lsqr", + "metadata": {}, + "source": [ + "## LSQR: solve a least-squares problem\n", + "\n", + "We add a small deterministic perturbation to the measurements so the problem is not just a perfectly consistent copy of the original vector. LSQR minimizes `||B x - data||` using `B.apply` and `B.H.apply` internally." + ] + }, + { + "cell_type": "code", + "id": "lsqr-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.783046Z", + "start_time": "2026-05-21T15:58:04.754755Z" + } + }, + "source": [ + "x_ls_true = ctx.asarray(np.cos(2.0 * grid))\n", + "noise = 0.01 * ctx.asarray(np.sin(np.linspace(0.0, 4.0 * np.pi, 2 * n)))\n", + "data = B.apply(x_ls_true) + noise\n", + "\n", + "lsqr_result = sc.lsqr(B, data, tol=1e-10, maxiter=256)\n", + "lsqr_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "LSQRResult(converged=True, num_iters=64, residual_norm=0.304159, normal_residual_norm=8.83974e-14, x=)" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 56 + }, + { + "cell_type": "code", + "id": "lsqr-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.795722Z", + "start_time": "2026-05-21T15:58:04.783535Z" + } + }, + "source": [ + "normal_residual = ctx.ops.norm(B.H.apply(B.apply(lsqr_result.x) - data))\n", + "float(normal_residual)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "8.839736157139945e-14" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 57 + }, + { + "cell_type": "markdown", + "id": "power", + "metadata": {}, + "source": [ + "## Power iteration: estimate the dominant eigenpair\n", + "\n", + "`power_iteration` estimates the largest-magnitude eigenvalue of a square operator. For our diagonal example, the answer should be the largest entry in `diag`." + ] + }, + { + "cell_type": "code", + "id": "power-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.812249Z", + "start_time": "2026-05-21T15:58:04.797109Z" + } + }, + "source": [ + "x0 = ctx.asarray(np.ones(n))\n", + "power_result = sc.power_iteration(A, x0=x0, tol=1e-10, maxiter=256)\n", + "power_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "PowerIterationResult(converged=True, num_iters=64, eigenvalue=6, residual_norm=3.25687e-29, eigenvector=)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 58 + }, + { + "cell_type": "code", + "id": "power-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.826966Z", + "start_time": "2026-05-21T15:58:04.812958Z" + } + }, + "source": [ + "float(power_result.eigenvalue), float(diag[-1])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.0, 6.0)" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 59 + }, + { + "cell_type": "markdown", + "id": "lanczos", + "metadata": {}, + "source": [ + "## Stochastic Lanczos: estimate the smallest eigenpair\n", + "\n", + "`stochastic_lanczos` builds a Krylov subspace from an initial domain element and returns a Ritz approximation to the smallest eigenpair. The returned eigenvector is a member of `A.domain`, not a raw matrix column from an internal representation." + ] + }, + { + "cell_type": "code", + "id": "lanczos-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.233074Z", + "start_time": "2026-05-21T15:58:04.828930Z" + } + }, + "source": [ + "initial = ctx.asarray(np.ones(n))\n", + "smallest_eigenvalue, smallest_eigenvector = sc.stochastic_lanczos(\n", + " A,\n", + " initial,\n", + " max_iter=64,\n", + " tol=1e-10,\n", + ")\n", + "\n", + "float(smallest_eigenvalue), smallest_eigenvector.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(0.25, (50000,))" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 60 + }, + { + "cell_type": "code", + "id": "lanczos-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.249991Z", + "start_time": "2026-05-21T15:58:05.233738Z" + } + }, + "source": [ + "ritz_residual = ctx.ops.norm(A.apply(smallest_eigenvector) - smallest_eigenvalue * smallest_eigenvector)\n", + "float(ritz_residual), float(diag[0])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.179265594934942e-15, 0.25)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 61 + }, + { + "cell_type": "markdown", + "id": "wrap", + "metadata": {}, + "source": [ + "## What to take away\n", + "\n", + "- `MatrixFreeLinOp` lets you use iterative algorithms without building a matrix.\n", + "- `cg` is for square Hermitian positive-definite solves.\n", + "- `lsqr` is for general rectangular least-squares problems.\n", + "- `power_iteration` gives a dominant eigenpair estimate.\n", + "- `stochastic_lanczos` gives a smallest Ritz eigenpair estimate from a Krylov subspace.\n", + "- Solver result objects are compact to display, while full arrays remain available as attributes such as `cg_result.x`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d28aa09dafaf48a666a76d50f426b36dd39c1351 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 21:10:11 -0300 Subject: [PATCH 2/3] Add batched linop lifting --- README.md | 16 ++ docs/source/api/linops.rst | 10 ++ docs/source/api/spaces.rst | 10 ++ docs/source/index.rst | 16 ++ docs/source/tutorials/linops.rst | 33 +++++ spacecore/__init__.py | 4 + spacecore/backend/_ops.py | 71 +++++++++ spacecore/backend/jax/_ops.py | 9 ++ spacecore/backend/torch/_ops.py | 14 ++ spacecore/linop/__init__.py | 2 + spacecore/linop/_algebra.py | 124 +++++++++++++++- spacecore/linop/_base.py | 58 ++++++++ spacecore/linop/_dense.py | 39 +++++ spacecore/linop/_diagonal.py | 94 ++++++++++++ spacecore/linop/_sparse.py | 38 +++++ spacecore/linop/product/_block.py | 22 +++ spacecore/linop/product/_from_single.py | 24 +++ spacecore/linop/product/_to_single.py | 24 +++ spacecore/space/__init__.py | 2 + spacecore/space/_base.py | 12 ++ spacecore/space/_batch.py | 185 ++++++++++++++++++++++++ tests/linops/test_batched_lifting.py | 150 +++++++++++++++++++ tests/spaces/test_batch_space.py | 38 +++++ 23 files changed, 992 insertions(+), 3 deletions(-) create mode 100644 spacecore/linop/_diagonal.py create mode 100644 spacecore/space/_batch.py create mode 100644 tests/linops/test_batched_lifting.py create mode 100644 tests/spaces/test_batch_space.py diff --git a/README.md b/README.md index 2459fb9..7566010 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,8 @@ A `Space` describes the structure and geometry of values: - `VectorSpace` for Euclidean vectors and tensors; - `HermitianSpace` for Hermitian or symmetric matrices; - `ProductSpace` for Cartesian products of spaces. +- `BatchSpace` for batched elements such as `X.batch((B,), (0,))`, + representing `B` independent copies of `X`. Algorithms should use space methods such as `zeros`, `add`, `scale`, `axpy`, `inner`, `norm`, `flatten`, and `unflatten` instead of hard-coding backend array @@ -168,6 +170,20 @@ A `LinOp` represents a linear operator between spaces: Operators expose `apply` and `rapply`, so algorithms can use a linear map and its adjoint without depending on the storage format. +For batched inputs, `vapply(xs)` and `rvapply(ys)` lift the operator over the +leading batch axis: + +```python +XB = X.batch(batch_shape=(B,), batch_axes=(0,)) +YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + +ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB +xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB +``` + +The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero, +algebraic, and product-structured operators provide specialized batched paths. + ## Who should use this? SpaceCore is aimed at people writing optimization, inverse-problem, optimal diff --git a/docs/source/api/linops.rst b/docs/source/api/linops.rst index ef963b4..2c59373 100644 --- a/docs/source/api/linops.rst +++ b/docs/source/api/linops.rst @@ -10,6 +10,7 @@ actions. spacecore.linop.LinOp spacecore.linop.ProductLinOp spacecore.linop.DenseLinOp + spacecore.linop.DiagonalLinOp spacecore.linop.SparseLinOp spacecore.linop.BlockDiagonalLinOp spacecore.linop.StackedLinOp @@ -42,6 +43,15 @@ DenseLinOp :inherited-members: :show-inheritance: +DiagonalLinOp +------------- + +.. autoclass:: spacecore.linop.DiagonalLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + SparseLinOp ----------- diff --git a/docs/source/api/spaces.rst b/docs/source/api/spaces.rst index 01075fe..ba94a95 100644 --- a/docs/source/api/spaces.rst +++ b/docs/source/api/spaces.rst @@ -10,6 +10,7 @@ Spaces define element structure, geometry, flattening, and validation. spacecore.space.VectorSpace spacecore.space.HermitianSpace spacecore.space.ProductSpace + spacecore.space.BatchSpace spacecore.space.SpaceCheck spacecore.space.SpaceValidationError @@ -49,6 +50,15 @@ ProductSpace :inherited-members: :show-inheritance: +BatchSpace +---------- + +.. autoclass:: spacecore.space.BatchSpace + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + Validation ---------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 709aae8..01185b4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -157,6 +157,8 @@ A ``Space`` describes the structure and geometry of values: * ``VectorSpace`` for Euclidean vectors and tensors; * ``HermitianSpace`` for Hermitian or symmetric matrices; * ``ProductSpace`` for Cartesian products of spaces. +* ``BatchSpace`` for batched elements such as ``X.batch((B,), (0,))``, + representing ``B`` independent copies of ``X``. Algorithms should use space methods such as ``zeros``, ``add``, ``scale``, ``axpy``, ``inner``, ``norm``, ``flatten``, and ``unflatten`` instead of @@ -176,6 +178,20 @@ A ``LinOp`` represents a linear operator between spaces: Operators expose ``apply`` and ``rapply``, so algorithms can use a linear map and its adjoint without depending on the storage format. +For batched inputs, ``vapply(xs)`` and ``rvapply(ys)`` lift the operator over +the leading batch axis: + +.. code-block:: python + + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB + xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB + +The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero, +algebraic, and product-structured operators provide specialized batched paths. + Who should use this? -------------------- diff --git a/docs/source/tutorials/linops.rst b/docs/source/tutorials/linops.rst index 72d73c4..43fa474 100644 --- a/docs/source/tutorials/linops.rst +++ b/docs/source/tutorials/linops.rst @@ -7,6 +7,7 @@ the abstraction for linear maps between spaces. Current implemented operator types are: * ``DenseLinOp`` +* ``DiagonalLinOp`` * ``SparseLinOp`` * ``BlockDiagonalLinOp`` * ``StackedLinOp`` @@ -60,6 +61,38 @@ operator :math:`A^* : Y \to X`, satisfying \langle Ax, y\rangle_Y = \langle x, A^*y\rangle_X. +Batched lifting +--------------- + +For a batch of elements, use ``Space.batch`` to describe the batched space and +``vapply`` or ``rvapply`` to lift the operator: + +.. math:: + + A : X \to Y, + \qquad + A^{(B)} : X^B \to Y^B. + +.. code-block:: python + + B = 8 + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + xs = ctx.asarray(np.ones((B,) + X.shape)) + ys = op.vapply(xs, batch_space=XB) + xs_back = op.rvapply(ys, batch_space=YB) + +This is equivalent to stacking scalar applications: + +.. code-block:: python + + ys_ref = ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + +The base fallback uses backend ``vmap``. Structured operators override this +path when they can use matrix multiplication, sparse multi-vector products, +broadcasting, or componentwise product-space batching. + DenseLinOp ---------- diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 9e6384f..4925f7a 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -13,6 +13,7 @@ from .linop import ( BlockDiagonalLinOp, ComposedLinOp, + DiagonalLinOp, DenseLinOp, IdentityLinOp, LinOp, @@ -37,6 +38,7 @@ stochastic_lanczos, ) from .space import ( + BatchSpace, BackendCheck, DTypeCheck, HermitianCheck, @@ -72,6 +74,7 @@ "LinOp", "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", "IdentityLinOp", "MatrixFreeLinOp", @@ -100,6 +103,7 @@ "ProductComponentCheck", "ProductStructureCheck", "ShapeCheck", + "BatchSpace", "VectorSpace", "HermitianSpace", "ProductSpace", diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index 09e2f89..fd8eb31 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -399,6 +399,77 @@ def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: """Stack arrays along a new axis (delegates to xp.stack).""" return self.xp.stack(tuple(arrays), axis=axis) + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize ``fn`` over array axes using a Python-loop fallback.""" + + def axis_for_arg(i: int) -> int | Sequence[int | None] | None: + if isinstance(in_axes, tuple) or isinstance(in_axes, list): + return in_axes[i] + return in_axes + + def normalize_axis(axis: int, ndim: int) -> int: + return axis + ndim if axis < 0 else axis + + def tree_size(x: Any, axis: Any) -> int | None: + if axis is None: + return None + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + for xi, ai in zip(x, axes): + size = tree_size(xi, ai) + if size is not None: + return size + return None + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + return int(shape[axis]) + + def tree_take(x: Any, axis: Any, i: int) -> Any: + if axis is None: + return x + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + return tuple(tree_take(xi, ai, i) for xi, ai in zip(x, axes)) + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + index = [slice(None)] * len(shape) + index[axis] = i + return x[tuple(index)] + + def tree_stack(xs: Sequence[Any], axis: Any) -> Any: + first = xs[0] + if isinstance(first, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(first) + return tuple( + tree_stack(tuple(x[i] for x in xs), ai) + for i, ai in enumerate(axes) + ) + if axis is None: + return first + return self.stack(xs, axis=int(axis)) + + def mapped(*args: Any) -> Any: + axes = tuple(axis_for_arg(i) for i in range(len(args))) + size = None + for arg, axis in zip(args, axes): + size = tree_size(arg, axis) + if size is not None: + break + if size is None: + return fn(*args) + outputs = tuple( + fn(*(tree_take(arg, axis, i) for arg, axis in zip(args, axes))) + for i in range(size) + ) + return tree_stack(outputs, out_axes) + + return mapped + def conj(self, x: DenseArray) -> DenseArray: """Complex conjugate of x (delegates to xp.conj).""" return self.xp.conj(x) diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index 04b2eab..d13f995 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -231,6 +231,15 @@ def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ return a @ b + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using ``jax.vmap``.""" + return self.jax.vmap(fn, in_axes=in_axes, out_axes=out_axes) + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: """ diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index f1413d7..dabf42d 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -413,6 +413,20 @@ def sparse_matmul( return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs) + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using PyTorch's native vmap when available.""" + vmap = getattr(self.torch, "vmap", None) + if vmap is None and hasattr(self.torch, "func"): + vmap = getattr(self.torch.func, "vmap", None) + if vmap is None: + return super().vmap(fn, in_axes=in_axes, out_axes=out_axes) + return vmap(fn, in_dims=in_axes, out_dims=out_axes) + def eigh( self, x: DenseArray, diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index f810bb1..7159df4 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -11,12 +11,14 @@ make_sum, ) from ._dense import DenseLinOp +from ._diagonal import DiagonalLinOp from ._sparse import SparseLinOp from .product import ProductLinOp, StackedLinOp, SumToSingleLinOp, BlockDiagonalLinOp __all__ = [ "LinOp", "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", "IdentityLinOp", "MatrixFreeLinOp", diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index cebd8b7..3029fc3 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -198,6 +198,14 @@ def rapply(self, y: Any) -> Any: """Return ``conj(scalar) * op.rapply(y)``.""" return _conjugate_scalar(self.scalar) * self.op.rapply(y) + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``scalar * op.vapply(xs)``.""" + return self.scalar * self.op.vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``conj(scalar) * op.rvapply(ys)``.""" + return _conjugate_scalar(self.scalar) * self.op.rvapply(ys, batch_space) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.scalar == other.scalar and self.op == other.op @@ -268,6 +276,24 @@ def rapply(self, y: Any) -> Any: acc = self.domain.add(acc, op.rapply(y)) return acc + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].vapply(xs)``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + out_space = self._output_batch_space(self.codomain, in_space) + acc = self.ops_tuple[0].vapply(xs, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.vapply(xs, in_space)) + return acc + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].rvapply(ys)``.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + out_space = self._output_batch_space(self.domain, in_space) + acc = self.ops_tuple[0].rvapply(ys, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.rvapply(ys, in_space)) + return acc + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.ops_tuple == other.ops_tuple @@ -322,6 +348,18 @@ def rapply(self, z: Any) -> Any: """Return ``right.rapply(left.rapply(z))``.""" return self.right.rapply(self.left.rapply(z)) + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``left.vapply(right.vapply(xs))``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + middle = self.right.codomain.batch(in_space.batch_shape, in_space.batch_axes) + return self.left.vapply(self.right.vapply(xs, in_space), middle) + + def rvapply(self, zs: Any, batch_space=None) -> Any: + """Return ``right.rvapply(left.rvapply(zs))``.""" + in_space = self._input_batch_space(self.codomain, zs, batch_space) + middle = self.left.domain.batch(in_space.batch_shape, in_space.batch_axes) + return self.right.rvapply(self.left.rvapply(zs, in_space), middle) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.left == other.left and self.right == other.right @@ -375,6 +413,20 @@ def rapply(self, y: Any) -> Any: self.codomain._check_member(y) return self.domain.zeros() + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return the batched zero element of the codomain.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return self._output_batch_space(self.codomain, in_space).zeros() + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return the batched zero element of the domain.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + return self._output_batch_space(self.domain, in_space).zeros() + def to_dense(self) -> Any: """ Return the dense tensor representation of the zero map. @@ -430,6 +482,20 @@ def rapply(self, x: Any) -> Any: self.codomain._check_member(x) return x + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched domain validation.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched codomain validation.""" + in_space = self._input_batch_space(self.codomain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + def to_dense(self) -> Any: """ Return the dense tensor representation of this identity map. @@ -484,14 +550,22 @@ def __init__( dom: Domain, cod: Codomain, ctx: Context | str | None = None, + vapply: Any | None = None, + rvapply: Any | None = None, ) -> None: if not callable(apply): raise TypeError(f"apply must be callable, got {type(apply).__name__}.") if not callable(rapply): raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") + if vapply is not None and not callable(vapply): + raise TypeError(f"vapply must be callable, got {type(vapply).__name__}.") + if rvapply is not None and not callable(rvapply): + raise TypeError(f"rvapply must be callable, got {type(rvapply).__name__}.") super().__init__(dom, cod, ctx) self.apply_fn = apply self.rapply_fn = rapply + self.vapply_fn = vapply + self.rvapply_fn = rvapply def apply(self, x: Any) -> Any: """Return ``apply_fn(x)``.""" @@ -511,6 +585,30 @@ def rapply(self, y: Any) -> Any: self.domain._check_member(x) return x + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``vapply_fn(xs)`` when supplied, otherwise use fallback batching.""" + if self.vapply_fn is None: + return super().vapply(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.vapply_fn(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``rvapply_fn(ys)`` when supplied, otherwise use fallback batching.""" + if self.rvapply_fn is None: + return super().rvapply(ys, batch_space) + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.rvapply_fn(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return ( @@ -518,18 +616,28 @@ def __eq__(self, other: Any) -> bool: and self.codomain == other.codomain and self.apply_fn is other.apply_fn and self.rapply_fn is other.rapply_fn + and self.vapply_fn is other.vapply_fn + and self.rvapply_fn is other.rvapply_fn ) return False def tree_flatten(self): children = () - aux = (self.apply_fn, self.rapply_fn, self.domain, self.codomain, self.ctx) + aux = ( + self.apply_fn, + self.rapply_fn, + self.domain, + self.codomain, + self.ctx, + self.vapply_fn, + self.rvapply_fn, + ) return children, aux @classmethod def tree_unflatten(cls, aux, children): - apply_fn, rapply_fn, domain, codomain, ctx = aux - return cls(apply_fn, rapply_fn, domain, codomain, ctx) + apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn = aux + return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: return MatrixFreeLinOp( @@ -538,6 +646,8 @@ def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx, + self.vapply_fn, + self.rvapply_fn, ) @@ -567,6 +677,14 @@ def rapply(self, x: Any) -> Any: """Return ``op.apply(x)``.""" return self.op.apply(x) + def vapply(self, ys: Any, batch_space=None) -> Any: + """Return ``op.rvapply(ys)`` over a batch.""" + return self.op.rvapply(ys, batch_space) + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``op.vapply(xs)`` over a batch.""" + return self.op.vapply(xs, batch_space) + @property def H(self) -> LinOp[Domain, Codomain]: """Original operator viewed as the adjoint of this adjoint view.""" diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index af9c05f..d5aa793 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -87,6 +87,64 @@ def adjoint_apply(self, y: Any) -> Any: """Apply the adjoint of this linear operator to ``y``.""" return self.rapply(y) + def vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + """Apply this operator independently over a batch of domain elements.""" + return self._fallback_vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + """Apply the adjoint independently over a batch of codomain elements.""" + return self._fallback_rvapply(ys, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _fallback_vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.ops.vmap(self.apply, in_axes=0, out_axes=0)(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def _fallback_rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.ops.vmap(self.rapply, in_axes=0, out_axes=0)(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + @property def H(self) -> LinOp: """Hermitian-adjoint view of this linear operator.""" diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 87f38d5..79ce2dd 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -47,6 +47,7 @@ def __init__(self, self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._A2T = self._A2.T self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) @@ -98,6 +99,44 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + xs2 = self.ops.reshape(xs, (-1, self._dom_size)) + ys2 = self.ops.matmul(xs2, self._A2T) + y_flat_shape = batch_shape + (self._cod_size,) + ys_flat = self.ops.reshape(ys2, y_flat_shape) + if self._cod_vector_fast_path: + ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) + else: + ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + ys2 = self.ops.reshape(ys, (-1, self._cod_size)) + xs2 = self.ops.matmul(ys2, self._A2H.T) + x_flat_shape = batch_shape + (self._dom_size,) + xs_flat = self.ops.reshape(xs2, x_flat_shape) + if self._dom_vector_fast_path: + xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) + else: + xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def to_dense(self) -> DenseArray: """ Return the stored dense tensor representation of this operator. diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py new file mode 100644 index 0000000..9ba6ea7 --- /dev/null +++ b/spacecore/linop/_diagonal.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from math import prod +from typing import Any + +from ._base import LinOp +from ..backend import Context, jax_pytree_class +from ..space import VectorSpace +from ..types import DenseArray +from .._contextual.manager import ctx_manager + + +@jax_pytree_class +class DiagonalLinOp(LinOp[VectorSpace, VectorSpace]): + """Coordinatewise diagonal linear operator on a vector space.""" + + def __init__( + self, + diagonal: DenseArray, + space: VectorSpace | None = None, + ctx: Context | str | None = None, + ) -> None: + ctx = ctx_manager.resolve_context_priority(ctx, space) + ctx.assert_dense(diagonal) + if space is None: + space = VectorSpace(tuple(diagonal.shape), ctx) + super().__init__(space, space, ctx) + expected = tuple(self.domain.shape) + if tuple(diagonal.shape) != expected: + raise TypeError(f"Expected diagonal.shape == space.shape == {expected}, got {diagonal.shape}") + self.diagonal = diagonal + self._size = prod(self.domain.shape) + self._is_flat = tuple(self.domain.shape) == (self._size,) + self._diag_flat = diagonal if self._is_flat else diagonal.reshape((self._size,)) + dtype = self.ops.get_dtype(diagonal) + self._is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._diag_adjoint = self.ops.conj(diagonal) if self._is_complex else diagonal + self._diag_adjoint_flat = ( + self._diag_adjoint if self._is_flat else self._diag_adjoint.reshape((self._size,)) + ) + + @property + def A(self) -> DenseArray: + return self.to_dense() + + def apply(self, x: DenseArray) -> DenseArray: + if self._enable_checks: + self.domain._check_member(x) + return self.diagonal * x + + def rapply(self, y: DenseArray) -> DenseArray: + if self._enable_checks: + self.codomain._check_member(y) + return self._diag_adjoint * y + + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.diagonal * xs + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self._diag_adjoint * ys + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def to_dense(self) -> DenseArray: + matrix = self.ops.diag(self._diag_flat) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose(self.diagonal, other.diagonal) + return False + + def tree_flatten(self): + children = (self.diagonal,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(children[0], domain, ctx) + + def _convert(self, new_ctx: Context) -> DiagonalLinOp: + return DiagonalLinOp(new_ctx.asarray(self.diagonal), self.domain.convert(new_ctx), new_ctx) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index cda5ba6..fca611f 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -98,6 +98,44 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + xs2 = self.ops.reshape(xs, (-1, self._dom_size)) + ys_t = self.ops.sparse_matmul(self.A, self.ops.transpose(xs2)) + ys2 = self.ops.transpose(ys_t) + ys_flat = self.ops.reshape(ys2, batch_shape + (self._cod_size,)) + if self._cod_vector_fast_path: + ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) + else: + ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + ys2 = self.ops.reshape(ys, (-1, self._cod_size)) + xs_t = self.ops.sparse_matmul(self._AH, self.ops.transpose(ys2)) + xs2 = self.ops.transpose(xs_t) + xs_flat = self.ops.reshape(xs2, batch_shape + (self._dom_size,)) + if self._dom_vector_fast_path: + xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) + else: + xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def to_dense(self) -> DenseArray: """ Materialize the stored sparse matrix as a dense operator tensor. diff --git a/spacecore/linop/product/_block.py b/spacecore/linop/product/_block.py index a2f7fe8..4992c6a 100644 --- a/spacecore/linop/product/_block.py +++ b/spacecore/linop/product/_block.py @@ -53,6 +53,28 @@ def _rapply_unchecked(self, y: Any) -> Any: return self._rapply_parts[0](y[0]), self._rapply_parts[1](y[1]) return tuple(rapply(yi) for rapply, yi in zip(self._rapply_parts, y)) + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + for op, xi in zip(self.parts, x) + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + for op, yi in zip(self.parts, y) + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: if not parts: diff --git a/spacecore/linop/product/_from_single.py b/spacecore/linop/product/_from_single.py index abef674..fabea53 100644 --- a/spacecore/linop/product/_from_single.py +++ b/spacecore/linop/product/_from_single.py @@ -61,6 +61,30 @@ def _rapply_unchecked(self, y: Any) -> Any: acc = xi if acc is None else (acc + xi if use_direct_add else self.dom.add(xi, acc)) return acc + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(x, op.domain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.domain.batch(batch_shape, batch_axes) + acc = None + for op, yi in zip(self.parts, y): + xi = op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + acc = xi if acc is None else out_space.add(acc, xi) + return acc + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: if not parts: diff --git a/spacecore/linop/product/_to_single.py b/spacecore/linop/product/_to_single.py index 806f94e..c5ab082 100644 --- a/spacecore/linop/product/_to_single.py +++ b/spacecore/linop/product/_to_single.py @@ -61,6 +61,30 @@ def _rapply_unchecked(self, y: Any) -> Any: return self._rapply_parts[0](y), self._rapply_parts[1](y) return tuple(rapply(y) for rapply in self._rapply_parts) + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.codomain.batch(batch_shape, batch_axes) + acc = None + for op, xi in zip(self.parts, x): + yi = op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + acc = yi if acc is None else out_space.add(acc, yi) + return acc + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(y, op.codomain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: if not parts: diff --git a/spacecore/space/__init__.py b/spacecore/space/__init__.py index fbb7ada..4883640 100644 --- a/spacecore/space/__init__.py +++ b/spacecore/space/__init__.py @@ -10,11 +10,13 @@ SquareMatrixCheck, ) from ._base import Space +from ._batch import BatchSpace from ._herm import HermitianSpace from ._vector import VectorSpace from ._product import ProductSpace __all__ = [ + "BatchSpace", "BackendCheck", "DTypeCheck", "HermitianCheck", diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index d206bb4..53318d1 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -106,6 +106,18 @@ def unflatten(self, v: DenseArray) -> Any: """Inverse of flatten; returns an element in the requested representation.""" raise NotImplementedError + def batch( + self, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...] | None = None, + ) -> Space: + """Return a wrapper representing a batch/product of this space.""" + from ._batch import BatchSpace + + if batch_axes is None: + batch_axes = tuple(range(len(batch_shape))) + return BatchSpace(self, batch_shape, batch_axes) + def _convert(self, new_ctx: Context) -> Space: raise NotImplementedError() diff --git a/spacecore/space/_batch.py b/spacecore/space/_batch.py new file mode 100644 index 0000000..b227c0d --- /dev/null +++ b/spacecore/space/_batch.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Callable, Tuple + +from ._base import Space +from ._checks import BackendCheck, DTypeCheck, ShapeCheck +from ._product import ProductSpace +from ..backend import Context +from ..types import DenseArray + + +def _batched_shape( + base_shape: tuple[int, ...], + batch_shape: tuple[int, ...], + batch_axes: tuple[int, ...], +) -> tuple[int, ...]: + total_ndim = len(base_shape) + len(batch_shape) + axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + if len(batch_shape) != len(axes): + raise ValueError("batch_shape and batch_axes must have the same length.") + if len(set(axes)) != len(axes): + raise ValueError("batch_axes must be unique.") + if any(axis < 0 or axis >= total_ndim for axis in axes): + raise ValueError( + f"batch_axes must be valid axes for batched ndim {total_ndim}, got {batch_axes}." + ) + + out: list[int | None] = [None] * total_ndim + for axis, dim in zip(axes, batch_shape): + out[axis] = int(dim) + + base_iter = iter(int(dim) for dim in base_shape) + for i, dim in enumerate(out): + if dim is None: + out[i] = next(base_iter) + return tuple(dim for dim in out if dim is not None) + + +class BatchSpace(Space): + """ + Wrapper space representing a batch of elements from a base space. + + ``BatchSpace(X, batch_shape, batch_axes)`` represents ``X`` repeated over + the given batch dimensions. It deliberately wraps the original space rather + than folding batch dimensions into the base ``Space`` instance. + """ + + def __init__( + self, + base: Space, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...], + ctx: Context | str | None = None, + ) -> None: + ctx = base.ctx if ctx is None else ctx + super().__init__( + _batched_shape(tuple(base.shape), tuple(batch_shape), tuple(batch_axes)), + ctx, + ) + self.base = base.convert(self.ctx) + self.batch_shape = tuple(int(dim) for dim in batch_shape) + total_ndim = len(self.base.shape) + len(self.batch_shape) + self.batch_axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + self._batch_size = prod(self.batch_shape) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BatchSpace): + return ( + self.ctx == other.ctx + and self.base == other.base + and self.batch_shape == other.batch_shape + and self.batch_axes == other.batch_axes + ) + return False + + @property + def _is_product(self) -> bool: + return isinstance(self.base, ProductSpace) + + def _component_spaces(self) -> tuple[BatchSpace, ...]: + if not isinstance(self.base, ProductSpace): + raise TypeError("BatchSpace component spaces are available only for ProductSpace bases.") + return tuple(sp.batch(self.batch_shape, self.batch_axes) for sp in self.base.spaces) + + def _check_member(self, x: Any) -> None: + if isinstance(self.base, ProductSpace): + if not isinstance(x, tuple) or len(x) != self.base.arity: + raise TypeError( + f"BatchSpace over ProductSpace expects tuple length {self.base.arity}." + ) + for space, component in zip(self._component_spaces(), x): + space.check_member(component) + return + + BackendCheck()(self, x) + ShapeCheck()(self, x) + DTypeCheck()(self, x) + for check in self.base.member_checks(): + if isinstance(check, (BackendCheck, ShapeCheck, DTypeCheck)): + continue + check(self, x) + + def zeros(self) -> Any: + if isinstance(self.base, ProductSpace): + return tuple(space.zeros() for space in self._component_spaces()) + return self.ops.zeros(self.shape, dtype=self.dtype) + + def add(self, x: Any, y: Any) -> Any: + if self._enable_checks: + self._check_member(x) + self._check_member(y) + if isinstance(self.base, ProductSpace): + return tuple(space.add(xi, yi) for space, xi, yi in zip(self._component_spaces(), x, y)) + return x + y + + def scale(self, a: Any, x: Any) -> Any: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + return tuple(space.scale(a, xi) for space, xi in zip(self._component_spaces(), x)) + return a * x + + def inner(self, x: Any, y: Any) -> Any: + if self._enable_checks: + self._check_member(x) + self._check_member(y) + if isinstance(self.base, ProductSpace): + acc = None + for space, xi, yi in zip(self._component_spaces(), x, y): + v = space.inner(xi, yi) + acc = v if acc is None else acc + v + return acc + return self.ops.vdot(x, y) + + def eigh(self, x: Any, k: int = None) -> Any: + raise TypeError(f"{type(self).__name__}.eigh is not defined for batched spaces.") + + def flatten(self, x: Any) -> DenseArray: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + parts = tuple(space.flatten(xi) for space, xi in zip(self._component_spaces(), x)) + return parts[0] if len(parts) == 1 else self.ops.concatenate(parts, axis=0) + return self.ops.reshape(x, (-1,)) + + def unflatten(self, v: DenseArray) -> Any: + vv = self.ctx.assert_dense(v) if self._enable_checks else v + if isinstance(self.base, ProductSpace): + if ( + tuple(getattr(vv, "shape", ())) == tuple(self.shape) + and self.batch_axes == tuple(range(len(self.batch_shape))) + ): + xs = [] + offset = 0 + for component, space in zip(self.base.spaces, self._component_spaces()): + size = prod(component.shape) + flat_component = vv[(..., slice(offset, offset + size))] + xs.append(space.unflatten(flat_component)) + offset += size + return tuple(xs) + xs = [] + offset = 0 + for space in self._component_spaces(): + size = prod(space.shape) + xs.append(space.unflatten(vv[offset : offset + size])) + offset += size + return tuple(xs) + return self.ops.reshape(vv, self.shape) + + def apply(self, x: Any, f: Callable) -> Any: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + return tuple(space.apply(xi, f) for space, xi in zip(self._component_spaces(), x)) + try: + y = f(x) + except Exception: + y = self.ops.vmap(lambda xi: self.base.apply(xi, f))(x) + if self._enable_checks: + self._check_member(y) + return y + + def _convert(self, new_ctx: Context) -> BatchSpace: + return BatchSpace(self.base.convert(new_ctx), self.batch_shape, self.batch_axes, new_ctx) diff --git a/tests/linops/test_batched_lifting.py b/tests/linops/test_batched_lifting.py new file mode 100644 index 0000000..2b3b107 --- /dev/null +++ b/tests/linops/test_batched_lifting.py @@ -0,0 +1,150 @@ +import importlib + +import numpy as np +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _stack_apply(ctx, op, xs): + return ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + + +def _stack_rapply(ctx, op, ys): + return ctx.ops.stack(tuple(op.rapply(y) for y in ys), axis=0) + + +def test_dense_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(matrix, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_sparse_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(ctx.assparse(sps.csr_matrix(dense)), dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_diagonal_identity_zero_sum_composed_and_adjoint_batched_lifting(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + d1 = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + d2 = sc.DiagonalLinOp(ctx.asarray([-1.0, 0.5, 4.0]), space, ctx) + identity = sc.IdentityLinOp(space, ctx) + zero = sc.ZeroLinOp(space, space, ctx) + summed = d1 + d2 + zero + composed = d1 @ (d2 + identity) + adjoint = composed.H + xs = ctx.asarray([[1.0, 2.0, 3.0], [4.0, -1.0, 0.0]]) + + for op in (d1, identity, zero, summed, composed): + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(xs), _stack_rapply(ctx, op, xs)) + assert np.allclose(adjoint.vapply(xs), _stack_apply(ctx, adjoint, xs)) + assert np.allclose(adjoint.rvapply(xs), _stack_rapply(ctx, adjoint, xs)) + + +def test_matrix_free_vapply_uses_callback_when_supplied_and_fallback_when_absent(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + calls = {"vapply": 0, "rvapply": 0} + + def apply(x): + return ctx.asarray(matrix @ np.asarray(x)) + + def rapply(y): + return ctx.asarray(matrix.T @ np.asarray(y)) + + def vapply(xs): + calls["vapply"] += 1 + return ctx.asarray(np.asarray(xs) @ matrix.T) + + def rvapply(ys): + calls["rvapply"] += 1 + return ctx.asarray(np.asarray(ys) @ matrix) + + with_callbacks = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx, vapply, rvapply) + fallback = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(with_callbacks.vapply(xs), _stack_apply(ctx, with_callbacks, xs)) + assert np.allclose(with_callbacks.rvapply(ys), _stack_rapply(ctx, with_callbacks, ys)) + assert calls == {"vapply": 1, "rvapply": 1} + assert np.allclose(fallback.vapply(xs), _stack_apply(ctx, fallback, xs)) + assert np.allclose(fallback.rvapply(ys), _stack_rapply(ctx, fallback, ys)) + + +def test_product_linops_batched_lifting_matches_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + y0 = sc.VectorSpace((3,), ctx) + y1 = sc.VectorSpace((2,), ctx) + a0 = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), x0, y0, ctx) + a1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0, 0.5], [0.0, 3.0, 4.0]]), x1, y1, ctx) + s1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.0, 3.0]]), x0, y1, ctx) + block = sc.BlockDiagonalLinOp.from_operators((a0, a1)) + stacked = sc.StackedLinOp.from_operators((a0, s1)) + sum_to_single = sc.SumToSingleLinOp.from_operators((a0.H, s1.H)) + + xb = (ctx.asarray([[1.0, 2.0], [3.0, 4.0]]), ctx.asarray([[5.0, 6.0, 7.0], [1.0, -1.0, 0.5]])) + yb = (ctx.asarray([[1.0, 2.0, 3.0], [0.0, -1.0, 4.0]]), ctx.asarray([[2.0, 1.0], [3.0, -2.0]])) + single_x = ctx.asarray([[1.0, 2.0], [3.0, 4.0]]) + single_y = ctx.asarray([[1.0, 2.0], [3.0, -2.0]]) + + block_v = block.vapply(xb) + block_expected = tuple(ctx.ops.stack(tuple(block.apply((xb[0][i], xb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_v[0], block_expected[0]) + assert np.allclose(block_v[1], block_expected[1]) + + block_rv = block.rvapply(yb) + block_r_expected = tuple(ctx.ops.stack(tuple(block.rapply((yb[0][i], yb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_rv[0], block_r_expected[0]) + assert np.allclose(block_rv[1], block_r_expected[1]) + + stacked_v = stacked.vapply(single_x) + stacked_expected = tuple(ctx.ops.stack(tuple(stacked.apply(single_x[i])[j] for i in range(2))) for j in range(2)) + assert np.allclose(stacked_v[0], stacked_expected[0]) + assert np.allclose(stacked_v[1], stacked_expected[1]) + + assert np.allclose( + stacked.rvapply(yb), + ctx.ops.stack(tuple(stacked.rapply((yb[0][i], yb[1][i])) for i in range(2))), + ) + assert np.allclose( + sum_to_single.vapply(yb), + ctx.ops.stack(tuple(sum_to_single.apply((yb[0][i], yb[1][i])) for i in range(2))), + ) + sum_rv = sum_to_single.rvapply(single_y) + sum_r_expected = tuple( + ctx.ops.stack(tuple(sum_to_single.rapply(single_y[i])[j] for i in range(2))) + for j in range(2) + ) + assert np.allclose(sum_rv[0], sum_r_expected[0]) + assert np.allclose(sum_rv[1], sum_r_expected[1]) diff --git a/tests/spaces/test_batch_space.py b/tests/spaces/test_batch_space.py new file mode 100644 index 0000000..3053806 --- /dev/null +++ b/tests/spaces/test_batch_space.py @@ -0,0 +1,38 @@ +import importlib + +import numpy as np + + +def test_vector_space_batch_wrapper_shape_and_membership(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x = sc.VectorSpace((2, 3), ctx) + + xb = x.batch(batch_shape=(4,), batch_axes=(0,)) + + assert isinstance(xb, sc.BatchSpace) + assert xb.base == x + assert xb.batch_shape == (4,) + assert xb.batch_axes == (0,) + assert xb.shape == (4, 2, 3) + xb.check_member(ctx.ops.zeros((4, 2, 3), dtype=ctx.dtype)) + + +def test_product_space_batch_wrapper_validates_component_batches(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + product = sc.ProductSpace((x0, x1), ctx) + batched = product.batch((5,), (0,)) + + value = ( + ctx.ops.zeros((5, 2), dtype=ctx.dtype), + ctx.ops.zeros((5, 3), dtype=ctx.dtype), + ) + + assert batched.shape == (5, 5) + batched.check_member(value) + zeros = batched.zeros() + assert np.allclose(zeros[0], np.zeros((5, 2))) + assert np.allclose(zeros[1], np.zeros((5, 3))) From 82ac1b999bd570ff87b6e609d7b4b667c9e85e4d Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 22:49:54 -0300 Subject: [PATCH 3/3] Optimize batched dense and sparse linops --- spacecore/linop/_dense.py | 85 ++++++++++++++++++++++------ spacecore/linop/_sparse.py | 85 ++++++++++++++++++++++------ tests/linops/test_batched_lifting.py | 25 ++++++++ 3 files changed, 163 insertions(+), 32 deletions(-) diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 79ce2dd..244f566 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -56,6 +56,8 @@ def __init__(self, if not self._enable_checks: self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + self.vapply = self._vapply_unchecked + self.rvapply = self._rvapply_unchecked @property def A(self) -> DenseArray: @@ -99,6 +101,71 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = xs2 @ self._A2T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = ys2 @ self._A2H.T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.domain, xs, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): @@ -106,14 +173,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(xs) batch_shape = tuple(in_space.batch_shape) - xs2 = self.ops.reshape(xs, (-1, self._dom_size)) - ys2 = self.ops.matmul(xs2, self._A2T) - y_flat_shape = batch_shape + (self._cod_size,) - ys_flat = self.ops.reshape(ys2, y_flat_shape) - if self._cod_vector_fast_path: - ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) - else: - ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + ys = self._vapply_unchecked_leading(xs, batch_shape) if self._enable_checks: self._output_batch_space(self.codomain, in_space)._check_member(ys) return ys @@ -125,14 +185,7 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(ys) batch_shape = tuple(in_space.batch_shape) - ys2 = self.ops.reshape(ys, (-1, self._cod_size)) - xs2 = self.ops.matmul(ys2, self._A2H.T) - x_flat_shape = batch_shape + (self._dom_size,) - xs_flat = self.ops.reshape(xs2, x_flat_shape) - if self._dom_vector_fast_path: - xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) - else: - xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + xs = self._rvapply_unchecked_leading(ys, batch_shape) if self._enable_checks: self._output_batch_space(self.domain, in_space)._check_member(xs) return xs diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index fca611f..ab4f306 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -51,6 +51,8 @@ def __init__(self, if not self._enable_checks: self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + self.vapply = self._vapply_unchecked + self.rvapply = self._rvapply_unchecked @property def A(self) -> SparseArray: @@ -98,6 +100,71 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = (self.A @ xs2.T).T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = (self._AH @ ys2.T).T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.domain, xs, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): @@ -105,14 +172,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(xs) batch_shape = tuple(in_space.batch_shape) - xs2 = self.ops.reshape(xs, (-1, self._dom_size)) - ys_t = self.ops.sparse_matmul(self.A, self.ops.transpose(xs2)) - ys2 = self.ops.transpose(ys_t) - ys_flat = self.ops.reshape(ys2, batch_shape + (self._cod_size,)) - if self._cod_vector_fast_path: - ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) - else: - ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + ys = self._vapply_unchecked_leading(xs, batch_shape) if self._enable_checks: self._output_batch_space(self.codomain, in_space)._check_member(ys) return ys @@ -124,14 +184,7 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(ys) batch_shape = tuple(in_space.batch_shape) - ys2 = self.ops.reshape(ys, (-1, self._cod_size)) - xs_t = self.ops.sparse_matmul(self._AH, self.ops.transpose(ys2)) - xs2 = self.ops.transpose(xs_t) - xs_flat = self.ops.reshape(xs2, batch_shape + (self._dom_size,)) - if self._dom_vector_fast_path: - xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) - else: - xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + xs = self._rvapply_unchecked_leading(ys, batch_shape) if self._enable_checks: self._output_batch_space(self.domain, in_space)._check_member(xs) return xs diff --git a/tests/linops/test_batched_lifting.py b/tests/linops/test_batched_lifting.py index 2b3b107..efb8f62 100644 --- a/tests/linops/test_batched_lifting.py +++ b/tests/linops/test_batched_lifting.py @@ -9,6 +9,11 @@ def _ctx(): return sc.Context(sc.NumpyOps(), dtype=np.float64) +def _unchecked_ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + + def _stack_apply(ctx, op, xs): return ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) @@ -45,6 +50,26 @@ def test_sparse_linop_vapply_and_rvapply_match_stacked_apply(): assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) +def test_dense_and_sparse_batched_lifting_fast_paths_without_checks(): + sc = importlib.import_module("spacecore") + ctx = _unchecked_ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + sparse = ctx.assparse(sps.csr_matrix([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]])) + dense_op = sc.DenseLinOp(matrix, dom, cod, ctx) + sparse_op = sc.SparseLinOp(sparse, dom, cod, ctx) + batch_dom = dom.batch((3,), (0,)) + batch_cod = cod.batch((2,), (0,)) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(dense_op.vapply(xs, batch_dom), np.asarray(xs) @ np.asarray(matrix).T) + assert np.allclose(dense_op.rvapply(ys, batch_cod), np.asarray(ys) @ np.asarray(matrix)) + assert np.allclose(sparse_op.vapply(xs, batch_dom), (sparse @ np.asarray(xs).T).T) + assert np.allclose(sparse_op.rvapply(ys, batch_cod), (sparse.T @ np.asarray(ys).T).T) + + def test_diagonal_identity_zero_sum_composed_and_adjoint_batched_lifting(): sc = importlib.import_module("spacecore") ctx = _ctx()