Motivation
In tsim's sampling pipeline, detectors often decouple from the observables in the ZX diagram. When they do, they compile to direct components: they're computed classically in NumPy (f_params[:, direct_f_indices] ^ direct_flips), on the CPU, before any JAX/GPU routine runs. The expensive part of sampling is the per-component autoregressive JAX loop in sample_program.
This opens an optimization for postselected simulations. If a shot is going to be discarded because a postselected detector fired, and that detector is one of the classically-computed direct detectors, then we can decide to discard the shot before running any JAX routines for it — skipping all the expensive work for that shot.
Proposed API
A new keyword on CompiledDetectorSampler.sample:
def sample(self, shots, *, ..., postselection_mask: np.ndarray | None = None):
postselection_mask: a plain boolean array of length num_detectors (detectors only, not observables). A shot is discarded if any masked detector fires.
- Independent of
prepend_observables / append_observables / separate_observables — those still act on the returned rows.
Return semantics
sample(shots, ...) still returns one row per drawn shot (shape stays (shots, num_outputs) modulo the usual observable/bit-packing options). No rows are dropped.
- Survivors (no masked detector fired): the complete, correct sample.
- Discarded via a direct postselected detector: the JAX loop is skipped for that shot. Direct detector columns (computed in NumPy for every shot) are real — including the fired postselected detector. The columns that were never computed (JAX-component outputs) are filled with
False.
The caller recovers the surviving shots by re-applying postselection_mask to the returned detector columns; the fired postselected detectors are always real, so this is reliable.
Note: A postselected detector that is non-direct (lives inside a JAX component) cannot be evaluated without running JAX, so those shots are computed in full regardless. Only direct postselected detectors trigger the JAX-skip / partial-row behavior. Correctness is preserved either way; the speedup only materializes when postselected detectors are direct. While this is not the most optimal approach, it is far simpler to implement (i.e. no load balancing for GPU is required). Since most detectors for QEC circuits are direct components, this should be very efficient in practice.
Implementation sketch
In _sample_batches / sample_program:
- Sample error params in NumPy and compute the direct detector bits for every shot (already cheap, CPU-only).
- Partition
postselection_mask into its direct and component parts. Apply the direct part to flag discarded shots.
- Feed only the surviving shots into the JAX
sample_program. To keep the JAX batch size fixed (avoiding recompilation), the NumPy side accumulates survivors in a buffer and dispatches a JAX call only once it has a full batch of B; the final leftover chunk is padded up to B and the padding discarded. (Fully-direct circuits skip JAX entirely.)
- Scatter JAX results back into their original shot rows; discarded shots keep their JAX-output columns as
False.
Notes / caveats
- Determinism: results are already only reproducible at fixed batch size; postselection doesn't change the number of draws (
shots), so this is unaffected beyond the existing caveat.
- Reference samples: for direct detectors the noiseless reference is a constant, so
use_detector_reference_sample can be applied in NumPy before the postselection check, keeping behavior consistent.
Out of scope
- Observable postselection (detectors only).
- Postselection on
CompiledMeasurementSampler.
Motivation
In tsim's sampling pipeline, detectors often decouple from the observables in the ZX diagram. When they do, they compile to direct components: they're computed classically in NumPy (
f_params[:, direct_f_indices] ^ direct_flips), on the CPU, before any JAX/GPU routine runs. The expensive part of sampling is the per-component autoregressive JAX loop insample_program.This opens an optimization for postselected simulations. If a shot is going to be discarded because a postselected detector fired, and that detector is one of the classically-computed direct detectors, then we can decide to discard the shot before running any JAX routines for it — skipping all the expensive work for that shot.
Proposed API
A new keyword on
CompiledDetectorSampler.sample:postselection_mask: a plain boolean array of lengthnum_detectors(detectors only, not observables). A shot is discarded if any masked detector fires.prepend_observables/append_observables/separate_observables— those still act on the returned rows.Return semantics
sample(shots, ...)still returns one row per drawn shot (shape stays(shots, num_outputs)modulo the usual observable/bit-packing options). No rows are dropped.False.The caller recovers the surviving shots by re-applying
postselection_maskto the returned detector columns; the fired postselected detectors are always real, so this is reliable.Implementation sketch
In
_sample_batches/sample_program:postselection_maskinto its direct and component parts. Apply the direct part to flag discarded shots.sample_program. To keep the JAX batch size fixed (avoiding recompilation), the NumPy side accumulates survivors in a buffer and dispatches a JAX call only once it has a full batch ofB; the final leftover chunk is padded up toBand the padding discarded. (Fully-direct circuits skip JAX entirely.)False.Notes / caveats
shots), so this is unaffected beyond the existing caveat.use_detector_reference_samplecan be applied in NumPy before the postselection check, keeping behavior consistent.Out of scope
CompiledMeasurementSampler.