Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mathematical objects:

- a `Space` knows the structure and geometry of its elements;
- a `LinOp` maps one space to another;
- a `Functional` maps a space element to a scalar;
- backend-specific array creation and operations live behind `BackendOps`.

The result is ordinary Python code whose core numerical logic is not tied to
Expand All @@ -28,7 +29,7 @@ one array library.
Mental model:

```text
BackendOps -> Context -> Space/LinOp -> Algorithm
BackendOps -> Context -> Space/LinOp/Functional -> Algorithm
```

## Write once, run twice
Expand Down Expand Up @@ -184,6 +185,17 @@ xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB
The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero,
algebraic, and product-structured operators provide specialized batched paths.

### `Functional`

A `Functional` represents a scalar-valued map on a space. `LinearFunctional`
covers maps such as `<c, x>`, `MatrixFreeLinearFunctional` wraps a callable
without storing a representer, and `LinOpQuadraticForm` represents objectives
such as `0.5 * <x, Qx> + ell(x) + a`.

For batched inputs, `vvalue(xs)` evaluates independently over leading batch
axes. Quadratic forms that define gradients also expose `grad(x)` and
`vgrad(xs)`.

## Who should use this?

SpaceCore is aimed at people writing optimization, inverse-problem, optimal
Expand Down
60 changes: 60 additions & 0 deletions docs/source/api/functionals.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
Functionals API
===============

Functionals represent scalar-valued maps on spaces, including linear
functionals and quadratic forms.

.. autosummary::
:nosignatures:

spacecore.functional.Functional
spacecore.functional.LinearFunctional
spacecore.functional.InnerProductFunctional
spacecore.functional.MatrixFreeLinearFunctional
spacecore.functional.QuadraticForm
spacecore.functional.LinOpQuadraticForm

Functional
----------

.. autoclass:: spacecore.functional.Functional
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

Linear functionals
------------------

.. autoclass:: spacecore.functional.LinearFunctional
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

.. autoclass:: spacecore.functional.InnerProductFunctional
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

.. autoclass:: spacecore.functional.MatrixFreeLinearFunctional
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

Quadratic forms
---------------

.. autoclass:: spacecore.functional.QuadraticForm
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

.. autoclass:: spacecore.functional.LinOpQuadraticForm
:members:
:undoc-members:
:inherited-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ directives for public objects instead of dumping entire modules.
context
spaces
linops
functionals
16 changes: 15 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mathematical objects:

* a ``Space`` knows the structure and geometry of its elements;
* a ``LinOp`` maps one space to another;
* a ``Functional`` maps a space element to a scalar;
* backend-specific array creation and operations live behind ``BackendOps``.

The result is ordinary Python code whose core numerical logic is not tied to
Expand All @@ -31,7 +32,7 @@ Mental model:

.. code-block:: text

BackendOps -> Context -> Space/LinOp -> Algorithm
BackendOps -> Context -> Space/LinOp/Functional -> Algorithm

Write once, run twice
---------------------
Expand Down Expand Up @@ -192,6 +193,19 @@ the leading batch axis:
The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero,
algebraic, and product-structured operators provide specialized batched paths.

``Functional``
~~~~~~~~~~~~~~

A ``Functional`` represents a scalar-valued map on a space.
``LinearFunctional`` covers maps such as ``<c, x>``,
``MatrixFreeLinearFunctional`` wraps a callable without storing a representer,
and ``LinOpQuadraticForm`` represents objectives such as
``0.5 * <x, Qx> + ell(x) + a``.

For batched inputs, ``vvalue(xs)`` evaluates independently over leading batch
axes. Quadratic forms that define gradients also expose ``grad(x)`` and
``vgrad(xs)``.

Who should use this?
--------------------

