diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 413e173..1803dfc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,8 +24,11 @@ jobs: - run: python -m pip install --upgrade pip - run: pip install -e ".[jax,torch,dev]" - - run: pytest + - run: pytest --cov=spacecore --cov-report=term-missing --cov-fail-under=70 + - run: python scripts/jit_audit.py --check - run: ruff check . + - run: ruff check --select D spacecore/ || true + - run: python scripts/docstring_audit.py || true publish: if: startsWith(github.ref, 'refs/tags/') diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d29efa0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,184 @@ +# Changelog + +All notable changes to SpaceCore are documented in this file. + +The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and the project adheres to [Semantic Versioning](https://semver.org/). + +## [0.2.0] + +SpaceCore 0.2.0 is a major API expansion. The backend layer now sits on the +Array API standard. Operators gained a lazy algebra with adjoint views, +composition, sums, and scaling. A new `Functional` hierarchy provides +scalar-valued maps with gradients and pull-backs. A new `spacecore.linalg` +module ships four JIT-compatible iterative solvers. Spaces, operators, and +functionals share a single validation pattern via `checked_method`, and the +public API is documented to numpydoc standard with doctest coverage. + +This release introduces breaking changes; see [Migration](#migration-from-01x). + +### Added + +#### Backend + +- Migrated `BackendOps` to the Array API standard via `array-api-compat`. +- `CuPyOps` and the `cupy` backend family as an optional install + (`pip install 'spacecore[cupy]'`). +- `BackendOps.is_complex_dtype` for backend-aware complex detection. +- `BackendOps.real_dtype` for extracting the real dtype matching a complex one. +- Broadened backend coverage for array creation, dtype conversion, sparse + conversion, indexing, reductions, linear algebra, loop primitives + (`fori_loop`, `while_loop`, `cond`), tree helpers, and vectorized mapping. +- JAX pytree registration for operator, space, and functional types so they + pass through `jax.jit`, `jax.vmap`, and `jax.grad` boundaries. + +#### Context and checking + +- Public free-function API in `spacecore._contextual`: `set_context`, + `get_context`, `resolve_context_priority`, `register_ops`, and the + resolution-policy accessors. +- Extended `checked_method` to support validation against `self` and multiple + input argument positions. +- Reusable space-validation checks: backend, dtype, shape, Hermitian, + square-matrix, product-structure, and product-component checks. Documented + at `docs/source/design/checking_policy.rst`. + +#### Spaces + +- `BatchSpace` for batched elements with explicit batch shape and batch-axis + metadata. + +#### Linear operators + +- Lazy operator algebra: + - `A @ B` composes operators. + - `A + B` sums operators. + - `alpha * A` scales an operator. + - `A.H` returns a cached adjoint view satisfying `A.H.H is A`. + - Algebraic simplification eliminates `I`, `Zero`, `alpha = 0`, `alpha = 1`, + and flattens nested sums. +- New operator types: `IdentityLinOp`, `ZeroLinOp`, `MatrixFreeLinOp`, + `DiagonalLinOp`. +- Structural `LinOp.is_hermitian()` reporting `True`, `False`, or `None` + (unknown) without applying incorrect Euclidean assumptions for custom space + geometries. +- `LinOp.to_dense()` for materializing operators as backend arrays. +- Product-structured operators and batched lifting: + - `ProductLinOp` + - `BlockDiagonalLinOp` + - `StackedLinOp` + - `SumToSingleLinOp` + - `vapply` / `rvapply` paths for batched operator application. + +#### Functionals + +- `Functional` as an abstract base for scalar-valued maps on spaces, with + `value`, `grad`, `hess_apply`, and batched counterparts. +- Linear functionals: `LinearFunctional`, `InnerProductFunctional`, + `MatrixFreeLinearFunctional`. +- Quadratic forms: `QuadraticForm`, `LinOpQuadraticForm`. +- `Functional.compose` and `ComposedFunctional` for pull-backs along linear + operators, with specializations that preserve the concrete functional type + when possible. + +#### Linear algebra + +The `spacecore.linalg` module is new in 0.2.0. It provides JIT-compatible +iterative solvers and structured result types. + +- Iterative solvers: + - `cg` for Hermitian positive-definite systems. + - `lsqr` for rectangular least-squares problems. + - `power_iteration` for dominant-eigenpair estimates of a `LinOp` or + `QuadraticForm`. + - `lanczos_smallest` for smallest-Ritz-eigenpair estimates of Hermitian + operators. + - `expm_multiply` for Krylov matrix-exponential actions `exp(t A) v` on + Hermitian operators, with complex `t` supported for Schrodinger-type + evolution. +- Structured result types `CGResult`, `LSQRResult`, `PowerIterationResult`, + `LanczosResult`, and `ExpmMultiplyResult`, each carrying convergence + diagnostics and a compact `__repr__`. +- Solvers are geometry-aware: norms, inner products, and the default initial + vector use `Space.inner` and `Space.norm` rather than assuming Euclidean + geometry. This makes the solvers correct on custom inner products such as + RKHS or weighted spaces. + +#### Documentation + +- Numpydoc-standard public docstrings with runnable doctests for solvers, + spaces, operators, functionals, backends, and contextual helpers. +- API reference pages for backend ops, spaces, linear operators, functionals, + and linear algebra. +- JAX integration design note at `docs/source/design/jax_integration.rst` + covering trace-time operator algebra and recommended JIT usage. +- Tutorials for backend operations, linear operators, and matrix-free linalg + workflows. + +#### Tooling + +- Optional dependency groups: `[jax]`, `[torch]`, `[cupy]`, `[examples]`, + `[docs]`, `[dev]`. +- Explicit `__all__` at the top level covering new backends, operators, + functionals, solvers, result types, validation checks, and contextual + helpers. +- CI runs a JIT-traceability audit in `--check` mode and enforces a 70% + coverage floor via `pytest-cov`. +- Cross-backend tests covering NumPy, JAX, Torch, and optional CuPy. + +### Changed + +- Restructured `_contextual` to hide implementation details while preserving + the public API via free functions. +- Replaced manual `if self._enable_checks` guards with `checked_method` across + `Space`, `LinOp`, and `Functional`. Inline guards are now reserved for + non-membership checks such as dense-array assertions and custom output-shape + checks. +- Improved `VectorSpace`, `HermitianSpace`, and `ProductSpace` conversion + behavior, validation, batching support, and docstrings. +- Improved linear-operator equality, representation, conversion, and JAX + pytree behavior. +- `spacecore.__version__` now resolves from package metadata via + `importlib.metadata` instead of a hand-maintained constant. +- Bumped the package version to `0.2.0`. + +### Fixed + +- `LinOp.__eq__` returns `NotImplemented` instead of raising + `NotImplementedError` on the base class, so `op == None` and + `op in some_list` no longer raise. +- `DenseLinOp.is_hermitian` and `SparseLinOp.is_hermitian` return `None` for + custom space geometries instead of applying an incorrect Euclidean + matrix-symmetry test. + +### Migration from 0.1.x + +- `BackendOps.eps` is now a method `eps(dtype)` rather than a property. + Callers must pass a dtype, typically `ctx.dtype`. +- The implementation attribute `DenseLinOp.A` is now a `cached_property` + backed by `_A`. The public attribute access `op.A` is unchanged. +- `LinOp.__eq__` returns `NotImplemented` rather than raising; downstream code + relying on the exception should be updated to handle the new behavior. +- Several module-internal helpers in `spacecore._contextual` moved to private + modules. Use the public functions re-exported from `spacecore._contextual` + (`set_context`, `get_context`, `resolve_context_priority`, `register_ops`, + `set_resolution_policy`, and the dtype-policy accessors) rather than + importing from internal modules. + +### Known limitations + +- `cg`, `lsqr`, and `power_iteration` do not structurally validate operator + properties (positive-definiteness, full Hermiticity) and may silently + produce incorrect results on inputs that violate their preconditions. See + each function's `Notes` section for details. +- Operator algebra runs Python-level simplification at construction time. For + maximum JIT efficiency, assemble operator expressions outside the + `jax.jit` boundary; see the JAX integration design note. +- `MatrixFreeLinOp` stores its callables in pytree auxiliary data. + Constructing one inside a JIT-traced function with a new lambda each call + triggers retracing. Construct outside the traced region with a stable + callable reference. +- The CuPy backend is provided as a preview. Coverage of non-standard + operations and sparse handling may evolve in a subsequent release. + +[0.2.0]: https://github.com/Pavlo3P/SpaceCore/releases/tag/v0.2.0 \ No newline at end of file diff --git a/README.md b/README.md index 2459fb9..3d30ab1 100644 --- a/README.md +++ b/README.md @@ -1,245 +1,206 @@ # SpaceCore -SpaceCore exists for writing numerical algorithms once, independently of the -array backend. +[![CI](https://github.com/Pavlo3P/SpaceCore/actions/workflows/ci.yml/badge.svg)](https://github.com/Pavlo3P/SpaceCore/actions/workflows/ci.yml) +[![PyPI](https://img.shields.io/pypi/v/spacecore.svg)](https://pypi.org/project/spacecore/) +[![Python](https://img.shields.io/pypi/pyversions/spacecore.svg)](https://pypi.org/project/spacecore/) +[![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) -For example, the same algorithm can run with NumPy for debugging, JAX for -JIT/autodiff, and Torch for tensor workflows, while preserving the same -mathematical spaces and linear operators. +**Backend-agnostic vector spaces, linear operators, and iterative solvers for scientific computing.** -## What problem does SpaceCore solve? +Write your algorithm once. Run it on NumPy for development, JAX for GPU acceleration and autodiff, or PyTorch for ML pipelines — without changing a line. -Numerical algorithms often start as clear NumPy code and later need to move to -JAX, Torch, or another array system. Without a backend boundary, that migration -usually leaks through the whole implementation: array constructors, dtype -handling, inner products, sparse support, and linear-operator conventions all -become backend-specific. - -SpaceCore keeps those choices in a `Context`, while algorithms work with -mathematical objects: - -- a `Space` knows the structure and geometry of its elements; -- a `LinOp` maps one space to another; -- backend-specific array creation and operations live behind `BackendOps`. - -The result is ordinary Python code whose core numerical logic is not tied to -one array library. +```python +import spacecore as sc +import numpy as np -Mental model: +# Define a space, a linear operator, and solve Ax = b +ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) +X = sc.VectorSpace((100,), ctx) +A = sc.DenseLinOp(np.random.randn(100, 100) @ np.random.randn(100, 100).T + np.eye(100), X, X, ctx) +b = ctx.asarray(np.random.randn(100)) -```text -BackendOps -> Context -> Space/LinOp -> Algorithm +result = sc.cg(A, b, tol=1e-8) +print(f"Converged in {result.num_iters} iterations.") ``` -## Write once, run twice - -This gradient descent loop uses only the `Space` and `LinOp` APIs. It does not -know whether the arrays are NumPy arrays, JAX arrays, or Torch tensors. +Same code on JAX with GPU? ```python -import numpy as np -import spacecore as sc - +ctx = sc.Context(sc.JaxOps(), dtype=jnp.float64) +# ... build A and b the same way using jax arrays ... +result = sc.cg(A, b, tol=1e-8) # runs on GPU, JIT-compiled +``` -def as_numpy(x): - if hasattr(x, "detach"): - return x.detach().cpu().numpy() - return np.asarray(x) +## Install +```bash +pip install spacecore # core (numpy only) +pip install "spacecore[jax]" # add JAX backend +pip install "spacecore[torch]" # add PyTorch backend +pip install "spacecore[jax,torch]" # both +``` -def make_problem(ctx): - X = sc.VectorSpace((3,), ctx) - Y = sc.VectorSpace((2,), ctx) +Python 3.11+. Built on the [Python Array API](https://data-apis.org/array-api/) standard. - A = sc.DenseLinOp( - ctx.asarray([[1.0, 2.0, 3.0], [0.0, 1.0, 0.0]]), - dom=X, - cod=Y, - ctx=ctx, - ) - x = ctx.asarray([1.0, 0.0, -1.0]) - b = ctx.asarray([0.5, 0.25]) - return X, Y, A, x, b +## What is SpaceCore for? +SpaceCore is for people writing numerical algorithms — optimization, inverse problems, eigensolvers, quantum simulation, computational geometry — who don't want to choose between NumPy, JAX, and PyTorch. -def gradient_step(X, A, x, b, eta): - r = A.apply(x) - b - grad = A.rapply(r) - return X.axpy(-eta, grad, x) +### Three things SpaceCore does well +**1. Matrix-free linear operators with algebra.** Write your operator once as `apply` and `adjoint` callables, then compose them: -def run_gradient_descent(X, A, x, b, eta, steps): - for _ in range(steps): - x = gradient_step(X, A, x, b, eta) - return x -``` +```python +# An FFT-based convolution operator, never materialized as a matrix +K = sc.MatrixFreeLinOp(apply=fft_convolve, rapply=fft_convolve_adjoint, dom=X, cod=X, ctx=ctx) +grad = sc.MatrixFreeLinOp(apply=finite_diff, rapply=neg_div, dom=X, cod=Y, ctx=ctx) -Run it with NumPy: +# Build the regularized system operator using algebra +lam = 0.01 +system = K.H @ K + lam * grad.H @ grad # SumLinOp of ComposedLinOps +rhs = K.H.apply(b) -```python -np_ctx = sc.Context(sc.NumpyOps(), dtype="float64") -X, Y, A, x, b = make_problem(np_ctx) -x_numpy = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_numpy)) +# Solve — no matrices were assembled +solution = sc.cg(system, rhs).x ``` -Later, run the same problem and the same `run_gradient_descent` with JAX: +**2. Cross-backend iterative solvers.** CG, LSQR, Lanczos, power iteration — all work uniformly across NumPy, JAX, and PyTorch. JAX backends JIT-compile: ```python -import jax - -jax.config.update("jax_enable_x64", True) - -jax_ctx = sc.Context(sc.JaxOps(), dtype="float64") -X, Y, A, x, b = make_problem(jax_ctx) -x_jax = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_jax)) +ctx = sc.Context(sc.JaxOps(), dtype=jnp.complex128) +A = build_hermitian_operator(ctx) -print(np.allclose(as_numpy(x_numpy), as_numpy(x_jax))) +# Find the smallest eigenpair via Lanczos +result = sc.lanczos_smallest(A, initial_vector, max_iter=50) +print(f"E_0 = {result.eigenvalue}, converged={result.converged}") ``` -Run it the same way with Torch: +**3. Custom Hilbert spaces with non-Euclidean geometry.** Subclass `VectorSpace`, override `inner`, and every solver respects your geometry: ```python -torch_ctx = sc.Context(sc.TorchOps(), dtype="float64") -X, Y, A, x, b = make_problem(torch_ctx) -x_torch = run_gradient_descent(X, A, x, b, eta=0.1, steps=5) -print(as_numpy(x_torch)) +class WeightedL2(sc.VectorSpace): + def __init__(self, shape, weights, ctx=None): + super().__init__(shape, ctx) + self.weights = self.ctx.asarray(weights) -print(np.allclose(as_numpy(x_numpy), as_numpy(x_torch))) -``` - -All three backends produce the same result: + def inner(self, x, y): + return self.ops.vdot(x, self.weights * y) -```text -[ 1.184125 0.3411875 -0.447625 ] -[ 1.184125 0.3411875 -0.447625 ] -True -[ 1.184125 0.3411875 -0.447625 ] -True +# CG, LSQR, Lanczos all use this inner product automatically ``` -If you do not want to enable JAX 64-bit mode, use a supported dtype such as -`"float32"`. - -## What SpaceCore is not +This is the basis for RKHS spaces, truncated Fock spaces (quantum many-body), function spaces with quadrature, and anything else where the geometry isn't `sum(x * y)`. -SpaceCore is not an optimizer and not a NumPy/JAX/Torch replacement. It provides -backend-aware spaces, operators, and context handling so you can write your own -algorithms without wiring them to one array library. +## Quick examples -## Core concepts +### Conjugate gradient on a symmetric positive-definite system -### `Context` - -A `Context` specifies how objects are represented: +```python +import spacecore as sc -- backend operations (`NumpyOps`, `JaxOps`, `TorchOps`, etc.); -- default dtype; -- runtime validation behavior. +ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) +X = sc.VectorSpace((1000,), ctx) +A = sc.DenseLinOp(make_spd_matrix(), X, X, ctx) +b = ctx.asarray(rhs) -Constructors resolve contexts in priority order: explicit `ctx=...`, then -contexts inferred from inputs, then the global default context. Advanced code -that needs this resolution step directly can call -`spacecore.resolve_context_priority(...)`. +result = sc.cg(A, b, tol=1e-10, maxiter=500) +print(f"x = {result.x}, residual = {result.residual_norm}") +``` -### `Space` +### Least-squares with regularization -A `Space` describes the structure and geometry of values: +```python +# min ||Ax - b||^2 + λ||x||^2 via normal equations +I = sc.IdentityLinOp(X) +system = A.H @ A + lam * I +rhs = A.H.apply(b) +x_hat = sc.cg(system, rhs).x +``` -- `VectorSpace` for Euclidean vectors and tensors; -- `HermitianSpace` for Hermitian or symmetric matrices; -- `ProductSpace` for Cartesian products of spaces. +### Smallest eigenpair of a Hermitian operator -Algorithms should use space methods such as `zeros`, `add`, `scale`, `axpy`, -`inner`, `norm`, `flatten`, and `unflatten` instead of hard-coding backend array -operations. +```python +result = sc.lanczos_smallest(A, initial_vector, max_iter=100) +print(f"E_0 ≈ {result.eigenvalue}") +print(f"Krylov dimension used: {result.krylov_dim}") +print(f"Converged: {result.converged}") +``` -### `LinOp` +### Building a custom operator -A `LinOp` represents a linear operator between spaces: +```python +class Convolution(sc.LinOp): + def __init__(self, kernel, space, ctx): + super().__init__(space, space, ctx) + self.kernel = kernel -- `DenseLinOp` for dense matrix or tensor operators; -- `SparseLinOp` for sparse operators; -- `BlockDiagonalLinOp` for block-diagonal product-space operators; -- `StackedLinOp` for operators from one space into a product space; -- `SumToSingleLinOp` for operators from a product space into one space. + def apply(self, x): + return self.ops.real(self.ops.fft.ifft(self.ops.fft.fft(x) * self.ops.fft.fft(self.kernel))) -Operators expose `apply` and `rapply`, so algorithms can use a linear map and -its adjoint without depending on the storage format. + def rapply(self, y): + return self.ops.real(self.ops.fft.ifft(self.ops.fft.fft(y) * self.ops.conj(self.ops.fft.fft(self.kernel)))) +``` -## Who should use this? +This operator works on NumPy, JAX, and PyTorch backends without modification. -SpaceCore is aimed at people writing optimization, inverse-problem, optimal -transport, semidefinite programming, or scientific ML algorithms that should not -be tied to one backend. +## How is SpaceCore different from...? -It is most useful when you want the mathematical model to stay stable while the -execution backend changes. +**...`scipy.sparse.linalg`?** SciPy's iterative solvers are great but tied to NumPy/SciPy. SpaceCore gives you the same algorithms across NumPy, JAX, and PyTorch, plus operator algebra (`A @ B + lam * I` actually returns a usable operator), plus first-class custom Hilbert spaces. -## Installation +**...PyLops?** PyLops is excellent for inverse problems but assumes Euclidean vectors and is tied to NumPy/CuPy. SpaceCore handles non-Euclidean geometry (RKHS, weighted spaces, function spaces) and works on JAX/PyTorch for autodiff and ML pipelines. -Base install: +**...QuTiP?** QuTiP is the standard for quantum optics on top of SciPy. SpaceCore lets you build the same quantum operators on JAX or PyTorch for GPU acceleration and gradient-based parameter learning. Less prebuilt, more composable. -```bash -pip install spacecore -``` +**...`array_api_compat`?** That package gives you portable arrays. SpaceCore builds on top of it to give you portable *vector spaces, linear operators, and iterative algorithms* — the abstractions one level up from arrays. -With JAX support: +## Documentation -```bash -pip install "spacecore[jax]" -``` +[//]: # (- **[Quick Start](https://pavlo3p.github.io/SpaceCore/quickstart.html)** — 20-line introduction) +[//]: # (- **[Concepts](https://pavlo3p.github.io/SpaceCore/concepts.html)** — Spaces, operators, contexts) +[//]: # (- **[Tutorials](https://pavlo3p.github.io/SpaceCore/tutorials/index.html)** — Image deblurring, Jaynes-Cummings model, kernel ridge regression) +- **[API Reference](https://pavlo3p.github.io/SpaceCore/api/index.html)** — Full documentation -With PyTorch support: +## Features at a glance -```bash -pip install "spacecore[torch]" -``` +**Spaces.** `VectorSpace`, `HermitianSpace`, `ProductSpace`, `BatchSpace`. All easy to subclass for custom geometry. -- `spacecore[jax]` installs optional JAX support. -- GPU users should install the appropriate CUDA-enabled JAX build first, - following the official JAX installation guide. -- `spacecore[torch]` installs optional PyTorch support for `torch.Tensor` - backends. -- GPU users should install the appropriate CUDA-enabled PyTorch build first, - following the official PyTorch installation guide. +**Linear operators.** `DenseLinOp`, `SparseLinOp`, `DiagonalLinOp`, `MatrixFreeLinOp`, plus operator algebra (`A @ B`, `A + B`, `2 * A`, `A.H`, `IdentityLinOp`, `ZeroLinOp`). -For local development: +**Functionals.** `LinearFunctional`, `QuadraticForm`, with `value`, `grad`, `hess_apply`, and `compose(linop)` for pull-back. -```bash -python -m pip install -e ".[dev]" -``` +**Iterative solvers.** `cg`, `lsqr`, `lanczos_smallest`, `power_iteration`. -## Full example +**Backends.** NumPy (always), JAX (`spacecore[jax]`), PyTorch (`spacecore[torch]`), CuPy (`spacecore[cupy]`). Adding a backend is ~100 LOC; the registry is public. -For a complete example of regularized optimal transport problem, [see](https://pavlo3p.github.io/SpaceCore/tutorials/regularized_ot.html) -the model is written once and solved with NumPy/JAX backends and -its [notebook](https://github.com/Pavlo3P/SpaceCore/blob/master/tutorials/6_Regularized_Opt_Transport.ipynb). +## Project status -## Documentation +**v0.2 alpha.** API may still change in minor ways. Core abstractions are stable. Suitable for research code; not yet recommended for production deployment. -The hosted documentation is available [here](https://pavlo3p.github.io/SpaceCore/). +The library is being developed in the open and is looking for early users and feedback. If you try it on your problem, please open an issue with what worked and what didn't — that's the single most valuable contribution right now. -The documentation website is built with Sphinx from `docs/source`. +## Contributing -Install the documentation dependencies: +Bug reports, feature requests, and PRs welcome. See [CONTRIBUTING.md](CONTRIBUTING.md). -```bash -python -m pip install -e ".[docs]" -``` +Specific areas where help is wanted: -Build the local HTML documentation: +- **Tutorials.** If SpaceCore solves your problem, a notebook example helps everyone. +- **Backends.** CuPy and Dask integration is partial; adding a new backend is well-scoped (~100 LOC). +- **Performance.** Cross-backend benchmarks on real workloads. +- **Documentation.** Concept pages, FAQ, gotchas. -```bash -sphinx-build -b html docs/source docs/build/html -``` +## License -## Status +Apache 2.0. See [LICENSE](LICENSE). -SpaceCore is currently experimental and under active development. The public API -may still evolve. +## Citation -## License +If SpaceCore is useful in your research, a citation is appreciated: -Apache License 2.0 +```bibtex +@software{spacecore, + author = {Pavlo, Pelikh}, + title = {SpaceCore: Backend-agnostic vector spaces and linear operators}, + url = {https://github.com/Pavlo3P/SpaceCore}, + year = {2026}, +} diff --git a/docs/source/api/backend.rst b/docs/source/api/backend.rst index 1b466cf..811b38a 100644 --- a/docs/source/api/backend.rst +++ b/docs/source/api/backend.rst @@ -41,6 +41,29 @@ JaxOps :show-inheritance: :exclude-members: jax, jnp, jsparse +CuPyOps +------- + +``CuPyOps`` is the optional CuPy backend implementation for GPU arrays and +``cupyx.scipy.sparse`` matrices. It is exported as ``spacecore.backend.CuPyOps`` +only when CuPy is installed in the environment. + +Install the optional backend before using it: + +.. code-block:: bash + + pip install spacecore[cupy] + +Use it through a normal SpaceCore context: + +.. code-block:: python + + import numpy as np + import spacecore as sc + + ctx = sc.Context(sc.CuPyOps(), dtype=np.float64) + x = ctx.asarray([1.0, 2.0, 3.0]) + TorchOps -------- diff --git a/docs/source/api/functionals.rst b/docs/source/api/functionals.rst new file mode 100644 index 0000000..147e6bf --- /dev/null +++ b/docs/source/api/functionals.rst @@ -0,0 +1,60 @@ +Functionals API +=============== + +Functionals represent scalar-valued maps on spaces, including linear +functionals and quadratic forms. + +.. autosummary:: + :nosignatures: + + spacecore.functional.Functional + spacecore.functional.LinearFunctional + spacecore.functional.InnerProductFunctional + spacecore.functional.MatrixFreeLinearFunctional + spacecore.functional.QuadraticForm + spacecore.functional.LinOpQuadraticForm + +Functional +---------- + +.. autoclass:: spacecore.functional.Functional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Linear functionals +------------------ + +.. autoclass:: spacecore.functional.LinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.InnerProductFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.MatrixFreeLinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Quadratic forms +--------------- + +.. autoclass:: spacecore.functional.QuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.LinOpQuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 259ac9d..b7636b8 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,3 +11,5 @@ directives for public objects instead of dumping entire modules. context spaces linops + functionals + linalg diff --git a/docs/source/api/linalg.rst b/docs/source/api/linalg.rst new file mode 100644 index 0000000..f76e39b --- /dev/null +++ b/docs/source/api/linalg.rst @@ -0,0 +1,60 @@ +Linear algebra API +================== + +Linear algebra routines solve systems, estimate eigenpairs, and apply matrix +functions through :class:`~spacecore.linop.LinOp` objects. They use +space-aware vector operations and avoid materializing dense operators unless a +method explicitly projects to a small Krylov subspace. + +.. autosummary:: + :nosignatures: + + spacecore.linalg.cg + spacecore.linalg.lsqr + spacecore.linalg.lanczos_smallest + spacecore.linalg.power_iteration + spacecore.linalg.expm_multiply + spacecore.linalg.CGResult + spacecore.linalg.LSQRResult + spacecore.linalg.LanczosResult + spacecore.linalg.PowerIterationResult + spacecore.linalg.ExpmMultiplyResult + +Solvers +------- + +.. autofunction:: spacecore.linalg.cg + +.. autoclass:: spacecore.linalg.CGResult + :members: + :undoc-members: + +.. autofunction:: spacecore.linalg.lsqr + +.. autoclass:: spacecore.linalg.LSQRResult + :members: + :undoc-members: + +Eigenvalue algorithms +--------------------- + +.. autofunction:: spacecore.linalg.lanczos_smallest + +.. autoclass:: spacecore.linalg.LanczosResult + :members: + :undoc-members: + +.. autofunction:: spacecore.linalg.power_iteration + +.. autoclass:: spacecore.linalg.PowerIterationResult + :members: + :undoc-members: + +Matrix functions +---------------- + +.. autofunction:: spacecore.linalg.expm_multiply + +.. autoclass:: spacecore.linalg.ExpmMultiplyResult + :members: + :undoc-members: diff --git a/docs/source/api/linops.rst b/docs/source/api/linops.rst index ef963b4..f0703f6 100644 --- a/docs/source/api/linops.rst +++ b/docs/source/api/linops.rst @@ -10,10 +10,20 @@ actions. spacecore.linop.LinOp spacecore.linop.ProductLinOp spacecore.linop.DenseLinOp + spacecore.linop.DiagonalLinOp spacecore.linop.SparseLinOp + spacecore.linop.MatrixFreeLinOp + spacecore.linop.IdentityLinOp + spacecore.linop.ZeroLinOp + spacecore.linop.ScaledLinOp + spacecore.linop.SumLinOp + spacecore.linop.ComposedLinOp spacecore.linop.BlockDiagonalLinOp spacecore.linop.StackedLinOp spacecore.linop.SumToSingleLinOp + spacecore.linop.make_scaled + spacecore.linop.make_sum + spacecore.linop.make_composed LinOp ----- @@ -42,6 +52,15 @@ DenseLinOp :inherited-members: :show-inheritance: +DiagonalLinOp +------------- + +.. autoclass:: spacecore.linop.DiagonalLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + SparseLinOp ----------- @@ -51,6 +70,60 @@ SparseLinOp :inherited-members: :show-inheritance: +MatrixFreeLinOp +--------------- + +.. autoclass:: spacecore.linop.MatrixFreeLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +IdentityLinOp +------------- + +.. autoclass:: spacecore.linop.IdentityLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +ZeroLinOp +--------- + +.. autoclass:: spacecore.linop.ZeroLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Algebraic operators +------------------- + +.. autoclass:: spacecore.linop.ScaledLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.linop.SumLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.linop.ComposedLinOp + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autofunction:: spacecore.linop.make_scaled + +.. autofunction:: spacecore.linop.make_sum + +.. autofunction:: spacecore.linop.make_composed + Product-structured operators ---------------------------- diff --git a/docs/source/api/spaces.rst b/docs/source/api/spaces.rst index 01075fe..ba94a95 100644 --- a/docs/source/api/spaces.rst +++ b/docs/source/api/spaces.rst @@ -10,6 +10,7 @@ Spaces define element structure, geometry, flattening, and validation. spacecore.space.VectorSpace spacecore.space.HermitianSpace spacecore.space.ProductSpace + spacecore.space.BatchSpace spacecore.space.SpaceCheck spacecore.space.SpaceValidationError @@ -49,6 +50,15 @@ ProductSpace :inherited-members: :show-inheritance: +BatchSpace +---------- + +.. autoclass:: spacecore.space.BatchSpace + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + Validation ---------- diff --git a/docs/source/design/checking_policy.rst b/docs/source/design/checking_policy.rst index 1b4976e..bfb8ca7 100644 --- a/docs/source/design/checking_policy.rst +++ b/docs/source/design/checking_policy.rst @@ -36,6 +36,18 @@ before ``apply`` and ``rapply`` when checking is enabled. For exploratory use, enabled checks produce clearer errors. For tight numerical loops, disabled checks reduce validation overhead. +Implementation convention +------------------------- + +Methods that perform simple membership validation should use +``@checked_method`` rather than inline ``if self._enable_checks`` branches. This +keeps validation policy visible at the method signature and avoids duplicating +the same guard throughout spaces, operators, and functionals. + +Inline ``if self._enable_checks`` blocks are reserved for checks that are not +plain membership checks, such as dense-array assertions, custom output-shape +comparisons, or the implementation of ``_check_member`` itself. + Inferred checking policy ------------------------ diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 1d9df1b..f722ba5 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -10,3 +10,5 @@ users should reason about them. conversion_policy dtype_policy checking_policy + jax_integration + backend_ops_array_api diff --git a/docs/source/design/jax_integration.rst b/docs/source/design/jax_integration.rst new file mode 100644 index 0000000..ef99186 --- /dev/null +++ b/docs/source/design/jax_integration.rst @@ -0,0 +1,44 @@ +JAX integration +=============== + +JIT usage notes +--------------- + +SpaceCore's numerical kernels are written to run under ``jax.jit`` when values +live in a JAX-backed ``Context``. The object model remains ordinary Python: +spaces, operators, and functionals are assembled before the numerical kernel is +traced, then passed into the jitted function. + +Operator algebra such as ``A @ B`` and ``A + B`` executes Python-level +simplification rules at construction time. For maximum JIT efficiency: + +* construct operator expressions outside the JIT-decorated function; +* pass the assembled operator as an argument to the jitted function; +* avoid calling ``make_sum`` or ``make_composed`` from inside a ``jax.jit`` + body. + +This is a trace-time concern rather than a correctness concern. The algebra is +correct either way, but composing inside ``jax.jit`` means the simplification +runs once per trace. For repeatedly invoked code with stable operator +structure, build the expression once outside the jitted function. + +Example: + +.. code-block:: python + + import jax + import spacecore as sc + + ctx = sc.Context(sc.JaxOps(), dtype="float32") + X = sc.VectorSpace((128,), ctx) + A = build_operator(X) + B = build_preconditioner(X) + + # Build algebra outside the JIT boundary. + system = B.H @ A @ B + 0.01 * sc.IdentityLinOp(X, ctx) + + @jax.jit + def solve(op, rhs): + return sc.cg(op, rhs, maxiter=50).x + + x = solve(system, rhs) diff --git a/docs/source/index.rst b/docs/source/index.rst index 709aae8..8a94a88 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -4,24 +4,26 @@ SpaceCore SpaceCore exists for writing numerical algorithms once, independently of the array backend. -For example, the same algorithm can run with NumPy for debugging, JAX for -JIT/autodiff, and Torch for tensor workflows, while preserving the same +For example, the same algorithm can run with NumPy for debugging, CuPy for +eager GPU execution, JAX for JIT/autodiff, and Torch for tensor workflows, +while preserving the same mathematical spaces and linear operators. What problem does SpaceCore solve? ---------------------------------- Numerical algorithms often start as clear NumPy code and later need to move to -JAX, Torch, or another array system. Without a backend boundary, that migration -usually leaks through the whole implementation: array constructors, dtype -handling, inner products, sparse support, and linear-operator conventions all -become backend-specific. +CuPy, JAX, Torch, or another array system. Without a backend boundary, that +migration usually leaks through the whole implementation: array constructors, +dtype handling, inner products, sparse support, and linear-operator conventions +all become backend-specific. SpaceCore keeps those choices in a ``Context``, while algorithms work with mathematical objects: * a ``Space`` knows the structure and geometry of its elements; * a ``LinOp`` maps one space to another; +* a ``Functional`` maps a space element to a scalar; * backend-specific array creation and operations live behind ``BackendOps``. The result is ordinary Python code whose core numerical logic is not tied to @@ -31,13 +33,14 @@ Mental model: .. code-block:: text - BackendOps -> Context -> Space/LinOp -> Algorithm + BackendOps -> Context -> Space/LinOp/Functional -> Algorithm Write once, run twice --------------------- This gradient descent loop uses only the ``Space`` and ``LinOp`` APIs. It does -not know whether the arrays are NumPy arrays, JAX arrays, or Torch tensors. +not know whether the arrays are NumPy arrays, CuPy arrays, JAX arrays, or Torch +tensors. .. code-block:: python @@ -140,7 +143,7 @@ Core concepts A ``Context`` specifies how objects are represented: -* backend operations (``NumpyOps``, ``JaxOps``, ``TorchOps``, etc.); +* backend operations (``NumpyOps``, ``CuPyOps``, ``JaxOps``, ``TorchOps``, etc.); * default dtype; * runtime validation behavior. @@ -157,6 +160,8 @@ A ``Space`` describes the structure and geometry of values: * ``VectorSpace`` for Euclidean vectors and tensors; * ``HermitianSpace`` for Hermitian or symmetric matrices; * ``ProductSpace`` for Cartesian products of spaces. +* ``BatchSpace`` for batched elements such as ``X.batch((B,), (0,))``, + representing ``B`` independent copies of ``X``. Algorithms should use space methods such as ``zeros``, ``add``, ``scale``, ``axpy``, ``inner``, ``norm``, ``flatten``, and ``unflatten`` instead of @@ -168,7 +173,12 @@ hard-coding backend array operations. A ``LinOp`` represents a linear operator between spaces: * ``DenseLinOp`` for dense matrix or tensor operators; +* ``DiagonalLinOp`` for coordinatewise diagonal operators; * ``SparseLinOp`` for sparse operators; +* ``MatrixFreeLinOp`` for callable-backed operators without stored matrices; +* ``IdentityLinOp`` and ``ZeroLinOp`` for canonical identity and zero maps; +* ``ScaledLinOp``, ``SumLinOp``, and ``ComposedLinOp`` for lazy operator + algebra; * ``BlockDiagonalLinOp`` for block-diagonal product-space operators; * ``StackedLinOp`` for operators from one space into a product space; * ``SumToSingleLinOp`` for operators from a product space into one space. @@ -176,6 +186,33 @@ A ``LinOp`` represents a linear operator between spaces: Operators expose ``apply`` and ``rapply``, so algorithms can use a linear map and its adjoint without depending on the storage format. +For batched inputs, ``vapply(xs)`` and ``rvapply(ys)`` lift the operator over +the leading batch axis: + +.. code-block:: python + + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB + xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB + +The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero, +algebraic, and product-structured operators provide specialized batched paths. + +``Functional`` +~~~~~~~~~~~~~~ + +A ``Functional`` represents a scalar-valued map on a space. +``LinearFunctional`` covers maps such as ````, +``MatrixFreeLinearFunctional`` wraps a callable without storing a representer, +and ``LinOpQuadraticForm`` represents objectives such as +``0.5 * + ell(x) + a``. + +For batched inputs, ``vvalue(xs)`` evaluates independently over leading batch +axes. Quadratic forms that define gradients also expose ``grad(x)`` and +``vgrad(xs)``. + Who should use this? -------------------- @@ -201,6 +238,12 @@ With JAX support: pip install "spacecore[jax]" +With CuPy support: + +.. code-block:: bash + + pip install "spacecore[cupy]" + With PyTorch support: .. code-block:: bash @@ -210,6 +253,10 @@ With PyTorch support: * ``spacecore[jax]`` installs optional JAX support. * GPU users should install the appropriate CUDA-enabled JAX build first, following the official JAX installation guide. +* ``spacecore[cupy]`` installs optional CuPy support for ``cupy.ndarray`` and + ``cupyx.scipy.sparse`` backends. +* GPU users should install the appropriate CUDA-enabled CuPy package first, + following the official CuPy installation guide. * ``spacecore[torch]`` installs optional PyTorch support for ``torch.Tensor`` backends. * GPU users should install the appropriate CUDA-enabled PyTorch build first, diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 35613fe..100a9e1 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -1,6 +1,240 @@ Release notes ============= +Version 0.2.0 +------------- + +SpaceCore 0.2.0 is a major API expansion. The backend layer now sits on the +Array API standard. Operators gained a lazy algebra with adjoint views, +composition, sums, and scaling. A new :class:`Functional` hierarchy provides +scalar-valued maps with gradients and pull-backs. A new :mod:`spacecore.linalg` +module ships four JIT-compatible iterative solvers. Spaces, operators, and +functionals share a single validation pattern via ``checked_method``, and the +public API is documented to numpydoc standard with doctest coverage. + +This release introduces breaking renames; see :ref:`migration-0-2`. + +Highlights +~~~~~~~~~~ + +* Array API backend layer with optional CuPy support. +* Lazy operator algebra: ``A @ B``, ``A + B``, ``A.H``, plus + :class:`IdentityLinOp`, :class:`ZeroLinOp`, :class:`MatrixFreeLinOp`, and + ``make_*`` factories with algebraic simplification. +* :class:`Functional` hierarchy with linear and quadratic forms, plus + ``Functional.compose`` for pull-back along linear operators. +* New :mod:`spacecore.linalg` module with iterative solvers: :func:`cg`, + :func:`lsqr`, :func:`power_iteration`, :func:`lanczos_smallest`, + :func:`expm_multiply`. +* Geometry-aware solvers honor the declared ``Space.inner`` instead of assuming + Euclidean. +* Unified ``checked_method`` decorator across :class:`Space`, :class:`LinOp`, + and :class:`Functional`. +* Comprehensive numpydoc-style docstrings, doctests, and a JAX integration + design note. + +Backend +~~~~~~~ + +* Migrated :class:`BackendOps` to the Array API standard via + ``array-api-compat``. +* Added :class:`CuPyOps` and the ``cupy`` backend family as an optional install + (``pip install 'spacecore[cupy]'``). +* Centralized complex-dtype handling on :class:`BackendOps`: + + * :meth:`BackendOps.is_complex_dtype` for backend-aware complex detection. + * :meth:`BackendOps.real_dtype` for extracting the real dtype matching a + complex one. + +* Broadened backend coverage for array creation, dtype conversion, sparse + conversion, indexing, reductions, linear algebra, loop primitives + (``fori_loop``, ``while_loop``, ``cond``), tree helpers, and vectorized + mapping. +* Registered JAX pytrees for operator, space, and functional types so they pass + through ``jax.jit``, ``jax.vmap``, and ``jax.grad`` boundaries. + +Context and checking +~~~~~~~~~~~~~~~~~~~~ + +* Restructured ``_contextual`` to hide implementation details while keeping the + public free-function API (:func:`set_context`, :func:`get_context`, + :func:`resolve_context_priority`, :func:`register_ops`, and the + resolution-policy accessors). +* Extended :func:`~spacecore._checks.checked_method` to support validation + against ``self`` and multiple input argument positions. +* Replaced manual ``if self._enable_checks`` guards with ``checked_method`` + across :class:`Space`, :class:`LinOp`, and :class:`Functional`. Inline guards + are now reserved for non-membership checks such as dense-array assertions and + custom output-shape checks. +* Added reusable space-validation checks documented at + ``docs/source/design/checking_policy.rst``: backend, dtype, shape, Hermitian, + square-matrix, product-structure, and product-component checks. + +Spaces +~~~~~~ + +* Added :class:`BatchSpace` for batched elements with explicit batch shape and + batch-axis metadata. +* Improved :class:`VectorSpace`, :class:`HermitianSpace`, and + :class:`ProductSpace` conversion behavior, validation, batching support, and + docstrings. + +Linear operators +~~~~~~~~~~~~~~~~ + +* **Lazy operator algebra.** Added composition, addition, scaling, and adjoint + view with algebraic simplification: + + * ``A @ B`` composes operators. + * ``A + B`` sums operators. + * ``alpha * A`` scales an operator. + * ``A.H`` returns a cached adjoint view satisfying ``A.H.H is A``. + + Simplification rules eliminate ``I``, ``Zero``, ``alpha = 0``, ``alpha = 1``, + and flatten nested sums. + +* Added :class:`IdentityLinOp`, :class:`ZeroLinOp`, :class:`MatrixFreeLinOp`, + and :class:`DiagonalLinOp`. +* Added structural :meth:`LinOp.is_hermitian` reporting ``True``, ``False``, + or ``None`` (unknown) without applying incorrect Euclidean assumptions for + custom space geometries. +* Added :meth:`LinOp.to_dense` for materializing operators as backend arrays. +* Added product-structured operators and batched lifting: + + * :class:`ProductLinOp` + * :class:`BlockDiagonalLinOp` + * :class:`StackedLinOp` + * :class:`SumToSingleLinOp` + * ``vapply`` / ``rvapply`` paths for batched operator application. + +* Improved linear-operator equality, representation, conversion, and JAX + pytree behavior. + +Functionals +~~~~~~~~~~~ + +* Added :class:`Functional` as an abstract base for scalar-valued maps on + spaces, with :meth:`value`, :meth:`grad`, :meth:`hess_apply`, and batched + counterparts. +* Added linear functional implementations: + + * :class:`LinearFunctional` + * :class:`InnerProductFunctional` + * :class:`MatrixFreeLinearFunctional` + +* Added quadratic forms: + + * :class:`QuadraticForm` + * :class:`LinOpQuadraticForm` + +* Added :meth:`Functional.compose` and :class:`ComposedFunctional` for + pull-backs along linear operators, with specializations that preserve the + concrete functional type when possible. + +Linear algebra +~~~~~~~~~~~~~~ + +The :mod:`spacecore.linalg` module is new in 0.2.0. It provides +JIT-compatible iterative solvers and structured result types. + +* Added iterative solvers: + + * :func:`cg` for Hermitian positive-definite systems. + * :func:`lsqr` for rectangular least-squares problems. + * :func:`power_iteration` for dominant-eigenpair estimates of a + :class:`LinOp` or :class:`QuadraticForm`. + * :func:`lanczos_smallest` for smallest-Ritz-eigenpair estimates of + Hermitian operators. + * :func:`expm_multiply` for Krylov matrix-exponential actions + ``exp(t A) v`` on Hermitian operators, with complex ``t`` supported for + Schrodinger-type evolution. + +* Added structured result types :class:`CGResult`, :class:`LSQRResult`, + :class:`PowerIterationResult`, :class:`LanczosResult`, and + :class:`ExpmMultiplyResult`, each carrying convergence diagnostics and a + compact ``__repr__``. +* Solvers are geometry-aware: norms, inner products, and the default initial + vector use ``Space.inner`` and ``Space.norm`` rather than assuming Euclidean + geometry. This makes the solvers correct on custom inner products such as + RKHS or weighted spaces. + +Documentation +~~~~~~~~~~~~~ + +* Reworked public docstrings to numpydoc standard with runnable doctests for + solvers, spaces, operators, functionals, backends, and contextual helpers. +* Clarified solver contracts: ``domain == codomain`` square requirements, + Hermiticity enforcement, tolerance semantics, JAX static arguments, complex + scalar behavior, ill-conditioning caveats, and convergence assumptions. +* Added API reference pages for backend ops, spaces, linear operators, + functionals, and linear algebra. +* Added a JAX integration design note documenting trace-time operator algebra + and recommended JIT usage at + ``docs/source/design/jax_integration.rst``. +* Added tutorials for backend operations, linear operators, and matrix-free + linalg workflows. + +Testing and CI +~~~~~~~~~~~~~~ + +* Added cross-backend tests covering NumPy, JAX, Torch, and optional CuPy. +* Added tests for backend ops delegation, backend loop primitives, CuPy ops, + context resolution, ``checked_method``, functionals, linalg solvers, + operator algebra, batched lifting, dense materialization, diagonal + operators, and JAX pytree/JIT behavior. +* Added CI execution of a JIT-traceability audit script in ``--check`` mode + and a coverage floor of 70% via ``pytest-cov``. +* Added nonblocking documentation lint and audit steps for the docstring + migration. + +Packaging +~~~~~~~~~ + +* Bumped the package version to ``0.2.0``. +* ``spacecore.__version__`` now resolves from package metadata via + ``importlib.metadata`` instead of a hand-maintained constant. +* Added optional dependency groups: ``[jax]``, ``[torch]``, ``[cupy]``, + ``[examples]``, ``[docs]``, ``[dev]``. +* Added an explicit ``__all__`` at the top level covering new backends, + operators, functionals, solvers, result types, validation checks, and + contextual helpers. + +.. _migration-0-2: + +Migration from 0.1.x +~~~~~~~~~~~~~~~~~~~~ + +* ``BackendOps.eps`` is now a method ``eps(dtype)`` rather than a property. + Callers must pass a dtype, typically ``ctx.dtype``. +* The implementation attribute ``DenseLinOp.A`` is now a + :class:`functools.cached_property` backed by ``_A``. The public attribute + access ``op.A`` is unchanged. +* :meth:`LinOp.__eq__` now returns ``NotImplemented`` instead of raising + ``NotImplementedError`` on the base class, so ``op == None`` and + ``op in some_list`` no longer raise. +* Several module-internal helpers in ``spacecore._contextual`` moved to + private modules. Use the public functions re-exported from + :mod:`spacecore._contextual` (``set_context``, ``get_context``, + ``resolve_context_priority``, ``register_ops``, ``set_resolution_policy``, + and the dtype-policy accessors) rather than importing from internal modules. + +Known limitations +~~~~~~~~~~~~~~~~~ + +* :func:`cg`, :func:`lsqr`, and :func:`power_iteration` do not structurally + validate operator properties (positive-definiteness, full Hermiticity) and + may silently produce incorrect results on inputs that violate their + preconditions. See each function's ``Notes`` section for details. +* Operator algebra runs Python-level simplification at construction time. For + maximum JIT efficiency, assemble operator expressions outside the + ``jax.jit`` boundary; see the JAX integration design note. +* :class:`MatrixFreeLinOp` stores its callables in pytree auxiliary data. + Constructing one inside a JIT-traced function with a new lambda each call + triggers retracing. Construct outside the traced region with a stable + callable reference. +* The CuPy backend is provided as a preview. Coverage of non-standard + operations and sparse handling may evolve in a subsequent release. + Version 0.1.4 ------------- diff --git a/docs/source/tutorials/backend_ops.rst b/docs/source/tutorials/backend_ops.rst index c90a88e..b0f6e83 100644 --- a/docs/source/tutorials/backend_ops.rst +++ b/docs/source/tutorials/backend_ops.rst @@ -5,9 +5,15 @@ This tutorial follows ``tutorials/1_BackendOps.ipynb``. It explains what ``BackendOps`` represents in SpaceCore, how it relates to ``Context``, and how to use the predefined backends. -Current predefined implementations are ``NumpyOps``, ``JaxOps``, and -``TorchOps``. ``TorchOps`` is optional and is available after installing the -PyTorch extra: +Current predefined implementations are ``NumpyOps``, ``JaxOps``, ``CuPyOps``, +and ``TorchOps``. ``CuPyOps`` and ``TorchOps`` are optional and are available +after installing their backend extras: + +.. code-block:: bash + + pip install spacecore[cupy] + +Install PyTorch support with: .. code-block:: bash @@ -31,9 +37,10 @@ SpaceCore separates two concerns: * the numerical backend used to store and compute with them. The same mathematical object may be represented using NumPy arrays for eager CPU -work, JAX arrays for JIT compilation and automatic differentiation, or PyTorch -tensors for eager CPU/CUDA execution and autograd. Without a backend abstraction, spaces and -operators would need backend-specific branches throughout their implementations. +work, CuPy arrays for eager GPU execution, JAX arrays for JIT compilation and +automatic differentiation, or PyTorch tensors for eager CPU/CUDA execution and +autograd. Without a backend abstraction, spaces and operators would need +backend-specific branches throughout their implementations. The design is: @@ -52,6 +59,15 @@ What BackendOps signifies internals. It mostly wraps NumPy-like methods, while normalizing the minimal signatures that SpaceCore relies on. +Common dense-array methods are implemented once in ``BackendOps`` by delegating +to an Array API compatible ``xp`` namespace. NumPy and PyTorch use +``array-api-compat`` wrappers, while CuPy uses ``cupy`` and JAX uses +``jax.numpy``. Concrete backend classes keep behavior that is genuinely +backend-specific, such as dtype sanitation, sparse conversion, indexed updates, +device/autograd controls, and control-flow primitives. ``ops.xp`` is available +as an escape hatch, but portable SpaceCore code should prefer explicit ``ops`` +methods. + For example, NumPy and JAX expose different optional arguments for matrix multiplication, but SpaceCore's portable interface only needs the common core: @@ -139,6 +155,19 @@ Use ``JaxOps`` for the JAX execution model: JIT compilation, automatic differentiation, accelerator execution, and JAX sparse compatibility. JAX dtype behavior depends on local JAX configuration, especially ``jax_enable_x64``. +Use ``CuPyOps`` for eager GPU-backed CuPy arrays and ``cupyx.scipy.sparse`` +matrices. The backend is optional and is exported only when CuPy is installed. +It follows CuPy's NumPy-compatible dtype behavior and keeps arrays on the CuPy +device where they were created. + +.. code-block:: python + + import numpy as np + from spacecore.backend import Context, CuPyOps + + ctx_cupy = Context(CuPyOps(), dtype=np.float64) + x = ctx_cupy.asarray([1.0, 2.0, 3.0]) + Use ``TorchOps`` for PyTorch tensors. The backend can be requested by either ``"torch"`` or ``"pytorch"`` where SpaceCore accepts backend names. diff --git a/docs/source/tutorials/linops.rst b/docs/source/tutorials/linops.rst index 72d73c4..2005b20 100644 --- a/docs/source/tutorials/linops.rst +++ b/docs/source/tutorials/linops.rst @@ -7,7 +7,14 @@ the abstraction for linear maps between spaces. Current implemented operator types are: * ``DenseLinOp`` +* ``DiagonalLinOp`` * ``SparseLinOp`` +* ``MatrixFreeLinOp`` +* ``IdentityLinOp`` +* ``ZeroLinOp`` +* ``ScaledLinOp`` +* ``SumLinOp`` +* ``ComposedLinOp`` * ``BlockDiagonalLinOp`` * ``StackedLinOp`` * ``SumToSingleLinOp`` @@ -60,6 +67,38 @@ operator :math:`A^* : Y \to X`, satisfying \langle Ax, y\rangle_Y = \langle x, A^*y\rangle_X. +Batched lifting +--------------- + +For a batch of elements, use ``Space.batch`` to describe the batched space and +``vapply`` or ``rvapply`` to lift the operator: + +.. math:: + + A : X \to Y, + \qquad + A^{(B)} : X^B \to Y^B. + +.. code-block:: python + + B = 8 + XB = X.batch(batch_shape=(B,), batch_axes=(0,)) + YB = Y.batch(batch_shape=(B,), batch_axes=(0,)) + + xs = ctx.asarray(np.ones((B,) + X.shape)) + ys = op.vapply(xs, batch_space=XB) + xs_back = op.rvapply(ys, batch_space=YB) + +This is equivalent to stacking scalar applications: + +.. code-block:: python + + ys_ref = ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + +The base fallback uses backend ``vmap``. Structured operators override this +path when they can use matrix multiplication, sparse multi-vector products, +broadcasting, or componentwise product-space batching. + DenseLinOp ---------- @@ -107,6 +146,42 @@ of the operator structure. op_sparse = sc.SparseLinOp(A_sparse, X, Y, ctx=ctx) +MatrixFreeLinOp +--------------- + +``MatrixFreeLinOp`` stores callables for forward and adjoint actions instead +of matrix entries. Use it when a linear map has a fast procedural +implementation or when materializing a matrix is too expensive. + +.. code-block:: python + + def apply(x): + return ctx.asarray([x[0] + x[1], x[0] - x[1]]) + + def rapply(y): + return ctx.asarray([y[0] + y[1], y[0] - y[1]]) + + op_free = sc.MatrixFreeLinOp(apply, rapply, X, X, ctx=ctx) + +Canonical and algebraic operators +--------------------------------- + +``IdentityLinOp`` and ``ZeroLinOp`` represent the canonical identity and zero +maps on spaces. Operator algebra creates lazy operators without immediately +materializing dense storage: + +.. code-block:: python + + I = sc.IdentityLinOp(X, ctx=ctx) + Z = sc.ZeroLinOp(X, Y, ctx=ctx) + + scaled = 2.0 * I # ScaledLinOp + summed = I + scaled # SumLinOp + composed = summed @ I # ComposedLinOp + +The helper constructors ``make_scaled``, ``make_sum``, and ``make_composed`` +perform the same simplifications used by the Python operators. + Product operators ----------------- diff --git a/pyproject.toml b/pyproject.toml index 198b1b6..0f531f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "spacecore" -version = "0.1.4" +version = "0.2.0" description = "Backend-agnostic vector spaces and linear operators." readme = "README.md" requires-python = ">=3.11" @@ -14,6 +14,7 @@ authors = [ { name = "Pavlo Pelikh" } ] dependencies = [ + "array-api-compat>=1.14.0", "numpy>=2.0.0", "scipy>=1.17", ] @@ -43,6 +44,9 @@ jax = [ torch = [ "torch>=2.0", ] +cupy = [ + "cupy>=13.0", +] examples = [ "matplotlib>=3.8", "optax>=0.2", @@ -56,7 +60,9 @@ docs = [ ] dev = [ "pytest>=8.0", - "ruff>=0.6", + "pytest-cov>=5", + "ruff>=0.6", + "numpydoc>=1.7", ] [tool.setuptools] @@ -67,12 +73,45 @@ where = ["."] include = ["spacecore*"] [tool.pytest.ini_options] -testpaths = ["tests"] +addopts = [ + "--doctest-modules", + "--ignore-glob=spacecore/_contextual/*", + "--ignore-glob=spacecore/backend/cupy/*", + "--ignore-glob=spacecore/backend/torch/*", +] +testpaths = ["spacecore", "tests"] +doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS", "NUMBER"] [tool.ruff] line-length = 100 target-version = "py311" +[tool.ruff.lint] +ignore = [ + "D100", + "D104", + "D203", + "D213", +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + [tool.ruff.lint.per-file-ignores] "tutorials/*.ipynb" = ["E741"] "tests/*.py" = ["E701"] +"tests/*" = ["D"] + +[tool.numpydoc_validation] +checks = [ + "all", + "ES01", + "EX01", + "SA01", + "GL08", +] +exclude = [ + '\._', + '\.tests\.', + 'test_', +] diff --git a/scripts/docstring_audit.py b/scripts/docstring_audit.py new file mode 100644 index 0000000..e83916d --- /dev/null +++ b/scripts/docstring_audit.py @@ -0,0 +1,95 @@ +"""Report numpydoc validation issues for SpaceCore's public API.""" + +from __future__ import annotations + +import argparse +import inspect +from collections.abc import Iterable +from dataclasses import dataclass + +from numpydoc.validate import validate + +import spacecore + +ALLOWED_CODES = frozenset({"ES01", "EX01", "SA01", "GL08"}) + + +@dataclass(frozen=True) +class ValidationIssue: + """A single numpydoc validation issue.""" + + target: str + code: str + message: str + + +def _iter_public_targets() -> Iterable[str]: + """Yield public import paths exported by the top-level package.""" + for name in getattr(spacecore, "__all__", ()): + if name.startswith("_"): + continue + target = f"spacecore.{name}" + try: + obj = getattr(spacecore, name) + except AttributeError: + continue + if inspect.ismodule(obj): + continue + yield target + + +def _validate_target(target: str, *, include_allowed: bool) -> list[ValidationIssue]: + """Validate one import path and normalize numpydoc's result shape.""" + try: + result = validate(target) + except Exception as exc: # pragma: no cover - defensive reporting path + return [ValidationIssue(target, "IMPORT", f"{type(exc).__name__}: {exc}")] + + issues = [] + for code, message in result.get("errors", []): + if not include_allowed and code in ALLOWED_CODES: + continue + issues.append(ValidationIssue(target, code, message)) + return issues + + +def collect_issues(*, include_allowed: bool = False) -> list[ValidationIssue]: + """Collect numpydoc issues for exported public symbols.""" + issues: list[ValidationIssue] = [] + for target in sorted(set(_iter_public_targets())): + issues.extend(_validate_target(target, include_allowed=include_allowed)) + return issues + + +def main() -> int: + """Run the audit command.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--check", + action="store_true", + help="exit non-zero when validation issues are present", + ) + parser.add_argument( + "--max-lines", + type=int, + default=200, + help="maximum number of individual issues to print", + ) + parser.add_argument( + "--include-allowed", + action="store_true", + help="include issues allowed during the migration baseline", + ) + args = parser.parse_args() + + issues = collect_issues(include_allowed=args.include_allowed) + for issue in issues[: args.max_lines]: + print(f"{issue.target}:{issue.code}:{issue.message}") + if len(issues) > args.max_lines: + print(f"... {len(issues) - args.max_lines} more issues omitted") + print(f"numpydoc issues: {len(issues)}") + return 1 if args.check and issues else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/jit_audit.py b/scripts/jit_audit.py new file mode 100644 index 0000000..a822b2c --- /dev/null +++ b/scripts/jit_audit.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, Callable + +import numpy as np + + +ROOT = Path(__file__).resolve().parents[1] +FIXTURE = ROOT / "tests" / "fixtures" / "jaxpr_lanczos_smallest.txt" + + +def _ctx(): + import jax + import spacecore as sc + + dtype = np.float64 if jax.config.read("jax_enable_x64") else np.float32 + return sc.Context(sc.JaxOps(), dtype=dtype, enable_checks=False) + + +def _spd_operator(n: int): + import spacecore as sc + + ctx = _ctx() + space = sc.VectorSpace((n,), ctx) + matrix = np.diag(np.arange(2.0, 2.0 + n)) + matrix += 0.05 * np.ones((n, n)) + return sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + +def _rect_operator(): + import spacecore as sc + + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 0.0], [0.0, 2.0], [1.0, -1.0]]) + return sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + + +def _same_shape_inputs(A: Any) -> tuple[Any, Any]: + ctx = A.ctx + x0 = ctx.asarray(np.linspace(1.0, 2.0, A.domain.shape[0])) + x1 = ctx.asarray(np.linspace(2.0, 3.0, A.domain.shape[0])) + return x0, x1 + + +def _audit_solver( + name: str, + fn_factory: Callable[[dict[str, int]], Callable[..., Any]], + A: Any, + first_rhs: Any, + second_rhs: Any, + shape_changed_A: Any, + shape_changed_rhs: Any, + static_name: str, +) -> dict[str, Any]: + import jax + + traces = {"count": 0} + fn = fn_factory(traces) + jitted = jax.jit(fn, static_argnames=(static_name,)) + + out0 = jitted(A, first_rhs, **{static_name: 4}) + out1 = jitted(A, second_rhs, **{static_name: 4}) + same_shape_traces = traces["count"] + + out2 = jitted(A, first_rhs, **{static_name: 5}) + static_changed_traces = traces["count"] + + out3 = jitted(shape_changed_A, shape_changed_rhs, **{static_name: 4}) + shape_changed_traces = traces["count"] + + for out in (out0, out1, out2, out3): + jax.block_until_ready(out) + + return { + "solver": name, + "traces_after_two_same_shape_calls": same_shape_traces, + "traces_after_static_change": static_changed_traces, + "traces_after_shape_change": shape_changed_traces, + "stable_values_retraced": same_shape_traces > 1, + "static_change_retraced": static_changed_traces > same_shape_traces, + "shape_change_retraced": shape_changed_traces > static_changed_traces, + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Audit JAX trace stability for SpaceCore solvers.") + parser.add_argument( + "--check", + action="store_true", + help="Exit non-zero if any solver retraces on same-shape inputs.", + ) + parser.add_argument( + "--log-compiles", + action="store_true", + help="Enable jax_log_compiles for manual inspection.", + ) + parser.add_argument( + "--write-fixture", + action="store_true", + help="Write the lanczos_smallest JAXPR fixture. Disabled by default in --check mode.", + ) + return parser.parse_args() + + +def _audit_failed(item: dict[str, Any]) -> bool: + if "status" in item: + return True + return ( + item["traces_after_two_same_shape_calls"] != 1 + or item["stable_values_retraced"] + or not item["static_change_retraced"] + or not item["shape_change_retraced"] + ) + + +def main() -> None: + import jax + import spacecore as sc + + args = _parse_args() + if args.log_compiles: + jax.config.update("jax_log_compiles", True) + + A2 = _spd_operator(2) + A3 = _spd_operator(3) + x2a, x2b = _same_shape_inputs(A2) + x3a, _ = _same_shape_inputs(A3) + R2 = _rect_operator() + R3 = sc.DenseLinOp( + _ctx().asarray([[1.0, 0.0, 0.5], [0.0, 2.0, -1.0], [1.0, -1.0, 0.25], [0.5, 0.0, 1.0]]), + sc.VectorSpace((3,), _ctx()), + sc.VectorSpace((4,), _ctx()), + _ctx(), + ) + b2a = R2.codomain.ctx.asarray([1.0, 2.0, 3.0]) + b2b = R2.codomain.ctx.asarray([3.0, 2.0, 1.0]) + b4 = R3.codomain.ctx.asarray([1.0, 2.0, 3.0, 4.0]) + + audits = [ + _audit_solver( + "cg", + lambda traces: ( + lambda A, b, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.cg(A, b, maxiter=maxiter).x + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "maxiter", + ), + _audit_solver( + "lsqr", + lambda traces: ( + lambda A, b, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.lsqr(A, b, maxiter=maxiter).x + ) + ), + R2, + b2a, + b2b, + R3, + b4, + "maxiter", + ), + _audit_solver( + "lanczos_smallest", + lambda traces: ( + lambda A, x, max_iter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.lanczos_smallest(A, x, max_iter=max_iter).eigenvalue + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "max_iter", + ), + _audit_solver( + "power_iteration", + lambda traces: ( + lambda A, x, maxiter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.power_iteration(A, x0=x, maxiter=maxiter).eigenvalue + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "maxiter", + ), + ] + + if hasattr(sc, "expm_multiply"): + audits.append( + _audit_solver( + "expm_multiply", + lambda traces: ( + lambda A, x, max_iter: ( + traces.__setitem__("count", traces["count"] + 1) + or sc.expm_multiply(A, x, max_iter=max_iter).result + ) + ), + A2, + x2a, + x2b, + A3, + x3a, + "max_iter", + ) + ) + else: + audits.append({"solver": "expm_multiply", "status": "not available before Task 1"}) + + print("JIT audit summary") + for item in audits: + print(item) + + if args.write_fixture or not args.check: + FIXTURE.parent.mkdir(parents=True, exist_ok=True) + jaxpr = jax.make_jaxpr( + lambda A, x: sc.lanczos_smallest(A, x, max_iter=3, check_every=1).eigenvalue + )(A2, x2a) + FIXTURE.write_text(str(jaxpr)) + print(f"wrote {FIXTURE.relative_to(ROOT)}") + + if args.check: + failures = [item for item in audits if _audit_failed(item)] + if failures: + print("JIT audit check failed") + for item in failures: + print(item) + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/spacecore/__init__.py b/spacecore/__init__.py index d130500..d94cd90 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -1,13 +1,64 @@ -__version__ = "0.1.4" +"""Backend-agnostic vector spaces, linear operators, and solvers.""" + +from importlib.metadata import version as _version + +try: + __version__ = _version("spacecore") +except Exception: + __version__ = "0.0.0+unknown" from .backend import Context, BackendOps, JaxOps, NumpyOps, jax_pytree_class +try: + from .backend import CuPyOps as CuPyOps +except ImportError: + pass try: from .backend import TorchOps as TorchOps except ImportError: pass -from .linop import DenseLinOp, SparseLinOp, BlockDiagonalLinOp, SumToSingleLinOp, StackedLinOp, LinOp +from .linop import ( + BlockDiagonalLinOp, + ComposedLinOp, + DiagonalLinOp, + DenseLinOp, + IdentityLinOp, + LinOp, + MatrixFreeLinOp, + ScaledLinOp, + SparseLinOp, + StackedLinOp, + SumLinOp, + SumToSingleLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) +from .functional import ( + ComposedFunctional, + Functional, + InnerProductFunctional, + LinearFunctional, + LinOpQuadraticForm, + MatrixFreeLinearFunctional, + QuadraticForm, + make_functional_composed, +) +from .linalg import ( + CGResult, + ExpmMultiplyResult, + LanczosResult, + LSQRResult, + PowerIterationResult, + cg, + expm_multiply, + lanczos_smallest, + lsqr, + power_iteration, +) from .space import ( + BatchSpace, BackendCheck, DTypeCheck, HermitianCheck, @@ -24,13 +75,15 @@ ) from .types import DenseArray, SparseArray, ArrayLike -from ._contextual import ContextBound -from ._contextual.manager import ( +from ._checks import checked_method +from ._contextual import ( + ContextBound, set_context, get_context, resolve_context_priority, register_ops, set_resolution_policy, set_dtype_resolution_policy, - get_resolution_policy, get_dtype_resolution_policy + get_resolution_policy, get_dtype_resolution_policy, + normalize_ops, normalize_context, ) __all__ = [ @@ -42,18 +95,49 @@ "NumpyOps", "LinOp", + "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "BlockDiagonalLinOp", "SumToSingleLinOp", "StackedLinOp", + "ComposedFunctional", + "Functional", + "LinearFunctional", + "InnerProductFunctional", + "MatrixFreeLinearFunctional", + "QuadraticForm", + "LinOpQuadraticForm", + "make_functional_composed", + + "CGResult", + "ExpmMultiplyResult", + "LanczosResult", + "LSQRResult", + "PowerIterationResult", + "cg", + "expm_multiply", + "lanczos_smallest", + "lsqr", + "power_iteration", + "BackendCheck", "DTypeCheck", "HermitianCheck", "ProductComponentCheck", "ProductStructureCheck", "ShapeCheck", + "BatchSpace", "VectorSpace", "HermitianSpace", "ProductSpace", @@ -66,6 +150,7 @@ "SparseArray", "ArrayLike", + "checked_method", "ContextBound", "set_context", "get_context", @@ -75,7 +160,11 @@ "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", + "normalize_ops", + "normalize_context", ] if "TorchOps" in globals(): __all__.append("TorchOps") +if "CuPyOps" in globals(): + __all__.append("CuPyOps") diff --git a/spacecore/_checks.py b/spacecore/_checks.py new file mode 100644 index 0000000..09ba022 --- /dev/null +++ b/spacecore/_checks.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable + + +def _as_positions( + arg_pos: int | None, + arg_positions: int | tuple[int, ...] | None, +) -> tuple[int, ...]: + """Normalize legacy and multi-position argument selectors.""" + if arg_pos is not None and arg_positions is not None: + raise TypeError("Use either arg_pos or arg_positions, not both.") + if arg_positions is None: + return (0,) if arg_pos is None else (arg_pos,) + if isinstance(arg_positions, int): + return (arg_positions,) + return tuple(arg_positions) + + +def _space_target(self: Any, space_name: str) -> Any: + """Return the space object named by ``space_name``.""" + return self if space_name == "self" else getattr(self, space_name) + + +def checked_method( + *, + in_space: str | None = None, + out_space: str | None = None, + arg_pos: int | None = None, + arg_positions: int | tuple[int, ...] | None = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Build a decorator that validates method inputs and outputs against spaces. + + Parameters + ---------- + in_space : str or None, optional + Name of the attribute on ``self`` containing the input + :class:`~spacecore.space.Space`, ``"self"`` to validate against the + receiver itself, or ``None`` to skip input validation. + out_space : str or None, optional + Name of the attribute on ``self`` containing the output + :class:`~spacecore.space.Space`, ``"self"`` to validate against the + receiver itself, or ``None`` to skip output validation. + arg_pos : int or None, optional + Deprecated alias for a single entry in ``arg_positions``. + arg_positions : int, tuple of int, or None, optional + Zero-based positions in ``*args`` of input values that should be checked + against ``in_space``. Defaults to ``(0,)``. + + Returns + ------- + Callable[[Callable[..., Any]], Callable[..., Any]] + Decorator that wraps a method, performs Python-level checks when + ``self._enable_checks`` is true, and otherwise forwards directly to the + wrapped method. + """ + positions = _as_positions(arg_pos, arg_positions) + + def decorate(method: Callable[..., Any]) -> Callable[..., Any]: + @wraps(method) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if self._enable_checks and in_space is not None: + check_target = _space_target(self, in_space) + for pos in positions: + check_target._check_member(args[pos]) + + y = method(self, *args, **kwargs) + + if self._enable_checks and out_space is not None: + _space_target(self, out_space)._check_member(y) + + return y + + return wrapper + + return decorate diff --git a/spacecore/_contextual/__init__.py b/spacecore/_contextual/__init__.py index 45cebdc..5d319e1 100644 --- a/spacecore/_contextual/__init__.py +++ b/spacecore/_contextual/__init__.py @@ -1,11 +1,45 @@ -from .bound import ContextBound as ContextBound -from .manager import ( - ctx_manager as ctx_manager, - set_context as set_context, - resolve_context_priority as resolve_context_priority, +from ._bound import ContextBound as ContextBound +from ._manager import ( + enforce_convert_policy as enforce_convert_policy, + get_context as get_context, + get_dtype_resolution_policy as get_dtype_resolution_policy, + get_resolution_policy as get_resolution_policy, + normalize_context as normalize_context, + normalize_ops as normalize_ops, register_ops as register_ops, - set_resolution_policy as set_resolution_policy, + resolve_context_priority as resolve_context_priority, + set_context as set_context, set_dtype_resolution_policy as set_dtype_resolution_policy, - get_resolution_policy as get_resolution_policy, - get_dtype_resolution_policy as get_dtype_resolution_policy, + set_resolution_policy as set_resolution_policy, +) +from ._policies import ( + ContextConflictError as ContextConflictError, + ContextConversionError as ContextConversionError, + ContextError as ContextError, + ContextInferenceError as ContextInferenceError, + ContextPolicy as ContextPolicy, + DtypePreservePolicy as DtypePreservePolicy, + UnknownBackendError as UnknownBackendError, ) + +__all__ = [ + "ContextBound", + "ContextConflictError", + "ContextConversionError", + "ContextError", + "ContextInferenceError", + "ContextPolicy", + "DtypePreservePolicy", + "UnknownBackendError", + "enforce_convert_policy", + "get_context", + "get_dtype_resolution_policy", + "get_resolution_policy", + "normalize_context", + "normalize_ops", + "register_ops", + "resolve_context_priority", + "set_context", + "set_dtype_resolution_policy", + "set_resolution_policy", +] diff --git a/spacecore/_contextual/_bound.py b/spacecore/_contextual/_bound.py new file mode 100644 index 0000000..3a21bad --- /dev/null +++ b/spacecore/_contextual/_bound.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Self + +from ..types import DType +from ._manager import enforce_convert_policy, normalize_context + +if TYPE_CHECKING: + from ..backend import BackendFamily, BackendOps, Context + + +def _same_context_for_conversion(left: Context, right: Context) -> bool: + """ + Compare contexts for conversion equivalence. + + Parameters + ---------- + left : Context + First context to compare. + right : Context + Second context to compare. + + Returns + ------- + bool + ``True`` when both contexts use the same backend operations, dtype, and + check policy; otherwise ``False``. + + Notes + ----- + This predicate is used by ``convert()`` and intentionally includes + ``enable_checks`` because a converted object with different runtime checks + is operationally different. + """ + return ( + left.ops == right.ops + and left.dtype == right.dtype + and left.enable_checks == right.enable_checks + ) + + +def _same_context_for_algebra(left: Context, right: Context) -> bool: + """ + Compare contexts for algebraic compatibility. + + Parameters + ---------- + left : Context + First context to compare. + right : Context + Second context to compare. + + Returns + ------- + bool + ``True`` when both contexts use the same backend operations and dtype. + + Notes + ----- + Algebraic combinators ignore ``enable_checks`` because validation policy is + operational, not mathematical. + """ + return left.ops == right.ops and left.dtype == right.dtype + + +class ContextBound(ABC): + """ + Base class for objects bound to a SpaceCore execution context. + + ``ContextBound`` normalizes and stores a :class:`~spacecore.backend.Context` + for subclasses such as spaces, linear operators, and functionals. It also + provides convenience access to the context's backend operations and dtype, + plus a common ``convert`` workflow that respects the global context + conversion policy. + + Subclasses that own backend arrays or nested context-bound objects must + implement :meth:`_convert` to rebuild themselves in a target context. + + Parameters + ---------- + ctx : Context, str, or None, optional + Context specification passed to :meth:`__init__`. This may be a + concrete :class:`~spacecore.backend.Context`, a backend-family string, + or ``None`` to use the current default context. + + Returns + ------- + ContextBound + A context-aware object whose concrete type is provided by a subclass. + """ + + def __init__(self, ctx: Context | str | None = None): + """ + Initialize this object with a normalized context. + + Parameters + ---------- + ctx : Context, str, or None, optional + Context specification for the object. This may be a concrete + :class:`~spacecore.backend.Context`, a backend-family string, or + ``None`` to use the current default context. + + Returns + ------- + None + The initializer stores the normalized context on ``self``. + """ + ctx = normalize_context(ctx) + self._ctx = ctx + + @property + def ops(self) -> BackendOps: + """ + Return backend operations associated with this object's context. + + Parameters + ---------- + None + + Returns + ------- + BackendOps + Backend operation object used by this instance. + """ + return self.ctx.ops + + @property + def dtype(self) -> DType: + """ + Return the default dtype associated with this object's context. + + Parameters + ---------- + None + + Returns + ------- + DType + Backend-normalized dtype stored in the bound context. + """ + return self.ctx.dtype + + @property + def ctx(self) -> Context: + """ + Return the execution context bound to this object. + + Parameters + ---------- + None + + Returns + ------- + Context + Context that controls backend operations, dtype, and validation + policy for this instance. + """ + return self._ctx + + def _convert(self, new_ctx: Context) -> Self: + """ + Rebuild this object in ``new_ctx``. + + Subclasses implement this hook with their concrete conversion logic. + The public :meth:`convert` method handles policy enforcement and skips + conversion when the target context is effectively identical. + + Parameters + ---------- + new_ctx : Context + Concrete target context in which the subclass should rebuild its + owned arrays, spaces, operators, or nested context-bound objects. + + Returns + ------- + Self + New object of the subclass type represented in ``new_ctx``. + """ + raise NotImplementedError() + + def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: + """ + Return this object represented in ``new_ctx``. + + Parameters + ---------- + new_ctx : Context, BackendFamily, str, or None, optional + Target context specification. ``None`` resolves according to the + current conversion policy and default context. + + Returns + ------- + Self + ``self`` when no effective context change is needed; otherwise a + converted object produced by :meth:`_convert`. + """ + _, new_ctx = enforce_convert_policy(self, new_ctx) + if _same_context_for_conversion(self.ctx, new_ctx): + return self + return self._convert(new_ctx) diff --git a/spacecore/_contextual/manager.py b/spacecore/_contextual/_manager.py similarity index 63% rename from spacecore/_contextual/manager.py rename to spacecore/_contextual/_manager.py index f655cb8..a9b3232 100644 --- a/spacecore/_contextual/manager.py +++ b/spacecore/_contextual/_manager.py @@ -1,11 +1,27 @@ -from typing import Any +from __future__ import annotations -from ..backend import Context, BackendOps -from .contextual import Contextual, ContextPolicy, DtypePreservePolicy -from ..backend import BackendFamily +from typing import TYPE_CHECKING, Any +from ..backend._family import BackendFamily +from ..backend._ops import BackendOps +from ._policies import ContextPolicy, DtypePreservePolicy -ctx_manager = Contextual() +if TYPE_CHECKING: + from ..backend._context import Context + + +_cached_state = None + + +def _state(): + """Return the cached contextual singleton.""" + global _cached_state + if _cached_state is not None: + return _cached_state + from ._state import _contextual + + _cached_state = _contextual + return _cached_state def set_context( @@ -18,14 +34,14 @@ def set_context( Parameters ---------- - ctx: + ctx : Context, BackendFamily, str, or None, optional Context specification to make default. This may be a concrete :class:`spacecore.backend.Context`, a backend family enum, a backend family string such as ``"numpy"`` or ``"jax"``, or ``None``. - dtype: + dtype : dtype-like, optional Optional dtype used when ``ctx`` is a backend family string or enum. Ignored when ``ctx`` is ``None`` or already a concrete ``Context``. - enable_checks: + enable_checks : bool or None, optional Optional validation flag used when constructing a context from a backend family. Ignored when ``ctx`` is ``None`` or already a concrete ``Context``. @@ -35,8 +51,8 @@ def set_context( Objects created without an explicit context use this default context. Existing spaces, operators, and contexts are not modified. """ - ctx = ctx_manager.normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) - ctx_manager.default_ctx = ctx + ctx = _state().normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + _state().default_ctx = ctx def get_context() -> Context: @@ -49,7 +65,7 @@ def get_context() -> Context: The default context used by constructors when no explicit context can be inferred or provided. """ - return ctx_manager.default_ctx + return _state().default_ctx def resolve_context_priority( @@ -61,10 +77,10 @@ def resolve_context_priority( Parameters ---------- - priority_ctx: + priority_ctx : Context, BackendFamily, str, or None, optional Explicit context supplied by the caller. If this is not ``None``, it wins over every inferred context. - *other_ctx: + *other_ctx : object Objects that may carry a ``ctx`` attribute or be backend-native arrays. These are used for context inference when no explicit context is supplied. @@ -80,7 +96,7 @@ def resolve_context_priority( User code should call this function instead of accessing the internal context manager singleton. """ - return ctx_manager.resolve_context_priority(priority_ctx, *other_ctx) + return _state().resolve_context_priority(priority_ctx, *other_ctx) def register_ops(ops: type[BackendOps]) -> type[BackendOps]: @@ -89,7 +105,7 @@ def register_ops(ops: type[BackendOps]) -> type[BackendOps]: Parameters ---------- - ops: + ops : type[BackendOps] Backend operations class to register. It must be a subclass of :class:`spacecore.backend.BackendOps` and define a unique backend family key. @@ -115,7 +131,61 @@ def register_ops(ops: type[BackendOps]) -> type[BackendOps]: class MyOps(BackendOps): ... """ - return ctx_manager.register_ops(ops) + return _state().register_ops(ops) + + +def normalize_context( + ctx: Context | BackendFamily | str | None = None, + dtype: Any = None, + enable_checks: bool | None = None, +) -> Context: + """ + Normalize a context specification through the process-wide state. + + Parameters + ---------- + ctx : Context, BackendFamily, str, or None, optional + Context specification to normalize. + dtype : dtype-like, optional + Optional dtype used when constructing a context from backend family. + enable_checks : bool or None, optional + Optional validation flag. + + Returns + ------- + Context + Normalized context. + """ + return _state().normalize_context(ctx, dtype=dtype, enable_checks=enable_checks) + + +def normalize_ops( + ops: str | BackendFamily | BackendOps | type[BackendOps] | Context +) -> BackendOps: + """ + Normalize backend operations through the process-wide state. + + Parameters + ---------- + ops : str, BackendFamily, BackendOps, type[BackendOps], or Context + Backend operations specification. + + Returns + ------- + BackendOps + Backend operations instance. + """ + if isinstance(ops, BackendOps): + return ops + return _state().get_ops(ops) + + +def enforce_convert_policy( + x: Any, + to: Context | BackendFamily | str | None = None, +) -> tuple[Any, Context]: + """Resolve a conversion target and enforce the configured policy.""" + return _state().enforce_convert_policy(x, to) def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: @@ -124,7 +194,7 @@ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: Parameters ---------- - policy: + policy : ContextPolicy, str, or None, optional Conversion policy to use. Accepted values are ``"warning"``, ``"error"``, ``"silent"``, matching :class:`ContextPolicy`, or ``None`` to restore the default policy. @@ -141,7 +211,7 @@ def set_resolution_policy(policy: ContextPolicy | str | None = None) -> None: * ``"error"``: reject backend conversion. * ``"silent"``: allow backend conversion without warning. """ - ctx_manager.resolution_policy = policy + _state().resolution_policy = policy def get_resolution_policy() -> str: @@ -153,7 +223,7 @@ def get_resolution_policy() -> str: str Policy name, one of ``"warning"``, ``"error"``, or ``"silent"``. """ - return ctx_manager.resolution_policy.value + return _state().resolution_policy.value def set_dtype_resolution_policy( @@ -164,7 +234,7 @@ def set_dtype_resolution_policy( Parameters ---------- - policy: + policy : DtypePreservePolicy, str, or None, optional Dtype policy to use. Accepted values are ``"keep_native"`` and ``"convert"``, matching :class:`DtypePreservePolicy`, or ``None`` to restore the default policy. @@ -181,7 +251,7 @@ def set_dtype_resolution_policy( equivalent dtype in the target backend. * ``"convert"``: use the dtype provided by the resolved target context. """ - ctx_manager.dtype_resolution_policy = policy + _state().dtype_resolution_policy = policy def get_dtype_resolution_policy() -> str: @@ -193,4 +263,4 @@ def get_dtype_resolution_policy() -> str: str Policy name, one of ``"keep_native"`` or ``"convert"``. """ - return ctx_manager.dtype_resolution_policy.value + return _state().dtype_resolution_policy.value diff --git a/spacecore/_contextual/_policies.py b/spacecore/_contextual/_policies.py new file mode 100644 index 0000000..3bf21f2 --- /dev/null +++ b/spacecore/_contextual/_policies.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from enum import StrEnum, auto + + +class ContextPolicy(StrEnum): + """ + Policy for backend-incompatible context conversion. + + Values + ------ + warning: + Allow conversion to a different backend family and issue a warning. + This is the default. + error: + Reject conversion to a different backend family. Use this when + accidental backend migration should be forbidden. + silent: + Allow conversion to a different backend family without warning. Use + this when automatic conversion is expected and controlled. + """ + + warning = auto() + error = auto() + silent = auto() + + +class DtypePreservePolicy(StrEnum): + """ + Policy for dtype handling during context conversion. + + Values + ------ + keep_native: + Preserve the source object's dtype where possible by converting it to an + equivalent dtype in the target backend. This is the default. + convert: + Use the dtype provided by the resolved target context. This prioritizes + dtype unification under the target context. + """ + + keep_native = auto() + convert = auto() + + +class ContextError(RuntimeError): + pass + + +class ContextInferenceError(ContextError): + pass + + +class ContextConflictError(ContextError): + pass + + +class UnknownBackendError(ContextError): + pass + + +class ContextConversionError(ContextError): + pass diff --git a/spacecore/_contextual/contextual.py b/spacecore/_contextual/_state.py similarity index 92% rename from spacecore/_contextual/contextual.py rename to spacecore/_contextual/_state.py index 67efe45..438ab11 100644 --- a/spacecore/_contextual/contextual.py +++ b/spacecore/_contextual/_state.py @@ -1,79 +1,31 @@ from __future__ import annotations from typing import Dict, Any, Iterable, Tuple -from enum import StrEnum, auto from warnings import warn from ..types import DType from ..backend import Context, NumpyOps, JaxOps, BackendFamily, BackendOps +from ._policies import ( + ContextConflictError, + ContextConversionError, + ContextInferenceError, + ContextPolicy, + DtypePreservePolicy, + UnknownBackendError, +) try: - from ..backend import TorchOps + from ..backend import CuPyOps except ImportError: pass - - -class ContextPolicy(StrEnum): - """ - Policy for backend-incompatible context conversion. - - Values - ------ - warning: - Allow conversion to a different backend family and issue a warning. - This is the default. - error: - Reject conversion to a different backend family. Use this when - accidental backend migration should be forbidden. - silent: - Allow conversion to a different backend family without warning. Use - this when automatic conversion is expected and controlled. - """ - - warning = auto() - error = auto() - silent = auto() - -class DtypePreservePolicy(StrEnum): - """ - Policy for dtype handling during context conversion. - - Values - ------ - keep_native: - Preserve the source object's dtype where possible by converting it to an - equivalent dtype in the target backend. This is the default. - convert: - Use the dtype provided by the resolved target context. This prioritizes - dtype unification under the target context. - """ - - keep_native = auto() - convert = auto() - - -class ContextError(RuntimeError): - pass - - -class ContextInferenceError(ContextError): - pass - - -class ContextConflictError(ContextError): - pass - - -class UnknownBackendError(ContextError): - pass - -class ContextConversionError(ContextError): +try: + from ..backend import TorchOps +except ImportError: pass class Contextual: - """ - Backend resolver. - """ + """Resolve contexts, backend registrations, and conversion policies.""" + _default_ctx: Context _available_ops: Dict[str, type[BackendOps]] _resolution_policy: ContextPolicy @@ -98,6 +50,8 @@ def __init__(self, self._backend_key(NumpyOps): NumpyOps, self._backend_key(JaxOps): JaxOps, } + if "CuPyOps" in globals(): + self._available_ops[self._backend_key(CuPyOps)] = CuPyOps if "TorchOps" in globals(): self._available_ops[self._backend_key(TorchOps)] = TorchOps @@ -553,3 +507,6 @@ def _join_dtypes(self, ops: BackendOps, *dtypes: DType | None) -> DType | None: np_ops = NumpyOps() joined = np_ops.np.result_type(*clean) return ops.sanitize_dtype(joined) + + +_contextual = Contextual() diff --git a/spacecore/_contextual/bound.py b/spacecore/_contextual/bound.py deleted file mode 100644 index eea87ba..0000000 --- a/spacecore/_contextual/bound.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import Self - -from ..backend import Context, BackendOps, BackendFamily -from ..types import DType -from .manager import ctx_manager - - -def _same_effective_context(left: Context, right: Context) -> bool: - return ( - left.ops == right.ops - and left.dtype == right.dtype - and left.enable_checks == right.enable_checks - ) - - -class ContextBound(ABC): - def __init__(self, ctx: Context | str | None = None): - ctx = ctx_manager.normalize_context(ctx) - self._ctx = ctx - - @property - def ops(self) -> BackendOps: - return self.ctx.ops - - @property - def dtype(self) -> DType: - return self.ctx.dtype - - @property - def ctx(self) -> Context: - return self._ctx - - def _convert(self, new_ctx: Context) -> Self: - raise NotImplementedError() - - def convert(self, new_ctx: Context | BackendFamily | str | None = None) -> Self: - _, new_ctx = ctx_manager.enforce_convert_policy(self, new_ctx) - if _same_effective_context(self.ctx, new_ctx): - return self - return self._convert(new_ctx) diff --git a/spacecore/backend/__init__.py b/spacecore/backend/__init__.py index f45af32..5ccc2e0 100644 --- a/spacecore/backend/__init__.py +++ b/spacecore/backend/__init__.py @@ -1,8 +1,15 @@ +"""Backend contexts and operation implementations.""" + from ._context import Context from ._ops import BackendOps from ._family import BackendFamily from .jax import JaxOps, jax_pytree_class from .numpy import NumpyOps +try: + from .cupy import CuPyOps as CuPyOps +except ModuleNotFoundError as exc: + if exc.name != "cupy": + raise try: from .torch import TorchOps as TorchOps @@ -19,5 +26,7 @@ "NumpyOps", ] +if "CuPyOps" in globals(): + __all__.append("CuPyOps") if "TorchOps" in globals(): __all__.append("TorchOps") diff --git a/spacecore/backend/_context.py b/spacecore/backend/_context.py index 92b5fd6..4d9bd29 100644 --- a/spacecore/backend/_context.py +++ b/spacecore/backend/_context.py @@ -3,12 +3,13 @@ from ._ops import BackendOps from ..types import DenseArray, SparseArray, DType, ArrayLike +from .._contextual import normalize_ops @dataclass(frozen=True, slots=True) class Context: """ - Backend execution context for SpaceCore objects. + Select backend operations, dtype, and validation policy. A context collects the backend operations object, default dtype, and runtime validation policy used by spaces, linear operators, and context-bound @@ -18,18 +19,27 @@ class Context: Parameters ---------- - ops: + ops : BackendOps Backend operations implementation. This must be an instance of :class:`spacecore.backend.BackendOps`, such as :class:`spacecore.backend.NumpyOps` or :class:`spacecore.backend.JaxOps`. - dtype: + dtype : dtype-like or None, optional Default dtype used by :meth:`asarray` and :meth:`assparse`. The value is normalized through ``ops.sanitize_dtype`` during initialization. - enable_checks: + enable_checks : bool, optional Whether spaces and linear operators using this context should perform membership and compatibility checks before operations. + Attributes + ---------- + ops : BackendOps + Normalized backend operations instance. + dtype : dtype-like + Backend-native dtype used by array constructors. + enable_checks : bool + Runtime validation flag propagated to spaces and operators. + Notes ----- ``Context`` is frozen and slot-based. Methods that convert values return new @@ -37,6 +47,17 @@ class Context: Equality compares backend family and ``enable_checks``. It currently does not compare ``dtype``. + + Examples + -------- + Create a NumPy context and convert a Python list to a backend array. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> x = ctx.asarray([1.0, 2.0]) + >>> x.dtype == np.dtype("float64") + True """ ops: BackendOps @@ -52,8 +73,11 @@ def __post_init__(self): TypeError If ``ops`` is not a :class:`BackendOps` instance. """ - if not isinstance(self.ops, BackendOps): - raise TypeError("ops must be a BackendOps") + try: + ops = normalize_ops(self.ops) + except TypeError: + raise TypeError("Unknown ops type.") + object.__setattr__(self, "ops", ops) sanitized = self.ops.sanitize_dtype(self.dtype) object.__setattr__(self, "dtype", sanitized) diff --git a/spacecore/backend/_family.py b/spacecore/backend/_family.py index 90a1183..17de281 100644 --- a/spacecore/backend/_family.py +++ b/spacecore/backend/_family.py @@ -5,3 +5,4 @@ class BackendFamily(StrEnum): numpy = auto() jax = auto() torch = auto() + cupy = auto() diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index 5bca111..bfa88ac 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -1,54 +1,57 @@ from __future__ import annotations from abc import ABC, abstractmethod +import importlib from typing import Any, Sequence, Tuple, Callable, Optional, Type, ClassVar from ..types import DenseArray, SparseArray, DType, ArrayLike, Index, T, X, Y, R, Carry +class LazyNamespace: + def __init__(self, module_name: str) -> None: + self.__name__ = module_name + self.__isabstractmethod__ = False + self._module_name = module_name + self._module: Any | None = None + + def _load(self) -> Any: + if self._module is None: + self._module = importlib.import_module(self._module_name) + return self._module + + def __getattr__(self, name: str) -> Any: + return getattr(self._load(), name) + + @property + def is_loaded(self) -> bool: + return self._module is not None + + class BackendOps(ABC): """ - Backend-agnostic numerical ops interface (portable core). + Public numerical contract for SpaceCore backends. - Contract: - - This base class exposes only the portable subset used by library internals. - - Concrete backends (NumPy/JAX/Torch) may extend these methods with additional - optional keyword parameters (e.g., `order=`, `out=`, `where=`, `like=`, ...). + Common dense-array operations delegate to the backend's Array API-compatible + ``xp`` namespace. Subclasses provide backend-specific sparse conversion, + dtype policy, indexing mutation, control flow, device/autograd behavior, and + operations not covered by the Array API. """ _family: ClassVar[str] _allow_sparse: ClassVar[bool] + xp: ClassVar[Any] + + def __init__(self) -> None: + self._constant_cache: dict[str, DenseArray] = {} @property def family(self) -> str: - """ - Generic backend-agnostic wrapper to backend family identifier. - - Input: - None. - - Output: - String naming the concrete backend family. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Backend family identifier.""" return type(self)._family @property def allow_sparse(self) -> bool: - """ - Generic backend-agnostic wrapper to sparse-array support flag. - - Input: - None. - - Output: - Boolean indicating whether this backend supports sparse arrays. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Whether this backend supports sparse arrays.""" return self._allow_sparse def __eq__(self, other: Any) -> bool: @@ -56,671 +59,484 @@ def __eq__(self, other: Any) -> bool: return self.family == other.family return False + def __hash__(self) -> int: + return hash((type(self).__name__, self.family)) + @property @abstractmethod def dense_array(self) -> Type[Any]: - """ - Generic backend-agnostic wrapper to dense array type. - - Input: - None. - - Output: - Concrete dense array class accepted by this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Dense array type accepted by this backend.""" ... @property @abstractmethod def sparse_array(self) -> Tuple[Type[Any], ...] | None: - """ - Generic backend-agnostic wrapper to sparse array type tuple. - - Input: - None. - - Output: - Concrete sparse array classes accepted by this backend, or None. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Sparse array types accepted by this backend, or None.""" ... @abstractmethod def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Generic backend-agnostic wrapper to normalize a dtype specifier. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Normalize a dtype specifier for this backend.""" ... def is_dense(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a dense backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of the backend dense array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a dense array for this backend.""" return isinstance(x, self.dense_array) def is_sparse(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a sparse backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of a backend sparse array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a sparse array for this backend.""" return self.sparse_array is not None and isinstance(x, self.sparse_array) def is_array(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for any backend array. - - Input: - x: Object to test. - - Output: - True when x is dense or sparse for this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is any array for this backend.""" return self.is_dense(x) or self.is_sparse(x) @abstractmethod - def get_dtype(self, x: Any) -> DType: - """ - Generic backend-agnostic wrapper to return an array dtype. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. + def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: + """Convert input to a backend sparse array.""" + ... - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: + """Multiply a sparse array by a dense array.""" ... @abstractmethod - def shape(self, x: Any) -> tuple[int, ...]: - """ - Generic backend-agnostic wrapper to return array shape metadata. + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, + keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: + """Compute a stable log-sum-exp reduction.""" + ... - Input: - x: Dense or sparse backend array. + @abstractmethod + def index_set( + self, + x: DenseArray, + index: Index, + values: ArrayLike, + *, + copy: bool = True, + ) -> DenseArray: + """Set indexed values using backend mutation semantics.""" - Output: - Tuple describing the logical shape of x. + @abstractmethod + def index_add( + self, + x: DenseArray, + index: Index, + values: DenseArray, + *, + copy: bool = True, + ) -> DenseArray: + """Add values into indexed positions using backend mutation semantics.""" + ... - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def ix_(self, *args: Any) -> Any: + """Build open-mesh index arrays.""" + ... @abstractmethod - def ndim(self, x: Any) -> int: - """ - Generic backend-agnostic wrapper to return array rank metadata. + def fori_loop( + self, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, + ) -> T: + """Run a counted loop primitive.""" - Input: - x: Dense or sparse backend array. + @abstractmethod + def while_loop( + self, + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: + """Run a while-loop primitive.""" - Output: - Number of dimensions in x. + @abstractmethod + def scan( + self, + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + ) -> Tuple[Carry, Y]: + """Run a scan primitive.""" - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: + """Run backend-compatible conditional branch selection.""" + ... @abstractmethod - def size(self, x: Any) -> int: - """ - Generic backend-agnostic wrapper to return logical element count. + def allclose_sparse( + self, + a: SparseArray, + b: SparseArray, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> bool: + """Compare sparse arrays elementwise within tolerances.""" + ... - Input: - x: Dense or sparse backend array. + def _dtype_arg(self, dtype: DType | None) -> DType | None: + return None if dtype is None else self.sanitize_dtype(dtype) - Output: - Total number of logical dense elements. + def _to_axis_tuple(self, axis: int | Sequence[int] | None) -> int | tuple[int, ...] | None: + if axis is None or isinstance(axis, int): + return axis + return tuple(axis) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def _permute_dims(self, x: DenseArray, axes: Sequence[int]) -> DenseArray: + axes = tuple(axes) + if hasattr(self.xp, "permute_dims"): + return self.xp.permute_dims(x, axes) + if hasattr(self.xp, "permute"): + return self.xp.permute(x, axes) + return self.xp.transpose(x, axes=axes) + + def _move_axis_order( + self, + ndim: int, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> tuple[int, ...]: + src = (source,) if isinstance(source, int) else tuple(source) + dst = (destination,) if isinstance(destination, int) else tuple(destination) + src = tuple(axis + ndim if axis < 0 else axis for axis in src) + dst = tuple(axis + ndim if axis < 0 else axis for axis in dst) + order = [axis for axis in range(ndim) if axis not in src] + for dest, axis in sorted(zip(dst, src, strict=True)): + order.insert(dest, axis) + return tuple(order) @property - @abstractmethod def inf(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to positive infinity scalar. - - Input: - None. - - Output: - Backend scalar representing positive infinity. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Positive infinity as a cached backend scalar.""" + return self._constant("inf", float("inf")) @property - @abstractmethod def nan(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to access a NaN scalar. - - Input: - None. - - Output: - Backend scalar representing NaN. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """NaN as a cached backend scalar.""" + return self._constant("nan", float("nan")) @property - @abstractmethod def pi(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to pi scalar. - - Input: - None. - - Output: - Backend scalar representing pi. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Pi as a cached backend scalar.""" + return self._constant("pi", 3.141592653589793) @property - @abstractmethod def e(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to access Euler's number scalar. - - Input: - None. - - Output: - Backend scalar representing Euler's number. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Euler's number as a cached backend scalar.""" + return self._constant("e", 2.718281828459045) - @property - @abstractmethod - def eps(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to machine epsilon scalar. - - Input: - None. + def _constant(self, name: str, value: float) -> DenseArray: + if name not in self._constant_cache: + self._constant_cache[name] = self.asarray(value) + return self._constant_cache[name] - Output: - Backend scalar for float64 machine epsilon. + def eps(self, dtype: DType) -> float: + """Machine epsilon for dtype.""" + return float(self.xp.finfo(self.sanitize_dtype(dtype)).eps) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. + def is_complex_dtype(self, dtype: DType) -> bool: """ + Return whether ``dtype`` is a complex floating type. - @abstractmethod - def asarray(self, x: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to convert input to a dense array. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. + Parameters + ---------- + dtype: + Backend or portable dtype specifier to inspect. - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. + Returns + ------- + bool + ``True`` when ``dtype`` represents complex floating values. """ - ... + dtype = self.sanitize_dtype(dtype) + return getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") - @abstractmethod - def astype(self, x: DenseArray, dtype: DType) -> DenseArray: + def real_dtype(self, dtype: DType) -> DType: """ - Generic backend-agnostic wrapper to cast an array to a dtype. + Return the real floating dtype with the same precision as ``dtype``. - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. + Parameters + ---------- + dtype: + Backend or portable dtype specifier. - Output: - Dense backend array with the requested dtype. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. + Returns + ------- + DType + ``dtype`` itself when it is already real-valued; otherwise + ``float32`` for complex64 and ``float64`` for complex128. """ + dtype = self.sanitize_dtype(dtype) + if not self.is_complex_dtype(dtype): + return dtype + itemsize = getattr(dtype, "itemsize", None) + if itemsize is None: + dtype_text = str(dtype) + if "complex64" in dtype_text: + return self.sanitize_dtype("float32") + return self.sanitize_dtype("float64") + return self.sanitize_dtype("float32" if itemsize <= 8 else "float64") - @abstractmethod - def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: - """ - Generic backend-agnostic wrapper to convert input to a sparse array. + def get_dtype(self, x: Any) -> DType: + """Return x.dtype after verifying x is a backend array.""" + if self.is_array(x): + return x.dtype + raise TypeError(f"Expected {self.family} array, got {type(x)}.") - Input: - x: Dense, sparse, or array-like input plus sparse-format options. + def shape(self, x: Any) -> tuple[int, ...]: + """Return x.shape as a tuple.""" + return tuple(x.shape) - Output: - Sparse backend array. + def ndim(self, x: Any) -> int: + """Return the number of dimensions of x.""" + return int(x.ndim) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def size(self, x: Any) -> int: + """Return the total number of elements in x.""" + result = 1 + for dim in self.shape(x): + result *= int(dim) + return result + + def asarray(self, x: Any, dtype: DType | None = None, **backend_kwargs: Any) -> DenseArray: + """Convert input to a dense backend array (delegates to xp.asarray).""" + if self.is_sparse(x) and hasattr(x, "to_dense"): + x = x.to_dense() + dtype = self._dtype_arg(dtype) + if hasattr(self.xp, "asarray"): + return self.xp.asarray(x, dtype=dtype, **backend_kwargs) + return self.xp.as_tensor(x, dtype=dtype, **backend_kwargs) + + def astype(self, x: DenseArray, dtype: DType | None, **backend_kwargs: Any) -> DenseArray: + """Cast x to dtype, returning x unchanged when dtype is None.""" + if dtype is None: + return x + dtype = self.sanitize_dtype(dtype) + if hasattr(x, "astype"): + return x.astype(dtype, **backend_kwargs) + return x.to(dtype=dtype, **backend_kwargs) - @abstractmethod def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create an uninitialized dense array. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array with uninitialized values. + """Create an uninitialized array (delegates to xp.empty).""" + return self.xp.empty(shape, dtype=self._dtype_arg(dtype)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a zero-filled dense array. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with zeros. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Create a zero-filled array (delegates to xp.zeros).""" + return self.xp.zeros(shape, dtype=self._dtype_arg(dtype)) - @abstractmethod def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a one-filled dense array. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with ones. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Create a one-filled array (delegates to xp.ones).""" + return self.xp.ones(shape, dtype=self._dtype_arg(dtype)) - @abstractmethod def zeros_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create zeros shaped like another array. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + """Create a zero-filled array like x (delegates to xp.zeros_like).""" + return self.xp.zeros_like(x, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array of zeros. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def ones_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create ones shaped like another array. + """Create a one-filled array like x (delegates to xp.ones_like).""" + return self.xp.ones_like(x, dtype=self._dtype_arg(dtype)) - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array like x (delegates to xp.full_like).""" + return self.xp.full_like(x, value, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array of ones. + def arange( + self, + start: int, + stop: int | None = None, + step: int | None = None, + dtype: DType | None = None, + ) -> DenseArray: + """Create an evenly spaced range (delegates to xp.arange).""" + dtype = self._dtype_arg(dtype) + if stop is None: + return self.xp.arange(start, dtype=dtype) + if step is None: + return self.xp.arange(start, stop, dtype=dtype) + return self.xp.arange(start, stop, step, dtype=dtype) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array (delegates to xp.full).""" + return self.xp.full(shape, fill_value, dtype=self._dtype_arg(dtype)) - @abstractmethod - def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create filled values shaped like another array. + def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: + """Create an identity-like matrix (delegates to xp.eye).""" + return self.xp.eye(n, m, dtype=self._dtype_arg(dtype)) - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. + def ravel(self, x: DenseArray) -> DenseArray: + """Flatten x to one dimension.""" + if hasattr(self.xp, "ravel"): + return self.xp.ravel(x) + return self.reshape(x, (-1,)) - Output: - Dense backend array filled with the requested value. + def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: + """Reshape x (delegates to xp.reshape).""" + shape_arg = (shape,) if isinstance(shape, int) else shape + return self.xp.reshape(x, shape_arg) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: + """Permute dimensions of x.""" + if axes is None: + axes = tuple(reversed(range(self.ndim(x)))) + return self._permute_dims(x, axes) - @abstractmethod - def arange(self, start: int, stop: int | None = None, step: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create evenly spaced integer-range values. + def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: + """Swap two axes of x.""" + if hasattr(self.xp, "swapaxes"): + return self.xp.swapaxes(x, axis1, axis2) + axes = list(range(self.ndim(x))) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + return self._permute_dims(x, axes) - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. + def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: + """Broadcast x to shape (delegates to xp.broadcast_to).""" + return self.xp.broadcast_to(x, shape) - Output: - One-dimensional dense backend array. + def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: + """Insert singleton dimensions into x.""" + if isinstance(axis, int): + if hasattr(self.xp, "expand_dims"): + return self.xp.expand_dims(x, axis=axis) + return self.xp.unsqueeze(x, axis) + out = x + ndim = self.ndim(x) + len(axis) + for ax in sorted(a + ndim if a < 0 else a for a in axis): + out = self.expand_dims(out, ax) + return out - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: + """Remove singleton dimensions from x.""" + if axis is None: + axis = tuple(i for i, dim in enumerate(self.shape(x)) if dim == 1) + if not axis: + return x + axis = (axis,) if isinstance(axis, int) else tuple(axis) + return self.xp.squeeze(x, axis=axis) - @abstractmethod - def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a filled dense array. + def moveaxis( + self, + x: DenseArray, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> DenseArray: + """Move axes of x to new positions.""" + if hasattr(self.xp, "moveaxis"): + return self.xp.moveaxis(x, source, destination) + return self._permute_dims(x, self._move_axis_order(self.ndim(x), source, destination)) - Input: - shape: Output shape; fill_value and dtype options are backend-specific. + def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: + """Stack arrays along a new axis (delegates to xp.stack).""" + return self.xp.stack(tuple(arrays), axis=axis) - Output: - Dense backend array filled with fill_value. + def vmap( + self, + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize ``fn`` over array axes using a Python-loop fallback.""" + + def axis_for_arg(i: int) -> int | Sequence[int | None] | None: + if isinstance(in_axes, tuple) or isinstance(in_axes, list): + return in_axes[i] + return in_axes + + def normalize_axis(axis: int, ndim: int) -> int: + return axis + ndim if axis < 0 else axis + + def tree_size(x: Any, axis: Any) -> int | None: + if axis is None: + return None + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + for xi, ai in zip(x, axes): + size = tree_size(xi, ai) + if size is not None: + return size + return None + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + return int(shape[axis]) + + def tree_take(x: Any, axis: Any, i: int) -> Any: + if axis is None: + return x + if isinstance(x, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x) + return tuple(tree_take(xi, ai, i) for xi, ai in zip(x, axes)) + shape = tuple(getattr(x, "shape", ())) + axis = normalize_axis(int(axis), len(shape)) + index = [slice(None)] * len(shape) + index[axis] = i + return x[tuple(index)] + + def tree_stack(xs: Sequence[Any], axis: Any) -> Any: + first = xs[0] + if isinstance(first, tuple): + axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(first) + return tuple( + tree_stack(tuple(x[i] for x in xs), ai) + for i, ai in enumerate(axes) + ) + if axis is None: + return first + return self.stack(xs, axis=int(axis)) + + def mapped(*args: Any) -> Any: + axes = tuple(axis_for_arg(i) for i in range(len(args))) + size = None + for arg, axis in zip(args, axes): + size = tree_size(arg, axis) + if size is not None: + break + if size is None: + return fn(*args) + outputs = tuple( + fn(*(tree_take(arg, axis, i) for arg, axis in zip(args, axes))) + for i in range(size) + ) + return tree_stack(outputs, out_axes) + + return mapped - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def conj(self, x: DenseArray) -> DenseArray: + """Complex conjugate of x (delegates to xp.conj).""" + return self.xp.conj(x) - @abstractmethod - def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a dense identity-like matrix. + def real(self, x: DenseArray) -> DenseArray: + """Real component of x (delegates to xp.real).""" + return self.xp.real(x) - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. + def imag(self, x: DenseArray) -> DenseArray: + """Imaginary component of x (delegates to xp.imag).""" + return self.xp.imag(x) - Output: - Two-dimensional dense backend array. + def abs(self, x: DenseArray) -> DenseArray: + """Absolute value of x (delegates to xp.abs).""" + return self.xp.abs(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def sign(self, x: DenseArray) -> DenseArray: + """Elementwise sign of x (delegates to xp.sign).""" + return self.xp.sign(x) - @abstractmethod - def ravel(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to flatten an array. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: - """ - Generic backend-agnostic wrapper to reshape an array. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to permute array axes. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Generic backend-agnostic wrapper to interchange two axes. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: - """ - Generic backend-agnostic wrapper to broadcast an array to a shape. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Generic backend-agnostic wrapper to insert length-one axes. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to remove length-one axes. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to move axes to new positions. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: - """ - Generic backend-agnostic wrapper to stack arrays along a new axis. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def conj(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute complex conjugates. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def real(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract real components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def imag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract imaginary components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def abs(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute absolute values. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sign(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute signs elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def sqrt(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute square roots elementwise. - - Input: - x: Dense backend array. + """Elementwise square root of x (delegates to xp.sqrt).""" + return self.xp.sqrt(x) - Output: - Dense backend array of square roots. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def sum( self, x: DenseArray, @@ -728,81 +544,41 @@ def sum( keepdims: bool = False, dtype: DType | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by summation. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. + """Sum over given axes (delegates to xp.sum).""" + return self.xp.sum( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def mean( self, x: DenseArray, axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by arithmetic mean. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Mean over given axes (delegates to xp.mean).""" + return self.xp.mean(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - @abstractmethod def min( self, x: DenseArray, axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by minimum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. + """Minimum over given axes (delegates to xp.min).""" + return self.xp.min(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - Output: - Dense backend array or scalar containing minima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def max( self, x: DenseArray, axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by maximum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Maximum over given axes (delegates to xp.max).""" + return self.xp.max(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - @abstractmethod def prod( self, x: DenseArray, @@ -810,197 +586,75 @@ def prod( keepdims: bool = False, dtype: DType | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by product. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. + """Product over given axes (delegates to xp.prod).""" + return self.xp.prod( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - Output: - Dense backend array or scalar containing products. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def trace(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to sum diagonal entries. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. + """Trace of a matrix (delegates to xp.trace when available).""" + if hasattr(self.xp, "trace"): + return self.xp.trace(x) + return self.sum(self.diagonal(x)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to return sorting indices. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. + """Return indices that sort ``x`` along an axis.""" + return self.xp.argsort(x, axis=axis) - Output: - Dense integer backend array of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to sort values. + """Sort x along an axis (delegates to xp.sort).""" + return self.xp.sort(x, axis=axis) - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of minimum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Return indices of minima along an axis.""" + return self.xp.argmin(x, axis=axis, keepdims=keepdims) - @abstractmethod def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of maximum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Return indices of maxima along an axis.""" + return self.xp.argmax(x, axis=axis, keepdims=keepdims) - @abstractmethod def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a conjugating vector dot product. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def matmul(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute matrix products. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. + """Return ``sum(conj(x) * y)`` over flattened inputs. - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: + Matches NumPy, JAX, and Torch ``vdot`` semantics. ``DenseLinOp.rapply`` + relies on this convention for complex inputs. """ - Generic backend-agnostic wrapper to multiply sparse and dense arrays. - - Input: - a: Sparse backend array; b: Dense backend array. + x_flat = self.ravel(x) + y_flat = self.ravel(y) + if hasattr(self.xp, "vdot"): + return self.xp.vdot(x_flat, y_flat) + return self.xp.vecdot(x_flat, y_flat) - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def matmul( + self, + a: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Matrix product (delegates to xp.matmul).""" + return self.xp.matmul(a, b, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a Kronecker product. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Kronecker product (delegates to xp.kron).""" + return self.xp.kron(a, b) - @abstractmethod def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to evaluate an Einstein summation expression. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def eigh(self, x: DenseArray) -> tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenpairs. + """Einstein summation (delegates to xp.einsum).""" + return self.xp.einsum(subscripts, *operands) - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def eigh( + self, + x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray]: + """Eigenpairs of a Hermitian dense matrix (delegates to xp.linalg.eigh).""" + if self.is_sparse(x): + raise TypeError("eigh requires a dense array; sparse input is not supported.") + return self.xp.linalg.eigh(x, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod def norm( self, x: DenseArray, @@ -1008,464 +662,117 @@ def norm( axis: int | Sequence[int] | None = None, keepdims: bool = False, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute vector or matrix norms. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to solve dense linear systems. + """Vector or matrix norm (delegates to xp.linalg.norm).""" + return self.xp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def eigvalsh(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenvalues. - - Input: - A: Dense Hermitian or symmetric backend array. - - Output: - Dense backend array containing eigenvalues. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def svd(self, A: DenseArray, full_matrices: bool = True) -> tuple[DenseArray, DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute singular value decompositions. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def cholesky(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Cholesky factors. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, - keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute a stable log-sum-exp reduction. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. + def solve( + self, + A: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Solve a dense linear system (delegates to xp.linalg.solve).""" + return self.xp.linalg.solve(A, b, **({} if backend_kwargs is None else backend_kwargs)) - Output: - Dense backend array or tuple containing log-sum-exp results. + def eigvalsh( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Eigenvalues of a Hermitian dense matrix (delegates to xp.linalg.eigvalsh).""" + return self.xp.linalg.eigvalsh(A, **({} if backend_kwargs is None else backend_kwargs)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def svd( + self, + A: DenseArray, + full_matrices: bool = True, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray, DenseArray]: + """Singular value decomposition (delegates to xp.linalg.svd).""" + return self.xp.linalg.svd( + A, + full_matrices=full_matrices, + **({} if backend_kwargs is None else backend_kwargs), + ) + + def cholesky( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Cholesky factorization (delegates to xp.linalg.cholesky).""" + return self.xp.linalg.cholesky(A, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod def exp(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute exponentials elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Elementwise exponential (delegates to xp.exp).""" + return self.xp.exp(x) - @abstractmethod def log(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute natural logarithms elementwise. + """Elementwise natural logarithm (delegates to xp.log).""" + return self.xp.log(x) - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def where(self, condition: DenseArray | bool, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to select values by condition. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. + """Select between x and y by condition (delegates to xp.where).""" + return self.xp.where(condition, x, y) - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def maximum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise maxima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. + """Elementwise maximum (delegates to xp.maximum).""" + return self.xp.maximum(x, y) - Output: - Dense backend array containing maxima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def minimum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise minima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. + """Elementwise minimum (delegates to xp.minimum).""" + return self.xp.minimum(x, y) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def clip(self, x: DenseArray, a_min: ArrayLike, a_max: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to clip values into an interval. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Clip x into [a_min, a_max] (delegates to xp.clip).""" + return self.xp.clip(x, a_min, a_max) - @abstractmethod def isfinite(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test finiteness elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. + """Elementwise finite check (delegates to xp.isfinite).""" + return self.xp.isfinite(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def isnan(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test NaN values elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to join arrays along an existing axis. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. + """Elementwise NaN check (delegates to xp.isnan).""" + return self.xp.isnan(x) - Output: - Dense backend array containing concatenated inputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def concatenate( + self, + arrays: Sequence[DenseArray], + axis: int = 0, + dtype: DType | None = None, + ) -> DenseArray: + """Concatenate arrays along an existing axis (delegates to xp.concat).""" + if hasattr(self.xp, "concat"): + result = self.xp.concat(tuple(arrays), axis=axis) + else: + result = self.xp.concatenate(tuple(arrays), axis=axis) + return self.astype(result, dtype) if dtype is not None else result - @abstractmethod def take( self, x: DenseArray, indices: DenseArray, axis: int | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to take values by integer indices. + """Take entries from x by integer indices (delegates to xp.take).""" + return self.xp.take(x, indices, axis=axis) - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def diag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract or build a diagonal. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. + """Extract or construct a diagonal (delegates to xp.diag).""" + return self.xp.diag(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def diagonal(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return selected diagonals. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return the main diagonal of x (delegates to xp.diagonal).""" + return self.xp.diagonal(x) - @abstractmethod def tril(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return lower-triangular values. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. + """Lower triangle of x (delegates to xp.tril).""" + return self.xp.tril(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def triu(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return upper-triangular values. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def index_set( - self, - x: DenseArray, - index: Index, - values: ArrayLike, - *, - copy: bool = True, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to set indexed values. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def index_add( - self, - x: DenseArray, - index: Index, - values: DenseArray, - *, - copy: bool = True, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to add into indexed values. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def ix_(self, *args: Any) -> Any: - """ - Generic backend-agnostic wrapper to build open mesh index arrays. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a counted loop primitive. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a while-loop primitive. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - ) -> Tuple[Carry, Y]: - """ - Generic backend-agnostic wrapper to run a scan primitive. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Upper triangle of x (delegates to xp.triu).""" + return self.xp.triu(x) - @abstractmethod - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Generic backend-agnostic wrapper to run conditional branch selection. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def allclose( self, a: DenseArray, @@ -1474,41 +781,12 @@ def allclose( atol: float = 1e-8, equal_nan: bool = False, ) -> bool: - """ - Generic backend-agnostic wrapper to compare dense arrays elementwise within tolerances. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def allclose_sparse( - self, - a: SparseArray, - b: SparseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - ) -> bool: - """ - Generic backend-agnostic wrapper to compare sparse arrays elementwise within tolerances. - - Input: - a, b: Sparse backend arrays; rtol and atol configure comparison. - - Output: - Boolean indicating whether sparse arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - def __repr__(self): - return f"{type(self).__name__}" + """Return whether dense arrays are close within tolerances.""" + return bool(self.xp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + + def __repr__(self) -> str: + xp = type(self).xp + xp_state = "" + if isinstance(xp, LazyNamespace): + xp_state = f", xp_loaded={xp.is_loaded!r}" + return f"{type(self).__name__}(family={self.family!r}{xp_state})" diff --git a/spacecore/backend/cupy/__init__.py b/spacecore/backend/cupy/__init__.py new file mode 100644 index 0000000..c32df52 --- /dev/null +++ b/spacecore/backend/cupy/__init__.py @@ -0,0 +1,5 @@ +"""CuPy backend implementation.""" + +from ._ops import CuPyOps as CuPyOps + +__all__ = ["CuPyOps"] diff --git a/spacecore/backend/cupy/_ops.py b/spacecore/backend/cupy/_ops.py new file mode 100644 index 0000000..0d75c2d --- /dev/null +++ b/spacecore/backend/cupy/_ops.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +from typing import Any, Callable, Literal, Optional, Sequence, Tuple, Type + +from .._family import BackendFamily +from .._ops import BackendOps +from ...types import ArrayLike, Carry, DenseArray, DType, Index, R, SparseArray, T, X, Y + + +class CuPyOps(BackendOps): + """ + BackendOps implementation for CuPy GPU arrays. + + This backend uses CuPy for dense array operations and ``cupyx.scipy.sparse`` + for sparse arrays. Most operations follow CuPy's NumPy-compatible API and + execute on the active CUDA device. + + Dense arrays + ``cupy.ndarray`` + + Sparse arrays + ``cupyx.scipy.sparse`` matrix types such as CSR, CSC, and COO. + """ + + import cupy as cp + import cupyx.scipy as cpx_scipy + import cupyx.scipy.sparse as cpx_sparse + + xp = cp + + _family = BackendFamily.cupy.value.lower() + _allow_sparse = True + + @property + def dense_array(self) -> Type[Any]: + """Dense CuPy array type.""" + return self.cp.ndarray + + @property + def sparse_array(self) -> Tuple[Type[Any], ...]: + """Sparse CuPy array type tuple.""" + sparse = self.cpx_sparse + types: list[type[Any]] = [] + for name in ("spmatrix", "csr_matrix", "csc_matrix", "coo_matrix"): + typ = getattr(sparse, name, None) + if typ is not None: + types.append(typ) + return tuple(types) + + def sanitize_dtype(self, dtype: DType | None) -> DType: + """ + Normalize a dtype specifier using CuPy. + + ``None`` follows NumPy/CuPy's float64 default. + """ + if dtype is None: + return self.cp.float64 + return self.cp.dtype(dtype) + + def assparse( + self, + x: Any, + *, + format: Literal["csr", "csc", "coo"] = "csr", + dtype: DType | None = None, + ) -> SparseArray: + """ + Convert input to a CuPy sparse matrix. + + Dense inputs must be two-dimensional. Existing sparse inputs are + converted to the requested sparse format. + """ + sparse = self.cpx_sparse + + if self.is_sparse(x): + if format == "csr": + return x.tocsr() + if format == "csc": + return x.tocsc() + if format == "coo": + return x.tocoo() + raise ValueError(f"Unknown sparse format: {format!r}") + + x_arr = self.asarray(x, dtype=dtype) + if x_arr.ndim != 2: + raise ValueError("CuPy sparse conversion currently expects a 2D array.") + + if format == "csr": + return sparse.csr_matrix(x_arr) + if format == "csc": + return sparse.csc_matrix(x_arr) + if format == "coo": + return sparse.coo_matrix(x_arr) + raise ValueError(f"Unknown sparse format: {format!r}") + + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: + """Multiply a CuPy sparse matrix by a CuPy dense array.""" + if not self.is_sparse(a): + raise TypeError("sparse_matmul expects a CuPy sparse matrix.") + if not self.is_dense(b): + raise TypeError("sparse_matmul expects a CuPy dense array.") + return a @ b + + def logsumexp( + self, + a: DenseArray, + axis: int | Sequence[int] | None = None, + b: DenseArray | None = None, + keepdims: bool = False, + return_sign: bool = False, + ) -> DenseArray | Tuple[DenseArray, DenseArray]: + """Compute log-sum-exp using ``cupyx.scipy.special``.""" + return self.cpx_scipy.special.logsumexp( + a, + axis=axis, + b=b, + keepdims=keepdims, + return_sign=return_sign, + ) + + def index_set( + self, + x: DenseArray, + index: Index, + values: ArrayLike, + *, + copy: bool = True, + ) -> DenseArray: + """Set indexed values in a CuPy array.""" + y = x.copy() if copy else x + y[index] = values + return y + + def index_add( + self, + x: DenseArray, + index: Index, + values: DenseArray, + *, + copy: bool = True, + ) -> DenseArray: + """Add values into indexed entries of a CuPy array.""" + y = x.copy() if copy else x + self.cp.add.at(y, index, values) + return y + + def ix_(self, *args: Any) -> Any: + """Build open-mesh indices using CuPy.""" + return self.cp.ix_(*args) + + def fori_loop( + self, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, + ) -> T: + """Run a counted loop eagerly in Python for CuPy.""" + val = init_val + for i in range(int(lower), int(upper)): + val = body_fun(i, val) + return val + + def while_loop( + self, + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: + """Run a while loop eagerly in Python for CuPy.""" + val = init_val + while bool(cond_fun(val)): + val = body_fun(val) + return val + + def _tree_map(self, f: Callable[[Any], Any], tree: Any) -> Any: + if isinstance(tree, dict): + return {k: self._tree_map(f, v) for k, v in tree.items()} + if isinstance(tree, tuple): + return tuple(self._tree_map(f, v) for v in tree) + if isinstance(tree, list): + return [self._tree_map(f, v) for v in tree] + return f(tree) + + def _tree_multimap(self, f: Callable[..., Any], *trees: Any) -> Any: + t0 = trees[0] + if isinstance(t0, dict): + return {k: self._tree_multimap(f, *(t[k] for t in trees)) for k in t0.keys()} + if isinstance(t0, tuple): + return tuple(self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))) + if isinstance(t0, list): + return [self._tree_multimap(f, *(t[i] for t in trees)) for i in range(len(t0))] + return f(*trees) + + def _tree_take0(self, xs: Any) -> Any: + if isinstance(xs, dict): + return self._tree_take0(next(iter(xs.values()))) + if isinstance(xs, (tuple, list)): + return self._tree_take0(xs[0]) + return xs + + def _tree_index(self, xs: Any, i: int) -> Any: + def _idx(a: Any) -> Any: + try: + return a[i] + except Exception: + return a + + return self._tree_map(_idx, xs) + + def _tree_stack(self, ys_list: Sequence[Any]) -> Any: + if not ys_list: + return () + + def _stack_leaves(*leaves: Any) -> Any: + try: + return self.cp.stack(leaves, axis=0) + except Exception: + return self.cp.asarray(leaves) + + return self._tree_multimap(_stack_leaves, *ys_list) + + def scan( + self, + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + ) -> Tuple[Carry, Y]: + """Run a scan loop eagerly in Python for CuPy.""" + carry = init + if xs is None: + if length is None: + raise ValueError("scan(xs=None) requires an explicit `length`.") + n = int(length) + indices = range(n - 1, -1, -1) if reverse else range(n) + ys_steps: list[Any] = [] + for _i in indices: + carry, y = f(carry, None) # type: ignore[arg-type] + ys_steps.append(y) + if reverse: + ys_steps.reverse() + return carry, self._tree_stack(ys_steps) + + if length is None: + leaf0 = self._tree_take0(xs) + try: + n = int(leaf0.shape[0]) + except Exception as e: + raise ValueError( + "Could not infer scan length from `xs`; pass `length=` explicitly." + ) from e + else: + n = int(length) + + indices = range(n - 1, -1, -1) if reverse else range(n) + ys_steps = [] + for i in indices: + x_i = self._tree_index(xs, i) + carry, y = f(carry, x_i) + ys_steps.append(y) + if reverse: + ys_steps.reverse() + return carry, self._tree_stack(ys_steps) + + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: + """Run conditional branch selection eagerly in Python for CuPy.""" + return true_fun(*operands) if bool(pred) else false_fun(*operands) + + def allclose_sparse( + self, + a: SparseArray, + b: SparseArray, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> bool: + """Compare two CuPy sparse matrices by dense values.""" + if not self.is_sparse(a) or not self.is_sparse(b): + raise TypeError("allclose_sparse expects two CuPy sparse matrices.") + return bool(self.cp.asnumpy(self.cp.allclose(a.toarray(), b.toarray(), rtol=rtol, atol=atol))) diff --git a/spacecore/backend/jax/__init__.py b/spacecore/backend/jax/__init__.py index afd0428..b64d6f7 100644 --- a/spacecore/backend/jax/__init__.py +++ b/spacecore/backend/jax/__init__.py @@ -1,2 +1,4 @@ +"""JAX backend implementation and pytree registration helpers.""" + from ._ops import JaxOps as JaxOps -from ._pytree import jax_pytree_class as jax_pytree_class \ No newline at end of file +from ._pytree import jax_pytree_class as jax_pytree_class diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index 6ec1591..000bc6c 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Literal, Tuple, Callable, Optional, Type -import inspect from warnings import warn from .._family import BackendFamily @@ -25,6 +24,7 @@ class JaxOps(BackendOps): jax.experimental.sparse.BCSR Methods + ------- Most methods mirror the corresponding JAX public API signatures and delegate to `jax.numpy`, `jax.numpy.linalg`, `jax.scipy`, or `jax.experimental.sparse`. Backend-specific behavior, tracing rules, @@ -46,6 +46,7 @@ class JaxOps(BackendOps): through instances as `ops.jsparse`. Notes + ----- Code intended to remain backend-portable should prefer `BackendOps` methods. Direct use of `ops.jax`, `ops.jnp`, or `ops.jsparse` is an explicit JAX-specific escape hatch. @@ -54,23 +55,17 @@ class JaxOps(BackendOps): JAX ignores them. Array-creation routines may expose `device` and `out_sharding` for explicit placement or sharding. """ + import jax import jax.numpy as jnp import jax.experimental.sparse as jsparse + xp = jnp _family = BackendFamily.jax.value.lower() _allow_sparse = True def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.jnp.reshape).parameters - self._reshape_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.reshape).parameters - self._ravel_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ravel).parameters - self._zeros_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros).parameters - self._empty_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.empty).parameters - self._zeros_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros_like).parameters - self._ones_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ones_like).parameters - self._broadcast_to_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.broadcast_to).parameters - + super().__init__() def sanitize_dtype(self, dtype: DType | None) -> DType: """ @@ -123,83 +118,13 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return dt - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Jax ndarray or BCOO/BCSR, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - - Backend-specific notes: - Shape-polymorphic dimensions may not be concrete Python integers inside traced code. - """ - result = 1 - for dim in self.shape(x): - result *= dim - return result - @property def dense_array(self) -> Type[Any]: """ Dense array type using JAX. - Returns: + Returns + ------- Concrete dense array class accepted by this backend. See: @@ -212,7 +137,8 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using JAX. - Returns: + Returns + ------- Concrete sparse array classes accepted by this backend, or None. See: @@ -220,121 +146,6 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ return (self.jsparse.BCOO, self.jsparse.BCSR) - @property - def inf(self): - """ - Positive infinity scalar using JAX. - - Returns: - Backend scalar representing positive infinity. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.inf.html - """ - return self.jnp.array(self.jnp.inf) - - @property - def nan(self): - """ - NaN scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing NaN. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.nan.html - """ - return self.jnp.array(self.jnp.nan) - - @property - def pi(self): - """ - Pi scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing pi. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.pi.html - """ - return self.jnp.array(self.jnp.pi) - - @property - def e(self): - """ - Euler number scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing Euler's number. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.e.html - """ - return self.jnp.array(self.jnp.e) - - @property - def eps(self): - """ - Machine epsilon scalar using JAX. - - Input: - None. - - Output: - Backend scalar for float64 machine epsilon. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.finfo.html - """ - return self.jnp.array(self.jnp.finfo(self.jnp.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - copy: bool | None = None, - device: Any | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using JAX. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.asarray.html - """ - return self.jnp.asarray(a, dtype=dtype, order=order, copy=copy, device=device) - - def astype(self, x: DenseArray, dtype: DType, copy: bool = True) -> DenseArray: - """ - Cast an array to a dtype using JAX. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.astype.html - """ - return x.astype(dtype, copy=copy) - def assparse( self, x: Any, @@ -407,1433 +218,198 @@ def assparse( raise ValueError(f"Unknown sparse format: {format!r}") - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Create an uninitialized dense array using JAX. + Multiply sparse and dense arrays using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend array with uninitialized values. + Dense backend array containing the product. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.empty.html + https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. """ - if self._empty_supports_out_sharding: - return self.jnp.empty(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.empty(shape, dtype=dtype, device=device) + return a @ b - def zeros( + def vmap( self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Create a zero-filled dense array using JAX. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with zeros. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros.html + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using ``jax.vmap``.""" + return self.jax.vmap(fn, in_axes=in_axes, out_axes=out_axes) - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. - """ - if self._zeros_supports_out_sharding: - return self.jnp.zeros(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.zeros(shape, dtype=dtype, device=device) - - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, + return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: """ - Create a one-filled dense array using JAX. + Compute a stable log-sum-exp reduction using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Dense backend array; axis, weights, and sign options control the reduction. Output: - Dense backend array filled with ones. + Dense backend array or tuple containing log-sum-exp results. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones.html + https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html """ - return self.jnp.ones(shape, dtype=dtype, device=device, out_sharding=out_sharding) + return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): """ - Create zeros shaped like another array using JAX. + Set indexed values using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. Output: - Dense backend array of zeros. + Dense backend array with indexed values set. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros_like.html + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + JAX arrays are immutable; copy=False raises NotImplementedError. """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._zeros_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.zeros_like(x, **kwargs) + if not copy: + raise NotImplementedError( + "JAX arrays are immutable; copy=False is not supported." + ) + return x.at[index].set(values) - def ones_like( - self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Create ones shaped like another array using JAX. + def ix_(self, *args: Any) -> Any: + r""" + Build open mesh index arrays using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + args: One-dimensional index arrays or sequences. Output: - Dense backend array of ones. + Tuple of dense backend arrays usable for open-mesh indexing. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones_like.html - - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._ones_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.ones_like(x, **kwargs) + return self.jnp.ix_(*args) - def full_like( + def fori_loop( self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - shape: Any = None, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, *, - device: Any | None = None, - ) -> DenseArray: + unroll: int | bool | None = None, + ) -> T: """ - Create filled values shaped like another array using JAX. + Run a counted loop primitive using JAX. Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. + lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array filled with the requested value. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full_like.html - """ - return self.jnp.full_like(x, value, dtype=dtype, shape=shape, device=device) - - def arange(self, - start: int, - stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using JAX. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.arange.html + Backend-specific notes: + Loop bounds and unroll behavior follow JAX tracing and compilation rules. """ - return self.jnp.arange(start, stop, step, dtype=dtype, device=device, out_sharding=out_sharding) + return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - def full( + def while_loop( self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: """ - Create a filled dense array using JAX. + Run a while-loop primitive using JAX. Input: - shape: Output shape; fill_value and dtype options are backend-specific. + cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array filled with fill_value. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full.html - """ - return self.jnp.full(shape, fill_value, dtype=dtype, device=device) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using JAX. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.eye.html + Backend-specific notes: + Condition and body are staged according to JAX lax control-flow semantics. """ - return self.jnp.eye(N=N, M=M, k=k, dtype=dtype, device=device) + return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - def ravel( + def scan( self, - a: DenseArray, - order: Literal["C", "F", "A", "K"] = "C", - *, - out_sharding: Any | None = None, - ) -> DenseArray: + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + ) -> Tuple[Carry, Y]: """ - Flatten an array using JAX. + Run a scan primitive using JAX. Input: - x: Dense backend array plus optional order parameters. + f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: - One-dimensional dense backend array view or copy. + Tuple of final carry and stacked outputs. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ravel.html - """ - if order == "C" and out_sharding is None: - return self.jnp.ravel(a) - if self._ravel_supports_out_sharding: - return self.jnp.ravel(a, order=order, out_sharding=out_sharding) - return self.jnp.ravel(a, order=order) + https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A"] = "C", - *, - copy: bool | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + Backend-specific notes: + Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. """ - Reshape an array using JAX. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. + return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.reshape.html - """ - if order == "C" and copy is None and out_sharding is None: - return self.jnp.reshape(a, shape) - kwargs: dict[str, Any] = {"order": order} - if self._reshape_supports_copy: - kwargs["copy"] = copy - if self._reshape_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.reshape(a, shape, **kwargs) - - def transpose( - self, - x: DenseArray, - axes: Sequence[int] | None = None, - ) -> DenseArray: + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: """ - Permute array axes using JAX. + Run conditional branch selection using JAX. Input: - x: Dense backend array; axes: Optional axis order. + pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: - Dense backend array with permuted axes. + Result returned by the selected branch. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.transpose.html + https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html + + Backend-specific notes: + Branches are staged according to JAX lax.cond semantics rather than Python eager branching. """ - return self.jnp.transpose(x, axes=axes) + return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: + def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ - Interchange two axes using JAX. + Add into indexed values using JAX. Input: - x: Dense backend array; axis1 and axis2: Axes to swap. + x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. Output: - Dense backend array with the two axes exchanged. + Dense backend array with indexed values incremented. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.swapaxes.html - """ - return self.jnp.swapaxes(x, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - *, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Broadcast an array to a shape using JAX. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.broadcast_to.html - """ - if self._broadcast_to_supports_out_sharding: - return self.jnp.broadcast_to(x, shape, out_sharding=out_sharding) - return self.jnp.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using JAX. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.expand_dims.html - """ - return self.jnp.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using JAX. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.squeeze.html - """ - return self.jnp.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using JAX. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.moveaxis.html - """ - return self.jnp.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: Any | None = None, - dtype: DType | None = None, - ) -> DenseArray: - """ - Stack arrays along a new axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.stack.html - """ - return self.jnp.stack(arrays, axis=axis, out=out, dtype=dtype) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Compute complex conjugates using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.conj.html - """ - return self.jnp.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.real.html - """ - return self.jnp.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.imag.html - """ - return self.jnp.imag(x) - - def abs(self, x: DenseArray) -> DenseArray: - """ - Compute absolute values using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.abs.html - """ - return self.jnp.abs(x) - - def sign(self, x: DenseArray) -> DenseArray: - """ - Compute signs elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sign.html - """ - return self.jnp.sign(x) - - def sqrt(self, x: DenseArray) -> DenseArray: - """ - Compute square roots elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sqrt.html - """ - return self.jnp.sqrt(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by summation using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sum.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.sum(a) - return self.jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: None = None, - keepdims: bool = False, - *, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by arithmetic mean using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.mean.html - """ - if out is None and where is None: - if axis is None and dtype is None and not keepdims: - return self.jnp.mean(a) - return self.jnp.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by minimum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.min.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.min(a) - return self.jnp.min(a, axis=axis, keepdims=keepdims) - return self.jnp.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by maximum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.max.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.max(a) - return self.jnp.max(a, axis=axis, keepdims=keepdims) - return self.jnp.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by product using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.prod.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.prod(a) - return self.jnp.prod(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def trace( - self, - a: DenseArray, - offset: int | Any = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using JAX. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.trace.html - """ - return self.jnp.trace(a, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out) - - def argsort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Return sorting indices using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argsort.html - """ - return self.jnp.argsort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def sort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Sort values using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sort.html - """ - return self.jnp.sort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmin.html - """ - return self.jnp.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmax.html - """ - return self.jnp.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - ) -> DenseArray: - """ - Compute a conjugating vector dot product using JAX. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vdot.html - """ - if precision is None and preferred_element_type is None: - return self.jnp.vdot(a, b) - return self.jnp.vdot(a, b, precision=precision, preferred_element_type=preferred_element_type) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - out_sharding: Any | None = None - ) -> DenseArray: - """ - Compute matrix products using JAX. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html - """ - if precision is None and preferred_element_type is None and out_sharding is None: - return self.jnp.matmul(a, b) - return self.jnp.matmul( - a, - b, - precision=precision, - preferred_element_type=preferred_element_type, - out_sharding=out_sharding, - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using JAX. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/jax.experimental.sparse.html - - Backend-specific notes: - Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. - """ - return a @ b - - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a Kronecker product using JAX. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.kron.html - """ - return self.jnp.kron(a, b) - - def einsum( - self, - subscripts: str, - /, - *operands: DenseArray, - out: Any | None = None, - optimize: str | bool | list[Tuple[int, ...]] = "auto", - precision: Any | None = None, - preferred_element_type: DType | None = None, - _dot_general: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Evaluate an Einstein summation expression using JAX. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.einsum.html - """ - return self.jnp.einsum( - subscripts, - *operands, - out=out, - optimize=optimize, - precision=precision, - preferred_element_type=preferred_element_type, - # _dot_general=_dot_general, - out_sharding=out_sharding, - ) - - def eigh( - self, - x: DenseArray, - UPLO: Literal["L", "U"] = "L", - symmetrize_input: bool = True - ) -> Tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenpairs using JAX. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigh.html - - Backend-specific notes: - SpaceCore rejects sparse input before delegating to JAX dense linear algebra. - """ - if self.is_sparse(x): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.jnp.linalg.eigh(x, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Compute vector or matrix norms using JAX. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.norm.html - """ - return self.jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Solve dense linear systems using JAX. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.solve.html - """ - return self.jnp.linalg.solve(A, b) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using JAX. - - Input: - A: Dense Hermitian or symmetric backend array. - - Output: - Dense backend array containing eigenvalues. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigvalsh.html - """ - return self.jnp.linalg.eigvalsh(A, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - subset_by_index: tuple[int, int] | None = None, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decompositions using JAX. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html - """ - return self.jnp.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - subset_by_index=subset_by_index, - ) - - def cholesky( - self, - A: DenseArray, - *, - upper: bool = False, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Cholesky factors using JAX. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.cholesky.html - """ - return self.jnp.linalg.cholesky(A, upper=upper, symmetrize_input=symmetrize_input) - - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, - return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Compute a stable log-sum-exp reduction using JAX. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. - - Output: - Dense backend array or tuple containing log-sum-exp results. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html - """ - return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - - def exp(self, x: DenseArray) -> DenseArray: - """ - Compute exponentials elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.exp.html - """ - return self.jnp.exp(x) - - def log(self, x: DenseArray) -> DenseArray: - """ - Compute natural logarithms elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.log.html - """ - return self.jnp.log(x) - - def maximum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise maxima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.maximum.html - """ - return self.jnp.maximum(x, y) - - def minimum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise minima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.minimum.html - """ - return self.jnp.minimum(x, y) - - def clip(self, x: DenseArray, a_min: DenseArray, a_max: DenseArray) -> DenseArray: - """ - Clip values into an interval using JAX. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html - """ - return self.jnp.clip(x, a_min, a_max) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isfinite.html - """ - return self.jnp.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isnan.html - """ - return self.jnp.isnan(x) - - def where(self, condition: DenseArray | bool, x: DenseArray | None = None, y: DenseArray | None = None, *, - size: int | None = None, fill_value: DenseArray | None = None) -> DenseArray: - """ - Select values by condition using JAX. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html - """ - return self.jnp.where(condition, x, y, size=size, fill_value=fill_value) - - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Join arrays along an existing axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.concatenate.html - """ - if dtype is None: - return self.jnp.concatenate(arrays, axis=axis) - return self.jnp.concatenate(arrays, axis=axis, dtype=dtype) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: None = None, - mode: str | None = None, - unique_indices: bool = False, - indices_are_sorted: bool = False, - fill_value: Any | None = None, - ) -> DenseArray: - """ - Take values by integer indices using JAX. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.take.html - - Backend-specific notes: - Out-of-bounds and mode behavior follow JAX, which can differ from NumPy. - """ - return self.jnp.take( - x, - indices, - axis=axis, - out=out, - mode=mode, - unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, - fill_value=fill_value, - ) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diag.html - """ - return self.jnp.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using JAX. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diagonal.html - """ - return self.jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.tril.html - """ - return self.jnp.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.triu.html - """ - return self.jnp.triu(x, k=k) - - def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): - """ - Set indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - - Backend-specific notes: - JAX arrays are immutable; copy=False raises NotImplementedError. - """ - if not copy: - raise NotImplementedError( - "JAX arrays are immutable; copy=False is not supported." - ) - return x.at[index].set(values) - - def ix_(self, *args: Any) -> Any: - """ - Build open mesh index arrays using JAX. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html - """ - return self.jnp.ix_(*args) - - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - *, - unroll: int | bool | None = None, - ) -> T: - """ - Run a counted loop primitive using JAX. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html - - Backend-specific notes: - Loop bounds and unroll behavior follow JAX tracing and compilation rules. - """ - return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Run a while-loop primitive using JAX. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - - Backend-specific notes: - Condition and body are staged according to JAX lax control-flow semantics. - """ - return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - _split_transpose: bool = False, - ) -> Tuple[Carry, Y]: - """ - Run a scan primitive using JAX. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - - Backend-specific notes: - Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. - """ - return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Run conditional branch selection using JAX. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html - - Backend-specific notes: - Branches are staged according to JAX lax.cond semantics rather than Python eager branching. - """ - return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - - def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): - """ - Add into indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html Backend-specific notes: JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. @@ -1844,28 +420,6 @@ def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bo ) return x.at[index].add(values) - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using JAX. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.allclose.html - """ - return self.jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/jax/_pytree.py b/spacecore/backend/jax/_pytree.py index 55f5fa6..df68672 100644 --- a/spacecore/backend/jax/_pytree.py +++ b/spacecore/backend/jax/_pytree.py @@ -1,20 +1,32 @@ from __future__ import annotations + from typing import TypeVar T = TypeVar("T") -def jax_pytree_class(cls: T) -> T: + +def jax_pytree_class(klass: T) -> T: """ Mark a class as a JAX PyTree node, if JAX is available. Safe to import without JAX installed. + + Parameters + ---------- + klass : type + Class implementing JAX pytree methods. + + Returns + ------- + type + Registered class when JAX is available, otherwise ``klass`` unchanged. """ try: from jax import tree_util except Exception: - return cls + return klass try: - tree_util.register_pytree_node_class(cls) + tree_util.register_pytree_node_class(klass) except Exception: pass - return cls + return klass diff --git a/spacecore/backend/numpy/__init__.py b/spacecore/backend/numpy/__init__.py index 9528a76..0bfc215 100644 --- a/spacecore/backend/numpy/__init__.py +++ b/spacecore/backend/numpy/__init__.py @@ -1 +1,3 @@ -from ._ops import NumpyOps as NumpyOps \ No newline at end of file +"""NumPy backend implementation.""" + +from ._ops import NumpyOps as NumpyOps diff --git a/spacecore/backend/numpy/_ops.py b/spacecore/backend/numpy/_ops.py index 0a309ad..efc9929 100644 --- a/spacecore/backend/numpy/_ops.py +++ b/spacecore/backend/numpy/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Tuple, Literal, Callable, Optional, Type -import inspect from .._family import BackendFamily from .._ops import BackendOps @@ -23,6 +22,7 @@ class NumpyOps(BackendOps): scipy.sparse.sparray Methods + ------- Most methods mirror the corresponding NumPy or SciPy signatures and delegate directly to NumPy/SciPy implementations. Backend-specific behavior, dtype promotion, broadcasting, memory layout, and error modes @@ -39,6 +39,7 @@ class NumpyOps(BackendOps): `ops.sp`. Advanced users may use it for SciPy-specific functionality. Notes + ----- Code intended to remain backend-portable should prefer `BackendOps` methods. Direct use of `ops.np` or `ops.sp` is an explicit NumPy/SciPy-specific escape hatch. @@ -47,1350 +48,131 @@ class NumpyOps(BackendOps): When supplied, it must be `"cpu"` or `None`; see the corresponding NumPy documentation for each method. """ + import numpy as np import scipy as sp + import array_api_compat.numpy as xp _family = BackendFamily.numpy.value.lower() _allow_sparse = True - @property - def dense_array(self) -> Type[Any]: - """ - Dense array type using NumPy. - - Returns: - Concrete dense array class accepted by this backend. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html - """ - return self.np.ndarray - - @property - def sparse_array(self) -> Tuple[Type[Any], ...]: - """ - Sparse array type tuple using SciPy. - - Returns: - Concrete sparse array classes accepted by this backend, or None. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - """ - sparse = self.sp.sparse - types: list[type[Any]] = [] - if hasattr(sparse, "spmatrix"): - types.append(sparse.spmatrix) - if hasattr(sparse, "sparray"): - types.append(sparse.sparray) - return tuple(types) - - - def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.np.reshape).parameters - - def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Normalize a dtype specifier using NumPy. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.dtype.html - """ - - if dtype is None: - return self.np.float64 - return self.np.dtype(dtype) - - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Numpy ndarray or SciPy sparse array, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html - - Backend-specific notes: - SciPy sparse inputs are reported by logical dense size, not stored entries. - """ - return int(self.np.prod(self.shape(x), dtype=self.np.intp)) - - @property - def inf(self): - """ - Positive infinity scalar using NumPy. - - Returns: - Backend scalar representing positive infinity. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.inf) - - @property - def nan(self): - """ - NaN scalar using NumPy. - - Returns: - Backend scalar representing NaN. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.nan) - - @property - def pi(self): - """ - Pi scalar using NumPy. - - Returns: - Backend scalar representing pi. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.pi) - - @property - def e(self): - """ - Euler number scalar using NumPy. - - Returns: - Backend scalar representing Euler's number. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using NumPy. - - Returns: - Backend scalar for float64 machine epsilon. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.finfo.html - """ - return self.np.array(self.np.finfo(self.np.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - device: str | None = None, - copy: bool | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using NumPy. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.asarray.html - """ - return self.np.asarray( - a, - dtype=dtype, - order=order, - device=device, - copy=copy, - like=like, - ) - - def astype( - self, - x: DenseArray, - dtype: DType, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "unsafe", - subok: bool = True, - copy: bool = True, - ) -> DenseArray: - """ - Cast an array to a dtype using NumPy. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.astype.html - """ - return x.astype(dtype, order=order, casting=casting, subok=subok, copy=copy) - - def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: - """ - Convert input to a sparse array using SciPy. - - Input: - x: Dense, sparse, or array-like input plus sparse-format options. - - Output: - Sparse backend array. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. - """ - sparse = self.sp.sparse - - if self.is_sparse(x): - if format == "csr": - return x.tocsr() - if format == "csc": - return x.tocsc() - if format == "coo": - return x.tocoo() - raise ValueError(f"Unknown sparse format: {format!r}") - - x_arr = self.asarray(x) - - if x_arr.ndim != 2: - raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - - if format == "csr": - return sparse.csr_matrix(x_arr) - if format == "csc": - return sparse.csc_matrix(x_arr) - if format == "coo": - return sparse.coo_matrix(x_arr) - - raise ValueError(f"Unknown sparse format: {format!r}") - - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create an uninitialized dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array with uninitialized values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.empty.html - """ - - return self.np.empty( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a zero-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros.html - """ - return self.np.zeros( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a one-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones.html - """ - return self.np.ones( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create zeros shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros_like.html - """ - return self.np.zeros_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def ones_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create ones shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones_like.html - """ - return self.np.ones_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def full_like( - self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create filled values shaped like another array using NumPy. - - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with the requested value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full_like.html - """ - return self.np.full_like(x, value, dtype=dtype, order=order, subok=subok, shape=shape) - - def arange(self, - start: int, stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using NumPy. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.arange.html - """ - return self.np.arange( - start, - stop, - step, - dtype=dtype, - device=device, - like=like, - ) - - def full( - self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a filled dense array using NumPy. - - Input: - shape: Output shape; fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with fill_value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full.html - """ - return self.np.full( - shape, - fill_value, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using NumPy. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.eye.html - """ - return self.np.eye( - N, - M=M, - k=k, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ravel(self, a: DenseArray, order: Literal["C", "F", "A", "K"] = "C") -> DenseArray: - """ - Flatten an array using NumPy. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ravel.html - """ - return self.np.ravel(a, order=order) - - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A", "K"] = "C", - copy: bool | None = None, - ) -> DenseArray: - """ - Reshape an array using NumPy. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.reshape.html - """ - if self._reshape_supports_copy: - return self.np.reshape(a, shape, order=order, copy=copy) - if copy: - return self.np.array(a, copy=True).reshape(shape, order=order) - return self.np.reshape(a, shape, order=order) - - def transpose(self, a: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute array axes using NumPy. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.transpose.html - """ - return self.np.transpose(a, axes=axes) - - def swapaxes(self, a: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Interchange two axes using NumPy. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html - """ - return self.np.swapaxes(a, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - subok: bool = False, - ) -> DenseArray: - """ - Broadcast an array to a shape using NumPy. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html - """ - return self.np.broadcast_to(x, shape, subok=subok) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html - """ - return self.np.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html - """ - return self.np.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using NumPy. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html - """ - return self.np.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Stack arrays along a new axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.stack.html - """ - return self.np.stack(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def conj( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute complex conjugates using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.conj.html - """ - return self.np.conj( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def abs( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute absolute values using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.abs.html - """ - return self.np.abs( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sign( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute signs elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sign.html - """ - return self.np.sign( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sqrt( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute square roots elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html - """ - return self.np.sqrt( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.real.html - """ - return self.np.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.imag.html - """ - return self.np.imag(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by summation using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sum.html - """ - return self.np.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by arithmetic mean using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.mean.html - """ - return self.np.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by minimum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.min.html - """ - return self.np.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by maximum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.max.html - """ - return self.np.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by product using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.prod.html - """ - return self.np.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def trace( - self, - a: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using NumPy. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.trace.html - """ - return self.np.trace( - a, - offset=offset, - axis1=axis1, - axis2=axis2, - dtype=dtype, - out=out, - ) - - def argsort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Return sorting indices using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argsort.html - """ - return self.np.argsort(a, axis=axis, kind=kind, order=order, stable=stable) - - def sort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Sort values using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sort.html - """ - return self.np.sort(a, axis=axis, kind=kind, order=order, stable=stable) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmin.html - """ - return self.np.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmax.html - """ - return self.np.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a conjugating vector dot product using NumPy. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.vdot.html - """ - return self.np.vdot(a, b) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - /, - out: DenseArray | None = None, - *, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute matrix products using NumPy. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.matmul.html - """ - return self.np.matmul( - a, - b, - out=out, - casting=casting, - order=order, - dtype=dtype, - subok=subok - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using SciPy. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. - """ - if not self.is_sparse(a): - raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") - if not self.is_dense(b): - raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") - return a @ b - - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a Kronecker product using NumPy. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.kron.html - """ - return self.np.kron(a, b) + def __init__(self) -> None: + super().__init__() - def einsum( - self, - subscripts: str, - *operands: DenseArray, - out: DenseArray | None = None, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "safe", - optimize: bool | str | Sequence[Any] = False, - ) -> DenseArray: + @property + def dense_array(self) -> Type[Any]: """ - Evaluate an Einstein summation expression using NumPy. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. + Dense array type using NumPy. - Output: - Dense backend array containing the contraction result. + Returns + ------- + Concrete dense array class accepted by this backend. See: - https://numpy.org/doc/stable/reference/generated/numpy.einsum.html - """ - return self.np.einsum( - subscripts, - *operands, - out=out, - dtype=dtype, - order=order, - casting=casting, - optimize=optimize, - ) - - def eigh(self, a: DenseArray, UPLO: Literal["L", "U"] = "L") -> Tuple[DenseArray, DenseArray]: + https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html """ - Compute Hermitian eigenpairs using NumPy. + return self.np.ndarray - Input: - x: Dense Hermitian or symmetric backend array. + @property + def sparse_array(self) -> Tuple[Type[Any], ...]: + """ + Sparse array type tuple using SciPy. - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. + Returns + ------- + Concrete sparse array classes accepted by this backend, or None. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigh.html + https://docs.scipy.org/doc/scipy/reference/sparse.html """ - if self.is_sparse(a): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.np.linalg.eigh(a, UPLO=UPLO) + sparse = self.sp.sparse + types: list[type[Any]] = [] + if hasattr(sparse, "spmatrix"): + types.append(sparse.spmatrix) + if hasattr(sparse, "sparray"): + types.append(sparse.sparray) + return tuple(types) - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: + + def sanitize_dtype(self, dtype: DType | None) -> DType: """ - Compute vector or matrix norms using NumPy. + Normalize a dtype specifier using NumPy. Input: - x: Dense backend array; ord, axis, and keepdims select the norm. + dtype: Optional dtype requested by SpaceCore or the caller. Output: - Dense backend array or scalar containing norm values. + Backend dtype object accepted by array constructors. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + https://numpy.org/doc/stable/reference/generated/numpy.dtype.html """ - return self.np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + if dtype is None: + return self.np.float64 + return self.np.dtype(dtype) - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: + def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: """ - Solve dense linear systems using NumPy. + Convert input to a sparse array using SciPy. Input: - A: Dense coefficient array; b: Dense right-hand side array. + x: Dense, sparse, or array-like input plus sparse-format options. Output: - Dense backend array solving A @ x = b. + Sparse backend array. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html - """ - return self.np.linalg.solve(A, b) + https://docs.scipy.org/doc/scipy/reference/sparse.html - def eigvalsh(self, A: DenseArray, UPLO: Literal["L", "U"] = "L") -> DenseArray: + Backend-specific notes: + SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. """ - Compute Hermitian eigenvalues using NumPy. - - Input: - A: Dense Hermitian or symmetric backend array. + sparse = self.sp.sparse - Output: - Dense backend array containing eigenvalues. + if self.is_sparse(x): + if format == "csr": + return x.tocsr() + if format == "csc": + return x.tocsc() + if format == "coo": + return x.tocoo() + raise ValueError(f"Unknown sparse format: {format!r}") - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigvalsh.html - """ - return self.np.linalg.eigvalsh(A, UPLO=UPLO) + x_arr = self.asarray(x) - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decompositions using NumPy. + if x_arr.ndim != 2: + raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - Input: - A: Dense backend array plus SVD options. + if format == "csr": + return sparse.csr_matrix(x_arr) + if format == "csc": + return sparse.csc_matrix(x_arr) + if format == "coo": + return sparse.coo_matrix(x_arr) - Output: - Dense backend arrays containing singular vectors and/or singular values. + raise ValueError(f"Unknown sparse format: {format!r}") - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html - """ - return self.np.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - ) - - def cholesky(self, A: DenseArray) -> DenseArray: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Compute Cholesky factors using NumPy. + Multiply sparse and dense arrays using SciPy. Input: - A: Dense Hermitian positive-definite backend array. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend array containing a triangular factor. + Dense backend array containing the product. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.cholesky.html + https://docs.scipy.org/doc/scipy/reference/sparse.html + + Backend-specific notes: + Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. """ - return self.np.linalg.cholesky(A) + if not self.is_sparse(a): + raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") + if not self.is_dense(b): + raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") + return a @ b def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: @@ -1408,362 +190,6 @@ def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: D """ return self.sp.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) - def exp( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute exponentials elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.exp.html - """ - return self.np.exp( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def log( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute natural logarithms elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.log.html - """ - return self.np.log( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def maximum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise maxima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.maximum.html - """ - return self.np.maximum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def minimum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise minima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.minimum.html - """ - return self.np.minimum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def clip( - self, - x: DenseArray, - a_min: DenseArray, - a_max: DenseArray, - out: DenseArray | None = None, - **kwargs: Any, - ) -> DenseArray: - """ - Clip values into an interval using NumPy. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.clip.html - """ - return self.np.clip(x, a_min, a_max, out=out, **kwargs) - - def isfinite( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test finiteness elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html - """ - return self.np.isfinite( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def isnan( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test NaN values elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isnan.html - """ - return self.np.isnan( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def where(self, condition: DenseArray | bool, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Select values by condition using NumPy. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.where.html - """ - return self.np.where(condition, x, y) - - def concatenate( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Join arrays along an existing axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html - """ - return self.np.concatenate(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - mode: Literal["raise", "wrap", "clip"] = "raise", - ) -> DenseArray: - """ - Take values by integer indices using NumPy. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.take.html - """ - return self.np.take(x, indices, axis=axis, out=out, mode=mode) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diag.html - """ - return self.np.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using NumPy. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diagonal.html - """ - return self.np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.tril.html - """ - return self.np.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.triu.html - """ - return self.np.triu(x, k=k) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed values using NumPy. @@ -1789,7 +215,7 @@ def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bo return x def ix_(self, *args: Any) -> Any: - """ + r""" Build open mesh index arrays using NumPy. Input: @@ -1877,9 +303,7 @@ def _tree_multimap(self, f: Callable[..., Any], *trees: Any) -> Any: return f(*trees) def _tree_take0(self, xs: Any) -> Any: - """ - Grab a representative leaf to infer leading length. - """ + """Grab a representative leaf to infer leading length.""" if isinstance(xs, dict): return self._tree_take0(next(iter(xs.values()))) if isinstance(xs, (tuple, list)): @@ -1887,9 +311,7 @@ def _tree_take0(self, xs: Any) -> Any: return xs def _tree_index(self, xs: Any, i: int) -> Any: - """ - Take per-step slice xs[i] along axis=0 for each leaf. - """ + """Take per-step slice ``xs[i]`` along axis 0 for each leaf.""" def _idx(a: Any) -> Any: # If it's an ndarray-like with leading axis, slice it; else treat as scalar leaf. @@ -1901,9 +323,9 @@ def _idx(a: Any) -> Any: return self._tree_map(_idx, xs) def _tree_stack(self, ys_list: Sequence[Any]) -> Any: - """ - Stack a list of per-step outputs into a single pytree of arrays - by stacking leaves along axis=0. + """Stack per-step outputs into a single pytree of arrays. + + Leaves are stacked along axis 0. """ if not ys_list: # JAX would return empty stacked outputs when length == 0 @@ -2035,28 +457,6 @@ def index_add( self.np.add.at(y, index, values) return y - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using NumPy. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.allclose.html - """ - return self.np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/torch/__init__.py b/spacecore/backend/torch/__init__.py index 6bca321..b395f7b 100644 --- a/spacecore/backend/torch/__init__.py +++ b/spacecore/backend/torch/__init__.py @@ -1,3 +1,5 @@ +"""PyTorch backend implementation.""" + from ._ops import TorchOps diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index cd6defb..299648d 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -5,8 +5,8 @@ import numpy as np from .._family import BackendFamily -from .._ops import BackendOps -from ...types import ArrayLike, DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry +from .._ops import BackendOps, LazyNamespace +from ...types import DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry class TorchOps(BackendOps): @@ -22,6 +22,7 @@ class TorchOps(BackendOps): torch.Tensor with a PyTorch sparse layout Methods + ------- Most methods mirror the corresponding PyTorch public API signatures and delegate to ``torch`` or ``torch.linalg``. Backend-specific behavior, dtype promotion, broadcasting, device placement, autograd tracking, and @@ -34,6 +35,7 @@ class TorchOps(BackendOps): portable API does not expose a required PyTorch feature. Notes + ----- Code intended to remain backend-portable should prefer ``BackendOps`` methods. Direct use of ``ops.torch`` is an explicit PyTorch-specific escape hatch. @@ -52,6 +54,7 @@ class TorchOps(BackendOps): """ import torch + xp = LazyNamespace("array_api_compat.torch") _family = BackendFamily.torch.value.lower() _allow_sparse = True @@ -64,6 +67,9 @@ class TorchOps(BackendOps): torch.sparse_bsc, ) + def __init__(self) -> None: + super().__init__() + @staticmethod def _defined_kwargs(**kwargs: Any) -> dict[str, Any]: return {key: value for key, value in kwargs.items() if value is not None} @@ -73,7 +79,8 @@ def dense_array(self) -> Type[Any]: """ Dense array type using PyTorch. - Returns: + Returns + ------- Concrete dense tensor class accepted by this backend. See: @@ -86,7 +93,8 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ Sparse array type tuple using PyTorch. - Returns: + Returns + ------- Tensor class accepted by this backend for sparse tensor layouts. See: @@ -180,193 +188,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return mapping[np_dtype] raise TypeError(f"Dtype {np_dtype!r} is not supported by PyTorch.") - def get_dtype(self, x: Any) -> DType: - """ - Return a tensor dtype using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Backend dtype associated with x. - - See: - https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype - """ - if self.is_array(x): - return x.dtype - raise TypeError(f"Expected PyTorch tensor, got {type(x)}.") - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return tensor shape metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return tensor rank metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Number of dimensions in x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Total number of logical dense elements. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.numel.html - """ - return int(x.numel()) - - @property - def inf(self): - """ - Positive infinity scalar using PyTorch. - - Returns: - Backend tensor scalar representing positive infinity. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("inf")) - - @property - def nan(self): - """ - NaN scalar using PyTorch. - - Returns: - Backend tensor scalar representing NaN. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("nan")) - - @property - def pi(self): - """ - Pi scalar using PyTorch. - - Returns: - Backend tensor scalar representing pi. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.pi) - - @property - def e(self): - """ - Euler number scalar using PyTorch. - - Returns: - Backend tensor scalar representing Euler's number. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using PyTorch. - - Returns: - Backend tensor scalar for float64 machine epsilon. - - See: - https://docs.pytorch.org/docs/stable/type_info.html#torch.finfo - """ - return self.torch.tensor(self.torch.finfo(self.torch.float64).eps) - - def asarray( - self, - x: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - copy: bool | None = None, - ) -> DenseArray: - """ - Convert input to a dense tensor using PyTorch. - - Input: - x/a: Array-like input and optional dtype, device, or copy controls. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.as_tensor.html - - Backend-specific notes: - Sparse tensors are densified. Existing tensors keep autograd - metadata according to normal PyTorch conversion rules. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if self.is_sparse(x): - x = x.to_dense() - out = self.torch.as_tensor(x, dtype=dtype, device=device) - if copy: - out = out.clone() - return out - - def astype( - self, - x: DenseArray, - dtype: DType, - copy: bool = True, - *, - non_blocking: bool = False, - memory_format: Any | None = None, - ) -> DenseArray: - """ - Cast a tensor to a dtype using PyTorch. - - Input: - x: Dense backend tensor; dtype: Target dtype; copy: Whether to force a copy. - - Output: - Tensor with the requested dtype. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html - """ - return x.to( - dtype=self.sanitize_dtype(dtype), - non_blocking=non_blocking, - copy=copy, - **self._defined_kwargs(memory_format=memory_format), - ) - def assparse( self, x: Any, @@ -427,72 +248,69 @@ def assparse( return out.to_sparse_csc() raise ValueError(f"Unknown sparse format: {format!r}") - def empty( + def asarray( self, - shape: int | Tuple[int, ...], + x: Any, dtype: DType | None = None, *, - out: DenseArray | None = None, - layout: Any | None = None, device: Any | None = None, - requires_grad: bool = False, - pin_memory: bool = False, - memory_format: Any | None = None, + copy: bool | None = None, + backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: - """ - Create an uninitialized dense tensor using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor with uninitialized values. + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if device is not None: + kwargs["device"] = device + dtype = self.sanitize_dtype(dtype) if dtype is not None else None + if self.is_sparse(x): + x = x.to_dense() + out = self.torch.as_tensor(x, dtype=dtype, **kwargs) + return out.clone() if copy else out - See: - https://docs.pytorch.org/docs/stable/generated/torch.empty.html - """ - return self.torch.empty( - shape, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - pin_memory=pin_memory, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), + def astype( + self, + x: DenseArray, + dtype: DType | None, + *, + copy: bool = True, + non_blocking: bool = False, + memory_format: Any | None = None, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + if dtype is None: + return x + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(memory_format=memory_format)) + return x.to( + dtype=self.sanitize_dtype(dtype), + non_blocking=non_blocking, + copy=copy, + **kwargs, ) - def zeros( + def empty( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, + pin_memory: bool = False, + memory_format: Any | None = None, ) -> DenseArray: - """ - Create a dense tensor filled with zeros using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros.html - """ - return self.torch.zeros( + return self.torch.empty( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), + pin_memory=pin_memory, + **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones( + def zeros( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, @@ -500,19 +318,7 @@ def ones( device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: - """ - Create a dense tensor filled with ones using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones.html - """ - return self.torch.ones( + return self.torch.zeros( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -530,18 +336,6 @@ def zeros_like( requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: - """ - Create a zero tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros_like.html - """ return self.torch.zeros_like( x, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -549,71 +343,11 @@ def zeros_like( **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones_like( - self, - x: DenseArray, - dtype: DType | None = None, - *, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - memory_format: Any | None = None, - ) -> DenseArray: - """ - Create a one tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones_like.html - """ - return self.torch.ones_like( - x, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) - - def full_like( - self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - *, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - memory_format: Any | None = None, - ) -> DenseArray: - """ - Create a filled tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; value: Fill value; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.full_like.html - """ - return self.torch.full_like( - x, - value, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) - def arange( self, - start: int | float = 0, - stop: int | float | None = None, - step: int | float | None = None, + start: int, + stop: int | None = None, + step: int | None = None, dtype: DType | None = None, *, out: DenseArray | None = None, @@ -621,712 +355,43 @@ def arange( device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: - """ - Create a range tensor using PyTorch. - - Input: - start, stop, step: Range parameters; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor containing evenly spaced values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.arange.html - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None kwargs = self._defined_kwargs(out=out, layout=layout, device=device) kwargs["requires_grad"] = requires_grad + dtype = self.sanitize_dtype(dtype) if dtype is not None else None if stop is None: return self.torch.arange(start, dtype=dtype, **kwargs) if step is None: return self.torch.arange(start, stop, dtype=dtype, **kwargs) return self.torch.arange(start, stop, step, dtype=dtype, **kwargs) - def full( + def sum( self, - shape: int | Tuple[int, ...], - fill_value: Any, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, dtype: DType | None = None, *, out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a filled dense tensor using PyTorch. - - Input: - shape: Output shape; fill_value: Fill value; dtype and device: Optional parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.full.html - """ - return self.torch.full( - shape, - fill_value, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - - def eye( - self, - n: int, - m: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a two-dimensional identity-like tensor using PyTorch. - - Input: - n, m: Matrix dimensions; k: Diagonal offset; dtype and device: Optional parameters. - - Output: - Dense backend tensor with ones on the requested diagonal. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.eye.html - - Backend-specific notes: - PyTorch ``torch.eye`` has no diagonal offset parameter, so SpaceCore - constructs the offset diagonal explicitly. - """ - m = n if m is None else m - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if k == 0: - return self.torch.eye( - n, - m, - out=out, - dtype=dtype, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - out = self.torch.zeros( - (n, m), - out=out, - dtype=dtype, - requires_grad=False, - **self._defined_kwargs(layout=layout, device=device), - ) - diag_len = min(n, m - k) if k > 0 else min(n + k, m) - if diag_len <= 0: - return out - rows = self.torch.arange(diag_len, device=device) - cols = rows + k - if k < 0: - rows = rows - k - cols = self.torch.arange(diag_len, device=device) - out[rows, cols] = 1 - if requires_grad: - out.requires_grad_() - return out - - def ravel(self, x: DenseArray) -> DenseArray: - """ - Flatten a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - One-dimensional tensor view or copy following PyTorch semantics. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ravel.html - """ - return self.torch.ravel(x) - - def reshape(self, x: DenseArray, shape: int | Tuple[int, ...], *, copy: bool | None = None) -> DenseArray: - """ - Reshape a tensor using PyTorch. - - Input: - x: Dense backend tensor; shape: Target shape; copy: Whether to clone first. - - Output: - Reshaped dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.reshape.html - """ - if copy: - x = x.clone() - return self.torch.reshape(x, shape if isinstance(shape, tuple) else (shape,)) - - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axes: Optional axis order. - - Output: - Tensor with permuted axes. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.permute.html - """ - if axes is None: - axes = tuple(reversed(range(x.ndim))) - return x.permute(tuple(axes)) - - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Swap two tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axis1, axis2: Axes to swap. - - Output: - Tensor with the requested axes swapped. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.swapaxes.html - """ - return self.torch.swapaxes(x, axis1, axis2) - - def broadcast_to(self, x: DenseArray, shape: int | Tuple[int, ...]) -> DenseArray: - """ - Broadcast a tensor to a shape using PyTorch. - - Input: - x: Dense backend tensor; shape: Target broadcast shape. - - Output: - Broadcasted tensor view following PyTorch broadcasting rules. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html - """ - return self.torch.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Axis or axes where dimensions are inserted. - - Output: - Tensor with inserted singleton dimensions. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html - """ - if isinstance(axis, int): - return self.torch.unsqueeze(x, axis) - ndim = x.ndim + len(axis) - axes = sorted(a + ndim if a < 0 else a for a in axis) - out = x - for ax in axes: - out = self.torch.unsqueeze(out, ax) - return out - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Optional axis or axes to squeeze. - - Output: - Tensor with singleton dimensions removed. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html - """ - if axis is None: - return self.torch.squeeze(x) - if isinstance(axis, int): - return self.torch.squeeze(x, dim=axis) - out = x - for ax in sorted(axis, reverse=True): - out = self.torch.squeeze(out, dim=ax) - return out - - def moveaxis(self, x: DenseArray, source: int | Sequence[int], destination: int | Sequence[int]) -> DenseArray: - """ - Move tensor axes to new positions using PyTorch. - - Input: - x: Dense backend tensor; source and destination: Axis positions. - - Output: - Tensor with axes moved. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.moveaxis.html - """ - return self.torch.moveaxis(x, source, destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Stack tensors along a new axis using PyTorch. - - Input: - arrays: Sequence of tensors; axis: New axis; out: Optional output tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.stack.html - """ - arrays = tuple(arrays) - if out is None: - return self.torch.stack(arrays, dim=axis) - return self.torch.stack(arrays, dim=axis, out=out) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Return the complex conjugate using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor containing complex conjugates. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.conj.html - """ - return self.torch.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Return the real part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing real components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.real.html - """ - return self.torch.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Return the imaginary part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing imaginary components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.imag.html - """ - return self.torch.imag(x) - - def abs( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise absolute value using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.abs.html - """ - if out is None: - return self.torch.abs(x) - return self.torch.abs(x, out=out) - - def sign( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise sign using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sign.html - """ - if out is None: - return self.torch.sign(x) - return self.torch.sign(x, out=out) - - def sqrt( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise square root using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sqrt.html - """ - if out is None: - return self.torch.sqrt(x) - return self.torch.sqrt(x, out=out) - - def sum( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sum.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x) - return self.torch.sum(x, dim=axis, keepdim=keepdims) - return self.torch.sum(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x, dtype=dtype) - return self.torch.sum(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.sum( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def mean( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Average tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.mean.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x) - return self.torch.mean(x, dim=axis, keepdim=keepdims) - return self.torch.mean(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x, dtype=dtype) - return self.torch.mean(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.mean( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def min( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amin.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amin(x) - return self.torch.amin(x, dim=axis, keepdim=keepdims) - return self.torch.amin(x, dim=axis, keepdim=keepdims, out=out) - - def max( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amax.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amax(x) - return self.torch.amax(x, dim=axis, keepdim=keepdims) - return self.torch.amax(x, dim=axis, keepdim=keepdims, out=out) - - def prod( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, ) -> DenseArray: - """ - Multiply tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.prod.html - - Backend-specific notes: - Multiple-axis products are applied one axis at a time because - PyTorch's ``torch.prod`` reduces a single dimension per call. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if axis is None: - result = self.torch.prod(x) if dtype is None else self.torch.prod(x, dtype=dtype) - if out is not None: - out.copy_(result) - return out - return result - if isinstance(axis, int): - if out is None: - if dtype is None: - return self.torch.prod(x, dim=axis, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) - result = x - for ax in sorted(axis, reverse=True): - result = self.torch.prod(result, dim=ax, dtype=dtype, keepdim=keepdims) + kwargs = {"dim": self._to_axis_tuple(axis), "keepdim": keepdims} + if dtype is not None: + kwargs["dtype"] = self.sanitize_dtype(dtype) if out is not None: - out.copy_(result) - return out - return result - - def trace( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2, dtype: Diagonal controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.sum(self.diagonal(x, offset=offset, axis1=axis1, axis2=axis2), dtype=dtype) - - def argsort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - ) -> DenseArray: - """ - Return sorting indices using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argsort.html - """ - return self.torch.argsort(x, dim=axis, stable=stable, descending=descending) - - def sort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - *, - out: tuple[DenseArray, DenseArray] | None = None, - ) -> DenseArray: - """ - Sort tensor values using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Dense backend tensor of sorted values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sort.html - """ - return self.torch.sort(x, dim=axis, stable=stable, descending=descending, out=out).values - - def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmin.html - """ - return self.torch.argmin(x, dim=axis, keepdim=keepdims) - - def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmax.html - """ - return self.torch.argmax(x, dim=axis, keepdim=keepdims) - - def vdot( - self, - x: DenseArray, - y: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute conjugating vector dot product using PyTorch. - - Input: - x, y: Dense backend tensors. - - Output: - Scalar tensor containing the vector dot product. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.vdot.html - """ - x1 = x if x.ndim == 1 else self.torch.ravel(x) - y1 = y if y.ndim == 1 else self.torch.ravel(y) - if out is None: - return self.torch.vdot(x1, y1) - return self.torch.vdot(x1, y1, out=out) + kwargs["out"] = out + return self.torch.sum(x, **kwargs) def matmul( self, a: DenseArray, b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, out: DenseArray | None = None, ) -> DenseArray: - """ - Matrix-multiply tensors using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.matmul.html - """ - if out is None: - return self.torch.matmul(a, b) - return self.torch.matmul(a, b, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if out is not None: + kwargs["out"] = out + return self.torch.matmul(a, b, **kwargs) def sparse_matmul( self, @@ -1352,64 +417,33 @@ def sparse_matmul( return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs) - def kron( + def vmap( self, - a: DenseArray, - b: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute the Kronecker product using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.kron.html - """ - return self.torch.kron(a, b, out=out) - - def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Evaluate an Einstein summation using PyTorch. - - Input: - subscripts: Einsum expression; operands: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.einsum.html - """ - return self.torch.einsum(subscripts, *operands) + fn: Callable, + in_axes: int | Sequence[int | None] | None = 0, + out_axes: int | Sequence[int | None] | None = 0, + ) -> Callable: + """Vectorize a function using PyTorch's native vmap when available.""" + vmap = getattr(self.torch, "vmap", None) + if vmap is None and hasattr(self.torch, "func"): + vmap = getattr(self.torch.func, "vmap", None) + if vmap is None: + return super().vmap(fn, in_axes=in_axes, out_axes=out_axes) + return vmap(fn, in_dims=in_axes, out_dims=out_axes) def eigh( self, x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, UPLO: Literal["L", "U"] = "L", *, out: tuple[DenseArray, DenseArray] | None = None, ) -> tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenvalues and eigenvectors using PyTorch. - - Input: - x: Dense Hermitian or symmetric backend tensor. - - Output: - Tuple of eigenvalues and eigenvectors. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigh.html - """ if self.is_sparse(x): raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.torch.linalg.eigh(x, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.eigh(x, UPLO=UPLO, **kwargs) def norm( self, @@ -1421,18 +455,6 @@ def norm( dtype: DType | None = None, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute vector or matrix norms using PyTorch. - - Input: - x: Dense backend tensor; ord, axis, keepdims: Norm controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.norm.html - """ return self.torch.linalg.norm( x, ord=ord, @@ -1446,97 +468,39 @@ def solve( self, A: DenseArray, b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, left: bool = True, out: DenseArray | None = None, ) -> DenseArray: - """ - Solve a linear system using PyTorch. - - Input: - A: Coefficient tensor; b: Right-hand side tensor. - - Output: - Dense backend tensor solving ``A @ x = b``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.solve.html - """ - return self.torch.linalg.solve(A, b, left=left, out=out) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using PyTorch. - - Input: - A: Dense Hermitian or symmetric backend tensor. - - Output: - Dense backend tensor of eigenvalues. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigvalsh.html - """ - return self.torch.linalg.eigvalsh(A, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.solve(A, b, left=left, **kwargs) def svd( self, A: DenseArray, full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, + backend_kwargs: dict[str, Any] | None = None, *, driver: str | None = None, out: DenseArray | tuple[DenseArray, DenseArray, DenseArray] | None = None, ) -> DenseArray | tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decomposition using PyTorch. - - Input: - A: Dense backend tensor; full_matrices, compute_uv, hermitian: SVD controls. - - Output: - Singular values or tuple ``(U, S, Vh)``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html - - Backend-specific notes: - PyTorch does not expose a ``hermitian`` option for SVD. When - ``compute_uv`` is false, this delegates to ``torch.linalg.svdvals``. - """ - if hermitian: - raise NotImplementedError("PyTorch svd does not expose a hermitian option.") - if not compute_uv: - return self.torch.linalg.svdvals(A, driver=driver, out=out) - return self.torch.linalg.svd(A, full_matrices=full_matrices, driver=driver, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(driver=driver, out=out)) + return self.torch.linalg.svd(A, full_matrices=full_matrices, **kwargs) def cholesky( self, A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, upper: bool = False, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute a Cholesky factorization using PyTorch. - - Input: - A: Positive-definite dense backend tensor. - - Output: - Dense backend tensor containing the Cholesky factor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.cholesky.html - """ - return self.torch.linalg.cholesky(A, upper=upper, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.cholesky(A, upper=upper, **kwargs) def logsumexp( self, @@ -1583,188 +547,18 @@ def logsumexp( return out return result - def exp( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise exponential using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.exp.html - """ - if out is None: - return self.torch.exp(x) - return self.torch.exp(x, out=out) - - def log( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise natural logarithm using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.log.html - """ - if out is None: - return self.torch.log(x) - return self.torch.log(x, out=out) - def where( self, condition: DenseArray | bool, - x: ArrayLike | None = None, - y: ArrayLike | None = None, + x: DenseArray, + y: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: - """ - Select values conditionally using PyTorch. - - Input: - condition: Boolean tensor or scalar; x, y: Values to select. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.where.html - """ - if x is None and y is None: - return self.torch.where(condition) - if x is None or y is None: - raise TypeError("where requires both x and y when either is provided.") if out is None: return self.torch.where(condition, x, y) return self.torch.where(condition, x, y, out=out) - def maximum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise maximum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.maximum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.maximum(x, y) - return self.torch.maximum( - x, - y, - out=out, - ) - - def minimum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise minimum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.minimum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.minimum(x, y) - return self.torch.minimum( - x, - y, - out=out, - ) - - def clip( - self, - x: DenseArray, - a_min: ArrayLike | None = None, - a_max: ArrayLike | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Clip tensor values using PyTorch. - - Input: - x: Dense backend tensor; a_min, a_max: Lower and upper bounds. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.clamp.html - """ - if out is None: - return self.torch.clamp(x, min=a_min, max=a_max) - return self.torch.clamp(x, min=a_min, max=a_max, out=out) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isfinite.html - """ - return self.torch.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isnan.html - """ - return self.torch.isnan(x) - def concatenate( self, arrays: Sequence[DenseArray], @@ -1773,131 +567,12 @@ def concatenate( *, out: DenseArray | None = None, ) -> DenseArray: - """ - Concatenate tensors using PyTorch. - - Input: - arrays: Sequence of dense backend tensors; axis: Concatenation axis; dtype: Optional cast. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.cat.html - """ - arrays = tuple(arrays) if out is None: - result = self.torch.cat(arrays, dim=axis) + result = self.torch.cat(tuple(arrays), dim=axis) else: - result = self.torch.cat(arrays, dim=axis, out=out) + result = self.torch.cat(tuple(arrays), dim=axis, out=out) return self.astype(result, dtype) if dtype is not None else result - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Take tensor elements by index using PyTorch. - - Input: - x: Dense backend tensor; indices: Integer indices; axis: Optional selection axis. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.take.html - """ - if axis is None: - result = self.torch.take(x, indices) - if out is not None: - out.copy_(result) - return out - return result - return self.torch.index_select(x, dim=axis, index=indices, out=out) - - def diag( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Extract or construct a diagonal tensor using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diag.html - """ - return self.torch.diag(x, diagonal=k, out=out) - - def diagonal(self, x: DenseArray, offset: int = 0, axis1: int = 0, axis2: int = 1) -> DenseArray: - """ - Return a tensor diagonal using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2: Diagonal controls. - - Output: - Dense backend tensor view or value. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2) - - def tril( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the lower triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tril.html - """ - return self.torch.tril(x, diagonal=k, out=out) - - def triu( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the upper triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.triu.html - """ - return self.torch.triu(x, diagonal=k, out=out) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed tensor values using PyTorch. @@ -2106,21 +781,6 @@ def cond(self, pred: bool, true_fun: Callable[[T], R], false_fun: Callable[[T], """ return true_fun(*operands) if bool(pred) else false_fun(*operands) - def allclose(self, a: DenseArray, b: DenseArray, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False) -> bool: - """ - Compare dense tensors elementwise within tolerances using PyTorch. - - Input: - a, b: Dense backend tensors; rtol, atol, equal_nan: Comparison controls. - - Output: - Boolean indicating whether tensors are close. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.allclose.html - """ - return bool(self.torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - def allclose_sparse(self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8) -> bool: """ Compare sparse tensors elementwise within tolerances using PyTorch. diff --git a/spacecore/functional/__init__.py b/spacecore/functional/__init__.py new file mode 100644 index 0000000..ef2022d --- /dev/null +++ b/spacecore/functional/__init__.py @@ -0,0 +1,17 @@ +"""Scalar-valued functionals and composition helpers.""" + +from ._base import Functional +from ._composed import ComposedFunctional, make_functional_composed +from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional +from ._quadratic import LinOpQuadraticForm, QuadraticForm + +__all__ = [ + "ComposedFunctional", + "Functional", + "InnerProductFunctional", + "LinearFunctional", + "LinOpQuadraticForm", + "MatrixFreeLinearFunctional", + "QuadraticForm", + "make_functional_composed", +] diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py new file mode 100644 index 0000000..9cf86ce --- /dev/null +++ b/spacecore/functional/_base.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from .._contextual import ContextBound, resolve_context_priority +from ..backend import Context +from ..space import Space + +if TYPE_CHECKING: + from ..linop import LinOp + + +Domain = TypeVar("Domain", bound=Space) + + +class Functional(ContextBound, Generic[Domain]): + r""" + Scalar-valued map on a space. + + ``Functional`` represents a map ``F : X -> K`` without assuming any storage + model. It mirrors the minimal ``LinOp`` contract: the domain is converted + into the resolved context, value checks follow ``ctx.enable_checks``, and + batched evaluation is implemented by a backend ``vmap`` fallback. + + Parameters + ---------- + dom : Space + Domain space ``X``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + + Attributes + ---------- + dom : Space + Domain space converted to ``ctx``. + ctx : Context + Resolved backend context. + """ + + def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: + ctx = resolve_context_priority(ctx, dom) + super().__init__(ctx) + self.dom = dom.convert(self.ctx) + self._enable_checks = self.ctx.enable_checks + + @property + def domain(self) -> Domain: + """Domain space of this scalar-valued map.""" + return self.dom + + @abstractmethod + def value(self, x: Any) -> Any: + """Evaluate this functional at an element of ``self.domain``.""" + + def __call__(self, x: Any) -> Any: + """Evaluate this functional at ``x``.""" + return self.value(x) + + def compose(self, A: "LinOp") -> "Functional": + """ + Return the pull-back ``self o A``. + + Parameters + ---------- + A : LinOp + Linear operator whose codomain matches this functional's domain. + + Returns + ------- + Functional + Functional on ``A.domain`` evaluating ``self.value(A.apply(x))``. + """ + from ._composed import make_functional_composed + + return make_functional_composed(self, A) + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate this functional independently over leading batch axes.""" + return self._fallback_vvalue(xs, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + """Infer leading batch dimensions from a value and base space.""" + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + """Return the batch space used to validate batched inputs.""" + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + """Return the batch space corresponding to a batched output.""" + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]: + """Return batch shape or raise when batch axes are not leading.""" + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + expected_axes = tuple(range(len(batch_shape))) + if batch_axes != expected_axes: + raise ValueError( + "Functional batching currently expects leading batch axes; " + f"got batch_axes={batch_axes}, expected {expected_axes}." + ) + return batch_shape + + def _vmap_leading(self, fn: Any, batch_ndim: int) -> Any: + """Vectorize ``fn`` over ``batch_ndim`` leading axes.""" + mapped = fn + for _ in range(batch_ndim): + mapped = self.ops.vmap(mapped, in_axes=0, out_axes=0) + return mapped + + def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None: + """Raise if scalar batch output does not have ``batch_shape``.""" + shape = tuple(getattr(values, "shape", ())) + if shape != batch_shape: + raise ValueError( + f"Expected scalar batch output with shape {batch_shape}, got {shape}." + ) + + def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate this functional over a leading batch with backend ``vmap``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self._vmap_leading(self.value, len(batch_shape))(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def assert_domain(self, x: Any) -> None: + """Raise if ``x`` is not in the domain.""" + self.dom.check_member(x) + + @abstractmethod + def tree_flatten(self): + """Flatten this functional for pytree registration.""" + ... + + @classmethod + @abstractmethod + def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" + ... diff --git a/spacecore/functional/_composed.py b/spacecore/functional/_composed.py new file mode 100644 index 0000000..8555b5a --- /dev/null +++ b/spacecore/functional/_composed.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import Any + +from ._base import Functional +from ._linear import InnerProductFunctional +from ._quadratic import LinOpQuadraticForm +from .._checks import checked_method +from ..backend import Context, jax_pytree_class +from ..linop import LinOp + + +def _require_composable(F: Functional, A: LinOp) -> None: + """Raise unless ``F`` can be composed with ``A``.""" + if not isinstance(F, Functional): + raise TypeError(f"F must be a Functional, got {type(F).__name__}.") + if not isinstance(A, LinOp): + raise TypeError(f"A must be a LinOp, got {type(A).__name__}.") + if A.codomain != F.domain: + raise ValueError( + "Functional composition requires A.codomain == F.domain; " + f"got {A.codomain!r} and {F.domain!r}." + ) + + +def make_functional_composed(F: Functional, A: LinOp) -> Functional: + """ + Return the pull-back ``F o A`` with local specializations. + + Parameters + ---------- + F : Functional + Functional defined on ``A.codomain``. + A : LinOp + Linear operator whose codomain is ``F.domain``. + + Returns + ------- + Functional + Specialized pull-back when available, otherwise + :class:`ComposedFunctional`. + """ + _require_composable(F, A) + if isinstance(F, InnerProductFunctional): + return InnerProductFunctional(A.H.apply(F.representer), A.domain, A.ctx) + if isinstance(F, LinOpQuadraticForm): + Q = A.H @ F.Q @ A + linear = None if F.linear is None else F.linear.compose(A) + return LinOpQuadraticForm(Q, linear, F.a, A.ctx) + return ComposedFunctional(F, A) + + +@jax_pytree_class +class ComposedFunctional(Functional): + """ + Generic pull-back of a functional through a linear operator. + + ``ComposedFunctional(F, A)`` represents ``x -> F(A x)`` on ``A.domain``. + + Parameters + ---------- + F : Functional + Functional defined on ``A.codomain``. + A : LinOp + Linear operator whose codomain is ``F.domain``. + """ + + def __init__(self, F: Functional, A: LinOp) -> None: + _require_composable(F, A) + super().__init__(A.domain, A.ctx) + self.F = F.convert(A.ctx) + self.A = A + + @checked_method(in_space="domain") + def value(self, x: Any) -> Any: + """ + Evaluate ``F(A x)``. + + Parameters + ---------- + x: + Element of ``A.domain``. + + Returns + ------- + Any + Scalar-like value returned by the composed functional. + """ + return self.F.value(self.A.apply(x)) + + def __eq__(self, other: Any) -> bool: + """Return whether another composed functional has the same operands.""" + if type(other) is type(self): + return self.F == other.F and self.A == other.A + return False + + def tree_flatten(self): + """Flatten this functional for pytree registration.""" + children = (self.F, self.A) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" + F, A = children + return cls(F, A) + + def _convert(self, new_ctx: Context) -> ComposedFunctional: + """Convert the composed functional and operator to ``new_ctx``.""" + return ComposedFunctional(self.F.convert(new_ctx), self.A.convert(new_ctx)) diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py new file mode 100644 index 0000000..1dbc887 --- /dev/null +++ b/spacecore/functional/_linear.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Callable + +from ._base import Domain, Functional +from .._checks import checked_method +from ..backend import Context, jax_pytree_class +from ..space import Space + + +def _convert_space_element(space: Space, value: Any) -> Any: + """Convert a value recursively into a possibly product-valued space.""" + if hasattr(space, "spaces") and isinstance(value, tuple): + if len(value) != len(space.spaces): + raise ValueError( + f"Expected tuple of length {len(space.spaces)}, got {len(value)}." + ) + return tuple( + _convert_space_element(component_space, component) + for component_space, component in zip(space.spaces, value) + ) + return space.ctx.asarray(value) + + +class LinearFunctional(Functional[Domain]): + r""" + Represent a linear scalar-valued map. + + Parameters + ---------- + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + """ + + @property + @abstractmethod + def representer(self) -> Any: + """ + Riesz representer of this functional when one is explicitly available. + + Matrix-free functionals may not have a stored representer and should + raise ``NotImplementedError``. + """ + + +@jax_pytree_class +class InnerProductFunctional(LinearFunctional[Domain]): + r""" + Linear functional represented by a domain element. + + ``InnerProductFunctional(c, X)`` evaluates + :math:`\ell_c(x) = \langle c, x\rangle_X`. + + Parameters + ---------- + c : array-like + Riesz representer in ``dom``. + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + + Attributes + ---------- + representer : array-like + Stored domain element ``c``. + """ + + def __init__( + self, + c: Any, + dom: Domain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, ctx) + self._c = _convert_space_element(self.domain, c) + if self._enable_checks: + self.domain._check_member(self._c) + + @property + def representer(self) -> Any: + """Stored domain element ``c`` defining ``ell_c(x) = ``.""" + return self._c + + @checked_method(in_space="domain") + def value(self, x: Any) -> Any: + """Return ``domain.inner(representer, x)``.""" + return self.domain.inner(self._c, x) + + def __eq__(self, other: Any) -> bool: + """Return whether another inner-product functional has the same representer.""" + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose( + self.domain.flatten(self._c), + other.domain.flatten(other._c), + ) + return False + + def tree_flatten(self): + """Flatten this functional for pytree registration.""" + children = (self._c,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" + domain, ctx = aux + c = children[0] + return cls(c, domain, ctx) + + def _convert(self, new_ctx: Context) -> InnerProductFunctional: + """Convert the domain and representer to ``new_ctx``.""" + return InnerProductFunctional(self._c, self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinearFunctional(LinearFunctional[Domain]): + """ + Linear functional defined by user-supplied evaluation callables. + + ``MatrixFreeLinearFunctional(value, X)`` represents a linear scalar-valued + map on ``X`` without storing or materializing a Riesz representer. + + Parameters + ---------- + value : callable + Callable with signature ``value(x: Any) -> Any`` accepting an element of + ``dom`` and returning a scalar-like backend value. + dom : Space + Domain space of the functional. + ctx : Context, str, or None, optional + Optional context specification. An explicit context wins over inferred + and default contexts. + vvalue : callable or None, optional + Optional callable with signature ``vvalue(xs: Any) -> Any`` for batched + evaluation. If omitted, backend ``vmap`` fallback is used. + + Returns + ------- + MatrixFreeLinearFunctional + Functional using the supplied callable for scalar evaluation and, + optionally, batched scalar evaluation. + """ + + def __init__( + self, + value: Callable[[Any], Any], + dom: Domain, + ctx: Context | str | None = None, + vvalue: Callable[[Any], Any] | None = None, + ) -> None: + """ + Initialize a matrix-free linear functional. + + Parameters + ---------- + value: + Callable ``value(x)`` accepting an element of ``dom`` and returning + a scalar-like value. + dom: + Domain space of the functional. + ctx: + Optional context specification for the functional and converted + domain. + vvalue: + Optional callable ``vvalue(xs)`` accepting a batch of domain + elements and returning a batch of scalar-like values. + + Returns + ------- + None + The initializer stores the callables and converted domain on + ``self``. + """ + if not callable(value): + raise TypeError(f"value must be callable, got {type(value).__name__}.") + if vvalue is not None and not callable(vvalue): + raise TypeError(f"vvalue must be callable, got {type(vvalue).__name__}.") + super().__init__(dom, ctx) + self.value_fn = value + self.vvalue_fn = vvalue + + @property + def representer(self) -> Any: + """ + Raise because matrix-free functionals do not store a representer. + + Parameters + ---------- + None + + Returns + ------- + Any + This property never returns; it raises ``NotImplementedError``. + """ + raise NotImplementedError( + f"{type(self).__name__} does not store a Riesz representer." + ) + + @checked_method(in_space="domain") + def value(self, x: Any) -> Any: + """ + Evaluate the scalar functional. + + Parameters + ---------- + x: + Element of ``self.domain`` passed to ``value_fn``. + + Returns + ------- + Any + Scalar-like backend value returned by ``value_fn``. + """ + y = self.value_fn(x) + if self._enable_checks: + self._check_scalar_batch(y, ()) + return y + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """ + Evaluate the scalar functional over a batch of domain elements. + + Parameters + ---------- + xs: + Batched element of ``self.domain``. + batch_space: + Optional batch-space descriptor for ``xs``. + + Returns + ------- + Any + Backend array of scalar-like values with shape matching the leading + batch shape. + """ + if self.vvalue_fn is None: + return super().vvalue(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self.vvalue_fn(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def __eq__(self, other: Any) -> bool: + """Return whether another matrix-free functional uses the same callables.""" + if type(other) is type(self): + return ( + self.domain == other.domain + and self.value_fn is other.value_fn + and self.vvalue_fn is other.vvalue_fn + ) + return False + + def tree_flatten(self): + """Flatten this functional for pytree registration.""" + children = () + aux = (self.value_fn, self.domain, self.ctx, self.vvalue_fn) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this functional from pytree data.""" + value_fn, domain, ctx, vvalue_fn = aux + return cls(value_fn, domain, ctx, vvalue_fn) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinearFunctional: + """ + Convert this functional to ``new_ctx``. + + Parameters + ---------- + new_ctx: + Concrete target context for the converted domain. + + Returns + ------- + MatrixFreeLinearFunctional + Functional with converted domain and the same user-supplied + callables. + """ + return MatrixFreeLinearFunctional( + self.value_fn, + self.domain.convert(new_ctx), + new_ctx, + self.vvalue_fn, + ) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py new file mode 100644 index 0000000..2e0d889 --- /dev/null +++ b/spacecore/functional/_quadratic.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import Any + +from ._base import Domain, Functional +from ._linear import LinearFunctional +from .._checks import checked_method +from .._contextual import resolve_context_priority +from ..backend import Context, jax_pytree_class +from ..linop import LinOp +from ..space import Space + + +class QuadraticForm(Functional[Domain]): + """ + Represent a scalar quadratic objective on a space. + + Parameters + ---------- + dom : Space + Domain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom``. + """ + + def hess_apply(self, x: Any) -> Any: + """Apply the Hessian action at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define hess_apply.") + + def grad(self, x: Any) -> Any: + """Gradient at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define grad.") + + def vgrad(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate ``grad`` independently over leading batch axes.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + grads = self._vmap_leading(self.grad, len(batch_shape))(xs) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(grads) + return grads + + +@jax_pytree_class +class LinOpQuadraticForm(QuadraticForm[Domain]): + r""" + Represent a quadratic form backed by a linear operator. + + Assumption: + Q is Hermitian/self-adjoint. Under this assumption, + grad f(x) = Q x. + + Non-Hermitian operators are not supported here. If users need the + Hermitian part, they must construct 0.5 * (Q + Q.H) explicitly. + + The full objective is ``q(x) = 1/2 * + linear(x) + a`` with + ``Q : X -> X``. Structurally available dense and diagonal operators are + checked at construction. Matrix-free operators are not validated; correctness + is the caller's responsibility. + + Parameters + ---------- + Q : LinOp + Hermitian operator from a space to itself. + linear : LinearFunctional or None, optional + Optional linear term on ``Q.domain``. + a : scalar-like, optional + Constant scalar offset. Default is 0. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``Q`` and + ``linear``. + + Attributes + ---------- + Q : LinOp + Stored Hermitian operator. + linear : LinearFunctional or None + Stored linear term. + a : scalar-like + Stored scalar offset. + """ + + def __init__( + self, + Q: LinOp[Domain, Domain], + linear: LinearFunctional[Domain] | None = None, + a: Any = 0, + ctx: Context | str | None = None, + ) -> None: + if not isinstance(Q, LinOp): + raise TypeError(f"Q must be a LinOp, got {type(Q).__name__}.") + if linear is not None and not isinstance(linear, LinearFunctional): + raise TypeError( + f"linear must be a LinearFunctional or None, got {type(linear).__name__}." + ) + + resolved_ctx = resolve_context_priority(ctx, Q.domain, Q, linear) + Q = Q.convert(resolved_ctx) + if Q.domain != Q.codomain: + raise ValueError("LinOpQuadraticForm requires Q.domain == Q.codomain.") + self._check_hermitian_structure(Q) + if linear is not None: + linear = linear.convert(resolved_ctx) + if linear.domain != Q.domain: + raise ValueError("linear.domain must match Q.domain.") + + super().__init__(Q.domain, resolved_ctx) + self.Q = Q + self.linear = linear + self.a = self.ctx.asarray(a) + self._check_scalar_batch(self.a, ()) + + @staticmethod + def _check_hermitian_structure(Q: LinOp[Domain, Domain]) -> None: + """Raise when ``Q`` is structurally known to be non-Hermitian.""" + result = Q.is_hermitian() + if result is False: + raise ValueError("LinOpQuadraticForm requires Q to be Hermitian/self-adjoint.") + + @checked_method(in_space="domain") + def value(self, x: Any) -> Any: + """Return ``1/2 * + linear(x) + a``.""" + qx = self.Q.apply(x) + value = 0.5 * self.domain.inner(x, qx) + if self.linear is not None: + value = value + self.linear.value(x) + return value + self.a + + @checked_method(in_space="domain", out_space="domain") + def grad(self, x: Any) -> Any: + """ + Return the Euclidean/Riesz gradient. + + ``LinOpQuadraticForm`` assumes ``Q`` is Hermitian/self-adjoint, so the + quadratic contribution is exactly ``Q.apply(x)``. + """ + grad = self.Q.apply(x) + if self.linear is not None: + grad = self.domain.add(grad, self.linear.representer) + return grad + + @checked_method(in_space="domain", out_space="domain") + def hess_apply(self, x: Any) -> Any: + """Return the Hessian action ``Q x`` under the Hermitian assumption.""" + return self.Q.apply(x) + + def __eq__(self, other: Any) -> bool: + """Return whether another quadratic form has the same stored terms.""" + if type(other) is type(self): + return ( + self.Q == other.Q + and self.linear == other.linear + and self.ops.allclose(self.a, other.a) + ) + return False + + def tree_flatten(self): + """Flatten this quadratic form for pytree registration.""" + children = (self.Q, self.linear, self.a) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this quadratic form from pytree data.""" + Q, linear, a = children + return cls(Q, linear, a, Q.ctx) + + def _convert(self, new_ctx: Context) -> LinOpQuadraticForm: + """Convert stored terms to ``new_ctx``.""" + linear = None if self.linear is None else self.linear.convert(new_ctx) + return LinOpQuadraticForm(self.Q.convert(new_ctx), linear, self.a, new_ctx) diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py new file mode 100644 index 0000000..49a0509 --- /dev/null +++ b/spacecore/linalg/__init__.py @@ -0,0 +1,22 @@ +"""Iterative linear algebra solvers and Krylov algorithms.""" + +from __future__ import annotations + +from ._cg import CGResult, cg +from ._expm import ExpmMultiplyResult, expm_multiply +from ._lanczos import LanczosResult, lanczos_smallest +from ._lsqr import LSQRResult, lsqr +from ._power import PowerIterationResult, power_iteration + +__all__ = [ + "CGResult", + "ExpmMultiplyResult", + "LanczosResult", + "LSQRResult", + "PowerIterationResult", + "cg", + "expm_multiply", + "lanczos_smallest", + "lsqr", + "power_iteration", +] diff --git a/spacecore/linalg/_cg.py b/spacecore/linalg/_cg.py new file mode 100644 index 0000000..06dcf2f --- /dev/null +++ b/spacecore/linalg/_cg.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import is_converged, real_inner, require_linop, require_square +from ._utils import result_repr, safe_inverse_nonneg, should_check_iteration, threshold + + +class CGResult(NamedTuple): + """ + Store the result returned by :func:`cg`. + + Parameters + ---------- + x : array-like + Approximate solution in ``A.domain``. + converged : bool-like + Whether the final residual norm satisfied the requested tolerance. + num_iters : int-like + Number of conjugate-gradient iterations executed. + residual_norm : scalar + Norm of the final residual in ``A.codomain``. + """ + + x: Any + converged: Any + num_iters: Any + residual_norm: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "CGResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "x": self.x, + }, + ) + + +def cg( + A: LinOp, + b: Any, + *, + x0: Any | None = None, + tol: float = 1e-6, + atol: float = 0.0, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> CGResult: + r""" + Solve :math:`A x = b` by conjugate gradients. + + Require ``A`` to be square in the SpaceCore sense + (``A.domain == A.codomain``), Hermitian, and positive-definite with respect + to ``A.domain.inner``. The implementation uses only :meth:`LinOp.apply` and + the domain-space inner product; it never materializes a dense matrix. + + Parameters + ---------- + A : LinOp + Linear operator that must be Hermitian positive-definite with respect + to ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, + including the underlying space type and inner-product geometry. + Hermiticity and positive-definiteness are not validated by ``cg``; + indefinite or non-Hermitian operators can diverge or produce NaN + outputs without an explicit error. + b : array-like + Right-hand side in ``A.codomain``. + x0 : array-like or None, optional + Initial guess in ``A.domain``. Default is the zero vector. + tol : float, optional + Relative tolerance on the linear-system residual. ``result.converged`` + is ``True`` when the residual norm is below + ``atol + tol * norm(b)``. Default is 1e-6. + atol : float, optional + Absolute residual tolerance. Default is 0.0. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh convergence diagnostics every this many iterations and always + on the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + CGResult + Named tuple with fields: + + - ``x``: approximate solution in ``A.domain`` + - ``converged``: whether the requested tolerance was met + - ``num_iters``: number of iterations executed + - ``residual_norm``: final residual norm + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square or if iteration parameters are invalid. + + See Also + -------- + lsqr : Solve least-squares systems for rectangular operators. + lanczos_smallest : Approximate the smallest eigenpair of a Hermitian + operator. + + Notes + ----- + The residual norm is compared with + :math:`\text{atol} + \text{tol} \| b \|` only every ``check_every`` + iterations, and always on the final iteration. This keeps convergence + checks out of the hot loop while remaining compatible with JAX JIT control + flow. ``maxiter`` and ``check_every`` should be treated as static JAX + arguments. + + For complex operators, residual norms and step sizes are computed from the + real part of ``A.domain.inner(x, y)``. SpaceCore's complex inner-product + convention conjugates the first argument; custom :class:`Space` subclasses + must follow that convention for CG to converge correctly. + + References + ---------- + Hestenes, M. R. and Stiefel, E., "Methods of Conjugate Gradients + for Solving Linear Systems," J. Res. Natl. Bur. Stand., 49 (1952), + 409-436. + + Examples + -------- + Solve a small positive-definite system. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> M = ctx.asarray([[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]]) + >>> A = sc.DenseLinOp(M, X, X, ctx) + >>> b = ctx.asarray([1.0, 2.0, 3.0]) + >>> result = sc.cg(A, b, tol=1e-10) + >>> bool(result.converged) + True + >>> np.allclose(A.apply(result.x), b) + True + """ + A = require_linop(A) + require_square(A, "cg") + A.codomain.check_member(b) + maxiter = check_maxiter(maxiter, A) + check_every = check_interval(check_every) + + x = A.domain.zeros() if x0 is None else x0 + A.domain.check_member(x) + r = A.codomain.add(b, A.codomain.scale(-1.0, A.apply(x))) + p = r + rs = real_inner(A.domain, r, r) + residual_norm = A.domain.norm(r) + threshold_value = threshold(A.codomain.norm(b), tol, atol) + eps = A.ops.asarray(A.ops.eps(A.dtype), dtype=A.dtype) + + def cond_fun(carry: tuple[Any, Any, Any, Any, Any, int]) -> Any: + _x, _r, _p, _rs, res_norm, k = carry + return (k < maxiter) & (res_norm > threshold_value) + + def body_fun(carry: tuple[Any, Any, Any, Any, Any, int]) -> tuple[Any, Any, Any, Any, Any, int]: + x, r, p, rs, _residual_norm, k = carry + Ap = A.apply(p) + pAp = real_inner(A.domain, p, Ap) + active = (rs > eps) & (pAp > eps) + alpha = A.ops.where(active, rs * safe_inverse_nonneg(A.ops, pAp), A.ops.zeros_like(rs)) + x_next = A.domain.axpy(alpha, p, x) + r_next = A.codomain.axpy(-alpha, Ap, r) + rs_next = real_inner(A.domain, r_next, r_next) + beta = A.ops.where(active, rs_next * safe_inverse_nonneg(A.ops, rs), A.ops.zeros_like(rs_next)) + p_next = A.domain.axpy(beta, p, r_next) + k_next = k + 1 + current_residual_norm = A.domain.norm(r_next) + residual_norm_next = A.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + lambda _: current_residual_norm, + lambda _: _residual_norm, + A.ops.asarray(0.0, dtype=A.dtype), + ) + return x_next, r_next, p_next, rs_next, residual_norm_next, k_next + + x, _r, _p, _rs, residual_norm, num_iters = A.ops.while_loop( + cond_fun, + body_fun, + (x, r, p, rs, residual_norm, 0), + ) + return CGResult(x, is_converged(residual_norm, threshold_value), num_iters, residual_norm) diff --git a/spacecore/linalg/_expm.py b/spacecore/linalg/_expm.py new file mode 100644 index 0000000..5b7d87a --- /dev/null +++ b/spacecore/linalg/_expm.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._lanczos import _check_lanczos_max_iter, _lanczos_basis_and_tridiag +from ._utils import require_linop, require_square, result_repr + + +class ExpmMultiplyResult(NamedTuple): + """ + Store the result returned by :func:`expm_multiply`. + + Parameters + ---------- + result : array-like + Vector in the domain of the input operator approximating + ``exp(t * A) @ v``. + krylov_dim : int-like + Actual Krylov dimension reached before breakdown or ``max_iter``. + residual_estimate : scalar + Projected exponential residual estimate + ``abs(beta[m] * phi[m - 1])``. + converged : bool-like + Boolean indicating whether ``residual_estimate < tol``. + """ + + result: Any + krylov_dim: Any + residual_estimate: Any + converged: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full vector.""" + return result_repr( + "ExpmMultiplyResult", + { + "converged": self.converged, + "krylov_dim": self.krylov_dim, + "residual_estimate": self.residual_estimate, + "result": self.result, + }, + ) + + +def expm_multiply( + A: LinOp, + v: Any, + t: float | complex = 1.0, + *, + max_iter: int = 30, + tol: float = 1e-10, +) -> ExpmMultiplyResult: + r""" + Compute :math:`\exp(t A) v` by Krylov projection. + + Require ``A`` to be square in the SpaceCore sense + (``A.domain == A.codomain``) and Hermitian with respect to + ``A.domain.inner``. The method builds a Lanczos basis and applies the + exponential of the small tridiagonal projection, avoiding dense + materialization of ``A``. + + Parameters + ---------- + A : LinOp + Linear operator that must be Hermitian/self-adjoint with respect to + ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including + the underlying space type and inner-product geometry. Operators with + structurally unknown Hermiticity (``A.is_hermitian()`` returns + ``None``) are accepted on trust; the caller is responsible for ensuring + Hermiticity. Non-Hermitian inputs produce undefined results. + v : array-like + Initial vector in ``A.domain``. + t : float or complex, optional + Scalar multiplier on ``A``. Complex values require a complex-valued + ``ctx.dtype`` such as ``complex64`` or ``complex128``. Using a complex + ``t`` with a real-valued context produces backend-dependent results. + Default is 1.0. + max_iter : int, optional + Maximum Krylov dimension. Values around 20-50 are usually sufficient + when :math:`|t|\|A\|` is moderate. Must be a Python ``int`` rather + than a traced JAX scalar; under ``jax.jit`` it is treated as a static + argument and changing it triggers retracing. Default is 30. + tol : float, optional + Tolerance used both for Lanczos breakdown and for the convergence flag: + ``result.converged`` is ``True`` when the projected exponential + residual estimate is below ``tol``. Default is 1e-10. + + Returns + ------- + ExpmMultiplyResult + Result vector in ``A.domain``, the Krylov dimension used, the standard + estimate ``abs(beta[m] * phi[m - 1])``, and a convergence flag. + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square, is known to be non-Hermitian, or if + ``max_iter`` is invalid. + + See Also + -------- + lanczos_smallest : Build the related Hermitian Krylov projection. + power_iteration : Estimate a dominant eigenpair. + + Notes + ----- + The projected exponential is computed as + :math:`\exp(t T) e_0` using an eigendecomposition of the small real + symmetric tridiagonal matrix ``T``. This is JIT-compatible on the JAX + backend when ``max_iter`` is static. + + Hermiticity is enforced only when it can be structurally verified: known + non-Hermitian operators raise ``ValueError``. Operators with unknown + structure, such as many matrix-free operators and operators on custom + spaces, are trusted. + + The returned residual estimate is + :math:`|\beta_m \phi_{m-1}|`, where ``phi`` is the projected exponential + vector. Callers that need the true residual can perform one additional + operator application. + + Examples + -------- + Apply the exponential of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([0.0, 1.0]), X, ctx) + >>> v = ctx.asarray([2.0, 3.0]) + >>> result = sc.expm_multiply(A, v, t=0.5, max_iter=5) + >>> np.allclose(result.result, [2.0, 3.0 * np.exp(0.5)], atol=1e-10) + True + """ + A = require_linop(A) + require_square(A, "expm_multiply") + if A.is_hermitian() is False: + raise ValueError("expm_multiply requires A to be Hermitian/self-adjoint.") + max_iter = _check_lanczos_max_iter(max_iter) + A.domain.check_member(v) + + ops = A.ops + ctx = A.ctx + real_dtype = ops.real_dtype(ctx.dtype) + basis = _lanczos_basis_and_tridiag(A, v, max_iter, tol, real_dtype, check_every=1) + + eigvals, eigvecs = ops.eigh(basis.T) + exp_eigs = ops.exp(t * eigvals) + expT_e1 = eigvecs @ (exp_eigs * eigvecs[0, :]) + + V_reduced = basis.V[:max_iter, :] + result_flat = basis.initial_norm * ops.einsum("j,jn->n", expT_e1, V_reduced) + result = A.domain.unflatten(result_flat) + + last_coeff = ops.abs(expT_e1[basis.krylov_dim - 1]) + residual_estimate = basis.betas[basis.krylov_dim] * last_coeff + converged = residual_estimate < basis.tol + + return ExpmMultiplyResult(result, basis.krylov_dim, residual_estimate, converged) diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py new file mode 100644 index 0000000..b1e1d13 --- /dev/null +++ b/spacecore/linalg/_lanczos.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + + +from ..linop import LinOp +from ..space import VectorSpace +from ..types import DenseArray +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval +from ._utils import require_linop, require_square, safe_inverse_nonneg, should_check_iteration +from ._utils import result_repr + + +class LanczosResult(NamedTuple): + """ + Store the result returned by :func:`lanczos_smallest`. + + Parameters + ---------- + eigenvalue : scalar + Ritz approximation to the smallest eigenvalue. + eigenvector : array-like + Ritz vector in ``A.domain``. + residual_norm : scalar + Standard Ritz residual estimate. + krylov_dim : int-like + Krylov dimension reached before breakdown or ``max_iter``. + converged : bool-like + Whether ``residual_norm < tol``. + """ + + eigenvalue: Any + eigenvector: Any + residual_norm: Any + krylov_dim: Any + converged: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "LanczosResult", + { + "eigenvalue": self.eigenvalue, + "eigenvector": self.eigenvector, + "residual_norm": self.residual_norm, + "krylov_dim": self.krylov_dim, + "converged": self.converged, + }, + ) + + +class _LanczosBasisResult(NamedTuple): + """Store fixed-size Lanczos basis data and tridiagonal projection.""" + + V: DenseArray + T: DenseArray + alphas: DenseArray + betas: DenseArray + krylov_dim: Any + initial_norm: Any + tol: Any + e0_unit: DenseArray + + +def _check_lanczos_max_iter(max_iter: int) -> int: + """Validate and normalize the maximum Lanczos iteration count.""" + max_iter = int(max_iter) + if max_iter < 1: + raise ValueError("max_iter must be positive.") + return max_iter + + +def _build_tridiagonal( + ops: Any, + alphas: DenseArray, + betas: DenseArray, + max_iter: int, + m: Any, + real_dtype: Any, +) -> DenseArray: + """Build the fixed-size tridiagonal Lanczos projection.""" + idx = ops.arange(max_iter) + full_indices = ops.arange(max_iter + 1) + mask_alpha = idx < m + inactive_sentinel = ( + ops.max(ops.abs(alphas)) + + 2.0 * ops.max(ops.abs(betas)) + + ops.asarray(1.0, dtype=real_dtype) + ) + alphas_full = ops.where(mask_alpha, alphas, inactive_sentinel) + betas_full = ops.where(full_indices == m, ops.asarray(0.0, dtype=real_dtype), betas) + + T = ops.zeros((max_iter, max_iter), dtype=real_dtype) + + def fill_diag(ii: int, T_in: DenseArray) -> DenseArray: + return ops.index_set(T_in, (ii, ii), alphas_full[ii], copy=True) + + T = ops.fori_loop(0, max_iter, fill_diag, T) + + def fill_off(ii: int, T_in: DenseArray) -> DenseArray: + b = betas_full[ii + 1] + T_in = ops.index_set(T_in, (ii, ii + 1), b, copy=True) + T_in = ops.index_set(T_in, (ii + 1, ii), b, copy=True) + return T_in + + return ops.fori_loop(0, max_iter - 1, fill_off, T) + + +def _lanczos_basis_and_tridiag( + A: LinOp, + initial_vector: Any, + max_iter: int, + tol: float, + real_dtype: Any, + check_every: int, +) -> _LanczosBasisResult: + """Build a Lanczos basis and tridiagonal projection.""" + ops = A.ops + ctx = A.ctx + use_euclidean_reorth = type(A.domain) is VectorSpace + + v0 = A.domain.flatten(initial_vector) + v0 = ctx.assert_dense(v0) + n = v0.shape[0] + + V = ops.zeros((max_iter + 1, n), dtype=ctx.dtype) + alphas = ops.zeros((max_iter,), dtype=real_dtype) + betas = ops.zeros((max_iter + 1,), dtype=real_dtype) + + tol_s = ops.asarray(tol, dtype=real_dtype) + eps_s = ops.asarray(1e-12, dtype=real_dtype) + + v0_norm = A.domain.norm(initial_vector) + + e0 = ops.zeros((n,), dtype=ctx.dtype) + e0 = ops.index_set(e0, (0,), ctx.asarray(1.0), copy=True) + e0_member = A.domain.unflatten(e0) + e0_norm = A.domain.norm(e0_member) + e0_unit = A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, e0_norm), e0_member)) + + v0_unit = ops.cond( + v0_norm > eps_s, + lambda _: A.domain.flatten( + A.domain.scale(safe_inverse_nonneg(ops, v0_norm), initial_vector) + ), + lambda _: e0_unit, + ops.asarray(0.0, dtype=real_dtype), + ) + V = ops.index_set(V, (0, slice(None)), v0_unit, copy=True) + + beta0 = ops.asarray(1.0, dtype=real_dtype) + i0 = 0 + keep_going0 = ops.asarray(True) + + full_indices = ops.arange(max_iter + 1) + coeffs_zero = ops.zeros((max_iter + 1,), dtype=ctx.dtype) + + def cond_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> Any: + i, _V, _alphas, _betas, _beta, keep_going = state + return (i < max_iter) & keep_going + + def body_fun(state: tuple[Any, Any, Any, Any, Any, Any]) -> tuple[Any, Any, Any, Any, Any, Any]: + i, V_, alphas_, betas_, beta, keep_going = state + + v_i = V_[i] + v_i_member = A.domain.unflatten(v_i) + w_member = A.apply(v_i_member) + w = A.codomain.flatten(w_member) + w = ctx.assert_dense(w) + + alpha = ops.real(A.domain.inner(v_i_member, w_member)) + alphas_ = ops.index_set(alphas_, (i,), alpha, copy=True) + + w = ops.cond( + i == 0, + lambda w_in: w_in - alpha * v_i, + lambda w_in: w_in - alpha * v_i - betas_[i] * V_[i - 1], + w, + ) + + w_member = A.domain.unflatten(w) + valid = full_indices < (i + 1) + mask = ops.where( + valid, + ops.asarray(1.0, dtype=real_dtype), + ops.asarray(0.0, dtype=real_dtype), + ) + mask = ops.astype(mask, ctx.dtype) + + if use_euclidean_reorth: + coeffs_full = ops.einsum("jn,n->j", ops.conj(V_), w) + else: + coeffs_full = coeffs_zero + + def fill_coeff(j: int, coeffs_in: DenseArray) -> DenseArray: + v_j_member = A.domain.unflatten(V_[j]) + coeff = A.domain.inner(v_j_member, w_member) + return ops.index_set(coeffs_in, (j,), coeff, copy=True) + + coeffs_full = ops.fori_loop(0, max_iter + 1, fill_coeff, coeffs_full) + coeffs_valid = coeffs_full * mask + proj = ops.sum(coeffs_valid[:, None] * V_, axis=0) + w = w - proj + + w_member = A.domain.unflatten(w) + beta_new = A.domain.norm(w_member) + betas_ = ops.index_set(betas_, (i + 1,), beta_new, copy=True) + + def set_next(V_in: DenseArray) -> DenseArray: + w_unit = A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, beta_new), w_member)) + return ops.index_set(V_in, (i + 1, slice(None)), w_unit, copy=True) + + V_ = ops.cond(beta_new >= tol_s, set_next, lambda V_in: V_in, V_) + i_next = i + 1 + keep_going_next = ops.cond( + should_check_iteration(i_next, max_iter, check_every), + lambda _: beta_new >= tol_s, + lambda _: keep_going, + ops.asarray(0.0, dtype=real_dtype), + ) + + return i_next, V_, alphas_, betas_, beta_new, keep_going_next + + i_final, V, alphas, betas, _beta_final, _keep_going = ops.while_loop( + cond_fun, body_fun, (i0, V, alphas, betas, beta0, keep_going0) + ) + T = _build_tridiagonal(ops, alphas, betas, max_iter, i_final, real_dtype) + return _LanczosBasisResult(V, T, alphas, betas, i_final, v0_norm, tol_s, e0_unit) + + +def lanczos_smallest( + A: LinOp, + initial_vector: Any, + *, + max_iter: int = 100, + tol: float = 1e-6, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> LanczosResult: + r""" + Approximate the smallest eigenpair of a Hermitian operator. + + The operator is supplied as a square ``LinOp`` in the SpaceCore sense + (``A.domain == A.codomain``), and ``initial_vector`` is an element of + ``A.domain``. The implementation keeps fixed-size coordinate arrays for JAX + compatibility, safely handles zero initial vectors, and refines the + returned eigenvalue with the Rayleigh quotient of the reconstructed Ritz + vector in the original space. + + Mathematically, Lanczos builds an orthonormal Krylov basis ``V`` for + ``span{v, A v, A^2 v, ...}`` and a tridiagonal projection + :math:`T_k = V^* A V`. The returned vector is the Ritz vector reconstructed + in the original coordinates, and the returned scalar is the Rayleigh + quotient :math:`\langle x, A x \rangle_X / \langle x, x \rangle_X`. + + Parameters + ---------- + A : LinOp + Linear operator that must be Hermitian/self-adjoint with respect to + ``A.domain.inner``. ``A.domain`` must equal ``A.codomain``, including + the underlying space type and inner-product geometry. Operators with + structurally unknown Hermiticity (``A.is_hermitian()`` returns + ``None``) are accepted on trust; the caller is responsible for ensuring + Hermiticity. Non-Hermitian inputs produce undefined results. + initial_vector : array-like + Starting vector in ``A.domain``. If it is numerically zero, the + algorithm falls back to a deterministic coordinate vector. + max_iter : int, optional + Maximum Krylov dimension. Must be a Python ``int`` rather than a + traced JAX scalar; under ``jax.jit`` it is treated as a static argument + and changing it triggers retracing. Default is 100. + tol : float, optional + Tolerance used for two purposes. Iteration stops at a check point when + the off-diagonal Lanczos coefficient falls below ``tol``; the returned + ``converged`` flag is ``True`` when the Ritz residual estimate is below + ``tol``. Default is 1e-6. + check_every : int, optional + Refresh the breakdown-based stopping decision every this many + iterations and always on the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + LanczosResult + Named tuple with fields: + + - ``eigenvalue``: smallest Ritz eigenvalue estimate + - ``eigenvector``: associated Ritz vector in ``A.domain`` + - ``residual_norm``: standard Ritz residual estimate + - ``krylov_dim``: actual Krylov dimension reached + - ``converged``: whether ``residual_norm < tol`` + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If ``A`` is not square, is known to be non-Hermitian, or if + ``max_iter`` is invalid. + + See Also + -------- + power_iteration : Estimate the dominant eigenpair. + expm_multiply : Apply a matrix exponential using the Lanczos basis. + + Notes + ----- + The residual estimate is computed from the tridiagonal recurrence as + :math:`\beta_m |y_{m-1}|`. Callers that need the true residual can evaluate + ``A.apply(eigenvector) - eigenvalue * eigenvector`` once more in the + original space. + + The "smallest Ritz value" is the smallest eigenvalue of the projected + tridiagonal matrix, not necessarily a good approximation of the smallest + eigenvalue of ``A``. Convergence to the actual smallest eigenvalue requires + the bottom of the spectrum to be separated and the initial vector to have + nonzero projection onto the corresponding eigenspace. For clustered low + eigenvalues, increase ``max_iter`` or use multiple initial vectors. + + Hermiticity is enforced only when it can be structurally verified: known + non-Hermitian operators raise ``ValueError``. Operators with unknown + structure, such as many matrix-free operators and operators on custom + spaces, are trusted. + + This function is JIT-compatible on the JAX backend when ``max_iter`` and + ``check_every`` are static arguments. For plain :class:`VectorSpace` + domains, Euclidean reorthogonalization is vectorized; custom spaces use + :meth:`Space.inner` to preserve the declared geometry. + + References + ---------- + Lanczos, C., "An Iteration Method for the Solution of the Eigenvalue + Problem of Linear Differential and Integral Operators," J. Res. Natl. + Bur. Stand., 45 (1950), 255-282. + + Examples + -------- + Approximate the smallest eigenpair of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 4.0]), X, ctx) + >>> result = sc.lanczos_smallest(A, ctx.asarray([1.0, 1.0, 1.0]), max_iter=3) + >>> np.allclose(result.eigenvalue, 1.0) + True + """ + A = require_linop(A) + require_square(A, "lanczos_smallest") + if A.is_hermitian() is False: + raise ValueError("lanczos_smallest requires A to be Hermitian/self-adjoint.") + max_iter = _check_lanczos_max_iter(max_iter) + check_every = check_interval(check_every) + A.domain.check_member(initial_vector) + ops = A.ops + ctx = A.ctx + real_dtype = ops.real_dtype(ctx.dtype) + idx = ops.arange(max_iter) + basis = _lanczos_basis_and_tridiag( + A, initial_vector, max_iter, tol, real_dtype, check_every + ) + + m = basis.krylov_dim + _eigvals, eigvecs = ops.eigh(basis.T) + y_full = eigvecs[:, 0] + residual_norm = basis.betas[m] * ops.abs(y_full[m - 1]) + converged = residual_norm < basis.tol + + mask_y = ops.where( + idx < m, + ops.asarray(1.0, dtype=real_dtype), + ops.asarray(0.0, dtype=real_dtype), + ) + mask_y = ops.astype(mask_y, y_full.dtype) + y_valid = y_full * mask_y + + V_reduced = basis.V[:max_iter, :] + x_flat = ops.einsum("j,jn->n", y_valid, V_reduced) + + x_member = A.domain.unflatten(x_flat) + x_norm = A.domain.norm(x_member) + x_flat = ops.cond( + x_norm > ops.asarray(1e-12, dtype=real_dtype), + lambda _: A.domain.flatten(A.domain.scale(safe_inverse_nonneg(ops, x_norm), x_member)), + lambda _: basis.e0_unit, + ops.asarray(0.0, dtype=real_dtype), + ) + + x = A.domain.unflatten(x_flat) + Ax = A.apply(x) + + num = ops.real(A.domain.inner(x, Ax)) + den = ops.real(A.domain.inner(x, x)) + lam = num / den + + return LanczosResult(lam, x, residual_norm, m, converged) diff --git a/spacecore/linalg/_lsqr.py b/spacecore/linalg/_lsqr.py new file mode 100644 index 0000000..7e975dc --- /dev/null +++ b/spacecore/linalg/_lsqr.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from typing import Any, NamedTuple + +from ..linop import LinOp +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import is_converged, require_linop, safe_inverse_nonneg, should_check_iteration +from ._utils import result_repr, threshold + + +class LSQRResult(NamedTuple): + """ + Store the result returned by :func:`lsqr`. + + Parameters + ---------- + x : array-like + Approximate least-squares solution in ``A.domain``. + converged : bool-like + Whether the normal-equation residual satisfied the requested tolerance. + num_iters : int-like + Number of LSQR iterations executed. + residual_norm : scalar + Norm of ``A x - b`` in ``A.codomain``. + normal_residual_norm : scalar + Norm of ``A.H @ (A x - b)`` in ``A.domain``. + """ + + x: Any + converged: Any + num_iters: Any + residual_norm: Any + normal_residual_norm: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full solution array.""" + return result_repr( + "LSQRResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "residual_norm": self.residual_norm, + "normal_residual_norm": self.normal_residual_norm, + "x": self.x, + }, + ) + + +def lsqr( + A: LinOp, + b: Any, + *, + x0: Any | None = None, + tol: float = 1e-6, + atol: float = 0.0, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> LSQRResult: + r""" + Solve :math:`\min_x \|A x - b\|` by LSQR. + + Allow ``A`` to map between distinct ``domain`` and ``codomain`` spaces. + The method uses :meth:`LinOp.apply` for forward products and ``A.H.apply`` + for adjoint products, so the normal equations are represented implicitly + and no dense matrix is formed. + + Parameters + ---------- + A : LinOp + Linear operator with possibly distinct ``domain`` and ``codomain``. + For square ``A`` (``A.domain == A.codomain``), :func:`cg` is usually + preferred when ``A`` is also Hermitian positive-definite. + b : array-like + Right-hand side in ``A.codomain``. + x0 : array-like or None, optional + Initial guess in ``A.domain``. Default is the zero vector. + tol : float, optional + Relative tolerance for the normal-equation residual + ``norm(A.H @ (A @ x - b))``. ``result.converged`` is ``True`` when that + residual is below ``atol + tol * norm(b)``. Default is 1e-6. + atol : float, optional + Absolute tolerance for the normal-equation residual. Default is 0.0. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh residual diagnostics every this many iterations and always on + the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + LSQRResult + Named tuple with fields: + + - ``x``: approximate least-squares solution in ``A.domain`` + - ``converged``: whether the requested tolerance was met + - ``num_iters``: number of iterations executed + - ``residual_norm``: final residual norm + - ``normal_residual_norm``: final normal-equation residual norm + + Raises + ------ + TypeError + If ``A`` is not a :class:`LinOp`. + ValueError + If iteration parameters are invalid. + + See Also + -------- + cg : Solve square Hermitian positive-definite systems. + power_iteration : Estimate a dominant eigenpair. + + Notes + ----- + Convergence is tested using + :math:`\|A^*(A x - b)\| < \text{atol} + \text{tol}\|b\|`. + The normal-equation residual is refreshed only every ``check_every`` + iterations, and always on the final iteration. This function is + JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are + static arguments. + + The normal-equation residual can be much smaller than the solution error + for ill-conditioned ``A``. For ill-conditioned problems, use a tighter + ``tol`` or check the residual and solution quality directly. + + Works on real and complex operators. For complex operators, ``A.H`` uses + the conjugate adjoint. + + References + ---------- + Paige, C. C. and Saunders, M. A., "LSQR: An Algorithm for Sparse + Linear Equations and Sparse Least Squares," ACM Trans. Math. Soft., + 8 (1982), 43-71. + + Examples + -------- + Solve a small overdetermined least-squares problem. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> Y = sc.VectorSpace((3,), ctx) + >>> M = ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + >>> A = sc.DenseLinOp(M, X, Y, ctx) + >>> b = ctx.asarray([1.0, 2.0, 3.0]) + >>> result = sc.lsqr(A, b, tol=1e-10) + >>> np.allclose(result.x, [1.0, 2.0]) + True + """ + A = require_linop(A) + A.codomain.check_member(b) + maxiter = check_maxiter(maxiter, A) + check_every = check_interval(check_every) + + x = A.domain.zeros() if x0 is None else x0 + A.domain.check_member(x) + residual = A.codomain.add(b, A.codomain.scale(-1.0, A.apply(x))) + beta = A.codomain.norm(residual) + normal_residual_norm = A.domain.norm(A.H.apply(residual)) + u = residual + u = A.codomain.scale(safe_inverse_nonneg(A.ops, beta), u) + v = A.H.apply(u) + alpha = A.domain.norm(v) + v = A.domain.scale(safe_inverse_nonneg(A.ops, alpha), v) + w = v + phi_bar = beta + rho_bar = alpha + residual_norm = beta + threshold_value = threshold(A.codomain.norm(b), tol, atol) + + def cond_fun(carry: tuple[Any, ...]) -> Any: + _x, _u, _v, _w, _alpha, _beta, _rho_bar, _phi_bar, _res_norm, norm_res, k = carry + return (k < maxiter) & (norm_res > threshold_value) + + def body_fun(carry: tuple[Any, ...]) -> tuple[Any, ...]: + x, u, v, w, alpha, _beta, rho_bar, phi_bar, _residual_norm, _normal_residual, k = carry + u_next = A.codomain.axpy(-alpha, u, A.apply(v)) + beta_next = A.codomain.norm(u_next) + u_next = A.codomain.scale(safe_inverse_nonneg(A.ops, beta_next), u_next) + + v_next = A.domain.axpy(-beta_next, v, A.H.apply(u_next)) + alpha_next = A.domain.norm(v_next) + v_next = A.domain.scale(safe_inverse_nonneg(A.ops, alpha_next), v_next) + + rho = A.ops.sqrt(rho_bar * rho_bar + beta_next * beta_next) + inv_rho = safe_inverse_nonneg(A.ops, rho) + c = rho_bar * inv_rho + s = beta_next * inv_rho + theta = s * alpha_next + rho_bar_next = -c * alpha_next + phi = c * phi_bar + phi_bar_next = s * phi_bar + + x_next = A.domain.axpy(phi * inv_rho, w, x) + w_next = A.domain.axpy(-(theta * inv_rho), w, v_next) + k_next = k + 1 + + def refresh_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]: + x_candidate, _old_residual_norm, _old_normal_residual = payload + residual_next = A.codomain.add(A.apply(x_candidate), A.codomain.scale(-1.0, b)) + return A.codomain.norm(residual_next), A.domain.norm(A.H.apply(residual_next)) + + def keep_residuals(payload: tuple[Any, Any, Any]) -> tuple[Any, Any]: + _x_candidate, old_residual_norm, old_normal_residual = payload + return old_residual_norm, old_normal_residual + + residual_norm_next, normal_residual_norm_next = A.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + refresh_residuals, + keep_residuals, + (x_next, _residual_norm, _normal_residual), + ) + return ( + x_next, + u_next, + v_next, + w_next, + alpha_next, + beta_next, + rho_bar_next, + phi_bar_next, + residual_norm_next, + normal_residual_norm_next, + k_next, + ) + + x, *_rest, residual_norm, normal_residual_norm, num_iters = A.ops.while_loop( + cond_fun, + body_fun, + (x, u, v, w, alpha, beta, rho_bar, phi_bar, residual_norm, normal_residual_norm, 0), + ) + return LSQRResult( + x, + is_converged(normal_residual_norm, threshold_value), + num_iters, + residual_norm, + normal_residual_norm, + ) diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py new file mode 100644 index 0000000..1bdd0ce --- /dev/null +++ b/spacecore/linalg/_power.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, NamedTuple + +from ..backend import Context +from ..functional import QuadraticForm +from ..linop import LinOp +from ..space import Space +from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter +from ._utils import default_initial_vector, is_converged, normalize, require_linop +from ._utils import require_square, result_repr, should_check_iteration + + +class PowerIterationResult(NamedTuple): + """ + Store the result returned by :func:`power_iteration`. + + Parameters + ---------- + eigenvalue : scalar + Rayleigh-quotient estimate of the dominant eigenvalue. + eigenvector : array-like + Normalized eigenvector estimate in the operator domain. + converged : bool-like + Whether the residual norm satisfied ``tol``. + num_iters : int-like + Number of power iterations executed. + residual_norm : scalar + Norm of ``A x - eigenvalue * x``. + """ + + eigenvalue: Any + eigenvector: Any + converged: Any + num_iters: Any + residual_norm: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "PowerIterationResult", + { + "converged": self.converged, + "num_iters": self.num_iters, + "eigenvalue": self.eigenvalue, + "residual_norm": self.residual_norm, + "eigenvector": self.eigenvector, + }, + ) + + +class _SelfAdjointAction(NamedTuple): + """Store the callable action used by power iteration.""" + + apply: Callable[[Any], Any] + domain: Space + ctx: Context + + @property + def ops(self) -> Any: + """Backend operations for this action.""" + return self.ctx.ops + + @property + def dtype(self) -> Any: + """Default dtype for this action.""" + return self.ctx.dtype + + +def _action_from_linop(A: LinOp) -> _SelfAdjointAction: + """Normalize a square linear operator into a self-adjoint action.""" + A = require_linop(A) + require_square(A, "power_iteration") + return _SelfAdjointAction(A.apply, A.domain, A.ctx) + + +def _action_from_quadratic_form(q: QuadraticForm) -> _SelfAdjointAction: + """Normalize a quadratic form into its Hessian action.""" + return _SelfAdjointAction(q.hess_apply, q.domain, q.ctx) + + +def power_iteration( + A: LinOp | QuadraticForm, + *, + x0: Any | None = None, + tol: float = 1e-6, + maxiter: int | None = None, + check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, +) -> PowerIterationResult: + r""" + Estimate the dominant eigenpair of a self-adjoint action. + + Accept a square :class:`LinOp` or a :class:`QuadraticForm` exposing + ``hess_apply``. Public dispatch converts either input into a fixed + self-adjoint action before entering the numerical loop. "Dominant" means + largest eigenvalue in absolute value, not necessarily the largest positive + eigenvalue. + + Parameters + ---------- + A : LinOp or QuadraticForm + Square operator or quadratic form whose dominant eigenpair, largest in + absolute value, is sought. Linear-operator inputs must satisfy + ``A.domain == A.codomain``; this includes the underlying space type and + inner-product geometry. + For spectral-norm estimates of a rectangular operator, pass + ``A.H @ A``. + x0 : array-like or None, optional + Initial vector in the action domain. Default is a normalized all-ones + vector in the domain geometry. + tol : float, optional + Residual-norm tolerance. ``result.converged`` is ``True`` when + ``norm(A @ x - lambda * x) < tol``. Default is 1e-6. + maxiter : int or None, optional + Maximum number of iterations. Default is ``prod(A.domain.shape)``. + check_every : int, optional + Refresh residual diagnostics every this many iterations and always on + the final iteration. Default is + ``DEFAULT_CONVERGENCE_CHECK_INTERVAL``. + + Returns + ------- + PowerIterationResult + Named tuple with fields: + + - ``eigenvalue``: Rayleigh-quotient eigenvalue estimate + - ``eigenvector``: normalized eigenvector estimate + - ``converged``: whether ``residual_norm < tol`` + - ``num_iters``: number of iterations executed + - ``residual_norm``: norm of ``A x - eigenvalue * x`` + + Raises + ------ + TypeError + If ``A`` is neither a :class:`LinOp` nor a :class:`QuadraticForm`. + ValueError + If a linear-operator input is not square or if iteration parameters are + invalid. + + See Also + -------- + lanczos_smallest : Approximate the smallest eigenpair of a Hermitian + operator. + cg : Solve Hermitian positive-definite systems. + + Notes + ----- + The residual-based stopping criterion uses + :math:`\|A x - \lambda x\|` and is refreshed only every ``check_every`` + iterations, and always on the final iteration. This function is + JIT-compatible on the JAX backend when ``maxiter`` and ``check_every`` are + static arguments. + + For operators with eigenvalues of mixed sign, the dominant eigenvalue is + the one with largest absolute value, which may be negative. Convergence + requires that this eigenvalue be separated from the rest in absolute value. + If the dominant modulus is degenerate, for example both ``lambda`` and + ``-lambda`` have maximum modulus, the iteration may oscillate between + subspaces. + + Examples + -------- + Estimate the largest eigenvalue of a diagonal operator. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((3,), ctx) + >>> A = sc.DiagonalLinOp(ctx.asarray([1.0, 3.0, 2.0]), X, ctx) + >>> result = sc.power_iteration(A, maxiter=20, tol=1e-10) + >>> np.allclose(result.eigenvalue, 3.0) + True + """ + if isinstance(A, QuadraticForm): + action = _action_from_quadratic_form(A) + elif isinstance(A, LinOp): + action = _action_from_linop(A) + else: + raise TypeError(f"A must be a LinOp or QuadraticForm, got {type(A).__name__}.") + + maxiter = check_maxiter(maxiter, action) + check_every = check_interval(check_every) + + x = default_initial_vector(action) if x0 is None else x0 + action.domain.check_member(x) + return PowerIterationResult(*_power_iteration_core(action, x, tol, maxiter, check_every)) + + +def _power_iteration_core( + action: _SelfAdjointAction, + x: Any, + tol: float, + maxiter: int, + check_every: int, +) -> tuple[Any, Any, Any, Any, Any]: + """Run the backend-loop implementation of power iteration.""" + x, _ = normalize(action.domain, x) + zero = action.ops.asarray(0.0, dtype=action.dtype) + residual_norm = action.domain.norm(x) + float("inf") + + def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: + _eigenvalue, _x, res_norm, k = carry + return (k < maxiter) & (res_norm > tol) + + def body_fun(carry: tuple[Any, Any, Any, int]) -> tuple[Any, Any, Any, int]: + _eigenvalue, x, _residual_norm, k = carry + y = action.apply(x) + x_next, _norm_y = normalize(action.domain, y) + y_next = action.apply(x_next) + eigenvalue_next = action.domain.inner(x_next, y_next) + k_next = k + 1 + + def refresh_residual(_: Any) -> Any: + residual = action.domain.axpy(-eigenvalue_next, x_next, y_next) + return action.domain.norm(residual) + + residual_norm_next = action.ops.cond( + should_check_iteration(k_next, maxiter, check_every), + refresh_residual, + lambda _: _residual_norm, + action.ops.asarray(0.0, dtype=action.dtype), + ) + return eigenvalue_next, x_next, residual_norm_next, k_next + + eigenvalue, eigenvector, residual_norm, num_iters = action.ops.while_loop( + cond_fun, + body_fun, + (zero, x, residual_norm, 0), + ) + return ( + eigenvalue, + eigenvector, + is_converged(residual_norm, tol), + num_iters, + residual_norm, + ) diff --git a/spacecore/linalg/_utils.py b/spacecore/linalg/_utils.py new file mode 100644 index 0000000..32a75d3 --- /dev/null +++ b/spacecore/linalg/_utils.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from math import prod +from typing import Any + +from ..linop import LinOp + +DEFAULT_CONVERGENCE_CHECK_INTERVAL = 64 + + +def require_linop(A: Any) -> LinOp: + """Return ``A`` as a ``LinOp`` or raise a clear type error.""" + if not isinstance(A, LinOp): + raise TypeError(f"A must be a LinOp, got {type(A).__name__}.") + return A + + +def require_square(A: LinOp, name: str) -> None: + """Raise if ``A`` is not a square operator.""" + if A.domain != A.codomain: + raise ValueError(f"{name} requires a square LinOp; got {A.domain!r} -> {A.codomain!r}.") + + +def default_maxiter(A: LinOp) -> int: + """Return the default Krylov iteration count for ``A``.""" + return max(1, prod(A.domain.shape)) + + +def check_maxiter(maxiter: int | None, A: LinOp) -> int: + """Validate an optional iteration count.""" + if maxiter is None: + return default_maxiter(A) + maxiter = int(maxiter) + if maxiter < 0: + raise ValueError("maxiter must be nonnegative.") + return maxiter + + +def check_interval(interval: int) -> int: + """Validate a convergence-check interval.""" + interval = int(interval) + if interval < 1: + raise ValueError("check_every must be positive.") + return interval + + +def should_check_iteration(k: Any, maxiter: int, interval: int) -> Any: + """Return whether iteration ``k`` should refresh convergence diagnostics.""" + return (k >= maxiter) | ((k % interval) == 0) + + +def threshold(norm_b: Any, tol: float, atol: float) -> Any: + """Return the absolute-plus-relative convergence threshold.""" + return max(float(atol), 0.0) + max(float(tol), 0.0) * norm_b + + +def real_inner(space: Any, x: Any, y: Any) -> Any: + """Return the real part of ``space.inner(x, y)``.""" + return space.ops.real(space.inner(x, y)) + + +def is_converged(residual_norm: Any, threshold_value: Any) -> Any: + """Return backend-compatible convergence predicate.""" + return residual_norm <= threshold_value + + +def safe_inverse_nonneg(ops: Any, value: Any) -> Any: + """ + Return ``1 / value`` where ``value > 0`` and zero otherwise. + + This helper is intended for norms and nonnegative residual magnitudes. It + is not a general scalar inverse: for example, ``-2`` maps to ``0``, not + ``-0.5``. + """ + positive = value > 0 + safe_value = ops.where(positive, value, ops.ones_like(value)) + return ops.where(positive, 1.0 / safe_value, ops.zeros_like(value)) + + +def normalize(space: Any, x: Any) -> tuple[Any, Any]: + """Normalize a space member and return ``(unit, norm)``.""" + norm = space.norm(x) + return space.scale(safe_inverse_nonneg(space.ops, norm), x), norm + + +def default_initial_vector(A: LinOp) -> Any: + """Return a deterministic unit vector in ``A.domain`` using its geometry.""" + size = prod(A.domain.shape) + flat = A.ops.ones((size,), dtype=A.dtype) + v = A.domain.unflatten(flat) + norm = A.domain.norm(v) + return A.domain.scale(safe_inverse_nonneg(A.ops, norm), v) + + +def summarize_value(value: Any) -> str: + """Return a compact representation for arrays, scalars, and pytrees.""" + shape = getattr(value, "shape", None) + dtype = getattr(value, "dtype", None) + if shape is not None: + shape_text = tuple(shape) + if shape_text == (): + dtype_text = str(dtype) + if dtype_text in {"bool", "bool_", "torch.bool"}: + try: + return repr(bool(value)) + except Exception: + return repr(value) + try: + return f"{float(value):.6g}" + except Exception: + return repr(value) + dtype_text = "" if dtype is None else f", dtype={dtype}" + return f"" + if isinstance(value, tuple): + return "(" + ", ".join(summarize_value(part) for part in value) + ")" + return repr(value) + + +def result_repr(name: str, fields: dict[str, Any]) -> str: + """Return a compact result-object representation.""" + body = ", ".join(f"{key}={summarize_value(value)}" for key, value in fields.items()) + return f"{name}({body})" diff --git a/spacecore/linop/__init__.py b/spacecore/linop/__init__.py index 7a6f537..b26c10b 100644 --- a/spacecore/linop/__init__.py +++ b/spacecore/linop/__init__.py @@ -1,12 +1,36 @@ +"""Linear operator abstractions, concrete operators, and algebra helpers.""" + from ._base import LinOp +from ._algebra import ( + ComposedLinOp, + IdentityLinOp, + MatrixFreeLinOp, + ScaledLinOp, + SumLinOp, + ZeroLinOp, + make_composed, + make_scaled, + make_sum, +) from ._dense import DenseLinOp +from ._diagonal import DiagonalLinOp from ._sparse import SparseLinOp from .product import ProductLinOp, StackedLinOp, SumToSingleLinOp, BlockDiagonalLinOp __all__ = [ "LinOp", + "ComposedLinOp", + "DiagonalLinOp", "DenseLinOp", + "IdentityLinOp", + "MatrixFreeLinOp", + "ScaledLinOp", "SparseLinOp", + "SumLinOp", + "ZeroLinOp", + "make_composed", + "make_scaled", + "make_sum", "ProductLinOp", "SumToSingleLinOp", "BlockDiagonalLinOp", diff --git a/spacecore/linop/_algebra.py b/spacecore/linop/_algebra.py new file mode 100644 index 0000000..d748880 --- /dev/null +++ b/spacecore/linop/_algebra.py @@ -0,0 +1,1037 @@ +from __future__ import annotations + +from numbers import Number +from typing import Any, Callable, Sequence + +from ._base import LinOp, Domain, Codomain +from .._checks import checked_method +from .._contextual._bound import _same_context_for_algebra +from ..backend import Context, jax_pytree_class + + +def is_scalar_like(value: Any) -> bool: + """Return whether ``value`` can be used as a scalar multiplier for a ``LinOp``.""" + if isinstance(value, Number): + return True + shape = getattr(value, "shape", None) + if shape is not None: + return tuple(shape) == () + ndim = getattr(value, "ndim", None) + return ndim == 0 + + +def _conjugate_scalar(value: Any) -> Any: + """Return the scalar conjugate when the value supports conjugation.""" + if hasattr(value, "conjugate"): + return value.conjugate() + if hasattr(value, "conj"): + return value.conj() + return value + + +def _require_same_context(ops: Sequence[LinOp]) -> Context: + """Return the common context for algebra operands or raise.""" + ctx = ops[0].ctx + for i, op in enumerate(ops[1:], start=1): + if not _same_context_for_algebra(ops[0].ctx, op.ctx): + raise ValueError( + "All LinOp operands in an algebraic expression must have the same ctx; " + f"operand 0 has ctx {ctx!r}, operand {i} has ctx {op.ctx!r}." + ) + return ctx + + +def _same_space_for_algebra(left: Any, right: Any) -> bool: + """Return whether two spaces are compatible for algebraic composition.""" + if type(left) is not type(right): + return False + if tuple(left.shape) != tuple(right.shape): + return False + if not _same_context_for_algebra(left.ctx, right.ctx): + return False + left_parts = getattr(left, "spaces", None) + right_parts = getattr(right, "spaces", None) + if left_parts is not None or right_parts is not None: + if left_parts is None or right_parts is None or len(left_parts) != len(right_parts): + return False + return all(_same_space_for_algebra(a, b) for a, b in zip(left_parts, right_parts)) + return True + + +def _require_linop(op: Any, name: str) -> LinOp: + """Return ``op`` as a linear operator or raise a typed error.""" + if not isinstance(op, LinOp): + raise TypeError(f"{name} must be a LinOp, got {type(op).__name__}.") + return op + + +def _scalar_equal(value: Any, target: Any) -> bool: + """Return whether two scalar-like values compare equal.""" + try: + return bool(value == target) + except Exception: + return False + + +def _is_zero_scalar(value: Any) -> bool: + """Return whether ``value`` is scalar-like zero.""" + return _scalar_equal(value, 0) + + +def _is_one_scalar(value: Any) -> bool: + """Return whether ``value`` is scalar-like one.""" + return _scalar_equal(value, 1) + + +def _flatten_sum_terms(ops: Sequence[LinOp]) -> tuple[LinOp, ...]: + """Flatten nested lazy sums into a tuple of terms.""" + terms: list[LinOp] = [] + for i, op in enumerate(ops): + op = _require_linop(op, f"ops[{i}]") + if isinstance(op, SumLinOp): + terms.extend(_flatten_sum_terms(op.parts)) + else: + terms.append(op) + return tuple(terms) + + +def make_sum(ops: Sequence[LinOp]) -> LinOp: + """ + Return a locally simplified lazy sum of linear operators. + + This factory performs only local algebraic canonicalization: nested + ``SumLinOp`` nodes are flattened and ``ZeroLinOp`` terms are removed. It + does not collect like terms, reorder operands, or attempt full symbolic + optimization. All operands must have the same context, domain, and codomain + before a simplified operator is returned. + + Parameters + ---------- + ops : sequence of LinOp + Nonempty sequence of operators with common domain and codomain. + + Returns + ------- + LinOp + Simplified lazy sum, a single operand, or a zero operator. + """ + if not ops: + raise ValueError("make_sum requires a nonempty sequence of LinOp operands.") + + terms = _flatten_sum_terms(ops) + ctx = _require_same_context(terms) + domain = terms[0].domain + codomain = terms[0].codomain + for i, op in enumerate(terms[1:], start=1): + if ( + not _same_space_for_algebra(op.domain, domain) + or not _same_space_for_algebra(op.codomain, codomain) + ): + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + + nonzero_terms = tuple(op for op in terms if not isinstance(op, ZeroLinOp)) + if not nonzero_terms: + return ZeroLinOp(domain, codomain, ctx) + if len(nonzero_terms) == 1: + return nonzero_terms[0] + return SumLinOp(nonzero_terms) + + +def make_scaled(scalar: Any, op: LinOp) -> LinOp: + """ + Return a locally simplified scalar multiple of a linear operator. + + This factory performs only local algebraic canonicalization: zero and unit + scalars are simplified, and nested ``ScaledLinOp`` nodes are folded into one + scalar. It does not distribute scaling over sums or perform full symbolic + optimization. Complex scalars retain the usual conjugated coefficient in + ``rapply`` through ``ScaledLinOp``. + + Parameters + ---------- + scalar : scalar-like + Scalar coefficient multiplying ``op``. + op : LinOp + Operator to scale. + + Returns + ------- + LinOp + Simplified scalar multiple. + """ + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + + if _is_zero_scalar(scalar): + return ZeroLinOp(op.domain, op.codomain, op.ctx) + if _is_one_scalar(scalar): + return op + if isinstance(op, ZeroLinOp): + return op + if isinstance(op, ScaledLinOp): + return make_scaled(scalar * op.scalar, op.op) + return ScaledLinOp(scalar, op) + + +def make_composed(left: LinOp, right: LinOp) -> LinOp: + """ + Return a locally simplified composition of two linear operators. + + This factory performs only local algebraic canonicalization: identity + factors are removed and compositions with zero maps become zero maps. It + preserves the binary ``ComposedLinOp`` representation and does not flatten + multi-factor chains or attempt full symbolic optimization. Operands must + have the same context and compatible middle spaces before a simplified + operator is returned. + + Parameters + ---------- + left : LinOp + Operator applied second. + right : LinOp + Operator applied first. + + Returns + ------- + LinOp + Simplified lazy composition representing ``left @ right``. + """ + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if not _same_space_for_algebra(right.codomain, left.domain): + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + + if isinstance(right, IdentityLinOp): + return left + if isinstance(left, IdentityLinOp): + return right + if isinstance(left, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + if isinstance(right, ZeroLinOp): + return ZeroLinOp(right.domain, left.codomain, left.ctx) + return ComposedLinOp(left, right) + + +@jax_pytree_class +class ScaledLinOp(LinOp[Domain, Codomain]): + r""" + Lazy scalar multiple of a linear operator. + + ``ScaledLinOp(alpha, A)`` represents the mathematical operator + ``alpha * A``. Its context is exactly ``A.ctx``; its domain is ``A.domain`` + and its codomain is ``A.codomain``. No dense matrix representation is + formed. + + The forward action is ``apply(x) = alpha * A.apply(x)`` for + ``x in A.domain``. The reverse action is + ``rapply(y) = conj(alpha) * A.rapply(y)`` for ``y in A.codomain``, so + complex scalars use the conjugated coefficient. + + Parameters + ---------- + scalar : scalar-like + Scalar multiplier. + op : LinOp + Operator being scaled. + + Attributes + ---------- + scalar : scalar-like + Stored scalar multiplier. + op : LinOp + Stored operand. + """ + + def __init__(self, scalar: Any, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + if not is_scalar_like(scalar): + raise TypeError(f"scalar must be scalar-like, got {type(scalar).__name__}.") + super().__init__(op.domain, op.codomain, op.ctx) + self.scalar = scalar + self.op = op + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """Return ``scalar * op.apply(x)``.""" + return self.scalar * self.op.apply(x) + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, y: Any) -> Any: + """Return ``conj(scalar) * op.rapply(y)``.""" + return _conjugate_scalar(self.scalar) * self.op.rapply(y) + + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``scalar * op.vapply(xs)``.""" + return self.scalar * self.op.vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``conj(scalar) * op.rvapply(ys)``.""" + return _conjugate_scalar(self.scalar) * self.op.rvapply(ys, batch_space) + + def __eq__(self, other: Any) -> bool: + """Return whether another scaled operator has the same scalar and operand.""" + if type(other) is type(self): + return self.scalar == other.scalar and self.op == other.op + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = (self.scalar, self.op) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + scalar, op = children + return cls(scalar, op) + + def _convert(self, new_ctx: Context) -> ScaledLinOp: + """Convert the operand to ``new_ctx`` while preserving the scalar.""" + return ScaledLinOp(self.scalar, self.op.convert(new_ctx)) + + +@jax_pytree_class +class SumLinOp(LinOp[Domain, Codomain]): + r""" + Lazy finite sum of linear operators with common spaces. + + ``SumLinOp((A1, ..., Ak))`` represents ``A1 + ... + Ak`` for a nonempty + sequence of ``LinOp`` instances. All operands must have the same ``ctx``, + the same domain, and the same codomain before construction. The resulting + operator has that shared context, domain, and codomain. + + The forward action is ``apply(x) = sum_i Ai.apply(x)`` for the shared + domain element ``x``. The reverse action is + ``rapply(y) = sum_i Ai.rapply(y)`` for the shared codomain element ``y``. + + Parameters + ---------- + ops : sequence of LinOp + Nonempty sequence of operators with common context, domain, and + codomain. + + Attributes + ---------- + parts : tuple of LinOp + Stored operands in the lazy sum. + """ + + def __init__(self, ops: Sequence[LinOp[Domain, Codomain]]) -> None: + if not ops: + raise ValueError("SumLinOp requires a nonempty sequence of LinOp operands.") + parts = tuple(_require_linop(op, f"ops[{i}]") for i, op in enumerate(ops)) + ctx = _require_same_context(parts) + domain = parts[0].domain + codomain = parts[0].codomain + for i, op in enumerate(parts[1:], start=1): + if ( + not _same_space_for_algebra(op.domain, domain) + or not _same_space_for_algebra(op.codomain, codomain) + ): + raise ValueError( + "All SumLinOp operands must have the same domain and codomain; " + f"operand 0 maps {domain!r} -> {codomain!r}, " + f"operand {i} maps {op.domain!r} -> {op.codomain!r}." + ) + super().__init__(domain, codomain, ctx) + self.ops_tuple = parts + + @property + def parts(self) -> tuple[LinOp[Domain, Codomain], ...]: + """Operators in this lazy sum.""" + return self.ops_tuple + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """Return ``sum_i ops[i].apply(x)``.""" + acc = self.ops_tuple[0].apply(x) + for op in self.ops_tuple[1:]: + acc = self.codomain.add(acc, op.apply(x)) + return acc + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, y: Any) -> Any: + """Return ``sum_i ops[i].rapply(y)``.""" + acc = self.ops_tuple[0].rapply(y) + for op in self.ops_tuple[1:]: + acc = self.domain.add(acc, op.rapply(y)) + return acc + + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].vapply(xs)``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + out_space = self._output_batch_space(self.codomain, in_space) + acc = self.ops_tuple[0].vapply(xs, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.vapply(xs, in_space)) + return acc + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return ``sum_i ops[i].rvapply(ys)``.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + out_space = self._output_batch_space(self.domain, in_space) + acc = self.ops_tuple[0].rvapply(ys, in_space) + for op in self.ops_tuple[1:]: + acc = out_space.add(acc, op.rvapply(ys, in_space)) + return acc + + def __eq__(self, other: Any) -> bool: + """Return whether another sum has the same operands.""" + if type(other) is type(self): + return self.ops_tuple == other.ops_tuple + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = self.ops_tuple + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + return cls(tuple(children)) + + def _convert(self, new_ctx: Context) -> SumLinOp: + """Convert all operands to ``new_ctx``.""" + return SumLinOp(tuple(op.convert(new_ctx) for op in self.ops_tuple)) + + +@jax_pytree_class +class ComposedLinOp(LinOp[Domain, Codomain]): + r""" + Lazy composition of two linear operators. + + ``ComposedLinOp(A, B)`` represents ``A @ B = A circ B``. The operands must + have the same ``ctx`` before construction, and ``B.codomain`` must equal + ``A.domain``. The resulting operator has domain ``B.domain`` and codomain + ``A.codomain``. + + The forward action is ``apply(x) = A.apply(B.apply(x))`` for + ``x in B.domain``. The reverse action is ``rapply(z) = B.rapply(A.rapply(z))`` + for ``z in A.codomain``. + + Parameters + ---------- + left : LinOp + Operator applied second. + right : LinOp + Operator applied first. + + Attributes + ---------- + left : LinOp + Left operand. + right : LinOp + Right operand. + """ + + def __init__(self, left: LinOp, right: LinOp) -> None: + left = _require_linop(left, "left") + right = _require_linop(right, "right") + _require_same_context((left, right)) + if not _same_space_for_algebra(right.codomain, left.domain): + raise ValueError( + "ComposedLinOp requires right.codomain == left.domain; " + f"got {right.codomain!r} and {left.domain!r}." + ) + super().__init__(right.domain, left.codomain, left.ctx) + self.left = left + self.right = right + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """Return ``left.apply(right.apply(x))``.""" + return self.left.apply(self.right.apply(x)) + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, z: Any) -> Any: + """Return ``right.rapply(left.rapply(z))``.""" + return self.right.rapply(self.left.rapply(z)) + + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``left.vapply(right.vapply(xs))``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + middle = self.right.codomain.batch(in_space.batch_shape, in_space.batch_axes) + return self.left.vapply(self.right.vapply(xs, in_space), middle) + + def rvapply(self, zs: Any, batch_space=None) -> Any: + """Return ``right.rvapply(left.rvapply(zs))``.""" + in_space = self._input_batch_space(self.codomain, zs, batch_space) + middle = self.left.domain.batch(in_space.batch_shape, in_space.batch_axes) + return self.right.rvapply(self.left.rvapply(zs, in_space), middle) + + def __eq__(self, other: Any) -> bool: + """Return whether another composition has the same operands.""" + if type(other) is type(self): + return self.left == other.left and self.right == other.right + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = (self.left, self.right) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + left, right = children + return cls(left, right) + + def _convert(self, new_ctx: Context) -> ComposedLinOp: + """Convert both operands to ``new_ctx``.""" + return ComposedLinOp(self.left.convert(new_ctx), self.right.convert(new_ctx)) + + +@jax_pytree_class +class ZeroLinOp(LinOp[Domain, Codomain]): + r""" + Lazy zero map between two spaces. + + ``ZeroLinOp(X, Y)`` represents the linear map ``0 : X -> Y``. The context is + resolved from the optional ``ctx`` argument and the two spaces, then both + spaces are converted to that context. Its domain is ``X`` and its codomain + is ``Y`` in the resolved context. + + The forward action is ``apply(x) = 0_Y`` for ``x in X``. The reverse action + is ``rapply(y) = 0_X`` for ``y in Y``. + + Parameters + ---------- + dom : Space + Domain space. + cod : Space + Codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. + """ + + def __init__( + self, + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, cod, ctx) + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """Return the zero element of the codomain.""" + return self._apply_unchecked(x) + + def _apply_unchecked(self, x: Any) -> Any: + """Return the codomain zero without membership checks.""" + return self.codomain.zeros() + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, y: Any) -> Any: + """Return the zero element of the domain.""" + return self._rapply_unchecked(y) + + def _rapply_unchecked(self, y: Any) -> Any: + """Return the domain zero without membership checks.""" + return self.domain.zeros() + + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return the batched zero element of the codomain.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return self._output_batch_space(self.codomain, in_space).zeros() + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """Return the batched zero element of the domain.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + return self._output_batch_space(self.domain, in_space).zeros() + + def to_dense(self) -> Any: + """ + Return the dense tensor representation of the zero map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.ops.zeros(tuple(self.codomain.shape) + tuple(self.domain.shape), dtype=self.dtype) + + def is_hermitian(self) -> bool: + """ + Return whether the zero map is Hermitian. + + Returns + ------- + bool + ``True`` exactly when domain and codomain are the same space. + """ + return self.domain == self.codomain + + def __eq__(self, other: Any) -> bool: + """Return whether another zero map has the same spaces.""" + if type(other) is type(self): + return self.domain == other.domain and self.codomain == other.codomain + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = () + aux = (self.domain, self.codomain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + def _convert(self, new_ctx: Context) -> ZeroLinOp: + """Convert domain and codomain spaces to ``new_ctx``.""" + return ZeroLinOp(self.domain.convert(new_ctx), self.codomain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class IdentityLinOp(LinOp[Domain, Domain]): + r""" + Lazy identity map on a space. + + ``IdentityLinOp(X)`` represents the identity operator ``I_X : X -> X``. The + context is resolved from the optional ``ctx`` argument and the space, and the + resulting operator has domain and codomain equal to ``X`` in that context. + + The forward action is ``apply(x) = x`` for ``x in X``. The reverse action is + ``rapply(x) = x`` for ``x in X``. + + Parameters + ---------- + space : Space + Domain and codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``space``. + """ + + def __init__(self, space: Domain, ctx: Context | str | None = None) -> None: + super().__init__(space, space, ctx) + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """Return ``x`` after domain validation.""" + return self._apply_unchecked(x) + + def _apply_unchecked(self, x: Any) -> Any: + """Return ``x`` without membership checks.""" + return x + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, x: Any) -> Any: + """Return ``x`` after codomain validation.""" + return self._rapply_unchecked(x) + + def _rapply_unchecked(self, x: Any) -> Any: + """Return ``x`` without membership checks.""" + return x + + def vapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched domain validation.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``xs`` after batched codomain validation.""" + in_space = self._input_batch_space(self.codomain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + return xs + + def to_dense(self) -> Any: + """ + Return the dense tensor representation of this identity map. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + size = 1 + for dim in self.domain.shape: + size *= dim + eye = self.ops.eye(size, dtype=self.dtype) + return self.ops.reshape(eye, tuple(self.codomain.shape) + tuple(self.domain.shape)) + + def is_hermitian(self) -> bool: + """ + Return whether this identity operator is Hermitian. + + Returns + ------- + bool + Always ``True``. + """ + return True + + def __eq__(self, other: Any) -> bool: + """Return whether another identity map has the same space.""" + if type(other) is type(self): + return self.domain == other.domain + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = () + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + domain, ctx = aux + return cls(domain, ctx) + + def _convert(self, new_ctx: Context) -> IdentityLinOp: + """Convert the identity space to ``new_ctx``.""" + return IdentityLinOp(self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinOp(LinOp[Domain, Codomain]): + """ + Linear operator defined by user-supplied forward and reverse callables. + + ``MatrixFreeLinOp(apply, rapply, X, Y)`` represents a matrix-free map + ``A : X -> Y`` without storing or materializing a matrix. The context is + resolved from the optional ``ctx`` argument and the spaces, then the spaces + are converted to that context. + + The forward action is ``apply(x) = apply_fn(x)`` for ``x in X``. The reverse + action is ``rapply(y) = rapply_fn(y)`` for ``y in Y``. When checks are + enabled, inputs and callable outputs are validated against the corresponding + domain and codomain. + + Parameters + ---------- + apply : callable + Callable with signature ``apply(x: Any) -> Any`` implementing the + forward map from ``dom`` to ``cod``. + rapply : callable + Callable with signature ``rapply(y: Any) -> Any`` implementing the + adjoint map from ``cod`` back to ``dom``. + dom : Space + Domain space containing valid inputs for ``apply`` and outputs from + ``rapply``. + cod : Space + Codomain space containing outputs from ``apply`` and valid inputs for + ``rapply``. + ctx : Context, str, or None, optional + Optional context specification. An explicit context wins over inferred + contexts from ``dom`` and ``cod``. + vapply : callable or None, optional + Optional callable with signature ``vapply(xs: Any) -> Any`` for batched + forward application. If omitted, backend ``vmap`` fallback is used. + rvapply : callable or None, optional + Optional callable with signature ``rvapply(ys: Any) -> Any`` for + batched adjoint application. If omitted, backend ``vmap`` fallback is + used. + + Returns + ------- + MatrixFreeLinOp + Operator using the supplied callables for forward, adjoint, and + optionally batched actions. + """ + + def __init__( + self, + apply: Callable[[Any], Any], + rapply: Callable[[Any], Any], + dom: Domain, + cod: Codomain, + ctx: Context | str | None = None, + vapply: Callable[[Any], Any] | None = None, + rvapply: Callable[[Any], Any] | None = None, + ) -> None: + """ + Initialize a matrix-free linear operator. + + Parameters + ---------- + apply: + Callable ``apply(x)`` that accepts an element of ``dom`` and returns + an element of ``cod``. + rapply: + Callable ``rapply(y)`` that accepts an element of ``cod`` and + returns an element of ``dom``. + dom: + Domain space of the operator. + cod: + Codomain space of the operator. + ctx: + Optional context specification for the operator and converted + spaces. + vapply: + Optional callable for batched forward application over ``dom`` + batches. + rvapply: + Optional callable for batched adjoint application over ``cod`` + batches. + + Returns + ------- + None + The initializer stores the callables and converted spaces on + ``self``. + """ + if not callable(apply): + raise TypeError(f"apply must be callable, got {type(apply).__name__}.") + if not callable(rapply): + raise TypeError(f"rapply must be callable, got {type(rapply).__name__}.") + if vapply is not None and not callable(vapply): + raise TypeError(f"vapply must be callable, got {type(vapply).__name__}.") + if rvapply is not None and not callable(rvapply): + raise TypeError(f"rvapply must be callable, got {type(rvapply).__name__}.") + super().__init__(dom, cod, ctx) + self.apply_fn = apply + self.rapply_fn = rapply + self.vapply_fn = vapply + self.rvapply_fn = rvapply + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: Any) -> Any: + """ + Apply the forward callable. + + Parameters + ---------- + x: + Element of ``self.domain`` passed to ``apply_fn``. + + Returns + ------- + Any + Element of ``self.codomain`` returned by ``apply_fn``. + """ + return self._apply_unchecked(x) + + def _apply_unchecked(self, x: Any) -> Any: + """ + Apply ``apply_fn`` without membership checks. + + Parameters + ---------- + x: + Value accepted by the user-supplied forward callable. + + Returns + ------- + Any + Raw forward-callable output. + """ + return self.apply_fn(x) + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, y: Any) -> Any: + """ + Apply the adjoint callable. + + Parameters + ---------- + y: + Element of ``self.codomain`` passed to ``rapply_fn``. + + Returns + ------- + Any + Element of ``self.domain`` returned by ``rapply_fn``. + """ + return self._rapply_unchecked(y) + + def _rapply_unchecked(self, y: Any) -> Any: + """ + Apply ``rapply_fn`` without membership checks. + + Parameters + ---------- + y: + Value accepted by the user-supplied adjoint callable. + + Returns + ------- + Any + Raw adjoint-callable output. + """ + return self.rapply_fn(y) + + def vapply(self, xs: Any, batch_space=None) -> Any: + """ + Apply this operator to a batch of domain elements. + + Parameters + ---------- + xs: + Batched element of ``self.domain``. + batch_space: + Optional batch-space descriptor for ``xs``. + + Returns + ------- + Any + Batched element of ``self.codomain`` produced by ``vapply_fn`` or + by the fallback batching implementation. + """ + if self.vapply_fn is None: + return super().vapply(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.vapply_fn(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: Any, batch_space=None) -> Any: + """ + Apply the adjoint operator to a batch of codomain elements. + + Parameters + ---------- + ys: + Batched element of ``self.codomain``. + batch_space: + Optional batch-space descriptor for ``ys``. + + Returns + ------- + Any + Batched element of ``self.domain`` produced by ``rvapply_fn`` or by + the fallback batching implementation. + """ + if self.rvapply_fn is None: + return super().rvapply(ys, batch_space) + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.rvapply_fn(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.domain == other.domain + and self.codomain == other.codomain + and self.apply_fn is other.apply_fn + and self.rapply_fn is other.rapply_fn + and self.vapply_fn is other.vapply_fn + and self.rvapply_fn is other.rvapply_fn + ) + return False + + def tree_flatten(self): + children = () + aux = ( + self.apply_fn, + self.rapply_fn, + self.domain, + self.codomain, + self.ctx, + self.vapply_fn, + self.rvapply_fn, + ) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn = aux + return cls(apply_fn, rapply_fn, domain, codomain, ctx, vapply_fn, rvapply_fn) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinOp: + """ + Convert this matrix-free operator to ``new_ctx``. + + Parameters + ---------- + new_ctx: + Concrete target context for converted domain and codomain spaces. + + Returns + ------- + MatrixFreeLinOp + Operator with converted spaces and the same user-supplied + callables. + """ + return MatrixFreeLinOp( + self.apply_fn, + self.rapply_fn, + self.domain.convert(new_ctx), + self.codomain.convert(new_ctx), + new_ctx, + self.vapply_fn, + self.rvapply_fn, + ) + + +@jax_pytree_class +class _AdjointViewLinOp(LinOp[Codomain, Domain]): + """ + Hermitian-adjoint view of a linear operator. + + ``A.H`` represents the adjoint view ``A*``. Its context is exactly + ``A.ctx``; its domain is ``A.codomain`` and its codomain is ``A.domain``. + ``A.H.H`` returns ``A`` rather than constructing another wrapper. + + The forward action is ``apply(y) = A.rapply(y)`` for ``y in A.codomain``. + The reverse action is ``rapply(x) = A.apply(x)`` for ``x in A.domain``. + """ + + def __init__(self, op: LinOp[Domain, Codomain]) -> None: + op = _require_linop(op, "op") + super().__init__(op.codomain, op.domain, op.ctx) + self.op = op + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, y: Any) -> Any: + """Return ``op.rapply(y)``.""" + return self.op.rapply(y) + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, x: Any) -> Any: + """Return ``op.apply(x)``.""" + return self.op.apply(x) + + def vapply(self, ys: Any, batch_space=None) -> Any: + """Return ``op.rvapply(ys)`` over a batch.""" + return self.op.rvapply(ys, batch_space) + + def rvapply(self, xs: Any, batch_space=None) -> Any: + """Return ``op.vapply(xs)`` over a batch.""" + return self.op.vapply(xs, batch_space) + + @property + def H(self) -> LinOp[Domain, Codomain]: + """Original operator viewed as the adjoint of this adjoint view.""" + return self.op + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.op == other.op + return False + + def tree_flatten(self): + children = (self.op,) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + return cls(children[0]) + + def _convert(self, new_ctx: Context) -> _AdjointViewLinOp: + return _AdjointViewLinOp(self.op.convert(new_ctx)) diff --git a/spacecore/linop/_base.py b/spacecore/linop/_base.py index dfc06b9..4589e4c 100644 --- a/spacecore/linop/_base.py +++ b/spacecore/linop/_base.py @@ -1,71 +1,315 @@ from __future__ import annotations from abc import abstractmethod +from functools import cached_property +from math import prod +from numbers import Number from typing import Any, Generic, TypeVar from ..space import Space from ..backend import Context -from .._contextual import ContextBound -from .._contextual.manager import ctx_manager +from .._contextual import ContextBound, resolve_context_priority Domain = TypeVar('Domain', bound=Space) Codomain = TypeVar('Codomain', bound=Space) class LinOp(ContextBound, Generic[Domain, Codomain]): - """ - Minimal linear operator (morphism) between two spaces. + r""" + Represent a linear map between two spaces. + + This class is intentionally small. It defines no storage assumptions and + requires subclasses to provide forward and adjoint actions. + + The adjoint :math:`A^*` satisfies + :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X` for + :math:`x \in X` and :math:`y \in Y`. For complex operators this is the + conjugate adjoint. + + Parameters + ---------- + dom : Space + Domain space ``X``. + cod : Space + Codomain space ``Y``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom`` and + ``cod``. - This class is intentionally small. It defines no matrix semantics, - arithmetic, or storage assumptions. + Attributes + ---------- + dom : Space + Domain space converted to ``ctx``. + cod : Space + Codomain space converted to ``ctx``. + ctx : Context + Resolved backend context. - Its sole purpose is to represent a linear map - ``A : dom -> cod`` - with access to both forward and adjoint actions. + Examples + -------- + Use a concrete dense operator as a :class:`LinOp`. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 2.0]]), X, X, ctx) + >>> A.apply(ctx.asarray([3.0, 4.0])) + array([3., 8.]) """ def __init__(self, dom: Domain, cod: Codomain, ctx: Context | str | None = None): - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) super(LinOp, self).__init__(ctx) self.dom = dom.convert(self.ctx) self.cod = cod.convert(self.ctx) self._enable_checks = self.ctx.enable_checks - @abstractmethod - def apply(self, x: Any) -> Any: + @property + def domain(self) -> Domain: + """Domain space of this linear operator.""" + return self.dom + + @property + def codomain(self) -> Codomain: + """Codomain space of this linear operator.""" + return self.cod + + @cached_property + def A(self) -> Any: """ - Forward application: y = A x + Native numerical representation of this operator. - Contract: - - x is an element of self.dom - - return value is an element of self.cod + Concrete subclasses may choose the representation that best matches + their storage model: for example, dense operators return a dense array + while sparse operators return their sparse matrix. Matrix-free or lazy + operators generally do not have such a representation and should leave + this property unimplemented. Use :meth:`to_dense` when a dense tensor + materialization is explicitly required. """ + raise NotImplementedError( + f"{type(self).__name__} does not define a native numerical representation." + ) + + @abstractmethod + def apply(self, x: Any) -> Any: + """Apply the forward map to an element of ``self.domain``.""" @abstractmethod def rapply(self, y: Any) -> Any: + """Apply the adjoint map to an element of ``self.codomain``.""" + + def __call__(self, x: Any) -> Any: + """Apply this linear operator to ``x``.""" + return self.apply(x) + + def adjoint_apply(self, y: Any) -> Any: + """Apply the adjoint of this linear operator to ``y``.""" + return self.rapply(y) + + def is_hermitian(self) -> bool | None: """ - Adjoint application: x = A^* y + Return whether this operator is structurally Hermitian when known. - Contract: - - y is an element of self.cod - - return value is an element of self.dom + Returns + ------- + bool | None + ``True`` or ``False`` when the subclass can verify the structure + cheaply, otherwise ``None`` for unknown or matrix-free operators. """ + return None - def __call__(self, x: Any) -> Any: - return self.apply(x) + def vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + """Apply this operator independently over a batch of domain elements.""" + return self._fallback_vapply(xs, batch_space) + + def rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + """Apply the adjoint independently over a batch of codomain elements.""" + return self._fallback_rvapply(ys, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + """Infer leading batch dimensions from a value and base space.""" + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + """Return the batch space used to validate batched inputs.""" + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + """Return the batch space corresponding to a batched output.""" + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _fallback_vapply(self, xs: Any, batch_space: Space | None = None) -> Any: + """Apply ``self.apply`` over a leading batch with backend ``vmap``.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + ys = self.ops.vmap(self.apply, in_axes=0, out_axes=0)(xs) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def _fallback_rvapply(self, ys: Any, batch_space: Space | None = None) -> Any: + """Apply ``self.rapply`` over a leading batch with backend ``vmap``.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + xs = self.ops.vmap(self.rapply, in_axes=0, out_axes=0)(ys) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + @property + def H(self) -> LinOp: + r"""Hermitian-adjoint view of this linear operator. + + Returns + ------- + LinOp + Adjoint view satisfying + :math:`\langle A x, y\rangle_Y = \langle x, A^* y\rangle_X`. + """ + from ._algebra import _AdjointViewLinOp + + view = getattr(self, "_adjoint_view", None) + if view is None: + view = _AdjointViewLinOp(self) + self._adjoint_view = view + return view + + def __add__(self, other: Any) -> LinOp: + """Return the lazy sum ``self + other`` of two compatible operators.""" + from ._algebra import make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, other)) + + def __radd__(self, other: Any) -> LinOp: + """Return the lazy sum ``other + self`` of two compatible operators.""" + from ._algebra import make_sum + + if isinstance(other, Number) and other == 0: + return self + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, self)) + + def __neg__(self) -> LinOp: + """Return the lazy negation ``-self``.""" + from ._algebra import make_scaled + + return make_scaled(-1, self) + + def __sub__(self, other: Any) -> LinOp: + """Return the lazy difference ``self - other`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((self, make_scaled(-1, other))) + + def __rsub__(self, other: Any) -> LinOp: + """Return the lazy difference ``other - self`` of two compatible operators.""" + from ._algebra import make_scaled, make_sum + + if isinstance(other, Number) and other == 0: + return make_scaled(-1, self) + if not isinstance(other, LinOp): + return NotImplemented + return make_sum((other, make_scaled(-1, self))) + + def __mul__(self, scalar: Any) -> LinOp: + """Return the lazy right scalar multiple ``self * scalar``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __rmul__(self, scalar: Any) -> LinOp: + """Return the lazy left scalar multiple ``scalar * self``.""" + from ._algebra import is_scalar_like, make_scaled + + if not is_scalar_like(scalar): + return NotImplemented + return make_scaled(scalar, self) + + def __matmul__(self, other: Any) -> LinOp: + """Return the lazy composition ``self @ other`` of two compatible operators.""" + from ._algebra import make_composed + + if not isinstance(other, LinOp): + return NotImplemented + return make_composed(self, other) + + def adjoint(self) -> LinOp: + """Return the Hermitian-adjoint view of this linear operator.""" + return self.H + + def to_dense(self) -> Any: + """ + Materialize this operator as a dense backend array. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + The default implementation applies the operator to each standard basis + vector of the domain, stacks the flattened outputs as matrix columns, + and reshapes the result back to tensor-operator form. Subclasses that + already store the matrix should override this method for efficiency. + """ + domain_size = prod(self.domain.shape) + eye = self.ops.eye(domain_size, dtype=self.dtype) + columns = [] + for i in range(domain_size): + basis_vector = eye[:, i] + x = self.domain.unflatten(basis_vector) + y = self.apply(x) + columns.append(self.codomain.flatten(y)) + matrix = self.ops.stack(tuple(columns), axis=1) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) def assert_domain(self, x: Any) -> None: + """Raise if ``x`` is not in the domain.""" self.dom.check_member(x) def assert_codomain(self, y: Any) -> None: + """Raise if ``y`` is not in the codomain.""" self.cod.check_member(y) - def __eq__(self, x: Any) -> bool: - raise NotImplementedError() + def __eq__(self, other: Any) -> bool: + """Return structural equality when implemented by a subclass.""" + return NotImplemented + @abstractmethod def tree_flatten(self): - raise NotImplementedError() + """Flatten this operator for backend pytree registration.""" + ... @classmethod + @abstractmethod def tree_unflatten(cls, aux, children): - raise NotImplementedError() + """Rebuild this operator from backend pytree data.""" + ... diff --git a/spacecore/linop/_dense.py b/spacecore/linop/_dense.py index a751799..95040ab 100644 --- a/spacecore/linop/_dense.py +++ b/spacecore/linop/_dense.py @@ -1,24 +1,54 @@ from __future__ import annotations +from functools import cached_property from math import prod from typing import Any from ._base import LinOp, Domain, Codomain +from .._checks import checked_method from ..space import VectorSpace from ..types import DenseArray from ..backend import jax_pytree_class, Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority @jax_pytree_class class DenseLinOp(LinOp[VectorSpace, VectorSpace]): - """ - Dense linear operator defined by an array A with shape: + r""" + Represent a dense tensor-backed linear operator. + + ``DenseLinOp(A, dom, cod)`` represents a linear map + :math:`A \colon X \to Y` where the stored dense array has shape + ``cod.shape + dom.shape``. Forward application contracts over the domain + axes; adjoint application uses the conjugate transpose of the flattened + matrix representation. + + Parameters + ---------- + A : DenseArray + Dense backend array with shape ``cod.shape + dom.shape``. + dom : Space + Domain space. + cod : Space or None, optional + Codomain space. If omitted, it is inferred from the leading axes of + ``A``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. - A.shape == cod.shape + dom.shape + Attributes + ---------- + A : DenseArray + Stored dense operator tensor. - apply: y = A ⋅ x (contract over dom axes) - rapply: x = A^* ⋅ y (contract over cod axes) + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 3.0]]), X, X, ctx) + >>> A.apply(ctx.asarray([1.0, 2.0])) + array([2., 6.]) """ def __init__(self, @@ -27,7 +57,7 @@ def __init__(self, cod: Codomain | None = None, ctx: Context | str | None = None ) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) ctx.assert_dense(A) # Check if A is ndarray of ctx if cod is None: @@ -40,55 +70,188 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == cod.shape + dom.shape == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = prod(self.cod.shape) self._dom_size = prod(self.dom.shape) self._matrix_shape = (self._cod_size, self._dom_size) self._A2 = self.A.reshape(self._matrix_shape) dtype = self.ops.get_dtype(self.A) - is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + is_complex = self.ops.is_complex_dtype(dtype) + self._A2T = self._A2.T self._A2H = self._A2.T.conj() if is_complex else self._A2.T self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._dom_vector_fast_path = type(self.dom) is VectorSpace self._cod_vector_fast_path = type(self.cod) is VectorSpace - if not self._enable_checks: - self.apply = self._apply_unchecked - self.rapply = self._rapply_unchecked - def apply(self, x: DenseArray) -> DenseArray: + @cached_property + def A(self) -> DenseArray: """ - Forward action: y = A ⋅ x with y in cod.shape. + Stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape`` + and is the same object supplied at construction. """ - if self._enable_checks: - self.dom._check_member(x) + return self._A + + @checked_method(in_space="dom", out_space="cod") + def apply(self, x: DenseArray) -> DenseArray: + """Apply the dense operator to ``x``.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: + """Apply the flattened dense matrix without membership checks.""" x1 = x if self._dom_is_flat else x.reshape((self._dom_size,)) y1 = self._A2 @ x1 if self._cod_vector_fast_path: return y1 if self._cod_is_flat else y1.reshape(self.cod.shape) return self.cod.unflatten(y1) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: DenseArray) -> DenseArray: - """ - Adjoint action: x = A^* ⋅ y with x in dom.shape. + r"""Apply the adjoint dense operator to ``y``. - For complex A, uses conjugate-transpose of the 2D reshaped matrix. + For complex ``A``, use the conjugate transpose of the flattened matrix. """ - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: + """Apply the flattened adjoint matrix without membership checks.""" y1 = y if self._cod_is_flat else y.reshape((self._cod_size,)) x1 = self._A2H @ y1 if self._dom_vector_fast_path: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + """Infer leading batch dimensions from an input array.""" + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + """Return whether a batch space uses leading batch axes.""" + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + """Return the explicit batch shape from a batch-space object.""" + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + """Apply the dense operator over leading batch axes.""" + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = xs2 @ self._A2T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + """Apply the dense adjoint over leading batch axes.""" + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = ys2 @ self._A2H.T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply a batch using the fast leading-axis path when possible.""" + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply adjoints over a batch using the fast leading-axis path when possible.""" + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply this operator independently over a batch of domain elements.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + ys = self._vapply_unchecked_leading(xs, batch_shape) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply the adjoint independently over a batch of codomain elements.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + xs = self._rvapply_unchecked_leading(ys, batch_shape) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def to_dense(self) -> DenseArray: + """ + Return the stored dense tensor representation of this operator. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + return self.A + + def is_hermitian(self) -> bool | None: + """ + Return whether this dense operator is structurally Hermitian. + + Returns + ------- + bool or None + ``True`` or ``False`` for plain :class:`VectorSpace` domains, where + Hermiticity is checked against the Euclidean flattened matrix. + ``None`` for custom geometries whose inner product may differ from + the Euclidean coordinate product. + """ + if self.dom != self.cod: + return False + if type(self.dom) is not VectorSpace: + return None + try: + return bool(self.ops.allclose(self._A2, self._A2H)) + except Exception: + return None + def __eq__(self, x: Any) -> bool: + """Return whether another dense operator has the same spaces and values.""" if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod @@ -97,17 +260,20 @@ def __eq__(self, x: Any) -> bool: return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" aux = (self.dom, self.cod, self.ctx) children = (self.A,) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" dom, cod, ctx = aux A = children[0] return cls(A, dom, cod, ctx) def _convert(self, new_ctx: Context) -> DenseLinOp: + """Convert spaces and stored dense tensor to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_A = new_ctx.asarray(self.A) diff --git a/spacecore/linop/_diagonal.py b/spacecore/linop/_diagonal.py new file mode 100644 index 0000000..b5c45f2 --- /dev/null +++ b/spacecore/linop/_diagonal.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from functools import cached_property +from math import prod +from typing import Any + +from ._base import LinOp +from .._checks import checked_method +from ..backend import Context, jax_pytree_class +from ..space import VectorSpace +from ..types import DenseArray +from .._contextual import resolve_context_priority + + +@jax_pytree_class +class DiagonalLinOp(LinOp[VectorSpace, VectorSpace]): + r""" + Represent a coordinatewise diagonal linear operator. + + ``DiagonalLinOp(diagonal, space)`` maps ``x`` to ``diagonal * x`` on a + :class:`VectorSpace`. The adjoint uses the complex conjugate of the + diagonal, so complex-valued diagonals follow the SpaceCore adjoint + convention. + + Parameters + ---------- + diagonal : DenseArray + Dense backend array with shape ``space.shape``. + space : VectorSpace or None, optional + Domain and codomain space. If omitted, a vector space is inferred from + ``diagonal.shape``. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``space``. + + Attributes + ---------- + diagonal : DenseArray + Stored diagonal values. + + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> D = sc.DiagonalLinOp(ctx.asarray([2.0, 3.0]), X, ctx) + >>> D.apply(ctx.asarray([4.0, 5.0])) + array([ 8., 15.]) + """ + + def __init__( + self, + diagonal: DenseArray, + space: VectorSpace | None = None, + ctx: Context | str | None = None, + ) -> None: + ctx = resolve_context_priority(ctx, space) + ctx.assert_dense(diagonal) + if space is None: + space = VectorSpace(tuple(diagonal.shape), ctx) + super().__init__(space, space, ctx) + expected = tuple(self.domain.shape) + if tuple(diagonal.shape) != expected: + raise TypeError(f"Expected diagonal.shape == space.shape == {expected}, got {diagonal.shape}") + self.diagonal = diagonal + dtype = self.ops.get_dtype(diagonal) + self._diag_adjoint = ( + self.ops.conj(diagonal) if self.ops.is_complex_dtype(dtype) else diagonal + ) + + @cached_property + def A(self) -> DenseArray: + """Dense tensor representation of this diagonal operator.""" + return self.to_dense() + + @checked_method(in_space="domain", out_space="codomain") + def apply(self, x: DenseArray) -> DenseArray: + """Apply the diagonal operator to ``x``.""" + return self.diagonal * x + + @checked_method(in_space="codomain", out_space="domain") + def rapply(self, y: DenseArray) -> DenseArray: + """Apply the adjoint diagonal operator to ``y``.""" + return self._diag_adjoint * y + + def _reshape_diagonal_for_batch(self, diagonal: DenseArray, batch_space: Any) -> DenseArray: + """Broadcast diagonal values over a batch space.""" + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + total_ndim = len(self.domain.shape) + len(batch_shape) + base_axes = [axis for axis in range(total_ndim) if axis not in batch_axes] + shape = [1] * total_ndim + for axis, dim in zip(base_axes, self.domain.shape, strict=True): + shape[axis] = dim + return self.ops.reshape(diagonal, tuple(shape)) + + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + """Apply this diagonal operator over a batch of domain elements.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + diagonal = self._reshape_diagonal_for_batch(self.diagonal, in_space) + ys = diagonal * xs + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + """Apply the adjoint over a batch of codomain elements.""" + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + diagonal = self._reshape_diagonal_for_batch(self._diag_adjoint, in_space) + xs = diagonal * ys + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def to_dense(self) -> DenseArray: + """Return a dense tensor representation of this diagonal operator.""" + flat = self.diagonal.reshape((prod(self.domain.shape),)) + matrix = self.ops.diag(flat) + return self.ops.reshape(matrix, tuple(self.codomain.shape) + tuple(self.domain.shape)) + + def is_hermitian(self) -> bool | None: + """ + Return whether this diagonal operator is structurally Hermitian. + + Returns + ------- + bool + ``True`` when the diagonal equals its complex conjugate. + """ + try: + return bool(self.ops.allclose(self.diagonal, self._diag_adjoint)) + except Exception: + return None + + def __eq__(self, other: Any) -> bool: + """Return whether another diagonal operator has the same space and values.""" + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose(self.diagonal, other.diagonal) + return False + + def tree_flatten(self): + """Flatten this operator for pytree registration.""" + children = (self.diagonal,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" + domain, ctx = aux + return cls(children[0], domain, ctx) + + def _convert(self, new_ctx: Context) -> DiagonalLinOp: + """Convert the stored diagonal and space to ``new_ctx``.""" + return DiagonalLinOp( + new_ctx.asarray(self.diagonal), + VectorSpace(tuple(self.domain.shape), new_ctx), + new_ctx, + ) diff --git a/spacecore/linop/_sparse.py b/spacecore/linop/_sparse.py index 20780ab..dfffe44 100644 --- a/spacecore/linop/_sparse.py +++ b/spacecore/linop/_sparse.py @@ -1,25 +1,52 @@ from __future__ import annotations +from functools import cached_property from math import prod from typing import Any from ._base import LinOp, Domain, Codomain +from .._checks import checked_method from ..space import VectorSpace from ..types import DenseArray, SparseArray from ..backend import jax_pytree_class, Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority @jax_pytree_class class SparseLinOp(LinOp): - """ - Sparse linear operator implementing the tensor map A : dom -> cod where - conceptually A has shape cod.shape + dom.shape, but stored as a 2D sparse matrix: + r""" + Represent a sparse matrix-backed linear operator. + + ``SparseLinOp(A, dom, cod)`` represents a tensor map whose conceptual shape + is ``cod.shape + dom.shape`` while storage uses a two-dimensional sparse + matrix with shape ``(prod(cod.shape), prod(dom.shape))``. - A2.shape == (prod(cod.shape), prod(dom.shape)) + Parameters + ---------- + A : SparseArray + Sparse backend matrix with shape ``(prod(cod.shape), prod(dom.shape))``. + dom : Space + Domain space. + cod : Space + Codomain space. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from the spaces. - apply: y = A ⋅ x (contract over dom axes) - rapply: x = A^* ⋅ y (contract over cod axes) + Attributes + ---------- + A : SparseArray + Stored sparse matrix representation. + + Examples + -------- + >>> import numpy as np + >>> import scipy.sparse as sps + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> A = sc.SparseLinOp(ctx.assparse(sps.eye(2)), X, X, ctx) + >>> A.apply(ctx.asarray([1.0, 2.0])) + array([1., 2.]) """ def __init__(self, @@ -28,7 +55,7 @@ def __init__(self, cod: Codomain, ctx: Context | str | None = None ) -> None: - ctx = ctx_manager.resolve_context_priority(ctx, dom, cod) + ctx = resolve_context_priority(ctx, dom, cod) ctx.assert_sparse(A) # Check if A is sparse array of ctx super(SparseLinOp, self).__init__(dom, cod, ctx) @@ -37,49 +64,57 @@ def __init__(self, if tuple(A.shape) != expected: raise TypeError(f"Expected A.shape == (prod(cod.shape), prod(dom.shape)) == {expected}, got {A.shape}") - self.A = A # No dtype conversion + self._A = A # No dtype conversion self._cod_size = expected[0] self._dom_size = expected[1] dtype = self.ops.get_dtype(self.A) - self._A_is_complex = getattr(dtype, "kind", None) == "c" or str(dtype).startswith("torch.complex") + self._A_is_complex = self.ops.is_complex_dtype(dtype) self._AT = self.A.T self._AH = self._AT.conj() if self._A_is_complex else self._AT self._dom_is_flat = tuple(self.dom.shape) == (self._dom_size,) self._cod_is_flat = tuple(self.cod.shape) == (self._cod_size,) self._dom_vector_fast_path = type(self.dom) is VectorSpace self._cod_vector_fast_path = type(self.cod) is VectorSpace - if not self._enable_checks: - self.apply = self._apply_unchecked - self.rapply = self._rapply_unchecked + @cached_property + def A(self) -> SparseArray: + """ + Stored sparse matrix representation of this operator. + + The returned sparse matrix has shape + ``(prod(self.codomain.shape), prod(self.domain.shape))`` and is the + same object supplied at construction. + """ + return self._A + + @checked_method(in_space="dom", out_space="cod") def apply(self, x: DenseArray) -> DenseArray: """ Forward action: y = A ⋅ x with y in cod.shape. x must have shape dom.shape (dense). """ - if self._enable_checks: - self.dom._check_member(x) return self._apply_unchecked(x) def _apply_unchecked(self, x: DenseArray) -> DenseArray: + """Apply the stored sparse matrix without membership checks.""" x1 = x if self._dom_is_flat else x.reshape((self._dom_size,)) y1 = self.A @ x1 # (m,) if self._cod_vector_fast_path: return y1 if self._cod_is_flat else y1.reshape(self.cod.shape) return self.cod.unflatten(y1) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: DenseArray) -> DenseArray: """ Adjoint action: x = A^* ⋅ y with x in dom.shape. y must have shape cod.shape (dense). """ - if self._enable_checks: - self.cod._check_member(y) return self._rapply_unchecked(y) def _rapply_unchecked(self, y: DenseArray) -> DenseArray: + """Apply the stored sparse adjoint without membership checks.""" y1 = y if self._cod_is_flat else y.reshape((self._cod_size,)) x1 = self._AH @ y1 @@ -87,6 +122,132 @@ def _rapply_unchecked(self, y: DenseArray) -> DenseArray: return x1 if self._dom_is_flat else x1.reshape(self.dom.shape) return self.dom.unflatten(x1) + @staticmethod + def _batch_shape_from_input(value: DenseArray, base_ndim: int) -> tuple[int, ...]: + shape = tuple(value.shape) + return shape if base_ndim == 0 else shape[:-base_ndim] + + @staticmethod + def _is_leading_batch(batch_space: Any) -> bool: + if batch_space is None: + return True + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + return batch_axes == tuple(range(len(batch_shape))) + + @staticmethod + def _batch_shape_from_space(batch_space: Any) -> tuple[int, ...]: + return tuple(getattr(batch_space, "batch_shape")) + + def _vapply_unchecked_leading( + self, + xs: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + xs2 = xs.reshape((-1, self._dom_size)) + ys2 = (self.A @ xs2.T).T + if self._cod_vector_fast_path: + if self._cod_is_flat and tuple(ys2.shape[:-1]) == batch_shape: + return ys2 + return ys2.reshape(batch_shape + tuple(self.cod.shape)) + ys_flat = ys2.reshape(batch_shape + (self._cod_size,)) + return self.cod.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(ys_flat) + + def _rvapply_unchecked_leading( + self, + ys: DenseArray, + batch_shape: tuple[int, ...], + ) -> DenseArray: + ys2 = ys.reshape((-1, self._cod_size)) + xs2 = (self._AH @ ys2.T).T + if self._dom_vector_fast_path: + if self._dom_is_flat and tuple(xs2.shape[:-1]) == batch_shape: + return xs2 + return xs2.reshape(batch_shape + tuple(self.dom.shape)) + xs_flat = xs2.reshape(batch_shape + (self._dom_size,)) + return self.dom.batch(batch_shape, tuple(range(len(batch_shape)))).unflatten(xs_flat) + + def _vapply_unchecked(self, xs: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_vapply(xs, batch_space) + batch_shape = ( + self._batch_shape_from_input(xs, len(self.domain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._vapply_unchecked_leading(xs, batch_shape) + + def _rvapply_unchecked(self, ys: DenseArray, batch_space=None) -> DenseArray: + if not self._is_leading_batch(batch_space): + return self._fallback_rvapply(ys, batch_space) + batch_shape = ( + self._batch_shape_from_input(ys, len(self.codomain.shape)) + if batch_space is None + else self._batch_shape_from_space(batch_space) + ) + return self._rvapply_unchecked_leading(ys, batch_shape) + + def vapply(self, xs: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.domain, xs, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_vapply(xs, batch_space) + if self._enable_checks: + in_space._check_member(xs) + batch_shape = tuple(in_space.batch_shape) + ys = self._vapply_unchecked_leading(xs, batch_shape) + if self._enable_checks: + self._output_batch_space(self.codomain, in_space)._check_member(ys) + return ys + + def rvapply(self, ys: DenseArray, batch_space=None) -> DenseArray: + in_space = self._input_batch_space(self.codomain, ys, batch_space) + if tuple(getattr(in_space, "batch_axes", ())) != tuple(range(len(in_space.batch_shape))): + return self._fallback_rvapply(ys, batch_space) + if self._enable_checks: + in_space._check_member(ys) + batch_shape = tuple(in_space.batch_shape) + xs = self._rvapply_unchecked_leading(ys, batch_shape) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(xs) + return xs + + def to_dense(self) -> DenseArray: + """ + Materialize the stored sparse matrix as a dense operator tensor. + + The returned array has shape ``self.codomain.shape + self.domain.shape``. + """ + if hasattr(self.A, "toarray"): + dense = self.A.toarray() + elif hasattr(self.A, "todense"): + dense = self.A.todense() + elif hasattr(self.A, "to_dense"): + dense = self.A.to_dense() + else: + dense = super().to_dense().reshape((self._cod_size, self._dom_size)) + return self.ops.reshape(dense, tuple(self.codomain.shape) + tuple(self.domain.shape)) + + def is_hermitian(self) -> bool | None: + """ + Return whether this sparse operator is structurally Hermitian. + + Returns + ------- + bool or None + ``True`` or ``False`` for plain :class:`VectorSpace` domains, where + Hermiticity is checked against the Euclidean sparse matrix. + ``None`` for custom geometries whose inner product may differ from + the Euclidean coordinate product. + """ + if self.dom != self.cod: + return False + if type(self.dom) is not VectorSpace: + return None + try: + return bool(self.ops.allclose_sparse(self.A, self._AH)) + except Exception: + return None + def __eq__(self, x: Any) -> bool: if type(x) is type(self): return (self.dom == x.dom diff --git a/spacecore/linop/product/__init__.py b/spacecore/linop/product/__init__.py index 4c534ab..1d81f74 100644 --- a/spacecore/linop/product/__init__.py +++ b/spacecore/linop/product/__init__.py @@ -1,3 +1,5 @@ +"""Linear operators that map to or from product spaces.""" + from ._base import ProductLinOp from ._block import BlockDiagonalLinOp from ._from_single import StackedLinOp @@ -8,4 +10,4 @@ "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", -] \ No newline at end of file +] diff --git a/spacecore/linop/product/_base.py b/spacecore/linop/product/_base.py index 7e4d4d9..b7a2567 100644 --- a/spacecore/linop/product/_base.py +++ b/spacecore/linop/product/_base.py @@ -10,7 +10,19 @@ @jax_pytree_class class ProductLinOp(LinOp[Domain, Codomain]): """ - Base class for linear operators assembled from component operators. + Define a base class for operators assembled from component operators. + + Parameters + ---------- + dom : Space + Domain space of the assembled operator. + cod : Space + Codomain space of the assembled operator. + parts : sequence of LinOp + Nonempty sequence of component operators. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from ``dom`` and + ``cod``. """ parts: Tuple[LinOp, ...] @@ -31,25 +43,20 @@ def __init__(self, self._apply_parts = tuple(getattr(op, "_apply_unchecked", op.apply) for op in self.parts) self._rapply_parts = tuple(getattr(op, "_rapply_unchecked", op.rapply) for op in self.parts) self._check_layout() - unchecked_apply = getattr(self, "_apply_unchecked", None) - unchecked_rapply = getattr(self, "_rapply_unchecked", None) - if not self._enable_checks and unchecked_apply is not None and unchecked_rapply is not None: - self.apply = unchecked_apply - self.rapply = unchecked_rapply @abstractmethod def _check_layout(self) -> None: - """ - Check incidence compatibility between self.parts and self.dom/self.cod. - """ + """Check incidence compatibility between parts and endpoint spaces.""" raise NotImplementedError @classmethod @abstractmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> ProductLinOp: + """Build a product operator from component operators.""" ... def __eq__(self, x: Any) -> bool: + """Return whether another product operator has the same layout.""" if type(x) is type(self): return (self.dom == x.dom and self.cod == x.cod @@ -59,11 +66,13 @@ def __eq__(self, x: Any) -> bool: return False def tree_flatten(self): + """Flatten this operator for pytree registration.""" children = self.parts aux = (self.dom, self.cod, self.ctx) return children, aux @classmethod def tree_unflatten(cls, aux, children): + """Rebuild this operator from pytree data.""" dom, cod, ctx = aux return cls(dom, cod, tuple(children), ctx) diff --git a/spacecore/linop/product/_block.py b/spacecore/linop/product/_block.py index a2f7fe8..9303225 100644 --- a/spacecore/linop/product/_block.py +++ b/spacecore/linop/product/_block.py @@ -4,6 +4,7 @@ from ._base import ProductLinOp from .._base import LinOp +from ..._checks import checked_method from ... import Context from ...space import ProductSpace from ...backend import jax_pytree_class @@ -11,16 +12,26 @@ @jax_pytree_class class BlockDiagonalLinOp(ProductLinOp[ProductSpace, ProductSpace]): - """ + r""" Block-diagonal operator between product spaces. - dom = X1 × ... × Xk - cod = Y1 × ... × Yk - - ops[i] : Xi -> Yi + If ``dom = X1 x ... x Xk`` and ``cod = Y1 x ... x Yk``, component + ``parts[i]`` maps ``Xi`` to ``Yi``. + + Parameters + ---------- + dom : ProductSpace + Product domain. + cod : ProductSpace + Product codomain. + parts : sequence of LinOp + Component operators with matching product incidence. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that each component maps the matching product component.""" if not isinstance(self.dom, ProductSpace) or not isinstance(self.cod, ProductSpace): raise TypeError("BlockDiagonalLinOp expects dom and cod to be ProductSpace.") @@ -33,28 +44,55 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} has incompatible dom/cod spaces.") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) + """Apply each block to the matching product component.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply each block without membership checks.""" if self._num_parts == 2: return self._apply_parts[0](x[0]), self._apply_parts[1](x[1]) return tuple(apply(xi) for apply, xi in zip(self._apply_parts, x)) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) + """Apply each adjoint block to the matching product component.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply each adjoint block without membership checks.""" if self._num_parts == 2: return self._rapply_parts[0](y[0]), self._rapply_parts[1](y[1]) return tuple(rapply(yi) for rapply, yi in zip(self._rapply_parts, y)) + def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this block-diagonal operator over a product batch.""" + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + for op, xi in zip(self.parts, x) + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint over a product batch.""" + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + for op, yi in zip(self.parts, y) + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: + """Build a block-diagonal operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -63,6 +101,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> BlockDiagonalLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> BlockDiagonalLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/linop/product/_from_single.py b/spacecore/linop/product/_from_single.py index abef674..408098f 100644 --- a/spacecore/linop/product/_from_single.py +++ b/spacecore/linop/product/_from_single.py @@ -4,24 +4,34 @@ from ._base import ProductLinOp from .._base import LinOp, Domain +from ..._checks import checked_method from ...space import ProductSpace, VectorSpace from ...backend import jax_pytree_class, Context @jax_pytree_class class StackedLinOp(ProductLinOp[Domain, ProductSpace]): - """ + r""" Stack of operators from a single domain into a product codomain. - dom = X - cod = Y1 × ... × Yk - - ``ops[i] : X -> Yi`` - ``apply(x) = (ops[i](x))_i`` - ``rapply(y) = sum_i ops[i]^*(y_i)`` + If ``dom = X`` and ``cod = Y1 x ... x Yk``, component ``parts[i]`` maps + ``X`` to ``Yi``. Forward application returns a tuple of component outputs; + adjoint application sums component adjoints in ``X``. + + Parameters + ---------- + dom : Space + Shared component domain. + cod : ProductSpace + Product codomain. + parts : sequence of LinOp + Operators from ``dom`` to each component of ``cod``. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that every component maps the shared domain to one codomain part.""" if not isinstance(self.cod, ProductSpace): raise TypeError("StackedLinOp expects cod to be ProductSpace.") @@ -34,22 +44,24 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} must map dom -> cod.spaces[{i}].") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) + """Apply each component operator to the same input.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply component operators without membership checks.""" if self._num_parts == 2: return self._apply_parts[0](x), self._apply_parts[1](x) return tuple(apply(x) for apply in self._apply_parts) + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) + """Apply component adjoints and sum in the shared domain.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply component adjoints without membership checks.""" if self._num_parts == 2: x0 = self._rapply_parts[0](y[0]) x1 = self._rapply_parts[1](y[1]) @@ -61,8 +73,35 @@ def _rapply_unchecked(self, y: Any) -> Any: acc = xi if acc is None else (acc + xi if use_direct_add else self.dom.add(xi, acc)) return acc + def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this stacked operator over a batch.""" + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.vapply(x, op.domain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + + def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint stacked operator over a product batch.""" + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.domain.batch(batch_shape, batch_axes) + acc = None + for op, yi in zip(self.parts, y): + xi = op.rvapply(yi, op.codomain.batch(batch_shape, batch_axes)) + acc = xi if acc is None else out_space.add(acc, xi) + return acc + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: + """Build a stacked operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -72,6 +111,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> StackedLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> StackedLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/linop/product/_to_single.py b/spacecore/linop/product/_to_single.py index 806f94e..61f2824 100644 --- a/spacecore/linop/product/_to_single.py +++ b/spacecore/linop/product/_to_single.py @@ -4,24 +4,34 @@ from ._base import ProductLinOp from .._base import LinOp, Codomain +from ..._checks import checked_method from ...space import ProductSpace, VectorSpace from ...backend import jax_pytree_class, Context @jax_pytree_class class SumToSingleLinOp(ProductLinOp[ProductSpace, Codomain]): - """ + r""" Sum of component operators from a product domain into a single codomain. - dom = X1 × ... × Xk - cod = Y - - ``ops[i] : Xi -> Y`` - ``apply(x) = sum_i ops[i](x_i)`` - ``rapply(y) = (ops[i]^*(y))_i`` + If ``dom = X1 x ... x Xk`` and ``cod = Y``, component ``parts[i]`` maps + ``Xi`` to ``Y``. Forward application sums component outputs in ``Y``; + adjoint application returns the tuple of component adjoints. + + Parameters + ---------- + dom : ProductSpace + Product domain. + cod : Space + Shared codomain. + parts : sequence of LinOp + Operators from each product component to ``cod``. + ctx : Context, str, or None, optional + Backend context specification. """ def _check_layout(self) -> None: + """Check that every component maps one product part to the shared codomain.""" if not isinstance(self.dom, ProductSpace): raise TypeError("SumToSingleLinOp expects dom to be ProductSpace.") @@ -34,12 +44,13 @@ def _check_layout(self) -> None: else: raise TypeError(f"Component op {i} must map dom.spaces[{i}] -> cod.") + @checked_method(in_space="dom", out_space="cod") def apply(self, x: Any) -> Any: - if self._enable_checks: - self.dom._check_member(x) + """Apply component operators and sum in the codomain.""" return self._apply_unchecked(x) def _apply_unchecked(self, x: Any) -> Any: + """Apply component operators without membership checks.""" if self._num_parts == 2: y0 = self._apply_parts[0](x[0]) y1 = self._apply_parts[1](x[1]) @@ -51,18 +62,46 @@ def _apply_unchecked(self, x: Any) -> Any: acc = yi if acc is None else (acc + yi if use_direct_add else self.cod.add(yi, acc)) return acc + @checked_method(in_space="cod", out_space="dom") def rapply(self, y: Any) -> Any: - if self._enable_checks: - self.cod._check_member(y) + """Apply each component adjoint to the shared codomain element.""" return self._rapply_unchecked(y) def _rapply_unchecked(self, y: Any) -> Any: + """Apply component adjoints without membership checks.""" if self._num_parts == 2: return self._rapply_parts[0](y), self._rapply_parts[1](y) return tuple(rapply(y) for rapply in self._rapply_parts) + def vapply(self, x: Any, batch_space=None) -> Any: + """Apply this sum-to-single operator over a product batch.""" + in_space = self._input_batch_space(self.domain, x, batch_space) + if self._enable_checks: + in_space._check_member(x) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + out_space = self.codomain.batch(batch_shape, batch_axes) + acc = None + for op, xi in zip(self.parts, x): + yi = op.vapply(xi, op.domain.batch(batch_shape, batch_axes)) + acc = yi if acc is None else out_space.add(acc, yi) + return acc + + def rvapply(self, y: Any, batch_space=None) -> Any: + """Apply the adjoint over a codomain batch.""" + in_space = self._input_batch_space(self.codomain, y, batch_space) + if self._enable_checks: + in_space._check_member(y) + batch_shape = in_space.batch_shape + batch_axes = in_space.batch_axes + return tuple( + op.rvapply(y, op.codomain.batch(batch_shape, batch_axes)) + for op in self.parts + ) + @classmethod def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: + """Build a sum-to-single operator from component operators.""" if not parts: raise ValueError("Parts must be non-empty.") @@ -72,6 +111,7 @@ def from_operators(cls, parts: Tuple[LinOp, ...]) -> SumToSingleLinOp: return cls(dom, cod, parts) def _convert(self, new_ctx: Context) -> SumToSingleLinOp: + """Convert spaces and component operators to ``new_ctx``.""" new_dom = self.dom.convert(new_ctx) new_cod = self.cod.convert(new_ctx) new_parts = [op.convert(new_ctx) for op in self.parts] diff --git a/spacecore/space/__init__.py b/spacecore/space/__init__.py index fbb7ada..b9fb285 100644 --- a/spacecore/space/__init__.py +++ b/spacecore/space/__init__.py @@ -1,3 +1,5 @@ +"""Vector space abstractions, concrete spaces, and validation checks.""" + from ._checks import ( BackendCheck, DTypeCheck, @@ -10,11 +12,13 @@ SquareMatrixCheck, ) from ._base import Space +from ._batch import BatchSpace from ._herm import HermitianSpace from ._vector import VectorSpace from ._product import ProductSpace __all__ = [ + "BatchSpace", "BackendCheck", "DTypeCheck", "HermitianCheck", diff --git a/spacecore/space/_base.py b/spacecore/space/_base.py index 26eaf6f..348b554 100644 --- a/spacecore/space/_base.py +++ b/spacecore/space/_base.py @@ -11,12 +11,45 @@ class Space(ContextBound): """ - Abstract Space. + Define the geometry and linear structure of a vector space. - A Space owns the *geometry* (inner product, norm) and the basic linear + A space owns the geometry (inner product, norm) and the basic linear structure (add/scale/axpy) for its elements. - Solvers should use only this API. + Membership validation is exposed through ``check_member``, which respects + the space's ``enable_checks`` policy. Internal code paths that have already + checked that policy may call ``_check_member`` to run the concrete checks + exactly once. + + Parameters + ---------- + shape : tuple of int + Canonical coordinate shape for elements of the space. + ctx : Context, str, or None, optional + Backend context specification. Default resolves to the global context. + + Attributes + ---------- + shape : tuple of int + Canonical element shape. + ctx : Context + Resolved backend context inherited from :class:`ContextBound`. + + Notes + ----- + Solvers use only this API. Concrete spaces define storage constraints, + membership checks, and flattening rules. + + Examples + -------- + Instantiate the concrete :class:`VectorSpace` subclass. + + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> X.shape + (2,) """ checks: ClassVar[tuple[SpaceCheck, ...]] = () @@ -41,15 +74,7 @@ def member_checks(self) -> tuple[SpaceCheck, ...]: return tuple(checks) def _check_member(self, x: Any) -> None: - """ - Raise if `x` is not a valid element of this space. - - Typical checks: - - x.space is self (if your elements carry a .space) - - backend family consistency (via ctx) - - representation is supported - - shape/structure constraints (Hermitian, block sizes, etc.) - """ + """Raise if ``x`` is not a valid element of this space.""" for check in self.member_checks(): check(self, x) @@ -75,18 +100,16 @@ def axpy(self, a: Any, x: Any, y: Any) -> Any: @abstractmethod def inner(self, x: Any, y: Any) -> Any: - """ - Inner product ⟨x, y⟩ for elements of this space. - """ + r"""Return :math:`\langle x, y \rangle_X` for elements of this space.""" def norm(self, x: Any) -> Any: - """Induced norm ||x|| = sqrt(real(⟨x,x⟩)). Override if you can do better.""" + r"""Return the induced norm :math:`\sqrt{\operatorname{Re}\langle x, x\rangle_X}`.""" v = self.ctx.ops.real(self.inner(x, x)) return self.ctx.ops.sqrt(v) @abstractmethod def eigh(self, x: Any, k: int = None) -> Any: - """Eigendecomposition of x (if applicable).)""" + """Return an eigendecomposition of ``x`` when the space defines one.""" @abstractmethod def flatten(self, x: Any) -> DenseArray: @@ -101,6 +124,18 @@ def unflatten(self, v: DenseArray) -> Any: """Inverse of flatten; returns an element in the requested representation.""" raise NotImplementedError + def batch( + self, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...] | None = None, + ) -> Space: + """Return a wrapper representing a batch/product of this space.""" + from ._batch import BatchSpace + + if batch_axes is None: + batch_axes = tuple(range(len(batch_shape))) + return BatchSpace(self, batch_shape, batch_axes) + def _convert(self, new_ctx: Context) -> Space: raise NotImplementedError() diff --git a/spacecore/space/_batch.py b/spacecore/space/_batch.py new file mode 100644 index 0000000..08f8d4c --- /dev/null +++ b/spacecore/space/_batch.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Callable, Tuple + +from ._base import Space +from ._checks import BackendCheck, DTypeCheck, ShapeCheck +from ._product import ProductSpace +from .._checks import checked_method +from ..backend import Context +from ..types import DenseArray + + +def _batched_shape( + base_shape: tuple[int, ...], + batch_shape: tuple[int, ...], + batch_axes: tuple[int, ...], +) -> tuple[int, ...]: + """Interleave base and batch dimensions according to ``batch_axes``.""" + total_ndim = len(base_shape) + len(batch_shape) + axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + if len(batch_shape) != len(axes): + raise ValueError("batch_shape and batch_axes must have the same length.") + if len(set(axes)) != len(axes): + raise ValueError("batch_axes must be unique.") + if any(axis < 0 or axis >= total_ndim for axis in axes): + raise ValueError( + f"batch_axes must be valid axes for batched ndim {total_ndim}, got {batch_axes}." + ) + + out: list[int | None] = [None] * total_ndim + for axis, dim in zip(axes, batch_shape): + out[axis] = int(dim) + + base_iter = iter(int(dim) for dim in base_shape) + for i, dim in enumerate(out): + if dim is None: + out[i] = next(base_iter) + return tuple(dim for dim in out if dim is not None) + + +class BatchSpace(Space): + """ + Represent a batch of elements from a base space. + + ``BatchSpace(X, batch_shape, batch_axes)`` represents ``X`` repeated over + the given batch dimensions. It deliberately wraps the original space rather + than folding batch dimensions into the base ``Space`` instance. + + Parameters + ---------- + base : Space + Space whose elements are batched. + batch_shape : tuple of int + Sizes of batch dimensions. + batch_axes : tuple of int + Axes occupied by batch dimensions in the batched representation. + ctx : Context, str, or None, optional + Backend context specification. Default is ``base.ctx``. + + Attributes + ---------- + base : Space + Converted base space. + batch_shape : tuple of int + Batch dimension sizes. + batch_axes : tuple of int + Batch axis positions. + """ + + def __init__( + self, + base: Space, + batch_shape: Tuple[int, ...], + batch_axes: Tuple[int, ...], + ctx: Context | str | None = None, + ) -> None: + ctx = base.ctx if ctx is None else ctx + super().__init__( + _batched_shape(tuple(base.shape), tuple(batch_shape), tuple(batch_axes)), + ctx, + ) + self.base = base.convert(self.ctx) + self.batch_shape = tuple(int(dim) for dim in batch_shape) + total_ndim = len(self.base.shape) + len(self.batch_shape) + self.batch_axes = tuple(axis + total_ndim if axis < 0 else axis for axis in batch_axes) + self._batch_size = prod(self.batch_shape) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BatchSpace): + return ( + self.ctx == other.ctx + and self.base == other.base + and self.batch_shape == other.batch_shape + and self.batch_axes == other.batch_axes + ) + return False + + @property + def _is_product(self) -> bool: + return isinstance(self.base, ProductSpace) + + def _component_spaces(self) -> tuple[BatchSpace, ...]: + """Return batched component spaces for product-space bases.""" + if not isinstance(self.base, ProductSpace): + raise TypeError("BatchSpace component spaces are available only for ProductSpace bases.") + return tuple(sp.batch(self.batch_shape, self.batch_axes) for sp in self.base.spaces) + + def _check_member(self, x: Any) -> None: + """Raise if ``x`` is not a valid batched element.""" + if isinstance(self.base, ProductSpace): + if not isinstance(x, tuple) or len(x) != self.base.arity: + raise TypeError( + f"BatchSpace over ProductSpace expects tuple length {self.base.arity}." + ) + for space, component in zip(self._component_spaces(), x): + space.check_member(component) + return + + BackendCheck()(self, x) + ShapeCheck()(self, x) + DTypeCheck()(self, x) + for check in self.base.member_checks(): + if isinstance(check, (BackendCheck, ShapeCheck, DTypeCheck)): + continue + check(self, x) + + def zeros(self) -> Any: + """Return the batched zero element.""" + if isinstance(self.base, ProductSpace): + return tuple(space.zeros() for space in self._component_spaces()) + return self.ops.zeros(self.shape, dtype=self.dtype) + + @checked_method(in_space="self", arg_positions=(0, 1)) + def add(self, x: Any, y: Any) -> Any: + """Return the batched sum ``x + y``.""" + if isinstance(self.base, ProductSpace): + return tuple(space.add(xi, yi) for space, xi, yi in zip(self._component_spaces(), x, y)) + return x + y + + @checked_method(in_space="self", arg_positions=(1,)) + def scale(self, a: Any, x: Any) -> Any: + """Return the batched scalar product ``a * x``.""" + if isinstance(self.base, ProductSpace): + return tuple(space.scale(a, xi) for space, xi in zip(self._component_spaces(), x)) + return a * x + + @checked_method(in_space="self", arg_positions=(0, 1)) + def inner(self, x: Any, y: Any) -> Any: + r"""Return :math:`\langle x, y\rangle` over the batched space.""" + if isinstance(self.base, ProductSpace): + acc = None + for space, xi, yi in zip(self._component_spaces(), x, y): + v = space.inner(xi, yi) + acc = v if acc is None else acc + v + return acc + return self.ops.vdot(x, y) + + def eigh(self, x: Any, k: int = None) -> Any: + """Raise because batched spaces do not define eigendecomposition.""" + raise TypeError(f"{type(self).__name__}.eigh is not defined for batched spaces.") + + @checked_method(in_space="self") + def flatten(self, x: Any) -> DenseArray: + """Flatten a batched element into dense coordinates.""" + if isinstance(self.base, ProductSpace): + parts = tuple(space.flatten(xi) for space, xi in zip(self._component_spaces(), x)) + return parts[0] if len(parts) == 1 else self.ops.concatenate(parts, axis=0) + return self.ops.reshape(x, (-1,)) + + def unflatten(self, v: DenseArray) -> Any: + """Convert dense batched coordinates into a batched element.""" + vv = self.ctx.assert_dense(v) if self._enable_checks else v + if isinstance(self.base, ProductSpace): + if ( + tuple(getattr(vv, "shape", ())) == tuple(self.shape) + and self.batch_axes == tuple(range(len(self.batch_shape))) + ): + xs = [] + offset = 0 + for component, space in zip(self.base.spaces, self._component_spaces()): + size = prod(component.shape) + flat_component = vv[(..., slice(offset, offset + size))] + xs.append(space.unflatten(flat_component)) + offset += size + return tuple(xs) + xs = [] + offset = 0 + for space in self._component_spaces(): + size = prod(space.shape) + xs.append(space.unflatten(vv[offset : offset + size])) + offset += size + return tuple(xs) + return self.ops.reshape(vv, self.shape) + + @checked_method(in_space="self", out_space="self") + def apply(self, x: Any, f: Callable) -> Any: + """Apply a function over batched elements using base-space semantics.""" + if isinstance(self.base, ProductSpace): + return tuple(space.apply(xi, f) for space, xi in zip(self._component_spaces(), x)) + try: + y = f(x) + except Exception: + y = self.ops.vmap(lambda xi: self.base.apply(xi, f))(x) + return y + + def _convert(self, new_ctx: Context) -> BatchSpace: + """Convert the base space to ``new_ctx``.""" + return BatchSpace(self.base.convert(new_ctx), self.batch_shape, self.batch_axes, new_ctx) diff --git a/spacecore/space/_checks.py b/spacecore/space/_checks.py index 622ce32..4ec8035 100644 --- a/spacecore/space/_checks.py +++ b/spacecore/space/_checks.py @@ -10,6 +10,7 @@ class SpaceValidationError(ValueError, TypeError): def _shape_of(space: Any, x: Any) -> tuple[int, ...] | None: + """Return the backend-visible shape of ``x`` when available.""" try: return tuple(space.ops.shape(x)) except Exception: @@ -18,6 +19,7 @@ def _shape_of(space: Any, x: Any) -> tuple[int, ...] | None: def _dtype_of(space: Any, x: Any) -> Any: + """Return the backend-visible dtype of ``x`` when available.""" try: return space.ops.get_dtype(x) except Exception: @@ -26,23 +28,44 @@ def _dtype_of(space: Any, x: Any) -> Any: @dataclass(frozen=True) class SpaceCheck(ABC): + """ + Define a membership check for :class:`Space` objects. + + Parameters + ---------- + name : str + Human-readable check name used in diagnostics. + """ + name: str def __call__(self, space: Any, x: Any) -> None: + """Raise :class:`SpaceValidationError` when ``x`` is invalid.""" if not self.is_valid(space, x): raise SpaceValidationError(self.error_message(space, x)) @abstractmethod def is_valid(self, space: Any, x: Any) -> bool: + """Return whether ``x`` is valid for ``space``.""" ... @abstractmethod def error_message(self, space: Any, x: Any) -> str: + """Return a diagnostic for an invalid ``x``.""" ... @dataclass(frozen=True) class BackendCheck(SpaceCheck): + """ + Check that a value is a dense array for a space backend. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"backend"``. + """ + name: str = "backend" def is_valid(self, space: Any, x: Any) -> bool: @@ -54,6 +77,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ShapeCheck(SpaceCheck): + """ + Check that a value has the canonical shape of a space. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"shape"``. + """ + name: str = "shape" def is_valid(self, space: Any, x: Any) -> bool: @@ -65,6 +97,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class DTypeCheck(SpaceCheck): + """ + Check that a value has the dtype required by a space context. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"dtype"``. + """ + name: str = "dtype" def is_valid(self, space: Any, x: Any) -> bool: @@ -76,6 +117,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class SquareMatrixCheck(SpaceCheck): + """ + Check that a value has square trailing matrix axes. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"square_matrix"``. + """ + name: str = "square_matrix" def is_valid(self, space: Any, x: Any) -> bool: @@ -88,6 +138,21 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class HermitianCheck(SpaceCheck): + """ + Check that a value is Hermitian within tolerances. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"hermitian"``. + atol : float, optional + Absolute tolerance for Hermitian comparison. + rtol : float, optional + Relative tolerance for Hermitian comparison. + enforce : bool, optional + Whether to enforce the Hermitian comparison. + """ + name: str = "hermitian" atol: float = 1e-8 rtol: float = 1e-8 @@ -113,6 +178,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ProductStructureCheck(SpaceCheck): + """ + Check that a product-space value is a tuple of the right length. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"product_structure"``. + """ + name: str = "product_structure" def is_valid(self, space: Any, x: Any) -> bool: @@ -126,6 +200,15 @@ def error_message(self, space: Any, x: Any) -> str: @dataclass(frozen=True) class ProductComponentCheck(SpaceCheck): + """ + Check each component of a product-space value. + + Parameters + ---------- + name : str, optional + Check name. Default is ``"product_components"``. + """ + name: str = "product_components" def is_valid(self, space: Any, x: Any) -> bool: diff --git a/spacecore/space/_herm.py b/spacecore/space/_herm.py index ef6418b..0c9a71d 100644 --- a/spacecore/space/_herm.py +++ b/spacecore/space/_herm.py @@ -4,13 +4,14 @@ from ._checks import HermitianCheck, SquareMatrixCheck from ._vector import VectorSpace +from .._checks import checked_method from ..types import DenseArray from ..backend import Context class HermitianSpace(VectorSpace): - """ - Space of dense n×n Hermitian matrices. + r""" + Represent dense Hermitian matrices with Frobenius geometry. Elements are backend-native dense arrays with shape ``(n, n)``. Membership enforces Hermitian structure up to tolerances. @@ -18,6 +19,24 @@ class HermitianSpace(VectorSpace): The inner product is Frobenius / Hilbert-Schmidt: `` = vdot(vec(X), vec(Y))``, where ``vdot`` conjugates the first argument according to backend rules. + + Parameters + ---------- + n : int + Matrix dimension. + atol : float, optional + Absolute tolerance for Hermitian membership checks. + rtol : float, optional + Relative tolerance for Hermitian membership checks. + enforce_herm : bool, optional + Whether membership checks enforce Hermitian structure. + ctx : Context, str, or None, optional + Backend context specification. + + Attributes + ---------- + n : int + Matrix dimension. """ def __init__(self, @@ -47,9 +66,11 @@ def __eq__(self, other: Any) -> bool: @property def n(self) -> int: + """Matrix dimension of this Hermitian space.""" return self.shape[0] def _local_checks(self): + """Return membership checks local to Hermitian spaces.""" return ( SquareMatrixCheck(), HermitianCheck( @@ -60,6 +81,7 @@ def _local_checks(self): ) def is_hermitian(self, x: DenseArray) -> bool: + """Return whether ``x`` satisfies this space's Hermitian check.""" return HermitianCheck( atol=self.atol, rtol=self.rtol, @@ -67,25 +89,29 @@ def is_hermitian(self, x: DenseArray) -> bool: ).is_valid(self, x) def symmetrize(self, x: DenseArray) -> DenseArray: - """Project onto the Hermitian cone: (X + X^H)/2.""" + r"""Project ``x`` onto the Hermitian subspace as :math:`(X + X^*) / 2`.""" return (x + x.T.conj()) * 0.5 + @checked_method(in_space="self") def eigh(self, x: DenseArray, k: int = None) -> Tuple[DenseArray, DenseArray]: - self.check_member(x) + """Return the eigendecomposition of a Hermitian element.""" return self.ops.eigh(x) def unflatten(self, v: DenseArray) -> DenseArray: + """Reshape dense coordinates and symmetrize the result.""" vv = self.ctx.assert_dense(v) if self._enable_checks else v X = vv.reshape(self.shape) return self.symmetrize(X) + @checked_method(in_space="self") def psd_proj(self, x: DenseArray) -> DenseArray: - self.check_member(x) + """Project a Hermitian element onto the positive semidefinite cone.""" evals, evecs = self.ops.eigh(x) evals = self.ops.maximum(evals, 0.) return self.eig_to_dense(evals, evecs) def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: + """Reconstruct a Hermitian matrix from eigenvalues and eigenvectors.""" self.ctx.assert_dense(evals) self.ctx.assert_dense(evecs) X = (evecs * evals) @ evecs.T.conj() @@ -93,8 +119,10 @@ def eig_to_dense(self, evals: DenseArray, evecs: DenseArray) -> DenseArray: return X def _convert(self, new_ctx: Context) -> HermitianSpace: + """Convert this Hermitian space to ``new_ctx``.""" return HermitianSpace(self.n, self.atol, self.rtol, self.enforce_herm, new_ctx) + @checked_method(in_space="self") def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: r""" Apply a scalar function to a Hermitian matrix via spectral calculus. @@ -149,7 +177,6 @@ def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseAr then the eigenvectors are preserved and only the eigenvalues are transformed. """ - self.check_member(x) evals, evecs = self.ops.eigh(x) fevals = self._apply_entrywise(evals, f) diff --git a/spacecore/space/_product.py b/spacecore/space/_product.py index f97966d..991a1b4 100644 --- a/spacecore/space/_product.py +++ b/spacecore/space/_product.py @@ -5,13 +5,15 @@ from ._base import Space from ._checks import ProductComponentCheck, ProductStructureCheck from ._vector import VectorSpace +from .._checks import checked_method from ..types import DenseArray from ..backend import Context -from .._contextual.manager import ctx_manager +from .._contextual import resolve_context_priority def _prod_int(shape: Tuple[int, ...]) -> int: + """Return the integer product of a shape tuple.""" p = 1 for d in shape: p *= int(d) @@ -19,27 +21,41 @@ def _prod_int(shape: Tuple[int, ...]) -> int: class ProductSpace(Space): - """ - Cartesian product space X = X1 × ... × Xk. - - Elements are tuples: - x = (x1, ..., xk) with xi ∈ Xi - - Canonical dense coordinates: - flatten(x) = concat(flatten_i(xi)) - - Notes: - - `shape` for this space is the *1D coordinate length* of the concatenated flattening. - - `eigh` has no canonical meaning here and raises by default. + r""" + Represent a Cartesian product of spaces. + + Elements are tuples ``(x1, ..., xk)`` with ``xi`` in ``spaces[i]``. + Dense coordinates concatenate the flattened coordinates of each component. + + Parameters + ---------- + spaces : tuple of Space + Nonempty tuple of component spaces. + ctx : Context, str, or None, optional + Backend context specification. Default is resolved from components. + + Attributes + ---------- + spaces : tuple of Space + Component spaces converted to ``ctx``. + arity : int + Number of component spaces. + + Notes + ----- + ``shape`` is the one-dimensional coordinate length of the concatenated + flattening. ``eigh`` has no canonical meaning and raises by default. """ def _convert(self, new_ctx: Context) -> Space: + """Convert all component spaces to ``new_ctx``.""" new_spaces = [] for sp in self.spaces: new_spaces.append(sp.convert(new_ctx)) return ProductSpace(tuple(new_spaces), new_ctx) def _local_checks(self): + """Return membership checks local to product spaces.""" return ProductStructureCheck(), ProductComponentCheck() def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) -> None: @@ -47,7 +63,7 @@ def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) raise ValueError("ProductSpace requires at least one subspace.") spaces = self._validate_spaces(spaces) - ctx = ctx_manager.resolve_context_priority(ctx, *spaces) + ctx = resolve_context_priority(ctx, *spaces) dims = tuple(_prod_int(s.shape) for s in spaces) offsets: List[int] = [0] @@ -95,6 +111,7 @@ def __init__(self, spaces: Tuple[Space, ...], ctx: Context | str | None = None) self._is_flat1 = self._component_is_flat[1] def _validate_spaces(self, spaces: Any) -> Tuple[Space, ...]: + """Validate and normalize product component spaces.""" if isinstance(spaces, Sequence): spaces = tuple(spaces) for i, sp in enumerate(spaces): @@ -108,27 +125,26 @@ def _validate_spaces(self, spaces: Any) -> Tuple[Space, ...]: @property def arity(self) -> int: + """Number of component spaces.""" return self._arity def zeros(self) -> Tuple[Any, ...]: + """Return the product-space zero tuple.""" return tuple(s.zeros() for s in self.spaces) + @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Tuple[Any, ...]: - if self._enable_checks: - self._check_member(x) - self._check_member(y) + """Return the componentwise product-space sum.""" return tuple(s.add(xi, yi) for s, xi, yi in zip(self.spaces, x, y)) + @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Tuple[Any, ...]) -> Tuple[Any, ...]: - if self._enable_checks: - self._check_member(x) + """Return the componentwise scalar product.""" return tuple(s.scale(a, xi) for s, xi in zip(self.spaces, x)) + @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) - + r"""Return the sum of component inner products.""" # Accumulate via backend ops (vdot works for scalars too, but sum is enough) acc = None for s, xi, yi in zip(self.spaces, x, y): @@ -137,15 +153,15 @@ def inner(self, x: Tuple[Any, ...], y: Tuple[Any, ...]) -> Any: return acc def eigh(self, x: Any, k: int = None) -> Any: + """Raise because product spaces do not define a canonical eigendecomposition.""" raise NotImplementedError( "ProductSpace.eigh is not defined. " "Call eigh on a specific component space, or define a custom convention." ) + @checked_method(in_space="self") def flatten(self, x: Tuple[Any, ...]) -> DenseArray: - if self._enable_checks: - self._check_member(x) - + """Concatenate component coordinate vectors into one dense vector.""" if self._vector_fast_path: if self._arity == 1: return x[0] if self._component_is_flat[0] else x[0].reshape((-1,)) @@ -178,6 +194,7 @@ def flatten(self, x: Tuple[Any, ...]) -> DenseArray: return self._concatenate(parts, axis=0) def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: + """Split dense coordinates into component-space elements.""" if self._enable_checks: v = self.ctx.assert_dense(v) v1 = v if tuple(getattr(v, "shape", ())) == self.shape else v.reshape((-1,)) @@ -210,6 +227,7 @@ def unflatten(self, v: DenseArray) -> Tuple[Any, ...]: return tuple(xs) + @checked_method(in_space="self", out_space="self") def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: r""" Apply a function to each component of a product-space element. @@ -257,8 +275,6 @@ def apply(self, x: Tuple[Any, ...], f: Callable[[Any], Any]) -> Tuple[Any, ...]: product space. It applies the existing functional calculus of each factor space independently, component by component. """ - if self._enable_checks: - self._check_member(x) if self._arity == 2: return self.spaces[0].apply(x[0], f), self.spaces[1].apply(x[1], f) return tuple(s.apply(xi, f) for s, xi in zip(self.spaces, x)) diff --git a/spacecore/space/_vector.py b/spacecore/space/_vector.py index 2f23b50..65e985b 100644 --- a/spacecore/space/_vector.py +++ b/spacecore/space/_vector.py @@ -5,21 +5,42 @@ from ._base import Space from ._checks import BackendCheck, DTypeCheck, ShapeCheck +from .._checks import checked_method from ..types import DenseArray from ..backend import Context class VectorSpace(Space): - """ - Dense vector space R^{n1, ..., nK} or C^{n1, ..., nK}. - - Elements: - - backend-native dense arrays; - - canonical shape is (n1, ..., nK). - - Geometry: - - Euclidean / ℓ2 inner product - ⟨x, y⟩ = vdot(x, y). + r""" + Represent dense backend arrays with Euclidean geometry. + + Elements are backend-native dense arrays with canonical shape ``shape``. + The inner product is :math:`\langle x, y\rangle_X = \operatorname{vdot}(x,y)`, + where the backend conjugates the first argument for complex arrays. + + Parameters + ---------- + shape : tuple of int + Canonical coordinate shape for elements of the space. + ctx : Context, str, or None, optional + Backend context specification. Default resolves to the global context. + + Attributes + ---------- + shape : tuple of int + Canonical element shape. + ctx : Context + Resolved backend context. + + Examples + -------- + >>> import numpy as np + >>> import spacecore as sc + >>> ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + >>> X = sc.VectorSpace((2,), ctx) + >>> x = ctx.asarray([1.0, 2.0]) + >>> X.inner(x, x) + np.float64(5.0) """ def __init__(self, shape: Tuple[int, ...], ctx: Context | str | None = None) -> None: @@ -28,46 +49,50 @@ def __init__(self, shape: Tuple[int, ...], ctx: Context | str | None = None) -> self._is_flat_shape = self.shape == (self._size,) def _local_checks(self): + """Return membership checks local to dense vector spaces.""" return BackendCheck(), ShapeCheck(), DTypeCheck() def zeros(self) -> DenseArray: + """Return the zero vector in this space.""" return self.ops.zeros(self.shape, dtype=self.dtype) + @checked_method(in_space="self", arg_positions=(0, 1)) def add(self, x: Any, y: Any) -> DenseArray: - if self._enable_checks: - self._check_member(x) - self._check_member(y) + """Return the vector-space sum ``x + y``.""" return x + y + @checked_method(in_space="self", arg_positions=(1,)) def scale(self, a: Any, x: Any) -> DenseArray: - if self._enable_checks: - self._check_member(x) + """Return the scalar product ``a * x``.""" return a * x + @checked_method(in_space="self", arg_positions=(0, 1)) def inner(self, x: Any, y: Any) -> Any: - if self._enable_checks: - self._check_member(x) - self._check_member(y) + r"""Return :math:`\langle x, y\rangle_X` using backend ``vdot``.""" return self.ops.vdot(x, y) def eigh(self, x: Any, k: int = None) -> Any: + """Raise because vector elements do not have a canonical eigendecomposition.""" raise TypeError( f"{type(self).__name__}.eigh is not defined for vector spaces." ) + @checked_method(in_space="self") def flatten(self, X: DenseArray) -> DenseArray: - if self._enable_checks: - self._check_member(X) + """Return ``X`` as a dense one-dimensional coordinate vector.""" return X if self._is_flat_shape else X.reshape((-1,)) def unflatten(self, v: DenseArray) -> DenseArray: + """Reshape a flat coordinate vector into this space's canonical shape.""" V = self.ctx.assert_dense(v) if self._enable_checks else v return V if self._is_flat_shape else V.reshape(self.shape) def _convert(self, new_ctx: Context) -> VectorSpace: + """Convert this vector space to ``new_ctx`` without changing shape.""" return VectorSpace(self.shape, new_ctx) def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: + """Apply ``f`` entrywise and verify that shape is preserved.""" try: y = f(x) except Exception: @@ -78,6 +103,7 @@ def _apply_entrywise(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) raise ValueError("Function application changed shape.") return y + @checked_method(in_space="self", out_space="self") def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseArray: r""" Apply a scalar function to a vector-space element entrywise. @@ -121,7 +147,5 @@ def apply(self, x: DenseArray, f: Callable[[DenseArray], DenseArray]) -> DenseAr application is performed entrywise in the distinguished coordinate representation. """ - if self._enable_checks: - self._check_member(x) y = self._apply_entrywise(x, f) return y diff --git a/spacecore/types/__init__.py b/spacecore/types/__init__.py index 65738fd..b91a969 100644 --- a/spacecore/types/__init__.py +++ b/spacecore/types/__init__.py @@ -1,3 +1,5 @@ +"""Common typing aliases and protocols used by SpaceCore.""" + from ._array import ArrayLike, DenseArray, SparseArray from ._dtype import DType from ._misc import Index, T, X, Y, R, Carry diff --git a/spacecore/types/_array.py b/spacecore/types/_array.py index ec4b03f..ce8501e 100644 --- a/spacecore/types/_array.py +++ b/spacecore/types/_array.py @@ -9,11 +9,20 @@ @runtime_checkable class ArrayLike(Protocol): - """Minimal array-like object accepted by public backend helpers. + """ + Define the minimal array-like object accepted by backend helpers. This intentionally only models common metadata. NumPy arrays, JAX arrays, PyTorch tensors, sparse arrays, scalar-like backend arrays, and array wrappers can satisfy this without implementing every dense-array method. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete array implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete array + implementations. """ @property @@ -24,12 +33,22 @@ def dtype(self) -> DType: ... class SparseArray(ArrayLike, Protocol): - """Portable sparse-array surface used by sparse linear operators. + """ + Define the portable sparse-array surface used by sparse operators. Backend-specific sparse APIs such as SciPy ``tocsr()``, JAX sparse ``indices``/``data``, and Torch ``to_dense()`` are intentionally not part of this protocol. Concrete backends may use those after checking that the object belongs to their sparse family. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete sparse array + implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete sparse array + implementations. """ @property @@ -41,12 +60,21 @@ def __matmul__(self, other: Any) -> Any: ... class DenseArray(ArrayLike, Protocol): - """Portable dense-array surface covering NumPy, JAX, and PyTorch arrays. + """ + Define the portable dense-array surface used by core abstractions. The protocol includes only operations that SpaceCore core abstractions use directly on dense arrays. Backend-specific metadata such as device, sharding, layout, strides, and gradient state belongs to concrete backend implementations, not to this portable type. + + Parameters + ---------- + *args : Any + Construction arguments accepted by concrete dense array implementations. + **kwargs : Any + Keyword construction arguments accepted by concrete dense array + implementations. """ @property diff --git a/tests/_helpers.py b/tests/_helpers.py index 267769c..d06511c 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,5 +1,6 @@ from __future__ import annotations import importlib.util +from functools import lru_cache import numpy as np @@ -11,6 +12,18 @@ def has_torch() -> bool: return importlib.util.find_spec("torch") is not None +@lru_cache +def has_cupy() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + cupy.asarray([0]).sum() + except Exception: + return False + return True + + def jax_real_dtype(): if not has_jax(): return np.float32 @@ -36,9 +49,37 @@ def torch_complex_dtype(): return torch.complex128 if torch.get_default_dtype() == torch.float64 else torch.complex64 +def cupy_real_dtype(): + return np.float64 + + +def cupy_complex_dtype(): + return np.complex128 + + def to_numpy(x): if isinstance(x, tuple): return tuple(to_numpy(xi) for xi in x) + if has_cupy(): + import cupy + if isinstance(x, cupy.ndarray): + return cupy.asnumpy(x) + try: + import cupyx.scipy.sparse as cupy_sparse + sparse_types = tuple( + typ + for typ in ( + getattr(cupy_sparse, "spmatrix", None), + getattr(cupy_sparse, "csr_matrix", None), + getattr(cupy_sparse, "csc_matrix", None), + getattr(cupy_sparse, "coo_matrix", None), + ) + if typ is not None + ) + if sparse_types and isinstance(x, sparse_types): + return cupy.asnumpy(x.toarray()) + except Exception: + pass if has_torch(): import torch if isinstance(x, torch.Tensor): diff --git a/tests/backend/test_backend_loops.py b/tests/backend/test_backend_loops.py new file mode 100644 index 0000000..3b3be1a --- /dev/null +++ b/tests/backend/test_backend_loops.py @@ -0,0 +1,119 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, has_jax, has_torch, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + pytest.param( + "cupy", + np.float64, + marks=pytest.mark.skipif(not has_cupy(), reason="cupy is not installed"), + id="cupy", + ), + ] + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + if name == "cupy": + return sc.CuPyOps() + raise ValueError(f"Unknown backend {name!r}.") + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_fori_loop_accumulates_indices(backend_name, dtype): + ops = _ops_for_backend(backend_name) + + def body_fun(i, carry): + return carry + ops.asarray(i, dtype=dtype) + + out = ops.fori_loop(0, 5, body_fun, ops.asarray(0.0, dtype=dtype)) + + np.testing.assert_allclose(to_numpy(out), 10.0) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_while_loop_reaches_terminal_state(backend_name, dtype): + ops = _ops_for_backend(backend_name) + limit = ops.asarray(4.0, dtype=dtype) + + def cond_fun(carry): + return carry < limit + + def body_fun(carry): + return carry + ops.asarray(1.0, dtype=dtype) + + out = ops.while_loop(cond_fun, body_fun, ops.asarray(0.0, dtype=dtype)) + + np.testing.assert_allclose(to_numpy(out), 4.0) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_scan_accumulates_and_stacks_outputs(backend_name, dtype): + ops = _ops_for_backend(backend_name) + xs = ops.asarray([1.0, 2.0, 3.0, 4.0], dtype=dtype) + + def body_fun(carry, x): + new_carry = carry + x + return new_carry, new_carry * ops.asarray(2.0, dtype=dtype) + + final, ys = ops.scan(body_fun, ops.asarray(0.0, dtype=dtype), xs) + + np.testing.assert_allclose(to_numpy(final), 10.0) + np.testing.assert_allclose(to_numpy(ys), [2.0, 6.0, 12.0, 20.0]) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_scan_without_xs_uses_explicit_length(backend_name, dtype): + ops = _ops_for_backend(backend_name) + + def body_fun(carry, _): + new_carry = carry + ops.asarray(2.0, dtype=dtype) + return new_carry, new_carry + + final, ys = ops.scan(body_fun, ops.asarray(1.0, dtype=dtype), None, length=3) + + np.testing.assert_allclose(to_numpy(final), 7.0) + np.testing.assert_allclose(to_numpy(ys), [3.0, 5.0, 7.0]) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_cond_selects_expected_branch(backend_name, dtype): + ops = _ops_for_backend(backend_name) + x = ops.asarray(3.0, dtype=dtype) + + def true_fun(value): + return value + ops.asarray(10.0, dtype=dtype) + + def false_fun(value): + return value - ops.asarray(10.0, dtype=dtype) + + true_out = ops.cond(True, true_fun, false_fun, x) + false_out = ops.cond(False, true_fun, false_fun, x) + + np.testing.assert_allclose(to_numpy(true_out), 13.0) + np.testing.assert_allclose(to_numpy(false_out), -7.0) diff --git a/tests/backend/test_backend_ops_delegation.py b/tests/backend/test_backend_ops_delegation.py new file mode 100644 index 0000000..770ce32 --- /dev/null +++ b/tests/backend/test_backend_ops_delegation.py @@ -0,0 +1,238 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import ( + has_jax, + has_torch, + jax_complex_dtype, + jax_real_dtype, + to_numpy, + torch_complex_dtype, + torch_real_dtype, +) + + +DELEGATED_METHODS = ( + "reshape", + "sum", + "eigh", + "trace", + "concatenate", + "transpose", + "matmul", +) + +TORCH_AAC_DELEGATED_METHODS = ( + "mean", + "prod", + "sort", + "argsort", + "argmin", + "argmax", + "clip", + "take", + "diagonal", + "squeeze", +) + + +def test_numpy_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.NumpyOps.xp.__name__ == "array_api_compat.numpy" + for name in DELEGATED_METHODS: + assert getattr(sc.NumpyOps, name) is getattr(sc.BackendOps, name) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.JaxOps.xp is sc.JaxOps.jnp + for name in DELEGATED_METHODS: + assert getattr(sc.JaxOps, name) is getattr(sc.BackendOps, name) + + +def test_torch_ops_uses_aac_namespace_when_available(): + sc = importlib.import_module("spacecore") + + assert sc.TorchOps.xp.__name__ == "array_api_compat.torch" + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_ops_inherits_aac_delegated_methods(): + sc = importlib.import_module("spacecore") + + for name in TORCH_AAC_DELEGATED_METHODS: + assert getattr(sc.TorchOps, name) is getattr(sc.BackendOps, name) + + +def _check_raw_delegated_ops(ops, dtype): + x = ops.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + y = ops.reshape(x, (3, 2)) + h = ops.asarray([[2.0, 1.0], [1.0, 3.0]], dtype=dtype) + cube = ops.reshape(ops.arange(24, dtype=dtype), (2, 3, 4)) + singleton = ops.reshape(ops.arange(6, dtype=dtype), (1, 2, 1, 3)) + + np.testing.assert_allclose(to_numpy(ops.matmul(h, ops.asarray([1.0, 2.0], dtype=dtype))), [4.0, 7.0]) + np.testing.assert_allclose(to_numpy(ops.reshape(x, (6,))), np.arange(1.0, 7.0)) + np.testing.assert_allclose(to_numpy(ops.sum(x, axis=0)), [5.0, 7.0, 9.0]) + np.testing.assert_allclose(to_numpy(ops.sum(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).sum(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.prod(cube + 1, axis=(0, 2))), (np.arange(24.0).reshape(2, 3, 4) + 1).prod(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.mean(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).mean(axis=(0, 2))) + assert ops.shape(ops.squeeze(singleton)) == (2, 3) + np.testing.assert_allclose(to_numpy(ops.trace(h)), 5.0) + np.testing.assert_allclose(to_numpy(ops.concatenate((x, x), axis=0)), np.concatenate((to_numpy(x), to_numpy(x)), axis=0)) + np.testing.assert_allclose(to_numpy(ops.transpose(y)), to_numpy(y).T) + + evals, evecs = ops.eigh(h) + np.testing.assert_allclose(to_numpy(evecs @ ops.diag(evals) @ ops.transpose(evecs)), to_numpy(h), atol=1e-6) + + +def test_numpy_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.NumpyOps(), np.float64) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.JaxOps(), jax_real_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.TorchOps(), torch_real_dtype()) + + +def test_numpy_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), np.finfo(ops.sanitize_dtype(None)).eps) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + + np.testing.assert_allclose(ops.eps(jax_real_dtype()), np.finfo(jax_real_dtype()).eps) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + import torch + + ops = sc.TorchOps() + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), torch.finfo(ops.sanitize_dtype(None)).eps) + + +def test_backend_ops_hash_and_repr(): + sc = importlib.import_module("spacecore") + first = sc.NumpyOps() + second = sc.NumpyOps() + + assert hash(first) == hash(second) + assert {first: 1, second: 2} == {first: 2} + assert repr(first) == "NumpyOps(family='numpy')" + + +def test_numpy_eps_distinguishes_dtype_precision(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + assert ops.eps(np.float64) < ops.eps(np.float32) + + +def test_numpy_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + x = ops.asarray([1.0, 2.0]) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + x = ops.asarray([1.0, 2.0], dtype=jax_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.TorchOps() + x = ops.asarray([1.0, 2.0], dtype=torch_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +def _check_complex_adjoint(ops, dtype): + sc = importlib.import_module("spacecore") + ctx = sc.Context(ops, dtype=dtype) + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + H = sc.HermitianSpace(2, ctx=ctx) + raw = ctx.asarray([[1.0 + 0.0j, 2.0 + 3.0j], [-1.0 + 4.0j, -0.5 + 0.0j]]) + herm = H.symmetrize(raw) + evals, evecs = H.eigh(herm) + rebuilt = (evecs * evals) @ ctx.ops.conj(ctx.ops.transpose(evecs)) + np.testing.assert_allclose(to_numpy(rebuilt), to_numpy(herm), rtol=1e-6, atol=1e-6) + + +def test_numpy_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.TorchOps(), torch_complex_dtype()) diff --git a/tests/backend/test_backend_registry.py b/tests/backend/test_backend_registry.py index f55aea9..51a20b9 100644 --- a/tests/backend/test_backend_registry.py +++ b/tests/backend/test_backend_registry.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from tests._helpers import has_torch +from tests._helpers import has_cupy, has_torch def test_builtin_backends_are_usable(): @@ -14,55 +14,23 @@ def test_register_ops_adds_backend(): sc = importlib.import_module("spacecore") class DummyOps(sc.BackendOps): + import array_api_compat.numpy as xp + _family = "dummy" _allow_sparse = False - _dense_array = np.ndarray - _sparse_array = object + + @property + def dense_array(self): + return np.ndarray + + @property + def sparse_array(self): + return None + def sanitize_dtype(self, dtype): return np.dtype(dtype) if dtype is not None else np.dtype(np.float64) - def asarray(self, x, dtype=None): return np.asarray(x, dtype=dtype) def assparse(self, x, dtype=None): raise NotImplementedError - def is_dense(self, x): return isinstance(x, np.ndarray) - def is_sparse(self, x): return False - def get_dtype(self, x): return x.dtype - def zeros(self, shape, dtype=None): return np.zeros(shape, dtype=dtype) - def ones(self, shape, dtype=None): return np.ones(shape, dtype=dtype) - def full(self, shape, fill_value, dtype=None): return np.full(shape, fill_value, dtype=dtype) - def empty(self, shape, dtype=None): return np.empty(shape, dtype=dtype) - def eye(self, n, dtype=None): return np.eye(n, dtype=dtype) - def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) - def reshape(self, x, shape): return np.reshape(x, shape) - def ravel(self, x): return np.ravel(x) - def transpose(self, x, axes=None): return np.transpose(x, axes=axes) - def swapaxes(self, x, axis1, axis2): return np.swapaxes(x, axis1, axis2) - def conj(self, x): return np.conj(x) - def sum(self, x, axis=None, **kwargs): return np.sum(x, axis=axis) - def prod(self, x, axis=None, **kwargs): return np.prod(x, axis=axis) - def trace(self, x, **kwargs): return np.trace(x) - def argsort(self, x, **kwargs): return np.argsort(x) - def sort(self, x, **kwargs): return np.sort(x) - def argmin(self, x, **kwargs): return np.argmin(x) - def argmax(self, x, **kwargs): return np.argmax(x) - def vdot(self, a, b, **kwargs): return np.vdot(a, b) - def matmul(self, a, b, **kwargs): return a @ b def sparse_matmul(self, a, b): raise NotImplementedError - def kron(self, a, b): return np.kron(a, b) - def einsum(self, subscripts, *operands, **kwargs): return np.einsum(subscripts, *operands) - def eigh(self, x, **kwargs): return np.linalg.eigh(x) def logsumexp(self, *args, **kwargs): raise NotImplementedError - def exp(self, x): return np.exp(x) - def log(self, x): return np.log(x) - def maximum(self, x, y): return np.maximum(x, y) - def minimum(self, x, y): return np.minimum(x, y) - def where(self, condition, x=None, y=None, **kwargs): return np.where(condition, x, y) - def concatenate(self, arrays, axis=0, dtype=None): - out = np.concatenate(arrays, axis=axis) - return out.astype(dtype) if dtype is not None else out - def stack(self, arrays, axis=0): return np.stack(arrays, axis=axis) - def sqrt(self, x): return np.sqrt(x) - def abs(self, x): return np.abs(x) - def real(self, x): return np.real(x) - def imag(self, x): return np.imag(x) - def sign(self, x): return np.sign(x) def index_set(self, x, index, values, copy=True): y = np.array(x, copy=True) @@ -90,10 +58,12 @@ def index_add(self, x, index, values, copy=True): y=np.array(x, copy=True) y[index]+=values return y - def allclose(self, a, b, **kwargs): return np.allclose(a, b, **kwargs) def allclose_sparse(self, a, b, **kwargs): return False sc.register_ops(DummyOps) - assert "dummy" in sc._contextual.manager.ctx_manager.available_ops + assert sc.VectorSpace((1,), "dummy").ctx.ops.family == "dummy" + ops = DummyOps() + x = ops.reshape(ops.arange(6), (2, 3)) + assert np.allclose(ops.sum(x, axis=0), [3, 5, 7]) @pytest.mark.skipif(not has_torch(), reason="torch is not installed") @@ -103,3 +73,11 @@ def test_torch_backend_aliases_resolve_when_available(): assert isinstance(sc.TorchOps(), sc.BackendOps) assert sc.VectorSpace((1,), "torch").ctx.ops.family == "torch" assert sc.VectorSpace((1,), "pytorch").ctx.ops.family == "torch" + + +@pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") +def test_cupy_backend_alias_resolves_when_available(): + sc = importlib.import_module("spacecore") + + assert isinstance(sc.CuPyOps(), sc.BackendOps) + assert sc.VectorSpace((1,), "cupy").ctx.ops.family == "cupy" diff --git a/tests/backend/test_cupy_ops.py b/tests/backend/test_cupy_ops.py new file mode 100644 index 0000000..dc547ba --- /dev/null +++ b/tests/backend/test_cupy_ops.py @@ -0,0 +1,66 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, to_numpy + + +pytestmark = pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") + + +def _ctx(dtype=np.float64): + sc = importlib.import_module("spacecore") + return sc.Context(sc.CuPyOps(), dtype=dtype) + + +def test_cupy_ops_dense_creation_and_indexing(): + sc = importlib.import_module("spacecore") + ops = sc.CuPyOps() + x = ops.asarray([1.0, 2.0, 3.0], dtype=np.float64) + y = ops.index_set(x, 1, ops.asarray(5.0), copy=True) + z = ops.index_add(y, 0, ops.asarray(2.0), copy=True) + + assert ops.family == "cupy" + assert ops.is_dense(x) + np.testing.assert_allclose(to_numpy(x), [1.0, 2.0, 3.0]) + np.testing.assert_allclose(to_numpy(y), [1.0, 5.0, 3.0]) + np.testing.assert_allclose(to_numpy(z), [3.0, 5.0, 3.0]) + + +def test_cupy_sparse_conversion_and_matmul(): + ctx = _ctx() + dense = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + sparse = ctx.assparse(dense) + x = ctx.asarray([7.0, 8.0]) + + assert ctx.ops.is_sparse(sparse) + np.testing.assert_allclose(to_numpy(ctx.ops.sparse_matmul(sparse, x)), [23.0, 53.0, 83.0]) + assert ctx.ops.allclose_sparse(sparse, ctx.assparse(dense)) + + +def test_cupy_dense_linop_apply_and_rapply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(ctx.asarray(dense), dom, cod, ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + np.testing.assert_allclose(to_numpy(op.apply(x)), dense @ np.asarray([7.0, 8.0])) + np.testing.assert_allclose(to_numpy(op.rapply(y)), dense.T @ np.asarray([1.0, -1.0, 2.0])) + + +def test_cupy_sparse_linop_apply_and_to_dense(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(ctx.assparse(dense), dom, cod, ctx) + x = ctx.asarray([7.0, 8.0]) + + np.testing.assert_allclose(to_numpy(op.apply(x)), dense @ np.asarray([7.0, 8.0])) + np.testing.assert_allclose(to_numpy(op.to_dense()), dense) diff --git a/tests/context/test_checked_method.py b/tests/context/test_checked_method.py new file mode 100644 index 0000000..9ac3d12 --- /dev/null +++ b/tests/context/test_checked_method.py @@ -0,0 +1,133 @@ +import pytest + +from spacecore import checked_method + + +class _RecordingSpace: + def __init__(self, valid): + self.valid = valid + self.calls = [] + + def _check_member(self, value): + self.calls.append(value) + if value != self.valid: + raise ValueError(f"invalid member: {value!r}") + + +class _CheckedDemo: + def __init__(self, enable_checks=True): + self._enable_checks = enable_checks + self.dom = _RecordingSpace("x") + self.cod = _RecordingSpace("y") + self.space = _RecordingSpace("z") + self.apply_result = "y" + self.rapply_result = "x" + self.value_result = 1.0 + self.grad_result = "z" + + @checked_method(in_space="dom", out_space="cod") + def apply(self, x): + """Apply docstring.""" + return self.apply_result + + @checked_method(in_space="cod", out_space="dom") + def rapply(self, y): + return self.rapply_result + + @checked_method(in_space="space") + def value(self, x): + return self.value_result + + @checked_method(in_space="space", out_space="space") + def grad(self, x): + return self.grad_result + + @checked_method(in_space="self", arg_positions=(0, 1)) + def combine(self, x, y): + return "combined" + + @checked_method(in_space="self", arg_pos=0) + def legacy_single_arg(self, x): + return "legacy" + + def _check_member(self, value): + self.space._check_member(value) + + +def test_checked_method_validates_apply_input_and_output(): + demo = _CheckedDemo() + + assert demo.apply("x") == "y" + assert demo.dom.calls == ["x"] + assert demo.cod.calls == ["y"] + + +def test_checked_method_validates_rapply_input_and_output(): + demo = _CheckedDemo() + + assert demo.rapply("y") == "x" + assert demo.cod.calls == ["y"] + assert demo.dom.calls == ["x"] + + +def test_checked_method_validates_value_input(): + demo = _CheckedDemo() + + assert demo.value("z") == 1.0 + assert demo.space.calls == ["z"] + + +def test_checked_method_validates_grad_input_and_output(): + demo = _CheckedDemo() + + assert demo.grad("z") == "z" + assert demo.space.calls == ["z", "z"] + + +def test_checked_method_invalid_input_raises_when_enabled(): + demo = _CheckedDemo(enable_checks=True) + + with pytest.raises(ValueError, match="invalid member"): + demo.apply("bad") + + +def test_checked_method_invalid_output_raises_when_enabled(): + demo = _CheckedDemo(enable_checks=True) + demo.apply_result = "bad" + + with pytest.raises(ValueError, match="invalid member"): + demo.apply("x") + + +def test_checked_method_skips_checks_when_disabled(): + demo = _CheckedDemo(enable_checks=False) + demo.apply_result = "bad" + + assert demo.apply("bad") == "bad" + assert demo.dom.calls == [] + assert demo.cod.calls == [] + + +def test_checked_method_preserves_metadata(): + assert _CheckedDemo.apply.__name__ == "apply" + assert _CheckedDemo.apply.__doc__ == "Apply docstring." + assert _CheckedDemo.apply.__wrapped__ is not None + + +def test_checked_method_supports_self_target_and_multiple_input_args(): + demo = _CheckedDemo() + + assert demo.combine("z", "z") == "combined" + assert demo.space.calls == ["z", "z"] + + +def test_checked_method_arg_pos_alias_still_works(): + demo = _CheckedDemo() + + assert demo.legacy_single_arg("z") == "legacy" + assert demo.space.calls == ["z"] + + +def test_checked_method_rejects_arg_pos_and_arg_positions_together(): + with pytest.raises(TypeError, match="arg_pos"): + checked_method(in_space="space", arg_pos=0, arg_positions=(0,)) diff --git a/tests/context/test_context_resolution.py b/tests/context/test_context_resolution.py index 60e72fc..28e4214 100644 --- a/tests/context/test_context_resolution.py +++ b/tests/context/test_context_resolution.py @@ -32,6 +32,55 @@ def test_public_resolve_context_priority_wrapper(): sc.set_context(original) +def test_resolve_context_priority_uses_explicit_ctx_before_inferred_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float16, enable_checks=True) + inferred = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=True) + explicit = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + try: + sc.set_context(default) + X = sc.VectorSpace((2,), inferred) + + resolved = sc.resolve_context_priority(explicit, X) + + assert resolved == explicit + assert resolved.dtype == np.dtype(np.float64) + assert resolved.enable_checks is False + finally: + sc.set_context(original) + + +def test_resolve_context_priority_uses_inferred_ctx_only_without_explicit_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=True) + inferred = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=False) + try: + sc.set_context(default) + X = sc.VectorSpace((2,), inferred) + + resolved = sc.resolve_context_priority(None, X) + + assert resolved.ops.family == inferred.ops.family + assert resolved.dtype == np.dtype(np.float32) + assert resolved.enable_checks is False + finally: + sc.set_context(original) + + +def test_resolve_context_priority_uses_default_only_without_explicit_or_inferred_ctx(): + sc = importlib.import_module("spacecore") + original = sc.get_context() + default = sc.Context(sc.NumpyOps(), dtype=np.float32, enable_checks=True) + try: + sc.set_context(default) + + assert sc.resolve_context_priority(None) == default + finally: + sc.set_context(original) + + def test_default_context_used_when_none_given(): sc = importlib.import_module("spacecore") original = sc.get_context() diff --git a/tests/context/test_enable_checks.py b/tests/context/test_enable_checks.py index 66e7686..c187116 100644 --- a/tests/context/test_enable_checks.py +++ b/tests/context/test_enable_checks.py @@ -2,7 +2,7 @@ import pytest import spacecore as sc -from spacecore._contextual.contextual import ContextConversionError +from spacecore._contextual import ContextConversionError from tests._helpers import has_jax, jax_real_dtype diff --git a/tests/fixtures/jaxpr_lanczos_smallest.txt b/tests/fixtures/jaxpr_lanczos_smallest.txt new file mode 100644 index 0000000..c3849c6 --- /dev/null +++ b/tests/fixtures/jaxpr_lanczos_smallest.txt @@ -0,0 +1,452 @@ +let _where = { lambda ; a:bool[] b:f32[] c:f32[]. let + d:f32[] = select_n a c b + in (d,) } in +{ lambda ; e:f32[2,2] f:f32[2]. let + _:f32[2,2] = transpose[permutation=(1, 0)] e + _:f32[2,2] = transpose[permutation=(1, 0)] e + g:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] + h:f32[4,2] = broadcast_in_dim 0.0:f32[] + i:f32[3] = broadcast_in_dim 0.0:f32[] + j:f32[4] = broadcast_in_dim 0.0:f32[] + k:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] f f + l:f32[] = sqrt k + m:f32[2] = broadcast_in_dim 0.0:f32[] + n:i32[1] = broadcast_in_dim 0:i32[] + o:f32[2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] m n 1.0:f32[] + p:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] o o + q:f32[] = sqrt p + r:bool[] = gt q 0.0:f32[] + s:f32[] = jit[name=_where jaxpr=_where] r q 1.0:f32[] + t:f32[] = div 1.0:f32[] s + u:f32[] = jit[name=_where jaxpr=_where] r t 0.0:f32[] + v:f32[2] = mul u o + w:bool[] = gt l 9.999999960041972e-13:f32[] + x:i32[] = convert_element_type[new_dtype=int32 weak_type=False] w + y:f32[2] = cond[ + branches=( + { lambda ; z:f32[] ba:f32[2] bb:f32[2] bc:f32[]. let in (bb,) } + { lambda ; bd:f32[] be:f32[2] bf:f32[2] bg:f32[]. let + bh:bool[] = gt bd 0.0:f32[] + bi:f32[] = jit[name=_where jaxpr=_where] bh bd 1.0:f32[] + bj:f32[] = div 1.0:f32[] bi + bk:f32[] = jit[name=_where jaxpr=_where] bh bj 0.0:f32[] + bl:f32[2] = mul bk be + in (bl,) } + ) + ] x l f v 0.0:f32[] + bm:i32[1] = broadcast_in_dim 0:i32[] + bn:f32[4,2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] h bm y + bo:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] + _:f32[4] = broadcast_in_dim 0.0:f32[] + bp:i32[] bq:f32[4,2] br:f32[3] bs:f32[4] _:f32[] _:bool[] = while[ + body_jaxpr={ lambda ; bt:f32[2,2] bu:i32[4] bv:f32[] bw:i32[] bx:f32[4,2] by:f32[3] + bz:f32[4] ca:f32[] cb:bool[]. let + cc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bw + cd:bool[] = lt cc 0:i32[] + ce:i32[] = add cc 4:i32[] + cf:i32[] = select_n cd cc ce + cg:bool[] = lt 0:i32[] 0:i32[] + ch:i32[] = add 0:i32[] 2:i32[] + ci:i32[] = select_n cg 0:i32[] ch + cj:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] bx cf ci + ck:f32[2] = squeeze[dimensions=(0,)] cj + cl:f32[2] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] bt ck + cm:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] ck cl + cn:bool[] = lt bw 0:i32[] + co:i32[] = add bw 3:i32[] + cp:i32[] = select_n cn bw co + cq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] cp + cr:i32[1] = broadcast_in_dim cq + cs:f32[3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] by cr cm + ct:bool[] = eq bw 0:i32[] + cu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ct + cv:f32[2] = cond[ + branches=( + { lambda ; cw:f32[] cx:f32[2] cy:i32[] cz:f32[4] da:f32[4,2] db:f32[2]. let + dc:f32[2] = mul cw cx + dd:f32[2] = sub db dc + de:bool[] = lt cy 0:i32[] + df:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] cy + dg:i32[] = add df 4:i32[] + dh:i32[] = select_n de cy dg + di:f32[1] = dynamic_slice[slice_sizes=(1,)] cz dh + dj:f32[] = squeeze[dimensions=(0,)] di + dk:i32[] = sub cy 1:i32[] + dl:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] dk + dm:bool[] = lt dl 0:i32[] + dn:i32[] = add dl 4:i32[] + do:i32[] = select_n dm dl dn + dp:bool[] = lt 0:i32[] 0:i32[] + dq:i32[] = add 0:i32[] 2:i32[] + dr:i32[] = select_n dp 0:i32[] dq + ds:f32[1,2] = dynamic_slice[slice_sizes=(1, 2)] da do dr + dt:f32[2] = squeeze[dimensions=(0,)] ds + du:f32[2] = mul dj dt + dv:f32[2] = sub dd du + in (dv,) } + { lambda ; dw:f32[] dx:f32[2] dy:i32[] dz:f32[4] ea:f32[4,2] eb:f32[2]. let + ec:f32[2] = mul dw dx + ed:f32[2] = sub eb ec + in (ed,) } + ) + ] cu cm ck bw bz bx cl + ee:i32[] = add bw 1:i32[] + ef:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ee + eg:bool[4] = lt bu ef + eh:f32[4] = jit[ + name=_where + jaxpr={ lambda ; eg:bool[4] ei:f32[] ej:f32[]. let + ek:f32[4] = broadcast_in_dim ei + el:f32[4] = broadcast_in_dim ej + eh:f32[4] = select_n eg el ek + in (eh,) } + ] eg 1.0:f32[] 0.0:f32[] + em:f32[4] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] bx cv + en:f32[4] = mul em eh + eo:f32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,)] en + ep:f32[4,2] = mul eo bx + eq:f32[2] = reduce_sum[axes=(0,) out_sharding=None] ep + er:f32[2] = sub cv eq + es:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] er er + et:f32[] = sqrt es + eu:i32[] = add bw 1:i32[] + ev:bool[] = lt eu 0:i32[] + ew:i32[] = add eu 4:i32[] + ex:i32[] = select_n ev eu ew + ey:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ex + ez:i32[1] = broadcast_in_dim ey + fa:f32[4] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] bz ez et + fb:bool[] = ge et bv + fc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] fb + fd:f32[4,2] = cond[ + branches=( + { lambda ; fe:f32[] ff:f32[2] fg:i32[] fh:f32[4,2]. let + + in (fh,) } + { lambda ; fi:f32[] fj:f32[2] fk:i32[] fl:f32[4,2]. let + fm:bool[] = gt fi 0.0:f32[] + fn:f32[] = jit[name=_where jaxpr=_where] fm fi 1.0:f32[] + fo:f32[] = div 1.0:f32[] fn + fp:f32[] = jit[name=_where jaxpr=_where] fm fo 0.0:f32[] + fq:f32[2] = mul fp fj + fr:i32[] = add fk 1:i32[] + fs:bool[] = lt fr 0:i32[] + ft:i32[] = add fr 4:i32[] + fu:i32[] = select_n fs fr ft + fv:i32[] = convert_element_type[ + new_dtype=int32 + weak_type=False + ] fu + fw:i32[1] = broadcast_in_dim fv + fx:f32[4,2] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] fl fw fq + in (fx,) } + ) + ] fc et er bw bx + fy:i32[] = add bw 1:i32[] + fz:bool[] = ge fy 3:i32[] + ga:i32[] = jit[ + name=remainder + jaxpr={ lambda ; fy:i32[] gb:i32[]. let + gc:bool[] = eq gb 0:i32[] + gd:i32[] = jit[ + name=_where + jaxpr={ lambda ; gc:bool[] ge:i32[] gb:i32[]. let + gd:i32[] = select_n gc gb ge + in (gd,) } + ] gc 1:i32[] gb + gf:i32[] = rem fy gd + gg:bool[] = ne gf 0:i32[] + gh:bool[] = lt gf 0:i32[] + gi:bool[] = lt gd 0:i32[] + gj:bool[] = ne gh gi + gk:bool[] = and gj gg + gl:i32[] = add gf gd + ga:i32[] = select_n gk gf gl + in (ga,) } + ] fy 1:i32[] + gm:bool[] = eq ga 0:i32[] + gn:bool[] = convert_element_type[new_dtype=bool weak_type=False] fz + go:bool[] = convert_element_type[new_dtype=bool weak_type=False] gm + gp:bool[] = or gn go + gq:i32[] = convert_element_type[new_dtype=int32 weak_type=False] gp + gr:bool[] = cond[ + branches=( + { lambda ; gs:f32[] gt:f32[] gu:bool[] gv:f32[]. let in (gu,) } + { lambda ; gw:f32[] gx:f32[] gy:bool[] gz:f32[]. let + ha:bool[] = ge gw gx + in (ha,) } + ) + ] gq et bv cb 0.0:f32[] + in (fy, fd, cs, fa, et, gr) } + body_nconsts=3 + cond_jaxpr={ lambda ; hb:i32[] hc:f32[4,2] hd:f32[3] he:f32[4] hf:f32[] hg:bool[]. let + hh:bool[] = lt hb 3:i32[] + hi:bool[] = convert_element_type[new_dtype=bool weak_type=False] hh + hj:bool[] = and hi hg + in (hj,) } + cond_nconsts=0 + ] e bo 9.999999974752427e-07:f32[] 0:i32[] bn i j 1.0:f32[] True:bool[] + hk:i32[3] = iota[dimension=0 dtype=int32 shape=(3,) sharding=None] + hl:i32[4] = iota[dimension=0 dtype=int32 shape=(4,) sharding=None] + hm:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + hn:bool[3] = lt hk hm + ho:f32[3] = abs br + hp:f32[] = reduce_max[axes=(0,)] ho + hq:f32[4] = abs bs + hr:f32[] = reduce_max[axes=(0,)] hq + hs:f32[] = mul 2.0:f32[] hr + ht:f32[] = add hp hs + hu:f32[] = add ht 1.0:f32[] + hv:f32[3] = jit[ + name=_where + jaxpr={ lambda ; hn:bool[3] br:f32[3] hu:f32[]. let + hw:f32[3] = broadcast_in_dim hu + hv:f32[3] = select_n hn hw br + in (hv,) } + ] hn br hu + hx:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + hy:bool[4] = eq hl hx + hz:f32[4] = jit[ + name=_where + jaxpr={ lambda ; hy:bool[4] ia:f32[] bs:f32[4]. let + ib:f32[4] = broadcast_in_dim ia + hz:f32[4] = select_n hy bs ib + in (hz,) } + ] hy 0.0:f32[] bs + ic:f32[3,3] = broadcast_in_dim 0.0:f32[] + _:i32[] id:f32[3,3] = scan[ + _split_transpose=False + jaxpr={ lambda ; ie:f32[3] if:i32[] ig:f32[3,3]. let + ih:i32[] = add if 1:i32[] + ii:bool[] = lt if 0:i32[] + ij:i32[] = convert_element_type[new_dtype=int32 weak_type=False] if + ik:i32[] = add ij 3:i32[] + il:i32[] = select_n ii if ik + im:f32[1] = dynamic_slice[slice_sizes=(1,)] ie il + in:f32[] = squeeze[dimensions=(0,)] im + io:bool[] = lt if 0:i32[] + ip:i32[] = add if 3:i32[] + iq:i32[] = select_n io if ip + ir:bool[] = lt if 0:i32[] + is:i32[] = add if 3:i32[] + it:i32[] = select_n ir if is + iu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] iq + iv:i32[] = convert_element_type[new_dtype=int32 weak_type=False] it + iw:i32[1] = broadcast_in_dim iu + ix:i32[1] = broadcast_in_dim iv + iy:i32[2] = concatenate[dimension=0] iw ix + iz:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] ig iy in + in (ih, iz) } + length=3 + linear=(False, False, False) + num_carry=2 + num_consts=1 + reverse=False + unroll=1 + ] hv 0:i32[] ic + _:i32[] ja:f32[3,3] = scan[ + _split_transpose=False + jaxpr={ lambda ; jb:f32[4] jc:i32[] jd:f32[3,3]. let + je:i32[] = add jc 1:i32[] + jf:i32[] = add jc 1:i32[] + jg:bool[] = lt jf 0:i32[] + jh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jf + ji:i32[] = add jh 4:i32[] + jj:i32[] = select_n jg jf ji + jk:f32[1] = dynamic_slice[slice_sizes=(1,)] jb jj + jl:f32[] = squeeze[dimensions=(0,)] jk + jm:i32[] = add jc 1:i32[] + jn:bool[] = lt jc 0:i32[] + jo:i32[] = add jc 3:i32[] + jp:i32[] = select_n jn jc jo + jq:bool[] = lt jm 0:i32[] + jr:i32[] = add jm 3:i32[] + js:i32[] = select_n jq jm jr + jt:i32[] = convert_element_type[new_dtype=int32 weak_type=False] jp + ju:i32[] = convert_element_type[new_dtype=int32 weak_type=False] js + jv:i32[1] = broadcast_in_dim jt + jw:i32[1] = broadcast_in_dim ju + jx:i32[2] = concatenate[dimension=0] jv jw + jy:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] jd jx jl + jz:i32[] = add jc 1:i32[] + ka:bool[] = lt jz 0:i32[] + kb:i32[] = add jz 3:i32[] + kc:i32[] = select_n ka jz kb + kd:bool[] = lt jc 0:i32[] + ke:i32[] = add jc 3:i32[] + kf:i32[] = select_n kd jc ke + kg:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kc + kh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] kf + ki:i32[1] = broadcast_in_dim kg + kj:i32[1] = broadcast_in_dim kh + kk:i32[2] = concatenate[dimension=0] ki kj + kl:f32[3,3] = scatter[ + dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1), operand_batching_dims=(), scatter_indices_batching_dims=()) + indices_are_sorted=True + mode=GatherScatterMode.FILL_OR_DROP + unique_indices=True + update_consts=() + update_jaxpr=None + ] jy kk jl + in (je, kl) } + length=2 + linear=(False, False, False) + num_carry=2 + num_consts=1 + reverse=False + unroll=1 + ] hz 0:i32[] id + _:f32[3] km:f32[3,3] = jit[ + name=eigh + jaxpr={ lambda ; ja:f32[3,3]. let + kn:f32[3,3] = transpose[permutation=(1, 0)] ja + ko:f32[3,3] = add ja kn + kp:f32[3,3] = div ko 2.0:f32[] + km:f32[3,3] _:f32[3] = eigh[ + algorithm=None + lower=True + sort_eigenvalues=True + subset_by_index=None + ] kp + in (_, km) } + ] ja + kq:f32[3,1] = slice[limit_indices=(3, 1) start_indices=(0, 0) strides=None] km + kr:f32[3] = squeeze[dimensions=(1,)] kq + ks:bool[] = lt bp 0:i32[] + kt:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + ku:i32[] = add kt 4:i32[] + kv:i32[] = select_n ks bp ku + kw:f32[1] = dynamic_slice[slice_sizes=(1,)] bs kv + kx:f32[] = squeeze[dimensions=(0,)] kw + ky:i32[] = sub bp 1:i32[] + kz:bool[] = lt ky 0:i32[] + la:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ky + lb:i32[] = add la 3:i32[] + lc:i32[] = select_n kz ky lb + ld:f32[1] = dynamic_slice[slice_sizes=(1,)] kr lc + le:f32[] = squeeze[dimensions=(0,)] ld + lf:f32[] = abs le + lg:f32[] = mul kx lf + _:bool[] = lt lg 9.999999974752427e-07:f32[] + lh:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bp + li:bool[3] = lt g lh + lj:f32[3] = jit[ + name=_where + jaxpr={ lambda ; li:bool[3] lk:f32[] ll:f32[]. let + lm:f32[3] = broadcast_in_dim lk + ln:f32[3] = broadcast_in_dim ll + lj:f32[3] = select_n li ln lm + in (lj,) } + ] li 1.0:f32[] 0.0:f32[] + lo:f32[3] = mul kr lj + lp:f32[3,2] = slice[limit_indices=(3, 2) start_indices=(0, 0) strides=None] bq + lq:f32[2] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] lo lp + lr:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] lq lq + ls:f32[] = sqrt lr + lt:bool[] = gt ls 9.999999960041972e-13:f32[] + lu:i32[] = convert_element_type[new_dtype=int32 weak_type=False] lt + lv:f32[2] = cond[ + branches=( + { lambda ; lw:f32[] lx:f32[2] ly:f32[2] lz:f32[]. let in (ly,) } + { lambda ; ma:f32[] mb:f32[2] mc:f32[2] md:f32[]. let + me:bool[] = gt ma 0.0:f32[] + mf:f32[] = jit[name=_where jaxpr=_where] me ma 1.0:f32[] + mg:f32[] = div 1.0:f32[] mf + mh:f32[] = jit[name=_where jaxpr=_where] me mg 0.0:f32[] + mi:f32[2] = mul mh mb + in (mi,) } + ) + ] lu ls lq v 0.0:f32[] + mj:f32[2] = dot_general[ + dimension_numbers=(([1], [0]), ([], [])) + preferred_element_type=float32 + ] e lv + mk:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] lv mj + ml:f32[] = dot_general[ + dimension_numbers=(([0], [0]), ([], [])) + preferred_element_type=float32 + ] lv lv + mm:f32[] = div mk ml + in (mm,) } \ No newline at end of file diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py new file mode 100644 index 0000000..af826d8 --- /dev/null +++ b/tests/functional/test_functional.py @@ -0,0 +1,243 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _quadratic_problem(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 4.0]]), dom, dom, ctx) + linear = sc.InnerProductFunctional(ctx.asarray([1.0, -1.0]), dom, ctx) + return sc.LinOpQuadraticForm(Q, linear, 3.0, ctx) + + +def test_explicit_context_overrides_inferred_contexts(): + sc = importlib.import_module("spacecore") + inferred = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=False) + dom = sc.VectorSpace((2,), inferred) + Q = sc.DenseLinOp(inferred.asarray([[1.0, 0.0], [0.0, 1.0]]), dom, dom, inferred) + linear = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom) + + functional = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom, explicit) + quadratic = sc.LinOpQuadraticForm(Q, linear, 0.0, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert quadratic.ctx == explicit + assert quadratic.Q.ctx == explicit + assert quadratic.linear.ctx == explicit + + +def test_domain_conversion_and_membership_checks_work(): + sc = importlib.import_module("spacecore") + source = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=True) + dom = sc.VectorSpace((2,), source) + functional = sc.InnerProductFunctional(source.asarray([1.0, 2.0]), dom, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert functional.domain.ctx.enable_checks is True + assert np.allclose(functional.value(functional.domain.ctx.asarray([3.0, 4.0])), 11.0) + with pytest.raises(Exception): + functional.value(explicit.asarray([1.0, 2.0, 3.0])) + + +def test_call_matches_value(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + assert np.allclose(q(x), q.value(x)) + + +def test_inner_product_functional_matches_domain_inner(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([1.0, -2.0]) + x = ctx.asarray([3.0, 4.0]) + functional = sc.InnerProductFunctional(c, dom, ctx) + + assert np.allclose(functional.value(x), dom.inner(c, x)) + + +def test_matrix_free_linear_functional_has_no_representer(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([2.0, 3.0]) + x = ctx.asarray([4.0, 5.0]) + functional = sc.MatrixFreeLinearFunctional(lambda y: dom.inner(c, y), dom, ctx) + + assert np.allclose(functional.value(x), 23.0) + with pytest.raises(NotImplementedError): + functional.representer + + +def test_linear_functional_compose_specializes_to_inner_product_functional(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, -1.0], [3.0, 0.5]]), X, Y, ctx) + c = ctx.asarray([2.0, -1.0, 0.5]) + F = sc.InnerProductFunctional(c, Y, ctx) + pullback = F.compose(A) + x = ctx.asarray([4.0, -2.0]) + + assert isinstance(pullback, sc.InnerProductFunctional) + np.testing.assert_allclose(pullback.representer, A.H.apply(c)) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + + +def test_quadratic_form_compose_specializes_quadratic_and_linear_terms(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, -1.0], [3.0, 0.5]]), X, Y, ctx) + Q = sc.IdentityLinOp(Y, ctx) + linear = sc.InnerProductFunctional(ctx.asarray([1.0, -2.0, 0.5]), Y, ctx) + F = sc.LinOpQuadraticForm(Q, linear, 1.25, ctx) + pullback = F.compose(A) + x = ctx.asarray([0.5, -1.5]) + + assert isinstance(pullback, sc.LinOpQuadraticForm) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + np.testing.assert_allclose(pullback.grad(x), A.H.apply(F.grad(A.apply(x)))) + + +def test_generic_functional_compose_forwards_value(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((2,), ctx) + A = sc.DiagonalLinOp(ctx.asarray([2.0, -1.0]), X, ctx) + + class SumSquares(sc.Functional): + def value(self, x): + return self.ops.sum(x * x) + + def tree_flatten(self): + return (), (self.domain, self.ctx) + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + return cls(domain, ctx) + + def _convert(self, new_ctx): + return SumSquares(self.domain.convert(new_ctx), new_ctx) + + F = SumSquares(Y, ctx) + pullback = F.compose(A) + x = ctx.asarray([3.0, 4.0]) + + assert isinstance(pullback, sc.ComposedFunctional) + np.testing.assert_allclose(pullback.value(x), F.value(A.apply(x))) + + +def test_functional_compose_rejects_incompatible_codomain(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.IdentityLinOp(X, ctx) + F = sc.InnerProductFunctional(ctx.asarray([1.0, 2.0, 3.0]), Y, ctx) + + with pytest.raises(ValueError, match="A.codomain == F.domain"): + F.compose(A) + + +def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + + assert np.allclose(q.value(x), 12.0) + assert np.allclose(q.grad(x), [5.0, -5.0]) + assert np.allclose(q.hess_apply(x), [4.0, -4.0]) + + +def test_linop_quadratic_form_hermitian_gradient_is_q_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[2.0, 1.0], [1.0, 4.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(Q, ctx=ctx) + x = ctx.asarray([2.0, -1.0]) + + np.testing.assert_allclose(q.grad(x), Q.apply(x)) + np.testing.assert_allclose(q.hess_apply(x), Q.apply(x)) + + +def test_linop_quadratic_form_rejects_non_hermitian_dense_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, 3.0]]), space, space, ctx) + + with pytest.raises(ValueError, match="Hermitian"): + sc.LinOpQuadraticForm(Q, ctx=ctx) + + +def test_linop_quadratic_form_does_not_validate_matrix_free_hermitian_assumption(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + + def apply(x): + return ctx.asarray([x[0] + 2.0 * x[1], 3.0 * x[1]]) + + def rapply(y): + return ctx.asarray([y[0], 2.0 * y[0] + 3.0 * y[1]]) + + Q = sc.MatrixFreeLinOp(apply, rapply, space, space, ctx) + q = sc.LinOpQuadraticForm(Q, ctx=ctx) + x = ctx.asarray([1.0, 2.0]) + + np.testing.assert_allclose(q.grad(x), Q.apply(x)) + + +def test_linop_quadratic_form_always_rejects_nonscalar_constant(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=False) + space = sc.VectorSpace((2,), ctx) + Q = sc.IdentityLinOp(space, ctx) + + with pytest.raises(ValueError, match="scalar batch"): + sc.LinOpQuadraticForm(Q, a=ctx.asarray([0.0, 0.0]), ctx=ctx) + + +def test_vvalue_and_vgrad_match_elementwise_loops(): + ctx = _ctx() + q = _quadratic_problem(ctx) + xs = ctx.asarray([[2.0, -1.0], [0.0, 3.0], [1.5, 2.0]]) + + expected_values = ctx.ops.stack(tuple(q.value(x) for x in xs), axis=0) + expected_grads = ctx.ops.stack(tuple(q.grad(x) for x in xs), axis=0) + + assert np.allclose(q.vvalue(xs), expected_values) + assert np.allclose(q.vgrad(xs), expected_grads) + + +def test_bad_shapes_raise_when_checks_are_enabled(): + ctx = _ctx(enable_checks=True) + q = _quadratic_problem(ctx) + bad = ctx.asarray([1.0, 2.0, 3.0]) + + with pytest.raises(Exception): + q.value(bad) + with pytest.raises(Exception): + q.grad(bad) + with pytest.raises(Exception): + q.vvalue(ctx.asarray([[1.0, 2.0, 3.0]])) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 26ceb65..78d59b9 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -2,7 +2,7 @@ import tomllib from pathlib import Path -from tests._helpers import has_jax, has_torch +from tests._helpers import has_cupy, has_jax, has_torch ROOT = Path(__file__).resolve().parents[2] @@ -19,15 +19,25 @@ def test_expected_names_are_exported(): sc = importlib.import_module("spacecore") expected = { "Context", "BackendOps", "NumpyOps", "DenseLinOp", "SparseLinOp", + "ScaledLinOp", "SumLinOp", "ComposedLinOp", "ZeroLinOp", + "IdentityLinOp", "MatrixFreeLinOp", "make_sum", "make_scaled", + "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", + "Functional", "LinearFunctional", "InnerProductFunctional", + "ComposedFunctional", "MatrixFreeLinearFunctional", "QuadraticForm", + "LinOpQuadraticForm", "make_functional_composed", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", "set_context", "get_context", "resolve_context_priority", "register_ops", "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", + "LanczosResult", "lanczos_smallest", + "ExpmMultiplyResult", "expm_multiply", } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} + if has_cupy(): + expected |= {"CuPyOps"} if has_torch(): expected |= {"TorchOps"} assert expected.issubset(set(sc.__all__)) @@ -38,17 +48,27 @@ def test_top_level_objects_match_source_modules(): backend = importlib.import_module("spacecore.backend") space = importlib.import_module("spacecore.space") linop = importlib.import_module("spacecore.linop") - manager = importlib.import_module("spacecore._contextual.manager") + functional = importlib.import_module("spacecore.functional") + linalg = importlib.import_module("spacecore.linalg") + contextual = importlib.import_module("spacecore._contextual") assert sc.Context is backend.Context assert sc.NumpyOps is backend.NumpyOps + if has_cupy(): + assert sc.CuPyOps is backend.CuPyOps if has_torch(): assert sc.TorchOps is backend.TorchOps assert sc.Space is space.Space assert sc.VectorSpace is space.VectorSpace assert sc.DenseLinOp is linop.DenseLinOp - assert sc.get_context is manager.get_context - assert sc.resolve_context_priority is manager.resolve_context_priority + assert sc.Functional is functional.Functional + assert sc.ComposedFunctional is functional.ComposedFunctional + assert sc.InnerProductFunctional is functional.InnerProductFunctional + assert sc.LanczosResult is linalg.LanczosResult + assert sc.ExpmMultiplyResult is linalg.ExpmMultiplyResult + assert sc.expm_multiply is linalg.expm_multiply + assert sc.get_context is contextual.get_context + assert sc.resolve_context_priority is contextual.resolve_context_priority def test_package_version_matches_project_metadata(): diff --git a/tests/linalg/__init__.py b/tests/linalg/__init__.py new file mode 100644 index 0000000..cfe7516 --- /dev/null +++ b/tests/linalg/__init__.py @@ -0,0 +1 @@ +"""Tests for SpaceCore linear algebra routines.""" diff --git a/tests/linalg/test_expm.py b/tests/linalg/test_expm.py new file mode 100644 index 0000000..55dc504 --- /dev/null +++ b/tests/linalg/test_expm.py @@ -0,0 +1,192 @@ +import importlib + +import numpy as np +import pytest +import scipy.linalg + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + ] + + +def _ctx(backend_name="numpy", dtype=np.float64, enable_checks=False): + sc = importlib.import_module("spacecore") + return sc.Context(_ops_for_backend(backend_name), dtype=dtype, enable_checks=enable_checks) + + +def _operator(ctx, matrix): + sc = importlib.import_module("spacecore") + space = sc.VectorSpace((matrix.shape[0],), ctx) + return sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + +def _ground_truth(matrix, vector, t): + return scipy.linalg.expm(t * matrix) @ vector + + +def test_expm_multiply_t_zero_returns_input(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + matrix = np.array([[2.0, 0.5], [0.5, 3.0]]) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0, -2.0]) + + result = sc.expm_multiply(A, v, t=0.0, max_iter=4) + + np.testing.assert_allclose(result.result, v, rtol=1e-12, atol=1e-12) + assert isinstance(result, sc.ExpmMultiplyResult) + + +def test_expm_multiply_rejects_structurally_non_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 2.0], [0.0, 3.0]])) + v = ctx.asarray([1.0, -2.0]) + + with pytest.raises(ValueError, match="Hermitian"): + sc.expm_multiply(A, v, t=0.1, max_iter=4) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_expm_multiply_matches_dense_ground_truth(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + matrix = np.array([[1.0, 0.25, 0.0], [0.25, 2.0, -0.5], [0.0, -0.5, 3.0]]) + A = _operator(ctx, matrix) + v_np = np.array([1.0, -2.0, 0.5]) + v = ctx.asarray(v_np) + + result = sc.expm_multiply(A, v, t=-0.2, max_iter=8, tol=1e-12) + + np.testing.assert_allclose( + to_numpy(result.result), + _ground_truth(matrix, v_np, -0.2), + rtol=1e-5, + atol=1e-5, + ) + assert bool(to_numpy(result.converged)) + + +def test_expm_multiply_is_linear_in_vector(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 0.5], [0.5, 2.0]])) + v1 = ctx.asarray([1.0, -1.0]) + v2 = ctx.asarray([0.5, 2.0]) + alpha = 1.5 + beta = -0.25 + + combined = sc.expm_multiply(A, alpha * v1 + beta * v2, t=0.3, max_iter=6).result + expected = ( + alpha * sc.expm_multiply(A, v1, t=0.3, max_iter=6).result + + beta * sc.expm_multiply(A, v2, t=0.3, max_iter=6).result + ) + + np.testing.assert_allclose(combined, expected, rtol=1e-10, atol=1e-10) + + +def test_expm_multiply_group_property(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _operator(ctx, np.array([[1.0, 0.5], [0.5, 2.0]])) + v = ctx.asarray([1.0, -1.0]) + t1 = 0.2 + t2 = -0.35 + + first = sc.expm_multiply(A, v, t=t1, max_iter=6).result + sequential = sc.expm_multiply(A, first, t=t2, max_iter=6).result + direct = sc.expm_multiply(A, v, t=t1 + t2, max_iter=6).result + + np.testing.assert_allclose(sequential, direct, rtol=1e-10, atol=1e-10) + + +def test_expm_multiply_complex_time_is_unitary_for_hermitian_generator(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + matrix = np.array([[1.0, 0.5 - 0.25j], [0.5 + 0.25j, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0 + 0.5j, -0.25 + 0.75j]) + + result = sc.expm_multiply(A, v, t=-0.5j, max_iter=6, tol=1e-12).result + + np.testing.assert_allclose( + to_numpy(A.domain.norm(result)), + to_numpy(A.domain.norm(v)), + rtol=1e-10, + atol=1e-10, + ) + + +def test_expm_multiply_complex_time_matches_dense_truth(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + matrix = np.array([[1.0, 0.25], [0.25, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v_np = np.array([1.0 - 0.5j, 0.25 + 0.75j]) + v = ctx.asarray(v_np) + + result = sc.expm_multiply(A, v, t=-0.5j, max_iter=6, tol=1e-12) + + np.testing.assert_allclose( + to_numpy(result.result), + _ground_truth(matrix, v_np, -0.5j), + rtol=1e-10, + atol=1e-10, + ) + + +def test_expm_multiply_residual_estimate_decreases_with_more_iterations(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + matrix = np.array([[1.0, 0.25, 0.1], [0.25, 2.0, -0.5], [0.1, -0.5, 4.0]]) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0, -2.0, 0.5]) + + low = sc.expm_multiply(A, v, t=0.4, max_iter=1, tol=1e-12) + high = sc.expm_multiply(A, v, t=0.4, max_iter=3, tol=1e-12) + + assert to_numpy(high.residual_estimate) <= to_numpy(low.residual_estimate) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_expm_multiply_jit_matches_eager(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_complex_dtype()) + matrix = np.array([[1.0, 0.25], [0.25, 2.0]], dtype=np.complex128) + A = _operator(ctx, matrix) + v = ctx.asarray([1.0 - 0.5j, 0.25 + 0.75j]) + + eager = sc.expm_multiply(A, v, t=-0.5j, max_iter=6).result + run = jax.jit(lambda op, x: sc.expm_multiply(op, x, t=-0.5j, max_iter=6).result) + compiled = run(A, v) + + np.testing.assert_allclose(to_numpy(compiled), to_numpy(eager), rtol=1e-6, atol=1e-6) diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py new file mode 100644 index 0000000..a85c2fe --- /dev/null +++ b/tests/linalg/test_krylov.py @@ -0,0 +1,524 @@ +import importlib +import inspect + +import numpy as np +import pytest + +from tests._helpers import has_cupy, has_jax, has_torch, jax_real_dtype, to_numpy +from tests._helpers import torch_real_dtype + + +def _backend_params(): + return [ + pytest.param("numpy", np.float64, id="numpy"), + pytest.param( + "jax", + jax_real_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ), + pytest.param( + "torch", + torch_real_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ), + pytest.param( + "cupy", + np.float64, + marks=pytest.mark.skipif(not has_cupy(), reason="cupy is not installed"), + id="cupy", + ), + ] + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + if name == "cupy": + return sc.CuPyOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _ctx(backend_name="numpy", dtype=np.float64): + sc = importlib.import_module("spacecore") + return sc.Context(_ops_for_backend(backend_name), dtype=dtype, enable_checks=False) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_cg_solves_spd_system(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + A = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + b = ctx.asarray([1.0, 2.0]) + + result = sc.cg(A, b, tol=1e-7, maxiter=10) + + np.testing.assert_allclose( + to_numpy(result.x), + np.linalg.solve(np.array([[4.0, 1.0], [1.0, 3.0]]), np.array([1.0, 2.0])), + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose(to_numpy(A.apply(result.x)), to_numpy(b), rtol=1e-5, atol=1e-5) + assert bool(to_numpy(result.converged)) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_lsqr_solves_rectangular_least_squares(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + A = sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + b = ctx.asarray([1.0, 2.0, 4.0]) + + result = sc.lsqr(A, b, tol=1e-7, maxiter=10) + + expected, *_ = np.linalg.lstsq(matrix, np.array([1.0, 2.0, 4.0]), rcond=None) + np.testing.assert_allclose(to_numpy(result.x), expected, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(to_numpy(A.H.apply(A.apply(result.x) - b)), [0.0, 0.0], atol=1e-5) + assert bool(to_numpy(result.converged)) + + +def test_lsqr_works_with_matrix_free_linop_and_uses_rapply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + calls = {"rapply": 0} + + def apply(x): + return matrix @ x + + def rapply(y): + calls["rapply"] += 1 + return matrix.T @ y + + A = sc.MatrixFreeLinOp(apply, rapply, domain, codomain, ctx) + b = ctx.asarray([1.0, 2.0, 3.0]) + + result = sc.lsqr(A, b, tol=1e-8, maxiter=10) + + np.testing.assert_allclose(result.x, [1.0, 2.0], rtol=1e-6, atol=1e-6) + assert calls["rapply"] > 0 + + +def test_cg_solves_complex_hermitian_positive_definite_system(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + space = sc.VectorSpace((2,), ctx) + matrix = np.array([[4.0, 1.0 + 1.0j], [1.0 - 1.0j, 3.0]], dtype=np.complex128) + A = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + b = ctx.asarray([1.0 + 2.0j, 3.0 - 1.0j]) + + result = sc.cg(A, b, tol=1e-10, maxiter=10) + + np.testing.assert_allclose(to_numpy(result.x), np.linalg.solve(matrix, to_numpy(b)), rtol=1e-8) + np.testing.assert_allclose(to_numpy(A.apply(result.x)), to_numpy(b), rtol=1e-8, atol=1e-8) + assert bool(to_numpy(result.converged)) + + +def test_lsqr_solves_complex_least_squares(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + matrix = np.array( + [[1.0 + 1.0j, 0.0], [0.0, 2.0 - 1.0j], [1.0, 1.0j]], + dtype=np.complex128, + ) + A = sc.DenseLinOp(ctx.asarray(matrix), domain, codomain, ctx) + b = ctx.asarray([1.0 - 1.0j, 2.0 + 0.5j, 3.0j]) + + result = sc.lsqr(A, b, tol=1e-10, maxiter=20) + + expected, *_ = np.linalg.lstsq(matrix, to_numpy(b), rcond=None) + np.testing.assert_allclose(to_numpy(result.x), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(A.H.apply(A.apply(result.x) - b)), + np.zeros(2, dtype=np.complex128), + atol=1e-7, + ) + assert bool(to_numpy(result.converged)) + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_power_iteration_estimates_dominant_eigenpair(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + A = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + x0 = ctx.asarray([1.0, 1.0]) + + result = sc.power_iteration(A, x0=x0, tol=1e-5, maxiter=60) + + np.testing.assert_allclose(to_numpy(result.eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(result.eigenvector)), + [0.0, 1.0], + rtol=1e-4, + atol=1e-4, + ) + assert bool(to_numpy(result.converged)) + + +def test_power_iteration_accepts_quadratic_form_hessian_action(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 1.0]) + + op_result = sc.power_iteration(op, x0=x0, tol=1e-5, maxiter=60) + q_result = sc.power_iteration(q, x0=x0, tol=1e-5, maxiter=60) + + np.testing.assert_allclose(to_numpy(q_result.eigenvalue), to_numpy(op_result.eigenvalue)) + np.testing.assert_allclose( + np.abs(to_numpy(q_result.eigenvector)), + np.abs(to_numpy(op_result.eigenvector)), + rtol=1e-6, + atol=1e-6, + ) + + +def test_power_iteration_dispatches_quadratic_form_before_core(monkeypatch): + sc = importlib.import_module("spacecore") + power_mod = importlib.import_module("spacecore.linalg._power") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 0.0]) + captured = {} + + def fake_core(action, x, tol, maxiter, check_every): + captured["action"] = action + captured["x"] = x + return ctx.asarray(0.0), x, ctx.asarray(True), 0, ctx.asarray(0.0) + + monkeypatch.setattr(power_mod, "_power_iteration_core", fake_core) + result = power_mod.power_iteration(q, x0=x0, maxiter=1) + + assert result.eigenvector is x0 + assert isinstance(captured["action"], power_mod._SelfAdjointAction) + assert captured["action"].domain == q.domain + x = ctx.asarray([1.0, 2.0]) + np.testing.assert_allclose(captured["action"].apply(x), q.hess_apply(x)) + + +def test_power_iteration_core_has_no_dispatch_logic(): + power_mod = importlib.import_module("spacecore.linalg._power") + source = inspect.getsource(power_mod._power_iteration_core) + + assert "isinstance" not in source + assert "hasattr" not in source + assert "getattr" not in source + assert "_SelfAdjointAction(" not in source + assert "PowerIterationResult(" not in source + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +def test_lanczos_smallest_approximates_smallest_eigenpair(backend_name, dtype): + sc = importlib.import_module("spacecore") + ctx = _ctx(backend_name, dtype) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + initial = ctx.asarray([1.0, 1.0]) + + result = sc.lanczos_smallest( + op, + initial, + max_iter=2, + tol=1e-8, + ) + + np.testing.assert_allclose(to_numpy(result.eigenvalue), 2.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(result.eigenvector)), + [1.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + assert bool(to_numpy(result.converged)) + np.testing.assert_allclose(to_numpy(result.residual_norm), 0.0, atol=1e-5) + assert int(to_numpy(result.krylov_dim)) == 2 + + +def test_lanczos_smallest_returns_result_object(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + result = sc.lanczos_smallest(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) + + assert isinstance(result, sc.LanczosResult) + np.testing.assert_allclose(result.eigenvalue, 2.0) + + +def test_lanczos_smallest_uses_e0_for_zero_initial_vector(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + initial = ctx.asarray([0.0, 0.0]) + + result = sc.lanczos_smallest( + op, + initial, + max_iter=2, + tol=1e-8, + ) + + np.testing.assert_allclose(result.eigenvalue, 2.0, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(result.eigenvector, [1.0, 0.0], rtol=1e-6, atol=1e-6) + + +def test_lanczos_smallest_rejects_invalid_max_iter(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((1,), ctx) + op = sc.IdentityLinOp(space, ctx) + + with pytest.raises(ValueError, match="max_iter"): + sc.lanczos_smallest(op, ctx.asarray([1.0]), max_iter=0) + + +def test_lanczos_smallest_rejects_structurally_non_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [0.0, 3.0]]), space, space, ctx) + + with pytest.raises(ValueError, match="Hermitian"): + sc.lanczos_smallest(op, ctx.asarray([1.0, 1.0]), max_iter=2) + + +def test_lanczos_smallest_handles_eigenvalues_larger_than_1e10(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + matrix = np.diag([2.0e12, 3.0e12]) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + result = sc.lanczos_smallest( + op, + ctx.asarray([1.0, 1.0]), + max_iter=4, + tol=1e-8, + ) + + np.testing.assert_allclose(to_numpy(result.eigenvalue), 2.0e12, rtol=1e-6) + np.testing.assert_allclose(np.abs(to_numpy(result.eigenvector)), [1.0, 0.0], atol=1e-5) + + +def test_lanczos_smallest_handles_complex_hermitian_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.complex128) + space = sc.VectorSpace((2,), ctx) + matrix = np.array([[2.0, 1.0 + 2.0j], [1.0 - 2.0j, 5.0]], dtype=np.complex128) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + result = sc.lanczos_smallest( + op, + ctx.asarray([1.0 + 0.0j, 1.0j]), + max_iter=2, + tol=1e-10, + ) + + expected = np.linalg.eigvalsh(matrix)[0] + np.testing.assert_allclose(to_numpy(result.eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(op.apply(result.eigenvector)), + to_numpy(result.eigenvalue) * to_numpy(result.eigenvector), + rtol=1e-6, + atol=1e-6, + ) + + +def test_lanczos_smallest_uses_domain_geometry_for_weighted_inner_product(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + + class WeightedVectorSpace(sc.VectorSpace): + def __init__(self, weights, ctx): + weights = ctx.asarray(weights) + super().__init__(tuple(weights.shape), ctx) + self.weights = weights + + def inner(self, x, y): + if self._enable_checks: + self._check_member(x) + self._check_member(y) + return self.ops.vdot(x, self.weights * y) + + def _convert(self, new_ctx): + return WeightedVectorSpace(new_ctx.asarray(self.weights), new_ctx) + + space = WeightedVectorSpace([1.0, 4.0], ctx) + assert type(space) is not sc.VectorSpace + + matrix = np.array([[2.0, 1.0], [0.25, 0.75]]) + op = sc.DenseLinOp(ctx.asarray(matrix), space, space, ctx) + + result = sc.lanczos_smallest( + op, + ctx.asarray([1.0, 1.0]), + max_iter=2, + tol=1e-12, + ) + + expected = np.min(np.linalg.eigvals(matrix).real) + np.testing.assert_allclose(to_numpy(result.eigenvalue), expected, rtol=1e-7, atol=1e-7) + np.testing.assert_allclose( + to_numpy(op.apply(result.eigenvector)), + to_numpy(result.eigenvalue) * to_numpy(result.eigenvector), + rtol=1e-6, + atol=1e-6, + ) + + +def test_safe_inverse_nonneg_returns_reciprocal_for_positive_values_only(): + sc = importlib.import_module("spacecore") + utils = importlib.import_module("spacecore.linalg._utils") + ctx = _ctx() + + values = ctx.asarray([-2.0, 0.0, 4.0]) + + np.testing.assert_allclose(to_numpy(utils.safe_inverse_nonneg(sc.NumpyOps(), values)), [0.0, 0.0, 0.25]) + + +def test_iterative_solvers_poll_convergence_on_check_interval(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + spd = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + rectangular = sc.DenseLinOp( + ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), + space, + sc.VectorSpace((3,), ctx), + ctx, + ) + diagonal = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + cg_result = sc.cg(spd, ctx.asarray([1.0, 2.0]), maxiter=65) + lsqr_result = sc.lsqr(rectangular, ctx.asarray([1.0, 2.0, 3.0]), maxiter=65) + power_result = sc.power_iteration(diagonal, x0=ctx.asarray([1.0, 1.0]), maxiter=65) + + assert cg_result.num_iters == 64 + assert lsqr_result.num_iters == 64 + assert power_result.num_iters == 64 + np.testing.assert_allclose(cg_result.residual_norm, 0.0, atol=1e-12) + np.testing.assert_allclose(lsqr_result.normal_residual_norm, 0.0, atol=1e-12) + np.testing.assert_allclose(power_result.residual_norm, 0.0, atol=1e-12) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_cg_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[4.0, 1.0], [1.0, 3.0]]), space, space, ctx) + + solve = jax.jit(lambda A, b: sc.cg(A, b, maxiter=10).x) + x = solve(op, ctx.asarray([1.0, 2.0])) + + np.testing.assert_allclose(to_numpy(x), [0.09090909, 0.63636364], rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_lsqr_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + op = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), domain, codomain, ctx) + + solve = jax.jit(lambda A, b: sc.lsqr(A, b, maxiter=10).x) + x = solve(op, ctx.asarray([1.0, 2.0, 4.0])) + + np.testing.assert_allclose(to_numpy(x), [1.33333333, 2.33333333], rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_power_iteration_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + run = jax.jit(lambda A, x: sc.power_iteration(A, x0=x, maxiter=60).eigenvalue) + eigenvalue = run(op, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_power_iteration_jit_compiles_with_quadratic_form_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + + run = jax.jit(lambda quad, x: sc.power_iteration(quad, x0=x, maxiter=60).eigenvalue) + eigenvalue = run(q, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_lanczos_smallest_jit_compiles_with_operator_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + def run(A, initial): + result = sc.lanczos_smallest( + A, + initial, + max_iter=2, + tol=1e-8, + ) + return result.eigenvalue, result.eigenvector + + eigenvalue, eigenvector = jax.jit(run)(op, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 2.0, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose( + np.abs(to_numpy(eigenvector)), + [1.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + + +def test_cg_and_power_iteration_reject_rectangular_operator(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + domain = sc.VectorSpace((2,), ctx) + codomain = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]), domain, codomain, ctx) + + with pytest.raises(ValueError, match="square LinOp"): + sc.cg(A, ctx.asarray([1.0, 2.0, 3.0])) + with pytest.raises(ValueError, match="square LinOp"): + sc.power_iteration(A) + with pytest.raises(ValueError, match="square LinOp"): + sc.lanczos_smallest(A, ctx.asarray([1.0, 2.0])) diff --git a/tests/linops/test_algebra.py b/tests/linops/test_algebra.py new file mode 100644 index 0000000..e8cccdb --- /dev/null +++ b/tests/linops/test_algebra.py @@ -0,0 +1,306 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, jax_real_dtype +from tests._helpers import to_numpy, torch_complex_dtype + + +def _backend_params(): + params = [pytest.param("numpy", np.complex128, id="numpy")] + params.append( + pytest.param( + "jax", + jax_complex_dtype(), + marks=pytest.mark.skipif(not has_jax(), reason="jax is not installed"), + id="jax", + ) + ) + params.append( + pytest.param( + "torch", + torch_complex_dtype(), + marks=pytest.mark.skipif(not has_torch(), reason="torch is not installed"), + id="torch", + ) + ) + return params + + +def _ops_for_backend(name): + sc = importlib.import_module("spacecore") + if name == "numpy": + return sc.NumpyOps() + if name == "jax": + return sc.JaxOps() + if name == "torch": + return sc.TorchOps() + raise ValueError(f"Unknown backend {name!r}.") + + +def _ctx(dtype=np.complex128, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _spaces(ctx): + sc = importlib.import_module("spacecore") + return sc.VectorSpace((2,), ctx), sc.VectorSpace((3,), ctx) + + +def _matrix(): + return np.array( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + + +def _square_matrix(): + return np.array([[2.0 - 1.0j, -0.5 + 0.25j], [1.25 + 2.0j, -3.0 + 0.5j]]) + + +def _dense_linop(ctx): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(_matrix()), dom, cod, ctx) + + +def _dense_same_shape(ctx, scale=1.0): + sc = importlib.import_module("spacecore") + dom, cod = _spaces(ctx) + return sc.DenseLinOp(ctx.asarray(scale * _matrix()), dom, cod, ctx) + + +def _dense_square(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + return sc.DenseLinOp(ctx.asarray(_square_matrix()), dom, dom, ctx) + + +def _xy(ctx): + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + return x, y + + +def _assert_adjoint_identity(op, x, y, ctx): + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + +def _adjoint_cases(ctx): + sc = importlib.import_module("spacecore") + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=0.5 - 0.25j) + C = _dense_square(ctx) + dom, cod = _spaces(ctx) + x, y = _xy(ctx) + z = ctx.asarray([-1.0 + 0.5j, 2.0 - 0.25j]) + + matrix = ctx.asarray(_matrix()) + matrix_free = sc.MatrixFreeLinOp( + lambda v: matrix @ v, + lambda w: ctx.ops.conj(ctx.ops.transpose(matrix)) @ w, + dom, + cod, + ctx, + ) + + return [ + ((2.0 + 3.0j) * A, x, y), + (A + B, x, y), + (A @ C, z, y), + (sc.ZeroLinOp(dom, cod, ctx), x, y), + (sc.IdentityLinOp(dom, ctx), x, x), + (matrix_free, x, y), + (A.H, y, x), + ] + + +@pytest.mark.parametrize("backend_name,dtype", _backend_params()) +@pytest.mark.parametrize("case_index", range(7)) +def test_complex_adjoint_identity_for_algebra_classes(backend_name, dtype, case_index): + sc = importlib.import_module("spacecore") + ctx = sc.Context(_ops_for_backend(backend_name), dtype=dtype) + op, x, y = _adjoint_cases(ctx)[case_index] + + _assert_adjoint_identity(op, x, y, ctx) + + +def test_simplification_canonicalizations(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _dense_linop(ctx) + B = _dense_same_shape(ctx, scale=2.0) + C = _dense_same_shape(ctx, scale=-1.0) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + + assert sc.make_sum((A, Z)) is A + assert isinstance(sc.make_sum((Z, Z)), sc.ZeroLinOp) + assert sc.make_sum((A,)) is A + flattened = sc.make_sum((sc.make_sum((A, B)), C)) + assert isinstance(flattened, sc.SumLinOp) + assert flattened.parts == (A, B, C) + + scaled_zero = sc.make_scaled(0, A) + assert isinstance(scaled_zero, sc.ZeroLinOp) + assert scaled_zero.domain == A.domain + assert scaled_zero.codomain == A.codomain + assert sc.make_scaled(1, A) is A + assert sc.make_scaled(7.0, Z) is Z + folded = sc.make_scaled(2, sc.make_scaled(3, A)) + assert isinstance(folded, sc.ScaledLinOp) + assert folded.scalar == 6 + assert folded.op is A + + I_dom = sc.IdentityLinOp(A.domain, ctx) + I_cod = sc.IdentityLinOp(A.codomain, ctx) + assert sc.make_composed(I_cod, A) is A + assert sc.make_composed(A, I_dom) is A + + out = sc.VectorSpace((4,), ctx) + left_zero = sc.ZeroLinOp(A.codomain, out, ctx) + composed_zero = sc.make_composed(left_zero, A) + assert isinstance(composed_zero, sc.ZeroLinOp) + assert composed_zero.domain == A.domain + assert composed_zero.codomain == out + + +@pytest.mark.parametrize("case_index", range(7)) +def test_double_adjoint_view_returns_literal_original(case_index): + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + + assert op.H.H is op + + +def test_identity_linop_apply_is_literal_input_when_checks_disabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=False) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + assert op.apply(x) is x + assert op.rapply(x) is x + + +def test_identity_linop_apply_equals_input_when_checks_enabled(): + sc = importlib.import_module("spacecore") + ctx = _ctx(enable_checks=True) + space = sc.VectorSpace((2,), ctx) + op = sc.IdentityLinOp(space, ctx) + x = ctx.asarray([1.0 + 2.0j, 3.0 - 4.0j]) + + np.testing.assert_allclose(op.apply(x), x) + np.testing.assert_allclose(op.rapply(x), x) + + +def test_python_sum_starts_from_zero_and_accumulates_linops(): + ctx = _ctx() + A = _dense_same_shape(ctx, scale=1.0) + B = _dense_same_shape(ctx, scale=0.5) + C = _dense_same_shape(ctx, scale=-2.0) + x, _ = _xy(ctx) + + op = sum([A, B, C]) + expected = A.apply(x) + B.apply(x) + C.apply(x) + + np.testing.assert_allclose(op.apply(x), expected) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +@pytest.mark.parametrize("case_index", range(7)) +def test_jax_pytree_roundtrip_for_algebra_classes(case_index): + import jax + + ctx = _ctx() + op, _, _ = _adjoint_cases(ctx)[case_index] + leaves, treedef = jax.tree.flatten(op) + rebuilt = jax.tree.unflatten(treedef, leaves) + + assert rebuilt == op + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_jit_algebra_expression_matches_eager(): + import jax + + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.JaxOps(), dtype=jax_real_dtype(), enable_checks=False) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, -1.0], [2.0, 1.0], [-0.5, 3.0]]), X, Y, ctx) + C = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.25, 1.5]]), X, X, ctx) + expr = (2 * A + B) @ C + x = ctx.asarray([1.0, -2.0]) + + apply_jit = jax.jit(lambda op, z: op.apply(z)) + + np.testing.assert_allclose(to_numpy(apply_jit(expr, x)), to_numpy(expr.apply(x))) + + +def test_factories_enforce_same_context_dtype(): + sc = importlib.import_module("spacecore") + ctx32 = sc.Context(sc.NumpyOps(), dtype=np.float32) + ctx64 = sc.Context(sc.NumpyOps(), dtype=np.float64) + X32 = sc.VectorSpace((2,), ctx32) + Y32 = sc.VectorSpace((2,), ctx32) + X64 = sc.VectorSpace((2,), ctx64) + Y64 = sc.VectorSpace((2,), ctx64) + A32 = sc.DenseLinOp(ctx32.asarray([[1.0, 2.0], [3.0, 4.0]]), X32, Y32, ctx32) + A64 = sc.DenseLinOp(ctx64.asarray([[1.0, 2.0], [3.0, 4.0]]), X64, Y64, ctx64) + + with pytest.raises(ValueError, match="same ctx"): + sc.make_sum((A32, A64)) + with pytest.raises(ValueError, match="same ctx"): + sc.make_composed(A32, A64) + + +def test_factories_ignore_enable_checks_when_context_dtype_matches(): + sc = importlib.import_module("spacecore") + checked = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=True) + unchecked = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + X_checked = sc.VectorSpace((2,), checked) + X_unchecked = sc.VectorSpace((2,), unchecked) + A = sc.DenseLinOp(checked.asarray([[1.0, 0.0], [0.0, 1.0]]), X_checked, X_checked, checked) + B = sc.DenseLinOp( + unchecked.asarray([[2.0, 0.0], [0.0, 3.0]]), + X_unchecked, + X_unchecked, + unchecked, + ) + + summed = sc.make_sum((A, B)) + composed = sc.make_composed(A, B) + + assert isinstance(summed, sc.SumLinOp) + assert isinstance(composed, sc.ComposedLinOp) + + +def test_factories_enforce_domain_and_codomain_compatibility(): + sc = importlib.import_module("spacecore") + ctx = _ctx(dtype=np.float64) + X = sc.VectorSpace((2,), ctx) + Y = sc.VectorSpace((3,), ctx) + Z = sc.VectorSpace((4,), ctx) + A = sc.DenseLinOp(ctx.asarray(np.ones((3, 2))), X, Y, ctx) + B = sc.DenseLinOp(ctx.asarray(np.ones((4, 2))), X, Z, ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + sc.make_sum((A, B)) + with pytest.raises(ValueError, match="right.codomain == left.domain"): + sc.make_composed(A, B) + + +def test_base_linop_equality_protocol_does_not_raise(): + A = _dense_linop(_ctx()) + + assert (A == None) is False # noqa: E711 + assert A in [A] diff --git a/tests/linops/test_algebra_linop.py b/tests/linops/test_algebra_linop.py new file mode 100644 index 0000000..4c2ee85 --- /dev/null +++ b/tests/linops/test_algebra_linop.py @@ -0,0 +1,242 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _op(matrix, dom_shape, cod_shape, ctx=None): + sc = importlib.import_module("spacecore") + ctx = ctx or _ctx() + dom = sc.VectorSpace(dom_shape, ctx) + cod = sc.VectorSpace(cod_shape, ctx) + return sc.DenseLinOp(ctx.asarray(matrix), dom, cod, ctx) + + +def test_algebra_linops_inherit_from_linop(): + sc = importlib.import_module("spacecore") + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + + assert isinstance(2.0 * A, sc.LinOp) + assert isinstance(A + A, sc.LinOp) + assert isinstance(A @ A, sc.LinOp) + assert isinstance(A.H, sc.LinOp) + assert isinstance(sc.ZeroLinOp(A.domain, A.codomain, A.ctx), sc.LinOp) + assert isinstance(sc.IdentityLinOp(A.domain, A.ctx), sc.LinOp) + assert isinstance(sc.MatrixFreeLinOp(A.apply, A.rapply, A.domain, A.codomain, A.ctx), sc.LinOp) + assert issubclass(sc.ScaledLinOp, sc.LinOp) + assert issubclass(sc.SumLinOp, sc.LinOp) + assert issubclass(sc.ComposedLinOp, sc.LinOp) + assert issubclass(sc.ZeroLinOp, sc.LinOp) + assert issubclass(sc.IdentityLinOp, sc.LinOp) + assert issubclass(sc.MatrixFreeLinOp, sc.LinOp) + assert not hasattr(sc, "AdjointLinOp") + + +def test_check_policy_mismatch_does_not_block_algebra(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), _ctx(enable_checks=True)) + B = _op([[5.0, 6.0], [7.0, 8.0]], (2,), (2,), _ctx(enable_checks=False)) + + assert isinstance(A + B, importlib.import_module("spacecore").SumLinOp) + assert isinstance(A @ B, importlib.import_module("spacecore").ComposedLinOp) + + +def test_sum_requires_matching_domain_and_codomain(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + bad_cod = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + bad_dom = _op([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], (3,), (2,), A.ctx) + + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_cod + with pytest.raises(ValueError, match="same domain and codomain"): + _ = A + bad_dom + + +def test_composition_requires_matching_middle_space(): + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,)) + B = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), A.ctx) + C = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), A.ctx) + + assert (A @ B).domain == B.domain + assert (A @ B).codomain == A.codomain + with pytest.raises(ValueError, match="right.codomain == left.domain"): + _ = A @ C + + +def test_scaled_sum_subtraction_and_negation_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0]]) + dense_b = np.array([[5.0, 1.0], [-2.0, 3.0]]) + + expr = 2.0 * A + B - (-A) + + assert expr.domain == A.domain + assert expr.codomain == A.codomain + assert np.allclose(expr.apply(x), (3.0 * dense_a + dense_b) @ np.asarray(x)) + assert np.allclose(expr.rapply(y), (3.0 * dense_a + dense_b).T @ np.asarray(y)) + assert np.allclose((-A).apply(x), -dense_a @ np.asarray(x)) + assert np.allclose((A * 3.0).apply(x), 3.0 * dense_a @ np.asarray(x)) + + +def test_complex_scaled_adjoint_conjugates_scalar(): + ctx = _ctx(np.complex128) + A = _op([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]], (2,), (2,), ctx) + y = ctx.asarray([1.0 - 1.0j, 2.0 + 3.0j]) + dense = np.array([[1.0 + 1.0j, 2.0], [3.0j, 4.0 - 2.0j]]) + alpha = 2.0 + 3.0j + + op = alpha * A + + assert np.allclose(op.rapply(y), np.conj(alpha) * dense.conj().T @ np.asarray(y)) + + +def test_composition_apply_and_adjoint_are_numerically_correct(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + B = _op([[2.0, -1.0], [0.5, 3.0]], (2,), (2,), ctx) + x = ctx.asarray([4.0, -2.0]) + z = ctx.asarray([1.0, -1.0, 2.0]) + dense_a = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + dense_b = np.array([[2.0, -1.0], [0.5, 3.0]]) + + op = A @ B + + assert op.domain == B.domain + assert op.codomain == A.codomain + assert np.allclose(op.apply(x), dense_a @ dense_b @ np.asarray(x)) + assert np.allclose(op.rapply(z), dense_b.T @ dense_a.T @ np.asarray(z)) + + +def test_H_swaps_spaces_and_double_H_returns_original(): + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + AH = A.H + AHH = AH.H + + assert AH.ctx == A.ctx + assert AH.domain == A.codomain + assert AH.codomain == A.domain + assert np.allclose(AH.apply(y), A.rapply(y)) + assert np.allclose(AH.rapply(x), A.apply(x)) + assert AHH is A + assert np.allclose(AHH.apply(x), A.apply(x)) + assert np.allclose(AHH.rapply(y), A.rapply(y)) + + +def test_zero_identity_and_matrix_free_rapply_are_numerically_correct(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + + zero = sc.ZeroLinOp(dom, cod, ctx) + identity = sc.IdentityLinOp(dom, ctx) + matrix_free = sc.MatrixFreeLinOp( + lambda v: ctx.asarray(dense @ np.asarray(v)), + lambda w: ctx.asarray(dense.T @ np.asarray(w)), + dom, + cod, + ctx, + ) + + assert np.allclose(zero.apply(x), np.zeros(3)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(identity.apply(x), np.asarray(x)) + assert np.allclose(identity.rapply(x), np.asarray(x)) + assert np.allclose(matrix_free.apply(x), dense @ np.asarray(x)) + assert np.allclose(matrix_free.rapply(y), dense.T @ np.asarray(y)) + + +def test_sum_factory_flattens_nested_sums_and_removes_zero_terms(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + B = _op([[5.0, 1.0], [-2.0, 3.0]], (2,), (2,), ctx) + Z = sc.ZeroLinOp(A.domain, A.codomain, ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + + nested = sc.SumLinOp((A, B)) + simplified = nested + Z + zero_sum = Z + Z + + assert isinstance(simplified, sc.SumLinOp) + assert simplified.parts == (A, B) + assert A + Z is A + assert Z + A is A + assert isinstance(zero_sum, sc.ZeroLinOp) + assert zero_sum.domain == A.domain + assert zero_sum.codomain == A.codomain + + unsimplified = sc.SumLinOp((nested, Z)) + assert np.allclose(simplified.apply(x), unsimplified.apply(x)) + assert np.allclose(simplified.rapply(y), unsimplified.rapply(y)) + + +def test_scaling_factory_simplifies_zero_one_and_nested_scaling(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0]], (2,), (2,), ctx) + x = ctx.asarray([2.0, -1.0]) + y = ctx.asarray([1.0, 3.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0]]) + + zero = 0 * A + unit = 1 * A + nested = 2 * (3 * A) + + assert isinstance(zero, sc.ZeroLinOp) + assert unit is A + assert isinstance(nested, sc.ScaledLinOp) + assert nested.scalar == 6 + assert nested.op is A + assert np.allclose(zero.apply(x), np.zeros(2)) + assert np.allclose(zero.rapply(y), np.zeros(2)) + assert np.allclose(nested.apply(x), 6 * dense @ np.asarray(x)) + assert np.allclose(nested.rapply(y), 6 * dense.T @ np.asarray(y)) + + +def test_composition_factory_simplifies_identity_and_zero_factors(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + A = _op([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], (2,), (3,), ctx) + id_domain = sc.IdentityLinOp(A.domain, ctx) + id_codomain = sc.IdentityLinOp(A.codomain, ctx) + left_zero = sc.ZeroLinOp(A.codomain, sc.VectorSpace((4,), ctx), ctx) + right_zero = sc.ZeroLinOp(sc.VectorSpace((5,), ctx), A.domain, ctx) + x = ctx.asarray([7.0, 8.0]) + y = ctx.asarray([1.0, -1.0, 2.0]) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + assert A @ id_domain is A + assert id_codomain @ A is A + + left_simplified = left_zero @ A + right_simplified = A @ right_zero + + assert isinstance(left_simplified, sc.ZeroLinOp) + assert left_simplified.domain == A.domain + assert left_simplified.codomain == left_zero.codomain + assert isinstance(right_simplified, sc.ZeroLinOp) + assert right_simplified.domain == right_zero.domain + assert right_simplified.codomain == A.codomain + + unsimplified_left = sc.ComposedLinOp(left_zero, A) + assert np.allclose((A @ id_domain).apply(x), dense @ np.asarray(x)) + assert np.allclose((id_codomain @ A).rapply(y), dense.T @ np.asarray(y)) + assert np.allclose(left_simplified.apply(x), unsimplified_left.apply(x)) + assert np.allclose(left_simplified.rapply(ctx.asarray([1.0, 2.0, 3.0, 4.0])), np.zeros(2)) diff --git a/tests/linops/test_batched_lifting.py b/tests/linops/test_batched_lifting.py new file mode 100644 index 0000000..efb8f62 --- /dev/null +++ b/tests/linops/test_batched_lifting.py @@ -0,0 +1,175 @@ +import importlib + +import numpy as np +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _unchecked_ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False) + + +def _stack_apply(ctx, op, xs): + return ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0) + + +def _stack_rapply(ctx, op, ys): + return ctx.ops.stack(tuple(op.rapply(y) for y in ys), axis=0) + + +def test_dense_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(matrix, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_sparse_linop_vapply_and_rvapply_match_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(ctx.assparse(sps.csr_matrix(dense)), dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(ys), _stack_rapply(ctx, op, ys)) + + +def test_dense_and_sparse_batched_lifting_fast_paths_without_checks(): + sc = importlib.import_module("spacecore") + ctx = _unchecked_ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + sparse = ctx.assparse(sps.csr_matrix([[1.0, 0.0], [0.0, 4.0], [5.0, 6.0]])) + dense_op = sc.DenseLinOp(matrix, dom, cod, ctx) + sparse_op = sc.SparseLinOp(sparse, dom, cod, ctx) + batch_dom = dom.batch((3,), (0,)) + batch_cod = cod.batch((2,), (0,)) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0], [0.5, 2.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(dense_op.vapply(xs, batch_dom), np.asarray(xs) @ np.asarray(matrix).T) + assert np.allclose(dense_op.rvapply(ys, batch_cod), np.asarray(ys) @ np.asarray(matrix)) + assert np.allclose(sparse_op.vapply(xs, batch_dom), (sparse @ np.asarray(xs).T).T) + assert np.allclose(sparse_op.rvapply(ys, batch_cod), (sparse.T @ np.asarray(ys).T).T) + + +def test_diagonal_identity_zero_sum_composed_and_adjoint_batched_lifting(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + d1 = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + d2 = sc.DiagonalLinOp(ctx.asarray([-1.0, 0.5, 4.0]), space, ctx) + identity = sc.IdentityLinOp(space, ctx) + zero = sc.ZeroLinOp(space, space, ctx) + summed = d1 + d2 + zero + composed = d1 @ (d2 + identity) + adjoint = composed.H + xs = ctx.asarray([[1.0, 2.0, 3.0], [4.0, -1.0, 0.0]]) + + for op in (d1, identity, zero, summed, composed): + assert np.allclose(op.vapply(xs), _stack_apply(ctx, op, xs)) + assert np.allclose(op.rvapply(xs), _stack_rapply(ctx, op, xs)) + assert np.allclose(adjoint.vapply(xs), _stack_apply(ctx, adjoint, xs)) + assert np.allclose(adjoint.rvapply(xs), _stack_rapply(ctx, adjoint, xs)) + + +def test_matrix_free_vapply_uses_callback_when_supplied_and_fallback_when_absent(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + matrix = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + calls = {"vapply": 0, "rvapply": 0} + + def apply(x): + return ctx.asarray(matrix @ np.asarray(x)) + + def rapply(y): + return ctx.asarray(matrix.T @ np.asarray(y)) + + def vapply(xs): + calls["vapply"] += 1 + return ctx.asarray(np.asarray(xs) @ matrix.T) + + def rvapply(ys): + calls["rvapply"] += 1 + return ctx.asarray(np.asarray(ys) @ matrix) + + with_callbacks = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx, vapply, rvapply) + fallback = sc.MatrixFreeLinOp(apply, rapply, dom, cod, ctx) + xs = ctx.asarray([[7.0, 8.0], [1.0, -1.0]]) + ys = ctx.asarray([[1.0, -1.0, 2.0], [0.0, 3.0, -2.0]]) + + assert np.allclose(with_callbacks.vapply(xs), _stack_apply(ctx, with_callbacks, xs)) + assert np.allclose(with_callbacks.rvapply(ys), _stack_rapply(ctx, with_callbacks, ys)) + assert calls == {"vapply": 1, "rvapply": 1} + assert np.allclose(fallback.vapply(xs), _stack_apply(ctx, fallback, xs)) + assert np.allclose(fallback.rvapply(ys), _stack_rapply(ctx, fallback, ys)) + + +def test_product_linops_batched_lifting_matches_stacked_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + y0 = sc.VectorSpace((3,), ctx) + y1 = sc.VectorSpace((2,), ctx) + a0 = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), x0, y0, ctx) + a1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0, 0.5], [0.0, 3.0, 4.0]]), x1, y1, ctx) + s1 = sc.DenseLinOp(ctx.asarray([[2.0, -1.0], [0.0, 3.0]]), x0, y1, ctx) + block = sc.BlockDiagonalLinOp.from_operators((a0, a1)) + stacked = sc.StackedLinOp.from_operators((a0, s1)) + sum_to_single = sc.SumToSingleLinOp.from_operators((a0.H, s1.H)) + + xb = (ctx.asarray([[1.0, 2.0], [3.0, 4.0]]), ctx.asarray([[5.0, 6.0, 7.0], [1.0, -1.0, 0.5]])) + yb = (ctx.asarray([[1.0, 2.0, 3.0], [0.0, -1.0, 4.0]]), ctx.asarray([[2.0, 1.0], [3.0, -2.0]])) + single_x = ctx.asarray([[1.0, 2.0], [3.0, 4.0]]) + single_y = ctx.asarray([[1.0, 2.0], [3.0, -2.0]]) + + block_v = block.vapply(xb) + block_expected = tuple(ctx.ops.stack(tuple(block.apply((xb[0][i], xb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_v[0], block_expected[0]) + assert np.allclose(block_v[1], block_expected[1]) + + block_rv = block.rvapply(yb) + block_r_expected = tuple(ctx.ops.stack(tuple(block.rapply((yb[0][i], yb[1][i]))[j] for i in range(2))) for j in range(2)) + assert np.allclose(block_rv[0], block_r_expected[0]) + assert np.allclose(block_rv[1], block_r_expected[1]) + + stacked_v = stacked.vapply(single_x) + stacked_expected = tuple(ctx.ops.stack(tuple(stacked.apply(single_x[i])[j] for i in range(2))) for j in range(2)) + assert np.allclose(stacked_v[0], stacked_expected[0]) + assert np.allclose(stacked_v[1], stacked_expected[1]) + + assert np.allclose( + stacked.rvapply(yb), + ctx.ops.stack(tuple(stacked.rapply((yb[0][i], yb[1][i])) for i in range(2))), + ) + assert np.allclose( + sum_to_single.vapply(yb), + ctx.ops.stack(tuple(sum_to_single.apply((yb[0][i], yb[1][i])) for i in range(2))), + ) + sum_rv = sum_to_single.rvapply(single_y) + sum_r_expected = tuple( + ctx.ops.stack(tuple(sum_to_single.rapply(single_y[i])[j] for i in range(2))) + for j in range(2) + ) + assert np.allclose(sum_rv[0], sum_r_expected[0]) + assert np.allclose(sum_rv[1], sum_r_expected[1]) diff --git a/tests/linops/test_diagonal_linop.py b/tests/linops/test_diagonal_linop.py new file mode 100644 index 0000000..e200a1f --- /dev/null +++ b/tests/linops/test_diagonal_linop.py @@ -0,0 +1,118 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, to_numpy + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def test_apply_and_rapply_flat_diagonal(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + x = ctx.asarray([4.0, -1.0, 0.5]) + + np.testing.assert_allclose(op.apply(x), [4.0, -2.0, 1.5]) + np.testing.assert_allclose(op.rapply(x), [4.0, -2.0, 1.5]) + + +def test_apply_and_rapply_tensor_shaped_diagonal(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + op = sc.DiagonalLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0]]), space, ctx) + x = ctx.asarray([[2.0, -1.0], [0.5, 3.0]]) + + np.testing.assert_allclose(op.apply(x), [[2.0, -2.0], [1.5, 12.0]]) + np.testing.assert_allclose(op.rapply(x), [[2.0, -2.0], [1.5, 12.0]]) + + +def test_complex_diagonal_satisfies_adjoint_identity_and_hermitian_predicate(): + sc = importlib.import_module("spacecore") + ctx = _ctx(np.complex128) + space = sc.VectorSpace((2,), ctx) + hermitian = sc.DiagonalLinOp(ctx.asarray([1.0 + 0.0j, 2.0 + 0.0j]), space, ctx) + non_hermitian = sc.DiagonalLinOp(ctx.asarray([1.0 + 2.0j, 3.0 - 1.0j]), space, ctx) + u = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + v = ctx.asarray([1.5 + 0.5j, -2.0j]) + + lhs = space.inner(non_hermitian.apply(u), v) + rhs = space.inner(u, non_hermitian.rapply(v)) + + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs)) + assert hermitian.is_hermitian() is True + assert non_hermitian.is_hermitian() is False + + +def test_vapply_and_rvapply_with_leading_batch_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + xs = ctx.asarray([[1.0, 2.0, 3.0], [4.0, -1.0, 0.5]]) + batch_space = space.batch((2,), (0,)) + + expected = ctx.asarray([[1.0, 4.0, 9.0], [4.0, -2.0, 1.5]]) + np.testing.assert_allclose(op.vapply(xs, batch_space), expected) + np.testing.assert_allclose(op.rvapply(xs, batch_space), expected) + + +def test_vapply_and_rvapply_with_non_leading_batch_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + xs = ctx.asarray([[1.0, 4.0], [2.0, -1.0], [3.0, 0.5]]) + batch_space = space.batch((2,), (1,)) + + expected = ctx.asarray([[1.0, 4.0], [4.0, -2.0], [9.0, 1.5]]) + np.testing.assert_allclose(op.vapply(xs, batch_space), expected) + np.testing.assert_allclose(op.rvapply(xs, batch_space), expected) + + +def test_to_dense_matches_numpy_diagonal_for_tensor_space(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + diagonal = ctx.asarray([[1.0, 2.0], [3.0, 4.0]]) + op = sc.DiagonalLinOp(diagonal, space, ctx) + + expected = np.diag(np.asarray(diagonal).reshape((4,))).reshape((2, 2, 2, 2)) + + np.testing.assert_allclose(op.to_dense(), expected) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_pytree_flatten_unflatten_round_trip(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + + leaves, treedef = jax.tree_util.tree_flatten(op) + restored = jax.tree_util.tree_unflatten(treedef, leaves) + + assert restored == op + np.testing.assert_allclose(restored.apply(ctx.asarray([1.0, 1.0, 1.0])), [1.0, 2.0, 3.0]) + + +def test_convert_changes_context_dtype(): + sc = importlib.import_module("spacecore") + ctx64 = _ctx(np.float64) + ctx32 = _ctx(np.float32) + space = sc.VectorSpace((3,), ctx64) + op = sc.DiagonalLinOp(ctx64.asarray([1.0, 2.0, 3.0]), space, ctx64) + + converted = op._convert(ctx32) + + assert converted.ctx == ctx32 + assert converted.domain.ctx == ctx32 + assert converted.diagonal.dtype == np.dtype(np.float32) + np.testing.assert_allclose(converted.apply(ctx32.asarray([1.0, 1.0, 1.0])), [1.0, 2.0, 3.0]) diff --git a/tests/linops/test_linop_jit.py b/tests/linops/test_linop_jit.py index 2f30207..32452d6 100644 --- a/tests/linops/test_linop_jit.py +++ b/tests/linops/test_linop_jit.py @@ -38,6 +38,26 @@ def test_dense_linop_jit_apply_and_rapply_with_operator_argument(): np.testing.assert_allclose(to_numpy(rapply_jit(op, y)), [8., 10.]) +def test_decorated_apply_rapply_value_and_grad_jit_compile(): + jax = pytest.importorskip("jax") + + ctx = _jax_ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 1.0], [1.0, 4.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x = ctx.asarray([3.0, -1.0]) + + apply_jit = jax.jit(lambda A, z: A.apply(z)) + rapply_jit = jax.jit(lambda A, z: A.rapply(z)) + value_jit = jax.jit(lambda functional, z: functional.value(z)) + grad_jit = jax.jit(lambda functional, z: functional.grad(z)) + + np.testing.assert_allclose(to_numpy(apply_jit(op, x)), to_numpy(op.apply(x))) + np.testing.assert_allclose(to_numpy(rapply_jit(op, x)), to_numpy(op.rapply(x))) + np.testing.assert_allclose(to_numpy(value_jit(q, x)), to_numpy(q.value(x))) + np.testing.assert_allclose(to_numpy(grad_jit(q, x)), to_numpy(q.grad(x))) + + def test_tensor_dense_linop_jit_preserves_shapes(): jax = pytest.importorskip("jax") diff --git a/tests/linops/test_to_dense.py b/tests/linops/test_to_dense.py new file mode 100644 index 0000000..a38ec49 --- /dev/null +++ b/tests/linops/test_to_dense.py @@ -0,0 +1,178 @@ +import importlib + +import numpy as np +import pytest +import scipy.sparse as sps + + +def _ctx(): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=np.float64) + + +def _assert_to_dense_matches_apply(op, x): + dense = op.to_dense() + matrix = dense.reshape((np.prod(op.codomain.shape), np.prod(op.domain.shape))) + y_from_dense = matrix @ op.domain.flatten(x) + y_from_apply = op.codomain.flatten(op.apply(x)) + assert np.allclose(y_from_dense, y_from_apply) + + +def test_dense_linop_to_dense_returns_stored_matrix_and_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.to_dense() is A + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_dense_linop_A_returns_stored_dense_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + assert op.A is A + + +def test_sparse_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(sps.csr_matrix(dense), dom, cod, ctx) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_sparse_linop_A_returns_stored_sparse_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sps.csr_matrix([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.SparseLinOp(A, dom, cod, ctx) + + assert op.A is A + + +def test_identity_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2, 2), ctx) + op = sc.IdentityLinOp(space, ctx) + + assert np.allclose(op.to_dense().reshape((4, 4)), np.eye(4)) + _assert_to_dense_matches_apply(op, ctx.asarray([[1.0, 2.0], [3.0, 4.0]])) + + +def test_zero_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + op = sc.ZeroLinOp(dom, cod, ctx) + + assert np.allclose(op.to_dense(), np.zeros((3, 2))) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_matrix_free_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + assert np.allclose(op.to_dense(), dense) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) + + +def test_matrix_free_linop_A_is_not_implemented_by_default(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + op = sc.MatrixFreeLinOp( + lambda x: ctx.asarray(dense @ np.asarray(x)), + lambda y: ctx.asarray(dense.T @ np.asarray(y)), + dom, + cod, + ctx, + ) + + with pytest.raises(NotImplementedError, match="native numerical representation"): + _ = op.A + + +def test_custom_linop_can_define_A_representation(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + dense = ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + class CustomLinOp(sc.LinOp): + @property + def A(self): + return {"backend": "custom", "data": dense} + + def apply(self, x): + return ctx.asarray(np.asarray(dense) @ np.asarray(x)) + + def rapply(self, y): + return ctx.asarray(np.asarray(dense).T @ np.asarray(y)) + + def tree_flatten(self): + return (), (self.domain, self.codomain, self.ctx) + + @classmethod + def tree_unflatten(cls, aux, children): + domain, codomain, ctx = aux + return cls(domain, codomain, ctx) + + op = CustomLinOp(dom, cod, ctx) + + assert op.A["backend"] == "custom" + assert op.A["data"] is dense + + +def test_diagonal_linop_A_is_cached(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((3,), ctx) + op = sc.DiagonalLinOp(ctx.asarray([1.0, 2.0, 3.0]), space, ctx) + + A = op.A + + assert op.A is A + assert np.allclose(A, np.diag([1.0, 2.0, 3.0])) + + +def test_sum_linop_to_dense_matches_apply(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = sc.DenseLinOp(ctx.asarray([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), dom, cod, ctx) + B = sc.DenseLinOp(ctx.asarray([[0.5, 1.0], [-1.0, 2.0], [3.0, -0.5]]), dom, cod, ctx) + op = A + B + + assert np.allclose(op.to_dense(), A.to_dense() + B.to_dense()) + _assert_to_dense_matches_apply(op, ctx.asarray([7.0, 8.0])) diff --git a/tests/spaces/test_batch_space.py b/tests/spaces/test_batch_space.py new file mode 100644 index 0000000..3053806 --- /dev/null +++ b/tests/spaces/test_batch_space.py @@ -0,0 +1,38 @@ +import importlib + +import numpy as np + + +def test_vector_space_batch_wrapper_shape_and_membership(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x = sc.VectorSpace((2, 3), ctx) + + xb = x.batch(batch_shape=(4,), batch_axes=(0,)) + + assert isinstance(xb, sc.BatchSpace) + assert xb.base == x + assert xb.batch_shape == (4,) + assert xb.batch_axes == (0,) + assert xb.shape == (4, 2, 3) + xb.check_member(ctx.ops.zeros((4, 2, 3), dtype=ctx.dtype)) + + +def test_product_space_batch_wrapper_validates_component_batches(): + sc = importlib.import_module("spacecore") + ctx = sc.Context(sc.NumpyOps(), dtype=np.float64) + x0 = sc.VectorSpace((2,), ctx) + x1 = sc.VectorSpace((3,), ctx) + product = sc.ProductSpace((x0, x1), ctx) + batched = product.batch((5,), (0,)) + + value = ( + ctx.ops.zeros((5, 2), dtype=ctx.dtype), + ctx.ops.zeros((5, 3), dtype=ctx.dtype), + ) + + assert batched.shape == (5, 5) + batched.check_member(value) + zeros = batched.zeros() + assert np.allclose(zeros[0], np.zeros((5, 2))) + assert np.allclose(zeros[1], np.zeros((5, 3))) diff --git a/tests/test_backend_ops_complex.py b/tests/test_backend_ops_complex.py new file mode 100644 index 0000000..512b194 --- /dev/null +++ b/tests/test_backend_ops_complex.py @@ -0,0 +1,41 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_cupy, has_jax, has_torch, jax_complex_dtype, to_numpy +from tests._helpers import torch_complex_dtype + + +def _check_vdot_conjugates_first_argument(ops, dtype): + x = ops.asarray([1.0 + 2.0j, 3.0 + 4.0j], dtype=dtype) + y = ops.asarray([5.0 + 6.0j, 7.0 + 8.0j], dtype=dtype) + + np.testing.assert_allclose(to_numpy(ops.vdot(x, y)), 70.0 - 8.0j) + + +def test_numpy_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.TorchOps(), torch_complex_dtype()) + + +@pytest.mark.skipif(not has_cupy(), reason="cupy is not installed") +def test_cupy_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.CuPyOps(), np.complex128) diff --git a/tutorials/8_Linalg_MatrixFree.ipynb b/tutorials/8_Linalg_MatrixFree.ipynb new file mode 100644 index 0000000..f195685 --- /dev/null +++ b/tutorials/8_Linalg_MatrixFree.ipynb @@ -0,0 +1,473 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# Matrix-free linear algebra\n", + "\n", + "This guide demonstrates SpaceCore's iterative linear algebra routines on `MatrixFreeLinOp` objects.\n", + "\n", + "A matrix-free operator is useful when the action of an operator is cheap, but building or storing the full matrix would be wasteful. In SpaceCore, the only requirements are:\n", + "\n", + "- a domain space,\n", + "- a codomain space,\n", + "- a forward action `apply(x)`,\n", + "- an adjoint action `rapply(y)`.\n", + "\n", + "The solvers below never need a dense matrix representation of the operator." + ] + }, + { + "cell_type": "code", + "id": "imports", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.583864Z", + "start_time": "2026-05-21T15:58:04.539898Z" + } + }, + "source": [ + "import numpy as np\n", + "\n", + "import spacecore as sc" + ], + "outputs": [], + "execution_count": 50 + }, + { + "cell_type": "markdown", + "id": "context", + "metadata": {}, + "source": [ + "## Backend context and vector space\n", + "\n", + "We use NumPy here to keep the notebook easy to run. The same operators can be converted to other supported backends when their callbacks are written using backend-compatible operations." + ] + }, + { + "cell_type": "code", + "id": "setup", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.612618Z", + "start_time": "2026-05-21T15:58:04.600433Z" + } + }, + "source": [ + "ctx = sc.Context(sc.NumpyOps(), dtype=np.float64, enable_checks=False)\n", + "n = 50000\n", + "X = sc.VectorSpace((n,), ctx)\n", + "\n", + "grid = np.linspace(0.0, np.pi, n)" + ], + "outputs": [], + "execution_count": 51 + }, + { + "cell_type": "markdown", + "id": "spd-op", + "metadata": {}, + "source": [ + "## A square Hermitian positive-definite operator\n", + "\n", + "This operator acts like a positive diagonal matrix, but we do not build a matrix object. The callback stores only the diagonal coefficients and multiplies elementwise.\n", + "\n", + "Because the operator is real and self-adjoint, the forward and adjoint callbacks are the same function." + ] + }, + { + "cell_type": "code", + "id": "spd-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.696219Z", + "start_time": "2026-05-21T15:58:04.628355Z" + } + }, + "source": [ + "diag = ctx.asarray(np.concatenate(([0.25], np.linspace(1.0, 2.0, n - 2), [6.0])))\n", + "\n", + "def spd_apply(x):\n", + " return diag * x\n", + "\n", + "A = sc.MatrixFreeLinOp(spd_apply, spd_apply, X, X, ctx)\n", + "\n", + "A.domain.shape, A.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (50000,))" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 52 + }, + { + "cell_type": "markdown", + "id": "cg", + "metadata": {}, + "source": [ + "## Conjugate gradient: solve `A x = b`\n", + "\n", + "`cg` is for square Hermitian positive-definite systems. We create a known solution, apply the operator to get the right-hand side, and ask CG to recover the solution.\n", + "\n", + "The result object summarizes convergence metadata and avoids printing the full solution vector in its `repr`." + ] + }, + { + "cell_type": "code", + "id": "cg-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.725542Z", + "start_time": "2026-05-21T15:58:04.699252Z" + } + }, + "source": [ + "x_true = ctx.asarray(np.sin(grid) + 0.2 * np.cos(3.0 * grid))\n", + "b = A.apply(x_true)\n", + "\n", + "cg_result = sc.cg(A, b, tol=1e-8, maxiter=256)\n", + "cg_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "CGResult(converged=True, num_iters=64, residual_norm=1.18582e-08, x=)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 53 + }, + { + "cell_type": "code", + "id": "cg-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.738009Z", + "start_time": "2026-05-21T15:58:04.726043Z" + } + }, + "source": [ + "relative_error = ctx.ops.norm(cg_result.x - x_true) / ctx.ops.norm(x_true)\n", + "float(relative_error)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "4.5131997046538686e-11" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 54 + }, + { + "cell_type": "markdown", + "id": "lsqr-op", + "metadata": {}, + "source": [ + "## A rectangular matrix-free operator\n", + "\n", + "`lsqr` works for rectangular least-squares problems. Here `B : R^n -> R^{2n}` maps a vector into two measurement channels:\n", + "\n", + "$$\n", + "B x = \\begin{bmatrix}x \\\\ w \\odot x\\end{bmatrix}.\n", + "$$\n", + "\n", + "The adjoint combines the two channels:\n", + "\n", + "$$\n", + "B^* y = y_1 + w \\odot y_2.\n", + "$$\n", + "\n", + "Again, no matrix is built." + ] + }, + { + "cell_type": "code", + "id": "lsqr-op-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.750694Z", + "start_time": "2026-05-21T15:58:04.739637Z" + } + }, + "source": [ + "Y = sc.VectorSpace((2 * n,), ctx)\n", + "weights = ctx.asarray(np.linspace(0.5, 1.5, n))\n", + "\n", + "def rectangular_apply(x):\n", + " return ctx.ops.concatenate((x, weights * x), axis=0)\n", + "\n", + "def rectangular_rapply(y):\n", + " return y[:n] + weights * y[n:]\n", + "\n", + "B = sc.MatrixFreeLinOp(rectangular_apply, rectangular_rapply, X, Y, ctx)\n", + "\n", + "B.domain.shape, B.codomain.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "((50000,), (100000,))" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 55 + }, + { + "cell_type": "markdown", + "id": "lsqr", + "metadata": {}, + "source": [ + "## LSQR: solve a least-squares problem\n", + "\n", + "We add a small deterministic perturbation to the measurements so the problem is not just a perfectly consistent copy of the original vector. LSQR minimizes `||B x - data||` using `B.apply` and `B.H.apply` internally." + ] + }, + { + "cell_type": "code", + "id": "lsqr-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.783046Z", + "start_time": "2026-05-21T15:58:04.754755Z" + } + }, + "source": [ + "x_ls_true = ctx.asarray(np.cos(2.0 * grid))\n", + "noise = 0.01 * ctx.asarray(np.sin(np.linspace(0.0, 4.0 * np.pi, 2 * n)))\n", + "data = B.apply(x_ls_true) + noise\n", + "\n", + "lsqr_result = sc.lsqr(B, data, tol=1e-10, maxiter=256)\n", + "lsqr_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "LSQRResult(converged=True, num_iters=64, residual_norm=0.304159, normal_residual_norm=8.83974e-14, x=)" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 56 + }, + { + "cell_type": "code", + "id": "lsqr-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.795722Z", + "start_time": "2026-05-21T15:58:04.783535Z" + } + }, + "source": [ + "normal_residual = ctx.ops.norm(B.H.apply(B.apply(lsqr_result.x) - data))\n", + "float(normal_residual)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "8.839736157139945e-14" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 57 + }, + { + "cell_type": "markdown", + "id": "power", + "metadata": {}, + "source": [ + "## Power iteration: estimate the dominant eigenpair\n", + "\n", + "`power_iteration` estimates the largest-magnitude eigenvalue of a square operator. For our diagonal example, the answer should be the largest entry in `diag`." + ] + }, + { + "cell_type": "code", + "id": "power-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.812249Z", + "start_time": "2026-05-21T15:58:04.797109Z" + } + }, + "source": [ + "x0 = ctx.asarray(np.ones(n))\n", + "power_result = sc.power_iteration(A, x0=x0, tol=1e-10, maxiter=256)\n", + "power_result" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "PowerIterationResult(converged=True, num_iters=64, eigenvalue=6, residual_norm=3.25687e-29, eigenvector=)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 58 + }, + { + "cell_type": "code", + "id": "power-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:04.826966Z", + "start_time": "2026-05-21T15:58:04.812958Z" + } + }, + "source": [ + "float(power_result.eigenvalue), float(diag[-1])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.0, 6.0)" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 59 + }, + { + "cell_type": "markdown", + "id": "lanczos", + "metadata": {}, + "source": [ + "## Stochastic Lanczos: estimate the smallest eigenpair\n", + "\n", + "`stochastic_lanczos` builds a Krylov subspace from an initial domain element and returns a Ritz approximation to the smallest eigenpair. The returned eigenvector is a member of `A.domain`, not a raw matrix column from an internal representation." + ] + }, + { + "cell_type": "code", + "id": "lanczos-code", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.233074Z", + "start_time": "2026-05-21T15:58:04.828930Z" + } + }, + "source": [ + "initial = ctx.asarray(np.ones(n))\n", + "smallest_eigenvalue, smallest_eigenvector = sc.stochastic_lanczos(\n", + " A,\n", + " initial,\n", + " max_iter=64,\n", + " tol=1e-10,\n", + ")\n", + "\n", + "float(smallest_eigenvalue), smallest_eigenvector.shape" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(0.25, (50000,))" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 60 + }, + { + "cell_type": "code", + "id": "lanczos-check", + "metadata": { + "ExecuteTime": { + "end_time": "2026-05-21T15:58:05.249991Z", + "start_time": "2026-05-21T15:58:05.233738Z" + } + }, + "source": [ + "ritz_residual = ctx.ops.norm(A.apply(smallest_eigenvector) - smallest_eigenvalue * smallest_eigenvector)\n", + "float(ritz_residual), float(diag[0])" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(6.179265594934942e-15, 0.25)" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 61 + }, + { + "cell_type": "markdown", + "id": "wrap", + "metadata": {}, + "source": [ + "## What to take away\n", + "\n", + "- `MatrixFreeLinOp` lets you use iterative algorithms without building a matrix.\n", + "- `cg` is for square Hermitian positive-definite solves.\n", + "- `lsqr` is for general rectangular least-squares problems.\n", + "- `power_iteration` gives a dominant eigenpair estimate.\n", + "- `stochastic_lanczos` gives a smallest Ritz eigenpair estimate from a Krylov subspace.\n", + "- Solver result objects are compact to display, while full arrays remain available as attributes such as `cg_result.x`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}