diff --git a/samplomatic/builders/build.py b/samplomatic/builders/build.py index 21ca0e50..e5aee01c 100644 --- a/samplomatic/builders/build.py +++ b/samplomatic/builders/build.py @@ -17,12 +17,18 @@ from qiskit.circuit import QuantumCircuit from ..aliases import CircuitInstruction +from ..exceptions import SamplexBuildError from ..pre_samplex import PreSamplex from ..samplex import Samplex from .builder import Builder from .get_builders import get_builders from .specs import InstructionSpec -from .template_builder import TemplateState +from .template_builder import ( + LeftBoxTemplateBuilder, + PassthroughTemplateBuilder, + RightBoxTemplateBuilder, + TemplateState, +) def _build_stream( @@ -69,6 +75,24 @@ def _build( for idx, nested_instr in enumerate(_build_stream(stream, template_builder, samplex_builder)): # assume the nested instruction is a box for now, handle other control flow ops later inner_template_builder, inner_samplex_builder = get_builders(nested_instr) + if isinstance(template_builder, PassthroughTemplateBuilder): + if isinstance(inner_template_builder, LeftBoxTemplateBuilder): + # Upcoming box is left-dressed, so when we get back to the passthrough builder we + # need to track non-cliffords, in case then next box is right dressed. + template_builder.track_noncliffords = True + template_builder.found_noncliffords = False + elif ( + isinstance(inner_template_builder, RightBoxTemplateBuilder) + and template_builder.found_noncliffords + ): + raise SamplexBuildError( + "Cannot have non-clifford gate between a left-dressed box" + " and a right-dressed box (which involve that qubit)." + ) + else: + template_builder.track_noncliffords = False + template_builder.found_noncliffords = False + qubit_remapping = dict(zip(nested_instr.operation.body.qubits, nested_instr.qubits)) remapped_template_state = template_builder.state.remap(qubit_remapping, idx) diff --git a/samplomatic/builders/template_builder/passthrough_template_builder.py b/samplomatic/builders/template_builder/passthrough_template_builder.py index 49502a41..1108517b 100644 --- a/samplomatic/builders/template_builder/passthrough_template_builder.py +++ b/samplomatic/builders/template_builder/passthrough_template_builder.py @@ -23,6 +23,11 @@ class PassthroughTemplateBuilder(Builder[TemplateState, InstructionSpec]): """Template builder that passes all instructions through.""" + def __init__(self): + super().__init__() + self.track_noncliffords = False + self.found_noncliffords = False + def parse(self, instr: CircuitInstruction) -> InstructionSpec: """Parse a single non-box instruction. @@ -53,6 +58,8 @@ def parse(self, instr: CircuitInstruction) -> InstructionSpec: clbit_idxs=self.state.get_condition_clbits(instr.operation.condition), ) else: + if self.track_noncliffords and instr.operation.name in ("rx", "rz"): + self.found_noncliffords = True return InstructionSpec( params=self.state.append_remapped_gate(instr), mode=InstructionMode.PROPAGATE ) diff --git a/samplomatic/builders/template_builder/template_state.py b/samplomatic/builders/template_builder/template_state.py index 4b4697d0..43229e33 100644 --- a/samplomatic/builders/template_builder/template_state.py +++ b/samplomatic/builders/template_builder/template_state.py @@ -96,9 +96,13 @@ def append_remapped_gate( new_params = [] param_mapping = [] - for param in instr.operation.params: - param_mapping.append([self.param_iter.idx, param]) - new_params.append(next(self.param_iter)) + if instr.operation.is_parameterized(): + # Note: It is assumed here that if is_parameterized() is true, then all parameters + # are ParameterExpressions. This is true for now because all of our parametrized + # gates have a single parameter. + for param in instr.operation.params: + param_mapping.append([self.param_iter.idx, param]) + new_params.append(next(self.param_iter)) new_qubits = [self.qubit_map.get(qubit, qubit) for qubit in instr.qubits] new_operation = type(instr.operation)(*new_params) if new_params else instr.operation @@ -122,9 +126,13 @@ def remap_subcircuit(self, circuit: QuantumCircuit) -> tuple[QuantumCircuit, Par ] new_params = [] instr_param_mapping = [] - for param in instr.operation.params: - instr_param_mapping.append([self.param_iter.idx, param]) - new_params.append(next(self.param_iter)) + if instr.operation.is_parameterized(): + # Note: It is assumed here that if is_parameterized() is true, then all parameters + # are ParameterExpressions. This is true for now because all of our parametrized + # gates have a single parameter. + for param in instr.operation.params: + instr_param_mapping.append([self.param_iter.idx, param]) + new_params.append(next(self.param_iter)) new_operation = type(instr.operation)(*new_params) if new_params else instr.operation remapped_circuit.append(CircuitInstruction(new_operation, new_qubits, instr.clbits)) diff --git a/samplomatic/pre_samplex/pre_samplex.py b/samplomatic/pre_samplex/pre_samplex.py index 3f296331..63168a88 100644 --- a/samplomatic/pre_samplex/pre_samplex.py +++ b/samplomatic/pre_samplex/pre_samplex.py @@ -62,10 +62,14 @@ CollectZ2ToOutputNode, CombineRegistersNode, InjectNoiseNode, + LeftConjugationNode, LeftMultiplicationNode, + LeftU2ParametricConjugationNode, LeftU2ParametricMultiplicationNode, PauliPastCliffordNode, + RightConjugationNode, RightMultiplicationNode, + RightU2ParametricConjugationNode, RightU2ParametricMultiplicationNode, SliceRegisterNode, TwirlSamplingNode, @@ -1357,6 +1361,28 @@ def add_propagate_node( combined_register_name, np.array(list(pre_propagate.partition), dtype=np.intp), ) + elif mode is InstructionMode.PROPAGATE: + # What's left are the supported non-clifford gates rz\rx + combined_register_type = VirtualType.U2 + if pre_propagate.operation.is_parameterized(): + param_idxs = [ + samplex.append_parameter_expression(param) + for _, param in pre_propagate.spec.params + ] + if pre_propagate.direction is Direction.LEFT: + propagate_node = RightU2ParametricConjugationNode( + op_name, combined_register_name, param_idxs + ) + else: + propagate_node = LeftU2ParametricConjugationNode( + op_name, combined_register_name, param_idxs + ) + else: + operand = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2)) + if pre_propagate.direction is Direction.LEFT: + propagate_node = RightConjugationNode(operand, combined_register_name) + else: + propagate_node = LeftConjugationNode(operand, combined_register_name) else: raise SamplexBuildError( f"Encountered unsupported {op_name} propragation with mode {mode} and " @@ -1377,8 +1403,6 @@ def add_propagate_node( node_idx = samplex.add_node(propagate_node) samplex.add_edge(combine_node_idx, node_idx) else: - # TODO: It should be possible to not add a slice node in this case, if there is - # a single predecessor. node_idx = combine_node_idx pre_nodes_to_nodes[pre_propagate_idx] = node_idx diff --git a/samplomatic/samplex/nodes/__init__.py b/samplomatic/samplex/nodes/__init__.py index 75d74bd1..7cc1bd72 100644 --- a/samplomatic/samplex/nodes/__init__.py +++ b/samplomatic/samplex/nodes/__init__.py @@ -18,14 +18,21 @@ from .conversion_node import ConversionNode from .evaluation_node import EvaluationNode from .inject_noise_node import InjectNoiseNode -from .multiplication_node import LeftMultiplicationNode, RightMultiplicationNode +from .multiplication_node import ( + LeftConjugationNode, + LeftMultiplicationNode, + RightConjugationNode, + RightMultiplicationNode, +) from .node import Node from .pauli_past_clifford_node import PauliPastCliffordNode from .sampling_node import SamplingNode from .slice_register_node import SliceRegisterNode from .twirl_sampling_node import TwirlSamplingNode from .u2_param_multiplication_node import ( + LeftU2ParametricConjugationNode, LeftU2ParametricMultiplicationNode, + RightU2ParametricConjugationNode, RightU2ParametricMultiplicationNode, - U2ParametricMultiplicationNode, + U2ParametricTransformationNode, ) diff --git a/samplomatic/samplex/nodes/multiplication_node.py b/samplomatic/samplex/nodes/multiplication_node.py index bf41df83..8f6e2e73 100644 --- a/samplomatic/samplex/nodes/multiplication_node.py +++ b/samplomatic/samplex/nodes/multiplication_node.py @@ -19,12 +19,12 @@ from .evaluation_node import EvaluationNode -class MultiplicationNode(EvaluationNode): - """Abstract parent for nodes that perform multiplication against a fixed register. +class TransformationNode(EvaluationNode): + """Abstract parent for nodes that perform transformation against a fixed register. Args: - operand: The fixed group elements by which to multiply. - register_name: The name of the register to multiply with. + operand: The fixed group elements by which to transform. + register_name: The name of the register to transform with. Raises: SamplexConstructionError: If ``operand`` has more than one sample. @@ -56,7 +56,7 @@ def get_style(self): return super().get_style().append_data("Fixed Operand", repr(self._operand)) -class LeftMultiplicationNode(MultiplicationNode): +class LeftMultiplicationNode(TransformationNode): """Perform left multiplication of a fixed register against a given register. Args: @@ -71,7 +71,7 @@ def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): registers[self._register_name].left_inplace_multiply(self._operand) -class RightMultiplicationNode(MultiplicationNode): +class RightMultiplicationNode(TransformationNode): """Perform right multiplication of a fixed register against a given register. Args: @@ -84,3 +84,39 @@ class RightMultiplicationNode(MultiplicationNode): def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): registers[self._register_name].inplace_multiply(self._operand) + + +class LeftConjugationNode(TransformationNode): + """Perform left conjugation of a fixed register against a given register. + + Performs operand*reg*(operand^{\dagger}). + + Args: + operand: The fixed group elements by which to conjugate. + register_name: The name of the register to conjugate with. + + Raises: + SamplexConstructionError: If ``operand`` has more than one sample. + """ + + def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): + registers[self._register_name].left_inplace_multiply(self._operand) + registers[self._register_name].inplace_multiply(self._operand.invert()) + + +class RightConjugationNode(TransformationNode): + """Perform right conjugation of a fixed register against a given register. + + Performs (operand^{\dagger})*reg*operand. + + Args: + operand: The fixed group elements by which to conjugate. + register_name: The name of the register to conjugate with. + + Raises: + SamplexConstructionError: If ``operand`` has more than one sample. + """ + + def evaluate(self, registers: dict[RegisterName, VirtualRegister], *_): + registers[self._register_name].left_inplace_multiply(self._operand.invert()) + registers[self._register_name].inplace_multiply(self._operand) diff --git a/samplomatic/samplex/nodes/u2_param_multiplication_node.py b/samplomatic/samplex/nodes/u2_param_multiplication_node.py index 6aa221ff..cca0f4fb 100644 --- a/samplomatic/samplex/nodes/u2_param_multiplication_node.py +++ b/samplomatic/samplex/nodes/u2_param_multiplication_node.py @@ -23,11 +23,11 @@ from .evaluation_node import EvaluationNode -class U2ParametricMultiplicationNode(EvaluationNode): - """Abstract parent node for nodes doing multiplication on a :class:`~.U2Register`. +class U2ParametricTransformationNode(EvaluationNode): + """Abstract parent node for nodes doing transformation on a :class:`~.U2Register`. The node stores a parametric representation of a one-qubit gate or gates from the - original circuit to perform multiplication on a registry. + original circuit to perform transformation on a registry. Limited to the gates ``rz`` or ``rx``, and all gates within the node are of the same type. @@ -98,7 +98,7 @@ def _get_operation(self, parameter_values: np.ndarray) -> U2Register: return U2Register(result) -class LeftU2ParametricMultiplicationNode(U2ParametricMultiplicationNode): +class LeftU2ParametricMultiplicationNode(U2ParametricTransformationNode): """Perform parametric left multiplication on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the @@ -141,7 +141,7 @@ def evaluate( registers[self._register_name].left_inplace_multiply(self._get_operation(parameter_values)) -class RightU2ParametricMultiplicationNode(U2ParametricMultiplicationNode): +class RightU2ParametricMultiplicationNode(U2ParametricTransformationNode): """Perform parametric right multiplication on a :class:`~.U2Register`. The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the @@ -182,3 +182,93 @@ def evaluate( ) registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) + + +class LeftU2ParametricConjugationNode(U2ParametricTransformationNode): + """Perform parametric left conjugation on a :class:`~.U2Register`. + + The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the + original circuit and performs a :math:`g*reg*g^{\dagger}` conjugation, where :math:`reg` is + the existing ``U2Register``. This is consistent with a traveling virtual gate going from + right to left. + + :math:`g` is limited to the gates ``rz`` or ``rx``, and all gates within the node are of the + same type:math:. + + Args: + operand: The gate type, given as a string. + register_name: The name of the register the operation is applied to. + param_idxs: List of ``ParamIndex`` for the parameter expressions specifying the gate + arguments. List order must match the order subsystems in the register. + + Raises: + SamplexConstructionError: if `param_idxs` is empty. + """ + + def evaluate( + self, registers: dict[RegisterName, VirtualRegister], parameter_values: np.ndarray + ): + """Evaluate this node. + + Args: + registers: At least those registers needed by this node to read from or write to. + parameter_values: The evaluated values of the parameter expressions in indices + ``self.parameter_idxs``, at the same order. + + Raises: + SamplexRuntimeError: If the number of parameter values doesn't match the number of + parameter expressions in ``self.parameter_idxs``. + """ + if len(parameter_values) != self.num_parameters: + raise SamplexRuntimeError( + f"Expected {self.num_parameters} parameter values instead got " + f"{len(parameter_values)}" + ) + + registers[self._register_name].left_inplace_multiply(self._get_operation(parameter_values)) + registers[self._register_name].inplace_multiply(self._get_operation(-parameter_values)) + + +class RightU2ParametricConjugationNode(U2ParametricTransformationNode): + """Perform parametric right conjugation on a :class:`~.U2Register`. + + The node stores a parametric representation :math:`g` of a one-qubit gate or gates from the + original circuit and performs a :math:`g^{\dagger}*reg*g` conjugation, where :math:`reg` is + the existing ``U2Register``. This is consistent with a traveling virtual gate going from + left to right. + + :math:`g` is limited to the gates ``rz`` or ``rx``, and all gates within the node are of the + same type:math:. + + Args: + operand: The gate type, given as a string. + register_name: The name of the register the operation is applied to. + param_idxs: List of ``ParamIndex`` for the parameter expressions specifying the gate + arguments. List order must match the order subsystems in the register. + + Raises: + SamplexConstructionError: if `param_idxs` is empty. + """ + + def evaluate( + self, registers: dict[RegisterName, VirtualRegister], parameter_values: np.ndarray + ): + """Evaluate this node. + + Args: + registers: At least those registers needed by this node to read from or write to. + parameter_values: The evaluated values of the parameter expressions in indices + ``self.parameter_idxs``, at the same order. + + Raises: + SamplexRuntimeError: If the number of parameter values doesn't match the number of + parameter expressions in ``self.parameter_idxs``. + """ + if len(parameter_values) != self.num_parameters: + raise SamplexRuntimeError( + f"Expected {self.num_parameters} parameter values instead got " + f"{len(parameter_values)}" + ) + + registers[self._register_name].left_inplace_multiply(self._get_operation(-parameter_values)) + registers[self._register_name].inplace_multiply(self._get_operation(parameter_values)) diff --git a/test/integration/test_dynamic_circuits.py b/test/integration/test_dynamic_circuits.py index aeefb46b..58fce4ae 100644 --- a/test/integration/test_dynamic_circuits.py +++ b/test/integration/test_dynamic_circuits.py @@ -233,6 +233,7 @@ def test_non_twirled_conditional(self, save_plot): circuit.measure(0, 0) with circuit.if_test((circuit.clbits[0], 1)) as _else: circuit.sx(1) + circuit.rz(1.2, 1) with _else: circuit.x(1) circuit.measure_all() diff --git a/test/integration/test_parametric_twirling_samples.py b/test/integration/test_parametric_twirling_samples.py index ab926cdd..e3a97090 100644 --- a/test/integration/test_parametric_twirling_samples.py +++ b/test/integration/test_parametric_twirling_samples.py @@ -84,6 +84,34 @@ def make_circuits(): yield circuit, "parametric_right_box" + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.h(0) + circuit.rz(-1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.x(0) + circuit.rz(1.2, 0) + circuit.rx(Parameter("a"), 0) + circuit.rz(Parameter("b"), 0) + with circuit.box([Twirl(dressing="right")]): + circuit.sx(0) + circuit.rz(1.5, 0) + + yield circuit, "parameterized_nonclifford_between_right_right_boxes" + + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.rz(1.2, 0) + circuit.cx(0, 1) + circuit.rx(Parameter("a"), 0) + circuit.rz(Parameter("b"), 0) + with circuit.box([Twirl(dressing="left")]): + circuit.cx(1, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + + yield circuit, "parameterized_nonclifford_between_left_left_boxes" + def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/integration/test_static_twirling_samples.py b/test/integration/test_static_twirling_samples.py index 911bc6f4..ca849f99 100644 --- a/test/integration/test_static_twirling_samples.py +++ b/test/integration/test_static_twirling_samples.py @@ -35,6 +35,20 @@ def make_circuits(): yield circuit, f"{op_name}_{str(pair1).replace(' ', '')}_{str(pair2).replace(' ', '')}" + circuit = QuantumCircuit(2) + with circuit.box([Twirl()]): + circuit.noop(0, 1) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + circuit.cx(0, 1) + circuit.rz(1.2, 0) + circuit.h(1) + circuit.sx(0) + circuit.x(1) + circuit.rx(1.2, 1) + + yield circuit, "instructions_outside_boxes_chain" + circuit = QuantumCircuit(4) with circuit.box([Twirl(dressing="left")]): circuit.cx(1, 0) @@ -229,6 +243,30 @@ def make_circuits(): yield circuit, "propagate_through_merged_invariant_gates" + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.rz(1.2, 0) + circuit.cx(0, 1) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="left")]): + circuit.cx(1, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + + yield circuit, "nonclifford_between_left_left_boxes" + + circuit = QuantumCircuit(2) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0, 1) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0, 1) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.cx(0, 1) + circuit.rx(1.2, 1) + + yield circuit, "nonclifford_between_right_right_boxes" + def pytest_generate_tests(metafunc): if "circuit" in metafunc.fixturenames: diff --git a/test/unit/test_builders/test_general_build_errors.py b/test/unit/test_builders/test_general_build_errors.py new file mode 100644 index 00000000..82f64d9c --- /dev/null +++ b/test/unit/test_builders/test_general_build_errors.py @@ -0,0 +1,55 @@ +# This code is a Qiskit project. +# +# (C) Copyright IBM 2025. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Testing of error raising during building process. + +Some build errors are hard to replicate without going through the entire build process. +This file is meant for such cases.""" + +import pytest +from qiskit.circuit import Parameter, QuantumCircuit + +from samplomatic import Twirl +from samplomatic.builders import pre_build +from samplomatic.exceptions import SamplexBuildError + + +class TestGeneralBuildErrors: + def _pre_build_and_assert_error(self, circuit, error_type, message): + with pytest.raises(error_type, match=message): + pre_build(circuit) + + def test_nonclifford_between_left_right_boxes(self): + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0) + circuit.rx(1.2, 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0) + self._pre_build_and_assert_error( + circuit, + SamplexBuildError, + "Cannot have non-clifford gate between a left-dressed box and a right-dressed box", + ) + + def test_parametric_nonclifford_between_left_right_boxes(self): + circuit = QuantumCircuit(1) + with circuit.box([Twirl(dressing="left")]): + circuit.noop(0) + circuit.rz(Parameter("a"), 0) + with circuit.box([Twirl(dressing="right")]): + circuit.noop(0) + self._pre_build_and_assert_error( + circuit, + SamplexBuildError, + "Cannot have non-clifford gate between a left-dressed box and a right-dressed box", + ) diff --git a/test/unit/test_builders/test_template_builder/test_build.py b/test/unit/test_builders/test_template_builder/test_build.py index a6145ba9..c2f96238 100644 --- a/test/unit/test_builders/test_template_builder/test_build.py +++ b/test/unit/test_builders/test_template_builder/test_build.py @@ -131,7 +131,7 @@ def test_box_decomposition(self): with circuit.box([Twirl(decomposition="rzrx")]): circuit.cx(0, 1) - circuit.rz(Parameter("c"), 0) + circuit.x(0) with circuit.box([Twirl(decomposition="rzsx", dressing="right")]): circuit.cx(0, 1) @@ -143,7 +143,7 @@ def test_box_decomposition(self): assert template.num_qubits == 2 assert template.num_clbits == 2 - assert [p.name for p in template.parameters] == [f"p{str(i).zfill(2)}" for i in range(13)] + assert [p.name for p in template.parameters] == [f"p{str(i).zfill(2)}" for i in range(12)] expected_names = ["h"] @@ -151,7 +151,7 @@ def test_box_decomposition(self): expected_names += ["rz", "rx", "rz"] * 2 + ["barrier", "cx"] expected_names += ["barrier"] - expected_names += ["rz"] + expected_names += ["x"] expected_names += ["barrier"] expected_names += ["cx", "barrier"] + ["rz", "sx", "rz", "sx", "rz"] * 2 diff --git a/test/unit/test_samplex/test_nodes/test_multiplication_node.py b/test/unit/test_samplex/test_nodes/test_multiplication_node.py index 0e0ed7ef..fc41f179 100644 --- a/test/unit/test_samplex/test_nodes/test_multiplication_node.py +++ b/test/unit/test_samplex/test_nodes/test_multiplication_node.py @@ -18,7 +18,12 @@ from samplomatic.annotations import VirtualType from samplomatic.distributions import HaarU2, UniformPauli from samplomatic.exceptions import SamplexConstructionError -from samplomatic.samplex.nodes import LeftMultiplicationNode, RightMultiplicationNode +from samplomatic.samplex.nodes import ( + LeftConjugationNode, + LeftMultiplicationNode, + RightConjugationNode, + RightMultiplicationNode, +) from samplomatic.virtual_registers import U2Register @@ -80,3 +85,73 @@ def test_writes_to(self): node = RightMultiplicationNode(U2Register.identity(3, 1), "a") assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} assert node.outgoing_register_type is VirtualType.U2 + + +class TestRightConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match=re.escape("Expected fixed operand to have only one sample but it has 7"), + ): + RightConjugationNode(U2Register.identity(5, 7), "a") + + @pytest.mark.parametrize("distribution_type", [HaarU2, UniformPauli]) + def test_multiply(self, distribution_type, rng): + """Test left multiply""" + operand = distribution_type(5).sample(1, rng) + register = distribution_type(5).sample(7, rng) + node = RightConjugationNode(operand, "a") + assert node.outgoing_register_type is operand.TYPE + + registers = {"a": register.copy()} + node.evaluate(registers, []) + expected_registers = {"a": register.copy()} + node = LeftMultiplicationNode(operand.invert(), "a") + node.evaluate(expected_registers) + node = RightMultiplicationNode(operand, "a") + node.evaluate(expected_registers) + + assert list(registers) == ["a"] + assert np.allclose(expected_registers["a"].virtual_gates, registers["a"].virtual_gates) + + def test_writes_to(self): + """Test writes to""" + node = RightConjugationNode(U2Register.identity(3, 1), "a") + assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} + assert node.outgoing_register_type is VirtualType.U2 + + +class TestLeftConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match=re.escape("Expected fixed operand to have only one sample but it has 7"), + ): + LeftConjugationNode(U2Register.identity(5, 7), "a") + + @pytest.mark.parametrize("distribution_type", [HaarU2, UniformPauli]) + def test_multiply(self, distribution_type, rng): + """Test left multiply""" + operand = distribution_type(5).sample(1, rng) + register = distribution_type(5).sample(7, rng) + node = LeftConjugationNode(operand, "a") + assert node.outgoing_register_type is operand.TYPE + + registers = {"a": register.copy()} + node.evaluate(registers, []) + expected_registers = {"a": register.copy()} + node = LeftMultiplicationNode(operand, "a") + node.evaluate(expected_registers) + node = RightMultiplicationNode(operand.invert(), "a") + node.evaluate(expected_registers) + + assert list(registers) == ["a"] + assert np.allclose(expected_registers["a"].virtual_gates, registers["a"].virtual_gates) + + def test_writes_to(self): + """Test writes to""" + node = LeftConjugationNode(U2Register.identity(3, 1), "a") + assert node.writes_to() == {"a": ({0, 1, 2}, VirtualType.U2)} + assert node.outgoing_register_type is VirtualType.U2 diff --git a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py index 28e9079b..d2a88e45 100644 --- a/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py +++ b/test/unit/test_samplex/test_nodes/test_u2_param_multiplication_node.py @@ -20,7 +20,9 @@ from samplomatic.distributions import HaarU2 from samplomatic.exceptions import SamplexConstructionError, SamplexRuntimeError from samplomatic.samplex.nodes import ( + LeftU2ParametricConjugationNode, LeftU2ParametricMultiplicationNode, + RightU2ParametricConjugationNode, RightU2ParametricMultiplicationNode, ) @@ -113,3 +115,83 @@ def test_left_multiply(self, gate, matrix, rng): assert np.allclose( registers["a"].virtual_gates, np.matmul(register.virtual_gates, operation) ) + + +class TestLeftU2ParamConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match="Expected at least one element in param_idxs", + ): + LeftU2ParametricConjugationNode("rz", "a", []) + + def test_evaluation_errors(self): + """Test that errors are properly raised during evaluation""" + node = LeftU2ParametricConjugationNode("rz", "a", [0, 1, 2]) + registers = {} + with pytest.raises( + SamplexRuntimeError, match=re.escape("Expected 3 parameter values instead got 1") + ): + node.evaluate(registers, [1]) + + @pytest.mark.parametrize("gate", ["rx", "rz"]) + def test_left_conjugation(self, gate, rng): + """Test left conjugation""" + register_shape = [5, 2] + + params = rng.random(register_shape[0]) * 2 * np.pi + haar_dist = HaarU2(register_shape[0]) + node = LeftU2ParametricConjugationNode(gate, "a", [i for i in range(register_shape[0])]) + register = haar_dist.sample(register_shape[1], rng) + registers = {"a": register.copy()} + node.evaluate(registers, params) + + # g*reg*(g^dagger) + expected_registers = {"a": register.copy()} + node = LeftU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, params) + node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, -params) + + assert np.allclose(registers["a"].virtual_gates, expected_registers["a"].virtual_gates) + + +class TestRightU2ParamConjugationNode: + def test_instantiation_errors(self): + """Test that errors are properly raised during instantiation""" + with pytest.raises( + SamplexConstructionError, + match="Expected at least one element in param_idxs", + ): + RightU2ParametricConjugationNode("rz", "a", []) + + def test_evaluation_errors(self): + """Test that errors are properly raised during evaluation""" + node = RightU2ParametricConjugationNode("rz", "a", [0, 1, 2]) + registers = {} + with pytest.raises( + SamplexRuntimeError, match=re.escape("Expected 3 parameter values instead got 1") + ): + node.evaluate(registers, [1]) + + @pytest.mark.parametrize("gate", ["rx", "rz"]) + def test_right_conjugation(self, gate, rng): + """Test right conjugation""" + register_shape = [5, 2] + + params = rng.random(register_shape[0]) * 2 * np.pi + haar_dist = HaarU2(register_shape[0]) + node = RightU2ParametricConjugationNode(gate, "a", [i for i in range(register_shape[0])]) + register = haar_dist.sample(register_shape[1], rng) + registers = {"a": register.copy()} + node.evaluate(registers, params) + + # (g^dagger)*reg*g + expected_registers = {"a": register.copy()} + node = LeftU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, -params) + node = RightU2ParametricMultiplicationNode(gate, "a", [i for i in range(register_shape[0])]) + node.evaluate(expected_registers, params) + + assert np.allclose(registers["a"].virtual_gates, expected_registers["a"].virtual_gates)