diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 55c70e53..d8cce7c8 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -13,6 +13,7 @@ import numexpr as ne from collections import defaultdict from math import prod +import array_api_compat from .domain import Domain from .field import Operand, Field @@ -245,10 +246,11 @@ def choose_layout(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) - np.add(arg0.data, arg1.data, out=out.data) + xp.add(arg0.data, arg1.data, out=out.data) # used for einsum string manipulation @@ -664,6 +666,7 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig): return G def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args out.preset_layout(arg0.layout) # Broadcast @@ -671,7 +674,11 @@ def operate(self, out): arg1_data = self.arg1_ghost_broadcaster.cast(arg1) # Call einsum if out.data.size: - np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if array_api_compat.is_cupy_namespace(xp): + # Cupy does not support output keyword + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True) + else: + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) @alias("cross") @@ -854,6 +861,7 @@ def __init__(self, arg0, arg1, out=None, **kw): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) @@ -863,7 +871,7 @@ def operate(self, out): # Reshape arg data to broadcast properly for output tensorsig arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):]) arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):]) - np.multiply(arg0_exp_data, arg1_exp_data, out=out.data) + xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data) class GhostBroadcaster: @@ -919,7 +927,7 @@ def __init__(self, arg0, arg1, out=None,**kw): super().__init__(arg0, arg1, out=out) self.domain = arg1.domain self.tensorsig = arg1.tensorsig - self.dtype = np.result_type(type(arg0), arg1.dtype) + self.dtype = np.result_type(arg0, arg1.dtype) @classmethod def _check_args(cls, *args, **kw): @@ -939,11 +947,12 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg1.layout) # Multiply argument data - np.multiply(arg0, arg1.data, out=out.data) + xp.multiply(arg0, arg1.data, out=out.data) def matrix_dependence(self, *vars): return self.args[1].matrix_dependence(*vars) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 3f6167c9..d9616b10 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5,6 +5,7 @@ from functools import reduce import inspect from math import prod +import array_api_compat from . import operators from ..libraries import spin_recombination @@ -14,7 +15,7 @@ from ..tools import clenshaw from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices from ..tools.dispatch import MultiClass, SkipDispatchException -from ..tools.general import unify, DeferredTuple +from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct from .domain import Domain from .field import Operand, LockedField @@ -438,8 +439,10 @@ class Jacobi(IntervalBasis, metaclass=CachedClass): group_shape = (1,) native_bounds = (-1, 1) transforms = {} - default_dct = "fftw_dct" - default_library = "matrix" + default_cpu_library = "matrix" + default_gpu_library = "matrix" + default_cpu_dct = "fftw" + default_gpu_dct = "matrix" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, library): @@ -474,12 +477,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, a, b, a0, b0, dealias, libr dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Jacobi dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - if a0 == b0 == -1/2: - library = cls.default_dct - else: - library = cls.default_library return (coord, size, bounds, a, b, a0, b0, dealias, library) def __init__(self, coord, size, bounds, a, b, a0=None, b0=None, dealias=(1,), library=None): @@ -503,10 +500,30 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return jacobi.build_grid(N, a=self.a0, b=self.b0) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if self.a0 == self.b0 == -1/2: + if dist.is_cupy_namespace: + return self.default_gpu_dct + else: + return self.default_cpu_dct + else: + if dist.is_cupy_namespace: + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0) + # Shortcut trivial transforms + if grid_size == 1 or self.size == 1: + return self.transforms["matrix"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + else: + return self.transforms[self.get_library(dist)](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" @@ -818,7 +835,8 @@ class FourierBase(IntervalBasis): """Base class for RealFourier and ComplexFourier.""" native_bounds = (0, 2*np.pi) - default_library = "fftw" + default_gpu_library = "cupy" + default_cpu_library = "fftw" @classmethod def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): @@ -841,9 +859,6 @@ def _preprocess_cache_args(cls, coord, size, bounds, dealias, library): dealias = tuple(dealias) if len(dealias) != 1: raise ValueError("Fourier dealias must have length 1.") - # library: pick default based on (a0, b0) - if library is None: - library = cls.default_library return (coord, size, bounds, dealias, library) def __init__(self, coord, size, bounds, dealias=(1,), library=None): @@ -912,14 +927,24 @@ def _native_grid(self, scale): N, = self.grid_shape((scale,)) return (2 * np.pi / N) * np.arange(N) + def get_library(self, dist): + """Get library for transforms.""" + if self.library is None: + if dist.is_cupy_namespace: + return self.default_gpu_library + else: + return self.default_cpu_library + else: + return self.library + @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms['matrix'](grid_size, self.size) + return self.transforms["matrix"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[self.library](grid_size, self.size) + return self.transforms[self.get_library(dist)](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform @@ -940,9 +965,9 @@ def Fourier(*args, dtype=None, **kw): """Factory function dispatching to RealFourier and ComplexFourier based on provided dtype.""" if dtype is None: raise ValueError("dtype must be specified") - elif dtype == np.float64: + elif is_real_dtype(dtype): return RealFourier(*args, **kw) - elif dtype == np.complex128: + elif is_complex_dtype(dtype): return ComplexFourier(*args, **kw) else: raise ValueError(f"Unrecognized dtype: {dtype}") @@ -2047,15 +2072,6 @@ def _preprocess_cache_args(cls, coordsys, shape, dtype, radii, k, alpha, dealias dealias = tuple(dealias) if len(dealias) != 2: raise ValueError("Annulus dealias must have length 2.") - # azimuth_library: pick default - if azimuth_library is None: - azimuth_library = RealFourier.default_library - # radius_library: pick default based on alpha - if radius_library is None: - if alpha[0] == alpha[1] == -1/2: - radius_library = Jacobi.default_dct - else: - radius_library = Jacobi.default_library return (coordsys, shape, dtype, radii, k, alpha, dealias, azimuth_library, radius_library) def __init__(self, coordsys, shape, dtype, radii=(1,2), k=0, alpha=(-0.5,-0.5), dealias=(1,1), azimuth_library=None, radius_library=None): @@ -6081,6 +6097,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL): @CachedMethod def cfl_spacing(self): + xp = self.array_namespace velocity = self.operand coordsys = velocity.tensorsig[0] spacing = [] @@ -6102,7 +6119,7 @@ def cfl_spacing(self): axis_spacing[:] = dealias * native_spacing * basis.COV.stretch elif basis is None: axis_spacing = np.inf - spacing.append(axis_spacing) + spacing.append(xp.asarray(axis_spacing)) return spacing def compute_cfl_frequency(self, velocity, out): diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index c4cc766f..6dcadfb4 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -10,6 +10,7 @@ from math import prod import numbers from weakref import WeakSet +import array_api_compat from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -39,12 +40,16 @@ class Distributor: Parameters ---------- - dim : int - Dimension + coordsystems : CoordinateSystem or tuple of CoordinateSystems + Problem coordinate systems comm : MPI communicator, optional MPI communicator (default: comm world) mesh : tuple of ints, optional Process mesh for parallelization (default: 1-D mesh of available processes) + dtype : data type, optional + Default data type for fields (default: None) + array_namespace : array namespace or string, optional + Array namespace for field data (e.g. numpy or cupy, default: numpy) Attributes ---------- @@ -74,7 +79,7 @@ class Distributor: states) and the paths between them (D transforms and R transposes). """ - def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): + def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np): # Accept single coordsys in place of tuple/list if not isinstance(coordsystems, (tuple, list)): coordsystems = (coordsystems,) @@ -115,6 +120,13 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): self._build_layouts() # Keep set of weak field references self.fields = WeakSet() + # Array module + if isinstance(array_namespace, str): + self.array_namespace = getattr(array_api_compat, array_namespace) + else: + self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) + self.is_numpy_namespace = array_api_compat.is_numpy_namespace(self.array_namespace) + self.is_cupy_namespace = array_api_compat.is_cupy_namespace(self.array_namespace) @CachedAttribute def cs_by_axis(self): @@ -255,11 +267,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None) return I def local_grid(self, basis, scale=None): + xp = self.array_namespace # TODO: remove from bases and do it all here? if scale is None: scale = 1 if basis.dim == 1: - return basis.local_grid(self, scale=scale) + return xp.asarray(basis.local_grid(self, scale=scale)) else: raise ValueError("Use `local_grids` for multidimensional bases.") @@ -292,16 +305,18 @@ def local_grid(self, basis, scale=None): # return tuple(grids) def local_grids(self, *bases, scales=None): + xp = self.array_namespace scales = self.remedy_scales(scales) grids = [] for basis in bases: basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1] - grids.extend(basis.local_grids(self, scales=basis_scales)) + grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales))) return grids def local_modes(self, basis): # TODO: remove from bases and do it all here? - return basis.local_modes(self) + xp = self.array_namespace + return xp.asarray(basis.local_modes(self)) @CachedAttribute def default_nonconst_groups(self): diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 415edcf6..93330740 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -7,6 +7,7 @@ from functools import partial, reduce from collections import defaultdict import numpy as np +import array_api_compat from mpi4py import MPI from scipy import sparse from scipy.sparse import linalg as splinalg @@ -473,16 +474,19 @@ def evaluate(self): def reinitialize(self, **kw): return self - @staticmethod - def _create_buffer(buffer_size): + def _create_buffer(self, buffer_size): """Create buffer for Field data.""" - if buffer_size == 0: - # FFTW doesn't like allocating size-0 arrays - return np.zeros((0,), dtype=np.float64) + xp = self.array_namespace + if xp == np: + if buffer_size == 0: + # FFTW doesn't like allocating size-0 arrays + return np.zeros((0,), dtype=np.float64) + else: + # Use FFTW SIMD aligned allocation + alloc_doubles = buffer_size // 8 + return fftw.create_buffer(alloc_doubles) else: - # Use FFTW SIMD aligned allocation - alloc_doubles = buffer_size // 8 - return fftw.create_buffer(alloc_doubles) + return xp.zeros(buffer_size) @CachedAttribute def _dealias_buffer_size(self): @@ -516,14 +520,17 @@ def preset_scales(self, scales): def preset_layout(self, layout): """Interpret buffer as data in specified layout.""" + xp = self.array_namespace layout = self.dist.get_layout_object(layout) self.layout = layout tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = np.ndarray(shape=total_shape, - dtype=self.dtype, - buffer=self.buffer) + # Create view into buffer + if array_api_compat.is_cupy_namespace(xp): + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) + else: + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) @@ -561,6 +568,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None): dtype = dist.dtype from .domain import Domain self.dist = dist + self.array_namespace = dist.array_namespace self.name = name self.tensorsig = tensorsig self.dtype = dtype @@ -774,9 +782,15 @@ def allgather_data(self, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # Build global buffers tensor_shape = tuple(cs.dim for cs in self.tensorsig) global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales) @@ -785,7 +799,7 @@ def allgather_data(self, layout=None): recv_buff = np.empty_like(send_buff) # Combine data via allreduce -- easy but not communication-optimal # Should be optimized using Allgatherv if this is used past startup - send_buff[local_slices] = self.data + send_buff[local_slices] = data self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM) return recv_buff @@ -793,13 +807,19 @@ def gather_data(self, root=0, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # TODO: Shortcut this for constant fields # Gather data # Should be optimized via Gatherv eventually - pieces = self.dist.comm.gather(self.data, root=root) + pieces = self.dist.comm.gather(data, root=root) # Assemble on root node if self.dist.comm.rank == root: ext_mesh = self.layout.ext_mesh @@ -826,7 +846,7 @@ def allreduce_data_norm(self, layout=None, order=2): if self.dist.comm.size > 1: norm = self.dist.comm.allreduce(norm, op=MPI.SUM) norm = norm ** (1 / order) - return norm + return float(norm) def allreduce_data_max(self, layout=None): return self.allreduce_data_norm(layout=layout, order=np.inf) @@ -907,6 +927,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis **kw : dict Other keywords passed to the distribution method. """ + xp = self.dist.array_namespace init_layout = self.layout # Set scales if requested if scales is not None: @@ -926,11 +947,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis spatial_slices = self.layout.slices(self.domain, self.scales) local_slices = component_slices + spatial_slices local_data = global_data[local_slices] - if self.is_real: - self.data[:] = local_data - else: - self.data.real[:] = local_data[..., 0] - self.data.imag[:] = local_data[..., 1] + if self.is_complex: + local_data = local_data[..., 0] + 1j * local_data[..., 1] + # Copy to field data + self.data[:] = xp.asarray(local_data, dtype=self.dtype) def low_pass_filter(self, shape=None, scales=None): """ diff --git a/dedalus/core/future.py b/dedalus/core/future.py index 58f9cd9d..ab07e8ff 100644 --- a/dedalus/core/future.py +++ b/dedalus/core/future.py @@ -51,6 +51,7 @@ def __init__(self, *args, out=None): self.original_args = tuple(args) self.out = out self.dist = unify_attributes(args, 'dist', require=False) + self.array_namespace = self.dist.array_namespace #self.domain = Domain(self.dist, self.bases) self._grid_layout = self.dist.grid_layout self._coeff_layout = self.dist.coeff_layout diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 9a9a993d..d19bd961 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -15,6 +15,7 @@ from math import prod from ..libraries import dedalus_sphere import logging +import array_api_compat logger = logging.getLogger(__name__.split('.')[-1]) from .domain import Domain @@ -378,11 +379,12 @@ def enforce_conditions(self): arg0.require_grid_space() def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args # Multiply in grid layout out.preset_layout(arg0.layout) if out.data.size: - np.power(arg0.data, arg1, out.data) + xp.power(arg0.data, arg1, out.data) def new_operands(self, arg0, arg1, **kw): return Power(arg0, arg1) @@ -498,8 +500,9 @@ def enforce_conditions(self): self.args[i].change_layout(self.layout) def operate(self, out): + xp = self.array_namespace out.preset_layout(self.layout) - np.copyto(out.data, self.func(*self.args, **self.kw)) + xp.copyto(out.data, self.func(*self.args, **self.kw)) class UnaryGridFunction(NonlinearOperator, FutureField): @@ -812,10 +815,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0 = self.args[0] out.preset_layout(arg0.layout) out.lock_to_layouts(self.layouts) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) def new_operand(self, operand, **kw): return Lock(operand, *self.layouts, **kw) @@ -947,6 +951,20 @@ def subspace_matrix(self, layout): # Caching layer to allow insertion of other arguments return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + @CachedMethod + def subspace_matrix_device(self, layout): + """Build matrix operating on local subspace data on device.""" + # Caching layer to allow insertion of other arguments + matrix = self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupy as cp + import cupyx.scipy.sparse as csp + if sparse.issparse(matrix): + matrix = csp.csr_matrix(matrix) + else: + matrix = cp.array(matrix) + return matrix + def group_matrix(self, group): return self._group_matrix(group, self.input_basis, self.output_basis) @@ -954,7 +972,7 @@ def group_matrix(self, group): @CachedMethod def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): if cls.subaxis_coupling[0]: - return cls._full_matrix(input_basis, output_basis, *args) + return cls._full_matrix(input_basis, output_basis, *args).astype(layout.dist.dtype) else: input_domain = Domain(layout.dist, bases=[input_basis]) output_domain = Domain(layout.dist, bases=[output_basis]) @@ -968,7 +986,7 @@ def _subspace_matrix(cls, layout, input_basis, output_basis, axis, *args): group_blocks = [cls._group_matrix(group, input_basis, output_basis, *args) for group in groups] arg_size = layout.local_shape(input_domain, scales=1)[axis] out_size = layout.local_shape(output_domain, scales=1)[axis] - return sparse_block_diag(group_blocks, shape=(out_size, arg_size)) + return sparse_block_diag(group_blocks, shape=(out_size, arg_size)).astype(layout.dist.dtype) @staticmethod def _full_matrix(input_basis, output_basis, *args): @@ -987,7 +1005,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) - apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) + apply_matrix(self.subspace_matrix_device(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) @@ -1522,9 +1540,10 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) class Convert(SpectralOperator, metaclass=MultiClass): @@ -1624,12 +1643,13 @@ def subspace_matrix(self, layout): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] layout = arg.layout # Copy for grid space if layout.grid_space[self.last_axis]: out.preset_layout(layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) # Revert to matrix application for coeff space else: super().operate(out) @@ -1772,10 +1792,13 @@ def base(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.einsum('ii...', arg.data, out=out.data) - + if array_api_compat.is_cupy_namespace(xp): + out.data[:] = xp.einsum('ii...', arg.data) + else: + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): @@ -1971,6 +1994,7 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace operand = self.args[0] # Set output layout out.preset_layout(operand.layout) @@ -3485,10 +3509,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductDivergence(Divergence): @@ -3534,10 +3559,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalDivergence(Divergence, SphericalEllOperator): @@ -3739,10 +3765,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductCurl(Curl): @@ -3826,10 +3853,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalCurl(Curl, SphericalEllOperator): @@ -4052,10 +4080,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductLaplacian(Laplacian): @@ -4097,10 +4126,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalLaplacian(Laplacian, SphericalEllOperator): diff --git a/dedalus/core/solvers.py b/dedalus/core/solvers.py index c63a5fbb..98776d3b 100644 --- a/dedalus/core/solvers.py +++ b/dedalus/core/solvers.py @@ -67,11 +67,16 @@ def __init__(self, problem, ncc_cutoff=1e-6, max_ncc_terms=None, entry_cutoff=1e self.ncc_cutoff = ncc_cutoff self.max_ncc_terms = max_ncc_terms self.entry_cutoff = entry_cutoff + # Determing matrix coupling if matrix_coupling is None: - matrix_coupling = np.array(problem.matrix_coupling) - # Couple fully separable problems along last axis by default for efficiency - if not np.any(matrix_coupling): - matrix_coupling[-1] = True + # Override with full coupling according to config option + if self.dist.is_cupy_namespace and config['matrix construction'].getboolean('COUPLE_GPU_SUBPROBLEMS'): + matrix_coupling = np.ones_like(problem.matrix_coupling, dtype=bool) + else: + matrix_coupling = np.array(problem.matrix_coupling) + # Couple fully separable problems along last axis by default for efficiency + if not np.any(matrix_coupling): + matrix_coupling[-1] = True else: # Check specified coupling for compatibility problem_coupling = np.array(problem.matrix_coupling) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index 191d63e0..bd0a6c0b 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -11,13 +11,20 @@ from mpi4py import MPI import uuid from math import prod +import array_api_compat from .domain import Domain -from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv +from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device from ..tools.cache import CachedAttribute, CachedMethod from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress +try: + import cupy as cp + import cupyx.scipy.sparse as csp +except ImportError: + pass + import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -118,6 +125,7 @@ def __init__(self, solver, group): self.solver = solver self.problem = problem = solver.problem self.dist = solver.dist + self.array_namespace = solver.dist.array_namespace self.dtype = problem.dtype self.group = group # Determine matrix group using solver matrix dependence @@ -191,11 +199,12 @@ def field_size(self, field): @CachedMethod def _gather_scatter_setup(self, fields): + xp = self.array_namespace # Allocate vector fsizes = tuple(self.field_size(f) for f in fields) fslices = tuple(self.field_slices(f) for f in fields) fshapes = tuple(self.field_shape(f) for f in fields) - data = np.empty(sum(fsizes), dtype=self.dtype) + data = xp.empty(sum(fsizes), dtype=self.dtype) # Make views into data fviews = [] i0 = 0 @@ -248,6 +257,7 @@ def __init__(self, solver, subsystems, group): self.subsystems = subsystems self.group = group self.dist = problem.dist + self.array_namespace = self.dist.array_namespace self.domain = problem.variables[0].domain # HACK self.dtype = problem.dtype # Cross reference from subsystems @@ -279,7 +289,8 @@ def size(self): @CachedAttribute def _compressed_buffer(self): - return np.zeros(self.shape, dtype=self.dtype) + xp = self.array_namespace + return xp.zeros(self.shape, dtype=self.dtype) def coeff_slices(self, domain): return self.subsystems[0].coeff_slices(domain) @@ -300,9 +311,10 @@ def field_size(self, field): return self.subsystems[0].field_size(field) def _build_buffer_views(self, fields): + xp = self.array_namespace # Allocate buffer fsizes = tuple(self.field_size(f) for f in fields) - buffer = np.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) + buffer = xp.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) # Make views into buffer views = [] i0 = 0 @@ -342,7 +354,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +366,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +380,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +389,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" @@ -555,24 +567,45 @@ def build_matrices(self, names): left_perm = left_permutation(self, eqns, bc_top=solver.bc_top, interleave_components=solver.interleave_components).tocsr() right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() - # Preconditioners + # Preconditioners on CPU # TODO: remove astype casting, requires dealing with used types in apply_sparse - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) - self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) - self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) + pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + pre_left_pinv = pre_left.T.tocsr().astype(dtype) + pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + pre_right = pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses - assert_sparse_pinv(self.pre_left, self.pre_left_pinv) - assert_sparse_pinv(self.pre_right, self.pre_right_pinv) + assert_sparse_pinv(pre_left, pre_left_pinv) + assert_sparse_pinv(pre_right, pre_right_pinv) # Precondition matrices for name in matrices: - matrices[name] = self.pre_left @ matrices[name] @ self.pre_right + matrices[name] = pre_left @ matrices[name] @ pre_right - # Store minimal CSR matrices for fast dot products + # Store minimal CSR matrices on CPU for name, matrix in matrices.items(): - setattr(self, '{:}_min'.format(name), matrix.tocsr()) + setattr(self, f'{name}_min', matrix.tocsr()) + + # Store device copies for fast dot products + xp = solver.dist.array_namespace + if array_api_compat.is_numpy_namespace(xp): + self.pre_left = pre_left + self.pre_left_pinv = pre_left_pinv + self.pre_right_pinv = pre_right_pinv + self.pre_right = pre_right + # Reference current CPU matrices + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', getattr(self, f'{name}_min')) + elif array_api_compat.is_cupy_namespace(xp): + # Copy to device + self.pre_left = csp.csr_matrix(pre_left) + self.pre_left_pinv = csp.csr_matrix(pre_left_pinv) + self.pre_right_pinv = csp.csr_matrix(pre_right_pinv) + self.pre_right = csp.csr_matrix(pre_right) + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', csp.csr_matrix(matrix)) + else: + raise ValueError("Unsupported array namespace: {}".format(xp)) # Store expanded CSR matrices for fast recombination if len(matrices) > 1: diff --git a/dedalus/core/system.py b/dedalus/core/system.py index 23cbb86b..f28cb206 100644 --- a/dedalus/core/system.py +++ b/dedalus/core/system.py @@ -12,45 +12,44 @@ class CoeffSystem: """ - Representation of a collection of fields that don't need to be transformed, - and are therefore stored as a contigous set of coefficient data for - efficient pencil and group manipulation. + Contiguous buffer for data from all subproblems. Parameters ---------- - nfields : int - Number of fields to represent - domain : domain object - Problem domain + subproblems : list of Subproblem objects + Subproblems to represent + dtype : dtype + Data type + array_namespace : array namespace + Array namespace Attributes ---------- data : ndarray - Contiguous buffer for field coefficients - - """ - - """ - var buffer - + Contiguous buffer for data from all subproblems + views : dict + Nested dictionary of views for each subproblem and subsystem """ - def __init__(self, subproblems, dtype): + def __init__(self, subproblems, dtype, array_namespace): + xp = array_namespace # Build buffer total_size = sum(sp.LHS.shape[1]*len(sp.subsystems) for sp in subproblems) - self.data = np.zeros(total_size, dtype=dtype) + self.data = xp.zeros(total_size, dtype=dtype) # Build views i0 = i1 = 0 self.views = views = {} for sp in subproblems: views[sp] = views_sp = {} + # View for each individual subsystem i00 = i0 for ss in sp.subsystems: i1 += sp.LHS.shape[1] views_sp[ss] = self.data[i0:i1] i0 = i1 i11 = i1 + # View combining all subsystems as rows in a matrix if i11 - i00 > 0: views_sp[None] = self.data[i00:i11].reshape((sp.LHS.shape[1], -1)) else: diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 81da4c10..0a8c20a9 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -2,10 +2,9 @@ from collections import deque, OrderedDict import numpy as np -from scipy.linalg import blas from .system import CoeffSystem -from ..tools.array import apply_sparse +from ..tools.array import apply_sparse, get_axpy # Public interface @@ -71,7 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,16 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(self.xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -117,7 +117,7 @@ def step(self, dt, wall_time): self.dt[0] = dt # Compute IMEX coefficients - a, b, c = self.compute_coefficients(self.dt, self._iteration) + a, b, c = self.compute_coefficients(self.dt, self._iteration, self.solver.dtype) self._iteration += 1 # Update RHS components and LHS matrices @@ -143,8 +143,8 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Evaluate F(X0) evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt) @@ -154,7 +154,7 @@ def step(self, dt, wall_time): # Build RHS if RHS.data.size: - np.multiply(c[1], F0.data, out=RHS.data) + self.xp.multiply(c[1], F0.data, out=RHS.data) for j in range(2, len(c)): # RHS.data += c[j] * F[j-1].data axpy(a=c[j], x=F[j-1].data, y=RHS.data) @@ -173,7 +173,7 @@ def step(self, dt, wall_time): if update_LHS: if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = a0*sp.M_exp.data + b0*sp.L_exp.data - np.multiply(a0, sp.M_exp.data, out=sp.LHS.data) + self.xp.multiply(a0, sp.M_exp.data, out=sp.LHS.data) axpy(a=b0, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (a0*sp.M_min + b0*sp.L_min) # CREATES TEMPORARY @@ -203,11 +203,11 @@ class CNAB1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -236,11 +236,11 @@ class SBDF1(MultistepIMEX): steps = 1 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k0, *rest = timesteps @@ -268,14 +268,14 @@ class CNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -306,14 +306,14 @@ class MCNAB2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -345,14 +345,14 @@ class SBDF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return SBDF1.compute_coefficients(timesteps, iteration) + return SBDF1.compute_coefficients(timesteps, iteration, dtype=dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -383,14 +383,14 @@ class CNLF2(MultistepIMEX): steps = 2 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 1: - return CNAB1.compute_coefficients(timesteps, iteration) + return CNAB1.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k1, k0, *rest = timesteps w1 = k1 / k0 @@ -422,14 +422,14 @@ class SBDF3(MultistepIMEX): steps = 3 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 2: - return SBDF2.compute_coefficients(timesteps, iteration) + return SBDF2.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k2, k1, k0, *rest = timesteps w2 = k2 / k1 @@ -463,14 +463,14 @@ class SBDF4(MultistepIMEX): steps = 4 @classmethod - def compute_coefficients(self, timesteps, iteration): + def compute_coefficients(self, timesteps, iteration, dtype): if iteration < 3: - return SBDF3.compute_coefficients(timesteps, iteration) + return SBDF3.compute_coefficients(timesteps, iteration, dtype) - a = np.zeros(self.amax+1) - b = np.zeros(self.bmax+1) - c = np.zeros(self.cmax+1) + a = np.zeros(self.amax+1, dtype=dtype) + b = np.zeros(self.bmax+1, dtype=dtype) + c = np.zeros(self.cmax+1, dtype=dtype) k3, k2, k1, k0, *rest = timesteps w3 = k3 / k2 @@ -539,15 +539,21 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + self.xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=self.xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(self.xp, solver.dtype) + + # Cast scheme coefficients + self.A = self.A.astype(self.solver.dtype) + self.H = self.H.astype(self.solver.dtype) + self.c = self.c.astype(self.solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -584,11 +590,12 @@ def step(self, dt, wall_time): # Compute M.X(n,0) and L.X(n,0) # Ensure coeff space before subsystem gathers + # TODO: add option to evaluate this matrix-free (e.g for high-bandwidth NCCs when using fast transforms) evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Compute stages # (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j) @@ -601,7 +608,7 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LXi.get_subdata(sp)) # Compute F(n,i-1), only doing output on first evaluation if i == 1: @@ -615,7 +622,7 @@ def step(self, dt, wall_time): # Construct RHS(n,i) if RHS.data.size: - np.copyto(RHS.data, MX0.data) + self.xp.copyto(RHS.data, MX0.data) for j in range(0, i): # RHS.data += (k * A[i,j]) * F[j].data axpy(a=(k*A[i,j]), x=F[j].data, y=RHS.data) @@ -632,7 +639,7 @@ def step(self, dt, wall_time): if update_LHS: if STORE_EXPANDED_MATRICES: # sp.LHS.data[:] = sp.M_exp.data + k_Hii*sp.L_exp.data - np.copyto(sp.LHS.data, sp.M_exp.data) + self.xp.copyto(sp.LHS.data, sp.M_exp.data) axpy(a=k_Hii, x=sp.L_exp.data, y=sp.LHS.data) else: sp.LHS = (sp.M_min + k_Hii*sp.L_min) # CREATES TEMPORARY diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..90b11a1d 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -8,13 +8,16 @@ import scipy.fftpack from ..libraries import dedalus_sphere from math import prod +import array_api_compat from . import basis from ..libraries.fftw import fftw_wrappers as fftw from ..tools import jacobi -from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse +from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse, copyto from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod +from ..tools.general import float_to_complex +from ..tools.linalg_gpu import cupy_solve_upper_csr, CustomCupyUpperTriangularSolver import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -93,19 +96,25 @@ class JacobiTransform(SeparableTransform): Jacobi "a" parameter for the quadrature grid. b0 : int Jacobi "b" parameter for the quadrature grid. + array_namespace : array namespace + Array namespace for the transform. + dtype : dtype + Data type for the transform. Notes ----- TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + self.array_namespace = array_namespace + self.dtype = dtype if dealias_before_converting is None: dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() self.dealias_before_converting = dealias_before_converting @@ -118,6 +127,7 @@ class JacobiMMT(JacobiTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -141,11 +151,12 @@ def forward_matrix(self): # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products - return np.asarray(forward_matrix, order='C') + return xp.asarray(forward_matrix, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -155,11 +166,11 @@ def backward_matrix(self): # Zero higher polynomials for transforms with grid_size < coeff_size polynomials[N:, :] = 0 # Transpose and ensure C ordering for fast dot products - return np.asarray(polynomials.T, order='C') + return xp.asarray(polynomials.T, order='C', dtype=self.dtype) class ComplexFourierTransform(SeparableTransform): - """ + r""" Abstract base class for complex-to-complex Fourier transforms. Parameters @@ -191,19 +202,22 @@ class ComplexFourierTransform(SeparableTransform): If M is even, the ordering is [0, 1, 2, ..., KM, -KM, -KM+1, ..., -1]. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): self.N = grid_size self.M = coeff_size self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace M = self.M KM = self.KM - k = np.arange(M) + k = xp.arange(M) # Wrap around Nyquist mode return (k + KM) % M - KM @@ -215,26 +229,28 @@ class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[:, None] - X = np.arange(self.N)[None, :] + X = xp.arange(self.N)[None, :] dX = self.N / 2 / np.pi - quadrature = np.exp(-1j*K*X/dX) / self.N + quadrature = xp.exp(-1j*K*X/dX) / self.N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - quadrature *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + quadrature *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[None, :] - X = np.arange(self.N)[:, None] + X = xp.arange(self.N)[:, None] dX = self.N / 2 / np.pi - functions = np.exp(1j*K*X/dX) + functions = xp.exp(1j*K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - functions *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + functions *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(functions, order='C', dtype=self.dtype) class ComplexFFT(ComplexFourierTransform): @@ -242,29 +258,30 @@ class ComplexFFT(ComplexFourierTransform): def resize_coeffs(self, data_in, data_out, axis, rescale): """Resize and rescale coefficients in standard FFT format by intermediate padding/truncation.""" + xp = self.array_namespace M = self.M Kmax = self.Kmax if Kmax == 0: posfreq = axslice(axis, 0, 1) badfreq = axslice(axis, 1, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 else: posfreq = axslice(axis, 0, Kmax+1) badfreq = axslice(axis, Kmax+1, -Kmax) negfreq = axslice(axis, -Kmax, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 - np.copyto(data_out[negfreq], data_in[negfreq]) + xp.copyto(data_out[negfreq], data_in[negfreq]) else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 - np.multiply(data_in[negfreq], rescale, data_out[negfreq]) + xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) @register_transform(basis.ComplexFourier, 'scipy') @@ -289,6 +306,34 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.ComplexFourier, 'cupy') +class CupyComplexFFT(ComplexFFT): + """Complex-to-complex FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call FFT + temp = self.cufft.fft(gdata, axis=axis) # Creates temporary + # Resize and rescale for unit-amplitude normalization + self.resize_coeffs(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_coeffs(cdata, temp, axis, rescale=self.N) + # Call FFT + temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + class FFTWBase: """Abstract base class for FFTW transforms.""" @@ -331,7 +376,7 @@ def backward(self, cdata, gdata, axis): class RealFourierTransform(SeparableTransform): - """ + r""" Abstract base class for real-to-real Fourier transforms. Parameters @@ -368,7 +413,7 @@ class RealFourierTransform(SeparableTransform): where the k = 0 minus-sine mode is zeroed in both directions. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): if coeff_size % 2 != 0: pass#raise ValueError("coeff_size must be even.") self.N = grid_size @@ -376,12 +421,15 @@ def __init__(self, grid_size, coeff_size): self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace # Repeat k's for cos and msin parts - return np.repeat(np.arange(self.KM+1), 2) + return xp.repeat(xp.arange(self.KM+1), 2) @register_transform(basis.RealFourier, 'matrix') @@ -391,37 +439,39 @@ class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[::2, None] - X = np.arange(N)[None, :] + X = xp.arange(N)[None, :] dX = N / 2 / np.pi - quadrature = np.zeros((M, N)) - quadrature[0::2] = (2 / N) * np.cos(K*X/dX) - quadrature[1::2] = -(2 / N) * np.sin(K*X/dX) + quadrature = xp.zeros((M, N)) + quadrature[0::2] = (2 / N) * xp.cos(K*X/dX) + quadrature[1::2] = -(2 / N) * xp.sin(K*X/dX) quadrature[0] = 1 / N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size quadrature *= self.wavenumbers[:,None] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[None, ::2] - X = np.arange(N)[:, None] + X = xp.arange(N)[:, None] dX = N / 2 / np.pi - functions = np.zeros((N, M)) - functions[:, 0::2] = np.cos(K*X/dX) - functions[:, 1::2] = -np.sin(K*X/dX) + functions = xp.zeros((N, M)) + functions[:, 0::2] = xp.cos(K*X/dX) + functions[:, 1::2] = -xp.sin(K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size functions *= self.wavenumbers[None, :] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + return xp.asarray(functions, order='C', dtype=self.dtype) @register_transform(basis.RealFourier, 'fftpack') @@ -471,40 +521,42 @@ class RealFFT(RealFourierTransform): def unpack_rescale(self, temp, cdata, axis, rescale): """Unpack complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 cos data meancos = axslice(axis, 0, 1) - np.multiply(temp[meancos].real, rescale, cdata[meancos]) + xp.multiply(temp[meancos].real, rescale, cdata[meancos]) # Zero k = 0 msin data cdata[axslice(axis, 1, 2)] = 0 # Unpack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] - np.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) - np.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) + xp.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) + xp.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) # Zero k > Kmax data cdata[axslice(axis, 2*(Kmax+1), None)] = 0 def repack_rescale(self, cdata, temp, axis, rescale): """Repack into complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 data meancos = axslice(axis, 0, 1) if rescale is None: - np.copyto(temp[meancos], cdata[meancos]) + xp.copyto(temp[meancos], cdata[meancos]) else: - np.multiply(cdata[meancos], rescale, temp[meancos]) + xp.multiply(cdata[meancos], rescale, temp[meancos]) # Repack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] if rescale is None: - np.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) else: - np.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) # Zero k > Kmax data temp[axslice(axis, Kmax+1, None)] = 0 @@ -513,6 +565,10 @@ def repack_rescale(self, cdata, temp, axis, rescale): class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call RFFT @@ -526,7 +582,7 @@ def backward(self, cdata, gdata, axis): # Rescale all modes and combine into complex form shape = list(gdata.shape) shape[axis] = N // 2 + 1 - temp = np.empty(shape=shape, dtype=np.complex128) # Creates temporary + temp = np.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary # Repack into complex form and rescale self.repack_rescale(cdata, temp, axis, rescale=N) # Call IRFFT @@ -534,6 +590,38 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.RealFourier, 'cupy') +class CupyRealFFT(RealFFT): + """Real-to-real FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call RFFT + temp = self.cufft.rfft(gdata, axis=axis) # Creates temporary + # Unpack from complex form and rescale + self.unpack_rescale(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + N = self.N + # Rescale all modes and combine into complex form + shape = list(gdata.shape) + shape[axis] = N // 2 + 1 + temp = xp.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary + # Repack into complex form and rescale + self.repack_rescale(cdata, temp, axis, rescale=N) + # Call IRFFT + temp = self.cufft.irfft(temp, axis=axis, n=N, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + @register_transform(basis.RealFourier, 'fftw') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -630,7 +718,7 @@ def backward(self, cdata, gdata, axis): class CosineTransform(SeparableTransform): - """ + r""" Abstract base class for cosine transforms. Parameters @@ -768,6 +856,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class CupyDCT(FastCosineTransform): + """Fast cosine transform using cupy fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DCT + temp = self.cufft.dct(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDCT + temp = self.cufft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + copyto(gdata, temp) + + #@register_transform(basis.Cosine, 'fftw') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @@ -804,11 +919,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) + super().__init__(grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -831,15 +946,22 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): else: # Conversion matrices if self.dealias_before_converting and (self.M_orig < self.N): # truncate prior to conversion matrix - self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) else: # input to conversion matrix not truncated self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b) self.forward_conversion.resize(self.M_orig, self.N) - self.forward_conversion = self.forward_conversion.tocsr() - self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr() + self.forward_conversion = self.forward_conversion.tocsr().astype(dtype) + self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b).tocsr().astype(dtype) self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupyx.scipy.sparse as csp + self.forward_conversion = csp.csr_matrix(self.forward_conversion) + self.backward_conversion = csp.csr_matrix(self.backward_conversion) + self.forward_conversion.sum_duplicates() + self.backward_conversion.sum_duplicates() + self.backward_conversion_LU = CustomCupyUpperTriangularSolver(self.backward_conversion) def _resize_rescale_forward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -881,7 +1003,10 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): # Truncate input before conversion data_in[badfreq] = 0 # Ultraspherical conversion - solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) + if array_api_compat.is_cupy_namespace(self.array_namespace): + cupy_solve_upper_csr(self.backward_conversion_LU, data_in, axis, out=data_in) + else: + solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) # Change sign of odd modes if Kmax_orig > 0: posfreq_odd = axslice(axis, 1, Kmax_orig+1, 2) @@ -890,18 +1015,24 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct') +@register_transform(basis.Jacobi, 'scipy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct') +@register_transform(basis.Jacobi, 'fftw') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance +@register_transform(basis.Jacobi, 'cupy') +class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): + """Fast ultraspherical transform using cupy fft and spectral conversion.""" + pass # Implementation is complete via inheritance + + # class ScipyDST(PolynomialTransform): # def forward_reduced(self): diff --git a/dedalus/dedalus.cfg b/dedalus/dedalus.cfg index edd0595c..04861e6b 100644 --- a/dedalus/dedalus.cfg +++ b/dedalus/dedalus.cfg @@ -31,9 +31,6 @@ [transforms] - # Default transform library (scipy, fftw) - DEFAULT_LIBRARY = fftw - # Transform multiple fields together when possible GROUP_TRANSFORMS = False @@ -71,6 +68,9 @@ [matrix construction] + # Fully couple GPU subproblems + COUPLE_GPU_SUBPROBLEMS = True + # Put BC rows at the top of the matrix BC_TOP = False diff --git a/dedalus/extras/flow_tools.py b/dedalus/extras/flow_tools.py index bc6798ba..b99a5069 100644 --- a/dedalus/extras/flow_tools.py +++ b/dedalus/extras/flow_tools.py @@ -86,7 +86,7 @@ def __init__(self, solver, cadence=1): self.solver = solver self.cadence = cadence - self.reducer = GlobalArrayReducer(solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(solver.dist.comm_cart, solver.dtype) self.properties = solver.evaluator.add_dictionary_handler(iter=cadence) def add_property(self, property, name, precompute_integral=False): @@ -181,7 +181,7 @@ def __init__(self, solver, initial_dt, cadence=1, safety=1., max_dt=np.inf, self.min_change = min_change self.threshold = threshold - self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart) + self.reducer = GlobalArrayReducer(self.solver.dist.comm_cart, solver.dtype) self.frequencies = self.solver.evaluator.add_dictionary_handler(iter=cadence) def compute_dt(self): diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index f301d4f2..ede93a1b 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -5,7 +5,12 @@ import scipy.sparse as sp import scipy.sparse.linalg as spla from functools import partial - +import array_api_compat +try: + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False matsolvers = {} def add_solver(solver): @@ -144,6 +149,21 @@ def __init__(self, matrix, solver=None): relax=self.relax, panel_size=self.panel_size, options=self.options) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + # Avoid cupy splu which requires GPU matrices but transfers them to factorize on CPU + # Run same typecheck as cupy splu + if matrix.dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) + # Build cupy factorization from scipy factorization of CPU matrices + self.LU = cupy_spla.SuperLU(self.LU) + self.LU.spsm_L_descr = None + self.LU.spsm_U_descr = None + self.solve = self.cupy_solve + + def cupy_solve(self, vector): + from dedalus.tools.linalg_gpu import custom_SuperLU_solve + return custom_SuperLU_solve(self.LU, vector, trans=self.trans) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) @@ -225,6 +245,9 @@ class SparseInverse(SparseSolver): def __init__(self, matrix, solver=None): self.matrix_inverse = spla.inv(matrix.tocsc()) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + self.matrix_inverse = cupy_spla.inv(matrix.tocsc()) def solve(self, vector): return self.matrix_inverse @ vector diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index ab9caf88..e137f75d 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -5,7 +5,10 @@ import scipy.sparse as sp from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla +from scipy.linalg import blas from math import prod +from ..tools import linalg_gpu +import array_api_compat from .config import config from . import linalg as cython_linalg @@ -76,10 +79,20 @@ def expand_pattern(input, pattern): def apply_matrix(matrix, array, axis, **kw): """Apply matrix along any axis of an array.""" - if sparse.isspmatrix(matrix): - return apply_sparse(matrix, array, axis, **kw) + xp = array_api_compat.array_namespace(array) + if array_api_compat.is_numpy_namespace(xp): + if sparse.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) + elif array_api_compat.is_cupy_namespace(xp): + import cupyx.scipy.sparse as csp + if csp.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) else: - return apply_dense(matrix, array, axis, **kw) + raise ValueError("Unsupported array type") def apply_dense_einsum(matrix, array, axis, optimize=True, **kw): @@ -173,14 +186,14 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Apply sparse matrix along any axis of an array. Must be out of place if ouptut is specified. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + xp = array_api_compat.array_namespace(array) + matrix.sum_duplicates() + matrix.has_canonical_format = True # Check output if out is None: out_shape = list(array.shape) out_shape[axis] = matrix.shape[0] - out = np.empty(out_shape, dtype=array.dtype) + out = xp.empty(out_shape, dtype=array.dtype) elif out is array: raise ValueError("Cannot apply in place") # Check shapes @@ -189,17 +202,27 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= raise ValueError("Axis out of bounds.") if matrix.shape[1] != array.shape[axis] or matrix.shape[0] != out.shape[axis]: raise ValueError("Matrix shape mismatch.") - # Old way if requested - if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: - out.fill(0) - return csr_matvecs(matrix, array, out) - # Promote datatypes - # TODO: find way to optimize this with fused types - matrix_data = matrix.data - if matrix_data.dtype != out.dtype: - matrix_data = matrix_data.astype(out.dtype) - # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Old way if requested + if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: + out.fill(0) + return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) + # Call cython routine + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + # TODO: check matrix format here without import cupy + linalg_gpu.cupy_apply_csr(matrix, array, axis, out) + else: + raise ValueError("Unsupported array type") return out @@ -208,28 +231,40 @@ def solve_upper_sparse(matrix, rhs, axis, out=None, check_shapes=False, num_thre Solve upper triangular sparse matrix along any axis of an array. Matrix assumed to be nonzero on the diagonals. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") - if not matrix._has_canonical_format: # avoid property hook (without underscore) - matrix.sum_duplicates() - # Setup output = rhs + xp = array_api_compat.array_namespace(rhs) + matrix.sum_duplicates() + matrix.has_canonical_format = True + # Check output if out is None: - out = np.copy(rhs) - elif out is not rhs: - np.copyto(out, rhs) - # Promote datatypes - matrix_data = matrix.data - if matrix_data.dtype != rhs.dtype: - matrix_data = matrix_data.astype(rhs.dtype) - # Check shapes - if check_shapes: - if not (0 <= axis < rhs.ndim): - raise ValueError("Axis out of bounds.") - if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): - raise ValueError("Matrix shape mismatch.") - # Call cython routine - cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + out = xp.empty_like(rhs) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + # Setup output = rhs + copyto(out, rhs) + # Promote datatypes + matrix_data = matrix.data + if matrix_data.dtype != rhs.dtype: + matrix_data = matrix_data.astype(rhs.dtype) + # Check shapes + if check_shapes: + if not (0 <= axis < rhs.ndim): + raise ValueError("Axis out of bounds.") + if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): + raise ValueError("Matrix shape mismatch.") + # Call cython routine + cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + linalg_gpu.cupy_solve_upper_csr(matrix, rhs, axis, out) + else: + raise ValueError("Unsupported array type") + return out def csr_matvec(A_csr, x_vec, out_vec): @@ -353,6 +388,22 @@ def copyto(dest, src): dest[:] = src +def copy_to_device(dest, src): + xp = array_api_compat.array_namespace(dest) + if array_api_compat.is_cupy_namespace(xp): + src = xp.asarray(src) + dest[:] = src + else: + dest[:] = src + + +def copy_from_device(dest, src): + if array_api_compat.is_cupy_array(src): + src.get(out=dest) + else: + dest[:] = src + + def perm_matrix(perm, M=None, source_index=False, sparse=True): """ Build sparse permutation matrix from permutation vector. @@ -474,3 +525,12 @@ def assert_sparse_pinv(A, B): if not sparse_allclose((B @ A).conj().T, B @ A): raise AssertionError("Not a pseudoinverse") + +def get_axpy(array_namespace, dtype): + if array_api_compat.is_numpy_namespace(array_namespace): + return blas.get_blas_funcs('axpy', dtype=dtype) + elif array_api_compat.is_cupy_namespace(array_namespace): + from cupy.cublas import axpy as cublas_axpy + return cublas_axpy + else: + raise ValueError("Unsupported array namespace") diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 18eb5ee4..9b8b5746 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -124,3 +124,15 @@ def is_complex_dtype(dtype): dtype = dtype.type return np.iscomplexobj(dtype()) + +def float_to_complex(dtype): + itemsize = np.dtype(dtype).itemsize + complex_dtype = np.dtype(f'complex{16*itemsize}') + return complex_dtype.type + + +def complex_to_float(dtype): + itemsize = np.dtype(dtype).itemsize + float_dtype = np.dtype(f'float{4*itemsize}') + return float_dtype.type + diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py new file mode 100644 index 00000000..4e1c3a33 --- /dev/null +++ b/dedalus/tools/linalg_gpu.py @@ -0,0 +1,549 @@ +"""Linear algebra routines using cupy.""" + +import numpy as np +import math +try: + import cupy as cp + import cupyx.scipy.sparse as csp + import cupyx.scipy.sparse.linalg as cupy_spla + from cupyx import jit + cupy_available = True +except ImportError: + # Mock jit so module can still be imported without cupy + class jit: + @staticmethod + def rawkernel(): + def decorator(func): + return func + return decorator + cupy_available = False + + +def cupy_apply_csr(matrix, array, axis, out): + """Apply CSR matrix to arbitrary axis of array.""" + if not cupy_available: + raise ImportError("cupy must be installed to use GPU linear algebra") + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + print('WARNING: converting matrix to CSR format') + matrix = csp.csr_matrix(matrix) + #raise ValueError("Matrix must be in CSR format.") + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + out[:] = matrix.dot(array) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + out[:,0] = matrix.dot(array[:,0]) + else: + out[:] = matrix.dot(array) + elif axis == 1: + if array.shape[0] == 1: + out[0,:] = matrix.dot(array[0,:]) + else: + out[:] = matrix.dot(array.T).T + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = array.shape[0] + N2 = array.shape[1] + N3 = array.shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + temp = matrix.dot(x1) + out[:] = temp.reshape(out.shape) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + temp = matrix.dot(x2) + out[:] = temp.reshape(out.shape) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + temp = matrix.dot(x2.T).T + out[:] = temp.reshape(out.shape) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape(((N1, matrix.shape[0], N3))) + cupy_apply_csr_mid(matrix, x3, y3) + + +@jit.rawkernel() +def apply_csr_mid_kernel(data, indices, indptr, x3, y3, N1, N2i, N2o, N3): + n1 = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x # batch index + n3 = jit.blockIdx.y * jit.blockDim.y + jit.threadIdx.y # output column index + if n1 >= N1 or n3 >= N3: + return + # Loop over output rows = CSR matrix rows + for i in range(N2o): + acc = 0 * y3[n1, i, n3] # get right type + start = indptr[i] + end = indptr[i + 1] + for k in range(start, end): + j = indices[k] + acc += data[k] * x3[n1, j, n3] + y3[n1, i, n3] = acc + +def cupy_apply_csr_mid(matrix, array, out): + N1, N2i, N3 = array.shape + N2o = matrix.shape[0] + N1 = cp.uint32(N1) + N2i = cp.uint32(N2i) + N3 = cp.uint32(N3) + N2o = cp.uint32(N2o) + # Choose thread/block config + threads_y = min(1024, N3) # maximize concurrency along n3 + threads_x = 1024 // threads_y # make block have 1024 threads + blockdim = (threads_x, threads_y) + blocks_x = (N1 + threads_x - 1) // threads_x + blocks_y = (N3 + threads_y - 1) // threads_y + griddim = (blocks_x, blocks_y) + # Launch kernel + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + + +def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): + """Custom spsm wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves a sparse triangular linear system op(a) * x = alpha * op(b). + + Args: + a (cupyx.scipy.sparse.csr_matrix or cupyx.scipy.sparse.coo_matrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): Dense matrix with dimension ``(M, K)``. + alpha (float or complex): Coefficient. + lower (bool): + True: ``a`` is lower triangle matrix. + False: ``a`` is upper triangle matrix. + unit_diag (bool): + True: diagonal part of ``a`` has unit elements. + False: diagonal part of ``a`` has non-unit elements. + transa (bool or str): True, False, 'N', 'T' or 'H'. + 'N' or False: op(a) == ``a``. + 'T' or True: op(a) == ``a.T``. + 'H': op(a) == ``a.conj().T``. + """ + import cupyx + from cupyx import cusparse + import cupy as _cupy + import numpy as _numpy + from cupy._core import _dtype + from cupy_backends.cuda.libs import cusparse as _cusparse + from cupy.cuda import device as _device + from cupyx.cusparse import SpMatDescriptor, DnMatDescriptor + if not cusparse.check_availability('spsm'): + raise RuntimeError('spsm is not available.') + + # Canonicalise transa + if transa is False: + transa = 'N' + elif transa is True: + transa = 'T' + elif transa not in 'NTH': + raise ValueError(f'Unknown transa (actual: {transa})') + + # Check A's type and sparse format + if cupyx.scipy.sparse.isspmatrix_csr(a): + pass + elif cupyx.scipy.sparse.isspmatrix_csc(a): + if transa == 'N': + a = a.T + transa = 'T' + elif transa == 'T': + a = a.T + transa = 'N' + elif transa == 'H': + a = a.conj().T + transa = 'N' + lower = not lower + elif cupyx.scipy.sparse.isspmatrix_coo(a): + pass + else: + raise ValueError('a must be CSR, CSC or COO sparse matrix') + assert a.has_canonical_format + + # Check B's ndim + if b.ndim == 1: + is_b_vector = True + b = b.reshape(-1, 1) + elif b.ndim == 2: + is_b_vector = False + else: + raise ValueError('b.ndim must be 1 or 2') + + # Check shapes + if not (a.shape[0] == a.shape[1] == b.shape[0]): + raise ValueError('mismatched shape') + + # Check dtypes + dtype = a.dtype + if dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(dtype)) + if dtype != b.dtype: + raise TypeError('dtype mismatch') + + # Prepare fill mode + if lower is True: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_LOWER + elif lower is False: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_UPPER + else: + raise ValueError('Unknown lower (actual: {})'.format(lower)) + + # Prepare diag type + if unit_diag is False: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_NON_UNIT + elif unit_diag is True: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_UNIT + else: + raise ValueError('Unknown unit_diag (actual: {})'.format(unit_diag)) + + # Prepare op_a + if transa == 'N': + op_a = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif transa == 'T': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: # transa == 'H' + if dtype.char in 'fd': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + op_a = _cusparse.CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + + # Prepare op_b + if b._f_contiguous: + op_b = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif b._c_contiguous: + if _cusparse.get_build_version() < 11701: # earlier than CUDA 11.6 + raise ValueError('b must be F-contiguous.') + b = b.T + op_b = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + raise ValueError('b must be F-contiguous or C-contiguous.') + + # Allocate space for matrix C. Note that it is known cusparseSpSM requires + # the output matrix zero initialized. + m, _ = a.shape + if op_b == _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE: + _, n = b.shape + else: + n, _ = b.shape + c_shape = m, n + c = _cupy.zeros(c_shape, dtype=a.dtype, order='f') + + # Prepare descriptors and other parameters + handle = _device.get_cusparse_handle() + mat_a = SpMatDescriptor.create(a) + mat_b = DnMatDescriptor.create(b) + mat_c = DnMatDescriptor.create(c) + if spsm_descr is None: + spsm_descr = _cusparse.spSM_createDescr() + new_spsm_descr = True + else: + spsm_descr, buff = spsm_descr + new_spsm_descr = False + alpha = _numpy.array(alpha, dtype=c.dtype).ctypes + cuda_dtype = _dtype.to_cuda_dtype(c.dtype) + algo = _cusparse.CUSPARSE_SPSM_ALG_DEFAULT + + try: + # Specify Lower|Upper fill mode + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_FILL_MODE, fill_mode) + + # Specify Unit|Non-Unit diagonal type + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) + + # Allocate the workspace needed by the succeeding phases + # Always calculate workspace (buff_size can change even for same spsm + # descriptor) + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + + need_analysis = new_spsm_descr + if new_spsm_descr: + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + else: + # Check if buff size grew from that in the cache + if buff is None or buff.size < buff_size: + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + # buff changed so need the analysis phase + need_analysis = True + + # Perform the analysis phase + if need_analysis: + _cusparse.spSM_analysis( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Executes the solve phase + _cusparse.spSM_solve( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Reshape back if B was a vector + if is_b_vector: + c = c.reshape(-1) + + return c, (spsm_descr, buff) + + finally: + # Destroy matrix/vector descriptors + #_cusparse.spSM_destroyDescr(spsm_descr) + pass + + +def custom_SuperLU_solve(self, rhs, trans='N', spsm_descr=None): + """Custom SuperLU solve wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves linear system of equations with one or several right-hand sides. + + Args: + rhs (cupy.ndarray): Right-hand side(s) of equation with dimension + ``(M)`` or ``(M, K)``. + trans (str): 'N', 'T' or 'H'. + 'N': Solves ``A * x = rhs``. + 'T': Solves ``A.T * x = rhs``. + 'H': Solves ``A.conj().T * x = rhs``. + + Returns: + cupy.ndarray: + Solution vector(s) + """ # NOQA + from cupyx import cusparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + if not isinstance(rhs, cupy.ndarray): + raise TypeError('ojb must be cupy.ndarray') + if rhs.ndim not in (1, 2): + raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. + format(rhs.ndim)) + if rhs.shape[0] != self.shape[0]: + raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' + .format(self.shape, rhs.shape)) + if trans not in ('N', 'T', 'H'): + raise ValueError('trans must be \'N\', \'T\', or \'H\'') + + if cusparse.check_availability('spsm') and _should_use_spsm(rhs): + def spsm(A, B, lower, transa, spsm_descr): + return custom_spsm(A, B, lower=lower, transa=transa, spsm_descr=spsm_descr) + sm = spsm + else: + raise NotImplementedError + + x = rhs.astype(self.L.dtype) + if trans == 'N': + if self.perm_r is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_r_rev].T # want to keep f-order + else: + x = x[self._perm_r_rev] + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + if self.perm_c is not None: + x = x[self.perm_c] + else: + if self.perm_c is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_c_rev].T # want to keep f-order + else: + x = x[self._perm_c_rev] + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + if self.perm_r is not None: + x = x[self.perm_r] + + if not x._f_contiguous: + # For compatibility with SciPy + x = x.copy(order='F') + return x + + +class CustomCupyUpperTriangularSolver: + """Hacky class to save spsm_descr for reuse in spsm for triangular solves.""" + + def __init__(self, matrix): + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + print('WARNING: converting matrix to CSR format') + #raise ValueError("Matrix must be in CSR format.") + self.matrix = matrix + self.spsm_descr = None + + def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, + unit_diagonal=False): + """Solves a sparse triangular system ``A x = b``. + + Args: + A (cupyx.scipy.sparse.spmatrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): + Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. + lower (bool): + Whether ``A`` is a lower or upper triangular matrix. + If True, it is lower triangular, otherwise, upper triangular. + overwrite_A (bool): + (not supported) + overwrite_b (bool): + Allows overwriting data in ``b``. + unit_diagonal (bool): + If True, diagonal elements of ``A`` are assumed to be 1 and will + not be referenced. + + Returns: + cupy.ndarray: + Solution to the system ``A x = b``. The shape is the same as ``b``. + """ + from cupyx import cusparse + from cupyx.scipy import sparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + A = self.matrix + + if not (cusparse.check_availability('spsm') or + cusparse.check_availability('csrsm2')): + raise NotImplementedError + + if not sparse.isspmatrix(A): + raise TypeError('A must be cupyx.scipy.sparse.spmatrix') + if not isinstance(b, cupy.ndarray): + raise TypeError('b must be cupy.ndarray') + if A.shape[0] != A.shape[1]: + raise ValueError(f'A must be a square matrix (A.shape: {A.shape})') + if b.ndim not in [1, 2]: + raise ValueError(f'b must be 1D or 2D array (b.shape: {b.shape})') + if A.shape[0] != b.shape[0]: + raise ValueError('The size of dimensions of A must be equal to the ' + 'size of the first dimension of b ' + f'(A.shape: {A.shape}, b.shape: {b.shape})') + if A.dtype.char not in 'fdFD': + raise TypeError(f'unsupported dtype (actual: {A.dtype})') + + if cusparse.check_availability('spsm') and _should_use_spsm(b): + if not (sparse.isspmatrix_csr(A) or + sparse.isspmatrix_csc(A) or + sparse.isspmatrix_coo(A)): + warnings.warn('CSR, CSC or COO format is required. Converting to ' + 'CSR format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + x, self.spsm_descr = custom_spsm(A, b, lower=lower, unit_diag=unit_diagonal, spsm_descr=self.spsm_descr) + elif cusparse.check_availability('csrsm2'): + if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): + warnings.warn('CSR or CSC format is required. Converting to CSR ' + 'format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + + if (overwrite_b and A.dtype == b.dtype and + (b._c_contiguous or b._f_contiguous)): + x = b + else: + x = b.astype(A.dtype, copy=True) + + cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) + else: + assert False + + # TODO: Check if need this (breaks things for float32?) + # if x.dtype.char in 'fF': + # # Note: This is for compatibility with SciPy. + # dtype = numpy.promote_types(x.dtype, 'float64') + # x = x.astype(dtype) + return x + + +def cupy_solve_upper_csr(matrix, array, axis, out): + """Solve upper triangular CSR matrix along specified axis of an array.""" + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + cupy_solve_upper_csr_vec(matrix, array, out) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + cupy_solve_upper_csr_vec(matrix, array[:,0], out[:,0]) + else: + cupy_solve_upper_csr_first(matrix, array, out) + elif axis == 1: + if array.shape[0] == 1: + cupy_solve_upper_csr_vec(matrix, array[0,:], out[0,:]) + else: + cupy_solve_upper_csr_last(matrix, array, out) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = shape[0] + N2 = shape[1] + N3 = shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + y1 = out.reshape((N2,)) + cupy_solve_upper_csr_vec(matrix, x1, y1) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + y2 = out.reshape((N2, N3)) + cupy_solve_upper_csr_first(matrix, x2, y2) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + y2 = out.reshape((N1, N2)) + cupy_solve_upper_csr_last(matrix, x2, y2) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape((N1, N2, N3)) + cupy_solve_upper_csr_mid(matrix, x3, y3) + + +def cupy_solve_upper_csr_vec(matrix, vec, out): + """Solve upper triangular CSR matrix along a vector.""" + out[:] = matrix.solve(vec, lower=False) + + +def cupy_solve_upper_csr_first(matrix, array, out): + """Solve upper triangular CSR matrix along first axis of 2D array.""" + out[:] = matrix.solve(array, lower=False) + + +def cupy_solve_upper_csr_last(matrix, array, out): + """Solve upper triangular CSR matrix along last axis of 2D array.""" + out.T[:] = matrix.solve(array.T, lower=False) + + +def cupy_solve_upper_csr_mid(matrix, array, out): + """Solve upper triangular CSR matrix along middle axis of 3D array.""" + raise NotImplementedError diff --git a/setup.py b/setup.py index 1cf1d9dd..583009a6 100644 --- a/setup.py +++ b/setup.py @@ -181,6 +181,7 @@ def read(rel_path): # Runtime requirements install_requires = [ + "array-api-compat", "docopt", "h5py >= 3.0.0", "matplotlib",