diff --git a/src/tsim/core/exact_scalar.py b/src/tsim/core/exact_scalar.py index e0766752..b86ef2f2 100644 --- a/src/tsim/core/exact_scalar.py +++ b/src/tsim/core/exact_scalar.py @@ -89,6 +89,54 @@ def _scalar_to_complex(data: jax.Array) -> jax.Array: return data[..., 0] + data[..., 1] * _E4 + data[..., 2] * 1j + data[..., 3] * _E4D +# lax.scan unroll factor in _reduce_along_scan. Higher unroll trades +# compiled-body size for fewer kernel-launches. 16 avoids committing to +# whatever `True` means in a future JAX. +_SCAN_UNROLL = 16 + + +def _reduce_along_scan(power, coeffs, op, axis): + """lax.scan-based reduction along ``axis`` returning only the final carry. + + Equivalent to ``lax.associative_scan(op, ..., axis=axis)`` followed by + ``take(..., -1, axis=axis)``, but keeps a single (power, coeffs) carry + through the scan instead of materialising the full prefix tensor — + O(1) extra memory along the scan axis instead of O(N). + """ + if axis < 0: + axis += power.ndim + power_t = jnp.moveaxis(power, axis, 0) + coeffs_t = jnp.moveaxis(coeffs, axis, 0) + init = (power_t[0], coeffs_t[0]) + rest = (power_t[1:], coeffs_t[1:]) + + def step(carry, x): + return op(carry, x), None + + (final_power, final_coeffs), _ = lax.scan(step, init, rest, unroll=_SCAN_UNROLL) + + # Final fixpoint pass. ``_scalar_add_with_power`` and + # ``_scalar_mul_with_power`` each apply only one + # ``_reduce_power_coeffs_step`` per call, so a sequential scan over N + # elements can lag canonical form (gcd of coeffs is odd) by up to + # ``log2(N)`` reductions. Iterate with ``lax.while_loop`` so we only + # pay for the iters that actually reduce. + def _fixpoint_cond(state): + _, _, did_change = state + return did_change + + def _fixpoint_body(state): + p, c, _ = state + new_p, new_c = _reduce_power_coeffs_step(p, c) + return new_p, new_c, jnp.any(new_p != p) + + init_state = (final_power, final_coeffs, jnp.bool_(True)) + final_power, final_coeffs, _ = lax.while_loop( + _fixpoint_cond, _fixpoint_body, init_state, + ) + return final_power, final_coeffs + + class ExactScalarArray(eqx.Module): """Exact scalar array for ZX-calculus phase arithmetic using dyadic representation. @@ -135,11 +183,9 @@ def sum(self, axis: int = -1) -> "ExactScalarArray": if axis < 0: axis += self.power.ndim - scanned_power, scanned_coeffs = lax.associative_scan( - _scalar_add_with_power, (self.power, self.coeffs), axis=axis + result_power, result_coeffs = _reduce_along_scan( + self.power, self.coeffs, _scalar_add_with_power, axis, ) - result_power = jnp.take(scanned_power, indices=-1, axis=axis) - result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis) return ExactScalarArray(result_coeffs, result_power) def prod(self, axis: int = -1) -> "ExactScalarArray": @@ -164,12 +210,9 @@ def prod(self, axis: int = -1) -> "ExactScalarArray": result_coeffs = result_coeffs.at[..., 0].set(1) return ExactScalarArray(result_coeffs) - scanned_power, scanned_coeffs = lax.associative_scan( - _scalar_mul_with_power, (self.power, self.coeffs), axis=axis + result_power, result_coeffs = _reduce_along_scan( + self.power, self.coeffs, _scalar_mul_with_power, axis, ) - result_power = jnp.take(scanned_power, indices=-1, axis=axis) - result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis) - return ExactScalarArray(result_coeffs, result_power) def to_complex(self) -> jax.Array: diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index bfe4e746..3a3831fa 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -17,6 +17,7 @@ from tsim.core.graph import prepare_graph from tsim.core.types import CompiledComponent, CompiledProgram from tsim.noise.channels import ChannelSampler +from tsim.utils.cuda_helpers import copy_d2h if TYPE_CHECKING: from jax import Array as PRNGKey @@ -335,7 +336,14 @@ def _sample_batches( batches.append(samples) - result = np.concatenate(batches)[:shots] + # Concatenate on device, then a single d2h. The prior + # np.concatenate(batches) triggered per-batch __array__ (one d2h each) + # plus a host-side memcpy into a fresh numpy buffer for the concat + # output. For big bool tensors (e.g. 500k shots × 528 detector bits) + # the host memcpy alone was ~1 s on top of the PCIe transfer. + combined = batches[0] if len(batches) == 1 else jnp.concatenate(batches, axis=0) + jax.block_until_ready(combined) + result = copy_d2h(combined)[:shots] if compute_reference: assert reference is not None diff --git a/src/tsim/utils/cuda_helpers.py b/src/tsim/utils/cuda_helpers.py new file mode 100644 index 00000000..e557db99 --- /dev/null +++ b/src/tsim/utils/cuda_helpers.py @@ -0,0 +1,141 @@ +"""CUDA-runtime helpers used by the sampler hot path. + +Lazy-imports ``cuda.bindings``. When the import succeeds, host-pinned +allocations + a direct ``cudaMemcpy`` replace numpy's default pageable d2h: +the pageable path forces the driver to stage the transfer through internal +pinned scratch before copying into the user buffer (host-DRAM-bandwidth +bound), while a pinned destination skips the staging hop and lets d2h reach +PCIe line rate. When the import fails, ``copy_d2h`` transparently falls +back to ``numpy.array``. +""" + +from __future__ import annotations + +import contextlib +import ctypes +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + cudart: Any = None + _CUDA_BINDINGS_AVAILABLE = False +else: + try: + from cuda.bindings import runtime as cudart + _CUDA_BINDINGS_AVAILABLE = True + except Exception: + cudart = None + _CUDA_BINDINGS_AVAILABLE = False + + +def _src_on_host(src) -> bool: + """Return True iff ``src`` lives on a CPU device. + + Its buffer pointer would be a host address and unsafe to feed to + ``cudaMemcpy(..., DeviceToHost)``. + """ + devices = getattr(src, "devices", None) + if devices is None: + return False + try: + ds = list(devices()) + except TypeError: + ds = list(devices) + return any(getattr(d, "platform", "").lower() == "cpu" for d in ds) + + +class _PinnedBuf: + """RAII wrapper for a ``cudaHostAlloc``'d region. + + The Python instance owns the lifetime; ``cudaFreeHost`` runs in + ``__del__`` when no references remain (typically when the wrapping + numpy view is garbage-collected). + """ + + __slots__ = ("nbytes", "ptr") + + def __init__(self, nbytes: int): + err, ptr = cudart.cudaHostAlloc(nbytes, cudart.cudaHostAllocDefault) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaHostAlloc({nbytes}) failed: {err}") + self.ptr = int(ptr) + self.nbytes = nbytes + + def __del__(self): + if self.ptr: + # cudart may be torn down at interpreter exit. + with contextlib.suppress(Exception): + cudart.cudaFreeHost(self.ptr) + self.ptr = 0 + + +def alloc_pinned_numpy(nbytes: int, dtype, shape) -> np.ndarray: + """Allocate a pinned host region and return it as an ndarray view. + + The returned array's ``base`` chain pins the underlying ``_PinnedBuf`` + alive until the array and all derived views are dropped; only then does + ``cudaFreeHost`` run. + + Args: + nbytes: Size of the underlying allocation in bytes. Must be at least + ``prod(shape) * dtype.itemsize``. + dtype: numpy-compatible dtype for the returned view. + shape: Shape of the returned view. + + Returns: + ndarray of the requested shape and dtype, backed by pinned memory. + + Raises: + RuntimeError: if cuda.bindings is unavailable, or the underlying + ``cudaHostAlloc`` fails. + + """ + if not _CUDA_BINDINGS_AVAILABLE: + raise RuntimeError( + "cuda.bindings not importable; install 'cuda-bindings' or use " + "copy_d2h() for a transparent fallback." + ) + buf = _PinnedBuf(nbytes) + carr = (ctypes.c_uint8 * nbytes).from_address(buf.ptr) + carr._owner = buf # type: ignore[attr-defined] # arr.base = carr; carr._owner = buf → buf stays alive + return np.frombuffer(carr, dtype=np.uint8).view(dtype).reshape(shape) + + +def copy_d2h(src, *, dst: np.ndarray | None = None) -> np.ndarray: + """Device-to-host copy, pinned-destination fast path when available. + + Args: + src: Single-device contiguous array-like exposing + ``unsafe_buffer_pointer()``, ``nbytes``, ``shape``, and ``dtype``. + The caller must sync to the source's stream before invocation + (``jax.block_until_ready(src)`` for a jax.Array). + dst: Optional pre-allocated pinned ndarray to write into. Must have + at least ``src.nbytes`` bytes. + + Returns: + ndarray with the same shape and dtype as ``src``. + + """ + if not _CUDA_BINDINGS_AVAILABLE or _src_on_host(src): + # Fallback path. Two cases: + # 1. cuda.bindings isn't importable + # 2. src lives on a CPU device (JAX on a CPU jaxlib build, or any + # array marked CPU) — its ``unsafe_buffer_pointer`` is a host + # pointer, so ``cudaMemcpy(..., DeviceToHost)`` would fail. + # ``np.asarray(jax.Array)`` returns a read-only zero-copy view in + # newer JAX, so allocate fresh and copy in to a writable buffer. + out = np.empty(src.shape, dtype=src.dtype) + out[:] = src + return out + if dst is None: + dst = alloc_pinned_numpy(src.nbytes, src.dtype, src.shape) + err = cudart.cudaMemcpy( + dst.ctypes.data, + src.unsafe_buffer_pointer(), + src.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + )[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaMemcpy d2h failed: {err}") + return dst