Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ A `Space` describes the structure and geometry of values:
- `VectorSpace` for Euclidean vectors and tensors;
- `HermitianSpace` for Hermitian or symmetric matrices;
- `ProductSpace` for Cartesian products of spaces.
- `BatchSpace` for batched elements such as `X.batch((B,), (0,))`,
representing `B` independent copies of `X`.

Algorithms should use space methods such as `zeros`, `add`, `scale`, `axpy`,
`inner`, `norm`, `flatten`, and `unflatten` instead of hard-coding backend array
Expand All @@ -168,6 +170,20 @@ A `LinOp` represents a linear operator between spaces:
Operators expose `apply` and `rapply`, so algorithms can use a linear map and
its adjoint without depending on the storage format.

For batched inputs, `vapply(xs)` and `rvapply(ys)` lift the operator over the
leading batch axis:

```python
XB = X.batch(batch_shape=(B,), batch_axes=(0,))
YB = Y.batch(batch_shape=(B,), batch_axes=(0,))

ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB
xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB
```

The fallback uses backend `vmap`; dense, sparse, diagonal, identity, zero,
algebraic, and product-structured operators provide specialized batched paths.

## Who should use this?

SpaceCore is aimed at people writing optimization, inverse-problem, optimal
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ actions.
spacecore.linop.LinOp
spacecore.linop.ProductLinOp
spacecore.linop.DenseLinOp
spacecore.linop.DiagonalLinOp
spacecore.linop.SparseLinOp
spacecore.linop.BlockDiagonalLinOp
spacecore.linop.StackedLinOp
Expand Down Expand Up @@ -42,6 +43,15 @@ DenseLinOp
:inherited-members:
:show-inheritance:

DiagonalLinOp
-------------

.. autoclass:: spacecore.linop.DiagonalLinOp
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

SparseLinOp
-----------

Expand Down
10 changes: 10 additions & 0 deletions docs/source/api/spaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Spaces define element structure, geometry, flattening, and validation.
spacecore.space.VectorSpace
spacecore.space.HermitianSpace
spacecore.space.ProductSpace
spacecore.space.BatchSpace
spacecore.space.SpaceCheck
spacecore.space.SpaceValidationError

Expand Down Expand Up @@ -49,6 +50,15 @@ ProductSpace
:inherited-members:
:show-inheritance:

BatchSpace
----------

.. autoclass:: spacecore.space.BatchSpace
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

Validation
----------

Expand Down
16 changes: 16 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ A ``Space`` describes the structure and geometry of values:
* ``VectorSpace`` for Euclidean vectors and tensors;
* ``HermitianSpace`` for Hermitian or symmetric matrices;
* ``ProductSpace`` for Cartesian products of spaces.
* ``BatchSpace`` for batched elements such as ``X.batch((B,), (0,))``,
representing ``B`` independent copies of ``X``.

Algorithms should use space methods such as ``zeros``, ``add``, ``scale``,
``axpy``, ``inner``, ``norm``, ``flatten``, and ``unflatten`` instead of
Expand All @@ -176,6 +178,20 @@ A ``LinOp`` represents a linear operator between spaces:
Operators expose ``apply`` and ``rapply``, so algorithms can use a linear map
and its adjoint without depending on the storage format.

For batched inputs, ``vapply(xs)`` and ``rvapply(ys)`` lift the operator over
the leading batch axis:

.. code-block:: python

XB = X.batch(batch_shape=(B,), batch_axes=(0,))
YB = Y.batch(batch_shape=(B,), batch_axes=(0,))

ys = A.vapply(xs, batch_space=XB) # xs in XB, ys in YB
xs2 = A.rvapply(ys, batch_space=YB) # ys in YB, xs2 in XB

The fallback uses backend ``vmap``; dense, sparse, diagonal, identity, zero,
algebraic, and product-structured operators provide specialized batched paths.

Who should use this?
--------------------

Expand Down
33 changes: 33 additions & 0 deletions docs/source/tutorials/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ the abstraction for linear maps between spaces.
Current implemented operator types are:

* ``DenseLinOp``
* ``DiagonalLinOp``
* ``SparseLinOp``
* ``BlockDiagonalLinOp``
* ``StackedLinOp``
Expand Down Expand Up @@ -60,6 +61,38 @@ operator :math:`A^* : Y \to X`, satisfying

