Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates.
- `CompiledDetectorSampler.sample` now accepts an optional `postselection_mask` argument for postselected simulations. When masked direct detectors fire, expensive JAX sampling is skipped for those shots while still returning one row per requested shot (#41).



Expand Down
172 changes: 169 additions & 3 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,51 @@ def __init__(
and not self._direct_flips.any()
and np.array_equal(self._direct_f_indices, np.arange(n_direct))
)
direct_output_indices = np.asarray(prog.output_order[:n_direct], dtype=np.int32)
self._direct_output_indices = direct_output_indices
self._direct_detector_indices = direct_output_indices[
direct_output_indices < self._num_detectors
]

def _scatter_direct_bits(self, f_params: np.ndarray) -> np.ndarray:
"""Scatter direct output bits into global output order."""
batch_size = f_params.shape[0]
num_outputs = self._program.num_outputs
n_direct = len(self._direct_f_indices)
if n_direct == 0:
return np.zeros((batch_size, num_outputs), dtype=np.bool_)

if self._direct_zero_copy:
direct_concat = f_params[:, :n_direct].view(np.bool_)
else:
direct_concat = (
f_params[:, self._direct_f_indices] ^ self._direct_flips
).view(np.bool_)

result = np.zeros((batch_size, num_outputs), dtype=np.bool_)
result[:, self._direct_output_indices] = direct_concat
return result

def _compute_reference_sample(self) -> np.ndarray:
"""Return the noiseless reference sample (f_params=0)."""
if not self._program.components:
num_f = self._channel_sampler.signature_matrix.shape[1]
f_params = np.zeros((1, num_f), dtype=np.uint8)
return self._scatter_direct_bits(f_params)[0]

f_params_np = self._channel_sampler.sample(1)
f_params_np[0] = 0
self._key, subkey = jax.random.split(self._key)
samples = sample_program(self._program, jnp.asarray(f_params_np), subkey)
return np.asarray(samples[0], dtype=np.bool_)

def _resolve_batch_size(self, shots: int, batch_size: int | None) -> int:
"""Resolve an effective JAX batch size for postselection sampling."""
if batch_size is not None:
return batch_size
max_batch_size = self._estimate_batch_size()
num_batches = max(1, ceil(shots / max_batch_size))
return ceil(shots / num_batches)

def _peak_bytes_per_sample(self) -> int:
"""Estimate peak device memory per sample from compiled program structure."""
Expand Down Expand Up @@ -352,6 +397,102 @@ def _sample_direct(self, shots: int) -> np.ndarray:
result = result[:, self._direct_reindex]
return result.view(np.bool_)

def _sample_batches_postselection(
self,
shots: int,
batch_size: int | None,
postselection_mask: np.ndarray,
*,
detector_reference: np.ndarray | None = None,
) -> np.ndarray:
"""Sample with postselection, skipping JAX for directly discarded shots."""
if shots < 0:
raise ValueError(f"shots must be non-negative, got {shots}")
if batch_size is not None and batch_size < 1:
raise ValueError(f"batch_size must be at least 1, got {batch_size}")

postselection_mask = np.asarray(postselection_mask, dtype=np.bool_)
num_detectors = self._num_detectors
if postselection_mask.shape != (num_detectors,):
raise ValueError(
"postselection_mask must have shape "
f"({num_detectors},), got {postselection_mask.shape}"
)

if shots == 0:
return np.empty((0, self._program.num_outputs), dtype=np.bool_)

if not self._program.components:
samples = self._sample_direct(shots)
if detector_reference is not None:
samples[:, :num_detectors] ^= detector_reference
return samples

batch_size = self._resolve_batch_size(shots, batch_size)
direct_postselect_mask = np.zeros(num_detectors, dtype=np.bool_)
direct_postselect_mask[self._direct_detector_indices] = postselection_mask[
self._direct_detector_indices
]

result = np.zeros((shots, self._program.num_outputs), dtype=np.bool_)
survivor_f_params: list[np.ndarray] = []
survivor_indices: list[int] = []

def flush_survivors(*, pad: bool) -> None:
nonlocal survivor_f_params, survivor_indices
if not survivor_f_params:
return

n_survivors = len(survivor_f_params)
if pad and n_survivors < batch_size:
pad_count = batch_size - n_survivors
f_batch = np.stack(
survivor_f_params + [survivor_f_params[-1]] * pad_count
)
else:
f_batch = np.stack(survivor_f_params)

self._key, subkey = jax.random.split(self._key)
jax_samples = np.asarray(
sample_program(self._program, jnp.asarray(f_batch), subkey),
dtype=np.bool_,
)
for i, shot_idx in enumerate(survivor_indices):
result[shot_idx] = jax_samples[i]

survivor_f_params = []
survivor_indices = []

fetch_size = max(batch_size, 1)
shot_idx = 0
while shot_idx < shots:
chunk_size = min(fetch_size, shots - shot_idx)
f_params_chunk = self._channel_sampler.sample(chunk_size)
direct_chunk = self._scatter_direct_bits(f_params_chunk)

for local_idx in range(chunk_size):
global_idx = shot_idx + local_idx
result[global_idx] = direct_chunk[local_idx]

det_bits = result[global_idx, :num_detectors].copy()
if detector_reference is not None:
det_bits ^= detector_reference

if np.any(det_bits & direct_postselect_mask):
continue

survivor_f_params.append(f_params_chunk[local_idx])
survivor_indices.append(global_idx)
if len(survivor_f_params) == batch_size:
flush_survivors(pad=False)

shot_idx += chunk_size

flush_survivors(pad=True)
if detector_reference is not None:
result[:, :num_detectors] ^= detector_reference
return result

def __repr__(self) -> str:
"""Return a string representation with compilation statistics."""
n_direct = len(self._program.direct_f_indices)
Expand Down Expand Up @@ -509,6 +650,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]: ...

