Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions src/tsim/core/exact_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,54 @@ def _scalar_to_complex(data: jax.Array) -> jax.Array:
return data[..., 0] + data[..., 1] * _E4 + data[..., 2] * 1j + data[..., 3] * _E4D


# lax.scan unroll factor in _reduce_along_scan. Higher unroll trades
# compiled-body size for fewer kernel-launches. 16 avoids committing to
# whatever `True` means in a future JAX.
_SCAN_UNROLL = 16


def _reduce_along_scan(power, coeffs, op, axis):
"""lax.scan-based reduction along ``axis`` returning only the final carry.

Equivalent to ``lax.associative_scan(op, ..., axis=axis)`` followed by
``take(..., -1, axis=axis)``, but keeps a single (power, coeffs) carry
through the scan instead of materialising the full prefix tensor —
O(1) extra memory along the scan axis instead of O(N).
"""
if axis < 0:
axis += power.ndim
power_t = jnp.moveaxis(power, axis, 0)
coeffs_t = jnp.moveaxis(coeffs, axis, 0)
init = (power_t[0], coeffs_t[0])
rest = (power_t[1:], coeffs_t[1:])

def step(carry, x):
return op(carry, x), None

(final_power, final_coeffs), _ = lax.scan(step, init, rest, unroll=_SCAN_UNROLL)

# Final fixpoint pass. ``_scalar_add_with_power`` and
# ``_scalar_mul_with_power`` each apply only one
# ``_reduce_power_coeffs_step`` per call, so a sequential scan over N
# elements can lag canonical form (gcd of coeffs is odd) by up to
# ``log2(N)`` reductions. Iterate with ``lax.while_loop`` so we only
# pay for the iters that actually reduce.
def _fixpoint_cond(state):
_, _, did_change = state
return did_change

def _fixpoint_body(state):
p, c, _ = state
new_p, new_c = _reduce_power_coeffs_step(p, c)
return new_p, new_c, jnp.any(new_p != p)

init_state = (final_power, final_coeffs, jnp.bool_(True))
final_power, final_coeffs, _ = lax.while_loop(
_fixpoint_cond, _fixpoint_body, init_state,
)
return final_power, final_coeffs


class ExactScalarArray(eqx.Module):
"""Exact scalar array for ZX-calculus phase arithmetic using dyadic representation.

Expand Down Expand Up @@ -135,11 +183,9 @@ def sum(self, axis: int = -1) -> "ExactScalarArray":
if axis < 0:
axis += self.power.ndim

scanned_power, scanned_coeffs = lax.associative_scan(
_scalar_add_with_power, (self.power, self.coeffs), axis=axis
result_power, result_coeffs = _reduce_along_scan(
self.power, self.coeffs, _scalar_add_with_power, axis,
)
result_power = jnp.take(scanned_power, indices=-1, axis=axis)
result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis)
return ExactScalarArray(result_coeffs, result_power)

def prod(self, axis: int = -1) -> "ExactScalarArray":
Expand All @@ -164,12 +210,9 @@ def prod(self, axis: int = -1) -> "ExactScalarArray":
result_coeffs = result_coeffs.at[..., 0].set(1)
return ExactScalarArray(result_coeffs)

scanned_power, scanned_coeffs = lax.associative_scan(
_scalar_mul_with_power, (self.power, self.coeffs), axis=axis
result_power, result_coeffs = _reduce_along_scan(
self.power, self.coeffs, _scalar_mul_with_power, axis,
)
result_power = jnp.take(scanned_power, indices=-1, axis=axis)
result_coeffs = jnp.take(scanned_coeffs, indices=-1, axis=axis)

return ExactScalarArray(result_coeffs, result_power)

def to_complex(self) -> jax.Array:
Expand Down
10 changes: 9 additions & 1 deletion src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tsim.core.graph import prepare_graph
from tsim.core.types import CompiledComponent, CompiledProgram
from tsim.noise.channels import ChannelSampler
from tsim.utils.cuda_helpers import copy_d2h

if TYPE_CHECKING:
from jax import Array as PRNGKey
Expand Down Expand Up @@ -335,7 +336,14 @@ def _sample_batches(

batches.append(samples)

result = np.concatenate(batches)[:shots]
# Concatenate on device, then a single d2h. The prior
# np.concatenate(batches) triggered per-batch __array__ (one d2h each)
# plus a host-side memcpy into a fresh numpy buffer for the concat
# output. For big bool tensors (e.g. 500k shots × 528 detector bits)
# the host memcpy alone was ~1 s on top of the PCIe transfer.
combined = batches[0] if len(batches) == 1 else jnp.concatenate(batches, axis=0)
jax.block_until_ready(combined)
result = copy_d2h(combined)[:shots]

if compute_reference:
assert reference is not None
Expand Down
141 changes: 141 additions & 0 deletions src/tsim/utils/cuda_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""CUDA-runtime helpers used by the sampler hot path.

Lazy-imports ``cuda.bindings``. When the import succeeds, host-pinned
allocations + a direct ``cudaMemcpy`` replace numpy's default pageable d2h:
the pageable path forces the driver to stage the transfer through internal
pinned scratch before copying into the user buffer (host-DRAM-bandwidth
bound), while a pinned destination skips the staging hop and lets d2h reach
PCIe line rate. When the import fails, ``copy_d2h`` transparently falls
back to ``numpy.array``.
"""