\langle Ax, y\rangle_Y = \langle x, A^*y\rangle_X.

Batched lifting
---------------

For a batch of elements, use ``Space.batch`` to describe the batched space and
``vapply`` or ``rvapply`` to lift the operator:

.. math::

A : X \to Y,
\qquad
A^{(B)} : X^B \to Y^B.

.. code-block:: python

B = 8
XB = X.batch(batch_shape=(B,), batch_axes=(0,))
YB = Y.batch(batch_shape=(B,), batch_axes=(0,))

xs = ctx.asarray(np.ones((B,) + X.shape))
ys = op.vapply(xs, batch_space=XB)
xs_back = op.rvapply(ys, batch_space=YB)

This is equivalent to stacking scalar applications:

.. code-block:: python

ys_ref = ctx.ops.stack(tuple(op.apply(x) for x in xs), axis=0)

The base fallback uses backend ``vmap``. Structured operators override this
path when they can use matrix multiplication, sparse multi-vector products,
broadcasting, or componentwise product-space batching.

DenseLinOp
----------

Expand Down
4 changes: 4 additions & 0 deletions spacecore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .linop import (
BlockDiagonalLinOp,
ComposedLinOp,
DiagonalLinOp,
DenseLinOp,
IdentityLinOp,
LinOp,
Expand All @@ -37,6 +38,7 @@
stochastic_lanczos,
)
from .space import (
BatchSpace,
BackendCheck,
DTypeCheck,
HermitianCheck,
Expand Down Expand Up @@ -72,6 +74,7 @@

"LinOp",
"ComposedLinOp",
"DiagonalLinOp",
"DenseLinOp",
"IdentityLinOp",
"MatrixFreeLinOp",
Expand Down Expand Up @@ -100,6 +103,7 @@
"ProductComponentCheck",
"ProductStructureCheck",
"ShapeCheck",
"BatchSpace",
"VectorSpace",
"HermitianSpace",
"ProductSpace",
Expand Down
71 changes: 71 additions & 0 deletions spacecore/backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,77 @@ def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray:
"""Stack arrays along a new axis (delegates to xp.stack)."""
return self.xp.stack(tuple(arrays), axis=axis)

def vmap(
self,
fn: Callable,
in_axes: int | Sequence[int | None] | None = 0,
out_axes: int | Sequence[int | None] | None = 0,
) -> Callable:
"""Vectorize ``fn`` over array axes using a Python-loop fallback."""

def axis_for_arg(i: int) -> int | Sequence[int | None] | None:
if isinstance(in_axes, tuple) or isinstance(in_axes, list):
return in_axes[i]
return in_axes

def normalize_axis(axis: int, ndim: int) -> int:
return axis + ndim if axis < 0 else axis

def tree_size(x: Any, axis: Any) -> int | None:
if axis is None:
return None
if isinstance(x, tuple):
axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x)
for xi, ai in zip(x, axes):
size = tree_size(xi, ai)
if size is not None:
return size
return None
shape = tuple(getattr(x, "shape", ()))
axis = normalize_axis(int(axis), len(shape))
return int(shape[axis])

def tree_take(x: Any, axis: Any, i: int) -> Any:
if axis is None:
return x
if isinstance(x, tuple):
axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(x)
return tuple(tree_take(xi, ai, i) for xi, ai in zip(x, axes))
shape = tuple(getattr(x, "shape", ()))
axis = normalize_axis(int(axis), len(shape))
index = [slice(None)] * len(shape)
index[axis] = i
return x[tuple(index)]

def tree_stack(xs: Sequence[Any], axis: Any) -> Any:
first = xs[0]
if isinstance(first, tuple):
axes = axis if isinstance(axis, (tuple, list)) else (axis,) * len(first)
return tuple(
tree_stack(tuple(x[i] for x in xs), ai)
for i, ai in enumerate(axes)
)
if axis is None:
return first
return self.stack(xs, axis=int(axis))

def mapped(*args: Any) -> Any:
axes = tuple(axis_for_arg(i) for i in range(len(args)))
size = None
for arg, axis in zip(args, axes):
size = tree_size(arg, axis)
if size is not None:
break
if size is None:
return fn(*args)
outputs = tuple(
fn(*(tree_take(arg, axis, i) for arg, axis in zip(args, axes)))
for i in range(size)
)
return tree_stack(outputs, out_axes)

