diff --git a/graphix/flow/core.py b/graphix/flow/core.py index fa674c3fd..22a792242 100644 --- a/graphix/flow/core.py +++ b/graphix/flow/core.py @@ -11,10 +11,12 @@ from typing import TYPE_CHECKING, Generic, TypeVar import networkx as nx +import numpy as np # `override` introduced in Python 3.12, `assert_never` introduced in Python 3.11 from typing_extensions import assert_never, override +from graphix._linalg import MatGF2, solve_f2_linear_system from graphix.circ_ext.extraction import ( CliffordMap, ExtractionResult, @@ -266,6 +268,54 @@ def to_gflow(self: XZCorrections[_PM_co]) -> GFlow[_PM_co]: gf.check_well_formed() # Raises a `FlowError` if the partial order and the correction function are not compatible. return gf + def to_pauli_flow(self: XZCorrections[_AM_co]) -> PauliFlow[_AM_co]: + r"""Extract a Pauli flow from XZ-corrections. + + This method does not invoke the flow-extraction routine on the underlying open graph. + Instead, it reconstructs, for every measured node, a correction set whose future + part matches the observed XZ-corrections and which satisfies the Pauli-flow + propositions (P1--P9; see :meth:`PauliFlow.check_well_formed`). + + Returns + ------- + PauliFlow[_AM_co] + + Raises + ------ + FlowError + If no Pauli flow is compatible with the XZ-corrections. + + Notes + ----- + See Theorem 4 in Ref. [1]. Compared with :meth:`to_gflow`, the difficulty is that + Pauli-flow correction sets may contain *anachronical corrections*: corrections + targeting :math:`X`- or :math:`Y`-measured nodes in the present or past of the + corrected node. Such corrections never appear in the pattern because + :meth:`PauliFlow.to_corrections` keeps only the part of each correction set in the + future (the ``& future`` filter). They must therefore be reconstructed. + + For each measured node ``i`` this is cast as a linear system over GF(2): membership + of future nodes in ``p(i)`` is pinned by the X-corrections of ``i``; the free + variables are the anachronical (:math:`X`- or :math:`Y`-measured, non-future) + candidates and, where the local proposition allows it, ``i`` itself; and the + equations encode the odd-neighbourhood constraints (Z-corrections on future nodes, + P2 on past non-(:math:`Y`/ :math:`Z`) nodes, the P3 coupling on past + :math:`Y`-measured nodes, and the local proposition P4--P9 on ``i``). The system + is reduced with :meth:`graphix._linalg.MatGF2.gauss_elimination` and solved with + :func:`graphix._linalg.solve_f2_linear_system`. + + The subsequent call to :meth:`PauliFlow.check_well_formed` is a regression guard: + the GF(2) construction satisfies the propositions by design. + + References + ---------- + [1] Browne et al., 2007 New J. Phys. 9 250 (arXiv:quant-ph/0702212). + """ + correction_function = _reconstruct_pauli_correction_function(self) + pf = PauliFlow(self.og, correction_function, self.partial_order_layers) + pf.check_well_formed() + return pf + def to_bloch(self: XZCorrections[Measurement]) -> XZCorrections[BlochMeasurement]: """Return the XZ-corrections where all measurements in the open graph are converted to Bloch. @@ -1244,6 +1294,170 @@ def check_well_formed(self) -> None: raise PartialOrderError(PartialOrderErrorReason.IncorrectNodes) +def _gf2_neighbour_parity(neighbours: AbstractSet[int], subset: AbstractSet[int]) -> int: + """Return ``len(neighbours & subset) % 2``.""" + return len(neighbours & subset) % 2 + + +def _solve_pauli_correcting_set( + xz: XZCorrections[_AM_co], + node: int, + future: AbstractSet[int], + adjacency: Mapping[int, AbstractSet[int]], + labels: Mapping[int, Plane | Axis], + non_inputs: AbstractSet[int], +) -> set[int] | None: + """Reconstruct the Pauli-flow correction set of a single measured node. + + Parameters + ---------- + xz : XZCorrections[_AM_co] + XZ-corrections from which the visible correction entries are read. + node : int + Measured node whose correcting set is reconstructed. + future : AbstractSet[int] + Nodes in the future of ``node`` at the time of its measurement. + adjacency : Mapping[int, AbstractSet[int]] + Open-graph adjacency lists. + labels : Mapping[int, Plane | Axis] + Measurement labels of the measured nodes. + non_inputs : AbstractSet[int] + Non-input nodes of the open graph. + + Returns + ------- + set[int] | None + The reconstructed correcting set, or ``None`` if no solution exists. + + Notes + ----- + See :meth:`XZCorrections.to_pauli_flow` for the GF(2) system solved here. + """ + graph_nodes = set(xz.og.graph.nodes) + x_members = set(xz.x_corrections.get(node, ())) + z_members = set(xz.z_corrections.get(node, ())) + label = labels[node] + + fixed_members = set(x_members) + if label in {Plane.XZ, Plane.YZ, Axis.Z}: + if node not in non_inputs: + return None + fixed_members.add(node) + elif label == Plane.XY: + fixed_members.discard(node) + + nonfuture_others = graph_nodes - set(future) - {node} + free_candidates = sorted( + candidate + for candidate in nonfuture_others + if candidate in non_inputs and labels.get(candidate) in {Axis.X, Axis.Y} + ) + self_is_free = label in {Axis.X, Axis.Y} and node in non_inputs + variables = [*free_candidates, node] if self_is_free else free_candidates + var_index = {variable: index for index, variable in enumerate(variables)} + + def parity_from_fixed(graph_node: int) -> int: + return _gf2_neighbour_parity(adjacency[graph_node], fixed_members) + + def variable_coefficients(graph_node: int) -> list[int]: + return [1 if variable in adjacency[graph_node] else 0 for variable in variables] + + matrix_rows: list[list[int]] = [] + rhs: list[int] = [] + + for graph_node in future: + matrix_rows.append(variable_coefficients(graph_node)) + rhs.append((1 if graph_node in z_members else 0) ^ parity_from_fixed(graph_node)) + + for graph_node in nonfuture_others: + node_label = labels.get(graph_node) + if node_label is not None and node_label not in {Axis.Y, Axis.Z}: + matrix_rows.append(variable_coefficients(graph_node)) + rhs.append(parity_from_fixed(graph_node)) + + for graph_node in nonfuture_others: + if labels.get(graph_node) == Axis.Y: + row = variable_coefficients(graph_node) + if graph_node in var_index: + row[var_index[graph_node]] ^= 1 + matrix_rows.append(row) + rhs.append(parity_from_fixed(graph_node)) + + if label == Axis.Y: + row = variable_coefficients(node) + if node in var_index: + row[var_index[node]] ^= 1 + matrix_rows.append(row) + rhs.append(1 ^ parity_from_fixed(node)) + elif label != Axis.Z: + target_parity = 0 if label == Plane.YZ else 1 + matrix_rows.append(variable_coefficients(node)) + rhs.append(target_parity ^ parity_from_fixed(node)) + + if not matrix_rows: + return set(fixed_members) + + n_vars = len(variables) + if n_vars == 0: + return None if any(rhs) else set(fixed_members) + + augmented = np.array([[*row, column] for row, column in zip(matrix_rows, rhs, strict=True)], dtype=np.uint8).view( + MatGF2 + ) + reduced = augmented.gauss_elimination(ncols=n_vars) + lhs = MatGF2(reduced[:, :n_vars]) + rhs_column = reduced[:, n_vars] + for row_index in range(lhs.shape[0]): + if not lhs[row_index].any() and rhs_column[row_index] != 0: + return None + solution = solve_f2_linear_system(lhs, MatGF2(rhs_column)) + + correcting_set = set(fixed_members) + correcting_set.update(variable for variable, bit in zip(variables, solution, strict=True) if int(bit)) + return correcting_set + + +def _reconstruct_pauli_correction_function(xz: XZCorrections[_AM_co]) -> dict[int, set[int]]: + """Reconstruct a Pauli-flow correction function from XZ-corrections. + + Parameters + ---------- + xz : XZCorrections[_AM_co] + Well-formed XZ-corrections to invert. + + Returns + ------- + dict[int, set[int]] + Pauli-flow correction function. + + Raises + ------ + FlowError + If no Pauli flow is compatible with ``xz``. + """ + og = xz.og + adjacency: dict[int, set[int]] = {n: set(og.graph.neighbors(n)) for n in og.graph.nodes} + labels: dict[int, Plane | Axis] = {n: meas.to_plane_or_axis() for n, meas in og.measurements.items()} + non_inputs = set(og.graph.nodes) - set(og.input_nodes) + + future_of: dict[int, set[int]] = {} + accumulated: set[int] = set() + for layer in xz.partial_order_layers: + for graph_node in layer: + future_of[graph_node] = set(accumulated) + accumulated |= set(layer) + + correction_function: dict[int, set[int]] = {} + for measured_node in og.measurements: + correcting_set = _solve_pauli_correcting_set( + xz, measured_node, future_of[measured_node], adjacency, labels, non_inputs + ) + if correcting_set is None: + raise FlowError + correction_function[measured_node] = correcting_set + return correction_function + + def _corrections_to_dag( x_corrections: Mapping[int, AbstractSet[int]], z_corrections: Mapping[int, AbstractSet[int]] ) -> nx.DiGraph[int]: diff --git a/graphix/pattern.py b/graphix/pattern.py index bcc18035b..250ce6e97 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -22,7 +22,7 @@ from graphix import command, optimization from graphix.command import CommandKind, Node -from graphix.flow.exceptions import FlowError +from graphix.flow.exceptions import FlowError, XZCorrectionsError from graphix.fundamentals import Plane from graphix.measurements import BlochMeasurement, Measurement, Outcome, toggle_outcome from graphix.pretty_print import OutputFormat, pattern_to_str @@ -991,6 +991,41 @@ def extract_gflow(self) -> GFlow[BlochMeasurement]: """ return self.extract_xzcorrections().downcast_bloch().to_gflow() + def extract_pauli_flow(self) -> PauliFlow[Measurement]: + r"""Extract the Pauli flow structure from the current measurement pattern. + + This method does not call the flow-extraction routine on the underlying open graph, + but constructs the Pauli flow from the pattern corrections instead. Unlike open-graph + flow extraction, this guarantees a flow that generates *this* pattern: the decisive + criterion is that :meth:`PauliFlow.to_corrections` reproduces the pattern's + XZ-corrections exactly. + + Returns + ------- + PauliFlow[Measurement] + The Pauli flow associated with the current pattern. + + Raises + ------ + FlowError + If the pattern is empty or if the extracted structure does not satisfy + the well-formedness conditions required for a valid Pauli flow. + ValueError + If `N` commands in the pattern do not represent a :math:`|+\rangle` state or if the pattern corrections form closed loops. + + Notes + ----- + The notes provided in :func:`self.extract_causal_flow` apply here as well. + Anachronical corrections omitted from the pattern (Theorem 4 in Ref. [1]) are + recovered in :meth:`XZCorrections.to_pauli_flow`; see that method for the GF(2) + reconstruction strategy. + + References + ---------- + [1] Browne et al., 2007 New J. Phys. 9 250 (arXiv:quant-ph/0702212). + """ + return self.extract_xzcorrections().to_pauli_flow() + def extract_xzcorrections(self) -> XZCorrections[Measurement]: """Extract the XZ-corrections from the current measurement pattern. @@ -1425,20 +1460,23 @@ def draw( if flow_from_pattern: try: - xz_corrections = self.extract_xzcorrections().downcast_bloch() - except TypeError: + xz_corrections = self.extract_xzcorrections() + except (XZCorrectionsError, ValueError): pass else: try: - flow = xz_corrections.to_causal_flow() - except FlowError: + flow = xz_corrections.downcast_bloch().to_causal_flow() + except (TypeError, FlowError): try: - flow = xz_corrections.to_gflow() - except FlowError: - warn( - "The pattern is not consistent with a causal flow or a gflow. An attempt to be extract the flow from the underlying open graph will be made.", - stacklevel=stacklevel, - ) + flow = xz_corrections.downcast_bloch().to_gflow() + except (TypeError, FlowError): + try: + flow = xz_corrections.to_pauli_flow() + except FlowError: + warn( + "The pattern is not consistent with a causal flow, a gflow or a Pauli flow. An attempt to be extract the flow from the underlying open graph will be made.", + stacklevel=stacklevel, + ) if flow is None: og = self.extract_opengraph() diff --git a/tests/test_flow_core.py b/tests/test_flow_core.py index 8e722bd93..9a64d2121 100644 --- a/tests/test_flow_core.py +++ b/tests/test_flow_core.py @@ -413,6 +413,21 @@ def test_corrections_to_pattern(self, test_case: XZCorrectionsTestCase, fx_rng: state = pattern.simulate_pattern(input_state=PlanarState(plane, alpha), rng=fx_rng) assert state.isclose(state_ref) + @pytest.mark.parametrize("test_case", prepare_test_xzcorrections()) + def test_corrections_to_pauli_flow(self, test_case: XZCorrectionsTestCase) -> None: + """Tests the round trip Flow -> XZCorrections -> PauliFlow -> XZCorrections.""" + flow = test_case.flow + flow.check_well_formed() + corrections = flow.to_corrections() + corrections.check_well_formed() + pauli_flow = corrections.to_pauli_flow() + pauli_flow.check_well_formed() + assert pauli_flow.correction_function == flow.correction_function + assert pauli_flow.partial_order_layers == flow.partial_order_layers + round_trip = pauli_flow.to_corrections() + assert round_trip.x_corrections == corrections.x_corrections + assert round_trip.z_corrections == corrections.z_corrections + class TestFlow: """Bundle for unit tests of :class:`PauliFlow` and children.""" diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 6a87c4016..363fbc4e2 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -972,6 +972,36 @@ def test_extract_gflow(self, fx_rng: Generator, test_case: PatternFlowTestCase) with pytest.raises(FlowError): test_case.pattern.extract_gflow() + def test_extract_pauli_flow_user_example(self) -> None: + """Extract Pauli flow directly from a pattern with no causal flow or gflow.""" + pattern = Pattern( + input_nodes=[0], + cmds=[ + N(1), + N(2), + N(3), + E((0, 1)), + E((1, 2)), + E((2, 3)), + M(0, Measurement.X), + X(3, {0}), + M(1, Measurement.X), + Z(3, {1}), + M(2, Measurement.X), + X(3, {2}), + ], + output_nodes=[3], + ) + + pauli_flow = pattern.extract_pauli_flow() + + assert pauli_flow.correction_function[0] == frozenset({1, 3}) + assert pauli_flow.correction_function[1] == frozenset({2}) + assert pauli_flow.correction_function[2] == frozenset({3}) + + pauli_flow_og = pattern.extract_opengraph().extract_pauli_flow() + assert pauli_flow.correction_function == pauli_flow_og.correction_function + # From open graph def test_extract_cflow_og(self, fx_rng: Generator) -> None: alpha = 2 * np.pi * fx_rng.random() diff --git a/tests/test_pauli_flow_extraction.py b/tests/test_pauli_flow_extraction.py new file mode 100644 index 000000000..b8a9e92a1 --- /dev/null +++ b/tests/test_pauli_flow_extraction.py @@ -0,0 +1,215 @@ +r"""Tests for Pauli-flow extraction from a pattern / XZ-corrections. + +Correctness criterion +--------------------- +A reconstructed Pauli flow ``pf`` generates the original pattern if and only if +``pf.check_well_formed()`` succeeds *and* ``pf.to_corrections()`` reproduces the pattern's +X- and Z-corrections exactly. The latter round-trip property is the decisive check: it +guarantees that the flow generates *this* pattern (and not merely some Pauli flow of the +underlying open graph, which need not be unique). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import networkx as nx +import numpy as np +import pytest + +from graphix import Measurement, OpenGraph, Pattern +from graphix.command import E, M, N, X, Z +from graphix.flow.core import XZCorrections +from graphix.flow.exceptions import FlowError +from graphix.opengraph import OpenGraphError +from graphix.pattern import DrawPatternAnnotations +from tests.test_flow_core import prepare_test_xzcorrections + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from collections.abc import Set as AbstractSet + + from numpy.random import Generator + + +def _norm(corrections: Mapping[int, AbstractSet[int]]) -> dict[int, frozenset[int]]: + """Drop empty correction sets to compare correction dictionaries up to empty entries.""" + return {k: frozenset(v) for k, v in corrections.items() if v} + + +def _assert_round_trip(pattern: Pattern) -> None: + xz = pattern.extract_xzcorrections() + pf = xz.to_pauli_flow() + assert pf.is_well_formed() + rt = pf.to_corrections() + assert _norm(rt.x_corrections) == _norm(xz.x_corrections) + assert _norm(rt.z_corrections) == _norm(xz.z_corrections) + + +def _correction_function(pattern: Pattern) -> dict[int, set[int]]: + pf = pattern.extract_pauli_flow() + return {k: set(v) for k, v in pf.correction_function.items()} + + +def _causal_pattern() -> Pattern: + return Pattern(input_nodes=[0], cmds=[N(1), E((0, 1)), M(0, Measurement.XY(0)), X(1, {0})], output_nodes=[1]) + + +def _gflow_pattern() -> Pattern: + return Pattern( + input_nodes=[0], + cmds=[ + N(1), + N(2), + N(3), + E((0, 1)), + E((0, 2)), + E((1, 2)), + E((1, 3)), + M(0, Measurement.XY(0.1)), + X(2, {0}), + X(3, {0}), + M(1, Measurement.XZ(0.2)), + Z(2, {1}), + Z(3, {1}), + X(2, {1}), + ], + output_nodes=[2, 3], + ) + + +def _pauli_pattern() -> Pattern: + return Pattern( + input_nodes=[0], + cmds=[ + N(1), + N(2), + N(3), + E((0, 1)), + E((1, 2)), + E((2, 3)), + M(0, Measurement.X), + X(3, {0}), + M(1, Measurement.X), + Z(3, {1}), + M(2, Measurement.X), + X(3, {2}), + ], + output_nodes=[3], + ) + + +def test_extract_pauli_flow_causal_example() -> None: + pattern = _causal_pattern() + assert _correction_function(pattern) == {0: {1}} + _assert_round_trip(pattern) + + +def test_extract_pauli_flow_gflow_example() -> None: + pattern = _gflow_pattern() + assert _correction_function(pattern) == {0: {2, 3}, 1: {1, 2}} + _assert_round_trip(pattern) + + +def test_extract_pauli_flow_pauli_example() -> None: + pattern = _pauli_pattern() + assert _correction_function(pattern) == {0: {1, 3}, 1: {2}, 2: {3}} + _assert_round_trip(pattern) + + +def test_extract_pauli_flow_matches_reference_flows() -> None: + """Reconstructed ``p`` must match the reference correction function, not only round-trip XZ.""" + for test_case in prepare_test_xzcorrections(): + flow = test_case.flow + xz = flow.to_corrections() + pf = xz.to_pauli_flow() + assert pf.correction_function == flow.correction_function + + +@pytest.mark.filterwarnings("ignore:Open graph with non-inferred Pauli measurements.") +def test_extract_pauli_flow_pauli_opengraph() -> None: + og = OpenGraph( + graph=nx.Graph([(0, 2), (2, 4), (3, 4), (4, 6), (1, 4), (1, 6), (2, 3), (3, 5), (2, 6), (3, 6)]), + input_nodes=[0], + output_nodes=[5, 6], + measurements={ + 0: Measurement.XY(0.1), + 1: Measurement.XZ(0.1), + 2: Measurement.Y, + 3: Measurement.XY(0.1), + 4: Measurement.Z, + }, + ) + _assert_round_trip(og.to_pattern()) + + +_MEASUREMENTS: list[Callable[[Generator], Measurement]] = [ + lambda r: Measurement.XY(float(r.random())), + lambda r: Measurement.XZ(float(r.random())), + lambda r: Measurement.YZ(float(r.random())), + lambda _r: Measurement.X, + lambda _r: Measurement.Y, + lambda _r: Measurement.Z, +] + + +@pytest.mark.filterwarnings("ignore:Open graph with non-inferred Pauli measurements.") +def test_extract_pauli_flow_randomized_round_trip() -> None: + tested = 0 + for seed in range(400): + rng = np.random.default_rng(seed) + n = int(rng.integers(4, 10)) + graph = nx.gnp_random_graph(n, 0.45, seed=seed) + if graph.number_of_edges() == 0: + continue + nodes = list(graph.nodes()) + rng.shuffle(nodes) + n_out = int(rng.integers(1, max(2, n // 2))) + n_in = int(rng.integers(0, max(1, n // 2))) + outputs = nodes[:n_out] + inputs = nodes[n_out : n_out + n_in] + measurements = { + m: _MEASUREMENTS[int(rng.integers(0, len(_MEASUREMENTS)))](rng) for m in nodes if m not in outputs + } + try: + pattern = OpenGraph( + graph=graph, input_nodes=inputs, output_nodes=outputs, measurements=measurements + ).to_pattern() + except OpenGraphError: + continue + _assert_round_trip(pattern) + tested += 1 + assert tested >= 30 + + +def test_draw_flow_from_pattern_skips_opengraph_for_pauli_measurements() -> None: + """``Pattern.draw`` must extract Pauli flow from corrections without ``downcast_bloch``.""" + pattern = _pauli_pattern() + mock_gv = MagicMock() + with ( + patch("graphix.visualization.GraphVisualizer.from_flow", return_value=mock_gv) as from_flow, + patch.object(pattern, "extract_opengraph") as extract_opengraph, + ): + pattern.draw(annotations=DrawPatternAnnotations.Flow, flow_from_pattern=True) + extract_opengraph.assert_not_called() + flow = from_flow.call_args.kwargs["flow"] + assert flow.correction_function[0] == frozenset({1, 3}) + + +def test_to_pauli_flow_raises_when_no_flow_exists() -> None: + og1 = OpenGraph(graph=nx.Graph([(0, 1)]), input_nodes=[0], output_nodes=[1], measurements={0: Measurement.Z}) + with pytest.raises(FlowError): + XZCorrections(og1, {}, {}, [{1}, {0}]).to_pauli_flow() + + graph: nx.Graph[int] = nx.Graph() + graph.add_node(0) + graph.add_edge(1, 2) + og2 = OpenGraph( + graph=graph, + input_nodes=[], + output_nodes=[2], + measurements={0: Measurement.XY(0.1), 1: Measurement.XY(0.1)}, + ) + with pytest.raises(FlowError): + XZCorrections(og2, {}, {}, [{2}, {1}, {0}]).to_pauli_flow()