Skip to content

feat: Support postselection masks #41

@rafaelha

Description

@rafaelha

Motivation

In tsim's sampling pipeline, detectors often decouple from the observables in the ZX diagram. When they do, they compile to direct components: they're computed classically in NumPy (f_params[:, direct_f_indices] ^ direct_flips), on the CPU, before any JAX/GPU routine runs. The expensive part of sampling is the per-component autoregressive JAX loop in sample_program.

This opens an optimization for postselected simulations. If a shot is going to be discarded because a postselected detector fired, and that detector is one of the classically-computed direct detectors, then we can decide to discard the shot before running any JAX routines for it — skipping all the expensive work for that shot.

Proposed API

A new keyword on CompiledDetectorSampler.sample:

def sample(self, shots, *, ..., postselection_mask: np.ndarray | None = None):
  • postselection_mask: a plain boolean array of length num_detectors (detectors only, not observables). A shot is discarded if any masked detector fires.
  • Independent of prepend_observables / append_observables / separate_observables — those still act on the returned rows.

Return semantics

sample(shots, ...) still returns one row per drawn shot (shape stays (shots, num_outputs) modulo the usual observable/bit-packing options). No rows are dropped.

  • Survivors (no masked detector fired): the complete, correct sample.
  • Discarded via a direct postselected detector: the JAX loop is skipped for that shot. Direct detector columns (computed in NumPy for every shot) are real — including the fired postselected detector. The columns that were never computed (JAX-component outputs) are filled with False.

The caller recovers the surviving shots by re-applying postselection_mask to the returned detector columns; the fired postselected detectors are always real, so this is reliable.

Note: A postselected detector that is non-direct (lives inside a JAX component) cannot be evaluated without running JAX, so those shots are computed in full regardless. Only direct postselected detectors trigger the JAX-skip / partial-row behavior. Correctness is preserved either way; the speedup only materializes when postselected detectors are direct. While this is not the most optimal approach, it is far simpler to implement (i.e. no load balancing for GPU is required). Since most detectors for QEC circuits are direct components, this should be very efficient in practice.

Implementation sketch

In _sample_batches / sample_program:

  1. Sample error params in NumPy and compute the direct detector bits for every shot (already cheap, CPU-only).
  2. Partition postselection_mask into its direct and component parts. Apply the direct part to flag discarded shots.
  3. Feed only the surviving shots into the JAX sample_program. To keep the JAX batch size fixed (avoiding recompilation), the NumPy side accumulates survivors in a buffer and dispatches a JAX call only once it has a full batch of B; the final leftover chunk is padded up to B and the padding discarded. (Fully-direct circuits skip JAX entirely.)
  4. Scatter JAX results back into their original shot rows; discarded shots keep their JAX-output columns as False.

Notes / caveats

  • Determinism: results are already only reproducible at fixed batch size; postselection doesn't change the number of draws (shots), so this is unaffected beyond the existing caveat.
  • Reference samples: for direct detectors the noiseless reference is a constant, so use_detector_reference_sample can be applied in NumPy before the postselection check, keeping behavior consistent.

Out of scope

  • Observable postselection (detectors only).
  • Postselection on CompiledMeasurementSampler.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions