From 09efbcacbc57c3ffd94d980653b8581d4fd493c5 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 19 May 2026 15:53:45 -0300 Subject: [PATCH 01/44] Refactor backend ops to use Array API Compat - delegate common dense operations through AAC-compatible `xp` - use AAC namespaces for NumPy and Torch backends - keep Torch-specific semantics in `TorchOps` - make `eps` depend on backend default dtype - remove broad `TypeError` fallback dispatch - add delegation, dtype epsilon, and complex adjoint tests - document the Array API backend design --- docs/source/design/index.rst | 1 + docs/source/tutorials/backend_ops.rst | 8 + pyproject.toml | 1 + spacecore/backend/_ops.py | 1585 +++++----------- spacecore/backend/jax/_ops.py | 1655 +---------------- spacecore/backend/numpy/_ops.py | 1697 +----------------- spacecore/backend/torch/_ops.py | 1583 ++-------------- tests/backend/test_backend_ops_delegation.py | 183 ++ tests/backend/test_backend_registry.py | 58 +- 9 files changed, 892 insertions(+), 5879 deletions(-) create mode 100644 tests/backend/test_backend_ops_delegation.py diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 1d9df1b..390b4ad 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -10,3 +10,4 @@ users should reason about them. conversion_policy dtype_policy checking_policy + backend_ops_array_api diff --git a/docs/source/tutorials/backend_ops.rst b/docs/source/tutorials/backend_ops.rst index c90a88e..47b58ce 100644 --- a/docs/source/tutorials/backend_ops.rst +++ b/docs/source/tutorials/backend_ops.rst @@ -52,6 +52,14 @@ What BackendOps signifies internals. It mostly wraps NumPy-like methods, while normalizing the minimal signatures that SpaceCore relies on. +Common dense-array methods are implemented once in ``BackendOps`` by delegating +to an Array API compatible ``xp`` namespace. NumPy and PyTorch use +``array-api-compat`` wrappers, while JAX uses ``jax.numpy``. Concrete backend +classes keep behavior that is genuinely backend-specific, such as dtype +sanitation, sparse conversion, indexed updates, device/autograd controls, and +control-flow primitives. ``ops.xp`` is available as an escape hatch, but +portable SpaceCore code should prefer explicit ``ops`` methods. + For example, NumPy and JAX expose different optional arguments for matrix multiplication, but SpaceCore's portable interface only needs the common core: diff --git a/pyproject.toml b/pyproject.toml index 198b1b6..2796cbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ authors = [ { name = "Pavlo Pelikh" } ] dependencies = [ + "array-api-compat>=1.14.0", "numpy>=2.0.0", "scipy>=1.17", ] diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index 5bca111..b2f3ad1 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -1,11 +1,28 @@ from __future__ import annotations from abc import ABC, abstractmethod +import importlib from typing import Any, Sequence, Tuple, Callable, Optional, Type, ClassVar from ..types import DenseArray, SparseArray, DType, ArrayLike, Index, T, X, Y, R, Carry +class LazyNamespace: + def __init__(self, module_name: str) -> None: + self.__name__ = module_name + self.__isabstractmethod__ = False + self._module_name = module_name + self._module: Any | None = None + + def _load(self) -> Any: + if self._module is None: + self._module = importlib.import_module(self._module_name) + return self._module + + def __getattr__(self, name: str) -> Any: + return getattr(self._load(), name) + + class BackendOps(ABC): """ Backend-agnostic numerical ops interface (portable core). @@ -18,6 +35,7 @@ class BackendOps(ABC): _family: ClassVar[str] _allow_sparse: ClassVar[bool] + xp: ClassVar[Any] @property def family(self) -> str: @@ -152,15 +170,15 @@ def is_array(self, x: Any) -> bool: return self.is_dense(x) or self.is_sparse(x) @abstractmethod - def get_dtype(self, x: Any) -> DType: + def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: """ - Generic backend-agnostic wrapper to return an array dtype. + Generic backend-agnostic wrapper to convert input to a sparse array. Input: - x: Dense or sparse backend array. + x: Dense, sparse, or array-like input plus sparse-format options. Output: - Backend dtype associated with x. + Sparse backend array. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. @@ -168,140 +186,179 @@ def get_dtype(self, x: Any) -> DType: ... @abstractmethod - def shape(self, x: Any) -> tuple[int, ...]: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Generic backend-agnostic wrapper to return array shape metadata. + Generic backend-agnostic wrapper to multiply sparse and dense arrays. Input: - x: Dense or sparse backend array. + a: Sparse backend array; b: Dense backend array. Output: - Tuple describing the logical shape of x. + Dense backend array containing the product. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ + ... @abstractmethod - def ndim(self, x: Any) -> int: + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, + keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: """ - Generic backend-agnostic wrapper to return array rank metadata. + Generic backend-agnostic wrapper to compute a stable log-sum-exp reduction. Input: - x: Dense or sparse backend array. + a: Dense backend array; axis, weights, and sign options control the reduction. Output: - Number of dimensions in x. + Dense backend array or tuple containing log-sum-exp results. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ + ... @abstractmethod - def size(self, x: Any) -> int: + def index_set( + self, + x: DenseArray, + index: Index, + values: ArrayLike, + *, + copy: bool = True, + ) -> DenseArray: """ - Generic backend-agnostic wrapper to return logical element count. + Generic backend-agnostic wrapper to set indexed values. Input: - x: Dense or sparse backend array. + x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. Output: - Total number of logical dense elements. + Dense backend array with indexed values set. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ - @property @abstractmethod - def inf(self) -> DenseArray: + def index_add( + self, + x: DenseArray, + index: Index, + values: DenseArray, + *, + copy: bool = True, + ) -> DenseArray: """ - Generic backend-agnostic wrapper to positive infinity scalar. + Generic backend-agnostic wrapper to add into indexed values. Input: - None. + x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. Output: - Backend scalar representing positive infinity. + Dense backend array with indexed values incremented. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ + ... - @property @abstractmethod - def nan(self) -> DenseArray: + def ix_(self, *args: Any) -> Any: """ - Generic backend-agnostic wrapper to access a NaN scalar. + Generic backend-agnostic wrapper to build open mesh index arrays. Input: - None. + args: One-dimensional index arrays or sequences. Output: - Backend scalar representing NaN. + Tuple of dense backend arrays usable for open-mesh indexing. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ + ... - @property @abstractmethod - def pi(self) -> DenseArray: + def fori_loop( + self, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, + ) -> T: """ - Generic backend-agnostic wrapper to pi scalar. + Generic backend-agnostic wrapper to run a counted loop primitive. Input: - None. + lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. Output: - Backend scalar representing pi. + Final carry value after loop execution. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ - @property @abstractmethod - def e(self) -> DenseArray: + def while_loop( + self, + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: """ - Generic backend-agnostic wrapper to access Euler's number scalar. + Generic backend-agnostic wrapper to run a while-loop primitive. Input: - None. + cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. Output: - Backend scalar representing Euler's number. + Final carry value after loop execution. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ - @property @abstractmethod - def eps(self) -> DenseArray: + def scan( + self, + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + ) -> Tuple[Carry, Y]: """ - Generic backend-agnostic wrapper to machine epsilon scalar. + Generic backend-agnostic wrapper to run a scan primitive. Input: - None. + f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: - Backend scalar for float64 machine epsilon. + Tuple of final carry and stacked outputs. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ @abstractmethod - def asarray(self, x: Any, dtype: DType | None = None) -> DenseArray: + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: """ - Generic backend-agnostic wrapper to convert input to a dense array. + Generic backend-agnostic wrapper to run conditional branch selection. Input: - x/a: Array-like input and optional dtype or backend conversion parameters. + pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: - Dense backend array. + Result returned by the selected branch. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. @@ -309,1163 +366,419 @@ def asarray(self, x: Any, dtype: DType | None = None) -> DenseArray: ... @abstractmethod - def astype(self, x: DenseArray, dtype: DType) -> DenseArray: + def allclose_sparse( + self, + a: SparseArray, + b: SparseArray, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> bool: """ - Generic backend-agnostic wrapper to cast an array to a dtype. + Generic backend-agnostic wrapper to compare sparse arrays elementwise within tolerances. Input: - x: Dense backend array; dtype: target dtype and optional casting controls. + a, b: Sparse backend arrays; rtol and atol configure comparison. Output: - Dense backend array with the requested dtype. + Boolean indicating whether sparse arrays are close. This declaration only specifies the portable SpaceCore interface. See the concrete backend implementation for backend-specific behavior. """ + ... - @abstractmethod - def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: - """ - Generic backend-agnostic wrapper to convert input to a sparse array. + def _dtype_arg(self, dtype: DType | None) -> DType | None: + return None if dtype is None else self.sanitize_dtype(dtype) - Input: - x: Dense, sparse, or array-like input plus sparse-format options. + def _to_axis_tuple(self, axis: int | Sequence[int] | None) -> int | tuple[int, ...] | None: + if axis is None or isinstance(axis, int): + return axis + return tuple(axis) - Output: - Sparse backend array. + def _permute_dims(self, x: DenseArray, axes: Sequence[int]) -> DenseArray: + axes = tuple(axes) + if hasattr(self.xp, "permute_dims"): + return self.xp.permute_dims(x, axes) + if hasattr(self.xp, "permute"): + return self.xp.permute(x, axes) + return self.xp.transpose(x, axes=axes) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def _move_axis_order( + self, + ndim: int, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> tuple[int, ...]: + src = (source,) if isinstance(source, int) else tuple(source) + dst = (destination,) if isinstance(destination, int) else tuple(destination) + src = tuple(axis + ndim if axis < 0 else axis for axis in src) + dst = tuple(axis + ndim if axis < 0 else axis for axis in dst) + order = [axis for axis in range(ndim) if axis not in src] + for dest, axis in sorted(zip(dst, src, strict=True)): + order.insert(dest, axis) + return tuple(order) - @abstractmethod - def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create an uninitialized dense array. + @property + def inf(self) -> DenseArray: + return self.asarray(float("inf")) - Input: - shape: Output shape; dtype and placement options are backend-specific. + @property + def nan(self) -> DenseArray: + return self.asarray(float("nan")) - Output: - Dense backend array with uninitialized values. + @property + def pi(self) -> DenseArray: + return self.asarray(3.141592653589793) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + @property + def e(self) -> DenseArray: + return self.asarray(2.718281828459045) - @abstractmethod - def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a zero-filled dense array. + @property + def eps(self) -> DenseArray: + return self.asarray(self.xp.finfo(self.sanitize_dtype(None)).eps) - Input: - shape: Output shape; dtype and placement options are backend-specific. + def get_dtype(self, x: Any) -> DType: + if self.is_array(x): + return x.dtype + raise TypeError(f"Expected {self.family} array, got {type(x)}.") - Output: - Dense backend array filled with zeros. + def shape(self, x: Any) -> tuple[int, ...]: + return tuple(x.shape) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def ndim(self, x: Any) -> int: + return int(x.ndim) - @abstractmethod - def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a one-filled dense array. + def size(self, x: Any) -> int: + result = 1 + for dim in self.shape(x): + result *= int(dim) + return result + + def asarray(self, x: Any, dtype: DType | None = None, **backend_kwargs: Any) -> DenseArray: + if self.is_sparse(x) and hasattr(x, "to_dense"): + x = x.to_dense() + dtype = self._dtype_arg(dtype) + if hasattr(self.xp, "asarray"): + return self.xp.asarray(x, dtype=dtype, **backend_kwargs) + return self.xp.as_tensor(x, dtype=dtype, **backend_kwargs) + + def astype(self, x: DenseArray, dtype: DType, **backend_kwargs: Any) -> DenseArray: + dtype = self.sanitize_dtype(dtype) + if hasattr(x, "astype"): + return x.astype(dtype, **backend_kwargs) + return x.to(dtype=dtype, **backend_kwargs) - Input: - shape: Output shape; dtype and placement options are backend-specific. + def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + return self.xp.empty(shape, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array filled with ones. + def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + return self.xp.zeros(shape, dtype=self._dtype_arg(dtype)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + return self.xp.ones(shape, dtype=self._dtype_arg(dtype)) - @abstractmethod def zeros_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create zeros shaped like another array. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of zeros. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + return self.xp.zeros_like(x, dtype=self._dtype_arg(dtype)) - @abstractmethod def ones_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create ones shaped like another array. + return self.xp.ones_like(x, dtype=self._dtype_arg(dtype)) - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: + return self.xp.full_like(x, value, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array of ones. + def arange( + self, + start: int, + stop: int | None = None, + step: int | None = None, + dtype: DType | None = None, + ) -> DenseArray: + dtype = self._dtype_arg(dtype) + if stop is None: + return self.xp.arange(start, dtype=dtype) + if step is None: + return self.xp.arange(start, stop, dtype=dtype) + return self.xp.arange(start, stop, step, dtype=dtype) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: + return self.xp.full(shape, fill_value, dtype=self._dtype_arg(dtype)) - @abstractmethod - def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create filled values shaped like another array. + def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: + return self.xp.eye(n, m, dtype=self._dtype_arg(dtype)) - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. + def ravel(self, x: DenseArray) -> DenseArray: + if hasattr(self.xp, "ravel"): + return self.xp.ravel(x) + return self.reshape(x, (-1,)) - Output: - Dense backend array filled with the requested value. + def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: + shape_arg = (shape,) if isinstance(shape, int) else shape + return self.xp.reshape(x, shape_arg) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: + if axes is None: + axes = tuple(reversed(range(self.ndim(x)))) + return self._permute_dims(x, axes) - @abstractmethod - def arange(self, start: int, stop: int | None = None, step: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create evenly spaced integer-range values. + def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: + if hasattr(self.xp, "swapaxes"): + return self.xp.swapaxes(x, axis1, axis2) + axes = list(range(self.ndim(x))) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + return self._permute_dims(x, axes) - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. + def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: + return self.xp.broadcast_to(x, shape) - Output: - One-dimensional dense backend array. + def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: + if isinstance(axis, int): + if hasattr(self.xp, "expand_dims"): + return self.xp.expand_dims(x, axis=axis) + return self.xp.unsqueeze(x, axis) + out = x + ndim = self.ndim(x) + len(axis) + for ax in sorted(a + ndim if a < 0 else a for a in axis): + out = self.expand_dims(out, ax) + return out - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: + if axis is None: + axis = tuple(i for i, dim in enumerate(self.shape(x)) if dim == 1) + if not axis: + return x + axis = (axis,) if isinstance(axis, int) else tuple(axis) + return self.xp.squeeze(x, axis=axis) - @abstractmethod - def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a filled dense array. + def moveaxis( + self, + x: DenseArray, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> DenseArray: + if hasattr(self.xp, "moveaxis"): + return self.xp.moveaxis(x, source, destination) + return self._permute_dims(x, self._move_axis_order(self.ndim(x), source, destination)) - Input: - shape: Output shape; fill_value and dtype options are backend-specific. + def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: + return self.xp.stack(tuple(arrays), axis=axis) - Output: - Dense backend array filled with fill_value. + def conj(self, x: DenseArray) -> DenseArray: + return self.xp.conj(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def real(self, x: DenseArray) -> DenseArray: + return self.xp.real(x) - @abstractmethod - def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a dense identity-like matrix. + def imag(self, x: DenseArray) -> DenseArray: + return self.xp.imag(x) - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. + def abs(self, x: DenseArray) -> DenseArray: + return self.xp.abs(x) - Output: - Two-dimensional dense backend array. + def sign(self, x: DenseArray) -> DenseArray: + return self.xp.sign(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def sqrt(self, x: DenseArray) -> DenseArray: + return self.xp.sqrt(x) - @abstractmethod - def ravel(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to flatten an array. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: - """ - Generic backend-agnostic wrapper to reshape an array. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to permute array axes. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Generic backend-agnostic wrapper to interchange two axes. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: - """ - Generic backend-agnostic wrapper to broadcast an array to a shape. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Generic backend-agnostic wrapper to insert length-one axes. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to remove length-one axes. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to move axes to new positions. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: - """ - Generic backend-agnostic wrapper to stack arrays along a new axis. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def conj(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute complex conjugates. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def real(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract real components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def imag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract imaginary components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def abs(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute absolute values. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sign(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute signs elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sqrt(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute square roots elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sum( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - dtype: DType | None = None, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by summation. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def mean( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by arithmetic mean. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def min( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by minimum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def max( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by maximum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def prod( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - dtype: DType | None = None, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by product. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def trace(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to sum diagonal entries. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to return sorting indices. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to sort values. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of minimum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of maximum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a conjugating vector dot product. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def matmul(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute matrix products. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to multiply sparse and dense arrays. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a Kronecker product. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to evaluate an Einstein summation expression. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def eigh(self, x: DenseArray) -> tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenpairs. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute vector or matrix norms. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to solve dense linear systems. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def eigvalsh(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenvalues. - - Input: - A: Dense Hermitian or symmetric backend array. - - Output: - Dense backend array containing eigenvalues. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def svd(self, A: DenseArray, full_matrices: bool = True) -> tuple[DenseArray, DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute singular value decompositions. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def cholesky(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Cholesky factors. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, - keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute a stable log-sum-exp reduction. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. - - Output: - Dense backend array or tuple containing log-sum-exp results. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def exp(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute exponentials elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def log(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute natural logarithms elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def where(self, condition: DenseArray | bool, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to select values by condition. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def maximum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise maxima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def minimum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise minima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def clip(self, x: DenseArray, a_min: ArrayLike, a_max: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to clip values into an interval. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test finiteness elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def isnan(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test NaN values elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to join arrays along an existing axis. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to take values by integer indices. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def diag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract or build a diagonal. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def diagonal(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return selected diagonals. - - Input: - x: Dense backend array plus offset and axis controls. + def sum( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + dtype: DType | None = None, + ) -> DenseArray: + return self.xp.sum( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - Output: - Dense backend array containing selected diagonals. + def mean( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + return self.xp.mean(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def min( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + return self.xp.min(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - @abstractmethod - def tril(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return lower-triangular values. + def max( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + return self.xp.max(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - Input: - x: Dense backend array and optional diagonal offset. + def prod( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + dtype: DType | None = None, + ) -> DenseArray: + return self.xp.prod( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - Output: - Dense backend array with upper entries zeroed. + def trace(self, x: DenseArray) -> DenseArray: + if hasattr(self.xp, "trace"): + return self.xp.trace(x) + return self.sum(self.diagonal(x)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: + return self.xp.argsort(x, axis=axis) - @abstractmethod - def triu(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return upper-triangular values. + def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: + return self.xp.sort(x, axis=axis) - Input: - x: Dense backend array and optional diagonal offset. + def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + return self.xp.argmin(x, axis=axis, keepdims=keepdims) - Output: - Dense backend array with lower entries zeroed. + def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + return self.xp.argmax(x, axis=axis, keepdims=keepdims) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: + x_flat = self.ravel(x) + y_flat = self.ravel(y) + if hasattr(self.xp, "vdot"): + return self.xp.vdot(x_flat, y_flat) + return self.xp.vecdot(x_flat, y_flat) - @abstractmethod - def index_set( - self, - x: DenseArray, - index: Index, - values: ArrayLike, - *, - copy: bool = True, + def matmul( + self, + a: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to set indexed values. + return self.xp.matmul(a, b, **({} if backend_kwargs is None else backend_kwargs)) - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. + def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: + return self.xp.kron(a, b) - Output: - Dense backend array with indexed values set. + def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: + return self.xp.einsum(subscripts, *operands) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def eigh( + self, + x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray]: + if self.is_sparse(x): + raise TypeError("eigh requires a dense array; sparse input is not supported.") + return self.xp.linalg.eigh(x, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod - def index_add( - self, - x: DenseArray, - index: Index, - values: DenseArray, - *, - copy: bool = True, + def norm( + self, + x: DenseArray, + ord: int | str | None = None, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to add into indexed values. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def ix_(self, *args: Any) -> Any: - """ - Generic backend-agnostic wrapper to build open mesh index arrays. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. + return self.xp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def solve( + self, + A: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + return self.xp.linalg.solve(A, b, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a counted loop primitive. + def eigvalsh( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + return self.xp.linalg.eigvalsh(A, **({} if backend_kwargs is None else backend_kwargs)) - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. + def svd( + self, + A: DenseArray, + full_matrices: bool = True, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray, DenseArray]: + return self.xp.linalg.svd( + A, + full_matrices=full_matrices, + **({} if backend_kwargs is None else backend_kwargs), + ) + + def cholesky( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + return self.xp.linalg.cholesky(A, **({} if backend_kwargs is None else backend_kwargs)) - Output: - Final carry value after loop execution. + def exp(self, x: DenseArray) -> DenseArray: + return self.xp.exp(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def log(self, x: DenseArray) -> DenseArray: + return self.xp.log(x) - @abstractmethod - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a while-loop primitive. + def where(self, condition: DenseArray | bool, x: ArrayLike, y: ArrayLike) -> DenseArray: + return self.xp.where(condition, x, y) - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. + def maximum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: + return self.xp.maximum(x, y) - Output: - Final carry value after loop execution. + def minimum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: + return self.xp.minimum(x, y) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def clip(self, x: DenseArray, a_min: ArrayLike, a_max: ArrayLike) -> DenseArray: + return self.xp.clip(x, a_min, a_max) - @abstractmethod - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - ) -> Tuple[Carry, Y]: - """ - Generic backend-agnostic wrapper to run a scan primitive. + def isfinite(self, x: DenseArray) -> DenseArray: + return self.xp.isfinite(x) - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. + def isnan(self, x: DenseArray) -> DenseArray: + return self.xp.isnan(x) - Output: - Tuple of final carry and stacked outputs. + def concatenate( + self, + arrays: Sequence[DenseArray], + axis: int = 0, + dtype: DType | None = None, + ) -> DenseArray: + if hasattr(self.xp, "concat"): + result = self.xp.concat(tuple(arrays), axis=axis) + else: + result = self.xp.concatenate(tuple(arrays), axis=axis) + return self.astype(result, dtype) if dtype is not None else result - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def take( + self, + x: DenseArray, + indices: DenseArray, + axis: int | None = None, + ) -> DenseArray: + return self.xp.take(x, indices, axis=axis) - @abstractmethod - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Generic backend-agnostic wrapper to run conditional branch selection. + def diag(self, x: DenseArray) -> DenseArray: + return self.xp.diag(x) - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. + def diagonal(self, x: DenseArray) -> DenseArray: + return self.xp.diagonal(x) - Output: - Result returned by the selected branch. + def tril(self, x: DenseArray) -> DenseArray: + return self.xp.tril(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def triu(self, x: DenseArray) -> DenseArray: + return self.xp.triu(x) - @abstractmethod def allclose( self, a: DenseArray, @@ -1474,41 +787,7 @@ def allclose( atol: float = 1e-8, equal_nan: bool = False, ) -> bool: - """ - Generic backend-agnostic wrapper to compare dense arrays elementwise within tolerances. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def allclose_sparse( - self, - a: SparseArray, - b: SparseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - ) -> bool: - """ - Generic backend-agnostic wrapper to compare sparse arrays elementwise within tolerances. - - Input: - a, b: Sparse backend arrays; rtol and atol configure comparison. - - Output: - Boolean indicating whether sparse arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + return bool(self.xp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) def __repr__(self): return f"{type(self).__name__}" diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index 6ec1591..effce84 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Literal, Tuple, Callable, Optional, Type -import inspect from warnings import warn from .._family import BackendFamily @@ -57,21 +56,11 @@ class JaxOps(BackendOps): import jax import jax.numpy as jnp import jax.experimental.sparse as jsparse + xp = jnp _family = BackendFamily.jax.value.lower() _allow_sparse = True - def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.jnp.reshape).parameters - self._reshape_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.reshape).parameters - self._ravel_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ravel).parameters - self._zeros_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros).parameters - self._empty_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.empty).parameters - self._zeros_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros_like).parameters - self._ones_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ones_like).parameters - self._broadcast_to_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.broadcast_to).parameters - - def sanitize_dtype(self, dtype: DType | None) -> DType: """ Normalize a dtype specifier using JAX. @@ -123,77 +112,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return dt - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Jax ndarray or BCOO/BCSR, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - - Backend-specific notes: - Shape-polymorphic dimensions may not be concrete Python integers inside traced code. - """ - result = 1 - for dim in self.shape(x): - result *= dim - return result - @property def dense_array(self) -> Type[Any]: """ @@ -220,121 +138,6 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ return (self.jsparse.BCOO, self.jsparse.BCSR) - @property - def inf(self): - """ - Positive infinity scalar using JAX. - - Returns: - Backend scalar representing positive infinity. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.inf.html - """ - return self.jnp.array(self.jnp.inf) - - @property - def nan(self): - """ - NaN scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing NaN. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.nan.html - """ - return self.jnp.array(self.jnp.nan) - - @property - def pi(self): - """ - Pi scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing pi. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.pi.html - """ - return self.jnp.array(self.jnp.pi) - - @property - def e(self): - """ - Euler number scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing Euler's number. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.e.html - """ - return self.jnp.array(self.jnp.e) - - @property - def eps(self): - """ - Machine epsilon scalar using JAX. - - Input: - None. - - Output: - Backend scalar for float64 machine epsilon. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.finfo.html - """ - return self.jnp.array(self.jnp.finfo(self.jnp.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - copy: bool | None = None, - device: Any | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using JAX. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.asarray.html - """ - return self.jnp.asarray(a, dtype=dtype, order=order, copy=copy, device=device) - - def astype(self, x: DenseArray, dtype: DType, copy: bool = True) -> DenseArray: - """ - Cast an array to a dtype using JAX. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.astype.html - """ - return x.astype(dtype, copy=copy) - def assparse( self, x: Any, @@ -407,1436 +210,192 @@ def assparse( raise ValueError(f"Unknown sparse format: {format!r}") - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Create an uninitialized dense array using JAX. + Multiply sparse and dense arrays using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend array with uninitialized values. + Dense backend array containing the product. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.empty.html + https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. """ - if self._empty_supports_out_sharding: - return self.jnp.empty(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.empty(shape, dtype=dtype, device=device) + return a @ b - def zeros( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, + return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: """ - Create a zero-filled dense array using JAX. + Compute a stable log-sum-exp reduction using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Dense backend array; axis, weights, and sign options control the reduction. Output: - Dense backend array filled with zeros. + Dense backend array or tuple containing log-sum-exp results. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros.html - - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html """ - if self._zeros_supports_out_sharding: - return self.jnp.zeros(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.zeros(shape, dtype=dtype, device=device) + return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): """ - Create a one-filled dense array using JAX. + Set indexed values using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. Output: - Dense backend array filled with ones. + Dense backend array with indexed values set. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones.html + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html + + Backend-specific notes: + JAX arrays are immutable; copy=False raises NotImplementedError. """ - return self.jnp.ones(shape, dtype=dtype, device=device, out_sharding=out_sharding) + if not copy: + raise NotImplementedError( + "JAX arrays are immutable; copy=False is not supported." + ) + return x.at[index].set(values) - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def ix_(self, *args: Any) -> Any: """ - Create zeros shaped like another array using JAX. + Build open mesh index arrays using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + args: One-dimensional index arrays or sequences. Output: - Dense backend array of zeros. + Tuple of dense backend arrays usable for open-mesh indexing. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros_like.html - - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._zeros_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.zeros_like(x, **kwargs) + return self.jnp.ix_(*args) - def ones_like( + def fori_loop( self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + unroll: int | bool | None = None, + ) -> T: """ - Create ones shaped like another array using JAX. + Run a counted loop primitive using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array of ones. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones_like.html + https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + Loop bounds and unroll behavior follow JAX tracing and compilation rules. """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._ones_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.ones_like(x, **kwargs) + return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - def full_like( + def while_loop( self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - ) -> DenseArray: + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: """ - Create filled values shaped like another array using JAX. + Run a while-loop primitive using JAX. Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. + cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array filled with the requested value. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full_like.html - """ - return self.jnp.full_like(x, value, dtype=dtype, shape=shape, device=device) - - def arange(self, - start: int, - stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using JAX. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.arange.html + Backend-specific notes: + Condition and body are staged according to JAX lax control-flow semantics. """ - return self.jnp.arange(start, stop, step, dtype=dtype, device=device, out_sharding=out_sharding) + return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - def full( + def scan( self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + ) -> Tuple[Carry, Y]: """ - Create a filled dense array using JAX. + Run a scan primitive using JAX. Input: - shape: Output shape; fill_value and dtype options are backend-specific. + f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: - Dense backend array filled with fill_value. + Tuple of final carry and stacked outputs. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full.html - """ - return self.jnp.full(shape, fill_value, dtype=dtype, device=device) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using JAX. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.eye.html + Backend-specific notes: + Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. """ - return self.jnp.eye(N=N, M=M, k=k, dtype=dtype, device=device) + return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - def ravel( - self, - a: DenseArray, - order: Literal["C", "F", "A", "K"] = "C", - *, - out_sharding: Any | None = None, - ) -> DenseArray: + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: """ - Flatten an array using JAX. + Run conditional branch selection using JAX. Input: - x: Dense backend array plus optional order parameters. + pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: - One-dimensional dense backend array view or copy. + Result returned by the selected branch. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ravel.html - """ - if order == "C" and out_sharding is None: - return self.jnp.ravel(a) - if self._ravel_supports_out_sharding: - return self.jnp.ravel(a, order=order, out_sharding=out_sharding) - return self.jnp.ravel(a, order=order) + https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A"] = "C", - *, - copy: bool | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + Backend-specific notes: + Branches are staged according to JAX lax.cond semantics rather than Python eager branching. """ - Reshape an array using JAX. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. + return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.reshape.html - """ - if order == "C" and copy is None and out_sharding is None: - return self.jnp.reshape(a, shape) - kwargs: dict[str, Any] = {"order": order} - if self._reshape_supports_copy: - kwargs["copy"] = copy - if self._reshape_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.reshape(a, shape, **kwargs) - - def transpose( - self, - x: DenseArray, - axes: Sequence[int] | None = None, - ) -> DenseArray: + def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ - Permute array axes using JAX. + Add into indexed values using JAX. Input: - x: Dense backend array; axes: Optional axis order. + x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. Output: - Dense backend array with permuted axes. + Dense backend array with indexed values incremented. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.transpose.html - """ - return self.jnp.transpose(x, axes=axes) + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Interchange two axes using JAX. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.swapaxes.html - """ - return self.jnp.swapaxes(x, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - *, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Broadcast an array to a shape using JAX. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.broadcast_to.html - """ - if self._broadcast_to_supports_out_sharding: - return self.jnp.broadcast_to(x, shape, out_sharding=out_sharding) - return self.jnp.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using JAX. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.expand_dims.html - """ - return self.jnp.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using JAX. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.squeeze.html - """ - return self.jnp.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using JAX. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.moveaxis.html - """ - return self.jnp.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: Any | None = None, - dtype: DType | None = None, - ) -> DenseArray: - """ - Stack arrays along a new axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.stack.html - """ - return self.jnp.stack(arrays, axis=axis, out=out, dtype=dtype) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Compute complex conjugates using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.conj.html - """ - return self.jnp.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.real.html - """ - return self.jnp.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.imag.html - """ - return self.jnp.imag(x) - - def abs(self, x: DenseArray) -> DenseArray: - """ - Compute absolute values using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.abs.html - """ - return self.jnp.abs(x) - - def sign(self, x: DenseArray) -> DenseArray: - """ - Compute signs elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sign.html - """ - return self.jnp.sign(x) - - def sqrt(self, x: DenseArray) -> DenseArray: - """ - Compute square roots elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sqrt.html - """ - return self.jnp.sqrt(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by summation using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sum.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.sum(a) - return self.jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: None = None, - keepdims: bool = False, - *, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by arithmetic mean using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.mean.html - """ - if out is None and where is None: - if axis is None and dtype is None and not keepdims: - return self.jnp.mean(a) - return self.jnp.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by minimum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.min.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.min(a) - return self.jnp.min(a, axis=axis, keepdims=keepdims) - return self.jnp.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by maximum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.max.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.max(a) - return self.jnp.max(a, axis=axis, keepdims=keepdims) - return self.jnp.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by product using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.prod.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.prod(a) - return self.jnp.prod(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def trace( - self, - a: DenseArray, - offset: int | Any = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using JAX. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.trace.html - """ - return self.jnp.trace(a, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out) - - def argsort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Return sorting indices using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argsort.html - """ - return self.jnp.argsort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def sort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Sort values using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sort.html - """ - return self.jnp.sort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmin.html - """ - return self.jnp.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmax.html - """ - return self.jnp.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - ) -> DenseArray: - """ - Compute a conjugating vector dot product using JAX. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vdot.html - """ - if precision is None and preferred_element_type is None: - return self.jnp.vdot(a, b) - return self.jnp.vdot(a, b, precision=precision, preferred_element_type=preferred_element_type) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - out_sharding: Any | None = None - ) -> DenseArray: - """ - Compute matrix products using JAX. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html - """ - if precision is None and preferred_element_type is None and out_sharding is None: - return self.jnp.matmul(a, b) - return self.jnp.matmul( - a, - b, - precision=precision, - preferred_element_type=preferred_element_type, - out_sharding=out_sharding, - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using JAX. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/jax.experimental.sparse.html - - Backend-specific notes: - Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. - """ - return a @ b - - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a Kronecker product using JAX. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.kron.html - """ - return self.jnp.kron(a, b) - - def einsum( - self, - subscripts: str, - /, - *operands: DenseArray, - out: Any | None = None, - optimize: str | bool | list[Tuple[int, ...]] = "auto", - precision: Any | None = None, - preferred_element_type: DType | None = None, - _dot_general: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Evaluate an Einstein summation expression using JAX. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.einsum.html - """ - return self.jnp.einsum( - subscripts, - *operands, - out=out, - optimize=optimize, - precision=precision, - preferred_element_type=preferred_element_type, - # _dot_general=_dot_general, - out_sharding=out_sharding, - ) - - def eigh( - self, - x: DenseArray, - UPLO: Literal["L", "U"] = "L", - symmetrize_input: bool = True - ) -> Tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenpairs using JAX. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigh.html - - Backend-specific notes: - SpaceCore rejects sparse input before delegating to JAX dense linear algebra. - """ - if self.is_sparse(x): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.jnp.linalg.eigh(x, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Compute vector or matrix norms using JAX. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.norm.html - """ - return self.jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Solve dense linear systems using JAX. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.solve.html - """ - return self.jnp.linalg.solve(A, b) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using JAX. - - Input: - A: Dense Hermitian or symmetric backend array. - - Output: - Dense backend array containing eigenvalues. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigvalsh.html - """ - return self.jnp.linalg.eigvalsh(A, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - subset_by_index: tuple[int, int] | None = None, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decompositions using JAX. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html - """ - return self.jnp.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - subset_by_index=subset_by_index, - ) - - def cholesky( - self, - A: DenseArray, - *, - upper: bool = False, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Cholesky factors using JAX. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.cholesky.html - """ - return self.jnp.linalg.cholesky(A, upper=upper, symmetrize_input=symmetrize_input) - - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, - return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Compute a stable log-sum-exp reduction using JAX. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. - - Output: - Dense backend array or tuple containing log-sum-exp results. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html - """ - return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - - def exp(self, x: DenseArray) -> DenseArray: - """ - Compute exponentials elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.exp.html - """ - return self.jnp.exp(x) - - def log(self, x: DenseArray) -> DenseArray: - """ - Compute natural logarithms elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.log.html - """ - return self.jnp.log(x) - - def maximum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise maxima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.maximum.html - """ - return self.jnp.maximum(x, y) - - def minimum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise minima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.minimum.html - """ - return self.jnp.minimum(x, y) - - def clip(self, x: DenseArray, a_min: DenseArray, a_max: DenseArray) -> DenseArray: - """ - Clip values into an interval using JAX. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html - """ - return self.jnp.clip(x, a_min, a_max) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isfinite.html - """ - return self.jnp.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isnan.html - """ - return self.jnp.isnan(x) - - def where(self, condition: DenseArray | bool, x: DenseArray | None = None, y: DenseArray | None = None, *, - size: int | None = None, fill_value: DenseArray | None = None) -> DenseArray: - """ - Select values by condition using JAX. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html - """ - return self.jnp.where(condition, x, y, size=size, fill_value=fill_value) - - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Join arrays along an existing axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.concatenate.html - """ - if dtype is None: - return self.jnp.concatenate(arrays, axis=axis) - return self.jnp.concatenate(arrays, axis=axis, dtype=dtype) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: None = None, - mode: str | None = None, - unique_indices: bool = False, - indices_are_sorted: bool = False, - fill_value: Any | None = None, - ) -> DenseArray: - """ - Take values by integer indices using JAX. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.take.html - - Backend-specific notes: - Out-of-bounds and mode behavior follow JAX, which can differ from NumPy. - """ - return self.jnp.take( - x, - indices, - axis=axis, - out=out, - mode=mode, - unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, - fill_value=fill_value, - ) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diag.html - """ - return self.jnp.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using JAX. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diagonal.html - """ - return self.jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.tril.html - """ - return self.jnp.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.triu.html - """ - return self.jnp.triu(x, k=k) - - def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): - """ - Set indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - - Backend-specific notes: - JAX arrays are immutable; copy=False raises NotImplementedError. - """ - if not copy: - raise NotImplementedError( - "JAX arrays are immutable; copy=False is not supported." - ) - return x.at[index].set(values) - - def ix_(self, *args: Any) -> Any: - """ - Build open mesh index arrays using JAX. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html - """ - return self.jnp.ix_(*args) - - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - *, - unroll: int | bool | None = None, - ) -> T: - """ - Run a counted loop primitive using JAX. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html - - Backend-specific notes: - Loop bounds and unroll behavior follow JAX tracing and compilation rules. - """ - return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Run a while-loop primitive using JAX. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - - Backend-specific notes: - Condition and body are staged according to JAX lax control-flow semantics. - """ - return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - _split_transpose: bool = False, - ) -> Tuple[Carry, Y]: - """ - Run a scan primitive using JAX. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - - Backend-specific notes: - Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. - """ - return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Run conditional branch selection using JAX. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html - - Backend-specific notes: - Branches are staged according to JAX lax.cond semantics rather than Python eager branching. - """ - return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - - def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): - """ - Add into indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - - Backend-specific notes: - JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. + Backend-specific notes: + JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. """ if not copy: raise NotImplementedError( @@ -1844,28 +403,6 @@ def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bo ) return x.at[index].add(values) - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using JAX. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.allclose.html - """ - return self.jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/numpy/_ops.py b/spacecore/backend/numpy/_ops.py index 0a309ad..6fc6263 100644 --- a/spacecore/backend/numpy/_ops.py +++ b/spacecore/backend/numpy/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Tuple, Literal, Callable, Optional, Type -import inspect from .._family import BackendFamily from .._ops import BackendOps @@ -49,6 +48,7 @@ class NumpyOps(BackendOps): """ import numpy as np import scipy as sp + import array_api_compat.numpy as xp _family = BackendFamily.numpy.value.lower() _allow_sparse = True @@ -86,1311 +86,86 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: return tuple(types) - def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.np.reshape).parameters - - def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Normalize a dtype specifier using NumPy. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.dtype.html - """ - - if dtype is None: - return self.np.float64 - return self.np.dtype(dtype) - - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Numpy ndarray or SciPy sparse array, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html - - Backend-specific notes: - SciPy sparse inputs are reported by logical dense size, not stored entries. - """ - return int(self.np.prod(self.shape(x), dtype=self.np.intp)) - - @property - def inf(self): - """ - Positive infinity scalar using NumPy. - - Returns: - Backend scalar representing positive infinity. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.inf) - - @property - def nan(self): - """ - NaN scalar using NumPy. - - Returns: - Backend scalar representing NaN. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.nan) - - @property - def pi(self): - """ - Pi scalar using NumPy. - - Returns: - Backend scalar representing pi. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.pi) - - @property - def e(self): - """ - Euler number scalar using NumPy. - - Returns: - Backend scalar representing Euler's number. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using NumPy. - - Returns: - Backend scalar for float64 machine epsilon. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.finfo.html - """ - return self.np.array(self.np.finfo(self.np.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - device: str | None = None, - copy: bool | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using NumPy. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.asarray.html - """ - return self.np.asarray( - a, - dtype=dtype, - order=order, - device=device, - copy=copy, - like=like, - ) - - def astype( - self, - x: DenseArray, - dtype: DType, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "unsafe", - subok: bool = True, - copy: bool = True, - ) -> DenseArray: - """ - Cast an array to a dtype using NumPy. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.astype.html - """ - return x.astype(dtype, order=order, casting=casting, subok=subok, copy=copy) - - def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: - """ - Convert input to a sparse array using SciPy. - - Input: - x: Dense, sparse, or array-like input plus sparse-format options. - - Output: - Sparse backend array. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. - """ - sparse = self.sp.sparse - - if self.is_sparse(x): - if format == "csr": - return x.tocsr() - if format == "csc": - return x.tocsc() - if format == "coo": - return x.tocoo() - raise ValueError(f"Unknown sparse format: {format!r}") - - x_arr = self.asarray(x) - - if x_arr.ndim != 2: - raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - - if format == "csr": - return sparse.csr_matrix(x_arr) - if format == "csc": - return sparse.csc_matrix(x_arr) - if format == "coo": - return sparse.coo_matrix(x_arr) - - raise ValueError(f"Unknown sparse format: {format!r}") - - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create an uninitialized dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array with uninitialized values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.empty.html - """ - - return self.np.empty( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a zero-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros.html - """ - return self.np.zeros( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a one-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones.html - """ - return self.np.ones( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create zeros shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros_like.html - """ - return self.np.zeros_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def ones_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create ones shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones_like.html - """ - return self.np.ones_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def full_like( - self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create filled values shaped like another array using NumPy. - - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with the requested value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full_like.html - """ - return self.np.full_like(x, value, dtype=dtype, order=order, subok=subok, shape=shape) - - def arange(self, - start: int, stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using NumPy. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.arange.html - """ - return self.np.arange( - start, - stop, - step, - dtype=dtype, - device=device, - like=like, - ) - - def full( - self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a filled dense array using NumPy. - - Input: - shape: Output shape; fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with fill_value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full.html - """ - return self.np.full( - shape, - fill_value, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using NumPy. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.eye.html - """ - return self.np.eye( - N, - M=M, - k=k, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ravel(self, a: DenseArray, order: Literal["C", "F", "A", "K"] = "C") -> DenseArray: - """ - Flatten an array using NumPy. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ravel.html - """ - return self.np.ravel(a, order=order) - - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A", "K"] = "C", - copy: bool | None = None, - ) -> DenseArray: - """ - Reshape an array using NumPy. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.reshape.html - """ - if self._reshape_supports_copy: - return self.np.reshape(a, shape, order=order, copy=copy) - if copy: - return self.np.array(a, copy=True).reshape(shape, order=order) - return self.np.reshape(a, shape, order=order) - - def transpose(self, a: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute array axes using NumPy. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.transpose.html - """ - return self.np.transpose(a, axes=axes) - - def swapaxes(self, a: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Interchange two axes using NumPy. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html - """ - return self.np.swapaxes(a, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - subok: bool = False, - ) -> DenseArray: - """ - Broadcast an array to a shape using NumPy. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html - """ - return self.np.broadcast_to(x, shape, subok=subok) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html - """ - return self.np.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html - """ - return self.np.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using NumPy. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html - """ - return self.np.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Stack arrays along a new axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.stack.html - """ - return self.np.stack(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def conj( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute complex conjugates using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.conj.html - """ - return self.np.conj( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def abs( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute absolute values using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.abs.html - """ - return self.np.abs( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sign( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute signs elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sign.html - """ - return self.np.sign( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sqrt( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute square roots elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html - """ - return self.np.sqrt( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.real.html - """ - return self.np.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.imag.html - """ - return self.np.imag(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by summation using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sum.html - """ - return self.np.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by arithmetic mean using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.mean.html - """ - return self.np.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by minimum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.min.html - """ - return self.np.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by maximum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.max.html - """ - return self.np.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by product using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.prod.html - """ - return self.np.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def trace( - self, - a: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using NumPy. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.trace.html - """ - return self.np.trace( - a, - offset=offset, - axis1=axis1, - axis2=axis2, - dtype=dtype, - out=out, - ) - - def argsort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Return sorting indices using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argsort.html - """ - return self.np.argsort(a, axis=axis, kind=kind, order=order, stable=stable) - - def sort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Sort values using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sort.html - """ - return self.np.sort(a, axis=axis, kind=kind, order=order, stable=stable) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmin.html - """ - return self.np.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmax.html - """ - return self.np.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a conjugating vector dot product using NumPy. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.vdot.html - """ - return self.np.vdot(a, b) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - /, - out: DenseArray | None = None, - *, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute matrix products using NumPy. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.matmul.html - """ - return self.np.matmul( - a, - b, - out=out, - casting=casting, - order=order, - dtype=dtype, - subok=subok - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using SciPy. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. - """ - if not self.is_sparse(a): - raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") - if not self.is_dense(b): - raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") - return a @ b - - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a Kronecker product using NumPy. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.kron.html - """ - return self.np.kron(a, b) - - def einsum( - self, - subscripts: str, - *operands: DenseArray, - out: DenseArray | None = None, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "safe", - optimize: bool | str | Sequence[Any] = False, - ) -> DenseArray: + def sanitize_dtype(self, dtype: DType | None) -> DType: """ - Evaluate an Einstein summation expression using NumPy. + Normalize a dtype specifier using NumPy. Input: - subscripts: Einstein summation string; operands: Dense backend arrays. + dtype: Optional dtype requested by SpaceCore or the caller. Output: - Dense backend array containing the contraction result. + Backend dtype object accepted by array constructors. See: - https://numpy.org/doc/stable/reference/generated/numpy.einsum.html - """ - return self.np.einsum( - subscripts, - *operands, - out=out, - dtype=dtype, - order=order, - casting=casting, - optimize=optimize, - ) - - def eigh(self, a: DenseArray, UPLO: Literal["L", "U"] = "L") -> Tuple[DenseArray, DenseArray]: + https://numpy.org/doc/stable/reference/generated/numpy.dtype.html """ - Compute Hermitian eigenpairs using NumPy. - - Input: - x: Dense Hermitian or symmetric backend array. - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigh.html - """ - if self.is_sparse(a): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.np.linalg.eigh(a, UPLO=UPLO) + if dtype is None: + return self.np.float64 + return self.np.dtype(dtype) - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: + def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: """ - Compute vector or matrix norms using NumPy. + Convert input to a sparse array using SciPy. Input: - x: Dense backend array; ord, axis, and keepdims select the norm. + x: Dense, sparse, or array-like input plus sparse-format options. Output: - Dense backend array or scalar containing norm values. + Sparse backend array. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html - """ - return self.np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + https://docs.scipy.org/doc/scipy/reference/sparse.html - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: + Backend-specific notes: + SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. """ - Solve dense linear systems using NumPy. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. + sparse = self.sp.sparse - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html - """ - return self.np.linalg.solve(A, b) + if self.is_sparse(x): + if format == "csr": + return x.tocsr() + if format == "csc": + return x.tocsc() + if format == "coo": + return x.tocoo() + raise ValueError(f"Unknown sparse format: {format!r}") - def eigvalsh(self, A: DenseArray, UPLO: Literal["L", "U"] = "L") -> DenseArray: - """ - Compute Hermitian eigenvalues using NumPy. + x_arr = self.asarray(x) - Input: - A: Dense Hermitian or symmetric backend array. + if x_arr.ndim != 2: + raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - Output: - Dense backend array containing eigenvalues. + if format == "csr": + return sparse.csr_matrix(x_arr) + if format == "csc": + return sparse.csc_matrix(x_arr) + if format == "coo": + return sparse.coo_matrix(x_arr) - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigvalsh.html - """ - return self.np.linalg.eigvalsh(A, UPLO=UPLO) + raise ValueError(f"Unknown sparse format: {format!r}") - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Compute singular value decompositions using NumPy. + Multiply sparse and dense arrays using SciPy. Input: - A: Dense backend array plus SVD options. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend arrays containing singular vectors and/or singular values. + Dense backend array containing the product. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html - """ - return self.np.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - ) - - def cholesky(self, A: DenseArray) -> DenseArray: - """ - Compute Cholesky factors using NumPy. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. + https://docs.scipy.org/doc/scipy/reference/sparse.html - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.cholesky.html + Backend-specific notes: + Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. """ - return self.np.linalg.cholesky(A) + if not self.is_sparse(a): + raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") + if not self.is_dense(b): + raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") + return a @ b def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: @@ -1408,362 +183,6 @@ def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: D """ return self.sp.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) - def exp( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute exponentials elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.exp.html - """ - return self.np.exp( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def log( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute natural logarithms elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.log.html - """ - return self.np.log( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def maximum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise maxima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.maximum.html - """ - return self.np.maximum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def minimum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise minima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.minimum.html - """ - return self.np.minimum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def clip( - self, - x: DenseArray, - a_min: DenseArray, - a_max: DenseArray, - out: DenseArray | None = None, - **kwargs: Any, - ) -> DenseArray: - """ - Clip values into an interval using NumPy. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.clip.html - """ - return self.np.clip(x, a_min, a_max, out=out, **kwargs) - - def isfinite( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test finiteness elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html - """ - return self.np.isfinite( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def isnan( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test NaN values elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isnan.html - """ - return self.np.isnan( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def where(self, condition: DenseArray | bool, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Select values by condition using NumPy. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.where.html - """ - return self.np.where(condition, x, y) - - def concatenate( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Join arrays along an existing axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html - """ - return self.np.concatenate(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - mode: Literal["raise", "wrap", "clip"] = "raise", - ) -> DenseArray: - """ - Take values by integer indices using NumPy. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.take.html - """ - return self.np.take(x, indices, axis=axis, out=out, mode=mode) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diag.html - """ - return self.np.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using NumPy. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diagonal.html - """ - return self.np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.tril.html - """ - return self.np.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.triu.html - """ - return self.np.triu(x, k=k) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed values using NumPy. @@ -2035,28 +454,6 @@ def index_add( self.np.add.at(y, index, values) return y - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using NumPy. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.allclose.html - """ - return self.np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index cd6defb..82b2154 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -5,8 +5,8 @@ import numpy as np from .._family import BackendFamily -from .._ops import BackendOps -from ...types import ArrayLike, DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry +from .._ops import BackendOps, LazyNamespace +from ...types import DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry class TorchOps(BackendOps): @@ -52,6 +52,7 @@ class TorchOps(BackendOps): """ import torch + xp = LazyNamespace("array_api_compat.torch") _family = BackendFamily.torch.value.lower() _allow_sparse = True @@ -180,193 +181,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return mapping[np_dtype] raise TypeError(f"Dtype {np_dtype!r} is not supported by PyTorch.") - def get_dtype(self, x: Any) -> DType: - """ - Return a tensor dtype using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Backend dtype associated with x. - - See: - https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype - """ - if self.is_array(x): - return x.dtype - raise TypeError(f"Expected PyTorch tensor, got {type(x)}.") - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return tensor shape metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return tensor rank metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Number of dimensions in x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Total number of logical dense elements. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.numel.html - """ - return int(x.numel()) - - @property - def inf(self): - """ - Positive infinity scalar using PyTorch. - - Returns: - Backend tensor scalar representing positive infinity. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("inf")) - - @property - def nan(self): - """ - NaN scalar using PyTorch. - - Returns: - Backend tensor scalar representing NaN. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("nan")) - - @property - def pi(self): - """ - Pi scalar using PyTorch. - - Returns: - Backend tensor scalar representing pi. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.pi) - - @property - def e(self): - """ - Euler number scalar using PyTorch. - - Returns: - Backend tensor scalar representing Euler's number. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using PyTorch. - - Returns: - Backend tensor scalar for float64 machine epsilon. - - See: - https://docs.pytorch.org/docs/stable/type_info.html#torch.finfo - """ - return self.torch.tensor(self.torch.finfo(self.torch.float64).eps) - - def asarray( - self, - x: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - copy: bool | None = None, - ) -> DenseArray: - """ - Convert input to a dense tensor using PyTorch. - - Input: - x/a: Array-like input and optional dtype, device, or copy controls. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.as_tensor.html - - Backend-specific notes: - Sparse tensors are densified. Existing tensors keep autograd - metadata according to normal PyTorch conversion rules. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if self.is_sparse(x): - x = x.to_dense() - out = self.torch.as_tensor(x, dtype=dtype, device=device) - if copy: - out = out.clone() - return out - - def astype( - self, - x: DenseArray, - dtype: DType, - copy: bool = True, - *, - non_blocking: bool = False, - memory_format: Any | None = None, - ) -> DenseArray: - """ - Cast a tensor to a dtype using PyTorch. - - Input: - x: Dense backend tensor; dtype: Target dtype; copy: Whether to force a copy. - - Output: - Tensor with the requested dtype. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html - """ - return x.to( - dtype=self.sanitize_dtype(dtype), - non_blocking=non_blocking, - copy=copy, - **self._defined_kwargs(memory_format=memory_format), - ) - def assparse( self, x: Any, @@ -427,72 +241,67 @@ def assparse( return out.to_sparse_csc() raise ValueError(f"Unknown sparse format: {format!r}") - def empty( + def asarray( self, - shape: int | Tuple[int, ...], + x: Any, dtype: DType | None = None, *, - out: DenseArray | None = None, - layout: Any | None = None, device: Any | None = None, - requires_grad: bool = False, - pin_memory: bool = False, - memory_format: Any | None = None, + copy: bool | None = None, + backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: - """ - Create an uninitialized dense tensor using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor with uninitialized values. + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if device is not None: + kwargs["device"] = device + dtype = self.sanitize_dtype(dtype) if dtype is not None else None + if self.is_sparse(x): + x = x.to_dense() + out = self.torch.as_tensor(x, dtype=dtype, **kwargs) + return out.clone() if copy else out - See: - https://docs.pytorch.org/docs/stable/generated/torch.empty.html - """ - return self.torch.empty( - shape, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - pin_memory=pin_memory, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), + def astype( + self, + x: DenseArray, + dtype: DType, + *, + copy: bool = True, + non_blocking: bool = False, + memory_format: Any | None = None, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(memory_format=memory_format)) + return x.to( + dtype=self.sanitize_dtype(dtype), + non_blocking=non_blocking, + copy=copy, + **kwargs, ) - def zeros( + def empty( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, + pin_memory: bool = False, + memory_format: Any | None = None, ) -> DenseArray: - """ - Create a dense tensor filled with zeros using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros.html - """ - return self.torch.zeros( + return self.torch.empty( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), + pin_memory=pin_memory, + **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones( + def zeros( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, @@ -500,19 +309,7 @@ def ones( device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: - """ - Create a dense tensor filled with ones using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones.html - """ - return self.torch.ones( + return self.torch.zeros( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -530,18 +327,6 @@ def zeros_like( requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: - """ - Create a zero tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros_like.html - """ return self.torch.zeros_like( x, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -549,867 +334,93 @@ def zeros_like( **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones_like( + def arange( self, - x: DenseArray, + start: int, + stop: int | None = None, + step: int | None = None, dtype: DType | None = None, *, + out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, - memory_format: Any | None = None, ) -> DenseArray: - """ - Create a one tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones_like.html - """ - return self.torch.ones_like( - x, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) + kwargs = self._defined_kwargs(out=out, layout=layout, device=device) + kwargs["requires_grad"] = requires_grad + dtype = self.sanitize_dtype(dtype) if dtype is not None else None + if stop is None: + return self.torch.arange(start, dtype=dtype, **kwargs) + if step is None: + return self.torch.arange(start, stop, dtype=dtype, **kwargs) + return self.torch.arange(start, stop, step, dtype=dtype, **kwargs) - def full_like( + def sum( self, x: DenseArray, - value: Any, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, dtype: DType | None = None, *, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - memory_format: Any | None = None, + out: DenseArray | None = None, + ) -> DenseArray: + kwargs = {"dim": self._to_axis_tuple(axis), "keepdim": keepdims} + if dtype is not None: + kwargs["dtype"] = self.sanitize_dtype(dtype) + if out is not None: + kwargs["out"] = out + return self.torch.sum(x, **kwargs) + + def matmul( + self, + a: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + *, + out: DenseArray | None = None, + ) -> DenseArray: + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if out is not None: + kwargs["out"] = out + return self.torch.matmul(a, b, **kwargs) + + def sparse_matmul( + self, + a: SparseArray, + b: DenseArray, + *, + reduce: Literal["sum", "mean", "amax", "amin"] = "sum", ) -> DenseArray: """ - Create a filled tensor matching another tensor using PyTorch. + Matrix-multiply a sparse tensor by a dense tensor using PyTorch. Input: - x: Reference tensor; value: Fill value; dtype and device: Optional overrides. + a: Sparse backend tensor; b: Dense backend tensor or vector. Output: - Dense backend tensor with shape matching x. + Dense backend tensor. See: - https://docs.pytorch.org/docs/stable/generated/torch.full_like.html - """ - return self.torch.full_like( - x, - value, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) - - def arange( - self, - start: int | float = 0, - stop: int | float | None = None, - step: int | float | None = None, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a range tensor using PyTorch. - - Input: - start, stop, step: Range parameters; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor containing evenly spaced values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.arange.html - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - kwargs = self._defined_kwargs(out=out, layout=layout, device=device) - kwargs["requires_grad"] = requires_grad - if stop is None: - return self.torch.arange(start, dtype=dtype, **kwargs) - if step is None: - return self.torch.arange(start, stop, dtype=dtype, **kwargs) - return self.torch.arange(start, stop, step, dtype=dtype, **kwargs) - - def full( - self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a filled dense tensor using PyTorch. - - Input: - shape: Output shape; fill_value: Fill value; dtype and device: Optional parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.full.html - """ - return self.torch.full( - shape, - fill_value, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - - def eye( - self, - n: int, - m: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a two-dimensional identity-like tensor using PyTorch. - - Input: - n, m: Matrix dimensions; k: Diagonal offset; dtype and device: Optional parameters. - - Output: - Dense backend tensor with ones on the requested diagonal. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.eye.html - - Backend-specific notes: - PyTorch ``torch.eye`` has no diagonal offset parameter, so SpaceCore - constructs the offset diagonal explicitly. - """ - m = n if m is None else m - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if k == 0: - return self.torch.eye( - n, - m, - out=out, - dtype=dtype, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - out = self.torch.zeros( - (n, m), - out=out, - dtype=dtype, - requires_grad=False, - **self._defined_kwargs(layout=layout, device=device), - ) - diag_len = min(n, m - k) if k > 0 else min(n + k, m) - if diag_len <= 0: - return out - rows = self.torch.arange(diag_len, device=device) - cols = rows + k - if k < 0: - rows = rows - k - cols = self.torch.arange(diag_len, device=device) - out[rows, cols] = 1 - if requires_grad: - out.requires_grad_() - return out - - def ravel(self, x: DenseArray) -> DenseArray: - """ - Flatten a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - One-dimensional tensor view or copy following PyTorch semantics. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ravel.html - """ - return self.torch.ravel(x) - - def reshape(self, x: DenseArray, shape: int | Tuple[int, ...], *, copy: bool | None = None) -> DenseArray: - """ - Reshape a tensor using PyTorch. - - Input: - x: Dense backend tensor; shape: Target shape; copy: Whether to clone first. - - Output: - Reshaped dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.reshape.html - """ - if copy: - x = x.clone() - return self.torch.reshape(x, shape if isinstance(shape, tuple) else (shape,)) - - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axes: Optional axis order. - - Output: - Tensor with permuted axes. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.permute.html - """ - if axes is None: - axes = tuple(reversed(range(x.ndim))) - return x.permute(tuple(axes)) - - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Swap two tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axis1, axis2: Axes to swap. - - Output: - Tensor with the requested axes swapped. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.swapaxes.html - """ - return self.torch.swapaxes(x, axis1, axis2) - - def broadcast_to(self, x: DenseArray, shape: int | Tuple[int, ...]) -> DenseArray: - """ - Broadcast a tensor to a shape using PyTorch. - - Input: - x: Dense backend tensor; shape: Target broadcast shape. - - Output: - Broadcasted tensor view following PyTorch broadcasting rules. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html - """ - return self.torch.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Axis or axes where dimensions are inserted. - - Output: - Tensor with inserted singleton dimensions. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html - """ - if isinstance(axis, int): - return self.torch.unsqueeze(x, axis) - ndim = x.ndim + len(axis) - axes = sorted(a + ndim if a < 0 else a for a in axis) - out = x - for ax in axes: - out = self.torch.unsqueeze(out, ax) - return out - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Optional axis or axes to squeeze. - - Output: - Tensor with singleton dimensions removed. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html - """ - if axis is None: - return self.torch.squeeze(x) - if isinstance(axis, int): - return self.torch.squeeze(x, dim=axis) - out = x - for ax in sorted(axis, reverse=True): - out = self.torch.squeeze(out, dim=ax) - return out - - def moveaxis(self, x: DenseArray, source: int | Sequence[int], destination: int | Sequence[int]) -> DenseArray: - """ - Move tensor axes to new positions using PyTorch. - - Input: - x: Dense backend tensor; source and destination: Axis positions. - - Output: - Tensor with axes moved. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.moveaxis.html - """ - return self.torch.moveaxis(x, source, destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Stack tensors along a new axis using PyTorch. - - Input: - arrays: Sequence of tensors; axis: New axis; out: Optional output tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.stack.html - """ - arrays = tuple(arrays) - if out is None: - return self.torch.stack(arrays, dim=axis) - return self.torch.stack(arrays, dim=axis, out=out) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Return the complex conjugate using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor containing complex conjugates. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.conj.html - """ - return self.torch.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Return the real part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing real components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.real.html - """ - return self.torch.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Return the imaginary part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing imaginary components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.imag.html - """ - return self.torch.imag(x) - - def abs( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise absolute value using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.abs.html - """ - if out is None: - return self.torch.abs(x) - return self.torch.abs(x, out=out) - - def sign( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise sign using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sign.html - """ - if out is None: - return self.torch.sign(x) - return self.torch.sign(x, out=out) - - def sqrt( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise square root using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sqrt.html - """ - if out is None: - return self.torch.sqrt(x) - return self.torch.sqrt(x, out=out) - - def sum( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sum.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x) - return self.torch.sum(x, dim=axis, keepdim=keepdims) - return self.torch.sum(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x, dtype=dtype) - return self.torch.sum(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.sum( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def mean( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Average tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.mean.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x) - return self.torch.mean(x, dim=axis, keepdim=keepdims) - return self.torch.mean(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x, dtype=dtype) - return self.torch.mean(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.mean( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def min( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amin.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amin(x) - return self.torch.amin(x, dim=axis, keepdim=keepdims) - return self.torch.amin(x, dim=axis, keepdim=keepdims, out=out) - - def max( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amax.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amax(x) - return self.torch.amax(x, dim=axis, keepdim=keepdims) - return self.torch.amax(x, dim=axis, keepdim=keepdims, out=out) - - def prod( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Multiply tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.prod.html - - Backend-specific notes: - Multiple-axis products are applied one axis at a time because - PyTorch's ``torch.prod`` reduces a single dimension per call. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if axis is None: - result = self.torch.prod(x) if dtype is None else self.torch.prod(x, dtype=dtype) - if out is not None: - out.copy_(result) - return out - return result - if isinstance(axis, int): - if out is None: - if dtype is None: - return self.torch.prod(x, dim=axis, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) - result = x - for ax in sorted(axis, reverse=True): - result = self.torch.prod(result, dim=ax, dtype=dtype, keepdim=keepdims) - if out is not None: - out.copy_(result) - return out - return result - - def trace( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2, dtype: Diagonal controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.sum(self.diagonal(x, offset=offset, axis1=axis1, axis2=axis2), dtype=dtype) - - def argsort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - ) -> DenseArray: - """ - Return sorting indices using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argsort.html - """ - return self.torch.argsort(x, dim=axis, stable=stable, descending=descending) - - def sort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - *, - out: tuple[DenseArray, DenseArray] | None = None, - ) -> DenseArray: - """ - Sort tensor values using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Dense backend tensor of sorted values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sort.html - """ - return self.torch.sort(x, dim=axis, stable=stable, descending=descending, out=out).values - - def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmin.html - """ - return self.torch.argmin(x, dim=axis, keepdim=keepdims) - - def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmax.html - """ - return self.torch.argmax(x, dim=axis, keepdim=keepdims) - - def vdot( - self, - x: DenseArray, - y: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute conjugating vector dot product using PyTorch. - - Input: - x, y: Dense backend tensors. - - Output: - Scalar tensor containing the vector dot product. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.vdot.html - """ - x1 = x if x.ndim == 1 else self.torch.ravel(x) - y1 = y if y.ndim == 1 else self.torch.ravel(y) - if out is None: - return self.torch.vdot(x1, y1) - return self.torch.vdot(x1, y1, out=out) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Matrix-multiply tensors using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.matmul.html - """ - if out is None: - return self.torch.matmul(a, b) - return self.torch.matmul(a, b, out=out) - - def sparse_matmul( - self, - a: SparseArray, - b: DenseArray, - *, - reduce: Literal["sum", "mean", "amax", "amin"] = "sum", - ) -> DenseArray: - """ - Matrix-multiply a sparse tensor by a dense tensor using PyTorch. - - Input: - a: Sparse backend tensor; b: Dense backend tensor or vector. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html + https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html """ kwargs = {"reduce": reduce} if reduce != "sum" else {} if b.ndim == 1: return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs) - def kron( - self, - a: DenseArray, - b: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute the Kronecker product using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.kron.html - """ - return self.torch.kron(a, b, out=out) - - def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Evaluate an Einstein summation using PyTorch. - - Input: - subscripts: Einsum expression; operands: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.einsum.html - """ - return self.torch.einsum(subscripts, *operands) - def eigh( self, x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, UPLO: Literal["L", "U"] = "L", *, out: tuple[DenseArray, DenseArray] | None = None, ) -> tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenvalues and eigenvectors using PyTorch. - - Input: - x: Dense Hermitian or symmetric backend tensor. - - Output: - Tuple of eigenvalues and eigenvectors. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigh.html - """ if self.is_sparse(x): raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.torch.linalg.eigh(x, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.eigh(x, UPLO=UPLO, **kwargs) def norm( self, @@ -1421,18 +432,6 @@ def norm( dtype: DType | None = None, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute vector or matrix norms using PyTorch. - - Input: - x: Dense backend tensor; ord, axis, keepdims: Norm controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.norm.html - """ return self.torch.linalg.norm( x, ord=ord, @@ -1446,97 +445,39 @@ def solve( self, A: DenseArray, b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, left: bool = True, out: DenseArray | None = None, ) -> DenseArray: - """ - Solve a linear system using PyTorch. - - Input: - A: Coefficient tensor; b: Right-hand side tensor. - - Output: - Dense backend tensor solving ``A @ x = b``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.solve.html - """ - return self.torch.linalg.solve(A, b, left=left, out=out) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using PyTorch. - - Input: - A: Dense Hermitian or symmetric backend tensor. - - Output: - Dense backend tensor of eigenvalues. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigvalsh.html - """ - return self.torch.linalg.eigvalsh(A, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.solve(A, b, left=left, **kwargs) def svd( self, A: DenseArray, full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, + backend_kwargs: dict[str, Any] | None = None, *, driver: str | None = None, out: DenseArray | tuple[DenseArray, DenseArray, DenseArray] | None = None, ) -> DenseArray | tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decomposition using PyTorch. - - Input: - A: Dense backend tensor; full_matrices, compute_uv, hermitian: SVD controls. - - Output: - Singular values or tuple ``(U, S, Vh)``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html - - Backend-specific notes: - PyTorch does not expose a ``hermitian`` option for SVD. When - ``compute_uv`` is false, this delegates to ``torch.linalg.svdvals``. - """ - if hermitian: - raise NotImplementedError("PyTorch svd does not expose a hermitian option.") - if not compute_uv: - return self.torch.linalg.svdvals(A, driver=driver, out=out) - return self.torch.linalg.svd(A, full_matrices=full_matrices, driver=driver, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(driver=driver, out=out)) + return self.torch.linalg.svd(A, full_matrices=full_matrices, **kwargs) def cholesky( self, A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, upper: bool = False, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute a Cholesky factorization using PyTorch. - - Input: - A: Positive-definite dense backend tensor. - - Output: - Dense backend tensor containing the Cholesky factor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.cholesky.html - """ - return self.torch.linalg.cholesky(A, upper=upper, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.cholesky(A, upper=upper, **kwargs) def logsumexp( self, @@ -1583,188 +524,18 @@ def logsumexp( return out return result - def exp( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise exponential using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.exp.html - """ - if out is None: - return self.torch.exp(x) - return self.torch.exp(x, out=out) - - def log( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise natural logarithm using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.log.html - """ - if out is None: - return self.torch.log(x) - return self.torch.log(x, out=out) - def where( self, condition: DenseArray | bool, - x: ArrayLike | None = None, - y: ArrayLike | None = None, + x: DenseArray, + y: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: - """ - Select values conditionally using PyTorch. - - Input: - condition: Boolean tensor or scalar; x, y: Values to select. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.where.html - """ - if x is None and y is None: - return self.torch.where(condition) - if x is None or y is None: - raise TypeError("where requires both x and y when either is provided.") if out is None: return self.torch.where(condition, x, y) return self.torch.where(condition, x, y, out=out) - def maximum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise maximum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.maximum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.maximum(x, y) - return self.torch.maximum( - x, - y, - out=out, - ) - - def minimum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise minimum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.minimum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.minimum(x, y) - return self.torch.minimum( - x, - y, - out=out, - ) - - def clip( - self, - x: DenseArray, - a_min: ArrayLike | None = None, - a_max: ArrayLike | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Clip tensor values using PyTorch. - - Input: - x: Dense backend tensor; a_min, a_max: Lower and upper bounds. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.clamp.html - """ - if out is None: - return self.torch.clamp(x, min=a_min, max=a_max) - return self.torch.clamp(x, min=a_min, max=a_max, out=out) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isfinite.html - """ - return self.torch.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isnan.html - """ - return self.torch.isnan(x) - def concatenate( self, arrays: Sequence[DenseArray], @@ -1773,131 +544,12 @@ def concatenate( *, out: DenseArray | None = None, ) -> DenseArray: - """ - Concatenate tensors using PyTorch. - - Input: - arrays: Sequence of dense backend tensors; axis: Concatenation axis; dtype: Optional cast. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.cat.html - """ - arrays = tuple(arrays) if out is None: - result = self.torch.cat(arrays, dim=axis) + result = self.torch.cat(tuple(arrays), dim=axis) else: - result = self.torch.cat(arrays, dim=axis, out=out) + result = self.torch.cat(tuple(arrays), dim=axis, out=out) return self.astype(result, dtype) if dtype is not None else result - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Take tensor elements by index using PyTorch. - - Input: - x: Dense backend tensor; indices: Integer indices; axis: Optional selection axis. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.take.html - """ - if axis is None: - result = self.torch.take(x, indices) - if out is not None: - out.copy_(result) - return out - return result - return self.torch.index_select(x, dim=axis, index=indices, out=out) - - def diag( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Extract or construct a diagonal tensor using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diag.html - """ - return self.torch.diag(x, diagonal=k, out=out) - - def diagonal(self, x: DenseArray, offset: int = 0, axis1: int = 0, axis2: int = 1) -> DenseArray: - """ - Return a tensor diagonal using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2: Diagonal controls. - - Output: - Dense backend tensor view or value. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2) - - def tril( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the lower triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tril.html - """ - return self.torch.tril(x, diagonal=k, out=out) - - def triu( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the upper triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.triu.html - """ - return self.torch.triu(x, diagonal=k, out=out) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed tensor values using PyTorch. @@ -2106,21 +758,6 @@ def cond(self, pred: bool, true_fun: Callable[[T], R], false_fun: Callable[[T], """ return true_fun(*operands) if bool(pred) else false_fun(*operands) - def allclose(self, a: DenseArray, b: DenseArray, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False) -> bool: - """ - Compare dense tensors elementwise within tolerances using PyTorch. - - Input: - a, b: Dense backend tensors; rtol, atol, equal_nan: Comparison controls. - - Output: - Boolean indicating whether tensors are close. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.allclose.html - """ - return bool(self.torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - def allclose_sparse(self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8) -> bool: """ Compare sparse tensors elementwise within tolerances using PyTorch. diff --git a/tests/backend/test_backend_ops_delegation.py b/tests/backend/test_backend_ops_delegation.py new file mode 100644 index 0000000..c9e58d5 --- /dev/null +++ b/tests/backend/test_backend_ops_delegation.py @@ -0,0 +1,183 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import ( + has_jax, + has_torch, + jax_complex_dtype, + jax_real_dtype, + to_numpy, + torch_complex_dtype, + torch_real_dtype, +) + + +DELEGATED_METHODS = ( + "reshape", + "sum", + "eigh", + "trace", + "concatenate", + "transpose", + "matmul", +) + +TORCH_AAC_DELEGATED_METHODS = ( + "mean", + "prod", + "sort", + "argsort", + "argmin", + "argmax", + "clip", + "take", + "diagonal", + "squeeze", +) + + +def test_numpy_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.NumpyOps.xp.__name__ == "array_api_compat.numpy" + for name in DELEGATED_METHODS: + assert getattr(sc.NumpyOps, name) is getattr(sc.BackendOps, name) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.JaxOps.xp is sc.JaxOps.jnp + for name in DELEGATED_METHODS: + assert getattr(sc.JaxOps, name) is getattr(sc.BackendOps, name) + + +def test_torch_ops_uses_aac_namespace_when_available(): + sc = importlib.import_module("spacecore") + + assert sc.TorchOps.xp.__name__ == "array_api_compat.torch" + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_ops_inherits_aac_delegated_methods(): + sc = importlib.import_module("spacecore") + + for name in TORCH_AAC_DELEGATED_METHODS: + assert getattr(sc.TorchOps, name) is getattr(sc.BackendOps, name) + + +def _check_raw_delegated_ops(ops, dtype): + x = ops.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + y = ops.reshape(x, (3, 2)) + h = ops.asarray([[2.0, 1.0], [1.0, 3.0]], dtype=dtype) + cube = ops.reshape(ops.arange(24, dtype=dtype), (2, 3, 4)) + singleton = ops.reshape(ops.arange(6, dtype=dtype), (1, 2, 1, 3)) + + np.testing.assert_allclose(to_numpy(ops.matmul(h, ops.asarray([1.0, 2.0], dtype=dtype))), [4.0, 7.0]) + np.testing.assert_allclose(to_numpy(ops.reshape(x, (6,))), np.arange(1.0, 7.0)) + np.testing.assert_allclose(to_numpy(ops.sum(x, axis=0)), [5.0, 7.0, 9.0]) + np.testing.assert_allclose(to_numpy(ops.sum(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).sum(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.prod(cube + 1, axis=(0, 2))), (np.arange(24.0).reshape(2, 3, 4) + 1).prod(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.mean(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).mean(axis=(0, 2))) + assert ops.shape(ops.squeeze(singleton)) == (2, 3) + np.testing.assert_allclose(to_numpy(ops.trace(h)), 5.0) + np.testing.assert_allclose(to_numpy(ops.concatenate((x, x), axis=0)), np.concatenate((to_numpy(x), to_numpy(x)), axis=0)) + np.testing.assert_allclose(to_numpy(ops.transpose(y)), to_numpy(y).T) + + evals, evecs = ops.eigh(h) + np.testing.assert_allclose(to_numpy(evecs @ ops.diag(evals) @ ops.transpose(evecs)), to_numpy(h), atol=1e-6) + + +def test_numpy_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.NumpyOps(), np.float64) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.JaxOps(), jax_real_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.TorchOps(), torch_real_dtype()) + + +def test_numpy_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + np.testing.assert_allclose(to_numpy(ops.eps), np.finfo(ops.sanitize_dtype(None)).eps) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + + np.testing.assert_allclose(to_numpy(ops.eps), np.finfo(jax_real_dtype()).eps) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + import torch + + ops = sc.TorchOps() + np.testing.assert_allclose(to_numpy(ops.eps), torch.finfo(ops.sanitize_dtype(None)).eps) + + +def _check_complex_adjoint(ops, dtype): + sc = importlib.import_module("spacecore") + ctx = sc.Context(ops, dtype=dtype) + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + H = sc.HermitianSpace(2, ctx=ctx) + raw = ctx.asarray([[1.0 + 0.0j, 2.0 + 3.0j], [-1.0 + 4.0j, -0.5 + 0.0j]]) + herm = H.symmetrize(raw) + evals, evecs = H.eigh(herm) + rebuilt = (evecs * evals) @ ctx.ops.conj(ctx.ops.transpose(evecs)) + np.testing.assert_allclose(to_numpy(rebuilt), to_numpy(herm), rtol=1e-6, atol=1e-6) + + +def test_numpy_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.TorchOps(), torch_complex_dtype()) diff --git a/tests/backend/test_backend_registry.py b/tests/backend/test_backend_registry.py index f55aea9..65ca47a 100644 --- a/tests/backend/test_backend_registry.py +++ b/tests/backend/test_backend_registry.py @@ -14,55 +14,23 @@ def test_register_ops_adds_backend(): sc = importlib.import_module("spacecore") class DummyOps(sc.BackendOps): + import array_api_compat.numpy as xp + _family = "dummy" _allow_sparse = False - _dense_array = np.ndarray - _sparse_array = object + + @property + def dense_array(self): + return np.ndarray + + @property + def sparse_array(self): + return None + def sanitize_dtype(self, dtype): return np.dtype(dtype) if dtype is not None else np.dtype(np.float64) - def asarray(self, x, dtype=None): return np.asarray(x, dtype=dtype) def assparse(self, x, dtype=None): raise NotImplementedError - def is_dense(self, x): return isinstance(x, np.ndarray) - def is_sparse(self, x): return False - def get_dtype(self, x): return x.dtype - def zeros(self, shape, dtype=None): return np.zeros(shape, dtype=dtype) - def ones(self, shape, dtype=None): return np.ones(shape, dtype=dtype) - def full(self, shape, fill_value, dtype=None): return np.full(shape, fill_value, dtype=dtype) - def empty(self, shape, dtype=None): return np.empty(shape, dtype=dtype) - def eye(self, n, dtype=None): return np.eye(n, dtype=dtype) - def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) - def reshape(self, x, shape): return np.reshape(x, shape) - def ravel(self, x): return np.ravel(x) - def transpose(self, x, axes=None): return np.transpose(x, axes=axes) - def swapaxes(self, x, axis1, axis2): return np.swapaxes(x, axis1, axis2) - def conj(self, x): return np.conj(x) - def sum(self, x, axis=None, **kwargs): return np.sum(x, axis=axis) - def prod(self, x, axis=None, **kwargs): return np.prod(x, axis=axis) - def trace(self, x, **kwargs): return np.trace(x) - def argsort(self, x, **kwargs): return np.argsort(x) - def sort(self, x, **kwargs): return np.sort(x) - def argmin(self, x, **kwargs): return np.argmin(x) - def argmax(self, x, **kwargs): return np.argmax(x) - def vdot(self, a, b, **kwargs): return np.vdot(a, b) - def matmul(self, a, b, **kwargs): return a @ b def sparse_matmul(self, a, b): raise NotImplementedError - def kron(self, a, b): return np.kron(a, b) - def einsum(self, subscripts, *operands, **kwargs): return np.einsum(subscripts, *operands) - def eigh(self, x, **kwargs): return np.linalg.eigh(x) def logsumexp(self, *args, **kwargs): raise NotImplementedError - def exp(self, x): return np.exp(x) - def log(self, x): return np.log(x) - def maximum(self, x, y): return np.maximum(x, y) - def minimum(self, x, y): return np.minimum(x, y) - def where(self, condition, x=None, y=None, **kwargs): return np.where(condition, x, y) - def concatenate(self, arrays, axis=0, dtype=None): - out = np.concatenate(arrays, axis=axis) - return out.astype(dtype) if dtype is not None else out - def stack(self, arrays, axis=0): return np.stack(arrays, axis=axis) - def sqrt(self, x): return np.sqrt(x) - def abs(self, x): return np.abs(x) - def real(self, x): return np.real(x) - def imag(self, x): return np.imag(x) - def sign(self, x): return np.sign(x) def index_set(self, x, index, values, copy=True): y = np.array(x, copy=True) @@ -90,10 +58,12 @@ def index_add(self, x, index, values, copy=True): y=np.array(x, copy=True) y[index]+=values return y - def allclose(self, a, b, **kwargs): return np.allclose(a, b, **kwargs) def allclose_sparse(self, a, b, **kwargs): return False sc.register_ops(DummyOps) assert "dummy" in sc._contextual.manager.ctx_manager.available_ops + ops = DummyOps() + x = ops.reshape(ops.arange(6), (2, 3)) + assert np.allclose(ops.sum(x, axis=0), [3, 5, 7]) @pytest.mark.skipif(not has_torch(), reason="torch is not installed") From 80abf4d08ef8be9d24af64f1aa1e96ac16339a1f Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 19 May 2026 18:51:06 -0300 Subject: [PATCH 02/44] Polish BackendOps behavior for v0.2 - make backend ops instances hashable - improve backend ops repr with family information - make astype(x, None) a no-op across backends - cache scalar constants per backend ops instance - convert eps to an explicit dtype-based method - document and test complex vdot conjugation semantics - simplify BackendOps docstrings around the AAC delegation contract --- spacecore/backend/_ops.py | 370 +++++++------------ spacecore/backend/jax/_ops.py | 3 + spacecore/backend/numpy/_ops.py | 3 + spacecore/backend/torch/_ops.py | 7 +- tests/backend/test_backend_ops_delegation.py | 61 ++- tests/test_backend_ops_complex.py | 33 ++ 6 files changed, 230 insertions(+), 247 deletions(-) create mode 100644 tests/test_backend_ops_complex.py diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index b2f3ad1..09e2f89 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -22,51 +22,36 @@ def _load(self) -> Any: def __getattr__(self, name: str) -> Any: return getattr(self._load(), name) + @property + def is_loaded(self) -> bool: + return self._module is not None + class BackendOps(ABC): """ - Backend-agnostic numerical ops interface (portable core). + Public numerical contract for SpaceCore backends. - Contract: - - This base class exposes only the portable subset used by library internals. - - Concrete backends (NumPy/JAX/Torch) may extend these methods with additional - optional keyword parameters (e.g., `order=`, `out=`, `where=`, `like=`, ...). + Common dense-array operations delegate to the backend's Array API-compatible + ``xp`` namespace. Subclasses provide backend-specific sparse conversion, + dtype policy, indexing mutation, control flow, device/autograd behavior, and + operations not covered by the Array API. """ _family: ClassVar[str] _allow_sparse: ClassVar[bool] xp: ClassVar[Any] + def __init__(self) -> None: + self._constant_cache: dict[str, DenseArray] = {} + @property def family(self) -> str: - """ - Generic backend-agnostic wrapper to backend family identifier. - - Input: - None. - - Output: - String naming the concrete backend family. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Backend family identifier.""" return type(self)._family @property def allow_sparse(self) -> bool: - """ - Generic backend-agnostic wrapper to sparse-array support flag. - - Input: - None. - - Output: - Boolean indicating whether this backend supports sparse arrays. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Whether this backend supports sparse arrays.""" return self._allow_sparse def __eq__(self, other: Any) -> bool: @@ -74,148 +59,52 @@ def __eq__(self, other: Any) -> bool: return self.family == other.family return False + def __hash__(self) -> int: + return hash((type(self).__name__, self.family)) + @property @abstractmethod def dense_array(self) -> Type[Any]: - """ - Generic backend-agnostic wrapper to dense array type. - - Input: - None. - - Output: - Concrete dense array class accepted by this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Dense array type accepted by this backend.""" ... @property @abstractmethod def sparse_array(self) -> Tuple[Type[Any], ...] | None: - """ - Generic backend-agnostic wrapper to sparse array type tuple. - - Input: - None. - - Output: - Concrete sparse array classes accepted by this backend, or None. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Sparse array types accepted by this backend, or None.""" ... @abstractmethod def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Generic backend-agnostic wrapper to normalize a dtype specifier. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Normalize a dtype specifier for this backend.""" ... def is_dense(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a dense backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of the backend dense array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a dense array for this backend.""" return isinstance(x, self.dense_array) def is_sparse(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a sparse backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of a backend sparse array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a sparse array for this backend.""" return self.sparse_array is not None and isinstance(x, self.sparse_array) def is_array(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for any backend array. - - Input: - x: Object to test. - - Output: - True when x is dense or sparse for this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is any array for this backend.""" return self.is_dense(x) or self.is_sparse(x) @abstractmethod def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: - """ - Generic backend-agnostic wrapper to convert input to a sparse array. - - Input: - x: Dense, sparse, or array-like input plus sparse-format options. - - Output: - Sparse backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Convert input to a backend sparse array.""" ... @abstractmethod def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to multiply sparse and dense arrays. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Multiply a sparse array by a dense array.""" ... @abstractmethod def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute a stable log-sum-exp reduction. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. - - Output: - Dense backend array or tuple containing log-sum-exp results. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Compute a stable log-sum-exp reduction.""" ... @abstractmethod @@ -227,18 +116,7 @@ def index_set( *, copy: bool = True, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to set indexed values. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Set indexed values using backend mutation semantics.""" @abstractmethod def index_add( @@ -249,34 +127,12 @@ def index_add( *, copy: bool = True, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to add into indexed values. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Add values into indexed positions using backend mutation semantics.""" ... @abstractmethod def ix_(self, *args: Any) -> Any: - """ - Generic backend-agnostic wrapper to build open mesh index arrays. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Build open-mesh index arrays.""" ... @abstractmethod @@ -287,18 +143,7 @@ def fori_loop( body_fun: Callable[[int, T], T], init_val: T, ) -> T: - """ - Generic backend-agnostic wrapper to run a counted loop primitive. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Run a counted loop primitive.""" @abstractmethod def while_loop( @@ -307,18 +152,7 @@ def while_loop( body_fun: Callable[[T], T], init_val: T, ) -> T: - """ - Generic backend-agnostic wrapper to run a while-loop primitive. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Run a while-loop primitive.""" @abstractmethod def scan( @@ -330,18 +164,7 @@ def scan( reverse: bool = False, unroll: int = 1, ) -> Tuple[Carry, Y]: - """ - Generic backend-agnostic wrapper to run a scan primitive. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Run a scan primitive.""" @abstractmethod def cond( @@ -351,18 +174,7 @@ def cond( false_fun: Callable[[T], R], *operands: Any, ) -> R: - """ - Generic backend-agnostic wrapper to run conditional branch selection. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Run backend-compatible conditional branch selection.""" ... @abstractmethod @@ -373,18 +185,7 @@ def allclose_sparse( rtol: float = 1e-5, atol: float = 1e-8, ) -> bool: - """ - Generic backend-agnostic wrapper to compare sparse arrays elementwise within tolerances. - - Input: - a, b: Sparse backend arrays; rtol and atol configure comparison. - - Output: - Boolean indicating whether sparse arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Compare sparse arrays elementwise within tolerances.""" ... def _dtype_arg(self, dtype: DType | None) -> DType | None: @@ -420,42 +221,56 @@ def _move_axis_order( @property def inf(self) -> DenseArray: - return self.asarray(float("inf")) + """Positive infinity as a cached backend scalar.""" + return self._constant("inf", float("inf")) @property def nan(self) -> DenseArray: - return self.asarray(float("nan")) + """NaN as a cached backend scalar.""" + return self._constant("nan", float("nan")) @property def pi(self) -> DenseArray: - return self.asarray(3.141592653589793) + """Pi as a cached backend scalar.""" + return self._constant("pi", 3.141592653589793) @property def e(self) -> DenseArray: - return self.asarray(2.718281828459045) + """Euler's number as a cached backend scalar.""" + return self._constant("e", 2.718281828459045) - @property - def eps(self) -> DenseArray: - return self.asarray(self.xp.finfo(self.sanitize_dtype(None)).eps) + def _constant(self, name: str, value: float) -> DenseArray: + if name not in self._constant_cache: + self._constant_cache[name] = self.asarray(value) + return self._constant_cache[name] + + def eps(self, dtype: DType) -> float: + """Machine epsilon for dtype.""" + return float(self.xp.finfo(self.sanitize_dtype(dtype)).eps) def get_dtype(self, x: Any) -> DType: + """Return x.dtype after verifying x is a backend array.""" if self.is_array(x): return x.dtype raise TypeError(f"Expected {self.family} array, got {type(x)}.") def shape(self, x: Any) -> tuple[int, ...]: + """Return x.shape as a tuple.""" return tuple(x.shape) def ndim(self, x: Any) -> int: + """Return the number of dimensions of x.""" return int(x.ndim) def size(self, x: Any) -> int: + """Return the total number of elements in x.""" result = 1 for dim in self.shape(x): result *= int(dim) return result def asarray(self, x: Any, dtype: DType | None = None, **backend_kwargs: Any) -> DenseArray: + """Convert input to a dense backend array (delegates to xp.asarray).""" if self.is_sparse(x) and hasattr(x, "to_dense"): x = x.to_dense() dtype = self._dtype_arg(dtype) @@ -463,28 +278,37 @@ def asarray(self, x: Any, dtype: DType | None = None, **backend_kwargs: Any) -> return self.xp.asarray(x, dtype=dtype, **backend_kwargs) return self.xp.as_tensor(x, dtype=dtype, **backend_kwargs) - def astype(self, x: DenseArray, dtype: DType, **backend_kwargs: Any) -> DenseArray: + def astype(self, x: DenseArray, dtype: DType | None, **backend_kwargs: Any) -> DenseArray: + """Cast x to dtype, returning x unchanged when dtype is None.""" + if dtype is None: + return x dtype = self.sanitize_dtype(dtype) if hasattr(x, "astype"): return x.astype(dtype, **backend_kwargs) return x.to(dtype=dtype, **backend_kwargs) def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create an uninitialized array (delegates to xp.empty).""" return self.xp.empty(shape, dtype=self._dtype_arg(dtype)) def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create a zero-filled array (delegates to xp.zeros).""" return self.xp.zeros(shape, dtype=self._dtype_arg(dtype)) def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create a one-filled array (delegates to xp.ones).""" return self.xp.ones(shape, dtype=self._dtype_arg(dtype)) def zeros_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: + """Create a zero-filled array like x (delegates to xp.zeros_like).""" return self.xp.zeros_like(x, dtype=self._dtype_arg(dtype)) def ones_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: + """Create a one-filled array like x (delegates to xp.ones_like).""" return self.xp.ones_like(x, dtype=self._dtype_arg(dtype)) def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array like x (delegates to xp.full_like).""" return self.xp.full_like(x, value, dtype=self._dtype_arg(dtype)) def arange( @@ -494,6 +318,7 @@ def arange( step: int | None = None, dtype: DType | None = None, ) -> DenseArray: + """Create an evenly spaced range (delegates to xp.arange).""" dtype = self._dtype_arg(dtype) if stop is None: return self.xp.arange(start, dtype=dtype) @@ -502,26 +327,32 @@ def arange( return self.xp.arange(start, stop, step, dtype=dtype) def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array (delegates to xp.full).""" return self.xp.full(shape, fill_value, dtype=self._dtype_arg(dtype)) def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: + """Create an identity-like matrix (delegates to xp.eye).""" return self.xp.eye(n, m, dtype=self._dtype_arg(dtype)) def ravel(self, x: DenseArray) -> DenseArray: + """Flatten x to one dimension.""" if hasattr(self.xp, "ravel"): return self.xp.ravel(x) return self.reshape(x, (-1,)) def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: + """Reshape x (delegates to xp.reshape).""" shape_arg = (shape,) if isinstance(shape, int) else shape return self.xp.reshape(x, shape_arg) def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: + """Permute dimensions of x.""" if axes is None: axes = tuple(reversed(range(self.ndim(x)))) return self._permute_dims(x, axes) def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: + """Swap two axes of x.""" if hasattr(self.xp, "swapaxes"): return self.xp.swapaxes(x, axis1, axis2) axes = list(range(self.ndim(x))) @@ -529,9 +360,11 @@ def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: return self._permute_dims(x, axes) def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: + """Broadcast x to shape (delegates to xp.broadcast_to).""" return self.xp.broadcast_to(x, shape) def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: + """Insert singleton dimensions into x.""" if isinstance(axis, int): if hasattr(self.xp, "expand_dims"): return self.xp.expand_dims(x, axis=axis) @@ -543,6 +376,7 @@ def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: return out def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: + """Remove singleton dimensions from x.""" if axis is None: axis = tuple(i for i, dim in enumerate(self.shape(x)) if dim == 1) if not axis: @@ -556,29 +390,37 @@ def moveaxis( source: int | Sequence[int], destination: int | Sequence[int], ) -> DenseArray: + """Move axes of x to new positions.""" if hasattr(self.xp, "moveaxis"): return self.xp.moveaxis(x, source, destination) return self._permute_dims(x, self._move_axis_order(self.ndim(x), source, destination)) def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: + """Stack arrays along a new axis (delegates to xp.stack).""" return self.xp.stack(tuple(arrays), axis=axis) def conj(self, x: DenseArray) -> DenseArray: + """Complex conjugate of x (delegates to xp.conj).""" return self.xp.conj(x) def real(self, x: DenseArray) -> DenseArray: + """Real component of x (delegates to xp.real).""" return self.xp.real(x) def imag(self, x: DenseArray) -> DenseArray: + """Imaginary component of x (delegates to xp.imag).""" return self.xp.imag(x) def abs(self, x: DenseArray) -> DenseArray: + """Absolute value of x (delegates to xp.abs).""" return self.xp.abs(x) def sign(self, x: DenseArray) -> DenseArray: + """Elementwise sign of x (delegates to xp.sign).""" return self.xp.sign(x) def sqrt(self, x: DenseArray) -> DenseArray: + """Elementwise square root of x (delegates to xp.sqrt).""" return self.xp.sqrt(x) def sum( @@ -588,6 +430,7 @@ def sum( keepdims: bool = False, dtype: DType | None = None, ) -> DenseArray: + """Sum over given axes (delegates to xp.sum).""" return self.xp.sum( x, axis=self._to_axis_tuple(axis), @@ -601,6 +444,7 @@ def mean( axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: + """Mean over given axes (delegates to xp.mean).""" return self.xp.mean(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) def min( @@ -609,6 +453,7 @@ def min( axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: + """Minimum over given axes (delegates to xp.min).""" return self.xp.min(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) def max( @@ -617,6 +462,7 @@ def max( axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: + """Maximum over given axes (delegates to xp.max).""" return self.xp.max(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) def prod( @@ -626,6 +472,7 @@ def prod( keepdims: bool = False, dtype: DType | None = None, ) -> DenseArray: + """Product over given axes (delegates to xp.prod).""" return self.xp.prod( x, axis=self._to_axis_tuple(axis), @@ -634,23 +481,32 @@ def prod( ) def trace(self, x: DenseArray) -> DenseArray: + """Trace of a matrix (delegates to xp.trace when available).""" if hasattr(self.xp, "trace"): return self.xp.trace(x) return self.sum(self.diagonal(x)) def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: + """Indices that sort x (delegates to xp.argsort).""" return self.xp.argsort(x, axis=axis) def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: + """Sort x along an axis (delegates to xp.sort).""" return self.xp.sort(x, axis=axis) def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + """Indices of minima (delegates to xp.argmin).""" return self.xp.argmin(x, axis=axis, keepdims=keepdims) def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + """Indices of maxima (delegates to xp.argmax).""" return self.xp.argmax(x, axis=axis, keepdims=keepdims) def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: + """ + Returns sum(conj(x) * y). Matches numpy/jax/torch vdot and Array API + vecdot. DenseLinOp.rapply relies on this for complex inputs. + """ x_flat = self.ravel(x) y_flat = self.ravel(y) if hasattr(self.xp, "vdot"): @@ -663,12 +519,15 @@ def matmul( b: DenseArray, backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: + """Matrix product (delegates to xp.matmul).""" return self.xp.matmul(a, b, **({} if backend_kwargs is None else backend_kwargs)) def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: + """Kronecker product (delegates to xp.kron).""" return self.xp.kron(a, b) def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: + """Einstein summation (delegates to xp.einsum).""" return self.xp.einsum(subscripts, *operands) def eigh( @@ -676,6 +535,7 @@ def eigh( x: DenseArray, backend_kwargs: dict[str, Any] | None = None, ) -> tuple[DenseArray, DenseArray]: + """Eigenpairs of a Hermitian dense matrix (delegates to xp.linalg.eigh).""" if self.is_sparse(x): raise TypeError("eigh requires a dense array; sparse input is not supported.") return self.xp.linalg.eigh(x, **({} if backend_kwargs is None else backend_kwargs)) @@ -687,6 +547,7 @@ def norm( axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: + """Vector or matrix norm (delegates to xp.linalg.norm).""" return self.xp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) def solve( @@ -695,6 +556,7 @@ def solve( b: DenseArray, backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: + """Solve a dense linear system (delegates to xp.linalg.solve).""" return self.xp.linalg.solve(A, b, **({} if backend_kwargs is None else backend_kwargs)) def eigvalsh( @@ -702,6 +564,7 @@ def eigvalsh( A: DenseArray, backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: + """Eigenvalues of a Hermitian dense matrix (delegates to xp.linalg.eigvalsh).""" return self.xp.linalg.eigvalsh(A, **({} if backend_kwargs is None else backend_kwargs)) def svd( @@ -710,6 +573,7 @@ def svd( full_matrices: bool = True, backend_kwargs: dict[str, Any] | None = None, ) -> tuple[DenseArray, DenseArray, DenseArray]: + """Singular value decomposition (delegates to xp.linalg.svd).""" return self.xp.linalg.svd( A, full_matrices=full_matrices, @@ -721,30 +585,39 @@ def cholesky( A: DenseArray, backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: + """Cholesky factorization (delegates to xp.linalg.cholesky).""" return self.xp.linalg.cholesky(A, **({} if backend_kwargs is None else backend_kwargs)) def exp(self, x: DenseArray) -> DenseArray: + """Elementwise exponential (delegates to xp.exp).""" return self.xp.exp(x) def log(self, x: DenseArray) -> DenseArray: + """Elementwise natural logarithm (delegates to xp.log).""" return self.xp.log(x) def where(self, condition: DenseArray | bool, x: ArrayLike, y: ArrayLike) -> DenseArray: + """Select between x and y by condition (delegates to xp.where).""" return self.xp.where(condition, x, y) def maximum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: + """Elementwise maximum (delegates to xp.maximum).""" return self.xp.maximum(x, y) def minimum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: + """Elementwise minimum (delegates to xp.minimum).""" return self.xp.minimum(x, y) def clip(self, x: DenseArray, a_min: ArrayLike, a_max: ArrayLike) -> DenseArray: + """Clip x into [a_min, a_max] (delegates to xp.clip).""" return self.xp.clip(x, a_min, a_max) def isfinite(self, x: DenseArray) -> DenseArray: + """Elementwise finite check (delegates to xp.isfinite).""" return self.xp.isfinite(x) def isnan(self, x: DenseArray) -> DenseArray: + """Elementwise NaN check (delegates to xp.isnan).""" return self.xp.isnan(x) def concatenate( @@ -753,6 +626,7 @@ def concatenate( axis: int = 0, dtype: DType | None = None, ) -> DenseArray: + """Concatenate arrays along an existing axis (delegates to xp.concat).""" if hasattr(self.xp, "concat"): result = self.xp.concat(tuple(arrays), axis=axis) else: @@ -765,18 +639,23 @@ def take( indices: DenseArray, axis: int | None = None, ) -> DenseArray: + """Take entries from x by integer indices (delegates to xp.take).""" return self.xp.take(x, indices, axis=axis) def diag(self, x: DenseArray) -> DenseArray: + """Extract or construct a diagonal (delegates to xp.diag).""" return self.xp.diag(x) def diagonal(self, x: DenseArray) -> DenseArray: + """Return the main diagonal of x (delegates to xp.diagonal).""" return self.xp.diagonal(x) def tril(self, x: DenseArray) -> DenseArray: + """Lower triangle of x (delegates to xp.tril).""" return self.xp.tril(x) def triu(self, x: DenseArray) -> DenseArray: + """Upper triangle of x (delegates to xp.triu).""" return self.xp.triu(x) def allclose( @@ -787,7 +666,12 @@ def allclose( atol: float = 1e-8, equal_nan: bool = False, ) -> bool: + """Return whether dense arrays are close within tolerances.""" return bool(self.xp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - def __repr__(self): - return f"{type(self).__name__}" + def __repr__(self) -> str: + xp = type(self).xp + xp_state = "" + if isinstance(xp, LazyNamespace): + xp_state = f", xp_loaded={xp.is_loaded!r}" + return f"{type(self).__name__}(family={self.family!r}{xp_state})" diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index effce84..04b2eab 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -61,6 +61,9 @@ class JaxOps(BackendOps): _family = BackendFamily.jax.value.lower() _allow_sparse = True + def __init__(self) -> None: + super().__init__() + def sanitize_dtype(self, dtype: DType | None) -> DType: """ Normalize a dtype specifier using JAX. diff --git a/spacecore/backend/numpy/_ops.py b/spacecore/backend/numpy/_ops.py index 6fc6263..9a919bc 100644 --- a/spacecore/backend/numpy/_ops.py +++ b/spacecore/backend/numpy/_ops.py @@ -53,6 +53,9 @@ class NumpyOps(BackendOps): _family = BackendFamily.numpy.value.lower() _allow_sparse = True + def __init__(self) -> None: + super().__init__() + @property def dense_array(self) -> Type[Any]: """ diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index 82b2154..f1413d7 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -65,6 +65,9 @@ class TorchOps(BackendOps): torch.sparse_bsc, ) + def __init__(self) -> None: + super().__init__() + @staticmethod def _defined_kwargs(**kwargs: Any) -> dict[str, Any]: return {key: value for key, value in kwargs.items() if value is not None} @@ -262,13 +265,15 @@ def asarray( def astype( self, x: DenseArray, - dtype: DType, + dtype: DType | None, *, copy: bool = True, non_blocking: bool = False, memory_format: Any | None = None, backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: + if dtype is None: + return x kwargs = {} if backend_kwargs is None else dict(backend_kwargs) kwargs.update(self._defined_kwargs(memory_format=memory_format)) return x.to( diff --git a/tests/backend/test_backend_ops_delegation.py b/tests/backend/test_backend_ops_delegation.py index c9e58d5..770ce32 100644 --- a/tests/backend/test_backend_ops_delegation.py +++ b/tests/backend/test_backend_ops_delegation.py @@ -115,7 +115,7 @@ def test_numpy_eps_uses_default_dtype(): sc = importlib.import_module("spacecore") ops = sc.NumpyOps() - np.testing.assert_allclose(to_numpy(ops.eps), np.finfo(ops.sanitize_dtype(None)).eps) + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), np.finfo(ops.sanitize_dtype(None)).eps) @pytest.mark.skipif(not has_jax(), reason="jax is not installed") @@ -123,7 +123,7 @@ def test_jax_eps_uses_default_dtype(): sc = importlib.import_module("spacecore") ops = sc.JaxOps() - np.testing.assert_allclose(to_numpy(ops.eps), np.finfo(jax_real_dtype()).eps) + np.testing.assert_allclose(ops.eps(jax_real_dtype()), np.finfo(jax_real_dtype()).eps) @pytest.mark.skipif(not has_torch(), reason="torch is not installed") @@ -132,7 +132,62 @@ def test_torch_eps_uses_default_dtype(): import torch ops = sc.TorchOps() - np.testing.assert_allclose(to_numpy(ops.eps), torch.finfo(ops.sanitize_dtype(None)).eps) + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), torch.finfo(ops.sanitize_dtype(None)).eps) + + +def test_backend_ops_hash_and_repr(): + sc = importlib.import_module("spacecore") + first = sc.NumpyOps() + second = sc.NumpyOps() + + assert hash(first) == hash(second) + assert {first: 1, second: 2} == {first: 2} + assert repr(first) == "NumpyOps(family='numpy')" + + +def test_numpy_eps_distinguishes_dtype_precision(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + assert ops.eps(np.float64) < ops.eps(np.float32) + + +def test_numpy_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + x = ops.asarray([1.0, 2.0]) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + x = ops.asarray([1.0, 2.0], dtype=jax_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.TorchOps() + x = ops.asarray([1.0, 2.0], dtype=torch_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x def _check_complex_adjoint(ops, dtype): diff --git a/tests/test_backend_ops_complex.py b/tests/test_backend_ops_complex.py new file mode 100644 index 0000000..c5b95e0 --- /dev/null +++ b/tests/test_backend_ops_complex.py @@ -0,0 +1,33 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, to_numpy, torch_complex_dtype + + +def _check_vdot_conjugates_first_argument(ops, dtype): + x = ops.asarray([1.0 + 2.0j, 3.0 + 4.0j], dtype=dtype) + y = ops.asarray([5.0 + 6.0j, 7.0 + 8.0j], dtype=dtype) + + np.testing.assert_allclose(to_numpy(ops.vdot(x, y)), 70.0 - 8.0j) + + +def test_numpy_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.TorchOps(), torch_complex_dtype()) From bae5446ad0d37fee88ccb1cdde2fb25b52e0ba3a Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:04:14 -0300 Subject: [PATCH 03/44] Add lazy LinOp algebra foundation --- spacecore/__init__.py | 27 +- spacecore/linop/__init__.py | 20 + spacecore/linop/_algebra.py | 568 +++++++++++++++++++++++++++ spacecore/linop/_base.py | 93 +++++ tests/integration/test_public_api.py | 3 + tests/linops/test_algebra_linop.py | 244 ++++++++++++ 6 files changed, 954 insertions(+), 1 deletion(-) create mode 100644 spacecore/linop/_algebra.py create mode 100644 tests/linops/test_algebra_linop.py diff --git a/spacecore/__init__.py b/spacecore/__init__.py index d130500..4371730 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -6,7 +6,23 @@ from .backend import TorchOps as TorchOps except ImportError: pass -from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp, LinOp +from .linop import ( + BlockDiagonalLinOp, + ComposedLinOp, + DenseLinOp, + IdentityLinOp, + LinOp, + MatrixFreeLinOp, + ScaledLinOp, + SparseLinOp, + StackedLinOp, + SumLinOp, + SumToSingleLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) from .space import ( BackendCheck, DTypeCheck, @@ -42,8 +58,17 @@ "NumpyOps", "LinOp", + "ComposedLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "BlockDiagonalLinOp", "SumToSingleLinOp", "StackedLinOp", diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index 7a6f537..f810bb1 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -1,12 +1,32 @@ from ._base import LinOp +from ._algebra import ( + ComposedLinOp, + IdentityLinOp, + MatrixFreeLinOp, + ScaledLinOp, + SumLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) from ._dense import DenseLinOp from ._sparse import SparseLinOp from .product import ProductLinOp, StackedLinOp, SumToSingleLinOp, BlockDiagonalLinOp __all__ = [ "LinOp", + "ComposedLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "ProductLinOp", "SumToSingleLinOp", "BlockDiagonalLinOp", diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py new file mode 100644 index 0000000..f104c2f --- /dev/null +++ b/spacecore/linop/_algebra.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +from numbers import Number +from typing import Any, Sequence + +from ._base import LinOp, Domain, Codomain +from ..backend import Context, jax_pytree_class + + +def is_scalar_like(value: Any) -> bool: + """Return whether ``value`` can be used as a scalar multiplier for a ``LinOp``.""" + if isinstance(value, Number): + return True + shape = getattr(value, "shape", None) + if shape is not None: + return tuple(shape) == () + ndim = getattr(value, "ndim", None) + return ndim == 0 + + +def _conjugate_scalar(value: Any) -> Any: + if hasattr(value, "conjugate"): + return value.conjugate() + if hasattr(value, "conj"): + return value.conj() + return value + + +def _same_context(left: LinOp, right: LinOp) -> bool: + return ( + left.ctx == right.ctx + and left.ctx.dtype == right.ctx.dtype + and left.ctx.enable_checks == right.ctx.enable_checks + ) + + +def _require_same_context(ops: Sequence[LinOp]) -> Context: + ctx = ops[0].ctx + for i, op in enumerate(ops[1:], start=1): + if not _same_context(ops[0], op): + raise ValueError( + "All LinOp operands in an algebraic expression must have the same ctx; " + f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}." + ) + return ctx + + +def _require_linop(op: Any, name: str) -> LinOp: + if not isinstance(op, LinOp): + raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.") + return op + + +def _scalar_equal(value: Any, target: Any) -> bool: + try: + return bool(value == target) + except Exception: + return False + + +def _is_zero_scalar(value: Any) -> bool: + return _scalar_equal(value, 0) + + +def _is_one_scalar(value: Any) -> bool: + return _scalar_equal(value, 1) + + +def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]: + terms: list[LinOp] = [] + for i, op in enumerate(ops): + op = _require_linop(op, f"ops[{i}]") + if isinstance(op, SumLinOp): + terms.extend(_flatten_sum_terms(op.parts)) + else: + terms.append(op) + return tuple(terms) + + +def make_sum(ops: Sequence[LinOp]) -> LinOp: + """ + Return a locally simplified lazy sum of linear operators. + + This factory performs only local algebraic canonicalization: nested + ``SumLinOp`` nodes are flattened and ``ZeroLinOp`` terms are removed. It + does not collect like terms, reorder operands, or attempt full symbolic + optimization. All operands must have the same context, domain, and codomain + before a simplified operator is returned. + """ + if not ops: + raise ValueError("make_sum requires a nonempty sequence of LinOp operands.") + + terms = _flatten_sum_terms(ops) + ctx = _require_same_context(terms) + domain = terms[0].domain + codomain = terms[0].codomain + for i, op in enumerate(terms[1:], start=1): + if op.domain != domain or op.codomain != codomain: + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + + nonzero_terms = tuple(op for op in terms if not isinstance(op, ZeroLinOp)) + if not nonzero_terms: + return ZeroLinOp(domain, codomain, ctx) + if len(nonzero_terms) == 1: + return nonzero_terms[0] + return SumLinOp(nonzero_terms) + + +def make_scaled(scalar: Any, op: LinOp) -> LinOp: + """ + Return a locally simplified scalar multiple of a linear operator. + + This factory performs only local algebraic canonicalization: zero and unit + scalars are simplified, and nested ``ScaledLinOp`` nodes are folded into one + scalar. It does not distribute scaling over sums or perform full symbolic + optimization. Complex scalars retain the usual conjugated coefficient in + ``rapply`` through ``ScaledLinOp``. + """ + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + + if _is_zero_scalar(scalar): + return ZeroLinOp(op.domain, op.codomain, op.ctx) + if _is_one_scalar(scalar): + return op + if isinstance(op, ScaledLinOp): + return make_scaled(scalar * op.scalar, op.op) + return ScaledLinOp(scalar, op) + + +def make_composed(left: LinOp, right: LinOp) -> LinOp: + """ + Return a locally simplified composition of two linear operators. + + This factory performs only local algebraic canonicalization: identity + factors are removed and compositions with zero maps become zero maps. It + preserves the binary ``ComposedLinOp`` representation and does not flatten + multi-factor chains or attempt full symbolic optimization. Operands must + have the same context and compatible middle spaces before a simplified + operator is returned. + """ + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if right.codomain != left.domain: + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + + if isinstance(right, IdentityLinOp): + return left + if isinstance(left, IdentityLinOp): + return right + if isinstance(left, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + if isinstance(right, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + return ComposedLinOp(left, right) + + +@jax_pytree_class +class ScaledLinOp(LinOp[Domain, Codomain]): + """ + Lazy scalar multiple of a linear operator. + + ``ScaledLinOp(alpha, A)`` represents the mathematical operator + ``alpha * A``. Its context is exactly ``A.ctx``; its domain is ``A.domain`` + and its codomain is ``A.codomain``. No dense matrix representation is + formed. + + The forward action is ``apply(x) = alpha * A.apply(x)`` for + ``x in A.domain``. The reverse action is + ``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so + complex scalars use the conjugated coefficient. + """ + + def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + super().__init__(op.domain, op.codomain, op.ctx) + self.scalar = scalar + self.op = op + + def apply(self, x: Any) -> Any: + """Return ``scalar * op.apply(x)``.""" + return self.scalar * self.op.apply(x) + + def rapply(self, y: Any) -> Any: + """Return ``conj(scalar) * op.rapply(y)``.""" + return _conjugate_scalar(self.scalar) * self.op.rapply(y) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.scalar == other.scalar and self.op == other.op + return False + + def tree_flatten(self): + children = (self.scalar, self.op) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + scalar, op = children + return cls(scalar, op) + + def _convert(self, new_ctx: Context) -> ScaledLinOp: + return ScaledLinOp(self.scalar, self.op.convert(new_ctx)) + + +@jax_pytree_class +class SumLinOp(LinOp[Domain, Codomain]): + """ + Lazy finite sum of linear operators with common spaces. + + ``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty + sequence of ``LinOp`` instances. All operands must have the same ``ctx``, + the same domain, and the same codomain before construction. The resulting + operator has that shared context, domain, and codomain. + + The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared + domain element ``x``. The reverse action is + ``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``. + """ + + def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: + if not ops: + raise ValueError("SumLinOp requires a nonempty sequence of LinOp operands.") + parts = tuple(_require_linop(op, f"ops[{i}]") for i, op in enumerate(ops)) + ctx = _require_same_context(parts) + domain = parts[0].domain + codomain = parts[0].codomain + for i, op in enumerate(parts[1:], start=1): + if op.domain != domain or op.codomain != codomain: + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + super().__init__(domain, codomain, ctx) + self.ops_tuple = parts + + @property + def parts(self) -> tuple[LinOp[Domain, Codomain], ...]: + """Operators in this lazy sum.""" + return self.ops_tuple + + def apply(self, x: Any) -> Any: + """Return ``sum_i ops[i].apply(x)``.""" + acc = self.ops_tuple[0].apply(x) + for op in self.ops_tuple[1:]: + acc = self.codomain.add(acc, op.apply(x)) + return acc + + def rapply(self, y: Any) -> Any: + """Return ``sum_i ops[i].rapply(y)``.""" + acc = self.ops_tuple[0].rapply(y) + for op in self.ops_tuple[1:]: + acc = self.domain.add(acc, op.rapply(y)) + return acc + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.ops_tuple == other.ops_tuple + return False + + def tree_flatten(self): + children = self.ops_tuple + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(tuple(children)) + + def _convert(self, new_ctx: Context) -> SumLinOp: + return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple)) + + +@jax_pytree_class +class ComposedLinOp(LinOp[Domain, Codomain]): + """ + Lazy composition of two linear operators. + + ``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must + have the same ``ctx`` before construction, and ``B.codomain`` must equal + ``A.domain``. The resulting operator has domain ``B.domain`` and codomain + ``A.codomain``. + + The forward action is ``apply(x) = A.apply(B.apply(x))`` for + ``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))`` + for ``z in A.codomain``. + """ + + def __init__(self, left: LinOp, right: LinOp) -> None: + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if right.codomain != left.domain: + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + super().__init__(right.domain, left.codomain, left.ctx) + self.left = left + self.right = right + + def apply(self, x: Any) -> Any: + """Return ``left.apply(right.apply(x))``.""" + return self.left.apply(self.right.apply(x)) + + def rapply(self, z: Any) -> Any: + """Return ``right.rapply(left.rapply(z))``.""" + return self.right.rapply(self.left.rapply(z)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.left == other.left and self.right == other.right + return False + + def tree_flatten(self): + children = (self.left, self.right) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + left, right = children + return cls(left, right) + + def _convert(self, new_ctx: Context) -> ComposedLinOp: + return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx)) + + +@jax_pytree_class +class ZeroLinOp(LinOp[Domain, Codomain]): + """ + Lazy zero map between two spaces. + + ``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is + resolved from the optional ``ctx`` argument and the two spaces, then both + spaces are converted to that context. Its domain is ``X`` and its codomain + is ``Y`` in the resolved context. + + The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action + is ``rapply(y) = 0_X`` for ``y in Y``. + """ + + def __init__( + self, + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, cod, ctx) + + def apply(self, x: Any) -> Any: + """Return the zero element of the codomain.""" + if self._enable_checks: + self.domain._check_member(x) + return self.codomain.zeros() + + def rapply(self, y: Any) -> Any: + """Return the zero element of the domain.""" + if self._enable_checks: + self.codomain._check_member(y) + return self.domain.zeros() + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.codomain == other.codomain + return False + + def tree_flatten(self): + children = () + aux = (self.domain, self.codomain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + def _convert(self, new_ctx: Context) -> ZeroLinOp: + return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class IdentityLinOp(LinOp[Domain, Domain]): + """ + Lazy identity map on a space. + + ``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The + context is resolved from the optional ``ctx`` argument and the space, and the + resulting operator has domain and codomain equal to ``X`` in that context. + + The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is + ``rapply(x) = x`` for ``x in X``. + """ + + def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: + super().__init__(space, space, ctx) + + def apply(self, x: Any) -> Any: + """Return ``x`` after domain validation.""" + if self._enable_checks: + self.domain._check_member(x) + return x + + def rapply(self, x: Any) -> Any: + """Return ``x`` after codomain validation.""" + if self._enable_checks: + self.codomain._check_member(x) + return x + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain + return False + + def tree_flatten(self): + children = () + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(domain, ctx) + + def _convert(self, new_ctx: Context) -> IdentityLinOp: + return IdentityLinOp(self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinOp(LinOp[Domain, Codomain]): + """ + Linear operator defined by user-supplied forward and reverse callables. + + ``MatrixFreeLinOp(apply, rapply, X, Y)`` represents a matrix-free map + ``A : X -> Y`` without storing or materializing a matrix. The context is + resolved from the optional ``ctx`` argument and the spaces, then the spaces + are converted to that context. + + The forward action is ``apply(x) = apply_fn(x)`` for ``x in X``. The reverse + action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. When checks are + enabled, inputs and callable outputs are validated against the corresponding + domain and codomain. + """ + + def __init__( + self, + apply: Any, + rapply: Any, + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + ) -> None: + if not callable(apply): + raise TypeError(f"apply must be callable, got {type(apply).__name__}.") + if not callable(rapply): + raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") + super().__init__(dom, cod, ctx) + self.apply_fn = apply + self.rapply_fn = rapply + + def apply(self, x: Any) -> Any: + """Return ``apply_fn(x)``.""" + if self._enable_checks: + self.domain._check_member(x) + y = self.apply_fn(x) + if self._enable_checks: + self.codomain._check_member(y) + return y + + def rapply(self, y: Any) -> Any: + """Return ``rapply_fn(y)``.""" + if self._enable_checks: + self.codomain._check_member(y) + x = self.rapply_fn(y) + if self._enable_checks: + self.domain._check_member(x) + return x + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.domain == other.domain + and self.codomain == other.codomain + and self.apply_fn is other.apply_fn + and self.rapply_fn is other.rapply_fn + ) + return False + + def tree_flatten(self): + children = () + aux = (self.apply_fn, self.rapply_fn, self.domain, self.codomain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + apply_fn, rapply_fn, domain, codomain, ctx = aux + return cls(apply_fn, rapply_fn, domain, codomain, ctx) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: + return MatrixFreeLinOp( + self.apply_fn, + self.rapply_fn, + self.domain.convert(new_ctx), + self.codomain.convert(new_ctx), + new_ctx, + ) + + +@jax_pytree_class +class _AdjointViewLinOp(LinOp[Codomain, Domain]): + """ + Hermitian-adjoint view of a linear operator. + + ``A.H`` represents the adjoint view ``A*``. Its context is exactly + ``A.ctx``; its domain is ``A.codomain`` and its codomain is ``A.domain``. + ``A.H.H`` returns ``A`` rather than constructing another wrapper. + + The forward action is ``apply(y) = A.rapply(y)`` for ``y in A.codomain``. + The reverse action is ``rapply(x) = A.apply(x)`` for ``x in A.domain``. + """ + + def __init__(self, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + super().__init__(op.codomain, op.domain, op.ctx) + self.op = op + + def apply(self, y: Any) -> Any: + """Return ``op.rapply(y)``.""" + return self.op.rapply(y) + + def rapply(self, x: Any) -> Any: + """Return ``op.apply(x)``.""" + return self.op.apply(x) + + @property + def H(self) -> LinOp[Domain, Codomain]: + """Original operator viewed as the adjoint of this adjoint view.""" + return self.op + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.op == other.op + return False + + def tree_flatten(self): + children = (self.op,) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0]) + + def _convert(self, new_ctx: Context) -> _AdjointViewLinOp: + return _AdjointViewLinOp(self.op.convert(new_ctx)) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index dfc06b9..05707d5 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from numbers import Number from typing import Any, Generic, TypeVar from ..space import Space @@ -31,6 +32,16 @@ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None) self.cod = cod.convert(self.ctx) self._enable_checks = self.ctx.enable_checks + @property + def domain(self) -> Domain: + """Domain space of this linear operator.""" + return self.dom + + @property + def codomain(self) -> Codomain: + """Codomain space of this linear operator.""" + return self.cod + @abstractmethod def apply(self, x: Any) -> Any: """ @@ -52,8 +63,90 @@ def rapply(self, y: Any) -> Any: """ def __call__(self, x: Any) -> Any: + """Apply this linear operator to ``x``.""" return self.apply(x) + def adjoint_apply(self, y: Any) -> Any: + """Apply the adjoint of this linear operator to ``y``.""" + return self.rapply(y) + + @property + def H(self) -> LinOp: + """Hermitian-adjoint view of this linear operator.""" + from ._algebra import _AdjointViewLinOp + + return _AdjointViewLinOp(self) + + def __add__(self, other: Any) -> LinOp: + """Return the lazy sum ``self + other`` of two compatible operators.""" + from ._algebra import make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, other)) + + def __radd__(self, other: Any) -> LinOp: + """Return the lazy sum ``other + self`` of two compatible operators.""" + from ._algebra import make_sum + + if isinstance(other, Number) and other == 0: + return self + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, self)) + + def __neg__(self) -> LinOp: + """Return the lazy negation ``-self``.""" + from ._algebra import make_scaled + + return make_scaled(-1, self) + + def __sub__(self, other: Any) -> LinOp: + """Return the lazy difference ``self - other`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, make_scaled(-1, other))) + + def __rsub__(self, other: Any) -> LinOp: + """Return the lazy difference ``other - self`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if isinstance(other, Number) and other == 0: + return make_scaled(-1, self) + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, make_scaled(-1, self))) + + def __mul__(self, scalar: Any) -> LinOp: + """Return the lazy right scalar multiple ``self * scalar``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __rmul__(self, scalar: Any) -> LinOp: + """Return the lazy left scalar multiple ``scalar * self``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __matmul__(self, other: Any) -> LinOp: + """Return the lazy composition ``self @ other`` of two compatible operators.""" + from ._algebra import make_composed + + if not isinstance(other, LinOp): + return NotImplemented + return make_composed(self, other) + + def adjoint(self) -> LinOp: + """Return the Hermitian-adjoint view of this linear operator.""" + return self.H + def assert_domain(self, x: Any) -> None: self.dom.check_member(x) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 26ceb65..45b6f4a 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -19,6 +19,9 @@ def test_expected_names_are_exported(): sc = importlib.import_module("spacecore") expected = { "Context", "BackendOps", "NumpyOps", "DenseLinOp", "SparseLinOp", + "ScaledLinOp", "SumLinOp", "ComposedLinOp", "ZeroLinOp", + "IdentityLinOp", "MatrixFreeLinOp", "make_sum", "make_scaled", + "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", diff --git a/tests/linops/test_algebra_linop.py b/tests/linops/test_algebra_linop.py new file mode 100644 index 0000000..fd2a8f8 --- /dev/null +++ b/tests/linops/test_algebra_linop.py @@ -0,0 +1,244 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _op(matrix, dom_shape, cod_shape, ctx=None): + sc = importlib.import_module("spacecore") + ctx = ctx or _ctx() + dom = sc.VectorSpace(dom_shape, ctx) + cod = sc.VectorSpace(cod_shape, ctx) + return sc.DenseLinOp(ctx.asarray(matrix), dom, cod, ctx) + + +def test_algebra_linops_inherit_from_linop(): + sc = importlib.import_module("spacecore") + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + + assert isinstance(2.0 * A, sc.LinOp) + assert isinstance(A + A, sc.LinOp) + assert isinstance(A @ A, sc.LinOp) + assert isinstance(A.H, sc.LinOp) + assert isinstance(sc.ZeroLinOp(A.domain, A.codomain, A.ctx), sc.LinOp) + assert isinstance(sc.IdentityLinOp(A.domain, A.ctx), sc.LinOp) + assert isinstance(sc.MatrixFreeLinOp(A.apply, A.rapply, A.domain, A.codomain, A.ctx), sc.LinOp) + assert issubclass(sc.ScaledLinOp, sc.LinOp) + assert issubclass(sc.SumLinOp, sc.LinOp) + assert issubclass(sc.ComposedLinOp, sc.LinOp) + assert issubclass(sc.ZeroLinOp, sc.LinOp) + assert issubclass(sc.IdentityLinOp, sc.LinOp) + assert issubclass(sc.MatrixFreeLinOp, sc.LinOp) + assert not hasattr(sc, "AdjointLinOp") + + +def test_context_mismatch_raises_clear_error(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), _ctx(enable_checks=True)) + B = _op([[5.0, 6.0], [7.0, 8.0]], (2,), (2,), _ctx(enable_checks=False)) + + with pytest.raises(ValueError, match="same ctx"): + _ = A + B + with pytest.raises(ValueError, match="same ctx"): + _ = A @ B + + +def test_sum_requires_matching_domain_and_codomain(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + bad_cod = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + bad_dom = _op([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], (3,), (2,), A.ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_cod + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_dom + + +def test_composition_requires_matching_middle_space(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + B = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), A.ctx) + C = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + + assert (A @ B).domain == B.domain + assert (A @ B).codomain == A.codomain + with pytest.raises(ValueError, match="right.codomain == left.domain"): + _ = A @ C + + +def test_scaled_sum_subtraction_and_negation_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0]]) + dense_b = np.array([[5.0, 1.0], [-2.0, 3.0]]) + + expr = 2.0 * A + B - (-A) + + assert expr.domain == A.domain + assert expr.codomain == A.codomain + assert np.allclose(expr.apply(x), (3.0 * dense_a + dense_b) @ np.asarray(x)) + assert np.allclose(expr.rapply(y), (3.0 * dense_a + dense_b).T @ np.asarray(y)) + assert np.allclose((-A).apply(x), -dense_a @ np.asarray(x)) + assert np.allclose((A * 3.0).apply(x), 3.0 * dense_a @ np.asarray(x)) + + +def test_complex_scaled_adjoint_conjugates_scalar(): + ctx = _ctx(np.complex128) + A = _op([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]], (2,), (2,), ctx) + y = ctx.asarray([1.0 - 1.0j, 2.0 + 3.0j]) + dense = np.array([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]]) + alpha = 2.0 + 3.0j + + op = alpha * A + + assert np.allclose(op.rapply(y), np.conj(alpha) * dense.conj().T @ np.asarray(y)) + + +def test_composition_apply_and_adjoint_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + B = _op([[2.0, -1.0], [0.5, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([4.0, -2.0]) + z = ctx.asarray([1.0, -1.0, 2.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + dense_b = np.array([[2.0, -1.0], [0.5, 3.0]]) + + op = A @ B + + assert op.domain == B.domain + assert op.codomain == A.codomain + assert np.allclose(op.apply(x), dense_a @ dense_b @ np.asarray(x)) + assert np.allclose(op.rapply(z), dense_b.T @ dense_a.T @ np.asarray(z)) + + +def test_H_swaps_spaces_and_double_H_returns_original(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + AH = A.H + AHH = AH.H + + assert AH.ctx == A.ctx + assert AH.domain == A.codomain + assert AH.codomain == A.domain + assert np.allclose(AH.apply(y), A.rapply(y)) + assert np.allclose(AH.rapply(x), A.apply(x)) + assert AHH is A + assert np.allclose(AHH.apply(x), A.apply(x)) + assert np.allclose(AHH.rapply(y), A.rapply(y)) + + +def test_zero_identity_and_matrix_free_rapply_are_numerically_correct(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + zero = sc.ZeroLinOp(dom, cod, ctx) + identity = sc.IdentityLinOp(dom, ctx) + matrix_free = sc.MatrixFreeLinOp( + lambda v: ctx.asarray(dense @ np.asarray(v)), + lambda w: ctx.asarray(dense.T @ np.asarray(w)), + dom, + cod, + ctx, + ) + + assert np.allclose(zero.apply(x), np.zeros(3)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(identity.apply(x), np.asarray(x)) + assert np.allclose(identity.rapply(x), np.asarray(x)) + assert np.allclose(matrix_free.apply(x), dense @ np.asarray(x)) + assert np.allclose(matrix_free.rapply(y), dense.T @ np.asarray(y)) + + +def test_sum_factory_flattens_nested_sums_and_removes_zero_terms(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + + nested = sc.SumLinOp((A, B)) + simplified = nested + Z + zero_sum = Z + Z + + assert isinstance(simplified, sc.SumLinOp) + assert simplified.parts == (A, B) + assert A + Z is A + assert Z + A is A + assert isinstance(zero_sum, sc.ZeroLinOp) + assert zero_sum.domain == A.domain + assert zero_sum.codomain == A.codomain + + unsimplified = sc.SumLinOp((nested, Z)) + assert np.allclose(simplified.apply(x), unsimplified.apply(x)) + assert np.allclose(simplified.rapply(y), unsimplified.rapply(y)) + + +def test_scaling_factory_simplifies_zero_one_and_nested_scaling(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0]]) + + zero = 0 * A + unit = 1 * A + nested = 2 * (3 * A) + + assert isinstance(zero, sc.ZeroLinOp) + assert unit is A + assert isinstance(nested, sc.ScaledLinOp) + assert nested.scalar == 6 + assert nested.op is A + assert np.allclose(zero.apply(x), np.zeros(2)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(nested.apply(x), 6 * dense @ np.asarray(x)) + assert np.allclose(nested.rapply(y), 6 * dense.T @ np.asarray(y)) + + +def test_composition_factory_simplifies_identity_and_zero_factors(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + id_domain = sc.IdentityLinOp(A.domain, ctx) + id_codomain = sc.IdentityLinOp(A.codomain, ctx) + left_zero = sc.ZeroLinOp(A.codomain, sc.VectorSpace((4,), ctx), ctx) + right_zero = sc.ZeroLinOp(sc.VectorSpace((5,), ctx), A.domain, ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + assert A @ id_domain is A + assert id_codomain @ A is A + + left_simplified = left_zero @ A + right_simplified = A @ right_zero + + assert isinstance(left_simplified, sc.ZeroLinOp) + assert left_simplified.domain == A.domain + assert left_simplified.codomain == left_zero.codomain + assert isinstance(right_simplified, sc.ZeroLinOp) + assert right_simplified.domain == right_zero.domain + assert right_simplified.codomain == A.codomain + + unsimplified_left = sc.ComposedLinOp(left_zero, A) + assert np.allclose((A @ id_domain).apply(x), dense @ np.asarray(x)) + assert np.allclose((id_codomain @ A).rapply(y), dense.T @ np.asarray(y)) + assert np.allclose(left_simplified.apply(x), unsimplified_left.apply(x)) + assert np.allclose(left_simplified.rapply(ctx.asarray([1.0, 2.0, 3.0, 4.0])), np.zeros(2)) From d5cf13982eaa88e8c3643b4445af082af29cbdb2 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:04:50 -0300 Subject: [PATCH 04/44] Fix LinOp base equality protocol --- spacecore/linop/_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index 05707d5..6cd45aa 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -153,12 +153,14 @@ def assert_domain(self, x: Any) -> None: def assert_codomain(self, y: Any) -> None: self.cod.check_member(y) - def __eq__(self, x: Any) -> bool: - raise NotImplementedError() + def __eq__(self, other: Any) -> bool: + return NotImplemented + @abstractmethod def tree_flatten(self): - raise NotImplementedError() + ... @classmethod + @abstractmethod def tree_unflatten(cls, aux, children): - raise NotImplementedError() + ... From c454332b4f02d19cae37909c41c41ff596833201 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:05:37 -0300 Subject: [PATCH 05/44] Document Space membership check convention --- spacecore/space/_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index 26eaf6f..d206bb4 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -16,6 +16,11 @@ class Space(ContextBound): A Space owns the *geometry* (inner product, norm) and the basic linear structure (add/scale/axpy) for its elements. + Membership validation is exposed through ``check_member``, which respects + the space's ``enable_checks`` policy. Internal code paths that have already + checked that policy may call ``_check_member`` to run the concrete checks + exactly once. + Solvers should use only this API. """ From 2d54d6458f066e16a3156b4510fb974df432a26f Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:06:04 -0300 Subject: [PATCH 06/44] Simplify scaled zero LinOps --- spacecore/linop/_algebra.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index f104c2f..a1e7cc8 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -128,6 +128,8 @@ def make_scaled(scalar: Any, op: LinOp) -> LinOp: return ZeroLinOp(op.domain, op.codomain, op.ctx) if _is_one_scalar(scalar): return op + if isinstance(op, ZeroLinOp): + return op if isinstance(op, ScaledLinOp): return make_scaled(scalar * op.scalar, op.op) return ScaledLinOp(scalar, op) From 57283f806fad0a05644884cf6b0209b663c17da7 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:07:59 -0300 Subject: [PATCH 07/44] Add dense materialization for LinOps --- spacecore/linop/_algebra.py | 20 ++++++++ spacecore/linop/_base.py | 23 +++++++++ spacecore/linop/_dense.py | 8 +++ tests/linops/test_to_dense.py | 93 +++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+) create mode 100644 tests/linops/test_to_dense.py diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index a1e7cc8..cebd8b7 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -375,6 +375,14 @@ def rapply(self, y: Any) -> Any: self.codomain._check_member(y) return self.domain.zeros() + def to_dense(self) -> Any: + """ + Return the dense tensor representation of the zero map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.ops.zeros(tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain and self.codomain == other.codomain @@ -422,6 +430,18 @@ def rapply(self, x: Any) -> Any: self.codomain._check_member(x) return x + def to_dense(self) -> Any: + """ + Return the dense tensor representation of this identity map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + size = 1 + for dim in self.domain.shape: + size *= dim + eye = self.ops.eye(size, dtype=self.dtype) + return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index 6cd45aa..f59d02e 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import abstractmethod +from math import prod from numbers import Number from typing import Any, Generic, TypeVar @@ -147,6 +148,28 @@ def adjoint(self) -> LinOp: """Return the Hermitian-adjoint view of this linear operator.""" return self.H + def to_dense(self) -> Any: + """ + Materialize this operator as a dense backend array. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + The default implementation applies the operator to each standard basis + vector of the domain, stacks the flattened outputs as matrix columns, + and reshapes the result back to tensor-operator form. Subclasses that + already store the matrix should override this method for efficiency. + """ + domain_size = prod(self.domain.shape) + codomain_size = prod(self.codomain.shape) + zero = self.ops.zeros((domain_size,), dtype=self.dtype) + columns = [] + for i in range(domain_size): + basis_vector = self.ops.index_set(zero, i, 1, copy=True) + x = self.domain.unflatten(basis_vector) + y = self.apply(x) + columns.append(self.codomain.flatten(y)) + matrix = self.ops.stack(tuple(columns), axis=1) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def assert_domain(self, x: Any) -> None: self.dom.check_member(x) diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index a751799..9b36a09 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -88,6 +88,14 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def to_dense(self) -> DenseArray: + """ + Return the stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.A + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py new file mode 100644 index 0000000..5ca33f2 --- /dev/null +++ b/tests/linops/test_to_dense.py @@ -0,0 +1,93 @@ +import importlib + +import numpy as np +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _assert_to_dense_matches_apply(op, x): + dense = op.to_dense() + matrix = dense.reshape((np.prod(op.codomain.shape), np.prod(op.domain.shape))) + y_from_dense = matrix @ op.domain.flatten(x) + y_from_apply = op.codomain.flatten(op.apply(x)) + assert np.allclose(y_from_dense, y_from_apply) + + +def test_dense_linop_to_dense_returns_stored_matrix_and_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.to_dense() is A + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_sparse_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(sps.csr_matrix(dense), dom, cod, ctx) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_identity_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + op = sc.IdentityLinOp(space, ctx) + + assert np.allclose(op.to_dense().reshape((4, 4)), np.eye(4)) + _assert_to_dense_matches_apply(op, ctx.asarray([[1.0, 2.0], [3.0, 4.0]])) + + +def test_zero_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + op = sc.ZeroLinOp(dom, cod, ctx) + + assert np.allclose(op.to_dense(), np.zeros((3, 2))) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_matrix_free_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_sum_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), dom, cod, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, 1.0], [-1.0, 2.0], [3.0, -0.5]]), dom, cod, ctx) + op = A + B + + assert np.allclose(op.to_dense(), A.to_dense() + B.to_dense()) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) From 89f007bc3dc38ca59912d132ba96a09800cb547e Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:11:35 -0300 Subject: [PATCH 08/44] Add algebra layer regression tests --- spacecore/linop/_base.py | 7 +- tests/linops/test_algebra.py | 285 +++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) create mode 100644 tests/linops/test_algebra.py diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index f59d02e..d1a05be 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -76,7 +76,11 @@ def H(self) -> LinOp: """Hermitian-adjoint view of this linear operator.""" from ._algebra import _AdjointViewLinOp - return _AdjointViewLinOp(self) + view = getattr(self, "_adjoint_view", None) + if view is None: + view = _AdjointViewLinOp(self) + self._adjoint_view = view + return view def __add__(self, other: Any) -> LinOp: """Return the lazy sum ``self + other`` of two compatible operators.""" @@ -159,7 +163,6 @@ def to_dense(self) -> Any: already store the matrix should override this method for efficiency. """ domain_size = prod(self.domain.shape) - codomain_size = prod(self.codomain.shape) zero = self.ops.zeros((domain_size,), dtype=self.dtype) columns = [] for i in range(domain_size): diff --git a/tests/linops/test_algebra.py b/tests/linops/test_algebra.py new file mode 100644 index 0000000..770b20c --- /dev/null +++ b/tests/linops/test_algebra.py @@ -0,0 +1,285 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, jax_real_dtype +from tests._helpers import to_numpy, torch_complex_dtype + + +def _backend_params(): + params = [pytest.param("numpy", np.complex128, id="numpy")] + params.append( + pytest.param( + "jax", + jax_complex_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ) + ) + params.append( + pytest.param( + "torch", + torch_complex_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ) + ) + return params + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _ctx(dtype=np.complex128, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _spaces(ctx): + sc = importlib.import_module("spacecore") + return sc.VectorSpace((2,), ctx), sc.VectorSpace((3,), ctx) + + +def _matrix(): + return np.array( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + + +def _square_matrix(): + return np.array([[2.0 - 1.0j, -0.5 + 0.25j], [1.25 + 2.0j, -3.0 + 0.5j]]) + + +def _dense_linop(ctx): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(_matrix()), dom, cod, ctx) + + +def _dense_same_shape(ctx, scale=1.0): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(scale * _matrix()), dom, cod, ctx) + + +def _dense_square(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + return sc.DenseLinOp(ctx.asarray(_square_matrix()), dom, dom, ctx) + + +def _xy(ctx): + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + return x, y + + +def _assert_adjoint_identity(op, x, y, ctx): + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + +def _adjoint_cases(ctx): + sc = importlib.import_module("spacecore") + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=0.5 - 0.25j) + C = _dense_square(ctx) + dom, cod = _spaces(ctx) + x, y = _xy(ctx) + z = ctx.asarray([-1.0 + 0.5j, 2.0 - 0.25j]) + + matrix = ctx.asarray(_matrix()) + matrix_free = sc.MatrixFreeLinOp( + lambda v: matrix @ v, + lambda w: ctx.ops.conj(ctx.ops.transpose(matrix)) @ w, + dom, + cod, + ctx, + ) + + return [ + ((2.0 + 3.0j) * A, x, y), + (A + B, x, y), + (A @ C, z, y), + (sc.ZeroLinOp(dom, cod, ctx), x, y), + (sc.IdentityLinOp(dom, ctx), x, x), + (matrix_free, x, y), + (A.H, y, x), + ] + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +@pytest.mark.parametrize("case_index", range(7)) +def test_complex_adjoint_identity_for_algebra_classes(backend_name, dtype, case_index): + sc = importlib.import_module("spacecore") + ctx = sc.Context(_ops_for_backend(backend_name), dtype=dtype) + op, x, y = _adjoint_cases(ctx)[case_index] + + _assert_adjoint_identity(op, x, y, ctx) + + +def test_simplification_canonicalizations(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=2.0) + C = _dense_same_shape(ctx, scale=-1.0) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + + assert sc.make_sum((A, Z)) is A + assert isinstance(sc.make_sum((Z, Z)), sc.ZeroLinOp) + assert sc.make_sum((A,)) is A + flattened = sc.make_sum((sc.make_sum((A, B)), C)) + assert isinstance(flattened, sc.SumLinOp) + assert flattened.parts == (A, B, C) + + scaled_zero = sc.make_scaled(0, A) + assert isinstance(scaled_zero, sc.ZeroLinOp) + assert scaled_zero.domain == A.domain + assert scaled_zero.codomain == A.codomain + assert sc.make_scaled(1, A) is A + assert sc.make_scaled(7.0, Z) is Z + folded = sc.make_scaled(2, sc.make_scaled(3, A)) + assert isinstance(folded, sc.ScaledLinOp) + assert folded.scalar == 6 + assert folded.op is A + + I_dom = sc.IdentityLinOp(A.domain, ctx) + I_cod = sc.IdentityLinOp(A.codomain, ctx) + assert sc.make_composed(I_cod, A) is A + assert sc.make_composed(A, I_dom) is A + + out = sc.VectorSpace((4,), ctx) + left_zero = sc.ZeroLinOp(A.codomain, out, ctx) + composed_zero = sc.make_composed(left_zero, A) + assert isinstance(composed_zero, sc.ZeroLinOp) + assert composed_zero.domain == A.domain + assert composed_zero.codomain == out + + +@pytest.mark.parametrize("case_index", range(7)) +def test_double_adjoint_view_returns_literal_original(case_index): + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + + assert op.H.H is op + + +def test_identity_linop_apply_is_literal_input_when_checks_disabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=False) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + assert op.apply(x) is x + assert op.rapply(x) is x + + +def test_identity_linop_apply_equals_input_when_checks_enabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=True) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + np.testing.assert_allclose(op.apply(x), x) + np.testing.assert_allclose(op.rapply(x), x) + + +def test_python_sum_starts_from_zero_and_accumulates_linops(): + ctx = _ctx() + A = _dense_same_shape(ctx, scale=1.0) + B = _dense_same_shape(ctx, scale=0.5) + C = _dense_same_shape(ctx, scale=-2.0) + x, _ = _xy(ctx) + + op = sum([A, B, C]) + expected = A.apply(x) + B.apply(x) + C.apply(x) + + np.testing.assert_allclose(op.apply(x), expected) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +@pytest.mark.parametrize("case_index", range(7)) +def test_jax_pytree_roundtrip_for_algebra_classes(case_index): + import jax + + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + leaves, treedef = jax.tree.flatten(op) + rebuilt = jax.tree.unflatten(treedef, leaves) + + assert rebuilt == op + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_jit_algebra_expression_matches_eager(): + import jax + + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.JaxOps(), dtype=jax_real_dtype(), enable_checks=False) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, -1.0], [2.0, 1.0], [-0.5, 3.0]]), X, Y, ctx) + C = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.25, 1.5]]), X, X, ctx) + expr = (2 * A + B) @ C + x = ctx.asarray([1.0, -2.0]) + + apply_jit = jax.jit(lambda op, z: op.apply(z)) + + np.testing.assert_allclose(to_numpy(apply_jit(expr, x)), to_numpy(expr.apply(x))) + + +def test_factories_enforce_same_context_dtype(): + sc = importlib.import_module("spacecore") + ctx32 = sc.Context(sc.NumpyOps(), dtype=np.float32) + ctx64 = sc.Context(sc.NumpyOps(), dtype=np.float64) + X32 = sc.VectorSpace((2,), ctx32) + Y32 = sc.VectorSpace((2,), ctx32) + X64 = sc.VectorSpace((2,), ctx64) + Y64 = sc.VectorSpace((2,), ctx64) + A32 = sc.DenseLinOp(ctx32.asarray([[1.0, 2.0], [3.0, 4.0]]), X32, Y32, ctx32) + A64 = sc.DenseLinOp(ctx64.asarray([[1.0, 2.0], [3.0, 4.0]]), X64, Y64, ctx64) + + with pytest.raises(ValueError, match="same ctx"): + sc.make_sum((A32, A64)) + with pytest.raises(ValueError, match="same ctx"): + sc.make_composed(A32, A64) + + +def test_factories_enforce_domain_and_codomain_compatibility(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.float64) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + Z = sc.VectorSpace((4,), ctx) + A = sc.DenseLinOp(ctx.asarray(np.ones((3, 2))), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray(np.ones((4, 2))), X, Z, ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + sc.make_sum((A, B)) + with pytest.raises(ValueError, match="right.codomain == left.domain"): + sc.make_composed(A, B) + + +def test_base_linop_equality_protocol_does_not_raise(): + A = _dense_linop(_ctx()) + + assert (A == None) is False # noqa: E711 + assert A in [A] From 5e816579facf561931ca8a64251c09110555d051 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:22:00 -0300 Subject: [PATCH 09/44] Polish LinOp dense materialization --- spacecore/linop/_base.py | 4 ++-- spacecore/linop/_sparse.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index d1a05be..f006e9a 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -163,10 +163,10 @@ def to_dense(self) -> Any: already store the matrix should override this method for efficiency. """ domain_size = prod(self.domain.shape) - zero = self.ops.zeros((domain_size,), dtype=self.dtype) + eye = self.ops.eye(domain_size, dtype=self.dtype) columns = [] for i in range(domain_size): - basis_vector = self.ops.index_set(zero, i, 1, copy=True) + basis_vector = eye[:, i] x = self.domain.unflatten(basis_vector) y = self.apply(x) columns.append(self.codomain.flatten(y)) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 20780ab..7a8162e 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -87,6 +87,22 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def to_dense(self) -> DenseArray: + """ + Materialize the stored sparse matrix as a dense operator tensor. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + if hasattr(self.A, "toarray"): + dense = self.A.toarray() + elif hasattr(self.A, "todense"): + dense = self.A.todense() + elif hasattr(self.A, "to_dense"): + dense = self.A.to_dense() + else: + dense = super().to_dense().reshape((self._cod_size, self._dom_size)) + return self.ops.reshape(dense, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom From ecbdd29358466286aa3b490d363d7d79a74abe5f Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 01:57:35 -0300 Subject: [PATCH 10/44] Add CuPy backend support --- pyproject.toml | 3 + spacecore/__init__.py | 6 + spacecore/_contextual/contextual.py | 6 + spacecore/backend/__init__.py | 7 + spacecore/backend/_family.py | 1 + spacecore/backend/cupy/__init__.py | 3 + spacecore/backend/cupy/_ops.py | 288 +++++++++++++++++++++++++ tests/_helpers.py | 41 ++++ tests/backend/test_backend_registry.py | 10 +- tests/backend/test_cupy_ops.py | 66 ++++++ tests/integration/test_public_api.py | 6 +- tests/test_backend_ops_complex.py | 10 +- 12 files changed, 444 insertions(+), 3 deletions(-) create mode 100644 spacecore/backend/cupy/__init__.py create mode 100644 spacecore/backend/cupy/_ops.py create mode 100644 tests/backend/test_cupy_ops.py diff --git a/pyproject.toml b/pyproject.toml index 2796cbb..91f9bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ jax = [ torch = [ "torch>=2.0", ] +cupy = [ + "cupy>=13.0", +] examples = [ "matplotlib>=3.8", "optax>=0.2", diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 4371730..6d3fbf0 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -2,6 +2,10 @@ from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class +try: + from .backend import CuPyOps as CuPyOps +except ImportError: + pass try: from .backend import TorchOps as TorchOps except ImportError: @@ -104,3 +108,5 @@ if "TorchOps" in globals(): __all__.append("TorchOps") +if "CuPyOps" in globals(): + __all__.append("CuPyOps") diff --git a/spacecore/_contextual/contextual.py b/spacecore/_contextual/contextual.py index 67efe45..aa21109 100644 --- a/spacecore/_contextual/contextual.py +++ b/spacecore/_contextual/contextual.py @@ -6,6 +6,10 @@ from ..types import DType from ..backend import Context, NumpyOps, JaxOps, BackendFamily, BackendOps +try: + from ..backend import CuPyOps +except ImportError: + pass try: from ..backend import TorchOps except ImportError: @@ -98,6 +102,8 @@ def __init__(self, self._backend_key(NumpyOps): NumpyOps, self._backend_key(JaxOps): JaxOps, } + if "CuPyOps" in globals(): + self._available_ops[self._backend_key(CuPyOps)] = CuPyOps if "TorchOps" in globals(): self._available_ops[self._backend_key(TorchOps)] = TorchOps diff --git a/spacecore/backend/__init__.py b/spacecore/backend/__init__.py index f45af32..99c97f9 100644 --- a/spacecore/backend/__init__.py +++ b/spacecore/backend/__init__.py @@ -3,6 +3,11 @@ from ._family import BackendFamily from .jax import JaxOps, jax_pytree_class from .numpy import NumpyOps +try: + from .cupy import CuPyOps as CuPyOps +except ModuleNotFoundError as exc: + if exc.name != "cupy": + raise try: from .torch import TorchOps as TorchOps @@ -19,5 +24,7 @@ "NumpyOps", ] +if "CuPyOps" in globals(): + __all__.append("CuPyOps") if "TorchOps" in globals(): __all__.append("TorchOps") diff --git a/spacecore/backend/_family.py b/spacecore/backend/_family.py index 90a1183..17de281 100644 --- a/spacecore/backend/_family.py +++ b/spacecore/backend/_family.py @@ -5,3 +5,4 @@ class BackendFamily(StrEnum): numpy = auto() jax = auto() torch = auto() + cupy = auto() diff --git a/spacecore/backend/cupy/__init__.py b/spacecore/backend/cupy/__init__.py new file mode 100644 index 0000000..8908ad9 --- /dev/null +++ b/spacecore/backend/cupy/__init__.py @@ -0,0 +1,3 @@ +from ._ops import CuPyOps as CuPyOps + +__all__ = ["CuPyOps"] diff --git a/spacecore/backend/cupy/_ops.py b/spacecore/backend/cupy/_ops.py new file mode 100644 index 0000000..0d75c2d --- /dev/null +++ b/spacecore/backend/cupy/_ops.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +from typing import Any, Callable, Literal, Optional, Sequence, Tuple, Type + +from .._family import BackendFamily +from .._ops import BackendOps +from ...types import ArrayLike, Carry, DenseArray, DType, Index, R, SparseArray, T, X, Y + + +class CuPyOps(BackendOps): + """ + BackendOps implementation for CuPy GPU arrays. + + This backend uses CuPy for dense array operations and ``cupyx.scipy.sparse`` + for sparse arrays. Most operations follow CuPy's NumPy-compatible API and + execute on the active CUDA device. + + Dense arrays + ``cupy.ndarray`` + + Sparse arrays + ``cupyx.scipy.sparse`` matrix types such as CSR, CSC, and COO. + """ + + import cupy as cp + import cupyx.scipy as cpx_scipy + import cupyx.scipy.sparse as cpx_sparse + + xp = cp + + _family = BackendFamily.cupy.value.lower() + _allow_sparse = True + + @property + def dense_array(self) -> Type[Any]: + """Dense CuPy array type.""" + return self.cp.ndarray + + @property + def sparse_array(self) -> Tuple[Type[Any], ...]: + """Sparse CuPy array type tuple.""" + sparse = self.cpx_sparse + types: list[type[Any]] = [] + for name in ("spmatrix", "csr_matrix", "csc_matrix", "coo_matrix"): + typ = getattr(sparse, name, None) + if typ is not None: + types.append(typ) + return tuple(types) + + def sanitize_dtype(self, dtype: DType | None) -> DType: + """ + Normalize a dtype specifier using CuPy. + + ``None`` follows NumPy/CuPy's float64 default. + """ + if dtype is None: + return self.cp.float64 + return self.cp.dtype(dtype) + + def assparse( + self, + x: Any, + *, + format: Literal["csr", "csc", "coo"] = "csr", + dtype: DType | None = None, + ) -> SparseArray: + """ + Convert input to a CuPy sparse matrix. + + Dense inputs must be two-dimensional. Existing sparse inputs are + converted to the requested sparse format. + """ + sparse = self.cpx_sparse + + if self.is_sparse(x): + if format == "csr": + return x.tocsr() + if format == "csc": + return x.tocsc() + if format == "coo": + return x.tocoo() + raise ValueError(f"Unknown sparse format: {format!r}") + + x_arr = self.asarray(x, dtype=dtype) + if x_arr.ndim != 2: + raise ValueError("CuPy sparse conversion currently expects a 2D array.") + + if format == "csr": + return sparse.csr_matrix(x_arr) + if format == "csc": + return sparse.csc_matrix(x_arr) + if format == "coo": + return sparse.coo_matrix(x_arr) + raise ValueError(f"Unknown sparse format: {format!r}") + + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: + """Multiply a CuPy sparse matrix by a CuPy dense array.""" + if not self.is_sparse(a): + raise TypeError("sparse_matmul expects a CuPy sparse matrix.") + if not self.is_dense(b): + raise TypeError("sparse_matmul expects a CuPy dense array.") + return a @ b + + def logsumexp( + self, + a: DenseArray, + axis: int | Sequence[int] | None = None, + b: DenseArray | None = None, + keepdims: bool = False, + return_sign: bool = False, + ) -> DenseArray | Tuple[DenseArray, DenseArray]: + """Compute log-sum-exp using ``cupyx.scipy.special``.""" + return self.cpx_scipy.special.logsumexp( + a, + axis=axis, + b=b, + keepdims=keepdims, + return_sign=return_sign, + ) + + def index_set( + self, + x: DenseArray, + index: Index, + values: ArrayLike, + *, + copy: bool = True, + ) -> DenseArray: + """Set indexed values in a CuPy array.""" + y = x.copy() if copy else x + y[index] = values + return y + + def index_add( + self, + x: DenseArray, + index: Index, + values: DenseArray, + *, + copy: bool = True, + ) -> DenseArray: + """Add values into indexed entries of a CuPy array.""" + y = x.copy() if copy else x + self.cp.add.at(y, index, values) + return y + + def ix_(self, *args: Any) -> Any: + """Build open-mesh indices using CuPy.""" + return self.cp.ix_(*args) + + def fori_loop( + self, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, + ) -> T: + """Run a counted loop eagerly in Python for CuPy.""" + val = init_val + for i in range(int(lower), int(upper)): + val = body_fun(i, val) + return val + + def while_loop( + self, + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: + """Run a while loop eagerly in Python for CuPy.""" + val = init_val + while bool(cond_fun(val)): + val = body_fun(val) + return val + + def _tree_map(self, f: Callable[[Any], Any], tree: Any) -> Any: + if isinstance(tree, dict): + return {k: self._tree_map(f, v) for k, v in tree.items()} + if isinstance(tree, tuple): + return tuple(self._tree_map(f, v) for v in tree) + if isinstance(tree, list): + return [self._tree_map(f, v) for v in tree] + return f(tree) + + def _tree_multimap(self, f: Callable[..., Any], *trees: Any) -> Any: + t0 = trees[0] + if isinstance(t0, dict): + return {k: self._tree_multimap(f, *(t[k] for t in trees)) for k in t0.keys()} + if isinstance(t0, tuple): + return tuple(self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))) + if isinstance(t0, list): + return [self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))] + return f(*trees) + + def _tree_take0(self, xs: Any) -> Any: + if isinstance(xs, dict): + return self._tree_take0(next(iter(xs.values()))) + if isinstance(xs, (tuple, list)): + return self._tree_take0(xs[0]) + return xs + + def _tree_index(self, xs: Any, i: int) -> Any: + def _idx(a: Any) -> Any: + try: + return a[i] + except Exception: + return a + + return self._tree_map(_idx, xs) + + def _tree_stack(self, ys_list: Sequence[Any]) -> Any: + if not ys_list: + return () + + def _stack_leaves(*leaves: Any) -> Any: + try: + return self.cp.stack(leaves, axis=0) + except Exception: + return self.cp.asarray(leaves) + + return self._tree_multimap(_stack_leaves, *ys_list) + + def scan( + self, + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + ) -> Tuple[Carry, Y]: + """Run a scan loop eagerly in Python for CuPy.""" + carry = init + if xs is None: + if length is None: + raise ValueError("scan(xs=None) requires an explicit `length`.") + n = int(length) + indices = range(n - 1, -1, -1) if reverse else range(n) + ys_steps: list[Any] = [] + for _i in indices: + carry, y = f(carry, None) # type: ignore[arg-type] + ys_steps.append(y) + if reverse: + ys_steps.reverse() + return carry, self._tree_stack(ys_steps) + + if length is None: + leaf0 = self._tree_take0(xs) + try: + n = int(leaf0.shape[0]) + except Exception as e: + raise ValueError( + "Could not infer scan length from `xs`; pass `length=` explicitly." + ) from e + else: + n = int(length) + + indices = range(n - 1, -1, -1) if reverse else range(n) + ys_steps = [] + for i in indices: + x_i = self._tree_index(xs, i) + carry, y = f(carry, x_i) + ys_steps.append(y) + if reverse: + ys_steps.reverse() + return carry, self._tree_stack(ys_steps) + + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: + """Run conditional branch selection eagerly in Python for CuPy.""" + return true_fun(*operands) if bool(pred) else false_fun(*operands) + + def allclose_sparse( + self, + a: SparseArray, + b: SparseArray, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> bool: + """Compare two CuPy sparse matrices by dense values.""" + if not self.is_sparse(a) or not self.is_sparse(b): + raise TypeError("allclose_sparse expects two CuPy sparse matrices.") + return bool(self.cp.asnumpy(self.cp.allclose(a.toarray(), b.toarray(), rtol=rtol, atol=atol))) diff --git a/tests/_helpers.py b/tests/_helpers.py index 267769c..d06511c 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations import importlib.util +from functools import lru_cache import numpy as np @@ -11,6 +12,18 @@ def has_torch() -> bool: return importlib.util.find_spec("torch") is not None +@lru_cache +def has_cupy() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + cupy.asarray([0]).sum() + except Exception: + return False + return True + + def jax_real_dtype(): if not has_jax(): return np.float32 @@ -36,9 +49,37 @@ def torch_complex_dtype(): return torch.complex128 if torch.get_default_dtype() == torch.float64 else torch.complex64 +def cupy_real_dtype(): + return np.float64 + + +def cupy_complex_dtype(): + return np.complex128 + + def to_numpy(x): if isinstance(x, tuple): return tuple(to_numpy(xi) for xi in x) + if has_cupy(): + import cupy + if isinstance(x, cupy.ndarray): + return cupy.asnumpy(x) + try: + import cupyx.scipy.sparse as cupy_sparse + sparse_types = tuple( + typ + for typ in ( + getattr(cupy_sparse, "spmatrix", None), + getattr(cupy_sparse, "csr_matrix", None), + getattr(cupy_sparse, "csc_matrix", None), + getattr(cupy_sparse, "coo_matrix", None), + ) + if typ is not None + ) + if sparse_types and isinstance(x, sparse_types): + return cupy.asnumpy(x.toarray()) + except Exception: + pass if has_torch(): import torch if isinstance(x, torch.Tensor): diff --git a/tests/backend/test_backend_registry.py b/tests/backend/test_backend_registry.py index 65ca47a..2ec1348 100644 --- a/tests/backend/test_backend_registry.py +++ b/tests/backend/test_backend_registry.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from tests._helpers import has_torch +from tests._helpers import has_cupy, has_torch def test_builtin_backends_are_usable(): @@ -73,3 +73,11 @@ def test_torch_backend_aliases_resolve_when_available(): assert isinstance(sc.TorchOps(), sc.BackendOps) assert sc.VectorSpace((1,), "torch").ctx.ops.family == "torch" assert sc.VectorSpace((1,), "pytorch").ctx.ops.family == "torch" + + +@pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") +def test_cupy_backend_alias_resolves_when_available(): + sc = importlib.import_module("spacecore") + + assert isinstance(sc.CuPyOps(), sc.BackendOps) + assert sc.VectorSpace((1,), "cupy").ctx.ops.family == "cupy" diff --git a/tests/backend/test_cupy_ops.py b/tests/backend/test_cupy_ops.py new file mode 100644 index 0000000..dc547ba --- /dev/null +++ b/tests/backend/test_cupy_ops.py @@ -0,0 +1,66 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, to_numpy + + +pytestmark = pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") + + +def _ctx(dtype=np.float64): + sc = importlib.import_module("spacecore") + return sc.Context(sc.CuPyOps(), dtype=dtype) + + +def test_cupy_ops_dense_creation_and_indexing(): + sc = importlib.import_module("spacecore") + ops = sc.CuPyOps() + x = ops.asarray([1.0, 2.0, 3.0], dtype=np.float64) + y = ops.index_set(x, 1, ops.asarray(5.0), copy=True) + z = ops.index_add(y, 0, ops.asarray(2.0), copy=True) + + assert ops.family == "cupy" + assert ops.is_dense(x) + np.testing.assert_allclose(to_numpy(x), [1.0, 2.0, 3.0]) + np.testing.assert_allclose(to_numpy(y), [1.0, 5.0, 3.0]) + np.testing.assert_allclose(to_numpy(z), [3.0, 5.0, 3.0]) + + +def test_cupy_sparse_conversion_and_matmul(): + ctx = _ctx() + dense = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + sparse = ctx.assparse(dense) + x = ctx.asarray([7.0, 8.0]) + + assert ctx.ops.is_sparse(sparse) + np.testing.assert_allclose(to_numpy(ctx.ops.sparse_matmul(sparse, x)), [23.0, 53.0, 83.0]) + assert ctx.ops.allclose_sparse(sparse, ctx.assparse(dense)) + + +def test_cupy_dense_linop_apply_and_rapply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(ctx.asarray(dense), dom, cod, ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + np.testing.assert_allclose(to_numpy(op.apply(x)), dense @ np.asarray([7.0, 8.0])) + np.testing.assert_allclose(to_numpy(op.rapply(y)), dense.T @ np.asarray([1.0, -1.0, 2.0])) + + +def test_cupy_sparse_linop_apply_and_to_dense(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(ctx.assparse(dense), dom, cod, ctx) + x = ctx.asarray([7.0, 8.0]) + + np.testing.assert_allclose(to_numpy(op.apply(x)), dense @ np.asarray([7.0, 8.0])) + np.testing.assert_allclose(to_numpy(op.to_dense()), dense) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 45b6f4a..99b28eb 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -2,7 +2,7 @@ import tomllib from pathlib import Path -from tests._helpers import has_jax, has_torch +from tests._helpers import has_cupy, has_jax, has_torch ROOT = Path(__file__).resolve().parents[2] @@ -31,6 +31,8 @@ def test_expected_names_are_exported(): } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} + if has_cupy(): + expected |= {"CuPyOps"} if has_torch(): expected |= {"TorchOps"} assert expected.issubset(set(sc.__all__)) @@ -45,6 +47,8 @@ def test_top_level_objects_match_source_modules(): assert sc.Context is backend.Context assert sc.NumpyOps is backend.NumpyOps + if has_cupy(): + assert sc.CuPyOps is backend.CuPyOps if has_torch(): assert sc.TorchOps is backend.TorchOps assert sc.Space is space.Space diff --git a/tests/test_backend_ops_complex.py b/tests/test_backend_ops_complex.py index c5b95e0..512b194 100644 --- a/tests/test_backend_ops_complex.py +++ b/tests/test_backend_ops_complex.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from tests._helpers import has_jax, has_torch, jax_complex_dtype, to_numpy, torch_complex_dtype +from tests._helpers import has_cupy, has_jax, has_torch, jax_complex_dtype, to_numpy +from tests._helpers import torch_complex_dtype def _check_vdot_conjugates_first_argument(ops, dtype): @@ -31,3 +32,10 @@ def test_torch_vdot_conjugates_first_argument(): sc = importlib.import_module("spacecore") _check_vdot_conjugates_first_argument(sc.TorchOps(), torch_complex_dtype()) + + +@pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") +def test_cupy_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.CuPyOps(), np.complex128) From 2778c5d7e68269e6caa17d119452ef5bd4f1bd55 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Wed, 20 May 2026 02:01:05 -0300 Subject: [PATCH 11/44] Add backend loop tests --- tests/backend/test_backend_loops.py | 119 ++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/backend/test_backend_loops.py diff --git a/tests/backend/test_backend_loops.py b/tests/backend/test_backend_loops.py new file mode 100644 index 0000000..3b3be1a --- /dev/null +++ b/tests/backend/test_backend_loops.py @@ -0,0 +1,119 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, has_jax, has_torch, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + pytest.param( + "cupy", + np.float64, + marks=pytest.mark.skipif(not has_cupy(), reason="cupy is not installed"), + id="cupy", + ), + ] + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + if name == "cupy": + return sc.CuPyOps() + raise ValueError(f"Unknown backend {name!r}.") + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_fori_loop_accumulates_indices(backend_name, dtype): + ops = _ops_for_backend(backend_name) + + def body_fun(i, carry): + return carry + ops.asarray(i, dtype=dtype) + + out = ops.fori_loop(0, 5, body_fun, ops.asarray(0.0, dtype=dtype)) + + np.testing.assert_allclose(to_numpy(out), 10.0) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_while_loop_reaches_terminal_state(backend_name, dtype): + ops = _ops_for_backend(backend_name) + limit = ops.asarray(4.0, dtype=dtype) + + def cond_fun(carry): + return carry < limit + + def body_fun(carry): + return carry + ops.asarray(1.0, dtype=dtype) + + out = ops.while_loop(cond_fun, body_fun, ops.asarray(0.0, dtype=dtype)) + + np.testing.assert_allclose(to_numpy(out), 4.0) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_scan_accumulates_and_stacks_outputs(backend_name, dtype): + ops = _ops_for_backend(backend_name) + xs = ops.asarray([1.0, 2.0, 3.0, 4.0], dtype=dtype) + + def body_fun(carry, x): + new_carry = carry + x + return new_carry, new_carry * ops.asarray(2.0, dtype=dtype) + + final, ys = ops.scan(body_fun, ops.asarray(0.0, dtype=dtype), xs) + + np.testing.assert_allclose(to_numpy(final), 10.0) + np.testing.assert_allclose(to_numpy(ys), [2.0, 6.0, 12.0, 20.0]) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_scan_without_xs_uses_explicit_length(backend_name, dtype): + ops = _ops_for_backend(backend_name) + + def body_fun(carry, _): + new_carry = carry + ops.asarray(2.0, dtype=dtype) + return new_carry, new_carry + + final, ys = ops.scan(body_fun, ops.asarray(1.0, dtype=dtype), None, length=3) + + np.testing.assert_allclose(to_numpy(final), 7.0) + np.testing.assert_allclose(to_numpy(ys), [3.0, 5.0, 7.0]) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_cond_selects_expected_branch(backend_name, dtype): + ops = _ops_for_backend(backend_name) + x = ops.asarray(3.0, dtype=dtype) + + def true_fun(value): + return value + ops.asarray(10.0, dtype=dtype) + + def false_fun(value): + return value - ops.asarray(10.0, dtype=dtype) + + true_out = ops.cond(True, true_fun, false_fun, x) + false_out = ops.cond(False, true_fun, false_fun, x) + + np.testing.assert_allclose(to_numpy(true_out), 13.0) + np.testing.assert_allclose(to_numpy(false_out), -7.0) From bf00af785fb5028870589df45712851452076150 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 00:24:50 -0300 Subject: [PATCH 12/44] Add jittable linalg solvers --- spacecore/__init__.py | 17 ++ spacecore/linalg/__init__.py | 16 ++ spacecore/linalg/_cg.py | 85 ++++++++++ spacecore/linalg/_krylov.py | 16 ++ spacecore/linalg/_lanczos.py | 188 ++++++++++++++++++++++ spacecore/linalg/_lsqr.py | 130 ++++++++++++++++ spacecore/linalg/_power.py | 85 ++++++++++ spacecore/linalg/_utils.py | 84 ++++++++++ tests/linalg/__init__.py | 1 + tests/linalg/test_krylov.py | 293 +++++++++++++++++++++++++++++++++++ 10 files changed, 915 insertions(+) create mode 100644 spacecore/linalg/__init__.py create mode 100644 spacecore/linalg/_cg.py create mode 100644 spacecore/linalg/_krylov.py create mode 100644 spacecore/linalg/_lanczos.py create mode 100644 spacecore/linalg/_lsqr.py create mode 100644 spacecore/linalg/_power.py create mode 100644 spacecore/linalg/_utils.py create mode 100644 tests/linalg/__init__.py create mode 100644 tests/linalg/test_krylov.py diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 6d3fbf0..9e6384f 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -27,6 +27,15 @@ make_scaled, make_sum, ) +from .linalg import ( + CGResult, + LSQRResult, + PowerIterationResult, + cg, + lsqr, + power_iteration, + stochastic_lanczos, +) from .space import ( BackendCheck, DTypeCheck, @@ -77,6 +86,14 @@ "SumToSingleLinOp", "StackedLinOp", + "CGResult", + "LSQRResult", + "PowerIterationResult", + "cg", + "lsqr", + "power_iteration", + "stochastic_lanczos", + "BackendCheck", "DTypeCheck", "HermitianCheck", diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py new file mode 100644 index 0000000..147ffb2 --- /dev/null +++ b/spacecore/linalg/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from ._cg import CGResult, cg +from ._lanczos import stochastic_lanczos +from ._lsqr import LSQRResult, lsqr +from ._power import PowerIterationResult, power_iteration + +__all__ = [ + "CGResult", + "LSQRResult", + "PowerIterationResult", + "cg", + "lsqr", + "power_iteration", + "stochastic_lanczos", +] diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py new file mode 100644 index 0000000..5e9569b --- /dev/null +++ b/spacecore/linalg/_cg.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import is_converged, real_inner, require_linop, require_square +from ._utils import safe_inverse, should_check_iteration, threshold + + +class CGResult(NamedTuple): + """Result returned by :func:`cg`.""" + + x: Any + converged: Any + num_iters: Any + residual_norm: Any + + +def cg( + A: LinOp, + b: Any, + *, + x0: Any | None = None, + tol: float = 1e-6, + atol: float = 0.0, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> CGResult: + """ + Solve ``A x = b`` by conjugate gradients. + + ``A`` must be a square symmetric/Hermitian positive-definite ``LinOp``. + The implementation uses only ``A.apply`` and the domain-space inner product; + it never materializes a dense matrix. The residual norm is compared with + ``atol + tol * ||b||`` only every ``check_every`` iterations, and always on + the final iteration. This avoids checking the stopping criterion on every + step while remaining compatible with JAX JIT control flow. + """ + A = require_linop(A) + require_square(A, "cg") + A.codomain.check_member(b) + maxiter = check_maxiter(maxiter, A) + check_every = check_interval(check_every) + + x = A.domain.zeros() if x0 is None else x0 + A.domain.check_member(x) + r = A.codomain.add(b, A.codomain.scale(-1.0, A.apply(x))) + p = r + rs = real_inner(A.domain, r, r) + residual_norm = A.domain.norm(r) + threshold_value = threshold(A.codomain.norm(b), tol, atol) + eps = A.ops.asarray(A.ops.eps(A.dtype), dtype=A.dtype) + + def cond_fun(carry: tuple[Any, Any, Any, Any, Any, int]) -> Any: + _x, _r, _p, _rs, res_norm, k = carry + return (k < maxiter) & (res_norm > threshold_value) + + def body_fun(carry: tuple[Any, Any, Any, Any, Any, int]) -> tuple[Any, Any, Any, Any, Any, int]: + x, r, p, rs, _residual_norm, k = carry + Ap = A.apply(p) + pAp = real_inner(A.domain, p, Ap) + active = (rs > eps) & (pAp > eps) + alpha = A.ops.where(active, rs * safe_inverse(A.ops, pAp), A.ops.zeros_like(rs)) + x_next = A.domain.axpy(alpha, p, x) + r_next = A.codomain.axpy(-alpha, Ap, r) + rs_next = real_inner(A.domain, r_next, r_next) + beta = A.ops.where(active, rs_next * safe_inverse(A.ops, rs), A.ops.zeros_like(rs_next)) + p_next = A.domain.axpy(beta, p, r_next) + k_next = k + 1 + current_residual_norm = A.domain.norm(r_next) + residual_norm_next = A.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + lambda _: current_residual_norm, + lambda _: _residual_norm, + A.ops.asarray(0.0, dtype=A.dtype), + ) + return x_next, r_next, p_next, rs_next, residual_norm_next, k_next + + x, _r, _p, _rs, residual_norm, num_iters = A.ops.while_loop( + cond_fun, + body_fun, + (x, r, p, rs, residual_norm, 0), + ) + return CGResult(x, is_converged(residual_norm, threshold_value), num_iters, residual_norm) diff --git a/spacecore/linalg/_krylov.py b/spacecore/linalg/_krylov.py new file mode 100644 index 0000000..147ffb2 --- /dev/null +++ b/spacecore/linalg/_krylov.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from ._cg import CGResult, cg +from ._lanczos import stochastic_lanczos +from ._lsqr import LSQRResult, lsqr +from ._power import PowerIterationResult, power_iteration + +__all__ = [ + "CGResult", + "LSQRResult", + "PowerIterationResult", + "cg", + "lsqr", + "power_iteration", + "stochastic_lanczos", +] diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py new file mode 100644 index 0000000..acf02aa --- /dev/null +++ b/spacecore/linalg/_lanczos.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from typing import Any + +from ..linop import LinOp +from ..types import DenseArray +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval +from ._utils import require_linop, require_square, should_check_iteration + + +def _check_lanczos_max_iter(max_iter: int) -> int: + max_iter = int(max_iter) + if max_iter < 1: + raise ValueError("max_iter must be positive.") + return max_iter + + +def stochastic_lanczos( + A: LinOp, + initial_vector: Any, + *, + max_iter: int = 100, + tol: float = 1e-6, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> tuple[DenseArray, Any]: + r"""Approximate the smallest eigenpair of a Hermitian operator. + + The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an + element of ``A.domain``. The implementation keeps fixed-size coordinate + arrays for JAX compatibility, safely handles zero initial vectors, and + refines the returned eigenvalue with the Rayleigh quotient of the + reconstructed Ritz vector in the original space. + + Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for + ``span{v, T v, T^2 v, ...}`` and a tridiagonal projection + :math:`T_k = V^\dagger T V`. The returned vector is the Ritz vector + reconstructed in the original coordinates, and the returned scalar is the + Rayleigh quotient + :math:`(x^\dagger T x) / (x^\dagger x)`. + + Args: + A: Square Hermitian linear operator. + initial_vector: Starting vector in ``A.domain``. + max_iter: Maximum number of Lanczos steps. + tol: Breakdown tolerance for the off-diagonal Lanczos coefficient. + check_every: Refresh the breakdown-based stopping decision only every + this many iterations, and always on the final iteration. + + Returns: + A pair ``(eigenvalue, eigenvector)`` for the smallest approximated + eigenpair. + """ + A = require_linop(A) + require_square(A, "stochastic_lanczos") + max_iter = _check_lanczos_max_iter(max_iter) + check_every = check_interval(check_every) + A.domain.check_member(initial_vector) + ops = A.ops + ctx = A.ctx + + v0 = A.domain.flatten(initial_vector) + v0 = ctx.assert_dense(v0) + n = v0.shape[0] + + V = ops.zeros((max_iter + 1, n), dtype=ctx.dtype) + alphas = ops.zeros((max_iter,), dtype=ctx.dtype) + betas = ops.zeros((max_iter + 1,), dtype=ctx.dtype) + + tol_s = ctx.asarray(tol) + eps_s = ctx.asarray(1e-12) + + v0_norm = ops.sqrt(ops.real(ops.vdot(v0, v0))) + + e0 = ops.zeros((n,), dtype=ctx.dtype) + e0 = ops.index_set(e0, (0,), ctx.asarray(1.0), copy=True) + + v0_unit = ops.cond( + v0_norm > eps_s, + lambda _: v0 / v0_norm, + lambda _: e0, + ctx.asarray(0.0), + ) + V = ops.index_set(V, (0, slice(None)), v0_unit, copy=True) + + beta0 = ctx.asarray(1.0) + i0 = 0 + keep_going0 = ops.asarray(True) + + full_indices = ops.arange(max_iter + 1) + idx = ops.arange(max_iter) + + def cond_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> Any: + i, _V, _alphas, _betas, _beta, keep_going = state + return (i < max_iter) & keep_going + + def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, Any, Any, Any]: + i, V_, alphas_, betas_, beta, keep_going = state + + v_i = V_[i] + w = A.codomain.flatten(A.apply(A.domain.unflatten(v_i))) + w = ctx.assert_dense(w) + + alpha = ops.real(ops.vdot(v_i, w)) + alphas_ = ops.index_set(alphas_, (i,), alpha, copy=True) + + w = ops.cond( + i == 0, + lambda w_in: w_in - alpha * v_i, + lambda w_in: w_in - alpha * v_i - betas_[i] * V_[i - 1], + w, + ) + + valid = full_indices < (i + 1) + mask = ops.where(valid, ctx.asarray(1.0), ctx.asarray(0.0)) + mask = ops.astype(mask, w.dtype) + + coeffs_full = ops.einsum("jn,n->j", ops.conj(V_), w) + coeffs_valid = coeffs_full * mask + proj = ops.sum(coeffs_valid[:, None] * V_, axis=0) + w = w - proj + + beta_new = ops.sqrt(ops.real(ops.vdot(w, w))) + betas_ = ops.index_set(betas_, (i + 1,), beta_new, copy=True) + + def set_next(V_in: DenseArray) -> DenseArray: + return ops.index_set(V_in, (i + 1, slice(None)), w / beta_new, copy=True) + + V_ = ops.cond(beta_new >= tol_s, set_next, lambda V_in: V_in, V_) + i_next = i + 1 + keep_going_next = ops.cond( + should_check_iteration(i_next, max_iter, check_every), + lambda _: beta_new >= tol_s, + lambda _: keep_going, + ctx.asarray(0.0), + ) + + return i_next, V_, alphas_, betas_, beta_new, keep_going_next + + i_final, V, alphas, betas, _beta_final, _keep_going = ops.while_loop( + cond_fun, body_fun, (i0, V, alphas, betas, beta0, keep_going0) + ) + m = i_final + + mask_alpha = idx < m + alphas_full = ops.where(mask_alpha, alphas, ctx.asarray(1e10)) + betas_full = ops.where(full_indices == m, ctx.asarray(0.0), betas) + + T = ops.zeros((max_iter, max_iter), dtype=ctx.dtype) + + def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: + return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) + + T = ops.fori_loop(0, max_iter, fill_diag, T) + + def fill_off(ii: int, T_in: DenseArray) -> DenseArray: + b = betas_full[ii + 1] + T_in = ops.index_set(T_in, (ii, ii + 1), b, copy=True) + T_in = ops.index_set(T_in, (ii + 1, ii), b, copy=True) + return T_in + + T = ops.fori_loop(0, max_iter - 1, fill_off, T) + + _eigvals, eigvecs = ops.eigh(T) + y_full = eigvecs[:, 0] + + mask_y = ops.where(idx < m, ctx.asarray(1.0), ctx.asarray(0.0)) + mask_y = ops.astype(mask_y, y_full.dtype) + y_valid = y_full * mask_y + + V_reduced = V[:max_iter, :] + x_flat = ops.einsum("j,jn->n", y_valid, V_reduced) + + x_norm = ops.sqrt(ops.real(ops.vdot(x_flat, x_flat))) + x_flat = ops.cond( + x_norm > eps_s, + lambda _: x_flat / x_norm, + lambda _: e0, + ctx.asarray(0.0), + ) + + x = A.domain.unflatten(x_flat) + Ax = A.apply(x) + + num = ops.real(A.domain.inner(x, Ax)) + den = ops.real(A.domain.inner(x, x)) + lam = num / den + + return lam, x diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py new file mode 100644 index 0000000..95a3921 --- /dev/null +++ b/spacecore/linalg/_lsqr.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import is_converged, require_linop, safe_inverse, should_check_iteration +from ._utils import threshold + + +class LSQRResult(NamedTuple): + """Result returned by :func:`lsqr`.""" + + x: Any + converged: Any + num_iters: Any + residual_norm: Any + normal_residual_norm: Any + + +def lsqr( + A: LinOp, + b: Any, + *, + x0: Any | None = None, + tol: float = 1e-6, + atol: float = 0.0, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> LSQRResult: + """ + Solve ``min_x ||A x - b||_2`` by the LSQR Krylov iteration. + + The operator may be rectangular or square. The method uses ``A.apply`` for + forward products and ``A.H.apply`` for adjoint products, so the normal + equations are represented implicitly and no dense matrix is formed. + Convergence is tested against ``atol + tol * ||b||`` using + ``||A.H @ (A x - b)||``. That normal-equation residual is refreshed only + every ``check_every`` iterations, and always on the final iteration, so the + expensive stopping diagnostic is not evaluated on every Krylov step. + """ + A = require_linop(A) + A.codomain.check_member(b) + maxiter = check_maxiter(maxiter, A) + check_every = check_interval(check_every) + + x = A.domain.zeros() if x0 is None else x0 + A.domain.check_member(x) + residual = A.codomain.add(b, A.codomain.scale(-1.0, A.apply(x))) + beta = A.codomain.norm(residual) + normal_residual_norm = A.domain.norm(A.H.apply(residual)) + u = residual + u = A.codomain.scale(safe_inverse(A.ops, beta), u) + v = A.H.apply(u) + alpha = A.domain.norm(v) + v = A.domain.scale(safe_inverse(A.ops, alpha), v) + w = v + phi_bar = beta + rho_bar = alpha + residual_norm = beta + threshold_value = threshold(A.codomain.norm(b), tol, atol) + + def cond_fun(carry: tuple[Any, ...]) -> Any: + _x, _u, _v, _w, _alpha, _beta, _rho_bar, _phi_bar, _res_norm, norm_res, k = carry + return (k < maxiter) & (norm_res > threshold_value) + + def body_fun(carry: tuple[Any, ...]) -> tuple[Any, ...]: + x, u, v, w, alpha, _beta, rho_bar, phi_bar, _residual_norm, _normal_residual, k = carry + u_next = A.codomain.axpy(-alpha, u, A.apply(v)) + beta_next = A.codomain.norm(u_next) + u_next = A.codomain.scale(safe_inverse(A.ops, beta_next), u_next) + + v_next = A.domain.axpy(-beta_next, v, A.H.apply(u_next)) + alpha_next = A.domain.norm(v_next) + v_next = A.domain.scale(safe_inverse(A.ops, alpha_next), v_next) + + rho = A.ops.sqrt(rho_bar * rho_bar + beta_next * beta_next) + inv_rho = safe_inverse(A.ops, rho) + c = rho_bar * inv_rho + s = beta_next * inv_rho + theta = s * alpha_next + rho_bar_next = -c * alpha_next + phi = c * phi_bar + phi_bar_next = s * phi_bar + + x_next = A.domain.axpy(phi * inv_rho, w, x) + w_next = A.domain.axpy(-(theta * inv_rho), w, v_next) + k_next = k + 1 + + def refresh_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]: + x_candidate, _old_residual_norm, _old_normal_residual = payload + residual_next = A.codomain.add(A.apply(x_candidate), A.codomain.scale(-1.0, b)) + return A.codomain.norm(residual_next), A.domain.norm(A.H.apply(residual_next)) + + def keep_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]: + _x_candidate, old_residual_norm, old_normal_residual = payload + return old_residual_norm, old_normal_residual + + residual_norm_next, normal_residual_norm_next = A.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + refresh_residuals, + keep_residuals, + (x_next, _residual_norm, _normal_residual), + ) + return ( + x_next, + u_next, + v_next, + w_next, + alpha_next, + beta_next, + rho_bar_next, + phi_bar_next, + residual_norm_next, + normal_residual_norm_next, + k_next, + ) + + x, *_rest, residual_norm, normal_residual_norm, num_iters = A.ops.while_loop( + cond_fun, + body_fun, + (x, u, v, w, alpha, beta, rho_bar, phi_bar, residual_norm, normal_residual_norm, 0), + ) + return LSQRResult( + x, + is_converged(normal_residual_norm, threshold_value), + num_iters, + residual_norm, + normal_residual_norm, + ) diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py new file mode 100644 index 0000000..8a64c12 --- /dev/null +++ b/spacecore/linalg/_power.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import default_initial_vector, is_converged, normalize, require_linop +from ._utils import require_square, should_check_iteration + + +class PowerIterationResult(NamedTuple): + """Result returned by :func:`power_iteration`.""" + + eigenvalue: Any + eigenvector: Any + converged: Any + num_iters: Any + residual_norm: Any + + +def power_iteration( + A: LinOp, + *, + x0: Any | None = None, + tol: float = 1e-6, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> PowerIterationResult: + """ + Estimate the dominant eigenpair of a square ``LinOp`` by power iteration. + + The method uses only ``A.apply`` and domain-space operations. It returns the + Rayleigh quotient for the current normalized iterate, the eigenvector + estimate, and the residual norm ``||A x - lambda x||``. The residual-based + stopping criterion is refreshed only every ``check_every`` iterations, and + always on the final iteration. For spectral-norm estimates of a rectangular + operator, call this on ``A.H @ A``. + """ + A = require_linop(A) + require_square(A, "power_iteration") + maxiter = check_maxiter(maxiter, A) + check_every = check_interval(check_every) + + x = default_initial_vector(A) if x0 is None else x0 + A.domain.check_member(x) + x, _ = normalize(A.domain, x) + zero = A.ops.asarray(0.0, dtype=A.dtype) + residual_norm = A.domain.norm(x) + float("inf") + + def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: + _eigenvalue, _x, res_norm, k = carry + return (k < maxiter) & (res_norm > tol) + + def body_fun(carry: tuple[Any, Any, Any, int]) -> tuple[Any, Any, Any, int]: + _eigenvalue, x, _residual_norm, k = carry + y = A.apply(x) + x_next, _norm_y = normalize(A.domain, y) + y_next = A.apply(x_next) + eigenvalue_next = A.domain.inner(x_next, y_next) + k_next = k + 1 + + def refresh_residual(_: Any) -> Any: + residual = A.codomain.axpy(-eigenvalue_next, x_next, y_next) + return A.codomain.norm(residual) + + residual_norm_next = A.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + refresh_residual, + lambda _: _residual_norm, + A.ops.asarray(0.0, dtype=A.dtype), + ) + return eigenvalue_next, x_next, residual_norm_next, k_next + + eigenvalue, eigenvector, residual_norm, num_iters = A.ops.while_loop( + cond_fun, + body_fun, + (zero, x, residual_norm, 0), + ) + return PowerIterationResult( + eigenvalue, + eigenvector, + is_converged(residual_norm, tol), + num_iters, + residual_norm, + ) diff --git a/spacecore/linalg/_utils.py b/spacecore/linalg/_utils.py new file mode 100644 index 0000000..773f0fd --- /dev/null +++ b/spacecore/linalg/_utils.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from math import prod +from typing import Any + +from ..linop import LinOp + +DEFAULT_CONVERGENCE_CHECK_INTERVAL = 64 + + +def require_linop(A: Any) -> LinOp: + """Return ``A`` as a ``LinOp`` or raise a clear type error.""" + if not isinstance(A, LinOp): + raise TypeError(f"A must be a LinOp, got {type(A).__name__}.") + return A + + +def require_square(A: LinOp, name: str) -> None: + """Raise if ``A`` is not a square operator.""" + if A.domain != A.codomain: + raise ValueError(f"{name} requires a square LinOp; got {A.domain!r} -> {A.codomain!r}.") + + +def default_maxiter(A: LinOp) -> int: + """Return the default Krylov iteration count for ``A``.""" + return max(1, prod(A.domain.shape)) + + +def check_maxiter(maxiter: int | None, A: LinOp) -> int: + """Validate an optional iteration count.""" + if maxiter is None: + return default_maxiter(A) + maxiter = int(maxiter) + if maxiter < 0: + raise ValueError("maxiter must be nonnegative.") + return maxiter + + +def check_interval(interval: int) -> int: + """Validate a convergence-check interval.""" + interval = int(interval) + if interval < 1: + raise ValueError("check_every must be positive.") + return interval + + +def should_check_iteration(k: Any, maxiter: int, interval: int) -> Any: + """Return whether iteration ``k`` should refresh convergence diagnostics.""" + return (k >= maxiter) | ((k % interval) == 0) + + +def threshold(norm_b: Any, tol: float, atol: float) -> Any: + """Return the absolute-plus-relative convergence threshold.""" + return max(float(atol), 0.0) + max(float(tol), 0.0) * norm_b + + +def real_inner(space: Any, x: Any, y: Any) -> Any: + """Return the real part of ``space.inner(x, y)``.""" + return space.ops.real(space.inner(x, y)) + + +def is_converged(residual_norm: Any, threshold_value: Any) -> Any: + """Return backend-compatible convergence predicate.""" + return residual_norm <= threshold_value + + +def safe_inverse(ops: Any, value: Any) -> Any: + """Return ``1 / value`` where positive and zero otherwise.""" + positive = value > 0 + safe_value = ops.where(positive, value, ops.ones_like(value)) + return ops.where(positive, 1.0 / safe_value, ops.zeros_like(value)) + + +def normalize(space: Any, x: Any) -> tuple[Any, Any]: + """Normalize a space member and return ``(unit, norm)``.""" + norm = space.norm(x) + return space.scale(safe_inverse(space.ops, norm), x), norm + + +def default_initial_vector(A: LinOp) -> Any: + """Return a deterministic nonzero initial vector for ``A.domain``.""" + size = prod(A.domain.shape) + flat = A.ops.ones((size,), dtype=A.dtype) / A.ops.sqrt(A.ops.asarray(float(size))) + return A.domain.unflatten(flat) diff --git a/tests/linalg/__init__.py b/tests/linalg/__init__.py new file mode 100644 index 0000000..cfe7516 --- /dev/null +++ b/tests/linalg/__init__.py @@ -0,0 +1 @@ +"""Tests for SpaceCore linear algebra routines.""" diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py new file mode 100644 index 0000000..25958f4 --- /dev/null +++ b/tests/linalg/test_krylov.py @@ -0,0 +1,293 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, has_jax, has_torch, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + pytest.param( + "cupy", + np.float64, + marks=pytest.mark.skipif(not has_cupy(), reason="cupy is not installed"), + id="cupy", + ), + ] + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + if name == "cupy": + return sc.CuPyOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _ctx(backend_name="numpy", dtype=np.float64): + sc = importlib.import_module("spacecore") + return sc.Context(_ops_for_backend(backend_name), dtype=dtype, enable_checks=False) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_cg_solves_spd_system(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + A = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + b = ctx.asarray([1.0, 2.0]) + + result = sc.cg(A, b, tol=1e-7, maxiter=10) + + np.testing.assert_allclose( + to_numpy(result.x), + np.linalg.solve(np.array([[4.0, 1.0], [1.0, 3.0]]), np.array([1.0, 2.0])), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose(to_numpy(A.apply(result.x)), to_numpy(b), rtol=1e-5, atol=1e-5) + assert bool(to_numpy(result.converged)) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_lsqr_solves_rectangular_least_squares(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + A = sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + b = ctx.asarray([1.0, 2.0, 4.0]) + + result = sc.lsqr(A, b, tol=1e-7, maxiter=10) + + expected, *_ = np.linalg.lstsq(matrix, np.array([1.0, 2.0, 4.0]), rcond=None) + np.testing.assert_allclose(to_numpy(result.x), expected, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(to_numpy(A.H.apply(A.apply(result.x) - b)), [0.0, 0.0], atol=1e-5) + assert bool(to_numpy(result.converged)) + + +def test_lsqr_works_with_matrix_free_linop_and_uses_rapply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + calls = {"rapply": 0} + + def apply(x): + return matrix @ x + + def rapply(y): + calls["rapply"] += 1 + return matrix.T @ y + + A = sc.MatrixFreeLinOp(apply, rapply, domain, codomain, ctx) + b = ctx.asarray([1.0, 2.0, 3.0]) + + result = sc.lsqr(A, b, tol=1e-8, maxiter=10) + + np.testing.assert_allclose(result.x, [1.0, 2.0], rtol=1e-6, atol=1e-6) + assert calls["rapply"] > 0 + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_power_iteration_estimates_dominant_eigenpair(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + x0 = ctx.asarray([1.0, 1.0]) + + result = sc.power_iteration(A, x0=x0, tol=1e-5, maxiter=60) + + np.testing.assert_allclose(to_numpy(result.eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(result.eigenvector)), + [0.0, 1.0], + rtol=1e-4, + atol=1e-4, + ) + assert bool(to_numpy(result.converged)) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + initial = ctx.asarray([1.0, 1.0]) + + eigenvalue, eigenvector = sc.stochastic_lanczos( + op, + initial, + max_iter=2, + tol=1e-8, + ) + + np.testing.assert_allclose(to_numpy(eigenvalue), 2.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(eigenvector)), + [1.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + + +def test_stochastic_lanczos_uses_e0_for_zero_initial_vector(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + initial = ctx.asarray([0.0, 0.0]) + + eigenvalue, eigenvector = sc.stochastic_lanczos( + op, + initial, + max_iter=2, + tol=1e-8, + ) + + np.testing.assert_allclose(eigenvalue, 2.0, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(eigenvector, [1.0, 0.0], rtol=1e-6, atol=1e-6) + + +def test_stochastic_lanczos_rejects_invalid_max_iter(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((1,), ctx) + op = sc.IdentityLinOp(space, ctx) + + with pytest.raises(ValueError, match="max_iter"): + sc.stochastic_lanczos(op, ctx.asarray([1.0]), max_iter=0) + + +def test_iterative_solvers_poll_convergence_on_check_interval(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + spd = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + rectangular = sc.DenseLinOp( + ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), + space, + sc.VectorSpace((3,), ctx), + ctx, + ) + diagonal = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + cg_result = sc.cg(spd, ctx.asarray([1.0, 2.0]), maxiter=65) + lsqr_result = sc.lsqr(rectangular, ctx.asarray([1.0, 2.0, 3.0]), maxiter=65) + power_result = sc.power_iteration(diagonal, x0=ctx.asarray([1.0, 1.0]), maxiter=65) + + assert cg_result.num_iters == 64 + assert lsqr_result.num_iters == 64 + assert power_result.num_iters == 64 + np.testing.assert_allclose(cg_result.residual_norm, 0.0, atol=1e-12) + np.testing.assert_allclose(lsqr_result.normal_residual_norm, 0.0, atol=1e-12) + np.testing.assert_allclose(power_result.residual_norm, 0.0, atol=1e-12) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_cg_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + + solve = jax.jit(lambda A, b: sc.cg(A, b, maxiter=10).x) + x = solve(op, ctx.asarray([1.0, 2.0])) + + np.testing.assert_allclose(to_numpy(x), [0.09090909, 0.63636364], rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_lsqr_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + op = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), domain, codomain, ctx) + + solve = jax.jit(lambda A, b: sc.lsqr(A, b, maxiter=10).x) + x = solve(op, ctx.asarray([1.0, 2.0, 4.0])) + + np.testing.assert_allclose(to_numpy(x), [1.33333333, 2.33333333], rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_power_iteration_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + run = jax.jit(lambda A, x: sc.power_iteration(A, x0=x, maxiter=60).eigenvalue) + eigenvalue = run(op, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_stochastic_lanczos_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + def run(A, initial): + return sc.stochastic_lanczos( + A, + initial, + max_iter=2, + tol=1e-8, + ) + + eigenvalue, eigenvector = jax.jit(run)(op, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 2.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(eigenvector)), + [1.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + + +def test_cg_and_power_iteration_reject_rectangular_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), domain, codomain, ctx) + + with pytest.raises(ValueError, match="square LinOp"): + sc.cg(A, ctx.asarray([1.0, 2.0, 3.0])) + with pytest.raises(ValueError, match="square LinOp"): + sc.power_iteration(A) + with pytest.raises(ValueError, match="square LinOp"): + sc.stochastic_lanczos(A, ctx.asarray([1.0, 2.0])) From d3ab7dde9131b3ce35c8f3130da6cfcdd9cc1200 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 20:06:38 -0300 Subject: [PATCH 13/44] Improve linalg guide and linop representations --- spacecore/linalg/_cg.py | 14 +- spacecore/linalg/_lsqr.py | 15 +- spacecore/linalg/_power.py | 15 +- spacecore/linalg/_utils.py | 30 ++ spacecore/linop/_base.py | 16 + spacecore/linop/_dense.py | 12 +- spacecore/linop/_sparse.py | 13 +- tests/linops/test_to_dense.py | 73 +++++ tutorials/8_Linalg_MatrixFree.ipynb | 473 ++++++++++++++++++++++++++++ 9 files changed, 656 insertions(+), 5 deletions(-) create mode 100644 tutorials/8_Linalg_MatrixFree.ipynb diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index 5e9569b..775b46c 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import is_converged, real_inner, require_linop, require_square -from ._utils import safe_inverse, should_check_iteration, threshold +from ._utils import result_repr, safe_inverse, should_check_iteration, threshold class CGResult(NamedTuple): @@ -16,6 +16,18 @@ class CGResult(NamedTuple): num_iters: Any residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "CGResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "x": self.x, + }, + ) + def cg( A: LinOp, diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index 95a3921..8494df0 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import is_converged, require_linop, safe_inverse, should_check_iteration -from ._utils import threshold +from ._utils import result_repr, threshold class LSQRResult(NamedTuple): @@ -17,6 +17,19 @@ class LSQRResult(NamedTuple): residual_norm: Any normal_residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "LSQRResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "normal_residual_norm": self.normal_residual_norm, + "x": self.x, + }, + ) + def lsqr( A: LinOp, diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 8a64c12..5358b85 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import default_initial_vector, is_converged, normalize, require_linop -from ._utils import require_square, should_check_iteration +from ._utils import require_square, result_repr, should_check_iteration class PowerIterationResult(NamedTuple): @@ -17,6 +17,19 @@ class PowerIterationResult(NamedTuple): num_iters: Any residual_norm: Any + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "PowerIterationResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "eigenvalue": self.eigenvalue, + "residual_norm": self.residual_norm, + "eigenvector": self.eigenvector, + }, + ) + def power_iteration( A: LinOp, diff --git a/spacecore/linalg/_utils.py b/spacecore/linalg/_utils.py index 773f0fd..d9af97e 100644 --- a/spacecore/linalg/_utils.py +++ b/spacecore/linalg/_utils.py @@ -82,3 +82,33 @@ def default_initial_vector(A: LinOp) -> Any: size = prod(A.domain.shape) flat = A.ops.ones((size,), dtype=A.dtype) / A.ops.sqrt(A.ops.asarray(float(size))) return A.domain.unflatten(flat) + + +def summarize_value(value: Any) -> str: + """Return a compact representation for arrays, scalars, and pytrees.""" + shape = getattr(value, "shape", None) + dtype = getattr(value, "dtype", None) + if shape is not None: + shape_text = tuple(shape) + if shape_text == (): + dtype_text = str(dtype) + if dtype_text in {"bool", "bool_", "torch.bool"}: + try: + return repr(bool(value)) + except Exception: + return repr(value) + try: + return f"{float(value):.6g}" + except Exception: + return repr(value) + dtype_text = "" if dtype is None else f", dtype={dtype}" + return f"" + if isinstance(value, tuple): + return "(" + ", ".join(summarize_value(part) for part in value) + ")" + return repr(value) + + +def result_repr(name: str, fields: dict[str, Any]) -> str: + """Return a compact result-object representation.""" + body = ", ".join(f"{key}={summarize_value(value)}" for key, value in fields.items()) + return f"{name}({body})" diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index f006e9a..af9c05f 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -43,6 +43,22 @@ def codomain(self) -> Codomain: """Codomain space of this linear operator.""" return self.cod + @property + def A(self) -> Any: + """ + Native numerical representation of this operator. + + Concrete subclasses may choose the representation that best matches + their storage model: for example, dense operators return a dense array + while sparse operators return their sparse matrix. Matrix-free or lazy + operators generally do not have such a representation and should leave + this property unimplemented. Use :meth:`to_dense` when a dense tensor + materialization is explicitly required. + """ + raise NotImplementedError( + f"{type(self).__name__} does not define a native numerical representation." + ) + @abstractmethod def apply(self, x: Any) -> Any: """ diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 9b36a09..87f38d5 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -40,7 +40,7 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = prod(self.cod.shape) self._dom_size = prod(self.dom.shape) self._matrix_shape = (self._cod_size, self._dom_size) @@ -56,6 +56,16 @@ def __init__(self, self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + @property + def A(self) -> DenseArray: + """ + Stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape`` + and is the same object supplied at construction. + """ + return self._A + def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 7a8162e..cda5ba6 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -37,7 +37,7 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == (prod(cod.shape), prod(dom.shape)) == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = expected[0] self._dom_size = expected[1] dtype = self.ops.get_dtype(self.A) @@ -52,6 +52,17 @@ def __init__(self, self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + @property + def A(self) -> SparseArray: + """ + Stored sparse matrix representation of this operator. + + The returned sparse matrix has shape + ``(prod(self.codomain.shape), prod(self.domain.shape))`` and is the + same object supplied at construction. + """ + return self._A + def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py index 5ca33f2..a60097d 100644 --- a/tests/linops/test_to_dense.py +++ b/tests/linops/test_to_dense.py @@ -1,6 +1,7 @@ import importlib import numpy as np +import pytest import scipy.sparse as sps @@ -29,6 +30,17 @@ def test_dense_linop_to_dense_returns_stored_matrix_and_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_dense_linop_A_returns_stored_dense_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.A is A + + def test_sparse_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -41,6 +53,17 @@ def test_sparse_linop_to_dense_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_sparse_linop_A_returns_stored_sparse_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sps.csr_matrix([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(A, dom, cod, ctx) + + assert op.A is A + + def test_identity_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -80,6 +103,56 @@ def test_matrix_free_linop_to_dense_matches_apply(): _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) +def test_matrix_free_linop_A_is_not_implemented_by_default(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + with pytest.raises(NotImplementedError, match="native numerical representation"): + _ = op.A + + +def test_custom_linop_can_define_A_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + class CustomLinOp(sc.LinOp): + @property + def A(self): + return {"backend": "custom", "data": dense} + + def apply(self, x): + return ctx.asarray(np.asarray(dense) @ np.asarray(x)) + + def rapply(self, y): + return ctx.asarray(np.asarray(dense).T @ np.asarray(y)) + + def tree_flatten(self): + return (), (self.domain, self.codomain, self.ctx) + + @classmethod + def tree_unflatten(cls, aux, children): + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + op = CustomLinOp(dom, cod, ctx) + + assert op.A["backend"] == "custom" + assert op.A["data"] is dense + + def test_sum_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() diff --git a/tutorials/8_Linalg_MatrixFree.ipynb b/tutorials/8_Linalg_MatrixFree.ipynb new file mode 100644 index 0000000..f195685 --- /dev/null +++ b/tutorials/8_Linalg_MatrixFree.ipynb @@ -0,0 +1,473 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# Matrix-free linear algebra\n", + "\n", + "This guide demonstrates SpaceCore's iterative linear algebra routines on `MatrixFreeLinOp` objects.\n", + "\n", + "A matrix-free operator is useful when the action of an operator is cheap, but building or storing the full matrix would be wasteful. In SpaceCore, the only requirements are:\n", + "\n", + "- a domain space,\n", + "- a codomain space,\n", + "- a forward action `apply(x)`,\n", + "- an adjoint action `rapply(y)`.\n", + "\n", + "The solvers below never need a dense matrix representation of the operator." + ] + }, + { + "cell_type": "code", + "id": "imports", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.583864Z", + "start_time": "2026-05-21T15:58:04.539898Z" + } + }, + "source": [ + "import numpy as np\n", + "\n", + "import spacecore as sc" + ], + "outputs": [], + "execution_count": 50 + }, + { + "cell_type": "markdown", + "id": "context", + "metadata": {}, + "source": [ + "## Backend context and vector space\n", + "\n", + "We use NumPy here to keep the notebook easy to run. The same operators can be converted to other supported backends when their callbacks are written using backend-compatible operations." + ] + }, + { + "cell_type": "code", + "id": "setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.612618Z", + "start_time": "2026-05-21T15:58:04.600433Z" + } + }, + "source": [ + "ctx = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False)\n", + "n = 50000\n", + "X = sc.VectorSpace((n,), ctx)\n", + "\n", + "grid = np.linspace(0.0, np.pi, n)" + ], + "outputs": [], + "execution_count": 51 + }, + { + "cell_type": "markdown", + "id": "spd-op", + "metadata": {}, + "source": [ + "## A square Hermitian positive-definite operator\n", + "\n", + "This operator acts like a positive diagonal matrix, but we do not build a matrix object. The callback stores only the diagonal coefficients and multiplies elementwise.\n", + "\n", + "Because the operator is real and self-adjoint, the forward and adjoint callbacks are the same function." + ] + }, + { + "cell_type": "code", + "id": "spd-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.696219Z", + "start_time": "2026-05-21T15:58:04.628355Z" + } + }, + "source": [ + "diag = ctx.asarray(np.concatenate(([0.25], np.linspace(1.0, 2.0, n - 2), [6.0])))\n", + "\n", + "def spd_apply(x):\n", + " return diag * x\n", + "\n", + "A = sc.MatrixFreeLinOp(spd_apply, spd_apply, X, X, ctx)\n", + "\n", + "A.domain.shape, A.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (50000,))" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 52 + }, + { + "cell_type": "markdown", + "id": "cg", + "metadata": {}, + "source": [ + "## Conjugate gradient: solve `A x = b`\n", + "\n", + "`cg` is for square Hermitian positive-definite systems. We create a known solution, apply the operator to get the right-hand side, and ask CG to recover the solution.\n", + "\n", + "The result object summarizes convergence metadata and avoids printing the full solution vector in its `repr`." + ] + }, + { + "cell_type": "code", + "id": "cg-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.725542Z", + "start_time": "2026-05-21T15:58:04.699252Z" + } + }, + "source": [ + "x_true = ctx.asarray(np.sin(grid) + 0.2 * np.cos(3.0 * grid))\n", + "b = A.apply(x_true)\n", + "\n", + "cg_result = sc.cg(A, b, tol=1e-8, maxiter=256)\n", + "cg_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "CGResult(converged=True, num_iters=64, residual_norm=1.18582e-08, x=)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 53 + }, + { + "cell_type": "code", + "id": "cg-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.738009Z", + "start_time": "2026-05-21T15:58:04.726043Z" + } + }, + "source": [ + "relative_error = ctx.ops.norm(cg_result.x - x_true) / ctx.ops.norm(x_true)\n", + "float(relative_error)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "4.5131997046538686e-11" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 54 + }, + { + "cell_type": "markdown", + "id": "lsqr-op", + "metadata": {}, + "source": [ + "## A rectangular matrix-free operator\n", + "\n", + "`lsqr` works for rectangular least-squares problems. Here `B : R^n -> R^{2n}` maps a vector into two measurement channels:\n", + "\n", + "$$\n", + "B x = \\begin{bmatrix}x \\\\ w \\odot x\\end{bmatrix}.\n", + "$$\n", + "\n", + "The adjoint combines the two channels:\n", + "\n", + "$$\n", + "B^* y = y_1 + w \\odot y_2.\n", + "$$\n", + "\n", + "Again, no matrix is built." + ] + }, + { + "cell_type": "code", + "id": "lsqr-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.750694Z", + "start_time": "2026-05-21T15:58:04.739637Z" + } + }, + "source": [ + "Y = sc.VectorSpace((2 * n,), ctx)\n", + "weights = ctx.asarray(np.linspace(0.5, 1.5, n))\n", + "\n", + "def rectangular_apply(x):\n", + " return ctx.ops.concatenate((x, weights * x), axis=0)\n", + "\n", + "def rectangular_rapply(y):\n", + " return y[:n] + weights * y[n:]\n", + "\n", + "B = sc.MatrixFreeLinOp(rectangular_apply, rectangular_rapply, X, Y, ctx)\n", + "\n", + "B.domain.shape, B.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (100000,))" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 55 + }, + { + "cell_type": "markdown", + "id": "lsqr", + "metadata": {}, + "source": [ + "## LSQR: solve a least-squares problem\n", + "\n", + "We add a small deterministic perturbation to the measurements so the problem is not just a perfectly consistent copy of the original vector. LSQR minimizes `||B x - data||` using `B.apply` and `B.H.apply` internally." + ] + }, + { + "cell_type": "code", + "id": "lsqr-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.783046Z", + "start_time": "2026-05-21T15:58:04.754755Z" + } + }, + "source": [ + "x_ls_true = ctx.asarray(np.cos(2.0 * grid))\n", + "noise = 0.01 * ctx.asarray(np.sin(np.linspace(0.0, 4.0 * np.pi, 2 * n)))\n", + "data = B.apply(x_ls_true) + noise\n", + "\n", + "lsqr_result = sc.lsqr(B, data, tol=1e-10, maxiter=256)\n", + "lsqr_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "LSQRResult(converged=True, num_iters=64, residual_norm=0.304159, normal_residual_norm=8.83974e-14, x=)" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 56 + }, + { + "cell_type": "code", + "id": "lsqr-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.795722Z", + "start_time": "2026-05-21T15:58:04.783535Z" + } + }, + "source": [ + "normal_residual = ctx.ops.norm(B.H.apply(B.apply(lsqr_result.x) - data))\n", + "float(normal_residual)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "8.839736157139945e-14" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 57 + }, + { + "cell_type": "markdown", + "id": "power", + "metadata": {}, + "source": [ + "## Power iteration: estimate the dominant eigenpair\n", + "\n", + "`power_iteration` estimates the largest-magnitude eigenvalue of a square operator. For our diagonal example, the answer should be the largest entry in `diag`." + ] + }, + { + "cell_type": "code", + "id": "power-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.812249Z", + "start_time": "2026-05-21T15:58:04.797109Z" + } + }, + "source": [ + "x0 = ctx.asarray(np.ones(n))\n", + "power_result = sc.power_iteration(A, x0=x0, tol=1e-10, maxiter=256)\n", + "power_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "PowerIterationResult(converged=True, num_iters=64, eigenvalue=6, residual_norm=3.25687e-29, eigenvector=)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 58 + }, + { + "cell_type": "code", + "id": "power-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.826966Z", + "start_time": "2026-05-21T15:58:04.812958Z" + } + }, + "source": [ + "float(power_result.eigenvalue), float(diag[-1])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.0, 6.0)" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 59 + }, + { + "cell_type": "markdown", + "id": "lanczos", + "metadata": {}, + "source": [ + "## Stochastic Lanczos: estimate the smallest eigenpair\n", + "\n", + "`stochastic_lanczos` builds a Krylov subspace from an initial domain element and returns a Ritz approximation to the smallest eigenpair. The returned eigenvector is a member of `A.domain`, not a raw matrix column from an internal representation." + ] + }, + { + "cell_type": "code", + "id": "lanczos-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.233074Z", + "start_time": "2026-05-21T15:58:04.828930Z" + } + }, + "source": [ + "initial = ctx.asarray(np.ones(n))\n", + "smallest_eigenvalue, smallest_eigenvector = sc.stochastic_lanczos(\n", + " A,\n", + " initial,\n", + " max_iter=64,\n", + " tol=1e-10,\n", + ")\n", + "\n", + "float(smallest_eigenvalue), smallest_eigenvector.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(0.25, (50000,))" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 60 + }, + { + "cell_type": "code", + "id": "lanczos-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.249991Z", + "start_time": "2026-05-21T15:58:05.233738Z" + } + }, + "source": [ + "ritz_residual = ctx.ops.norm(A.apply(smallest_eigenvector) - smallest_eigenvalue * smallest_eigenvector)\n", + "float(ritz_residual), float(diag[0])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.179265594934942e-15, 0.25)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 61 + }, + { + "cell_type": "markdown", + "id": "wrap", + "metadata": {}, + "source": [ + "## What to take away\n", + "\n", + "- `MatrixFreeLinOp` lets you use iterative algorithms without building a matrix.\n", + "- `cg` is for square Hermitian positive-definite solves.\n", + "- `lsqr` is for general rectangular least-squares problems.\n", + "- `power_iteration` gives a dominant eigenpair estimate.\n", + "- `stochastic_lanczos` gives a smallest Ritz eigenpair estimate from a Krylov subspace.\n", + "- Solver result objects are compact to display, while full arrays remain available as attributes such as `cg_result.x`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d28aa09dafaf48a666a76d50f426b36dd39c1351 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 21:10:11 -0300 Subject: [PATCH 14/44] Add batched linop lifting --- README.md | 16 ++ docs/source/api/linops.rst | 10 ++ docs/source/api/spaces.rst | 10 ++ docs/source/index.rst | 16 ++ docs/source/tutorials/linops.rst | 33 +++++ spacecore/__init__.py | 4 + spacecore/backend/_ops.py | 71 +++++++++ spacecore/backend/jax/_ops.py | 9 ++ spacecore/backend/torch/_ops.py | 14 ++ spacecore/linop/__init__.py | 2 + spacecore/linop/_algebra.py | 124 +++++++++++++++- spacecore/linop/_base.py | 58 ++++++++ spacecore/linop/_dense.py | 39 +++++ spacecore/linop/_diagonal.py | 94 ++++++++++++ spacecore/linop/_sparse.py | 38 +++++ spacecore/linop/product/_block.py | 22 +++ spacecore/linop/product/_from_single.py | 24 +++ spacecore/linop/product/_to_single.py | 24 +++ spacecore/space/__init__.py | 2 + spacecore/space/_base.py | 12 ++ spacecore/space/_batch.py | 185 ++++++++++++++++++++++++ tests/linops/test_batched_lifting.py | 150 +++++++++++++++++++ tests/spaces/test_batch_space.py | 38 +++++ 23 files changed, 992 insertions(+), 3 deletions(-) create mode 100644 spacecore/linop/_diagonal.py create mode 100644 spacecore/space/_batch.py create mode 100644 tests/linops/test_batched_lifting.py create mode 100644 tests/spaces/test_batch_space.py diff --git a/README.md b/README.md index 2459fb9..7566010 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,8 @@ A `Space` describes the structure and geometry of values: - `VectorSpace` for Euclidean vectors and tensors; - `HermitianSpace` for Hermitian or symmetric matrices; - `ProductSpace` for Cartesian products of spaces. +- `BatchSpace` for batched elements such as `X.batch((B,), (0,))`, + representing `B` independent copies of `X`. Algorithms should use space methods such as `zeros`, `add`, `scale`, `axpy`, `inner`, `norm`, `flatten`, and `unflatten` instead of hard-coding backend array @@ -168,6 +170,20 @@ A `LinOp` represents a linear operator between spaces: Operators expose `apply` and `rapply`, so algorithms can use a linear map and its adjoint without depending on the storage format. +For batched inputs, `vapply(xs)` and `rvapply(ys)` lift the operator over the +leading batch axis: + +```python +XB = X.batch(batch_shape=(B,), batch_axes=(0,)) +YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + +ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB +xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB +``` + +The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero, +algebraic, and product-structured operators provide specialized batched paths. + ## Who should use this? SpaceCore is aimed at people writing optimization, inverse-problem, optimal diff --git a/docs/source/api/linops.rst b/docs/source/api/linops.rst index ef963b4..2c59373 100644 --- a/docs/source/api/linops.rst +++ b/docs/source/api/linops.rst @@ -10,6 +10,7 @@ actions. spacecore.linop.LinOp spacecore.linop.ProductLinOp spacecore.linop.DenseLinOp + spacecore.linop.DiagonalLinOp spacecore.linop.SparseLinOp spacecore.linop.BlockDiagonalLinOp spacecore.linop.StackedLinOp @@ -42,6 +43,15 @@ DenseLinOp :inherited-members: :show-inheritance: +DiagonalLinOp +------------- + +.. autoclass:: spacecore.linop.DiagonalLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + SparseLinOp ----------- diff --git a/docs/source/api/spaces.rst b/docs/source/api/spaces.rst index 01075fe..ba94a95 100644 --- a/docs/source/api/spaces.rst +++ b/docs/source/api/spaces.rst @@ -10,6 +10,7 @@ Spaces define element structure, geometry, flattening, and validation. spacecore.space.VectorSpace spacecore.space.HermitianSpace spacecore.space.ProductSpace + spacecore.space.BatchSpace spacecore.space.SpaceCheck spacecore.space.SpaceValidationError @@ -49,6 +50,15 @@ ProductSpace :inherited-members: :show-inheritance: +BatchSpace +---------- + +.. autoclass:: spacecore.space.BatchSpace + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + Validation ---------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 709aae8..01185b4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -157,6 +157,8 @@ A ``Space`` describes the structure and geometry of values: * ``VectorSpace`` for Euclidean vectors and tensors; * ``HermitianSpace`` for Hermitian or symmetric matrices; * ``ProductSpace`` for Cartesian products of spaces. +* ``BatchSpace`` for batched elements such as ``X.batch((B,), (0,))``, + representing ``B`` independent copies of ``X``. Algorithms should use space methods such as ``zeros``, ``add``, ``scale``, ``axpy``, ``inner``, ``norm``, ``flatten``, and ``unflatten`` instead of @@ -176,6 +178,20 @@ A ``LinOp`` represents a linear operator between spaces: Operators expose ``apply`` and ``rapply``, so algorithms can use a linear map and its adjoint without depending on the storage format. +For batched inputs, ``vapply(xs)`` and ``rvapply(ys)`` lift the operator over +the leading batch axis: + +.. code-block:: python + + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB + xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB + +The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero, +algebraic, and product-structured operators provide specialized batched paths. + Who should use this? -------------------- diff --git a/docs/source/tutorials/linops.rst b/docs/source/tutorials/linops.rst index 72d73c4..43fa474 100644 --- a/docs/source/tutorials/linops.rst +++ b/docs/source/tutorials/linops.rst @@ -7,6 +7,7 @@ the abstraction for linear maps between spaces. Current implemented operator types are: * ``DenseLinOp`` +* ``DiagonalLinOp`` * ``SparseLinOp`` * ``BlockDiagonalLinOp`` * ``StackedLinOp`` @@ -60,6 +61,38 @@ operator :math:`A^* : Y \to X`, satisfying \langle Ax, y\rangle_Y = \langle x, A^*y\rangle_X. +Batched lifting +--------------- + +For a batch of elements, use ``Space.batch`` to describe the batched space and +``vapply`` or ``rvapply`` to lift the operator: + +.. math:: + + A : X \to Y, + \qquad + A^{(B)} : X^B \to Y^B. + +.. code-block:: python + + B = 8 + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + xs = ctx.asarray(np.ones((B,) + X.shape)) + ys = op.vapply(xs, batch_space=XB) + xs_back = op.rvapply(ys, batch_space=YB) + +This is equivalent to stacking scalar applications: + +.. code-block:: python + + ys_ref = ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + +The base fallback uses backend ``vmap``. Structured operators override this +path when they can use matrix multiplication, sparse multi-vector products, +broadcasting, or componentwise product-space batching. + DenseLinOp ---------- diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 9e6384f..4925f7a 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -13,6 +13,7 @@ from .linop import ( BlockDiagonalLinOp, ComposedLinOp, + DiagonalLinOp, DenseLinOp, IdentityLinOp, LinOp, @@ -37,6 +38,7 @@ stochastic_lanczos, ) from .space import ( + BatchSpace, BackendCheck, DTypeCheck, HermitianCheck, @@ -72,6 +74,7 @@ "LinOp", "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", "IdentityLinOp", "MatrixFreeLinOp", @@ -100,6 +103,7 @@ "ProductComponentCheck", "ProductStructureCheck", "ShapeCheck", + "BatchSpace", "VectorSpace", "HermitianSpace", "ProductSpace", diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index 09e2f89..fd8eb31 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -399,6 +399,77 @@ def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: """Stack arrays along a new axis (delegates to xp.stack).""" return self.xp.stack(tuple(arrays), axis=axis) + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize ``fn`` over array axes using a Python-loop fallback.""" + + def axis_for_arg(i: int) -> int | Sequence[int | None] | None: + if isinstance(in_axes, tuple) or isinstance(in_axes, list): + return in_axes[i] + return in_axes + + def normalize_axis(axis: int, ndim: int) -> int: + return axis + ndim if axis < 0 else axis + + def tree_size(x: Any, axis: Any) -> int | None: + if axis is None: + return None + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + for xi, ai in zip(x, axes): + size = tree_size(xi, ai) + if size is not None: + return size + return None + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + return int(shape[axis]) + + def tree_take(x: Any, axis: Any, i: int) -> Any: + if axis is None: + return x + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + return tuple(tree_take(xi, ai, i) for xi, ai in zip(x, axes)) + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + index = [slice(None)] * len(shape) + index[axis] = i + return x[tuple(index)] + + def tree_stack(xs: Sequence[Any], axis: Any) -> Any: + first = xs[0] + if isinstance(first, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(first) + return tuple( + tree_stack(tuple(x[i] for x in xs), ai) + for i, ai in enumerate(axes) + ) + if axis is None: + return first + return self.stack(xs, axis=int(axis)) + + def mapped(*args: Any) -> Any: + axes = tuple(axis_for_arg(i) for i in range(len(args))) + size = None + for arg, axis in zip(args, axes): + size = tree_size(arg, axis) + if size is not None: + break + if size is None: + return fn(*args) + outputs = tuple( + fn(*(tree_take(arg, axis, i) for arg, axis in zip(args, axes))) + for i in range(size) + ) + return tree_stack(outputs, out_axes) + + return mapped + def conj(self, x: DenseArray) -> DenseArray: """Complex conjugate of x (delegates to xp.conj).""" return self.xp.conj(x) diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index 04b2eab..d13f995 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -231,6 +231,15 @@ def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ return a @ b + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using ``jax.vmap``.""" + return self.jax.vmap(fn, in_axes=in_axes, out_axes=out_axes) + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: """ diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index f1413d7..dabf42d 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -413,6 +413,20 @@ def sparse_matmul( return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs) + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using PyTorch's native vmap when available.""" + vmap = getattr(self.torch, "vmap", None) + if vmap is None and hasattr(self.torch, "func"): + vmap = getattr(self.torch.func, "vmap", None) + if vmap is None: + return super().vmap(fn, in_axes=in_axes, out_axes=out_axes) + return vmap(fn, in_dims=in_axes, out_dims=out_axes) + def eigh( self, x: DenseArray, diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index f810bb1..7159df4 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -11,12 +11,14 @@ make_sum, ) from ._dense import DenseLinOp +from ._diagonal import DiagonalLinOp from ._sparse import SparseLinOp from .product import ProductLinOp, StackedLinOp, SumToSingleLinOp, BlockDiagonalLinOp __all__ = [ "LinOp", "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", "IdentityLinOp", "MatrixFreeLinOp", diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index cebd8b7..3029fc3 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -198,6 +198,14 @@ def rapply(self, y: Any) -> Any: """Return ``conj(scalar) * op.rapply(y)``.""" return _conjugate_scalar(self.scalar) * self.op.rapply(y) + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``scalar * op.vapply(xs)``.""" + return self.scalar * self.op.vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``conj(scalar) * op.rvapply(ys)``.""" + return _conjugate_scalar(self.scalar) * self.op.rvapply(ys, batch_space) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.scalar == other.scalar and self.op == other.op @@ -268,6 +276,24 @@ def rapply(self, y: Any) -> Any: acc = self.domain.add(acc, op.rapply(y)) return acc + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].vapply(xs)``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + out_space = self._output_batch_space(self.codomain, in_space) + acc = self.ops_tuple[0].vapply(xs, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.vapply(xs, in_space)) + return acc + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].rvapply(ys)``.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + out_space = self._output_batch_space(self.domain, in_space) + acc = self.ops_tuple[0].rvapply(ys, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.rvapply(ys, in_space)) + return acc + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.ops_tuple == other.ops_tuple @@ -322,6 +348,18 @@ def rapply(self, z: Any) -> Any: """Return ``right.rapply(left.rapply(z))``.""" return self.right.rapply(self.left.rapply(z)) + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``left.vapply(right.vapply(xs))``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + middle = self.right.codomain.batch(in_space.batch_shape, in_space.batch_axes) + return self.left.vapply(self.right.vapply(xs, in_space), middle) + + def rvapply(self, zs: Any, batch_space=None) -> Any: + """Return ``right.rvapply(left.rvapply(zs))``.""" + in_space = self._input_batch_space(self.codomain, zs, batch_space) + middle = self.left.domain.batch(in_space.batch_shape, in_space.batch_axes) + return self.right.rvapply(self.left.rvapply(zs, in_space), middle) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.left == other.left and self.right == other.right @@ -375,6 +413,20 @@ def rapply(self, y: Any) -> Any: self.codomain._check_member(y) return self.domain.zeros() + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return the batched zero element of the codomain.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return self._output_batch_space(self.codomain, in_space).zeros() + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return the batched zero element of the domain.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + return self._output_batch_space(self.domain, in_space).zeros() + def to_dense(self) -> Any: """ Return the dense tensor representation of the zero map. @@ -430,6 +482,20 @@ def rapply(self, x: Any) -> Any: self.codomain._check_member(x) return x + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched domain validation.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched codomain validation.""" + in_space = self._input_batch_space(self.codomain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + def to_dense(self) -> Any: """ Return the dense tensor representation of this identity map. @@ -484,14 +550,22 @@ def __init__( dom: Domain, cod: Codomain, ctx: Context | str | None = None, + vapply: Any | None = None, + rvapply: Any | None = None, ) -> None: if not callable(apply): raise TypeError(f"apply must be callable, got {type(apply).__name__}.") if not callable(rapply): raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") + if vapply is not None and not callable(vapply): + raise TypeError(f"vapply must be callable, got {type(vapply).__name__}.") + if rvapply is not None and not callable(rvapply): + raise TypeError(f"rvapply must be callable, got {type(rvapply).__name__}.") super().__init__(dom, cod, ctx) self.apply_fn = apply self.rapply_fn = rapply + self.vapply_fn = vapply + self.rvapply_fn = rvapply def apply(self, x: Any) -> Any: """Return ``apply_fn(x)``.""" @@ -511,6 +585,30 @@ def rapply(self, y: Any) -> Any: self.domain._check_member(x) return x + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``vapply_fn(xs)`` when supplied, otherwise use fallback batching.""" + if self.vapply_fn is None: + return super().vapply(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.vapply_fn(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``rvapply_fn(ys)`` when supplied, otherwise use fallback batching.""" + if self.rvapply_fn is None: + return super().rvapply(ys, batch_space) + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.rvapply_fn(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return ( @@ -518,18 +616,28 @@ def __eq__(self, other: Any) -> bool: and self.codomain == other.codomain and self.apply_fn is other.apply_fn and self.rapply_fn is other.rapply_fn + and self.vapply_fn is other.vapply_fn + and self.rvapply_fn is other.rvapply_fn ) return False def tree_flatten(self): children = () - aux = (self.apply_fn, self.rapply_fn, self.domain, self.codomain, self.ctx) + aux = ( + self.apply_fn, + self.rapply_fn, + self.domain, + self.codomain, + self.ctx, + self.vapply_fn, + self.rvapply_fn, + ) return children, aux @classmethod def tree_unflatten(cls, aux, children): - apply_fn, rapply_fn, domain, codomain, ctx = aux - return cls(apply_fn, rapply_fn, domain, codomain, ctx) + apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn = aux + return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: return MatrixFreeLinOp( @@ -538,6 +646,8 @@ def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx, + self.vapply_fn, + self.rvapply_fn, ) @@ -567,6 +677,14 @@ def rapply(self, x: Any) -> Any: """Return ``op.apply(x)``.""" return self.op.apply(x) + def vapply(self, ys: Any, batch_space=None) -> Any: + """Return ``op.rvapply(ys)`` over a batch.""" + return self.op.rvapply(ys, batch_space) + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``op.vapply(xs)`` over a batch.""" + return self.op.vapply(xs, batch_space) + @property def H(self) -> LinOp[Domain, Codomain]: """Original operator viewed as the adjoint of this adjoint view.""" diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index af9c05f..d5aa793 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -87,6 +87,64 @@ def adjoint_apply(self, y: Any) -> Any: """Apply the adjoint of this linear operator to ``y``.""" return self.rapply(y) + def vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + """Apply this operator independently over a batch of domain elements.""" + return self._fallback_vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + """Apply the adjoint independently over a batch of codomain elements.""" + return self._fallback_rvapply(ys, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _fallback_vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.ops.vmap(self.apply, in_axes=0, out_axes=0)(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def _fallback_rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.ops.vmap(self.rapply, in_axes=0, out_axes=0)(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + @property def H(self) -> LinOp: """Hermitian-adjoint view of this linear operator.""" diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 87f38d5..79ce2dd 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -47,6 +47,7 @@ def __init__(self, self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._A2T = self._A2.T self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) @@ -98,6 +99,44 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + xs2 = self.ops.reshape(xs, (-1, self._dom_size)) + ys2 = self.ops.matmul(xs2, self._A2T) + y_flat_shape = batch_shape + (self._cod_size,) + ys_flat = self.ops.reshape(ys2, y_flat_shape) + if self._cod_vector_fast_path: + ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) + else: + ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + ys2 = self.ops.reshape(ys, (-1, self._cod_size)) + xs2 = self.ops.matmul(ys2, self._A2H.T) + x_flat_shape = batch_shape + (self._dom_size,) + xs_flat = self.ops.reshape(xs2, x_flat_shape) + if self._dom_vector_fast_path: + xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) + else: + xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def to_dense(self) -> DenseArray: """ Return the stored dense tensor representation of this operator. diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py new file mode 100644 index 0000000..9ba6ea7 --- /dev/null +++ b/spacecore/linop/_diagonal.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from math import prod +from typing import Any + +from ._base import LinOp +from ..backend import Context, jax_pytree_class +from ..space import VectorSpace +from ..types import DenseArray +from .._contextual.manager import ctx_manager + + +@jax_pytree_class +class DiagonalLinOp(LinOp[VectorSpace, VectorSpace]): + """Coordinatewise diagonal linear operator on a vector space.""" + + def __init__( + self, + diagonal: DenseArray, + space: VectorSpace | None = None, + ctx: Context | str | None = None, + ) -> None: + ctx = ctx_manager.resolve_context_priority(ctx, space) + ctx.assert_dense(diagonal) + if space is None: + space = VectorSpace(tuple(diagonal.shape), ctx) + super().__init__(space, space, ctx) + expected = tuple(self.domain.shape) + if tuple(diagonal.shape) != expected: + raise TypeError(f"Expected diagonal.shape == space.shape == {expected}, got {diagonal.shape}") + self.diagonal = diagonal + self._size = prod(self.domain.shape) + self._is_flat = tuple(self.domain.shape) == (self._size,) + self._diag_flat = diagonal if self._is_flat else diagonal.reshape((self._size,)) + dtype = self.ops.get_dtype(diagonal) + self._is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._diag_adjoint = self.ops.conj(diagonal) if self._is_complex else diagonal + self._diag_adjoint_flat = ( + self._diag_adjoint if self._is_flat else self._diag_adjoint.reshape((self._size,)) + ) + + @property + def A(self) -> DenseArray: + return self.to_dense() + + def apply(self, x: DenseArray) -> DenseArray: + if self._enable_checks: + self.domain._check_member(x) + return self.diagonal * x + + def rapply(self, y: DenseArray) -> DenseArray: + if self._enable_checks: + self.codomain._check_member(y) + return self._diag_adjoint * y + + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.diagonal * xs + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self._diag_adjoint * ys + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def to_dense(self) -> DenseArray: + matrix = self.ops.diag(self._diag_flat) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose(self.diagonal, other.diagonal) + return False + + def tree_flatten(self): + children = (self.diagonal,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(children[0], domain, ctx) + + def _convert(self, new_ctx: Context) -> DiagonalLinOp: + return DiagonalLinOp(new_ctx.asarray(self.diagonal), self.domain.convert(new_ctx), new_ctx) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index cda5ba6..fca611f 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -98,6 +98,44 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + xs2 = self.ops.reshape(xs, (-1, self._dom_size)) + ys_t = self.ops.sparse_matmul(self.A, self.ops.transpose(xs2)) + ys2 = self.ops.transpose(ys_t) + ys_flat = self.ops.reshape(ys2, batch_shape + (self._cod_size,)) + if self._cod_vector_fast_path: + ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) + else: + ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + ys2 = self.ops.reshape(ys, (-1, self._cod_size)) + xs_t = self.ops.sparse_matmul(self._AH, self.ops.transpose(ys2)) + xs2 = self.ops.transpose(xs_t) + xs_flat = self.ops.reshape(xs2, batch_shape + (self._dom_size,)) + if self._dom_vector_fast_path: + xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) + else: + xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + def to_dense(self) -> DenseArray: """ Materialize the stored sparse matrix as a dense operator tensor. diff --git a/spacecore/linop/product/_block.py b/spacecore/linop/product/_block.py index a2f7fe8..4992c6a 100644 --- a/spacecore/linop/product/_block.py +++ b/spacecore/linop/product/_block.py @@ -53,6 +53,28 @@ def _rapply_unchecked(self, y: Any) -> Any: return self._rapply_parts[0](y[0]), self._rapply_parts[1](y[1]) return tuple(rapply(yi) for rapply, yi in zip(self._rapply_parts, y)) + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + for op, xi in zip(self.parts, x) + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + for op, yi in zip(self.parts, y) + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: if not parts: diff --git a/spacecore/linop/product/_from_single.py b/spacecore/linop/product/_from_single.py index abef674..fabea53 100644 --- a/spacecore/linop/product/_from_single.py +++ b/spacecore/linop/product/_from_single.py @@ -61,6 +61,30 @@ def _rapply_unchecked(self, y: Any) -> Any: acc = xi if acc is None else (acc + xi if use_direct_add else self.dom.add(xi, acc)) return acc + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(x, op.domain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.domain.batch(batch_shape, batch_axes) + acc = None + for op, yi in zip(self.parts, y): + xi = op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + acc = xi if acc is None else out_space.add(acc, xi) + return acc + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: if not parts: diff --git a/spacecore/linop/product/_to_single.py b/spacecore/linop/product/_to_single.py index 806f94e..c5ab082 100644 --- a/spacecore/linop/product/_to_single.py +++ b/spacecore/linop/product/_to_single.py @@ -61,6 +61,30 @@ def _rapply_unchecked(self, y: Any) -> Any: return self._rapply_parts[0](y), self._rapply_parts[1](y) return tuple(rapply(y) for rapply in self._rapply_parts) + def vapply(self, x: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.codomain.batch(batch_shape, batch_axes) + acc = None + for op, xi in zip(self.parts, x): + yi = op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + acc = yi if acc is None else out_space.add(acc, yi) + return acc + + def rvapply(self, y: Any, batch_space=None) -> Any: + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(y, op.codomain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: if not parts: diff --git a/spacecore/space/__init__.py b/spacecore/space/__init__.py index fbb7ada..4883640 100644 --- a/spacecore/space/__init__.py +++ b/spacecore/space/__init__.py @@ -10,11 +10,13 @@ SquareMatrixCheck, ) from ._base import Space +from ._batch import BatchSpace from ._herm import HermitianSpace from ._vector import VectorSpace from ._product import ProductSpace __all__ = [ + "BatchSpace", "BackendCheck", "DTypeCheck", "HermitianCheck", diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index d206bb4..53318d1 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -106,6 +106,18 @@ def unflatten(self, v: DenseArray) -> Any: """Inverse of flatten; returns an element in the requested representation.""" raise NotImplementedError + def batch( + self, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...] | None = None, + ) -> Space: + """Return a wrapper representing a batch/product of this space.""" + from ._batch import BatchSpace + + if batch_axes is None: + batch_axes = tuple(range(len(batch_shape))) + return BatchSpace(self, batch_shape, batch_axes) + def _convert(self, new_ctx: Context) -> Space: raise NotImplementedError() diff --git a/spacecore/space/_batch.py b/spacecore/space/_batch.py new file mode 100644 index 0000000..b227c0d --- /dev/null +++ b/spacecore/space/_batch.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Callable, Tuple + +from ._base import Space +from ._checks import BackendCheck, DTypeCheck, ShapeCheck +from ._product import ProductSpace +from ..backend import Context +from ..types import DenseArray + + +def _batched_shape( + base_shape: tuple[int, ...], + batch_shape: tuple[int, ...], + batch_axes: tuple[int, ...], +) -> tuple[int, ...]: + total_ndim = len(base_shape) + len(batch_shape) + axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + if len(batch_shape) != len(axes): + raise ValueError("batch_shape and batch_axes must have the same length.") + if len(set(axes)) != len(axes): + raise ValueError("batch_axes must be unique.") + if any(axis < 0 or axis >= total_ndim for axis in axes): + raise ValueError( + f"batch_axes must be valid axes for batched ndim {total_ndim}, got {batch_axes}." + ) + + out: list[int | None] = [None] * total_ndim + for axis, dim in zip(axes, batch_shape): + out[axis] = int(dim) + + base_iter = iter(int(dim) for dim in base_shape) + for i, dim in enumerate(out): + if dim is None: + out[i] = next(base_iter) + return tuple(dim for dim in out if dim is not None) + + +class BatchSpace(Space): + """ + Wrapper space representing a batch of elements from a base space. + + ``BatchSpace(X, batch_shape, batch_axes)`` represents ``X`` repeated over + the given batch dimensions. It deliberately wraps the original space rather + than folding batch dimensions into the base ``Space`` instance. + """ + + def __init__( + self, + base: Space, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...], + ctx: Context | str | None = None, + ) -> None: + ctx = base.ctx if ctx is None else ctx + super().__init__( + _batched_shape(tuple(base.shape), tuple(batch_shape), tuple(batch_axes)), + ctx, + ) + self.base = base.convert(self.ctx) + self.batch_shape = tuple(int(dim) for dim in batch_shape) + total_ndim = len(self.base.shape) + len(self.batch_shape) + self.batch_axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + self._batch_size = prod(self.batch_shape) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BatchSpace): + return ( + self.ctx == other.ctx + and self.base == other.base + and self.batch_shape == other.batch_shape + and self.batch_axes == other.batch_axes + ) + return False + + @property + def _is_product(self) -> bool: + return isinstance(self.base, ProductSpace) + + def _component_spaces(self) -> tuple[BatchSpace, ...]: + if not isinstance(self.base, ProductSpace): + raise TypeError("BatchSpace component spaces are available only for ProductSpace bases.") + return tuple(sp.batch(self.batch_shape, self.batch_axes) for sp in self.base.spaces) + + def _check_member(self, x: Any) -> None: + if isinstance(self.base, ProductSpace): + if not isinstance(x, tuple) or len(x) != self.base.arity: + raise TypeError( + f"BatchSpace over ProductSpace expects tuple length {self.base.arity}." + ) + for space, component in zip(self._component_spaces(), x): + space.check_member(component) + return + + BackendCheck()(self, x) + ShapeCheck()(self, x) + DTypeCheck()(self, x) + for check in self.base.member_checks(): + if isinstance(check, (BackendCheck, ShapeCheck, DTypeCheck)): + continue + check(self, x) + + def zeros(self) -> Any: + if isinstance(self.base, ProductSpace): + return tuple(space.zeros() for space in self._component_spaces()) + return self.ops.zeros(self.shape, dtype=self.dtype) + + def add(self, x: Any, y: Any) -> Any: + if self._enable_checks: + self._check_member(x) + self._check_member(y) + if isinstance(self.base, ProductSpace): + return tuple(space.add(xi, yi) for space, xi, yi in zip(self._component_spaces(), x, y)) + return x + y + + def scale(self, a: Any, x: Any) -> Any: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + return tuple(space.scale(a, xi) for space, xi in zip(self._component_spaces(), x)) + return a * x + + def inner(self, x: Any, y: Any) -> Any: + if self._enable_checks: + self._check_member(x) + self._check_member(y) + if isinstance(self.base, ProductSpace): + acc = None + for space, xi, yi in zip(self._component_spaces(), x, y): + v = space.inner(xi, yi) + acc = v if acc is None else acc + v + return acc + return self.ops.vdot(x, y) + + def eigh(self, x: Any, k: int = None) -> Any: + raise TypeError(f"{type(self).__name__}.eigh is not defined for batched spaces.") + + def flatten(self, x: Any) -> DenseArray: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + parts = tuple(space.flatten(xi) for space, xi in zip(self._component_spaces(), x)) + return parts[0] if len(parts) == 1 else self.ops.concatenate(parts, axis=0) + return self.ops.reshape(x, (-1,)) + + def unflatten(self, v: DenseArray) -> Any: + vv = self.ctx.assert_dense(v) if self._enable_checks else v + if isinstance(self.base, ProductSpace): + if ( + tuple(getattr(vv, "shape", ())) == tuple(self.shape) + and self.batch_axes == tuple(range(len(self.batch_shape))) + ): + xs = [] + offset = 0 + for component, space in zip(self.base.spaces, self._component_spaces()): + size = prod(component.shape) + flat_component = vv[(..., slice(offset, offset + size))] + xs.append(space.unflatten(flat_component)) + offset += size + return tuple(xs) + xs = [] + offset = 0 + for space in self._component_spaces(): + size = prod(space.shape) + xs.append(space.unflatten(vv[offset : offset + size])) + offset += size + return tuple(xs) + return self.ops.reshape(vv, self.shape) + + def apply(self, x: Any, f: Callable) -> Any: + if self._enable_checks: + self._check_member(x) + if isinstance(self.base, ProductSpace): + return tuple(space.apply(xi, f) for space, xi in zip(self._component_spaces(), x)) + try: + y = f(x) + except Exception: + y = self.ops.vmap(lambda xi: self.base.apply(xi, f))(x) + if self._enable_checks: + self._check_member(y) + return y + + def _convert(self, new_ctx: Context) -> BatchSpace: + return BatchSpace(self.base.convert(new_ctx), self.batch_shape, self.batch_axes, new_ctx) diff --git a/tests/linops/test_batched_lifting.py b/tests/linops/test_batched_lifting.py new file mode 100644 index 0000000..2b3b107 --- /dev/null +++ b/tests/linops/test_batched_lifting.py @@ -0,0 +1,150 @@ +import importlib + +import numpy as np +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _stack_apply(ctx, op, xs): + return ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + + +def _stack_rapply(ctx, op, ys): + return ctx.ops.stack(tuple(op.rapply(y) for y in ys), axis=0) + + +def test_dense_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(matrix, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_sparse_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(ctx.assparse(sps.csr_matrix(dense)), dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_diagonal_identity_zero_sum_composed_and_adjoint_batched_lifting(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + d1 = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + d2 = sc.DiagonalLinOp(ctx.asarray([-1.0, 0.5, 4.0]), space, ctx) + identity = sc.IdentityLinOp(space, ctx) + zero = sc.ZeroLinOp(space, space, ctx) + summed = d1 + d2 + zero + composed = d1 @ (d2 + identity) + adjoint = composed.H + xs = ctx.asarray([[1.0, 2.0, 3.0], [4.0, -1.0, 0.0]]) + + for op in (d1, identity, zero, summed, composed): + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(xs), _stack_rapply(ctx, op, xs)) + assert np.allclose(adjoint.vapply(xs), _stack_apply(ctx, adjoint, xs)) + assert np.allclose(adjoint.rvapply(xs), _stack_rapply(ctx, adjoint, xs)) + + +def test_matrix_free_vapply_uses_callback_when_supplied_and_fallback_when_absent(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + calls = {"vapply": 0, "rvapply": 0} + + def apply(x): + return ctx.asarray(matrix @ np.asarray(x)) + + def rapply(y): + return ctx.asarray(matrix.T @ np.asarray(y)) + + def vapply(xs): + calls["vapply"] += 1 + return ctx.asarray(np.asarray(xs) @ matrix.T) + + def rvapply(ys): + calls["rvapply"] += 1 + return ctx.asarray(np.asarray(ys) @ matrix) + + with_callbacks = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx, vapply, rvapply) + fallback = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(with_callbacks.vapply(xs), _stack_apply(ctx, with_callbacks, xs)) + assert np.allclose(with_callbacks.rvapply(ys), _stack_rapply(ctx, with_callbacks, ys)) + assert calls == {"vapply": 1, "rvapply": 1} + assert np.allclose(fallback.vapply(xs), _stack_apply(ctx, fallback, xs)) + assert np.allclose(fallback.rvapply(ys), _stack_rapply(ctx, fallback, ys)) + + +def test_product_linops_batched_lifting_matches_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + y0 = sc.VectorSpace((3,), ctx) + y1 = sc.VectorSpace((2,), ctx) + a0 = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), x0, y0, ctx) + a1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0, 0.5], [0.0, 3.0, 4.0]]), x1, y1, ctx) + s1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.0, 3.0]]), x0, y1, ctx) + block = sc.BlockDiagonalLinOp.from_operators((a0, a1)) + stacked = sc.StackedLinOp.from_operators((a0, s1)) + sum_to_single = sc.SumToSingleLinOp.from_operators((a0.H, s1.H)) + + xb = (ctx.asarray([[1.0, 2.0], [3.0, 4.0]]), ctx.asarray([[5.0, 6.0, 7.0], [1.0, -1.0, 0.5]])) + yb = (ctx.asarray([[1.0, 2.0, 3.0], [0.0, -1.0, 4.0]]), ctx.asarray([[2.0, 1.0], [3.0, -2.0]])) + single_x = ctx.asarray([[1.0, 2.0], [3.0, 4.0]]) + single_y = ctx.asarray([[1.0, 2.0], [3.0, -2.0]]) + + block_v = block.vapply(xb) + block_expected = tuple(ctx.ops.stack(tuple(block.apply((xb[0][i], xb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_v[0], block_expected[0]) + assert np.allclose(block_v[1], block_expected[1]) + + block_rv = block.rvapply(yb) + block_r_expected = tuple(ctx.ops.stack(tuple(block.rapply((yb[0][i], yb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_rv[0], block_r_expected[0]) + assert np.allclose(block_rv[1], block_r_expected[1]) + + stacked_v = stacked.vapply(single_x) + stacked_expected = tuple(ctx.ops.stack(tuple(stacked.apply(single_x[i])[j] for i in range(2))) for j in range(2)) + assert np.allclose(stacked_v[0], stacked_expected[0]) + assert np.allclose(stacked_v[1], stacked_expected[1]) + + assert np.allclose( + stacked.rvapply(yb), + ctx.ops.stack(tuple(stacked.rapply((yb[0][i], yb[1][i])) for i in range(2))), + ) + assert np.allclose( + sum_to_single.vapply(yb), + ctx.ops.stack(tuple(sum_to_single.apply((yb[0][i], yb[1][i])) for i in range(2))), + ) + sum_rv = sum_to_single.rvapply(single_y) + sum_r_expected = tuple( + ctx.ops.stack(tuple(sum_to_single.rapply(single_y[i])[j] for i in range(2))) + for j in range(2) + ) + assert np.allclose(sum_rv[0], sum_r_expected[0]) + assert np.allclose(sum_rv[1], sum_r_expected[1]) diff --git a/tests/spaces/test_batch_space.py b/tests/spaces/test_batch_space.py new file mode 100644 index 0000000..3053806 --- /dev/null +++ b/tests/spaces/test_batch_space.py @@ -0,0 +1,38 @@ +import importlib + +import numpy as np + + +def test_vector_space_batch_wrapper_shape_and_membership(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x = sc.VectorSpace((2, 3), ctx) + + xb = x.batch(batch_shape=(4,), batch_axes=(0,)) + + assert isinstance(xb, sc.BatchSpace) + assert xb.base == x + assert xb.batch_shape == (4,) + assert xb.batch_axes == (0,) + assert xb.shape == (4, 2, 3) + xb.check_member(ctx.ops.zeros((4, 2, 3), dtype=ctx.dtype)) + + +def test_product_space_batch_wrapper_validates_component_batches(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + product = sc.ProductSpace((x0, x1), ctx) + batched = product.batch((5,), (0,)) + + value = ( + ctx.ops.zeros((5, 2), dtype=ctx.dtype), + ctx.ops.zeros((5, 3), dtype=ctx.dtype), + ) + + assert batched.shape == (5, 5) + batched.check_member(value) + zeros = batched.zeros() + assert np.allclose(zeros[0], np.zeros((5, 2))) + assert np.allclose(zeros[1], np.zeros((5, 3))) From 82ac1b999bd570ff87b6e609d7b4b667c9e85e4d Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 22:49:54 -0300 Subject: [PATCH 15/44] Optimize batched dense and sparse linops --- spacecore/linop/_dense.py | 85 ++++++++++++++++++++++------ spacecore/linop/_sparse.py | 85 ++++++++++++++++++++++------ tests/linops/test_batched_lifting.py | 25 ++++++++ 3 files changed, 163 insertions(+), 32 deletions(-) diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 79ce2dd..244f566 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -56,6 +56,8 @@ def __init__(self, if not self._enable_checks: self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + self.vapply = self._vapply_unchecked + self.rvapply = self._rvapply_unchecked @property def A(self) -> DenseArray: @@ -99,6 +101,71 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = xs2 @ self._A2T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = ys2 @ self._A2H.T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.domain, xs, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): @@ -106,14 +173,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(xs) batch_shape = tuple(in_space.batch_shape) - xs2 = self.ops.reshape(xs, (-1, self._dom_size)) - ys2 = self.ops.matmul(xs2, self._A2T) - y_flat_shape = batch_shape + (self._cod_size,) - ys_flat = self.ops.reshape(ys2, y_flat_shape) - if self._cod_vector_fast_path: - ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) - else: - ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + ys = self._vapply_unchecked_leading(xs, batch_shape) if self._enable_checks: self._output_batch_space(self.codomain, in_space)._check_member(ys) return ys @@ -125,14 +185,7 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(ys) batch_shape = tuple(in_space.batch_shape) - ys2 = self.ops.reshape(ys, (-1, self._cod_size)) - xs2 = self.ops.matmul(ys2, self._A2H.T) - x_flat_shape = batch_shape + (self._dom_size,) - xs_flat = self.ops.reshape(xs2, x_flat_shape) - if self._dom_vector_fast_path: - xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) - else: - xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + xs = self._rvapply_unchecked_leading(ys, batch_shape) if self._enable_checks: self._output_batch_space(self.domain, in_space)._check_member(xs) return xs diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index fca611f..ab4f306 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -51,6 +51,8 @@ def __init__(self, if not self._enable_checks: self.apply = self._apply_unchecked self.rapply = self._rapply_unchecked + self.vapply = self._vapply_unchecked + self.rvapply = self._rvapply_unchecked @property def A(self) -> SparseArray: @@ -98,6 +100,71 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = (self.A @ xs2.T).T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = (self._AH @ ys2.T).T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.domain, xs, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): @@ -105,14 +172,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(xs) batch_shape = tuple(in_space.batch_shape) - xs2 = self.ops.reshape(xs, (-1, self._dom_size)) - ys_t = self.ops.sparse_matmul(self.A, self.ops.transpose(xs2)) - ys2 = self.ops.transpose(ys_t) - ys_flat = self.ops.reshape(ys2, batch_shape + (self._cod_size,)) - if self._cod_vector_fast_path: - ys = self.ops.reshape(ys2, batch_shape + tuple(self.cod.shape)) - else: - ys = self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + ys = self._vapply_unchecked_leading(xs, batch_shape) if self._enable_checks: self._output_batch_space(self.codomain, in_space)._check_member(ys) return ys @@ -124,14 +184,7 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: if self._enable_checks: in_space._check_member(ys) batch_shape = tuple(in_space.batch_shape) - ys2 = self.ops.reshape(ys, (-1, self._cod_size)) - xs_t = self.ops.sparse_matmul(self._AH, self.ops.transpose(ys2)) - xs2 = self.ops.transpose(xs_t) - xs_flat = self.ops.reshape(xs2, batch_shape + (self._dom_size,)) - if self._dom_vector_fast_path: - xs = self.ops.reshape(xs2, batch_shape + tuple(self.dom.shape)) - else: - xs = self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + xs = self._rvapply_unchecked_leading(ys, batch_shape) if self._enable_checks: self._output_batch_space(self.domain, in_space)._check_member(xs) return xs diff --git a/tests/linops/test_batched_lifting.py b/tests/linops/test_batched_lifting.py index 2b3b107..efb8f62 100644 --- a/tests/linops/test_batched_lifting.py +++ b/tests/linops/test_batched_lifting.py @@ -9,6 +9,11 @@ def _ctx(): return sc.Context(sc.NumpyOps(), dtype=np.float64) +def _unchecked_ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + + def _stack_apply(ctx, op, xs): return ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) @@ -45,6 +50,26 @@ def test_sparse_linop_vapply_and_rvapply_match_stacked_apply(): assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) +def test_dense_and_sparse_batched_lifting_fast_paths_without_checks(): + sc = importlib.import_module("spacecore") + ctx = _unchecked_ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + sparse = ctx.assparse(sps.csr_matrix([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]])) + dense_op = sc.DenseLinOp(matrix, dom, cod, ctx) + sparse_op = sc.SparseLinOp(sparse, dom, cod, ctx) + batch_dom = dom.batch((3,), (0,)) + batch_cod = cod.batch((2,), (0,)) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(dense_op.vapply(xs, batch_dom), np.asarray(xs) @ np.asarray(matrix).T) + assert np.allclose(dense_op.rvapply(ys, batch_cod), np.asarray(ys) @ np.asarray(matrix)) + assert np.allclose(sparse_op.vapply(xs, batch_dom), (sparse @ np.asarray(xs).T).T) + assert np.allclose(sparse_op.rvapply(ys, batch_cod), (sparse.T @ np.asarray(ys).T).T) + + def test_diagonal_identity_zero_sum_composed_and_adjoint_batched_lifting(): sc = importlib.import_module("spacecore") ctx = _ctx() From cff5d83c73e986e15ef36969fe6b3a09d037c7e4 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 23:26:55 -0300 Subject: [PATCH 16/44] Add functional abstractions --- README.md | 14 ++- docs/source/api/functionals.rst | 60 ++++++++++ docs/source/api/index.rst | 1 + docs/source/index.rst | 16 ++- spacecore/__init__.py | 15 +++ spacecore/functional/__init__.py | 12 ++ spacecore/functional/_base.py | 130 +++++++++++++++++++++ spacecore/functional/_linear.py | 167 +++++++++++++++++++++++++++ spacecore/functional/_quadratic.py | 117 +++++++++++++++++++ tests/functional/test_functional.py | 117 +++++++++++++++++++ tests/integration/test_public_api.py | 5 + 11 files changed, 652 insertions(+), 2 deletions(-) create mode 100644 docs/source/api/functionals.rst create mode 100644 spacecore/functional/__init__.py create mode 100644 spacecore/functional/_base.py create mode 100644 spacecore/functional/_linear.py create mode 100644 spacecore/functional/_quadratic.py create mode 100644 tests/functional/test_functional.py diff --git a/README.md b/README.md index 7566010..ae4e09a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ mathematical objects: - a `Space` knows the structure and geometry of its elements; - a `LinOp` maps one space to another; +- a `Functional` maps a space element to a scalar; - backend-specific array creation and operations live behind `BackendOps`. The result is ordinary Python code whose core numerical logic is not tied to @@ -28,7 +29,7 @@ one array library. Mental model: ```text -BackendOps -> Context -> Space/LinOp -> Algorithm +BackendOps -> Context -> Space/LinOp/Functional -> Algorithm ``` ## Write once, run twice @@ -184,6 +185,17 @@ xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero, algebraic, and product-structured operators provide specialized batched paths. +### `Functional` + +A `Functional` represents a scalar-valued map on a space. `LinearFunctional` +covers maps such as ``, `MatrixFreeLinearFunctional` wraps a callable +without storing a representer, and `LinOpQuadraticForm` represents objectives +such as `0.5 * + ell(x) + a`. + +For batched inputs, `vvalue(xs)` evaluates independently over leading batch +axes. Quadratic forms that define gradients also expose `grad(x)` and +`vgrad(xs)`. + ## Who should use this? SpaceCore is aimed at people writing optimization, inverse-problem, optimal diff --git a/docs/source/api/functionals.rst b/docs/source/api/functionals.rst new file mode 100644 index 0000000..147e6bf --- /dev/null +++ b/docs/source/api/functionals.rst @@ -0,0 +1,60 @@ +Functionals API +=============== + +Functionals represent scalar-valued maps on spaces, including linear +functionals and quadratic forms. + +.. autosummary:: + :nosignatures: + + spacecore.functional.Functional + spacecore.functional.LinearFunctional + spacecore.functional.InnerProductFunctional + spacecore.functional.MatrixFreeLinearFunctional + spacecore.functional.QuadraticForm + spacecore.functional.LinOpQuadraticForm + +Functional +---------- + +.. autoclass:: spacecore.functional.Functional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Linear functionals +------------------ + +.. autoclass:: spacecore.functional.LinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.InnerProductFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.MatrixFreeLinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Quadratic forms +--------------- + +.. autoclass:: spacecore.functional.QuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.LinOpQuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 259ac9d..654fbb8 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,3 +11,4 @@ directives for public objects instead of dumping entire modules. context spaces linops + functionals diff --git a/docs/source/index.rst b/docs/source/index.rst index 01185b4..470a78d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ mathematical objects: * a ``Space`` knows the structure and geometry of its elements; * a ``LinOp`` maps one space to another; +* a ``Functional`` maps a space element to a scalar; * backend-specific array creation and operations live behind ``BackendOps``. The result is ordinary Python code whose core numerical logic is not tied to @@ -31,7 +32,7 @@ Mental model: .. code-block:: text - BackendOps -> Context -> Space/LinOp -> Algorithm + BackendOps -> Context -> Space/LinOp/Functional -> Algorithm Write once, run twice --------------------- @@ -192,6 +193,19 @@ the leading batch axis: The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero, algebraic, and product-structured operators provide specialized batched paths. +``Functional`` +~~~~~~~~~~~~~~ + +A ``Functional`` represents a scalar-valued map on a space. +``LinearFunctional`` covers maps such as ````, +``MatrixFreeLinearFunctional`` wraps a callable without storing a representer, +and ``LinOpQuadraticForm`` represents objectives such as +``0.5 * + ell(x) + a``. + +For batched inputs, ``vvalue(xs)`` evaluates independently over leading batch +axes. Quadratic forms that define gradients also expose ``grad(x)`` and +``vgrad(xs)``. + Who should use this? -------------------- diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 4925f7a..fa702ed 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -28,6 +28,14 @@ make_scaled, make_sum, ) +from .functional import ( + Functional, + InnerProductFunctional, + LinearFunctional, + LinOpQuadraticForm, + MatrixFreeLinearFunctional, + QuadraticForm, +) from .linalg import ( CGResult, LSQRResult, @@ -89,6 +97,13 @@ "SumToSingleLinOp", "StackedLinOp", + "Functional", + "LinearFunctional", + "InnerProductFunctional", + "MatrixFreeLinearFunctional", + "QuadraticForm", + "LinOpQuadraticForm", + "CGResult", "LSQRResult", "PowerIterationResult", diff --git a/spacecore/functional/__init__.py b/spacecore/functional/__init__.py new file mode 100644 index 0000000..19bae03 --- /dev/null +++ b/spacecore/functional/__init__.py @@ -0,0 +1,12 @@ +from ._base import Functional +from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional +from ._quadratic import LinOpQuadraticForm, QuadraticForm + +__all__ = [ + "Functional", + "InnerProductFunctional", + "LinearFunctional", + "LinOpQuadraticForm", + "MatrixFreeLinearFunctional", + "QuadraticForm", +] diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py new file mode 100644 index 0000000..185a827 --- /dev/null +++ b/spacecore/functional/_base.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Generic, TypeVar + +from .._contextual import ContextBound +from .._contextual.manager import ctx_manager +from ..backend import Context +from ..space import Space + + +Domain = TypeVar("Domain", bound=Space) + + +class Functional(ContextBound, Generic[Domain]): + """ + Scalar-valued map on a space. + + ``Functional`` represents a map ``F : X -> K`` without assuming any storage + model. It mirrors the minimal ``LinOp`` contract: the domain is converted + into the resolved context, value checks follow ``ctx.enable_checks``, and + batched evaluation is implemented by a backend ``vmap`` fallback. + """ + + def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: + ctx = ctx_manager.resolve_context_priority(ctx, dom) + super().__init__(ctx) + self.dom = dom.convert(self.ctx) + self._enable_checks = self.ctx.enable_checks + + @property + def domain(self) -> Domain: + """Domain space of this scalar-valued map.""" + return self.dom + + @abstractmethod + def value(self, x: Any) -> Any: + """ + Evaluate this functional at ``x``. + + Contract: + - x is an element of ``self.domain``; + - the return value is scalar-like in the functional context. + """ + + def __call__(self, x: Any) -> Any: + """Evaluate this functional at ``x``.""" + return self.value(x) + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate this functional independently over leading batch axes.""" + return self._fallback_vvalue(xs, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]: + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + expected_axes = tuple(range(len(batch_shape))) + if batch_axes != expected_axes: + raise ValueError( + "Functional batching currently expects leading batch axes; " + f"got batch_axes={batch_axes}, expected {expected_axes}." + ) + return batch_shape + + def _vmap_leading(self, fn: Any, batch_ndim: int) -> Any: + mapped = fn + for _ in range(batch_ndim): + mapped = self.ops.vmap(mapped, in_axes=0, out_axes=0) + return mapped + + def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None: + shape = tuple(getattr(values, "shape", ())) + if shape != batch_shape: + raise ValueError( + f"Expected scalar batch output with shape {batch_shape}, got {shape}." + ) + + def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self._vmap_leading(self.value, len(batch_shape))(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def assert_domain(self, x: Any) -> None: + self.dom.check_member(x) + + @abstractmethod + def tree_flatten(self): + ... + + @classmethod + @abstractmethod + def tree_unflatten(cls, aux, children): + ... diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py new file mode 100644 index 0000000..ca349a3 --- /dev/null +++ b/spacecore/functional/_linear.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +from ._base import Domain, Functional +from ..backend import Context, jax_pytree_class +from ..space import Space + + +def _convert_space_element(space: Space, value: Any) -> Any: + if hasattr(space, "spaces") and isinstance(value, tuple): + if len(value) != len(space.spaces): + raise ValueError( + f"Expected tuple of length {len(space.spaces)}, got {len(value)}." + ) + return tuple( + _convert_space_element(component_space, component) + for component_space, component in zip(space.spaces, value) + ) + return space.ctx.asarray(value) + + +class LinearFunctional(Functional[Domain]): + """Linear scalar-valued map ``ell : X -> K``.""" + + @property + @abstractmethod + def representer(self) -> Any: + """ + Riesz representer of this functional when one is explicitly available. + + Matrix-free functionals may not have a stored representer and should + raise ``NotImplementedError``. + """ + + +@jax_pytree_class +class InnerProductFunctional(LinearFunctional[Domain]): + """ + Linear functional represented by a domain element. + + ``InnerProductFunctional(c, X)`` evaluates ``ell_c(x) = _X``. + """ + + def __init__( + self, + c: Any, + dom: Domain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, ctx) + self._c = _convert_space_element(self.domain, c) + if self._enable_checks: + self.domain._check_member(self._c) + + @property + def representer(self) -> Any: + """Stored domain element ``c`` defining ``ell_c(x) = ``.""" + return self._c + + def value(self, x: Any) -> Any: + """Return ``domain.inner(representer, x)``.""" + if self._enable_checks: + self.domain._check_member(x) + return self.domain.inner(self._c, x) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose( + self.domain.flatten(self._c), + other.domain.flatten(other._c), + ) + return False + + def tree_flatten(self): + children = (self._c,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + c = children[0] + return cls(c, domain, ctx) + + def _convert(self, new_ctx: Context) -> InnerProductFunctional: + return InnerProductFunctional(self._c, self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinearFunctional(LinearFunctional[Domain]): + """ + Linear functional defined by user-supplied evaluation callables. + + No representer is stored or materialized. + """ + + def __init__( + self, + value: Any, + dom: Domain, + ctx: Context | str | None = None, + vvalue: Any | None = None, + ) -> None: + if not callable(value): + raise TypeError(f"value must be callable, got {type(value).__name__}.") + if vvalue is not None and not callable(vvalue): + raise TypeError(f"vvalue must be callable, got {type(vvalue).__name__}.") + super().__init__(dom, ctx) + self.value_fn = value + self.vvalue_fn = vvalue + + @property + def representer(self) -> Any: + raise NotImplementedError( + f"{type(self).__name__} does not store a Riesz representer." + ) + + def value(self, x: Any) -> Any: + """Return ``value_fn(x)``.""" + if self._enable_checks: + self.domain._check_member(x) + y = self.value_fn(x) + if self._enable_checks: + self._check_scalar_batch(y, ()) + return y + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Return ``vvalue_fn(xs)`` when supplied, otherwise use fallback batching.""" + if self.vvalue_fn is None: + return super().vvalue(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self.vvalue_fn(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.domain == other.domain + and self.value_fn is other.value_fn + and self.vvalue_fn is other.vvalue_fn + ) + return False + + def tree_flatten(self): + children = () + aux = (self.value_fn, self.domain, self.ctx, self.vvalue_fn) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + value_fn, domain, ctx, vvalue_fn = aux + return cls(value_fn, domain, ctx, vvalue_fn) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinearFunctional: + return MatrixFreeLinearFunctional( + self.value_fn, + self.domain.convert(new_ctx), + new_ctx, + self.vvalue_fn, + ) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py new file mode 100644 index 0000000..484b3e9 --- /dev/null +++ b/spacecore/functional/_quadratic.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import Any + +from ._base import Domain, Functional +from ._linear import LinearFunctional +from .._contextual.manager import ctx_manager +from ..backend import Context, jax_pytree_class +from ..linop import LinOp +from ..space import Space + + +class QuadraticForm(Functional[Domain]): + """Scalar quadratic objective on a space.""" + + def grad(self, x: Any) -> Any: + """Gradient at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define grad.") + + def vgrad(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate ``grad`` independently over leading batch axes.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + grads = self._vmap_leading(self.grad, len(batch_shape))(xs) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(grads) + return grads + + +@jax_pytree_class +class LinOpQuadraticForm(QuadraticForm[Domain]): + """ + Quadratic form backed by a linear operator. + + ``q(x) = 1/2 * + linear(x) + a`` with ``Q : X -> X``. + """ + + def __init__( + self, + Q: LinOp[Domain, Domain], + linear: LinearFunctional[Domain] | None = None, + a: Any = 0, + ctx: Context | str | None = None, + ) -> None: + if not isinstance(Q, LinOp): + raise TypeError(f"Q must be a LinOp, got {type(Q).__name__}.") + if linear is not None and not isinstance(linear, LinearFunctional): + raise TypeError( + f"linear must be a LinearFunctional or None, got {type(linear).__name__}." + ) + + resolved_ctx = ctx_manager.resolve_context_priority(ctx, Q.domain, Q, linear) + Q = Q.convert(resolved_ctx) + if Q.domain != Q.codomain: + raise ValueError("LinOpQuadraticForm requires Q.domain == Q.codomain.") + if linear is not None: + linear = linear.convert(resolved_ctx) + if linear.domain != Q.domain: + raise ValueError("linear.domain must match Q.domain.") + + super().__init__(Q.domain, resolved_ctx) + self.Q = Q + self.linear = linear + self.a = self.ctx.asarray(a) + if self._enable_checks: + self._check_scalar_batch(self.a, ()) + + def value(self, x: Any) -> Any: + """Return ``1/2 * + linear(x) + a``.""" + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + value = 0.5 * self.domain.inner(x, qx) + if self.linear is not None: + value = value + self.linear.value(x) + return value + self.a + + def grad(self, x: Any) -> Any: + """ + Return the Euclidean/Riesz gradient. + + The quadratic part uses the symmetric adjoint part ``(Q + Q*) / 2``. + For self-adjoint ``Q`` this is exactly ``Qx``. + """ + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + qhx = self.Q.rapply(x) + grad = self.domain.scale(0.5, self.domain.add(qx, qhx)) + if self.linear is not None: + grad = self.domain.add(grad, self.linear.representer) + return grad + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.Q == other.Q + and self.linear == other.linear + and self.ops.allclose(self.a, other.a) + ) + return False + + def tree_flatten(self): + children = (self.Q, self.linear, self.a) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + Q, linear, a = children + return cls(Q, linear, a, Q.ctx) + + def _convert(self, new_ctx: Context) -> LinOpQuadraticForm: + linear = None if self.linear is None else self.linear.convert(new_ctx) + return LinOpQuadraticForm(self.Q.convert(new_ctx), linear, self.a, new_ctx) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py new file mode 100644 index 0000000..b923a83 --- /dev/null +++ b/tests/functional/test_functional.py @@ -0,0 +1,117 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _quadratic_problem(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 4.0]]), dom, dom, ctx) + linear = sc.InnerProductFunctional(ctx.asarray([1.0, -1.0]), dom, ctx) + return sc.LinOpQuadraticForm(Q, linear, 3.0, ctx) + + +def test_explicit_context_overrides_inferred_contexts(): + sc = importlib.import_module("spacecore") + inferred = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=False) + dom = sc.VectorSpace((2,), inferred) + Q = sc.DenseLinOp(inferred.asarray([[1.0, 0.0], [0.0, 1.0]]), dom, dom, inferred) + linear = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom) + + functional = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom, explicit) + quadratic = sc.LinOpQuadraticForm(Q, linear, 0.0, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert quadratic.ctx == explicit + assert quadratic.Q.ctx == explicit + assert quadratic.linear.ctx == explicit + + +def test_domain_conversion_and_membership_checks_work(): + sc = importlib.import_module("spacecore") + source = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=True) + dom = sc.VectorSpace((2,), source) + functional = sc.InnerProductFunctional(source.asarray([1.0, 2.0]), dom, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert functional.domain.ctx.enable_checks is True + assert np.allclose(functional.value(functional.domain.ctx.asarray([3.0, 4.0])), 11.0) + with pytest.raises(Exception): + functional.value(explicit.asarray([1.0, 2.0, 3.0])) + + +def test_call_matches_value(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + assert np.allclose(q(x), q.value(x)) + + +def test_inner_product_functional_matches_domain_inner(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([1.0, -2.0]) + x = ctx.asarray([3.0, 4.0]) + functional = sc.InnerProductFunctional(c, dom, ctx) + + assert np.allclose(functional.value(x), dom.inner(c, x)) + + +def test_matrix_free_linear_functional_has_no_representer(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([2.0, 3.0]) + x = ctx.asarray([4.0, 5.0]) + functional = sc.MatrixFreeLinearFunctional(lambda y: dom.inner(c, y), dom, ctx) + + assert np.allclose(functional.value(x), 23.0) + with pytest.raises(NotImplementedError): + functional.representer + + +def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + + assert np.allclose(q.value(x), 12.0) + assert np.allclose(q.grad(x), [5.0, -5.0]) + + +def test_vvalue_and_vgrad_match_elementwise_loops(): + ctx = _ctx() + q = _quadratic_problem(ctx) + xs = ctx.asarray([[2.0, -1.0], [0.0, 3.0], [1.5, 2.0]]) + + expected_values = ctx.ops.stack(tuple(q.value(x) for x in xs), axis=0) + expected_grads = ctx.ops.stack(tuple(q.grad(x) for x in xs), axis=0) + + assert np.allclose(q.vvalue(xs), expected_values) + assert np.allclose(q.vgrad(xs), expected_grads) + + +def test_bad_shapes_raise_when_checks_are_enabled(): + ctx = _ctx(enable_checks=True) + q = _quadratic_problem(ctx) + bad = ctx.asarray([1.0, 2.0, 3.0]) + + with pytest.raises(Exception): + q.value(bad) + with pytest.raises(Exception): + q.grad(bad) + with pytest.raises(Exception): + q.vvalue(ctx.asarray([[1.0, 2.0, 3.0]])) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 99b28eb..b9ff355 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -23,6 +23,8 @@ def test_expected_names_are_exported(): "IdentityLinOp", "MatrixFreeLinOp", "make_sum", "make_scaled", "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", + "Functional", "LinearFunctional", "InnerProductFunctional", + "MatrixFreeLinearFunctional", "QuadraticForm", "LinOpQuadraticForm", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", "set_context", "get_context", "resolve_context_priority", "register_ops", @@ -43,6 +45,7 @@ def test_top_level_objects_match_source_modules(): backend = importlib.import_module("spacecore.backend") space = importlib.import_module("spacecore.space") linop = importlib.import_module("spacecore.linop") + functional = importlib.import_module("spacecore.functional") manager = importlib.import_module("spacecore._contextual.manager") assert sc.Context is backend.Context @@ -54,6 +57,8 @@ def test_top_level_objects_match_source_modules(): assert sc.Space is space.Space assert sc.VectorSpace is space.VectorSpace assert sc.DenseLinOp is linop.DenseLinOp + assert sc.Functional is functional.Functional + assert sc.InnerProductFunctional is functional.InnerProductFunctional assert sc.get_context is manager.get_context assert sc.resolve_context_priority is manager.resolve_context_priority From a37b3c35e86b763d7e99145ec1b1f7c0686b3d00 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Thu, 21 May 2026 23:45:07 -0300 Subject: [PATCH 17/44] Generalize power iteration dispatch --- spacecore/__init__.py | 2 + spacecore/functional/_quadratic.py | 12 ++++ spacecore/linalg/__init__.py | 3 +- spacecore/linalg/_lanczos.py | 29 +++++++-- spacecore/linalg/_power.py | 89 +++++++++++++++++++++------- tests/functional/test_functional.py | 1 + tests/integration/test_public_api.py | 3 + tests/linalg/test_krylov.py | 86 +++++++++++++++++++++++++++ 8 files changed, 197 insertions(+), 28 deletions(-) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index fa702ed..91e63c3 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -40,6 +40,7 @@ CGResult, LSQRResult, PowerIterationResult, + StochasticLanczosResult, cg, lsqr, power_iteration, @@ -107,6 +108,7 @@ "CGResult", "LSQRResult", "PowerIterationResult", + "StochasticLanczosResult", "cg", "lsqr", "power_iteration", diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py index 484b3e9..ebbe79f 100644 --- a/spacecore/functional/_quadratic.py +++ b/spacecore/functional/_quadratic.py @@ -13,6 +13,10 @@ class QuadraticForm(Functional[Domain]): """Scalar quadratic objective on a space.""" + def hess_apply(self, x: Any) -> Any: + """Apply the Hessian action at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define hess_apply.") + def grad(self, x: Any) -> Any: """Gradient at ``x`` when available.""" raise NotImplementedError(f"{type(self).__name__} does not define grad.") @@ -93,6 +97,14 @@ def grad(self, x: Any) -> Any: grad = self.domain.add(grad, self.linear.representer) return grad + def hess_apply(self, x: Any) -> Any: + """Return the self-adjoint Hessian action ``(Q + Q*) x / 2``.""" + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + qhx = self.Q.rapply(x) + return self.domain.scale(0.5, self.domain.add(qx, qhx)) + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return ( diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index 147ffb2..06be398 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._cg import CGResult, cg -from ._lanczos import stochastic_lanczos +from ._lanczos import StochasticLanczosResult, stochastic_lanczos from ._lsqr import LSQRResult, lsqr from ._power import PowerIterationResult, power_iteration @@ -9,6 +9,7 @@ "CGResult", "LSQRResult", "PowerIterationResult", + "StochasticLanczosResult", "cg", "lsqr", "power_iteration", diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index acf02aa..2f1ec3e 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -1,11 +1,29 @@ from __future__ import annotations -from typing import Any +from typing import Any, NamedTuple from ..linop import LinOp from ..types import DenseArray from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval from ._utils import require_linop, require_square, should_check_iteration +from ._utils import result_repr + + +class StochasticLanczosResult(NamedTuple): + """Result returned by :func:`stochastic_lanczos`.""" + + eigenvalue: Any + eigenvector: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "StochasticLanczosResult", + { + "eigenvalue": self.eigenvalue, + "eigenvector": self.eigenvector, + }, + ) def _check_lanczos_max_iter(max_iter: int) -> int: @@ -22,7 +40,7 @@ def stochastic_lanczos( max_iter: int = 100, tol: float = 1e-6, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, -) -> tuple[DenseArray, Any]: +) -> StochasticLanczosResult: r"""Approximate the smallest eigenpair of a Hermitian operator. The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an @@ -47,8 +65,9 @@ def stochastic_lanczos( this many iterations, and always on the final iteration. Returns: - A pair ``(eigenvalue, eigenvector)`` for the smallest approximated - eigenpair. + ``StochasticLanczosResult`` containing the smallest approximated + eigenpair. The result supports tuple unpacking as + ``eigenvalue, eigenvector``. """ A = require_linop(A) require_square(A, "stochastic_lanczos") @@ -185,4 +204,4 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: den = ops.real(A.domain.inner(x, x)) lam = num / den - return lam, x + return StochasticLanczosResult(lam, x) diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 5358b85..2d40f29 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -1,8 +1,12 @@ from __future__ import annotations +from collections.abc import Callable from typing import Any, NamedTuple +from ..backend import Context +from ..functional import QuadraticForm from ..linop import LinOp +from ..space import Space from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import default_initial_vector, is_converged, normalize, require_linop from ._utils import require_square, result_repr, should_check_iteration @@ -31,8 +35,32 @@ def __repr__(self) -> str: ) +class _SelfAdjointAction(NamedTuple): + apply: Callable[[Any], Any] + domain: Space + ctx: Context + + @property + def ops(self) -> Any: + return self.ctx.ops + + @property + def dtype(self) -> Any: + return self.ctx.dtype + + +def _action_from_linop(A: LinOp) -> _SelfAdjointAction: + A = require_linop(A) + require_square(A, "power_iteration") + return _SelfAdjointAction(A.apply, A.domain, A.ctx) + + +def _action_from_quadratic_form(q: QuadraticForm) -> _SelfAdjointAction: + return _SelfAdjointAction(q.hess_apply, q.domain, q.ctx) + + def power_iteration( - A: LinOp, + A: LinOp | QuadraticForm, *, x0: Any | None = None, tol: float = 1e-6, @@ -40,25 +68,42 @@ def power_iteration( check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> PowerIterationResult: """ - Estimate the dominant eigenpair of a square ``LinOp`` by power iteration. + Estimate the dominant eigenpair of a square ``LinOp`` or Hessian action. - The method uses only ``A.apply`` and domain-space operations. It returns the - Rayleigh quotient for the current normalized iterate, the eigenvector + ``A`` may be a square ``LinOp`` or a ``QuadraticForm`` that exposes + ``hess_apply``. Public dispatch converts either input into a fixed + self-adjoint action before entering the numerical loop. The method returns + the Rayleigh quotient for the current normalized iterate, the eigenvector estimate, and the residual norm ``||A x - lambda x||``. The residual-based stopping criterion is refreshed only every ``check_every`` iterations, and always on the final iteration. For spectral-norm estimates of a rectangular operator, call this on ``A.H @ A``. """ - A = require_linop(A) - require_square(A, "power_iteration") - maxiter = check_maxiter(maxiter, A) + if isinstance(A, QuadraticForm): + action = _action_from_quadratic_form(A) + elif isinstance(A, LinOp): + action = _action_from_linop(A) + else: + raise TypeError(f"A must be a LinOp or QuadraticForm, got {type(A).__name__}.") + + maxiter = check_maxiter(maxiter, action) check_every = check_interval(check_every) - x = default_initial_vector(A) if x0 is None else x0 - A.domain.check_member(x) - x, _ = normalize(A.domain, x) - zero = A.ops.asarray(0.0, dtype=A.dtype) - residual_norm = A.domain.norm(x) + float("inf") + x = default_initial_vector(action) if x0 is None else x0 + action.domain.check_member(x) + return PowerIterationResult(*_power_iteration_core(action, x, tol, maxiter, check_every)) + + +def _power_iteration_core( + action: _SelfAdjointAction, + x: Any, + tol: float, + maxiter: int, + check_every: int, +) -> tuple[Any, Any, Any, Any, Any]: + x, _ = normalize(action.domain, x) + zero = action.ops.asarray(0.0, dtype=action.dtype) + residual_norm = action.domain.norm(x) + float("inf") def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: _eigenvalue, _x, res_norm, k = carry @@ -66,30 +111,30 @@ def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: def body_fun(carry: tuple[Any, Any, Any, int]) -> tuple[Any, Any, Any, int]: _eigenvalue, x, _residual_norm, k = carry - y = A.apply(x) - x_next, _norm_y = normalize(A.domain, y) - y_next = A.apply(x_next) - eigenvalue_next = A.domain.inner(x_next, y_next) + y = action.apply(x) + x_next, _norm_y = normalize(action.domain, y) + y_next = action.apply(x_next) + eigenvalue_next = action.domain.inner(x_next, y_next) k_next = k + 1 def refresh_residual(_: Any) -> Any: - residual = A.codomain.axpy(-eigenvalue_next, x_next, y_next) - return A.codomain.norm(residual) + residual = action.domain.axpy(-eigenvalue_next, x_next, y_next) + return action.domain.norm(residual) - residual_norm_next = A.ops.cond( + residual_norm_next = action.ops.cond( should_check_iteration(k_next, maxiter, check_every), refresh_residual, lambda _: _residual_norm, - A.ops.asarray(0.0, dtype=A.dtype), + action.ops.asarray(0.0, dtype=action.dtype), ) return eigenvalue_next, x_next, residual_norm_next, k_next - eigenvalue, eigenvector, residual_norm, num_iters = A.ops.while_loop( + eigenvalue, eigenvector, residual_norm, num_iters = action.ops.while_loop( cond_fun, body_fun, (zero, x, residual_norm, 0), ) - return PowerIterationResult( + return ( eigenvalue, eigenvector, is_converged(residual_norm, tol), diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index b923a83..3623b9b 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -90,6 +90,7 @@ def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): assert np.allclose(q.value(x), 12.0) assert np.allclose(q.grad(x), [5.0, -5.0]) + assert np.allclose(q.hess_apply(x), [4.0, -4.0]) def test_vvalue_and_vgrad_match_elementwise_loops(): diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index b9ff355..f352934 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -30,6 +30,7 @@ def test_expected_names_are_exported(): "set_context", "get_context", "resolve_context_priority", "register_ops", "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", + "StochasticLanczosResult", } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} @@ -46,6 +47,7 @@ def test_top_level_objects_match_source_modules(): space = importlib.import_module("spacecore.space") linop = importlib.import_module("spacecore.linop") functional = importlib.import_module("spacecore.functional") + linalg = importlib.import_module("spacecore.linalg") manager = importlib.import_module("spacecore._contextual.manager") assert sc.Context is backend.Context @@ -59,6 +61,7 @@ def test_top_level_objects_match_source_modules(): assert sc.DenseLinOp is linop.DenseLinOp assert sc.Functional is functional.Functional assert sc.InnerProductFunctional is functional.InnerProductFunctional + assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult assert sc.get_context is manager.get_context assert sc.resolve_context_priority is manager.resolve_context_priority diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index 25958f4..40d14d2 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -1,4 +1,5 @@ import importlib +import inspect import numpy as np import pytest @@ -131,6 +132,62 @@ def test_power_iteration_estimates_dominant_eigenpair(backend_name, dtype): assert bool(to_numpy(result.converged)) +def test_power_iteration_accepts_quadratic_form_hessian_action(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 1.0]) + + op_result = sc.power_iteration(op, x0=x0, tol=1e-5, maxiter=60) + q_result = sc.power_iteration(q, x0=x0, tol=1e-5, maxiter=60) + + np.testing.assert_allclose(to_numpy(q_result.eigenvalue), to_numpy(op_result.eigenvalue)) + np.testing.assert_allclose( + np.abs(to_numpy(q_result.eigenvector)), + np.abs(to_numpy(op_result.eigenvector)), + rtol=1e-6, + atol=1e-6, + ) + + +def test_power_iteration_dispatches_quadratic_form_before_core(monkeypatch): + sc = importlib.import_module("spacecore") + power_mod = importlib.import_module("spacecore.linalg._power") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 0.0]) + captured = {} + + def fake_core(action, x, tol, maxiter, check_every): + captured["action"] = action + captured["x"] = x + return ctx.asarray(0.0), x, ctx.asarray(True), 0, ctx.asarray(0.0) + + monkeypatch.setattr(power_mod, "_power_iteration_core", fake_core) + result = power_mod.power_iteration(q, x0=x0, maxiter=1) + + assert result.eigenvector is x0 + assert isinstance(captured["action"], power_mod._SelfAdjointAction) + assert captured["action"].domain == q.domain + x = ctx.asarray([1.0, 2.0]) + np.testing.assert_allclose(captured["action"].apply(x), q.hess_apply(x)) + + +def test_power_iteration_core_has_no_dispatch_logic(): + power_mod = importlib.import_module("spacecore.linalg._power") + source = inspect.getsource(power_mod._power_iteration_core) + + assert "isinstance" not in source + assert "hasattr" not in source + assert "getattr" not in source + assert "_SelfAdjointAction(" not in source + assert "PowerIterationResult(" not in source + + @pytest.mark.parametrize("backend_name,dtype", _backend_params()) def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype): sc = importlib.import_module("spacecore") @@ -155,6 +212,20 @@ def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype) ) +def test_stochastic_lanczos_returns_result_object(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + result = sc.stochastic_lanczos(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) + eigenvalue, eigenvector = result + + assert isinstance(result, sc.StochasticLanczosResult) + np.testing.assert_allclose(eigenvalue, result.eigenvalue) + np.testing.assert_allclose(eigenvector, result.eigenvector) + + def test_stochastic_lanczos_uses_e0_for_zero_initial_vector(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -251,6 +322,21 @@ def test_power_iteration_jit_compiles_with_operator_argument(): np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_power_iteration_jit_compiles_with_quadratic_form_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + + run = jax.jit(lambda quad, x: sc.power_iteration(quad, x0=x, maxiter=60).eigenvalue) + eigenvalue = run(q, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not has_jax(), reason="jax is not installed") def test_stochastic_lanczos_jit_compiles_with_operator_argument(): jax = pytest.importorskip("jax") From d0f58bff3b8eb21b4662100fc3c73997dc70b8a0 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Fri, 22 May 2026 01:50:26 -0300 Subject: [PATCH 18/44] Refactor contextual checks and Lanczos geometry --- spacecore/__init__.py | 8 +- spacecore/_checks.py | 31 +++++ spacecore/_contextual/__init__.py | 48 +++++-- spacecore/_contextual/{bound.py => _bound.py} | 6 +- .../_contextual/{manager.py => _manager.py} | 41 ++++-- spacecore/_contextual/_policies.py | 63 +++++++++ .../_contextual/{contextual.py => _state.py} | 70 ++-------- spacecore/functional/_base.py | 5 +- spacecore/functional/_linear.py | 7 +- spacecore/functional/_quadratic.py | 59 +++++--- spacecore/linalg/_lanczos.py | 101 ++++++++++---- spacecore/linop/_algebra.py | 55 +++++--- spacecore/linop/_base.py | 8 +- spacecore/linop/_dense.py | 19 +-- spacecore/linop/_diagonal.py | 14 +- spacecore/linop/_sparse.py | 19 +-- spacecore/linop/product/_base.py | 5 - spacecore/linop/product/_block.py | 7 +- spacecore/linop/product/_from_single.py | 7 +- spacecore/linop/product/_to_single.py | 7 +- spacecore/space/_product.py | 4 +- tests/backend/test_backend_registry.py | 2 +- tests/context/test_checked_method.py | 103 ++++++++++++++ tests/context/test_context_resolution.py | 49 +++++++ tests/context/test_enable_checks.py | 2 +- tests/functional/test_functional.py | 40 ++++++ tests/integration/test_public_api.py | 6 +- tests/linalg/test_krylov.py | 131 ++++++++++++++++++ tests/linops/test_linop_jit.py | 20 +++ 29 files changed, 719 insertions(+), 218 deletions(-) create mode 100644 spacecore/_checks.py rename spacecore/_contextual/{bound.py => _bound.py} (86%) rename spacecore/_contextual/{manager.py => _manager.py} (82%) create mode 100644 spacecore/_contextual/_policies.py rename spacecore/_contextual/{contextual.py => _state.py} (92%) create mode 100644 tests/context/test_checked_method.py diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 91e63c3..e190147 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -64,13 +64,14 @@ ) from .types import DenseArray, SparseArray, ArrayLike -from ._contextual import ContextBound -from ._contextual.manager import ( +from ._checks import checked_method +from ._contextual import ( + ContextBound, set_context, get_context, resolve_context_priority, register_ops, set_resolution_policy, set_dtype_resolution_policy, - get_resolution_policy, get_dtype_resolution_policy + get_resolution_policy, get_dtype_resolution_policy, ) __all__ = [ @@ -133,6 +134,7 @@ "SparseArray", "ArrayLike", + "checked_method", "ContextBound", "set_context", "get_context", diff --git a/spacecore/_checks.py b/spacecore/_checks.py new file mode 100644 index 0000000..d1b84f9 --- /dev/null +++ b/spacecore/_checks.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable + + +def checked_method( + *, + in_space: str | None = None, + out_space: str | None = None, + arg_pos: int = 0, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorate methods with optional Space membership checks.""" + + def decorate(method: Callable[..., Any]) -> Callable[..., Any]: + @wraps(method) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if self._enable_checks and in_space is not None: + x = args[arg_pos] + getattr(self, in_space)._check_member(x) + + y = method(self, *args, **kwargs) + + if self._enable_checks and out_space is not None: + getattr(self, out_space)._check_member(y) + + return y + + return wrapper + + return decorate diff --git a/spacecore/_contextual/__init__.py b/spacecore/_contextual/__init__.py index 45cebdc..bb86e5d 100644 --- a/spacecore/_contextual/__init__.py +++ b/spacecore/_contextual/__init__.py @@ -1,11 +1,43 @@ -from .bound import ContextBound as ContextBound -from .manager import ( - ctx_manager as ctx_manager, - set_context as set_context, - resolve_context_priority as resolve_context_priority, +from ._bound import ContextBound as ContextBound +from ._manager import ( + enforce_convert_policy as enforce_convert_policy, + get_context as get_context, + get_dtype_resolution_policy as get_dtype_resolution_policy, + get_resolution_policy as get_resolution_policy, + normalize_context as normalize_context, register_ops as register_ops, - set_resolution_policy as set_resolution_policy, + resolve_context_priority as resolve_context_priority, + set_context as set_context, set_dtype_resolution_policy as set_dtype_resolution_policy, - get_resolution_policy as get_resolution_policy, - get_dtype_resolution_policy as get_dtype_resolution_policy, + set_resolution_policy as set_resolution_policy, +) +from ._policies import ( + ContextConflictError as ContextConflictError, + ContextConversionError as ContextConversionError, + ContextError as ContextError, + ContextInferenceError as ContextInferenceError, + ContextPolicy as ContextPolicy, + DtypePreservePolicy as DtypePreservePolicy, + UnknownBackendError as UnknownBackendError, ) + +__all__ = [ + "ContextBound", + "ContextConflictError", + "ContextConversionError", + "ContextError", + "ContextInferenceError", + "ContextPolicy", + "DtypePreservePolicy", + "UnknownBackendError", + "enforce_convert_policy", + "get_context", + "get_dtype_resolution_policy", + "get_resolution_policy", + "normalize_context", + "register_ops", + "resolve_context_priority", + "set_context", + "set_dtype_resolution_policy", + "set_resolution_policy", +] diff --git a/spacecore/_contextual/bound.py b/spacecore/_contextual/_bound.py similarity index 86% rename from spacecore/_contextual/bound.py rename to spacecore/_contextual/_bound.py index eea87ba..658d22d 100644 --- a/spacecore/_contextual/bound.py +++ b/spacecore/_contextual/_bound.py @@ -5,7 +5,7 @@ from ..backend import Context, BackendOps, BackendFamily from ..types import DType -from .manager import ctx_manager +from ._manager import enforce_convert_policy, normalize_context def _same_effective_context(left: Context, right: Context) -> bool: @@ -18,7 +18,7 @@ def _same_effective_context(left: Context, right: Context) -> bool: class ContextBound(ABC): def __init__(self, ctx: Context | str | None = None): - ctx = ctx_manager.normalize_context(ctx) + ctx = normalize_context(ctx) self._ctx = ctx @property @@ -37,7 +37,7 @@ def _convert(self, new_ctx: Context) -> Self: raise NotImplementedError() def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: - _, new_ctx = ctx_manager.enforce_convert_policy(self, new_ctx) + _, new_ctx = enforce_convert_policy(self, new_ctx) if _same_effective_context(self.ctx, new_ctx): return self return self._convert(new_ctx) diff --git a/spacecore/_contextual/manager.py b/spacecore/_contextual/_manager.py similarity index 82% rename from spacecore/_contextual/manager.py rename to spacecore/_contextual/_manager.py index f655cb8..104f0a4 100644 --- a/spacecore/_contextual/manager.py +++ b/spacecore/_contextual/_manager.py @@ -1,13 +1,11 @@ from typing import Any from ..backend import Context, BackendOps -from .contextual import Contextual, ContextPolicy, DtypePreservePolicy +from ._policies import ContextPolicy, DtypePreservePolicy +from ._state import _contextual from ..backend import BackendFamily -ctx_manager = Contextual() - - def set_context( ctx: Context | BackendFamily | str | None = None, dtype: Any = None, @@ -35,8 +33,8 @@ def set_context( Objects created without an explicit context use this default context. Existing spaces, operators, and contexts are not modified. """ - ctx = ctx_manager.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) - ctx_manager.default_ctx = ctx + ctx = _contextual.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + _contextual.default_ctx = ctx def get_context() -> Context: @@ -49,7 +47,7 @@ def get_context() -> Context: The default context used by constructors when no explicit context can be inferred or provided. """ - return ctx_manager.default_ctx + return _contextual.default_ctx def resolve_context_priority( @@ -80,7 +78,7 @@ def resolve_context_priority( User code should call this function instead of accessing the internal context manager singleton. """ - return ctx_manager.resolve_context_priority(priority_ctx, *other_ctx) + return _contextual.resolve_context_priority(priority_ctx, *other_ctx) def register_ops(ops: type[BackendOps]) -> type[BackendOps]: @@ -115,7 +113,24 @@ def register_ops(ops: type[BackendOps]) -> type[BackendOps]: class MyOps(BackendOps): ... """ - return ctx_manager.register_ops(ops) + return _contextual.register_ops(ops) + + +def normalize_context( + ctx: Context | BackendFamily | str | None = None, + dtype: Any = None, + enable_checks: bool | None = None, +) -> Context: + """Normalize a context specification through the process-wide state.""" + return _contextual.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + + +def enforce_convert_policy( + x: Any, + to: Context | BackendFamily | str | None = None, +) -> tuple[Any, Context]: + """Resolve a conversion target and enforce the configured policy.""" + return _contextual.enforce_convert_policy(x, to) def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: @@ -141,7 +156,7 @@ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: * ``"error"``: reject backend conversion. * ``"silent"``: allow backend conversion without warning. """ - ctx_manager.resolution_policy = policy + _contextual.resolution_policy = policy def get_resolution_policy() -> str: @@ -153,7 +168,7 @@ def get_resolution_policy() -> str: str Policy name, one of ``"warning"``, ``"error"``, or ``"silent"``. """ - return ctx_manager.resolution_policy.value + return _contextual.resolution_policy.value def set_dtype_resolution_policy( @@ -181,7 +196,7 @@ def set_dtype_resolution_policy( equivalent dtype in the target backend. * ``"convert"``: use the dtype provided by the resolved target context. """ - ctx_manager.dtype_resolution_policy = policy + _contextual.dtype_resolution_policy = policy def get_dtype_resolution_policy() -> str: @@ -193,4 +208,4 @@ def get_dtype_resolution_policy() -> str: str Policy name, one of ``"keep_native"`` or ``"convert"``. """ - return ctx_manager.dtype_resolution_policy.value + return _contextual.dtype_resolution_policy.value diff --git a/spacecore/_contextual/_policies.py b/spacecore/_contextual/_policies.py new file mode 100644 index 0000000..3bf21f2 --- /dev/null +++ b/spacecore/_contextual/_policies.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from enum import StrEnum, auto + + +class ContextPolicy(StrEnum): + """ + Policy for backend-incompatible context conversion. + + Values + ------ + warning: + Allow conversion to a different backend family and issue a warning. + This is the default. + error: + Reject conversion to a different backend family. Use this when + accidental backend migration should be forbidden. + silent: + Allow conversion to a different backend family without warning. Use + this when automatic conversion is expected and controlled. + """ + + warning = auto() + error = auto() + silent = auto() + + +class DtypePreservePolicy(StrEnum): + """ + Policy for dtype handling during context conversion. + + Values + ------ + keep_native: + Preserve the source object's dtype where possible by converting it to an + equivalent dtype in the target backend. This is the default. + convert: + Use the dtype provided by the resolved target context. This prioritizes + dtype unification under the target context. + """ + + keep_native = auto() + convert = auto() + + +class ContextError(RuntimeError): + pass + + +class ContextInferenceError(ContextError): + pass + + +class ContextConflictError(ContextError): + pass + + +class UnknownBackendError(ContextError): + pass + + +class ContextConversionError(ContextError): + pass diff --git a/spacecore/_contextual/contextual.py b/spacecore/_contextual/_state.py similarity index 92% rename from spacecore/_contextual/contextual.py rename to spacecore/_contextual/_state.py index aa21109..e3ec419 100644 --- a/spacecore/_contextual/contextual.py +++ b/spacecore/_contextual/_state.py @@ -1,11 +1,18 @@ from __future__ import annotations from typing import Dict, Any, Iterable, Tuple -from enum import StrEnum, auto from warnings import warn from ..types import DType from ..backend import Context, NumpyOps, JaxOps, BackendFamily, BackendOps +from ._policies import ( + ContextConflictError, + ContextConversionError, + ContextInferenceError, + ContextPolicy, + DtypePreservePolicy, + UnknownBackendError, +) try: from ..backend import CuPyOps except ImportError: @@ -16,64 +23,6 @@ pass -class ContextPolicy(StrEnum): - """ - Policy for backend-incompatible context conversion. - - Values - ------ - warning: - Allow conversion to a different backend family and issue a warning. - This is the default. - error: - Reject conversion to a different backend family. Use this when - accidental backend migration should be forbidden. - silent: - Allow conversion to a different backend family without warning. Use - this when automatic conversion is expected and controlled. - """ - - warning = auto() - error = auto() - silent = auto() - -class DtypePreservePolicy(StrEnum): - """ - Policy for dtype handling during context conversion. - - Values - ------ - keep_native: - Preserve the source object's dtype where possible by converting it to an - equivalent dtype in the target backend. This is the default. - convert: - Use the dtype provided by the resolved target context. This prioritizes - dtype unification under the target context. - """ - - keep_native = auto() - convert = auto() - - -class ContextError(RuntimeError): - pass - - -class ContextInferenceError(ContextError): - pass - - -class ContextConflictError(ContextError): - pass - - -class UnknownBackendError(ContextError): - pass - -class ContextConversionError(ContextError): - pass - - class Contextual: """ Backend resolver. @@ -559,3 +508,6 @@ def _join_dtypes(self, ops: BackendOps, *dtypes: DType | None) -> DType | None: np_ops = NumpyOps() joined = np_ops.np.result_type(*clean) return ops.sanitize_dtype(joined) + + +_contextual = Contextual() diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py index 185a827..95c2f7a 100644 --- a/spacecore/functional/_base.py +++ b/spacecore/functional/_base.py @@ -3,8 +3,7 @@ from abc import abstractmethod from typing import Any, Generic, TypeVar -from .._contextual import ContextBound -from .._contextual.manager import ctx_manager +from .._contextual import ContextBound, resolve_context_priority from ..backend import Context from ..space import Space @@ -23,7 +22,7 @@ class Functional(ContextBound, Generic[Domain]): """ def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, dom) + ctx = resolve_context_priority(ctx, dom) super().__init__(ctx) self.dom = dom.convert(self.ctx) self._enable_checks = self.ctx.enable_checks diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py index ca349a3..e231220 100644 --- a/spacecore/functional/_linear.py +++ b/spacecore/functional/_linear.py @@ -4,6 +4,7 @@ from typing import Any from ._base import Domain, Functional +from .._checks import checked_method from ..backend import Context, jax_pytree_class from ..space import Space @@ -59,10 +60,9 @@ def representer(self) -> Any: """Stored domain element ``c`` defining ``ell_c(x) = ``.""" return self._c + @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``domain.inner(representer, x)``.""" - if self._enable_checks: - self.domain._check_member(x) return self.domain.inner(self._c, x) def __eq__(self, other: Any) -> bool: @@ -117,10 +117,9 @@ def representer(self) -> Any: f"{type(self).__name__} does not store a Riesz representer." ) + @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``value_fn(x)``.""" - if self._enable_checks: - self.domain._check_member(x) y = self.value_fn(x) if self._enable_checks: self._check_scalar_batch(y, ()) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py index ebbe79f..6f77f5d 100644 --- a/spacecore/functional/_quadratic.py +++ b/spacecore/functional/_quadratic.py @@ -4,9 +4,10 @@ from ._base import Domain, Functional from ._linear import LinearFunctional -from .._contextual.manager import ctx_manager +from .._checks import checked_method +from .._contextual import resolve_context_priority from ..backend import Context, jax_pytree_class -from ..linop import LinOp +from ..linop import DenseLinOp, DiagonalLinOp, LinOp from ..space import Space @@ -36,9 +37,19 @@ def vgrad(self, xs: Any, batch_space: Space | None = None) -> Any: @jax_pytree_class class LinOpQuadraticForm(QuadraticForm[Domain]): """ - Quadratic form backed by a linear operator. + Quadratic form f(x) = 0.5 * . - ``q(x) = 1/2 * + linear(x) + a`` with ``Q : X -> X``. + Assumption: + Q is Hermitian/self-adjoint. Under this assumption, + grad f(x) = Q x. + + Non-Hermitian operators are not supported here. If users need the + Hermitian part, they must construct 0.5 * (Q + Q.H) explicitly. + + The full objective is ``q(x) = 1/2 * + linear(x) + a`` with + ``Q : X -> X``. Structurally available dense and diagonal operators are + checked at construction. Matrix-free operators are not validated; correctness + is the caller's responsibility. """ def __init__( @@ -55,10 +66,11 @@ def __init__( f"linear must be a LinearFunctional or None, got {type(linear).__name__}." ) - resolved_ctx = ctx_manager.resolve_context_priority(ctx, Q.domain, Q, linear) + resolved_ctx = resolve_context_priority(ctx, Q.domain, Q, linear) Q = Q.convert(resolved_ctx) if Q.domain != Q.codomain: raise ValueError("LinOpQuadraticForm requires Q.domain == Q.codomain.") + self._check_hermitian_structure(Q) if linear is not None: linear = linear.convert(resolved_ctx) if linear.domain != Q.domain: @@ -71,39 +83,46 @@ def __init__( if self._enable_checks: self._check_scalar_batch(self.a, ()) + @staticmethod + def _check_hermitian_structure(Q: LinOp[Domain, Domain]) -> None: + try: + if isinstance(Q, DenseLinOp): + is_hermitian = Q.ops.allclose(Q._A2, Q._A2H) + elif isinstance(Q, DiagonalLinOp): + is_hermitian = Q.ops.allclose(Q.diagonal, Q._diag_adjoint) + else: + return + except Exception: + return + if not is_hermitian: + raise ValueError("LinOpQuadraticForm requires Q to be Hermitian/self-adjoint.") + + @checked_method(in_space="domain") def value(self, x: Any) -> Any: """Return ``1/2 * + linear(x) + a``.""" - if self._enable_checks: - self.domain._check_member(x) qx = self.Q.apply(x) value = 0.5 * self.domain.inner(x, qx) if self.linear is not None: value = value + self.linear.value(x) return value + self.a + @checked_method(in_space="domain", out_space="domain") def grad(self, x: Any) -> Any: """ Return the Euclidean/Riesz gradient. - The quadratic part uses the symmetric adjoint part ``(Q + Q*) / 2``. - For self-adjoint ``Q`` this is exactly ``Qx``. + ``LinOpQuadraticForm`` assumes ``Q`` is Hermitian/self-adjoint, so the + quadratic contribution is exactly ``Q.apply(x)``. """ - if self._enable_checks: - self.domain._check_member(x) - qx = self.Q.apply(x) - qhx = self.Q.rapply(x) - grad = self.domain.scale(0.5, self.domain.add(qx, qhx)) + grad = self.Q.apply(x) if self.linear is not None: grad = self.domain.add(grad, self.linear.representer) return grad + @checked_method(in_space="domain", out_space="domain") def hess_apply(self, x: Any) -> Any: - """Return the self-adjoint Hessian action ``(Q + Q*) x / 2``.""" - if self._enable_checks: - self.domain._check_member(x) - qx = self.Q.apply(x) - qhx = self.Q.rapply(x) - return self.domain.scale(0.5, self.domain.add(qx, qhx)) + """Return the Hessian action ``Q x`` under the Hermitian assumption.""" + return self.Q.apply(x) def __eq__(self, other: Any) -> bool: if type(other) is type(self): diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 2f1ec3e..b14322f 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -2,10 +2,12 @@ from typing import Any, NamedTuple +import numpy as np + from ..linop import LinOp from ..types import DenseArray from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval -from ._utils import require_linop, require_square, should_check_iteration +from ._utils import require_linop, require_square, safe_inverse, should_check_iteration from ._utils import result_repr @@ -33,6 +35,21 @@ def _check_lanczos_max_iter(max_iter: int) -> int: return max_iter +def _real_dtype(ctx: Any) -> Any: + dtype_text = str(ctx.dtype) + if "complex128" in dtype_text: + return ctx.ops.sanitize_dtype(np.float64) + if "complex64" in dtype_text: + return ctx.ops.sanitize_dtype(np.float32) + try: + dtype = np.dtype(ctx.dtype) + except TypeError: + return ctx.dtype + if dtype.kind == "c": + return ctx.ops.sanitize_dtype(np.float64 if dtype.itemsize > 8 else np.float32) + return ctx.dtype + + def stochastic_lanczos( A: LinOp, initial_vector: Any, @@ -76,32 +93,38 @@ def stochastic_lanczos( A.domain.check_member(initial_vector) ops = A.ops ctx = A.ctx + real_dtype = _real_dtype(ctx) v0 = A.domain.flatten(initial_vector) v0 = ctx.assert_dense(v0) n = v0.shape[0] V = ops.zeros((max_iter + 1, n), dtype=ctx.dtype) - alphas = ops.zeros((max_iter,), dtype=ctx.dtype) - betas = ops.zeros((max_iter + 1,), dtype=ctx.dtype) + alphas = ops.zeros((max_iter,), dtype=real_dtype) + betas = ops.zeros((max_iter + 1,), dtype=real_dtype) - tol_s = ctx.asarray(tol) - eps_s = ctx.asarray(1e-12) + tol_s = ops.asarray(tol, dtype=real_dtype) + eps_s = ops.asarray(1e-12, dtype=real_dtype) - v0_norm = ops.sqrt(ops.real(ops.vdot(v0, v0))) + v0_norm = A.domain.norm(initial_vector) e0 = ops.zeros((n,), dtype=ctx.dtype) e0 = ops.index_set(e0, (0,), ctx.asarray(1.0), copy=True) + e0_member = A.domain.unflatten(e0) + e0_norm = A.domain.norm(e0_member) + e0_unit = A.domain.flatten(A.domain.scale(safe_inverse(ops, e0_norm), e0_member)) v0_unit = ops.cond( v0_norm > eps_s, - lambda _: v0 / v0_norm, - lambda _: e0, - ctx.asarray(0.0), + lambda _: A.domain.flatten( + A.domain.scale(safe_inverse(ops, v0_norm), initial_vector) + ), + lambda _: e0_unit, + ops.asarray(0.0, dtype=real_dtype), ) V = ops.index_set(V, (0, slice(None)), v0_unit, copy=True) - beta0 = ctx.asarray(1.0) + beta0 = ops.asarray(1.0, dtype=real_dtype) i0 = 0 keep_going0 = ops.asarray(True) @@ -116,10 +139,12 @@ def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, i, V_, alphas_, betas_, beta, keep_going = state v_i = V_[i] - w = A.codomain.flatten(A.apply(A.domain.unflatten(v_i))) + v_i_member = A.domain.unflatten(v_i) + w_member = A.apply(v_i_member) + w = A.codomain.flatten(w_member) w = ctx.assert_dense(w) - alpha = ops.real(ops.vdot(v_i, w)) + alpha = ops.real(A.domain.inner(v_i_member, w_member)) alphas_ = ops.index_set(alphas_, (i,), alpha, copy=True) w = ops.cond( @@ -129,20 +154,34 @@ def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, w, ) + w_member = A.domain.unflatten(w) valid = full_indices < (i + 1) - mask = ops.where(valid, ctx.asarray(1.0), ctx.asarray(0.0)) - mask = ops.astype(mask, w.dtype) + mask = ops.where( + valid, + ops.asarray(1.0, dtype=real_dtype), + ops.asarray(0.0, dtype=real_dtype), + ) + mask = ops.astype(mask, ctx.dtype) - coeffs_full = ops.einsum("jn,n->j", ops.conj(V_), w) + coeffs_full = ops.zeros((max_iter + 1,), dtype=ctx.dtype) + + def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: + v_j_member = A.domain.unflatten(V_[j]) + coeff = A.domain.inner(v_j_member, w_member) + return ops.index_set(coeffs_in, (j,), coeff, copy=True) + + coeffs_full = ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full) coeffs_valid = coeffs_full * mask proj = ops.sum(coeffs_valid[:, None] * V_, axis=0) w = w - proj - beta_new = ops.sqrt(ops.real(ops.vdot(w, w))) + w_member = A.domain.unflatten(w) + beta_new = A.domain.norm(w_member) betas_ = ops.index_set(betas_, (i + 1,), beta_new, copy=True) def set_next(V_in: DenseArray) -> DenseArray: - return ops.index_set(V_in, (i + 1, slice(None)), w / beta_new, copy=True) + w_unit = A.domain.flatten(A.domain.scale(safe_inverse(ops, beta_new), w_member)) + return ops.index_set(V_in, (i + 1, slice(None)), w_unit, copy=True) V_ = ops.cond(beta_new >= tol_s, set_next, lambda V_in: V_in, V_) i_next = i + 1 @@ -150,7 +189,7 @@ def set_next(V_in: DenseArray) -> DenseArray: should_check_iteration(i_next, max_iter, check_every), lambda _: beta_new >= tol_s, lambda _: keep_going, - ctx.asarray(0.0), + ops.asarray(0.0, dtype=real_dtype), ) return i_next, V_, alphas_, betas_, beta_new, keep_going_next @@ -161,10 +200,15 @@ def set_next(V_in: DenseArray) -> DenseArray: m = i_final mask_alpha = idx < m - alphas_full = ops.where(mask_alpha, alphas, ctx.asarray(1e10)) - betas_full = ops.where(full_indices == m, ctx.asarray(0.0), betas) + inactive_sentinel = ( + ops.max(ops.abs(alphas)) + + 2.0 * ops.max(ops.abs(betas)) + + ops.asarray(1.0, dtype=real_dtype) + ) + alphas_full = ops.where(mask_alpha, alphas, inactive_sentinel) + betas_full = ops.where(full_indices == m, ops.asarray(0.0, dtype=real_dtype), betas) - T = ops.zeros((max_iter, max_iter), dtype=ctx.dtype) + T = ops.zeros((max_iter, max_iter), dtype=real_dtype) def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) @@ -182,19 +226,24 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: _eigvals, eigvecs = ops.eigh(T) y_full = eigvecs[:, 0] - mask_y = ops.where(idx < m, ctx.asarray(1.0), ctx.asarray(0.0)) + mask_y = ops.where( + idx < m, + ops.asarray(1.0, dtype=real_dtype), + ops.asarray(0.0, dtype=real_dtype), + ) mask_y = ops.astype(mask_y, y_full.dtype) y_valid = y_full * mask_y V_reduced = V[:max_iter, :] x_flat = ops.einsum("j,jn->n", y_valid, V_reduced) - x_norm = ops.sqrt(ops.real(ops.vdot(x_flat, x_flat))) + x_member = A.domain.unflatten(x_flat) + x_norm = A.domain.norm(x_member) x_flat = ops.cond( x_norm > eps_s, - lambda _: x_flat / x_norm, - lambda _: e0, - ctx.asarray(0.0), + lambda _: A.domain.flatten(A.domain.scale(safe_inverse(ops, x_norm), x_member)), + lambda _: e0_unit, + ops.asarray(0.0, dtype=real_dtype), ) x = A.domain.unflatten(x_flat) diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index 3029fc3..8b9c3b4 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -4,6 +4,7 @@ from typing import Any, Sequence from ._base import LinOp, Domain, Codomain +from .._checks import checked_method from ..backend import Context, jax_pytree_class @@ -190,10 +191,12 @@ def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: self.scalar = scalar self.op = op + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``scalar * op.apply(x)``.""" return self.scalar * self.op.apply(x) + @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return ``conj(scalar) * op.rapply(y)``.""" return _conjugate_scalar(self.scalar) * self.op.rapply(y) @@ -262,6 +265,7 @@ def parts(self) -> tuple[LinOp[Domain, Codomain], ...]: """Operators in this lazy sum.""" return self.ops_tuple + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``sum_i ops[i].apply(x)``.""" acc = self.ops_tuple[0].apply(x) @@ -269,6 +273,7 @@ def apply(self, x: Any) -> Any: acc = self.codomain.add(acc, op.apply(x)) return acc + @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return ``sum_i ops[i].rapply(y)``.""" acc = self.ops_tuple[0].rapply(y) @@ -340,10 +345,12 @@ def __init__(self, left: LinOp, right: LinOp) -> None: self.left = left self.right = right + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``left.apply(right.apply(x))``.""" return self.left.apply(self.right.apply(x)) + @checked_method(in_space="codomain", out_space="domain") def rapply(self, z: Any) -> Any: """Return ``right.rapply(left.rapply(z))``.""" return self.right.rapply(self.left.rapply(z)) @@ -401,16 +408,20 @@ def __init__( ) -> None: super().__init__(dom, cod, ctx) + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return the zero element of the codomain.""" - if self._enable_checks: - self.domain._check_member(x) + return self._apply_unchecked(x) + + def _apply_unchecked(self, x: Any) -> Any: return self.codomain.zeros() + @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return the zero element of the domain.""" - if self._enable_checks: - self.codomain._check_member(y) + return self._rapply_unchecked(y) + + def _rapply_unchecked(self, y: Any) -> Any: return self.domain.zeros() def vapply(self, xs: Any, batch_space=None) -> Any: @@ -470,16 +481,20 @@ class IdentityLinOp(LinOp[Domain, Domain]): def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: super().__init__(space, space, ctx) + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``x`` after domain validation.""" - if self._enable_checks: - self.domain._check_member(x) + return self._apply_unchecked(x) + + def _apply_unchecked(self, x: Any) -> Any: return x + @checked_method(in_space="codomain", out_space="domain") def rapply(self, x: Any) -> Any: """Return ``x`` after codomain validation.""" - if self._enable_checks: - self.codomain._check_member(x) + return self._rapply_unchecked(x) + + def _rapply_unchecked(self, x: Any) -> Any: return x def vapply(self, xs: Any, batch_space=None) -> Any: @@ -567,23 +582,21 @@ def __init__( self.vapply_fn = vapply self.rvapply_fn = rvapply + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: """Return ``apply_fn(x)``.""" - if self._enable_checks: - self.domain._check_member(x) - y = self.apply_fn(x) - if self._enable_checks: - self.codomain._check_member(y) - return y + return self._apply_unchecked(x) + def _apply_unchecked(self, x: Any) -> Any: + return self.apply_fn(x) + + @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: """Return ``rapply_fn(y)``.""" - if self._enable_checks: - self.codomain._check_member(y) - x = self.rapply_fn(y) - if self._enable_checks: - self.domain._check_member(x) - return x + return self._rapply_unchecked(y) + + def _rapply_unchecked(self, y: Any) -> Any: + return self.rapply_fn(y) def vapply(self, xs: Any, batch_space=None) -> Any: """Return ``vapply_fn(xs)`` when supplied, otherwise use fallback batching.""" @@ -669,10 +682,12 @@ def __init__(self, op: LinOp[Domain, Codomain]) -> None: super().__init__(op.codomain, op.domain, op.ctx) self.op = op + @checked_method(in_space="domain", out_space="codomain") def apply(self, y: Any) -> Any: """Return ``op.rapply(y)``.""" return self.op.rapply(y) + @checked_method(in_space="codomain", out_space="domain") def rapply(self, x: Any) -> Any: """Return ``op.apply(x)``.""" return self.op.apply(x) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index d5aa793..dda40ed 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,14 +1,14 @@ from __future__ import annotations from abc import abstractmethod +from functools import cached_property from math import prod from numbers import Number from typing import Any, Generic, TypeVar from ..space import Space from ..backend import Context -from .._contextual import ContextBound -from .._contextual.manager import ctx_manager +from .._contextual import ContextBound, resolve_context_priority Domain = TypeVar('Domain', bound=Space) Codomain = TypeVar('Codomain', bound=Space) @@ -26,7 +26,7 @@ class LinOp(ContextBound, Generic[Domain, Codomain]): """ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None): - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) super(LinOp, self).__init__(ctx) self.dom = dom.convert(self.ctx) @@ -43,7 +43,7 @@ def codomain(self) -> Codomain: """Codomain space of this linear operator.""" return self.cod - @property + @cached_property def A(self) -> Any: """ Native numerical representation of this operator. diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 244f566..d2fa2ce 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -1,13 +1,15 @@ from __future__ import annotations +from functools import cached_property from math import prod from typing import Any from ._base import LinOp, Domain, Codomain +from .._checks import checked_method from ..space import VectorSpace from ..types import DenseArray from ..backend import jax_pytree_class, Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority @jax_pytree_class @@ -27,7 +29,7 @@ def __init__(self, cod: Codomain | None = None, ctx: Context | str | None = None ) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) ctx.assert_dense(A) # Check if A is ndarray of ctx if cod is None: @@ -53,13 +55,8 @@ def __init__(self, self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._dom_vector_fast_path = type(self.dom) is VectorSpace self._cod_vector_fast_path = type(self.cod) is VectorSpace - if not self._enable_checks: - self.apply = self._apply_unchecked - self.rapply = self._rapply_unchecked - self.vapply = self._vapply_unchecked - self.rvapply = self._rvapply_unchecked - @property + @cached_property def A(self) -> DenseArray: """ Stored dense tensor representation of this operator. @@ -69,12 +66,11 @@ def A(self) -> DenseArray: """ return self._A + @checked_method(in_space="dom", out_space="cod") def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. """ - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: @@ -84,14 +80,13 @@ def _apply_unchecked(self, x: DenseArray) -> DenseArray: return y1 if self._cod_is_flat else y1.reshape(self.cod.shape) return self.cod.unflatten(y1) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: DenseArray) -> DenseArray: """ Adjoint action: x = A^* ⋅ y with x in dom.shape. For complex A, uses conjugate-transpose of the 2D reshaped matrix. """ - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py index 9ba6ea7..a27abab 100644 --- a/spacecore/linop/_diagonal.py +++ b/spacecore/linop/_diagonal.py @@ -1,13 +1,15 @@ from __future__ import annotations +from functools import cached_property from math import prod from typing import Any from ._base import LinOp +from .._checks import checked_method from ..backend import Context, jax_pytree_class from ..space import VectorSpace from ..types import DenseArray -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority @jax_pytree_class @@ -20,7 +22,7 @@ def __init__( space: VectorSpace | None = None, ctx: Context | str | None = None, ) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, space) + ctx = resolve_context_priority(ctx, space) ctx.assert_dense(diagonal) if space is None: space = VectorSpace(tuple(diagonal.shape), ctx) @@ -39,18 +41,16 @@ def __init__( self._diag_adjoint if self._is_flat else self._diag_adjoint.reshape((self._size,)) ) - @property + @cached_property def A(self) -> DenseArray: return self.to_dense() + @checked_method(in_space="domain", out_space="codomain") def apply(self, x: DenseArray) -> DenseArray: - if self._enable_checks: - self.domain._check_member(x) return self.diagonal * x + @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: DenseArray) -> DenseArray: - if self._enable_checks: - self.codomain._check_member(y) return self._diag_adjoint * y def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index ab4f306..a691bf3 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -1,13 +1,15 @@ from __future__ import annotations +from functools import cached_property from math import prod from typing import Any from ._base import LinOp, Domain, Codomain +from .._checks import checked_method from ..space import VectorSpace from ..types import DenseArray, SparseArray from ..backend import jax_pytree_class, Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority @jax_pytree_class @@ -28,7 +30,7 @@ def __init__(self, cod: Codomain, ctx: Context | str | None = None ) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) ctx.assert_sparse(A) # Check if A is sparse array of ctx super(SparseLinOp, self).__init__(dom, cod, ctx) @@ -48,13 +50,8 @@ def __init__(self, self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._dom_vector_fast_path = type(self.dom) is VectorSpace self._cod_vector_fast_path = type(self.cod) is VectorSpace - if not self._enable_checks: - self.apply = self._apply_unchecked - self.rapply = self._rapply_unchecked - self.vapply = self._vapply_unchecked - self.rvapply = self._rvapply_unchecked - @property + @cached_property def A(self) -> SparseArray: """ Stored sparse matrix representation of this operator. @@ -65,14 +62,13 @@ def A(self) -> SparseArray: """ return self._A + @checked_method(in_space="dom", out_space="cod") def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. x must have shape dom.shape (dense). """ - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: @@ -82,14 +78,13 @@ def _apply_unchecked(self, x: DenseArray) -> DenseArray: return y1 if self._cod_is_flat else y1.reshape(self.cod.shape) return self.cod.unflatten(y1) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: DenseArray) -> DenseArray: """ Adjoint action: x = A^* ⋅ y with x in dom.shape. y must have shape cod.shape (dense). """ - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: diff --git a/spacecore/linop/product/_base.py b/spacecore/linop/product/_base.py index 7e4d4d9..1decf86 100644 --- a/spacecore/linop/product/_base.py +++ b/spacecore/linop/product/_base.py @@ -31,11 +31,6 @@ def __init__(self, self._apply_parts = tuple(getattr(op, "_apply_unchecked", op.apply) for op in self.parts) self._rapply_parts = tuple(getattr(op, "_rapply_unchecked", op.rapply) for op in self.parts) self._check_layout() - unchecked_apply = getattr(self, "_apply_unchecked", None) - unchecked_rapply = getattr(self, "_rapply_unchecked", None) - if not self._enable_checks and unchecked_apply is not None and unchecked_rapply is not None: - self.apply = unchecked_apply - self.rapply = unchecked_rapply @abstractmethod def _check_layout(self) -> None: diff --git a/spacecore/linop/product/_block.py b/spacecore/linop/product/_block.py index 4992c6a..8d7ea19 100644 --- a/spacecore/linop/product/_block.py +++ b/spacecore/linop/product/_block.py @@ -4,6 +4,7 @@ from ._base import ProductLinOp from .._base import LinOp +from ..._checks import checked_method from ... import Context from ...space import ProductSpace from ...backend import jax_pytree_class @@ -33,9 +34,8 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} has incompatible dom/cod spaces.") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: @@ -43,9 +43,8 @@ def _apply_unchecked(self, x: Any) -> Any: return self._apply_parts[0](x[0]), self._apply_parts[1](x[1]) return tuple(apply(xi) for apply, xi in zip(self._apply_parts, x)) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: diff --git a/spacecore/linop/product/_from_single.py b/spacecore/linop/product/_from_single.py index fabea53..2b65689 100644 --- a/spacecore/linop/product/_from_single.py +++ b/spacecore/linop/product/_from_single.py @@ -4,6 +4,7 @@ from ._base import ProductLinOp from .._base import LinOp, Domain +from ..._checks import checked_method from ...space import ProductSpace, VectorSpace from ...backend import jax_pytree_class, Context @@ -34,9 +35,8 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} must map dom -> cod.spaces[{i}].") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: @@ -44,9 +44,8 @@ def _apply_unchecked(self, x: Any) -> Any: return self._apply_parts[0](x), self._apply_parts[1](x) return tuple(apply(x) for apply in self._apply_parts) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: diff --git a/spacecore/linop/product/_to_single.py b/spacecore/linop/product/_to_single.py index c5ab082..60c91fa 100644 --- a/spacecore/linop/product/_to_single.py +++ b/spacecore/linop/product/_to_single.py @@ -4,6 +4,7 @@ from ._base import ProductLinOp from .._base import LinOp, Codomain +from ..._checks import checked_method from ...space import ProductSpace, VectorSpace from ...backend import jax_pytree_class, Context @@ -34,9 +35,8 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} must map dom.spaces[{i}] -> cod.") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: @@ -51,9 +51,8 @@ def _apply_unchecked(self, x: Any) -> Any: acc = yi if acc is None else (acc + yi if use_direct_add else self.cod.add(yi, acc)) return acc + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index f97966d..4862634 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -8,7 +8,7 @@ from ..types import DenseArray from ..backend import Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority def _prod_int(shape: Tuple[int, ...]) -> int: @@ -47,7 +47,7 @@ def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) raise ValueError("ProductSpace requires at least one subspace.") spaces = self._validate_spaces(spaces) - ctx = ctx_manager.resolve_context_priority(ctx, *spaces) + ctx = resolve_context_priority(ctx, *spaces) dims = tuple(_prod_int(s.shape) for s in spaces) offsets: List[int] = [0] diff --git a/tests/backend/test_backend_registry.py b/tests/backend/test_backend_registry.py index 2ec1348..51a20b9 100644 --- a/tests/backend/test_backend_registry.py +++ b/tests/backend/test_backend_registry.py @@ -60,7 +60,7 @@ def index_add(self, x, index, values, copy=True): return y def allclose_sparse(self, a, b, **kwargs): return False sc.register_ops(DummyOps) - assert "dummy" in sc._contextual.manager.ctx_manager.available_ops + assert sc.VectorSpace((1,), "dummy").ctx.ops.family == "dummy" ops = DummyOps() x = ops.reshape(ops.arange(6), (2, 3)) assert np.allclose(ops.sum(x, axis=0), [3, 5, 7]) diff --git a/tests/context/test_checked_method.py b/tests/context/test_checked_method.py new file mode 100644 index 0000000..1d56ac2 --- /dev/null +++ b/tests/context/test_checked_method.py @@ -0,0 +1,103 @@ +import pytest + +from spacecore import checked_method + + +class _RecordingSpace: + def __init__(self, valid): + self.valid = valid + self.calls = [] + + def _check_member(self, value): + self.calls.append(value) + if value != self.valid: + raise ValueError(f"invalid member: {value!r}") + + +class _CheckedDemo: + def __init__(self, enable_checks=True): + self._enable_checks = enable_checks + self.dom = _RecordingSpace("x") + self.cod = _RecordingSpace("y") + self.space = _RecordingSpace("z") + self.apply_result = "y" + self.rapply_result = "x" + self.value_result = 1.0 + self.grad_result = "z" + + @checked_method(in_space="dom", out_space="cod") + def apply(self, x): + """Apply docstring.""" + return self.apply_result + + @checked_method(in_space="cod", out_space="dom") + def rapply(self, y): + return self.rapply_result + + @checked_method(in_space="space") + def value(self, x): + return self.value_result + + @checked_method(in_space="space", out_space="space") + def grad(self, x): + return self.grad_result + + +def test_checked_method_validates_apply_input_and_output(): + demo = _CheckedDemo() + + assert demo.apply("x") == "y" + assert demo.dom.calls == ["x"] + assert demo.cod.calls == ["y"] + + +def test_checked_method_validates_rapply_input_and_output(): + demo = _CheckedDemo() + + assert demo.rapply("y") == "x" + assert demo.cod.calls == ["y"] + assert demo.dom.calls == ["x"] + + +def test_checked_method_validates_value_input(): + demo = _CheckedDemo() + + assert demo.value("z") == 1.0 + assert demo.space.calls == ["z"] + + +def test_checked_method_validates_grad_input_and_output(): + demo = _CheckedDemo() + + assert demo.grad("z") == "z" + assert demo.space.calls == ["z", "z"] + + +def test_checked_method_invalid_input_raises_when_enabled(): + demo = _CheckedDemo(enable_checks=True) + + with pytest.raises(ValueError, match="invalid member"): + demo.apply("bad") + + +def test_checked_method_invalid_output_raises_when_enabled(): + demo = _CheckedDemo(enable_checks=True) + demo.apply_result = "bad" + + with pytest.raises(ValueError, match="invalid member"): + demo.apply("x") + + +def test_checked_method_skips_checks_when_disabled(): + demo = _CheckedDemo(enable_checks=False) + demo.apply_result = "bad" + + assert demo.apply("bad") == "bad" + assert demo.dom.calls == [] + assert demo.cod.calls == [] + + +def test_checked_method_preserves_metadata(): + assert _CheckedDemo.apply.__name__ == "apply" + assert _CheckedDemo.apply.__doc__ == "Apply docstring." + assert _CheckedDemo.apply.__wrapped__ is not None diff --git a/tests/context/test_context_resolution.py b/tests/context/test_context_resolution.py index 60e72fc..28e4214 100644 --- a/tests/context/test_context_resolution.py +++ b/tests/context/test_context_resolution.py @@ -32,6 +32,55 @@ def test_public_resolve_context_priority_wrapper(): sc.set_context(original) +def test_resolve_context_priority_uses_explicit_ctx_before_inferred_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float16, enable_checks=True) + inferred = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=True) + explicit = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + try: + sc.set_context(default) + X = sc.VectorSpace((2,), inferred) + + resolved = sc.resolve_context_priority(explicit, X) + + assert resolved == explicit + assert resolved.dtype == np.dtype(np.float64) + assert resolved.enable_checks is False + finally: + sc.set_context(original) + + +def test_resolve_context_priority_uses_inferred_ctx_only_without_explicit_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=True) + inferred = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=False) + try: + sc.set_context(default) + X = sc.VectorSpace((2,), inferred) + + resolved = sc.resolve_context_priority(None, X) + + assert resolved.ops.family == inferred.ops.family + assert resolved.dtype == np.dtype(np.float32) + assert resolved.enable_checks is False + finally: + sc.set_context(original) + + +def test_resolve_context_priority_uses_default_only_without_explicit_or_inferred_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=True) + try: + sc.set_context(default) + + assert sc.resolve_context_priority(None) == default + finally: + sc.set_context(original) + + def test_default_context_used_when_none_given(): sc = importlib.import_module("spacecore") original = sc.get_context() diff --git a/tests/context/test_enable_checks.py b/tests/context/test_enable_checks.py index 66e7686..c187116 100644 --- a/tests/context/test_enable_checks.py +++ b/tests/context/test_enable_checks.py @@ -2,7 +2,7 @@ import pytest import spacecore as sc -from spacecore._contextual.contextual import ContextConversionError +from spacecore._contextual import ContextConversionError from tests._helpers import has_jax, jax_real_dtype diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3623b9b..d33314f 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -93,6 +93,46 @@ def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): assert np.allclose(q.hess_apply(x), [4.0, -4.0]) +def test_linop_quadratic_form_hermitian_gradient_is_q_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[2.0, 1.0], [1.0, 4.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(Q, ctx=ctx) + x = ctx.asarray([2.0, -1.0]) + + np.testing.assert_allclose(q.grad(x), Q.apply(x)) + np.testing.assert_allclose(q.hess_apply(x), Q.apply(x)) + + +def test_linop_quadratic_form_rejects_non_hermitian_dense_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, 3.0]]), space, space, ctx) + + with pytest.raises(ValueError, match="Hermitian"): + sc.LinOpQuadraticForm(Q, ctx=ctx) + + +def test_linop_quadratic_form_does_not_validate_matrix_free_hermitian_assumption(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + + def apply(x): + return ctx.asarray([x[0] + 2.0 * x[1], 3.0 * x[1]]) + + def rapply(y): + return ctx.asarray([y[0], 2.0 * y[0] + 3.0 * y[1]]) + + Q = sc.MatrixFreeLinOp(apply, rapply, space, space, ctx) + q = sc.LinOpQuadraticForm(Q, ctx=ctx) + x = ctx.asarray([1.0, 2.0]) + + np.testing.assert_allclose(q.grad(x), Q.apply(x)) + + def test_vvalue_and_vgrad_match_elementwise_loops(): ctx = _ctx() q = _quadratic_problem(ctx) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index f352934..4df7035 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -48,7 +48,7 @@ def test_top_level_objects_match_source_modules(): linop = importlib.import_module("spacecore.linop") functional = importlib.import_module("spacecore.functional") linalg = importlib.import_module("spacecore.linalg") - manager = importlib.import_module("spacecore._contextual.manager") + contextual = importlib.import_module("spacecore._contextual") assert sc.Context is backend.Context assert sc.NumpyOps is backend.NumpyOps @@ -62,8 +62,8 @@ def test_top_level_objects_match_source_modules(): assert sc.Functional is functional.Functional assert sc.InnerProductFunctional is functional.InnerProductFunctional assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult - assert sc.get_context is manager.get_context - assert sc.resolve_context_priority is manager.resolve_context_priority + assert sc.get_context is contextual.get_context + assert sc.resolve_context_priority is contextual.resolve_context_priority def test_package_version_matches_project_metadata(): diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index 40d14d2..07b8626 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -112,6 +112,45 @@ def rapply(y): assert calls["rapply"] > 0 +def test_cg_solves_complex_hermitian_positive_definite_system(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + space = sc.VectorSpace((2,), ctx) + matrix = np.array([[4.0, 1.0 + 1.0j], [1.0 - 1.0j, 3.0]], dtype=np.complex128) + A = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + b = ctx.asarray([1.0 + 2.0j, 3.0 - 1.0j]) + + result = sc.cg(A, b, tol=1e-10, maxiter=10) + + np.testing.assert_allclose(to_numpy(result.x), np.linalg.solve(matrix, to_numpy(b)), rtol=1e-8) + np.testing.assert_allclose(to_numpy(A.apply(result.x)), to_numpy(b), rtol=1e-8, atol=1e-8) + assert bool(to_numpy(result.converged)) + + +def test_lsqr_solves_complex_least_squares(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array( + [[1.0 + 1.0j, 0.0], [0.0, 2.0 - 1.0j], [1.0, 1.0j]], + dtype=np.complex128, + ) + A = sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + b = ctx.asarray([1.0 - 1.0j, 2.0 + 0.5j, 3.0j]) + + result = sc.lsqr(A, b, tol=1e-10, maxiter=20) + + expected, *_ = np.linalg.lstsq(matrix, to_numpy(b), rcond=None) + np.testing.assert_allclose(to_numpy(result.x), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(A.H.apply(A.apply(result.x) - b)), + np.zeros(2, dtype=np.complex128), + atol=1e-7, + ) + assert bool(to_numpy(result.converged)) + + @pytest.mark.parametrize("backend_name,dtype", _backend_params()) def test_power_iteration_estimates_dominant_eigenpair(backend_name, dtype): sc = importlib.import_module("spacecore") @@ -254,6 +293,98 @@ def test_stochastic_lanczos_rejects_invalid_max_iter(): sc.stochastic_lanczos(op, ctx.asarray([1.0]), max_iter=0) +def test_stochastic_lanczos_handles_eigenvalues_larger_than_1e10(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + matrix = np.diag([2.0e12, 3.0e12]) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + eigenvalue, eigenvector = sc.stochastic_lanczos( + op, + ctx.asarray([1.0, 1.0]), + max_iter=4, + tol=1e-8, + ) + + np.testing.assert_allclose(to_numpy(eigenvalue), 2.0e12, rtol=1e-6) + np.testing.assert_allclose(np.abs(to_numpy(eigenvector)), [1.0, 0.0], atol=1e-5) + + +def test_stochastic_lanczos_handles_complex_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + space = sc.VectorSpace((2,), ctx) + matrix = np.array([[2.0, 1.0 + 2.0j], [1.0 - 2.0j, 5.0]], dtype=np.complex128) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + eigenvalue, eigenvector = sc.stochastic_lanczos( + op, + ctx.asarray([1.0 + 0.0j, 1.0j]), + max_iter=2, + tol=1e-10, + ) + + expected = np.linalg.eigvalsh(matrix)[0] + np.testing.assert_allclose(to_numpy(eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(op.apply(eigenvector)), + to_numpy(eigenvalue) * to_numpy(eigenvector), + rtol=1e-6, + atol=1e-6, + ) + + +def test_stochastic_lanczos_uses_domain_geometry_for_weighted_inner_product(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + + class WeightedVectorSpace(sc.VectorSpace): + def __init__(self, weights, ctx): + weights = ctx.asarray(weights) + super().__init__(tuple(weights.shape), ctx) + self.weights = weights + + def inner(self, x, y): + if self._enable_checks: + self._check_member(x) + self._check_member(y) + return self.ops.vdot(x, self.weights * y) + + def _convert(self, new_ctx): + return WeightedVectorSpace(new_ctx.asarray(self.weights), new_ctx) + + space = WeightedVectorSpace([1.0, 4.0], ctx) + matrix = np.array([[2.0, 1.0], [0.25, 0.75]]) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + eigenvalue, eigenvector = sc.stochastic_lanczos( + op, + ctx.asarray([1.0, 1.0]), + max_iter=2, + tol=1e-12, + ) + + expected = np.min(np.linalg.eigvals(matrix).real) + np.testing.assert_allclose(to_numpy(eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(op.apply(eigenvector)), + to_numpy(eigenvalue) * to_numpy(eigenvector), + rtol=1e-6, + atol=1e-6, + ) + + +def test_safe_inverse_returns_reciprocal_for_positive_values_only(): + sc = importlib.import_module("spacecore") + utils = importlib.import_module("spacecore.linalg._utils") + ctx = _ctx() + + values = ctx.asarray([-2.0, 0.0, 4.0]) + + np.testing.assert_allclose(to_numpy(utils.safe_inverse(sc.NumpyOps(), values)), [0.0, 0.0, 0.25]) + + def test_iterative_solvers_poll_convergence_on_check_interval(): sc = importlib.import_module("spacecore") ctx = _ctx() diff --git a/tests/linops/test_linop_jit.py b/tests/linops/test_linop_jit.py index 2f30207..32452d6 100644 --- a/tests/linops/test_linop_jit.py +++ b/tests/linops/test_linop_jit.py @@ -38,6 +38,26 @@ def test_dense_linop_jit_apply_and_rapply_with_operator_argument(): np.testing.assert_allclose(to_numpy(rapply_jit(op, y)), [8., 10.]) +def test_decorated_apply_rapply_value_and_grad_jit_compile(): + jax = pytest.importorskip("jax") + + ctx = _jax_ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 1.0], [1.0, 4.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x = ctx.asarray([3.0, -1.0]) + + apply_jit = jax.jit(lambda A, z: A.apply(z)) + rapply_jit = jax.jit(lambda A, z: A.rapply(z)) + value_jit = jax.jit(lambda functional, z: functional.value(z)) + grad_jit = jax.jit(lambda functional, z: functional.grad(z)) + + np.testing.assert_allclose(to_numpy(apply_jit(op, x)), to_numpy(op.apply(x))) + np.testing.assert_allclose(to_numpy(rapply_jit(op, x)), to_numpy(op.rapply(x))) + np.testing.assert_allclose(to_numpy(value_jit(q, x)), to_numpy(q.value(x))) + np.testing.assert_allclose(to_numpy(grad_jit(q, x)), to_numpy(q.grad(x))) + + def test_tensor_dense_linop_jit_preserves_shapes(): jax = pytest.importorskip("jax") From 21c4d7180e1e96c60dc174bb2ca21bfc4b08b230 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Fri, 22 May 2026 01:51:16 -0300 Subject: [PATCH 19/44] Refactor contextual checks and Lanczos geometry --- spacecore/linalg/_krylov.py | 16 ---------------- tests/linops/test_to_dense.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 16 deletions(-) delete mode 100644 spacecore/linalg/_krylov.py diff --git a/spacecore/linalg/_krylov.py b/spacecore/linalg/_krylov.py deleted file mode 100644 index 147ffb2..0000000 --- a/spacecore/linalg/_krylov.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from ._cg import CGResult, cg -from ._lanczos import stochastic_lanczos -from ._lsqr import LSQRResult, lsqr -from ._power import PowerIterationResult, power_iteration - -__all__ = [ - "CGResult", - "LSQRResult", - "PowerIterationResult", - "cg", - "lsqr", - "power_iteration", - "stochastic_lanczos", -] diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py index a60097d..a38ec49 100644 --- a/tests/linops/test_to_dense.py +++ b/tests/linops/test_to_dense.py @@ -153,6 +153,18 @@ def tree_unflatten(cls, aux, children): assert op.A["data"] is dense +def test_diagonal_linop_A_is_cached(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + + A = op.A + + assert op.A is A + assert np.allclose(A, np.diag([1.0, 2.0, 3.0])) + + def test_sum_linop_to_dense_matches_apply(): sc = importlib.import_module("spacecore") ctx = _ctx() From f958cc09ed91758b4a962c34241dcbc6751d238e Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Fri, 22 May 2026 02:10:50 -0300 Subject: [PATCH 20/44] Enrich docstrings and backend docs --- docs/source/api/backend.rst | 23 ++++ docs/source/api/linops.rst | 63 ++++++++++ docs/source/index.rst | 33 +++-- docs/source/tutorials/backend_ops.rst | 43 +++++-- docs/source/tutorials/linops.rst | 42 +++++++ spacecore/_checks.py | 23 +++- spacecore/_contextual/_bound.py | 126 +++++++++++++++++++ spacecore/functional/_linear.py | 109 +++++++++++++++- spacecore/linop/_algebra.py | 174 ++++++++++++++++++++++++-- 9 files changed, 601 insertions(+), 35 deletions(-) diff --git a/docs/source/api/backend.rst b/docs/source/api/backend.rst index 1b466cf..811b38a 100644 --- a/docs/source/api/backend.rst +++ b/docs/source/api/backend.rst @@ -41,6 +41,29 @@ JaxOps :show-inheritance: :exclude-members: jax, jnp, jsparse +CuPyOps +------- + +``CuPyOps`` is the optional CuPy backend implementation for GPU arrays and +``cupyx.scipy.sparse`` matrices. It is exported as ``spacecore.backend.CuPyOps`` +only when CuPy is installed in the environment. + +Install the optional backend before using it: + +.. code-block:: bash + + pip install spacecore[cupy] + +Use it through a normal SpaceCore context: + +.. code-block:: python + + import numpy as np + import spacecore as sc + + ctx = sc.Context(sc.CuPyOps(), dtype=np.float64) + x = ctx.asarray([1.0, 2.0, 3.0]) + TorchOps -------- diff --git a/docs/source/api/linops.rst b/docs/source/api/linops.rst index 2c59373..f0703f6 100644 --- a/docs/source/api/linops.rst +++ b/docs/source/api/linops.rst @@ -12,9 +12,18 @@ actions. spacecore.linop.DenseLinOp spacecore.linop.DiagonalLinOp spacecore.linop.SparseLinOp + spacecore.linop.MatrixFreeLinOp + spacecore.linop.IdentityLinOp + spacecore.linop.ZeroLinOp + spacecore.linop.ScaledLinOp + spacecore.linop.SumLinOp + spacecore.linop.ComposedLinOp spacecore.linop.BlockDiagonalLinOp spacecore.linop.StackedLinOp spacecore.linop.SumToSingleLinOp + spacecore.linop.make_scaled + spacecore.linop.make_sum + spacecore.linop.make_composed LinOp ----- @@ -61,6 +70,60 @@ SparseLinOp :inherited-members: :show-inheritance: +MatrixFreeLinOp +--------------- + +.. autoclass:: spacecore.linop.MatrixFreeLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +IdentityLinOp +------------- + +.. autoclass:: spacecore.linop.IdentityLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +ZeroLinOp +--------- + +.. autoclass:: spacecore.linop.ZeroLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Algebraic operators +------------------- + +.. autoclass:: spacecore.linop.ScaledLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.linop.SumLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.linop.ComposedLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autofunction:: spacecore.linop.make_scaled + +.. autofunction:: spacecore.linop.make_sum + +.. autofunction:: spacecore.linop.make_composed + Product-structured operators ---------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index 470a78d..8a94a88 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -4,18 +4,19 @@ SpaceCore SpaceCore exists for writing numerical algorithms once, independently of the array backend. -For example, the same algorithm can run with NumPy for debugging, JAX for -JIT/autodiff, and Torch for tensor workflows, while preserving the same +For example, the same algorithm can run with NumPy for debugging, CuPy for +eager GPU execution, JAX for JIT/autodiff, and Torch for tensor workflows, +while preserving the same mathematical spaces and linear operators. What problem does SpaceCore solve? ---------------------------------- Numerical algorithms often start as clear NumPy code and later need to move to -JAX, Torch, or another array system. Without a backend boundary, that migration -usually leaks through the whole implementation: array constructors, dtype -handling, inner products, sparse support, and linear-operator conventions all -become backend-specific. +CuPy, JAX, Torch, or another array system. Without a backend boundary, that +migration usually leaks through the whole implementation: array constructors, +dtype handling, inner products, sparse support, and linear-operator conventions +all become backend-specific. SpaceCore keeps those choices in a ``Context``, while algorithms work with mathematical objects: @@ -38,7 +39,8 @@ Write once, run twice --------------------- This gradient descent loop uses only the ``Space`` and ``LinOp`` APIs. It does -not know whether the arrays are NumPy arrays, JAX arrays, or Torch tensors. +not know whether the arrays are NumPy arrays, CuPy arrays, JAX arrays, or Torch +tensors. .. code-block:: python @@ -141,7 +143,7 @@ Core concepts A ``Context`` specifies how objects are represented: -* backend operations (``NumpyOps``, ``JaxOps``, ``TorchOps``, etc.); +* backend operations (``NumpyOps``, ``CuPyOps``, ``JaxOps``, ``TorchOps``, etc.); * default dtype; * runtime validation behavior. @@ -171,7 +173,12 @@ hard-coding backend array operations. A ``LinOp`` represents a linear operator between spaces: * ``DenseLinOp`` for dense matrix or tensor operators; +* ``DiagonalLinOp`` for coordinatewise diagonal operators; * ``SparseLinOp`` for sparse operators; +* ``MatrixFreeLinOp`` for callable-backed operators without stored matrices; +* ``IdentityLinOp`` and ``ZeroLinOp`` for canonical identity and zero maps; +* ``ScaledLinOp``, ``SumLinOp``, and ``ComposedLinOp`` for lazy operator + algebra; * ``BlockDiagonalLinOp`` for block-diagonal product-space operators; * ``StackedLinOp`` for operators from one space into a product space; * ``SumToSingleLinOp`` for operators from a product space into one space. @@ -231,6 +238,12 @@ With JAX support: pip install "spacecore[jax]" +With CuPy support: + +.. code-block:: bash + + pip install "spacecore[cupy]" + With PyTorch support: .. code-block:: bash @@ -240,6 +253,10 @@ With PyTorch support: * ``spacecore[jax]`` installs optional JAX support. * GPU users should install the appropriate CUDA-enabled JAX build first, following the official JAX installation guide. +* ``spacecore[cupy]`` installs optional CuPy support for ``cupy.ndarray`` and + ``cupyx.scipy.sparse`` backends. +* GPU users should install the appropriate CUDA-enabled CuPy package first, + following the official CuPy installation guide. * ``spacecore[torch]`` installs optional PyTorch support for ``torch.Tensor`` backends. * GPU users should install the appropriate CUDA-enabled PyTorch build first, diff --git a/docs/source/tutorials/backend_ops.rst b/docs/source/tutorials/backend_ops.rst index 47b58ce..b0f6e83 100644 --- a/docs/source/tutorials/backend_ops.rst +++ b/docs/source/tutorials/backend_ops.rst @@ -5,9 +5,15 @@ This tutorial follows ``tutorials/1_BackendOps.ipynb``. It explains what ``BackendOps`` represents in SpaceCore, how it relates to ``Context``, and how to use the predefined backends. -Current predefined implementations are ``NumpyOps``, ``JaxOps``, and -``TorchOps``. ``TorchOps`` is optional and is available after installing the -PyTorch extra: +Current predefined implementations are ``NumpyOps``, ``JaxOps``, ``CuPyOps``, +and ``TorchOps``. ``CuPyOps`` and ``TorchOps`` are optional and are available +after installing their backend extras: + +.. code-block:: bash + + pip install spacecore[cupy] + +Install PyTorch support with: .. code-block:: bash @@ -31,9 +37,10 @@ SpaceCore separates two concerns: * the numerical backend used to store and compute with them. The same mathematical object may be represented using NumPy arrays for eager CPU -work, JAX arrays for JIT compilation and automatic differentiation, or PyTorch -tensors for eager CPU/CUDA execution and autograd. Without a backend abstraction, spaces and -operators would need backend-specific branches throughout their implementations. +work, CuPy arrays for eager GPU execution, JAX arrays for JIT compilation and +automatic differentiation, or PyTorch tensors for eager CPU/CUDA execution and +autograd. Without a backend abstraction, spaces and operators would need +backend-specific branches throughout their implementations. The design is: @@ -54,11 +61,12 @@ signatures that SpaceCore relies on. Common dense-array methods are implemented once in ``BackendOps`` by delegating to an Array API compatible ``xp`` namespace. NumPy and PyTorch use -``array-api-compat`` wrappers, while JAX uses ``jax.numpy``. Concrete backend -classes keep behavior that is genuinely backend-specific, such as dtype -sanitation, sparse conversion, indexed updates, device/autograd controls, and -control-flow primitives. ``ops.xp`` is available as an escape hatch, but -portable SpaceCore code should prefer explicit ``ops`` methods. +``array-api-compat`` wrappers, while CuPy uses ``cupy`` and JAX uses +``jax.numpy``. Concrete backend classes keep behavior that is genuinely +backend-specific, such as dtype sanitation, sparse conversion, indexed updates, +device/autograd controls, and control-flow primitives. ``ops.xp`` is available +as an escape hatch, but portable SpaceCore code should prefer explicit ``ops`` +methods. For example, NumPy and JAX expose different optional arguments for matrix multiplication, but SpaceCore's portable interface only needs the common core: @@ -147,6 +155,19 @@ Use ``JaxOps`` for the JAX execution model: JIT compilation, automatic differentiation, accelerator execution, and JAX sparse compatibility. JAX dtype behavior depends on local JAX configuration, especially ``jax_enable_x64``. +Use ``CuPyOps`` for eager GPU-backed CuPy arrays and ``cupyx.scipy.sparse`` +matrices. The backend is optional and is exported only when CuPy is installed. +It follows CuPy's NumPy-compatible dtype behavior and keeps arrays on the CuPy +device where they were created. + +.. code-block:: python + + import numpy as np + from spacecore.backend import Context, CuPyOps + + ctx_cupy = Context(CuPyOps(), dtype=np.float64) + x = ctx_cupy.asarray([1.0, 2.0, 3.0]) + Use ``TorchOps`` for PyTorch tensors. The backend can be requested by either ``"torch"`` or ``"pytorch"`` where SpaceCore accepts backend names. diff --git a/docs/source/tutorials/linops.rst b/docs/source/tutorials/linops.rst index 43fa474..2005b20 100644 --- a/docs/source/tutorials/linops.rst +++ b/docs/source/tutorials/linops.rst @@ -9,6 +9,12 @@ Current implemented operator types are: * ``DenseLinOp`` * ``DiagonalLinOp`` * ``SparseLinOp`` +* ``MatrixFreeLinOp`` +* ``IdentityLinOp`` +* ``ZeroLinOp`` +* ``ScaledLinOp`` +* ``SumLinOp`` +* ``ComposedLinOp`` * ``BlockDiagonalLinOp`` * ``StackedLinOp`` * ``SumToSingleLinOp`` @@ -140,6 +146,42 @@ of the operator structure. op_sparse = sc.SparseLinOp(A_sparse, X, Y, ctx=ctx) +MatrixFreeLinOp +--------------- + +``MatrixFreeLinOp`` stores callables for forward and adjoint actions instead +of matrix entries. Use it when a linear map has a fast procedural +implementation or when materializing a matrix is too expensive. + +.. code-block:: python + + def apply(x): + return ctx.asarray([x[0] + x[1], x[0] - x[1]]) + + def rapply(y): + return ctx.asarray([y[0] + y[1], y[0] - y[1]]) + + op_free = sc.MatrixFreeLinOp(apply, rapply, X, X, ctx=ctx) + +Canonical and algebraic operators +--------------------------------- + +``IdentityLinOp`` and ``ZeroLinOp`` represent the canonical identity and zero +maps on spaces. Operator algebra creates lazy operators without immediately +materializing dense storage: + +.. code-block:: python + + I = sc.IdentityLinOp(X, ctx=ctx) + Z = sc.ZeroLinOp(X, Y, ctx=ctx) + + scaled = 2.0 * I # ScaledLinOp + summed = I + scaled # SumLinOp + composed = summed @ I # ComposedLinOp + +The helper constructors ``make_scaled``, ``make_sum``, and ``make_composed`` +perform the same simplifications used by the Python operators. + Product operators ----------------- diff --git a/spacecore/_checks.py b/spacecore/_checks.py index d1b84f9..6bae056 100644 --- a/spacecore/_checks.py +++ b/spacecore/_checks.py @@ -10,7 +10,28 @@ def checked_method( out_space: str | None = None, arg_pos: int = 0, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """Decorate methods with optional Space membership checks.""" + """ + Build a decorator that validates method inputs and outputs against spaces. + + Parameters + ---------- + in_space: + Name of the attribute on ``self`` containing the input + :class:`~spacecore.space.Space`, or ``None`` to skip input validation. + out_space: + Name of the attribute on ``self`` containing the output + :class:`~spacecore.space.Space`, or ``None`` to skip output validation. + arg_pos: + Zero-based position in ``*args`` of the input value that should be + checked against ``in_space``. + + Returns + ------- + Callable[[Callable[..., Any]], Callable[..., Any]] + Decorator that wraps a method, performs Python-level checks when + ``self._enable_checks`` is true, and otherwise forwards directly to the + wrapped method. + """ def decorate(method: Callable[..., Any]) -> Callable[..., Any]: @wraps(method) diff --git a/spacecore/_contextual/_bound.py b/spacecore/_contextual/_bound.py index 658d22d..9d61334 100644 --- a/spacecore/_contextual/_bound.py +++ b/spacecore/_contextual/_bound.py @@ -9,6 +9,22 @@ def _same_effective_context(left: Context, right: Context) -> bool: + """ + Compare two contexts for conversion equivalence. + + Parameters + ---------- + left: + First context to compare. + right: + Second context to compare. + + Returns + ------- + bool + ``True`` when both contexts use the same backend operations, dtype, and + check policy; otherwise ``False``. + """ return ( left.ops == right.ops and left.dtype == right.dtype @@ -17,26 +33,136 @@ def _same_effective_context(left: Context, right: Context) -> bool: class ContextBound(ABC): + """ + Base class for objects bound to a SpaceCore execution context. + + ``ContextBound`` normalizes and stores a :class:`~spacecore.backend.Context` + for subclasses such as spaces, linear operators, and functionals. It also + provides convenience access to the context's backend operations and dtype, + plus a common ``convert`` workflow that respects the global context + conversion policy. + + Subclasses that own backend arrays or nested context-bound objects must + implement :meth:`_convert` to rebuild themselves in a target context. + + Parameters + ---------- + ctx: + Context specification passed to :meth:`__init__`. This may be a + concrete :class:`~spacecore.backend.Context`, a backend-family string, + or ``None`` to use the current default context. + + Returns + ------- + ContextBound + A context-aware object whose concrete type is provided by a subclass. + """ + def __init__(self, ctx: Context | str | None = None): + """ + Initialize this object with a normalized context. + + Parameters + ---------- + ctx: + Context specification for the object. This may be a concrete + :class:`~spacecore.backend.Context`, a backend-family string, or + ``None`` to use the current default context. + + Returns + ------- + None + The initializer stores the normalized context on ``self``. + """ ctx = normalize_context(ctx) self._ctx = ctx @property def ops(self) -> BackendOps: + """ + Return backend operations associated with this object's context. + + Parameters + ---------- + None + + Returns + ------- + BackendOps + Backend operation object used by this instance. + """ return self.ctx.ops @property def dtype(self) -> DType: + """ + Return the default dtype associated with this object's context. + + Parameters + ---------- + None + + Returns + ------- + DType + Backend-normalized dtype stored in the bound context. + """ return self.ctx.dtype @property def ctx(self) -> Context: + """ + Return the execution context bound to this object. + + Parameters + ---------- + None + + Returns + ------- + Context + Context that controls backend operations, dtype, and validation + policy for this instance. + """ return self._ctx def _convert(self, new_ctx: Context) -> Self: + """ + Rebuild this object in ``new_ctx``. + + Subclasses implement this hook with their concrete conversion logic. + The public :meth:`convert` method handles policy enforcement and skips + conversion when the target context is effectively identical. + + Parameters + ---------- + new_ctx: + Concrete target context in which the subclass should rebuild its + owned arrays, spaces, operators, or nested context-bound objects. + + Returns + ------- + Self + New object of the subclass type represented in ``new_ctx``. + """ raise NotImplementedError() def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: + """ + Return this object represented in ``new_ctx``. + + Parameters + ---------- + new_ctx: + Target context specification. ``None`` resolves according to the + current conversion policy and default context. + + Returns + ------- + Self + ``self`` when no effective context change is needed; otherwise a + converted object produced by :meth:`_convert`. + """ _, new_ctx = enforce_convert_policy(self, new_ctx) if _same_effective_context(self.ctx, new_ctx): return self diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py index e231220..487349f 100644 --- a/spacecore/functional/_linear.py +++ b/spacecore/functional/_linear.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any +from typing import Any, Callable from ._base import Domain, Functional from .._checks import checked_method @@ -93,16 +93,60 @@ class MatrixFreeLinearFunctional(LinearFunctional[Domain]): """ Linear functional defined by user-supplied evaluation callables. - No representer is stored or materialized. + ``MatrixFreeLinearFunctional(value, X)`` represents a linear scalar-valued + map on ``X`` without storing or materializing a Riesz representer. + + Parameters + ---------- + value: + Callable with signature ``value(x: Any) -> Any`` accepting an element of + ``dom`` and returning a scalar-like backend value. + dom: + Domain space of the functional. + ctx: + Optional context specification. An explicit context wins over inferred + and default contexts. + vvalue: + Optional callable with signature ``vvalue(xs: Any) -> Any`` for batched + evaluation. If omitted, backend ``vmap`` fallback is used. + + Returns + ------- + MatrixFreeLinearFunctional + Functional using the supplied callable for scalar evaluation and, + optionally, batched scalar evaluation. """ def __init__( self, - value: Any, + value: Callable[[Any], Any], dom: Domain, ctx: Context | str | None = None, - vvalue: Any | None = None, + vvalue: Callable[[Any], Any] | None = None, ) -> None: + """ + Initialize a matrix-free linear functional. + + Parameters + ---------- + value: + Callable ``value(x)`` accepting an element of ``dom`` and returning + a scalar-like value. + dom: + Domain space of the functional. + ctx: + Optional context specification for the functional and converted + domain. + vvalue: + Optional callable ``vvalue(xs)`` accepting a batch of domain + elements and returning a batch of scalar-like values. + + Returns + ------- + None + The initializer stores the callables and converted domain on + ``self``. + """ if not callable(value): raise TypeError(f"value must be callable, got {type(value).__name__}.") if vvalue is not None and not callable(vvalue): @@ -113,20 +157,59 @@ def __init__( @property def representer(self) -> Any: + """ + Raise because matrix-free functionals do not store a representer. + + Parameters + ---------- + None + + Returns + ------- + Any + This property never returns; it raises ``NotImplementedError``. + """ raise NotImplementedError( f"{type(self).__name__} does not store a Riesz representer." ) @checked_method(in_space="domain") def value(self, x: Any) -> Any: - """Return ``value_fn(x)``.""" + """ + Evaluate the scalar functional. + + Parameters + ---------- + x: + Element of ``self.domain`` passed to ``value_fn``. + + Returns + ------- + Any + Scalar-like backend value returned by ``value_fn``. + """ y = self.value_fn(x) if self._enable_checks: self._check_scalar_batch(y, ()) return y def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: - """Return ``vvalue_fn(xs)`` when supplied, otherwise use fallback batching.""" + """ + Evaluate the scalar functional over a batch of domain elements. + + Parameters + ---------- + xs: + Batched element of ``self.domain``. + batch_space: + Optional batch-space descriptor for ``xs``. + + Returns + ------- + Any + Backend array of scalar-like values with shape matching the leading + batch shape. + """ if self.vvalue_fn is None: return super().vvalue(xs, batch_space) in_space = self._input_batch_space(self.domain, xs, batch_space) @@ -158,6 +241,20 @@ def tree_unflatten(cls, aux, children): return cls(value_fn, domain, ctx, vvalue_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinearFunctional: + """ + Convert this functional to ``new_ctx``. + + Parameters + ---------- + new_ctx: + Concrete target context for the converted domain. + + Returns + ------- + MatrixFreeLinearFunctional + Functional with converted domain and the same user-supplied + callables. + """ return MatrixFreeLinearFunctional( self.value_fn, self.domain.convert(new_ctx), diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index 8b9c3b4..16d04b0 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -1,7 +1,7 @@ from __future__ import annotations from numbers import Number -from typing import Any, Sequence +from typing import Any, Callable, Sequence from ._base import LinOp, Domain, Codomain from .._checks import checked_method @@ -556,18 +556,80 @@ class MatrixFreeLinOp(LinOp[Domain, Codomain]): action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. When checks are enabled, inputs and callable outputs are validated against the corresponding domain and codomain. + + Parameters + ---------- + apply: + Callable with signature ``apply(x: Any) -> Any`` implementing the + forward map from ``dom`` to ``cod``. + rapply: + Callable with signature ``rapply(y: Any) -> Any`` implementing the + adjoint map from ``cod`` back to ``dom``. + dom: + Domain space containing valid inputs for ``apply`` and outputs from + ``rapply``. + cod: + Codomain space containing outputs from ``apply`` and valid inputs for + ``rapply``. + ctx: + Optional context specification. An explicit context wins over inferred + contexts from ``dom`` and ``cod``. + vapply: + Optional callable with signature ``vapply(xs: Any) -> Any`` for batched + forward application. If omitted, backend ``vmap`` fallback is used. + rvapply: + Optional callable with signature ``rvapply(ys: Any) -> Any`` for + batched adjoint application. If omitted, backend ``vmap`` fallback is + used. + + Returns + ------- + MatrixFreeLinOp + Operator using the supplied callables for forward, adjoint, and + optionally batched actions. """ def __init__( self, - apply: Any, - rapply: Any, + apply: Callable[[Any], Any], + rapply: Callable[[Any], Any], dom: Domain, cod: Codomain, ctx: Context | str | None = None, - vapply: Any | None = None, - rvapply: Any | None = None, + vapply: Callable[[Any], Any] | None = None, + rvapply: Callable[[Any], Any] | None = None, ) -> None: + """ + Initialize a matrix-free linear operator. + + Parameters + ---------- + apply: + Callable ``apply(x)`` that accepts an element of ``dom`` and returns + an element of ``cod``. + rapply: + Callable ``rapply(y)`` that accepts an element of ``cod`` and + returns an element of ``dom``. + dom: + Domain space of the operator. + cod: + Codomain space of the operator. + ctx: + Optional context specification for the operator and converted + spaces. + vapply: + Optional callable for batched forward application over ``dom`` + batches. + rvapply: + Optional callable for batched adjoint application over ``cod`` + batches. + + Returns + ------- + None + The initializer stores the callables and converted spaces on + ``self``. + """ if not callable(apply): raise TypeError(f"apply must be callable, got {type(apply).__name__}.") if not callable(rapply): @@ -584,22 +646,87 @@ def __init__( @checked_method(in_space="domain", out_space="codomain") def apply(self, x: Any) -> Any: - """Return ``apply_fn(x)``.""" + """ + Apply the forward callable. + + Parameters + ---------- + x: + Element of ``self.domain`` passed to ``apply_fn``. + + Returns + ------- + Any + Element of ``self.codomain`` returned by ``apply_fn``. + """ return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """ + Apply ``apply_fn`` without membership checks. + + Parameters + ---------- + x: + Value accepted by the user-supplied forward callable. + + Returns + ------- + Any + Raw forward-callable output. + """ return self.apply_fn(x) @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: Any) -> Any: - """Return ``rapply_fn(y)``.""" + """ + Apply the adjoint callable. + + Parameters + ---------- + y: + Element of ``self.codomain`` passed to ``rapply_fn``. + + Returns + ------- + Any + Element of ``self.domain`` returned by ``rapply_fn``. + """ return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """ + Apply ``rapply_fn`` without membership checks. + + Parameters + ---------- + y: + Value accepted by the user-supplied adjoint callable. + + Returns + ------- + Any + Raw adjoint-callable output. + """ return self.rapply_fn(y) def vapply(self, xs: Any, batch_space=None) -> Any: - """Return ``vapply_fn(xs)`` when supplied, otherwise use fallback batching.""" + """ + Apply this operator to a batch of domain elements. + + Parameters + ---------- + xs: + Batched element of ``self.domain``. + batch_space: + Optional batch-space descriptor for ``xs``. + + Returns + ------- + Any + Batched element of ``self.codomain`` produced by ``vapply_fn`` or + by the fallback batching implementation. + """ if self.vapply_fn is None: return super().vapply(xs, batch_space) in_space = self._input_batch_space(self.domain, xs, batch_space) @@ -611,7 +738,22 @@ def vapply(self, xs: Any, batch_space=None) -> Any: return ys def rvapply(self, ys: Any, batch_space=None) -> Any: - """Return ``rvapply_fn(ys)`` when supplied, otherwise use fallback batching.""" + """ + Apply the adjoint operator to a batch of codomain elements. + + Parameters + ---------- + ys: + Batched element of ``self.codomain``. + batch_space: + Optional batch-space descriptor for ``ys``. + + Returns + ------- + Any + Batched element of ``self.domain`` produced by ``rvapply_fn`` or by + the fallback batching implementation. + """ if self.rvapply_fn is None: return super().rvapply(ys, batch_space) in_space = self._input_batch_space(self.codomain, ys, batch_space) @@ -653,6 +795,20 @@ def tree_unflatten(cls, aux, children): return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn) def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: + """ + Convert this matrix-free operator to ``new_ctx``. + + Parameters + ---------- + new_ctx: + Concrete target context for converted domain and codomain spaces. + + Returns + ------- + MatrixFreeLinOp + Operator with converted spaces and the same user-supplied + callables. + """ return MatrixFreeLinOp( self.apply_fn, self.rapply_fn, From 9197c23d618c79376a5f300bc6163cc18711a7c1 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Fri, 22 May 2026 02:31:19 -0300 Subject: [PATCH 21/44] Fix contextual import cycle --- spacecore/_contextual/__init__.py | 2 ++ spacecore/_contextual/_bound.py | 6 ++-- spacecore/_contextual/_manager.py | 49 +++++++++++++++++++++---------- spacecore/backend/_context.py | 8 +++-- 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/spacecore/_contextual/__init__.py b/spacecore/_contextual/__init__.py index bb86e5d..5d319e1 100644 --- a/spacecore/_contextual/__init__.py +++ b/spacecore/_contextual/__init__.py @@ -5,6 +5,7 @@ get_dtype_resolution_policy as get_dtype_resolution_policy, get_resolution_policy as get_resolution_policy, normalize_context as normalize_context, + normalize_ops as normalize_ops, register_ops as register_ops, resolve_context_priority as resolve_context_priority, set_context as set_context, @@ -35,6 +36,7 @@ "get_dtype_resolution_policy", "get_resolution_policy", "normalize_context", + "normalize_ops", "register_ops", "resolve_context_priority", "set_context", diff --git a/spacecore/_contextual/_bound.py b/spacecore/_contextual/_bound.py index 9d61334..2548bfa 100644 --- a/spacecore/_contextual/_bound.py +++ b/spacecore/_contextual/_bound.py @@ -1,12 +1,14 @@ from __future__ import annotations from abc import ABC -from typing import Self +from typing import TYPE_CHECKING, Self -from ..backend import Context, BackendOps, BackendFamily from ..types import DType from ._manager import enforce_convert_policy, normalize_context +if TYPE_CHECKING: + from ..backend import BackendFamily, BackendOps, Context + def _same_effective_context(left: Context, right: Context) -> bool: """ diff --git a/spacecore/_contextual/_manager.py b/spacecore/_contextual/_manager.py index 104f0a4..9e0efc9 100644 --- a/spacecore/_contextual/_manager.py +++ b/spacecore/_contextual/_manager.py @@ -1,9 +1,19 @@ -from typing import Any +from __future__ import annotations -from ..backend import Context, BackendOps +from typing import TYPE_CHECKING, Any + +from ..backend._family import BackendFamily +from ..backend._ops import BackendOps from ._policies import ContextPolicy, DtypePreservePolicy -from ._state import _contextual -from ..backend import BackendFamily + +if TYPE_CHECKING: + from ..backend._context import Context + + +def _state(): + from ._state import _contextual + + return _contextual def set_context( @@ -33,8 +43,8 @@ def set_context( Objects created without an explicit context use this default context. Existing spaces, operators, and contexts are not modified. """ - ctx = _contextual.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) - _contextual.default_ctx = ctx + ctx = _state().normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + _state().default_ctx = ctx def get_context() -> Context: @@ -47,7 +57,7 @@ def get_context() -> Context: The default context used by constructors when no explicit context can be inferred or provided. """ - return _contextual.default_ctx + return _state().default_ctx def resolve_context_priority( @@ -78,7 +88,7 @@ def resolve_context_priority( User code should call this function instead of accessing the internal context manager singleton. """ - return _contextual.resolve_context_priority(priority_ctx, *other_ctx) + return _state().resolve_context_priority(priority_ctx, *other_ctx) def register_ops(ops: type[BackendOps]) -> type[BackendOps]: @@ -113,7 +123,7 @@ def register_ops(ops: type[BackendOps]) -> type[BackendOps]: class MyOps(BackendOps): ... """ - return _contextual.register_ops(ops) + return _state().register_ops(ops) def normalize_context( @@ -122,7 +132,16 @@ def normalize_context( enable_checks: bool | None = None, ) -> Context: """Normalize a context specification through the process-wide state.""" - return _contextual.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + return _state().normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + + +def normalize_ops( + ops: str | BackendFamily | BackendOps | type[BackendOps] | Context +) -> BackendOps: + """Normalize backend operations through the process-wide state.""" + if isinstance(ops, BackendOps): + return ops + return _state().get_ops(ops) def enforce_convert_policy( @@ -130,7 +149,7 @@ def enforce_convert_policy( to: Context | BackendFamily | str | None = None, ) -> tuple[Any, Context]: """Resolve a conversion target and enforce the configured policy.""" - return _contextual.enforce_convert_policy(x, to) + return _state().enforce_convert_policy(x, to) def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: @@ -156,7 +175,7 @@ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: * ``"error"``: reject backend conversion. * ``"silent"``: allow backend conversion without warning. """ - _contextual.resolution_policy = policy + _state().resolution_policy = policy def get_resolution_policy() -> str: @@ -168,7 +187,7 @@ def get_resolution_policy() -> str: str Policy name, one of ``"warning"``, ``"error"``, or ``"silent"``. """ - return _contextual.resolution_policy.value + return _state().resolution_policy.value def set_dtype_resolution_policy( @@ -196,7 +215,7 @@ def set_dtype_resolution_policy( equivalent dtype in the target backend. * ``"convert"``: use the dtype provided by the resolved target context. """ - _contextual.dtype_resolution_policy = policy + _state().dtype_resolution_policy = policy def get_dtype_resolution_policy() -> str: @@ -208,4 +227,4 @@ def get_dtype_resolution_policy() -> str: str Policy name, one of ``"keep_native"`` or ``"convert"``. """ - return _contextual.dtype_resolution_policy.value + return _state().dtype_resolution_policy.value diff --git a/spacecore/backend/_context.py b/spacecore/backend/_context.py index 92b5fd6..689aa43 100644 --- a/spacecore/backend/_context.py +++ b/spacecore/backend/_context.py @@ -3,6 +3,7 @@ from ._ops import BackendOps from ..types import DenseArray, SparseArray, DType, ArrayLike +from .._contextual import normalize_ops @dataclass(frozen=True, slots=True) @@ -52,8 +53,11 @@ def __post_init__(self): TypeError If ``ops`` is not a :class:`BackendOps` instance. """ - if not isinstance(self.ops, BackendOps): - raise TypeError("ops must be a BackendOps") + try: + ops = normalize_ops(self.ops) + except TypeError: + raise TypeError("Unknown ops type.") + object.__setattr__(self, "ops", ops) sanitized = self.ops.sanitize_dtype(self.dtype) object.__setattr__(self, "dtype", sanitized) From f7eb7c887485c7445a3f8f62fe5f55a26bfd8d05 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Fri, 22 May 2026 02:38:57 -0300 Subject: [PATCH 22/44] Fix contextual import cycle --- spacecore/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index e190147..b755cc5 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -72,6 +72,7 @@ register_ops, set_resolution_policy, set_dtype_resolution_policy, get_resolution_policy, get_dtype_resolution_policy, + normalize_ops, normalize_context, ) __all__ = [ @@ -144,6 +145,8 @@ "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", + "normalize_ops", + "normalize_context", ] if "TorchOps" in globals(): From b3ab5ad744c03576ed1337f2b40cc84492d11d26 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 01:57:50 -0300 Subject: [PATCH 23/44] Prepare v0.2 API cleanup --- .github/workflows/ci.yml | 2 +- pyproject.toml | 5 +- spacecore/__init__.py | 15 +++- spacecore/_contextual/_bound.py | 36 +++++++- spacecore/_contextual/_manager.py | 9 +- spacecore/backend/_ops.py | 43 ++++++++++ spacecore/functional/__init__.py | 3 + spacecore/functional/_base.py | 23 +++++- spacecore/functional/_composed.py | 99 ++++++++++++++++++++++ spacecore/functional/_quadratic.py | 17 +--- spacecore/linalg/__init__.py | 4 +- spacecore/linalg/_cg.py | 6 +- spacecore/linalg/_lanczos.py | 93 +++++++++++++-------- spacecore/linalg/_lsqr.py | 12 +-- spacecore/linalg/_utils.py | 20 +++-- spacecore/linop/_algebra.py | 63 +++++++++++--- spacecore/linop/_base.py | 12 +++ spacecore/linop/_dense.py | 19 ++++- spacecore/linop/_diagonal.py | 48 ++++++++--- spacecore/linop/_sparse.py | 19 ++++- tests/functional/test_functional.py | 85 +++++++++++++++++++ tests/integration/test_public_api.py | 7 +- tests/linalg/test_krylov.py | 79 ++++++++++-------- tests/linops/test_algebra.py | 21 +++++ tests/linops/test_algebra_linop.py | 8 +- tests/linops/test_diagonal_linop.py | 118 +++++++++++++++++++++++++++ 26 files changed, 729 insertions(+), 137 deletions(-) create mode 100644 spacecore/functional/_composed.py create mode 100644 tests/linops/test_diagonal_linop.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 413e173..ca7ab8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: - run: python -m pip install --upgrade pip - run: pip install -e ".[jax,torch,dev]" - - run: pytest + - run: pytest --cov=spacecore --cov-report=term-missing --cov-fail-under=70 - run: ruff check . publish: diff --git a/pyproject.toml b/pyproject.toml index 91f9bc4..a76d717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "spacecore" -version = "0.1.4" +version = "0.2.0" description = "Backend-agnostic vector spaces and linear operators." readme = "README.md" requires-python = ">=3.11" @@ -60,7 +60,8 @@ docs = [ ] dev = [ "pytest>=8.0", - "ruff>=0.6", + "pytest-cov>=5", + "ruff>=0.6", ] [tool.setuptools] diff --git a/spacecore/__init__.py b/spacecore/__init__.py index b755cc5..16d922f 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -1,4 +1,9 @@ -__version__ = "0.1.4" +from importlib.metadata import version as _version + +try: + __version__ = _version("spacecore") +except Exception: + __version__ = "0.0.0+unknown" from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class @@ -29,19 +34,23 @@ make_sum, ) from .functional import ( + ComposedFunctional, Functional, InnerProductFunctional, LinearFunctional, LinOpQuadraticForm, MatrixFreeLinearFunctional, QuadraticForm, + make_functional_composed, ) from .linalg import ( CGResult, + LanczosResult, LSQRResult, PowerIterationResult, StochasticLanczosResult, cg, + lanczos_smallest, lsqr, power_iteration, stochastic_lanczos, @@ -100,18 +109,22 @@ "SumToSingleLinOp", "StackedLinOp", + "ComposedFunctional", "Functional", "LinearFunctional", "InnerProductFunctional", "MatrixFreeLinearFunctional", "QuadraticForm", "LinOpQuadraticForm", + "make_functional_composed", "CGResult", + "LanczosResult", "LSQRResult", "PowerIterationResult", "StochasticLanczosResult", "cg", + "lanczos_smallest", "lsqr", "power_iteration", "stochastic_lanczos", diff --git a/spacecore/_contextual/_bound.py b/spacecore/_contextual/_bound.py index 2548bfa..cce24e0 100644 --- a/spacecore/_contextual/_bound.py +++ b/spacecore/_contextual/_bound.py @@ -10,9 +10,9 @@ from ..backend import BackendFamily, BackendOps, Context -def _same_effective_context(left: Context, right: Context) -> bool: +def _same_context_for_conversion(left: Context, right: Context) -> bool: """ - Compare two contexts for conversion equivalence. + Compare contexts for conversion equivalence. Parameters ---------- @@ -26,6 +26,12 @@ def _same_effective_context(left: Context, right: Context) -> bool: bool ``True`` when both contexts use the same backend operations, dtype, and check policy; otherwise ``False``. + + Notes + ----- + This predicate is used by ``convert()`` and intentionally includes + ``enable_checks`` because a converted object with different runtime checks + is operationally different. """ return ( left.ops == right.ops @@ -34,6 +40,30 @@ def _same_effective_context(left: Context, right: Context) -> bool: ) +def _same_context_for_algebra(left: Context, right: Context) -> bool: + """ + Compare contexts for algebraic compatibility. + + Parameters + ---------- + left: + First context to compare. + right: + Second context to compare. + + Returns + ------- + bool + ``True`` when both contexts use the same backend operations and dtype. + + Notes + ----- + Algebraic combinators ignore ``enable_checks`` because validation policy is + operational, not mathematical. + """ + return left.ops == right.ops and left.dtype == right.dtype + + class ContextBound(ABC): """ Base class for objects bound to a SpaceCore execution context. @@ -166,6 +196,6 @@ def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: converted object produced by :meth:`_convert`. """ _, new_ctx = enforce_convert_policy(self, new_ctx) - if _same_effective_context(self.ctx, new_ctx): + if _same_context_for_conversion(self.ctx, new_ctx): return self return self._convert(new_ctx) diff --git a/spacecore/_contextual/_manager.py b/spacecore/_contextual/_manager.py index 9e0efc9..efea35f 100644 --- a/spacecore/_contextual/_manager.py +++ b/spacecore/_contextual/_manager.py @@ -10,10 +10,17 @@ from ..backend._context import Context +_cached_state = None + + def _state(): + global _cached_state + if _cached_state is not None: + return _cached_state from ._state import _contextual - return _contextual + _cached_state = _contextual + return _cached_state def set_context( diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index fd8eb31..f5362f7 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -248,6 +248,49 @@ def eps(self, dtype: DType) -> float: """Machine epsilon for dtype.""" return float(self.xp.finfo(self.sanitize_dtype(dtype)).eps) + def is_complex_dtype(self, dtype: DType) -> bool: + """ + Return whether ``dtype`` is a complex floating type. + + Parameters + ---------- + dtype: + Backend or portable dtype specifier to inspect. + + Returns + ------- + bool + ``True`` when ``dtype`` represents complex floating values. + """ + dtype = self.sanitize_dtype(dtype) + return getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + + def real_dtype(self, dtype: DType) -> DType: + """ + Return the real floating dtype with the same precision as ``dtype``. + + Parameters + ---------- + dtype: + Backend or portable dtype specifier. + + Returns + ------- + DType + ``dtype`` itself when it is already real-valued; otherwise + ``float32`` for complex64 and ``float64`` for complex128. + """ + dtype = self.sanitize_dtype(dtype) + if not self.is_complex_dtype(dtype): + return dtype + itemsize = getattr(dtype, "itemsize", None) + if itemsize is None: + dtype_text = str(dtype) + if "complex64" in dtype_text: + return self.sanitize_dtype("float32") + return self.sanitize_dtype("float64") + return self.sanitize_dtype("float32" if itemsize <= 8 else "float64") + def get_dtype(self, x: Any) -> DType: """Return x.dtype after verifying x is a backend array.""" if self.is_array(x): diff --git a/spacecore/functional/__init__.py b/spacecore/functional/__init__.py index 19bae03..8209cf1 100644 --- a/spacecore/functional/__init__.py +++ b/spacecore/functional/__init__.py @@ -1,12 +1,15 @@ from ._base import Functional +from ._composed import ComposedFunctional, make_functional_composed from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional from ._quadratic import LinOpQuadraticForm, QuadraticForm __all__ = [ + "ComposedFunctional", "Functional", "InnerProductFunctional", "LinearFunctional", "LinOpQuadraticForm", "MatrixFreeLinearFunctional", "QuadraticForm", + "make_functional_composed", ] diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py index 95c2f7a..1ef833e 100644 --- a/spacecore/functional/_base.py +++ b/spacecore/functional/_base.py @@ -1,12 +1,15 @@ from __future__ import annotations from abc import abstractmethod -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from .._contextual import ContextBound, resolve_context_priority from ..backend import Context from ..space import Space +if TYPE_CHECKING: + from ..linop import LinOp + Domain = TypeVar("Domain", bound=Space) @@ -46,6 +49,24 @@ def __call__(self, x: Any) -> Any: """Evaluate this functional at ``x``.""" return self.value(x) + def compose(self, A: "LinOp") -> "Functional": + """ + Return the pull-back ``self o A``. + + Parameters + ---------- + A: + Linear operator whose codomain matches this functional's domain. + + Returns + ------- + Functional + Functional on ``A.domain`` evaluating ``self.value(A.apply(x))``. + """ + from ._composed import make_functional_composed + + return make_functional_composed(self, A) + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: """Evaluate this functional independently over leading batch axes.""" return self._fallback_vvalue(xs, batch_space) diff --git a/spacecore/functional/_composed.py b/spacecore/functional/_composed.py new file mode 100644 index 0000000..d018176 --- /dev/null +++ b/spacecore/functional/_composed.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import Any + +from ._base import Functional +from ._linear import InnerProductFunctional +from ._quadratic import LinOpQuadraticForm +from .._checks import checked_method +from ..backend import Context, jax_pytree_class +from ..linop import LinOp + + +def _require_composable(F: Functional, A: LinOp) -> None: + if not isinstance(F, Functional): + raise TypeError(f"F must be a Functional, got {type(F).__name__}.") + if not isinstance(A, LinOp): + raise TypeError(f"A must be a LinOp, got {type(A).__name__}.") + if A.codomain != F.domain: + raise ValueError( + "Functional composition requires A.codomain == F.domain; " + f"got {A.codomain!r} and {F.domain!r}." + ) + + +def make_functional_composed(F: Functional, A: LinOp) -> Functional: + """ + Return the pull-back ``F o A`` with local specializations. + + Parameters + ---------- + F: + Functional defined on ``A.codomain``. + A: + Linear operator whose codomain is ``F.domain``. + + Returns + ------- + Functional + Specialized pull-back when available, otherwise + :class:`ComposedFunctional`. + """ + _require_composable(F, A) + if isinstance(F, InnerProductFunctional): + return InnerProductFunctional(A.H.apply(F.representer), A.domain, A.ctx) + if isinstance(F, LinOpQuadraticForm): + Q = A.H @ F.Q @ A + linear = None if F.linear is None else F.linear.compose(A) + return LinOpQuadraticForm(Q, linear, F.a, A.ctx) + return ComposedFunctional(F, A) + + +@jax_pytree_class +class ComposedFunctional(Functional): + """ + Generic pull-back of a functional through a linear operator. + + ``ComposedFunctional(F, A)`` represents ``x -> F(A x)`` on ``A.domain``. + """ + + def __init__(self, F: Functional, A: LinOp) -> None: + _require_composable(F, A) + super().__init__(A.domain, A.ctx) + self.F = F.convert(A.ctx) + self.A = A + + @checked_method(in_space="domain") + def value(self, x: Any) -> Any: + """ + Evaluate ``F(A x)``. + + Parameters + ---------- + x: + Element of ``A.domain``. + + Returns + ------- + Any + Scalar-like value returned by the composed functional. + """ + return self.F.value(self.A.apply(x)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.F == other.F and self.A == other.A + return False + + def tree_flatten(self): + children = (self.F, self.A) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + F, A = children + return cls(F, A) + + def _convert(self, new_ctx: Context) -> ComposedFunctional: + return ComposedFunctional(self.F.convert(new_ctx), self.A.convert(new_ctx)) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py index 6f77f5d..5387f75 100644 --- a/spacecore/functional/_quadratic.py +++ b/spacecore/functional/_quadratic.py @@ -7,7 +7,7 @@ from .._checks import checked_method from .._contextual import resolve_context_priority from ..backend import Context, jax_pytree_class -from ..linop import DenseLinOp, DiagonalLinOp, LinOp +from ..linop import LinOp from ..space import Space @@ -80,21 +80,12 @@ def __init__( self.Q = Q self.linear = linear self.a = self.ctx.asarray(a) - if self._enable_checks: - self._check_scalar_batch(self.a, ()) + self._check_scalar_batch(self.a, ()) @staticmethod def _check_hermitian_structure(Q: LinOp[Domain, Domain]) -> None: - try: - if isinstance(Q, DenseLinOp): - is_hermitian = Q.ops.allclose(Q._A2, Q._A2H) - elif isinstance(Q, DiagonalLinOp): - is_hermitian = Q.ops.allclose(Q.diagonal, Q._diag_adjoint) - else: - return - except Exception: - return - if not is_hermitian: + result = Q.is_hermitian() + if result is False: raise ValueError("LinOpQuadraticForm requires Q to be Hermitian/self-adjoint.") @checked_method(in_space="domain") diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index 06be398..03837cc 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -1,16 +1,18 @@ from __future__ import annotations from ._cg import CGResult, cg -from ._lanczos import StochasticLanczosResult, stochastic_lanczos +from ._lanczos import LanczosResult, StochasticLanczosResult, lanczos_smallest, stochastic_lanczos from ._lsqr import LSQRResult, lsqr from ._power import PowerIterationResult, power_iteration __all__ = [ "CGResult", + "LanczosResult", "LSQRResult", "PowerIterationResult", "StochasticLanczosResult", "cg", + "lanczos_smallest", "lsqr", "power_iteration", "stochastic_lanczos", diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index 775b46c..a29d157 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -5,7 +5,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import is_converged, real_inner, require_linop, require_square -from ._utils import result_repr, safe_inverse, should_check_iteration, threshold +from ._utils import result_repr, safe_inverse_nonneg, should_check_iteration, threshold class CGResult(NamedTuple): @@ -73,11 +73,11 @@ def body_fun(carry: tuple[Any, Any, Any, Any, Any, int]) -> tuple[Any, Any, Any, Ap = A.apply(p) pAp = real_inner(A.domain, p, Ap) active = (rs > eps) & (pAp > eps) - alpha = A.ops.where(active, rs * safe_inverse(A.ops, pAp), A.ops.zeros_like(rs)) + alpha = A.ops.where(active, rs * safe_inverse_nonneg(A.ops, pAp), A.ops.zeros_like(rs)) x_next = A.domain.axpy(alpha, p, x) r_next = A.codomain.axpy(-alpha, Ap, r) rs_next = real_inner(A.domain, r_next, r_next) - beta = A.ops.where(active, rs_next * safe_inverse(A.ops, rs), A.ops.zeros_like(rs_next)) + beta = A.ops.where(active, rs_next * safe_inverse_nonneg(A.ops, rs), A.ops.zeros_like(rs_next)) p_next = A.domain.axpy(beta, p, r_next) k_next = k + 1 current_residual_norm = A.domain.norm(r_next) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index b14322f..0f91802 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -1,33 +1,42 @@ from __future__ import annotations from typing import Any, NamedTuple +from warnings import warn -import numpy as np from ..linop import LinOp from ..types import DenseArray from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval -from ._utils import require_linop, require_square, safe_inverse, should_check_iteration +from ._utils import require_linop, require_square, safe_inverse_nonneg, should_check_iteration from ._utils import result_repr -class StochasticLanczosResult(NamedTuple): - """Result returned by :func:`stochastic_lanczos`.""" +class LanczosResult(NamedTuple): + """Result returned by :func:`lanczos_smallest`.""" eigenvalue: Any eigenvector: Any + residual_norm: Any + krylov_dim: Any + converged: Any def __repr__(self) -> str: """Return a compact summary without printing the full eigenvector.""" return result_repr( - "StochasticLanczosResult", + "LanczosResult", { "eigenvalue": self.eigenvalue, "eigenvector": self.eigenvector, + "residual_norm": self.residual_norm, + "krylov_dim": self.krylov_dim, + "converged": self.converged, }, ) +StochasticLanczosResult = LanczosResult + + def _check_lanczos_max_iter(max_iter: int) -> int: max_iter = int(max_iter) if max_iter < 1: @@ -35,29 +44,14 @@ def _check_lanczos_max_iter(max_iter: int) -> int: return max_iter -def _real_dtype(ctx: Any) -> Any: - dtype_text = str(ctx.dtype) - if "complex128" in dtype_text: - return ctx.ops.sanitize_dtype(np.float64) - if "complex64" in dtype_text: - return ctx.ops.sanitize_dtype(np.float32) - try: - dtype = np.dtype(ctx.dtype) - except TypeError: - return ctx.dtype - if dtype.kind == "c": - return ctx.ops.sanitize_dtype(np.float64 if dtype.itemsize > 8 else np.float32) - return ctx.dtype - - -def stochastic_lanczos( +def lanczos_smallest( A: LinOp, initial_vector: Any, *, max_iter: int = 100, tol: float = 1e-6, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, -) -> StochasticLanczosResult: +) -> LanczosResult: r"""Approximate the smallest eigenpair of a Hermitian operator. The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an @@ -82,18 +76,21 @@ def stochastic_lanczos( this many iterations, and always on the final iteration. Returns: - ``StochasticLanczosResult`` containing the smallest approximated - eigenpair. The result supports tuple unpacking as - ``eigenvalue, eigenvector``. + ``LanczosResult`` containing the smallest approximated eigenpair, the + standard Ritz residual estimate ``beta[m] * abs(y[m - 1])``, the + Krylov dimension reached, and a convergence flag. The residual estimate + is computed from the tridiagonal recurrence; callers that need the true + residual can evaluate ``A.apply(eigenvector) - eigenvalue * eigenvector`` + once more in the original space. """ A = require_linop(A) - require_square(A, "stochastic_lanczos") + require_square(A, "lanczos_smallest") max_iter = _check_lanczos_max_iter(max_iter) check_every = check_interval(check_every) A.domain.check_member(initial_vector) ops = A.ops ctx = A.ctx - real_dtype = _real_dtype(ctx) + real_dtype = ops.real_dtype(ctx.dtype) v0 = A.domain.flatten(initial_vector) v0 = ctx.assert_dense(v0) @@ -112,12 +109,12 @@ def stochastic_lanczos( e0 = ops.index_set(e0, (0,), ctx.asarray(1.0), copy=True) e0_member = A.domain.unflatten(e0) e0_norm = A.domain.norm(e0_member) - e0_unit = A.domain.flatten(A.domain.scale(safe_inverse(ops, e0_norm), e0_member)) + e0_unit = A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, e0_norm), e0_member)) v0_unit = ops.cond( v0_norm > eps_s, lambda _: A.domain.flatten( - A.domain.scale(safe_inverse(ops, v0_norm), initial_vector) + A.domain.scale(safe_inverse_nonneg(ops, v0_norm), initial_vector) ), lambda _: e0_unit, ops.asarray(0.0, dtype=real_dtype), @@ -180,7 +177,7 @@ def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: betas_ = ops.index_set(betas_, (i + 1,), beta_new, copy=True) def set_next(V_in: DenseArray) -> DenseArray: - w_unit = A.domain.flatten(A.domain.scale(safe_inverse(ops, beta_new), w_member)) + w_unit = A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, beta_new), w_member)) return ops.index_set(V_in, (i + 1, slice(None)), w_unit, copy=True) V_ = ops.cond(beta_new >= tol_s, set_next, lambda V_in: V_in, V_) @@ -225,6 +222,8 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: _eigvals, eigvecs = ops.eigh(T) y_full = eigvecs[:, 0] + residual_norm = betas[m] * ops.abs(y_full[m - 1]) + converged = residual_norm < tol_s mask_y = ops.where( idx < m, @@ -241,7 +240,7 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: x_norm = A.domain.norm(x_member) x_flat = ops.cond( x_norm > eps_s, - lambda _: A.domain.flatten(A.domain.scale(safe_inverse(ops, x_norm), x_member)), + lambda _: A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, x_norm), x_member)), lambda _: e0_unit, ops.asarray(0.0, dtype=real_dtype), ) @@ -253,4 +252,34 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: den = ops.real(A.domain.inner(x, x)) lam = num / den - return StochasticLanczosResult(lam, x) + return LanczosResult(lam, x, residual_norm, m, converged) + + +def stochastic_lanczos( + A: LinOp, + initial_vector: Any, + *, + max_iter: int = 100, + tol: float = 1e-6, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> LanczosResult: + """ + Deprecated alias for :func:`lanczos_smallest`. + + Returns + ------- + LanczosResult + Result from :func:`lanczos_smallest`. + """ + warn( + "stochastic_lanczos is deprecated; use lanczos_smallest instead.", + DeprecationWarning, + stacklevel=2, + ) + return lanczos_smallest( + A, + initial_vector, + max_iter=max_iter, + tol=tol, + check_every=check_every, + ) diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index 8494df0..cdd73c0 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -4,7 +4,7 @@ from ..linop import LinOp from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter -from ._utils import is_converged, require_linop, safe_inverse, should_check_iteration +from ._utils import is_converged, require_linop, safe_inverse_nonneg, should_check_iteration from ._utils import result_repr, threshold @@ -63,10 +63,10 @@ def lsqr( beta = A.codomain.norm(residual) normal_residual_norm = A.domain.norm(A.H.apply(residual)) u = residual - u = A.codomain.scale(safe_inverse(A.ops, beta), u) + u = A.codomain.scale(safe_inverse_nonneg(A.ops, beta), u) v = A.H.apply(u) alpha = A.domain.norm(v) - v = A.domain.scale(safe_inverse(A.ops, alpha), v) + v = A.domain.scale(safe_inverse_nonneg(A.ops, alpha), v) w = v phi_bar = beta rho_bar = alpha @@ -81,14 +81,14 @@ def body_fun(carry: tuple[Any, ...]) -> tuple[Any, ...]: x, u, v, w, alpha, _beta, rho_bar, phi_bar, _residual_norm, _normal_residual, k = carry u_next = A.codomain.axpy(-alpha, u, A.apply(v)) beta_next = A.codomain.norm(u_next) - u_next = A.codomain.scale(safe_inverse(A.ops, beta_next), u_next) + u_next = A.codomain.scale(safe_inverse_nonneg(A.ops, beta_next), u_next) v_next = A.domain.axpy(-beta_next, v, A.H.apply(u_next)) alpha_next = A.domain.norm(v_next) - v_next = A.domain.scale(safe_inverse(A.ops, alpha_next), v_next) + v_next = A.domain.scale(safe_inverse_nonneg(A.ops, alpha_next), v_next) rho = A.ops.sqrt(rho_bar * rho_bar + beta_next * beta_next) - inv_rho = safe_inverse(A.ops, rho) + inv_rho = safe_inverse_nonneg(A.ops, rho) c = rho_bar * inv_rho s = beta_next * inv_rho theta = s * alpha_next diff --git a/spacecore/linalg/_utils.py b/spacecore/linalg/_utils.py index d9af97e..32a75d3 100644 --- a/spacecore/linalg/_utils.py +++ b/spacecore/linalg/_utils.py @@ -64,8 +64,14 @@ def is_converged(residual_norm: Any, threshold_value: Any) -> Any: return residual_norm <= threshold_value -def safe_inverse(ops: Any, value: Any) -> Any: - """Return ``1 / value`` where positive and zero otherwise.""" +def safe_inverse_nonneg(ops: Any, value: Any) -> Any: + """ + Return ``1 / value`` where ``value > 0`` and zero otherwise. + + This helper is intended for norms and nonnegative residual magnitudes. It + is not a general scalar inverse: for example, ``-2`` maps to ``0``, not + ``-0.5``. + """ positive = value > 0 safe_value = ops.where(positive, value, ops.ones_like(value)) return ops.where(positive, 1.0 / safe_value, ops.zeros_like(value)) @@ -74,14 +80,16 @@ def safe_inverse(ops: Any, value: Any) -> Any: def normalize(space: Any, x: Any) -> tuple[Any, Any]: """Normalize a space member and return ``(unit, norm)``.""" norm = space.norm(x) - return space.scale(safe_inverse(space.ops, norm), x), norm + return space.scale(safe_inverse_nonneg(space.ops, norm), x), norm def default_initial_vector(A: LinOp) -> Any: - """Return a deterministic nonzero initial vector for ``A.domain``.""" + """Return a deterministic unit vector in ``A.domain`` using its geometry.""" size = prod(A.domain.shape) - flat = A.ops.ones((size,), dtype=A.dtype) / A.ops.sqrt(A.ops.asarray(float(size))) - return A.domain.unflatten(flat) + flat = A.ops.ones((size,), dtype=A.dtype) + v = A.domain.unflatten(flat) + norm = A.domain.norm(v) + return A.domain.scale(safe_inverse_nonneg(A.ops, norm), v) def summarize_value(value: Any) -> str: diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index 16d04b0..4cf6846 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -5,6 +5,7 @@ from ._base import LinOp, Domain, Codomain from .._checks import checked_method +from .._contextual._bound import _same_context_for_algebra from ..backend import Context, jax_pytree_class @@ -27,18 +28,10 @@ def _conjugate_scalar(value: Any) -> Any: return value -def _same_context(left: LinOp, right: LinOp) -> bool: - return ( - left.ctx == right.ctx - and left.ctx.dtype == right.ctx.dtype - and left.ctx.enable_checks == right.ctx.enable_checks - ) - - def _require_same_context(ops: Sequence[LinOp]) -> Context: ctx = ops[0].ctx for i, op in enumerate(ops[1:], start=1): - if not _same_context(ops[0], op): + if not _same_context_for_algebra(ops[0].ctx, op.ctx): raise ValueError( "All LinOp operands in an algebraic expression must have the same ctx; " f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}." @@ -46,6 +39,22 @@ def _require_same_context(ops: Sequence[LinOp]) -> Context: return ctx +def _same_space_for_algebra(left: Any, right: Any) -> bool: + if type(left) is not type(right): + return False + if tuple(left.shape) != tuple(right.shape): + return False + if not _same_context_for_algebra(left.ctx, right.ctx): + return False + left_parts = getattr(left, "spaces", None) + right_parts = getattr(right, "spaces", None) + if left_parts is not None or right_parts is not None: + if left_parts is None or right_parts is None or len(left_parts) != len(right_parts): + return False + return all(_same_space_for_algebra(a, b) for a, b in zip(left_parts, right_parts)) + return True + + def _require_linop(op: Any, name: str) -> LinOp: if not isinstance(op, LinOp): raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.") @@ -96,7 +105,10 @@ def make_sum(ops: Sequence[LinOp]) -> LinOp: domain = terms[0].domain codomain = terms[0].codomain for i, op in enumerate(terms[1:], start=1): - if op.domain != domain or op.codomain != codomain: + if ( + not _same_space_for_algebra(op.domain, domain) + or not _same_space_for_algebra(op.codomain, codomain) + ): raise ValueError( "All SumLinOp operands must have the same domain and codomain; " f"operand 0 maps {domain!r} -> {codomain!r}, " @@ -150,7 +162,7 @@ def make_composed(left: LinOp, right: LinOp) -> LinOp: left = _require_linop(left, "left") right = _require_linop(right, "right") _require_same_context((left, right)) - if right.codomain != left.domain: + if not _same_space_for_algebra(right.codomain, left.domain): raise ValueError( "ComposedLinOp requires right.codomain == left.domain; " f"got {right.codomain!r} and {left.domain!r}." @@ -251,7 +263,10 @@ def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: domain = parts[0].domain codomain = parts[0].codomain for i, op in enumerate(parts[1:], start=1): - if op.domain != domain or op.codomain != codomain: + if ( + not _same_space_for_algebra(op.domain, domain) + or not _same_space_for_algebra(op.codomain, codomain) + ): raise ValueError( "All SumLinOp operands must have the same domain and codomain; " f"operand 0 maps {domain!r} -> {codomain!r}, " @@ -336,7 +351,7 @@ def __init__(self, left: LinOp, right: LinOp) -> None: left = _require_linop(left, "left") right = _require_linop(right, "right") _require_same_context((left, right)) - if right.codomain != left.domain: + if not _same_space_for_algebra(right.codomain, left.domain): raise ValueError( "ComposedLinOp requires right.codomain == left.domain; " f"got {right.codomain!r} and {left.domain!r}." @@ -446,6 +461,17 @@ def to_dense(self) -> Any: """ return self.ops.zeros(tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype) + def is_hermitian(self) -> bool: + """ + Return whether the zero map is Hermitian. + + Returns + ------- + bool + ``True`` exactly when domain and codomain are the same space. + """ + return self.domain == self.codomain + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain and self.codomain == other.codomain @@ -523,6 +549,17 @@ def to_dense(self) -> Any: eye = self.ops.eye(size, dtype=self.dtype) return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def is_hermitian(self) -> bool: + """ + Return whether this identity operator is Hermitian. + + Returns + ------- + bool + Always ``True``. + """ + return True + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index dda40ed..4cff12d 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -87,6 +87,18 @@ def adjoint_apply(self, y: Any) -> Any: """Apply the adjoint of this linear operator to ``y``.""" return self.rapply(y) + def is_hermitian(self) -> bool | None: + """ + Return whether this operator is structurally Hermitian when known. + + Returns + ------- + bool | None + ``True`` or ``False`` when the subclass can verify the structure + cheaply, otherwise ``None`` for unknown or matrix-free operators. + """ + return None + def vapply(self, xs: Any, batch_space: Space | None = None) -> Any: """Apply this operator independently over a batch of domain elements.""" return self._fallback_vapply(xs, batch_space) diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index d2fa2ce..b133aed 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -48,7 +48,7 @@ def __init__(self, self._matrix_shape = (self._cod_size, self._dom_size) self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) - is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + is_complex = self.ops.is_complex_dtype(dtype) self._A2T = self._A2.T self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) @@ -193,6 +193,23 @@ def to_dense(self) -> DenseArray: """ return self.A + def is_hermitian(self) -> bool | None: + """ + Return whether this dense operator is structurally Hermitian. + + Returns + ------- + bool + ``True`` when the operator is square and its flattened matrix + equals its conjugate transpose. + """ + if self.dom != self.cod: + return False + try: + return bool(self.ops.allclose(self._A2, self._A2H)) + except Exception: + return None + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py index a27abab..389c9cb 100644 --- a/spacecore/linop/_diagonal.py +++ b/spacecore/linop/_diagonal.py @@ -31,14 +31,9 @@ def __init__( if tuple(diagonal.shape) != expected: raise TypeError(f"Expected diagonal.shape == space.shape == {expected}, got {diagonal.shape}") self.diagonal = diagonal - self._size = prod(self.domain.shape) - self._is_flat = tuple(self.domain.shape) == (self._size,) - self._diag_flat = diagonal if self._is_flat else diagonal.reshape((self._size,)) dtype = self.ops.get_dtype(diagonal) - self._is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") - self._diag_adjoint = self.ops.conj(diagonal) if self._is_complex else diagonal - self._diag_adjoint_flat = ( - self._diag_adjoint if self._is_flat else self._diag_adjoint.reshape((self._size,)) + self._diag_adjoint = ( + self.ops.conj(diagonal) if self.ops.is_complex_dtype(dtype) else diagonal ) @cached_property @@ -53,11 +48,22 @@ def apply(self, x: DenseArray) -> DenseArray: def rapply(self, y: DenseArray) -> DenseArray: return self._diag_adjoint * y + def _reshape_diagonal_for_batch(self, diagonal: DenseArray, batch_space: Any) -> DenseArray: + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + total_ndim = len(self.domain.shape) + len(batch_shape) + base_axes = [axis for axis in range(total_ndim) if axis not in batch_axes] + shape = [1] * total_ndim + for axis, dim in zip(base_axes, self.domain.shape, strict=True): + shape[axis] = dim + return self.ops.reshape(diagonal, tuple(shape)) + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.domain, xs, batch_space) if self._enable_checks: in_space._check_member(xs) - ys = self.diagonal * xs + diagonal = self._reshape_diagonal_for_batch(self.diagonal, in_space) + ys = diagonal * xs if self._enable_checks: self._output_batch_space(self.codomain, in_space)._check_member(ys) return ys @@ -66,15 +72,31 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: in_space = self._input_batch_space(self.codomain, ys, batch_space) if self._enable_checks: in_space._check_member(ys) - xs = self._diag_adjoint * ys + diagonal = self._reshape_diagonal_for_batch(self._diag_adjoint, in_space) + xs = diagonal * ys if self._enable_checks: self._output_batch_space(self.domain, in_space)._check_member(xs) return xs def to_dense(self) -> DenseArray: - matrix = self.ops.diag(self._diag_flat) + flat = self.diagonal.reshape((prod(self.domain.shape),)) + matrix = self.ops.diag(flat) return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def is_hermitian(self) -> bool | None: + """ + Return whether this diagonal operator is structurally Hermitian. + + Returns + ------- + bool + ``True`` when the diagonal equals its complex conjugate. + """ + try: + return bool(self.ops.allclose(self.diagonal, self._diag_adjoint)) + except Exception: + return None + def __eq__(self, other: Any) -> bool: if type(other) is type(self): return self.domain == other.domain and self.ops.allclose(self.diagonal, other.diagonal) @@ -91,4 +113,8 @@ def tree_unflatten(cls, aux, children): return cls(children[0], domain, ctx) def _convert(self, new_ctx: Context) -> DiagonalLinOp: - return DiagonalLinOp(new_ctx.asarray(self.diagonal), self.domain.convert(new_ctx), new_ctx) + return DiagonalLinOp( + new_ctx.asarray(self.diagonal), + VectorSpace(tuple(self.domain.shape), new_ctx), + new_ctx, + ) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index a691bf3..236418c 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -43,7 +43,7 @@ def __init__(self, self._cod_size = expected[0] self._dom_size = expected[1] dtype = self.ops.get_dtype(self.A) - self._A_is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._A_is_complex = self.ops.is_complex_dtype(dtype) self._AT = self.A.T self._AH = self._AT.conj() if self._A_is_complex else self._AT self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) @@ -200,6 +200,23 @@ def to_dense(self) -> DenseArray: dense = super().to_dense().reshape((self._cod_size, self._dom_size)) return self.ops.reshape(dense, tuple(self.codomain.shape) + tuple(self.domain.shape)) + def is_hermitian(self) -> bool | None: + """ + Return whether this sparse operator is structurally Hermitian. + + Returns + ------- + bool + ``True`` when the operator is square and its sparse matrix equals + its conjugate transpose within backend tolerances. + """ + if self.dom != self.cod: + return False + try: + return bool(self.ops.allclose_sparse(self.A, self._AH)) + except Exception: + return None + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index d33314f..af826d8 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -83,6 +83,81 @@ def test_matrix_free_linear_functional_has_no_representer(): functional.representer +def test_linear_functional_compose_specializes_to_inner_product_functional(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, -1.0], [3.0, 0.5]]), X, Y, ctx) + c = ctx.asarray([2.0, -1.0, 0.5]) + F = sc.InnerProductFunctional(c, Y, ctx) + pullback = F.compose(A) + x = ctx.asarray([4.0, -2.0]) + + assert isinstance(pullback, sc.InnerProductFunctional) + np.testing.assert_allclose(pullback.representer, A.H.apply(c)) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + + +def test_quadratic_form_compose_specializes_quadratic_and_linear_terms(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, -1.0], [3.0, 0.5]]), X, Y, ctx) + Q = sc.IdentityLinOp(Y, ctx) + linear = sc.InnerProductFunctional(ctx.asarray([1.0, -2.0, 0.5]), Y, ctx) + F = sc.LinOpQuadraticForm(Q, linear, 1.25, ctx) + pullback = F.compose(A) + x = ctx.asarray([0.5, -1.5]) + + assert isinstance(pullback, sc.LinOpQuadraticForm) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + np.testing.assert_allclose(pullback.grad(x), A.H.apply(F.grad(A.apply(x)))) + + +def test_generic_functional_compose_forwards_value(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((2,), ctx) + A = sc.DiagonalLinOp(ctx.asarray([2.0, -1.0]), X, ctx) + + class SumSquares(sc.Functional): + def value(self, x): + return self.ops.sum(x * x) + + def tree_flatten(self): + return (), (self.domain, self.ctx) + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(domain, ctx) + + def _convert(self, new_ctx): + return SumSquares(self.domain.convert(new_ctx), new_ctx) + + F = SumSquares(Y, ctx) + pullback = F.compose(A) + x = ctx.asarray([3.0, 4.0]) + + assert isinstance(pullback, sc.ComposedFunctional) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + + +def test_functional_compose_rejects_incompatible_codomain(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.IdentityLinOp(X, ctx) + F = sc.InnerProductFunctional(ctx.asarray([1.0, 2.0, 3.0]), Y, ctx) + + with pytest.raises(ValueError, match="A.codomain == F.domain"): + F.compose(A) + + def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): ctx = _ctx() q = _quadratic_problem(ctx) @@ -133,6 +208,16 @@ def rapply(y): np.testing.assert_allclose(q.grad(x), Q.apply(x)) +def test_linop_quadratic_form_always_rejects_nonscalar_constant(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=False) + space = sc.VectorSpace((2,), ctx) + Q = sc.IdentityLinOp(space, ctx) + + with pytest.raises(ValueError, match="scalar batch"): + sc.LinOpQuadraticForm(Q, a=ctx.asarray([0.0, 0.0]), ctx=ctx) + + def test_vvalue_and_vgrad_match_elementwise_loops(): ctx = _ctx() q = _quadratic_problem(ctx) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 4df7035..31778f2 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -24,13 +24,14 @@ def test_expected_names_are_exported(): "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", "Functional", "LinearFunctional", "InnerProductFunctional", - "MatrixFreeLinearFunctional", "QuadraticForm", "LinOpQuadraticForm", + "ComposedFunctional", "MatrixFreeLinearFunctional", "QuadraticForm", + "LinOpQuadraticForm", "make_functional_composed", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", "set_context", "get_context", "resolve_context_priority", "register_ops", "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", - "StochasticLanczosResult", + "LanczosResult", "StochasticLanczosResult", "lanczos_smallest", } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} @@ -60,7 +61,9 @@ def test_top_level_objects_match_source_modules(): assert sc.VectorSpace is space.VectorSpace assert sc.DenseLinOp is linop.DenseLinOp assert sc.Functional is functional.Functional + assert sc.ComposedFunctional is functional.ComposedFunctional assert sc.InnerProductFunctional is functional.InnerProductFunctional + assert sc.LanczosResult is linalg.LanczosResult assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult assert sc.get_context is contextual.get_context assert sc.resolve_context_priority is contextual.resolve_context_priority diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index 07b8626..a6fd825 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -228,97 +228,107 @@ def test_power_iteration_core_has_no_dispatch_logic(): @pytest.mark.parametrize("backend_name,dtype", _backend_params()) -def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype): +def test_lanczos_smallest_approximates_smallest_eigenpair(backend_name, dtype): sc = importlib.import_module("spacecore") ctx = _ctx(backend_name, dtype) space = sc.VectorSpace((2,), ctx) op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) initial = ctx.asarray([1.0, 1.0]) - eigenvalue, eigenvector = sc.stochastic_lanczos( + result = sc.lanczos_smallest( op, initial, max_iter=2, tol=1e-8, ) - np.testing.assert_allclose(to_numpy(eigenvalue), 2.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(to_numpy(result.eigenvalue), 2.0, rtol=1e-5, atol=1e-5) np.testing.assert_allclose( - np.abs(to_numpy(eigenvector)), + np.abs(to_numpy(result.eigenvector)), [1.0, 0.0], rtol=1e-5, atol=1e-5, ) + assert bool(to_numpy(result.converged)) + np.testing.assert_allclose(to_numpy(result.residual_norm), 0.0, atol=1e-5) + assert int(to_numpy(result.krylov_dim)) == 2 -def test_stochastic_lanczos_returns_result_object(): +def test_lanczos_smallest_returns_result_object_and_deprecated_alias_warns(): sc = importlib.import_module("spacecore") ctx = _ctx() space = sc.VectorSpace((2,), ctx) op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) - result = sc.stochastic_lanczos(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) - eigenvalue, eigenvector = result + result = sc.lanczos_smallest(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) + assert isinstance(result, sc.LanczosResult) assert isinstance(result, sc.StochasticLanczosResult) - np.testing.assert_allclose(eigenvalue, result.eigenvalue) - np.testing.assert_allclose(eigenvector, result.eigenvector) + with pytest.warns(DeprecationWarning, match="lanczos_smallest"): + alias_result = sc.stochastic_lanczos( + op, + ctx.asarray([1.0, 1.0]), + max_iter=2, + tol=1e-8, + ) + np.testing.assert_allclose(alias_result.eigenvalue, result.eigenvalue) + np.testing.assert_allclose(alias_result.eigenvector, result.eigenvector) -def test_stochastic_lanczos_uses_e0_for_zero_initial_vector(): +def test_lanczos_smallest_uses_e0_for_zero_initial_vector(): sc = importlib.import_module("spacecore") ctx = _ctx() space = sc.VectorSpace((2,), ctx) op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) initial = ctx.asarray([0.0, 0.0]) - eigenvalue, eigenvector = sc.stochastic_lanczos( + result = sc.lanczos_smallest( op, initial, max_iter=2, tol=1e-8, ) - np.testing.assert_allclose(eigenvalue, 2.0, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose(eigenvector, [1.0, 0.0], rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(result.eigenvalue, 2.0, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(result.eigenvector, [1.0, 0.0], rtol=1e-6, atol=1e-6) -def test_stochastic_lanczos_rejects_invalid_max_iter(): +def test_lanczos_smallest_rejects_invalid_max_iter(): sc = importlib.import_module("spacecore") ctx = _ctx() space = sc.VectorSpace((1,), ctx) op = sc.IdentityLinOp(space, ctx) with pytest.raises(ValueError, match="max_iter"): - sc.stochastic_lanczos(op, ctx.asarray([1.0]), max_iter=0) + sc.lanczos_smallest(op, ctx.asarray([1.0]), max_iter=0) -def test_stochastic_lanczos_handles_eigenvalues_larger_than_1e10(): +def test_lanczos_smallest_handles_eigenvalues_larger_than_1e10(): sc = importlib.import_module("spacecore") ctx = _ctx() space = sc.VectorSpace((2,), ctx) matrix = np.diag([2.0e12, 3.0e12]) op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) - eigenvalue, eigenvector = sc.stochastic_lanczos( + result = sc.lanczos_smallest( op, ctx.asarray([1.0, 1.0]), max_iter=4, tol=1e-8, ) - np.testing.assert_allclose(to_numpy(eigenvalue), 2.0e12, rtol=1e-6) - np.testing.assert_allclose(np.abs(to_numpy(eigenvector)), [1.0, 0.0], atol=1e-5) + np.testing.assert_allclose(to_numpy(result.eigenvalue), 2.0e12, rtol=1e-6) + np.testing.assert_allclose(np.abs(to_numpy(result.eigenvector)), [1.0, 0.0], atol=1e-5) -def test_stochastic_lanczos_handles_complex_hermitian_operator(): +def test_lanczos_smallest_handles_complex_hermitian_operator(): sc = importlib.import_module("spacecore") ctx = _ctx(dtype=np.complex128) space = sc.VectorSpace((2,), ctx) matrix = np.array([[2.0, 1.0 + 2.0j], [1.0 - 2.0j, 5.0]], dtype=np.complex128) op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) - eigenvalue, eigenvector = sc.stochastic_lanczos( + result = sc.lanczos_smallest( op, ctx.asarray([1.0 + 0.0j, 1.0j]), max_iter=2, @@ -326,16 +336,16 @@ def test_stochastic_lanczos_handles_complex_hermitian_operator(): ) expected = np.linalg.eigvalsh(matrix)[0] - np.testing.assert_allclose(to_numpy(eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(to_numpy(result.eigenvalue), expected, rtol=1e-7, atol=1e-7) np.testing.assert_allclose( - to_numpy(op.apply(eigenvector)), - to_numpy(eigenvalue) * to_numpy(eigenvector), + to_numpy(op.apply(result.eigenvector)), + to_numpy(result.eigenvalue) * to_numpy(result.eigenvector), rtol=1e-6, atol=1e-6, ) -def test_stochastic_lanczos_uses_domain_geometry_for_weighted_inner_product(): +def test_lanczos_smallest_uses_domain_geometry_for_weighted_inner_product(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -358,7 +368,7 @@ def _convert(self, new_ctx): matrix = np.array([[2.0, 1.0], [0.25, 0.75]]) op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) - eigenvalue, eigenvector = sc.stochastic_lanczos( + result = sc.lanczos_smallest( op, ctx.asarray([1.0, 1.0]), max_iter=2, @@ -366,23 +376,23 @@ def _convert(self, new_ctx): ) expected = np.min(np.linalg.eigvals(matrix).real) - np.testing.assert_allclose(to_numpy(eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(to_numpy(result.eigenvalue), expected, rtol=1e-7, atol=1e-7) np.testing.assert_allclose( - to_numpy(op.apply(eigenvector)), - to_numpy(eigenvalue) * to_numpy(eigenvector), + to_numpy(op.apply(result.eigenvector)), + to_numpy(result.eigenvalue) * to_numpy(result.eigenvector), rtol=1e-6, atol=1e-6, ) -def test_safe_inverse_returns_reciprocal_for_positive_values_only(): +def test_safe_inverse_nonneg_returns_reciprocal_for_positive_values_only(): sc = importlib.import_module("spacecore") utils = importlib.import_module("spacecore.linalg._utils") ctx = _ctx() values = ctx.asarray([-2.0, 0.0, 4.0]) - np.testing.assert_allclose(to_numpy(utils.safe_inverse(sc.NumpyOps(), values)), [0.0, 0.0, 0.25]) + np.testing.assert_allclose(to_numpy(utils.safe_inverse_nonneg(sc.NumpyOps(), values)), [0.0, 0.0, 0.25]) def test_iterative_solvers_poll_convergence_on_check_interval(): @@ -469,7 +479,7 @@ def test_power_iteration_jit_compiles_with_quadratic_form_argument(): @pytest.mark.skipif(not has_jax(), reason="jax is not installed") -def test_stochastic_lanczos_jit_compiles_with_operator_argument(): +def test_lanczos_smallest_jit_compiles_with_operator_argument(): jax = pytest.importorskip("jax") sc = importlib.import_module("spacecore") ctx = _ctx("jax", jax_real_dtype()) @@ -477,12 +487,13 @@ def test_stochastic_lanczos_jit_compiles_with_operator_argument(): op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) def run(A, initial): - return sc.stochastic_lanczos( + result = sc.lanczos_smallest( A, initial, max_iter=2, tol=1e-8, ) + return result.eigenvalue, result.eigenvector eigenvalue, eigenvector = jax.jit(run)(op, ctx.asarray([1.0, 1.0])) @@ -507,4 +518,4 @@ def test_cg_and_power_iteration_reject_rectangular_operator(): with pytest.raises(ValueError, match="square LinOp"): sc.power_iteration(A) with pytest.raises(ValueError, match="square LinOp"): - sc.stochastic_lanczos(A, ctx.asarray([1.0, 2.0])) + sc.lanczos_smallest(A, ctx.asarray([1.0, 2.0])) diff --git a/tests/linops/test_algebra.py b/tests/linops/test_algebra.py index 770b20c..e8cccdb 100644 --- a/tests/linops/test_algebra.py +++ b/tests/linops/test_algebra.py @@ -263,6 +263,27 @@ def test_factories_enforce_same_context_dtype(): sc.make_composed(A32, A64) +def test_factories_ignore_enable_checks_when_context_dtype_matches(): + sc = importlib.import_module("spacecore") + checked = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=True) + unchecked = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + X_checked = sc.VectorSpace((2,), checked) + X_unchecked = sc.VectorSpace((2,), unchecked) + A = sc.DenseLinOp(checked.asarray([[1.0, 0.0], [0.0, 1.0]]), X_checked, X_checked, checked) + B = sc.DenseLinOp( + unchecked.asarray([[2.0, 0.0], [0.0, 3.0]]), + X_unchecked, + X_unchecked, + unchecked, + ) + + summed = sc.make_sum((A, B)) + composed = sc.make_composed(A, B) + + assert isinstance(summed, sc.SumLinOp) + assert isinstance(composed, sc.ComposedLinOp) + + def test_factories_enforce_domain_and_codomain_compatibility(): sc = importlib.import_module("spacecore") ctx = _ctx(dtype=np.float64) diff --git a/tests/linops/test_algebra_linop.py b/tests/linops/test_algebra_linop.py index fd2a8f8..4c2ee85 100644 --- a/tests/linops/test_algebra_linop.py +++ b/tests/linops/test_algebra_linop.py @@ -37,14 +37,12 @@ def test_algebra_linops_inherit_from_linop(): assert not hasattr(sc, "AdjointLinOp") -def test_context_mismatch_raises_clear_error(): +def test_check_policy_mismatch_does_not_block_algebra(): A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), _ctx(enable_checks=True)) B = _op([[5.0, 6.0], [7.0, 8.0]], (2,), (2,), _ctx(enable_checks=False)) - with pytest.raises(ValueError, match="same ctx"): - _ = A + B - with pytest.raises(ValueError, match="same ctx"): - _ = A @ B + assert isinstance(A + B, importlib.import_module("spacecore").SumLinOp) + assert isinstance(A @ B, importlib.import_module("spacecore").ComposedLinOp) def test_sum_requires_matching_domain_and_codomain(): diff --git a/tests/linops/test_diagonal_linop.py b/tests/linops/test_diagonal_linop.py new file mode 100644 index 0000000..e200a1f --- /dev/null +++ b/tests/linops/test_diagonal_linop.py @@ -0,0 +1,118 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, to_numpy + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def test_apply_and_rapply_flat_diagonal(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + x = ctx.asarray([4.0, -1.0, 0.5]) + + np.testing.assert_allclose(op.apply(x), [4.0, -2.0, 1.5]) + np.testing.assert_allclose(op.rapply(x), [4.0, -2.0, 1.5]) + + +def test_apply_and_rapply_tensor_shaped_diagonal(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + op = sc.DiagonalLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0]]), space, ctx) + x = ctx.asarray([[2.0, -1.0], [0.5, 3.0]]) + + np.testing.assert_allclose(op.apply(x), [[2.0, -2.0], [1.5, 12.0]]) + np.testing.assert_allclose(op.rapply(x), [[2.0, -2.0], [1.5, 12.0]]) + + +def test_complex_diagonal_satisfies_adjoint_identity_and_hermitian_predicate(): + sc = importlib.import_module("spacecore") + ctx = _ctx(np.complex128) + space = sc.VectorSpace((2,), ctx) + hermitian = sc.DiagonalLinOp(ctx.asarray([1.0 + 0.0j, 2.0 + 0.0j]), space, ctx) + non_hermitian = sc.DiagonalLinOp(ctx.asarray([1.0 + 2.0j, 3.0 - 1.0j]), space, ctx) + u = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + v = ctx.asarray([1.5 + 0.5j, -2.0j]) + + lhs = space.inner(non_hermitian.apply(u), v) + rhs = space.inner(u, non_hermitian.rapply(v)) + + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs)) + assert hermitian.is_hermitian() is True + assert non_hermitian.is_hermitian() is False + + +def test_vapply_and_rvapply_with_leading_batch_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + xs = ctx.asarray([[1.0, 2.0, 3.0], [4.0, -1.0, 0.5]]) + batch_space = space.batch((2,), (0,)) + + expected = ctx.asarray([[1.0, 4.0, 9.0], [4.0, -2.0, 1.5]]) + np.testing.assert_allclose(op.vapply(xs, batch_space), expected) + np.testing.assert_allclose(op.rvapply(xs, batch_space), expected) + + +def test_vapply_and_rvapply_with_non_leading_batch_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + xs = ctx.asarray([[1.0, 4.0], [2.0, -1.0], [3.0, 0.5]]) + batch_space = space.batch((2,), (1,)) + + expected = ctx.asarray([[1.0, 4.0], [4.0, -2.0], [9.0, 1.5]]) + np.testing.assert_allclose(op.vapply(xs, batch_space), expected) + np.testing.assert_allclose(op.rvapply(xs, batch_space), expected) + + +def test_to_dense_matches_numpy_diagonal_for_tensor_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + diagonal = ctx.asarray([[1.0, 2.0], [3.0, 4.0]]) + op = sc.DiagonalLinOp(diagonal, space, ctx) + + expected = np.diag(np.asarray(diagonal).reshape((4,))).reshape((2, 2, 2, 2)) + + np.testing.assert_allclose(op.to_dense(), expected) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_pytree_flatten_unflatten_round_trip(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + + leaves, treedef = jax.tree_util.tree_flatten(op) + restored = jax.tree_util.tree_unflatten(treedef, leaves) + + assert restored == op + np.testing.assert_allclose(restored.apply(ctx.asarray([1.0, 1.0, 1.0])), [1.0, 2.0, 3.0]) + + +def test_convert_changes_context_dtype(): + sc = importlib.import_module("spacecore") + ctx64 = _ctx(np.float64) + ctx32 = _ctx(np.float32) + space = sc.VectorSpace((3,), ctx64) + op = sc.DiagonalLinOp(ctx64.asarray([1.0, 2.0, 3.0]), space, ctx64) + + converted = op._convert(ctx32) + + assert converted.ctx == ctx32 + assert converted.domain.ctx == ctx32 + assert converted.diagonal.dtype == np.dtype(np.float32) + np.testing.assert_allclose(converted.apply(ctx32.asarray([1.0, 1.0, 1.0])), [1.0, 2.0, 3.0]) From cb13ecbad0a3d6449ec981b9fe3628a588002b0b Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:11:23 -0300 Subject: [PATCH 24/44] Document cacheability audit --- audit_caching.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 audit_caching.md diff --git a/audit_caching.md b/audit_caching.md new file mode 100644 index 0000000..1531165 --- /dev/null +++ b/audit_caching.md @@ -0,0 +1,36 @@ +# Cacheability Audit + +This audit covers the requested hot paths and separates construction-time state +from per-call work. Evidence comes from direct source inspection with line +numbers and grep checks such as: + +```text +grep -R "tuple(self\\.dom\\.shape) ==\\|tuple(self\\.cod\\.shape) ==\\|getattr(batch_space\\|coeffs_full = ops.zeros\\|fori_loop(0, max_iter + 1" -n spacecore/linop spacecore/linalg +``` + +## Candidates + +| Location | Per-call work | Evidence | Cost of caching | Benefit | Recommendation | +| --- | --- | --- | --- | --- | --- | +| `spacecore/linop/_dense.py:76-81` | Reshapes `x` when the domain is not flat, multiplies by cached `_A2`, and reshapes/unflattens output. | `_A2`, `_A2T`, `_A2H`, `_dom_is_flat`, `_cod_is_flat`, sizes are already constructed at lines 46-57. | Additional caching would duplicate shape tuples only; < 100 bytes. | Negligible. The expensive matrix multiply dominates. | Don't cache more. Existing construction-time cache is appropriate. | +| `spacecore/linop/_dense.py:92-97` | Reshapes `y`, multiplies by cached `_A2H`, reshapes/unflattens output. | `_A2H` is already cached at line 53; flat flags are cached at lines 54-55. | Additional shape tuple cache only; < 100 bytes. | Negligible relative to matvec. | Don't cache more. | +| `spacecore/linop/_dense.py:116-142` | Batched reshape and batched dense matmul; for non-plain `VectorSpace`, constructs a `BatchSpace` for unflattening. | Calls `self.cod.batch(batch_shape, tuple(range(len(batch_shape))))` at line 128 and domain equivalent at line 142. | A general cache keyed by `batch_shape` would be an unbounded dict and would not be a pytree leaf. | Low unless repeatedly using custom non-vector spaces with the same batch shape. For plain `VectorSpace`, no `BatchSpace` is created. | Don't cache. Avoid unbounded mutable instance state for a narrow path. | +| `spacecore/linop/_dense.py:164-186` | Reflects `in_space.batch_axes` and builds `tuple(range(...))` to identify leading batches. | Lines 166 and 178. | Could cache a tuple per batch rank, but batch rank is input-dependent. | Very low; one tuple/reflection check per batched call. | Don't cache. | +| `spacecore/linop/_sparse.py:74-96` | Reshapes vectors, sparse matvec with cached `_AH`, output reshape/unflatten. | `_AH` and flat flags are cached at lines 47-52. | Additional cache would be shape tuples only. | Negligible relative to sparse matvec. | Don't cache more. | +| `spacecore/linop/_sparse.py:115-141` | Batched sparse matmul; may construct `BatchSpace` for non-vector spaces. | Same pattern as dense: lines 127 and 141. | Unbounded dict if keyed by `batch_shape`; not JAX-friendly. | Low for the common vector-space fast path. | Don't cache. | +| `spacecore/linop/_diagonal.py:51-59` | Builds `batch_shape`, `batch_axes`, `base_axes`, and reshape tuple on every `vapply`/`rvapply`. | Lines 52-59; called by `vapply` at line 65 and `rvapply` at line 75. | A cache would store small tuples keyed by `(batch_shape, batch_axes)`; memory tiny per key but unbounded and mutable. | Low: the elementwise multiply dominates for large arrays; for small arrays, Python overhead exists but caching adds mutable state and pytree concerns. | Don't cache. Keep stateless and JAX-safe. | +| `spacecore/linop/_algebra.py:283-297` | `SumLinOp.apply/rapply` loops over operands and accumulates results. | Lines 286-288 and 294-296. | No reusable derived object; caching partial sums would depend on input. | None. Work is mathematical operator application. | Don't cache. | +| `spacecore/linop/_algebra.py:363-383` | `ComposedLinOp.apply/rapply/vapply/rvapply` delegates through left/right operators; batched paths create middle `BatchSpace`. | Lines 366, 371, 376, 382. | Could cache middle batch spaces by input batch signature; unbounded mutable dict. | Low and only for repeated batched composition with the same batch axes. | Don't cache. | +| `spacecore/linop/_algebra.py:878-894` | `_AdjointViewLinOp` delegates to the wrapped operator. | Lines 881, 886, 890, 894. | Nothing useful to cache. | None. Delegation is already minimal. | Don't cache. | +| `spacecore/linalg/_lanczos.py:108-112` | Builds `e0` and normalized `e0_unit` once per `lanczos_smallest` call. | Lines 108-112. | Caching on `A.domain` would add mutable state to spaces and complicate JAX pytree semantics. Memory is `O(n)`. | Low: once per solver call, not per iteration. | Don't cache. | +| `spacecore/linalg/_lanczos.py:163-170` | Allocates `coeffs_full = zeros(max_iter + 1)` inside every Lanczos iteration, then fills it with a `fori_loop`. | Grep shows `coeffs_full = ops.zeros((max_iter + 1,), dtype=ctx.dtype)` at line 163 inside `body_fun`, followed by `ops.fori_loop(0, max_iter + 1, ...)` at line 170. | One extra closure-captured zero vector of length `max_iter + 1`; bytes are about `(max_iter + 1) * itemsize`, usually a few hundred bytes. | Moderate: avoids allocating the same small vector `m` times per Krylov call. This is also the path Task 3 will revisit for trace cleanliness. | Cache/hoist a zero template outside `body_fun` and reuse it as the initial coefficient vector. | + +## Recommended + +- Hoist Lanczos `coeffs_full` zero-vector allocation out of `body_fun`. + This removes one small allocation per Krylov iteration with negligible memory + cost and no API change. + +Everything else is either already cached (`_A2`, `_A2H`, `_AH`, flat flags) or +would require mutable shape caches for small tuple/reflection work. Those are +not worth the added state, especially for JAX pytree compatibility. From 68a3268d90c510de748d275b428c3772ba7e6c32 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:12:21 -0300 Subject: [PATCH 25/44] Hoist Lanczos coefficient zero template --- spacecore/linalg/_lanczos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 0f91802..900556d 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -127,6 +127,7 @@ def lanczos_smallest( full_indices = ops.arange(max_iter + 1) idx = ops.arange(max_iter) + coeffs_zero = ops.zeros((max_iter + 1,), dtype=ctx.dtype) def cond_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> Any: i, _V, _alphas, _betas, _beta, keep_going = state @@ -160,7 +161,7 @@ def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, ) mask = ops.astype(mask, ctx.dtype) - coeffs_full = ops.zeros((max_iter + 1,), dtype=ctx.dtype) + coeffs_full = coeffs_zero def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: v_j_member = A.domain.unflatten(V_[j]) From a4ddaf26e9fa6a31780094e829443a36a674af54 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:15:05 -0300 Subject: [PATCH 26/44] Add JIT traceability audit --- audit_jit.md | 103 +++++ scripts/jit_audit.py | 206 +++++++++ tests/fixtures/jaxpr_lanczos_smallest.txt | 484 ++++++++++++++++++++++ 3 files changed, 793 insertions(+) create mode 100644 audit_jit.md create mode 100644 scripts/jit_audit.py create mode 100644 tests/fixtures/jaxpr_lanczos_smallest.txt diff --git a/audit_jit.md b/audit_jit.md new file mode 100644 index 0000000..8066ea6 --- /dev/null +++ b/audit_jit.md @@ -0,0 +1,103 @@ +# JIT Traceability Audit + +This audit was generated with `scripts/jit_audit.py`. The script: + +- wraps each solver with `jax.jit`; +- calls each jitted wrapper twice with shape/dtype-stable values; +- calls it again with a changed static iteration argument; +- calls it again with a changed operator/domain shape; +- writes a `jax.make_jaxpr` fixture for `lanczos_smallest` to + `tests/fixtures/jaxpr_lanczos_smallest.txt`. + +The script also enables `jax_log_compiles`. In addition to JAX's compile logs, +it uses a trace-time counter, which increments only when JAX retraces the Python +wrapper. + +## Summary + +| Solver | Traces cleanly | Recompiles/retraces on same shape/dtype values | Retraces when expected | Evidence | +| --- | --- | --- | --- | --- | +| `cg` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output: `{'solver': 'cg', 'traces_after_two_same_shape_calls': 1, 'traces_after_static_change': 2, 'traces_after_shape_change': 3, ...}` | +| `lsqr` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lsqr`. | +| `lanczos_smallest` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `max_iter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lanczos_smallest`. | +| `power_iteration` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `power_iteration`. | +| `expm_multiply` | Not audited yet | Not applicable | Not applicable | Not implemented before Task 1. The script reports `{'solver': 'expm_multiply', 'status': 'not available before Task 1'}` and should be rerun after Task 1. | + +## Findings + +### 1. Lanczos full reorthogonalization lowers to an inner scan + +`tests/fixtures/jaxpr_lanczos_smallest.txt` captures the current JAXPR for +`lanczos_smallest(max_iter=3)`. Grep evidence: + +```text +grep -n "scan\\|while\\|scatter\\|dot_general" tests/fixtures/jaxpr_lanczos_smallest.txt +``` + +The fixture shows: + +- a top-level `while` at line 61 for the Krylov iteration; +- a nested `scan` at line 143 corresponding to + `ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full)`; +- repeated `scatter` operations inside that scan. + +This is correct, but for the common exact `VectorSpace` case it is more IR than +needed. Mathematically, Euclidean reorthogonalization coefficients are +`conj(V) @ w`, which can be one `einsum`/matmul node. + +Important constraint: this replacement is **not valid for arbitrary +`Space.inner`**. A weighted space, RKHS, or any custom geometry must keep using +`Space.inner(v_j, w)`. Therefore the optimization should be guarded to the exact +`VectorSpace` type only, not subclasses. + +### 2. `cg` and `lsqr` use `ops.cond` correctly for periodic diagnostics + +Both solvers trace without errors and do not retrace on value-only changes. The +`ops.cond(..., lambda _: ..., lambda _: ..., operand)` pattern is valid under +the installed JAX version. Both branches return matching shapes and dtypes. + +### 3. `power_iteration` dispatch is trace-time, not data-dependent + +`power_iteration` branches on whether the first argument is a `LinOp` or +`QuadraticForm`. That branch is Python-level and happens at trace time. This is +acceptable because the object type is static in the pytree structure; changing +from a `LinOp` to a `QuadraticForm` should retrace. + +### 4. Algebra construction inside JIT is possible but theoretical + +The algebra factories use Python `isinstance` checks for symbolic simplification. +Those checks are not data-dependent on traced arrays. If users construct new +operator expressions inside a jitted function, that Python algebra executes at +trace time and contributes to trace cost. This is a usage-pattern concern, not a +solver bug: normal use passes already-built operators into jitted numerical +kernels. + +### 5. Constants are mostly static by design + +Iteration counts (`maxiter`, `max_iter`) are static in the audit wrappers. +Changing them retraces, which is expected because loop bounds and fixed-size +work arrays change. Scalars such as tolerances are currently Python arguments +converted through `ops.asarray`; changing them may retrace unless callers pass +them through a wrapper as array values. This is acceptable for the current API. + +## Recommended Implementation + +- Add a Euclidean fast path for Lanczos reorthogonalization when + `type(A.domain) is VectorSpace`: compute all coefficients with + `ops.einsum("jn,n->j", ops.conj(V_), w)`. +- Keep the existing `Space.inner` loop for all non-exact `VectorSpace` domains + to preserve Space geometry. + +This is the only validated change from this audit. The broader replacement +`V @ w` is rejected for non-Euclidean spaces because it would regress the +geometry-correct Lanczos recurrence. + +## Follow-Up TODO + +1. Rerun `scripts/jit_audit.py` after `expm_multiply` lands and update this + document with its trace counts. +2. Consider a benchmark for exact `VectorSpace` Lanczos before/after the + reorthogonalization fast path. +3. If users report trace-time issues from constructing algebra expressions + inside `jax.jit`, document that operators should be built outside the jitted + function. diff --git a/scripts/jit_audit.py b/scripts/jit_audit.py new file mode 100644 index 0000000..cd41afd --- /dev/null +++ b/scripts/jit_audit.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable + +import numpy as np + + +ROOT = Path(__file__).resolve().parents[1] +FIXTURE = ROOT / "tests" / "fixtures" / "jaxpr_lanczos_smallest.txt" + + +def _ctx(): + import jax + import spacecore as sc + + dtype = np.float64 if jax.config.read("jax_enable_x64") else np.float32 + return sc.Context(sc.JaxOps(), dtype=dtype, enable_checks=False) + + +def _spd_operator(n: int): + import spacecore as sc + + ctx = _ctx() + space = sc.VectorSpace((n,), ctx) + matrix = np.diag(np.arange(2.0, 2.0 + n)) + matrix += 0.05 * np.ones((n, n)) + return sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + +def _rect_operator(): + import spacecore as sc + + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 0.0], [0.0, 2.0], [1.0, -1.0]]) + return sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + + +def _same_shape_inputs(A: Any) -> tuple[Any, Any]: + ctx = A.ctx + x0 = ctx.asarray(np.linspace(1.0, 2.0, A.domain.shape[0])) + x1 = ctx.asarray(np.linspace(2.0, 3.0, A.domain.shape[0])) + return x0, x1 + + +def _audit_solver( + name: str, + fn_factory: Callable[[dict[str, int]], Callable[..., Any]], + A: Any, + first_rhs: Any, + second_rhs: Any, + shape_changed_A: Any, + shape_changed_rhs: Any, + static_name: str, +) -> dict[str, Any]: + import jax + + traces = {"count": 0} + fn = fn_factory(traces) + jitted = jax.jit(fn, static_argnames=(static_name,)) + + out0 = jitted(A, first_rhs, **{static_name: 4}) + out1 = jitted(A, second_rhs, **{static_name: 4}) + same_shape_traces = traces["count"] + + out2 = jitted(A, first_rhs, **{static_name: 5}) + static_changed_traces = traces["count"] + + out3 = jitted(shape_changed_A, shape_changed_rhs, **{static_name: 4}) + shape_changed_traces = traces["count"] + + for out in (out0, out1, out2, out3): + jax.block_until_ready(out) + + return { + "solver": name, + "traces_after_two_same_shape_calls": same_shape_traces, + "traces_after_static_change": static_changed_traces, + "traces_after_shape_change": shape_changed_traces, + "stable_values_retraced": same_shape_traces > 1, + "static_change_retraced": static_changed_traces > same_shape_traces, + "shape_change_retraced": shape_changed_traces > static_changed_traces, + } + + +def main() -> None: + import jax + import spacecore as sc + + jax.config.update("jax_log_compiles", True) + + A2 = _spd_operator(2) + A3 = _spd_operator(3) + x2a, x2b = _same_shape_inputs(A2) + x3a, _ = _same_shape_inputs(A3) + R2 = _rect_operator() + R3 = sc.DenseLinOp( + _ctx().asarray([[1.0, 0.0, 0.5], [0.0, 2.0, -1.0], [1.0, -1.0, 0.25], [0.5, 0.0, 1.0]]), + sc.VectorSpace((3,), _ctx()), + sc.VectorSpace((4,), _ctx()), + _ctx(), + ) + b2a = R2.codomain.ctx.asarray([1.0, 2.0, 3.0]) + b2b = R2.codomain.ctx.asarray([3.0, 2.0, 1.0]) + b4 = R3.codomain.ctx.asarray([1.0, 2.0, 3.0, 4.0]) + + audits = [ + _audit_solver( + "cg", + lambda traces: ( + lambda A, b, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.cg(A, b, maxiter=maxiter).x + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "maxiter", + ), + _audit_solver( + "lsqr", + lambda traces: ( + lambda A, b, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.lsqr(A, b, maxiter=maxiter).x + ) + ), + R2, + b2a, + b2b, + R3, + b4, + "maxiter", + ), + _audit_solver( + "lanczos_smallest", + lambda traces: ( + lambda A, x, max_iter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.lanczos_smallest(A, x, max_iter=max_iter).eigenvalue + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "max_iter", + ), + _audit_solver( + "power_iteration", + lambda traces: ( + lambda A, x, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.power_iteration(A, x0=x, maxiter=maxiter).eigenvalue + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "maxiter", + ), + ] + + if hasattr(sc, "expm_multiply"): + audits.append( + _audit_solver( + "expm_multiply", + lambda traces: ( + lambda A, x, max_iter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.expm_multiply(A, x, max_iter=max_iter).result + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "max_iter", + ) + ) + else: + audits.append({"solver": "expm_multiply", "status": "not available before Task 1"}) + + FIXTURE.parent.mkdir(parents=True, exist_ok=True) + jaxpr = jax.make_jaxpr( + lambda A, x: sc.lanczos_smallest(A, x, max_iter=3, check_every=1).eigenvalue + )(A2, x2a) + FIXTURE.write_text(str(jaxpr)) + + print("JIT audit summary") + for item in audits: + print(item) + print(f"wrote {FIXTURE.relative_to(ROOT)}") + + +if __name__ == "__main__": + main() diff --git a/tests/fixtures/jaxpr_lanczos_smallest.txt b/tests/fixtures/jaxpr_lanczos_smallest.txt new file mode 100644 index 0000000..8361375 --- /dev/null +++ b/tests/fixtures/jaxpr_lanczos_smallest.txt @@ -0,0 +1,484 @@ +let _where = { lambda ; a:bool[] b:f32[] c:f32[]. let + d:f32[] = select_n a c b + in (d,) } in +{ lambda ; e:f32[2,2] f:f32[2]. let + _:f32[2,2] = transpose[permutation=(1, 0)] e + _:f32[2,2] = transpose[permutation=(1, 0)] e + g:f32[4,2] = broadcast_in_dim 0.0:f32[] + h:f32[3] = broadcast_in_dim 0.0:f32[] + i:f32[4] = broadcast_in_dim 0.0:f32[] + j:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] f f + k:f32[] = sqrt j + l:f32[2] = broadcast_in_dim 0.0:f32[] + m:i32[1] = broadcast_in_dim 0:i32[] + n:f32[2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] l m 1.0:f32[] + o:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] n n + p:f32[] = sqrt o + q:bool[] = gt p 0.0:f32[] + r:f32[] = jit[name=_where jaxpr=_where] q p 1.0:f32[] + s:f32[] = div 1.0:f32[] r + t:f32[] = jit[name=_where jaxpr=_where] q s 0.0:f32[] + u:f32[2] = mul t n + v:bool[] = gt k 9.999999960041972e-13:f32[] + w:i32[] = convert_element_type[new_dtype=int32 weak_type=False] v + x:f32[2] = cond[ + branches=( + { lambda ; y:f32[] z:f32[2] ba:f32[2] bb:f32[]. let in (ba,) } + { lambda ; bc:f32[] bd:f32[2] be:f32[2] bf:f32[]. let + bg:bool[] = gt bc 0.0:f32[] + bh:f32[] = jit[name=_where jaxpr=_where] bg bc 1.0:f32[] + bi:f32[] = div 1.0:f32[] bh + bj:f32[] = jit[name=_where jaxpr=_where] bg bi 0.0:f32[] + bk:f32[2] = mul bj bd + in (bk,) } + ) + ] w k f u 0.0:f32[] + bl:i32[1] = broadcast_in_dim 0:i32[] + bm:f32[4,2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] g bl x + bn:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] + bo:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] + bp:f32[4] = broadcast_in_dim 0.0:f32[] + bq:i32[] br:f32[4,2] bs:f32[3] bt:f32[4] _:f32[] _:bool[] = while[ + body_jaxpr={ lambda ; bu:f32[2,2] bv:i32[4] bw:f32[4] bx:f32[] by:i32[] bz:f32[4,2] + ca:f32[3] cb:f32[4] cc:f32[] cd:bool[]. let + ce:i32[] = convert_element_type[new_dtype=int32 weak_type=False] by + cf:bool[] = lt ce 0:i32[] + cg:i32[] = add ce 4:i32[] + ch:i32[] = select_n cf ce cg + ci:bool[] = lt 0:i32[] 0:i32[] + cj:i32[] = add 0:i32[] 2:i32[] + ck:i32[] = select_n ci 0:i32[] cj + cl:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] bz ch ck + cm:f32[2] = squeeze[dimensions=(0,)] cl + cn:f32[2] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] bu cm + co:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] cm cn + cp:bool[] = lt by 0:i32[] + cq:i32[] = add by 3:i32[] + cr:i32[] = select_n cp by cq + cs:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cr + ct:i32[1] = broadcast_in_dim cs + cu:f32[3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] ca ct co + cv:bool[] = eq by 0:i32[] + cw:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cv + cx:f32[2] = cond[ + branches=( + { lambda ; cy:f32[] cz:f32[2] da:i32[] db:f32[4] dc:f32[4,2] dd:f32[2]. let + de:f32[2] = mul cy cz + df:f32[2] = sub dd de + dg:bool[] = lt da 0:i32[] + dh:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] da + di:i32[] = add dh 4:i32[] + dj:i32[] = select_n dg da di + dk:f32[1] = dynamic_slice[slice_sizes=(1,)] db dj + dl:f32[] = squeeze[dimensions=(0,)] dk + dm:i32[] = sub da 1:i32[] + dn:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] dm + do:bool[] = lt dn 0:i32[] + dp:i32[] = add dn 4:i32[] + dq:i32[] = select_n do dn dp + dr:bool[] = lt 0:i32[] 0:i32[] + ds:i32[] = add 0:i32[] 2:i32[] + dt:i32[] = select_n dr 0:i32[] ds + du:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] dc dq dt + dv:f32[2] = squeeze[dimensions=(0,)] du + dw:f32[2] = mul dl dv + dx:f32[2] = sub df dw + in (dx,) } + { lambda ; dy:f32[] dz:f32[2] ea:i32[] eb:f32[4] ec:f32[4,2] ed:f32[2]. let + ee:f32[2] = mul dy dz + ef:f32[2] = sub ed ee + in (ef,) } + ) + ] cw co cm by cb bz cn + eg:i32[] = add by 1:i32[] + eh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] eg + ei:bool[4] = lt bv eh + ej:f32[4] = jit[ + name=_where + jaxpr={ lambda ; ei:bool[4] ek:f32[] el:f32[]. let + em:f32[4] = broadcast_in_dim ek + en:f32[4] = broadcast_in_dim el + ej:f32[4] = select_n ei en em + in (ej,) } + ] ei 1.0:f32[] 0.0:f32[] + _:i32[] eo:f32[4] = scan[ + _split_transpose=False + jaxpr={ lambda ; ep:f32[4,2] eq:f32[2] er:i32[] es:f32[4]. let + et:i32[] = add er 1:i32[] + eu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] er + ev:bool[] = lt eu 0:i32[] + ew:i32[] = add eu 4:i32[] + ex:i32[] = select_n ev eu ew + ey:bool[] = lt 0:i32[] 0:i32[] + ez:i32[] = add 0:i32[] 2:i32[] + fa:i32[] = select_n ey 0:i32[] ez + fb:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] ep ex fa + fc:f32[2] = squeeze[dimensions=(0,)] fb + fd:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] fc eq + fe:bool[] = lt er 0:i32[] + ff:i32[] = add er 4:i32[] + fg:i32[] = select_n fe er ff + fh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fg + fi:i32[1] = broadcast_in_dim fh + fj:f32[4] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] es fi fd + in (et, fj) } + length=4 + linear=(False, False, False, False) + num_carry=2 + num_consts=2 + reverse=False + unroll=1 + ] bz cx 0:i32[] bw + fk:f32[4] = mul eo ej + fl:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,)] fk + fm:f32[4,2] = mul fl bz + fn:f32[2] = reduce_sum[axes=(0,) out_sharding=None] fm + fo:f32[2] = sub cx fn + fp:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] fo fo + fq:f32[] = sqrt fp + fr:i32[] = add by 1:i32[] + fs:bool[] = lt fr 0:i32[] + ft:i32[] = add fr 4:i32[] + fu:i32[] = select_n fs fr ft + fv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fu + fw:i32[1] = broadcast_in_dim fv + fx:f32[4] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] cb fw fq + fy:bool[] = ge fq bx + fz:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fy + ga:f32[4,2] = cond[ + branches=( + { lambda ; gb:f32[] gc:f32[2] gd:i32[] ge:f32[4,2]. let + + in (ge,) } + { lambda ; gf:f32[] gg:f32[2] gh:i32[] gi:f32[4,2]. let + gj:bool[] = gt gf 0.0:f32[] + gk:f32[] = jit[name=_where jaxpr=_where] gj gf 1.0:f32[] + gl:f32[] = div 1.0:f32[] gk + gm:f32[] = jit[name=_where jaxpr=_where] gj gl 0.0:f32[] + gn:f32[2] = mul gm gg + go:i32[] = add gh 1:i32[] + gp:bool[] = lt go 0:i32[] + gq:i32[] = add go 4:i32[] + gr:i32[] = select_n gp go gq + gs:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] gr + gt:i32[1] = broadcast_in_dim gs + gu:f32[4,2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] gi gt gn + in (gu,) } + ) + ] fz fq fo by bz + gv:i32[] = add by 1:i32[] + gw:bool[] = ge gv 3:i32[] + gx:i32[] = jit[ + name=remainder + jaxpr={ lambda ; gv:i32[] gy:i32[]. let + gz:bool[] = eq gy 0:i32[] + ha:i32[] = jit[ + name=_where + jaxpr={ lambda ; gz:bool[] hb:i32[] gy:i32[]. let + ha:i32[] = select_n gz gy hb + in (ha,) } + ] gz 1:i32[] gy + hc:i32[] = rem gv ha + hd:bool[] = ne hc 0:i32[] + he:bool[] = lt hc 0:i32[] + hf:bool[] = lt ha 0:i32[] + hg:bool[] = ne he hf + hh:bool[] = and hg hd + hi:i32[] = add hc ha + gx:i32[] = select_n hh hc hi + in (gx,) } + ] gv 1:i32[] + hj:bool[] = eq gx 0:i32[] + hk:bool[] = convert_element_type[new_dtype=bool weak_type=False] gw + hl:bool[] = convert_element_type[new_dtype=bool weak_type=False] hj + hm:bool[] = or hk hl + hn:i32[] = convert_element_type[new_dtype=int32 weak_type=False] hm + ho:bool[] = cond[ + branches=( + { lambda ; hp:f32[] hq:f32[] hr:bool[] hs:f32[]. let in (hr,) } + { lambda ; ht:f32[] hu:f32[] hv:bool[] hw:f32[]. let + hx:bool[] = ge ht hu + in (hx,) } + ) + ] hn fq bx cd 0.0:f32[] + in (gv, ga, cu, fx, fq, ho) } + body_nconsts=4 + cond_jaxpr={ lambda ; hy:i32[] hz:f32[4,2] ia:f32[3] ib:f32[4] ic:f32[] id:bool[]. let + ie:bool[] = lt hy 3:i32[] + if:bool[] = convert_element_type[new_dtype=bool weak_type=False] ie + ig:bool[] = and if id + in (ig,) } + cond_nconsts=0 + ] e bn bp 9.999999974752427e-07:f32[] 0:i32[] bm h i 1.0:f32[] True:bool[] + ih:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq + ii:bool[3] = lt bo ih + ij:f32[3] = abs bs + ik:f32[] = reduce_max[axes=(0,)] ij + il:f32[4] = abs bt + im:f32[] = reduce_max[axes=(0,)] il + in:f32[] = mul 2.0:f32[] im + io:f32[] = add ik in + ip:f32[] = add io 1.0:f32[] + iq:f32[3] = jit[ + name=_where + jaxpr={ lambda ; ii:bool[3] bs:f32[3] ip:f32[]. let + ir:f32[3] = broadcast_in_dim ip + iq:f32[3] = select_n ii ir bs + in (iq,) } + ] ii bs ip + is:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq + it:bool[4] = eq bn is + iu:f32[4] = jit[ + name=_where + jaxpr={ lambda ; it:bool[4] iv:f32[] bt:f32[4]. let + iw:f32[4] = broadcast_in_dim iv + iu:f32[4] = select_n it bt iw + in (iu,) } + ] it 0.0:f32[] bt + ix:f32[3,3] = broadcast_in_dim 0.0:f32[] + _:i32[] iy:f32[3,3] = scan[ + _split_transpose=False + jaxpr={ lambda ; iz:f32[3] ja:i32[] jb:f32[3,3]. let + jc:i32[] = add ja 1:i32[] + jd:bool[] = lt ja 0:i32[] + je:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ja + jf:i32[] = add je 3:i32[] + jg:i32[] = select_n jd ja jf + jh:f32[1] = dynamic_slice[slice_sizes=(1,)] iz jg + ji:f32[] = squeeze[dimensions=(0,)] jh + jj:bool[] = lt ja 0:i32[] + jk:i32[] = add ja 3:i32[] + jl:i32[] = select_n jj ja jk + jm:bool[] = lt ja 0:i32[] + jn:i32[] = add ja 3:i32[] + jo:i32[] = select_n jm ja jn + jp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jl + jq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jo + jr:i32[1] = broadcast_in_dim jp + js:i32[1] = broadcast_in_dim jq + jt:i32[2] = concatenate[dimension=0] jr js + ju:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] jb jt ji + in (jc, ju) } + length=3 + linear=(False, False, False) + num_carry=2 + num_consts=1 + reverse=False + unroll=1 + ] iq 0:i32[] ix + _:i32[] jv:f32[3,3] = scan[ + _split_transpose=False + jaxpr={ lambda ; jw:f32[4] jx:i32[] jy:f32[3,3]. let + jz:i32[] = add jx 1:i32[] + ka:i32[] = add jx 1:i32[] + kb:bool[] = lt ka 0:i32[] + kc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ka + kd:i32[] = add kc 4:i32[] + ke:i32[] = select_n kb ka kd + kf:f32[1] = dynamic_slice[slice_sizes=(1,)] jw ke + kg:f32[] = squeeze[dimensions=(0,)] kf + kh:i32[] = add jx 1:i32[] + ki:bool[] = lt jx 0:i32[] + kj:i32[] = add jx 3:i32[] + kk:i32[] = select_n ki jx kj + kl:bool[] = lt kh 0:i32[] + km:i32[] = add kh 3:i32[] + kn:i32[] = select_n kl kh km + ko:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kk + kp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kn + kq:i32[1] = broadcast_in_dim ko + kr:i32[1] = broadcast_in_dim kp + ks:i32[2] = concatenate[dimension=0] kq kr + kt:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] jy ks kg + ku:i32[] = add jx 1:i32[] + kv:bool[] = lt ku 0:i32[] + kw:i32[] = add ku 3:i32[] + kx:i32[] = select_n kv ku kw + ky:bool[] = lt jx 0:i32[] + kz:i32[] = add jx 3:i32[] + la:i32[] = select_n ky jx kz + lb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kx + lc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] la + ld:i32[1] = broadcast_in_dim lb + le:i32[1] = broadcast_in_dim lc + lf:i32[2] = concatenate[dimension=0] ld le + lg:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] kt lf kg + in (jz, lg) } + length=2 + linear=(False, False, False) + num_carry=2 + num_consts=1 + reverse=False + unroll=1 + ] iu 0:i32[] iy + _:f32[3] lh:f32[3,3] = jit[ + name=eigh + jaxpr={ lambda ; jv:f32[3,3]. let + li:f32[3,3] = transpose[permutation=(1, 0)] jv + lj:f32[3,3] = add jv li + lk:f32[3,3] = div lj 2.0:f32[] + lh:f32[3,3] _:f32[3] = eigh[ + algorithm=None + lower=True + sort_eigenvalues=True + subset_by_index=None + ] lk + in (_, lh) } + ] jv + ll:f32[3,1] = slice[limit_indices=(3, 1) start_indices=(0, 0) strides=None] lh + lm:f32[3] = squeeze[dimensions=(1,)] ll + ln:bool[] = lt bq 0:i32[] + lo:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq + lp:i32[] = add lo 4:i32[] + lq:i32[] = select_n ln bq lp + lr:f32[1] = dynamic_slice[slice_sizes=(1,)] bt lq + ls:f32[] = squeeze[dimensions=(0,)] lr + lt:i32[] = sub bq 1:i32[] + lu:bool[] = lt lt 0:i32[] + lv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] lt + lw:i32[] = add lv 3:i32[] + lx:i32[] = select_n lu lt lw + ly:f32[1] = dynamic_slice[slice_sizes=(1,)] lm lx + lz:f32[] = squeeze[dimensions=(0,)] ly + ma:f32[] = abs lz + mb:f32[] = mul ls ma + _:bool[] = lt mb 9.999999974752427e-07:f32[] + mc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq + md:bool[3] = lt bo mc + me:f32[3] = jit[ + name=_where + jaxpr={ lambda ; md:bool[3] mf:f32[] mg:f32[]. let + mh:f32[3] = broadcast_in_dim mf + mi:f32[3] = broadcast_in_dim mg + me:f32[3] = select_n md mi mh + in (me,) } + ] md 1.0:f32[] 0.0:f32[] + mj:f32[3] = mul lm me + mk:f32[3,2] = slice[limit_indices=(3, 2) start_indices=(0, 0) strides=None] br + ml:f32[2] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] mj mk + mm:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] ml ml + mn:f32[] = sqrt mm + mo:bool[] = gt mn 9.999999960041972e-13:f32[] + mp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] mo + mq:f32[2] = cond[ + branches=( + { lambda ; mr:f32[] ms:f32[2] mt:f32[2] mu:f32[]. let in (mt,) } + { lambda ; mv:f32[] mw:f32[2] mx:f32[2] my:f32[]. let + mz:bool[] = gt mv 0.0:f32[] + na:f32[] = jit[name=_where jaxpr=_where] mz mv 1.0:f32[] + nb:f32[] = div 1.0:f32[] na + nc:f32[] = jit[name=_where jaxpr=_where] mz nb 0.0:f32[] + nd:f32[2] = mul nc mw + in (nd,) } + ) + ] mp mn ml u 0.0:f32[] + ne:f32[2] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] e mq + nf:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] mq ne + ng:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] mq mq + nh:f32[] = div nf ng + in (nh,) } \ No newline at end of file From e90caee1e4d836ce6b2f279c63d7415fb98d9418 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:15:41 -0300 Subject: [PATCH 27/44] Vectorize Lanczos reorthogonalization for vector spaces --- spacecore/linalg/_lanczos.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 900556d..f34b353 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -5,6 +5,7 @@ from ..linop import LinOp +from ..space import VectorSpace from ..types import DenseArray from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval from ._utils import require_linop, require_square, safe_inverse_nonneg, should_check_iteration @@ -91,6 +92,7 @@ def lanczos_smallest( ops = A.ops ctx = A.ctx real_dtype = ops.real_dtype(ctx.dtype) + use_euclidean_reorth = type(A.domain) is VectorSpace v0 = A.domain.flatten(initial_vector) v0 = ctx.assert_dense(v0) @@ -161,14 +163,17 @@ def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, ) mask = ops.astype(mask, ctx.dtype) - coeffs_full = coeffs_zero + if use_euclidean_reorth: + coeffs_full = ops.einsum("jn,n->j", ops.conj(V_), w) + else: + coeffs_full = coeffs_zero - def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: - v_j_member = A.domain.unflatten(V_[j]) - coeff = A.domain.inner(v_j_member, w_member) - return ops.index_set(coeffs_in, (j,), coeff, copy=True) + def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: + v_j_member = A.domain.unflatten(V_[j]) + coeff = A.domain.inner(v_j_member, w_member) + return ops.index_set(coeffs_in, (j,), coeff, copy=True) - coeffs_full = ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full) + coeffs_full = ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full) coeffs_valid = coeffs_full * mask proj = ops.sum(coeffs_valid[:, None] * V_, axis=0) w = w - proj From a32432c08f0861d6f80f82c12016eee169760900 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:22:37 -0300 Subject: [PATCH 28/44] Add Krylov expm_multiply --- spacecore/__init__.py | 4 + spacecore/linalg/__init__.py | 3 + spacecore/linalg/_expm.py | 102 ++++++++++++++ spacecore/linalg/_lanczos.py | 173 ++++++++++++++---------- tests/integration/test_public_api.py | 3 + tests/linalg/test_expm.py | 192 +++++++++++++++++++++++++++ 6 files changed, 410 insertions(+), 67 deletions(-) create mode 100644 spacecore/linalg/_expm.py create mode 100644 tests/linalg/test_expm.py diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 16d922f..19690b6 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -45,11 +45,13 @@ ) from .linalg import ( CGResult, + ExpmMultiplyResult, LanczosResult, LSQRResult, PowerIterationResult, StochasticLanczosResult, cg, + expm_multiply, lanczos_smallest, lsqr, power_iteration, @@ -119,11 +121,13 @@ "make_functional_composed", "CGResult", + "ExpmMultiplyResult", "LanczosResult", "LSQRResult", "PowerIterationResult", "StochasticLanczosResult", "cg", + "expm_multiply", "lanczos_smallest", "lsqr", "power_iteration", diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index 03837cc..f3a4a7e 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -1,17 +1,20 @@ from __future__ import annotations from ._cg import CGResult, cg +from ._expm import ExpmMultiplyResult, expm_multiply from ._lanczos import LanczosResult, StochasticLanczosResult, lanczos_smallest, stochastic_lanczos from ._lsqr import LSQRResult, lsqr from ._power import PowerIterationResult, power_iteration __all__ = [ "CGResult", + "ExpmMultiplyResult", "LanczosResult", "LSQRResult", "PowerIterationResult", "StochasticLanczosResult", "cg", + "expm_multiply", "lanczos_smallest", "lsqr", "power_iteration", diff --git a/spacecore/linalg/_expm.py b/spacecore/linalg/_expm.py new file mode 100644 index 0000000..40b68a0 --- /dev/null +++ b/spacecore/linalg/_expm.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._lanczos import _check_lanczos_max_iter, _lanczos_basis_and_tridiag +from ._utils import require_linop, require_square, result_repr + + +class ExpmMultiplyResult(NamedTuple): + """Result returned by :func:`expm_multiply`. + + Attributes + ---------- + result: + Vector in the domain of the input operator approximating + ``exp(t * A) @ v``. + krylov_dim: + Actual Krylov dimension reached before breakdown or ``max_iter``. + residual_estimate: + Projected exponential residual estimate + ``abs(beta[m] * phi[m - 1])``. + converged: + Boolean indicating whether ``residual_estimate < tol``. + """ + + result: Any + krylov_dim: Any + residual_estimate: Any + converged: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full vector.""" + return result_repr( + "ExpmMultiplyResult", + { + "converged": self.converged, + "krylov_dim": self.krylov_dim, + "residual_estimate": self.residual_estimate, + "result": self.result, + }, + ) + + +def expm_multiply( + A: LinOp, + v: Any, + t: float | complex = 1.0, + *, + max_iter: int = 30, + tol: float = 1e-10, +) -> ExpmMultiplyResult: + """ + Compute ``exp(t * A) @ v`` for a Hermitian operator via Krylov projection. + + Parameters + ---------- + A: + Square Hermitian linear operator. + v: + Initial vector in ``A.domain``. + t: + Scalar time/scale multiplying ``A``. Complex values are supported for + complex-valued contexts, for example Schrödinger evolution. + max_iter: + Maximum Krylov dimension. Values around 20-50 are usually sufficient + when ``abs(t) * ||A||`` is moderate. + tol: + Breakdown tolerance for Lanczos and threshold for the projected + exponential residual estimate. + + Returns + ------- + ExpmMultiplyResult + Result vector in ``A.domain``, the Krylov dimension used, the standard + estimate ``abs(beta[m] * phi[m - 1])``, and a convergence flag. + """ + A = require_linop(A) + require_square(A, "expm_multiply") + if A.is_hermitian() is False: + raise ValueError("expm_multiply requires A to be Hermitian/self-adjoint.") + max_iter = _check_lanczos_max_iter(max_iter) + A.domain.check_member(v) + + ops = A.ops + ctx = A.ctx + real_dtype = ops.real_dtype(ctx.dtype) + basis = _lanczos_basis_and_tridiag(A, v, max_iter, tol, real_dtype, check_every=1) + + eigvals, eigvecs = ops.eigh(basis.T) + exp_eigs = ops.exp(t * eigvals) + expT_e1 = eigvecs @ (exp_eigs * eigvecs[0, :]) + + V_reduced = basis.V[:max_iter, :] + result_flat = basis.initial_norm * ops.einsum("j,jn->n", expT_e1, V_reduced) + result = A.domain.unflatten(result_flat) + + last_coeff = ops.abs(expT_e1[basis.krylov_dim - 1]) + residual_estimate = basis.betas[basis.krylov_dim] * last_coeff + converged = residual_estimate < basis.tol + + return ExpmMultiplyResult(result, basis.krylov_dim, residual_estimate, converged) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index f34b353..1051dbe 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -38,6 +38,17 @@ def __repr__(self) -> str: StochasticLanczosResult = LanczosResult +class _LanczosBasisResult(NamedTuple): + V: DenseArray + T: DenseArray + alphas: DenseArray + betas: DenseArray + krylov_dim: Any + initial_norm: Any + tol: Any + e0_unit: DenseArray + + def _check_lanczos_max_iter(max_iter: int) -> int: max_iter = int(max_iter) if max_iter < 1: @@ -45,53 +56,51 @@ def _check_lanczos_max_iter(max_iter: int) -> int: return max_iter -def lanczos_smallest( - A: LinOp, - initial_vector: Any, - *, - max_iter: int = 100, - tol: float = 1e-6, - check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, -) -> LanczosResult: - r"""Approximate the smallest eigenpair of a Hermitian operator. +def _build_tridiagonal( + ops: Any, + alphas: DenseArray, + betas: DenseArray, + max_iter: int, + m: Any, + real_dtype: Any, +) -> DenseArray: + idx = ops.arange(max_iter) + full_indices = ops.arange(max_iter + 1) + mask_alpha = idx < m + inactive_sentinel = ( + ops.max(ops.abs(alphas)) + + 2.0 * ops.max(ops.abs(betas)) + + ops.asarray(1.0, dtype=real_dtype) + ) + alphas_full = ops.where(mask_alpha, alphas, inactive_sentinel) + betas_full = ops.where(full_indices == m, ops.asarray(0.0, dtype=real_dtype), betas) - The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an - element of ``A.domain``. The implementation keeps fixed-size coordinate - arrays for JAX compatibility, safely handles zero initial vectors, and - refines the returned eigenvalue with the Rayleigh quotient of the - reconstructed Ritz vector in the original space. + T = ops.zeros((max_iter, max_iter), dtype=real_dtype) - Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for - ``span{v, T v, T^2 v, ...}`` and a tridiagonal projection - :math:`T_k = V^\dagger T V`. The returned vector is the Ritz vector - reconstructed in the original coordinates, and the returned scalar is the - Rayleigh quotient - :math:`(x^\dagger T x) / (x^\dagger x)`. + def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: + return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) - Args: - A: Square Hermitian linear operator. - initial_vector: Starting vector in ``A.domain``. - max_iter: Maximum number of Lanczos steps. - tol: Breakdown tolerance for the off-diagonal Lanczos coefficient. - check_every: Refresh the breakdown-based stopping decision only every - this many iterations, and always on the final iteration. + T = ops.fori_loop(0, max_iter, fill_diag, T) - Returns: - ``LanczosResult`` containing the smallest approximated eigenpair, the - standard Ritz residual estimate ``beta[m] * abs(y[m - 1])``, the - Krylov dimension reached, and a convergence flag. The residual estimate - is computed from the tridiagonal recurrence; callers that need the true - residual can evaluate ``A.apply(eigenvector) - eigenvalue * eigenvector`` - once more in the original space. - """ - A = require_linop(A) - require_square(A, "lanczos_smallest") - max_iter = _check_lanczos_max_iter(max_iter) - check_every = check_interval(check_every) - A.domain.check_member(initial_vector) + def fill_off(ii: int, T_in: DenseArray) -> DenseArray: + b = betas_full[ii + 1] + T_in = ops.index_set(T_in, (ii, ii + 1), b, copy=True) + T_in = ops.index_set(T_in, (ii + 1, ii), b, copy=True) + return T_in + + return ops.fori_loop(0, max_iter - 1, fill_off, T) + + +def _lanczos_basis_and_tridiag( + A: LinOp, + initial_vector: Any, + max_iter: int, + tol: float, + real_dtype: Any, + check_every: int, +) -> _LanczosBasisResult: ops = A.ops ctx = A.ctx - real_dtype = ops.real_dtype(ctx.dtype) use_euclidean_reorth = type(A.domain) is VectorSpace v0 = A.domain.flatten(initial_vector) @@ -128,7 +137,6 @@ def lanczos_smallest( keep_going0 = ops.asarray(True) full_indices = ops.arange(max_iter + 1) - idx = ops.arange(max_iter) coeffs_zero = ops.zeros((max_iter + 1,), dtype=ctx.dtype) def cond_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> Any: @@ -200,36 +208,67 @@ def set_next(V_in: DenseArray) -> DenseArray: i_final, V, alphas, betas, _beta_final, _keep_going = ops.while_loop( cond_fun, body_fun, (i0, V, alphas, betas, beta0, keep_going0) ) - m = i_final + T = _build_tridiagonal(ops, alphas, betas, max_iter, i_final, real_dtype) + return _LanczosBasisResult(V, T, alphas, betas, i_final, v0_norm, tol_s, e0_unit) - mask_alpha = idx < m - inactive_sentinel = ( - ops.max(ops.abs(alphas)) - + 2.0 * ops.max(ops.abs(betas)) - + ops.asarray(1.0, dtype=real_dtype) - ) - alphas_full = ops.where(mask_alpha, alphas, inactive_sentinel) - betas_full = ops.where(full_indices == m, ops.asarray(0.0, dtype=real_dtype), betas) - T = ops.zeros((max_iter, max_iter), dtype=real_dtype) +def lanczos_smallest( + A: LinOp, + initial_vector: Any, + *, + max_iter: int = 100, + tol: float = 1e-6, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> LanczosResult: + r"""Approximate the smallest eigenpair of a Hermitian operator. - def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: - return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) + The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an + element of ``A.domain``. The implementation keeps fixed-size coordinate + arrays for JAX compatibility, safely handles zero initial vectors, and + refines the returned eigenvalue with the Rayleigh quotient of the + reconstructed Ritz vector in the original space. - T = ops.fori_loop(0, max_iter, fill_diag, T) + Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for + ``span{v, T v, T^2 v, ...}`` and a tridiagonal projection + :math:`T_k = V^\dagger T V`. The returned vector is the Ritz vector + reconstructed in the original coordinates, and the returned scalar is the + Rayleigh quotient + :math:`(x^\dagger T x) / (x^\dagger x)`. - def fill_off(ii: int, T_in: DenseArray) -> DenseArray: - b = betas_full[ii + 1] - T_in = ops.index_set(T_in, (ii, ii + 1), b, copy=True) - T_in = ops.index_set(T_in, (ii + 1, ii), b, copy=True) - return T_in + Args: + A: Square Hermitian linear operator. + initial_vector: Starting vector in ``A.domain``. + max_iter: Maximum number of Lanczos steps. + tol: Breakdown tolerance for the off-diagonal Lanczos coefficient. + check_every: Refresh the breakdown-based stopping decision only every + this many iterations, and always on the final iteration. - T = ops.fori_loop(0, max_iter - 1, fill_off, T) + Returns: + ``LanczosResult`` containing the smallest approximated eigenpair, the + standard Ritz residual estimate ``beta[m] * abs(y[m - 1])``, the + Krylov dimension reached, and a convergence flag. The residual estimate + is computed from the tridiagonal recurrence; callers that need the true + residual can evaluate ``A.apply(eigenvector) - eigenvalue * eigenvector`` + once more in the original space. + """ + A = require_linop(A) + require_square(A, "lanczos_smallest") + max_iter = _check_lanczos_max_iter(max_iter) + check_every = check_interval(check_every) + A.domain.check_member(initial_vector) + ops = A.ops + ctx = A.ctx + real_dtype = ops.real_dtype(ctx.dtype) + idx = ops.arange(max_iter) + basis = _lanczos_basis_and_tridiag( + A, initial_vector, max_iter, tol, real_dtype, check_every + ) - _eigvals, eigvecs = ops.eigh(T) + m = basis.krylov_dim + _eigvals, eigvecs = ops.eigh(basis.T) y_full = eigvecs[:, 0] - residual_norm = betas[m] * ops.abs(y_full[m - 1]) - converged = residual_norm < tol_s + residual_norm = basis.betas[m] * ops.abs(y_full[m - 1]) + converged = residual_norm < basis.tol mask_y = ops.where( idx < m, @@ -239,15 +278,15 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: mask_y = ops.astype(mask_y, y_full.dtype) y_valid = y_full * mask_y - V_reduced = V[:max_iter, :] + V_reduced = basis.V[:max_iter, :] x_flat = ops.einsum("j,jn->n", y_valid, V_reduced) x_member = A.domain.unflatten(x_flat) x_norm = A.domain.norm(x_member) x_flat = ops.cond( - x_norm > eps_s, + x_norm > ops.asarray(1e-12, dtype=real_dtype), lambda _: A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, x_norm), x_member)), - lambda _: e0_unit, + lambda _: basis.e0_unit, ops.asarray(0.0, dtype=real_dtype), ) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 31778f2..1de2587 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -32,6 +32,7 @@ def test_expected_names_are_exported(): "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", "LanczosResult", "StochasticLanczosResult", "lanczos_smallest", + "ExpmMultiplyResult", "expm_multiply", } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} @@ -65,6 +66,8 @@ def test_top_level_objects_match_source_modules(): assert sc.InnerProductFunctional is functional.InnerProductFunctional assert sc.LanczosResult is linalg.LanczosResult assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult + assert sc.ExpmMultiplyResult is linalg.ExpmMultiplyResult + assert sc.expm_multiply is linalg.expm_multiply assert sc.get_context is contextual.get_context assert sc.resolve_context_priority is contextual.resolve_context_priority diff --git a/tests/linalg/test_expm.py b/tests/linalg/test_expm.py new file mode 100644 index 0000000..55dc504 --- /dev/null +++ b/tests/linalg/test_expm.py @@ -0,0 +1,192 @@ +import importlib + +import numpy as np +import pytest +import scipy.linalg + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + ] + + +def _ctx(backend_name="numpy", dtype=np.float64, enable_checks=False): + sc = importlib.import_module("spacecore") + return sc.Context(_ops_for_backend(backend_name), dtype=dtype, enable_checks=enable_checks) + + +def _operator(ctx, matrix): + sc = importlib.import_module("spacecore") + space = sc.VectorSpace((matrix.shape[0],), ctx) + return sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + +def _ground_truth(matrix, vector, t): + return scipy.linalg.expm(t * matrix) @ vector + + +def test_expm_multiply_t_zero_returns_input(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + matrix = np.array([[2.0, 0.5], [0.5, 3.0]]) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0, -2.0]) + + result = sc.expm_multiply(A, v, t=0.0, max_iter=4) + + np.testing.assert_allclose(result.result, v, rtol=1e-12, atol=1e-12) + assert isinstance(result, sc.ExpmMultiplyResult) + + +def test_expm_multiply_rejects_structurally_non_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 2.0], [0.0, 3.0]])) + v = ctx.asarray([1.0, -2.0]) + + with pytest.raises(ValueError, match="Hermitian"): + sc.expm_multiply(A, v, t=0.1, max_iter=4) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_expm_multiply_matches_dense_ground_truth(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + matrix = np.array([[1.0, 0.25, 0.0], [0.25, 2.0, -0.5], [0.0, -0.5, 3.0]]) + A = _operator(ctx, matrix) + v_np = np.array([1.0, -2.0, 0.5]) + v = ctx.asarray(v_np) + + result = sc.expm_multiply(A, v, t=-0.2, max_iter=8, tol=1e-12) + + np.testing.assert_allclose( + to_numpy(result.result), + _ground_truth(matrix, v_np, -0.2), + rtol=1e-5, + atol=1e-5, + ) + assert bool(to_numpy(result.converged)) + + +def test_expm_multiply_is_linear_in_vector(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 0.5], [0.5, 2.0]])) + v1 = ctx.asarray([1.0, -1.0]) + v2 = ctx.asarray([0.5, 2.0]) + alpha = 1.5 + beta = -0.25 + + combined = sc.expm_multiply(A, alpha * v1 + beta * v2, t=0.3, max_iter=6).result + expected = ( + alpha * sc.expm_multiply(A, v1, t=0.3, max_iter=6).result + + beta * sc.expm_multiply(A, v2, t=0.3, max_iter=6).result + ) + + np.testing.assert_allclose(combined, expected, rtol=1e-10, atol=1e-10) + + +def test_expm_multiply_group_property(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 0.5], [0.5, 2.0]])) + v = ctx.asarray([1.0, -1.0]) + t1 = 0.2 + t2 = -0.35 + + first = sc.expm_multiply(A, v, t=t1, max_iter=6).result + sequential = sc.expm_multiply(A, first, t=t2, max_iter=6).result + direct = sc.expm_multiply(A, v, t=t1 + t2, max_iter=6).result + + np.testing.assert_allclose(sequential, direct, rtol=1e-10, atol=1e-10) + + +def test_expm_multiply_complex_time_is_unitary_for_hermitian_generator(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + matrix = np.array([[1.0, 0.5 - 0.25j], [0.5 + 0.25j, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0 + 0.5j, -0.25 + 0.75j]) + + result = sc.expm_multiply(A, v, t=-0.5j, max_iter=6, tol=1e-12).result + + np.testing.assert_allclose( + to_numpy(A.domain.norm(result)), + to_numpy(A.domain.norm(v)), + rtol=1e-10, + atol=1e-10, + ) + + +def test_expm_multiply_complex_time_matches_dense_truth(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + matrix = np.array([[1.0, 0.25], [0.25, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v_np = np.array([1.0 - 0.5j, 0.25 + 0.75j]) + v = ctx.asarray(v_np) + + result = sc.expm_multiply(A, v, t=-0.5j, max_iter=6, tol=1e-12) + + np.testing.assert_allclose( + to_numpy(result.result), + _ground_truth(matrix, v_np, -0.5j), + rtol=1e-10, + atol=1e-10, + ) + + +def test_expm_multiply_residual_estimate_decreases_with_more_iterations(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + matrix = np.array([[1.0, 0.25, 0.1], [0.25, 2.0, -0.5], [0.1, -0.5, 4.0]]) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0, -2.0, 0.5]) + + low = sc.expm_multiply(A, v, t=0.4, max_iter=1, tol=1e-12) + high = sc.expm_multiply(A, v, t=0.4, max_iter=3, tol=1e-12) + + assert to_numpy(high.residual_estimate) <= to_numpy(low.residual_estimate) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_expm_multiply_jit_matches_eager(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_complex_dtype()) + matrix = np.array([[1.0, 0.25], [0.25, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0 - 0.5j, 0.25 + 0.75j]) + + eager = sc.expm_multiply(A, v, t=-0.5j, max_iter=6).result + run = jax.jit(lambda op, x: sc.expm_multiply(op, x, t=-0.5j, max_iter=6).result) + compiled = run(A, v) + + np.testing.assert_allclose(to_numpy(compiled), to_numpy(eager), rtol=1e-6, atol=1e-6) From 727eb92f2245fee69ab6e5a06c644e2816826b10 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 02:23:03 -0300 Subject: [PATCH 29/44] Update JIT audit for expm multiply --- audit_jit.md | 23 +- tests/fixtures/jaxpr_lanczos_smallest.txt | 700 +++++++++++----------- 2 files changed, 348 insertions(+), 375 deletions(-) diff --git a/audit_jit.md b/audit_jit.md index 8066ea6..ced8e85 100644 --- a/audit_jit.md +++ b/audit_jit.md @@ -21,11 +21,11 @@ wrapper. | `lsqr` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lsqr`. | | `lanczos_smallest` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `max_iter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lanczos_smallest`. | | `power_iteration` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `power_iteration`. | -| `expm_multiply` | Not audited yet | Not applicable | Not applicable | Not implemented before Task 1. The script reports `{'solver': 'expm_multiply', 'status': 'not available before Task 1'}` and should be rerun after Task 1. | +| `expm_multiply` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `max_iter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output: `{'solver': 'expm_multiply', 'traces_after_two_same_shape_calls': 1, 'traces_after_static_change': 2, 'traces_after_shape_change': 3, ...}` | ## Findings -### 1. Lanczos full reorthogonalization lowers to an inner scan +### 1. Lanczos full reorthogonalization now uses a vectorized exact-VectorSpace path `tests/fixtures/jaxpr_lanczos_smallest.txt` captures the current JAXPR for `lanczos_smallest(max_iter=3)`. Grep evidence: @@ -34,7 +34,7 @@ wrapper. grep -n "scan\\|while\\|scatter\\|dot_general" tests/fixtures/jaxpr_lanczos_smallest.txt ``` -The fixture shows: +The initial audit fixture showed: - a top-level `while` at line 61 for the Krylov iteration; - a nested `scan` at line 143 corresponding to @@ -50,6 +50,11 @@ Important constraint: this replacement is **not valid for arbitrary `Space.inner(v_j, w)`. Therefore the optimization should be guarded to the exact `VectorSpace` type only, not subclasses. +The implemented exact-`VectorSpace` path now lowers the coefficient computation +to a single `dot_general` at fixture line 143. Remaining `scan` operations in +the fixture are the fixed-size tridiagonal construction loops, not the +reorthogonalization coefficient loop. + ### 2. `cg` and `lsqr` use `ops.cond` correctly for periodic diagnostics Both solvers trace without errors and do not retrace on value-only changes. The @@ -80,12 +85,12 @@ work arrays change. Scalars such as tolerances are currently Python arguments converted through `ops.asarray`; changing them may retrace unless callers pass them through a wrapper as array values. This is acceptable for the current API. -## Recommended Implementation +## Implemented Change -- Add a Euclidean fast path for Lanczos reorthogonalization when +- Added a Euclidean fast path for Lanczos reorthogonalization when `type(A.domain) is VectorSpace`: compute all coefficients with `ops.einsum("jn,n->j", ops.conj(V_), w)`. -- Keep the existing `Space.inner` loop for all non-exact `VectorSpace` domains +- Kept the existing `Space.inner` loop for all non-exact `VectorSpace` domains to preserve Space geometry. This is the only validated change from this audit. The broader replacement @@ -94,10 +99,10 @@ geometry-correct Lanczos recurrence. ## Follow-Up TODO -1. Rerun `scripts/jit_audit.py` after `expm_multiply` lands and update this - document with its trace counts. -2. Consider a benchmark for exact `VectorSpace` Lanczos before/after the +1. Consider a benchmark for exact `VectorSpace` Lanczos before/after the reorthogonalization fast path. +2. Consider vectorizing fixed-size tridiagonal construction if the remaining + construction `scan` nodes show up in JAX profiling. 3. If users report trace-time issues from constructing algebra expressions inside `jax.jit`, document that operators should be built outside the jitted function. diff --git a/tests/fixtures/jaxpr_lanczos_smallest.txt b/tests/fixtures/jaxpr_lanczos_smallest.txt index 8361375..c3849c6 100644 --- a/tests/fixtures/jaxpr_lanczos_smallest.txt +++ b/tests/fixtures/jaxpr_lanczos_smallest.txt @@ -4,481 +4,449 @@ let _where = { lambda ; a:bool[] b:f32[] c:f32[]. let { lambda ; e:f32[2,2] f:f32[2]. let _:f32[2,2] = transpose[permutation=(1, 0)] e _:f32[2,2] = transpose[permutation=(1, 0)] e - g:f32[4,2] = broadcast_in_dim 0.0:f32[] - h:f32[3] = broadcast_in_dim 0.0:f32[] - i:f32[4] = broadcast_in_dim 0.0:f32[] - j:f32[] = dot_general[ + g:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] + h:f32[4,2] = broadcast_in_dim 0.0:f32[] + i:f32[3] = broadcast_in_dim 0.0:f32[] + j:f32[4] = broadcast_in_dim 0.0:f32[] + k:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] f f - k:f32[] = sqrt j - l:f32[2] = broadcast_in_dim 0.0:f32[] - m:i32[1] = broadcast_in_dim 0:i32[] - n:f32[2] = scatter[ + l:f32[] = sqrt k + m:f32[2] = broadcast_in_dim 0.0:f32[] + n:i32[1] = broadcast_in_dim 0:i32[] + o:f32[2] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] l m 1.0:f32[] - o:f32[] = dot_general[ + ] m n 1.0:f32[] + p:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] n n - p:f32[] = sqrt o - q:bool[] = gt p 0.0:f32[] - r:f32[] = jit[name=_where jaxpr=_where] q p 1.0:f32[] - s:f32[] = div 1.0:f32[] r - t:f32[] = jit[name=_where jaxpr=_where] q s 0.0:f32[] - u:f32[2] = mul t n - v:bool[] = gt k 9.999999960041972e-13:f32[] - w:i32[] = convert_element_type[new_dtype=int32 weak_type=False] v - x:f32[2] = cond[ + ] o o + q:f32[] = sqrt p + r:bool[] = gt q 0.0:f32[] + s:f32[] = jit[name=_where jaxpr=_where] r q 1.0:f32[] + t:f32[] = div 1.0:f32[] s + u:f32[] = jit[name=_where jaxpr=_where] r t 0.0:f32[] + v:f32[2] = mul u o + w:bool[] = gt l 9.999999960041972e-13:f32[] + x:i32[] = convert_element_type[new_dtype=int32 weak_type=False] w + y:f32[2] = cond[ branches=( - { lambda ; y:f32[] z:f32[2] ba:f32[2] bb:f32[]. let in (ba,) } - { lambda ; bc:f32[] bd:f32[2] be:f32[2] bf:f32[]. let - bg:bool[] = gt bc 0.0:f32[] - bh:f32[] = jit[name=_where jaxpr=_where] bg bc 1.0:f32[] - bi:f32[] = div 1.0:f32[] bh - bj:f32[] = jit[name=_where jaxpr=_where] bg bi 0.0:f32[] - bk:f32[2] = mul bj bd - in (bk,) } + { lambda ; z:f32[] ba:f32[2] bb:f32[2] bc:f32[]. let in (bb,) } + { lambda ; bd:f32[] be:f32[2] bf:f32[2] bg:f32[]. let + bh:bool[] = gt bd 0.0:f32[] + bi:f32[] = jit[name=_where jaxpr=_where] bh bd 1.0:f32[] + bj:f32[] = div 1.0:f32[] bi + bk:f32[] = jit[name=_where jaxpr=_where] bh bj 0.0:f32[] + bl:f32[2] = mul bk be + in (bl,) } ) - ] w k f u 0.0:f32[] - bl:i32[1] = broadcast_in_dim 0:i32[] - bm:f32[4,2] = scatter[ + ] x l f v 0.0:f32[] + bm:i32[1] = broadcast_in_dim 0:i32[] + bn:f32[4,2] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] g bl x - bn:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] - bo:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] - bp:f32[4] = broadcast_in_dim 0.0:f32[] - bq:i32[] br:f32[4,2] bs:f32[3] bt:f32[4] _:f32[] _:bool[] = while[ - body_jaxpr={ lambda ; bu:f32[2,2] bv:i32[4] bw:f32[4] bx:f32[] by:i32[] bz:f32[4,2] - ca:f32[3] cb:f32[4] cc:f32[] cd:bool[]. let - ce:i32[] = convert_element_type[new_dtype=int32 weak_type=False] by - cf:bool[] = lt ce 0:i32[] - cg:i32[] = add ce 4:i32[] - ch:i32[] = select_n cf ce cg - ci:bool[] = lt 0:i32[] 0:i32[] - cj:i32[] = add 0:i32[] 2:i32[] - ck:i32[] = select_n ci 0:i32[] cj - cl:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] bz ch ck - cm:f32[2] = squeeze[dimensions=(0,)] cl - cn:f32[2] = dot_general[ + ] h bm y + bo:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] + _:f32[4] = broadcast_in_dim 0.0:f32[] + bp:i32[] bq:f32[4,2] br:f32[3] bs:f32[4] _:f32[] _:bool[] = while[ + body_jaxpr={ lambda ; bt:f32[2,2] bu:i32[4] bv:f32[] bw:i32[] bx:f32[4,2] by:f32[3] + bz:f32[4] ca:f32[] cb:bool[]. let + cc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bw + cd:bool[] = lt cc 0:i32[] + ce:i32[] = add cc 4:i32[] + cf:i32[] = select_n cd cc ce + cg:bool[] = lt 0:i32[] 0:i32[] + ch:i32[] = add 0:i32[] 2:i32[] + ci:i32[] = select_n cg 0:i32[] ch + cj:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] bx cf ci + ck:f32[2] = squeeze[dimensions=(0,)] cj + cl:f32[2] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 - ] bu cm - co:f32[] = dot_general[ + ] bt ck + cm:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] cm cn - cp:bool[] = lt by 0:i32[] - cq:i32[] = add by 3:i32[] - cr:i32[] = select_n cp by cq - cs:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cr - ct:i32[1] = broadcast_in_dim cs - cu:f32[3] = scatter[ + ] ck cl + cn:bool[] = lt bw 0:i32[] + co:i32[] = add bw 3:i32[] + cp:i32[] = select_n cn bw co + cq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cp + cr:i32[1] = broadcast_in_dim cq + cs:f32[3] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] ca ct co - cv:bool[] = eq by 0:i32[] - cw:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cv - cx:f32[2] = cond[ + ] by cr cm + ct:bool[] = eq bw 0:i32[] + cu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ct + cv:f32[2] = cond[ branches=( - { lambda ; cy:f32[] cz:f32[2] da:i32[] db:f32[4] dc:f32[4,2] dd:f32[2]. let - de:f32[2] = mul cy cz - df:f32[2] = sub dd de - dg:bool[] = lt da 0:i32[] - dh:i32[] = convert_element_type[ + { lambda ; cw:f32[] cx:f32[2] cy:i32[] cz:f32[4] da:f32[4,2] db:f32[2]. let + dc:f32[2] = mul cw cx + dd:f32[2] = sub db dc + de:bool[] = lt cy 0:i32[] + df:i32[] = convert_element_type[ new_dtype=int32 weak_type=False - ] da - di:i32[] = add dh 4:i32[] - dj:i32[] = select_n dg da di - dk:f32[1] = dynamic_slice[slice_sizes=(1,)] db dj - dl:f32[] = squeeze[dimensions=(0,)] dk - dm:i32[] = sub da 1:i32[] - dn:i32[] = convert_element_type[ + ] cy + dg:i32[] = add df 4:i32[] + dh:i32[] = select_n de cy dg + di:f32[1] = dynamic_slice[slice_sizes=(1,)] cz dh + dj:f32[] = squeeze[dimensions=(0,)] di + dk:i32[] = sub cy 1:i32[] + dl:i32[] = convert_element_type[ new_dtype=int32 weak_type=False - ] dm - do:bool[] = lt dn 0:i32[] - dp:i32[] = add dn 4:i32[] - dq:i32[] = select_n do dn dp - dr:bool[] = lt 0:i32[] 0:i32[] - ds:i32[] = add 0:i32[] 2:i32[] - dt:i32[] = select_n dr 0:i32[] ds - du:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] dc dq dt - dv:f32[2] = squeeze[dimensions=(0,)] du - dw:f32[2] = mul dl dv - dx:f32[2] = sub df dw - in (dx,) } - { lambda ; dy:f32[] dz:f32[2] ea:i32[] eb:f32[4] ec:f32[4,2] ed:f32[2]. let - ee:f32[2] = mul dy dz - ef:f32[2] = sub ed ee - in (ef,) } + ] dk + dm:bool[] = lt dl 0:i32[] + dn:i32[] = add dl 4:i32[] + do:i32[] = select_n dm dl dn + dp:bool[] = lt 0:i32[] 0:i32[] + dq:i32[] = add 0:i32[] 2:i32[] + dr:i32[] = select_n dp 0:i32[] dq + ds:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] da do dr + dt:f32[2] = squeeze[dimensions=(0,)] ds + du:f32[2] = mul dj dt + dv:f32[2] = sub dd du + in (dv,) } + { lambda ; dw:f32[] dx:f32[2] dy:i32[] dz:f32[4] ea:f32[4,2] eb:f32[2]. let + ec:f32[2] = mul dw dx + ed:f32[2] = sub eb ec + in (ed,) } ) - ] cw co cm by cb bz cn - eg:i32[] = add by 1:i32[] - eh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] eg - ei:bool[4] = lt bv eh - ej:f32[4] = jit[ + ] cu cm ck bw bz bx cl + ee:i32[] = add bw 1:i32[] + ef:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ee + eg:bool[4] = lt bu ef + eh:f32[4] = jit[ name=_where - jaxpr={ lambda ; ei:bool[4] ek:f32[] el:f32[]. let - em:f32[4] = broadcast_in_dim ek - en:f32[4] = broadcast_in_dim el - ej:f32[4] = select_n ei en em - in (ej,) } - ] ei 1.0:f32[] 0.0:f32[] - _:i32[] eo:f32[4] = scan[ - _split_transpose=False - jaxpr={ lambda ; ep:f32[4,2] eq:f32[2] er:i32[] es:f32[4]. let - et:i32[] = add er 1:i32[] - eu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] er - ev:bool[] = lt eu 0:i32[] - ew:i32[] = add eu 4:i32[] - ex:i32[] = select_n ev eu ew - ey:bool[] = lt 0:i32[] 0:i32[] - ez:i32[] = add 0:i32[] 2:i32[] - fa:i32[] = select_n ey 0:i32[] ez - fb:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] ep ex fa - fc:f32[2] = squeeze[dimensions=(0,)] fb - fd:f32[] = dot_general[ - dimension_numbers=(([0], [0]), ([], [])) - preferred_element_type=float32 - ] fc eq - fe:bool[] = lt er 0:i32[] - ff:i32[] = add er 4:i32[] - fg:i32[] = select_n fe er ff - fh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fg - fi:i32[1] = broadcast_in_dim fh - fj:f32[4] = scatter[ - dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) - indices_are_sorted=True - mode=GatherScatterMode.FILL_OR_DROP - unique_indices=True - update_consts=() - update_jaxpr=None - ] es fi fd - in (et, fj) } - length=4 - linear=(False, False, False, False) - num_carry=2 - num_consts=2 - reverse=False - unroll=1 - ] bz cx 0:i32[] bw - fk:f32[4] = mul eo ej - fl:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,)] fk - fm:f32[4,2] = mul fl bz - fn:f32[2] = reduce_sum[axes=(0,) out_sharding=None] fm - fo:f32[2] = sub cx fn - fp:f32[] = dot_general[ + jaxpr={ lambda ; eg:bool[4] ei:f32[] ej:f32[]. let + ek:f32[4] = broadcast_in_dim ei + el:f32[4] = broadcast_in_dim ej + eh:f32[4] = select_n eg el ek + in (eh,) } + ] eg 1.0:f32[] 0.0:f32[] + em:f32[4] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] bx cv + en:f32[4] = mul em eh + eo:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,)] en + ep:f32[4,2] = mul eo bx + eq:f32[2] = reduce_sum[axes=(0,) out_sharding=None] ep + er:f32[2] = sub cv eq + es:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] fo fo - fq:f32[] = sqrt fp - fr:i32[] = add by 1:i32[] - fs:bool[] = lt fr 0:i32[] - ft:i32[] = add fr 4:i32[] - fu:i32[] = select_n fs fr ft - fv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fu - fw:i32[1] = broadcast_in_dim fv - fx:f32[4] = scatter[ + ] er er + et:f32[] = sqrt es + eu:i32[] = add bw 1:i32[] + ev:bool[] = lt eu 0:i32[] + ew:i32[] = add eu 4:i32[] + ex:i32[] = select_n ev eu ew + ey:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ex + ez:i32[1] = broadcast_in_dim ey + fa:f32[4] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] cb fw fq - fy:bool[] = ge fq bx - fz:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fy - ga:f32[4,2] = cond[ + ] bz ez et + fb:bool[] = ge et bv + fc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fb + fd:f32[4,2] = cond[ branches=( - { lambda ; gb:f32[] gc:f32[2] gd:i32[] ge:f32[4,2]. let + { lambda ; fe:f32[] ff:f32[2] fg:i32[] fh:f32[4,2]. let - in (ge,) } - { lambda ; gf:f32[] gg:f32[2] gh:i32[] gi:f32[4,2]. let - gj:bool[] = gt gf 0.0:f32[] - gk:f32[] = jit[name=_where jaxpr=_where] gj gf 1.0:f32[] - gl:f32[] = div 1.0:f32[] gk - gm:f32[] = jit[name=_where jaxpr=_where] gj gl 0.0:f32[] - gn:f32[2] = mul gm gg - go:i32[] = add gh 1:i32[] - gp:bool[] = lt go 0:i32[] - gq:i32[] = add go 4:i32[] - gr:i32[] = select_n gp go gq - gs:i32[] = convert_element_type[ + in (fh,) } + { lambda ; fi:f32[] fj:f32[2] fk:i32[] fl:f32[4,2]. let + fm:bool[] = gt fi 0.0:f32[] + fn:f32[] = jit[name=_where jaxpr=_where] fm fi 1.0:f32[] + fo:f32[] = div 1.0:f32[] fn + fp:f32[] = jit[name=_where jaxpr=_where] fm fo 0.0:f32[] + fq:f32[2] = mul fp fj + fr:i32[] = add fk 1:i32[] + fs:bool[] = lt fr 0:i32[] + ft:i32[] = add fr 4:i32[] + fu:i32[] = select_n fs fr ft + fv:i32[] = convert_element_type[ new_dtype=int32 weak_type=False - ] gr - gt:i32[1] = broadcast_in_dim gs - gu:f32[4,2] = scatter[ + ] fu + fw:i32[1] = broadcast_in_dim fv + fx:f32[4,2] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] gi gt gn - in (gu,) } + ] fl fw fq + in (fx,) } ) - ] fz fq fo by bz - gv:i32[] = add by 1:i32[] - gw:bool[] = ge gv 3:i32[] - gx:i32[] = jit[ + ] fc et er bw bx + fy:i32[] = add bw 1:i32[] + fz:bool[] = ge fy 3:i32[] + ga:i32[] = jit[ name=remainder - jaxpr={ lambda ; gv:i32[] gy:i32[]. let - gz:bool[] = eq gy 0:i32[] - ha:i32[] = jit[ + jaxpr={ lambda ; fy:i32[] gb:i32[]. let + gc:bool[] = eq gb 0:i32[] + gd:i32[] = jit[ name=_where - jaxpr={ lambda ; gz:bool[] hb:i32[] gy:i32[]. let - ha:i32[] = select_n gz gy hb - in (ha,) } - ] gz 1:i32[] gy - hc:i32[] = rem gv ha - hd:bool[] = ne hc 0:i32[] - he:bool[] = lt hc 0:i32[] - hf:bool[] = lt ha 0:i32[] - hg:bool[] = ne he hf - hh:bool[] = and hg hd - hi:i32[] = add hc ha - gx:i32[] = select_n hh hc hi - in (gx,) } - ] gv 1:i32[] - hj:bool[] = eq gx 0:i32[] - hk:bool[] = convert_element_type[new_dtype=bool weak_type=False] gw - hl:bool[] = convert_element_type[new_dtype=bool weak_type=False] hj - hm:bool[] = or hk hl - hn:i32[] = convert_element_type[new_dtype=int32 weak_type=False] hm - ho:bool[] = cond[ + jaxpr={ lambda ; gc:bool[] ge:i32[] gb:i32[]. let + gd:i32[] = select_n gc gb ge + in (gd,) } + ] gc 1:i32[] gb + gf:i32[] = rem fy gd + gg:bool[] = ne gf 0:i32[] + gh:bool[] = lt gf 0:i32[] + gi:bool[] = lt gd 0:i32[] + gj:bool[] = ne gh gi + gk:bool[] = and gj gg + gl:i32[] = add gf gd + ga:i32[] = select_n gk gf gl + in (ga,) } + ] fy 1:i32[] + gm:bool[] = eq ga 0:i32[] + gn:bool[] = convert_element_type[new_dtype=bool weak_type=False] fz + go:bool[] = convert_element_type[new_dtype=bool weak_type=False] gm + gp:bool[] = or gn go + gq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] gp + gr:bool[] = cond[ branches=( - { lambda ; hp:f32[] hq:f32[] hr:bool[] hs:f32[]. let in (hr,) } - { lambda ; ht:f32[] hu:f32[] hv:bool[] hw:f32[]. let - hx:bool[] = ge ht hu - in (hx,) } + { lambda ; gs:f32[] gt:f32[] gu:bool[] gv:f32[]. let in (gu,) } + { lambda ; gw:f32[] gx:f32[] gy:bool[] gz:f32[]. let + ha:bool[] = ge gw gx + in (ha,) } ) - ] hn fq bx cd 0.0:f32[] - in (gv, ga, cu, fx, fq, ho) } - body_nconsts=4 - cond_jaxpr={ lambda ; hy:i32[] hz:f32[4,2] ia:f32[3] ib:f32[4] ic:f32[] id:bool[]. let - ie:bool[] = lt hy 3:i32[] - if:bool[] = convert_element_type[new_dtype=bool weak_type=False] ie - ig:bool[] = and if id - in (ig,) } + ] gq et bv cb 0.0:f32[] + in (fy, fd, cs, fa, et, gr) } + body_nconsts=3 + cond_jaxpr={ lambda ; hb:i32[] hc:f32[4,2] hd:f32[3] he:f32[4] hf:f32[] hg:bool[]. let + hh:bool[] = lt hb 3:i32[] + hi:bool[] = convert_element_type[new_dtype=bool weak_type=False] hh + hj:bool[] = and hi hg + in (hj,) } cond_nconsts=0 - ] e bn bp 9.999999974752427e-07:f32[] 0:i32[] bm h i 1.0:f32[] True:bool[] - ih:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq - ii:bool[3] = lt bo ih - ij:f32[3] = abs bs - ik:f32[] = reduce_max[axes=(0,)] ij - il:f32[4] = abs bt - im:f32[] = reduce_max[axes=(0,)] il - in:f32[] = mul 2.0:f32[] im - io:f32[] = add ik in - ip:f32[] = add io 1.0:f32[] - iq:f32[3] = jit[ + ] e bo 9.999999974752427e-07:f32[] 0:i32[] bn i j 1.0:f32[] True:bool[] + hk:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] + hl:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] + hm:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + hn:bool[3] = lt hk hm + ho:f32[3] = abs br + hp:f32[] = reduce_max[axes=(0,)] ho + hq:f32[4] = abs bs + hr:f32[] = reduce_max[axes=(0,)] hq + hs:f32[] = mul 2.0:f32[] hr + ht:f32[] = add hp hs + hu:f32[] = add ht 1.0:f32[] + hv:f32[3] = jit[ name=_where - jaxpr={ lambda ; ii:bool[3] bs:f32[3] ip:f32[]. let - ir:f32[3] = broadcast_in_dim ip - iq:f32[3] = select_n ii ir bs - in (iq,) } - ] ii bs ip - is:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq - it:bool[4] = eq bn is - iu:f32[4] = jit[ + jaxpr={ lambda ; hn:bool[3] br:f32[3] hu:f32[]. let + hw:f32[3] = broadcast_in_dim hu + hv:f32[3] = select_n hn hw br + in (hv,) } + ] hn br hu + hx:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + hy:bool[4] = eq hl hx + hz:f32[4] = jit[ name=_where - jaxpr={ lambda ; it:bool[4] iv:f32[] bt:f32[4]. let - iw:f32[4] = broadcast_in_dim iv - iu:f32[4] = select_n it bt iw - in (iu,) } - ] it 0.0:f32[] bt - ix:f32[3,3] = broadcast_in_dim 0.0:f32[] - _:i32[] iy:f32[3,3] = scan[ + jaxpr={ lambda ; hy:bool[4] ia:f32[] bs:f32[4]. let + ib:f32[4] = broadcast_in_dim ia + hz:f32[4] = select_n hy bs ib + in (hz,) } + ] hy 0.0:f32[] bs + ic:f32[3,3] = broadcast_in_dim 0.0:f32[] + _:i32[] id:f32[3,3] = scan[ _split_transpose=False - jaxpr={ lambda ; iz:f32[3] ja:i32[] jb:f32[3,3]. let - jc:i32[] = add ja 1:i32[] - jd:bool[] = lt ja 0:i32[] - je:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ja - jf:i32[] = add je 3:i32[] - jg:i32[] = select_n jd ja jf - jh:f32[1] = dynamic_slice[slice_sizes=(1,)] iz jg - ji:f32[] = squeeze[dimensions=(0,)] jh - jj:bool[] = lt ja 0:i32[] - jk:i32[] = add ja 3:i32[] - jl:i32[] = select_n jj ja jk - jm:bool[] = lt ja 0:i32[] - jn:i32[] = add ja 3:i32[] - jo:i32[] = select_n jm ja jn - jp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jl - jq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jo - jr:i32[1] = broadcast_in_dim jp - js:i32[1] = broadcast_in_dim jq - jt:i32[2] = concatenate[dimension=0] jr js - ju:f32[3,3] = scatter[ + jaxpr={ lambda ; ie:f32[3] if:i32[] ig:f32[3,3]. let + ih:i32[] = add if 1:i32[] + ii:bool[] = lt if 0:i32[] + ij:i32[] = convert_element_type[new_dtype=int32 weak_type=False] if + ik:i32[] = add ij 3:i32[] + il:i32[] = select_n ii if ik + im:f32[1] = dynamic_slice[slice_sizes=(1,)] ie il + in:f32[] = squeeze[dimensions=(0,)] im + io:bool[] = lt if 0:i32[] + ip:i32[] = add if 3:i32[] + iq:i32[] = select_n io if ip + ir:bool[] = lt if 0:i32[] + is:i32[] = add if 3:i32[] + it:i32[] = select_n ir if is + iu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] iq + iv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] it + iw:i32[1] = broadcast_in_dim iu + ix:i32[1] = broadcast_in_dim iv + iy:i32[2] = concatenate[dimension=0] iw ix + iz:f32[3,3] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] jb jt ji - in (jc, ju) } + ] ig iy in + in (ih, iz) } length=3 linear=(False, False, False) num_carry=2 num_consts=1 reverse=False unroll=1 - ] iq 0:i32[] ix - _:i32[] jv:f32[3,3] = scan[ + ] hv 0:i32[] ic + _:i32[] ja:f32[3,3] = scan[ _split_transpose=False - jaxpr={ lambda ; jw:f32[4] jx:i32[] jy:f32[3,3]. let - jz:i32[] = add jx 1:i32[] - ka:i32[] = add jx 1:i32[] - kb:bool[] = lt ka 0:i32[] - kc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ka - kd:i32[] = add kc 4:i32[] - ke:i32[] = select_n kb ka kd - kf:f32[1] = dynamic_slice[slice_sizes=(1,)] jw ke - kg:f32[] = squeeze[dimensions=(0,)] kf - kh:i32[] = add jx 1:i32[] - ki:bool[] = lt jx 0:i32[] - kj:i32[] = add jx 3:i32[] - kk:i32[] = select_n ki jx kj - kl:bool[] = lt kh 0:i32[] - km:i32[] = add kh 3:i32[] - kn:i32[] = select_n kl kh km - ko:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kk - kp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kn - kq:i32[1] = broadcast_in_dim ko - kr:i32[1] = broadcast_in_dim kp - ks:i32[2] = concatenate[dimension=0] kq kr - kt:f32[3,3] = scatter[ + jaxpr={ lambda ; jb:f32[4] jc:i32[] jd:f32[3,3]. let + je:i32[] = add jc 1:i32[] + jf:i32[] = add jc 1:i32[] + jg:bool[] = lt jf 0:i32[] + jh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jf + ji:i32[] = add jh 4:i32[] + jj:i32[] = select_n jg jf ji + jk:f32[1] = dynamic_slice[slice_sizes=(1,)] jb jj + jl:f32[] = squeeze[dimensions=(0,)] jk + jm:i32[] = add jc 1:i32[] + jn:bool[] = lt jc 0:i32[] + jo:i32[] = add jc 3:i32[] + jp:i32[] = select_n jn jc jo + jq:bool[] = lt jm 0:i32[] + jr:i32[] = add jm 3:i32[] + js:i32[] = select_n jq jm jr + jt:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jp + ju:i32[] = convert_element_type[new_dtype=int32 weak_type=False] js + jv:i32[1] = broadcast_in_dim jt + jw:i32[1] = broadcast_in_dim ju + jx:i32[2] = concatenate[dimension=0] jv jw + jy:f32[3,3] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] jy ks kg - ku:i32[] = add jx 1:i32[] - kv:bool[] = lt ku 0:i32[] - kw:i32[] = add ku 3:i32[] - kx:i32[] = select_n kv ku kw - ky:bool[] = lt jx 0:i32[] - kz:i32[] = add jx 3:i32[] - la:i32[] = select_n ky jx kz - lb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kx - lc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] la - ld:i32[1] = broadcast_in_dim lb - le:i32[1] = broadcast_in_dim lc - lf:i32[2] = concatenate[dimension=0] ld le - lg:f32[3,3] = scatter[ + ] jd jx jl + jz:i32[] = add jc 1:i32[] + ka:bool[] = lt jz 0:i32[] + kb:i32[] = add jz 3:i32[] + kc:i32[] = select_n ka jz kb + kd:bool[] = lt jc 0:i32[] + ke:i32[] = add jc 3:i32[] + kf:i32[] = select_n kd jc ke + kg:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kc + kh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kf + ki:i32[1] = broadcast_in_dim kg + kj:i32[1] = broadcast_in_dim kh + kk:i32[2] = concatenate[dimension=0] ki kj + kl:f32[3,3] = scatter[ dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) indices_are_sorted=True mode=GatherScatterMode.FILL_OR_DROP unique_indices=True update_consts=() update_jaxpr=None - ] kt lf kg - in (jz, lg) } + ] jy kk jl + in (je, kl) } length=2 linear=(False, False, False) num_carry=2 num_consts=1 reverse=False unroll=1 - ] iu 0:i32[] iy - _:f32[3] lh:f32[3,3] = jit[ + ] hz 0:i32[] id + _:f32[3] km:f32[3,3] = jit[ name=eigh - jaxpr={ lambda ; jv:f32[3,3]. let - li:f32[3,3] = transpose[permutation=(1, 0)] jv - lj:f32[3,3] = add jv li - lk:f32[3,3] = div lj 2.0:f32[] - lh:f32[3,3] _:f32[3] = eigh[ + jaxpr={ lambda ; ja:f32[3,3]. let + kn:f32[3,3] = transpose[permutation=(1, 0)] ja + ko:f32[3,3] = add ja kn + kp:f32[3,3] = div ko 2.0:f32[] + km:f32[3,3] _:f32[3] = eigh[ algorithm=None lower=True sort_eigenvalues=True subset_by_index=None - ] lk - in (_, lh) } - ] jv - ll:f32[3,1] = slice[limit_indices=(3, 1) start_indices=(0, 0) strides=None] lh - lm:f32[3] = squeeze[dimensions=(1,)] ll - ln:bool[] = lt bq 0:i32[] - lo:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq - lp:i32[] = add lo 4:i32[] - lq:i32[] = select_n ln bq lp - lr:f32[1] = dynamic_slice[slice_sizes=(1,)] bt lq - ls:f32[] = squeeze[dimensions=(0,)] lr - lt:i32[] = sub bq 1:i32[] - lu:bool[] = lt lt 0:i32[] - lv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] lt - lw:i32[] = add lv 3:i32[] - lx:i32[] = select_n lu lt lw - ly:f32[1] = dynamic_slice[slice_sizes=(1,)] lm lx - lz:f32[] = squeeze[dimensions=(0,)] ly - ma:f32[] = abs lz - mb:f32[] = mul ls ma - _:bool[] = lt mb 9.999999974752427e-07:f32[] - mc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bq - md:bool[3] = lt bo mc - me:f32[3] = jit[ + ] kp + in (_, km) } + ] ja + kq:f32[3,1] = slice[limit_indices=(3, 1) start_indices=(0, 0) strides=None] km + kr:f32[3] = squeeze[dimensions=(1,)] kq + ks:bool[] = lt bp 0:i32[] + kt:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + ku:i32[] = add kt 4:i32[] + kv:i32[] = select_n ks bp ku + kw:f32[1] = dynamic_slice[slice_sizes=(1,)] bs kv + kx:f32[] = squeeze[dimensions=(0,)] kw + ky:i32[] = sub bp 1:i32[] + kz:bool[] = lt ky 0:i32[] + la:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ky + lb:i32[] = add la 3:i32[] + lc:i32[] = select_n kz ky lb + ld:f32[1] = dynamic_slice[slice_sizes=(1,)] kr lc + le:f32[] = squeeze[dimensions=(0,)] ld + lf:f32[] = abs le + lg:f32[] = mul kx lf + _:bool[] = lt lg 9.999999974752427e-07:f32[] + lh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + li:bool[3] = lt g lh + lj:f32[3] = jit[ name=_where - jaxpr={ lambda ; md:bool[3] mf:f32[] mg:f32[]. let - mh:f32[3] = broadcast_in_dim mf - mi:f32[3] = broadcast_in_dim mg - me:f32[3] = select_n md mi mh - in (me,) } - ] md 1.0:f32[] 0.0:f32[] - mj:f32[3] = mul lm me - mk:f32[3,2] = slice[limit_indices=(3, 2) start_indices=(0, 0) strides=None] br - ml:f32[2] = dot_general[ + jaxpr={ lambda ; li:bool[3] lk:f32[] ll:f32[]. let + lm:f32[3] = broadcast_in_dim lk + ln:f32[3] = broadcast_in_dim ll + lj:f32[3] = select_n li ln lm + in (lj,) } + ] li 1.0:f32[] 0.0:f32[] + lo:f32[3] = mul kr lj + lp:f32[3,2] = slice[limit_indices=(3, 2) start_indices=(0, 0) strides=None] bq + lq:f32[2] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] mj mk - mm:f32[] = dot_general[ + ] lo lp + lr:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] ml ml - mn:f32[] = sqrt mm - mo:bool[] = gt mn 9.999999960041972e-13:f32[] - mp:i32[] = convert_element_type[new_dtype=int32 weak_type=False] mo - mq:f32[2] = cond[ + ] lq lq + ls:f32[] = sqrt lr + lt:bool[] = gt ls 9.999999960041972e-13:f32[] + lu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] lt + lv:f32[2] = cond[ branches=( - { lambda ; mr:f32[] ms:f32[2] mt:f32[2] mu:f32[]. let in (mt,) } - { lambda ; mv:f32[] mw:f32[2] mx:f32[2] my:f32[]. let - mz:bool[] = gt mv 0.0:f32[] - na:f32[] = jit[name=_where jaxpr=_where] mz mv 1.0:f32[] - nb:f32[] = div 1.0:f32[] na - nc:f32[] = jit[name=_where jaxpr=_where] mz nb 0.0:f32[] - nd:f32[2] = mul nc mw - in (nd,) } + { lambda ; lw:f32[] lx:f32[2] ly:f32[2] lz:f32[]. let in (ly,) } + { lambda ; ma:f32[] mb:f32[2] mc:f32[2] md:f32[]. let + me:bool[] = gt ma 0.0:f32[] + mf:f32[] = jit[name=_where jaxpr=_where] me ma 1.0:f32[] + mg:f32[] = div 1.0:f32[] mf + mh:f32[] = jit[name=_where jaxpr=_where] me mg 0.0:f32[] + mi:f32[2] = mul mh mb + in (mi,) } ) - ] mp mn ml u 0.0:f32[] - ne:f32[2] = dot_general[ + ] lu ls lq v 0.0:f32[] + mj:f32[2] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 - ] e mq - nf:f32[] = dot_general[ + ] e lv + mk:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] mq ne - ng:f32[] = dot_general[ + ] lv mj + ml:f32[] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 - ] mq mq - nh:f32[] = div nf ng - in (nh,) } \ No newline at end of file + ] lv lv + mm:f32[] = div mk ml + in (mm,) } \ No newline at end of file From ce375f4f76f6994ba3f2c3eef67d4dd08dfa6d33 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 12:36:00 -0300 Subject: [PATCH 30/44] Pin Lanczos weighted-space slow path --- tests/linalg/test_krylov.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index a6fd825..9b663ba 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -365,6 +365,8 @@ def _convert(self, new_ctx): return WeightedVectorSpace(new_ctx.asarray(self.weights), new_ctx) space = WeightedVectorSpace([1.0, 4.0], ctx) + assert type(space) is not sc.VectorSpace + matrix = np.array([[2.0, 1.0], [0.25, 0.75]]) op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) From 3ae3d8569648711829787363448fef853fbc29ff Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 12:36:45 -0300 Subject: [PATCH 31/44] Run JIT audit in CI --- .github/workflows/ci.yml | 1 + scripts/jit_audit.py | 59 ++++++++++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca7ab8a..567cfc1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,7 @@ jobs: - run: python -m pip install --upgrade pip - run: pip install -e ".[jax,torch,dev]" - run: pytest --cov=spacecore --cov-report=term-missing --cov-fail-under=70 + - run: python scripts/jit_audit.py --check - run: ruff check . publish: diff --git a/scripts/jit_audit.py b/scripts/jit_audit.py index cd41afd..a822b2c 100644 --- a/scripts/jit_audit.py +++ b/scripts/jit_audit.py @@ -1,5 +1,6 @@ from __future__ import annotations +import argparse from pathlib import Path from typing import Any, Callable @@ -85,11 +86,44 @@ def _audit_solver( } +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Audit JAX trace stability for SpaceCore solvers.") + parser.add_argument( + "--check", + action="store_true", + help="Exit non-zero if any solver retraces on same-shape inputs.", + ) + parser.add_argument( + "--log-compiles", + action="store_true", + help="Enable jax_log_compiles for manual inspection.", + ) + parser.add_argument( + "--write-fixture", + action="store_true", + help="Write the lanczos_smallest JAXPR fixture. Disabled by default in --check mode.", + ) + return parser.parse_args() + + +def _audit_failed(item: dict[str, Any]) -> bool: + if "status" in item: + return True + return ( + item["traces_after_two_same_shape_calls"] != 1 + or item["stable_values_retraced"] + or not item["static_change_retraced"] + or not item["shape_change_retraced"] + ) + + def main() -> None: import jax import spacecore as sc - jax.config.update("jax_log_compiles", True) + args = _parse_args() + if args.log_compiles: + jax.config.update("jax_log_compiles", True) A2 = _spd_operator(2) A3 = _spd_operator(3) @@ -190,16 +224,25 @@ def main() -> None: else: audits.append({"solver": "expm_multiply", "status": "not available before Task 1"}) - FIXTURE.parent.mkdir(parents=True, exist_ok=True) - jaxpr = jax.make_jaxpr( - lambda A, x: sc.lanczos_smallest(A, x, max_iter=3, check_every=1).eigenvalue - )(A2, x2a) - FIXTURE.write_text(str(jaxpr)) - print("JIT audit summary") for item in audits: print(item) - print(f"wrote {FIXTURE.relative_to(ROOT)}") + + if args.write_fixture or not args.check: + FIXTURE.parent.mkdir(parents=True, exist_ok=True) + jaxpr = jax.make_jaxpr( + lambda A, x: sc.lanczos_smallest(A, x, max_iter=3, check_every=1).eigenvalue + )(A2, x2a) + FIXTURE.write_text(str(jaxpr)) + print(f"wrote {FIXTURE.relative_to(ROOT)}") + + if args.check: + failures = [item for item in audits if _audit_failed(item)] + if failures: + print("JIT audit check failed") + for item in failures: + print(item) + raise SystemExit(1) if __name__ == "__main__": From 0586f860958a5962fcf95a79252d268dc0b5c846 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 12:38:41 -0300 Subject: [PATCH 32/44] Document JAX integration guidance Add JAX JIT usage notes for operator algebra and link them from the docs index and README. --- README.md | 2 ++ docs/source/design/index.rst | 1 + docs/source/design/jax_integration.rst | 44 ++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 docs/source/design/jax_integration.rst diff --git a/README.md b/README.md index ae4e09a..b445a65 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,8 @@ its [notebook](https://github.com/Pavlo3P/SpaceCore/blob/master/tutorials/6_Regu ## Documentation The hosted documentation is available [here](https://pavlo3p.github.io/SpaceCore/). +JAX integration notes are available +[here](https://pavlo3p.github.io/SpaceCore/design/jax_integration.html). The documentation website is built with Sphinx from `docs/source`. diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 390b4ad..f722ba5 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -10,4 +10,5 @@ users should reason about them. conversion_policy dtype_policy checking_policy + jax_integration backend_ops_array_api diff --git a/docs/source/design/jax_integration.rst b/docs/source/design/jax_integration.rst new file mode 100644 index 0000000..ef99186 --- /dev/null +++ b/docs/source/design/jax_integration.rst @@ -0,0 +1,44 @@ +JAX integration +=============== + +JIT usage notes +--------------- + +SpaceCore's numerical kernels are written to run under ``jax.jit`` when values +live in a JAX-backed ``Context``. The object model remains ordinary Python: +spaces, operators, and functionals are assembled before the numerical kernel is +traced, then passed into the jitted function. + +Operator algebra such as ``A @ B`` and ``A + B`` executes Python-level +simplification rules at construction time. For maximum JIT efficiency: + +* construct operator expressions outside the JIT-decorated function; +* pass the assembled operator as an argument to the jitted function; +* avoid calling ``make_sum`` or ``make_composed`` from inside a ``jax.jit`` + body. + +This is a trace-time concern rather than a correctness concern. The algebra is +correct either way, but composing inside ``jax.jit`` means the simplification +runs once per trace. For repeatedly invoked code with stable operator +structure, build the expression once outside the jitted function. + +Example: + +.. code-block:: python + + import jax + import spacecore as sc + + ctx = sc.Context(sc.JaxOps(), dtype="float32") + X = sc.VectorSpace((128,), ctx) + A = build_operator(X) + B = build_preconditioner(X) + + # Build algebra outside the JIT boundary. + system = B.H @ A @ B + 0.01 * sc.IdentityLinOp(X, ctx) + + @jax.jit + def solve(op, rhs): + return sc.cg(op, rhs, maxiter=50).x + + x = solve(system, rhs) From d7188b680ffdaaa1173f8037cb8aa558daf69f94 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 12:46:11 -0300 Subject: [PATCH 33/44] Use checked_method for Space membership checks --- docs/source/design/checking_policy.rst | 12 +++++++++ spacecore/_checks.py | 37 ++++++++++++++++++++------ spacecore/space/_batch.py | 20 +++++--------- spacecore/space/_herm.py | 7 ++--- spacecore/space/_product.py | 20 +++++--------- spacecore/space/_vector.py | 18 +++++-------- tests/context/test_checked_method.py | 30 +++++++++++++++++++++ 7 files changed, 93 insertions(+), 51 deletions(-) diff --git a/docs/source/design/checking_policy.rst b/docs/source/design/checking_policy.rst index 1b4976e..bfb8ca7 100644 --- a/docs/source/design/checking_policy.rst +++ b/docs/source/design/checking_policy.rst @@ -36,6 +36,18 @@ before ``apply`` and ``rapply`` when checking is enabled. For exploratory use, enabled checks produce clearer errors. For tight numerical loops, disabled checks reduce validation overhead. +Implementation convention +------------------------- + +Methods that perform simple membership validation should use +``@checked_method`` rather than inline ``if self._enable_checks`` branches. This +keeps validation policy visible at the method signature and avoids duplicating +the same guard throughout spaces, operators, and functionals. + +Inline ``if self._enable_checks`` blocks are reserved for checks that are not +plain membership checks, such as dense-array assertions, custom output-shape +comparisons, or the implementation of ``_check_member`` itself. + Inferred checking policy ------------------------ diff --git a/spacecore/_checks.py b/spacecore/_checks.py index 6bae056..dbfc965 100644 --- a/spacecore/_checks.py +++ b/spacecore/_checks.py @@ -4,11 +4,26 @@ from typing import Any, Callable +def _as_positions(arg_pos: int | None, arg_positions: int | tuple[int, ...] | None) -> tuple[int, ...]: + if arg_pos is not None and arg_positions is not None: + raise TypeError("Use either arg_pos or arg_positions, not both.") + if arg_positions is None: + return (0,) if arg_pos is None else (arg_pos,) + if isinstance(arg_positions, int): + return (arg_positions,) + return tuple(arg_positions) + + +def _space_target(self: Any, space_name: str) -> Any: + return self if space_name == "self" else getattr(self, space_name) + + def checked_method( *, in_space: str | None = None, out_space: str | None = None, - arg_pos: int = 0, + arg_pos: int | None = None, + arg_positions: int | tuple[int, ...] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Build a decorator that validates method inputs and outputs against spaces. @@ -17,13 +32,17 @@ def checked_method( ---------- in_space: Name of the attribute on ``self`` containing the input - :class:`~spacecore.space.Space`, or ``None`` to skip input validation. + :class:`~spacecore.space.Space`, ``"self"`` to validate against the + receiver itself, or ``None`` to skip input validation. out_space: Name of the attribute on ``self`` containing the output - :class:`~spacecore.space.Space`, or ``None`` to skip output validation. + :class:`~spacecore.space.Space`, ``"self"`` to validate against the + receiver itself, or ``None`` to skip output validation. arg_pos: - Zero-based position in ``*args`` of the input value that should be - checked against ``in_space``. + Deprecated alias for a single entry in ``arg_positions``. + arg_positions: + Zero-based positions in ``*args`` of input values that should be checked + against ``in_space``. Defaults to ``(0,)``. Returns ------- @@ -32,18 +51,20 @@ def checked_method( ``self._enable_checks`` is true, and otherwise forwards directly to the wrapped method. """ + positions = _as_positions(arg_pos, arg_positions) def decorate(method: Callable[..., Any]) -> Callable[..., Any]: @wraps(method) def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if self._enable_checks and in_space is not None: - x = args[arg_pos] - getattr(self, in_space)._check_member(x) + check_target = _space_target(self, in_space) + for pos in positions: + check_target._check_member(args[pos]) y = method(self, *args, **kwargs) if self._enable_checks and out_space is not None: - getattr(self, out_space)._check_member(y) + _space_target(self, out_space)._check_member(y) return y diff --git a/spacecore/space/_batch.py b/spacecore/space/_batch.py index b227c0d..12f2971 100644 --- a/spacecore/space/_batch.py +++ b/spacecore/space/_batch.py @@ -6,6 +6,7 @@ from ._base import Space from ._checks import BackendCheck, DTypeCheck, ShapeCheck from ._product import ProductSpace +from .._checks import checked_method from ..backend import Context from ..types import DenseArray @@ -106,25 +107,20 @@ def zeros(self) -> Any: return tuple(space.zeros() for space in self._component_spaces()) return self.ops.zeros(self.shape, dtype=self.dtype) + @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) if isinstance(self.base, ProductSpace): return tuple(space.add(xi, yi) for space, xi, yi in zip(self._component_spaces(), x, y)) return x + y + @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> Any: - if self._enable_checks: - self._check_member(x) if isinstance(self.base, ProductSpace): return tuple(space.scale(a, xi) for space, xi in zip(self._component_spaces(), x)) return a * x + @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) if isinstance(self.base, ProductSpace): acc = None for space, xi, yi in zip(self._component_spaces(), x, y): @@ -136,9 +132,8 @@ def inner(self, x: Any, y: Any) -> Any: def eigh(self, x: Any, k: int = None) -> Any: raise TypeError(f"{type(self).__name__}.eigh is not defined for batched spaces.") + @checked_method(in_space="self") def flatten(self, x: Any) -> DenseArray: - if self._enable_checks: - self._check_member(x) if isinstance(self.base, ProductSpace): parts = tuple(space.flatten(xi) for space, xi in zip(self._component_spaces(), x)) return parts[0] if len(parts) == 1 else self.ops.concatenate(parts, axis=0) @@ -168,17 +163,14 @@ def unflatten(self, v: DenseArray) -> Any: return tuple(xs) return self.ops.reshape(vv, self.shape) + @checked_method(in_space="self", out_space="self") def apply(self, x: Any, f: Callable) -> Any: - if self._enable_checks: - self._check_member(x) if isinstance(self.base, ProductSpace): return tuple(space.apply(xi, f) for space, xi in zip(self._component_spaces(), x)) try: y = f(x) except Exception: y = self.ops.vmap(lambda xi: self.base.apply(xi, f))(x) - if self._enable_checks: - self._check_member(y) return y def _convert(self, new_ctx: Context) -> BatchSpace: diff --git a/spacecore/space/_herm.py b/spacecore/space/_herm.py index ef6418b..e8117af 100644 --- a/spacecore/space/_herm.py +++ b/spacecore/space/_herm.py @@ -4,6 +4,7 @@ from ._checks import HermitianCheck, SquareMatrixCheck from ._vector import VectorSpace +from .._checks import checked_method from ..types import DenseArray from ..backend import Context @@ -70,8 +71,8 @@ def symmetrize(self, x: DenseArray) -> DenseArray: """Project onto the Hermitian cone: (X + X^H)/2.""" return (x + x.T.conj()) * 0.5 + @checked_method(in_space="self") def eigh(self, x: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]: - self.check_member(x) return self.ops.eigh(x) def unflatten(self, v: DenseArray) -> DenseArray: @@ -79,8 +80,8 @@ def unflatten(self, v: DenseArray) -> DenseArray: X = vv.reshape(self.shape) return self.symmetrize(X) + @checked_method(in_space="self") def psd_proj(self, x: DenseArray) -> DenseArray: - self.check_member(x) evals, evecs = self.ops.eigh(x) evals = self.ops.maximum(evals, 0.) return self.eig_to_dense(evals, evecs) @@ -95,6 +96,7 @@ def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: def _convert(self, new_ctx: Context) -> HermitianSpace: return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx) + @checked_method(in_space="self") def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: r""" Apply a scalar function to a Hermitian matrix via spectral calculus. @@ -149,7 +151,6 @@ def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseAr then the eigenvectors are preserved and only the eigenvalues are transformed. """ - self.check_member(x) evals, evecs = self.ops.eigh(x) fevals = self._apply_entrywise(evals, f) diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index 4862634..6e52b5e 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -5,6 +5,7 @@ from ._base import Space from ._checks import ProductComponentCheck, ProductStructureCheck from ._vector import VectorSpace +from .._checks import checked_method from ..types import DenseArray from ..backend import Context @@ -113,22 +114,16 @@ def arity(self) -> int: def zeros(self) -> Tuple[Any, ...]: return tuple(s.zeros() for s in self.spaces) + @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Tuple[Any, ...]: - if self._enable_checks: - self._check_member(x) - self._check_member(y) return tuple(s.add(xi, yi) for s, xi, yi in zip(self.spaces, x, y)) + @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Tuple[Any, ...]) -> Tuple[Any, ...]: - if self._enable_checks: - self._check_member(x) return tuple(s.scale(a, xi) for s, xi in zip(self.spaces, x)) + @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) - # Accumulate via backend ops (vdot works for scalars too, but sum is enough) acc = None for s, xi, yi in zip(self.spaces, x, y): @@ -142,10 +137,8 @@ def eigh(self, x: Any, k: int = None) -> Any: "Call eigh on a specific component space, or define a custom convention." ) + @checked_method(in_space="self") def flatten(self, x: Tuple[Any, ...]) -> DenseArray: - if self._enable_checks: - self._check_member(x) - if self._vector_fast_path: if self._arity == 1: return x[0] if self._component_is_flat[0] else x[0].reshape((-1,)) @@ -210,6 +203,7 @@ def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: return tuple(xs) + @checked_method(in_space="self", out_space="self") def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: r""" Apply a function to each component of a product-space element. @@ -257,8 +251,6 @@ def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: product space. It applies the existing functional calculus of each factor space independently, component by component. """ - if self._enable_checks: - self._check_member(x) if self._arity == 2: return self.spaces[0].apply(x[0], f), self.spaces[1].apply(x[1], f) return tuple(s.apply(xi, f) for s, xi in zip(self.spaces, x)) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index 2f23b50..47f25df 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -5,6 +5,7 @@ from ._base import Space from ._checks import BackendCheck, DTypeCheck, ShapeCheck +from .._checks import checked_method from ..types import DenseArray from ..backend import Context @@ -33,21 +34,16 @@ def _local_checks(self): def zeros(self) -> DenseArray: return self.ops.zeros(self.shape, dtype=self.dtype) + @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> DenseArray: - if self._enable_checks: - self._check_member(x) - self._check_member(y) return x + y + @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> DenseArray: - if self._enable_checks: - self._check_member(x) return a * x + @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) return self.ops.vdot(x, y) def eigh(self, x: Any, k: int = None) -> Any: @@ -55,9 +51,8 @@ def eigh(self, x: Any, k: int = None) -> Any: f"{type(self).__name__}.eigh is not defined for vector spaces." ) + @checked_method(in_space="self") def flatten(self, X: DenseArray) -> DenseArray: - if self._enable_checks: - self._check_member(X) return X if self._is_flat_shape else X.reshape((-1,)) def unflatten(self, v: DenseArray) -> DenseArray: @@ -78,6 +73,7 @@ def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) raise ValueError("Function application changed shape.") return y + @checked_method(in_space="self", out_space="self") def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: r""" Apply a scalar function to a vector-space element entrywise. @@ -121,7 +117,5 @@ def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseAr application is performed entrywise in the distinguished coordinate representation. """ - if self._enable_checks: - self._check_member(x) y = self._apply_entrywise(x, f) return y diff --git a/tests/context/test_checked_method.py b/tests/context/test_checked_method.py index 1d56ac2..9ac3d12 100644 --- a/tests/context/test_checked_method.py +++ b/tests/context/test_checked_method.py @@ -42,6 +42,17 @@ def value(self, x): def grad(self, x): return self.grad_result + @checked_method(in_space="self", arg_positions=(0, 1)) + def combine(self, x, y): + return "combined" + + @checked_method(in_space="self", arg_pos=0) + def legacy_single_arg(self, x): + return "legacy" + + def _check_member(self, value): + self.space._check_member(value) + def test_checked_method_validates_apply_input_and_output(): demo = _CheckedDemo() @@ -101,3 +112,22 @@ def test_checked_method_preserves_metadata(): assert _CheckedDemo.apply.__name__ == "apply" assert _CheckedDemo.apply.__doc__ == "Apply docstring." assert _CheckedDemo.apply.__wrapped__ is not None + + +def test_checked_method_supports_self_target_and_multiple_input_args(): + demo = _CheckedDemo() + + assert demo.combine("z", "z") == "combined" + assert demo.space.calls == ["z", "z"] + + +def test_checked_method_arg_pos_alias_still_works(): + demo = _CheckedDemo() + + assert demo.legacy_single_arg("z") == "legacy" + assert demo.space.calls == ["z"] + + +def test_checked_method_rejects_arg_pos_and_arg_positions_together(): + with pytest.raises(TypeError, match="arg_pos"): + checked_method(in_space="space", arg_pos=0, arg_positions=(0,)) From 7d543b077942c500acf1c228ebba5947b731c373 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 21:52:56 -0300 Subject: [PATCH 34/44] Add docstring migration tooling baseline --- .github/workflows/ci.yml | 2 + MIGRATION.md | 20 ++++++++ pyproject.toml | 36 ++++++++++++++- scripts/docstring_audit.py | 95 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 MIGRATION.md create mode 100644 scripts/docstring_audit.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 567cfc1..1803dfc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,8 @@ jobs: - run: pytest --cov=spacecore --cov-report=term-missing --cov-fail-under=70 - run: python scripts/jit_audit.py --check - run: ruff check . + - run: ruff check --select D spacecore/ || true + - run: python scripts/docstring_audit.py || true publish: if: startsWith(github.ref, 'refs/tags/') diff --git a/MIGRATION.md b/MIGRATION.md new file mode 100644 index 0000000..b083497 --- /dev/null +++ b/MIGRATION.md @@ -0,0 +1,20 @@ +# Docstring migration progress + +Baseline (2026-05-27): + +- Ruff pydocstyle: 60 `D` violations from `ruff check --select D spacecore/`. +- Numpydoc validation: 133 actionable issues from `python scripts/docstring_audit.py` + after the Phase 0 allow-list (`ES01`, `EX01`, `SA01`, `GL08`). +- Numpydoc validation, raw: 306 issues from + `python scripts/docstring_audit.py --include-allowed`. +- Doctest: 0 doctest examples collected from + `pytest --doctest-modules spacecore/ -x` under the initial ignore list. + +Notes: + +- The installed `numpydoc` command validates import paths, not package + directories, so `scripts/docstring_audit.py` records and reports the public + SpaceCore API baseline. +- Ruff docstring rules are run as a separate non-blocking CI warning while the + migration is in progress, so the existing strict `ruff check .` step remains + unchanged. diff --git a/pyproject.toml b/pyproject.toml index a76d717..0f531f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dev = [ "pytest>=8.0", "pytest-cov>=5", "ruff>=0.6", + "numpydoc>=1.7", ] [tool.setuptools] @@ -72,12 +73,45 @@ where = ["."] include = ["spacecore*"] [tool.pytest.ini_options] -testpaths = ["tests"] +addopts = [ + "--doctest-modules", + "--ignore-glob=spacecore/_contextual/*", + "--ignore-glob=spacecore/backend/cupy/*", + "--ignore-glob=spacecore/backend/torch/*", +] +testpaths = ["spacecore", "tests"] +doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS", "NUMBER"] [tool.ruff] line-length = 100 target-version = "py311" +[tool.ruff.lint] +ignore = [ + "D100", + "D104", + "D203", + "D213", +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + [tool.ruff.lint.per-file-ignores] "tutorials/*.ipynb" = ["E741"] "tests/*.py" = ["E701"] +"tests/*" = ["D"] + +[tool.numpydoc_validation] +checks = [ + "all", + "ES01", + "EX01", + "SA01", + "GL08", +] +exclude = [ + '\._', + '\.tests\.', + 'test_', +] diff --git a/scripts/docstring_audit.py b/scripts/docstring_audit.py new file mode 100644 index 0000000..e83916d --- /dev/null +++ b/scripts/docstring_audit.py @@ -0,0 +1,95 @@ +"""Report numpydoc validation issues for SpaceCore's public API.""" + +from __future__ import annotations + +import argparse +import inspect +from collections.abc import Iterable +from dataclasses import dataclass + +from numpydoc.validate import validate + +import spacecore + +ALLOWED_CODES = frozenset({"ES01", "EX01", "SA01", "GL08"}) + + +@dataclass(frozen=True) +class ValidationIssue: + """A single numpydoc validation issue.""" + + target: str + code: str + message: str + + +def _iter_public_targets() -> Iterable[str]: + """Yield public import paths exported by the top-level package.""" + for name in getattr(spacecore, "__all__", ()): + if name.startswith("_"): + continue + target = f"spacecore.{name}" + try: + obj = getattr(spacecore, name) + except AttributeError: + continue + if inspect.ismodule(obj): + continue + yield target + + +def _validate_target(target: str, *, include_allowed: bool) -> list[ValidationIssue]: + """Validate one import path and normalize numpydoc's result shape.""" + try: + result = validate(target) + except Exception as exc: # pragma: no cover - defensive reporting path + return [ValidationIssue(target, "IMPORT", f"{type(exc).__name__}: {exc}")] + + issues = [] + for code, message in result.get("errors", []): + if not include_allowed and code in ALLOWED_CODES: + continue + issues.append(ValidationIssue(target, code, message)) + return issues + + +def collect_issues(*, include_allowed: bool = False) -> list[ValidationIssue]: + """Collect numpydoc issues for exported public symbols.""" + issues: list[ValidationIssue] = [] + for target in sorted(set(_iter_public_targets())): + issues.extend(_validate_target(target, include_allowed=include_allowed)) + return issues + + +def main() -> int: + """Run the audit command.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--check", + action="store_true", + help="exit non-zero when validation issues are present", + ) + parser.add_argument( + "--max-lines", + type=int, + default=200, + help="maximum number of individual issues to print", + ) + parser.add_argument( + "--include-allowed", + action="store_true", + help="include issues allowed during the migration baseline", + ) + args = parser.parse_args() + + issues = collect_issues(include_allowed=args.include_allowed) + for issue in issues[: args.max_lines]: + print(f"{issue.target}:{issue.code}:{issue.message}") + if len(issues) > args.max_lines: + print(f"... {len(issues) - args.max_lines} more issues omitted") + print(f"numpydoc issues: {len(issues)}") + return 1 if args.check and issues else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 5b99e646f858650150c7f102045698fba5e1380b Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 22:14:14 -0300 Subject: [PATCH 35/44] Improve public API docstrings --- spacecore/__init__.py | 2 + spacecore/_checks.py | 15 ++- spacecore/_contextual/_bound.py | 16 +-- spacecore/_contextual/_manager.py | 49 ++++++-- spacecore/_contextual/_state.py | 5 +- spacecore/backend/__init__.py | 2 + spacecore/backend/_context.py | 28 ++++- spacecore/backend/_ops.py | 13 ++- spacecore/backend/cupy/__init__.py | 2 + spacecore/backend/jax/__init__.py | 4 +- spacecore/backend/jax/_ops.py | 11 +- spacecore/backend/jax/_pytree.py | 20 +++- spacecore/backend/numpy/__init__.py | 4 +- spacecore/backend/numpy/_ops.py | 26 ++--- spacecore/backend/torch/__init__.py | 2 + spacecore/backend/torch/_ops.py | 8 +- spacecore/functional/__init__.py | 2 + spacecore/functional/_base.py | 36 ++++-- spacecore/functional/_composed.py | 16 ++- spacecore/functional/_linear.py | 46 ++++++-- spacecore/functional/_quadratic.py | 41 ++++++- spacecore/linalg/__init__.py | 2 + spacecore/linalg/_cg.py | 107 +++++++++++++++-- spacecore/linalg/_expm.py | 77 ++++++++++--- spacecore/linalg/_lanczos.py | 141 +++++++++++++++++++---- spacecore/linalg/_lsqr.py | 110 ++++++++++++++++-- spacecore/linalg/_power.py | 101 +++++++++++++++-- spacecore/linop/__init__.py | 2 + spacecore/linop/_algebra.py | 145 ++++++++++++++++++++++-- spacecore/linop/_base.py | 85 ++++++++++---- spacecore/linop/_dense.py | 66 ++++++++--- spacecore/linop/_diagonal.py | 45 +++++++- spacecore/linop/_sparse.py | 43 +++++-- spacecore/linop/product/__init__.py | 4 +- spacecore/linop/product/_base.py | 22 +++- spacecore/linop/product/_block.py | 28 ++++- spacecore/linop/product/_from_single.py | 31 +++-- spacecore/linop/product/_to_single.py | 31 +++-- spacecore/space/__init__.py | 2 + spacecore/space/_base.py | 52 ++++++--- spacecore/space/_batch.py | 34 +++++- spacecore/space/_checks.py | 83 ++++++++++++++ spacecore/space/_herm.py | 32 +++++- spacecore/space/_product.py | 48 ++++++-- spacecore/space/_vector.py | 50 ++++++-- spacecore/types/__init__.py | 2 + spacecore/types/_array.py | 34 +++++- 47 files changed, 1445 insertions(+), 280 deletions(-) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 19690b6..f8b9d55 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -1,3 +1,5 @@ +"""Backend-agnostic vector spaces, linear operators, and solvers.""" + from importlib.metadata import version as _version try: diff --git a/spacecore/_checks.py b/spacecore/_checks.py index dbfc965..09ba022 100644 --- a/spacecore/_checks.py +++ b/spacecore/_checks.py @@ -4,7 +4,11 @@ from typing import Any, Callable -def _as_positions(arg_pos: int | None, arg_positions: int | tuple[int, ...] | None) -> tuple[int, ...]: +def _as_positions( + arg_pos: int | None, + arg_positions: int | tuple[int, ...] | None, +) -> tuple[int, ...]: + """Normalize legacy and multi-position argument selectors.""" if arg_pos is not None and arg_positions is not None: raise TypeError("Use either arg_pos or arg_positions, not both.") if arg_positions is None: @@ -15,6 +19,7 @@ def _as_positions(arg_pos: int | None, arg_positions: int | tuple[int, ...] | No def _space_target(self: Any, space_name: str) -> Any: + """Return the space object named by ``space_name``.""" return self if space_name == "self" else getattr(self, space_name) @@ -30,17 +35,17 @@ def checked_method( Parameters ---------- - in_space: + in_space : str or None, optional Name of the attribute on ``self`` containing the input :class:`~spacecore.space.Space`, ``"self"`` to validate against the receiver itself, or ``None`` to skip input validation. - out_space: + out_space : str or None, optional Name of the attribute on ``self`` containing the output :class:`~spacecore.space.Space`, ``"self"`` to validate against the receiver itself, or ``None`` to skip output validation. - arg_pos: + arg_pos : int or None, optional Deprecated alias for a single entry in ``arg_positions``. - arg_positions: + arg_positions : int, tuple of int, or None, optional Zero-based positions in ``*args`` of input values that should be checked against ``in_space``. Defaults to ``(0,)``. diff --git a/spacecore/_contextual/_bound.py b/spacecore/_contextual/_bound.py index cce24e0..3a21bad 100644 --- a/spacecore/_contextual/_bound.py +++ b/spacecore/_contextual/_bound.py @@ -16,9 +16,9 @@ def _same_context_for_conversion(left: Context, right: Context) -> bool: Parameters ---------- - left: + left : Context First context to compare. - right: + right : Context Second context to compare. Returns @@ -46,9 +46,9 @@ def _same_context_for_algebra(left: Context, right: Context) -> bool: Parameters ---------- - left: + left : Context First context to compare. - right: + right : Context Second context to compare. Returns @@ -79,7 +79,7 @@ class ContextBound(ABC): Parameters ---------- - ctx: + ctx : Context, str, or None, optional Context specification passed to :meth:`__init__`. This may be a concrete :class:`~spacecore.backend.Context`, a backend-family string, or ``None`` to use the current default context. @@ -96,7 +96,7 @@ def __init__(self, ctx: Context | str | None = None): Parameters ---------- - ctx: + ctx : Context, str, or None, optional Context specification for the object. This may be a concrete :class:`~spacecore.backend.Context`, a backend-family string, or ``None`` to use the current default context. @@ -168,7 +168,7 @@ def _convert(self, new_ctx: Context) -> Self: Parameters ---------- - new_ctx: + new_ctx : Context Concrete target context in which the subclass should rebuild its owned arrays, spaces, operators, or nested context-bound objects. @@ -185,7 +185,7 @@ def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: Parameters ---------- - new_ctx: + new_ctx : Context, BackendFamily, str, or None, optional Target context specification. ``None`` resolves according to the current conversion policy and default context. diff --git a/spacecore/_contextual/_manager.py b/spacecore/_contextual/_manager.py index efea35f..a9b3232 100644 --- a/spacecore/_contextual/_manager.py +++ b/spacecore/_contextual/_manager.py @@ -14,6 +14,7 @@ def _state(): + """Return the cached contextual singleton.""" global _cached_state if _cached_state is not None: return _cached_state @@ -33,14 +34,14 @@ def set_context( Parameters ---------- - ctx: + ctx : Context, BackendFamily, str, or None, optional Context specification to make default. This may be a concrete :class:`spacecore.backend.Context`, a backend family enum, a backend family string such as ``"numpy"`` or ``"jax"``, or ``None``. - dtype: + dtype : dtype-like, optional Optional dtype used when ``ctx`` is a backend family string or enum. Ignored when ``ctx`` is ``None`` or already a concrete ``Context``. - enable_checks: + enable_checks : bool or None, optional Optional validation flag used when constructing a context from a backend family. Ignored when ``ctx`` is ``None`` or already a concrete ``Context``. @@ -76,10 +77,10 @@ def resolve_context_priority( Parameters ---------- - priority_ctx: + priority_ctx : Context, BackendFamily, str, or None, optional Explicit context supplied by the caller. If this is not ``None``, it wins over every inferred context. - *other_ctx: + *other_ctx : object Objects that may carry a ``ctx`` attribute or be backend-native arrays. These are used for context inference when no explicit context is supplied. @@ -104,7 +105,7 @@ def register_ops(ops: type[BackendOps]) -> type[BackendOps]: Parameters ---------- - ops: + ops : type[BackendOps] Backend operations class to register. It must be a subclass of :class:`spacecore.backend.BackendOps` and define a unique backend family key. @@ -138,14 +139,42 @@ def normalize_context( dtype: Any = None, enable_checks: bool | None = None, ) -> Context: - """Normalize a context specification through the process-wide state.""" + """ + Normalize a context specification through the process-wide state. + + Parameters + ---------- + ctx : Context, BackendFamily, str, or None, optional + Context specification to normalize. + dtype : dtype-like, optional + Optional dtype used when constructing a context from backend family. + enable_checks : bool or None, optional + Optional validation flag. + + Returns + ------- + Context + Normalized context. + """ return _state().normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) def normalize_ops( ops: str | BackendFamily | BackendOps | type[BackendOps] | Context ) -> BackendOps: - """Normalize backend operations through the process-wide state.""" + """ + Normalize backend operations through the process-wide state. + + Parameters + ---------- + ops : str, BackendFamily, BackendOps, type[BackendOps], or Context + Backend operations specification. + + Returns + ------- + BackendOps + Backend operations instance. + """ if isinstance(ops, BackendOps): return ops return _state().get_ops(ops) @@ -165,7 +194,7 @@ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: Parameters ---------- - policy: + policy : ContextPolicy, str, or None, optional Conversion policy to use. Accepted values are ``"warning"``, ``"error"``, ``"silent"``, matching :class:`ContextPolicy`, or ``None`` to restore the default policy. @@ -205,7 +234,7 @@ def set_dtype_resolution_policy( Parameters ---------- - policy: + policy : DtypePreservePolicy, str, or None, optional Dtype policy to use. Accepted values are ``"keep_native"`` and ``"convert"``, matching :class:`DtypePreservePolicy`, or ``None`` to restore the default policy. diff --git a/spacecore/_contextual/_state.py b/spacecore/_contextual/_state.py index e3ec419..438ab11 100644 --- a/spacecore/_contextual/_state.py +++ b/spacecore/_contextual/_state.py @@ -24,9 +24,8 @@ class Contextual: - """ - Backend resolver. - """ + """Resolve contexts, backend registrations, and conversion policies.""" + _default_ctx: Context _available_ops: Dict[str, type[BackendOps]] _resolution_policy: ContextPolicy diff --git a/spacecore/backend/__init__.py b/spacecore/backend/__init__.py index 99c97f9..5ccc2e0 100644 --- a/spacecore/backend/__init__.py +++ b/spacecore/backend/__init__.py @@ -1,3 +1,5 @@ +"""Backend contexts and operation implementations.""" + from ._context import Context from ._ops import BackendOps from ._family import BackendFamily diff --git a/spacecore/backend/_context.py b/spacecore/backend/_context.py index 689aa43..4d9bd29 100644 --- a/spacecore/backend/_context.py +++ b/spacecore/backend/_context.py @@ -9,7 +9,7 @@ @dataclass(frozen=True, slots=True) class Context: """ - Backend execution context for SpaceCore objects. + Select backend operations, dtype, and validation policy. A context collects the backend operations object, default dtype, and runtime validation policy used by spaces, linear operators, and context-bound @@ -19,18 +19,27 @@ class Context: Parameters ---------- - ops: + ops : BackendOps Backend operations implementation. This must be an instance of :class:`spacecore.backend.BackendOps`, such as :class:`spacecore.backend.NumpyOps` or :class:`spacecore.backend.JaxOps`. - dtype: + dtype : dtype-like or None, optional Default dtype used by :meth:`asarray` and :meth:`assparse`. The value is normalized through ``ops.sanitize_dtype`` during initialization. - enable_checks: + enable_checks : bool, optional Whether spaces and linear operators using this context should perform membership and compatibility checks before operations. + Attributes + ---------- + ops : BackendOps + Normalized backend operations instance. + dtype : dtype-like + Backend-native dtype used by array constructors. + enable_checks : bool + Runtime validation flag propagated to spaces and operators. + Notes ----- ``Context`` is frozen and slot-based. Methods that convert values return new @@ -38,6 +47,17 @@ class Context: Equality compares backend family and ``enable_checks``. It currently does not compare ``dtype``. + + Examples + -------- + Create a NumPy context and convert a Python list to a backend array. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> x = ctx.asarray([1.0, 2.0]) + >>> x.dtype == np.dtype("float64") + True """ ops: BackendOps diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index f5362f7..bfa88ac 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -601,7 +601,7 @@ def trace(self, x: DenseArray) -> DenseArray: return self.sum(self.diagonal(x)) def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """Indices that sort x (delegates to xp.argsort).""" + """Return indices that sort ``x`` along an axis.""" return self.xp.argsort(x, axis=axis) def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: @@ -609,17 +609,18 @@ def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: return self.xp.sort(x, axis=axis) def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """Indices of minima (delegates to xp.argmin).""" + """Return indices of minima along an axis.""" return self.xp.argmin(x, axis=axis, keepdims=keepdims) def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """Indices of maxima (delegates to xp.argmax).""" + """Return indices of maxima along an axis.""" return self.xp.argmax(x, axis=axis, keepdims=keepdims) def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Returns sum(conj(x) * y). Matches numpy/jax/torch vdot and Array API - vecdot. DenseLinOp.rapply relies on this for complex inputs. + """Return ``sum(conj(x) * y)`` over flattened inputs. + + Matches NumPy, JAX, and Torch ``vdot`` semantics. ``DenseLinOp.rapply`` + relies on this convention for complex inputs. """ x_flat = self.ravel(x) y_flat = self.ravel(y) diff --git a/spacecore/backend/cupy/__init__.py b/spacecore/backend/cupy/__init__.py index 8908ad9..c32df52 100644 --- a/spacecore/backend/cupy/__init__.py +++ b/spacecore/backend/cupy/__init__.py @@ -1,3 +1,5 @@ +"""CuPy backend implementation.""" + from ._ops import CuPyOps as CuPyOps __all__ = ["CuPyOps"] diff --git a/spacecore/backend/jax/__init__.py b/spacecore/backend/jax/__init__.py index afd0428..b64d6f7 100644 --- a/spacecore/backend/jax/__init__.py +++ b/spacecore/backend/jax/__init__.py @@ -1,2 +1,4 @@ +"""JAX backend implementation and pytree registration helpers.""" + from ._ops import JaxOps as JaxOps -from ._pytree import jax_pytree_class as jax_pytree_class \ No newline at end of file +from ._pytree import jax_pytree_class as jax_pytree_class diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index d13f995..000bc6c 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -24,6 +24,7 @@ class JaxOps(BackendOps): jax.experimental.sparse.BCSR Methods + ------- Most methods mirror the corresponding JAX public API signatures and delegate to `jax.numpy`, `jax.numpy.linalg`, `jax.scipy`, or `jax.experimental.sparse`. Backend-specific behavior, tracing rules, @@ -45,6 +46,7 @@ class JaxOps(BackendOps): through instances as `ops.jsparse`. Notes + ----- Code intended to remain backend-portable should prefer `BackendOps` methods. Direct use of `ops.jax`, `ops.jnp`, or `ops.jsparse` is an explicit JAX-specific escape hatch. @@ -53,6 +55,7 @@ class JaxOps(BackendOps): JAX ignores them. Array-creation routines may expose `device` and `out_sharding` for explicit placement or sharding. """ + import jax import jax.numpy as jnp import jax.experimental.sparse as jsparse @@ -120,7 +123,8 @@ def dense_array(self) -> Type[Any]: """ Dense array type using JAX. - Returns: + Returns + ------- Concrete dense array class accepted by this backend. See: @@ -133,7 +137,8 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using JAX. - Returns: + Returns + ------- Concrete sparse array classes accepted by this backend, or None. See: @@ -279,7 +284,7 @@ def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: boo return x.at[index].set(values) def ix_(self, *args: Any) -> Any: - """ + r""" Build open mesh index arrays using JAX. Input: diff --git a/spacecore/backend/jax/_pytree.py b/spacecore/backend/jax/_pytree.py index 55f5fa6..df68672 100644 --- a/spacecore/backend/jax/_pytree.py +++ b/spacecore/backend/jax/_pytree.py @@ -1,20 +1,32 @@ from __future__ import annotations + from typing import TypeVar T = TypeVar("T") -def jax_pytree_class(cls: T) -> T: + +def jax_pytree_class(klass: T) -> T: """ Mark a class as a JAX PyTree node, if JAX is available. Safe to import without JAX installed. + + Parameters + ---------- + klass : type + Class implementing JAX pytree methods. + + Returns + ------- + type + Registered class when JAX is available, otherwise ``klass`` unchanged. """ try: from jax import tree_util except Exception: - return cls + return klass try: - tree_util.register_pytree_node_class(cls) + tree_util.register_pytree_node_class(klass) except Exception: pass - return cls + return klass diff --git a/spacecore/backend/numpy/__init__.py b/spacecore/backend/numpy/__init__.py index 9528a76..0bfc215 100644 --- a/spacecore/backend/numpy/__init__.py +++ b/spacecore/backend/numpy/__init__.py @@ -1 +1,3 @@ -from ._ops import NumpyOps as NumpyOps \ No newline at end of file +"""NumPy backend implementation.""" + +from ._ops import NumpyOps as NumpyOps diff --git a/spacecore/backend/numpy/_ops.py b/spacecore/backend/numpy/_ops.py index 9a919bc..efc9929 100644 --- a/spacecore/backend/numpy/_ops.py +++ b/spacecore/backend/numpy/_ops.py @@ -22,6 +22,7 @@ class NumpyOps(BackendOps): scipy.sparse.sparray Methods + ------- Most methods mirror the corresponding NumPy or SciPy signatures and delegate directly to NumPy/SciPy implementations. Backend-specific behavior, dtype promotion, broadcasting, memory layout, and error modes @@ -38,6 +39,7 @@ class NumpyOps(BackendOps): `ops.sp`. Advanced users may use it for SciPy-specific functionality. Notes + ----- Code intended to remain backend-portable should prefer `BackendOps` methods. Direct use of `ops.np` or `ops.sp` is an explicit NumPy/SciPy-specific escape hatch. @@ -46,6 +48,7 @@ class NumpyOps(BackendOps): When supplied, it must be `"cpu"` or `None`; see the corresponding NumPy documentation for each method. """ + import numpy as np import scipy as sp import array_api_compat.numpy as xp @@ -61,7 +64,8 @@ def dense_array(self) -> Type[Any]: """ Dense array type using NumPy. - Returns: + Returns + ------- Concrete dense array class accepted by this backend. See: @@ -74,7 +78,8 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using SciPy. - Returns: + Returns + ------- Concrete sparse array classes accepted by this backend, or None. See: @@ -102,7 +107,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: See: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html """ - if dtype is None: return self.np.float64 return self.np.dtype(dtype) @@ -211,7 +215,7 @@ def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bo return x def ix_(self, *args: Any) -> Any: - """ + r""" Build open mesh index arrays using NumPy. Input: @@ -299,9 +303,7 @@ def _tree_multimap(self, f: Callable[..., Any], *trees: Any) -> Any: return f(*trees) def _tree_take0(self, xs: Any) -> Any: - """ - Grab a representative leaf to infer leading length. - """ + """Grab a representative leaf to infer leading length.""" if isinstance(xs, dict): return self._tree_take0(next(iter(xs.values()))) if isinstance(xs, (tuple, list)): @@ -309,9 +311,7 @@ def _tree_take0(self, xs: Any) -> Any: return xs def _tree_index(self, xs: Any, i: int) -> Any: - """ - Take per-step slice xs[i] along axis=0 for each leaf. - """ + """Take per-step slice ``xs[i]`` along axis 0 for each leaf.""" def _idx(a: Any) -> Any: # If it's an ndarray-like with leading axis, slice it; else treat as scalar leaf. @@ -323,9 +323,9 @@ def _idx(a: Any) -> Any: return self._tree_map(_idx, xs) def _tree_stack(self, ys_list: Sequence[Any]) -> Any: - """ - Stack a list of per-step outputs into a single pytree of arrays - by stacking leaves along axis=0. + """Stack per-step outputs into a single pytree of arrays. + + Leaves are stacked along axis 0. """ if not ys_list: # JAX would return empty stacked outputs when length == 0 diff --git a/spacecore/backend/torch/__init__.py b/spacecore/backend/torch/__init__.py index 6bca321..b395f7b 100644 --- a/spacecore/backend/torch/__init__.py +++ b/spacecore/backend/torch/__init__.py @@ -1,3 +1,5 @@ +"""PyTorch backend implementation.""" + from ._ops import TorchOps diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index dabf42d..299648d 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -22,6 +22,7 @@ class TorchOps(BackendOps): torch.Tensor with a PyTorch sparse layout Methods + ------- Most methods mirror the corresponding PyTorch public API signatures and delegate to ``torch`` or ``torch.linalg``. Backend-specific behavior, dtype promotion, broadcasting, device placement, autograd tracking, and @@ -34,6 +35,7 @@ class TorchOps(BackendOps): portable API does not expose a required PyTorch feature. Notes + ----- Code intended to remain backend-portable should prefer ``BackendOps`` methods. Direct use of ``ops.torch`` is an explicit PyTorch-specific escape hatch. @@ -77,7 +79,8 @@ def dense_array(self) -> Type[Any]: """ Dense array type using PyTorch. - Returns: + Returns + ------- Concrete dense tensor class accepted by this backend. See: @@ -90,7 +93,8 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using PyTorch. - Returns: + Returns + ------- Tensor class accepted by this backend for sparse tensor layouts. See: diff --git a/spacecore/functional/__init__.py b/spacecore/functional/__init__.py index 8209cf1..ef2022d 100644 --- a/spacecore/functional/__init__.py +++ b/spacecore/functional/__init__.py @@ -1,3 +1,5 @@ +"""Scalar-valued functionals and composition helpers.""" + from ._base import Functional from ._composed import ComposedFunctional, make_functional_composed from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py index 1ef833e..9cf86ce 100644 --- a/spacecore/functional/_base.py +++ b/spacecore/functional/_base.py @@ -15,13 +15,27 @@ class Functional(ContextBound, Generic[Domain]): - """ + r""" Scalar-valued map on a space. ``Functional`` represents a map ``F : X -> K`` without assuming any storage model. It mirrors the minimal ``LinOp`` contract: the domain is converted into the resolved context, value checks follow ``ctx.enable_checks``, and batched evaluation is implemented by a backend ``vmap`` fallback. + + Parameters + ---------- + dom : Space + Domain space ``X``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + + Attributes + ---------- + dom : Space + Domain space converted to ``ctx``. + ctx : Context + Resolved backend context. """ def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: @@ -37,13 +51,7 @@ def domain(self) -> Domain: @abstractmethod def value(self, x: Any) -> Any: - """ - Evaluate this functional at ``x``. - - Contract: - - x is an element of ``self.domain``; - - the return value is scalar-like in the functional context. - """ + """Evaluate this functional at an element of ``self.domain``.""" def __call__(self, x: Any) -> Any: """Evaluate this functional at ``x``.""" @@ -55,7 +63,7 @@ def compose(self, A: "LinOp") -> "Functional": Parameters ---------- - A: + A : LinOp Linear operator whose codomain matches this functional's domain. Returns @@ -72,6 +80,7 @@ def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: return self._fallback_vvalue(xs, batch_space) def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + """Infer leading batch dimensions from a value and base space.""" if hasattr(space, "spaces") and isinstance(value, tuple) and value: return self._infer_batch_shape(space.spaces[0], value[0]) shape = tuple(getattr(value, "shape", ())) @@ -91,12 +100,14 @@ def _input_batch_space( value: Any, batch_space: Space | None, ) -> Space: + """Return the batch space used to validate batched inputs.""" if batch_space is not None: return batch_space batch_shape = self._infer_batch_shape(space, value) return space.batch(batch_shape, tuple(range(len(batch_shape)))) def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + """Return the batch space corresponding to a batched output.""" batch_shape = getattr(input_batch_space, "batch_shape", None) batch_axes = getattr(input_batch_space, "batch_axes", None) if batch_shape is None or batch_axes is None: @@ -104,6 +115,7 @@ def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: return space.batch(tuple(batch_shape), tuple(batch_axes)) def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]: + """Return batch shape or raise when batch axes are not leading.""" batch_shape = tuple(getattr(batch_space, "batch_shape", ())) batch_axes = tuple(getattr(batch_space, "batch_axes", ())) expected_axes = tuple(range(len(batch_shape))) @@ -115,12 +127,14 @@ def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]: return batch_shape def _vmap_leading(self, fn: Any, batch_ndim: int) -> Any: + """Vectorize ``fn`` over ``batch_ndim`` leading axes.""" mapped = fn for _ in range(batch_ndim): mapped = self.ops.vmap(mapped, in_axes=0, out_axes=0) return mapped def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None: + """Raise if scalar batch output does not have ``batch_shape``.""" shape = tuple(getattr(values, "shape", ())) if shape != batch_shape: raise ValueError( @@ -128,6 +142,7 @@ def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None ) def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate this functional over a leading batch with backend ``vmap``.""" in_space = self._input_batch_space(self.domain, xs, batch_space) batch_shape = self._require_leading_batch_axes(in_space) if self._enable_checks: @@ -138,13 +153,16 @@ def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: return values def assert_domain(self, x: Any) -> None: + """Raise if ``x`` is not in the domain.""" self.dom.check_member(x) @abstractmethod def tree_flatten(self): + """Flatten this functional for pytree registration.""" ... @classmethod @abstractmethod def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" ... diff --git a/spacecore/functional/_composed.py b/spacecore/functional/_composed.py index d018176..8555b5a 100644 --- a/spacecore/functional/_composed.py +++ b/spacecore/functional/_composed.py @@ -11,6 +11,7 @@ def _require_composable(F: Functional, A: LinOp) -> None: + """Raise unless ``F`` can be composed with ``A``.""" if not isinstance(F, Functional): raise TypeError(f"F must be a Functional, got {type(F).__name__}.") if not isinstance(A, LinOp): @@ -28,9 +29,9 @@ def make_functional_composed(F: Functional, A: LinOp) -> Functional: Parameters ---------- - F: + F : Functional Functional defined on ``A.codomain``. - A: + A : LinOp Linear operator whose codomain is ``F.domain``. Returns @@ -55,6 +56,13 @@ class ComposedFunctional(Functional): Generic pull-back of a functional through a linear operator. ``ComposedFunctional(F, A)`` represents ``x -> F(A x)`` on ``A.domain``. + + Parameters + ---------- + F : Functional + Functional defined on ``A.codomain``. + A : LinOp + Linear operator whose codomain is ``F.domain``. """ def __init__(self, F: Functional, A: LinOp) -> None: @@ -81,19 +89,23 @@ def value(self, x: Any) -> Any: return self.F.value(self.A.apply(x)) def __eq__(self, other: Any) -> bool: + """Return whether another composed functional has the same operands.""" if type(other) is type(self): return self.F == other.F and self.A == other.A return False def tree_flatten(self): + """Flatten this functional for pytree registration.""" children = (self.F, self.A) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" F, A = children return cls(F, A) def _convert(self, new_ctx: Context) -> ComposedFunctional: + """Convert the composed functional and operator to ``new_ctx``.""" return ComposedFunctional(self.F.convert(new_ctx), self.A.convert(new_ctx)) diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py index 487349f..1dbc887 100644 --- a/spacecore/functional/_linear.py +++ b/spacecore/functional/_linear.py @@ -10,6 +10,7 @@ def _convert_space_element(space: Space, value: Any) -> Any: + """Convert a value recursively into a possibly product-valued space.""" if hasattr(space, "spaces") and isinstance(value, tuple): if len(value) != len(space.spaces): raise ValueError( @@ -23,7 +24,16 @@ def _convert_space_element(space: Space, value: Any) -> Any: class LinearFunctional(Functional[Domain]): - """Linear scalar-valued map ``ell : X -> K``.""" + r""" + Represent a linear scalar-valued map. + + Parameters + ---------- + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + """ @property @abstractmethod @@ -38,10 +48,25 @@ def representer(self) -> Any: @jax_pytree_class class InnerProductFunctional(LinearFunctional[Domain]): - """ + r""" Linear functional represented by a domain element. - ``InnerProductFunctional(c, X)`` evaluates ``ell_c(x) = _X``. + ``InnerProductFunctional(c, X)`` evaluates + :math:`\ell_c(x) = \langle c, x\rangle_X`. + + Parameters + ---------- + c : array-like + Riesz representer in ``dom``. + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + + Attributes + ---------- + representer : array-like + Stored domain element ``c``. """ def __init__( @@ -66,6 +91,7 @@ def value(self, x: Any) -> Any: return self.domain.inner(self._c, x) def __eq__(self, other: Any) -> bool: + """Return whether another inner-product functional has the same representer.""" if type(other) is type(self): return self.domain == other.domain and self.ops.allclose( self.domain.flatten(self._c), @@ -74,17 +100,20 @@ def __eq__(self, other: Any) -> bool: return False def tree_flatten(self): + """Flatten this functional for pytree registration.""" children = (self._c,) aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" domain, ctx = aux c = children[0] return cls(c, domain, ctx) def _convert(self, new_ctx: Context) -> InnerProductFunctional: + """Convert the domain and representer to ``new_ctx``.""" return InnerProductFunctional(self._c, self.domain.convert(new_ctx), new_ctx) @@ -98,15 +127,15 @@ class MatrixFreeLinearFunctional(LinearFunctional[Domain]): Parameters ---------- - value: + value : callable Callable with signature ``value(x: Any) -> Any`` accepting an element of ``dom`` and returning a scalar-like backend value. - dom: + dom : Space Domain space of the functional. - ctx: + ctx : Context, str, or None, optional Optional context specification. An explicit context wins over inferred and default contexts. - vvalue: + vvalue : callable or None, optional Optional callable with signature ``vvalue(xs: Any) -> Any`` for batched evaluation. If omitted, backend ``vmap`` fallback is used. @@ -222,6 +251,7 @@ def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: return values def __eq__(self, other: Any) -> bool: + """Return whether another matrix-free functional uses the same callables.""" if type(other) is type(self): return ( self.domain == other.domain @@ -231,12 +261,14 @@ def __eq__(self, other: Any) -> bool: return False def tree_flatten(self): + """Flatten this functional for pytree registration.""" children = () aux = (self.value_fn, self.domain, self.ctx, self.vvalue_fn) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" value_fn, domain, ctx, vvalue_fn = aux return cls(value_fn, domain, ctx, vvalue_fn) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py index 5387f75..2e0d889 100644 --- a/spacecore/functional/_quadratic.py +++ b/spacecore/functional/_quadratic.py @@ -12,7 +12,16 @@ class QuadraticForm(Functional[Domain]): - """Scalar quadratic objective on a space.""" + """ + Represent a scalar quadratic objective on a space. + + Parameters + ---------- + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + """ def hess_apply(self, x: Any) -> Any: """Apply the Hessian action at ``x`` when available.""" @@ -36,8 +45,8 @@ def vgrad(self, xs: Any, batch_space: Space | None = None) -> Any: @jax_pytree_class class LinOpQuadraticForm(QuadraticForm[Domain]): - """ - Quadratic form f(x) = 0.5 * . + r""" + Represent a quadratic form backed by a linear operator. Assumption: Q is Hermitian/self-adjoint. Under this assumption, @@ -50,6 +59,27 @@ class LinOpQuadraticForm(QuadraticForm[Domain]): ``Q : X -> X``. Structurally available dense and diagonal operators are checked at construction. Matrix-free operators are not validated; correctness is the caller's responsibility. + + Parameters + ---------- + Q : LinOp + Hermitian operator from a space to itself. + linear : LinearFunctional or None, optional + Optional linear term on ``Q.domain``. + a : scalar-like, optional + Constant scalar offset. Default is 0. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``Q`` and + ``linear``. + + Attributes + ---------- + Q : LinOp + Stored Hermitian operator. + linear : LinearFunctional or None + Stored linear term. + a : scalar-like + Stored scalar offset. """ def __init__( @@ -84,6 +114,7 @@ def __init__( @staticmethod def _check_hermitian_structure(Q: LinOp[Domain, Domain]) -> None: + """Raise when ``Q`` is structurally known to be non-Hermitian.""" result = Q.is_hermitian() if result is False: raise ValueError("LinOpQuadraticForm requires Q to be Hermitian/self-adjoint.") @@ -116,6 +147,7 @@ def hess_apply(self, x: Any) -> Any: return self.Q.apply(x) def __eq__(self, other: Any) -> bool: + """Return whether another quadratic form has the same stored terms.""" if type(other) is type(self): return ( self.Q == other.Q @@ -125,15 +157,18 @@ def __eq__(self, other: Any) -> bool: return False def tree_flatten(self): + """Flatten this quadratic form for pytree registration.""" children = (self.Q, self.linear, self.a) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this quadratic form from pytree data.""" Q, linear, a = children return cls(Q, linear, a, Q.ctx) def _convert(self, new_ctx: Context) -> LinOpQuadraticForm: + """Convert stored terms to ``new_ctx``.""" linear = None if self.linear is None else self.linear.convert(new_ctx) return LinOpQuadraticForm(self.Q.convert(new_ctx), linear, self.a, new_ctx) diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index f3a4a7e..c80d084 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -1,3 +1,5 @@ +"""Iterative linear algebra solvers and Krylov algorithms.""" + from __future__ import annotations from ._cg import CGResult, cg diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index a29d157..eb9ac49 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -9,7 +9,20 @@ class CGResult(NamedTuple): - """Result returned by :func:`cg`.""" + """ + Store the result returned by :func:`cg`. + + Parameters + ---------- + x : array-like + Approximate solution in ``A.domain``. + converged : bool-like + Whether the final residual norm satisfied the requested tolerance. + num_iters : int-like + Number of conjugate-gradient iterations executed. + residual_norm : scalar + Norm of the final residual in ``A.codomain``. + """ x: Any converged: Any @@ -39,15 +52,89 @@ def cg( maxiter: int | None = None, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> CGResult: - """ - Solve ``A x = b`` by conjugate gradients. - - ``A`` must be a square symmetric/Hermitian positive-definite ``LinOp``. - The implementation uses only ``A.apply`` and the domain-space inner product; - it never materializes a dense matrix. The residual norm is compared with - ``atol + tol * ||b||`` only every ``check_every`` iterations, and always on - the final iteration. This avoids checking the stopping criterion on every - step while remaining compatible with JAX JIT control flow. + r""" + Solve :math:`A x = b` by conjugate gradients. + + Require ``A`` to be a square Hermitian positive-definite :class:`LinOp`. + The implementation uses only :meth:`LinOp.apply` and the domain-space inner + product; it never materializes a dense matrix. + + Parameters + ---------- + A : LinOp + Hermitian positive-definite linear operator. + b : array-like + Right-hand side in ``A.codomain``. + x0 : array-like or None, optional + Initial guess in ``A.domain``. Default is the zero vector. + tol : float, optional + Relative residual tolerance. Default is 1e-6. + atol : float, optional + Absolute residual tolerance. Default is 0.0. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh convergence diagnostics every this many iterations and always + on the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + CGResult + Named tuple with fields: + + - ``x``: approximate solution in ``A.domain`` + - ``converged``: whether the requested tolerance was met + - ``num_iters``: number of iterations executed + - ``residual_norm``: final residual norm + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square or if iteration parameters are invalid. + + See Also + -------- + lsqr : Solve least-squares systems for rectangular operators. + lanczos_smallest : Approximate the smallest eigenpair of a Hermitian + operator. + + Notes + ----- + The residual norm is compared with + :math:`\text{atol} + \text{tol} \| b \|` only every ``check_every`` + iterations, and always on the final iteration. This keeps convergence + checks out of the hot loop while remaining compatible with JAX JIT control + flow. ``maxiter`` and ``check_every`` should be treated as static JAX + arguments. + + Works on real and complex operators. For complex operators, the method uses + the domain inner product convention implemented by ``A.domain``. + + References + ---------- + .. [1] Hestenes, M. R. and Stiefel, E., "Methods of Conjugate Gradients + for Solving Linear Systems," J. Res. Natl. Bur. Stand., 49 (1952), + 409-436. + + Examples + -------- + Solve a small positive-definite system. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> M = ctx.asarray([[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]]) + >>> A = sc.DenseLinOp(M, X, X, ctx) + >>> b = ctx.asarray([1.0, 2.0, 3.0]) + >>> result = sc.cg(A, b, tol=1e-10) + >>> bool(result.converged) + True + >>> np.allclose(A.apply(result.x), b) + True """ A = require_linop(A) require_square(A, "cg") diff --git a/spacecore/linalg/_expm.py b/spacecore/linalg/_expm.py index 40b68a0..fe60a2a 100644 --- a/spacecore/linalg/_expm.py +++ b/spacecore/linalg/_expm.py @@ -8,19 +8,20 @@ class ExpmMultiplyResult(NamedTuple): - """Result returned by :func:`expm_multiply`. + """ + Store the result returned by :func:`expm_multiply`. - Attributes + Parameters ---------- - result: + result : array-like Vector in the domain of the input operator approximating ``exp(t * A) @ v``. - krylov_dim: + krylov_dim : int-like Actual Krylov dimension reached before breakdown or ``max_iter``. - residual_estimate: + residual_estimate : scalar Projected exponential residual estimate ``abs(beta[m] * phi[m - 1])``. - converged: + converged : bool-like Boolean indicating whether ``residual_estimate < tol``. """ @@ -50,30 +51,74 @@ def expm_multiply( max_iter: int = 30, tol: float = 1e-10, ) -> ExpmMultiplyResult: - """ - Compute ``exp(t * A) @ v`` for a Hermitian operator via Krylov projection. + r""" + Compute :math:`\exp(t A) v` by Krylov projection. + + Require ``A`` to be Hermitian or structurally unknown. The method builds a + Lanczos basis and applies the exponential of the small tridiagonal + projection, avoiding dense materialization of ``A``. Parameters ---------- - A: + A : LinOp Square Hermitian linear operator. - v: + v : array-like Initial vector in ``A.domain``. - t: + t : float or complex, optional Scalar time/scale multiplying ``A``. Complex values are supported for - complex-valued contexts, for example Schrödinger evolution. - max_iter: + complex-valued contexts, for example Schrodinger evolution. Default is + 1.0. + max_iter : int, optional Maximum Krylov dimension. Values around 20-50 are usually sufficient - when ``abs(t) * ||A||`` is moderate. - tol: + when :math:`|t|\|A\|` is moderate. Default is 30. + tol : float, optional Breakdown tolerance for Lanczos and threshold for the projected - exponential residual estimate. + exponential residual estimate. Default is 1e-10. Returns ------- ExpmMultiplyResult Result vector in ``A.domain``, the Krylov dimension used, the standard estimate ``abs(beta[m] * phi[m - 1])``, and a convergence flag. + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square, is known to be non-Hermitian, or if + ``max_iter`` is invalid. + + See Also + -------- + lanczos_smallest : Build the related Hermitian Krylov projection. + power_iteration : Estimate a dominant eigenpair. + + Notes + ----- + The projected exponential is computed as + :math:`\exp(t T) e_0` using an eigendecomposition of the small real + symmetric tridiagonal matrix ``T``. This is JIT-compatible on the JAX + backend when ``max_iter`` is static. + + The returned residual estimate is + :math:`|\beta_m \phi_{m-1}|`, where ``phi`` is the projected exponential + vector. Callers that need the true residual can perform one additional + operator application. + + Examples + -------- + Apply the exponential of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([0.0, 1.0]), X, ctx) + >>> v = ctx.asarray([2.0, 3.0]) + >>> result = sc.expm_multiply(A, v, t=0.5, max_iter=2) + >>> np.allclose(result.result, [2.0, 3.0 * np.exp(0.5)]) + True """ A = require_linop(A) require_square(A, "expm_multiply") diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 1051dbe..0d58c64 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -13,7 +13,22 @@ class LanczosResult(NamedTuple): - """Result returned by :func:`lanczos_smallest`.""" + """ + Store the result returned by :func:`lanczos_smallest`. + + Parameters + ---------- + eigenvalue : scalar + Ritz approximation to the smallest eigenvalue. + eigenvector : array-like + Ritz vector in ``A.domain``. + residual_norm : scalar + Standard Ritz residual estimate. + krylov_dim : int-like + Krylov dimension reached before breakdown or ``max_iter``. + converged : bool-like + Whether ``residual_norm < tol``. + """ eigenvalue: Any eigenvector: Any @@ -39,6 +54,8 @@ def __repr__(self) -> str: class _LanczosBasisResult(NamedTuple): + """Store fixed-size Lanczos basis data and tridiagonal projection.""" + V: DenseArray T: DenseArray alphas: DenseArray @@ -50,6 +67,7 @@ class _LanczosBasisResult(NamedTuple): def _check_lanczos_max_iter(max_iter: int) -> int: + """Validate and normalize the maximum Lanczos iteration count.""" max_iter = int(max_iter) if max_iter < 1: raise ValueError("max_iter must be positive.") @@ -64,6 +82,7 @@ def _build_tridiagonal( m: Any, real_dtype: Any, ) -> DenseArray: + """Build the fixed-size tridiagonal Lanczos projection.""" idx = ops.arange(max_iter) full_indices = ops.arange(max_iter + 1) mask_alpha = idx < m @@ -99,6 +118,7 @@ def _lanczos_basis_and_tridiag( real_dtype: Any, check_every: int, ) -> _LanczosBasisResult: + """Build a Lanczos basis and tridiagonal projection.""" ops = A.ops ctx = A.ctx use_euclidean_reorth = type(A.domain) is VectorSpace @@ -220,7 +240,8 @@ def lanczos_smallest( tol: float = 1e-6, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> LanczosResult: - r"""Approximate the smallest eigenpair of a Hermitian operator. + r""" + Approximate the smallest eigenpair of a Hermitian operator. The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an element of ``A.domain``. The implementation keeps fixed-size coordinate @@ -229,27 +250,81 @@ def lanczos_smallest( reconstructed Ritz vector in the original space. Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for - ``span{v, T v, T^2 v, ...}`` and a tridiagonal projection - :math:`T_k = V^\dagger T V`. The returned vector is the Ritz vector - reconstructed in the original coordinates, and the returned scalar is the - Rayleigh quotient - :math:`(x^\dagger T x) / (x^\dagger x)`. - - Args: - A: Square Hermitian linear operator. - initial_vector: Starting vector in ``A.domain``. - max_iter: Maximum number of Lanczos steps. - tol: Breakdown tolerance for the off-diagonal Lanczos coefficient. - check_every: Refresh the breakdown-based stopping decision only every - this many iterations, and always on the final iteration. - - Returns: - ``LanczosResult`` containing the smallest approximated eigenpair, the - standard Ritz residual estimate ``beta[m] * abs(y[m - 1])``, the - Krylov dimension reached, and a convergence flag. The residual estimate - is computed from the tridiagonal recurrence; callers that need the true - residual can evaluate ``A.apply(eigenvector) - eigenvalue * eigenvector`` - once more in the original space. + ``span{v, A v, A^2 v, ...}`` and a tridiagonal projection + :math:`T_k = V^* A V`. The returned vector is the Ritz vector reconstructed + in the original coordinates, and the returned scalar is the Rayleigh + quotient :math:`\langle x, A x \rangle_X / \langle x, x \rangle_X`. + + Parameters + ---------- + A : LinOp + Square Hermitian linear operator. + initial_vector : array-like + Starting vector in ``A.domain``. If it is numerically zero, the + algorithm falls back to a deterministic coordinate vector. + max_iter : int, optional + Maximum Krylov dimension. Default is 100. + tol : float, optional + Breakdown tolerance for the off-diagonal Lanczos coefficient. Default + is 1e-6. + check_every : int, optional + Refresh the breakdown-based stopping decision every this many + iterations and always on the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + LanczosResult + Named tuple with fields: + + - ``eigenvalue``: smallest Ritz eigenvalue estimate + - ``eigenvector``: associated Ritz vector in ``A.domain`` + - ``residual_norm``: standard Ritz residual estimate + - ``krylov_dim``: actual Krylov dimension reached + - ``converged``: whether ``residual_norm < tol`` + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square or if ``max_iter`` is invalid. + + See Also + -------- + power_iteration : Estimate the dominant eigenpair. + expm_multiply : Apply a matrix exponential using the Lanczos basis. + + Notes + ----- + The residual estimate is computed from the tridiagonal recurrence as + :math:`\beta_m |y_{m-1}|`. Callers that need the true residual can evaluate + ``A.apply(eigenvector) - eigenvalue * eigenvector`` once more in the + original space. + + This function is JIT-compatible on the JAX backend when ``max_iter`` and + ``check_every`` are static arguments. For plain :class:`VectorSpace` + domains, Euclidean reorthogonalization is vectorized; custom spaces use + :meth:`Space.inner` to preserve the declared geometry. + + References + ---------- + .. [1] Lanczos, C., "An Iteration Method for the Solution of the Eigenvalue + Problem of Linear Differential and Integral Operators," J. Res. Natl. + Bur. Stand., 45 (1950), 255-282. + + Examples + -------- + Approximate the smallest eigenpair of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 4.0]), X, ctx) + >>> result = sc.lanczos_smallest(A, ctx.asarray([1.0, 1.0, 1.0]), max_iter=3) + >>> np.allclose(result.eigenvalue, 1.0) + True """ A = require_linop(A) require_square(A, "lanczos_smallest") @@ -309,12 +384,30 @@ def stochastic_lanczos( check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> LanczosResult: """ - Deprecated alias for :func:`lanczos_smallest`. + Call :func:`lanczos_smallest` through a deprecated alias. + + Parameters + ---------- + A : LinOp + Square Hermitian linear operator. + initial_vector : array-like + Starting vector in ``A.domain``. + max_iter : int, optional + Maximum Krylov dimension. Default is 100. + tol : float, optional + Breakdown tolerance. Default is 1e-6. + check_every : int, optional + Iteration interval for convergence checks. Returns ------- LanczosResult Result from :func:`lanczos_smallest`. + + Warns + ----- + DeprecationWarning + Always emitted because this alias will be removed in a future release. """ warn( "stochastic_lanczos is deprecated; use lanczos_smallest instead.", diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index cdd73c0..8907158 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -9,7 +9,22 @@ class LSQRResult(NamedTuple): - """Result returned by :func:`lsqr`.""" + """ + Store the result returned by :func:`lsqr`. + + Parameters + ---------- + x : array-like + Approximate least-squares solution in ``A.domain``. + converged : bool-like + Whether the normal-equation residual satisfied the requested tolerance. + num_iters : int-like + Number of LSQR iterations executed. + residual_norm : scalar + Norm of ``A x - b`` in ``A.codomain``. + normal_residual_norm : scalar + Norm of ``A.H @ (A x - b)`` in ``A.domain``. + """ x: Any converged: Any @@ -41,16 +56,89 @@ def lsqr( maxiter: int | None = None, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> LSQRResult: - """ - Solve ``min_x ||A x - b||_2`` by the LSQR Krylov iteration. - - The operator may be rectangular or square. The method uses ``A.apply`` for - forward products and ``A.H.apply`` for adjoint products, so the normal - equations are represented implicitly and no dense matrix is formed. - Convergence is tested against ``atol + tol * ||b||`` using - ``||A.H @ (A x - b)||``. That normal-equation residual is refreshed only - every ``check_every`` iterations, and always on the final iteration, so the - expensive stopping diagnostic is not evaluated on every Krylov step. + r""" + Solve :math:`\min_x \|A x - b\|` by LSQR. + + Allow ``A`` to be rectangular or square. The method uses + :meth:`LinOp.apply` for forward products and ``A.H.apply`` for adjoint + products, so the normal equations are represented implicitly and no dense + matrix is formed. + + Parameters + ---------- + A : LinOp + Linear operator defining the least-squares problem. + b : array-like + Right-hand side in ``A.codomain``. + x0 : array-like or None, optional + Initial guess in ``A.domain``. Default is the zero vector. + tol : float, optional + Relative tolerance for the normal-equation residual. Default is 1e-6. + atol : float, optional + Absolute tolerance for the normal-equation residual. Default is 0.0. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh residual diagnostics every this many iterations and always on + the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + LSQRResult + Named tuple with fields: + + - ``x``: approximate least-squares solution in ``A.domain`` + - ``converged``: whether the requested tolerance was met + - ``num_iters``: number of iterations executed + - ``residual_norm``: final residual norm + - ``normal_residual_norm``: final normal-equation residual norm + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If iteration parameters are invalid. + + See Also + -------- + cg : Solve square Hermitian positive-definite systems. + power_iteration : Estimate a dominant eigenpair. + + Notes + ----- + Convergence is tested using + :math:`\|A^*(A x - b)\| < \text{atol} + \text{tol}\|b\|`. + The normal-equation residual is refreshed only every ``check_every`` + iterations, and always on the final iteration. This function is + JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are + static arguments. + + Works on real and complex operators. For complex operators, ``A.H`` uses + the conjugate adjoint. + + References + ---------- + .. [1] Paige, C. C. and Saunders, M. A., "LSQR: An Algorithm for Sparse + Linear Equations and Sparse Least Squares," ACM Trans. Math. Soft., + 8 (1982), 43-71. + + Examples + -------- + Solve a small overdetermined least-squares problem. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> Y = sc.VectorSpace((3,), ctx) + >>> M = ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + >>> A = sc.DenseLinOp(M, X, Y, ctx) + >>> b = ctx.asarray([1.0, 2.0, 3.0]) + >>> result = sc.lsqr(A, b, tol=1e-10) + >>> np.allclose(result.x, [1.0, 2.0]) + True """ A = require_linop(A) A.codomain.check_member(b) diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 2d40f29..2c50588 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -13,7 +13,22 @@ class PowerIterationResult(NamedTuple): - """Result returned by :func:`power_iteration`.""" + """ + Store the result returned by :func:`power_iteration`. + + Parameters + ---------- + eigenvalue : scalar + Rayleigh-quotient estimate of the dominant eigenvalue. + eigenvector : array-like + Normalized eigenvector estimate in the operator domain. + converged : bool-like + Whether the residual norm satisfied ``tol``. + num_iters : int-like + Number of power iterations executed. + residual_norm : scalar + Norm of ``A x - eigenvalue * x``. + """ eigenvalue: Any eigenvector: Any @@ -36,26 +51,32 @@ def __repr__(self) -> str: class _SelfAdjointAction(NamedTuple): + """Store the callable action used by power iteration.""" + apply: Callable[[Any], Any] domain: Space ctx: Context @property def ops(self) -> Any: + """Backend operations for this action.""" return self.ctx.ops @property def dtype(self) -> Any: + """Default dtype for this action.""" return self.ctx.dtype def _action_from_linop(A: LinOp) -> _SelfAdjointAction: + """Normalize a square linear operator into a self-adjoint action.""" A = require_linop(A) require_square(A, "power_iteration") return _SelfAdjointAction(A.apply, A.domain, A.ctx) def _action_from_quadratic_form(q: QuadraticForm) -> _SelfAdjointAction: + """Normalize a quadratic form into its Hessian action.""" return _SelfAdjointAction(q.hess_apply, q.domain, q.ctx) @@ -67,17 +88,76 @@ def power_iteration( maxiter: int | None = None, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> PowerIterationResult: - """ - Estimate the dominant eigenpair of a square ``LinOp`` or Hessian action. + r""" + Estimate the dominant eigenpair of a self-adjoint action. - ``A`` may be a square ``LinOp`` or a ``QuadraticForm`` that exposes + Accept a square :class:`LinOp` or a :class:`QuadraticForm` exposing ``hess_apply``. Public dispatch converts either input into a fixed - self-adjoint action before entering the numerical loop. The method returns - the Rayleigh quotient for the current normalized iterate, the eigenvector - estimate, and the residual norm ``||A x - lambda x||``. The residual-based - stopping criterion is refreshed only every ``check_every`` iterations, and - always on the final iteration. For spectral-norm estimates of a rectangular - operator, call this on ``A.H @ A``. + self-adjoint action before entering the numerical loop. + + Parameters + ---------- + A : LinOp or QuadraticForm + Square operator or quadratic form whose dominant eigenpair is sought. + For spectral-norm estimates of a rectangular operator, pass + ``A.H @ A``. + x0 : array-like or None, optional + Initial vector in the action domain. Default is a normalized all-ones + vector in the domain geometry. + tol : float, optional + Residual-norm tolerance. Default is 1e-6. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh residual diagnostics every this many iterations and always on + the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + PowerIterationResult + Named tuple with fields: + + - ``eigenvalue``: Rayleigh-quotient eigenvalue estimate + - ``eigenvector``: normalized eigenvector estimate + - ``converged``: whether ``residual_norm < tol`` + - ``num_iters``: number of iterations executed + - ``residual_norm``: norm of ``A x - eigenvalue * x`` + + Raises + ------ + TypeError + If ``A`` is neither a :class:`LinOp` nor a :class:`QuadraticForm`. + ValueError + If a linear-operator input is not square or if iteration parameters are + invalid. + + See Also + -------- + lanczos_smallest : Approximate the smallest eigenpair of a Hermitian + operator. + cg : Solve Hermitian positive-definite systems. + + Notes + ----- + The residual-based stopping criterion uses + :math:`\|A x - \lambda x\|` and is refreshed only every ``check_every`` + iterations, and always on the final iteration. This function is + JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are + static arguments. + + Examples + -------- + Estimate the largest eigenvalue of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([1.0, 3.0, 2.0]), X, ctx) + >>> result = sc.power_iteration(A, maxiter=20, tol=1e-10) + >>> np.allclose(result.eigenvalue, 3.0) + True """ if isinstance(A, QuadraticForm): action = _action_from_quadratic_form(A) @@ -101,6 +181,7 @@ def _power_iteration_core( maxiter: int, check_every: int, ) -> tuple[Any, Any, Any, Any, Any]: + """Run the backend-loop implementation of power iteration.""" x, _ = normalize(action.domain, x) zero = action.ops.asarray(0.0, dtype=action.dtype) residual_norm = action.domain.norm(x) + float("inf") diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index 7159df4..b26c10b 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -1,3 +1,5 @@ +"""Linear operator abstractions, concrete operators, and algebra helpers.""" + from ._base import LinOp from ._algebra import ( ComposedLinOp, diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py index 4cf6846..d748880 100644 --- a/spacecore/linop/_algebra.py +++ b/spacecore/linop/_algebra.py @@ -21,6 +21,7 @@ def is_scalar_like(value: Any) -> bool: def _conjugate_scalar(value: Any) -> Any: + """Return the scalar conjugate when the value supports conjugation.""" if hasattr(value, "conjugate"): return value.conjugate() if hasattr(value, "conj"): @@ -29,6 +30,7 @@ def _conjugate_scalar(value: Any) -> Any: def _require_same_context(ops: Sequence[LinOp]) -> Context: + """Return the common context for algebra operands or raise.""" ctx = ops[0].ctx for i, op in enumerate(ops[1:], start=1): if not _same_context_for_algebra(ops[0].ctx, op.ctx): @@ -40,6 +42,7 @@ def _require_same_context(ops: Sequence[LinOp]) -> Context: def _same_space_for_algebra(left: Any, right: Any) -> bool: + """Return whether two spaces are compatible for algebraic composition.""" if type(left) is not type(right): return False if tuple(left.shape) != tuple(right.shape): @@ -56,12 +59,14 @@ def _same_space_for_algebra(left: Any, right: Any) -> bool: def _require_linop(op: Any, name: str) -> LinOp: + """Return ``op`` as a linear operator or raise a typed error.""" if not isinstance(op, LinOp): raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.") return op def _scalar_equal(value: Any, target: Any) -> bool: + """Return whether two scalar-like values compare equal.""" try: return bool(value == target) except Exception: @@ -69,14 +74,17 @@ def _scalar_equal(value: Any, target: Any) -> bool: def _is_zero_scalar(value: Any) -> bool: + """Return whether ``value`` is scalar-like zero.""" return _scalar_equal(value, 0) def _is_one_scalar(value: Any) -> bool: + """Return whether ``value`` is scalar-like one.""" return _scalar_equal(value, 1) def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]: + """Flatten nested lazy sums into a tuple of terms.""" terms: list[LinOp] = [] for i, op in enumerate(ops): op = _require_linop(op, f"ops[{i}]") @@ -96,6 +104,16 @@ def make_sum(ops: Sequence[LinOp]) -> LinOp: does not collect like terms, reorder operands, or attempt full symbolic optimization. All operands must have the same context, domain, and codomain before a simplified operator is returned. + + Parameters + ---------- + ops : sequence of LinOp + Nonempty sequence of operators with common domain and codomain. + + Returns + ------- + LinOp + Simplified lazy sum, a single operand, or a zero operator. """ if not ops: raise ValueError("make_sum requires a nonempty sequence of LinOp operands.") @@ -132,6 +150,18 @@ def make_scaled(scalar: Any, op: LinOp) -> LinOp: scalar. It does not distribute scaling over sums or perform full symbolic optimization. Complex scalars retain the usual conjugated coefficient in ``rapply`` through ``ScaledLinOp``. + + Parameters + ---------- + scalar : scalar-like + Scalar coefficient multiplying ``op``. + op : LinOp + Operator to scale. + + Returns + ------- + LinOp + Simplified scalar multiple. """ op = _require_linop(op, "op") if not is_scalar_like(scalar): @@ -158,6 +188,18 @@ def make_composed(left: LinOp, right: LinOp) -> LinOp: multi-factor chains or attempt full symbolic optimization. Operands must have the same context and compatible middle spaces before a simplified operator is returned. + + Parameters + ---------- + left : LinOp + Operator applied second. + right : LinOp + Operator applied first. + + Returns + ------- + LinOp + Simplified lazy composition representing ``left @ right``. """ left = _require_linop(left, "left") right = _require_linop(right, "right") @@ -181,7 +223,7 @@ def make_composed(left: LinOp, right: LinOp) -> LinOp: @jax_pytree_class class ScaledLinOp(LinOp[Domain, Codomain]): - """ + r""" Lazy scalar multiple of a linear operator. ``ScaledLinOp(alpha, A)`` represents the mathematical operator @@ -193,6 +235,20 @@ class ScaledLinOp(LinOp[Domain, Codomain]): ``x in A.domain``. The reverse action is ``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so complex scalars use the conjugated coefficient. + + Parameters + ---------- + scalar : scalar-like + Scalar multiplier. + op : LinOp + Operator being scaled. + + Attributes + ---------- + scalar : scalar-like + Stored scalar multiplier. + op : LinOp + Stored operand. """ def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: @@ -222,27 +278,31 @@ def rvapply(self, ys: Any, batch_space=None) -> Any: return _conjugate_scalar(self.scalar) * self.op.rvapply(ys, batch_space) def __eq__(self, other: Any) -> bool: + """Return whether another scaled operator has the same scalar and operand.""" if type(other) is type(self): return self.scalar == other.scalar and self.op == other.op return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = (self.scalar, self.op) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" scalar, op = children return cls(scalar, op) def _convert(self, new_ctx: Context) -> ScaledLinOp: + """Convert the operand to ``new_ctx`` while preserving the scalar.""" return ScaledLinOp(self.scalar, self.op.convert(new_ctx)) @jax_pytree_class class SumLinOp(LinOp[Domain, Codomain]): - """ + r""" Lazy finite sum of linear operators with common spaces. ``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty @@ -253,6 +313,17 @@ class SumLinOp(LinOp[Domain, Codomain]): The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared domain element ``x``. The reverse action is ``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``. + + Parameters + ---------- + ops : sequence of LinOp + Nonempty sequence of operators with common context, domain, and + codomain. + + Attributes + ---------- + parts : tuple of LinOp + Stored operands in the lazy sum. """ def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: @@ -315,26 +386,30 @@ def rvapply(self, ys: Any, batch_space=None) -> Any: return acc def __eq__(self, other: Any) -> bool: + """Return whether another sum has the same operands.""" if type(other) is type(self): return self.ops_tuple == other.ops_tuple return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = self.ops_tuple aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" return cls(tuple(children)) def _convert(self, new_ctx: Context) -> SumLinOp: + """Convert all operands to ``new_ctx``.""" return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple)) @jax_pytree_class class ComposedLinOp(LinOp[Domain, Codomain]): - """ + r""" Lazy composition of two linear operators. ``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must @@ -345,6 +420,20 @@ class ComposedLinOp(LinOp[Domain, Codomain]): The forward action is ``apply(x) = A.apply(B.apply(x))`` for ``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))`` for ``z in A.codomain``. + + Parameters + ---------- + left : LinOp + Operator applied second. + right : LinOp + Operator applied first. + + Attributes + ---------- + left : LinOp + Left operand. + right : LinOp + Right operand. """ def __init__(self, left: LinOp, right: LinOp) -> None: @@ -383,27 +472,31 @@ def rvapply(self, zs: Any, batch_space=None) -> Any: return self.right.rvapply(self.left.rvapply(zs, in_space), middle) def __eq__(self, other: Any) -> bool: + """Return whether another composition has the same operands.""" if type(other) is type(self): return self.left == other.left and self.right == other.right return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = (self.left, self.right) aux = () return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" left, right = children return cls(left, right) def _convert(self, new_ctx: Context) -> ComposedLinOp: + """Convert both operands to ``new_ctx``.""" return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx)) @jax_pytree_class class ZeroLinOp(LinOp[Domain, Codomain]): - """ + r""" Lazy zero map between two spaces. ``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is @@ -413,6 +506,15 @@ class ZeroLinOp(LinOp[Domain, Codomain]): The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action is ``rapply(y) = 0_X`` for ``y in Y``. + + Parameters + ---------- + dom : Space + Domain space. + cod : Space + Codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. """ def __init__( @@ -429,6 +531,7 @@ def apply(self, x: Any) -> Any: return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Return the codomain zero without membership checks.""" return self.codomain.zeros() @checked_method(in_space="codomain", out_space="domain") @@ -437,6 +540,7 @@ def rapply(self, y: Any) -> Any: return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Return the domain zero without membership checks.""" return self.domain.zeros() def vapply(self, xs: Any, batch_space=None) -> Any: @@ -473,27 +577,31 @@ def is_hermitian(self) -> bool: return self.domain == self.codomain def __eq__(self, other: Any) -> bool: + """Return whether another zero map has the same spaces.""" if type(other) is type(self): return self.domain == other.domain and self.codomain == other.codomain return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = () aux = (self.domain, self.codomain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" domain, codomain, ctx = aux return cls(domain, codomain, ctx) def _convert(self, new_ctx: Context) -> ZeroLinOp: + """Convert domain and codomain spaces to ``new_ctx``.""" return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx) @jax_pytree_class class IdentityLinOp(LinOp[Domain, Domain]): - """ + r""" Lazy identity map on a space. ``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The @@ -502,6 +610,13 @@ class IdentityLinOp(LinOp[Domain, Domain]): The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is ``rapply(x) = x`` for ``x in X``. + + Parameters + ---------- + space : Space + Domain and codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``space``. """ def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: @@ -513,6 +628,7 @@ def apply(self, x: Any) -> Any: return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Return ``x`` without membership checks.""" return x @checked_method(in_space="codomain", out_space="domain") @@ -521,6 +637,7 @@ def rapply(self, x: Any) -> Any: return self._rapply_unchecked(x) def _rapply_unchecked(self, x: Any) -> Any: + """Return ``x`` without membership checks.""" return x def vapply(self, xs: Any, batch_space=None) -> Any: @@ -561,21 +678,25 @@ def is_hermitian(self) -> bool: return True def __eq__(self, other: Any) -> bool: + """Return whether another identity map has the same space.""" if type(other) is type(self): return self.domain == other.domain return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = () aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" domain, ctx = aux return cls(domain, ctx) def _convert(self, new_ctx: Context) -> IdentityLinOp: + """Convert the identity space to ``new_ctx``.""" return IdentityLinOp(self.domain.convert(new_ctx), new_ctx) @@ -596,25 +717,25 @@ class MatrixFreeLinOp(LinOp[Domain, Codomain]): Parameters ---------- - apply: + apply : callable Callable with signature ``apply(x: Any) -> Any`` implementing the forward map from ``dom`` to ``cod``. - rapply: + rapply : callable Callable with signature ``rapply(y: Any) -> Any`` implementing the adjoint map from ``cod`` back to ``dom``. - dom: + dom : Space Domain space containing valid inputs for ``apply`` and outputs from ``rapply``. - cod: + cod : Space Codomain space containing outputs from ``apply`` and valid inputs for ``rapply``. - ctx: + ctx : Context, str, or None, optional Optional context specification. An explicit context wins over inferred contexts from ``dom`` and ``cod``. - vapply: + vapply : callable or None, optional Optional callable with signature ``vapply(xs: Any) -> Any`` for batched forward application. If omitted, backend ``vmap`` fallback is used. - rvapply: + rvapply : callable or None, optional Optional callable with signature ``rvapply(ys: Any) -> Any`` for batched adjoint application. If omitted, backend ``vmap`` fallback is used. diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index 4cff12d..4589e4c 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -14,15 +14,47 @@ Codomain = TypeVar('Codomain', bound=Space) class LinOp(ContextBound, Generic[Domain, Codomain]): - """ - Minimal linear operator (morphism) between two spaces. - - This class is intentionally small. It defines no matrix semantics, - arithmetic, or storage assumptions. - - Its sole purpose is to represent a linear map - ``A : dom -> cod`` - with access to both forward and adjoint actions. + r""" + Represent a linear map between two spaces. + + This class is intentionally small. It defines no storage assumptions and + requires subclasses to provide forward and adjoint actions. + + The adjoint :math:`A^*` satisfies + :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X` for + :math:`x \in X` and :math:`y \in Y`. For complex operators this is the + conjugate adjoint. + + Parameters + ---------- + dom : Space + Domain space ``X``. + cod : Space + Codomain space ``Y``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom`` and + ``cod``. + + Attributes + ---------- + dom : Space + Domain space converted to ``ctx``. + cod : Space + Codomain space converted to ``ctx``. + ctx : Context + Resolved backend context. + + Examples + -------- + Use a concrete dense operator as a :class:`LinOp`. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 2.0]]), X, X, ctx) + >>> A.apply(ctx.asarray([3.0, 4.0])) + array([3., 8.]) """ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None): @@ -61,23 +93,11 @@ def A(self) -> Any: @abstractmethod def apply(self, x: Any) -> Any: - """ - Forward application: y = A x - - Contract: - - x is an element of self.dom - - return value is an element of self.cod - """ + """Apply the forward map to an element of ``self.domain``.""" @abstractmethod def rapply(self, y: Any) -> Any: - """ - Adjoint application: x = A^* y - - Contract: - - y is an element of self.cod - - return value is an element of self.dom - """ + """Apply the adjoint map to an element of ``self.codomain``.""" def __call__(self, x: Any) -> Any: """Apply this linear operator to ``x``.""" @@ -108,6 +128,7 @@ def rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: return self._fallback_rvapply(ys, batch_space) def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + """Infer leading batch dimensions from a value and base space.""" if hasattr(space, "spaces") and isinstance(value, tuple) and value: return self._infer_batch_shape(space.spaces[0], value[0]) shape = tuple(getattr(value, "shape", ())) @@ -127,12 +148,14 @@ def _input_batch_space( value: Any, batch_space: Space | None, ) -> Space: + """Return the batch space used to validate batched inputs.""" if batch_space is not None: return batch_space batch_shape = self._infer_batch_shape(space, value) return space.batch(batch_shape, tuple(range(len(batch_shape)))) def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + """Return the batch space corresponding to a batched output.""" batch_shape = getattr(input_batch_space, "batch_shape", None) batch_axes = getattr(input_batch_space, "batch_axes", None) if batch_shape is None or batch_axes is None: @@ -140,6 +163,7 @@ def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: return space.batch(tuple(batch_shape), tuple(batch_axes)) def _fallback_vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + """Apply ``self.apply`` over a leading batch with backend ``vmap``.""" in_space = self._input_batch_space(self.domain, xs, batch_space) if self._enable_checks: in_space._check_member(xs) @@ -149,6 +173,7 @@ def _fallback_vapply(self, xs: Any, batch_space: Space | None = None) -> Any: return ys def _fallback_rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + """Apply ``self.rapply`` over a leading batch with backend ``vmap``.""" in_space = self._input_batch_space(self.codomain, ys, batch_space) if self._enable_checks: in_space._check_member(ys) @@ -159,7 +184,14 @@ def _fallback_rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: @property def H(self) -> LinOp: - """Hermitian-adjoint view of this linear operator.""" + r"""Hermitian-adjoint view of this linear operator. + + Returns + ------- + LinOp + Adjoint view satisfying + :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X`. + """ from ._algebra import _AdjointViewLinOp view = getattr(self, "_adjoint_view", None) @@ -260,19 +292,24 @@ def to_dense(self) -> Any: return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) def assert_domain(self, x: Any) -> None: + """Raise if ``x`` is not in the domain.""" self.dom.check_member(x) def assert_codomain(self, y: Any) -> None: + """Raise if ``y`` is not in the codomain.""" self.cod.check_member(y) def __eq__(self, other: Any) -> bool: + """Return structural equality when implemented by a subclass.""" return NotImplemented @abstractmethod def tree_flatten(self): + """Flatten this operator for backend pytree registration.""" ... @classmethod @abstractmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from backend pytree data.""" ... diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index b133aed..4d0dbb4 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -14,13 +14,41 @@ @jax_pytree_class class DenseLinOp(LinOp[VectorSpace, VectorSpace]): - """ - Dense linear operator defined by an array A with shape: - - A.shape == cod.shape + dom.shape - - apply: y = A ⋅ x (contract over dom axes) - rapply: x = A^* ⋅ y (contract over cod axes) + r""" + Represent a dense tensor-backed linear operator. + + ``DenseLinOp(A, dom, cod)`` represents a linear map + :math:`A \colon X \to Y` where the stored dense array has shape + ``cod.shape + dom.shape``. Forward application contracts over the domain + axes; adjoint application uses the conjugate transpose of the flattened + matrix representation. + + Parameters + ---------- + A : DenseArray + Dense backend array with shape ``cod.shape + dom.shape``. + dom : Space + Domain space. + cod : Space or None, optional + Codomain space. If omitted, it is inferred from the leading axes of + ``A``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. + + Attributes + ---------- + A : DenseArray + Stored dense operator tensor. + + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 3.0]]), X, X, ctx) + >>> A.apply(ctx.asarray([1.0, 2.0])) + array([2., 6.]) """ def __init__(self, @@ -68,12 +96,11 @@ def A(self) -> DenseArray: @checked_method(in_space="dom", out_space="cod") def apply(self, x: DenseArray) -> DenseArray: - """ - Forward action: y = A ⋅ x with y in cod.shape. - """ + """Apply the dense operator to ``x``.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: + """Apply the flattened dense matrix without membership checks.""" x1 = x if self._dom_is_flat else x.reshape((self._dom_size,)) y1 = self._A2 @ x1 if self._cod_vector_fast_path: @@ -82,14 +109,14 @@ def _apply_unchecked(self, x: DenseArray) -> DenseArray: @checked_method(in_space="cod", out_space="dom") def rapply(self, y: DenseArray) -> DenseArray: - """ - Adjoint action: x = A^* ⋅ y with x in dom.shape. + r"""Apply the adjoint dense operator to ``y``. - For complex A, uses conjugate-transpose of the 2D reshaped matrix. + For complex ``A``, use the conjugate transpose of the flattened matrix. """ return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: + """Apply the flattened adjoint matrix without membership checks.""" y1 = y if self._cod_is_flat else y.reshape((self._cod_size,)) x1 = self._A2H @ y1 if self._dom_vector_fast_path: @@ -98,11 +125,13 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: @staticmethod def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + """Infer leading batch dimensions from an input array.""" shape = tuple(value.shape) return shape if base_ndim == 0 else shape[:-base_ndim] @staticmethod def _is_leading_batch(batch_space: Any) -> bool: + """Return whether a batch space uses leading batch axes.""" if batch_space is None: return True batch_shape = tuple(getattr(batch_space, "batch_shape", ())) @@ -111,6 +140,7 @@ def _is_leading_batch(batch_space: Any) -> bool: @staticmethod def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + """Return the explicit batch shape from a batch-space object.""" return tuple(getattr(batch_space, "batch_shape")) def _vapply_unchecked_leading( @@ -118,6 +148,7 @@ def _vapply_unchecked_leading( xs: DenseArray, batch_shape: tuple[int, ...], ) -> DenseArray: + """Apply the dense operator over leading batch axes.""" xs2 = xs.reshape((-1, self._dom_size)) ys2 = xs2 @ self._A2T if self._cod_vector_fast_path: @@ -132,6 +163,7 @@ def _rvapply_unchecked_leading( ys: DenseArray, batch_shape: tuple[int, ...], ) -> DenseArray: + """Apply the dense adjoint over leading batch axes.""" ys2 = ys.reshape((-1, self._cod_size)) xs2 = ys2 @ self._A2H.T if self._dom_vector_fast_path: @@ -142,6 +174,7 @@ def _rvapply_unchecked_leading( return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply a batch using the fast leading-axis path when possible.""" if not self._is_leading_batch(batch_space): return self._fallback_vapply(xs, batch_space) batch_shape = ( @@ -152,6 +185,7 @@ def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: return self._vapply_unchecked_leading(xs, batch_shape) def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply adjoints over a batch using the fast leading-axis path when possible.""" if not self._is_leading_batch(batch_space): return self._fallback_rvapply(ys, batch_space) batch_shape = ( @@ -162,6 +196,7 @@ def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: return self._rvapply_unchecked_leading(ys, batch_shape) def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply this operator independently over a batch of domain elements.""" in_space = self._input_batch_space(self.domain, xs, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): return self._fallback_vapply(xs, batch_space) @@ -174,6 +209,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: return ys def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply the adjoint independently over a batch of codomain elements.""" in_space = self._input_batch_space(self.codomain, ys, batch_space) if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): return self._fallback_rvapply(ys, batch_space) @@ -211,6 +247,7 @@ def is_hermitian(self) -> bool | None: return None def __eq__(self, x: Any) -> bool: + """Return whether another dense operator has the same spaces and values.""" if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod @@ -219,17 +256,20 @@ def __eq__(self, x: Any) -> bool: return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" aux = (self.dom, self.cod, self.ctx) children = (self.A,) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" dom, cod, ctx = aux A = children[0] return cls(A, dom, cod, ctx) def _convert(self, new_ctx: Context) -> DenseLinOp: + """Convert spaces and stored dense tensor to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_A = new_ctx.asarray(self.A) diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py index 389c9cb..b5c45f2 100644 --- a/spacecore/linop/_diagonal.py +++ b/spacecore/linop/_diagonal.py @@ -14,7 +14,39 @@ @jax_pytree_class class DiagonalLinOp(LinOp[VectorSpace, VectorSpace]): - """Coordinatewise diagonal linear operator on a vector space.""" + r""" + Represent a coordinatewise diagonal linear operator. + + ``DiagonalLinOp(diagonal, space)`` maps ``x`` to ``diagonal * x`` on a + :class:`VectorSpace`. The adjoint uses the complex conjugate of the + diagonal, so complex-valued diagonals follow the SpaceCore adjoint + convention. + + Parameters + ---------- + diagonal : DenseArray + Dense backend array with shape ``space.shape``. + space : VectorSpace or None, optional + Domain and codomain space. If omitted, a vector space is inferred from + ``diagonal.shape``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``space``. + + Attributes + ---------- + diagonal : DenseArray + Stored diagonal values. + + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> D = sc.DiagonalLinOp(ctx.asarray([2.0, 3.0]), X, ctx) + >>> D.apply(ctx.asarray([4.0, 5.0])) + array([ 8., 15.]) + """ def __init__( self, @@ -38,17 +70,21 @@ def __init__( @cached_property def A(self) -> DenseArray: + """Dense tensor representation of this diagonal operator.""" return self.to_dense() @checked_method(in_space="domain", out_space="codomain") def apply(self, x: DenseArray) -> DenseArray: + """Apply the diagonal operator to ``x``.""" return self.diagonal * x @checked_method(in_space="codomain", out_space="domain") def rapply(self, y: DenseArray) -> DenseArray: + """Apply the adjoint diagonal operator to ``y``.""" return self._diag_adjoint * y def _reshape_diagonal_for_batch(self, diagonal: DenseArray, batch_space: Any) -> DenseArray: + """Broadcast diagonal values over a batch space.""" batch_shape = tuple(getattr(batch_space, "batch_shape", ())) batch_axes = tuple(getattr(batch_space, "batch_axes", ())) total_ndim = len(self.domain.shape) + len(batch_shape) @@ -59,6 +95,7 @@ def _reshape_diagonal_for_batch(self, diagonal: DenseArray, batch_space: Any) -> return self.ops.reshape(diagonal, tuple(shape)) def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply this diagonal operator over a batch of domain elements.""" in_space = self._input_batch_space(self.domain, xs, batch_space) if self._enable_checks: in_space._check_member(xs) @@ -69,6 +106,7 @@ def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: return ys def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply the adjoint over a batch of codomain elements.""" in_space = self._input_batch_space(self.codomain, ys, batch_space) if self._enable_checks: in_space._check_member(ys) @@ -79,6 +117,7 @@ def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: return xs def to_dense(self) -> DenseArray: + """Return a dense tensor representation of this diagonal operator.""" flat = self.diagonal.reshape((prod(self.domain.shape),)) matrix = self.ops.diag(flat) return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) @@ -98,21 +137,25 @@ def is_hermitian(self) -> bool | None: return None def __eq__(self, other: Any) -> bool: + """Return whether another diagonal operator has the same space and values.""" if type(other) is type(self): return self.domain == other.domain and self.ops.allclose(self.diagonal, other.diagonal) return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = (self.diagonal,) aux = (self.domain, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" domain, ctx = aux return cls(children[0], domain, ctx) def _convert(self, new_ctx: Context) -> DiagonalLinOp: + """Convert the stored diagonal and space to ``new_ctx``.""" return DiagonalLinOp( new_ctx.asarray(self.diagonal), VectorSpace(tuple(self.domain.shape), new_ctx), diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 236418c..5e9cd8e 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -14,14 +14,39 @@ @jax_pytree_class class SparseLinOp(LinOp): - """ - Sparse linear operator implementing the tensor map A : dom -> cod where - conceptually A has shape cod.shape + dom.shape, but stored as a 2D sparse matrix: - - A2.shape == (prod(cod.shape), prod(dom.shape)) - - apply: y = A ⋅ x (contract over dom axes) - rapply: x = A^* ⋅ y (contract over cod axes) + r""" + Represent a sparse matrix-backed linear operator. + + ``SparseLinOp(A, dom, cod)`` represents a tensor map whose conceptual shape + is ``cod.shape + dom.shape`` while storage uses a two-dimensional sparse + matrix with shape ``(prod(cod.shape), prod(dom.shape))``. + + Parameters + ---------- + A : SparseArray + Sparse backend matrix with shape ``(prod(cod.shape), prod(dom.shape))``. + dom : Space + Domain space. + cod : Space + Codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. + + Attributes + ---------- + A : SparseArray + Stored sparse matrix representation. + + Examples + -------- + >>> import numpy as np + >>> import scipy.sparse as sps + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.SparseLinOp(ctx.assparse(sps.eye(2)), X, X, ctx) + >>> A.apply(ctx.asarray([1.0, 2.0])) + array([1., 2.]) """ def __init__(self, @@ -72,6 +97,7 @@ def apply(self, x: DenseArray) -> DenseArray: return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: + """Apply the stored sparse matrix without membership checks.""" x1 = x if self._dom_is_flat else x.reshape((self._dom_size,)) y1 = self.A @ x1 # (m,) if self._cod_vector_fast_path: @@ -88,6 +114,7 @@ def rapply(self, y: DenseArray) -> DenseArray: return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: + """Apply the stored sparse adjoint without membership checks.""" y1 = y if self._cod_is_flat else y.reshape((self._cod_size,)) x1 = self._AH @ y1 diff --git a/spacecore/linop/product/__init__.py b/spacecore/linop/product/__init__.py index 4c534ab..1d81f74 100644 --- a/spacecore/linop/product/__init__.py +++ b/spacecore/linop/product/__init__.py @@ -1,3 +1,5 @@ +"""Linear operators that map to or from product spaces.""" + from ._base import ProductLinOp from ._block import BlockDiagonalLinOp from ._from_single import StackedLinOp @@ -8,4 +10,4 @@ "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", -] \ No newline at end of file +] diff --git a/spacecore/linop/product/_base.py b/spacecore/linop/product/_base.py index 1decf86..b7a2567 100644 --- a/spacecore/linop/product/_base.py +++ b/spacecore/linop/product/_base.py @@ -10,7 +10,19 @@ @jax_pytree_class class ProductLinOp(LinOp[Domain, Codomain]): """ - Base class for linear operators assembled from component operators. + Define a base class for operators assembled from component operators. + + Parameters + ---------- + dom : Space + Domain space of the assembled operator. + cod : Space + Codomain space of the assembled operator. + parts : sequence of LinOp + Nonempty sequence of component operators. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom`` and + ``cod``. """ parts: Tuple[LinOp, ...] @@ -34,17 +46,17 @@ def __init__(self, @abstractmethod def _check_layout(self) -> None: - """ - Check incidence compatibility between self.parts and self.dom/self.cod. - """ + """Check incidence compatibility between parts and endpoint spaces.""" raise NotImplementedError @classmethod @abstractmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> ProductLinOp: + """Build a product operator from component operators.""" ... def __eq__(self, x: Any) -> bool: + """Return whether another product operator has the same layout.""" if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod @@ -54,11 +66,13 @@ def __eq__(self, x: Any) -> bool: return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = self.parts aux = (self.dom, self.cod, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" dom, cod, ctx = aux return cls(dom, cod, tuple(children), ctx) diff --git a/spacecore/linop/product/_block.py b/spacecore/linop/product/_block.py index 8d7ea19..9303225 100644 --- a/spacecore/linop/product/_block.py +++ b/spacecore/linop/product/_block.py @@ -12,16 +12,26 @@ @jax_pytree_class class BlockDiagonalLinOp(ProductLinOp[ProductSpace, ProductSpace]): - """ + r""" Block-diagonal operator between product spaces. - dom = X1 × ... × Xk - cod = Y1 × ... × Yk - - ops[i] : Xi -> Yi + If ``dom = X1 x ... x Xk`` and ``cod = Y1 x ... x Yk``, component + ``parts[i]`` maps ``Xi`` to ``Yi``. + + Parameters + ---------- + dom : ProductSpace + Product domain. + cod : ProductSpace + Product codomain. + parts : sequence of LinOp + Component operators with matching product incidence. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that each component maps the matching product component.""" if not isinstance(self.dom, ProductSpace) or not isinstance(self.cod, ProductSpace): raise TypeError("BlockDiagonalLinOp expects dom and cod to be ProductSpace.") @@ -36,23 +46,28 @@ def _check_layout(self) -> None: @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: + """Apply each block to the matching product component.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply each block without membership checks.""" if self._num_parts == 2: return self._apply_parts[0](x[0]), self._apply_parts[1](x[1]) return tuple(apply(xi) for apply, xi in zip(self._apply_parts, x)) @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: + """Apply each adjoint block to the matching product component.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply each adjoint block without membership checks.""" if self._num_parts == 2: return self._rapply_parts[0](y[0]), self._rapply_parts[1](y[1]) return tuple(rapply(yi) for rapply, yi in zip(self._rapply_parts, y)) def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this block-diagonal operator over a product batch.""" in_space = self._input_batch_space(self.domain, x, batch_space) if self._enable_checks: in_space._check_member(x) @@ -64,6 +79,7 @@ def vapply(self, x: Any, batch_space=None) -> Any: ) def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint over a product batch.""" in_space = self._input_batch_space(self.codomain, y, batch_space) if self._enable_checks: in_space._check_member(y) @@ -76,6 +92,7 @@ def rvapply(self, y: Any, batch_space=None) -> Any: @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: + """Build a block-diagonal operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -84,6 +101,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> BlockDiagonalLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/linop/product/_from_single.py b/spacecore/linop/product/_from_single.py index 2b65689..408098f 100644 --- a/spacecore/linop/product/_from_single.py +++ b/spacecore/linop/product/_from_single.py @@ -11,18 +11,27 @@ @jax_pytree_class class StackedLinOp(ProductLinOp[Domain, ProductSpace]): - """ + r""" Stack of operators from a single domain into a product codomain. - dom = X - cod = Y1 × ... × Yk - - ``ops[i] : X -> Yi`` - ``apply(x) = (ops[i](x))_i`` - ``rapply(y) = sum_i ops[i]^*(y_i)`` + If ``dom = X`` and ``cod = Y1 x ... x Yk``, component ``parts[i]`` maps + ``X`` to ``Yi``. Forward application returns a tuple of component outputs; + adjoint application sums component adjoints in ``X``. + + Parameters + ---------- + dom : Space + Shared component domain. + cod : ProductSpace + Product codomain. + parts : sequence of LinOp + Operators from ``dom`` to each component of ``cod``. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that every component maps the shared domain to one codomain part.""" if not isinstance(self.cod, ProductSpace): raise TypeError("StackedLinOp expects cod to be ProductSpace.") @@ -37,18 +46,22 @@ def _check_layout(self) -> None: @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: + """Apply each component operator to the same input.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply component operators without membership checks.""" if self._num_parts == 2: return self._apply_parts[0](x), self._apply_parts[1](x) return tuple(apply(x) for apply in self._apply_parts) @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: + """Apply component adjoints and sum in the shared domain.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply component adjoints without membership checks.""" if self._num_parts == 2: x0 = self._rapply_parts[0](y[0]) x1 = self._rapply_parts[1](y[1]) @@ -61,6 +74,7 @@ def _rapply_unchecked(self, y: Any) -> Any: return acc def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this stacked operator over a batch.""" in_space = self._input_batch_space(self.domain, x, batch_space) if self._enable_checks: in_space._check_member(x) @@ -72,6 +86,7 @@ def vapply(self, x: Any, batch_space=None) -> Any: ) def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint stacked operator over a product batch.""" in_space = self._input_batch_space(self.codomain, y, batch_space) if self._enable_checks: in_space._check_member(y) @@ -86,6 +101,7 @@ def rvapply(self, y: Any, batch_space=None) -> Any: @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: + """Build a stacked operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -95,6 +111,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> StackedLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/linop/product/_to_single.py b/spacecore/linop/product/_to_single.py index 60c91fa..61f2824 100644 --- a/spacecore/linop/product/_to_single.py +++ b/spacecore/linop/product/_to_single.py @@ -11,18 +11,27 @@ @jax_pytree_class class SumToSingleLinOp(ProductLinOp[ProductSpace, Codomain]): - """ + r""" Sum of component operators from a product domain into a single codomain. - dom = X1 × ... × Xk - cod = Y - - ``ops[i] : Xi -> Y`` - ``apply(x) = sum_i ops[i](x_i)`` - ``rapply(y) = (ops[i]^*(y))_i`` + If ``dom = X1 x ... x Xk`` and ``cod = Y``, component ``parts[i]`` maps + ``Xi`` to ``Y``. Forward application sums component outputs in ``Y``; + adjoint application returns the tuple of component adjoints. + + Parameters + ---------- + dom : ProductSpace + Product domain. + cod : Space + Shared codomain. + parts : sequence of LinOp + Operators from each product component to ``cod``. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that every component maps one product part to the shared codomain.""" if not isinstance(self.dom, ProductSpace): raise TypeError("SumToSingleLinOp expects dom to be ProductSpace.") @@ -37,9 +46,11 @@ def _check_layout(self) -> None: @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: + """Apply component operators and sum in the codomain.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply component operators without membership checks.""" if self._num_parts == 2: y0 = self._apply_parts[0](x[0]) y1 = self._apply_parts[1](x[1]) @@ -53,14 +64,17 @@ def _apply_unchecked(self, x: Any) -> Any: @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: + """Apply each component adjoint to the shared codomain element.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply component adjoints without membership checks.""" if self._num_parts == 2: return self._rapply_parts[0](y), self._rapply_parts[1](y) return tuple(rapply(y) for rapply in self._rapply_parts) def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this sum-to-single operator over a product batch.""" in_space = self._input_batch_space(self.domain, x, batch_space) if self._enable_checks: in_space._check_member(x) @@ -74,6 +88,7 @@ def vapply(self, x: Any, batch_space=None) -> Any: return acc def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint over a codomain batch.""" in_space = self._input_batch_space(self.codomain, y, batch_space) if self._enable_checks: in_space._check_member(y) @@ -86,6 +101,7 @@ def rvapply(self, y: Any, batch_space=None) -> Any: @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: + """Build a sum-to-single operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -95,6 +111,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> SumToSingleLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/space/__init__.py b/spacecore/space/__init__.py index 4883640..b9fb285 100644 --- a/spacecore/space/__init__.py +++ b/spacecore/space/__init__.py @@ -1,3 +1,5 @@ +"""Vector space abstractions, concrete spaces, and validation checks.""" + from ._checks import ( BackendCheck, DTypeCheck, diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index 53318d1..348b554 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -11,9 +11,9 @@ class Space(ContextBound): """ - Abstract Space. + Define the geometry and linear structure of a vector space. - A Space owns the *geometry* (inner product, norm) and the basic linear + A space owns the geometry (inner product, norm) and the basic linear structure (add/scale/axpy) for its elements. Membership validation is exposed through ``check_member``, which respects @@ -21,7 +21,35 @@ class Space(ContextBound): checked that policy may call ``_check_member`` to run the concrete checks exactly once. - Solvers should use only this API. + Parameters + ---------- + shape : tuple of int + Canonical coordinate shape for elements of the space. + ctx : Context, str, or None, optional + Backend context specification. Default resolves to the global context. + + Attributes + ---------- + shape : tuple of int + Canonical element shape. + ctx : Context + Resolved backend context inherited from :class:`ContextBound`. + + Notes + ----- + Solvers use only this API. Concrete spaces define storage constraints, + membership checks, and flattening rules. + + Examples + -------- + Instantiate the concrete :class:`VectorSpace` subclass. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> X.shape + (2,) """ checks: ClassVar[tuple[SpaceCheck, ...]] = () @@ -46,15 +74,7 @@ def member_checks(self) -> tuple[SpaceCheck, ...]: return tuple(checks) def _check_member(self, x: Any) -> None: - """ - Raise if `x` is not a valid element of this space. - - Typical checks: - - x.space is self (if your elements carry a .space) - - backend family consistency (via ctx) - - representation is supported - - shape/structure constraints (Hermitian, block sizes, etc.) - """ + """Raise if ``x`` is not a valid element of this space.""" for check in self.member_checks(): check(self, x) @@ -80,18 +100,16 @@ def axpy(self, a: Any, x: Any, y: Any) -> Any: @abstractmethod def inner(self, x: Any, y: Any) -> Any: - """ - Inner product ⟨x, y⟩ for elements of this space. - """ + r"""Return :math:`\langle x, y \rangle_X` for elements of this space.""" def norm(self, x: Any) -> Any: - """Induced norm ||x|| = sqrt(real(⟨x,x⟩)). Override if you can do better.""" + r"""Return the induced norm :math:`\sqrt{\operatorname{Re}\langle x, x\rangle_X}`.""" v = self.ctx.ops.real(self.inner(x, x)) return self.ctx.ops.sqrt(v) @abstractmethod def eigh(self, x: Any, k: int = None) -> Any: - """Eigendecomposition of x (if applicable).)""" + """Return an eigendecomposition of ``x`` when the space defines one.""" @abstractmethod def flatten(self, x: Any) -> DenseArray: diff --git a/spacecore/space/_batch.py b/spacecore/space/_batch.py index 12f2971..08f8d4c 100644 --- a/spacecore/space/_batch.py +++ b/spacecore/space/_batch.py @@ -16,6 +16,7 @@ def _batched_shape( batch_shape: tuple[int, ...], batch_axes: tuple[int, ...], ) -> tuple[int, ...]: + """Interleave base and batch dimensions according to ``batch_axes``.""" total_ndim = len(base_shape) + len(batch_shape) axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) if len(batch_shape) != len(axes): @@ -40,11 +41,31 @@ def _batched_shape( class BatchSpace(Space): """ - Wrapper space representing a batch of elements from a base space. + Represent a batch of elements from a base space. ``BatchSpace(X, batch_shape, batch_axes)`` represents ``X`` repeated over the given batch dimensions. It deliberately wraps the original space rather than folding batch dimensions into the base ``Space`` instance. + + Parameters + ---------- + base : Space + Space whose elements are batched. + batch_shape : tuple of int + Sizes of batch dimensions. + batch_axes : tuple of int + Axes occupied by batch dimensions in the batched representation. + ctx : Context, str, or None, optional + Backend context specification. Default is ``base.ctx``. + + Attributes + ---------- + base : Space + Converted base space. + batch_shape : tuple of int + Batch dimension sizes. + batch_axes : tuple of int + Batch axis positions. """ def __init__( @@ -80,11 +101,13 @@ def _is_product(self) -> bool: return isinstance(self.base, ProductSpace) def _component_spaces(self) -> tuple[BatchSpace, ...]: + """Return batched component spaces for product-space bases.""" if not isinstance(self.base, ProductSpace): raise TypeError("BatchSpace component spaces are available only for ProductSpace bases.") return tuple(sp.batch(self.batch_shape, self.batch_axes) for sp in self.base.spaces) def _check_member(self, x: Any) -> None: + """Raise if ``x`` is not a valid batched element.""" if isinstance(self.base, ProductSpace): if not isinstance(x, tuple) or len(x) != self.base.arity: raise TypeError( @@ -103,24 +126,28 @@ def _check_member(self, x: Any) -> None: check(self, x) def zeros(self) -> Any: + """Return the batched zero element.""" if isinstance(self.base, ProductSpace): return tuple(space.zeros() for space in self._component_spaces()) return self.ops.zeros(self.shape, dtype=self.dtype) @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> Any: + """Return the batched sum ``x + y``.""" if isinstance(self.base, ProductSpace): return tuple(space.add(xi, yi) for space, xi, yi in zip(self._component_spaces(), x, y)) return x + y @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> Any: + """Return the batched scalar product ``a * x``.""" if isinstance(self.base, ProductSpace): return tuple(space.scale(a, xi) for space, xi in zip(self._component_spaces(), x)) return a * x @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: + r"""Return :math:`\langle x, y\rangle` over the batched space.""" if isinstance(self.base, ProductSpace): acc = None for space, xi, yi in zip(self._component_spaces(), x, y): @@ -130,16 +157,19 @@ def inner(self, x: Any, y: Any) -> Any: return self.ops.vdot(x, y) def eigh(self, x: Any, k: int = None) -> Any: + """Raise because batched spaces do not define eigendecomposition.""" raise TypeError(f"{type(self).__name__}.eigh is not defined for batched spaces.") @checked_method(in_space="self") def flatten(self, x: Any) -> DenseArray: + """Flatten a batched element into dense coordinates.""" if isinstance(self.base, ProductSpace): parts = tuple(space.flatten(xi) for space, xi in zip(self._component_spaces(), x)) return parts[0] if len(parts) == 1 else self.ops.concatenate(parts, axis=0) return self.ops.reshape(x, (-1,)) def unflatten(self, v: DenseArray) -> Any: + """Convert dense batched coordinates into a batched element.""" vv = self.ctx.assert_dense(v) if self._enable_checks else v if isinstance(self.base, ProductSpace): if ( @@ -165,6 +195,7 @@ def unflatten(self, v: DenseArray) -> Any: @checked_method(in_space="self", out_space="self") def apply(self, x: Any, f: Callable) -> Any: + """Apply a function over batched elements using base-space semantics.""" if isinstance(self.base, ProductSpace): return tuple(space.apply(xi, f) for space, xi in zip(self._component_spaces(), x)) try: @@ -174,4 +205,5 @@ def apply(self, x: Any, f: Callable) -> Any: return y def _convert(self, new_ctx: Context) -> BatchSpace: + """Convert the base space to ``new_ctx``.""" return BatchSpace(self.base.convert(new_ctx), self.batch_shape, self.batch_axes, new_ctx) diff --git a/spacecore/space/_checks.py b/spacecore/space/_checks.py index 622ce32..4ec8035 100644 --- a/spacecore/space/_checks.py +++ b/spacecore/space/_checks.py @@ -10,6 +10,7 @@ class SpaceValidationError(ValueError, TypeError): def _shape_of(space: Any, x: Any) -> tuple[int, ...] | None: + """Return the backend-visible shape of ``x`` when available.""" try: return tuple(space.ops.shape(x)) except Exception: @@ -18,6 +19,7 @@ def _shape_of(space: Any, x: Any) -> tuple[int, ...] | None: def _dtype_of(space: Any, x: Any) -> Any: + """Return the backend-visible dtype of ``x`` when available.""" try: return space.ops.get_dtype(x) except Exception: @@ -26,23 +28,44 @@ def _dtype_of(space: Any, x: Any) -> Any: @dataclass(frozen=True) class SpaceCheck(ABC): + """ + Define a membership check for :class:`Space` objects. + + Parameters + ---------- + name : str + Human-readable check name used in diagnostics. + """ + name: str def __call__(self, space: Any, x: Any) -> None: + """Raise :class:`SpaceValidationError` when ``x`` is invalid.""" if not self.is_valid(space, x): raise SpaceValidationError(self.error_message(space, x)) @abstractmethod def is_valid(self, space: Any, x: Any) -> bool: + """Return whether ``x`` is valid for ``space``.""" ... @abstractmethod def error_message(self, space: Any, x: Any) -> str: + """Return a diagnostic for an invalid ``x``.""" ... @dataclass(frozen=True) class BackendCheck(SpaceCheck): + """ + Check that a value is a dense array for a space backend. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"backend"``. + """ + name: str = "backend" def is_valid(self, space: Any, x: Any) -> bool: @@ -54,6 +77,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ShapeCheck(SpaceCheck): + """ + Check that a value has the canonical shape of a space. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"shape"``. + """ + name: str = "shape" def is_valid(self, space: Any, x: Any) -> bool: @@ -65,6 +97,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class DTypeCheck(SpaceCheck): + """ + Check that a value has the dtype required by a space context. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"dtype"``. + """ + name: str = "dtype" def is_valid(self, space: Any, x: Any) -> bool: @@ -76,6 +117,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class SquareMatrixCheck(SpaceCheck): + """ + Check that a value has square trailing matrix axes. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"square_matrix"``. + """ + name: str = "square_matrix" def is_valid(self, space: Any, x: Any) -> bool: @@ -88,6 +138,21 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class HermitianCheck(SpaceCheck): + """ + Check that a value is Hermitian within tolerances. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"hermitian"``. + atol : float, optional + Absolute tolerance for Hermitian comparison. + rtol : float, optional + Relative tolerance for Hermitian comparison. + enforce : bool, optional + Whether to enforce the Hermitian comparison. + """ + name: str = "hermitian" atol: float = 1e-8 rtol: float = 1e-8 @@ -113,6 +178,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ProductStructureCheck(SpaceCheck): + """ + Check that a product-space value is a tuple of the right length. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"product_structure"``. + """ + name: str = "product_structure" def is_valid(self, space: Any, x: Any) -> bool: @@ -126,6 +200,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ProductComponentCheck(SpaceCheck): + """ + Check each component of a product-space value. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"product_components"``. + """ + name: str = "product_components" def is_valid(self, space: Any, x: Any) -> bool: diff --git a/spacecore/space/_herm.py b/spacecore/space/_herm.py index e8117af..0c9a71d 100644 --- a/spacecore/space/_herm.py +++ b/spacecore/space/_herm.py @@ -10,8 +10,8 @@ class HermitianSpace(VectorSpace): - """ - Space of dense n×n Hermitian matrices. + r""" + Represent dense Hermitian matrices with Frobenius geometry. Elements are backend-native dense arrays with shape ``(n, n)``. Membership enforces Hermitian structure up to tolerances. @@ -19,6 +19,24 @@ class HermitianSpace(VectorSpace): The inner product is Frobenius / Hilbert-Schmidt: `` = vdot(vec(X), vec(Y))``, where ``vdot`` conjugates the first argument according to backend rules. + + Parameters + ---------- + n : int + Matrix dimension. + atol : float, optional + Absolute tolerance for Hermitian membership checks. + rtol : float, optional + Relative tolerance for Hermitian membership checks. + enforce_herm : bool, optional + Whether membership checks enforce Hermitian structure. + ctx : Context, str, or None, optional + Backend context specification. + + Attributes + ---------- + n : int + Matrix dimension. """ def __init__(self, @@ -48,9 +66,11 @@ def __eq__(self, other: Any) -> bool: @property def n(self) -> int: + """Matrix dimension of this Hermitian space.""" return self.shape[0] def _local_checks(self): + """Return membership checks local to Hermitian spaces.""" return ( SquareMatrixCheck(), HermitianCheck( @@ -61,6 +81,7 @@ def _local_checks(self): ) def is_hermitian(self, x: DenseArray) -> bool: + """Return whether ``x`` satisfies this space's Hermitian check.""" return HermitianCheck( atol=self.atol, rtol=self.rtol, @@ -68,25 +89,29 @@ def is_hermitian(self, x: DenseArray) -> bool: ).is_valid(self, x) def symmetrize(self, x: DenseArray) -> DenseArray: - """Project onto the Hermitian cone: (X + X^H)/2.""" + r"""Project ``x`` onto the Hermitian subspace as :math:`(X + X^*) / 2`.""" return (x + x.T.conj()) * 0.5 @checked_method(in_space="self") def eigh(self, x: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]: + """Return the eigendecomposition of a Hermitian element.""" return self.ops.eigh(x) def unflatten(self, v: DenseArray) -> DenseArray: + """Reshape dense coordinates and symmetrize the result.""" vv = self.ctx.assert_dense(v) if self._enable_checks else v X = vv.reshape(self.shape) return self.symmetrize(X) @checked_method(in_space="self") def psd_proj(self, x: DenseArray) -> DenseArray: + """Project a Hermitian element onto the positive semidefinite cone.""" evals, evecs = self.ops.eigh(x) evals = self.ops.maximum(evals, 0.) return self.eig_to_dense(evals, evecs) def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: + """Reconstruct a Hermitian matrix from eigenvalues and eigenvectors.""" self.ctx.assert_dense(evals) self.ctx.assert_dense(evecs) X = (evecs * evals) @ evecs.T.conj() @@ -94,6 +119,7 @@ def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: return X def _convert(self, new_ctx: Context) -> HermitianSpace: + """Convert this Hermitian space to ``new_ctx``.""" return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx) @checked_method(in_space="self") diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index 6e52b5e..991a1b4 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -13,6 +13,7 @@ def _prod_int(shape: Tuple[int, ...]) -> int: + """Return the integer product of a shape tuple.""" p = 1 for d in shape: p *= int(d) @@ -20,27 +21,41 @@ def _prod_int(shape: Tuple[int, ...]) -> int: class ProductSpace(Space): - """ - Cartesian product space X = X1 × ... × Xk. - - Elements are tuples: - x = (x1, ..., xk) with xi ∈ Xi - - Canonical dense coordinates: - flatten(x) = concat(flatten_i(xi)) - - Notes: - - `shape` for this space is the *1D coordinate length* of the concatenated flattening. - - `eigh` has no canonical meaning here and raises by default. + r""" + Represent a Cartesian product of spaces. + + Elements are tuples ``(x1, ..., xk)`` with ``xi`` in ``spaces[i]``. + Dense coordinates concatenate the flattened coordinates of each component. + + Parameters + ---------- + spaces : tuple of Space + Nonempty tuple of component spaces. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from components. + + Attributes + ---------- + spaces : tuple of Space + Component spaces converted to ``ctx``. + arity : int + Number of component spaces. + + Notes + ----- + ``shape`` is the one-dimensional coordinate length of the concatenated + flattening. ``eigh`` has no canonical meaning and raises by default. """ def _convert(self, new_ctx: Context) -> Space: + """Convert all component spaces to ``new_ctx``.""" new_spaces = [] for sp in self.spaces: new_spaces.append(sp.convert(new_ctx)) return ProductSpace(tuple(new_spaces), new_ctx) def _local_checks(self): + """Return membership checks local to product spaces.""" return ProductStructureCheck(), ProductComponentCheck() def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) -> None: @@ -96,6 +111,7 @@ def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) self._is_flat1 = self._component_is_flat[1] def _validate_spaces(self, spaces: Any) -> Tuple[Space, ...]: + """Validate and normalize product component spaces.""" if isinstance(spaces, Sequence): spaces = tuple(spaces) for i, sp in enumerate(spaces): @@ -109,21 +125,26 @@ def _validate_spaces(self, spaces: Any) -> Tuple[Space, ...]: @property def arity(self) -> int: + """Number of component spaces.""" return self._arity def zeros(self) -> Tuple[Any, ...]: + """Return the product-space zero tuple.""" return tuple(s.zeros() for s in self.spaces) @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Tuple[Any, ...]: + """Return the componentwise product-space sum.""" return tuple(s.add(xi, yi) for s, xi, yi in zip(self.spaces, x, y)) @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Tuple[Any, ...]) -> Tuple[Any, ...]: + """Return the componentwise scalar product.""" return tuple(s.scale(a, xi) for s, xi in zip(self.spaces, x)) @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Any: + r"""Return the sum of component inner products.""" # Accumulate via backend ops (vdot works for scalars too, but sum is enough) acc = None for s, xi, yi in zip(self.spaces, x, y): @@ -132,6 +153,7 @@ def inner(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Any: return acc def eigh(self, x: Any, k: int = None) -> Any: + """Raise because product spaces do not define a canonical eigendecomposition.""" raise NotImplementedError( "ProductSpace.eigh is not defined. " "Call eigh on a specific component space, or define a custom convention." @@ -139,6 +161,7 @@ def eigh(self, x: Any, k: int = None) -> Any: @checked_method(in_space="self") def flatten(self, x: Tuple[Any, ...]) -> DenseArray: + """Concatenate component coordinate vectors into one dense vector.""" if self._vector_fast_path: if self._arity == 1: return x[0] if self._component_is_flat[0] else x[0].reshape((-1,)) @@ -171,6 +194,7 @@ def flatten(self, x: Tuple[Any, ...]) -> DenseArray: return self._concatenate(parts, axis=0) def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: + """Split dense coordinates into component-space elements.""" if self._enable_checks: v = self.ctx.assert_dense(v) v1 = v if tuple(getattr(v, "shape", ())) == self.shape else v.reshape((-1,)) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index 47f25df..65e985b 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -11,16 +11,36 @@ class VectorSpace(Space): - """ - Dense vector space R^{n1, ..., nK} or C^{n1, ..., nK}. - - Elements: - - backend-native dense arrays; - - canonical shape is (n1, ..., nK). - - Geometry: - - Euclidean / ℓ2 inner product - ⟨x, y⟩ = vdot(x, y). + r""" + Represent dense backend arrays with Euclidean geometry. + + Elements are backend-native dense arrays with canonical shape ``shape``. + The inner product is :math:`\langle x, y\rangle_X = \operatorname{vdot}(x,y)`, + where the backend conjugates the first argument for complex arrays. + + Parameters + ---------- + shape : tuple of int + Canonical coordinate shape for elements of the space. + ctx : Context, str, or None, optional + Backend context specification. Default resolves to the global context. + + Attributes + ---------- + shape : tuple of int + Canonical element shape. + ctx : Context + Resolved backend context. + + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> x = ctx.asarray([1.0, 2.0]) + >>> X.inner(x, x) + np.float64(5.0) """ def __init__(self, shape: Tuple[int, ...], ctx: Context | str | None = None) -> None: @@ -29,40 +49,50 @@ def __init__(self, shape: Tuple[int, ...], ctx: Context | str | None = None) -> self._is_flat_shape = self.shape == (self._size,) def _local_checks(self): + """Return membership checks local to dense vector spaces.""" return BackendCheck(), ShapeCheck(), DTypeCheck() def zeros(self) -> DenseArray: + """Return the zero vector in this space.""" return self.ops.zeros(self.shape, dtype=self.dtype) @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> DenseArray: + """Return the vector-space sum ``x + y``.""" return x + y @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> DenseArray: + """Return the scalar product ``a * x``.""" return a * x @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: + r"""Return :math:`\langle x, y\rangle_X` using backend ``vdot``.""" return self.ops.vdot(x, y) def eigh(self, x: Any, k: int = None) -> Any: + """Raise because vector elements do not have a canonical eigendecomposition.""" raise TypeError( f"{type(self).__name__}.eigh is not defined for vector spaces." ) @checked_method(in_space="self") def flatten(self, X: DenseArray) -> DenseArray: + """Return ``X`` as a dense one-dimensional coordinate vector.""" return X if self._is_flat_shape else X.reshape((-1,)) def unflatten(self, v: DenseArray) -> DenseArray: + """Reshape a flat coordinate vector into this space's canonical shape.""" V = self.ctx.assert_dense(v) if self._enable_checks else v return V if self._is_flat_shape else V.reshape(self.shape) def _convert(self, new_ctx: Context) -> VectorSpace: + """Convert this vector space to ``new_ctx`` without changing shape.""" return VectorSpace(self.shape, new_ctx) def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + """Apply ``f`` entrywise and verify that shape is preserved.""" try: y = f(x) except Exception: diff --git a/spacecore/types/__init__.py b/spacecore/types/__init__.py index 65738fd..b91a969 100644 --- a/spacecore/types/__init__.py +++ b/spacecore/types/__init__.py @@ -1,3 +1,5 @@ +"""Common typing aliases and protocols used by SpaceCore.""" + from ._array import ArrayLike, DenseArray, SparseArray from ._dtype import DType from ._misc import Index, T, X, Y, R, Carry diff --git a/spacecore/types/_array.py b/spacecore/types/_array.py index ec4b03f..ce8501e 100644 --- a/spacecore/types/_array.py +++ b/spacecore/types/_array.py @@ -9,11 +9,20 @@ @runtime_checkable class ArrayLike(Protocol): - """Minimal array-like object accepted by public backend helpers. + """ + Define the minimal array-like object accepted by backend helpers. This intentionally only models common metadata. NumPy arrays, JAX arrays, PyTorch tensors, sparse arrays, scalar-like backend arrays, and array wrappers can satisfy this without implementing every dense-array method. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete array implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete array + implementations. """ @property @@ -24,12 +33,22 @@ def dtype(self) -> DType: ... class SparseArray(ArrayLike, Protocol): - """Portable sparse-array surface used by sparse linear operators. + """ + Define the portable sparse-array surface used by sparse operators. Backend-specific sparse APIs such as SciPy ``tocsr()``, JAX sparse ``indices``/``data``, and Torch ``to_dense()`` are intentionally not part of this protocol. Concrete backends may use those after checking that the object belongs to their sparse family. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete sparse array + implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete sparse array + implementations. """ @property @@ -41,12 +60,21 @@ def __matmul__(self, other: Any) -> Any: ... class DenseArray(ArrayLike, Protocol): - """Portable dense-array surface covering NumPy, JAX, and PyTorch arrays. + """ + Define the portable dense-array surface used by core abstractions. The protocol includes only operations that SpaceCore core abstractions use directly on dense arrays. Backend-specific metadata such as device, sharding, layout, strides, and gradient state belongs to concrete backend implementations, not to this portable type. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete dense array implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete dense array + implementations. """ @property From b2be42160ad09979f6554e6d8a8020aafc993b8d Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 22:21:45 -0300 Subject: [PATCH 36/44] Add linalg API docs page --- docs/source/api/index.rst | 1 + docs/source/api/linalg.rst | 68 ++++++++++++++++++++++++++++++++++++ spacecore/linalg/_cg.py | 6 ++-- spacecore/linalg/_lanczos.py | 6 ++-- spacecore/linalg/_lsqr.py | 6 ++-- 5 files changed, 78 insertions(+), 9 deletions(-) create mode 100644 docs/source/api/linalg.rst diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 654fbb8..b7636b8 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -12,3 +12,4 @@ directives for public objects instead of dumping entire modules. spaces linops functionals + linalg diff --git a/docs/source/api/linalg.rst b/docs/source/api/linalg.rst new file mode 100644 index 0000000..3ae5531 --- /dev/null +++ b/docs/source/api/linalg.rst @@ -0,0 +1,68 @@ +Linear algebra API +================== + +Linear algebra routines solve systems, estimate eigenpairs, and apply matrix +functions through :class:`~spacecore.linop.LinOp` objects. They use +space-aware vector operations and avoid materializing dense operators unless a +method explicitly projects to a small Krylov subspace. + +.. autosummary:: + :nosignatures: + + spacecore.linalg.cg + spacecore.linalg.lsqr + spacecore.linalg.lanczos_smallest + spacecore.linalg.stochastic_lanczos + spacecore.linalg.power_iteration + spacecore.linalg.expm_multiply + spacecore.linalg.CGResult + spacecore.linalg.LSQRResult + spacecore.linalg.LanczosResult + spacecore.linalg.StochasticLanczosResult + spacecore.linalg.PowerIterationResult + spacecore.linalg.ExpmMultiplyResult + +Solvers +------- + +.. autofunction:: spacecore.linalg.cg + +.. autoclass:: spacecore.linalg.CGResult + :members: + :undoc-members: + +.. autofunction:: spacecore.linalg.lsqr + +.. autoclass:: spacecore.linalg.LSQRResult + :members: + :undoc-members: + +Eigenvalue algorithms +--------------------- + +.. autofunction:: spacecore.linalg.lanczos_smallest + +.. autofunction:: spacecore.linalg.stochastic_lanczos + +.. autoclass:: spacecore.linalg.LanczosResult + :members: + :undoc-members: + +.. autoclass:: spacecore.linalg.StochasticLanczosResult + :members: + :undoc-members: + +.. autofunction:: spacecore.linalg.power_iteration + +.. autoclass:: spacecore.linalg.PowerIterationResult + :members: + :undoc-members: + +Matrix functions +---------------- + +.. autofunction:: spacecore.linalg.expm_multiply + +.. autoclass:: spacecore.linalg.ExpmMultiplyResult + :members: + :undoc-members: diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index eb9ac49..a717ba9 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -115,9 +115,9 @@ def cg( References ---------- - .. [1] Hestenes, M. R. and Stiefel, E., "Methods of Conjugate Gradients - for Solving Linear Systems," J. Res. Natl. Bur. Stand., 49 (1952), - 409-436. + Hestenes, M. R. and Stiefel, E., "Methods of Conjugate Gradients + for Solving Linear Systems," J. Res. Natl. Bur. Stand., 49 (1952), + 409-436. Examples -------- diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 0d58c64..60ebc17 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -309,9 +309,9 @@ def lanczos_smallest( References ---------- - .. [1] Lanczos, C., "An Iteration Method for the Solution of the Eigenvalue - Problem of Linear Differential and Integral Operators," J. Res. Natl. - Bur. Stand., 45 (1950), 255-282. + Lanczos, C., "An Iteration Method for the Solution of the Eigenvalue + Problem of Linear Differential and Integral Operators," J. Res. Natl. + Bur. Stand., 45 (1950), 255-282. Examples -------- diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index 8907158..d32ac39 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -120,9 +120,9 @@ def lsqr( References ---------- - .. [1] Paige, C. C. and Saunders, M. A., "LSQR: An Algorithm for Sparse - Linear Equations and Sparse Least Squares," ACM Trans. Math. Soft., - 8 (1982), 43-71. + Paige, C. C. and Saunders, M. A., "LSQR: An Algorithm for Sparse + Linear Equations and Sparse Least Squares," ACM Trans. Math. Soft., + 8 (1982), 43-71. Examples -------- From 7dde7b81bebc91b682adcda46a2d3b424fcb29d2 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 22:44:20 -0300 Subject: [PATCH 37/44] Clarify linalg solver contracts --- spacecore/linalg/_cg.py | 24 +++++++++++++------ spacecore/linalg/_expm.py | 40 ++++++++++++++++++++++---------- spacecore/linalg/_lanczos.py | 45 ++++++++++++++++++++++++++++-------- spacecore/linalg/_lsqr.py | 20 +++++++++++----- spacecore/linalg/_power.py | 19 ++++++++++++--- spacecore/linop/_dense.py | 10 +++++--- spacecore/linop/_sparse.py | 10 +++++--- tests/linalg/test_krylov.py | 10 ++++++++ 8 files changed, 134 insertions(+), 44 deletions(-) diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py index a717ba9..06dcf2f 100644 --- a/spacecore/linalg/_cg.py +++ b/spacecore/linalg/_cg.py @@ -55,20 +55,28 @@ def cg( r""" Solve :math:`A x = b` by conjugate gradients. - Require ``A`` to be a square Hermitian positive-definite :class:`LinOp`. - The implementation uses only :meth:`LinOp.apply` and the domain-space inner - product; it never materializes a dense matrix. + Require ``A`` to be square in the SpaceCore sense + (``A.domain == A.codomain``), Hermitian, and positive-definite with respect + to ``A.domain.inner``. The implementation uses only :meth:`LinOp.apply` and + the domain-space inner product; it never materializes a dense matrix. Parameters ---------- A : LinOp - Hermitian positive-definite linear operator. + Linear operator that must be Hermitian positive-definite with respect + to ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, + including the underlying space type and inner-product geometry. + Hermiticity and positive-definiteness are not validated by ``cg``; + indefinite or non-Hermitian operators can diverge or produce NaN + outputs without an explicit error. b : array-like Right-hand side in ``A.codomain``. x0 : array-like or None, optional Initial guess in ``A.domain``. Default is the zero vector. tol : float, optional - Relative residual tolerance. Default is 1e-6. + Relative tolerance on the linear-system residual. ``result.converged`` + is ``True`` when the residual norm is below + ``atol + tol * norm(b)``. Default is 1e-6. atol : float, optional Absolute residual tolerance. Default is 0.0. maxiter : int or None, optional @@ -110,8 +118,10 @@ def cg( flow. ``maxiter`` and ``check_every`` should be treated as static JAX arguments. - Works on real and complex operators. For complex operators, the method uses - the domain inner product convention implemented by ``A.domain``. + For complex operators, residual norms and step sizes are computed from the + real part of ``A.domain.inner(x, y)``. SpaceCore's complex inner-product + convention conjugates the first argument; custom :class:`Space` subclasses + must follow that convention for CG to converge correctly. References ---------- diff --git a/spacecore/linalg/_expm.py b/spacecore/linalg/_expm.py index fe60a2a..5b7d87a 100644 --- a/spacecore/linalg/_expm.py +++ b/spacecore/linalg/_expm.py @@ -54,26 +54,37 @@ def expm_multiply( r""" Compute :math:`\exp(t A) v` by Krylov projection. - Require ``A`` to be Hermitian or structurally unknown. The method builds a - Lanczos basis and applies the exponential of the small tridiagonal - projection, avoiding dense materialization of ``A``. + Require ``A`` to be square in the SpaceCore sense + (``A.domain == A.codomain``) and Hermitian with respect to + ``A.domain.inner``. The method builds a Lanczos basis and applies the + exponential of the small tridiagonal projection, avoiding dense + materialization of ``A``. Parameters ---------- A : LinOp - Square Hermitian linear operator. + Linear operator that must be Hermitian/self-adjoint with respect to + ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including + the underlying space type and inner-product geometry. Operators with + structurally unknown Hermiticity (``A.is_hermitian()`` returns + ``None``) are accepted on trust; the caller is responsible for ensuring + Hermiticity. Non-Hermitian inputs produce undefined results. v : array-like Initial vector in ``A.domain``. t : float or complex, optional - Scalar time/scale multiplying ``A``. Complex values are supported for - complex-valued contexts, for example Schrodinger evolution. Default is - 1.0. + Scalar multiplier on ``A``. Complex values require a complex-valued + ``ctx.dtype`` such as ``complex64`` or ``complex128``. Using a complex + ``t`` with a real-valued context produces backend-dependent results. + Default is 1.0. max_iter : int, optional Maximum Krylov dimension. Values around 20-50 are usually sufficient - when :math:`|t|\|A\|` is moderate. Default is 30. + when :math:`|t|\|A\|` is moderate. Must be a Python ``int`` rather + than a traced JAX scalar; under ``jax.jit`` it is treated as a static + argument and changing it triggers retracing. Default is 30. tol : float, optional - Breakdown tolerance for Lanczos and threshold for the projected - exponential residual estimate. Default is 1e-10. + Tolerance used both for Lanczos breakdown and for the convergence flag: + ``result.converged`` is ``True`` when the projected exponential + residual estimate is below ``tol``. Default is 1e-10. Returns ------- @@ -101,6 +112,11 @@ def expm_multiply( symmetric tridiagonal matrix ``T``. This is JIT-compatible on the JAX backend when ``max_iter`` is static. + Hermiticity is enforced only when it can be structurally verified: known + non-Hermitian operators raise ``ValueError``. Operators with unknown + structure, such as many matrix-free operators and operators on custom + spaces, are trusted. + The returned residual estimate is :math:`|\beta_m \phi_{m-1}|`, where ``phi`` is the projected exponential vector. Callers that need the true residual can perform one additional @@ -116,8 +132,8 @@ def expm_multiply( >>> X = sc.VectorSpace((2,), ctx) >>> A = sc.DiagonalLinOp(ctx.asarray([0.0, 1.0]), X, ctx) >>> v = ctx.asarray([2.0, 3.0]) - >>> result = sc.expm_multiply(A, v, t=0.5, max_iter=2) - >>> np.allclose(result.result, [2.0, 3.0 * np.exp(0.5)]) + >>> result = sc.expm_multiply(A, v, t=0.5, max_iter=5) + >>> np.allclose(result.result, [2.0, 3.0 * np.exp(0.5)], atol=1e-10) True """ A = require_linop(A) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 60ebc17..3595fcf 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -243,11 +243,12 @@ def lanczos_smallest( r""" Approximate the smallest eigenpair of a Hermitian operator. - The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an - element of ``A.domain``. The implementation keeps fixed-size coordinate - arrays for JAX compatibility, safely handles zero initial vectors, and - refines the returned eigenvalue with the Rayleigh quotient of the - reconstructed Ritz vector in the original space. + The operator is supplied as a square ``LinOp`` in the SpaceCore sense + (``A.domain == A.codomain``), and ``initial_vector`` is an element of + ``A.domain``. The implementation keeps fixed-size coordinate arrays for JAX + compatibility, safely handles zero initial vectors, and refines the + returned eigenvalue with the Rayleigh quotient of the reconstructed Ritz + vector in the original space. Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for ``span{v, A v, A^2 v, ...}`` and a tridiagonal projection @@ -258,15 +259,24 @@ def lanczos_smallest( Parameters ---------- A : LinOp - Square Hermitian linear operator. + Linear operator that must be Hermitian/self-adjoint with respect to + ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including + the underlying space type and inner-product geometry. Operators with + structurally unknown Hermiticity (``A.is_hermitian()`` returns + ``None``) are accepted on trust; the caller is responsible for ensuring + Hermiticity. Non-Hermitian inputs produce undefined results. initial_vector : array-like Starting vector in ``A.domain``. If it is numerically zero, the algorithm falls back to a deterministic coordinate vector. max_iter : int, optional - Maximum Krylov dimension. Default is 100. + Maximum Krylov dimension. Must be a Python ``int`` rather than a + traced JAX scalar; under ``jax.jit`` it is treated as a static argument + and changing it triggers retracing. Default is 100. tol : float, optional - Breakdown tolerance for the off-diagonal Lanczos coefficient. Default - is 1e-6. + Tolerance used for two purposes. Iteration stops at a check point when + the off-diagonal Lanczos coefficient falls below ``tol``; the returned + ``converged`` flag is ``True`` when the Ritz residual estimate is below + ``tol``. Default is 1e-6. check_every : int, optional Refresh the breakdown-based stopping decision every this many iterations and always on the final iteration. Default is @@ -288,7 +298,8 @@ def lanczos_smallest( TypeError If ``A`` is not a :class:`LinOp`. ValueError - If ``A`` is not square or if ``max_iter`` is invalid. + If ``A`` is not square, is known to be non-Hermitian, or if + ``max_iter`` is invalid. See Also -------- @@ -302,6 +313,18 @@ def lanczos_smallest( ``A.apply(eigenvector) - eigenvalue * eigenvector`` once more in the original space. + The "smallest Ritz value" is the smallest eigenvalue of the projected + tridiagonal matrix, not necessarily a good approximation of the smallest + eigenvalue of ``A``. Convergence to the actual smallest eigenvalue requires + the bottom of the spectrum to be separated and the initial vector to have + nonzero projection onto the corresponding eigenspace. For clustered low + eigenvalues, increase ``max_iter`` or use multiple initial vectors. + + Hermiticity is enforced only when it can be structurally verified: known + non-Hermitian operators raise ``ValueError``. Operators with unknown + structure, such as many matrix-free operators and operators on custom + spaces, are trusted. + This function is JIT-compatible on the JAX backend when ``max_iter`` and ``check_every`` are static arguments. For plain :class:`VectorSpace` domains, Euclidean reorthogonalization is vectorized; custom spaces use @@ -328,6 +351,8 @@ def lanczos_smallest( """ A = require_linop(A) require_square(A, "lanczos_smallest") + if A.is_hermitian() is False: + raise ValueError("lanczos_smallest requires A to be Hermitian/self-adjoint.") max_iter = _check_lanczos_max_iter(max_iter) check_every = check_interval(check_every) A.domain.check_member(initial_vector) diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py index d32ac39..7e975dc 100644 --- a/spacecore/linalg/_lsqr.py +++ b/spacecore/linalg/_lsqr.py @@ -59,21 +59,25 @@ def lsqr( r""" Solve :math:`\min_x \|A x - b\|` by LSQR. - Allow ``A`` to be rectangular or square. The method uses - :meth:`LinOp.apply` for forward products and ``A.H.apply`` for adjoint - products, so the normal equations are represented implicitly and no dense - matrix is formed. + Allow ``A`` to map between distinct ``domain`` and ``codomain`` spaces. + The method uses :meth:`LinOp.apply` for forward products and ``A.H.apply`` + for adjoint products, so the normal equations are represented implicitly + and no dense matrix is formed. Parameters ---------- A : LinOp - Linear operator defining the least-squares problem. + Linear operator with possibly distinct ``domain`` and ``codomain``. + For square ``A`` (``A.domain == A.codomain``), :func:`cg` is usually + preferred when ``A`` is also Hermitian positive-definite. b : array-like Right-hand side in ``A.codomain``. x0 : array-like or None, optional Initial guess in ``A.domain``. Default is the zero vector. tol : float, optional - Relative tolerance for the normal-equation residual. Default is 1e-6. + Relative tolerance for the normal-equation residual + ``norm(A.H @ (A @ x - b))``. ``result.converged`` is ``True`` when that + residual is below ``atol + tol * norm(b)``. Default is 1e-6. atol : float, optional Absolute tolerance for the normal-equation residual. Default is 0.0. maxiter : int or None, optional @@ -115,6 +119,10 @@ def lsqr( JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are static arguments. + The normal-equation residual can be much smaller than the solution error + for ill-conditioned ``A``. For ill-conditioned problems, use a tighter + ``tol`` or check the residual and solution quality directly. + Works on real and complex operators. For complex operators, ``A.H`` uses the conjugate adjoint. diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 2c50588..1bdd0ce 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -93,19 +93,25 @@ def power_iteration( Accept a square :class:`LinOp` or a :class:`QuadraticForm` exposing ``hess_apply``. Public dispatch converts either input into a fixed - self-adjoint action before entering the numerical loop. + self-adjoint action before entering the numerical loop. "Dominant" means + largest eigenvalue in absolute value, not necessarily the largest positive + eigenvalue. Parameters ---------- A : LinOp or QuadraticForm - Square operator or quadratic form whose dominant eigenpair is sought. + Square operator or quadratic form whose dominant eigenpair, largest in + absolute value, is sought. Linear-operator inputs must satisfy + ``A.domain == A.codomain``; this includes the underlying space type and + inner-product geometry. For spectral-norm estimates of a rectangular operator, pass ``A.H @ A``. x0 : array-like or None, optional Initial vector in the action domain. Default is a normalized all-ones vector in the domain geometry. tol : float, optional - Residual-norm tolerance. Default is 1e-6. + Residual-norm tolerance. ``result.converged`` is ``True`` when + ``norm(A @ x - lambda * x) < tol``. Default is 1e-6. maxiter : int or None, optional Maximum number of iterations. Default is ``prod(A.domain.shape)``. check_every : int, optional @@ -146,6 +152,13 @@ def power_iteration( JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are static arguments. + For operators with eigenvalues of mixed sign, the dominant eigenvalue is + the one with largest absolute value, which may be negative. Convergence + requires that this eigenvalue be separated from the rest in absolute value. + If the dominant modulus is degenerate, for example both ``lambda`` and + ``-lambda`` have maximum modulus, the iteration may oscillate between + subspaces. + Examples -------- Estimate the largest eigenvalue of a diagonal operator. diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index 4d0dbb4..95040ab 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -235,12 +235,16 @@ def is_hermitian(self) -> bool | None: Returns ------- - bool - ``True`` when the operator is square and its flattened matrix - equals its conjugate transpose. + bool or None + ``True`` or ``False`` for plain :class:`VectorSpace` domains, where + Hermiticity is checked against the Euclidean flattened matrix. + ``None`` for custom geometries whose inner product may differ from + the Euclidean coordinate product. """ if self.dom != self.cod: return False + if type(self.dom) is not VectorSpace: + return None try: return bool(self.ops.allclose(self._A2, self._A2H)) except Exception: diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 5e9cd8e..dfffe44 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -233,12 +233,16 @@ def is_hermitian(self) -> bool | None: Returns ------- - bool - ``True`` when the operator is square and its sparse matrix equals - its conjugate transpose within backend tolerances. + bool or None + ``True`` or ``False`` for plain :class:`VectorSpace` domains, where + Hermiticity is checked against the Euclidean sparse matrix. + ``None`` for custom geometries whose inner product may differ from + the Euclidean coordinate product. """ if self.dom != self.cod: return False + if type(self.dom) is not VectorSpace: + return None try: return bool(self.ops.allclose_sparse(self.A, self._AH)) except Exception: diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index 9b663ba..c4dc2f2 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -303,6 +303,16 @@ def test_lanczos_smallest_rejects_invalid_max_iter(): sc.lanczos_smallest(op, ctx.asarray([1.0]), max_iter=0) +def test_lanczos_smallest_rejects_structurally_non_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, 3.0]]), space, space, ctx) + + with pytest.raises(ValueError, match="Hermitian"): + sc.lanczos_smallest(op, ctx.asarray([1.0, 1.0]), max_iter=2) + + def test_lanczos_smallest_handles_eigenvalues_larger_than_1e10(): sc = importlib.import_module("spacecore") ctx = _ctx() From 93cb7f4b82996de11fdb17cf2ad3d475d5d17b35 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 22:55:28 -0300 Subject: [PATCH 38/44] Update README and Release Notes --- README.md | 331 ++++++++++++++-------------------- audit_caching.md | 36 ---- audit_jit.md | 108 ----------- docs/source/release_notes.rst | 193 ++++++++++++++++++++ 4 files changed, 324 insertions(+), 344 deletions(-) delete mode 100644 audit_caching.md delete mode 100644 audit_jit.md diff --git a/README.md b/README.md index b445a65..3d30ab1 100644 --- a/README.md +++ b/README.md @@ -1,275 +1,206 @@ # SpaceCore -SpaceCore exists for writing numerical algorithms once, independently of the -array backend. +[![CI](https://github.com/Pavlo3P/SpaceCore/actions/workflows/ci.yml/badge.svg)](https://github.com/Pavlo3P/SpaceCore/actions/workflows/ci.yml) +[![PyPI](https://img.shields.io/pypi/v/spacecore.svg)](https://pypi.org/project/spacecore/) +[![Python](https://img.shields.io/pypi/pyversions/spacecore.svg)](https://pypi.org/project/spacecore/) +[![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) -For example, the same algorithm can run with NumPy for debugging, JAX for -JIT/autodiff, and Torch for tensor workflows, while preserving the same -mathematical spaces and linear operators. +**Backend-agnostic vector spaces, linear operators, and iterative solvers for scientific computing.** -## What problem does SpaceCore solve? +Write your algorithm once. Run it on NumPy for development, JAX for GPU acceleration and autodiff, or PyTorch for ML pipelines — without changing a line. -Numerical algorithms often start as clear NumPy code and later need to move to -JAX, Torch, or another array system. Without a backend boundary, that migration -usually leaks through the whole implementation: array constructors, dtype -handling, inner products, sparse support, and linear-operator conventions all -become backend-specific. - -SpaceCore keeps those choices in a `Context`, while algorithms work with -mathematical objects: - -- a `Space` knows the structure and geometry of its elements; -- a `LinOp` maps one space to another; -- a `Functional` maps a space element to a scalar; -- backend-specific array creation and operations live behind `BackendOps`. - -The result is ordinary Python code whose core numerical logic is not tied to -one array library. +```python +import spacecore as sc +import numpy as np -Mental model: +# Define a space, a linear operator, and solve Ax = b +ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) +X = sc.VectorSpace((100,), ctx) +A = sc.DenseLinOp(np.random.randn(100, 100) @ np.random.randn(100, 100).T + np.eye(100), X, X, ctx) +b = ctx.asarray(np.random.randn(100)) -```text -BackendOps -> Context -> Space/LinOp/Functional -> Algorithm +result = sc.cg(A, b, tol=1e-8) +print(f"Converged in {result.num_iters} iterations.") ``` -## Write once, run twice - -This gradient descent loop uses only the `Space` and `LinOp` APIs. It does not -know whether the arrays are NumPy arrays, JAX arrays, or Torch tensors. +Same code on JAX with GPU? ```python -import numpy as np -import spacecore as sc - +ctx = sc.Context(sc.JaxOps(), dtype=jnp.float64) +# ... build A and b the same way using jax arrays ... +result = sc.cg(A, b, tol=1e-8) # runs on GPU, JIT-compiled +``` -def as_numpy(x): - if hasattr(x, "detach"): - return x.detach().cpu().numpy() - return np.asarray(x) +## Install +```bash +pip install spacecore # core (numpy only) +pip install "spacecore[jax]" # add JAX backend +pip install "spacecore[torch]" # add PyTorch backend +pip install "spacecore[jax,torch]" # both +``` -def make_problem(ctx): - X = sc.VectorSpace((3,), ctx) - Y = sc.VectorSpace((2,), ctx) +Python 3.11+. Built on the [Python Array API](https://data-apis.org/array-api/) standard. - A = sc.DenseLinOp( - ctx.asarray([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]]), - dom=X, - cod=Y, - ctx=ctx, - ) - x = ctx.asarray([1.0, 0.0, -1.0]) - b = ctx.asarray([0.5, 0.25]) - return X, Y, A, x, b +## What is SpaceCore for? +SpaceCore is for people writing numerical algorithms — optimization, inverse problems, eigensolvers, quantum simulation, computational geometry — who don't want to choose between NumPy, JAX, and PyTorch. -def gradient_step(X, A, x, b, eta): - r = A.apply(x) - b - grad = A.rapply(r) - return X.axpy(-eta, grad, x) +### Three things SpaceCore does well +**1. Matrix-free linear operators with algebra.** Write your operator once as `apply` and `adjoint` callables, then compose them: -def run_gradient_descent(X, A, x, b, eta, steps): - for _ in range(steps): - x = gradient_step(X, A, x, b, eta) - return x -``` +```python +# An FFT-based convolution operator, never materialized as a matrix +K = sc.MatrixFreeLinOp(apply=fft_convolve, rapply=fft_convolve_adjoint, dom=X, cod=X, ctx=ctx) +grad = sc.MatrixFreeLinOp(apply=finite_diff, rapply=neg_div, dom=X, cod=Y, ctx=ctx) -Run it with NumPy: +# Build the regularized system operator using algebra +lam = 0.01 +system = K.H @ K + lam * grad.H @ grad # SumLinOp of ComposedLinOps +rhs = K.H.apply(b) -```python -np_ctx = sc.Context(sc.NumpyOps(), dtype="float64") -X, Y, A, x, b = make_problem(np_ctx) -x_numpy = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_numpy)) +# Solve — no matrices were assembled +solution = sc.cg(system, rhs).x ``` -Later, run the same problem and the same `run_gradient_descent` with JAX: +**2. Cross-backend iterative solvers.** CG, LSQR, Lanczos, power iteration — all work uniformly across NumPy, JAX, and PyTorch. JAX backends JIT-compile: ```python -import jax - -jax.config.update("jax_enable_x64", True) +ctx = sc.Context(sc.JaxOps(), dtype=jnp.complex128) +A = build_hermitian_operator(ctx) -jax_ctx = sc.Context(sc.JaxOps(), dtype="float64") -X, Y, A, x, b = make_problem(jax_ctx) -x_jax = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_jax)) - -print(np.allclose(as_numpy(x_numpy), as_numpy(x_jax))) +# Find the smallest eigenpair via Lanczos +result = sc.lanczos_smallest(A, initial_vector, max_iter=50) +print(f"E_0 = {result.eigenvalue}, converged={result.converged}") ``` -Run it the same way with Torch: +**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `VectorSpace`, override `inner`, and every solver respects your geometry: ```python -torch_ctx = sc.Context(sc.TorchOps(), dtype="float64") -X, Y, A, x, b = make_problem(torch_ctx) -x_torch = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_torch)) - -print(np.allclose(as_numpy(x_numpy), as_numpy(x_torch))) -``` +class WeightedL2(sc.VectorSpace): + def __init__(self, shape, weights, ctx=None): + super().__init__(shape, ctx) + self.weights = self.ctx.asarray(weights) -All three backends produce the same result: + def inner(self, x, y): + return self.ops.vdot(x, self.weights * y) -```text -[ 1.184125 0.3411875 -0.447625 ] -[ 1.184125 0.3411875 -0.447625 ] -True -[ 1.184125 0.3411875 -0.447625 ] -True +# CG, LSQR, Lanczos all use this inner product automatically ``` -If you do not want to enable JAX 64-bit mode, use a supported dtype such as -`"float32"`. - -## What SpaceCore is not - -SpaceCore is not an optimizer and not a NumPy/JAX/Torch replacement. It provides -backend-aware spaces, operators, and context handling so you can write your own -algorithms without wiring them to one array library. - -## Core concepts - -### `Context` - -A `Context` specifies how objects are represented: - -- backend operations (`NumpyOps`, `JaxOps`, `TorchOps`, etc.); -- default dtype; -- runtime validation behavior. +This is the basis for RKHS spaces, truncated Fock spaces (quantum many-body), function spaces with quadrature, and anything else where the geometry isn't `sum(x * y)`. -Constructors resolve contexts in priority order: explicit `ctx=...`, then -contexts inferred from inputs, then the global default context. Advanced code -that needs this resolution step directly can call -`spacecore.resolve_context_priority(...)`. +## Quick examples -### `Space` +### Conjugate gradient on a symmetric positive-definite system -A `Space` describes the structure and geometry of values: - -- `VectorSpace` for Euclidean vectors and tensors; -- `HermitianSpace` for Hermitian or symmetric matrices; -- `ProductSpace` for Cartesian products of spaces. -- `BatchSpace` for batched elements such as `X.batch((B,), (0,))`, - representing `B` independent copies of `X`. - -Algorithms should use space methods such as `zeros`, `add`, `scale`, `axpy`, -`inner`, `norm`, `flatten`, and `unflatten` instead of hard-coding backend array -operations. +```python +import spacecore as sc -### `LinOp` +ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) +X = sc.VectorSpace((1000,), ctx) +A = sc.DenseLinOp(make_spd_matrix(), X, X, ctx) +b = ctx.asarray(rhs) -A `LinOp` represents a linear operator between spaces: +result = sc.cg(A, b, tol=1e-10, maxiter=500) +print(f"x = {result.x}, residual = {result.residual_norm}") +``` -- `DenseLinOp` for dense matrix or tensor operators; -- `SparseLinOp` for sparse operators; -- `BlockDiagonalLinOp` for block-diagonal product-space operators; -- `StackedLinOp` for operators from one space into a product space; -- `SumToSingleLinOp` for operators from a product space into one space. +### Least-squares with regularization -Operators expose `apply` and `rapply`, so algorithms can use a linear map and -its adjoint without depending on the storage format. +```python +# min ||Ax - b||^2 + λ||x||^2 via normal equations +I = sc.IdentityLinOp(X) +system = A.H @ A + lam * I +rhs = A.H.apply(b) +x_hat = sc.cg(system, rhs).x +``` -For batched inputs, `vapply(xs)` and `rvapply(ys)` lift the operator over the -leading batch axis: +### Smallest eigenpair of a Hermitian operator ```python -XB = X.batch(batch_shape=(B,), batch_axes=(0,)) -YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) - -ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB -xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB +result = sc.lanczos_smallest(A, initial_vector, max_iter=100) +print(f"E_0 ≈ {result.eigenvalue}") +print(f"Krylov dimension used: {result.krylov_dim}") +print(f"Converged: {result.converged}") ``` -The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero, -algebraic, and product-structured operators provide specialized batched paths. +### Building a custom operator -### `Functional` +```python +class Convolution(sc.LinOp): + def __init__(self, kernel, space, ctx): + super().__init__(space, space, ctx) + self.kernel = kernel -A `Functional` represents a scalar-valued map on a space. `LinearFunctional` -covers maps such as ``, `MatrixFreeLinearFunctional` wraps a callable -without storing a representer, and `LinOpQuadraticForm` represents objectives -such as `0.5 * + ell(x) + a`. + def apply(self, x): + return self.ops.real(self.ops.fft.ifft(self.ops.fft.fft(x) * self.ops.fft.fft(self.kernel))) -For batched inputs, `vvalue(xs)` evaluates independently over leading batch -axes. Quadratic forms that define gradients also expose `grad(x)` and -`vgrad(xs)`. + def rapply(self, y): + return self.ops.real(self.ops.fft.ifft(self.ops.fft.fft(y) * self.ops.conj(self.ops.fft.fft(self.kernel)))) +``` -## Who should use this? +This operator works on NumPy, JAX, and PyTorch backends without modification. -SpaceCore is aimed at people writing optimization, inverse-problem, optimal -transport, semidefinite programming, or scientific ML algorithms that should not -be tied to one backend. +## How is SpaceCore different from...? -It is most useful when you want the mathematical model to stay stable while the -execution backend changes. +**...`scipy.sparse.linalg`?** SciPy's iterative solvers are great but tied to NumPy/SciPy. SpaceCore gives you the same algorithms across NumPy, JAX, and PyTorch, plus operator algebra (`A @ B + lam * I` actually returns a usable operator), plus first-class custom Hilbert spaces. -## Installation +**...PyLops?** PyLops is excellent for inverse problems but assumes Euclidean vectors and is tied to NumPy/CuPy. SpaceCore handles non-Euclidean geometry (RKHS, weighted spaces, function spaces) and works on JAX/PyTorch for autodiff and ML pipelines. -Base install: +**...QuTiP?** QuTiP is the standard for quantum optics on top of SciPy. SpaceCore lets you build the same quantum operators on JAX or PyTorch for GPU acceleration and gradient-based parameter learning. Less prebuilt, more composable. -```bash -pip install spacecore -``` +**...`array_api_compat`?** That package gives you portable arrays. SpaceCore builds on top of it to give you portable *vector spaces, linear operators, and iterative algorithms* — the abstractions one level up from arrays. -With JAX support: +## Documentation -```bash -pip install "spacecore[jax]" -``` +[//]: # (- **[Quick Start](https://pavlo3p.github.io/SpaceCore/quickstart.html)** — 20-line introduction) +[//]: # (- **[Concepts](https://pavlo3p.github.io/SpaceCore/concepts.html)** — Spaces, operators, contexts) +[//]: # (- **[Tutorials](https://pavlo3p.github.io/SpaceCore/tutorials/index.html)** — Image deblurring, Jaynes-Cummings model, kernel ridge regression) +- **[API Reference](https://pavlo3p.github.io/SpaceCore/api/index.html)** — Full documentation -With PyTorch support: +## Features at a glance -```bash -pip install "spacecore[torch]" -``` +**Spaces.** `VectorSpace`, `HermitianSpace`, `ProductSpace`, `BatchSpace`. All easy to subclass for custom geometry. -- `spacecore[jax]` installs optional JAX support. -- GPU users should install the appropriate CUDA-enabled JAX build first, - following the official JAX installation guide. -- `spacecore[torch]` installs optional PyTorch support for `torch.Tensor` - backends. -- GPU users should install the appropriate CUDA-enabled PyTorch build first, - following the official PyTorch installation guide. +**Linear operators.** `DenseLinOp`, `SparseLinOp`, `DiagonalLinOp`, `MatrixFreeLinOp`, plus operator algebra (`A @ B`, `A + B`, `2 * A`, `A.H`, `IdentityLinOp`, `ZeroLinOp`). -For local development: +**Functionals.** `LinearFunctional`, `QuadraticForm`, with `value`, `grad`, `hess_apply`, and `compose(linop)` for pull-back. -```bash -python -m pip install -e ".[dev]" -``` +**Iterative solvers.** `cg`, `lsqr`, `lanczos_smallest`, `power_iteration`. -## Full example +**Backends.** NumPy (always), JAX (`spacecore[jax]`), PyTorch (`spacecore[torch]`), CuPy (`spacecore[cupy]`). Adding a backend is ~100 LOC; the registry is public. -For a complete example of regularized optimal transport problem, [see](https://pavlo3p.github.io/SpaceCore/tutorials/regularized_ot.html) -the model is written once and solved with NumPy/JAX backends and -its [notebook](https://github.com/Pavlo3P/SpaceCore/blob/master/tutorials/6_Regularized_Opt_Transport.ipynb). +## Project status -## Documentation +**v0.2 alpha.** API may still change in minor ways. Core abstractions are stable. Suitable for research code; not yet recommended for production deployment. -The hosted documentation is available [here](https://pavlo3p.github.io/SpaceCore/). -JAX integration notes are available -[here](https://pavlo3p.github.io/SpaceCore/design/jax_integration.html). +The library is being developed in the open and is looking for early users and feedback. If you try it on your problem, please open an issue with what worked and what didn't — that's the single most valuable contribution right now. -The documentation website is built with Sphinx from `docs/source`. +## Contributing -Install the documentation dependencies: +Bug reports, feature requests, and PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md). -```bash -python -m pip install -e ".[docs]" -``` +Specific areas where help is wanted: -Build the local HTML documentation: +- **Tutorials.** If SpaceCore solves your problem, a notebook example helps everyone. +- **Backends.** CuPy and Dask integration is partial; adding a new backend is well-scoped (~100 LOC). +- **Performance.** Cross-backend benchmarks on real workloads. +- **Documentation.** Concept pages, FAQ, gotchas. -```bash -sphinx-build -b html docs/source docs/build/html -``` +## License -## Status +Apache 2.0. See [LICENSE](LICENSE). -SpaceCore is currently experimental and under active development. The public API -may still evolve. +## Citation -## License +If SpaceCore is useful in your research, a citation is appreciated: -Apache License 2.0 +```bibtex +@software{spacecore, + author = {Pavlo, Pelikh}, + title = {SpaceCore: Backend-agnostic vector spaces and linear operators}, + url = {https://github.com/Pavlo3P/SpaceCore}, + year = {2026}, +} diff --git a/audit_caching.md b/audit_caching.md deleted file mode 100644 index 1531165..0000000 --- a/audit_caching.md +++ /dev/null @@ -1,36 +0,0 @@ -# Cacheability Audit - -This audit covers the requested hot paths and separates construction-time state -from per-call work. Evidence comes from direct source inspection with line -numbers and grep checks such as: - -```text -grep -R "tuple(self\\.dom\\.shape) ==\\|tuple(self\\.cod\\.shape) ==\\|getattr(batch_space\\|coeffs_full = ops.zeros\\|fori_loop(0, max_iter + 1" -n spacecore/linop spacecore/linalg -``` - -## Candidates - -| Location | Per-call work | Evidence | Cost of caching | Benefit | Recommendation | -| --- | --- | --- | --- | --- | --- | -| `spacecore/linop/_dense.py:76-81` | Reshapes `x` when the domain is not flat, multiplies by cached `_A2`, and reshapes/unflattens output. | `_A2`, `_A2T`, `_A2H`, `_dom_is_flat`, `_cod_is_flat`, sizes are already constructed at lines 46-57. | Additional caching would duplicate shape tuples only; < 100 bytes. | Negligible. The expensive matrix multiply dominates. | Don't cache more. Existing construction-time cache is appropriate. | -| `spacecore/linop/_dense.py:92-97` | Reshapes `y`, multiplies by cached `_A2H`, reshapes/unflattens output. | `_A2H` is already cached at line 53; flat flags are cached at lines 54-55. | Additional shape tuple cache only; < 100 bytes. | Negligible relative to matvec. | Don't cache more. | -| `spacecore/linop/_dense.py:116-142` | Batched reshape and batched dense matmul; for non-plain `VectorSpace`, constructs a `BatchSpace` for unflattening. | Calls `self.cod.batch(batch_shape, tuple(range(len(batch_shape))))` at line 128 and domain equivalent at line 142. | A general cache keyed by `batch_shape` would be an unbounded dict and would not be a pytree leaf. | Low unless repeatedly using custom non-vector spaces with the same batch shape. For plain `VectorSpace`, no `BatchSpace` is created. | Don't cache. Avoid unbounded mutable instance state for a narrow path. | -| `spacecore/linop/_dense.py:164-186` | Reflects `in_space.batch_axes` and builds `tuple(range(...))` to identify leading batches. | Lines 166 and 178. | Could cache a tuple per batch rank, but batch rank is input-dependent. | Very low; one tuple/reflection check per batched call. | Don't cache. | -| `spacecore/linop/_sparse.py:74-96` | Reshapes vectors, sparse matvec with cached `_AH`, output reshape/unflatten. | `_AH` and flat flags are cached at lines 47-52. | Additional cache would be shape tuples only. | Negligible relative to sparse matvec. | Don't cache more. | -| `spacecore/linop/_sparse.py:115-141` | Batched sparse matmul; may construct `BatchSpace` for non-vector spaces. | Same pattern as dense: lines 127 and 141. | Unbounded dict if keyed by `batch_shape`; not JAX-friendly. | Low for the common vector-space fast path. | Don't cache. | -| `spacecore/linop/_diagonal.py:51-59` | Builds `batch_shape`, `batch_axes`, `base_axes`, and reshape tuple on every `vapply`/`rvapply`. | Lines 52-59; called by `vapply` at line 65 and `rvapply` at line 75. | A cache would store small tuples keyed by `(batch_shape, batch_axes)`; memory tiny per key but unbounded and mutable. | Low: the elementwise multiply dominates for large arrays; for small arrays, Python overhead exists but caching adds mutable state and pytree concerns. | Don't cache. Keep stateless and JAX-safe. | -| `spacecore/linop/_algebra.py:283-297` | `SumLinOp.apply/rapply` loops over operands and accumulates results. | Lines 286-288 and 294-296. | No reusable derived object; caching partial sums would depend on input. | None. Work is mathematical operator application. | Don't cache. | -| `spacecore/linop/_algebra.py:363-383` | `ComposedLinOp.apply/rapply/vapply/rvapply` delegates through left/right operators; batched paths create middle `BatchSpace`. | Lines 366, 371, 376, 382. | Could cache middle batch spaces by input batch signature; unbounded mutable dict. | Low and only for repeated batched composition with the same batch axes. | Don't cache. | -| `spacecore/linop/_algebra.py:878-894` | `_AdjointViewLinOp` delegates to the wrapped operator. | Lines 881, 886, 890, 894. | Nothing useful to cache. | None. Delegation is already minimal. | Don't cache. | -| `spacecore/linalg/_lanczos.py:108-112` | Builds `e0` and normalized `e0_unit` once per `lanczos_smallest` call. | Lines 108-112. | Caching on `A.domain` would add mutable state to spaces and complicate JAX pytree semantics. Memory is `O(n)`. | Low: once per solver call, not per iteration. | Don't cache. | -| `spacecore/linalg/_lanczos.py:163-170` | Allocates `coeffs_full = zeros(max_iter + 1)` inside every Lanczos iteration, then fills it with a `fori_loop`. | Grep shows `coeffs_full = ops.zeros((max_iter + 1,), dtype=ctx.dtype)` at line 163 inside `body_fun`, followed by `ops.fori_loop(0, max_iter + 1, ...)` at line 170. | One extra closure-captured zero vector of length `max_iter + 1`; bytes are about `(max_iter + 1) * itemsize`, usually a few hundred bytes. | Moderate: avoids allocating the same small vector `m` times per Krylov call. This is also the path Task 3 will revisit for trace cleanliness. | Cache/hoist a zero template outside `body_fun` and reuse it as the initial coefficient vector. | - -## Recommended - -- Hoist Lanczos `coeffs_full` zero-vector allocation out of `body_fun`. - This removes one small allocation per Krylov iteration with negligible memory - cost and no API change. - -Everything else is either already cached (`_A2`, `_A2H`, `_AH`, flat flags) or -would require mutable shape caches for small tuple/reflection work. Those are -not worth the added state, especially for JAX pytree compatibility. diff --git a/audit_jit.md b/audit_jit.md deleted file mode 100644 index ced8e85..0000000 --- a/audit_jit.md +++ /dev/null @@ -1,108 +0,0 @@ -# JIT Traceability Audit - -This audit was generated with `scripts/jit_audit.py`. The script: - -- wraps each solver with `jax.jit`; -- calls each jitted wrapper twice with shape/dtype-stable values; -- calls it again with a changed static iteration argument; -- calls it again with a changed operator/domain shape; -- writes a `jax.make_jaxpr` fixture for `lanczos_smallest` to - `tests/fixtures/jaxpr_lanczos_smallest.txt`. - -The script also enables `jax_log_compiles`. In addition to JAX's compile logs, -it uses a trace-time counter, which increments only when JAX retraces the Python -wrapper. - -## Summary - -| Solver | Traces cleanly | Recompiles/retraces on same shape/dtype values | Retraces when expected | Evidence | -| --- | --- | --- | --- | --- | -| `cg` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output: `{'solver': 'cg', 'traces_after_two_same_shape_calls': 1, 'traces_after_static_change': 2, 'traces_after_shape_change': 3, ...}` | -| `lsqr` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lsqr`. | -| `lanczos_smallest` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `max_iter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `lanczos_smallest`. | -| `power_iteration` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `maxiter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output for `power_iteration`. | -| `expm_multiply` | Yes | No. Trace count stayed at 1 after two calls. | Yes. Static `max_iter` change raised trace count to 2; shape change raised it to 3. | `scripts/jit_audit.py` output: `{'solver': 'expm_multiply', 'traces_after_two_same_shape_calls': 1, 'traces_after_static_change': 2, 'traces_after_shape_change': 3, ...}` | - -## Findings - -### 1. Lanczos full reorthogonalization now uses a vectorized exact-VectorSpace path - -`tests/fixtures/jaxpr_lanczos_smallest.txt` captures the current JAXPR for -`lanczos_smallest(max_iter=3)`. Grep evidence: - -```text -grep -n "scan\\|while\\|scatter\\|dot_general" tests/fixtures/jaxpr_lanczos_smallest.txt -``` - -The initial audit fixture showed: - -- a top-level `while` at line 61 for the Krylov iteration; -- a nested `scan` at line 143 corresponding to - `ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full)`; -- repeated `scatter` operations inside that scan. - -This is correct, but for the common exact `VectorSpace` case it is more IR than -needed. Mathematically, Euclidean reorthogonalization coefficients are -`conj(V) @ w`, which can be one `einsum`/matmul node. - -Important constraint: this replacement is **not valid for arbitrary -`Space.inner`**. A weighted space, RKHS, or any custom geometry must keep using -`Space.inner(v_j, w)`. Therefore the optimization should be guarded to the exact -`VectorSpace` type only, not subclasses. - -The implemented exact-`VectorSpace` path now lowers the coefficient computation -to a single `dot_general` at fixture line 143. Remaining `scan` operations in -the fixture are the fixed-size tridiagonal construction loops, not the -reorthogonalization coefficient loop. - -### 2. `cg` and `lsqr` use `ops.cond` correctly for periodic diagnostics - -Both solvers trace without errors and do not retrace on value-only changes. The -`ops.cond(..., lambda _: ..., lambda _: ..., operand)` pattern is valid under -the installed JAX version. Both branches return matching shapes and dtypes. - -### 3. `power_iteration` dispatch is trace-time, not data-dependent - -`power_iteration` branches on whether the first argument is a `LinOp` or -`QuadraticForm`. That branch is Python-level and happens at trace time. This is -acceptable because the object type is static in the pytree structure; changing -from a `LinOp` to a `QuadraticForm` should retrace. - -### 4. Algebra construction inside JIT is possible but theoretical - -The algebra factories use Python `isinstance` checks for symbolic simplification. -Those checks are not data-dependent on traced arrays. If users construct new -operator expressions inside a jitted function, that Python algebra executes at -trace time and contributes to trace cost. This is a usage-pattern concern, not a -solver bug: normal use passes already-built operators into jitted numerical -kernels. - -### 5. Constants are mostly static by design - -Iteration counts (`maxiter`, `max_iter`) are static in the audit wrappers. -Changing them retraces, which is expected because loop bounds and fixed-size -work arrays change. Scalars such as tolerances are currently Python arguments -converted through `ops.asarray`; changing them may retrace unless callers pass -them through a wrapper as array values. This is acceptable for the current API. - -## Implemented Change - -- Added a Euclidean fast path for Lanczos reorthogonalization when - `type(A.domain) is VectorSpace`: compute all coefficients with - `ops.einsum("jn,n->j", ops.conj(V_), w)`. -- Kept the existing `Space.inner` loop for all non-exact `VectorSpace` domains - to preserve Space geometry. - -This is the only validated change from this audit. The broader replacement -`V @ w` is rejected for non-Euclidean spaces because it would regress the -geometry-correct Lanczos recurrence. - -## Follow-Up TODO - -1. Consider a benchmark for exact `VectorSpace` Lanczos before/after the - reorthogonalization fast path. -2. Consider vectorizing fixed-size tridiagonal construction if the remaining - construction `scan` nodes show up in JAX profiling. -3. If users report trace-time issues from constructing algebra expressions - inside `jax.jit`, document that operators should be built outside the jitted - function. diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 35613fe..3fd7d4e 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -1,6 +1,199 @@ Release notes ============= +Version 0.2.0 +------------- + +SpaceCore 0.2.0 is a major API expansion after +``3a1e382f13ef0496f8f54dc55db81aea95444775``. It migrates backend operations +toward the Array API, adds lazy linear-operator algebra, introduces functionals +and iterative solvers, broadens batching support, and substantially improves +documentation and validation. + +Changes +~~~~~~~ + +Backend support +^^^^^^^^^^^^^^^ + +* Refactored ``BackendOps`` around ``array-api-compat`` so backend-specific + code shares a common Array API-oriented implementation. +* Added optional CuPy support through ``CuPyOps`` and the ``cupy`` backend + family. +* Broadened backend operation coverage for array creation, dtype conversion, + sparse conversion, indexing, reductions, linear algebra, loop primitives, + tree helpers, and vectorized mapping. +* Added backend loop tests for ``fori_loop``, ``while_loop``, ``cond``, and + tree/stack behavior used by JIT-compatible algorithms. +* Added complex-dtype helpers on backend ops, including centralized complex + dtype detection and real-dtype extraction. +* Added JAX pytree registration support for new operator, space, and + functional objects. + +Context and checking +^^^^^^^^^^^^^^^^^^^^ + +* Fixed contextual import cycles by moving contextual implementation details + behind private modules while preserving public exports. +* Added ``ContextBound`` support for context-aware conversion and object + binding. +* Centralized conversion, context normalization, backend registration, and + context-resolution helpers in the contextual manager. +* Extended ``checked_method`` to support validation against ``self`` and + multiple input argument positions. +* Replaced repeated manual ``if self._enable_checks`` membership checks in + spaces with ``checked_method`` where the decorator fits. +* Added and documented reusable space-validation checks, including backend, + dtype, shape, Hermitian, square-matrix, product-structure, and + product-component checks. +* Improved ``enable_checks`` behavior and test coverage across spaces, + operators, context conversion, and functionals. + +Spaces +^^^^^^ + +* Added ``BatchSpace`` for batched elements with explicit batch shape and batch + axis metadata. +* Improved ``VectorSpace``, ``HermitianSpace``, and ``ProductSpace`` docstrings, + conversion behavior, validation, and batching support. +* Made the default linalg initial vector space-aware by normalizing with + ``A.domain.norm`` instead of assuming a Euclidean inner product. + +Linear operators +^^^^^^^^^^^^^^^^ + +* Added lazy linear-operator algebra: + + * ``A @ B`` and ``make_composed`` for composition. + * ``A + B`` and ``make_sum`` for sums. + * scalar multiplication and ``make_scaled`` for scaled operators. + * ``IdentityLinOp``, ``ZeroLinOp``, ``ScaledLinOp``, ``SumLinOp``, and + ``ComposedLinOp``. + +* Added dense materialization via ``to_dense`` for core linear operators and + operator algebra. +* Added ``DiagonalLinOp`` and improved dense, sparse, and diagonal operator + handling for complex adjoints. +* Added public ``LinOp.is_hermitian()`` structural checks where verification is + cheap and reliable. +* Made ``DenseLinOp.is_hermitian()`` and ``SparseLinOp.is_hermitian()`` return + ``None`` for custom space geometries instead of applying an incorrect + Euclidean matrix-symmetry test. +* Added product-structured operators and batched lifting: + + * ``ProductLinOp`` + * ``BlockDiagonalLinOp`` + * ``StackedLinOp`` + * ``SumToSingleLinOp`` + * ``vapply`` and ``rvapply`` paths for batched operator application. + +* Optimized dense, sparse, block-diagonal, stacked, sum-to-single, and + product-operator batched paths. +* Improved linear-operator equality, representation, conversion, and JAX + pytree behavior. + +Functionals +^^^^^^^^^^^ + +* Added the ``Functional`` abstraction for scalar-valued maps on spaces. +* Added linear functional implementations: + + * ``LinearFunctional`` + * ``InnerProductFunctional`` + * ``MatrixFreeLinearFunctional`` + +* Added quadratic forms: + + * ``QuadraticForm`` + * ``LinOpQuadraticForm`` + +* Added ``Functional.compose`` and ``ComposedFunctional`` for pull-backs along + linear operators. +* Added ``LinOpQuadraticForm`` Hermiticity validation through the public + ``LinOp.is_hermitian()`` contract instead of reaching into private operator + attributes. +* Added value, gradient, batched value, batched gradient, conversion, and + pytree coverage for functionals. + +Linear algebra +^^^^^^^^^^^^^^ + +* Added JIT-compatible iterative solvers: + + * ``cg`` for Hermitian positive-definite systems. + * ``lsqr`` for rectangular least-squares problems. + * ``power_iteration`` for dominant eigenpair estimates. + * ``lanczos_smallest`` for smallest Ritz eigenpair estimates. + +* Renamed ``stochastic_lanczos`` to ``lanczos_smallest`` and kept + ``stochastic_lanczos`` as a deprecated alias. +* Added ``LanczosResult`` with residual estimate, Krylov dimension, and + convergence flag. +* Added ``expm_multiply`` for Krylov matrix-exponential actions + ``exp(t * A) @ v`` on Hermitian operators. +* Added ``ExpmMultiplyResult`` with result vector, Krylov dimension, projected + residual estimate, and convergence flag. +* Factored shared Lanczos basis and tridiagonal construction for reuse by + ``lanczos_smallest`` and ``expm_multiply``. +* Hoisted the Lanczos coefficient zero template out of the loop body. +* Vectorized Lanczos reorthogonalization for plain ``VectorSpace`` domains while + preserving the slower geometry-aware path for custom spaces. +* Added a weighted-inner-product regression test to prevent applying the + Euclidean Lanczos fast path to non-Euclidean spaces. +* Made ``lanczos_smallest`` reject operators that are structurally known to be + non-Hermitian. +* Clarified solver contracts for ``domain == codomain`` square requirements, + Hermiticity enforcement, tolerance semantics, JAX static arguments, complex + scalar behavior, ill-conditioning caveats, and power-iteration convergence + caveats. +* Renamed the linalg helper ``safe_inverse`` to ``safe_inverse_nonneg`` to make + its nonnegative-domain semantics explicit. + +Documentation +^^^^^^^^^^^^^ + +* Added API reference pages for backend ops, spaces, linear operators, + functionals, and linear algebra. +* Added the linear algebra API page covering solvers, eigenvalue algorithms, + matrix-function routines, and result types. +* Added a JAX integration design note documenting trace-time operator algebra + and recommended JIT usage. +* Added and updated tutorials for backend operations, linear operators, and + matrix-free linalg workflows. +* Added cacheability and JIT traceability audit documents. +* Added a committed JAXPR fixture for ``lanczos_smallest`` regression tracking. +* Added docstring migration tooling, a migration baseline, and broad NumPy-style + docstring coverage for public APIs. +* Added ``numpydoc`` validation configuration and doctest integration. +* Reworked public docstrings for solvers, spaces, operators, functionals, + backends, and contextual helpers. + +Testing and CI +^^^^^^^^^^^^^^ + +* Added broad tests for backend ops delegation, backend loop primitives, CuPy + ops, context resolution, ``checked_method``, functionals, linalg solvers, + operator algebra, batched lifting, dense materialization, diagonal operators, + and JAX pytree/JIT behavior. +* Added cross-backend linalg tests for NumPy, JAX, Torch, and optional CuPy. +* Added ``expm_multiply`` tests against dense SciPy ground truth, complex time, + group behavior, linearity, residual estimates, and JAX JIT. +* Added CI execution of the JIT audit script in ``--check`` mode. +* Added nonblocking documentation lint/audit steps for the docstring migration. +* Added development dependencies for testing, coverage, linting, and docstring + validation. + +Packaging +^^^^^^^^^ + +* Bumped the package version to ``0.2.0``. +* Made the top-level ``__version__`` resolve from package metadata instead of a + hand-maintained constant. +* Added optional dependency groups for JAX, Torch, CuPy, examples, docs, and + development tooling. +* Updated top-level exports for new backends, operators, functionals, solvers, + result types, validation checks, and contextual helpers. + Version 0.1.4 ------------- From 7a7a99dc21983701572bd64683e646eaa3bcfacd Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:12:56 -0300 Subject: [PATCH 39/44] Update release notes --- docs/source/release_notes.rst | 341 +++++++++++++++++++--------------- 1 file changed, 191 insertions(+), 150 deletions(-) diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 3fd7d4e..100a9e1 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -4,195 +4,236 @@ Release notes Version 0.2.0 ------------- -SpaceCore 0.2.0 is a major API expansion after -``3a1e382f13ef0496f8f54dc55db81aea95444775``. It migrates backend operations -toward the Array API, adds lazy linear-operator algebra, introduces functionals -and iterative solvers, broadens batching support, and substantially improves -documentation and validation. - -Changes +SpaceCore 0.2.0 is a major API expansion. The backend layer now sits on the +Array API standard. Operators gained a lazy algebra with adjoint views, +composition, sums, and scaling. A new :class:`Functional` hierarchy provides +scalar-valued maps with gradients and pull-backs. A new :mod:`spacecore.linalg` +module ships four JIT-compatible iterative solvers. Spaces, operators, and +functionals share a single validation pattern via ``checked_method``, and the +public API is documented to numpydoc standard with doctest coverage. + +This release introduces breaking renames; see :ref:`migration-0-2`. + +Highlights +~~~~~~~~~~ + +* Array API backend layer with optional CuPy support. +* Lazy operator algebra: ``A @ B``, ``A + B``, ``A.H``, plus + :class:`IdentityLinOp`, :class:`ZeroLinOp`, :class:`MatrixFreeLinOp`, and + ``make_*`` factories with algebraic simplification. +* :class:`Functional` hierarchy with linear and quadratic forms, plus + ``Functional.compose`` for pull-back along linear operators. +* New :mod:`spacecore.linalg` module with iterative solvers: :func:`cg`, + :func:`lsqr`, :func:`power_iteration`, :func:`lanczos_smallest`, + :func:`expm_multiply`. +* Geometry-aware solvers honor the declared ``Space.inner`` instead of assuming + Euclidean. +* Unified ``checked_method`` decorator across :class:`Space`, :class:`LinOp`, + and :class:`Functional`. +* Comprehensive numpydoc-style docstrings, doctests, and a JAX integration + design note. + +Backend ~~~~~~~ -Backend support -^^^^^^^^^^^^^^^ +* Migrated :class:`BackendOps` to the Array API standard via + ``array-api-compat``. +* Added :class:`CuPyOps` and the ``cupy`` backend family as an optional install + (``pip install 'spacecore[cupy]'``). +* Centralized complex-dtype handling on :class:`BackendOps`: -* Refactored ``BackendOps`` around ``array-api-compat`` so backend-specific - code shares a common Array API-oriented implementation. -* Added optional CuPy support through ``CuPyOps`` and the ``cupy`` backend - family. -* Broadened backend operation coverage for array creation, dtype conversion, - sparse conversion, indexing, reductions, linear algebra, loop primitives, - tree helpers, and vectorized mapping. -* Added backend loop tests for ``fori_loop``, ``while_loop``, ``cond``, and - tree/stack behavior used by JIT-compatible algorithms. -* Added complex-dtype helpers on backend ops, including centralized complex - dtype detection and real-dtype extraction. -* Added JAX pytree registration support for new operator, space, and - functional objects. + * :meth:`BackendOps.is_complex_dtype` for backend-aware complex detection. + * :meth:`BackendOps.real_dtype` for extracting the real dtype matching a + complex one. + +* Broadened backend coverage for array creation, dtype conversion, sparse + conversion, indexing, reductions, linear algebra, loop primitives + (``fori_loop``, ``while_loop``, ``cond``), tree helpers, and vectorized + mapping. +* Registered JAX pytrees for operator, space, and functional types so they pass + through ``jax.jit``, ``jax.vmap``, and ``jax.grad`` boundaries. Context and checking -^^^^^^^^^^^^^^^^^^^^ - -* Fixed contextual import cycles by moving contextual implementation details - behind private modules while preserving public exports. -* Added ``ContextBound`` support for context-aware conversion and object - binding. -* Centralized conversion, context normalization, backend registration, and - context-resolution helpers in the contextual manager. -* Extended ``checked_method`` to support validation against ``self`` and - multiple input argument positions. -* Replaced repeated manual ``if self._enable_checks`` membership checks in - spaces with ``checked_method`` where the decorator fits. -* Added and documented reusable space-validation checks, including backend, - dtype, shape, Hermitian, square-matrix, product-structure, and - product-component checks. -* Improved ``enable_checks`` behavior and test coverage across spaces, - operators, context conversion, and functionals. +~~~~~~~~~~~~~~~~~~~~ + +* Restructured ``_contextual`` to hide implementation details while keeping the + public free-function API (:func:`set_context`, :func:`get_context`, + :func:`resolve_context_priority`, :func:`register_ops`, and the + resolution-policy accessors). +* Extended :func:`~spacecore._checks.checked_method` to support validation + against ``self`` and multiple input argument positions. +* Replaced manual ``if self._enable_checks`` guards with ``checked_method`` + across :class:`Space`, :class:`LinOp`, and :class:`Functional`. Inline guards + are now reserved for non-membership checks such as dense-array assertions and + custom output-shape checks. +* Added reusable space-validation checks documented at + ``docs/source/design/checking_policy.rst``: backend, dtype, shape, Hermitian, + square-matrix, product-structure, and product-component checks. Spaces -^^^^^^ +~~~~~~ -* Added ``BatchSpace`` for batched elements with explicit batch shape and batch - axis metadata. -* Improved ``VectorSpace``, ``HermitianSpace``, and ``ProductSpace`` docstrings, - conversion behavior, validation, and batching support. -* Made the default linalg initial vector space-aware by normalizing with - ``A.domain.norm`` instead of assuming a Euclidean inner product. +* Added :class:`BatchSpace` for batched elements with explicit batch shape and + batch-axis metadata. +* Improved :class:`VectorSpace`, :class:`HermitianSpace`, and + :class:`ProductSpace` conversion behavior, validation, batching support, and + docstrings. Linear operators -^^^^^^^^^^^^^^^^ +~~~~~~~~~~~~~~~~ + +* **Lazy operator algebra.** Added composition, addition, scaling, and adjoint + view with algebraic simplification: + + * ``A @ B`` composes operators. + * ``A + B`` sums operators. + * ``alpha * A`` scales an operator. + * ``A.H`` returns a cached adjoint view satisfying ``A.H.H is A``. -* Added lazy linear-operator algebra: - - * ``A @ B`` and ``make_composed`` for composition. - * ``A + B`` and ``make_sum`` for sums. - * scalar multiplication and ``make_scaled`` for scaled operators. - * ``IdentityLinOp``, ``ZeroLinOp``, ``ScaledLinOp``, ``SumLinOp``, and - ``ComposedLinOp``. - -* Added dense materialization via ``to_dense`` for core linear operators and - operator algebra. -* Added ``DiagonalLinOp`` and improved dense, sparse, and diagonal operator - handling for complex adjoints. -* Added public ``LinOp.is_hermitian()`` structural checks where verification is - cheap and reliable. -* Made ``DenseLinOp.is_hermitian()`` and ``SparseLinOp.is_hermitian()`` return - ``None`` for custom space geometries instead of applying an incorrect - Euclidean matrix-symmetry test. + Simplification rules eliminate ``I``, ``Zero``, ``alpha = 0``, ``alpha = 1``, + and flatten nested sums. + +* Added :class:`IdentityLinOp`, :class:`ZeroLinOp`, :class:`MatrixFreeLinOp`, + and :class:`DiagonalLinOp`. +* Added structural :meth:`LinOp.is_hermitian` reporting ``True``, ``False``, + or ``None`` (unknown) without applying incorrect Euclidean assumptions for + custom space geometries. +* Added :meth:`LinOp.to_dense` for materializing operators as backend arrays. * Added product-structured operators and batched lifting: - * ``ProductLinOp`` - * ``BlockDiagonalLinOp`` - * ``StackedLinOp`` - * ``SumToSingleLinOp`` - * ``vapply`` and ``rvapply`` paths for batched operator application. + * :class:`ProductLinOp` + * :class:`BlockDiagonalLinOp` + * :class:`StackedLinOp` + * :class:`SumToSingleLinOp` + * ``vapply`` / ``rvapply`` paths for batched operator application. -* Optimized dense, sparse, block-diagonal, stacked, sum-to-single, and - product-operator batched paths. * Improved linear-operator equality, representation, conversion, and JAX pytree behavior. Functionals -^^^^^^^^^^^ +~~~~~~~~~~~ -* Added the ``Functional`` abstraction for scalar-valued maps on spaces. +* Added :class:`Functional` as an abstract base for scalar-valued maps on + spaces, with :meth:`value`, :meth:`grad`, :meth:`hess_apply`, and batched + counterparts. * Added linear functional implementations: - * ``LinearFunctional`` - * ``InnerProductFunctional`` - * ``MatrixFreeLinearFunctional`` + * :class:`LinearFunctional` + * :class:`InnerProductFunctional` + * :class:`MatrixFreeLinearFunctional` * Added quadratic forms: - * ``QuadraticForm`` - * ``LinOpQuadraticForm`` + * :class:`QuadraticForm` + * :class:`LinOpQuadraticForm` -* Added ``Functional.compose`` and ``ComposedFunctional`` for pull-backs along - linear operators. -* Added ``LinOpQuadraticForm`` Hermiticity validation through the public - ``LinOp.is_hermitian()`` contract instead of reaching into private operator - attributes. -* Added value, gradient, batched value, batched gradient, conversion, and - pytree coverage for functionals. +* Added :meth:`Functional.compose` and :class:`ComposedFunctional` for + pull-backs along linear operators, with specializations that preserve the + concrete functional type when possible. Linear algebra -^^^^^^^^^^^^^^ - -* Added JIT-compatible iterative solvers: - - * ``cg`` for Hermitian positive-definite systems. - * ``lsqr`` for rectangular least-squares problems. - * ``power_iteration`` for dominant eigenpair estimates. - * ``lanczos_smallest`` for smallest Ritz eigenpair estimates. - -* Renamed ``stochastic_lanczos`` to ``lanczos_smallest`` and kept - ``stochastic_lanczos`` as a deprecated alias. -* Added ``LanczosResult`` with residual estimate, Krylov dimension, and - convergence flag. -* Added ``expm_multiply`` for Krylov matrix-exponential actions - ``exp(t * A) @ v`` on Hermitian operators. -* Added ``ExpmMultiplyResult`` with result vector, Krylov dimension, projected - residual estimate, and convergence flag. -* Factored shared Lanczos basis and tridiagonal construction for reuse by - ``lanczos_smallest`` and ``expm_multiply``. -* Hoisted the Lanczos coefficient zero template out of the loop body. -* Vectorized Lanczos reorthogonalization for plain ``VectorSpace`` domains while - preserving the slower geometry-aware path for custom spaces. -* Added a weighted-inner-product regression test to prevent applying the - Euclidean Lanczos fast path to non-Euclidean spaces. -* Made ``lanczos_smallest`` reject operators that are structurally known to be - non-Hermitian. -* Clarified solver contracts for ``domain == codomain`` square requirements, - Hermiticity enforcement, tolerance semantics, JAX static arguments, complex - scalar behavior, ill-conditioning caveats, and power-iteration convergence - caveats. -* Renamed the linalg helper ``safe_inverse`` to ``safe_inverse_nonneg`` to make - its nonnegative-domain semantics explicit. +~~~~~~~~~~~~~~ + +The :mod:`spacecore.linalg` module is new in 0.2.0. It provides +JIT-compatible iterative solvers and structured result types. + +* Added iterative solvers: + + * :func:`cg` for Hermitian positive-definite systems. + * :func:`lsqr` for rectangular least-squares problems. + * :func:`power_iteration` for dominant-eigenpair estimates of a + :class:`LinOp` or :class:`QuadraticForm`. + * :func:`lanczos_smallest` for smallest-Ritz-eigenpair estimates of + Hermitian operators. + * :func:`expm_multiply` for Krylov matrix-exponential actions + ``exp(t A) v`` on Hermitian operators, with complex ``t`` supported for + Schrodinger-type evolution. + +* Added structured result types :class:`CGResult`, :class:`LSQRResult`, + :class:`PowerIterationResult`, :class:`LanczosResult`, and + :class:`ExpmMultiplyResult`, each carrying convergence diagnostics and a + compact ``__repr__``. +* Solvers are geometry-aware: norms, inner products, and the default initial + vector use ``Space.inner`` and ``Space.norm`` rather than assuming Euclidean + geometry. This makes the solvers correct on custom inner products such as + RKHS or weighted spaces. Documentation -^^^^^^^^^^^^^ +~~~~~~~~~~~~~ +* Reworked public docstrings to numpydoc standard with runnable doctests for + solvers, spaces, operators, functionals, backends, and contextual helpers. +* Clarified solver contracts: ``domain == codomain`` square requirements, + Hermiticity enforcement, tolerance semantics, JAX static arguments, complex + scalar behavior, ill-conditioning caveats, and convergence assumptions. * Added API reference pages for backend ops, spaces, linear operators, functionals, and linear algebra. -* Added the linear algebra API page covering solvers, eigenvalue algorithms, - matrix-function routines, and result types. * Added a JAX integration design note documenting trace-time operator algebra - and recommended JIT usage. -* Added and updated tutorials for backend operations, linear operators, and - matrix-free linalg workflows. -* Added cacheability and JIT traceability audit documents. -* Added a committed JAXPR fixture for ``lanczos_smallest`` regression tracking. -* Added docstring migration tooling, a migration baseline, and broad NumPy-style - docstring coverage for public APIs. -* Added ``numpydoc`` validation configuration and doctest integration. -* Reworked public docstrings for solvers, spaces, operators, functionals, - backends, and contextual helpers. + and recommended JIT usage at + ``docs/source/design/jax_integration.rst``. +* Added tutorials for backend operations, linear operators, and matrix-free + linalg workflows. Testing and CI -^^^^^^^^^^^^^^ - -* Added broad tests for backend ops delegation, backend loop primitives, CuPy - ops, context resolution, ``checked_method``, functionals, linalg solvers, - operator algebra, batched lifting, dense materialization, diagonal operators, - and JAX pytree/JIT behavior. -* Added cross-backend linalg tests for NumPy, JAX, Torch, and optional CuPy. -* Added ``expm_multiply`` tests against dense SciPy ground truth, complex time, - group behavior, linearity, residual estimates, and JAX JIT. -* Added CI execution of the JIT audit script in ``--check`` mode. -* Added nonblocking documentation lint/audit steps for the docstring migration. -* Added development dependencies for testing, coverage, linting, and docstring - validation. +~~~~~~~~~~~~~~ + +* Added cross-backend tests covering NumPy, JAX, Torch, and optional CuPy. +* Added tests for backend ops delegation, backend loop primitives, CuPy ops, + context resolution, ``checked_method``, functionals, linalg solvers, + operator algebra, batched lifting, dense materialization, diagonal + operators, and JAX pytree/JIT behavior. +* Added CI execution of a JIT-traceability audit script in ``--check`` mode + and a coverage floor of 70% via ``pytest-cov``. +* Added nonblocking documentation lint and audit steps for the docstring + migration. Packaging -^^^^^^^^^ +~~~~~~~~~ * Bumped the package version to ``0.2.0``. -* Made the top-level ``__version__`` resolve from package metadata instead of a - hand-maintained constant. -* Added optional dependency groups for JAX, Torch, CuPy, examples, docs, and - development tooling. -* Updated top-level exports for new backends, operators, functionals, solvers, - result types, validation checks, and contextual helpers. +* ``spacecore.__version__`` now resolves from package metadata via + ``importlib.metadata`` instead of a hand-maintained constant. +* Added optional dependency groups: ``[jax]``, ``[torch]``, ``[cupy]``, + ``[examples]``, ``[docs]``, ``[dev]``. +* Added an explicit ``__all__`` at the top level covering new backends, + operators, functionals, solvers, result types, validation checks, and + contextual helpers. + +.. _migration-0-2: + +Migration from 0.1.x +~~~~~~~~~~~~~~~~~~~~ + +* ``BackendOps.eps`` is now a method ``eps(dtype)`` rather than a property. + Callers must pass a dtype, typically ``ctx.dtype``. +* The implementation attribute ``DenseLinOp.A`` is now a + :class:`functools.cached_property` backed by ``_A``. The public attribute + access ``op.A`` is unchanged. +* :meth:`LinOp.__eq__` now returns ``NotImplemented`` instead of raising + ``NotImplementedError`` on the base class, so ``op == None`` and + ``op in some_list`` no longer raise. +* Several module-internal helpers in ``spacecore._contextual`` moved to + private modules. Use the public functions re-exported from + :mod:`spacecore._contextual` (``set_context``, ``get_context``, + ``resolve_context_priority``, ``register_ops``, ``set_resolution_policy``, + and the dtype-policy accessors) rather than importing from internal modules. + +Known limitations +~~~~~~~~~~~~~~~~~ + +* :func:`cg`, :func:`lsqr`, and :func:`power_iteration` do not structurally + validate operator properties (positive-definiteness, full Hermiticity) and + may silently produce incorrect results on inputs that violate their + preconditions. See each function's ``Notes`` section for details. +* Operator algebra runs Python-level simplification at construction time. For + maximum JIT efficiency, assemble operator expressions outside the + ``jax.jit`` boundary; see the JAX integration design note. +* :class:`MatrixFreeLinOp` stores its callables in pytree auxiliary data. + Constructing one inside a JIT-traced function with a new lambda each call + triggers retracing. Construct outside the traced region with a stable + callable reference. +* The CuPy backend is provided as a preview. Coverage of non-standard + operations and sparse handling may evolve in a subsequent release. Version 0.1.4 ------------- From 4c26c91479213643567fa6704eebf74756d68a79 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:13:38 -0300 Subject: [PATCH 40/44] Remove stochastic Lanczos alias --- docs/source/api/linalg.rst | 8 ------ spacecore/linalg/__init__.py | 4 +-- spacecore/linalg/_lanczos.py | 51 ------------------------------------ 3 files changed, 1 insertion(+), 62 deletions(-) diff --git a/docs/source/api/linalg.rst b/docs/source/api/linalg.rst index 3ae5531..f76e39b 100644 --- a/docs/source/api/linalg.rst +++ b/docs/source/api/linalg.rst @@ -12,13 +12,11 @@ method explicitly projects to a small Krylov subspace. spacecore.linalg.cg spacecore.linalg.lsqr spacecore.linalg.lanczos_smallest - spacecore.linalg.stochastic_lanczos spacecore.linalg.power_iteration spacecore.linalg.expm_multiply spacecore.linalg.CGResult spacecore.linalg.LSQRResult spacecore.linalg.LanczosResult - spacecore.linalg.StochasticLanczosResult spacecore.linalg.PowerIterationResult spacecore.linalg.ExpmMultiplyResult @@ -42,16 +40,10 @@ Eigenvalue algorithms .. autofunction:: spacecore.linalg.lanczos_smallest -.. autofunction:: spacecore.linalg.stochastic_lanczos - .. autoclass:: spacecore.linalg.LanczosResult :members: :undoc-members: -.. autoclass:: spacecore.linalg.StochasticLanczosResult - :members: - :undoc-members: - .. autofunction:: spacecore.linalg.power_iteration .. autoclass:: spacecore.linalg.PowerIterationResult diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index c80d084..49a0509 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -4,7 +4,7 @@ from ._cg import CGResult, cg from ._expm import ExpmMultiplyResult, expm_multiply -from ._lanczos import LanczosResult, StochasticLanczosResult, lanczos_smallest, stochastic_lanczos +from ._lanczos import LanczosResult, lanczos_smallest from ._lsqr import LSQRResult, lsqr from ._power import PowerIterationResult, power_iteration @@ -14,11 +14,9 @@ "LanczosResult", "LSQRResult", "PowerIterationResult", - "StochasticLanczosResult", "cg", "expm_multiply", "lanczos_smallest", "lsqr", "power_iteration", - "stochastic_lanczos", ] diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 3595fcf..98d8005 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -50,9 +50,6 @@ def __repr__(self) -> str: ) -StochasticLanczosResult = LanczosResult - - class _LanczosBasisResult(NamedTuple): """Store fixed-size Lanczos basis data and tridiagonal projection.""" @@ -398,51 +395,3 @@ def lanczos_smallest( lam = num / den return LanczosResult(lam, x, residual_norm, m, converged) - - -def stochastic_lanczos( - A: LinOp, - initial_vector: Any, - *, - max_iter: int = 100, - tol: float = 1e-6, - check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, -) -> LanczosResult: - """ - Call :func:`lanczos_smallest` through a deprecated alias. - - Parameters - ---------- - A : LinOp - Square Hermitian linear operator. - initial_vector : array-like - Starting vector in ``A.domain``. - max_iter : int, optional - Maximum Krylov dimension. Default is 100. - tol : float, optional - Breakdown tolerance. Default is 1e-6. - check_every : int, optional - Iteration interval for convergence checks. - - Returns - ------- - LanczosResult - Result from :func:`lanczos_smallest`. - - Warns - ----- - DeprecationWarning - Always emitted because this alias will be removed in a future release. - """ - warn( - "stochastic_lanczos is deprecated; use lanczos_smallest instead.", - DeprecationWarning, - stacklevel=2, - ) - return lanczos_smallest( - A, - initial_vector, - max_iter=max_iter, - tol=tol, - check_every=check_every, - ) From 33d19859c2e1795d8a43ccac5fb5d9b9322be39d Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:17:39 -0300 Subject: [PATCH 41/44] Add changelog --- CHANGELOG.md | 184 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d29efa0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,184 @@ +# Changelog + +All notable changes to SpaceCore are documented in this file. + +The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and the project adheres to [Semantic Versioning](https://semver.org/). + +## [0.2.0] + +SpaceCore 0.2.0 is a major API expansion. The backend layer now sits on the +Array API standard. Operators gained a lazy algebra with adjoint views, +composition, sums, and scaling. A new `Functional` hierarchy provides +scalar-valued maps with gradients and pull-backs. A new `spacecore.linalg` +module ships four JIT-compatible iterative solvers. Spaces, operators, and +functionals share a single validation pattern via `checked_method`, and the +public API is documented to numpydoc standard with doctest coverage. + +This release introduces breaking changes; see [Migration](#migration-from-01x). + +### Added + +#### Backend + +- Migrated `BackendOps` to the Array API standard via `array-api-compat`. +- `CuPyOps` and the `cupy` backend family as an optional install + (`pip install 'spacecore[cupy]'`). +- `BackendOps.is_complex_dtype` for backend-aware complex detection. +- `BackendOps.real_dtype` for extracting the real dtype matching a complex one. +- Broadened backend coverage for array creation, dtype conversion, sparse + conversion, indexing, reductions, linear algebra, loop primitives + (`fori_loop`, `while_loop`, `cond`), tree helpers, and vectorized mapping. +- JAX pytree registration for operator, space, and functional types so they + pass through `jax.jit`, `jax.vmap`, and `jax.grad` boundaries. + +#### Context and checking + +- Public free-function API in `spacecore._contextual`: `set_context`, + `get_context`, `resolve_context_priority`, `register_ops`, and the + resolution-policy accessors. +- Extended `checked_method` to support validation against `self` and multiple + input argument positions. +- Reusable space-validation checks: backend, dtype, shape, Hermitian, + square-matrix, product-structure, and product-component checks. Documented + at `docs/source/design/checking_policy.rst`. + +#### Spaces + +- `BatchSpace` for batched elements with explicit batch shape and batch-axis + metadata. + +#### Linear operators + +- Lazy operator algebra: + - `A @ B` composes operators. + - `A + B` sums operators. + - `alpha * A` scales an operator. + - `A.H` returns a cached adjoint view satisfying `A.H.H is A`. + - Algebraic simplification eliminates `I`, `Zero`, `alpha = 0`, `alpha = 1`, + and flattens nested sums. +- New operator types: `IdentityLinOp`, `ZeroLinOp`, `MatrixFreeLinOp`, + `DiagonalLinOp`. +- Structural `LinOp.is_hermitian()` reporting `True`, `False`, or `None` + (unknown) without applying incorrect Euclidean assumptions for custom space + geometries. +- `LinOp.to_dense()` for materializing operators as backend arrays. +- Product-structured operators and batched lifting: + - `ProductLinOp` + - `BlockDiagonalLinOp` + - `StackedLinOp` + - `SumToSingleLinOp` + - `vapply` / `rvapply` paths for batched operator application. + +#### Functionals + +- `Functional` as an abstract base for scalar-valued maps on spaces, with + `value`, `grad`, `hess_apply`, and batched counterparts. +- Linear functionals: `LinearFunctional`, `InnerProductFunctional`, + `MatrixFreeLinearFunctional`. +- Quadratic forms: `QuadraticForm`, `LinOpQuadraticForm`. +- `Functional.compose` and `ComposedFunctional` for pull-backs along linear + operators, with specializations that preserve the concrete functional type + when possible. + +#### Linear algebra + +The `spacecore.linalg` module is new in 0.2.0. It provides JIT-compatible +iterative solvers and structured result types. + +- Iterative solvers: + - `cg` for Hermitian positive-definite systems. + - `lsqr` for rectangular least-squares problems. + - `power_iteration` for dominant-eigenpair estimates of a `LinOp` or + `QuadraticForm`. + - `lanczos_smallest` for smallest-Ritz-eigenpair estimates of Hermitian + operators. + - `expm_multiply` for Krylov matrix-exponential actions `exp(t A) v` on + Hermitian operators, with complex `t` supported for Schrodinger-type + evolution. +- Structured result types `CGResult`, `LSQRResult`, `PowerIterationResult`, + `LanczosResult`, and `ExpmMultiplyResult`, each carrying convergence + diagnostics and a compact `__repr__`. +- Solvers are geometry-aware: norms, inner products, and the default initial + vector use `Space.inner` and `Space.norm` rather than assuming Euclidean + geometry. This makes the solvers correct on custom inner products such as + RKHS or weighted spaces. + +#### Documentation + +- Numpydoc-standard public docstrings with runnable doctests for solvers, + spaces, operators, functionals, backends, and contextual helpers. +- API reference pages for backend ops, spaces, linear operators, functionals, + and linear algebra. +- JAX integration design note at `docs/source/design/jax_integration.rst` + covering trace-time operator algebra and recommended JIT usage. +- Tutorials for backend operations, linear operators, and matrix-free linalg + workflows. + +#### Tooling + +- Optional dependency groups: `[jax]`, `[torch]`, `[cupy]`, `[examples]`, + `[docs]`, `[dev]`. +- Explicit `__all__` at the top level covering new backends, operators, + functionals, solvers, result types, validation checks, and contextual + helpers. +- CI runs a JIT-traceability audit in `--check` mode and enforces a 70% + coverage floor via `pytest-cov`. +- Cross-backend tests covering NumPy, JAX, Torch, and optional CuPy. + +### Changed + +- Restructured `_contextual` to hide implementation details while preserving + the public API via free functions. +- Replaced manual `if self._enable_checks` guards with `checked_method` across + `Space`, `LinOp`, and `Functional`. Inline guards are now reserved for + non-membership checks such as dense-array assertions and custom output-shape + checks. +- Improved `VectorSpace`, `HermitianSpace`, and `ProductSpace` conversion + behavior, validation, batching support, and docstrings. +- Improved linear-operator equality, representation, conversion, and JAX + pytree behavior. +- `spacecore.__version__` now resolves from package metadata via + `importlib.metadata` instead of a hand-maintained constant. +- Bumped the package version to `0.2.0`. + +### Fixed + +- `LinOp.__eq__` returns `NotImplemented` instead of raising + `NotImplementedError` on the base class, so `op == None` and + `op in some_list` no longer raise. +- `DenseLinOp.is_hermitian` and `SparseLinOp.is_hermitian` return `None` for + custom space geometries instead of applying an incorrect Euclidean + matrix-symmetry test. + +### Migration from 0.1.x + +- `BackendOps.eps` is now a method `eps(dtype)` rather than a property. + Callers must pass a dtype, typically `ctx.dtype`. +- The implementation attribute `DenseLinOp.A` is now a `cached_property` + backed by `_A`. The public attribute access `op.A` is unchanged. +- `LinOp.__eq__` returns `NotImplemented` rather than raising; downstream code + relying on the exception should be updated to handle the new behavior. +- Several module-internal helpers in `spacecore._contextual` moved to private + modules. Use the public functions re-exported from `spacecore._contextual` + (`set_context`, `get_context`, `resolve_context_priority`, `register_ops`, + `set_resolution_policy`, and the dtype-policy accessors) rather than + importing from internal modules. + +### Known limitations + +- `cg`, `lsqr`, and `power_iteration` do not structurally validate operator + properties (positive-definiteness, full Hermiticity) and may silently + produce incorrect results on inputs that violate their preconditions. See + each function's `Notes` section for details. +- Operator algebra runs Python-level simplification at construction time. For + maximum JIT efficiency, assemble operator expressions outside the + `jax.jit` boundary; see the JAX integration design note. +- `MatrixFreeLinOp` stores its callables in pytree auxiliary data. + Constructing one inside a JIT-traced function with a new lambda each call + triggers retracing. Construct outside the traced region with a stable + callable reference. +- The CuPy backend is provided as a preview. Coverage of non-standard + operations and sparse handling may evolve in a subsequent release. + +[0.2.0]: https://github.com/Pavlo3P/SpaceCore/releases/tag/v0.2.0 \ No newline at end of file From cb32ce905a37a635a2c9e4136be19ea44bfe8960 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:19:00 -0300 Subject: [PATCH 42/44] Remove old files --- MIGRATION.md | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 MIGRATION.md diff --git a/MIGRATION.md b/MIGRATION.md deleted file mode 100644 index b083497..0000000 --- a/MIGRATION.md +++ /dev/null @@ -1,20 +0,0 @@ -# Docstring migration progress - -Baseline (2026-05-27): - -- Ruff pydocstyle: 60 `D` violations from `ruff check --select D spacecore/`. -- Numpydoc validation: 133 actionable issues from `python scripts/docstring_audit.py` - after the Phase 0 allow-list (`ES01`, `EX01`, `SA01`, `GL08`). -- Numpydoc validation, raw: 306 issues from - `python scripts/docstring_audit.py --include-allowed`. -- Doctest: 0 doctest examples collected from - `pytest --doctest-modules spacecore/ -x` under the initial ignore list. - -Notes: - -- The installed `numpydoc` command validates import paths, not package - directories, so `scripts/docstring_audit.py` records and reports the public - SpaceCore API baseline. -- Ruff docstring rules are run as a separate non-blocking CI warning while the - migration is in progress, so the existing strict `ruff check .` step remains - unchanged. From 965817912cff80d80e8ac4502ed7829405f72b37 Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:28:21 -0300 Subject: [PATCH 43/44] Remove stochastic Lanczos --- spacecore/__init__.py | 4 ---- tests/integration/test_public_api.py | 3 +-- tests/linalg/test_krylov.py | 13 ++----------- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/spacecore/__init__.py b/spacecore/__init__.py index f8b9d55..d94cd90 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -51,13 +51,11 @@ LanczosResult, LSQRResult, PowerIterationResult, - StochasticLanczosResult, cg, expm_multiply, lanczos_smallest, lsqr, power_iteration, - stochastic_lanczos, ) from .space import ( BatchSpace, @@ -127,13 +125,11 @@ "LanczosResult", "LSQRResult", "PowerIterationResult", - "StochasticLanczosResult", "cg", "expm_multiply", "lanczos_smallest", "lsqr", "power_iteration", - "stochastic_lanczos", "BackendCheck", "DTypeCheck", diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 1de2587..78d59b9 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -31,7 +31,7 @@ def test_expected_names_are_exported(): "set_context", "get_context", "resolve_context_priority", "register_ops", "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", - "LanczosResult", "StochasticLanczosResult", "lanczos_smallest", + "LanczosResult", "lanczos_smallest", "ExpmMultiplyResult", "expm_multiply", } if has_jax(): @@ -65,7 +65,6 @@ def test_top_level_objects_match_source_modules(): assert sc.ComposedFunctional is functional.ComposedFunctional assert sc.InnerProductFunctional is functional.InnerProductFunctional assert sc.LanczosResult is linalg.LanczosResult - assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult assert sc.ExpmMultiplyResult is linalg.ExpmMultiplyResult assert sc.expm_multiply is linalg.expm_multiply assert sc.get_context is contextual.get_context diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index c4dc2f2..a85c2fe 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -254,7 +254,7 @@ def test_lanczos_smallest_approximates_smallest_eigenpair(backend_name, dtype): assert int(to_numpy(result.krylov_dim)) == 2 -def test_lanczos_smallest_returns_result_object_and_deprecated_alias_warns(): +def test_lanczos_smallest_returns_result_object(): sc = importlib.import_module("spacecore") ctx = _ctx() space = sc.VectorSpace((2,), ctx) @@ -263,16 +263,7 @@ def test_lanczos_smallest_returns_result_object_and_deprecated_alias_warns(): result = sc.lanczos_smallest(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) assert isinstance(result, sc.LanczosResult) - assert isinstance(result, sc.StochasticLanczosResult) - with pytest.warns(DeprecationWarning, match="lanczos_smallest"): - alias_result = sc.stochastic_lanczos( - op, - ctx.asarray([1.0, 1.0]), - max_iter=2, - tol=1e-8, - ) - np.testing.assert_allclose(alias_result.eigenvalue, result.eigenvalue) - np.testing.assert_allclose(alias_result.eigenvector, result.eigenvector) + np.testing.assert_allclose(result.eigenvalue, 2.0) def test_lanczos_smallest_uses_e0_for_zero_initial_vector(): From de34032862f9404dfd8310d93b15962ff245a8fc Mon Sep 17 00:00:00 2001 From: Pavlo Pelikh Date: Tue, 26 May 2026 23:31:24 -0300 Subject: [PATCH 44/44] Remove stochastic Lanczos --- spacecore/linalg/_lanczos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index 98d8005..b1e1d13 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, NamedTuple -from warnings import warn from ..linop import LinOp