return mapped

def conj(self, x: DenseArray) -> DenseArray:
"""Complex conjugate of x (delegates to xp.conj)."""
return self.xp.conj(x)
Expand Down
9 changes: 9 additions & 0 deletions spacecore/backend/jax/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray:
"""
return a @ b

def vmap(
self,
fn: Callable,
in_axes: int | Sequence[int | None] | None = 0,
out_axes: int | Sequence[int | None] | None = 0,
) -> Callable:
"""Vectorize a function using ``jax.vmap``."""
return self.jax.vmap(fn, in_axes=in_axes, out_axes=out_axes)

def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False,
return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]:
"""
Expand Down
14 changes: 14 additions & 0 deletions spacecore/backend/torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,20 @@ def sparse_matmul(
return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0]
return self.torch.sparse.mm(a, b, **kwargs)

def vmap(
self,
fn: Callable,
in_axes: int | Sequence[int | None] | None = 0,
out_axes: int | Sequence[int | None] | None = 0,
) -> Callable:
"""Vectorize a function using PyTorch's native vmap when available."""
vmap = getattr(self.torch, "vmap", None)
if vmap is None and hasattr(self.torch, "func"):
vmap = getattr(self.torch.func, "vmap", None)
if vmap is None:
return super().vmap(fn, in_axes=in_axes, out_axes=out_axes)
return vmap(fn, in_dims=in_axes, out_dims=out_axes)

def eigh(
self,
x: DenseArray,
Expand Down
14 changes: 13 additions & 1 deletion spacecore/linalg/_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..linop import LinOp
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter
from ._utils import is_converged, real_inner, require_linop, require_square
from ._utils import safe_inverse, should_check_iteration, threshold
from ._utils import result_repr, safe_inverse, should_check_iteration, threshold


class CGResult(NamedTuple):
Expand All @@ -16,6 +16,18 @@ class CGResult(NamedTuple):
num_iters: Any
residual_norm: Any

def __repr__(self) -> str:
"""Return a compact summary without printing the full solution array."""
return result_repr(
"CGResult",
{
"converged": self.converged,
"num_iters": self.num_iters,
"residual_norm": self.residual_norm,
"x": self.x,
},
)


def cg(
A: LinOp,
Expand Down
15 changes: 14 additions & 1 deletion spacecore/linalg/_lsqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..linop import LinOp
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter
from ._utils import is_converged, require_linop, safe_inverse, should_check_iteration
from ._utils import threshold
from ._utils import result_repr, threshold


class LSQRResult(NamedTuple):
Expand All @@ -17,6 +17,19 @@ class LSQRResult(NamedTuple):
residual_norm: Any
normal_residual_norm: Any

def __repr__(self) -> str:
"""Return a compact summary without printing the full solution array."""
return result_repr(
"LSQRResult",
{
"converged": self.converged,
"num_iters": self.num_iters,
"residual_norm": self.residual_norm,
"normal_residual_norm": self.normal_residual_norm,
"x": self.x,
},
)


def lsqr(
A: LinOp,
Expand Down
15 changes: 14 additions & 1 deletion spacecore/linalg/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..linop import LinOp
from ._utils import DEFAULT_CONVERGENCE_CHECK_INTERVAL, check_interval, check_maxiter
from ._utils import default_initial_vector, is_converged, normalize, require_linop
from ._utils import require_square, should_check_iteration
from ._utils import require_square, result_repr, should_check_iteration


class PowerIterationResult(NamedTuple):
Expand All @@ -17,6 +17,19 @@ class PowerIterationResult(NamedTuple):
num_iters: Any
residual_norm: Any

def __repr__(self) -> str:
"""Return a compact summary without printing the full eigenvector."""
return result_repr(
"PowerIterationResult",
{
"converged": self.converged,
"num_iters": self.num_iters,
"eigenvalue": self.eigenvalue,
"residual_norm": self.residual_norm,
"eigenvector": self.eigenvector,
},
)


def power_iteration(
A: LinOp,
Expand Down
Loading
Loading