diff --git a/CHANGELOG.md b/CHANGELOG.md index 125d4f9..918458f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Fast path in the detector sampler for components whose output is deterministically given by a single error variable. These components now skip the JAX compilation and autoregressive sampling pipeline, significantly speeding up detector sampling for surface-code circuits at low physical error rates. +- `CompiledDetectorSampler.sample` now accepts an optional `postselection_mask` argument for postselected simulations. When masked direct detectors fire, expensive JAX sampling is skipped for those shots while still returning one row per requested shot (#41). diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index bfe4e74..06b3d49 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -224,6 +224,51 @@ def __init__( and not self._direct_flips.any() and np.array_equal(self._direct_f_indices, np.arange(n_direct)) ) + direct_output_indices = np.asarray(prog.output_order[:n_direct], dtype=np.int32) + self._direct_output_indices = direct_output_indices + self._direct_detector_indices = direct_output_indices[ + direct_output_indices < self._num_detectors + ] + + def _scatter_direct_bits(self, f_params: np.ndarray) -> np.ndarray: + """Scatter direct output bits into global output order.""" + batch_size = f_params.shape[0] + num_outputs = self._program.num_outputs + n_direct = len(self._direct_f_indices) + if n_direct == 0: + return np.zeros((batch_size, num_outputs), dtype=np.bool_) + + if self._direct_zero_copy: + direct_concat = f_params[:, :n_direct].view(np.bool_) + else: + direct_concat = ( + f_params[:, self._direct_f_indices] ^ self._direct_flips + ).view(np.bool_) + + result = np.zeros((batch_size, num_outputs), dtype=np.bool_) + result[:, self._direct_output_indices] = direct_concat + return result + + def _compute_reference_sample(self) -> np.ndarray: + """Return the noiseless reference sample (f_params=0).""" + if not self._program.components: + num_f = self._channel_sampler.signature_matrix.shape[1] + f_params = np.zeros((1, num_f), dtype=np.uint8) + return self._scatter_direct_bits(f_params)[0] + + f_params_np = self._channel_sampler.sample(1) + f_params_np[0] = 0 + self._key, subkey = jax.random.split(self._key) + samples = sample_program(self._program, jnp.asarray(f_params_np), subkey) + return np.asarray(samples[0], dtype=np.bool_) + + def _resolve_batch_size(self, shots: int, batch_size: int | None) -> int: + """Resolve an effective JAX batch size for postselection sampling.""" + if batch_size is not None: + return batch_size + max_batch_size = self._estimate_batch_size() + num_batches = max(1, ceil(shots / max_batch_size)) + return ceil(shots / num_batches) def _peak_bytes_per_sample(self) -> int: """Estimate peak device memory per sample from compiled program structure.""" @@ -352,6 +397,102 @@ def _sample_direct(self, shots: int) -> np.ndarray: result = result[:, self._direct_reindex] return result.view(np.bool_) + def _sample_batches_postselection( + self, + shots: int, + batch_size: int | None, + postselection_mask: np.ndarray, + *, + detector_reference: np.ndarray | None = None, + ) -> np.ndarray: + """Sample with postselection, skipping JAX for directly discarded shots.""" + if shots < 0: + raise ValueError(f"shots must be non-negative, got {shots}") + if batch_size is not None and batch_size < 1: + raise ValueError(f"batch_size must be at least 1, got {batch_size}") + + postselection_mask = np.asarray(postselection_mask, dtype=np.bool_) + num_detectors = self._num_detectors + if postselection_mask.shape != (num_detectors,): + raise ValueError( + "postselection_mask must have shape " + f"({num_detectors},), got {postselection_mask.shape}" + ) + + if shots == 0: + return np.empty((0, self._program.num_outputs), dtype=np.bool_) + + if not self._program.components: + samples = self._sample_direct(shots) + if detector_reference is not None: + samples[:, :num_detectors] ^= detector_reference + return samples + + batch_size = self._resolve_batch_size(shots, batch_size) + direct_postselect_mask = np.zeros(num_detectors, dtype=np.bool_) + direct_postselect_mask[self._direct_detector_indices] = postselection_mask[ + self._direct_detector_indices + ] + + result = np.zeros((shots, self._program.num_outputs), dtype=np.bool_) + survivor_f_params: list[np.ndarray] = [] + survivor_indices: list[int] = [] + + def flush_survivors(*, pad: bool) -> None: + nonlocal survivor_f_params, survivor_indices + if not survivor_f_params: + return + + n_survivors = len(survivor_f_params) + if pad and n_survivors < batch_size: + pad_count = batch_size - n_survivors + f_batch = np.stack( + survivor_f_params + [survivor_f_params[-1]] * pad_count + ) + else: + f_batch = np.stack(survivor_f_params) + + self._key, subkey = jax.random.split(self._key) + jax_samples = np.asarray( + sample_program(self._program, jnp.asarray(f_batch), subkey), + dtype=np.bool_, + ) + for i, shot_idx in enumerate(survivor_indices): + result[shot_idx] = jax_samples[i] + + survivor_f_params = [] + survivor_indices = [] + + fetch_size = max(batch_size, 1) + shot_idx = 0 + while shot_idx < shots: + chunk_size = min(fetch_size, shots - shot_idx) + f_params_chunk = self._channel_sampler.sample(chunk_size) + direct_chunk = self._scatter_direct_bits(f_params_chunk) + + for local_idx in range(chunk_size): + global_idx = shot_idx + local_idx + result[global_idx] = direct_chunk[local_idx] + + det_bits = result[global_idx, :num_detectors].copy() + if detector_reference is not None: + det_bits ^= detector_reference + + if np.any(det_bits & direct_postselect_mask): + continue + + survivor_f_params.append(f_params_chunk[local_idx]) + survivor_indices.append(global_idx) + if len(survivor_f_params) == batch_size: + flush_survivors(pad=False) + + shot_idx += chunk_size + + flush_survivors(pad=True) + if detector_reference is not None: + result[:, :num_detectors] ^= detector_reference + return result + def __repr__(self) -> str: """Return a string representation with compilation statistics.""" n_direct = len(self._program.direct_f_indices) @@ -509,6 +650,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: ... @overload @@ -523,6 +665,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> np.ndarray: ... def sample( @@ -536,6 +679,7 @@ def sample( bit_packed: bool = False, use_detector_reference_sample: bool = False, use_observable_reference_sample: bool = False, + postselection_mask: np.ndarray | None = None, ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Return detector samples from the circuit. @@ -568,6 +712,11 @@ def sample( results represent deviations from the noiseless baseline. This should only be used when observables are deterministic. Otherwise, it can unpredictably change the results. + postselection_mask: Optional boolean array of length ``num_detectors``. + When set, shots where any masked detector fires may skip expensive + JAX sampling for direct postselected detectors. All ``shots`` rows + are still returned; filter survivors by re-applying the mask to the + detector columns. Returns: A numpy array or tuple of numpy arrays containing the samples. @@ -586,12 +735,30 @@ def sample( compute_reference = ( use_detector_reference_sample or use_observable_reference_sample ) + num_detectors = self._num_detectors - if compute_reference: + if postselection_mask is not None: + reference: np.ndarray | None = None + if compute_reference: + reference = self._compute_reference_sample() + + detector_reference = ( + reference[:num_detectors] + if use_detector_reference_sample and reference is not None + else None + ) + samples = self._sample_batches_postselection( + shots, + batch_size, + postselection_mask, + detector_reference=detector_reference, + ) + if use_observable_reference_sample and reference is not None: + samples[:, num_detectors:] ^= reference[num_detectors:] + elif compute_reference: samples, reference = self._sample_batches( shots, batch_size, compute_reference=True ) - num_detectors = self._num_detectors if use_detector_reference_sample: samples[:, :num_detectors] ^= reference[:num_detectors] if use_observable_reference_sample: @@ -599,7 +766,6 @@ def sample( else: samples = self._sample_batches(shots, batch_size) - num_detectors = self._num_detectors det_samples = samples[:, :num_detectors] obs_samples = samples[:, num_detectors:] diff --git a/test/unit/test_postselection.py b/test/unit/test_postselection.py new file mode 100644 index 0000000..ac5e917 --- /dev/null +++ b/test/unit/test_postselection.py @@ -0,0 +1,281 @@ +import numpy as np +import pytest +import stim + +import tsim.sampler as sampler_module +from tsim.circuit import Circuit + +MIXED_DIRECT_CIRCUIT = """ +X_ERROR(0.1) 0 +M 0 +DETECTOR rec[-1] +DETECTOR rec[-1] rec[-1] +""" + +FULLY_DIRECT_CIRCUIT = """ +X_ERROR(0.5) 0 +M 0 +DETECTOR rec[-1] +""" + +ALWAYS_DISCARD_CIRCUIT = """ +X_ERROR(1) 0 +M 0 +DETECTOR rec[-1] +DETECTOR rec[-1] rec[-1] +""" + + +def _mixed_sampler(seed: int = 0): + return Circuit(MIXED_DIRECT_CIRCUIT).compile_detector_sampler(seed=seed) + + +def _survivor_mask(samples: np.ndarray, mask: np.ndarray) -> np.ndarray: + """Return True for rows that survive postselection on detector columns.""" + return ~np.any(samples & mask, axis=1) + + +def test_postselection_mask_invalid_shape_raises(): + sampler = _mixed_sampler() + with pytest.raises(ValueError, match="postselection_mask must have shape"): + sampler.sample(1, postselection_mask=np.array([True, False, True])) + + +def test_postselection_preserves_row_count(): + sampler = _mixed_sampler(seed=1) + mask = np.array([True, False]) + samples = sampler.sample(20, postselection_mask=mask, batch_size=4) + assert samples.shape == (20, 2) + + +def test_postselection_discarded_rows_have_false_non_direct_columns(): + sampler = _mixed_sampler(seed=2) + mask = np.array([True, False]) + samples = sampler.sample(100, postselection_mask=mask, batch_size=8) + + # Detector 0 is direct; detector 1 is a compiled component. + discarded = samples[:, 0] & mask[0] + assert np.any(discarded) + assert np.all(~samples[discarded, 1]) + + +def test_postselection_survivors_match_unpostselected_sampling(): + seed = 7 + mask = np.array([True, False]) + sampler = _mixed_sampler(seed=seed) + with_post = sampler.sample(200, postselection_mask=mask, batch_size=16) + + sampler2 = _mixed_sampler(seed=seed) + without_post = sampler2.sample(200, batch_size=16) + + survivors = _survivor_mask(with_post, mask) + assert np.array_equal(with_post[survivors], without_post[survivors]) + + +def test_postselection_none_matches_default(): + seed = 3 + sampler = Circuit(MIXED_DIRECT_CIRCUIT).compile_detector_sampler(seed=seed) + default = sampler.sample(10, batch_size=4) + + sampler2 = Circuit(MIXED_DIRECT_CIRCUIT).compile_detector_sampler(seed=seed) + explicit_none = sampler2.sample(10, batch_size=4, postselection_mask=None) + + assert np.array_equal(default, explicit_none) + + +def test_postselection_fully_direct_matches_unpostselected_sampling(): + seed = 11 + mask = np.array([True]) + sampler = Circuit(FULLY_DIRECT_CIRCUIT).compile_detector_sampler(seed=seed) + samples = sampler.sample(500, postselection_mask=mask) + + sampler2 = Circuit(FULLY_DIRECT_CIRCUIT).compile_detector_sampler(seed=seed) + reference = sampler2.sample(500) + + assert np.array_equal(samples, reference) + + +def test_postselection_skips_jax_for_direct_discards(monkeypatch): + sampler = Circuit(ALWAYS_DISCARD_CIRCUIT).compile_detector_sampler(seed=0) + mask = np.array([True, False]) + calls: list[int] = [] + + original = sampler_module.sample_program + + def counting_sample_program(program, f_params, key): + calls.append(int(f_params.shape[0])) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", counting_sample_program) + + samples = sampler.sample(10, postselection_mask=mask, batch_size=4) + assert np.all(samples[:, 0]) + assert np.all(~samples[:, 1]) + assert calls == [] + + +def test_postselection_with_detector_reference_sample(): + seed = 0 + mask = np.array([True, False]) + kwargs = { + "use_detector_reference_sample": True, + "batch_size": 4, + } + + with_post = _mixed_sampler(seed=seed).sample(40, postselection_mask=mask, **kwargs) + survivors = _survivor_mask(with_post, mask) + assert survivors.any() + assert not np.any(with_post[survivors] & mask) + + discarded = ~survivors + assert discarded.any() + assert np.all(with_post[discarded, 0]) + assert np.all(~with_post[discarded, 1]) + + +def test_postselection_with_detector_reference_fully_direct(): + mask = np.array([True]) + circuit = Circuit(FULLY_DIRECT_CIRCUIT) + samples = circuit.compile_detector_sampler(seed=2).sample( + 50, + postselection_mask=mask, + use_detector_reference_sample=True, + ) + survivors = _survivor_mask(samples, mask) + assert survivors.any() + assert not np.any(samples[survivors] & mask) + + +def test_postselection_with_observable_reference_sample(): + c = Circuit(""" + R 0 1 2 + X 2 + M 0 1 2 + DETECTOR rec[-2] + DETECTOR rec[-3] + OBSERVABLE_INCLUDE(0) rec[-1] + """) + mask = np.array([True, False]) + kwargs = { + "separate_observables": True, + "use_observable_reference_sample": True, + "postselection_mask": mask, + "batch_size": 4, + } + + dets1, obs1 = c.compile_detector_sampler(seed=3).sample(20, **kwargs) + dets2, obs2 = c.compile_detector_sampler(seed=3).sample(20, **kwargs) + assert np.array_equal(dets1, dets2) + assert np.array_equal(obs1, obs2) + + +def test_postselection_zero_shots(): + sampler = _mixed_sampler() + mask = np.array([True, False]) + result = sampler.sample(0, postselection_mask=mask) + assert result.shape == (0, 2) + + dets, obs = sampler.sample(0, postselection_mask=mask, separate_observables=True) + assert dets.shape == (0, 2) + assert obs.shape == (0, 0) + + +def test_postselection_negative_shots_raises(): + sampler = _mixed_sampler() + with pytest.raises(ValueError, match="shots must be non-negative"): + sampler.sample(-1, postselection_mask=np.array([True, False])) + + +def test_postselection_invalid_batch_size_raises(): + sampler = _mixed_sampler() + with pytest.raises(ValueError, match="batch_size must be at least 1"): + sampler.sample(1, batch_size=0, postselection_mask=np.array([True, False])) + + +def test_postselection_respects_output_layout_flags(): + c = Circuit(""" + R 0 1 2 + X 2 + M 0 1 2 + DETECTOR rec[-2] + DETECTOR rec[-3] + OBSERVABLE_INCLUDE(0) rec[-1] + """) + sampler = c.compile_detector_sampler(seed=0) + mask = np.array([True, False]) + + appended = sampler.sample(2, append_observables=True, postselection_mask=mask) + assert appended.shape == (2, 3) + + prepended = sampler.sample(2, prepend_observables=True, postselection_mask=mask) + assert prepended.shape == (2, 3) + + dets, obs = sampler.sample(2, separate_observables=True, postselection_mask=mask) + assert dets.shape == (2, 2) + assert obs.shape == (2, 1) + + +def test_postselection_non_direct_mask_does_not_skip_jax(monkeypatch): + """Postselection on a non-direct detector still runs JAX for every shot.""" + sampler = _mixed_sampler(seed=9) + mask = np.array([False, True]) + jax_rows: list[int] = [] + + original = sampler_module.sample_program + + def counting_sample_program(program, f_params, key): + jax_rows.append(int(f_params.shape[0])) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", counting_sample_program) + + sampler.sample(16, postselection_mask=mask, batch_size=8) + assert sum(jax_rows) == 16 + + +def test_postselection_batch_padding(monkeypatch): + """Leftover survivors are padded to a full JAX batch and padding is discarded.""" + sampler = _mixed_sampler(seed=4) + mask = np.array([True, False]) + seen_batch_sizes: list[int] = [] + + original = sampler_module.sample_program + + def counting_sample_program(program, f_params, key): + seen_batch_sizes.append(int(f_params.shape[0])) + return original(program, f_params, key) + + monkeypatch.setattr(sampler_module, "sample_program", counting_sample_program) + + sampler.sample(10, postselection_mask=mask, batch_size=4) + assert seen_batch_sizes + assert all(batch_size == 4 for batch_size in seen_batch_sizes) + + +def test_postselection_surface_code_fully_direct(): + """Typical QEC circuits are fully direct; postselection must not change samples.""" + circ = stim.Circuit.generated( + "surface_code:rotated_memory_x", + distance=3, + rounds=2, + after_clifford_depolarization=0.01, + ) + c = Circuit.from_stim_program(circ) + mask = np.zeros(c.num_detectors, dtype=np.bool_) + mask[0] = True + + sampler = c.compile_detector_sampler(seed=0) + with_mask = sampler.sample(100, postselection_mask=mask, batch_size=16) + without_mask = c.compile_detector_sampler(seed=0).sample(100, batch_size=16) + assert np.array_equal(with_mask, without_mask) + + +def test_postselection_caller_can_filter_survivors(): + sampler = _mixed_sampler(seed=6) + mask = np.array([True, False]) + samples = sampler.sample(100, postselection_mask=mask, batch_size=8) + + survivors = _survivor_mask(samples, mask) + assert survivors.any() + assert not np.any(samples[survivors] & mask) + assert np.any(samples[~survivors] & mask)