diff --git a/docs/source/design/index.rst b/docs/source/design/index.rst index 1d9df1b..390b4ad 100644 --- a/docs/source/design/index.rst +++ b/docs/source/design/index.rst @@ -10,3 +10,4 @@ users should reason about them. conversion_policy dtype_policy checking_policy + backend_ops_array_api diff --git a/docs/source/tutorials/backend_ops.rst b/docs/source/tutorials/backend_ops.rst index c90a88e..47b58ce 100644 --- a/docs/source/tutorials/backend_ops.rst +++ b/docs/source/tutorials/backend_ops.rst @@ -52,6 +52,14 @@ What BackendOps signifies internals. It mostly wraps NumPy-like methods, while normalizing the minimal signatures that SpaceCore relies on. +Common dense-array methods are implemented once in ``BackendOps`` by delegating +to an Array API compatible ``xp`` namespace. NumPy and PyTorch use +``array-api-compat`` wrappers, while JAX uses ``jax.numpy``. Concrete backend +classes keep behavior that is genuinely backend-specific, such as dtype +sanitation, sparse conversion, indexed updates, device/autograd controls, and +control-flow primitives. ``ops.xp`` is available as an escape hatch, but +portable SpaceCore code should prefer explicit ``ops`` methods. + For example, NumPy and JAX expose different optional arguments for matrix multiplication, but SpaceCore's portable interface only needs the common core: diff --git a/pyproject.toml b/pyproject.toml index 198b1b6..2796cbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ authors = [ { name = "Pavlo Pelikh" } ] dependencies = [ + "array-api-compat>=1.14.0", "numpy>=2.0.0", "scipy>=1.17", ] diff --git a/spacecore/backend/_ops.py b/spacecore/backend/_ops.py index 5bca111..09e2f89 100644 --- a/spacecore/backend/_ops.py +++ b/spacecore/backend/_ops.py @@ -1,54 +1,57 @@ from __future__ import annotations from abc import ABC, abstractmethod +import importlib from typing import Any, Sequence, Tuple, Callable, Optional, Type, ClassVar from ..types import DenseArray, SparseArray, DType, ArrayLike, Index, T, X, Y, R, Carry +class LazyNamespace: + def __init__(self, module_name: str) -> None: + self.__name__ = module_name + self.__isabstractmethod__ = False + self._module_name = module_name + self._module: Any | None = None + + def _load(self) -> Any: + if self._module is None: + self._module = importlib.import_module(self._module_name) + return self._module + + def __getattr__(self, name: str) -> Any: + return getattr(self._load(), name) + + @property + def is_loaded(self) -> bool: + return self._module is not None + + class BackendOps(ABC): """ - Backend-agnostic numerical ops interface (portable core). + Public numerical contract for SpaceCore backends. - Contract: - - This base class exposes only the portable subset used by library internals. - - Concrete backends (NumPy/JAX/Torch) may extend these methods with additional - optional keyword parameters (e.g., `order=`, `out=`, `where=`, `like=`, ...). + Common dense-array operations delegate to the backend's Array API-compatible + ``xp`` namespace. Subclasses provide backend-specific sparse conversion, + dtype policy, indexing mutation, control flow, device/autograd behavior, and + operations not covered by the Array API. """ _family: ClassVar[str] _allow_sparse: ClassVar[bool] + xp: ClassVar[Any] + + def __init__(self) -> None: + self._constant_cache: dict[str, DenseArray] = {} @property def family(self) -> str: - """ - Generic backend-agnostic wrapper to backend family identifier. - - Input: - None. - - Output: - String naming the concrete backend family. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Backend family identifier.""" return type(self)._family @property def allow_sparse(self) -> bool: - """ - Generic backend-agnostic wrapper to sparse-array support flag. - - Input: - None. - - Output: - Boolean indicating whether this backend supports sparse arrays. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Whether this backend supports sparse arrays.""" return self._allow_sparse def __eq__(self, other: Any) -> bool: @@ -56,753 +59,412 @@ def __eq__(self, other: Any) -> bool: return self.family == other.family return False + def __hash__(self) -> int: + return hash((type(self).__name__, self.family)) + @property @abstractmethod def dense_array(self) -> Type[Any]: - """ - Generic backend-agnostic wrapper to dense array type. - - Input: - None. - - Output: - Concrete dense array class accepted by this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Dense array type accepted by this backend.""" ... @property @abstractmethod def sparse_array(self) -> Tuple[Type[Any], ...] | None: - """ - Generic backend-agnostic wrapper to sparse array type tuple. - - Input: - None. - - Output: - Concrete sparse array classes accepted by this backend, or None. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Sparse array types accepted by this backend, or None.""" ... @abstractmethod def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Generic backend-agnostic wrapper to normalize a dtype specifier. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Normalize a dtype specifier for this backend.""" ... def is_dense(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a dense backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of the backend dense array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a dense array for this backend.""" return isinstance(x, self.dense_array) def is_sparse(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for a sparse backend array. - - Input: - x: Object to test. - - Output: - True when x is an instance of a backend sparse array type. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is a sparse array for this backend.""" return self.sparse_array is not None and isinstance(x, self.sparse_array) def is_array(self, x: Any) -> bool: - """ - Generic backend-agnostic wrapper to test for any backend array. - - Input: - x: Object to test. - - Output: - True when x is dense or sparse for this backend. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Return whether x is any array for this backend.""" return self.is_dense(x) or self.is_sparse(x) @abstractmethod - def get_dtype(self, x: Any) -> DType: - """ - Generic backend-agnostic wrapper to return an array dtype. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. + def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: + """Convert input to a backend sparse array.""" + ... - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: + """Multiply a sparse array by a dense array.""" ... @abstractmethod - def shape(self, x: Any) -> tuple[int, ...]: - """ - Generic backend-agnostic wrapper to return array shape metadata. + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, + keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: + """Compute a stable log-sum-exp reduction.""" + ... - Input: - x: Dense or sparse backend array. + @abstractmethod + def index_set( + self, + x: DenseArray, + index: Index, + values: ArrayLike, + *, + copy: bool = True, + ) -> DenseArray: + """Set indexed values using backend mutation semantics.""" - Output: - Tuple describing the logical shape of x. + @abstractmethod + def index_add( + self, + x: DenseArray, + index: Index, + values: DenseArray, + *, + copy: bool = True, + ) -> DenseArray: + """Add values into indexed positions using backend mutation semantics.""" + ... - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def ix_(self, *args: Any) -> Any: + """Build open-mesh index arrays.""" + ... @abstractmethod - def ndim(self, x: Any) -> int: - """ - Generic backend-agnostic wrapper to return array rank metadata. + def fori_loop( + self, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, + ) -> T: + """Run a counted loop primitive.""" - Input: - x: Dense or sparse backend array. + @abstractmethod + def while_loop( + self, + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: + """Run a while-loop primitive.""" - Output: - Number of dimensions in x. + @abstractmethod + def scan( + self, + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + ) -> Tuple[Carry, Y]: + """Run a scan primitive.""" - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + @abstractmethod + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: + """Run backend-compatible conditional branch selection.""" + ... @abstractmethod - def size(self, x: Any) -> int: - """ - Generic backend-agnostic wrapper to return logical element count. + def allclose_sparse( + self, + a: SparseArray, + b: SparseArray, + rtol: float = 1e-5, + atol: float = 1e-8, + ) -> bool: + """Compare sparse arrays elementwise within tolerances.""" + ... - Input: - x: Dense or sparse backend array. + def _dtype_arg(self, dtype: DType | None) -> DType | None: + return None if dtype is None else self.sanitize_dtype(dtype) - Output: - Total number of logical dense elements. + def _to_axis_tuple(self, axis: int | Sequence[int] | None) -> int | tuple[int, ...] | None: + if axis is None or isinstance(axis, int): + return axis + return tuple(axis) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def _permute_dims(self, x: DenseArray, axes: Sequence[int]) -> DenseArray: + axes = tuple(axes) + if hasattr(self.xp, "permute_dims"): + return self.xp.permute_dims(x, axes) + if hasattr(self.xp, "permute"): + return self.xp.permute(x, axes) + return self.xp.transpose(x, axes=axes) + + def _move_axis_order( + self, + ndim: int, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> tuple[int, ...]: + src = (source,) if isinstance(source, int) else tuple(source) + dst = (destination,) if isinstance(destination, int) else tuple(destination) + src = tuple(axis + ndim if axis < 0 else axis for axis in src) + dst = tuple(axis + ndim if axis < 0 else axis for axis in dst) + order = [axis for axis in range(ndim) if axis not in src] + for dest, axis in sorted(zip(dst, src, strict=True)): + order.insert(dest, axis) + return tuple(order) @property - @abstractmethod def inf(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to positive infinity scalar. - - Input: - None. - - Output: - Backend scalar representing positive infinity. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Positive infinity as a cached backend scalar.""" + return self._constant("inf", float("inf")) @property - @abstractmethod def nan(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to access a NaN scalar. - - Input: - None. - - Output: - Backend scalar representing NaN. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """NaN as a cached backend scalar.""" + return self._constant("nan", float("nan")) @property - @abstractmethod def pi(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to pi scalar. - - Input: - None. - - Output: - Backend scalar representing pi. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Pi as a cached backend scalar.""" + return self._constant("pi", 3.141592653589793) @property - @abstractmethod def e(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to access Euler's number scalar. + """Euler's number as a cached backend scalar.""" + return self._constant("e", 2.718281828459045) - Input: - None. + def _constant(self, name: str, value: float) -> DenseArray: + if name not in self._constant_cache: + self._constant_cache[name] = self.asarray(value) + return self._constant_cache[name] - Output: - Backend scalar representing Euler's number. + def eps(self, dtype: DType) -> float: + """Machine epsilon for dtype.""" + return float(self.xp.finfo(self.sanitize_dtype(dtype)).eps) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @property - @abstractmethod - def eps(self) -> DenseArray: - """ - Generic backend-agnostic wrapper to machine epsilon scalar. + def get_dtype(self, x: Any) -> DType: + """Return x.dtype after verifying x is a backend array.""" + if self.is_array(x): + return x.dtype + raise TypeError(f"Expected {self.family} array, got {type(x)}.") - Input: - None. + def shape(self, x: Any) -> tuple[int, ...]: + """Return x.shape as a tuple.""" + return tuple(x.shape) - Output: - Backend scalar for float64 machine epsilon. + def ndim(self, x: Any) -> int: + """Return the number of dimensions of x.""" + return int(x.ndim) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def size(self, x: Any) -> int: + """Return the total number of elements in x.""" + result = 1 + for dim in self.shape(x): + result *= int(dim) + return result + + def asarray(self, x: Any, dtype: DType | None = None, **backend_kwargs: Any) -> DenseArray: + """Convert input to a dense backend array (delegates to xp.asarray).""" + if self.is_sparse(x) and hasattr(x, "to_dense"): + x = x.to_dense() + dtype = self._dtype_arg(dtype) + if hasattr(self.xp, "asarray"): + return self.xp.asarray(x, dtype=dtype, **backend_kwargs) + return self.xp.as_tensor(x, dtype=dtype, **backend_kwargs) + + def astype(self, x: DenseArray, dtype: DType | None, **backend_kwargs: Any) -> DenseArray: + """Cast x to dtype, returning x unchanged when dtype is None.""" + if dtype is None: + return x + dtype = self.sanitize_dtype(dtype) + if hasattr(x, "astype"): + return x.astype(dtype, **backend_kwargs) + return x.to(dtype=dtype, **backend_kwargs) - @abstractmethod - def asarray(self, x: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to convert input to a dense array. + def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create an uninitialized array (delegates to xp.empty).""" + return self.xp.empty(shape, dtype=self._dtype_arg(dtype)) - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. + def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create a zero-filled array (delegates to xp.zeros).""" + return self.xp.zeros(shape, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array. + def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: + """Create a one-filled array (delegates to xp.ones).""" + return self.xp.ones(shape, dtype=self._dtype_arg(dtype)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def zeros_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: + """Create a zero-filled array like x (delegates to xp.zeros_like).""" + return self.xp.zeros_like(x, dtype=self._dtype_arg(dtype)) - @abstractmethod - def astype(self, x: DenseArray, dtype: DType) -> DenseArray: - """ - Generic backend-agnostic wrapper to cast an array to a dtype. + def ones_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: + """Create a one-filled array like x (delegates to xp.ones_like).""" + return self.xp.ones_like(x, dtype=self._dtype_arg(dtype)) - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. + def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array like x (delegates to xp.full_like).""" + return self.xp.full_like(x, value, dtype=self._dtype_arg(dtype)) - Output: - Dense backend array with the requested dtype. + def arange( + self, + start: int, + stop: int | None = None, + step: int | None = None, + dtype: DType | None = None, + ) -> DenseArray: + """Create an evenly spaced range (delegates to xp.arange).""" + dtype = self._dtype_arg(dtype) + if stop is None: + return self.xp.arange(start, dtype=dtype) + if step is None: + return self.xp.arange(start, stop, dtype=dtype) + return self.xp.arange(start, stop, step, dtype=dtype) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: + """Create a value-filled array (delegates to xp.full).""" + return self.xp.full(shape, fill_value, dtype=self._dtype_arg(dtype)) - @abstractmethod - def assparse(self, x: Any, dtype: DType | None = None) -> SparseArray: - """ - Generic backend-agnostic wrapper to convert input to a sparse array. + def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: + """Create an identity-like matrix (delegates to xp.eye).""" + return self.xp.eye(n, m, dtype=self._dtype_arg(dtype)) - Input: - x: Dense, sparse, or array-like input plus sparse-format options. + def ravel(self, x: DenseArray) -> DenseArray: + """Flatten x to one dimension.""" + if hasattr(self.xp, "ravel"): + return self.xp.ravel(x) + return self.reshape(x, (-1,)) - Output: - Sparse backend array. + def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: + """Reshape x (delegates to xp.reshape).""" + shape_arg = (shape,) if isinstance(shape, int) else shape + return self.xp.reshape(x, shape_arg) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: + """Permute dimensions of x.""" + if axes is None: + axes = tuple(reversed(range(self.ndim(x)))) + return self._permute_dims(x, axes) - @abstractmethod - def empty(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create an uninitialized dense array. + def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: + """Swap two axes of x.""" + if hasattr(self.xp, "swapaxes"): + return self.xp.swapaxes(x, axis1, axis2) + axes = list(range(self.ndim(x))) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + return self._permute_dims(x, axes) - Input: - shape: Output shape; dtype and placement options are backend-specific. + def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: + """Broadcast x to shape (delegates to xp.broadcast_to).""" + return self.xp.broadcast_to(x, shape) - Output: - Dense backend array with uninitialized values. + def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: + """Insert singleton dimensions into x.""" + if isinstance(axis, int): + if hasattr(self.xp, "expand_dims"): + return self.xp.expand_dims(x, axis=axis) + return self.xp.unsqueeze(x, axis) + out = x + ndim = self.ndim(x) + len(axis) + for ax in sorted(a + ndim if a < 0 else a for a in axis): + out = self.expand_dims(out, ax) + return out - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: + """Remove singleton dimensions from x.""" + if axis is None: + axis = tuple(i for i, dim in enumerate(self.shape(x)) if dim == 1) + if not axis: + return x + axis = (axis,) if isinstance(axis, int) else tuple(axis) + return self.xp.squeeze(x, axis=axis) - @abstractmethod - def zeros(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a zero-filled dense array. + def moveaxis( + self, + x: DenseArray, + source: int | Sequence[int], + destination: int | Sequence[int], + ) -> DenseArray: + """Move axes of x to new positions.""" + if hasattr(self.xp, "moveaxis"): + return self.xp.moveaxis(x, source, destination) + return self._permute_dims(x, self._move_axis_order(self.ndim(x), source, destination)) - Input: - shape: Output shape; dtype and placement options are backend-specific. + def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: + """Stack arrays along a new axis (delegates to xp.stack).""" + return self.xp.stack(tuple(arrays), axis=axis) - Output: - Dense backend array filled with zeros. + def conj(self, x: DenseArray) -> DenseArray: + """Complex conjugate of x (delegates to xp.conj).""" + return self.xp.conj(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def real(self, x: DenseArray) -> DenseArray: + """Real component of x (delegates to xp.real).""" + return self.xp.real(x) - @abstractmethod - def ones(self, shape: Tuple[int, ...], dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a one-filled dense array. + def imag(self, x: DenseArray) -> DenseArray: + """Imaginary component of x (delegates to xp.imag).""" + return self.xp.imag(x) - Input: - shape: Output shape; dtype and placement options are backend-specific. + def abs(self, x: DenseArray) -> DenseArray: + """Absolute value of x (delegates to xp.abs).""" + return self.xp.abs(x) - Output: - Dense backend array filled with ones. + def sign(self, x: DenseArray) -> DenseArray: + """Elementwise sign of x (delegates to xp.sign).""" + return self.xp.sign(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def sqrt(self, x: DenseArray) -> DenseArray: + """Elementwise square root of x (delegates to xp.sqrt).""" + return self.xp.sqrt(x) - @abstractmethod - def zeros_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create zeros shaped like another array. + def sum( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + dtype: DType | None = None, + ) -> DenseArray: + """Sum over given axes (delegates to xp.sum).""" + return self.xp.sum( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + def mean( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + """Mean over given axes (delegates to xp.mean).""" + return self.xp.mean(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - Output: - Dense backend array of zeros. + def min( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + """Minimum over given axes (delegates to xp.min).""" + return self.xp.min(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def max( + self, + x: DenseArray, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + """Maximum over given axes (delegates to xp.max).""" + return self.xp.max(x, axis=self._to_axis_tuple(axis), keepdims=keepdims) - @abstractmethod - def ones_like(self, x: DenseArray, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create ones shaped like another array. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of ones. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def full_like(self, x: DenseArray, value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create filled values shaped like another array. - - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with the requested value. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def arange(self, start: int, stop: int | None = None, step: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create evenly spaced integer-range values. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def full(self, shape: Tuple[int, ...], fill_value: Any, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a filled dense array. - - Input: - shape: Output shape; fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with fill_value. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def eye(self, n: int, m: int | None = None, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to create a dense identity-like matrix. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def ravel(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to flatten an array. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def reshape(self, x: DenseArray, shape: Tuple[int, ...] | int) -> DenseArray: - """ - Generic backend-agnostic wrapper to reshape an array. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to permute array axes. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Generic backend-agnostic wrapper to interchange two axes. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def broadcast_to(self, x: DenseArray, shape: Tuple[int, ...]) -> DenseArray: - """ - Generic backend-agnostic wrapper to broadcast an array to a shape. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Generic backend-agnostic wrapper to insert length-one axes. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to remove length-one axes. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to move axes to new positions. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def stack(self, arrays: Sequence[DenseArray], axis: int = 0) -> DenseArray: - """ - Generic backend-agnostic wrapper to stack arrays along a new axis. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def conj(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute complex conjugates. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def real(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract real components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def imag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract imaginary components. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def abs(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute absolute values. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sign(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute signs elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sqrt(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute square roots elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sum( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - dtype: DType | None = None, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by summation. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def mean( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by arithmetic mean. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def min( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by minimum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def max( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by maximum. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def prod( self, x: DenseArray, @@ -810,662 +472,192 @@ def prod( keepdims: bool = False, dtype: DType | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to reduce by product. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def trace(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to sum diagonal entries. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to return sorting indices. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: - """ - Generic backend-agnostic wrapper to sort values. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of minimum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Generic backend-agnostic wrapper to return indices of maximum values. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a conjugating vector dot product. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def matmul(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute matrix products. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to multiply sparse and dense arrays. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute a Kronecker product. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to evaluate an Einstein summation expression. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def eigh(self, x: DenseArray) -> tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenpairs. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute vector or matrix norms. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to solve dense linear systems. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. + """Product over given axes (delegates to xp.prod).""" + return self.xp.prod( + x, + axis=self._to_axis_tuple(axis), + dtype=self._dtype_arg(dtype), + keepdims=keepdims, + ) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def trace(self, x: DenseArray) -> DenseArray: + """Trace of a matrix (delegates to xp.trace when available).""" + if hasattr(self.xp, "trace"): + return self.xp.trace(x) + return self.sum(self.diagonal(x)) - @abstractmethod - def eigvalsh(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Hermitian eigenvalues. + def argsort(self, x: DenseArray, axis: int = -1) -> DenseArray: + """Indices that sort x (delegates to xp.argsort).""" + return self.xp.argsort(x, axis=axis) - Input: - A: Dense Hermitian or symmetric backend array. + def sort(self, x: DenseArray, axis: int = -1) -> DenseArray: + """Sort x along an axis (delegates to xp.sort).""" + return self.xp.sort(x, axis=axis) - Output: - Dense backend array containing eigenvalues. + def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + """Indices of minima (delegates to xp.argmin).""" + return self.xp.argmin(x, axis=axis, keepdims=keepdims) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: + """Indices of maxima (delegates to xp.argmax).""" + return self.xp.argmax(x, axis=axis, keepdims=keepdims) - @abstractmethod - def svd(self, A: DenseArray, full_matrices: bool = True) -> tuple[DenseArray, DenseArray, DenseArray]: + def vdot(self, x: DenseArray, y: DenseArray) -> DenseArray: """ - Generic backend-agnostic wrapper to compute singular value decompositions. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. + Returns sum(conj(x) * y). Matches numpy/jax/torch vdot and Array API + vecdot. DenseLinOp.rapply relies on this for complex inputs. """ + x_flat = self.ravel(x) + y_flat = self.ravel(y) + if hasattr(self.xp, "vdot"): + return self.xp.vdot(x_flat, y_flat) + return self.xp.vecdot(x_flat, y_flat) - @abstractmethod - def cholesky(self, A: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute Cholesky factors. + def matmul( + self, + a: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Matrix product (delegates to xp.matmul).""" + return self.xp.matmul(a, b, **({} if backend_kwargs is None else backend_kwargs)) - Input: - A: Dense Hermitian positive-definite backend array. + def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: + """Kronecker product (delegates to xp.kron).""" + return self.xp.kron(a, b) - Output: - Dense backend array containing a triangular factor. + def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: + """Einstein summation (delegates to xp.einsum).""" + return self.xp.einsum(subscripts, *operands) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + def eigh( + self, + x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray]: + """Eigenpairs of a Hermitian dense matrix (delegates to xp.linalg.eigh).""" + if self.is_sparse(x): + raise TypeError("eigh requires a dense array; sparse input is not supported.") + return self.xp.linalg.eigh(x, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, - keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Generic backend-agnostic wrapper to compute a stable log-sum-exp reduction. + def norm( + self, + x: DenseArray, + ord: int | str | None = None, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, + ) -> DenseArray: + """Vector or matrix norm (delegates to xp.linalg.norm).""" + return self.xp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. + def solve( + self, + A: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Solve a dense linear system (delegates to xp.linalg.solve).""" + return self.xp.linalg.solve(A, b, **({} if backend_kwargs is None else backend_kwargs)) - Output: - Dense backend array or tuple containing log-sum-exp results. + def eigvalsh( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Eigenvalues of a Hermitian dense matrix (delegates to xp.linalg.eigvalsh).""" + return self.xp.linalg.eigvalsh(A, **({} if backend_kwargs is None else backend_kwargs)) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def svd( + self, + A: DenseArray, + full_matrices: bool = True, + backend_kwargs: dict[str, Any] | None = None, + ) -> tuple[DenseArray, DenseArray, DenseArray]: + """Singular value decomposition (delegates to xp.linalg.svd).""" + return self.xp.linalg.svd( + A, + full_matrices=full_matrices, + **({} if backend_kwargs is None else backend_kwargs), + ) + + def cholesky( + self, + A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + """Cholesky factorization (delegates to xp.linalg.cholesky).""" + return self.xp.linalg.cholesky(A, **({} if backend_kwargs is None else backend_kwargs)) - @abstractmethod def exp(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute exponentials elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Elementwise exponential (delegates to xp.exp).""" + return self.xp.exp(x) - @abstractmethod def log(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute natural logarithms elementwise. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Elementwise natural logarithm (delegates to xp.log).""" + return self.xp.log(x) - @abstractmethod def where(self, condition: DenseArray | bool, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to select values by condition. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Select between x and y by condition (delegates to xp.where).""" + return self.xp.where(condition, x, y) - @abstractmethod def maximum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise maxima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. + """Elementwise maximum (delegates to xp.maximum).""" + return self.xp.maximum(x, y) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod def minimum(self, x: ArrayLike, y: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to compute elementwise minima. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Elementwise minimum (delegates to xp.minimum).""" + return self.xp.minimum(x, y) - @abstractmethod def clip(self, x: DenseArray, a_min: ArrayLike, a_max: ArrayLike) -> DenseArray: - """ - Generic backend-agnostic wrapper to clip values into an interval. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. + """Clip x into [a_min, a_max] (delegates to xp.clip).""" + return self.xp.clip(x, a_min, a_max) - Output: - Dense backend array with clipped values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def isfinite(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test finiteness elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Elementwise finite check (delegates to xp.isfinite).""" + return self.xp.isfinite(x) - @abstractmethod def isnan(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to test NaN values elementwise. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Generic backend-agnostic wrapper to join arrays along an existing axis. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. + """Elementwise NaN check (delegates to xp.isnan).""" + return self.xp.isnan(x) - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + def concatenate( + self, + arrays: Sequence[DenseArray], + axis: int = 0, + dtype: DType | None = None, + ) -> DenseArray: + """Concatenate arrays along an existing axis (delegates to xp.concat).""" + if hasattr(self.xp, "concat"): + result = self.xp.concat(tuple(arrays), axis=axis) + else: + result = self.xp.concatenate(tuple(arrays), axis=axis) + return self.astype(result, dtype) if dtype is not None else result - @abstractmethod def take( self, x: DenseArray, indices: DenseArray, axis: int | None = None, ) -> DenseArray: - """ - Generic backend-agnostic wrapper to take values by integer indices. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. + """Take entries from x by integer indices (delegates to xp.take).""" + return self.xp.take(x, indices, axis=axis) - Output: - Dense backend array containing selected values. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def diag(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to extract or build a diagonal. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Extract or construct a diagonal (delegates to xp.diag).""" + return self.xp.diag(x) - @abstractmethod def diagonal(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return selected diagonals. - - Input: - x: Dense backend array plus offset and axis controls. + """Return the main diagonal of x (delegates to xp.diagonal).""" + return self.xp.diagonal(x) - Output: - Dense backend array containing selected diagonals. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod def tril(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return lower-triangular values. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ + """Lower triangle of x (delegates to xp.tril).""" + return self.xp.tril(x) - @abstractmethod def triu(self, x: DenseArray) -> DenseArray: - """ - Generic backend-agnostic wrapper to return upper-triangular values. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def index_set( - self, - x: DenseArray, - index: Index, - values: ArrayLike, - *, - copy: bool = True, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to set indexed values. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def index_add( - self, - x: DenseArray, - index: Index, - values: DenseArray, - *, - copy: bool = True, - ) -> DenseArray: - """ - Generic backend-agnostic wrapper to add into indexed values. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def ix_(self, *args: Any) -> Any: - """ - Generic backend-agnostic wrapper to build open mesh index arrays. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a counted loop primitive. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Generic backend-agnostic wrapper to run a while-loop primitive. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - ) -> Tuple[Carry, Y]: - """ - Generic backend-agnostic wrapper to run a scan primitive. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - - @abstractmethod - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Generic backend-agnostic wrapper to run conditional branch selection. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... + """Upper triangle of x (delegates to xp.triu).""" + return self.xp.triu(x) - @abstractmethod def allclose( self, a: DenseArray, @@ -1474,41 +666,12 @@ def allclose( atol: float = 1e-8, equal_nan: bool = False, ) -> bool: - """ - Generic backend-agnostic wrapper to compare dense arrays elementwise within tolerances. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - @abstractmethod - def allclose_sparse( - self, - a: SparseArray, - b: SparseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - ) -> bool: - """ - Generic backend-agnostic wrapper to compare sparse arrays elementwise within tolerances. - - Input: - a, b: Sparse backend arrays; rtol and atol configure comparison. - - Output: - Boolean indicating whether sparse arrays are close. - - This declaration only specifies the portable SpaceCore interface. - See the concrete backend implementation for backend-specific behavior. - """ - ... - - def __repr__(self): - return f"{type(self).__name__}" + """Return whether dense arrays are close within tolerances.""" + return bool(self.xp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) + + def __repr__(self) -> str: + xp = type(self).xp + xp_state = "" + if isinstance(xp, LazyNamespace): + xp_state = f", xp_loaded={xp.is_loaded!r}" + return f"{type(self).__name__}(family={self.family!r}{xp_state})" diff --git a/spacecore/backend/jax/_ops.py b/spacecore/backend/jax/_ops.py index 6ec1591..04b2eab 100644 --- a/spacecore/backend/jax/_ops.py +++ b/spacecore/backend/jax/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Literal, Tuple, Callable, Optional, Type -import inspect from warnings import warn from .._family import BackendFamily @@ -57,20 +56,13 @@ class JaxOps(BackendOps): import jax import jax.numpy as jnp import jax.experimental.sparse as jsparse + xp = jnp _family = BackendFamily.jax.value.lower() _allow_sparse = True def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.jnp.reshape).parameters - self._reshape_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.reshape).parameters - self._ravel_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ravel).parameters - self._zeros_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros).parameters - self._empty_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.empty).parameters - self._zeros_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.zeros_like).parameters - self._ones_like_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.ones_like).parameters - self._broadcast_to_supports_out_sharding = "out_sharding" in inspect.signature(self.jnp.broadcast_to).parameters - + super().__init__() def sanitize_dtype(self, dtype: DType | None) -> DType: """ @@ -123,77 +115,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return dt - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Jax ndarray or BCOO/BCSR, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using JAX. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://docs.jax.dev/en/latest/jax.Array.html - - Backend-specific notes: - Shape-polymorphic dimensions may not be concrete Python integers inside traced code. - """ - result = 1 - for dim in self.shape(x): - result *= dim - return result - @property def dense_array(self) -> Type[Any]: """ @@ -220,121 +141,6 @@ def sparse_array(self) -> Tuple[Type[Any], ...]: """ return (self.jsparse.BCOO, self.jsparse.BCSR) - @property - def inf(self): - """ - Positive infinity scalar using JAX. - - Returns: - Backend scalar representing positive infinity. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.inf.html - """ - return self.jnp.array(self.jnp.inf) - - @property - def nan(self): - """ - NaN scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing NaN. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.nan.html - """ - return self.jnp.array(self.jnp.nan) - - @property - def pi(self): - """ - Pi scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing pi. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.pi.html - """ - return self.jnp.array(self.jnp.pi) - - @property - def e(self): - """ - Euler number scalar using JAX. - - Input: - None. - - Output: - Backend scalar representing Euler's number. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.e.html - """ - return self.jnp.array(self.jnp.e) - - @property - def eps(self): - """ - Machine epsilon scalar using JAX. - - Input: - None. - - Output: - Backend scalar for float64 machine epsilon. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.finfo.html - """ - return self.jnp.array(self.jnp.finfo(self.jnp.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - copy: bool | None = None, - device: Any | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using JAX. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.asarray.html - """ - return self.jnp.asarray(a, dtype=dtype, order=order, copy=copy, device=device) - - def astype(self, x: DenseArray, dtype: DType, copy: bool = True) -> DenseArray: - """ - Cast an array to a dtype using JAX. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.astype.html - """ - return x.astype(dtype, copy=copy) - def assparse( self, x: Any, @@ -407,1436 +213,192 @@ def assparse( raise ValueError(f"Unknown sparse format: {format!r}") - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Create an uninitialized dense array using JAX. + Multiply sparse and dense arrays using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend array with uninitialized values. + Dense backend array containing the product. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.empty.html + https://docs.jax.dev/en/latest/jax.experimental.sparse.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. """ - if self._empty_supports_out_sharding: - return self.jnp.empty(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.empty(shape, dtype=dtype, device=device) + return a @ b - def zeros( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, + return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: """ - Create a zero-filled dense array using JAX. + Compute a stable log-sum-exp reduction using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + a: Dense backend array; axis, weights, and sign options control the reduction. Output: - Dense backend array filled with zeros. + Dense backend array or tuple containing log-sum-exp results. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros.html - - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html """ - if self._zeros_supports_out_sharding: - return self.jnp.zeros(shape, dtype=dtype, device=device, out_sharding=out_sharding) - return self.jnp.zeros(shape, dtype=dtype, device=device) + return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): """ - Create a one-filled dense array using JAX. + Set indexed values using JAX. Input: - shape: Output shape; dtype and placement options are backend-specific. + x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. Output: - Dense backend array filled with ones. + Dense backend array with indexed values set. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones.html + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html + + Backend-specific notes: + JAX arrays are immutable; copy=False raises NotImplementedError. """ - return self.jnp.ones(shape, dtype=dtype, device=device, out_sharding=out_sharding) + if not copy: + raise NotImplementedError( + "JAX arrays are immutable; copy=False is not supported." + ) + return x.at[index].set(values) - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + def ix_(self, *args: Any) -> Any: """ - Create zeros shaped like another array using JAX. + Build open mesh index arrays using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + args: One-dimensional index arrays or sequences. Output: - Dense backend array of zeros. + Tuple of dense backend arrays usable for open-mesh indexing. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.zeros_like.html - - Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._zeros_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.zeros_like(x, **kwargs) + return self.jnp.ix_(*args) - def ones_like( + def fori_loop( self, - x: DenseArray, - dtype: DType | None = None, - shape: Any = None, + lower: int, + upper: int, + body_fun: Callable[[int, T], T], + init_val: T, *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + unroll: int | bool | None = None, + ) -> T: """ - Create ones shaped like another array using JAX. + Run a counted loop primitive using JAX. Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. + lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array of ones. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ones_like.html + https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html Backend-specific notes: - out_sharding is forwarded only when supported by the installed JAX version. + Loop bounds and unroll behavior follow JAX tracing and compilation rules. """ - kwargs: dict[str, Any] = {"dtype": dtype, "shape": shape, "device": device} - if self._ones_like_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.ones_like(x, **kwargs) + return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - def full_like( + def while_loop( self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - shape: Any = None, - *, - device: Any | None = None, - ) -> DenseArray: + cond_fun: Callable[[T], bool], + body_fun: Callable[[T], T], + init_val: T, + ) -> T: """ - Create filled values shaped like another array using JAX. + Run a while-loop primitive using JAX. Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. + cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. Output: - Dense backend array filled with the requested value. + Final carry value after loop execution. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full_like.html - """ - return self.jnp.full_like(x, value, dtype=dtype, shape=shape, device=device) - - def arange(self, - start: int, - stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using JAX. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.arange.html + Backend-specific notes: + Condition and body are staged according to JAX lax control-flow semantics. """ - return self.jnp.arange(start, stop, step, dtype=dtype, device=device, out_sharding=out_sharding) + return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - def full( + def scan( self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: + f: Callable[[Carry, X], Tuple[Carry, Y]], + init: Carry, + xs: X, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + ) -> Tuple[Carry, Y]: """ - Create a filled dense array using JAX. + Run a scan primitive using JAX. Input: - shape: Output shape; fill_value and dtype options are backend-specific. + f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. Output: - Dense backend array filled with fill_value. + Tuple of final carry and stacked outputs. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.full.html - """ - return self.jnp.full(shape, fill_value, dtype=dtype, device=device) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - device: Any | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using JAX. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. + https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.eye.html + Backend-specific notes: + Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. """ - return self.jnp.eye(N=N, M=M, k=k, dtype=dtype, device=device) + return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - def ravel( - self, - a: DenseArray, - order: Literal["C", "F", "A", "K"] = "C", - *, - out_sharding: Any | None = None, - ) -> DenseArray: + def cond( + self, + pred: bool, + true_fun: Callable[[T], R], + false_fun: Callable[[T], R], + *operands: Any, + ) -> R: """ - Flatten an array using JAX. + Run conditional branch selection using JAX. Input: - x: Dense backend array plus optional order parameters. + pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. Output: - One-dimensional dense backend array view or copy. + Result returned by the selected branch. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ravel.html - """ - if order == "C" and out_sharding is None: - return self.jnp.ravel(a) - if self._ravel_supports_out_sharding: - return self.jnp.ravel(a, order=order, out_sharding=out_sharding) - return self.jnp.ravel(a, order=order) + https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A"] = "C", - *, - copy: bool | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: + Backend-specific notes: + Branches are staged according to JAX lax.cond semantics rather than Python eager branching. """ - Reshape an array using JAX. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. + return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.reshape.html - """ - if order == "C" and copy is None and out_sharding is None: - return self.jnp.reshape(a, shape) - kwargs: dict[str, Any] = {"order": order} - if self._reshape_supports_copy: - kwargs["copy"] = copy - if self._reshape_supports_out_sharding: - kwargs["out_sharding"] = out_sharding - return self.jnp.reshape(a, shape, **kwargs) - - def transpose( - self, - x: DenseArray, - axes: Sequence[int] | None = None, - ) -> DenseArray: + def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ - Permute array axes using JAX. + Add into indexed values using JAX. Input: - x: Dense backend array; axes: Optional axis order. + x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. Output: - Dense backend array with permuted axes. + Dense backend array with indexed values incremented. See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.transpose.html - """ - return self.jnp.transpose(x, axes=axes) + https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Interchange two axes using JAX. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.swapaxes.html - """ - return self.jnp.swapaxes(x, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - *, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Broadcast an array to a shape using JAX. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.broadcast_to.html - """ - if self._broadcast_to_supports_out_sharding: - return self.jnp.broadcast_to(x, shape, out_sharding=out_sharding) - return self.jnp.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using JAX. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.expand_dims.html - """ - return self.jnp.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using JAX. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.squeeze.html - """ - return self.jnp.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using JAX. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.moveaxis.html - """ - return self.jnp.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: Any | None = None, - dtype: DType | None = None, - ) -> DenseArray: - """ - Stack arrays along a new axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.stack.html - """ - return self.jnp.stack(arrays, axis=axis, out=out, dtype=dtype) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Compute complex conjugates using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.conj.html - """ - return self.jnp.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.real.html - """ - return self.jnp.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.imag.html - """ - return self.jnp.imag(x) - - def abs(self, x: DenseArray) -> DenseArray: - """ - Compute absolute values using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.abs.html - """ - return self.jnp.abs(x) - - def sign(self, x: DenseArray) -> DenseArray: - """ - Compute signs elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sign.html - """ - return self.jnp.sign(x) - - def sqrt(self, x: DenseArray) -> DenseArray: - """ - Compute square roots elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sqrt.html - """ - return self.jnp.sqrt(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by summation using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sum.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.sum(a) - return self.jnp.sum(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: None = None, - keepdims: bool = False, - *, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by arithmetic mean using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.mean.html - """ - if out is None and where is None: - if axis is None and dtype is None and not keepdims: - return self.jnp.mean(a) - return self.jnp.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by minimum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.min.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.min(a) - return self.jnp.min(a, axis=axis, keepdims=keepdims) - return self.jnp.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - ) -> DenseArray: - """ - Reduce by maximum using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.max.html - """ - if out is None and initial is None and where is None: - if axis is None and not keepdims: - return self.jnp.max(a) - return self.jnp.max(a, axis=axis, keepdims=keepdims) - return self.jnp.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: Any | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | None = None, - promote_integers: bool = True, - ) -> DenseArray: - """ - Reduce by product using JAX. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.prod.html - """ - if out is None and initial is None and where is None and promote_integers: - if axis is None and dtype is None and not keepdims: - return self.jnp.prod(a) - return self.jnp.prod(a, axis=axis, dtype=dtype, keepdims=keepdims) - return self.jnp.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - promote_integers=promote_integers, - ) - - def trace( - self, - a: DenseArray, - offset: int | Any = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using JAX. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.trace.html - """ - return self.jnp.trace(a, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out) - - def argsort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Return sorting indices using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argsort.html - """ - return self.jnp.argsort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def sort( - self, - a: DenseArray, - axis: int | None = -1, - kind: None = None, - order: None = None, - stable: bool = True, - descending: bool = False - ) -> DenseArray: - """ - Sort values using JAX. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.sort.html - """ - return self.jnp.sort(a, axis=axis, kind=kind, order=order, stable=stable, descending=descending) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmin.html - """ - return self.jnp.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: Any | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using JAX. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.argmax.html - """ - return self.jnp.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - ) -> DenseArray: - """ - Compute a conjugating vector dot product using JAX. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vdot.html - """ - if precision is None and preferred_element_type is None: - return self.jnp.vdot(a, b) - return self.jnp.vdot(a, b, precision=precision, preferred_element_type=preferred_element_type) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - *, - precision: Any | None = None, - preferred_element_type: DType | None = None, - out_sharding: Any | None = None - ) -> DenseArray: - """ - Compute matrix products using JAX. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html - """ - if precision is None and preferred_element_type is None and out_sharding is None: - return self.jnp.matmul(a, b) - return self.jnp.matmul( - a, - b, - precision=precision, - preferred_element_type=preferred_element_type, - out_sharding=out_sharding, - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using JAX. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.jax.dev/en/latest/jax.experimental.sparse.html - - Backend-specific notes: - Uses JAX sparse matmul and returns a JAX array; sparse support remains experimental in JAX. - """ - return a @ b - - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a Kronecker product using JAX. - - Input: - a, b: Dense backend arrays. - - Output: - Dense backend array containing the Kronecker product. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.kron.html - """ - return self.jnp.kron(a, b) - - def einsum( - self, - subscripts: str, - /, - *operands: DenseArray, - out: Any | None = None, - optimize: str | bool | list[Tuple[int, ...]] = "auto", - precision: Any | None = None, - preferred_element_type: DType | None = None, - _dot_general: Any | None = None, - out_sharding: Any | None = None, - ) -> DenseArray: - """ - Evaluate an Einstein summation expression using JAX. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. - - Output: - Dense backend array containing the contraction result. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.einsum.html - """ - return self.jnp.einsum( - subscripts, - *operands, - out=out, - optimize=optimize, - precision=precision, - preferred_element_type=preferred_element_type, - # _dot_general=_dot_general, - out_sharding=out_sharding, - ) - - def eigh( - self, - x: DenseArray, - UPLO: Literal["L", "U"] = "L", - symmetrize_input: bool = True - ) -> Tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenpairs using JAX. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigh.html - - Backend-specific notes: - SpaceCore rejects sparse input before delegating to JAX dense linear algebra. - """ - if self.is_sparse(x): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.jnp.linalg.eigh(x, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: - """ - Compute vector or matrix norms using JAX. - - Input: - x: Dense backend array; ord, axis, and keepdims select the norm. - - Output: - Dense backend array or scalar containing norm values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.norm.html - """ - return self.jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: - """ - Solve dense linear systems using JAX. - - Input: - A: Dense coefficient array; b: Dense right-hand side array. - - Output: - Dense backend array solving A @ x = b. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.solve.html - """ - return self.jnp.linalg.solve(A, b) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using JAX. - - Input: - A: Dense Hermitian or symmetric backend array. - - Output: - Dense backend array containing eigenvalues. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.eigvalsh.html - """ - return self.jnp.linalg.eigvalsh(A, UPLO=UPLO, symmetrize_input=symmetrize_input) - - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - subset_by_index: tuple[int, int] | None = None, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decompositions using JAX. - - Input: - A: Dense backend array plus SVD options. - - Output: - Dense backend arrays containing singular vectors and/or singular values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html - """ - return self.jnp.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - subset_by_index=subset_by_index, - ) - - def cholesky( - self, - A: DenseArray, - *, - upper: bool = False, - symmetrize_input: bool = True, - ) -> DenseArray: - """ - Compute Cholesky factors using JAX. - - Input: - A: Dense Hermitian positive-definite backend array. - - Output: - Dense backend array containing a triangular factor. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.cholesky.html - """ - return self.jnp.linalg.cholesky(A, upper=upper, symmetrize_input=symmetrize_input) - - def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, - return_sign: bool = False, where: DenseArray | None = None) -> DenseArray | Tuple[DenseArray, DenseArray]: - """ - Compute a stable log-sum-exp reduction using JAX. - - Input: - a: Dense backend array; axis, weights, and sign options control the reduction. - - Output: - Dense backend array or tuple containing log-sum-exp results. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.scipy.special.logsumexp.html - """ - return self.jax.scipy.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign, where=where) - - def exp(self, x: DenseArray) -> DenseArray: - """ - Compute exponentials elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.exp.html - """ - return self.jnp.exp(x) - - def log(self, x: DenseArray) -> DenseArray: - """ - Compute natural logarithms elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.log.html - """ - return self.jnp.log(x) - - def maximum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise maxima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.maximum.html - """ - return self.jnp.maximum(x, y) - - def minimum(self, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Compute elementwise minima using JAX. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.minimum.html - """ - return self.jnp.minimum(x, y) - - def clip(self, x: DenseArray, a_min: DenseArray, a_max: DenseArray) -> DenseArray: - """ - Clip values into an interval using JAX. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html - """ - return self.jnp.clip(x, a_min, a_max) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isfinite.html - """ - return self.jnp.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using JAX. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.isnan.html - """ - return self.jnp.isnan(x) - - def where(self, condition: DenseArray | bool, x: DenseArray | None = None, y: DenseArray | None = None, *, - size: int | None = None, fill_value: DenseArray | None = None) -> DenseArray: - """ - Select values by condition using JAX. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html - """ - return self.jnp.where(condition, x, y, size=size, fill_value=fill_value) - - def concatenate(self, arrays: Sequence[DenseArray], axis: int = 0, dtype: DType | None = None) -> DenseArray: - """ - Join arrays along an existing axis using JAX. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.concatenate.html - """ - if dtype is None: - return self.jnp.concatenate(arrays, axis=axis) - return self.jnp.concatenate(arrays, axis=axis, dtype=dtype) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: None = None, - mode: str | None = None, - unique_indices: bool = False, - indices_are_sorted: bool = False, - fill_value: Any | None = None, - ) -> DenseArray: - """ - Take values by integer indices using JAX. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.take.html - - Backend-specific notes: - Out-of-bounds and mode behavior follow JAX, which can differ from NumPy. - """ - return self.jnp.take( - x, - indices, - axis=axis, - out=out, - mode=mode, - unique_indices=unique_indices, - indices_are_sorted=indices_are_sorted, - fill_value=fill_value, - ) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diag.html - """ - return self.jnp.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using JAX. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.diagonal.html - """ - return self.jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.tril.html - """ - return self.jnp.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using JAX. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.triu.html - """ - return self.jnp.triu(x, k=k) - - def index_set(self, x: DenseArray, index: Index, values: ArrayLike, *, copy: bool = True): - """ - Set indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Replacement values; copy controls mutation policy. - - Output: - Dense backend array with indexed values set. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - - Backend-specific notes: - JAX arrays are immutable; copy=False raises NotImplementedError. - """ - if not copy: - raise NotImplementedError( - "JAX arrays are immutable; copy=False is not supported." - ) - return x.at[index].set(values) - - def ix_(self, *args: Any) -> Any: - """ - Build open mesh index arrays using JAX. - - Input: - args: One-dimensional index arrays or sequences. - - Output: - Tuple of dense backend arrays usable for open-mesh indexing. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ix\\_.html - """ - return self.jnp.ix_(*args) - - def fori_loop( - self, - lower: int, - upper: int, - body_fun: Callable[[int, T], T], - init_val: T, - *, - unroll: int | bool | None = None, - ) -> T: - """ - Run a counted loop primitive using JAX. - - Input: - lower, upper: Loop bounds; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.fori_loop.html - - Backend-specific notes: - Loop bounds and unroll behavior follow JAX tracing and compilation rules. - """ - return self.jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) - - def while_loop( - self, - cond_fun: Callable[[T], bool], - body_fun: Callable[[T], T], - init_val: T, - ) -> T: - """ - Run a while-loop primitive using JAX. - - Input: - cond_fun: Loop condition; body_fun: Loop body; init_val: Initial carry value. - - Output: - Final carry value after loop execution. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html - - Backend-specific notes: - Condition and body are staged according to JAX lax control-flow semantics. - """ - return self.jax.lax.while_loop(cond_fun, body_fun, init_val) - - def scan( - self, - f: Callable[[Carry, X], Tuple[Carry, Y]], - init: Carry, - xs: X, - length: Optional[int] = None, - reverse: bool = False, - unroll: int = 1, - _split_transpose: bool = False, - ) -> Tuple[Carry, Y]: - """ - Run a scan primitive using JAX. - - Input: - f: Scan body; init: Initial carry; xs: Per-step inputs plus scan options. - - Output: - Tuple of final carry and stacked outputs. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.scan.html - - Backend-specific notes: - Inputs and outputs may be pytrees and are staged according to JAX lax.scan semantics. - """ - return self.jax.lax.scan(f, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose) - - def cond( - self, - pred: bool, - true_fun: Callable[[T], R], - false_fun: Callable[[T], R], - *operands: Any, - ) -> R: - """ - Run conditional branch selection using JAX. - - Input: - pred: Predicate; true_fun and false_fun: Branch functions; operands: Branch inputs. - - Output: - Result returned by the selected branch. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.lax.cond.html - - Backend-specific notes: - Branches are staged according to JAX lax.cond semantics rather than Python eager branching. - """ - return self.jax.lax.cond(pred, true_fun, false_fun, *operands) - - def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): - """ - Add into indexed values using JAX. - - Input: - x: Dense backend array; index: Selection; values: Values to add; copy controls mutation policy. - - Output: - Dense backend array with indexed values incremented. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.Array.at.html - - Backend-specific notes: - JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. + Backend-specific notes: + JAX arrays are immutable; copy=False raises NotImplementedError and repeated indices follow JAX scatter-add semantics. """ if not copy: raise NotImplementedError( @@ -1844,28 +406,6 @@ def index_add(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bo ) return x.at[index].add(values) - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using JAX. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.allclose.html - """ - return self.jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/numpy/_ops.py b/spacecore/backend/numpy/_ops.py index 0a309ad..9a919bc 100644 --- a/spacecore/backend/numpy/_ops.py +++ b/spacecore/backend/numpy/_ops.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, Sequence, Tuple, Literal, Callable, Optional, Type -import inspect from .._family import BackendFamily from .._ops import BackendOps @@ -49,1348 +48,127 @@ class NumpyOps(BackendOps): """ import numpy as np import scipy as sp + import array_api_compat.numpy as xp _family = BackendFamily.numpy.value.lower() _allow_sparse = True - @property - def dense_array(self) -> Type[Any]: - """ - Dense array type using NumPy. - - Returns: - Concrete dense array class accepted by this backend. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html - """ - return self.np.ndarray - - @property - def sparse_array(self) -> Tuple[Type[Any], ...]: - """ - Sparse array type tuple using SciPy. - - Returns: - Concrete sparse array classes accepted by this backend, or None. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - """ - sparse = self.sp.sparse - types: list[type[Any]] = [] - if hasattr(sparse, "spmatrix"): - types.append(sparse.spmatrix) - if hasattr(sparse, "sparray"): - types.append(sparse.sparray) - return tuple(types) - - - def __init__(self) -> None: - self._reshape_supports_copy = "copy" in inspect.signature(self.np.reshape).parameters - - def sanitize_dtype(self, dtype: DType | None) -> DType: - """ - Normalize a dtype specifier using NumPy. - - Input: - dtype: Optional dtype requested by SpaceCore or the caller. - - Output: - Backend dtype object accepted by array constructors. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.dtype.html - """ - - if dtype is None: - return self.np.float64 - return self.np.dtype(dtype) - - def get_dtype(self, x: Any) -> DType: - """ - Return an array dtype using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Backend dtype associated with x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html - """ - if self.is_dense(x): - return x.dtype - elif self.is_sparse(x): - return x.dtype - else: - raise TypeError(f'Expected Numpy ndarray or SciPy sparse array, got {type(x)}.') - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return array shape metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Tuple describing the logical shape of x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return array rank metadata using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Number of dimensions in x. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using NumPy. - - Input: - x: Dense or sparse backend array. - - Output: - Total number of logical dense elements. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html - - Backend-specific notes: - SciPy sparse inputs are reported by logical dense size, not stored entries. - """ - return int(self.np.prod(self.shape(x), dtype=self.np.intp)) - - @property - def inf(self): - """ - Positive infinity scalar using NumPy. - - Returns: - Backend scalar representing positive infinity. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.inf) - - @property - def nan(self): - """ - NaN scalar using NumPy. - - Returns: - Backend scalar representing NaN. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.nan) - - @property - def pi(self): - """ - Pi scalar using NumPy. - - Returns: - Backend scalar representing pi. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.pi) - - @property - def e(self): - """ - Euler number scalar using NumPy. - - Returns: - Backend scalar representing Euler's number. - - See: - https://numpy.org/doc/stable/reference/constants.html - """ - return self.np.array(self.np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using NumPy. - - Returns: - Backend scalar for float64 machine epsilon. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.finfo.html - """ - return self.np.array(self.np.finfo(self.np.float64).eps) - - def asarray( - self, - a: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] | None = None, - *, - device: str | None = None, - copy: bool | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Convert input to a dense array using NumPy. - - Input: - x/a: Array-like input and optional dtype or backend conversion parameters. - - Output: - Dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.asarray.html - """ - return self.np.asarray( - a, - dtype=dtype, - order=order, - device=device, - copy=copy, - like=like, - ) - - def astype( - self, - x: DenseArray, - dtype: DType, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "unsafe", - subok: bool = True, - copy: bool = True, - ) -> DenseArray: - """ - Cast an array to a dtype using NumPy. - - Input: - x: Dense backend array; dtype: target dtype and optional casting controls. - - Output: - Dense backend array with the requested dtype. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ndarray.astype.html - """ - return x.astype(dtype, order=order, casting=casting, subok=subok, copy=copy) - - def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: - """ - Convert input to a sparse array using SciPy. - - Input: - x: Dense, sparse, or array-like input plus sparse-format options. - - Output: - Sparse backend array. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. - """ - sparse = self.sp.sparse - - if self.is_sparse(x): - if format == "csr": - return x.tocsr() - if format == "csc": - return x.tocsc() - if format == "coo": - return x.tocoo() - raise ValueError(f"Unknown sparse format: {format!r}") - - x_arr = self.asarray(x) - - if x_arr.ndim != 2: - raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - - if format == "csr": - return sparse.csr_matrix(x_arr) - if format == "csc": - return sparse.csc_matrix(x_arr) - if format == "coo": - return sparse.coo_matrix(x_arr) - - raise ValueError(f"Unknown sparse format: {format!r}") - - def empty( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create an uninitialized dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array with uninitialized values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.empty.html - """ - - return self.np.empty( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a zero-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros.html - """ - return self.np.zeros( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ones( - self, - shape: int | Tuple[int, ...], - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a one-filled dense array using NumPy. - - Input: - shape: Output shape; dtype and placement options are backend-specific. - - Output: - Dense backend array filled with ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones.html - """ - return self.np.ones( - shape, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def zeros_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create zeros shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of zeros. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.zeros_like.html - """ - return self.np.zeros_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def ones_like( - self, - x: DenseArray, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create ones shaped like another array using NumPy. - - Input: - x: Prototype dense array; dtype, shape, and placement options are backend-specific. - - Output: - Dense backend array of ones. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ones_like.html - """ - return self.np.ones_like(x, dtype=dtype, order=order, subok=subok, shape=shape) - - def full_like( - self, - x: DenseArray, - value: Any, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - subok: bool = True, - shape: int | Tuple[int, ...] | None = None, - ) -> DenseArray: - """ - Create filled values shaped like another array using NumPy. - - Input: - x: Prototype dense array; value/fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with the requested value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full_like.html - """ - return self.np.full_like(x, value, dtype=dtype, order=order, subok=subok, shape=shape) - - def arange(self, - start: int, stop: int | None = None, - step: int | None = None, - dtype: DType | None = None, - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create evenly spaced integer-range values using NumPy. - - Input: - start, stop, step: Range parameters; dtype and placement options are backend-specific. - - Output: - One-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.arange.html - """ - return self.np.arange( - start, - stop, - step, - dtype=dtype, - device=device, - like=like, - ) - - def full( - self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a filled dense array using NumPy. - - Input: - shape: Output shape; fill_value and dtype options are backend-specific. - - Output: - Dense backend array filled with fill_value. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.full.html - """ - return self.np.full( - shape, - fill_value, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def eye( - self, - N: int, - M: int | None = None, - k: int = 0, - dtype: DType | None = float, - order: Literal["C", "F"] = "C", - *, - device: str | None = None, - like: DenseArray | None = None, - ) -> DenseArray: - """ - Create a dense identity-like matrix using NumPy. - - Input: - n and optional m: Matrix dimensions; dtype and placement options are backend-specific. - - Output: - Two-dimensional dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.eye.html - """ - return self.np.eye( - N, - M=M, - k=k, - dtype=dtype, - order=order, - device=device, - like=like, - ) - - def ravel(self, a: DenseArray, order: Literal["C", "F", "A", "K"] = "C") -> DenseArray: - """ - Flatten an array using NumPy. - - Input: - x: Dense backend array plus optional order parameters. - - Output: - One-dimensional dense backend array view or copy. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.ravel.html - """ - return self.np.ravel(a, order=order) - - def reshape( - self, - a: DenseArray, - shape: int | Tuple[int, ...], - order: Literal["C", "F", "A", "K"] = "C", - copy: bool | None = None, - ) -> DenseArray: - """ - Reshape an array using NumPy. - - Input: - x: Dense backend array; shape: New shape plus backend-specific options. - - Output: - Dense backend array with the requested shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.reshape.html - """ - if self._reshape_supports_copy: - return self.np.reshape(a, shape, order=order, copy=copy) - if copy: - return self.np.array(a, copy=True).reshape(shape, order=order) - return self.np.reshape(a, shape, order=order) - - def transpose(self, a: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute array axes using NumPy. - - Input: - x: Dense backend array; axes: Optional axis order. - - Output: - Dense backend array with permuted axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.transpose.html - """ - return self.np.transpose(a, axes=axes) - - def swapaxes(self, a: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Interchange two axes using NumPy. - - Input: - x: Dense backend array; axis1 and axis2: Axes to swap. - - Output: - Dense backend array with the two axes exchanged. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html - """ - return self.np.swapaxes(a, axis1, axis2) - - def broadcast_to( - self, - x: DenseArray, - shape: int | Tuple[int, ...], - subok: bool = False, - ) -> DenseArray: - """ - Broadcast an array to a shape using NumPy. - - Input: - x: Dense backend array; shape: Target broadcast shape. - - Output: - Dense backend array with broadcast shape. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html - """ - return self.np.broadcast_to(x, shape, subok=subok) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Position or positions to insert. - - Output: - Dense backend array with expanded rank. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html - """ - return self.np.expand_dims(x, axis=axis) - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove length-one axes using NumPy. - - Input: - x: Dense backend array; axis: Optional axes to squeeze. - - Output: - Dense backend array with selected singleton dimensions removed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html - """ - return self.np.squeeze(x, axis=axis) - - def moveaxis( - self, - x: DenseArray, - source: int | Sequence[int], - destination: int | Sequence[int], - ) -> DenseArray: - """ - Move axes to new positions using NumPy. - - Input: - x: Dense backend array; source and destination: Axis positions. - - Output: - Dense backend array with moved axes. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.moveaxis.html - """ - return self.np.moveaxis(x, source=source, destination=destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Stack arrays along a new axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis: New axis position. - - Output: - Dense backend array containing stacked inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.stack.html - """ - return self.np.stack(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def conj( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute complex conjugates using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array with conjugated values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.conj.html - """ - return self.np.conj( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def abs( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute absolute values using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of absolute values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.abs.html - """ - return self.np.abs( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sign( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute signs elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of signs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sign.html - """ - return self.np.sign( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def sqrt( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute square roots elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of square roots. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html - """ - return self.np.sqrt( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def real(self, x: DenseArray) -> DenseArray: - """ - Extract real components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing real components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.real.html - """ - return self.np.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Extract imaginary components using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array containing imaginary components. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.imag.html - """ - return self.np.imag(x) - - def sum( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by summation using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing sums. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sum.html - """ - return self.np.sum( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def mean( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by arithmetic mean using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing means. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.mean.html - """ - return self.np.mean(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where) - - def min( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by minimum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.min.html - """ - return self.np.min(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def max( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by maximum using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense backend array or scalar containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.max.html - """ - return self.np.max(a, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) - - def prod( - self, - a: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - out: DenseArray | None = None, - keepdims: bool = False, - initial: DenseArray | None = None, - where: DenseArray | bool = True, - ) -> DenseArray: - """ - Reduce by product using NumPy. - - Input: - x: Dense backend array; axis, keepdims, and dtype control the reduction. - - Output: - Dense backend array or scalar containing products. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.prod.html - """ - return self.np.prod( - a, - axis=axis, - dtype=dtype, - out=out, - keepdims=keepdims, - initial=initial, - where=where, - ) - - def trace( - self, - a: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using NumPy. - - Input: - x: Dense backend array plus optional diagonal and axis controls. - - Output: - Dense backend array or scalar containing trace values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.trace.html - """ - return self.np.trace( - a, - offset=offset, - axis1=axis1, - axis2=axis2, - dtype=dtype, - out=out, - ) - - def argsort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Return sorting indices using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense integer backend array of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argsort.html - """ - return self.np.argsort(a, axis=axis, kind=kind, order=order, stable=stable) - - def sort( - self, - a: DenseArray, - axis: int = -1, - kind: Literal["quicksort", "mergesort", "heapsort", "stable"] | None = None, - order: str | Sequence[str] | None = None, - *, - stable: bool | None = None, - ) -> DenseArray: - """ - Sort values using NumPy. - - Input: - x: Dense backend array; axis and ordering options are backend-specific. - - Output: - Dense backend array with sorted values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.sort.html - """ - return self.np.sort(a, axis=axis, kind=kind, order=order, stable=stable) - - def argmin( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of minimum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmin.html - """ - return self.np.argmin(a, axis=axis, out=out, keepdims=keepdims) - - def argmax( - self, - a: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - *, - keepdims: bool = False, - ) -> DenseArray: - """ - Return indices of maximum values using NumPy. - - Input: - x: Dense backend array; axis and keepdims control the reduction. - - Output: - Dense integer backend array or scalar of indices. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.argmax.html - """ - return self.np.argmax(a, axis=axis, out=out, keepdims=keepdims) - - def vdot(self, a: DenseArray, b: DenseArray) -> DenseArray: - """ - Compute a conjugating vector dot product using NumPy. - - Input: - x, y: Dense backend arrays accepted by the backend vdot operation. - - Output: - Backend scalar or dense array containing the dot product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.vdot.html - """ - return self.np.vdot(a, b) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - /, - out: DenseArray | None = None, - *, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute matrix products using NumPy. - - Input: - a, b: Dense backend arrays with matrix-multiplication-compatible shapes. - - Output: - Dense backend array containing the product. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.matmul.html - """ - return self.np.matmul( - a, - b, - out=out, - casting=casting, - order=order, - dtype=dtype, - subok=subok - ) - - def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: - """ - Multiply sparse and dense arrays using SciPy. - - Input: - a: Sparse backend array; b: Dense backend array. - - Output: - Dense backend array containing the product. - - See: - https://docs.scipy.org/doc/scipy/reference/sparse.html - - Backend-specific notes: - Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. - """ - if not self.is_sparse(a): - raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") - if not self.is_dense(b): - raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") - return a @ b + def __init__(self) -> None: + super().__init__() - def kron(self, a: DenseArray, b: DenseArray) -> DenseArray: + @property + def dense_array(self) -> Type[Any]: """ - Compute a Kronecker product using NumPy. - - Input: - a, b: Dense backend arrays. + Dense array type using NumPy. - Output: - Dense backend array containing the Kronecker product. + Returns: + Concrete dense array class accepted by this backend. See: - https://numpy.org/doc/stable/reference/generated/numpy.kron.html + https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html """ - return self.np.kron(a, b) + return self.np.ndarray - def einsum( - self, - subscripts: str, - *operands: DenseArray, - out: DenseArray | None = None, - dtype: DType | None = None, - order: Literal["C", "F", "A", "K"] = "K", - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "safe", - optimize: bool | str | Sequence[Any] = False, - ) -> DenseArray: + @property + def sparse_array(self) -> Tuple[Type[Any], ...]: """ - Evaluate an Einstein summation expression using NumPy. - - Input: - subscripts: Einstein summation string; operands: Dense backend arrays. + Sparse array type tuple using SciPy. - Output: - Dense backend array containing the contraction result. + Returns: + Concrete sparse array classes accepted by this backend, or None. See: - https://numpy.org/doc/stable/reference/generated/numpy.einsum.html - """ - return self.np.einsum( - subscripts, - *operands, - out=out, - dtype=dtype, - order=order, - casting=casting, - optimize=optimize, - ) - - def eigh(self, a: DenseArray, UPLO: Literal["L", "U"] = "L") -> Tuple[DenseArray, DenseArray]: + https://docs.scipy.org/doc/scipy/reference/sparse.html """ - Compute Hermitian eigenpairs using NumPy. - - Input: - x: Dense Hermitian or symmetric backend array. - - Output: - Tuple of dense backend arrays containing eigenvalues and eigenvectors. + sparse = self.sp.sparse + types: list[type[Any]] = [] + if hasattr(sparse, "spmatrix"): + types.append(sparse.spmatrix) + if hasattr(sparse, "sparray"): + types.append(sparse.sparray) + return tuple(types) - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigh.html - """ - if self.is_sparse(a): - raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.np.linalg.eigh(a, UPLO=UPLO) - def norm( - self, - x: DenseArray, - ord: int | str | None = None, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - ) -> DenseArray: + def sanitize_dtype(self, dtype: DType | None) -> DType: """ - Compute vector or matrix norms using NumPy. + Normalize a dtype specifier using NumPy. Input: - x: Dense backend array; ord, axis, and keepdims select the norm. + dtype: Optional dtype requested by SpaceCore or the caller. Output: - Dense backend array or scalar containing norm values. + Backend dtype object accepted by array constructors. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + https://numpy.org/doc/stable/reference/generated/numpy.dtype.html """ - return self.np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) - def solve(self, A: DenseArray, b: DenseArray) -> DenseArray: + if dtype is None: + return self.np.float64 + return self.np.dtype(dtype) + + def assparse(self, x: Any, *, format: Literal["csr", "csc", "coo"] = "csr", dtype: DType | None = None) -> SparseArray: """ - Solve dense linear systems using NumPy. + Convert input to a sparse array using SciPy. Input: - A: Dense coefficient array; b: Dense right-hand side array. + x: Dense, sparse, or array-like input plus sparse-format options. Output: - Dense backend array solving A @ x = b. + Sparse backend array. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html - """ - return self.np.linalg.solve(A, b) + https://docs.scipy.org/doc/scipy/reference/sparse.html - def eigvalsh(self, A: DenseArray, UPLO: Literal["L", "U"] = "L") -> DenseArray: + Backend-specific notes: + SpaceCore currently converts dense inputs to 2-D SciPy sparse matrices in the requested format. """ - Compute Hermitian eigenvalues using NumPy. - - Input: - A: Dense Hermitian or symmetric backend array. + sparse = self.sp.sparse - Output: - Dense backend array containing eigenvalues. + if self.is_sparse(x): + if format == "csr": + return x.tocsr() + if format == "csc": + return x.tocsc() + if format == "coo": + return x.tocoo() + raise ValueError(f"Unknown sparse format: {format!r}") - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.eigvalsh.html - """ - return self.np.linalg.eigvalsh(A, UPLO=UPLO) + x_arr = self.asarray(x) - def svd( - self, - A: DenseArray, - full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, - ) -> DenseArray | Tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decompositions using NumPy. + if x_arr.ndim != 2: + raise ValueError("NumPy/SciPy sparse conversion currently expects a 2D array.") - Input: - A: Dense backend array plus SVD options. + if format == "csr": + return sparse.csr_matrix(x_arr) + if format == "csc": + return sparse.csc_matrix(x_arr) + if format == "coo": + return sparse.coo_matrix(x_arr) - Output: - Dense backend arrays containing singular vectors and/or singular values. + raise ValueError(f"Unknown sparse format: {format!r}") - See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html - """ - return self.np.linalg.svd( - A, - full_matrices=full_matrices, - compute_uv=compute_uv, - hermitian=hermitian, - ) - - def cholesky(self, A: DenseArray) -> DenseArray: + def sparse_matmul(self, a: SparseArray, b: DenseArray) -> DenseArray: """ - Compute Cholesky factors using NumPy. + Multiply sparse and dense arrays using SciPy. Input: - A: Dense Hermitian positive-definite backend array. + a: Sparse backend array; b: Dense backend array. Output: - Dense backend array containing a triangular factor. + Dense backend array containing the product. See: - https://numpy.org/doc/stable/reference/generated/numpy.linalg.cholesky.html + https://docs.scipy.org/doc/scipy/reference/sparse.html + + Backend-specific notes: + Uses SciPy sparse multiplication before returning a dense NumPy result when applicable. """ - return self.np.linalg.cholesky(A) + if not self.is_sparse(a): + raise TypeError("sparse_matmul expects `a` to be a SciPy sparse matrix/array.") + if not self.is_dense(b): + raise TypeError("sparse_matmul expects `b` to be a Numpy dense object.") + return a @ b def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: DenseArray | None = None, keepdims: bool = False, return_sign: bool = False) -> DenseArray | Tuple[DenseArray, DenseArray]: @@ -1408,362 +186,6 @@ def logsumexp(self, a: DenseArray, axis: int | Sequence[int] | None = None, b: D """ return self.sp.special.logsumexp(a, axis=axis, b=b, keepdims=keepdims, return_sign=return_sign) - def exp( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute exponentials elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of exponentials. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.exp.html - """ - return self.np.exp( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def log( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute natural logarithms elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Dense backend array of logarithms. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.log.html - """ - return self.np.log( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def maximum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise maxima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing maxima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.maximum.html - """ - return self.np.maximum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def minimum( - self, - x: DenseArray, - y: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Compute elementwise minima using NumPy. - - Input: - x, y: Arrays or scalars accepted by backend broadcasting. - - Output: - Dense backend array containing minima. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.minimum.html - """ - return self.np.minimum( - x, - y, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def clip( - self, - x: DenseArray, - a_min: DenseArray, - a_max: DenseArray, - out: DenseArray | None = None, - **kwargs: Any, - ) -> DenseArray: - """ - Clip values into an interval using NumPy. - - Input: - x: Dense backend array; a_min and a_max: Broadcastable bounds. - - Output: - Dense backend array with clipped values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.clip.html - """ - return self.np.clip(x, a_min, a_max, out=out, **kwargs) - - def isfinite( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test finiteness elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html - """ - return self.np.isfinite( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def isnan( - self, - x: DenseArray, - /, - out: DenseArray | None = None, - *, - where: DenseArray | bool = True, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - order: Literal["C", "F", "A", "K"] = "K", - dtype: DType | None = None, - subok: bool = True, - ) -> DenseArray: - """ - Test NaN values elementwise using NumPy. - - Input: - x: Dense backend array. - - Output: - Boolean dense backend array. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.isnan.html - """ - return self.np.isnan( - x, - out=out, - where=where, - casting=casting, - order=order, - dtype=dtype, - subok=subok, - ) - - def where(self, condition: DenseArray | bool, x: DenseArray, y: DenseArray) -> DenseArray: - """ - Select values by condition using NumPy. - - Input: - condition: Boolean array or scalar; x and y: Values to choose between. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.where.html - """ - return self.np.where(condition, x, y) - - def concatenate( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - out: DenseArray | None = None, - *, - dtype: DType | None = None, - casting: Literal["no", "equiv", "safe", "same_kind", "unsafe"] = "same_kind", - ) -> DenseArray: - """ - Join arrays along an existing axis using NumPy. - - Input: - arrays: Sequence of dense backend arrays; axis and dtype options are backend-specific. - - Output: - Dense backend array containing concatenated inputs. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html - """ - return self.np.concatenate(arrays, axis=axis, out=out, dtype=dtype, casting=casting) - - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - out: DenseArray | None = None, - mode: Literal["raise", "wrap", "clip"] = "raise", - ) -> DenseArray: - """ - Take values by integer indices using NumPy. - - Input: - x: Dense backend array; indices: Integer indices; axis and mode options are backend-specific. - - Output: - Dense backend array containing selected values. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.take.html - """ - return self.np.take(x, indices, axis=axis, out=out, mode=mode) - - def diag(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Extract or build a diagonal using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array containing a diagonal view/copy or matrix. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diag.html - """ - return self.np.diag(x, k=k) - - def diagonal( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - ) -> DenseArray: - """ - Return selected diagonals using NumPy. - - Input: - x: Dense backend array plus offset and axis controls. - - Output: - Dense backend array containing selected diagonals. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.diagonal.html - """ - return self.np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - def tril(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return lower-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with upper entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.tril.html - """ - return self.np.tril(x, k=k) - - def triu(self, x: DenseArray, k: int = 0) -> DenseArray: - """ - Return upper-triangular values using NumPy. - - Input: - x: Dense backend array and optional diagonal offset. - - Output: - Dense backend array with lower entries zeroed. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.triu.html - """ - return self.np.triu(x, k=k) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed values using NumPy. @@ -2035,28 +457,6 @@ def index_add( self.np.add.at(y, index, values) return y - def allclose( - self, - a: DenseArray, - b: DenseArray, - rtol: float = 1e-5, - atol: float = 1e-8, - equal_nan: bool = False, - ) -> bool: - """ - Compare dense arrays elementwise within tolerances using NumPy. - - Input: - a, b: Dense backend arrays; rtol, atol, and equal_nan configure comparison. - - Output: - Boolean indicating whether arrays are close. - - See: - https://numpy.org/doc/stable/reference/generated/numpy.allclose.html - """ - return self.np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - def allclose_sparse( self, a: SparseArray, diff --git a/spacecore/backend/torch/_ops.py b/spacecore/backend/torch/_ops.py index cd6defb..f1413d7 100644 --- a/spacecore/backend/torch/_ops.py +++ b/spacecore/backend/torch/_ops.py @@ -5,8 +5,8 @@ import numpy as np from .._family import BackendFamily -from .._ops import BackendOps -from ...types import ArrayLike, DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry +from .._ops import BackendOps, LazyNamespace +from ...types import DenseArray, DType, Index, SparseArray, T, X, Y, R, Carry class TorchOps(BackendOps): @@ -52,6 +52,7 @@ class TorchOps(BackendOps): """ import torch + xp = LazyNamespace("array_api_compat.torch") _family = BackendFamily.torch.value.lower() _allow_sparse = True @@ -64,6 +65,9 @@ class TorchOps(BackendOps): torch.sparse_bsc, ) + def __init__(self) -> None: + super().__init__() + @staticmethod def _defined_kwargs(**kwargs: Any) -> dict[str, Any]: return {key: value for key, value in kwargs.items() if value is not None} @@ -180,193 +184,6 @@ def sanitize_dtype(self, dtype: DType | None) -> DType: return mapping[np_dtype] raise TypeError(f"Dtype {np_dtype!r} is not supported by PyTorch.") - def get_dtype(self, x: Any) -> DType: - """ - Return a tensor dtype using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Backend dtype associated with x. - - See: - https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch-dtype - """ - if self.is_array(x): - return x.dtype - raise TypeError(f"Expected PyTorch tensor, got {type(x)}.") - - def shape(self, x: Any) -> tuple[int, ...]: - """ - Return tensor shape metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Tuple describing the logical shape of x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.shape.html - """ - return tuple(x.shape) - - def ndim(self, x: Any) -> int: - """ - Return tensor rank metadata using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Number of dimensions in x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.ndim.html - """ - return int(x.ndim) - - def size(self, x: Any) -> int: - """ - Return logical element count using PyTorch. - - Input: - x: Dense or sparse backend tensor. - - Output: - Total number of logical dense elements. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.numel.html - """ - return int(x.numel()) - - @property - def inf(self): - """ - Positive infinity scalar using PyTorch. - - Returns: - Backend tensor scalar representing positive infinity. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("inf")) - - @property - def nan(self): - """ - NaN scalar using PyTorch. - - Returns: - Backend tensor scalar representing NaN. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(float("nan")) - - @property - def pi(self): - """ - Pi scalar using PyTorch. - - Returns: - Backend tensor scalar representing pi. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.pi) - - @property - def e(self): - """ - Euler number scalar using PyTorch. - - Returns: - Backend tensor scalar representing Euler's number. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tensor.html - """ - return self.torch.tensor(np.e) - - @property - def eps(self): - """ - Machine epsilon scalar using PyTorch. - - Returns: - Backend tensor scalar for float64 machine epsilon. - - See: - https://docs.pytorch.org/docs/stable/type_info.html#torch.finfo - """ - return self.torch.tensor(self.torch.finfo(self.torch.float64).eps) - - def asarray( - self, - x: Any, - dtype: DType | None = None, - *, - device: Any | None = None, - copy: bool | None = None, - ) -> DenseArray: - """ - Convert input to a dense tensor using PyTorch. - - Input: - x/a: Array-like input and optional dtype, device, or copy controls. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.as_tensor.html - - Backend-specific notes: - Sparse tensors are densified. Existing tensors keep autograd - metadata according to normal PyTorch conversion rules. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if self.is_sparse(x): - x = x.to_dense() - out = self.torch.as_tensor(x, dtype=dtype, device=device) - if copy: - out = out.clone() - return out - - def astype( - self, - x: DenseArray, - dtype: DType, - copy: bool = True, - *, - non_blocking: bool = False, - memory_format: Any | None = None, - ) -> DenseArray: - """ - Cast a tensor to a dtype using PyTorch. - - Input: - x: Dense backend tensor; dtype: Target dtype; copy: Whether to force a copy. - - Output: - Tensor with the requested dtype. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html - """ - return x.to( - dtype=self.sanitize_dtype(dtype), - non_blocking=non_blocking, - copy=copy, - **self._defined_kwargs(memory_format=memory_format), - ) - def assparse( self, x: Any, @@ -427,72 +244,69 @@ def assparse( return out.to_sparse_csc() raise ValueError(f"Unknown sparse format: {format!r}") - def empty( + def asarray( self, - shape: int | Tuple[int, ...], + x: Any, dtype: DType | None = None, *, - out: DenseArray | None = None, - layout: Any | None = None, device: Any | None = None, - requires_grad: bool = False, - pin_memory: bool = False, - memory_format: Any | None = None, + copy: bool | None = None, + backend_kwargs: dict[str, Any] | None = None, ) -> DenseArray: - """ - Create an uninitialized dense tensor using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor with uninitialized values. + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if device is not None: + kwargs["device"] = device + dtype = self.sanitize_dtype(dtype) if dtype is not None else None + if self.is_sparse(x): + x = x.to_dense() + out = self.torch.as_tensor(x, dtype=dtype, **kwargs) + return out.clone() if copy else out - See: - https://docs.pytorch.org/docs/stable/generated/torch.empty.html - """ - return self.torch.empty( - shape, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - pin_memory=pin_memory, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), + def astype( + self, + x: DenseArray, + dtype: DType | None, + *, + copy: bool = True, + non_blocking: bool = False, + memory_format: Any | None = None, + backend_kwargs: dict[str, Any] | None = None, + ) -> DenseArray: + if dtype is None: + return x + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(memory_format=memory_format)) + return x.to( + dtype=self.sanitize_dtype(dtype), + non_blocking=non_blocking, + copy=copy, + **kwargs, ) - def zeros( + def empty( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, + pin_memory: bool = False, + memory_format: Any | None = None, ) -> DenseArray: - """ - Create a dense tensor filled with zeros using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros.html - """ - return self.torch.zeros( + return self.torch.empty( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), + pin_memory=pin_memory, + **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones( + def zeros( self, - shape: int | Tuple[int, ...], + shape: Tuple[int, ...], dtype: DType | None = None, *, out: DenseArray | None = None, @@ -500,19 +314,7 @@ def ones( device: Any | None = None, requires_grad: bool = False, ) -> DenseArray: - """ - Create a dense tensor filled with ones using PyTorch. - - Input: - shape: Output shape; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones.html - """ - return self.torch.ones( + return self.torch.zeros( shape, out=out, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -530,18 +332,6 @@ def zeros_like( requires_grad: bool = False, memory_format: Any | None = None, ) -> DenseArray: - """ - Create a zero tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.zeros_like.html - """ return self.torch.zeros_like( x, dtype=self.sanitize_dtype(dtype) if dtype is not None else None, @@ -549,867 +339,93 @@ def zeros_like( **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), ) - def ones_like( + def arange( self, - x: DenseArray, + start: int, + stop: int | None = None, + step: int | None = None, dtype: DType | None = None, *, + out: DenseArray | None = None, layout: Any | None = None, device: Any | None = None, requires_grad: bool = False, - memory_format: Any | None = None, ) -> DenseArray: - """ - Create a one tensor matching another tensor using PyTorch. - - Input: - x: Reference tensor; dtype and device: Optional overrides. - - Output: - Dense backend tensor with shape matching x. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ones_like.html - """ - return self.torch.ones_like( - x, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) + kwargs = self._defined_kwargs(out=out, layout=layout, device=device) + kwargs["requires_grad"] = requires_grad + dtype = self.sanitize_dtype(dtype) if dtype is not None else None + if stop is None: + return self.torch.arange(start, dtype=dtype, **kwargs) + if step is None: + return self.torch.arange(start, stop, dtype=dtype, **kwargs) + return self.torch.arange(start, stop, step, dtype=dtype, **kwargs) - def full_like( + def sum( self, x: DenseArray, - value: Any, + axis: int | Sequence[int] | None = None, + keepdims: bool = False, dtype: DType | None = None, *, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - memory_format: Any | None = None, + out: DenseArray | None = None, + ) -> DenseArray: + kwargs = {"dim": self._to_axis_tuple(axis), "keepdim": keepdims} + if dtype is not None: + kwargs["dtype"] = self.sanitize_dtype(dtype) + if out is not None: + kwargs["out"] = out + return self.torch.sum(x, **kwargs) + + def matmul( + self, + a: DenseArray, + b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, + *, + out: DenseArray | None = None, + ) -> DenseArray: + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + if out is not None: + kwargs["out"] = out + return self.torch.matmul(a, b, **kwargs) + + def sparse_matmul( + self, + a: SparseArray, + b: DenseArray, + *, + reduce: Literal["sum", "mean", "amax", "amin"] = "sum", ) -> DenseArray: """ - Create a filled tensor matching another tensor using PyTorch. + Matrix-multiply a sparse tensor by a dense tensor using PyTorch. Input: - x: Reference tensor; value: Fill value; dtype and device: Optional overrides. + a: Sparse backend tensor; b: Dense backend tensor or vector. Output: - Dense backend tensor with shape matching x. + Dense backend tensor. See: - https://docs.pytorch.org/docs/stable/generated/torch.full_like.html - """ - return self.torch.full_like( - x, - value, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device, memory_format=memory_format), - ) - - def arange( - self, - start: int | float = 0, - stop: int | float | None = None, - step: int | float | None = None, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a range tensor using PyTorch. - - Input: - start, stop, step: Range parameters; dtype and device: Optional construction parameters. - - Output: - Dense backend tensor containing evenly spaced values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.arange.html - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - kwargs = self._defined_kwargs(out=out, layout=layout, device=device) - kwargs["requires_grad"] = requires_grad - if stop is None: - return self.torch.arange(start, dtype=dtype, **kwargs) - if step is None: - return self.torch.arange(start, stop, dtype=dtype, **kwargs) - return self.torch.arange(start, stop, step, dtype=dtype, **kwargs) - - def full( - self, - shape: int | Tuple[int, ...], - fill_value: Any, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a filled dense tensor using PyTorch. - - Input: - shape: Output shape; fill_value: Fill value; dtype and device: Optional parameters. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.full.html - """ - return self.torch.full( - shape, - fill_value, - out=out, - dtype=self.sanitize_dtype(dtype) if dtype is not None else None, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - - def eye( - self, - n: int, - m: int | None = None, - k: int = 0, - dtype: DType | None = None, - *, - out: DenseArray | None = None, - layout: Any | None = None, - device: Any | None = None, - requires_grad: bool = False, - ) -> DenseArray: - """ - Create a two-dimensional identity-like tensor using PyTorch. - - Input: - n, m: Matrix dimensions; k: Diagonal offset; dtype and device: Optional parameters. - - Output: - Dense backend tensor with ones on the requested diagonal. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.eye.html - - Backend-specific notes: - PyTorch ``torch.eye`` has no diagonal offset parameter, so SpaceCore - constructs the offset diagonal explicitly. - """ - m = n if m is None else m - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if k == 0: - return self.torch.eye( - n, - m, - out=out, - dtype=dtype, - requires_grad=requires_grad, - **self._defined_kwargs(layout=layout, device=device), - ) - out = self.torch.zeros( - (n, m), - out=out, - dtype=dtype, - requires_grad=False, - **self._defined_kwargs(layout=layout, device=device), - ) - diag_len = min(n, m - k) if k > 0 else min(n + k, m) - if diag_len <= 0: - return out - rows = self.torch.arange(diag_len, device=device) - cols = rows + k - if k < 0: - rows = rows - k - cols = self.torch.arange(diag_len, device=device) - out[rows, cols] = 1 - if requires_grad: - out.requires_grad_() - return out - - def ravel(self, x: DenseArray) -> DenseArray: - """ - Flatten a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - One-dimensional tensor view or copy following PyTorch semantics. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.ravel.html - """ - return self.torch.ravel(x) - - def reshape(self, x: DenseArray, shape: int | Tuple[int, ...], *, copy: bool | None = None) -> DenseArray: - """ - Reshape a tensor using PyTorch. - - Input: - x: Dense backend tensor; shape: Target shape; copy: Whether to clone first. - - Output: - Reshaped dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.reshape.html - """ - if copy: - x = x.clone() - return self.torch.reshape(x, shape if isinstance(shape, tuple) else (shape,)) - - def transpose(self, x: DenseArray, axes: Sequence[int] | None = None) -> DenseArray: - """ - Permute tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axes: Optional axis order. - - Output: - Tensor with permuted axes. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.permute.html - """ - if axes is None: - axes = tuple(reversed(range(x.ndim))) - return x.permute(tuple(axes)) - - def swapaxes(self, x: DenseArray, axis1: int, axis2: int) -> DenseArray: - """ - Swap two tensor axes using PyTorch. - - Input: - x: Dense backend tensor; axis1, axis2: Axes to swap. - - Output: - Tensor with the requested axes swapped. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.swapaxes.html - """ - return self.torch.swapaxes(x, axis1, axis2) - - def broadcast_to(self, x: DenseArray, shape: int | Tuple[int, ...]) -> DenseArray: - """ - Broadcast a tensor to a shape using PyTorch. - - Input: - x: Dense backend tensor; shape: Target broadcast shape. - - Output: - Broadcasted tensor view following PyTorch broadcasting rules. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.broadcast_to.html - """ - return self.torch.broadcast_to(x, shape) - - def expand_dims(self, x: DenseArray, axis: int | Sequence[int]) -> DenseArray: - """ - Insert singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Axis or axes where dimensions are inserted. - - Output: - Tensor with inserted singleton dimensions. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html - """ - if isinstance(axis, int): - return self.torch.unsqueeze(x, axis) - ndim = x.ndim + len(axis) - axes = sorted(a + ndim if a < 0 else a for a in axis) - out = x - for ax in axes: - out = self.torch.unsqueeze(out, ax) - return out - - def squeeze(self, x: DenseArray, axis: int | Sequence[int] | None = None) -> DenseArray: - """ - Remove singleton dimensions using PyTorch. - - Input: - x: Dense backend tensor; axis: Optional axis or axes to squeeze. - - Output: - Tensor with singleton dimensions removed. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html - """ - if axis is None: - return self.torch.squeeze(x) - if isinstance(axis, int): - return self.torch.squeeze(x, dim=axis) - out = x - for ax in sorted(axis, reverse=True): - out = self.torch.squeeze(out, dim=ax) - return out - - def moveaxis(self, x: DenseArray, source: int | Sequence[int], destination: int | Sequence[int]) -> DenseArray: - """ - Move tensor axes to new positions using PyTorch. - - Input: - x: Dense backend tensor; source and destination: Axis positions. - - Output: - Tensor with axes moved. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.moveaxis.html - """ - return self.torch.moveaxis(x, source, destination) - - def stack( - self, - arrays: Sequence[DenseArray], - axis: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Stack tensors along a new axis using PyTorch. - - Input: - arrays: Sequence of tensors; axis: New axis; out: Optional output tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.stack.html - """ - arrays = tuple(arrays) - if out is None: - return self.torch.stack(arrays, dim=axis) - return self.torch.stack(arrays, dim=axis, out=out) - - def conj(self, x: DenseArray) -> DenseArray: - """ - Return the complex conjugate using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor containing complex conjugates. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.conj.html - """ - return self.torch.conj(x) - - def real(self, x: DenseArray) -> DenseArray: - """ - Return the real part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing real components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.real.html - """ - return self.torch.real(x) - - def imag(self, x: DenseArray) -> DenseArray: - """ - Return the imaginary part of a tensor using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Tensor view or value containing imaginary components. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.imag.html - """ - return self.torch.imag(x) - - def abs( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise absolute value using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.abs.html - """ - if out is None: - return self.torch.abs(x) - return self.torch.abs(x, out=out) - - def sign( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise sign using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sign.html - """ - if out is None: - return self.torch.sign(x) - return self.torch.sign(x, out=out) - - def sqrt( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise square root using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sqrt.html - """ - if out is None: - return self.torch.sqrt(x) - return self.torch.sqrt(x, out=out) - - def sum( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Sum tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sum.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x) - return self.torch.sum(x, dim=axis, keepdim=keepdims) - return self.torch.sum(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.sum(x, dtype=dtype) - return self.torch.sum(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.sum( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def mean( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Average tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.mean.html - """ - if dtype is None: - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x) - return self.torch.mean(x, dim=axis, keepdim=keepdims) - return self.torch.mean(x, dim=axis, keepdim=keepdims, out=out) - dtype = self.sanitize_dtype(dtype) - if out is None: - if axis is None and not keepdims: - return self.torch.mean(x, dtype=dtype) - return self.torch.mean(x, dim=axis, keepdim=keepdims, dtype=dtype) - return self.torch.mean( - x, - dim=axis, - keepdim=keepdims, - dtype=dtype, - out=out, - ) - - def min( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amin.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amin(x) - return self.torch.amin(x, dim=axis, keepdim=keepdims) - return self.torch.amin(x, dim=axis, keepdim=keepdims, out=out) - - def max( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.amax.html - """ - if out is None: - if axis is None and not keepdims: - return self.torch.amax(x) - return self.torch.amax(x, dim=axis, keepdim=keepdims) - return self.torch.amax(x, dim=axis, keepdim=keepdims, out=out) - - def prod( - self, - x: DenseArray, - axis: int | Sequence[int] | None = None, - dtype: DType | None = None, - keepdims: bool = False, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Multiply tensor elements using PyTorch. - - Input: - x: Dense backend tensor; axis, dtype, keepdims: Reduction controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.prod.html - - Backend-specific notes: - Multiple-axis products are applied one axis at a time because - PyTorch's ``torch.prod`` reduces a single dimension per call. - """ - dtype = self.sanitize_dtype(dtype) if dtype is not None else None - if axis is None: - result = self.torch.prod(x) if dtype is None else self.torch.prod(x, dtype=dtype) - if out is not None: - out.copy_(result) - return out - return result - if isinstance(axis, int): - if out is None: - if dtype is None: - return self.torch.prod(x, dim=axis, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims) - return self.torch.prod(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) - result = x - for ax in sorted(axis, reverse=True): - result = self.torch.prod(result, dim=ax, dtype=dtype, keepdim=keepdims) - if out is not None: - out.copy_(result) - return out - return result - - def trace( - self, - x: DenseArray, - offset: int = 0, - axis1: int = 0, - axis2: int = 1, - dtype: DType | None = None, - ) -> DenseArray: - """ - Sum diagonal entries using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2, dtype: Diagonal controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.sum(self.diagonal(x, offset=offset, axis1=axis1, axis2=axis2), dtype=dtype) - - def argsort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - ) -> DenseArray: - """ - Return sorting indices using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argsort.html - """ - return self.torch.argsort(x, dim=axis, stable=stable, descending=descending) - - def sort( - self, - x: DenseArray, - axis: int = -1, - stable: bool = False, - descending: bool = False, - *, - out: tuple[DenseArray, DenseArray] | None = None, - ) -> DenseArray: - """ - Sort tensor values using PyTorch. - - Input: - x: Dense backend tensor; axis, stable, descending: Sorting controls. - - Output: - Dense backend tensor of sorted values. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sort.html - """ - return self.torch.sort(x, dim=axis, stable=stable, descending=descending, out=out).values - - def argmin(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of minimum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmin.html - """ - return self.torch.argmin(x, dim=axis, keepdim=keepdims) - - def argmax(self, x: DenseArray, axis: int | None = None, keepdims: bool = False) -> DenseArray: - """ - Return indices of maximum values using PyTorch. - - Input: - x: Dense backend tensor; axis and keepdims: Reduction controls. - - Output: - Integer tensor of indices. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.argmax.html - """ - return self.torch.argmax(x, dim=axis, keepdim=keepdims) - - def vdot( - self, - x: DenseArray, - y: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute conjugating vector dot product using PyTorch. - - Input: - x, y: Dense backend tensors. - - Output: - Scalar tensor containing the vector dot product. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.vdot.html - """ - x1 = x if x.ndim == 1 else self.torch.ravel(x) - y1 = y if y.ndim == 1 else self.torch.ravel(y) - if out is None: - return self.torch.vdot(x1, y1) - return self.torch.vdot(x1, y1, out=out) - - def matmul( - self, - a: DenseArray, - b: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Matrix-multiply tensors using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.matmul.html - """ - if out is None: - return self.torch.matmul(a, b) - return self.torch.matmul(a, b, out=out) - - def sparse_matmul( - self, - a: SparseArray, - b: DenseArray, - *, - reduce: Literal["sum", "mean", "amax", "amin"] = "sum", - ) -> DenseArray: - """ - Matrix-multiply a sparse tensor by a dense tensor using PyTorch. - - Input: - a: Sparse backend tensor; b: Dense backend tensor or vector. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html + https://docs.pytorch.org/docs/stable/generated/torch.sparse.mm.html """ kwargs = {"reduce": reduce} if reduce != "sum" else {} if b.ndim == 1: return self.torch.sparse.mm(a, b[:, None], **kwargs)[:, 0] return self.torch.sparse.mm(a, b, **kwargs) - def kron( - self, - a: DenseArray, - b: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute the Kronecker product using PyTorch. - - Input: - a, b: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.kron.html - """ - return self.torch.kron(a, b, out=out) - - def einsum(self, subscripts: str, *operands: DenseArray) -> DenseArray: - """ - Evaluate an Einstein summation using PyTorch. - - Input: - subscripts: Einsum expression; operands: Dense backend tensors. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.einsum.html - """ - return self.torch.einsum(subscripts, *operands) - def eigh( self, x: DenseArray, + backend_kwargs: dict[str, Any] | None = None, UPLO: Literal["L", "U"] = "L", *, out: tuple[DenseArray, DenseArray] | None = None, ) -> tuple[DenseArray, DenseArray]: - """ - Compute Hermitian eigenvalues and eigenvectors using PyTorch. - - Input: - x: Dense Hermitian or symmetric backend tensor. - - Output: - Tuple of eigenvalues and eigenvectors. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigh.html - """ if self.is_sparse(x): raise TypeError("eigh requires a dense array; sparse input is not supported.") - return self.torch.linalg.eigh(x, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.eigh(x, UPLO=UPLO, **kwargs) def norm( self, @@ -1421,18 +437,6 @@ def norm( dtype: DType | None = None, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute vector or matrix norms using PyTorch. - - Input: - x: Dense backend tensor; ord, axis, keepdims: Norm controls. - - Output: - Dense backend tensor or scalar tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.norm.html - """ return self.torch.linalg.norm( x, ord=ord, @@ -1446,97 +450,39 @@ def solve( self, A: DenseArray, b: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, left: bool = True, out: DenseArray | None = None, ) -> DenseArray: - """ - Solve a linear system using PyTorch. - - Input: - A: Coefficient tensor; b: Right-hand side tensor. - - Output: - Dense backend tensor solving ``A @ x = b``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.solve.html - """ - return self.torch.linalg.solve(A, b, left=left, out=out) - - def eigvalsh( - self, - A: DenseArray, - UPLO: Literal["L", "U"] = "L", - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute Hermitian eigenvalues using PyTorch. - - Input: - A: Dense Hermitian or symmetric backend tensor. - - Output: - Dense backend tensor of eigenvalues. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.eigvalsh.html - """ - return self.torch.linalg.eigvalsh(A, UPLO=UPLO, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.solve(A, b, left=left, **kwargs) def svd( self, A: DenseArray, full_matrices: bool = True, - compute_uv: bool = True, - hermitian: bool = False, + backend_kwargs: dict[str, Any] | None = None, *, driver: str | None = None, out: DenseArray | tuple[DenseArray, DenseArray, DenseArray] | None = None, ) -> DenseArray | tuple[DenseArray, DenseArray, DenseArray]: - """ - Compute singular value decomposition using PyTorch. - - Input: - A: Dense backend tensor; full_matrices, compute_uv, hermitian: SVD controls. - - Output: - Singular values or tuple ``(U, S, Vh)``. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.svd.html - - Backend-specific notes: - PyTorch does not expose a ``hermitian`` option for SVD. When - ``compute_uv`` is false, this delegates to ``torch.linalg.svdvals``. - """ - if hermitian: - raise NotImplementedError("PyTorch svd does not expose a hermitian option.") - if not compute_uv: - return self.torch.linalg.svdvals(A, driver=driver, out=out) - return self.torch.linalg.svd(A, full_matrices=full_matrices, driver=driver, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(driver=driver, out=out)) + return self.torch.linalg.svd(A, full_matrices=full_matrices, **kwargs) def cholesky( self, A: DenseArray, + backend_kwargs: dict[str, Any] | None = None, *, upper: bool = False, out: DenseArray | None = None, ) -> DenseArray: - """ - Compute a Cholesky factorization using PyTorch. - - Input: - A: Positive-definite dense backend tensor. - - Output: - Dense backend tensor containing the Cholesky factor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.linalg.cholesky.html - """ - return self.torch.linalg.cholesky(A, upper=upper, out=out) + kwargs = {} if backend_kwargs is None else dict(backend_kwargs) + kwargs.update(self._defined_kwargs(out=out)) + return self.torch.linalg.cholesky(A, upper=upper, **kwargs) def logsumexp( self, @@ -1583,188 +529,18 @@ def logsumexp( return out return result - def exp( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise exponential using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.exp.html - """ - if out is None: - return self.torch.exp(x) - return self.torch.exp(x, out=out) - - def log( - self, - x: DenseArray, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise natural logarithm using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.log.html - """ - if out is None: - return self.torch.log(x) - return self.torch.log(x, out=out) - def where( self, condition: DenseArray | bool, - x: ArrayLike | None = None, - y: ArrayLike | None = None, + x: DenseArray, + y: DenseArray, *, out: DenseArray | None = None, ) -> DenseArray: - """ - Select values conditionally using PyTorch. - - Input: - condition: Boolean tensor or scalar; x, y: Values to select. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.where.html - """ - if x is None and y is None: - return self.torch.where(condition) - if x is None or y is None: - raise TypeError("where requires both x and y when either is provided.") if out is None: return self.torch.where(condition, x, y) return self.torch.where(condition, x, y, out=out) - def maximum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise maximum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.maximum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.maximum(x, y) - return self.torch.maximum( - x, - y, - out=out, - ) - - def minimum( - self, - x: ArrayLike, - y: ArrayLike, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Compute elementwise minimum using PyTorch. - - Input: - x, y: Array-like operands. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.minimum.html - """ - y = y if isinstance(y, self.torch.Tensor) else self.asarray(y, dtype=x.dtype, device=x.device) - if out is None: - return self.torch.minimum(x, y) - return self.torch.minimum( - x, - y, - out=out, - ) - - def clip( - self, - x: DenseArray, - a_min: ArrayLike | None = None, - a_max: ArrayLike | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Clip tensor values using PyTorch. - - Input: - x: Dense backend tensor; a_min, a_max: Lower and upper bounds. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.clamp.html - """ - if out is None: - return self.torch.clamp(x, min=a_min, max=a_max) - return self.torch.clamp(x, min=a_min, max=a_max, out=out) - - def isfinite(self, x: DenseArray) -> DenseArray: - """ - Test finiteness elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isfinite.html - """ - return self.torch.isfinite(x) - - def isnan(self, x: DenseArray) -> DenseArray: - """ - Test NaN values elementwise using PyTorch. - - Input: - x: Dense backend tensor. - - Output: - Boolean dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.isnan.html - """ - return self.torch.isnan(x) - def concatenate( self, arrays: Sequence[DenseArray], @@ -1773,131 +549,12 @@ def concatenate( *, out: DenseArray | None = None, ) -> DenseArray: - """ - Concatenate tensors using PyTorch. - - Input: - arrays: Sequence of dense backend tensors; axis: Concatenation axis; dtype: Optional cast. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.cat.html - """ - arrays = tuple(arrays) if out is None: - result = self.torch.cat(arrays, dim=axis) + result = self.torch.cat(tuple(arrays), dim=axis) else: - result = self.torch.cat(arrays, dim=axis, out=out) + result = self.torch.cat(tuple(arrays), dim=axis, out=out) return self.astype(result, dtype) if dtype is not None else result - def take( - self, - x: DenseArray, - indices: DenseArray, - axis: int | None = None, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Take tensor elements by index using PyTorch. - - Input: - x: Dense backend tensor; indices: Integer indices; axis: Optional selection axis. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.take.html - """ - if axis is None: - result = self.torch.take(x, indices) - if out is not None: - out.copy_(result) - return out - return result - return self.torch.index_select(x, dim=axis, index=indices, out=out) - - def diag( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Extract or construct a diagonal tensor using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diag.html - """ - return self.torch.diag(x, diagonal=k, out=out) - - def diagonal(self, x: DenseArray, offset: int = 0, axis1: int = 0, axis2: int = 1) -> DenseArray: - """ - Return a tensor diagonal using PyTorch. - - Input: - x: Dense backend tensor; offset, axis1, axis2: Diagonal controls. - - Output: - Dense backend tensor view or value. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.diagonal.html - """ - return self.torch.diagonal(x, offset=offset, dim1=axis1, dim2=axis2) - - def tril( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the lower triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.tril.html - """ - return self.torch.tril(x, diagonal=k, out=out) - - def triu( - self, - x: DenseArray, - k: int = 0, - *, - out: DenseArray | None = None, - ) -> DenseArray: - """ - Return the upper triangular part using PyTorch. - - Input: - x: Dense backend tensor; k: Diagonal offset. - - Output: - Dense backend tensor. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.triu.html - """ - return self.torch.triu(x, diagonal=k, out=out) - def index_set(self, x: DenseArray, index: Index, values: DenseArray, *, copy: bool = True): """ Set indexed tensor values using PyTorch. @@ -2106,21 +763,6 @@ def cond(self, pred: bool, true_fun: Callable[[T], R], false_fun: Callable[[T], """ return true_fun(*operands) if bool(pred) else false_fun(*operands) - def allclose(self, a: DenseArray, b: DenseArray, rtol: float = 1e-5, atol: float = 1e-8, equal_nan: bool = False) -> bool: - """ - Compare dense tensors elementwise within tolerances using PyTorch. - - Input: - a, b: Dense backend tensors; rtol, atol, equal_nan: Comparison controls. - - Output: - Boolean indicating whether tensors are close. - - See: - https://docs.pytorch.org/docs/stable/generated/torch.allclose.html - """ - return bool(self.torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) - def allclose_sparse(self, a: SparseArray, b: SparseArray, rtol: float = 1e-5, atol: float = 1e-8) -> bool: """ Compare sparse tensors elementwise within tolerances using PyTorch. diff --git a/tests/backend/test_backend_ops_delegation.py b/tests/backend/test_backend_ops_delegation.py new file mode 100644 index 0000000..770ce32 --- /dev/null +++ b/tests/backend/test_backend_ops_delegation.py @@ -0,0 +1,238 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import ( + has_jax, + has_torch, + jax_complex_dtype, + jax_real_dtype, + to_numpy, + torch_complex_dtype, + torch_real_dtype, +) + + +DELEGATED_METHODS = ( + "reshape", + "sum", + "eigh", + "trace", + "concatenate", + "transpose", + "matmul", +) + +TORCH_AAC_DELEGATED_METHODS = ( + "mean", + "prod", + "sort", + "argsort", + "argmin", + "argmax", + "clip", + "take", + "diagonal", + "squeeze", +) + + +def test_numpy_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.NumpyOps.xp.__name__ == "array_api_compat.numpy" + for name in DELEGATED_METHODS: + assert getattr(sc.NumpyOps, name) is getattr(sc.BackendOps, name) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_ops_inherits_common_delegated_methods(): + sc = importlib.import_module("spacecore") + + assert sc.JaxOps.xp is sc.JaxOps.jnp + for name in DELEGATED_METHODS: + assert getattr(sc.JaxOps, name) is getattr(sc.BackendOps, name) + + +def test_torch_ops_uses_aac_namespace_when_available(): + sc = importlib.import_module("spacecore") + + assert sc.TorchOps.xp.__name__ == "array_api_compat.torch" + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_ops_inherits_aac_delegated_methods(): + sc = importlib.import_module("spacecore") + + for name in TORCH_AAC_DELEGATED_METHODS: + assert getattr(sc.TorchOps, name) is getattr(sc.BackendOps, name) + + +def _check_raw_delegated_ops(ops, dtype): + x = ops.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + y = ops.reshape(x, (3, 2)) + h = ops.asarray([[2.0, 1.0], [1.0, 3.0]], dtype=dtype) + cube = ops.reshape(ops.arange(24, dtype=dtype), (2, 3, 4)) + singleton = ops.reshape(ops.arange(6, dtype=dtype), (1, 2, 1, 3)) + + np.testing.assert_allclose(to_numpy(ops.matmul(h, ops.asarray([1.0, 2.0], dtype=dtype))), [4.0, 7.0]) + np.testing.assert_allclose(to_numpy(ops.reshape(x, (6,))), np.arange(1.0, 7.0)) + np.testing.assert_allclose(to_numpy(ops.sum(x, axis=0)), [5.0, 7.0, 9.0]) + np.testing.assert_allclose(to_numpy(ops.sum(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).sum(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.prod(cube + 1, axis=(0, 2))), (np.arange(24.0).reshape(2, 3, 4) + 1).prod(axis=(0, 2))) + np.testing.assert_allclose(to_numpy(ops.mean(cube, axis=(0, 2))), np.arange(24.0).reshape(2, 3, 4).mean(axis=(0, 2))) + assert ops.shape(ops.squeeze(singleton)) == (2, 3) + np.testing.assert_allclose(to_numpy(ops.trace(h)), 5.0) + np.testing.assert_allclose(to_numpy(ops.concatenate((x, x), axis=0)), np.concatenate((to_numpy(x), to_numpy(x)), axis=0)) + np.testing.assert_allclose(to_numpy(ops.transpose(y)), to_numpy(y).T) + + evals, evecs = ops.eigh(h) + np.testing.assert_allclose(to_numpy(evecs @ ops.diag(evals) @ ops.transpose(evecs)), to_numpy(h), atol=1e-6) + + +def test_numpy_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.NumpyOps(), np.float64) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.JaxOps(), jax_real_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_raw_delegated_ops(): + sc = importlib.import_module("spacecore") + + _check_raw_delegated_ops(sc.TorchOps(), torch_real_dtype()) + + +def test_numpy_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), np.finfo(ops.sanitize_dtype(None)).eps) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + + np.testing.assert_allclose(ops.eps(jax_real_dtype()), np.finfo(jax_real_dtype()).eps) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_eps_uses_default_dtype(): + sc = importlib.import_module("spacecore") + import torch + + ops = sc.TorchOps() + np.testing.assert_allclose(ops.eps(ops.sanitize_dtype(None)), torch.finfo(ops.sanitize_dtype(None)).eps) + + +def test_backend_ops_hash_and_repr(): + sc = importlib.import_module("spacecore") + first = sc.NumpyOps() + second = sc.NumpyOps() + + assert hash(first) == hash(second) + assert {first: 1, second: 2} == {first: 2} + assert repr(first) == "NumpyOps(family='numpy')" + + +def test_numpy_eps_distinguishes_dtype_precision(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + + assert ops.eps(np.float64) < ops.eps(np.float32) + + +def test_numpy_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.NumpyOps() + x = ops.asarray([1.0, 2.0]) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.JaxOps() + x = ops.asarray([1.0, 2.0], dtype=jax_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_constants_are_cached_and_astype_none_is_noop(): + sc = importlib.import_module("spacecore") + ops = sc.TorchOps() + x = ops.asarray([1.0, 2.0], dtype=torch_real_dtype()) + + assert ops.inf is ops.inf + assert ops.nan is ops.nan + assert ops.pi is ops.pi + assert ops.e is ops.e + assert ops.astype(x, None) is x + + +def _check_complex_adjoint(ops, dtype): + sc = importlib.import_module("spacecore") + ctx = sc.Context(ops, dtype=dtype) + dom = sc.VectorSpace((2,), ctx) + cod = sc.VectorSpace((3,), ctx) + A = ctx.asarray( + [ + [1.0 + 2.0j, 3.0 - 1.0j], + [-2.0 + 0.5j, 0.25 + 4.0j], + [1.5 - 3.0j, -0.75 + 2.0j], + ] + ) + x = ctx.asarray([2.0 - 1.0j, -0.5 + 0.25j]) + y = ctx.asarray([1.0 + 0.5j, -2.0j, 0.75 - 1.25j]) + op = sc.DenseLinOp(A, dom, cod, ctx) + + lhs = ctx.ops.vdot(op.apply(x), y) + rhs = ctx.ops.vdot(x, op.rapply(y)) + np.testing.assert_allclose(to_numpy(lhs), to_numpy(rhs), rtol=1e-6, atol=1e-6) + + H = sc.HermitianSpace(2, ctx=ctx) + raw = ctx.asarray([[1.0 + 0.0j, 2.0 + 3.0j], [-1.0 + 4.0j, -0.5 + 0.0j]]) + herm = H.symmetrize(raw) + evals, evecs = H.eigh(herm) + rebuilt = (evecs * evals) @ ctx.ops.conj(ctx.ops.transpose(evecs)) + np.testing.assert_allclose(to_numpy(rebuilt), to_numpy(herm), rtol=1e-6, atol=1e-6) + + +def test_numpy_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_complex_adjoint_and_hermitian_eigh(): + sc = importlib.import_module("spacecore") + + _check_complex_adjoint(sc.TorchOps(), torch_complex_dtype()) diff --git a/tests/backend/test_backend_registry.py b/tests/backend/test_backend_registry.py index f55aea9..65ca47a 100644 --- a/tests/backend/test_backend_registry.py +++ b/tests/backend/test_backend_registry.py @@ -14,55 +14,23 @@ def test_register_ops_adds_backend(): sc = importlib.import_module("spacecore") class DummyOps(sc.BackendOps): + import array_api_compat.numpy as xp + _family = "dummy" _allow_sparse = False - _dense_array = np.ndarray - _sparse_array = object + + @property + def dense_array(self): + return np.ndarray + + @property + def sparse_array(self): + return None + def sanitize_dtype(self, dtype): return np.dtype(dtype) if dtype is not None else np.dtype(np.float64) - def asarray(self, x, dtype=None): return np.asarray(x, dtype=dtype) def assparse(self, x, dtype=None): raise NotImplementedError - def is_dense(self, x): return isinstance(x, np.ndarray) - def is_sparse(self, x): return False - def get_dtype(self, x): return x.dtype - def zeros(self, shape, dtype=None): return np.zeros(shape, dtype=dtype) - def ones(self, shape, dtype=None): return np.ones(shape, dtype=dtype) - def full(self, shape, fill_value, dtype=None): return np.full(shape, fill_value, dtype=dtype) - def empty(self, shape, dtype=None): return np.empty(shape, dtype=dtype) - def eye(self, n, dtype=None): return np.eye(n, dtype=dtype) - def arange(self, *args, **kwargs): return np.arange(*args, **kwargs) - def reshape(self, x, shape): return np.reshape(x, shape) - def ravel(self, x): return np.ravel(x) - def transpose(self, x, axes=None): return np.transpose(x, axes=axes) - def swapaxes(self, x, axis1, axis2): return np.swapaxes(x, axis1, axis2) - def conj(self, x): return np.conj(x) - def sum(self, x, axis=None, **kwargs): return np.sum(x, axis=axis) - def prod(self, x, axis=None, **kwargs): return np.prod(x, axis=axis) - def trace(self, x, **kwargs): return np.trace(x) - def argsort(self, x, **kwargs): return np.argsort(x) - def sort(self, x, **kwargs): return np.sort(x) - def argmin(self, x, **kwargs): return np.argmin(x) - def argmax(self, x, **kwargs): return np.argmax(x) - def vdot(self, a, b, **kwargs): return np.vdot(a, b) - def matmul(self, a, b, **kwargs): return a @ b def sparse_matmul(self, a, b): raise NotImplementedError - def kron(self, a, b): return np.kron(a, b) - def einsum(self, subscripts, *operands, **kwargs): return np.einsum(subscripts, *operands) - def eigh(self, x, **kwargs): return np.linalg.eigh(x) def logsumexp(self, *args, **kwargs): raise NotImplementedError - def exp(self, x): return np.exp(x) - def log(self, x): return np.log(x) - def maximum(self, x, y): return np.maximum(x, y) - def minimum(self, x, y): return np.minimum(x, y) - def where(self, condition, x=None, y=None, **kwargs): return np.where(condition, x, y) - def concatenate(self, arrays, axis=0, dtype=None): - out = np.concatenate(arrays, axis=axis) - return out.astype(dtype) if dtype is not None else out - def stack(self, arrays, axis=0): return np.stack(arrays, axis=axis) - def sqrt(self, x): return np.sqrt(x) - def abs(self, x): return np.abs(x) - def real(self, x): return np.real(x) - def imag(self, x): return np.imag(x) - def sign(self, x): return np.sign(x) def index_set(self, x, index, values, copy=True): y = np.array(x, copy=True) @@ -90,10 +58,12 @@ def index_add(self, x, index, values, copy=True): y=np.array(x, copy=True) y[index]+=values return y - def allclose(self, a, b, **kwargs): return np.allclose(a, b, **kwargs) def allclose_sparse(self, a, b, **kwargs): return False sc.register_ops(DummyOps) assert "dummy" in sc._contextual.manager.ctx_manager.available_ops + ops = DummyOps() + x = ops.reshape(ops.arange(6), (2, 3)) + assert np.allclose(ops.sum(x, axis=0), [3, 5, 7]) @pytest.mark.skipif(not has_torch(), reason="torch is not installed") diff --git a/tests/test_backend_ops_complex.py b/tests/test_backend_ops_complex.py new file mode 100644 index 0000000..c5b95e0 --- /dev/null +++ b/tests/test_backend_ops_complex.py @@ -0,0 +1,33 @@ +import importlib + +import numpy as np +import pytest + +from tests._helpers import has_jax, has_torch, jax_complex_dtype, to_numpy, torch_complex_dtype + + +def _check_vdot_conjugates_first_argument(ops, dtype): + x = ops.asarray([1.0 + 2.0j, 3.0 + 4.0j], dtype=dtype) + y = ops.asarray([5.0 + 6.0j, 7.0 + 8.0j], dtype=dtype) + + np.testing.assert_allclose(to_numpy(ops.vdot(x, y)), 70.0 - 8.0j) + + +def test_numpy_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.NumpyOps(), np.complex128) + + +@pytest.mark.skipif(not has_jax(), reason="jax is not installed") +def test_jax_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.JaxOps(), jax_complex_dtype()) + + +@pytest.mark.skipif(not has_torch(), reason="torch is not installed") +def test_torch_vdot_conjugates_first_argument(): + sc = importlib.import_module("spacecore") + + _check_vdot_conjugates_first_argument(sc.TorchOps(), torch_complex_dtype())