from __future__ import annotations

import contextlib
import ctypes
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
cudart: Any = None
_CUDA_BINDINGS_AVAILABLE = False
else:
try:
from cuda.bindings import runtime as cudart
_CUDA_BINDINGS_AVAILABLE = True
except Exception:
cudart = None
_CUDA_BINDINGS_AVAILABLE = False


def _src_on_host(src) -> bool:
"""Return True iff ``src`` lives on a CPU device.

Its buffer pointer would be a host address and unsafe to feed to
``cudaMemcpy(..., DeviceToHost)``.
"""
devices = getattr(src, "devices", None)
if devices is None:
return False
try:
ds = list(devices())
except TypeError:
ds = list(devices)
return any(getattr(d, "platform", "").lower() == "cpu" for d in ds)


class _PinnedBuf:
"""RAII wrapper for a ``cudaHostAlloc``'d region.

The Python instance owns the lifetime; ``cudaFreeHost`` runs in
``__del__`` when no references remain (typically when the wrapping
numpy view is garbage-collected).
"""

__slots__ = ("nbytes", "ptr")

def __init__(self, nbytes: int):
err, ptr = cudart.cudaHostAlloc(nbytes, cudart.cudaHostAllocDefault)
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaHostAlloc({nbytes}) failed: {err}")
self.ptr = int(ptr)
self.nbytes = nbytes

def __del__(self):
if self.ptr:
# cudart may be torn down at interpreter exit.
with contextlib.suppress(Exception):
cudart.cudaFreeHost(self.ptr)
self.ptr = 0


def alloc_pinned_numpy(nbytes: int, dtype, shape) -> np.ndarray:
"""Allocate a pinned host region and return it as an ndarray view.

The returned array's ``base`` chain pins the underlying ``_PinnedBuf``
alive until the array and all derived views are dropped; only then does
``cudaFreeHost`` run.

Args:
nbytes: Size of the underlying allocation in bytes. Must be at least
``prod(shape) * dtype.itemsize``.
dtype: numpy-compatible dtype for the returned view.
shape: Shape of the returned view.

Returns:
ndarray of the requested shape and dtype, backed by pinned memory.

Raises:
RuntimeError: if cuda.bindings is unavailable, or the underlying
``cudaHostAlloc`` fails.

"""
if not _CUDA_BINDINGS_AVAILABLE:
raise RuntimeError(
"cuda.bindings not importable; install 'cuda-bindings' or use "
"copy_d2h() for a transparent fallback."
)
buf = _PinnedBuf(nbytes)
carr = (ctypes.c_uint8 * nbytes).from_address(buf.ptr)
carr._owner = buf # type: ignore[attr-defined] # arr.base = carr; carr._owner = buf → buf stays alive
return np.frombuffer(carr, dtype=np.uint8).view(dtype).reshape(shape)


def copy_d2h(src, *, dst: np.ndarray | None = None) -> np.ndarray:
"""Device-to-host copy, pinned-destination fast path when available.

Args:
src: Single-device contiguous array-like exposing
``unsafe_buffer_pointer()``, ``nbytes``, ``shape``, and ``dtype``.
The caller must sync to the source's stream before invocation
(``jax.block_until_ready(src)`` for a jax.Array).
dst: Optional pre-allocated pinned ndarray to write into. Must have
at least ``src.nbytes`` bytes.

Returns:
ndarray with the same shape and dtype as ``src``.

"""
if not _CUDA_BINDINGS_AVAILABLE or _src_on_host(src):
# Fallback path. Two cases:
# 1. cuda.bindings isn't importable
# 2. src lives on a CPU device (JAX on a CPU jaxlib build, or any
# array marked CPU) — its ``unsafe_buffer_pointer`` is a host
# pointer, so ``cudaMemcpy(..., DeviceToHost)`` would fail.
# ``np.asarray(jax.Array)`` returns a read-only zero-copy view in
# newer JAX, so allocate fresh and copy in to a writable buffer.
out = np.empty(src.shape, dtype=src.dtype)
out[:] = src
return out
if dst is None:
dst = alloc_pinned_numpy(src.nbytes, src.dtype, src.shape)
err = cudart.cudaMemcpy(
dst.ctypes.data,
src.unsafe_buffer_pointer(),
src.nbytes,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost,
)[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(f"cudaMemcpy d2h failed: {err}")
return dst
Loading