Expand Down
17 changes: 17 additions & 0 deletions spacecore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@
make_scaled,
make_sum,
)
from .functional import (
Functional,
InnerProductFunctional,
LinearFunctional,
LinOpQuadraticForm,
MatrixFreeLinearFunctional,
QuadraticForm,
)
from .linalg import (
CGResult,
LSQRResult,
PowerIterationResult,
StochasticLanczosResult,
cg,
lsqr,
power_iteration,
Expand Down Expand Up @@ -89,9 +98,17 @@
"SumToSingleLinOp",
"StackedLinOp",

"Functional",
"LinearFunctional",
"InnerProductFunctional",
"MatrixFreeLinearFunctional",
"QuadraticForm",
"LinOpQuadraticForm",

"CGResult",
"LSQRResult",
"PowerIterationResult",
"StochasticLanczosResult",
"cg",
"lsqr",
"power_iteration",
Expand Down
12 changes: 12 additions & 0 deletions spacecore/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ._base import Functional
from ._linear import InnerProductFunctional, LinearFunctional, MatrixFreeLinearFunctional
from ._quadratic import LinOpQuadraticForm, QuadraticForm

__all__ = [
"Functional",
"InnerProductFunctional",
"LinearFunctional",
"LinOpQuadraticForm",
"MatrixFreeLinearFunctional",
"QuadraticForm",
]
130 changes: 130 additions & 0 deletions spacecore/functional/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any, Generic, TypeVar

from .._contextual import ContextBound
from .._contextual.manager import ctx_manager
from ..backend import Context
from ..space import Space


Domain = TypeVar("Domain", bound=Space)


class Functional(ContextBound, Generic[Domain]):
"""
Scalar-valued map on a space.

``Functional`` represents a map ``F : X -> K`` without assuming any storage
model. It mirrors the minimal ``LinOp`` contract: the domain is converted
into the resolved context, value checks follow ``ctx.enable_checks``, and
batched evaluation is implemented by a backend ``vmap`` fallback.
"""

def __init__(self, dom: Domain, ctx: Context | str | None = None) -> None:
ctx = ctx_manager.resolve_context_priority(ctx, dom)
super().__init__(ctx)
self.dom = dom.convert(self.ctx)
self._enable_checks = self.ctx.enable_checks

@property
def domain(self) -> Domain:
"""Domain space of this scalar-valued map."""
return self.dom

@abstractmethod
def value(self, x: Any) -> Any:
"""
Evaluate this functional at ``x``.

Contract:
- x is an element of ``self.domain``;
- the return value is scalar-like in the functional context.
"""

def __call__(self, x: Any) -> Any:
"""Evaluate this functional at ``x``."""
return self.value(x)

def vvalue(self, xs: Any, batch_space: Space | None = None) -> Any:
"""Evaluate this functional independently over leading batch axes."""
return self._fallback_vvalue(xs, batch_space)

def _infer_batch_shape(self, space: Space, value: Any) -> tuple[int, ...]:
if hasattr(space, "spaces") and isinstance(value, tuple) and value:
return self._infer_batch_shape(space.spaces[0], value[0])
shape = tuple(getattr(value, "shape", ()))
base_shape = tuple(space.shape)
if not base_shape:
return shape
if len(shape) < len(base_shape) or shape[-len(base_shape):] != base_shape:
raise ValueError(
f"Cannot infer leading batch shape for value shape {shape} "
f"and base space shape {base_shape}."
)
return shape[: len(shape) - len(base_shape)]

def _input_batch_space(
self,
space: Space,
value: Any,
batch_space: Space | None,
) -> Space:
if batch_space is not None:
return batch_space
batch_shape = self._infer_batch_shape(space, value)
return space.batch(batch_shape, tuple(range(len(batch_shape))))

def _output_batch_space(self, space: Space, input_batch_space: Space) -> Space:
batch_shape = getattr(input_batch_space, "batch_shape", None)
batch_axes = getattr(input_batch_space, "batch_axes", None)
if batch_shape is None or batch_axes is None:
raise TypeError("batch_space must be a BatchSpace-compatible object.")
return space.batch(tuple(batch_shape), tuple(batch_axes))

def _require_leading_batch_axes(self, batch_space: Space) -> tuple[int, ...]:
batch_shape = tuple(getattr(batch_space, "batch_shape", ()))
batch_axes = tuple(getattr(batch_space, "batch_axes", ()))
expected_axes = tuple(range(len(batch_shape)))
if batch_axes != expected_axes:
raise ValueError(
"Functional batching currently expects leading batch axes; "
f"got batch_axes={batch_axes}, expected {expected_axes}."
)
return batch_shape

def _vmap_leading(self, fn: Any, batch_ndim: int) -> Any:
mapped = fn
for _ in range(batch_ndim):
mapped = self.ops.vmap(mapped, in_axes=0, out_axes=0)
return mapped

def _check_scalar_batch(self, values: Any, batch_shape: tuple[int, ...]) -> None:
shape = tuple(getattr(values, "shape", ()))
if shape != batch_shape:
raise ValueError(
f"Expected scalar batch output with shape {batch_shape}, got {shape}."
)

def _fallback_vvalue(self, xs: Any, batch_space: Space | None = None) -> Any:
in_space = self._input_batch_space(self.domain, xs, batch_space)
batch_shape = self._require_leading_batch_axes(in_space)
if self._enable_checks:
in_space._check_member(xs)
values = self._vmap_leading(self.value, len(batch_shape))(xs)
if self._enable_checks:
self._check_scalar_batch(values, batch_shape)
return values

def assert_domain(self, x: Any) -> None:
self.dom.check_member(x)

@abstractmethod
def tree_flatten(self):
...

@classmethod
@abstractmethod
def tree_unflatten(cls, aux, children):
...
Loading
Loading