Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates.
- `CompiledDetectorSampler.sample` now accepts an optional `postselection_mask` argument for postselected simulations (#41). The mask has length `num_detectors`; a shot is discarded when any masked detector fires. When a discarded shot is flagged by a direct detector, the expensive JAX autoregressive loop is skipped for that draw while still returning one row per requested shot. Discarded rows retain their direct detector columns and fill all other columns with zero; callers recover surviving shots by re-applying the mask to the returned detector columns. Masks that only target non-direct detectors, or that mask no direct detectors, fall back to the standard sampling path. Fully-direct circuits continue to use the NumPy fast path.



Expand Down
4 changes: 4 additions & 0 deletions docs/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,7 @@ uv run just doc
```

This will launch a local server to preview the documentation. You can also run `uv run just doc-build` to build the documentation without launching the server.

Postselection support lives in `CompiledDetectorSampler.sample(postselection_mask=...)`.
See the **Postselected simulations** section in `docs/index.md` and unit tests in
`test/unit/test_postselection.py`.
46 changes: 46 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,52 @@ When set to `True`, a noiseless reference sample is computed and XORed with the
results, so that output values represent deviations from the noiseless baseline.
Note that this feature should be used carefully. If detectors or observables are not deterministic, this may lead to incorrect statistics.

## Postselected simulations

For postselected QEC experiments, pass a boolean mask to
`CompiledDetectorSampler.sample`. The mask has length `num_detectors`; a shot is
*discarded* when any masked detector fires.

```python
import numpy as np

c = tsim.Circuit(
"""
X_ERROR(0.01) 0 1
M 0 1
DETECTOR rec[-2]
DETECTOR rec[-1] rec[-2]
OBSERVABLE_INCLUDE(0) rec[-1]
"""
)
sampler = c.compile_detector_sampler()
mask = np.array([True, False]) # postselect on detector 0

samples = sampler.sample(
shots=10_000,
postselection_mask=mask,
append_observables=True,
)
keep = ~np.any(samples[:, : c.num_detectors] & mask, axis=1)
survivors = samples[keep]
```

`sample` always returns exactly `shots` rows. Shots discarded by a **direct**
postselected detector skip the expensive JAX autoregressive loop; their direct
detector columns are still correct, and all other columns are filled with
`False`. Re-apply the mask to the detector columns (as above) to recover the
surviving shots. Detectors that live inside a JAX component cannot be evaluated
without running JAX, so those shots are always computed in full.

This is independent of `prepend_observables`, `append_observables`,
`separate_observables`, and `bit_packed`. When combined with
`use_detector_reference_sample`, the reference XOR is applied before the
postselection discard check. On surviving rows it is applied to every detector
column; on direct-discarded partial rows it is applied only to direct detector
columns (component columns stay `False`). When combined with
`use_observable_reference_sample`, the reference XOR is applied to every row
that ran JAX; direct-discarded partial rows are left unchanged.

## Benchmarks

With GPU acceleration, Tsim can achieve sampling throughput for low-magic circuits that approaches the throughput of Stim on Clifford circuits of the same size. The figure below shows a comparison for [distillation circuits](https://arxiv.org/html/2412.15165v1) (35 and 85 qubits), [cultivation circuits](https://arxiv.org/abs/2409.17595), and rotated surface code circuits.
Expand Down
247 changes: 243 additions & 4 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,71 @@ def __init__(
and not self._direct_flips.any()
and np.array_equal(self._direct_f_indices, np.arange(n_direct))
)
self._direct_global_indices = np.asarray(
prog.output_order[:n_direct], dtype=np.int32
)
self._direct_output_mask = np.zeros(prog.num_outputs, dtype=np.bool_)
if n_direct > 0:
self._direct_output_mask[self._direct_global_indices] = True
self._direct_detector_mask = self._direct_output_mask[
: self._num_detectors
].copy()

def _compute_direct_outputs(self, f_params_np: np.ndarray) -> np.ndarray:
"""Scatter direct output bits into a full (batch, num_outputs) bool array.

Non-direct columns are zero. The zero-copy fast path applies when
direct indices are 0..n-1, there are no flips, and no reindex —
i.e. the common surface-code case.
"""
batch = f_params_np.shape[0]
num_outputs = self._program.num_outputs
n_direct = len(self._direct_f_indices)
if n_direct == 0:
return np.zeros((batch, num_outputs), dtype=np.bool_)
if self._direct_zero_copy and n_direct == num_outputs:
return f_params_np[:, :n_direct].view(np.bool_).copy()
raw = (
f_params_np[:, :n_direct].view(np.bool_)
if self._direct_zero_copy
else (f_params_np[:, self._direct_f_indices] ^ self._direct_flips).view(
np.bool_
)
)
out = np.zeros((batch, num_outputs), dtype=np.bool_)
out[:, self._direct_global_indices] = raw
return out

def _compute_reference_sample(self) -> np.ndarray:
"""Return the noiseless reference sample (all f_params = 0).

Does not advance the channel sampler RNG.
"""
num_f = self._channel_sampler.signature_matrix.shape[1]
f_ref = np.zeros((1, num_f), dtype=np.uint8)
if not self._program.components:
return self._compute_direct_outputs(f_ref)[0]
self._key, subkey = jax.random.split(self._key)
return np.asarray(
sample_program(self._program, jnp.asarray(f_ref), subkey)[0],
dtype=np.bool_,
)

def _resolve_batch_size(
self,
shots: int,
batch_size: int | None,
*,
compute_reference: bool,
) -> int:
"""Choose a uniform JAX batch size for ``shots`` samples."""
if batch_size is None:
max_batch_size = self._estimate_batch_size()
num_batches = max(1, ceil(shots / max_batch_size))
batch_size = ceil(shots / num_batches)
if compute_reference and batch_size * ceil(shots / batch_size) == shots:
batch_size += 1
return batch_size

def _peak_bytes_per_sample(self) -> int:
"""Estimate peak device memory per sample from compiled program structure."""
Expand Down Expand Up @@ -301,8 +366,12 @@ def _sample_batches(
return empty, np.zeros(self._program.num_outputs, dtype=np.bool_)
return empty

if not self._program.components and not compute_reference:
return self._sample_direct(shots)
if not self._program.components:
samples = self._sample_direct(shots)
if compute_reference:
reference = self._compute_reference_sample()
return samples, reference
return samples

if batch_size is None:
max_batch_size = self._estimate_batch_size()
Expand Down Expand Up @@ -342,6 +411,131 @@ def _sample_batches(
return result, reference
return result

def _sample_batches_with_postselection(
self,
shots: int,
batch_size: int | None,
*,
postselection_mask: np.ndarray,
compute_reference: bool = False,
xor_detector_ref: bool = False,
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
"""Sample with postselection, skipping JAX for direct discarded shots.

Shots discarded by a direct masked detector are filled with their
direct detector columns and ``False`` elsewhere; JAX is never called for
those shots. Survivors are buffered until a full batch of
``batch_size`` is ready, then dispatched to ``sample_program`` in one
call. The final partial batch is padded to keep the JAX batch size
fixed (avoiding recompilation) and the padding rows are discarded.
"""
if shots < 0:
raise ValueError(f"shots must be non-negative, got {shots}")
if batch_size is not None and batch_size < 1:
raise ValueError(f"batch_size must be at least 1, got {batch_size}")

num_outputs = self._program.num_outputs
if shots == 0:
empty = np.empty((0, num_outputs), dtype=np.bool_)
empty_discarded = np.empty(0, dtype=np.bool_)
if compute_reference:
return empty, np.zeros(num_outputs, dtype=np.bool_), empty_discarded
return empty, None, empty_discarded

postselect_direct = postselection_mask & self._direct_detector_mask

if not self._program.components:
samples = self._sample_direct(shots)
if compute_reference:
reference = self._compute_reference_sample()
if xor_detector_ref:
samples[:, : self._num_detectors] ^= reference[
: self._num_detectors
]
return samples, reference, np.zeros(shots, dtype=np.bool_)
return samples, None, np.zeros(shots, dtype=np.bool_)

if batch_size is None:
batch_size = self._resolve_batch_size(
shots, batch_size, compute_reference=False
)

reference: np.ndarray | None = None
if compute_reference:
reference = self._compute_reference_sample()

result = np.zeros((shots, num_outputs), dtype=np.bool_)
was_discarded = np.zeros(shots, dtype=np.bool_)
survivor_f_buf: list[np.ndarray] = []
survivor_idx_buf: list[int] = []
shot_idx = 0

def _dispatch(f_batch: np.ndarray, indices: list[int], n_valid: int) -> None:
self._key, subkey = jax.random.split(self._key)
jax_out = np.asarray(
sample_program(self._program, jnp.asarray(f_batch), subkey)
)
result[indices[:n_valid]] = jax_out[:n_valid]

def _flush(*, final: bool = False) -> None:
nonlocal survivor_f_buf, survivor_idx_buf
while len(survivor_f_buf) >= batch_size:
_dispatch(
np.stack(survivor_f_buf[:batch_size]),
survivor_idx_buf[:batch_size],
batch_size,
)
survivor_f_buf = survivor_f_buf[batch_size:]
survivor_idx_buf = survivor_idx_buf[batch_size:]

if final and survivor_f_buf:
n_valid = len(survivor_f_buf)
f_stack = np.stack(survivor_f_buf)
f_batch = np.empty((batch_size, f_stack.shape[1]), dtype=f_stack.dtype)
f_batch[:n_valid] = f_stack
f_batch[n_valid:] = f_stack[0]
_dispatch(f_batch, survivor_idx_buf, n_valid)
survivor_f_buf = []
survivor_idx_buf = []

while shot_idx < shots:
chunk = min(batch_size, shots - shot_idx)
f_params_np = self._channel_sampler.sample(chunk)
direct_full = self._compute_direct_outputs(f_params_np)
det_cols = direct_full[:, : self._num_detectors]
if xor_detector_ref and reference is not None:
det_cols = det_cols ^ reference[: self._num_detectors]

discarded = (det_cols & postselect_direct).any(axis=1)

result[shot_idx : shot_idx + chunk, : self._num_detectors] = direct_full[
:, : self._num_detectors
]
was_discarded[shot_idx : shot_idx + chunk] = discarded

survivor_local = np.flatnonzero(~discarded)
if survivor_local.size:
survivor_f_buf.extend(f_params_np[survivor_local])
survivor_idx_buf.extend((shot_idx + survivor_local).tolist())

shot_idx += chunk
_flush()

_flush(final=True)

if xor_detector_ref and reference is not None:
det_ref = reference[: self._num_detectors]
survivors = ~was_discarded
result[survivors, : self._num_detectors] ^= det_ref
result[was_discarded, : self._num_detectors] ^= (
det_ref & self._direct_detector_mask
)

if compute_reference:
assert reference is not None
return result, reference, was_discarded
return result, None, was_discarded

def _sample_direct(self, shots: int) -> np.ndarray:
"""Fast path when all components are direct (pure numpy, no JAX)."""
f_params = self._channel_sampler.sample(shots)
Expand Down Expand Up @@ -509,6 +703,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]: ...

@overload
Expand All @@ -523,6 +718,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> np.ndarray: ...

def sample(
Expand All @@ -536,6 +732,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Return detector samples from the circuit.

Expand Down Expand Up @@ -568,6 +765,12 @@ def sample(
results represent deviations from the noiseless baseline. This should
only be used when observables are deterministic. Otherwise, it can
unpredictably change the results.
postselection_mask: Optional boolean array of length ``num_detectors``.
When set, shots where any masked direct detector fires skip the JAX
sampling loop. All ``shots`` rows are still returned: survivors contain
the full sample, while discarded rows retain direct detector columns
and fill component columns with ``False``. Re-apply the mask to the
returned detector columns to recover surviving shots.

Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand All @@ -587,7 +790,44 @@ def sample(
use_detector_reference_sample or use_observable_reference_sample
)

if compute_reference:
if postselection_mask is not None:
mask = np.asarray(postselection_mask, dtype=np.bool_)
if mask.shape != (self._num_detectors,):
raise ValueError(
f"postselection_mask must have shape ({self._num_detectors},), "
f"got {mask.shape}"
)
if postselection_mask is not mask:
postselection_mask = mask
if (
not (postselection_mask & self._direct_detector_mask).any()
or not self._program.components
):
postselection_mask = None

if postselection_mask is not None:
if compute_reference:
samples, reference, direct_discarded = (
self._sample_batches_with_postselection(
shots,
batch_size,
postselection_mask=postselection_mask,
compute_reference=True,
xor_detector_ref=use_detector_reference_sample,
)
)
assert reference is not None
num_detectors = self._num_detectors
if use_observable_reference_sample:
obs_ref = reference[num_detectors:]
samples[~direct_discarded, num_detectors:] ^= obs_ref
else:
samples, _, _ = self._sample_batches_with_postselection(
shots,
batch_size,
postselection_mask=postselection_mask,
)
elif compute_reference:
samples, reference = self._sample_batches(
shots, batch_size, compute_reference=True
)
Expand Down Expand Up @@ -618,7 +858,6 @@ def sample(
)

return _maybe_bit_pack(det_samples, bit_packed=bit_packed)
# TODO: don't compute observables if they are discarded here


class CompiledStateProbs(_CompiledSamplerBase):
Expand Down
Loading