diff --git a/README.md b/README.md index 7566010..ae4e09a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ mathematical objects: - a `Space` knows the structure and geometry of its elements; - a `LinOp` maps one space to another; +- a `Functional` maps a space element to a scalar; - backend-specific array creation and operations live behind `BackendOps`. The result is ordinary Python code whose core numerical logic is not tied to @@ -28,7 +29,7 @@ one array library. Mental model: ```text -BackendOps -> Context -> Space/LinOp -> Algorithm +BackendOps -> Context -> Space/LinOp/Functional -> Algorithm ``` ## Write once, run twice @@ -184,6 +185,17 @@ xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero, algebraic, and product-structured operators provide specialized batched paths. +### `Functional` + +A `Functional` represents a scalar-valued map on a space. `LinearFunctional` +covers maps such as ``, `MatrixFreeLinearFunctional` wraps a callable +without storing a representer, and `LinOpQuadraticForm` represents objectives +such as `0.5 * + ell(x) + a`. + +For batched inputs, `vvalue(xs)` evaluates independently over leading batch +axes. Quadratic forms that define gradients also expose `grad(x)` and +`vgrad(xs)`. + ## Who should use this? SpaceCore is aimed at people writing optimization, inverse-problem, optimal diff --git a/docs/source/api/functionals.rst b/docs/source/api/functionals.rst new file mode 100644 index 0000000..147e6bf --- /dev/null +++ b/docs/source/api/functionals.rst @@ -0,0 +1,60 @@ +Functionals API +=============== + +Functionals represent scalar-valued maps on spaces, including linear +functionals and quadratic forms. + +.. autosummary:: + :nosignatures: + + spacecore.functional.Functional + spacecore.functional.LinearFunctional + spacecore.functional.InnerProductFunctional + spacecore.functional.MatrixFreeLinearFunctional + spacecore.functional.QuadraticForm + spacecore.functional.LinOpQuadraticForm + +Functional +---------- + +.. autoclass:: spacecore.functional.Functional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Linear functionals +------------------ + +.. autoclass:: spacecore.functional.LinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.InnerProductFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.MatrixFreeLinearFunctional + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +Quadratic forms +--------------- + +.. autoclass:: spacecore.functional.QuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + +.. autoclass:: spacecore.functional.LinOpQuadraticForm + :members: + :undoc-members: + :inherited-members: + :show-inheritance: diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 259ac9d..654fbb8 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,3 +11,4 @@ directives for public objects instead of dumping entire modules. context spaces linops + functionals diff --git a/docs/source/index.rst b/docs/source/index.rst index 01185b4..470a78d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ mathematical objects: * a ``Space`` knows the structure and geometry of its elements; * a ``LinOp`` maps one space to another; +* a ``Functional`` maps a space element to a scalar; * backend-specific array creation and operations live behind ``BackendOps``. The result is ordinary Python code whose core numerical logic is not tied to @@ -31,7 +32,7 @@ Mental model: .. code-block:: text - BackendOps -> Context -> Space/LinOp -> Algorithm + BackendOps -> Context -> Space/LinOp/Functional -> Algorithm Write once, run twice --------------------- @@ -192,6 +193,19 @@ the leading batch axis: The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero, algebraic, and product-structured operators provide specialized batched paths. +``Functional`` +~~~~~~~~~~~~~~ + +A ``Functional`` represents a scalar-valued map on a space. +``LinearFunctional`` covers maps such as ````, +``MatrixFreeLinearFunctional`` wraps a callable without storing a representer, +and ``LinOpQuadraticForm`` represents objectives such as +``0.5 * + ell(x) + a``. + +For batched inputs, ``vvalue(xs)`` evaluates independently over leading batch +axes. Quadratic forms that define gradients also expose ``grad(x)`` and +``vgrad(xs)``. + Who should use this? -------------------- diff --git a/spacecore/__init__.py b/spacecore/__init__.py index 4925f7a..91e63c3 100644 --- a/spacecore/__init__.py +++ b/spacecore/__init__.py @@ -28,10 +28,19 @@ make_scaled, make_sum, ) +from .functional import ( + Functional, + InnerProductFunctional, + LinearFunctional, + LinOpQuadraticForm, + MatrixFreeLinearFunctional, + QuadraticForm, +) from .linalg import ( CGResult, LSQRResult, PowerIterationResult, + StochasticLanczosResult, cg, lsqr, power_iteration, @@ -89,9 +98,17 @@ "SumToSingleLinOp", "StackedLinOp", + "Functional", + "LinearFunctional", + "InnerProductFunctional", + "MatrixFreeLinearFunctional", + "QuadraticForm", + "LinOpQuadraticForm", + "CGResult", "LSQRResult", "PowerIterationResult", + "StochasticLanczosResult", "cg", "lsqr", "power_iteration", diff --git a/spacecore/functional/__init__.py b/spacecore/functional/__init__.py new file mode 100644 index 0000000..19bae03 --- /dev/null +++ b/spacecore/functional/__init__.py @@ -0,0 +1,12 @@ +from ._base import Functional +from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional +from ._quadratic import LinOpQuadraticForm, QuadraticForm + +__all__ = [ + "Functional", + "InnerProductFunctional", + "LinearFunctional", + "LinOpQuadraticForm", + "MatrixFreeLinearFunctional", + "QuadraticForm", +] diff --git a/spacecore/functional/_base.py b/spacecore/functional/_base.py new file mode 100644 index 0000000..185a827 --- /dev/null +++ b/spacecore/functional/_base.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Generic, TypeVar + +from .._contextual import ContextBound +from .._contextual.manager import ctx_manager +from ..backend import Context +from ..space import Space + + +Domain = TypeVar("Domain", bound=Space) + + +class Functional(ContextBound, Generic[Domain]): + """ + Scalar-valued map on a space. + + ``Functional`` represents a map ``F : X -> K`` without assuming any storage + model. It mirrors the minimal ``LinOp`` contract: the domain is converted + into the resolved context, value checks follow ``ctx.enable_checks``, and + batched evaluation is implemented by a backend ``vmap`` fallback. + """ + + def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None: + ctx = ctx_manager.resolve_context_priority(ctx, dom) + super().__init__(ctx) + self.dom = dom.convert(self.ctx) + self._enable_checks = self.ctx.enable_checks + + @property + def domain(self) -> Domain: + """Domain space of this scalar-valued map.""" + return self.dom + + @abstractmethod + def value(self, x: Any) -> Any: + """ + Evaluate this functional at ``x``. + + Contract: + - x is an element of ``self.domain``; + - the return value is scalar-like in the functional context. + """ + + def __call__(self, x: Any) -> Any: + """Evaluate this functional at ``x``.""" + return self.value(x) + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate this functional independently over leading batch axes.""" + return self._fallback_vvalue(xs, batch_space) + + def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]: + if hasattr(space, "spaces") and isinstance(value, tuple) and value: + return self._infer_batch_shape(space.spaces[0], value[0]) + shape = tuple(getattr(value, "shape", ())) + base_shape = tuple(space.shape) + if not base_shape: + return shape + if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape: + raise ValueError( + f"Cannot infer leading batch shape for value shape {shape} " + f"and base space shape {base_shape}." + ) + return shape[: len(shape) - len(base_shape)] + + def _input_batch_space( + self, + space: Space, + value: Any, + batch_space: Space | None, + ) -> Space: + if batch_space is not None: + return batch_space + batch_shape = self._infer_batch_shape(space, value) + return space.batch(batch_shape, tuple(range(len(batch_shape)))) + + def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space: + batch_shape = getattr(input_batch_space, "batch_shape", None) + batch_axes = getattr(input_batch_space, "batch_axes", None) + if batch_shape is None or batch_axes is None: + raise TypeError("batch_space must be a BatchSpace-compatible object.") + return space.batch(tuple(batch_shape), tuple(batch_axes)) + + def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]: + batch_shape = tuple(getattr(batch_space, "batch_shape", ())) + batch_axes = tuple(getattr(batch_space, "batch_axes", ())) + expected_axes = tuple(range(len(batch_shape))) + if batch_axes != expected_axes: + raise ValueError( + "Functional batching currently expects leading batch axes; " + f"got batch_axes={batch_axes}, expected {expected_axes}." + ) + return batch_shape + + def _vmap_leading(self, fn: Any, batch_ndim: int) -> Any: + mapped = fn + for _ in range(batch_ndim): + mapped = self.ops.vmap(mapped, in_axes=0, out_axes=0) + return mapped + + def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None: + shape = tuple(getattr(values, "shape", ())) + if shape != batch_shape: + raise ValueError( + f"Expected scalar batch output with shape {batch_shape}, got {shape}." + ) + + def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self._vmap_leading(self.value, len(batch_shape))(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def assert_domain(self, x: Any) -> None: + self.dom.check_member(x) + + @abstractmethod + def tree_flatten(self): + ... + + @classmethod + @abstractmethod + def tree_unflatten(cls, aux, children): + ... diff --git a/spacecore/functional/_linear.py b/spacecore/functional/_linear.py new file mode 100644 index 0000000..ca349a3 --- /dev/null +++ b/spacecore/functional/_linear.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Any + +from ._base import Domain, Functional +from ..backend import Context, jax_pytree_class +from ..space import Space + + +def _convert_space_element(space: Space, value: Any) -> Any: + if hasattr(space, "spaces") and isinstance(value, tuple): + if len(value) != len(space.spaces): + raise ValueError( + f"Expected tuple of length {len(space.spaces)}, got {len(value)}." + ) + return tuple( + _convert_space_element(component_space, component) + for component_space, component in zip(space.spaces, value) + ) + return space.ctx.asarray(value) + + +class LinearFunctional(Functional[Domain]): + """Linear scalar-valued map ``ell : X -> K``.""" + + @property + @abstractmethod + def representer(self) -> Any: + """ + Riesz representer of this functional when one is explicitly available. + + Matrix-free functionals may not have a stored representer and should + raise ``NotImplementedError``. + """ + + +@jax_pytree_class +class InnerProductFunctional(LinearFunctional[Domain]): + """ + Linear functional represented by a domain element. + + ``InnerProductFunctional(c, X)`` evaluates ``ell_c(x) = _X``. + """ + + def __init__( + self, + c: Any, + dom: Domain, + ctx: Context | str | None = None, + ) -> None: + super().__init__(dom, ctx) + self._c = _convert_space_element(self.domain, c) + if self._enable_checks: + self.domain._check_member(self._c) + + @property + def representer(self) -> Any: + """Stored domain element ``c`` defining ``ell_c(x) = ``.""" + return self._c + + def value(self, x: Any) -> Any: + """Return ``domain.inner(representer, x)``.""" + if self._enable_checks: + self.domain._check_member(x) + return self.domain.inner(self._c, x) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return self.domain == other.domain and self.ops.allclose( + self.domain.flatten(self._c), + other.domain.flatten(other._c), + ) + return False + + def tree_flatten(self): + children = (self._c,) + aux = (self.domain, self.ctx) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + domain, ctx = aux + c = children[0] + return cls(c, domain, ctx) + + def _convert(self, new_ctx: Context) -> InnerProductFunctional: + return InnerProductFunctional(self._c, self.domain.convert(new_ctx), new_ctx) + + +@jax_pytree_class +class MatrixFreeLinearFunctional(LinearFunctional[Domain]): + """ + Linear functional defined by user-supplied evaluation callables. + + No representer is stored or materialized. + """ + + def __init__( + self, + value: Any, + dom: Domain, + ctx: Context | str | None = None, + vvalue: Any | None = None, + ) -> None: + if not callable(value): + raise TypeError(f"value must be callable, got {type(value).__name__}.") + if vvalue is not None and not callable(vvalue): + raise TypeError(f"vvalue must be callable, got {type(vvalue).__name__}.") + super().__init__(dom, ctx) + self.value_fn = value + self.vvalue_fn = vvalue + + @property + def representer(self) -> Any: + raise NotImplementedError( + f"{type(self).__name__} does not store a Riesz representer." + ) + + def value(self, x: Any) -> Any: + """Return ``value_fn(x)``.""" + if self._enable_checks: + self.domain._check_member(x) + y = self.value_fn(x) + if self._enable_checks: + self._check_scalar_batch(y, ()) + return y + + def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any: + """Return ``vvalue_fn(xs)`` when supplied, otherwise use fallback batching.""" + if self.vvalue_fn is None: + return super().vvalue(xs, batch_space) + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + values = self.vvalue_fn(xs) + if self._enable_checks: + self._check_scalar_batch(values, batch_shape) + return values + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.domain == other.domain + and self.value_fn is other.value_fn + and self.vvalue_fn is other.vvalue_fn + ) + return False + + def tree_flatten(self): + children = () + aux = (self.value_fn, self.domain, self.ctx, self.vvalue_fn) + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + value_fn, domain, ctx, vvalue_fn = aux + return cls(value_fn, domain, ctx, vvalue_fn) + + def _convert(self, new_ctx: Context) -> MatrixFreeLinearFunctional: + return MatrixFreeLinearFunctional( + self.value_fn, + self.domain.convert(new_ctx), + new_ctx, + self.vvalue_fn, + ) diff --git a/spacecore/functional/_quadratic.py b/spacecore/functional/_quadratic.py new file mode 100644 index 0000000..ebbe79f --- /dev/null +++ b/spacecore/functional/_quadratic.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Any + +from ._base import Domain, Functional +from ._linear import LinearFunctional +from .._contextual.manager import ctx_manager +from ..backend import Context, jax_pytree_class +from ..linop import LinOp +from ..space import Space + + +class QuadraticForm(Functional[Domain]): + """Scalar quadratic objective on a space.""" + + def hess_apply(self, x: Any) -> Any: + """Apply the Hessian action at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define hess_apply.") + + def grad(self, x: Any) -> Any: + """Gradient at ``x`` when available.""" + raise NotImplementedError(f"{type(self).__name__} does not define grad.") + + def vgrad(self, xs: Any, batch_space: Space | None = None) -> Any: + """Evaluate ``grad`` independently over leading batch axes.""" + in_space = self._input_batch_space(self.domain, xs, batch_space) + batch_shape = self._require_leading_batch_axes(in_space) + if self._enable_checks: + in_space._check_member(xs) + grads = self._vmap_leading(self.grad, len(batch_shape))(xs) + if self._enable_checks: + self._output_batch_space(self.domain, in_space)._check_member(grads) + return grads + + +@jax_pytree_class +class LinOpQuadraticForm(QuadraticForm[Domain]): + """ + Quadratic form backed by a linear operator. + + ``q(x) = 1/2 * + linear(x) + a`` with ``Q : X -> X``. + """ + + def __init__( + self, + Q: LinOp[Domain, Domain], + linear: LinearFunctional[Domain] | None = None, + a: Any = 0, + ctx: Context | str | None = None, + ) -> None: + if not isinstance(Q, LinOp): + raise TypeError(f"Q must be a LinOp, got {type(Q).__name__}.") + if linear is not None and not isinstance(linear, LinearFunctional): + raise TypeError( + f"linear must be a LinearFunctional or None, got {type(linear).__name__}." + ) + + resolved_ctx = ctx_manager.resolve_context_priority(ctx, Q.domain, Q, linear) + Q = Q.convert(resolved_ctx) + if Q.domain != Q.codomain: + raise ValueError("LinOpQuadraticForm requires Q.domain == Q.codomain.") + if linear is not None: + linear = linear.convert(resolved_ctx) + if linear.domain != Q.domain: + raise ValueError("linear.domain must match Q.domain.") + + super().__init__(Q.domain, resolved_ctx) + self.Q = Q + self.linear = linear + self.a = self.ctx.asarray(a) + if self._enable_checks: + self._check_scalar_batch(self.a, ()) + + def value(self, x: Any) -> Any: + """Return ``1/2 * + linear(x) + a``.""" + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + value = 0.5 * self.domain.inner(x, qx) + if self.linear is not None: + value = value + self.linear.value(x) + return value + self.a + + def grad(self, x: Any) -> Any: + """ + Return the Euclidean/Riesz gradient. + + The quadratic part uses the symmetric adjoint part ``(Q + Q*) / 2``. + For self-adjoint ``Q`` this is exactly ``Qx``. + """ + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + qhx = self.Q.rapply(x) + grad = self.domain.scale(0.5, self.domain.add(qx, qhx)) + if self.linear is not None: + grad = self.domain.add(grad, self.linear.representer) + return grad + + def hess_apply(self, x: Any) -> Any: + """Return the self-adjoint Hessian action ``(Q + Q*) x / 2``.""" + if self._enable_checks: + self.domain._check_member(x) + qx = self.Q.apply(x) + qhx = self.Q.rapply(x) + return self.domain.scale(0.5, self.domain.add(qx, qhx)) + + def __eq__(self, other: Any) -> bool: + if type(other) is type(self): + return ( + self.Q == other.Q + and self.linear == other.linear + and self.ops.allclose(self.a, other.a) + ) + return False + + def tree_flatten(self): + children = (self.Q, self.linear, self.a) + aux = () + return children, aux + + @classmethod + def tree_unflatten(cls, aux, children): + Q, linear, a = children + return cls(Q, linear, a, Q.ctx) + + def _convert(self, new_ctx: Context) -> LinOpQuadraticForm: + linear = None if self.linear is None else self.linear.convert(new_ctx) + return LinOpQuadraticForm(self.Q.convert(new_ctx), linear, self.a, new_ctx) diff --git a/spacecore/linalg/__init__.py b/spacecore/linalg/__init__.py index 147ffb2..06be398 100644 --- a/spacecore/linalg/__init__.py +++ b/spacecore/linalg/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._cg import CGResult, cg -from ._lanczos import stochastic_lanczos +from ._lanczos import StochasticLanczosResult, stochastic_lanczos from ._lsqr import LSQRResult, lsqr from ._power import PowerIterationResult, power_iteration @@ -9,6 +9,7 @@ "CGResult", "LSQRResult", "PowerIterationResult", + "StochasticLanczosResult", "cg", "lsqr", "power_iteration", diff --git a/spacecore/linalg/_lanczos.py b/spacecore/linalg/_lanczos.py index acf02aa..2f1ec3e 100644 --- a/spacecore/linalg/_lanczos.py +++ b/spacecore/linalg/_lanczos.py @@ -1,11 +1,29 @@ from __future__ import annotations -from typing import Any +from typing import Any, NamedTuple from ..linop import LinOp from ..types import DenseArray from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval from ._utils import require_linop, require_square, should_check_iteration +from ._utils import result_repr + + +class StochasticLanczosResult(NamedTuple): + """Result returned by :func:`stochastic_lanczos`.""" + + eigenvalue: Any + eigenvector: Any + + def __repr__(self) -> str: + """Return a compact summary without printing the full eigenvector.""" + return result_repr( + "StochasticLanczosResult", + { + "eigenvalue": self.eigenvalue, + "eigenvector": self.eigenvector, + }, + ) def _check_lanczos_max_iter(max_iter: int) -> int: @@ -22,7 +40,7 @@ def stochastic_lanczos( max_iter: int = 100, tol: float = 1e-6, check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, -) -> tuple[DenseArray, Any]: +) -> StochasticLanczosResult: r"""Approximate the smallest eigenpair of a Hermitian operator. The operator is supplied as a square ``LinOp`` and ``initial_vector`` is an @@ -47,8 +65,9 @@ def stochastic_lanczos( this many iterations, and always on the final iteration. Returns: - A pair ``(eigenvalue, eigenvector)`` for the smallest approximated - eigenpair. + ``StochasticLanczosResult`` containing the smallest approximated + eigenpair. The result supports tuple unpacking as + ``eigenvalue, eigenvector``. """ A = require_linop(A) require_square(A, "stochastic_lanczos") @@ -185,4 +204,4 @@ def fill_off(ii: int, T_in: DenseArray) -> DenseArray: den = ops.real(A.domain.inner(x, x)) lam = num / den - return lam, x + return StochasticLanczosResult(lam, x) diff --git a/spacecore/linalg/_power.py b/spacecore/linalg/_power.py index 5358b85..2d40f29 100644 --- a/spacecore/linalg/_power.py +++ b/spacecore/linalg/_power.py @@ -1,8 +1,12 @@ from __future__ import annotations +from collections.abc import Callable from typing import Any, NamedTuple +from ..backend import Context +from ..functional import QuadraticForm from ..linop import LinOp +from ..space import Space from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter from ._utils import default_initial_vector, is_converged, normalize, require_linop from ._utils import require_square, result_repr, should_check_iteration @@ -31,8 +35,32 @@ def __repr__(self) -> str: ) +class _SelfAdjointAction(NamedTuple): + apply: Callable[[Any], Any] + domain: Space + ctx: Context + + @property + def ops(self) -> Any: + return self.ctx.ops + + @property + def dtype(self) -> Any: + return self.ctx.dtype + + +def _action_from_linop(A: LinOp) -> _SelfAdjointAction: + A = require_linop(A) + require_square(A, "power_iteration") + return _SelfAdjointAction(A.apply, A.domain, A.ctx) + + +def _action_from_quadratic_form(q: QuadraticForm) -> _SelfAdjointAction: + return _SelfAdjointAction(q.hess_apply, q.domain, q.ctx) + + def power_iteration( - A: LinOp, + A: LinOp | QuadraticForm, *, x0: Any | None = None, tol: float = 1e-6, @@ -40,25 +68,42 @@ def power_iteration( check_every: int = DEFAULT_CONVERGENCE_CHECK_INTERVAL, ) -> PowerIterationResult: """ - Estimate the dominant eigenpair of a square ``LinOp`` by power iteration. + Estimate the dominant eigenpair of a square ``LinOp`` or Hessian action. - The method uses only ``A.apply`` and domain-space operations. It returns the - Rayleigh quotient for the current normalized iterate, the eigenvector + ``A`` may be a square ``LinOp`` or a ``QuadraticForm`` that exposes + ``hess_apply``. Public dispatch converts either input into a fixed + self-adjoint action before entering the numerical loop. The method returns + the Rayleigh quotient for the current normalized iterate, the eigenvector estimate, and the residual norm ``||A x - lambda x||``. The residual-based stopping criterion is refreshed only every ``check_every`` iterations, and always on the final iteration. For spectral-norm estimates of a rectangular operator, call this on ``A.H @ A``. """ - A = require_linop(A) - require_square(A, "power_iteration") - maxiter = check_maxiter(maxiter, A) + if isinstance(A, QuadraticForm): + action = _action_from_quadratic_form(A) + elif isinstance(A, LinOp): + action = _action_from_linop(A) + else: + raise TypeError(f"A must be a LinOp or QuadraticForm, got {type(A).__name__}.") + + maxiter = check_maxiter(maxiter, action) check_every = check_interval(check_every) - x = default_initial_vector(A) if x0 is None else x0 - A.domain.check_member(x) - x, _ = normalize(A.domain, x) - zero = A.ops.asarray(0.0, dtype=A.dtype) - residual_norm = A.domain.norm(x) + float("inf") + x = default_initial_vector(action) if x0 is None else x0 + action.domain.check_member(x) + return PowerIterationResult(*_power_iteration_core(action, x, tol, maxiter, check_every)) + + +def _power_iteration_core( + action: _SelfAdjointAction, + x: Any, + tol: float, + maxiter: int, + check_every: int, +) -> tuple[Any, Any, Any, Any, Any]: + x, _ = normalize(action.domain, x) + zero = action.ops.asarray(0.0, dtype=action.dtype) + residual_norm = action.domain.norm(x) + float("inf") def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: _eigenvalue, _x, res_norm, k = carry @@ -66,30 +111,30 @@ def cond_fun(carry: tuple[Any, Any, Any, int]) -> Any: def body_fun(carry: tuple[Any, Any, Any, int]) -> tuple[Any, Any, Any, int]: _eigenvalue, x, _residual_norm, k = carry - y = A.apply(x) - x_next, _norm_y = normalize(A.domain, y) - y_next = A.apply(x_next) - eigenvalue_next = A.domain.inner(x_next, y_next) + y = action.apply(x) + x_next, _norm_y = normalize(action.domain, y) + y_next = action.apply(x_next) + eigenvalue_next = action.domain.inner(x_next, y_next) k_next = k + 1 def refresh_residual(_: Any) -> Any: - residual = A.codomain.axpy(-eigenvalue_next, x_next, y_next) - return A.codomain.norm(residual) + residual = action.domain.axpy(-eigenvalue_next, x_next, y_next) + return action.domain.norm(residual) - residual_norm_next = A.ops.cond( + residual_norm_next = action.ops.cond( should_check_iteration(k_next, maxiter, check_every), refresh_residual, lambda _: _residual_norm, - A.ops.asarray(0.0, dtype=A.dtype), + action.ops.asarray(0.0, dtype=action.dtype), ) return eigenvalue_next, x_next, residual_norm_next, k_next - eigenvalue, eigenvector, residual_norm, num_iters = A.ops.while_loop( + eigenvalue, eigenvector, residual_norm, num_iters = action.ops.while_loop( cond_fun, body_fun, (zero, x, residual_norm, 0), ) - return PowerIterationResult( + return ( eigenvalue, eigenvector, is_converged(residual_norm, tol), diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py new file mode 100644 index 0000000..3623b9b --- /dev/null +++ b/tests/functional/test_functional.py @@ -0,0 +1,118 @@ +import importlib + +import numpy as np +import pytest + + +def _ctx(dtype=np.float64, enable_checks=True): + sc = importlib.import_module("spacecore") + return sc.Context(sc.NumpyOps(), dtype=dtype, enable_checks=enable_checks) + + +def _quadratic_problem(ctx): + sc = importlib.import_module("spacecore") + dom = sc.VectorSpace((2,), ctx) + Q = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 4.0]]), dom, dom, ctx) + linear = sc.InnerProductFunctional(ctx.asarray([1.0, -1.0]), dom, ctx) + return sc.LinOpQuadraticForm(Q, linear, 3.0, ctx) + + +def test_explicit_context_overrides_inferred_contexts(): + sc = importlib.import_module("spacecore") + inferred = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=False) + dom = sc.VectorSpace((2,), inferred) + Q = sc.DenseLinOp(inferred.asarray([[1.0, 0.0], [0.0, 1.0]]), dom, dom, inferred) + linear = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom) + + functional = sc.InnerProductFunctional(inferred.asarray([1.0, 2.0]), dom, explicit) + quadratic = sc.LinOpQuadraticForm(Q, linear, 0.0, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert quadratic.ctx == explicit + assert quadratic.Q.ctx == explicit + assert quadratic.linear.ctx == explicit + + +def test_domain_conversion_and_membership_checks_work(): + sc = importlib.import_module("spacecore") + source = _ctx(np.float32, enable_checks=True) + explicit = _ctx(np.float64, enable_checks=True) + dom = sc.VectorSpace((2,), source) + functional = sc.InnerProductFunctional(source.asarray([1.0, 2.0]), dom, explicit) + + assert functional.ctx == explicit + assert functional.dtype == np.dtype(np.float64) + assert functional.domain.ctx == explicit + assert functional.domain.ctx.enable_checks is True + assert np.allclose(functional.value(functional.domain.ctx.asarray([3.0, 4.0])), 11.0) + with pytest.raises(Exception): + functional.value(explicit.asarray([1.0, 2.0, 3.0])) + + +def test_call_matches_value(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + assert np.allclose(q(x), q.value(x)) + + +def test_inner_product_functional_matches_domain_inner(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([1.0, -2.0]) + x = ctx.asarray([3.0, 4.0]) + functional = sc.InnerProductFunctional(c, dom, ctx) + + assert np.allclose(functional.value(x), dom.inner(c, x)) + + +def test_matrix_free_linear_functional_has_no_representer(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + dom = sc.VectorSpace((2,), ctx) + c = ctx.asarray([2.0, 3.0]) + x = ctx.asarray([4.0, 5.0]) + functional = sc.MatrixFreeLinearFunctional(lambda y: dom.inner(c, y), dom, ctx) + + assert np.allclose(functional.value(x), 23.0) + with pytest.raises(NotImplementedError): + functional.representer + + +def test_linop_quadratic_value_and_gradient_match_euclidean_hand_computation(): + ctx = _ctx() + q = _quadratic_problem(ctx) + x = ctx.asarray([2.0, -1.0]) + + assert np.allclose(q.value(x), 12.0) + assert np.allclose(q.grad(x), [5.0, -5.0]) + assert np.allclose(q.hess_apply(x), [4.0, -4.0]) + + +def test_vvalue_and_vgrad_match_elementwise_loops(): + ctx = _ctx() + q = _quadratic_problem(ctx) + xs = ctx.asarray([[2.0, -1.0], [0.0, 3.0], [1.5, 2.0]]) + + expected_values = ctx.ops.stack(tuple(q.value(x) for x in xs), axis=0) + expected_grads = ctx.ops.stack(tuple(q.grad(x) for x in xs), axis=0) + + assert np.allclose(q.vvalue(xs), expected_values) + assert np.allclose(q.vgrad(xs), expected_grads) + + +def test_bad_shapes_raise_when_checks_are_enabled(): + ctx = _ctx(enable_checks=True) + q = _quadratic_problem(ctx) + bad = ctx.asarray([1.0, 2.0, 3.0]) + + with pytest.raises(Exception): + q.value(bad) + with pytest.raises(Exception): + q.grad(bad) + with pytest.raises(Exception): + q.vvalue(ctx.asarray([[1.0, 2.0, 3.0]])) diff --git a/tests/integration/test_public_api.py b/tests/integration/test_public_api.py index 99b28eb..f352934 100644 --- a/tests/integration/test_public_api.py +++ b/tests/integration/test_public_api.py @@ -23,11 +23,14 @@ def test_expected_names_are_exported(): "IdentityLinOp", "MatrixFreeLinOp", "make_sum", "make_scaled", "make_composed", "BlockDiagonalLinOp", "StackedLinOp", "SumToSingleLinOp", + "Functional", "LinearFunctional", "InnerProductFunctional", + "MatrixFreeLinearFunctional", "QuadraticForm", "LinOpQuadraticForm", "VectorSpace", "HermitianSpace", "ProductSpace", "Space", "DenseArray", "SparseArray", "ArrayLike", "set_context", "get_context", "resolve_context_priority", "register_ops", "set_resolution_policy", "set_dtype_resolution_policy", "get_resolution_policy", "get_dtype_resolution_policy", + "StochasticLanczosResult", } if has_jax(): expected |= {"JaxOps", "jax_pytree_class"} @@ -43,6 +46,8 @@ def test_top_level_objects_match_source_modules(): backend = importlib.import_module("spacecore.backend") space = importlib.import_module("spacecore.space") linop = importlib.import_module("spacecore.linop") + functional = importlib.import_module("spacecore.functional") + linalg = importlib.import_module("spacecore.linalg") manager = importlib.import_module("spacecore._contextual.manager") assert sc.Context is backend.Context @@ -54,6 +59,9 @@ def test_top_level_objects_match_source_modules(): assert sc.Space is space.Space assert sc.VectorSpace is space.VectorSpace assert sc.DenseLinOp is linop.DenseLinOp + assert sc.Functional is functional.Functional + assert sc.InnerProductFunctional is functional.InnerProductFunctional + assert sc.StochasticLanczosResult is linalg.StochasticLanczosResult assert sc.get_context is manager.get_context assert sc.resolve_context_priority is manager.resolve_context_priority diff --git a/tests/linalg/test_krylov.py b/tests/linalg/test_krylov.py index 25958f4..40d14d2 100644 --- a/tests/linalg/test_krylov.py +++ b/tests/linalg/test_krylov.py @@ -1,4 +1,5 @@ import importlib +import inspect import numpy as np import pytest @@ -131,6 +132,62 @@ def test_power_iteration_estimates_dominant_eigenpair(backend_name, dtype): assert bool(to_numpy(result.converged)) +def test_power_iteration_accepts_quadratic_form_hessian_action(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 1.0]) + + op_result = sc.power_iteration(op, x0=x0, tol=1e-5, maxiter=60) + q_result = sc.power_iteration(q, x0=x0, tol=1e-5, maxiter=60) + + np.testing.assert_allclose(to_numpy(q_result.eigenvalue), to_numpy(op_result.eigenvalue)) + np.testing.assert_allclose( + np.abs(to_numpy(q_result.eigenvector)), + np.abs(to_numpy(op_result.eigenvector)), + rtol=1e-6, + atol=1e-6, + ) + + +def test_power_iteration_dispatches_quadratic_form_before_core(monkeypatch): + sc = importlib.import_module("spacecore") + power_mod = importlib.import_module("spacecore.linalg._power") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + x0 = ctx.asarray([1.0, 0.0]) + captured = {} + + def fake_core(action, x, tol, maxiter, check_every): + captured["action"] = action + captured["x"] = x + return ctx.asarray(0.0), x, ctx.asarray(True), 0, ctx.asarray(0.0) + + monkeypatch.setattr(power_mod, "_power_iteration_core", fake_core) + result = power_mod.power_iteration(q, x0=x0, maxiter=1) + + assert result.eigenvector is x0 + assert isinstance(captured["action"], power_mod._SelfAdjointAction) + assert captured["action"].domain == q.domain + x = ctx.asarray([1.0, 2.0]) + np.testing.assert_allclose(captured["action"].apply(x), q.hess_apply(x)) + + +def test_power_iteration_core_has_no_dispatch_logic(): + power_mod = importlib.import_module("spacecore.linalg._power") + source = inspect.getsource(power_mod._power_iteration_core) + + assert "isinstance" not in source + assert "hasattr" not in source + assert "getattr" not in source + assert "_SelfAdjointAction(" not in source + assert "PowerIterationResult(" not in source + + @pytest.mark.parametrize("backend_name,dtype", _backend_params()) def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype): sc = importlib.import_module("spacecore") @@ -155,6 +212,20 @@ def test_stochastic_lanczos_approximates_smallest_eigenpair(backend_name, dtype) ) +def test_stochastic_lanczos_returns_result_object(): + sc = importlib.import_module("spacecore") + ctx = _ctx() + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + + result = sc.stochastic_lanczos(op, ctx.asarray([1.0, 1.0]), max_iter=2, tol=1e-8) + eigenvalue, eigenvector = result + + assert isinstance(result, sc.StochasticLanczosResult) + np.testing.assert_allclose(eigenvalue, result.eigenvalue) + np.testing.assert_allclose(eigenvector, result.eigenvector) + + def test_stochastic_lanczos_uses_e0_for_zero_initial_vector(): sc = importlib.import_module("spacecore") ctx = _ctx() @@ -251,6 +322,21 @@ def test_power_iteration_jit_compiles_with_operator_argument(): np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_power_iteration_jit_compiles_with_quadratic_form_argument(): + jax = pytest.importorskip("jax") + sc = importlib.import_module("spacecore") + ctx = _ctx("jax", jax_real_dtype()) + space = sc.VectorSpace((2,), ctx) + op = sc.DenseLinOp(ctx.asarray([[2.0, 0.0], [0.0, 5.0]]), space, space, ctx) + q = sc.LinOpQuadraticForm(op, ctx=ctx) + + run = jax.jit(lambda quad, x: sc.power_iteration(quad, x0=x, maxiter=60).eigenvalue) + eigenvalue = run(q, ctx.asarray([1.0, 1.0])) + + np.testing.assert_allclose(to_numpy(eigenvalue), 5.0, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not has_jax(), reason="jax is not installed") def test_stochastic_lanczos_jit_compiles_with_operator_argument(): jax = pytest.importorskip("jax")