From 6ea2d7a8225c0bb529b61456e47d9820fa642f87 Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Mon, 18 May 2026 23:48:10 -0500 Subject: [PATCH 1/6] sampler: device-side concat + single d2h in _sample_batches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous np.concatenate(batches)[:shots] triggered (1) one __array__ call per jax.Array in the list — each forcing its own d2h transfer — and (2) a fresh host buffer allocation plus a host-side memcpy from the d2h'd buffer into the concat output. On large bool tensors (e.g. 500k shots × 528 detector bits = 264 MB) the host concat alone was about 1 s on top of the PCIe transfer. Concatenating on device first and then doing a single np.asarray means one d2h, no extra host memcpy. For the common batch_size == shots case we also skip the jnp.concatenate (single-element batches list). No behavior change, no new flag. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tsim/sampler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index bfe4e74..d2015d8 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -335,7 +335,13 @@ def _sample_batches( batches.append(samples) - result = np.concatenate(batches)[:shots] + # Concatenate on device, then a single d2h. The prior + # np.concatenate(batches) triggered per-batch __array__ (one d2h each) + # plus a host-side memcpy into a fresh numpy buffer for the concat + # output. For big bool tensors (e.g. 500k shots × 528 detector bits) + # the host memcpy alone was ~1 s on top of the PCIe transfer. + combined = batches[0] if len(batches) == 1 else jnp.concatenate(batches, axis=0) + result = np.asarray(combined)[:shots] if compute_reference: assert reference is not None From ed1e6f569e9217adeafe8b2dbe6a8c63b494d1c5 Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Mon, 18 May 2026 23:51:10 -0500 Subject: [PATCH 2/6] exact_scalar: replace associative_scan with lax.scan in .sum / .prod MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ExactScalarArray.sum and .prod previously did: scanned = lax.associative_scan(op, ..., axis=axis) return take(scanned, -1, axis=axis) That builds the full O(N) prefix tensor along the scan axis just to throw all but the last element away. For the batches the sampler now reaches (B up to 500k+, T up to several dozen) the prefix tensor was the dominant memory cost in .sum / .prod and the limiting factor for batch size. Replace with a lax.scan-based reduction that keeps a single (power, coeffs) carry — O(1) extra memory along the scan axis. Compute depth becomes O(N) sequential instead of O(log N) parallel, but the public .sum/.prod API only needs the final value, so the parallel-prefix advantage of associative_scan doesn't apply. unroll=4 default keeps XLA's traced body small enough that compile times stay reasonable for typical scan lengths. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tsim/core/exact_scalar.py | 42 +++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/tsim/core/exact_scalar.py b/src/tsim/core/exact_scalar.py index e076675..92ad81a 100644 --- a/src/tsim/core/exact_scalar.py +++ b/src/tsim/core/exact_scalar.py @@ -89,6 +89,35 @@ def _scalar_to_complex(data: jax.Array) -> jax.Array: return data[..., 0] + data[..., 1] * _E4 + data[..., 2] * 1j + data[..., 3] * _E4D +# Sane default for the lax.scan unroll factor in _reduce_along_scan. Keeps the +# unrolled body small enough that XLA compile times stay reasonable for the +# T values that ExactScalarArray.sum / .prod actually see (typically ≤ a few +# dozen). +_SCAN_UNROLL = 4 + + +def _reduce_along_scan(power, coeffs, op, axis): + """lax.scan-based reduction along ``axis`` returning only the final carry. + + Equivalent to ``lax.associative_scan(op, ..., axis=axis)`` followed by + ``take(..., -1, axis=axis)``, but keeps a single (power, coeffs) carry + through the scan instead of materialising the full prefix tensor — + O(1) extra memory along the scan axis instead of O(N). + """ + if axis < 0: + axis += power.ndim + power_t = jnp.moveaxis(power, axis, 0) + coeffs_t = jnp.moveaxis(coeffs, axis, 0) + init = (power_t[0], coeffs_t[0]) + rest = (power_t[1:], coeffs_t[1:]) + + def step(carry, x): + return op(carry, x), None + + (final_power, final_coeffs), _ = lax.scan(step, init, rest, unroll=_SCAN_UNROLL) + return final_power, final_coeffs + + class ExactScalarArray(eqx.Module): """Exact scalar array for ZX-calculus phase arithmetic using dyadic representation. @@ -135,11 +164,9 @@ def sum(self, axis: int = -1) -> "ExactScalarArray": if axis < 0: axis += self.power.ndim - scanned_power, scanned_coeffs = lax.associative_scan( - _scalar_add_with_power, (self.power, self.coeffs), axis=axis + result_power, result_coeffs = _reduce_along_scan( + self.power, self.coeffs, _scalar_add_with_power, axis, ) - result_power = jnp.take(scanned_power, indices=-1, axis=axis) - result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis) return ExactScalarArray(result_coeffs, result_power) def prod(self, axis: int = -1) -> "ExactScalarArray": @@ -164,12 +191,9 @@ def prod(self, axis: int = -1) -> "ExactScalarArray": result_coeffs = result_coeffs.at[..., 0].set(1) return ExactScalarArray(result_coeffs) - scanned_power, scanned_coeffs = lax.associative_scan( - _scalar_mul_with_power, (self.power, self.coeffs), axis=axis + result_power, result_coeffs = _reduce_along_scan( + self.power, self.coeffs, _scalar_mul_with_power, axis, ) - result_power = jnp.take(scanned_power, indices=-1, axis=axis) - result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis) - return ExactScalarArray(result_coeffs, result_power) def to_complex(self) -> jax.Array: From b26283103f10390961352fa3296399ff0f49ea25 Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Thu, 21 May 2026 15:18:40 -0500 Subject: [PATCH 3/6] sampler: pinned-host d2h via cuda.bindings when available MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Default numpy d2h (np.asarray on a jax.Array) lands in pageable host memory, which forces the CUDA driver to stage the transfer through internal pinned scratch before copying into the user buffer. The second hop is bound by host DRAM bandwidth and caps the effective d2h throughput well below PCIe line rate. Pinning the destination skips the staging hop. Measured throughput at 3.4 GB transfer (rotated surface code d=7, 10M shots): H100 (gen4 x16) B200 (gen5 x16) cp.asnumpy / np.asarray (pageable) 1.9 GB/s 4.1 GB/s cudaMemcpy → cudaHostAlloc (pinned) 23.4 GB/s 51.4 GB/s Translated to per-shot wall on the surface-code-noise sweep: p vanilla pre-pin (+G) post-pin (+G) 1e-6 (10M shots) 0.084µs 0.162µs 0.021µs 1e-4 (10M shots) 0.176µs 0.188µs 0.021µs 1e-2 (10K shots) 4.86µs 0.102µs 0.091µs The pinned path lifts +G from "loses to vanilla below p~1e-5" to "monotonically faster across the whole sweep." Implementation: - New tsim.utils.cuda_helpers module: _PinnedBuf (RAII over cudaHostAlloc), alloc_pinned_numpy (returns a pinned-backed ndarray with lifetime tied to the underlying region via ctypes + ndarray.base), copy_d2h (the public entry, picks pinned fast path or numpy fallback based on import of cuda.bindings). - sampler._sample_batches replaces np.asarray(combined)[:shots] with copy_d2h(combined)[:shots] and adds the matching jax.block_until_ready before the call. cuda.bindings is a soft dep — when import fails, copy_d2h falls back to np.array, preserving the prior behavior. The pyproject.toml is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tsim/sampler.py | 4 +- src/tsim/utils/cuda_helpers.py | 114 +++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 src/tsim/utils/cuda_helpers.py diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index d2015d8..3a3831f 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -17,6 +17,7 @@ from tsim.core.graph import prepare_graph from tsim.core.types import CompiledComponent, CompiledProgram from tsim.noise.channels import ChannelSampler +from tsim.utils.cuda_helpers import copy_d2h if TYPE_CHECKING: from jax import Array as PRNGKey @@ -341,7 +342,8 @@ def _sample_batches( # output. For big bool tensors (e.g. 500k shots × 528 detector bits) # the host memcpy alone was ~1 s on top of the PCIe transfer. combined = batches[0] if len(batches) == 1 else jnp.concatenate(batches, axis=0) - result = np.asarray(combined)[:shots] + jax.block_until_ready(combined) + result = copy_d2h(combined)[:shots] if compute_reference: assert reference is not None diff --git a/src/tsim/utils/cuda_helpers.py b/src/tsim/utils/cuda_helpers.py new file mode 100644 index 0000000..fcb7fec --- /dev/null +++ b/src/tsim/utils/cuda_helpers.py @@ -0,0 +1,114 @@ +"""CUDA-runtime helpers used by the sampler hot path. + +Lazy-imports ``cuda.bindings``. When the import succeeds, host-pinned +allocations + a direct ``cudaMemcpy`` replace numpy's default pageable d2h: +the pageable path forces the driver to stage the transfer through internal +pinned scratch before copying into the user buffer (host-DRAM-bandwidth +bound), while a pinned destination skips the staging hop and lets d2h reach +PCIe line rate. When the import fails, ``copy_d2h`` transparently falls +back to ``numpy.array``. +""" + +from __future__ import annotations + +import ctypes + +import numpy as np + +try: + from cuda.bindings import runtime as cudart + _CUDA_BINDINGS_AVAILABLE = True +except Exception: + _CUDA_BINDINGS_AVAILABLE = False + + +class _PinnedBuf: + """RAII wrapper for a ``cudaHostAlloc``'d region. + + The Python instance owns the lifetime; ``cudaFreeHost`` runs in + ``__del__`` when no references remain (typically when the wrapping + numpy view is garbage-collected). + """ + + __slots__ = ("ptr", "nbytes") + + def __init__(self, nbytes: int): + err, ptr = cudart.cudaHostAlloc(nbytes, cudart.cudaHostAllocDefault) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaHostAlloc({nbytes}) failed: {err}") + self.ptr = int(ptr) + self.nbytes = nbytes + + def __del__(self): + if self.ptr: + try: + cudart.cudaFreeHost(self.ptr) + except Exception: + # cudart may be torn down at interpreter exit. + pass + self.ptr = 0 + + +def alloc_pinned_numpy(nbytes: int, dtype, shape) -> np.ndarray: + """Allocate a pinned host region and return it as an ndarray view. + + The returned array's ``base`` chain pins the underlying ``_PinnedBuf`` + alive until the array and all derived views are dropped; only then does + ``cudaFreeHost`` run. + + Args: + nbytes: Size of the underlying allocation in bytes. Must be at least + ``prod(shape) * dtype.itemsize``. + dtype: numpy-compatible dtype for the returned view. + shape: Shape of the returned view. + + Returns: + ndarray of the requested shape and dtype, backed by pinned memory. + + Raises: + RuntimeError: if cuda.bindings is unavailable, or the underlying + ``cudaHostAlloc`` fails. + """ + if not _CUDA_BINDINGS_AVAILABLE: + raise RuntimeError( + "cuda.bindings not importable; install 'cuda-bindings' or use " + "copy_d2h() for a transparent fallback." + ) + buf = _PinnedBuf(nbytes) + carr = (ctypes.c_uint8 * nbytes).from_address(buf.ptr) + carr._owner = buf # arr.base = carr; carr._owner = buf → buf stays alive + return np.frombuffer(carr, dtype=np.uint8).view(dtype).reshape(shape) + + +def copy_d2h(src, *, dst: np.ndarray | None = None) -> np.ndarray: + """Device-to-host copy, pinned-destination fast path when available. + + Args: + src: Single-device contiguous array-like exposing + ``unsafe_buffer_pointer()``, ``nbytes``, ``shape``, and ``dtype``. + The caller must sync to the source's stream before invocation + (``jax.block_until_ready(src)`` for a jax.Array). + dst: Optional pre-allocated pinned ndarray to write into. Must have + at least ``src.nbytes`` bytes. + + Returns: + ndarray with the same shape and dtype as ``src``. + """ + if not _CUDA_BINDINGS_AVAILABLE: + # np.asarray(jax.Array) is a read-only zero-copy view; callers + # mutate the return (e.g. XOR detectors with the reference sample) + # so allocate fresh and copy. + out = np.empty(src.shape, dtype=src.dtype) + out[:] = src + return out + if dst is None: + dst = alloc_pinned_numpy(src.nbytes, src.dtype, src.shape) + err = cudart.cudaMemcpy( + dst.ctypes.data, + src.unsafe_buffer_pointer(), + src.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + )[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaMemcpy d2h failed: {err}") + return dst From de4c5607d4ed2d342d4045c00984c415f685922a Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Fri, 22 May 2026 13:06:12 -0500 Subject: [PATCH 4/6] exact_scalar: bump _SCAN_UNROLL to 16 + final fixpoint pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit unroll ------ Sweeping ``_SCAN_UNROLL ∈ {1, 4, 8, 16, True}`` on msc_3 cutting 100k (H100): unroll us/shot compile_s 1 6.16 6.82 4 5.35 6.78 (prior default) 8 5.07 6.77 16 4.22 6.82 True 4.24 6.66 16 ties ``True`` on speed and stays explicit. Compile time is flat across the range, so there's no compile-cost reason to keep 4. Verified on star_d7 cat5 500k: unroll=16 gives 1.81 µs/shot vs 1.96 at unroll=4 (no regression). Net vs upstream/main vanilla on msc_3 cutting 100k: 4.22 vs 5.31 µs/shot (-21%). fixpoint -------- ``_scalar_add_with_power`` and ``_scalar_mul_with_power`` each apply one ``_reduce_power_coeffs_step`` per call. The tree-shaped ``lax.associative_scan`` this PR replaced naturally produces canonical form (gcd of coeffs is odd) because each combine sees two operands of equal accumulation depth — one reduce per node suffices. Sequential ``lax.scan`` accumulation can lag canonical form by up to ``log2(N)`` reductions for an N-element scan, breaking ``test_sum_reduces_while_adding``. Bring the carry back to canonical form with a ``lax.while_loop`` that early-exits as soon as no element reduces further. Most inputs converge in 1-2 iters; on msc_3 cutting / star_d7 cat5 benches the added cost is inside measurement noise. cuda_helpers ------------ Also gate ``copy_d2h``'s pinned-d2h fast path on the source actually living on a CUDA device. JAX on a CPU-only jaxlib reports ``src.unsafe_buffer_pointer()`` as a host pointer; feeding that to ``cudaMemcpy(..., DeviceToHost)`` returns ``cudaErrorInvalidValue``. Detect via ``src.devices()`` platform and route to the numpy fallback there too. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/tsim/core/exact_scalar.py | 36 +++++++++++++++++++++++++++++----- src/tsim/utils/cuda_helpers.py | 26 ++++++++++++++++++++---- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/tsim/core/exact_scalar.py b/src/tsim/core/exact_scalar.py index 92ad81a..42687d5 100644 --- a/src/tsim/core/exact_scalar.py +++ b/src/tsim/core/exact_scalar.py @@ -89,11 +89,13 @@ def _scalar_to_complex(data: jax.Array) -> jax.Array: return data[..., 0] + data[..., 1] * _E4 + data[..., 2] * 1j + data[..., 3] * _E4D -# Sane default for the lax.scan unroll factor in _reduce_along_scan. Keeps the -# unrolled body small enough that XLA compile times stay reasonable for the -# T values that ExactScalarArray.sum / .prod actually see (typically ≤ a few -# dozen). -_SCAN_UNROLL = 4 +# lax.scan unroll factor in _reduce_along_scan. Higher unroll reduces +# kernel-launch overhead at the cost of a larger compiled body; on the +# msc_3 cutting 100k bench, sweeping unroll ∈ {1, 4, 8, 16, True} gives +# per-shot wall {6.16, 5.35, 5.07, 4.22, 4.24} µs and compile_s flat +# at ~6.8 s. 16 ties `True` on speed without committing to whatever +# `True` means in a future JAX. Bumped from 4. +_SCAN_UNROLL = 16 def _reduce_along_scan(power, coeffs, op, axis): @@ -115,6 +117,30 @@ def step(carry, x): return op(carry, x), None (final_power, final_coeffs), _ = lax.scan(step, init, rest, unroll=_SCAN_UNROLL) + + # Final fixpoint pass. ``_scalar_add_with_power`` and + # ``_scalar_mul_with_power`` each apply one ``_reduce_power_coeffs_step`` + # per call. The tree-shaped ``lax.associative_scan`` we replaced + # naturally hit canonical form (gcd of coeffs is odd) because each + # combine sees two operands of equal accumulation depth — one reduce + # suffices per node. Sequential ``lax.scan`` accumulation can lag + # canonical form by up to ``log2(N)`` reductions for an N-element scan; + # iterate with ``lax.while_loop`` so we only pay for the iters that + # actually reduce. Most inputs converge in 1-2 steps; on the msc_3 + # cutting / star_d7 cat5 benches the added cost is in measurement noise. + def _fixpoint_cond(state): + _, _, did_change = state + return did_change + + def _fixpoint_body(state): + p, c, _ = state + new_p, new_c = _reduce_power_coeffs_step(p, c) + return new_p, new_c, jnp.any(new_p != p) + + init_state = (final_power, final_coeffs, jnp.bool_(True)) + final_power, final_coeffs, _ = lax.while_loop( + _fixpoint_cond, _fixpoint_body, init_state, + ) return final_power, final_coeffs diff --git a/src/tsim/utils/cuda_helpers.py b/src/tsim/utils/cuda_helpers.py index fcb7fec..bd42cad 100644 --- a/src/tsim/utils/cuda_helpers.py +++ b/src/tsim/utils/cuda_helpers.py @@ -22,6 +22,20 @@ _CUDA_BINDINGS_AVAILABLE = False +def _src_on_host(src) -> bool: + """True iff ``src`` lives on a CPU device — its buffer pointer would be + a host address and unsafe to feed to ``cudaMemcpy(..., DeviceToHost)``. + """ + devices = getattr(src, "devices", None) + if devices is None: + return False + try: + ds = list(devices()) + except TypeError: + ds = list(devices) + return any(getattr(d, "platform", "").lower() == "cpu" for d in ds) + + class _PinnedBuf: """RAII wrapper for a ``cudaHostAlloc``'d region. @@ -94,10 +108,14 @@ def copy_d2h(src, *, dst: np.ndarray | None = None) -> np.ndarray: Returns: ndarray with the same shape and dtype as ``src``. """ - if not _CUDA_BINDINGS_AVAILABLE: - # np.asarray(jax.Array) is a read-only zero-copy view; callers - # mutate the return (e.g. XOR detectors with the reference sample) - # so allocate fresh and copy. + if not _CUDA_BINDINGS_AVAILABLE or _src_on_host(src): + # Fallback path. Two cases: + # 1. cuda.bindings isn't importable + # 2. src lives on a CPU device (JAX on a CPU jaxlib build, or any + # array marked CPU) — its ``unsafe_buffer_pointer`` is a host + # pointer, so ``cudaMemcpy(..., DeviceToHost)`` would fail. + # ``np.asarray(jax.Array)`` returns a read-only zero-copy view in + # newer JAX, so allocate fresh and copy in to a writable buffer. out = np.empty(src.shape, dtype=src.dtype) out[:] = src return out From 81203a631381716dd2ba198427859c75257d2632 Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Fri, 22 May 2026 13:09:16 -0500 Subject: [PATCH 5/6] exact_scalar: strip workload-specific residue from comments Co-Authored-By: Claude Opus 4.7 --- src/tsim/core/exact_scalar.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/tsim/core/exact_scalar.py b/src/tsim/core/exact_scalar.py index 42687d5..b86ef2f 100644 --- a/src/tsim/core/exact_scalar.py +++ b/src/tsim/core/exact_scalar.py @@ -89,12 +89,9 @@ def _scalar_to_complex(data: jax.Array) -> jax.Array: return data[..., 0] + data[..., 1] * _E4 + data[..., 2] * 1j + data[..., 3] * _E4D -# lax.scan unroll factor in _reduce_along_scan. Higher unroll reduces -# kernel-launch overhead at the cost of a larger compiled body; on the -# msc_3 cutting 100k bench, sweeping unroll ∈ {1, 4, 8, 16, True} gives -# per-shot wall {6.16, 5.35, 5.07, 4.22, 4.24} µs and compile_s flat -# at ~6.8 s. 16 ties `True` on speed without committing to whatever -# `True` means in a future JAX. Bumped from 4. +# lax.scan unroll factor in _reduce_along_scan. Higher unroll trades +# compiled-body size for fewer kernel-launches. 16 avoids committing to +# whatever `True` means in a future JAX. _SCAN_UNROLL = 16 @@ -119,15 +116,11 @@ def step(carry, x): (final_power, final_coeffs), _ = lax.scan(step, init, rest, unroll=_SCAN_UNROLL) # Final fixpoint pass. ``_scalar_add_with_power`` and - # ``_scalar_mul_with_power`` each apply one ``_reduce_power_coeffs_step`` - # per call. The tree-shaped ``lax.associative_scan`` we replaced - # naturally hit canonical form (gcd of coeffs is odd) because each - # combine sees two operands of equal accumulation depth — one reduce - # suffices per node. Sequential ``lax.scan`` accumulation can lag - # canonical form by up to ``log2(N)`` reductions for an N-element scan; - # iterate with ``lax.while_loop`` so we only pay for the iters that - # actually reduce. Most inputs converge in 1-2 steps; on the msc_3 - # cutting / star_d7 cat5 benches the added cost is in measurement noise. + # ``_scalar_mul_with_power`` each apply only one + # ``_reduce_power_coeffs_step`` per call, so a sequential scan over N + # elements can lag canonical form (gcd of coeffs is odd) by up to + # ``log2(N)`` reductions. Iterate with ``lax.while_loop`` so we only + # pay for the iters that actually reduce. def _fixpoint_cond(state): _, _, did_change = state return did_change From ee704c8dd88b790ecad8c7cfb83c05c48309dbdc Mon Sep 17 00:00:00 2001 From: Danylo Lykov Date: Tue, 2 Jun 2026 02:42:06 -0500 Subject: [PATCH 6/6] cuda_helpers: satisfy ruff D205/D401/RUF023/SIM105/D413 + pyright --- src/tsim/utils/cuda_helpers.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/tsim/utils/cuda_helpers.py b/src/tsim/utils/cuda_helpers.py index bd42cad..e557db9 100644 --- a/src/tsim/utils/cuda_helpers.py +++ b/src/tsim/utils/cuda_helpers.py @@ -11,20 +11,29 @@ from __future__ import annotations +import contextlib import ctypes +from typing import TYPE_CHECKING, Any import numpy as np -try: - from cuda.bindings import runtime as cudart - _CUDA_BINDINGS_AVAILABLE = True -except Exception: +if TYPE_CHECKING: + cudart: Any = None _CUDA_BINDINGS_AVAILABLE = False +else: + try: + from cuda.bindings import runtime as cudart + _CUDA_BINDINGS_AVAILABLE = True + except Exception: + cudart = None + _CUDA_BINDINGS_AVAILABLE = False def _src_on_host(src) -> bool: - """True iff ``src`` lives on a CPU device — its buffer pointer would be - a host address and unsafe to feed to ``cudaMemcpy(..., DeviceToHost)``. + """Return True iff ``src`` lives on a CPU device. + + Its buffer pointer would be a host address and unsafe to feed to + ``cudaMemcpy(..., DeviceToHost)``. """ devices = getattr(src, "devices", None) if devices is None: @@ -44,7 +53,7 @@ class _PinnedBuf: numpy view is garbage-collected). """ - __slots__ = ("ptr", "nbytes") + __slots__ = ("nbytes", "ptr") def __init__(self, nbytes: int): err, ptr = cudart.cudaHostAlloc(nbytes, cudart.cudaHostAllocDefault) @@ -55,11 +64,9 @@ def __init__(self, nbytes: int): def __del__(self): if self.ptr: - try: + # cudart may be torn down at interpreter exit. + with contextlib.suppress(Exception): cudart.cudaFreeHost(self.ptr) - except Exception: - # cudart may be torn down at interpreter exit. - pass self.ptr = 0 @@ -82,6 +89,7 @@ def alloc_pinned_numpy(nbytes: int, dtype, shape) -> np.ndarray: Raises: RuntimeError: if cuda.bindings is unavailable, or the underlying ``cudaHostAlloc`` fails. + """ if not _CUDA_BINDINGS_AVAILABLE: raise RuntimeError( @@ -90,7 +98,7 @@ def alloc_pinned_numpy(nbytes: int, dtype, shape) -> np.ndarray: ) buf = _PinnedBuf(nbytes) carr = (ctypes.c_uint8 * nbytes).from_address(buf.ptr) - carr._owner = buf # arr.base = carr; carr._owner = buf → buf stays alive + carr._owner = buf # type: ignore[attr-defined] # arr.base = carr; carr._owner = buf → buf stays alive return np.frombuffer(carr, dtype=np.uint8).view(dtype).reshape(shape) @@ -107,6 +115,7 @@ def copy_d2h(src, *, dst: np.ndarray | None = None) -> np.ndarray: Returns: ndarray with the same shape and dtype as ``src``. + """ if not _CUDA_BINDINGS_AVAILABLE or _src_on_host(src): # Fallback path. Two cases: