From 7cd586373d229cc3862d1f42195f060b960563d8 Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Tue, 12 Aug 2025 12:21:09 +0300 Subject: [PATCH 1/6] Add integration tests --- test/integration/test_parametric_twirling_samples.py | 10 ++++++++++ test/integration/test_static_twirling_samples.py | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/test/integration/test_parametric_twirling_samples.py b/test/integration/test_parametric_twirling_samples.py index e0aa8237..d1109a0a 100644 --- a/test/integration/test_parametric_twirling_samples.py +++ b/test/integration/test_parametric_twirling_samples.py @@ -84,6 +84,16 @@ def make_circuits(): yield circuit, "parametric_right_box" + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.x(0) + circuit.rz(1.2, 0) + circuit.rx(Parameter("a"), 0) + with circuit.box([Twirl(dressing="right")]): + circuit.sx(0) + circuit.rz(1.5, 0) + + yield circuit, "parametric_nonclifford_between_boxes" def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/integration/test_static_twirling_samples.py b/test/integration/test_static_twirling_samples.py index 911bc6f4..2153e910 100644 --- a/test/integration/test_static_twirling_samples.py +++ b/test/integration/test_static_twirling_samples.py @@ -229,6 +229,17 @@ def make_circuits(): yield circuit, "propagate_through_merged_invariant_gates" + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.x(0) + circuit.rz(1.2, 0) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.sx(0) + circuit.rz(1.5, 0) + + yield circuit, "nonclifford_between_boxes" + def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: From 84495b3928bde12fe8800b929f5ed821b71ce4ea Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Tue, 12 Aug 2025 12:52:46 +0300 Subject: [PATCH 2/6] More integration tests --- .../test_parametric_twirling_samples.py | 14 +++++++++++++- test/integration/test_static_twirling_samples.py | 14 +++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/test/integration/test_parametric_twirling_samples.py b/test/integration/test_parametric_twirling_samples.py index d1109a0a..e2a667a9 100644 --- a/test/integration/test_parametric_twirling_samples.py +++ b/test/integration/test_parametric_twirling_samples.py @@ -93,7 +93,19 @@ def make_circuits(): circuit.sx(0) circuit.rz(1.5, 0) - yield circuit, "parametric_nonclifford_between_boxes" + yield circuit, "parameterize_nonclifford_between_left_right_boxes" + + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.rz(1.2, 0) + circuit.cx(0, 1) + circuit.rx(Parameter("a"), 0) + with circuit.box([Twirl(dressing="left")]): + circuit.cx(1, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + + yield circuit, "parameterize_nonclifford_between_left_left_boxes" def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/integration/test_static_twirling_samples.py b/test/integration/test_static_twirling_samples.py index 2153e910..104b2eff 100644 --- a/test/integration/test_static_twirling_samples.py +++ b/test/integration/test_static_twirling_samples.py @@ -238,7 +238,19 @@ def make_circuits(): circuit.sx(0) circuit.rz(1.5, 0) - yield circuit, "nonclifford_between_boxes" + yield circuit, "nonclifford_between_left_right_boxes" + + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.rz(1.2, 0) + circuit.cx(0, 1) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="left")]): + circuit.cx(1, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + + yield circuit, "nonclifford_between_left_left_boxes" def pytest_generate_tests(metafunc): From 92739c0507f2411cad730b38378aec4e203f18b2 Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Tue, 12 Aug 2025 14:18:32 +0300 Subject: [PATCH 3/6] Fix for parametric gates up to virtual gate crossing (?) --- samplomatic/pre_samplex/pre_samplex.py | 21 +++- samplomatic/samplex/nodes/__init__.py | 4 +- .../nodes/u2_param_multiplication_node.py | 100 +++++++++++++++++- .../test_parametric_twirling_samples.py | 23 +++- .../test_u2_param_multiplication_node.py | 84 +++++++++++++++ 5 files changed, 222 insertions(+), 10 deletions(-) diff --git a/samplomatic/pre_samplex/pre_samplex.py b/samplomatic/pre_samplex/pre_samplex.py index 0c851e3f..36c9e1ee 100644 --- a/samplomatic/pre_samplex/pre_samplex.py +++ b/samplomatic/pre_samplex/pre_samplex.py @@ -64,6 +64,8 @@ RightU2ParametricMultiplicationNode, SliceRegisterNode, TwirlSamplingNode, + RightU2ParametricConjugationNode, + LeftU2ParametricConjugationNode, ) from ..samplex.nodes.basis_transform_node import MEAS_PAULI_BASIS, PREP_PAULI_BASIS from ..samplex.nodes.pauli_past_clifford_node import ( @@ -1258,6 +1260,23 @@ def add_propagate_node( combined_register_name, np.array(list(pre_propagate.partition), dtype=np.intp), ) + elif mode is InstructionMode.PROPAGATE: + combined_register_type = VirtualType.U2 + if pre_propagate.operation.is_parameterized(): + param_idxs = [ + samplex.append_parameter_expression(param) + for _, param in pre_propagate.spec.params + ] + if pre_propagate.direction is Direction.LEFT: + propagate_node = RightU2ParametricConjugationNode( + op_name, combined_register_name, param_idxs + ) + else: + propagate_node = LeftU2ParametricConjugationNode( + op_name, combined_register_name, param_idxs + ) + else: + raise NotImplementedError() else: raise SamplexBuildError( f"Encountered unsupported {op_name} propragation with mode {mode} and " @@ -1278,8 +1297,6 @@ def add_propagate_node( node_idx = samplex.add_node(propagate_node) samplex.add_edge(combine_node_idx, node_idx) else: - # TODO: It should be possible to not add a slice node in this case, if there is - # a single predecessor. node_idx = combine_node_idx pre_nodes_to_nodes[pre_propagate_idx] = node_idx diff --git a/samplomatic/samplex/nodes/__init__.py b/samplomatic/samplex/nodes/__init__.py index 75d74bd1..9b030657 100644 --- a/samplomatic/samplex/nodes/__init__.py +++ b/samplomatic/samplex/nodes/__init__.py @@ -27,5 +27,7 @@ from .u2_param_multiplication_node import ( LeftU2ParametricMultiplicationNode, RightU2ParametricMultiplicationNode, - U2ParametricMultiplicationNode, + U2ParametricTransformationNode, + RightU2ParametricConjugationNode, + LeftU2ParametricConjugationNode, ) diff --git a/samplomatic/samplex/nodes/u2_param_multiplication_node.py b/samplomatic/samplex/nodes/u2_param_multiplication_node.py index 6aa221ff..c5c0172e 100644 --- a/samplomatic/samplex/nodes/u2_param_multiplication_node.py +++ b/samplomatic/samplex/nodes/u2_param_multiplication_node.py @@ -23,11 +23,11 @@ from .evaluation_node import EvaluationNode -class U2ParametricMultiplicationNode(EvaluationNode): - """Abstract parent node for nodes doing multiplication on a :class:`~.U2Register`. +class U2ParametricTransformationNode(EvaluationNode): + """Abstract parent node for nodes doing transformation on a :class:`~.U2Register`. The node stores a parametric representation of a one-qubit gate or gates from the - original circuit to perform multiplication on a registry. + original circuit to perform transformation on a registry. Limited to the gates ``rz`` or ``rx``, and all gates within the node are of the same type. @@ -98,7 +98,7 @@ def _get_operation(self, parameter_values: np.ndarray) -> U2Register: return U2Register(result) -class LeftU2ParametricMultiplicationNode(U2ParametricMultiplicationNode): +class LeftU2ParametricMultiplicationNode(U2ParametricTransformationNode): """Perform parametric left multiplication on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the @@ -141,7 +141,7 @@ def evaluate( registers[self._register_name].left_inplace_multiply(self._get_operation(parameter_values)) -class RightU2ParametricMultiplicationNode(U2ParametricMultiplicationNode): +class RightU2ParametricMultiplicationNode(U2ParametricTransformationNode): """Perform parametric right multiplication on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the @@ -182,3 +182,93 @@ def evaluate( ) registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) + + +class LeftU2ParametricConjugationNode(U2ParametricTransformationNode): + """Perform parametric left conjugation on a :class:`~.U2Register`. + + The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the + original circuit and performs a :math:`g*reg*g^{\dagger}` conjugation, where :math:`reg` is + the existing ``U2Register``. This is consistent with a traveling virtual gate going from + right to left. + + :math:`g` is limited to the gates ``rz`` or ``rx``, and all gates within the node are of the + same type:math:. + + Args: + operand: The gate type, given as a string. + register_name: The name of the register the operation is applied to. + param_idxs: List of ``ParamIndex`` for the parameter expressions specifying the gate + arguments. List order must match the order subsystems in the register. + + Raises: + SamplexConstructionError: if `param_idxs` is empty. + """ + + def evaluate( + self, registers: dict[RegisterName, VirtualRegister], parameter_values: np.ndarray + ): + """Evaluate this node. + + Args: + registers: At least those registers needed by this node to read from or write to. + parameter_values: The evaluated values of the parameter expressions in indices + ``self.parameter_idxs``, at the same order. + + Raises: + SamplexRuntimeError: If the number of parameter values doesn't match the number of + parameter expressions in ``self.parameter_idxs``. + """ + if len(parameter_values) != self.num_parameters: + raise SamplexRuntimeError( + f"Expected {self.num_parameters} parameter values instead got " + f"{len(parameter_values)}" + ) + + registers[self._register_name].left_inplace_multiply(self._get_operation(parameter_values)) + registers[self._register_name].inplace_multiply(self._get_operation(-parameter_values)) + + +class RightU2ParametricConjugationNode(U2ParametricTransformationNode): + """Perform parametric right conjugation on a :class:`~.U2Register`. + + The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the + original circuit and performs a :math:`g^{\dagger}*reg*g` conjugation, where :math:`reg` is + the existing ``U2Register``. This is consistent with a traveling virtual going going from + left to right. + + :math:`g` is limited to the gates ``rz`` or ``rx``, and all gates within the node are of the + same type:math:. + + Args: + operand: The gate type, given as a string. + register_name: The name of the register the operation is applied to. + param_idxs: List of ``ParamIndex`` for the parameter expressions specifying the gate + arguments. List order must match the order subsystems in the register. + + Raises: + SamplexConstructionError: if `param_idxs` is empty. + """ + + def evaluate( + self, registers: dict[RegisterName, VirtualRegister], parameter_values: np.ndarray + ): + """Evaluate this node. + + Args: + registers: At least those registers needed by this node to read from or write to. + parameter_values: The evaluated values of the parameter expressions in indices + ``self.parameter_idxs``, at the same order. + + Raises: + SamplexRuntimeError: If the number of parameter values doesn't match the number of + parameter expressions in ``self.parameter_idxs``. + """ + if len(parameter_values) != self.num_parameters: + raise SamplexRuntimeError( + f"Expected {self.num_parameters} parameter values instead got " + f"{len(parameter_values)}" + ) + + registers[self._register_name].left_inplace_multiply(self._get_operation(-parameter_values)) + registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) \ No newline at end of file diff --git a/test/integration/test_parametric_twirling_samples.py b/test/integration/test_parametric_twirling_samples.py index e2a667a9..871c837b 100644 --- a/test/integration/test_parametric_twirling_samples.py +++ b/test/integration/test_parametric_twirling_samples.py @@ -86,26 +86,45 @@ def make_circuits(): circuit = QuantumCircuit(1) with circuit.box([Twirl(dressing="left")]): + circuit.h(0) + circuit.rz(-1.2, 0) + with circuit.box([Twirl(dressing="right")]): circuit.x(0) circuit.rz(1.2, 0) circuit.rx(Parameter("a"), 0) + circuit.rz(Parameter("b"), 0) with circuit.box([Twirl(dressing="right")]): circuit.sx(0) circuit.rz(1.5, 0) - yield circuit, "parameterize_nonclifford_between_left_right_boxes" + yield circuit, "parameterized_nonclifford_between_right_right_boxes" + + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + # circuit.x(0) + # circuit.rz(1.2, 0) + circuit.noop(0) + circuit.rx(Parameter("a"), 0) + circuit.rz(Parameter("b"), 0) + with circuit.box([Twirl(dressing="right")]): + # circuit.sx(0) + # circuit.rz(1.5, 0) + circuit.noop(0) + + yield circuit, "parameterized_nonclifford_between_left_right_boxes" circuit = QuantumCircuit(2) with circuit.box([Twirl(dressing="left")]): circuit.rz(1.2, 0) circuit.cx(0, 1) circuit.rx(Parameter("a"), 0) + circuit.rz(Parameter("b"), 0) with circuit.box([Twirl(dressing="left")]): circuit.cx(1, 0) with circuit.box([Twirl(dressing="right")]): circuit.noop(0, 1) - yield circuit, "parameterize_nonclifford_between_left_left_boxes" + yield circuit, "parameterized_nonclifford_between_left_left_boxes" def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py index 28e9079b..c503f504 100644 --- a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py +++ b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py @@ -22,6 +22,8 @@ from samplomatic.samplex.nodes import ( LeftU2ParametricMultiplicationNode, RightU2ParametricMultiplicationNode, + LeftU2ParametricConjugationNode, + RightU2ParametricConjugationNode, ) X_MATRIX = np.array([[0, 1], [1, 0]]) @@ -113,3 +115,85 @@ def test_left_multiply(self, gate, matrix, rng): assert np.allclose( registers["a"].virtual_gates, np.matmul(register.virtual_gates, operation) ) + +class TestLeftU2ParamConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match="Expected at least one element in param_idxs", + ): + LeftU2ParametricConjugationNode("rz", "a", []) + + def test_evaluation_errors(self): + """Test that errors are properly raised during evaluation""" + node = LeftU2ParametricConjugationNode("rz", "a", [0, 1, 2]) + registers = {} + with pytest.raises( + SamplexRuntimeError, match=re.escape("Expected 3 parameter values instead got 1") + ): + node.evaluate(registers, [1]) + + @pytest.mark.parametrize("gate", ["rx", "rz"]) + def test_left_conjugation(self, gate, rng): + """Test left conjugation""" + register_shape = [5, 2] + + params = rng.random(register_shape[0]) * 2 * np.pi + haar_dist = HaarU2(register_shape[0]) + node = LeftU2ParametricConjugationNode(gate, "a", [i for i in range(register_shape[0])]) + register = haar_dist.sample(register_shape[1], rng) + registers = {"a": register.copy()} + node.evaluate(registers, params) + + # g*reg*(g^dagger) + expected_registers = {"a": register.copy()} + node = LeftU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, params) + node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, -params) + + assert np.allclose( + registers["a"].virtual_gates, expected_registers["a"].virtual_gates + ) + +class TestRightU2ParamConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match="Expected at least one element in param_idxs", + ): + RightU2ParametricConjugationNode("rz", "a", []) + + def test_evaluation_errors(self): + """Test that errors are properly raised during evaluation""" + node = RightU2ParametricConjugationNode("rz", "a", [0, 1, 2]) + registers = {} + with pytest.raises( + SamplexRuntimeError, match=re.escape("Expected 3 parameter values instead got 1") + ): + node.evaluate(registers, [1]) + + @pytest.mark.parametrize("gate", ["rx", "rz"]) + def test_right_conjugation(self, gate, rng): + """Test right conjugation""" + register_shape = [5, 2] + + params = rng.random(register_shape[0]) * 2 * np.pi + haar_dist = HaarU2(register_shape[0]) + node = RightU2ParametricConjugationNode(gate, "a", [i for i in range(register_shape[0])]) + register = haar_dist.sample(register_shape[1], rng) + registers = {"a": register.copy()} + node.evaluate(registers, params) + + # (g^dagger)*reg*g + expected_registers = {"a": register.copy()} + node = LeftU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, -params) + node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, params) + + assert np.allclose( + registers["a"].virtual_gates, expected_registers["a"].virtual_gates + ) \ No newline at end of file From 87689e964209f4632fcc6353d27570856d0f1ca9 Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Wed, 13 Aug 2025 15:27:50 +0300 Subject: [PATCH 4/6] fix for nonparametric non-cliffords --- .../template_builder/template_state.py | 20 +++-- samplomatic/pre_samplex/pre_samplex.py | 12 ++- samplomatic/samplex/nodes/__init__.py | 11 ++- .../samplex/nodes/multiplication_node.py | 48 ++++++++++-- .../nodes/u2_param_multiplication_node.py | 10 +-- test/integration/test_dynamic_circuits.py | 1 + .../test_static_twirling_samples.py | 26 +++++++ .../test_nodes/test_multiplication_node.py | 77 ++++++++++++++++++- 8 files changed, 181 insertions(+), 24 deletions(-) diff --git a/samplomatic/builders/template_builder/template_state.py b/samplomatic/builders/template_builder/template_state.py index 4b4697d0..43229e33 100644 --- a/samplomatic/builders/template_builder/template_state.py +++ b/samplomatic/builders/template_builder/template_state.py @@ -96,9 +96,13 @@ def append_remapped_gate( new_params = [] param_mapping = [] - for param in instr.operation.params: - param_mapping.append([self.param_iter.idx, param]) - new_params.append(next(self.param_iter)) + if instr.operation.is_parameterized(): + # Note: It is assumed here that if is_parameterized() is true, then all parameters + # are ParameterExpressions. This is true for now because all of our parametrized + # gates have a single parameter. + for param in instr.operation.params: + param_mapping.append([self.param_iter.idx, param]) + new_params.append(next(self.param_iter)) new_qubits = [self.qubit_map.get(qubit, qubit) for qubit in instr.qubits] new_operation = type(instr.operation)(*new_params) if new_params else instr.operation @@ -122,9 +126,13 @@ def remap_subcircuit(self, circuit: QuantumCircuit) -> tuple[QuantumCircuit, Par ] new_params = [] instr_param_mapping = [] - for param in instr.operation.params: - instr_param_mapping.append([self.param_iter.idx, param]) - new_params.append(next(self.param_iter)) + if instr.operation.is_parameterized(): + # Note: It is assumed here that if is_parameterized() is true, then all parameters + # are ParameterExpressions. This is true for now because all of our parametrized + # gates have a single parameter. + for param in instr.operation.params: + instr_param_mapping.append([self.param_iter.idx, param]) + new_params.append(next(self.param_iter)) new_operation = type(instr.operation)(*new_params) if new_params else instr.operation remapped_circuit.append(CircuitInstruction(new_operation, new_qubits, instr.clbits)) diff --git a/samplomatic/pre_samplex/pre_samplex.py b/samplomatic/pre_samplex/pre_samplex.py index 36c9e1ee..5f257f12 100644 --- a/samplomatic/pre_samplex/pre_samplex.py +++ b/samplomatic/pre_samplex/pre_samplex.py @@ -57,15 +57,17 @@ CollectZ2ToOutputNode, CombineRegistersNode, InjectNoiseNode, + LeftConjugationNode, LeftMultiplicationNode, + LeftU2ParametricConjugationNode, LeftU2ParametricMultiplicationNode, PauliPastCliffordNode, + RightConjugationNode, RightMultiplicationNode, + RightU2ParametricConjugationNode, RightU2ParametricMultiplicationNode, SliceRegisterNode, TwirlSamplingNode, - RightU2ParametricConjugationNode, - LeftU2ParametricConjugationNode, ) from ..samplex.nodes.basis_transform_node import MEAS_PAULI_BASIS, PREP_PAULI_BASIS from ..samplex.nodes.pauli_past_clifford_node import ( @@ -1276,7 +1278,11 @@ def add_propagate_node( op_name, combined_register_name, param_idxs ) else: - raise NotImplementedError() + operand = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2)) + if pre_propagate.direction is Direction.LEFT: + propagate_node = RightConjugationNode(operand, combined_register_name) + else: + propagate_node = LeftConjugationNode(operand, combined_register_name) else: raise SamplexBuildError( f"Encountered unsupported {op_name} propragation with mode {mode} and " diff --git a/samplomatic/samplex/nodes/__init__.py b/samplomatic/samplex/nodes/__init__.py index 9b030657..7cc1bd72 100644 --- a/samplomatic/samplex/nodes/__init__.py +++ b/samplomatic/samplex/nodes/__init__.py @@ -18,16 +18,21 @@ from .conversion_node import ConversionNode from .evaluation_node import EvaluationNode from .inject_noise_node import InjectNoiseNode -from .multiplication_node import LeftMultiplicationNode, RightMultiplicationNode +from .multiplication_node import ( + LeftConjugationNode, + LeftMultiplicationNode, + RightConjugationNode, + RightMultiplicationNode, +) from .node import Node from .pauli_past_clifford_node import PauliPastCliffordNode from .sampling_node import SamplingNode from .slice_register_node import SliceRegisterNode from .twirl_sampling_node import TwirlSamplingNode from .u2_param_multiplication_node import ( + LeftU2ParametricConjugationNode, LeftU2ParametricMultiplicationNode, + RightU2ParametricConjugationNode, RightU2ParametricMultiplicationNode, U2ParametricTransformationNode, - RightU2ParametricConjugationNode, - LeftU2ParametricConjugationNode, ) diff --git a/samplomatic/samplex/nodes/multiplication_node.py b/samplomatic/samplex/nodes/multiplication_node.py index bf41df83..8f6e2e73 100644 --- a/samplomatic/samplex/nodes/multiplication_node.py +++ b/samplomatic/samplex/nodes/multiplication_node.py @@ -19,12 +19,12 @@ from .evaluation_node import EvaluationNode -class MultiplicationNode(EvaluationNode): - """Abstract parent for nodes that perform multiplication against a fixed register. +class TransformationNode(EvaluationNode): + """Abstract parent for nodes that perform transformation against a fixed register. Args: - operand: The fixed group elements by which to multiply. - register_name: The name of the register to multiply with. + operand: The fixed group elements by which to transform. + register_name: The name of the register to transform with. Raises: SamplexConstructionError: If ``operand`` has more than one sample. @@ -56,7 +56,7 @@ def get_style(self): return super().get_style().append_data("Fixed Operand", repr(self._operand)) -class LeftMultiplicationNode(MultiplicationNode): +class LeftMultiplicationNode(TransformationNode): """Perform left multiplication of a fixed register against a given register. Args: @@ -71,7 +71,7 @@ def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): registers[self._register_name].left_inplace_multiply(self._operand) -class RightMultiplicationNode(MultiplicationNode): +class RightMultiplicationNode(TransformationNode): """Perform right multiplication of a fixed register against a given register. Args: @@ -84,3 +84,39 @@ class RightMultiplicationNode(MultiplicationNode): def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): registers[self._register_name].inplace_multiply(self._operand) + + +class LeftConjugationNode(TransformationNode): + """Perform left conjugation of a fixed register against a given register. + + Performs operand*reg*(operand^{\dagger}). + + Args: + operand: The fixed group elements by which to conjugate. + register_name: The name of the register to conjugate with. + + Raises: + SamplexConstructionError: If ``operand`` has more than one sample. + """ + + def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): + registers[self._register_name].left_inplace_multiply(self._operand) + registers[self._register_name].inplace_multiply(self._operand.invert()) + + +class RightConjugationNode(TransformationNode): + """Perform right conjugation of a fixed register against a given register. + + Performs (operand^{\dagger})*reg*operand. + + Args: + operand: The fixed group elements by which to conjugate. + register_name: The name of the register to conjugate with. + + Raises: + SamplexConstructionError: If ``operand`` has more than one sample. + """ + + def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): + registers[self._register_name].left_inplace_multiply(self._operand.invert()) + registers[self._register_name].inplace_multiply(self._operand) diff --git a/samplomatic/samplex/nodes/u2_param_multiplication_node.py b/samplomatic/samplex/nodes/u2_param_multiplication_node.py index c5c0172e..cca0f4fb 100644 --- a/samplomatic/samplex/nodes/u2_param_multiplication_node.py +++ b/samplomatic/samplex/nodes/u2_param_multiplication_node.py @@ -188,7 +188,7 @@ class LeftU2ParametricConjugationNode(U2ParametricTransformationNode): """Perform parametric left conjugation on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the - original circuit and performs a :math:`g*reg*g^{\dagger}` conjugation, where :math:`reg` is + original circuit and performs a :math:`g*reg*g^{\dagger}` conjugation, where :math:`reg` is the existing ``U2Register``. This is consistent with a traveling virtual gate going from right to left. @@ -233,8 +233,8 @@ class RightU2ParametricConjugationNode(U2ParametricTransformationNode): """Perform parametric right conjugation on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the - original circuit and performs a :math:`g^{\dagger}*reg*g` conjugation, where :math:`reg` is - the existing ``U2Register``. This is consistent with a traveling virtual going going from + original circuit and performs a :math:`g^{\dagger}*reg*g` conjugation, where :math:`reg` is + the existing ``U2Register``. This is consistent with a traveling virtual gate going from left to right. :math:`g` is limited to the gates ``rz`` or ``rx``, and all gates within the node are of the @@ -269,6 +269,6 @@ def evaluate( f"Expected {self.num_parameters} parameter values instead got " f"{len(parameter_values)}" ) - + registers[self._register_name].left_inplace_multiply(self._get_operation(-parameter_values)) - registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) \ No newline at end of file + registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) diff --git a/test/integration/test_dynamic_circuits.py b/test/integration/test_dynamic_circuits.py index aeefb46b..58fce4ae 100644 --- a/test/integration/test_dynamic_circuits.py +++ b/test/integration/test_dynamic_circuits.py @@ -233,6 +233,7 @@ def test_non_twirled_conditional(self, save_plot): circuit.measure(0, 0) with circuit.if_test((circuit.clbits[0], 1)) as _else: circuit.sx(1) + circuit.rz(1.2, 1) with _else: circuit.x(1) circuit.measure_all() diff --git a/test/integration/test_static_twirling_samples.py b/test/integration/test_static_twirling_samples.py index 104b2eff..1c7f1c04 100644 --- a/test/integration/test_static_twirling_samples.py +++ b/test/integration/test_static_twirling_samples.py @@ -35,6 +35,20 @@ def make_circuits(): yield circuit, f"{op_name}_{str(pair1).replace(' ', '')}_{str(pair2).replace(' ', '')}" + circuit = QuantumCircuit(2) + with circuit.box([Twirl()]): + circuit.noop(0, 1) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + circuit.cx(0, 1) + circuit.rz(1.2, 0) + circuit.h(1) + circuit.sx(0) + circuit.x(1) + circuit.rx(1.2, 1) + + yield circuit, "instructions_outside_boxes_chain" + circuit = QuantumCircuit(4) with circuit.box([Twirl(dressing="left")]): circuit.cx(1, 0) @@ -252,6 +266,18 @@ def make_circuits(): yield circuit, "nonclifford_between_left_left_boxes" + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0, 1) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.cx(0, 1) + circuit.rx(1.2, 1) + + yield circuit, "nonclifford_between_right_right_boxes" + def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/unit/test_samplex/test_nodes/test_multiplication_node.py b/test/unit/test_samplex/test_nodes/test_multiplication_node.py index 0e0ed7ef..fc41f179 100644 --- a/test/unit/test_samplex/test_nodes/test_multiplication_node.py +++ b/test/unit/test_samplex/test_nodes/test_multiplication_node.py @@ -18,7 +18,12 @@ from samplomatic.annotations import VirtualType from samplomatic.distributions import HaarU2, UniformPauli from samplomatic.exceptions import SamplexConstructionError -from samplomatic.samplex.nodes import LeftMultiplicationNode, RightMultiplicationNode +from samplomatic.samplex.nodes import ( + LeftConjugationNode, + LeftMultiplicationNode, + RightConjugationNode, + RightMultiplicationNode, +) from samplomatic.virtual_registers import U2Register @@ -80,3 +85,73 @@ def test_writes_to(self): node = RightMultiplicationNode(U2Register.identity(3, 1), "a") assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} assert node.outgoing_register_type is VirtualType.U2 + + +class TestRightConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match=re.escape("Expected fixed operand to have only one sample but it has 7"), + ): + RightConjugationNode(U2Register.identity(5, 7), "a") + + @pytest.mark.parametrize("distribution_type", [HaarU2, UniformPauli]) + def test_multiply(self, distribution_type, rng): + """Test left multiply""" + operand = distribution_type(5).sample(1, rng) + register = distribution_type(5).sample(7, rng) + node = RightConjugationNode(operand, "a") + assert node.outgoing_register_type is operand.TYPE + + registers = {"a": register.copy()} + node.evaluate(registers, []) + expected_registers = {"a": register.copy()} + node = LeftMultiplicationNode(operand.invert(), "a") + node.evaluate(expected_registers) + node = RightMultiplicationNode(operand, "a") + node.evaluate(expected_registers) + + assert list(registers) == ["a"] + assert np.allclose(expected_registers["a"].virtual_gates, registers["a"].virtual_gates) + + def test_writes_to(self): + """Test writes to""" + node = RightConjugationNode(U2Register.identity(3, 1), "a") + assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} + assert node.outgoing_register_type is VirtualType.U2 + + +class TestLeftConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match=re.escape("Expected fixed operand to have only one sample but it has 7"), + ): + LeftConjugationNode(U2Register.identity(5, 7), "a") + + @pytest.mark.parametrize("distribution_type", [HaarU2, UniformPauli]) + def test_multiply(self, distribution_type, rng): + """Test left multiply""" + operand = distribution_type(5).sample(1, rng) + register = distribution_type(5).sample(7, rng) + node = LeftConjugationNode(operand, "a") + assert node.outgoing_register_type is operand.TYPE + + registers = {"a": register.copy()} + node.evaluate(registers, []) + expected_registers = {"a": register.copy()} + node = LeftMultiplicationNode(operand, "a") + node.evaluate(expected_registers) + node = RightMultiplicationNode(operand.invert(), "a") + node.evaluate(expected_registers) + + assert list(registers) == ["a"] + assert np.allclose(expected_registers["a"].virtual_gates, registers["a"].virtual_gates) + + def test_writes_to(self): + """Test writes to""" + node = LeftConjugationNode(U2Register.identity(3, 1), "a") + assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} + assert node.outgoing_register_type is VirtualType.U2 From a35fbb97b847a812c335689a019cc427b989c4c7 Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Thu, 14 Aug 2025 11:53:44 +0300 Subject: [PATCH 5/6] Error if there's a non-clifford between left-right boxes. --- samplomatic/builders/build.py | 26 ++++++++- .../passthrough_template_builder.py | 7 +++ samplomatic/pre_samplex/pre_samplex.py | 1 + .../test_parametric_twirling_samples.py | 15 +---- .../test_static_twirling_samples.py | 11 ---- .../test_general_build_errors.py | 55 +++++++++++++++++++ .../test_template_builder/test_build.py | 6 +- 7 files changed, 92 insertions(+), 29 deletions(-) create mode 100644 test/unit/test_builders/test_general_build_errors.py diff --git a/samplomatic/builders/build.py b/samplomatic/builders/build.py index 21ca0e50..e5aee01c 100644 --- a/samplomatic/builders/build.py +++ b/samplomatic/builders/build.py @@ -17,12 +17,18 @@ from qiskit.circuit import QuantumCircuit from ..aliases import CircuitInstruction +from ..exceptions import SamplexBuildError from ..pre_samplex import PreSamplex from ..samplex import Samplex from .builder import Builder from .get_builders import get_builders from .specs import InstructionSpec -from .template_builder import TemplateState +from .template_builder import ( + LeftBoxTemplateBuilder, + PassthroughTemplateBuilder, + RightBoxTemplateBuilder, + TemplateState, +) def _build_stream( @@ -69,6 +75,24 @@ def _build( for idx, nested_instr in enumerate(_build_stream(stream, template_builder, samplex_builder)): # assume the nested instruction is a box for now, handle other control flow ops later inner_template_builder, inner_samplex_builder = get_builders(nested_instr) + if isinstance(template_builder, PassthroughTemplateBuilder): + if isinstance(inner_template_builder, LeftBoxTemplateBuilder): + # Upcoming box is left-dressed, so when we get back to the passthrough builder we + # need to track non-cliffords, in case then next box is right dressed. + template_builder.track_noncliffords = True + template_builder.found_noncliffords = False + elif ( + isinstance(inner_template_builder, RightBoxTemplateBuilder) + and template_builder.found_noncliffords + ): + raise SamplexBuildError( + "Cannot have non-clifford gate between a left-dressed box" + " and a right-dressed box (which involve that qubit)." + ) + else: + template_builder.track_noncliffords = False + template_builder.found_noncliffords = False + qubit_remapping = dict(zip(nested_instr.operation.body.qubits, nested_instr.qubits)) remapped_template_state = template_builder.state.remap(qubit_remapping, idx) diff --git a/samplomatic/builders/template_builder/passthrough_template_builder.py b/samplomatic/builders/template_builder/passthrough_template_builder.py index 49502a41..1108517b 100644 --- a/samplomatic/builders/template_builder/passthrough_template_builder.py +++ b/samplomatic/builders/template_builder/passthrough_template_builder.py @@ -23,6 +23,11 @@ class PassthroughTemplateBuilder(Builder[TemplateState, InstructionSpec]): """Template builder that passes all instructions through.""" + def __init__(self): + super().__init__() + self.track_noncliffords = False + self.found_noncliffords = False + def parse(self, instr: CircuitInstruction) -> InstructionSpec: """Parse a single non-box instruction. @@ -53,6 +58,8 @@ def parse(self, instr: CircuitInstruction) -> InstructionSpec: clbit_idxs=self.state.get_condition_clbits(instr.operation.condition), ) else: + if self.track_noncliffords and instr.operation.name in ("rx", "rz"): + self.found_noncliffords = True return InstructionSpec( params=self.state.append_remapped_gate(instr), mode=InstructionMode.PROPAGATE ) diff --git a/samplomatic/pre_samplex/pre_samplex.py b/samplomatic/pre_samplex/pre_samplex.py index 2c208c16..63168a88 100644 --- a/samplomatic/pre_samplex/pre_samplex.py +++ b/samplomatic/pre_samplex/pre_samplex.py @@ -1362,6 +1362,7 @@ def add_propagate_node( np.array(list(pre_propagate.partition), dtype=np.intp), ) elif mode is InstructionMode.PROPAGATE: + # What's left are the supported non-clifford gates rz\rx combined_register_type = VirtualType.U2 if pre_propagate.operation.is_parameterized(): param_idxs = [ diff --git a/test/integration/test_parametric_twirling_samples.py b/test/integration/test_parametric_twirling_samples.py index 05f4596a..e3a97090 100644 --- a/test/integration/test_parametric_twirling_samples.py +++ b/test/integration/test_parametric_twirling_samples.py @@ -99,20 +99,6 @@ def make_circuits(): yield circuit, "parameterized_nonclifford_between_right_right_boxes" - circuit = QuantumCircuit(1) - with circuit.box([Twirl(dressing="left")]): - # circuit.x(0) - # circuit.rz(1.2, 0) - circuit.noop(0) - circuit.rx(Parameter("a"), 0) - circuit.rz(Parameter("b"), 0) - with circuit.box([Twirl(dressing="right")]): - # circuit.sx(0) - # circuit.rz(1.5, 0) - circuit.noop(0) - - yield circuit, "parameterized_nonclifford_between_left_right_boxes" - circuit = QuantumCircuit(2) with circuit.box([Twirl(dressing="left")]): circuit.rz(1.2, 0) @@ -126,6 +112,7 @@ def make_circuits(): yield circuit, "parameterized_nonclifford_between_left_left_boxes" + def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: circuits, descriptions = zip(*make_circuits()) diff --git a/test/integration/test_static_twirling_samples.py b/test/integration/test_static_twirling_samples.py index 1c7f1c04..ca849f99 100644 --- a/test/integration/test_static_twirling_samples.py +++ b/test/integration/test_static_twirling_samples.py @@ -243,17 +243,6 @@ def make_circuits(): yield circuit, "propagate_through_merged_invariant_gates" - circuit = QuantumCircuit(1) - with circuit.box([Twirl(dressing="left")]): - circuit.x(0) - circuit.rz(1.2, 0) - circuit.rx(1.2, 0) - with circuit.box([Twirl(dressing="right")]): - circuit.sx(0) - circuit.rz(1.5, 0) - - yield circuit, "nonclifford_between_left_right_boxes" - circuit = QuantumCircuit(2) with circuit.box([Twirl(dressing="left")]): circuit.rz(1.2, 0) diff --git a/test/unit/test_builders/test_general_build_errors.py b/test/unit/test_builders/test_general_build_errors.py new file mode 100644 index 00000000..82f64d9c --- /dev/null +++ b/test/unit/test_builders/test_general_build_errors.py @@ -0,0 +1,55 @@ +# This code is a Qiskit project. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Testing of error raising during building process. + +Some build errors are hard to replicate without going through the entire build process. +This file is meant for such cases.""" + +import pytest +from qiskit.circuit import Parameter, QuantumCircuit + +from samplomatic import Twirl +from samplomatic.builders import pre_build +from samplomatic.exceptions import SamplexBuildError + + +class TestGeneralBuildErrors: + def _pre_build_and_assert_error(self, circuit, error_type, message): + with pytest.raises(error_type, match=message): + pre_build(circuit) + + def test_nonclifford_between_left_right_boxes(self): + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0) + self._pre_build_and_assert_error( + circuit, + SamplexBuildError, + "Cannot have non-clifford gate between a left-dressed box and a right-dressed box", + ) + + def test_parametric_nonclifford_between_left_right_boxes(self): + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0) + circuit.rz(Parameter("a"), 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0) + self._pre_build_and_assert_error( + circuit, + SamplexBuildError, + "Cannot have non-clifford gate between a left-dressed box and a right-dressed box", + ) diff --git a/test/unit/test_builders/test_template_builder/test_build.py b/test/unit/test_builders/test_template_builder/test_build.py index a6145ba9..c2f96238 100644 --- a/test/unit/test_builders/test_template_builder/test_build.py +++ b/test/unit/test_builders/test_template_builder/test_build.py @@ -131,7 +131,7 @@ def test_box_decomposition(self): with circuit.box([Twirl(decomposition="rzrx")]): circuit.cx(0, 1) - circuit.rz(Parameter("c"), 0) + circuit.x(0) with circuit.box([Twirl(decomposition="rzsx", dressing="right")]): circuit.cx(0, 1) @@ -143,7 +143,7 @@ def test_box_decomposition(self): assert template.num_qubits == 2 assert template.num_clbits == 2 - assert [p.name for p in template.parameters] == [f"p{str(i).zfill(2)}" for i in range(13)] + assert [p.name for p in template.parameters] == [f"p{str(i).zfill(2)}" for i in range(12)] expected_names = ["h"] @@ -151,7 +151,7 @@ def test_box_decomposition(self): expected_names += ["rz", "rx", "rz"] * 2 + ["barrier", "cx"] expected_names += ["barrier"] - expected_names += ["rz"] + expected_names += ["x"] expected_names += ["barrier"] expected_names += ["cx", "barrier"] + ["rz", "sx", "rz", "sx", "rz"] * 2 From 09b64241a60f7d979451ba2ef3e6d0023570e680 Mon Sep 17 00:00:00 2001 From: Tsafrir Armon Date: Thu, 14 Aug 2025 13:53:27 +0300 Subject: [PATCH 6/6] pre-commit --- .../test_u2_param_multiplication_node.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py index c503f504..d2a88e45 100644 --- a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py +++ b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py @@ -20,10 +20,10 @@ from samplomatic.distributions import HaarU2 from samplomatic.exceptions import SamplexConstructionError, SamplexRuntimeError from samplomatic.samplex.nodes import ( - LeftU2ParametricMultiplicationNode, - RightU2ParametricMultiplicationNode, LeftU2ParametricConjugationNode, + LeftU2ParametricMultiplicationNode, RightU2ParametricConjugationNode, + RightU2ParametricMultiplicationNode, ) X_MATRIX = np.array([[0, 1], [1, 0]]) @@ -116,6 +116,7 @@ def test_left_multiply(self, gate, matrix, rng): registers["a"].virtual_gates, np.matmul(register.virtual_gates, operation) ) + class TestLeftU2ParamConjugationNode: def test_instantiation_errors(self): """Test that errors are properly raised during instantiation""" @@ -153,9 +154,8 @@ def test_left_conjugation(self, gate, rng): node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) node.evaluate(expected_registers, -params) - assert np.allclose( - registers["a"].virtual_gates, expected_registers["a"].virtual_gates - ) + assert np.allclose(registers["a"].virtual_gates, expected_registers["a"].virtual_gates) + class TestRightU2ParamConjugationNode: def test_instantiation_errors(self): @@ -194,6 +194,4 @@ def test_right_conjugation(self, gate, rng): node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) node.evaluate(expected_registers, params) - assert np.allclose( - registers["a"].virtual_gates, expected_registers["a"].virtual_gates - ) \ No newline at end of file + assert np.allclose(registers["a"].virtual_gates, expected_registers["a"].virtual_gates)