@overload
Expand All @@ -523,6 +665,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> np.ndarray: ...

def sample(
Expand All @@ -536,6 +679,7 @@ def sample(
bit_packed: bool = False,
use_detector_reference_sample: bool = False,
use_observable_reference_sample: bool = False,
postselection_mask: np.ndarray | None = None,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Return detector samples from the circuit.

Expand Down Expand Up @@ -568,6 +712,11 @@ def sample(
results represent deviations from the noiseless baseline. This should
only be used when observables are deterministic. Otherwise, it can
unpredictably change the results.
postselection_mask: Optional boolean array of length ``num_detectors``.
When set, shots where any masked detector fires may skip expensive
JAX sampling for direct postselected detectors. All ``shots`` rows
are still returned; filter survivors by re-applying the mask to the
detector columns.

Returns:
A numpy array or tuple of numpy arrays containing the samples.
Expand All @@ -586,20 +735,37 @@ def sample(
compute_reference = (
use_detector_reference_sample or use_observable_reference_sample
)
num_detectors = self._num_detectors

if compute_reference:
if postselection_mask is not None:
reference: np.ndarray | None = None
if compute_reference:
reference = self._compute_reference_sample()

detector_reference = (
reference[:num_detectors]
if use_detector_reference_sample and reference is not None
else None
)
samples = self._sample_batches_postselection(
shots,
batch_size,
postselection_mask,
detector_reference=detector_reference,
)
if use_observable_reference_sample and reference is not None:
samples[:, num_detectors:] ^= reference[num_detectors:]
elif compute_reference:
samples, reference = self._sample_batches(
shots, batch_size, compute_reference=True
)
num_detectors = self._num_detectors
if use_detector_reference_sample:
samples[:, :num_detectors] ^= reference[:num_detectors]
if use_observable_reference_sample:
samples[:, num_detectors:] ^= reference[num_detectors:]
else:
samples = self._sample_batches(shots, batch_size)

num_detectors = self._num_detectors
det_samples = samples[:, :num_detectors]
obs_samples = samples[:, num_detectors:]

Expand Down
Loading